Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
292f25f9
Commit
292f25f9
authored
Jan 20, 2026
by
yangjianbo
Browse files
Merge branch 'main' of
https://github.com/mt21625457/aicodex2api
parents
c92e3777
fbb57294
Changes
60
Hide whitespace changes
Inline
Side-by-side
backend/cmd/server/main.go
View file @
292f25f9
...
@@ -8,6 +8,7 @@ import (
...
@@ -8,6 +8,7 @@ import (
"errors"
"errors"
"flag"
"flag"
"log"
"log"
"log/slog"
"net/http"
"net/http"
"os"
"os"
"os/signal"
"os/signal"
...
@@ -44,7 +45,25 @@ func init() {
...
@@ -44,7 +45,25 @@ func init() {
}
}
}
}
// initLogger configures the default slog handler based on gin.Mode().
// In non-release mode, Debug level logs are enabled.
func
initLogger
()
{
var
level
slog
.
Level
if
gin
.
Mode
()
==
gin
.
ReleaseMode
{
level
=
slog
.
LevelInfo
}
else
{
level
=
slog
.
LevelDebug
}
handler
:=
slog
.
NewTextHandler
(
os
.
Stderr
,
&
slog
.
HandlerOptions
{
Level
:
level
,
})
slog
.
SetDefault
(
slog
.
New
(
handler
))
}
func
main
()
{
func
main
()
{
// Initialize slog logger based on gin mode
initLogger
()
// Parse command line flags
// Parse command line flags
setupMode
:=
flag
.
Bool
(
"setup"
,
false
,
"Run setup wizard in CLI mode"
)
setupMode
:=
flag
.
Bool
(
"setup"
,
false
,
"Run setup wizard in CLI mode"
)
showVersion
:=
flag
.
Bool
(
"version"
,
false
,
"Show version information"
)
showVersion
:=
flag
.
Bool
(
"version"
,
false
,
"Show version information"
)
...
...
backend/internal/config/config.go
View file @
292f25f9
...
@@ -258,8 +258,43 @@ type GatewayConfig struct {
...
@@ -258,8 +258,43 @@ type GatewayConfig struct {
// 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义)
// 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义)
FailoverOn400
bool
`mapstructure:"failover_on_400"`
FailoverOn400
bool
`mapstructure:"failover_on_400"`
// 账户切换最大次数(遇到上游错误时切换到其他账户的次数上限)
MaxAccountSwitches
int
`mapstructure:"max_account_switches"`
// Gemini 账户切换最大次数(Gemini 平台单独配置,因 API 限制更严格)
MaxAccountSwitchesGemini
int
`mapstructure:"max_account_switches_gemini"`
// Antigravity 429 fallback 限流时间(分钟),解析重置时间失败时使用
AntigravityFallbackCooldownMinutes
int
`mapstructure:"antigravity_fallback_cooldown_minutes"`
// Scheduling: 账号调度相关配置
// Scheduling: 账号调度相关配置
Scheduling
GatewaySchedulingConfig
`mapstructure:"scheduling"`
Scheduling
GatewaySchedulingConfig
`mapstructure:"scheduling"`
// TLSFingerprint: TLS指纹伪装配置
TLSFingerprint
TLSFingerprintConfig
`mapstructure:"tls_fingerprint"`
}
// TLSFingerprintConfig TLS指纹伪装配置
// 用于模拟 Claude CLI (Node.js) 的 TLS 握手特征,避免被识别为非官方客户端
type
TLSFingerprintConfig
struct
{
// Enabled: 是否全局启用TLS指纹功能
Enabled
bool
`mapstructure:"enabled"`
// Profiles: 预定义的TLS指纹配置模板
// key 为模板名称,如 "claude_cli_v2", "chrome_120" 等
Profiles
map
[
string
]
TLSProfileConfig
`mapstructure:"profiles"`
}
// TLSProfileConfig 单个TLS指纹模板的配置
type
TLSProfileConfig
struct
{
// Name: 模板显示名称
Name
string
`mapstructure:"name"`
// EnableGREASE: 是否启用GREASE扩展(Chrome使用,Node.js不使用)
EnableGREASE
bool
`mapstructure:"enable_grease"`
// CipherSuites: TLS加密套件列表(空则使用内置默认值)
CipherSuites
[]
uint16
`mapstructure:"cipher_suites"`
// Curves: 椭圆曲线列表(空则使用内置默认值)
Curves
[]
uint16
`mapstructure:"curves"`
// PointFormats: 点格式列表(空则使用内置默认值)
PointFormats
[]
uint8
`mapstructure:"point_formats"`
}
}
// GatewaySchedulingConfig accounts scheduling configuration.
// GatewaySchedulingConfig accounts scheduling configuration.
...
@@ -272,6 +307,9 @@ type GatewaySchedulingConfig struct {
...
@@ -272,6 +307,9 @@ type GatewaySchedulingConfig struct {
FallbackWaitTimeout
time
.
Duration
`mapstructure:"fallback_wait_timeout"`
FallbackWaitTimeout
time
.
Duration
`mapstructure:"fallback_wait_timeout"`
FallbackMaxWaiting
int
`mapstructure:"fallback_max_waiting"`
FallbackMaxWaiting
int
`mapstructure:"fallback_max_waiting"`
// 兜底层账户选择策略: "last_used"(按最后使用时间排序,默认) 或 "random"(随机)
FallbackSelectionMode
string
`mapstructure:"fallback_selection_mode"`
// 负载计算
// 负载计算
LoadBatchEnabled
bool
`mapstructure:"load_batch_enabled"`
LoadBatchEnabled
bool
`mapstructure:"load_batch_enabled"`
...
@@ -781,6 +819,9 @@ func setDefaults() {
...
@@ -781,6 +819,9 @@ func setDefaults() {
viper
.
SetDefault
(
"gateway.log_upstream_error_body_max_bytes"
,
2048
)
viper
.
SetDefault
(
"gateway.log_upstream_error_body_max_bytes"
,
2048
)
viper
.
SetDefault
(
"gateway.inject_beta_for_apikey"
,
false
)
viper
.
SetDefault
(
"gateway.inject_beta_for_apikey"
,
false
)
viper
.
SetDefault
(
"gateway.failover_on_400"
,
false
)
viper
.
SetDefault
(
"gateway.failover_on_400"
,
false
)
viper
.
SetDefault
(
"gateway.max_account_switches"
,
10
)
viper
.
SetDefault
(
"gateway.max_account_switches_gemini"
,
3
)
viper
.
SetDefault
(
"gateway.antigravity_fallback_cooldown_minutes"
,
1
)
viper
.
SetDefault
(
"gateway.max_body_size"
,
int64
(
100
*
1024
*
1024
))
viper
.
SetDefault
(
"gateway.max_body_size"
,
int64
(
100
*
1024
*
1024
))
viper
.
SetDefault
(
"gateway.connection_pool_isolation"
,
ConnectionPoolIsolationAccountProxy
)
viper
.
SetDefault
(
"gateway.connection_pool_isolation"
,
ConnectionPoolIsolationAccountProxy
)
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
...
@@ -793,11 +834,12 @@ func setDefaults() {
...
@@ -793,11 +834,12 @@ func setDefaults() {
viper
.
SetDefault
(
"gateway.concurrency_slot_ttl_minutes"
,
30
)
// 并发槽位过期时间(支持超长请求)
viper
.
SetDefault
(
"gateway.concurrency_slot_ttl_minutes"
,
30
)
// 并发槽位过期时间(支持超长请求)
viper
.
SetDefault
(
"gateway.stream_data_interval_timeout"
,
180
)
viper
.
SetDefault
(
"gateway.stream_data_interval_timeout"
,
180
)
viper
.
SetDefault
(
"gateway.stream_keepalive_interval"
,
10
)
viper
.
SetDefault
(
"gateway.stream_keepalive_interval"
,
10
)
viper
.
SetDefault
(
"gateway.max_line_size"
,
1
0
*
1024
*
1024
)
viper
.
SetDefault
(
"gateway.max_line_size"
,
4
0
*
1024
*
1024
)
viper
.
SetDefault
(
"gateway.scheduling.sticky_session_max_waiting"
,
3
)
viper
.
SetDefault
(
"gateway.scheduling.sticky_session_max_waiting"
,
3
)
viper
.
SetDefault
(
"gateway.scheduling.sticky_session_wait_timeout"
,
120
*
time
.
Second
)
viper
.
SetDefault
(
"gateway.scheduling.sticky_session_wait_timeout"
,
120
*
time
.
Second
)
viper
.
SetDefault
(
"gateway.scheduling.fallback_wait_timeout"
,
30
*
time
.
Second
)
viper
.
SetDefault
(
"gateway.scheduling.fallback_wait_timeout"
,
30
*
time
.
Second
)
viper
.
SetDefault
(
"gateway.scheduling.fallback_max_waiting"
,
100
)
viper
.
SetDefault
(
"gateway.scheduling.fallback_max_waiting"
,
100
)
viper
.
SetDefault
(
"gateway.scheduling.fallback_selection_mode"
,
"last_used"
)
viper
.
SetDefault
(
"gateway.scheduling.load_batch_enabled"
,
true
)
viper
.
SetDefault
(
"gateway.scheduling.load_batch_enabled"
,
true
)
viper
.
SetDefault
(
"gateway.scheduling.slot_cleanup_interval"
,
30
*
time
.
Second
)
viper
.
SetDefault
(
"gateway.scheduling.slot_cleanup_interval"
,
30
*
time
.
Second
)
viper
.
SetDefault
(
"gateway.scheduling.db_fallback_enabled"
,
true
)
viper
.
SetDefault
(
"gateway.scheduling.db_fallback_enabled"
,
true
)
...
@@ -809,6 +851,8 @@ func setDefaults() {
...
@@ -809,6 +851,8 @@ func setDefaults() {
viper
.
SetDefault
(
"gateway.scheduling.outbox_lag_rebuild_failures"
,
3
)
viper
.
SetDefault
(
"gateway.scheduling.outbox_lag_rebuild_failures"
,
3
)
viper
.
SetDefault
(
"gateway.scheduling.outbox_backlog_rebuild_rows"
,
10000
)
viper
.
SetDefault
(
"gateway.scheduling.outbox_backlog_rebuild_rows"
,
10000
)
viper
.
SetDefault
(
"gateway.scheduling.full_rebuild_interval_seconds"
,
300
)
viper
.
SetDefault
(
"gateway.scheduling.full_rebuild_interval_seconds"
,
300
)
// TLS指纹伪装配置(默认关闭,需要账号级别单独启用)
viper
.
SetDefault
(
"gateway.tls_fingerprint.enabled"
,
true
)
viper
.
SetDefault
(
"concurrency.ping_interval"
,
10
)
viper
.
SetDefault
(
"concurrency.ping_interval"
,
10
)
// TokenRefresh
// TokenRefresh
...
...
backend/internal/handler/admin/account_handler.go
View file @
292f25f9
...
@@ -173,6 +173,7 @@ func (h *AccountHandler) List(c *gin.Context) {
...
@@ -173,6 +173,7 @@ func (h *AccountHandler) List(c *gin.Context) {
// 识别需要查询窗口费用和会话数的账号(Anthropic OAuth/SetupToken 且启用了相应功能)
// 识别需要查询窗口费用和会话数的账号(Anthropic OAuth/SetupToken 且启用了相应功能)
windowCostAccountIDs
:=
make
([]
int64
,
0
)
windowCostAccountIDs
:=
make
([]
int64
,
0
)
sessionLimitAccountIDs
:=
make
([]
int64
,
0
)
sessionLimitAccountIDs
:=
make
([]
int64
,
0
)
sessionIdleTimeouts
:=
make
(
map
[
int64
]
time
.
Duration
)
// 各账号的会话空闲超时配置
for
i
:=
range
accounts
{
for
i
:=
range
accounts
{
acc
:=
&
accounts
[
i
]
acc
:=
&
accounts
[
i
]
if
acc
.
IsAnthropicOAuthOrSetupToken
()
{
if
acc
.
IsAnthropicOAuthOrSetupToken
()
{
...
@@ -181,6 +182,7 @@ func (h *AccountHandler) List(c *gin.Context) {
...
@@ -181,6 +182,7 @@ func (h *AccountHandler) List(c *gin.Context) {
}
}
if
acc
.
GetMaxSessions
()
>
0
{
if
acc
.
GetMaxSessions
()
>
0
{
sessionLimitAccountIDs
=
append
(
sessionLimitAccountIDs
,
acc
.
ID
)
sessionLimitAccountIDs
=
append
(
sessionLimitAccountIDs
,
acc
.
ID
)
sessionIdleTimeouts
[
acc
.
ID
]
=
time
.
Duration
(
acc
.
GetSessionIdleTimeoutMinutes
())
*
time
.
Minute
}
}
}
}
}
}
...
@@ -189,9 +191,9 @@ func (h *AccountHandler) List(c *gin.Context) {
...
@@ -189,9 +191,9 @@ func (h *AccountHandler) List(c *gin.Context) {
var
windowCosts
map
[
int64
]
float64
var
windowCosts
map
[
int64
]
float64
var
activeSessions
map
[
int64
]
int
var
activeSessions
map
[
int64
]
int
// 获取活跃会话数(批量查询)
// 获取活跃会话数(批量查询
,传入各账号的 idleTimeout 配置
)
if
len
(
sessionLimitAccountIDs
)
>
0
&&
h
.
sessionLimitCache
!=
nil
{
if
len
(
sessionLimitAccountIDs
)
>
0
&&
h
.
sessionLimitCache
!=
nil
{
activeSessions
,
_
=
h
.
sessionLimitCache
.
GetActiveSessionCountBatch
(
c
.
Request
.
Context
(),
sessionLimitAccountIDs
)
activeSessions
,
_
=
h
.
sessionLimitCache
.
GetActiveSessionCountBatch
(
c
.
Request
.
Context
(),
sessionLimitAccountIDs
,
sessionIdleTimeouts
)
if
activeSessions
==
nil
{
if
activeSessions
==
nil
{
activeSessions
=
make
(
map
[
int64
]
int
)
activeSessions
=
make
(
map
[
int64
]
int
)
}
}
...
@@ -211,12 +213,8 @@ func (h *AccountHandler) List(c *gin.Context) {
...
@@ -211,12 +213,8 @@ func (h *AccountHandler) List(c *gin.Context) {
}
}
accCopy
:=
acc
// 闭包捕获
accCopy
:=
acc
// 闭包捕获
g
.
Go
(
func
()
error
{
g
.
Go
(
func
()
error
{
var
startTime
time
.
Time
// 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
if
accCopy
.
SessionWindowStart
!=
nil
{
startTime
:=
accCopy
.
GetCurrentWindowStartTime
()
startTime
=
*
accCopy
.
SessionWindowStart
}
else
{
startTime
=
time
.
Now
()
.
Add
(
-
5
*
time
.
Hour
)
}
stats
,
err
:=
h
.
accountUsageService
.
GetAccountWindowStats
(
gctx
,
accCopy
.
ID
,
startTime
)
stats
,
err
:=
h
.
accountUsageService
.
GetAccountWindowStats
(
gctx
,
accCopy
.
ID
,
startTime
)
if
err
==
nil
&&
stats
!=
nil
{
if
err
==
nil
&&
stats
!=
nil
{
mu
.
Lock
()
mu
.
Lock
()
...
@@ -545,6 +543,36 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
...
@@ -545,6 +543,36 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
newCredentials
[
k
]
=
v
newCredentials
[
k
]
=
v
}
}
}
}
// 如果 project_id 获取失败,先更新凭证,再标记账户为 error
if
tokenInfo
.
ProjectIDMissing
{
// 先更新凭证
_
,
updateErr
:=
h
.
adminService
.
UpdateAccount
(
c
.
Request
.
Context
(),
accountID
,
&
service
.
UpdateAccountInput
{
Credentials
:
newCredentials
,
})
if
updateErr
!=
nil
{
response
.
InternalError
(
c
,
"Failed to update credentials: "
+
updateErr
.
Error
())
return
}
// 标记账户为 error
if
setErr
:=
h
.
adminService
.
SetAccountError
(
c
.
Request
.
Context
(),
accountID
,
"missing_project_id: 账户缺少project id,可能无法使用Antigravity"
);
setErr
!=
nil
{
response
.
InternalError
(
c
,
"Failed to set account error: "
+
setErr
.
Error
())
return
}
response
.
Success
(
c
,
gin
.
H
{
"message"
:
"Token refreshed but project_id is missing, account marked as error"
,
"warning"
:
"missing_project_id"
,
})
return
}
// 成功获取到 project_id,如果之前是 missing_project_id 错误则清除
if
account
.
Status
==
service
.
StatusError
&&
strings
.
Contains
(
account
.
ErrorMessage
,
"missing_project_id:"
)
{
if
_
,
clearErr
:=
h
.
adminService
.
ClearAccountError
(
c
.
Request
.
Context
(),
accountID
);
clearErr
!=
nil
{
response
.
InternalError
(
c
,
"Failed to clear account error: "
+
clearErr
.
Error
())
return
}
}
}
else
{
}
else
{
// Use Anthropic/Claude OAuth service to refresh token
// Use Anthropic/Claude OAuth service to refresh token
tokenInfo
,
err
:=
h
.
oauthService
.
RefreshAccountToken
(
c
.
Request
.
Context
(),
account
)
tokenInfo
,
err
:=
h
.
oauthService
.
RefreshAccountToken
(
c
.
Request
.
Context
(),
account
)
...
...
backend/internal/handler/admin/admin_service_stub_test.go
View file @
292f25f9
...
@@ -200,6 +200,10 @@ func (s *stubAdminService) ClearAccountError(ctx context.Context, id int64) (*se
...
@@ -200,6 +200,10 @@ func (s *stubAdminService) ClearAccountError(ctx context.Context, id int64) (*se
return
&
account
,
nil
return
&
account
,
nil
}
}
func
(
s
*
stubAdminService
)
SetAccountError
(
ctx
context
.
Context
,
id
int64
,
errorMsg
string
)
error
{
return
nil
}
func
(
s
*
stubAdminService
)
SetAccountSchedulable
(
ctx
context
.
Context
,
id
int64
,
schedulable
bool
)
(
*
service
.
Account
,
error
)
{
func
(
s
*
stubAdminService
)
SetAccountSchedulable
(
ctx
context
.
Context
,
id
int64
,
schedulable
bool
)
(
*
service
.
Account
,
error
)
{
account
:=
service
.
Account
{
ID
:
id
,
Name
:
"account"
,
Status
:
service
.
StatusActive
,
Schedulable
:
schedulable
}
account
:=
service
.
Account
{
ID
:
id
,
Name
:
"account"
,
Status
:
service
.
StatusActive
,
Schedulable
:
schedulable
}
return
&
account
,
nil
return
&
account
,
nil
...
...
backend/internal/handler/dto/mappers.go
View file @
292f25f9
...
@@ -161,6 +161,16 @@ func AccountFromServiceShallow(a *service.Account) *Account {
...
@@ -161,6 +161,16 @@ func AccountFromServiceShallow(a *service.Account) *Account {
if
idleTimeout
:=
a
.
GetSessionIdleTimeoutMinutes
();
idleTimeout
>
0
{
if
idleTimeout
:=
a
.
GetSessionIdleTimeoutMinutes
();
idleTimeout
>
0
{
out
.
SessionIdleTimeoutMin
=
&
idleTimeout
out
.
SessionIdleTimeoutMin
=
&
idleTimeout
}
}
// TLS指纹伪装开关
if
a
.
IsTLSFingerprintEnabled
()
{
enabled
:=
true
out
.
EnableTLSFingerprint
=
&
enabled
}
// 会话ID伪装开关
if
a
.
IsSessionIDMaskingEnabled
()
{
enabled
:=
true
out
.
EnableSessionIDMasking
=
&
enabled
}
}
}
return
out
return
out
...
...
backend/internal/handler/dto/types.go
View file @
292f25f9
...
@@ -112,6 +112,15 @@ type Account struct {
...
@@ -112,6 +112,15 @@ type Account struct {
MaxSessions
*
int
`json:"max_sessions,omitempty"`
MaxSessions
*
int
`json:"max_sessions,omitempty"`
SessionIdleTimeoutMin
*
int
`json:"session_idle_timeout_minutes,omitempty"`
SessionIdleTimeoutMin
*
int
`json:"session_idle_timeout_minutes,omitempty"`
// TLS指纹伪装(仅 Anthropic OAuth/SetupToken 账号有效)
// 从 extra 字段提取,方便前端显示和编辑
EnableTLSFingerprint
*
bool
`json:"enable_tls_fingerprint,omitempty"`
// 会话ID伪装(仅 Anthropic OAuth/SetupToken 账号有效)
// 启用后将在15分钟内固定 metadata.user_id 中的 session ID
// 从 extra 字段提取,方便前端显示和编辑
EnableSessionIDMasking
*
bool
`json:"session_id_masking_enabled,omitempty"`
Proxy
*
Proxy
`json:"proxy,omitempty"`
Proxy
*
Proxy
`json:"proxy,omitempty"`
AccountGroups
[]
AccountGroup
`json:"account_groups,omitempty"`
AccountGroups
[]
AccountGroup
`json:"account_groups,omitempty"`
...
...
backend/internal/handler/gateway_handler.go
View file @
292f25f9
...
@@ -31,6 +31,8 @@ type GatewayHandler struct {
...
@@ -31,6 +31,8 @@ type GatewayHandler struct {
userService
*
service
.
UserService
userService
*
service
.
UserService
billingCacheService
*
service
.
BillingCacheService
billingCacheService
*
service
.
BillingCacheService
concurrencyHelper
*
ConcurrencyHelper
concurrencyHelper
*
ConcurrencyHelper
maxAccountSwitches
int
maxAccountSwitchesGemini
int
}
}
// NewGatewayHandler creates a new GatewayHandler
// NewGatewayHandler creates a new GatewayHandler
...
@@ -44,8 +46,16 @@ func NewGatewayHandler(
...
@@ -44,8 +46,16 @@ func NewGatewayHandler(
cfg
*
config
.
Config
,
cfg
*
config
.
Config
,
)
*
GatewayHandler
{
)
*
GatewayHandler
{
pingInterval
:=
time
.
Duration
(
0
)
pingInterval
:=
time
.
Duration
(
0
)
maxAccountSwitches
:=
10
maxAccountSwitchesGemini
:=
3
if
cfg
!=
nil
{
if
cfg
!=
nil
{
pingInterval
=
time
.
Duration
(
cfg
.
Concurrency
.
PingInterval
)
*
time
.
Second
pingInterval
=
time
.
Duration
(
cfg
.
Concurrency
.
PingInterval
)
*
time
.
Second
if
cfg
.
Gateway
.
MaxAccountSwitches
>
0
{
maxAccountSwitches
=
cfg
.
Gateway
.
MaxAccountSwitches
}
if
cfg
.
Gateway
.
MaxAccountSwitchesGemini
>
0
{
maxAccountSwitchesGemini
=
cfg
.
Gateway
.
MaxAccountSwitchesGemini
}
}
}
return
&
GatewayHandler
{
return
&
GatewayHandler
{
gatewayService
:
gatewayService
,
gatewayService
:
gatewayService
,
...
@@ -54,6 +64,8 @@ func NewGatewayHandler(
...
@@ -54,6 +64,8 @@ func NewGatewayHandler(
userService
:
userService
,
userService
:
userService
,
billingCacheService
:
billingCacheService
,
billingCacheService
:
billingCacheService
,
concurrencyHelper
:
NewConcurrencyHelper
(
concurrencyService
,
SSEPingFormatClaude
,
pingInterval
),
concurrencyHelper
:
NewConcurrencyHelper
(
concurrencyService
,
SSEPingFormatClaude
,
pingInterval
),
maxAccountSwitches
:
maxAccountSwitches
,
maxAccountSwitchesGemini
:
maxAccountSwitchesGemini
,
}
}
}
}
...
@@ -179,7 +191,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
...
@@ -179,7 +191,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
}
if
platform
==
service
.
PlatformGemini
{
if
platform
==
service
.
PlatformGemini
{
const
maxAccountSwitches
=
3
maxAccountSwitches
:=
h
.
maxAccountSwitches
Gemini
switchCount
:=
0
switchCount
:=
0
failedAccountIDs
:=
make
(
map
[
int64
]
struct
{})
failedAccountIDs
:=
make
(
map
[
int64
]
struct
{})
lastFailoverStatus
:=
0
lastFailoverStatus
:=
0
...
@@ -313,7 +325,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
...
@@ -313,7 +325,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
}
}
}
const
maxAccountSwitches
=
10
maxAccountSwitches
:=
h
.
maxAccountSwitches
switchCount
:=
0
switchCount
:=
0
failedAccountIDs
:=
make
(
map
[
int64
]
struct
{})
failedAccountIDs
:=
make
(
map
[
int64
]
struct
{})
lastFailoverStatus
:=
0
lastFailoverStatus
:=
0
...
...
backend/internal/handler/gemini_v1beta_handler.go
View file @
292f25f9
...
@@ -220,7 +220,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
...
@@ -220,7 +220,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
if
sessionHash
!=
""
{
if
sessionHash
!=
""
{
sessionKey
=
"gemini:"
+
sessionHash
sessionKey
=
"gemini:"
+
sessionHash
}
}
const
maxAccountSwitches
=
3
maxAccountSwitches
:=
h
.
maxAccountSwitches
Gemini
switchCount
:=
0
switchCount
:=
0
failedAccountIDs
:=
make
(
map
[
int64
]
struct
{})
failedAccountIDs
:=
make
(
map
[
int64
]
struct
{})
lastFailoverStatus
:=
0
lastFailoverStatus
:=
0
...
...
backend/internal/handler/openai_gateway_handler.go
View file @
292f25f9
...
@@ -25,6 +25,7 @@ type OpenAIGatewayHandler struct {
...
@@ -25,6 +25,7 @@ type OpenAIGatewayHandler struct {
gatewayService
*
service
.
OpenAIGatewayService
gatewayService
*
service
.
OpenAIGatewayService
billingCacheService
*
service
.
BillingCacheService
billingCacheService
*
service
.
BillingCacheService
concurrencyHelper
*
ConcurrencyHelper
concurrencyHelper
*
ConcurrencyHelper
maxAccountSwitches
int
}
}
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
...
@@ -35,13 +36,18 @@ func NewOpenAIGatewayHandler(
...
@@ -35,13 +36,18 @@ func NewOpenAIGatewayHandler(
cfg
*
config
.
Config
,
cfg
*
config
.
Config
,
)
*
OpenAIGatewayHandler
{
)
*
OpenAIGatewayHandler
{
pingInterval
:=
time
.
Duration
(
0
)
pingInterval
:=
time
.
Duration
(
0
)
maxAccountSwitches
:=
3
if
cfg
!=
nil
{
if
cfg
!=
nil
{
pingInterval
=
time
.
Duration
(
cfg
.
Concurrency
.
PingInterval
)
*
time
.
Second
pingInterval
=
time
.
Duration
(
cfg
.
Concurrency
.
PingInterval
)
*
time
.
Second
if
cfg
.
Gateway
.
MaxAccountSwitches
>
0
{
maxAccountSwitches
=
cfg
.
Gateway
.
MaxAccountSwitches
}
}
}
return
&
OpenAIGatewayHandler
{
return
&
OpenAIGatewayHandler
{
gatewayService
:
gatewayService
,
gatewayService
:
gatewayService
,
billingCacheService
:
billingCacheService
,
billingCacheService
:
billingCacheService
,
concurrencyHelper
:
NewConcurrencyHelper
(
concurrencyService
,
SSEPingFormatComment
,
pingInterval
),
concurrencyHelper
:
NewConcurrencyHelper
(
concurrencyService
,
SSEPingFormatComment
,
pingInterval
),
maxAccountSwitches
:
maxAccountSwitches
,
}
}
}
}
...
@@ -189,7 +195,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
...
@@ -189,7 +195,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Generate session hash (header first; fallback to prompt_cache_key)
// Generate session hash (header first; fallback to prompt_cache_key)
sessionHash
:=
h
.
gatewayService
.
GenerateSessionHash
(
c
,
reqBody
)
sessionHash
:=
h
.
gatewayService
.
GenerateSessionHash
(
c
,
reqBody
)
const
maxAccountSwitches
=
3
maxAccountSwitches
:=
h
.
maxAccountSwitches
switchCount
:=
0
switchCount
:=
0
failedAccountIDs
:=
make
(
map
[
int64
]
struct
{})
failedAccountIDs
:=
make
(
map
[
int64
]
struct
{})
lastFailoverStatus
:=
0
lastFailoverStatus
:=
0
...
...
backend/internal/pkg/antigravity/client.go
View file @
292f25f9
...
@@ -16,15 +16,6 @@ import (
...
@@ -16,15 +16,6 @@ import (
"time"
"time"
)
)
// resolveHost 从 URL 解析 host
func
resolveHost
(
urlStr
string
)
string
{
parsed
,
err
:=
url
.
Parse
(
urlStr
)
if
err
!=
nil
{
return
""
}
return
parsed
.
Host
}
// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点)
// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点)
func
NewAPIRequestWithURL
(
ctx
context
.
Context
,
baseURL
,
action
,
accessToken
string
,
body
[]
byte
)
(
*
http
.
Request
,
error
)
{
func
NewAPIRequestWithURL
(
ctx
context
.
Context
,
baseURL
,
action
,
accessToken
string
,
body
[]
byte
)
(
*
http
.
Request
,
error
)
{
// 构建 URL,流式请求添加 ?alt=sse 参数
// 构建 URL,流式请求添加 ?alt=sse 参数
...
@@ -39,23 +30,11 @@ func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken stri
...
@@ -39,23 +30,11 @@ func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken stri
return
nil
,
err
return
nil
,
err
}
}
// 基础 Headers
// 基础 Headers
(与 Antigravity-Manager 保持一致,只设置这 3 个)
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
accessToken
)
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
accessToken
)
req
.
Header
.
Set
(
"User-Agent"
,
UserAgent
)
req
.
Header
.
Set
(
"User-Agent"
,
UserAgent
)
// Accept Header 根据请求类型设置
if
isStream
{
req
.
Header
.
Set
(
"Accept"
,
"text/event-stream"
)
}
else
{
req
.
Header
.
Set
(
"Accept"
,
"application/json"
)
}
// 显式设置 Host Header
if
host
:=
resolveHost
(
apiURL
);
host
!=
""
{
req
.
Host
=
host
}
return
req
,
nil
return
req
,
nil
}
}
...
@@ -195,12 +174,15 @@ func isConnectionError(err error) bool {
...
@@ -195,12 +174,15 @@ func isConnectionError(err error) bool {
}
}
// shouldFallbackToNextURL 判断是否应切换到下一个 URL
// shouldFallbackToNextURL 判断是否应切换到下一个 URL
//
仅连接错误和 HTTP 429
触发 URL 降级
//
与 Antigravity-Manager 保持一致:连接错误、429、408、404、5xx
触发 URL 降级
func
shouldFallbackToNextURL
(
err
error
,
statusCode
int
)
bool
{
func
shouldFallbackToNextURL
(
err
error
,
statusCode
int
)
bool
{
if
isConnectionError
(
err
)
{
if
isConnectionError
(
err
)
{
return
true
return
true
}
}
return
statusCode
==
http
.
StatusTooManyRequests
return
statusCode
==
http
.
StatusTooManyRequests
||
statusCode
==
http
.
StatusRequestTimeout
||
statusCode
==
http
.
StatusNotFound
||
statusCode
>=
500
}
}
// ExchangeCode 用 authorization code 交换 token
// ExchangeCode 用 authorization code 交换 token
...
@@ -321,11 +303,8 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
...
@@ -321,11 +303,8 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
return
nil
,
nil
,
fmt
.
Errorf
(
"序列化请求失败: %w"
,
err
)
return
nil
,
nil
,
fmt
.
Errorf
(
"序列化请求失败: %w"
,
err
)
}
}
// 获取可用的 URL 列表
// 固定顺序:prod -> daily
availableURLs
:=
DefaultURLAvailability
.
GetAvailableURLs
()
availableURLs
:=
BaseURLs
if
len
(
availableURLs
)
==
0
{
availableURLs
=
BaseURLs
// 所有 URL 都不可用时,重试所有
}
var
lastErr
error
var
lastErr
error
for
urlIdx
,
baseURL
:=
range
availableURLs
{
for
urlIdx
,
baseURL
:=
range
availableURLs
{
...
@@ -343,7 +322,6 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
...
@@ -343,7 +322,6 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
if
err
!=
nil
{
if
err
!=
nil
{
lastErr
=
fmt
.
Errorf
(
"loadCodeAssist 请求失败: %w"
,
err
)
lastErr
=
fmt
.
Errorf
(
"loadCodeAssist 请求失败: %w"
,
err
)
if
shouldFallbackToNextURL
(
err
,
0
)
&&
urlIdx
<
len
(
availableURLs
)
-
1
{
if
shouldFallbackToNextURL
(
err
,
0
)
&&
urlIdx
<
len
(
availableURLs
)
-
1
{
DefaultURLAvailability
.
MarkUnavailable
(
baseURL
)
log
.
Printf
(
"[antigravity] loadCodeAssist URL fallback: %s -> %s"
,
baseURL
,
availableURLs
[
urlIdx
+
1
])
log
.
Printf
(
"[antigravity] loadCodeAssist URL fallback: %s -> %s"
,
baseURL
,
availableURLs
[
urlIdx
+
1
])
continue
continue
}
}
...
@@ -358,7 +336,6 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
...
@@ -358,7 +336,6 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
// 检查是否需要 URL 降级
// 检查是否需要 URL 降级
if
shouldFallbackToNextURL
(
nil
,
resp
.
StatusCode
)
&&
urlIdx
<
len
(
availableURLs
)
-
1
{
if
shouldFallbackToNextURL
(
nil
,
resp
.
StatusCode
)
&&
urlIdx
<
len
(
availableURLs
)
-
1
{
DefaultURLAvailability
.
MarkUnavailable
(
baseURL
)
log
.
Printf
(
"[antigravity] loadCodeAssist URL fallback (HTTP %d): %s -> %s"
,
resp
.
StatusCode
,
baseURL
,
availableURLs
[
urlIdx
+
1
])
log
.
Printf
(
"[antigravity] loadCodeAssist URL fallback (HTTP %d): %s -> %s"
,
resp
.
StatusCode
,
baseURL
,
availableURLs
[
urlIdx
+
1
])
continue
continue
}
}
...
@@ -376,6 +353,8 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
...
@@ -376,6 +353,8 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
var
rawResp
map
[
string
]
any
var
rawResp
map
[
string
]
any
_
=
json
.
Unmarshal
(
respBodyBytes
,
&
rawResp
)
_
=
json
.
Unmarshal
(
respBodyBytes
,
&
rawResp
)
// 标记成功的 URL,下次优先使用
DefaultURLAvailability
.
MarkSuccess
(
baseURL
)
return
&
loadResp
,
rawResp
,
nil
return
&
loadResp
,
rawResp
,
nil
}
}
...
@@ -412,11 +391,8 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
...
@@ -412,11 +391,8 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
return
nil
,
nil
,
fmt
.
Errorf
(
"序列化请求失败: %w"
,
err
)
return
nil
,
nil
,
fmt
.
Errorf
(
"序列化请求失败: %w"
,
err
)
}
}
// 获取可用的 URL 列表
// 固定顺序:prod -> daily
availableURLs
:=
DefaultURLAvailability
.
GetAvailableURLs
()
availableURLs
:=
BaseURLs
if
len
(
availableURLs
)
==
0
{
availableURLs
=
BaseURLs
// 所有 URL 都不可用时,重试所有
}
var
lastErr
error
var
lastErr
error
for
urlIdx
,
baseURL
:=
range
availableURLs
{
for
urlIdx
,
baseURL
:=
range
availableURLs
{
...
@@ -434,7 +410,6 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
...
@@ -434,7 +410,6 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
if
err
!=
nil
{
if
err
!=
nil
{
lastErr
=
fmt
.
Errorf
(
"fetchAvailableModels 请求失败: %w"
,
err
)
lastErr
=
fmt
.
Errorf
(
"fetchAvailableModels 请求失败: %w"
,
err
)
if
shouldFallbackToNextURL
(
err
,
0
)
&&
urlIdx
<
len
(
availableURLs
)
-
1
{
if
shouldFallbackToNextURL
(
err
,
0
)
&&
urlIdx
<
len
(
availableURLs
)
-
1
{
DefaultURLAvailability
.
MarkUnavailable
(
baseURL
)
log
.
Printf
(
"[antigravity] fetchAvailableModels URL fallback: %s -> %s"
,
baseURL
,
availableURLs
[
urlIdx
+
1
])
log
.
Printf
(
"[antigravity] fetchAvailableModels URL fallback: %s -> %s"
,
baseURL
,
availableURLs
[
urlIdx
+
1
])
continue
continue
}
}
...
@@ -449,7 +424,6 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
...
@@ -449,7 +424,6 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
// 检查是否需要 URL 降级
// 检查是否需要 URL 降级
if
shouldFallbackToNextURL
(
nil
,
resp
.
StatusCode
)
&&
urlIdx
<
len
(
availableURLs
)
-
1
{
if
shouldFallbackToNextURL
(
nil
,
resp
.
StatusCode
)
&&
urlIdx
<
len
(
availableURLs
)
-
1
{
DefaultURLAvailability
.
MarkUnavailable
(
baseURL
)
log
.
Printf
(
"[antigravity] fetchAvailableModels URL fallback (HTTP %d): %s -> %s"
,
resp
.
StatusCode
,
baseURL
,
availableURLs
[
urlIdx
+
1
])
log
.
Printf
(
"[antigravity] fetchAvailableModels URL fallback (HTTP %d): %s -> %s"
,
resp
.
StatusCode
,
baseURL
,
availableURLs
[
urlIdx
+
1
])
continue
continue
}
}
...
@@ -467,6 +441,8 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
...
@@ -467,6 +441,8 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
var
rawResp
map
[
string
]
any
var
rawResp
map
[
string
]
any
_
=
json
.
Unmarshal
(
respBodyBytes
,
&
rawResp
)
_
=
json
.
Unmarshal
(
respBodyBytes
,
&
rawResp
)
// 标记成功的 URL,下次优先使用
DefaultURLAvailability
.
MarkSuccess
(
baseURL
)
return
&
modelsResp
,
rawResp
,
nil
return
&
modelsResp
,
rawResp
,
nil
}
}
...
...
backend/internal/pkg/antigravity/gemini_types.go
View file @
292f25f9
...
@@ -143,9 +143,10 @@ type GeminiResponse struct {
...
@@ -143,9 +143,10 @@ type GeminiResponse struct {
// GeminiCandidate Gemini 候选响应
// GeminiCandidate Gemini 候选响应
type
GeminiCandidate
struct
{
type
GeminiCandidate
struct
{
Content
*
GeminiContent
`json:"content,omitempty"`
Content
*
GeminiContent
`json:"content,omitempty"`
FinishReason
string
`json:"finishReason,omitempty"`
FinishReason
string
`json:"finishReason,omitempty"`
Index
int
`json:"index,omitempty"`
Index
int
`json:"index,omitempty"`
GroundingMetadata
*
GeminiGroundingMetadata
`json:"groundingMetadata,omitempty"`
}
}
// GeminiUsageMetadata Gemini 用量元数据
// GeminiUsageMetadata Gemini 用量元数据
...
@@ -156,6 +157,23 @@ type GeminiUsageMetadata struct {
...
@@ -156,6 +157,23 @@ type GeminiUsageMetadata struct {
TotalTokenCount
int
`json:"totalTokenCount,omitempty"`
TotalTokenCount
int
`json:"totalTokenCount,omitempty"`
}
}
// GeminiGroundingMetadata Gemini grounding 元数据(Web Search)
type
GeminiGroundingMetadata
struct
{
WebSearchQueries
[]
string
`json:"webSearchQueries,omitempty"`
GroundingChunks
[]
GeminiGroundingChunk
`json:"groundingChunks,omitempty"`
}
// GeminiGroundingChunk Gemini grounding chunk
type
GeminiGroundingChunk
struct
{
Web
*
GeminiGroundingWeb
`json:"web,omitempty"`
}
// GeminiGroundingWeb Gemini grounding web 信息
type
GeminiGroundingWeb
struct
{
Title
string
`json:"title,omitempty"`
URI
string
`json:"uri,omitempty"`
}
// DefaultSafetySettings 默认安全设置(关闭所有过滤)
// DefaultSafetySettings 默认安全设置(关闭所有过滤)
var
DefaultSafetySettings
=
[]
GeminiSafetySetting
{
var
DefaultSafetySettings
=
[]
GeminiSafetySetting
{
{
Category
:
"HARM_CATEGORY_HARASSMENT"
,
Threshold
:
"OFF"
},
{
Category
:
"HARM_CATEGORY_HARASSMENT"
,
Threshold
:
"OFF"
},
...
...
backend/internal/pkg/antigravity/oauth.go
View file @
292f25f9
...
@@ -32,8 +32,8 @@ const (
...
@@ -32,8 +32,8 @@ const (
"https://www.googleapis.com/auth/cclog "
+
"https://www.googleapis.com/auth/cclog "
+
"https://www.googleapis.com/auth/experimentsandconfigs"
"https://www.googleapis.com/auth/experimentsandconfigs"
// User-Agent(
模拟官方客户端
)
// User-Agent(
与 Antigravity-Manager 保持一致
)
UserAgent
=
"antigravity/1.1
04.0 darwin/arm
64"
UserAgent
=
"antigravity/1.1
1.9 windows/amd
64"
// Session 过期时间
// Session 过期时间
SessionTTL
=
30
*
time
.
Minute
SessionTTL
=
30
*
time
.
Minute
...
@@ -42,22 +42,21 @@ const (
...
@@ -42,22 +42,21 @@ const (
URLAvailabilityTTL
=
5
*
time
.
Minute
URLAvailabilityTTL
=
5
*
time
.
Minute
)
)
// BaseURLs 定义 Antigravity API 端点,按优先级排序
// BaseURLs 定义 Antigravity API 端点(与 Antigravity-Manager 保持一致)
// fallback 顺序: sandbox → daily → prod
var
BaseURLs
=
[]
string
{
var
BaseURLs
=
[]
string
{
"https://daily-cloudcode-pa.sandbox.googleapis.com"
,
// sandbox
"https://cloudcode-pa.googleapis.com"
,
// prod (优先)
"https://daily-cloudcode-pa.googleapis.com"
,
// daily
"https://daily-cloudcode-pa.sandbox.googleapis.com"
,
// daily sandbox (备用)
"https://cloudcode-pa.googleapis.com"
,
// prod
}
}
// BaseURL 默认 URL(保持向后兼容)
// BaseURL 默认 URL(保持向后兼容)
var
BaseURL
=
BaseURLs
[
0
]
var
BaseURL
=
BaseURLs
[
0
]
// URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复)
// URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复
和动态优先级
)
type
URLAvailability
struct
{
type
URLAvailability
struct
{
mu
sync
.
RWMutex
mu
sync
.
RWMutex
unavailable
map
[
string
]
time
.
Time
// URL -> 恢复时间
unavailable
map
[
string
]
time
.
Time
// URL -> 恢复时间
ttl
time
.
Duration
ttl
time
.
Duration
lastSuccess
string
// 最近成功请求的 URL,优先使用
}
}
// DefaultURLAvailability 全局 URL 可用性管理器
// DefaultURLAvailability 全局 URL 可用性管理器
...
@@ -78,6 +77,15 @@ func (u *URLAvailability) MarkUnavailable(url string) {
...
@@ -78,6 +77,15 @@ func (u *URLAvailability) MarkUnavailable(url string) {
u
.
unavailable
[
url
]
=
time
.
Now
()
.
Add
(
u
.
ttl
)
u
.
unavailable
[
url
]
=
time
.
Now
()
.
Add
(
u
.
ttl
)
}
}
// MarkSuccess 标记 URL 请求成功,将其设为优先使用
func
(
u
*
URLAvailability
)
MarkSuccess
(
url
string
)
{
u
.
mu
.
Lock
()
defer
u
.
mu
.
Unlock
()
u
.
lastSuccess
=
url
// 成功后清除该 URL 的不可用标记
delete
(
u
.
unavailable
,
url
)
}
// IsAvailable 检查 URL 是否可用
// IsAvailable 检查 URL 是否可用
func
(
u
*
URLAvailability
)
IsAvailable
(
url
string
)
bool
{
func
(
u
*
URLAvailability
)
IsAvailable
(
url
string
)
bool
{
u
.
mu
.
RLock
()
u
.
mu
.
RLock
()
...
@@ -89,14 +97,29 @@ func (u *URLAvailability) IsAvailable(url string) bool {
...
@@ -89,14 +97,29 @@ func (u *URLAvailability) IsAvailable(url string) bool {
return
time
.
Now
()
.
After
(
expiry
)
return
time
.
Now
()
.
After
(
expiry
)
}
}
// GetAvailableURLs 返回可用的 URL 列表(保持优先级顺序)
// GetAvailableURLs 返回可用的 URL 列表
// 最近成功的 URL 优先,其他按默认顺序
func
(
u
*
URLAvailability
)
GetAvailableURLs
()
[]
string
{
func
(
u
*
URLAvailability
)
GetAvailableURLs
()
[]
string
{
u
.
mu
.
RLock
()
u
.
mu
.
RLock
()
defer
u
.
mu
.
RUnlock
()
defer
u
.
mu
.
RUnlock
()
now
:=
time
.
Now
()
now
:=
time
.
Now
()
result
:=
make
([]
string
,
0
,
len
(
BaseURLs
))
result
:=
make
([]
string
,
0
,
len
(
BaseURLs
))
// 如果有最近成功的 URL 且可用,放在最前面
if
u
.
lastSuccess
!=
""
{
expiry
,
exists
:=
u
.
unavailable
[
u
.
lastSuccess
]
if
!
exists
||
now
.
After
(
expiry
)
{
result
=
append
(
result
,
u
.
lastSuccess
)
}
}
// 添加其他可用的 URL(按默认顺序)
for
_
,
url
:=
range
BaseURLs
{
for
_
,
url
:=
range
BaseURLs
{
// 跳过已添加的 lastSuccess
if
url
==
u
.
lastSuccess
{
continue
}
expiry
,
exists
:=
u
.
unavailable
[
url
]
expiry
,
exists
:=
u
.
unavailable
[
url
]
if
!
exists
||
now
.
After
(
expiry
)
{
if
!
exists
||
now
.
After
(
expiry
)
{
result
=
append
(
result
,
url
)
result
=
append
(
result
,
url
)
...
@@ -240,24 +263,3 @@ func BuildAuthorizationURL(state, codeChallenge string) string {
...
@@ -240,24 +263,3 @@ func BuildAuthorizationURL(state, codeChallenge string) string {
return
fmt
.
Sprintf
(
"%s?%s"
,
AuthorizeURL
,
params
.
Encode
())
return
fmt
.
Sprintf
(
"%s?%s"
,
AuthorizeURL
,
params
.
Encode
())
}
}
// GenerateMockProjectID 生成随机 project_id(当 API 不返回时使用)
// 格式:{形容词}-{名词}-{5位随机字符}
func
GenerateMockProjectID
()
string
{
adjectives
:=
[]
string
{
"useful"
,
"bright"
,
"swift"
,
"calm"
,
"bold"
}
nouns
:=
[]
string
{
"fuze"
,
"wave"
,
"spark"
,
"flow"
,
"core"
}
randBytes
,
_
:=
GenerateRandomBytes
(
7
)
adj
:=
adjectives
[
int
(
randBytes
[
0
])
%
len
(
adjectives
)]
noun
:=
nouns
[
int
(
randBytes
[
1
])
%
len
(
nouns
)]
// 生成 5 位随机字符(a-z0-9)
const
charset
=
"abcdefghijklmnopqrstuvwxyz0123456789"
suffix
:=
make
([]
byte
,
5
)
for
i
:=
0
;
i
<
5
;
i
++
{
suffix
[
i
]
=
charset
[
int
(
randBytes
[
i
+
2
])
%
len
(
charset
)]
}
return
fmt
.
Sprintf
(
"%s-%s-%s"
,
adj
,
noun
,
string
(
suffix
))
}
backend/internal/pkg/antigravity/request_transformer.go
View file @
292f25f9
...
@@ -54,6 +54,9 @@ func DefaultTransformOptions() TransformOptions {
...
@@ -54,6 +54,9 @@ func DefaultTransformOptions() TransformOptions {
}
}
}
}
// webSearchFallbackModel web_search 请求使用的降级模型
const
webSearchFallbackModel
=
"gemini-2.5-flash"
// TransformClaudeToGemini 将 Claude 请求转换为 v1internal Gemini 格式
// TransformClaudeToGemini 将 Claude 请求转换为 v1internal Gemini 格式
func
TransformClaudeToGemini
(
claudeReq
*
ClaudeRequest
,
projectID
,
mappedModel
string
)
([]
byte
,
error
)
{
func
TransformClaudeToGemini
(
claudeReq
*
ClaudeRequest
,
projectID
,
mappedModel
string
)
([]
byte
,
error
)
{
return
TransformClaudeToGeminiWithOptions
(
claudeReq
,
projectID
,
mappedModel
,
DefaultTransformOptions
())
return
TransformClaudeToGeminiWithOptions
(
claudeReq
,
projectID
,
mappedModel
,
DefaultTransformOptions
())
...
@@ -64,12 +67,23 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
...
@@ -64,12 +67,23 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
// 用于存储 tool_use id -> name 映射
// 用于存储 tool_use id -> name 映射
toolIDToName
:=
make
(
map
[
string
]
string
)
toolIDToName
:=
make
(
map
[
string
]
string
)
// 检测是否有 web_search 工具
hasWebSearchTool
:=
hasWebSearchTool
(
claudeReq
.
Tools
)
requestType
:=
"agent"
targetModel
:=
mappedModel
if
hasWebSearchTool
{
requestType
=
"web_search"
if
targetModel
!=
webSearchFallbackModel
{
targetModel
=
webSearchFallbackModel
}
}
// 检测是否启用 thinking
// 检测是否启用 thinking
isThinkingEnabled
:=
claudeReq
.
Thinking
!=
nil
&&
claudeReq
.
Thinking
.
Type
==
"enabled"
isThinkingEnabled
:=
claudeReq
.
Thinking
!=
nil
&&
claudeReq
.
Thinking
.
Type
==
"enabled"
// 只有 Gemini 模型支持 dummy thought workaround
// 只有 Gemini 模型支持 dummy thought workaround
// Claude 模型通过 Vertex/Google API 需要有效的 thought signatures
// Claude 模型通过 Vertex/Google API 需要有效的 thought signatures
allowDummyThought
:=
strings
.
HasPrefix
(
mapped
Model
,
"gemini-"
)
allowDummyThought
:=
strings
.
HasPrefix
(
target
Model
,
"gemini-"
)
// 1. 构建 contents
// 1. 构建 contents
contents
,
strippedThinking
,
err
:=
buildContents
(
claudeReq
.
Messages
,
toolIDToName
,
isThinkingEnabled
,
allowDummyThought
)
contents
,
strippedThinking
,
err
:=
buildContents
(
claudeReq
.
Messages
,
toolIDToName
,
isThinkingEnabled
,
allowDummyThought
)
...
@@ -78,7 +92,7 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
...
@@ -78,7 +92,7 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
}
}
// 2. 构建 systemInstruction
// 2. 构建 systemInstruction
systemInstruction
:=
buildSystemInstruction
(
claudeReq
.
System
,
claudeReq
.
Model
,
opts
)
systemInstruction
:=
buildSystemInstruction
(
claudeReq
.
System
,
claudeReq
.
Model
,
opts
,
claudeReq
.
Tools
)
// 3. 构建 generationConfig
// 3. 构建 generationConfig
reqForConfig
:=
claudeReq
reqForConfig
:=
claudeReq
...
@@ -89,6 +103,11 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
...
@@ -89,6 +103,11 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
reqCopy
.
Thinking
=
nil
reqCopy
.
Thinking
=
nil
reqForConfig
=
&
reqCopy
reqForConfig
=
&
reqCopy
}
}
if
targetModel
!=
""
&&
targetModel
!=
reqForConfig
.
Model
{
reqCopy
:=
*
reqForConfig
reqCopy
.
Model
=
targetModel
reqForConfig
=
&
reqCopy
}
generationConfig
:=
buildGenerationConfig
(
reqForConfig
)
generationConfig
:=
buildGenerationConfig
(
reqForConfig
)
// 4. 构建 tools
// 4. 构建 tools
...
@@ -127,8 +146,8 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
...
@@ -127,8 +146,8 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
Project
:
projectID
,
Project
:
projectID
,
RequestID
:
"agent-"
+
uuid
.
New
()
.
String
(),
RequestID
:
"agent-"
+
uuid
.
New
()
.
String
(),
UserAgent
:
"antigravity"
,
// 固定值,与官方客户端一致
UserAgent
:
"antigravity"
,
// 固定值,与官方客户端一致
RequestType
:
"agent"
,
RequestType
:
requestType
,
Model
:
mapped
Model
,
Model
:
target
Model
,
Request
:
innerRequest
,
Request
:
innerRequest
,
}
}
...
@@ -154,8 +173,40 @@ func GetDefaultIdentityPatch() string {
...
@@ -154,8 +173,40 @@ func GetDefaultIdentityPatch() string {
return
antigravityIdentity
return
antigravityIdentity
}
}
// buildSystemInstruction 构建 systemInstruction
// mcpXMLProtocol MCP XML 工具调用协议(与 Antigravity-Manager 保持一致)
func
buildSystemInstruction
(
system
json
.
RawMessage
,
modelName
string
,
opts
TransformOptions
)
*
GeminiContent
{
const
mcpXMLProtocol
=
`
==== MCP XML 工具调用协议 (Workaround) ====
当你需要调用名称以 `
+
"`mcp__`"
+
` 开头的 MCP 工具时:
1) 优先尝试 XML 格式调用:输出 `
+
"`<mcp__tool_name>{
\"
arg
\"
:
\"
value
\"
}</mcp__tool_name>`"
+
`。
2) 必须直接输出 XML 块,无需 markdown 包装,内容为 JSON 格式的入参。
3) 这种方式具有更高的连通性和容错性,适用于大型结果返回场景。
===========================================`
// hasMCPTools 检测是否有 mcp__ 前缀的工具
func
hasMCPTools
(
tools
[]
ClaudeTool
)
bool
{
for
_
,
tool
:=
range
tools
{
if
strings
.
HasPrefix
(
tool
.
Name
,
"mcp__"
)
{
return
true
}
}
return
false
}
// filterOpenCodePrompt 过滤 OpenCode 默认提示词,只保留用户自定义指令
func
filterOpenCodePrompt
(
text
string
)
string
{
if
!
strings
.
Contains
(
text
,
"You are an interactive CLI tool"
)
{
return
text
}
// 提取 "Instructions from:" 及之后的部分
if
idx
:=
strings
.
Index
(
text
,
"Instructions from:"
);
idx
>=
0
{
return
text
[
idx
:
]
}
// 如果没有自定义指令,返回空
return
""
}
// buildSystemInstruction 构建 systemInstruction(与 Antigravity-Manager 保持一致)
func
buildSystemInstruction
(
system
json
.
RawMessage
,
modelName
string
,
opts
TransformOptions
,
tools
[]
ClaudeTool
)
*
GeminiContent
{
var
parts
[]
GeminiPart
var
parts
[]
GeminiPart
// 先解析用户的 system prompt,检测是否已包含 Antigravity identity
// 先解析用户的 system prompt,检测是否已包含 Antigravity identity
...
@@ -167,10 +218,14 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
...
@@ -167,10 +218,14 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
var
sysStr
string
var
sysStr
string
if
err
:=
json
.
Unmarshal
(
system
,
&
sysStr
);
err
==
nil
{
if
err
:=
json
.
Unmarshal
(
system
,
&
sysStr
);
err
==
nil
{
if
strings
.
TrimSpace
(
sysStr
)
!=
""
{
if
strings
.
TrimSpace
(
sysStr
)
!=
""
{
userSystemParts
=
append
(
userSystemParts
,
GeminiPart
{
Text
:
sysStr
})
if
strings
.
Contains
(
sysStr
,
"You are Antigravity"
)
{
if
strings
.
Contains
(
sysStr
,
"You are Antigravity"
)
{
userHasAntigravityIdentity
=
true
userHasAntigravityIdentity
=
true
}
}
// 过滤 OpenCode 默认提示词
filtered
:=
filterOpenCodePrompt
(
sysStr
)
if
filtered
!=
""
{
userSystemParts
=
append
(
userSystemParts
,
GeminiPart
{
Text
:
filtered
})
}
}
}
}
else
{
}
else
{
// 尝试解析为数组
// 尝试解析为数组
...
@@ -178,10 +233,14 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
...
@@ -178,10 +233,14 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
if
err
:=
json
.
Unmarshal
(
system
,
&
sysBlocks
);
err
==
nil
{
if
err
:=
json
.
Unmarshal
(
system
,
&
sysBlocks
);
err
==
nil
{
for
_
,
block
:=
range
sysBlocks
{
for
_
,
block
:=
range
sysBlocks
{
if
block
.
Type
==
"text"
&&
strings
.
TrimSpace
(
block
.
Text
)
!=
""
{
if
block
.
Type
==
"text"
&&
strings
.
TrimSpace
(
block
.
Text
)
!=
""
{
userSystemParts
=
append
(
userSystemParts
,
GeminiPart
{
Text
:
block
.
Text
})
if
strings
.
Contains
(
block
.
Text
,
"You are Antigravity"
)
{
if
strings
.
Contains
(
block
.
Text
,
"You are Antigravity"
)
{
userHasAntigravityIdentity
=
true
userHasAntigravityIdentity
=
true
}
}
// 过滤 OpenCode 默认提示词
filtered
:=
filterOpenCodePrompt
(
block
.
Text
)
if
filtered
!=
""
{
userSystemParts
=
append
(
userSystemParts
,
GeminiPart
{
Text
:
filtered
})
}
}
}
}
}
}
}
...
@@ -200,6 +259,16 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
...
@@ -200,6 +259,16 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
// 添加用户的 system prompt
// 添加用户的 system prompt
parts
=
append
(
parts
,
userSystemParts
...
)
parts
=
append
(
parts
,
userSystemParts
...
)
// 检测是否有 MCP 工具,如有则注入 XML 调用协议
if
hasMCPTools
(
tools
)
{
parts
=
append
(
parts
,
GeminiPart
{
Text
:
mcpXMLProtocol
})
}
// 如果用户没有提供 Antigravity 身份,添加结束标记
if
!
userHasAntigravityIdentity
{
parts
=
append
(
parts
,
GeminiPart
{
Text
:
"
\n
--- [SYSTEM_PROMPT_END] ---"
})
}
if
len
(
parts
)
==
0
{
if
len
(
parts
)
==
0
{
return
nil
return
nil
}
}
...
@@ -429,6 +498,11 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
...
@@ -429,6 +498,11 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
StopSequences
:
DefaultStopSequences
,
StopSequences
:
DefaultStopSequences
,
}
}
// 如果请求中指定了 MaxTokens,使用请求值
if
req
.
MaxTokens
>
0
{
config
.
MaxOutputTokens
=
req
.
MaxTokens
}
// Thinking 配置
// Thinking 配置
if
req
.
Thinking
!=
nil
&&
req
.
Thinking
.
Type
==
"enabled"
{
if
req
.
Thinking
!=
nil
&&
req
.
Thinking
.
Type
==
"enabled"
{
config
.
ThinkingConfig
=
&
GeminiThinkingConfig
{
config
.
ThinkingConfig
=
&
GeminiThinkingConfig
{
...
@@ -458,37 +532,43 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
...
@@ -458,37 +532,43 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
return
config
return
config
}
}
func
hasWebSearchTool
(
tools
[]
ClaudeTool
)
bool
{
for
_
,
tool
:=
range
tools
{
if
isWebSearchTool
(
tool
)
{
return
true
}
}
return
false
}
func
isWebSearchTool
(
tool
ClaudeTool
)
bool
{
if
strings
.
HasPrefix
(
tool
.
Type
,
"web_search"
)
||
tool
.
Type
==
"google_search"
{
return
true
}
name
:=
strings
.
TrimSpace
(
tool
.
Name
)
switch
name
{
case
"web_search"
,
"google_search"
,
"web_search_20250305"
:
return
true
default
:
return
false
}
}
// buildTools 构建 tools
// buildTools 构建 tools
func
buildTools
(
tools
[]
ClaudeTool
)
[]
GeminiToolDeclaration
{
func
buildTools
(
tools
[]
ClaudeTool
)
[]
GeminiToolDeclaration
{
if
len
(
tools
)
==
0
{
if
len
(
tools
)
==
0
{
return
nil
return
nil
}
}
// 检查是否有 web_search 工具
hasWebSearch
:=
hasWebSearchTool
(
tools
)
hasWebSearch
:=
false
for
_
,
tool
:=
range
tools
{
if
tool
.
Name
==
"web_search"
{
hasWebSearch
=
true
break
}
}
if
hasWebSearch
{
// Web Search 工具映射
return
[]
GeminiToolDeclaration
{{
GoogleSearch
:
&
GeminiGoogleSearch
{
EnhancedContent
:
&
GeminiEnhancedContent
{
ImageSearch
:
&
GeminiImageSearch
{
MaxResultCount
:
5
,
},
},
},
}}
}
// 普通工具
// 普通工具
var
funcDecls
[]
GeminiFunctionDecl
var
funcDecls
[]
GeminiFunctionDecl
for
_
,
tool
:=
range
tools
{
for
_
,
tool
:=
range
tools
{
if
isWebSearchTool
(
tool
)
{
continue
}
// 跳过无效工具名称
// 跳过无效工具名称
if
strings
.
TrimSpace
(
tool
.
Name
)
==
""
{
if
strings
.
TrimSpace
(
tool
.
Name
)
==
""
{
log
.
Printf
(
"Warning: skipping tool with empty name"
)
log
.
Printf
(
"Warning: skipping tool with empty name"
)
...
@@ -531,7 +611,20 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
...
@@ -531,7 +611,20 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
}
}
if
len
(
funcDecls
)
==
0
{
if
len
(
funcDecls
)
==
0
{
return
nil
if
!
hasWebSearch
{
return
nil
}
// Web Search 工具映射
return
[]
GeminiToolDeclaration
{{
GoogleSearch
:
&
GeminiGoogleSearch
{
EnhancedContent
:
&
GeminiEnhancedContent
{
ImageSearch
:
&
GeminiImageSearch
{
MaxResultCount
:
5
,
},
},
},
}}
}
}
return
[]
GeminiToolDeclaration
{{
return
[]
GeminiToolDeclaration
{{
...
...
backend/internal/pkg/antigravity/response_transformer.go
View file @
292f25f9
...
@@ -3,6 +3,7 @@ package antigravity
...
@@ -3,6 +3,7 @@ package antigravity
import
(
import
(
"encoding/json"
"encoding/json"
"fmt"
"fmt"
"strings"
)
)
// TransformGeminiToClaude 将 Gemini 响应转换为 Claude 格式(非流式)
// TransformGeminiToClaude 将 Gemini 响应转换为 Claude 格式(非流式)
...
@@ -63,6 +64,12 @@ func (p *NonStreamingProcessor) Process(geminiResp *GeminiResponse, responseID,
...
@@ -63,6 +64,12 @@ func (p *NonStreamingProcessor) Process(geminiResp *GeminiResponse, responseID,
p
.
processPart
(
&
part
)
p
.
processPart
(
&
part
)
}
}
if
len
(
geminiResp
.
Candidates
)
>
0
{
if
grounding
:=
geminiResp
.
Candidates
[
0
]
.
GroundingMetadata
;
grounding
!=
nil
{
p
.
processGrounding
(
grounding
)
}
}
// 刷新剩余内容
// 刷新剩余内容
p
.
flushThinking
()
p
.
flushThinking
()
p
.
flushText
()
p
.
flushText
()
...
@@ -190,6 +197,18 @@ func (p *NonStreamingProcessor) processPart(part *GeminiPart) {
...
@@ -190,6 +197,18 @@ func (p *NonStreamingProcessor) processPart(part *GeminiPart) {
}
}
}
}
func
(
p
*
NonStreamingProcessor
)
processGrounding
(
grounding
*
GeminiGroundingMetadata
)
{
groundingText
:=
buildGroundingText
(
grounding
)
if
groundingText
==
""
{
return
}
p
.
flushThinking
()
p
.
flushText
()
p
.
textBuilder
+=
groundingText
p
.
flushText
()
}
// flushText 刷新 text builder
// flushText 刷新 text builder
func
(
p
*
NonStreamingProcessor
)
flushText
()
{
func
(
p
*
NonStreamingProcessor
)
flushText
()
{
if
p
.
textBuilder
==
""
{
if
p
.
textBuilder
==
""
{
...
@@ -262,6 +281,44 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon
...
@@ -262,6 +281,44 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon
}
}
}
}
func
buildGroundingText
(
grounding
*
GeminiGroundingMetadata
)
string
{
if
grounding
==
nil
{
return
""
}
var
builder
strings
.
Builder
if
len
(
grounding
.
WebSearchQueries
)
>
0
{
_
,
_
=
builder
.
WriteString
(
"
\n\n
---
\n
Web search queries: "
)
_
,
_
=
builder
.
WriteString
(
strings
.
Join
(
grounding
.
WebSearchQueries
,
", "
))
}
if
len
(
grounding
.
GroundingChunks
)
>
0
{
var
links
[]
string
for
i
,
chunk
:=
range
grounding
.
GroundingChunks
{
if
chunk
.
Web
==
nil
{
continue
}
title
:=
strings
.
TrimSpace
(
chunk
.
Web
.
Title
)
if
title
==
""
{
title
=
"Source"
}
uri
:=
strings
.
TrimSpace
(
chunk
.
Web
.
URI
)
if
uri
==
""
{
uri
=
"#"
}
links
=
append
(
links
,
fmt
.
Sprintf
(
"[%d] [%s](%s)"
,
i
+
1
,
title
,
uri
))
}
if
len
(
links
)
>
0
{
_
,
_
=
builder
.
WriteString
(
"
\n\n
Sources:
\n
"
)
_
,
_
=
builder
.
WriteString
(
strings
.
Join
(
links
,
"
\n
"
))
}
}
return
builder
.
String
()
}
// generateRandomID 生成随机 ID
// generateRandomID 生成随机 ID
func
generateRandomID
()
string
{
func
generateRandomID
()
string
{
const
chars
=
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
const
chars
=
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
...
...
backend/internal/pkg/antigravity/stream_transformer.go
View file @
292f25f9
...
@@ -27,6 +27,8 @@ type StreamingProcessor struct {
...
@@ -27,6 +27,8 @@ type StreamingProcessor struct {
pendingSignature
string
pendingSignature
string
trailingSignature
string
trailingSignature
string
originalModel
string
originalModel
string
webSearchQueries
[]
string
groundingChunks
[]
GeminiGroundingChunk
// 累计 usage
// 累计 usage
inputTokens
int
inputTokens
int
...
@@ -93,6 +95,10 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
...
@@ -93,6 +95,10 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
}
}
}
}
if
len
(
geminiResp
.
Candidates
)
>
0
{
p
.
captureGrounding
(
geminiResp
.
Candidates
[
0
]
.
GroundingMetadata
)
}
// 检查是否结束
// 检查是否结束
if
len
(
geminiResp
.
Candidates
)
>
0
{
if
len
(
geminiResp
.
Candidates
)
>
0
{
finishReason
:=
geminiResp
.
Candidates
[
0
]
.
FinishReason
finishReason
:=
geminiResp
.
Candidates
[
0
]
.
FinishReason
...
@@ -200,6 +206,20 @@ func (p *StreamingProcessor) processPart(part *GeminiPart) []byte {
...
@@ -200,6 +206,20 @@ func (p *StreamingProcessor) processPart(part *GeminiPart) []byte {
return
result
.
Bytes
()
return
result
.
Bytes
()
}
}
func
(
p
*
StreamingProcessor
)
captureGrounding
(
grounding
*
GeminiGroundingMetadata
)
{
if
grounding
==
nil
{
return
}
if
len
(
grounding
.
WebSearchQueries
)
>
0
&&
len
(
p
.
webSearchQueries
)
==
0
{
p
.
webSearchQueries
=
append
([]
string
(
nil
),
grounding
.
WebSearchQueries
...
)
}
if
len
(
grounding
.
GroundingChunks
)
>
0
&&
len
(
p
.
groundingChunks
)
==
0
{
p
.
groundingChunks
=
append
([]
GeminiGroundingChunk
(
nil
),
grounding
.
GroundingChunks
...
)
}
}
// processThinking 处理 thinking
// processThinking 处理 thinking
func
(
p
*
StreamingProcessor
)
processThinking
(
text
,
signature
string
)
[]
byte
{
func
(
p
*
StreamingProcessor
)
processThinking
(
text
,
signature
string
)
[]
byte
{
var
result
bytes
.
Buffer
var
result
bytes
.
Buffer
...
@@ -417,6 +437,23 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
...
@@ -417,6 +437,23 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
p
.
trailingSignature
=
""
p
.
trailingSignature
=
""
}
}
if
len
(
p
.
webSearchQueries
)
>
0
||
len
(
p
.
groundingChunks
)
>
0
{
groundingText
:=
buildGroundingText
(
&
GeminiGroundingMetadata
{
WebSearchQueries
:
p
.
webSearchQueries
,
GroundingChunks
:
p
.
groundingChunks
,
})
if
groundingText
!=
""
{
_
,
_
=
result
.
Write
(
p
.
startBlock
(
BlockTypeText
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
""
,
}))
_
,
_
=
result
.
Write
(
p
.
emitDelta
(
"text_delta"
,
map
[
string
]
any
{
"text"
:
groundingText
,
}))
_
,
_
=
result
.
Write
(
p
.
endBlock
())
}
}
// 确定 stop_reason
// 确定 stop_reason
stopReason
:=
"end_turn"
stopReason
:=
"end_turn"
if
p
.
usedTool
{
if
p
.
usedTool
{
...
...
backend/internal/pkg/response/response.go
View file @
292f25f9
...
@@ -162,11 +162,11 @@ func ParsePagination(c *gin.Context) (page, pageSize int) {
...
@@ -162,11 +162,11 @@ func ParsePagination(c *gin.Context) (page, pageSize int) {
// 支持 page_size 和 limit 两种参数名
// 支持 page_size 和 limit 两种参数名
if
ps
:=
c
.
Query
(
"page_size"
);
ps
!=
""
{
if
ps
:=
c
.
Query
(
"page_size"
);
ps
!=
""
{
if
val
,
err
:=
parseInt
(
ps
);
err
==
nil
&&
val
>
0
&&
val
<=
100
{
if
val
,
err
:=
parseInt
(
ps
);
err
==
nil
&&
val
>
0
&&
val
<=
100
0
{
pageSize
=
val
pageSize
=
val
}
}
}
else
if
l
:=
c
.
Query
(
"limit"
);
l
!=
""
{
}
else
if
l
:=
c
.
Query
(
"limit"
);
l
!=
""
{
if
val
,
err
:=
parseInt
(
l
);
err
==
nil
&&
val
>
0
&&
val
<=
100
{
if
val
,
err
:=
parseInt
(
l
);
err
==
nil
&&
val
>
0
&&
val
<=
100
0
{
pageSize
=
val
pageSize
=
val
}
}
}
}
...
...
backend/internal/pkg/tlsfingerprint/dialer.go
0 → 100644
View file @
292f25f9
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
// It uses the utls library to create TLS connections that mimic Node.js/Claude Code clients.
package
tlsfingerprint
import
(
"bufio"
"context"
"encoding/base64"
"fmt"
"log/slog"
"net"
"net/http"
"net/url"
utls
"github.com/refraction-networking/utls"
"golang.org/x/net/proxy"
)
// Profile contains TLS fingerprint configuration.
type
Profile
struct
{
Name
string
// Profile name for identification
CipherSuites
[]
uint16
Curves
[]
uint16
PointFormats
[]
uint8
EnableGREASE
bool
}
// Dialer creates TLS connections with custom fingerprints.
type
Dialer
struct
{
profile
*
Profile
baseDialer
func
(
ctx
context
.
Context
,
network
,
addr
string
)
(
net
.
Conn
,
error
)
}
// HTTPProxyDialer creates TLS connections through HTTP/HTTPS proxies with custom fingerprints.
// It handles the CONNECT tunnel establishment before performing TLS handshake.
type
HTTPProxyDialer
struct
{
profile
*
Profile
proxyURL
*
url
.
URL
}
// SOCKS5ProxyDialer creates TLS connections through SOCKS5 proxies with custom fingerprints.
// It uses golang.org/x/net/proxy to establish the SOCKS5 tunnel.
type
SOCKS5ProxyDialer
struct
{
profile
*
Profile
proxyURL
*
url
.
URL
}
// Default TLS fingerprint values captured from Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x)
// Captured using: tshark -i lo -f "tcp port 8443" -Y "tls.handshake.type == 1" -V
// JA3 Hash: 1a28e69016765d92e3b381168d68922c
//
// Note: JA3/JA4 may have slight variations due to:
// - Session ticket presence/absence
// - Extension negotiation state
var
(
// defaultCipherSuites contains all 59 cipher suites from Claude CLI
// Order is critical for JA3 fingerprint matching
defaultCipherSuites
=
[]
uint16
{
// TLS 1.3 cipher suites (MUST be first)
0x1302
,
// TLS_AES_256_GCM_SHA384
0x1303
,
// TLS_CHACHA20_POLY1305_SHA256
0x1301
,
// TLS_AES_128_GCM_SHA256
// ECDHE + AES-GCM
0xc02f
,
// TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
0xc02b
,
// TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256
0xc030
,
// TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384
0xc02c
,
// TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384
// DHE + AES-GCM
0x009e
,
// TLS_DHE_RSA_WITH_AES_128_GCM_SHA256
// ECDHE/DHE + AES-CBC-SHA256/384
0xc027
,
// TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256
0x0067
,
// TLS_DHE_RSA_WITH_AES_128_CBC_SHA256
0xc028
,
// TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384
0x006b
,
// TLS_DHE_RSA_WITH_AES_256_CBC_SHA256
// DHE-DSS/RSA + AES-GCM
0x00a3
,
// TLS_DHE_DSS_WITH_AES_256_GCM_SHA384
0x009f
,
// TLS_DHE_RSA_WITH_AES_256_GCM_SHA384
// ChaCha20-Poly1305
0xcca9
,
// TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256
0xcca8
,
// TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256
0xccaa
,
// TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256
// AES-CCM (256-bit)
0xc0af
,
// TLS_ECDHE_ECDSA_WITH_AES_256_CCM_8
0xc0ad
,
// TLS_ECDHE_ECDSA_WITH_AES_256_CCM
0xc0a3
,
// TLS_DHE_RSA_WITH_AES_256_CCM_8
0xc09f
,
// TLS_DHE_RSA_WITH_AES_256_CCM
// ARIA (256-bit)
0xc05d
,
// TLS_ECDHE_ECDSA_WITH_ARIA_256_GCM_SHA384
0xc061
,
// TLS_ECDHE_RSA_WITH_ARIA_256_GCM_SHA384
0xc057
,
// TLS_DHE_DSS_WITH_ARIA_256_GCM_SHA384
0xc053
,
// TLS_DHE_RSA_WITH_ARIA_256_GCM_SHA384
// DHE-DSS + AES-GCM (128-bit)
0x00a2
,
// TLS_DHE_DSS_WITH_AES_128_GCM_SHA256
// AES-CCM (128-bit)
0xc0ae
,
// TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8
0xc0ac
,
// TLS_ECDHE_ECDSA_WITH_AES_128_CCM
0xc0a2
,
// TLS_DHE_RSA_WITH_AES_128_CCM_8
0xc09e
,
// TLS_DHE_RSA_WITH_AES_128_CCM
// ARIA (128-bit)
0xc05c
,
// TLS_ECDHE_ECDSA_WITH_ARIA_128_GCM_SHA256
0xc060
,
// TLS_ECDHE_RSA_WITH_ARIA_128_GCM_SHA256
0xc056
,
// TLS_DHE_DSS_WITH_ARIA_128_GCM_SHA256
0xc052
,
// TLS_DHE_RSA_WITH_ARIA_128_GCM_SHA256
// ECDHE/DHE + AES-CBC-SHA384/256 (more)
0xc024
,
// TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384
0x006a
,
// TLS_DHE_DSS_WITH_AES_256_CBC_SHA256
0xc023
,
// TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256
0x0040
,
// TLS_DHE_DSS_WITH_AES_128_CBC_SHA256
// ECDHE/DHE + AES-CBC-SHA (legacy)
0xc00a
,
// TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA
0xc014
,
// TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA
0x0039
,
// TLS_DHE_RSA_WITH_AES_256_CBC_SHA
0x0038
,
// TLS_DHE_DSS_WITH_AES_256_CBC_SHA
0xc009
,
// TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA
0xc013
,
// TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA
0x0033
,
// TLS_DHE_RSA_WITH_AES_128_CBC_SHA
0x0032
,
// TLS_DHE_DSS_WITH_AES_128_CBC_SHA
// RSA + AES-GCM/CCM/ARIA (non-PFS, 256-bit)
0x009d
,
// TLS_RSA_WITH_AES_256_GCM_SHA384
0xc0a1
,
// TLS_RSA_WITH_AES_256_CCM_8
0xc09d
,
// TLS_RSA_WITH_AES_256_CCM
0xc051
,
// TLS_RSA_WITH_ARIA_256_GCM_SHA384
// RSA + AES-GCM/CCM/ARIA (non-PFS, 128-bit)
0x009c
,
// TLS_RSA_WITH_AES_128_GCM_SHA256
0xc0a0
,
// TLS_RSA_WITH_AES_128_CCM_8
0xc09c
,
// TLS_RSA_WITH_AES_128_CCM
0xc050
,
// TLS_RSA_WITH_ARIA_128_GCM_SHA256
// RSA + AES-CBC (non-PFS, legacy)
0x003d
,
// TLS_RSA_WITH_AES_256_CBC_SHA256
0x003c
,
// TLS_RSA_WITH_AES_128_CBC_SHA256
0x0035
,
// TLS_RSA_WITH_AES_256_CBC_SHA
0x002f
,
// TLS_RSA_WITH_AES_128_CBC_SHA
// Renegotiation indication
0x00ff
,
// TLS_EMPTY_RENEGOTIATION_INFO_SCSV
}
// defaultCurves contains the 10 supported groups from Claude CLI (including FFDHE)
defaultCurves
=
[]
utls
.
CurveID
{
utls
.
X25519
,
// 0x001d
utls
.
CurveP256
,
// 0x0017 (secp256r1)
utls
.
CurveID
(
0x001e
),
// x448
utls
.
CurveP521
,
// 0x0019 (secp521r1)
utls
.
CurveP384
,
// 0x0018 (secp384r1)
utls
.
CurveID
(
0x0100
),
// ffdhe2048
utls
.
CurveID
(
0x0101
),
// ffdhe3072
utls
.
CurveID
(
0x0102
),
// ffdhe4096
utls
.
CurveID
(
0x0103
),
// ffdhe6144
utls
.
CurveID
(
0x0104
),
// ffdhe8192
}
// defaultPointFormats contains all 3 point formats from Claude CLI
defaultPointFormats
=
[]
uint8
{
0
,
// uncompressed
1
,
// ansiX962_compressed_prime
2
,
// ansiX962_compressed_char2
}
// defaultSignatureAlgorithms contains the 20 signature algorithms from Claude CLI
defaultSignatureAlgorithms
=
[]
utls
.
SignatureScheme
{
0x0403
,
// ecdsa_secp256r1_sha256
0x0503
,
// ecdsa_secp384r1_sha384
0x0603
,
// ecdsa_secp521r1_sha512
0x0807
,
// ed25519
0x0808
,
// ed448
0x0809
,
// rsa_pss_pss_sha256
0x080a
,
// rsa_pss_pss_sha384
0x080b
,
// rsa_pss_pss_sha512
0x0804
,
// rsa_pss_rsae_sha256
0x0805
,
// rsa_pss_rsae_sha384
0x0806
,
// rsa_pss_rsae_sha512
0x0401
,
// rsa_pkcs1_sha256
0x0501
,
// rsa_pkcs1_sha384
0x0601
,
// rsa_pkcs1_sha512
0x0303
,
// ecdsa_sha224
0x0301
,
// rsa_pkcs1_sha224
0x0302
,
// dsa_sha224
0x0402
,
// dsa_sha256
0x0502
,
// dsa_sha384
0x0602
,
// dsa_sha512
}
)
// NewDialer creates a new TLS fingerprint dialer.
// baseDialer is used for TCP connection establishment (supports proxy scenarios).
// If baseDialer is nil, direct TCP dial is used.
func
NewDialer
(
profile
*
Profile
,
baseDialer
func
(
ctx
context
.
Context
,
network
,
addr
string
)
(
net
.
Conn
,
error
))
*
Dialer
{
if
baseDialer
==
nil
{
baseDialer
=
(
&
net
.
Dialer
{})
.
DialContext
}
return
&
Dialer
{
profile
:
profile
,
baseDialer
:
baseDialer
}
}
// NewHTTPProxyDialer creates a new TLS fingerprint dialer that works through HTTP/HTTPS proxies.
// It establishes a CONNECT tunnel before performing TLS handshake with custom fingerprint.
func
NewHTTPProxyDialer
(
profile
*
Profile
,
proxyURL
*
url
.
URL
)
*
HTTPProxyDialer
{
return
&
HTTPProxyDialer
{
profile
:
profile
,
proxyURL
:
proxyURL
}
}
// NewSOCKS5ProxyDialer creates a new TLS fingerprint dialer that works through SOCKS5 proxies.
// It establishes a SOCKS5 tunnel before performing TLS handshake with custom fingerprint.
func
NewSOCKS5ProxyDialer
(
profile
*
Profile
,
proxyURL
*
url
.
URL
)
*
SOCKS5ProxyDialer
{
return
&
SOCKS5ProxyDialer
{
profile
:
profile
,
proxyURL
:
proxyURL
}
}
// DialTLSContext establishes a TLS connection through SOCKS5 proxy with the configured fingerprint.
// Flow: SOCKS5 CONNECT to target -> TLS handshake with utls on the tunnel
func
(
d
*
SOCKS5ProxyDialer
)
DialTLSContext
(
ctx
context
.
Context
,
network
,
addr
string
)
(
net
.
Conn
,
error
)
{
slog
.
Debug
(
"tls_fingerprint_socks5_connecting"
,
"proxy"
,
d
.
proxyURL
.
Host
,
"target"
,
addr
)
// Step 1: Create SOCKS5 dialer
var
auth
*
proxy
.
Auth
if
d
.
proxyURL
.
User
!=
nil
{
username
:=
d
.
proxyURL
.
User
.
Username
()
password
,
_
:=
d
.
proxyURL
.
User
.
Password
()
auth
=
&
proxy
.
Auth
{
User
:
username
,
Password
:
password
,
}
}
// Determine proxy address
proxyAddr
:=
d
.
proxyURL
.
Host
if
d
.
proxyURL
.
Port
()
==
""
{
proxyAddr
=
net
.
JoinHostPort
(
d
.
proxyURL
.
Hostname
(),
"1080"
)
// Default SOCKS5 port
}
socksDialer
,
err
:=
proxy
.
SOCKS5
(
"tcp"
,
proxyAddr
,
auth
,
proxy
.
Direct
)
if
err
!=
nil
{
slog
.
Debug
(
"tls_fingerprint_socks5_dialer_failed"
,
"error"
,
err
)
return
nil
,
fmt
.
Errorf
(
"create SOCKS5 dialer: %w"
,
err
)
}
// Step 2: Establish SOCKS5 tunnel to target
slog
.
Debug
(
"tls_fingerprint_socks5_establishing_tunnel"
,
"target"
,
addr
)
conn
,
err
:=
socksDialer
.
Dial
(
"tcp"
,
addr
)
if
err
!=
nil
{
slog
.
Debug
(
"tls_fingerprint_socks5_connect_failed"
,
"error"
,
err
)
return
nil
,
fmt
.
Errorf
(
"SOCKS5 connect: %w"
,
err
)
}
slog
.
Debug
(
"tls_fingerprint_socks5_tunnel_established"
)
// Step 3: Perform TLS handshake on the tunnel with utls fingerprint
host
,
_
,
err
:=
net
.
SplitHostPort
(
addr
)
if
err
!=
nil
{
host
=
addr
}
slog
.
Debug
(
"tls_fingerprint_socks5_starting_handshake"
,
"host"
,
host
)
// Build ClientHello specification from profile (Node.js/Claude CLI fingerprint)
spec
:=
buildClientHelloSpecFromProfile
(
d
.
profile
)
slog
.
Debug
(
"tls_fingerprint_socks5_clienthello_spec"
,
"cipher_suites"
,
len
(
spec
.
CipherSuites
),
"extensions"
,
len
(
spec
.
Extensions
),
"compression_methods"
,
spec
.
CompressionMethods
,
"tls_vers_max"
,
fmt
.
Sprintf
(
"0x%04x"
,
spec
.
TLSVersMax
),
"tls_vers_min"
,
fmt
.
Sprintf
(
"0x%04x"
,
spec
.
TLSVersMin
))
if
d
.
profile
!=
nil
{
slog
.
Debug
(
"tls_fingerprint_socks5_using_profile"
,
"name"
,
d
.
profile
.
Name
,
"grease"
,
d
.
profile
.
EnableGREASE
)
}
// Create uTLS connection on the tunnel
tlsConn
:=
utls
.
UClient
(
conn
,
&
utls
.
Config
{
ServerName
:
host
,
},
utls
.
HelloCustom
)
if
err
:=
tlsConn
.
ApplyPreset
(
spec
);
err
!=
nil
{
slog
.
Debug
(
"tls_fingerprint_socks5_apply_preset_failed"
,
"error"
,
err
)
_
=
conn
.
Close
()
return
nil
,
fmt
.
Errorf
(
"apply TLS preset: %w"
,
err
)
}
if
err
:=
tlsConn
.
Handshake
();
err
!=
nil
{
slog
.
Debug
(
"tls_fingerprint_socks5_handshake_failed"
,
"error"
,
err
)
_
=
conn
.
Close
()
return
nil
,
fmt
.
Errorf
(
"TLS handshake failed: %w"
,
err
)
}
state
:=
tlsConn
.
ConnectionState
()
slog
.
Debug
(
"tls_fingerprint_socks5_handshake_success"
,
"version"
,
fmt
.
Sprintf
(
"0x%04x"
,
state
.
Version
),
"cipher_suite"
,
fmt
.
Sprintf
(
"0x%04x"
,
state
.
CipherSuite
),
"alpn"
,
state
.
NegotiatedProtocol
)
return
tlsConn
,
nil
}
// DialTLSContext establishes a TLS connection through HTTP proxy with the configured fingerprint.
// Flow: TCP connect to proxy -> CONNECT tunnel -> TLS handshake with utls
func
(
d
*
HTTPProxyDialer
)
DialTLSContext
(
ctx
context
.
Context
,
network
,
addr
string
)
(
net
.
Conn
,
error
)
{
slog
.
Debug
(
"tls_fingerprint_http_proxy_connecting"
,
"proxy"
,
d
.
proxyURL
.
Host
,
"target"
,
addr
)
// Step 1: TCP connect to proxy server
var
proxyAddr
string
if
d
.
proxyURL
.
Port
()
!=
""
{
proxyAddr
=
d
.
proxyURL
.
Host
}
else
{
// Default ports
if
d
.
proxyURL
.
Scheme
==
"https"
{
proxyAddr
=
net
.
JoinHostPort
(
d
.
proxyURL
.
Hostname
(),
"443"
)
}
else
{
proxyAddr
=
net
.
JoinHostPort
(
d
.
proxyURL
.
Hostname
(),
"80"
)
}
}
dialer
:=
&
net
.
Dialer
{}
conn
,
err
:=
dialer
.
DialContext
(
ctx
,
"tcp"
,
proxyAddr
)
if
err
!=
nil
{
slog
.
Debug
(
"tls_fingerprint_http_proxy_connect_failed"
,
"error"
,
err
)
return
nil
,
fmt
.
Errorf
(
"connect to proxy: %w"
,
err
)
}
slog
.
Debug
(
"tls_fingerprint_http_proxy_connected"
,
"proxy_addr"
,
proxyAddr
)
// Step 2: Send CONNECT request to establish tunnel
req
:=
&
http
.
Request
{
Method
:
"CONNECT"
,
URL
:
&
url
.
URL
{
Opaque
:
addr
},
Host
:
addr
,
Header
:
make
(
http
.
Header
),
}
// Add proxy authentication if present
if
d
.
proxyURL
.
User
!=
nil
{
username
:=
d
.
proxyURL
.
User
.
Username
()
password
,
_
:=
d
.
proxyURL
.
User
.
Password
()
auth
:=
base64
.
StdEncoding
.
EncodeToString
([]
byte
(
username
+
":"
+
password
))
req
.
Header
.
Set
(
"Proxy-Authorization"
,
"Basic "
+
auth
)
}
slog
.
Debug
(
"tls_fingerprint_http_proxy_sending_connect"
,
"target"
,
addr
)
if
err
:=
req
.
Write
(
conn
);
err
!=
nil
{
_
=
conn
.
Close
()
slog
.
Debug
(
"tls_fingerprint_http_proxy_write_failed"
,
"error"
,
err
)
return
nil
,
fmt
.
Errorf
(
"write CONNECT request: %w"
,
err
)
}
// Step 3: Read CONNECT response
br
:=
bufio
.
NewReader
(
conn
)
resp
,
err
:=
http
.
ReadResponse
(
br
,
req
)
if
err
!=
nil
{
_
=
conn
.
Close
()
slog
.
Debug
(
"tls_fingerprint_http_proxy_read_response_failed"
,
"error"
,
err
)
return
nil
,
fmt
.
Errorf
(
"read CONNECT response: %w"
,
err
)
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
if
resp
.
StatusCode
!=
http
.
StatusOK
{
_
=
conn
.
Close
()
slog
.
Debug
(
"tls_fingerprint_http_proxy_connect_failed_status"
,
"status_code"
,
resp
.
StatusCode
,
"status"
,
resp
.
Status
)
return
nil
,
fmt
.
Errorf
(
"proxy CONNECT failed: %s"
,
resp
.
Status
)
}
slog
.
Debug
(
"tls_fingerprint_http_proxy_tunnel_established"
)
// Step 4: Perform TLS handshake on the tunnel with utls fingerprint
host
,
_
,
err
:=
net
.
SplitHostPort
(
addr
)
if
err
!=
nil
{
host
=
addr
}
slog
.
Debug
(
"tls_fingerprint_http_proxy_starting_handshake"
,
"host"
,
host
)
// Build ClientHello specification (reuse the shared method)
spec
:=
buildClientHelloSpecFromProfile
(
d
.
profile
)
slog
.
Debug
(
"tls_fingerprint_http_proxy_clienthello_spec"
,
"cipher_suites"
,
len
(
spec
.
CipherSuites
),
"extensions"
,
len
(
spec
.
Extensions
))
if
d
.
profile
!=
nil
{
slog
.
Debug
(
"tls_fingerprint_http_proxy_using_profile"
,
"name"
,
d
.
profile
.
Name
,
"grease"
,
d
.
profile
.
EnableGREASE
)
}
// Create uTLS connection on the tunnel
// Note: TLS 1.3 cipher suites are handled automatically by utls when TLS 1.3 is in SupportedVersions
tlsConn
:=
utls
.
UClient
(
conn
,
&
utls
.
Config
{
ServerName
:
host
,
},
utls
.
HelloCustom
)
if
err
:=
tlsConn
.
ApplyPreset
(
spec
);
err
!=
nil
{
slog
.
Debug
(
"tls_fingerprint_http_proxy_apply_preset_failed"
,
"error"
,
err
)
_
=
conn
.
Close
()
return
nil
,
fmt
.
Errorf
(
"apply TLS preset: %w"
,
err
)
}
if
err
:=
tlsConn
.
HandshakeContext
(
ctx
);
err
!=
nil
{
slog
.
Debug
(
"tls_fingerprint_http_proxy_handshake_failed"
,
"error"
,
err
)
_
=
conn
.
Close
()
return
nil
,
fmt
.
Errorf
(
"TLS handshake failed: %w"
,
err
)
}
state
:=
tlsConn
.
ConnectionState
()
slog
.
Debug
(
"tls_fingerprint_http_proxy_handshake_success"
,
"version"
,
fmt
.
Sprintf
(
"0x%04x"
,
state
.
Version
),
"cipher_suite"
,
fmt
.
Sprintf
(
"0x%04x"
,
state
.
CipherSuite
),
"alpn"
,
state
.
NegotiatedProtocol
)
return
tlsConn
,
nil
}
// DialTLSContext establishes a TLS connection with the configured fingerprint.
// This method is designed to be used as http.Transport.DialTLSContext.
func
(
d
*
Dialer
)
DialTLSContext
(
ctx
context
.
Context
,
network
,
addr
string
)
(
net
.
Conn
,
error
)
{
// Establish TCP connection using base dialer (supports proxy)
slog
.
Debug
(
"tls_fingerprint_dialing_tcp"
,
"addr"
,
addr
)
conn
,
err
:=
d
.
baseDialer
(
ctx
,
network
,
addr
)
if
err
!=
nil
{
slog
.
Debug
(
"tls_fingerprint_tcp_dial_failed"
,
"error"
,
err
)
return
nil
,
err
}
slog
.
Debug
(
"tls_fingerprint_tcp_connected"
,
"addr"
,
addr
)
// Extract hostname for SNI
host
,
_
,
err
:=
net
.
SplitHostPort
(
addr
)
if
err
!=
nil
{
host
=
addr
}
slog
.
Debug
(
"tls_fingerprint_sni_hostname"
,
"host"
,
host
)
// Build ClientHello specification
spec
:=
d
.
buildClientHelloSpec
()
slog
.
Debug
(
"tls_fingerprint_clienthello_spec"
,
"cipher_suites"
,
len
(
spec
.
CipherSuites
),
"extensions"
,
len
(
spec
.
Extensions
))
// Log profile info
if
d
.
profile
!=
nil
{
slog
.
Debug
(
"tls_fingerprint_using_profile"
,
"name"
,
d
.
profile
.
Name
,
"grease"
,
d
.
profile
.
EnableGREASE
)
}
else
{
slog
.
Debug
(
"tls_fingerprint_using_default_profile"
)
}
// Create uTLS connection
// Note: TLS 1.3 cipher suites are handled automatically by utls when TLS 1.3 is in SupportedVersions
tlsConn
:=
utls
.
UClient
(
conn
,
&
utls
.
Config
{
ServerName
:
host
,
},
utls
.
HelloCustom
)
// Apply fingerprint
if
err
:=
tlsConn
.
ApplyPreset
(
spec
);
err
!=
nil
{
slog
.
Debug
(
"tls_fingerprint_apply_preset_failed"
,
"error"
,
err
)
_
=
conn
.
Close
()
return
nil
,
err
}
slog
.
Debug
(
"tls_fingerprint_preset_applied"
)
// Perform TLS handshake
if
err
:=
tlsConn
.
HandshakeContext
(
ctx
);
err
!=
nil
{
slog
.
Debug
(
"tls_fingerprint_handshake_failed"
,
"error"
,
err
,
"local_addr"
,
conn
.
LocalAddr
(),
"remote_addr"
,
conn
.
RemoteAddr
())
_
=
conn
.
Close
()
return
nil
,
fmt
.
Errorf
(
"TLS handshake failed: %w"
,
err
)
}
// Log successful handshake details
state
:=
tlsConn
.
ConnectionState
()
slog
.
Debug
(
"tls_fingerprint_handshake_success"
,
"version"
,
fmt
.
Sprintf
(
"0x%04x"
,
state
.
Version
),
"cipher_suite"
,
fmt
.
Sprintf
(
"0x%04x"
,
state
.
CipherSuite
),
"alpn"
,
state
.
NegotiatedProtocol
)
return
tlsConn
,
nil
}
// buildClientHelloSpec constructs the ClientHello specification based on the profile.
func
(
d
*
Dialer
)
buildClientHelloSpec
()
*
utls
.
ClientHelloSpec
{
return
buildClientHelloSpecFromProfile
(
d
.
profile
)
}
// toUTLSCurves converts uint16 slice to utls.CurveID slice.
func
toUTLSCurves
(
curves
[]
uint16
)
[]
utls
.
CurveID
{
result
:=
make
([]
utls
.
CurveID
,
len
(
curves
))
for
i
,
c
:=
range
curves
{
result
[
i
]
=
utls
.
CurveID
(
c
)
}
return
result
}
// buildClientHelloSpecFromProfile constructs ClientHelloSpec from a Profile.
// This is a standalone function that can be used by both Dialer and HTTPProxyDialer.
func
buildClientHelloSpecFromProfile
(
profile
*
Profile
)
*
utls
.
ClientHelloSpec
{
// Get cipher suites
var
cipherSuites
[]
uint16
if
profile
!=
nil
&&
len
(
profile
.
CipherSuites
)
>
0
{
cipherSuites
=
profile
.
CipherSuites
}
else
{
cipherSuites
=
defaultCipherSuites
}
// Get curves
var
curves
[]
utls
.
CurveID
if
profile
!=
nil
&&
len
(
profile
.
Curves
)
>
0
{
curves
=
toUTLSCurves
(
profile
.
Curves
)
}
else
{
curves
=
defaultCurves
}
// Get point formats
var
pointFormats
[]
uint8
if
profile
!=
nil
&&
len
(
profile
.
PointFormats
)
>
0
{
pointFormats
=
profile
.
PointFormats
}
else
{
pointFormats
=
defaultPointFormats
}
// Check if GREASE is enabled
enableGREASE
:=
profile
!=
nil
&&
profile
.
EnableGREASE
extensions
:=
make
([]
utls
.
TLSExtension
,
0
,
16
)
if
enableGREASE
{
extensions
=
append
(
extensions
,
&
utls
.
UtlsGREASEExtension
{})
}
// SNI extension - MUST be explicitly added for HelloCustom mode
// utls will populate the server name from Config.ServerName
extensions
=
append
(
extensions
,
&
utls
.
SNIExtension
{})
// Claude CLI extension order (captured from tshark):
// server_name(0), ec_point_formats(11), supported_groups(10), session_ticket(35),
// alpn(16), encrypt_then_mac(22), extended_master_secret(23),
// signature_algorithms(13), supported_versions(43),
// psk_key_exchange_modes(45), key_share(51)
extensions
=
append
(
extensions
,
&
utls
.
SupportedPointsExtension
{
SupportedPoints
:
pointFormats
},
&
utls
.
SupportedCurvesExtension
{
Curves
:
curves
},
&
utls
.
SessionTicketExtension
{},
&
utls
.
ALPNExtension
{
AlpnProtocols
:
[]
string
{
"http/1.1"
}},
&
utls
.
GenericExtension
{
Id
:
22
},
&
utls
.
ExtendedMasterSecretExtension
{},
&
utls
.
SignatureAlgorithmsExtension
{
SupportedSignatureAlgorithms
:
defaultSignatureAlgorithms
},
&
utls
.
SupportedVersionsExtension
{
Versions
:
[]
uint16
{
utls
.
VersionTLS13
,
utls
.
VersionTLS12
,
}},
&
utls
.
PSKKeyExchangeModesExtension
{
Modes
:
[]
uint8
{
utls
.
PskModeDHE
}},
&
utls
.
KeyShareExtension
{
KeyShares
:
[]
utls
.
KeyShare
{
{
Group
:
utls
.
X25519
},
}},
)
if
enableGREASE
{
extensions
=
append
(
extensions
,
&
utls
.
UtlsGREASEExtension
{})
}
return
&
utls
.
ClientHelloSpec
{
CipherSuites
:
cipherSuites
,
CompressionMethods
:
[]
uint8
{
0
},
// null compression only (standard)
Extensions
:
extensions
,
TLSVersMax
:
utls
.
VersionTLS13
,
TLSVersMin
:
utls
.
VersionTLS10
,
}
}
backend/internal/pkg/tlsfingerprint/dialer_test.go
0 → 100644
View file @
292f25f9
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
//
// Integration tests for verifying TLS fingerprint correctness.
// These tests make actual network requests and should be run manually.
//
// Run with: go test -v ./internal/pkg/tlsfingerprint/...
// Run integration tests: go test -v -run TestJA3 ./internal/pkg/tlsfingerprint/...
package
tlsfingerprint
import
(
"context"
"encoding/json"
"io"
"net/http"
"net/url"
"strings"
"testing"
"time"
)
// FingerprintResponse represents the response from tls.peet.ws/api/all.
type
FingerprintResponse
struct
{
IP
string
`json:"ip"`
TLS
TLSInfo
`json:"tls"`
HTTP2
any
`json:"http2"`
}
// TLSInfo contains TLS fingerprint details.
type
TLSInfo
struct
{
JA3
string
`json:"ja3"`
JA3Hash
string
`json:"ja3_hash"`
JA4
string
`json:"ja4"`
PeetPrint
string
`json:"peetprint"`
PeetPrintHash
string
`json:"peetprint_hash"`
ClientRandom
string
`json:"client_random"`
SessionID
string
`json:"session_id"`
}
// TestDialerBasicConnection tests that the dialer can establish TLS connections.
func
TestDialerBasicConnection
(
t
*
testing
.
T
)
{
if
testing
.
Short
()
{
t
.
Skip
(
"skipping network test in short mode"
)
}
// Create a dialer with default profile
profile
:=
&
Profile
{
Name
:
"Test Profile"
,
EnableGREASE
:
false
,
}
dialer
:=
NewDialer
(
profile
,
nil
)
// Create HTTP client with custom TLS dialer
client
:=
&
http
.
Client
{
Transport
:
&
http
.
Transport
{
DialTLSContext
:
dialer
.
DialTLSContext
,
},
Timeout
:
30
*
time
.
Second
,
}
// Make a request to a known HTTPS endpoint
resp
,
err
:=
client
.
Get
(
"https://www.google.com"
)
if
err
!=
nil
{
t
.
Fatalf
(
"failed to connect: %v"
,
err
)
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
if
resp
.
StatusCode
!=
http
.
StatusOK
{
t
.
Errorf
(
"expected status 200, got %d"
,
resp
.
StatusCode
)
}
}
// TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value.
// This test uses tls.peet.ws to verify the fingerprint.
// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x)
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP)
func
TestJA3Fingerprint
(
t
*
testing
.
T
)
{
// Skip if network is unavailable or if running in short mode
if
testing
.
Short
()
{
t
.
Skip
(
"skipping integration test in short mode"
)
}
profile
:=
&
Profile
{
Name
:
"Claude CLI Test"
,
EnableGREASE
:
false
,
}
dialer
:=
NewDialer
(
profile
,
nil
)
client
:=
&
http
.
Client
{
Transport
:
&
http
.
Transport
{
DialTLSContext
:
dialer
.
DialTLSContext
,
},
Timeout
:
30
*
time
.
Second
,
}
// Use tls.peet.ws fingerprint detection API
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
30
*
time
.
Second
)
defer
cancel
()
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
"GET"
,
"https://tls.peet.ws/api/all"
,
nil
)
if
err
!=
nil
{
t
.
Fatalf
(
"failed to create request: %v"
,
err
)
}
req
.
Header
.
Set
(
"User-Agent"
,
"Claude Code/2.0.0 Node.js/20.0.0"
)
resp
,
err
:=
client
.
Do
(
req
)
if
err
!=
nil
{
t
.
Fatalf
(
"failed to get fingerprint: %v"
,
err
)
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
body
,
err
:=
io
.
ReadAll
(
resp
.
Body
)
if
err
!=
nil
{
t
.
Fatalf
(
"failed to read response: %v"
,
err
)
}
var
fpResp
FingerprintResponse
if
err
:=
json
.
Unmarshal
(
body
,
&
fpResp
);
err
!=
nil
{
t
.
Logf
(
"Response body: %s"
,
string
(
body
))
t
.
Fatalf
(
"failed to parse fingerprint response: %v"
,
err
)
}
// Log all fingerprint information
t
.
Logf
(
"JA3: %s"
,
fpResp
.
TLS
.
JA3
)
t
.
Logf
(
"JA3 Hash: %s"
,
fpResp
.
TLS
.
JA3Hash
)
t
.
Logf
(
"JA4: %s"
,
fpResp
.
TLS
.
JA4
)
t
.
Logf
(
"PeetPrint: %s"
,
fpResp
.
TLS
.
PeetPrint
)
t
.
Logf
(
"PeetPrint Hash: %s"
,
fpResp
.
TLS
.
PeetPrintHash
)
// Verify JA3 hash matches expected value
expectedJA3Hash
:=
"1a28e69016765d92e3b381168d68922c"
if
fpResp
.
TLS
.
JA3Hash
==
expectedJA3Hash
{
t
.
Logf
(
"✓ JA3 hash matches expected value: %s"
,
expectedJA3Hash
)
}
else
{
t
.
Errorf
(
"✗ JA3 hash mismatch: got %s, expected %s"
,
fpResp
.
TLS
.
JA3Hash
,
expectedJA3Hash
)
}
// Verify JA4 fingerprint
// JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash]
// Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP)
// The suffix _a33745022dd6_1f22a2ca17c4 should match
expectedJA4Suffix
:=
"_a33745022dd6_1f22a2ca17c4"
if
strings
.
HasSuffix
(
fpResp
.
TLS
.
JA4
,
expectedJA4Suffix
)
{
t
.
Logf
(
"✓ JA4 suffix matches expected value: %s"
,
expectedJA4Suffix
)
}
else
{
t
.
Errorf
(
"✗ JA4 suffix mismatch: got %s, expected suffix %s"
,
fpResp
.
TLS
.
JA4
,
expectedJA4Suffix
)
}
// Verify JA4 prefix (t13d5911h1 or t13i5911h1)
// d = domain (SNI present), i = IP (no SNI)
// Since we connect to tls.peet.ws (domain), we expect 'd'
expectedJA4Prefix
:=
"t13d5911h1"
if
strings
.
HasPrefix
(
fpResp
.
TLS
.
JA4
,
expectedJA4Prefix
)
{
t
.
Logf
(
"✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)"
,
expectedJA4Prefix
)
}
else
{
// Also accept 'i' variant for IP connections
altPrefix
:=
"t13i5911h1"
if
strings
.
HasPrefix
(
fpResp
.
TLS
.
JA4
,
altPrefix
)
{
t
.
Logf
(
"✓ JA4 prefix matches (IP variant): %s"
,
altPrefix
)
}
else
{
t
.
Errorf
(
"✗ JA4 prefix mismatch: got %s, expected %s or %s"
,
fpResp
.
TLS
.
JA4
,
expectedJA4Prefix
,
altPrefix
)
}
}
// Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning)
if
strings
.
Contains
(
fpResp
.
TLS
.
JA3
,
"4866-4867-4865"
)
{
t
.
Logf
(
"✓ JA3 contains expected TLS 1.3 cipher suites"
)
}
else
{
t
.
Logf
(
"Warning: JA3 does not contain expected TLS 1.3 cipher suites"
)
}
// Verify extension list (should be 11 extensions including SNI)
// Expected: 0-11-10-35-16-22-23-13-43-45-51
expectedExtensions
:=
"0-11-10-35-16-22-23-13-43-45-51"
if
strings
.
Contains
(
fpResp
.
TLS
.
JA3
,
expectedExtensions
)
{
t
.
Logf
(
"✓ JA3 contains expected extension list: %s"
,
expectedExtensions
)
}
else
{
t
.
Logf
(
"Warning: JA3 extension list may differ"
)
}
}
// TestDialerWithProfile tests that different profiles produce different fingerprints.
func
TestDialerWithProfile
(
t
*
testing
.
T
)
{
// Create two dialers with different profiles
profile1
:=
&
Profile
{
Name
:
"Profile 1 - No GREASE"
,
EnableGREASE
:
false
,
}
profile2
:=
&
Profile
{
Name
:
"Profile 2 - With GREASE"
,
EnableGREASE
:
true
,
}
dialer1
:=
NewDialer
(
profile1
,
nil
)
dialer2
:=
NewDialer
(
profile2
,
nil
)
// Build specs and compare
// Note: We can't directly compare JA3 without making network requests
// but we can verify the specs are different
spec1
:=
dialer1
.
buildClientHelloSpec
()
spec2
:=
dialer2
.
buildClientHelloSpec
()
// Profile with GREASE should have more extensions
if
len
(
spec2
.
Extensions
)
<=
len
(
spec1
.
Extensions
)
{
t
.
Error
(
"expected GREASE profile to have more extensions"
)
}
}
// TestHTTPProxyDialerBasic tests HTTP proxy dialer creation.
// Note: This is a unit test - actual proxy testing requires a proxy server.
func
TestHTTPProxyDialerBasic
(
t
*
testing
.
T
)
{
profile
:=
&
Profile
{
Name
:
"Test Profile"
,
EnableGREASE
:
false
,
}
// Test that dialer is created without panic
proxyURL
:=
mustParseURL
(
"http://proxy.example.com:8080"
)
dialer
:=
NewHTTPProxyDialer
(
profile
,
proxyURL
)
if
dialer
==
nil
{
t
.
Fatal
(
"expected dialer to be created"
)
}
if
dialer
.
profile
!=
profile
{
t
.
Error
(
"expected profile to be set"
)
}
if
dialer
.
proxyURL
!=
proxyURL
{
t
.
Error
(
"expected proxyURL to be set"
)
}
}
// TestSOCKS5ProxyDialerBasic tests SOCKS5 proxy dialer creation.
// Note: This is a unit test - actual proxy testing requires a proxy server.
func
TestSOCKS5ProxyDialerBasic
(
t
*
testing
.
T
)
{
profile
:=
&
Profile
{
Name
:
"Test Profile"
,
EnableGREASE
:
false
,
}
// Test that dialer is created without panic
proxyURL
:=
mustParseURL
(
"socks5://proxy.example.com:1080"
)
dialer
:=
NewSOCKS5ProxyDialer
(
profile
,
proxyURL
)
if
dialer
==
nil
{
t
.
Fatal
(
"expected dialer to be created"
)
}
if
dialer
.
profile
!=
profile
{
t
.
Error
(
"expected profile to be set"
)
}
if
dialer
.
proxyURL
!=
proxyURL
{
t
.
Error
(
"expected proxyURL to be set"
)
}
}
// TestBuildClientHelloSpec tests ClientHello spec construction.
func
TestBuildClientHelloSpec
(
t
*
testing
.
T
)
{
// Test with nil profile (should use defaults)
spec
:=
buildClientHelloSpecFromProfile
(
nil
)
if
len
(
spec
.
CipherSuites
)
==
0
{
t
.
Error
(
"expected cipher suites to be set"
)
}
if
len
(
spec
.
Extensions
)
==
0
{
t
.
Error
(
"expected extensions to be set"
)
}
// Verify default cipher suites are used
if
len
(
spec
.
CipherSuites
)
!=
len
(
defaultCipherSuites
)
{
t
.
Errorf
(
"expected %d cipher suites, got %d"
,
len
(
defaultCipherSuites
),
len
(
spec
.
CipherSuites
))
}
// Test with custom profile
customProfile
:=
&
Profile
{
Name
:
"Custom"
,
EnableGREASE
:
false
,
CipherSuites
:
[]
uint16
{
0x1301
,
0x1302
},
}
spec
=
buildClientHelloSpecFromProfile
(
customProfile
)
if
len
(
spec
.
CipherSuites
)
!=
2
{
t
.
Errorf
(
"expected 2 cipher suites, got %d"
,
len
(
spec
.
CipherSuites
))
}
}
// TestToUTLSCurves tests curve ID conversion.
func
TestToUTLSCurves
(
t
*
testing
.
T
)
{
input
:=
[]
uint16
{
0x001d
,
0x0017
,
0x0018
}
result
:=
toUTLSCurves
(
input
)
if
len
(
result
)
!=
len
(
input
)
{
t
.
Errorf
(
"expected %d curves, got %d"
,
len
(
input
),
len
(
result
))
}
for
i
,
curve
:=
range
result
{
if
uint16
(
curve
)
!=
input
[
i
]
{
t
.
Errorf
(
"curve %d: expected 0x%04x, got 0x%04x"
,
i
,
input
[
i
],
uint16
(
curve
))
}
}
}
// Helper function to parse URL without error handling.
func
mustParseURL
(
rawURL
string
)
*
url
.
URL
{
u
,
err
:=
url
.
Parse
(
rawURL
)
if
err
!=
nil
{
panic
(
err
)
}
return
u
}
backend/internal/pkg/tlsfingerprint/registry.go
0 → 100644
View file @
292f25f9
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
package
tlsfingerprint
import
(
"log/slog"
"sort"
"sync"
"github.com/Wei-Shaw/sub2api/internal/config"
)
// DefaultProfileName is the name of the built-in Claude CLI profile.
const
DefaultProfileName
=
"claude_cli_v2"
// Registry manages TLS fingerprint profiles.
// It holds a collection of profiles that can be used for TLS fingerprint simulation.
// Profiles are selected based on account ID using modulo operation.
type
Registry
struct
{
mu
sync
.
RWMutex
profiles
map
[
string
]
*
Profile
profileNames
[]
string
// Sorted list of profile names for deterministic selection
}
// NewRegistry creates a new TLS fingerprint profile registry.
// It initializes with the built-in default profile.
func
NewRegistry
()
*
Registry
{
r
:=
&
Registry
{
profiles
:
make
(
map
[
string
]
*
Profile
),
profileNames
:
make
([]
string
,
0
),
}
// Register the built-in default profile
r
.
registerBuiltinProfile
()
return
r
}
// NewRegistryFromConfig creates a new registry and loads profiles from config.
// If the config has custom profiles defined, they will be merged with the built-in default.
func
NewRegistryFromConfig
(
cfg
*
config
.
TLSFingerprintConfig
)
*
Registry
{
r
:=
NewRegistry
()
if
cfg
==
nil
||
!
cfg
.
Enabled
{
slog
.
Debug
(
"tls_registry_disabled"
,
"reason"
,
"disabled or no config"
)
return
r
}
// Load custom profiles from config
for
name
,
profileCfg
:=
range
cfg
.
Profiles
{
profile
:=
&
Profile
{
Name
:
profileCfg
.
Name
,
EnableGREASE
:
profileCfg
.
EnableGREASE
,
CipherSuites
:
profileCfg
.
CipherSuites
,
Curves
:
profileCfg
.
Curves
,
PointFormats
:
profileCfg
.
PointFormats
,
}
// If the profile has empty values, they will use defaults in dialer
r
.
RegisterProfile
(
name
,
profile
)
slog
.
Debug
(
"tls_registry_loaded_profile"
,
"key"
,
name
,
"name"
,
profileCfg
.
Name
)
}
slog
.
Debug
(
"tls_registry_initialized"
,
"profile_count"
,
len
(
r
.
profileNames
),
"profiles"
,
r
.
profileNames
)
return
r
}
// registerBuiltinProfile adds the default Claude CLI profile to the registry.
func
(
r
*
Registry
)
registerBuiltinProfile
()
{
defaultProfile
:=
&
Profile
{
Name
:
"Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x)"
,
EnableGREASE
:
false
,
// Node.js does not use GREASE
// Empty slices will cause dialer to use built-in defaults
CipherSuites
:
nil
,
Curves
:
nil
,
PointFormats
:
nil
,
}
r
.
RegisterProfile
(
DefaultProfileName
,
defaultProfile
)
}
// RegisterProfile adds or updates a profile in the registry.
func
(
r
*
Registry
)
RegisterProfile
(
name
string
,
profile
*
Profile
)
{
r
.
mu
.
Lock
()
defer
r
.
mu
.
Unlock
()
// Check if this is a new profile
_
,
exists
:=
r
.
profiles
[
name
]
r
.
profiles
[
name
]
=
profile
if
!
exists
{
r
.
profileNames
=
append
(
r
.
profileNames
,
name
)
// Keep names sorted for deterministic selection
sort
.
Strings
(
r
.
profileNames
)
}
}
// GetProfile returns a profile by name.
// Returns nil if the profile does not exist.
func
(
r
*
Registry
)
GetProfile
(
name
string
)
*
Profile
{
r
.
mu
.
RLock
()
defer
r
.
mu
.
RUnlock
()
return
r
.
profiles
[
name
]
}
// GetDefaultProfile returns the built-in default profile.
func
(
r
*
Registry
)
GetDefaultProfile
()
*
Profile
{
return
r
.
GetProfile
(
DefaultProfileName
)
}
// GetProfileByAccountID returns a profile for the given account ID.
// The profile is selected using: profileNames[accountID % len(profiles)]
// This ensures deterministic profile assignment for each account.
func
(
r
*
Registry
)
GetProfileByAccountID
(
accountID
int64
)
*
Profile
{
r
.
mu
.
RLock
()
defer
r
.
mu
.
RUnlock
()
if
len
(
r
.
profileNames
)
==
0
{
return
nil
}
// Use modulo to select profile index
// Use absolute value to handle negative IDs (though unlikely)
idx
:=
accountID
if
idx
<
0
{
idx
=
-
idx
}
selectedIndex
:=
int
(
idx
%
int64
(
len
(
r
.
profileNames
)))
selectedName
:=
r
.
profileNames
[
selectedIndex
]
return
r
.
profiles
[
selectedName
]
}
// ProfileCount returns the number of registered profiles.
func
(
r
*
Registry
)
ProfileCount
()
int
{
r
.
mu
.
RLock
()
defer
r
.
mu
.
RUnlock
()
return
len
(
r
.
profiles
)
}
// ProfileNames returns a sorted list of all registered profile names.
func
(
r
*
Registry
)
ProfileNames
()
[]
string
{
r
.
mu
.
RLock
()
defer
r
.
mu
.
RUnlock
()
// Return a copy to prevent modification
names
:=
make
([]
string
,
len
(
r
.
profileNames
))
copy
(
names
,
r
.
profileNames
)
return
names
}
// Global registry instance for convenience
var
globalRegistry
*
Registry
var
globalRegistryOnce
sync
.
Once
// GlobalRegistry returns the global TLS fingerprint registry.
// The registry is lazily initialized with the default profile.
func
GlobalRegistry
()
*
Registry
{
globalRegistryOnce
.
Do
(
func
()
{
globalRegistry
=
NewRegistry
()
})
return
globalRegistry
}
// InitGlobalRegistry initializes the global registry with configuration.
// This should be called during application startup.
// It is safe to call multiple times; subsequent calls will update the registry.
func
InitGlobalRegistry
(
cfg
*
config
.
TLSFingerprintConfig
)
*
Registry
{
globalRegistryOnce
.
Do
(
func
()
{
globalRegistry
=
NewRegistryFromConfig
(
cfg
)
})
return
globalRegistry
}
backend/internal/pkg/tlsfingerprint/registry_test.go
0 → 100644
View file @
292f25f9
package
tlsfingerprint
import
(
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
)
func
TestNewRegistry
(
t
*
testing
.
T
)
{
r
:=
NewRegistry
()
// Should have exactly one profile (the default)
if
r
.
ProfileCount
()
!=
1
{
t
.
Errorf
(
"expected 1 profile, got %d"
,
r
.
ProfileCount
())
}
// Should have the default profile
profile
:=
r
.
GetDefaultProfile
()
if
profile
==
nil
{
t
.
Error
(
"expected default profile to exist"
)
}
// Default profile name should be in the list
names
:=
r
.
ProfileNames
()
if
len
(
names
)
!=
1
||
names
[
0
]
!=
DefaultProfileName
{
t
.
Errorf
(
"expected profile names to be [%s], got %v"
,
DefaultProfileName
,
names
)
}
}
func
TestRegisterProfile
(
t
*
testing
.
T
)
{
r
:=
NewRegistry
()
// Register a new profile
customProfile
:=
&
Profile
{
Name
:
"Custom Profile"
,
EnableGREASE
:
true
,
}
r
.
RegisterProfile
(
"custom"
,
customProfile
)
// Should now have 2 profiles
if
r
.
ProfileCount
()
!=
2
{
t
.
Errorf
(
"expected 2 profiles, got %d"
,
r
.
ProfileCount
())
}
// Should be able to retrieve the custom profile
retrieved
:=
r
.
GetProfile
(
"custom"
)
if
retrieved
==
nil
{
t
.
Fatal
(
"expected custom profile to exist"
)
}
if
retrieved
.
Name
!=
"Custom Profile"
{
t
.
Errorf
(
"expected profile name 'Custom Profile', got '%s'"
,
retrieved
.
Name
)
}
if
!
retrieved
.
EnableGREASE
{
t
.
Error
(
"expected EnableGREASE to be true"
)
}
}
func
TestGetProfile
(
t
*
testing
.
T
)
{
r
:=
NewRegistry
()
// Get existing profile
profile
:=
r
.
GetProfile
(
DefaultProfileName
)
if
profile
==
nil
{
t
.
Error
(
"expected default profile to exist"
)
}
// Get non-existing profile
nonExistent
:=
r
.
GetProfile
(
"nonexistent"
)
if
nonExistent
!=
nil
{
t
.
Error
(
"expected nil for non-existent profile"
)
}
}
func
TestGetProfileByAccountID
(
t
*
testing
.
T
)
{
r
:=
NewRegistry
()
// With only default profile, all account IDs should return the same profile
for
i
:=
int64
(
0
);
i
<
10
;
i
++
{
profile
:=
r
.
GetProfileByAccountID
(
i
)
if
profile
==
nil
{
t
.
Errorf
(
"expected profile for account %d, got nil"
,
i
)
}
}
// Add more profiles
r
.
RegisterProfile
(
"profile_a"
,
&
Profile
{
Name
:
"Profile A"
})
r
.
RegisterProfile
(
"profile_b"
,
&
Profile
{
Name
:
"Profile B"
})
// Now we have 3 profiles: claude_cli_v2, profile_a, profile_b
// Names are sorted, so order is: claude_cli_v2, profile_a, profile_b
expectedOrder
:=
[]
string
{
DefaultProfileName
,
"profile_a"
,
"profile_b"
}
names
:=
r
.
ProfileNames
()
for
i
,
name
:=
range
expectedOrder
{
if
names
[
i
]
!=
name
{
t
.
Errorf
(
"expected name at index %d to be %s, got %s"
,
i
,
name
,
names
[
i
])
}
}
// Test modulo selection
// Account ID 0 % 3 = 0 -> claude_cli_v2
// Account ID 1 % 3 = 1 -> profile_a
// Account ID 2 % 3 = 2 -> profile_b
// Account ID 3 % 3 = 0 -> claude_cli_v2
testCases
:=
[]
struct
{
accountID
int64
expectedName
string
}{
{
0
,
"Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x)"
},
{
1
,
"Profile A"
},
{
2
,
"Profile B"
},
{
3
,
"Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x)"
},
{
4
,
"Profile A"
},
{
5
,
"Profile B"
},
{
100
,
"Profile A"
},
// 100 % 3 = 1
{
-
1
,
"Profile A"
},
// |-1| % 3 = 1
{
-
3
,
"Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x)"
},
// |-3| % 3 = 0
}
for
_
,
tc
:=
range
testCases
{
profile
:=
r
.
GetProfileByAccountID
(
tc
.
accountID
)
if
profile
==
nil
{
t
.
Errorf
(
"expected profile for account %d, got nil"
,
tc
.
accountID
)
continue
}
if
profile
.
Name
!=
tc
.
expectedName
{
t
.
Errorf
(
"account %d: expected profile name '%s', got '%s'"
,
tc
.
accountID
,
tc
.
expectedName
,
profile
.
Name
)
}
}
}
func
TestNewRegistryFromConfig
(
t
*
testing
.
T
)
{
// Test with nil config
r
:=
NewRegistryFromConfig
(
nil
)
if
r
.
ProfileCount
()
!=
1
{
t
.
Errorf
(
"expected 1 profile with nil config, got %d"
,
r
.
ProfileCount
())
}
// Test with disabled config
disabledCfg
:=
&
config
.
TLSFingerprintConfig
{
Enabled
:
false
,
}
r
=
NewRegistryFromConfig
(
disabledCfg
)
if
r
.
ProfileCount
()
!=
1
{
t
.
Errorf
(
"expected 1 profile with disabled config, got %d"
,
r
.
ProfileCount
())
}
// Test with enabled config and custom profiles
enabledCfg
:=
&
config
.
TLSFingerprintConfig
{
Enabled
:
true
,
Profiles
:
map
[
string
]
config
.
TLSProfileConfig
{
"custom1"
:
{
Name
:
"Custom Profile 1"
,
EnableGREASE
:
true
,
},
"custom2"
:
{
Name
:
"Custom Profile 2"
,
EnableGREASE
:
false
,
},
},
}
r
=
NewRegistryFromConfig
(
enabledCfg
)
// Should have 3 profiles: default + 2 custom
if
r
.
ProfileCount
()
!=
3
{
t
.
Errorf
(
"expected 3 profiles, got %d"
,
r
.
ProfileCount
())
}
// Check custom profiles exist
custom1
:=
r
.
GetProfile
(
"custom1"
)
if
custom1
==
nil
||
custom1
.
Name
!=
"Custom Profile 1"
{
t
.
Error
(
"expected custom1 profile to exist with correct name"
)
}
custom2
:=
r
.
GetProfile
(
"custom2"
)
if
custom2
==
nil
||
custom2
.
Name
!=
"Custom Profile 2"
{
t
.
Error
(
"expected custom2 profile to exist with correct name"
)
}
}
func
TestProfileNames
(
t
*
testing
.
T
)
{
r
:=
NewRegistry
()
// Add profiles in non-alphabetical order
r
.
RegisterProfile
(
"zebra"
,
&
Profile
{
Name
:
"Zebra"
})
r
.
RegisterProfile
(
"alpha"
,
&
Profile
{
Name
:
"Alpha"
})
r
.
RegisterProfile
(
"beta"
,
&
Profile
{
Name
:
"Beta"
})
names
:=
r
.
ProfileNames
()
// Should be sorted alphabetically
expected
:=
[]
string
{
"alpha"
,
"beta"
,
DefaultProfileName
,
"zebra"
}
if
len
(
names
)
!=
len
(
expected
)
{
t
.
Errorf
(
"expected %d names, got %d"
,
len
(
expected
),
len
(
names
))
}
for
i
,
name
:=
range
expected
{
if
names
[
i
]
!=
name
{
t
.
Errorf
(
"expected name at index %d to be %s, got %s"
,
i
,
name
,
names
[
i
])
}
}
// Test that returned slice is a copy (modifying it shouldn't affect registry)
names
[
0
]
=
"modified"
originalNames
:=
r
.
ProfileNames
()
if
originalNames
[
0
]
==
"modified"
{
t
.
Error
(
"modifying returned slice should not affect registry"
)
}
}
func
TestConcurrentAccess
(
t
*
testing
.
T
)
{
r
:=
NewRegistry
()
// Run concurrent reads and writes
done
:=
make
(
chan
bool
)
// Writers
for
i
:=
0
;
i
<
10
;
i
++
{
go
func
(
id
int
)
{
for
j
:=
0
;
j
<
100
;
j
++
{
r
.
RegisterProfile
(
"concurrent"
+
string
(
rune
(
'0'
+
id
)),
&
Profile
{
Name
:
"Concurrent"
})
}
done
<-
true
}(
i
)
}
// Readers
for
i
:=
0
;
i
<
10
;
i
++
{
go
func
(
id
int
)
{
for
j
:=
0
;
j
<
100
;
j
++
{
_
=
r
.
ProfileCount
()
_
=
r
.
ProfileNames
()
_
=
r
.
GetProfileByAccountID
(
int64
(
id
*
j
))
_
=
r
.
GetProfile
(
DefaultProfileName
)
}
done
<-
true
}(
i
)
}
// Wait for all goroutines
for
i
:=
0
;
i
<
20
;
i
++
{
<-
done
}
// Test should pass without data races (run with -race flag)
}
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment