Commit eb2dce92 authored by 陈曦's avatar 陈曦
Browse files

升级v1.0.8 解决冲突

parents 7b83d6e7 339d906e
......@@ -27,7 +27,7 @@ func TestFinalizeProxyQualityResult_ScoreAndGrade(t *testing.T) {
require.Contains(t, result.Summary, "挑战 1 项")
}
func TestRunProxyQualityTarget_SoraChallenge(t *testing.T) {
func TestRunProxyQualityTarget_CloudflareChallenge(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.Header().Set("cf-ray", "test-ray-123")
......@@ -37,7 +37,7 @@ func TestRunProxyQualityTarget_SoraChallenge(t *testing.T) {
defer server.Close()
target := proxyQualityTarget{
Target: "sora",
Target: "openai",
URL: server.URL,
Method: http.MethodGet,
AllowedStatuses: map[int]struct{}{
......
......@@ -5,13 +5,12 @@ package service
import (
"bytes"
"context"
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
"github.com/stretchr/testify/require"
"io"
"net/http"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
"github.com/stretchr/testify/require"
)
// stubSmartRetryCache 用于 handleSmartRetry 测试的 GatewayCache mock
......@@ -81,17 +80,12 @@ func (m *mockSmartRetryUpstream) Do(req *http.Request, proxyURL string, accountI
m.responseBodies[respIdx] = bodyBytes
}
// 用缓存的 body 字节重建新的 reader
var body io.ReadCloser
// 用缓存的 body 重建 reader(支持重试场景多次读取)
cloned := *resp
if m.responseBodies[respIdx] != nil {
body = io.NopCloser(bytes.NewReader(m.responseBodies[respIdx]))
cloned.Body = io.NopCloser(bytes.NewReader(m.responseBodies[respIdx]))
}
return &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: body,
}, respErr
return &cloned, respErr
}
func (m *mockSmartRetryUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) {
......
......@@ -49,10 +49,6 @@ type APIKeyAuthGroupSnapshot struct {
ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"`
SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"`
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"`
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd,omitempty"`
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"`
......
......@@ -234,10 +234,6 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
ImagePrice1K: apiKey.Group.ImagePrice1K,
ImagePrice2K: apiKey.Group.ImagePrice2K,
ImagePrice4K: apiKey.Group.ImagePrice4K,
SoraImagePrice360: apiKey.Group.SoraImagePrice360,
SoraImagePrice540: apiKey.Group.SoraImagePrice540,
SoraVideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
FallbackGroupID: apiKey.Group.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: apiKey.Group.FallbackGroupIDOnInvalidRequest,
......@@ -293,10 +289,6 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
ImagePrice1K: snapshot.Group.ImagePrice1K,
ImagePrice2K: snapshot.Group.ImagePrice2K,
ImagePrice4K: snapshot.Group.ImagePrice4K,
SoraImagePrice360: snapshot.Group.SoraImagePrice360,
SoraImagePrice540: snapshot.Group.SoraImagePrice540,
SoraVideoPricePerRequest: snapshot.Group.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: snapshot.Group.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
FallbackGroupID: snapshot.Group.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: snapshot.Group.FallbackGroupIDOnInvalidRequest,
......
......@@ -808,14 +808,6 @@ type ImagePriceConfig struct {
Price4K *float64 // 4K 尺寸价格(nil 表示使用默认值)
}
// SoraPriceConfig Sora 按次计费配置
type SoraPriceConfig struct {
ImagePrice360 *float64
ImagePrice540 *float64
VideoPricePerRequest *float64
VideoPricePerRequestHD *float64
}
// CalculateImageCost 计算图片生成费用
// model: 请求的模型名称(用于获取 LiteLLM 默认价格)
// imageSize: 图片尺寸 "1K", "2K", "4K"
......@@ -846,65 +838,6 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag
}
}
// CalculateSoraImageCost 计算 Sora 图片按次费用
func (s *BillingService) CalculateSoraImageCost(imageSize string, imageCount int, groupConfig *SoraPriceConfig, rateMultiplier float64) *CostBreakdown {
if imageCount <= 0 {
return &CostBreakdown{}
}
unitPrice := 0.0
if groupConfig != nil {
switch imageSize {
case "540":
if groupConfig.ImagePrice540 != nil {
unitPrice = *groupConfig.ImagePrice540
}
default:
if groupConfig.ImagePrice360 != nil {
unitPrice = *groupConfig.ImagePrice360
}
}
}
totalCost := unitPrice * float64(imageCount)
if rateMultiplier <= 0 {
rateMultiplier = 1.0
}
actualCost := totalCost * rateMultiplier
return &CostBreakdown{
TotalCost: totalCost,
ActualCost: actualCost,
}
}
// CalculateSoraVideoCost 计算 Sora 视频按次费用
func (s *BillingService) CalculateSoraVideoCost(model string, groupConfig *SoraPriceConfig, rateMultiplier float64) *CostBreakdown {
unitPrice := 0.0
if groupConfig != nil {
modelLower := strings.ToLower(model)
if strings.Contains(modelLower, "sora2pro-hd") {
if groupConfig.VideoPricePerRequestHD != nil {
unitPrice = *groupConfig.VideoPricePerRequestHD
}
}
if unitPrice <= 0 && groupConfig.VideoPricePerRequest != nil {
unitPrice = *groupConfig.VideoPricePerRequest
}
}
totalCost := unitPrice
if rateMultiplier <= 0 {
rateMultiplier = 1.0
}
actualCost := totalCost * rateMultiplier
return &CostBreakdown{
TotalCost: totalCost,
ActualCost: actualCost,
}
}
// getImageUnitPrice 获取图片单价
func (s *BillingService) getImageUnitPrice(model string, imageSize string, groupConfig *ImagePriceConfig) float64 {
// 优先使用分组配置的价格
......
......@@ -363,28 +363,6 @@ func TestCalculateImageCost(t *testing.T) {
require.InDelta(t, 0.134*3, cost.ActualCost, 1e-10)
}
func TestCalculateSoraVideoCost(t *testing.T) {
svc := newTestBillingService()
price := 0.5
cfg := &SoraPriceConfig{VideoPricePerRequest: &price}
cost := svc.CalculateSoraVideoCost("sora-video", cfg, 1.0)
require.InDelta(t, 0.5, cost.TotalCost, 1e-10)
}
func TestCalculateSoraVideoCost_HDModel(t *testing.T) {
svc := newTestBillingService()
hdPrice := 1.0
normalPrice := 0.5
cfg := &SoraPriceConfig{
VideoPricePerRequest: &normalPrice,
VideoPricePerRequestHD: &hdPrice,
}
cost := svc.CalculateSoraVideoCost("sora2pro-hd", cfg, 1.0)
require.InDelta(t, 1.0, cost.TotalCost, 1e-10)
}
func TestIsModelSupported(t *testing.T) {
svc := newTestBillingService()
......@@ -464,33 +442,6 @@ func TestForceUpdatePricing_NilService(t *testing.T) {
require.Contains(t, err.Error(), "not initialized")
}
func TestCalculateSoraImageCost(t *testing.T) {
svc := newTestBillingService()
price360 := 0.05
price540 := 0.08
cfg := &SoraPriceConfig{ImagePrice360: &price360, ImagePrice540: &price540}
cost := svc.CalculateSoraImageCost("360", 2, cfg, 1.0)
require.InDelta(t, 0.10, cost.TotalCost, 1e-10)
cost540 := svc.CalculateSoraImageCost("540", 1, cfg, 2.0)
require.InDelta(t, 0.08, cost540.TotalCost, 1e-10)
require.InDelta(t, 0.16, cost540.ActualCost, 1e-10)
}
func TestCalculateSoraImageCost_ZeroCount(t *testing.T) {
svc := newTestBillingService()
cost := svc.CalculateSoraImageCost("360", 0, nil, 1.0)
require.Equal(t, 0.0, cost.TotalCost)
}
func TestCalculateSoraVideoCost_NilConfig(t *testing.T) {
svc := newTestBillingService()
cost := svc.CalculateSoraVideoCost("sora-video", nil, 1.0)
require.Equal(t, 0.0, cost.TotalCost)
}
func TestCalculateCostWithLongContext_PropagatesError(t *testing.T) {
// 使用空的 fallback prices 让 GetModelPricing 失败
svc := &BillingService{
......
......@@ -197,10 +197,8 @@ func newEmptyChannelCache() *channelCache {
}
// expandPricingToCache 将渠道的模型定价展开到缓存(按分组+平台维度)。
// antigravity 平台同时服务 Claude 和 Gemini 模型,需匹配 anthropic/gemini 的定价条目。
// 缓存 key 使用定价条目的原始平台(pricing.Platform),而非分组平台,
// 避免跨平台同名模型(如 anthropic 和 gemini 都有 "model-x")互相覆盖。
// 查找时通过 lookupPricingAcrossPlatforms() 依次尝试所有匹配平台。
// 各平台严格独立:antigravity 分组只匹配 antigravity 定价,不会匹配 anthropic/gemini 的定价。
// 查找时通过 lookupPricingAcrossPlatforms() 在本平台内查找。
func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
for j := range ch.ModelPricing {
pricing := &ch.ModelPricing[j]
......@@ -226,8 +224,7 @@ func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform
}
// expandMappingToCache 将渠道的模型映射展开到缓存(按分组+平台维度)。
// antigravity 平台同时服务 Claude 和 Gemini 模型。
// 缓存 key 使用映射条目的原始平台(mappingPlatform),避免跨平台同名映射覆盖。
// 各平台严格独立:antigravity 分组只匹配 antigravity 映射。
func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
for _, mappingPlatform := range matchingPlatforms(platform) {
platformMapping, ok := ch.ModelMapping[mappingPlatform]
......@@ -311,23 +308,14 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
// invalidateCache 使缓存失效,让下次读取时自然重建
// isPlatformPricingMatch 判断定价条目的平台是否匹配分组平台。
// antigravity 平台同时服务 Claude(anthropic)和 Gemini(gemini)模型,
// 因此 antigravity 分组应匹配 anthropic 和 gemini 的定价条目。
// 各平台(antigravity / anthropic / gemini / openai)严格独立,不跨平台匹配。
func isPlatformPricingMatch(groupPlatform, pricingPlatform string) bool {
if groupPlatform == pricingPlatform {
return true
}
if groupPlatform == PlatformAntigravity {
return pricingPlatform == PlatformAnthropic || pricingPlatform == PlatformGemini
}
return false
return groupPlatform == pricingPlatform
}
// matchingPlatforms 返回分组平台对应的所有可匹配平台列表。
// matchingPlatforms 返回分组平台对应的可匹配平台列表。
// 各平台严格独立,只返回自身。
func matchingPlatforms(groupPlatform string) []string {
if groupPlatform == PlatformAntigravity {
return []string{PlatformAntigravity, PlatformAnthropic, PlatformGemini}
}
return []string{groupPlatform}
}
func (s *ChannelService) invalidateCache() {
......@@ -364,10 +352,8 @@ func (c *channelCache) matchWildcardMapping(groupID int64, platform, modelLower
return ""
}
// lookupPricingAcrossPlatforms 在所有匹配平台中查找模型定价。
// antigravity 分组的缓存 key 使用定价条目的原始平台,因此查找时需依次尝试
// matchingPlatforms() 返回的所有平台(antigravity → anthropic → gemini),
// 返回第一个命中的结果。非 antigravity 平台只尝试自身。
// lookupPricingAcrossPlatforms 在分组平台内查找模型定价。
// 各平台严格独立,只在本平台内查找(先精确匹配,再通配符)。
func lookupPricingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) *ChannelModelPricing {
for _, p := range matchingPlatforms(groupPlatform) {
key := channelModelKey{groupID: groupID, platform: p, model: modelLower}
......@@ -384,7 +370,7 @@ func lookupPricingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatf
return nil
}
// lookupMappingAcrossPlatforms 在所有匹配平台查找模型映射。
// lookupMappingAcrossPlatforms 在分组平台查找模型映射。
// 逻辑与 lookupPricingAcrossPlatforms 相同:先精确查找,再通配符。
func lookupMappingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) string {
for _, p := range matchingPlatforms(groupPlatform) {
......@@ -442,8 +428,7 @@ func (s *ChannelService) lookupGroupChannel(ctx context.Context, groupID int64)
}
// GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径 O(1))。
// antigravity 分组依次尝试所有匹配平台(antigravity → anthropic → gemini),
// 确保跨平台同名模型各自独立匹配。
// 各平台严格独立,只在本平台内查找定价。
func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int64, model string) *ChannelModelPricing {
lk, err := s.lookupGroupChannel(ctx, groupID)
if err != nil {
......@@ -524,7 +509,7 @@ func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappi
}
// checkRestricted 基于已查找的渠道信息检查模型是否被限制。
// antigravity 分组依次尝试所有匹配平台的定价列表。
// 只在本平台的定价列表中查找
func checkRestricted(lk *channelLookup, groupID int64, model string) bool {
if !lk.channel.RestrictModels {
return false
......
......@@ -1932,8 +1932,8 @@ func TestIsPlatformPricingMatch(t *testing.T) {
pricingPlatform string
want bool
}{
{"antigravity matches anthropic", PlatformAntigravity, PlatformAnthropic, true},
{"antigravity matches gemini", PlatformAntigravity, PlatformGemini, true},
{"antigravity does NOT match anthropic", PlatformAntigravity, PlatformAnthropic, false},
{"antigravity does NOT match gemini", PlatformAntigravity, PlatformGemini, false},
{"antigravity matches antigravity", PlatformAntigravity, PlatformAntigravity, true},
{"antigravity does NOT match openai", PlatformAntigravity, PlatformOpenAI, false},
{"anthropic matches anthropic", PlatformAnthropic, PlatformAnthropic, true},
......@@ -1963,7 +1963,7 @@ func TestMatchingPlatforms(t *testing.T) {
groupPlatform string
want []string
}{
{"antigravity returns all three", PlatformAntigravity, []string{PlatformAntigravity, PlatformAnthropic, PlatformGemini}},
{"antigravity returns itself only", PlatformAntigravity, []string{PlatformAntigravity}},
{"anthropic returns itself", PlatformAnthropic, []string{PlatformAnthropic}},
{"gemini returns itself", PlatformGemini, []string{PlatformGemini}},
{"openai returns itself", PlatformOpenAI, []string{PlatformOpenAI}},
......@@ -1978,12 +1978,12 @@ func TestMatchingPlatforms(t *testing.T) {
}
// ===========================================================================
// 9. Antigravity cross-platform channel pricing
// 9. Antigravity platform isolation — no cross-platform pricing leakage
// ===========================================================================
func TestGetChannelModelPricing_AntigravityCrossPlatform(t *testing.T) {
func TestGetChannelModelPricing_AntigravityDoesNotSeeCrossPlatformPricing(t *testing.T) {
// Channel has anthropic pricing for claude-opus-4-6.
// Group 10 is antigravity — should see the anthropic pricing.
// Group 10 is antigravity — should NOT see the anthropic pricing.
ch := Channel{
ID: 1,
Status: StatusActive,
......@@ -1996,9 +1996,7 @@ func TestGetChannelModelPricing_AntigravityCrossPlatform(t *testing.T) {
svc := newTestChannelService(repo)
result := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4-6")
require.NotNil(t, result, "antigravity group should see anthropic pricing")
require.Equal(t, int64(100), result.ID)
require.InDelta(t, 15e-6, *result.InputPrice, 1e-12)
require.Nil(t, result, "antigravity group should NOT see anthropic-platform pricing")
}
func TestGetChannelModelPricing_AnthropicCannotSeeAntigravityPricing(t *testing.T) {
......@@ -2020,12 +2018,12 @@ func TestGetChannelModelPricing_AnthropicCannotSeeAntigravityPricing(t *testing.
}
// ===========================================================================
// 10. Antigravity cross-platform model mapping
// 10. Antigravity platform isolation — no cross-platform model mapping
// ===========================================================================
func TestResolveChannelMapping_AntigravityCrossPlatform(t *testing.T) {
func TestResolveChannelMapping_AntigravityDoesNotSeeCrossPlatformMapping(t *testing.T) {
// Channel has anthropic model mapping: claude-opus-4-5 → claude-opus-4-6.
// Group 10 is antigravity — should apply the anthropic mapping.
// Group 10 is antigravity — should NOT apply the anthropic mapping.
ch := Channel{
ID: 1,
Status: StatusActive,
......@@ -2040,18 +2038,17 @@ func TestResolveChannelMapping_AntigravityCrossPlatform(t *testing.T) {
svc := newTestChannelService(repo)
result := svc.ResolveChannelMapping(context.Background(), 10, "claude-opus-4-5")
require.True(t, result.Mapped, "antigravity group should apply anthropic mapping")
require.Equal(t, "claude-opus-4-6", result.MappedModel)
require.Equal(t, int64(1), result.ChannelID)
require.False(t, result.Mapped, "antigravity group should NOT apply anthropic mapping")
require.Equal(t, "claude-opus-4-5", result.MappedModel)
}
// ===========================================================================
// 11. Antigravity cross-platform same-name model — no overwrite
// 11. Antigravity platform isolation — same-name model across platforms
// ===========================================================================
func TestGetChannelModelPricing_AntigravitySameModelDifferentPlatforms(t *testing.T) {
func TestGetChannelModelPricing_AntigravityDoesNotSeeSameModelFromOtherPlatforms(t *testing.T) {
// anthropic 和 gemini 都定义了同名模型 "shared-model",价格不同。
// antigravity 分组应能分别查到各自的定价,而不是后者覆盖前者
// antigravity 分组不应看到任何一个(各平台严格独立)
ch := Channel{
ID: 1,
Status: StatusActive,
......@@ -2064,17 +2061,13 @@ func TestGetChannelModelPricing_AntigravitySameModelDifferentPlatforms(t *testin
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity})
svc := newTestChannelService(repo)
// antigravity 分组查找 "shared-model":应命中第一个匹配(按 matchingPlatforms 顺序 antigravity→anthropic→gemini)
result := svc.GetChannelModelPricing(context.Background(), 10, "shared-model")
require.NotNil(t, result, "antigravity group should find pricing for shared-model")
// 第一个匹配应该是 anthropic(matchingPlatforms 返回 [antigravity, anthropic, gemini])
require.Equal(t, int64(200), result.ID)
require.InDelta(t, 10e-6, *result.InputPrice, 1e-12)
require.Nil(t, result, "antigravity group should NOT see anthropic/gemini-platform pricing")
}
func TestGetChannelModelPricing_AntigravityOnlyGeminiPricing(t *testing.T) {
func TestGetChannelModelPricing_AntigravityDoesNotSeeGeminiOnlyPricing(t *testing.T) {
// 只有 gemini 平台定义了模型 "gemini-model"。
// antigravity 分组应能查到 gemini 的定价。
// antigravity 分组不应看到 gemini 的定价。
ch := Channel{
ID: 1,
Status: StatusActive,
......@@ -2087,14 +2080,12 @@ func TestGetChannelModelPricing_AntigravityOnlyGeminiPricing(t *testing.T) {
svc := newTestChannelService(repo)
result := svc.GetChannelModelPricing(context.Background(), 10, "gemini-model")
require.NotNil(t, result, "antigravity group should find gemini pricing")
require.Equal(t, int64(300), result.ID)
require.InDelta(t, 2e-6, *result.InputPrice, 1e-12)
require.Nil(t, result, "antigravity group should NOT see gemini-platform pricing")
}
func TestGetChannelModelPricing_AntigravityWildcardCrossPlatformNoOverwrite(t *testing.T) {
// anthropic 和 gemini 都有 "shared-*" 通配符定价,价格不同
// antigravity 分组查找 "shared-model" 应命中第一个匹配而非被覆盖
func TestGetChannelModelPricing_AntigravityDoesNotSeeWildcardFromOtherPlatforms(t *testing.T) {
// anthropic 和 gemini 都有 "shared-*" 通配符定价。
// antigravity 分组不应命中任何一个
ch := Channel{
ID: 1,
Status: StatusActive,
......@@ -2108,15 +2099,12 @@ func TestGetChannelModelPricing_AntigravityWildcardCrossPlatformNoOverwrite(t *t
svc := newTestChannelService(repo)
result := svc.GetChannelModelPricing(context.Background(), 10, "shared-model")
require.NotNil(t, result, "antigravity group should find wildcard pricing for shared-model")
// 两个通配符都存在,应命中 anthropic 的(matchingPlatforms 顺序)
require.Equal(t, int64(400), result.ID)
require.InDelta(t, 10e-6, *result.InputPrice, 1e-12)
require.Nil(t, result, "antigravity group should NOT see wildcard pricing from other platforms")
}
func TestResolveChannelMapping_AntigravitySameModelDifferentPlatforms(t *testing.T) {
func TestResolveChannelMapping_AntigravityDoesNotSeeMappingFromOtherPlatforms(t *testing.T) {
// anthropic 和 gemini 都定义了同名模型映射 "alias" → 不同目标。
// antigravity 分组应命中 anthropic 的映射(按 matchingPlatforms 顺序)
// antigravity 分组应命中任何一个
ch := Channel{
ID: 1,
Status: StatusActive,
......@@ -2130,13 +2118,13 @@ func TestResolveChannelMapping_AntigravitySameModelDifferentPlatforms(t *testing
svc := newTestChannelService(repo)
result := svc.ResolveChannelMapping(context.Background(), 10, "alias")
require.True(t, result.Mapped)
require.Equal(t, "anthropic-target", result.MappedModel)
require.False(t, result.Mapped, "antigravity group should NOT see mapping from other platforms")
require.Equal(t, "alias", result.MappedModel)
}
func TestCheckRestricted_AntigravitySameModelDifferentPlatforms(t *testing.T) {
func TestCheckRestricted_AntigravityDoesNotSeeModelsFromOtherPlatforms(t *testing.T) {
// anthropic 和 gemini 都定义了同名模型 "shared-model"。
// antigravity 分组启用了 RestrictModels,"shared-model" 应被限制。
// antigravity 分组启用了 RestrictModels,"shared-model" 应被限制(各平台独立)
ch := Channel{
ID: 1,
Status: StatusActive,
......@@ -2151,13 +2139,39 @@ func TestCheckRestricted_AntigravitySameModelDifferentPlatforms(t *testing.T) {
svc := newTestChannelService(repo)
restricted := svc.IsModelRestricted(context.Background(), 10, "shared-model")
require.False(t, restricted, "shared-model should not be restricted for antigravity")
require.True(t, restricted, "shared-model from other platforms should be restricted for antigravity")
// 未定义的模型应被限制
restricted = svc.IsModelRestricted(context.Background(), 10, "unknown-model")
require.True(t, restricted, "unknown-model should be restricted for antigravity")
}
func TestGetChannelModelPricing_AntigravityOwnPricingWorks(t *testing.T) {
// antigravity 平台自己配置的定价应正常生效(覆盖 Claude 和 Gemini 模型)。
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
ModelPricing: []ChannelModelPricing{
{ID: 600, Platform: PlatformAntigravity, Models: []string{"claude-*"}, InputPrice: testPtrFloat64(15e-6)},
{ID: 601, Platform: PlatformAntigravity, Models: []string{"gemini-*"}, InputPrice: testPtrFloat64(2e-6)},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity})
svc := newTestChannelService(repo)
// Claude 模型匹配 antigravity 定价
result := svc.GetChannelModelPricing(context.Background(), 10, "claude-sonnet-4")
require.NotNil(t, result)
require.Equal(t, int64(600), result.ID)
require.InDelta(t, 15e-6, *result.InputPrice, 1e-12)
// Gemini 模型匹配 antigravity 定价
result = svc.GetChannelModelPricing(context.Background(), 10, "gemini-2.5-flash")
require.NotNil(t, result)
require.Equal(t, int64(601), result.ID)
require.InDelta(t, 2e-6, *result.InputPrice, 1e-12)
}
func TestGetChannelModelPricing_NonAntigravityUnaffected(t *testing.T) {
// 确保非 antigravity 平台的行为不受影响。
// anthropic 分组只能看到 anthropic 的定价,看不到 gemini 的。
......
......@@ -24,7 +24,6 @@ const (
PlatformOpenAI = domain.PlatformOpenAI
PlatformGemini = domain.PlatformGemini
PlatformAntigravity = domain.PlatformAntigravity
PlatformSora = domain.PlatformSora
)
// Account type constants
......@@ -107,7 +106,6 @@ const (
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
// OEM设置
SettingKeySoraClientEnabled = "sora_client_enabled" // 是否启用 Sora 客户端(管理员手动控制)
SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
......@@ -199,27 +197,6 @@ const (
// SettingKeyBetaPolicySettings stores JSON config for beta policy rules.
SettingKeyBetaPolicySettings = "beta_policy_settings"
// =========================
// Sora S3 存储配置
// =========================
SettingKeySoraS3Enabled = "sora_s3_enabled" // 是否启用 Sora S3 存储
SettingKeySoraS3Endpoint = "sora_s3_endpoint" // S3 端点地址
SettingKeySoraS3Region = "sora_s3_region" // S3 区域
SettingKeySoraS3Bucket = "sora_s3_bucket" // S3 存储桶名称
SettingKeySoraS3AccessKeyID = "sora_s3_access_key_id" // S3 Access Key ID
SettingKeySoraS3SecretAccessKey = "sora_s3_secret_access_key" // S3 Secret Access Key(加密存储)
SettingKeySoraS3Prefix = "sora_s3_prefix" // S3 对象键前缀
SettingKeySoraS3ForcePathStyle = "sora_s3_force_path_style" // 是否强制 Path Style(兼容 MinIO 等)
SettingKeySoraS3CDNURL = "sora_s3_cdn_url" // CDN 加速 URL(可选)
SettingKeySoraS3Profiles = "sora_s3_profiles" // Sora S3 多配置(JSON)
// =========================
// Sora 用户存储配额
// =========================
SettingKeySoraDefaultStorageQuotaBytes = "sora_default_storage_quota_bytes" // 新用户默认 Sora 存储配额(字节)
// =========================
// Claude Code Version Check
// =========================
......
......@@ -60,13 +60,6 @@ const (
claudeMimicDebugInfoKey = "claude_mimic_debug_info"
)
// MediaType 媒体类型常量
const (
MediaTypeImage = "image"
MediaTypeVideo = "video"
MediaTypePrompt = "prompt"
)
// ForceCacheBillingContextKey 强制缓存计费上下文键
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
type forceCacheBillingKeyType struct{}
......@@ -511,9 +504,6 @@ type ForwardResult struct {
ImageCount int // 生成的图片数量
ImageSize string // 图片尺寸 "1K", "2K", "4K"
// Sora 媒体字段
MediaType string // image / video / prompt
MediaURL string // 生成后的媒体地址(可选)
}
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
......@@ -1971,9 +1961,6 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr
}
func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
if platform == PlatformSora {
return s.listSoraSchedulableAccounts(ctx, groupID)
}
if s.schedulerSnapshot != nil {
accounts, useMixed, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
if err == nil {
......@@ -2070,53 +2057,6 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
return accounts, useMixed, nil
}
func (s *GatewayService) listSoraSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, bool, error) {
const useMixed = false
var accounts []Account
var err error
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora)
} else if groupID != nil {
accounts, err = s.accountRepo.ListByGroup(ctx, *groupID)
} else {
accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora)
}
if err != nil {
slog.Debug("account_scheduling_list_failed",
"group_id", derefGroupID(groupID),
"platform", PlatformSora,
"error", err)
return nil, useMixed, err
}
filtered := make([]Account, 0, len(accounts))
for _, acc := range accounts {
if acc.Platform != PlatformSora {
continue
}
if !s.isSoraAccountSchedulable(&acc) {
continue
}
filtered = append(filtered, acc)
}
slog.Debug("account_scheduling_list_sora",
"group_id", derefGroupID(groupID),
"platform", PlatformSora,
"raw_count", len(accounts),
"filtered_count", len(filtered))
for _, acc := range filtered {
slog.Debug("account_scheduling_account_detail",
"account_id", acc.ID,
"name", acc.Name,
"platform", acc.Platform,
"type", acc.Type,
"status", acc.Status,
"tls_fingerprint", acc.IsTLSFingerprintEnabled())
}
return filtered, useMixed, nil
}
// IsSingleAntigravityAccountGroup 检查指定分组是否只有一个 antigravity 平台的可调度账号。
// 用于 Handler 层在首次请求时提前设置 SingleAccountRetry context,
// 避免单账号分组收到 503 时错误地设置模型限流标记导致后续请求连续快速失败。
......@@ -2141,33 +2081,10 @@ func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform
return account.Platform == platform
}
func (s *GatewayService) isSoraAccountSchedulable(account *Account) bool {
return s.soraUnschedulableReason(account) == ""
}
func (s *GatewayService) soraUnschedulableReason(account *Account) string {
if account == nil {
return "account_nil"
}
if account.Status != StatusActive {
return fmt.Sprintf("status=%s", account.Status)
}
if !account.Schedulable {
return "schedulable=false"
}
if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
return fmt.Sprintf("temp_unschedulable_until=%s", account.TempUnschedulableUntil.UTC().Format(time.RFC3339))
}
return ""
}
func (s *GatewayService) isAccountSchedulableForSelection(account *Account) bool {
if account == nil {
return false
}
if account.Platform == PlatformSora {
return s.isSoraAccountSchedulable(account)
}
return account.IsSchedulable()
}
......@@ -2175,12 +2092,6 @@ func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Conte
if account == nil {
return false
}
if account.Platform == PlatformSora {
if !s.isSoraAccountSchedulable(account) {
return false
}
return account.GetRateLimitRemainingTimeWithContext(ctx, requestedModel) <= 0
}
return account.IsSchedulableForModelWithContext(ctx, requestedModel)
}
......@@ -3357,9 +3268,6 @@ func (s *GatewayService) logDetailedSelectionFailure(
stats.SampleMappingIDs,
stats.SampleRateLimitIDs,
)
if platform == PlatformSora {
s.logSoraSelectionFailureDetails(ctx, groupID, sessionHash, requestedModel, accounts, excludedIDs, allowMixedScheduling)
}
return stats
}
......@@ -3416,11 +3324,7 @@ func (s *GatewayService) diagnoseSelectionFailure(
return selectionFailureDiagnosis{Category: "excluded"}
}
if !s.isAccountSchedulableForSelection(acc) {
detail := "generic_unschedulable"
if acc.Platform == PlatformSora {
detail = s.soraUnschedulableReason(acc)
}
return selectionFailureDiagnosis{Category: "unschedulable", Detail: detail}
return selectionFailureDiagnosis{Category: "unschedulable", Detail: "generic_unschedulable"}
}
if isPlatformFilteredForSelection(acc, platform, allowMixedScheduling) {
return selectionFailureDiagnosis{
......@@ -3444,57 +3348,6 @@ func (s *GatewayService) diagnoseSelectionFailure(
return selectionFailureDiagnosis{Category: "eligible"}
}
func (s *GatewayService) logSoraSelectionFailureDetails(
ctx context.Context,
groupID *int64,
sessionHash string,
requestedModel string,
accounts []Account,
excludedIDs map[int64]struct{},
allowMixedScheduling bool,
) {
const maxLines = 30
logged := 0
for i := range accounts {
if logged >= maxLines {
break
}
acc := &accounts[i]
diagnosis := s.diagnoseSelectionFailure(ctx, acc, requestedModel, PlatformSora, excludedIDs, allowMixedScheduling)
if diagnosis.Category == "eligible" {
continue
}
detail := diagnosis.Detail
if detail == "" {
detail = "-"
}
logger.LegacyPrintf(
"service.gateway",
"[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s account_id=%d account_platform=%s category=%s detail=%s",
derefGroupID(groupID),
requestedModel,
shortSessionHash(sessionHash),
acc.ID,
acc.Platform,
diagnosis.Category,
detail,
)
logged++
}
if len(accounts) > maxLines {
logger.LegacyPrintf(
"service.gateway",
"[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s truncated=true total=%d logged=%d",
derefGroupID(groupID),
requestedModel,
shortSessionHash(sessionHash),
len(accounts),
logged,
)
}
}
func isPlatformFilteredForSelection(acc *Account, platform string, allowMixedScheduling bool) bool {
if acc == nil {
return true
......@@ -3573,9 +3426,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
}
return mapAntigravityModel(account, requestedModel) != ""
}
if account.Platform == PlatformSora {
return s.isSoraModelSupportedByAccount(account, requestedModel)
}
if account.IsBedrock() {
_, ok := ResolveBedrockModelID(account, requestedModel)
return ok
......@@ -3588,143 +3438,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
return account.IsModelSupported(requestedModel)
}
func (s *GatewayService) isSoraModelSupportedByAccount(account *Account, requestedModel string) bool {
if account == nil {
return false
}
if strings.TrimSpace(requestedModel) == "" {
return true
}
// 先走原始精确/通配符匹配。
mapping := account.GetModelMapping()
if len(mapping) == 0 || account.IsModelSupported(requestedModel) {
return true
}
aliases := buildSoraModelAliases(requestedModel)
if len(aliases) == 0 {
return false
}
hasSoraSelector := false
for pattern := range mapping {
if !isSoraModelSelector(pattern) {
continue
}
hasSoraSelector = true
if matchPatternAnyAlias(pattern, aliases) {
return true
}
}
// 兼容旧账号:mapping 存在但未配置任何 Sora 选择器(例如只含 gpt-*),
// 此时不应误拦截 Sora 模型请求。
if !hasSoraSelector {
return true
}
return false
}
func matchPatternAnyAlias(pattern string, aliases []string) bool {
normalizedPattern := strings.ToLower(strings.TrimSpace(pattern))
if normalizedPattern == "" {
return false
}
for _, alias := range aliases {
if matchWildcard(normalizedPattern, alias) {
return true
}
}
return false
}
func isSoraModelSelector(pattern string) bool {
p := strings.ToLower(strings.TrimSpace(pattern))
if p == "" {
return false
}
switch {
case strings.HasPrefix(p, "sora"),
strings.HasPrefix(p, "gpt-image"),
strings.HasPrefix(p, "prompt-enhance"),
strings.HasPrefix(p, "sy_"):
return true
}
return p == "video" || p == "image"
}
func buildSoraModelAliases(requestedModel string) []string {
modelID := strings.ToLower(strings.TrimSpace(requestedModel))
if modelID == "" {
return nil
}
aliases := make([]string, 0, 8)
addAlias := func(value string) {
v := strings.ToLower(strings.TrimSpace(value))
if v == "" {
return
}
for _, existing := range aliases {
if existing == v {
return
}
}
aliases = append(aliases, v)
}
addAlias(modelID)
cfg, ok := GetSoraModelConfig(modelID)
if ok {
addAlias(cfg.Model)
switch cfg.Type {
case "video":
addAlias("video")
addAlias("sora")
addAlias(soraVideoFamilyAlias(modelID))
case "image":
addAlias("image")
addAlias("gpt-image")
case "prompt_enhance":
addAlias("prompt-enhance")
}
return aliases
}
switch {
case strings.HasPrefix(modelID, "sora"):
addAlias("video")
addAlias("sora")
addAlias(soraVideoFamilyAlias(modelID))
case strings.HasPrefix(modelID, "gpt-image"):
addAlias("image")
addAlias("gpt-image")
case strings.HasPrefix(modelID, "prompt-enhance"):
addAlias("prompt-enhance")
default:
return nil
}
return aliases
}
func soraVideoFamilyAlias(modelID string) string {
switch {
case strings.HasPrefix(modelID, "sora2pro-hd"):
return "sora2pro-hd"
case strings.HasPrefix(modelID, "sora2pro"):
return "sora2pro"
case strings.HasPrefix(modelID, "sora2"):
return "sora2"
default:
return ""
}
}
// GetAccessToken 获取账号凭证
func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
switch account.Type {
......@@ -7592,9 +7305,6 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage
cmd.CacheCreationTokens = usageLog.CacheCreationTokens
cmd.CacheReadTokens = usageLog.CacheReadTokens
cmd.ImageCount = usageLog.ImageCount
if usageLog.MediaType != nil {
cmd.MediaType = *usageLog.MediaType
}
if usageLog.ServiceTier != nil {
cmd.ServiceTier = *usageLog.ServiceTier
}
......@@ -7750,8 +7460,6 @@ type recordUsageOpts struct {
// EnableClaudePath 启用 Claude 路径特有逻辑:
// - Claude Max 缓存计费策略
// - Sora 媒体类型分支(image/video/prompt)
// - MediaType 字段写入使用日志
EnableClaudePath bool
// 长上下文计费(仅 Gemini 路径需要)
......@@ -7842,7 +7550,6 @@ type recordUsageCoreInput struct {
// recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。
// opts 中的字段控制两者之间的差异行为:
// - ParsedRequest != nil → 启用 Claude Max 缓存计费策略
// - EnableSoraMedia → 启用 Sora MediaType 分支(image/video/prompt)
// - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext
func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsageCoreInput, opts *recordUsageOpts) error {
result := input.Result
......@@ -7944,16 +7651,6 @@ func (s *GatewayService) calculateRecordUsageCost(
multiplier float64,
opts *recordUsageOpts,
) *CostBreakdown {
// Sora 媒体类型分支(仅 Claude 路径启用)
if opts.EnableClaudePath {
if result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo {
return s.calculateSoraMediaCost(result, apiKey, billingModel, multiplier)
}
if result.MediaType == MediaTypePrompt {
return &CostBreakdown{}
}
}
// 图片生成计费
if result.ImageCount > 0 {
return s.calculateImageCost(ctx, result, apiKey, billingModel, multiplier)
......@@ -7963,28 +7660,6 @@ func (s *GatewayService) calculateRecordUsageCost(
return s.calculateTokenCost(ctx, result, apiKey, billingModel, multiplier, opts)
}
// calculateSoraMediaCost 计算 Sora 图片/视频的费用。
func (s *GatewayService) calculateSoraMediaCost(
result *ForwardResult,
apiKey *APIKey,
billingModel string,
multiplier float64,
) *CostBreakdown {
var soraConfig *SoraPriceConfig
if apiKey.Group != nil {
soraConfig = &SoraPriceConfig{
ImagePrice360: apiKey.Group.SoraImagePrice360,
ImagePrice540: apiKey.Group.SoraImagePrice540,
VideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest,
VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
}
}
if result.MediaType == MediaTypeImage {
return s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier)
}
return s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier)
}
// resolveChannelPricing 检查指定模型是否存在渠道级别定价。
// 返回非 nil 的 ResolvedPricing 表示有渠道定价,nil 表示走默认定价路径。
func (s *GatewayService) resolveChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing {
......@@ -8133,13 +7808,12 @@ func (s *GatewayService) buildRecordUsageLog(
RateMultiplier: multiplier,
AccountRateMultiplier: &accountRateMultiplier,
BillingType: billingType,
BillingMode: resolveBillingMode(opts, result, cost),
BillingMode: resolveBillingMode(result, cost),
Stream: result.Stream,
DurationMs: &durationMs,
FirstTokenMs: result.FirstTokenMs,
ImageCount: result.ImageCount,
ImageSize: optionalTrimmedStringPtr(result.ImageSize),
MediaType: resolveMediaType(opts, result),
CacheTTLOverridden: cacheTTLOverridden,
ChannelID: optionalInt64Ptr(input.ChannelID),
ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain),
......@@ -8163,13 +7837,7 @@ func (s *GatewayService) buildRecordUsageLog(
}
// resolveBillingMode 根据计费结果和请求类型确定计费模式。
// Sora 媒体类型自身已确定计费模式(由上游处理),返回 nil 跳过。
func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *CostBreakdown) *string {
isSoraMedia := opts.EnableClaudePath &&
(result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo || result.MediaType == MediaTypePrompt)
if isSoraMedia {
return nil
}
func resolveBillingMode(result *ForwardResult, cost *CostBreakdown) *string {
var mode string
switch {
case cost != nil && cost.BillingMode != "":
......@@ -8182,13 +7850,6 @@ func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *Cost
return &mode
}
func resolveMediaType(opts *recordUsageOpts, result *ForwardResult) *string {
if opts.EnableClaudePath && strings.TrimSpace(result.MediaType) != "" {
return &result.MediaType
}
return nil
}
func optionalSubscriptionID(subscription *UserSubscription) *int64 {
if subscription != nil {
return &subscription.ID
......
......@@ -9,35 +9,35 @@ import (
func TestCollectSelectionFailureStats(t *testing.T) {
svc := &GatewayService{}
model := "sora2-landscape-10s"
model := "gpt-5.4"
resetAt := time.Now().Add(2 * time.Minute).Format(time.RFC3339)
accounts := []Account{
// excluded
{
ID: 1,
Platform: PlatformSora,
Platform: PlatformOpenAI,
Status: StatusActive,
Schedulable: true,
},
// unschedulable
{
ID: 2,
Platform: PlatformSora,
Platform: PlatformOpenAI,
Status: StatusActive,
Schedulable: false,
},
// platform filtered
{
ID: 3,
Platform: PlatformOpenAI,
Platform: PlatformAntigravity,
Status: StatusActive,
Schedulable: true,
},
// model unsupported
{
ID: 4,
Platform: PlatformSora,
Platform: PlatformOpenAI,
Status: StatusActive,
Schedulable: true,
Credentials: map[string]any{
......@@ -49,7 +49,7 @@ func TestCollectSelectionFailureStats(t *testing.T) {
// model rate limited
{
ID: 5,
Platform: PlatformSora,
Platform: PlatformOpenAI,
Status: StatusActive,
Schedulable: true,
Extra: map[string]any{
......@@ -63,14 +63,14 @@ func TestCollectSelectionFailureStats(t *testing.T) {
// eligible
{
ID: 6,
Platform: PlatformSora,
Platform: PlatformOpenAI,
Status: StatusActive,
Schedulable: true,
},
}
excluded := map[int64]struct{}{1: {}}
stats := svc.collectSelectionFailureStats(context.Background(), accounts, model, PlatformSora, excluded, false)
stats := svc.collectSelectionFailureStats(context.Background(), accounts, model, PlatformOpenAI, excluded, false)
if stats.Total != 6 {
t.Fatalf("total=%d want=6", stats.Total)
......@@ -95,31 +95,31 @@ func TestCollectSelectionFailureStats(t *testing.T) {
}
}
func TestDiagnoseSelectionFailure_SoraUnschedulableDetail(t *testing.T) {
func TestDiagnoseSelectionFailure_UnschedulableDetail(t *testing.T) {
svc := &GatewayService{}
acc := &Account{
ID: 7,
Platform: PlatformSora,
Platform: PlatformOpenAI,
Status: StatusActive,
Schedulable: false,
}
diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, "sora2-landscape-10s", PlatformSora, map[int64]struct{}{}, false)
diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, "gpt-5.4", PlatformOpenAI, map[int64]struct{}{}, false)
if diagnosis.Category != "unschedulable" {
t.Fatalf("category=%s want=unschedulable", diagnosis.Category)
}
if diagnosis.Detail != "schedulable=false" {
t.Fatalf("detail=%s want=schedulable=false", diagnosis.Detail)
if diagnosis.Detail != "generic_unschedulable" {
t.Fatalf("detail=%s want=generic_unschedulable", diagnosis.Detail)
}
}
func TestDiagnoseSelectionFailure_SoraModelRateLimitedDetail(t *testing.T) {
func TestDiagnoseSelectionFailure_ModelRateLimitedDetail(t *testing.T) {
svc := &GatewayService{}
model := "sora2-landscape-10s"
model := "gpt-5.4"
resetAt := time.Now().Add(2 * time.Minute).UTC().Format(time.RFC3339)
acc := &Account{
ID: 8,
Platform: PlatformSora,
Platform: PlatformOpenAI,
Status: StatusActive,
Schedulable: true,
Extra: map[string]any{
......@@ -131,7 +131,7 @@ func TestDiagnoseSelectionFailure_SoraModelRateLimitedDetail(t *testing.T) {
},
}
diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, model, PlatformSora, map[int64]struct{}{}, false)
diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, model, PlatformOpenAI, map[int64]struct{}{}, false)
if diagnosis.Category != "model_rate_limited" {
t.Fatalf("category=%s want=model_rate_limited", diagnosis.Category)
}
......
package service
import "testing"
func TestGatewayServiceIsModelSupportedByAccount_SoraNoMappingAllowsAll(t *testing.T) {
svc := &GatewayService{}
account := &Account{
Platform: PlatformSora,
Credentials: map[string]any{},
}
if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
t.Fatalf("expected sora model to be supported when model_mapping is empty")
}
}
func TestGatewayServiceIsModelSupportedByAccount_SoraLegacyNonSoraMappingDoesNotBlock(t *testing.T) {
svc := &GatewayService{}
account := &Account{
Platform: PlatformSora,
Credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-4o": "gpt-4o",
},
},
}
if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
t.Fatalf("expected sora model to be supported when mapping has no sora selectors")
}
}
func TestGatewayServiceIsModelSupportedByAccount_SoraFamilyAlias(t *testing.T) {
svc := &GatewayService{}
account := &Account{
Platform: PlatformSora,
Credentials: map[string]any{
"model_mapping": map[string]any{
"sora2": "sora2",
},
},
}
if !svc.isModelSupportedByAccount(account, "sora2-landscape-15s") {
t.Fatalf("expected family selector sora2 to support sora2-landscape-15s")
}
}
func TestGatewayServiceIsModelSupportedByAccount_SoraUnderlyingModelAlias(t *testing.T) {
svc := &GatewayService{}
account := &Account{
Platform: PlatformSora,
Credentials: map[string]any{
"model_mapping": map[string]any{
"sy_8": "sy_8",
},
},
}
if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
t.Fatalf("expected underlying model selector sy_8 to support sora2-landscape-10s")
}
}
func TestGatewayServiceIsModelSupportedByAccount_SoraExplicitImageSelectorBlocksVideo(t *testing.T) {
svc := &GatewayService{}
account := &Account{
Platform: PlatformSora,
Credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-image": "gpt-image",
},
},
}
if svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
t.Fatalf("expected video model to be blocked when mapping explicitly only allows gpt-image")
}
}
package service
import (
"context"
"testing"
"time"
)
func TestGatewayServiceIsAccountSchedulableForSelectionSoraIgnoresGenericWindows(t *testing.T) {
svc := &GatewayService{}
now := time.Now()
past := now.Add(-1 * time.Minute)
future := now.Add(5 * time.Minute)
acc := &Account{
Platform: PlatformSora,
Status: StatusActive,
Schedulable: true,
AutoPauseOnExpired: true,
ExpiresAt: &past,
OverloadUntil: &future,
RateLimitResetAt: &future,
}
if !svc.isAccountSchedulableForSelection(acc) {
t.Fatalf("expected sora account to ignore generic expiry/overload/rate-limit windows")
}
}
func TestGatewayServiceIsAccountSchedulableForSelectionNonSoraKeepsGenericLogic(t *testing.T) {
svc := &GatewayService{}
future := time.Now().Add(5 * time.Minute)
acc := &Account{
Platform: PlatformAnthropic,
Status: StatusActive,
Schedulable: true,
RateLimitResetAt: &future,
}
if svc.isAccountSchedulableForSelection(acc) {
t.Fatalf("expected non-sora account to keep generic schedulable checks")
}
}
func TestGatewayServiceIsAccountSchedulableForModelSelectionSoraChecksModelScopeOnly(t *testing.T) {
svc := &GatewayService{}
model := "sora2-landscape-10s"
resetAt := time.Now().Add(2 * time.Minute).UTC().Format(time.RFC3339)
globalResetAt := time.Now().Add(2 * time.Minute)
acc := &Account{
Platform: PlatformSora,
Status: StatusActive,
Schedulable: true,
RateLimitResetAt: &globalResetAt,
Extra: map[string]any{
"model_rate_limits": map[string]any{
model: map[string]any{
"rate_limit_reset_at": resetAt,
},
},
},
}
if svc.isAccountSchedulableForModelSelection(context.Background(), acc, model) {
t.Fatalf("expected sora account to be blocked by model scope rate limit")
}
}
func TestCollectSelectionFailureStatsSoraIgnoresGenericUnschedulableWindows(t *testing.T) {
svc := &GatewayService{}
future := time.Now().Add(3 * time.Minute)
accounts := []Account{
{
ID: 1,
Platform: PlatformSora,
Status: StatusActive,
Schedulable: true,
RateLimitResetAt: &future,
},
}
stats := svc.collectSelectionFailureStats(context.Background(), accounts, "sora2-landscape-10s", PlatformSora, map[int64]struct{}{}, false)
if stats.Unschedulable != 0 || stats.Eligible != 1 {
t.Fatalf("unexpected stats: unschedulable=%d eligible=%d", stats.Unschedulable, stats.Eligible)
}
}
......@@ -26,15 +26,6 @@ type Group struct {
ImagePrice2K *float64
ImagePrice4K *float64
// Sora 按次计费配置(阶段 1)
SoraImagePrice360 *float64
SoraImagePrice540 *float64
SoraVideoPricePerRequest *float64
SoraVideoPricePerRequestHD *float64
// Sora 存储配额
SoraStorageQuotaBytes int64
// Claude Code 客户端限制
ClaudeCodeOnly bool
FallbackGroupID *int64
......@@ -112,18 +103,6 @@ func (g *Group) GetImagePrice(imageSize string) *float64 {
}
}
// GetSoraImagePrice 根据 Sora 图片尺寸返回价格(360/540)
func (g *Group) GetSoraImagePrice(imageSize string) *float64 {
switch imageSize {
case "360":
return g.SoraImagePrice360
case "540":
return g.SoraImagePrice540
default:
return g.SoraImagePrice360
}
}
// IsGroupContextValid reports whether a group from context has the fields required for routing decisions.
func IsGroupContextValid(group *Group) bool {
if group == nil {
......
......@@ -933,6 +933,89 @@ func TestOpenAIGatewayServiceRecordUsage_BillsMappedRequestsUsingRequestedModel(
require.Equal(t, expectedCost.ActualCost, userRepo.lastAmount)
}
func TestOpenAIGatewayServiceRecordUsage_ChannelMappedDoesNotOverrideBillingModelWhenUnmapped(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10}
// When channel did NOT map the model (ChannelMappedModel == OriginalModel),
// billing should use result.BillingModel (the actual model used after group
// DefaultMappedModel resolution), not the unmapped original model.
expectedCost, err := svc.billingService.CalculateCost("gpt-5.1", UsageTokens{
InputTokens: 20,
OutputTokens: 10,
}, 1.1)
require.NoError(t, err)
err = svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_channel_unmapped_billing",
Model: "glm",
BillingModel: "gpt-5.1",
UpstreamModel: "gpt-5.1",
Usage: usage,
Duration: time.Second,
},
APIKey: &APIKey{ID: 10},
User: &User{ID: 20},
Account: &Account{ID: 30},
ChannelUsageFields: ChannelUsageFields{
ChannelID: 1,
OriginalModel: "glm",
ChannelMappedModel: "glm", // channel did NOT map
BillingModelSource: BillingModelSourceChannelMapped,
},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost)
require.True(t, usageRepo.lastLog.ActualCost > 0, "cost must not be zero")
}
func TestOpenAIGatewayServiceRecordUsage_ChannelMappedOverridesBillingModelWhenMapped(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10}
// When channel DID map the model (ChannelMappedModel != OriginalModel),
// billing should use the channel-mapped model, honoring admin intent.
expectedCost, err := svc.billingService.CalculateCost("gpt-5.1", UsageTokens{
InputTokens: 20,
OutputTokens: 10,
}, 1.1)
require.NoError(t, err)
err = svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_channel_mapped_billing",
Model: "glm",
BillingModel: "gpt-5.1-codex",
UpstreamModel: "gpt-5.1-codex",
Usage: usage,
Duration: time.Second,
},
APIKey: &APIKey{ID: 10},
User: &User{ID: 20},
Account: &Account{ID: 30},
ChannelUsageFields: ChannelUsageFields{
ChannelID: 1,
OriginalModel: "glm",
ChannelMappedModel: "gpt-5.1", // channel mapped glm → gpt-5.1
BillingModelSource: BillingModelSourceChannelMapped,
},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost)
require.True(t, usageRepo.lastLog.ActualCost > 0, "cost must not be zero")
}
func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFields(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
......
......@@ -4277,7 +4277,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
if result.BillingModel != "" {
billingModel = strings.TrimSpace(result.BillingModel)
}
if input.BillingModelSource == BillingModelSourceChannelMapped && input.ChannelMappedModel != "" {
if input.BillingModelSource == BillingModelSourceChannelMapped && input.ChannelMappedModel != "" && input.ChannelMappedModel != input.OriginalModel {
billingModel = input.ChannelMappedModel
}
if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" {
......
......@@ -3,30 +3,15 @@ package service
import (
"context"
"crypto/subtle"
"encoding/json"
"io"
"log/slog"
"net/http"
"regexp"
"sort"
"strconv"
"strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
)
var openAISoraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session"
var soraSessionCookiePattern = regexp.MustCompile(`(?i)(?:^|[\n\r;])\s*(?:(?:set-cookie|cookie)\s*:\s*)?__Secure-(?:next-auth|authjs)\.session-token(?:\.(\d+))?=([^;\r\n]+)`)
type soraSessionChunk struct {
index int
value string
}
// OpenAIOAuthService handles OpenAI OAuth authentication flows
type OpenAIOAuthService struct {
sessionStore *openai.SessionStore
......@@ -225,7 +210,7 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri
return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "")
}
// RefreshTokenWithClientID refreshes an OpenAI/Sora OAuth token with optional client_id.
// RefreshTokenWithClientID refreshes an OpenAI OAuth token with optional client_id.
func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken string, proxyURL string, clientID string) (*OpenAITokenInfo, error) {
tokenResp, err := s.oauthClient.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
if err != nil {
......@@ -298,215 +283,10 @@ func (s *OpenAIOAuthService) enrichTokenInfo(ctx context.Context, tokenInfo *Ope
tokenInfo.PrivacyMode = disableOpenAITraining(ctx, s.privacyClientFactory, tokenInfo.AccessToken, proxyURL)
}
// ExchangeSoraSessionToken exchanges Sora session_token to access_token.
func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessionToken string, proxyID *int64) (*OpenAITokenInfo, error) {
sessionToken = normalizeSoraSessionTokenInput(sessionToken)
if strings.TrimSpace(sessionToken) == "" {
return nil, infraerrors.New(http.StatusBadRequest, "SORA_SESSION_TOKEN_REQUIRED", "session_token is required")
}
proxyURL, err := s.resolveProxyURL(ctx, proxyID)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, openAISoraSessionAuthURL, nil)
if err != nil {
return nil, infraerrors.Newf(http.StatusInternalServerError, "SORA_SESSION_REQUEST_BUILD_FAILED", "failed to build request: %v", err)
}
req.Header.Set("Cookie", "__Secure-next-auth.session-token="+strings.TrimSpace(sessionToken))
req.Header.Set("Accept", "application/json")
req.Header.Set("Origin", "https://sora.chatgpt.com")
req.Header.Set("Referer", "https://sora.chatgpt.com/")
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL,
Timeout: 120 * time.Second,
})
if err != nil {
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_CLIENT_FAILED", "create http client failed: %v", err)
}
resp, err := client.Do(req)
if err != nil {
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_REQUEST_FAILED", "request failed: %v", err)
}
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
if resp.StatusCode != http.StatusOK {
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_EXCHANGE_FAILED", "status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var sessionResp struct {
AccessToken string `json:"accessToken"`
Expires string `json:"expires"`
User struct {
Email string `json:"email"`
Name string `json:"name"`
} `json:"user"`
}
if err := json.Unmarshal(body, &sessionResp); err != nil {
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_PARSE_FAILED", "failed to parse response: %v", err)
}
if strings.TrimSpace(sessionResp.AccessToken) == "" {
return nil, infraerrors.New(http.StatusBadGateway, "SORA_SESSION_ACCESS_TOKEN_MISSING", "session exchange response missing access token")
}
expiresAt := time.Now().Add(time.Hour).Unix()
if strings.TrimSpace(sessionResp.Expires) != "" {
if parsed, parseErr := time.Parse(time.RFC3339, sessionResp.Expires); parseErr == nil {
expiresAt = parsed.Unix()
}
}
expiresIn := expiresAt - time.Now().Unix()
if expiresIn < 0 {
expiresIn = 0
}
return &OpenAITokenInfo{
AccessToken: strings.TrimSpace(sessionResp.AccessToken),
ExpiresIn: expiresIn,
ExpiresAt: expiresAt,
ClientID: openai.SoraClientID,
Email: strings.TrimSpace(sessionResp.User.Email),
}, nil
}
func normalizeSoraSessionTokenInput(raw string) string {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return ""
}
matches := soraSessionCookiePattern.FindAllStringSubmatch(trimmed, -1)
if len(matches) == 0 {
return sanitizeSessionToken(trimmed)
}
chunkMatches := make([]soraSessionChunk, 0, len(matches))
singleValues := make([]string, 0, len(matches))
for _, match := range matches {
if len(match) < 3 {
continue
}
value := sanitizeSessionToken(match[2])
if value == "" {
continue
}
if strings.TrimSpace(match[1]) == "" {
singleValues = append(singleValues, value)
continue
}
idx, err := strconv.Atoi(strings.TrimSpace(match[1]))
if err != nil || idx < 0 {
continue
}
chunkMatches = append(chunkMatches, soraSessionChunk{
index: idx,
value: value,
})
}
if merged := mergeLatestSoraSessionChunks(chunkMatches); merged != "" {
return merged
}
if len(singleValues) > 0 {
return singleValues[len(singleValues)-1]
}
return ""
}
func mergeSoraSessionChunkSegment(chunks []soraSessionChunk, requiredMaxIndex int, requireComplete bool) string {
if len(chunks) == 0 {
return ""
}
byIndex := make(map[int]string, len(chunks))
for _, chunk := range chunks {
byIndex[chunk.index] = chunk.value
}
if _, ok := byIndex[0]; !ok {
return ""
}
if requireComplete {
for idx := 0; idx <= requiredMaxIndex; idx++ {
if _, ok := byIndex[idx]; !ok {
return ""
}
}
}
orderedIndexes := make([]int, 0, len(byIndex))
for idx := range byIndex {
orderedIndexes = append(orderedIndexes, idx)
}
sort.Ints(orderedIndexes)
var builder strings.Builder
for _, idx := range orderedIndexes {
if _, err := builder.WriteString(byIndex[idx]); err != nil {
return ""
}
}
return sanitizeSessionToken(builder.String())
}
func mergeLatestSoraSessionChunks(chunks []soraSessionChunk) string {
if len(chunks) == 0 {
return ""
}
requiredMaxIndex := 0
for _, chunk := range chunks {
if chunk.index > requiredMaxIndex {
requiredMaxIndex = chunk.index
}
}
groupStarts := make([]int, 0, len(chunks))
for idx, chunk := range chunks {
if chunk.index == 0 {
groupStarts = append(groupStarts, idx)
}
}
if len(groupStarts) == 0 {
return mergeSoraSessionChunkSegment(chunks, requiredMaxIndex, false)
}
for i := len(groupStarts) - 1; i >= 0; i-- {
start := groupStarts[i]
end := len(chunks)
if i+1 < len(groupStarts) {
end = groupStarts[i+1]
}
if merged := mergeSoraSessionChunkSegment(chunks[start:end], requiredMaxIndex, true); merged != "" {
return merged
}
}
return mergeSoraSessionChunkSegment(chunks, requiredMaxIndex, false)
}
func sanitizeSessionToken(raw string) string {
token := strings.TrimSpace(raw)
token = strings.Trim(token, "\"'`")
token = strings.TrimSuffix(token, ";")
return strings.TrimSpace(token)
}
// RefreshAccountToken refreshes token for an OpenAI/Sora OAuth account
// RefreshAccountToken refreshes token for an OpenAI OAuth account
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
if account.Platform != PlatformOpenAI && account.Platform != PlatformSora {
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI/Sora account")
if account.Platform != PlatformOpenAI {
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI account")
}
if account.Type != AccountTypeOAuth {
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT_TYPE", "account is not an OAuth account")
......@@ -594,25 +374,6 @@ func (s *OpenAIOAuthService) Stop() {
s.sessionStore.Stop()
}
func (s *OpenAIOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64) (string, error) {
if proxyID == nil {
return "", nil
}
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
if err != nil {
return "", infraerrors.Newf(http.StatusBadRequest, "OPENAI_OAUTH_PROXY_NOT_FOUND", "proxy not found: %v", err)
}
if proxy == nil {
return "", nil
}
return proxy.URL(), nil
}
func normalizeOpenAIOAuthPlatform(platform string) string {
switch strings.ToLower(strings.TrimSpace(platform)) {
case PlatformSora:
return openai.OAuthPlatformSora
default:
return openai.OAuthPlatformOpenAI
}
return openai.OAuthPlatformOpenAI
}
......@@ -43,25 +43,3 @@ func TestOpenAIOAuthService_GenerateAuthURL_OpenAIKeepsCodexFlow(t *testing.T) {
require.True(t, ok)
require.Equal(t, openai.ClientID, session.ClientID)
}
// TestOpenAIOAuthService_GenerateAuthURL_SoraUsesCodexClient 验证 Sora 平台复用 Codex CLI 的
// client_id(支持 localhost redirect_uri),但不启用 codex_cli_simplified_flow。
func TestOpenAIOAuthService_GenerateAuthURL_SoraUsesCodexClient(t *testing.T) {
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientAuthURLStub{})
defer svc.Stop()
result, err := svc.GenerateAuthURL(context.Background(), nil, "", PlatformSora)
require.NoError(t, err)
require.NotEmpty(t, result.AuthURL)
require.NotEmpty(t, result.SessionID)
parsed, err := url.Parse(result.AuthURL)
require.NoError(t, err)
q := parsed.Query()
require.Equal(t, openai.ClientID, q.Get("client_id"))
require.Empty(t, q.Get("codex_cli_simplified_flow"))
session, ok := svc.sessionStore.Get(result.SessionID)
require.True(t, ok)
require.Equal(t, openai.ClientID, session.ClientID)
}
package service
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/stretchr/testify/require"
)
type openaiOAuthClientNoopStub struct{}
func (s *openaiOAuthClientNoopStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) {
return nil, errors.New("not implemented")
}
func (s *openaiOAuthClientNoopStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
return nil, errors.New("not implemented")
}
func (s *openaiOAuthClientNoopStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
return nil, errors.New("not implemented")
}
func TestOpenAIOAuthService_ExchangeSoraSessionToken_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=st-token")
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
}))
defer server.Close()
origin := openAISoraSessionAuthURL
openAISoraSessionAuthURL = server.URL
defer func() { openAISoraSessionAuthURL = origin }()
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
defer svc.Stop()
info, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil)
require.NoError(t, err)
require.NotNil(t, info)
require.Equal(t, "at-token", info.AccessToken)
require.Equal(t, "demo@example.com", info.Email)
require.Greater(t, info.ExpiresAt, int64(0))
}
func TestOpenAIOAuthService_ExchangeSoraSessionToken_MissingAccessToken(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"expires":"2099-01-01T00:00:00Z"}`))
}))
defer server.Close()
origin := openAISoraSessionAuthURL
openAISoraSessionAuthURL = server.URL
defer func() { openAISoraSessionAuthURL = origin }()
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
defer svc.Stop()
_, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "missing access token")
}
func TestOpenAIOAuthService_ExchangeSoraSessionToken_AcceptsSetCookieLine(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=st-cookie-value")
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
}))
defer server.Close()
origin := openAISoraSessionAuthURL
openAISoraSessionAuthURL = server.URL
defer func() { openAISoraSessionAuthURL = origin }()
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
defer svc.Stop()
raw := "__Secure-next-auth.session-token.0=st-cookie-value; Domain=.chatgpt.com; Path=/; HttpOnly; Secure; SameSite=Lax"
info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
require.NoError(t, err)
require.Equal(t, "at-token", info.AccessToken)
}
func TestOpenAIOAuthService_ExchangeSoraSessionToken_MergesChunkedSetCookieLines(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=chunk-0chunk-1")
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
}))
defer server.Close()
origin := openAISoraSessionAuthURL
openAISoraSessionAuthURL = server.URL
defer func() { openAISoraSessionAuthURL = origin }()
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
defer svc.Stop()
raw := strings.Join([]string{
"Set-Cookie: __Secure-next-auth.session-token.1=chunk-1; Path=/; HttpOnly",
"Set-Cookie: __Secure-next-auth.session-token.0=chunk-0; Path=/; HttpOnly",
}, "\n")
info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
require.NoError(t, err)
require.Equal(t, "at-token", info.AccessToken)
}
func TestOpenAIOAuthService_ExchangeSoraSessionToken_PrefersLatestDuplicateChunks(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=new-0new-1")
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
}))
defer server.Close()
origin := openAISoraSessionAuthURL
openAISoraSessionAuthURL = server.URL
defer func() { openAISoraSessionAuthURL = origin }()
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
defer svc.Stop()
raw := strings.Join([]string{
"Set-Cookie: __Secure-next-auth.session-token.0=old-0; Path=/; HttpOnly",
"Set-Cookie: __Secure-next-auth.session-token.1=old-1; Path=/; HttpOnly",
"Set-Cookie: __Secure-next-auth.session-token.0=new-0; Path=/; HttpOnly",
"Set-Cookie: __Secure-next-auth.session-token.1=new-1; Path=/; HttpOnly",
}, "\n")
info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
require.NoError(t, err)
require.Equal(t, "at-token", info.AccessToken)
}
func TestOpenAIOAuthService_ExchangeSoraSessionToken_UsesLatestCompleteChunkGroup(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=ok-0ok-1")
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
}))
defer server.Close()
origin := openAISoraSessionAuthURL
openAISoraSessionAuthURL = server.URL
defer func() { openAISoraSessionAuthURL = origin }()
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
defer svc.Stop()
raw := strings.Join([]string{
"set-cookie",
"__Secure-next-auth.session-token.0=ok-0; Domain=.chatgpt.com; Path=/",
"set-cookie",
"__Secure-next-auth.session-token.1=ok-1; Domain=.chatgpt.com; Path=/",
"set-cookie",
"__Secure-next-auth.session-token.0=partial-0; Domain=.chatgpt.com; Path=/",
}, "\n")
info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
require.NoError(t, err)
require.Equal(t, "at-token", info.AccessToken)
}
......@@ -75,7 +75,7 @@ func (m *openAITokenRuntimeMetricsStore) touchNow() {
// OpenAITokenCache token cache interface.
type OpenAITokenCache = GeminiTokenCache
// OpenAITokenProvider manages access_token for OpenAI/Sora OAuth accounts.
// OpenAITokenProvider manages access_token for OpenAI OAuth accounts.
type OpenAITokenProvider struct {
accountRepo AccountRepository
tokenCache OpenAITokenCache
......@@ -131,8 +131,8 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
if account == nil {
return "", errors.New("account is nil")
}
if (account.Platform != PlatformOpenAI && account.Platform != PlatformSora) || account.Type != AccountTypeOAuth {
return "", errors.New("not an openai/sora oauth account")
if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth {
return "", errors.New("not an openai oauth account")
}
cacheKey := OpenAITokenCacheKey(account)
......@@ -158,40 +158,34 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
p.metrics.refreshRequests.Add(1)
p.metrics.touchNow()
// Sora accounts skip OpenAI OAuth refresh and keep existing token path.
if account.Platform == PlatformSora {
slog.Debug("openai_token_refresh_skipped_for_sora", "account_id", account.ID)
result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, openAITokenRefreshSkew)
if err != nil {
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
return "", err
}
slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err)
p.metrics.refreshFailure.Add(1)
refreshFailed = true
} else {
result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, openAITokenRefreshSkew)
if err != nil {
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
return "", err
} else if result.LockHeld {
if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache {
p.metrics.lockContention.Add(1)
p.metrics.touchNow()
token, waitErr := p.waitForTokenAfterLockRace(ctx, cacheKey)
if waitErr != nil {
return "", waitErr
}
slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err)
p.metrics.refreshFailure.Add(1)
refreshFailed = true
} else if result.LockHeld {
if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache {
p.metrics.lockContention.Add(1)
p.metrics.touchNow()
token, waitErr := p.waitForTokenAfterLockRace(ctx, cacheKey)
if waitErr != nil {
return "", waitErr
}
if strings.TrimSpace(token) != "" {
slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID)
return token, nil
}
if strings.TrimSpace(token) != "" {
slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID)
return token, nil
}
} else if result.Refreshed {
p.metrics.refreshSuccess.Add(1)
account = result.Account
expiresAt = account.GetCredentialAsTime("expires_at")
} else {
account = result.Account
expiresAt = account.GetCredentialAsTime("expires_at")
}
} else if result.Refreshed {
p.metrics.refreshSuccess.Add(1)
account = result.Account
expiresAt = account.GetCredentialAsTime("expires_at")
} else {
account = result.Account
expiresAt = account.GetCredentialAsTime("expires_at")
}
} else if needsRefresh && p.tokenCache != nil {
// Backward-compatible test path when refreshAPI is not injected.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment