Commit 04b2866f authored by ivanvolt's avatar ivanvolt
Browse files

fix: use Responses-compatible function tool_choice format

parent b0a2252e
...@@ -991,9 +991,40 @@ func TestAnthropicToResponses_ToolChoiceSpecific(t *testing.T) { ...@@ -991,9 +991,40 @@ func TestAnthropicToResponses_ToolChoiceSpecific(t *testing.T) {
var tc map[string]any var tc map[string]any
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc)) require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
assert.Equal(t, "function", tc["type"]) assert.Equal(t, "function", tc["type"])
fn, ok := tc["function"].(map[string]any) assert.Equal(t, "get_weather", tc["name"])
require.True(t, ok) assert.NotContains(t, tc, "function")
assert.Equal(t, "get_weather", fn["name"]) }
func TestResponsesToAnthropicRequest_ToolChoiceFunctionName(t *testing.T) {
req := &ResponsesRequest{
Model: "gpt-5.2",
Input: json.RawMessage(`[{"role":"user","content":"Hello"}]`),
ToolChoice: json.RawMessage(`{"type":"function","name":"get_weather"}`),
}
resp, err := ResponsesToAnthropicRequest(req)
require.NoError(t, err)
var tc map[string]string
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
assert.Equal(t, "tool", tc["type"])
assert.Equal(t, "get_weather", tc["name"])
}
func TestResponsesToAnthropicRequest_ToolChoiceLegacyFunctionName(t *testing.T) {
req := &ResponsesRequest{
Model: "gpt-5.2",
Input: json.RawMessage(`[{"role":"user","content":"Hello"}]`),
ToolChoice: json.RawMessage(`{"type":"function","function":{"name":"get_weather"}}`),
}
resp, err := ResponsesToAnthropicRequest(req)
require.NoError(t, err)
var tc map[string]string
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
assert.Equal(t, "tool", tc["type"])
assert.Equal(t, "get_weather", tc["name"])
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
......
...@@ -75,7 +75,7 @@ func AnthropicToResponses(req *AnthropicRequest) (*ResponsesRequest, error) { ...@@ -75,7 +75,7 @@ func AnthropicToResponses(req *AnthropicRequest) (*ResponsesRequest, error) {
// {"type":"auto"} → "auto" // {"type":"auto"} → "auto"
// {"type":"any"} → "required" // {"type":"any"} → "required"
// {"type":"none"} → "none" // {"type":"none"} → "none"
// {"type":"tool","name":"X"} → {"type":"function","function":{"name":"X"}} // {"type":"tool","name":"X"} → {"type":"function","name":"X"}
func convertAnthropicToolChoiceToResponses(raw json.RawMessage) (json.RawMessage, error) { func convertAnthropicToolChoiceToResponses(raw json.RawMessage) (json.RawMessage, error) {
var tc struct { var tc struct {
Type string `json:"type"` Type string `json:"type"`
...@@ -95,7 +95,7 @@ func convertAnthropicToolChoiceToResponses(raw json.RawMessage) (json.RawMessage ...@@ -95,7 +95,7 @@ func convertAnthropicToolChoiceToResponses(raw json.RawMessage) (json.RawMessage
case "tool": case "tool":
return json.Marshal(map[string]any{ return json.Marshal(map[string]any{
"type": "function", "type": "function",
"function": map[string]string{"name": tc.Name}, "name": tc.Name,
}) })
default: default:
// Pass through unknown types as-is // Pass through unknown types as-is
......
...@@ -281,6 +281,8 @@ func TestChatCompletionsToResponses_LegacyFunctions(t *testing.T) { ...@@ -281,6 +281,8 @@ func TestChatCompletionsToResponses_LegacyFunctions(t *testing.T) {
var tc map[string]any var tc map[string]any
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc)) require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
assert.Equal(t, "function", tc["type"]) assert.Equal(t, "function", tc["type"])
assert.Equal(t, "get_weather", tc["name"])
assert.NotContains(t, tc, "function")
} }
func TestChatCompletionsToResponses_ServiceTier(t *testing.T) { func TestChatCompletionsToResponses_ServiceTier(t *testing.T) {
......
...@@ -420,7 +420,7 @@ func convertChatToolsToResponses(tools []ChatTool, functions []ChatFunction) []R ...@@ -420,7 +420,7 @@ func convertChatToolsToResponses(tools []ChatTool, functions []ChatFunction) []R
// //
// "auto" → "auto" // "auto" → "auto"
// "none" → "none" // "none" → "none"
// {"name":"X"} → {"type":"function","function":{"name":"X"}} // {"name":"X"} → {"type":"function","name":"X"}
func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage, error) { func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage, error) {
// Try string first ("auto", "none", etc.) — pass through as-is. // Try string first ("auto", "none", etc.) — pass through as-is.
var s string var s string
...@@ -437,6 +437,6 @@ func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage, ...@@ -437,6 +437,6 @@ func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage,
} }
return json.Marshal(map[string]any{ return json.Marshal(map[string]any{
"type": "function", "type": "function",
"function": map[string]string{"name": obj.Name}, "name": obj.Name,
}) })
} }
...@@ -428,7 +428,8 @@ func normalizeAnthropicInputSchema(schema json.RawMessage) json.RawMessage { ...@@ -428,7 +428,8 @@ func normalizeAnthropicInputSchema(schema json.RawMessage) json.RawMessage {
// "auto" → {"type":"auto"} // "auto" → {"type":"auto"}
// "required" → {"type":"any"} // "required" → {"type":"any"}
// "none" → {"type":"none"} // "none" → {"type":"none"}
// {"type":"function","function":{"name":"X"}} → {"type":"tool","name":"X"} // {"type":"function","name":"X"} → {"type":"tool","name":"X"}
// {"type":"function","function":{"name":"X"}} → {"type":"tool","name":"X"} // legacy
func convertResponsesToAnthropicToolChoice(raw json.RawMessage) (json.RawMessage, error) { func convertResponsesToAnthropicToolChoice(raw json.RawMessage) (json.RawMessage, error) {
// Try as string first // Try as string first
var s string var s string
...@@ -448,14 +449,22 @@ func convertResponsesToAnthropicToolChoice(raw json.RawMessage) (json.RawMessage ...@@ -448,14 +449,22 @@ func convertResponsesToAnthropicToolChoice(raw json.RawMessage) (json.RawMessage
// Try as object with type=function // Try as object with type=function
var tc struct { var tc struct {
Type string `json:"type"` Type string `json:"type"`
Name string `json:"name"`
Function struct { Function struct {
Name string `json:"name"` Name string `json:"name"`
} `json:"function"` } `json:"function"`
} }
if err := json.Unmarshal(raw, &tc); err == nil && tc.Type == "function" && tc.Function.Name != "" { if err := json.Unmarshal(raw, &tc); err == nil && tc.Type == "function" {
name := strings.TrimSpace(tc.Name)
if name == "" {
name = strings.TrimSpace(tc.Function.Name)
}
if name == "" {
return raw, nil
}
return json.Marshal(map[string]string{ return json.Marshal(map[string]string{
"type": "tool", "type": "tool",
"name": tc.Function.Name, "name": name,
}) })
} }
......
...@@ -141,9 +141,7 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact ...@@ -141,9 +141,7 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
if name, ok := fcObj["name"].(string); ok && strings.TrimSpace(name) != "" { if name, ok := fcObj["name"].(string); ok && strings.TrimSpace(name) != "" {
reqBody["tool_choice"] = map[string]any{ reqBody["tool_choice"] = map[string]any{
"type": "function", "type": "function",
"function": map[string]any{
"name": name, "name": name,
},
} }
} }
} }
...@@ -219,9 +217,38 @@ func normalizeCodexToolChoice(reqBody map[string]any) bool { ...@@ -219,9 +217,38 @@ func normalizeCodexToolChoice(reqBody map[string]any) bool {
return false return false
} }
choiceType := strings.TrimSpace(firstNonEmptyString(choiceMap["type"])) choiceType := strings.TrimSpace(firstNonEmptyString(choiceMap["type"]))
if choiceType == "" || codexToolsContainType(reqBody["tools"], choiceType) { if choiceType == "" {
return false return false
} }
modified := false
if choiceType == "function" {
name := strings.TrimSpace(firstNonEmptyString(choiceMap["name"]))
if name == "" {
if function, ok := choiceMap["function"].(map[string]any); ok {
name = strings.TrimSpace(firstNonEmptyString(function["name"]))
}
}
if name == "" {
reqBody["tool_choice"] = "auto"
return true
}
if strings.TrimSpace(firstNonEmptyString(choiceMap["name"])) != name {
choiceMap["name"] = name
modified = true
}
if _, ok := choiceMap["function"]; ok {
delete(choiceMap, "function")
modified = true
}
if !codexToolsContainFunctionName(reqBody["tools"], name) {
reqBody["tool_choice"] = "auto"
return true
}
return modified
}
if codexToolsContainType(reqBody["tools"], choiceType) {
return modified
}
reqBody["tool_choice"] = "auto" reqBody["tool_choice"] = "auto"
return true return true
} }
...@@ -243,6 +270,33 @@ func codexToolsContainType(rawTools any, toolType string) bool { ...@@ -243,6 +270,33 @@ func codexToolsContainType(rawTools any, toolType string) bool {
return false return false
} }
func codexToolsContainFunctionName(rawTools any, name string) bool {
tools, ok := rawTools.([]any)
if !ok || strings.TrimSpace(name) == "" {
return false
}
normalizedName := strings.TrimSpace(name)
for _, rawTool := range tools {
tool, ok := rawTool.(map[string]any)
if !ok {
continue
}
if strings.TrimSpace(firstNonEmptyString(tool["type"])) != "function" {
continue
}
toolName := strings.TrimSpace(firstNonEmptyString(tool["name"]))
if toolName == "" {
if function, ok := tool["function"].(map[string]any); ok {
toolName = strings.TrimSpace(firstNonEmptyString(function["name"]))
}
}
if toolName == normalizedName {
return true
}
}
return false
}
func normalizeCodexToolRoleMessages(input []any) ([]any, bool) { func normalizeCodexToolRoleMessages(input []any) ([]any, bool) {
if len(input) == 0 { if len(input) == 0 {
return input, false return input, false
......
...@@ -249,6 +249,44 @@ func TestApplyCodexOAuthTransform_PreservesKnownToolChoice(t *testing.T) { ...@@ -249,6 +249,44 @@ func TestApplyCodexOAuthTransform_PreservesKnownToolChoice(t *testing.T) {
require.Equal(t, "custom", choice["type"]) require.Equal(t, "custom", choice["type"])
} }
func TestApplyCodexOAuthTransform_NormalizesLegacyFunctionToolChoice(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
"tools": []any{
map[string]any{"type": "function", "name": "shell"},
},
"tool_choice": map[string]any{
"type": "function",
"function": map[string]any{"name": "shell"},
},
}
applyCodexOAuthTransform(reqBody, true, false)
choice, ok := reqBody["tool_choice"].(map[string]any)
require.True(t, ok)
require.Equal(t, "function", choice["type"])
require.Equal(t, "shell", choice["name"])
require.NotContains(t, choice, "function")
}
func TestApplyCodexOAuthTransform_DowngradesMissingFunctionToolChoice(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
"tools": []any{
map[string]any{"type": "function", "name": "shell"},
},
"tool_choice": map[string]any{
"type": "function",
"function": map[string]any{"name": "missing"},
},
}
applyCodexOAuthTransform(reqBody, true, false)
require.Equal(t, "auto", reqBody["tool_choice"])
}
func TestApplyCodexOAuthTransform_AddsFallbackNameForFunctionCallInput(t *testing.T) { func TestApplyCodexOAuthTransform_AddsFallbackNameForFunctionCallInput(t *testing.T) {
reqBody := map[string]any{ reqBody := map[string]any{
"model": "gpt-5.4", "model": "gpt-5.4",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment