Unverified Commit 6663e1ed authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #1420 from YanzheL/fix/1202-gemini-customtools-404

Fix Gemini CLI 404s for gemini-3.1-pro-preview-customtools
parents 83a16dec 649afef5
...@@ -121,7 +121,7 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) { ...@@ -121,7 +121,7 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
googleError(c, http.StatusBadGateway, err.Error()) googleError(c, http.StatusBadGateway, err.Error())
return return
} }
if shouldFallbackGeminiModels(res) { if shouldFallbackGeminiModel(modelName, res) {
c.JSON(http.StatusOK, gemini.FallbackModel(modelName)) c.JSON(http.StatusOK, gemini.FallbackModel(modelName))
return return
} }
...@@ -674,6 +674,16 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool { ...@@ -674,6 +674,16 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool {
return false return false
} }
func shouldFallbackGeminiModel(modelName string, res *service.UpstreamHTTPResult) bool {
if shouldFallbackGeminiModels(res) {
return true
}
if res == nil || res.StatusCode != http.StatusNotFound {
return false
}
return gemini.HasFallbackModel(modelName)
}
// extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。 // extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。
// 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。 // 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。
// //
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
package handler package handler
import ( import (
"net/http"
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
...@@ -141,3 +142,28 @@ func TestGeminiV1BetaHandler_GetModelAntigravityFallback(t *testing.T) { ...@@ -141,3 +142,28 @@ func TestGeminiV1BetaHandler_GetModelAntigravityFallback(t *testing.T) {
}) })
} }
} }
func TestShouldFallbackGeminiModel_KnownFallbackOn404(t *testing.T) {
t.Parallel()
res := &service.UpstreamHTTPResult{StatusCode: http.StatusNotFound}
require.True(t, shouldFallbackGeminiModel("gemini-3.1-pro-preview-customtools", res))
}
func TestShouldFallbackGeminiModel_UnknownModelOn404(t *testing.T) {
t.Parallel()
res := &service.UpstreamHTTPResult{StatusCode: http.StatusNotFound}
require.False(t, shouldFallbackGeminiModel("gemini-future-model", res))
}
func TestShouldFallbackGeminiModel_DelegatesScopeFallback(t *testing.T) {
t.Parallel()
res := &service.UpstreamHTTPResult{
StatusCode: http.StatusForbidden,
Headers: http.Header{"Www-Authenticate": []string{"Bearer error=\"insufficient_scope\""}},
Body: []byte("insufficient authentication scopes"),
}
require.True(t, shouldFallbackGeminiModel("gemini-future-model", res))
}
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
// It is used when upstream model listing is unavailable (e.g. OAuth token missing AI Studio scopes). // It is used when upstream model listing is unavailable (e.g. OAuth token missing AI Studio scopes).
package gemini package gemini
import "strings"
type Model struct { type Model struct {
Name string `json:"name"` Name string `json:"name"`
DisplayName string `json:"displayName,omitempty"` DisplayName string `json:"displayName,omitempty"`
...@@ -23,10 +25,27 @@ func DefaultModels() []Model { ...@@ -23,10 +25,27 @@ func DefaultModels() []Model {
{Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods}, {Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods}, {Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
{Name: "models/gemini-3.1-pro-preview", SupportedGenerationMethods: methods}, {Name: "models/gemini-3.1-pro-preview", SupportedGenerationMethods: methods},
{Name: "models/gemini-3.1-pro-preview-customtools", SupportedGenerationMethods: methods},
{Name: "models/gemini-3.1-flash-image", SupportedGenerationMethods: methods}, {Name: "models/gemini-3.1-flash-image", SupportedGenerationMethods: methods},
} }
} }
func HasFallbackModel(model string) bool {
trimmed := strings.TrimSpace(model)
if trimmed == "" {
return false
}
if !strings.HasPrefix(trimmed, "models/") {
trimmed = "models/" + trimmed
}
for _, model := range DefaultModels() {
if model.Name == trimmed {
return true
}
}
return false
}
func FallbackModelsList() ModelsListResponse { func FallbackModelsList() ModelsListResponse {
return ModelsListResponse{Models: DefaultModels()} return ModelsListResponse{Models: DefaultModels()}
} }
......
...@@ -2,7 +2,7 @@ package gemini ...@@ -2,7 +2,7 @@ package gemini
import "testing" import "testing"
func TestDefaultModels_ContainsImageModels(t *testing.T) { func TestDefaultModels_ContainsFallbackCatalogModels(t *testing.T) {
t.Parallel() t.Parallel()
models := DefaultModels() models := DefaultModels()
...@@ -13,6 +13,7 @@ func TestDefaultModels_ContainsImageModels(t *testing.T) { ...@@ -13,6 +13,7 @@ func TestDefaultModels_ContainsImageModels(t *testing.T) {
required := []string{ required := []string{
"models/gemini-2.5-flash-image", "models/gemini-2.5-flash-image",
"models/gemini-3.1-pro-preview-customtools",
"models/gemini-3.1-flash-image", "models/gemini-3.1-flash-image",
} }
...@@ -26,3 +27,17 @@ func TestDefaultModels_ContainsImageModels(t *testing.T) { ...@@ -26,3 +27,17 @@ func TestDefaultModels_ContainsImageModels(t *testing.T) {
} }
} }
} }
func TestHasFallbackModel_RecognizesCustomtoolsModel(t *testing.T) {
t.Parallel()
if !HasFallbackModel("gemini-3.1-pro-preview-customtools") {
t.Fatalf("expected customtools model to exist in fallback catalog")
}
if !HasFallbackModel("models/gemini-3.1-pro-preview-customtools") {
t.Fatalf("expected prefixed customtools model to exist in fallback catalog")
}
if HasFallbackModel("gemini-unknown") {
t.Fatalf("did not expect unknown model to exist in fallback catalog")
}
}
...@@ -515,18 +515,27 @@ func ensureAntigravityDefaultPassthroughs(mapping map[string]string, models []st ...@@ -515,18 +515,27 @@ func ensureAntigravityDefaultPassthroughs(mapping map[string]string, models []st
} }
} }
// IsModelSupported 检查模型是否在 model_mapping 中(支持通配符) func normalizeRequestedModelForLookup(platform, requestedModel string) string {
// 如果未配置 mapping,返回 true(允许所有模型) trimmed := strings.TrimSpace(requestedModel)
func (a *Account) IsModelSupported(requestedModel string) bool { if trimmed == "" {
mapping := a.GetModelMapping() return ""
if len(mapping) == 0 { }
return true // 无映射 = 允许所有 if platform != PlatformGemini && platform != PlatformAntigravity {
return trimmed
}
if trimmed == "gemini-3.1-pro-preview-customtools" {
return "gemini-3.1-pro-preview"
}
return trimmed
}
func mappingSupportsRequestedModel(mapping map[string]string, requestedModel string) bool {
if requestedModel == "" {
return false
} }
// 精确匹配
if _, exists := mapping[requestedModel]; exists { if _, exists := mapping[requestedModel]; exists {
return true return true
} }
// 通配符匹配
for pattern := range mapping { for pattern := range mapping {
if matchWildcard(pattern, requestedModel) { if matchWildcard(pattern, requestedModel) {
return true return true
...@@ -535,6 +544,30 @@ func (a *Account) IsModelSupported(requestedModel string) bool { ...@@ -535,6 +544,30 @@ func (a *Account) IsModelSupported(requestedModel string) bool {
return false return false
} }
func resolveRequestedModelInMapping(mapping map[string]string, requestedModel string) (mappedModel string, matched bool) {
if requestedModel == "" {
return "", false
}
if mappedModel, exists := mapping[requestedModel]; exists {
return mappedModel, true
}
return matchWildcardMappingResult(mapping, requestedModel)
}
// IsModelSupported 检查模型是否在 model_mapping 中(支持通配符)
// 如果未配置 mapping,返回 true(允许所有模型)
func (a *Account) IsModelSupported(requestedModel string) bool {
mapping := a.GetModelMapping()
if len(mapping) == 0 {
return true // 无映射 = 允许所有
}
if mappingSupportsRequestedModel(mapping, requestedModel) {
return true
}
normalized := normalizeRequestedModelForLookup(a.Platform, requestedModel)
return normalized != requestedModel && mappingSupportsRequestedModel(mapping, normalized)
}
// GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配) // GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配)
// 如果未配置 mapping,返回原始模型名 // 如果未配置 mapping,返回原始模型名
func (a *Account) GetMappedModel(requestedModel string) string { func (a *Account) GetMappedModel(requestedModel string) string {
...@@ -549,12 +582,16 @@ func (a *Account) ResolveMappedModel(requestedModel string) (mappedModel string, ...@@ -549,12 +582,16 @@ func (a *Account) ResolveMappedModel(requestedModel string) (mappedModel string,
if len(mapping) == 0 { if len(mapping) == 0 {
return requestedModel, false return requestedModel, false
} }
// 精确匹配优先 if mappedModel, matched := resolveRequestedModelInMapping(mapping, requestedModel); matched {
if mappedModel, exists := mapping[requestedModel]; exists {
return mappedModel, true return mappedModel, true
} }
// 通配符匹配(最长优先) normalized := normalizeRequestedModelForLookup(a.Platform, requestedModel)
return matchWildcardMappingResult(mapping, requestedModel) if normalized != requestedModel {
if mappedModel, matched := resolveRequestedModelInMapping(mapping, normalized); matched {
return mappedModel, true
}
}
return requestedModel, false
} }
func (a *Account) GetBaseURL() string { func (a *Account) GetBaseURL() string {
......
...@@ -133,6 +133,7 @@ func TestMatchWildcardMappingResult(t *testing.T) { ...@@ -133,6 +133,7 @@ func TestMatchWildcardMappingResult(t *testing.T) {
func TestAccountIsModelSupported(t *testing.T) { func TestAccountIsModelSupported(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
platform string
credentials map[string]any credentials map[string]any
requestedModel string requestedModel string
expected bool expected bool
...@@ -184,6 +185,17 @@ func TestAccountIsModelSupported(t *testing.T) { ...@@ -184,6 +185,17 @@ func TestAccountIsModelSupported(t *testing.T) {
requestedModel: "claude-opus-4-5-thinking", requestedModel: "claude-opus-4-5-thinking",
expected: true, expected: true,
}, },
{
name: "gemini customtools alias matches normalized mapping",
platform: PlatformGemini,
credentials: map[string]any{
"model_mapping": map[string]any{
"gemini-3.1-pro-preview": "gemini-3.1-pro-preview",
},
},
requestedModel: "gemini-3.1-pro-preview-customtools",
expected: true,
},
{ {
name: "wildcard match not supported", name: "wildcard match not supported",
credentials: map[string]any{ credentials: map[string]any{
...@@ -199,6 +211,7 @@ func TestAccountIsModelSupported(t *testing.T) { ...@@ -199,6 +211,7 @@ func TestAccountIsModelSupported(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
account := &Account{ account := &Account{
Platform: tt.platform,
Credentials: tt.credentials, Credentials: tt.credentials,
} }
result := account.IsModelSupported(tt.requestedModel) result := account.IsModelSupported(tt.requestedModel)
...@@ -212,6 +225,7 @@ func TestAccountIsModelSupported(t *testing.T) { ...@@ -212,6 +225,7 @@ func TestAccountIsModelSupported(t *testing.T) {
func TestAccountGetMappedModel(t *testing.T) { func TestAccountGetMappedModel(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
platform string
credentials map[string]any credentials map[string]any
requestedModel string requestedModel string
expected string expected string
...@@ -223,6 +237,13 @@ func TestAccountGetMappedModel(t *testing.T) { ...@@ -223,6 +237,13 @@ func TestAccountGetMappedModel(t *testing.T) {
requestedModel: "claude-sonnet-4-5", requestedModel: "claude-sonnet-4-5",
expected: "claude-sonnet-4-5", expected: "claude-sonnet-4-5",
}, },
{
name: "no mapping preserves gemini customtools model",
platform: PlatformGemini,
credentials: nil,
requestedModel: "gemini-3.1-pro-preview-customtools",
expected: "gemini-3.1-pro-preview-customtools",
},
// 精确匹配 // 精确匹配
{ {
...@@ -250,6 +271,29 @@ func TestAccountGetMappedModel(t *testing.T) { ...@@ -250,6 +271,29 @@ func TestAccountGetMappedModel(t *testing.T) {
}, },
// 无匹配返回原始模型 // 无匹配返回原始模型
{
name: "gemini customtools alias resolves through normalized mapping",
platform: PlatformGemini,
credentials: map[string]any{
"model_mapping": map[string]any{
"gemini-3.1-pro-preview": "gemini-3.1-pro-preview",
},
},
requestedModel: "gemini-3.1-pro-preview-customtools",
expected: "gemini-3.1-pro-preview",
},
{
name: "gemini customtools exact mapping wins over normalized fallback",
platform: PlatformGemini,
credentials: map[string]any{
"model_mapping": map[string]any{
"gemini-3.1-pro-preview": "gemini-3.1-pro-preview",
"gemini-3.1-pro-preview-customtools": "gemini-3.1-pro-preview-customtools",
},
},
requestedModel: "gemini-3.1-pro-preview-customtools",
expected: "gemini-3.1-pro-preview-customtools",
},
{ {
name: "no match returns original", name: "no match returns original",
credentials: map[string]any{ credentials: map[string]any{
...@@ -265,6 +309,7 @@ func TestAccountGetMappedModel(t *testing.T) { ...@@ -265,6 +309,7 @@ func TestAccountGetMappedModel(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
account := &Account{ account := &Account{
Platform: tt.platform,
Credentials: tt.credentials, Credentials: tt.credentials,
} }
result := account.GetMappedModel(tt.requestedModel) result := account.GetMappedModel(tt.requestedModel)
...@@ -278,6 +323,7 @@ func TestAccountGetMappedModel(t *testing.T) { ...@@ -278,6 +323,7 @@ func TestAccountGetMappedModel(t *testing.T) {
func TestAccountResolveMappedModel(t *testing.T) { func TestAccountResolveMappedModel(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
platform string
credentials map[string]any credentials map[string]any
requestedModel string requestedModel string
expectedModel string expectedModel string
...@@ -312,6 +358,31 @@ func TestAccountResolveMappedModel(t *testing.T) { ...@@ -312,6 +358,31 @@ func TestAccountResolveMappedModel(t *testing.T) {
expectedModel: "gpt-5.4", expectedModel: "gpt-5.4",
expectedMatch: true, expectedMatch: true,
}, },
{
name: "gemini customtools alias reports normalized match",
platform: PlatformGemini,
credentials: map[string]any{
"model_mapping": map[string]any{
"gemini-3.1-pro-preview": "gemini-3.1-pro-preview",
},
},
requestedModel: "gemini-3.1-pro-preview-customtools",
expectedModel: "gemini-3.1-pro-preview",
expectedMatch: true,
},
{
name: "gemini customtools exact mapping reports exact match",
platform: PlatformGemini,
credentials: map[string]any{
"model_mapping": map[string]any{
"gemini-3.1-pro-preview": "gemini-3.1-pro-preview",
"gemini-3.1-pro-preview-customtools": "gemini-3.1-pro-preview-customtools",
},
},
requestedModel: "gemini-3.1-pro-preview-customtools",
expectedModel: "gemini-3.1-pro-preview-customtools",
expectedMatch: true,
},
{ {
name: "missing mapping reports unmatched", name: "missing mapping reports unmatched",
credentials: map[string]any{ credentials: map[string]any{
...@@ -328,6 +399,7 @@ func TestAccountResolveMappedModel(t *testing.T) { ...@@ -328,6 +399,7 @@ func TestAccountResolveMappedModel(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
account := &Account{ account := &Account{
Platform: tt.platform,
Credentials: tt.credentials, Credentials: tt.credentials,
} }
mappedModel, matched := account.ResolveMappedModel(tt.requestedModel) mappedModel, matched := account.ResolveMappedModel(tt.requestedModel)
......
...@@ -268,6 +268,12 @@ func TestMapAntigravityModel_WildcardTargetEqualsRequest(t *testing.T) { ...@@ -268,6 +268,12 @@ func TestMapAntigravityModel_WildcardTargetEqualsRequest(t *testing.T) {
requestedModel: "gemini-2.5-flash", requestedModel: "gemini-2.5-flash",
expected: "gemini-2.5-flash", expected: "gemini-2.5-flash",
}, },
{
name: "customtools alias falls back to normalized preview mapping",
modelMapping: map[string]any{"gemini-3.1-pro-preview": "gemini-3.1-pro-high"},
requestedModel: "gemini-3.1-pro-preview-customtools",
expected: "gemini-3.1-pro-high",
},
} }
for _, tt := range tests { for _, tt := range tests {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment