Unverified Commit 342fd03e authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #986 from LvyuanW/openai-model-mapping-fix

fix: honor account model mapping before group fallback
parents e7086cb3 a377e990
...@@ -522,16 +522,23 @@ func (a *Account) IsModelSupported(requestedModel string) bool { ...@@ -522,16 +522,23 @@ func (a *Account) IsModelSupported(requestedModel string) bool {
// GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配) // GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配)
// 如果未配置 mapping,返回原始模型名 // 如果未配置 mapping,返回原始模型名
func (a *Account) GetMappedModel(requestedModel string) string { func (a *Account) GetMappedModel(requestedModel string) string {
mappedModel, _ := a.ResolveMappedModel(requestedModel)
return mappedModel
}
// ResolveMappedModel 获取映射后的模型名,并返回是否命中了账号级映射。
// matched=true 表示命中了精确映射或通配符映射,即使映射结果与原模型名相同。
func (a *Account) ResolveMappedModel(requestedModel string) (mappedModel string, matched bool) {
mapping := a.GetModelMapping() mapping := a.GetModelMapping()
if len(mapping) == 0 { if len(mapping) == 0 {
return requestedModel return requestedModel, false
} }
// 精确匹配优先 // 精确匹配优先
if mappedModel, exists := mapping[requestedModel]; exists { if mappedModel, exists := mapping[requestedModel]; exists {
return mappedModel return mappedModel, true
} }
// 通配符匹配(最长优先) // 通配符匹配(最长优先)
return matchWildcardMapping(mapping, requestedModel) return matchWildcardMappingResult(mapping, requestedModel)
} }
func (a *Account) GetBaseURL() string { func (a *Account) GetBaseURL() string {
...@@ -605,9 +612,7 @@ func matchWildcard(pattern, str string) bool { ...@@ -605,9 +612,7 @@ func matchWildcard(pattern, str string) bool {
return matchAntigravityWildcard(pattern, str) return matchAntigravityWildcard(pattern, str)
} }
// matchWildcardMapping 通配符映射匹配(最长优先) func matchWildcardMappingResult(mapping map[string]string, requestedModel string) (string, bool) {
// 如果没有匹配,返回原始字符串
func matchWildcardMapping(mapping map[string]string, requestedModel string) string {
// 收集所有匹配的 pattern,按长度降序排序(最长优先) // 收集所有匹配的 pattern,按长度降序排序(最长优先)
type patternMatch struct { type patternMatch struct {
pattern string pattern string
...@@ -622,7 +627,7 @@ func matchWildcardMapping(mapping map[string]string, requestedModel string) stri ...@@ -622,7 +627,7 @@ func matchWildcardMapping(mapping map[string]string, requestedModel string) stri
} }
if len(matches) == 0 { if len(matches) == 0 {
return requestedModel // 无匹配,返回原始模型名 return requestedModel, false // 无匹配,返回原始模型名
} }
// 按 pattern 长度降序排序 // 按 pattern 长度降序排序
...@@ -633,7 +638,7 @@ func matchWildcardMapping(mapping map[string]string, requestedModel string) stri ...@@ -633,7 +638,7 @@ func matchWildcardMapping(mapping map[string]string, requestedModel string) stri
return matches[i].pattern < matches[j].pattern return matches[i].pattern < matches[j].pattern
}) })
return matches[0].target return matches[0].target, true
} }
func (a *Account) IsCustomErrorCodesEnabled() bool { func (a *Account) IsCustomErrorCodesEnabled() bool {
......
...@@ -43,12 +43,13 @@ func TestMatchWildcard(t *testing.T) { ...@@ -43,12 +43,13 @@ func TestMatchWildcard(t *testing.T) {
} }
} }
func TestMatchWildcardMapping(t *testing.T) { func TestMatchWildcardMappingResult(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
mapping map[string]string mapping map[string]string
requestedModel string requestedModel string
expected string expected string
matched bool
}{ }{
// 精确匹配优先于通配符 // 精确匹配优先于通配符
{ {
...@@ -59,6 +60,7 @@ func TestMatchWildcardMapping(t *testing.T) { ...@@ -59,6 +60,7 @@ func TestMatchWildcardMapping(t *testing.T) {
}, },
requestedModel: "claude-sonnet-4-5", requestedModel: "claude-sonnet-4-5",
expected: "claude-sonnet-4-5-exact", expected: "claude-sonnet-4-5-exact",
matched: true,
}, },
// 最长通配符优先 // 最长通配符优先
...@@ -71,6 +73,7 @@ func TestMatchWildcardMapping(t *testing.T) { ...@@ -71,6 +73,7 @@ func TestMatchWildcardMapping(t *testing.T) {
}, },
requestedModel: "claude-sonnet-4-5", requestedModel: "claude-sonnet-4-5",
expected: "claude-sonnet-4-series", expected: "claude-sonnet-4-series",
matched: true,
}, },
// 单个通配符 // 单个通配符
...@@ -81,6 +84,7 @@ func TestMatchWildcardMapping(t *testing.T) { ...@@ -81,6 +84,7 @@ func TestMatchWildcardMapping(t *testing.T) {
}, },
requestedModel: "claude-opus-4-5", requestedModel: "claude-opus-4-5",
expected: "claude-mapped", expected: "claude-mapped",
matched: true,
}, },
// 无匹配返回原始模型 // 无匹配返回原始模型
...@@ -91,6 +95,7 @@ func TestMatchWildcardMapping(t *testing.T) { ...@@ -91,6 +95,7 @@ func TestMatchWildcardMapping(t *testing.T) {
}, },
requestedModel: "gemini-3-flash", requestedModel: "gemini-3-flash",
expected: "gemini-3-flash", expected: "gemini-3-flash",
matched: false,
}, },
// 空映射返回原始模型 // 空映射返回原始模型
...@@ -99,6 +104,7 @@ func TestMatchWildcardMapping(t *testing.T) { ...@@ -99,6 +104,7 @@ func TestMatchWildcardMapping(t *testing.T) {
mapping: map[string]string{}, mapping: map[string]string{},
requestedModel: "claude-sonnet-4-5", requestedModel: "claude-sonnet-4-5",
expected: "claude-sonnet-4-5", expected: "claude-sonnet-4-5",
matched: false,
}, },
// Gemini 模型映射 // Gemini 模型映射
...@@ -110,14 +116,15 @@ func TestMatchWildcardMapping(t *testing.T) { ...@@ -110,14 +116,15 @@ func TestMatchWildcardMapping(t *testing.T) {
}, },
requestedModel: "gemini-3-flash-preview", requestedModel: "gemini-3-flash-preview",
expected: "gemini-3-pro-high", expected: "gemini-3-pro-high",
matched: true,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result := matchWildcardMapping(tt.mapping, tt.requestedModel) result, matched := matchWildcardMappingResult(tt.mapping, tt.requestedModel)
if result != tt.expected { if result != tt.expected || matched != tt.matched {
t.Errorf("matchWildcardMapping(%v, %q) = %q, want %q", tt.mapping, tt.requestedModel, result, tt.expected) t.Errorf("matchWildcardMappingResult(%v, %q) = (%q, %v), want (%q, %v)", tt.mapping, tt.requestedModel, result, matched, tt.expected, tt.matched)
} }
}) })
} }
...@@ -268,6 +275,69 @@ func TestAccountGetMappedModel(t *testing.T) { ...@@ -268,6 +275,69 @@ func TestAccountGetMappedModel(t *testing.T) {
} }
} }
func TestAccountResolveMappedModel(t *testing.T) {
tests := []struct {
name string
credentials map[string]any
requestedModel string
expectedModel string
expectedMatch bool
}{
{
name: "no mapping reports unmatched",
credentials: nil,
requestedModel: "gpt-5.4",
expectedModel: "gpt-5.4",
expectedMatch: false,
},
{
name: "exact passthrough mapping still counts as matched",
credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-5.4": "gpt-5.4",
},
},
requestedModel: "gpt-5.4",
expectedModel: "gpt-5.4",
expectedMatch: true,
},
{
name: "wildcard passthrough mapping still counts as matched",
credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-*": "gpt-5.4",
},
},
requestedModel: "gpt-5.4",
expectedModel: "gpt-5.4",
expectedMatch: true,
},
{
name: "missing mapping reports unmatched",
credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-5.2": "gpt-5.2",
},
},
requestedModel: "gpt-5.4",
expectedModel: "gpt-5.4",
expectedMatch: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
account := &Account{
Credentials: tt.credentials,
}
mappedModel, matched := account.ResolveMappedModel(tt.requestedModel)
if mappedModel != tt.expectedModel || matched != tt.expectedMatch {
t.Fatalf("ResolveMappedModel(%q) = (%q, %v), want (%q, %v)", tt.requestedModel, mappedModel, matched, tt.expectedModel, tt.expectedMatch)
}
})
}
}
func TestAccountGetModelMapping_AntigravityEnsuresGeminiDefaultPassthroughs(t *testing.T) { func TestAccountGetModelMapping_AntigravityEnsuresGeminiDefaultPassthroughs(t *testing.T) {
account := &Account{ account := &Account{
Platform: PlatformAntigravity, Platform: PlatformAntigravity,
......
...@@ -51,10 +51,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( ...@@ -51,10 +51,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
} }
// 3. Model mapping // 3. Model mapping
mappedModel := account.GetMappedModel(originalModel) mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
if mappedModel == originalModel && defaultMappedModel != "" {
mappedModel = defaultMappedModel
}
responsesReq.Model = mappedModel responsesReq.Model = mappedModel
logger.L().Debug("openai chat_completions: model mapping applied", logger.L().Debug("openai chat_completions: model mapping applied",
......
...@@ -59,11 +59,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( ...@@ -59,11 +59,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
} }
// 3. Model mapping // 3. Model mapping
mappedModel := account.GetMappedModel(originalModel) mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
// 分组级降级:账号未映射时使用分组默认映射模型
if mappedModel == originalModel && defaultMappedModel != "" {
mappedModel = defaultMappedModel
}
responsesReq.Model = mappedModel responsesReq.Model = mappedModel
logger.L().Debug("openai messages: model mapping applied", logger.L().Debug("openai messages: model mapping applied",
......
package service
// resolveOpenAIForwardModel determines the upstream model for OpenAI-compatible
// forwarding. Group-level default mapping only applies when the account itself
// did not match any explicit model_mapping rule.
func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedModel string) string {
if account == nil {
if defaultMappedModel != "" {
return defaultMappedModel
}
return requestedModel
}
mappedModel, matched := account.ResolveMappedModel(requestedModel)
if !matched && defaultMappedModel != "" {
return defaultMappedModel
}
return mappedModel
}
package service
import "testing"
func TestResolveOpenAIForwardModel(t *testing.T) {
tests := []struct {
name string
account *Account
requestedModel string
defaultMappedModel string
expectedModel string
}{
{
name: "falls back to group default when account has no mapping",
account: &Account{
Credentials: map[string]any{},
},
requestedModel: "gpt-5.4",
defaultMappedModel: "gpt-4o-mini",
expectedModel: "gpt-4o-mini",
},
{
name: "preserves exact passthrough mapping instead of group default",
account: &Account{
Credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-5.4": "gpt-5.4",
},
},
},
requestedModel: "gpt-5.4",
defaultMappedModel: "gpt-4o-mini",
expectedModel: "gpt-5.4",
},
{
name: "preserves wildcard passthrough mapping instead of group default",
account: &Account{
Credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-*": "gpt-5.4",
},
},
},
requestedModel: "gpt-5.4",
defaultMappedModel: "gpt-4o-mini",
expectedModel: "gpt-5.4",
},
{
name: "uses account remap when explicit target differs",
account: &Account{
Credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-5": "gpt-5.4",
},
},
},
requestedModel: "gpt-5",
defaultMappedModel: "gpt-4o-mini",
expectedModel: "gpt-5.4",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := resolveOpenAIForwardModel(tt.account, tt.requestedModel, tt.defaultMappedModel); got != tt.expectedModel {
t.Fatalf("resolveOpenAIForwardModel(...) = %q, want %q", got, tt.expectedModel)
}
})
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment