Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
8dd38f47
Unverified
Commit
8dd38f47
authored
Mar 11, 2026
by
Wesley Liddick
Committed by
GitHub
Mar 11, 2026
Browse files
Merge pull request #926 from 7976723/feat/chat-completions-compat-v2
feat: 添加 OpenAI Chat Completions 兼容端点(基于 #648,修复编译错误和运行时 panic)
parents
6bbe7800
a17ac501
Changes
6
Hide whitespace changes
Inline
Side-by-side
backend/internal/handler/openai_chat_completions.go
0 → 100644
View file @
8dd38f47
package
handler
import
(
"bytes"
"crypto/rand"
"encoding/hex"
"encoding/json"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// ChatCompletions handles OpenAI Chat Completions API compatibility.
// POST /v1/chat/completions
func
(
h
*
OpenAIGatewayHandler
)
ChatCompletions
(
c
*
gin
.
Context
)
{
body
,
err
:=
io
.
ReadAll
(
c
.
Request
.
Body
)
if
err
!=
nil
{
if
maxErr
,
ok
:=
extractMaxBytesError
(
err
);
ok
{
h
.
errorResponse
(
c
,
http
.
StatusRequestEntityTooLarge
,
"invalid_request_error"
,
buildBodyTooLargeMessage
(
maxErr
.
Limit
))
return
}
h
.
errorResponse
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"Failed to read request body"
)
return
}
if
len
(
body
)
==
0
{
h
.
errorResponse
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"Request body is empty"
)
return
}
// Preserve original chat-completions request for upstream passthrough when needed.
c
.
Set
(
service
.
OpenAIChatCompletionsBodyKey
,
body
)
var
chatReq
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
chatReq
);
err
!=
nil
{
h
.
errorResponse
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"Failed to parse request body"
)
return
}
includeUsage
:=
false
if
streamOptions
,
ok
:=
chatReq
[
"stream_options"
]
.
(
map
[
string
]
any
);
ok
{
if
v
,
ok
:=
streamOptions
[
"include_usage"
]
.
(
bool
);
ok
{
includeUsage
=
v
}
}
c
.
Set
(
service
.
OpenAIChatCompletionsIncludeUsageKey
,
includeUsage
)
converted
,
err
:=
service
.
ConvertChatCompletionsToResponses
(
chatReq
)
if
err
!=
nil
{
h
.
errorResponse
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
err
.
Error
())
return
}
convertedBody
,
err
:=
json
.
Marshal
(
converted
)
if
err
!=
nil
{
h
.
errorResponse
(
c
,
http
.
StatusInternalServerError
,
"api_error"
,
"Failed to process request"
)
return
}
stream
,
_
:=
converted
[
"stream"
]
.
(
bool
)
model
,
_
:=
converted
[
"model"
]
.
(
string
)
originalWriter
:=
c
.
Writer
writer
:=
newChatCompletionsResponseWriter
(
c
.
Writer
,
stream
,
includeUsage
,
model
)
c
.
Writer
=
writer
c
.
Request
.
Body
=
io
.
NopCloser
(
bytes
.
NewReader
(
convertedBody
))
c
.
Request
.
ContentLength
=
int64
(
len
(
convertedBody
))
h
.
Responses
(
c
)
writer
.
Finalize
()
c
.
Writer
=
originalWriter
}
type
chatCompletionsResponseWriter
struct
{
gin
.
ResponseWriter
stream
bool
includeUsage
bool
buffer
bytes
.
Buffer
streamBuf
bytes
.
Buffer
state
*
chatCompletionStreamState
corrector
*
service
.
CodexToolCorrector
finalized
bool
passthrough
bool
}
type
chatCompletionStreamState
struct
{
id
string
model
string
created
int64
sentRole
bool
sawToolCall
bool
sawText
bool
toolCallIndex
map
[
string
]
int
usage
map
[
string
]
any
}
func
newChatCompletionsResponseWriter
(
w
gin
.
ResponseWriter
,
stream
bool
,
includeUsage
bool
,
model
string
)
*
chatCompletionsResponseWriter
{
return
&
chatCompletionsResponseWriter
{
ResponseWriter
:
w
,
stream
:
stream
,
includeUsage
:
includeUsage
,
state
:
&
chatCompletionStreamState
{
model
:
strings
.
TrimSpace
(
model
),
toolCallIndex
:
make
(
map
[
string
]
int
),
},
corrector
:
service
.
NewCodexToolCorrector
(),
}
}
func
(
w
*
chatCompletionsResponseWriter
)
Write
(
data
[]
byte
)
(
int
,
error
)
{
if
w
.
passthrough
{
return
w
.
ResponseWriter
.
Write
(
data
)
}
if
w
.
stream
{
n
,
err
:=
w
.
streamBuf
.
Write
(
data
)
if
err
!=
nil
{
return
n
,
err
}
w
.
flushStreamBuffer
()
return
n
,
nil
}
if
w
.
finalized
{
return
len
(
data
),
nil
}
return
w
.
buffer
.
Write
(
data
)
}
func
(
w
*
chatCompletionsResponseWriter
)
WriteString
(
s
string
)
(
int
,
error
)
{
return
w
.
Write
([]
byte
(
s
))
}
func
(
w
*
chatCompletionsResponseWriter
)
Finalize
()
{
if
w
.
finalized
{
return
}
w
.
finalized
=
true
if
w
.
passthrough
{
return
}
if
w
.
stream
{
return
}
body
:=
w
.
buffer
.
Bytes
()
if
len
(
body
)
==
0
{
return
}
w
.
ResponseWriter
.
Header
()
.
Del
(
"Content-Length"
)
converted
,
err
:=
service
.
ConvertResponsesToChatCompletion
(
body
)
if
err
!=
nil
{
_
,
_
=
w
.
ResponseWriter
.
Write
(
body
)
return
}
corrected
:=
converted
if
correctedStr
,
ok
:=
w
.
corrector
.
CorrectToolCallsInSSEData
(
string
(
converted
));
ok
{
corrected
=
[]
byte
(
correctedStr
)
}
_
,
_
=
w
.
ResponseWriter
.
Write
(
corrected
)
}
func
(
w
*
chatCompletionsResponseWriter
)
SetPassthrough
()
{
w
.
passthrough
=
true
}
func
(
w
*
chatCompletionsResponseWriter
)
Status
()
int
{
if
w
.
ResponseWriter
==
nil
{
return
0
}
return
w
.
ResponseWriter
.
Status
()
}
func
(
w
*
chatCompletionsResponseWriter
)
Written
()
bool
{
if
w
.
ResponseWriter
==
nil
{
return
false
}
return
w
.
ResponseWriter
.
Written
()
}
func
(
w
*
chatCompletionsResponseWriter
)
flushStreamBuffer
()
{
for
{
buf
:=
w
.
streamBuf
.
Bytes
()
idx
:=
bytes
.
IndexByte
(
buf
,
'\n'
)
if
idx
==
-
1
{
return
}
lineBytes
:=
w
.
streamBuf
.
Next
(
idx
+
1
)
line
:=
strings
.
TrimRight
(
string
(
lineBytes
),
"
\r\n
"
)
w
.
handleStreamLine
(
line
)
}
}
func
(
w
*
chatCompletionsResponseWriter
)
handleStreamLine
(
line
string
)
{
if
line
==
""
{
return
}
if
strings
.
HasPrefix
(
line
,
":"
)
{
_
,
_
=
w
.
ResponseWriter
.
Write
([]
byte
(
line
+
"
\n\n
"
))
return
}
if
!
strings
.
HasPrefix
(
line
,
"data:"
)
{
return
}
data
:=
strings
.
TrimSpace
(
strings
.
TrimPrefix
(
line
,
"data:"
))
for
_
,
chunk
:=
range
w
.
convertResponseDataToChatChunks
(
data
)
{
if
chunk
==
""
{
continue
}
if
chunk
==
"[DONE]"
{
_
,
_
=
w
.
ResponseWriter
.
Write
([]
byte
(
"data: [DONE]
\n\n
"
))
continue
}
_
,
_
=
w
.
ResponseWriter
.
Write
([]
byte
(
"data: "
+
chunk
+
"
\n\n
"
))
}
}
func
(
w
*
chatCompletionsResponseWriter
)
convertResponseDataToChatChunks
(
data
string
)
[]
string
{
if
data
==
""
{
return
nil
}
if
data
==
"[DONE]"
{
return
[]
string
{
"[DONE]"
}
}
var
payload
map
[
string
]
any
if
err
:=
json
.
Unmarshal
([]
byte
(
data
),
&
payload
);
err
!=
nil
{
return
[]
string
{
data
}
}
if
_
,
ok
:=
payload
[
"error"
];
ok
{
return
[]
string
{
data
}
}
eventType
:=
strings
.
TrimSpace
(
getString
(
payload
[
"type"
]))
if
eventType
==
""
{
return
[]
string
{
data
}
}
w
.
state
.
applyMetadata
(
payload
)
switch
eventType
{
case
"response.created"
:
return
nil
case
"response.output_text.delta"
:
delta
:=
getString
(
payload
[
"delta"
])
if
delta
==
""
{
return
nil
}
w
.
state
.
sawText
=
true
return
[]
string
{
w
.
buildTextDeltaChunk
(
delta
)}
case
"response.output_text.done"
:
if
w
.
state
.
sawText
{
return
nil
}
text
:=
getString
(
payload
[
"text"
])
if
text
==
""
{
return
nil
}
w
.
state
.
sawText
=
true
return
[]
string
{
w
.
buildTextDeltaChunk
(
text
)}
case
"response.output_item.added"
,
"response.output_item.delta"
:
if
item
,
ok
:=
payload
[
"item"
]
.
(
map
[
string
]
any
);
ok
{
if
callID
,
name
,
args
,
ok
:=
extractToolCallFromItem
(
item
);
ok
{
w
.
state
.
sawToolCall
=
true
return
[]
string
{
w
.
buildToolCallChunk
(
callID
,
name
,
args
)}
}
}
case
"response.completed"
,
"response.done"
:
if
responseObj
,
ok
:=
payload
[
"response"
]
.
(
map
[
string
]
any
);
ok
{
w
.
state
.
applyResponseUsage
(
responseObj
)
}
return
[]
string
{
w
.
buildFinalChunk
()}
}
if
strings
.
Contains
(
eventType
,
"tool_call"
)
||
strings
.
Contains
(
eventType
,
"function_call"
)
{
callID
:=
strings
.
TrimSpace
(
getString
(
payload
[
"call_id"
]))
if
callID
==
""
{
callID
=
strings
.
TrimSpace
(
getString
(
payload
[
"tool_call_id"
]))
}
if
callID
==
""
{
callID
=
strings
.
TrimSpace
(
getString
(
payload
[
"id"
]))
}
args
:=
getString
(
payload
[
"delta"
])
name
:=
strings
.
TrimSpace
(
getString
(
payload
[
"name"
]))
if
callID
!=
""
&&
(
args
!=
""
||
name
!=
""
)
{
w
.
state
.
sawToolCall
=
true
return
[]
string
{
w
.
buildToolCallChunk
(
callID
,
name
,
args
)}
}
}
return
nil
}
func
(
w
*
chatCompletionsResponseWriter
)
buildTextDeltaChunk
(
delta
string
)
string
{
w
.
state
.
ensureDefaults
()
payload
:=
map
[
string
]
any
{
"content"
:
delta
,
}
if
!
w
.
state
.
sentRole
{
payload
[
"role"
]
=
"assistant"
w
.
state
.
sentRole
=
true
}
return
w
.
buildChunk
(
payload
,
nil
,
nil
)
}
func
(
w
*
chatCompletionsResponseWriter
)
buildToolCallChunk
(
callID
,
name
,
args
string
)
string
{
w
.
state
.
ensureDefaults
()
index
:=
w
.
state
.
toolCallIndexFor
(
callID
)
function
:=
map
[
string
]
any
{}
if
name
!=
""
{
function
[
"name"
]
=
name
}
if
args
!=
""
{
function
[
"arguments"
]
=
args
}
toolCall
:=
map
[
string
]
any
{
"index"
:
index
,
"id"
:
callID
,
"type"
:
"function"
,
"function"
:
function
,
}
delta
:=
map
[
string
]
any
{
"tool_calls"
:
[]
any
{
toolCall
},
}
if
!
w
.
state
.
sentRole
{
delta
[
"role"
]
=
"assistant"
w
.
state
.
sentRole
=
true
}
return
w
.
buildChunk
(
delta
,
nil
,
nil
)
}
func
(
w
*
chatCompletionsResponseWriter
)
buildFinalChunk
()
string
{
w
.
state
.
ensureDefaults
()
finishReason
:=
"stop"
if
w
.
state
.
sawToolCall
{
finishReason
=
"tool_calls"
}
usage
:=
map
[
string
]
any
(
nil
)
if
w
.
includeUsage
&&
w
.
state
.
usage
!=
nil
{
usage
=
w
.
state
.
usage
}
return
w
.
buildChunk
(
map
[
string
]
any
{},
finishReason
,
usage
)
}
func
(
w
*
chatCompletionsResponseWriter
)
buildChunk
(
delta
map
[
string
]
any
,
finishReason
any
,
usage
map
[
string
]
any
)
string
{
w
.
state
.
ensureDefaults
()
chunk
:=
map
[
string
]
any
{
"id"
:
w
.
state
.
id
,
"object"
:
"chat.completion.chunk"
,
"created"
:
w
.
state
.
created
,
"model"
:
w
.
state
.
model
,
"choices"
:
[]
any
{
map
[
string
]
any
{
"index"
:
0
,
"delta"
:
delta
,
"finish_reason"
:
finishReason
,
},
},
}
if
usage
!=
nil
{
chunk
[
"usage"
]
=
usage
}
data
,
_
:=
json
.
Marshal
(
chunk
)
if
corrected
,
ok
:=
w
.
corrector
.
CorrectToolCallsInSSEData
(
string
(
data
));
ok
{
return
corrected
}
return
string
(
data
)
}
func
(
s
*
chatCompletionStreamState
)
ensureDefaults
()
{
if
s
.
id
==
""
{
s
.
id
=
"chatcmpl-"
+
randomHexUnsafe
(
12
)
}
if
s
.
model
==
""
{
s
.
model
=
"unknown"
}
if
s
.
created
==
0
{
s
.
created
=
time
.
Now
()
.
Unix
()
}
}
func
(
s
*
chatCompletionStreamState
)
toolCallIndexFor
(
callID
string
)
int
{
if
idx
,
ok
:=
s
.
toolCallIndex
[
callID
];
ok
{
return
idx
}
idx
:=
len
(
s
.
toolCallIndex
)
s
.
toolCallIndex
[
callID
]
=
idx
return
idx
}
func
(
s
*
chatCompletionStreamState
)
applyMetadata
(
payload
map
[
string
]
any
)
{
if
responseObj
,
ok
:=
payload
[
"response"
]
.
(
map
[
string
]
any
);
ok
{
s
.
applyResponseMetadata
(
responseObj
)
}
if
s
.
id
==
""
{
if
id
:=
strings
.
TrimSpace
(
getString
(
payload
[
"response_id"
]));
id
!=
""
{
s
.
id
=
id
}
else
if
id
:=
strings
.
TrimSpace
(
getString
(
payload
[
"id"
]));
id
!=
""
{
s
.
id
=
id
}
}
if
s
.
model
==
""
{
if
model
:=
strings
.
TrimSpace
(
getString
(
payload
[
"model"
]));
model
!=
""
{
s
.
model
=
model
}
}
if
s
.
created
==
0
{
if
created
:=
getInt64
(
payload
[
"created_at"
]);
created
!=
0
{
s
.
created
=
created
}
else
if
created
:=
getInt64
(
payload
[
"created"
]);
created
!=
0
{
s
.
created
=
created
}
}
}
func
(
s
*
chatCompletionStreamState
)
applyResponseMetadata
(
responseObj
map
[
string
]
any
)
{
if
s
.
id
==
""
{
if
id
:=
strings
.
TrimSpace
(
getString
(
responseObj
[
"id"
]));
id
!=
""
{
s
.
id
=
id
}
}
if
s
.
model
==
""
{
if
model
:=
strings
.
TrimSpace
(
getString
(
responseObj
[
"model"
]));
model
!=
""
{
s
.
model
=
model
}
}
if
s
.
created
==
0
{
if
created
:=
getInt64
(
responseObj
[
"created_at"
]);
created
!=
0
{
s
.
created
=
created
}
}
}
func
(
s
*
chatCompletionStreamState
)
applyResponseUsage
(
responseObj
map
[
string
]
any
)
{
usage
,
ok
:=
responseObj
[
"usage"
]
.
(
map
[
string
]
any
)
if
!
ok
{
return
}
promptTokens
:=
int
(
getNumber
(
usage
[
"input_tokens"
]))
completionTokens
:=
int
(
getNumber
(
usage
[
"output_tokens"
]))
if
promptTokens
==
0
&&
completionTokens
==
0
{
return
}
s
.
usage
=
map
[
string
]
any
{
"prompt_tokens"
:
promptTokens
,
"completion_tokens"
:
completionTokens
,
"total_tokens"
:
promptTokens
+
completionTokens
,
}
}
func
extractToolCallFromItem
(
item
map
[
string
]
any
)
(
string
,
string
,
string
,
bool
)
{
itemType
:=
strings
.
TrimSpace
(
getString
(
item
[
"type"
]))
if
itemType
!=
"tool_call"
&&
itemType
!=
"function_call"
{
return
""
,
""
,
""
,
false
}
callID
:=
strings
.
TrimSpace
(
getString
(
item
[
"call_id"
]))
if
callID
==
""
{
callID
=
strings
.
TrimSpace
(
getString
(
item
[
"id"
]))
}
name
:=
strings
.
TrimSpace
(
getString
(
item
[
"name"
]))
args
:=
getString
(
item
[
"arguments"
])
if
fn
,
ok
:=
item
[
"function"
]
.
(
map
[
string
]
any
);
ok
{
if
name
==
""
{
name
=
strings
.
TrimSpace
(
getString
(
fn
[
"name"
]))
}
if
args
==
""
{
args
=
getString
(
fn
[
"arguments"
])
}
}
if
callID
==
""
&&
name
==
""
&&
args
==
""
{
return
""
,
""
,
""
,
false
}
if
callID
==
""
{
callID
=
"call_"
+
randomHexUnsafe
(
6
)
}
return
callID
,
name
,
args
,
true
}
func
getString
(
value
any
)
string
{
switch
v
:=
value
.
(
type
)
{
case
string
:
return
v
case
[]
byte
:
return
string
(
v
)
case
json
.
Number
:
return
v
.
String
()
default
:
return
""
}
}
func
getNumber
(
value
any
)
float64
{
switch
v
:=
value
.
(
type
)
{
case
float64
:
return
v
case
float32
:
return
float64
(
v
)
case
int
:
return
float64
(
v
)
case
int64
:
return
float64
(
v
)
case
json
.
Number
:
f
,
_
:=
v
.
Float64
()
return
f
default
:
return
0
}
}
func
getInt64
(
value
any
)
int64
{
switch
v
:=
value
.
(
type
)
{
case
int64
:
return
v
case
int
:
return
int64
(
v
)
case
float64
:
return
int64
(
v
)
case
json
.
Number
:
i
,
_
:=
v
.
Int64
()
return
i
default
:
return
0
}
}
func
randomHexUnsafe
(
byteLength
int
)
string
{
if
byteLength
<=
0
{
byteLength
=
8
}
buf
:=
make
([]
byte
,
byteLength
)
if
_
,
err
:=
rand
.
Read
(
buf
);
err
!=
nil
{
return
"000000"
}
return
hex
.
EncodeToString
(
buf
)
}
backend/internal/server/routes/gateway.go
View file @
8dd38f47
...
@@ -71,15 +71,8 @@ func RegisterGatewayRoutes(
...
@@ -71,15 +71,8 @@ func RegisterGatewayRoutes(
gateway
.
POST
(
"/responses"
,
h
.
OpenAIGateway
.
Responses
)
gateway
.
POST
(
"/responses"
,
h
.
OpenAIGateway
.
Responses
)
gateway
.
POST
(
"/responses/*subpath"
,
h
.
OpenAIGateway
.
Responses
)
gateway
.
POST
(
"/responses/*subpath"
,
h
.
OpenAIGateway
.
Responses
)
gateway
.
GET
(
"/responses"
,
h
.
OpenAIGateway
.
ResponsesWebSocket
)
gateway
.
GET
(
"/responses"
,
h
.
OpenAIGateway
.
ResponsesWebSocket
)
// 明确阻止旧协议入口:OpenAI 仅支持 Responses API,避免客户端误解为会自动路由到其它平台。
// OpenAI Chat Completions API
gateway
.
POST
(
"/chat/completions"
,
func
(
c
*
gin
.
Context
)
{
gateway
.
POST
(
"/chat/completions"
,
h
.
OpenAIGateway
.
ChatCompletions
)
c
.
JSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
gin
.
H
{
"type"
:
"invalid_request_error"
,
"message"
:
"Unsupported legacy protocol: /v1/chat/completions is not supported. Please use /v1/responses."
,
},
})
})
}
}
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
...
@@ -100,6 +93,8 @@ func RegisterGatewayRoutes(
...
@@ -100,6 +93,8 @@ func RegisterGatewayRoutes(
r
.
POST
(
"/responses"
,
bodyLimit
,
clientRequestID
,
opsErrorLogger
,
gin
.
HandlerFunc
(
apiKeyAuth
),
requireGroupAnthropic
,
h
.
OpenAIGateway
.
Responses
)
r
.
POST
(
"/responses"
,
bodyLimit
,
clientRequestID
,
opsErrorLogger
,
gin
.
HandlerFunc
(
apiKeyAuth
),
requireGroupAnthropic
,
h
.
OpenAIGateway
.
Responses
)
r
.
POST
(
"/responses/*subpath"
,
bodyLimit
,
clientRequestID
,
opsErrorLogger
,
gin
.
HandlerFunc
(
apiKeyAuth
),
requireGroupAnthropic
,
h
.
OpenAIGateway
.
Responses
)
r
.
POST
(
"/responses/*subpath"
,
bodyLimit
,
clientRequestID
,
opsErrorLogger
,
gin
.
HandlerFunc
(
apiKeyAuth
),
requireGroupAnthropic
,
h
.
OpenAIGateway
.
Responses
)
r
.
GET
(
"/responses"
,
bodyLimit
,
clientRequestID
,
opsErrorLogger
,
gin
.
HandlerFunc
(
apiKeyAuth
),
requireGroupAnthropic
,
h
.
OpenAIGateway
.
ResponsesWebSocket
)
r
.
GET
(
"/responses"
,
bodyLimit
,
clientRequestID
,
opsErrorLogger
,
gin
.
HandlerFunc
(
apiKeyAuth
),
requireGroupAnthropic
,
h
.
OpenAIGateway
.
ResponsesWebSocket
)
// OpenAI Chat Completions API(不带v1前缀的别名)
r
.
POST
(
"/chat/completions"
,
bodyLimit
,
clientRequestID
,
opsErrorLogger
,
gin
.
HandlerFunc
(
apiKeyAuth
),
requireGroupAnthropic
,
h
.
OpenAIGateway
.
ChatCompletions
)
// Antigravity 模型列表
// Antigravity 模型列表
r
.
GET
(
"/antigravity/models"
,
gin
.
HandlerFunc
(
apiKeyAuth
),
requireGroupAnthropic
,
h
.
Gateway
.
AntigravityModels
)
r
.
GET
(
"/antigravity/models"
,
gin
.
HandlerFunc
(
apiKeyAuth
),
requireGroupAnthropic
,
h
.
Gateway
.
AntigravityModels
)
...
...
backend/internal/service/openai_chat_completions.go
0 → 100644
View file @
8dd38f47
package
service
import
(
"encoding/json"
"errors"
"strings"
"time"
)
// ConvertChatCompletionsToResponses converts an OpenAI Chat Completions request to a Responses request.
func
ConvertChatCompletionsToResponses
(
req
map
[
string
]
any
)
(
map
[
string
]
any
,
error
)
{
if
req
==
nil
{
return
nil
,
errors
.
New
(
"request is nil"
)
}
model
:=
strings
.
TrimSpace
(
getString
(
req
[
"model"
]))
if
model
==
""
{
return
nil
,
errors
.
New
(
"model is required"
)
}
messagesRaw
,
ok
:=
req
[
"messages"
]
if
!
ok
{
return
nil
,
errors
.
New
(
"messages is required"
)
}
messages
,
ok
:=
messagesRaw
.
([]
any
)
if
!
ok
{
return
nil
,
errors
.
New
(
"messages must be an array"
)
}
input
,
err
:=
convertChatMessagesToResponsesInput
(
messages
)
if
err
!=
nil
{
return
nil
,
err
}
out
:=
make
(
map
[
string
]
any
,
len
(
req
)
+
1
)
for
key
,
value
:=
range
req
{
switch
key
{
case
"messages"
,
"max_tokens"
,
"max_completion_tokens"
,
"stream_options"
,
"functions"
,
"function_call"
:
continue
default
:
out
[
key
]
=
value
}
}
out
[
"model"
]
=
model
out
[
"input"
]
=
input
if
_
,
ok
:=
out
[
"max_output_tokens"
];
!
ok
{
if
v
,
ok
:=
req
[
"max_tokens"
];
ok
{
out
[
"max_output_tokens"
]
=
v
}
else
if
v
,
ok
:=
req
[
"max_completion_tokens"
];
ok
{
out
[
"max_output_tokens"
]
=
v
}
}
if
_
,
ok
:=
out
[
"tools"
];
!
ok
{
if
functions
,
ok
:=
req
[
"functions"
]
.
([]
any
);
ok
&&
len
(
functions
)
>
0
{
tools
:=
make
([]
any
,
0
,
len
(
functions
))
for
_
,
fn
:=
range
functions
{
if
fnMap
,
ok
:=
fn
.
(
map
[
string
]
any
);
ok
{
tools
=
append
(
tools
,
map
[
string
]
any
{
"type"
:
"function"
,
"function"
:
fnMap
,
})
}
}
if
len
(
tools
)
>
0
{
out
[
"tools"
]
=
tools
}
}
}
if
_
,
ok
:=
out
[
"tool_choice"
];
!
ok
{
if
functionCall
,
ok
:=
req
[
"function_call"
];
ok
{
out
[
"tool_choice"
]
=
functionCall
}
}
return
out
,
nil
}
// ConvertResponsesToChatCompletion converts an OpenAI Responses response body to Chat Completions format.
func
ConvertResponsesToChatCompletion
(
body
[]
byte
)
([]
byte
,
error
)
{
var
resp
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
resp
);
err
!=
nil
{
return
nil
,
err
}
id
:=
strings
.
TrimSpace
(
getString
(
resp
[
"id"
]))
if
id
==
""
{
id
=
"chatcmpl-"
+
safeRandomHex
(
12
)
}
model
:=
strings
.
TrimSpace
(
getString
(
resp
[
"model"
]))
created
:=
getInt64
(
resp
[
"created_at"
])
if
created
==
0
{
created
=
getInt64
(
resp
[
"created"
])
}
if
created
==
0
{
created
=
time
.
Now
()
.
Unix
()
}
text
,
toolCalls
:=
extractResponseTextAndToolCalls
(
resp
)
finishReason
:=
"stop"
if
len
(
toolCalls
)
>
0
{
finishReason
=
"tool_calls"
}
message
:=
map
[
string
]
any
{
"role"
:
"assistant"
,
"content"
:
text
,
}
if
len
(
toolCalls
)
>
0
{
message
[
"tool_calls"
]
=
toolCalls
}
chatResp
:=
map
[
string
]
any
{
"id"
:
id
,
"object"
:
"chat.completion"
,
"created"
:
created
,
"model"
:
model
,
"choices"
:
[]
any
{
map
[
string
]
any
{
"index"
:
0
,
"message"
:
message
,
"finish_reason"
:
finishReason
,
},
},
}
if
usage
:=
extractResponseUsage
(
resp
);
usage
!=
nil
{
chatResp
[
"usage"
]
=
usage
}
if
fingerprint
:=
strings
.
TrimSpace
(
getString
(
resp
[
"system_fingerprint"
]));
fingerprint
!=
""
{
chatResp
[
"system_fingerprint"
]
=
fingerprint
}
return
json
.
Marshal
(
chatResp
)
}
func
convertChatMessagesToResponsesInput
(
messages
[]
any
)
([]
any
,
error
)
{
input
:=
make
([]
any
,
0
,
len
(
messages
))
for
_
,
msg
:=
range
messages
{
msgMap
,
ok
:=
msg
.
(
map
[
string
]
any
)
if
!
ok
{
return
nil
,
errors
.
New
(
"message must be an object"
)
}
role
:=
strings
.
TrimSpace
(
getString
(
msgMap
[
"role"
]))
if
role
==
""
{
return
nil
,
errors
.
New
(
"message role is required"
)
}
switch
role
{
case
"tool"
:
callID
:=
strings
.
TrimSpace
(
getString
(
msgMap
[
"tool_call_id"
]))
if
callID
==
""
{
callID
=
strings
.
TrimSpace
(
getString
(
msgMap
[
"id"
]))
}
output
:=
extractMessageContentText
(
msgMap
[
"content"
])
input
=
append
(
input
,
map
[
string
]
any
{
"type"
:
"function_call_output"
,
"call_id"
:
callID
,
"output"
:
output
,
})
case
"function"
:
callID
:=
strings
.
TrimSpace
(
getString
(
msgMap
[
"name"
]))
output
:=
extractMessageContentText
(
msgMap
[
"content"
])
input
=
append
(
input
,
map
[
string
]
any
{
"type"
:
"function_call_output"
,
"call_id"
:
callID
,
"output"
:
output
,
})
default
:
convertedContent
:=
convertChatContent
(
msgMap
[
"content"
])
toolCalls
:=
[]
any
(
nil
)
if
role
==
"assistant"
{
toolCalls
=
extractToolCallsFromMessage
(
msgMap
)
}
skipAssistantMessage
:=
role
==
"assistant"
&&
len
(
toolCalls
)
>
0
&&
isEmptyContent
(
convertedContent
)
if
!
skipAssistantMessage
{
msgItem
:=
map
[
string
]
any
{
"role"
:
role
,
"content"
:
convertedContent
,
}
if
name
:=
strings
.
TrimSpace
(
getString
(
msgMap
[
"name"
]));
name
!=
""
{
msgItem
[
"name"
]
=
name
}
input
=
append
(
input
,
msgItem
)
}
if
role
==
"assistant"
&&
len
(
toolCalls
)
>
0
{
input
=
append
(
input
,
toolCalls
...
)
}
}
}
return
input
,
nil
}
func
convertChatContent
(
content
any
)
any
{
switch
v
:=
content
.
(
type
)
{
case
nil
:
return
""
case
string
:
return
v
case
[]
any
:
converted
:=
make
([]
any
,
0
,
len
(
v
))
for
_
,
part
:=
range
v
{
partMap
,
ok
:=
part
.
(
map
[
string
]
any
)
if
!
ok
{
converted
=
append
(
converted
,
part
)
continue
}
partType
:=
strings
.
TrimSpace
(
getString
(
partMap
[
"type"
]))
switch
partType
{
case
"text"
:
text
:=
getString
(
partMap
[
"text"
])
if
text
!=
""
{
converted
=
append
(
converted
,
map
[
string
]
any
{
"type"
:
"input_text"
,
"text"
:
text
,
})
continue
}
case
"image_url"
:
imageURL
:=
""
if
imageObj
,
ok
:=
partMap
[
"image_url"
]
.
(
map
[
string
]
any
);
ok
{
imageURL
=
getString
(
imageObj
[
"url"
])
}
else
{
imageURL
=
getString
(
partMap
[
"image_url"
])
}
if
imageURL
!=
""
{
converted
=
append
(
converted
,
map
[
string
]
any
{
"type"
:
"input_image"
,
"image_url"
:
imageURL
,
})
continue
}
case
"input_text"
,
"input_image"
:
converted
=
append
(
converted
,
partMap
)
continue
}
converted
=
append
(
converted
,
partMap
)
}
return
converted
default
:
return
v
}
}
func
extractToolCallsFromMessage
(
msg
map
[
string
]
any
)
[]
any
{
var
out
[]
any
if
toolCalls
,
ok
:=
msg
[
"tool_calls"
]
.
([]
any
);
ok
{
for
_
,
call
:=
range
toolCalls
{
callMap
,
ok
:=
call
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
callID
:=
strings
.
TrimSpace
(
getString
(
callMap
[
"id"
]))
if
callID
==
""
{
callID
=
strings
.
TrimSpace
(
getString
(
callMap
[
"call_id"
]))
}
name
:=
""
args
:=
""
if
fn
,
ok
:=
callMap
[
"function"
]
.
(
map
[
string
]
any
);
ok
{
name
=
strings
.
TrimSpace
(
getString
(
fn
[
"name"
]))
args
=
getString
(
fn
[
"arguments"
])
}
if
name
==
""
&&
args
==
""
{
continue
}
item
:=
map
[
string
]
any
{
"type"
:
"tool_call"
,
}
if
callID
!=
""
{
item
[
"call_id"
]
=
callID
}
if
name
!=
""
{
item
[
"name"
]
=
name
}
if
args
!=
""
{
item
[
"arguments"
]
=
args
}
out
=
append
(
out
,
item
)
}
}
if
fnCall
,
ok
:=
msg
[
"function_call"
]
.
(
map
[
string
]
any
);
ok
{
name
:=
strings
.
TrimSpace
(
getString
(
fnCall
[
"name"
]))
args
:=
getString
(
fnCall
[
"arguments"
])
if
name
!=
""
||
args
!=
""
{
callID
:=
strings
.
TrimSpace
(
getString
(
msg
[
"tool_call_id"
]))
if
callID
==
""
{
callID
=
name
}
item
:=
map
[
string
]
any
{
"type"
:
"function_call"
,
}
if
callID
!=
""
{
item
[
"call_id"
]
=
callID
}
if
name
!=
""
{
item
[
"name"
]
=
name
}
if
args
!=
""
{
item
[
"arguments"
]
=
args
}
out
=
append
(
out
,
item
)
}
}
return
out
}
func
extractMessageContentText
(
content
any
)
string
{
switch
v
:=
content
.
(
type
)
{
case
string
:
return
v
case
[]
any
:
parts
:=
make
([]
string
,
0
,
len
(
v
))
for
_
,
part
:=
range
v
{
partMap
,
ok
:=
part
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
partType
:=
strings
.
TrimSpace
(
getString
(
partMap
[
"type"
]))
if
partType
==
""
||
partType
==
"text"
||
partType
==
"output_text"
||
partType
==
"input_text"
{
text
:=
getString
(
partMap
[
"text"
])
if
text
!=
""
{
parts
=
append
(
parts
,
text
)
}
}
}
return
strings
.
Join
(
parts
,
""
)
default
:
return
""
}
}
func
isEmptyContent
(
content
any
)
bool
{
switch
v
:=
content
.
(
type
)
{
case
nil
:
return
true
case
string
:
return
strings
.
TrimSpace
(
v
)
==
""
case
[]
any
:
return
len
(
v
)
==
0
default
:
return
false
}
}
func
extractResponseTextAndToolCalls
(
resp
map
[
string
]
any
)
(
string
,
[]
any
)
{
output
,
ok
:=
resp
[
"output"
]
.
([]
any
)
if
!
ok
{
if
text
,
ok
:=
resp
[
"output_text"
]
.
(
string
);
ok
{
return
text
,
nil
}
return
""
,
nil
}
textParts
:=
make
([]
string
,
0
)
toolCalls
:=
make
([]
any
,
0
)
for
_
,
item
:=
range
output
{
itemMap
,
ok
:=
item
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
itemType
:=
strings
.
TrimSpace
(
getString
(
itemMap
[
"type"
]))
if
itemType
==
"tool_call"
||
itemType
==
"function_call"
{
if
tc
:=
responseItemToChatToolCall
(
itemMap
);
tc
!=
nil
{
toolCalls
=
append
(
toolCalls
,
tc
)
}
continue
}
content
:=
itemMap
[
"content"
]
switch
v
:=
content
.
(
type
)
{
case
string
:
if
v
!=
""
{
textParts
=
append
(
textParts
,
v
)
}
case
[]
any
:
for
_
,
part
:=
range
v
{
partMap
,
ok
:=
part
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
partType
:=
strings
.
TrimSpace
(
getString
(
partMap
[
"type"
]))
switch
partType
{
case
"output_text"
,
"text"
,
"input_text"
:
text
:=
getString
(
partMap
[
"text"
])
if
text
!=
""
{
textParts
=
append
(
textParts
,
text
)
}
case
"tool_call"
,
"function_call"
:
if
tc
:=
responseItemToChatToolCall
(
partMap
);
tc
!=
nil
{
toolCalls
=
append
(
toolCalls
,
tc
)
}
}
}
}
}
return
strings
.
Join
(
textParts
,
""
),
toolCalls
}
func
responseItemToChatToolCall
(
item
map
[
string
]
any
)
map
[
string
]
any
{
callID
:=
strings
.
TrimSpace
(
getString
(
item
[
"call_id"
]))
if
callID
==
""
{
callID
=
strings
.
TrimSpace
(
getString
(
item
[
"id"
]))
}
name
:=
strings
.
TrimSpace
(
getString
(
item
[
"name"
]))
arguments
:=
getString
(
item
[
"arguments"
])
if
fn
,
ok
:=
item
[
"function"
]
.
(
map
[
string
]
any
);
ok
{
if
name
==
""
{
name
=
strings
.
TrimSpace
(
getString
(
fn
[
"name"
]))
}
if
arguments
==
""
{
arguments
=
getString
(
fn
[
"arguments"
])
}
}
if
name
==
""
&&
arguments
==
""
&&
callID
==
""
{
return
nil
}
if
callID
==
""
{
callID
=
"call_"
+
safeRandomHex
(
6
)
}
return
map
[
string
]
any
{
"id"
:
callID
,
"type"
:
"function"
,
"function"
:
map
[
string
]
any
{
"name"
:
name
,
"arguments"
:
arguments
,
},
}
}
func
extractResponseUsage
(
resp
map
[
string
]
any
)
map
[
string
]
any
{
usage
,
ok
:=
resp
[
"usage"
]
.
(
map
[
string
]
any
)
if
!
ok
{
return
nil
}
promptTokens
:=
int
(
getNumber
(
usage
[
"input_tokens"
]))
completionTokens
:=
int
(
getNumber
(
usage
[
"output_tokens"
]))
if
promptTokens
==
0
&&
completionTokens
==
0
{
return
nil
}
return
map
[
string
]
any
{
"prompt_tokens"
:
promptTokens
,
"completion_tokens"
:
completionTokens
,
"total_tokens"
:
promptTokens
+
completionTokens
,
}
}
func
getString
(
value
any
)
string
{
switch
v
:=
value
.
(
type
)
{
case
string
:
return
v
case
[]
byte
:
return
string
(
v
)
case
json
.
Number
:
return
v
.
String
()
default
:
return
""
}
}
func
getNumber
(
value
any
)
float64
{
switch
v
:=
value
.
(
type
)
{
case
float64
:
return
v
case
float32
:
return
float64
(
v
)
case
int
:
return
float64
(
v
)
case
int64
:
return
float64
(
v
)
case
json
.
Number
:
f
,
_
:=
v
.
Float64
()
return
f
default
:
return
0
}
}
func
getInt64
(
value
any
)
int64
{
switch
v
:=
value
.
(
type
)
{
case
int64
:
return
v
case
int
:
return
int64
(
v
)
case
float64
:
return
int64
(
v
)
case
json
.
Number
:
i
,
_
:=
v
.
Int64
()
return
i
default
:
return
0
}
}
func
safeRandomHex
(
byteLength
int
)
string
{
value
,
err
:=
randomHexString
(
byteLength
)
if
err
!=
nil
||
value
==
""
{
return
"000000"
}
return
value
}
backend/internal/service/openai_chat_completions_forward.go
0 → 100644
View file @
8dd38f47
package
service
import
(
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/gin-gonic/gin"
)
type
chatStreamingResult
struct
{
usage
*
OpenAIUsage
firstTokenMs
*
int
}
func
(
s
*
OpenAIGatewayService
)
forwardChatCompletions
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
,
includeUsage
bool
,
startTime
time
.
Time
)
(
*
OpenAIForwardResult
,
error
)
{
// Parse request body once (avoid multiple parse/serialize cycles)
var
reqBody
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
reqBody
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"parse request: %w"
,
err
)
}
reqModel
,
_
:=
reqBody
[
"model"
]
.
(
string
)
reqStream
,
_
:=
reqBody
[
"stream"
]
.
(
bool
)
originalModel
:=
reqModel
bodyModified
:=
false
mappedModel
:=
account
.
GetMappedModel
(
reqModel
)
if
mappedModel
!=
reqModel
{
log
.
Printf
(
"[OpenAI Chat] Model mapping applied: %s -> %s (account: %s)"
,
reqModel
,
mappedModel
,
account
.
Name
)
reqBody
[
"model"
]
=
mappedModel
bodyModified
=
true
}
if
reqStream
&&
includeUsage
{
streamOptions
,
_
:=
reqBody
[
"stream_options"
]
.
(
map
[
string
]
any
)
if
streamOptions
==
nil
{
streamOptions
=
map
[
string
]
any
{}
}
if
_
,
ok
:=
streamOptions
[
"include_usage"
];
!
ok
{
streamOptions
[
"include_usage"
]
=
true
reqBody
[
"stream_options"
]
=
streamOptions
bodyModified
=
true
}
}
if
bodyModified
{
var
err
error
body
,
err
=
json
.
Marshal
(
reqBody
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"serialize request body: %w"
,
err
)
}
}
// Get access token
token
,
_
,
err
:=
s
.
GetAccessToken
(
ctx
,
account
)
if
err
!=
nil
{
return
nil
,
err
}
upstreamReq
,
err
:=
s
.
buildChatCompletionsRequest
(
ctx
,
c
,
account
,
body
,
token
)
if
err
!=
nil
{
return
nil
,
err
}
proxyURL
:=
""
if
account
.
ProxyID
!=
nil
&&
account
.
Proxy
!=
nil
{
proxyURL
=
account
.
Proxy
.
URL
()
}
if
c
!=
nil
{
c
.
Set
(
OpsUpstreamRequestBodyKey
,
string
(
body
))
}
resp
,
err
:=
s
.
httpUpstream
.
Do
(
upstreamReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
err
!=
nil
{
safeErr
:=
sanitizeUpstreamErrorMessage
(
err
.
Error
())
setOpsUpstreamError
(
c
,
0
,
safeErr
,
""
)
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
0
,
Kind
:
"request_error"
,
Message
:
safeErr
,
})
c
.
JSON
(
http
.
StatusBadGateway
,
gin
.
H
{
"error"
:
gin
.
H
{
"type"
:
"upstream_error"
,
"message"
:
"Upstream request failed"
,
},
})
return
nil
,
fmt
.
Errorf
(
"upstream request failed: %s"
,
safeErr
)
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
if
resp
.
StatusCode
>=
400
{
if
s
.
shouldFailoverUpstreamError
(
resp
.
StatusCode
)
{
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
_
=
resp
.
Body
.
Close
()
resp
.
Body
=
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
))
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
respBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
upstreamDetail
:=
""
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
maxBytes
:=
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
if
maxBytes
<=
0
{
maxBytes
=
2048
}
upstreamDetail
=
truncateString
(
string
(
respBody
),
maxBytes
)
}
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Kind
:
"failover"
,
Message
:
upstreamMsg
,
Detail
:
upstreamDetail
,
})
s
.
handleFailoverSideEffects
(
ctx
,
resp
,
account
)
return
nil
,
&
UpstreamFailoverError
{
StatusCode
:
resp
.
StatusCode
}
}
return
s
.
handleErrorResponse
(
ctx
,
resp
,
c
,
account
,
body
)
}
var
usage
*
OpenAIUsage
var
firstTokenMs
*
int
if
reqStream
{
streamResult
,
err
:=
s
.
handleChatCompletionsStreamingResponse
(
ctx
,
resp
,
c
,
account
,
startTime
,
originalModel
,
mappedModel
)
if
err
!=
nil
{
return
nil
,
err
}
usage
=
streamResult
.
usage
firstTokenMs
=
streamResult
.
firstTokenMs
}
else
{
usage
,
err
=
s
.
handleChatCompletionsNonStreamingResponse
(
resp
,
c
,
originalModel
,
mappedModel
)
if
err
!=
nil
{
return
nil
,
err
}
}
if
usage
==
nil
{
usage
=
&
OpenAIUsage
{}
}
return
&
OpenAIForwardResult
{
RequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Usage
:
*
usage
,
Model
:
originalModel
,
Stream
:
reqStream
,
Duration
:
time
.
Since
(
startTime
),
FirstTokenMs
:
firstTokenMs
,
},
nil
}
func
(
s
*
OpenAIGatewayService
)
buildChatCompletionsRequest
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
,
token
string
)
(
*
http
.
Request
,
error
)
{
var
targetURL
string
baseURL
:=
account
.
GetOpenAIBaseURL
()
if
baseURL
==
""
{
targetURL
=
openaiChatAPIURL
}
else
{
validatedURL
,
err
:=
s
.
validateUpstreamBaseURL
(
baseURL
)
if
err
!=
nil
{
return
nil
,
err
}
targetURL
=
validatedURL
+
"/chat/completions"
}
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
"POST"
,
targetURL
,
bytes
.
NewReader
(
body
))
if
err
!=
nil
{
return
nil
,
err
}
req
.
Header
.
Set
(
"authorization"
,
"Bearer "
+
token
)
for
key
,
values
:=
range
c
.
Request
.
Header
{
lowerKey
:=
strings
.
ToLower
(
key
)
if
openaiChatAllowedHeaders
[
lowerKey
]
{
for
_
,
v
:=
range
values
{
req
.
Header
.
Add
(
key
,
v
)
}
}
}
customUA
:=
account
.
GetOpenAIUserAgent
()
if
customUA
!=
""
{
req
.
Header
.
Set
(
"user-agent"
,
customUA
)
}
if
req
.
Header
.
Get
(
"content-type"
)
==
""
{
req
.
Header
.
Set
(
"content-type"
,
"application/json"
)
}
return
req
,
nil
}
func
(
s
*
OpenAIGatewayService
)
handleChatCompletionsStreamingResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
Account
,
startTime
time
.
Time
,
originalModel
,
mappedModel
string
)
(
*
chatStreamingResult
,
error
)
{
if
s
.
responseHeaderFilter
!=
nil
{
responseheaders
.
WriteFilteredHeaders
(
c
.
Writer
.
Header
(),
resp
.
Header
,
s
.
responseHeaderFilter
)
}
c
.
Header
(
"Content-Type"
,
"text/event-stream"
)
c
.
Header
(
"Cache-Control"
,
"no-cache"
)
c
.
Header
(
"Connection"
,
"keep-alive"
)
c
.
Header
(
"X-Accel-Buffering"
,
"no"
)
if
v
:=
resp
.
Header
.
Get
(
"x-request-id"
);
v
!=
""
{
c
.
Header
(
"x-request-id"
,
v
)
}
w
:=
c
.
Writer
flusher
,
ok
:=
w
.
(
http
.
Flusher
)
if
!
ok
{
return
nil
,
errors
.
New
(
"streaming not supported"
)
}
usage
:=
&
OpenAIUsage
{}
var
firstTokenMs
*
int
scanner
:=
bufio
.
NewScanner
(
resp
.
Body
)
maxLineSize
:=
defaultMaxLineSize
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
MaxLineSize
>
0
{
maxLineSize
=
s
.
cfg
.
Gateway
.
MaxLineSize
}
scanner
.
Buffer
(
make
([]
byte
,
64
*
1024
),
maxLineSize
)
type
scanEvent
struct
{
line
string
err
error
}
events
:=
make
(
chan
scanEvent
,
16
)
done
:=
make
(
chan
struct
{})
sendEvent
:=
func
(
ev
scanEvent
)
bool
{
select
{
case
events
<-
ev
:
return
true
case
<-
done
:
return
false
}
}
var
lastReadAt
int64
atomic
.
StoreInt64
(
&
lastReadAt
,
time
.
Now
()
.
UnixNano
())
go
func
()
{
defer
close
(
events
)
for
scanner
.
Scan
()
{
atomic
.
StoreInt64
(
&
lastReadAt
,
time
.
Now
()
.
UnixNano
())
if
!
sendEvent
(
scanEvent
{
line
:
scanner
.
Text
()})
{
return
}
}
if
err
:=
scanner
.
Err
();
err
!=
nil
{
_
=
sendEvent
(
scanEvent
{
err
:
err
})
}
}()
defer
close
(
done
)
streamInterval
:=
time
.
Duration
(
0
)
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
StreamDataIntervalTimeout
>
0
{
streamInterval
=
time
.
Duration
(
s
.
cfg
.
Gateway
.
StreamDataIntervalTimeout
)
*
time
.
Second
}
var
intervalTicker
*
time
.
Ticker
if
streamInterval
>
0
{
intervalTicker
=
time
.
NewTicker
(
streamInterval
)
defer
intervalTicker
.
Stop
()
}
var
intervalCh
<-
chan
time
.
Time
if
intervalTicker
!=
nil
{
intervalCh
=
intervalTicker
.
C
}
keepaliveInterval
:=
time
.
Duration
(
0
)
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
StreamKeepaliveInterval
>
0
{
keepaliveInterval
=
time
.
Duration
(
s
.
cfg
.
Gateway
.
StreamKeepaliveInterval
)
*
time
.
Second
}
var
keepaliveTicker
*
time
.
Ticker
if
keepaliveInterval
>
0
{
keepaliveTicker
=
time
.
NewTicker
(
keepaliveInterval
)
defer
keepaliveTicker
.
Stop
()
}
var
keepaliveCh
<-
chan
time
.
Time
if
keepaliveTicker
!=
nil
{
keepaliveCh
=
keepaliveTicker
.
C
}
lastDataAt
:=
time
.
Now
()
errorEventSent
:=
false
sendErrorEvent
:=
func
(
reason
string
)
{
if
errorEventSent
{
return
}
errorEventSent
=
true
_
,
_
=
fmt
.
Fprintf
(
w
,
"event: error
\n
data: {
\"
error
\"
:
\"
%s
\"
}
\n\n
"
,
reason
)
flusher
.
Flush
()
}
needModelReplace
:=
originalModel
!=
mappedModel
for
{
select
{
case
ev
,
ok
:=
<-
events
:
if
!
ok
{
return
&
chatStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
}
if
ev
.
err
!=
nil
{
if
errors
.
Is
(
ev
.
err
,
bufio
.
ErrTooLong
)
{
log
.
Printf
(
"SSE line too long: account=%d max_size=%d error=%v"
,
account
.
ID
,
maxLineSize
,
ev
.
err
)
sendErrorEvent
(
"response_too_large"
)
return
&
chatStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
ev
.
err
}
sendErrorEvent
(
"stream_read_error"
)
return
&
chatStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream read error: %w"
,
ev
.
err
)
}
line
:=
ev
.
line
lastDataAt
=
time
.
Now
()
if
openaiSSEDataRe
.
MatchString
(
line
)
{
data
:=
openaiSSEDataRe
.
ReplaceAllString
(
line
,
""
)
if
needModelReplace
{
line
=
s
.
replaceModelInSSELine
(
line
,
mappedModel
,
originalModel
)
}
if
correctedData
,
corrected
:=
s
.
toolCorrector
.
CorrectToolCallsInSSEData
(
data
);
corrected
{
line
=
"data: "
+
correctedData
}
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
sendErrorEvent
(
"write_failed"
)
return
&
chatStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
}
flusher
.
Flush
()
if
firstTokenMs
==
nil
{
if
event
:=
parseChatStreamEvent
(
data
);
event
!=
nil
{
if
chatChunkHasDelta
(
event
)
{
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
firstTokenMs
=
&
ms
}
applyChatUsageFromEvent
(
event
,
usage
)
}
}
else
{
if
event
:=
parseChatStreamEvent
(
data
);
event
!=
nil
{
applyChatUsageFromEvent
(
event
,
usage
)
}
}
}
else
{
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
sendErrorEvent
(
"write_failed"
)
return
&
chatStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
}
flusher
.
Flush
()
}
case
<-
intervalCh
:
lastRead
:=
time
.
Unix
(
0
,
atomic
.
LoadInt64
(
&
lastReadAt
))
if
time
.
Since
(
lastRead
)
<
streamInterval
{
continue
}
log
.
Printf
(
"Stream data interval timeout: account=%d model=%s interval=%s"
,
account
.
ID
,
originalModel
,
streamInterval
)
if
s
.
rateLimitService
!=
nil
{
s
.
rateLimitService
.
HandleStreamTimeout
(
ctx
,
account
,
originalModel
)
}
sendErrorEvent
(
"stream_timeout"
)
return
&
chatStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream data interval timeout"
)
case
<-
keepaliveCh
:
if
time
.
Since
(
lastDataAt
)
<
keepaliveInterval
{
continue
}
if
_
,
err
:=
fmt
.
Fprint
(
w
,
":
\n\n
"
);
err
!=
nil
{
return
&
chatStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
}
flusher
.
Flush
()
}
}
}
func
(
s
*
OpenAIGatewayService
)
handleChatCompletionsNonStreamingResponse
(
resp
*
http
.
Response
,
c
*
gin
.
Context
,
originalModel
,
mappedModel
string
)
(
*
OpenAIUsage
,
error
)
{
body
,
err
:=
io
.
ReadAll
(
resp
.
Body
)
if
err
!=
nil
{
return
nil
,
err
}
usage
:=
&
OpenAIUsage
{}
var
parsed
map
[
string
]
any
if
json
.
Unmarshal
(
body
,
&
parsed
)
==
nil
{
if
usageMap
,
ok
:=
parsed
[
"usage"
]
.
(
map
[
string
]
any
);
ok
{
applyChatUsageFromMap
(
usageMap
,
usage
)
}
}
if
originalModel
!=
mappedModel
{
body
=
s
.
replaceModelInResponseBody
(
body
,
mappedModel
,
originalModel
)
}
body
=
s
.
correctToolCallsInResponseBody
(
body
)
if
s
.
responseHeaderFilter
!=
nil
{
responseheaders
.
WriteFilteredHeaders
(
c
.
Writer
.
Header
(),
resp
.
Header
,
s
.
responseHeaderFilter
)
}
contentType
:=
"application/json"
if
s
.
cfg
!=
nil
&&
!
s
.
cfg
.
Security
.
ResponseHeaders
.
Enabled
{
if
upstreamType
:=
resp
.
Header
.
Get
(
"Content-Type"
);
upstreamType
!=
""
{
contentType
=
upstreamType
}
}
c
.
Data
(
resp
.
StatusCode
,
contentType
,
body
)
return
usage
,
nil
}
func
parseChatStreamEvent
(
data
string
)
map
[
string
]
any
{
if
data
==
""
||
data
==
"[DONE]"
{
return
nil
}
var
event
map
[
string
]
any
if
json
.
Unmarshal
([]
byte
(
data
),
&
event
)
!=
nil
{
return
nil
}
return
event
}
func
chatChunkHasDelta
(
event
map
[
string
]
any
)
bool
{
choices
,
ok
:=
event
[
"choices"
]
.
([]
any
)
if
!
ok
{
return
false
}
for
_
,
choice
:=
range
choices
{
choiceMap
,
ok
:=
choice
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
delta
,
ok
:=
choiceMap
[
"delta"
]
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
if
content
,
ok
:=
delta
[
"content"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
content
)
!=
""
{
return
true
}
if
toolCalls
,
ok
:=
delta
[
"tool_calls"
]
.
([]
any
);
ok
&&
len
(
toolCalls
)
>
0
{
return
true
}
if
functionCall
,
ok
:=
delta
[
"function_call"
]
.
(
map
[
string
]
any
);
ok
&&
len
(
functionCall
)
>
0
{
return
true
}
}
return
false
}
func
applyChatUsageFromEvent
(
event
map
[
string
]
any
,
usage
*
OpenAIUsage
)
{
if
event
==
nil
||
usage
==
nil
{
return
}
usageMap
,
ok
:=
event
[
"usage"
]
.
(
map
[
string
]
any
)
if
!
ok
{
return
}
applyChatUsageFromMap
(
usageMap
,
usage
)
}
func
applyChatUsageFromMap
(
usageMap
map
[
string
]
any
,
usage
*
OpenAIUsage
)
{
if
usageMap
==
nil
||
usage
==
nil
{
return
}
promptTokens
:=
int
(
getNumber
(
usageMap
[
"prompt_tokens"
]))
completionTokens
:=
int
(
getNumber
(
usageMap
[
"completion_tokens"
]))
if
promptTokens
>
0
{
usage
.
InputTokens
=
promptTokens
}
if
completionTokens
>
0
{
usage
.
OutputTokens
=
completionTokens
}
}
backend/internal/service/openai_chat_completions_test.go
0 → 100644
View file @
8dd38f47
package
service
import
(
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
)
func
TestConvertChatCompletionsToResponses
(
t
*
testing
.
T
)
{
req
:=
map
[
string
]
any
{
"model"
:
"gpt-4o"
,
"messages"
:
[]
any
{
map
[
string
]
any
{
"role"
:
"user"
,
"content"
:
"hello"
,
},
map
[
string
]
any
{
"role"
:
"assistant"
,
"tool_calls"
:
[]
any
{
map
[
string
]
any
{
"id"
:
"call_1"
,
"type"
:
"function"
,
"function"
:
map
[
string
]
any
{
"name"
:
"ping"
,
"arguments"
:
"{}"
,
},
},
},
},
map
[
string
]
any
{
"role"
:
"tool"
,
"tool_call_id"
:
"call_1"
,
"content"
:
"ok"
,
"response"
:
"ignored"
,
"response_time"
:
1
,
},
},
"functions"
:
[]
any
{
map
[
string
]
any
{
"name"
:
"ping"
,
"description"
:
"ping tool"
,
"parameters"
:
map
[
string
]
any
{
"type"
:
"object"
},
},
},
"function_call"
:
map
[
string
]
any
{
"name"
:
"ping"
},
}
converted
,
err
:=
ConvertChatCompletionsToResponses
(
req
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"gpt-4o"
,
converted
[
"model"
])
input
,
ok
:=
converted
[
"input"
]
.
([]
any
)
require
.
True
(
t
,
ok
)
require
.
Len
(
t
,
input
,
3
)
toolCall
:=
findInputItemByType
(
input
,
"tool_call"
)
require
.
NotNil
(
t
,
toolCall
)
require
.
Equal
(
t
,
"call_1"
,
toolCall
[
"call_id"
])
toolOutput
:=
findInputItemByType
(
input
,
"function_call_output"
)
require
.
NotNil
(
t
,
toolOutput
)
require
.
Equal
(
t
,
"call_1"
,
toolOutput
[
"call_id"
])
tools
,
ok
:=
converted
[
"tools"
]
.
([]
any
)
require
.
True
(
t
,
ok
)
require
.
Len
(
t
,
tools
,
1
)
require
.
Equal
(
t
,
map
[
string
]
any
{
"name"
:
"ping"
},
converted
[
"tool_choice"
])
}
func
TestConvertResponsesToChatCompletion
(
t
*
testing
.
T
)
{
resp
:=
map
[
string
]
any
{
"id"
:
"resp_123"
,
"model"
:
"gpt-4o"
,
"created_at"
:
1700000000
,
"output"
:
[]
any
{
map
[
string
]
any
{
"type"
:
"message"
,
"role"
:
"assistant"
,
"content"
:
[]
any
{
map
[
string
]
any
{
"type"
:
"output_text"
,
"text"
:
"hi"
,
},
},
},
},
"usage"
:
map
[
string
]
any
{
"input_tokens"
:
2
,
"output_tokens"
:
3
,
},
}
body
,
err
:=
json
.
Marshal
(
resp
)
require
.
NoError
(
t
,
err
)
converted
,
err
:=
ConvertResponsesToChatCompletion
(
body
)
require
.
NoError
(
t
,
err
)
var
chat
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
converted
,
&
chat
))
require
.
Equal
(
t
,
"chat.completion"
,
chat
[
"object"
])
choices
,
ok
:=
chat
[
"choices"
]
.
([]
any
)
require
.
True
(
t
,
ok
)
require
.
Len
(
t
,
choices
,
1
)
choice
,
ok
:=
choices
[
0
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
message
,
ok
:=
choice
[
"message"
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"hi"
,
message
[
"content"
])
usage
,
ok
:=
chat
[
"usage"
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
float64
(
2
),
usage
[
"prompt_tokens"
])
require
.
Equal
(
t
,
float64
(
3
),
usage
[
"completion_tokens"
])
require
.
Equal
(
t
,
float64
(
5
),
usage
[
"total_tokens"
])
}
func
findInputItemByType
(
items
[]
any
,
itemType
string
)
map
[
string
]
any
{
for
_
,
item
:=
range
items
{
itemMap
,
ok
:=
item
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
if
itemMap
[
"type"
]
==
itemType
{
return
itemMap
}
}
return
nil
}
backend/internal/service/openai_gateway_service.go
View file @
8dd38f47
...
@@ -12,6 +12,7 @@ import (
...
@@ -12,6 +12,7 @@ import (
"io"
"io"
"math/rand"
"math/rand"
"net/http"
"net/http"
"regexp"
"sort"
"sort"
"strconv"
"strconv"
"strings"
"strings"
...
@@ -36,6 +37,7 @@ const (
...
@@ -36,6 +37,7 @@ const (
chatgptCodexURL
=
"https://chatgpt.com/backend-api/codex/responses"
chatgptCodexURL
=
"https://chatgpt.com/backend-api/codex/responses"
// OpenAI Platform API for API Key accounts (fallback)
// OpenAI Platform API for API Key accounts (fallback)
openaiPlatformAPIURL
=
"https://api.openai.com/v1/responses"
openaiPlatformAPIURL
=
"https://api.openai.com/v1/responses"
openaiChatAPIURL
=
"https://api.openai.com/v1/chat/completions"
openaiStickySessionTTL
=
time
.
Hour
// 粘性会话TTL
openaiStickySessionTTL
=
time
.
Hour
// 粘性会话TTL
codexCLIUserAgent
=
"codex_cli_rs/0.104.0"
codexCLIUserAgent
=
"codex_cli_rs/0.104.0"
// codex_cli_only 拒绝时单个请求头日志长度上限(字符)
// codex_cli_only 拒绝时单个请求头日志长度上限(字符)
...
@@ -54,6 +56,16 @@ const (
...
@@ -54,6 +56,16 @@ const (
codexCLIVersion
=
"0.104.0"
codexCLIVersion
=
"0.104.0"
)
)
// OpenAIChatCompletionsBodyKey stores the original chat-completions payload in gin.Context.
const
OpenAIChatCompletionsBodyKey
=
"openai_chat_completions_body"
// OpenAIChatCompletionsIncludeUsageKey stores stream_options.include_usage in gin.Context.
const
OpenAIChatCompletionsIncludeUsageKey
=
"openai_chat_completions_include_usage"
// openaiSSEDataRe matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var
openaiSSEDataRe
=
regexp
.
MustCompile
(
`^data:\s*`
)
// OpenAI allowed headers whitelist (for non-passthrough).
// OpenAI allowed headers whitelist (for non-passthrough).
var
openaiAllowedHeaders
=
map
[
string
]
bool
{
var
openaiAllowedHeaders
=
map
[
string
]
bool
{
"accept-language"
:
true
,
"accept-language"
:
true
,
...
@@ -97,6 +109,19 @@ var codexCLIOnlyDebugHeaderWhitelist = []string{
...
@@ -97,6 +109,19 @@ var codexCLIOnlyDebugHeaderWhitelist = []string{
"X-Real-IP"
,
"X-Real-IP"
,
}
}
// OpenAI chat-completions allowed headers (extend responses whitelist).
var
openaiChatAllowedHeaders
=
map
[
string
]
bool
{
"accept-language"
:
true
,
"content-type"
:
true
,
"conversation_id"
:
true
,
"user-agent"
:
true
,
"originator"
:
true
,
"session_id"
:
true
,
"openai-organization"
:
true
,
"openai-project"
:
true
,
"openai-beta"
:
true
,
}
// OpenAICodexUsageSnapshot represents Codex API usage limits from response headers
// OpenAICodexUsageSnapshot represents Codex API usage limits from response headers
type
OpenAICodexUsageSnapshot
struct
{
type
OpenAICodexUsageSnapshot
struct
{
PrimaryUsedPercent
*
float64
`json:"primary_used_percent,omitempty"`
PrimaryUsedPercent
*
float64
`json:"primary_used_percent,omitempty"`
...
@@ -1577,6 +1602,23 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
...
@@ -1577,6 +1602,23 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
return
nil
,
errors
.
New
(
"codex_cli_only restriction: only codex official clients are allowed"
)
return
nil
,
errors
.
New
(
"codex_cli_only restriction: only codex official clients are allowed"
)
}
}
if
c
!=
nil
&&
account
!=
nil
&&
account
.
Type
==
AccountTypeAPIKey
{
if
raw
,
ok
:=
c
.
Get
(
OpenAIChatCompletionsBodyKey
);
ok
{
if
rawBody
,
ok
:=
raw
.
([]
byte
);
ok
&&
len
(
rawBody
)
>
0
{
includeUsage
:=
false
if
v
,
ok
:=
c
.
Get
(
OpenAIChatCompletionsIncludeUsageKey
);
ok
{
if
flag
,
ok
:=
v
.
(
bool
);
ok
{
includeUsage
=
flag
}
}
if
passthroughWriter
,
ok
:=
c
.
Writer
.
(
interface
{
SetPassthrough
()
});
ok
{
passthroughWriter
.
SetPassthrough
()
}
return
s
.
forwardChatCompletions
(
ctx
,
c
,
account
,
rawBody
,
includeUsage
,
startTime
)
}
}
}
originalBody
:=
body
originalBody
:=
body
reqModel
,
reqStream
,
promptCacheKey
:=
extractOpenAIRequestMetaFromBody
(
body
)
reqModel
,
reqStream
,
promptCacheKey
:=
extractOpenAIRequestMetaFromBody
(
body
)
originalModel
:=
reqModel
originalModel
:=
reqModel
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment