Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
987589ea
Commit
987589ea
authored
Feb 21, 2026
by
yangjianbo
Browse files
Merge branch 'test' into release
parents
372e04f6
03f69dd3
Changes
109
Expand all
Hide whitespace changes
Inline
Side-by-side
backend/internal/handler/admin/proxy_handler.go
View file @
987589ea
...
...
@@ -236,6 +236,24 @@ func (h *ProxyHandler) Test(c *gin.Context) {
response
.
Success
(
c
,
result
)
}
// CheckQuality handles checking proxy quality across common AI targets.
// POST /api/v1/admin/proxies/:id/quality-check
func
(
h
*
ProxyHandler
)
CheckQuality
(
c
*
gin
.
Context
)
{
proxyID
,
err
:=
strconv
.
ParseInt
(
c
.
Param
(
"id"
),
10
,
64
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid proxy ID"
)
return
}
result
,
err
:=
h
.
adminService
.
CheckProxyQuality
(
c
.
Request
.
Context
(),
proxyID
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
result
)
}
// GetStats handles getting proxy statistics
// GET /api/v1/admin/proxies/:id/stats
func
(
h
*
ProxyHandler
)
GetStats
(
c
*
gin
.
Context
)
{
...
...
backend/internal/handler/dto/mappers.go
View file @
987589ea
...
...
@@ -214,6 +214,13 @@ func AccountFromServiceShallow(a *service.Account) *Account {
enabled
:=
true
out
.
EnableSessionIDMasking
=
&
enabled
}
// 缓存 TTL 强制替换
if
a
.
IsCacheTTLOverrideEnabled
()
{
enabled
:=
true
out
.
CacheTTLOverrideEnabled
=
&
enabled
target
:=
a
.
GetCacheTTLOverrideTarget
()
out
.
CacheTTLOverrideTarget
=
&
target
}
}
return
out
...
...
@@ -296,6 +303,11 @@ func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWi
CountryCode
:
p
.
CountryCode
,
Region
:
p
.
Region
,
City
:
p
.
City
,
QualityStatus
:
p
.
QualityStatus
,
QualityScore
:
p
.
QualityScore
,
QualityGrade
:
p
.
QualityGrade
,
QualitySummary
:
p
.
QualitySummary
,
QualityChecked
:
p
.
QualityChecked
,
}
}
...
...
@@ -402,6 +414,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
ImageSize
:
l
.
ImageSize
,
MediaType
:
l
.
MediaType
,
UserAgent
:
l
.
UserAgent
,
CacheTTLOverridden
:
l
.
CacheTTLOverridden
,
CreatedAt
:
l
.
CreatedAt
,
User
:
UserFromServiceShallow
(
l
.
User
),
APIKey
:
APIKeyFromService
(
l
.
APIKey
),
...
...
backend/internal/handler/dto/types.go
View file @
987589ea
...
...
@@ -156,6 +156,11 @@ type Account struct {
// 从 extra 字段提取,方便前端显示和编辑
EnableSessionIDMasking
*
bool
`json:"session_id_masking_enabled,omitempty"`
// 缓存 TTL 强制替换(仅 Anthropic OAuth/SetupToken 账号有效)
// 启用后将所有 cache creation tokens 归入指定的 TTL 类型计费
CacheTTLOverrideEnabled
*
bool
`json:"cache_ttl_override_enabled,omitempty"`
CacheTTLOverrideTarget
*
string
`json:"cache_ttl_override_target,omitempty"`
Proxy
*
Proxy
`json:"proxy,omitempty"`
AccountGroups
[]
AccountGroup
`json:"account_groups,omitempty"`
...
...
@@ -197,6 +202,11 @@ type ProxyWithAccountCount struct {
CountryCode
string
`json:"country_code,omitempty"`
Region
string
`json:"region,omitempty"`
City
string
`json:"city,omitempty"`
QualityStatus
string
`json:"quality_status,omitempty"`
QualityScore
*
int
`json:"quality_score,omitempty"`
QualityGrade
string
`json:"quality_grade,omitempty"`
QualitySummary
string
`json:"quality_summary,omitempty"`
QualityChecked
*
int64
`json:"quality_checked,omitempty"`
}
type
ProxyAccountSummary
struct
{
...
...
@@ -280,6 +290,9 @@ type UsageLog struct {
// User-Agent
UserAgent
*
string
`json:"user_agent"`
// Cache TTL Override 标记
CacheTTLOverridden
bool
`json:"cache_ttl_overridden"`
CreatedAt
time
.
Time
`json:"created_at"`
User
*
User
`json:"user,omitempty"`
...
...
backend/internal/handler/sora_gateway_handler.go
View file @
987589ea
...
...
@@ -4,6 +4,7 @@ import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
...
...
@@ -20,6 +21,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
middleware2
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
...
...
@@ -35,6 +37,7 @@ type SoraGatewayHandler struct {
concurrencyHelper
*
ConcurrencyHelper
maxAccountSwitches
int
streamMode
string
soraTLSEnabled
bool
soraMediaSigningKey
string
soraMediaRoot
string
}
...
...
@@ -50,6 +53,7 @@ func NewSoraGatewayHandler(
pingInterval
:=
time
.
Duration
(
0
)
maxAccountSwitches
:=
3
streamMode
:=
"force"
soraTLSEnabled
:=
true
signKey
:=
""
mediaRoot
:=
"/app/data/sora"
if
cfg
!=
nil
{
...
...
@@ -60,6 +64,7 @@ func NewSoraGatewayHandler(
if
mode
:=
strings
.
TrimSpace
(
cfg
.
Gateway
.
SoraStreamMode
);
mode
!=
""
{
streamMode
=
mode
}
soraTLSEnabled
=
!
cfg
.
Sora
.
Client
.
DisableTLSFingerprint
signKey
=
strings
.
TrimSpace
(
cfg
.
Gateway
.
SoraMediaSigningKey
)
if
root
:=
strings
.
TrimSpace
(
cfg
.
Sora
.
Storage
.
LocalPath
);
root
!=
""
{
mediaRoot
=
root
...
...
@@ -72,6 +77,7 @@ func NewSoraGatewayHandler(
concurrencyHelper
:
NewConcurrencyHelper
(
concurrencyService
,
SSEPingFormatComment
,
pingInterval
),
maxAccountSwitches
:
maxAccountSwitches
,
streamMode
:
strings
.
ToLower
(
streamMode
),
soraTLSEnabled
:
soraTLSEnabled
,
soraMediaSigningKey
:
signKey
,
soraMediaRoot
:
mediaRoot
,
}
...
...
@@ -212,6 +218,8 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
switchCount
:=
0
failedAccountIDs
:=
make
(
map
[
int64
]
struct
{})
lastFailoverStatus
:=
0
var
lastFailoverBody
[]
byte
var
lastFailoverHeaders
http
.
Header
for
{
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionHash
,
reqModel
,
failedAccountIDs
,
""
)
...
...
@@ -224,11 +232,31 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts: "
+
err
.
Error
(),
streamStarted
)
return
}
h
.
handleFailoverExhausted
(
c
,
lastFailoverStatus
,
streamStarted
)
rayID
,
mitigated
,
contentType
:=
extractSoraFailoverHeaderInsights
(
lastFailoverHeaders
,
lastFailoverBody
)
fields
:=
[]
zap
.
Field
{
zap
.
Int
(
"last_upstream_status"
,
lastFailoverStatus
),
}
if
rayID
!=
""
{
fields
=
append
(
fields
,
zap
.
String
(
"last_upstream_cf_ray"
,
rayID
))
}
if
mitigated
!=
""
{
fields
=
append
(
fields
,
zap
.
String
(
"last_upstream_cf_mitigated"
,
mitigated
))
}
if
contentType
!=
""
{
fields
=
append
(
fields
,
zap
.
String
(
"last_upstream_content_type"
,
contentType
))
}
reqLog
.
Warn
(
"sora.failover_exhausted_no_available_accounts"
,
fields
...
)
h
.
handleFailoverExhausted
(
c
,
lastFailoverStatus
,
lastFailoverHeaders
,
lastFailoverBody
,
streamStarted
)
return
}
account
:=
selection
.
Account
setOpsSelectedAccount
(
c
,
account
.
ID
,
account
.
Platform
)
proxyBound
:=
account
.
ProxyID
!=
nil
proxyID
:=
int64
(
0
)
if
account
.
ProxyID
!=
nil
{
proxyID
=
*
account
.
ProxyID
}
tlsFingerprintEnabled
:=
h
.
soraTLSEnabled
accountReleaseFunc
:=
selection
.
ReleaseFunc
if
!
selection
.
Acquired
{
...
...
@@ -239,10 +267,19 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
accountWaitCounted
:=
false
canWait
,
err
:=
h
.
concurrencyHelper
.
IncrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
,
selection
.
WaitPlan
.
MaxWaiting
)
if
err
!=
nil
{
reqLog
.
Warn
(
"sora.account_wait_counter_increment_failed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Error
(
err
))
reqLog
.
Warn
(
"sora.account_wait_counter_increment_failed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Int64
(
"proxy_id"
,
proxyID
),
zap
.
Bool
(
"proxy_bound"
,
proxyBound
),
zap
.
Bool
(
"tls_fingerprint_enabled"
,
tlsFingerprintEnabled
),
zap
.
Error
(
err
),
)
}
else
if
!
canWait
{
reqLog
.
Info
(
"sora.account_wait_queue_full"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Int64
(
"proxy_id"
,
proxyID
),
zap
.
Bool
(
"proxy_bound"
,
proxyBound
),
zap
.
Bool
(
"tls_fingerprint_enabled"
,
tlsFingerprintEnabled
),
zap
.
Int
(
"max_waiting"
,
selection
.
WaitPlan
.
MaxWaiting
),
)
h
.
handleStreamingAwareError
(
c
,
http
.
StatusTooManyRequests
,
"rate_limit_error"
,
"Too many pending requests, please retry later"
,
streamStarted
)
...
...
@@ -266,7 +303,13 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
&
streamStarted
,
)
if
err
!=
nil
{
reqLog
.
Warn
(
"sora.account_slot_acquire_failed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Error
(
err
))
reqLog
.
Warn
(
"sora.account_slot_acquire_failed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Int64
(
"proxy_id"
,
proxyID
),
zap
.
Bool
(
"proxy_bound"
,
proxyBound
),
zap
.
Bool
(
"tls_fingerprint_enabled"
,
tlsFingerprintEnabled
),
zap
.
Error
(
err
),
)
h
.
handleConcurrencyError
(
c
,
err
,
"account"
,
streamStarted
)
return
}
...
...
@@ -287,20 +330,67 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
failedAccountIDs
[
account
.
ID
]
=
struct
{}{}
if
switchCount
>=
maxAccountSwitches
{
lastFailoverStatus
=
failoverErr
.
StatusCode
h
.
handleFailoverExhausted
(
c
,
lastFailoverStatus
,
streamStarted
)
lastFailoverHeaders
=
cloneHTTPHeaders
(
failoverErr
.
ResponseHeaders
)
lastFailoverBody
=
failoverErr
.
ResponseBody
rayID
,
mitigated
,
contentType
:=
extractSoraFailoverHeaderInsights
(
lastFailoverHeaders
,
lastFailoverBody
)
fields
:=
[]
zap
.
Field
{
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Int64
(
"proxy_id"
,
proxyID
),
zap
.
Bool
(
"proxy_bound"
,
proxyBound
),
zap
.
Bool
(
"tls_fingerprint_enabled"
,
tlsFingerprintEnabled
),
zap
.
Int
(
"upstream_status"
,
failoverErr
.
StatusCode
),
zap
.
Int
(
"switch_count"
,
switchCount
),
zap
.
Int
(
"max_switches"
,
maxAccountSwitches
),
}
if
rayID
!=
""
{
fields
=
append
(
fields
,
zap
.
String
(
"upstream_cf_ray"
,
rayID
))
}
if
mitigated
!=
""
{
fields
=
append
(
fields
,
zap
.
String
(
"upstream_cf_mitigated"
,
mitigated
))
}
if
contentType
!=
""
{
fields
=
append
(
fields
,
zap
.
String
(
"upstream_content_type"
,
contentType
))
}
reqLog
.
Warn
(
"sora.upstream_failover_exhausted"
,
fields
...
)
h
.
handleFailoverExhausted
(
c
,
lastFailoverStatus
,
lastFailoverHeaders
,
lastFailoverBody
,
streamStarted
)
return
}
lastFailoverStatus
=
failoverErr
.
StatusCode
lastFailoverHeaders
=
cloneHTTPHeaders
(
failoverErr
.
ResponseHeaders
)
lastFailoverBody
=
failoverErr
.
ResponseBody
switchCount
++
reqLog
.
Warn
(
"sora.upstream_failover_switching"
,
upstreamErrCode
,
upstreamErrMsg
:=
extractUpstreamErrorCodeAndMessage
(
lastFailoverBody
)
rayID
,
mitigated
,
contentType
:=
extractSoraFailoverHeaderInsights
(
lastFailoverHeaders
,
lastFailoverBody
)
fields
:=
[]
zap
.
Field
{
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Int64
(
"proxy_id"
,
proxyID
),
zap
.
Bool
(
"proxy_bound"
,
proxyBound
),
zap
.
Bool
(
"tls_fingerprint_enabled"
,
tlsFingerprintEnabled
),
zap
.
Int
(
"upstream_status"
,
failoverErr
.
StatusCode
),
zap
.
String
(
"upstream_error_code"
,
upstreamErrCode
),
zap
.
String
(
"upstream_error_message"
,
upstreamErrMsg
),
zap
.
Int
(
"switch_count"
,
switchCount
),
zap
.
Int
(
"max_switches"
,
maxAccountSwitches
),
)
}
if
rayID
!=
""
{
fields
=
append
(
fields
,
zap
.
String
(
"upstream_cf_ray"
,
rayID
))
}
if
mitigated
!=
""
{
fields
=
append
(
fields
,
zap
.
String
(
"upstream_cf_mitigated"
,
mitigated
))
}
if
contentType
!=
""
{
fields
=
append
(
fields
,
zap
.
String
(
"upstream_content_type"
,
contentType
))
}
reqLog
.
Warn
(
"sora.upstream_failover_switching"
,
fields
...
)
continue
}
reqLog
.
Error
(
"sora.forward_failed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Error
(
err
))
reqLog
.
Error
(
"sora.forward_failed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Int64
(
"proxy_id"
,
proxyID
),
zap
.
Bool
(
"proxy_bound"
,
proxyBound
),
zap
.
Bool
(
"tls_fingerprint_enabled"
,
tlsFingerprintEnabled
),
zap
.
Error
(
err
),
)
return
}
...
...
@@ -331,6 +421,9 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
}(
result
,
account
,
userAgent
,
clientIP
)
reqLog
.
Debug
(
"sora.request_completed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Int64
(
"proxy_id"
,
proxyID
),
zap
.
Bool
(
"proxy_bound"
,
proxyBound
),
zap
.
Bool
(
"tls_fingerprint_enabled"
,
tlsFingerprintEnabled
),
zap
.
Int
(
"switch_count"
,
switchCount
),
)
return
...
...
@@ -360,17 +453,41 @@ func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, s
fmt
.
Sprintf
(
"Concurrency limit exceeded for %s, please retry later"
,
slotType
),
streamStarted
)
}
func
(
h
*
SoraGatewayHandler
)
handleFailoverExhausted
(
c
*
gin
.
Context
,
statusCode
int
,
streamStarted
bool
)
{
status
,
errType
,
errMsg
:=
h
.
mapUpstreamError
(
statusCode
)
func
(
h
*
SoraGatewayHandler
)
handleFailoverExhausted
(
c
*
gin
.
Context
,
statusCode
int
,
responseHeaders
http
.
Header
,
responseBody
[]
byte
,
streamStarted
bool
)
{
status
,
errType
,
errMsg
:=
h
.
mapUpstreamError
(
statusCode
,
responseHeaders
,
responseBody
)
h
.
handleStreamingAwareError
(
c
,
status
,
errType
,
errMsg
,
streamStarted
)
}
func
(
h
*
SoraGatewayHandler
)
mapUpstreamError
(
statusCode
int
)
(
int
,
string
,
string
)
{
func
(
h
*
SoraGatewayHandler
)
mapUpstreamError
(
statusCode
int
,
responseHeaders
http
.
Header
,
responseBody
[]
byte
)
(
int
,
string
,
string
)
{
if
isSoraCloudflareChallengeResponse
(
statusCode
,
responseHeaders
,
responseBody
)
{
baseMsg
:=
fmt
.
Sprintf
(
"Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry."
,
statusCode
)
return
http
.
StatusBadGateway
,
"upstream_error"
,
formatSoraCloudflareChallengeMessage
(
baseMsg
,
responseHeaders
,
responseBody
)
}
upstreamCode
,
upstreamMessage
:=
extractUpstreamErrorCodeAndMessage
(
responseBody
)
if
strings
.
EqualFold
(
upstreamCode
,
"cf_shield_429"
)
{
baseMsg
:=
"Sora request blocked by Cloudflare shield (429). Please switch to a clean proxy/network and retry."
return
http
.
StatusTooManyRequests
,
"rate_limit_error"
,
formatSoraCloudflareChallengeMessage
(
baseMsg
,
responseHeaders
,
responseBody
)
}
if
shouldPassthroughSoraUpstreamMessage
(
statusCode
,
upstreamMessage
)
{
switch
statusCode
{
case
401
,
403
,
404
,
500
,
502
,
503
,
504
:
return
http
.
StatusBadGateway
,
"upstream_error"
,
upstreamMessage
case
429
:
return
http
.
StatusTooManyRequests
,
"rate_limit_error"
,
upstreamMessage
}
}
switch
statusCode
{
case
401
:
return
http
.
StatusBadGateway
,
"upstream_error"
,
"Upstream authentication failed, please contact administrator"
case
403
:
return
http
.
StatusBadGateway
,
"upstream_error"
,
"Upstream access forbidden, please contact administrator"
case
404
:
if
strings
.
EqualFold
(
upstreamCode
,
"unsupported_country_code"
)
{
return
http
.
StatusBadGateway
,
"upstream_error"
,
"Upstream region capability unavailable for this account, please contact administrator"
}
return
http
.
StatusBadGateway
,
"upstream_error"
,
"Upstream capability unavailable for this account, please contact administrator"
case
429
:
return
http
.
StatusTooManyRequests
,
"rate_limit_error"
,
"Upstream rate limit exceeded, please retry later"
case
529
:
...
...
@@ -382,11 +499,67 @@ func (h *SoraGatewayHandler) mapUpstreamError(statusCode int) (int, string, stri
}
}
func
cloneHTTPHeaders
(
headers
http
.
Header
)
http
.
Header
{
if
headers
==
nil
{
return
nil
}
return
headers
.
Clone
()
}
func
extractSoraFailoverHeaderInsights
(
headers
http
.
Header
,
body
[]
byte
)
(
rayID
,
mitigated
,
contentType
string
)
{
if
headers
!=
nil
{
mitigated
=
strings
.
TrimSpace
(
headers
.
Get
(
"cf-mitigated"
))
contentType
=
strings
.
TrimSpace
(
headers
.
Get
(
"content-type"
))
if
contentType
==
""
{
contentType
=
strings
.
TrimSpace
(
headers
.
Get
(
"Content-Type"
))
}
}
rayID
=
soraerror
.
ExtractCloudflareRayID
(
headers
,
body
)
return
rayID
,
mitigated
,
contentType
}
func
isSoraCloudflareChallengeResponse
(
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
)
bool
{
return
soraerror
.
IsCloudflareChallengeResponse
(
statusCode
,
headers
,
body
)
}
func
shouldPassthroughSoraUpstreamMessage
(
statusCode
int
,
message
string
)
bool
{
message
=
strings
.
TrimSpace
(
message
)
if
message
==
""
{
return
false
}
if
statusCode
==
http
.
StatusForbidden
||
statusCode
==
http
.
StatusTooManyRequests
{
lower
:=
strings
.
ToLower
(
message
)
if
strings
.
Contains
(
lower
,
"<html"
)
||
strings
.
Contains
(
lower
,
"<!doctype html"
)
||
strings
.
Contains
(
lower
,
"window._cf_chl_opt"
)
{
return
false
}
}
return
true
}
func
formatSoraCloudflareChallengeMessage
(
base
string
,
headers
http
.
Header
,
body
[]
byte
)
string
{
return
soraerror
.
FormatCloudflareChallengeMessage
(
base
,
headers
,
body
)
}
func
extractUpstreamErrorCodeAndMessage
(
body
[]
byte
)
(
string
,
string
)
{
return
soraerror
.
ExtractUpstreamErrorCodeAndMessage
(
body
)
}
func
(
h
*
SoraGatewayHandler
)
handleStreamingAwareError
(
c
*
gin
.
Context
,
status
int
,
errType
,
message
string
,
streamStarted
bool
)
{
if
streamStarted
{
flusher
,
ok
:=
c
.
Writer
.
(
http
.
Flusher
)
if
ok
{
errorEvent
:=
fmt
.
Sprintf
(
`event: error`
+
"
\n
"
+
`data: {"error": {"type": "%s", "message": "%s"}}`
+
"
\n\n
"
,
errType
,
message
)
errorData
:=
map
[
string
]
any
{
"error"
:
map
[
string
]
string
{
"type"
:
errType
,
"message"
:
message
,
},
}
jsonBytes
,
err
:=
json
.
Marshal
(
errorData
)
if
err
!=
nil
{
_
=
c
.
Error
(
err
)
return
}
errorEvent
:=
fmt
.
Sprintf
(
"event: error
\n
data: %s
\n\n
"
,
string
(
jsonBytes
))
if
_
,
err
:=
fmt
.
Fprint
(
c
.
Writer
,
errorEvent
);
err
!=
nil
{
_
=
c
.
Error
(
err
)
}
...
...
backend/internal/handler/sora_gateway_handler_test.go
View file @
987589ea
...
...
@@ -43,6 +43,48 @@ func (s *stubSoraClient) CreateImageTask(ctx context.Context, account *service.A
func
(
s
*
stubSoraClient
)
CreateVideoTask
(
ctx
context
.
Context
,
account
*
service
.
Account
,
req
service
.
SoraVideoRequest
)
(
string
,
error
)
{
return
"task-video"
,
nil
}
func
(
s
*
stubSoraClient
)
CreateStoryboardTask
(
ctx
context
.
Context
,
account
*
service
.
Account
,
req
service
.
SoraStoryboardRequest
)
(
string
,
error
)
{
return
"task-video"
,
nil
}
func
(
s
*
stubSoraClient
)
UploadCharacterVideo
(
ctx
context
.
Context
,
account
*
service
.
Account
,
data
[]
byte
)
(
string
,
error
)
{
return
"cameo-1"
,
nil
}
func
(
s
*
stubSoraClient
)
GetCameoStatus
(
ctx
context
.
Context
,
account
*
service
.
Account
,
cameoID
string
)
(
*
service
.
SoraCameoStatus
,
error
)
{
return
&
service
.
SoraCameoStatus
{
Status
:
"finalized"
,
StatusMessage
:
"Completed"
,
DisplayNameHint
:
"Character"
,
UsernameHint
:
"user.character"
,
ProfileAssetURL
:
"https://example.com/avatar.webp"
,
},
nil
}
func
(
s
*
stubSoraClient
)
DownloadCharacterImage
(
ctx
context
.
Context
,
account
*
service
.
Account
,
imageURL
string
)
([]
byte
,
error
)
{
return
[]
byte
(
"avatar"
),
nil
}
func
(
s
*
stubSoraClient
)
UploadCharacterImage
(
ctx
context
.
Context
,
account
*
service
.
Account
,
data
[]
byte
)
(
string
,
error
)
{
return
"asset-pointer"
,
nil
}
func
(
s
*
stubSoraClient
)
FinalizeCharacter
(
ctx
context
.
Context
,
account
*
service
.
Account
,
req
service
.
SoraCharacterFinalizeRequest
)
(
string
,
error
)
{
return
"character-1"
,
nil
}
func
(
s
*
stubSoraClient
)
SetCharacterPublic
(
ctx
context
.
Context
,
account
*
service
.
Account
,
cameoID
string
)
error
{
return
nil
}
func
(
s
*
stubSoraClient
)
DeleteCharacter
(
ctx
context
.
Context
,
account
*
service
.
Account
,
characterID
string
)
error
{
return
nil
}
func
(
s
*
stubSoraClient
)
PostVideoForWatermarkFree
(
ctx
context
.
Context
,
account
*
service
.
Account
,
generationID
string
)
(
string
,
error
)
{
return
"s_post"
,
nil
}
func
(
s
*
stubSoraClient
)
DeletePost
(
ctx
context
.
Context
,
account
*
service
.
Account
,
postID
string
)
error
{
return
nil
}
func
(
s
*
stubSoraClient
)
GetWatermarkFreeURLCustom
(
ctx
context
.
Context
,
account
*
service
.
Account
,
parseURL
,
parseToken
,
postID
string
)
(
string
,
error
)
{
return
"https://example.com/no-watermark.mp4"
,
nil
}
func
(
s
*
stubSoraClient
)
EnhancePrompt
(
ctx
context
.
Context
,
account
*
service
.
Account
,
prompt
,
expansionLevel
string
,
durationS
int
)
(
string
,
error
)
{
return
"enhanced prompt"
,
nil
}
func
(
s
*
stubSoraClient
)
GetImageTask
(
ctx
context
.
Context
,
account
*
service
.
Account
,
taskID
string
)
(
*
service
.
SoraImageTaskStatus
,
error
)
{
return
&
service
.
SoraImageTaskStatus
{
ID
:
taskID
,
Status
:
"completed"
,
URLs
:
s
.
imageURLs
},
nil
}
...
...
@@ -88,7 +130,7 @@ func (r *stubAccountRepo) Delete(ctx context.Context, id int64) error
func
(
r
*
stubAccountRepo
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
service
.
Account
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
r
*
stubAccountRepo
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
accountType
,
status
,
search
string
)
([]
service
.
Account
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
stubAccountRepo
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
accountType
,
status
,
search
string
,
groupID
int64
)
([]
service
.
Account
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
r
*
stubAccountRepo
)
ListByGroup
(
ctx
context
.
Context
,
groupID
int64
)
([]
service
.
Account
,
error
)
{
...
...
@@ -495,3 +537,152 @@ func TestGenerateOpenAISessionHash_WithBody(t *testing.T) {
require
.
NotEmpty
(
t
,
hash3
)
require
.
NotEqual
(
t
,
hash
,
hash3
)
// 不同来源应产生不同 hash
}
func
TestSoraHandleStreamingAwareError_JSONEscaping
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
errType
string
message
string
}{
{
name
:
"包含双引号"
,
errType
:
"upstream_error"
,
message
:
`upstream returned "invalid" payload`
,
},
{
name
:
"包含换行和制表符"
,
errType
:
"rate_limit_error"
,
message
:
"line1
\n
line2
\t
tab"
,
},
{
name
:
"包含反斜杠"
,
errType
:
"upstream_error"
,
message
:
`path C:\Users\test\file.txt not found`
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
h
:=
&
SoraGatewayHandler
{}
h
.
handleStreamingAwareError
(
c
,
http
.
StatusBadGateway
,
tt
.
errType
,
tt
.
message
,
true
)
body
:=
w
.
Body
.
String
()
require
.
True
(
t
,
strings
.
HasPrefix
(
body
,
"event: error
\n
"
),
"应以 SSE error 事件开头"
)
require
.
True
(
t
,
strings
.
HasSuffix
(
body
,
"
\n\n
"
),
"应以 SSE 结束分隔符结尾"
)
lines
:=
strings
.
Split
(
strings
.
TrimSuffix
(
body
,
"
\n\n
"
),
"
\n
"
)
require
.
Len
(
t
,
lines
,
2
,
"SSE 错误事件应包含 event 行和 data 行"
)
require
.
Equal
(
t
,
"event: error"
,
lines
[
0
])
require
.
True
(
t
,
strings
.
HasPrefix
(
lines
[
1
],
"data: "
),
"第二行应为 data 前缀"
)
jsonStr
:=
strings
.
TrimPrefix
(
lines
[
1
],
"data: "
)
var
parsed
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
([]
byte
(
jsonStr
),
&
parsed
),
"data 行必须是合法 JSON"
)
errorObj
,
ok
:=
parsed
[
"error"
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
,
"JSON 中应包含 error 对象"
)
require
.
Equal
(
t
,
tt
.
errType
,
errorObj
[
"type"
])
require
.
Equal
(
t
,
tt
.
message
,
errorObj
[
"message"
])
})
}
}
func
TestSoraHandleFailoverExhausted_StreamPassesUpstreamMessage
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
h
:=
&
SoraGatewayHandler
{}
resp
:=
[]
byte
(
`{"error":{"message":"invalid \"prompt\"\nline2","code":"bad_request"}}`
)
h
.
handleFailoverExhausted
(
c
,
http
.
StatusBadGateway
,
nil
,
resp
,
true
)
body
:=
w
.
Body
.
String
()
require
.
True
(
t
,
strings
.
HasPrefix
(
body
,
"event: error
\n
"
))
require
.
True
(
t
,
strings
.
HasSuffix
(
body
,
"
\n\n
"
))
lines
:=
strings
.
Split
(
strings
.
TrimSuffix
(
body
,
"
\n\n
"
),
"
\n
"
)
require
.
Len
(
t
,
lines
,
2
)
jsonStr
:=
strings
.
TrimPrefix
(
lines
[
1
],
"data: "
)
var
parsed
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
([]
byte
(
jsonStr
),
&
parsed
))
errorObj
,
ok
:=
parsed
[
"error"
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"upstream_error"
,
errorObj
[
"type"
])
require
.
Equal
(
t
,
"invalid
\"
prompt
\"\n
line2"
,
errorObj
[
"message"
])
}
func
TestSoraHandleFailoverExhausted_CloudflareChallengeIncludesRay
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
headers
:=
http
.
Header
{}
headers
.
Set
(
"cf-ray"
,
"9d01b0e9ecc35829-SEA"
)
body
:=
[]
byte
(
`<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={};</script></body></html>`
)
h
:=
&
SoraGatewayHandler
{}
h
.
handleFailoverExhausted
(
c
,
http
.
StatusForbidden
,
headers
,
body
,
true
)
lines
:=
strings
.
Split
(
strings
.
TrimSuffix
(
w
.
Body
.
String
(),
"
\n\n
"
),
"
\n
"
)
require
.
Len
(
t
,
lines
,
2
)
jsonStr
:=
strings
.
TrimPrefix
(
lines
[
1
],
"data: "
)
var
parsed
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
([]
byte
(
jsonStr
),
&
parsed
))
errorObj
,
ok
:=
parsed
[
"error"
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"upstream_error"
,
errorObj
[
"type"
])
msg
,
_
:=
errorObj
[
"message"
]
.
(
string
)
require
.
Contains
(
t
,
msg
,
"Cloudflare challenge"
)
require
.
Contains
(
t
,
msg
,
"cf-ray: 9d01b0e9ecc35829-SEA"
)
}
func
TestSoraHandleFailoverExhausted_CfShield429MappedToRateLimitError
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
headers
:=
http
.
Header
{}
headers
.
Set
(
"cf-ray"
,
"9d03b68c086027a1-SEA"
)
body
:=
[]
byte
(
`{"error":{"code":"cf_shield_429","message":"shield blocked"}}`
)
h
:=
&
SoraGatewayHandler
{}
h
.
handleFailoverExhausted
(
c
,
http
.
StatusTooManyRequests
,
headers
,
body
,
true
)
lines
:=
strings
.
Split
(
strings
.
TrimSuffix
(
w
.
Body
.
String
(),
"
\n\n
"
),
"
\n
"
)
require
.
Len
(
t
,
lines
,
2
)
jsonStr
:=
strings
.
TrimPrefix
(
lines
[
1
],
"data: "
)
var
parsed
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
([]
byte
(
jsonStr
),
&
parsed
))
errorObj
,
ok
:=
parsed
[
"error"
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"rate_limit_error"
,
errorObj
[
"type"
])
msg
,
_
:=
errorObj
[
"message"
]
.
(
string
)
require
.
Contains
(
t
,
msg
,
"Cloudflare shield"
)
require
.
Contains
(
t
,
msg
,
"cf-ray: 9d03b68c086027a1-SEA"
)
}
func
TestExtractSoraFailoverHeaderInsights
(
t
*
testing
.
T
)
{
headers
:=
http
.
Header
{}
headers
.
Set
(
"cf-mitigated"
,
"challenge"
)
headers
.
Set
(
"content-type"
,
"text/html"
)
body
:=
[]
byte
(
`<script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script>`
)
rayID
,
mitigated
,
contentType
:=
extractSoraFailoverHeaderInsights
(
headers
,
body
)
require
.
Equal
(
t
,
"9cff2d62d83bb98d"
,
rayID
)
require
.
Equal
(
t
,
"challenge"
,
mitigated
)
require
.
Equal
(
t
,
"text/html"
,
contentType
)
}
backend/internal/pkg/claude/constants.go
View file @
987589ea
...
...
@@ -10,6 +10,7 @@ const (
BetaInterleavedThinking
=
"interleaved-thinking-2025-05-14"
BetaFineGrainedToolStreaming
=
"fine-grained-tool-streaming-2025-05-14"
BetaTokenCounting
=
"token-counting-2024-11-01"
BetaContext1M
=
"context-1m-2025-08-07"
)
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
...
...
@@ -77,6 +78,12 @@ var DefaultModels = []Model{
DisplayName
:
"Claude Opus 4.6"
,
CreatedAt
:
"2026-02-06T00:00:00Z"
,
},
{
ID
:
"claude-sonnet-4-6"
,
Type
:
"model"
,
DisplayName
:
"Claude Sonnet 4.6"
,
CreatedAt
:
"2026-02-18T00:00:00Z"
,
},
{
ID
:
"claude-sonnet-4-5-20250929"
,
Type
:
"model"
,
...
...
backend/internal/pkg/openai/oauth.go
View file @
987589ea
...
...
@@ -17,6 +17,8 @@ import (
const
(
// OAuth Client ID for OpenAI (Codex CLI official)
ClientID
=
"app_EMoamEEZ73f0CkXaXp7hrann"
// OAuth Client ID for Sora mobile flow (aligned with sora2api)
SoraClientID
=
"app_LlGpXReQgckcGGUo2JrYvtJK"
// OAuth endpoints
AuthorizeURL
=
"https://auth.openai.com/oauth/authorize"
...
...
backend/internal/repository/account_repo.go
View file @
987589ea
...
...
@@ -435,10 +435,10 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error {
}
func
(
r
*
accountRepository
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
service
.
Account
,
*
pagination
.
PaginationResult
,
error
)
{
return
r
.
ListWithFilters
(
ctx
,
params
,
""
,
""
,
""
,
""
)
return
r
.
ListWithFilters
(
ctx
,
params
,
""
,
""
,
""
,
""
,
0
)
}
func
(
r
*
accountRepository
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
accountType
,
status
,
search
string
)
([]
service
.
Account
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
accountRepository
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
accountType
,
status
,
search
string
,
groupID
int64
)
([]
service
.
Account
,
*
pagination
.
PaginationResult
,
error
)
{
q
:=
r
.
client
.
Account
.
Query
()
if
platform
!=
""
{
...
...
@@ -458,6 +458,9 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
if
search
!=
""
{
q
=
q
.
Where
(
dbaccount
.
NameContainsFold
(
search
))
}
if
groupID
>
0
{
q
=
q
.
Where
(
dbaccount
.
HasAccountGroupsWith
(
dbaccountgroup
.
GroupIDEQ
(
groupID
)))
}
total
,
err
:=
q
.
Count
(
ctx
)
if
err
!=
nil
{
...
...
backend/internal/repository/account_repo_integration_test.go
View file @
987589ea
...
...
@@ -238,7 +238,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
tt
.
setup
(
client
)
accounts
,
_
,
err
:=
repo
.
ListWithFilters
(
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
tt
.
platform
,
tt
.
accType
,
tt
.
status
,
tt
.
search
)
accounts
,
_
,
err
:=
repo
.
ListWithFilters
(
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
tt
.
platform
,
tt
.
accType
,
tt
.
status
,
tt
.
search
,
0
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
accounts
,
tt
.
wantCount
)
if
tt
.
validate
!=
nil
{
...
...
@@ -305,7 +305,7 @@ func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
s
.
Require
()
.
Len
(
got
.
Groups
,
1
,
"expected Groups to be populated"
)
s
.
Require
()
.
Equal
(
group
.
ID
,
got
.
Groups
[
0
]
.
ID
)
accounts
,
page
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
""
,
""
,
""
,
"acc"
)
accounts
,
page
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
""
,
""
,
""
,
"acc"
,
0
)
s
.
Require
()
.
NoError
(
err
,
"ListWithFilters"
)
s
.
Require
()
.
Equal
(
int64
(
1
),
page
.
Total
)
s
.
Require
()
.
Len
(
accounts
,
1
)
...
...
backend/internal/repository/openai_oauth_service.go
View file @
987589ea
...
...
@@ -4,6 +4,7 @@ import (
"context"
"net/http"
"net/url"
"strings"
"time"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
...
...
@@ -56,12 +57,49 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
}
func
(
s
*
openaiOAuthService
)
RefreshToken
(
ctx
context
.
Context
,
refreshToken
,
proxyURL
string
)
(
*
openai
.
TokenResponse
,
error
)
{
return
s
.
RefreshTokenWithClientID
(
ctx
,
refreshToken
,
proxyURL
,
""
)
}
func
(
s
*
openaiOAuthService
)
RefreshTokenWithClientID
(
ctx
context
.
Context
,
refreshToken
,
proxyURL
string
,
clientID
string
)
(
*
openai
.
TokenResponse
,
error
)
{
if
strings
.
TrimSpace
(
clientID
)
!=
""
{
return
s
.
refreshTokenWithClientID
(
ctx
,
refreshToken
,
proxyURL
,
strings
.
TrimSpace
(
clientID
))
}
clientIDs
:=
[]
string
{
openai
.
ClientID
,
openai
.
SoraClientID
,
}
seen
:=
make
(
map
[
string
]
struct
{},
len
(
clientIDs
))
var
lastErr
error
for
_
,
clientID
:=
range
clientIDs
{
clientID
=
strings
.
TrimSpace
(
clientID
)
if
clientID
==
""
{
continue
}
if
_
,
ok
:=
seen
[
clientID
];
ok
{
continue
}
seen
[
clientID
]
=
struct
{}{}
tokenResp
,
err
:=
s
.
refreshTokenWithClientID
(
ctx
,
refreshToken
,
proxyURL
,
clientID
)
if
err
==
nil
{
return
tokenResp
,
nil
}
lastErr
=
err
}
if
lastErr
!=
nil
{
return
nil
,
lastErr
}
return
nil
,
infraerrors
.
New
(
http
.
StatusBadGateway
,
"OPENAI_OAUTH_TOKEN_REFRESH_FAILED"
,
"token refresh failed"
)
}
func
(
s
*
openaiOAuthService
)
refreshTokenWithClientID
(
ctx
context
.
Context
,
refreshToken
,
proxyURL
,
clientID
string
)
(
*
openai
.
TokenResponse
,
error
)
{
client
:=
createOpenAIReqClient
(
proxyURL
)
formData
:=
url
.
Values
{}
formData
.
Set
(
"grant_type"
,
"refresh_token"
)
formData
.
Set
(
"refresh_token"
,
refreshToken
)
formData
.
Set
(
"client_id"
,
openai
.
C
lientID
)
formData
.
Set
(
"client_id"
,
c
lientID
)
formData
.
Set
(
"scope"
,
openai
.
RefreshScopes
)
var
tokenResp
openai
.
TokenResponse
...
...
backend/internal/repository/openai_oauth_service_test.go
View file @
987589ea
...
...
@@ -136,6 +136,60 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() {
require
.
Equal
(
s
.
T
(),
"rt2"
,
resp
.
RefreshToken
)
}
func
(
s
*
OpenAIOAuthServiceSuite
)
TestRefreshToken_FallbackToSoraClientID
()
{
var
seenClientIDs
[]
string
s
.
setupServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
if
err
:=
r
.
ParseForm
();
err
!=
nil
{
w
.
WriteHeader
(
http
.
StatusBadRequest
)
return
}
clientID
:=
r
.
PostForm
.
Get
(
"client_id"
)
seenClientIDs
=
append
(
seenClientIDs
,
clientID
)
if
clientID
==
openai
.
ClientID
{
w
.
WriteHeader
(
http
.
StatusBadRequest
)
_
,
_
=
io
.
WriteString
(
w
,
"invalid_grant"
)
return
}
if
clientID
==
openai
.
SoraClientID
{
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
io
.
WriteString
(
w
,
`{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`
)
return
}
w
.
WriteHeader
(
http
.
StatusBadRequest
)
}))
resp
,
err
:=
s
.
svc
.
RefreshToken
(
s
.
ctx
,
"rt"
,
""
)
require
.
NoError
(
s
.
T
(),
err
,
"RefreshToken"
)
require
.
Equal
(
s
.
T
(),
"at-sora"
,
resp
.
AccessToken
)
require
.
Equal
(
s
.
T
(),
"rt-sora"
,
resp
.
RefreshToken
)
require
.
Equal
(
s
.
T
(),
[]
string
{
openai
.
ClientID
,
openai
.
SoraClientID
},
seenClientIDs
)
}
func
(
s
*
OpenAIOAuthServiceSuite
)
TestRefreshToken_UseProvidedClientID
()
{
const
customClientID
=
"custom-client-id"
var
seenClientIDs
[]
string
s
.
setupServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
if
err
:=
r
.
ParseForm
();
err
!=
nil
{
w
.
WriteHeader
(
http
.
StatusBadRequest
)
return
}
clientID
:=
r
.
PostForm
.
Get
(
"client_id"
)
seenClientIDs
=
append
(
seenClientIDs
,
clientID
)
if
clientID
!=
customClientID
{
w
.
WriteHeader
(
http
.
StatusBadRequest
)
return
}
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
io
.
WriteString
(
w
,
`{"access_token":"at-custom","refresh_token":"rt-custom","token_type":"bearer","expires_in":3600}`
)
}))
resp
,
err
:=
s
.
svc
.
RefreshTokenWithClientID
(
s
.
ctx
,
"rt"
,
""
,
customClientID
)
require
.
NoError
(
s
.
T
(),
err
,
"RefreshTokenWithClientID"
)
require
.
Equal
(
s
.
T
(),
"at-custom"
,
resp
.
AccessToken
)
require
.
Equal
(
s
.
T
(),
"rt-custom"
,
resp
.
RefreshToken
)
require
.
Equal
(
s
.
T
(),
[]
string
{
customClientID
},
seenClientIDs
)
}
func
(
s
*
OpenAIOAuthServiceSuite
)
TestNonSuccessStatus_IncludesBody
()
{
s
.
setupServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
WriteHeader
(
http
.
StatusBadRequest
)
...
...
backend/internal/repository/usage_log_repo.go
View file @
987589ea
...
...
@@ -22,7 +22,7 @@ import (
"github.com/lib/pq"
)
const
usageLogSelectColumns
=
"id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, created_at"
const
usageLogSelectColumns
=
"id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort,
cache_ttl_overridden,
created_at"
// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL
var
dateFormatWhitelist
=
map
[
string
]
string
{
...
...
@@ -132,6 +132,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
image_size,
media_type,
reasoning_effort,
cache_ttl_overridden,
created_at
) VALUES (
$1, $2, $3, $4, $5,
...
...
@@ -139,7 +140,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32
, $33
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
...
...
@@ -192,6 +193,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
imageSize
,
mediaType
,
reasoningEffort
,
log
.
CacheTTLOverridden
,
createdAt
,
}
if
err
:=
scanSingleRow
(
ctx
,
sqlq
,
query
,
args
,
&
log
.
ID
,
&
log
.
CreatedAt
);
err
!=
nil
{
...
...
@@ -2221,6 +2223,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
imageSize
sql
.
NullString
mediaType
sql
.
NullString
reasoningEffort
sql
.
NullString
cacheTTLOverridden
bool
createdAt
time
.
Time
)
...
...
@@ -2257,6 +2260,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&
imageSize
,
&
mediaType
,
&
reasoningEffort
,
&
cacheTTLOverridden
,
&
createdAt
,
);
err
!=
nil
{
return
nil
,
err
...
...
@@ -2285,6 +2289,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
BillingType
:
int8
(
billingType
),
Stream
:
stream
,
ImageCount
:
imageCount
,
CacheTTLOverridden
:
cacheTTLOverridden
,
CreatedAt
:
createdAt
,
}
...
...
backend/internal/server/api_contract_test.go
View file @
987589ea
...
...
@@ -406,6 +406,7 @@ func TestAPIContracts(t *testing.T) {
"image_count": 0,
"image_size": null,
"media_type": null,
"cache_ttl_overridden": false,
"created_at": "2025-01-02T03:04:05Z",
"user_agent": null
}
...
...
@@ -945,7 +946,7 @@ func (s *stubAccountRepo) List(ctx context.Context, params pagination.Pagination
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
accountType
,
status
,
search
string
)
([]
service
.
Account
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
s
*
stubAccountRepo
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
accountType
,
status
,
search
string
,
groupID
int64
)
([]
service
.
Account
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
...
...
backend/internal/server/middleware/cors.go
View file @
987589ea
...
...
@@ -50,6 +50,19 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc {
}
allowedSet
[
origin
]
=
struct
{}{}
}
allowHeaders
:=
[]
string
{
"Content-Type"
,
"Content-Length"
,
"Accept-Encoding"
,
"X-CSRF-Token"
,
"Authorization"
,
"accept"
,
"origin"
,
"Cache-Control"
,
"X-Requested-With"
,
"X-API-Key"
,
}
// OpenAI Node SDK 会发送 x-stainless-* 请求头,需在 CORS 中显式放行。
openAIProperties
:=
[]
string
{
"lang"
,
"package-version"
,
"os"
,
"arch"
,
"retry-count"
,
"runtime"
,
"runtime-version"
,
"async"
,
"helper-method"
,
"poll-helper"
,
"custom-poll-interval"
,
"timeout"
,
}
for
_
,
prop
:=
range
openAIProperties
{
allowHeaders
=
append
(
allowHeaders
,
"x-stainless-"
+
prop
)
}
allowHeadersValue
:=
strings
.
Join
(
allowHeaders
,
", "
)
return
func
(
c
*
gin
.
Context
)
{
origin
:=
strings
.
TrimSpace
(
c
.
GetHeader
(
"Origin"
))
...
...
@@ -68,12 +81,11 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc {
if
allowCredentials
{
c
.
Writer
.
Header
()
.
Set
(
"Access-Control-Allow-Credentials"
,
"true"
)
}
c
.
Writer
.
Header
()
.
Set
(
"Access-Control-Allow-Headers"
,
"Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key"
)
c
.
Writer
.
Header
()
.
Set
(
"Access-Control-Allow-Headers"
,
allowHeadersValue
)
c
.
Writer
.
Header
()
.
Set
(
"Access-Control-Allow-Methods"
,
"POST, OPTIONS, GET, PUT, DELETE, PATCH"
)
c
.
Writer
.
Header
()
.
Set
(
"Access-Control-Expose-Headers"
,
"ETag"
)
c
.
Writer
.
Header
()
.
Set
(
"Access-Control-Max-Age"
,
"86400"
)
}
// 处理预检请求
if
c
.
Request
.
Method
==
http
.
MethodOptions
{
if
originAllowed
{
...
...
backend/internal/server/routes/admin.go
View file @
987589ea
...
...
@@ -34,6 +34,8 @@ func RegisterAdminRoutes(
// OpenAI OAuth
registerOpenAIOAuthRoutes
(
admin
,
h
)
// Sora OAuth(实现复用 OpenAI OAuth 服务,入口独立)
registerSoraOAuthRoutes
(
admin
,
h
)
// Gemini OAuth
registerGeminiOAuthRoutes
(
admin
,
h
)
...
...
@@ -276,6 +278,19 @@ func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
}
}
func
registerSoraOAuthRoutes
(
admin
*
gin
.
RouterGroup
,
h
*
handler
.
Handlers
)
{
sora
:=
admin
.
Group
(
"/sora"
)
{
sora
.
POST
(
"/generate-auth-url"
,
h
.
Admin
.
OpenAIOAuth
.
GenerateAuthURL
)
sora
.
POST
(
"/exchange-code"
,
h
.
Admin
.
OpenAIOAuth
.
ExchangeCode
)
sora
.
POST
(
"/refresh-token"
,
h
.
Admin
.
OpenAIOAuth
.
RefreshToken
)
sora
.
POST
(
"/st2at"
,
h
.
Admin
.
OpenAIOAuth
.
ExchangeSoraSessionToken
)
sora
.
POST
(
"/rt2at"
,
h
.
Admin
.
OpenAIOAuth
.
RefreshToken
)
sora
.
POST
(
"/accounts/:id/refresh"
,
h
.
Admin
.
OpenAIOAuth
.
RefreshAccountToken
)
sora
.
POST
(
"/create-from-oauth"
,
h
.
Admin
.
OpenAIOAuth
.
CreateAccountFromOAuth
)
}
}
func
registerGeminiOAuthRoutes
(
admin
*
gin
.
RouterGroup
,
h
*
handler
.
Handlers
)
{
gemini
:=
admin
.
Group
(
"/gemini"
)
{
...
...
@@ -306,6 +321,7 @@ func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
proxies
.
PUT
(
"/:id"
,
h
.
Admin
.
Proxy
.
Update
)
proxies
.
DELETE
(
"/:id"
,
h
.
Admin
.
Proxy
.
Delete
)
proxies
.
POST
(
"/:id/test"
,
h
.
Admin
.
Proxy
.
Test
)
proxies
.
POST
(
"/:id/quality-check"
,
h
.
Admin
.
Proxy
.
CheckQuality
)
proxies
.
GET
(
"/:id/stats"
,
h
.
Admin
.
Proxy
.
GetStats
)
proxies
.
GET
(
"/:id/accounts"
,
h
.
Admin
.
Proxy
.
GetProxyAccounts
)
proxies
.
POST
(
"/batch-delete"
,
h
.
Admin
.
Proxy
.
BatchDelete
)
...
...
backend/internal/server/routes/gateway.go
View file @
987589ea
package
routes
import
(
"net/http"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
...
...
@@ -41,16 +43,15 @@ func RegisterGatewayRoutes(
gateway
.
GET
(
"/usage"
,
h
.
Gateway
.
Usage
)
// OpenAI Responses API
gateway
.
POST
(
"/responses"
,
h
.
OpenAIGateway
.
Responses
)
}
// Sora Chat Completions
soraGateway
:=
r
.
Group
(
"/v1"
)
soraGateway
.
Use
(
soraBodyLimit
)
soraGateway
.
Use
(
clientRequestID
)
soraGateway
.
Use
(
opsErrorLogger
)
soraGateway
.
Use
(
gin
.
HandlerFunc
(
apiKeyAuth
))
{
soraGateway
.
POST
(
"/chat/completions"
,
h
.
SoraGateway
.
ChatCompletions
)
// 明确阻止旧入口误用到 Sora,避免客户端把 OpenAI Chat Completions 当作 Sora 入口
gateway
.
POST
(
"/chat/completions"
,
func
(
c
*
gin
.
Context
)
{
c
.
JSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
gin
.
H
{
"type"
:
"invalid_request_error"
,
"message"
:
"For Sora, use /sora/v1/chat/completions. OpenAI should use /v1/responses."
,
},
})
})
}
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
...
...
backend/internal/service/account.go
View file @
987589ea
...
...
@@ -786,6 +786,38 @@ func (a *Account) IsSessionIDMaskingEnabled() bool {
return
false
}
// IsCacheTTLOverrideEnabled 检查是否启用缓存 TTL 强制替换
// 仅适用于 Anthropic OAuth/SetupToken 类型账号
// 启用后将所有 cache creation tokens 归入指定的 TTL 类型(5m 或 1h)
func
(
a
*
Account
)
IsCacheTTLOverrideEnabled
()
bool
{
if
!
a
.
IsAnthropicOAuthOrSetupToken
()
{
return
false
}
if
a
.
Extra
==
nil
{
return
false
}
if
v
,
ok
:=
a
.
Extra
[
"cache_ttl_override_enabled"
];
ok
{
if
enabled
,
ok
:=
v
.
(
bool
);
ok
{
return
enabled
}
}
return
false
}
// GetCacheTTLOverrideTarget 获取缓存 TTL 强制替换的目标类型
// 返回 "5m" 或 "1h",默认 "5m"
func
(
a
*
Account
)
GetCacheTTLOverrideTarget
()
string
{
if
a
.
Extra
==
nil
{
return
"5m"
}
if
v
,
ok
:=
a
.
Extra
[
"cache_ttl_override_target"
];
ok
{
if
target
,
ok
:=
v
.
(
string
);
ok
&&
(
target
==
"5m"
||
target
==
"1h"
)
{
return
target
}
}
return
"5m"
}
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
// 返回 0 表示未启用
func
(
a
*
Account
)
GetWindowCostLimit
()
float64
{
...
...
backend/internal/service/account_service.go
View file @
987589ea
...
...
@@ -35,7 +35,7 @@ type AccountRepository interface {
Delete
(
ctx
context
.
Context
,
id
int64
)
error
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
Account
,
*
pagination
.
PaginationResult
,
error
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
accountType
,
status
,
search
string
)
([]
Account
,
*
pagination
.
PaginationResult
,
error
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
accountType
,
status
,
search
string
,
groupID
int64
)
([]
Account
,
*
pagination
.
PaginationResult
,
error
)
ListByGroup
(
ctx
context
.
Context
,
groupID
int64
)
([]
Account
,
error
)
ListActive
(
ctx
context
.
Context
)
([]
Account
,
error
)
ListByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
Account
,
error
)
...
...
backend/internal/service/account_service_delete_test.go
View file @
987589ea
...
...
@@ -79,7 +79,7 @@ func (s *accountRepoStub) List(ctx context.Context, params pagination.Pagination
panic
(
"unexpected List call"
)
}
func
(
s
*
accountRepoStub
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
accountType
,
status
,
search
string
)
([]
Account
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
s
*
accountRepoStub
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
accountType
,
status
,
search
string
,
groupID
int64
)
([]
Account
,
*
pagination
.
PaginationResult
,
error
)
{
panic
(
"unexpected ListWithFilters call"
)
}
...
...
backend/internal/service/account_test_service.go
View file @
987589ea
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment