Unverified Commit b6d46fd5 authored by InCerryGit's avatar InCerryGit Committed by GitHub
Browse files

Merge branch 'Wei-Shaw:main' into main

parents fa68cbad fdd8499f
......@@ -15,6 +15,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
......@@ -130,7 +131,7 @@ func (s *httpUpstreamStub) Do(_ *http.Request, _ string, _ int64, _ int) (*http.
return s.resp, s.err
}
func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int, _ bool) (*http.Response, error) {
func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int, _ *tlsfingerprint.Profile) (*http.Response, error) {
return s.resp, s.err
}
......@@ -171,7 +172,7 @@ func (s *queuedHTTPUpstreamStub) Do(req *http.Request, _ string, _ int64, _ int)
return resp, err
}
func (s *queuedHTTPUpstreamStub) DoWithTLS(req *http.Request, proxyURL string, accountID int64, concurrency int, _ bool) (*http.Response, error) {
func (s *queuedHTTPUpstreamStub) DoWithTLS(req *http.Request, proxyURL string, accountID int64, concurrency int, _ *tlsfingerprint.Profile) (*http.Response, error) {
return s.Do(req, proxyURL, accountID, concurrency)
}
......
......@@ -89,7 +89,8 @@ type AntigravityTokenInfo struct {
TokenType string `json:"token_type"`
Email string `json:"email,omitempty"`
ProjectID string `json:"project_id,omitempty"`
ProjectIDMissing bool `json:"-"` // LoadCodeAssist 未返回 project_id
ProjectIDMissing bool `json:"-"`
PlanType string `json:"-"`
}
// ExchangeCode 用 authorization code 交换 token
......@@ -145,13 +146,17 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
result.Email = userInfo.Email
}
// 获取 project_id(部分账户类型可能没有),失败时重试
projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenResp.AccessToken, proxyURL, 3)
// 获取 project_id + plan_type(部分账户类型可能没有),失败时重试
loadResult, loadErr := s.loadProjectIDWithRetry(ctx, tokenResp.AccessToken, proxyURL, 3)
if loadErr != nil {
fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败(重试后): %v\n", loadErr)
result.ProjectIDMissing = true
} else {
result.ProjectID = projectID
}
if loadResult != nil {
result.ProjectID = loadResult.ProjectID
if loadResult.Subscription != nil {
result.PlanType = loadResult.Subscription.PlanType
}
}
return result, nil
......@@ -230,13 +235,17 @@ func (s *AntigravityOAuthService) ValidateRefreshToken(ctx context.Context, refr
tokenInfo.Email = userInfo.Email
}
// 获取 project_id(容错,失败不阻塞)
projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenInfo.AccessToken, proxyURL, 3)
// 获取 project_id + plan_type(容错,失败不阻塞)
loadResult, loadErr := s.loadProjectIDWithRetry(ctx, tokenInfo.AccessToken, proxyURL, 3)
if loadErr != nil {
fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败(重试后): %v\n", loadErr)
tokenInfo.ProjectIDMissing = true
} else {
tokenInfo.ProjectID = projectID
}
if loadResult != nil {
tokenInfo.ProjectID = loadResult.ProjectID
if loadResult.Subscription != nil {
tokenInfo.PlanType = loadResult.Subscription.PlanType
}
}
return tokenInfo, nil
......@@ -288,33 +297,42 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou
tokenInfo.Email = existingEmail
}
// 每次刷新都调用 LoadCodeAssist 获取 project_id,失败时重试
// 每次刷新都调用 LoadCodeAssist 获取 project_id + plan_type,失败时重试
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenInfo.AccessToken, proxyURL, 3)
loadResult, loadErr := s.loadProjectIDWithRetry(ctx, tokenInfo.AccessToken, proxyURL, 3)
if loadErr != nil {
// LoadCodeAssist 失败,保留原有 project_id
tokenInfo.ProjectID = existingProjectID
// 只有从未获取过 project_id 且本次也获取失败时,才标记为真正缺失
// 如果之前有 project_id,本次只是临时故障,不应标记为错误
if existingProjectID == "" {
tokenInfo.ProjectIDMissing = true
}
} else {
tokenInfo.ProjectID = projectID
}
if loadResult != nil {
if loadResult.ProjectID != "" {
tokenInfo.ProjectID = loadResult.ProjectID
}
if loadResult.Subscription != nil {
tokenInfo.PlanType = loadResult.Subscription.PlanType
}
}
return tokenInfo, nil
}
// loadProjectIDWithRetry 带重试机制获取 project_id
// 返回 project_id 和错误,失败时会重试指定次数
func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, accessToken, proxyURL string, maxRetries int) (string, error) {
// loadCodeAssistResult 封装 loadProjectIDWithRetry 的返回结果,
// 同时携带从 LoadCodeAssist 响应中提取的 plan_type 信息。
type loadCodeAssistResult struct {
ProjectID string
Subscription *AntigravitySubscriptionResult
}
// loadProjectIDWithRetry 带重试机制获取 project_id,同时从响应中提取 plan_type。
func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, accessToken, proxyURL string, maxRetries int) (*loadCodeAssistResult, error) {
var lastErr error
var lastSubscription *AntigravitySubscriptionResult
for attempt := 0; attempt <= maxRetries; attempt++ {
if attempt > 0 {
// 指数退避:1s, 2s, 4s
backoff := time.Duration(1<<uint(attempt-1)) * time.Second
if backoff > 8*time.Second {
backoff = 8 * time.Second
......@@ -324,24 +342,34 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac
client, err := antigravity.NewClient(proxyURL)
if err != nil {
return "", fmt.Errorf("create antigravity client failed: %w", err)
return nil, fmt.Errorf("create antigravity client failed: %w", err)
}
loadResp, loadRaw, err := client.LoadCodeAssist(ctx, accessToken)
if loadResp != nil {
sub := NormalizeAntigravitySubscription(loadResp)
lastSubscription = &sub
}
if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" {
return loadResp.CloudAICompanionProject, nil
return &loadCodeAssistResult{
ProjectID: loadResp.CloudAICompanionProject,
Subscription: lastSubscription,
}, nil
}
if err == nil {
if projectID, onboardErr := tryOnboardProjectID(ctx, client, accessToken, loadRaw); onboardErr == nil && projectID != "" {
return projectID, nil
return &loadCodeAssistResult{
ProjectID: projectID,
Subscription: lastSubscription,
}, nil
} else if onboardErr != nil {
lastErr = onboardErr
continue
}
}
// 记录错误
if err != nil {
lastErr = err
} else if loadResp == nil {
......@@ -351,7 +379,10 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac
}
}
return "", fmt.Errorf("获取 project_id 失败 (重试 %d 次后): %w", maxRetries, lastErr)
if lastSubscription != nil {
return &loadCodeAssistResult{Subscription: lastSubscription}, fmt.Errorf("获取 project_id 失败 (重试 %d 次后): %w", maxRetries, lastErr)
}
return nil, fmt.Errorf("获取 project_id 失败 (重试 %d 次后): %w", maxRetries, lastErr)
}
func tryOnboardProjectID(ctx context.Context, client *antigravity.Client, accessToken string, loadRaw map[string]any) (string, error) {
......@@ -410,7 +441,11 @@ func (s *AntigravityOAuthService) FillProjectID(ctx context.Context, account *Ac
proxyURL = proxy.URL()
}
}
return s.loadProjectIDWithRetry(ctx, accessToken, proxyURL, 3)
result, err := s.loadProjectIDWithRetry(ctx, accessToken, proxyURL, 3)
if result != nil {
return result.ProjectID, err
}
return "", err
}
// BuildAccountCredentials 构建账户凭证
......@@ -431,6 +466,9 @@ func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *Antigravity
if tokenInfo.ProjectID != "" {
creds["project_id"] = tokenInfo.ProjectID
}
if tokenInfo.PlanType != "" {
creds["plan_type"] = tokenInfo.PlanType
}
return creds
}
......
package service
import (
"context"
"log/slog"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
const (
AntigravityPrivacySet = "privacy_set"
AntigravityPrivacyFailed = "privacy_set_failed"
)
// setAntigravityPrivacy 调用 Antigravity API 设置隐私并验证结果。
// 流程:
// 1. setUserSettings 清空设置 → 检查返回值 {"userSettings":{}}
// 2. fetchUserInfo 二次验证隐私是否已生效(需要 project_id)
//
// 返回 privacy_mode 值:"privacy_set" 成功,"privacy_set_failed" 失败,空串表示无法执行。
func setAntigravityPrivacy(ctx context.Context, accessToken, projectID, proxyURL string) string {
if accessToken == "" {
return ""
}
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
client, err := antigravity.NewClient(proxyURL)
if err != nil {
slog.Warn("antigravity_privacy_client_error", "error", err.Error())
return AntigravityPrivacyFailed
}
// 第 1 步:调用 setUserSettings,检查返回值
setResp, err := client.SetUserSettings(ctx, accessToken)
if err != nil {
slog.Warn("antigravity_privacy_set_failed", "error", err.Error())
return AntigravityPrivacyFailed
}
if !setResp.IsSuccess() {
slog.Warn("antigravity_privacy_set_response_not_empty",
"user_settings", setResp.UserSettings,
)
return AntigravityPrivacyFailed
}
// 第 2 步:调用 fetchUserInfo 二次验证隐私是否已生效
if strings.TrimSpace(projectID) == "" {
slog.Warn("antigravity_privacy_missing_project_id")
return AntigravityPrivacyFailed
}
userInfo, err := client.FetchUserInfo(ctx, accessToken, projectID)
if err != nil {
slog.Warn("antigravity_privacy_verify_failed", "error", err.Error())
return AntigravityPrivacyFailed
}
if !userInfo.IsPrivate() {
slog.Warn("antigravity_privacy_verify_not_private",
"user_settings", userInfo.UserSettings,
)
return AntigravityPrivacyFailed
}
slog.Info("antigravity_privacy_set_success")
return AntigravityPrivacySet
}
func applyAntigravityPrivacyMode(account *Account, mode string) {
if account == nil || strings.TrimSpace(mode) == "" {
return
}
extra := make(map[string]any, len(account.Extra)+1)
for k, v := range account.Extra {
extra[k] = v
}
extra["privacy_mode"] = mode
account.Extra = extra
}
//go:build unit
package service
import (
"testing"
)
func applyAntigravitySubscriptionResult(account *Account, result AntigravitySubscriptionResult) (map[string]any, map[string]any) {
credentials := make(map[string]any)
for k, v := range account.Credentials {
credentials[k] = v
}
credentials["plan_type"] = result.PlanType
extra := make(map[string]any)
for k, v := range account.Extra {
extra[k] = v
}
if result.SubscriptionStatus != "" {
extra["subscription_status"] = result.SubscriptionStatus
} else {
delete(extra, "subscription_status")
}
if result.SubscriptionError != "" {
extra["subscription_error"] = result.SubscriptionError
} else {
delete(extra, "subscription_error")
}
return credentials, extra
}
func TestApplyAntigravityPrivacyMode_SetsInMemoryExtra(t *testing.T) {
account := &Account{}
applyAntigravityPrivacyMode(account, AntigravityPrivacySet)
if account.Extra == nil {
t.Fatal("expected account.Extra to be initialized")
}
if got := account.Extra["privacy_mode"]; got != AntigravityPrivacySet {
t.Fatalf("expected privacy_mode %q, got %v", AntigravityPrivacySet, got)
}
}
func TestApplyAntigravityPrivacyMode_PreservedBySubscriptionResult(t *testing.T) {
account := &Account{
Credentials: map[string]any{
"access_token": "token",
},
Extra: map[string]any{
"existing": "value",
},
}
applyAntigravityPrivacyMode(account, AntigravityPrivacySet)
_, extra := applyAntigravitySubscriptionResult(account, AntigravitySubscriptionResult{
PlanType: "Pro",
})
if got := extra["privacy_mode"]; got != AntigravityPrivacySet {
t.Fatalf("expected subscription writeback to keep privacy_mode %q, got %v", AntigravityPrivacySet, got)
}
if got := extra["existing"]; got != "value" {
t.Fatalf("expected existing extra fields to be preserved, got %v", got)
}
}
......@@ -12,6 +12,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
"github.com/stretchr/testify/require"
)
......@@ -40,7 +41,7 @@ func (r *recordingOKUpstream) Do(req *http.Request, proxyURL string, accountID i
}, nil
}
func (r *recordingOKUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
func (r *recordingOKUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) {
return r.Do(req, proxyURL, accountID, accountConcurrency)
}
......@@ -61,7 +62,7 @@ func (s *stubAntigravityUpstream) Do(req *http.Request, proxyURL string, account
}, nil
}
func (s *stubAntigravityUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
func (s *stubAntigravityUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) {
return s.Do(req, proxyURL, accountID, accountConcurrency)
}
......
......@@ -10,6 +10,7 @@ import (
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
"github.com/stretchr/testify/require"
)
......@@ -93,7 +94,7 @@ func (m *mockSmartRetryUpstream) Do(req *http.Request, proxyURL string, accountI
}, respErr
}
func (m *mockSmartRetryUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
func (m *mockSmartRetryUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) {
return m.Do(req, proxyURL, accountID, accountConcurrency)
}
......
package service
import (
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
const antigravitySubscriptionAbnormal = "abnormal"
// AntigravitySubscriptionResult 表示订阅检测后的规范化结果。
type AntigravitySubscriptionResult struct {
PlanType string
SubscriptionStatus string
SubscriptionError string
}
// NormalizeAntigravitySubscription 从 LoadCodeAssistResponse 提取 plan_type + 异常状态。
// 使用 GetTier()(返回 tier ID)+ TierIDToPlanType 映射。
func NormalizeAntigravitySubscription(resp *antigravity.LoadCodeAssistResponse) AntigravitySubscriptionResult {
if resp == nil {
return AntigravitySubscriptionResult{PlanType: "Free"}
}
if len(resp.IneligibleTiers) > 0 {
result := AntigravitySubscriptionResult{
PlanType: "Abnormal",
SubscriptionStatus: antigravitySubscriptionAbnormal,
}
if resp.IneligibleTiers[0] != nil {
result.SubscriptionError = strings.TrimSpace(resp.IneligibleTiers[0].ReasonMessage)
}
return result
}
tierID := resp.GetTier()
return AntigravitySubscriptionResult{
PlanType: antigravity.TierIDToPlanType(tierID),
}
}
......@@ -235,6 +235,12 @@ const (
// SettingKeyBackendModeEnabled Backend 模式:禁用用户注册和自助服务,仅管理员可登录
SettingKeyBackendModeEnabled = "backend_mode_enabled"
// Gateway Forwarding Behavior
// SettingKeyEnableFingerprintUnification 是否统一 OAuth 账号的 X-Stainless-* 指纹头(默认 true)
SettingKeyEnableFingerprintUnification = "enable_fingerprint_unification"
// SettingKeyEnableMetadataPassthrough 是否透传客户端原始 metadata.user_id(默认 false)
SettingKeyEnableMetadataPassthrough = "enable_metadata_passthrough"
)
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
......
......@@ -12,6 +12,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
"github.com/stretchr/testify/require"
)
......@@ -35,7 +36,7 @@ func (u *epFixedUpstream) Do(req *http.Request, proxyURL string, accountID int64
}, nil
}
func (u *epFixedUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
func (u *epFixedUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) {
return u.Do(req, proxyURL, accountID, accountConcurrency)
}
......
......@@ -15,6 +15,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
......@@ -60,7 +61,7 @@ func (u *anthropicHTTPUpstreamRecorder) Do(req *http.Request, proxyURL string, a
return u.resp, nil
}
func (u *anthropicHTTPUpstreamRecorder) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
func (u *anthropicHTTPUpstreamRecorder) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) {
return u.Do(req, proxyURL, accountID, accountConcurrency)
}
......@@ -175,13 +176,13 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd
require.Equal(t, "claude-3-haiku-20240307", gjson.GetBytes(upstream.lastBody, "model").String(), "透传模式应应用账号级模型映射")
require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key"))
require.Empty(t, upstream.lastReq.Header.Get("authorization"))
require.Empty(t, upstream.lastReq.Header.Get("x-goog-api-key"))
require.Empty(t, upstream.lastReq.Header.Get("cookie"))
require.Equal(t, "2023-06-01", upstream.lastReq.Header.Get("anthropic-version"))
require.Equal(t, "interleaved-thinking-2025-05-14", upstream.lastReq.Header.Get("anthropic-beta"))
require.Empty(t, upstream.lastReq.Header.Get("x-stainless-lang"), "API Key 透传不应注入 OAuth 指纹头")
require.Equal(t, "upstream-anthropic-key", getHeaderRaw(upstream.lastReq.Header, "x-api-key"))
require.Empty(t, getHeaderRaw(upstream.lastReq.Header, "authorization"))
require.Empty(t, getHeaderRaw(upstream.lastReq.Header, "x-goog-api-key"))
require.Empty(t, getHeaderRaw(upstream.lastReq.Header, "cookie"))
require.Equal(t, "2023-06-01", getHeaderRaw(upstream.lastReq.Header, "anthropic-version"))
require.Equal(t, "interleaved-thinking-2025-05-14", getHeaderRaw(upstream.lastReq.Header, "anthropic-beta"))
require.Empty(t, getHeaderRaw(upstream.lastReq.Header, "x-stainless-lang"), "API Key 透传不应注入 OAuth 指纹头")
require.Contains(t, rec.Body.String(), `"cached_tokens":7`)
require.NotContains(t, rec.Body.String(), `"cache_read_input_tokens":7`, "透传输出不应被网关改写")
......@@ -257,9 +258,9 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo
require.NoError(t, err)
require.Equal(t, "claude-3-opus-20240229", gjson.GetBytes(upstream.lastBody, "model").String(), "count_tokens 透传模式应应用账号级模型映射")
require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key"))
require.Empty(t, upstream.lastReq.Header.Get("authorization"))
require.Empty(t, upstream.lastReq.Header.Get("cookie"))
require.Equal(t, "upstream-anthropic-key", getHeaderRaw(upstream.lastReq.Header, "x-api-key"))
require.Empty(t, getHeaderRaw(upstream.lastReq.Header, "authorization"))
require.Empty(t, getHeaderRaw(upstream.lastReq.Header, "cookie"))
require.Equal(t, http.StatusOK, rec.Code)
require.JSONEq(t, upstreamRespBody, rec.Body.String())
require.Empty(t, rec.Header().Get("Set-Cookie"))
......@@ -684,8 +685,8 @@ func TestGatewayService_AnthropicOAuth_NotAffectedByAPIKeyPassthroughToggle(t *t
req, err := svc.buildUpstreamRequest(context.Background(), c, account, []byte(`{"model":"claude-3-7-sonnet-20250219"}`), "oauth-token", "oauth", "claude-3-7-sonnet-20250219", true, false)
require.NoError(t, err)
require.Equal(t, "Bearer oauth-token", req.Header.Get("authorization"))
require.Contains(t, req.Header.Get("anthropic-beta"), claude.BetaOAuth, "OAuth 链路仍应按原逻辑补齐 oauth beta")
require.Equal(t, "Bearer oauth-token", getHeaderRaw(req.Header, "authorization"))
require.Contains(t, getHeaderRaw(req.Header, "anthropic-beta"), claude.BetaOAuth, "OAuth 链路仍应按原逻辑补齐 oauth beta")
}
func TestGatewayService_AnthropicOAuth_ForwardPreservesBillingHeaderSystemBlock(t *testing.T) {
......@@ -755,8 +756,8 @@ func TestGatewayService_AnthropicOAuth_ForwardPreservesBillingHeaderSystemBlock(
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, upstream.lastReq)
require.Equal(t, "Bearer oauth-token", upstream.lastReq.Header.Get("authorization"))
require.Contains(t, upstream.lastReq.Header.Get("anthropic-beta"), claude.BetaOAuth)
require.Equal(t, "Bearer oauth-token", getHeaderRaw(upstream.lastReq.Header, "authorization"))
require.Contains(t, getHeaderRaw(upstream.lastReq.Header, "anthropic-beta"), claude.BetaOAuth)
system := gjson.GetBytes(upstream.lastBody, "system")
require.True(t, system.Exists())
......
......@@ -2,31 +2,28 @@ package service
import "testing"
func TestDebugGatewayBodyLoggingEnabled(t *testing.T) {
t.Run("default disabled", func(t *testing.T) {
t.Setenv(debugGatewayBodyEnv, "")
if debugGatewayBodyLoggingEnabled() {
t.Fatalf("expected debug gateway body logging to be disabled by default")
func TestParseDebugEnvBool(t *testing.T) {
t.Run("empty is false", func(t *testing.T) {
if parseDebugEnvBool("") {
t.Fatalf("expected false for empty string")
}
})
t.Run("enabled with true-like values", func(t *testing.T) {
t.Run("true-like values", func(t *testing.T) {
for _, value := range []string{"1", "true", "TRUE", "yes", "on"} {
t.Run(value, func(t *testing.T) {
t.Setenv(debugGatewayBodyEnv, value)
if !debugGatewayBodyLoggingEnabled() {
t.Fatalf("expected debug gateway body logging to be enabled for %q", value)
if !parseDebugEnvBool(value) {
t.Fatalf("expected true for %q", value)
}
})
}
})
t.Run("disabled with other values", func(t *testing.T) {
t.Run("false-like values", func(t *testing.T) {
for _, value := range []string{"0", "false", "off", "debug"} {
t.Run(value, func(t *testing.T) {
t.Setenv(debugGatewayBodyEnv, value)
if debugGatewayBodyLoggingEnabled() {
t.Fatalf("expected debug gateway body logging to be disabled for %q", value)
if parseDebugEnvBool(value) {
t.Fatalf("expected false for %q", value)
}
})
}
......
......@@ -120,7 +120,7 @@ func (s *GatewayService) ForwardAsChatCompletions(
}
// 11. Send request
resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
if err != nil {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
......
......@@ -117,7 +117,7 @@ func (s *GatewayService) ForwardAsResponses(
}
// 11. Send request
resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
if err != nil {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
......
......@@ -124,6 +124,27 @@ func TestSystemIncludesClaudeCodePrompt(t *testing.T) {
},
want: false,
},
// json.RawMessage cases (conversion path: ForwardAsResponses / ForwardAsChatCompletions)
{
name: "json.RawMessage string with Claude Code prompt",
system: json.RawMessage(`"` + claudeCodeSystemPrompt + `"`),
want: true,
},
{
name: "json.RawMessage string without Claude Code prompt",
system: json.RawMessage(`"You are a helpful assistant"`),
want: false,
},
{
name: "json.RawMessage nil (empty)",
system: json.RawMessage(nil),
want: false,
},
{
name: "json.RawMessage empty string",
system: json.RawMessage(`""`),
want: false,
},
}
for _, tt := range tests {
......@@ -202,6 +223,29 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
wantSystemLen: 1,
wantFirstText: claudeCodeSystemPrompt,
},
// json.RawMessage cases (conversion path: ForwardAsResponses / ForwardAsChatCompletions)
{
name: "json.RawMessage string system",
body: `{"model":"claude-3","system":"Custom prompt"}`,
system: json.RawMessage(`"Custom prompt"`),
wantSystemLen: 2,
wantFirstText: claudeCodeSystemPrompt,
wantSecondText: claudePrefix + "\n\nCustom prompt",
},
{
name: "json.RawMessage nil system",
body: `{"model":"claude-3"}`,
system: json.RawMessage(nil),
wantSystemLen: 1,
wantFirstText: claudeCodeSystemPrompt,
},
{
name: "json.RawMessage Claude Code prompt (should not duplicate)",
body: `{"model":"claude-3","system":"` + claudeCodeSystemPrompt + `"}`,
system: json.RawMessage(`"` + claudeCodeSystemPrompt + `"`),
wantSystemLen: 1,
wantFirstText: claudeCodeSystemPrompt,
},
}
for _, tt := range tests {
......
......@@ -40,6 +40,7 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo
nil,
nil,
nil,
nil,
)
}
......
......@@ -13,6 +13,7 @@ import (
mathrand "math/rand"
"net/http"
"os"
"path/filepath"
"regexp"
"sort"
"strconv"
......@@ -366,6 +367,7 @@ var allowedHeaders = map[string]bool{
"sec-fetch-mode": true,
"user-agent": true,
"content-type": true,
"accept-encoding": true,
}
// GatewayCache 定义网关服务的缓存操作接口。
......@@ -563,6 +565,8 @@ type GatewayService struct {
responseHeaderFilter *responseheaders.CompiledHeaderFilter
debugModelRouting atomic.Bool
debugClaudeMimic atomic.Bool
debugGatewayBodyFile atomic.Pointer[os.File] // non-nil when SUB2API_DEBUG_GATEWAY_BODY is set
tlsFPProfileService *TLSFingerprintProfileService
}
// NewGatewayService creates a new GatewayService
......@@ -589,6 +593,7 @@ func NewGatewayService(
rpmCache RPMCache,
digestStore *DigestSessionStore,
settingService *SettingService,
tlsFPProfileService *TLSFingerprintProfileService,
) *GatewayService {
userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg)
modelsListTTL := resolveModelsListCacheTTL(cfg)
......@@ -620,6 +625,7 @@ func NewGatewayService(
modelsListCache: gocache.New(modelsListTTL, time.Minute),
modelsListCacheTTL: modelsListTTL,
responseHeaderFilter: compileResponseHeaderFilter(cfg),
tlsFPProfileService: tlsFPProfileService,
}
svc.userGroupRateResolver = newUserGroupRateResolver(
userGroupRateRepo,
......@@ -630,6 +636,9 @@ func NewGatewayService(
)
svc.debugModelRouting.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING")))
svc.debugClaudeMimic.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_CLAUDE_MIMIC")))
if path := strings.TrimSpace(os.Getenv(debugGatewayBodyEnv)); path != "" {
svc.initDebugGatewayBodyFile(path)
}
return svc
}
......@@ -3740,9 +3749,28 @@ func isClaudeCodeRequest(ctx context.Context, c *gin.Context, parsed *ParsedRequ
return isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID)
}
// normalizeSystemParam 将 json.RawMessage 类型的 system 参数转为标准 Go 类型(string / []any / nil),
// 避免 type switch 中 json.RawMessage(底层 []byte)无法匹配 case string / case []any / case nil 的问题。
// 这是 Go 的 typed nil 陷阱:(json.RawMessage, nil) ≠ (nil, nil)。
func normalizeSystemParam(system any) any {
raw, ok := system.(json.RawMessage)
if !ok {
return system
}
if len(raw) == 0 {
return nil
}
var parsed any
if err := json.Unmarshal(raw, &parsed); err != nil {
return nil
}
return parsed
}
// systemIncludesClaudeCodePrompt 检查 system 中是否已包含 Claude Code 提示词
// 使用前缀匹配支持多种变体(标准版、Agent SDK 版等)
func systemIncludesClaudeCodePrompt(system any) bool {
system = normalizeSystemParam(system)
switch v := system.(type) {
case string:
return hasClaudeCodePrefix(v)
......@@ -3771,6 +3799,7 @@ func hasClaudeCodePrefix(text string) bool {
// injectClaudeCodePrompt 在 system 开头注入 Claude Code 提示词
// 处理 null、字符串、数组三种格式
func injectClaudeCodePrompt(body []byte, system any) []byte {
system = normalizeSystemParam(system)
claudeCodeBlock, err := marshalAnthropicSystemTextBlock(claudeCodeSystemPrompt, true)
if err != nil {
logger.LegacyPrintf("service.gateway", "Warning: failed to build Claude Code prompt block: %v", err)
......@@ -4048,8 +4077,15 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
reqStream := parsed.Stream
originalModel := reqModel
// === DEBUG: 打印客户端原始请求 body ===
debugLogRequestBody("CLIENT_ORIGINAL", body)
// === DEBUG: 打印客户端原始请求(headers + body 摘要)===
if c != nil {
s.debugLogGatewaySnapshot("CLIENT_ORIGINAL", c.Request.Header, body, map[string]string{
"account": fmt.Sprintf("%d(%s)", account.ID, account.Name),
"account_type": string(account.Type),
"model": reqModel,
"stream": strconv.FormatBool(reqStream),
})
}
isClaudeCode := isClaudeCodeRequest(ctx, c, parsed)
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
......@@ -4066,12 +4102,16 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
if s.identityService != nil {
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
if err == nil && fp != nil {
// metadata 透传开启时跳过 metadata 注入
_, mimicMPT := s.settingService.GetGatewayForwardingSettings(ctx)
if !mimicMPT {
if metadataUserID := s.buildOAuthMetadataUserID(parsed, account, fp); metadataUserID != "" {
normalizeOpts.injectMetadata = true
normalizeOpts.metadataUserID = metadataUserID
}
}
}
}
body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
}
......@@ -4116,9 +4156,12 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
proxyURL = account.Proxy.URL()
}
// 解析 TLS 指纹 profile(同一请求生命周期内不变,避免重试循环中重复解析)
tlsProfile := s.tlsFPProfileService.ResolveTLSProfile(account)
// 调试日志:记录即将转发的账号信息
logger.LegacyPrintf("service.gateway", "[Forward] Using account: ID=%d Name=%s Platform=%s Type=%s TLSFingerprint=%v Proxy=%s",
account.ID, account.Name, account.Platform, account.Type, account.IsTLSFingerprintEnabled(), proxyURL)
account.ID, account.Name, account.Platform, account.Type, tlsProfile, proxyURL)
// Pre-filter: strip empty text blocks (including nested in tool_result) to prevent upstream 400.
body = StripEmptyTextBlocks(body)
......@@ -4138,7 +4181,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}
// 发送请求
resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, tlsProfile)
if err != nil {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
......@@ -4171,7 +4214,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
if readErr == nil {
_ = resp.Body.Close()
if s.isThinkingBlockSignatureError(respBody) && s.settingService.IsSignatureRectifierEnabled(ctx) {
if s.shouldRectifySignatureError(ctx, account, respBody) {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
......@@ -4216,7 +4259,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
retryReq, buildErr := s.buildUpstreamRequest(retryCtx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
releaseRetryCtx()
if buildErr == nil {
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, tlsProfile)
if retryErr == nil {
if retryResp.StatusCode < 400 {
logger.LegacyPrintf("service.gateway", "Account %d: thinking block retry succeeded (blocks downgraded)", account.ID)
......@@ -4226,7 +4269,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
retryRespBody, retryReadErr := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
_ = retryResp.Body.Close()
if retryReadErr == nil && retryResp.StatusCode == 400 && s.isThinkingBlockSignatureError(retryRespBody) {
if retryReadErr == nil && retryResp.StatusCode == 400 && s.isSignatureErrorPattern(ctx, account, retryRespBody) {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
......@@ -4251,7 +4294,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
retryReq2, buildErr2 := s.buildUpstreamRequest(retryCtx2, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
releaseRetryCtx2()
if buildErr2 == nil {
retryResp2, retryErr2 := s.httpUpstream.DoWithTLS(retryReq2, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
retryResp2, retryErr2 := s.httpUpstream.DoWithTLS(retryReq2, proxyURL, account.ID, account.Concurrency, tlsProfile)
if retryErr2 == nil {
resp = retryResp2
break
......@@ -4322,7 +4365,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
budgetRetryReq, buildErr := s.buildUpstreamRequest(budgetRetryCtx, c, account, rectifiedBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
releaseBudgetRetryCtx()
if buildErr == nil {
budgetRetryResp, retryErr := s.httpUpstream.DoWithTLS(budgetRetryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
budgetRetryResp, retryErr := s.httpUpstream.DoWithTLS(budgetRetryReq, proxyURL, account.ID, account.Concurrency, tlsProfile)
if retryErr == nil {
resp = budgetRetryResp
break
......@@ -4628,7 +4671,7 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput(
return nil, err
}
resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
if err != nil {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
......@@ -4840,8 +4883,9 @@ func (s *GatewayService) buildUpstreamRequestAnthropicAPIKeyPassthrough(
if !allowedHeaders[lowerKey] {
continue
}
wireKey := resolveWireCasing(key)
for _, v := range values {
req.Header.Add(key, v)
addHeaderRaw(req.Header, wireKey, v)
}
}
}
......@@ -4851,13 +4895,13 @@ func (s *GatewayService) buildUpstreamRequestAnthropicAPIKeyPassthrough(
req.Header.Del("x-api-key")
req.Header.Del("x-goog-api-key")
req.Header.Del("cookie")
req.Header.Set("x-api-key", token)
setHeaderRaw(req.Header, "x-api-key", token)
if req.Header.Get("content-type") == "" {
req.Header.Set("content-type", "application/json")
if getHeaderRaw(req.Header, "content-type") == "" {
setHeaderRaw(req.Header, "content-type", "application/json")
}
if req.Header.Get("anthropic-version") == "" {
req.Header.Set("anthropic-version", "2023-06-01")
if getHeaderRaw(req.Header, "anthropic-version") == "" {
setHeaderRaw(req.Header, "anthropic-version", "2023-06-01")
}
return req, nil
......@@ -5346,7 +5390,7 @@ func (s *GatewayService) executeBedrockUpstream(
return nil, err
}
resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, false)
resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, nil)
if err != nil {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
......@@ -5591,8 +5635,12 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
clientHeaders = c.Request.Header
}
// OAuth账号:应用统一指纹
// OAuth账号:应用统一指纹和metadata重写(受设置开关控制)
var fingerprint *Fingerprint
enableFP, enableMPT := true, false
if s.settingService != nil {
enableFP, enableMPT = s.settingService.GetGatewayForwardingSettings(ctx)
}
if account.IsOAuth() && s.identityService != nil {
// 1. 获取或创建指纹(包含随机生成的ClientID)
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, clientHeaders)
......@@ -5600,10 +5648,14 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
logger.LegacyPrintf("service.gateway", "Warning: failed to get fingerprint for account %d: %v", account.ID, err)
// 失败时降级为透传原始headers
} else {
if enableFP {
fingerprint = fp
}
// 2. 重写metadata.user_id(需要指纹中的ClientID和账号的account_uuid)
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
// 当 metadata 透传开启时跳过重写
if !enableMPT {
accountUUID := account.GetExtraString("account_uuid")
if accountUUID != "" && fp.ClientID != "" {
if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID, fp.UserAgent); err == nil && len(newBody) > 0 {
......@@ -5612,28 +5664,27 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
}
}
// === DEBUG: 打印转发给上游的 body(metadata 已重写) ===
debugLogRequestBody("UPSTREAM_FORWARD", body)
}
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
if err != nil {
return nil, err
}
// 设置认证头
// 设置认证头(保持原始大小写)
if tokenType == "oauth" {
req.Header.Set("authorization", "Bearer "+token)
setHeaderRaw(req.Header, "authorization", "Bearer "+token)
} else {
req.Header.Set("x-api-key", token)
setHeaderRaw(req.Header, "x-api-key", token)
}
// 白名单透传headers
// 白名单透传headers(恢复真实 wire casing)
for key, values := range clientHeaders {
lowerKey := strings.ToLower(key)
if allowedHeaders[lowerKey] {
wireKey := resolveWireCasing(key)
for _, v := range values {
req.Header.Add(key, v)
addHeaderRaw(req.Header, wireKey, v)
}
}
}
......@@ -5643,15 +5694,15 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
s.identityService.ApplyFingerprint(req, fingerprint)
}
// 确保必要的headers存在
if req.Header.Get("content-type") == "" {
req.Header.Set("content-type", "application/json")
// 确保必要的headers存在(保持原始大小写)
if getHeaderRaw(req.Header, "content-type") == "" {
setHeaderRaw(req.Header, "content-type", "application/json")
}
if req.Header.Get("anthropic-version") == "" {
req.Header.Set("anthropic-version", "2023-06-01")
if getHeaderRaw(req.Header, "anthropic-version") == "" {
setHeaderRaw(req.Header, "anthropic-version", "2023-06-01")
}
if tokenType == "oauth" {
applyClaudeOAuthHeaderDefaults(req, reqStream)
applyClaudeOAuthHeaderDefaults(req)
}
// Build effective drop set: merge static defaults with dynamic beta policy filter rules
......@@ -5667,31 +5718,41 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// - 保留 incoming beta 的同时,确保 OAuth 所需 beta 存在
applyClaudeCodeMimicHeaders(req, reqStream)
incomingBeta := req.Header.Get("anthropic-beta")
incomingBeta := getHeaderRaw(req.Header, "anthropic-beta")
// Match real Claude CLI traffic (per mitmproxy reports):
// messages requests typically use only oauth + interleaved-thinking.
// Also drop claude-code beta if a downstream client added it.
requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking}
req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropWithClaudeCodeSet))
setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropWithClaudeCodeSet))
} else {
// Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta
clientBetaHeader := req.Header.Get("anthropic-beta")
req.Header.Set("anthropic-beta", stripBetaTokensWithSet(s.getBetaHeader(modelID, clientBetaHeader), effectiveDropSet))
clientBetaHeader := getHeaderRaw(req.Header, "anthropic-beta")
setHeaderRaw(req.Header, "anthropic-beta", stripBetaTokensWithSet(s.getBetaHeader(modelID, clientBetaHeader), effectiveDropSet))
}
} else {
// API-key accounts: apply beta policy filter to strip controlled tokens
if existingBeta := req.Header.Get("anthropic-beta"); existingBeta != "" {
req.Header.Set("anthropic-beta", stripBetaTokensWithSet(existingBeta, effectiveDropSet))
if existingBeta := getHeaderRaw(req.Header, "anthropic-beta"); existingBeta != "" {
setHeaderRaw(req.Header, "anthropic-beta", stripBetaTokensWithSet(existingBeta, effectiveDropSet))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey {
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
setHeaderRaw(req.Header, "anthropic-beta", beta)
}
}
}
}
// === DEBUG: 打印上游转发请求(headers + body 摘要),与 CLIENT_ORIGINAL 对比 ===
s.debugLogGatewaySnapshot("UPSTREAM_FORWARD", req.Header, body, map[string]string{
"url": req.URL.String(),
"token_type": tokenType,
"mimic_claude_code": strconv.FormatBool(mimicClaudeCode),
"fingerprint_applied": strconv.FormatBool(fingerprint != nil),
"enable_fp": strconv.FormatBool(enableFP),
"enable_mpt": strconv.FormatBool(enableMPT),
})
// Always capture a compact fingerprint line for later error diagnostics.
// We only print it when needed (or when the explicit debug flag is enabled).
if c != nil && tokenType == "oauth" {
......@@ -5771,23 +5832,20 @@ func defaultAPIKeyBetaHeader(body []byte) string {
return claude.APIKeyBetaHeader
}
func applyClaudeOAuthHeaderDefaults(req *http.Request, isStream bool) {
func applyClaudeOAuthHeaderDefaults(req *http.Request) {
if req == nil {
return
}
if req.Header.Get("accept") == "" {
req.Header.Set("accept", "application/json")
if getHeaderRaw(req.Header, "Accept") == "" {
setHeaderRaw(req.Header, "Accept", "application/json")
}
for key, value := range claude.DefaultHeaders {
if value == "" {
continue
}
if req.Header.Get(key) == "" {
req.Header.Set(key, value)
}
if getHeaderRaw(req.Header, key) == "" {
setHeaderRaw(req.Header, resolveWireCasing(key), value)
}
if isStream && req.Header.Get("x-stainless-helper-method") == "" {
req.Header.Set("x-stainless-helper-method", "stream")
}
}
......@@ -6083,18 +6141,19 @@ func applyClaudeCodeMimicHeaders(req *http.Request, isStream bool) {
return
}
// Start with the standard defaults (fill missing).
applyClaudeOAuthHeaderDefaults(req, isStream)
applyClaudeOAuthHeaderDefaults(req)
// Then force key headers to match Claude Code fingerprint regardless of what the client sent.
// 使用 resolveWireCasing 确保 key 与真实 wire format 一致(如 "x-app" 而非 "X-App")
for key, value := range claude.DefaultHeaders {
if value == "" {
continue
}
req.Header.Set(key, value)
setHeaderRaw(req.Header, resolveWireCasing(key), value)
}
// Real Claude CLI uses Accept: application/json (even for streaming).
req.Header.Set("accept", "application/json")
setHeaderRaw(req.Header, "Accept", "application/json")
if isStream {
req.Header.Set("x-stainless-helper-method", "stream")
setHeaderRaw(req.Header, "x-stainless-helper-method", "stream")
}
}
......@@ -6112,6 +6171,59 @@ func truncateForLog(b []byte, maxBytes int) string {
return s
}
// shouldRectifySignatureError 统一判断是否应触发签名整流(strip thinking blocks 并重试)。
// 根据账号类型检查对应的开关和匹配模式。
func (s *GatewayService) shouldRectifySignatureError(ctx context.Context, account *Account, respBody []byte) bool {
if account.Type == AccountTypeAPIKey {
// API Key 账号:独立开关,一次读取配置
settings, err := s.settingService.GetRectifierSettings(ctx)
if err != nil || !settings.Enabled || !settings.APIKeySignatureEnabled {
return false
}
// 先检查内置模式(同 OAuth),再检查自定义关键词
if s.isThinkingBlockSignatureError(respBody) {
return true
}
return matchSignaturePatterns(respBody, settings.APIKeySignaturePatterns)
}
// OAuth/SetupToken/Upstream/Bedrock 等:保持原有行为(内置模式 + 原开关)
return s.isThinkingBlockSignatureError(respBody) && s.settingService.IsSignatureRectifierEnabled(ctx)
}
// isSignatureErrorPattern 仅做模式匹配,不检查开关。
// 用于已进入重试流程后的二阶段检测(此时开关已在首次调用时验证过)。
func (s *GatewayService) isSignatureErrorPattern(ctx context.Context, account *Account, respBody []byte) bool {
if s.isThinkingBlockSignatureError(respBody) {
return true
}
if account.Type == AccountTypeAPIKey {
settings, err := s.settingService.GetRectifierSettings(ctx)
if err != nil {
return false
}
return matchSignaturePatterns(respBody, settings.APIKeySignaturePatterns)
}
return false
}
// matchSignaturePatterns 检查响应体是否匹配自定义关键词列表(不区分大小写)。
func matchSignaturePatterns(respBody []byte, patterns []string) bool {
if len(patterns) == 0 {
return false
}
bodyLower := strings.ToLower(string(respBody))
for _, p := range patterns {
p = strings.TrimSpace(p)
if p == "" {
continue
}
if strings.Contains(bodyLower, strings.ToLower(p)) {
return true
}
}
return false
}
// isThinkingBlockSignatureError 检测是否是thinking block相关错误
// 这类错误可以通过过滤thinking blocks并重试来解决
func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool {
......@@ -7958,7 +8070,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
// 发送请求
resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
if err != nil {
setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(err.Error()), "")
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
......@@ -7980,13 +8092,13 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
// 检测 thinking block 签名错误(400)并重试一次(过滤 thinking blocks)
if resp.StatusCode == 400 && s.isThinkingBlockSignatureError(respBody) && s.settingService.IsSignatureRectifierEnabled(ctx) {
if resp.StatusCode == 400 && s.shouldRectifySignatureError(ctx, account, respBody) {
logger.LegacyPrintf("service.gateway", "Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID)
filteredBody := FilterThinkingBlocksForRetry(body)
retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, shouldMimicClaudeCode)
if buildErr == nil {
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
if retryErr == nil {
resp = retryResp
respBody, err = readUpstreamResponseBodyLimited(resp.Body, maxReadBytes)
......@@ -8075,7 +8187,7 @@ func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx contex
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
if err != nil {
setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(err.Error()), "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
......@@ -8197,8 +8309,9 @@ func (s *GatewayService) buildCountTokensRequestAnthropicAPIKeyPassthrough(
if !allowedHeaders[lowerKey] {
continue
}
wireKey := resolveWireCasing(key)
for _, v := range values {
req.Header.Add(key, v)
addHeaderRaw(req.Header, wireKey, v)
}
}
}
......@@ -8239,11 +8352,18 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
clientHeaders = c.Request.Header
}
// OAuth 账号:应用统一指纹和重写 userID
// OAuth 账号:应用统一指纹和重写 userID(受设置开关控制)
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
ctEnableFP, ctEnableMPT := true, false
if s.settingService != nil {
ctEnableFP, ctEnableMPT = s.settingService.GetGatewayForwardingSettings(ctx)
}
var ctFingerprint *Fingerprint
if account.IsOAuth() && s.identityService != nil {
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, clientHeaders)
if err == nil {
ctFingerprint = fp
if !ctEnableMPT {
accountUUID := account.GetExtraString("account_uuid")
if accountUUID != "" && fp.ClientID != "" {
if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID, fp.UserAgent); err == nil && len(newBody) > 0 {
......@@ -8252,46 +8372,45 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
}
}
}
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
if err != nil {
return nil, err
}
// 设置认证头
// 设置认证头(保持原始大小写)
if tokenType == "oauth" {
req.Header.Set("authorization", "Bearer "+token)
setHeaderRaw(req.Header, "authorization", "Bearer "+token)
} else {
req.Header.Set("x-api-key", token)
setHeaderRaw(req.Header, "x-api-key", token)
}
// 白名单透传 headers
// 白名单透传 headers(恢复真实 wire casing)
for key, values := range clientHeaders {
lowerKey := strings.ToLower(key)
if allowedHeaders[lowerKey] {
wireKey := resolveWireCasing(key)
for _, v := range values {
req.Header.Add(key, v)
addHeaderRaw(req.Header, wireKey, v)
}
}
}
// OAuth 账号:应用指纹到请求头
if account.IsOAuth() && s.identityService != nil {
fp, _ := s.identityService.GetOrCreateFingerprint(ctx, account.ID, clientHeaders)
if fp != nil {
s.identityService.ApplyFingerprint(req, fp)
}
// OAuth 账号:应用指纹到请求头(受设置开关控制)
if ctEnableFP && ctFingerprint != nil {
s.identityService.ApplyFingerprint(req, ctFingerprint)
}
// 确保必要的 headers 存在
if req.Header.Get("content-type") == "" {
req.Header.Set("content-type", "application/json")
// 确保必要的 headers 存在(保持原始大小写)
if getHeaderRaw(req.Header, "content-type") == "" {
setHeaderRaw(req.Header, "content-type", "application/json")
}
if req.Header.Get("anthropic-version") == "" {
req.Header.Set("anthropic-version", "2023-06-01")
if getHeaderRaw(req.Header, "anthropic-version") == "" {
setHeaderRaw(req.Header, "anthropic-version", "2023-06-01")
}
if tokenType == "oauth" {
applyClaudeOAuthHeaderDefaults(req, false)
applyClaudeOAuthHeaderDefaults(req)
}
// Build effective drop set for count_tokens: merge static defaults with dynamic beta policy filter rules
......@@ -8302,30 +8421,30 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
if mimicClaudeCode {
applyClaudeCodeMimicHeaders(req, false)
incomingBeta := req.Header.Get("anthropic-beta")
incomingBeta := getHeaderRaw(req.Header, "anthropic-beta")
requiredBetas := []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking, claude.BetaTokenCounting}
req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, ctEffectiveDropSet))
setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, ctEffectiveDropSet))
} else {
clientBetaHeader := req.Header.Get("anthropic-beta")
clientBetaHeader := getHeaderRaw(req.Header, "anthropic-beta")
if clientBetaHeader == "" {
req.Header.Set("anthropic-beta", claude.CountTokensBetaHeader)
setHeaderRaw(req.Header, "anthropic-beta", claude.CountTokensBetaHeader)
} else {
beta := s.getBetaHeader(modelID, clientBetaHeader)
if !strings.Contains(beta, claude.BetaTokenCounting) {
beta = beta + "," + claude.BetaTokenCounting
}
req.Header.Set("anthropic-beta", stripBetaTokensWithSet(beta, ctEffectiveDropSet))
setHeaderRaw(req.Header, "anthropic-beta", stripBetaTokensWithSet(beta, ctEffectiveDropSet))
}
}
} else {
// API-key accounts: apply beta policy filter to strip controlled tokens
if existingBeta := req.Header.Get("anthropic-beta"); existingBeta != "" {
req.Header.Set("anthropic-beta", stripBetaTokensWithSet(existingBeta, ctEffectiveDropSet))
if existingBeta := getHeaderRaw(req.Header, "anthropic-beta"); existingBeta != "" {
setHeaderRaw(req.Header, "anthropic-beta", stripBetaTokensWithSet(existingBeta, ctEffectiveDropSet))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey {
// API-key:与 messages 同步的按需 beta 注入(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
setHeaderRaw(req.Header, "anthropic-beta", beta)
}
}
}
......@@ -8496,42 +8615,94 @@ func reconcileCachedTokens(usage map[string]any) bool {
return true
}
func debugGatewayBodyLoggingEnabled() bool {
raw := strings.TrimSpace(os.Getenv(debugGatewayBodyEnv))
if raw == "" {
return false
const debugGatewayBodyDefaultFilename = "gateway_debug.log"
// initDebugGatewayBodyFile 初始化网关调试日志文件。
//
// - "1"/"true" 等布尔值 → 当前目录下 gateway_debug.log
// - 已有目录路径 → 该目录下 gateway_debug.log
// - 其他 → 视为完整文件路径
func (s *GatewayService) initDebugGatewayBodyFile(path string) {
if parseDebugEnvBool(path) {
path = debugGatewayBodyDefaultFilename
}
switch strings.ToLower(raw) {
case "1", "true", "yes", "on":
return true
default:
return false
// 如果 path 指向一个已存在的目录,自动追加默认文件名
if info, err := os.Stat(path); err == nil && info.IsDir() {
path = filepath.Join(path, debugGatewayBodyDefaultFilename)
}
// 确保父目录存在
if dir := filepath.Dir(path); dir != "." {
if err := os.MkdirAll(dir, 0755); err != nil {
slog.Error("failed to create gateway debug log directory", "dir", dir, "error", err)
return
}
}
f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
slog.Error("failed to open gateway debug log file", "path", path, "error", err)
return
}
s.debugGatewayBodyFile.Store(f)
slog.Info("gateway debug logging enabled", "path", path)
}
// debugLogRequestBody 打印请求 body 用于调试 metadata.user_id 重写。
// 默认关闭,仅在设置环境变量时启用:
// debugLogGatewaySnapshot 将网关请求的完整快照(headers + body)写入独立的调试日志文件,
// 用于对比客户端原始请求和上游转发请求。
//
// 启用方式(环境变量):
//
// SUB2API_DEBUG_GATEWAY_BODY=1
func debugLogRequestBody(tag string, body []byte) {
if !debugGatewayBodyLoggingEnabled() {
// SUB2API_DEBUG_GATEWAY_BODY=1 # 写入 gateway_debug.log
// SUB2API_DEBUG_GATEWAY_BODY=/tmp/gateway_debug.log # 写入指定路径
//
// tag: "CLIENT_ORIGINAL" 或 "UPSTREAM_FORWARD"
func (s *GatewayService) debugLogGatewaySnapshot(tag string, headers http.Header, body []byte, extra map[string]string) {
f := s.debugGatewayBodyFile.Load()
if f == nil {
return
}
if len(body) == 0 {
logger.LegacyPrintf("service.gateway", "[DEBUG_%s] body is empty", tag)
return
var buf strings.Builder
ts := time.Now().Format("2006-01-02 15:04:05.000")
fmt.Fprintf(&buf, "\n========== [%s] %s ==========\n", ts, tag)
// 1. context
if len(extra) > 0 {
fmt.Fprint(&buf, "--- context ---\n")
extraKeys := make([]string, 0, len(extra))
for k := range extra {
extraKeys = append(extraKeys, k)
}
sort.Strings(extraKeys)
for _, k := range extraKeys {
fmt.Fprintf(&buf, " %s: %s\n", k, extra[k])
}
}
// 提取 metadata 字段完整打印
metadataResult := gjson.GetBytes(body, "metadata")
if metadataResult.Exists() {
logger.LegacyPrintf("service.gateway", "[DEBUG_%s] metadata = %s", tag, metadataResult.Raw)
// 2. headers(按真实 Claude CLI wire 顺序排列,便于与抓包对比;auth 脱敏)
fmt.Fprint(&buf, "--- headers ---\n")
for _, k := range sortHeadersByWireOrder(headers) {
for _, v := range headers[k] {
fmt.Fprintf(&buf, " %s: %s\n", k, safeHeaderValueForLog(k, v))
}
}
// 3. body(完整输出,格式化 JSON 便于 diff)
fmt.Fprint(&buf, "--- body ---\n")
if len(body) == 0 {
fmt.Fprint(&buf, " (empty)\n")
} else {
var pretty bytes.Buffer
if json.Indent(&pretty, body, " ", " ") == nil {
fmt.Fprintf(&buf, " %s\n", pretty.Bytes())
} else {
logger.LegacyPrintf("service.gateway", "[DEBUG_%s] metadata field not found", tag)
// JSON 格式化失败时原样输出
fmt.Fprintf(&buf, " %s\n", body)
}
}
// 全量打印 body
logger.LegacyPrintf("service.gateway", "[DEBUG_%s] body (%d bytes) = %s", tag, len(body), string(body))
// 写入文件(调试用,并发写入可能交错但不影响可读性)
_, _ = f.WriteString(buf.String())
}
......@@ -12,6 +12,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
......@@ -36,7 +37,7 @@ func (s *geminiCompatHTTPUpstreamStub) Do(req *http.Request, proxyURL string, ac
return &resp, nil
}
func (s *geminiCompatHTTPUpstreamStub) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
func (s *geminiCompatHTTPUpstreamStub) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) {
return s.Do(req, proxyURL, accountID, accountConcurrency)
}
......
package service
import (
"net/http"
"strings"
)
// headerWireCasing 定义每个白名单 header 在真实 Claude CLI 抓包中的准确大小写。
// Go 的 HTTP server 解析请求时会将所有 header key 转为 Canonical 形式(如 x-app → X-App),
// 此 map 用于在转发时恢复到真实的 wire format。
//
// 来源:对真实 Claude CLI (claude-cli/2.1.81) 到 api.anthropic.com 的 HTTPS 流量抓包。
var headerWireCasing = map[string]string{
// Title case
"accept": "Accept",
"user-agent": "User-Agent",
// X-Stainless-* 保持 SDK 原始大小写
"x-stainless-retry-count": "X-Stainless-Retry-Count",
"x-stainless-timeout": "X-Stainless-Timeout",
"x-stainless-lang": "X-Stainless-Lang",
"x-stainless-package-version": "X-Stainless-Package-Version",
"x-stainless-os": "X-Stainless-OS",
"x-stainless-arch": "X-Stainless-Arch",
"x-stainless-runtime": "X-Stainless-Runtime",
"x-stainless-runtime-version": "X-Stainless-Runtime-Version",
"x-stainless-helper-method": "x-stainless-helper-method",
// Anthropic SDK 自身设置的 header,全小写
"anthropic-dangerous-direct-browser-access": "anthropic-dangerous-direct-browser-access",
"anthropic-version": "anthropic-version",
"anthropic-beta": "anthropic-beta",
"x-app": "x-app",
"content-type": "content-type",
"accept-language": "accept-language",
"sec-fetch-mode": "sec-fetch-mode",
"accept-encoding": "accept-encoding",
"authorization": "authorization",
}
// headerWireOrder 定义真实 Claude CLI 发送 header 的顺序(基于抓包)。
// 用于 debug log 按此顺序输出,便于与抓包结果直接对比。
var headerWireOrder = []string{
"Accept",
"X-Stainless-Retry-Count",
"X-Stainless-Timeout",
"X-Stainless-Lang",
"X-Stainless-Package-Version",
"X-Stainless-OS",
"X-Stainless-Arch",
"X-Stainless-Runtime",
"X-Stainless-Runtime-Version",
"anthropic-dangerous-direct-browser-access",
"anthropic-version",
"authorization",
"x-app",
"User-Agent",
"content-type",
"anthropic-beta",
"accept-language",
"sec-fetch-mode",
"accept-encoding",
"x-stainless-helper-method",
}
// headerWireOrderSet 用于快速判断某个 key 是否在 headerWireOrder 中(按 lowercase 匹配)。
var headerWireOrderSet map[string]struct{}
func init() {
headerWireOrderSet = make(map[string]struct{}, len(headerWireOrder))
for _, k := range headerWireOrder {
headerWireOrderSet[strings.ToLower(k)] = struct{}{}
}
}
// resolveWireCasing 将 Go canonical key(如 X-Stainless-Os)映射为真实 wire casing(如 X-Stainless-OS)。
// 如果 map 中没有对应条目,返回原始 key 不变。
func resolveWireCasing(key string) string {
if wk, ok := headerWireCasing[strings.ToLower(key)]; ok {
return wk
}
return key
}
// setHeaderRaw sets a header bypassing Go's canonical-case normalization.
// The key is stored exactly as provided, preserving original casing.
//
// It first removes any existing value under the canonical key, the wire casing key,
// and the exact raw key, preventing duplicates from any source.
func setHeaderRaw(h http.Header, key, value string) {
h.Del(key) // remove canonical form (e.g. "Anthropic-Beta")
if wk := resolveWireCasing(key); wk != key {
delete(h, wk) // remove wire casing form if different
}
delete(h, key) // remove exact raw key if it differs from canonical
h[key] = []string{value}
}
// addHeaderRaw appends a header value bypassing Go's canonical-case normalization.
func addHeaderRaw(h http.Header, key, value string) {
h[key] = append(h[key], value)
}
// getHeaderRaw reads a header value, trying multiple key forms to handle the mismatch
// between Go canonical keys, wire casing keys, and raw keys:
// 1. exact key as provided
// 2. wire casing form (from headerWireCasing)
// 3. Go canonical form (via http.Header.Get)
func getHeaderRaw(h http.Header, key string) string {
// 1. exact key
if vals := h[key]; len(vals) > 0 {
return vals[0]
}
// 2. wire casing (e.g. looking up "Anthropic-Dangerous-Direct-Browser-Access" finds "anthropic-dangerous-direct-browser-access")
if wk := resolveWireCasing(key); wk != key {
if vals := h[wk]; len(vals) > 0 {
return vals[0]
}
}
// 3. canonical fallback
return h.Get(key)
}
// sortHeadersByWireOrder 按照真实 Claude CLI 的 header 顺序返回排序后的 key 列表。
// 在 headerWireOrder 中定义的 key 按其顺序排列,未定义的 key 追加到末尾。
func sortHeadersByWireOrder(h http.Header) []string {
// 构建 lowercase -> actual map key 的映射
present := make(map[string]string, len(h))
for k := range h {
present[strings.ToLower(k)] = k
}
result := make([]string, 0, len(h))
seen := make(map[string]struct{}, len(h))
// 先按 wire order 输出
for _, wk := range headerWireOrder {
lk := strings.ToLower(wk)
if actual, ok := present[lk]; ok {
if _, dup := seen[lk]; !dup {
result = append(result, actual)
seen[lk] = struct{}{}
}
}
}
// 再追加不在 wire order 中的 header
for k := range h {
lk := strings.ToLower(k)
if _, ok := seen[lk]; !ok {
result = append(result, k)
seen[lk] = struct{}{}
}
}
return result
}
package service
import "net/http"
import (
"net/http"
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
)
// HTTPUpstream 上游 HTTP 请求接口
// 用于向上游 API(Claude、OpenAI、Gemini 等)发送请求
// 这是一个通用接口,可用于任何基于 HTTP 的上游服务
//
// 设计说明:
// - 支持可选代理配置
// - 支持账户级连接池隔离
// - 实现类负责连接池管理和复用
// - 支持可选的 TLS 指纹伪装
type HTTPUpstream interface {
// Do 执行 HTTP 请求
//
// 参数:
// - req: HTTP 请求对象,由调用方构建
// - proxyURL: 代理服务器地址,空字符串表示直连
// - accountID: 账户 ID,用于连接池隔离(隔离策略为 account 或 account_proxy 时生效)
// - accountConcurrency: 账户并发限制,用于动态调整连接池大小
//
// 返回:
// - *http.Response: HTTP 响应,调用方必须关闭 Body
// - error: 请求错误(网络错误、超时等)
//
// 注意:
// - 调用方必须关闭 resp.Body,否则会导致连接泄漏
// - 响应体可能已被包装以跟踪请求生命周期
// Do 执行 HTTP 请求(不启用 TLS 指纹)
Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error)
// DoWithTLS 执行带 TLS 指纹伪装的 HTTP 请求
//
// 参数:
// - req: HTTP 请求对象,由调用方构建
// - proxyURL: 代理服务器地址,空字符串表示直连
// - accountID: 账户 ID,用于连接池隔离和 TLS 指纹模板选择
// - accountConcurrency: 账户并发限制,用于动态调整连接池大小
// - enableTLSFingerprint: 是否启用 TLS 指纹伪装
//
// 返回:
// - *http.Response: HTTP 响应,调用方必须关闭 Body
// - error: 请求错误(网络错误、超时等)
//
// TLS 指纹说明:
// - 当 enableTLSFingerprint=true 时,使用 utls 库模拟 Claude CLI 的 TLS 指纹
// - TLS 指纹模板根据 accountID % len(profiles) 自动选择
// - 支持直连、HTTP/HTTPS 代理、SOCKS5 代理三种场景
// - 如果 enableTLSFingerprint=false,行为与 Do 方法相同
// profile 参数:
// - nil: 不启用 TLS 指纹,行为与 Do 方法相同
// - non-nil: 使用指定的 Profile 进行 TLS 指纹伪装
//
// 注意:
// - 调用方必须关闭 resp.Body,否则会导致连接泄漏
// - TLS 指纹客户端与普通客户端使用不同的缓存键,互不影响
DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error)
// Profile 由调用方通过 TLSFingerprintProfileService 解析后传入,
// 支持按账号绑定的数据库 profile 或内置默认 profile。
DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error)
}
......@@ -174,6 +174,7 @@ func getHeaderOrDefault(headers http.Header, key, defaultValue string) string {
}
// ApplyFingerprint 将指纹应用到请求头(覆盖原有的x-stainless-*头)
// 使用 setHeaderRaw 保持原始大小写(如 X-Stainless-OS 而非 X-Stainless-Os)
func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) {
if fp == nil {
return
......@@ -181,27 +182,27 @@ func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) {
// 设置user-agent
if fp.UserAgent != "" {
req.Header.Set("user-agent", fp.UserAgent)
setHeaderRaw(req.Header, "User-Agent", fp.UserAgent)
}
// 设置x-stainless-*头
// 设置x-stainless-*头(保持与 claude.DefaultHeaders 一致的大小写)
if fp.StainlessLang != "" {
req.Header.Set("X-Stainless-Lang", fp.StainlessLang)
setHeaderRaw(req.Header, "X-Stainless-Lang", fp.StainlessLang)
}
if fp.StainlessPackageVersion != "" {
req.Header.Set("X-Stainless-Package-Version", fp.StainlessPackageVersion)
setHeaderRaw(req.Header, "X-Stainless-Package-Version", fp.StainlessPackageVersion)
}
if fp.StainlessOS != "" {
req.Header.Set("X-Stainless-OS", fp.StainlessOS)
setHeaderRaw(req.Header, "X-Stainless-OS", fp.StainlessOS)
}
if fp.StainlessArch != "" {
req.Header.Set("X-Stainless-Arch", fp.StainlessArch)
setHeaderRaw(req.Header, "X-Stainless-Arch", fp.StainlessArch)
}
if fp.StainlessRuntime != "" {
req.Header.Set("X-Stainless-Runtime", fp.StainlessRuntime)
setHeaderRaw(req.Header, "X-Stainless-Runtime", fp.StainlessRuntime)
}
if fp.StainlessRuntimeVersion != "" {
req.Header.Set("X-Stainless-Runtime-Version", fp.StainlessRuntimeVersion)
setHeaderRaw(req.Header, "X-Stainless-Runtime-Version", fp.StainlessRuntimeVersion)
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment