Unverified Commit 0537a490 authored by Oliver Li's avatar Oliver Li Committed by GitHub
Browse files

Merge branch 'Wei-Shaw:main' into vertex

parents 3f05ef2a c92b88e3
......@@ -117,12 +117,7 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
return
}
sessionHash := ""
if parsed.Multipart {
sessionHash = h.gatewayService.GenerateSessionHashWithFallback(c, nil, parsed.StickySessionSeed())
} else {
sessionHash = h.gatewayService.GenerateSessionHash(c, body)
}
sessionHash := h.gatewayService.GenerateExplicitSessionHash(c, body)
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
......
......@@ -258,6 +258,48 @@ func TestResponsesToAnthropic_ToolUse(t *testing.T) {
assert.Equal(t, "tool_use", anth.Content[1].Type)
assert.Equal(t, "call_1", anth.Content[1].ID)
assert.Equal(t, "get_weather", anth.Content[1].Name)
assert.JSONEq(t, `{"city":"NYC"}`, string(anth.Content[1].Input))
}
func TestResponsesToAnthropic_ReadToolDropsEmptyPages(t *testing.T) {
resp := &ResponsesResponse{
ID: "resp_read",
Model: "gpt-5.5",
Status: "completed",
Output: []ResponsesOutput{
{
Type: "function_call",
CallID: "call_read",
Name: "Read",
Arguments: `{"file_path":"/tmp/demo.py","limit":2000,"offset":0,"pages":""}`,
},
},
}
anth := ResponsesToAnthropic(resp, "claude-opus-4-6")
require.Len(t, anth.Content, 1)
assert.Equal(t, "tool_use", anth.Content[0].Type)
assert.JSONEq(t, `{"file_path":"/tmp/demo.py","limit":2000,"offset":0}`, string(anth.Content[0].Input))
}
func TestResponsesToAnthropic_PreservesEmptyStringsForOtherTools(t *testing.T) {
resp := &ResponsesResponse{
ID: "resp_other",
Model: "gpt-5.5",
Status: "completed",
Output: []ResponsesOutput{
{
Type: "function_call",
CallID: "call_other",
Name: "Search",
Arguments: `{"query":""}`,
},
},
}
anth := ResponsesToAnthropic(resp, "claude-opus-4-6")
require.Len(t, anth.Content, 1)
assert.JSONEq(t, `{"query":""}`, string(anth.Content[0].Input))
}
func TestResponsesToAnthropic_Reasoning(t *testing.T) {
......@@ -472,6 +514,41 @@ func TestStreamingToolCall(t *testing.T) {
assert.Equal(t, "tool_use", events[0].Delta.StopReason)
}
func TestStreamingReadToolDropsEmptyPages(t *testing.T) {
state := NewResponsesEventToAnthropicState()
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.created",
Response: &ResponsesResponse{ID: "resp_read_stream", Model: "gpt-5.5"},
}, state)
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.output_item.added",
OutputIndex: 0,
Item: &ResponsesOutput{Type: "function_call", CallID: "call_read", Name: "Read"},
}, state)
require.Len(t, events, 1)
assert.Equal(t, "content_block_start", events[0].Type)
events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.function_call_arguments.delta",
OutputIndex: 0,
Delta: `{"file_path":"/tmp/demo.py","limit":2000,"offset":0,"pages":""}`,
}, state)
assert.Len(t, events, 0)
events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.function_call_arguments.done",
OutputIndex: 0,
Arguments: `{"file_path":"/tmp/demo.py","limit":2000,"offset":0,"pages":""}`,
}, state)
require.Len(t, events, 2)
assert.Equal(t, "content_block_delta", events[0].Type)
assert.Equal(t, "input_json_delta", events[0].Delta.Type)
assert.JSONEq(t, `{"file_path":"/tmp/demo.py","limit":2000,"offset":0}`, events[0].Delta.PartialJSON)
assert.Equal(t, "content_block_stop", events[1].Type)
}
func TestStreamingReasoning(t *testing.T) {
state := NewResponsesEventToAnthropicState()
......
......@@ -52,7 +52,7 @@ func ResponsesToAnthropic(resp *ResponsesResponse, model string) *AnthropicRespo
Type: "tool_use",
ID: fromResponsesCallID(item.CallID),
Name: item.Name,
Input: json.RawMessage(item.Arguments),
Input: sanitizeAnthropicToolUseInput(item.Name, item.Arguments),
})
case "web_search_call":
toolUseID := "srvtoolu_" + item.ID
......@@ -129,6 +129,28 @@ func responsesStatusToAnthropicStopReason(status string, details *ResponsesIncom
}
}
func sanitizeAnthropicToolUseInput(name string, raw string) json.RawMessage {
if name != "Read" || raw == "" {
return json.RawMessage(raw)
}
var input map[string]json.RawMessage
if err := json.Unmarshal([]byte(raw), &input); err != nil {
return json.RawMessage(raw)
}
if pages, ok := input["pages"]; !ok || string(pages) != `""` {
return json.RawMessage(raw)
}
delete(input, "pages")
sanitized, err := json.Marshal(input)
if err != nil {
return json.RawMessage(raw)
}
return sanitized
}
// ---------------------------------------------------------------------------
// Streaming: ResponsesStreamEvent → []AnthropicStreamEvent (stateful converter)
// ---------------------------------------------------------------------------
......@@ -142,6 +164,8 @@ type ResponsesEventToAnthropicState struct {
ContentBlockIndex int
ContentBlockOpen bool
CurrentBlockType string // "text" | "thinking" | "tool_use"
CurrentToolName string
CurrentToolArgs string
// OutputIndexToBlockIdx maps Responses output_index → Anthropic content block index.
OutputIndexToBlockIdx map[int]int
......@@ -181,7 +205,7 @@ func ResponsesEventToAnthropicEvents(
case "response.function_call_arguments.delta":
return resToAnthHandleFuncArgsDelta(evt, state)
case "response.function_call_arguments.done":
return resToAnthHandleBlockDone(state)
return resToAnthHandleFuncArgsDone(evt, state)
case "response.output_item.done":
return resToAnthHandleOutputItemDone(evt, state)
case "response.reasoning_summary_text.delta":
......@@ -278,6 +302,8 @@ func resToAnthHandleOutputItemAdded(evt *ResponsesStreamEvent, state *ResponsesE
state.OutputIndexToBlockIdx[evt.OutputIndex] = idx
state.ContentBlockOpen = true
state.CurrentBlockType = "tool_use"
state.CurrentToolName = evt.Item.Name
state.CurrentToolArgs = ""
events = append(events, AnthropicStreamEvent{
Type: "content_block_start",
......@@ -358,6 +384,11 @@ func resToAnthHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEve
return nil
}
if state.CurrentBlockType == "tool_use" && state.CurrentToolName == "Read" {
state.CurrentToolArgs += evt.Delta
return nil
}
blockIdx, ok := state.OutputIndexToBlockIdx[evt.OutputIndex]
if !ok {
return nil
......@@ -373,6 +404,33 @@ func resToAnthHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEve
}}
}
func resToAnthHandleFuncArgsDone(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
if state.CurrentBlockType != "tool_use" || state.CurrentToolName != "Read" {
return resToAnthHandleBlockDone(state)
}
raw := evt.Arguments
if raw == "" {
raw = state.CurrentToolArgs
}
sanitized := sanitizeAnthropicToolUseInput(state.CurrentToolName, raw)
if len(sanitized) == 0 {
return closeCurrentBlock(state)
}
idx := state.ContentBlockIndex
events := []AnthropicStreamEvent{{
Type: "content_block_delta",
Index: &idx,
Delta: &AnthropicDelta{
Type: "input_json_delta",
PartialJSON: string(sanitized),
},
}}
events = append(events, closeCurrentBlock(state)...)
return events
}
func resToAnthHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
if evt.Delta == "" {
return nil
......@@ -524,6 +582,8 @@ func closeCurrentBlock(state *ResponsesEventToAnthropicState) []AnthropicStreamE
idx := state.ContentBlockIndex
state.ContentBlockOpen = false
state.ContentBlockIndex++
state.CurrentToolName = ""
state.CurrentToolArgs = ""
return []AnthropicStreamEvent{{
Type: "content_block_stop",
Index: &idx,
......
......@@ -1125,6 +1125,35 @@ func (s *OpenAIGatewayService) ExtractSessionID(c *gin.Context, body []byte) str
return sessionID
}
func explicitOpenAISessionID(c *gin.Context, body []byte) string {
if c == nil {
return ""
}
sessionID := strings.TrimSpace(c.GetHeader("session_id"))
if sessionID == "" {
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
}
if sessionID == "" && len(body) > 0 {
sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
}
return sessionID
}
// GenerateExplicitSessionHash generates a sticky-session hash only from explicit
// client session signals. It intentionally skips content-derived fallback and is
// used by stateless endpoints such as /v1/images.
func (s *OpenAIGatewayService) GenerateExplicitSessionHash(c *gin.Context, body []byte) string {
sessionID := explicitOpenAISessionID(c, body)
if sessionID == "" {
return ""
}
currentHash, legacyHash := deriveOpenAISessionHashes(sessionID)
attachOpenAILegacySessionHashToGin(c, legacyHash)
return currentHash
}
// GenerateSessionHash generates a sticky-session hash for OpenAI requests.
//
// Priority:
......@@ -1137,13 +1166,7 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, body []byte)
return ""
}
sessionID := strings.TrimSpace(c.GetHeader("session_id"))
if sessionID == "" {
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
}
if sessionID == "" && len(body) > 0 {
sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
}
sessionID := explicitOpenAISessionID(c, body)
if sessionID == "" && len(body) > 0 {
sessionID = deriveOpenAIContentSessionSeed(body)
}
......
......@@ -227,6 +227,41 @@ func TestOpenAIGatewayService_GenerateSessionHash_AttachesLegacyHashToContext(t
require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context()))
}
func TestOpenAIGatewayService_GenerateExplicitSessionHash_SkipsContentFallback(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := &OpenAIGatewayService{}
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat"}`)
t.Run("stateless image body stays unstuck", func(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
require.Empty(t, svc.GenerateExplicitSessionHash(c, body))
require.Empty(t, openAILegacySessionHashFromContext(c.Request.Context()))
})
t.Run("prompt_cache_key is explicit", func(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
got := svc.GenerateExplicitSessionHash(c, []byte(`{"model":"gpt-image-2","prompt_cache_key":"image-session"}`))
require.Equal(t, fmt.Sprintf("%016x", xxhash.Sum64String("image-session")), got)
require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context()))
})
t.Run("header overrides body", func(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
c.Request.Header.Set("session_id", "header-session")
got := svc.GenerateExplicitSessionHash(c, []byte(`{"prompt_cache_key":"body-session"}`))
require.Equal(t, fmt.Sprintf("%016x", xxhash.Sum64String("header-session")), got)
})
}
func TestOpenAIGatewayService_GenerateSessionHashWithFallback(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment