Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
eb2dce92
Commit
eb2dce92
authored
Apr 06, 2026
by
陈曦
Browse files
升级v1.0.8 解决冲突
parents
7b83d6e7
339d906e
Changes
178
Hide whitespace changes
Inline
Side-by-side
backend/internal/service/admin_service_proxy_quality_test.go
View file @
eb2dce92
...
...
@@ -27,7 +27,7 @@ func TestFinalizeProxyQualityResult_ScoreAndGrade(t *testing.T) {
require
.
Contains
(
t
,
result
.
Summary
,
"挑战 1 项"
)
}
func
TestRunProxyQualityTarget_
Sora
Challenge
(
t
*
testing
.
T
)
{
func
TestRunProxyQualityTarget_
Cloudflare
Challenge
(
t
*
testing
.
T
)
{
server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
_
*
http
.
Request
)
{
w
.
Header
()
.
Set
(
"Content-Type"
,
"text/html"
)
w
.
Header
()
.
Set
(
"cf-ray"
,
"test-ray-123"
)
...
...
@@ -37,7 +37,7 @@ func TestRunProxyQualityTarget_SoraChallenge(t *testing.T) {
defer
server
.
Close
()
target
:=
proxyQualityTarget
{
Target
:
"
sora
"
,
Target
:
"
openai
"
,
URL
:
server
.
URL
,
Method
:
http
.
MethodGet
,
AllowedStatuses
:
map
[
int
]
struct
{}{
...
...
backend/internal/service/antigravity_smart_retry_test.go
View file @
eb2dce92
...
...
@@ -5,13 +5,12 @@ package service
import
(
"bytes"
"context"
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
"github.com/stretchr/testify/require"
"io"
"net/http"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
"github.com/stretchr/testify/require"
)
// stubSmartRetryCache 用于 handleSmartRetry 测试的 GatewayCache mock
...
...
@@ -81,17 +80,12 @@ func (m *mockSmartRetryUpstream) Do(req *http.Request, proxyURL string, accountI
m
.
responseBodies
[
respIdx
]
=
bodyBytes
}
// 用缓存的 body
字节重建新的
reader
var
body
io
.
ReadCloser
// 用缓存的 body
重建
reader
(支持重试场景多次读取)
cloned
:=
*
resp
if
m
.
responseBodies
[
respIdx
]
!=
nil
{
b
ody
=
io
.
NopCloser
(
bytes
.
NewReader
(
m
.
responseBodies
[
respIdx
]))
cloned
.
B
ody
=
io
.
NopCloser
(
bytes
.
NewReader
(
m
.
responseBodies
[
respIdx
]))
}
return
&
http
.
Response
{
StatusCode
:
resp
.
StatusCode
,
Header
:
resp
.
Header
.
Clone
(),
Body
:
body
,
},
respErr
return
&
cloned
,
respErr
}
func
(
m
*
mockSmartRetryUpstream
)
DoWithTLS
(
req
*
http
.
Request
,
proxyURL
string
,
accountID
int64
,
accountConcurrency
int
,
profile
*
tlsfingerprint
.
Profile
)
(
*
http
.
Response
,
error
)
{
...
...
backend/internal/service/api_key_auth_cache.go
View file @
eb2dce92
...
...
@@ -49,10 +49,6 @@ type APIKeyAuthGroupSnapshot struct {
ImagePrice1K
*
float64
`json:"image_price_1k,omitempty"`
ImagePrice2K
*
float64
`json:"image_price_2k,omitempty"`
ImagePrice4K
*
float64
`json:"image_price_4k,omitempty"`
SoraImagePrice360
*
float64
`json:"sora_image_price_360,omitempty"`
SoraImagePrice540
*
float64
`json:"sora_image_price_540,omitempty"`
SoraVideoPricePerRequest
*
float64
`json:"sora_video_price_per_request,omitempty"`
SoraVideoPricePerRequestHD
*
float64
`json:"sora_video_price_per_request_hd,omitempty"`
ClaudeCodeOnly
bool
`json:"claude_code_only"`
FallbackGroupID
*
int64
`json:"fallback_group_id,omitempty"`
FallbackGroupIDOnInvalidRequest
*
int64
`json:"fallback_group_id_on_invalid_request,omitempty"`
...
...
backend/internal/service/api_key_auth_cache_impl.go
View file @
eb2dce92
...
...
@@ -234,10 +234,6 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
ImagePrice1K
:
apiKey
.
Group
.
ImagePrice1K
,
ImagePrice2K
:
apiKey
.
Group
.
ImagePrice2K
,
ImagePrice4K
:
apiKey
.
Group
.
ImagePrice4K
,
SoraImagePrice360
:
apiKey
.
Group
.
SoraImagePrice360
,
SoraImagePrice540
:
apiKey
.
Group
.
SoraImagePrice540
,
SoraVideoPricePerRequest
:
apiKey
.
Group
.
SoraVideoPricePerRequest
,
SoraVideoPricePerRequestHD
:
apiKey
.
Group
.
SoraVideoPricePerRequestHD
,
ClaudeCodeOnly
:
apiKey
.
Group
.
ClaudeCodeOnly
,
FallbackGroupID
:
apiKey
.
Group
.
FallbackGroupID
,
FallbackGroupIDOnInvalidRequest
:
apiKey
.
Group
.
FallbackGroupIDOnInvalidRequest
,
...
...
@@ -293,10 +289,6 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
ImagePrice1K
:
snapshot
.
Group
.
ImagePrice1K
,
ImagePrice2K
:
snapshot
.
Group
.
ImagePrice2K
,
ImagePrice4K
:
snapshot
.
Group
.
ImagePrice4K
,
SoraImagePrice360
:
snapshot
.
Group
.
SoraImagePrice360
,
SoraImagePrice540
:
snapshot
.
Group
.
SoraImagePrice540
,
SoraVideoPricePerRequest
:
snapshot
.
Group
.
SoraVideoPricePerRequest
,
SoraVideoPricePerRequestHD
:
snapshot
.
Group
.
SoraVideoPricePerRequestHD
,
ClaudeCodeOnly
:
snapshot
.
Group
.
ClaudeCodeOnly
,
FallbackGroupID
:
snapshot
.
Group
.
FallbackGroupID
,
FallbackGroupIDOnInvalidRequest
:
snapshot
.
Group
.
FallbackGroupIDOnInvalidRequest
,
...
...
backend/internal/service/billing_service.go
View file @
eb2dce92
...
...
@@ -808,14 +808,6 @@ type ImagePriceConfig struct {
Price4K
*
float64
// 4K 尺寸价格(nil 表示使用默认值)
}
// SoraPriceConfig Sora 按次计费配置
type
SoraPriceConfig
struct
{
ImagePrice360
*
float64
ImagePrice540
*
float64
VideoPricePerRequest
*
float64
VideoPricePerRequestHD
*
float64
}
// CalculateImageCost 计算图片生成费用
// model: 请求的模型名称(用于获取 LiteLLM 默认价格)
// imageSize: 图片尺寸 "1K", "2K", "4K"
...
...
@@ -846,65 +838,6 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag
}
}
// CalculateSoraImageCost 计算 Sora 图片按次费用
func
(
s
*
BillingService
)
CalculateSoraImageCost
(
imageSize
string
,
imageCount
int
,
groupConfig
*
SoraPriceConfig
,
rateMultiplier
float64
)
*
CostBreakdown
{
if
imageCount
<=
0
{
return
&
CostBreakdown
{}
}
unitPrice
:=
0.0
if
groupConfig
!=
nil
{
switch
imageSize
{
case
"540"
:
if
groupConfig
.
ImagePrice540
!=
nil
{
unitPrice
=
*
groupConfig
.
ImagePrice540
}
default
:
if
groupConfig
.
ImagePrice360
!=
nil
{
unitPrice
=
*
groupConfig
.
ImagePrice360
}
}
}
totalCost
:=
unitPrice
*
float64
(
imageCount
)
if
rateMultiplier
<=
0
{
rateMultiplier
=
1.0
}
actualCost
:=
totalCost
*
rateMultiplier
return
&
CostBreakdown
{
TotalCost
:
totalCost
,
ActualCost
:
actualCost
,
}
}
// CalculateSoraVideoCost 计算 Sora 视频按次费用
func
(
s
*
BillingService
)
CalculateSoraVideoCost
(
model
string
,
groupConfig
*
SoraPriceConfig
,
rateMultiplier
float64
)
*
CostBreakdown
{
unitPrice
:=
0.0
if
groupConfig
!=
nil
{
modelLower
:=
strings
.
ToLower
(
model
)
if
strings
.
Contains
(
modelLower
,
"sora2pro-hd"
)
{
if
groupConfig
.
VideoPricePerRequestHD
!=
nil
{
unitPrice
=
*
groupConfig
.
VideoPricePerRequestHD
}
}
if
unitPrice
<=
0
&&
groupConfig
.
VideoPricePerRequest
!=
nil
{
unitPrice
=
*
groupConfig
.
VideoPricePerRequest
}
}
totalCost
:=
unitPrice
if
rateMultiplier
<=
0
{
rateMultiplier
=
1.0
}
actualCost
:=
totalCost
*
rateMultiplier
return
&
CostBreakdown
{
TotalCost
:
totalCost
,
ActualCost
:
actualCost
,
}
}
// getImageUnitPrice 获取图片单价
func
(
s
*
BillingService
)
getImageUnitPrice
(
model
string
,
imageSize
string
,
groupConfig
*
ImagePriceConfig
)
float64
{
// 优先使用分组配置的价格
...
...
backend/internal/service/billing_service_test.go
View file @
eb2dce92
...
...
@@ -363,28 +363,6 @@ func TestCalculateImageCost(t *testing.T) {
require
.
InDelta
(
t
,
0.134
*
3
,
cost
.
ActualCost
,
1e-10
)
}
func
TestCalculateSoraVideoCost
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
price
:=
0.5
cfg
:=
&
SoraPriceConfig
{
VideoPricePerRequest
:
&
price
}
cost
:=
svc
.
CalculateSoraVideoCost
(
"sora-video"
,
cfg
,
1.0
)
require
.
InDelta
(
t
,
0.5
,
cost
.
TotalCost
,
1e-10
)
}
func
TestCalculateSoraVideoCost_HDModel
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
hdPrice
:=
1.0
normalPrice
:=
0.5
cfg
:=
&
SoraPriceConfig
{
VideoPricePerRequest
:
&
normalPrice
,
VideoPricePerRequestHD
:
&
hdPrice
,
}
cost
:=
svc
.
CalculateSoraVideoCost
(
"sora2pro-hd"
,
cfg
,
1.0
)
require
.
InDelta
(
t
,
1.0
,
cost
.
TotalCost
,
1e-10
)
}
func
TestIsModelSupported
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
...
...
@@ -464,33 +442,6 @@ func TestForceUpdatePricing_NilService(t *testing.T) {
require
.
Contains
(
t
,
err
.
Error
(),
"not initialized"
)
}
func
TestCalculateSoraImageCost
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
price360
:=
0.05
price540
:=
0.08
cfg
:=
&
SoraPriceConfig
{
ImagePrice360
:
&
price360
,
ImagePrice540
:
&
price540
}
cost
:=
svc
.
CalculateSoraImageCost
(
"360"
,
2
,
cfg
,
1.0
)
require
.
InDelta
(
t
,
0.10
,
cost
.
TotalCost
,
1e-10
)
cost540
:=
svc
.
CalculateSoraImageCost
(
"540"
,
1
,
cfg
,
2.0
)
require
.
InDelta
(
t
,
0.08
,
cost540
.
TotalCost
,
1e-10
)
require
.
InDelta
(
t
,
0.16
,
cost540
.
ActualCost
,
1e-10
)
}
func
TestCalculateSoraImageCost_ZeroCount
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
cost
:=
svc
.
CalculateSoraImageCost
(
"360"
,
0
,
nil
,
1.0
)
require
.
Equal
(
t
,
0.0
,
cost
.
TotalCost
)
}
func
TestCalculateSoraVideoCost_NilConfig
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
cost
:=
svc
.
CalculateSoraVideoCost
(
"sora-video"
,
nil
,
1.0
)
require
.
Equal
(
t
,
0.0
,
cost
.
TotalCost
)
}
func
TestCalculateCostWithLongContext_PropagatesError
(
t
*
testing
.
T
)
{
// 使用空的 fallback prices 让 GetModelPricing 失败
svc
:=
&
BillingService
{
...
...
backend/internal/service/channel_service.go
View file @
eb2dce92
...
...
@@ -197,10 +197,8 @@ func newEmptyChannelCache() *channelCache {
}
// expandPricingToCache 将渠道的模型定价展开到缓存(按分组+平台维度)。
// antigravity 平台同时服务 Claude 和 Gemini 模型,需匹配 anthropic/gemini 的定价条目。
// 缓存 key 使用定价条目的原始平台(pricing.Platform),而非分组平台,
// 避免跨平台同名模型(如 anthropic 和 gemini 都有 "model-x")互相覆盖。
// 查找时通过 lookupPricingAcrossPlatforms() 依次尝试所有匹配平台。
// 各平台严格独立:antigravity 分组只匹配 antigravity 定价,不会匹配 anthropic/gemini 的定价。
// 查找时通过 lookupPricingAcrossPlatforms() 在本平台内查找。
func
expandPricingToCache
(
cache
*
channelCache
,
ch
*
Channel
,
gid
int64
,
platform
string
)
{
for
j
:=
range
ch
.
ModelPricing
{
pricing
:=
&
ch
.
ModelPricing
[
j
]
...
...
@@ -226,8 +224,7 @@ func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform
}
// expandMappingToCache 将渠道的模型映射展开到缓存(按分组+平台维度)。
// antigravity 平台同时服务 Claude 和 Gemini 模型。
// 缓存 key 使用映射条目的原始平台(mappingPlatform),避免跨平台同名映射覆盖。
// 各平台严格独立:antigravity 分组只匹配 antigravity 映射。
func
expandMappingToCache
(
cache
*
channelCache
,
ch
*
Channel
,
gid
int64
,
platform
string
)
{
for
_
,
mappingPlatform
:=
range
matchingPlatforms
(
platform
)
{
platformMapping
,
ok
:=
ch
.
ModelMapping
[
mappingPlatform
]
...
...
@@ -311,23 +308,14 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
// invalidateCache 使缓存失效,让下次读取时自然重建
// isPlatformPricingMatch 判断定价条目的平台是否匹配分组平台。
// antigravity 平台同时服务 Claude(anthropic)和 Gemini(gemini)模型,
// 因此 antigravity 分组应匹配 anthropic 和 gemini 的定价条目。
// 各平台(antigravity / anthropic / gemini / openai)严格独立,不跨平台匹配。
func
isPlatformPricingMatch
(
groupPlatform
,
pricingPlatform
string
)
bool
{
if
groupPlatform
==
pricingPlatform
{
return
true
}
if
groupPlatform
==
PlatformAntigravity
{
return
pricingPlatform
==
PlatformAnthropic
||
pricingPlatform
==
PlatformGemini
}
return
false
return
groupPlatform
==
pricingPlatform
}
// matchingPlatforms 返回分组平台对应的所有可匹配平台列表。
// matchingPlatforms 返回分组平台对应的可匹配平台列表。
// 各平台严格独立,只返回自身。
func
matchingPlatforms
(
groupPlatform
string
)
[]
string
{
if
groupPlatform
==
PlatformAntigravity
{
return
[]
string
{
PlatformAntigravity
,
PlatformAnthropic
,
PlatformGemini
}
}
return
[]
string
{
groupPlatform
}
}
func
(
s
*
ChannelService
)
invalidateCache
()
{
...
...
@@ -364,10 +352,8 @@ func (c *channelCache) matchWildcardMapping(groupID int64, platform, modelLower
return
""
}
// lookupPricingAcrossPlatforms 在所有匹配平台中查找模型定价。
// antigravity 分组的缓存 key 使用定价条目的原始平台,因此查找时需依次尝试
// matchingPlatforms() 返回的所有平台(antigravity → anthropic → gemini),
// 返回第一个命中的结果。非 antigravity 平台只尝试自身。
// lookupPricingAcrossPlatforms 在分组平台内查找模型定价。
// 各平台严格独立,只在本平台内查找(先精确匹配,再通配符)。
func
lookupPricingAcrossPlatforms
(
cache
*
channelCache
,
groupID
int64
,
groupPlatform
,
modelLower
string
)
*
ChannelModelPricing
{
for
_
,
p
:=
range
matchingPlatforms
(
groupPlatform
)
{
key
:=
channelModelKey
{
groupID
:
groupID
,
platform
:
p
,
model
:
modelLower
}
...
...
@@ -384,7 +370,7 @@ func lookupPricingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatf
return
nil
}
// lookupMappingAcrossPlatforms 在
所有匹配
平台
中
查找模型映射。
// lookupMappingAcrossPlatforms 在
分组
平台
内
查找模型映射。
// 逻辑与 lookupPricingAcrossPlatforms 相同:先精确查找,再通配符。
func
lookupMappingAcrossPlatforms
(
cache
*
channelCache
,
groupID
int64
,
groupPlatform
,
modelLower
string
)
string
{
for
_
,
p
:=
range
matchingPlatforms
(
groupPlatform
)
{
...
...
@@ -442,8 +428,7 @@ func (s *ChannelService) lookupGroupChannel(ctx context.Context, groupID int64)
}
// GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径 O(1))。
// antigravity 分组依次尝试所有匹配平台(antigravity → anthropic → gemini),
// 确保跨平台同名模型各自独立匹配。
// 各平台严格独立,只在本平台内查找定价。
func
(
s
*
ChannelService
)
GetChannelModelPricing
(
ctx
context
.
Context
,
groupID
int64
,
model
string
)
*
ChannelModelPricing
{
lk
,
err
:=
s
.
lookupGroupChannel
(
ctx
,
groupID
)
if
err
!=
nil
{
...
...
@@ -524,7 +509,7 @@ func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappi
}
// checkRestricted 基于已查找的渠道信息检查模型是否被限制。
//
antigravity 分组依次尝试所有匹配
平台的定价列表。
//
只在本
平台的定价列表
中查找
。
func
checkRestricted
(
lk
*
channelLookup
,
groupID
int64
,
model
string
)
bool
{
if
!
lk
.
channel
.
RestrictModels
{
return
false
...
...
backend/internal/service/channel_service_test.go
View file @
eb2dce92
...
...
@@ -1932,8 +1932,8 @@ func TestIsPlatformPricingMatch(t *testing.T) {
pricingPlatform
string
want
bool
}{
{
"antigravity match
es
anthropic"
,
PlatformAntigravity
,
PlatformAnthropic
,
tru
e
},
{
"antigravity match
es
gemini"
,
PlatformAntigravity
,
PlatformGemini
,
tru
e
},
{
"antigravity
does NOT
match anthropic"
,
PlatformAntigravity
,
PlatformAnthropic
,
fals
e
},
{
"antigravity
does NOT
match gemini"
,
PlatformAntigravity
,
PlatformGemini
,
fals
e
},
{
"antigravity matches antigravity"
,
PlatformAntigravity
,
PlatformAntigravity
,
true
},
{
"antigravity does NOT match openai"
,
PlatformAntigravity
,
PlatformOpenAI
,
false
},
{
"anthropic matches anthropic"
,
PlatformAnthropic
,
PlatformAnthropic
,
true
},
...
...
@@ -1963,7 +1963,7 @@ func TestMatchingPlatforms(t *testing.T) {
groupPlatform
string
want
[]
string
}{
{
"antigravity returns
all three
"
,
PlatformAntigravity
,
[]
string
{
PlatformAntigravity
,
PlatformAnthropic
,
PlatformGemini
}},
{
"antigravity returns
itself only
"
,
PlatformAntigravity
,
[]
string
{
PlatformAntigravity
}},
{
"anthropic returns itself"
,
PlatformAnthropic
,
[]
string
{
PlatformAnthropic
}},
{
"gemini returns itself"
,
PlatformGemini
,
[]
string
{
PlatformGemini
}},
{
"openai returns itself"
,
PlatformOpenAI
,
[]
string
{
PlatformOpenAI
}},
...
...
@@ -1978,12 +1978,12 @@ func TestMatchingPlatforms(t *testing.T) {
}
// ===========================================================================
// 9. Antigravity
cross-platform channel pricing
// 9. Antigravity
platform isolation — no cross-platform pricing leakage
// ===========================================================================
func
TestGetChannelModelPricing_AntigravityCrossPlatform
(
t
*
testing
.
T
)
{
func
TestGetChannelModelPricing_Antigravity
DoesNotSee
CrossPlatform
Pricing
(
t
*
testing
.
T
)
{
// Channel has anthropic pricing for claude-opus-4-6.
// Group 10 is antigravity — should see the anthropic pricing.
// Group 10 is antigravity — should
NOT
see the anthropic pricing.
ch
:=
Channel
{
ID
:
1
,
Status
:
StatusActive
,
...
...
@@ -1996,9 +1996,7 @@ func TestGetChannelModelPricing_AntigravityCrossPlatform(t *testing.T) {
svc
:=
newTestChannelService
(
repo
)
result
:=
svc
.
GetChannelModelPricing
(
context
.
Background
(),
10
,
"claude-opus-4-6"
)
require
.
NotNil
(
t
,
result
,
"antigravity group should see anthropic pricing"
)
require
.
Equal
(
t
,
int64
(
100
),
result
.
ID
)
require
.
InDelta
(
t
,
15e-6
,
*
result
.
InputPrice
,
1e-12
)
require
.
Nil
(
t
,
result
,
"antigravity group should NOT see anthropic-platform pricing"
)
}
func
TestGetChannelModelPricing_AnthropicCannotSeeAntigravityPricing
(
t
*
testing
.
T
)
{
...
...
@@ -2020,12 +2018,12 @@ func TestGetChannelModelPricing_AnthropicCannotSeeAntigravityPricing(t *testing.
}
// ===========================================================================
// 10. Antigravity cross-platform model mapping
// 10. Antigravity
platform isolation — no
cross-platform model mapping
// ===========================================================================
func
TestResolveChannelMapping_AntigravityCrossPlatform
(
t
*
testing
.
T
)
{
func
TestResolveChannelMapping_Antigravity
DoesNotSee
CrossPlatform
Mapping
(
t
*
testing
.
T
)
{
// Channel has anthropic model mapping: claude-opus-4-5 → claude-opus-4-6.
// Group 10 is antigravity — should apply the anthropic mapping.
// Group 10 is antigravity — should
NOT
apply the anthropic mapping.
ch
:=
Channel
{
ID
:
1
,
Status
:
StatusActive
,
...
...
@@ -2040,18 +2038,17 @@ func TestResolveChannelMapping_AntigravityCrossPlatform(t *testing.T) {
svc
:=
newTestChannelService
(
repo
)
result
:=
svc
.
ResolveChannelMapping
(
context
.
Background
(),
10
,
"claude-opus-4-5"
)
require
.
True
(
t
,
result
.
Mapped
,
"antigravity group should apply anthropic mapping"
)
require
.
Equal
(
t
,
"claude-opus-4-6"
,
result
.
MappedModel
)
require
.
Equal
(
t
,
int64
(
1
),
result
.
ChannelID
)
require
.
False
(
t
,
result
.
Mapped
,
"antigravity group should NOT apply anthropic mapping"
)
require
.
Equal
(
t
,
"claude-opus-4-5"
,
result
.
MappedModel
)
}
// ===========================================================================
// 11. Antigravity
cross-
platform same-name model
— no overwrite
// 11. Antigravity platform
isolation —
same-name model
across platforms
// ===========================================================================
func
TestGetChannelModelPricing_Antigravity
SameModelDifferent
Platforms
(
t
*
testing
.
T
)
{
func
TestGetChannelModelPricing_Antigravity
DoesNotSeeSameModelFromOther
Platforms
(
t
*
testing
.
T
)
{
// anthropic 和 gemini 都定义了同名模型 "shared-model",价格不同。
// antigravity 分组
应能分别查到各自的定价,而不是后者覆盖前者
。
// antigravity 分组
不应看到任何一个(各平台严格独立)
。
ch
:=
Channel
{
ID
:
1
,
Status
:
StatusActive
,
...
...
@@ -2064,17 +2061,13 @@ func TestGetChannelModelPricing_AntigravitySameModelDifferentPlatforms(t *testin
repo
:=
makeStandardRepo
(
ch
,
map
[
int64
]
string
{
10
:
PlatformAntigravity
})
svc
:=
newTestChannelService
(
repo
)
// antigravity 分组查找 "shared-model":应命中第一个匹配(按 matchingPlatforms 顺序 antigravity→anthropic→gemini)
result
:=
svc
.
GetChannelModelPricing
(
context
.
Background
(),
10
,
"shared-model"
)
require
.
NotNil
(
t
,
result
,
"antigravity group should find pricing for shared-model"
)
// 第一个匹配应该是 anthropic(matchingPlatforms 返回 [antigravity, anthropic, gemini])
require
.
Equal
(
t
,
int64
(
200
),
result
.
ID
)
require
.
InDelta
(
t
,
10e-6
,
*
result
.
InputPrice
,
1e-12
)
require
.
Nil
(
t
,
result
,
"antigravity group should NOT see anthropic/gemini-platform pricing"
)
}
func
TestGetChannelModelPricing_Antigravity
Only
GeminiPricing
(
t
*
testing
.
T
)
{
func
TestGetChannelModelPricing_Antigravity
DoesNotSee
Gemini
Only
Pricing
(
t
*
testing
.
T
)
{
// 只有 gemini 平台定义了模型 "gemini-model"。
// antigravity 分组
应能查
到 gemini 的定价。
// antigravity 分组
不应看
到 gemini 的定价。
ch
:=
Channel
{
ID
:
1
,
Status
:
StatusActive
,
...
...
@@ -2087,14 +2080,12 @@ func TestGetChannelModelPricing_AntigravityOnlyGeminiPricing(t *testing.T) {
svc
:=
newTestChannelService
(
repo
)
result
:=
svc
.
GetChannelModelPricing
(
context
.
Background
(),
10
,
"gemini-model"
)
require
.
NotNil
(
t
,
result
,
"antigravity group should find gemini pricing"
)
require
.
Equal
(
t
,
int64
(
300
),
result
.
ID
)
require
.
InDelta
(
t
,
2e-6
,
*
result
.
InputPrice
,
1e-12
)
require
.
Nil
(
t
,
result
,
"antigravity group should NOT see gemini-platform pricing"
)
}
func
TestGetChannelModelPricing_AntigravityWildcard
C
ro
ss
Platform
NoOverwrite
(
t
*
testing
.
T
)
{
// anthropic 和 gemini 都有 "shared-*" 通配符定价
,价格不同
。
// antigravity 分组
查找 "shared-model" 应命中第一个匹配而非被覆盖
。
func
TestGetChannelModelPricing_Antigravity
DoesNotSee
Wildcard
F
ro
mOther
Platform
s
(
t
*
testing
.
T
)
{
// anthropic 和 gemini 都有 "shared-*" 通配符定价。
// antigravity 分组
不应命中任何一个
。
ch
:=
Channel
{
ID
:
1
,
Status
:
StatusActive
,
...
...
@@ -2108,15 +2099,12 @@ func TestGetChannelModelPricing_AntigravityWildcardCrossPlatformNoOverwrite(t *t
svc
:=
newTestChannelService
(
repo
)
result
:=
svc
.
GetChannelModelPricing
(
context
.
Background
(),
10
,
"shared-model"
)
require
.
NotNil
(
t
,
result
,
"antigravity group should find wildcard pricing for shared-model"
)
// 两个通配符都存在,应命中 anthropic 的(matchingPlatforms 顺序)
require
.
Equal
(
t
,
int64
(
400
),
result
.
ID
)
require
.
InDelta
(
t
,
10e-6
,
*
result
.
InputPrice
,
1e-12
)
require
.
Nil
(
t
,
result
,
"antigravity group should NOT see wildcard pricing from other platforms"
)
}
func
TestResolveChannelMapping_Antigravity
SameModelDifferent
Platforms
(
t
*
testing
.
T
)
{
func
TestResolveChannelMapping_Antigravity
DoesNotSeeMappingFromOther
Platforms
(
t
*
testing
.
T
)
{
// anthropic 和 gemini 都定义了同名模型映射 "alias" → 不同目标。
// antigravity 分组应命中
anthropic 的映射(按 matchingPlatforms 顺序)
。
// antigravity 分组
不
应命中
任何一个
。
ch
:=
Channel
{
ID
:
1
,
Status
:
StatusActive
,
...
...
@@ -2130,13 +2118,13 @@ func TestResolveChannelMapping_AntigravitySameModelDifferentPlatforms(t *testing
svc
:=
newTestChannelService
(
repo
)
result
:=
svc
.
ResolveChannelMapping
(
context
.
Background
(),
10
,
"alias"
)
require
.
Tru
e
(
t
,
result
.
Mapped
)
require
.
Equal
(
t
,
"a
nthropic-target
"
,
result
.
MappedModel
)
require
.
Fals
e
(
t
,
result
.
Mapped
,
"antigravity group should NOT see mapping from other platforms"
)
require
.
Equal
(
t
,
"a
lias
"
,
result
.
MappedModel
)
}
func
TestCheckRestricted_Antigravity
SameModelDifferent
Platforms
(
t
*
testing
.
T
)
{
func
TestCheckRestricted_Antigravity
DoesNotSeeModelsFromOther
Platforms
(
t
*
testing
.
T
)
{
// anthropic 和 gemini 都定义了同名模型 "shared-model"。
// antigravity 分组启用了 RestrictModels,"shared-model" 应
不
被限制。
// antigravity 分组启用了 RestrictModels,"shared-model" 应被限制
(各平台独立)
。
ch
:=
Channel
{
ID
:
1
,
Status
:
StatusActive
,
...
...
@@ -2151,13 +2139,39 @@ func TestCheckRestricted_AntigravitySameModelDifferentPlatforms(t *testing.T) {
svc
:=
newTestChannelService
(
repo
)
restricted
:=
svc
.
IsModelRestricted
(
context
.
Background
(),
10
,
"shared-model"
)
require
.
Fals
e
(
t
,
restricted
,
"shared-model should
not
be restricted for antigravity"
)
require
.
Tru
e
(
t
,
restricted
,
"shared-model
from other platforms
should be restricted for antigravity"
)
// 未定义的模型应被限制
restricted
=
svc
.
IsModelRestricted
(
context
.
Background
(),
10
,
"unknown-model"
)
require
.
True
(
t
,
restricted
,
"unknown-model should be restricted for antigravity"
)
}
func
TestGetChannelModelPricing_AntigravityOwnPricingWorks
(
t
*
testing
.
T
)
{
// antigravity 平台自己配置的定价应正常生效(覆盖 Claude 和 Gemini 模型)。
ch
:=
Channel
{
ID
:
1
,
Status
:
StatusActive
,
GroupIDs
:
[]
int64
{
10
},
ModelPricing
:
[]
ChannelModelPricing
{
{
ID
:
600
,
Platform
:
PlatformAntigravity
,
Models
:
[]
string
{
"claude-*"
},
InputPrice
:
testPtrFloat64
(
15e-6
)},
{
ID
:
601
,
Platform
:
PlatformAntigravity
,
Models
:
[]
string
{
"gemini-*"
},
InputPrice
:
testPtrFloat64
(
2e-6
)},
},
}
repo
:=
makeStandardRepo
(
ch
,
map
[
int64
]
string
{
10
:
PlatformAntigravity
})
svc
:=
newTestChannelService
(
repo
)
// Claude 模型匹配 antigravity 定价
result
:=
svc
.
GetChannelModelPricing
(
context
.
Background
(),
10
,
"claude-sonnet-4"
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
int64
(
600
),
result
.
ID
)
require
.
InDelta
(
t
,
15e-6
,
*
result
.
InputPrice
,
1e-12
)
// Gemini 模型匹配 antigravity 定价
result
=
svc
.
GetChannelModelPricing
(
context
.
Background
(),
10
,
"gemini-2.5-flash"
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
int64
(
601
),
result
.
ID
)
require
.
InDelta
(
t
,
2e-6
,
*
result
.
InputPrice
,
1e-12
)
}
func
TestGetChannelModelPricing_NonAntigravityUnaffected
(
t
*
testing
.
T
)
{
// 确保非 antigravity 平台的行为不受影响。
// anthropic 分组只能看到 anthropic 的定价,看不到 gemini 的。
...
...
backend/internal/service/domain_constants.go
View file @
eb2dce92
...
...
@@ -24,7 +24,6 @@ const (
PlatformOpenAI
=
domain
.
PlatformOpenAI
PlatformGemini
=
domain
.
PlatformGemini
PlatformAntigravity
=
domain
.
PlatformAntigravity
PlatformSora
=
domain
.
PlatformSora
)
// Account type constants
...
...
@@ -107,7 +106,6 @@ const (
SettingKeyLinuxDoConnectRedirectURL
=
"linuxdo_connect_redirect_url"
// OEM设置
SettingKeySoraClientEnabled
=
"sora_client_enabled"
// 是否启用 Sora 客户端(管理员手动控制)
SettingKeySiteName
=
"site_name"
// 网站名称
SettingKeySiteLogo
=
"site_logo"
// 网站Logo (base64)
SettingKeySiteSubtitle
=
"site_subtitle"
// 网站副标题
...
...
@@ -199,27 +197,6 @@ const (
// SettingKeyBetaPolicySettings stores JSON config for beta policy rules.
SettingKeyBetaPolicySettings
=
"beta_policy_settings"
// =========================
// Sora S3 存储配置
// =========================
SettingKeySoraS3Enabled
=
"sora_s3_enabled"
// 是否启用 Sora S3 存储
SettingKeySoraS3Endpoint
=
"sora_s3_endpoint"
// S3 端点地址
SettingKeySoraS3Region
=
"sora_s3_region"
// S3 区域
SettingKeySoraS3Bucket
=
"sora_s3_bucket"
// S3 存储桶名称
SettingKeySoraS3AccessKeyID
=
"sora_s3_access_key_id"
// S3 Access Key ID
SettingKeySoraS3SecretAccessKey
=
"sora_s3_secret_access_key"
// S3 Secret Access Key(加密存储)
SettingKeySoraS3Prefix
=
"sora_s3_prefix"
// S3 对象键前缀
SettingKeySoraS3ForcePathStyle
=
"sora_s3_force_path_style"
// 是否强制 Path Style(兼容 MinIO 等)
SettingKeySoraS3CDNURL
=
"sora_s3_cdn_url"
// CDN 加速 URL(可选)
SettingKeySoraS3Profiles
=
"sora_s3_profiles"
// Sora S3 多配置(JSON)
// =========================
// Sora 用户存储配额
// =========================
SettingKeySoraDefaultStorageQuotaBytes
=
"sora_default_storage_quota_bytes"
// 新用户默认 Sora 存储配额(字节)
// =========================
// Claude Code Version Check
// =========================
...
...
backend/internal/service/gateway_service.go
View file @
eb2dce92
...
...
@@ -60,13 +60,6 @@ const (
claudeMimicDebugInfoKey
=
"claude_mimic_debug_info"
)
// MediaType 媒体类型常量
const
(
MediaTypeImage
=
"image"
MediaTypeVideo
=
"video"
MediaTypePrompt
=
"prompt"
)
// ForceCacheBillingContextKey 强制缓存计费上下文键
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
type
forceCacheBillingKeyType
struct
{}
...
...
@@ -511,9 +504,6 @@ type ForwardResult struct {
ImageCount
int
// 生成的图片数量
ImageSize
string
// 图片尺寸 "1K", "2K", "4K"
// Sora 媒体字段
MediaType
string
// image / video / prompt
MediaURL
string
// 生成后的媒体地址(可选)
}
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
...
...
@@ -1971,9 +1961,6 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr
}
func
(
s
*
GatewayService
)
listSchedulableAccounts
(
ctx
context
.
Context
,
groupID
*
int64
,
platform
string
,
hasForcePlatform
bool
)
([]
Account
,
bool
,
error
)
{
if
platform
==
PlatformSora
{
return
s
.
listSoraSchedulableAccounts
(
ctx
,
groupID
)
}
if
s
.
schedulerSnapshot
!=
nil
{
accounts
,
useMixed
,
err
:=
s
.
schedulerSnapshot
.
ListSchedulableAccounts
(
ctx
,
groupID
,
platform
,
hasForcePlatform
)
if
err
==
nil
{
...
...
@@ -2070,53 +2057,6 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
return
accounts
,
useMixed
,
nil
}
func
(
s
*
GatewayService
)
listSoraSchedulableAccounts
(
ctx
context
.
Context
,
groupID
*
int64
)
([]
Account
,
bool
,
error
)
{
const
useMixed
=
false
var
accounts
[]
Account
var
err
error
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
RunMode
==
config
.
RunModeSimple
{
accounts
,
err
=
s
.
accountRepo
.
ListByPlatform
(
ctx
,
PlatformSora
)
}
else
if
groupID
!=
nil
{
accounts
,
err
=
s
.
accountRepo
.
ListByGroup
(
ctx
,
*
groupID
)
}
else
{
accounts
,
err
=
s
.
accountRepo
.
ListByPlatform
(
ctx
,
PlatformSora
)
}
if
err
!=
nil
{
slog
.
Debug
(
"account_scheduling_list_failed"
,
"group_id"
,
derefGroupID
(
groupID
),
"platform"
,
PlatformSora
,
"error"
,
err
)
return
nil
,
useMixed
,
err
}
filtered
:=
make
([]
Account
,
0
,
len
(
accounts
))
for
_
,
acc
:=
range
accounts
{
if
acc
.
Platform
!=
PlatformSora
{
continue
}
if
!
s
.
isSoraAccountSchedulable
(
&
acc
)
{
continue
}
filtered
=
append
(
filtered
,
acc
)
}
slog
.
Debug
(
"account_scheduling_list_sora"
,
"group_id"
,
derefGroupID
(
groupID
),
"platform"
,
PlatformSora
,
"raw_count"
,
len
(
accounts
),
"filtered_count"
,
len
(
filtered
))
for
_
,
acc
:=
range
filtered
{
slog
.
Debug
(
"account_scheduling_account_detail"
,
"account_id"
,
acc
.
ID
,
"name"
,
acc
.
Name
,
"platform"
,
acc
.
Platform
,
"type"
,
acc
.
Type
,
"status"
,
acc
.
Status
,
"tls_fingerprint"
,
acc
.
IsTLSFingerprintEnabled
())
}
return
filtered
,
useMixed
,
nil
}
// IsSingleAntigravityAccountGroup 检查指定分组是否只有一个 antigravity 平台的可调度账号。
// 用于 Handler 层在首次请求时提前设置 SingleAccountRetry context,
// 避免单账号分组收到 503 时错误地设置模型限流标记导致后续请求连续快速失败。
...
...
@@ -2141,33 +2081,10 @@ func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform
return
account
.
Platform
==
platform
}
func
(
s
*
GatewayService
)
isSoraAccountSchedulable
(
account
*
Account
)
bool
{
return
s
.
soraUnschedulableReason
(
account
)
==
""
}
func
(
s
*
GatewayService
)
soraUnschedulableReason
(
account
*
Account
)
string
{
if
account
==
nil
{
return
"account_nil"
}
if
account
.
Status
!=
StatusActive
{
return
fmt
.
Sprintf
(
"status=%s"
,
account
.
Status
)
}
if
!
account
.
Schedulable
{
return
"schedulable=false"
}
if
account
.
TempUnschedulableUntil
!=
nil
&&
time
.
Now
()
.
Before
(
*
account
.
TempUnschedulableUntil
)
{
return
fmt
.
Sprintf
(
"temp_unschedulable_until=%s"
,
account
.
TempUnschedulableUntil
.
UTC
()
.
Format
(
time
.
RFC3339
))
}
return
""
}
func
(
s
*
GatewayService
)
isAccountSchedulableForSelection
(
account
*
Account
)
bool
{
if
account
==
nil
{
return
false
}
if
account
.
Platform
==
PlatformSora
{
return
s
.
isSoraAccountSchedulable
(
account
)
}
return
account
.
IsSchedulable
()
}
...
...
@@ -2175,12 +2092,6 @@ func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Conte
if
account
==
nil
{
return
false
}
if
account
.
Platform
==
PlatformSora
{
if
!
s
.
isSoraAccountSchedulable
(
account
)
{
return
false
}
return
account
.
GetRateLimitRemainingTimeWithContext
(
ctx
,
requestedModel
)
<=
0
}
return
account
.
IsSchedulableForModelWithContext
(
ctx
,
requestedModel
)
}
...
...
@@ -3357,9 +3268,6 @@ func (s *GatewayService) logDetailedSelectionFailure(
stats
.
SampleMappingIDs
,
stats
.
SampleRateLimitIDs
,
)
if
platform
==
PlatformSora
{
s
.
logSoraSelectionFailureDetails
(
ctx
,
groupID
,
sessionHash
,
requestedModel
,
accounts
,
excludedIDs
,
allowMixedScheduling
)
}
return
stats
}
...
...
@@ -3416,11 +3324,7 @@ func (s *GatewayService) diagnoseSelectionFailure(
return
selectionFailureDiagnosis
{
Category
:
"excluded"
}
}
if
!
s
.
isAccountSchedulableForSelection
(
acc
)
{
detail
:=
"generic_unschedulable"
if
acc
.
Platform
==
PlatformSora
{
detail
=
s
.
soraUnschedulableReason
(
acc
)
}
return
selectionFailureDiagnosis
{
Category
:
"unschedulable"
,
Detail
:
detail
}
return
selectionFailureDiagnosis
{
Category
:
"unschedulable"
,
Detail
:
"generic_unschedulable"
}
}
if
isPlatformFilteredForSelection
(
acc
,
platform
,
allowMixedScheduling
)
{
return
selectionFailureDiagnosis
{
...
...
@@ -3444,57 +3348,6 @@ func (s *GatewayService) diagnoseSelectionFailure(
return
selectionFailureDiagnosis
{
Category
:
"eligible"
}
}
func
(
s
*
GatewayService
)
logSoraSelectionFailureDetails
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
,
requestedModel
string
,
accounts
[]
Account
,
excludedIDs
map
[
int64
]
struct
{},
allowMixedScheduling
bool
,
)
{
const
maxLines
=
30
logged
:=
0
for
i
:=
range
accounts
{
if
logged
>=
maxLines
{
break
}
acc
:=
&
accounts
[
i
]
diagnosis
:=
s
.
diagnoseSelectionFailure
(
ctx
,
acc
,
requestedModel
,
PlatformSora
,
excludedIDs
,
allowMixedScheduling
)
if
diagnosis
.
Category
==
"eligible"
{
continue
}
detail
:=
diagnosis
.
Detail
if
detail
==
""
{
detail
=
"-"
}
logger
.
LegacyPrintf
(
"service.gateway"
,
"[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s account_id=%d account_platform=%s category=%s detail=%s"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
acc
.
ID
,
acc
.
Platform
,
diagnosis
.
Category
,
detail
,
)
logged
++
}
if
len
(
accounts
)
>
maxLines
{
logger
.
LegacyPrintf
(
"service.gateway"
,
"[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s truncated=true total=%d logged=%d"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
len
(
accounts
),
logged
,
)
}
}
func
isPlatformFilteredForSelection
(
acc
*
Account
,
platform
string
,
allowMixedScheduling
bool
)
bool
{
if
acc
==
nil
{
return
true
...
...
@@ -3573,9 +3426,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
}
return
mapAntigravityModel
(
account
,
requestedModel
)
!=
""
}
if
account
.
Platform
==
PlatformSora
{
return
s
.
isSoraModelSupportedByAccount
(
account
,
requestedModel
)
}
if
account
.
IsBedrock
()
{
_
,
ok
:=
ResolveBedrockModelID
(
account
,
requestedModel
)
return
ok
...
...
@@ -3588,143 +3438,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
return
account
.
IsModelSupported
(
requestedModel
)
}
func
(
s
*
GatewayService
)
isSoraModelSupportedByAccount
(
account
*
Account
,
requestedModel
string
)
bool
{
if
account
==
nil
{
return
false
}
if
strings
.
TrimSpace
(
requestedModel
)
==
""
{
return
true
}
// 先走原始精确/通配符匹配。
mapping
:=
account
.
GetModelMapping
()
if
len
(
mapping
)
==
0
||
account
.
IsModelSupported
(
requestedModel
)
{
return
true
}
aliases
:=
buildSoraModelAliases
(
requestedModel
)
if
len
(
aliases
)
==
0
{
return
false
}
hasSoraSelector
:=
false
for
pattern
:=
range
mapping
{
if
!
isSoraModelSelector
(
pattern
)
{
continue
}
hasSoraSelector
=
true
if
matchPatternAnyAlias
(
pattern
,
aliases
)
{
return
true
}
}
// 兼容旧账号:mapping 存在但未配置任何 Sora 选择器(例如只含 gpt-*),
// 此时不应误拦截 Sora 模型请求。
if
!
hasSoraSelector
{
return
true
}
return
false
}
func
matchPatternAnyAlias
(
pattern
string
,
aliases
[]
string
)
bool
{
normalizedPattern
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
pattern
))
if
normalizedPattern
==
""
{
return
false
}
for
_
,
alias
:=
range
aliases
{
if
matchWildcard
(
normalizedPattern
,
alias
)
{
return
true
}
}
return
false
}
func
isSoraModelSelector
(
pattern
string
)
bool
{
p
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
pattern
))
if
p
==
""
{
return
false
}
switch
{
case
strings
.
HasPrefix
(
p
,
"sora"
),
strings
.
HasPrefix
(
p
,
"gpt-image"
),
strings
.
HasPrefix
(
p
,
"prompt-enhance"
),
strings
.
HasPrefix
(
p
,
"sy_"
)
:
return
true
}
return
p
==
"video"
||
p
==
"image"
}
func
buildSoraModelAliases
(
requestedModel
string
)
[]
string
{
modelID
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
requestedModel
))
if
modelID
==
""
{
return
nil
}
aliases
:=
make
([]
string
,
0
,
8
)
addAlias
:=
func
(
value
string
)
{
v
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
value
))
if
v
==
""
{
return
}
for
_
,
existing
:=
range
aliases
{
if
existing
==
v
{
return
}
}
aliases
=
append
(
aliases
,
v
)
}
addAlias
(
modelID
)
cfg
,
ok
:=
GetSoraModelConfig
(
modelID
)
if
ok
{
addAlias
(
cfg
.
Model
)
switch
cfg
.
Type
{
case
"video"
:
addAlias
(
"video"
)
addAlias
(
"sora"
)
addAlias
(
soraVideoFamilyAlias
(
modelID
))
case
"image"
:
addAlias
(
"image"
)
addAlias
(
"gpt-image"
)
case
"prompt_enhance"
:
addAlias
(
"prompt-enhance"
)
}
return
aliases
}
switch
{
case
strings
.
HasPrefix
(
modelID
,
"sora"
)
:
addAlias
(
"video"
)
addAlias
(
"sora"
)
addAlias
(
soraVideoFamilyAlias
(
modelID
))
case
strings
.
HasPrefix
(
modelID
,
"gpt-image"
)
:
addAlias
(
"image"
)
addAlias
(
"gpt-image"
)
case
strings
.
HasPrefix
(
modelID
,
"prompt-enhance"
)
:
addAlias
(
"prompt-enhance"
)
default
:
return
nil
}
return
aliases
}
func
soraVideoFamilyAlias
(
modelID
string
)
string
{
switch
{
case
strings
.
HasPrefix
(
modelID
,
"sora2pro-hd"
)
:
return
"sora2pro-hd"
case
strings
.
HasPrefix
(
modelID
,
"sora2pro"
)
:
return
"sora2pro"
case
strings
.
HasPrefix
(
modelID
,
"sora2"
)
:
return
"sora2"
default
:
return
""
}
}
// GetAccessToken 获取账号凭证
func
(
s
*
GatewayService
)
GetAccessToken
(
ctx
context
.
Context
,
account
*
Account
)
(
string
,
string
,
error
)
{
switch
account
.
Type
{
...
...
@@ -7592,9 +7305,6 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage
cmd
.
CacheCreationTokens
=
usageLog
.
CacheCreationTokens
cmd
.
CacheReadTokens
=
usageLog
.
CacheReadTokens
cmd
.
ImageCount
=
usageLog
.
ImageCount
if
usageLog
.
MediaType
!=
nil
{
cmd
.
MediaType
=
*
usageLog
.
MediaType
}
if
usageLog
.
ServiceTier
!=
nil
{
cmd
.
ServiceTier
=
*
usageLog
.
ServiceTier
}
...
...
@@ -7750,8 +7460,6 @@ type recordUsageOpts struct {
// EnableClaudePath 启用 Claude 路径特有逻辑:
// - Claude Max 缓存计费策略
// - Sora 媒体类型分支(image/video/prompt)
// - MediaType 字段写入使用日志
EnableClaudePath
bool
// 长上下文计费(仅 Gemini 路径需要)
...
...
@@ -7842,7 +7550,6 @@ type recordUsageCoreInput struct {
// recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。
// opts 中的字段控制两者之间的差异行为:
// - ParsedRequest != nil → 启用 Claude Max 缓存计费策略
// - EnableSoraMedia → 启用 Sora MediaType 分支(image/video/prompt)
// - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext
func
(
s
*
GatewayService
)
recordUsageCore
(
ctx
context
.
Context
,
input
*
recordUsageCoreInput
,
opts
*
recordUsageOpts
)
error
{
result
:=
input
.
Result
...
...
@@ -7944,16 +7651,6 @@ func (s *GatewayService) calculateRecordUsageCost(
multiplier
float64
,
opts
*
recordUsageOpts
,
)
*
CostBreakdown
{
// Sora 媒体类型分支(仅 Claude 路径启用)
if
opts
.
EnableClaudePath
{
if
result
.
MediaType
==
MediaTypeImage
||
result
.
MediaType
==
MediaTypeVideo
{
return
s
.
calculateSoraMediaCost
(
result
,
apiKey
,
billingModel
,
multiplier
)
}
if
result
.
MediaType
==
MediaTypePrompt
{
return
&
CostBreakdown
{}
}
}
// 图片生成计费
if
result
.
ImageCount
>
0
{
return
s
.
calculateImageCost
(
ctx
,
result
,
apiKey
,
billingModel
,
multiplier
)
...
...
@@ -7963,28 +7660,6 @@ func (s *GatewayService) calculateRecordUsageCost(
return
s
.
calculateTokenCost
(
ctx
,
result
,
apiKey
,
billingModel
,
multiplier
,
opts
)
}
// calculateSoraMediaCost 计算 Sora 图片/视频的费用。
func
(
s
*
GatewayService
)
calculateSoraMediaCost
(
result
*
ForwardResult
,
apiKey
*
APIKey
,
billingModel
string
,
multiplier
float64
,
)
*
CostBreakdown
{
var
soraConfig
*
SoraPriceConfig
if
apiKey
.
Group
!=
nil
{
soraConfig
=
&
SoraPriceConfig
{
ImagePrice360
:
apiKey
.
Group
.
SoraImagePrice360
,
ImagePrice540
:
apiKey
.
Group
.
SoraImagePrice540
,
VideoPricePerRequest
:
apiKey
.
Group
.
SoraVideoPricePerRequest
,
VideoPricePerRequestHD
:
apiKey
.
Group
.
SoraVideoPricePerRequestHD
,
}
}
if
result
.
MediaType
==
MediaTypeImage
{
return
s
.
billingService
.
CalculateSoraImageCost
(
result
.
ImageSize
,
result
.
ImageCount
,
soraConfig
,
multiplier
)
}
return
s
.
billingService
.
CalculateSoraVideoCost
(
billingModel
,
soraConfig
,
multiplier
)
}
// resolveChannelPricing 检查指定模型是否存在渠道级别定价。
// 返回非 nil 的 ResolvedPricing 表示有渠道定价,nil 表示走默认定价路径。
func
(
s
*
GatewayService
)
resolveChannelPricing
(
ctx
context
.
Context
,
billingModel
string
,
apiKey
*
APIKey
)
*
ResolvedPricing
{
...
...
@@ -8133,13 +7808,12 @@ func (s *GatewayService) buildRecordUsageLog(
RateMultiplier
:
multiplier
,
AccountRateMultiplier
:
&
accountRateMultiplier
,
BillingType
:
billingType
,
BillingMode
:
resolveBillingMode
(
opts
,
result
,
cost
),
BillingMode
:
resolveBillingMode
(
result
,
cost
),
Stream
:
result
.
Stream
,
DurationMs
:
&
durationMs
,
FirstTokenMs
:
result
.
FirstTokenMs
,
ImageCount
:
result
.
ImageCount
,
ImageSize
:
optionalTrimmedStringPtr
(
result
.
ImageSize
),
MediaType
:
resolveMediaType
(
opts
,
result
),
CacheTTLOverridden
:
cacheTTLOverridden
,
ChannelID
:
optionalInt64Ptr
(
input
.
ChannelID
),
ModelMappingChain
:
optionalTrimmedStringPtr
(
input
.
ModelMappingChain
),
...
...
@@ -8163,13 +7837,7 @@ func (s *GatewayService) buildRecordUsageLog(
}
// resolveBillingMode 根据计费结果和请求类型确定计费模式。
// Sora 媒体类型自身已确定计费模式(由上游处理),返回 nil 跳过。
func
resolveBillingMode
(
opts
*
recordUsageOpts
,
result
*
ForwardResult
,
cost
*
CostBreakdown
)
*
string
{
isSoraMedia
:=
opts
.
EnableClaudePath
&&
(
result
.
MediaType
==
MediaTypeImage
||
result
.
MediaType
==
MediaTypeVideo
||
result
.
MediaType
==
MediaTypePrompt
)
if
isSoraMedia
{
return
nil
}
func
resolveBillingMode
(
result
*
ForwardResult
,
cost
*
CostBreakdown
)
*
string
{
var
mode
string
switch
{
case
cost
!=
nil
&&
cost
.
BillingMode
!=
""
:
...
...
@@ -8182,13 +7850,6 @@ func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *Cost
return
&
mode
}
func
resolveMediaType
(
opts
*
recordUsageOpts
,
result
*
ForwardResult
)
*
string
{
if
opts
.
EnableClaudePath
&&
strings
.
TrimSpace
(
result
.
MediaType
)
!=
""
{
return
&
result
.
MediaType
}
return
nil
}
func
optionalSubscriptionID
(
subscription
*
UserSubscription
)
*
int64
{
if
subscription
!=
nil
{
return
&
subscription
.
ID
...
...
backend/internal/service/gateway_service_selection_failure_stats_test.go
View file @
eb2dce92
...
...
@@ -9,35 +9,35 @@ import (
func
TestCollectSelectionFailureStats
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
model
:=
"
sora2-landscape-10s
"
model
:=
"
gpt-5.4
"
resetAt
:=
time
.
Now
()
.
Add
(
2
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
accounts
:=
[]
Account
{
// excluded
{
ID
:
1
,
Platform
:
Platform
Sora
,
Platform
:
Platform
OpenAI
,
Status
:
StatusActive
,
Schedulable
:
true
,
},
// unschedulable
{
ID
:
2
,
Platform
:
Platform
Sora
,
Platform
:
Platform
OpenAI
,
Status
:
StatusActive
,
Schedulable
:
false
,
},
// platform filtered
{
ID
:
3
,
Platform
:
Platform
OpenAI
,
Platform
:
Platform
Antigravity
,
Status
:
StatusActive
,
Schedulable
:
true
,
},
// model unsupported
{
ID
:
4
,
Platform
:
Platform
Sora
,
Platform
:
Platform
OpenAI
,
Status
:
StatusActive
,
Schedulable
:
true
,
Credentials
:
map
[
string
]
any
{
...
...
@@ -49,7 +49,7 @@ func TestCollectSelectionFailureStats(t *testing.T) {
// model rate limited
{
ID
:
5
,
Platform
:
Platform
Sora
,
Platform
:
Platform
OpenAI
,
Status
:
StatusActive
,
Schedulable
:
true
,
Extra
:
map
[
string
]
any
{
...
...
@@ -63,14 +63,14 @@ func TestCollectSelectionFailureStats(t *testing.T) {
// eligible
{
ID
:
6
,
Platform
:
Platform
Sora
,
Platform
:
Platform
OpenAI
,
Status
:
StatusActive
,
Schedulable
:
true
,
},
}
excluded
:=
map
[
int64
]
struct
{}{
1
:
{}}
stats
:=
svc
.
collectSelectionFailureStats
(
context
.
Background
(),
accounts
,
model
,
Platform
Sora
,
excluded
,
false
)
stats
:=
svc
.
collectSelectionFailureStats
(
context
.
Background
(),
accounts
,
model
,
Platform
OpenAI
,
excluded
,
false
)
if
stats
.
Total
!=
6
{
t
.
Fatalf
(
"total=%d want=6"
,
stats
.
Total
)
...
...
@@ -95,31 +95,31 @@ func TestCollectSelectionFailureStats(t *testing.T) {
}
}
func
TestDiagnoseSelectionFailure_
Sora
UnschedulableDetail
(
t
*
testing
.
T
)
{
func
TestDiagnoseSelectionFailure_UnschedulableDetail
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
acc
:=
&
Account
{
ID
:
7
,
Platform
:
Platform
Sora
,
Platform
:
Platform
OpenAI
,
Status
:
StatusActive
,
Schedulable
:
false
,
}
diagnosis
:=
svc
.
diagnoseSelectionFailure
(
context
.
Background
(),
acc
,
"
sora2-landscape-10s
"
,
Platform
Sora
,
map
[
int64
]
struct
{}{},
false
)
diagnosis
:=
svc
.
diagnoseSelectionFailure
(
context
.
Background
(),
acc
,
"
gpt-5.4
"
,
Platform
OpenAI
,
map
[
int64
]
struct
{}{},
false
)
if
diagnosis
.
Category
!=
"unschedulable"
{
t
.
Fatalf
(
"category=%s want=unschedulable"
,
diagnosis
.
Category
)
}
if
diagnosis
.
Detail
!=
"schedulable
=false
"
{
t
.
Fatalf
(
"detail=%s want=schedulable
=false
"
,
diagnosis
.
Detail
)
if
diagnosis
.
Detail
!=
"
generic_un
schedulable"
{
t
.
Fatalf
(
"detail=%s want=
generic_un
schedulable"
,
diagnosis
.
Detail
)
}
}
func
TestDiagnoseSelectionFailure_
Sora
ModelRateLimitedDetail
(
t
*
testing
.
T
)
{
func
TestDiagnoseSelectionFailure_ModelRateLimitedDetail
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
model
:=
"
sora2-landscape-10s
"
model
:=
"
gpt-5.4
"
resetAt
:=
time
.
Now
()
.
Add
(
2
*
time
.
Minute
)
.
UTC
()
.
Format
(
time
.
RFC3339
)
acc
:=
&
Account
{
ID
:
8
,
Platform
:
Platform
Sora
,
Platform
:
Platform
OpenAI
,
Status
:
StatusActive
,
Schedulable
:
true
,
Extra
:
map
[
string
]
any
{
...
...
@@ -131,7 +131,7 @@ func TestDiagnoseSelectionFailure_SoraModelRateLimitedDetail(t *testing.T) {
},
}
diagnosis
:=
svc
.
diagnoseSelectionFailure
(
context
.
Background
(),
acc
,
model
,
Platform
Sora
,
map
[
int64
]
struct
{}{},
false
)
diagnosis
:=
svc
.
diagnoseSelectionFailure
(
context
.
Background
(),
acc
,
model
,
Platform
OpenAI
,
map
[
int64
]
struct
{}{},
false
)
if
diagnosis
.
Category
!=
"model_rate_limited"
{
t
.
Fatalf
(
"category=%s want=model_rate_limited"
,
diagnosis
.
Category
)
}
...
...
backend/internal/service/gateway_service_sora_model_support_test.go
deleted
100644 → 0
View file @
7b83d6e7
package
service
import
"testing"
func
TestGatewayServiceIsModelSupportedByAccount_SoraNoMappingAllowsAll
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
account
:=
&
Account
{
Platform
:
PlatformSora
,
Credentials
:
map
[
string
]
any
{},
}
if
!
svc
.
isModelSupportedByAccount
(
account
,
"sora2-landscape-10s"
)
{
t
.
Fatalf
(
"expected sora model to be supported when model_mapping is empty"
)
}
}
func
TestGatewayServiceIsModelSupportedByAccount_SoraLegacyNonSoraMappingDoesNotBlock
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
account
:=
&
Account
{
Platform
:
PlatformSora
,
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"gpt-4o"
:
"gpt-4o"
,
},
},
}
if
!
svc
.
isModelSupportedByAccount
(
account
,
"sora2-landscape-10s"
)
{
t
.
Fatalf
(
"expected sora model to be supported when mapping has no sora selectors"
)
}
}
func
TestGatewayServiceIsModelSupportedByAccount_SoraFamilyAlias
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
account
:=
&
Account
{
Platform
:
PlatformSora
,
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"sora2"
:
"sora2"
,
},
},
}
if
!
svc
.
isModelSupportedByAccount
(
account
,
"sora2-landscape-15s"
)
{
t
.
Fatalf
(
"expected family selector sora2 to support sora2-landscape-15s"
)
}
}
func
TestGatewayServiceIsModelSupportedByAccount_SoraUnderlyingModelAlias
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
account
:=
&
Account
{
Platform
:
PlatformSora
,
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"sy_8"
:
"sy_8"
,
},
},
}
if
!
svc
.
isModelSupportedByAccount
(
account
,
"sora2-landscape-10s"
)
{
t
.
Fatalf
(
"expected underlying model selector sy_8 to support sora2-landscape-10s"
)
}
}
func
TestGatewayServiceIsModelSupportedByAccount_SoraExplicitImageSelectorBlocksVideo
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
account
:=
&
Account
{
Platform
:
PlatformSora
,
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"gpt-image"
:
"gpt-image"
,
},
},
}
if
svc
.
isModelSupportedByAccount
(
account
,
"sora2-landscape-10s"
)
{
t
.
Fatalf
(
"expected video model to be blocked when mapping explicitly only allows gpt-image"
)
}
}
backend/internal/service/gateway_service_sora_scheduling_test.go
deleted
100644 → 0
View file @
7b83d6e7
package
service
import
(
"context"
"testing"
"time"
)
func
TestGatewayServiceIsAccountSchedulableForSelectionSoraIgnoresGenericWindows
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
now
:=
time
.
Now
()
past
:=
now
.
Add
(
-
1
*
time
.
Minute
)
future
:=
now
.
Add
(
5
*
time
.
Minute
)
acc
:=
&
Account
{
Platform
:
PlatformSora
,
Status
:
StatusActive
,
Schedulable
:
true
,
AutoPauseOnExpired
:
true
,
ExpiresAt
:
&
past
,
OverloadUntil
:
&
future
,
RateLimitResetAt
:
&
future
,
}
if
!
svc
.
isAccountSchedulableForSelection
(
acc
)
{
t
.
Fatalf
(
"expected sora account to ignore generic expiry/overload/rate-limit windows"
)
}
}
func
TestGatewayServiceIsAccountSchedulableForSelectionNonSoraKeepsGenericLogic
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
future
:=
time
.
Now
()
.
Add
(
5
*
time
.
Minute
)
acc
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
Schedulable
:
true
,
RateLimitResetAt
:
&
future
,
}
if
svc
.
isAccountSchedulableForSelection
(
acc
)
{
t
.
Fatalf
(
"expected non-sora account to keep generic schedulable checks"
)
}
}
func
TestGatewayServiceIsAccountSchedulableForModelSelectionSoraChecksModelScopeOnly
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
model
:=
"sora2-landscape-10s"
resetAt
:=
time
.
Now
()
.
Add
(
2
*
time
.
Minute
)
.
UTC
()
.
Format
(
time
.
RFC3339
)
globalResetAt
:=
time
.
Now
()
.
Add
(
2
*
time
.
Minute
)
acc
:=
&
Account
{
Platform
:
PlatformSora
,
Status
:
StatusActive
,
Schedulable
:
true
,
RateLimitResetAt
:
&
globalResetAt
,
Extra
:
map
[
string
]
any
{
"model_rate_limits"
:
map
[
string
]
any
{
model
:
map
[
string
]
any
{
"rate_limit_reset_at"
:
resetAt
,
},
},
},
}
if
svc
.
isAccountSchedulableForModelSelection
(
context
.
Background
(),
acc
,
model
)
{
t
.
Fatalf
(
"expected sora account to be blocked by model scope rate limit"
)
}
}
func
TestCollectSelectionFailureStatsSoraIgnoresGenericUnschedulableWindows
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
future
:=
time
.
Now
()
.
Add
(
3
*
time
.
Minute
)
accounts
:=
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformSora
,
Status
:
StatusActive
,
Schedulable
:
true
,
RateLimitResetAt
:
&
future
,
},
}
stats
:=
svc
.
collectSelectionFailureStats
(
context
.
Background
(),
accounts
,
"sora2-landscape-10s"
,
PlatformSora
,
map
[
int64
]
struct
{}{},
false
)
if
stats
.
Unschedulable
!=
0
||
stats
.
Eligible
!=
1
{
t
.
Fatalf
(
"unexpected stats: unschedulable=%d eligible=%d"
,
stats
.
Unschedulable
,
stats
.
Eligible
)
}
}
backend/internal/service/group.go
View file @
eb2dce92
...
...
@@ -26,15 +26,6 @@ type Group struct {
ImagePrice2K
*
float64
ImagePrice4K
*
float64
// Sora 按次计费配置(阶段 1)
SoraImagePrice360
*
float64
SoraImagePrice540
*
float64
SoraVideoPricePerRequest
*
float64
SoraVideoPricePerRequestHD
*
float64
// Sora 存储配额
SoraStorageQuotaBytes
int64
// Claude Code 客户端限制
ClaudeCodeOnly
bool
FallbackGroupID
*
int64
...
...
@@ -112,18 +103,6 @@ func (g *Group) GetImagePrice(imageSize string) *float64 {
}
}
// GetSoraImagePrice 根据 Sora 图片尺寸返回价格(360/540)
func
(
g
*
Group
)
GetSoraImagePrice
(
imageSize
string
)
*
float64
{
switch
imageSize
{
case
"360"
:
return
g
.
SoraImagePrice360
case
"540"
:
return
g
.
SoraImagePrice540
default
:
return
g
.
SoraImagePrice360
}
}
// IsGroupContextValid reports whether a group from context has the fields required for routing decisions.
func
IsGroupContextValid
(
group
*
Group
)
bool
{
if
group
==
nil
{
...
...
backend/internal/service/openai_gateway_record_usage_test.go
View file @
eb2dce92
...
...
@@ -933,6 +933,89 @@ func TestOpenAIGatewayServiceRecordUsage_BillsMappedRequestsUsingRequestedModel(
require
.
Equal
(
t
,
expectedCost
.
ActualCost
,
userRepo
.
lastAmount
)
}
func
TestOpenAIGatewayServiceRecordUsage_ChannelMappedDoesNotOverrideBillingModelWhenUnmapped
(
t
*
testing
.
T
)
{
usageRepo
:=
&
openAIRecordUsageLogRepoStub
{
inserted
:
true
}
userRepo
:=
&
openAIRecordUsageUserRepoStub
{}
subRepo
:=
&
openAIRecordUsageSubRepoStub
{}
svc
:=
newOpenAIRecordUsageServiceForTest
(
usageRepo
,
userRepo
,
subRepo
,
nil
)
usage
:=
OpenAIUsage
{
InputTokens
:
20
,
OutputTokens
:
10
}
// When channel did NOT map the model (ChannelMappedModel == OriginalModel),
// billing should use result.BillingModel (the actual model used after group
// DefaultMappedModel resolution), not the unmapped original model.
expectedCost
,
err
:=
svc
.
billingService
.
CalculateCost
(
"gpt-5.1"
,
UsageTokens
{
InputTokens
:
20
,
OutputTokens
:
10
,
},
1.1
)
require
.
NoError
(
t
,
err
)
err
=
svc
.
RecordUsage
(
context
.
Background
(),
&
OpenAIRecordUsageInput
{
Result
:
&
OpenAIForwardResult
{
RequestID
:
"resp_channel_unmapped_billing"
,
Model
:
"glm"
,
BillingModel
:
"gpt-5.1"
,
UpstreamModel
:
"gpt-5.1"
,
Usage
:
usage
,
Duration
:
time
.
Second
,
},
APIKey
:
&
APIKey
{
ID
:
10
},
User
:
&
User
{
ID
:
20
},
Account
:
&
Account
{
ID
:
30
},
ChannelUsageFields
:
ChannelUsageFields
{
ChannelID
:
1
,
OriginalModel
:
"glm"
,
ChannelMappedModel
:
"glm"
,
// channel did NOT map
BillingModelSource
:
BillingModelSourceChannelMapped
,
},
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
usageRepo
.
lastLog
)
require
.
Equal
(
t
,
expectedCost
.
ActualCost
,
usageRepo
.
lastLog
.
ActualCost
)
require
.
True
(
t
,
usageRepo
.
lastLog
.
ActualCost
>
0
,
"cost must not be zero"
)
}
func
TestOpenAIGatewayServiceRecordUsage_ChannelMappedOverridesBillingModelWhenMapped
(
t
*
testing
.
T
)
{
usageRepo
:=
&
openAIRecordUsageLogRepoStub
{
inserted
:
true
}
userRepo
:=
&
openAIRecordUsageUserRepoStub
{}
subRepo
:=
&
openAIRecordUsageSubRepoStub
{}
svc
:=
newOpenAIRecordUsageServiceForTest
(
usageRepo
,
userRepo
,
subRepo
,
nil
)
usage
:=
OpenAIUsage
{
InputTokens
:
20
,
OutputTokens
:
10
}
// When channel DID map the model (ChannelMappedModel != OriginalModel),
// billing should use the channel-mapped model, honoring admin intent.
expectedCost
,
err
:=
svc
.
billingService
.
CalculateCost
(
"gpt-5.1"
,
UsageTokens
{
InputTokens
:
20
,
OutputTokens
:
10
,
},
1.1
)
require
.
NoError
(
t
,
err
)
err
=
svc
.
RecordUsage
(
context
.
Background
(),
&
OpenAIRecordUsageInput
{
Result
:
&
OpenAIForwardResult
{
RequestID
:
"resp_channel_mapped_billing"
,
Model
:
"glm"
,
BillingModel
:
"gpt-5.1-codex"
,
UpstreamModel
:
"gpt-5.1-codex"
,
Usage
:
usage
,
Duration
:
time
.
Second
,
},
APIKey
:
&
APIKey
{
ID
:
10
},
User
:
&
User
{
ID
:
20
},
Account
:
&
Account
{
ID
:
30
},
ChannelUsageFields
:
ChannelUsageFields
{
ChannelID
:
1
,
OriginalModel
:
"glm"
,
ChannelMappedModel
:
"gpt-5.1"
,
// channel mapped glm → gpt-5.1
BillingModelSource
:
BillingModelSourceChannelMapped
,
},
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
usageRepo
.
lastLog
)
require
.
Equal
(
t
,
expectedCost
.
ActualCost
,
usageRepo
.
lastLog
.
ActualCost
)
require
.
True
(
t
,
usageRepo
.
lastLog
.
ActualCost
>
0
,
"cost must not be zero"
)
}
func
TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFields
(
t
*
testing
.
T
)
{
usageRepo
:=
&
openAIRecordUsageLogRepoStub
{
inserted
:
true
}
userRepo
:=
&
openAIRecordUsageUserRepoStub
{}
...
...
backend/internal/service/openai_gateway_service.go
View file @
eb2dce92
...
...
@@ -4277,7 +4277,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
if
result
.
BillingModel
!=
""
{
billingModel
=
strings
.
TrimSpace
(
result
.
BillingModel
)
}
if
input
.
BillingModelSource
==
BillingModelSourceChannelMapped
&&
input
.
ChannelMappedModel
!=
""
{
if
input
.
BillingModelSource
==
BillingModelSourceChannelMapped
&&
input
.
ChannelMappedModel
!=
""
&&
input
.
ChannelMappedModel
!=
input
.
OriginalModel
{
billingModel
=
input
.
ChannelMappedModel
}
if
input
.
BillingModelSource
==
BillingModelSourceRequested
&&
input
.
OriginalModel
!=
""
{
...
...
backend/internal/service/openai_oauth_service.go
View file @
eb2dce92
...
...
@@ -3,30 +3,15 @@ package service
import
(
"context"
"crypto/subtle"
"encoding/json"
"io"
"log/slog"
"net/http"
"regexp"
"sort"
"strconv"
"strings"
"time"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
)
var
openAISoraSessionAuthURL
=
"https://sora.chatgpt.com/api/auth/session"
var
soraSessionCookiePattern
=
regexp
.
MustCompile
(
`(?i)(?:^|[\n\r;])\s*(?:(?:set-cookie|cookie)\s*:\s*)?__Secure-(?:next-auth|authjs)\.session-token(?:\.(\d+))?=([^;\r\n]+)`
)
type
soraSessionChunk
struct
{
index
int
value
string
}
// OpenAIOAuthService handles OpenAI OAuth authentication flows
type
OpenAIOAuthService
struct
{
sessionStore
*
openai
.
SessionStore
...
...
@@ -225,7 +210,7 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri
return
s
.
RefreshTokenWithClientID
(
ctx
,
refreshToken
,
proxyURL
,
""
)
}
// RefreshTokenWithClientID refreshes an OpenAI
/Sora
OAuth token with optional client_id.
// RefreshTokenWithClientID refreshes an OpenAI OAuth token with optional client_id.
func
(
s
*
OpenAIOAuthService
)
RefreshTokenWithClientID
(
ctx
context
.
Context
,
refreshToken
string
,
proxyURL
string
,
clientID
string
)
(
*
OpenAITokenInfo
,
error
)
{
tokenResp
,
err
:=
s
.
oauthClient
.
RefreshTokenWithClientID
(
ctx
,
refreshToken
,
proxyURL
,
clientID
)
if
err
!=
nil
{
...
...
@@ -298,215 +283,10 @@ func (s *OpenAIOAuthService) enrichTokenInfo(ctx context.Context, tokenInfo *Ope
tokenInfo
.
PrivacyMode
=
disableOpenAITraining
(
ctx
,
s
.
privacyClientFactory
,
tokenInfo
.
AccessToken
,
proxyURL
)
}
// ExchangeSoraSessionToken exchanges Sora session_token to access_token.
func
(
s
*
OpenAIOAuthService
)
ExchangeSoraSessionToken
(
ctx
context
.
Context
,
sessionToken
string
,
proxyID
*
int64
)
(
*
OpenAITokenInfo
,
error
)
{
sessionToken
=
normalizeSoraSessionTokenInput
(
sessionToken
)
if
strings
.
TrimSpace
(
sessionToken
)
==
""
{
return
nil
,
infraerrors
.
New
(
http
.
StatusBadRequest
,
"SORA_SESSION_TOKEN_REQUIRED"
,
"session_token is required"
)
}
proxyURL
,
err
:=
s
.
resolveProxyURL
(
ctx
,
proxyID
)
if
err
!=
nil
{
return
nil
,
err
}
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodGet
,
openAISoraSessionAuthURL
,
nil
)
if
err
!=
nil
{
return
nil
,
infraerrors
.
Newf
(
http
.
StatusInternalServerError
,
"SORA_SESSION_REQUEST_BUILD_FAILED"
,
"failed to build request: %v"
,
err
)
}
req
.
Header
.
Set
(
"Cookie"
,
"__Secure-next-auth.session-token="
+
strings
.
TrimSpace
(
sessionToken
))
req
.
Header
.
Set
(
"Accept"
,
"application/json"
)
req
.
Header
.
Set
(
"Origin"
,
"https://sora.chatgpt.com"
)
req
.
Header
.
Set
(
"Referer"
,
"https://sora.chatgpt.com/"
)
req
.
Header
.
Set
(
"User-Agent"
,
"Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)"
)
client
,
err
:=
httpclient
.
GetClient
(
httpclient
.
Options
{
ProxyURL
:
proxyURL
,
Timeout
:
120
*
time
.
Second
,
})
if
err
!=
nil
{
return
nil
,
infraerrors
.
Newf
(
http
.
StatusBadGateway
,
"SORA_SESSION_CLIENT_FAILED"
,
"create http client failed: %v"
,
err
)
}
resp
,
err
:=
client
.
Do
(
req
)
if
err
!=
nil
{
return
nil
,
infraerrors
.
Newf
(
http
.
StatusBadGateway
,
"SORA_SESSION_REQUEST_FAILED"
,
"request failed: %v"
,
err
)
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
body
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
if
resp
.
StatusCode
!=
http
.
StatusOK
{
return
nil
,
infraerrors
.
Newf
(
http
.
StatusBadGateway
,
"SORA_SESSION_EXCHANGE_FAILED"
,
"status %d: %s"
,
resp
.
StatusCode
,
strings
.
TrimSpace
(
string
(
body
)))
}
var
sessionResp
struct
{
AccessToken
string
`json:"accessToken"`
Expires
string
`json:"expires"`
User
struct
{
Email
string
`json:"email"`
Name
string
`json:"name"`
}
`json:"user"`
}
if
err
:=
json
.
Unmarshal
(
body
,
&
sessionResp
);
err
!=
nil
{
return
nil
,
infraerrors
.
Newf
(
http
.
StatusBadGateway
,
"SORA_SESSION_PARSE_FAILED"
,
"failed to parse response: %v"
,
err
)
}
if
strings
.
TrimSpace
(
sessionResp
.
AccessToken
)
==
""
{
return
nil
,
infraerrors
.
New
(
http
.
StatusBadGateway
,
"SORA_SESSION_ACCESS_TOKEN_MISSING"
,
"session exchange response missing access token"
)
}
expiresAt
:=
time
.
Now
()
.
Add
(
time
.
Hour
)
.
Unix
()
if
strings
.
TrimSpace
(
sessionResp
.
Expires
)
!=
""
{
if
parsed
,
parseErr
:=
time
.
Parse
(
time
.
RFC3339
,
sessionResp
.
Expires
);
parseErr
==
nil
{
expiresAt
=
parsed
.
Unix
()
}
}
expiresIn
:=
expiresAt
-
time
.
Now
()
.
Unix
()
if
expiresIn
<
0
{
expiresIn
=
0
}
return
&
OpenAITokenInfo
{
AccessToken
:
strings
.
TrimSpace
(
sessionResp
.
AccessToken
),
ExpiresIn
:
expiresIn
,
ExpiresAt
:
expiresAt
,
ClientID
:
openai
.
SoraClientID
,
Email
:
strings
.
TrimSpace
(
sessionResp
.
User
.
Email
),
},
nil
}
func
normalizeSoraSessionTokenInput
(
raw
string
)
string
{
trimmed
:=
strings
.
TrimSpace
(
raw
)
if
trimmed
==
""
{
return
""
}
matches
:=
soraSessionCookiePattern
.
FindAllStringSubmatch
(
trimmed
,
-
1
)
if
len
(
matches
)
==
0
{
return
sanitizeSessionToken
(
trimmed
)
}
chunkMatches
:=
make
([]
soraSessionChunk
,
0
,
len
(
matches
))
singleValues
:=
make
([]
string
,
0
,
len
(
matches
))
for
_
,
match
:=
range
matches
{
if
len
(
match
)
<
3
{
continue
}
value
:=
sanitizeSessionToken
(
match
[
2
])
if
value
==
""
{
continue
}
if
strings
.
TrimSpace
(
match
[
1
])
==
""
{
singleValues
=
append
(
singleValues
,
value
)
continue
}
idx
,
err
:=
strconv
.
Atoi
(
strings
.
TrimSpace
(
match
[
1
]))
if
err
!=
nil
||
idx
<
0
{
continue
}
chunkMatches
=
append
(
chunkMatches
,
soraSessionChunk
{
index
:
idx
,
value
:
value
,
})
}
if
merged
:=
mergeLatestSoraSessionChunks
(
chunkMatches
);
merged
!=
""
{
return
merged
}
if
len
(
singleValues
)
>
0
{
return
singleValues
[
len
(
singleValues
)
-
1
]
}
return
""
}
func
mergeSoraSessionChunkSegment
(
chunks
[]
soraSessionChunk
,
requiredMaxIndex
int
,
requireComplete
bool
)
string
{
if
len
(
chunks
)
==
0
{
return
""
}
byIndex
:=
make
(
map
[
int
]
string
,
len
(
chunks
))
for
_
,
chunk
:=
range
chunks
{
byIndex
[
chunk
.
index
]
=
chunk
.
value
}
if
_
,
ok
:=
byIndex
[
0
];
!
ok
{
return
""
}
if
requireComplete
{
for
idx
:=
0
;
idx
<=
requiredMaxIndex
;
idx
++
{
if
_
,
ok
:=
byIndex
[
idx
];
!
ok
{
return
""
}
}
}
orderedIndexes
:=
make
([]
int
,
0
,
len
(
byIndex
))
for
idx
:=
range
byIndex
{
orderedIndexes
=
append
(
orderedIndexes
,
idx
)
}
sort
.
Ints
(
orderedIndexes
)
var
builder
strings
.
Builder
for
_
,
idx
:=
range
orderedIndexes
{
if
_
,
err
:=
builder
.
WriteString
(
byIndex
[
idx
]);
err
!=
nil
{
return
""
}
}
return
sanitizeSessionToken
(
builder
.
String
())
}
func
mergeLatestSoraSessionChunks
(
chunks
[]
soraSessionChunk
)
string
{
if
len
(
chunks
)
==
0
{
return
""
}
requiredMaxIndex
:=
0
for
_
,
chunk
:=
range
chunks
{
if
chunk
.
index
>
requiredMaxIndex
{
requiredMaxIndex
=
chunk
.
index
}
}
groupStarts
:=
make
([]
int
,
0
,
len
(
chunks
))
for
idx
,
chunk
:=
range
chunks
{
if
chunk
.
index
==
0
{
groupStarts
=
append
(
groupStarts
,
idx
)
}
}
if
len
(
groupStarts
)
==
0
{
return
mergeSoraSessionChunkSegment
(
chunks
,
requiredMaxIndex
,
false
)
}
for
i
:=
len
(
groupStarts
)
-
1
;
i
>=
0
;
i
--
{
start
:=
groupStarts
[
i
]
end
:=
len
(
chunks
)
if
i
+
1
<
len
(
groupStarts
)
{
end
=
groupStarts
[
i
+
1
]
}
if
merged
:=
mergeSoraSessionChunkSegment
(
chunks
[
start
:
end
],
requiredMaxIndex
,
true
);
merged
!=
""
{
return
merged
}
}
return
mergeSoraSessionChunkSegment
(
chunks
,
requiredMaxIndex
,
false
)
}
func
sanitizeSessionToken
(
raw
string
)
string
{
token
:=
strings
.
TrimSpace
(
raw
)
token
=
strings
.
Trim
(
token
,
"
\"
'`"
)
token
=
strings
.
TrimSuffix
(
token
,
";"
)
return
strings
.
TrimSpace
(
token
)
}
// RefreshAccountToken refreshes token for an OpenAI/Sora OAuth account
// RefreshAccountToken refreshes token for an OpenAI OAuth account
func
(
s
*
OpenAIOAuthService
)
RefreshAccountToken
(
ctx
context
.
Context
,
account
*
Account
)
(
*
OpenAITokenInfo
,
error
)
{
if
account
.
Platform
!=
PlatformOpenAI
&&
account
.
Platform
!=
PlatformSora
{
return
nil
,
infraerrors
.
New
(
http
.
StatusBadRequest
,
"OPENAI_OAUTH_INVALID_ACCOUNT"
,
"account is not an OpenAI
/Sora
account"
)
if
account
.
Platform
!=
PlatformOpenAI
{
return
nil
,
infraerrors
.
New
(
http
.
StatusBadRequest
,
"OPENAI_OAUTH_INVALID_ACCOUNT"
,
"account is not an OpenAI account"
)
}
if
account
.
Type
!=
AccountTypeOAuth
{
return
nil
,
infraerrors
.
New
(
http
.
StatusBadRequest
,
"OPENAI_OAUTH_INVALID_ACCOUNT_TYPE"
,
"account is not an OAuth account"
)
...
...
@@ -594,25 +374,6 @@ func (s *OpenAIOAuthService) Stop() {
s
.
sessionStore
.
Stop
()
}
func
(
s
*
OpenAIOAuthService
)
resolveProxyURL
(
ctx
context
.
Context
,
proxyID
*
int64
)
(
string
,
error
)
{
if
proxyID
==
nil
{
return
""
,
nil
}
proxy
,
err
:=
s
.
proxyRepo
.
GetByID
(
ctx
,
*
proxyID
)
if
err
!=
nil
{
return
""
,
infraerrors
.
Newf
(
http
.
StatusBadRequest
,
"OPENAI_OAUTH_PROXY_NOT_FOUND"
,
"proxy not found: %v"
,
err
)
}
if
proxy
==
nil
{
return
""
,
nil
}
return
proxy
.
URL
(),
nil
}
func
normalizeOpenAIOAuthPlatform
(
platform
string
)
string
{
switch
strings
.
ToLower
(
strings
.
TrimSpace
(
platform
))
{
case
PlatformSora
:
return
openai
.
OAuthPlatformSora
default
:
return
openai
.
OAuthPlatformOpenAI
}
return
openai
.
OAuthPlatformOpenAI
}
backend/internal/service/openai_oauth_service_auth_url_test.go
View file @
eb2dce92
...
...
@@ -43,25 +43,3 @@ func TestOpenAIOAuthService_GenerateAuthURL_OpenAIKeepsCodexFlow(t *testing.T) {
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
openai
.
ClientID
,
session
.
ClientID
)
}
// TestOpenAIOAuthService_GenerateAuthURL_SoraUsesCodexClient 验证 Sora 平台复用 Codex CLI 的
// client_id(支持 localhost redirect_uri),但不启用 codex_cli_simplified_flow。
func
TestOpenAIOAuthService_GenerateAuthURL_SoraUsesCodexClient
(
t
*
testing
.
T
)
{
svc
:=
NewOpenAIOAuthService
(
nil
,
&
openaiOAuthClientAuthURLStub
{})
defer
svc
.
Stop
()
result
,
err
:=
svc
.
GenerateAuthURL
(
context
.
Background
(),
nil
,
""
,
PlatformSora
)
require
.
NoError
(
t
,
err
)
require
.
NotEmpty
(
t
,
result
.
AuthURL
)
require
.
NotEmpty
(
t
,
result
.
SessionID
)
parsed
,
err
:=
url
.
Parse
(
result
.
AuthURL
)
require
.
NoError
(
t
,
err
)
q
:=
parsed
.
Query
()
require
.
Equal
(
t
,
openai
.
ClientID
,
q
.
Get
(
"client_id"
))
require
.
Empty
(
t
,
q
.
Get
(
"codex_cli_simplified_flow"
))
session
,
ok
:=
svc
.
sessionStore
.
Get
(
result
.
SessionID
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
openai
.
ClientID
,
session
.
ClientID
)
}
backend/internal/service/openai_oauth_service_sora_session_test.go
deleted
100644 → 0
View file @
7b83d6e7
package
service
import
(
"context"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/stretchr/testify/require"
)
type
openaiOAuthClientNoopStub
struct
{}
func
(
s
*
openaiOAuthClientNoopStub
)
ExchangeCode
(
ctx
context
.
Context
,
code
,
codeVerifier
,
redirectURI
,
proxyURL
,
clientID
string
)
(
*
openai
.
TokenResponse
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
s
*
openaiOAuthClientNoopStub
)
RefreshToken
(
ctx
context
.
Context
,
refreshToken
,
proxyURL
string
)
(
*
openai
.
TokenResponse
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
s
*
openaiOAuthClientNoopStub
)
RefreshTokenWithClientID
(
ctx
context
.
Context
,
refreshToken
,
proxyURL
string
,
clientID
string
)
(
*
openai
.
TokenResponse
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
TestOpenAIOAuthService_ExchangeSoraSessionToken_Success
(
t
*
testing
.
T
)
{
server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
require
.
Equal
(
t
,
http
.
MethodGet
,
r
.
Method
)
require
.
Contains
(
t
,
r
.
Header
.
Get
(
"Cookie"
),
"__Secure-next-auth.session-token=st-token"
)
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
w
.
Write
([]
byte
(
`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`
))
}))
defer
server
.
Close
()
origin
:=
openAISoraSessionAuthURL
openAISoraSessionAuthURL
=
server
.
URL
defer
func
()
{
openAISoraSessionAuthURL
=
origin
}()
svc
:=
NewOpenAIOAuthService
(
nil
,
&
openaiOAuthClientNoopStub
{})
defer
svc
.
Stop
()
info
,
err
:=
svc
.
ExchangeSoraSessionToken
(
context
.
Background
(),
"st-token"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
info
)
require
.
Equal
(
t
,
"at-token"
,
info
.
AccessToken
)
require
.
Equal
(
t
,
"demo@example.com"
,
info
.
Email
)
require
.
Greater
(
t
,
info
.
ExpiresAt
,
int64
(
0
))
}
func
TestOpenAIOAuthService_ExchangeSoraSessionToken_MissingAccessToken
(
t
*
testing
.
T
)
{
server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
w
.
Write
([]
byte
(
`{"expires":"2099-01-01T00:00:00Z"}`
))
}))
defer
server
.
Close
()
origin
:=
openAISoraSessionAuthURL
openAISoraSessionAuthURL
=
server
.
URL
defer
func
()
{
openAISoraSessionAuthURL
=
origin
}()
svc
:=
NewOpenAIOAuthService
(
nil
,
&
openaiOAuthClientNoopStub
{})
defer
svc
.
Stop
()
_
,
err
:=
svc
.
ExchangeSoraSessionToken
(
context
.
Background
(),
"st-token"
,
nil
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"missing access token"
)
}
func
TestOpenAIOAuthService_ExchangeSoraSessionToken_AcceptsSetCookieLine
(
t
*
testing
.
T
)
{
server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
require
.
Equal
(
t
,
http
.
MethodGet
,
r
.
Method
)
require
.
Contains
(
t
,
r
.
Header
.
Get
(
"Cookie"
),
"__Secure-next-auth.session-token=st-cookie-value"
)
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
w
.
Write
([]
byte
(
`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`
))
}))
defer
server
.
Close
()
origin
:=
openAISoraSessionAuthURL
openAISoraSessionAuthURL
=
server
.
URL
defer
func
()
{
openAISoraSessionAuthURL
=
origin
}()
svc
:=
NewOpenAIOAuthService
(
nil
,
&
openaiOAuthClientNoopStub
{})
defer
svc
.
Stop
()
raw
:=
"__Secure-next-auth.session-token.0=st-cookie-value; Domain=.chatgpt.com; Path=/; HttpOnly; Secure; SameSite=Lax"
info
,
err
:=
svc
.
ExchangeSoraSessionToken
(
context
.
Background
(),
raw
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"at-token"
,
info
.
AccessToken
)
}
func
TestOpenAIOAuthService_ExchangeSoraSessionToken_MergesChunkedSetCookieLines
(
t
*
testing
.
T
)
{
server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
require
.
Equal
(
t
,
http
.
MethodGet
,
r
.
Method
)
require
.
Contains
(
t
,
r
.
Header
.
Get
(
"Cookie"
),
"__Secure-next-auth.session-token=chunk-0chunk-1"
)
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
w
.
Write
([]
byte
(
`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`
))
}))
defer
server
.
Close
()
origin
:=
openAISoraSessionAuthURL
openAISoraSessionAuthURL
=
server
.
URL
defer
func
()
{
openAISoraSessionAuthURL
=
origin
}()
svc
:=
NewOpenAIOAuthService
(
nil
,
&
openaiOAuthClientNoopStub
{})
defer
svc
.
Stop
()
raw
:=
strings
.
Join
([]
string
{
"Set-Cookie: __Secure-next-auth.session-token.1=chunk-1; Path=/; HttpOnly"
,
"Set-Cookie: __Secure-next-auth.session-token.0=chunk-0; Path=/; HttpOnly"
,
},
"
\n
"
)
info
,
err
:=
svc
.
ExchangeSoraSessionToken
(
context
.
Background
(),
raw
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"at-token"
,
info
.
AccessToken
)
}
func
TestOpenAIOAuthService_ExchangeSoraSessionToken_PrefersLatestDuplicateChunks
(
t
*
testing
.
T
)
{
server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
require
.
Equal
(
t
,
http
.
MethodGet
,
r
.
Method
)
require
.
Contains
(
t
,
r
.
Header
.
Get
(
"Cookie"
),
"__Secure-next-auth.session-token=new-0new-1"
)
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
w
.
Write
([]
byte
(
`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`
))
}))
defer
server
.
Close
()
origin
:=
openAISoraSessionAuthURL
openAISoraSessionAuthURL
=
server
.
URL
defer
func
()
{
openAISoraSessionAuthURL
=
origin
}()
svc
:=
NewOpenAIOAuthService
(
nil
,
&
openaiOAuthClientNoopStub
{})
defer
svc
.
Stop
()
raw
:=
strings
.
Join
([]
string
{
"Set-Cookie: __Secure-next-auth.session-token.0=old-0; Path=/; HttpOnly"
,
"Set-Cookie: __Secure-next-auth.session-token.1=old-1; Path=/; HttpOnly"
,
"Set-Cookie: __Secure-next-auth.session-token.0=new-0; Path=/; HttpOnly"
,
"Set-Cookie: __Secure-next-auth.session-token.1=new-1; Path=/; HttpOnly"
,
},
"
\n
"
)
info
,
err
:=
svc
.
ExchangeSoraSessionToken
(
context
.
Background
(),
raw
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"at-token"
,
info
.
AccessToken
)
}
func
TestOpenAIOAuthService_ExchangeSoraSessionToken_UsesLatestCompleteChunkGroup
(
t
*
testing
.
T
)
{
server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
require
.
Equal
(
t
,
http
.
MethodGet
,
r
.
Method
)
require
.
Contains
(
t
,
r
.
Header
.
Get
(
"Cookie"
),
"__Secure-next-auth.session-token=ok-0ok-1"
)
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
w
.
Write
([]
byte
(
`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`
))
}))
defer
server
.
Close
()
origin
:=
openAISoraSessionAuthURL
openAISoraSessionAuthURL
=
server
.
URL
defer
func
()
{
openAISoraSessionAuthURL
=
origin
}()
svc
:=
NewOpenAIOAuthService
(
nil
,
&
openaiOAuthClientNoopStub
{})
defer
svc
.
Stop
()
raw
:=
strings
.
Join
([]
string
{
"set-cookie"
,
"__Secure-next-auth.session-token.0=ok-0; Domain=.chatgpt.com; Path=/"
,
"set-cookie"
,
"__Secure-next-auth.session-token.1=ok-1; Domain=.chatgpt.com; Path=/"
,
"set-cookie"
,
"__Secure-next-auth.session-token.0=partial-0; Domain=.chatgpt.com; Path=/"
,
},
"
\n
"
)
info
,
err
:=
svc
.
ExchangeSoraSessionToken
(
context
.
Background
(),
raw
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"at-token"
,
info
.
AccessToken
)
}
backend/internal/service/openai_token_provider.go
View file @
eb2dce92
...
...
@@ -75,7 +75,7 @@ func (m *openAITokenRuntimeMetricsStore) touchNow() {
// OpenAITokenCache token cache interface.
type
OpenAITokenCache
=
GeminiTokenCache
// OpenAITokenProvider manages access_token for OpenAI
/Sora
OAuth accounts.
// OpenAITokenProvider manages access_token for OpenAI OAuth accounts.
type
OpenAITokenProvider
struct
{
accountRepo
AccountRepository
tokenCache
OpenAITokenCache
...
...
@@ -131,8 +131,8 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
if
account
==
nil
{
return
""
,
errors
.
New
(
"account is nil"
)
}
if
(
account
.
Platform
!=
PlatformOpenAI
&&
account
.
Platform
!=
PlatformSora
)
||
account
.
Type
!=
AccountTypeOAuth
{
return
""
,
errors
.
New
(
"not an openai
/sora
oauth account"
)
if
account
.
Platform
!=
PlatformOpenAI
||
account
.
Type
!=
AccountTypeOAuth
{
return
""
,
errors
.
New
(
"not an openai oauth account"
)
}
cacheKey
:=
OpenAITokenCacheKey
(
account
)
...
...
@@ -158,40 +158,34 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
p
.
metrics
.
refreshRequests
.
Add
(
1
)
p
.
metrics
.
touchNow
()
// Sora accounts skip OpenAI OAuth refresh and keep existing token path.
if
account
.
Platform
==
PlatformSora
{
slog
.
Debug
(
"openai_token_refresh_skipped_for_sora"
,
"account_id"
,
account
.
ID
)
result
,
err
:=
p
.
refreshAPI
.
RefreshIfNeeded
(
ctx
,
account
,
p
.
executor
,
openAITokenRefreshSkew
)
if
err
!=
nil
{
if
p
.
refreshPolicy
.
OnRefreshError
==
ProviderRefreshErrorReturn
{
return
""
,
err
}
slog
.
Warn
(
"openai_token_refresh_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
p
.
metrics
.
refreshFailure
.
Add
(
1
)
refreshFailed
=
true
}
else
{
result
,
err
:=
p
.
refreshAPI
.
RefreshIfNeeded
(
ctx
,
account
,
p
.
executor
,
openAITokenRefreshSkew
)
if
err
!=
nil
{
if
p
.
refreshPolicy
.
OnRefreshError
==
ProviderRefreshErrorReturn
{
return
""
,
err
}
else
if
result
.
LockHeld
{
if
p
.
refreshPolicy
.
OnLockHeld
==
ProviderLockHeldWaitForCache
{
p
.
metrics
.
lockContention
.
Add
(
1
)
p
.
metrics
.
touchNow
()
token
,
waitErr
:=
p
.
waitForTokenAfterLockRace
(
ctx
,
cacheKey
)
if
waitErr
!=
nil
{
return
""
,
waitErr
}
slog
.
Warn
(
"openai_token_refresh_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
p
.
metrics
.
refreshFailure
.
Add
(
1
)
refreshFailed
=
true
}
else
if
result
.
LockHeld
{
if
p
.
refreshPolicy
.
OnLockHeld
==
ProviderLockHeldWaitForCache
{
p
.
metrics
.
lockContention
.
Add
(
1
)
p
.
metrics
.
touchNow
()
token
,
waitErr
:=
p
.
waitForTokenAfterLockRace
(
ctx
,
cacheKey
)
if
waitErr
!=
nil
{
return
""
,
waitErr
}
if
strings
.
TrimSpace
(
token
)
!=
""
{
slog
.
Debug
(
"openai_token_cache_hit_after_wait"
,
"account_id"
,
account
.
ID
)
return
token
,
nil
}
if
strings
.
TrimSpace
(
token
)
!=
""
{
slog
.
Debug
(
"openai_token_cache_hit_after_wait"
,
"account_id"
,
account
.
ID
)
return
token
,
nil
}
}
else
if
result
.
Refreshed
{
p
.
metrics
.
refreshSuccess
.
Add
(
1
)
account
=
result
.
Account
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
else
{
account
=
result
.
Account
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
}
else
if
result
.
Refreshed
{
p
.
metrics
.
refreshSuccess
.
Add
(
1
)
account
=
result
.
Account
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
else
{
account
=
result
.
Account
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
}
else
if
needsRefresh
&&
p
.
tokenCache
!=
nil
{
// Backward-compatible test path when refreshAPI is not injected.
...
...
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment