Commit 4617ef2b authored by Jiahao Luo's avatar Jiahao Luo
Browse files

Fix OpenAI default model forwarding

parent 94bba415
...@@ -181,7 +181,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { ...@@ -181,7 +181,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
forwardStart := time.Now() forwardStart := time.Now()
defaultMappedModel := c.GetString("openai_chat_completions_fallback_model") defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_chat_completions_fallback_model"))
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel) result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
forwardDurationMs := time.Since(forwardStart).Milliseconds() forwardDurationMs := time.Since(forwardStart).Milliseconds()
......
...@@ -37,6 +37,16 @@ type OpenAIGatewayHandler struct { ...@@ -37,6 +37,16 @@ type OpenAIGatewayHandler struct {
cfg *config.Config cfg *config.Config
} }
func resolveOpenAIForwardDefaultMappedModel(apiKey *service.APIKey, fallbackModel string) string {
if fallbackModel = strings.TrimSpace(fallbackModel); fallbackModel != "" {
return fallbackModel
}
if apiKey == nil || apiKey.Group == nil {
return ""
}
return strings.TrimSpace(apiKey.Group.DefaultMappedModel)
}
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler // NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
func NewOpenAIGatewayHandler( func NewOpenAIGatewayHandler(
gatewayService *service.OpenAIGatewayService, gatewayService *service.OpenAIGatewayService,
...@@ -657,9 +667,9 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { ...@@ -657,9 +667,9 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
forwardStart := time.Now() forwardStart := time.Now()
// 仅在调度时实际触发了降级(原模型无可用账号、改用默认模型重试成功)时, // Forward 层需要始终拿到 group 默认映射模型,这样未命中账号级映射的
// 才将降级模型传给 Forward 层做模型替换;否则保持用户请求的原始模型 // Claude 兼容模型才不会在后续 Codex 规范化中意外退化到 gpt-5.1
defaultMappedModel := c.GetString("openai_messages_fallback_model") defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_messages_fallback_model"))
result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel) result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
forwardDurationMs := time.Since(forwardStart).Milliseconds() forwardDurationMs := time.Since(forwardStart).Milliseconds()
......
...@@ -352,6 +352,30 @@ func TestOpenAIEnsureResponsesDependencies(t *testing.T) { ...@@ -352,6 +352,30 @@ func TestOpenAIEnsureResponsesDependencies(t *testing.T) {
}) })
} }
func TestResolveOpenAIForwardDefaultMappedModel(t *testing.T) {
t.Run("prefers_explicit_fallback_model", func(t *testing.T) {
apiKey := &service.APIKey{
Group: &service.Group{DefaultMappedModel: "gpt-5.4"},
}
require.Equal(t, "gpt-5.2", resolveOpenAIForwardDefaultMappedModel(apiKey, " gpt-5.2 "))
})
t.Run("uses_group_default_on_normal_path", func(t *testing.T) {
apiKey := &service.APIKey{
Group: &service.Group{DefaultMappedModel: "gpt-5.4"},
}
require.Equal(t, "gpt-5.4", resolveOpenAIForwardDefaultMappedModel(apiKey, ""))
})
t.Run("returns_empty_without_group_default", func(t *testing.T) {
require.Empty(t, resolveOpenAIForwardDefaultMappedModel(nil, ""))
require.Empty(t, resolveOpenAIForwardDefaultMappedModel(&service.APIKey{}, ""))
require.Empty(t, resolveOpenAIForwardDefaultMappedModel(&service.APIKey{
Group: &service.Group{},
}, ""))
})
}
func TestOpenAIResponses_MissingDependencies_ReturnsServiceUnavailable(t *testing.T) { func TestOpenAIResponses_MissingDependencies_ReturnsServiceUnavailable(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
......
...@@ -68,3 +68,19 @@ func TestResolveOpenAIForwardModel(t *testing.T) { ...@@ -68,3 +68,19 @@ func TestResolveOpenAIForwardModel(t *testing.T) {
}) })
} }
} }
func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt51(t *testing.T) {
account := &Account{
Credentials: map[string]any{},
}
withoutDefault := resolveOpenAIForwardModel(account, "claude-opus-4-6", "")
if got := normalizeCodexModel(withoutDefault); got != "gpt-5.1" {
t.Fatalf("normalizeCodexModel(%q) = %q, want %q", withoutDefault, got, "gpt-5.1")
}
withDefault := resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4")
if got := normalizeCodexModel(withDefault); got != "gpt-5.4" {
t.Fatalf("normalizeCodexModel(%q) = %q, want %q", withDefault, got, "gpt-5.4")
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment