Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
a09478f3
Unverified
Commit
a09478f3
authored
Feb 03, 2026
by
Wesley Liddick
Committed by
GitHub
Feb 03, 2026
Browse files
Merge pull request #316 from cyhhao/fix/claude-oauth-compat
fix(网关): 完善 Claude OAuth/Claude Code 兼容
parents
0ab68aa9
2fe8932c
Changes
12
Hide whitespace changes
Inline
Side-by-side
backend/internal/handler/gateway_handler.go
View file @
a09478f3
...
...
@@ -779,6 +779,9 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
return
}
// 检查是否为 Claude Code 客户端,设置到 context 中
SetClaudeCodeClientContext
(
c
,
body
)
setOpsRequestContext
(
c
,
""
,
false
,
body
)
parsedReq
,
err
:=
service
.
ParseGatewayRequest
(
body
)
...
...
backend/internal/pkg/claude/constants.go
View file @
a09478f3
...
...
@@ -9,11 +9,26 @@ const (
BetaClaudeCode
=
"claude-code-20250219"
BetaInterleavedThinking
=
"interleaved-thinking-2025-05-14"
BetaFineGrainedToolStreaming
=
"fine-grained-tool-streaming-2025-05-14"
BetaTokenCounting
=
"token-counting-2024-11-01"
)
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
const
DefaultBetaHeader
=
BetaClaudeCode
+
","
+
BetaOAuth
+
","
+
BetaInterleavedThinking
+
","
+
BetaFineGrainedToolStreaming
// MessageBetaHeaderNoTools /v1/messages 在无工具时的 beta header
//
// NOTE: Claude Code OAuth credentials are scoped to Claude Code. When we "mimic"
// Claude Code for non-Claude-Code clients, we must include the claude-code beta
// even if the request doesn't use tools, otherwise upstream may reject the
// request as a non-Claude-Code API request.
const
MessageBetaHeaderNoTools
=
BetaClaudeCode
+
","
+
BetaOAuth
+
","
+
BetaInterleavedThinking
// MessageBetaHeaderWithTools /v1/messages 在有工具时的 beta header
const
MessageBetaHeaderWithTools
=
BetaClaudeCode
+
","
+
BetaOAuth
+
","
+
BetaInterleavedThinking
// CountTokensBetaHeader count_tokens 请求使用的 anthropic-beta header
const
CountTokensBetaHeader
=
BetaClaudeCode
+
","
+
BetaOAuth
+
","
+
BetaInterleavedThinking
+
","
+
BetaTokenCounting
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta)
const
HaikuBetaHeader
=
BetaOAuth
+
","
+
BetaInterleavedThinking
...
...
@@ -25,15 +40,17 @@ const APIKeyHaikuBetaHeader = BetaInterleavedThinking
// DefaultHeaders 是 Claude Code 客户端默认请求头。
var
DefaultHeaders
=
map
[
string
]
string
{
"User-Agent"
:
"claude-cli/2.0.62 (external, cli)"
,
// Keep these in sync with recent Claude CLI traffic to reduce the chance
// that Claude Code-scoped OAuth credentials are rejected as "non-CLI" usage.
"User-Agent"
:
"claude-cli/2.1.22 (external, cli)"
,
"X-Stainless-Lang"
:
"js"
,
"X-Stainless-Package-Version"
:
"0.
52
.0"
,
"X-Stainless-Package-Version"
:
"0.
70
.0"
,
"X-Stainless-OS"
:
"Linux"
,
"X-Stainless-Arch"
:
"
x
64"
,
"X-Stainless-Arch"
:
"
arm
64"
,
"X-Stainless-Runtime"
:
"node"
,
"X-Stainless-Runtime-Version"
:
"v2
2
.1
4
.0"
,
"X-Stainless-Runtime-Version"
:
"v2
4
.1
3
.0"
,
"X-Stainless-Retry-Count"
:
"0"
,
"X-Stainless-Timeout"
:
"60"
,
"X-Stainless-Timeout"
:
"60
0
"
,
"X-App"
:
"cli"
,
"Anthropic-Dangerous-Direct-Browser-Access"
:
"true"
,
}
...
...
@@ -79,3 +96,39 @@ func DefaultModelIDs() []string {
// DefaultTestModel 测试时使用的默认模型
const
DefaultTestModel
=
"claude-sonnet-4-5-20250929"
// ModelIDOverrides Claude OAuth 请求需要的模型 ID 映射
var
ModelIDOverrides
=
map
[
string
]
string
{
"claude-sonnet-4-5"
:
"claude-sonnet-4-5-20250929"
,
"claude-opus-4-5"
:
"claude-opus-4-5-20251101"
,
"claude-haiku-4-5"
:
"claude-haiku-4-5-20251001"
,
}
// ModelIDReverseOverrides 用于将上游模型 ID 还原为短名
var
ModelIDReverseOverrides
=
map
[
string
]
string
{
"claude-sonnet-4-5-20250929"
:
"claude-sonnet-4-5"
,
"claude-opus-4-5-20251101"
:
"claude-opus-4-5"
,
"claude-haiku-4-5-20251001"
:
"claude-haiku-4-5"
,
}
// NormalizeModelID 根据 Claude OAuth 规则映射模型
func
NormalizeModelID
(
id
string
)
string
{
if
id
==
""
{
return
id
}
if
mapped
,
ok
:=
ModelIDOverrides
[
id
];
ok
{
return
mapped
}
return
id
}
// DenormalizeModelID 将上游模型 ID 转换为短名
func
DenormalizeModelID
(
id
string
)
string
{
if
id
==
""
{
return
id
}
if
mapped
,
ok
:=
ModelIDReverseOverrides
[
id
];
ok
{
return
mapped
}
return
id
}
backend/internal/service/account.go
View file @
a09478f3
...
...
@@ -410,6 +410,22 @@ func (a *Account) GetExtraString(key string) string {
return
""
}
func
(
a
*
Account
)
GetClaudeUserID
()
string
{
if
v
:=
strings
.
TrimSpace
(
a
.
GetExtraString
(
"claude_user_id"
));
v
!=
""
{
return
v
}
if
v
:=
strings
.
TrimSpace
(
a
.
GetExtraString
(
"anthropic_user_id"
));
v
!=
""
{
return
v
}
if
v
:=
strings
.
TrimSpace
(
a
.
GetCredential
(
"claude_user_id"
));
v
!=
""
{
return
v
}
if
v
:=
strings
.
TrimSpace
(
a
.
GetCredential
(
"anthropic_user_id"
));
v
!=
""
{
return
v
}
return
""
}
func
(
a
*
Account
)
IsCustomErrorCodesEnabled
()
bool
{
if
a
.
Type
!=
AccountTypeAPIKey
||
a
.
Credentials
==
nil
{
return
false
...
...
backend/internal/service/account_test_service.go
View file @
a09478f3
...
...
@@ -123,7 +123,7 @@ func createTestPayload(modelID string) (map[string]any, error) {
"system"
:
[]
map
[
string
]
any
{
{
"type"
:
"text"
,
"text"
:
"You are C
laude
Code
, Anthropic's official CLI for Claude."
,
"text"
:
c
laudeCode
SystemPrompt
,
"cache_control"
:
map
[
string
]
string
{
"type"
:
"ephemeral"
,
},
...
...
backend/internal/service/gateway_beta_test.go
0 → 100644
View file @
a09478f3
package
service
import
(
"testing"
"github.com/stretchr/testify/require"
)
func
TestMergeAnthropicBeta
(
t
*
testing
.
T
)
{
got
:=
mergeAnthropicBeta
(
[]
string
{
"oauth-2025-04-20"
,
"interleaved-thinking-2025-05-14"
},
"foo, oauth-2025-04-20,bar, foo"
,
)
require
.
Equal
(
t
,
"oauth-2025-04-20,interleaved-thinking-2025-05-14,foo,bar"
,
got
)
}
func
TestMergeAnthropicBeta_EmptyIncoming
(
t
*
testing
.
T
)
{
got
:=
mergeAnthropicBeta
(
[]
string
{
"oauth-2025-04-20"
,
"interleaved-thinking-2025-05-14"
},
""
,
)
require
.
Equal
(
t
,
"oauth-2025-04-20,interleaved-thinking-2025-05-14"
,
got
)
}
backend/internal/service/gateway_oauth_metadata_test.go
0 → 100644
View file @
a09478f3
package
service
import
(
"regexp"
"testing"
"github.com/stretchr/testify/require"
)
func
TestBuildOAuthMetadataUserID_FallbackWithoutAccountUUID
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
parsed
:=
&
ParsedRequest
{
Model
:
"claude-sonnet-4-5"
,
Stream
:
true
,
MetadataUserID
:
""
,
System
:
nil
,
Messages
:
nil
,
}
account
:=
&
Account
{
ID
:
123
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{},
// intentionally missing account_uuid / claude_user_id
}
fp
:=
&
Fingerprint
{
ClientID
:
"deadbeef"
}
// should be used as user id in legacy format
got
:=
svc
.
buildOAuthMetadataUserID
(
parsed
,
account
,
fp
)
require
.
NotEmpty
(
t
,
got
)
// Legacy format: user_{client}_account__session_{uuid}
re
:=
regexp
.
MustCompile
(
`^user_[a-zA-Z0-9]+_account__session_[a-f0-9-]{36}$`
)
require
.
True
(
t
,
re
.
MatchString
(
got
),
"unexpected user_id format: %s"
,
got
)
}
func
TestBuildOAuthMetadataUserID_UsesAccountUUIDWhenPresent
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
parsed
:=
&
ParsedRequest
{
Model
:
"claude-sonnet-4-5"
,
Stream
:
true
,
MetadataUserID
:
""
,
}
account
:=
&
Account
{
ID
:
123
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{
"account_uuid"
:
"acc-uuid"
,
"claude_user_id"
:
"clientid123"
,
"anthropic_user_id"
:
""
,
},
}
got
:=
svc
.
buildOAuthMetadataUserID
(
parsed
,
account
,
nil
)
require
.
NotEmpty
(
t
,
got
)
// New format: user_{client}_account_{account_uuid}_session_{uuid}
re
:=
regexp
.
MustCompile
(
`^user_clientid123_account_acc-uuid_session_[a-f0-9-]{36}$`
)
require
.
True
(
t
,
re
.
MatchString
(
got
),
"unexpected user_id format: %s"
,
got
)
}
backend/internal/service/gateway_prompt_test.go
View file @
a09478f3
...
...
@@ -2,6 +2,7 @@ package service
import
(
"encoding/json"
"strings"
"testing"
"github.com/stretchr/testify/require"
...
...
@@ -134,6 +135,8 @@ func TestSystemIncludesClaudeCodePrompt(t *testing.T) {
}
func
TestInjectClaudeCodePrompt
(
t
*
testing
.
T
)
{
claudePrefix
:=
strings
.
TrimSpace
(
claudeCodeSystemPrompt
)
tests
:=
[]
struct
{
name
string
body
string
...
...
@@ -162,7 +165,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
system
:
"Custom prompt"
,
wantSystemLen
:
2
,
wantFirstText
:
claudeCodeSystemPrompt
,
wantSecondText
:
"
Custom prompt"
,
wantSecondText
:
claudePrefix
+
"
\n\n
Custom prompt"
,
},
{
name
:
"string system equals Claude Code prompt"
,
...
...
@@ -178,7 +181,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
// Claude Code + Custom = 2
wantSystemLen
:
2
,
wantFirstText
:
claudeCodeSystemPrompt
,
wantSecondText
:
"
Custom"
,
wantSecondText
:
claudePrefix
+
"
\n\n
Custom"
,
},
{
name
:
"array system with existing Claude Code prompt (should dedupe)"
,
...
...
@@ -190,7 +193,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
// Claude Code at start + Other = 2 (deduped)
wantSystemLen
:
2
,
wantFirstText
:
claudeCodeSystemPrompt
,
wantSecondText
:
"
Other"
,
wantSecondText
:
claudePrefix
+
"
\n\n
Other"
,
},
{
name
:
"empty array"
,
...
...
backend/internal/service/gateway_sanitize_test.go
0 → 100644
View file @
a09478f3
package
service
import
(
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func
TestSanitizeOpenCodeText_RewritesCanonicalSentence
(
t
*
testing
.
T
)
{
in
:=
"You are OpenCode, the best coding agent on the planet."
got
:=
sanitizeSystemText
(
in
)
require
.
Equal
(
t
,
strings
.
TrimSpace
(
claudeCodeSystemPrompt
),
got
)
}
func
TestSanitizeToolDescription_DoesNotRewriteKeywords
(
t
*
testing
.
T
)
{
in
:=
"OpenCode and opencode are mentioned."
got
:=
sanitizeToolDescription
(
in
)
// We no longer rewrite tool descriptions; only redact obvious path leaks.
require
.
Equal
(
t
,
in
,
got
)
}
backend/internal/service/gateway_service.go
View file @
a09478f3
...
...
@@ -20,12 +20,14 @@ import (
"strings"
"sync/atomic"
"time"
"unicode"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/google/uuid"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
...
...
@@ -37,8 +39,15 @@ const (
claudeAPICountTokensURL
=
"https://api.anthropic.com/v1/messages/count_tokens?beta=true"
stickySessionTTL
=
time
.
Hour
// 粘性会话TTL
defaultMaxLineSize
=
40
*
1024
*
1024
claudeCodeSystemPrompt
=
"You are Claude Code, Anthropic's official CLI for Claude."
maxCacheControlBlocks
=
4
// Anthropic API 允许的最大 cache_control 块数量
// Canonical Claude Code banner. Keep it EXACT (no trailing whitespace/newlines)
// to match real Claude CLI traffic as closely as possible. When we need a visual
// separator between system blocks, we add "\n\n" at concatenation time.
claudeCodeSystemPrompt
=
"You are Claude Code, Anthropic's official CLI for Claude."
maxCacheControlBlocks
=
4
// Anthropic API 允许的最大 cache_control 块数量
)
const
(
claudeMimicDebugInfoKey
=
"claude_mimic_debug_info"
)
func
(
s
*
GatewayService
)
debugModelRoutingEnabled
()
bool
{
...
...
@@ -46,6 +55,11 @@ func (s *GatewayService) debugModelRoutingEnabled() bool {
return
v
==
"1"
||
v
==
"true"
||
v
==
"yes"
||
v
==
"on"
}
func
(
s
*
GatewayService
)
debugClaudeMimicEnabled
()
bool
{
v
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
os
.
Getenv
(
"SUB2API_DEBUG_CLAUDE_MIMIC"
)))
return
v
==
"1"
||
v
==
"true"
||
v
==
"yes"
||
v
==
"on"
}
func
shortSessionHash
(
sessionHash
string
)
string
{
if
sessionHash
==
""
{
return
""
...
...
@@ -56,12 +70,178 @@ func shortSessionHash(sessionHash string) string {
return
sessionHash
[
:
8
]
}
func
redactAuthHeaderValue
(
v
string
)
string
{
v
=
strings
.
TrimSpace
(
v
)
if
v
==
""
{
return
""
}
// Keep scheme for debugging, redact secret.
if
strings
.
HasPrefix
(
strings
.
ToLower
(
v
),
"bearer "
)
{
return
"Bearer [redacted]"
}
return
"[redacted]"
}
func
safeHeaderValueForLog
(
key
string
,
v
string
)
string
{
key
=
strings
.
ToLower
(
strings
.
TrimSpace
(
key
))
switch
key
{
case
"authorization"
,
"x-api-key"
:
return
redactAuthHeaderValue
(
v
)
default
:
return
strings
.
TrimSpace
(
v
)
}
}
func
extractSystemPreviewFromBody
(
body
[]
byte
)
string
{
if
len
(
body
)
==
0
{
return
""
}
sys
:=
gjson
.
GetBytes
(
body
,
"system"
)
if
!
sys
.
Exists
()
{
return
""
}
switch
{
case
sys
.
IsArray
()
:
for
_
,
item
:=
range
sys
.
Array
()
{
if
!
item
.
IsObject
()
{
continue
}
if
strings
.
EqualFold
(
item
.
Get
(
"type"
)
.
String
(),
"text"
)
{
if
t
:=
item
.
Get
(
"text"
)
.
String
();
strings
.
TrimSpace
(
t
)
!=
""
{
return
t
}
}
}
return
""
case
sys
.
Type
==
gjson
.
String
:
return
sys
.
String
()
default
:
return
""
}
}
func
buildClaudeMimicDebugLine
(
req
*
http
.
Request
,
body
[]
byte
,
account
*
Account
,
tokenType
string
,
mimicClaudeCode
bool
)
string
{
if
req
==
nil
{
return
""
}
// Only log a minimal fingerprint to avoid leaking user content.
interesting
:=
[]
string
{
"user-agent"
,
"x-app"
,
"anthropic-dangerous-direct-browser-access"
,
"anthropic-version"
,
"anthropic-beta"
,
"x-stainless-lang"
,
"x-stainless-package-version"
,
"x-stainless-os"
,
"x-stainless-arch"
,
"x-stainless-runtime"
,
"x-stainless-runtime-version"
,
"x-stainless-retry-count"
,
"x-stainless-timeout"
,
"authorization"
,
"x-api-key"
,
"content-type"
,
"accept"
,
"x-stainless-helper-method"
,
}
h
:=
make
([]
string
,
0
,
len
(
interesting
))
for
_
,
k
:=
range
interesting
{
if
v
:=
req
.
Header
.
Get
(
k
);
v
!=
""
{
h
=
append
(
h
,
fmt
.
Sprintf
(
"%s=%q"
,
k
,
safeHeaderValueForLog
(
k
,
v
)))
}
}
metaUserID
:=
strings
.
TrimSpace
(
gjson
.
GetBytes
(
body
,
"metadata.user_id"
)
.
String
())
sysPreview
:=
strings
.
TrimSpace
(
extractSystemPreviewFromBody
(
body
))
// Truncate preview to keep logs sane.
if
len
(
sysPreview
)
>
300
{
sysPreview
=
sysPreview
[
:
300
]
+
"..."
}
sysPreview
=
strings
.
ReplaceAll
(
sysPreview
,
"
\n
"
,
"
\\
n"
)
sysPreview
=
strings
.
ReplaceAll
(
sysPreview
,
"
\r
"
,
"
\\
r"
)
aid
:=
int64
(
0
)
aname
:=
""
if
account
!=
nil
{
aid
=
account
.
ID
aname
=
account
.
Name
}
return
fmt
.
Sprintf
(
"url=%s account=%d(%s) tokenType=%s mimic=%t meta.user_id=%q system.preview=%q headers={%s}"
,
req
.
URL
.
String
(),
aid
,
aname
,
tokenType
,
mimicClaudeCode
,
metaUserID
,
sysPreview
,
strings
.
Join
(
h
,
" "
),
)
}
func
logClaudeMimicDebug
(
req
*
http
.
Request
,
body
[]
byte
,
account
*
Account
,
tokenType
string
,
mimicClaudeCode
bool
)
{
line
:=
buildClaudeMimicDebugLine
(
req
,
body
,
account
,
tokenType
,
mimicClaudeCode
)
if
line
==
""
{
return
}
log
.
Printf
(
"[ClaudeMimicDebug] %s"
,
line
)
}
func
isClaudeCodeCredentialScopeError
(
msg
string
)
bool
{
m
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
msg
))
if
m
==
""
{
return
false
}
return
strings
.
Contains
(
m
,
"only authorized for use with claude code"
)
&&
strings
.
Contains
(
m
,
"cannot be used for other api requests"
)
}
// sseDataRe matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var
(
sseDataRe
=
regexp
.
MustCompile
(
`^data:\s*`
)
sessionIDRegex
=
regexp
.
MustCompile
(
`session_([a-f0-9-]{36})`
)
claudeCliUserAgentRe
=
regexp
.
MustCompile
(
`^claude-cli/\d+\.\d+\.\d+`
)
toolPrefixRe
=
regexp
.
MustCompile
(
`(?i)^(?:oc_|mcp_)`
)
toolNameBoundaryRe
=
regexp
.
MustCompile
(
`[^a-zA-Z0-9]+`
)
toolNameCamelRe
=
regexp
.
MustCompile
(
`([a-z0-9])([A-Z])`
)
toolNameFieldRe
=
regexp
.
MustCompile
(
`"name"\s*:\s*"([^"]+)"`
)
modelFieldRe
=
regexp
.
MustCompile
(
`"model"\s*:\s*"([^"]+)"`
)
toolDescAbsPathRe
=
regexp
.
MustCompile
(
`/\/?(?:home|Users|tmp|var|opt|usr|etc)\/[^\s,\)"'\]]+`
)
toolDescWinPathRe
=
regexp
.
MustCompile
(
`(?i)[A-Z]:\\[^\s,\)"'\]]+`
)
claudeToolNameOverrides
=
map
[
string
]
string
{
"bash"
:
"Bash"
,
"read"
:
"Read"
,
"edit"
:
"Edit"
,
"write"
:
"Write"
,
"task"
:
"Task"
,
"glob"
:
"Glob"
,
"grep"
:
"Grep"
,
"webfetch"
:
"WebFetch"
,
"websearch"
:
"WebSearch"
,
"todowrite"
:
"TodoWrite"
,
"question"
:
"AskUserQuestion"
,
}
openCodeToolOverrides
=
map
[
string
]
string
{
"Bash"
:
"bash"
,
"Read"
:
"read"
,
"Edit"
:
"edit"
,
"Write"
:
"write"
,
"Task"
:
"task"
,
"Glob"
:
"glob"
,
"Grep"
:
"grep"
,
"WebFetch"
:
"webfetch"
,
"WebSearch"
:
"websearch"
,
"TodoWrite"
:
"todowrite"
,
"AskUserQuestion"
:
"question"
,
}
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
// 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等
...
...
@@ -418,6 +598,394 @@ func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte
return
newBody
}
type
claudeOAuthNormalizeOptions
struct
{
injectMetadata
bool
metadataUserID
string
stripSystemCacheControl
bool
}
func
stripToolPrefix
(
value
string
)
string
{
if
value
==
""
{
return
value
}
return
toolPrefixRe
.
ReplaceAllString
(
value
,
""
)
}
func
toPascalCase
(
value
string
)
string
{
if
value
==
""
{
return
value
}
normalized
:=
toolNameBoundaryRe
.
ReplaceAllString
(
value
,
" "
)
tokens
:=
make
([]
string
,
0
)
for
_
,
token
:=
range
strings
.
Fields
(
normalized
)
{
expanded
:=
toolNameCamelRe
.
ReplaceAllString
(
token
,
"$1 $2"
)
parts
:=
strings
.
Fields
(
expanded
)
if
len
(
parts
)
>
0
{
tokens
=
append
(
tokens
,
parts
...
)
}
}
if
len
(
tokens
)
==
0
{
return
value
}
var
builder
strings
.
Builder
for
_
,
token
:=
range
tokens
{
lower
:=
strings
.
ToLower
(
token
)
if
lower
==
""
{
continue
}
runes
:=
[]
rune
(
lower
)
runes
[
0
]
=
unicode
.
ToUpper
(
runes
[
0
])
_
,
_
=
builder
.
WriteString
(
string
(
runes
))
}
return
builder
.
String
()
}
func
toSnakeCase
(
value
string
)
string
{
if
value
==
""
{
return
value
}
output
:=
toolNameCamelRe
.
ReplaceAllString
(
value
,
"$1_$2"
)
output
=
toolNameBoundaryRe
.
ReplaceAllString
(
output
,
"_"
)
output
=
strings
.
Trim
(
output
,
"_"
)
return
strings
.
ToLower
(
output
)
}
func
normalizeToolNameForClaude
(
name
string
,
cache
map
[
string
]
string
)
string
{
if
name
==
""
{
return
name
}
stripped
:=
stripToolPrefix
(
name
)
mapped
,
ok
:=
claudeToolNameOverrides
[
strings
.
ToLower
(
stripped
)]
if
!
ok
{
mapped
=
toPascalCase
(
stripped
)
}
if
mapped
!=
""
&&
cache
!=
nil
&&
mapped
!=
stripped
{
cache
[
mapped
]
=
stripped
}
if
mapped
==
""
{
return
stripped
}
return
mapped
}
func
normalizeToolNameForOpenCode
(
name
string
,
cache
map
[
string
]
string
)
string
{
if
name
==
""
{
return
name
}
stripped
:=
stripToolPrefix
(
name
)
if
cache
!=
nil
{
if
mapped
,
ok
:=
cache
[
stripped
];
ok
{
return
mapped
}
}
if
mapped
,
ok
:=
openCodeToolOverrides
[
stripped
];
ok
{
return
mapped
}
return
toSnakeCase
(
stripped
)
}
func
normalizeParamNameForOpenCode
(
name
string
,
cache
map
[
string
]
string
)
string
{
if
name
==
""
{
return
name
}
if
cache
!=
nil
{
if
mapped
,
ok
:=
cache
[
name
];
ok
{
return
mapped
}
}
return
name
}
// sanitizeSystemText rewrites only the fixed OpenCode identity sentence (if present).
// We intentionally avoid broad keyword replacement in system prompts to prevent
// accidentally changing user-provided instructions.
func
sanitizeSystemText
(
text
string
)
string
{
if
text
==
""
{
return
text
}
// Some clients include a fixed OpenCode identity sentence. Anthropic may treat
// this as a non-Claude-Code fingerprint, so rewrite it to the canonical
// Claude Code banner before generic "OpenCode"/"opencode" replacements.
text
=
strings
.
ReplaceAll
(
text
,
"You are OpenCode, the best coding agent on the planet."
,
strings
.
TrimSpace
(
claudeCodeSystemPrompt
),
)
return
text
}
func
sanitizeToolDescription
(
description
string
)
string
{
if
description
==
""
{
return
description
}
description
=
toolDescAbsPathRe
.
ReplaceAllString
(
description
,
"[path]"
)
description
=
toolDescWinPathRe
.
ReplaceAllString
(
description
,
"[path]"
)
// Intentionally do NOT rewrite tool descriptions (OpenCode/Claude strings).
// Tool names/skill names may rely on exact wording, and rewriting can be misleading.
return
description
}
func
normalizeToolInputSchema
(
inputSchema
any
,
cache
map
[
string
]
string
)
{
schema
,
ok
:=
inputSchema
.
(
map
[
string
]
any
)
if
!
ok
{
return
}
properties
,
ok
:=
schema
[
"properties"
]
.
(
map
[
string
]
any
)
if
!
ok
{
return
}
newProperties
:=
make
(
map
[
string
]
any
,
len
(
properties
))
for
key
,
value
:=
range
properties
{
snakeKey
:=
toSnakeCase
(
key
)
newProperties
[
snakeKey
]
=
value
if
snakeKey
!=
key
&&
cache
!=
nil
{
cache
[
snakeKey
]
=
key
}
}
schema
[
"properties"
]
=
newProperties
if
required
,
ok
:=
schema
[
"required"
]
.
([]
any
);
ok
{
newRequired
:=
make
([]
any
,
0
,
len
(
required
))
for
_
,
item
:=
range
required
{
name
,
ok
:=
item
.
(
string
)
if
!
ok
{
newRequired
=
append
(
newRequired
,
item
)
continue
}
snakeName
:=
toSnakeCase
(
name
)
newRequired
=
append
(
newRequired
,
snakeName
)
if
snakeName
!=
name
&&
cache
!=
nil
{
cache
[
snakeName
]
=
name
}
}
schema
[
"required"
]
=
newRequired
}
}
func
stripCacheControlFromSystemBlocks
(
system
any
)
bool
{
blocks
,
ok
:=
system
.
([]
any
)
if
!
ok
{
return
false
}
changed
:=
false
for
_
,
item
:=
range
blocks
{
block
,
ok
:=
item
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
if
_
,
exists
:=
block
[
"cache_control"
];
!
exists
{
continue
}
delete
(
block
,
"cache_control"
)
changed
=
true
}
return
changed
}
func
normalizeClaudeOAuthRequestBody
(
body
[]
byte
,
modelID
string
,
opts
claudeOAuthNormalizeOptions
)
([]
byte
,
string
,
map
[
string
]
string
)
{
if
len
(
body
)
==
0
{
return
body
,
modelID
,
nil
}
var
req
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
req
);
err
!=
nil
{
return
body
,
modelID
,
nil
}
toolNameMap
:=
make
(
map
[
string
]
string
)
if
system
,
ok
:=
req
[
"system"
];
ok
{
switch
v
:=
system
.
(
type
)
{
case
string
:
sanitized
:=
sanitizeSystemText
(
v
)
if
sanitized
!=
v
{
req
[
"system"
]
=
sanitized
}
case
[]
any
:
for
_
,
item
:=
range
v
{
block
,
ok
:=
item
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
if
blockType
,
_
:=
block
[
"type"
]
.
(
string
);
blockType
!=
"text"
{
continue
}
text
,
ok
:=
block
[
"text"
]
.
(
string
)
if
!
ok
||
text
==
""
{
continue
}
sanitized
:=
sanitizeSystemText
(
text
)
if
sanitized
!=
text
{
block
[
"text"
]
=
sanitized
}
}
}
}
if
rawModel
,
ok
:=
req
[
"model"
]
.
(
string
);
ok
{
normalized
:=
claude
.
NormalizeModelID
(
rawModel
)
if
normalized
!=
rawModel
{
req
[
"model"
]
=
normalized
modelID
=
normalized
}
}
if
rawTools
,
exists
:=
req
[
"tools"
];
exists
{
switch
tools
:=
rawTools
.
(
type
)
{
case
[]
any
:
for
idx
,
tool
:=
range
tools
{
toolMap
,
ok
:=
tool
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
if
name
,
ok
:=
toolMap
[
"name"
]
.
(
string
);
ok
{
normalized
:=
normalizeToolNameForClaude
(
name
,
toolNameMap
)
if
normalized
!=
""
&&
normalized
!=
name
{
toolMap
[
"name"
]
=
normalized
}
}
if
desc
,
ok
:=
toolMap
[
"description"
]
.
(
string
);
ok
{
sanitized
:=
sanitizeToolDescription
(
desc
)
if
sanitized
!=
desc
{
toolMap
[
"description"
]
=
sanitized
}
}
if
schema
,
ok
:=
toolMap
[
"input_schema"
];
ok
{
normalizeToolInputSchema
(
schema
,
toolNameMap
)
}
tools
[
idx
]
=
toolMap
}
req
[
"tools"
]
=
tools
case
map
[
string
]
any
:
normalizedTools
:=
make
(
map
[
string
]
any
,
len
(
tools
))
for
name
,
value
:=
range
tools
{
normalized
:=
normalizeToolNameForClaude
(
name
,
toolNameMap
)
if
normalized
==
""
{
normalized
=
name
}
if
toolMap
,
ok
:=
value
.
(
map
[
string
]
any
);
ok
{
toolMap
[
"name"
]
=
normalized
if
desc
,
ok
:=
toolMap
[
"description"
]
.
(
string
);
ok
{
sanitized
:=
sanitizeToolDescription
(
desc
)
if
sanitized
!=
desc
{
toolMap
[
"description"
]
=
sanitized
}
}
if
schema
,
ok
:=
toolMap
[
"input_schema"
];
ok
{
normalizeToolInputSchema
(
schema
,
toolNameMap
)
}
normalizedTools
[
normalized
]
=
toolMap
continue
}
normalizedTools
[
normalized
]
=
value
}
req
[
"tools"
]
=
normalizedTools
}
}
else
{
req
[
"tools"
]
=
[]
any
{}
}
if
messages
,
ok
:=
req
[
"messages"
]
.
([]
any
);
ok
{
for
_
,
msg
:=
range
messages
{
msgMap
,
ok
:=
msg
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
content
,
ok
:=
msgMap
[
"content"
]
.
([]
any
)
if
!
ok
{
continue
}
for
_
,
block
:=
range
content
{
blockMap
,
ok
:=
block
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
if
blockType
,
_
:=
blockMap
[
"type"
]
.
(
string
);
blockType
!=
"tool_use"
{
continue
}
if
name
,
ok
:=
blockMap
[
"name"
]
.
(
string
);
ok
{
normalized
:=
normalizeToolNameForClaude
(
name
,
toolNameMap
)
if
normalized
!=
""
&&
normalized
!=
name
{
blockMap
[
"name"
]
=
normalized
}
}
}
}
}
if
opts
.
stripSystemCacheControl
{
if
system
,
ok
:=
req
[
"system"
];
ok
{
_
=
stripCacheControlFromSystemBlocks
(
system
)
}
}
if
opts
.
injectMetadata
&&
opts
.
metadataUserID
!=
""
{
metadata
,
ok
:=
req
[
"metadata"
]
.
(
map
[
string
]
any
)
if
!
ok
{
metadata
=
map
[
string
]
any
{}
req
[
"metadata"
]
=
metadata
}
if
existing
,
ok
:=
metadata
[
"user_id"
]
.
(
string
);
!
ok
||
existing
==
""
{
metadata
[
"user_id"
]
=
opts
.
metadataUserID
}
}
delete
(
req
,
"temperature"
)
delete
(
req
,
"tool_choice"
)
newBody
,
err
:=
json
.
Marshal
(
req
)
if
err
!=
nil
{
return
body
,
modelID
,
toolNameMap
}
return
newBody
,
modelID
,
toolNameMap
}
func
(
s
*
GatewayService
)
buildOAuthMetadataUserID
(
parsed
*
ParsedRequest
,
account
*
Account
,
fp
*
Fingerprint
)
string
{
if
parsed
==
nil
||
account
==
nil
{
return
""
}
if
parsed
.
MetadataUserID
!=
""
{
return
""
}
userID
:=
strings
.
TrimSpace
(
account
.
GetClaudeUserID
())
if
userID
==
""
&&
fp
!=
nil
{
userID
=
fp
.
ClientID
}
if
userID
==
""
{
// Fall back to a random, well-formed client id so we can still satisfy
// Claude Code OAuth requirements when account metadata is incomplete.
userID
=
generateClientID
()
}
sessionHash
:=
s
.
GenerateSessionHash
(
parsed
)
sessionID
:=
uuid
.
NewString
()
if
sessionHash
!=
""
{
seed
:=
fmt
.
Sprintf
(
"%d::%s"
,
account
.
ID
,
sessionHash
)
sessionID
=
generateSessionUUID
(
seed
)
}
// Prefer the newer format that includes account_uuid (if present),
// otherwise fall back to the legacy Claude Code format.
accountUUID
:=
strings
.
TrimSpace
(
account
.
GetExtraString
(
"account_uuid"
))
if
accountUUID
!=
""
{
return
fmt
.
Sprintf
(
"user_%s_account_%s_session_%s"
,
userID
,
accountUUID
,
sessionID
)
}
return
fmt
.
Sprintf
(
"user_%s_account__session_%s"
,
userID
,
sessionID
)
}
func
generateSessionUUID
(
seed
string
)
string
{
if
seed
==
""
{
return
uuid
.
NewString
()
}
hash
:=
sha256
.
Sum256
([]
byte
(
seed
))
bytes
:=
hash
[
:
16
]
bytes
[
6
]
=
(
bytes
[
6
]
&
0x0f
)
|
0x40
bytes
[
8
]
=
(
bytes
[
8
]
&
0x3f
)
|
0x80
return
fmt
.
Sprintf
(
"%x-%x-%x-%x-%x"
,
bytes
[
0
:
4
],
bytes
[
4
:
6
],
bytes
[
6
:
8
],
bytes
[
8
:
10
],
bytes
[
10
:
16
])
}
// SelectAccount 选择账号(粘性会话+优先级)
func
(
s
*
GatewayService
)
SelectAccount
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
)
(
*
Account
,
error
)
{
return
s
.
SelectAccountForModel
(
ctx
,
groupID
,
sessionHash
,
""
)
...
...
@@ -2021,6 +2589,16 @@ func isClaudeCodeClient(userAgent string, metadataUserID string) bool {
return
claudeCliUserAgentRe
.
MatchString
(
userAgent
)
}
func
isClaudeCodeRequest
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
parsed
*
ParsedRequest
)
bool
{
if
IsClaudeCodeClient
(
ctx
)
{
return
true
}
if
parsed
==
nil
||
c
==
nil
{
return
false
}
return
isClaudeCodeClient
(
c
.
GetHeader
(
"User-Agent"
),
parsed
.
MetadataUserID
)
}
// systemIncludesClaudeCodePrompt 检查 system 中是否已包含 Claude Code 提示词
// 使用前缀匹配支持多种变体(标准版、Agent SDK 版等)
func
systemIncludesClaudeCodePrompt
(
system
any
)
bool
{
...
...
@@ -2057,6 +2635,10 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
"text"
:
claudeCodeSystemPrompt
,
"cache_control"
:
map
[
string
]
string
{
"type"
:
"ephemeral"
},
}
// Opencode plugin applies an extra safeguard: it not only prepends the Claude Code
// banner, it also prefixes the next system instruction with the same banner plus
// a blank line. This helps when upstream concatenates system instructions.
claudeCodePrefix
:=
strings
.
TrimSpace
(
claudeCodeSystemPrompt
)
var
newSystem
[]
any
...
...
@@ -2064,19 +2646,36 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
case
nil
:
newSystem
=
[]
any
{
claudeCodeBlock
}
case
string
:
if
v
==
""
||
v
==
claudeCodeSystemPrompt
{
// Be tolerant of older/newer clients that may differ only by trailing whitespace/newlines.
if
strings
.
TrimSpace
(
v
)
==
""
||
strings
.
TrimSpace
(
v
)
==
strings
.
TrimSpace
(
claudeCodeSystemPrompt
)
{
newSystem
=
[]
any
{
claudeCodeBlock
}
}
else
{
newSystem
=
[]
any
{
claudeCodeBlock
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
v
}}
// Mirror opencode behavior: keep the banner as a separate system entry,
// but also prefix the next system text with the banner.
merged
:=
v
if
!
strings
.
HasPrefix
(
v
,
claudeCodePrefix
)
{
merged
=
claudeCodePrefix
+
"
\n\n
"
+
v
}
newSystem
=
[]
any
{
claudeCodeBlock
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
merged
}}
}
case
[]
any
:
newSystem
=
make
([]
any
,
0
,
len
(
v
)
+
1
)
newSystem
=
append
(
newSystem
,
claudeCodeBlock
)
prefixedNext
:=
false
for
_
,
item
:=
range
v
{
if
m
,
ok
:=
item
.
(
map
[
string
]
any
);
ok
{
if
text
,
ok
:=
m
[
"text"
]
.
(
string
);
ok
&&
text
==
claudeCodeSystemPrompt
{
if
text
,
ok
:=
m
[
"text"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
text
)
==
strings
.
TrimSpace
(
claudeCodeSystemPrompt
)
{
continue
}
// Prefix the first subsequent text system block once.
if
!
prefixedNext
{
if
blockType
,
_
:=
m
[
"type"
]
.
(
string
);
blockType
==
"text"
{
if
text
,
ok
:=
m
[
"text"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
text
)
!=
""
&&
!
strings
.
HasPrefix
(
text
,
claudeCodePrefix
)
{
m
[
"text"
]
=
claudeCodePrefix
+
"
\n\n
"
+
text
prefixedNext
=
true
}
}
}
}
newSystem
=
append
(
newSystem
,
item
)
}
...
...
@@ -2280,21 +2879,38 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
body
:=
parsed
.
Body
reqModel
:=
parsed
.
Model
reqStream
:=
parsed
.
Stream
originalModel
:=
reqModel
var
toolNameMap
map
[
string
]
string
isClaudeCode
:=
isClaudeCodeRequest
(
ctx
,
c
,
parsed
)
shouldMimicClaudeCode
:=
account
.
IsOAuth
()
&&
!
isClaudeCode
if
shouldMimicClaudeCode
{
// 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要)
// 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
if
!
strings
.
Contains
(
strings
.
ToLower
(
reqModel
),
"haiku"
)
&&
!
systemIncludesClaudeCodePrompt
(
parsed
.
System
)
{
body
=
injectClaudeCodePrompt
(
body
,
parsed
.
System
)
}
normalizeOpts
:=
claudeOAuthNormalizeOptions
{
stripSystemCacheControl
:
true
}
if
s
.
identityService
!=
nil
{
fp
,
err
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
.
Request
.
Header
)
if
err
==
nil
&&
fp
!=
nil
{
if
metadataUserID
:=
s
.
buildOAuthMetadataUserID
(
parsed
,
account
,
fp
);
metadataUserID
!=
""
{
normalizeOpts
.
injectMetadata
=
true
normalizeOpts
.
metadataUserID
=
metadataUserID
}
}
}
// 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要)
// 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
if
account
.
IsOAuth
()
&&
!
isClaudeCodeClient
(
c
.
GetHeader
(
"User-Agent"
),
parsed
.
MetadataUserID
)
&&
!
strings
.
Contains
(
strings
.
ToLower
(
reqModel
),
"haiku"
)
&&
!
systemIncludesClaudeCodePrompt
(
parsed
.
System
)
{
body
=
injectClaudeCodePrompt
(
body
,
parsed
.
System
)
body
,
reqModel
,
toolNameMap
=
normalizeClaudeOAuthRequestBody
(
body
,
reqModel
,
normalizeOpts
)
}
// 强制执行 cache_control 块数量限制(最多 4 个)
body
=
enforceCacheControlLimit
(
body
)
// 应用模型映射(仅对apikey类型账号)
originalModel
:=
reqModel
if
account
.
Type
==
AccountTypeAPIKey
{
mappedModel
:=
account
.
GetMappedModel
(
reqModel
)
if
mappedModel
!=
reqModel
{
...
...
@@ -2326,10 +2942,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
retryStart
:=
time
.
Now
()
for
attempt
:=
1
;
attempt
<=
maxRetryAttempts
;
attempt
++
{
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
upstreamReq
,
err
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
body
,
token
,
tokenType
,
reqModel
)
// Capture upstream request body for ops retry of this attempt.
c
.
Set
(
OpsUpstreamRequestBodyKey
,
string
(
body
))
upstreamReq
,
err
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
body
,
token
,
tokenType
,
reqModel
,
reqStream
,
shouldMimicClaudeCode
)
if
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -2407,7 +3022,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// also downgrade tool_use/tool_result blocks to text.
filteredBody
:=
FilterThinkingBlocksForRetry
(
body
)
retryReq
,
buildErr
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
filteredBody
,
token
,
tokenType
,
reqModel
)
retryReq
,
buildErr
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
filteredBody
,
token
,
tokenType
,
reqModel
,
reqStream
,
shouldMimicClaudeCode
)
if
buildErr
==
nil
{
retryResp
,
retryErr
:=
s
.
httpUpstream
.
DoWithTLS
(
retryReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
())
if
retryErr
==
nil
{
...
...
@@ -2439,7 +3054,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
if
looksLikeToolSignatureError
(
msg2
)
&&
time
.
Since
(
retryStart
)
<
maxRetryElapsed
{
log
.
Printf
(
"Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded"
,
account
.
ID
)
filteredBody2
:=
FilterSignatureSensitiveBlocksForRetry
(
body
)
retryReq2
,
buildErr2
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
filteredBody2
,
token
,
tokenType
,
reqModel
)
retryReq2
,
buildErr2
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
filteredBody2
,
token
,
tokenType
,
reqModel
,
reqStream
,
shouldMimicClaudeCode
)
if
buildErr2
==
nil
{
retryResp2
,
retryErr2
:=
s
.
httpUpstream
.
DoWithTLS
(
retryReq2
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
())
if
retryErr2
==
nil
{
...
...
@@ -2664,7 +3279,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
var
firstTokenMs
*
int
var
clientDisconnect
bool
if
reqStream
{
streamResult
,
err
:=
s
.
handleStreamingResponse
(
ctx
,
resp
,
c
,
account
,
startTime
,
originalModel
,
reqModel
)
streamResult
,
err
:=
s
.
handleStreamingResponse
(
ctx
,
resp
,
c
,
account
,
startTime
,
originalModel
,
reqModel
,
toolNameMap
,
shouldMimicClaudeCode
)
if
err
!=
nil
{
if
err
.
Error
()
==
"have error in stream"
{
return
nil
,
&
UpstreamFailoverError
{
...
...
@@ -2677,7 +3292,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
firstTokenMs
=
streamResult
.
firstTokenMs
clientDisconnect
=
streamResult
.
clientDisconnect
}
else
{
usage
,
err
=
s
.
handleNonStreamingResponse
(
ctx
,
resp
,
c
,
account
,
originalModel
,
reqModel
)
usage
,
err
=
s
.
handleNonStreamingResponse
(
ctx
,
resp
,
c
,
account
,
originalModel
,
reqModel
,
toolNameMap
,
shouldMimicClaudeCode
)
if
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -2694,7 +3309,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
},
nil
}
func
(
s
*
GatewayService
)
buildUpstreamRequest
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
,
token
,
tokenType
,
modelID
string
)
(
*
http
.
Request
,
error
)
{
func
(
s
*
GatewayService
)
buildUpstreamRequest
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
,
token
,
tokenType
,
modelID
string
,
reqStream
bool
,
mimicClaudeCode
bool
)
(
*
http
.
Request
,
error
)
{
// 确定目标URL
targetURL
:=
claudeAPIURL
if
account
.
Type
==
AccountTypeAPIKey
{
...
...
@@ -2708,11 +3323,16 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
}
clientHeaders
:=
http
.
Header
{}
if
c
!=
nil
&&
c
.
Request
!=
nil
{
clientHeaders
=
c
.
Request
.
Header
}
// OAuth账号:应用统一指纹
var
fingerprint
*
Fingerprint
if
account
.
IsOAuth
()
&&
s
.
identityService
!=
nil
{
// 1. 获取或创建指纹(包含随机生成的ClientID)
fp
,
err
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
.
Request
.
Header
)
fp
,
err
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
lient
Header
s
)
if
err
!=
nil
{
log
.
Printf
(
"Warning: failed to get fingerprint for account %d: %v"
,
account
.
ID
,
err
)
// 失败时降级为透传原始headers
...
...
@@ -2743,7 +3363,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
// 白名单透传headers
for
key
,
values
:=
range
c
.
Request
.
Header
{
for
key
,
values
:=
range
c
lient
Header
s
{
lowerKey
:=
strings
.
ToLower
(
key
)
if
allowedHeaders
[
lowerKey
]
{
for
_
,
v
:=
range
values
{
...
...
@@ -2764,10 +3384,30 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
if
req
.
Header
.
Get
(
"anthropic-version"
)
==
""
{
req
.
Header
.
Set
(
"anthropic-version"
,
"2023-06-01"
)
}
if
tokenType
==
"oauth"
{
applyClaudeOAuthHeaderDefaults
(
req
,
reqStream
)
}
// 处理anthropic-beta header(OAuth账号需要
特殊处理
)
// 处理
anthropic-beta header(OAuth
账号需要
包含 oauth beta
)
if
tokenType
==
"oauth"
{
req
.
Header
.
Set
(
"anthropic-beta"
,
s
.
getBetaHeader
(
modelID
,
c
.
GetHeader
(
"anthropic-beta"
)))
if
mimicClaudeCode
{
// 非 Claude Code 客户端:按 opencode 的策略处理:
// - 强制 Claude Code 指纹相关请求头(尤其是 user-agent/x-stainless/x-app)
// - 保留 incoming beta 的同时,确保 OAuth 所需 beta 存在
applyClaudeCodeMimicHeaders
(
req
,
reqStream
)
incomingBeta
:=
req
.
Header
.
Get
(
"anthropic-beta"
)
// Match real Claude CLI traffic (per mitmproxy reports):
// messages requests typically use only oauth + interleaved-thinking.
// Also drop claude-code beta if a downstream client added it.
requiredBetas
:=
[]
string
{
claude
.
BetaOAuth
,
claude
.
BetaInterleavedThinking
}
drop
:=
map
[
string
]
struct
{}{
claude
.
BetaClaudeCode
:
{}}
req
.
Header
.
Set
(
"anthropic-beta"
,
mergeAnthropicBetaDropping
(
requiredBetas
,
incomingBeta
,
drop
))
}
else
{
// Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta
clientBetaHeader
:=
req
.
Header
.
Get
(
"anthropic-beta"
)
req
.
Header
.
Set
(
"anthropic-beta"
,
s
.
getBetaHeader
(
modelID
,
clientBetaHeader
))
}
}
else
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
InjectBetaForAPIKey
&&
req
.
Header
.
Get
(
"anthropic-beta"
)
==
""
{
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
if
requestNeedsBetaFeatures
(
body
)
{
...
...
@@ -2777,6 +3417,15 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
}
// Always capture a compact fingerprint line for later error diagnostics.
// We only print it when needed (or when the explicit debug flag is enabled).
if
c
!=
nil
&&
tokenType
==
"oauth"
{
c
.
Set
(
claudeMimicDebugInfoKey
,
buildClaudeMimicDebugLine
(
req
,
body
,
account
,
tokenType
,
mimicClaudeCode
))
}
if
s
.
debugClaudeMimicEnabled
()
{
logClaudeMimicDebug
(
req
,
body
,
account
,
tokenType
,
mimicClaudeCode
)
}
return
req
,
nil
}
...
...
@@ -2846,6 +3495,93 @@ func defaultAPIKeyBetaHeader(body []byte) string {
return
claude
.
APIKeyBetaHeader
}
func
applyClaudeOAuthHeaderDefaults
(
req
*
http
.
Request
,
isStream
bool
)
{
if
req
==
nil
{
return
}
if
req
.
Header
.
Get
(
"accept"
)
==
""
{
req
.
Header
.
Set
(
"accept"
,
"application/json"
)
}
for
key
,
value
:=
range
claude
.
DefaultHeaders
{
if
value
==
""
{
continue
}
if
req
.
Header
.
Get
(
key
)
==
""
{
req
.
Header
.
Set
(
key
,
value
)
}
}
if
isStream
&&
req
.
Header
.
Get
(
"x-stainless-helper-method"
)
==
""
{
req
.
Header
.
Set
(
"x-stainless-helper-method"
,
"stream"
)
}
}
func
mergeAnthropicBeta
(
required
[]
string
,
incoming
string
)
string
{
seen
:=
make
(
map
[
string
]
struct
{},
len
(
required
)
+
8
)
out
:=
make
([]
string
,
0
,
len
(
required
)
+
8
)
add
:=
func
(
v
string
)
{
v
=
strings
.
TrimSpace
(
v
)
if
v
==
""
{
return
}
if
_
,
ok
:=
seen
[
v
];
ok
{
return
}
seen
[
v
]
=
struct
{}{}
out
=
append
(
out
,
v
)
}
for
_
,
r
:=
range
required
{
add
(
r
)
}
for
_
,
p
:=
range
strings
.
Split
(
incoming
,
","
)
{
add
(
p
)
}
return
strings
.
Join
(
out
,
","
)
}
func
mergeAnthropicBetaDropping
(
required
[]
string
,
incoming
string
,
drop
map
[
string
]
struct
{})
string
{
merged
:=
mergeAnthropicBeta
(
required
,
incoming
)
if
merged
==
""
||
len
(
drop
)
==
0
{
return
merged
}
out
:=
make
([]
string
,
0
,
8
)
for
_
,
p
:=
range
strings
.
Split
(
merged
,
","
)
{
p
=
strings
.
TrimSpace
(
p
)
if
p
==
""
{
continue
}
if
_
,
ok
:=
drop
[
p
];
ok
{
continue
}
out
=
append
(
out
,
p
)
}
return
strings
.
Join
(
out
,
","
)
}
// applyClaudeCodeMimicHeaders forces "Claude Code-like" request headers.
// This mirrors opencode-anthropic-auth behavior: do not trust downstream
// headers when using Claude Code-scoped OAuth credentials.
func
applyClaudeCodeMimicHeaders
(
req
*
http
.
Request
,
isStream
bool
)
{
if
req
==
nil
{
return
}
// Start with the standard defaults (fill missing).
applyClaudeOAuthHeaderDefaults
(
req
,
isStream
)
// Then force key headers to match Claude Code fingerprint regardless of what the client sent.
for
key
,
value
:=
range
claude
.
DefaultHeaders
{
if
value
==
""
{
continue
}
req
.
Header
.
Set
(
key
,
value
)
}
// Real Claude CLI uses Accept: application/json (even for streaming).
req
.
Header
.
Set
(
"accept"
,
"application/json"
)
if
isStream
{
req
.
Header
.
Set
(
"x-stainless-helper-method"
,
"stream"
)
}
}
func
truncateForLog
(
b
[]
byte
,
maxBytes
int
)
string
{
if
maxBytes
<=
0
{
maxBytes
=
2048
...
...
@@ -2949,6 +3685,20 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
body
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
// Print a compact upstream request fingerprint when we hit the Claude Code OAuth
// credential scope error. This avoids requiring env-var tweaks in a fixed deploy.
if
isClaudeCodeCredentialScopeError
(
upstreamMsg
)
&&
c
!=
nil
{
if
v
,
ok
:=
c
.
Get
(
claudeMimicDebugInfoKey
);
ok
{
if
line
,
ok
:=
v
.
(
string
);
ok
&&
strings
.
TrimSpace
(
line
)
!=
""
{
log
.
Printf
(
"[ClaudeMimicDebugOnError] status=%d request_id=%s %s"
,
resp
.
StatusCode
,
resp
.
Header
.
Get
(
"x-request-id"
),
line
,
)
}
}
}
// Enrich Ops error logs with upstream status + message, and optionally a truncated body snippet.
upstreamDetail
:=
""
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
...
...
@@ -3078,6 +3828,19 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
respBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
if
isClaudeCodeCredentialScopeError
(
upstreamMsg
)
&&
c
!=
nil
{
if
v
,
ok
:=
c
.
Get
(
claudeMimicDebugInfoKey
);
ok
{
if
line
,
ok
:=
v
.
(
string
);
ok
&&
strings
.
TrimSpace
(
line
)
!=
""
{
log
.
Printf
(
"[ClaudeMimicDebugOnError] status=%d request_id=%s %s"
,
resp
.
StatusCode
,
resp
.
Header
.
Get
(
"x-request-id"
),
line
,
)
}
}
}
upstreamDetail
:=
""
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
maxBytes
:=
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
...
...
@@ -3130,7 +3893,7 @@ type streamingResult struct {
clientDisconnect
bool
// 客户端是否在流式传输过程中断开
}
func
(
s
*
GatewayService
)
handleStreamingResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
Account
,
startTime
time
.
Time
,
originalModel
,
mappedModel
string
)
(
*
streamingResult
,
error
)
{
func
(
s
*
GatewayService
)
handleStreamingResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
Account
,
startTime
time
.
Time
,
originalModel
,
mappedModel
string
,
toolNameMap
map
[
string
]
string
,
mimicClaudeCode
bool
)
(
*
streamingResult
,
error
)
{
// 更新5h窗口状态
s
.
rateLimitService
.
UpdateSessionWindow
(
ctx
,
account
,
resp
.
Header
)
...
...
@@ -3225,6 +3988,171 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
needModelReplace
:=
originalModel
!=
mappedModel
clientDisconnected
:=
false
// 客户端断开标志,断开后继续读取上游以获取完整usage
pendingEventLines
:=
make
([]
string
,
0
,
4
)
var
toolInputBuffers
map
[
int
]
string
if
mimicClaudeCode
{
toolInputBuffers
=
make
(
map
[
int
]
string
)
}
transformToolInputJSON
:=
func
(
raw
string
)
string
{
if
!
mimicClaudeCode
{
return
raw
}
raw
=
strings
.
TrimSpace
(
raw
)
if
raw
==
""
{
return
raw
}
var
parsed
any
if
err
:=
json
.
Unmarshal
([]
byte
(
raw
),
&
parsed
);
err
!=
nil
{
return
replaceToolNamesInText
(
raw
,
toolNameMap
)
}
rewritten
,
changed
:=
rewriteParamKeysInValue
(
parsed
,
toolNameMap
)
if
changed
{
if
bytes
,
err
:=
json
.
Marshal
(
rewritten
);
err
==
nil
{
return
string
(
bytes
)
}
}
return
raw
}
processSSEEvent
:=
func
(
lines
[]
string
)
([]
string
,
string
,
error
)
{
if
len
(
lines
)
==
0
{
return
nil
,
""
,
nil
}
eventName
:=
""
dataLine
:=
""
for
_
,
line
:=
range
lines
{
trimmed
:=
strings
.
TrimSpace
(
line
)
if
strings
.
HasPrefix
(
trimmed
,
"event:"
)
{
eventName
=
strings
.
TrimSpace
(
strings
.
TrimPrefix
(
trimmed
,
"event:"
))
continue
}
if
dataLine
==
""
&&
sseDataRe
.
MatchString
(
trimmed
)
{
dataLine
=
sseDataRe
.
ReplaceAllString
(
trimmed
,
""
)
}
}
if
eventName
==
"error"
{
return
nil
,
dataLine
,
errors
.
New
(
"have error in stream"
)
}
if
dataLine
==
""
{
return
[]
string
{
strings
.
Join
(
lines
,
"
\n
"
)
+
"
\n\n
"
},
""
,
nil
}
if
dataLine
==
"[DONE]"
{
block
:=
""
if
eventName
!=
""
{
block
=
"event: "
+
eventName
+
"
\n
"
}
block
+=
"data: "
+
dataLine
+
"
\n\n
"
return
[]
string
{
block
},
dataLine
,
nil
}
var
event
map
[
string
]
any
if
err
:=
json
.
Unmarshal
([]
byte
(
dataLine
),
&
event
);
err
!=
nil
{
replaced
:=
dataLine
if
mimicClaudeCode
{
replaced
=
replaceToolNamesInText
(
dataLine
,
toolNameMap
)
}
block
:=
""
if
eventName
!=
""
{
block
=
"event: "
+
eventName
+
"
\n
"
}
block
+=
"data: "
+
replaced
+
"
\n\n
"
return
[]
string
{
block
},
replaced
,
nil
}
eventType
,
_
:=
event
[
"type"
]
.
(
string
)
if
eventName
==
""
{
eventName
=
eventType
}
if
needModelReplace
{
if
msg
,
ok
:=
event
[
"message"
]
.
(
map
[
string
]
any
);
ok
{
if
model
,
ok
:=
msg
[
"model"
]
.
(
string
);
ok
&&
model
==
mappedModel
{
msg
[
"model"
]
=
originalModel
}
}
}
if
mimicClaudeCode
&&
eventType
==
"content_block_delta"
{
if
delta
,
ok
:=
event
[
"delta"
]
.
(
map
[
string
]
any
);
ok
{
if
deltaType
,
_
:=
delta
[
"type"
]
.
(
string
);
deltaType
==
"input_json_delta"
{
if
indexVal
,
ok
:=
event
[
"index"
]
.
(
float64
);
ok
{
index
:=
int
(
indexVal
)
if
partial
,
ok
:=
delta
[
"partial_json"
]
.
(
string
);
ok
{
toolInputBuffers
[
index
]
+=
partial
}
}
return
nil
,
dataLine
,
nil
}
}
}
if
mimicClaudeCode
&&
eventType
==
"content_block_stop"
{
if
indexVal
,
ok
:=
event
[
"index"
]
.
(
float64
);
ok
{
index
:=
int
(
indexVal
)
if
buffered
:=
toolInputBuffers
[
index
];
buffered
!=
""
{
delete
(
toolInputBuffers
,
index
)
transformed
:=
transformToolInputJSON
(
buffered
)
synthetic
:=
map
[
string
]
any
{
"type"
:
"content_block_delta"
,
"index"
:
index
,
"delta"
:
map
[
string
]
any
{
"type"
:
"input_json_delta"
,
"partial_json"
:
transformed
,
},
}
synthBytes
,
synthErr
:=
json
.
Marshal
(
synthetic
)
if
synthErr
==
nil
{
synthBlock
:=
"event: content_block_delta
\n
"
+
"data: "
+
string
(
synthBytes
)
+
"
\n\n
"
rewriteToolNamesInValue
(
event
,
toolNameMap
)
stopBytes
,
stopErr
:=
json
.
Marshal
(
event
)
if
stopErr
==
nil
{
stopBlock
:=
""
if
eventName
!=
""
{
stopBlock
=
"event: "
+
eventName
+
"
\n
"
}
stopBlock
+=
"data: "
+
string
(
stopBytes
)
+
"
\n\n
"
return
[]
string
{
synthBlock
,
stopBlock
},
string
(
stopBytes
),
nil
}
}
}
}
}
if
mimicClaudeCode
{
rewriteToolNamesInValue
(
event
,
toolNameMap
)
}
newData
,
err
:=
json
.
Marshal
(
event
)
if
err
!=
nil
{
replaced
:=
dataLine
if
mimicClaudeCode
{
replaced
=
replaceToolNamesInText
(
dataLine
,
toolNameMap
)
}
block
:=
""
if
eventName
!=
""
{
block
=
"event: "
+
eventName
+
"
\n
"
}
block
+=
"data: "
+
replaced
+
"
\n\n
"
return
[]
string
{
block
},
replaced
,
nil
}
block
:=
""
if
eventName
!=
""
{
block
=
"event: "
+
eventName
+
"
\n
"
}
block
+=
"data: "
+
string
(
newData
)
+
"
\n\n
"
return
[]
string
{
block
},
string
(
newData
),
nil
}
for
{
select
{
case
ev
,
ok
:=
<-
events
:
...
...
@@ -3253,43 +4181,44 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream read error: %w"
,
ev
.
err
)
}
line
:=
ev
.
line
if
line
==
"event: error"
{
// 上游返回错误事件,如果客户端已断开仍返回已收集的 usage
if
clientDisconnected
{
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
true
},
nil
}
return
nil
,
errors
.
New
(
"have error in stream"
)
}
trimmed
:=
strings
.
TrimSpace
(
line
)
// Extract data from SSE line (supports both "data: " and "data:" formats)
var
data
string
if
sseDataRe
.
MatchString
(
line
)
{
data
=
sseDataRe
.
ReplaceAllString
(
line
,
""
)
// 如果有模型映射,替换响应中的model字段
if
needModelReplace
{
line
=
s
.
replaceModelInSSELine
(
line
,
mappedModel
,
originalModel
)
if
trimmed
==
""
{
if
len
(
pendingEventLines
)
==
0
{
continue
}
}
// 写入客户端(统一处理 data 行和非 data 行)
if
!
clientDisconnected
{
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
clientDisconnected
=
true
log
.
Printf
(
"Client disconnected during streaming, continuing to drain upstream for billing"
)
}
else
{
flusher
.
Flush
()
outputBlocks
,
data
,
err
:=
processSSEEvent
(
pendingEventLines
)
pendingEventLines
=
pendingEventLines
[
:
0
]
if
err
!=
nil
{
if
clientDisconnected
{
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
true
},
nil
}
return
nil
,
err
}
}
// 无论客户端是否断开,都解析 usage(仅对 data 行)
if
data
!=
""
{
if
firstTokenMs
==
nil
&&
data
!=
"[DONE]"
{
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
firstTokenMs
=
&
ms
for
_
,
block
:=
range
outputBlocks
{
if
!
clientDisconnected
{
if
_
,
werr
:=
fmt
.
Fprint
(
w
,
block
);
werr
!=
nil
{
clientDisconnected
=
true
log
.
Printf
(
"Client disconnected during streaming, continuing to drain upstream for billing"
)
break
}
flusher
.
Flush
()
}
if
data
!=
""
{
if
firstTokenMs
==
nil
&&
data
!=
"[DONE]"
{
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
firstTokenMs
=
&
ms
}
s
.
parseSSEUsage
(
data
,
usage
)
}
}
s
.
parseSSEUsage
(
data
,
usage
)
continue
}
pendingEventLines
=
append
(
pendingEventLines
,
line
)
case
<-
intervalCh
:
lastRead
:=
time
.
Unix
(
0
,
atomic
.
LoadInt64
(
&
lastReadAt
))
if
time
.
Since
(
lastRead
)
<
streamInterval
{
...
...
@@ -3312,43 +4241,124 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
// replaceModelInSSELine 替换SSE数据行中的model字段
func
(
s
*
GatewayService
)
replaceModelInSSELine
(
line
,
fromModel
,
toModel
string
)
string
{
if
!
sseDataRe
.
MatchString
(
line
)
{
return
line
}
data
:=
sseDataRe
.
ReplaceAllString
(
line
,
""
)
if
data
==
""
||
data
==
"[DONE]"
{
return
line
}
var
event
map
[
string
]
any
if
err
:=
json
.
Unmarshal
([]
byte
(
data
),
&
event
);
err
!=
nil
{
return
line
}
// 只替换 message_start 事件中的 message.model
if
event
[
"type"
]
!=
"message_start"
{
return
line
func
rewriteParamKeysInValue
(
value
any
,
cache
map
[
string
]
string
)
(
any
,
bool
)
{
switch
v
:=
value
.
(
type
)
{
case
map
[
string
]
any
:
changed
:=
false
rewritten
:=
make
(
map
[
string
]
any
,
len
(
v
))
for
key
,
item
:=
range
v
{
newKey
:=
normalizeParamNameForOpenCode
(
key
,
cache
)
newItem
,
childChanged
:=
rewriteParamKeysInValue
(
item
,
cache
)
if
childChanged
{
changed
=
true
}
if
newKey
!=
key
{
changed
=
true
}
rewritten
[
newKey
]
=
newItem
}
if
!
changed
{
return
value
,
false
}
return
rewritten
,
true
case
[]
any
:
changed
:=
false
rewritten
:=
make
([]
any
,
len
(
v
))
for
idx
,
item
:=
range
v
{
newItem
,
childChanged
:=
rewriteParamKeysInValue
(
item
,
cache
)
if
childChanged
{
changed
=
true
}
rewritten
[
idx
]
=
newItem
}
if
!
changed
{
return
value
,
false
}
return
rewritten
,
true
default
:
return
value
,
false
}
}
msg
,
ok
:=
event
[
"message"
]
.
(
map
[
string
]
any
)
if
!
ok
{
return
line
func
rewriteToolNamesInValue
(
value
any
,
toolNameMap
map
[
string
]
string
)
bool
{
switch
v
:=
value
.
(
type
)
{
case
map
[
string
]
any
:
changed
:=
false
if
blockType
,
_
:=
v
[
"type"
]
.
(
string
);
blockType
==
"tool_use"
{
if
name
,
ok
:=
v
[
"name"
]
.
(
string
);
ok
{
mapped
:=
normalizeToolNameForOpenCode
(
name
,
toolNameMap
)
if
mapped
!=
name
{
v
[
"name"
]
=
mapped
changed
=
true
}
}
if
input
,
ok
:=
v
[
"input"
]
.
(
map
[
string
]
any
);
ok
{
rewrittenInput
,
inputChanged
:=
rewriteParamKeysInValue
(
input
,
toolNameMap
)
if
inputChanged
{
if
m
,
ok
:=
rewrittenInput
.
(
map
[
string
]
any
);
ok
{
v
[
"input"
]
=
m
changed
=
true
}
}
}
}
for
_
,
item
:=
range
v
{
if
rewriteToolNamesInValue
(
item
,
toolNameMap
)
{
changed
=
true
}
}
return
changed
case
[]
any
:
changed
:=
false
for
_
,
item
:=
range
v
{
if
rewriteToolNamesInValue
(
item
,
toolNameMap
)
{
changed
=
true
}
}
return
changed
default
:
return
false
}
}
model
,
ok
:=
msg
[
"model"
]
.
(
string
)
if
!
ok
||
model
!=
fromModel
{
return
line
func
replaceToolNamesInText
(
text
string
,
toolNameMap
map
[
string
]
string
)
string
{
if
text
==
""
{
return
text
}
output
:=
toolNameFieldRe
.
ReplaceAllStringFunc
(
text
,
func
(
match
string
)
string
{
submatches
:=
toolNameFieldRe
.
FindStringSubmatch
(
match
)
if
len
(
submatches
)
<
2
{
return
match
}
name
:=
submatches
[
1
]
mapped
:=
normalizeToolNameForOpenCode
(
name
,
toolNameMap
)
if
mapped
==
name
{
return
match
}
return
strings
.
Replace
(
match
,
name
,
mapped
,
1
)
})
output
=
modelFieldRe
.
ReplaceAllStringFunc
(
output
,
func
(
match
string
)
string
{
submatches
:=
modelFieldRe
.
FindStringSubmatch
(
match
)
if
len
(
submatches
)
<
2
{
return
match
}
model
:=
submatches
[
1
]
mapped
:=
claude
.
DenormalizeModelID
(
model
)
if
mapped
==
model
{
return
match
}
return
strings
.
Replace
(
match
,
model
,
mapped
,
1
)
})
msg
[
"model"
]
=
toModel
newData
,
err
:=
json
.
Marshal
(
event
)
if
err
!=
nil
{
return
line
for
mapped
,
original
:=
range
toolNameMap
{
if
mapped
==
""
||
original
==
""
||
mapped
==
original
{
continue
}
output
=
strings
.
ReplaceAll
(
output
,
"
\"
"
+
mapped
+
"
\"
:"
,
"
\"
"
+
original
+
"
\"
:"
)
output
=
strings
.
ReplaceAll
(
output
,
"
\\\"
"
+
mapped
+
"
\\\"
:"
,
"
\\\"
"
+
original
+
"
\\\"
:"
)
}
return
"data: "
+
string
(
newData
)
return
output
}
func
(
s
*
GatewayService
)
parseSSEUsage
(
data
string
,
usage
*
ClaudeUsage
)
{
...
...
@@ -3394,7 +4404,7 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
}
}
func
(
s
*
GatewayService
)
handleNonStreamingResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
Account
,
originalModel
,
mappedModel
string
)
(
*
ClaudeUsage
,
error
)
{
func
(
s
*
GatewayService
)
handleNonStreamingResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
Account
,
originalModel
,
mappedModel
string
,
toolNameMap
map
[
string
]
string
,
mimicClaudeCode
bool
)
(
*
ClaudeUsage
,
error
)
{
// 更新5h窗口状态
s
.
rateLimitService
.
UpdateSessionWindow
(
ctx
,
account
,
resp
.
Header
)
...
...
@@ -3415,6 +4425,9 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
if
originalModel
!=
mappedModel
{
body
=
s
.
replaceModelInResponseBody
(
body
,
mappedModel
,
originalModel
)
}
if
mimicClaudeCode
{
body
=
s
.
replaceToolNamesInResponseBody
(
body
,
toolNameMap
)
}
responseheaders
.
WriteFilteredHeaders
(
c
.
Writer
.
Header
(),
resp
.
Header
,
s
.
cfg
.
Security
.
ResponseHeaders
)
...
...
@@ -3452,6 +4465,28 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
return
newBody
}
func
(
s
*
GatewayService
)
replaceToolNamesInResponseBody
(
body
[]
byte
,
toolNameMap
map
[
string
]
string
)
[]
byte
{
if
len
(
body
)
==
0
{
return
body
}
var
resp
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
resp
);
err
!=
nil
{
replaced
:=
replaceToolNamesInText
(
string
(
body
),
toolNameMap
)
if
replaced
==
string
(
body
)
{
return
body
}
return
[]
byte
(
replaced
)
}
if
!
rewriteToolNamesInValue
(
resp
,
toolNameMap
)
{
return
body
}
newBody
,
err
:=
json
.
Marshal
(
resp
)
if
err
!=
nil
{
return
body
}
return
newBody
}
// RecordUsageInput 记录使用量的输入参数
type
RecordUsageInput
struct
{
Result
*
ForwardResult
...
...
@@ -3773,6 +4808,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
body
:=
parsed
.
Body
reqModel
:=
parsed
.
Model
isClaudeCode
:=
isClaudeCodeRequest
(
ctx
,
c
,
parsed
)
shouldMimicClaudeCode
:=
account
.
IsOAuth
()
&&
!
isClaudeCode
if
shouldMimicClaudeCode
{
normalizeOpts
:=
claudeOAuthNormalizeOptions
{
stripSystemCacheControl
:
true
}
body
,
reqModel
,
_
=
normalizeClaudeOAuthRequestBody
(
body
,
reqModel
,
normalizeOpts
)
}
// Antigravity 账户不支持 count_tokens 转发,直接返回空值
if
account
.
Platform
==
PlatformAntigravity
{
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"input_tokens"
:
0
})
...
...
@@ -3799,7 +4842,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
// 构建上游请求
upstreamReq
,
err
:=
s
.
buildCountTokensRequest
(
ctx
,
c
,
account
,
body
,
token
,
tokenType
,
reqModel
)
upstreamReq
,
err
:=
s
.
buildCountTokensRequest
(
ctx
,
c
,
account
,
body
,
token
,
tokenType
,
reqModel
,
shouldMimicClaudeCode
)
if
err
!=
nil
{
s
.
countTokensError
(
c
,
http
.
StatusInternalServerError
,
"api_error"
,
"Failed to build request"
)
return
err
...
...
@@ -3832,7 +4875,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
log
.
Printf
(
"Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks"
,
account
.
ID
)
filteredBody
:=
FilterThinkingBlocksForRetry
(
body
)
retryReq
,
buildErr
:=
s
.
buildCountTokensRequest
(
ctx
,
c
,
account
,
filteredBody
,
token
,
tokenType
,
reqModel
)
retryReq
,
buildErr
:=
s
.
buildCountTokensRequest
(
ctx
,
c
,
account
,
filteredBody
,
token
,
tokenType
,
reqModel
,
shouldMimicClaudeCode
)
if
buildErr
==
nil
{
retryResp
,
retryErr
:=
s
.
httpUpstream
.
DoWithTLS
(
retryReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
())
if
retryErr
==
nil
{
...
...
@@ -3897,7 +4940,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
// buildCountTokensRequest 构建 count_tokens 上游请求
func
(
s
*
GatewayService
)
buildCountTokensRequest
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
,
token
,
tokenType
,
modelID
string
)
(
*
http
.
Request
,
error
)
{
func
(
s
*
GatewayService
)
buildCountTokensRequest
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
,
token
,
tokenType
,
modelID
string
,
mimicClaudeCode
bool
)
(
*
http
.
Request
,
error
)
{
// 确定目标 URL
targetURL
:=
claudeAPICountTokensURL
if
account
.
Type
==
AccountTypeAPIKey
{
...
...
@@ -3911,10 +4954,15 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
}
clientHeaders
:=
http
.
Header
{}
if
c
!=
nil
&&
c
.
Request
!=
nil
{
clientHeaders
=
c
.
Request
.
Header
}
// OAuth 账号:应用统一指纹和重写 userID
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
if
account
.
IsOAuth
()
&&
s
.
identityService
!=
nil
{
fp
,
err
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
.
Request
.
Header
)
fp
,
err
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
lient
Header
s
)
if
err
==
nil
{
accountUUID
:=
account
.
GetExtraString
(
"account_uuid"
)
if
accountUUID
!=
""
&&
fp
.
ClientID
!=
""
{
...
...
@@ -3938,7 +4986,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
// 白名单透传 headers
for
key
,
values
:=
range
c
.
Request
.
Header
{
for
key
,
values
:=
range
c
lient
Header
s
{
lowerKey
:=
strings
.
ToLower
(
key
)
if
allowedHeaders
[
lowerKey
]
{
for
_
,
v
:=
range
values
{
...
...
@@ -3949,7 +4997,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
// OAuth 账号:应用指纹到请求头
if
account
.
IsOAuth
()
&&
s
.
identityService
!=
nil
{
fp
,
_
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
.
Request
.
Header
)
fp
,
_
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
lient
Header
s
)
if
fp
!=
nil
{
s
.
identityService
.
ApplyFingerprint
(
req
,
fp
)
}
...
...
@@ -3962,10 +5010,30 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
if
req
.
Header
.
Get
(
"anthropic-version"
)
==
""
{
req
.
Header
.
Set
(
"anthropic-version"
,
"2023-06-01"
)
}
if
tokenType
==
"oauth"
{
applyClaudeOAuthHeaderDefaults
(
req
,
false
)
}
// OAuth 账号:处理 anthropic-beta header
if
tokenType
==
"oauth"
{
req
.
Header
.
Set
(
"anthropic-beta"
,
s
.
getBetaHeader
(
modelID
,
c
.
GetHeader
(
"anthropic-beta"
)))
if
mimicClaudeCode
{
applyClaudeCodeMimicHeaders
(
req
,
false
)
incomingBeta
:=
req
.
Header
.
Get
(
"anthropic-beta"
)
requiredBetas
:=
[]
string
{
claude
.
BetaClaudeCode
,
claude
.
BetaOAuth
,
claude
.
BetaInterleavedThinking
,
claude
.
BetaTokenCounting
}
req
.
Header
.
Set
(
"anthropic-beta"
,
mergeAnthropicBeta
(
requiredBetas
,
incomingBeta
))
}
else
{
clientBetaHeader
:=
req
.
Header
.
Get
(
"anthropic-beta"
)
if
clientBetaHeader
==
""
{
req
.
Header
.
Set
(
"anthropic-beta"
,
claude
.
CountTokensBetaHeader
)
}
else
{
beta
:=
s
.
getBetaHeader
(
modelID
,
clientBetaHeader
)
if
!
strings
.
Contains
(
beta
,
claude
.
BetaTokenCounting
)
{
beta
=
beta
+
","
+
claude
.
BetaTokenCounting
}
req
.
Header
.
Set
(
"anthropic-beta"
,
beta
)
}
}
}
else
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
InjectBetaForAPIKey
&&
req
.
Header
.
Get
(
"anthropic-beta"
)
==
""
{
// API-key:与 messages 同步的按需 beta 注入(默认关闭)
if
requestNeedsBetaFeatures
(
body
)
{
...
...
@@ -3975,6 +5043,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
}
if
c
!=
nil
&&
tokenType
==
"oauth"
{
c
.
Set
(
claudeMimicDebugInfoKey
,
buildClaudeMimicDebugLine
(
req
,
body
,
account
,
tokenType
,
mimicClaudeCode
))
}
if
s
.
debugClaudeMimicEnabled
()
{
logClaudeMimicDebug
(
req
,
body
,
account
,
tokenType
,
mimicClaudeCode
)
}
return
req
,
nil
}
...
...
backend/internal/service/identity_service.go
View file @
a09478f3
...
...
@@ -26,13 +26,13 @@ var (
// 默认指纹值(当客户端未提供时使用)
var
defaultFingerprint
=
Fingerprint
{
UserAgent
:
"claude-cli/2.
0.6
2 (external, cli)"
,
UserAgent
:
"claude-cli/2.
1.2
2 (external, cli)"
,
StainlessLang
:
"js"
,
StainlessPackageVersion
:
"0.
52
.0"
,
StainlessPackageVersion
:
"0.
70
.0"
,
StainlessOS
:
"Linux"
,
StainlessArch
:
"
x
64"
,
StainlessArch
:
"
arm
64"
,
StainlessRuntime
:
"node"
,
StainlessRuntimeVersion
:
"v2
2
.1
4
.0"
,
StainlessRuntimeVersion
:
"v2
4
.1
3
.0"
,
}
// Fingerprint represents account fingerprint data
...
...
@@ -327,7 +327,7 @@ func generateUUIDFromSeed(seed string) string {
}
// parseUserAgentVersion 解析user-agent版本号
// 例如:claude-cli/2.
0.6
2 -> (2,
0
,
6
2)
// 例如:claude-cli/2.
1.
2 -> (2,
1
, 2)
func
parseUserAgentVersion
(
ua
string
)
(
major
,
minor
,
patch
int
,
ok
bool
)
{
// 匹配 xxx/x.y.z 格式
matches
:=
userAgentVersionRegex
.
FindStringSubmatch
(
ua
)
...
...
backend/internal/service/openai_gateway_service.go
View file @
a09478f3
...
...
@@ -1260,15 +1260,29 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
// 记录上次收到上游数据的时间,用于控制 keepalive 发送频率
lastDataAt
:=
time
.
Now
()
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
// 仅发送一次错误事件,避免多次写入导致协议混乱。
// 注意:OpenAI `/v1/responses` streaming 事件必须符合 OpenAI Responses schema;
// 否则下游 SDK(例如 OpenCode)会因为类型校验失败而报错。
errorEventSent
:=
false
clientDisconnected
:=
false
// 客户端断开后继续 drain 上游以收集 usage
sendErrorEvent
:=
func
(
reason
string
)
{
if
errorEventSent
{
if
errorEventSent
||
clientDisconnected
{
return
}
errorEventSent
=
true
_
,
_
=
fmt
.
Fprintf
(
w
,
"event: error
\n
data: {
\"
error
\"
:
\"
%s
\"
}
\n\n
"
,
reason
)
flusher
.
Flush
()
payload
:=
map
[
string
]
any
{
"type"
:
"error"
,
"sequence_number"
:
0
,
"error"
:
map
[
string
]
any
{
"type"
:
"upstream_error"
,
"message"
:
reason
,
"code"
:
reason
,
},
}
if
b
,
err
:=
json
.
Marshal
(
payload
);
err
==
nil
{
_
,
_
=
fmt
.
Fprintf
(
w
,
"data: %s
\n\n
"
,
b
)
flusher
.
Flush
()
}
}
needModelReplace
:=
originalModel
!=
mappedModel
...
...
@@ -1280,6 +1294,17 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
}
if
ev
.
err
!=
nil
{
// 客户端断开/取消请求时,上游读取往往会返回 context canceled。
// /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。
if
errors
.
Is
(
ev
.
err
,
context
.
Canceled
)
||
errors
.
Is
(
ev
.
err
,
context
.
DeadlineExceeded
)
{
log
.
Printf
(
"Context canceled during streaming, returning collected usage"
)
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
}
// 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
if
clientDisconnected
{
log
.
Printf
(
"Upstream read error after client disconnect: %v, returning collected usage"
,
ev
.
err
)
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
}
if
errors
.
Is
(
ev
.
err
,
bufio
.
ErrTooLong
)
{
log
.
Printf
(
"SSE line too long: account=%d max_size=%d error=%v"
,
account
.
ID
,
maxLineSize
,
ev
.
err
)
sendErrorEvent
(
"response_too_large"
)
...
...
@@ -1303,15 +1328,19 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
if
correctedData
,
corrected
:=
s
.
toolCorrector
.
CorrectToolCallsInSSEData
(
data
);
corrected
{
data
=
correctedData
line
=
"data: "
+
correctedData
}
// Forward line
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
sendErrorEvent
(
"write_failed"
)
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
// 写入客户端(客户端断开后继续 drain 上游)
if
!
clientDisconnected
{
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
clientDisconnected
=
true
log
.
Printf
(
"Client disconnected during streaming, continuing to drain upstream for billing"
)
}
else
{
flusher
.
Flush
()
}
}
flusher
.
Flush
()
// Record first token time
if
firstTokenMs
==
nil
&&
data
!=
""
&&
data
!=
"[DONE]"
{
...
...
@@ -1321,11 +1350,14 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
s
.
parseSSEUsage
(
data
,
usage
)
}
else
{
// Forward non-data lines as-is
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
sendErrorEvent
(
"write_failed"
)
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
if
!
clientDisconnected
{
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
clientDisconnected
=
true
log
.
Printf
(
"Client disconnected during streaming, continuing to drain upstream for billing"
)
}
else
{
flusher
.
Flush
()
}
}
flusher
.
Flush
()
}
case
<-
intervalCh
:
...
...
@@ -1333,6 +1365,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if
time
.
Since
(
lastRead
)
<
streamInterval
{
continue
}
if
clientDisconnected
{
log
.
Printf
(
"Upstream timeout after client disconnect, returning collected usage"
)
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
}
log
.
Printf
(
"Stream data interval timeout: account=%d model=%s interval=%s"
,
account
.
ID
,
originalModel
,
streamInterval
)
// 处理流超时,可能标记账户为临时不可调度或错误状态
if
s
.
rateLimitService
!=
nil
{
...
...
@@ -1342,11 +1378,16 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream data interval timeout"
)
case
<-
keepaliveCh
:
if
clientDisconnected
{
continue
}
if
time
.
Since
(
lastDataAt
)
<
keepaliveInterval
{
continue
}
if
_
,
err
:=
fmt
.
Fprint
(
w
,
":
\n\n
"
);
err
!=
nil
{
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
clientDisconnected
=
true
log
.
Printf
(
"Client disconnected during streaming, continuing to drain upstream for billing"
)
continue
}
flusher
.
Flush
()
}
...
...
backend/internal/service/openai_gateway_service_test.go
View file @
a09478f3
...
...
@@ -59,6 +59,25 @@ type stubConcurrencyCache struct {
skipDefaultLoad
bool
}
type
cancelReadCloser
struct
{}
func
(
c
cancelReadCloser
)
Read
(
p
[]
byte
)
(
int
,
error
)
{
return
0
,
context
.
Canceled
}
func
(
c
cancelReadCloser
)
Close
()
error
{
return
nil
}
type
failingGinWriter
struct
{
gin
.
ResponseWriter
failAfter
int
writes
int
}
func
(
w
*
failingGinWriter
)
Write
(
p
[]
byte
)
(
int
,
error
)
{
if
w
.
writes
>=
w
.
failAfter
{
return
0
,
errors
.
New
(
"write failed"
)
}
w
.
writes
++
return
w
.
ResponseWriter
.
Write
(
p
)
}
func
(
c
stubConcurrencyCache
)
AcquireAccountSlot
(
ctx
context
.
Context
,
accountID
int64
,
maxConcurrency
int
,
requestID
string
)
(
bool
,
error
)
{
if
c
.
acquireResults
!=
nil
{
if
result
,
ok
:=
c
.
acquireResults
[
accountID
];
ok
{
...
...
@@ -814,8 +833,85 @@ func TestOpenAIStreamingTimeout(t *testing.T) {
if
err
==
nil
||
!
strings
.
Contains
(
err
.
Error
(),
"stream data interval timeout"
)
{
t
.
Fatalf
(
"expected stream timeout error, got %v"
,
err
)
}
if
!
strings
.
Contains
(
rec
.
Body
.
String
(),
"stream_timeout"
)
{
t
.
Fatalf
(
"expected stream_timeout SSE error, got %q"
,
rec
.
Body
.
String
())
if
!
strings
.
Contains
(
rec
.
Body
.
String
(),
"
\"
type
\"
:
\"
error
\"
"
)
||
!
strings
.
Contains
(
rec
.
Body
.
String
(),
"stream_timeout"
)
{
t
.
Fatalf
(
"expected OpenAI-compatible error SSE event, got %q"
,
rec
.
Body
.
String
())
}
}
func
TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
StreamDataIntervalTimeout
:
0
,
StreamKeepaliveInterval
:
0
,
MaxLineSize
:
defaultMaxLineSize
,
},
}
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
}
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
ctx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
cancel
()
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
.
WithContext
(
ctx
)
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
cancelReadCloser
{},
Header
:
http
.
Header
{},
}
_
,
err
:=
svc
.
handleStreamingResponse
(
c
.
Request
.
Context
(),
resp
,
c
,
&
Account
{
ID
:
1
},
time
.
Now
(),
"model"
,
"model"
)
if
err
!=
nil
{
t
.
Fatalf
(
"expected nil error, got %v"
,
err
)
}
if
strings
.
Contains
(
rec
.
Body
.
String
(),
"event: error"
)
||
strings
.
Contains
(
rec
.
Body
.
String
(),
"stream_read_error"
)
{
t
.
Fatalf
(
"expected no injected SSE error event, got %q"
,
rec
.
Body
.
String
())
}
}
func
TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
StreamDataIntervalTimeout
:
0
,
StreamKeepaliveInterval
:
0
,
MaxLineSize
:
defaultMaxLineSize
,
},
}
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
}
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
c
.
Writer
=
&
failingGinWriter
{
ResponseWriter
:
c
.
Writer
,
failAfter
:
0
}
pr
,
pw
:=
io
.
Pipe
()
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
pr
,
Header
:
http
.
Header
{},
}
go
func
()
{
defer
func
()
{
_
=
pw
.
Close
()
}()
_
,
_
=
pw
.
Write
([]
byte
(
"data: {
\"
type
\"
:
\"
response.in_progress
\"
,
\"
response
\"
:{}}
\n\n
"
))
_
,
_
=
pw
.
Write
([]
byte
(
"data: {
\"
type
\"
:
\"
response.completed
\"
,
\"
response
\"
:{
\"
usage
\"
:{
\"
input_tokens
\"
:3,
\"
output_tokens
\"
:5,
\"
input_tokens_details
\"
:{
\"
cached_tokens
\"
:1}}}}
\n\n
"
))
}()
result
,
err
:=
svc
.
handleStreamingResponse
(
c
.
Request
.
Context
(),
resp
,
c
,
&
Account
{
ID
:
1
},
time
.
Now
(),
"model"
,
"model"
)
_
=
pr
.
Close
()
if
err
!=
nil
{
t
.
Fatalf
(
"expected nil error, got %v"
,
err
)
}
if
result
==
nil
||
result
.
usage
==
nil
{
t
.
Fatalf
(
"expected usage result"
)
}
if
result
.
usage
.
InputTokens
!=
3
||
result
.
usage
.
OutputTokens
!=
5
||
result
.
usage
.
CacheReadInputTokens
!=
1
{
t
.
Fatalf
(
"unexpected usage: %+v"
,
*
result
.
usage
)
}
if
strings
.
Contains
(
rec
.
Body
.
String
(),
"event: error"
)
||
strings
.
Contains
(
rec
.
Body
.
String
(),
"write_failed"
)
{
t
.
Fatalf
(
"expected no injected SSE error event, got %q"
,
rec
.
Body
.
String
())
}
}
...
...
@@ -854,8 +950,8 @@ func TestOpenAIStreamingTooLong(t *testing.T) {
if
!
errors
.
Is
(
err
,
bufio
.
ErrTooLong
)
{
t
.
Fatalf
(
"expected ErrTooLong, got %v"
,
err
)
}
if
!
strings
.
Contains
(
rec
.
Body
.
String
(),
"response_too_large"
)
{
t
.
Fatalf
(
"expected
response_too_large
SSE e
rror
, got %q"
,
rec
.
Body
.
String
())
if
!
strings
.
Contains
(
rec
.
Body
.
String
(),
"
\"
type
\"
:
\"
error
\"
"
)
||
!
strings
.
Contains
(
rec
.
Body
.
String
(),
"response_too_large"
)
{
t
.
Fatalf
(
"expected
OpenAI-compatible error
SSE e
vent
, got %q"
,
rec
.
Body
.
String
())
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment