"frontend/vscode:/vscode.git/clone" did not exist on "1624523c4ef45e916b3ea81b13d846b53039081e"
Commit a6764e82 authored by shaw's avatar shaw
Browse files

修复 OAuth/SetupToken 转发请求体重排并增加调试开关

parent 9f6ab6b8
...@@ -275,21 +275,6 @@ func filterOpenCodePrompt(text string) string { ...@@ -275,21 +275,6 @@ func filterOpenCodePrompt(text string) string {
return "" return ""
} }
// systemBlockFilterPrefixes 需要从 system 中过滤的文本前缀列表
var systemBlockFilterPrefixes = []string{
"x-anthropic-billing-header",
}
// filterSystemBlockByPrefix 如果文本匹配过滤前缀,返回空字符串
func filterSystemBlockByPrefix(text string) string {
for _, prefix := range systemBlockFilterPrefixes {
if strings.HasPrefix(text, prefix) {
return ""
}
}
return text
}
// buildSystemInstruction 构建 systemInstruction(与 Antigravity-Manager 保持一致) // buildSystemInstruction 构建 systemInstruction(与 Antigravity-Manager 保持一致)
func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions, tools []ClaudeTool) *GeminiContent { func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions, tools []ClaudeTool) *GeminiContent {
var parts []GeminiPart var parts []GeminiPart
...@@ -306,8 +291,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans ...@@ -306,8 +291,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
if strings.Contains(sysStr, "You are Antigravity") { if strings.Contains(sysStr, "You are Antigravity") {
userHasAntigravityIdentity = true userHasAntigravityIdentity = true
} }
// 过滤 OpenCode 默认提示词和黑名单前缀 // 过滤 OpenCode 默认提示词
filtered := filterSystemBlockByPrefix(filterOpenCodePrompt(sysStr)) filtered := filterOpenCodePrompt(sysStr)
if filtered != "" { if filtered != "" {
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered}) userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
} }
...@@ -321,8 +306,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans ...@@ -321,8 +306,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
if strings.Contains(block.Text, "You are Antigravity") { if strings.Contains(block.Text, "You are Antigravity") {
userHasAntigravityIdentity = true userHasAntigravityIdentity = true
} }
// 过滤 OpenCode 默认提示词和黑名单前缀 // 过滤 OpenCode 默认提示词
filtered := filterSystemBlockByPrefix(filterOpenCodePrompt(block.Text)) filtered := filterOpenCodePrompt(block.Text)
if filtered != "" { if filtered != "" {
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered}) userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
} }
......
...@@ -2,7 +2,10 @@ package antigravity ...@@ -2,7 +2,10 @@ package antigravity
import ( import (
"encoding/json" "encoding/json"
"strings"
"testing" "testing"
"github.com/stretchr/testify/require"
) )
// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理 // TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理
...@@ -349,3 +352,51 @@ func TestBuildGenerationConfig_ThinkingDynamicBudget(t *testing.T) { ...@@ -349,3 +352,51 @@ func TestBuildGenerationConfig_ThinkingDynamicBudget(t *testing.T) {
}) })
} }
} }
func TestTransformClaudeToGeminiWithOptions_PreservesBillingHeaderSystemBlock(t *testing.T) {
tests := []struct {
name string
system json.RawMessage
}{
{
name: "system array",
system: json.RawMessage(`[{"type":"text","text":"x-anthropic-billing-header keep"}]`),
},
{
name: "system string",
system: json.RawMessage(`"x-anthropic-billing-header keep"`),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
claudeReq := &ClaudeRequest{
Model: "claude-3-5-sonnet-latest",
System: tt.system,
Messages: []ClaudeMessage{
{
Role: "user",
Content: json.RawMessage(`[{"type":"text","text":"hello"}]`),
},
},
}
body, err := TransformClaudeToGeminiWithOptions(claudeReq, "project-1", "gemini-2.5-flash", DefaultTransformOptions())
require.NoError(t, err)
var req V1InternalRequest
require.NoError(t, json.Unmarshal(body, &req))
require.NotNil(t, req.Request.SystemInstruction)
found := false
for _, part := range req.Request.SystemInstruction.Parts {
if strings.Contains(part.Text, "x-anthropic-billing-header keep") {
found = true
break
}
}
require.True(t, found, "转换后的 systemInstruction 应保留 x-anthropic-billing-header 内容")
})
}
}
...@@ -688,6 +688,83 @@ func TestGatewayService_AnthropicOAuth_NotAffectedByAPIKeyPassthroughToggle(t *t ...@@ -688,6 +688,83 @@ func TestGatewayService_AnthropicOAuth_NotAffectedByAPIKeyPassthroughToggle(t *t
require.Contains(t, req.Header.Get("anthropic-beta"), claude.BetaOAuth, "OAuth 链路仍应按原逻辑补齐 oauth beta") require.Contains(t, req.Header.Get("anthropic-beta"), claude.BetaOAuth, "OAuth 链路仍应按原逻辑补齐 oauth beta")
} }
func TestGatewayService_AnthropicOAuth_ForwardPreservesBillingHeaderSystemBlock(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
body string
}{
{
name: "system array",
body: `{"model":"claude-3-5-sonnet-latest","system":[{"type":"text","text":"x-anthropic-billing-header keep"}],"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`,
},
{
name: "system string",
body: `{"model":"claude-3-5-sonnet-latest","system":"x-anthropic-billing-header keep","messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
parsed, err := ParseGatewayRequest([]byte(tt.body), PlatformAnthropic)
require.NoError(t, err)
upstream := &anthropicHTTPUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"application/json"},
"x-request-id": []string{"rid-oauth-preserve"},
},
Body: io.NopCloser(strings.NewReader(`{"id":"msg_1","type":"message","role":"assistant","model":"claude-3-5-sonnet-20241022","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":12,"output_tokens":7}}`)),
},
}
cfg := &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
}
svc := &GatewayService{
cfg: cfg,
responseHeaderFilter: compileResponseHeaderFilter(cfg),
httpUpstream: upstream,
rateLimitService: &RateLimitService{},
deferredService: &DeferredService{},
}
account := &Account{
ID: 301,
Name: "anthropic-oauth-preserve",
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
},
Status: StatusActive,
Schedulable: true,
}
result, err := svc.Forward(context.Background(), c, account, parsed)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, upstream.lastReq)
require.Equal(t, "Bearer oauth-token", upstream.lastReq.Header.Get("authorization"))
require.Contains(t, upstream.lastReq.Header.Get("anthropic-beta"), claude.BetaOAuth)
system := gjson.GetBytes(upstream.lastBody, "system")
require.True(t, system.Exists())
require.Contains(t, system.Raw, "x-anthropic-billing-header keep")
})
}
}
func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingStillCollectsUsageAfterClientDisconnect(t *testing.T) { func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingStillCollectsUsageAfterClientDisconnect(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
......
package service
import (
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/stretchr/testify/require"
)
func assertJSONTokenOrder(t *testing.T, body string, tokens ...string) {
t.Helper()
last := -1
for _, token := range tokens {
pos := strings.Index(body, token)
require.NotEqualf(t, -1, pos, "missing token %s in body %s", token, body)
require.Greaterf(t, pos, last, "token %s should appear after previous tokens in body %s", token, body)
last = pos
}
}
func TestReplaceModelInBody_PreservesTopLevelFieldOrder(t *testing.T) {
svc := &GatewayService{}
body := []byte(`{"alpha":1,"model":"claude-3-5-sonnet-latest","messages":[],"omega":2}`)
result := svc.replaceModelInBody(body, "claude-3-5-sonnet-20241022")
resultStr := string(result)
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"model"`, `"messages"`, `"omega"`)
require.Contains(t, resultStr, `"model":"claude-3-5-sonnet-20241022"`)
}
func TestNormalizeClaudeOAuthRequestBody_PreservesTopLevelFieldOrder(t *testing.T) {
body := []byte(`{"alpha":1,"model":"claude-3-5-sonnet-latest","temperature":0.2,"system":"You are OpenCode, the best coding agent on the planet.","messages":[],"tool_choice":{"type":"auto"},"omega":2}`)
result, modelID := normalizeClaudeOAuthRequestBody(body, "claude-3-5-sonnet-latest", claudeOAuthNormalizeOptions{
injectMetadata: true,
metadataUserID: "user-1",
})
resultStr := string(result)
require.Equal(t, claude.NormalizeModelID("claude-3-5-sonnet-latest"), modelID)
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"model"`, `"system"`, `"messages"`, `"omega"`, `"tools"`, `"metadata"`)
require.NotContains(t, resultStr, `"temperature"`)
require.NotContains(t, resultStr, `"tool_choice"`)
require.Contains(t, resultStr, `"system":"`+claudeCodeSystemPrompt+`"`)
require.Contains(t, resultStr, `"tools":[]`)
require.Contains(t, resultStr, `"metadata":{"user_id":"user-1"}`)
}
func TestInjectClaudeCodePrompt_PreservesFieldOrder(t *testing.T) {
body := []byte(`{"alpha":1,"system":[{"id":"block-1","type":"text","text":"Custom"}],"messages":[],"omega":2}`)
result := injectClaudeCodePrompt(body, []any{
map[string]any{"id": "block-1", "type": "text", "text": "Custom"},
})
resultStr := string(result)
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"system"`, `"messages"`, `"omega"`)
require.Contains(t, resultStr, `{"id":"block-1","type":"text","text":"`+claudeCodeSystemPrompt+`\n\nCustom"}`)
}
func TestEnforceCacheControlLimit_PreservesTopLevelFieldOrder(t *testing.T) {
body := []byte(`{"alpha":1,"system":[{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}},{"type":"text","text":"s2","cache_control":{"type":"ephemeral"}}],"messages":[{"role":"user","content":[{"type":"text","text":"m1","cache_control":{"type":"ephemeral"}},{"type":"text","text":"m2","cache_control":{"type":"ephemeral"}},{"type":"text","text":"m3","cache_control":{"type":"ephemeral"}}]}],"omega":2}`)
result := enforceCacheControlLimit(body)
resultStr := string(result)
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"system"`, `"messages"`, `"omega"`)
require.Equal(t, 4, strings.Count(resultStr, `"cache_control"`))
}
package service
import "testing"
func TestDebugGatewayBodyLoggingEnabled(t *testing.T) {
t.Run("default disabled", func(t *testing.T) {
t.Setenv(debugGatewayBodyEnv, "")
if debugGatewayBodyLoggingEnabled() {
t.Fatalf("expected debug gateway body logging to be disabled by default")
}
})
t.Run("enabled with true-like values", func(t *testing.T) {
for _, value := range []string{"1", "true", "TRUE", "yes", "on"} {
t.Run(value, func(t *testing.T) {
t.Setenv(debugGatewayBodyEnv, value)
if !debugGatewayBodyLoggingEnabled() {
t.Fatalf("expected debug gateway body logging to be enabled for %q", value)
}
})
}
})
t.Run("disabled with other values", func(t *testing.T) {
for _, value := range []string{"0", "false", "off", "debug"} {
t.Run(value, func(t *testing.T) {
t.Setenv(debugGatewayBodyEnv, value)
if debugGatewayBodyLoggingEnabled() {
t.Fatalf("expected debug gateway body logging to be disabled for %q", value)
}
})
}
})
}
...@@ -51,6 +51,7 @@ const ( ...@@ -51,6 +51,7 @@ const (
defaultUserGroupRateCacheTTL = 30 * time.Second defaultUserGroupRateCacheTTL = 30 * time.Second
defaultModelsListCacheTTL = 15 * time.Second defaultModelsListCacheTTL = 15 * time.Second
postUsageBillingTimeout = 15 * time.Second postUsageBillingTimeout = 15 * time.Second
debugGatewayBodyEnv = "SUB2API_DEBUG_GATEWAY_BODY"
) )
const ( const (
...@@ -339,12 +340,6 @@ var ( ...@@ -339,12 +340,6 @@ var (
} }
) )
// systemBlockFilterPrefixes 需要从 system 中过滤的文本前缀列表
// OAuth/SetupToken 账号转发时,匹配这些前缀的 system 元素会被移除
var systemBlockFilterPrefixes = []string{
"x-anthropic-billing-header",
}
// ErrNoAvailableAccounts 表示没有可用的账号 // ErrNoAvailableAccounts 表示没有可用的账号
var ErrNoAvailableAccounts = errors.New("no available accounts") var ErrNoAvailableAccounts = errors.New("no available accounts")
...@@ -840,20 +835,30 @@ func (s *GatewayService) hashContent(content string) string { ...@@ -840,20 +835,30 @@ func (s *GatewayService) hashContent(content string) string {
return strconv.FormatUint(h, 36) return strconv.FormatUint(h, 36)
} }
type anthropicCacheControlPayload struct {
Type string `json:"type"`
}
type anthropicSystemTextBlockPayload struct {
Type string `json:"type"`
Text string `json:"text"`
CacheControl *anthropicCacheControlPayload `json:"cache_control,omitempty"`
}
type anthropicMetadataPayload struct {
UserID string `json:"user_id"`
}
// replaceModelInBody 替换请求体中的model字段 // replaceModelInBody 替换请求体中的model字段
// 使用 json.RawMessage 保留其他字段的原始字节,避免 thinking 块等内容被修改 // 优先使用定点修改,尽量保持客户端原始字段顺序。
func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte { func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte {
var req map[string]json.RawMessage if len(body) == 0 {
if err := json.Unmarshal(body, &req); err != nil {
return body return body
} }
// 只序列化 model 字段 if current := gjson.GetBytes(body, "model"); current.Exists() && current.String() == newModel {
modelBytes, err := json.Marshal(newModel)
if err != nil {
return body return body
} }
req["model"] = modelBytes newBody, err := sjson.SetBytes(body, "model", newModel)
newBody, err := json.Marshal(req)
if err != nil { if err != nil {
return body return body
} }
...@@ -884,121 +889,206 @@ func sanitizeSystemText(text string) string { ...@@ -884,121 +889,206 @@ func sanitizeSystemText(text string) string {
return text return text
} }
func stripCacheControlFromSystemBlocks(system any) bool { func marshalAnthropicSystemTextBlock(text string, includeCacheControl bool) ([]byte, error) {
blocks, ok := system.([]any) block := anthropicSystemTextBlockPayload{
if !ok { Type: "text",
return false Text: text,
} }
changed := false if includeCacheControl {
for _, item := range blocks { block.CacheControl = &anthropicCacheControlPayload{Type: "ephemeral"}
block, ok := item.(map[string]any)
if !ok {
continue
} }
if _, exists := block["cache_control"]; !exists { return json.Marshal(block)
continue }
func marshalAnthropicMetadata(userID string) ([]byte, error) {
return json.Marshal(anthropicMetadataPayload{UserID: userID})
}
func buildJSONArrayRaw(items [][]byte) []byte {
if len(items) == 0 {
return []byte("[]")
}
total := 2
for _, item := range items {
total += len(item)
}
total += len(items) - 1
buf := make([]byte, 0, total)
buf = append(buf, '[')
for i, item := range items {
if i > 0 {
buf = append(buf, ',')
} }
delete(block, "cache_control") buf = append(buf, item...)
changed = true
} }
return changed buf = append(buf, ']')
return buf
} }
func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAuthNormalizeOptions) ([]byte, string) { func setJSONValueBytes(body []byte, path string, value any) ([]byte, bool) {
if len(body) == 0 { next, err := sjson.SetBytes(body, path, value)
return body, modelID if err != nil {
return body, false
} }
return next, true
}
// 解析为 map[string]any 用于修改字段 func setJSONRawBytes(body []byte, path string, raw []byte) ([]byte, bool) {
var req map[string]any next, err := sjson.SetRawBytes(body, path, raw)
if err := json.Unmarshal(body, &req); err != nil { if err != nil {
return body, modelID return body, false
}
return next, true
}
func deleteJSONPathBytes(body []byte, path string) ([]byte, bool) {
next, err := sjson.DeleteBytes(body, path)
if err != nil {
return body, false
}
return next, true
}
func normalizeClaudeOAuthSystemBody(body []byte, opts claudeOAuthNormalizeOptions) ([]byte, bool) {
sys := gjson.GetBytes(body, "system")
if !sys.Exists() {
return body, false
} }
out := body
modified := false modified := false
if system, ok := req["system"]; ok { switch {
switch v := system.(type) { case sys.Type == gjson.String:
case string: sanitized := sanitizeSystemText(sys.String())
sanitized := sanitizeSystemText(v) if sanitized != sys.String() {
if sanitized != v { if next, ok := setJSONValueBytes(out, "system", sanitized); ok {
req["system"] = sanitized out = next
modified = true modified = true
} }
case []any:
for _, item := range v {
block, ok := item.(map[string]any)
if !ok {
continue
}
if blockType, _ := block["type"].(string); blockType != "text" {
continue
}
text, ok := block["text"].(string)
if !ok || text == "" {
continue
} }
case sys.IsArray():
index := 0
sys.ForEach(func(_, item gjson.Result) bool {
if item.Get("type").String() == "text" {
textResult := item.Get("text")
if textResult.Exists() && textResult.Type == gjson.String {
text := textResult.String()
sanitized := sanitizeSystemText(text) sanitized := sanitizeSystemText(text)
if sanitized != text { if sanitized != text {
block["text"] = sanitized if next, ok := setJSONValueBytes(out, fmt.Sprintf("system.%d.text", index), sanitized); ok {
out = next
modified = true modified = true
} }
} }
} }
} }
if rawModel, ok := req["model"].(string); ok { if opts.stripSystemCacheControl && item.Get("cache_control").Exists() {
normalized := claude.NormalizeModelID(rawModel) if next, ok := deleteJSONPathBytes(out, fmt.Sprintf("system.%d.cache_control", index)); ok {
if normalized != rawModel { out = next
req["model"] = normalized
modelID = normalized
modified = true modified = true
} }
} }
// 确保 tools 字段存在(即使为空数组) index++
if _, exists := req["tools"]; !exists { return true
req["tools"] = []any{} })
}
return out, modified
}
func ensureClaudeOAuthMetadataUserID(body []byte, userID string) ([]byte, bool) {
if strings.TrimSpace(userID) == "" {
return body, false
}
metadata := gjson.GetBytes(body, "metadata")
if !metadata.Exists() || metadata.Type == gjson.Null {
raw, err := marshalAnthropicMetadata(userID)
if err != nil {
return body, false
}
return setJSONRawBytes(body, "metadata", raw)
}
trimmedRaw := strings.TrimSpace(metadata.Raw)
if strings.HasPrefix(trimmedRaw, "{") {
existing := metadata.Get("user_id")
if existing.Exists() && existing.Type == gjson.String && existing.String() != "" {
return body, false
}
return setJSONValueBytes(body, "metadata.user_id", userID)
}
raw, err := marshalAnthropicMetadata(userID)
if err != nil {
return body, false
}
return setJSONRawBytes(body, "metadata", raw)
}
func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAuthNormalizeOptions) ([]byte, string) {
if len(body) == 0 {
return body, modelID
}
out := body
modified := false
if next, changed := normalizeClaudeOAuthSystemBody(out, opts); changed {
out = next
modified = true modified = true
} }
if opts.stripSystemCacheControl { rawModel := gjson.GetBytes(out, "model")
if system, ok := req["system"]; ok { if rawModel.Exists() && rawModel.Type == gjson.String {
_ = stripCacheControlFromSystemBlocks(system) normalized := claude.NormalizeModelID(rawModel.String())
if normalized != rawModel.String() {
if next, ok := setJSONValueBytes(out, "model", normalized); ok {
out = next
modified = true modified = true
} }
modelID = normalized
}
} }
if opts.injectMetadata && opts.metadataUserID != "" { // 确保 tools 字段存在(即使为空数组)
metadata, ok := req["metadata"].(map[string]any) if !gjson.GetBytes(out, "tools").Exists() {
if !ok { if next, ok := setJSONRawBytes(out, "tools", []byte("[]")); ok {
metadata = map[string]any{} out = next
req["metadata"] = metadata modified = true
} }
if existing, ok := metadata["user_id"].(string); !ok || existing == "" { }
metadata["user_id"] = opts.metadataUserID
if opts.injectMetadata && opts.metadataUserID != "" {
if next, changed := ensureClaudeOAuthMetadataUserID(out, opts.metadataUserID); changed {
out = next
modified = true modified = true
} }
} }
if _, hasTemp := req["temperature"]; hasTemp { if gjson.GetBytes(out, "temperature").Exists() {
delete(req, "temperature") if next, ok := deleteJSONPathBytes(out, "temperature"); ok {
out = next
modified = true modified = true
} }
if _, hasChoice := req["tool_choice"]; hasChoice { }
delete(req, "tool_choice") if gjson.GetBytes(out, "tool_choice").Exists() {
if next, ok := deleteJSONPathBytes(out, "tool_choice"); ok {
out = next
modified = true modified = true
} }
}
if !modified { if !modified {
return body, modelID return body, modelID
} }
newBody, err := json.Marshal(req) return out, modelID
if err != nil {
return body, modelID
}
return newBody, modelID
} }
func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account *Account, fp *Fingerprint) string { func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account *Account, fp *Fingerprint) string {
...@@ -3676,82 +3766,28 @@ func hasClaudeCodePrefix(text string) bool { ...@@ -3676,82 +3766,28 @@ func hasClaudeCodePrefix(text string) bool {
return false return false
} }
// matchesFilterPrefix 检查文本是否匹配任一过滤前缀
func matchesFilterPrefix(text string) bool {
for _, prefix := range systemBlockFilterPrefixes {
if strings.HasPrefix(text, prefix) {
return true
}
}
return false
}
// filterSystemBlocksByPrefix 从 body 的 system 中移除文本匹配 systemBlockFilterPrefixes 前缀的元素
// 直接从 body 解析 system,不依赖外部传入的 parsed.System(因为前置步骤可能已修改 body 中的 system)
func filterSystemBlocksByPrefix(body []byte) []byte {
sys := gjson.GetBytes(body, "system")
if !sys.Exists() {
return body
}
switch {
case sys.Type == gjson.String:
if matchesFilterPrefix(sys.Str) {
result, err := sjson.DeleteBytes(body, "system")
if err != nil {
return body
}
return result
}
case sys.IsArray():
var parsed []any
if err := json.Unmarshal([]byte(sys.Raw), &parsed); err != nil {
return body
}
filtered := make([]any, 0, len(parsed))
changed := false
for _, item := range parsed {
if m, ok := item.(map[string]any); ok {
if text, ok := m["text"].(string); ok && matchesFilterPrefix(text) {
changed = true
continue
}
}
filtered = append(filtered, item)
}
if changed {
result, err := sjson.SetBytes(body, "system", filtered)
if err != nil {
return body
}
return result
}
}
return body
}
// injectClaudeCodePrompt 在 system 开头注入 Claude Code 提示词 // injectClaudeCodePrompt 在 system 开头注入 Claude Code 提示词
// 处理 null、字符串、数组三种格式 // 处理 null、字符串、数组三种格式
func injectClaudeCodePrompt(body []byte, system any) []byte { func injectClaudeCodePrompt(body []byte, system any) []byte {
claudeCodeBlock := map[string]any{ claudeCodeBlock, err := marshalAnthropicSystemTextBlock(claudeCodeSystemPrompt, true)
"type": "text", if err != nil {
"text": claudeCodeSystemPrompt, logger.LegacyPrintf("service.gateway", "Warning: failed to build Claude Code prompt block: %v", err)
"cache_control": map[string]string{"type": "ephemeral"}, return body
} }
// Opencode plugin applies an extra safeguard: it not only prepends the Claude Code // Opencode plugin applies an extra safeguard: it not only prepends the Claude Code
// banner, it also prefixes the next system instruction with the same banner plus // banner, it also prefixes the next system instruction with the same banner plus
// a blank line. This helps when upstream concatenates system instructions. // a blank line. This helps when upstream concatenates system instructions.
claudeCodePrefix := strings.TrimSpace(claudeCodeSystemPrompt) claudeCodePrefix := strings.TrimSpace(claudeCodeSystemPrompt)
var newSystem []any var items [][]byte
switch v := system.(type) { switch v := system.(type) {
case nil: case nil:
newSystem = []any{claudeCodeBlock} items = [][]byte{claudeCodeBlock}
case string: case string:
// Be tolerant of older/newer clients that may differ only by trailing whitespace/newlines. // Be tolerant of older/newer clients that may differ only by trailing whitespace/newlines.
if strings.TrimSpace(v) == "" || strings.TrimSpace(v) == strings.TrimSpace(claudeCodeSystemPrompt) { if strings.TrimSpace(v) == "" || strings.TrimSpace(v) == strings.TrimSpace(claudeCodeSystemPrompt) {
newSystem = []any{claudeCodeBlock} items = [][]byte{claudeCodeBlock}
} else { } else {
// Mirror opencode behavior: keep the banner as a separate system entry, // Mirror opencode behavior: keep the banner as a separate system entry,
// but also prefix the next system text with the banner. // but also prefix the next system text with the banner.
...@@ -3759,18 +3795,54 @@ func injectClaudeCodePrompt(body []byte, system any) []byte { ...@@ -3759,18 +3795,54 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
if !strings.HasPrefix(v, claudeCodePrefix) { if !strings.HasPrefix(v, claudeCodePrefix) {
merged = claudeCodePrefix + "\n\n" + v merged = claudeCodePrefix + "\n\n" + v
} }
newSystem = []any{claudeCodeBlock, map[string]any{"type": "text", "text": merged}} nextBlock, buildErr := marshalAnthropicSystemTextBlock(merged, false)
if buildErr != nil {
logger.LegacyPrintf("service.gateway", "Warning: failed to build prefixed Claude Code system block: %v", buildErr)
return body
}
items = [][]byte{claudeCodeBlock, nextBlock}
} }
case []any: case []any:
newSystem = make([]any, 0, len(v)+1) items = make([][]byte, 0, len(v)+1)
newSystem = append(newSystem, claudeCodeBlock) items = append(items, claudeCodeBlock)
prefixedNext := false prefixedNext := false
systemResult := gjson.GetBytes(body, "system")
if systemResult.IsArray() {
systemResult.ForEach(func(_, item gjson.Result) bool {
textResult := item.Get("text")
if textResult.Exists() && textResult.Type == gjson.String &&
strings.TrimSpace(textResult.String()) == strings.TrimSpace(claudeCodeSystemPrompt) {
return true
}
raw := []byte(item.Raw)
// Prefix the first subsequent text system block once.
if !prefixedNext && item.Get("type").String() == "text" && textResult.Exists() && textResult.Type == gjson.String {
text := textResult.String()
if strings.TrimSpace(text) != "" && !strings.HasPrefix(text, claudeCodePrefix) {
next, setErr := sjson.SetBytes(raw, "text", claudeCodePrefix+"\n\n"+text)
if setErr == nil {
raw = next
prefixedNext = true
}
}
}
items = append(items, raw)
return true
})
} else {
for _, item := range v { for _, item := range v {
if m, ok := item.(map[string]any); ok { m, ok := item.(map[string]any)
if !ok {
raw, marshalErr := json.Marshal(item)
if marshalErr == nil {
items = append(items, raw)
}
continue
}
if text, ok := m["text"].(string); ok && strings.TrimSpace(text) == strings.TrimSpace(claudeCodeSystemPrompt) { if text, ok := m["text"].(string); ok && strings.TrimSpace(text) == strings.TrimSpace(claudeCodeSystemPrompt) {
continue continue
} }
// Prefix the first subsequent text system block once.
if !prefixedNext { if !prefixedNext {
if blockType, _ := m["type"].(string); blockType == "text" { if blockType, _ := m["type"].(string); blockType == "text" {
if text, ok := m["text"].(string); ok && strings.TrimSpace(text) != "" && !strings.HasPrefix(text, claudeCodePrefix) { if text, ok := m["text"].(string); ok && strings.TrimSpace(text) != "" && !strings.HasPrefix(text, claudeCodePrefix) {
...@@ -3779,197 +3851,150 @@ func injectClaudeCodePrompt(body []byte, system any) []byte { ...@@ -3779,197 +3851,150 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
} }
} }
} }
raw, marshalErr := json.Marshal(m)
if marshalErr == nil {
items = append(items, raw)
}
} }
newSystem = append(newSystem, item)
} }
default: default:
newSystem = []any{claudeCodeBlock} items = [][]byte{claudeCodeBlock}
} }
result, err := sjson.SetBytes(body, "system", newSystem) result, ok := setJSONRawBytes(body, "system", buildJSONArrayRaw(items))
if err != nil { if !ok {
logger.LegacyPrintf("service.gateway", "Warning: failed to inject Claude Code prompt: %v", err) logger.LegacyPrintf("service.gateway", "Warning: failed to inject Claude Code prompt")
return body return body
} }
return result return result
} }
// enforceCacheControlLimit 强制执行 cache_control 块数量限制(最多 4 个) type cacheControlPath struct {
// 超限时优先从 messages 中移除 cache_control,保护 system 中的缓存控制 path string
func enforceCacheControlLimit(body []byte) []byte { log string
var data map[string]any
if err := json.Unmarshal(body, &data); err != nil {
return body
}
// 清理 thinking 块中的非法 cache_control(thinking 块不支持该字段)
removeCacheControlFromThinkingBlocks(data)
// 计算当前 cache_control 块数量
count := countCacheControlBlocks(data)
if count <= maxCacheControlBlocks {
return body
}
// 超限:优先从 messages 中移除,再从 system 中移除
for count > maxCacheControlBlocks {
if removeCacheControlFromMessages(data) {
count--
continue
}
if removeCacheControlFromSystem(data) {
count--
continue
}
break
}
result, err := json.Marshal(data)
if err != nil {
return body
}
return result
} }
// countCacheControlBlocks 统计 system 和 messages 中的 cache_control 块数量 func collectCacheControlPaths(body []byte) (invalidThinking []cacheControlPath, messagePaths []string, systemPaths []string) {
// 注意:thinking 块不支持 cache_control,统计时跳过 system := gjson.GetBytes(body, "system")
func countCacheControlBlocks(data map[string]any) int { if system.IsArray() {
count := 0 sysIndex := 0
system.ForEach(func(_, item gjson.Result) bool {
// 统计 system 中的块 if item.Get("cache_control").Exists() {
if system, ok := data["system"].([]any); ok { path := fmt.Sprintf("system.%d.cache_control", sysIndex)
for _, item := range system { if item.Get("type").String() == "thinking" {
if m, ok := item.(map[string]any); ok { invalidThinking = append(invalidThinking, cacheControlPath{
// thinking 块不支持 cache_control,跳过 path: path,
if blockType, _ := m["type"].(string); blockType == "thinking" { log: "[Warning] Removed illegal cache_control from thinking block in system",
continue })
} } else {
if _, has := m["cache_control"]; has { systemPaths = append(systemPaths, path)
count++
}
} }
} }
sysIndex++
return true
})
} }
// 统计 messages 中的块 messages := gjson.GetBytes(body, "messages")
if messages, ok := data["messages"].([]any); ok { if messages.IsArray() {
for _, msg := range messages { msgIndex := 0
if msgMap, ok := msg.(map[string]any); ok { messages.ForEach(func(_, msg gjson.Result) bool {
if content, ok := msgMap["content"].([]any); ok { content := msg.Get("content")
for _, item := range content { if content.IsArray() {
if m, ok := item.(map[string]any); ok { contentIndex := 0
// thinking 块不支持 cache_control,跳过 content.ForEach(func(_, item gjson.Result) bool {
if blockType, _ := m["type"].(string); blockType == "thinking" { if item.Get("cache_control").Exists() {
continue path := fmt.Sprintf("messages.%d.content.%d.cache_control", msgIndex, contentIndex)
} if item.Get("type").String() == "thinking" {
if _, has := m["cache_control"]; has { invalidThinking = append(invalidThinking, cacheControlPath{
count++ path: path,
} log: fmt.Sprintf("[Warning] Removed illegal cache_control from thinking block in messages[%d].content[%d]", msgIndex, contentIndex),
} })
} } else {
messagePaths = append(messagePaths, path)
} }
} }
contentIndex++
return true
})
} }
msgIndex++
return true
})
} }
return count return invalidThinking, messagePaths, systemPaths
} }
// removeCacheControlFromMessages 从 messages 中移除一个 cache_control(从头开始) // enforceCacheControlLimit 强制执行 cache_control 块数量限制(最多 4 个)
// 返回 true 表示成功移除,false 表示没有可移除的 // 超限时优先从 messages 中移除 cache_control,保护 system 中的缓存控制
// 注意:跳过 thinking 块(它不支持 cache_control) func enforceCacheControlLimit(body []byte) []byte {
func removeCacheControlFromMessages(data map[string]any) bool { if len(body) == 0 {
messages, ok := data["messages"].([]any) return body
if !ok {
return false
} }
for _, msg := range messages { invalidThinking, messagePaths, systemPaths := collectCacheControlPaths(body)
msgMap, ok := msg.(map[string]any) out := body
if !ok { modified := false
// 先清理 thinking 块中的非法 cache_control(thinking 块不支持该字段)
for _, item := range invalidThinking {
if !gjson.GetBytes(out, item.path).Exists() {
continue continue
} }
content, ok := msgMap["content"].([]any) next, ok := deleteJSONPathBytes(out, item.path)
if !ok { if !ok {
continue continue
} }
for _, item := range content { out = next
if m, ok := item.(map[string]any); ok { modified = true
// thinking 块不支持 cache_control,跳过 logger.LegacyPrintf("service.gateway", "%s", item.log)
if blockType, _ := m["type"].(string); blockType == "thinking" {
continue
}
if _, has := m["cache_control"]; has {
delete(m, "cache_control")
return true
}
} }
count := len(messagePaths) + len(systemPaths)
if count <= maxCacheControlBlocks {
if modified {
return out
} }
return body
} }
return false
}
// removeCacheControlFromSystem 从 system 中移除一个 cache_control(从尾部开始,保护注入的 prompt) // 超限:优先从 messages 中移除,再从 system 中移除
// 返回 true 表示成功移除,false 表示没有可移除的 remaining := count - maxCacheControlBlocks
// 注意:跳过 thinking 块(它不支持 cache_control) for _, path := range messagePaths {
func removeCacheControlFromSystem(data map[string]any) bool { if remaining <= 0 {
system, ok := data["system"].([]any) break
if !ok {
return false
} }
if !gjson.GetBytes(out, path).Exists() {
// 从尾部开始移除,保护开头注入的 Claude Code prompt
for i := len(system) - 1; i >= 0; i-- {
if m, ok := system[i].(map[string]any); ok {
// thinking 块不支持 cache_control,跳过
if blockType, _ := m["type"].(string); blockType == "thinking" {
continue continue
} }
if _, has := m["cache_control"]; has { next, ok := deleteJSONPathBytes(out, path)
delete(m, "cache_control") if !ok {
return true continue
}
} }
out = next
modified = true
remaining--
} }
return false
}
// removeCacheControlFromThinkingBlocks 强制清理所有 thinking 块中的非法 cache_control for i := len(systemPaths) - 1; i >= 0 && remaining > 0; i-- {
// thinking 块不支持 cache_control 字段,这个函数确保所有 thinking 块都不含该字段 path := systemPaths[i]
func removeCacheControlFromThinkingBlocks(data map[string]any) { if !gjson.GetBytes(out, path).Exists() {
// 清理 system 中的 thinking 块 continue
if system, ok := data["system"].([]any); ok {
for _, item := range system {
if m, ok := item.(map[string]any); ok {
if blockType, _ := m["type"].(string); blockType == "thinking" {
if _, has := m["cache_control"]; has {
delete(m, "cache_control")
logger.LegacyPrintf("service.gateway", "[Warning] Removed illegal cache_control from thinking block in system")
}
}
} }
next, ok := deleteJSONPathBytes(out, path)
if !ok {
continue
} }
out = next
modified = true
remaining--
} }
// 清理 messages 中的 thinking 块 if modified {
if messages, ok := data["messages"].([]any); ok { return out
for msgIdx, msg := range messages {
if msgMap, ok := msg.(map[string]any); ok {
if content, ok := msgMap["content"].([]any); ok {
for contentIdx, item := range content {
if m, ok := item.(map[string]any); ok {
if blockType, _ := m["type"].(string); blockType == "thinking" {
if _, has := m["cache_control"]; has {
delete(m, "cache_control")
logger.LegacyPrintf("service.gateway", "[Warning] Removed illegal cache_control from thinking block in messages[%d].content[%d]", msgIdx, contentIdx)
}
}
}
}
}
}
}
} }
return body
} }
// Forward 转发请求到Claude API // Forward 转发请求到Claude API
...@@ -4021,6 +4046,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -4021,6 +4046,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
reqStream := parsed.Stream reqStream := parsed.Stream
originalModel := reqModel originalModel := reqModel
// === DEBUG: 打印客户端原始请求 body ===
debugLogRequestBody("CLIENT_ORIGINAL", body)
isClaudeCode := isClaudeCodeRequest(ctx, c, parsed) isClaudeCode := isClaudeCodeRequest(ctx, c, parsed)
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
...@@ -4046,12 +4074,6 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -4046,12 +4074,6 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts) body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
} }
// OAuth/SetupToken 账号:移除黑名单前缀匹配的 system 元素(如客户端注入的计费元数据)
// 放在 inject/normalize 之后,确保不会被覆盖
if account.IsOAuth() {
body = filterSystemBlocksByPrefix(body)
}
// 强制执行 cache_control 块数量限制(最多 4 个) // 强制执行 cache_control 块数量限制(最多 4 个)
body = enforceCacheControlLimit(body) body = enforceCacheControlLimit(body)
...@@ -5573,6 +5595,9 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex ...@@ -5573,6 +5595,9 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
} }
} }
// === DEBUG: 打印转发给上游的 body(metadata 已重写) ===
debugLogRequestBody("UPSTREAM_FORWARD", body)
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -8447,3 +8472,43 @@ func reconcileCachedTokens(usage map[string]any) bool { ...@@ -8447,3 +8472,43 @@ func reconcileCachedTokens(usage map[string]any) bool {
usage["cache_read_input_tokens"] = cached usage["cache_read_input_tokens"] = cached
return true return true
} }
func debugGatewayBodyLoggingEnabled() bool {
raw := strings.TrimSpace(os.Getenv(debugGatewayBodyEnv))
if raw == "" {
return false
}
switch strings.ToLower(raw) {
case "1", "true", "yes", "on":
return true
default:
return false
}
}
// debugLogRequestBody 打印请求 body 用于调试 metadata.user_id 重写。
// 默认关闭,仅在设置环境变量时启用:
//
// SUB2API_DEBUG_GATEWAY_BODY=1
func debugLogRequestBody(tag string, body []byte) {
if !debugGatewayBodyLoggingEnabled() {
return
}
if len(body) == 0 {
logger.LegacyPrintf("service.gateway", "[DEBUG_%s] body is empty", tag)
return
}
// 提取 metadata 字段完整打印
metadataResult := gjson.GetBytes(body, "metadata")
if metadataResult.Exists() {
logger.LegacyPrintf("service.gateway", "[DEBUG_%s] metadata = %s", tag, metadataResult.Raw)
} else {
logger.LegacyPrintf("service.gateway", "[DEBUG_%s] metadata field not found", tag)
}
// 全量打印 body
logger.LegacyPrintf("service.gateway", "[DEBUG_%s] body (%d bytes) = %s", tag, len(body), string(body))
}
...@@ -5,7 +5,6 @@ import ( ...@@ -5,7 +5,6 @@ import (
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"encoding/json"
"fmt" "fmt"
"log/slog" "log/slog"
"net/http" "net/http"
...@@ -15,6 +14,8 @@ import ( ...@@ -15,6 +14,8 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
) )
// 预编译正则表达式(避免每次调用重新编译) // 预编译正则表达式(避免每次调用重新编译)
...@@ -215,25 +216,20 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI ...@@ -215,25 +216,20 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
return body, nil return body, nil
} }
// 使用 RawMessage 保留其他字段的原始字节 metadata := gjson.GetBytes(body, "metadata")
var reqMap map[string]json.RawMessage if !metadata.Exists() || metadata.Type == gjson.Null {
if err := json.Unmarshal(body, &reqMap); err != nil {
return body, nil return body, nil
} }
if !strings.HasPrefix(strings.TrimSpace(metadata.Raw), "{") {
// 解析 metadata 字段
metadataRaw, ok := reqMap["metadata"]
if !ok {
return body, nil return body, nil
} }
var metadata map[string]any userIDResult := metadata.Get("user_id")
if err := json.Unmarshal(metadataRaw, &metadata); err != nil { if !userIDResult.Exists() || userIDResult.Type != gjson.String {
return body, nil return body, nil
} }
userID := userIDResult.String()
userID, ok := metadata["user_id"].(string) if userID == "" {
if !ok || userID == "" {
return body, nil return body, nil
} }
...@@ -252,17 +248,15 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI ...@@ -252,17 +248,15 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
// 根据客户端版本选择输出格式 // 根据客户端版本选择输出格式
version := ExtractCLIVersion(fingerprintUA) version := ExtractCLIVersion(fingerprintUA)
newUserID := FormatMetadataUserID(cachedClientID, accountUUID, newSessionHash, version) newUserID := FormatMetadataUserID(cachedClientID, accountUUID, newSessionHash, version)
if newUserID == userID {
return body, nil
}
metadata["user_id"] = newUserID newBody, err := sjson.SetBytes(body, "metadata.user_id", newUserID)
// 只重新序列化 metadata 字段
newMetadataRaw, err := json.Marshal(metadata)
if err != nil { if err != nil {
return body, nil return body, nil
} }
reqMap["metadata"] = newMetadataRaw return newBody, nil
return json.Marshal(reqMap)
} }
// RewriteUserIDWithMasking 重写body中的metadata.user_id,支持会话ID伪装 // RewriteUserIDWithMasking 重写body中的metadata.user_id,支持会话ID伪装
...@@ -283,25 +277,20 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b ...@@ -283,25 +277,20 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
return newBody, nil return newBody, nil
} }
// 使用 RawMessage 保留其他字段的原始字节 metadata := gjson.GetBytes(newBody, "metadata")
var reqMap map[string]json.RawMessage if !metadata.Exists() || metadata.Type == gjson.Null {
if err := json.Unmarshal(newBody, &reqMap); err != nil {
return newBody, nil return newBody, nil
} }
if !strings.HasPrefix(strings.TrimSpace(metadata.Raw), "{") {
// 解析 metadata 字段
metadataRaw, ok := reqMap["metadata"]
if !ok {
return newBody, nil return newBody, nil
} }
var metadata map[string]any userIDResult := metadata.Get("user_id")
if err := json.Unmarshal(metadataRaw, &metadata); err != nil { if !userIDResult.Exists() || userIDResult.Type != gjson.String {
return newBody, nil return newBody, nil
} }
userID := userIDResult.String()
userID, ok := metadata["user_id"].(string) if userID == "" {
if !ok || userID == "" {
return newBody, nil return newBody, nil
} }
...@@ -339,16 +328,15 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b ...@@ -339,16 +328,15 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
"after", newUserID, "after", newUserID,
) )
metadata["user_id"] = newUserID if newUserID == userID {
// 只重新序列化 metadata 字段
newMetadataRaw, marshalErr := json.Marshal(metadata)
if marshalErr != nil {
return newBody, nil return newBody, nil
} }
reqMap["metadata"] = newMetadataRaw
return json.Marshal(reqMap) maskedBody, setErr := sjson.SetBytes(newBody, "metadata.user_id", newUserID)
if setErr != nil {
return newBody, nil
}
return maskedBody, nil
} }
// generateRandomUUID 生成随机 UUID v4 格式字符串 // generateRandomUUID 生成随机 UUID v4 格式字符串
......
package service
import (
"context"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
type identityCacheStub struct {
maskedSessionID string
}
func (s *identityCacheStub) GetFingerprint(_ context.Context, _ int64) (*Fingerprint, error) {
return nil, nil
}
func (s *identityCacheStub) SetFingerprint(_ context.Context, _ int64, _ *Fingerprint) error {
return nil
}
func (s *identityCacheStub) GetMaskedSessionID(_ context.Context, _ int64) (string, error) {
return s.maskedSessionID, nil
}
func (s *identityCacheStub) SetMaskedSessionID(_ context.Context, _ int64, sessionID string) error {
s.maskedSessionID = sessionID
return nil
}
func TestIdentityService_RewriteUserID_PreservesTopLevelFieldOrder(t *testing.T) {
cache := &identityCacheStub{}
svc := NewIdentityService(cache)
originalUserID := FormatMetadataUserID(
"d61f76d0730d2b920763648949bad5c79742155c27037fc77ac3f9805cb90169",
"",
"7578cf37-aaca-46e4-a45c-71285d9dbb83",
"2.1.78",
)
body := []byte(`{"alpha":1,"messages":[],"metadata":{"user_id":` + strconvQuote(originalUserID) + `},"max_tokens":64000,"thinking":{"type":"adaptive"},"output_config":{"effort":"high"},"stream":true}`)
result, err := svc.RewriteUserID(body, 123, "acc-uuid", "client-xyz", "claude-cli/2.1.78 (external, cli)")
require.NoError(t, err)
resultStr := string(result)
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"messages"`, `"metadata"`, `"max_tokens"`, `"thinking"`, `"output_config"`, `"stream"`)
require.NotContains(t, resultStr, originalUserID)
require.Contains(t, resultStr, `"metadata":{"user_id":"`)
}
func TestIdentityService_RewriteUserIDWithMasking_PreservesTopLevelFieldOrder(t *testing.T) {
cache := &identityCacheStub{maskedSessionID: "11111111-2222-4333-8444-555555555555"}
svc := NewIdentityService(cache)
originalUserID := FormatMetadataUserID(
"d61f76d0730d2b920763648949bad5c79742155c27037fc77ac3f9805cb90169",
"",
"7578cf37-aaca-46e4-a45c-71285d9dbb83",
"2.1.78",
)
body := []byte(`{"alpha":1,"messages":[],"metadata":{"user_id":` + strconvQuote(originalUserID) + `},"max_tokens":64000,"thinking":{"type":"adaptive"},"output_config":{"effort":"high"},"stream":true}`)
account := &Account{
ID: 123,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Extra: map[string]any{
"session_id_masking_enabled": true,
},
}
result, err := svc.RewriteUserIDWithMasking(context.Background(), body, account, "acc-uuid", "client-xyz", "claude-cli/2.1.78 (external, cli)")
require.NoError(t, err)
resultStr := string(result)
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"messages"`, `"metadata"`, `"max_tokens"`, `"thinking"`, `"output_config"`, `"stream"`)
require.Contains(t, resultStr, cache.maskedSessionID)
require.True(t, strings.Contains(resultStr, `"metadata":{"user_id":"`))
}
func strconvQuote(v string) string {
return `"` + strings.ReplaceAll(strings.ReplaceAll(v, `\`, `\\`), `"`, `\"`) + `"`
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment