Unverified Commit 7f8f3fe0 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #2100 from KnowSky404/fix/codex-cli-edit-resend-tool-continuation

[codex] fix WS continuation inference for explicit tool replay
parents 46f06b24 f7c13af1
......@@ -1366,16 +1366,25 @@ func setPreviousResponseIDToRawPayload(payload []byte, previousResponseID string
func shouldInferIngressFunctionCallOutputPreviousResponseID(
storeDisabled bool,
turn int,
hasFunctionCallOutput bool,
signals ToolContinuationSignals,
currentPreviousResponseID string,
expectedPreviousResponseID string,
) bool {
if !storeDisabled || turn <= 1 || !hasFunctionCallOutput {
if !storeDisabled || turn <= 1 || !signals.HasFunctionCallOutput {
return false
}
if strings.TrimSpace(currentPreviousResponseID) != "" {
return false
}
if signals.HasFunctionCallOutputMissingCallID {
return false
}
// If the client already sent tool-call context or item_reference anchors,
// treat this as a full replay / self-contained continuation payload rather
// than downgrading it into an inferred delta continuation.
if signals.HasToolCallContext || signals.HasItemReferenceForAllCallIDs {
return false
}
return strings.TrimSpace(expectedPreviousResponseID) != ""
}
......@@ -3179,13 +3188,22 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
skipBeforeTurn = false
currentPreviousResponseID := openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id")
expectedPrev := strings.TrimSpace(lastTurnResponseID)
hasFunctionCallOutput := gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists()
toolSignals := ToolContinuationSignals{
HasFunctionCallOutput: gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists(),
}
if toolSignals.HasFunctionCallOutput {
var currentReqBody map[string]any
if err := json.Unmarshal(currentPayload, &currentReqBody); err == nil {
toolSignals = AnalyzeToolContinuationSignals(currentReqBody)
}
}
hasFunctionCallOutput := toolSignals.HasFunctionCallOutput
// store=false + function_call_output 场景必须有续链锚点。
// 若客户端未传 previous_response_id,优先回填上一轮响应 ID,避免上游报 call_id 无法关联。
if shouldInferIngressFunctionCallOutputPreviousResponseID(
storeDisabled,
turn,
hasFunctionCallOutput,
toolSignals,
currentPreviousResponseID,
expectedPrev,
) {
......
......@@ -1354,6 +1354,274 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFun
require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "上一轮缺失 response.id 时不应自动补齐 previous_response_id")
}
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputSkipsAutoAttachWhenToolCallContextPresent(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
captureConn := &openAIWSCaptureConn{
events: [][]byte{
[]byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ctx_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
[]byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ctx_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
},
}
captureDialer := &openAIWSQueueDialer{
conns: []openAIWSClientConn{captureConn},
}
pool := newOpenAIWSConnPool(cfg)
pool.setClientDialerForTest(captureDialer)
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
openaiWSPool: pool,
}
account := &Account{
ID: 114,
Name: "openai-ingress-tool-context",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
serverErrCh := make(chan error, 1)
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
CompressionMode: coderws.CompressionContextTakeover,
})
if err != nil {
serverErrCh <- err
return
}
defer func() {
_ = conn.CloseNow()
}()
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
req := r.Clone(r.Context())
req.Header = req.Header.Clone()
req.Header.Set("User-Agent", "unit-test-agent/1.0")
ginCtx.Request = req
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
msgType, firstMessage, readErr := conn.Read(readCtx)
cancel()
if readErr != nil {
serverErrCh <- readErr
return
}
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
serverErrCh <- errors.New("unsupported websocket client message type")
return
}
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
}))
defer wsServer.Close()
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
cancelDial()
require.NoError(t, err)
defer func() {
_ = clientConn.CloseNow()
}()
writeMessage := func(payload string) {
writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
}
readMessage := func() []byte {
readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
msgType, message, readErr := clientConn.Read(readCtx)
require.NoError(t, readErr)
require.Equal(t, coderws.MessageText, msgType)
return message
}
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
firstTurn := readMessage()
require.Equal(t, "resp_auto_prev_ctx_1", gjson.GetBytes(firstTurn, "response.id").String())
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"function_call","call_id":"call_ctx_1","name":"shell","arguments":"{}"},{"type":"function_call_output","call_id":"call_ctx_1","output":"ok"},{"type":"message","role":"user","content":[{"type":"input_text","text":"retry"}]}]}`)
secondTurn := readMessage()
require.Equal(t, "resp_auto_prev_ctx_2", gjson.GetBytes(secondTurn, "response.id").String())
require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
select {
case serverErr := <-serverErrCh:
require.NoError(t, serverErr)
case <-time.After(5 * time.Second):
t.Fatal("等待 ingress websocket 结束超时")
}
require.Equal(t, 1, captureDialer.DialCount())
require.Len(t, captureConn.writes, 2)
require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "请求已包含 function_call 上下文时不应自动补齐 previous_response_id")
}
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputSkipsAutoAttachWhenItemReferencesPresent(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
captureConn := &openAIWSCaptureConn{
events: [][]byte{
[]byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ref_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
[]byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ref_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
},
}
captureDialer := &openAIWSQueueDialer{
conns: []openAIWSClientConn{captureConn},
}
pool := newOpenAIWSConnPool(cfg)
pool.setClientDialerForTest(captureDialer)
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
openaiWSPool: pool,
}
account := &Account{
ID: 115,
Name: "openai-ingress-item-reference",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
serverErrCh := make(chan error, 1)
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
CompressionMode: coderws.CompressionContextTakeover,
})
if err != nil {
serverErrCh <- err
return
}
defer func() {
_ = conn.CloseNow()
}()
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
req := r.Clone(r.Context())
req.Header = req.Header.Clone()
req.Header.Set("User-Agent", "unit-test-agent/1.0")
ginCtx.Request = req
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
msgType, firstMessage, readErr := conn.Read(readCtx)
cancel()
if readErr != nil {
serverErrCh <- readErr
return
}
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
serverErrCh <- errors.New("unsupported websocket client message type")
return
}
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
}))
defer wsServer.Close()
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
cancelDial()
require.NoError(t, err)
defer func() {
_ = clientConn.CloseNow()
}()
writeMessage := func(payload string) {
writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
}
readMessage := func() []byte {
readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
msgType, message, readErr := clientConn.Read(readCtx)
require.NoError(t, readErr)
require.Equal(t, coderws.MessageText, msgType)
return message
}
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
firstTurn := readMessage()
require.Equal(t, "resp_auto_prev_ref_1", gjson.GetBytes(firstTurn, "response.id").String())
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"item_reference","id":"call_ref_1"},{"type":"function_call_output","call_id":"call_ref_1","output":"ok"},{"type":"message","role":"user","content":[{"type":"input_text","text":"retry"}]}]}`)
secondTurn := readMessage()
require.Equal(t, "resp_auto_prev_ref_2", gjson.GetBytes(secondTurn, "response.id").String())
require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
select {
case serverErr := <-serverErrCh:
require.NoError(t, serverErr)
case <-time.After(5 * time.Second):
t.Fatal("等待 ingress websocket 结束超时")
}
require.Equal(t, 1, captureDialer.DialCount())
require.Len(t, captureConn.writes, 2)
require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "请求已包含 item_reference 锚点时不应自动补齐 previous_response_id")
}
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PreflightPingFailReconnectsBeforeTurn(t *testing.T) {
gin.SetMode(gin.TestMode)
prevPreflightPingIdle := openAIWSIngressPreflightPingIdle
......
......@@ -232,67 +232,91 @@ func TestShouldInferIngressFunctionCallOutputPreviousResponseID(t *testing.T) {
name string
storeDisabled bool
turn int
hasFunctionCallOutput bool
signals ToolContinuationSignals
currentPreviousResponse string
expectedPrevious string
want bool
}{
{
name: "infer_when_all_conditions_match",
storeDisabled: true,
turn: 2,
hasFunctionCallOutput: true,
expectedPrevious: "resp_1",
want: true,
name: "infer_when_all_conditions_match",
storeDisabled: true,
turn: 2,
signals: ToolContinuationSignals{HasFunctionCallOutput: true},
expectedPrevious: "resp_1",
want: true,
},
{
name: "skip_when_store_enabled",
storeDisabled: false,
turn: 2,
hasFunctionCallOutput: true,
expectedPrevious: "resp_1",
want: false,
name: "skip_when_store_enabled",
storeDisabled: false,
turn: 2,
signals: ToolContinuationSignals{HasFunctionCallOutput: true},
expectedPrevious: "resp_1",
want: false,
},
{
name: "skip_on_first_turn",
storeDisabled: true,
turn: 1,
hasFunctionCallOutput: true,
expectedPrevious: "resp_1",
want: false,
name: "skip_on_first_turn",
storeDisabled: true,
turn: 1,
signals: ToolContinuationSignals{HasFunctionCallOutput: true},
expectedPrevious: "resp_1",
want: false,
},
{
name: "skip_without_function_call_output",
storeDisabled: true,
turn: 2,
hasFunctionCallOutput: false,
expectedPrevious: "resp_1",
want: false,
name: "skip_without_function_call_output",
storeDisabled: true,
turn: 2,
signals: ToolContinuationSignals{},
expectedPrevious: "resp_1",
want: false,
},
{
name: "skip_when_request_already_has_previous_response_id",
storeDisabled: true,
turn: 2,
hasFunctionCallOutput: true,
signals: ToolContinuationSignals{HasFunctionCallOutput: true},
currentPreviousResponse: "resp_client",
expectedPrevious: "resp_1",
want: false,
},
{
name: "skip_when_last_turn_response_id_missing",
storeDisabled: true,
turn: 2,
hasFunctionCallOutput: true,
expectedPrevious: "",
want: false,
name: "skip_when_last_turn_response_id_missing",
storeDisabled: true,
turn: 2,
signals: ToolContinuationSignals{HasFunctionCallOutput: true},
expectedPrevious: "",
want: false,
},
{
name: "trim_whitespace_before_judgement",
storeDisabled: true,
turn: 2,
hasFunctionCallOutput: true,
expectedPrevious: " resp_2 ",
want: true,
name: "trim_whitespace_before_judgement",
storeDisabled: true,
turn: 2,
signals: ToolContinuationSignals{HasFunctionCallOutput: true},
expectedPrevious: " resp_2 ",
want: true,
},
{
name: "skip_when_tool_call_context_already_present",
storeDisabled: true,
turn: 2,
signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasToolCallContext: true},
expectedPrevious: "resp_2",
want: false,
},
{
name: "skip_when_item_reference_already_covers_all_call_ids",
storeDisabled: true,
turn: 2,
signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasItemReferenceForAllCallIDs: true},
expectedPrevious: "resp_2",
want: false,
},
{
name: "skip_when_function_call_output_missing_call_id",
storeDisabled: true,
turn: 2,
signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasFunctionCallOutputMissingCallID: true},
expectedPrevious: "resp_2",
want: false,
},
}
......@@ -303,7 +327,7 @@ func TestShouldInferIngressFunctionCallOutputPreviousResponseID(t *testing.T) {
got := shouldInferIngressFunctionCallOutputPreviousResponseID(
tt.storeDisabled,
tt.turn,
tt.hasFunctionCallOutput,
tt.signals,
tt.currentPreviousResponse,
tt.expectedPrevious,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment