Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
377bffe2
Commit
377bffe2
authored
Feb 03, 2026
by
yangjianbo
Browse files
Merge branch 'main' into test
parents
99250ec5
31fe0178
Changes
235
Hide whitespace changes
Inline
Side-by-side
backend/internal/service/antigravity_quota_scope.go
View file @
377bffe2
...
...
@@ -89,3 +89,30 @@ func (a *Account) antigravityQuotaScopeResetAt(scope AntigravityQuotaScope) *tim
}
return
&
resetAt
}
var
antigravityAllScopes
=
[]
AntigravityQuotaScope
{
AntigravityQuotaScopeClaude
,
AntigravityQuotaScopeGeminiText
,
AntigravityQuotaScopeGeminiImage
,
}
func
(
a
*
Account
)
GetAntigravityScopeRateLimits
()
map
[
string
]
int64
{
if
a
==
nil
||
a
.
Platform
!=
PlatformAntigravity
{
return
nil
}
now
:=
time
.
Now
()
result
:=
make
(
map
[
string
]
int64
)
for
_
,
scope
:=
range
antigravityAllScopes
{
resetAt
:=
a
.
antigravityQuotaScopeResetAt
(
scope
)
if
resetAt
!=
nil
&&
now
.
Before
(
*
resetAt
)
{
remainingSec
:=
int64
(
time
.
Until
(
*
resetAt
)
.
Seconds
())
if
remainingSec
>
0
{
result
[
string
(
scope
)]
=
remainingSec
}
}
}
if
len
(
result
)
==
0
{
return
nil
}
return
result
}
backend/internal/service/antigravity_token_refresher.go
View file @
377bffe2
...
...
@@ -3,6 +3,8 @@ package service
import
(
"context"
"fmt"
"log"
"strings"
"time"
)
...
...
@@ -55,15 +57,32 @@ func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Accoun
}
newCredentials
:=
r
.
antigravityOAuthService
.
BuildAccountCredentials
(
tokenInfo
)
// 合并旧的 credentials,保留新 credentials 中不存在的字段
for
k
,
v
:=
range
account
.
Credentials
{
if
_
,
exists
:=
newCredentials
[
k
];
!
exists
{
newCredentials
[
k
]
=
v
}
}
// 如果 project_id 获取失败,返回 credentials 但同时返回错误让账户被标记
// 特殊处理 project_id:如果新值为空但旧值非空,保留旧值
// 这确保了即使 LoadCodeAssist 失败,project_id 也不会丢失
if
newProjectID
,
_
:=
newCredentials
[
"project_id"
]
.
(
string
);
newProjectID
==
""
{
if
oldProjectID
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"project_id"
));
oldProjectID
!=
""
{
newCredentials
[
"project_id"
]
=
oldProjectID
}
}
// 如果 project_id 获取失败,只记录警告,不返回错误
// LoadCodeAssist 失败可能是临时网络问题,应该允许重试而不是立即标记为不可重试错误
// Token 刷新本身是成功的(access_token 和 refresh_token 已更新)
if
tokenInfo
.
ProjectIDMissing
{
return
newCredentials
,
fmt
.
Errorf
(
"missing_project_id: 账户缺少project id,可能无法使用Antigravity"
)
if
tokenInfo
.
ProjectID
!=
""
{
// 有旧的 project_id,本次获取失败,保留旧值
log
.
Printf
(
"[AntigravityTokenRefresher] Account %d: LoadCodeAssist 临时失败,保留旧 project_id"
,
account
.
ID
)
}
else
{
// 从未获取过 project_id,本次也失败,但不返回错误以允许下次重试
log
.
Printf
(
"[AntigravityTokenRefresher] Account %d: LoadCodeAssist 失败,project_id 缺失,但 token 已更新,将在下次刷新时重试"
,
account
.
ID
)
}
}
return
newCredentials
,
nil
...
...
backend/internal/service/auth_service.go
View file @
377bffe2
...
...
@@ -19,17 +19,19 @@ import (
)
var
(
ErrInvalidCredentials
=
infraerrors
.
Unauthorized
(
"INVALID_CREDENTIALS"
,
"invalid email or password"
)
ErrUserNotActive
=
infraerrors
.
Forbidden
(
"USER_NOT_ACTIVE"
,
"user is not active"
)
ErrEmailExists
=
infraerrors
.
Conflict
(
"EMAIL_EXISTS"
,
"email already exists"
)
ErrEmailReserved
=
infraerrors
.
BadRequest
(
"EMAIL_RESERVED"
,
"email is reserved"
)
ErrInvalidToken
=
infraerrors
.
Unauthorized
(
"INVALID_TOKEN"
,
"invalid token"
)
ErrTokenExpired
=
infraerrors
.
Unauthorized
(
"TOKEN_EXPIRED"
,
"token has expired"
)
ErrTokenTooLarge
=
infraerrors
.
BadRequest
(
"TOKEN_TOO_LARGE"
,
"token too large"
)
ErrTokenRevoked
=
infraerrors
.
Unauthorized
(
"TOKEN_REVOKED"
,
"token has been revoked"
)
ErrEmailVerifyRequired
=
infraerrors
.
BadRequest
(
"EMAIL_VERIFY_REQUIRED"
,
"email verification is required"
)
ErrRegDisabled
=
infraerrors
.
Forbidden
(
"REGISTRATION_DISABLED"
,
"registration is currently disabled"
)
ErrServiceUnavailable
=
infraerrors
.
ServiceUnavailable
(
"SERVICE_UNAVAILABLE"
,
"service temporarily unavailable"
)
ErrInvalidCredentials
=
infraerrors
.
Unauthorized
(
"INVALID_CREDENTIALS"
,
"invalid email or password"
)
ErrUserNotActive
=
infraerrors
.
Forbidden
(
"USER_NOT_ACTIVE"
,
"user is not active"
)
ErrEmailExists
=
infraerrors
.
Conflict
(
"EMAIL_EXISTS"
,
"email already exists"
)
ErrEmailReserved
=
infraerrors
.
BadRequest
(
"EMAIL_RESERVED"
,
"email is reserved"
)
ErrInvalidToken
=
infraerrors
.
Unauthorized
(
"INVALID_TOKEN"
,
"invalid token"
)
ErrTokenExpired
=
infraerrors
.
Unauthorized
(
"TOKEN_EXPIRED"
,
"token has expired"
)
ErrTokenTooLarge
=
infraerrors
.
BadRequest
(
"TOKEN_TOO_LARGE"
,
"token too large"
)
ErrTokenRevoked
=
infraerrors
.
Unauthorized
(
"TOKEN_REVOKED"
,
"token has been revoked"
)
ErrEmailVerifyRequired
=
infraerrors
.
BadRequest
(
"EMAIL_VERIFY_REQUIRED"
,
"email verification is required"
)
ErrRegDisabled
=
infraerrors
.
Forbidden
(
"REGISTRATION_DISABLED"
,
"registration is currently disabled"
)
ErrServiceUnavailable
=
infraerrors
.
ServiceUnavailable
(
"SERVICE_UNAVAILABLE"
,
"service temporarily unavailable"
)
ErrInvitationCodeRequired
=
infraerrors
.
BadRequest
(
"INVITATION_CODE_REQUIRED"
,
"invitation code is required"
)
ErrInvitationCodeInvalid
=
infraerrors
.
BadRequest
(
"INVITATION_CODE_INVALID"
,
"invalid or used invitation code"
)
)
// maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。
...
...
@@ -47,6 +49,7 @@ type JWTClaims struct {
// AuthService 认证服务
type
AuthService
struct
{
userRepo
UserRepository
redeemRepo
RedeemCodeRepository
cfg
*
config
.
Config
settingService
*
SettingService
emailService
*
EmailService
...
...
@@ -58,6 +61,7 @@ type AuthService struct {
// NewAuthService 创建认证服务实例
func
NewAuthService
(
userRepo
UserRepository
,
redeemRepo
RedeemCodeRepository
,
cfg
*
config
.
Config
,
settingService
*
SettingService
,
emailService
*
EmailService
,
...
...
@@ -67,6 +71,7 @@ func NewAuthService(
)
*
AuthService
{
return
&
AuthService
{
userRepo
:
userRepo
,
redeemRepo
:
redeemRepo
,
cfg
:
cfg
,
settingService
:
settingService
,
emailService
:
emailService
,
...
...
@@ -78,11 +83,11 @@ func NewAuthService(
// Register 用户注册,返回token和用户
func
(
s
*
AuthService
)
Register
(
ctx
context
.
Context
,
email
,
password
string
)
(
string
,
*
User
,
error
)
{
return
s
.
RegisterWithVerification
(
ctx
,
email
,
password
,
""
,
""
)
return
s
.
RegisterWithVerification
(
ctx
,
email
,
password
,
""
,
""
,
""
)
}
// RegisterWithVerification 用户注册(支持邮件验证
和
优惠码),返回token和用户
func
(
s
*
AuthService
)
RegisterWithVerification
(
ctx
context
.
Context
,
email
,
password
,
verifyCode
,
promoCode
string
)
(
string
,
*
User
,
error
)
{
// RegisterWithVerification 用户注册(支持邮件验证
、
优惠码
和邀请码
),返回token和用户
func
(
s
*
AuthService
)
RegisterWithVerification
(
ctx
context
.
Context
,
email
,
password
,
verifyCode
,
promoCode
,
invitationCode
string
)
(
string
,
*
User
,
error
)
{
// 检查是否开放注册(默认关闭:settingService 未配置时不允许注册)
if
s
.
settingService
==
nil
||
!
s
.
settingService
.
IsRegistrationEnabled
(
ctx
)
{
return
""
,
nil
,
ErrRegDisabled
...
...
@@ -93,6 +98,26 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return
""
,
nil
,
ErrEmailReserved
}
// 检查是否需要邀请码
var
invitationRedeemCode
*
RedeemCode
if
s
.
settingService
!=
nil
&&
s
.
settingService
.
IsInvitationCodeEnabled
(
ctx
)
{
if
invitationCode
==
""
{
return
""
,
nil
,
ErrInvitationCodeRequired
}
// 验证邀请码
redeemCode
,
err
:=
s
.
redeemRepo
.
GetByCode
(
ctx
,
invitationCode
)
if
err
!=
nil
{
log
.
Printf
(
"[Auth] Invalid invitation code: %s, error: %v"
,
invitationCode
,
err
)
return
""
,
nil
,
ErrInvitationCodeInvalid
}
// 检查类型和状态
if
redeemCode
.
Type
!=
RedeemTypeInvitation
||
redeemCode
.
Status
!=
StatusUnused
{
log
.
Printf
(
"[Auth] Invitation code invalid: type=%s, status=%s"
,
redeemCode
.
Type
,
redeemCode
.
Status
)
return
""
,
nil
,
ErrInvitationCodeInvalid
}
invitationRedeemCode
=
redeemCode
}
// 检查是否需要邮件验证
if
s
.
settingService
!=
nil
&&
s
.
settingService
.
IsEmailVerifyEnabled
(
ctx
)
{
// 如果邮件验证已开启但邮件服务未配置,拒绝注册
...
...
@@ -153,6 +178,14 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return
""
,
nil
,
ErrServiceUnavailable
}
// 标记邀请码为已使用(如果使用了邀请码)
if
invitationRedeemCode
!=
nil
{
if
err
:=
s
.
redeemRepo
.
Use
(
ctx
,
invitationRedeemCode
.
ID
,
user
.
ID
);
err
!=
nil
{
// 邀请码标记失败不影响注册,只记录日志
log
.
Printf
(
"[Auth] Failed to mark invitation code as used for user %d: %v"
,
user
.
ID
,
err
)
}
}
// 应用优惠码(如果提供且功能已启用)
if
promoCode
!=
""
&&
s
.
promoService
!=
nil
&&
s
.
settingService
!=
nil
&&
s
.
settingService
.
IsPromoCodeEnabled
(
ctx
)
{
if
err
:=
s
.
promoService
.
ApplyPromoCode
(
ctx
,
user
.
ID
,
promoCode
);
err
!=
nil
{
...
...
@@ -580,3 +613,149 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
// 生成新token
return
s
.
GenerateToken
(
user
)
}
// IsPasswordResetEnabled 检查是否启用密码重置功能
// 要求:必须同时开启邮件验证且 SMTP 配置正确
func
(
s
*
AuthService
)
IsPasswordResetEnabled
(
ctx
context
.
Context
)
bool
{
if
s
.
settingService
==
nil
{
return
false
}
// Must have email verification enabled and SMTP configured
if
!
s
.
settingService
.
IsEmailVerifyEnabled
(
ctx
)
{
return
false
}
return
s
.
settingService
.
IsPasswordResetEnabled
(
ctx
)
}
// preparePasswordReset validates the password reset request and returns necessary data
// Returns (siteName, resetURL, shouldProceed)
// shouldProceed is false when we should silently return success (to prevent enumeration)
func
(
s
*
AuthService
)
preparePasswordReset
(
ctx
context
.
Context
,
email
,
frontendBaseURL
string
)
(
string
,
string
,
bool
)
{
// Check if user exists (but don't reveal this to the caller)
user
,
err
:=
s
.
userRepo
.
GetByEmail
(
ctx
,
email
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
ErrUserNotFound
)
{
// Security: Log but don't reveal that user doesn't exist
log
.
Printf
(
"[Auth] Password reset requested for non-existent email: %s"
,
email
)
return
""
,
""
,
false
}
log
.
Printf
(
"[Auth] Database error checking email for password reset: %v"
,
err
)
return
""
,
""
,
false
}
// Check if user is active
if
!
user
.
IsActive
()
{
log
.
Printf
(
"[Auth] Password reset requested for inactive user: %s"
,
email
)
return
""
,
""
,
false
}
// Get site name
siteName
:=
"Sub2API"
if
s
.
settingService
!=
nil
{
siteName
=
s
.
settingService
.
GetSiteName
(
ctx
)
}
// Build reset URL base
resetURL
:=
fmt
.
Sprintf
(
"%s/reset-password"
,
strings
.
TrimSuffix
(
frontendBaseURL
,
"/"
))
return
siteName
,
resetURL
,
true
}
// RequestPasswordReset 请求密码重置(同步发送)
// Security: Returns the same response regardless of whether the email exists (prevent user enumeration)
func
(
s
*
AuthService
)
RequestPasswordReset
(
ctx
context
.
Context
,
email
,
frontendBaseURL
string
)
error
{
if
!
s
.
IsPasswordResetEnabled
(
ctx
)
{
return
infraerrors
.
Forbidden
(
"PASSWORD_RESET_DISABLED"
,
"password reset is not enabled"
)
}
if
s
.
emailService
==
nil
{
return
ErrServiceUnavailable
}
siteName
,
resetURL
,
shouldProceed
:=
s
.
preparePasswordReset
(
ctx
,
email
,
frontendBaseURL
)
if
!
shouldProceed
{
return
nil
// Silent success to prevent enumeration
}
if
err
:=
s
.
emailService
.
SendPasswordResetEmail
(
ctx
,
email
,
siteName
,
resetURL
);
err
!=
nil
{
log
.
Printf
(
"[Auth] Failed to send password reset email to %s: %v"
,
email
,
err
)
return
nil
// Silent success to prevent enumeration
}
log
.
Printf
(
"[Auth] Password reset email sent to: %s"
,
email
)
return
nil
}
// RequestPasswordResetAsync 异步请求密码重置(队列发送)
// Security: Returns the same response regardless of whether the email exists (prevent user enumeration)
func
(
s
*
AuthService
)
RequestPasswordResetAsync
(
ctx
context
.
Context
,
email
,
frontendBaseURL
string
)
error
{
if
!
s
.
IsPasswordResetEnabled
(
ctx
)
{
return
infraerrors
.
Forbidden
(
"PASSWORD_RESET_DISABLED"
,
"password reset is not enabled"
)
}
if
s
.
emailQueueService
==
nil
{
return
ErrServiceUnavailable
}
siteName
,
resetURL
,
shouldProceed
:=
s
.
preparePasswordReset
(
ctx
,
email
,
frontendBaseURL
)
if
!
shouldProceed
{
return
nil
// Silent success to prevent enumeration
}
if
err
:=
s
.
emailQueueService
.
EnqueuePasswordReset
(
email
,
siteName
,
resetURL
);
err
!=
nil
{
log
.
Printf
(
"[Auth] Failed to enqueue password reset email for %s: %v"
,
email
,
err
)
return
nil
// Silent success to prevent enumeration
}
log
.
Printf
(
"[Auth] Password reset email enqueued for: %s"
,
email
)
return
nil
}
// ResetPassword 重置密码
// Security: Increments TokenVersion to invalidate all existing JWT tokens
func
(
s
*
AuthService
)
ResetPassword
(
ctx
context
.
Context
,
email
,
token
,
newPassword
string
)
error
{
// Check if password reset is enabled
if
!
s
.
IsPasswordResetEnabled
(
ctx
)
{
return
infraerrors
.
Forbidden
(
"PASSWORD_RESET_DISABLED"
,
"password reset is not enabled"
)
}
if
s
.
emailService
==
nil
{
return
ErrServiceUnavailable
}
// Verify and consume the reset token (one-time use)
if
err
:=
s
.
emailService
.
ConsumePasswordResetToken
(
ctx
,
email
,
token
);
err
!=
nil
{
return
err
}
// Get user
user
,
err
:=
s
.
userRepo
.
GetByEmail
(
ctx
,
email
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
ErrUserNotFound
)
{
return
ErrInvalidResetToken
// Token was valid but user was deleted
}
log
.
Printf
(
"[Auth] Database error getting user for password reset: %v"
,
err
)
return
ErrServiceUnavailable
}
// Check if user is active
if
!
user
.
IsActive
()
{
return
ErrUserNotActive
}
// Hash new password
hashedPassword
,
err
:=
s
.
HashPassword
(
newPassword
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"hash password: %w"
,
err
)
}
// Update password and increment TokenVersion
user
.
PasswordHash
=
hashedPassword
user
.
TokenVersion
++
// Invalidate all existing tokens
if
err
:=
s
.
userRepo
.
Update
(
ctx
,
user
);
err
!=
nil
{
log
.
Printf
(
"[Auth] Database error updating password for user %d: %v"
,
user
.
ID
,
err
)
return
ErrServiceUnavailable
}
log
.
Printf
(
"[Auth] Password reset successful for user: %s"
,
email
)
return
nil
}
backend/internal/service/auth_service_register_test.go
View file @
377bffe2
...
...
@@ -71,6 +71,26 @@ func (s *emailCacheStub) DeleteVerificationCode(ctx context.Context, email strin
return
nil
}
func
(
s
*
emailCacheStub
)
GetPasswordResetToken
(
ctx
context
.
Context
,
email
string
)
(
*
PasswordResetTokenData
,
error
)
{
return
nil
,
nil
}
func
(
s
*
emailCacheStub
)
SetPasswordResetToken
(
ctx
context
.
Context
,
email
string
,
data
*
PasswordResetTokenData
,
ttl
time
.
Duration
)
error
{
return
nil
}
func
(
s
*
emailCacheStub
)
DeletePasswordResetToken
(
ctx
context
.
Context
,
email
string
)
error
{
return
nil
}
func
(
s
*
emailCacheStub
)
IsPasswordResetEmailInCooldown
(
ctx
context
.
Context
,
email
string
)
bool
{
return
false
}
func
(
s
*
emailCacheStub
)
SetPasswordResetEmailCooldown
(
ctx
context
.
Context
,
email
string
,
ttl
time
.
Duration
)
error
{
return
nil
}
func
newAuthService
(
repo
*
userRepoStub
,
settings
map
[
string
]
string
,
emailCache
EmailCache
)
*
AuthService
{
cfg
:=
&
config
.
Config
{
JWT
:
config
.
JWTConfig
{
...
...
@@ -95,6 +115,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
return
NewAuthService
(
repo
,
nil
,
// redeemRepo
cfg
,
settingService
,
emailService
,
...
...
@@ -132,7 +153,7 @@ func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testi
},
nil
)
// 应返回服务不可用错误,而不是允许绕过验证
_
,
_
,
err
:=
service
.
RegisterWithVerification
(
context
.
Background
(),
"user@test.com"
,
"password"
,
"any-code"
,
""
)
_
,
_
,
err
:=
service
.
RegisterWithVerification
(
context
.
Background
(),
"user@test.com"
,
"password"
,
"any-code"
,
""
,
""
)
require
.
ErrorIs
(
t
,
err
,
ErrServiceUnavailable
)
}
...
...
@@ -144,7 +165,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
SettingKeyEmailVerifyEnabled
:
"true"
,
},
cache
)
_
,
_
,
err
:=
service
.
RegisterWithVerification
(
context
.
Background
(),
"user@test.com"
,
"password"
,
""
,
""
)
_
,
_
,
err
:=
service
.
RegisterWithVerification
(
context
.
Background
(),
"user@test.com"
,
"password"
,
""
,
""
,
""
)
require
.
ErrorIs
(
t
,
err
,
ErrEmailVerifyRequired
)
}
...
...
@@ -158,7 +179,7 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
SettingKeyEmailVerifyEnabled
:
"true"
,
},
cache
)
_
,
_
,
err
:=
service
.
RegisterWithVerification
(
context
.
Background
(),
"user@test.com"
,
"password"
,
"wrong"
,
""
)
_
,
_
,
err
:=
service
.
RegisterWithVerification
(
context
.
Background
(),
"user@test.com"
,
"password"
,
"wrong"
,
""
,
""
)
require
.
ErrorIs
(
t
,
err
,
ErrInvalidVerifyCode
)
require
.
ErrorContains
(
t
,
err
,
"verify code"
)
}
...
...
backend/internal/service/billing_service.go
View file @
377bffe2
...
...
@@ -241,6 +241,76 @@ func (s *BillingService) CalculateCostWithConfig(model string, tokens UsageToken
return
s
.
CalculateCost
(
model
,
tokens
,
multiplier
)
}
// CalculateCostWithLongContext 计算费用,支持长上下文双倍计费
// threshold: 阈值(如 200000),超过此值的部分按 extraMultiplier 倍计费
// extraMultiplier: 超出部分的倍率(如 2.0 表示双倍)
//
// 示例:缓存 210k + 输入 10k = 220k,阈值 200k,倍率 2.0
// 拆分为:范围内 (200k, 0) + 范围外 (10k, 10k)
// 范围内正常计费,范围外 × 2 计费
func
(
s
*
BillingService
)
CalculateCostWithLongContext
(
model
string
,
tokens
UsageTokens
,
rateMultiplier
float64
,
threshold
int
,
extraMultiplier
float64
)
(
*
CostBreakdown
,
error
)
{
// 未启用长上下文计费,直接走正常计费
if
threshold
<=
0
||
extraMultiplier
<=
1
{
return
s
.
CalculateCost
(
model
,
tokens
,
rateMultiplier
)
}
// 计算总输入 token(缓存读取 + 新输入)
total
:=
tokens
.
CacheReadTokens
+
tokens
.
InputTokens
if
total
<=
threshold
{
return
s
.
CalculateCost
(
model
,
tokens
,
rateMultiplier
)
}
// 拆分成范围内和范围外
var
inRangeCacheTokens
,
inRangeInputTokens
int
var
outRangeCacheTokens
,
outRangeInputTokens
int
if
tokens
.
CacheReadTokens
>=
threshold
{
// 缓存已超过阈值:范围内只有缓存,范围外是超出的缓存+全部输入
inRangeCacheTokens
=
threshold
inRangeInputTokens
=
0
outRangeCacheTokens
=
tokens
.
CacheReadTokens
-
threshold
outRangeInputTokens
=
tokens
.
InputTokens
}
else
{
// 缓存未超过阈值:范围内是全部缓存+部分输入,范围外是剩余输入
inRangeCacheTokens
=
tokens
.
CacheReadTokens
inRangeInputTokens
=
threshold
-
tokens
.
CacheReadTokens
outRangeCacheTokens
=
0
outRangeInputTokens
=
tokens
.
InputTokens
-
inRangeInputTokens
}
// 范围内部分:正常计费
inRangeTokens
:=
UsageTokens
{
InputTokens
:
inRangeInputTokens
,
OutputTokens
:
tokens
.
OutputTokens
,
// 输出只算一次
CacheCreationTokens
:
tokens
.
CacheCreationTokens
,
CacheReadTokens
:
inRangeCacheTokens
,
}
inRangeCost
,
err
:=
s
.
CalculateCost
(
model
,
inRangeTokens
,
rateMultiplier
)
if
err
!=
nil
{
return
nil
,
err
}
// 范围外部分:× extraMultiplier 计费
outRangeTokens
:=
UsageTokens
{
InputTokens
:
outRangeInputTokens
,
CacheReadTokens
:
outRangeCacheTokens
,
}
outRangeCost
,
err
:=
s
.
CalculateCost
(
model
,
outRangeTokens
,
rateMultiplier
*
extraMultiplier
)
if
err
!=
nil
{
return
inRangeCost
,
nil
// 出错时返回范围内成本
}
// 合并成本
return
&
CostBreakdown
{
InputCost
:
inRangeCost
.
InputCost
+
outRangeCost
.
InputCost
,
OutputCost
:
inRangeCost
.
OutputCost
,
CacheCreationCost
:
inRangeCost
.
CacheCreationCost
,
CacheReadCost
:
inRangeCost
.
CacheReadCost
+
outRangeCost
.
CacheReadCost
,
TotalCost
:
inRangeCost
.
TotalCost
+
outRangeCost
.
TotalCost
,
ActualCost
:
inRangeCost
.
ActualCost
+
outRangeCost
.
ActualCost
,
},
nil
}
// ListSupportedModels 列出所有支持的模型(现在总是返回true,因为有模糊匹配)
func
(
s
*
BillingService
)
ListSupportedModels
()
[]
string
{
models
:=
make
([]
string
,
0
)
...
...
backend/internal/service/domain_constants.go
View file @
377bffe2
package
service
import
"github.com/Wei-Shaw/sub2api/internal/domain"
// Status constants
const
(
StatusActive
=
"a
ctive
"
StatusDisabled
=
"
disabled
"
StatusError
=
"e
rror
"
StatusUnused
=
"u
nused
"
StatusUsed
=
"u
sed
"
StatusExpired
=
"e
xpired
"
StatusActive
=
domain
.
StatusA
ctive
StatusDisabled
=
d
omain
.
StatusD
isabled
StatusError
=
domain
.
StatusE
rror
StatusUnused
=
domain
.
StatusU
nused
StatusUsed
=
domain
.
StatusU
sed
StatusExpired
=
domain
.
StatusE
xpired
)
// Role constants
const
(
RoleAdmin
=
"a
dmin
"
RoleUser
=
"u
ser
"
RoleAdmin
=
domain
.
RoleA
dmin
RoleUser
=
domain
.
RoleU
ser
)
// Platform constants
const
(
PlatformAnthropic
=
"a
nthropic
"
PlatformOpenAI
=
"openai"
PlatformGemini
=
"g
emini
"
PlatformAntigravity
=
"a
ntigravity
"
PlatformSora
=
"s
ora
"
PlatformAnthropic
=
domain
.
PlatformA
nthropic
PlatformOpenAI
=
domain
.
PlatformOpenAI
PlatformGemini
=
domain
.
PlatformG
emini
PlatformAntigravity
=
domain
.
PlatformA
ntigravity
PlatformSora
=
domain
.
PlatformS
ora
)
// Account type constants
const
(
AccountTypeOAuth
=
"oa
uth
"
// OAuth类型账号(full scope: profile + inference)
AccountTypeSetupToken
=
"s
etup
-t
oken
"
// Setup Token类型账号(inference only scope)
AccountTypeAPIKey
=
"apikey"
// API Key类型账号
AccountTypeOAuth
=
domain
.
AccountTypeOA
uth
// OAuth类型账号(full scope: profile + inference)
AccountTypeSetupToken
=
domain
.
AccountTypeS
etup
T
oken
// Setup Token类型账号(inference only scope)
AccountTypeAPIKey
=
domain
.
AccountTypeAPIKey
// API Key类型账号
)
// Redeem type constants
const
(
RedeemTypeBalance
=
"balance"
RedeemTypeConcurrency
=
"concurrency"
RedeemTypeSubscription
=
"subscription"
RedeemTypeBalance
=
domain
.
RedeemTypeBalance
RedeemTypeConcurrency
=
domain
.
RedeemTypeConcurrency
RedeemTypeSubscription
=
domain
.
RedeemTypeSubscription
RedeemTypeInvitation
=
domain
.
RedeemTypeInvitation
)
// PromoCode status constants
const
(
PromoCodeStatusActive
=
"a
ctive
"
PromoCodeStatusDisabled
=
"
disabled
"
PromoCodeStatusActive
=
domain
.
PromoCodeStatusA
ctive
PromoCodeStatusDisabled
=
d
omain
.
PromoCodeStatusD
isabled
)
// Admin adjustment type constants
const
(
AdjustmentTypeAdminBalance
=
"a
dmin
_b
alance
"
// 管理员调整余额
AdjustmentTypeAdminConcurrency
=
"a
dmin
_c
oncurrency
"
// 管理员调整并发数
AdjustmentTypeAdminBalance
=
domain
.
AdjustmentTypeA
dmin
B
alance
// 管理员调整余额
AdjustmentTypeAdminConcurrency
=
domain
.
AdjustmentTypeA
dmin
C
oncurrency
// 管理员调整并发数
)
// Group subscription type constants
const
(
SubscriptionTypeStandard
=
"s
tandard
"
// 标准计费模式(按余额扣费)
SubscriptionTypeSubscription
=
"s
ubscription
"
// 订阅模式(按限额控制)
SubscriptionTypeStandard
=
domain
.
SubscriptionTypeS
tandard
// 标准计费模式(按余额扣费)
SubscriptionTypeSubscription
=
domain
.
SubscriptionTypeS
ubscription
// 订阅模式(按限额控制)
)
// Subscription status constants
const
(
SubscriptionStatusActive
=
"a
ctive
"
SubscriptionStatusExpired
=
"e
xpired
"
SubscriptionStatusSuspended
=
"s
uspended
"
SubscriptionStatusActive
=
domain
.
SubscriptionStatusA
ctive
SubscriptionStatusExpired
=
domain
.
SubscriptionStatusE
xpired
SubscriptionStatusSuspended
=
domain
.
SubscriptionStatusS
uspended
)
// LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀(RFC 保留域名)。
...
...
@@ -70,9 +73,11 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
// Setting keys
const
(
// 注册设置
SettingKeyRegistrationEnabled
=
"registration_enabled"
// 是否开放注册
SettingKeyEmailVerifyEnabled
=
"email_verify_enabled"
// 是否开启邮件验证
SettingKeyPromoCodeEnabled
=
"promo_code_enabled"
// 是否启用优惠码功能
SettingKeyRegistrationEnabled
=
"registration_enabled"
// 是否开放注册
SettingKeyEmailVerifyEnabled
=
"email_verify_enabled"
// 是否开启邮件验证
SettingKeyPromoCodeEnabled
=
"promo_code_enabled"
// 是否启用优惠码功能
SettingKeyPasswordResetEnabled
=
"password_reset_enabled"
// 是否启用忘记密码功能(需要先开启邮件验证)
SettingKeyInvitationCodeEnabled
=
"invitation_code_enabled"
// 是否启用邀请码注册
// 邮件服务设置
SettingKeySMTPHost
=
"smtp_host"
// SMTP服务器地址
...
...
@@ -88,6 +93,9 @@ const (
SettingKeyTurnstileSiteKey
=
"turnstile_site_key"
// Turnstile Site Key
SettingKeyTurnstileSecretKey
=
"turnstile_secret_key"
// Turnstile Secret Key
// TOTP 双因素认证设置
SettingKeyTotpEnabled
=
"totp_enabled"
// 是否启用 TOTP 2FA 功能
// LinuxDo Connect OAuth 登录设置
SettingKeyLinuxDoConnectEnabled
=
"linuxdo_connect_enabled"
SettingKeyLinuxDoConnectClientID
=
"linuxdo_connect_client_id"
...
...
@@ -95,14 +103,16 @@ const (
SettingKeyLinuxDoConnectRedirectURL
=
"linuxdo_connect_redirect_url"
// OEM设置
SettingKeySiteName
=
"site_name"
// 网站名称
SettingKeySiteLogo
=
"site_logo"
// 网站Logo (base64)
SettingKeySiteSubtitle
=
"site_subtitle"
// 网站副标题
SettingKeyAPIBaseURL
=
"api_base_url"
// API端点地址(用于客户端配置和导入)
SettingKeyContactInfo
=
"contact_info"
// 客服联系方式
SettingKeyDocURL
=
"doc_url"
// 文档链接
SettingKeyHomeContent
=
"home_content"
// 首页内容(支持 Markdown/HTML,或 URL 作为 iframe src)
SettingKeyHideCcsImportButton
=
"hide_ccs_import_button"
// 是否隐藏 API Keys 页面的导入 CCS 按钮
SettingKeySiteName
=
"site_name"
// 网站名称
SettingKeySiteLogo
=
"site_logo"
// 网站Logo (base64)
SettingKeySiteSubtitle
=
"site_subtitle"
// 网站副标题
SettingKeyAPIBaseURL
=
"api_base_url"
// API端点地址(用于客户端配置和导入)
SettingKeyContactInfo
=
"contact_info"
// 客服联系方式
SettingKeyDocURL
=
"doc_url"
// 文档链接
SettingKeyHomeContent
=
"home_content"
// 首页内容(支持 Markdown/HTML,或 URL 作为 iframe src)
SettingKeyHideCcsImportButton
=
"hide_ccs_import_button"
// 是否隐藏 API Keys 页面的导入 CCS 按钮
SettingKeyPurchaseSubscriptionEnabled
=
"purchase_subscription_enabled"
// 是否展示“购买订阅”页面入口
SettingKeyPurchaseSubscriptionURL
=
"purchase_subscription_url"
// “购买订阅”页面 URL(作为 iframe src)
// 默认配置
SettingKeyDefaultConcurrency
=
"default_concurrency"
// 新用户默认并发量
...
...
backend/internal/service/email_queue_service.go
View file @
377bffe2
...
...
@@ -8,11 +8,18 @@ import (
"time"
)
// Task type constants
const
(
TaskTypeVerifyCode
=
"verify_code"
TaskTypePasswordReset
=
"password_reset"
)
// EmailTask 邮件发送任务
type
EmailTask
struct
{
Email
string
SiteName
string
TaskType
string
// "verify_code"
TaskType
string
// "verify_code" or "password_reset"
ResetURL
string
// Only used for password_reset task type
}
// EmailQueueService 异步邮件队列服务
...
...
@@ -73,12 +80,18 @@ func (s *EmailQueueService) processTask(workerID int, task EmailTask) {
defer
cancel
()
switch
task
.
TaskType
{
case
"v
erify
_c
ode
"
:
case
TaskTypeV
erify
C
ode
:
if
err
:=
s
.
emailService
.
SendVerifyCode
(
ctx
,
task
.
Email
,
task
.
SiteName
);
err
!=
nil
{
log
.
Printf
(
"[EmailQueue] Worker %d failed to send verify code to %s: %v"
,
workerID
,
task
.
Email
,
err
)
}
else
{
log
.
Printf
(
"[EmailQueue] Worker %d sent verify code to %s"
,
workerID
,
task
.
Email
)
}
case
TaskTypePasswordReset
:
if
err
:=
s
.
emailService
.
SendPasswordResetEmailWithCooldown
(
ctx
,
task
.
Email
,
task
.
SiteName
,
task
.
ResetURL
);
err
!=
nil
{
log
.
Printf
(
"[EmailQueue] Worker %d failed to send password reset to %s: %v"
,
workerID
,
task
.
Email
,
err
)
}
else
{
log
.
Printf
(
"[EmailQueue] Worker %d sent password reset to %s"
,
workerID
,
task
.
Email
)
}
default
:
log
.
Printf
(
"[EmailQueue] Worker %d unknown task type: %s"
,
workerID
,
task
.
TaskType
)
}
...
...
@@ -89,7 +102,7 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
task
:=
EmailTask
{
Email
:
email
,
SiteName
:
siteName
,
TaskType
:
"v
erify
_c
ode
"
,
TaskType
:
TaskTypeV
erify
C
ode
,
}
select
{
...
...
@@ -101,6 +114,24 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
}
}
// EnqueuePasswordReset 将密码重置邮件任务加入队列
func
(
s
*
EmailQueueService
)
EnqueuePasswordReset
(
email
,
siteName
,
resetURL
string
)
error
{
task
:=
EmailTask
{
Email
:
email
,
SiteName
:
siteName
,
TaskType
:
TaskTypePasswordReset
,
ResetURL
:
resetURL
,
}
select
{
case
s
.
taskChan
<-
task
:
log
.
Printf
(
"[EmailQueue] Enqueued password reset task for %s"
,
email
)
return
nil
default
:
return
fmt
.
Errorf
(
"email queue is full"
)
}
}
// Stop 停止队列服务
func
(
s
*
EmailQueueService
)
Stop
()
{
close
(
s
.
stopChan
)
...
...
backend/internal/service/email_service.go
View file @
377bffe2
...
...
@@ -3,11 +3,14 @@ package service
import
(
"context"
"crypto/rand"
"crypto/subtle"
"crypto/tls"
"encoding/hex"
"fmt"
"log"
"math/big"
"net/smtp"
"net/url"
"strconv"
"time"
...
...
@@ -19,6 +22,9 @@ var (
ErrInvalidVerifyCode
=
infraerrors
.
BadRequest
(
"INVALID_VERIFY_CODE"
,
"invalid or expired verification code"
)
ErrVerifyCodeTooFrequent
=
infraerrors
.
TooManyRequests
(
"VERIFY_CODE_TOO_FREQUENT"
,
"please wait before requesting a new code"
)
ErrVerifyCodeMaxAttempts
=
infraerrors
.
TooManyRequests
(
"VERIFY_CODE_MAX_ATTEMPTS"
,
"too many failed attempts, please request a new code"
)
// Password reset errors
ErrInvalidResetToken
=
infraerrors
.
BadRequest
(
"INVALID_RESET_TOKEN"
,
"invalid or expired password reset token"
)
)
// EmailCache defines cache operations for email service
...
...
@@ -26,6 +32,16 @@ type EmailCache interface {
GetVerificationCode
(
ctx
context
.
Context
,
email
string
)
(
*
VerificationCodeData
,
error
)
SetVerificationCode
(
ctx
context
.
Context
,
email
string
,
data
*
VerificationCodeData
,
ttl
time
.
Duration
)
error
DeleteVerificationCode
(
ctx
context
.
Context
,
email
string
)
error
// Password reset token methods
GetPasswordResetToken
(
ctx
context
.
Context
,
email
string
)
(
*
PasswordResetTokenData
,
error
)
SetPasswordResetToken
(
ctx
context
.
Context
,
email
string
,
data
*
PasswordResetTokenData
,
ttl
time
.
Duration
)
error
DeletePasswordResetToken
(
ctx
context
.
Context
,
email
string
)
error
// Password reset email cooldown methods
// Returns true if in cooldown period (email was sent recently)
IsPasswordResetEmailInCooldown
(
ctx
context
.
Context
,
email
string
)
bool
SetPasswordResetEmailCooldown
(
ctx
context
.
Context
,
email
string
,
ttl
time
.
Duration
)
error
}
// VerificationCodeData represents verification code data
...
...
@@ -35,10 +51,22 @@ type VerificationCodeData struct {
CreatedAt
time
.
Time
}
// PasswordResetTokenData represents password reset token data
type
PasswordResetTokenData
struct
{
Token
string
CreatedAt
time
.
Time
}
const
(
verifyCodeTTL
=
15
*
time
.
Minute
verifyCodeCooldown
=
1
*
time
.
Minute
maxVerifyCodeAttempts
=
5
// Password reset token settings
passwordResetTokenTTL
=
30
*
time
.
Minute
// Password reset email cooldown (prevent email bombing)
passwordResetEmailCooldown
=
30
*
time
.
Second
)
// SMTPConfig SMTP配置
...
...
@@ -254,8 +282,8 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
return
ErrVerifyCodeMaxAttempts
}
// 验证码不匹配
if
data
.
Code
!=
code
{
// 验证码不匹配
(constant-time comparison to prevent timing attacks)
if
subtle
.
ConstantTimeCompare
([]
byte
(
data
.
Code
),
[]
byte
(
code
))
!=
1
{
data
.
Attempts
++
if
err
:=
s
.
cache
.
SetVerificationCode
(
ctx
,
email
,
data
,
verifyCodeTTL
);
err
!=
nil
{
log
.
Printf
(
"[Email] Failed to update verification attempt count: %v"
,
err
)
...
...
@@ -357,3 +385,157 @@ func (s *EmailService) TestSMTPConnectionWithConfig(config *SMTPConfig) error {
return
client
.
Quit
()
}
// GeneratePasswordResetToken generates a secure 32-byte random token (64 hex characters)
func
(
s
*
EmailService
)
GeneratePasswordResetToken
()
(
string
,
error
)
{
bytes
:=
make
([]
byte
,
32
)
if
_
,
err
:=
rand
.
Read
(
bytes
);
err
!=
nil
{
return
""
,
err
}
return
hex
.
EncodeToString
(
bytes
),
nil
}
// SendPasswordResetEmail sends a password reset email with a reset link
func
(
s
*
EmailService
)
SendPasswordResetEmail
(
ctx
context
.
Context
,
email
,
siteName
,
resetURL
string
)
error
{
var
token
string
var
needSaveToken
bool
// Check if token already exists
existing
,
err
:=
s
.
cache
.
GetPasswordResetToken
(
ctx
,
email
)
if
err
==
nil
&&
existing
!=
nil
{
// Token exists, reuse it (allows resending email without generating new token)
token
=
existing
.
Token
needSaveToken
=
false
}
else
{
// Generate new token
token
,
err
=
s
.
GeneratePasswordResetToken
()
if
err
!=
nil
{
return
fmt
.
Errorf
(
"generate token: %w"
,
err
)
}
needSaveToken
=
true
}
// Save token to Redis (only if new token generated)
if
needSaveToken
{
data
:=
&
PasswordResetTokenData
{
Token
:
token
,
CreatedAt
:
time
.
Now
(),
}
if
err
:=
s
.
cache
.
SetPasswordResetToken
(
ctx
,
email
,
data
,
passwordResetTokenTTL
);
err
!=
nil
{
return
fmt
.
Errorf
(
"save reset token: %w"
,
err
)
}
}
// Build full reset URL with URL-encoded token and email
fullResetURL
:=
fmt
.
Sprintf
(
"%s?email=%s&token=%s"
,
resetURL
,
url
.
QueryEscape
(
email
),
url
.
QueryEscape
(
token
))
// Build email content
subject
:=
fmt
.
Sprintf
(
"[%s] 密码重置请求"
,
siteName
)
body
:=
s
.
buildPasswordResetEmailBody
(
fullResetURL
,
siteName
)
// Send email
if
err
:=
s
.
SendEmail
(
ctx
,
email
,
subject
,
body
);
err
!=
nil
{
return
fmt
.
Errorf
(
"send email: %w"
,
err
)
}
return
nil
}
// SendPasswordResetEmailWithCooldown sends password reset email with cooldown check (called by queue worker)
// This method wraps SendPasswordResetEmail with email cooldown to prevent email bombing
func
(
s
*
EmailService
)
SendPasswordResetEmailWithCooldown
(
ctx
context
.
Context
,
email
,
siteName
,
resetURL
string
)
error
{
// Check email cooldown to prevent email bombing
if
s
.
cache
.
IsPasswordResetEmailInCooldown
(
ctx
,
email
)
{
log
.
Printf
(
"[Email] Password reset email skipped (cooldown): %s"
,
email
)
return
nil
// Silent success to prevent revealing cooldown to attackers
}
// Send email using core method
if
err
:=
s
.
SendPasswordResetEmail
(
ctx
,
email
,
siteName
,
resetURL
);
err
!=
nil
{
return
err
}
// Set cooldown marker (Redis TTL handles expiration)
if
err
:=
s
.
cache
.
SetPasswordResetEmailCooldown
(
ctx
,
email
,
passwordResetEmailCooldown
);
err
!=
nil
{
log
.
Printf
(
"[Email] Failed to set password reset cooldown for %s: %v"
,
email
,
err
)
}
return
nil
}
// VerifyPasswordResetToken verifies the password reset token without consuming it
func
(
s
*
EmailService
)
VerifyPasswordResetToken
(
ctx
context
.
Context
,
email
,
token
string
)
error
{
data
,
err
:=
s
.
cache
.
GetPasswordResetToken
(
ctx
,
email
)
if
err
!=
nil
||
data
==
nil
{
return
ErrInvalidResetToken
}
// Use constant-time comparison to prevent timing attacks
if
subtle
.
ConstantTimeCompare
([]
byte
(
data
.
Token
),
[]
byte
(
token
))
!=
1
{
return
ErrInvalidResetToken
}
return
nil
}
// ConsumePasswordResetToken verifies and deletes the token (one-time use)
func
(
s
*
EmailService
)
ConsumePasswordResetToken
(
ctx
context
.
Context
,
email
,
token
string
)
error
{
// Verify first
if
err
:=
s
.
VerifyPasswordResetToken
(
ctx
,
email
,
token
);
err
!=
nil
{
return
err
}
// Delete after verification (one-time use)
if
err
:=
s
.
cache
.
DeletePasswordResetToken
(
ctx
,
email
);
err
!=
nil
{
log
.
Printf
(
"[Email] Failed to delete password reset token after consumption: %v"
,
err
)
}
return
nil
}
// buildPasswordResetEmailBody builds the HTML content for password reset email
func
(
s
*
EmailService
)
buildPasswordResetEmailBody
(
resetURL
,
siteName
string
)
string
{
return
fmt
.
Sprintf
(
`
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
.container { max-width: 600px; margin: 0 auto; background-color: #ffffff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
.header { background: linear-gradient(135deg, #667eea 0%%, #764ba2 100%%); color: white; padding: 30px; text-align: center; }
.header h1 { margin: 0; font-size: 24px; }
.content { padding: 40px 30px; text-align: center; }
.button { display: inline-block; background: linear-gradient(135deg, #667eea 0%%, #764ba2 100%%); color: white; padding: 14px 32px; text-decoration: none; border-radius: 8px; font-size: 16px; font-weight: 600; margin: 20px 0; }
.button:hover { opacity: 0.9; }
.info { color: #666; font-size: 14px; line-height: 1.6; margin-top: 20px; }
.link-fallback { color: #666; font-size: 12px; word-break: break-all; margin-top: 20px; padding: 15px; background-color: #f8f9fa; border-radius: 4px; }
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
.warning { color: #e74c3c; font-weight: 500; }
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>%s</h1>
</div>
<div class="content">
<p style="font-size: 18px; color: #333;">密码重置请求</p>
<p style="color: #666;">您已请求重置密码。请点击下方按钮设置新密码:</p>
<a href="%s" class="button">重置密码</a>
<div class="info">
<p>此链接将在 <strong>30 分钟</strong>后失效。</p>
<p class="warning">如果您没有请求重置密码,请忽略此邮件。您的密码将保持不变。</p>
</div>
<div class="link-fallback">
<p>如果按钮无法点击,请复制以下链接到浏览器中打开:</p>
<p>%s</p>
</div>
</div>
<div class="footer">
<p>这是一封自动发送的邮件,请勿回复。</p>
</div>
</div>
</body>
</html>
`
,
siteName
,
resetURL
,
resetURL
)
}
backend/internal/service/gateway_beta_test.go
0 → 100644
View file @
377bffe2
package
service
import
(
"testing"
"github.com/stretchr/testify/require"
)
func
TestMergeAnthropicBeta
(
t
*
testing
.
T
)
{
got
:=
mergeAnthropicBeta
(
[]
string
{
"oauth-2025-04-20"
,
"interleaved-thinking-2025-05-14"
},
"foo, oauth-2025-04-20,bar, foo"
,
)
require
.
Equal
(
t
,
"oauth-2025-04-20,interleaved-thinking-2025-05-14,foo,bar"
,
got
)
}
func
TestMergeAnthropicBeta_EmptyIncoming
(
t
*
testing
.
T
)
{
got
:=
mergeAnthropicBeta
(
[]
string
{
"oauth-2025-04-20"
,
"interleaved-thinking-2025-05-14"
},
""
,
)
require
.
Equal
(
t
,
"oauth-2025-04-20,interleaved-thinking-2025-05-14"
,
got
)
}
backend/internal/service/gateway_multiplatform_test.go
View file @
377bffe2
...
...
@@ -269,6 +269,14 @@ func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Conte
return
0
,
nil
}
func
(
m
*
mockGroupRepoForGateway
)
BindAccountsToGroup
(
ctx
context
.
Context
,
groupID
int64
,
accountIDs
[]
int64
)
error
{
return
nil
}
func
(
m
*
mockGroupRepoForGateway
)
GetAccountIDsByGroupIDs
(
ctx
context
.
Context
,
groupIDs
[]
int64
)
([]
int64
,
error
)
{
return
nil
,
nil
}
func
ptr
[
T
any
](
v
T
)
*
T
{
return
&
v
}
...
...
backend/internal/service/gateway_oauth_metadata_test.go
0 → 100644
View file @
377bffe2
package
service
import
(
"regexp"
"testing"
"github.com/stretchr/testify/require"
)
func
TestBuildOAuthMetadataUserID_FallbackWithoutAccountUUID
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
parsed
:=
&
ParsedRequest
{
Model
:
"claude-sonnet-4-5"
,
Stream
:
true
,
MetadataUserID
:
""
,
System
:
nil
,
Messages
:
nil
,
}
account
:=
&
Account
{
ID
:
123
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{},
// intentionally missing account_uuid / claude_user_id
}
fp
:=
&
Fingerprint
{
ClientID
:
"deadbeef"
}
// should be used as user id in legacy format
got
:=
svc
.
buildOAuthMetadataUserID
(
parsed
,
account
,
fp
)
require
.
NotEmpty
(
t
,
got
)
// Legacy format: user_{client}_account__session_{uuid}
re
:=
regexp
.
MustCompile
(
`^user_[a-zA-Z0-9]+_account__session_[a-f0-9-]{36}$`
)
require
.
True
(
t
,
re
.
MatchString
(
got
),
"unexpected user_id format: %s"
,
got
)
}
func
TestBuildOAuthMetadataUserID_UsesAccountUUIDWhenPresent
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
parsed
:=
&
ParsedRequest
{
Model
:
"claude-sonnet-4-5"
,
Stream
:
true
,
MetadataUserID
:
""
,
}
account
:=
&
Account
{
ID
:
123
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{
"account_uuid"
:
"acc-uuid"
,
"claude_user_id"
:
"clientid123"
,
"anthropic_user_id"
:
""
,
},
}
got
:=
svc
.
buildOAuthMetadataUserID
(
parsed
,
account
,
nil
)
require
.
NotEmpty
(
t
,
got
)
// New format: user_{client}_account_{account_uuid}_session_{uuid}
re
:=
regexp
.
MustCompile
(
`^user_clientid123_account_acc-uuid_session_[a-f0-9-]{36}$`
)
require
.
True
(
t
,
re
.
MatchString
(
got
),
"unexpected user_id format: %s"
,
got
)
}
backend/internal/service/gateway_prompt_test.go
View file @
377bffe2
...
...
@@ -2,6 +2,7 @@ package service
import
(
"encoding/json"
"strings"
"testing"
"github.com/stretchr/testify/require"
...
...
@@ -134,6 +135,8 @@ func TestSystemIncludesClaudeCodePrompt(t *testing.T) {
}
func
TestInjectClaudeCodePrompt
(
t
*
testing
.
T
)
{
claudePrefix
:=
strings
.
TrimSpace
(
claudeCodeSystemPrompt
)
tests
:=
[]
struct
{
name
string
body
string
...
...
@@ -162,7 +165,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
system
:
"Custom prompt"
,
wantSystemLen
:
2
,
wantFirstText
:
claudeCodeSystemPrompt
,
wantSecondText
:
"
Custom prompt"
,
wantSecondText
:
claudePrefix
+
"
\n\n
Custom prompt"
,
},
{
name
:
"string system equals Claude Code prompt"
,
...
...
@@ -178,7 +181,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
// Claude Code + Custom = 2
wantSystemLen
:
2
,
wantFirstText
:
claudeCodeSystemPrompt
,
wantSecondText
:
"
Custom"
,
wantSecondText
:
claudePrefix
+
"
\n\n
Custom"
,
},
{
name
:
"array system with existing Claude Code prompt (should dedupe)"
,
...
...
@@ -190,7 +193,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
// Claude Code at start + Other = 2 (deduped)
wantSystemLen
:
2
,
wantFirstText
:
claudeCodeSystemPrompt
,
wantSecondText
:
"
Other"
,
wantSecondText
:
claudePrefix
+
"
\n\n
Other"
,
},
{
name
:
"empty array"
,
...
...
backend/internal/service/gateway_sanitize_test.go
0 → 100644
View file @
377bffe2
package
service
import
(
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func
TestSanitizeOpenCodeText_RewritesCanonicalSentence
(
t
*
testing
.
T
)
{
in
:=
"You are OpenCode, the best coding agent on the planet."
got
:=
sanitizeSystemText
(
in
)
require
.
Equal
(
t
,
strings
.
TrimSpace
(
claudeCodeSystemPrompt
),
got
)
}
func
TestSanitizeToolDescription_DoesNotRewriteKeywords
(
t
*
testing
.
T
)
{
in
:=
"OpenCode and opencode are mentioned."
got
:=
sanitizeToolDescription
(
in
)
// We no longer rewrite tool descriptions; only redact obvious path leaks.
require
.
Equal
(
t
,
in
,
got
)
}
backend/internal/service/gateway_service.go
View file @
377bffe2
...
...
@@ -20,12 +20,14 @@ import (
"strings"
"sync/atomic"
"time"
"unicode"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/google/uuid"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
...
...
@@ -37,8 +39,15 @@ const (
claudeAPICountTokensURL
=
"https://api.anthropic.com/v1/messages/count_tokens?beta=true"
stickySessionTTL
=
time
.
Hour
// 粘性会话TTL
defaultMaxLineSize
=
40
*
1024
*
1024
claudeCodeSystemPrompt
=
"You are Claude Code, Anthropic's official CLI for Claude."
maxCacheControlBlocks
=
4
// Anthropic API 允许的最大 cache_control 块数量
// Canonical Claude Code banner. Keep it EXACT (no trailing whitespace/newlines)
// to match real Claude CLI traffic as closely as possible. When we need a visual
// separator between system blocks, we add "\n\n" at concatenation time.
claudeCodeSystemPrompt
=
"You are Claude Code, Anthropic's official CLI for Claude."
maxCacheControlBlocks
=
4
// Anthropic API 允许的最大 cache_control 块数量
)
const
(
claudeMimicDebugInfoKey
=
"claude_mimic_debug_info"
)
func
(
s
*
GatewayService
)
debugModelRoutingEnabled
()
bool
{
...
...
@@ -46,6 +55,11 @@ func (s *GatewayService) debugModelRoutingEnabled() bool {
return
v
==
"1"
||
v
==
"true"
||
v
==
"yes"
||
v
==
"on"
}
func
(
s
*
GatewayService
)
debugClaudeMimicEnabled
()
bool
{
v
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
os
.
Getenv
(
"SUB2API_DEBUG_CLAUDE_MIMIC"
)))
return
v
==
"1"
||
v
==
"true"
||
v
==
"yes"
||
v
==
"on"
}
func
shortSessionHash
(
sessionHash
string
)
string
{
if
sessionHash
==
""
{
return
""
...
...
@@ -56,12 +70,178 @@ func shortSessionHash(sessionHash string) string {
return
sessionHash
[
:
8
]
}
func
redactAuthHeaderValue
(
v
string
)
string
{
v
=
strings
.
TrimSpace
(
v
)
if
v
==
""
{
return
""
}
// Keep scheme for debugging, redact secret.
if
strings
.
HasPrefix
(
strings
.
ToLower
(
v
),
"bearer "
)
{
return
"Bearer [redacted]"
}
return
"[redacted]"
}
func
safeHeaderValueForLog
(
key
string
,
v
string
)
string
{
key
=
strings
.
ToLower
(
strings
.
TrimSpace
(
key
))
switch
key
{
case
"authorization"
,
"x-api-key"
:
return
redactAuthHeaderValue
(
v
)
default
:
return
strings
.
TrimSpace
(
v
)
}
}
func
extractSystemPreviewFromBody
(
body
[]
byte
)
string
{
if
len
(
body
)
==
0
{
return
""
}
sys
:=
gjson
.
GetBytes
(
body
,
"system"
)
if
!
sys
.
Exists
()
{
return
""
}
switch
{
case
sys
.
IsArray
()
:
for
_
,
item
:=
range
sys
.
Array
()
{
if
!
item
.
IsObject
()
{
continue
}
if
strings
.
EqualFold
(
item
.
Get
(
"type"
)
.
String
(),
"text"
)
{
if
t
:=
item
.
Get
(
"text"
)
.
String
();
strings
.
TrimSpace
(
t
)
!=
""
{
return
t
}
}
}
return
""
case
sys
.
Type
==
gjson
.
String
:
return
sys
.
String
()
default
:
return
""
}
}
func
buildClaudeMimicDebugLine
(
req
*
http
.
Request
,
body
[]
byte
,
account
*
Account
,
tokenType
string
,
mimicClaudeCode
bool
)
string
{
if
req
==
nil
{
return
""
}
// Only log a minimal fingerprint to avoid leaking user content.
interesting
:=
[]
string
{
"user-agent"
,
"x-app"
,
"anthropic-dangerous-direct-browser-access"
,
"anthropic-version"
,
"anthropic-beta"
,
"x-stainless-lang"
,
"x-stainless-package-version"
,
"x-stainless-os"
,
"x-stainless-arch"
,
"x-stainless-runtime"
,
"x-stainless-runtime-version"
,
"x-stainless-retry-count"
,
"x-stainless-timeout"
,
"authorization"
,
"x-api-key"
,
"content-type"
,
"accept"
,
"x-stainless-helper-method"
,
}
h
:=
make
([]
string
,
0
,
len
(
interesting
))
for
_
,
k
:=
range
interesting
{
if
v
:=
req
.
Header
.
Get
(
k
);
v
!=
""
{
h
=
append
(
h
,
fmt
.
Sprintf
(
"%s=%q"
,
k
,
safeHeaderValueForLog
(
k
,
v
)))
}
}
metaUserID
:=
strings
.
TrimSpace
(
gjson
.
GetBytes
(
body
,
"metadata.user_id"
)
.
String
())
sysPreview
:=
strings
.
TrimSpace
(
extractSystemPreviewFromBody
(
body
))
// Truncate preview to keep logs sane.
if
len
(
sysPreview
)
>
300
{
sysPreview
=
sysPreview
[
:
300
]
+
"..."
}
sysPreview
=
strings
.
ReplaceAll
(
sysPreview
,
"
\n
"
,
"
\\
n"
)
sysPreview
=
strings
.
ReplaceAll
(
sysPreview
,
"
\r
"
,
"
\\
r"
)
aid
:=
int64
(
0
)
aname
:=
""
if
account
!=
nil
{
aid
=
account
.
ID
aname
=
account
.
Name
}
return
fmt
.
Sprintf
(
"url=%s account=%d(%s) tokenType=%s mimic=%t meta.user_id=%q system.preview=%q headers={%s}"
,
req
.
URL
.
String
(),
aid
,
aname
,
tokenType
,
mimicClaudeCode
,
metaUserID
,
sysPreview
,
strings
.
Join
(
h
,
" "
),
)
}
func
logClaudeMimicDebug
(
req
*
http
.
Request
,
body
[]
byte
,
account
*
Account
,
tokenType
string
,
mimicClaudeCode
bool
)
{
line
:=
buildClaudeMimicDebugLine
(
req
,
body
,
account
,
tokenType
,
mimicClaudeCode
)
if
line
==
""
{
return
}
log
.
Printf
(
"[ClaudeMimicDebug] %s"
,
line
)
}
func
isClaudeCodeCredentialScopeError
(
msg
string
)
bool
{
m
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
msg
))
if
m
==
""
{
return
false
}
return
strings
.
Contains
(
m
,
"only authorized for use with claude code"
)
&&
strings
.
Contains
(
m
,
"cannot be used for other api requests"
)
}
// sseDataRe matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var
(
sseDataRe
=
regexp
.
MustCompile
(
`^data:\s*`
)
sessionIDRegex
=
regexp
.
MustCompile
(
`session_([a-f0-9-]{36})`
)
claudeCliUserAgentRe
=
regexp
.
MustCompile
(
`^claude-cli/\d+\.\d+\.\d+`
)
toolPrefixRe
=
regexp
.
MustCompile
(
`(?i)^(?:oc_|mcp_)`
)
toolNameBoundaryRe
=
regexp
.
MustCompile
(
`[^a-zA-Z0-9]+`
)
toolNameCamelRe
=
regexp
.
MustCompile
(
`([a-z0-9])([A-Z])`
)
toolNameFieldRe
=
regexp
.
MustCompile
(
`"name"\s*:\s*"([^"]+)"`
)
modelFieldRe
=
regexp
.
MustCompile
(
`"model"\s*:\s*"([^"]+)"`
)
toolDescAbsPathRe
=
regexp
.
MustCompile
(
`/\/?(?:home|Users|tmp|var|opt|usr|etc)\/[^\s,\)"'\]]+`
)
toolDescWinPathRe
=
regexp
.
MustCompile
(
`(?i)[A-Z]:\\[^\s,\)"'\]]+`
)
claudeToolNameOverrides
=
map
[
string
]
string
{
"bash"
:
"Bash"
,
"read"
:
"Read"
,
"edit"
:
"Edit"
,
"write"
:
"Write"
,
"task"
:
"Task"
,
"glob"
:
"Glob"
,
"grep"
:
"Grep"
,
"webfetch"
:
"WebFetch"
,
"websearch"
:
"WebSearch"
,
"todowrite"
:
"TodoWrite"
,
"question"
:
"AskUserQuestion"
,
}
openCodeToolOverrides
=
map
[
string
]
string
{
"Bash"
:
"bash"
,
"Read"
:
"read"
,
"Edit"
:
"edit"
,
"Write"
:
"write"
,
"Task"
:
"task"
,
"Glob"
:
"glob"
,
"Grep"
:
"grep"
,
"WebFetch"
:
"webfetch"
,
"WebSearch"
:
"websearch"
,
"TodoWrite"
:
"todowrite"
,
"AskUserQuestion"
:
"question"
,
}
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
// 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等
...
...
@@ -309,6 +489,19 @@ func (s *GatewayService) BindStickySession(ctx context.Context, groupID *int64,
return
s
.
cache
.
SetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
accountID
,
stickySessionTTL
)
}
// GetCachedSessionAccountID retrieves the account ID bound to a sticky session.
// Returns 0 if no binding exists or on error.
func
(
s
*
GatewayService
)
GetCachedSessionAccountID
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
)
(
int64
,
error
)
{
if
sessionHash
==
""
||
s
.
cache
==
nil
{
return
0
,
nil
}
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
if
err
!=
nil
{
return
0
,
err
}
return
accountID
,
nil
}
func
(
s
*
GatewayService
)
extractCacheableContent
(
parsed
*
ParsedRequest
)
string
{
if
parsed
==
nil
{
return
""
...
...
@@ -409,6 +602,394 @@ func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte
return
newBody
}
type
claudeOAuthNormalizeOptions
struct
{
injectMetadata
bool
metadataUserID
string
stripSystemCacheControl
bool
}
func
stripToolPrefix
(
value
string
)
string
{
if
value
==
""
{
return
value
}
return
toolPrefixRe
.
ReplaceAllString
(
value
,
""
)
}
func
toPascalCase
(
value
string
)
string
{
if
value
==
""
{
return
value
}
normalized
:=
toolNameBoundaryRe
.
ReplaceAllString
(
value
,
" "
)
tokens
:=
make
([]
string
,
0
)
for
_
,
token
:=
range
strings
.
Fields
(
normalized
)
{
expanded
:=
toolNameCamelRe
.
ReplaceAllString
(
token
,
"$1 $2"
)
parts
:=
strings
.
Fields
(
expanded
)
if
len
(
parts
)
>
0
{
tokens
=
append
(
tokens
,
parts
...
)
}
}
if
len
(
tokens
)
==
0
{
return
value
}
var
builder
strings
.
Builder
for
_
,
token
:=
range
tokens
{
lower
:=
strings
.
ToLower
(
token
)
if
lower
==
""
{
continue
}
runes
:=
[]
rune
(
lower
)
runes
[
0
]
=
unicode
.
ToUpper
(
runes
[
0
])
_
,
_
=
builder
.
WriteString
(
string
(
runes
))
}
return
builder
.
String
()
}
func
toSnakeCase
(
value
string
)
string
{
if
value
==
""
{
return
value
}
output
:=
toolNameCamelRe
.
ReplaceAllString
(
value
,
"$1_$2"
)
output
=
toolNameBoundaryRe
.
ReplaceAllString
(
output
,
"_"
)
output
=
strings
.
Trim
(
output
,
"_"
)
return
strings
.
ToLower
(
output
)
}
func
normalizeToolNameForClaude
(
name
string
,
cache
map
[
string
]
string
)
string
{
if
name
==
""
{
return
name
}
stripped
:=
stripToolPrefix
(
name
)
mapped
,
ok
:=
claudeToolNameOverrides
[
strings
.
ToLower
(
stripped
)]
if
!
ok
{
mapped
=
toPascalCase
(
stripped
)
}
if
mapped
!=
""
&&
cache
!=
nil
&&
mapped
!=
stripped
{
cache
[
mapped
]
=
stripped
}
if
mapped
==
""
{
return
stripped
}
return
mapped
}
func
normalizeToolNameForOpenCode
(
name
string
,
cache
map
[
string
]
string
)
string
{
if
name
==
""
{
return
name
}
stripped
:=
stripToolPrefix
(
name
)
if
cache
!=
nil
{
if
mapped
,
ok
:=
cache
[
stripped
];
ok
{
return
mapped
}
}
if
mapped
,
ok
:=
openCodeToolOverrides
[
stripped
];
ok
{
return
mapped
}
return
toSnakeCase
(
stripped
)
}
func
normalizeParamNameForOpenCode
(
name
string
,
cache
map
[
string
]
string
)
string
{
if
name
==
""
{
return
name
}
if
cache
!=
nil
{
if
mapped
,
ok
:=
cache
[
name
];
ok
{
return
mapped
}
}
return
name
}
// sanitizeSystemText rewrites only the fixed OpenCode identity sentence (if present).
// We intentionally avoid broad keyword replacement in system prompts to prevent
// accidentally changing user-provided instructions.
func
sanitizeSystemText
(
text
string
)
string
{
if
text
==
""
{
return
text
}
// Some clients include a fixed OpenCode identity sentence. Anthropic may treat
// this as a non-Claude-Code fingerprint, so rewrite it to the canonical
// Claude Code banner before generic "OpenCode"/"opencode" replacements.
text
=
strings
.
ReplaceAll
(
text
,
"You are OpenCode, the best coding agent on the planet."
,
strings
.
TrimSpace
(
claudeCodeSystemPrompt
),
)
return
text
}
func
sanitizeToolDescription
(
description
string
)
string
{
if
description
==
""
{
return
description
}
description
=
toolDescAbsPathRe
.
ReplaceAllString
(
description
,
"[path]"
)
description
=
toolDescWinPathRe
.
ReplaceAllString
(
description
,
"[path]"
)
// Intentionally do NOT rewrite tool descriptions (OpenCode/Claude strings).
// Tool names/skill names may rely on exact wording, and rewriting can be misleading.
return
description
}
func
normalizeToolInputSchema
(
inputSchema
any
,
cache
map
[
string
]
string
)
{
schema
,
ok
:=
inputSchema
.
(
map
[
string
]
any
)
if
!
ok
{
return
}
properties
,
ok
:=
schema
[
"properties"
]
.
(
map
[
string
]
any
)
if
!
ok
{
return
}
newProperties
:=
make
(
map
[
string
]
any
,
len
(
properties
))
for
key
,
value
:=
range
properties
{
snakeKey
:=
toSnakeCase
(
key
)
newProperties
[
snakeKey
]
=
value
if
snakeKey
!=
key
&&
cache
!=
nil
{
cache
[
snakeKey
]
=
key
}
}
schema
[
"properties"
]
=
newProperties
if
required
,
ok
:=
schema
[
"required"
]
.
([]
any
);
ok
{
newRequired
:=
make
([]
any
,
0
,
len
(
required
))
for
_
,
item
:=
range
required
{
name
,
ok
:=
item
.
(
string
)
if
!
ok
{
newRequired
=
append
(
newRequired
,
item
)
continue
}
snakeName
:=
toSnakeCase
(
name
)
newRequired
=
append
(
newRequired
,
snakeName
)
if
snakeName
!=
name
&&
cache
!=
nil
{
cache
[
snakeName
]
=
name
}
}
schema
[
"required"
]
=
newRequired
}
}
func
stripCacheControlFromSystemBlocks
(
system
any
)
bool
{
blocks
,
ok
:=
system
.
([]
any
)
if
!
ok
{
return
false
}
changed
:=
false
for
_
,
item
:=
range
blocks
{
block
,
ok
:=
item
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
if
_
,
exists
:=
block
[
"cache_control"
];
!
exists
{
continue
}
delete
(
block
,
"cache_control"
)
changed
=
true
}
return
changed
}
func
normalizeClaudeOAuthRequestBody
(
body
[]
byte
,
modelID
string
,
opts
claudeOAuthNormalizeOptions
)
([]
byte
,
string
,
map
[
string
]
string
)
{
if
len
(
body
)
==
0
{
return
body
,
modelID
,
nil
}
var
req
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
req
);
err
!=
nil
{
return
body
,
modelID
,
nil
}
toolNameMap
:=
make
(
map
[
string
]
string
)
if
system
,
ok
:=
req
[
"system"
];
ok
{
switch
v
:=
system
.
(
type
)
{
case
string
:
sanitized
:=
sanitizeSystemText
(
v
)
if
sanitized
!=
v
{
req
[
"system"
]
=
sanitized
}
case
[]
any
:
for
_
,
item
:=
range
v
{
block
,
ok
:=
item
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
if
blockType
,
_
:=
block
[
"type"
]
.
(
string
);
blockType
!=
"text"
{
continue
}
text
,
ok
:=
block
[
"text"
]
.
(
string
)
if
!
ok
||
text
==
""
{
continue
}
sanitized
:=
sanitizeSystemText
(
text
)
if
sanitized
!=
text
{
block
[
"text"
]
=
sanitized
}
}
}
}
if
rawModel
,
ok
:=
req
[
"model"
]
.
(
string
);
ok
{
normalized
:=
claude
.
NormalizeModelID
(
rawModel
)
if
normalized
!=
rawModel
{
req
[
"model"
]
=
normalized
modelID
=
normalized
}
}
if
rawTools
,
exists
:=
req
[
"tools"
];
exists
{
switch
tools
:=
rawTools
.
(
type
)
{
case
[]
any
:
for
idx
,
tool
:=
range
tools
{
toolMap
,
ok
:=
tool
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
if
name
,
ok
:=
toolMap
[
"name"
]
.
(
string
);
ok
{
normalized
:=
normalizeToolNameForClaude
(
name
,
toolNameMap
)
if
normalized
!=
""
&&
normalized
!=
name
{
toolMap
[
"name"
]
=
normalized
}
}
if
desc
,
ok
:=
toolMap
[
"description"
]
.
(
string
);
ok
{
sanitized
:=
sanitizeToolDescription
(
desc
)
if
sanitized
!=
desc
{
toolMap
[
"description"
]
=
sanitized
}
}
if
schema
,
ok
:=
toolMap
[
"input_schema"
];
ok
{
normalizeToolInputSchema
(
schema
,
toolNameMap
)
}
tools
[
idx
]
=
toolMap
}
req
[
"tools"
]
=
tools
case
map
[
string
]
any
:
normalizedTools
:=
make
(
map
[
string
]
any
,
len
(
tools
))
for
name
,
value
:=
range
tools
{
normalized
:=
normalizeToolNameForClaude
(
name
,
toolNameMap
)
if
normalized
==
""
{
normalized
=
name
}
if
toolMap
,
ok
:=
value
.
(
map
[
string
]
any
);
ok
{
toolMap
[
"name"
]
=
normalized
if
desc
,
ok
:=
toolMap
[
"description"
]
.
(
string
);
ok
{
sanitized
:=
sanitizeToolDescription
(
desc
)
if
sanitized
!=
desc
{
toolMap
[
"description"
]
=
sanitized
}
}
if
schema
,
ok
:=
toolMap
[
"input_schema"
];
ok
{
normalizeToolInputSchema
(
schema
,
toolNameMap
)
}
normalizedTools
[
normalized
]
=
toolMap
continue
}
normalizedTools
[
normalized
]
=
value
}
req
[
"tools"
]
=
normalizedTools
}
}
else
{
req
[
"tools"
]
=
[]
any
{}
}
if
messages
,
ok
:=
req
[
"messages"
]
.
([]
any
);
ok
{
for
_
,
msg
:=
range
messages
{
msgMap
,
ok
:=
msg
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
content
,
ok
:=
msgMap
[
"content"
]
.
([]
any
)
if
!
ok
{
continue
}
for
_
,
block
:=
range
content
{
blockMap
,
ok
:=
block
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
if
blockType
,
_
:=
blockMap
[
"type"
]
.
(
string
);
blockType
!=
"tool_use"
{
continue
}
if
name
,
ok
:=
blockMap
[
"name"
]
.
(
string
);
ok
{
normalized
:=
normalizeToolNameForClaude
(
name
,
toolNameMap
)
if
normalized
!=
""
&&
normalized
!=
name
{
blockMap
[
"name"
]
=
normalized
}
}
}
}
}
if
opts
.
stripSystemCacheControl
{
if
system
,
ok
:=
req
[
"system"
];
ok
{
_
=
stripCacheControlFromSystemBlocks
(
system
)
}
}
if
opts
.
injectMetadata
&&
opts
.
metadataUserID
!=
""
{
metadata
,
ok
:=
req
[
"metadata"
]
.
(
map
[
string
]
any
)
if
!
ok
{
metadata
=
map
[
string
]
any
{}
req
[
"metadata"
]
=
metadata
}
if
existing
,
ok
:=
metadata
[
"user_id"
]
.
(
string
);
!
ok
||
existing
==
""
{
metadata
[
"user_id"
]
=
opts
.
metadataUserID
}
}
delete
(
req
,
"temperature"
)
delete
(
req
,
"tool_choice"
)
newBody
,
err
:=
json
.
Marshal
(
req
)
if
err
!=
nil
{
return
body
,
modelID
,
toolNameMap
}
return
newBody
,
modelID
,
toolNameMap
}
func
(
s
*
GatewayService
)
buildOAuthMetadataUserID
(
parsed
*
ParsedRequest
,
account
*
Account
,
fp
*
Fingerprint
)
string
{
if
parsed
==
nil
||
account
==
nil
{
return
""
}
if
parsed
.
MetadataUserID
!=
""
{
return
""
}
userID
:=
strings
.
TrimSpace
(
account
.
GetClaudeUserID
())
if
userID
==
""
&&
fp
!=
nil
{
userID
=
fp
.
ClientID
}
if
userID
==
""
{
// Fall back to a random, well-formed client id so we can still satisfy
// Claude Code OAuth requirements when account metadata is incomplete.
userID
=
generateClientID
()
}
sessionHash
:=
s
.
GenerateSessionHash
(
parsed
)
sessionID
:=
uuid
.
NewString
()
if
sessionHash
!=
""
{
seed
:=
fmt
.
Sprintf
(
"%d::%s"
,
account
.
ID
,
sessionHash
)
sessionID
=
generateSessionUUID
(
seed
)
}
// Prefer the newer format that includes account_uuid (if present),
// otherwise fall back to the legacy Claude Code format.
accountUUID
:=
strings
.
TrimSpace
(
account
.
GetExtraString
(
"account_uuid"
))
if
accountUUID
!=
""
{
return
fmt
.
Sprintf
(
"user_%s_account_%s_session_%s"
,
userID
,
accountUUID
,
sessionID
)
}
return
fmt
.
Sprintf
(
"user_%s_account__session_%s"
,
userID
,
sessionID
)
}
func
generateSessionUUID
(
seed
string
)
string
{
if
seed
==
""
{
return
uuid
.
NewString
()
}
hash
:=
sha256
.
Sum256
([]
byte
(
seed
))
bytes
:=
hash
[
:
16
]
bytes
[
6
]
=
(
bytes
[
6
]
&
0x0f
)
|
0x40
bytes
[
8
]
=
(
bytes
[
8
]
&
0x3f
)
|
0x80
return
fmt
.
Sprintf
(
"%x-%x-%x-%x-%x"
,
bytes
[
0
:
4
],
bytes
[
4
:
6
],
bytes
[
6
:
8
],
bytes
[
8
:
10
],
bytes
[
10
:
16
])
}
// SelectAccount 选择账号(粘性会话+优先级)
func
(
s
*
GatewayService
)
SelectAccount
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
)
(
*
Account
,
error
)
{
return
s
.
SelectAccountForModel
(
ctx
,
groupID
,
sessionHash
,
""
)
...
...
@@ -1884,6 +2465,10 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
// Antigravity 平台使用专门的模型支持检查
return
IsAntigravityModelSupported
(
requestedModel
)
}
// Gemini API Key 账户直接透传,由上游判断模型是否支持
if
account
.
Platform
==
PlatformGemini
&&
account
.
Type
==
AccountTypeAPIKey
{
return
true
}
// 其他平台使用账户的模型支持检查
return
account
.
IsModelSupported
(
requestedModel
)
}
...
...
@@ -2008,6 +2593,16 @@ func isClaudeCodeClient(userAgent string, metadataUserID string) bool {
return
claudeCliUserAgentRe
.
MatchString
(
userAgent
)
}
func
isClaudeCodeRequest
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
parsed
*
ParsedRequest
)
bool
{
if
IsClaudeCodeClient
(
ctx
)
{
return
true
}
if
parsed
==
nil
||
c
==
nil
{
return
false
}
return
isClaudeCodeClient
(
c
.
GetHeader
(
"User-Agent"
),
parsed
.
MetadataUserID
)
}
// systemIncludesClaudeCodePrompt 检查 system 中是否已包含 Claude Code 提示词
// 使用前缀匹配支持多种变体(标准版、Agent SDK 版等)
func
systemIncludesClaudeCodePrompt
(
system
any
)
bool
{
...
...
@@ -2044,6 +2639,10 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
"text"
:
claudeCodeSystemPrompt
,
"cache_control"
:
map
[
string
]
string
{
"type"
:
"ephemeral"
},
}
// Opencode plugin applies an extra safeguard: it not only prepends the Claude Code
// banner, it also prefixes the next system instruction with the same banner plus
// a blank line. This helps when upstream concatenates system instructions.
claudeCodePrefix
:=
strings
.
TrimSpace
(
claudeCodeSystemPrompt
)
var
newSystem
[]
any
...
...
@@ -2051,19 +2650,36 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
case
nil
:
newSystem
=
[]
any
{
claudeCodeBlock
}
case
string
:
if
v
==
""
||
v
==
claudeCodeSystemPrompt
{
// Be tolerant of older/newer clients that may differ only by trailing whitespace/newlines.
if
strings
.
TrimSpace
(
v
)
==
""
||
strings
.
TrimSpace
(
v
)
==
strings
.
TrimSpace
(
claudeCodeSystemPrompt
)
{
newSystem
=
[]
any
{
claudeCodeBlock
}
}
else
{
newSystem
=
[]
any
{
claudeCodeBlock
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
v
}}
// Mirror opencode behavior: keep the banner as a separate system entry,
// but also prefix the next system text with the banner.
merged
:=
v
if
!
strings
.
HasPrefix
(
v
,
claudeCodePrefix
)
{
merged
=
claudeCodePrefix
+
"
\n\n
"
+
v
}
newSystem
=
[]
any
{
claudeCodeBlock
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
merged
}}
}
case
[]
any
:
newSystem
=
make
([]
any
,
0
,
len
(
v
)
+
1
)
newSystem
=
append
(
newSystem
,
claudeCodeBlock
)
prefixedNext
:=
false
for
_
,
item
:=
range
v
{
if
m
,
ok
:=
item
.
(
map
[
string
]
any
);
ok
{
if
text
,
ok
:=
m
[
"text"
]
.
(
string
);
ok
&&
text
==
claudeCodeSystemPrompt
{
if
text
,
ok
:=
m
[
"text"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
text
)
==
strings
.
TrimSpace
(
claudeCodeSystemPrompt
)
{
continue
}
// Prefix the first subsequent text system block once.
if
!
prefixedNext
{
if
blockType
,
_
:=
m
[
"type"
]
.
(
string
);
blockType
==
"text"
{
if
text
,
ok
:=
m
[
"text"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
text
)
!=
""
&&
!
strings
.
HasPrefix
(
text
,
claudeCodePrefix
)
{
m
[
"text"
]
=
claudeCodePrefix
+
"
\n\n
"
+
text
prefixedNext
=
true
}
}
}
}
newSystem
=
append
(
newSystem
,
item
)
}
...
...
@@ -2267,21 +2883,38 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
body
:=
parsed
.
Body
reqModel
:=
parsed
.
Model
reqStream
:=
parsed
.
Stream
originalModel
:=
reqModel
var
toolNameMap
map
[
string
]
string
isClaudeCode
:=
isClaudeCodeRequest
(
ctx
,
c
,
parsed
)
shouldMimicClaudeCode
:=
account
.
IsOAuth
()
&&
!
isClaudeCode
if
shouldMimicClaudeCode
{
// 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要)
// 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
if
!
strings
.
Contains
(
strings
.
ToLower
(
reqModel
),
"haiku"
)
&&
!
systemIncludesClaudeCodePrompt
(
parsed
.
System
)
{
body
=
injectClaudeCodePrompt
(
body
,
parsed
.
System
)
}
normalizeOpts
:=
claudeOAuthNormalizeOptions
{
stripSystemCacheControl
:
true
}
if
s
.
identityService
!=
nil
{
fp
,
err
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
.
Request
.
Header
)
if
err
==
nil
&&
fp
!=
nil
{
if
metadataUserID
:=
s
.
buildOAuthMetadataUserID
(
parsed
,
account
,
fp
);
metadataUserID
!=
""
{
normalizeOpts
.
injectMetadata
=
true
normalizeOpts
.
metadataUserID
=
metadataUserID
}
}
}
// 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要)
// 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
if
account
.
IsOAuth
()
&&
!
isClaudeCodeClient
(
c
.
GetHeader
(
"User-Agent"
),
parsed
.
MetadataUserID
)
&&
!
strings
.
Contains
(
strings
.
ToLower
(
reqModel
),
"haiku"
)
&&
!
systemIncludesClaudeCodePrompt
(
parsed
.
System
)
{
body
=
injectClaudeCodePrompt
(
body
,
parsed
.
System
)
body
,
reqModel
,
toolNameMap
=
normalizeClaudeOAuthRequestBody
(
body
,
reqModel
,
normalizeOpts
)
}
// 强制执行 cache_control 块数量限制(最多 4 个)
body
=
enforceCacheControlLimit
(
body
)
// 应用模型映射(仅对apikey类型账号)
originalModel
:=
reqModel
if
account
.
Type
==
AccountTypeAPIKey
{
mappedModel
:=
account
.
GetMappedModel
(
reqModel
)
if
mappedModel
!=
reqModel
{
...
...
@@ -2313,10 +2946,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
retryStart
:=
time
.
Now
()
for
attempt
:=
1
;
attempt
<=
maxRetryAttempts
;
attempt
++
{
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
upstreamReq
,
err
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
body
,
token
,
tokenType
,
reqModel
)
// Capture upstream request body for ops retry of this attempt.
c
.
Set
(
OpsUpstreamRequestBodyKey
,
string
(
body
))
upstreamReq
,
err
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
body
,
token
,
tokenType
,
reqModel
,
reqStream
,
shouldMimicClaudeCode
)
if
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -2394,7 +3026,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// also downgrade tool_use/tool_result blocks to text.
filteredBody
:=
FilterThinkingBlocksForRetry
(
body
)
retryReq
,
buildErr
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
filteredBody
,
token
,
tokenType
,
reqModel
)
retryReq
,
buildErr
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
filteredBody
,
token
,
tokenType
,
reqModel
,
reqStream
,
shouldMimicClaudeCode
)
if
buildErr
==
nil
{
retryResp
,
retryErr
:=
s
.
httpUpstream
.
DoWithTLS
(
retryReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
())
if
retryErr
==
nil
{
...
...
@@ -2426,7 +3058,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
if
looksLikeToolSignatureError
(
msg2
)
&&
time
.
Since
(
retryStart
)
<
maxRetryElapsed
{
log
.
Printf
(
"Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded"
,
account
.
ID
)
filteredBody2
:=
FilterSignatureSensitiveBlocksForRetry
(
body
)
retryReq2
,
buildErr2
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
filteredBody2
,
token
,
tokenType
,
reqModel
)
retryReq2
,
buildErr2
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
filteredBody2
,
token
,
tokenType
,
reqModel
,
reqStream
,
shouldMimicClaudeCode
)
if
buildErr2
==
nil
{
retryResp2
,
retryErr2
:=
s
.
httpUpstream
.
DoWithTLS
(
retryReq2
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
())
if
retryErr2
==
nil
{
...
...
@@ -2651,7 +3283,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
var
firstTokenMs
*
int
var
clientDisconnect
bool
if
reqStream
{
streamResult
,
err
:=
s
.
handleStreamingResponse
(
ctx
,
resp
,
c
,
account
,
startTime
,
originalModel
,
reqModel
)
streamResult
,
err
:=
s
.
handleStreamingResponse
(
ctx
,
resp
,
c
,
account
,
startTime
,
originalModel
,
reqModel
,
toolNameMap
,
shouldMimicClaudeCode
)
if
err
!=
nil
{
if
err
.
Error
()
==
"have error in stream"
{
return
nil
,
&
UpstreamFailoverError
{
...
...
@@ -2664,7 +3296,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
firstTokenMs
=
streamResult
.
firstTokenMs
clientDisconnect
=
streamResult
.
clientDisconnect
}
else
{
usage
,
err
=
s
.
handleNonStreamingResponse
(
ctx
,
resp
,
c
,
account
,
originalModel
,
reqModel
)
usage
,
err
=
s
.
handleNonStreamingResponse
(
ctx
,
resp
,
c
,
account
,
originalModel
,
reqModel
,
toolNameMap
,
shouldMimicClaudeCode
)
if
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -2681,7 +3313,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
},
nil
}
func
(
s
*
GatewayService
)
buildUpstreamRequest
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
,
token
,
tokenType
,
modelID
string
)
(
*
http
.
Request
,
error
)
{
func
(
s
*
GatewayService
)
buildUpstreamRequest
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
,
token
,
tokenType
,
modelID
string
,
reqStream
bool
,
mimicClaudeCode
bool
)
(
*
http
.
Request
,
error
)
{
// 确定目标URL
targetURL
:=
claudeAPIURL
if
account
.
Type
==
AccountTypeAPIKey
{
...
...
@@ -2695,11 +3327,16 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
}
clientHeaders
:=
http
.
Header
{}
if
c
!=
nil
&&
c
.
Request
!=
nil
{
clientHeaders
=
c
.
Request
.
Header
}
// OAuth账号:应用统一指纹
var
fingerprint
*
Fingerprint
if
account
.
IsOAuth
()
&&
s
.
identityService
!=
nil
{
// 1. 获取或创建指纹(包含随机生成的ClientID)
fp
,
err
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
.
Request
.
Header
)
fp
,
err
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
lient
Header
s
)
if
err
!=
nil
{
log
.
Printf
(
"Warning: failed to get fingerprint for account %d: %v"
,
account
.
ID
,
err
)
// 失败时降级为透传原始headers
...
...
@@ -2730,7 +3367,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
// 白名单透传headers
for
key
,
values
:=
range
c
.
Request
.
Header
{
for
key
,
values
:=
range
c
lient
Header
s
{
lowerKey
:=
strings
.
ToLower
(
key
)
if
allowedHeaders
[
lowerKey
]
{
for
_
,
v
:=
range
values
{
...
...
@@ -2751,10 +3388,30 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
if
req
.
Header
.
Get
(
"anthropic-version"
)
==
""
{
req
.
Header
.
Set
(
"anthropic-version"
,
"2023-06-01"
)
}
if
tokenType
==
"oauth"
{
applyClaudeOAuthHeaderDefaults
(
req
,
reqStream
)
}
// 处理anthropic-beta header(OAuth账号需要
特殊处理
)
// 处理
anthropic-beta header(OAuth
账号需要
包含 oauth beta
)
if
tokenType
==
"oauth"
{
req
.
Header
.
Set
(
"anthropic-beta"
,
s
.
getBetaHeader
(
modelID
,
c
.
GetHeader
(
"anthropic-beta"
)))
if
mimicClaudeCode
{
// 非 Claude Code 客户端:按 opencode 的策略处理:
// - 强制 Claude Code 指纹相关请求头(尤其是 user-agent/x-stainless/x-app)
// - 保留 incoming beta 的同时,确保 OAuth 所需 beta 存在
applyClaudeCodeMimicHeaders
(
req
,
reqStream
)
incomingBeta
:=
req
.
Header
.
Get
(
"anthropic-beta"
)
// Match real Claude CLI traffic (per mitmproxy reports):
// messages requests typically use only oauth + interleaved-thinking.
// Also drop claude-code beta if a downstream client added it.
requiredBetas
:=
[]
string
{
claude
.
BetaOAuth
,
claude
.
BetaInterleavedThinking
}
drop
:=
map
[
string
]
struct
{}{
claude
.
BetaClaudeCode
:
{}}
req
.
Header
.
Set
(
"anthropic-beta"
,
mergeAnthropicBetaDropping
(
requiredBetas
,
incomingBeta
,
drop
))
}
else
{
// Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta
clientBetaHeader
:=
req
.
Header
.
Get
(
"anthropic-beta"
)
req
.
Header
.
Set
(
"anthropic-beta"
,
s
.
getBetaHeader
(
modelID
,
clientBetaHeader
))
}
}
else
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
InjectBetaForAPIKey
&&
req
.
Header
.
Get
(
"anthropic-beta"
)
==
""
{
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
if
requestNeedsBetaFeatures
(
body
)
{
...
...
@@ -2764,6 +3421,15 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
}
// Always capture a compact fingerprint line for later error diagnostics.
// We only print it when needed (or when the explicit debug flag is enabled).
if
c
!=
nil
&&
tokenType
==
"oauth"
{
c
.
Set
(
claudeMimicDebugInfoKey
,
buildClaudeMimicDebugLine
(
req
,
body
,
account
,
tokenType
,
mimicClaudeCode
))
}
if
s
.
debugClaudeMimicEnabled
()
{
logClaudeMimicDebug
(
req
,
body
,
account
,
tokenType
,
mimicClaudeCode
)
}
return
req
,
nil
}
...
...
@@ -2833,22 +3499,109 @@ func defaultAPIKeyBetaHeader(body []byte) string {
return
claude
.
APIKeyBetaHeader
}
func
truncateForLog
(
b
[]
byte
,
maxBytes
int
)
string
{
if
maxBytes
<=
0
{
maxBytes
=
2048
func
applyClaudeOAuthHeaderDefaults
(
req
*
http
.
Request
,
isStream
bool
)
{
if
req
==
nil
{
return
}
if
len
(
b
)
>
maxBytes
{
b
=
b
[
:
maxBytes
]
if
req
.
Header
.
Get
(
"accept"
)
==
""
{
req
.
Header
.
Set
(
"accept"
,
"application/json"
)
}
for
key
,
value
:=
range
claude
.
DefaultHeaders
{
if
value
==
""
{
continue
}
if
req
.
Header
.
Get
(
key
)
==
""
{
req
.
Header
.
Set
(
key
,
value
)
}
}
if
isStream
&&
req
.
Header
.
Get
(
"x-stainless-helper-method"
)
==
""
{
req
.
Header
.
Set
(
"x-stainless-helper-method"
,
"stream"
)
}
s
:=
string
(
b
)
// 保持一行,避免污染日志格式
s
=
strings
.
ReplaceAll
(
s
,
"
\n
"
,
"
\\
n"
)
s
=
strings
.
ReplaceAll
(
s
,
"
\r
"
,
"
\\
r"
)
return
s
}
// isThinkingBlockSignatureError 检测是否是thinking block相关错误
// 这类错误可以通过过滤thinking blocks并重试来解决
func
mergeAnthropicBeta
(
required
[]
string
,
incoming
string
)
string
{
seen
:=
make
(
map
[
string
]
struct
{},
len
(
required
)
+
8
)
out
:=
make
([]
string
,
0
,
len
(
required
)
+
8
)
add
:=
func
(
v
string
)
{
v
=
strings
.
TrimSpace
(
v
)
if
v
==
""
{
return
}
if
_
,
ok
:=
seen
[
v
];
ok
{
return
}
seen
[
v
]
=
struct
{}{}
out
=
append
(
out
,
v
)
}
for
_
,
r
:=
range
required
{
add
(
r
)
}
for
_
,
p
:=
range
strings
.
Split
(
incoming
,
","
)
{
add
(
p
)
}
return
strings
.
Join
(
out
,
","
)
}
func
mergeAnthropicBetaDropping
(
required
[]
string
,
incoming
string
,
drop
map
[
string
]
struct
{})
string
{
merged
:=
mergeAnthropicBeta
(
required
,
incoming
)
if
merged
==
""
||
len
(
drop
)
==
0
{
return
merged
}
out
:=
make
([]
string
,
0
,
8
)
for
_
,
p
:=
range
strings
.
Split
(
merged
,
","
)
{
p
=
strings
.
TrimSpace
(
p
)
if
p
==
""
{
continue
}
if
_
,
ok
:=
drop
[
p
];
ok
{
continue
}
out
=
append
(
out
,
p
)
}
return
strings
.
Join
(
out
,
","
)
}
// applyClaudeCodeMimicHeaders forces "Claude Code-like" request headers.
// This mirrors opencode-anthropic-auth behavior: do not trust downstream
// headers when using Claude Code-scoped OAuth credentials.
func
applyClaudeCodeMimicHeaders
(
req
*
http
.
Request
,
isStream
bool
)
{
if
req
==
nil
{
return
}
// Start with the standard defaults (fill missing).
applyClaudeOAuthHeaderDefaults
(
req
,
isStream
)
// Then force key headers to match Claude Code fingerprint regardless of what the client sent.
for
key
,
value
:=
range
claude
.
DefaultHeaders
{
if
value
==
""
{
continue
}
req
.
Header
.
Set
(
key
,
value
)
}
// Real Claude CLI uses Accept: application/json (even for streaming).
req
.
Header
.
Set
(
"accept"
,
"application/json"
)
if
isStream
{
req
.
Header
.
Set
(
"x-stainless-helper-method"
,
"stream"
)
}
}
func
truncateForLog
(
b
[]
byte
,
maxBytes
int
)
string
{
if
maxBytes
<=
0
{
maxBytes
=
2048
}
if
len
(
b
)
>
maxBytes
{
b
=
b
[
:
maxBytes
]
}
s
:=
string
(
b
)
// 保持一行,避免污染日志格式
s
=
strings
.
ReplaceAll
(
s
,
"
\n
"
,
"
\\
n"
)
s
=
strings
.
ReplaceAll
(
s
,
"
\r
"
,
"
\\
r"
)
return
s
}
// isThinkingBlockSignatureError 检测是否是thinking block相关错误
// 这类错误可以通过过滤thinking blocks并重试来解决
func
(
s
*
GatewayService
)
isThinkingBlockSignatureError
(
respBody
[]
byte
)
bool
{
msg
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
respBody
)))
if
msg
==
""
{
...
...
@@ -2936,6 +3689,20 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
body
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
// Print a compact upstream request fingerprint when we hit the Claude Code OAuth
// credential scope error. This avoids requiring env-var tweaks in a fixed deploy.
if
isClaudeCodeCredentialScopeError
(
upstreamMsg
)
&&
c
!=
nil
{
if
v
,
ok
:=
c
.
Get
(
claudeMimicDebugInfoKey
);
ok
{
if
line
,
ok
:=
v
.
(
string
);
ok
&&
strings
.
TrimSpace
(
line
)
!=
""
{
log
.
Printf
(
"[ClaudeMimicDebugOnError] status=%d request_id=%s %s"
,
resp
.
StatusCode
,
resp
.
Header
.
Get
(
"x-request-id"
),
line
,
)
}
}
}
// Enrich Ops error logs with upstream status + message, and optionally a truncated body snippet.
upstreamDetail
:=
""
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
...
...
@@ -3065,6 +3832,19 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
respBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
if
isClaudeCodeCredentialScopeError
(
upstreamMsg
)
&&
c
!=
nil
{
if
v
,
ok
:=
c
.
Get
(
claudeMimicDebugInfoKey
);
ok
{
if
line
,
ok
:=
v
.
(
string
);
ok
&&
strings
.
TrimSpace
(
line
)
!=
""
{
log
.
Printf
(
"[ClaudeMimicDebugOnError] status=%d request_id=%s %s"
,
resp
.
StatusCode
,
resp
.
Header
.
Get
(
"x-request-id"
),
line
,
)
}
}
}
upstreamDetail
:=
""
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
maxBytes
:=
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
...
...
@@ -3117,7 +3897,7 @@ type streamingResult struct {
clientDisconnect
bool
// 客户端是否在流式传输过程中断开
}
func
(
s
*
GatewayService
)
handleStreamingResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
Account
,
startTime
time
.
Time
,
originalModel
,
mappedModel
string
)
(
*
streamingResult
,
error
)
{
func
(
s
*
GatewayService
)
handleStreamingResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
Account
,
startTime
time
.
Time
,
originalModel
,
mappedModel
string
,
toolNameMap
map
[
string
]
string
,
mimicClaudeCode
bool
)
(
*
streamingResult
,
error
)
{
// 更新5h窗口状态
s
.
rateLimitService
.
UpdateSessionWindow
(
ctx
,
account
,
resp
.
Header
)
...
...
@@ -3212,6 +3992,171 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
needModelReplace
:=
originalModel
!=
mappedModel
clientDisconnected
:=
false
// 客户端断开标志,断开后继续读取上游以获取完整usage
pendingEventLines
:=
make
([]
string
,
0
,
4
)
var
toolInputBuffers
map
[
int
]
string
if
mimicClaudeCode
{
toolInputBuffers
=
make
(
map
[
int
]
string
)
}
transformToolInputJSON
:=
func
(
raw
string
)
string
{
if
!
mimicClaudeCode
{
return
raw
}
raw
=
strings
.
TrimSpace
(
raw
)
if
raw
==
""
{
return
raw
}
var
parsed
any
if
err
:=
json
.
Unmarshal
([]
byte
(
raw
),
&
parsed
);
err
!=
nil
{
return
replaceToolNamesInText
(
raw
,
toolNameMap
)
}
rewritten
,
changed
:=
rewriteParamKeysInValue
(
parsed
,
toolNameMap
)
if
changed
{
if
bytes
,
err
:=
json
.
Marshal
(
rewritten
);
err
==
nil
{
return
string
(
bytes
)
}
}
return
raw
}
processSSEEvent
:=
func
(
lines
[]
string
)
([]
string
,
string
,
error
)
{
if
len
(
lines
)
==
0
{
return
nil
,
""
,
nil
}
eventName
:=
""
dataLine
:=
""
for
_
,
line
:=
range
lines
{
trimmed
:=
strings
.
TrimSpace
(
line
)
if
strings
.
HasPrefix
(
trimmed
,
"event:"
)
{
eventName
=
strings
.
TrimSpace
(
strings
.
TrimPrefix
(
trimmed
,
"event:"
))
continue
}
if
dataLine
==
""
&&
sseDataRe
.
MatchString
(
trimmed
)
{
dataLine
=
sseDataRe
.
ReplaceAllString
(
trimmed
,
""
)
}
}
if
eventName
==
"error"
{
return
nil
,
dataLine
,
errors
.
New
(
"have error in stream"
)
}
if
dataLine
==
""
{
return
[]
string
{
strings
.
Join
(
lines
,
"
\n
"
)
+
"
\n\n
"
},
""
,
nil
}
if
dataLine
==
"[DONE]"
{
block
:=
""
if
eventName
!=
""
{
block
=
"event: "
+
eventName
+
"
\n
"
}
block
+=
"data: "
+
dataLine
+
"
\n\n
"
return
[]
string
{
block
},
dataLine
,
nil
}
var
event
map
[
string
]
any
if
err
:=
json
.
Unmarshal
([]
byte
(
dataLine
),
&
event
);
err
!=
nil
{
replaced
:=
dataLine
if
mimicClaudeCode
{
replaced
=
replaceToolNamesInText
(
dataLine
,
toolNameMap
)
}
block
:=
""
if
eventName
!=
""
{
block
=
"event: "
+
eventName
+
"
\n
"
}
block
+=
"data: "
+
replaced
+
"
\n\n
"
return
[]
string
{
block
},
replaced
,
nil
}
eventType
,
_
:=
event
[
"type"
]
.
(
string
)
if
eventName
==
""
{
eventName
=
eventType
}
if
needModelReplace
{
if
msg
,
ok
:=
event
[
"message"
]
.
(
map
[
string
]
any
);
ok
{
if
model
,
ok
:=
msg
[
"model"
]
.
(
string
);
ok
&&
model
==
mappedModel
{
msg
[
"model"
]
=
originalModel
}
}
}
if
mimicClaudeCode
&&
eventType
==
"content_block_delta"
{
if
delta
,
ok
:=
event
[
"delta"
]
.
(
map
[
string
]
any
);
ok
{
if
deltaType
,
_
:=
delta
[
"type"
]
.
(
string
);
deltaType
==
"input_json_delta"
{
if
indexVal
,
ok
:=
event
[
"index"
]
.
(
float64
);
ok
{
index
:=
int
(
indexVal
)
if
partial
,
ok
:=
delta
[
"partial_json"
]
.
(
string
);
ok
{
toolInputBuffers
[
index
]
+=
partial
}
}
return
nil
,
dataLine
,
nil
}
}
}
if
mimicClaudeCode
&&
eventType
==
"content_block_stop"
{
if
indexVal
,
ok
:=
event
[
"index"
]
.
(
float64
);
ok
{
index
:=
int
(
indexVal
)
if
buffered
:=
toolInputBuffers
[
index
];
buffered
!=
""
{
delete
(
toolInputBuffers
,
index
)
transformed
:=
transformToolInputJSON
(
buffered
)
synthetic
:=
map
[
string
]
any
{
"type"
:
"content_block_delta"
,
"index"
:
index
,
"delta"
:
map
[
string
]
any
{
"type"
:
"input_json_delta"
,
"partial_json"
:
transformed
,
},
}
synthBytes
,
synthErr
:=
json
.
Marshal
(
synthetic
)
if
synthErr
==
nil
{
synthBlock
:=
"event: content_block_delta
\n
"
+
"data: "
+
string
(
synthBytes
)
+
"
\n\n
"
rewriteToolNamesInValue
(
event
,
toolNameMap
)
stopBytes
,
stopErr
:=
json
.
Marshal
(
event
)
if
stopErr
==
nil
{
stopBlock
:=
""
if
eventName
!=
""
{
stopBlock
=
"event: "
+
eventName
+
"
\n
"
}
stopBlock
+=
"data: "
+
string
(
stopBytes
)
+
"
\n\n
"
return
[]
string
{
synthBlock
,
stopBlock
},
string
(
stopBytes
),
nil
}
}
}
}
}
if
mimicClaudeCode
{
rewriteToolNamesInValue
(
event
,
toolNameMap
)
}
newData
,
err
:=
json
.
Marshal
(
event
)
if
err
!=
nil
{
replaced
:=
dataLine
if
mimicClaudeCode
{
replaced
=
replaceToolNamesInText
(
dataLine
,
toolNameMap
)
}
block
:=
""
if
eventName
!=
""
{
block
=
"event: "
+
eventName
+
"
\n
"
}
block
+=
"data: "
+
replaced
+
"
\n\n
"
return
[]
string
{
block
},
replaced
,
nil
}
block
:=
""
if
eventName
!=
""
{
block
=
"event: "
+
eventName
+
"
\n
"
}
block
+=
"data: "
+
string
(
newData
)
+
"
\n\n
"
return
[]
string
{
block
},
string
(
newData
),
nil
}
for
{
select
{
case
ev
,
ok
:=
<-
events
:
...
...
@@ -3240,43 +4185,44 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream read error: %w"
,
ev
.
err
)
}
line
:=
ev
.
line
if
line
==
"event: error"
{
// 上游返回错误事件,如果客户端已断开仍返回已收集的 usage
if
clientDisconnected
{
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
true
},
nil
}
return
nil
,
errors
.
New
(
"have error in stream"
)
}
trimmed
:=
strings
.
TrimSpace
(
line
)
// Extract data from SSE line (supports both "data: " and "data:" formats)
var
data
string
if
sseDataRe
.
MatchString
(
line
)
{
data
=
sseDataRe
.
ReplaceAllString
(
line
,
""
)
// 如果有模型映射,替换响应中的model字段
if
needModelReplace
{
line
=
s
.
replaceModelInSSELine
(
line
,
mappedModel
,
originalModel
)
if
trimmed
==
""
{
if
len
(
pendingEventLines
)
==
0
{
continue
}
}
// 写入客户端(统一处理 data 行和非 data 行)
if
!
clientDisconnected
{
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
clientDisconnected
=
true
log
.
Printf
(
"Client disconnected during streaming, continuing to drain upstream for billing"
)
}
else
{
flusher
.
Flush
()
outputBlocks
,
data
,
err
:=
processSSEEvent
(
pendingEventLines
)
pendingEventLines
=
pendingEventLines
[
:
0
]
if
err
!=
nil
{
if
clientDisconnected
{
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
true
},
nil
}
return
nil
,
err
}
}
// 无论客户端是否断开,都解析 usage(仅对 data 行)
if
data
!=
""
{
if
firstTokenMs
==
nil
&&
data
!=
"[DONE]"
{
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
firstTokenMs
=
&
ms
for
_
,
block
:=
range
outputBlocks
{
if
!
clientDisconnected
{
if
_
,
werr
:=
fmt
.
Fprint
(
w
,
block
);
werr
!=
nil
{
clientDisconnected
=
true
log
.
Printf
(
"Client disconnected during streaming, continuing to drain upstream for billing"
)
break
}
flusher
.
Flush
()
}
if
data
!=
""
{
if
firstTokenMs
==
nil
&&
data
!=
"[DONE]"
{
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
firstTokenMs
=
&
ms
}
s
.
parseSSEUsage
(
data
,
usage
)
}
}
s
.
parseSSEUsage
(
data
,
usage
)
continue
}
pendingEventLines
=
append
(
pendingEventLines
,
line
)
case
<-
intervalCh
:
lastRead
:=
time
.
Unix
(
0
,
atomic
.
LoadInt64
(
&
lastReadAt
))
if
time
.
Since
(
lastRead
)
<
streamInterval
{
...
...
@@ -3299,43 +4245,124 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
// replaceModelInSSELine 替换SSE数据行中的model字段
func
(
s
*
GatewayService
)
replaceModelInSSELine
(
line
,
fromModel
,
toModel
string
)
string
{
if
!
sseDataRe
.
MatchString
(
line
)
{
return
line
}
data
:=
sseDataRe
.
ReplaceAllString
(
line
,
""
)
if
data
==
""
||
data
==
"[DONE]"
{
return
line
}
var
event
map
[
string
]
any
if
err
:=
json
.
Unmarshal
([]
byte
(
data
),
&
event
);
err
!=
nil
{
return
line
}
// 只替换 message_start 事件中的 message.model
if
event
[
"type"
]
!=
"message_start"
{
return
line
func
rewriteParamKeysInValue
(
value
any
,
cache
map
[
string
]
string
)
(
any
,
bool
)
{
switch
v
:=
value
.
(
type
)
{
case
map
[
string
]
any
:
changed
:=
false
rewritten
:=
make
(
map
[
string
]
any
,
len
(
v
))
for
key
,
item
:=
range
v
{
newKey
:=
normalizeParamNameForOpenCode
(
key
,
cache
)
newItem
,
childChanged
:=
rewriteParamKeysInValue
(
item
,
cache
)
if
childChanged
{
changed
=
true
}
if
newKey
!=
key
{
changed
=
true
}
rewritten
[
newKey
]
=
newItem
}
if
!
changed
{
return
value
,
false
}
return
rewritten
,
true
case
[]
any
:
changed
:=
false
rewritten
:=
make
([]
any
,
len
(
v
))
for
idx
,
item
:=
range
v
{
newItem
,
childChanged
:=
rewriteParamKeysInValue
(
item
,
cache
)
if
childChanged
{
changed
=
true
}
rewritten
[
idx
]
=
newItem
}
if
!
changed
{
return
value
,
false
}
return
rewritten
,
true
default
:
return
value
,
false
}
}
msg
,
ok
:=
event
[
"message"
]
.
(
map
[
string
]
any
)
if
!
ok
{
return
line
func
rewriteToolNamesInValue
(
value
any
,
toolNameMap
map
[
string
]
string
)
bool
{
switch
v
:=
value
.
(
type
)
{
case
map
[
string
]
any
:
changed
:=
false
if
blockType
,
_
:=
v
[
"type"
]
.
(
string
);
blockType
==
"tool_use"
{
if
name
,
ok
:=
v
[
"name"
]
.
(
string
);
ok
{
mapped
:=
normalizeToolNameForOpenCode
(
name
,
toolNameMap
)
if
mapped
!=
name
{
v
[
"name"
]
=
mapped
changed
=
true
}
}
if
input
,
ok
:=
v
[
"input"
]
.
(
map
[
string
]
any
);
ok
{
rewrittenInput
,
inputChanged
:=
rewriteParamKeysInValue
(
input
,
toolNameMap
)
if
inputChanged
{
if
m
,
ok
:=
rewrittenInput
.
(
map
[
string
]
any
);
ok
{
v
[
"input"
]
=
m
changed
=
true
}
}
}
}
for
_
,
item
:=
range
v
{
if
rewriteToolNamesInValue
(
item
,
toolNameMap
)
{
changed
=
true
}
}
return
changed
case
[]
any
:
changed
:=
false
for
_
,
item
:=
range
v
{
if
rewriteToolNamesInValue
(
item
,
toolNameMap
)
{
changed
=
true
}
}
return
changed
default
:
return
false
}
}
model
,
ok
:=
msg
[
"model"
]
.
(
string
)
if
!
ok
||
model
!=
fromModel
{
return
line
func
replaceToolNamesInText
(
text
string
,
toolNameMap
map
[
string
]
string
)
string
{
if
text
==
""
{
return
text
}
output
:=
toolNameFieldRe
.
ReplaceAllStringFunc
(
text
,
func
(
match
string
)
string
{
submatches
:=
toolNameFieldRe
.
FindStringSubmatch
(
match
)
if
len
(
submatches
)
<
2
{
return
match
}
name
:=
submatches
[
1
]
mapped
:=
normalizeToolNameForOpenCode
(
name
,
toolNameMap
)
if
mapped
==
name
{
return
match
}
return
strings
.
Replace
(
match
,
name
,
mapped
,
1
)
})
output
=
modelFieldRe
.
ReplaceAllStringFunc
(
output
,
func
(
match
string
)
string
{
submatches
:=
modelFieldRe
.
FindStringSubmatch
(
match
)
if
len
(
submatches
)
<
2
{
return
match
}
model
:=
submatches
[
1
]
mapped
:=
claude
.
DenormalizeModelID
(
model
)
if
mapped
==
model
{
return
match
}
return
strings
.
Replace
(
match
,
model
,
mapped
,
1
)
})
msg
[
"model"
]
=
toModel
newData
,
err
:=
json
.
Marshal
(
event
)
if
err
!=
nil
{
return
line
for
mapped
,
original
:=
range
toolNameMap
{
if
mapped
==
""
||
original
==
""
||
mapped
==
original
{
continue
}
output
=
strings
.
ReplaceAll
(
output
,
"
\"
"
+
mapped
+
"
\"
:"
,
"
\"
"
+
original
+
"
\"
:"
)
output
=
strings
.
ReplaceAll
(
output
,
"
\\\"
"
+
mapped
+
"
\\\"
:"
,
"
\\\"
"
+
original
+
"
\\\"
:"
)
}
return
"data: "
+
string
(
newData
)
return
output
}
func
(
s
*
GatewayService
)
parseSSEUsage
(
data
string
,
usage
*
ClaudeUsage
)
{
...
...
@@ -3363,23 +4390,25 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
}
`json:"usage"`
}
if
json
.
Unmarshal
([]
byte
(
data
),
&
msgDelta
)
==
nil
&&
msgDelta
.
Type
==
"message_delta"
{
// output_tokens 总是从 message_delta 获取
usage
.
OutputTokens
=
msgDelta
.
Usage
.
OutputTokens
// 如果 message_start 中没有值,则从 message_delta 获取(兼容GLM等API)
if
usage
.
InputTokens
==
0
{
// message_delta 仅覆盖存在且非0的字段
// 避免覆盖 message_start 中已有的值(如 input_tokens)
// Claude API 的 message_delta 通常只包含 output_tokens
if
msgDelta
.
Usage
.
InputTokens
>
0
{
usage
.
InputTokens
=
msgDelta
.
Usage
.
InputTokens
}
if
usage
.
CacheCreationInputTokens
==
0
{
if
msgDelta
.
Usage
.
OutputTokens
>
0
{
usage
.
OutputTokens
=
msgDelta
.
Usage
.
OutputTokens
}
if
msgDelta
.
Usage
.
CacheCreationInputTokens
>
0
{
usage
.
CacheCreationInputTokens
=
msgDelta
.
Usage
.
CacheCreationInputTokens
}
if
u
sage
.
CacheReadInputTokens
==
0
{
if
msgDelta
.
U
sage
.
CacheReadInputTokens
>
0
{
usage
.
CacheReadInputTokens
=
msgDelta
.
Usage
.
CacheReadInputTokens
}
}
}
func
(
s
*
GatewayService
)
handleNonStreamingResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
Account
,
originalModel
,
mappedModel
string
)
(
*
ClaudeUsage
,
error
)
{
func
(
s
*
GatewayService
)
handleNonStreamingResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
Account
,
originalModel
,
mappedModel
string
,
toolNameMap
map
[
string
]
string
,
mimicClaudeCode
bool
)
(
*
ClaudeUsage
,
error
)
{
// 更新5h窗口状态
s
.
rateLimitService
.
UpdateSessionWindow
(
ctx
,
account
,
resp
.
Header
)
...
...
@@ -3400,6 +4429,9 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
if
originalModel
!=
mappedModel
{
body
=
s
.
replaceModelInResponseBody
(
body
,
mappedModel
,
originalModel
)
}
if
mimicClaudeCode
{
body
=
s
.
replaceToolNamesInResponseBody
(
body
,
toolNameMap
)
}
responseheaders
.
WriteFilteredHeaders
(
c
.
Writer
.
Header
(),
resp
.
Header
,
s
.
cfg
.
Security
.
ResponseHeaders
)
...
...
@@ -3437,6 +4469,28 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
return
newBody
}
func
(
s
*
GatewayService
)
replaceToolNamesInResponseBody
(
body
[]
byte
,
toolNameMap
map
[
string
]
string
)
[]
byte
{
if
len
(
body
)
==
0
{
return
body
}
var
resp
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
resp
);
err
!=
nil
{
replaced
:=
replaceToolNamesInText
(
string
(
body
),
toolNameMap
)
if
replaced
==
string
(
body
)
{
return
body
}
return
[]
byte
(
replaced
)
}
if
!
rewriteToolNamesInValue
(
resp
,
toolNameMap
)
{
return
body
}
newBody
,
err
:=
json
.
Marshal
(
resp
)
if
err
!=
nil
{
return
body
}
return
newBody
}
// RecordUsageInput 记录使用量的输入参数
type
RecordUsageInput
struct
{
Result
*
ForwardResult
...
...
@@ -3613,6 +4667,162 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
return
nil
}
// RecordUsageLongContextInput 记录使用量的输入参数(支持长上下文双倍计费)
type
RecordUsageLongContextInput
struct
{
Result
*
ForwardResult
APIKey
*
APIKey
User
*
User
Account
*
Account
Subscription
*
UserSubscription
// 可选:订阅信息
UserAgent
string
// 请求的 User-Agent
IPAddress
string
// 请求的客户端 IP 地址
LongContextThreshold
int
// 长上下文阈值(如 200000)
LongContextMultiplier
float64
// 超出阈值部分的倍率(如 2.0)
}
// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini)
func
(
s
*
GatewayService
)
RecordUsageWithLongContext
(
ctx
context
.
Context
,
input
*
RecordUsageLongContextInput
)
error
{
result
:=
input
.
Result
apiKey
:=
input
.
APIKey
user
:=
input
.
User
account
:=
input
.
Account
subscription
:=
input
.
Subscription
// 获取费率倍数
multiplier
:=
s
.
cfg
.
Default
.
RateMultiplier
if
apiKey
.
GroupID
!=
nil
&&
apiKey
.
Group
!=
nil
{
multiplier
=
apiKey
.
Group
.
RateMultiplier
}
var
cost
*
CostBreakdown
// 根据请求类型选择计费方式
if
result
.
ImageCount
>
0
{
// 图片生成计费
var
groupConfig
*
ImagePriceConfig
if
apiKey
.
Group
!=
nil
{
groupConfig
=
&
ImagePriceConfig
{
Price1K
:
apiKey
.
Group
.
ImagePrice1K
,
Price2K
:
apiKey
.
Group
.
ImagePrice2K
,
Price4K
:
apiKey
.
Group
.
ImagePrice4K
,
}
}
cost
=
s
.
billingService
.
CalculateImageCost
(
result
.
Model
,
result
.
ImageSize
,
result
.
ImageCount
,
groupConfig
,
multiplier
)
}
else
{
// Token 计费(使用长上下文计费方法)
tokens
:=
UsageTokens
{
InputTokens
:
result
.
Usage
.
InputTokens
,
OutputTokens
:
result
.
Usage
.
OutputTokens
,
CacheCreationTokens
:
result
.
Usage
.
CacheCreationInputTokens
,
CacheReadTokens
:
result
.
Usage
.
CacheReadInputTokens
,
}
var
err
error
cost
,
err
=
s
.
billingService
.
CalculateCostWithLongContext
(
result
.
Model
,
tokens
,
multiplier
,
input
.
LongContextThreshold
,
input
.
LongContextMultiplier
)
if
err
!=
nil
{
log
.
Printf
(
"Calculate cost failed: %v"
,
err
)
cost
=
&
CostBreakdown
{
ActualCost
:
0
}
}
}
// 判断计费方式:订阅模式 vs 余额模式
isSubscriptionBilling
:=
subscription
!=
nil
&&
apiKey
.
Group
!=
nil
&&
apiKey
.
Group
.
IsSubscriptionType
()
billingType
:=
BillingTypeBalance
if
isSubscriptionBilling
{
billingType
=
BillingTypeSubscription
}
// 创建使用日志
durationMs
:=
int
(
result
.
Duration
.
Milliseconds
())
var
imageSize
*
string
if
result
.
ImageSize
!=
""
{
imageSize
=
&
result
.
ImageSize
}
accountRateMultiplier
:=
account
.
BillingRateMultiplier
()
usageLog
:=
&
UsageLog
{
UserID
:
user
.
ID
,
APIKeyID
:
apiKey
.
ID
,
AccountID
:
account
.
ID
,
RequestID
:
result
.
RequestID
,
Model
:
result
.
Model
,
InputTokens
:
result
.
Usage
.
InputTokens
,
OutputTokens
:
result
.
Usage
.
OutputTokens
,
CacheCreationTokens
:
result
.
Usage
.
CacheCreationInputTokens
,
CacheReadTokens
:
result
.
Usage
.
CacheReadInputTokens
,
InputCost
:
cost
.
InputCost
,
OutputCost
:
cost
.
OutputCost
,
CacheCreationCost
:
cost
.
CacheCreationCost
,
CacheReadCost
:
cost
.
CacheReadCost
,
TotalCost
:
cost
.
TotalCost
,
ActualCost
:
cost
.
ActualCost
,
RateMultiplier
:
multiplier
,
AccountRateMultiplier
:
&
accountRateMultiplier
,
BillingType
:
billingType
,
Stream
:
result
.
Stream
,
DurationMs
:
&
durationMs
,
FirstTokenMs
:
result
.
FirstTokenMs
,
ImageCount
:
result
.
ImageCount
,
ImageSize
:
imageSize
,
CreatedAt
:
time
.
Now
(),
}
// 添加 UserAgent
if
input
.
UserAgent
!=
""
{
usageLog
.
UserAgent
=
&
input
.
UserAgent
}
// 添加 IPAddress
if
input
.
IPAddress
!=
""
{
usageLog
.
IPAddress
=
&
input
.
IPAddress
}
// 添加分组和订阅关联
if
apiKey
.
GroupID
!=
nil
{
usageLog
.
GroupID
=
apiKey
.
GroupID
}
if
subscription
!=
nil
{
usageLog
.
SubscriptionID
=
&
subscription
.
ID
}
inserted
,
err
:=
s
.
usageLogRepo
.
Create
(
ctx
,
usageLog
)
if
err
!=
nil
{
log
.
Printf
(
"Create usage log failed: %v"
,
err
)
}
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
RunMode
==
config
.
RunModeSimple
{
log
.
Printf
(
"[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d"
,
usageLog
.
UserID
,
usageLog
.
TotalTokens
())
s
.
deferredService
.
ScheduleLastUsedUpdate
(
account
.
ID
)
return
nil
}
shouldBill
:=
inserted
||
err
!=
nil
// 根据计费类型执行扣费
if
isSubscriptionBilling
{
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
if
shouldBill
&&
cost
.
TotalCost
>
0
{
if
err
:=
s
.
userSubRepo
.
IncrementUsage
(
ctx
,
subscription
.
ID
,
cost
.
TotalCost
);
err
!=
nil
{
log
.
Printf
(
"Increment subscription usage failed: %v"
,
err
)
}
// 异步更新订阅缓存
s
.
billingCacheService
.
QueueUpdateSubscriptionUsage
(
user
.
ID
,
*
apiKey
.
GroupID
,
cost
.
TotalCost
)
}
}
else
{
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
if
shouldBill
&&
cost
.
ActualCost
>
0
{
if
err
:=
s
.
userRepo
.
DeductBalance
(
ctx
,
user
.
ID
,
cost
.
ActualCost
);
err
!=
nil
{
log
.
Printf
(
"Deduct balance failed: %v"
,
err
)
}
// 异步更新余额缓存
s
.
billingCacheService
.
QueueDeductBalance
(
user
.
ID
,
cost
.
ActualCost
)
}
}
// Schedule batch update for account last_used_at
s
.
deferredService
.
ScheduleLastUsedUpdate
(
account
.
ID
)
return
nil
}
// ForwardCountTokens 转发 count_tokens 请求到上游 API
// 特点:不记录使用量、仅支持非流式响应
func
(
s
*
GatewayService
)
ForwardCountTokens
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
parsed
*
ParsedRequest
)
error
{
...
...
@@ -3624,6 +4834,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
body
:=
parsed
.
Body
reqModel
:=
parsed
.
Model
isClaudeCode
:=
isClaudeCodeRequest
(
ctx
,
c
,
parsed
)
shouldMimicClaudeCode
:=
account
.
IsOAuth
()
&&
!
isClaudeCode
if
shouldMimicClaudeCode
{
normalizeOpts
:=
claudeOAuthNormalizeOptions
{
stripSystemCacheControl
:
true
}
body
,
reqModel
,
_
=
normalizeClaudeOAuthRequestBody
(
body
,
reqModel
,
normalizeOpts
)
}
// Antigravity 账户不支持 count_tokens 转发,直接返回空值
if
account
.
Platform
==
PlatformAntigravity
{
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"input_tokens"
:
0
})
...
...
@@ -3650,7 +4868,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
// 构建上游请求
upstreamReq
,
err
:=
s
.
buildCountTokensRequest
(
ctx
,
c
,
account
,
body
,
token
,
tokenType
,
reqModel
)
upstreamReq
,
err
:=
s
.
buildCountTokensRequest
(
ctx
,
c
,
account
,
body
,
token
,
tokenType
,
reqModel
,
shouldMimicClaudeCode
)
if
err
!=
nil
{
s
.
countTokensError
(
c
,
http
.
StatusInternalServerError
,
"api_error"
,
"Failed to build request"
)
return
err
...
...
@@ -3683,7 +4901,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
log
.
Printf
(
"Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks"
,
account
.
ID
)
filteredBody
:=
FilterThinkingBlocksForRetry
(
body
)
retryReq
,
buildErr
:=
s
.
buildCountTokensRequest
(
ctx
,
c
,
account
,
filteredBody
,
token
,
tokenType
,
reqModel
)
retryReq
,
buildErr
:=
s
.
buildCountTokensRequest
(
ctx
,
c
,
account
,
filteredBody
,
token
,
tokenType
,
reqModel
,
shouldMimicClaudeCode
)
if
buildErr
==
nil
{
retryResp
,
retryErr
:=
s
.
httpUpstream
.
DoWithTLS
(
retryReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
())
if
retryErr
==
nil
{
...
...
@@ -3748,7 +4966,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
// buildCountTokensRequest 构建 count_tokens 上游请求
func
(
s
*
GatewayService
)
buildCountTokensRequest
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
,
token
,
tokenType
,
modelID
string
)
(
*
http
.
Request
,
error
)
{
func
(
s
*
GatewayService
)
buildCountTokensRequest
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
,
token
,
tokenType
,
modelID
string
,
mimicClaudeCode
bool
)
(
*
http
.
Request
,
error
)
{
// 确定目标 URL
targetURL
:=
claudeAPICountTokensURL
if
account
.
Type
==
AccountTypeAPIKey
{
...
...
@@ -3762,10 +4980,15 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
}
clientHeaders
:=
http
.
Header
{}
if
c
!=
nil
&&
c
.
Request
!=
nil
{
clientHeaders
=
c
.
Request
.
Header
}
// OAuth 账号:应用统一指纹和重写 userID
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
if
account
.
IsOAuth
()
&&
s
.
identityService
!=
nil
{
fp
,
err
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
.
Request
.
Header
)
fp
,
err
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
lient
Header
s
)
if
err
==
nil
{
accountUUID
:=
account
.
GetExtraString
(
"account_uuid"
)
if
accountUUID
!=
""
&&
fp
.
ClientID
!=
""
{
...
...
@@ -3789,7 +5012,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
// 白名单透传 headers
for
key
,
values
:=
range
c
.
Request
.
Header
{
for
key
,
values
:=
range
c
lient
Header
s
{
lowerKey
:=
strings
.
ToLower
(
key
)
if
allowedHeaders
[
lowerKey
]
{
for
_
,
v
:=
range
values
{
...
...
@@ -3800,7 +5023,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
// OAuth 账号:应用指纹到请求头
if
account
.
IsOAuth
()
&&
s
.
identityService
!=
nil
{
fp
,
_
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
.
Request
.
Header
)
fp
,
_
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
lient
Header
s
)
if
fp
!=
nil
{
s
.
identityService
.
ApplyFingerprint
(
req
,
fp
)
}
...
...
@@ -3813,10 +5036,30 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
if
req
.
Header
.
Get
(
"anthropic-version"
)
==
""
{
req
.
Header
.
Set
(
"anthropic-version"
,
"2023-06-01"
)
}
if
tokenType
==
"oauth"
{
applyClaudeOAuthHeaderDefaults
(
req
,
false
)
}
// OAuth 账号:处理 anthropic-beta header
if
tokenType
==
"oauth"
{
req
.
Header
.
Set
(
"anthropic-beta"
,
s
.
getBetaHeader
(
modelID
,
c
.
GetHeader
(
"anthropic-beta"
)))
if
mimicClaudeCode
{
applyClaudeCodeMimicHeaders
(
req
,
false
)
incomingBeta
:=
req
.
Header
.
Get
(
"anthropic-beta"
)
requiredBetas
:=
[]
string
{
claude
.
BetaClaudeCode
,
claude
.
BetaOAuth
,
claude
.
BetaInterleavedThinking
,
claude
.
BetaTokenCounting
}
req
.
Header
.
Set
(
"anthropic-beta"
,
mergeAnthropicBeta
(
requiredBetas
,
incomingBeta
))
}
else
{
clientBetaHeader
:=
req
.
Header
.
Get
(
"anthropic-beta"
)
if
clientBetaHeader
==
""
{
req
.
Header
.
Set
(
"anthropic-beta"
,
claude
.
CountTokensBetaHeader
)
}
else
{
beta
:=
s
.
getBetaHeader
(
modelID
,
clientBetaHeader
)
if
!
strings
.
Contains
(
beta
,
claude
.
BetaTokenCounting
)
{
beta
=
beta
+
","
+
claude
.
BetaTokenCounting
}
req
.
Header
.
Set
(
"anthropic-beta"
,
beta
)
}
}
}
else
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
InjectBetaForAPIKey
&&
req
.
Header
.
Get
(
"anthropic-beta"
)
==
""
{
// API-key:与 messages 同步的按需 beta 注入(默认关闭)
if
requestNeedsBetaFeatures
(
body
)
{
...
...
@@ -3826,6 +5069,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
}
if
c
!=
nil
&&
tokenType
==
"oauth"
{
c
.
Set
(
claudeMimicDebugInfoKey
,
buildClaudeMimicDebugLine
(
req
,
body
,
account
,
tokenType
,
mimicClaudeCode
))
}
if
s
.
debugClaudeMimicEnabled
()
{
logClaudeMimicDebug
(
req
,
body
,
account
,
tokenType
,
mimicClaudeCode
)
}
return
req
,
nil
}
...
...
backend/internal/service/gemini_messages_compat_service.go
View file @
377bffe2
...
...
@@ -36,6 +36,11 @@ const (
geminiRetryMaxDelay
=
16
*
time
.
Second
)
// Gemini tool calling now requires `thoughtSignature` in parts that include `functionCall`.
// Many clients don't send it; we inject a known dummy signature to satisfy the validator.
// Ref: https://ai.google.dev/gemini-api/docs/thought-signatures
const
geminiDummyThoughtSignature
=
"skip_thought_signature_validator"
type
GeminiMessagesCompatService
struct
{
accountRepo
AccountRepository
groupRepo
GroupRepository
...
...
@@ -528,6 +533,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
if
err
!=
nil
{
return
nil
,
s
.
writeClaudeError
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
err
.
Error
())
}
geminiReq
=
ensureGeminiFunctionCallThoughtSignatures
(
geminiReq
)
originalClaudeBody
:=
body
proxyURL
:=
""
...
...
@@ -931,6 +937,13 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}
}
// 图片生成计费
imageCount
:=
0
imageSize
:=
s
.
extractImageSize
(
body
)
if
isImageGenerationModel
(
originalModel
)
{
imageCount
=
1
}
return
&
ForwardResult
{
RequestID
:
requestID
,
Usage
:
*
usage
,
...
...
@@ -938,6 +951,8 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
Stream
:
req
.
Stream
,
Duration
:
time
.
Since
(
startTime
),
FirstTokenMs
:
firstTokenMs
,
ImageCount
:
imageCount
,
ImageSize
:
imageSize
,
},
nil
}
...
...
@@ -969,6 +984,10 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return
nil
,
s
.
writeGoogleError
(
c
,
http
.
StatusNotFound
,
"Unsupported action: "
+
action
)
}
// Some Gemini upstreams validate tool call parts strictly; ensure any `functionCall` part includes a
// `thoughtSignature` to avoid frequent INVALID_ARGUMENT 400s.
body
=
ensureGeminiFunctionCallThoughtSignatures
(
body
)
mappedModel
:=
originalModel
if
account
.
Type
==
AccountTypeAPIKey
{
mappedModel
=
account
.
GetMappedModel
(
originalModel
)
...
...
@@ -1371,6 +1390,13 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
usage
=
&
ClaudeUsage
{}
}
// 图片生成计费
imageCount
:=
0
imageSize
:=
s
.
extractImageSize
(
body
)
if
isImageGenerationModel
(
originalModel
)
{
imageCount
=
1
}
return
&
ForwardResult
{
RequestID
:
requestID
,
Usage
:
*
usage
,
...
...
@@ -1378,6 +1404,8 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
Stream
:
stream
,
Duration
:
time
.
Since
(
startTime
),
FirstTokenMs
:
firstTokenMs
,
ImageCount
:
imageCount
,
ImageSize
:
imageSize
,
},
nil
}
...
...
@@ -2504,9 +2532,13 @@ func extractGeminiUsage(geminiResp map[string]any) *ClaudeUsage {
}
prompt
,
_
:=
asInt
(
usageMeta
[
"promptTokenCount"
])
cand
,
_
:=
asInt
(
usageMeta
[
"candidatesTokenCount"
])
cached
,
_
:=
asInt
(
usageMeta
[
"cachedContentTokenCount"
])
// 注意:Gemini 的 promptTokenCount 包含 cachedContentTokenCount,
// 但 Claude 的 input_tokens 不包含 cache_read_input_tokens,需要减去
return
&
ClaudeUsage
{
InputTokens
:
prompt
,
OutputTokens
:
cand
,
InputTokens
:
prompt
-
cached
,
OutputTokens
:
cand
,
CacheReadInputTokens
:
cached
,
}
}
...
...
@@ -2635,6 +2667,58 @@ func nextGeminiDailyResetUnix() *int64 {
return
&
ts
}
func
ensureGeminiFunctionCallThoughtSignatures
(
body
[]
byte
)
[]
byte
{
// Fast path: only run when functionCall is present.
if
!
bytes
.
Contains
(
body
,
[]
byte
(
`"functionCall"`
))
{
return
body
}
var
payload
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
payload
);
err
!=
nil
{
return
body
}
contentsAny
,
ok
:=
payload
[
"contents"
]
.
([]
any
)
if
!
ok
||
len
(
contentsAny
)
==
0
{
return
body
}
modified
:=
false
for
_
,
c
:=
range
contentsAny
{
cm
,
ok
:=
c
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
partsAny
,
ok
:=
cm
[
"parts"
]
.
([]
any
)
if
!
ok
||
len
(
partsAny
)
==
0
{
continue
}
for
_
,
p
:=
range
partsAny
{
pm
,
ok
:=
p
.
(
map
[
string
]
any
)
if
!
ok
||
pm
==
nil
{
continue
}
if
fc
,
ok
:=
pm
[
"functionCall"
]
.
(
map
[
string
]
any
);
!
ok
||
fc
==
nil
{
continue
}
ts
,
_
:=
pm
[
"thoughtSignature"
]
.
(
string
)
if
strings
.
TrimSpace
(
ts
)
==
""
{
pm
[
"thoughtSignature"
]
=
geminiDummyThoughtSignature
modified
=
true
}
}
}
if
!
modified
{
return
body
}
b
,
err
:=
json
.
Marshal
(
payload
)
if
err
!=
nil
{
return
body
}
return
b
}
func
extractGeminiFinishReason
(
geminiResp
map
[
string
]
any
)
string
{
if
candidates
,
ok
:=
geminiResp
[
"candidates"
]
.
([]
any
);
ok
&&
len
(
candidates
)
>
0
{
if
cand
,
ok
:=
candidates
[
0
]
.
(
map
[
string
]
any
);
ok
{
...
...
@@ -2834,7 +2918,13 @@ func convertClaudeMessagesToGeminiContents(messages any, toolUseIDToName map[str
if
strings
.
TrimSpace
(
id
)
!=
""
&&
strings
.
TrimSpace
(
name
)
!=
""
{
toolUseIDToName
[
id
]
=
name
}
signature
,
_
:=
bm
[
"signature"
]
.
(
string
)
signature
=
strings
.
TrimSpace
(
signature
)
if
signature
==
""
{
signature
=
geminiDummyThoughtSignature
}
parts
=
append
(
parts
,
map
[
string
]
any
{
"thoughtSignature"
:
signature
,
"functionCall"
:
map
[
string
]
any
{
"name"
:
name
,
"args"
:
bm
[
"input"
],
...
...
@@ -3031,3 +3121,26 @@ func convertClaudeGenerationConfig(req map[string]any) map[string]any {
}
return
out
}
// extractImageSize 从 Gemini 请求中提取 image_size 参数
func
(
s
*
GeminiMessagesCompatService
)
extractImageSize
(
body
[]
byte
)
string
{
var
req
struct
{
GenerationConfig
*
struct
{
ImageConfig
*
struct
{
ImageSize
string
`json:"imageSize"`
}
`json:"imageConfig"`
}
`json:"generationConfig"`
}
if
err
:=
json
.
Unmarshal
(
body
,
&
req
);
err
!=
nil
{
return
"2K"
}
if
req
.
GenerationConfig
!=
nil
&&
req
.
GenerationConfig
.
ImageConfig
!=
nil
{
size
:=
strings
.
ToUpper
(
strings
.
TrimSpace
(
req
.
GenerationConfig
.
ImageConfig
.
ImageSize
))
if
size
==
"1K"
||
size
==
"2K"
||
size
==
"4K"
{
return
size
}
}
return
"2K"
}
backend/internal/service/gemini_messages_compat_service_test.go
View file @
377bffe2
package
service
import
(
"encoding/json"
"strings"
"testing"
)
...
...
@@ -126,3 +128,78 @@ func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
})
}
}
func
TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse
(
t
*
testing
.
T
)
{
claudeReq
:=
map
[
string
]
any
{
"model"
:
"claude-haiku-4-5-20251001"
,
"max_tokens"
:
10
,
"messages"
:
[]
any
{
map
[
string
]
any
{
"role"
:
"user"
,
"content"
:
[]
any
{
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
"hi"
},
},
},
map
[
string
]
any
{
"role"
:
"assistant"
,
"content"
:
[]
any
{
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
"ok"
},
map
[
string
]
any
{
"type"
:
"tool_use"
,
"id"
:
"toolu_123"
,
"name"
:
"default_api:write_file"
,
"input"
:
map
[
string
]
any
{
"path"
:
"a.txt"
,
"content"
:
"x"
},
// no signature on purpose
},
},
},
},
"tools"
:
[]
any
{
map
[
string
]
any
{
"name"
:
"default_api:write_file"
,
"description"
:
"write file"
,
"input_schema"
:
map
[
string
]
any
{
"type"
:
"object"
,
"properties"
:
map
[
string
]
any
{
"path"
:
map
[
string
]
any
{
"type"
:
"string"
}},
},
},
},
}
b
,
_
:=
json
.
Marshal
(
claudeReq
)
out
,
err
:=
convertClaudeMessagesToGeminiGenerateContent
(
b
)
if
err
!=
nil
{
t
.
Fatalf
(
"convert failed: %v"
,
err
)
}
s
:=
string
(
out
)
if
!
strings
.
Contains
(
s
,
"
\"
functionCall
\"
"
)
{
t
.
Fatalf
(
"expected functionCall in output, got: %s"
,
s
)
}
if
!
strings
.
Contains
(
s
,
"
\"
thoughtSignature
\"
:
\"
"
+
geminiDummyThoughtSignature
+
"
\"
"
)
{
t
.
Fatalf
(
"expected injected thoughtSignature %q, got: %s"
,
geminiDummyThoughtSignature
,
s
)
}
}
func
TestEnsureGeminiFunctionCallThoughtSignatures_InsertsWhenMissing
(
t
*
testing
.
T
)
{
geminiReq
:=
map
[
string
]
any
{
"contents"
:
[]
any
{
map
[
string
]
any
{
"role"
:
"user"
,
"parts"
:
[]
any
{
map
[
string
]
any
{
"functionCall"
:
map
[
string
]
any
{
"name"
:
"default_api:write_file"
,
"args"
:
map
[
string
]
any
{
"path"
:
"a.txt"
},
},
},
},
},
},
}
b
,
_
:=
json
.
Marshal
(
geminiReq
)
out
:=
ensureGeminiFunctionCallThoughtSignatures
(
b
)
s
:=
string
(
out
)
if
!
strings
.
Contains
(
s
,
"
\"
thoughtSignature
\"
:
\"
"
+
geminiDummyThoughtSignature
+
"
\"
"
)
{
t
.
Fatalf
(
"expected injected thoughtSignature %q, got: %s"
,
geminiDummyThoughtSignature
,
s
)
}
}
backend/internal/service/gemini_multiplatform_test.go
View file @
377bffe2
...
...
@@ -221,6 +221,14 @@ func (m *mockGroupRepoForGemini) DeleteAccountGroupsByGroupID(ctx context.Contex
return
0
,
nil
}
func
(
m
*
mockGroupRepoForGemini
)
BindAccountsToGroup
(
ctx
context
.
Context
,
groupID
int64
,
accountIDs
[]
int64
)
error
{
return
nil
}
func
(
m
*
mockGroupRepoForGemini
)
GetAccountIDsByGroupIDs
(
ctx
context
.
Context
,
groupIDs
[]
int64
)
([]
int64
,
error
)
{
return
nil
,
nil
}
var
_
GroupRepository
=
(
*
mockGroupRepoForGemini
)(
nil
)
// mockGatewayCacheForGemini Gemini 测试用的 cache mock
...
...
backend/internal/service/gemini_native_signature_cleaner.go
0 → 100644
View file @
377bffe2
package
service
import
(
"encoding/json"
)
// CleanGeminiNativeThoughtSignatures 从 Gemini 原生 API 请求中移除 thoughtSignature 字段,
// 以避免跨账号签名验证错误。
//
// 当粘性会话切换账号时(例如原账号异常、不可调度等),旧账号返回的 thoughtSignature
// 会导致新账号的签名验证失败。通过移除这些签名,让新账号重新生成有效的签名。
//
// CleanGeminiNativeThoughtSignatures removes thoughtSignature fields from Gemini native API requests
// to avoid cross-account signature validation errors.
//
// When sticky session switches accounts (e.g., original account becomes unavailable),
// thoughtSignatures from the old account will cause validation failures on the new account.
// By removing these signatures, we allow the new account to generate valid signatures.
func
CleanGeminiNativeThoughtSignatures
(
body
[]
byte
)
[]
byte
{
if
len
(
body
)
==
0
{
return
body
}
// 解析 JSON
var
data
any
if
err
:=
json
.
Unmarshal
(
body
,
&
data
);
err
!=
nil
{
// 如果解析失败,返回原始 body(可能不是 JSON 或格式不正确)
return
body
}
// 递归清理 thoughtSignature
cleaned
:=
cleanThoughtSignaturesRecursive
(
data
)
// 重新序列化
result
,
err
:=
json
.
Marshal
(
cleaned
)
if
err
!=
nil
{
// 如果序列化失败,返回原始 body
return
body
}
return
result
}
// cleanThoughtSignaturesRecursive 递归遍历数据结构,移除所有 thoughtSignature 字段
func
cleanThoughtSignaturesRecursive
(
data
any
)
any
{
switch
v
:=
data
.
(
type
)
{
case
map
[
string
]
any
:
// 创建新的 map,移除 thoughtSignature
result
:=
make
(
map
[
string
]
any
,
len
(
v
))
for
key
,
value
:=
range
v
{
// 跳过 thoughtSignature 字段
if
key
==
"thoughtSignature"
{
continue
}
// 递归处理嵌套结构
result
[
key
]
=
cleanThoughtSignaturesRecursive
(
value
)
}
return
result
case
[]
any
:
// 递归处理数组中的每个元素
result
:=
make
([]
any
,
len
(
v
))
for
i
,
item
:=
range
v
{
result
[
i
]
=
cleanThoughtSignaturesRecursive
(
item
)
}
return
result
default
:
// 基本类型(string, number, bool, null)直接返回
return
v
}
}
backend/internal/service/group_service.go
View file @
377bffe2
...
...
@@ -29,6 +29,10 @@ type GroupRepository interface {
ExistsByName
(
ctx
context
.
Context
,
name
string
)
(
bool
,
error
)
GetAccountCount
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
DeleteAccountGroupsByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
// GetAccountIDsByGroupIDs 获取多个分组的所有账号 ID(去重)
GetAccountIDsByGroupIDs
(
ctx
context
.
Context
,
groupIDs
[]
int64
)
([]
int64
,
error
)
// BindAccountsToGroup 将多个账号绑定到指定分组
BindAccountsToGroup
(
ctx
context
.
Context
,
groupID
int64
,
accountIDs
[]
int64
)
error
}
// CreateGroupRequest 创建分组请求
...
...
backend/internal/service/identity_service.go
View file @
377bffe2
...
...
@@ -26,13 +26,13 @@ var (
// 默认指纹值(当客户端未提供时使用)
var
defaultFingerprint
=
Fingerprint
{
UserAgent
:
"claude-cli/2.
0.6
2 (external, cli)"
,
UserAgent
:
"claude-cli/2.
1.2
2 (external, cli)"
,
StainlessLang
:
"js"
,
StainlessPackageVersion
:
"0.
52
.0"
,
StainlessPackageVersion
:
"0.
70
.0"
,
StainlessOS
:
"Linux"
,
StainlessArch
:
"
x
64"
,
StainlessArch
:
"
arm
64"
,
StainlessRuntime
:
"node"
,
StainlessRuntimeVersion
:
"v2
2
.1
4
.0"
,
StainlessRuntimeVersion
:
"v2
4
.1
3
.0"
,
}
// Fingerprint represents account fingerprint data
...
...
@@ -327,7 +327,7 @@ func generateUUIDFromSeed(seed string) string {
}
// parseUserAgentVersion 解析user-agent版本号
// 例如:claude-cli/2.
0.6
2 -> (2,
0
,
6
2)
// 例如:claude-cli/2.
1.
2 -> (2,
1
, 2)
func
parseUserAgentVersion
(
ua
string
)
(
major
,
minor
,
patch
int
,
ok
bool
)
{
// 匹配 xxx/x.y.z 格式
matches
:=
userAgentVersionRegex
.
FindStringSubmatch
(
ua
)
...
...
Prev
1
…
3
4
5
6
7
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment