"...internal/handler/git@web.lueluesay.top:chenxi/sub2api.git" did not exist on "8e69957bb19ccb10a3e485259c5ec0c6319321be"
Unverified Commit 77ba9e72 authored by Yanzhe Lee's avatar Yanzhe Lee Committed by GitHub
Browse files

Merge branch 'Wei-Shaw:main' into fix/openai-gateway-content-session-hash-fallback

parents cf9efefd 055c48ab
...@@ -121,7 +121,7 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) { ...@@ -121,7 +121,7 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
googleError(c, http.StatusBadGateway, err.Error()) googleError(c, http.StatusBadGateway, err.Error())
return return
} }
if shouldFallbackGeminiModels(res) { if shouldFallbackGeminiModel(modelName, res) {
c.JSON(http.StatusOK, gemini.FallbackModel(modelName)) c.JSON(http.StatusOK, gemini.FallbackModel(modelName))
return return
} }
...@@ -674,6 +674,16 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool { ...@@ -674,6 +674,16 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool {
return false return false
} }
func shouldFallbackGeminiModel(modelName string, res *service.UpstreamHTTPResult) bool {
if shouldFallbackGeminiModels(res) {
return true
}
if res == nil || res.StatusCode != http.StatusNotFound {
return false
}
return gemini.HasFallbackModel(modelName)
}
// extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。 // extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。
// 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。 // 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。
// //
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
package handler package handler
import ( import (
"net/http"
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
...@@ -141,3 +142,28 @@ func TestGeminiV1BetaHandler_GetModelAntigravityFallback(t *testing.T) { ...@@ -141,3 +142,28 @@ func TestGeminiV1BetaHandler_GetModelAntigravityFallback(t *testing.T) {
}) })
} }
} }
func TestShouldFallbackGeminiModel_KnownFallbackOn404(t *testing.T) {
t.Parallel()
res := &service.UpstreamHTTPResult{StatusCode: http.StatusNotFound}
require.True(t, shouldFallbackGeminiModel("gemini-3.1-pro-preview-customtools", res))
}
func TestShouldFallbackGeminiModel_UnknownModelOn404(t *testing.T) {
t.Parallel()
res := &service.UpstreamHTTPResult{StatusCode: http.StatusNotFound}
require.False(t, shouldFallbackGeminiModel("gemini-future-model", res))
}
func TestShouldFallbackGeminiModel_DelegatesScopeFallback(t *testing.T) {
t.Parallel()
res := &service.UpstreamHTTPResult{
StatusCode: http.StatusForbidden,
Headers: http.Header{"Www-Authenticate": []string{"Bearer error=\"insufficient_scope\""}},
Body: []byte("insufficient authentication scopes"),
}
require.True(t, shouldFallbackGeminiModel("gemini-future-model", res))
}
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
// It is used when upstream model listing is unavailable (e.g. OAuth token missing AI Studio scopes). // It is used when upstream model listing is unavailable (e.g. OAuth token missing AI Studio scopes).
package gemini package gemini
import "strings"
type Model struct { type Model struct {
Name string `json:"name"` Name string `json:"name"`
DisplayName string `json:"displayName,omitempty"` DisplayName string `json:"displayName,omitempty"`
...@@ -23,10 +25,27 @@ func DefaultModels() []Model { ...@@ -23,10 +25,27 @@ func DefaultModels() []Model {
{Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods}, {Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods}, {Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
{Name: "models/gemini-3.1-pro-preview", SupportedGenerationMethods: methods}, {Name: "models/gemini-3.1-pro-preview", SupportedGenerationMethods: methods},
{Name: "models/gemini-3.1-pro-preview-customtools", SupportedGenerationMethods: methods},
{Name: "models/gemini-3.1-flash-image", SupportedGenerationMethods: methods}, {Name: "models/gemini-3.1-flash-image", SupportedGenerationMethods: methods},
} }
} }
func HasFallbackModel(model string) bool {
trimmed := strings.TrimSpace(model)
if trimmed == "" {
return false
}
if !strings.HasPrefix(trimmed, "models/") {
trimmed = "models/" + trimmed
}
for _, model := range DefaultModels() {
if model.Name == trimmed {
return true
}
}
return false
}
func FallbackModelsList() ModelsListResponse { func FallbackModelsList() ModelsListResponse {
return ModelsListResponse{Models: DefaultModels()} return ModelsListResponse{Models: DefaultModels()}
} }
......
...@@ -2,7 +2,7 @@ package gemini ...@@ -2,7 +2,7 @@ package gemini
import "testing" import "testing"
func TestDefaultModels_ContainsImageModels(t *testing.T) { func TestDefaultModels_ContainsFallbackCatalogModels(t *testing.T) {
t.Parallel() t.Parallel()
models := DefaultModels() models := DefaultModels()
...@@ -13,6 +13,7 @@ func TestDefaultModels_ContainsImageModels(t *testing.T) { ...@@ -13,6 +13,7 @@ func TestDefaultModels_ContainsImageModels(t *testing.T) {
required := []string{ required := []string{
"models/gemini-2.5-flash-image", "models/gemini-2.5-flash-image",
"models/gemini-3.1-pro-preview-customtools",
"models/gemini-3.1-flash-image", "models/gemini-3.1-flash-image",
} }
...@@ -26,3 +27,17 @@ func TestDefaultModels_ContainsImageModels(t *testing.T) { ...@@ -26,3 +27,17 @@ func TestDefaultModels_ContainsImageModels(t *testing.T) {
} }
} }
} }
func TestHasFallbackModel_RecognizesCustomtoolsModel(t *testing.T) {
t.Parallel()
if !HasFallbackModel("gemini-3.1-pro-preview-customtools") {
t.Fatalf("expected customtools model to exist in fallback catalog")
}
if !HasFallbackModel("models/gemini-3.1-pro-preview-customtools") {
t.Fatalf("expected prefixed customtools model to exist in fallback catalog")
}
if HasFallbackModel("gemini-unknown") {
t.Fatalf("did not expect unknown model to exist in fallback catalog")
}
}
...@@ -515,18 +515,27 @@ func ensureAntigravityDefaultPassthroughs(mapping map[string]string, models []st ...@@ -515,18 +515,27 @@ func ensureAntigravityDefaultPassthroughs(mapping map[string]string, models []st
} }
} }
// IsModelSupported 检查模型是否在 model_mapping 中(支持通配符) func normalizeRequestedModelForLookup(platform, requestedModel string) string {
// 如果未配置 mapping,返回 true(允许所有模型) trimmed := strings.TrimSpace(requestedModel)
func (a *Account) IsModelSupported(requestedModel string) bool { if trimmed == "" {
mapping := a.GetModelMapping() return ""
if len(mapping) == 0 { }
return true // 无映射 = 允许所有 if platform != PlatformGemini && platform != PlatformAntigravity {
return trimmed
}
if trimmed == "gemini-3.1-pro-preview-customtools" {
return "gemini-3.1-pro-preview"
}
return trimmed
}
func mappingSupportsRequestedModel(mapping map[string]string, requestedModel string) bool {
if requestedModel == "" {
return false
} }
// 精确匹配
if _, exists := mapping[requestedModel]; exists { if _, exists := mapping[requestedModel]; exists {
return true return true
} }
// 通配符匹配
for pattern := range mapping { for pattern := range mapping {
if matchWildcard(pattern, requestedModel) { if matchWildcard(pattern, requestedModel) {
return true return true
...@@ -535,6 +544,30 @@ func (a *Account) IsModelSupported(requestedModel string) bool { ...@@ -535,6 +544,30 @@ func (a *Account) IsModelSupported(requestedModel string) bool {
return false return false
} }
func resolveRequestedModelInMapping(mapping map[string]string, requestedModel string) (mappedModel string, matched bool) {
if requestedModel == "" {
return "", false
}
if mappedModel, exists := mapping[requestedModel]; exists {
return mappedModel, true
}
return matchWildcardMappingResult(mapping, requestedModel)
}
// IsModelSupported 检查模型是否在 model_mapping 中(支持通配符)
// 如果未配置 mapping,返回 true(允许所有模型)
func (a *Account) IsModelSupported(requestedModel string) bool {
mapping := a.GetModelMapping()
if len(mapping) == 0 {
return true // 无映射 = 允许所有
}
if mappingSupportsRequestedModel(mapping, requestedModel) {
return true
}
normalized := normalizeRequestedModelForLookup(a.Platform, requestedModel)
return normalized != requestedModel && mappingSupportsRequestedModel(mapping, normalized)
}
// GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配) // GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配)
// 如果未配置 mapping,返回原始模型名 // 如果未配置 mapping,返回原始模型名
func (a *Account) GetMappedModel(requestedModel string) string { func (a *Account) GetMappedModel(requestedModel string) string {
...@@ -549,12 +582,16 @@ func (a *Account) ResolveMappedModel(requestedModel string) (mappedModel string, ...@@ -549,12 +582,16 @@ func (a *Account) ResolveMappedModel(requestedModel string) (mappedModel string,
if len(mapping) == 0 { if len(mapping) == 0 {
return requestedModel, false return requestedModel, false
} }
// 精确匹配优先 if mappedModel, matched := resolveRequestedModelInMapping(mapping, requestedModel); matched {
if mappedModel, exists := mapping[requestedModel]; exists {
return mappedModel, true return mappedModel, true
} }
// 通配符匹配(最长优先) normalized := normalizeRequestedModelForLookup(a.Platform, requestedModel)
return matchWildcardMappingResult(mapping, requestedModel) if normalized != requestedModel {
if mappedModel, matched := resolveRequestedModelInMapping(mapping, normalized); matched {
return mappedModel, true
}
}
return requestedModel, false
} }
func (a *Account) GetBaseURL() string { func (a *Account) GetBaseURL() string {
......
...@@ -133,6 +133,7 @@ func TestMatchWildcardMappingResult(t *testing.T) { ...@@ -133,6 +133,7 @@ func TestMatchWildcardMappingResult(t *testing.T) {
func TestAccountIsModelSupported(t *testing.T) { func TestAccountIsModelSupported(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
platform string
credentials map[string]any credentials map[string]any
requestedModel string requestedModel string
expected bool expected bool
...@@ -184,6 +185,17 @@ func TestAccountIsModelSupported(t *testing.T) { ...@@ -184,6 +185,17 @@ func TestAccountIsModelSupported(t *testing.T) {
requestedModel: "claude-opus-4-5-thinking", requestedModel: "claude-opus-4-5-thinking",
expected: true, expected: true,
}, },
{
name: "gemini customtools alias matches normalized mapping",
platform: PlatformGemini,
credentials: map[string]any{
"model_mapping": map[string]any{
"gemini-3.1-pro-preview": "gemini-3.1-pro-preview",
},
},
requestedModel: "gemini-3.1-pro-preview-customtools",
expected: true,
},
{ {
name: "wildcard match not supported", name: "wildcard match not supported",
credentials: map[string]any{ credentials: map[string]any{
...@@ -199,6 +211,7 @@ func TestAccountIsModelSupported(t *testing.T) { ...@@ -199,6 +211,7 @@ func TestAccountIsModelSupported(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
account := &Account{ account := &Account{
Platform: tt.platform,
Credentials: tt.credentials, Credentials: tt.credentials,
} }
result := account.IsModelSupported(tt.requestedModel) result := account.IsModelSupported(tt.requestedModel)
...@@ -212,6 +225,7 @@ func TestAccountIsModelSupported(t *testing.T) { ...@@ -212,6 +225,7 @@ func TestAccountIsModelSupported(t *testing.T) {
func TestAccountGetMappedModel(t *testing.T) { func TestAccountGetMappedModel(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
platform string
credentials map[string]any credentials map[string]any
requestedModel string requestedModel string
expected string expected string
...@@ -223,6 +237,13 @@ func TestAccountGetMappedModel(t *testing.T) { ...@@ -223,6 +237,13 @@ func TestAccountGetMappedModel(t *testing.T) {
requestedModel: "claude-sonnet-4-5", requestedModel: "claude-sonnet-4-5",
expected: "claude-sonnet-4-5", expected: "claude-sonnet-4-5",
}, },
{
name: "no mapping preserves gemini customtools model",
platform: PlatformGemini,
credentials: nil,
requestedModel: "gemini-3.1-pro-preview-customtools",
expected: "gemini-3.1-pro-preview-customtools",
},
// 精确匹配 // 精确匹配
{ {
...@@ -250,6 +271,29 @@ func TestAccountGetMappedModel(t *testing.T) { ...@@ -250,6 +271,29 @@ func TestAccountGetMappedModel(t *testing.T) {
}, },
// 无匹配返回原始模型 // 无匹配返回原始模型
{
name: "gemini customtools alias resolves through normalized mapping",
platform: PlatformGemini,
credentials: map[string]any{
"model_mapping": map[string]any{
"gemini-3.1-pro-preview": "gemini-3.1-pro-preview",
},
},
requestedModel: "gemini-3.1-pro-preview-customtools",
expected: "gemini-3.1-pro-preview",
},
{
name: "gemini customtools exact mapping wins over normalized fallback",
platform: PlatformGemini,
credentials: map[string]any{
"model_mapping": map[string]any{
"gemini-3.1-pro-preview": "gemini-3.1-pro-preview",
"gemini-3.1-pro-preview-customtools": "gemini-3.1-pro-preview-customtools",
},
},
requestedModel: "gemini-3.1-pro-preview-customtools",
expected: "gemini-3.1-pro-preview-customtools",
},
{ {
name: "no match returns original", name: "no match returns original",
credentials: map[string]any{ credentials: map[string]any{
...@@ -265,6 +309,7 @@ func TestAccountGetMappedModel(t *testing.T) { ...@@ -265,6 +309,7 @@ func TestAccountGetMappedModel(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
account := &Account{ account := &Account{
Platform: tt.platform,
Credentials: tt.credentials, Credentials: tt.credentials,
} }
result := account.GetMappedModel(tt.requestedModel) result := account.GetMappedModel(tt.requestedModel)
...@@ -278,6 +323,7 @@ func TestAccountGetMappedModel(t *testing.T) { ...@@ -278,6 +323,7 @@ func TestAccountGetMappedModel(t *testing.T) {
func TestAccountResolveMappedModel(t *testing.T) { func TestAccountResolveMappedModel(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
platform string
credentials map[string]any credentials map[string]any
requestedModel string requestedModel string
expectedModel string expectedModel string
...@@ -312,6 +358,31 @@ func TestAccountResolveMappedModel(t *testing.T) { ...@@ -312,6 +358,31 @@ func TestAccountResolveMappedModel(t *testing.T) {
expectedModel: "gpt-5.4", expectedModel: "gpt-5.4",
expectedMatch: true, expectedMatch: true,
}, },
{
name: "gemini customtools alias reports normalized match",
platform: PlatformGemini,
credentials: map[string]any{
"model_mapping": map[string]any{
"gemini-3.1-pro-preview": "gemini-3.1-pro-preview",
},
},
requestedModel: "gemini-3.1-pro-preview-customtools",
expectedModel: "gemini-3.1-pro-preview",
expectedMatch: true,
},
{
name: "gemini customtools exact mapping reports exact match",
platform: PlatformGemini,
credentials: map[string]any{
"model_mapping": map[string]any{
"gemini-3.1-pro-preview": "gemini-3.1-pro-preview",
"gemini-3.1-pro-preview-customtools": "gemini-3.1-pro-preview-customtools",
},
},
requestedModel: "gemini-3.1-pro-preview-customtools",
expectedModel: "gemini-3.1-pro-preview-customtools",
expectedMatch: true,
},
{ {
name: "missing mapping reports unmatched", name: "missing mapping reports unmatched",
credentials: map[string]any{ credentials: map[string]any{
...@@ -328,6 +399,7 @@ func TestAccountResolveMappedModel(t *testing.T) { ...@@ -328,6 +399,7 @@ func TestAccountResolveMappedModel(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
account := &Account{ account := &Account{
Platform: tt.platform,
Credentials: tt.credentials, Credentials: tt.credentials,
} }
mappedModel, matched := account.ResolveMappedModel(tt.requestedModel) mappedModel, matched := account.ResolveMappedModel(tt.requestedModel)
......
...@@ -268,6 +268,12 @@ func TestMapAntigravityModel_WildcardTargetEqualsRequest(t *testing.T) { ...@@ -268,6 +268,12 @@ func TestMapAntigravityModel_WildcardTargetEqualsRequest(t *testing.T) {
requestedModel: "gemini-2.5-flash", requestedModel: "gemini-2.5-flash",
expected: "gemini-2.5-flash", expected: "gemini-2.5-flash",
}, },
{
name: "customtools alias falls back to normalized preview mapping",
modelMapping: map[string]any{"gemini-3.1-pro-preview": "gemini-3.1-pro-high"},
requestedModel: "gemini-3.1-pro-preview-customtools",
expected: "gemini-3.1-pro-high",
},
} }
for _, tt := range tests { for _, tt := range tests {
......
...@@ -85,7 +85,7 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact ...@@ -85,7 +85,7 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
if v, ok := reqBody["model"].(string); ok { if v, ok := reqBody["model"].(string); ok {
model = v model = v
} }
normalizedModel := normalizeCodexModel(model) normalizedModel := strings.TrimSpace(model)
if normalizedModel != "" { if normalizedModel != "" {
if model != normalizedModel { if model != normalizedModel {
reqBody["model"] = normalizedModel reqBody["model"] = normalizedModel
......
...@@ -246,6 +246,7 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) { ...@@ -246,6 +246,7 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) {
"gpt-5.3-codex": "gpt-5.3-codex", "gpt-5.3-codex": "gpt-5.3-codex",
"gpt-5.3-codex-xhigh": "gpt-5.3-codex", "gpt-5.3-codex-xhigh": "gpt-5.3-codex",
"gpt-5.3-codex-spark": "gpt-5.3-codex", "gpt-5.3-codex-spark": "gpt-5.3-codex",
"gpt 5.3 codex spark": "gpt-5.3-codex",
"gpt-5.3-codex-spark-high": "gpt-5.3-codex", "gpt-5.3-codex-spark-high": "gpt-5.3-codex",
"gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex", "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex",
"gpt 5.3 codex": "gpt-5.3-codex", "gpt 5.3 codex": "gpt-5.3-codex",
...@@ -256,6 +257,34 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) { ...@@ -256,6 +257,34 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) {
} }
} }
func TestApplyCodexOAuthTransform_PreservesBareSparkModel(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.3-codex-spark",
"input": []any{},
}
result := applyCodexOAuthTransform(reqBody, false, false)
require.Equal(t, "gpt-5.3-codex-spark", reqBody["model"])
require.Equal(t, "gpt-5.3-codex-spark", result.NormalizedModel)
store, ok := reqBody["store"].(bool)
require.True(t, ok)
require.False(t, store)
}
func TestApplyCodexOAuthTransform_TrimmedModelWithoutPolicyRewrite(t *testing.T) {
reqBody := map[string]any{
"model": " gpt-5.3-codex-spark ",
"input": []any{},
}
result := applyCodexOAuthTransform(reqBody, false, false)
require.Equal(t, "gpt-5.3-codex-spark", reqBody["model"])
require.Equal(t, "gpt-5.3-codex-spark", result.NormalizedModel)
require.True(t, result.Modified)
}
func TestApplyCodexOAuthTransform_CodexCLI_PreservesExistingInstructions(t *testing.T) { func TestApplyCodexOAuthTransform_CodexCLI_PreservesExistingInstructions(t *testing.T) {
// Codex CLI 场景:已有 instructions 时不修改 // Codex CLI 场景:已有 instructions 时不修改
......
...@@ -10,8 +10,8 @@ import ( ...@@ -10,8 +10,8 @@ import (
const compatPromptCacheKeyPrefix = "compat_cc_" const compatPromptCacheKeyPrefix = "compat_cc_"
func shouldAutoInjectPromptCacheKeyForCompat(model string) bool { func shouldAutoInjectPromptCacheKeyForCompat(model string) bool {
switch normalizeCodexModel(strings.TrimSpace(model)) { switch resolveOpenAIUpstreamModel(strings.TrimSpace(model)) {
case "gpt-5.4", "gpt-5.3-codex": case "gpt-5.4", "gpt-5.3-codex", "gpt-5.3-codex-spark":
return true return true
default: default:
return false return false
...@@ -23,9 +23,9 @@ func deriveCompatPromptCacheKey(req *apicompat.ChatCompletionsRequest, mappedMod ...@@ -23,9 +23,9 @@ func deriveCompatPromptCacheKey(req *apicompat.ChatCompletionsRequest, mappedMod
return "" return ""
} }
normalizedModel := normalizeCodexModel(strings.TrimSpace(mappedModel)) normalizedModel := resolveOpenAIUpstreamModel(strings.TrimSpace(mappedModel))
if normalizedModel == "" { if normalizedModel == "" {
normalizedModel = normalizeCodexModel(strings.TrimSpace(req.Model)) normalizedModel = resolveOpenAIUpstreamModel(strings.TrimSpace(req.Model))
} }
if normalizedModel == "" { if normalizedModel == "" {
normalizedModel = strings.TrimSpace(req.Model) normalizedModel = strings.TrimSpace(req.Model)
......
...@@ -17,6 +17,7 @@ func TestShouldAutoInjectPromptCacheKeyForCompat(t *testing.T) { ...@@ -17,6 +17,7 @@ func TestShouldAutoInjectPromptCacheKeyForCompat(t *testing.T) {
require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.4")) require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.4"))
require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3")) require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3"))
require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3-codex")) require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3-codex"))
require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3-codex-spark"))
require.False(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-4o")) require.False(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-4o"))
} }
...@@ -62,3 +63,17 @@ func TestDeriveCompatPromptCacheKey_DiffersAcrossSessions(t *testing.T) { ...@@ -62,3 +63,17 @@ func TestDeriveCompatPromptCacheKey_DiffersAcrossSessions(t *testing.T) {
k2 := deriveCompatPromptCacheKey(req2, "gpt-5.4") k2 := deriveCompatPromptCacheKey(req2, "gpt-5.4")
require.NotEqual(t, k1, k2, "different first user messages should yield different keys") require.NotEqual(t, k1, k2, "different first user messages should yield different keys")
} }
func TestDeriveCompatPromptCacheKey_UsesResolvedSparkFamily(t *testing.T) {
req := &apicompat.ChatCompletionsRequest{
Model: "gpt-5.3-codex-spark",
Messages: []apicompat.ChatMessage{
{Role: "user", Content: mustRawJSON(t, `"Question A"`)},
},
}
k1 := deriveCompatPromptCacheKey(req, "gpt-5.3-codex-spark")
k2 := deriveCompatPromptCacheKey(req, " openai/gpt-5.3-codex-spark ")
require.NotEmpty(t, k1)
require.Equal(t, k1, k2, "resolved spark family should derive a stable compat cache key")
}
...@@ -45,12 +45,13 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( ...@@ -45,12 +45,13 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
// 2. Resolve model mapping early so compat prompt_cache_key injection can // 2. Resolve model mapping early so compat prompt_cache_key injection can
// derive a stable seed from the final upstream model family. // derive a stable seed from the final upstream model family.
mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel) billingModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
upstreamModel := resolveOpenAIUpstreamModel(billingModel)
promptCacheKey = strings.TrimSpace(promptCacheKey) promptCacheKey = strings.TrimSpace(promptCacheKey)
compatPromptCacheInjected := false compatPromptCacheInjected := false
if promptCacheKey == "" && account.Type == AccountTypeOAuth && shouldAutoInjectPromptCacheKeyForCompat(mappedModel) { if promptCacheKey == "" && account.Type == AccountTypeOAuth && shouldAutoInjectPromptCacheKeyForCompat(upstreamModel) {
promptCacheKey = deriveCompatPromptCacheKey(&chatReq, mappedModel) promptCacheKey = deriveCompatPromptCacheKey(&chatReq, upstreamModel)
compatPromptCacheInjected = promptCacheKey != "" compatPromptCacheInjected = promptCacheKey != ""
} }
...@@ -60,12 +61,13 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( ...@@ -60,12 +61,13 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
if err != nil { if err != nil {
return nil, fmt.Errorf("convert chat completions to responses: %w", err) return nil, fmt.Errorf("convert chat completions to responses: %w", err)
} }
responsesReq.Model = mappedModel responsesReq.Model = upstreamModel
logFields := []zap.Field{ logFields := []zap.Field{
zap.Int64("account_id", account.ID), zap.Int64("account_id", account.ID),
zap.String("original_model", originalModel), zap.String("original_model", originalModel),
zap.String("mapped_model", mappedModel), zap.String("billing_model", billingModel),
zap.String("upstream_model", upstreamModel),
zap.Bool("stream", clientStream), zap.Bool("stream", clientStream),
} }
if compatPromptCacheInjected { if compatPromptCacheInjected {
...@@ -88,6 +90,9 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( ...@@ -88,6 +90,9 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
return nil, fmt.Errorf("unmarshal for codex transform: %w", err) return nil, fmt.Errorf("unmarshal for codex transform: %w", err)
} }
codexResult := applyCodexOAuthTransform(reqBody, false, false) codexResult := applyCodexOAuthTransform(reqBody, false, false)
if codexResult.NormalizedModel != "" {
upstreamModel = codexResult.NormalizedModel
}
if codexResult.PromptCacheKey != "" { if codexResult.PromptCacheKey != "" {
promptCacheKey = codexResult.PromptCacheKey promptCacheKey = codexResult.PromptCacheKey
} else if promptCacheKey != "" { } else if promptCacheKey != "" {
...@@ -180,9 +185,9 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( ...@@ -180,9 +185,9 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
var result *OpenAIForwardResult var result *OpenAIForwardResult
var handleErr error var handleErr error
if clientStream { if clientStream {
result, handleErr = s.handleChatStreamingResponse(resp, c, originalModel, mappedModel, includeUsage, startTime) result, handleErr = s.handleChatStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, includeUsage, startTime)
} else { } else {
result, handleErr = s.handleChatBufferedStreamingResponse(resp, c, originalModel, mappedModel, startTime) result, handleErr = s.handleChatBufferedStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, startTime)
} }
// Propagate ServiceTier and ReasoningEffort to result for billing // Propagate ServiceTier and ReasoningEffort to result for billing
...@@ -224,7 +229,8 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse( ...@@ -224,7 +229,8 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse(
resp *http.Response, resp *http.Response,
c *gin.Context, c *gin.Context,
originalModel string, originalModel string,
mappedModel string, billingModel string,
upstreamModel string,
startTime time.Time, startTime time.Time,
) (*OpenAIForwardResult, error) { ) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id") requestID := resp.Header.Get("x-request-id")
...@@ -295,8 +301,8 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse( ...@@ -295,8 +301,8 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse(
RequestID: requestID, RequestID: requestID,
Usage: usage, Usage: usage,
Model: originalModel, Model: originalModel,
BillingModel: mappedModel, BillingModel: billingModel,
UpstreamModel: mappedModel, UpstreamModel: upstreamModel,
Stream: false, Stream: false,
Duration: time.Since(startTime), Duration: time.Since(startTime),
}, nil }, nil
...@@ -308,7 +314,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( ...@@ -308,7 +314,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
resp *http.Response, resp *http.Response,
c *gin.Context, c *gin.Context,
originalModel string, originalModel string,
mappedModel string, billingModel string,
upstreamModel string,
includeUsage bool, includeUsage bool,
startTime time.Time, startTime time.Time,
) (*OpenAIForwardResult, error) { ) (*OpenAIForwardResult, error) {
...@@ -343,8 +350,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( ...@@ -343,8 +350,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
RequestID: requestID, RequestID: requestID,
Usage: usage, Usage: usage,
Model: originalModel, Model: originalModel,
BillingModel: mappedModel, BillingModel: billingModel,
UpstreamModel: mappedModel, UpstreamModel: upstreamModel,
Stream: true, Stream: true,
Duration: time.Since(startTime), Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs, FirstTokenMs: firstTokenMs,
......
...@@ -41,6 +41,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( ...@@ -41,6 +41,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
} }
originalModel := anthropicReq.Model originalModel := anthropicReq.Model
applyOpenAICompatModelNormalization(&anthropicReq) applyOpenAICompatModelNormalization(&anthropicReq)
normalizedModel := anthropicReq.Model
clientStream := anthropicReq.Stream // client's original stream preference clientStream := anthropicReq.Stream // client's original stream preference
// 2. Convert Anthropic → Responses // 2. Convert Anthropic → Responses
...@@ -60,13 +61,16 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( ...@@ -60,13 +61,16 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
} }
// 3. Model mapping // 3. Model mapping
mappedModel := resolveOpenAIForwardModel(account, anthropicReq.Model, defaultMappedModel) billingModel := resolveOpenAIForwardModel(account, normalizedModel, defaultMappedModel)
responsesReq.Model = mappedModel upstreamModel := resolveOpenAIUpstreamModel(billingModel)
responsesReq.Model = upstreamModel
logger.L().Debug("openai messages: model mapping applied", logger.L().Debug("openai messages: model mapping applied",
zap.Int64("account_id", account.ID), zap.Int64("account_id", account.ID),
zap.String("original_model", originalModel), zap.String("original_model", originalModel),
zap.String("mapped_model", mappedModel), zap.String("normalized_model", normalizedModel),
zap.String("billing_model", billingModel),
zap.String("upstream_model", upstreamModel),
zap.Bool("stream", isStream), zap.Bool("stream", isStream),
) )
...@@ -82,6 +86,9 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( ...@@ -82,6 +86,9 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
return nil, fmt.Errorf("unmarshal for codex transform: %w", err) return nil, fmt.Errorf("unmarshal for codex transform: %w", err)
} }
codexResult := applyCodexOAuthTransform(reqBody, false, false) codexResult := applyCodexOAuthTransform(reqBody, false, false)
if codexResult.NormalizedModel != "" {
upstreamModel = codexResult.NormalizedModel
}
if codexResult.PromptCacheKey != "" { if codexResult.PromptCacheKey != "" {
promptCacheKey = codexResult.PromptCacheKey promptCacheKey = codexResult.PromptCacheKey
} else if promptCacheKey != "" { } else if promptCacheKey != "" {
...@@ -182,10 +189,10 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( ...@@ -182,10 +189,10 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
var result *OpenAIForwardResult var result *OpenAIForwardResult
var handleErr error var handleErr error
if clientStream { if clientStream {
result, handleErr = s.handleAnthropicStreamingResponse(resp, c, originalModel, mappedModel, startTime) result, handleErr = s.handleAnthropicStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, startTime)
} else { } else {
// Client wants JSON: buffer the streaming response and assemble a JSON reply. // Client wants JSON: buffer the streaming response and assemble a JSON reply.
result, handleErr = s.handleAnthropicBufferedStreamingResponse(resp, c, originalModel, mappedModel, startTime) result, handleErr = s.handleAnthropicBufferedStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, startTime)
} }
// Propagate ServiceTier and ReasoningEffort to result for billing // Propagate ServiceTier and ReasoningEffort to result for billing
...@@ -230,7 +237,8 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse( ...@@ -230,7 +237,8 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
resp *http.Response, resp *http.Response,
c *gin.Context, c *gin.Context,
originalModel string, originalModel string,
mappedModel string, billingModel string,
upstreamModel string,
startTime time.Time, startTime time.Time,
) (*OpenAIForwardResult, error) { ) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id") requestID := resp.Header.Get("x-request-id")
...@@ -303,8 +311,8 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse( ...@@ -303,8 +311,8 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
RequestID: requestID, RequestID: requestID,
Usage: usage, Usage: usage,
Model: originalModel, Model: originalModel,
BillingModel: mappedModel, BillingModel: billingModel,
UpstreamModel: mappedModel, UpstreamModel: upstreamModel,
Stream: false, Stream: false,
Duration: time.Since(startTime), Duration: time.Since(startTime),
}, nil }, nil
...@@ -319,7 +327,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( ...@@ -319,7 +327,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
resp *http.Response, resp *http.Response,
c *gin.Context, c *gin.Context,
originalModel string, originalModel string,
mappedModel string, billingModel string,
upstreamModel string,
startTime time.Time, startTime time.Time,
) (*OpenAIForwardResult, error) { ) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id") requestID := resp.Header.Get("x-request-id")
...@@ -352,8 +361,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( ...@@ -352,8 +361,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
RequestID: requestID, RequestID: requestID,
Usage: usage, Usage: usage,
Model: originalModel, Model: originalModel,
BillingModel: mappedModel, BillingModel: billingModel,
UpstreamModel: mappedModel, UpstreamModel: upstreamModel,
Stream: true, Stream: true,
Duration: time.Since(startTime), Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs, FirstTokenMs: firstTokenMs,
......
...@@ -1818,29 +1818,29 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco ...@@ -1818,29 +1818,29 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
} }
// 对所有请求执行模型映射(包含 Codex CLI)。 // 对所有请求执行模型映射(包含 Codex CLI)。
mappedModel := account.GetMappedModel(reqModel) billingModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel { if billingModel != reqModel {
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, mappedModel, account.Name, isCodexCLI) logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, billingModel, account.Name, isCodexCLI)
reqBody["model"] = mappedModel reqBody["model"] = billingModel
bodyModified = true bodyModified = true
markPatchSet("model", mappedModel) markPatchSet("model", billingModel)
} }
upstreamModel := billingModel
// 针对所有 OpenAI 账号执行 Codex 模型名规范化,确保上游识别一致。 // 针对所有 OpenAI 账号执行 Codex 模型名规范化,确保上游识别一致。
if model, ok := reqBody["model"].(string); ok { if model, ok := reqBody["model"].(string); ok {
normalizedModel := normalizeCodexModel(model) upstreamModel = resolveOpenAIUpstreamModel(model)
if normalizedModel != "" && normalizedModel != model { if upstreamModel != "" && upstreamModel != model {
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Codex model normalization: %s -> %s (account: %s, type: %s, isCodexCLI: %v)", logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Upstream model resolved: %s -> %s (account: %s, type: %s, isCodexCLI: %v)",
model, normalizedModel, account.Name, account.Type, isCodexCLI) model, upstreamModel, account.Name, account.Type, isCodexCLI)
reqBody["model"] = normalizedModel reqBody["model"] = upstreamModel
mappedModel = normalizedModel
bodyModified = true bodyModified = true
markPatchSet("model", normalizedModel) markPatchSet("model", upstreamModel)
} }
// 移除 gpt-5.2-codex 以下的版本 verbosity 参数 // 移除 gpt-5.2-codex 以下的版本 verbosity 参数
// 确保高版本模型向低版本模型映射不报错 // 确保高版本模型向低版本模型映射不报错
if !SupportsVerbosity(normalizedModel) { if !SupportsVerbosity(upstreamModel) {
if text, ok := reqBody["text"].(map[string]any); ok { if text, ok := reqBody["text"].(map[string]any); ok {
delete(text, "verbosity") delete(text, "verbosity")
} }
...@@ -1864,7 +1864,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco ...@@ -1864,7 +1864,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
disablePatch() disablePatch()
} }
if codexResult.NormalizedModel != "" { if codexResult.NormalizedModel != "" {
mappedModel = codexResult.NormalizedModel upstreamModel = codexResult.NormalizedModel
} }
if codexResult.PromptCacheKey != "" { if codexResult.PromptCacheKey != "" {
promptCacheKey = codexResult.PromptCacheKey promptCacheKey = codexResult.PromptCacheKey
...@@ -1981,7 +1981,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco ...@@ -1981,7 +1981,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
"forward_start account_id=%d account_type=%s model=%s stream=%v has_previous_response_id=%v", "forward_start account_id=%d account_type=%s model=%s stream=%v has_previous_response_id=%v",
account.ID, account.ID,
account.Type, account.Type,
mappedModel, upstreamModel,
reqStream, reqStream,
hasPreviousResponseID, hasPreviousResponseID,
) )
...@@ -2070,7 +2070,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco ...@@ -2070,7 +2070,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
isCodexCLI, isCodexCLI,
reqStream, reqStream,
originalModel, originalModel,
mappedModel, upstreamModel,
startTime, startTime,
attempt, attempt,
wsLastFailureReason, wsLastFailureReason,
...@@ -2171,7 +2171,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco ...@@ -2171,7 +2171,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
firstTokenMs, firstTokenMs,
wsAttempts, wsAttempts,
) )
wsResult.UpstreamModel = mappedModel wsResult.UpstreamModel = upstreamModel
return wsResult, nil return wsResult, nil
} }
s.writeOpenAIWSFallbackErrorResponse(c, account, wsErr) s.writeOpenAIWSFallbackErrorResponse(c, account, wsErr)
...@@ -2276,14 +2276,14 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco ...@@ -2276,14 +2276,14 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
var usage *OpenAIUsage var usage *OpenAIUsage
var firstTokenMs *int var firstTokenMs *int
if reqStream { if reqStream {
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel) streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, upstreamModel)
if err != nil { if err != nil {
return nil, err return nil, err
} }
usage = streamResult.usage usage = streamResult.usage
firstTokenMs = streamResult.firstTokenMs firstTokenMs = streamResult.firstTokenMs
} else { } else {
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, mappedModel) usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, upstreamModel)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -2307,7 +2307,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco ...@@ -2307,7 +2307,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
RequestID: resp.Header.Get("x-request-id"), RequestID: resp.Header.Get("x-request-id"),
Usage: *usage, Usage: *usage,
Model: originalModel, Model: originalModel,
UpstreamModel: mappedModel, UpstreamModel: upstreamModel,
ServiceTier: serviceTier, ServiceTier: serviceTier,
ReasoningEffort: reasoningEffort, ReasoningEffort: reasoningEffort,
Stream: reqStream, Stream: reqStream,
......
package service package service
// resolveOpenAIForwardModel determines the upstream model for OpenAI-compatible import "strings"
// forwarding. Group-level default mapping only applies when the account itself
// did not match any explicit model_mapping rule. // resolveOpenAIForwardModel resolves the account/group mapping result for
// OpenAI-compatible forwarding. Group-level default mapping only applies when
// the account itself did not match any explicit model_mapping rule.
func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedModel string) string { func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedModel string) string {
if account == nil { if account == nil {
if defaultMappedModel != "" { if defaultMappedModel != "" {
...@@ -17,3 +19,23 @@ func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedMo ...@@ -17,3 +19,23 @@ func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedMo
} }
return mappedModel return mappedModel
} }
func resolveOpenAIUpstreamModel(model string) string {
if isBareGPT53CodexSparkModel(model) {
return "gpt-5.3-codex-spark"
}
return normalizeCodexModel(strings.TrimSpace(model))
}
func isBareGPT53CodexSparkModel(model string) bool {
modelID := strings.TrimSpace(model)
if modelID == "" {
return false
}
if strings.Contains(modelID, "/") {
parts := strings.Split(modelID, "/")
modelID = parts[len(parts)-1]
}
normalized := strings.ToLower(strings.TrimSpace(modelID))
return normalized == "gpt-5.3-codex-spark" || normalized == "gpt 5.3 codex spark"
}
...@@ -74,13 +74,30 @@ func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt51(t * ...@@ -74,13 +74,30 @@ func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt51(t *
Credentials: map[string]any{}, Credentials: map[string]any{},
} }
withoutDefault := resolveOpenAIForwardModel(account, "claude-opus-4-6", "") withoutDefault := resolveOpenAIUpstreamModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", ""))
if got := normalizeCodexModel(withoutDefault); got != "gpt-5.1" { if withoutDefault != "gpt-5.1" {
t.Fatalf("normalizeCodexModel(%q) = %q, want %q", withoutDefault, got, "gpt-5.1") t.Fatalf("resolveOpenAIUpstreamModel(...) = %q, want %q", withoutDefault, "gpt-5.1")
} }
withDefault := resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4") withDefault := resolveOpenAIUpstreamModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4"))
if got := normalizeCodexModel(withDefault); got != "gpt-5.4" { if withDefault != "gpt-5.4" {
t.Fatalf("normalizeCodexModel(%q) = %q, want %q", withDefault, got, "gpt-5.4") t.Fatalf("resolveOpenAIUpstreamModel(...) = %q, want %q", withDefault, "gpt-5.4")
}
}
func TestResolveOpenAIUpstreamModel(t *testing.T) {
cases := map[string]string{
"gpt-5.3-codex-spark": "gpt-5.3-codex-spark",
"gpt 5.3 codex spark": "gpt-5.3-codex-spark",
" openai/gpt-5.3-codex-spark ": "gpt-5.3-codex-spark",
"gpt-5.3-codex-spark-high": "gpt-5.3-codex",
"gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex",
"gpt-5.3": "gpt-5.3-codex",
}
for input, expected := range cases {
if got := resolveOpenAIUpstreamModel(input); got != expected {
t.Fatalf("resolveOpenAIUpstreamModel(%q) = %q, want %q", input, got, expected)
}
} }
} }
...@@ -2515,12 +2515,9 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( ...@@ -2515,12 +2515,9 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
} }
normalized = next normalized = next
} }
mappedModel := account.GetMappedModel(originalModel) upstreamModel := resolveOpenAIUpstreamModel(account.GetMappedModel(originalModel))
if normalizedModel := normalizeCodexModel(mappedModel); normalizedModel != "" { if upstreamModel != originalModel {
mappedModel = normalizedModel next, setErr := applyPayloadMutation(normalized, "model", upstreamModel)
}
if mappedModel != originalModel {
next, setErr := applyPayloadMutation(normalized, "model", mappedModel)
if setErr != nil { if setErr != nil {
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr) return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr)
} }
...@@ -2776,10 +2773,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( ...@@ -2776,10 +2773,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
mappedModel := "" mappedModel := ""
var mappedModelBytes []byte var mappedModelBytes []byte
if originalModel != "" { if originalModel != "" {
mappedModel = account.GetMappedModel(originalModel) mappedModel = resolveOpenAIUpstreamModel(account.GetMappedModel(originalModel))
if normalizedModel := normalizeCodexModel(mappedModel); normalizedModel != "" {
mappedModel = normalizedModel
}
needModelReplace = mappedModel != "" && mappedModel != originalModel needModelReplace = mappedModel != "" && mappedModel != originalModel
if needModelReplace { if needModelReplace {
mappedModelBytes = []byte(mappedModel) mappedModelBytes = []byte(mappedModel)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment