Commit 3b7a5fff authored by 陈曦's avatar 陈曦
Browse files

补充openai、gemini以及流失请求的采集数据以及nfs落库

parent 8519a8eb
Pipeline #82284 failed with stage
in 2 minutes and 21 seconds
......@@ -212,6 +212,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
nil,
nil, // promoService
nil, // defaultSubAssigner
nil, // affiliateService
)
}
......@@ -243,7 +244,7 @@ func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testi
}, nil)
// 应返回服务不可用错误,而不是允许绕过验证
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "", "")
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "", "", "")
require.ErrorIs(t, err, ErrServiceUnavailable)
}
......@@ -255,7 +256,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
SettingKeyEmailVerifyEnabled: "true",
}, cache)
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "", "")
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "", "", "")
require.ErrorIs(t, err, ErrEmailVerifyRequired)
}
......@@ -269,7 +270,7 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
SettingKeyEmailVerifyEnabled: "true",
}, cache)
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "", "")
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "", "", "")
require.ErrorIs(t, err, ErrInvalidVerifyCode)
require.ErrorContains(t, err, "verify code")
}
......@@ -621,7 +622,7 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_UsesLinuxDoAuthSourceDefa
service.defaultSubAssigner = assigner
service.refreshTokenCache = &refreshTokenCacheStub{}
tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "")
tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "", "")
require.NoError(t, err)
require.NotNil(t, tokenPair)
require.NotNil(t, user)
......@@ -657,7 +658,7 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_ExistingUserDoesNotGrantA
service.defaultSubAssigner = assigner
service.refreshTokenCache = &refreshTokenCacheStub{}
tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "")
tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "", "")
require.NoError(t, err)
require.NotNil(t, tokenPair)
require.Equal(t, existing.ID, user.ID)
......
......@@ -54,6 +54,7 @@ func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier
nil, // emailQueueService
nil, // promoService
nil, // defaultSubAssigner
nil, // affiliateService
)
}
......
......@@ -18,6 +18,19 @@ const (
RoleUser = domain.RoleUser
)
// Affiliate rebate settings
const (
AffiliateRebateRateDefault = 20.0
AffiliateRebateRateMin = 0.0
AffiliateRebateRateMax = 100.0
AffiliateEnabledDefault = false // 邀请返利总开关默认关闭
AffiliateRebateFreezeHoursDefault = 0 // 0 = 不冻结(向后兼容)
AffiliateRebateFreezeHoursMax = 720 // 最大 30 天
AffiliateRebateDurationDaysDefault = 0 // 0 = 永久有效
AffiliateRebateDurationDaysMax = 3650 // ~10 年
AffiliateRebatePerInviteeCapDefault = 0.0 // 0 = 无上限
)
// Platform constants
const (
PlatformAnthropic = domain.PlatformAnthropic
......@@ -87,6 +100,11 @@ const (
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
SettingKeyFrontendURL = "frontend_url" // 前端基础URL,用于生成邮件中的重置密码链接
SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册
SettingKeyAffiliateEnabled = "affiliate_enabled" // 邀请返利功能总开关
SettingKeyAffiliateRebateRate = "affiliate_rebate_rate" // 邀请返利比例(百分比,0-100)
SettingKeyAffiliateRebateFreezeHours = "affiliate_rebate_freeze_hours" // 返利冻结期(小时,0=不冻结)
SettingKeyAffiliateRebateDurationDays = "affiliate_rebate_duration_days" // 返利有效期(天,0=永久)
SettingKeyAffiliateRebatePerInviteeCap = "affiliate_rebate_per_invitee_cap" // 单人返利上限(0=无上限)
// 邮件服务设置
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
......
......@@ -762,8 +762,14 @@ func TestGatewayService_AnthropicOAuth_ForwardPreservesBillingHeaderSystemBlock(
system := gjson.GetBytes(upstream.lastBody, "system")
require.True(t, system.Exists())
require.True(t, system.IsArray(), "system should be an array")
require.Equal(t, claudeCodeSystemPrompt, system.Array()[0].Get("text").String())
require.Equal(t, "ephemeral", system.Array()[0].Get("cache_control.type").String())
arr := system.Array()
require.Len(t, arr, 2, "system array should have billing block + cc prompt block")
require.Contains(t, arr[0].Get("text").String(), "x-anthropic-billing-header:")
require.Contains(t, arr[0].Get("text").String(), "cc_version=")
require.Equal(t, claudeCodeSystemPrompt, arr[1].Get("text").String())
require.Equal(t, "ephemeral", arr[1].Get("cache_control.type").String())
// 原始 system prompt 应迁移至 messages 中
messages := gjson.GetBytes(upstream.lastBody, "messages")
......
package service
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"github.com/tidwall/gjson"
)
// fingerprintSalt 是计算 cc_version 后缀指纹的盐值。
//
// 来源:与 Parrot src/transform/cc_mimicry.py 的 FINGERPRINT_SALT 完全一致;
// 这是真实 Claude Code CLI 抓包推导出的常量,改动会导致 fp 与 CLI 不一致,
// 进一步触发 Anthropic 的第三方检测。
const fingerprintSalt = "59cf53e54c78"
// computeClaudeCodeFingerprint 复刻真实 Claude Code CLI 的 cc_version 指纹算法:
//
// 1. 取 messages 中第一条 role=user 的纯文本(首块 text)
// 2. 取该文本的第 4、7、20 字符(不足以 '0' 补齐)
// 3. SHA256(SALT + chars + cc_version) 取 hex 前 3 字符
//
// 算法来自 Parrot src/transform/cc_mimicry.py:compute_fingerprint,与官方 CLI 字节对齐。
// 任何偏差都会导致 cc_version=X.Y.Z.{fp} 在上游侧与真实 CLI 不一致。
func computeClaudeCodeFingerprint(body []byte, version string) string {
firstText := extractFirstUserText(body)
indices := []int{4, 7, 20}
chars := make([]byte, 0, 3)
for _, i := range indices {
if i < len(firstText) {
chars = append(chars, firstText[i])
} else {
chars = append(chars, '0')
}
}
sum := sha256.Sum256([]byte(fingerprintSalt + string(chars) + version))
return hex.EncodeToString(sum[:])[:3]
}
// extractFirstUserText 提取 messages 中第一条 user 消息的首段 text 内容。
// 兼容 string 和 []block 两种 content 格式。
func extractFirstUserText(body []byte) string {
messages := gjson.GetBytes(body, "messages")
if !messages.IsArray() {
return ""
}
first := ""
messages.ForEach(func(_, msg gjson.Result) bool {
if msg.Get("role").String() != "user" {
return true
}
content := msg.Get("content")
if content.Type == gjson.String {
first = content.String()
return false
}
if content.IsArray() {
content.ForEach(func(_, block gjson.Result) bool {
if block.Get("type").String() == "text" {
first = block.Get("text").String()
return false
}
return true
})
return false
}
return false
})
return first
}
// buildBillingAttributionBlockJSON 构造 system 数组的 billing attribution block。
//
// 形态严格对齐真实 Claude Code CLI:
//
// {"type":"text","text":"x-anthropic-billing-header: cc_version=2.1.92.{fp}; cc_entrypoint=cli; cch=00000;"}
//
// cch=00000 是签名占位符,由 signBillingHeaderCCH 在 buildUpstreamRequest 阶段
// 替换为基于完整 body 的 xxhash64 5 位十六进制摘要。
//
// 此 block 不带 cache_control(与真实 CLI 一致;cache breakpoint 由后续的
// Claude Code prompt block 承担)。
func buildBillingAttributionBlockJSON(body []byte, cliVersion string) ([]byte, error) {
if cliVersion == "" {
return nil, fmt.Errorf("cliVersion required")
}
fp := computeClaudeCodeFingerprint(body, cliVersion)
text := fmt.Sprintf(
"x-anthropic-billing-header: cc_version=%s.%s; cc_entrypoint=cli; cch=00000;",
cliVersion, fp,
)
return json.Marshal(map[string]string{
"type": "text",
"text": text,
})
}
......@@ -41,12 +41,13 @@ func TestNormalizeClaudeOAuthRequestBody_PreservesTopLevelFieldOrder(t *testing.
resultStr := string(result)
require.Equal(t, claude.NormalizeModelID("claude-3-5-sonnet-latest"), modelID)
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"model"`, `"system"`, `"messages"`, `"omega"`, `"tools"`, `"metadata"`)
require.NotContains(t, resultStr, `"temperature"`)
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"model"`, `"temperature"`, `"system"`, `"messages"`, `"omega"`, `"tools"`, `"metadata"`, `"max_tokens"`)
require.Contains(t, resultStr, `"temperature":0.2`)
require.NotContains(t, resultStr, `"tool_choice"`)
require.Contains(t, resultStr, `"system":"`+claudeCodeSystemPrompt+`"`)
require.Contains(t, resultStr, `"tools":[]`)
require.Contains(t, resultStr, `"metadata":{"user_id":"user-1"}`)
require.Contains(t, resultStr, `"max_tokens":128000`)
}
func TestInjectClaudeCodePrompt_PreservesFieldOrder(t *testing.T) {
......
......@@ -85,15 +85,16 @@ func (s *GatewayService) ForwardAsChatCompletions(
return nil, fmt.Errorf("marshal anthropic request: %w", err)
}
// 6. Apply Claude Code mimicry for OAuth accounts
isClaudeCode := false // CC API is never Claude Code
// 6. Apply Claude Code mimicry for OAuth accounts.
// Chat Completions 协议进来的请求永远不是 Claude Code 客户端,所以对 OAuth 账号
// 必须完整执行 /v1/messages 主路径上的伪装链路(system 重写 + normalize + metadata 注入),
// 否则会被 Anthropic 判为第三方应用并扣 extra usage。
// 见 applyClaudeCodeOAuthMimicryToBody 的 godoc。
isClaudeCode := false
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
if shouldMimicClaudeCode {
if !strings.Contains(strings.ToLower(mappedModel), "haiku") &&
!systemIncludesClaudeCodePrompt(anthropicReq.System) {
anthropicBody = injectClaudeCodePrompt(anthropicBody, anthropicReq.System)
}
anthropicBody = s.applyClaudeCodeOAuthMimicryToBody(ctx, c, account, anthropicBody, anthropicReq.System, mappedModel)
}
// 7. Enforce cache_control block limit
......@@ -312,7 +313,16 @@ func (s *GatewayService) handleCCBufferedFromAnthropic(
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
// Marshal then bytes-replace so tool name mapping is reversed at byte level
// (parity with Parrot non-stream flow that marshals → restore → emit).
var responseBody string
if respBytes, err := json.Marshal(ccResp); err == nil {
respBytes = reverseToolNamesIfPresent(c, respBytes)
responseBody = string(respBytes)
c.Data(http.StatusOK, "application/json; charset=utf-8", respBytes)
} else {
c.JSON(http.StatusOK, ccResp)
}
return &ForwardResult{
RequestID: requestID,
......@@ -322,6 +332,7 @@ func (s *GatewayService) handleCCBufferedFromAnthropic(
ReasoningEffort: reasoningEffort,
Stream: false,
Duration: time.Since(startTime),
ResponseBody: responseBody,
}, nil
}
......@@ -357,6 +368,7 @@ func (s *GatewayService) handleCCStreamingFromAnthropic(
var usage ClaudeUsage
var firstTokenMs *int
firstChunk := true
var textBuilder strings.Builder // 收集 assistant 文本用于响应捕获
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
......@@ -375,15 +387,23 @@ func (s *GatewayService) handleCCStreamingFromAnthropic(
Stream: true,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
ResponseBody: textBuilder.String(),
}
}
writeChunk := func(chunk apicompat.ChatCompletionsChunk) bool {
// 收集 assistant text 用于响应捕获
if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != nil {
textBuilder.WriteString(*chunk.Choices[0].Delta.Content)
}
sse, err := apicompat.ChatChunkToSSE(chunk)
if err != nil {
return false
}
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
// Reverse tool name mapping: fake → real, per-chunk bytes.Replace.
// c 可能持有请求侧注入的 ToolNameRewrite;无则仅做静态前缀还原。
out := string(reverseToolNamesIfPresent(c, []byte(sse)))
if _, err := fmt.Fprint(c.Writer, out); err != nil {
return true // client disconnected
}
return false
......
......@@ -82,15 +82,16 @@ func (s *GatewayService) ForwardAsResponses(
return nil, fmt.Errorf("marshal anthropic request: %w", err)
}
// 6. Apply Claude Code mimicry for OAuth accounts (non-Claude-Code endpoints)
isClaudeCode := false // Responses API is never Claude Code
// 6. Apply Claude Code mimicry for OAuth accounts (non-Claude-Code endpoints).
// OpenAI Responses 协议进来的请求永远不是 Claude Code 客户端,所以对 OAuth 账号
// 必须完整执行 /v1/messages 主路径上的伪装链路(system 重写 + normalize + metadata 注入),
// 否则会被 Anthropic 判为第三方应用并扣 extra usage。
// 见 applyClaudeCodeOAuthMimicryToBody 的 godoc。
isClaudeCode := false
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
if shouldMimicClaudeCode {
if !strings.Contains(strings.ToLower(mappedModel), "haiku") &&
!systemIncludesClaudeCodePrompt(anthropicReq.System) {
anthropicBody = injectClaudeCodePrompt(anthropicBody, anthropicReq.System)
}
anthropicBody = s.applyClaudeCodeOAuthMimicryToBody(ctx, c, account, anthropicBody, anthropicReq.System, mappedModel)
}
// 7. Enforce cache_control block limit
......@@ -331,7 +332,12 @@ func (s *GatewayService) handleResponsesBufferedStreamingResponse(
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
if respBytes, err := json.Marshal(responsesResp); err == nil {
respBytes = reverseToolNamesIfPresent(c, respBytes)
c.Data(http.StatusOK, "application/json; charset=utf-8", respBytes)
} else {
c.JSON(http.StatusOK, responsesResp)
}
return &ForwardResult{
RequestID: requestID,
......@@ -419,7 +425,8 @@ func (s *GatewayService) handleResponsesStreamingResponse(
)
continue
}
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
out := string(reverseToolNamesIfPresent(c, []byte(sse)))
if _, err := fmt.Fprint(c.Writer, out); err != nil {
logger.L().Info("forward_as_responses stream: client disconnected",
zap.String("request_id", requestID),
)
......@@ -439,7 +446,8 @@ func (s *GatewayService) handleResponsesStreamingResponse(
if err != nil {
continue
}
fmt.Fprint(c.Writer, sse) //nolint:errcheck
out := string(reverseToolNamesIfPresent(c, []byte(sse)))
fmt.Fprint(c.Writer, out) //nolint:errcheck
}
c.Writer.Flush()
}
......
package service
import (
"fmt"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// stripMessageCacheControl 移除 $.messages[*].content[*].cache_control。
// 与 Parrot _strip_message_cache_control 语义一致。
//
// 为什么必须整体清空:客户端(特别是 Claude Code)经常把 cache_control 打在
// "当前最后一条 user message" 上;下一轮对话 messages 追加后,原本的最后一条
// 变成中间某条,cache_control 还挂着就导致"前缀签名变化",破坏缓存命中。
// 统一由代理重新打断点(addMessageCacheBreakpoints)才能在多轮间稳定。
func stripMessageCacheControl(body []byte) []byte {
messages := gjson.GetBytes(body, "messages")
if !messages.IsArray() {
return body
}
msgIdx := -1
messages.ForEach(func(_, msg gjson.Result) bool {
msgIdx++
content := msg.Get("content")
if !content.IsArray() {
return true
}
blockIdx := -1
content.ForEach(func(_, block gjson.Result) bool {
blockIdx++
if !block.Get("cache_control").Exists() {
return true
}
path := fmt.Sprintf("messages.%d.content.%d.cache_control", msgIdx, blockIdx)
if next, err := sjson.DeleteBytes(body, path); err == nil {
body = next
}
return true
})
return true
})
return body
}
// addMessageCacheBreakpoints 在 messages 上注入两个稳定的 cache 断点:
// 1. 最后一条 message
// 2. 当 messages 数量 ≥ 4 时,倒数第二个 role=user 的 message
//
// 与 Parrot add_cache_breakpoints 一致。两个断点 + system prompt block 的断点
// + tools[-1] 的断点共同构成最多 4 个断点(Anthropic 上限)。
//
// cache_control ttl 策略:
// - 若目标 block 已有 cache_control.ttl → 不覆盖
// - 否则写入 {"type":"ephemeral","ttl": claude.DefaultCacheControlTTL}
//
// 调用前应先 stripMessageCacheControl 以保证幂等和稳定。
func addMessageCacheBreakpoints(body []byte) []byte {
messages := gjson.GetBytes(body, "messages")
if !messages.IsArray() {
return body
}
arr := messages.Array()
if len(arr) == 0 {
return body
}
body = injectCacheControlOnLastContentBlock(body, len(arr)-1, &arr[len(arr)-1])
if len(arr) >= 4 {
userCount := 0
for i := len(arr) - 1; i >= 0; i-- {
if arr[i].Get("role").String() != "user" {
continue
}
userCount++
if userCount == 2 {
body = injectCacheControlOnLastContentBlock(body, i, &arr[i])
break
}
}
}
return body
}
// injectCacheControlOnLastContentBlock 把 cache_control 断点打在 messages[idx]
// 的最后一个 content block 上。若 content 是 string,先升级成单块 text 数组
// (对齐 Parrot _inject_cache_on_msg 的行为)。
//
// msg 是调用方已持有的 gjson.Result 快照,用于省一次 GetBytes。
func injectCacheControlOnLastContentBlock(body []byte, idx int, msg *gjson.Result) []byte {
content := msg.Get("content")
if content.Type == gjson.String {
text := content.String()
blockRaw := fmt.Sprintf(
`[{"type":"text","text":%s,"cache_control":{"type":"ephemeral","ttl":%q}}]`,
mustJSONString(text), claude.DefaultCacheControlTTL,
)
if next, err := sjson.SetRawBytes(body, fmt.Sprintf("messages.%d.content", idx), []byte(blockRaw)); err == nil {
body = next
}
return body
}
if !content.IsArray() {
return body
}
contentArr := content.Array()
if len(contentArr) == 0 {
return body
}
lastBlockIdx := len(contentArr) - 1
lastBlock := contentArr[lastBlockIdx]
if cc := lastBlock.Get("cache_control"); cc.Exists() && cc.Get("ttl").String() != "" {
return body
}
pathPrefix := fmt.Sprintf("messages.%d.content.%d.cache_control", idx, lastBlockIdx)
existingCC := lastBlock.Get("cache_control")
if existingCC.Exists() {
if next, err := sjson.SetBytes(body, pathPrefix+".ttl", claude.DefaultCacheControlTTL); err == nil {
body = next
}
return body
}
raw := fmt.Sprintf(`{"type":"ephemeral","ttl":%q}`, claude.DefaultCacheControlTTL)
if next, err := sjson.SetRawBytes(body, pathPrefix, []byte(raw)); err == nil {
body = next
}
return body
}
// mustJSONString 把一个 Go string 序列化为合法 JSON string(含引号),
// 用于 sjson.SetRawBytes 场景下手工拼 JSON。
func mustJSONString(s string) string {
return fmt.Sprintf("%q", s)
}
......@@ -9,6 +9,11 @@ import (
)
func TestIsClaudeCodeClient(t *testing.T) {
// 合法的 legacy 格式 metadata.user_id(64位 hex + account uuid + session uuid)
legacyUserID := "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account_550e8400-e29b-41d4-a716-446655440000_session_123e4567-e89b-12d3-a456-426614174000"
// 合法的 JSON 格式 metadata.user_id(2.1.78+ 版本)
jsonUserID := `{"device_id":"a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2","account_uuid":"550e8400-e29b-41d4-a716-446655440000","session_id":"123e4567-e89b-12d3-a456-426614174000"}`
tests := []struct {
name string
userAgent string
......@@ -16,15 +21,21 @@ func TestIsClaudeCodeClient(t *testing.T) {
want bool
}{
{
name: "Claude Code client",
name: "Claude Code client with legacy user_id",
userAgent: "claude-cli/1.0.62 (darwin; arm64)",
metadataUserID: "session_123e4567-e89b-12d3-a456-426614174000",
metadataUserID: legacyUserID,
want: true,
},
{
name: "Claude Code without version suffix",
userAgent: "claude-cli/2.0.0",
metadataUserID: "session_abc",
name: "Claude Code client with JSON user_id",
userAgent: "claude-cli/2.1.92 (external, cli)",
metadataUserID: jsonUserID,
want: true,
},
{
name: "Claude Code case insensitive UA",
userAgent: "Claude-CLI/2.0.0",
metadataUserID: legacyUserID,
want: true,
},
{
......@@ -34,21 +45,33 @@ func TestIsClaudeCodeClient(t *testing.T) {
want: false,
},
{
name: "Different user agent",
name: "Claude CLI UA with invalid user_id format",
userAgent: "claude-cli/2.0.0",
metadataUserID: "fake-user-id-12345",
want: false,
},
{
name: "Different user agent with valid user_id",
userAgent: "curl/7.68.0",
metadataUserID: "user123",
metadataUserID: legacyUserID,
want: false,
},
{
name: "Empty user agent",
userAgent: "",
metadataUserID: "user123",
metadataUserID: legacyUserID,
want: false,
},
{
name: "Similar but not Claude CLI",
userAgent: "claude-api/1.0.0",
metadataUserID: "user123",
metadataUserID: legacyUserID,
want: false,
},
{
name: "Opencode spoofing UA with arbitrary user_id",
userAgent: "claude-cli/2.1.92",
metadataUserID: "session_abc",
want: false,
},
}
......@@ -378,16 +401,27 @@ func TestRewriteSystemForNonClaudeCode(t *testing.T) {
err := json.Unmarshal(result, &parsed)
require.NoError(t, err)
// system 应为 array 格式: [{type: "text", text: "...", cache_control: {type: "ephemeral"}}]
// system 应为 array 格式,对齐真实 Claude Code CLI 的 2-block 形态:
// [0] billing attribution block (x-anthropic-billing-header: cc_version=...;)
// [1] Claude Code prompt block (带 cache_control)
systemArr, ok := parsed["system"].([]any)
require.True(t, ok, "system should be an array, got %T", parsed["system"])
require.Len(t, systemArr, 1, "system array should have exactly 1 block")
systemBlock, ok := systemArr[0].(map[string]any)
require.Len(t, systemArr, 2, "system array should have exactly 2 blocks (billing + cc prompt)")
billingBlock, ok := systemArr[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "text", billingBlock["type"])
require.Contains(t, billingBlock["text"], "x-anthropic-billing-header:")
require.Contains(t, billingBlock["text"], "cc_version=")
require.Contains(t, billingBlock["text"], "cc_entrypoint=cli")
require.Contains(t, billingBlock["text"], "cch=00000")
systemBlock, ok := systemArr[1].(map[string]any)
require.True(t, ok)
require.Equal(t, "text", systemBlock["type"])
require.Equal(t, tt.wantSystemText, systemBlock["text"])
cc, ok := systemBlock["cache_control"].(map[string]any)
require.True(t, ok, "system block should have cache_control")
require.True(t, ok, "cc prompt block should have cache_control")
require.Equal(t, "ephemeral", cc["type"])
// 检查 messages
......
......@@ -119,7 +119,7 @@ func openAIStreamEventIsTerminal(data string) bool {
return true
}
switch gjson.Get(trimmed, "type").String() {
case "response.completed", "response.done", "response.failed":
case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
return true
default:
return false
......@@ -329,7 +329,7 @@ func isClaudeCodeCredentialScopeError(msg string) bool {
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var (
sseDataRe = regexp.MustCompile(`^data:\s*`)
claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`)
claudeCliUserAgentRe = regexp.MustCompile(`(?i)^claude-cli/\d+\.\d+\.\d+`)
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
// 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等
......@@ -854,6 +854,7 @@ func (s *GatewayService) hashContent(content string) string {
type anthropicCacheControlPayload struct {
Type string `json:"type"`
TTL string `json:"ttl,omitempty"`
}
type anthropicSystemTextBlockPayload struct {
......@@ -902,7 +903,10 @@ func marshalAnthropicSystemTextBlock(text string, includeCacheControl bool) ([]b
Text: text,
}
if includeCacheControl {
block.CacheControl = &anthropicCacheControlPayload{Type: "ephemeral"}
block.CacheControl = &anthropicCacheControlPayload{
Type: "ephemeral",
TTL: claude.DefaultCacheControlTTL,
}
}
return json.Marshal(block)
}
......@@ -1078,18 +1082,51 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
}
}
if gjson.GetBytes(out, "temperature").Exists() {
if next, ok := deleteJSONPathBytes(out, "temperature"); ok {
// temperature:真实 Claude Code CLI 总是发送 temperature(默认 1,客户端可覆盖)。
// 之前的实现直接 delete 会导致 payload 缺字段,与真实 CLI 字节级不一致。
// 策略:客户端传了什么就透传;没传则补默认 1。
if !gjson.GetBytes(out, "temperature").Exists() {
if next, ok := setJSONValueBytes(out, "temperature", 1); ok {
out = next
modified = true
}
}
// max_tokens:真实 CLI 的默认值是 128000。缺失时补齐以对齐指纹。
if !gjson.GetBytes(out, "max_tokens").Exists() {
if next, ok := setJSONValueBytes(out, "max_tokens", 128000); ok {
out = next
modified = true
}
}
// context_management:thinking.type 为 enabled/adaptive 时,真实 CLI 会自动
// 附带 {"edits":[{"type":"clear_thinking_20251015","keep":"all"}]}。
// 客户端显式传了就透传;否则按 CLI 行为补齐。
if !gjson.GetBytes(out, "context_management").Exists() {
thinkingType := gjson.GetBytes(out, "thinking.type").String()
if thinkingType == "enabled" || thinkingType == "adaptive" {
const cmDefault = `{"edits":[{"type":"clear_thinking_20251015","keep":"all"}]}`
if next, ok := setJSONRawBytes(out, "context_management", []byte(cmDefault)); ok {
out = next
modified = true
}
}
}
// tool_choice:与 Parrot 对齐,不再无条件删除。
// - 客户端传了 {"type":"tool","name":"X"} → 保留结构,name 由
// applyToolNameRewriteToBody 同步映射为假名
// - 其他形态(auto/any/none)原样透传
// 如果 body 里完全没有 tools(空数组),tool_choice 没意义时才删除
if !gjson.GetBytes(out, "tools").IsArray() || len(gjson.GetBytes(out, "tools").Array()) == 0 {
if gjson.GetBytes(out, "tool_choice").Exists() {
if next, ok := deleteJSONPathBytes(out, "tool_choice"); ok {
out = next
modified = true
}
}
}
if !modified {
return body, modelID
......@@ -1132,6 +1169,135 @@ func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account
return FormatMetadataUserID(userID, accountUUID, sessionID, uaVersion)
}
// applyClaudeCodeOAuthMimicryToBody 将"非 Claude Code 客户端 + Claude OAuth 账号"
// 路径上原本只在 /v1/messages 里做的完整伪装应用到任意 body 上。
//
// 这是 /v1/messages 主路径上 rewriteSystemForNonClaudeCode +
// normalizeClaudeOAuthRequestBody 流程的通用版,供 OpenAI 协议兼容层
// (ForwardAsChatCompletions / ForwardAsResponses) 复用。
//
// 未抽离之前,OpenAI 协议兼容层仅做 injectClaudeCodePrompt(前置追加),
// 而仓内 /v1/messages 路径自己的注释明确说过"仅前置追加无法通过 Anthropic
// 第三方检测";那条注释就是本函数存在的根因。
//
// 参数:
// - ctx / c:用于读取指纹和 gateway settings;c 可为 nil(如 count_tokens)。
// - account:必须是 OAuth 账号,且调用方已判断不是 Claude Code 客户端。
// - body:已经 marshal 成 Anthropic /v1/messages 格式的请求体。
// - systemRaw:body 中原始 system 字段(用于判断是否需要 rewrite)。
// - model:最终会发给上游的模型 ID(用于 haiku 旁路 + metadata 版本选择)。
//
// 返回:改写后的 body。即使中间任何一步失败,也会退化成原 body(不会 panic)。
func (s *GatewayService) applyClaudeCodeOAuthMimicryToBody(
ctx context.Context,
c *gin.Context,
account *Account,
body []byte,
systemRaw any,
model string,
) []byte {
if account == nil || !account.IsOAuth() || len(body) == 0 {
return body
}
systemRewritten := false
if !strings.Contains(strings.ToLower(model), "haiku") {
body = rewriteSystemForNonClaudeCode(body, systemRaw)
systemRewritten = true
}
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: !systemRewritten}
if s.identityService != nil && c != nil && c.Request != nil {
if fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header); err == nil && fp != nil {
mimicMPT := false
if s.settingService != nil {
_, mimicMPT, _ = s.settingService.GetGatewayForwardingSettings(ctx)
}
if !mimicMPT {
if uid := s.buildOAuthMetadataUserIDFromBody(ctx, account, fp, body); uid != "" {
normalizeOpts.injectMetadata = true
normalizeOpts.metadataUserID = uid
}
}
}
}
body, _ = normalizeClaudeOAuthRequestBody(body, model, normalizeOpts)
// Phase D+E+F: messages cache 策略 + 工具名混淆 + tools[-1] 断点
// 对齐 Parrot transform_request 里剩余的字段级改写。三步顺序有语义约束:
// 1) strip:先清除客户端的 messages[*].cache_control(多轮稳定性)
// 2) breakpoints:再注入 2 个断点(最后一条 + 倒数第二个 user turn)
// 3) tool rewrite:最后改 tools[*].name / tool_choice.name 并在 tools[-1]
// 上打断点;mapping 存入 gin.Context 供响应侧 bytes.Replace 还原。
body = stripMessageCacheControl(body)
body = addMessageCacheBreakpoints(body)
if rw := buildToolNameRewriteFromBody(body); rw != nil {
body = applyToolNameRewriteToBody(body, rw)
if c != nil {
c.Set(toolNameRewriteKey, rw)
}
} else {
body = applyToolsLastCacheBreakpoint(body)
}
return body
}
// buildOAuthMetadataUserIDFromBody 是 buildOAuthMetadataUserID 的变体,
// 适用于调用方手上没有 ParsedRequest 的场景(如 OpenAI 协议兼容层)。
//
// 与 buildOAuthMetadataUserID 的唯一区别:
// - session hash 从 body 本体按同样规则重算,而不是读取 ParsedRequest 缓存值。
// - 如果 body 里已经存在 metadata.user_id,则返回空(由 ensureClaudeOAuthMetadataUserID
// 自行决定是否覆盖)。
func (s *GatewayService) buildOAuthMetadataUserIDFromBody(
ctx context.Context,
account *Account,
fp *Fingerprint,
body []byte,
) string {
_ = ctx
if account == nil {
return ""
}
if existing := gjson.GetBytes(body, "metadata.user_id").String(); existing != "" {
return ""
}
userID := strings.TrimSpace(account.GetClaudeUserID())
if userID == "" && fp != nil {
userID = fp.ClientID
}
if userID == "" {
userID = generateClientID()
}
sessionID := uuid.NewString()
if hash := hashBodyForSessionSeed(body); hash != "" {
sessionID = generateSessionUUID(fmt.Sprintf("%d::%s", account.ID, hash))
}
var uaVersion string
if fp != nil {
uaVersion = ExtractCLIVersion(fp.UserAgent)
}
accountUUID := strings.TrimSpace(account.GetExtraString("account_uuid"))
return FormatMetadataUserID(userID, accountUUID, sessionID, uaVersion)
}
// hashBodyForSessionSeed 为 sessionID 提供一个稳定但仅对本次请求特征化的种子。
// 复用 SHA-256 + 截断,与 generateSessionUUID 的输入格式对齐。
func hashBodyForSessionSeed(body []byte) string {
if len(body) == 0 {
return ""
}
sum := sha256.Sum256(body)
return fmt.Sprintf("%x", sum[:16])
}
// GenerateSessionUUID creates a deterministic UUID4 from a seed string.
func GenerateSessionUUID(seed string) string {
return generateSessionUUID(seed)
......@@ -3547,23 +3713,19 @@ func sleepWithContext(ctx context.Context, d time.Duration) error {
}
}
// isClaudeCodeClient 判断请求是否来自 Claude Code 客户端
// 简化判断:User-Agent 匹配 + metadata.user_id 存在
// isClaudeCodeClient 判断请求是否来自真正的 Claude Code 客户端。
// 判定条件:
// 1. User-Agent 匹配 claude-cli/X.Y.Z(大小写不敏感)
// 2. metadata.user_id 符合 Claude Code 格式(legacy 或 JSON 格式)
//
// 只检查 metadata.user_id 非空不够严格:第三方工具(opencode 等)可能伪造 UA
// 并附带任意 metadata.user_id 字符串,从而绕过 mimicry。必须通过 ParseMetadataUserID
// 验证格式才能确认是真正的 Claude Code 客户端。
func isClaudeCodeClient(userAgent string, metadataUserID string) bool {
if metadataUserID == "" {
return false
}
return claudeCliUserAgentRe.MatchString(userAgent)
}
func isClaudeCodeRequest(ctx context.Context, c *gin.Context, parsed *ParsedRequest) bool {
if IsClaudeCodeClient(ctx) {
return true
}
if parsed == nil || c == nil {
if !claudeCliUserAgentRe.MatchString(userAgent) {
return false
}
return isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID)
return ParseMetadataUserID(metadataUserID) != nil
}
// normalizeSystemParam 将 json.RawMessage 类型的 system 参数转为标准 Go 类型(string / []any / nil),
......@@ -3742,17 +3904,20 @@ func rewriteSystemForNonClaudeCode(body []byte, system any) []byte {
originalSystemText = strings.Join(parts, "\n\n")
}
// 2. 将 system 替换为 Claude Code 标准提示词(array 格式,与真实 Claude Code 一致)
// 真实 Claude Code 始终以 [{type: "text", text: "...", cache_control: {type: "ephemeral"}}] 发送 system。
// 使用 string 格式会被 Anthropic 检测为第三方应用。
claudeCodeSystemBlock := []map[string]any{
{
"type": "text",
"text": claudeCodeSystemPrompt,
"cache_control": map[string]string{"type": "ephemeral"},
},
// 2. 构造 system 数组,对齐真实 Claude Code CLI 的 2-block 形态:
// [0] billing attribution block(cc_version={cliVer}.{fp}; cc_entrypoint=cli; cch=00000;)
// [1] "You are Claude Code..." prompt block(带 cache_control 作为稳定缓存断点)
//
// billing block 的 cch=00000 是占位符,会被 buildUpstreamRequest 里的
// signBillingHeaderCCH 替换成 xxhash64 签名。缺失 billing block 的系统 payload
// 是 Anthropic 判定第三方的关键信号之一(真实 CLI 每个请求都带)。
billingBlock, billingErr := buildBillingAttributionBlockJSON(body, claude.CLICurrentVersion)
ccPromptBlock, ccErr := marshalAnthropicSystemTextBlock(claudeCodeSystemPrompt, true)
if billingErr != nil || ccErr != nil {
logger.LegacyPrintf("service.gateway", "Warning: failed to build system blocks (billing=%v, cc=%v)", billingErr, ccErr)
return body
}
out, ok := setJSONValueBytes(body, "system", claudeCodeSystemBlock)
out, ok := setJSONRawBytes(body, "system", buildJSONArrayRaw([][]byte{billingBlock, ccPromptBlock}))
if !ok {
logger.LegacyPrintf("service.gateway", "Warning: failed to set Claude Code system prompt")
return body
......@@ -3989,15 +4154,24 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
})
}
isClaudeCode := isClaudeCodeRequest(ctx, c, parsed)
// Claude Code 客户端判定:UA 匹配 claude-cli/* 且携带 metadata.user_id。
// 真正的 Claude Code 客户端自带完整的 system prompt、cache_control 断点和 header,
// 不需要代理做任何 body 级别的 mimicry;强行替换反而会破坏客户端的缓存策略
// (长 system prompt 被替换为 ~45 tokens 的短 prompt,低于 Anthropic 1024 token
// 最低缓存门槛,导致系统级缓存失效)。
//
// 对于非 Claude Code 的第三方客户端(opencode 等),仍然走完整 mimicry。
isClaudeCode := IsClaudeCodeClient(ctx) || isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID)
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
if shouldMimicClaudeCode {
// 非 Claude Code 客户端:将 system 替换为 Claude Code 标识,原始 system 迁移至 messages
// 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
// 与 Parrot 对齐:OAuth 账号无条件重写 system(即使客户端已发了 Claude Code
// 风格的 system prompt)。原因:第三方工具(opencode 等)会发 "You are Claude
// Code..." system prompt 但缺少 billing attribution block,导致 Anthropic
// 检测到"有 CC prompt 但无 billing block"的不一致而判为 third-party。
// Parrot 的 transform_request 从不检查客户端 system 内容,直接覆盖。
systemRewritten := false
if !strings.Contains(strings.ToLower(reqModel), "haiku") &&
!systemIncludesClaudeCodePrompt(parsed.System) {
if !strings.Contains(strings.ToLower(reqModel), "haiku") {
body = rewriteSystemForNonClaudeCode(body, parsed.System)
systemRewritten = true
}
......@@ -4021,6 +4195,18 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}
body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
// D/E/F: messages cache 策略 + 工具名混淆 + tools[-1] 断点
// 与 forward_as_chat_completions / forward_as_responses 路径对齐,
// 保证原生 /v1/messages 路径也经过完整的 Parrot 字段级改写。
body = stripMessageCacheControl(body)
body = addMessageCacheBreakpoints(body)
if rw := buildToolNameRewriteFromBody(body); rw != nil {
body = applyToolNameRewriteToBody(body, rw)
c.Set(toolNameRewriteKey, rw)
} else {
body = applyToolsLastCacheBreakpoint(body)
}
}
// 强制执行 cache_control 块数量限制(最多 4 个)
......@@ -4494,6 +4680,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
usage = streamResult.usage
firstTokenMs = streamResult.firstTokenMs
clientDisconnect = streamResult.clientDisconnect
// 若注入了 ResponseCaptureBuffer,从 context 中读取已收集的 assistant 文本
if captureBuilder, ok := ctx.Value(ctxkey.ResponseCaptureBuffer).(*strings.Builder); ok && captureBuilder != nil {
nonStreamingResponseBody = captureBuilder.String()
}
} else {
var nonStreamRespBody []byte
nonStreamRespBody, usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel)
......@@ -4963,7 +5153,8 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
}
if !clientDisconnected {
if _, err := io.WriteString(w, line); err != nil {
restored := string(reverseToolNamesIfPresent(c, []byte(line)))
if _, err := io.WriteString(w, restored); err != nil {
clientDisconnected = true
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID)
} else if _, err := io.WriteString(w, "\n"); err != nil {
......@@ -5133,6 +5324,7 @@ func (s *GatewayService) handleNonStreamingResponseAnthropicAPIKeyPassthrough(
if contentType == "" {
contentType = "application/json"
}
body = reverseToolNamesIfPresent(c, body)
c.Data(resp.StatusCode, contentType, body)
return usage, nil
}
......@@ -5588,7 +5780,12 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
setHeaderRaw(req.Header, "x-api-key", token)
}
// 白名单透传headers(恢复真实 wire casing)
// 白名单透传 headers
// OAuth mimicry 路径:跳过客户端 header 透传,与 Parrot 对齐。
// Parrot 的 build_upstream_headers 只发 9 个精确 header,不透传任何客户端 header。
// 透传客户端 header 会引入不一致的 x-stainless-* / anthropic-beta / user-agent /
// x-claude-code-session-id 等值,和我们注入的伪装 header 冲突,被 Anthropic 判 third-party。
if tokenType != "oauth" || !mimicClaudeCode {
for key, values := range clientHeaders {
lowerKey := strings.ToLower(key)
if allowedHeaders[lowerKey] {
......@@ -5598,6 +5795,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
}
}
}
// OAuth账号:应用缓存的指纹到请求头(覆盖白名单透传的头)
if fingerprint != nil {
......@@ -5635,7 +5833,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// Haiku models are exempt from third-party detection and don't need it.
requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking}
if !strings.Contains(strings.ToLower(modelID), "haiku") {
requiredBetas = []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking}
requiredBetas = claude.FullClaudeCodeMimicryBetas()
}
setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropSet))
} else {
......@@ -6107,6 +6305,11 @@ func applyClaudeCodeMimicHeaders(req *http.Request, isStream bool) {
if isStream {
setHeaderRaw(req.Header, "x-stainless-helper-method", "stream")
}
// Real Claude CLI 每个请求都会生成一个新的 UUID 放在 x-client-request-id。
// 上游会以此作为会话/请求指纹的一部分,缺失或重复都可能触发第三方判定。
if getHeaderRaw(req.Header, "x-client-request-id") == "" {
setHeaderRaw(req.Header, "x-client-request-id", uuid.NewString())
}
}
func truncateForLog(b []byte, maxBytes int) string {
......@@ -6695,6 +6898,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage
sawTerminalEvent := false
// 响应体捕获:若 context 中注入了 ResponseCaptureBuffer,则收集 text_delta 文本
captureBuilder, _ := ctx.Value(ctxkey.ResponseCaptureBuffer).(*strings.Builder)
pendingEventLines := make([]string, 0, 4)
processSSEEvent := func(lines []string) ([]string, string, *sseUsagePatch, error) {
......@@ -6750,6 +6956,17 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
eventChanged := false
// 收集 assistant text(仅 content_block_delta + text_delta)
if captureBuilder != nil && eventType == "content_block_delta" {
if delta, ok := event["delta"].(map[string]any); ok {
if dt, _ := delta["type"].(string); dt == "text_delta" {
if text, _ := delta["text"].(string); text != "" {
captureBuilder.WriteString(text)
}
}
}
}
// 兼容 Kimi cached_tokens → cache_read_input_tokens
if eventType == "message_start" {
if msg, ok := event["message"].(map[string]any); ok {
......@@ -6872,7 +7089,8 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
for _, block := range outputBlocks {
if !clientDisconnected {
if _, werr := fmt.Fprint(w, block); werr != nil {
restored := reverseToolNamesIfPresent(c, []byte(block))
if _, werr := fmt.Fprint(w, string(restored)); werr != nil {
clientDisconnected = true
logger.LegacyPrintf("service.gateway", "Client disconnected during streaming, continuing to drain upstream for billing")
break
......@@ -7214,6 +7432,8 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
}
}
body = reverseToolNamesIfPresent(c, body)
// 写入响应
c.Data(resp.StatusCode, contentType, body)
......@@ -8202,12 +8422,20 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
// Pre-filter: strip empty text blocks to prevent upstream 400.
body = StripEmptyTextBlocks(body)
isClaudeCode := isClaudeCodeRequest(ctx, c, parsed)
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
isClaudeCodeCT := IsClaudeCodeClient(ctx) || isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID)
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCodeCT
if shouldMimicClaudeCode {
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true}
body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
body = stripMessageCacheControl(body)
body = addMessageCacheBreakpoints(body)
if rw := buildToolNameRewriteFromBody(body); rw != nil {
body = applyToolNameRewriteToBody(body, rw)
} else {
body = applyToolsLastCacheBreakpoint(body)
}
}
// Antigravity 账户不支持 count_tokens,返回 404 让客户端 fallback 到本地估算。
......@@ -8631,7 +8859,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
applyClaudeCodeMimicHeaders(req, false)
incomingBeta := getHeaderRaw(req.Header, "anthropic-beta")
requiredBetas := []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking, claude.BetaTokenCounting}
requiredBetas := append(claude.FullClaudeCodeMimicryBetas(), claude.BetaTokenCounting)
setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, ctEffectiveDropSet))
} else {
clientBetaHeader := getHeaderRaw(req.Header, "anthropic-beta")
......
package service
import (
"fmt"
"hash/fnv"
"math/rand"
"sort"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// toolNameRewriteKey 是 gin.Context 上存 ToolNameRewrite 映射的 key。
// 请求阶段写入,响应阶段读取,用于 bytes 级逆向还原假名 → 真名。
const toolNameRewriteKey = "claude_tool_name_rewrite"
// staticToolNameRewrites 是"静态前缀映射",与 Parrot src/transform/cc_mimicry.py
// TOOL_NAME_REWRITES 完全一致。只有以这些前缀开头的工具会被重写。
var staticToolNameRewrites = map[string]string{
"sessions_": "cc_sess_",
"session_": "cc_ses_",
}
// fakeToolNamePrefixes 是"动态映射"的前缀池,与 Parrot _FAKE_PREFIXES 一致。
// 当 tools 数量 > dynamicToolMapThreshold 时随机选用其中前缀生成可读假名。
var fakeToolNamePrefixes = []string{
"analyze_", "compute_", "fetch_", "generate_", "lookup_", "modify_",
"process_", "query_", "render_", "resolve_", "sync_", "update_",
"validate_", "convert_", "extract_", "manage_", "monitor_", "parse_",
"review_", "search_", "transform_", "handle_", "invoke_", "notify_",
}
// dynamicToolMapThreshold 与 Parrot 一致:tools 数量超过 5 才启用动态映射。
// 少量工具不需要混淆(一般是 Claude Code 自己的核心工具 bash/edit/read 等)。
const dynamicToolMapThreshold = 5
// ToolNameRewrite 是单次请求内的工具名混淆映射。
// - Forward: real → fake,请求阶段在 body 上应用。
// - Reverse: fake → real,响应阶段对每个 chunk 做 bytes.Replace 还原。
//
// ReverseOrdered 是按假名长度倒序的 (fake, real) 列表,用于防止短假名是长假名的
// 子串时 bytes.Replace 先被吃掉(对齐 Parrot _restore_tool_names_in_chunk 的
// `sorted(..., key=lambda x: len(x[1]), reverse=True)`)。
type ToolNameRewrite struct {
Forward map[string]string
Reverse map[string]string
ReverseOrdered [][2]string
}
// buildDynamicToolMap 构造 tools 的动态假名映射。
//
// 与 Parrot _build_dynamic_tool_map 语义等价:
// - tools 数量 ≤ dynamicToolMapThreshold 时返回 nil(不做动态映射,走静态 fallback)
// - 同一组 tool_names 在同进程内映射稳定(保证 cache 命中)
//
// Parrot 用 `random.Random(hash(tuple(tool_names)))` 作 seed + shuffle 前缀池;
// Go 无法字节级复刻 Python hash,但"稳定性"和"前缀池打散"两个不变量都保留:
// 用 fnv64a(strings.Join(names, "\x00")) 作 seed 喂 math/rand.New。
// 字节级不同不影响上游判定(Anthropic 不会验证我们的随机种子算法)。
func buildDynamicToolMap(toolNames []string) map[string]string {
if len(toolNames) <= dynamicToolMapThreshold {
return nil
}
h := fnv.New64a()
for i, n := range toolNames {
if i > 0 {
_, _ = h.Write([]byte{0})
}
_, _ = h.Write([]byte(n))
}
rng := rand.New(rand.NewSource(int64(h.Sum64())))
available := make([]string, len(fakeToolNamePrefixes))
copy(available, fakeToolNamePrefixes)
rng.Shuffle(len(available), func(i, j int) { available[i], available[j] = available[j], available[i] })
mapping := make(map[string]string, len(toolNames))
for i, name := range toolNames {
prefix := available[i%len(available)]
headLen := 3
if len(name) < 3 {
headLen = len(name)
}
fake := fmt.Sprintf("%s%s%02d", prefix, name[:headLen], i)
mapping[name] = fake
}
return mapping
}
// sanitizeToolName 把真名转成假名。
// 与 Parrot _sanitize_tool_name 语义一致:动态映射优先,再走静态前缀映射。
func sanitizeToolName(name string, dynamic map[string]string) string {
if dynamic != nil {
if fake, ok := dynamic[name]; ok {
return fake
}
}
for prefix, replacement := range staticToolNameRewrites {
if strings.HasPrefix(name, prefix) {
return replacement + name[len(prefix):]
}
}
return name
}
// shouldMimicToolName 指示某个 tool 是否需要重命名。
// server tool(type != "" 且不是 "function" / "custom")是 Anthropic 协议语义的一部分,
// 比如 "web_search_20250305" / "computer_20250124";误改会导致上游拒绝。
func shouldMimicToolName(toolType string) bool {
if toolType == "" || toolType == "function" || toolType == "custom" {
return true
}
return false
}
// buildToolNameRewriteFromBody 扫描 body 的 tools[*].name,构造 ToolNameRewrite
// 并返回它。若不需要混淆(tools 数量不足 + 没有匹配静态前缀的工具)返回 nil。
//
// 注意:只扫描,不改 body。真正的 body 改写在 applyToolNameRewriteToBody。
func buildToolNameRewriteFromBody(body []byte) *ToolNameRewrite {
tools := gjson.GetBytes(body, "tools")
if !tools.IsArray() {
return nil
}
mimicableNames := make([]string, 0)
toolsArr := tools.Array()
for _, t := range toolsArr {
if !shouldMimicToolName(t.Get("type").String()) {
continue
}
name := t.Get("name").String()
if name == "" {
continue
}
mimicableNames = append(mimicableNames, name)
}
dynamic := buildDynamicToolMap(mimicableNames)
rw := &ToolNameRewrite{
Forward: make(map[string]string),
Reverse: make(map[string]string),
}
for _, name := range mimicableNames {
fake := sanitizeToolName(name, dynamic)
if fake == name {
continue
}
rw.Forward[name] = fake
rw.Reverse[fake] = name
}
if len(rw.Forward) == 0 {
return nil
}
rw.ReverseOrdered = make([][2]string, 0, len(rw.Reverse))
for fake, real := range rw.Reverse {
rw.ReverseOrdered = append(rw.ReverseOrdered, [2]string{fake, real})
}
sort.SliceStable(rw.ReverseOrdered, func(i, j int) bool {
return len(rw.ReverseOrdered[i][0]) > len(rw.ReverseOrdered[j][0])
})
return rw
}
// applyToolNameRewriteToBody 把已构造的 ToolNameRewrite 应用到 body 上:
// - 改写 $.tools[*].name(仅对 shouldMimicToolName 通过的 tool)
// - 在 $.tools[last].cache_control 上打 ephemeral 缓存断点(Parrot 行为对齐,
// ttl 客户端已有则透传,否则默认 claude.DefaultCacheControlTTL)
// - 改写 $.tool_choice.name(仅当 $.tool_choice.type == "tool")
//
// 历史 $.messages[*].content[*].name(tool_use)不在请求侧改写——这与 Parrot 一致;
// 响应侧 bytes.Replace 会连带还原它们。
func applyToolNameRewriteToBody(body []byte, rw *ToolNameRewrite) []byte {
if rw == nil || len(rw.Forward) == 0 {
body = applyToolsLastCacheBreakpoint(body)
return body
}
tools := gjson.GetBytes(body, "tools")
if tools.IsArray() {
idx := -1
tools.ForEach(func(_, t gjson.Result) bool {
idx++
if !shouldMimicToolName(t.Get("type").String()) {
return true
}
name := t.Get("name").String()
if name == "" {
return true
}
fake, ok := rw.Forward[name]
if !ok {
return true
}
if next, err := sjson.SetBytes(body, fmt.Sprintf("tools.%d.name", idx), fake); err == nil {
body = next
}
return true
})
}
if tc := gjson.GetBytes(body, "tool_choice"); tc.Exists() && tc.Get("type").String() == "tool" {
name := tc.Get("name").String()
if fake, ok := rw.Forward[name]; ok {
if next, err := sjson.SetBytes(body, "tool_choice.name", fake); err == nil {
body = next
}
}
}
body = applyToolsLastCacheBreakpoint(body)
return body
}
// applyToolsLastCacheBreakpoint 在 tools 数组最后一个工具上注入 cache_control
// 断点,对齐 Parrot `tools[-1]["cache_control"] = {"type":"ephemeral","ttl":"1h"}`
// 行为,但 ttl 按本仓规则:
// - 客户端已为该 tool 显式设置 cache_control.ttl → 完全透传不覆盖
// - 否则注入 {"type":"ephemeral","ttl": claude.DefaultCacheControlTTL}
//
// 纯副作用函数,tools 不存在或为空数组时 no-op。
func applyToolsLastCacheBreakpoint(body []byte) []byte {
tools := gjson.GetBytes(body, "tools")
if !tools.IsArray() {
return body
}
arr := tools.Array()
if len(arr) == 0 {
return body
}
lastIdx := len(arr) - 1
existingCC := arr[lastIdx].Get("cache_control")
if existingCC.Exists() && existingCC.Get("ttl").String() != "" {
return body
}
if existingCC.Exists() {
if next, err := sjson.SetBytes(body, fmt.Sprintf("tools.%d.cache_control.ttl", lastIdx), claude.DefaultCacheControlTTL); err == nil {
body = next
}
return body
}
raw := fmt.Sprintf(`{"type":"ephemeral","ttl":%q}`, claude.DefaultCacheControlTTL)
if next, err := sjson.SetRawBytes(body, fmt.Sprintf("tools.%d.cache_control", lastIdx), []byte(raw)); err == nil {
body = next
}
return body
}
// restoreToolNamesInBytes 对 bytes chunk 做逆向还原:假名 → 真名。
// 按 ReverseOrdered 的假名长度倒序逐个 bytes.Replace,防止子串冲突
// (与 Parrot _restore_tool_names_in_chunk 的 sorted(..., reverse=True) 等价)。
// 再做静态前缀还原(cc_sess_ → sessions_ / cc_ses_ → session_)。
//
// rw 可为 nil;nil 时仍会做静态前缀还原。
func restoreToolNamesInBytes(data []byte, rw *ToolNameRewrite) []byte {
if rw != nil {
for _, pair := range rw.ReverseOrdered {
fake, real := pair[0], pair[1]
if fake == "" || fake == real {
continue
}
data = replaceAllBytes(data, fake, real)
}
}
for prefix, replacement := range staticToolNameRewrites {
data = replaceAllBytes(data, replacement, prefix)
}
return data
}
// replaceAllBytes 是 bytes.ReplaceAll 的便捷封装,避免每个调用点各自做 []byte 转换。
func replaceAllBytes(data []byte, from, to string) []byte {
if len(data) == 0 || from == to || !strings.Contains(string(data), from) {
return data
}
return []byte(strings.ReplaceAll(string(data), from, to))
}
// toolNameRewriteFromContext 从 gin.Context 取出请求阶段保存的工具名映射。
// 找不到(c==nil 或 key 不存在或类型不对)时返回 nil;调用方必须能处理 nil。
func toolNameRewriteFromContext(c interface {
Get(string) (any, bool)
}) *ToolNameRewrite {
if c == nil {
return nil
}
raw, ok := c.Get(toolNameRewriteKey)
if !ok || raw == nil {
return nil
}
rw, _ := raw.(*ToolNameRewrite)
return rw
}
// reverseToolNamesIfPresent 是响应侧 5 处注入点的统一封装:从 c 取出 mapping
// 并对 chunk 做 bytes 级假名→真名替换。c 没有 mapping 时仍会做静态前缀还原。
func reverseToolNamesIfPresent(c interface {
Get(string) (any, bool)
}, chunk []byte) []byte {
rw := toolNameRewriteFromContext(c)
if rw == nil && len(staticToolNameRewrites) == 0 {
return chunk
}
return restoreToolNamesInBytes(chunk, rw)
}
package service
import (
"strings"
"testing"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestBuildDynamicToolMap_BelowThreshold(t *testing.T) {
// Parrot 行为:tools 数量 ≤ 5 时不做动态映射。
names := []string{"bash", "edit", "read", "write", "search"}
require.Nil(t, buildDynamicToolMap(names))
}
func TestBuildDynamicToolMap_AboveThresholdIsStable(t *testing.T) {
// Parrot 不变量:同一组 tool_names 在同进程内映射稳定(保证 cache 命中)。
names := []string{"alpha", "beta", "gamma", "delta", "epsilon", "zeta"}
a := buildDynamicToolMap(names)
b := buildDynamicToolMap(names)
require.NotNil(t, a)
require.Equal(t, a, b, "same input tool_names must yield identical mapping")
require.Len(t, a, 6)
for _, name := range names {
require.Contains(t, a, name)
require.NotEqual(t, name, a[name])
}
}
func TestSanitizeToolName_StaticPrefix(t *testing.T) {
require.Equal(t, "cc_sess_list", sanitizeToolName("sessions_list", nil))
require.Equal(t, "cc_ses_get", sanitizeToolName("session_get", nil))
require.Equal(t, "bash", sanitizeToolName("bash", nil))
}
func TestSanitizeToolName_DynamicTakesPrecedence(t *testing.T) {
dyn := map[string]string{"sessions_list": "analyze_ses00"}
got := sanitizeToolName("sessions_list", dyn)
require.Equal(t, "analyze_ses00", got, "dynamic mapping wins over static prefix")
}
func TestRestoreToolNamesInBytes_LongestFirst(t *testing.T) {
// 当假名 "abc_12" 是另一个更长假名的子串(真实场景极少但算法必须防御)时,
// 长的必须先替换。本测试用显式构造的映射来验证排序不变量。
rw := &ToolNameRewrite{
Forward: map[string]string{"foo": "abc_12", "bar": "abc_12_ext"},
Reverse: map[string]string{"abc_12": "foo", "abc_12_ext": "bar"},
}
// 手工构造 ReverseOrdered:长的在前
rw.ReverseOrdered = [][2]string{
{"abc_12_ext", "bar"},
{"abc_12", "foo"},
}
data := []byte(`{"tool":"abc_12_ext","other":"abc_12"}`)
restored := string(restoreToolNamesInBytes(data, rw))
require.Equal(t, `{"tool":"bar","other":"foo"}`, restored)
}
func TestRestoreToolNamesInBytes_StaticPrefixRollback(t *testing.T) {
data := []byte(`{"name":"sessions_list","id":"cc_ses_xyz"}`)
got := string(restoreToolNamesInBytes(data, nil))
require.Equal(t, `{"name":"sessions_list","id":"session_xyz"}`, got)
}
func TestApplyToolNameRewriteToBody_RenamesToolsAndToolChoice(t *testing.T) {
body := []byte(`{"tools":[{"name":"sessions_list","input_schema":{}},{"name":"session_get","input_schema":{}},{"name":"web_search","type":"web_search_20250305"}],"tool_choice":{"type":"tool","name":"sessions_list"}}`)
rw := buildToolNameRewriteFromBody(body)
require.NotNil(t, rw)
require.Contains(t, rw.Forward, "sessions_list")
require.Contains(t, rw.Forward, "session_get")
// web_search is a server tool, not rewritten
require.NotContains(t, rw.Forward, "web_search")
out := applyToolNameRewriteToBody(body, rw)
// tools[0].name and tools[1].name rewritten; tools[2].name untouched
require.Equal(t, "cc_sess_list", gjson.GetBytes(out, "tools.0.name").String())
require.Equal(t, "cc_ses_get", gjson.GetBytes(out, "tools.1.name").String())
require.Equal(t, "web_search", gjson.GetBytes(out, "tools.2.name").String())
// tool_choice.name rewritten
require.Equal(t, "cc_sess_list", gjson.GetBytes(out, "tool_choice.name").String())
require.Equal(t, "tool", gjson.GetBytes(out, "tool_choice.type").String())
}
func TestApplyToolsLastCacheBreakpoint_InjectsDefault(t *testing.T) {
body := []byte(`{"tools":[{"name":"a","input_schema":{}},{"name":"b","input_schema":{}}]}`)
out := applyToolsLastCacheBreakpoint(body)
require.Equal(t, "ephemeral", gjson.GetBytes(out, "tools.1.cache_control.type").String())
require.Equal(t, "5m", gjson.GetBytes(out, "tools.1.cache_control.ttl").String())
// First tool untouched
require.False(t, gjson.GetBytes(out, "tools.0.cache_control").Exists())
}
func TestApplyToolsLastCacheBreakpoint_PassesThroughClientTTL(t *testing.T) {
body := []byte(`{"tools":[{"name":"a","input_schema":{},"cache_control":{"type":"ephemeral","ttl":"1h"}}]}`)
out := applyToolsLastCacheBreakpoint(body)
// User-provided ttl must be preserved.
require.Equal(t, "1h", gjson.GetBytes(out, "tools.0.cache_control.ttl").String())
}
func TestStripMessageCacheControl(t *testing.T) {
body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral"}}]}]}`)
out := stripMessageCacheControl(body)
require.False(t, gjson.GetBytes(out, "messages.0.content.0.cache_control").Exists())
}
func TestAddMessageCacheBreakpoints_LastMessageOnly(t *testing.T) {
body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
out := addMessageCacheBreakpoints(body)
require.Equal(t, "ephemeral", gjson.GetBytes(out, "messages.0.content.0.cache_control.type").String())
require.Equal(t, "5m", gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").String())
}
func TestAddMessageCacheBreakpoints_SecondToLastUserTurn(t *testing.T) {
// Parrot 不变量:messages ≥ 4 时才打第二个断点,且位置是"倒数第二个 user turn"。
body := []byte(`{"messages":[
{"role":"user","content":[{"type":"text","text":"q1"}]},
{"role":"assistant","content":[{"type":"text","text":"a1"}]},
{"role":"user","content":[{"type":"text","text":"q2"}]},
{"role":"assistant","content":[{"type":"text","text":"a2"}]}
]}`)
out := addMessageCacheBreakpoints(body)
// 最后一条 assistant 被打断点
require.Equal(t, "ephemeral", gjson.GetBytes(out, "messages.3.content.0.cache_control.type").String())
// 倒数第二个 user turn = index 0(唯一另一个 user)
require.Equal(t, "ephemeral", gjson.GetBytes(out, "messages.0.content.0.cache_control.type").String())
// 其他不打断点
require.False(t, gjson.GetBytes(out, "messages.1.content.0.cache_control").Exists())
require.False(t, gjson.GetBytes(out, "messages.2.content.0.cache_control").Exists())
}
func TestAddMessageCacheBreakpoints_StringContentPromoted(t *testing.T) {
body := []byte(`{"messages":[{"role":"user","content":"hi"}]}`)
out := addMessageCacheBreakpoints(body)
// content 升级成数组
require.True(t, gjson.GetBytes(out, "messages.0.content").IsArray())
require.Equal(t, "text", gjson.GetBytes(out, "messages.0.content.0.type").String())
require.Equal(t, "hi", gjson.GetBytes(out, "messages.0.content.0.text").String())
require.Equal(t, "5m", gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").String())
}
func TestBuildToolNameRewriteFromBody_ReverseOrderedByLengthDesc(t *testing.T) {
// 超过阈值触发动态映射,验证 ReverseOrdered 按假名长度倒序排列
body := []byte(`{"tools":[
{"name":"t1","input_schema":{}},
{"name":"t2","input_schema":{}},
{"name":"t3","input_schema":{}},
{"name":"t4","input_schema":{}},
{"name":"t5","input_schema":{}},
{"name":"t6","input_schema":{}}
]}`)
rw := buildToolNameRewriteFromBody(body)
require.NotNil(t, rw)
require.NotEmpty(t, rw.ReverseOrdered)
for i := 1; i < len(rw.ReverseOrdered); i++ {
require.GreaterOrEqual(t, len(rw.ReverseOrdered[i-1][0]), len(rw.ReverseOrdered[i][0]),
"ReverseOrdered must be sorted by fake-name length descending")
}
}
func TestRestoreToolNamesInBytes_NoMapping_NoStaticMatch_IsNoop(t *testing.T) {
data := []byte("plain text without any tool names")
require.Equal(t, string(data), string(restoreToolNamesInBytes(data, nil)))
}
// Ensure the fake name format follows Parrot's "{prefix}{name[:3]}{i:02d}".
func TestBuildDynamicToolMap_FakeNameShape(t *testing.T) {
names := []string{"alphabet", "bravo", "charlie", "delta", "echo", "foxtrot"}
m := buildDynamicToolMap(names)
require.NotNil(t, m)
for _, name := range names {
fake, ok := m[name]
require.True(t, ok)
// fake = prefix + head3 + "%02d"
// ends with two decimal digits
require.Regexp(t, `^[a-z]+_[a-z0-9]{1,3}\d{2}$`, fake)
head := name
if len(head) > 3 {
head = head[:3]
}
require.True(t, strings.Contains(fake, head), "fake %q should contain head3 %q of %q", fake, head, name)
}
}
......@@ -1008,6 +1008,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
var usage *ClaudeUsage
var firstTokenMs *int
var responseBody string
if req.Stream {
streamRes, err := s.handleStreamingResponse(c, resp, startTime, originalModel)
if err != nil {
......@@ -1015,6 +1016,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
responseBody = streamRes.responseBody
} else {
if useUpstreamStream {
collected, usageObj, err := collectGeminiSSE(resp.Body, true)
......@@ -1023,16 +1025,20 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}
collectedBytes, _ := json.Marshal(collected)
claudeResp, usageObj2 := convertGeminiToClaudeMessage(collected, originalModel, collectedBytes)
c.JSON(http.StatusOK, claudeResp)
respBytes, _ := json.Marshal(claudeResp)
c.Data(http.StatusOK, "application/json", respBytes)
responseBody = string(respBytes)
usage = usageObj2
if usageObj != nil && (usageObj.InputTokens > 0 || usageObj.OutputTokens > 0) {
usage = usageObj
}
} else {
usage, err = s.handleNonStreamingResponse(c, resp, originalModel)
var nonStreamBody string
nonStreamBody, usage, err = s.handleNonStreamingResponse(c, resp, originalModel)
if err != nil {
return nil, err
}
responseBody = nonStreamBody
}
}
......@@ -1053,6 +1059,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
FirstTokenMs: firstTokenMs,
ImageCount: imageCount,
ImageSize: imageSize,
ResponseBody: responseBody,
}, nil
}
......@@ -1872,28 +1879,30 @@ func mapGeminiStatusToClaudeErrorType(status string) string {
type geminiStreamResult struct {
usage *ClaudeUsage
firstTokenMs *int
responseBody string // 累积的文本内容,用于响应捕获
}
func (s *GeminiMessagesCompatService) handleNonStreamingResponse(c *gin.Context, resp *http.Response, originalModel string) (*ClaudeUsage, error) {
func (s *GeminiMessagesCompatService) handleNonStreamingResponse(c *gin.Context, resp *http.Response, originalModel string) (string, *ClaudeUsage, error) {
body, err := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
if err != nil {
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream response")
return "", nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream response")
}
unwrappedBody, err := unwrapGeminiResponse(body)
if err != nil {
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
return "", nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
}
var geminiResp map[string]any
if err := json.Unmarshal(unwrappedBody, &geminiResp); err != nil {
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
return "", nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
}
claudeResp, usage := convertGeminiToClaudeMessage(geminiResp, originalModel, unwrappedBody)
c.JSON(http.StatusOK, claudeResp)
respBytes, _ := json.Marshal(claudeResp)
c.Data(http.StatusOK, "application/json", respBytes)
return usage, nil
return string(respBytes), usage, nil
}
func (s *GeminiMessagesCompatService) handleStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*geminiStreamResult, error) {
......@@ -2146,7 +2155,7 @@ func (s *GeminiMessagesCompatService) handleStreamingResponse(c *gin.Context, re
})
flusher.Flush()
return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil
return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs, responseBody: seenText}, nil
}
func writeSSE(w io.Writer, event string, data any) {
......
......@@ -26,7 +26,7 @@ var (
// 默认指纹值(当客户端未提供时使用)
var defaultFingerprint = Fingerprint{
UserAgent: "claude-cli/2.1.22 (external, cli)",
UserAgent: "claude-cli/2.1.92 (external, cli)",
StainlessLang: "js",
StainlessPackageVersion: "0.70.0",
StainlessOS: "Linux",
......
......@@ -3,7 +3,6 @@ package service
import (
"container/heap"
"context"
"errors"
"fmt"
"hash/fnv"
"math"
......@@ -45,6 +44,7 @@ type OpenAIAccountScheduleRequest struct {
RequestedModel string
RequiredTransport OpenAIUpstreamTransport
RequiredImageCapability OpenAIImagesCapability
RequireCompact bool
ExcludedIDs map[int64]struct{}
}
......@@ -258,12 +258,16 @@ func (s *defaultOpenAIAccountScheduler) Select(
previousResponseID,
req.RequestedModel,
req.ExcludedIDs,
req.RequireCompact,
)
if err != nil {
return nil, decision, err
}
if selection != nil && selection.Account != nil {
if !s.isAccountTransportCompatible(selection.Account, req.RequiredTransport) {
if selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
selection = nil
}
}
......@@ -348,8 +352,8 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
}
account = s.service.recheckSelectedOpenAIAccountFromDB(ctx, account, req.RequestedModel)
if account == nil {
account = s.service.recheckSelectedOpenAIAccountFromDB(ctx, account, req.RequestedModel, req.RequireCompact)
if account == nil || !s.isAccountTransportCompatible(account, req.RequiredTransport) {
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
}
......@@ -590,7 +594,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
return nil, 0, 0, 0, err
}
if len(accounts) == 0 {
return nil, 0, 0, 0, errors.New("no available OpenAI accounts")
return nil, 0, 0, 0, noAvailableOpenAISelectionError(req.RequestedModel, false)
}
// require_privacy_set: 获取分组信息
......@@ -630,7 +634,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
})
}
if len(filtered) == 0 {
return nil, 0, 0, 0, errors.New("no available OpenAI accounts")
return nil, 0, 0, 0, noAvailableOpenAISelectionError(req.RequestedModel, false)
}
loadMap := map[int64]*AccountLoadInfo{}
......@@ -640,53 +644,77 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
}
}
minPriority, maxPriority := filtered[0].Priority, filtered[0].Priority
maxWaiting := 1
loadRateSum := 0.0
loadRateSumSquares := 0.0
minTTFT, maxTTFT := 0.0, 0.0
hasTTFTSample := false
candidates := make([]openAIAccountCandidateScore, 0, len(filtered))
allCandidates := make([]openAIAccountCandidateScore, 0, len(filtered))
for _, account := range filtered {
loadInfo := loadMap[account.ID]
if loadInfo == nil {
loadInfo = &AccountLoadInfo{AccountID: account.ID}
}
if account.Priority < minPriority {
minPriority = account.Priority
errorRate, ttft, hasTTFT := s.stats.snapshot(account.ID)
allCandidates = append(allCandidates, openAIAccountCandidateScore{
account: account,
loadInfo: loadInfo,
errorRate: errorRate,
ttft: ttft,
hasTTFT: hasTTFT,
})
}
if account.Priority > maxPriority {
maxPriority = account.Priority
// Compact 模式下把明确不支持 compact 的账号拆出,仅在 schedulerSnapshot 启用
// 时作为最后兜底(snapshot 可能已陈旧)。
candidates := allCandidates
staleSnapshotCompactRetry := make([]openAIAccountCandidateScore, 0, len(allCandidates))
if req.RequireCompact {
candidates = make([]openAIAccountCandidateScore, 0, len(allCandidates))
for _, candidate := range allCandidates {
if openAICompactSupportTier(candidate.account) == 0 {
staleSnapshotCompactRetry = append(staleSnapshotCompactRetry, candidate)
continue
}
if loadInfo.WaitingCount > maxWaiting {
maxWaiting = loadInfo.WaitingCount
candidates = append(candidates, candidate)
}
errorRate, ttft, hasTTFT := s.stats.snapshot(account.ID)
if hasTTFT && ttft > 0 {
if len(candidates) == 0 && len(staleSnapshotCompactRetry) == 0 {
return nil, 0, 0, 0, ErrNoAvailableCompactAccounts
}
}
candidateCount := len(candidates)
loadSkew := 0.0
if len(candidates) > 0 {
minPriority, maxPriority := candidates[0].account.Priority, candidates[0].account.Priority
maxWaiting := 1
loadRateSum := 0.0
loadRateSumSquares := 0.0
minTTFT, maxTTFT := 0.0, 0.0
hasTTFTSample := false
for _, candidate := range candidates {
if candidate.account.Priority < minPriority {
minPriority = candidate.account.Priority
}
if candidate.account.Priority > maxPriority {
maxPriority = candidate.account.Priority
}
if candidate.loadInfo.WaitingCount > maxWaiting {
maxWaiting = candidate.loadInfo.WaitingCount
}
if candidate.hasTTFT && candidate.ttft > 0 {
if !hasTTFTSample {
minTTFT, maxTTFT = ttft, ttft
minTTFT, maxTTFT = candidate.ttft, candidate.ttft
hasTTFTSample = true
} else {
if ttft < minTTFT {
minTTFT = ttft
if candidate.ttft < minTTFT {
minTTFT = candidate.ttft
}
if ttft > maxTTFT {
maxTTFT = ttft
if candidate.ttft > maxTTFT {
maxTTFT = candidate.ttft
}
}
}
loadRate := float64(loadInfo.LoadRate)
loadRate := float64(candidate.loadInfo.LoadRate)
loadRateSum += loadRate
loadRateSumSquares += loadRate * loadRate
candidates = append(candidates, openAIAccountCandidateScore{
account: account,
loadInfo: loadInfo,
errorRate: errorRate,
ttft: ttft,
hasTTFT: hasTTFT,
})
}
loadSkew := calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates))
loadSkew = calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates))
weights := s.service.openAIWSSchedulerWeights()
for i := range candidates {
......@@ -709,30 +737,105 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
weights.ErrorRate*errorFactor +
weights.TTFT*ttftFactor
}
}
topK := s.service.openAIWSLBTopK()
topK := 0
if len(candidates) > 0 {
topK = s.service.openAIWSLBTopK()
if topK > len(candidates) {
topK = len(candidates)
}
if topK <= 0 {
topK = 1
}
rankedCandidates := selectTopKOpenAICandidates(candidates, topK)
selectionOrder := buildOpenAIWeightedSelectionOrder(rankedCandidates, req)
}
buildSelectionOrder := func(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore {
if len(pool) == 0 || topK <= 0 {
return nil
}
groupTopK := topK
if groupTopK > len(pool) {
groupTopK = len(pool)
}
ranked := selectTopKOpenAICandidates(pool, groupTopK)
return buildOpenAIWeightedSelectionOrder(ranked, req)
}
sortCompactRetryCandidates := func(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore {
if len(pool) == 0 {
return nil
}
ordered := append([]openAIAccountCandidateScore(nil), pool...)
sort.SliceStable(ordered, func(i, j int) bool {
a, b := ordered[i], ordered[j]
if a.account.Priority != b.account.Priority {
return a.account.Priority < b.account.Priority
}
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
}
if a.loadInfo.WaitingCount != b.loadInfo.WaitingCount {
return a.loadInfo.WaitingCount < b.loadInfo.WaitingCount
}
switch {
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
return true
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
return false
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
return false
default:
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
}
})
return ordered
}
selectionOrder := make([]openAIAccountCandidateScore, 0, len(allCandidates))
if req.RequireCompact {
supported := make([]openAIAccountCandidateScore, 0, len(candidates))
unknown := make([]openAIAccountCandidateScore, 0, len(candidates))
for _, candidate := range candidates {
switch openAICompactSupportTier(candidate.account) {
case 2:
supported = append(supported, candidate)
case 1:
unknown = append(unknown, candidate)
}
}
if len(supported) == 0 && len(unknown) == 0 && s.service.schedulerSnapshot == nil {
return nil, candidateCount, topK, loadSkew, ErrNoAvailableCompactAccounts
}
selectionOrder = append(selectionOrder, buildSelectionOrder(supported)...)
selectionOrder = append(selectionOrder, buildSelectionOrder(unknown)...)
if len(staleSnapshotCompactRetry) > 0 && s.service.schedulerSnapshot != nil {
selectionOrder = append(selectionOrder, sortCompactRetryCandidates(staleSnapshotCompactRetry)...)
}
} else {
selectionOrder = buildSelectionOrder(candidates)
}
if len(selectionOrder) == 0 {
return nil, candidateCount, topK, loadSkew, noAvailableOpenAISelectionError(req.RequestedModel, req.RequireCompact && len(allCandidates) > 0)
}
compactBlocked := false
for i := 0; i < len(selectionOrder); i++ {
candidate := selectionOrder[i]
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel)
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
continue
}
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel)
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
continue
}
if req.RequireCompact && openAICompactSupportTier(fresh) == 0 {
compactBlocked = true
continue
}
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
if acquireErr != nil {
return nil, len(candidates), topK, loadSkew, acquireErr
return nil, candidateCount, topK, loadSkew, acquireErr
}
if result != nil && result.Acquired {
if req.SessionHash != "" {
......@@ -742,17 +845,25 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
Account: fresh,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, len(candidates), topK, loadSkew, nil
}, candidateCount, topK, loadSkew, nil
}
}
cfg := s.service.schedulingConfig()
// WaitPlan.MaxConcurrency 使用 Concurrency(非 EffectiveLoadFactor),因为 WaitPlan 控制的是 Redis 实际并发槽位等待。
for _, candidate := range selectionOrder {
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel)
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
continue
}
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
continue
}
if req.RequireCompact && openAICompactSupportTier(fresh) == 0 {
compactBlocked = true
continue
}
return &AccountSelectionResult{
Account: fresh,
WaitPlan: &AccountWaitPlan{
......@@ -761,10 +872,10 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, len(candidates), topK, loadSkew, nil
}, candidateCount, topK, loadSkew, nil
}
return nil, len(candidates), topK, loadSkew, ErrNoAvailableAccounts
return nil, candidateCount, topK, loadSkew, noAvailableOpenAISelectionError(req.RequestedModel, compactBlocked)
}
func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool {
......@@ -905,8 +1016,9 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
requestedModel string,
excludedIDs map[int64]struct{},
requiredTransport OpenAIUpstreamTransport,
requireCompact bool,
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
return s.selectAccountWithScheduler(ctx, groupID, previousResponseID, sessionHash, requestedModel, excludedIDs, requiredTransport, "")
return s.selectAccountWithScheduler(ctx, groupID, previousResponseID, sessionHash, requestedModel, excludedIDs, requiredTransport, "", requireCompact)
}
func (s *OpenAIGatewayService) SelectAccountWithSchedulerForImages(
......@@ -917,13 +1029,13 @@ func (s *OpenAIGatewayService) SelectAccountWithSchedulerForImages(
excludedIDs map[int64]struct{},
requiredCapability OpenAIImagesCapability,
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
selection, decision, err := s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, requiredCapability)
selection, decision, err := s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, requiredCapability, false)
if err == nil && selection != nil && selection.Account != nil {
return selection, decision, nil
}
// 如果要求 native 能力(如指定了模型)但没有可用的 APIKey 账号,回退到 basic(OAuth 账号)
if requiredCapability == OpenAIImagesCapabilityNative {
return s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, OpenAIImagesCapabilityBasic)
return s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, OpenAIImagesCapabilityBasic, false)
}
return selection, decision, err
}
......@@ -937,6 +1049,7 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler(
excludedIDs map[int64]struct{},
requiredTransport OpenAIUpstreamTransport,
requiredImageCapability OpenAIImagesCapability,
requireCompact bool,
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
decision := OpenAIAccountScheduleDecision{}
scheduler := s.getOpenAIAccountScheduler(ctx)
......@@ -945,7 +1058,7 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler(
if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs)
for {
selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs)
selection, err := s.selectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs, requireCompact)
if err != nil {
return nil, decision, err
}
......@@ -970,7 +1083,7 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler(
effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs)
for {
selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs)
selection, err := s.selectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs, requireCompact)
if err != nil {
return nil, decision, err
}
......@@ -1008,6 +1121,7 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler(
RequestedModel: requestedModel,
RequiredTransport: requiredTransport,
RequiredImageCapability: requiredImageCapability,
RequireCompact: requireCompact,
ExcludedIDs: excludedIDs,
})
}
......
package service
import (
"context"
"errors"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
// TestOpenAIGatewayService_SelectAccountWithScheduler_CompactPrefersSupportedOverUnknown
// 验证 compact 调度时显式支持 (tier=2) 优先于未探测 (tier=1)。
func TestOpenAIGatewayService_SelectAccountWithScheduler_CompactPrefersSupportedOverUnknown(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
ctx := context.Background()
groupID := int64(91001)
accounts := []Account{
{
ID: 71001,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
Extra: map[string]any{}, // unknown
},
{
ID: 71002,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
Extra: map[string]any{"openai_compact_supported": true}, // tier=2
},
}
cfg := &config.Config{}
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &OpenAIGatewayService{
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: &schedulerTestGatewayCache{},
cfg: cfg,
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, _, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"",
"",
"gpt-5.4",
nil,
OpenAIUpstreamTransportAny,
true,
)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, int64(71002), selection.Account.ID, "compact-supported account should win over unknown")
}
// TestOpenAIGatewayService_SelectAccountWithScheduler_CompactRejectsExplicitlyUnsupported
// 验证 force_off / 已探测不支持 (tier=0) 的账号不会被 compact 请求选中。
func TestOpenAIGatewayService_SelectAccountWithScheduler_CompactRejectsExplicitlyUnsupported(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
ctx := context.Background()
groupID := int64(91002)
accounts := []Account{
{
ID: 71010,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOff},
},
{
ID: 71011,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
Extra: map[string]any{"openai_compact_supported": false},
},
}
cfg := &config.Config{}
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &OpenAIGatewayService{
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: &schedulerTestGatewayCache{},
cfg: cfg,
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, _, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"",
"",
"gpt-5.4",
nil,
OpenAIUpstreamTransportAny,
true,
)
require.Error(t, err)
require.True(t, errors.Is(err, ErrNoAvailableCompactAccounts), "compact-only accounts should rejected explicitly unsupported and return compact error")
require.Nil(t, selection)
}
// TestOpenAIGatewayService_SelectAccountWithScheduler_CompactFallsBackToUnknown
// 验证当没有"已知支持"账号时,compact 请求会回退到"未探测"账号。
func TestOpenAIGatewayService_SelectAccountWithScheduler_CompactFallsBackToUnknown(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
ctx := context.Background()
groupID := int64(91003)
accounts := []Account{
{
ID: 71020,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
Extra: map[string]any{"openai_compact_supported": false}, // tier=0
},
{
ID: 71021,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
Extra: map[string]any{}, // unknown -> tier=1
},
}
cfg := &config.Config{}
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &OpenAIGatewayService{
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: &schedulerTestGatewayCache{},
cfg: cfg,
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, _, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"",
"",
"gpt-5.4",
nil,
OpenAIUpstreamTransportAny,
true,
)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, int64(71021), selection.Account.ID, "unknown account should be picked when no supported account available")
}
// TestOpenAICompactSupportTier 验证 tier 分类逻辑。
func TestOpenAICompactSupportTier(t *testing.T) {
tests := []struct {
name string
account *Account
want int
}{
{name: "nil", account: nil, want: 0},
{name: "non openai", account: &Account{Platform: PlatformAnthropic}, want: 0},
{name: "openai unknown", account: &Account{Platform: PlatformOpenAI, Extra: map[string]any{}}, want: 1},
{name: "openai supported", account: &Account{Platform: PlatformOpenAI, Extra: map[string]any{"openai_compact_supported": true}}, want: 2},
{name: "openai unsupported", account: &Account{Platform: PlatformOpenAI, Extra: map[string]any{"openai_compact_supported": false}}, want: 0},
{name: "force on", account: &Account{Platform: PlatformOpenAI, Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOn}}, want: 2},
{name: "force off overrides probe true", account: &Account{Platform: PlatformOpenAI, Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOff, "openai_compact_supported": true}}, want: 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := openAICompactSupportTier(tt.account); got != tt.want {
t.Fatalf("openAICompactSupportTier(...) = %d, want %d", got, tt.want)
}
})
}
}
......@@ -289,6 +289,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabledUsesLega
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
false,
)
require.NoError(t, err)
require.NotNil(t, selection)
......@@ -343,6 +344,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabled_Require
"gpt-5.1",
nil,
OpenAIUpstreamTransportResponsesWebsocketV2,
false,
)
require.NoError(t, err)
require.NotNil(t, selection)
......@@ -384,6 +386,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabled_Require
"gpt-5.1",
nil,
OpenAIUpstreamTransportResponsesWebsocketV2,
false,
)
require.ErrorContains(t, err, "no available OpenAI accounts")
require.Nil(t, selection)
......@@ -445,6 +448,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_EnabledUsesAdvancedPrev
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
false,
)
require.NoError(t, err)
require.NotNil(t, selection)
......@@ -486,7 +490,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimite
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_rate_limited", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_rate_limited", "gpt-5.1", nil, OpenAIUpstreamTransportAny, false)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
......@@ -540,7 +544,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyDBRuntimeR
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_db_runtime_recheck", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_db_runtime_recheck", "gpt-5.1", nil, OpenAIUpstreamTransportAny, false)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
......@@ -616,6 +620,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
false,
)
require.NoError(t, err)
require.NotNil(t, selection)
......@@ -662,6 +667,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky(t *testin
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
false,
)
require.NoError(t, err)
require.NotNil(t, selection)
......@@ -740,6 +746,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
false,
)
require.NoError(t, err)
require.NotNil(t, selection)
......@@ -788,6 +795,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky_ForceHTTP
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
false,
)
require.NoError(t, err)
require.NotNil(t, selection)
......@@ -857,6 +865,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStick
"gpt-5.1",
nil,
OpenAIUpstreamTransportResponsesWebsocketV2,
false,
)
require.NoError(t, err)
require.NotNil(t, selection)
......@@ -900,6 +909,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_NoAvailabl
"gpt-5.1",
nil,
OpenAIUpstreamTransportResponsesWebsocketV2,
false,
)
require.Error(t, err)
require.Nil(t, selection)
......@@ -976,6 +986,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
false,
)
require.NoError(t, err)
require.NotNil(t, selection)
......@@ -1014,7 +1025,7 @@ func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics(t *testing.T) {
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_metrics", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
selection, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_metrics", "gpt-5.1", nil, OpenAIUpstreamTransportAny, false)
require.NoError(t, err)
require.NotNil(t, selection)
svc.ReportOpenAIAccountScheduleResult(account.ID, true, intPtrForTest(120))
......@@ -1218,6 +1229,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesA
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
false,
)
require.NoError(t, err)
require.NotNil(t, selection)
......
......@@ -54,6 +54,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_UsesWSPassthroughSnapsh
"gpt-5.1",
nil,
OpenAIUpstreamTransportResponsesWebsocketV2,
false,
)
require.NoError(t, err)
require.NotNil(t, selection)
......
package service
import (
"encoding/json"
"fmt"
"strings"
)
......@@ -48,6 +49,8 @@ type codexTransformResult struct {
const (
codexImageGenerationBridgeMarker = "<sub2api-codex-image-generation>"
codexImageGenerationBridgeText = codexImageGenerationBridgeMarker + "\nWhen the user asks for raster image generation or editing, use the OpenAI Responses native `image_generation` tool attached to this request. The local Codex client may not expose an `image_gen` namespace, but that does not mean image generation is unavailable. Do not ask the user to switch to CLI fallback solely because `image_gen` is absent.\n</sub2api-codex-image-generation>"
codexSparkImageUnsupportedMarker = "<sub2api-codex-spark-image-unsupported>"
codexSparkImageUnsupportedText = codexSparkImageUnsupportedMarker + "\nThe current model is gpt-5.3-codex-spark, which does not support image generation, image editing, image input, the `image_generation` tool, or Codex `image_gen`/`$imagegen` workflows. If the user asks for image generation or image editing, clearly explain this model limitation and ask them to switch to a non-Spark Codex model such as gpt-5.3-codex or gpt-5.4. Do not claim that the local environment merely lacks image_gen tooling, and do not suggest CLI fallback as the primary fix while the model remains Spark.\n</sub2api-codex-spark-image-unsupported>"
)
func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact bool) codexTransformResult {
......@@ -151,6 +154,9 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
if normalizeCodexTools(reqBody) {
result.Modified = true
}
if normalizeCodexToolChoice(reqBody) {
result.Modified = true
}
if v, ok := reqBody["prompt_cache_key"].(string); ok {
result.PromptCacheKey = strings.TrimSpace(v)
......@@ -165,9 +171,20 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
if applyInstructions(reqBody, isCodexCLI) {
result.Modified = true
}
if isCodexSparkModel(normalizedModel) && applyCodexSparkImageUnsupportedInstructions(reqBody) {
result.Modified = true
}
// 续链场景保留 item_reference 与 id,避免 call_id 上下文丢失。
if input, ok := reqBody["input"].([]any); ok {
if normalizedInput, modified := normalizeCodexToolRoleMessages(input); modified {
input = normalizedInput
result.Modified = true
}
if normalizedInput, modified := normalizeCodexMessageContentText(input); modified {
input = normalizedInput
result.Modified = true
}
input = filterCodexInput(input, needsToolContinuation)
reqBody["input"] = input
result.Modified = true
......@@ -192,6 +209,183 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
return result
}
func normalizeCodexToolChoice(reqBody map[string]any) bool {
choice, ok := reqBody["tool_choice"]
if !ok || choice == nil {
return false
}
choiceMap, ok := choice.(map[string]any)
if !ok {
return false
}
choiceType := strings.TrimSpace(firstNonEmptyString(choiceMap["type"]))
if choiceType == "" || codexToolsContainType(reqBody["tools"], choiceType) {
return false
}
reqBody["tool_choice"] = "auto"
return true
}
func codexToolsContainType(rawTools any, toolType string) bool {
tools, ok := rawTools.([]any)
if !ok || strings.TrimSpace(toolType) == "" {
return false
}
for _, rawTool := range tools {
tool, ok := rawTool.(map[string]any)
if !ok {
continue
}
if strings.TrimSpace(firstNonEmptyString(tool["type"])) == toolType {
return true
}
}
return false
}
func normalizeCodexToolRoleMessages(input []any) ([]any, bool) {
if len(input) == 0 {
return input, false
}
modified := false
normalized := make([]any, 0, len(input))
for _, item := range input {
m, ok := item.(map[string]any)
if !ok {
normalized = append(normalized, item)
continue
}
role, _ := m["role"].(string)
if strings.TrimSpace(role) != "tool" {
normalized = append(normalized, item)
continue
}
callID := firstNonEmptyString(m["call_id"], m["tool_call_id"], m["id"])
callID = strings.TrimSpace(callID)
if callID == "" {
// Responses does not accept role:"tool". If no call id is available,
// preserve the text as a user message instead of sending invalid input.
fallback := make(map[string]any, len(m))
for key, value := range m {
fallback[key] = value
}
fallback["role"] = "user"
delete(fallback, "tool_call_id")
normalized = append(normalized, fallback)
modified = true
continue
}
output := extractTextFromContent(m["content"])
if output == "" {
if value, ok := m["output"].(string); ok {
output = value
}
}
if output == "" && m["content"] != nil {
if b, err := json.Marshal(m["content"]); err == nil {
output = string(b)
}
}
normalized = append(normalized, map[string]any{
"type": "function_call_output",
"call_id": callID,
"output": output,
})
modified = true
}
if !modified {
return input, false
}
return normalized, true
}
func normalizeCodexMessageContentText(input []any) ([]any, bool) {
if len(input) == 0 {
return input, false
}
modified := false
normalized := make([]any, 0, len(input))
for _, item := range input {
m, ok := item.(map[string]any)
if !ok || strings.TrimSpace(firstNonEmptyString(m["type"])) != "message" {
normalized = append(normalized, item)
continue
}
parts, ok := m["content"].([]any)
if !ok {
normalized = append(normalized, item)
continue
}
var newItem map[string]any
var newParts []any
ensureItemCopy := func() {
if newItem != nil {
return
}
newItem = make(map[string]any, len(m))
for key, value := range m {
newItem[key] = value
}
newParts = make([]any, len(parts))
copy(newParts, parts)
}
for i, rawPart := range parts {
part, ok := rawPart.(map[string]any)
if !ok {
continue
}
text, hasText := part["text"]
if !hasText {
continue
}
if _, ok := text.(string); ok {
continue
}
ensureItemCopy()
newPart := make(map[string]any, len(part))
for key, value := range part {
newPart[key] = value
}
newPart["text"] = stringifyCodexContentText(text)
newParts[i] = newPart
modified = true
}
if newItem != nil {
newItem["content"] = newParts
normalized = append(normalized, newItem)
continue
}
normalized = append(normalized, item)
}
if !modified {
return input, false
}
return normalized, true
}
func stringifyCodexContentText(value any) string {
switch v := value.(type) {
case string:
return v
case nil:
return ""
default:
if b, err := json.Marshal(v); err == nil {
return string(b)
}
return fmt.Sprint(v)
}
}
func normalizeCodexModel(model string) string {
model = strings.TrimSpace(model)
if model == "" {
......@@ -244,6 +438,10 @@ func normalizeCodexModel(model string) string {
return "gpt-5.4"
}
func isCodexSparkModel(model string) bool {
return normalizeCodexModel(model) == "gpt-5.3-codex-spark"
}
func hasOpenAIImageGenerationTool(reqBody map[string]any) bool {
rawTools, ok := reqBody["tools"]
if !ok || rawTools == nil {
......@@ -265,6 +463,40 @@ func hasOpenAIImageGenerationTool(reqBody map[string]any) bool {
return false
}
func hasOpenAIInputImage(reqBody map[string]any) bool {
if reqBody == nil {
return false
}
return hasOpenAIInputImageValue(reqBody["input"]) || hasOpenAIInputImageValue(reqBody["messages"])
}
func hasOpenAIInputImageValue(value any) bool {
switch v := value.(type) {
case []any:
for _, item := range v {
if hasOpenAIInputImageValue(item) {
return true
}
}
case map[string]any:
if strings.TrimSpace(firstNonEmptyString(v["type"])) == "input_image" {
return true
}
if _, ok := v["image_url"]; ok {
return true
}
return hasOpenAIInputImageValue(v["content"])
}
return false
}
func validateCodexSparkInput(reqBody map[string]any, model string) error {
if !isCodexSparkModel(model) || !hasOpenAIInputImage(reqBody) {
return nil
}
return fmt.Errorf("model %q does not support image input", strings.TrimSpace(model))
}
func normalizeOpenAIResponsesImageGenerationTools(reqBody map[string]any) bool {
rawTools, ok := reqBody["tools"]
if !ok || rawTools == nil {
......@@ -309,6 +541,9 @@ func ensureOpenAIResponsesImageGenerationTool(reqBody map[string]any) bool {
if len(reqBody) == 0 {
return false
}
if isCodexSparkModel(firstNonEmptyString(reqBody["model"])) {
return false
}
tool := map[string]any{
"type": "image_generation",
......@@ -344,6 +579,9 @@ func applyCodexImageGenerationBridgeInstructions(reqBody map[string]any) bool {
if len(reqBody) == 0 || !hasOpenAIImageGenerationTool(reqBody) {
return false
}
if isCodexSparkModel(firstNonEmptyString(reqBody["model"])) {
return false
}
existing, _ := reqBody["instructions"].(string)
if strings.Contains(existing, codexImageGenerationBridgeMarker) {
......@@ -360,6 +598,23 @@ func applyCodexImageGenerationBridgeInstructions(reqBody map[string]any) bool {
return true
}
func applyCodexSparkImageUnsupportedInstructions(reqBody map[string]any) bool {
if len(reqBody) == 0 {
return false
}
existing, _ := reqBody["instructions"].(string)
if strings.Contains(existing, codexSparkImageUnsupportedMarker) {
return false
}
existing = strings.TrimRight(existing, " \t\r\n")
if strings.TrimSpace(existing) == "" {
reqBody["instructions"] = codexSparkImageUnsupportedText
return true
}
reqBody["instructions"] = existing + "\n\n" + codexSparkImageUnsupportedText
return true
}
func validateOpenAIResponsesImageModel(reqBody map[string]any, model string) error {
if !hasOpenAIImageGenerationTool(reqBody) {
return nil
......@@ -658,12 +913,30 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
}
}
if !preserveReferences {
ensureCopy()
delete(newItem, "id")
if !isCodexToolCallItemType(typ) {
ensureCopy()
delete(newItem, "call_id")
}
if codexInputItemRequiresName(typ) {
if strings.TrimSpace(firstNonEmptyString(m["name"])) == "" {
name := firstNonEmptyString(m["tool_name"])
if name == "" {
if function, ok := m["function"].(map[string]any); ok {
name = firstNonEmptyString(function["name"])
}
}
if name == "" {
name = "tool"
}
ensureCopy()
newItem["name"] = name
}
}
if !preserveReferences {
ensureCopy()
delete(newItem, "id")
}
filtered = append(filtered, newItem)
......@@ -672,10 +945,30 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
}
func isCodexToolCallItemType(typ string) bool {
if typ == "" {
switch typ {
case "function_call",
"tool_call",
"local_shell_call",
"tool_search_call",
"custom_tool_call",
"mcp_tool_call",
"function_call_output",
"mcp_tool_call_output",
"custom_tool_call_output",
"tool_search_output":
return true
default:
return false
}
}
func codexInputItemRequiresName(typ string) bool {
switch strings.TrimSpace(typ) {
case "function_call", "custom_tool_call", "mcp_tool_call":
return true
default:
return false
}
return strings.HasSuffix(typ, "_call") || strings.HasSuffix(typ, "_call_output")
}
func normalizeCodexTools(reqBody map[string]any) bool {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment