Commit 9d814679 authored by shaw's avatar shaw
Browse files

refactor: 重构 Chat Completions 端点,采用类型安全的 Responses API 转换

将 /v1/chat/completions 端点从 ResponseWriter 劫持模式重构为独立的
类型安全转换路径,与 Anthropic Messages 端点架构对齐:

- 在 apicompat 包新增 Chat Completions 完整类型定义和双向转换器
- 新增 ForwardAsChatCompletions service 方法,走 Responses API 上游
- Handler 改为独立的账号选择/failover 循环,不再劫持 Responses handler
- 提取 handleCompatErrorResponse 为 Chat Completions 和 Messages 共用
- 删除旧的 forwardChatCompletions 直传路径及相关死代码
parent 8dd38f47
package apicompat
import (
"encoding/json"
"fmt"
)
// ChatCompletionsToResponses converts a Chat Completions request into a
// Responses API request. The upstream always streams, so Stream is forced to
// true. store is always false and reasoning.encrypted_content is always
// included so that the response translator has full context.
func ChatCompletionsToResponses(req *ChatCompletionsRequest) (*ResponsesRequest, error) {
input, err := convertChatMessagesToResponsesInput(req.Messages)
if err != nil {
return nil, err
}
inputJSON, err := json.Marshal(input)
if err != nil {
return nil, err
}
out := &ResponsesRequest{
Model: req.Model,
Input: inputJSON,
Temperature: req.Temperature,
TopP: req.TopP,
Stream: true, // upstream always streams
Include: []string{"reasoning.encrypted_content"},
ServiceTier: req.ServiceTier,
}
storeFalse := false
out.Store = &storeFalse
// max_tokens / max_completion_tokens → max_output_tokens, prefer max_completion_tokens
maxTokens := 0
if req.MaxTokens != nil {
maxTokens = *req.MaxTokens
}
if req.MaxCompletionTokens != nil {
maxTokens = *req.MaxCompletionTokens
}
if maxTokens > 0 {
v := maxTokens
if v < minMaxOutputTokens {
v = minMaxOutputTokens
}
out.MaxOutputTokens = &v
}
// reasoning_effort → reasoning.effort + reasoning.summary="auto"
if req.ReasoningEffort != "" {
out.Reasoning = &ResponsesReasoning{
Effort: req.ReasoningEffort,
Summary: "auto",
}
}
// tools[] and legacy functions[] → ResponsesTool[]
if len(req.Tools) > 0 || len(req.Functions) > 0 {
out.Tools = convertChatToolsToResponses(req.Tools, req.Functions)
}
// tool_choice: already compatible format — pass through directly.
// Legacy function_call needs mapping.
if len(req.ToolChoice) > 0 {
out.ToolChoice = req.ToolChoice
} else if len(req.FunctionCall) > 0 {
tc, err := convertChatFunctionCallToToolChoice(req.FunctionCall)
if err != nil {
return nil, fmt.Errorf("convert function_call: %w", err)
}
out.ToolChoice = tc
}
return out, nil
}
// convertChatMessagesToResponsesInput converts the Chat Completions messages
// array into a Responses API input items array.
func convertChatMessagesToResponsesInput(msgs []ChatMessage) ([]ResponsesInputItem, error) {
var out []ResponsesInputItem
for _, m := range msgs {
items, err := chatMessageToResponsesItems(m)
if err != nil {
return nil, err
}
out = append(out, items...)
}
return out, nil
}
// chatMessageToResponsesItems converts a single ChatMessage into one or more
// ResponsesInputItem values.
func chatMessageToResponsesItems(m ChatMessage) ([]ResponsesInputItem, error) {
switch m.Role {
case "system":
return chatSystemToResponses(m)
case "user":
return chatUserToResponses(m)
case "assistant":
return chatAssistantToResponses(m)
case "tool":
return chatToolToResponses(m)
case "function":
return chatFunctionToResponses(m)
default:
return chatUserToResponses(m)
}
}
// chatSystemToResponses converts a system message.
func chatSystemToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
text, err := parseChatContent(m.Content)
if err != nil {
return nil, err
}
content, err := json.Marshal(text)
if err != nil {
return nil, err
}
return []ResponsesInputItem{{Role: "system", Content: content}}, nil
}
// chatUserToResponses converts a user message, handling both plain strings and
// multi-modal content arrays.
func chatUserToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
// Try plain string first.
var s string
if err := json.Unmarshal(m.Content, &s); err == nil {
content, _ := json.Marshal(s)
return []ResponsesInputItem{{Role: "user", Content: content}}, nil
}
var parts []ChatContentPart
if err := json.Unmarshal(m.Content, &parts); err != nil {
return nil, fmt.Errorf("parse user content: %w", err)
}
var responseParts []ResponsesContentPart
for _, p := range parts {
switch p.Type {
case "text":
if p.Text != "" {
responseParts = append(responseParts, ResponsesContentPart{
Type: "input_text",
Text: p.Text,
})
}
case "image_url":
if p.ImageURL != nil && p.ImageURL.URL != "" {
responseParts = append(responseParts, ResponsesContentPart{
Type: "input_image",
ImageURL: p.ImageURL.URL,
})
}
}
}
content, err := json.Marshal(responseParts)
if err != nil {
return nil, err
}
return []ResponsesInputItem{{Role: "user", Content: content}}, nil
}
// chatAssistantToResponses converts an assistant message. If there is both
// text content and tool_calls, the text is emitted as an assistant message
// first, then each tool_call becomes a function_call item. If the content is
// empty/nil and there are tool_calls, only function_call items are emitted.
func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
var items []ResponsesInputItem
// Emit assistant message with output_text if content is non-empty.
if len(m.Content) > 0 {
var s string
if err := json.Unmarshal(m.Content, &s); err == nil && s != "" {
parts := []ResponsesContentPart{{Type: "output_text", Text: s}}
partsJSON, err := json.Marshal(parts)
if err != nil {
return nil, err
}
items = append(items, ResponsesInputItem{Role: "assistant", Content: partsJSON})
}
}
// Emit one function_call item per tool_call.
for _, tc := range m.ToolCalls {
args := tc.Function.Arguments
if args == "" {
args = "{}"
}
items = append(items, ResponsesInputItem{
Type: "function_call",
CallID: tc.ID,
Name: tc.Function.Name,
Arguments: args,
ID: tc.ID,
})
}
return items, nil
}
// chatToolToResponses converts a tool result message (role=tool) into a
// function_call_output item.
func chatToolToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
output, err := parseChatContent(m.Content)
if err != nil {
return nil, err
}
if output == "" {
output = "(empty)"
}
return []ResponsesInputItem{{
Type: "function_call_output",
CallID: m.ToolCallID,
Output: output,
}}, nil
}
// chatFunctionToResponses converts a legacy function result message
// (role=function) into a function_call_output item. The Name field is used as
// call_id since legacy function calls do not carry a separate call_id.
func chatFunctionToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
output, err := parseChatContent(m.Content)
if err != nil {
return nil, err
}
if output == "" {
output = "(empty)"
}
return []ResponsesInputItem{{
Type: "function_call_output",
CallID: m.Name,
Output: output,
}}, nil
}
// parseChatContent returns the string value of a ChatMessage Content field.
// Content must be a JSON string. Returns "" if content is null or empty.
func parseChatContent(raw json.RawMessage) (string, error) {
if len(raw) == 0 {
return "", nil
}
var s string
if err := json.Unmarshal(raw, &s); err != nil {
return "", fmt.Errorf("parse content as string: %w", err)
}
return s, nil
}
// convertChatToolsToResponses maps Chat Completions tool definitions and legacy
// function definitions to Responses API tool definitions.
func convertChatToolsToResponses(tools []ChatTool, functions []ChatFunction) []ResponsesTool {
var out []ResponsesTool
for _, t := range tools {
if t.Type != "function" || t.Function == nil {
continue
}
rt := ResponsesTool{
Type: "function",
Name: t.Function.Name,
Description: t.Function.Description,
Parameters: t.Function.Parameters,
Strict: t.Function.Strict,
}
out = append(out, rt)
}
// Legacy functions[] are treated as function-type tools.
for _, f := range functions {
rt := ResponsesTool{
Type: "function",
Name: f.Name,
Description: f.Description,
Parameters: f.Parameters,
Strict: f.Strict,
}
out = append(out, rt)
}
return out
}
// convertChatFunctionCallToToolChoice maps the legacy function_call field to a
// Responses API tool_choice value.
//
// "auto" → "auto"
// "none" → "none"
// {"name":"X"} → {"type":"function","function":{"name":"X"}}
func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage, error) {
// Try string first ("auto", "none", etc.) — pass through as-is.
var s string
if err := json.Unmarshal(raw, &s); err == nil {
return json.Marshal(s)
}
// Object form: {"name":"X"}
var obj struct {
Name string `json:"name"`
}
if err := json.Unmarshal(raw, &obj); err != nil {
return nil, err
}
return json.Marshal(map[string]any{
"type": "function",
"function": map[string]string{"name": obj.Name},
})
}
package apicompat
import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"time"
)
// ---------------------------------------------------------------------------
// Non-streaming: ResponsesResponse → ChatCompletionsResponse
// ---------------------------------------------------------------------------
// ResponsesToChatCompletions converts a Responses API response into a Chat
// Completions response. Text output items are concatenated into
// choices[0].message.content; function_call items become tool_calls.
func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatCompletionsResponse {
id := resp.ID
if id == "" {
id = generateChatCmplID()
}
out := &ChatCompletionsResponse{
ID: id,
Object: "chat.completion",
Created: time.Now().Unix(),
Model: model,
}
var contentText string
var toolCalls []ChatToolCall
for _, item := range resp.Output {
switch item.Type {
case "message":
for _, part := range item.Content {
if part.Type == "output_text" && part.Text != "" {
contentText += part.Text
}
}
case "function_call":
toolCalls = append(toolCalls, ChatToolCall{
ID: item.CallID,
Type: "function",
Function: ChatFunctionCall{
Name: item.Name,
Arguments: item.Arguments,
},
})
case "reasoning":
for _, s := range item.Summary {
if s.Type == "summary_text" && s.Text != "" {
contentText += s.Text
}
}
case "web_search_call":
// silently consumed — results already incorporated into text output
}
}
msg := ChatMessage{Role: "assistant"}
if len(toolCalls) > 0 {
msg.ToolCalls = toolCalls
}
if contentText != "" {
raw, _ := json.Marshal(contentText)
msg.Content = raw
}
finishReason := responsesStatusToChatFinishReason(resp.Status, resp.IncompleteDetails, toolCalls)
out.Choices = []ChatChoice{{
Index: 0,
Message: msg,
FinishReason: finishReason,
}}
if resp.Usage != nil {
usage := &ChatUsage{
PromptTokens: resp.Usage.InputTokens,
CompletionTokens: resp.Usage.OutputTokens,
TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens,
}
if resp.Usage.InputTokensDetails != nil && resp.Usage.InputTokensDetails.CachedTokens > 0 {
usage.PromptTokensDetails = &ChatTokenDetails{
CachedTokens: resp.Usage.InputTokensDetails.CachedTokens,
}
}
out.Usage = usage
}
return out
}
func responsesStatusToChatFinishReason(status string, details *ResponsesIncompleteDetails, toolCalls []ChatToolCall) string {
switch status {
case "incomplete":
if details != nil && details.Reason == "max_output_tokens" {
return "length"
}
return "stop"
case "completed":
if len(toolCalls) > 0 {
return "tool_calls"
}
return "stop"
default:
return "stop"
}
}
// ---------------------------------------------------------------------------
// Streaming: ResponsesStreamEvent → []ChatCompletionsChunk (stateful converter)
// ---------------------------------------------------------------------------
// ResponsesEventToChatState tracks state for converting a sequence of Responses
// SSE events into Chat Completions SSE chunks.
type ResponsesEventToChatState struct {
ID string
Model string
Created int64
SentRole bool
SawToolCall bool
SawText bool
Finalized bool // true after finish chunk has been emitted
NextToolCallIndex int // next sequential tool_call index to assign
OutputIndexToToolIndex map[int]int // Responses output_index → Chat tool_calls index
IncludeUsage bool
Usage *ChatUsage
}
// NewResponsesEventToChatState returns an initialised stream state.
func NewResponsesEventToChatState() *ResponsesEventToChatState {
return &ResponsesEventToChatState{
ID: generateChatCmplID(),
Created: time.Now().Unix(),
OutputIndexToToolIndex: make(map[int]int),
}
}
// ResponsesEventToChatChunks converts a single Responses SSE event into zero
// or more Chat Completions chunks, updating state as it goes.
func ResponsesEventToChatChunks(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
switch evt.Type {
case "response.created":
return resToChatHandleCreated(evt, state)
case "response.output_text.delta":
return resToChatHandleTextDelta(evt, state)
case "response.output_item.added":
return resToChatHandleOutputItemAdded(evt, state)
case "response.function_call_arguments.delta":
return resToChatHandleFuncArgsDelta(evt, state)
case "response.reasoning_summary_text.delta":
return resToChatHandleReasoningDelta(evt, state)
case "response.completed", "response.incomplete", "response.failed":
return resToChatHandleCompleted(evt, state)
default:
return nil
}
}
// FinalizeResponsesChatStream emits a final chunk with finish_reason if the
// stream ended without a proper completion event (e.g. upstream disconnect).
// It is idempotent: if a completion event already emitted the finish chunk,
// this returns nil.
func FinalizeResponsesChatStream(state *ResponsesEventToChatState) []ChatCompletionsChunk {
if state.Finalized {
return nil
}
state.Finalized = true
finishReason := "stop"
if state.SawToolCall {
finishReason = "tool_calls"
}
chunks := []ChatCompletionsChunk{makeChatFinishChunk(state, finishReason)}
if state.IncludeUsage && state.Usage != nil {
chunks = append(chunks, ChatCompletionsChunk{
ID: state.ID,
Object: "chat.completion.chunk",
Created: state.Created,
Model: state.Model,
Choices: []ChatChunkChoice{},
Usage: state.Usage,
})
}
return chunks
}
// ChatChunkToSSE formats a ChatCompletionsChunk as an SSE data line.
func ChatChunkToSSE(chunk ChatCompletionsChunk) (string, error) {
data, err := json.Marshal(chunk)
if err != nil {
return "", err
}
return fmt.Sprintf("data: %s\n\n", data), nil
}
// --- internal handlers ---
func resToChatHandleCreated(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
if evt.Response != nil {
if evt.Response.ID != "" {
state.ID = evt.Response.ID
}
if state.Model == "" && evt.Response.Model != "" {
state.Model = evt.Response.Model
}
}
// Emit the role chunk.
if state.SentRole {
return nil
}
state.SentRole = true
role := "assistant"
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Role: role})}
}
func resToChatHandleTextDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
if evt.Delta == "" {
return nil
}
state.SawText = true
content := evt.Delta
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Content: &content})}
}
func resToChatHandleOutputItemAdded(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
if evt.Item == nil || evt.Item.Type != "function_call" {
return nil
}
state.SawToolCall = true
idx := state.NextToolCallIndex
state.OutputIndexToToolIndex[evt.OutputIndex] = idx
state.NextToolCallIndex++
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{
ToolCalls: []ChatToolCall{{
Index: &idx,
ID: evt.Item.CallID,
Type: "function",
Function: ChatFunctionCall{
Name: evt.Item.Name,
},
}},
})}
}
func resToChatHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
if evt.Delta == "" {
return nil
}
idx, ok := state.OutputIndexToToolIndex[evt.OutputIndex]
if !ok {
return nil
}
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{
ToolCalls: []ChatToolCall{{
Index: &idx,
Function: ChatFunctionCall{
Arguments: evt.Delta,
},
}},
})}
}
func resToChatHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
if evt.Delta == "" {
return nil
}
content := evt.Delta
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Content: &content})}
}
func resToChatHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
state.Finalized = true
finishReason := "stop"
if evt.Response != nil {
if evt.Response.Usage != nil {
u := evt.Response.Usage
usage := &ChatUsage{
PromptTokens: u.InputTokens,
CompletionTokens: u.OutputTokens,
TotalTokens: u.InputTokens + u.OutputTokens,
}
if u.InputTokensDetails != nil && u.InputTokensDetails.CachedTokens > 0 {
usage.PromptTokensDetails = &ChatTokenDetails{
CachedTokens: u.InputTokensDetails.CachedTokens,
}
}
state.Usage = usage
}
switch evt.Response.Status {
case "incomplete":
if evt.Response.IncompleteDetails != nil && evt.Response.IncompleteDetails.Reason == "max_output_tokens" {
finishReason = "length"
}
case "completed":
if state.SawToolCall {
finishReason = "tool_calls"
}
}
} else if state.SawToolCall {
finishReason = "tool_calls"
}
var chunks []ChatCompletionsChunk
chunks = append(chunks, makeChatFinishChunk(state, finishReason))
if state.IncludeUsage && state.Usage != nil {
chunks = append(chunks, ChatCompletionsChunk{
ID: state.ID,
Object: "chat.completion.chunk",
Created: state.Created,
Model: state.Model,
Choices: []ChatChunkChoice{},
Usage: state.Usage,
})
}
return chunks
}
func makeChatDeltaChunk(state *ResponsesEventToChatState, delta ChatDelta) ChatCompletionsChunk {
return ChatCompletionsChunk{
ID: state.ID,
Object: "chat.completion.chunk",
Created: state.Created,
Model: state.Model,
Choices: []ChatChunkChoice{{
Index: 0,
Delta: delta,
FinishReason: nil,
}},
}
}
func makeChatFinishChunk(state *ResponsesEventToChatState, finishReason string) ChatCompletionsChunk {
empty := ""
return ChatCompletionsChunk{
ID: state.ID,
Object: "chat.completion.chunk",
Created: state.Created,
Model: state.Model,
Choices: []ChatChunkChoice{{
Index: 0,
Delta: ChatDelta{Content: &empty},
FinishReason: &finishReason,
}},
}
}
// generateChatCmplID returns a "chatcmpl-" prefixed random hex ID.
func generateChatCmplID() string {
b := make([]byte, 12)
_, _ = rand.Read(b)
return "chatcmpl-" + hex.EncodeToString(b)
}
......@@ -329,6 +329,148 @@ type ResponsesStreamEvent struct {
SequenceNumber int `json:"sequence_number,omitempty"`
}
// ---------------------------------------------------------------------------
// OpenAI Chat Completions API types
// ---------------------------------------------------------------------------
// ChatCompletionsRequest is the request body for POST /v1/chat/completions.
type ChatCompletionsRequest struct {
Model string `json:"model"`
Messages []ChatMessage `json:"messages"`
MaxTokens *int `json:"max_tokens,omitempty"`
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *ChatStreamOptions `json:"stream_options,omitempty"`
Tools []ChatTool `json:"tools,omitempty"`
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"` // "low" | "medium" | "high"
ServiceTier string `json:"service_tier,omitempty"`
Stop json.RawMessage `json:"stop,omitempty"` // string or []string
// Legacy function calling (deprecated but still supported)
Functions []ChatFunction `json:"functions,omitempty"`
FunctionCall json.RawMessage `json:"function_call,omitempty"`
}
// ChatStreamOptions configures streaming behavior.
type ChatStreamOptions struct {
IncludeUsage bool `json:"include_usage,omitempty"`
}
// ChatMessage is a single message in the Chat Completions conversation.
type ChatMessage struct {
Role string `json:"role"` // "system" | "user" | "assistant" | "tool" | "function"
Content json.RawMessage `json:"content,omitempty"`
Name string `json:"name,omitempty"`
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
// Legacy function calling
FunctionCall *ChatFunctionCall `json:"function_call,omitempty"`
}
// ChatContentPart is a typed content part in a multi-modal message.
type ChatContentPart struct {
Type string `json:"type"` // "text" | "image_url"
Text string `json:"text,omitempty"`
ImageURL *ChatImageURL `json:"image_url,omitempty"`
}
// ChatImageURL contains the URL for an image content part.
type ChatImageURL struct {
URL string `json:"url"`
Detail string `json:"detail,omitempty"` // "auto" | "low" | "high"
}
// ChatTool describes a tool available to the model.
type ChatTool struct {
Type string `json:"type"` // "function"
Function *ChatFunction `json:"function,omitempty"`
}
// ChatFunction describes a function tool definition.
type ChatFunction struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Parameters json.RawMessage `json:"parameters,omitempty"`
Strict *bool `json:"strict,omitempty"`
}
// ChatToolCall represents a tool call made by the assistant.
// Index is only populated in streaming chunks (omitted in non-streaming responses).
type ChatToolCall struct {
Index *int `json:"index,omitempty"`
ID string `json:"id,omitempty"`
Type string `json:"type,omitempty"` // "function"
Function ChatFunctionCall `json:"function"`
}
// ChatFunctionCall contains the function name and arguments.
type ChatFunctionCall struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}
// ChatCompletionsResponse is the non-streaming response from POST /v1/chat/completions.
type ChatCompletionsResponse struct {
ID string `json:"id"`
Object string `json:"object"` // "chat.completion"
Created int64 `json:"created"`
Model string `json:"model"`
Choices []ChatChoice `json:"choices"`
Usage *ChatUsage `json:"usage,omitempty"`
SystemFingerprint string `json:"system_fingerprint,omitempty"`
ServiceTier string `json:"service_tier,omitempty"`
}
// ChatChoice is a single completion choice.
type ChatChoice struct {
Index int `json:"index"`
Message ChatMessage `json:"message"`
FinishReason string `json:"finish_reason"` // "stop" | "length" | "tool_calls" | "content_filter"
}
// ChatUsage holds token counts in Chat Completions format.
type ChatUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
PromptTokensDetails *ChatTokenDetails `json:"prompt_tokens_details,omitempty"`
}
// ChatTokenDetails provides a breakdown of token usage.
type ChatTokenDetails struct {
CachedTokens int `json:"cached_tokens,omitempty"`
}
// ChatCompletionsChunk is a single streaming chunk from POST /v1/chat/completions.
type ChatCompletionsChunk struct {
ID string `json:"id"`
Object string `json:"object"` // "chat.completion.chunk"
Created int64 `json:"created"`
Model string `json:"model"`
Choices []ChatChunkChoice `json:"choices"`
Usage *ChatUsage `json:"usage,omitempty"`
SystemFingerprint string `json:"system_fingerprint,omitempty"`
ServiceTier string `json:"service_tier,omitempty"`
}
// ChatChunkChoice is a single choice in a streaming chunk.
type ChatChunkChoice struct {
Index int `json:"index"`
Delta ChatDelta `json:"delta"`
FinishReason *string `json:"finish_reason"` // pointer: null when not final
}
// ChatDelta carries incremental content in a streaming chunk.
type ChatDelta struct {
Role string `json:"role,omitempty"`
Content *string `json:"content,omitempty"` // pointer: omit when not present, null vs "" matters
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
}
// ---------------------------------------------------------------------------
// Shared constants
// ---------------------------------------------------------------------------
......
package service
import (
"encoding/json"
"errors"
"strings"
"time"
)
// ConvertChatCompletionsToResponses converts an OpenAI Chat Completions request to a Responses request.
func ConvertChatCompletionsToResponses(req map[string]any) (map[string]any, error) {
if req == nil {
return nil, errors.New("request is nil")
}
model := strings.TrimSpace(getString(req["model"]))
if model == "" {
return nil, errors.New("model is required")
}
messagesRaw, ok := req["messages"]
if !ok {
return nil, errors.New("messages is required")
}
messages, ok := messagesRaw.([]any)
if !ok {
return nil, errors.New("messages must be an array")
}
input, err := convertChatMessagesToResponsesInput(messages)
if err != nil {
return nil, err
}
out := make(map[string]any, len(req)+1)
for key, value := range req {
switch key {
case "messages", "max_tokens", "max_completion_tokens", "stream_options", "functions", "function_call":
continue
default:
out[key] = value
}
}
out["model"] = model
out["input"] = input
if _, ok := out["max_output_tokens"]; !ok {
if v, ok := req["max_tokens"]; ok {
out["max_output_tokens"] = v
} else if v, ok := req["max_completion_tokens"]; ok {
out["max_output_tokens"] = v
}
}
if _, ok := out["tools"]; !ok {
if functions, ok := req["functions"].([]any); ok && len(functions) > 0 {
tools := make([]any, 0, len(functions))
for _, fn := range functions {
if fnMap, ok := fn.(map[string]any); ok {
tools = append(tools, map[string]any{
"type": "function",
"function": fnMap,
})
}
}
if len(tools) > 0 {
out["tools"] = tools
}
}
}
if _, ok := out["tool_choice"]; !ok {
if functionCall, ok := req["function_call"]; ok {
out["tool_choice"] = functionCall
}
}
return out, nil
}
// ConvertResponsesToChatCompletion converts an OpenAI Responses response body to Chat Completions format.
func ConvertResponsesToChatCompletion(body []byte) ([]byte, error) {
var resp map[string]any
if err := json.Unmarshal(body, &resp); err != nil {
return nil, err
}
id := strings.TrimSpace(getString(resp["id"]))
if id == "" {
id = "chatcmpl-" + safeRandomHex(12)
}
model := strings.TrimSpace(getString(resp["model"]))
created := getInt64(resp["created_at"])
if created == 0 {
created = getInt64(resp["created"])
}
if created == 0 {
created = time.Now().Unix()
}
text, toolCalls := extractResponseTextAndToolCalls(resp)
finishReason := "stop"
if len(toolCalls) > 0 {
finishReason = "tool_calls"
}
message := map[string]any{
"role": "assistant",
"content": text,
}
if len(toolCalls) > 0 {
message["tool_calls"] = toolCalls
}
chatResp := map[string]any{
"id": id,
"object": "chat.completion",
"created": created,
"model": model,
"choices": []any{
map[string]any{
"index": 0,
"message": message,
"finish_reason": finishReason,
},
},
}
if usage := extractResponseUsage(resp); usage != nil {
chatResp["usage"] = usage
}
if fingerprint := strings.TrimSpace(getString(resp["system_fingerprint"])); fingerprint != "" {
chatResp["system_fingerprint"] = fingerprint
}
return json.Marshal(chatResp)
}
func convertChatMessagesToResponsesInput(messages []any) ([]any, error) {
input := make([]any, 0, len(messages))
for _, msg := range messages {
msgMap, ok := msg.(map[string]any)
if !ok {
return nil, errors.New("message must be an object")
}
role := strings.TrimSpace(getString(msgMap["role"]))
if role == "" {
return nil, errors.New("message role is required")
}
switch role {
case "tool":
callID := strings.TrimSpace(getString(msgMap["tool_call_id"]))
if callID == "" {
callID = strings.TrimSpace(getString(msgMap["id"]))
}
output := extractMessageContentText(msgMap["content"])
input = append(input, map[string]any{
"type": "function_call_output",
"call_id": callID,
"output": output,
})
case "function":
callID := strings.TrimSpace(getString(msgMap["name"]))
output := extractMessageContentText(msgMap["content"])
input = append(input, map[string]any{
"type": "function_call_output",
"call_id": callID,
"output": output,
})
default:
convertedContent := convertChatContent(msgMap["content"])
toolCalls := []any(nil)
if role == "assistant" {
toolCalls = extractToolCallsFromMessage(msgMap)
}
skipAssistantMessage := role == "assistant" && len(toolCalls) > 0 && isEmptyContent(convertedContent)
if !skipAssistantMessage {
msgItem := map[string]any{
"role": role,
"content": convertedContent,
}
if name := strings.TrimSpace(getString(msgMap["name"])); name != "" {
msgItem["name"] = name
}
input = append(input, msgItem)
}
if role == "assistant" && len(toolCalls) > 0 {
input = append(input, toolCalls...)
}
}
}
return input, nil
}
func convertChatContent(content any) any {
switch v := content.(type) {
case nil:
return ""
case string:
return v
case []any:
converted := make([]any, 0, len(v))
for _, part := range v {
partMap, ok := part.(map[string]any)
if !ok {
converted = append(converted, part)
continue
}
partType := strings.TrimSpace(getString(partMap["type"]))
switch partType {
case "text":
text := getString(partMap["text"])
if text != "" {
converted = append(converted, map[string]any{
"type": "input_text",
"text": text,
})
continue
}
case "image_url":
imageURL := ""
if imageObj, ok := partMap["image_url"].(map[string]any); ok {
imageURL = getString(imageObj["url"])
} else {
imageURL = getString(partMap["image_url"])
}
if imageURL != "" {
converted = append(converted, map[string]any{
"type": "input_image",
"image_url": imageURL,
})
continue
}
case "input_text", "input_image":
converted = append(converted, partMap)
continue
}
converted = append(converted, partMap)
}
return converted
default:
return v
}
}
func extractToolCallsFromMessage(msg map[string]any) []any {
var out []any
if toolCalls, ok := msg["tool_calls"].([]any); ok {
for _, call := range toolCalls {
callMap, ok := call.(map[string]any)
if !ok {
continue
}
callID := strings.TrimSpace(getString(callMap["id"]))
if callID == "" {
callID = strings.TrimSpace(getString(callMap["call_id"]))
}
name := ""
args := ""
if fn, ok := callMap["function"].(map[string]any); ok {
name = strings.TrimSpace(getString(fn["name"]))
args = getString(fn["arguments"])
}
if name == "" && args == "" {
continue
}
item := map[string]any{
"type": "tool_call",
}
if callID != "" {
item["call_id"] = callID
}
if name != "" {
item["name"] = name
}
if args != "" {
item["arguments"] = args
}
out = append(out, item)
}
}
if fnCall, ok := msg["function_call"].(map[string]any); ok {
name := strings.TrimSpace(getString(fnCall["name"]))
args := getString(fnCall["arguments"])
if name != "" || args != "" {
callID := strings.TrimSpace(getString(msg["tool_call_id"]))
if callID == "" {
callID = name
}
item := map[string]any{
"type": "function_call",
}
if callID != "" {
item["call_id"] = callID
}
if name != "" {
item["name"] = name
}
if args != "" {
item["arguments"] = args
}
out = append(out, item)
}
}
return out
}
func extractMessageContentText(content any) string {
switch v := content.(type) {
case string:
return v
case []any:
parts := make([]string, 0, len(v))
for _, part := range v {
partMap, ok := part.(map[string]any)
if !ok {
continue
}
partType := strings.TrimSpace(getString(partMap["type"]))
if partType == "" || partType == "text" || partType == "output_text" || partType == "input_text" {
text := getString(partMap["text"])
if text != "" {
parts = append(parts, text)
}
}
}
return strings.Join(parts, "")
default:
return ""
}
}
func isEmptyContent(content any) bool {
switch v := content.(type) {
case nil:
return true
case string:
return strings.TrimSpace(v) == ""
case []any:
return len(v) == 0
default:
return false
}
}
func extractResponseTextAndToolCalls(resp map[string]any) (string, []any) {
output, ok := resp["output"].([]any)
if !ok {
if text, ok := resp["output_text"].(string); ok {
return text, nil
}
return "", nil
}
textParts := make([]string, 0)
toolCalls := make([]any, 0)
for _, item := range output {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType := strings.TrimSpace(getString(itemMap["type"]))
if itemType == "tool_call" || itemType == "function_call" {
if tc := responseItemToChatToolCall(itemMap); tc != nil {
toolCalls = append(toolCalls, tc)
}
continue
}
content := itemMap["content"]
switch v := content.(type) {
case string:
if v != "" {
textParts = append(textParts, v)
}
case []any:
for _, part := range v {
partMap, ok := part.(map[string]any)
if !ok {
continue
}
partType := strings.TrimSpace(getString(partMap["type"]))
switch partType {
case "output_text", "text", "input_text":
text := getString(partMap["text"])
if text != "" {
textParts = append(textParts, text)
}
case "tool_call", "function_call":
if tc := responseItemToChatToolCall(partMap); tc != nil {
toolCalls = append(toolCalls, tc)
}
}
}
}
}
return strings.Join(textParts, ""), toolCalls
}
func responseItemToChatToolCall(item map[string]any) map[string]any {
callID := strings.TrimSpace(getString(item["call_id"]))
if callID == "" {
callID = strings.TrimSpace(getString(item["id"]))
}
name := strings.TrimSpace(getString(item["name"]))
arguments := getString(item["arguments"])
if fn, ok := item["function"].(map[string]any); ok {
if name == "" {
name = strings.TrimSpace(getString(fn["name"]))
}
if arguments == "" {
arguments = getString(fn["arguments"])
}
}
if name == "" && arguments == "" && callID == "" {
return nil
}
if callID == "" {
callID = "call_" + safeRandomHex(6)
}
return map[string]any{
"id": callID,
"type": "function",
"function": map[string]any{
"name": name,
"arguments": arguments,
},
}
}
func extractResponseUsage(resp map[string]any) map[string]any {
usage, ok := resp["usage"].(map[string]any)
if !ok {
return nil
}
promptTokens := int(getNumber(usage["input_tokens"]))
completionTokens := int(getNumber(usage["output_tokens"]))
if promptTokens == 0 && completionTokens == 0 {
return nil
}
return map[string]any{
"prompt_tokens": promptTokens,
"completion_tokens": completionTokens,
"total_tokens": promptTokens + completionTokens,
}
}
func getString(value any) string {
switch v := value.(type) {
case string:
return v
case []byte:
return string(v)
case json.Number:
return v.String()
default:
return ""
}
}
func getNumber(value any) float64 {
switch v := value.(type) {
case float64:
return v
case float32:
return float64(v)
case int:
return float64(v)
case int64:
return float64(v)
case json.Number:
f, _ := v.Float64()
return f
default:
return 0
}
}
func getInt64(value any) int64 {
switch v := value.(type) {
case int64:
return v
case int:
return int64(v)
case float64:
return int64(v)
case json.Number:
i, _ := v.Int64()
return i
default:
return 0
}
}
func safeRandomHex(byteLength int) string {
value, err := randomHexString(byteLength)
if err != nil || value == "" {
return "000000"
}
return value
}
package service
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/gin-gonic/gin"
)
type chatStreamingResult struct {
usage *OpenAIUsage
firstTokenMs *int
}
func (s *OpenAIGatewayService) forwardChatCompletions(ctx context.Context, c *gin.Context, account *Account, body []byte, includeUsage bool, startTime time.Time) (*OpenAIForwardResult, error) {
// Parse request body once (avoid multiple parse/serialize cycles)
var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err != nil {
return nil, fmt.Errorf("parse request: %w", err)
}
reqModel, _ := reqBody["model"].(string)
reqStream, _ := reqBody["stream"].(bool)
originalModel := reqModel
bodyModified := false
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel {
log.Printf("[OpenAI Chat] Model mapping applied: %s -> %s (account: %s)", reqModel, mappedModel, account.Name)
reqBody["model"] = mappedModel
bodyModified = true
}
if reqStream && includeUsage {
streamOptions, _ := reqBody["stream_options"].(map[string]any)
if streamOptions == nil {
streamOptions = map[string]any{}
}
if _, ok := streamOptions["include_usage"]; !ok {
streamOptions["include_usage"] = true
reqBody["stream_options"] = streamOptions
bodyModified = true
}
}
if bodyModified {
var err error
body, err = json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("serialize request body: %w", err)
}
}
// Get access token
token, _, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, err
}
upstreamReq, err := s.buildChatCompletionsRequest(ctx, c, account, body, token)
if err != nil {
return nil, err
}
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
if c != nil {
c.Set(OpsUpstreamRequestBodyKey, string(body))
}
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"type": "upstream_error",
"message": "Upstream request failed",
},
})
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode >= 400 {
if s.shouldFailoverUpstreamError(resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
return s.handleErrorResponse(ctx, resp, c, account, body)
}
var usage *OpenAIUsage
var firstTokenMs *int
if reqStream {
streamResult, err := s.handleChatCompletionsStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel)
if err != nil {
return nil, err
}
usage = streamResult.usage
firstTokenMs = streamResult.firstTokenMs
} else {
usage, err = s.handleChatCompletionsNonStreamingResponse(resp, c, originalModel, mappedModel)
if err != nil {
return nil, err
}
}
if usage == nil {
usage = &OpenAIUsage{}
}
return &OpenAIForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Usage: *usage,
Model: originalModel,
Stream: reqStream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
}
func (s *OpenAIGatewayService) buildChatCompletionsRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string) (*http.Request, error) {
var targetURL string
baseURL := account.GetOpenAIBaseURL()
if baseURL == "" {
targetURL = openaiChatAPIURL
} else {
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
targetURL = validatedURL + "/chat/completions"
}
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("authorization", "Bearer "+token)
for key, values := range c.Request.Header {
lowerKey := strings.ToLower(key)
if openaiChatAllowedHeaders[lowerKey] {
for _, v := range values {
req.Header.Add(key, v)
}
}
}
customUA := account.GetOpenAIUserAgent()
if customUA != "" {
req.Header.Set("user-agent", customUA)
}
if req.Header.Get("content-type") == "" {
req.Header.Set("content-type", "application/json")
}
return req, nil
}
func (s *OpenAIGatewayService) handleChatCompletionsStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*chatStreamingResult, error) {
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
if v := resp.Header.Get("x-request-id"); v != "" {
c.Header("x-request-id", v)
}
w := c.Writer
flusher, ok := w.(http.Flusher)
if !ok {
return nil, errors.New("streaming not supported")
}
usage := &OpenAIUsage{}
var firstTokenMs *int
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
type scanEvent struct {
line string
err error
}
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
streamInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
keepaliveInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
}
var keepaliveTicker *time.Ticker
if keepaliveInterval > 0 {
keepaliveTicker = time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
}
var keepaliveCh <-chan time.Time
if keepaliveTicker != nil {
keepaliveCh = keepaliveTicker.C
}
lastDataAt := time.Now()
errorEventSent := false
sendErrorEvent := func(reason string) {
if errorEventSent {
return
}
errorEventSent = true
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
flusher.Flush()
}
needModelReplace := originalModel != mappedModel
for {
select {
case ev, ok := <-events:
if !ok {
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
if ev.err != nil {
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
sendErrorEvent("response_too_large")
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
}
sendErrorEvent("stream_read_error")
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
}
line := ev.line
lastDataAt = time.Now()
if openaiSSEDataRe.MatchString(line) {
data := openaiSSEDataRe.ReplaceAllString(line, "")
if needModelReplace {
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
}
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected {
line = "data: " + correctedData
}
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
if firstTokenMs == nil {
if event := parseChatStreamEvent(data); event != nil {
if chatChunkHasDelta(event) {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
applyChatUsageFromEvent(event, usage)
}
} else {
if event := parseChatStreamEvent(data); event != nil {
applyChatUsageFromEvent(event, usage)
}
}
} else {
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
}
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
if s.rateLimitService != nil {
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
}
sendErrorEvent("stream_timeout")
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
case <-keepaliveCh:
if time.Since(lastDataAt) < keepaliveInterval {
continue
}
if _, err := fmt.Fprint(w, ":\n\n"); err != nil {
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
}
}
}
func (s *OpenAIGatewayService) handleChatCompletionsNonStreamingResponse(resp *http.Response, c *gin.Context, originalModel, mappedModel string) (*OpenAIUsage, error) {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
usage := &OpenAIUsage{}
var parsed map[string]any
if json.Unmarshal(body, &parsed) == nil {
if usageMap, ok := parsed["usage"].(map[string]any); ok {
applyChatUsageFromMap(usageMap, usage)
}
}
if originalModel != mappedModel {
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
body = s.correctToolCallsInResponseBody(body)
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
contentType := "application/json"
if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled {
if upstreamType := resp.Header.Get("Content-Type"); upstreamType != "" {
contentType = upstreamType
}
}
c.Data(resp.StatusCode, contentType, body)
return usage, nil
}
func parseChatStreamEvent(data string) map[string]any {
if data == "" || data == "[DONE]" {
return nil
}
var event map[string]any
if json.Unmarshal([]byte(data), &event) != nil {
return nil
}
return event
}
func chatChunkHasDelta(event map[string]any) bool {
choices, ok := event["choices"].([]any)
if !ok {
return false
}
for _, choice := range choices {
choiceMap, ok := choice.(map[string]any)
if !ok {
continue
}
delta, ok := choiceMap["delta"].(map[string]any)
if !ok {
continue
}
if content, ok := delta["content"].(string); ok && strings.TrimSpace(content) != "" {
return true
}
if toolCalls, ok := delta["tool_calls"].([]any); ok && len(toolCalls) > 0 {
return true
}
if functionCall, ok := delta["function_call"].(map[string]any); ok && len(functionCall) > 0 {
return true
}
}
return false
}
func applyChatUsageFromEvent(event map[string]any, usage *OpenAIUsage) {
if event == nil || usage == nil {
return
}
usageMap, ok := event["usage"].(map[string]any)
if !ok {
return
}
applyChatUsageFromMap(usageMap, usage)
}
func applyChatUsageFromMap(usageMap map[string]any, usage *OpenAIUsage) {
if usageMap == nil || usage == nil {
return
}
promptTokens := int(getNumber(usageMap["prompt_tokens"]))
completionTokens := int(getNumber(usageMap["completion_tokens"]))
if promptTokens > 0 {
usage.InputTokens = promptTokens
}
if completionTokens > 0 {
usage.OutputTokens = completionTokens
}
}
package service
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
)
func TestConvertChatCompletionsToResponses(t *testing.T) {
req := map[string]any{
"model": "gpt-4o",
"messages": []any{
map[string]any{
"role": "user",
"content": "hello",
},
map[string]any{
"role": "assistant",
"tool_calls": []any{
map[string]any{
"id": "call_1",
"type": "function",
"function": map[string]any{
"name": "ping",
"arguments": "{}",
},
},
},
},
map[string]any{
"role": "tool",
"tool_call_id": "call_1",
"content": "ok",
"response": "ignored",
"response_time": 1,
},
},
"functions": []any{
map[string]any{
"name": "ping",
"description": "ping tool",
"parameters": map[string]any{"type": "object"},
},
},
"function_call": map[string]any{"name": "ping"},
}
converted, err := ConvertChatCompletionsToResponses(req)
require.NoError(t, err)
require.Equal(t, "gpt-4o", converted["model"])
input, ok := converted["input"].([]any)
require.True(t, ok)
require.Len(t, input, 3)
toolCall := findInputItemByType(input, "tool_call")
require.NotNil(t, toolCall)
require.Equal(t, "call_1", toolCall["call_id"])
toolOutput := findInputItemByType(input, "function_call_output")
require.NotNil(t, toolOutput)
require.Equal(t, "call_1", toolOutput["call_id"])
tools, ok := converted["tools"].([]any)
require.True(t, ok)
require.Len(t, tools, 1)
require.Equal(t, map[string]any{"name": "ping"}, converted["tool_choice"])
}
func TestConvertResponsesToChatCompletion(t *testing.T) {
resp := map[string]any{
"id": "resp_123",
"model": "gpt-4o",
"created_at": 1700000000,
"output": []any{
map[string]any{
"type": "message",
"role": "assistant",
"content": []any{
map[string]any{
"type": "output_text",
"text": "hi",
},
},
},
},
"usage": map[string]any{
"input_tokens": 2,
"output_tokens": 3,
},
}
body, err := json.Marshal(resp)
require.NoError(t, err)
converted, err := ConvertResponsesToChatCompletion(body)
require.NoError(t, err)
var chat map[string]any
require.NoError(t, json.Unmarshal(converted, &chat))
require.Equal(t, "chat.completion", chat["object"])
choices, ok := chat["choices"].([]any)
require.True(t, ok)
require.Len(t, choices, 1)
choice, ok := choices[0].(map[string]any)
require.True(t, ok)
message, ok := choice["message"].(map[string]any)
require.True(t, ok)
require.Equal(t, "hi", message["content"])
usage, ok := chat["usage"].(map[string]any)
require.True(t, ok)
require.Equal(t, float64(2), usage["prompt_tokens"])
require.Equal(t, float64(3), usage["completion_tokens"])
require.Equal(t, float64(5), usage["total_tokens"])
}
func findInputItemByType(items []any, itemType string) map[string]any {
for _, item := range items {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
if itemMap["type"] == itemType {
return itemMap
}
}
return nil
}
package service
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// ForwardAsChatCompletions accepts a Chat Completions request body, converts it
// to OpenAI Responses API format, forwards to the OpenAI upstream, and converts
// the response back to Chat Completions format. All account types (OAuth and API
// Key) go through the Responses API conversion path since the upstream only
// exposes the /v1/responses endpoint.
func (s *OpenAIGatewayService) ForwardAsChatCompletions(
ctx context.Context,
c *gin.Context,
account *Account,
body []byte,
promptCacheKey string,
defaultMappedModel string,
) (*OpenAIForwardResult, error) {
startTime := time.Now()
// 1. Parse Chat Completions request
var chatReq apicompat.ChatCompletionsRequest
if err := json.Unmarshal(body, &chatReq); err != nil {
return nil, fmt.Errorf("parse chat completions request: %w", err)
}
originalModel := chatReq.Model
clientStream := chatReq.Stream
includeUsage := chatReq.StreamOptions != nil && chatReq.StreamOptions.IncludeUsage
// 2. Convert to Responses and forward
// ChatCompletionsToResponses always sets Stream=true (upstream always streams).
responsesReq, err := apicompat.ChatCompletionsToResponses(&chatReq)
if err != nil {
return nil, fmt.Errorf("convert chat completions to responses: %w", err)
}
// 3. Model mapping
mappedModel := account.GetMappedModel(originalModel)
if mappedModel == originalModel && defaultMappedModel != "" {
mappedModel = defaultMappedModel
}
responsesReq.Model = mappedModel
logger.L().Debug("openai chat_completions: model mapping applied",
zap.Int64("account_id", account.ID),
zap.String("original_model", originalModel),
zap.String("mapped_model", mappedModel),
zap.Bool("stream", clientStream),
)
// 4. Marshal Responses request body, then apply OAuth codex transform
responsesBody, err := json.Marshal(responsesReq)
if err != nil {
return nil, fmt.Errorf("marshal responses request: %w", err)
}
if account.Type == AccountTypeOAuth {
var reqBody map[string]any
if err := json.Unmarshal(responsesBody, &reqBody); err != nil {
return nil, fmt.Errorf("unmarshal for codex transform: %w", err)
}
codexResult := applyCodexOAuthTransform(reqBody, false, false)
if codexResult.PromptCacheKey != "" {
promptCacheKey = codexResult.PromptCacheKey
} else if promptCacheKey != "" {
reqBody["prompt_cache_key"] = promptCacheKey
}
responsesBody, err = json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("remarshal after codex transform: %w", err)
}
}
// 5. Get access token
token, _, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, fmt.Errorf("get access token: %w", err)
}
// 6. Build upstream request
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, promptCacheKey, false)
if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err)
}
if promptCacheKey != "" {
upstreamReq.Header.Set("session_id", generateSessionUUID(promptCacheKey))
}
// 7. Send request
proxyURL := ""
if account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed")
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
// 8. Handle error response with failover
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
if s.rateLimitService != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
}
}
return s.handleChatCompletionsErrorResponse(resp, c, account)
}
// 9. Handle normal response
var result *OpenAIForwardResult
var handleErr error
if clientStream {
result, handleErr = s.handleChatStreamingResponse(resp, c, originalModel, mappedModel, includeUsage, startTime)
} else {
result, handleErr = s.handleChatBufferedStreamingResponse(resp, c, originalModel, mappedModel, startTime)
}
// Propagate ServiceTier and ReasoningEffort to result for billing
if handleErr == nil && result != nil {
if responsesReq.ServiceTier != "" {
st := responsesReq.ServiceTier
result.ServiceTier = &st
}
if responsesReq.Reasoning != nil && responsesReq.Reasoning.Effort != "" {
re := responsesReq.Reasoning.Effort
result.ReasoningEffort = &re
}
}
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
if handleErr == nil && account.Type == AccountTypeOAuth {
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
}
}
return result, handleErr
}
// handleChatCompletionsErrorResponse reads an upstream error and returns it in
// OpenAI Chat Completions error format.
func (s *OpenAIGatewayService) handleChatCompletionsErrorResponse(
resp *http.Response,
c *gin.Context,
account *Account,
) (*OpenAIForwardResult, error) {
return s.handleCompatErrorResponse(resp, c, account, writeChatCompletionsError)
}
// handleChatBufferedStreamingResponse reads all Responses SSE events from the
// upstream, finds the terminal event, converts to a Chat Completions JSON
// response, and writes it to the client.
func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse(
resp *http.Response,
c *gin.Context,
originalModel string,
mappedModel string,
startTime time.Time,
) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
var finalResponse *apicompat.ResponsesResponse
var usage OpenAIUsage
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
continue
}
payload := line[6:]
var event apicompat.ResponsesStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
logger.L().Warn("openai chat_completions buffered: failed to parse event",
zap.Error(err),
zap.String("request_id", requestID),
)
continue
}
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") &&
event.Response != nil {
finalResponse = event.Response
if event.Response.Usage != nil {
usage = OpenAIUsage{
InputTokens: event.Response.Usage.InputTokens,
OutputTokens: event.Response.Usage.OutputTokens,
}
if event.Response.Usage.InputTokensDetails != nil {
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
}
}
}
}
if err := scanner.Err(); err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("openai chat_completions buffered: read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
}
if finalResponse == nil {
writeChatCompletionsError(c, http.StatusBadGateway, "api_error", "Upstream stream ended without a terminal response event")
return nil, fmt.Errorf("upstream stream ended without terminal event")
}
chatResp := apicompat.ResponsesToChatCompletions(finalResponse, originalModel)
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.JSON(http.StatusOK, chatResp)
return &OpenAIForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: mappedModel,
Stream: false,
Duration: time.Since(startTime),
}, nil
}
// handleChatStreamingResponse reads Responses SSE events from upstream,
// converts each to Chat Completions SSE chunks, and writes them to the client.
func (s *OpenAIGatewayService) handleChatStreamingResponse(
resp *http.Response,
c *gin.Context,
originalModel string,
mappedModel string,
includeUsage bool,
startTime time.Time,
) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.WriteHeader(http.StatusOK)
state := apicompat.NewResponsesEventToChatState()
state.Model = originalModel
state.IncludeUsage = includeUsage
var usage OpenAIUsage
var firstTokenMs *int
firstChunk := true
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
resultWithUsage := func() *OpenAIForwardResult {
return &OpenAIForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: mappedModel,
Stream: true,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}
}
processDataLine := func(payload string) bool {
if firstChunk {
firstChunk = false
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
var event apicompat.ResponsesStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
logger.L().Warn("openai chat_completions stream: failed to parse event",
zap.Error(err),
zap.String("request_id", requestID),
)
return false
}
// Extract usage from completion events
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") &&
event.Response != nil && event.Response.Usage != nil {
usage = OpenAIUsage{
InputTokens: event.Response.Usage.InputTokens,
OutputTokens: event.Response.Usage.OutputTokens,
}
if event.Response.Usage.InputTokensDetails != nil {
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
}
}
chunks := apicompat.ResponsesEventToChatChunks(&event, state)
for _, chunk := range chunks {
sse, err := apicompat.ChatChunkToSSE(chunk)
if err != nil {
logger.L().Warn("openai chat_completions stream: failed to marshal chunk",
zap.Error(err),
zap.String("request_id", requestID),
)
continue
}
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
logger.L().Info("openai chat_completions stream: client disconnected",
zap.String("request_id", requestID),
)
return true
}
}
if len(chunks) > 0 {
c.Writer.Flush()
}
return false
}
finalizeStream := func() (*OpenAIForwardResult, error) {
if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 {
for _, chunk := range finalChunks {
sse, err := apicompat.ChatChunkToSSE(chunk)
if err != nil {
continue
}
fmt.Fprint(c.Writer, sse) //nolint:errcheck
}
}
// Send [DONE] sentinel
fmt.Fprint(c.Writer, "data: [DONE]\n\n") //nolint:errcheck
c.Writer.Flush()
return resultWithUsage(), nil
}
handleScanErr := func(err error) {
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("openai chat_completions stream: read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
}
// Determine keepalive interval
keepaliveInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
}
// No keepalive: fast synchronous path
if keepaliveInterval <= 0 {
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
continue
}
if processDataLine(line[6:]) {
return resultWithUsage(), nil
}
}
handleScanErr(scanner.Err())
return finalizeStream()
}
// With keepalive: goroutine + channel + select
type scanEvent struct {
line string
err error
}
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
go func() {
defer close(events)
for scanner.Scan() {
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
keepaliveTicker := time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
lastDataAt := time.Now()
for {
select {
case ev, ok := <-events:
if !ok {
return finalizeStream()
}
if ev.err != nil {
handleScanErr(ev.err)
return finalizeStream()
}
lastDataAt = time.Now()
line := ev.line
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
continue
}
if processDataLine(line[6:]) {
return resultWithUsage(), nil
}
case <-keepaliveTicker.C:
if time.Since(lastDataAt) < keepaliveInterval {
continue
}
// Send SSE comment as keepalive
if _, err := fmt.Fprint(c.Writer, ":\n\n"); err != nil {
logger.L().Info("openai chat_completions stream: client disconnected during keepalive",
zap.String("request_id", requestID),
)
return resultWithUsage(), nil
}
c.Writer.Flush()
}
}
}
// writeChatCompletionsError writes an error response in OpenAI Chat Completions format.
func writeChatCompletionsError(c *gin.Context, statusCode int, errType, message string) {
c.JSON(statusCode, gin.H{
"error": gin.H{
"type": errType,
"message": message,
},
})
}
......@@ -172,7 +172,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody),
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
}
}
// Non-failover error: return Anthropic-formatted error to client
......@@ -219,54 +219,7 @@ func (s *OpenAIGatewayService) handleAnthropicErrorResponse(
c *gin.Context,
account *Account,
) (*OpenAIForwardResult, error) {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
if upstreamMsg == "" {
upstreamMsg = fmt.Sprintf("Upstream error: %d", resp.StatusCode)
}
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
// Record upstream error details for ops logging
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(body), maxBytes)
}
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
// Apply error passthrough rules (matches handleErrorResponse pattern in openai_gateway_service.go)
if status, errType, errMsg, matched := applyErrorPassthroughRule(
c, account.Platform, resp.StatusCode, body,
http.StatusBadGateway, "api_error", "Upstream request failed",
); matched {
writeAnthropicError(c, status, errType, errMsg)
if upstreamMsg == "" {
upstreamMsg = errMsg
}
if upstreamMsg == "" {
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, upstreamMsg)
}
errType := "api_error"
switch {
case resp.StatusCode == 400:
errType = "invalid_request_error"
case resp.StatusCode == 404:
errType = "not_found_error"
case resp.StatusCode == 429:
errType = "rate_limit_error"
case resp.StatusCode >= 500:
errType = "api_error"
}
writeAnthropicError(c, resp.StatusCode, errType, upstreamMsg)
return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg)
return s.handleCompatErrorResponse(resp, c, account, writeAnthropicError)
}
// handleAnthropicBufferedStreamingResponse reads all Responses SSE events from
......
......@@ -12,7 +12,6 @@ import (
"io"
"math/rand"
"net/http"
"regexp"
"sort"
"strconv"
"strings"
......@@ -37,7 +36,6 @@ const (
chatgptCodexURL = "https://chatgpt.com/backend-api/codex/responses"
// OpenAI Platform API for API Key accounts (fallback)
openaiPlatformAPIURL = "https://api.openai.com/v1/responses"
openaiChatAPIURL = "https://api.openai.com/v1/chat/completions"
openaiStickySessionTTL = time.Hour // 粘性会话TTL
codexCLIUserAgent = "codex_cli_rs/0.104.0"
// codex_cli_only 拒绝时单个请求头日志长度上限(字符)
......@@ -56,16 +54,6 @@ const (
codexCLIVersion = "0.104.0"
)
// OpenAIChatCompletionsBodyKey stores the original chat-completions payload in gin.Context.
const OpenAIChatCompletionsBodyKey = "openai_chat_completions_body"
// OpenAIChatCompletionsIncludeUsageKey stores stream_options.include_usage in gin.Context.
const OpenAIChatCompletionsIncludeUsageKey = "openai_chat_completions_include_usage"
// openaiSSEDataRe matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var openaiSSEDataRe = regexp.MustCompile(`^data:\s*`)
// OpenAI allowed headers whitelist (for non-passthrough).
var openaiAllowedHeaders = map[string]bool{
"accept-language": true,
......@@ -109,19 +97,6 @@ var codexCLIOnlyDebugHeaderWhitelist = []string{
"X-Real-IP",
}
// OpenAI chat-completions allowed headers (extend responses whitelist).
var openaiChatAllowedHeaders = map[string]bool{
"accept-language": true,
"content-type": true,
"conversation_id": true,
"user-agent": true,
"originator": true,
"session_id": true,
"openai-organization": true,
"openai-project": true,
"openai-beta": true,
}
// OpenAICodexUsageSnapshot represents Codex API usage limits from response headers
type OpenAICodexUsageSnapshot struct {
PrimaryUsedPercent *float64 `json:"primary_used_percent,omitempty"`
......@@ -1602,23 +1577,6 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
return nil, errors.New("codex_cli_only restriction: only codex official clients are allowed")
}
if c != nil && account != nil && account.Type == AccountTypeAPIKey {
if raw, ok := c.Get(OpenAIChatCompletionsBodyKey); ok {
if rawBody, ok := raw.([]byte); ok && len(rawBody) > 0 {
includeUsage := false
if v, ok := c.Get(OpenAIChatCompletionsIncludeUsageKey); ok {
if flag, ok := v.(bool); ok {
includeUsage = flag
}
}
if passthroughWriter, ok := c.Writer.(interface{ SetPassthrough() }); ok {
passthroughWriter.SetPassthrough()
}
return s.forwardChatCompletions(ctx, c, account, rawBody, includeUsage, startTime)
}
}
}
originalBody := body
reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body)
originalModel := reqModel
......@@ -2989,6 +2947,120 @@ func (s *OpenAIGatewayService) handleErrorResponse(
return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
}
// compatErrorWriter is the signature for format-specific error writers used by
// the compat paths (Chat Completions and Anthropic Messages).
type compatErrorWriter func(c *gin.Context, statusCode int, errType, message string)
// handleCompatErrorResponse is the shared non-failover error handler for the
// Chat Completions and Anthropic Messages compat paths. It mirrors the logic of
// handleErrorResponse (passthrough rules, ShouldHandleErrorCode, rate-limit
// tracking, secondary failover) but delegates the final error write to the
// format-specific writer function.
func (s *OpenAIGatewayService) handleCompatErrorResponse(
resp *http.Response,
c *gin.Context,
account *Account,
writeError compatErrorWriter,
) (*OpenAIForwardResult, error) {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
if upstreamMsg == "" {
upstreamMsg = fmt.Sprintf("Upstream error: %d", resp.StatusCode)
}
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(body), maxBytes)
}
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
// Apply error passthrough rules
if status, errType, errMsg, matched := applyErrorPassthroughRule(
c, account.Platform, resp.StatusCode, body,
http.StatusBadGateway, "api_error", "Upstream request failed",
); matched {
writeError(c, status, errType, errMsg)
if upstreamMsg == "" {
upstreamMsg = errMsg
}
if upstreamMsg == "" {
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, upstreamMsg)
}
// Check custom error codes — if the account does not handle this status,
// return a generic error without exposing upstream details.
if !account.ShouldHandleErrorCode(resp.StatusCode) {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "http_error",
Message: upstreamMsg,
Detail: upstreamDetail,
})
writeError(c, http.StatusInternalServerError, "api_error", "Upstream gateway error")
if upstreamMsg == "" {
return nil, fmt.Errorf("upstream error: %d (not in custom error codes)", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d (not in custom error codes) message=%s", resp.StatusCode, upstreamMsg)
}
// Track rate limits and decide whether to trigger secondary failover.
shouldDisable := false
if s.rateLimitService != nil {
shouldDisable = s.rateLimitService.HandleUpstreamError(
c.Request.Context(), account, resp.StatusCode, resp.Header, body,
)
}
kind := "http_error"
if shouldDisable {
kind = "failover"
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: kind,
Message: upstreamMsg,
Detail: upstreamDetail,
})
if shouldDisable {
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: body,
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
}
}
// Map status code to error type and write response
errType := "api_error"
switch {
case resp.StatusCode == 400:
errType = "invalid_request_error"
case resp.StatusCode == 404:
errType = "not_found_error"
case resp.StatusCode == 429:
errType = "rate_limit_error"
case resp.StatusCode >= 500:
errType = "api_error"
}
writeError(c, resp.StatusCode, errType, upstreamMsg)
return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg)
}
// openaiStreamingResult streaming response result
type openaiStreamingResult struct {
usage *OpenAIUsage
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment