Unverified Commit a09478f3 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #316 from cyhhao/fix/claude-oauth-compat

fix(网关): 完善 Claude OAuth/Claude Code 兼容
parents 0ab68aa9 2fe8932c
...@@ -779,6 +779,9 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { ...@@ -779,6 +779,9 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
return return
} }
// 检查是否为 Claude Code 客户端,设置到 context 中
SetClaudeCodeClientContext(c, body)
setOpsRequestContext(c, "", false, body) setOpsRequestContext(c, "", false, body)
parsedReq, err := service.ParseGatewayRequest(body) parsedReq, err := service.ParseGatewayRequest(body)
......
...@@ -9,11 +9,26 @@ const ( ...@@ -9,11 +9,26 @@ const (
BetaClaudeCode = "claude-code-20250219" BetaClaudeCode = "claude-code-20250219"
BetaInterleavedThinking = "interleaved-thinking-2025-05-14" BetaInterleavedThinking = "interleaved-thinking-2025-05-14"
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14" BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
BetaTokenCounting = "token-counting-2024-11-01"
) )
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header // DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
// MessageBetaHeaderNoTools /v1/messages 在无工具时的 beta header
//
// NOTE: Claude Code OAuth credentials are scoped to Claude Code. When we "mimic"
// Claude Code for non-Claude-Code clients, we must include the claude-code beta
// even if the request doesn't use tools, otherwise upstream may reject the
// request as a non-Claude-Code API request.
const MessageBetaHeaderNoTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking
// MessageBetaHeaderWithTools /v1/messages 在有工具时的 beta header
const MessageBetaHeaderWithTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking
// CountTokensBetaHeader count_tokens 请求使用的 anthropic-beta header
const CountTokensBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaTokenCounting
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta) // HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta)
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
...@@ -25,15 +40,17 @@ const APIKeyHaikuBetaHeader = BetaInterleavedThinking ...@@ -25,15 +40,17 @@ const APIKeyHaikuBetaHeader = BetaInterleavedThinking
// DefaultHeaders 是 Claude Code 客户端默认请求头。 // DefaultHeaders 是 Claude Code 客户端默认请求头。
var DefaultHeaders = map[string]string{ var DefaultHeaders = map[string]string{
"User-Agent": "claude-cli/2.0.62 (external, cli)", // Keep these in sync with recent Claude CLI traffic to reduce the chance
// that Claude Code-scoped OAuth credentials are rejected as "non-CLI" usage.
"User-Agent": "claude-cli/2.1.22 (external, cli)",
"X-Stainless-Lang": "js", "X-Stainless-Lang": "js",
"X-Stainless-Package-Version": "0.52.0", "X-Stainless-Package-Version": "0.70.0",
"X-Stainless-OS": "Linux", "X-Stainless-OS": "Linux",
"X-Stainless-Arch": "x64", "X-Stainless-Arch": "arm64",
"X-Stainless-Runtime": "node", "X-Stainless-Runtime": "node",
"X-Stainless-Runtime-Version": "v22.14.0", "X-Stainless-Runtime-Version": "v24.13.0",
"X-Stainless-Retry-Count": "0", "X-Stainless-Retry-Count": "0",
"X-Stainless-Timeout": "60", "X-Stainless-Timeout": "600",
"X-App": "cli", "X-App": "cli",
"Anthropic-Dangerous-Direct-Browser-Access": "true", "Anthropic-Dangerous-Direct-Browser-Access": "true",
} }
...@@ -79,3 +96,39 @@ func DefaultModelIDs() []string { ...@@ -79,3 +96,39 @@ func DefaultModelIDs() []string {
// DefaultTestModel 测试时使用的默认模型 // DefaultTestModel 测试时使用的默认模型
const DefaultTestModel = "claude-sonnet-4-5-20250929" const DefaultTestModel = "claude-sonnet-4-5-20250929"
// ModelIDOverrides Claude OAuth 请求需要的模型 ID 映射
var ModelIDOverrides = map[string]string{
"claude-sonnet-4-5": "claude-sonnet-4-5-20250929",
"claude-opus-4-5": "claude-opus-4-5-20251101",
"claude-haiku-4-5": "claude-haiku-4-5-20251001",
}
// ModelIDReverseOverrides 用于将上游模型 ID 还原为短名
var ModelIDReverseOverrides = map[string]string{
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
"claude-opus-4-5-20251101": "claude-opus-4-5",
"claude-haiku-4-5-20251001": "claude-haiku-4-5",
}
// NormalizeModelID 根据 Claude OAuth 规则映射模型
func NormalizeModelID(id string) string {
if id == "" {
return id
}
if mapped, ok := ModelIDOverrides[id]; ok {
return mapped
}
return id
}
// DenormalizeModelID 将上游模型 ID 转换为短名
func DenormalizeModelID(id string) string {
if id == "" {
return id
}
if mapped, ok := ModelIDReverseOverrides[id]; ok {
return mapped
}
return id
}
...@@ -410,6 +410,22 @@ func (a *Account) GetExtraString(key string) string { ...@@ -410,6 +410,22 @@ func (a *Account) GetExtraString(key string) string {
return "" return ""
} }
func (a *Account) GetClaudeUserID() string {
if v := strings.TrimSpace(a.GetExtraString("claude_user_id")); v != "" {
return v
}
if v := strings.TrimSpace(a.GetExtraString("anthropic_user_id")); v != "" {
return v
}
if v := strings.TrimSpace(a.GetCredential("claude_user_id")); v != "" {
return v
}
if v := strings.TrimSpace(a.GetCredential("anthropic_user_id")); v != "" {
return v
}
return ""
}
func (a *Account) IsCustomErrorCodesEnabled() bool { func (a *Account) IsCustomErrorCodesEnabled() bool {
if a.Type != AccountTypeAPIKey || a.Credentials == nil { if a.Type != AccountTypeAPIKey || a.Credentials == nil {
return false return false
......
...@@ -123,7 +123,7 @@ func createTestPayload(modelID string) (map[string]any, error) { ...@@ -123,7 +123,7 @@ func createTestPayload(modelID string) (map[string]any, error) {
"system": []map[string]any{ "system": []map[string]any{
{ {
"type": "text", "type": "text",
"text": "You are Claude Code, Anthropic's official CLI for Claude.", "text": claudeCodeSystemPrompt,
"cache_control": map[string]string{ "cache_control": map[string]string{
"type": "ephemeral", "type": "ephemeral",
}, },
......
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestMergeAnthropicBeta(t *testing.T) {
got := mergeAnthropicBeta(
[]string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"},
"foo, oauth-2025-04-20,bar, foo",
)
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,foo,bar", got)
}
func TestMergeAnthropicBeta_EmptyIncoming(t *testing.T) {
got := mergeAnthropicBeta(
[]string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"},
"",
)
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14", got)
}
package service
import (
"regexp"
"testing"
"github.com/stretchr/testify/require"
)
func TestBuildOAuthMetadataUserID_FallbackWithoutAccountUUID(t *testing.T) {
svc := &GatewayService{}
parsed := &ParsedRequest{
Model: "claude-sonnet-4-5",
Stream: true,
MetadataUserID: "",
System: nil,
Messages: nil,
}
account := &Account{
ID: 123,
Type: AccountTypeOAuth,
Extra: map[string]any{}, // intentionally missing account_uuid / claude_user_id
}
fp := &Fingerprint{ClientID: "deadbeef"} // should be used as user id in legacy format
got := svc.buildOAuthMetadataUserID(parsed, account, fp)
require.NotEmpty(t, got)
// Legacy format: user_{client}_account__session_{uuid}
re := regexp.MustCompile(`^user_[a-zA-Z0-9]+_account__session_[a-f0-9-]{36}$`)
require.True(t, re.MatchString(got), "unexpected user_id format: %s", got)
}
func TestBuildOAuthMetadataUserID_UsesAccountUUIDWhenPresent(t *testing.T) {
svc := &GatewayService{}
parsed := &ParsedRequest{
Model: "claude-sonnet-4-5",
Stream: true,
MetadataUserID: "",
}
account := &Account{
ID: 123,
Type: AccountTypeOAuth,
Extra: map[string]any{
"account_uuid": "acc-uuid",
"claude_user_id": "clientid123",
"anthropic_user_id": "",
},
}
got := svc.buildOAuthMetadataUserID(parsed, account, nil)
require.NotEmpty(t, got)
// New format: user_{client}_account_{account_uuid}_session_{uuid}
re := regexp.MustCompile(`^user_clientid123_account_acc-uuid_session_[a-f0-9-]{36}$`)
require.True(t, re.MatchString(got), "unexpected user_id format: %s", got)
}
...@@ -2,6 +2,7 @@ package service ...@@ -2,6 +2,7 @@ package service
import ( import (
"encoding/json" "encoding/json"
"strings"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
...@@ -134,6 +135,8 @@ func TestSystemIncludesClaudeCodePrompt(t *testing.T) { ...@@ -134,6 +135,8 @@ func TestSystemIncludesClaudeCodePrompt(t *testing.T) {
} }
func TestInjectClaudeCodePrompt(t *testing.T) { func TestInjectClaudeCodePrompt(t *testing.T) {
claudePrefix := strings.TrimSpace(claudeCodeSystemPrompt)
tests := []struct { tests := []struct {
name string name string
body string body string
...@@ -162,7 +165,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) { ...@@ -162,7 +165,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
system: "Custom prompt", system: "Custom prompt",
wantSystemLen: 2, wantSystemLen: 2,
wantFirstText: claudeCodeSystemPrompt, wantFirstText: claudeCodeSystemPrompt,
wantSecondText: "Custom prompt", wantSecondText: claudePrefix + "\n\nCustom prompt",
}, },
{ {
name: "string system equals Claude Code prompt", name: "string system equals Claude Code prompt",
...@@ -178,7 +181,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) { ...@@ -178,7 +181,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
// Claude Code + Custom = 2 // Claude Code + Custom = 2
wantSystemLen: 2, wantSystemLen: 2,
wantFirstText: claudeCodeSystemPrompt, wantFirstText: claudeCodeSystemPrompt,
wantSecondText: "Custom", wantSecondText: claudePrefix + "\n\nCustom",
}, },
{ {
name: "array system with existing Claude Code prompt (should dedupe)", name: "array system with existing Claude Code prompt (should dedupe)",
...@@ -190,7 +193,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) { ...@@ -190,7 +193,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
// Claude Code at start + Other = 2 (deduped) // Claude Code at start + Other = 2 (deduped)
wantSystemLen: 2, wantSystemLen: 2,
wantFirstText: claudeCodeSystemPrompt, wantFirstText: claudeCodeSystemPrompt,
wantSecondText: "Other", wantSecondText: claudePrefix + "\n\nOther",
}, },
{ {
name: "empty array", name: "empty array",
......
package service
import (
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestSanitizeOpenCodeText_RewritesCanonicalSentence(t *testing.T) {
in := "You are OpenCode, the best coding agent on the planet."
got := sanitizeSystemText(in)
require.Equal(t, strings.TrimSpace(claudeCodeSystemPrompt), got)
}
func TestSanitizeToolDescription_DoesNotRewriteKeywords(t *testing.T) {
in := "OpenCode and opencode are mentioned."
got := sanitizeToolDescription(in)
// We no longer rewrite tool descriptions; only redact obvious path leaks.
require.Equal(t, in, got)
}
...@@ -20,12 +20,14 @@ import ( ...@@ -20,12 +20,14 @@ import (
"strings" "strings"
"sync/atomic" "sync/atomic"
"time" "time"
"unicode"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/google/uuid"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
...@@ -37,15 +39,27 @@ const ( ...@@ -37,15 +39,27 @@ const (
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true" claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
stickySessionTTL = time.Hour // 粘性会话TTL stickySessionTTL = time.Hour // 粘性会话TTL
defaultMaxLineSize = 40 * 1024 * 1024 defaultMaxLineSize = 40 * 1024 * 1024
// Canonical Claude Code banner. Keep it EXACT (no trailing whitespace/newlines)
// to match real Claude CLI traffic as closely as possible. When we need a visual
// separator between system blocks, we add "\n\n" at concatenation time.
claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude." claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude."
maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量 maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量
) )
const (
claudeMimicDebugInfoKey = "claude_mimic_debug_info"
)
func (s *GatewayService) debugModelRoutingEnabled() bool { func (s *GatewayService) debugModelRoutingEnabled() bool {
v := strings.ToLower(strings.TrimSpace(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING"))) v := strings.ToLower(strings.TrimSpace(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING")))
return v == "1" || v == "true" || v == "yes" || v == "on" return v == "1" || v == "true" || v == "yes" || v == "on"
} }
func (s *GatewayService) debugClaudeMimicEnabled() bool {
v := strings.ToLower(strings.TrimSpace(os.Getenv("SUB2API_DEBUG_CLAUDE_MIMIC")))
return v == "1" || v == "true" || v == "yes" || v == "on"
}
func shortSessionHash(sessionHash string) string { func shortSessionHash(sessionHash string) string {
if sessionHash == "" { if sessionHash == "" {
return "" return ""
...@@ -56,12 +70,178 @@ func shortSessionHash(sessionHash string) string { ...@@ -56,12 +70,178 @@ func shortSessionHash(sessionHash string) string {
return sessionHash[:8] return sessionHash[:8]
} }
func redactAuthHeaderValue(v string) string {
v = strings.TrimSpace(v)
if v == "" {
return ""
}
// Keep scheme for debugging, redact secret.
if strings.HasPrefix(strings.ToLower(v), "bearer ") {
return "Bearer [redacted]"
}
return "[redacted]"
}
func safeHeaderValueForLog(key string, v string) string {
key = strings.ToLower(strings.TrimSpace(key))
switch key {
case "authorization", "x-api-key":
return redactAuthHeaderValue(v)
default:
return strings.TrimSpace(v)
}
}
func extractSystemPreviewFromBody(body []byte) string {
if len(body) == 0 {
return ""
}
sys := gjson.GetBytes(body, "system")
if !sys.Exists() {
return ""
}
switch {
case sys.IsArray():
for _, item := range sys.Array() {
if !item.IsObject() {
continue
}
if strings.EqualFold(item.Get("type").String(), "text") {
if t := item.Get("text").String(); strings.TrimSpace(t) != "" {
return t
}
}
}
return ""
case sys.Type == gjson.String:
return sys.String()
default:
return ""
}
}
func buildClaudeMimicDebugLine(req *http.Request, body []byte, account *Account, tokenType string, mimicClaudeCode bool) string {
if req == nil {
return ""
}
// Only log a minimal fingerprint to avoid leaking user content.
interesting := []string{
"user-agent",
"x-app",
"anthropic-dangerous-direct-browser-access",
"anthropic-version",
"anthropic-beta",
"x-stainless-lang",
"x-stainless-package-version",
"x-stainless-os",
"x-stainless-arch",
"x-stainless-runtime",
"x-stainless-runtime-version",
"x-stainless-retry-count",
"x-stainless-timeout",
"authorization",
"x-api-key",
"content-type",
"accept",
"x-stainless-helper-method",
}
h := make([]string, 0, len(interesting))
for _, k := range interesting {
if v := req.Header.Get(k); v != "" {
h = append(h, fmt.Sprintf("%s=%q", k, safeHeaderValueForLog(k, v)))
}
}
metaUserID := strings.TrimSpace(gjson.GetBytes(body, "metadata.user_id").String())
sysPreview := strings.TrimSpace(extractSystemPreviewFromBody(body))
// Truncate preview to keep logs sane.
if len(sysPreview) > 300 {
sysPreview = sysPreview[:300] + "..."
}
sysPreview = strings.ReplaceAll(sysPreview, "\n", "\\n")
sysPreview = strings.ReplaceAll(sysPreview, "\r", "\\r")
aid := int64(0)
aname := ""
if account != nil {
aid = account.ID
aname = account.Name
}
return fmt.Sprintf(
"url=%s account=%d(%s) tokenType=%s mimic=%t meta.user_id=%q system.preview=%q headers={%s}",
req.URL.String(),
aid,
aname,
tokenType,
mimicClaudeCode,
metaUserID,
sysPreview,
strings.Join(h, " "),
)
}
func logClaudeMimicDebug(req *http.Request, body []byte, account *Account, tokenType string, mimicClaudeCode bool) {
line := buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode)
if line == "" {
return
}
log.Printf("[ClaudeMimicDebug] %s", line)
}
func isClaudeCodeCredentialScopeError(msg string) bool {
m := strings.ToLower(strings.TrimSpace(msg))
if m == "" {
return false
}
return strings.Contains(m, "only authorized for use with claude code") &&
strings.Contains(m, "cannot be used for other api requests")
}
// sseDataRe matches SSE data lines with optional whitespace after colon. // sseDataRe matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: "). // Some upstream APIs return non-standard "data:" without space (should be "data: ").
var ( var (
sseDataRe = regexp.MustCompile(`^data:\s*`) sseDataRe = regexp.MustCompile(`^data:\s*`)
sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`) sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`)
claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`) claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`)
toolPrefixRe = regexp.MustCompile(`(?i)^(?:oc_|mcp_)`)
toolNameBoundaryRe = regexp.MustCompile(`[^a-zA-Z0-9]+`)
toolNameCamelRe = regexp.MustCompile(`([a-z0-9])([A-Z])`)
toolNameFieldRe = regexp.MustCompile(`"name"\s*:\s*"([^"]+)"`)
modelFieldRe = regexp.MustCompile(`"model"\s*:\s*"([^"]+)"`)
toolDescAbsPathRe = regexp.MustCompile(`/\/?(?:home|Users|tmp|var|opt|usr|etc)\/[^\s,\)"'\]]+`)
toolDescWinPathRe = regexp.MustCompile(`(?i)[A-Z]:\\[^\s,\)"'\]]+`)
claudeToolNameOverrides = map[string]string{
"bash": "Bash",
"read": "Read",
"edit": "Edit",
"write": "Write",
"task": "Task",
"glob": "Glob",
"grep": "Grep",
"webfetch": "WebFetch",
"websearch": "WebSearch",
"todowrite": "TodoWrite",
"question": "AskUserQuestion",
}
openCodeToolOverrides = map[string]string{
"Bash": "bash",
"Read": "read",
"Edit": "edit",
"Write": "write",
"Task": "task",
"Glob": "glob",
"Grep": "grep",
"WebFetch": "webfetch",
"WebSearch": "websearch",
"TodoWrite": "todowrite",
"AskUserQuestion": "question",
}
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表 // claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
// 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等 // 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等
...@@ -418,6 +598,394 @@ func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte ...@@ -418,6 +598,394 @@ func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte
return newBody return newBody
} }
type claudeOAuthNormalizeOptions struct {
injectMetadata bool
metadataUserID string
stripSystemCacheControl bool
}
func stripToolPrefix(value string) string {
if value == "" {
return value
}
return toolPrefixRe.ReplaceAllString(value, "")
}
func toPascalCase(value string) string {
if value == "" {
return value
}
normalized := toolNameBoundaryRe.ReplaceAllString(value, " ")
tokens := make([]string, 0)
for _, token := range strings.Fields(normalized) {
expanded := toolNameCamelRe.ReplaceAllString(token, "$1 $2")
parts := strings.Fields(expanded)
if len(parts) > 0 {
tokens = append(tokens, parts...)
}
}
if len(tokens) == 0 {
return value
}
var builder strings.Builder
for _, token := range tokens {
lower := strings.ToLower(token)
if lower == "" {
continue
}
runes := []rune(lower)
runes[0] = unicode.ToUpper(runes[0])
_, _ = builder.WriteString(string(runes))
}
return builder.String()
}
func toSnakeCase(value string) string {
if value == "" {
return value
}
output := toolNameCamelRe.ReplaceAllString(value, "$1_$2")
output = toolNameBoundaryRe.ReplaceAllString(output, "_")
output = strings.Trim(output, "_")
return strings.ToLower(output)
}
func normalizeToolNameForClaude(name string, cache map[string]string) string {
if name == "" {
return name
}
stripped := stripToolPrefix(name)
mapped, ok := claudeToolNameOverrides[strings.ToLower(stripped)]
if !ok {
mapped = toPascalCase(stripped)
}
if mapped != "" && cache != nil && mapped != stripped {
cache[mapped] = stripped
}
if mapped == "" {
return stripped
}
return mapped
}
func normalizeToolNameForOpenCode(name string, cache map[string]string) string {
if name == "" {
return name
}
stripped := stripToolPrefix(name)
if cache != nil {
if mapped, ok := cache[stripped]; ok {
return mapped
}
}
if mapped, ok := openCodeToolOverrides[stripped]; ok {
return mapped
}
return toSnakeCase(stripped)
}
func normalizeParamNameForOpenCode(name string, cache map[string]string) string {
if name == "" {
return name
}
if cache != nil {
if mapped, ok := cache[name]; ok {
return mapped
}
}
return name
}
// sanitizeSystemText rewrites only the fixed OpenCode identity sentence (if present).
// We intentionally avoid broad keyword replacement in system prompts to prevent
// accidentally changing user-provided instructions.
func sanitizeSystemText(text string) string {
if text == "" {
return text
}
// Some clients include a fixed OpenCode identity sentence. Anthropic may treat
// this as a non-Claude-Code fingerprint, so rewrite it to the canonical
// Claude Code banner before generic "OpenCode"/"opencode" replacements.
text = strings.ReplaceAll(
text,
"You are OpenCode, the best coding agent on the planet.",
strings.TrimSpace(claudeCodeSystemPrompt),
)
return text
}
func sanitizeToolDescription(description string) string {
if description == "" {
return description
}
description = toolDescAbsPathRe.ReplaceAllString(description, "[path]")
description = toolDescWinPathRe.ReplaceAllString(description, "[path]")
// Intentionally do NOT rewrite tool descriptions (OpenCode/Claude strings).
// Tool names/skill names may rely on exact wording, and rewriting can be misleading.
return description
}
func normalizeToolInputSchema(inputSchema any, cache map[string]string) {
schema, ok := inputSchema.(map[string]any)
if !ok {
return
}
properties, ok := schema["properties"].(map[string]any)
if !ok {
return
}
newProperties := make(map[string]any, len(properties))
for key, value := range properties {
snakeKey := toSnakeCase(key)
newProperties[snakeKey] = value
if snakeKey != key && cache != nil {
cache[snakeKey] = key
}
}
schema["properties"] = newProperties
if required, ok := schema["required"].([]any); ok {
newRequired := make([]any, 0, len(required))
for _, item := range required {
name, ok := item.(string)
if !ok {
newRequired = append(newRequired, item)
continue
}
snakeName := toSnakeCase(name)
newRequired = append(newRequired, snakeName)
if snakeName != name && cache != nil {
cache[snakeName] = name
}
}
schema["required"] = newRequired
}
}
func stripCacheControlFromSystemBlocks(system any) bool {
blocks, ok := system.([]any)
if !ok {
return false
}
changed := false
for _, item := range blocks {
block, ok := item.(map[string]any)
if !ok {
continue
}
if _, exists := block["cache_control"]; !exists {
continue
}
delete(block, "cache_control")
changed = true
}
return changed
}
func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAuthNormalizeOptions) ([]byte, string, map[string]string) {
if len(body) == 0 {
return body, modelID, nil
}
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return body, modelID, nil
}
toolNameMap := make(map[string]string)
if system, ok := req["system"]; ok {
switch v := system.(type) {
case string:
sanitized := sanitizeSystemText(v)
if sanitized != v {
req["system"] = sanitized
}
case []any:
for _, item := range v {
block, ok := item.(map[string]any)
if !ok {
continue
}
if blockType, _ := block["type"].(string); blockType != "text" {
continue
}
text, ok := block["text"].(string)
if !ok || text == "" {
continue
}
sanitized := sanitizeSystemText(text)
if sanitized != text {
block["text"] = sanitized
}
}
}
}
if rawModel, ok := req["model"].(string); ok {
normalized := claude.NormalizeModelID(rawModel)
if normalized != rawModel {
req["model"] = normalized
modelID = normalized
}
}
if rawTools, exists := req["tools"]; exists {
switch tools := rawTools.(type) {
case []any:
for idx, tool := range tools {
toolMap, ok := tool.(map[string]any)
if !ok {
continue
}
if name, ok := toolMap["name"].(string); ok {
normalized := normalizeToolNameForClaude(name, toolNameMap)
if normalized != "" && normalized != name {
toolMap["name"] = normalized
}
}
if desc, ok := toolMap["description"].(string); ok {
sanitized := sanitizeToolDescription(desc)
if sanitized != desc {
toolMap["description"] = sanitized
}
}
if schema, ok := toolMap["input_schema"]; ok {
normalizeToolInputSchema(schema, toolNameMap)
}
tools[idx] = toolMap
}
req["tools"] = tools
case map[string]any:
normalizedTools := make(map[string]any, len(tools))
for name, value := range tools {
normalized := normalizeToolNameForClaude(name, toolNameMap)
if normalized == "" {
normalized = name
}
if toolMap, ok := value.(map[string]any); ok {
toolMap["name"] = normalized
if desc, ok := toolMap["description"].(string); ok {
sanitized := sanitizeToolDescription(desc)
if sanitized != desc {
toolMap["description"] = sanitized
}
}
if schema, ok := toolMap["input_schema"]; ok {
normalizeToolInputSchema(schema, toolNameMap)
}
normalizedTools[normalized] = toolMap
continue
}
normalizedTools[normalized] = value
}
req["tools"] = normalizedTools
}
} else {
req["tools"] = []any{}
}
if messages, ok := req["messages"].([]any); ok {
for _, msg := range messages {
msgMap, ok := msg.(map[string]any)
if !ok {
continue
}
content, ok := msgMap["content"].([]any)
if !ok {
continue
}
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
continue
}
if blockType, _ := blockMap["type"].(string); blockType != "tool_use" {
continue
}
if name, ok := blockMap["name"].(string); ok {
normalized := normalizeToolNameForClaude(name, toolNameMap)
if normalized != "" && normalized != name {
blockMap["name"] = normalized
}
}
}
}
}
if opts.stripSystemCacheControl {
if system, ok := req["system"]; ok {
_ = stripCacheControlFromSystemBlocks(system)
}
}
if opts.injectMetadata && opts.metadataUserID != "" {
metadata, ok := req["metadata"].(map[string]any)
if !ok {
metadata = map[string]any{}
req["metadata"] = metadata
}
if existing, ok := metadata["user_id"].(string); !ok || existing == "" {
metadata["user_id"] = opts.metadataUserID
}
}
delete(req, "temperature")
delete(req, "tool_choice")
newBody, err := json.Marshal(req)
if err != nil {
return body, modelID, toolNameMap
}
return newBody, modelID, toolNameMap
}
func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account *Account, fp *Fingerprint) string {
if parsed == nil || account == nil {
return ""
}
if parsed.MetadataUserID != "" {
return ""
}
userID := strings.TrimSpace(account.GetClaudeUserID())
if userID == "" && fp != nil {
userID = fp.ClientID
}
if userID == "" {
// Fall back to a random, well-formed client id so we can still satisfy
// Claude Code OAuth requirements when account metadata is incomplete.
userID = generateClientID()
}
sessionHash := s.GenerateSessionHash(parsed)
sessionID := uuid.NewString()
if sessionHash != "" {
seed := fmt.Sprintf("%d::%s", account.ID, sessionHash)
sessionID = generateSessionUUID(seed)
}
// Prefer the newer format that includes account_uuid (if present),
// otherwise fall back to the legacy Claude Code format.
accountUUID := strings.TrimSpace(account.GetExtraString("account_uuid"))
if accountUUID != "" {
return fmt.Sprintf("user_%s_account_%s_session_%s", userID, accountUUID, sessionID)
}
return fmt.Sprintf("user_%s_account__session_%s", userID, sessionID)
}
func generateSessionUUID(seed string) string {
if seed == "" {
return uuid.NewString()
}
hash := sha256.Sum256([]byte(seed))
bytes := hash[:16]
bytes[6] = (bytes[6] & 0x0f) | 0x40
bytes[8] = (bytes[8] & 0x3f) | 0x80
return fmt.Sprintf("%x-%x-%x-%x-%x",
bytes[0:4], bytes[4:6], bytes[6:8], bytes[8:10], bytes[10:16])
}
// SelectAccount 选择账号(粘性会话+优先级) // SelectAccount 选择账号(粘性会话+优先级)
func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) { func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) {
return s.SelectAccountForModel(ctx, groupID, sessionHash, "") return s.SelectAccountForModel(ctx, groupID, sessionHash, "")
...@@ -2021,6 +2589,16 @@ func isClaudeCodeClient(userAgent string, metadataUserID string) bool { ...@@ -2021,6 +2589,16 @@ func isClaudeCodeClient(userAgent string, metadataUserID string) bool {
return claudeCliUserAgentRe.MatchString(userAgent) return claudeCliUserAgentRe.MatchString(userAgent)
} }
func isClaudeCodeRequest(ctx context.Context, c *gin.Context, parsed *ParsedRequest) bool {
if IsClaudeCodeClient(ctx) {
return true
}
if parsed == nil || c == nil {
return false
}
return isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID)
}
// systemIncludesClaudeCodePrompt 检查 system 中是否已包含 Claude Code 提示词 // systemIncludesClaudeCodePrompt 检查 system 中是否已包含 Claude Code 提示词
// 使用前缀匹配支持多种变体(标准版、Agent SDK 版等) // 使用前缀匹配支持多种变体(标准版、Agent SDK 版等)
func systemIncludesClaudeCodePrompt(system any) bool { func systemIncludesClaudeCodePrompt(system any) bool {
...@@ -2057,6 +2635,10 @@ func injectClaudeCodePrompt(body []byte, system any) []byte { ...@@ -2057,6 +2635,10 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
"text": claudeCodeSystemPrompt, "text": claudeCodeSystemPrompt,
"cache_control": map[string]string{"type": "ephemeral"}, "cache_control": map[string]string{"type": "ephemeral"},
} }
// Opencode plugin applies an extra safeguard: it not only prepends the Claude Code
// banner, it also prefixes the next system instruction with the same banner plus
// a blank line. This helps when upstream concatenates system instructions.
claudeCodePrefix := strings.TrimSpace(claudeCodeSystemPrompt)
var newSystem []any var newSystem []any
...@@ -2064,19 +2646,36 @@ func injectClaudeCodePrompt(body []byte, system any) []byte { ...@@ -2064,19 +2646,36 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
case nil: case nil:
newSystem = []any{claudeCodeBlock} newSystem = []any{claudeCodeBlock}
case string: case string:
if v == "" || v == claudeCodeSystemPrompt { // Be tolerant of older/newer clients that may differ only by trailing whitespace/newlines.
if strings.TrimSpace(v) == "" || strings.TrimSpace(v) == strings.TrimSpace(claudeCodeSystemPrompt) {
newSystem = []any{claudeCodeBlock} newSystem = []any{claudeCodeBlock}
} else { } else {
newSystem = []any{claudeCodeBlock, map[string]any{"type": "text", "text": v}} // Mirror opencode behavior: keep the banner as a separate system entry,
// but also prefix the next system text with the banner.
merged := v
if !strings.HasPrefix(v, claudeCodePrefix) {
merged = claudeCodePrefix + "\n\n" + v
}
newSystem = []any{claudeCodeBlock, map[string]any{"type": "text", "text": merged}}
} }
case []any: case []any:
newSystem = make([]any, 0, len(v)+1) newSystem = make([]any, 0, len(v)+1)
newSystem = append(newSystem, claudeCodeBlock) newSystem = append(newSystem, claudeCodeBlock)
prefixedNext := false
for _, item := range v { for _, item := range v {
if m, ok := item.(map[string]any); ok { if m, ok := item.(map[string]any); ok {
if text, ok := m["text"].(string); ok && text == claudeCodeSystemPrompt { if text, ok := m["text"].(string); ok && strings.TrimSpace(text) == strings.TrimSpace(claudeCodeSystemPrompt) {
continue continue
} }
// Prefix the first subsequent text system block once.
if !prefixedNext {
if blockType, _ := m["type"].(string); blockType == "text" {
if text, ok := m["text"].(string); ok && strings.TrimSpace(text) != "" && !strings.HasPrefix(text, claudeCodePrefix) {
m["text"] = claudeCodePrefix + "\n\n" + text
prefixedNext = true
}
}
}
} }
newSystem = append(newSystem, item) newSystem = append(newSystem, item)
} }
...@@ -2280,21 +2879,38 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -2280,21 +2879,38 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
body := parsed.Body body := parsed.Body
reqModel := parsed.Model reqModel := parsed.Model
reqStream := parsed.Stream reqStream := parsed.Stream
originalModel := reqModel
var toolNameMap map[string]string
isClaudeCode := isClaudeCodeRequest(ctx, c, parsed)
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
if shouldMimicClaudeCode {
// 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要) // 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要)
// 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词 // 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
if account.IsOAuth() && if !strings.Contains(strings.ToLower(reqModel), "haiku") &&
!isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) &&
!strings.Contains(strings.ToLower(reqModel), "haiku") &&
!systemIncludesClaudeCodePrompt(parsed.System) { !systemIncludesClaudeCodePrompt(parsed.System) {
body = injectClaudeCodePrompt(body, parsed.System) body = injectClaudeCodePrompt(body, parsed.System)
} }
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true}
if s.identityService != nil {
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
if err == nil && fp != nil {
if metadataUserID := s.buildOAuthMetadataUserID(parsed, account, fp); metadataUserID != "" {
normalizeOpts.injectMetadata = true
normalizeOpts.metadataUserID = metadataUserID
}
}
}
body, reqModel, toolNameMap = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
}
// 强制执行 cache_control 块数量限制(最多 4 个) // 强制执行 cache_control 块数量限制(最多 4 个)
body = enforceCacheControlLimit(body) body = enforceCacheControlLimit(body)
// 应用模型映射(仅对apikey类型账号) // 应用模型映射(仅对apikey类型账号)
originalModel := reqModel
if account.Type == AccountTypeAPIKey { if account.Type == AccountTypeAPIKey {
mappedModel := account.GetMappedModel(reqModel) mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel { if mappedModel != reqModel {
...@@ -2326,10 +2942,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -2326,10 +2942,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
retryStart := time.Now() retryStart := time.Now()
for attempt := 1; attempt <= maxRetryAttempts; attempt++ { for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取) // 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel)
// Capture upstream request body for ops retry of this attempt. // Capture upstream request body for ops retry of this attempt.
c.Set(OpsUpstreamRequestBodyKey, string(body)) c.Set(OpsUpstreamRequestBodyKey, string(body))
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -2407,7 +3022,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -2407,7 +3022,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// also downgrade tool_use/tool_result blocks to text. // also downgrade tool_use/tool_result blocks to text.
filteredBody := FilterThinkingBlocksForRetry(body) filteredBody := FilterThinkingBlocksForRetry(body)
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel) retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
if buildErr == nil { if buildErr == nil {
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if retryErr == nil { if retryErr == nil {
...@@ -2439,7 +3054,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -2439,7 +3054,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed { if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed {
log.Printf("Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID) log.Printf("Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID)
filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body) filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body)
retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel) retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
if buildErr2 == nil { if buildErr2 == nil {
retryResp2, retryErr2 := s.httpUpstream.DoWithTLS(retryReq2, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) retryResp2, retryErr2 := s.httpUpstream.DoWithTLS(retryReq2, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if retryErr2 == nil { if retryErr2 == nil {
...@@ -2664,7 +3279,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -2664,7 +3279,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
var firstTokenMs *int var firstTokenMs *int
var clientDisconnect bool var clientDisconnect bool
if reqStream { if reqStream {
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel) streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel, toolNameMap, shouldMimicClaudeCode)
if err != nil { if err != nil {
if err.Error() == "have error in stream" { if err.Error() == "have error in stream" {
return nil, &UpstreamFailoverError{ return nil, &UpstreamFailoverError{
...@@ -2677,7 +3292,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -2677,7 +3292,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
firstTokenMs = streamResult.firstTokenMs firstTokenMs = streamResult.firstTokenMs
clientDisconnect = streamResult.clientDisconnect clientDisconnect = streamResult.clientDisconnect
} else { } else {
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel) usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel, toolNameMap, shouldMimicClaudeCode)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -2694,7 +3309,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -2694,7 +3309,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}, nil }, nil
} }
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) { func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool, mimicClaudeCode bool) (*http.Request, error) {
// 确定目标URL // 确定目标URL
targetURL := claudeAPIURL targetURL := claudeAPIURL
if account.Type == AccountTypeAPIKey { if account.Type == AccountTypeAPIKey {
...@@ -2708,11 +3323,16 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex ...@@ -2708,11 +3323,16 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
} }
} }
clientHeaders := http.Header{}
if c != nil && c.Request != nil {
clientHeaders = c.Request.Header
}
// OAuth账号:应用统一指纹 // OAuth账号:应用统一指纹
var fingerprint *Fingerprint var fingerprint *Fingerprint
if account.IsOAuth() && s.identityService != nil { if account.IsOAuth() && s.identityService != nil {
// 1. 获取或创建指纹(包含随机生成的ClientID) // 1. 获取或创建指纹(包含随机生成的ClientID)
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, clientHeaders)
if err != nil { if err != nil {
log.Printf("Warning: failed to get fingerprint for account %d: %v", account.ID, err) log.Printf("Warning: failed to get fingerprint for account %d: %v", account.ID, err)
// 失败时降级为透传原始headers // 失败时降级为透传原始headers
...@@ -2743,7 +3363,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex ...@@ -2743,7 +3363,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
} }
// 白名单透传headers // 白名单透传headers
for key, values := range c.Request.Header { for key, values := range clientHeaders {
lowerKey := strings.ToLower(key) lowerKey := strings.ToLower(key)
if allowedHeaders[lowerKey] { if allowedHeaders[lowerKey] {
for _, v := range values { for _, v := range values {
...@@ -2764,10 +3384,30 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex ...@@ -2764,10 +3384,30 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
if req.Header.Get("anthropic-version") == "" { if req.Header.Get("anthropic-version") == "" {
req.Header.Set("anthropic-version", "2023-06-01") req.Header.Set("anthropic-version", "2023-06-01")
} }
if tokenType == "oauth" {
applyClaudeOAuthHeaderDefaults(req, reqStream)
}
// 处理anthropic-beta header(OAuth账号需要特殊处理 // 处理 anthropic-beta header(OAuth 账号需要包含 oauth beta
if tokenType == "oauth" { if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) if mimicClaudeCode {
// 非 Claude Code 客户端:按 opencode 的策略处理:
// - 强制 Claude Code 指纹相关请求头(尤其是 user-agent/x-stainless/x-app)
// - 保留 incoming beta 的同时,确保 OAuth 所需 beta 存在
applyClaudeCodeMimicHeaders(req, reqStream)
incomingBeta := req.Header.Get("anthropic-beta")
// Match real Claude CLI traffic (per mitmproxy reports):
// messages requests typically use only oauth + interleaved-thinking.
// Also drop claude-code beta if a downstream client added it.
requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking}
drop := map[string]struct{}{claude.BetaClaudeCode: {}}
req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop))
} else {
// Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta
clientBetaHeader := req.Header.Get("anthropic-beta")
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, clientBetaHeader))
}
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" { } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭) // API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
if requestNeedsBetaFeatures(body) { if requestNeedsBetaFeatures(body) {
...@@ -2777,6 +3417,15 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex ...@@ -2777,6 +3417,15 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
} }
} }
// Always capture a compact fingerprint line for later error diagnostics.
// We only print it when needed (or when the explicit debug flag is enabled).
if c != nil && tokenType == "oauth" {
c.Set(claudeMimicDebugInfoKey, buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode))
}
if s.debugClaudeMimicEnabled() {
logClaudeMimicDebug(req, body, account, tokenType, mimicClaudeCode)
}
return req, nil return req, nil
} }
...@@ -2846,6 +3495,93 @@ func defaultAPIKeyBetaHeader(body []byte) string { ...@@ -2846,6 +3495,93 @@ func defaultAPIKeyBetaHeader(body []byte) string {
return claude.APIKeyBetaHeader return claude.APIKeyBetaHeader
} }
func applyClaudeOAuthHeaderDefaults(req *http.Request, isStream bool) {
if req == nil {
return
}
if req.Header.Get("accept") == "" {
req.Header.Set("accept", "application/json")
}
for key, value := range claude.DefaultHeaders {
if value == "" {
continue
}
if req.Header.Get(key) == "" {
req.Header.Set(key, value)
}
}
if isStream && req.Header.Get("x-stainless-helper-method") == "" {
req.Header.Set("x-stainless-helper-method", "stream")
}
}
func mergeAnthropicBeta(required []string, incoming string) string {
seen := make(map[string]struct{}, len(required)+8)
out := make([]string, 0, len(required)+8)
add := func(v string) {
v = strings.TrimSpace(v)
if v == "" {
return
}
if _, ok := seen[v]; ok {
return
}
seen[v] = struct{}{}
out = append(out, v)
}
for _, r := range required {
add(r)
}
for _, p := range strings.Split(incoming, ",") {
add(p)
}
return strings.Join(out, ",")
}
func mergeAnthropicBetaDropping(required []string, incoming string, drop map[string]struct{}) string {
merged := mergeAnthropicBeta(required, incoming)
if merged == "" || len(drop) == 0 {
return merged
}
out := make([]string, 0, 8)
for _, p := range strings.Split(merged, ",") {
p = strings.TrimSpace(p)
if p == "" {
continue
}
if _, ok := drop[p]; ok {
continue
}
out = append(out, p)
}
return strings.Join(out, ",")
}
// applyClaudeCodeMimicHeaders forces "Claude Code-like" request headers.
// This mirrors opencode-anthropic-auth behavior: do not trust downstream
// headers when using Claude Code-scoped OAuth credentials.
func applyClaudeCodeMimicHeaders(req *http.Request, isStream bool) {
if req == nil {
return
}
// Start with the standard defaults (fill missing).
applyClaudeOAuthHeaderDefaults(req, isStream)
// Then force key headers to match Claude Code fingerprint regardless of what the client sent.
for key, value := range claude.DefaultHeaders {
if value == "" {
continue
}
req.Header.Set(key, value)
}
// Real Claude CLI uses Accept: application/json (even for streaming).
req.Header.Set("accept", "application/json")
if isStream {
req.Header.Set("x-stainless-helper-method", "stream")
}
}
func truncateForLog(b []byte, maxBytes int) string { func truncateForLog(b []byte, maxBytes int) string {
if maxBytes <= 0 { if maxBytes <= 0 {
maxBytes = 2048 maxBytes = 2048
...@@ -2949,6 +3685,20 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res ...@@ -2949,6 +3685,20 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
// Print a compact upstream request fingerprint when we hit the Claude Code OAuth
// credential scope error. This avoids requiring env-var tweaks in a fixed deploy.
if isClaudeCodeCredentialScopeError(upstreamMsg) && c != nil {
if v, ok := c.Get(claudeMimicDebugInfoKey); ok {
if line, ok := v.(string); ok && strings.TrimSpace(line) != "" {
log.Printf("[ClaudeMimicDebugOnError] status=%d request_id=%s %s",
resp.StatusCode,
resp.Header.Get("x-request-id"),
line,
)
}
}
}
// Enrich Ops error logs with upstream status + message, and optionally a truncated body snippet. // Enrich Ops error logs with upstream status + message, and optionally a truncated body snippet.
upstreamDetail := "" upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
...@@ -3078,6 +3828,19 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht ...@@ -3078,6 +3828,19 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if isClaudeCodeCredentialScopeError(upstreamMsg) && c != nil {
if v, ok := c.Get(claudeMimicDebugInfoKey); ok {
if line, ok := v.(string); ok && strings.TrimSpace(line) != "" {
log.Printf("[ClaudeMimicDebugOnError] status=%d request_id=%s %s",
resp.StatusCode,
resp.Header.Get("x-request-id"),
line,
)
}
}
}
upstreamDetail := "" upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
...@@ -3130,7 +3893,7 @@ type streamingResult struct { ...@@ -3130,7 +3893,7 @@ type streamingResult struct {
clientDisconnect bool // 客户端是否在流式传输过程中断开 clientDisconnect bool // 客户端是否在流式传输过程中断开
} }
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) { func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string, toolNameMap map[string]string, mimicClaudeCode bool) (*streamingResult, error) {
// 更新5h窗口状态 // 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
...@@ -3225,6 +3988,171 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http ...@@ -3225,6 +3988,171 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
needModelReplace := originalModel != mappedModel needModelReplace := originalModel != mappedModel
clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage
pendingEventLines := make([]string, 0, 4)
var toolInputBuffers map[int]string
if mimicClaudeCode {
toolInputBuffers = make(map[int]string)
}
transformToolInputJSON := func(raw string) string {
if !mimicClaudeCode {
return raw
}
raw = strings.TrimSpace(raw)
if raw == "" {
return raw
}
var parsed any
if err := json.Unmarshal([]byte(raw), &parsed); err != nil {
return replaceToolNamesInText(raw, toolNameMap)
}
rewritten, changed := rewriteParamKeysInValue(parsed, toolNameMap)
if changed {
if bytes, err := json.Marshal(rewritten); err == nil {
return string(bytes)
}
}
return raw
}
processSSEEvent := func(lines []string) ([]string, string, error) {
if len(lines) == 0 {
return nil, "", nil
}
eventName := ""
dataLine := ""
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if strings.HasPrefix(trimmed, "event:") {
eventName = strings.TrimSpace(strings.TrimPrefix(trimmed, "event:"))
continue
}
if dataLine == "" && sseDataRe.MatchString(trimmed) {
dataLine = sseDataRe.ReplaceAllString(trimmed, "")
}
}
if eventName == "error" {
return nil, dataLine, errors.New("have error in stream")
}
if dataLine == "" {
return []string{strings.Join(lines, "\n") + "\n\n"}, "", nil
}
if dataLine == "[DONE]" {
block := ""
if eventName != "" {
block = "event: " + eventName + "\n"
}
block += "data: " + dataLine + "\n\n"
return []string{block}, dataLine, nil
}
var event map[string]any
if err := json.Unmarshal([]byte(dataLine), &event); err != nil {
replaced := dataLine
if mimicClaudeCode {
replaced = replaceToolNamesInText(dataLine, toolNameMap)
}
block := ""
if eventName != "" {
block = "event: " + eventName + "\n"
}
block += "data: " + replaced + "\n\n"
return []string{block}, replaced, nil
}
eventType, _ := event["type"].(string)
if eventName == "" {
eventName = eventType
}
if needModelReplace {
if msg, ok := event["message"].(map[string]any); ok {
if model, ok := msg["model"].(string); ok && model == mappedModel {
msg["model"] = originalModel
}
}
}
if mimicClaudeCode && eventType == "content_block_delta" {
if delta, ok := event["delta"].(map[string]any); ok {
if deltaType, _ := delta["type"].(string); deltaType == "input_json_delta" {
if indexVal, ok := event["index"].(float64); ok {
index := int(indexVal)
if partial, ok := delta["partial_json"].(string); ok {
toolInputBuffers[index] += partial
}
}
return nil, dataLine, nil
}
}
}
if mimicClaudeCode && eventType == "content_block_stop" {
if indexVal, ok := event["index"].(float64); ok {
index := int(indexVal)
if buffered := toolInputBuffers[index]; buffered != "" {
delete(toolInputBuffers, index)
transformed := transformToolInputJSON(buffered)
synthetic := map[string]any{
"type": "content_block_delta",
"index": index,
"delta": map[string]any{
"type": "input_json_delta",
"partial_json": transformed,
},
}
synthBytes, synthErr := json.Marshal(synthetic)
if synthErr == nil {
synthBlock := "event: content_block_delta\n" + "data: " + string(synthBytes) + "\n\n"
rewriteToolNamesInValue(event, toolNameMap)
stopBytes, stopErr := json.Marshal(event)
if stopErr == nil {
stopBlock := ""
if eventName != "" {
stopBlock = "event: " + eventName + "\n"
}
stopBlock += "data: " + string(stopBytes) + "\n\n"
return []string{synthBlock, stopBlock}, string(stopBytes), nil
}
}
}
}
}
if mimicClaudeCode {
rewriteToolNamesInValue(event, toolNameMap)
}
newData, err := json.Marshal(event)
if err != nil {
replaced := dataLine
if mimicClaudeCode {
replaced = replaceToolNamesInText(dataLine, toolNameMap)
}
block := ""
if eventName != "" {
block = "event: " + eventName + "\n"
}
block += "data: " + replaced + "\n\n"
return []string{block}, replaced, nil
}
block := ""
if eventName != "" {
block = "event: " + eventName + "\n"
}
block += "data: " + string(newData) + "\n\n"
return []string{block}, string(newData), nil
}
for { for {
select { select {
case ev, ok := <-events: case ev, ok := <-events:
...@@ -3253,35 +4181,31 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http ...@@ -3253,35 +4181,31 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err) return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
} }
line := ev.line line := ev.line
if line == "event: error" { trimmed := strings.TrimSpace(line)
// 上游返回错误事件,如果客户端已断开仍返回已收集的 usage
if clientDisconnected { if trimmed == "" {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil if len(pendingEventLines) == 0 {
} continue
return nil, errors.New("have error in stream")
} }
// Extract data from SSE line (supports both "data: " and "data:" formats) outputBlocks, data, err := processSSEEvent(pendingEventLines)
var data string pendingEventLines = pendingEventLines[:0]
if sseDataRe.MatchString(line) { if err != nil {
data = sseDataRe.ReplaceAllString(line, "") if clientDisconnected {
// 如果有模型映射,替换响应中的model字段 return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
if needModelReplace {
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
} }
return nil, err
} }
// 写入客户端(统一处理 data 行和非 data 行) for _, block := range outputBlocks {
if !clientDisconnected { if !clientDisconnected {
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { if _, werr := fmt.Fprint(w, block); werr != nil {
clientDisconnected = true clientDisconnected = true
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing") log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
} else { break
flusher.Flush()
} }
flusher.Flush()
} }
// 无论客户端是否断开,都解析 usage(仅对 data 行)
if data != "" { if data != "" {
if firstTokenMs == nil && data != "[DONE]" { if firstTokenMs == nil && data != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds()) ms := int(time.Since(startTime).Milliseconds())
...@@ -3289,6 +4213,11 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http ...@@ -3289,6 +4213,11 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
} }
s.parseSSEUsage(data, usage) s.parseSSEUsage(data, usage)
} }
}
continue
}
pendingEventLines = append(pendingEventLines, line)
case <-intervalCh: case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
...@@ -3312,43 +4241,124 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http ...@@ -3312,43 +4241,124 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
} }
// replaceModelInSSELine 替换SSE数据行中的model字段 func rewriteParamKeysInValue(value any, cache map[string]string) (any, bool) {
func (s *GatewayService) replaceModelInSSELine(line, fromModel, toModel string) string { switch v := value.(type) {
if !sseDataRe.MatchString(line) { case map[string]any:
return line changed := false
rewritten := make(map[string]any, len(v))
for key, item := range v {
newKey := normalizeParamNameForOpenCode(key, cache)
newItem, childChanged := rewriteParamKeysInValue(item, cache)
if childChanged {
changed = true
} }
data := sseDataRe.ReplaceAllString(line, "") if newKey != key {
if data == "" || data == "[DONE]" { changed = true
return line
} }
rewritten[newKey] = newItem
var event map[string]any
if err := json.Unmarshal([]byte(data), &event); err != nil {
return line
} }
if !changed {
// 只替换 message_start 事件中的 message.model return value, false
if event["type"] != "message_start" {
return line
} }
return rewritten, true
case []any:
changed := false
rewritten := make([]any, len(v))
for idx, item := range v {
newItem, childChanged := rewriteParamKeysInValue(item, cache)
if childChanged {
changed = true
}
rewritten[idx] = newItem
}
if !changed {
return value, false
}
return rewritten, true
default:
return value, false
}
}
msg, ok := event["message"].(map[string]any) func rewriteToolNamesInValue(value any, toolNameMap map[string]string) bool {
if !ok { switch v := value.(type) {
return line case map[string]any:
changed := false
if blockType, _ := v["type"].(string); blockType == "tool_use" {
if name, ok := v["name"].(string); ok {
mapped := normalizeToolNameForOpenCode(name, toolNameMap)
if mapped != name {
v["name"] = mapped
changed = true
}
}
if input, ok := v["input"].(map[string]any); ok {
rewrittenInput, inputChanged := rewriteParamKeysInValue(input, toolNameMap)
if inputChanged {
if m, ok := rewrittenInput.(map[string]any); ok {
v["input"] = m
changed = true
}
} }
}
}
for _, item := range v {
if rewriteToolNamesInValue(item, toolNameMap) {
changed = true
}
}
return changed
case []any:
changed := false
for _, item := range v {
if rewriteToolNamesInValue(item, toolNameMap) {
changed = true
}
}
return changed
default:
return false
}
}
model, ok := msg["model"].(string) func replaceToolNamesInText(text string, toolNameMap map[string]string) string {
if !ok || model != fromModel { if text == "" {
return line return text
}
output := toolNameFieldRe.ReplaceAllStringFunc(text, func(match string) string {
submatches := toolNameFieldRe.FindStringSubmatch(match)
if len(submatches) < 2 {
return match
}
name := submatches[1]
mapped := normalizeToolNameForOpenCode(name, toolNameMap)
if mapped == name {
return match
}
return strings.Replace(match, name, mapped, 1)
})
output = modelFieldRe.ReplaceAllStringFunc(output, func(match string) string {
submatches := modelFieldRe.FindStringSubmatch(match)
if len(submatches) < 2 {
return match
}
model := submatches[1]
mapped := claude.DenormalizeModelID(model)
if mapped == model {
return match
} }
return strings.Replace(match, model, mapped, 1)
})
msg["model"] = toModel for mapped, original := range toolNameMap {
newData, err := json.Marshal(event) if mapped == "" || original == "" || mapped == original {
if err != nil { continue
return line }
output = strings.ReplaceAll(output, "\""+mapped+"\":", "\""+original+"\":")
output = strings.ReplaceAll(output, "\\\""+mapped+"\\\":", "\\\""+original+"\\\":")
} }
return "data: " + string(newData) return output
} }
func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) { func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
...@@ -3394,7 +4404,7 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) { ...@@ -3394,7 +4404,7 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
} }
} }
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) { func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string, toolNameMap map[string]string, mimicClaudeCode bool) (*ClaudeUsage, error) {
// 更新5h窗口状态 // 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
...@@ -3415,6 +4425,9 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h ...@@ -3415,6 +4425,9 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
if originalModel != mappedModel { if originalModel != mappedModel {
body = s.replaceModelInResponseBody(body, mappedModel, originalModel) body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
} }
if mimicClaudeCode {
body = s.replaceToolNamesInResponseBody(body, toolNameMap)
}
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
...@@ -3452,6 +4465,28 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo ...@@ -3452,6 +4465,28 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
return newBody return newBody
} }
func (s *GatewayService) replaceToolNamesInResponseBody(body []byte, toolNameMap map[string]string) []byte {
if len(body) == 0 {
return body
}
var resp map[string]any
if err := json.Unmarshal(body, &resp); err != nil {
replaced := replaceToolNamesInText(string(body), toolNameMap)
if replaced == string(body) {
return body
}
return []byte(replaced)
}
if !rewriteToolNamesInValue(resp, toolNameMap) {
return body
}
newBody, err := json.Marshal(resp)
if err != nil {
return body
}
return newBody
}
// RecordUsageInput 记录使用量的输入参数 // RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct { type RecordUsageInput struct {
Result *ForwardResult Result *ForwardResult
...@@ -3773,6 +4808,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, ...@@ -3773,6 +4808,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
body := parsed.Body body := parsed.Body
reqModel := parsed.Model reqModel := parsed.Model
isClaudeCode := isClaudeCodeRequest(ctx, c, parsed)
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
if shouldMimicClaudeCode {
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true}
body, reqModel, _ = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
}
// Antigravity 账户不支持 count_tokens 转发,直接返回空值 // Antigravity 账户不支持 count_tokens 转发,直接返回空值
if account.Platform == PlatformAntigravity { if account.Platform == PlatformAntigravity {
c.JSON(http.StatusOK, gin.H{"input_tokens": 0}) c.JSON(http.StatusOK, gin.H{"input_tokens": 0})
...@@ -3799,7 +4842,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, ...@@ -3799,7 +4842,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
} }
// 构建上游请求 // 构建上游请求
upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType, reqModel) upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType, reqModel, shouldMimicClaudeCode)
if err != nil { if err != nil {
s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request") s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request")
return err return err
...@@ -3832,7 +4875,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, ...@@ -3832,7 +4875,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
log.Printf("Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID) log.Printf("Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID)
filteredBody := FilterThinkingBlocksForRetry(body) filteredBody := FilterThinkingBlocksForRetry(body)
retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel) retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, shouldMimicClaudeCode)
if buildErr == nil { if buildErr == nil {
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if retryErr == nil { if retryErr == nil {
...@@ -3897,7 +4940,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, ...@@ -3897,7 +4940,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
} }
// buildCountTokensRequest 构建 count_tokens 上游请求 // buildCountTokensRequest 构建 count_tokens 上游请求
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) { func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, mimicClaudeCode bool) (*http.Request, error) {
// 确定目标 URL // 确定目标 URL
targetURL := claudeAPICountTokensURL targetURL := claudeAPICountTokensURL
if account.Type == AccountTypeAPIKey { if account.Type == AccountTypeAPIKey {
...@@ -3911,10 +4954,15 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con ...@@ -3911,10 +4954,15 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
} }
} }
clientHeaders := http.Header{}
if c != nil && c.Request != nil {
clientHeaders = c.Request.Header
}
// OAuth 账号:应用统一指纹和重写 userID // OAuth 账号:应用统一指纹和重写 userID
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值 // 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
if account.IsOAuth() && s.identityService != nil { if account.IsOAuth() && s.identityService != nil {
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, clientHeaders)
if err == nil { if err == nil {
accountUUID := account.GetExtraString("account_uuid") accountUUID := account.GetExtraString("account_uuid")
if accountUUID != "" && fp.ClientID != "" { if accountUUID != "" && fp.ClientID != "" {
...@@ -3938,7 +4986,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con ...@@ -3938,7 +4986,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
} }
// 白名单透传 headers // 白名单透传 headers
for key, values := range c.Request.Header { for key, values := range clientHeaders {
lowerKey := strings.ToLower(key) lowerKey := strings.ToLower(key)
if allowedHeaders[lowerKey] { if allowedHeaders[lowerKey] {
for _, v := range values { for _, v := range values {
...@@ -3949,7 +4997,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con ...@@ -3949,7 +4997,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
// OAuth 账号:应用指纹到请求头 // OAuth 账号:应用指纹到请求头
if account.IsOAuth() && s.identityService != nil { if account.IsOAuth() && s.identityService != nil {
fp, _ := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) fp, _ := s.identityService.GetOrCreateFingerprint(ctx, account.ID, clientHeaders)
if fp != nil { if fp != nil {
s.identityService.ApplyFingerprint(req, fp) s.identityService.ApplyFingerprint(req, fp)
} }
...@@ -3962,10 +5010,30 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con ...@@ -3962,10 +5010,30 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
if req.Header.Get("anthropic-version") == "" { if req.Header.Get("anthropic-version") == "" {
req.Header.Set("anthropic-version", "2023-06-01") req.Header.Set("anthropic-version", "2023-06-01")
} }
if tokenType == "oauth" {
applyClaudeOAuthHeaderDefaults(req, false)
}
// OAuth 账号:处理 anthropic-beta header // OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" { if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) if mimicClaudeCode {
applyClaudeCodeMimicHeaders(req, false)
incomingBeta := req.Header.Get("anthropic-beta")
requiredBetas := []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking, claude.BetaTokenCounting}
req.Header.Set("anthropic-beta", mergeAnthropicBeta(requiredBetas, incomingBeta))
} else {
clientBetaHeader := req.Header.Get("anthropic-beta")
if clientBetaHeader == "" {
req.Header.Set("anthropic-beta", claude.CountTokensBetaHeader)
} else {
beta := s.getBetaHeader(modelID, clientBetaHeader)
if !strings.Contains(beta, claude.BetaTokenCounting) {
beta = beta + "," + claude.BetaTokenCounting
}
req.Header.Set("anthropic-beta", beta)
}
}
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" { } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key:与 messages 同步的按需 beta 注入(默认关闭) // API-key:与 messages 同步的按需 beta 注入(默认关闭)
if requestNeedsBetaFeatures(body) { if requestNeedsBetaFeatures(body) {
...@@ -3975,6 +5043,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con ...@@ -3975,6 +5043,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
} }
} }
if c != nil && tokenType == "oauth" {
c.Set(claudeMimicDebugInfoKey, buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode))
}
if s.debugClaudeMimicEnabled() {
logClaudeMimicDebug(req, body, account, tokenType, mimicClaudeCode)
}
return req, nil return req, nil
} }
......
...@@ -26,13 +26,13 @@ var ( ...@@ -26,13 +26,13 @@ var (
// 默认指纹值(当客户端未提供时使用) // 默认指纹值(当客户端未提供时使用)
var defaultFingerprint = Fingerprint{ var defaultFingerprint = Fingerprint{
UserAgent: "claude-cli/2.0.62 (external, cli)", UserAgent: "claude-cli/2.1.22 (external, cli)",
StainlessLang: "js", StainlessLang: "js",
StainlessPackageVersion: "0.52.0", StainlessPackageVersion: "0.70.0",
StainlessOS: "Linux", StainlessOS: "Linux",
StainlessArch: "x64", StainlessArch: "arm64",
StainlessRuntime: "node", StainlessRuntime: "node",
StainlessRuntimeVersion: "v22.14.0", StainlessRuntimeVersion: "v24.13.0",
} }
// Fingerprint represents account fingerprint data // Fingerprint represents account fingerprint data
...@@ -327,7 +327,7 @@ func generateUUIDFromSeed(seed string) string { ...@@ -327,7 +327,7 @@ func generateUUIDFromSeed(seed string) string {
} }
// parseUserAgentVersion 解析user-agent版本号 // parseUserAgentVersion 解析user-agent版本号
// 例如:claude-cli/2.0.62 -> (2, 0, 62) // 例如:claude-cli/2.1.2 -> (2, 1, 2)
func parseUserAgentVersion(ua string) (major, minor, patch int, ok bool) { func parseUserAgentVersion(ua string) (major, minor, patch int, ok bool) {
// 匹配 xxx/x.y.z 格式 // 匹配 xxx/x.y.z 格式
matches := userAgentVersionRegex.FindStringSubmatch(ua) matches := userAgentVersionRegex.FindStringSubmatch(ua)
......
...@@ -1260,16 +1260,30 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp ...@@ -1260,16 +1260,30 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
// 记录上次收到上游数据的时间,用于控制 keepalive 发送频率 // 记录上次收到上游数据的时间,用于控制 keepalive 发送频率
lastDataAt := time.Now() lastDataAt := time.Now()
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端) // 仅发送一次错误事件,避免多次写入导致协议混乱。
// 注意:OpenAI `/v1/responses` streaming 事件必须符合 OpenAI Responses schema;
// 否则下游 SDK(例如 OpenCode)会因为类型校验失败而报错。
errorEventSent := false errorEventSent := false
clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage
sendErrorEvent := func(reason string) { sendErrorEvent := func(reason string) {
if errorEventSent { if errorEventSent || clientDisconnected {
return return
} }
errorEventSent = true errorEventSent = true
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason) payload := map[string]any{
"type": "error",
"sequence_number": 0,
"error": map[string]any{
"type": "upstream_error",
"message": reason,
"code": reason,
},
}
if b, err := json.Marshal(payload); err == nil {
_, _ = fmt.Fprintf(w, "data: %s\n\n", b)
flusher.Flush() flusher.Flush()
} }
}
needModelReplace := originalModel != mappedModel needModelReplace := originalModel != mappedModel
...@@ -1280,6 +1294,17 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp ...@@ -1280,6 +1294,17 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
} }
if ev.err != nil { if ev.err != nil {
// 客户端断开/取消请求时,上游读取往往会返回 context canceled。
// /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
log.Printf("Context canceled during streaming, returning collected usage")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
// 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
if clientDisconnected {
log.Printf("Upstream read error after client disconnect: %v, returning collected usage", ev.err)
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
if errors.Is(ev.err, bufio.ErrTooLong) { if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
sendErrorEvent("response_too_large") sendErrorEvent("response_too_large")
...@@ -1303,15 +1328,19 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp ...@@ -1303,15 +1328,19 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
// Correct Codex tool calls if needed (apply_patch -> edit, etc.) // Correct Codex tool calls if needed (apply_patch -> edit, etc.)
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected { if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected {
data = correctedData
line = "data: " + correctedData line = "data: " + correctedData
} }
// Forward line // 写入客户端(客户端断开后继续 drain 上游)
if !clientDisconnected {
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed") clientDisconnected = true
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
} } else {
flusher.Flush() flusher.Flush()
}
}
// Record first token time // Record first token time
if firstTokenMs == nil && data != "" && data != "[DONE]" { if firstTokenMs == nil && data != "" && data != "[DONE]" {
...@@ -1321,18 +1350,25 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp ...@@ -1321,18 +1350,25 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
s.parseSSEUsage(data, usage) s.parseSSEUsage(data, usage)
} else { } else {
// Forward non-data lines as-is // Forward non-data lines as-is
if !clientDisconnected {
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed") clientDisconnected = true
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
} } else {
flusher.Flush() flusher.Flush()
} }
}
}
case <-intervalCh: case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval { if time.Since(lastRead) < streamInterval {
continue continue
} }
if clientDisconnected {
log.Printf("Upstream timeout after client disconnect, returning collected usage")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
// 处理流超时,可能标记账户为临时不可调度或错误状态 // 处理流超时,可能标记账户为临时不可调度或错误状态
if s.rateLimitService != nil { if s.rateLimitService != nil {
...@@ -1342,11 +1378,16 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp ...@@ -1342,11 +1378,16 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
case <-keepaliveCh: case <-keepaliveCh:
if clientDisconnected {
continue
}
if time.Since(lastDataAt) < keepaliveInterval { if time.Since(lastDataAt) < keepaliveInterval {
continue continue
} }
if _, err := fmt.Fprint(w, ":\n\n"); err != nil { if _, err := fmt.Fprint(w, ":\n\n"); err != nil {
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err clientDisconnected = true
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
continue
} }
flusher.Flush() flusher.Flush()
} }
......
...@@ -59,6 +59,25 @@ type stubConcurrencyCache struct { ...@@ -59,6 +59,25 @@ type stubConcurrencyCache struct {
skipDefaultLoad bool skipDefaultLoad bool
} }
type cancelReadCloser struct{}
func (c cancelReadCloser) Read(p []byte) (int, error) { return 0, context.Canceled }
func (c cancelReadCloser) Close() error { return nil }
type failingGinWriter struct {
gin.ResponseWriter
failAfter int
writes int
}
func (w *failingGinWriter) Write(p []byte) (int, error) {
if w.writes >= w.failAfter {
return 0, errors.New("write failed")
}
w.writes++
return w.ResponseWriter.Write(p)
}
func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
if c.acquireResults != nil { if c.acquireResults != nil {
if result, ok := c.acquireResults[accountID]; ok { if result, ok := c.acquireResults[accountID]; ok {
...@@ -814,8 +833,85 @@ func TestOpenAIStreamingTimeout(t *testing.T) { ...@@ -814,8 +833,85 @@ func TestOpenAIStreamingTimeout(t *testing.T) {
if err == nil || !strings.Contains(err.Error(), "stream data interval timeout") { if err == nil || !strings.Contains(err.Error(), "stream data interval timeout") {
t.Fatalf("expected stream timeout error, got %v", err) t.Fatalf("expected stream timeout error, got %v", err)
} }
if !strings.Contains(rec.Body.String(), "stream_timeout") { if !strings.Contains(rec.Body.String(), "\"type\":\"error\"") || !strings.Contains(rec.Body.String(), "stream_timeout") {
t.Fatalf("expected stream_timeout SSE error, got %q", rec.Body.String()) t.Fatalf("expected OpenAI-compatible error SSE event, got %q", rec.Body.String())
}
}
func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
ctx, cancel := context.WithCancel(context.Background())
cancel()
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: cancelReadCloser{},
Header: http.Header{},
}
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "stream_read_error") {
t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String())
}
}
func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
c.Writer = &failingGinWriter{ResponseWriter: c.Writer, failAfter: 0}
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Body: pr,
Header: http.Header{},
}
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
_, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":3,\"output_tokens\":5,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n"))
}()
result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
_ = pr.Close()
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
if result == nil || result.usage == nil {
t.Fatalf("expected usage result")
}
if result.usage.InputTokens != 3 || result.usage.OutputTokens != 5 || result.usage.CacheReadInputTokens != 1 {
t.Fatalf("unexpected usage: %+v", *result.usage)
}
if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "write_failed") {
t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String())
} }
} }
...@@ -854,8 +950,8 @@ func TestOpenAIStreamingTooLong(t *testing.T) { ...@@ -854,8 +950,8 @@ func TestOpenAIStreamingTooLong(t *testing.T) {
if !errors.Is(err, bufio.ErrTooLong) { if !errors.Is(err, bufio.ErrTooLong) {
t.Fatalf("expected ErrTooLong, got %v", err) t.Fatalf("expected ErrTooLong, got %v", err)
} }
if !strings.Contains(rec.Body.String(), "response_too_large") { if !strings.Contains(rec.Body.String(), "\"type\":\"error\"") || !strings.Contains(rec.Body.String(), "response_too_large") {
t.Fatalf("expected response_too_large SSE error, got %q", rec.Body.String()) t.Fatalf("expected OpenAI-compatible error SSE event, got %q", rec.Body.String())
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment