Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
a09478f3
Unverified
Commit
a09478f3
authored
Feb 03, 2026
by
Wesley Liddick
Committed by
GitHub
Feb 03, 2026
Browse files
Merge pull request #316 from cyhhao/fix/claude-oauth-compat
fix(网关): 完善 Claude OAuth/Claude Code 兼容
parents
0ab68aa9
2fe8932c
Changes
12
Hide whitespace changes
Inline
Side-by-side
backend/internal/handler/gateway_handler.go
View file @
a09478f3
...
@@ -779,6 +779,9 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
...
@@ -779,6 +779,9 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
return
return
}
}
// 检查是否为 Claude Code 客户端,设置到 context 中
SetClaudeCodeClientContext
(
c
,
body
)
setOpsRequestContext
(
c
,
""
,
false
,
body
)
setOpsRequestContext
(
c
,
""
,
false
,
body
)
parsedReq
,
err
:=
service
.
ParseGatewayRequest
(
body
)
parsedReq
,
err
:=
service
.
ParseGatewayRequest
(
body
)
...
...
backend/internal/pkg/claude/constants.go
View file @
a09478f3
...
@@ -9,11 +9,26 @@ const (
...
@@ -9,11 +9,26 @@ const (
BetaClaudeCode
=
"claude-code-20250219"
BetaClaudeCode
=
"claude-code-20250219"
BetaInterleavedThinking
=
"interleaved-thinking-2025-05-14"
BetaInterleavedThinking
=
"interleaved-thinking-2025-05-14"
BetaFineGrainedToolStreaming
=
"fine-grained-tool-streaming-2025-05-14"
BetaFineGrainedToolStreaming
=
"fine-grained-tool-streaming-2025-05-14"
BetaTokenCounting
=
"token-counting-2024-11-01"
)
)
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
const
DefaultBetaHeader
=
BetaClaudeCode
+
","
+
BetaOAuth
+
","
+
BetaInterleavedThinking
+
","
+
BetaFineGrainedToolStreaming
const
DefaultBetaHeader
=
BetaClaudeCode
+
","
+
BetaOAuth
+
","
+
BetaInterleavedThinking
+
","
+
BetaFineGrainedToolStreaming
// MessageBetaHeaderNoTools /v1/messages 在无工具时的 beta header
//
// NOTE: Claude Code OAuth credentials are scoped to Claude Code. When we "mimic"
// Claude Code for non-Claude-Code clients, we must include the claude-code beta
// even if the request doesn't use tools, otherwise upstream may reject the
// request as a non-Claude-Code API request.
const
MessageBetaHeaderNoTools
=
BetaClaudeCode
+
","
+
BetaOAuth
+
","
+
BetaInterleavedThinking
// MessageBetaHeaderWithTools /v1/messages 在有工具时的 beta header
const
MessageBetaHeaderWithTools
=
BetaClaudeCode
+
","
+
BetaOAuth
+
","
+
BetaInterleavedThinking
// CountTokensBetaHeader count_tokens 请求使用的 anthropic-beta header
const
CountTokensBetaHeader
=
BetaClaudeCode
+
","
+
BetaOAuth
+
","
+
BetaInterleavedThinking
+
","
+
BetaTokenCounting
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta)
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta)
const
HaikuBetaHeader
=
BetaOAuth
+
","
+
BetaInterleavedThinking
const
HaikuBetaHeader
=
BetaOAuth
+
","
+
BetaInterleavedThinking
...
@@ -25,15 +40,17 @@ const APIKeyHaikuBetaHeader = BetaInterleavedThinking
...
@@ -25,15 +40,17 @@ const APIKeyHaikuBetaHeader = BetaInterleavedThinking
// DefaultHeaders 是 Claude Code 客户端默认请求头。
// DefaultHeaders 是 Claude Code 客户端默认请求头。
var
DefaultHeaders
=
map
[
string
]
string
{
var
DefaultHeaders
=
map
[
string
]
string
{
"User-Agent"
:
"claude-cli/2.0.62 (external, cli)"
,
// Keep these in sync with recent Claude CLI traffic to reduce the chance
// that Claude Code-scoped OAuth credentials are rejected as "non-CLI" usage.
"User-Agent"
:
"claude-cli/2.1.22 (external, cli)"
,
"X-Stainless-Lang"
:
"js"
,
"X-Stainless-Lang"
:
"js"
,
"X-Stainless-Package-Version"
:
"0.
52
.0"
,
"X-Stainless-Package-Version"
:
"0.
70
.0"
,
"X-Stainless-OS"
:
"Linux"
,
"X-Stainless-OS"
:
"Linux"
,
"X-Stainless-Arch"
:
"
x
64"
,
"X-Stainless-Arch"
:
"
arm
64"
,
"X-Stainless-Runtime"
:
"node"
,
"X-Stainless-Runtime"
:
"node"
,
"X-Stainless-Runtime-Version"
:
"v2
2
.1
4
.0"
,
"X-Stainless-Runtime-Version"
:
"v2
4
.1
3
.0"
,
"X-Stainless-Retry-Count"
:
"0"
,
"X-Stainless-Retry-Count"
:
"0"
,
"X-Stainless-Timeout"
:
"60"
,
"X-Stainless-Timeout"
:
"60
0
"
,
"X-App"
:
"cli"
,
"X-App"
:
"cli"
,
"Anthropic-Dangerous-Direct-Browser-Access"
:
"true"
,
"Anthropic-Dangerous-Direct-Browser-Access"
:
"true"
,
}
}
...
@@ -79,3 +96,39 @@ func DefaultModelIDs() []string {
...
@@ -79,3 +96,39 @@ func DefaultModelIDs() []string {
// DefaultTestModel 测试时使用的默认模型
// DefaultTestModel 测试时使用的默认模型
const
DefaultTestModel
=
"claude-sonnet-4-5-20250929"
const
DefaultTestModel
=
"claude-sonnet-4-5-20250929"
// ModelIDOverrides Claude OAuth 请求需要的模型 ID 映射
var
ModelIDOverrides
=
map
[
string
]
string
{
"claude-sonnet-4-5"
:
"claude-sonnet-4-5-20250929"
,
"claude-opus-4-5"
:
"claude-opus-4-5-20251101"
,
"claude-haiku-4-5"
:
"claude-haiku-4-5-20251001"
,
}
// ModelIDReverseOverrides 用于将上游模型 ID 还原为短名
var
ModelIDReverseOverrides
=
map
[
string
]
string
{
"claude-sonnet-4-5-20250929"
:
"claude-sonnet-4-5"
,
"claude-opus-4-5-20251101"
:
"claude-opus-4-5"
,
"claude-haiku-4-5-20251001"
:
"claude-haiku-4-5"
,
}
// NormalizeModelID 根据 Claude OAuth 规则映射模型
func
NormalizeModelID
(
id
string
)
string
{
if
id
==
""
{
return
id
}
if
mapped
,
ok
:=
ModelIDOverrides
[
id
];
ok
{
return
mapped
}
return
id
}
// DenormalizeModelID 将上游模型 ID 转换为短名
func
DenormalizeModelID
(
id
string
)
string
{
if
id
==
""
{
return
id
}
if
mapped
,
ok
:=
ModelIDReverseOverrides
[
id
];
ok
{
return
mapped
}
return
id
}
backend/internal/service/account.go
View file @
a09478f3
...
@@ -410,6 +410,22 @@ func (a *Account) GetExtraString(key string) string {
...
@@ -410,6 +410,22 @@ func (a *Account) GetExtraString(key string) string {
return
""
return
""
}
}
func
(
a
*
Account
)
GetClaudeUserID
()
string
{
if
v
:=
strings
.
TrimSpace
(
a
.
GetExtraString
(
"claude_user_id"
));
v
!=
""
{
return
v
}
if
v
:=
strings
.
TrimSpace
(
a
.
GetExtraString
(
"anthropic_user_id"
));
v
!=
""
{
return
v
}
if
v
:=
strings
.
TrimSpace
(
a
.
GetCredential
(
"claude_user_id"
));
v
!=
""
{
return
v
}
if
v
:=
strings
.
TrimSpace
(
a
.
GetCredential
(
"anthropic_user_id"
));
v
!=
""
{
return
v
}
return
""
}
func
(
a
*
Account
)
IsCustomErrorCodesEnabled
()
bool
{
func
(
a
*
Account
)
IsCustomErrorCodesEnabled
()
bool
{
if
a
.
Type
!=
AccountTypeAPIKey
||
a
.
Credentials
==
nil
{
if
a
.
Type
!=
AccountTypeAPIKey
||
a
.
Credentials
==
nil
{
return
false
return
false
...
...
backend/internal/service/account_test_service.go
View file @
a09478f3
...
@@ -123,7 +123,7 @@ func createTestPayload(modelID string) (map[string]any, error) {
...
@@ -123,7 +123,7 @@ func createTestPayload(modelID string) (map[string]any, error) {
"system"
:
[]
map
[
string
]
any
{
"system"
:
[]
map
[
string
]
any
{
{
{
"type"
:
"text"
,
"type"
:
"text"
,
"text"
:
"You are C
laude
Code
, Anthropic's official CLI for Claude."
,
"text"
:
c
laudeCode
SystemPrompt
,
"cache_control"
:
map
[
string
]
string
{
"cache_control"
:
map
[
string
]
string
{
"type"
:
"ephemeral"
,
"type"
:
"ephemeral"
,
},
},
...
...
backend/internal/service/gateway_beta_test.go
0 → 100644
View file @
a09478f3
package
service
import
(
"testing"
"github.com/stretchr/testify/require"
)
func
TestMergeAnthropicBeta
(
t
*
testing
.
T
)
{
got
:=
mergeAnthropicBeta
(
[]
string
{
"oauth-2025-04-20"
,
"interleaved-thinking-2025-05-14"
},
"foo, oauth-2025-04-20,bar, foo"
,
)
require
.
Equal
(
t
,
"oauth-2025-04-20,interleaved-thinking-2025-05-14,foo,bar"
,
got
)
}
func
TestMergeAnthropicBeta_EmptyIncoming
(
t
*
testing
.
T
)
{
got
:=
mergeAnthropicBeta
(
[]
string
{
"oauth-2025-04-20"
,
"interleaved-thinking-2025-05-14"
},
""
,
)
require
.
Equal
(
t
,
"oauth-2025-04-20,interleaved-thinking-2025-05-14"
,
got
)
}
backend/internal/service/gateway_oauth_metadata_test.go
0 → 100644
View file @
a09478f3
package
service
import
(
"regexp"
"testing"
"github.com/stretchr/testify/require"
)
func
TestBuildOAuthMetadataUserID_FallbackWithoutAccountUUID
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
parsed
:=
&
ParsedRequest
{
Model
:
"claude-sonnet-4-5"
,
Stream
:
true
,
MetadataUserID
:
""
,
System
:
nil
,
Messages
:
nil
,
}
account
:=
&
Account
{
ID
:
123
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{},
// intentionally missing account_uuid / claude_user_id
}
fp
:=
&
Fingerprint
{
ClientID
:
"deadbeef"
}
// should be used as user id in legacy format
got
:=
svc
.
buildOAuthMetadataUserID
(
parsed
,
account
,
fp
)
require
.
NotEmpty
(
t
,
got
)
// Legacy format: user_{client}_account__session_{uuid}
re
:=
regexp
.
MustCompile
(
`^user_[a-zA-Z0-9]+_account__session_[a-f0-9-]{36}$`
)
require
.
True
(
t
,
re
.
MatchString
(
got
),
"unexpected user_id format: %s"
,
got
)
}
func
TestBuildOAuthMetadataUserID_UsesAccountUUIDWhenPresent
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
parsed
:=
&
ParsedRequest
{
Model
:
"claude-sonnet-4-5"
,
Stream
:
true
,
MetadataUserID
:
""
,
}
account
:=
&
Account
{
ID
:
123
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{
"account_uuid"
:
"acc-uuid"
,
"claude_user_id"
:
"clientid123"
,
"anthropic_user_id"
:
""
,
},
}
got
:=
svc
.
buildOAuthMetadataUserID
(
parsed
,
account
,
nil
)
require
.
NotEmpty
(
t
,
got
)
// New format: user_{client}_account_{account_uuid}_session_{uuid}
re
:=
regexp
.
MustCompile
(
`^user_clientid123_account_acc-uuid_session_[a-f0-9-]{36}$`
)
require
.
True
(
t
,
re
.
MatchString
(
got
),
"unexpected user_id format: %s"
,
got
)
}
backend/internal/service/gateway_prompt_test.go
View file @
a09478f3
...
@@ -2,6 +2,7 @@ package service
...
@@ -2,6 +2,7 @@ package service
import
(
import
(
"encoding/json"
"encoding/json"
"strings"
"testing"
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/require"
...
@@ -134,6 +135,8 @@ func TestSystemIncludesClaudeCodePrompt(t *testing.T) {
...
@@ -134,6 +135,8 @@ func TestSystemIncludesClaudeCodePrompt(t *testing.T) {
}
}
func
TestInjectClaudeCodePrompt
(
t
*
testing
.
T
)
{
func
TestInjectClaudeCodePrompt
(
t
*
testing
.
T
)
{
claudePrefix
:=
strings
.
TrimSpace
(
claudeCodeSystemPrompt
)
tests
:=
[]
struct
{
tests
:=
[]
struct
{
name
string
name
string
body
string
body
string
...
@@ -162,7 +165,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
...
@@ -162,7 +165,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
system
:
"Custom prompt"
,
system
:
"Custom prompt"
,
wantSystemLen
:
2
,
wantSystemLen
:
2
,
wantFirstText
:
claudeCodeSystemPrompt
,
wantFirstText
:
claudeCodeSystemPrompt
,
wantSecondText
:
"
Custom prompt"
,
wantSecondText
:
claudePrefix
+
"
\n\n
Custom prompt"
,
},
},
{
{
name
:
"string system equals Claude Code prompt"
,
name
:
"string system equals Claude Code prompt"
,
...
@@ -178,7 +181,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
...
@@ -178,7 +181,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
// Claude Code + Custom = 2
// Claude Code + Custom = 2
wantSystemLen
:
2
,
wantSystemLen
:
2
,
wantFirstText
:
claudeCodeSystemPrompt
,
wantFirstText
:
claudeCodeSystemPrompt
,
wantSecondText
:
"
Custom"
,
wantSecondText
:
claudePrefix
+
"
\n\n
Custom"
,
},
},
{
{
name
:
"array system with existing Claude Code prompt (should dedupe)"
,
name
:
"array system with existing Claude Code prompt (should dedupe)"
,
...
@@ -190,7 +193,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
...
@@ -190,7 +193,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
// Claude Code at start + Other = 2 (deduped)
// Claude Code at start + Other = 2 (deduped)
wantSystemLen
:
2
,
wantSystemLen
:
2
,
wantFirstText
:
claudeCodeSystemPrompt
,
wantFirstText
:
claudeCodeSystemPrompt
,
wantSecondText
:
"
Other"
,
wantSecondText
:
claudePrefix
+
"
\n\n
Other"
,
},
},
{
{
name
:
"empty array"
,
name
:
"empty array"
,
...
...
backend/internal/service/gateway_sanitize_test.go
0 → 100644
View file @
a09478f3
package
service
import
(
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func
TestSanitizeOpenCodeText_RewritesCanonicalSentence
(
t
*
testing
.
T
)
{
in
:=
"You are OpenCode, the best coding agent on the planet."
got
:=
sanitizeSystemText
(
in
)
require
.
Equal
(
t
,
strings
.
TrimSpace
(
claudeCodeSystemPrompt
),
got
)
}
func
TestSanitizeToolDescription_DoesNotRewriteKeywords
(
t
*
testing
.
T
)
{
in
:=
"OpenCode and opencode are mentioned."
got
:=
sanitizeToolDescription
(
in
)
// We no longer rewrite tool descriptions; only redact obvious path leaks.
require
.
Equal
(
t
,
in
,
got
)
}
backend/internal/service/gateway_service.go
View file @
a09478f3
...
@@ -20,12 +20,14 @@ import (
...
@@ -20,12 +20,14 @@ import (
"strings"
"strings"
"sync/atomic"
"sync/atomic"
"time"
"time"
"unicode"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/google/uuid"
"github.com/tidwall/gjson"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"github.com/tidwall/sjson"
...
@@ -37,8 +39,15 @@ const (
...
@@ -37,8 +39,15 @@ const (
claudeAPICountTokensURL
=
"https://api.anthropic.com/v1/messages/count_tokens?beta=true"
claudeAPICountTokensURL
=
"https://api.anthropic.com/v1/messages/count_tokens?beta=true"
stickySessionTTL
=
time
.
Hour
// 粘性会话TTL
stickySessionTTL
=
time
.
Hour
// 粘性会话TTL
defaultMaxLineSize
=
40
*
1024
*
1024
defaultMaxLineSize
=
40
*
1024
*
1024
claudeCodeSystemPrompt
=
"You are Claude Code, Anthropic's official CLI for Claude."
// Canonical Claude Code banner. Keep it EXACT (no trailing whitespace/newlines)
maxCacheControlBlocks
=
4
// Anthropic API 允许的最大 cache_control 块数量
// to match real Claude CLI traffic as closely as possible. When we need a visual
// separator between system blocks, we add "\n\n" at concatenation time.
claudeCodeSystemPrompt
=
"You are Claude Code, Anthropic's official CLI for Claude."
maxCacheControlBlocks
=
4
// Anthropic API 允许的最大 cache_control 块数量
)
const
(
claudeMimicDebugInfoKey
=
"claude_mimic_debug_info"
)
)
func
(
s
*
GatewayService
)
debugModelRoutingEnabled
()
bool
{
func
(
s
*
GatewayService
)
debugModelRoutingEnabled
()
bool
{
...
@@ -46,6 +55,11 @@ func (s *GatewayService) debugModelRoutingEnabled() bool {
...
@@ -46,6 +55,11 @@ func (s *GatewayService) debugModelRoutingEnabled() bool {
return
v
==
"1"
||
v
==
"true"
||
v
==
"yes"
||
v
==
"on"
return
v
==
"1"
||
v
==
"true"
||
v
==
"yes"
||
v
==
"on"
}
}
func
(
s
*
GatewayService
)
debugClaudeMimicEnabled
()
bool
{
v
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
os
.
Getenv
(
"SUB2API_DEBUG_CLAUDE_MIMIC"
)))
return
v
==
"1"
||
v
==
"true"
||
v
==
"yes"
||
v
==
"on"
}
func
shortSessionHash
(
sessionHash
string
)
string
{
func
shortSessionHash
(
sessionHash
string
)
string
{
if
sessionHash
==
""
{
if
sessionHash
==
""
{
return
""
return
""
...
@@ -56,12 +70,178 @@ func shortSessionHash(sessionHash string) string {
...
@@ -56,12 +70,178 @@ func shortSessionHash(sessionHash string) string {
return
sessionHash
[
:
8
]
return
sessionHash
[
:
8
]
}
}
func
redactAuthHeaderValue
(
v
string
)
string
{
v
=
strings
.
TrimSpace
(
v
)
if
v
==
""
{
return
""
}
// Keep scheme for debugging, redact secret.
if
strings
.
HasPrefix
(
strings
.
ToLower
(
v
),
"bearer "
)
{
return
"Bearer [redacted]"
}
return
"[redacted]"
}
func
safeHeaderValueForLog
(
key
string
,
v
string
)
string
{
key
=
strings
.
ToLower
(
strings
.
TrimSpace
(
key
))
switch
key
{
case
"authorization"
,
"x-api-key"
:
return
redactAuthHeaderValue
(
v
)
default
:
return
strings
.
TrimSpace
(
v
)
}
}
func
extractSystemPreviewFromBody
(
body
[]
byte
)
string
{
if
len
(
body
)
==
0
{
return
""
}
sys
:=
gjson
.
GetBytes
(
body
,
"system"
)
if
!
sys
.
Exists
()
{
return
""
}
switch
{
case
sys
.
IsArray
()
:
for
_
,
item
:=
range
sys
.
Array
()
{
if
!
item
.
IsObject
()
{
continue
}
if
strings
.
EqualFold
(
item
.
Get
(
"type"
)
.
String
(),
"text"
)
{
if
t
:=
item
.
Get
(
"text"
)
.
String
();
strings
.
TrimSpace
(
t
)
!=
""
{
return
t
}
}
}
return
""
case
sys
.
Type
==
gjson
.
String
:
return
sys
.
String
()
default
:
return
""
}
}
func
buildClaudeMimicDebugLine
(
req
*
http
.
Request
,
body
[]
byte
,
account
*
Account
,
tokenType
string
,
mimicClaudeCode
bool
)
string
{
if
req
==
nil
{
return
""
}
// Only log a minimal fingerprint to avoid leaking user content.
interesting
:=
[]
string
{
"user-agent"
,
"x-app"
,
"anthropic-dangerous-direct-browser-access"
,
"anthropic-version"
,
"anthropic-beta"
,
"x-stainless-lang"
,
"x-stainless-package-version"
,
"x-stainless-os"
,
"x-stainless-arch"
,
"x-stainless-runtime"
,
"x-stainless-runtime-version"
,
"x-stainless-retry-count"
,
"x-stainless-timeout"
,
"authorization"
,
"x-api-key"
,
"content-type"
,
"accept"
,
"x-stainless-helper-method"
,
}
h
:=
make
([]
string
,
0
,
len
(
interesting
))
for
_
,
k
:=
range
interesting
{
if
v
:=
req
.
Header
.
Get
(
k
);
v
!=
""
{
h
=
append
(
h
,
fmt
.
Sprintf
(
"%s=%q"
,
k
,
safeHeaderValueForLog
(
k
,
v
)))
}
}
metaUserID
:=
strings
.
TrimSpace
(
gjson
.
GetBytes
(
body
,
"metadata.user_id"
)
.
String
())
sysPreview
:=
strings
.
TrimSpace
(
extractSystemPreviewFromBody
(
body
))
// Truncate preview to keep logs sane.
if
len
(
sysPreview
)
>
300
{
sysPreview
=
sysPreview
[
:
300
]
+
"..."
}
sysPreview
=
strings
.
ReplaceAll
(
sysPreview
,
"
\n
"
,
"
\\
n"
)
sysPreview
=
strings
.
ReplaceAll
(
sysPreview
,
"
\r
"
,
"
\\
r"
)
aid
:=
int64
(
0
)
aname
:=
""
if
account
!=
nil
{
aid
=
account
.
ID
aname
=
account
.
Name
}
return
fmt
.
Sprintf
(
"url=%s account=%d(%s) tokenType=%s mimic=%t meta.user_id=%q system.preview=%q headers={%s}"
,
req
.
URL
.
String
(),
aid
,
aname
,
tokenType
,
mimicClaudeCode
,
metaUserID
,
sysPreview
,
strings
.
Join
(
h
,
" "
),
)
}
func
logClaudeMimicDebug
(
req
*
http
.
Request
,
body
[]
byte
,
account
*
Account
,
tokenType
string
,
mimicClaudeCode
bool
)
{
line
:=
buildClaudeMimicDebugLine
(
req
,
body
,
account
,
tokenType
,
mimicClaudeCode
)
if
line
==
""
{
return
}
log
.
Printf
(
"[ClaudeMimicDebug] %s"
,
line
)
}
func
isClaudeCodeCredentialScopeError
(
msg
string
)
bool
{
m
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
msg
))
if
m
==
""
{
return
false
}
return
strings
.
Contains
(
m
,
"only authorized for use with claude code"
)
&&
strings
.
Contains
(
m
,
"cannot be used for other api requests"
)
}
// sseDataRe matches SSE data lines with optional whitespace after colon.
// sseDataRe matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var
(
var
(
sseDataRe
=
regexp
.
MustCompile
(
`^data:\s*`
)
sseDataRe
=
regexp
.
MustCompile
(
`^data:\s*`
)
sessionIDRegex
=
regexp
.
MustCompile
(
`session_([a-f0-9-]{36})`
)
sessionIDRegex
=
regexp
.
MustCompile
(
`session_([a-f0-9-]{36})`
)
claudeCliUserAgentRe
=
regexp
.
MustCompile
(
`^claude-cli/\d+\.\d+\.\d+`
)
claudeCliUserAgentRe
=
regexp
.
MustCompile
(
`^claude-cli/\d+\.\d+\.\d+`
)
toolPrefixRe
=
regexp
.
MustCompile
(
`(?i)^(?:oc_|mcp_)`
)
toolNameBoundaryRe
=
regexp
.
MustCompile
(
`[^a-zA-Z0-9]+`
)
toolNameCamelRe
=
regexp
.
MustCompile
(
`([a-z0-9])([A-Z])`
)
toolNameFieldRe
=
regexp
.
MustCompile
(
`"name"\s*:\s*"([^"]+)"`
)
modelFieldRe
=
regexp
.
MustCompile
(
`"model"\s*:\s*"([^"]+)"`
)
toolDescAbsPathRe
=
regexp
.
MustCompile
(
`/\/?(?:home|Users|tmp|var|opt|usr|etc)\/[^\s,\)"'\]]+`
)
toolDescWinPathRe
=
regexp
.
MustCompile
(
`(?i)[A-Z]:\\[^\s,\)"'\]]+`
)
claudeToolNameOverrides
=
map
[
string
]
string
{
"bash"
:
"Bash"
,
"read"
:
"Read"
,
"edit"
:
"Edit"
,
"write"
:
"Write"
,
"task"
:
"Task"
,
"glob"
:
"Glob"
,
"grep"
:
"Grep"
,
"webfetch"
:
"WebFetch"
,
"websearch"
:
"WebSearch"
,
"todowrite"
:
"TodoWrite"
,
"question"
:
"AskUserQuestion"
,
}
openCodeToolOverrides
=
map
[
string
]
string
{
"Bash"
:
"bash"
,
"Read"
:
"read"
,
"Edit"
:
"edit"
,
"Write"
:
"write"
,
"Task"
:
"task"
,
"Glob"
:
"glob"
,
"Grep"
:
"grep"
,
"WebFetch"
:
"webfetch"
,
"WebSearch"
:
"websearch"
,
"TodoWrite"
:
"todowrite"
,
"AskUserQuestion"
:
"question"
,
}
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
// 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等
// 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等
...
@@ -418,6 +598,394 @@ func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte
...
@@ -418,6 +598,394 @@ func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte
return
newBody
return
newBody
}
}
type
claudeOAuthNormalizeOptions
struct
{
injectMetadata
bool
metadataUserID
string
stripSystemCacheControl
bool
}
func
stripToolPrefix
(
value
string
)
string
{
if
value
==
""
{
return
value
}
return
toolPrefixRe
.
ReplaceAllString
(
value
,
""
)
}
func
toPascalCase
(
value
string
)
string
{
if
value
==
""
{
return
value
}
normalized
:=
toolNameBoundaryRe
.
ReplaceAllString
(
value
,
" "
)
tokens
:=
make
([]
string
,
0
)
for
_
,
token
:=
range
strings
.
Fields
(
normalized
)
{
expanded
:=
toolNameCamelRe
.
ReplaceAllString
(
token
,
"$1 $2"
)
parts
:=
strings
.
Fields
(
expanded
)
if
len
(
parts
)
>
0
{
tokens
=
append
(
tokens
,
parts
...
)
}
}
if
len
(
tokens
)
==
0
{
return
value
}
var
builder
strings
.
Builder
for
_
,
token
:=
range
tokens
{
lower
:=
strings
.
ToLower
(
token
)
if
lower
==
""
{
continue
}
runes
:=
[]
rune
(
lower
)
runes
[
0
]
=
unicode
.
ToUpper
(
runes
[
0
])
_
,
_
=
builder
.
WriteString
(
string
(
runes
))
}
return
builder
.
String
()
}
func
toSnakeCase
(
value
string
)
string
{
if
value
==
""
{
return
value
}
output
:=
toolNameCamelRe
.
ReplaceAllString
(
value
,
"$1_$2"
)
output
=
toolNameBoundaryRe
.
ReplaceAllString
(
output
,
"_"
)
output
=
strings
.
Trim
(
output
,
"_"
)
return
strings
.
ToLower
(
output
)
}
func
normalizeToolNameForClaude
(
name
string
,
cache
map
[
string
]
string
)
string
{
if
name
==
""
{
return
name
}
stripped
:=
stripToolPrefix
(
name
)
mapped
,
ok
:=
claudeToolNameOverrides
[
strings
.
ToLower
(
stripped
)]
if
!
ok
{
mapped
=
toPascalCase
(
stripped
)
}
if
mapped
!=
""
&&
cache
!=
nil
&&
mapped
!=
stripped
{
cache
[
mapped
]
=
stripped
}
if
mapped
==
""
{
return
stripped
}
return
mapped
}
func
normalizeToolNameForOpenCode
(
name
string
,
cache
map
[
string
]
string
)
string
{
if
name
==
""
{
return
name
}
stripped
:=
stripToolPrefix
(
name
)
if
cache
!=
nil
{
if
mapped
,
ok
:=
cache
[
stripped
];
ok
{
return
mapped
}
}
if
mapped
,
ok
:=
openCodeToolOverrides
[
stripped
];
ok
{
return
mapped
}
return
toSnakeCase
(
stripped
)
}
func
normalizeParamNameForOpenCode
(
name
string
,
cache
map
[
string
]
string
)
string
{
if
name
==
""
{
return
name
}
if
cache
!=
nil
{
if
mapped
,
ok
:=
cache
[
name
];
ok
{
return
mapped
}
}
return
name
}
// sanitizeSystemText rewrites only the fixed OpenCode identity sentence (if present).
// We intentionally avoid broad keyword replacement in system prompts to prevent
// accidentally changing user-provided instructions.
func
sanitizeSystemText
(
text
string
)
string
{
if
text
==
""
{
return
text
}
// Some clients include a fixed OpenCode identity sentence. Anthropic may treat
// this as a non-Claude-Code fingerprint, so rewrite it to the canonical
// Claude Code banner before generic "OpenCode"/"opencode" replacements.
text
=
strings
.
ReplaceAll
(
text
,
"You are OpenCode, the best coding agent on the planet."
,
strings
.
TrimSpace
(
claudeCodeSystemPrompt
),
)
return
text
}
func
sanitizeToolDescription
(
description
string
)
string
{
if
description
==
""
{
return
description
}
description
=
toolDescAbsPathRe
.
ReplaceAllString
(
description
,
"[path]"
)
description
=
toolDescWinPathRe
.
ReplaceAllString
(
description
,
"[path]"
)
// Intentionally do NOT rewrite tool descriptions (OpenCode/Claude strings).
// Tool names/skill names may rely on exact wording, and rewriting can be misleading.
return
description
}
func
normalizeToolInputSchema
(
inputSchema
any
,
cache
map
[
string
]
string
)
{
schema
,
ok
:=
inputSchema
.
(
map
[
string
]
any
)
if
!
ok
{
return
}
properties
,
ok
:=
schema
[
"properties"
]
.
(
map
[
string
]
any
)
if
!
ok
{
return
}
newProperties
:=
make
(
map
[
string
]
any
,
len
(
properties
))
for
key
,
value
:=
range
properties
{
snakeKey
:=
toSnakeCase
(
key
)
newProperties
[
snakeKey
]
=
value
if
snakeKey
!=
key
&&
cache
!=
nil
{
cache
[
snakeKey
]
=
key
}
}
schema
[
"properties"
]
=
newProperties
if
required
,
ok
:=
schema
[
"required"
]
.
([]
any
);
ok
{
newRequired
:=
make
([]
any
,
0
,
len
(
required
))
for
_
,
item
:=
range
required
{
name
,
ok
:=
item
.
(
string
)
if
!
ok
{
newRequired
=
append
(
newRequired
,
item
)
continue
}
snakeName
:=
toSnakeCase
(
name
)
newRequired
=
append
(
newRequired
,
snakeName
)
if
snakeName
!=
name
&&
cache
!=
nil
{
cache
[
snakeName
]
=
name
}
}
schema
[
"required"
]
=
newRequired
}
}
func
stripCacheControlFromSystemBlocks
(
system
any
)
bool
{
blocks
,
ok
:=
system
.
([]
any
)
if
!
ok
{
return
false
}
changed
:=
false
for
_
,
item
:=
range
blocks
{
block
,
ok
:=
item
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
if
_
,
exists
:=
block
[
"cache_control"
];
!
exists
{
continue
}
delete
(
block
,
"cache_control"
)
changed
=
true
}
return
changed
}
func
normalizeClaudeOAuthRequestBody
(
body
[]
byte
,
modelID
string
,
opts
claudeOAuthNormalizeOptions
)
([]
byte
,
string
,
map
[
string
]
string
)
{
if
len
(
body
)
==
0
{
return
body
,
modelID
,
nil
}
var
req
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
req
);
err
!=
nil
{
return
body
,
modelID
,
nil
}
toolNameMap
:=
make
(
map
[
string
]
string
)
if
system
,
ok
:=
req
[
"system"
];
ok
{
switch
v
:=
system
.
(
type
)
{
case
string
:
sanitized
:=
sanitizeSystemText
(
v
)
if
sanitized
!=
v
{
req
[
"system"
]
=
sanitized
}
case
[]
any
:
for
_
,
item
:=
range
v
{
block
,
ok
:=
item
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
if
blockType
,
_
:=
block
[
"type"
]
.
(
string
);
blockType
!=
"text"
{
continue
}
text
,
ok
:=
block
[
"text"
]
.
(
string
)
if
!
ok
||
text
==
""
{
continue
}
sanitized
:=
sanitizeSystemText
(
text
)
if
sanitized
!=
text
{
block
[
"text"
]
=
sanitized
}
}
}
}
if
rawModel
,
ok
:=
req
[
"model"
]
.
(
string
);
ok
{
normalized
:=
claude
.
NormalizeModelID
(
rawModel
)
if
normalized
!=
rawModel
{
req
[
"model"
]
=
normalized
modelID
=
normalized
}
}
if
rawTools
,
exists
:=
req
[
"tools"
];
exists
{
switch
tools
:=
rawTools
.
(
type
)
{
case
[]
any
:
for
idx
,
tool
:=
range
tools
{
toolMap
,
ok
:=
tool
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
if
name
,
ok
:=
toolMap
[
"name"
]
.
(
string
);
ok
{
normalized
:=
normalizeToolNameForClaude
(
name
,
toolNameMap
)
if
normalized
!=
""
&&
normalized
!=
name
{
toolMap
[
"name"
]
=
normalized
}
}
if
desc
,
ok
:=
toolMap
[
"description"
]
.
(
string
);
ok
{
sanitized
:=
sanitizeToolDescription
(
desc
)
if
sanitized
!=
desc
{
toolMap
[
"description"
]
=
sanitized
}
}
if
schema
,
ok
:=
toolMap
[
"input_schema"
];
ok
{
normalizeToolInputSchema
(
schema
,
toolNameMap
)
}
tools
[
idx
]
=
toolMap
}
req
[
"tools"
]
=
tools
case
map
[
string
]
any
:
normalizedTools
:=
make
(
map
[
string
]
any
,
len
(
tools
))
for
name
,
value
:=
range
tools
{
normalized
:=
normalizeToolNameForClaude
(
name
,
toolNameMap
)
if
normalized
==
""
{
normalized
=
name
}
if
toolMap
,
ok
:=
value
.
(
map
[
string
]
any
);
ok
{
toolMap
[
"name"
]
=
normalized
if
desc
,
ok
:=
toolMap
[
"description"
]
.
(
string
);
ok
{
sanitized
:=
sanitizeToolDescription
(
desc
)
if
sanitized
!=
desc
{
toolMap
[
"description"
]
=
sanitized
}
}
if
schema
,
ok
:=
toolMap
[
"input_schema"
];
ok
{
normalizeToolInputSchema
(
schema
,
toolNameMap
)
}
normalizedTools
[
normalized
]
=
toolMap
continue
}
normalizedTools
[
normalized
]
=
value
}
req
[
"tools"
]
=
normalizedTools
}
}
else
{
req
[
"tools"
]
=
[]
any
{}
}
if
messages
,
ok
:=
req
[
"messages"
]
.
([]
any
);
ok
{
for
_
,
msg
:=
range
messages
{
msgMap
,
ok
:=
msg
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
content
,
ok
:=
msgMap
[
"content"
]
.
([]
any
)
if
!
ok
{
continue
}
for
_
,
block
:=
range
content
{
blockMap
,
ok
:=
block
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
if
blockType
,
_
:=
blockMap
[
"type"
]
.
(
string
);
blockType
!=
"tool_use"
{
continue
}
if
name
,
ok
:=
blockMap
[
"name"
]
.
(
string
);
ok
{
normalized
:=
normalizeToolNameForClaude
(
name
,
toolNameMap
)
if
normalized
!=
""
&&
normalized
!=
name
{
blockMap
[
"name"
]
=
normalized
}
}
}
}
}
if
opts
.
stripSystemCacheControl
{
if
system
,
ok
:=
req
[
"system"
];
ok
{
_
=
stripCacheControlFromSystemBlocks
(
system
)
}
}
if
opts
.
injectMetadata
&&
opts
.
metadataUserID
!=
""
{
metadata
,
ok
:=
req
[
"metadata"
]
.
(
map
[
string
]
any
)
if
!
ok
{
metadata
=
map
[
string
]
any
{}
req
[
"metadata"
]
=
metadata
}
if
existing
,
ok
:=
metadata
[
"user_id"
]
.
(
string
);
!
ok
||
existing
==
""
{
metadata
[
"user_id"
]
=
opts
.
metadataUserID
}
}
delete
(
req
,
"temperature"
)
delete
(
req
,
"tool_choice"
)
newBody
,
err
:=
json
.
Marshal
(
req
)
if
err
!=
nil
{
return
body
,
modelID
,
toolNameMap
}
return
newBody
,
modelID
,
toolNameMap
}
func
(
s
*
GatewayService
)
buildOAuthMetadataUserID
(
parsed
*
ParsedRequest
,
account
*
Account
,
fp
*
Fingerprint
)
string
{
if
parsed
==
nil
||
account
==
nil
{
return
""
}
if
parsed
.
MetadataUserID
!=
""
{
return
""
}
userID
:=
strings
.
TrimSpace
(
account
.
GetClaudeUserID
())
if
userID
==
""
&&
fp
!=
nil
{
userID
=
fp
.
ClientID
}
if
userID
==
""
{
// Fall back to a random, well-formed client id so we can still satisfy
// Claude Code OAuth requirements when account metadata is incomplete.
userID
=
generateClientID
()
}
sessionHash
:=
s
.
GenerateSessionHash
(
parsed
)
sessionID
:=
uuid
.
NewString
()
if
sessionHash
!=
""
{
seed
:=
fmt
.
Sprintf
(
"%d::%s"
,
account
.
ID
,
sessionHash
)
sessionID
=
generateSessionUUID
(
seed
)
}
// Prefer the newer format that includes account_uuid (if present),
// otherwise fall back to the legacy Claude Code format.
accountUUID
:=
strings
.
TrimSpace
(
account
.
GetExtraString
(
"account_uuid"
))
if
accountUUID
!=
""
{
return
fmt
.
Sprintf
(
"user_%s_account_%s_session_%s"
,
userID
,
accountUUID
,
sessionID
)
}
return
fmt
.
Sprintf
(
"user_%s_account__session_%s"
,
userID
,
sessionID
)
}
func
generateSessionUUID
(
seed
string
)
string
{
if
seed
==
""
{
return
uuid
.
NewString
()
}
hash
:=
sha256
.
Sum256
([]
byte
(
seed
))
bytes
:=
hash
[
:
16
]
bytes
[
6
]
=
(
bytes
[
6
]
&
0x0f
)
|
0x40
bytes
[
8
]
=
(
bytes
[
8
]
&
0x3f
)
|
0x80
return
fmt
.
Sprintf
(
"%x-%x-%x-%x-%x"
,
bytes
[
0
:
4
],
bytes
[
4
:
6
],
bytes
[
6
:
8
],
bytes
[
8
:
10
],
bytes
[
10
:
16
])
}
// SelectAccount 选择账号(粘性会话+优先级)
// SelectAccount 选择账号(粘性会话+优先级)
func
(
s
*
GatewayService
)
SelectAccount
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
)
(
*
Account
,
error
)
{
func
(
s
*
GatewayService
)
SelectAccount
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
)
(
*
Account
,
error
)
{
return
s
.
SelectAccountForModel
(
ctx
,
groupID
,
sessionHash
,
""
)
return
s
.
SelectAccountForModel
(
ctx
,
groupID
,
sessionHash
,
""
)
...
@@ -2021,6 +2589,16 @@ func isClaudeCodeClient(userAgent string, metadataUserID string) bool {
...
@@ -2021,6 +2589,16 @@ func isClaudeCodeClient(userAgent string, metadataUserID string) bool {
return
claudeCliUserAgentRe
.
MatchString
(
userAgent
)
return
claudeCliUserAgentRe
.
MatchString
(
userAgent
)
}
}
func
isClaudeCodeRequest
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
parsed
*
ParsedRequest
)
bool
{
if
IsClaudeCodeClient
(
ctx
)
{
return
true
}
if
parsed
==
nil
||
c
==
nil
{
return
false
}
return
isClaudeCodeClient
(
c
.
GetHeader
(
"User-Agent"
),
parsed
.
MetadataUserID
)
}
// systemIncludesClaudeCodePrompt 检查 system 中是否已包含 Claude Code 提示词
// systemIncludesClaudeCodePrompt 检查 system 中是否已包含 Claude Code 提示词
// 使用前缀匹配支持多种变体(标准版、Agent SDK 版等)
// 使用前缀匹配支持多种变体(标准版、Agent SDK 版等)
func
systemIncludesClaudeCodePrompt
(
system
any
)
bool
{
func
systemIncludesClaudeCodePrompt
(
system
any
)
bool
{
...
@@ -2057,6 +2635,10 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
...
@@ -2057,6 +2635,10 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
"text"
:
claudeCodeSystemPrompt
,
"text"
:
claudeCodeSystemPrompt
,
"cache_control"
:
map
[
string
]
string
{
"type"
:
"ephemeral"
},
"cache_control"
:
map
[
string
]
string
{
"type"
:
"ephemeral"
},
}
}
// Opencode plugin applies an extra safeguard: it not only prepends the Claude Code
// banner, it also prefixes the next system instruction with the same banner plus
// a blank line. This helps when upstream concatenates system instructions.
claudeCodePrefix
:=
strings
.
TrimSpace
(
claudeCodeSystemPrompt
)
var
newSystem
[]
any
var
newSystem
[]
any
...
@@ -2064,19 +2646,36 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
...
@@ -2064,19 +2646,36 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
case
nil
:
case
nil
:
newSystem
=
[]
any
{
claudeCodeBlock
}
newSystem
=
[]
any
{
claudeCodeBlock
}
case
string
:
case
string
:
if
v
==
""
||
v
==
claudeCodeSystemPrompt
{
// Be tolerant of older/newer clients that may differ only by trailing whitespace/newlines.
if
strings
.
TrimSpace
(
v
)
==
""
||
strings
.
TrimSpace
(
v
)
==
strings
.
TrimSpace
(
claudeCodeSystemPrompt
)
{
newSystem
=
[]
any
{
claudeCodeBlock
}
newSystem
=
[]
any
{
claudeCodeBlock
}
}
else
{
}
else
{
newSystem
=
[]
any
{
claudeCodeBlock
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
v
}}
// Mirror opencode behavior: keep the banner as a separate system entry,
// but also prefix the next system text with the banner.
merged
:=
v
if
!
strings
.
HasPrefix
(
v
,
claudeCodePrefix
)
{
merged
=
claudeCodePrefix
+
"
\n\n
"
+
v
}
newSystem
=
[]
any
{
claudeCodeBlock
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
merged
}}
}
}
case
[]
any
:
case
[]
any
:
newSystem
=
make
([]
any
,
0
,
len
(
v
)
+
1
)
newSystem
=
make
([]
any
,
0
,
len
(
v
)
+
1
)
newSystem
=
append
(
newSystem
,
claudeCodeBlock
)
newSystem
=
append
(
newSystem
,
claudeCodeBlock
)
prefixedNext
:=
false
for
_
,
item
:=
range
v
{
for
_
,
item
:=
range
v
{
if
m
,
ok
:=
item
.
(
map
[
string
]
any
);
ok
{
if
m
,
ok
:=
item
.
(
map
[
string
]
any
);
ok
{
if
text
,
ok
:=
m
[
"text"
]
.
(
string
);
ok
&&
text
==
claudeCodeSystemPrompt
{
if
text
,
ok
:=
m
[
"text"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
text
)
==
strings
.
TrimSpace
(
claudeCodeSystemPrompt
)
{
continue
continue
}
}
// Prefix the first subsequent text system block once.
if
!
prefixedNext
{
if
blockType
,
_
:=
m
[
"type"
]
.
(
string
);
blockType
==
"text"
{
if
text
,
ok
:=
m
[
"text"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
text
)
!=
""
&&
!
strings
.
HasPrefix
(
text
,
claudeCodePrefix
)
{
m
[
"text"
]
=
claudeCodePrefix
+
"
\n\n
"
+
text
prefixedNext
=
true
}
}
}
}
}
newSystem
=
append
(
newSystem
,
item
)
newSystem
=
append
(
newSystem
,
item
)
}
}
...
@@ -2280,21 +2879,38 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
...
@@ -2280,21 +2879,38 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
body
:=
parsed
.
Body
body
:=
parsed
.
Body
reqModel
:=
parsed
.
Model
reqModel
:=
parsed
.
Model
reqStream
:=
parsed
.
Stream
reqStream
:=
parsed
.
Stream
originalModel
:=
reqModel
var
toolNameMap
map
[
string
]
string
isClaudeCode
:=
isClaudeCodeRequest
(
ctx
,
c
,
parsed
)
shouldMimicClaudeCode
:=
account
.
IsOAuth
()
&&
!
isClaudeCode
if
shouldMimicClaudeCode
{
// 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要)
// 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
if
!
strings
.
Contains
(
strings
.
ToLower
(
reqModel
),
"haiku"
)
&&
!
systemIncludesClaudeCodePrompt
(
parsed
.
System
)
{
body
=
injectClaudeCodePrompt
(
body
,
parsed
.
System
)
}
normalizeOpts
:=
claudeOAuthNormalizeOptions
{
stripSystemCacheControl
:
true
}
if
s
.
identityService
!=
nil
{
fp
,
err
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
.
Request
.
Header
)
if
err
==
nil
&&
fp
!=
nil
{
if
metadataUserID
:=
s
.
buildOAuthMetadataUserID
(
parsed
,
account
,
fp
);
metadataUserID
!=
""
{
normalizeOpts
.
injectMetadata
=
true
normalizeOpts
.
metadataUserID
=
metadataUserID
}
}
}
// 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要)
body
,
reqModel
,
toolNameMap
=
normalizeClaudeOAuthRequestBody
(
body
,
reqModel
,
normalizeOpts
)
// 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
if
account
.
IsOAuth
()
&&
!
isClaudeCodeClient
(
c
.
GetHeader
(
"User-Agent"
),
parsed
.
MetadataUserID
)
&&
!
strings
.
Contains
(
strings
.
ToLower
(
reqModel
),
"haiku"
)
&&
!
systemIncludesClaudeCodePrompt
(
parsed
.
System
)
{
body
=
injectClaudeCodePrompt
(
body
,
parsed
.
System
)
}
}
// 强制执行 cache_control 块数量限制(最多 4 个)
// 强制执行 cache_control 块数量限制(最多 4 个)
body
=
enforceCacheControlLimit
(
body
)
body
=
enforceCacheControlLimit
(
body
)
// 应用模型映射(仅对apikey类型账号)
// 应用模型映射(仅对apikey类型账号)
originalModel
:=
reqModel
if
account
.
Type
==
AccountTypeAPIKey
{
if
account
.
Type
==
AccountTypeAPIKey
{
mappedModel
:=
account
.
GetMappedModel
(
reqModel
)
mappedModel
:=
account
.
GetMappedModel
(
reqModel
)
if
mappedModel
!=
reqModel
{
if
mappedModel
!=
reqModel
{
...
@@ -2326,10 +2942,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
...
@@ -2326,10 +2942,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
retryStart
:=
time
.
Now
()
retryStart
:=
time
.
Now
()
for
attempt
:=
1
;
attempt
<=
maxRetryAttempts
;
attempt
++
{
for
attempt
:=
1
;
attempt
<=
maxRetryAttempts
;
attempt
++
{
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
upstreamReq
,
err
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
body
,
token
,
tokenType
,
reqModel
)
// Capture upstream request body for ops retry of this attempt.
// Capture upstream request body for ops retry of this attempt.
c
.
Set
(
OpsUpstreamRequestBodyKey
,
string
(
body
))
c
.
Set
(
OpsUpstreamRequestBodyKey
,
string
(
body
))
upstreamReq
,
err
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
body
,
token
,
tokenType
,
reqModel
,
reqStream
,
shouldMimicClaudeCode
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
...
@@ -2407,7 +3022,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
...
@@ -2407,7 +3022,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// also downgrade tool_use/tool_result blocks to text.
// also downgrade tool_use/tool_result blocks to text.
filteredBody
:=
FilterThinkingBlocksForRetry
(
body
)
filteredBody
:=
FilterThinkingBlocksForRetry
(
body
)
retryReq
,
buildErr
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
filteredBody
,
token
,
tokenType
,
reqModel
)
retryReq
,
buildErr
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
filteredBody
,
token
,
tokenType
,
reqModel
,
reqStream
,
shouldMimicClaudeCode
)
if
buildErr
==
nil
{
if
buildErr
==
nil
{
retryResp
,
retryErr
:=
s
.
httpUpstream
.
DoWithTLS
(
retryReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
())
retryResp
,
retryErr
:=
s
.
httpUpstream
.
DoWithTLS
(
retryReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
())
if
retryErr
==
nil
{
if
retryErr
==
nil
{
...
@@ -2439,7 +3054,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
...
@@ -2439,7 +3054,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
if
looksLikeToolSignatureError
(
msg2
)
&&
time
.
Since
(
retryStart
)
<
maxRetryElapsed
{
if
looksLikeToolSignatureError
(
msg2
)
&&
time
.
Since
(
retryStart
)
<
maxRetryElapsed
{
log
.
Printf
(
"Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded"
,
account
.
ID
)
log
.
Printf
(
"Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded"
,
account
.
ID
)
filteredBody2
:=
FilterSignatureSensitiveBlocksForRetry
(
body
)
filteredBody2
:=
FilterSignatureSensitiveBlocksForRetry
(
body
)
retryReq2
,
buildErr2
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
filteredBody2
,
token
,
tokenType
,
reqModel
)
retryReq2
,
buildErr2
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
filteredBody2
,
token
,
tokenType
,
reqModel
,
reqStream
,
shouldMimicClaudeCode
)
if
buildErr2
==
nil
{
if
buildErr2
==
nil
{
retryResp2
,
retryErr2
:=
s
.
httpUpstream
.
DoWithTLS
(
retryReq2
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
())
retryResp2
,
retryErr2
:=
s
.
httpUpstream
.
DoWithTLS
(
retryReq2
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
())
if
retryErr2
==
nil
{
if
retryErr2
==
nil
{
...
@@ -2664,7 +3279,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
...
@@ -2664,7 +3279,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
var
firstTokenMs
*
int
var
firstTokenMs
*
int
var
clientDisconnect
bool
var
clientDisconnect
bool
if
reqStream
{
if
reqStream
{
streamResult
,
err
:=
s
.
handleStreamingResponse
(
ctx
,
resp
,
c
,
account
,
startTime
,
originalModel
,
reqModel
)
streamResult
,
err
:=
s
.
handleStreamingResponse
(
ctx
,
resp
,
c
,
account
,
startTime
,
originalModel
,
reqModel
,
toolNameMap
,
shouldMimicClaudeCode
)
if
err
!=
nil
{
if
err
!=
nil
{
if
err
.
Error
()
==
"have error in stream"
{
if
err
.
Error
()
==
"have error in stream"
{
return
nil
,
&
UpstreamFailoverError
{
return
nil
,
&
UpstreamFailoverError
{
...
@@ -2677,7 +3292,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
...
@@ -2677,7 +3292,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
firstTokenMs
=
streamResult
.
firstTokenMs
firstTokenMs
=
streamResult
.
firstTokenMs
clientDisconnect
=
streamResult
.
clientDisconnect
clientDisconnect
=
streamResult
.
clientDisconnect
}
else
{
}
else
{
usage
,
err
=
s
.
handleNonStreamingResponse
(
ctx
,
resp
,
c
,
account
,
originalModel
,
reqModel
)
usage
,
err
=
s
.
handleNonStreamingResponse
(
ctx
,
resp
,
c
,
account
,
originalModel
,
reqModel
,
toolNameMap
,
shouldMimicClaudeCode
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
...
@@ -2694,7 +3309,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
...
@@ -2694,7 +3309,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
},
nil
},
nil
}
}
func
(
s
*
GatewayService
)
buildUpstreamRequest
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
,
token
,
tokenType
,
modelID
string
)
(
*
http
.
Request
,
error
)
{
func
(
s
*
GatewayService
)
buildUpstreamRequest
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
,
token
,
tokenType
,
modelID
string
,
reqStream
bool
,
mimicClaudeCode
bool
)
(
*
http
.
Request
,
error
)
{
// 确定目标URL
// 确定目标URL
targetURL
:=
claudeAPIURL
targetURL
:=
claudeAPIURL
if
account
.
Type
==
AccountTypeAPIKey
{
if
account
.
Type
==
AccountTypeAPIKey
{
...
@@ -2708,11 +3323,16 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
...
@@ -2708,11 +3323,16 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
}
}
}
clientHeaders
:=
http
.
Header
{}
if
c
!=
nil
&&
c
.
Request
!=
nil
{
clientHeaders
=
c
.
Request
.
Header
}
// OAuth账号:应用统一指纹
// OAuth账号:应用统一指纹
var
fingerprint
*
Fingerprint
var
fingerprint
*
Fingerprint
if
account
.
IsOAuth
()
&&
s
.
identityService
!=
nil
{
if
account
.
IsOAuth
()
&&
s
.
identityService
!=
nil
{
// 1. 获取或创建指纹(包含随机生成的ClientID)
// 1. 获取或创建指纹(包含随机生成的ClientID)
fp
,
err
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
.
Request
.
Header
)
fp
,
err
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
lient
Header
s
)
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Printf
(
"Warning: failed to get fingerprint for account %d: %v"
,
account
.
ID
,
err
)
log
.
Printf
(
"Warning: failed to get fingerprint for account %d: %v"
,
account
.
ID
,
err
)
// 失败时降级为透传原始headers
// 失败时降级为透传原始headers
...
@@ -2743,7 +3363,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
...
@@ -2743,7 +3363,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
}
// 白名单透传headers
// 白名单透传headers
for
key
,
values
:=
range
c
.
Request
.
Header
{
for
key
,
values
:=
range
c
lient
Header
s
{
lowerKey
:=
strings
.
ToLower
(
key
)
lowerKey
:=
strings
.
ToLower
(
key
)
if
allowedHeaders
[
lowerKey
]
{
if
allowedHeaders
[
lowerKey
]
{
for
_
,
v
:=
range
values
{
for
_
,
v
:=
range
values
{
...
@@ -2764,10 +3384,30 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
...
@@ -2764,10 +3384,30 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
if
req
.
Header
.
Get
(
"anthropic-version"
)
==
""
{
if
req
.
Header
.
Get
(
"anthropic-version"
)
==
""
{
req
.
Header
.
Set
(
"anthropic-version"
,
"2023-06-01"
)
req
.
Header
.
Set
(
"anthropic-version"
,
"2023-06-01"
)
}
}
if
tokenType
==
"oauth"
{
applyClaudeOAuthHeaderDefaults
(
req
,
reqStream
)
}
// 处理anthropic-beta header(OAuth账号需要
特殊处理
)
// 处理
anthropic-beta header(OAuth
账号需要
包含 oauth beta
)
if
tokenType
==
"oauth"
{
if
tokenType
==
"oauth"
{
req
.
Header
.
Set
(
"anthropic-beta"
,
s
.
getBetaHeader
(
modelID
,
c
.
GetHeader
(
"anthropic-beta"
)))
if
mimicClaudeCode
{
// 非 Claude Code 客户端:按 opencode 的策略处理:
// - 强制 Claude Code 指纹相关请求头(尤其是 user-agent/x-stainless/x-app)
// - 保留 incoming beta 的同时,确保 OAuth 所需 beta 存在
applyClaudeCodeMimicHeaders
(
req
,
reqStream
)
incomingBeta
:=
req
.
Header
.
Get
(
"anthropic-beta"
)
// Match real Claude CLI traffic (per mitmproxy reports):
// messages requests typically use only oauth + interleaved-thinking.
// Also drop claude-code beta if a downstream client added it.
requiredBetas
:=
[]
string
{
claude
.
BetaOAuth
,
claude
.
BetaInterleavedThinking
}
drop
:=
map
[
string
]
struct
{}{
claude
.
BetaClaudeCode
:
{}}
req
.
Header
.
Set
(
"anthropic-beta"
,
mergeAnthropicBetaDropping
(
requiredBetas
,
incomingBeta
,
drop
))
}
else
{
// Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta
clientBetaHeader
:=
req
.
Header
.
Get
(
"anthropic-beta"
)
req
.
Header
.
Set
(
"anthropic-beta"
,
s
.
getBetaHeader
(
modelID
,
clientBetaHeader
))
}
}
else
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
InjectBetaForAPIKey
&&
req
.
Header
.
Get
(
"anthropic-beta"
)
==
""
{
}
else
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
InjectBetaForAPIKey
&&
req
.
Header
.
Get
(
"anthropic-beta"
)
==
""
{
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
if
requestNeedsBetaFeatures
(
body
)
{
if
requestNeedsBetaFeatures
(
body
)
{
...
@@ -2777,6 +3417,15 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
...
@@ -2777,6 +3417,15 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
}
}
}
// Always capture a compact fingerprint line for later error diagnostics.
// We only print it when needed (or when the explicit debug flag is enabled).
if
c
!=
nil
&&
tokenType
==
"oauth"
{
c
.
Set
(
claudeMimicDebugInfoKey
,
buildClaudeMimicDebugLine
(
req
,
body
,
account
,
tokenType
,
mimicClaudeCode
))
}
if
s
.
debugClaudeMimicEnabled
()
{
logClaudeMimicDebug
(
req
,
body
,
account
,
tokenType
,
mimicClaudeCode
)
}
return
req
,
nil
return
req
,
nil
}
}
...
@@ -2846,6 +3495,93 @@ func defaultAPIKeyBetaHeader(body []byte) string {
...
@@ -2846,6 +3495,93 @@ func defaultAPIKeyBetaHeader(body []byte) string {
return
claude
.
APIKeyBetaHeader
return
claude
.
APIKeyBetaHeader
}
}
func
applyClaudeOAuthHeaderDefaults
(
req
*
http
.
Request
,
isStream
bool
)
{
if
req
==
nil
{
return
}
if
req
.
Header
.
Get
(
"accept"
)
==
""
{
req
.
Header
.
Set
(
"accept"
,
"application/json"
)
}
for
key
,
value
:=
range
claude
.
DefaultHeaders
{
if
value
==
""
{
continue
}
if
req
.
Header
.
Get
(
key
)
==
""
{
req
.
Header
.
Set
(
key
,
value
)
}
}
if
isStream
&&
req
.
Header
.
Get
(
"x-stainless-helper-method"
)
==
""
{
req
.
Header
.
Set
(
"x-stainless-helper-method"
,
"stream"
)
}
}
func
mergeAnthropicBeta
(
required
[]
string
,
incoming
string
)
string
{
seen
:=
make
(
map
[
string
]
struct
{},
len
(
required
)
+
8
)
out
:=
make
([]
string
,
0
,
len
(
required
)
+
8
)
add
:=
func
(
v
string
)
{
v
=
strings
.
TrimSpace
(
v
)
if
v
==
""
{
return
}
if
_
,
ok
:=
seen
[
v
];
ok
{
return
}
seen
[
v
]
=
struct
{}{}
out
=
append
(
out
,
v
)
}
for
_
,
r
:=
range
required
{
add
(
r
)
}
for
_
,
p
:=
range
strings
.
Split
(
incoming
,
","
)
{
add
(
p
)
}
return
strings
.
Join
(
out
,
","
)
}
func
mergeAnthropicBetaDropping
(
required
[]
string
,
incoming
string
,
drop
map
[
string
]
struct
{})
string
{
merged
:=
mergeAnthropicBeta
(
required
,
incoming
)
if
merged
==
""
||
len
(
drop
)
==
0
{
return
merged
}
out
:=
make
([]
string
,
0
,
8
)
for
_
,
p
:=
range
strings
.
Split
(
merged
,
","
)
{
p
=
strings
.
TrimSpace
(
p
)
if
p
==
""
{
continue
}
if
_
,
ok
:=
drop
[
p
];
ok
{
continue
}
out
=
append
(
out
,
p
)
}
return
strings
.
Join
(
out
,
","
)
}
// applyClaudeCodeMimicHeaders forces "Claude Code-like" request headers.
// This mirrors opencode-anthropic-auth behavior: do not trust downstream
// headers when using Claude Code-scoped OAuth credentials.
func
applyClaudeCodeMimicHeaders
(
req
*
http
.
Request
,
isStream
bool
)
{
if
req
==
nil
{
return
}
// Start with the standard defaults (fill missing).
applyClaudeOAuthHeaderDefaults
(
req
,
isStream
)
// Then force key headers to match Claude Code fingerprint regardless of what the client sent.
for
key
,
value
:=
range
claude
.
DefaultHeaders
{
if
value
==
""
{
continue
}
req
.
Header
.
Set
(
key
,
value
)
}
// Real Claude CLI uses Accept: application/json (even for streaming).
req
.
Header
.
Set
(
"accept"
,
"application/json"
)
if
isStream
{
req
.
Header
.
Set
(
"x-stainless-helper-method"
,
"stream"
)
}
}
func
truncateForLog
(
b
[]
byte
,
maxBytes
int
)
string
{
func
truncateForLog
(
b
[]
byte
,
maxBytes
int
)
string
{
if
maxBytes
<=
0
{
if
maxBytes
<=
0
{
maxBytes
=
2048
maxBytes
=
2048
...
@@ -2949,6 +3685,20 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
...
@@ -2949,6 +3685,20 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
body
))
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
body
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
// Print a compact upstream request fingerprint when we hit the Claude Code OAuth
// credential scope error. This avoids requiring env-var tweaks in a fixed deploy.
if
isClaudeCodeCredentialScopeError
(
upstreamMsg
)
&&
c
!=
nil
{
if
v
,
ok
:=
c
.
Get
(
claudeMimicDebugInfoKey
);
ok
{
if
line
,
ok
:=
v
.
(
string
);
ok
&&
strings
.
TrimSpace
(
line
)
!=
""
{
log
.
Printf
(
"[ClaudeMimicDebugOnError] status=%d request_id=%s %s"
,
resp
.
StatusCode
,
resp
.
Header
.
Get
(
"x-request-id"
),
line
,
)
}
}
}
// Enrich Ops error logs with upstream status + message, and optionally a truncated body snippet.
// Enrich Ops error logs with upstream status + message, and optionally a truncated body snippet.
upstreamDetail
:=
""
upstreamDetail
:=
""
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
...
@@ -3078,6 +3828,19 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
...
@@ -3078,6 +3828,19 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
respBody
))
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
respBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
if
isClaudeCodeCredentialScopeError
(
upstreamMsg
)
&&
c
!=
nil
{
if
v
,
ok
:=
c
.
Get
(
claudeMimicDebugInfoKey
);
ok
{
if
line
,
ok
:=
v
.
(
string
);
ok
&&
strings
.
TrimSpace
(
line
)
!=
""
{
log
.
Printf
(
"[ClaudeMimicDebugOnError] status=%d request_id=%s %s"
,
resp
.
StatusCode
,
resp
.
Header
.
Get
(
"x-request-id"
),
line
,
)
}
}
}
upstreamDetail
:=
""
upstreamDetail
:=
""
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
maxBytes
:=
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
maxBytes
:=
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
...
@@ -3130,7 +3893,7 @@ type streamingResult struct {
...
@@ -3130,7 +3893,7 @@ type streamingResult struct {
clientDisconnect
bool
// 客户端是否在流式传输过程中断开
clientDisconnect
bool
// 客户端是否在流式传输过程中断开
}
}
func
(
s
*
GatewayService
)
handleStreamingResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
Account
,
startTime
time
.
Time
,
originalModel
,
mappedModel
string
)
(
*
streamingResult
,
error
)
{
func
(
s
*
GatewayService
)
handleStreamingResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
Account
,
startTime
time
.
Time
,
originalModel
,
mappedModel
string
,
toolNameMap
map
[
string
]
string
,
mimicClaudeCode
bool
)
(
*
streamingResult
,
error
)
{
// 更新5h窗口状态
// 更新5h窗口状态
s
.
rateLimitService
.
UpdateSessionWindow
(
ctx
,
account
,
resp
.
Header
)
s
.
rateLimitService
.
UpdateSessionWindow
(
ctx
,
account
,
resp
.
Header
)
...
@@ -3225,6 +3988,171 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
...
@@ -3225,6 +3988,171 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
needModelReplace
:=
originalModel
!=
mappedModel
needModelReplace
:=
originalModel
!=
mappedModel
clientDisconnected
:=
false
// 客户端断开标志,断开后继续读取上游以获取完整usage
clientDisconnected
:=
false
// 客户端断开标志,断开后继续读取上游以获取完整usage
pendingEventLines
:=
make
([]
string
,
0
,
4
)
var
toolInputBuffers
map
[
int
]
string
if
mimicClaudeCode
{
toolInputBuffers
=
make
(
map
[
int
]
string
)
}
transformToolInputJSON
:=
func
(
raw
string
)
string
{
if
!
mimicClaudeCode
{
return
raw
}
raw
=
strings
.
TrimSpace
(
raw
)
if
raw
==
""
{
return
raw
}
var
parsed
any
if
err
:=
json
.
Unmarshal
([]
byte
(
raw
),
&
parsed
);
err
!=
nil
{
return
replaceToolNamesInText
(
raw
,
toolNameMap
)
}
rewritten
,
changed
:=
rewriteParamKeysInValue
(
parsed
,
toolNameMap
)
if
changed
{
if
bytes
,
err
:=
json
.
Marshal
(
rewritten
);
err
==
nil
{
return
string
(
bytes
)
}
}
return
raw
}
processSSEEvent
:=
func
(
lines
[]
string
)
([]
string
,
string
,
error
)
{
if
len
(
lines
)
==
0
{
return
nil
,
""
,
nil
}
eventName
:=
""
dataLine
:=
""
for
_
,
line
:=
range
lines
{
trimmed
:=
strings
.
TrimSpace
(
line
)
if
strings
.
HasPrefix
(
trimmed
,
"event:"
)
{
eventName
=
strings
.
TrimSpace
(
strings
.
TrimPrefix
(
trimmed
,
"event:"
))
continue
}
if
dataLine
==
""
&&
sseDataRe
.
MatchString
(
trimmed
)
{
dataLine
=
sseDataRe
.
ReplaceAllString
(
trimmed
,
""
)
}
}
if
eventName
==
"error"
{
return
nil
,
dataLine
,
errors
.
New
(
"have error in stream"
)
}
if
dataLine
==
""
{
return
[]
string
{
strings
.
Join
(
lines
,
"
\n
"
)
+
"
\n\n
"
},
""
,
nil
}
if
dataLine
==
"[DONE]"
{
block
:=
""
if
eventName
!=
""
{
block
=
"event: "
+
eventName
+
"
\n
"
}
block
+=
"data: "
+
dataLine
+
"
\n\n
"
return
[]
string
{
block
},
dataLine
,
nil
}
var
event
map
[
string
]
any
if
err
:=
json
.
Unmarshal
([]
byte
(
dataLine
),
&
event
);
err
!=
nil
{
replaced
:=
dataLine
if
mimicClaudeCode
{
replaced
=
replaceToolNamesInText
(
dataLine
,
toolNameMap
)
}
block
:=
""
if
eventName
!=
""
{
block
=
"event: "
+
eventName
+
"
\n
"
}
block
+=
"data: "
+
replaced
+
"
\n\n
"
return
[]
string
{
block
},
replaced
,
nil
}
eventType
,
_
:=
event
[
"type"
]
.
(
string
)
if
eventName
==
""
{
eventName
=
eventType
}
if
needModelReplace
{
if
msg
,
ok
:=
event
[
"message"
]
.
(
map
[
string
]
any
);
ok
{
if
model
,
ok
:=
msg
[
"model"
]
.
(
string
);
ok
&&
model
==
mappedModel
{
msg
[
"model"
]
=
originalModel
}
}
}
if
mimicClaudeCode
&&
eventType
==
"content_block_delta"
{
if
delta
,
ok
:=
event
[
"delta"
]
.
(
map
[
string
]
any
);
ok
{
if
deltaType
,
_
:=
delta
[
"type"
]
.
(
string
);
deltaType
==
"input_json_delta"
{
if
indexVal
,
ok
:=
event
[
"index"
]
.
(
float64
);
ok
{
index
:=
int
(
indexVal
)
if
partial
,
ok
:=
delta
[
"partial_json"
]
.
(
string
);
ok
{
toolInputBuffers
[
index
]
+=
partial
}
}
return
nil
,
dataLine
,
nil
}
}
}
if
mimicClaudeCode
&&
eventType
==
"content_block_stop"
{
if
indexVal
,
ok
:=
event
[
"index"
]
.
(
float64
);
ok
{
index
:=
int
(
indexVal
)
if
buffered
:=
toolInputBuffers
[
index
];
buffered
!=
""
{
delete
(
toolInputBuffers
,
index
)
transformed
:=
transformToolInputJSON
(
buffered
)
synthetic
:=
map
[
string
]
any
{
"type"
:
"content_block_delta"
,
"index"
:
index
,
"delta"
:
map
[
string
]
any
{
"type"
:
"input_json_delta"
,
"partial_json"
:
transformed
,
},
}
synthBytes
,
synthErr
:=
json
.
Marshal
(
synthetic
)
if
synthErr
==
nil
{
synthBlock
:=
"event: content_block_delta
\n
"
+
"data: "
+
string
(
synthBytes
)
+
"
\n\n
"
rewriteToolNamesInValue
(
event
,
toolNameMap
)
stopBytes
,
stopErr
:=
json
.
Marshal
(
event
)
if
stopErr
==
nil
{
stopBlock
:=
""
if
eventName
!=
""
{
stopBlock
=
"event: "
+
eventName
+
"
\n
"
}
stopBlock
+=
"data: "
+
string
(
stopBytes
)
+
"
\n\n
"
return
[]
string
{
synthBlock
,
stopBlock
},
string
(
stopBytes
),
nil
}
}
}
}
}
if
mimicClaudeCode
{
rewriteToolNamesInValue
(
event
,
toolNameMap
)
}
newData
,
err
:=
json
.
Marshal
(
event
)
if
err
!=
nil
{
replaced
:=
dataLine
if
mimicClaudeCode
{
replaced
=
replaceToolNamesInText
(
dataLine
,
toolNameMap
)
}
block
:=
""
if
eventName
!=
""
{
block
=
"event: "
+
eventName
+
"
\n
"
}
block
+=
"data: "
+
replaced
+
"
\n\n
"
return
[]
string
{
block
},
replaced
,
nil
}
block
:=
""
if
eventName
!=
""
{
block
=
"event: "
+
eventName
+
"
\n
"
}
block
+=
"data: "
+
string
(
newData
)
+
"
\n\n
"
return
[]
string
{
block
},
string
(
newData
),
nil
}
for
{
for
{
select
{
select
{
case
ev
,
ok
:=
<-
events
:
case
ev
,
ok
:=
<-
events
:
...
@@ -3253,43 +4181,44 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
...
@@ -3253,43 +4181,44 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream read error: %w"
,
ev
.
err
)
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream read error: %w"
,
ev
.
err
)
}
}
line
:=
ev
.
line
line
:=
ev
.
line
if
line
==
"event: error"
{
trimmed
:=
strings
.
TrimSpace
(
line
)
// 上游返回错误事件,如果客户端已断开仍返回已收集的 usage
if
clientDisconnected
{
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
true
},
nil
}
return
nil
,
errors
.
New
(
"have error in stream"
)
}
// Extract data from SSE line (supports both "data: " and "data:" formats)
if
trimmed
==
""
{
var
data
string
if
len
(
pendingEventLines
)
==
0
{
if
sseDataRe
.
MatchString
(
line
)
{
continue
data
=
sseDataRe
.
ReplaceAllString
(
line
,
""
)
// 如果有模型映射,替换响应中的model字段
if
needModelReplace
{
line
=
s
.
replaceModelInSSELine
(
line
,
mappedModel
,
originalModel
)
}
}
}
// 写入客户端(统一处理 data 行和非 data 行)
outputBlocks
,
data
,
err
:=
processSSEEvent
(
pendingEventLines
)
if
!
clientDisconnected
{
pendingEventLines
=
pendingEventLines
[
:
0
]
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
if
err
!=
nil
{
clientDisconnected
=
true
if
clientDisconnected
{
log
.
Printf
(
"Client disconnected during streaming, continuing to drain upstream for billing"
)
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
true
},
nil
}
else
{
}
flusher
.
Flush
()
return
nil
,
err
}
}
}
// 无论客户端是否断开,都解析 usage(仅对 data 行)
for
_
,
block
:=
range
outputBlocks
{
if
data
!=
""
{
if
!
clientDisconnected
{
if
firstTokenMs
==
nil
&&
data
!=
"[DONE]"
{
if
_
,
werr
:=
fmt
.
Fprint
(
w
,
block
);
werr
!=
nil
{
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
clientDisconnected
=
true
firstTokenMs
=
&
ms
log
.
Printf
(
"Client disconnected during streaming, continuing to drain upstream for billing"
)
break
}
flusher
.
Flush
()
}
if
data
!=
""
{
if
firstTokenMs
==
nil
&&
data
!=
"[DONE]"
{
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
firstTokenMs
=
&
ms
}
s
.
parseSSEUsage
(
data
,
usage
)
}
}
}
s
.
parseSSEUsage
(
data
,
usage
)
continue
}
}
pendingEventLines
=
append
(
pendingEventLines
,
line
)
case
<-
intervalCh
:
case
<-
intervalCh
:
lastRead
:=
time
.
Unix
(
0
,
atomic
.
LoadInt64
(
&
lastReadAt
))
lastRead
:=
time
.
Unix
(
0
,
atomic
.
LoadInt64
(
&
lastReadAt
))
if
time
.
Since
(
lastRead
)
<
streamInterval
{
if
time
.
Since
(
lastRead
)
<
streamInterval
{
...
@@ -3312,43 +4241,124 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
...
@@ -3312,43 +4241,124 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
}
// replaceModelInSSELine 替换SSE数据行中的model字段
func
rewriteParamKeysInValue
(
value
any
,
cache
map
[
string
]
string
)
(
any
,
bool
)
{
func
(
s
*
GatewayService
)
replaceModelInSSELine
(
line
,
fromModel
,
toModel
string
)
string
{
switch
v
:=
value
.
(
type
)
{
if
!
sseDataRe
.
MatchString
(
line
)
{
case
map
[
string
]
any
:
return
line
changed
:=
false
}
rewritten
:=
make
(
map
[
string
]
any
,
len
(
v
))
data
:=
sseDataRe
.
ReplaceAllString
(
line
,
""
)
for
key
,
item
:=
range
v
{
if
data
==
""
||
data
==
"[DONE]"
{
newKey
:=
normalizeParamNameForOpenCode
(
key
,
cache
)
return
line
newItem
,
childChanged
:=
rewriteParamKeysInValue
(
item
,
cache
)
}
if
childChanged
{
changed
=
true
var
event
map
[
string
]
any
}
if
err
:=
json
.
Unmarshal
([]
byte
(
data
),
&
event
);
err
!=
nil
{
if
newKey
!=
key
{
return
line
changed
=
true
}
}
rewritten
[
newKey
]
=
newItem
// 只替换 message_start 事件中的 message.model
}
if
event
[
"type"
]
!=
"message_start"
{
if
!
changed
{
return
line
return
value
,
false
}
return
rewritten
,
true
case
[]
any
:
changed
:=
false
rewritten
:=
make
([]
any
,
len
(
v
))
for
idx
,
item
:=
range
v
{
newItem
,
childChanged
:=
rewriteParamKeysInValue
(
item
,
cache
)
if
childChanged
{
changed
=
true
}
rewritten
[
idx
]
=
newItem
}
if
!
changed
{
return
value
,
false
}
return
rewritten
,
true
default
:
return
value
,
false
}
}
}
msg
,
ok
:=
event
[
"message"
]
.
(
map
[
string
]
any
)
func
rewriteToolNamesInValue
(
value
any
,
toolNameMap
map
[
string
]
string
)
bool
{
if
!
ok
{
switch
v
:=
value
.
(
type
)
{
return
line
case
map
[
string
]
any
:
changed
:=
false
if
blockType
,
_
:=
v
[
"type"
]
.
(
string
);
blockType
==
"tool_use"
{
if
name
,
ok
:=
v
[
"name"
]
.
(
string
);
ok
{
mapped
:=
normalizeToolNameForOpenCode
(
name
,
toolNameMap
)
if
mapped
!=
name
{
v
[
"name"
]
=
mapped
changed
=
true
}
}
if
input
,
ok
:=
v
[
"input"
]
.
(
map
[
string
]
any
);
ok
{
rewrittenInput
,
inputChanged
:=
rewriteParamKeysInValue
(
input
,
toolNameMap
)
if
inputChanged
{
if
m
,
ok
:=
rewrittenInput
.
(
map
[
string
]
any
);
ok
{
v
[
"input"
]
=
m
changed
=
true
}
}
}
}
for
_
,
item
:=
range
v
{
if
rewriteToolNamesInValue
(
item
,
toolNameMap
)
{
changed
=
true
}
}
return
changed
case
[]
any
:
changed
:=
false
for
_
,
item
:=
range
v
{
if
rewriteToolNamesInValue
(
item
,
toolNameMap
)
{
changed
=
true
}
}
return
changed
default
:
return
false
}
}
}
model
,
ok
:=
msg
[
"model"
]
.
(
string
)
func
replaceToolNamesInText
(
text
string
,
toolNameMap
map
[
string
]
string
)
string
{
if
!
ok
||
model
!=
fromModel
{
if
text
==
""
{
return
line
return
text
}
}
output
:=
toolNameFieldRe
.
ReplaceAllStringFunc
(
text
,
func
(
match
string
)
string
{
submatches
:=
toolNameFieldRe
.
FindStringSubmatch
(
match
)
if
len
(
submatches
)
<
2
{
return
match
}
name
:=
submatches
[
1
]
mapped
:=
normalizeToolNameForOpenCode
(
name
,
toolNameMap
)
if
mapped
==
name
{
return
match
}
return
strings
.
Replace
(
match
,
name
,
mapped
,
1
)
})
output
=
modelFieldRe
.
ReplaceAllStringFunc
(
output
,
func
(
match
string
)
string
{
submatches
:=
modelFieldRe
.
FindStringSubmatch
(
match
)
if
len
(
submatches
)
<
2
{
return
match
}
model
:=
submatches
[
1
]
mapped
:=
claude
.
DenormalizeModelID
(
model
)
if
mapped
==
model
{
return
match
}
return
strings
.
Replace
(
match
,
model
,
mapped
,
1
)
})
msg
[
"model"
]
=
toModel
for
mapped
,
original
:=
range
toolNameMap
{
newData
,
err
:=
json
.
Marshal
(
event
)
if
mapped
==
""
||
original
==
""
||
mapped
==
original
{
if
err
!=
nil
{
continue
return
line
}
output
=
strings
.
ReplaceAll
(
output
,
"
\"
"
+
mapped
+
"
\"
:"
,
"
\"
"
+
original
+
"
\"
:"
)
output
=
strings
.
ReplaceAll
(
output
,
"
\\\"
"
+
mapped
+
"
\\\"
:"
,
"
\\\"
"
+
original
+
"
\\\"
:"
)
}
}
return
"data: "
+
string
(
newData
)
return
output
}
}
func
(
s
*
GatewayService
)
parseSSEUsage
(
data
string
,
usage
*
ClaudeUsage
)
{
func
(
s
*
GatewayService
)
parseSSEUsage
(
data
string
,
usage
*
ClaudeUsage
)
{
...
@@ -3394,7 +4404,7 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
...
@@ -3394,7 +4404,7 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
}
}
}
}
func
(
s
*
GatewayService
)
handleNonStreamingResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
Account
,
originalModel
,
mappedModel
string
)
(
*
ClaudeUsage
,
error
)
{
func
(
s
*
GatewayService
)
handleNonStreamingResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
Account
,
originalModel
,
mappedModel
string
,
toolNameMap
map
[
string
]
string
,
mimicClaudeCode
bool
)
(
*
ClaudeUsage
,
error
)
{
// 更新5h窗口状态
// 更新5h窗口状态
s
.
rateLimitService
.
UpdateSessionWindow
(
ctx
,
account
,
resp
.
Header
)
s
.
rateLimitService
.
UpdateSessionWindow
(
ctx
,
account
,
resp
.
Header
)
...
@@ -3415,6 +4425,9 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
...
@@ -3415,6 +4425,9 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
if
originalModel
!=
mappedModel
{
if
originalModel
!=
mappedModel
{
body
=
s
.
replaceModelInResponseBody
(
body
,
mappedModel
,
originalModel
)
body
=
s
.
replaceModelInResponseBody
(
body
,
mappedModel
,
originalModel
)
}
}
if
mimicClaudeCode
{
body
=
s
.
replaceToolNamesInResponseBody
(
body
,
toolNameMap
)
}
responseheaders
.
WriteFilteredHeaders
(
c
.
Writer
.
Header
(),
resp
.
Header
,
s
.
cfg
.
Security
.
ResponseHeaders
)
responseheaders
.
WriteFilteredHeaders
(
c
.
Writer
.
Header
(),
resp
.
Header
,
s
.
cfg
.
Security
.
ResponseHeaders
)
...
@@ -3452,6 +4465,28 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
...
@@ -3452,6 +4465,28 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
return
newBody
return
newBody
}
}
func
(
s
*
GatewayService
)
replaceToolNamesInResponseBody
(
body
[]
byte
,
toolNameMap
map
[
string
]
string
)
[]
byte
{
if
len
(
body
)
==
0
{
return
body
}
var
resp
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
resp
);
err
!=
nil
{
replaced
:=
replaceToolNamesInText
(
string
(
body
),
toolNameMap
)
if
replaced
==
string
(
body
)
{
return
body
}
return
[]
byte
(
replaced
)
}
if
!
rewriteToolNamesInValue
(
resp
,
toolNameMap
)
{
return
body
}
newBody
,
err
:=
json
.
Marshal
(
resp
)
if
err
!=
nil
{
return
body
}
return
newBody
}
// RecordUsageInput 记录使用量的输入参数
// RecordUsageInput 记录使用量的输入参数
type
RecordUsageInput
struct
{
type
RecordUsageInput
struct
{
Result
*
ForwardResult
Result
*
ForwardResult
...
@@ -3773,6 +4808,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
...
@@ -3773,6 +4808,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
body
:=
parsed
.
Body
body
:=
parsed
.
Body
reqModel
:=
parsed
.
Model
reqModel
:=
parsed
.
Model
isClaudeCode
:=
isClaudeCodeRequest
(
ctx
,
c
,
parsed
)
shouldMimicClaudeCode
:=
account
.
IsOAuth
()
&&
!
isClaudeCode
if
shouldMimicClaudeCode
{
normalizeOpts
:=
claudeOAuthNormalizeOptions
{
stripSystemCacheControl
:
true
}
body
,
reqModel
,
_
=
normalizeClaudeOAuthRequestBody
(
body
,
reqModel
,
normalizeOpts
)
}
// Antigravity 账户不支持 count_tokens 转发,直接返回空值
// Antigravity 账户不支持 count_tokens 转发,直接返回空值
if
account
.
Platform
==
PlatformAntigravity
{
if
account
.
Platform
==
PlatformAntigravity
{
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"input_tokens"
:
0
})
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"input_tokens"
:
0
})
...
@@ -3799,7 +4842,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
...
@@ -3799,7 +4842,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
}
// 构建上游请求
// 构建上游请求
upstreamReq
,
err
:=
s
.
buildCountTokensRequest
(
ctx
,
c
,
account
,
body
,
token
,
tokenType
,
reqModel
)
upstreamReq
,
err
:=
s
.
buildCountTokensRequest
(
ctx
,
c
,
account
,
body
,
token
,
tokenType
,
reqModel
,
shouldMimicClaudeCode
)
if
err
!=
nil
{
if
err
!=
nil
{
s
.
countTokensError
(
c
,
http
.
StatusInternalServerError
,
"api_error"
,
"Failed to build request"
)
s
.
countTokensError
(
c
,
http
.
StatusInternalServerError
,
"api_error"
,
"Failed to build request"
)
return
err
return
err
...
@@ -3832,7 +4875,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
...
@@ -3832,7 +4875,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
log
.
Printf
(
"Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks"
,
account
.
ID
)
log
.
Printf
(
"Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks"
,
account
.
ID
)
filteredBody
:=
FilterThinkingBlocksForRetry
(
body
)
filteredBody
:=
FilterThinkingBlocksForRetry
(
body
)
retryReq
,
buildErr
:=
s
.
buildCountTokensRequest
(
ctx
,
c
,
account
,
filteredBody
,
token
,
tokenType
,
reqModel
)
retryReq
,
buildErr
:=
s
.
buildCountTokensRequest
(
ctx
,
c
,
account
,
filteredBody
,
token
,
tokenType
,
reqModel
,
shouldMimicClaudeCode
)
if
buildErr
==
nil
{
if
buildErr
==
nil
{
retryResp
,
retryErr
:=
s
.
httpUpstream
.
DoWithTLS
(
retryReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
())
retryResp
,
retryErr
:=
s
.
httpUpstream
.
DoWithTLS
(
retryReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
())
if
retryErr
==
nil
{
if
retryErr
==
nil
{
...
@@ -3897,7 +4940,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
...
@@ -3897,7 +4940,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
}
// buildCountTokensRequest 构建 count_tokens 上游请求
// buildCountTokensRequest 构建 count_tokens 上游请求
func
(
s
*
GatewayService
)
buildCountTokensRequest
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
,
token
,
tokenType
,
modelID
string
)
(
*
http
.
Request
,
error
)
{
func
(
s
*
GatewayService
)
buildCountTokensRequest
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
,
token
,
tokenType
,
modelID
string
,
mimicClaudeCode
bool
)
(
*
http
.
Request
,
error
)
{
// 确定目标 URL
// 确定目标 URL
targetURL
:=
claudeAPICountTokensURL
targetURL
:=
claudeAPICountTokensURL
if
account
.
Type
==
AccountTypeAPIKey
{
if
account
.
Type
==
AccountTypeAPIKey
{
...
@@ -3911,10 +4954,15 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
...
@@ -3911,10 +4954,15 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
}
}
}
clientHeaders
:=
http
.
Header
{}
if
c
!=
nil
&&
c
.
Request
!=
nil
{
clientHeaders
=
c
.
Request
.
Header
}
// OAuth 账号:应用统一指纹和重写 userID
// OAuth 账号:应用统一指纹和重写 userID
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
if
account
.
IsOAuth
()
&&
s
.
identityService
!=
nil
{
if
account
.
IsOAuth
()
&&
s
.
identityService
!=
nil
{
fp
,
err
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
.
Request
.
Header
)
fp
,
err
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
lient
Header
s
)
if
err
==
nil
{
if
err
==
nil
{
accountUUID
:=
account
.
GetExtraString
(
"account_uuid"
)
accountUUID
:=
account
.
GetExtraString
(
"account_uuid"
)
if
accountUUID
!=
""
&&
fp
.
ClientID
!=
""
{
if
accountUUID
!=
""
&&
fp
.
ClientID
!=
""
{
...
@@ -3938,7 +4986,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
...
@@ -3938,7 +4986,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
}
// 白名单透传 headers
// 白名单透传 headers
for
key
,
values
:=
range
c
.
Request
.
Header
{
for
key
,
values
:=
range
c
lient
Header
s
{
lowerKey
:=
strings
.
ToLower
(
key
)
lowerKey
:=
strings
.
ToLower
(
key
)
if
allowedHeaders
[
lowerKey
]
{
if
allowedHeaders
[
lowerKey
]
{
for
_
,
v
:=
range
values
{
for
_
,
v
:=
range
values
{
...
@@ -3949,7 +4997,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
...
@@ -3949,7 +4997,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
// OAuth 账号:应用指纹到请求头
// OAuth 账号:应用指纹到请求头
if
account
.
IsOAuth
()
&&
s
.
identityService
!=
nil
{
if
account
.
IsOAuth
()
&&
s
.
identityService
!=
nil
{
fp
,
_
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
.
Request
.
Header
)
fp
,
_
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
lient
Header
s
)
if
fp
!=
nil
{
if
fp
!=
nil
{
s
.
identityService
.
ApplyFingerprint
(
req
,
fp
)
s
.
identityService
.
ApplyFingerprint
(
req
,
fp
)
}
}
...
@@ -3962,10 +5010,30 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
...
@@ -3962,10 +5010,30 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
if
req
.
Header
.
Get
(
"anthropic-version"
)
==
""
{
if
req
.
Header
.
Get
(
"anthropic-version"
)
==
""
{
req
.
Header
.
Set
(
"anthropic-version"
,
"2023-06-01"
)
req
.
Header
.
Set
(
"anthropic-version"
,
"2023-06-01"
)
}
}
if
tokenType
==
"oauth"
{
applyClaudeOAuthHeaderDefaults
(
req
,
false
)
}
// OAuth 账号:处理 anthropic-beta header
// OAuth 账号:处理 anthropic-beta header
if
tokenType
==
"oauth"
{
if
tokenType
==
"oauth"
{
req
.
Header
.
Set
(
"anthropic-beta"
,
s
.
getBetaHeader
(
modelID
,
c
.
GetHeader
(
"anthropic-beta"
)))
if
mimicClaudeCode
{
applyClaudeCodeMimicHeaders
(
req
,
false
)
incomingBeta
:=
req
.
Header
.
Get
(
"anthropic-beta"
)
requiredBetas
:=
[]
string
{
claude
.
BetaClaudeCode
,
claude
.
BetaOAuth
,
claude
.
BetaInterleavedThinking
,
claude
.
BetaTokenCounting
}
req
.
Header
.
Set
(
"anthropic-beta"
,
mergeAnthropicBeta
(
requiredBetas
,
incomingBeta
))
}
else
{
clientBetaHeader
:=
req
.
Header
.
Get
(
"anthropic-beta"
)
if
clientBetaHeader
==
""
{
req
.
Header
.
Set
(
"anthropic-beta"
,
claude
.
CountTokensBetaHeader
)
}
else
{
beta
:=
s
.
getBetaHeader
(
modelID
,
clientBetaHeader
)
if
!
strings
.
Contains
(
beta
,
claude
.
BetaTokenCounting
)
{
beta
=
beta
+
","
+
claude
.
BetaTokenCounting
}
req
.
Header
.
Set
(
"anthropic-beta"
,
beta
)
}
}
}
else
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
InjectBetaForAPIKey
&&
req
.
Header
.
Get
(
"anthropic-beta"
)
==
""
{
}
else
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
InjectBetaForAPIKey
&&
req
.
Header
.
Get
(
"anthropic-beta"
)
==
""
{
// API-key:与 messages 同步的按需 beta 注入(默认关闭)
// API-key:与 messages 同步的按需 beta 注入(默认关闭)
if
requestNeedsBetaFeatures
(
body
)
{
if
requestNeedsBetaFeatures
(
body
)
{
...
@@ -3975,6 +5043,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
...
@@ -3975,6 +5043,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
}
}
}
if
c
!=
nil
&&
tokenType
==
"oauth"
{
c
.
Set
(
claudeMimicDebugInfoKey
,
buildClaudeMimicDebugLine
(
req
,
body
,
account
,
tokenType
,
mimicClaudeCode
))
}
if
s
.
debugClaudeMimicEnabled
()
{
logClaudeMimicDebug
(
req
,
body
,
account
,
tokenType
,
mimicClaudeCode
)
}
return
req
,
nil
return
req
,
nil
}
}
...
...
backend/internal/service/identity_service.go
View file @
a09478f3
...
@@ -26,13 +26,13 @@ var (
...
@@ -26,13 +26,13 @@ var (
// 默认指纹值(当客户端未提供时使用)
// 默认指纹值(当客户端未提供时使用)
var
defaultFingerprint
=
Fingerprint
{
var
defaultFingerprint
=
Fingerprint
{
UserAgent
:
"claude-cli/2.
0.6
2 (external, cli)"
,
UserAgent
:
"claude-cli/2.
1.2
2 (external, cli)"
,
StainlessLang
:
"js"
,
StainlessLang
:
"js"
,
StainlessPackageVersion
:
"0.
52
.0"
,
StainlessPackageVersion
:
"0.
70
.0"
,
StainlessOS
:
"Linux"
,
StainlessOS
:
"Linux"
,
StainlessArch
:
"
x
64"
,
StainlessArch
:
"
arm
64"
,
StainlessRuntime
:
"node"
,
StainlessRuntime
:
"node"
,
StainlessRuntimeVersion
:
"v2
2
.1
4
.0"
,
StainlessRuntimeVersion
:
"v2
4
.1
3
.0"
,
}
}
// Fingerprint represents account fingerprint data
// Fingerprint represents account fingerprint data
...
@@ -327,7 +327,7 @@ func generateUUIDFromSeed(seed string) string {
...
@@ -327,7 +327,7 @@ func generateUUIDFromSeed(seed string) string {
}
}
// parseUserAgentVersion 解析user-agent版本号
// parseUserAgentVersion 解析user-agent版本号
// 例如:claude-cli/2.
0.6
2 -> (2,
0
,
6
2)
// 例如:claude-cli/2.
1.
2 -> (2,
1
, 2)
func
parseUserAgentVersion
(
ua
string
)
(
major
,
minor
,
patch
int
,
ok
bool
)
{
func
parseUserAgentVersion
(
ua
string
)
(
major
,
minor
,
patch
int
,
ok
bool
)
{
// 匹配 xxx/x.y.z 格式
// 匹配 xxx/x.y.z 格式
matches
:=
userAgentVersionRegex
.
FindStringSubmatch
(
ua
)
matches
:=
userAgentVersionRegex
.
FindStringSubmatch
(
ua
)
...
...
backend/internal/service/openai_gateway_service.go
View file @
a09478f3
...
@@ -1260,15 +1260,29 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
...
@@ -1260,15 +1260,29 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
// 记录上次收到上游数据的时间,用于控制 keepalive 发送频率
// 记录上次收到上游数据的时间,用于控制 keepalive 发送频率
lastDataAt
:=
time
.
Now
()
lastDataAt
:=
time
.
Now
()
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
// 仅发送一次错误事件,避免多次写入导致协议混乱。
// 注意:OpenAI `/v1/responses` streaming 事件必须符合 OpenAI Responses schema;
// 否则下游 SDK(例如 OpenCode)会因为类型校验失败而报错。
errorEventSent
:=
false
errorEventSent
:=
false
clientDisconnected
:=
false
// 客户端断开后继续 drain 上游以收集 usage
sendErrorEvent
:=
func
(
reason
string
)
{
sendErrorEvent
:=
func
(
reason
string
)
{
if
errorEventSent
{
if
errorEventSent
||
clientDisconnected
{
return
return
}
}
errorEventSent
=
true
errorEventSent
=
true
_
,
_
=
fmt
.
Fprintf
(
w
,
"event: error
\n
data: {
\"
error
\"
:
\"
%s
\"
}
\n\n
"
,
reason
)
payload
:=
map
[
string
]
any
{
flusher
.
Flush
()
"type"
:
"error"
,
"sequence_number"
:
0
,
"error"
:
map
[
string
]
any
{
"type"
:
"upstream_error"
,
"message"
:
reason
,
"code"
:
reason
,
},
}
if
b
,
err
:=
json
.
Marshal
(
payload
);
err
==
nil
{
_
,
_
=
fmt
.
Fprintf
(
w
,
"data: %s
\n\n
"
,
b
)
flusher
.
Flush
()
}
}
}
needModelReplace
:=
originalModel
!=
mappedModel
needModelReplace
:=
originalModel
!=
mappedModel
...
@@ -1280,6 +1294,17 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
...
@@ -1280,6 +1294,17 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
}
}
if
ev
.
err
!=
nil
{
if
ev
.
err
!=
nil
{
// 客户端断开/取消请求时,上游读取往往会返回 context canceled。
// /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。
if
errors
.
Is
(
ev
.
err
,
context
.
Canceled
)
||
errors
.
Is
(
ev
.
err
,
context
.
DeadlineExceeded
)
{
log
.
Printf
(
"Context canceled during streaming, returning collected usage"
)
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
}
// 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
if
clientDisconnected
{
log
.
Printf
(
"Upstream read error after client disconnect: %v, returning collected usage"
,
ev
.
err
)
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
}
if
errors
.
Is
(
ev
.
err
,
bufio
.
ErrTooLong
)
{
if
errors
.
Is
(
ev
.
err
,
bufio
.
ErrTooLong
)
{
log
.
Printf
(
"SSE line too long: account=%d max_size=%d error=%v"
,
account
.
ID
,
maxLineSize
,
ev
.
err
)
log
.
Printf
(
"SSE line too long: account=%d max_size=%d error=%v"
,
account
.
ID
,
maxLineSize
,
ev
.
err
)
sendErrorEvent
(
"response_too_large"
)
sendErrorEvent
(
"response_too_large"
)
...
@@ -1303,15 +1328,19 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
...
@@ -1303,15 +1328,19 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
if
correctedData
,
corrected
:=
s
.
toolCorrector
.
CorrectToolCallsInSSEData
(
data
);
corrected
{
if
correctedData
,
corrected
:=
s
.
toolCorrector
.
CorrectToolCallsInSSEData
(
data
);
corrected
{
data
=
correctedData
line
=
"data: "
+
correctedData
line
=
"data: "
+
correctedData
}
}
// Forward line
// 写入客户端(客户端断开后继续 drain 上游)
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
if
!
clientDisconnected
{
sendErrorEvent
(
"write_failed"
)
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
clientDisconnected
=
true
log
.
Printf
(
"Client disconnected during streaming, continuing to drain upstream for billing"
)
}
else
{
flusher
.
Flush
()
}
}
}
flusher
.
Flush
()
// Record first token time
// Record first token time
if
firstTokenMs
==
nil
&&
data
!=
""
&&
data
!=
"[DONE]"
{
if
firstTokenMs
==
nil
&&
data
!=
""
&&
data
!=
"[DONE]"
{
...
@@ -1321,11 +1350,14 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
...
@@ -1321,11 +1350,14 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
s
.
parseSSEUsage
(
data
,
usage
)
s
.
parseSSEUsage
(
data
,
usage
)
}
else
{
}
else
{
// Forward non-data lines as-is
// Forward non-data lines as-is
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
if
!
clientDisconnected
{
sendErrorEvent
(
"write_failed"
)
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
clientDisconnected
=
true
log
.
Printf
(
"Client disconnected during streaming, continuing to drain upstream for billing"
)
}
else
{
flusher
.
Flush
()
}
}
}
flusher
.
Flush
()
}
}
case
<-
intervalCh
:
case
<-
intervalCh
:
...
@@ -1333,6 +1365,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
...
@@ -1333,6 +1365,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if
time
.
Since
(
lastRead
)
<
streamInterval
{
if
time
.
Since
(
lastRead
)
<
streamInterval
{
continue
continue
}
}
if
clientDisconnected
{
log
.
Printf
(
"Upstream timeout after client disconnect, returning collected usage"
)
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
}
log
.
Printf
(
"Stream data interval timeout: account=%d model=%s interval=%s"
,
account
.
ID
,
originalModel
,
streamInterval
)
log
.
Printf
(
"Stream data interval timeout: account=%d model=%s interval=%s"
,
account
.
ID
,
originalModel
,
streamInterval
)
// 处理流超时,可能标记账户为临时不可调度或错误状态
// 处理流超时,可能标记账户为临时不可调度或错误状态
if
s
.
rateLimitService
!=
nil
{
if
s
.
rateLimitService
!=
nil
{
...
@@ -1342,11 +1378,16 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
...
@@ -1342,11 +1378,16 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream data interval timeout"
)
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream data interval timeout"
)
case
<-
keepaliveCh
:
case
<-
keepaliveCh
:
if
clientDisconnected
{
continue
}
if
time
.
Since
(
lastDataAt
)
<
keepaliveInterval
{
if
time
.
Since
(
lastDataAt
)
<
keepaliveInterval
{
continue
continue
}
}
if
_
,
err
:=
fmt
.
Fprint
(
w
,
":
\n\n
"
);
err
!=
nil
{
if
_
,
err
:=
fmt
.
Fprint
(
w
,
":
\n\n
"
);
err
!=
nil
{
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
clientDisconnected
=
true
log
.
Printf
(
"Client disconnected during streaming, continuing to drain upstream for billing"
)
continue
}
}
flusher
.
Flush
()
flusher
.
Flush
()
}
}
...
...
backend/internal/service/openai_gateway_service_test.go
View file @
a09478f3
...
@@ -59,6 +59,25 @@ type stubConcurrencyCache struct {
...
@@ -59,6 +59,25 @@ type stubConcurrencyCache struct {
skipDefaultLoad
bool
skipDefaultLoad
bool
}
}
type
cancelReadCloser
struct
{}
func
(
c
cancelReadCloser
)
Read
(
p
[]
byte
)
(
int
,
error
)
{
return
0
,
context
.
Canceled
}
func
(
c
cancelReadCloser
)
Close
()
error
{
return
nil
}
type
failingGinWriter
struct
{
gin
.
ResponseWriter
failAfter
int
writes
int
}
func
(
w
*
failingGinWriter
)
Write
(
p
[]
byte
)
(
int
,
error
)
{
if
w
.
writes
>=
w
.
failAfter
{
return
0
,
errors
.
New
(
"write failed"
)
}
w
.
writes
++
return
w
.
ResponseWriter
.
Write
(
p
)
}
func
(
c
stubConcurrencyCache
)
AcquireAccountSlot
(
ctx
context
.
Context
,
accountID
int64
,
maxConcurrency
int
,
requestID
string
)
(
bool
,
error
)
{
func
(
c
stubConcurrencyCache
)
AcquireAccountSlot
(
ctx
context
.
Context
,
accountID
int64
,
maxConcurrency
int
,
requestID
string
)
(
bool
,
error
)
{
if
c
.
acquireResults
!=
nil
{
if
c
.
acquireResults
!=
nil
{
if
result
,
ok
:=
c
.
acquireResults
[
accountID
];
ok
{
if
result
,
ok
:=
c
.
acquireResults
[
accountID
];
ok
{
...
@@ -814,8 +833,85 @@ func TestOpenAIStreamingTimeout(t *testing.T) {
...
@@ -814,8 +833,85 @@ func TestOpenAIStreamingTimeout(t *testing.T) {
if
err
==
nil
||
!
strings
.
Contains
(
err
.
Error
(),
"stream data interval timeout"
)
{
if
err
==
nil
||
!
strings
.
Contains
(
err
.
Error
(),
"stream data interval timeout"
)
{
t
.
Fatalf
(
"expected stream timeout error, got %v"
,
err
)
t
.
Fatalf
(
"expected stream timeout error, got %v"
,
err
)
}
}
if
!
strings
.
Contains
(
rec
.
Body
.
String
(),
"stream_timeout"
)
{
if
!
strings
.
Contains
(
rec
.
Body
.
String
(),
"
\"
type
\"
:
\"
error
\"
"
)
||
!
strings
.
Contains
(
rec
.
Body
.
String
(),
"stream_timeout"
)
{
t
.
Fatalf
(
"expected stream_timeout SSE error, got %q"
,
rec
.
Body
.
String
())
t
.
Fatalf
(
"expected OpenAI-compatible error SSE event, got %q"
,
rec
.
Body
.
String
())
}
}
func
TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
StreamDataIntervalTimeout
:
0
,
StreamKeepaliveInterval
:
0
,
MaxLineSize
:
defaultMaxLineSize
,
},
}
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
}
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
ctx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
cancel
()
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
.
WithContext
(
ctx
)
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
cancelReadCloser
{},
Header
:
http
.
Header
{},
}
_
,
err
:=
svc
.
handleStreamingResponse
(
c
.
Request
.
Context
(),
resp
,
c
,
&
Account
{
ID
:
1
},
time
.
Now
(),
"model"
,
"model"
)
if
err
!=
nil
{
t
.
Fatalf
(
"expected nil error, got %v"
,
err
)
}
if
strings
.
Contains
(
rec
.
Body
.
String
(),
"event: error"
)
||
strings
.
Contains
(
rec
.
Body
.
String
(),
"stream_read_error"
)
{
t
.
Fatalf
(
"expected no injected SSE error event, got %q"
,
rec
.
Body
.
String
())
}
}
func
TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
StreamDataIntervalTimeout
:
0
,
StreamKeepaliveInterval
:
0
,
MaxLineSize
:
defaultMaxLineSize
,
},
}
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
}
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
c
.
Writer
=
&
failingGinWriter
{
ResponseWriter
:
c
.
Writer
,
failAfter
:
0
}
pr
,
pw
:=
io
.
Pipe
()
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
pr
,
Header
:
http
.
Header
{},
}
go
func
()
{
defer
func
()
{
_
=
pw
.
Close
()
}()
_
,
_
=
pw
.
Write
([]
byte
(
"data: {
\"
type
\"
:
\"
response.in_progress
\"
,
\"
response
\"
:{}}
\n\n
"
))
_
,
_
=
pw
.
Write
([]
byte
(
"data: {
\"
type
\"
:
\"
response.completed
\"
,
\"
response
\"
:{
\"
usage
\"
:{
\"
input_tokens
\"
:3,
\"
output_tokens
\"
:5,
\"
input_tokens_details
\"
:{
\"
cached_tokens
\"
:1}}}}
\n\n
"
))
}()
result
,
err
:=
svc
.
handleStreamingResponse
(
c
.
Request
.
Context
(),
resp
,
c
,
&
Account
{
ID
:
1
},
time
.
Now
(),
"model"
,
"model"
)
_
=
pr
.
Close
()
if
err
!=
nil
{
t
.
Fatalf
(
"expected nil error, got %v"
,
err
)
}
if
result
==
nil
||
result
.
usage
==
nil
{
t
.
Fatalf
(
"expected usage result"
)
}
if
result
.
usage
.
InputTokens
!=
3
||
result
.
usage
.
OutputTokens
!=
5
||
result
.
usage
.
CacheReadInputTokens
!=
1
{
t
.
Fatalf
(
"unexpected usage: %+v"
,
*
result
.
usage
)
}
if
strings
.
Contains
(
rec
.
Body
.
String
(),
"event: error"
)
||
strings
.
Contains
(
rec
.
Body
.
String
(),
"write_failed"
)
{
t
.
Fatalf
(
"expected no injected SSE error event, got %q"
,
rec
.
Body
.
String
())
}
}
}
}
...
@@ -854,8 +950,8 @@ func TestOpenAIStreamingTooLong(t *testing.T) {
...
@@ -854,8 +950,8 @@ func TestOpenAIStreamingTooLong(t *testing.T) {
if
!
errors
.
Is
(
err
,
bufio
.
ErrTooLong
)
{
if
!
errors
.
Is
(
err
,
bufio
.
ErrTooLong
)
{
t
.
Fatalf
(
"expected ErrTooLong, got %v"
,
err
)
t
.
Fatalf
(
"expected ErrTooLong, got %v"
,
err
)
}
}
if
!
strings
.
Contains
(
rec
.
Body
.
String
(),
"response_too_large"
)
{
if
!
strings
.
Contains
(
rec
.
Body
.
String
(),
"
\"
type
\"
:
\"
error
\"
"
)
||
!
strings
.
Contains
(
rec
.
Body
.
String
(),
"response_too_large"
)
{
t
.
Fatalf
(
"expected
response_too_large
SSE e
rror
, got %q"
,
rec
.
Body
.
String
())
t
.
Fatalf
(
"expected
OpenAI-compatible error
SSE e
vent
, got %q"
,
rec
.
Body
.
String
())
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment