Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
a09478f3
Unverified
Commit
a09478f3
authored
Feb 03, 2026
by
Wesley Liddick
Committed by
GitHub
Feb 03, 2026
Browse files
Merge pull request #316 from cyhhao/fix/claude-oauth-compat
fix(网关): 完善 Claude OAuth/Claude Code 兼容
parents
0ab68aa9
2fe8932c
Changes
12
Expand all
Hide whitespace changes
Inline
Side-by-side
backend/internal/handler/gateway_handler.go
View file @
a09478f3
...
@@ -779,6 +779,9 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
...
@@ -779,6 +779,9 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
return
return
}
}
// 检查是否为 Claude Code 客户端,设置到 context 中
SetClaudeCodeClientContext
(
c
,
body
)
setOpsRequestContext
(
c
,
""
,
false
,
body
)
setOpsRequestContext
(
c
,
""
,
false
,
body
)
parsedReq
,
err
:=
service
.
ParseGatewayRequest
(
body
)
parsedReq
,
err
:=
service
.
ParseGatewayRequest
(
body
)
...
...
backend/internal/pkg/claude/constants.go
View file @
a09478f3
...
@@ -9,11 +9,26 @@ const (
...
@@ -9,11 +9,26 @@ const (
BetaClaudeCode
=
"claude-code-20250219"
BetaClaudeCode
=
"claude-code-20250219"
BetaInterleavedThinking
=
"interleaved-thinking-2025-05-14"
BetaInterleavedThinking
=
"interleaved-thinking-2025-05-14"
BetaFineGrainedToolStreaming
=
"fine-grained-tool-streaming-2025-05-14"
BetaFineGrainedToolStreaming
=
"fine-grained-tool-streaming-2025-05-14"
BetaTokenCounting
=
"token-counting-2024-11-01"
)
)
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
const
DefaultBetaHeader
=
BetaClaudeCode
+
","
+
BetaOAuth
+
","
+
BetaInterleavedThinking
+
","
+
BetaFineGrainedToolStreaming
const
DefaultBetaHeader
=
BetaClaudeCode
+
","
+
BetaOAuth
+
","
+
BetaInterleavedThinking
+
","
+
BetaFineGrainedToolStreaming
// MessageBetaHeaderNoTools /v1/messages 在无工具时的 beta header
//
// NOTE: Claude Code OAuth credentials are scoped to Claude Code. When we "mimic"
// Claude Code for non-Claude-Code clients, we must include the claude-code beta
// even if the request doesn't use tools, otherwise upstream may reject the
// request as a non-Claude-Code API request.
const
MessageBetaHeaderNoTools
=
BetaClaudeCode
+
","
+
BetaOAuth
+
","
+
BetaInterleavedThinking
// MessageBetaHeaderWithTools /v1/messages 在有工具时的 beta header
const
MessageBetaHeaderWithTools
=
BetaClaudeCode
+
","
+
BetaOAuth
+
","
+
BetaInterleavedThinking
// CountTokensBetaHeader count_tokens 请求使用的 anthropic-beta header
const
CountTokensBetaHeader
=
BetaClaudeCode
+
","
+
BetaOAuth
+
","
+
BetaInterleavedThinking
+
","
+
BetaTokenCounting
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta)
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta)
const
HaikuBetaHeader
=
BetaOAuth
+
","
+
BetaInterleavedThinking
const
HaikuBetaHeader
=
BetaOAuth
+
","
+
BetaInterleavedThinking
...
@@ -25,15 +40,17 @@ const APIKeyHaikuBetaHeader = BetaInterleavedThinking
...
@@ -25,15 +40,17 @@ const APIKeyHaikuBetaHeader = BetaInterleavedThinking
// DefaultHeaders 是 Claude Code 客户端默认请求头。
// DefaultHeaders 是 Claude Code 客户端默认请求头。
var
DefaultHeaders
=
map
[
string
]
string
{
var
DefaultHeaders
=
map
[
string
]
string
{
"User-Agent"
:
"claude-cli/2.0.62 (external, cli)"
,
// Keep these in sync with recent Claude CLI traffic to reduce the chance
// that Claude Code-scoped OAuth credentials are rejected as "non-CLI" usage.
"User-Agent"
:
"claude-cli/2.1.22 (external, cli)"
,
"X-Stainless-Lang"
:
"js"
,
"X-Stainless-Lang"
:
"js"
,
"X-Stainless-Package-Version"
:
"0.
52
.0"
,
"X-Stainless-Package-Version"
:
"0.
70
.0"
,
"X-Stainless-OS"
:
"Linux"
,
"X-Stainless-OS"
:
"Linux"
,
"X-Stainless-Arch"
:
"
x
64"
,
"X-Stainless-Arch"
:
"
arm
64"
,
"X-Stainless-Runtime"
:
"node"
,
"X-Stainless-Runtime"
:
"node"
,
"X-Stainless-Runtime-Version"
:
"v2
2
.1
4
.0"
,
"X-Stainless-Runtime-Version"
:
"v2
4
.1
3
.0"
,
"X-Stainless-Retry-Count"
:
"0"
,
"X-Stainless-Retry-Count"
:
"0"
,
"X-Stainless-Timeout"
:
"60"
,
"X-Stainless-Timeout"
:
"60
0
"
,
"X-App"
:
"cli"
,
"X-App"
:
"cli"
,
"Anthropic-Dangerous-Direct-Browser-Access"
:
"true"
,
"Anthropic-Dangerous-Direct-Browser-Access"
:
"true"
,
}
}
...
@@ -79,3 +96,39 @@ func DefaultModelIDs() []string {
...
@@ -79,3 +96,39 @@ func DefaultModelIDs() []string {
// DefaultTestModel 测试时使用的默认模型
// DefaultTestModel 测试时使用的默认模型
const
DefaultTestModel
=
"claude-sonnet-4-5-20250929"
const
DefaultTestModel
=
"claude-sonnet-4-5-20250929"
// ModelIDOverrides Claude OAuth 请求需要的模型 ID 映射
var
ModelIDOverrides
=
map
[
string
]
string
{
"claude-sonnet-4-5"
:
"claude-sonnet-4-5-20250929"
,
"claude-opus-4-5"
:
"claude-opus-4-5-20251101"
,
"claude-haiku-4-5"
:
"claude-haiku-4-5-20251001"
,
}
// ModelIDReverseOverrides 用于将上游模型 ID 还原为短名
var
ModelIDReverseOverrides
=
map
[
string
]
string
{
"claude-sonnet-4-5-20250929"
:
"claude-sonnet-4-5"
,
"claude-opus-4-5-20251101"
:
"claude-opus-4-5"
,
"claude-haiku-4-5-20251001"
:
"claude-haiku-4-5"
,
}
// NormalizeModelID 根据 Claude OAuth 规则映射模型
func
NormalizeModelID
(
id
string
)
string
{
if
id
==
""
{
return
id
}
if
mapped
,
ok
:=
ModelIDOverrides
[
id
];
ok
{
return
mapped
}
return
id
}
// DenormalizeModelID 将上游模型 ID 转换为短名
func
DenormalizeModelID
(
id
string
)
string
{
if
id
==
""
{
return
id
}
if
mapped
,
ok
:=
ModelIDReverseOverrides
[
id
];
ok
{
return
mapped
}
return
id
}
backend/internal/service/account.go
View file @
a09478f3
...
@@ -410,6 +410,22 @@ func (a *Account) GetExtraString(key string) string {
...
@@ -410,6 +410,22 @@ func (a *Account) GetExtraString(key string) string {
return
""
return
""
}
}
func
(
a
*
Account
)
GetClaudeUserID
()
string
{
if
v
:=
strings
.
TrimSpace
(
a
.
GetExtraString
(
"claude_user_id"
));
v
!=
""
{
return
v
}
if
v
:=
strings
.
TrimSpace
(
a
.
GetExtraString
(
"anthropic_user_id"
));
v
!=
""
{
return
v
}
if
v
:=
strings
.
TrimSpace
(
a
.
GetCredential
(
"claude_user_id"
));
v
!=
""
{
return
v
}
if
v
:=
strings
.
TrimSpace
(
a
.
GetCredential
(
"anthropic_user_id"
));
v
!=
""
{
return
v
}
return
""
}
func
(
a
*
Account
)
IsCustomErrorCodesEnabled
()
bool
{
func
(
a
*
Account
)
IsCustomErrorCodesEnabled
()
bool
{
if
a
.
Type
!=
AccountTypeAPIKey
||
a
.
Credentials
==
nil
{
if
a
.
Type
!=
AccountTypeAPIKey
||
a
.
Credentials
==
nil
{
return
false
return
false
...
...
backend/internal/service/account_test_service.go
View file @
a09478f3
...
@@ -123,7 +123,7 @@ func createTestPayload(modelID string) (map[string]any, error) {
...
@@ -123,7 +123,7 @@ func createTestPayload(modelID string) (map[string]any, error) {
"system"
:
[]
map
[
string
]
any
{
"system"
:
[]
map
[
string
]
any
{
{
{
"type"
:
"text"
,
"type"
:
"text"
,
"text"
:
"You are C
laude
Code
, Anthropic's official CLI for Claude."
,
"text"
:
c
laudeCode
SystemPrompt
,
"cache_control"
:
map
[
string
]
string
{
"cache_control"
:
map
[
string
]
string
{
"type"
:
"ephemeral"
,
"type"
:
"ephemeral"
,
},
},
...
...
backend/internal/service/gateway_beta_test.go
0 → 100644
View file @
a09478f3
package
service
import
(
"testing"
"github.com/stretchr/testify/require"
)
func
TestMergeAnthropicBeta
(
t
*
testing
.
T
)
{
got
:=
mergeAnthropicBeta
(
[]
string
{
"oauth-2025-04-20"
,
"interleaved-thinking-2025-05-14"
},
"foo, oauth-2025-04-20,bar, foo"
,
)
require
.
Equal
(
t
,
"oauth-2025-04-20,interleaved-thinking-2025-05-14,foo,bar"
,
got
)
}
func
TestMergeAnthropicBeta_EmptyIncoming
(
t
*
testing
.
T
)
{
got
:=
mergeAnthropicBeta
(
[]
string
{
"oauth-2025-04-20"
,
"interleaved-thinking-2025-05-14"
},
""
,
)
require
.
Equal
(
t
,
"oauth-2025-04-20,interleaved-thinking-2025-05-14"
,
got
)
}
backend/internal/service/gateway_oauth_metadata_test.go
0 → 100644
View file @
a09478f3
package
service
import
(
"regexp"
"testing"
"github.com/stretchr/testify/require"
)
func
TestBuildOAuthMetadataUserID_FallbackWithoutAccountUUID
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
parsed
:=
&
ParsedRequest
{
Model
:
"claude-sonnet-4-5"
,
Stream
:
true
,
MetadataUserID
:
""
,
System
:
nil
,
Messages
:
nil
,
}
account
:=
&
Account
{
ID
:
123
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{},
// intentionally missing account_uuid / claude_user_id
}
fp
:=
&
Fingerprint
{
ClientID
:
"deadbeef"
}
// should be used as user id in legacy format
got
:=
svc
.
buildOAuthMetadataUserID
(
parsed
,
account
,
fp
)
require
.
NotEmpty
(
t
,
got
)
// Legacy format: user_{client}_account__session_{uuid}
re
:=
regexp
.
MustCompile
(
`^user_[a-zA-Z0-9]+_account__session_[a-f0-9-]{36}$`
)
require
.
True
(
t
,
re
.
MatchString
(
got
),
"unexpected user_id format: %s"
,
got
)
}
func
TestBuildOAuthMetadataUserID_UsesAccountUUIDWhenPresent
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
parsed
:=
&
ParsedRequest
{
Model
:
"claude-sonnet-4-5"
,
Stream
:
true
,
MetadataUserID
:
""
,
}
account
:=
&
Account
{
ID
:
123
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{
"account_uuid"
:
"acc-uuid"
,
"claude_user_id"
:
"clientid123"
,
"anthropic_user_id"
:
""
,
},
}
got
:=
svc
.
buildOAuthMetadataUserID
(
parsed
,
account
,
nil
)
require
.
NotEmpty
(
t
,
got
)
// New format: user_{client}_account_{account_uuid}_session_{uuid}
re
:=
regexp
.
MustCompile
(
`^user_clientid123_account_acc-uuid_session_[a-f0-9-]{36}$`
)
require
.
True
(
t
,
re
.
MatchString
(
got
),
"unexpected user_id format: %s"
,
got
)
}
backend/internal/service/gateway_prompt_test.go
View file @
a09478f3
...
@@ -2,6 +2,7 @@ package service
...
@@ -2,6 +2,7 @@ package service
import
(
import
(
"encoding/json"
"encoding/json"
"strings"
"testing"
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/require"
...
@@ -134,6 +135,8 @@ func TestSystemIncludesClaudeCodePrompt(t *testing.T) {
...
@@ -134,6 +135,8 @@ func TestSystemIncludesClaudeCodePrompt(t *testing.T) {
}
}
func
TestInjectClaudeCodePrompt
(
t
*
testing
.
T
)
{
func
TestInjectClaudeCodePrompt
(
t
*
testing
.
T
)
{
claudePrefix
:=
strings
.
TrimSpace
(
claudeCodeSystemPrompt
)
tests
:=
[]
struct
{
tests
:=
[]
struct
{
name
string
name
string
body
string
body
string
...
@@ -162,7 +165,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
...
@@ -162,7 +165,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
system
:
"Custom prompt"
,
system
:
"Custom prompt"
,
wantSystemLen
:
2
,
wantSystemLen
:
2
,
wantFirstText
:
claudeCodeSystemPrompt
,
wantFirstText
:
claudeCodeSystemPrompt
,
wantSecondText
:
"
Custom prompt"
,
wantSecondText
:
claudePrefix
+
"
\n\n
Custom prompt"
,
},
},
{
{
name
:
"string system equals Claude Code prompt"
,
name
:
"string system equals Claude Code prompt"
,
...
@@ -178,7 +181,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
...
@@ -178,7 +181,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
// Claude Code + Custom = 2
// Claude Code + Custom = 2
wantSystemLen
:
2
,
wantSystemLen
:
2
,
wantFirstText
:
claudeCodeSystemPrompt
,
wantFirstText
:
claudeCodeSystemPrompt
,
wantSecondText
:
"
Custom"
,
wantSecondText
:
claudePrefix
+
"
\n\n
Custom"
,
},
},
{
{
name
:
"array system with existing Claude Code prompt (should dedupe)"
,
name
:
"array system with existing Claude Code prompt (should dedupe)"
,
...
@@ -190,7 +193,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
...
@@ -190,7 +193,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
// Claude Code at start + Other = 2 (deduped)
// Claude Code at start + Other = 2 (deduped)
wantSystemLen
:
2
,
wantSystemLen
:
2
,
wantFirstText
:
claudeCodeSystemPrompt
,
wantFirstText
:
claudeCodeSystemPrompt
,
wantSecondText
:
"
Other"
,
wantSecondText
:
claudePrefix
+
"
\n\n
Other"
,
},
},
{
{
name
:
"empty array"
,
name
:
"empty array"
,
...
...
backend/internal/service/gateway_sanitize_test.go
0 → 100644
View file @
a09478f3
package
service
import
(
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func
TestSanitizeOpenCodeText_RewritesCanonicalSentence
(
t
*
testing
.
T
)
{
in
:=
"You are OpenCode, the best coding agent on the planet."
got
:=
sanitizeSystemText
(
in
)
require
.
Equal
(
t
,
strings
.
TrimSpace
(
claudeCodeSystemPrompt
),
got
)
}
func
TestSanitizeToolDescription_DoesNotRewriteKeywords
(
t
*
testing
.
T
)
{
in
:=
"OpenCode and opencode are mentioned."
got
:=
sanitizeToolDescription
(
in
)
// We no longer rewrite tool descriptions; only redact obvious path leaks.
require
.
Equal
(
t
,
in
,
got
)
}
backend/internal/service/gateway_service.go
View file @
a09478f3
This diff is collapsed.
Click to expand it.
backend/internal/service/identity_service.go
View file @
a09478f3
...
@@ -26,13 +26,13 @@ var (
...
@@ -26,13 +26,13 @@ var (
// 默认指纹值(当客户端未提供时使用)
// 默认指纹值(当客户端未提供时使用)
var
defaultFingerprint
=
Fingerprint
{
var
defaultFingerprint
=
Fingerprint
{
UserAgent
:
"claude-cli/2.
0.6
2 (external, cli)"
,
UserAgent
:
"claude-cli/2.
1.2
2 (external, cli)"
,
StainlessLang
:
"js"
,
StainlessLang
:
"js"
,
StainlessPackageVersion
:
"0.
52
.0"
,
StainlessPackageVersion
:
"0.
70
.0"
,
StainlessOS
:
"Linux"
,
StainlessOS
:
"Linux"
,
StainlessArch
:
"
x
64"
,
StainlessArch
:
"
arm
64"
,
StainlessRuntime
:
"node"
,
StainlessRuntime
:
"node"
,
StainlessRuntimeVersion
:
"v2
2
.1
4
.0"
,
StainlessRuntimeVersion
:
"v2
4
.1
3
.0"
,
}
}
// Fingerprint represents account fingerprint data
// Fingerprint represents account fingerprint data
...
@@ -327,7 +327,7 @@ func generateUUIDFromSeed(seed string) string {
...
@@ -327,7 +327,7 @@ func generateUUIDFromSeed(seed string) string {
}
}
// parseUserAgentVersion 解析user-agent版本号
// parseUserAgentVersion 解析user-agent版本号
// 例如:claude-cli/2.
0.6
2 -> (2,
0
,
6
2)
// 例如:claude-cli/2.
1.
2 -> (2,
1
, 2)
func
parseUserAgentVersion
(
ua
string
)
(
major
,
minor
,
patch
int
,
ok
bool
)
{
func
parseUserAgentVersion
(
ua
string
)
(
major
,
minor
,
patch
int
,
ok
bool
)
{
// 匹配 xxx/x.y.z 格式
// 匹配 xxx/x.y.z 格式
matches
:=
userAgentVersionRegex
.
FindStringSubmatch
(
ua
)
matches
:=
userAgentVersionRegex
.
FindStringSubmatch
(
ua
)
...
...
backend/internal/service/openai_gateway_service.go
View file @
a09478f3
...
@@ -1260,15 +1260,29 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
...
@@ -1260,15 +1260,29 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
// 记录上次收到上游数据的时间,用于控制 keepalive 发送频率
// 记录上次收到上游数据的时间,用于控制 keepalive 发送频率
lastDataAt
:=
time
.
Now
()
lastDataAt
:=
time
.
Now
()
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
// 仅发送一次错误事件,避免多次写入导致协议混乱。
// 注意:OpenAI `/v1/responses` streaming 事件必须符合 OpenAI Responses schema;
// 否则下游 SDK(例如 OpenCode)会因为类型校验失败而报错。
errorEventSent
:=
false
errorEventSent
:=
false
clientDisconnected
:=
false
// 客户端断开后继续 drain 上游以收集 usage
sendErrorEvent
:=
func
(
reason
string
)
{
sendErrorEvent
:=
func
(
reason
string
)
{
if
errorEventSent
{
if
errorEventSent
||
clientDisconnected
{
return
return
}
}
errorEventSent
=
true
errorEventSent
=
true
_
,
_
=
fmt
.
Fprintf
(
w
,
"event: error
\n
data: {
\"
error
\"
:
\"
%s
\"
}
\n\n
"
,
reason
)
payload
:=
map
[
string
]
any
{
flusher
.
Flush
()
"type"
:
"error"
,
"sequence_number"
:
0
,
"error"
:
map
[
string
]
any
{
"type"
:
"upstream_error"
,
"message"
:
reason
,
"code"
:
reason
,
},
}
if
b
,
err
:=
json
.
Marshal
(
payload
);
err
==
nil
{
_
,
_
=
fmt
.
Fprintf
(
w
,
"data: %s
\n\n
"
,
b
)
flusher
.
Flush
()
}
}
}
needModelReplace
:=
originalModel
!=
mappedModel
needModelReplace
:=
originalModel
!=
mappedModel
...
@@ -1280,6 +1294,17 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
...
@@ -1280,6 +1294,17 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
}
}
if
ev
.
err
!=
nil
{
if
ev
.
err
!=
nil
{
// 客户端断开/取消请求时,上游读取往往会返回 context canceled。
// /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。
if
errors
.
Is
(
ev
.
err
,
context
.
Canceled
)
||
errors
.
Is
(
ev
.
err
,
context
.
DeadlineExceeded
)
{
log
.
Printf
(
"Context canceled during streaming, returning collected usage"
)
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
}
// 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
if
clientDisconnected
{
log
.
Printf
(
"Upstream read error after client disconnect: %v, returning collected usage"
,
ev
.
err
)
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
}
if
errors
.
Is
(
ev
.
err
,
bufio
.
ErrTooLong
)
{
if
errors
.
Is
(
ev
.
err
,
bufio
.
ErrTooLong
)
{
log
.
Printf
(
"SSE line too long: account=%d max_size=%d error=%v"
,
account
.
ID
,
maxLineSize
,
ev
.
err
)
log
.
Printf
(
"SSE line too long: account=%d max_size=%d error=%v"
,
account
.
ID
,
maxLineSize
,
ev
.
err
)
sendErrorEvent
(
"response_too_large"
)
sendErrorEvent
(
"response_too_large"
)
...
@@ -1303,15 +1328,19 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
...
@@ -1303,15 +1328,19 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
if
correctedData
,
corrected
:=
s
.
toolCorrector
.
CorrectToolCallsInSSEData
(
data
);
corrected
{
if
correctedData
,
corrected
:=
s
.
toolCorrector
.
CorrectToolCallsInSSEData
(
data
);
corrected
{
data
=
correctedData
line
=
"data: "
+
correctedData
line
=
"data: "
+
correctedData
}
}
// Forward line
// 写入客户端(客户端断开后继续 drain 上游)
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
if
!
clientDisconnected
{
sendErrorEvent
(
"write_failed"
)
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
clientDisconnected
=
true
log
.
Printf
(
"Client disconnected during streaming, continuing to drain upstream for billing"
)
}
else
{
flusher
.
Flush
()
}
}
}
flusher
.
Flush
()
// Record first token time
// Record first token time
if
firstTokenMs
==
nil
&&
data
!=
""
&&
data
!=
"[DONE]"
{
if
firstTokenMs
==
nil
&&
data
!=
""
&&
data
!=
"[DONE]"
{
...
@@ -1321,11 +1350,14 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
...
@@ -1321,11 +1350,14 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
s
.
parseSSEUsage
(
data
,
usage
)
s
.
parseSSEUsage
(
data
,
usage
)
}
else
{
}
else
{
// Forward non-data lines as-is
// Forward non-data lines as-is
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
if
!
clientDisconnected
{
sendErrorEvent
(
"write_failed"
)
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
clientDisconnected
=
true
log
.
Printf
(
"Client disconnected during streaming, continuing to drain upstream for billing"
)
}
else
{
flusher
.
Flush
()
}
}
}
flusher
.
Flush
()
}
}
case
<-
intervalCh
:
case
<-
intervalCh
:
...
@@ -1333,6 +1365,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
...
@@ -1333,6 +1365,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if
time
.
Since
(
lastRead
)
<
streamInterval
{
if
time
.
Since
(
lastRead
)
<
streamInterval
{
continue
continue
}
}
if
clientDisconnected
{
log
.
Printf
(
"Upstream timeout after client disconnect, returning collected usage"
)
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
}
log
.
Printf
(
"Stream data interval timeout: account=%d model=%s interval=%s"
,
account
.
ID
,
originalModel
,
streamInterval
)
log
.
Printf
(
"Stream data interval timeout: account=%d model=%s interval=%s"
,
account
.
ID
,
originalModel
,
streamInterval
)
// 处理流超时,可能标记账户为临时不可调度或错误状态
// 处理流超时,可能标记账户为临时不可调度或错误状态
if
s
.
rateLimitService
!=
nil
{
if
s
.
rateLimitService
!=
nil
{
...
@@ -1342,11 +1378,16 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
...
@@ -1342,11 +1378,16 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream data interval timeout"
)
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream data interval timeout"
)
case
<-
keepaliveCh
:
case
<-
keepaliveCh
:
if
clientDisconnected
{
continue
}
if
time
.
Since
(
lastDataAt
)
<
keepaliveInterval
{
if
time
.
Since
(
lastDataAt
)
<
keepaliveInterval
{
continue
continue
}
}
if
_
,
err
:=
fmt
.
Fprint
(
w
,
":
\n\n
"
);
err
!=
nil
{
if
_
,
err
:=
fmt
.
Fprint
(
w
,
":
\n\n
"
);
err
!=
nil
{
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
clientDisconnected
=
true
log
.
Printf
(
"Client disconnected during streaming, continuing to drain upstream for billing"
)
continue
}
}
flusher
.
Flush
()
flusher
.
Flush
()
}
}
...
...
backend/internal/service/openai_gateway_service_test.go
View file @
a09478f3
...
@@ -59,6 +59,25 @@ type stubConcurrencyCache struct {
...
@@ -59,6 +59,25 @@ type stubConcurrencyCache struct {
skipDefaultLoad
bool
skipDefaultLoad
bool
}
}
type
cancelReadCloser
struct
{}
func
(
c
cancelReadCloser
)
Read
(
p
[]
byte
)
(
int
,
error
)
{
return
0
,
context
.
Canceled
}
func
(
c
cancelReadCloser
)
Close
()
error
{
return
nil
}
type
failingGinWriter
struct
{
gin
.
ResponseWriter
failAfter
int
writes
int
}
func
(
w
*
failingGinWriter
)
Write
(
p
[]
byte
)
(
int
,
error
)
{
if
w
.
writes
>=
w
.
failAfter
{
return
0
,
errors
.
New
(
"write failed"
)
}
w
.
writes
++
return
w
.
ResponseWriter
.
Write
(
p
)
}
func
(
c
stubConcurrencyCache
)
AcquireAccountSlot
(
ctx
context
.
Context
,
accountID
int64
,
maxConcurrency
int
,
requestID
string
)
(
bool
,
error
)
{
func
(
c
stubConcurrencyCache
)
AcquireAccountSlot
(
ctx
context
.
Context
,
accountID
int64
,
maxConcurrency
int
,
requestID
string
)
(
bool
,
error
)
{
if
c
.
acquireResults
!=
nil
{
if
c
.
acquireResults
!=
nil
{
if
result
,
ok
:=
c
.
acquireResults
[
accountID
];
ok
{
if
result
,
ok
:=
c
.
acquireResults
[
accountID
];
ok
{
...
@@ -814,8 +833,85 @@ func TestOpenAIStreamingTimeout(t *testing.T) {
...
@@ -814,8 +833,85 @@ func TestOpenAIStreamingTimeout(t *testing.T) {
if
err
==
nil
||
!
strings
.
Contains
(
err
.
Error
(),
"stream data interval timeout"
)
{
if
err
==
nil
||
!
strings
.
Contains
(
err
.
Error
(),
"stream data interval timeout"
)
{
t
.
Fatalf
(
"expected stream timeout error, got %v"
,
err
)
t
.
Fatalf
(
"expected stream timeout error, got %v"
,
err
)
}
}
if
!
strings
.
Contains
(
rec
.
Body
.
String
(),
"stream_timeout"
)
{
if
!
strings
.
Contains
(
rec
.
Body
.
String
(),
"
\"
type
\"
:
\"
error
\"
"
)
||
!
strings
.
Contains
(
rec
.
Body
.
String
(),
"stream_timeout"
)
{
t
.
Fatalf
(
"expected stream_timeout SSE error, got %q"
,
rec
.
Body
.
String
())
t
.
Fatalf
(
"expected OpenAI-compatible error SSE event, got %q"
,
rec
.
Body
.
String
())
}
}
func
TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
StreamDataIntervalTimeout
:
0
,
StreamKeepaliveInterval
:
0
,
MaxLineSize
:
defaultMaxLineSize
,
},
}
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
}
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
ctx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
cancel
()
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
.
WithContext
(
ctx
)
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
cancelReadCloser
{},
Header
:
http
.
Header
{},
}
_
,
err
:=
svc
.
handleStreamingResponse
(
c
.
Request
.
Context
(),
resp
,
c
,
&
Account
{
ID
:
1
},
time
.
Now
(),
"model"
,
"model"
)
if
err
!=
nil
{
t
.
Fatalf
(
"expected nil error, got %v"
,
err
)
}
if
strings
.
Contains
(
rec
.
Body
.
String
(),
"event: error"
)
||
strings
.
Contains
(
rec
.
Body
.
String
(),
"stream_read_error"
)
{
t
.
Fatalf
(
"expected no injected SSE error event, got %q"
,
rec
.
Body
.
String
())
}
}
func
TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
StreamDataIntervalTimeout
:
0
,
StreamKeepaliveInterval
:
0
,
MaxLineSize
:
defaultMaxLineSize
,
},
}
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
}
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
c
.
Writer
=
&
failingGinWriter
{
ResponseWriter
:
c
.
Writer
,
failAfter
:
0
}
pr
,
pw
:=
io
.
Pipe
()
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
pr
,
Header
:
http
.
Header
{},
}
go
func
()
{
defer
func
()
{
_
=
pw
.
Close
()
}()
_
,
_
=
pw
.
Write
([]
byte
(
"data: {
\"
type
\"
:
\"
response.in_progress
\"
,
\"
response
\"
:{}}
\n\n
"
))
_
,
_
=
pw
.
Write
([]
byte
(
"data: {
\"
type
\"
:
\"
response.completed
\"
,
\"
response
\"
:{
\"
usage
\"
:{
\"
input_tokens
\"
:3,
\"
output_tokens
\"
:5,
\"
input_tokens_details
\"
:{
\"
cached_tokens
\"
:1}}}}
\n\n
"
))
}()
result
,
err
:=
svc
.
handleStreamingResponse
(
c
.
Request
.
Context
(),
resp
,
c
,
&
Account
{
ID
:
1
},
time
.
Now
(),
"model"
,
"model"
)
_
=
pr
.
Close
()
if
err
!=
nil
{
t
.
Fatalf
(
"expected nil error, got %v"
,
err
)
}
if
result
==
nil
||
result
.
usage
==
nil
{
t
.
Fatalf
(
"expected usage result"
)
}
if
result
.
usage
.
InputTokens
!=
3
||
result
.
usage
.
OutputTokens
!=
5
||
result
.
usage
.
CacheReadInputTokens
!=
1
{
t
.
Fatalf
(
"unexpected usage: %+v"
,
*
result
.
usage
)
}
if
strings
.
Contains
(
rec
.
Body
.
String
(),
"event: error"
)
||
strings
.
Contains
(
rec
.
Body
.
String
(),
"write_failed"
)
{
t
.
Fatalf
(
"expected no injected SSE error event, got %q"
,
rec
.
Body
.
String
())
}
}
}
}
...
@@ -854,8 +950,8 @@ func TestOpenAIStreamingTooLong(t *testing.T) {
...
@@ -854,8 +950,8 @@ func TestOpenAIStreamingTooLong(t *testing.T) {
if
!
errors
.
Is
(
err
,
bufio
.
ErrTooLong
)
{
if
!
errors
.
Is
(
err
,
bufio
.
ErrTooLong
)
{
t
.
Fatalf
(
"expected ErrTooLong, got %v"
,
err
)
t
.
Fatalf
(
"expected ErrTooLong, got %v"
,
err
)
}
}
if
!
strings
.
Contains
(
rec
.
Body
.
String
(),
"response_too_large"
)
{
if
!
strings
.
Contains
(
rec
.
Body
.
String
(),
"
\"
type
\"
:
\"
error
\"
"
)
||
!
strings
.
Contains
(
rec
.
Body
.
String
(),
"response_too_large"
)
{
t
.
Fatalf
(
"expected
response_too_large
SSE e
rror
, got %q"
,
rec
.
Body
.
String
())
t
.
Fatalf
(
"expected
OpenAI-compatible error
SSE e
vent
, got %q"
,
rec
.
Body
.
String
())
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment