Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
31fe0178
Commit
31fe0178
authored
Feb 03, 2026
by
yangjianbo
Browse files
Merge branch 'main' of
https://github.com/mt21625457/aicodex2api
parents
d9e345f2
ba5a0d47
Changes
235
Hide whitespace changes
Inline
Side-by-side
backend/internal/service/antigravity_quota_scope.go
View file @
31fe0178
...
@@ -89,3 +89,30 @@ func (a *Account) antigravityQuotaScopeResetAt(scope AntigravityQuotaScope) *tim
...
@@ -89,3 +89,30 @@ func (a *Account) antigravityQuotaScopeResetAt(scope AntigravityQuotaScope) *tim
}
}
return
&
resetAt
return
&
resetAt
}
}
var
antigravityAllScopes
=
[]
AntigravityQuotaScope
{
AntigravityQuotaScopeClaude
,
AntigravityQuotaScopeGeminiText
,
AntigravityQuotaScopeGeminiImage
,
}
func
(
a
*
Account
)
GetAntigravityScopeRateLimits
()
map
[
string
]
int64
{
if
a
==
nil
||
a
.
Platform
!=
PlatformAntigravity
{
return
nil
}
now
:=
time
.
Now
()
result
:=
make
(
map
[
string
]
int64
)
for
_
,
scope
:=
range
antigravityAllScopes
{
resetAt
:=
a
.
antigravityQuotaScopeResetAt
(
scope
)
if
resetAt
!=
nil
&&
now
.
Before
(
*
resetAt
)
{
remainingSec
:=
int64
(
time
.
Until
(
*
resetAt
)
.
Seconds
())
if
remainingSec
>
0
{
result
[
string
(
scope
)]
=
remainingSec
}
}
}
if
len
(
result
)
==
0
{
return
nil
}
return
result
}
backend/internal/service/antigravity_token_refresher.go
View file @
31fe0178
...
@@ -3,6 +3,8 @@ package service
...
@@ -3,6 +3,8 @@ package service
import
(
import
(
"context"
"context"
"fmt"
"fmt"
"log"
"strings"
"time"
"time"
)
)
...
@@ -55,15 +57,32 @@ func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Accoun
...
@@ -55,15 +57,32 @@ func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Accoun
}
}
newCredentials
:=
r
.
antigravityOAuthService
.
BuildAccountCredentials
(
tokenInfo
)
newCredentials
:=
r
.
antigravityOAuthService
.
BuildAccountCredentials
(
tokenInfo
)
// 合并旧的 credentials,保留新 credentials 中不存在的字段
for
k
,
v
:=
range
account
.
Credentials
{
for
k
,
v
:=
range
account
.
Credentials
{
if
_
,
exists
:=
newCredentials
[
k
];
!
exists
{
if
_
,
exists
:=
newCredentials
[
k
];
!
exists
{
newCredentials
[
k
]
=
v
newCredentials
[
k
]
=
v
}
}
}
}
// 如果 project_id 获取失败,返回 credentials 但同时返回错误让账户被标记
// 特殊处理 project_id:如果新值为空但旧值非空,保留旧值
// 这确保了即使 LoadCodeAssist 失败,project_id 也不会丢失
if
newProjectID
,
_
:=
newCredentials
[
"project_id"
]
.
(
string
);
newProjectID
==
""
{
if
oldProjectID
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"project_id"
));
oldProjectID
!=
""
{
newCredentials
[
"project_id"
]
=
oldProjectID
}
}
// 如果 project_id 获取失败,只记录警告,不返回错误
// LoadCodeAssist 失败可能是临时网络问题,应该允许重试而不是立即标记为不可重试错误
// Token 刷新本身是成功的(access_token 和 refresh_token 已更新)
if
tokenInfo
.
ProjectIDMissing
{
if
tokenInfo
.
ProjectIDMissing
{
return
newCredentials
,
fmt
.
Errorf
(
"missing_project_id: 账户缺少project id,可能无法使用Antigravity"
)
if
tokenInfo
.
ProjectID
!=
""
{
// 有旧的 project_id,本次获取失败,保留旧值
log
.
Printf
(
"[AntigravityTokenRefresher] Account %d: LoadCodeAssist 临时失败,保留旧 project_id"
,
account
.
ID
)
}
else
{
// 从未获取过 project_id,本次也失败,但不返回错误以允许下次重试
log
.
Printf
(
"[AntigravityTokenRefresher] Account %d: LoadCodeAssist 失败,project_id 缺失,但 token 已更新,将在下次刷新时重试"
,
account
.
ID
)
}
}
}
return
newCredentials
,
nil
return
newCredentials
,
nil
...
...
backend/internal/service/auth_service.go
View file @
31fe0178
...
@@ -19,17 +19,19 @@ import (
...
@@ -19,17 +19,19 @@ import (
)
)
var
(
var
(
ErrInvalidCredentials
=
infraerrors
.
Unauthorized
(
"INVALID_CREDENTIALS"
,
"invalid email or password"
)
ErrInvalidCredentials
=
infraerrors
.
Unauthorized
(
"INVALID_CREDENTIALS"
,
"invalid email or password"
)
ErrUserNotActive
=
infraerrors
.
Forbidden
(
"USER_NOT_ACTIVE"
,
"user is not active"
)
ErrUserNotActive
=
infraerrors
.
Forbidden
(
"USER_NOT_ACTIVE"
,
"user is not active"
)
ErrEmailExists
=
infraerrors
.
Conflict
(
"EMAIL_EXISTS"
,
"email already exists"
)
ErrEmailExists
=
infraerrors
.
Conflict
(
"EMAIL_EXISTS"
,
"email already exists"
)
ErrEmailReserved
=
infraerrors
.
BadRequest
(
"EMAIL_RESERVED"
,
"email is reserved"
)
ErrEmailReserved
=
infraerrors
.
BadRequest
(
"EMAIL_RESERVED"
,
"email is reserved"
)
ErrInvalidToken
=
infraerrors
.
Unauthorized
(
"INVALID_TOKEN"
,
"invalid token"
)
ErrInvalidToken
=
infraerrors
.
Unauthorized
(
"INVALID_TOKEN"
,
"invalid token"
)
ErrTokenExpired
=
infraerrors
.
Unauthorized
(
"TOKEN_EXPIRED"
,
"token has expired"
)
ErrTokenExpired
=
infraerrors
.
Unauthorized
(
"TOKEN_EXPIRED"
,
"token has expired"
)
ErrTokenTooLarge
=
infraerrors
.
BadRequest
(
"TOKEN_TOO_LARGE"
,
"token too large"
)
ErrTokenTooLarge
=
infraerrors
.
BadRequest
(
"TOKEN_TOO_LARGE"
,
"token too large"
)
ErrTokenRevoked
=
infraerrors
.
Unauthorized
(
"TOKEN_REVOKED"
,
"token has been revoked"
)
ErrTokenRevoked
=
infraerrors
.
Unauthorized
(
"TOKEN_REVOKED"
,
"token has been revoked"
)
ErrEmailVerifyRequired
=
infraerrors
.
BadRequest
(
"EMAIL_VERIFY_REQUIRED"
,
"email verification is required"
)
ErrEmailVerifyRequired
=
infraerrors
.
BadRequest
(
"EMAIL_VERIFY_REQUIRED"
,
"email verification is required"
)
ErrRegDisabled
=
infraerrors
.
Forbidden
(
"REGISTRATION_DISABLED"
,
"registration is currently disabled"
)
ErrRegDisabled
=
infraerrors
.
Forbidden
(
"REGISTRATION_DISABLED"
,
"registration is currently disabled"
)
ErrServiceUnavailable
=
infraerrors
.
ServiceUnavailable
(
"SERVICE_UNAVAILABLE"
,
"service temporarily unavailable"
)
ErrServiceUnavailable
=
infraerrors
.
ServiceUnavailable
(
"SERVICE_UNAVAILABLE"
,
"service temporarily unavailable"
)
ErrInvitationCodeRequired
=
infraerrors
.
BadRequest
(
"INVITATION_CODE_REQUIRED"
,
"invitation code is required"
)
ErrInvitationCodeInvalid
=
infraerrors
.
BadRequest
(
"INVITATION_CODE_INVALID"
,
"invalid or used invitation code"
)
)
)
// maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。
// maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。
...
@@ -47,6 +49,7 @@ type JWTClaims struct {
...
@@ -47,6 +49,7 @@ type JWTClaims struct {
// AuthService 认证服务
// AuthService 认证服务
type
AuthService
struct
{
type
AuthService
struct
{
userRepo
UserRepository
userRepo
UserRepository
redeemRepo
RedeemCodeRepository
cfg
*
config
.
Config
cfg
*
config
.
Config
settingService
*
SettingService
settingService
*
SettingService
emailService
*
EmailService
emailService
*
EmailService
...
@@ -58,6 +61,7 @@ type AuthService struct {
...
@@ -58,6 +61,7 @@ type AuthService struct {
// NewAuthService 创建认证服务实例
// NewAuthService 创建认证服务实例
func
NewAuthService
(
func
NewAuthService
(
userRepo
UserRepository
,
userRepo
UserRepository
,
redeemRepo
RedeemCodeRepository
,
cfg
*
config
.
Config
,
cfg
*
config
.
Config
,
settingService
*
SettingService
,
settingService
*
SettingService
,
emailService
*
EmailService
,
emailService
*
EmailService
,
...
@@ -67,6 +71,7 @@ func NewAuthService(
...
@@ -67,6 +71,7 @@ func NewAuthService(
)
*
AuthService
{
)
*
AuthService
{
return
&
AuthService
{
return
&
AuthService
{
userRepo
:
userRepo
,
userRepo
:
userRepo
,
redeemRepo
:
redeemRepo
,
cfg
:
cfg
,
cfg
:
cfg
,
settingService
:
settingService
,
settingService
:
settingService
,
emailService
:
emailService
,
emailService
:
emailService
,
...
@@ -78,11 +83,11 @@ func NewAuthService(
...
@@ -78,11 +83,11 @@ func NewAuthService(
// Register 用户注册,返回token和用户
// Register 用户注册,返回token和用户
func
(
s
*
AuthService
)
Register
(
ctx
context
.
Context
,
email
,
password
string
)
(
string
,
*
User
,
error
)
{
func
(
s
*
AuthService
)
Register
(
ctx
context
.
Context
,
email
,
password
string
)
(
string
,
*
User
,
error
)
{
return
s
.
RegisterWithVerification
(
ctx
,
email
,
password
,
""
,
""
)
return
s
.
RegisterWithVerification
(
ctx
,
email
,
password
,
""
,
""
,
""
)
}
}
// RegisterWithVerification 用户注册(支持邮件验证
和
优惠码),返回token和用户
// RegisterWithVerification 用户注册(支持邮件验证
、
优惠码
和邀请码
),返回token和用户
func
(
s
*
AuthService
)
RegisterWithVerification
(
ctx
context
.
Context
,
email
,
password
,
verifyCode
,
promoCode
string
)
(
string
,
*
User
,
error
)
{
func
(
s
*
AuthService
)
RegisterWithVerification
(
ctx
context
.
Context
,
email
,
password
,
verifyCode
,
promoCode
,
invitationCode
string
)
(
string
,
*
User
,
error
)
{
// 检查是否开放注册(默认关闭:settingService 未配置时不允许注册)
// 检查是否开放注册(默认关闭:settingService 未配置时不允许注册)
if
s
.
settingService
==
nil
||
!
s
.
settingService
.
IsRegistrationEnabled
(
ctx
)
{
if
s
.
settingService
==
nil
||
!
s
.
settingService
.
IsRegistrationEnabled
(
ctx
)
{
return
""
,
nil
,
ErrRegDisabled
return
""
,
nil
,
ErrRegDisabled
...
@@ -93,6 +98,26 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
...
@@ -93,6 +98,26 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return
""
,
nil
,
ErrEmailReserved
return
""
,
nil
,
ErrEmailReserved
}
}
// 检查是否需要邀请码
var
invitationRedeemCode
*
RedeemCode
if
s
.
settingService
!=
nil
&&
s
.
settingService
.
IsInvitationCodeEnabled
(
ctx
)
{
if
invitationCode
==
""
{
return
""
,
nil
,
ErrInvitationCodeRequired
}
// 验证邀请码
redeemCode
,
err
:=
s
.
redeemRepo
.
GetByCode
(
ctx
,
invitationCode
)
if
err
!=
nil
{
log
.
Printf
(
"[Auth] Invalid invitation code: %s, error: %v"
,
invitationCode
,
err
)
return
""
,
nil
,
ErrInvitationCodeInvalid
}
// 检查类型和状态
if
redeemCode
.
Type
!=
RedeemTypeInvitation
||
redeemCode
.
Status
!=
StatusUnused
{
log
.
Printf
(
"[Auth] Invitation code invalid: type=%s, status=%s"
,
redeemCode
.
Type
,
redeemCode
.
Status
)
return
""
,
nil
,
ErrInvitationCodeInvalid
}
invitationRedeemCode
=
redeemCode
}
// 检查是否需要邮件验证
// 检查是否需要邮件验证
if
s
.
settingService
!=
nil
&&
s
.
settingService
.
IsEmailVerifyEnabled
(
ctx
)
{
if
s
.
settingService
!=
nil
&&
s
.
settingService
.
IsEmailVerifyEnabled
(
ctx
)
{
// 如果邮件验证已开启但邮件服务未配置,拒绝注册
// 如果邮件验证已开启但邮件服务未配置,拒绝注册
...
@@ -153,6 +178,14 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
...
@@ -153,6 +178,14 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return
""
,
nil
,
ErrServiceUnavailable
return
""
,
nil
,
ErrServiceUnavailable
}
}
// 标记邀请码为已使用(如果使用了邀请码)
if
invitationRedeemCode
!=
nil
{
if
err
:=
s
.
redeemRepo
.
Use
(
ctx
,
invitationRedeemCode
.
ID
,
user
.
ID
);
err
!=
nil
{
// 邀请码标记失败不影响注册,只记录日志
log
.
Printf
(
"[Auth] Failed to mark invitation code as used for user %d: %v"
,
user
.
ID
,
err
)
}
}
// 应用优惠码(如果提供且功能已启用)
// 应用优惠码(如果提供且功能已启用)
if
promoCode
!=
""
&&
s
.
promoService
!=
nil
&&
s
.
settingService
!=
nil
&&
s
.
settingService
.
IsPromoCodeEnabled
(
ctx
)
{
if
promoCode
!=
""
&&
s
.
promoService
!=
nil
&&
s
.
settingService
!=
nil
&&
s
.
settingService
.
IsPromoCodeEnabled
(
ctx
)
{
if
err
:=
s
.
promoService
.
ApplyPromoCode
(
ctx
,
user
.
ID
,
promoCode
);
err
!=
nil
{
if
err
:=
s
.
promoService
.
ApplyPromoCode
(
ctx
,
user
.
ID
,
promoCode
);
err
!=
nil
{
...
@@ -580,3 +613,149 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
...
@@ -580,3 +613,149 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
// 生成新token
// 生成新token
return
s
.
GenerateToken
(
user
)
return
s
.
GenerateToken
(
user
)
}
}
// IsPasswordResetEnabled 检查是否启用密码重置功能
// 要求:必须同时开启邮件验证且 SMTP 配置正确
func
(
s
*
AuthService
)
IsPasswordResetEnabled
(
ctx
context
.
Context
)
bool
{
if
s
.
settingService
==
nil
{
return
false
}
// Must have email verification enabled and SMTP configured
if
!
s
.
settingService
.
IsEmailVerifyEnabled
(
ctx
)
{
return
false
}
return
s
.
settingService
.
IsPasswordResetEnabled
(
ctx
)
}
// preparePasswordReset validates the password reset request and returns necessary data
// Returns (siteName, resetURL, shouldProceed)
// shouldProceed is false when we should silently return success (to prevent enumeration)
func
(
s
*
AuthService
)
preparePasswordReset
(
ctx
context
.
Context
,
email
,
frontendBaseURL
string
)
(
string
,
string
,
bool
)
{
// Check if user exists (but don't reveal this to the caller)
user
,
err
:=
s
.
userRepo
.
GetByEmail
(
ctx
,
email
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
ErrUserNotFound
)
{
// Security: Log but don't reveal that user doesn't exist
log
.
Printf
(
"[Auth] Password reset requested for non-existent email: %s"
,
email
)
return
""
,
""
,
false
}
log
.
Printf
(
"[Auth] Database error checking email for password reset: %v"
,
err
)
return
""
,
""
,
false
}
// Check if user is active
if
!
user
.
IsActive
()
{
log
.
Printf
(
"[Auth] Password reset requested for inactive user: %s"
,
email
)
return
""
,
""
,
false
}
// Get site name
siteName
:=
"Sub2API"
if
s
.
settingService
!=
nil
{
siteName
=
s
.
settingService
.
GetSiteName
(
ctx
)
}
// Build reset URL base
resetURL
:=
fmt
.
Sprintf
(
"%s/reset-password"
,
strings
.
TrimSuffix
(
frontendBaseURL
,
"/"
))
return
siteName
,
resetURL
,
true
}
// RequestPasswordReset 请求密码重置(同步发送)
// Security: Returns the same response regardless of whether the email exists (prevent user enumeration)
func
(
s
*
AuthService
)
RequestPasswordReset
(
ctx
context
.
Context
,
email
,
frontendBaseURL
string
)
error
{
if
!
s
.
IsPasswordResetEnabled
(
ctx
)
{
return
infraerrors
.
Forbidden
(
"PASSWORD_RESET_DISABLED"
,
"password reset is not enabled"
)
}
if
s
.
emailService
==
nil
{
return
ErrServiceUnavailable
}
siteName
,
resetURL
,
shouldProceed
:=
s
.
preparePasswordReset
(
ctx
,
email
,
frontendBaseURL
)
if
!
shouldProceed
{
return
nil
// Silent success to prevent enumeration
}
if
err
:=
s
.
emailService
.
SendPasswordResetEmail
(
ctx
,
email
,
siteName
,
resetURL
);
err
!=
nil
{
log
.
Printf
(
"[Auth] Failed to send password reset email to %s: %v"
,
email
,
err
)
return
nil
// Silent success to prevent enumeration
}
log
.
Printf
(
"[Auth] Password reset email sent to: %s"
,
email
)
return
nil
}
// RequestPasswordResetAsync 异步请求密码重置(队列发送)
// Security: Returns the same response regardless of whether the email exists (prevent user enumeration)
func
(
s
*
AuthService
)
RequestPasswordResetAsync
(
ctx
context
.
Context
,
email
,
frontendBaseURL
string
)
error
{
if
!
s
.
IsPasswordResetEnabled
(
ctx
)
{
return
infraerrors
.
Forbidden
(
"PASSWORD_RESET_DISABLED"
,
"password reset is not enabled"
)
}
if
s
.
emailQueueService
==
nil
{
return
ErrServiceUnavailable
}
siteName
,
resetURL
,
shouldProceed
:=
s
.
preparePasswordReset
(
ctx
,
email
,
frontendBaseURL
)
if
!
shouldProceed
{
return
nil
// Silent success to prevent enumeration
}
if
err
:=
s
.
emailQueueService
.
EnqueuePasswordReset
(
email
,
siteName
,
resetURL
);
err
!=
nil
{
log
.
Printf
(
"[Auth] Failed to enqueue password reset email for %s: %v"
,
email
,
err
)
return
nil
// Silent success to prevent enumeration
}
log
.
Printf
(
"[Auth] Password reset email enqueued for: %s"
,
email
)
return
nil
}
// ResetPassword 重置密码
// Security: Increments TokenVersion to invalidate all existing JWT tokens
func
(
s
*
AuthService
)
ResetPassword
(
ctx
context
.
Context
,
email
,
token
,
newPassword
string
)
error
{
// Check if password reset is enabled
if
!
s
.
IsPasswordResetEnabled
(
ctx
)
{
return
infraerrors
.
Forbidden
(
"PASSWORD_RESET_DISABLED"
,
"password reset is not enabled"
)
}
if
s
.
emailService
==
nil
{
return
ErrServiceUnavailable
}
// Verify and consume the reset token (one-time use)
if
err
:=
s
.
emailService
.
ConsumePasswordResetToken
(
ctx
,
email
,
token
);
err
!=
nil
{
return
err
}
// Get user
user
,
err
:=
s
.
userRepo
.
GetByEmail
(
ctx
,
email
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
ErrUserNotFound
)
{
return
ErrInvalidResetToken
// Token was valid but user was deleted
}
log
.
Printf
(
"[Auth] Database error getting user for password reset: %v"
,
err
)
return
ErrServiceUnavailable
}
// Check if user is active
if
!
user
.
IsActive
()
{
return
ErrUserNotActive
}
// Hash new password
hashedPassword
,
err
:=
s
.
HashPassword
(
newPassword
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"hash password: %w"
,
err
)
}
// Update password and increment TokenVersion
user
.
PasswordHash
=
hashedPassword
user
.
TokenVersion
++
// Invalidate all existing tokens
if
err
:=
s
.
userRepo
.
Update
(
ctx
,
user
);
err
!=
nil
{
log
.
Printf
(
"[Auth] Database error updating password for user %d: %v"
,
user
.
ID
,
err
)
return
ErrServiceUnavailable
}
log
.
Printf
(
"[Auth] Password reset successful for user: %s"
,
email
)
return
nil
}
backend/internal/service/auth_service_register_test.go
View file @
31fe0178
...
@@ -71,6 +71,26 @@ func (s *emailCacheStub) DeleteVerificationCode(ctx context.Context, email strin
...
@@ -71,6 +71,26 @@ func (s *emailCacheStub) DeleteVerificationCode(ctx context.Context, email strin
return
nil
return
nil
}
}
func
(
s
*
emailCacheStub
)
GetPasswordResetToken
(
ctx
context
.
Context
,
email
string
)
(
*
PasswordResetTokenData
,
error
)
{
return
nil
,
nil
}
func
(
s
*
emailCacheStub
)
SetPasswordResetToken
(
ctx
context
.
Context
,
email
string
,
data
*
PasswordResetTokenData
,
ttl
time
.
Duration
)
error
{
return
nil
}
func
(
s
*
emailCacheStub
)
DeletePasswordResetToken
(
ctx
context
.
Context
,
email
string
)
error
{
return
nil
}
func
(
s
*
emailCacheStub
)
IsPasswordResetEmailInCooldown
(
ctx
context
.
Context
,
email
string
)
bool
{
return
false
}
func
(
s
*
emailCacheStub
)
SetPasswordResetEmailCooldown
(
ctx
context
.
Context
,
email
string
,
ttl
time
.
Duration
)
error
{
return
nil
}
func
newAuthService
(
repo
*
userRepoStub
,
settings
map
[
string
]
string
,
emailCache
EmailCache
)
*
AuthService
{
func
newAuthService
(
repo
*
userRepoStub
,
settings
map
[
string
]
string
,
emailCache
EmailCache
)
*
AuthService
{
cfg
:=
&
config
.
Config
{
cfg
:=
&
config
.
Config
{
JWT
:
config
.
JWTConfig
{
JWT
:
config
.
JWTConfig
{
...
@@ -95,6 +115,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
...
@@ -95,6 +115,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
return
NewAuthService
(
return
NewAuthService
(
repo
,
repo
,
nil
,
// redeemRepo
cfg
,
cfg
,
settingService
,
settingService
,
emailService
,
emailService
,
...
@@ -132,7 +153,7 @@ func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testi
...
@@ -132,7 +153,7 @@ func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testi
},
nil
)
},
nil
)
// 应返回服务不可用错误,而不是允许绕过验证
// 应返回服务不可用错误,而不是允许绕过验证
_
,
_
,
err
:=
service
.
RegisterWithVerification
(
context
.
Background
(),
"user@test.com"
,
"password"
,
"any-code"
,
""
)
_
,
_
,
err
:=
service
.
RegisterWithVerification
(
context
.
Background
(),
"user@test.com"
,
"password"
,
"any-code"
,
""
,
""
)
require
.
ErrorIs
(
t
,
err
,
ErrServiceUnavailable
)
require
.
ErrorIs
(
t
,
err
,
ErrServiceUnavailable
)
}
}
...
@@ -144,7 +165,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
...
@@ -144,7 +165,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
SettingKeyEmailVerifyEnabled
:
"true"
,
SettingKeyEmailVerifyEnabled
:
"true"
,
},
cache
)
},
cache
)
_
,
_
,
err
:=
service
.
RegisterWithVerification
(
context
.
Background
(),
"user@test.com"
,
"password"
,
""
,
""
)
_
,
_
,
err
:=
service
.
RegisterWithVerification
(
context
.
Background
(),
"user@test.com"
,
"password"
,
""
,
""
,
""
)
require
.
ErrorIs
(
t
,
err
,
ErrEmailVerifyRequired
)
require
.
ErrorIs
(
t
,
err
,
ErrEmailVerifyRequired
)
}
}
...
@@ -158,7 +179,7 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
...
@@ -158,7 +179,7 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
SettingKeyEmailVerifyEnabled
:
"true"
,
SettingKeyEmailVerifyEnabled
:
"true"
,
},
cache
)
},
cache
)
_
,
_
,
err
:=
service
.
RegisterWithVerification
(
context
.
Background
(),
"user@test.com"
,
"password"
,
"wrong"
,
""
)
_
,
_
,
err
:=
service
.
RegisterWithVerification
(
context
.
Background
(),
"user@test.com"
,
"password"
,
"wrong"
,
""
,
""
)
require
.
ErrorIs
(
t
,
err
,
ErrInvalidVerifyCode
)
require
.
ErrorIs
(
t
,
err
,
ErrInvalidVerifyCode
)
require
.
ErrorContains
(
t
,
err
,
"verify code"
)
require
.
ErrorContains
(
t
,
err
,
"verify code"
)
}
}
...
...
backend/internal/service/billing_service.go
View file @
31fe0178
...
@@ -241,6 +241,76 @@ func (s *BillingService) CalculateCostWithConfig(model string, tokens UsageToken
...
@@ -241,6 +241,76 @@ func (s *BillingService) CalculateCostWithConfig(model string, tokens UsageToken
return
s
.
CalculateCost
(
model
,
tokens
,
multiplier
)
return
s
.
CalculateCost
(
model
,
tokens
,
multiplier
)
}
}
// CalculateCostWithLongContext 计算费用,支持长上下文双倍计费
// threshold: 阈值(如 200000),超过此值的部分按 extraMultiplier 倍计费
// extraMultiplier: 超出部分的倍率(如 2.0 表示双倍)
//
// 示例:缓存 210k + 输入 10k = 220k,阈值 200k,倍率 2.0
// 拆分为:范围内 (200k, 0) + 范围外 (10k, 10k)
// 范围内正常计费,范围外 × 2 计费
func
(
s
*
BillingService
)
CalculateCostWithLongContext
(
model
string
,
tokens
UsageTokens
,
rateMultiplier
float64
,
threshold
int
,
extraMultiplier
float64
)
(
*
CostBreakdown
,
error
)
{
// 未启用长上下文计费,直接走正常计费
if
threshold
<=
0
||
extraMultiplier
<=
1
{
return
s
.
CalculateCost
(
model
,
tokens
,
rateMultiplier
)
}
// 计算总输入 token(缓存读取 + 新输入)
total
:=
tokens
.
CacheReadTokens
+
tokens
.
InputTokens
if
total
<=
threshold
{
return
s
.
CalculateCost
(
model
,
tokens
,
rateMultiplier
)
}
// 拆分成范围内和范围外
var
inRangeCacheTokens
,
inRangeInputTokens
int
var
outRangeCacheTokens
,
outRangeInputTokens
int
if
tokens
.
CacheReadTokens
>=
threshold
{
// 缓存已超过阈值:范围内只有缓存,范围外是超出的缓存+全部输入
inRangeCacheTokens
=
threshold
inRangeInputTokens
=
0
outRangeCacheTokens
=
tokens
.
CacheReadTokens
-
threshold
outRangeInputTokens
=
tokens
.
InputTokens
}
else
{
// 缓存未超过阈值:范围内是全部缓存+部分输入,范围外是剩余输入
inRangeCacheTokens
=
tokens
.
CacheReadTokens
inRangeInputTokens
=
threshold
-
tokens
.
CacheReadTokens
outRangeCacheTokens
=
0
outRangeInputTokens
=
tokens
.
InputTokens
-
inRangeInputTokens
}
// 范围内部分:正常计费
inRangeTokens
:=
UsageTokens
{
InputTokens
:
inRangeInputTokens
,
OutputTokens
:
tokens
.
OutputTokens
,
// 输出只算一次
CacheCreationTokens
:
tokens
.
CacheCreationTokens
,
CacheReadTokens
:
inRangeCacheTokens
,
}
inRangeCost
,
err
:=
s
.
CalculateCost
(
model
,
inRangeTokens
,
rateMultiplier
)
if
err
!=
nil
{
return
nil
,
err
}
// 范围外部分:× extraMultiplier 计费
outRangeTokens
:=
UsageTokens
{
InputTokens
:
outRangeInputTokens
,
CacheReadTokens
:
outRangeCacheTokens
,
}
outRangeCost
,
err
:=
s
.
CalculateCost
(
model
,
outRangeTokens
,
rateMultiplier
*
extraMultiplier
)
if
err
!=
nil
{
return
inRangeCost
,
nil
// 出错时返回范围内成本
}
// 合并成本
return
&
CostBreakdown
{
InputCost
:
inRangeCost
.
InputCost
+
outRangeCost
.
InputCost
,
OutputCost
:
inRangeCost
.
OutputCost
,
CacheCreationCost
:
inRangeCost
.
CacheCreationCost
,
CacheReadCost
:
inRangeCost
.
CacheReadCost
+
outRangeCost
.
CacheReadCost
,
TotalCost
:
inRangeCost
.
TotalCost
+
outRangeCost
.
TotalCost
,
ActualCost
:
inRangeCost
.
ActualCost
+
outRangeCost
.
ActualCost
,
},
nil
}
// ListSupportedModels 列出所有支持的模型(现在总是返回true,因为有模糊匹配)
// ListSupportedModels 列出所有支持的模型(现在总是返回true,因为有模糊匹配)
func
(
s
*
BillingService
)
ListSupportedModels
()
[]
string
{
func
(
s
*
BillingService
)
ListSupportedModels
()
[]
string
{
models
:=
make
([]
string
,
0
)
models
:=
make
([]
string
,
0
)
...
...
backend/internal/service/domain_constants.go
View file @
31fe0178
package
service
package
service
import
"github.com/Wei-Shaw/sub2api/internal/domain"
// Status constants
// Status constants
const
(
const
(
StatusActive
=
"a
ctive
"
StatusActive
=
domain
.
StatusA
ctive
StatusDisabled
=
"
disabled
"
StatusDisabled
=
d
omain
.
StatusD
isabled
StatusError
=
"e
rror
"
StatusError
=
domain
.
StatusE
rror
StatusUnused
=
"u
nused
"
StatusUnused
=
domain
.
StatusU
nused
StatusUsed
=
"u
sed
"
StatusUsed
=
domain
.
StatusU
sed
StatusExpired
=
"e
xpired
"
StatusExpired
=
domain
.
StatusE
xpired
)
)
// Role constants
// Role constants
const
(
const
(
RoleAdmin
=
"a
dmin
"
RoleAdmin
=
domain
.
RoleA
dmin
RoleUser
=
"u
ser
"
RoleUser
=
domain
.
RoleU
ser
)
)
// Platform constants
// Platform constants
const
(
const
(
PlatformAnthropic
=
"a
nthropic
"
PlatformAnthropic
=
domain
.
PlatformA
nthropic
PlatformOpenAI
=
"openai"
PlatformOpenAI
=
domain
.
PlatformOpenAI
PlatformGemini
=
"g
emini
"
PlatformGemini
=
domain
.
PlatformG
emini
PlatformAntigravity
=
"a
ntigravity
"
PlatformAntigravity
=
domain
.
PlatformA
ntigravity
)
)
// Account type constants
// Account type constants
const
(
const
(
AccountTypeOAuth
=
"oa
uth
"
// OAuth类型账号(full scope: profile + inference)
AccountTypeOAuth
=
domain
.
AccountTypeOA
uth
// OAuth类型账号(full scope: profile + inference)
AccountTypeSetupToken
=
"s
etup
-t
oken
"
// Setup Token类型账号(inference only scope)
AccountTypeSetupToken
=
domain
.
AccountTypeS
etup
T
oken
// Setup Token类型账号(inference only scope)
AccountTypeAPIKey
=
"apikey"
// API Key类型账号
AccountTypeAPIKey
=
domain
.
AccountTypeAPIKey
// API Key类型账号
)
)
// Redeem type constants
// Redeem type constants
const
(
const
(
RedeemTypeBalance
=
"balance"
RedeemTypeBalance
=
domain
.
RedeemTypeBalance
RedeemTypeConcurrency
=
"concurrency"
RedeemTypeConcurrency
=
domain
.
RedeemTypeConcurrency
RedeemTypeSubscription
=
"subscription"
RedeemTypeSubscription
=
domain
.
RedeemTypeSubscription
RedeemTypeInvitation
=
domain
.
RedeemTypeInvitation
)
)
// PromoCode status constants
// PromoCode status constants
const
(
const
(
PromoCodeStatusActive
=
"a
ctive
"
PromoCodeStatusActive
=
domain
.
PromoCodeStatusA
ctive
PromoCodeStatusDisabled
=
"
disabled
"
PromoCodeStatusDisabled
=
d
omain
.
PromoCodeStatusD
isabled
)
)
// Admin adjustment type constants
// Admin adjustment type constants
const
(
const
(
AdjustmentTypeAdminBalance
=
"a
dmin
_b
alance
"
// 管理员调整余额
AdjustmentTypeAdminBalance
=
domain
.
AdjustmentTypeA
dmin
B
alance
// 管理员调整余额
AdjustmentTypeAdminConcurrency
=
"a
dmin
_c
oncurrency
"
// 管理员调整并发数
AdjustmentTypeAdminConcurrency
=
domain
.
AdjustmentTypeA
dmin
C
oncurrency
// 管理员调整并发数
)
)
// Group subscription type constants
// Group subscription type constants
const
(
const
(
SubscriptionTypeStandard
=
"s
tandard
"
// 标准计费模式(按余额扣费)
SubscriptionTypeStandard
=
domain
.
SubscriptionTypeS
tandard
// 标准计费模式(按余额扣费)
SubscriptionTypeSubscription
=
"s
ubscription
"
// 订阅模式(按限额控制)
SubscriptionTypeSubscription
=
domain
.
SubscriptionTypeS
ubscription
// 订阅模式(按限额控制)
)
)
// Subscription status constants
// Subscription status constants
const
(
const
(
SubscriptionStatusActive
=
"a
ctive
"
SubscriptionStatusActive
=
domain
.
SubscriptionStatusA
ctive
SubscriptionStatusExpired
=
"e
xpired
"
SubscriptionStatusExpired
=
domain
.
SubscriptionStatusE
xpired
SubscriptionStatusSuspended
=
"s
uspended
"
SubscriptionStatusSuspended
=
domain
.
SubscriptionStatusS
uspended
)
)
// LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀(RFC 保留域名)。
// LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀(RFC 保留域名)。
...
@@ -69,9 +72,11 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
...
@@ -69,9 +72,11 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
// Setting keys
// Setting keys
const
(
const
(
// 注册设置
// 注册设置
SettingKeyRegistrationEnabled
=
"registration_enabled"
// 是否开放注册
SettingKeyRegistrationEnabled
=
"registration_enabled"
// 是否开放注册
SettingKeyEmailVerifyEnabled
=
"email_verify_enabled"
// 是否开启邮件验证
SettingKeyEmailVerifyEnabled
=
"email_verify_enabled"
// 是否开启邮件验证
SettingKeyPromoCodeEnabled
=
"promo_code_enabled"
// 是否启用优惠码功能
SettingKeyPromoCodeEnabled
=
"promo_code_enabled"
// 是否启用优惠码功能
SettingKeyPasswordResetEnabled
=
"password_reset_enabled"
// 是否启用忘记密码功能(需要先开启邮件验证)
SettingKeyInvitationCodeEnabled
=
"invitation_code_enabled"
// 是否启用邀请码注册
// 邮件服务设置
// 邮件服务设置
SettingKeySMTPHost
=
"smtp_host"
// SMTP服务器地址
SettingKeySMTPHost
=
"smtp_host"
// SMTP服务器地址
...
@@ -87,6 +92,9 @@ const (
...
@@ -87,6 +92,9 @@ const (
SettingKeyTurnstileSiteKey
=
"turnstile_site_key"
// Turnstile Site Key
SettingKeyTurnstileSiteKey
=
"turnstile_site_key"
// Turnstile Site Key
SettingKeyTurnstileSecretKey
=
"turnstile_secret_key"
// Turnstile Secret Key
SettingKeyTurnstileSecretKey
=
"turnstile_secret_key"
// Turnstile Secret Key
// TOTP 双因素认证设置
SettingKeyTotpEnabled
=
"totp_enabled"
// 是否启用 TOTP 2FA 功能
// LinuxDo Connect OAuth 登录设置
// LinuxDo Connect OAuth 登录设置
SettingKeyLinuxDoConnectEnabled
=
"linuxdo_connect_enabled"
SettingKeyLinuxDoConnectEnabled
=
"linuxdo_connect_enabled"
SettingKeyLinuxDoConnectClientID
=
"linuxdo_connect_client_id"
SettingKeyLinuxDoConnectClientID
=
"linuxdo_connect_client_id"
...
@@ -94,14 +102,16 @@ const (
...
@@ -94,14 +102,16 @@ const (
SettingKeyLinuxDoConnectRedirectURL
=
"linuxdo_connect_redirect_url"
SettingKeyLinuxDoConnectRedirectURL
=
"linuxdo_connect_redirect_url"
// OEM设置
// OEM设置
SettingKeySiteName
=
"site_name"
// 网站名称
SettingKeySiteName
=
"site_name"
// 网站名称
SettingKeySiteLogo
=
"site_logo"
// 网站Logo (base64)
SettingKeySiteLogo
=
"site_logo"
// 网站Logo (base64)
SettingKeySiteSubtitle
=
"site_subtitle"
// 网站副标题
SettingKeySiteSubtitle
=
"site_subtitle"
// 网站副标题
SettingKeyAPIBaseURL
=
"api_base_url"
// API端点地址(用于客户端配置和导入)
SettingKeyAPIBaseURL
=
"api_base_url"
// API端点地址(用于客户端配置和导入)
SettingKeyContactInfo
=
"contact_info"
// 客服联系方式
SettingKeyContactInfo
=
"contact_info"
// 客服联系方式
SettingKeyDocURL
=
"doc_url"
// 文档链接
SettingKeyDocURL
=
"doc_url"
// 文档链接
SettingKeyHomeContent
=
"home_content"
// 首页内容(支持 Markdown/HTML,或 URL 作为 iframe src)
SettingKeyHomeContent
=
"home_content"
// 首页内容(支持 Markdown/HTML,或 URL 作为 iframe src)
SettingKeyHideCcsImportButton
=
"hide_ccs_import_button"
// 是否隐藏 API Keys 页面的导入 CCS 按钮
SettingKeyHideCcsImportButton
=
"hide_ccs_import_button"
// 是否隐藏 API Keys 页面的导入 CCS 按钮
SettingKeyPurchaseSubscriptionEnabled
=
"purchase_subscription_enabled"
// 是否展示“购买订阅”页面入口
SettingKeyPurchaseSubscriptionURL
=
"purchase_subscription_url"
// “购买订阅”页面 URL(作为 iframe src)
// 默认配置
// 默认配置
SettingKeyDefaultConcurrency
=
"default_concurrency"
// 新用户默认并发量
SettingKeyDefaultConcurrency
=
"default_concurrency"
// 新用户默认并发量
...
...
backend/internal/service/email_queue_service.go
View file @
31fe0178
...
@@ -8,11 +8,18 @@ import (
...
@@ -8,11 +8,18 @@ import (
"time"
"time"
)
)
// Task type constants
const
(
TaskTypeVerifyCode
=
"verify_code"
TaskTypePasswordReset
=
"password_reset"
)
// EmailTask 邮件发送任务
// EmailTask 邮件发送任务
type
EmailTask
struct
{
type
EmailTask
struct
{
Email
string
Email
string
SiteName
string
SiteName
string
TaskType
string
// "verify_code"
TaskType
string
// "verify_code" or "password_reset"
ResetURL
string
// Only used for password_reset task type
}
}
// EmailQueueService 异步邮件队列服务
// EmailQueueService 异步邮件队列服务
...
@@ -73,12 +80,18 @@ func (s *EmailQueueService) processTask(workerID int, task EmailTask) {
...
@@ -73,12 +80,18 @@ func (s *EmailQueueService) processTask(workerID int, task EmailTask) {
defer
cancel
()
defer
cancel
()
switch
task
.
TaskType
{
switch
task
.
TaskType
{
case
"v
erify
_c
ode
"
:
case
TaskTypeV
erify
C
ode
:
if
err
:=
s
.
emailService
.
SendVerifyCode
(
ctx
,
task
.
Email
,
task
.
SiteName
);
err
!=
nil
{
if
err
:=
s
.
emailService
.
SendVerifyCode
(
ctx
,
task
.
Email
,
task
.
SiteName
);
err
!=
nil
{
log
.
Printf
(
"[EmailQueue] Worker %d failed to send verify code to %s: %v"
,
workerID
,
task
.
Email
,
err
)
log
.
Printf
(
"[EmailQueue] Worker %d failed to send verify code to %s: %v"
,
workerID
,
task
.
Email
,
err
)
}
else
{
}
else
{
log
.
Printf
(
"[EmailQueue] Worker %d sent verify code to %s"
,
workerID
,
task
.
Email
)
log
.
Printf
(
"[EmailQueue] Worker %d sent verify code to %s"
,
workerID
,
task
.
Email
)
}
}
case
TaskTypePasswordReset
:
if
err
:=
s
.
emailService
.
SendPasswordResetEmailWithCooldown
(
ctx
,
task
.
Email
,
task
.
SiteName
,
task
.
ResetURL
);
err
!=
nil
{
log
.
Printf
(
"[EmailQueue] Worker %d failed to send password reset to %s: %v"
,
workerID
,
task
.
Email
,
err
)
}
else
{
log
.
Printf
(
"[EmailQueue] Worker %d sent password reset to %s"
,
workerID
,
task
.
Email
)
}
default
:
default
:
log
.
Printf
(
"[EmailQueue] Worker %d unknown task type: %s"
,
workerID
,
task
.
TaskType
)
log
.
Printf
(
"[EmailQueue] Worker %d unknown task type: %s"
,
workerID
,
task
.
TaskType
)
}
}
...
@@ -89,7 +102,7 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
...
@@ -89,7 +102,7 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
task
:=
EmailTask
{
task
:=
EmailTask
{
Email
:
email
,
Email
:
email
,
SiteName
:
siteName
,
SiteName
:
siteName
,
TaskType
:
"v
erify
_c
ode
"
,
TaskType
:
TaskTypeV
erify
C
ode
,
}
}
select
{
select
{
...
@@ -101,6 +114,24 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
...
@@ -101,6 +114,24 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
}
}
}
}
// EnqueuePasswordReset 将密码重置邮件任务加入队列
func
(
s
*
EmailQueueService
)
EnqueuePasswordReset
(
email
,
siteName
,
resetURL
string
)
error
{
task
:=
EmailTask
{
Email
:
email
,
SiteName
:
siteName
,
TaskType
:
TaskTypePasswordReset
,
ResetURL
:
resetURL
,
}
select
{
case
s
.
taskChan
<-
task
:
log
.
Printf
(
"[EmailQueue] Enqueued password reset task for %s"
,
email
)
return
nil
default
:
return
fmt
.
Errorf
(
"email queue is full"
)
}
}
// Stop 停止队列服务
// Stop 停止队列服务
func
(
s
*
EmailQueueService
)
Stop
()
{
func
(
s
*
EmailQueueService
)
Stop
()
{
close
(
s
.
stopChan
)
close
(
s
.
stopChan
)
...
...
backend/internal/service/email_service.go
View file @
31fe0178
...
@@ -3,11 +3,14 @@ package service
...
@@ -3,11 +3,14 @@ package service
import
(
import
(
"context"
"context"
"crypto/rand"
"crypto/rand"
"crypto/subtle"
"crypto/tls"
"crypto/tls"
"encoding/hex"
"fmt"
"fmt"
"log"
"log"
"math/big"
"math/big"
"net/smtp"
"net/smtp"
"net/url"
"strconv"
"strconv"
"time"
"time"
...
@@ -19,6 +22,9 @@ var (
...
@@ -19,6 +22,9 @@ var (
ErrInvalidVerifyCode
=
infraerrors
.
BadRequest
(
"INVALID_VERIFY_CODE"
,
"invalid or expired verification code"
)
ErrInvalidVerifyCode
=
infraerrors
.
BadRequest
(
"INVALID_VERIFY_CODE"
,
"invalid or expired verification code"
)
ErrVerifyCodeTooFrequent
=
infraerrors
.
TooManyRequests
(
"VERIFY_CODE_TOO_FREQUENT"
,
"please wait before requesting a new code"
)
ErrVerifyCodeTooFrequent
=
infraerrors
.
TooManyRequests
(
"VERIFY_CODE_TOO_FREQUENT"
,
"please wait before requesting a new code"
)
ErrVerifyCodeMaxAttempts
=
infraerrors
.
TooManyRequests
(
"VERIFY_CODE_MAX_ATTEMPTS"
,
"too many failed attempts, please request a new code"
)
ErrVerifyCodeMaxAttempts
=
infraerrors
.
TooManyRequests
(
"VERIFY_CODE_MAX_ATTEMPTS"
,
"too many failed attempts, please request a new code"
)
// Password reset errors
ErrInvalidResetToken
=
infraerrors
.
BadRequest
(
"INVALID_RESET_TOKEN"
,
"invalid or expired password reset token"
)
)
)
// EmailCache defines cache operations for email service
// EmailCache defines cache operations for email service
...
@@ -26,6 +32,16 @@ type EmailCache interface {
...
@@ -26,6 +32,16 @@ type EmailCache interface {
GetVerificationCode
(
ctx
context
.
Context
,
email
string
)
(
*
VerificationCodeData
,
error
)
GetVerificationCode
(
ctx
context
.
Context
,
email
string
)
(
*
VerificationCodeData
,
error
)
SetVerificationCode
(
ctx
context
.
Context
,
email
string
,
data
*
VerificationCodeData
,
ttl
time
.
Duration
)
error
SetVerificationCode
(
ctx
context
.
Context
,
email
string
,
data
*
VerificationCodeData
,
ttl
time
.
Duration
)
error
DeleteVerificationCode
(
ctx
context
.
Context
,
email
string
)
error
DeleteVerificationCode
(
ctx
context
.
Context
,
email
string
)
error
// Password reset token methods
GetPasswordResetToken
(
ctx
context
.
Context
,
email
string
)
(
*
PasswordResetTokenData
,
error
)
SetPasswordResetToken
(
ctx
context
.
Context
,
email
string
,
data
*
PasswordResetTokenData
,
ttl
time
.
Duration
)
error
DeletePasswordResetToken
(
ctx
context
.
Context
,
email
string
)
error
// Password reset email cooldown methods
// Returns true if in cooldown period (email was sent recently)
IsPasswordResetEmailInCooldown
(
ctx
context
.
Context
,
email
string
)
bool
SetPasswordResetEmailCooldown
(
ctx
context
.
Context
,
email
string
,
ttl
time
.
Duration
)
error
}
}
// VerificationCodeData represents verification code data
// VerificationCodeData represents verification code data
...
@@ -35,10 +51,22 @@ type VerificationCodeData struct {
...
@@ -35,10 +51,22 @@ type VerificationCodeData struct {
CreatedAt
time
.
Time
CreatedAt
time
.
Time
}
}
// PasswordResetTokenData represents password reset token data
type
PasswordResetTokenData
struct
{
Token
string
CreatedAt
time
.
Time
}
const
(
const
(
verifyCodeTTL
=
15
*
time
.
Minute
verifyCodeTTL
=
15
*
time
.
Minute
verifyCodeCooldown
=
1
*
time
.
Minute
verifyCodeCooldown
=
1
*
time
.
Minute
maxVerifyCodeAttempts
=
5
maxVerifyCodeAttempts
=
5
// Password reset token settings
passwordResetTokenTTL
=
30
*
time
.
Minute
// Password reset email cooldown (prevent email bombing)
passwordResetEmailCooldown
=
30
*
time
.
Second
)
)
// SMTPConfig SMTP配置
// SMTPConfig SMTP配置
...
@@ -254,8 +282,8 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
...
@@ -254,8 +282,8 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
return
ErrVerifyCodeMaxAttempts
return
ErrVerifyCodeMaxAttempts
}
}
// 验证码不匹配
// 验证码不匹配
(constant-time comparison to prevent timing attacks)
if
data
.
Code
!=
code
{
if
subtle
.
ConstantTimeCompare
([]
byte
(
data
.
Code
),
[]
byte
(
code
))
!=
1
{
data
.
Attempts
++
data
.
Attempts
++
if
err
:=
s
.
cache
.
SetVerificationCode
(
ctx
,
email
,
data
,
verifyCodeTTL
);
err
!=
nil
{
if
err
:=
s
.
cache
.
SetVerificationCode
(
ctx
,
email
,
data
,
verifyCodeTTL
);
err
!=
nil
{
log
.
Printf
(
"[Email] Failed to update verification attempt count: %v"
,
err
)
log
.
Printf
(
"[Email] Failed to update verification attempt count: %v"
,
err
)
...
@@ -357,3 +385,157 @@ func (s *EmailService) TestSMTPConnectionWithConfig(config *SMTPConfig) error {
...
@@ -357,3 +385,157 @@ func (s *EmailService) TestSMTPConnectionWithConfig(config *SMTPConfig) error {
return
client
.
Quit
()
return
client
.
Quit
()
}
}
// GeneratePasswordResetToken generates a secure 32-byte random token (64 hex characters)
func
(
s
*
EmailService
)
GeneratePasswordResetToken
()
(
string
,
error
)
{
bytes
:=
make
([]
byte
,
32
)
if
_
,
err
:=
rand
.
Read
(
bytes
);
err
!=
nil
{
return
""
,
err
}
return
hex
.
EncodeToString
(
bytes
),
nil
}
// SendPasswordResetEmail sends a password reset email with a reset link
func
(
s
*
EmailService
)
SendPasswordResetEmail
(
ctx
context
.
Context
,
email
,
siteName
,
resetURL
string
)
error
{
var
token
string
var
needSaveToken
bool
// Check if token already exists
existing
,
err
:=
s
.
cache
.
GetPasswordResetToken
(
ctx
,
email
)
if
err
==
nil
&&
existing
!=
nil
{
// Token exists, reuse it (allows resending email without generating new token)
token
=
existing
.
Token
needSaveToken
=
false
}
else
{
// Generate new token
token
,
err
=
s
.
GeneratePasswordResetToken
()
if
err
!=
nil
{
return
fmt
.
Errorf
(
"generate token: %w"
,
err
)
}
needSaveToken
=
true
}
// Save token to Redis (only if new token generated)
if
needSaveToken
{
data
:=
&
PasswordResetTokenData
{
Token
:
token
,
CreatedAt
:
time
.
Now
(),
}
if
err
:=
s
.
cache
.
SetPasswordResetToken
(
ctx
,
email
,
data
,
passwordResetTokenTTL
);
err
!=
nil
{
return
fmt
.
Errorf
(
"save reset token: %w"
,
err
)
}
}
// Build full reset URL with URL-encoded token and email
fullResetURL
:=
fmt
.
Sprintf
(
"%s?email=%s&token=%s"
,
resetURL
,
url
.
QueryEscape
(
email
),
url
.
QueryEscape
(
token
))
// Build email content
subject
:=
fmt
.
Sprintf
(
"[%s] 密码重置请求"
,
siteName
)
body
:=
s
.
buildPasswordResetEmailBody
(
fullResetURL
,
siteName
)
// Send email
if
err
:=
s
.
SendEmail
(
ctx
,
email
,
subject
,
body
);
err
!=
nil
{
return
fmt
.
Errorf
(
"send email: %w"
,
err
)
}
return
nil
}
// SendPasswordResetEmailWithCooldown sends password reset email with cooldown check (called by queue worker)
// This method wraps SendPasswordResetEmail with email cooldown to prevent email bombing
func
(
s
*
EmailService
)
SendPasswordResetEmailWithCooldown
(
ctx
context
.
Context
,
email
,
siteName
,
resetURL
string
)
error
{
// Check email cooldown to prevent email bombing
if
s
.
cache
.
IsPasswordResetEmailInCooldown
(
ctx
,
email
)
{
log
.
Printf
(
"[Email] Password reset email skipped (cooldown): %s"
,
email
)
return
nil
// Silent success to prevent revealing cooldown to attackers
}
// Send email using core method
if
err
:=
s
.
SendPasswordResetEmail
(
ctx
,
email
,
siteName
,
resetURL
);
err
!=
nil
{
return
err
}
// Set cooldown marker (Redis TTL handles expiration)
if
err
:=
s
.
cache
.
SetPasswordResetEmailCooldown
(
ctx
,
email
,
passwordResetEmailCooldown
);
err
!=
nil
{
log
.
Printf
(
"[Email] Failed to set password reset cooldown for %s: %v"
,
email
,
err
)
}
return
nil
}
// VerifyPasswordResetToken verifies the password reset token without consuming it
func
(
s
*
EmailService
)
VerifyPasswordResetToken
(
ctx
context
.
Context
,
email
,
token
string
)
error
{
data
,
err
:=
s
.
cache
.
GetPasswordResetToken
(
ctx
,
email
)
if
err
!=
nil
||
data
==
nil
{
return
ErrInvalidResetToken
}
// Use constant-time comparison to prevent timing attacks
if
subtle
.
ConstantTimeCompare
([]
byte
(
data
.
Token
),
[]
byte
(
token
))
!=
1
{
return
ErrInvalidResetToken
}
return
nil
}
// ConsumePasswordResetToken verifies and deletes the token (one-time use)
func
(
s
*
EmailService
)
ConsumePasswordResetToken
(
ctx
context
.
Context
,
email
,
token
string
)
error
{
// Verify first
if
err
:=
s
.
VerifyPasswordResetToken
(
ctx
,
email
,
token
);
err
!=
nil
{
return
err
}
// Delete after verification (one-time use)
if
err
:=
s
.
cache
.
DeletePasswordResetToken
(
ctx
,
email
);
err
!=
nil
{
log
.
Printf
(
"[Email] Failed to delete password reset token after consumption: %v"
,
err
)
}
return
nil
}
// buildPasswordResetEmailBody builds the HTML content for password reset email
func
(
s
*
EmailService
)
buildPasswordResetEmailBody
(
resetURL
,
siteName
string
)
string
{
return
fmt
.
Sprintf
(
`
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
.container { max-width: 600px; margin: 0 auto; background-color: #ffffff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
.header { background: linear-gradient(135deg, #667eea 0%%, #764ba2 100%%); color: white; padding: 30px; text-align: center; }
.header h1 { margin: 0; font-size: 24px; }
.content { padding: 40px 30px; text-align: center; }
.button { display: inline-block; background: linear-gradient(135deg, #667eea 0%%, #764ba2 100%%); color: white; padding: 14px 32px; text-decoration: none; border-radius: 8px; font-size: 16px; font-weight: 600; margin: 20px 0; }
.button:hover { opacity: 0.9; }
.info { color: #666; font-size: 14px; line-height: 1.6; margin-top: 20px; }
.link-fallback { color: #666; font-size: 12px; word-break: break-all; margin-top: 20px; padding: 15px; background-color: #f8f9fa; border-radius: 4px; }
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
.warning { color: #e74c3c; font-weight: 500; }
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>%s</h1>
</div>
<div class="content">
<p style="font-size: 18px; color: #333;">密码重置请求</p>
<p style="color: #666;">您已请求重置密码。请点击下方按钮设置新密码:</p>
<a href="%s" class="button">重置密码</a>
<div class="info">
<p>此链接将在 <strong>30 分钟</strong>后失效。</p>
<p class="warning">如果您没有请求重置密码,请忽略此邮件。您的密码将保持不变。</p>
</div>
<div class="link-fallback">
<p>如果按钮无法点击,请复制以下链接到浏览器中打开:</p>
<p>%s</p>
</div>
</div>
<div class="footer">
<p>这是一封自动发送的邮件,请勿回复。</p>
</div>
</div>
</body>
</html>
`
,
siteName
,
resetURL
,
resetURL
)
}
backend/internal/service/gateway_beta_test.go
0 → 100644
View file @
31fe0178
package
service
import
(
"testing"
"github.com/stretchr/testify/require"
)
func
TestMergeAnthropicBeta
(
t
*
testing
.
T
)
{
got
:=
mergeAnthropicBeta
(
[]
string
{
"oauth-2025-04-20"
,
"interleaved-thinking-2025-05-14"
},
"foo, oauth-2025-04-20,bar, foo"
,
)
require
.
Equal
(
t
,
"oauth-2025-04-20,interleaved-thinking-2025-05-14,foo,bar"
,
got
)
}
func
TestMergeAnthropicBeta_EmptyIncoming
(
t
*
testing
.
T
)
{
got
:=
mergeAnthropicBeta
(
[]
string
{
"oauth-2025-04-20"
,
"interleaved-thinking-2025-05-14"
},
""
,
)
require
.
Equal
(
t
,
"oauth-2025-04-20,interleaved-thinking-2025-05-14"
,
got
)
}
backend/internal/service/gateway_multiplatform_test.go
View file @
31fe0178
...
@@ -266,6 +266,14 @@ func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Conte
...
@@ -266,6 +266,14 @@ func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Conte
return
0
,
nil
return
0
,
nil
}
}
func
(
m
*
mockGroupRepoForGateway
)
BindAccountsToGroup
(
ctx
context
.
Context
,
groupID
int64
,
accountIDs
[]
int64
)
error
{
return
nil
}
func
(
m
*
mockGroupRepoForGateway
)
GetAccountIDsByGroupIDs
(
ctx
context
.
Context
,
groupIDs
[]
int64
)
([]
int64
,
error
)
{
return
nil
,
nil
}
func
ptr
[
T
any
](
v
T
)
*
T
{
func
ptr
[
T
any
](
v
T
)
*
T
{
return
&
v
return
&
v
}
}
...
...
backend/internal/service/gateway_oauth_metadata_test.go
0 → 100644
View file @
31fe0178
package
service
import
(
"regexp"
"testing"
"github.com/stretchr/testify/require"
)
func
TestBuildOAuthMetadataUserID_FallbackWithoutAccountUUID
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
parsed
:=
&
ParsedRequest
{
Model
:
"claude-sonnet-4-5"
,
Stream
:
true
,
MetadataUserID
:
""
,
System
:
nil
,
Messages
:
nil
,
}
account
:=
&
Account
{
ID
:
123
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{},
// intentionally missing account_uuid / claude_user_id
}
fp
:=
&
Fingerprint
{
ClientID
:
"deadbeef"
}
// should be used as user id in legacy format
got
:=
svc
.
buildOAuthMetadataUserID
(
parsed
,
account
,
fp
)
require
.
NotEmpty
(
t
,
got
)
// Legacy format: user_{client}_account__session_{uuid}
re
:=
regexp
.
MustCompile
(
`^user_[a-zA-Z0-9]+_account__session_[a-f0-9-]{36}$`
)
require
.
True
(
t
,
re
.
MatchString
(
got
),
"unexpected user_id format: %s"
,
got
)
}
func
TestBuildOAuthMetadataUserID_UsesAccountUUIDWhenPresent
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
parsed
:=
&
ParsedRequest
{
Model
:
"claude-sonnet-4-5"
,
Stream
:
true
,
MetadataUserID
:
""
,
}
account
:=
&
Account
{
ID
:
123
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{
"account_uuid"
:
"acc-uuid"
,
"claude_user_id"
:
"clientid123"
,
"anthropic_user_id"
:
""
,
},
}
got
:=
svc
.
buildOAuthMetadataUserID
(
parsed
,
account
,
nil
)
require
.
NotEmpty
(
t
,
got
)
// New format: user_{client}_account_{account_uuid}_session_{uuid}
re
:=
regexp
.
MustCompile
(
`^user_clientid123_account_acc-uuid_session_[a-f0-9-]{36}$`
)
require
.
True
(
t
,
re
.
MatchString
(
got
),
"unexpected user_id format: %s"
,
got
)
}
backend/internal/service/gateway_prompt_test.go
View file @
31fe0178
...
@@ -2,6 +2,7 @@ package service
...
@@ -2,6 +2,7 @@ package service
import
(
import
(
"encoding/json"
"encoding/json"
"strings"
"testing"
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/require"
...
@@ -134,6 +135,8 @@ func TestSystemIncludesClaudeCodePrompt(t *testing.T) {
...
@@ -134,6 +135,8 @@ func TestSystemIncludesClaudeCodePrompt(t *testing.T) {
}
}
func
TestInjectClaudeCodePrompt
(
t
*
testing
.
T
)
{
func
TestInjectClaudeCodePrompt
(
t
*
testing
.
T
)
{
claudePrefix
:=
strings
.
TrimSpace
(
claudeCodeSystemPrompt
)
tests
:=
[]
struct
{
tests
:=
[]
struct
{
name
string
name
string
body
string
body
string
...
@@ -162,7 +165,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
...
@@ -162,7 +165,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
system
:
"Custom prompt"
,
system
:
"Custom prompt"
,
wantSystemLen
:
2
,
wantSystemLen
:
2
,
wantFirstText
:
claudeCodeSystemPrompt
,
wantFirstText
:
claudeCodeSystemPrompt
,
wantSecondText
:
"
Custom prompt"
,
wantSecondText
:
claudePrefix
+
"
\n\n
Custom prompt"
,
},
},
{
{
name
:
"string system equals Claude Code prompt"
,
name
:
"string system equals Claude Code prompt"
,
...
@@ -178,7 +181,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
...
@@ -178,7 +181,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
// Claude Code + Custom = 2
// Claude Code + Custom = 2
wantSystemLen
:
2
,
wantSystemLen
:
2
,
wantFirstText
:
claudeCodeSystemPrompt
,
wantFirstText
:
claudeCodeSystemPrompt
,
wantSecondText
:
"
Custom"
,
wantSecondText
:
claudePrefix
+
"
\n\n
Custom"
,
},
},
{
{
name
:
"array system with existing Claude Code prompt (should dedupe)"
,
name
:
"array system with existing Claude Code prompt (should dedupe)"
,
...
@@ -190,7 +193,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
...
@@ -190,7 +193,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
// Claude Code at start + Other = 2 (deduped)
// Claude Code at start + Other = 2 (deduped)
wantSystemLen
:
2
,
wantSystemLen
:
2
,
wantFirstText
:
claudeCodeSystemPrompt
,
wantFirstText
:
claudeCodeSystemPrompt
,
wantSecondText
:
"
Other"
,
wantSecondText
:
claudePrefix
+
"
\n\n
Other"
,
},
},
{
{
name
:
"empty array"
,
name
:
"empty array"
,
...
...
backend/internal/service/gateway_sanitize_test.go
0 → 100644
View file @
31fe0178
package
service
import
(
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func
TestSanitizeOpenCodeText_RewritesCanonicalSentence
(
t
*
testing
.
T
)
{
in
:=
"You are OpenCode, the best coding agent on the planet."
got
:=
sanitizeSystemText
(
in
)
require
.
Equal
(
t
,
strings
.
TrimSpace
(
claudeCodeSystemPrompt
),
got
)
}
func
TestSanitizeToolDescription_DoesNotRewriteKeywords
(
t
*
testing
.
T
)
{
in
:=
"OpenCode and opencode are mentioned."
got
:=
sanitizeToolDescription
(
in
)
// We no longer rewrite tool descriptions; only redact obvious path leaks.
require
.
Equal
(
t
,
in
,
got
)
}
backend/internal/service/gateway_service.go
View file @
31fe0178
...
@@ -20,12 +20,14 @@ import (
...
@@ -20,12 +20,14 @@ import (
"strings"
"strings"
"sync/atomic"
"sync/atomic"
"time"
"time"
"unicode"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/google/uuid"
"github.com/tidwall/gjson"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"github.com/tidwall/sjson"
...
@@ -37,8 +39,15 @@ const (
...
@@ -37,8 +39,15 @@ const (
claudeAPICountTokensURL
=
"https://api.anthropic.com/v1/messages/count_tokens?beta=true"
claudeAPICountTokensURL
=
"https://api.anthropic.com/v1/messages/count_tokens?beta=true"
stickySessionTTL
=
time
.
Hour
// 粘性会话TTL
stickySessionTTL
=
time
.
Hour
// 粘性会话TTL
defaultMaxLineSize
=
40
*
1024
*
1024
defaultMaxLineSize
=
40
*
1024
*
1024
claudeCodeSystemPrompt
=
"You are Claude Code, Anthropic's official CLI for Claude."
// Canonical Claude Code banner. Keep it EXACT (no trailing whitespace/newlines)
maxCacheControlBlocks
=
4
// Anthropic API 允许的最大 cache_control 块数量
// to match real Claude CLI traffic as closely as possible. When we need a visual
// separator between system blocks, we add "\n\n" at concatenation time.
claudeCodeSystemPrompt
=
"You are Claude Code, Anthropic's official CLI for Claude."
maxCacheControlBlocks
=
4
// Anthropic API 允许的最大 cache_control 块数量
)
const
(
claudeMimicDebugInfoKey
=
"claude_mimic_debug_info"
)
)
func
(
s
*
GatewayService
)
debugModelRoutingEnabled
()
bool
{
func
(
s
*
GatewayService
)
debugModelRoutingEnabled
()
bool
{
...
@@ -46,6 +55,11 @@ func (s *GatewayService) debugModelRoutingEnabled() bool {
...
@@ -46,6 +55,11 @@ func (s *GatewayService) debugModelRoutingEnabled() bool {
return
v
==
"1"
||
v
==
"true"
||
v
==
"yes"
||
v
==
"on"
return
v
==
"1"
||
v
==
"true"
||
v
==
"yes"
||
v
==
"on"
}
}
func
(
s
*
GatewayService
)
debugClaudeMimicEnabled
()
bool
{
v
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
os
.
Getenv
(
"SUB2API_DEBUG_CLAUDE_MIMIC"
)))
return
v
==
"1"
||
v
==
"true"
||
v
==
"yes"
||
v
==
"on"
}
func
shortSessionHash
(
sessionHash
string
)
string
{
func
shortSessionHash
(
sessionHash
string
)
string
{
if
sessionHash
==
""
{
if
sessionHash
==
""
{
return
""
return
""
...
@@ -56,12 +70,178 @@ func shortSessionHash(sessionHash string) string {
...
@@ -56,12 +70,178 @@ func shortSessionHash(sessionHash string) string {
return
sessionHash
[
:
8
]
return
sessionHash
[
:
8
]
}
}
func
redactAuthHeaderValue
(
v
string
)
string
{
v
=
strings
.
TrimSpace
(
v
)
if
v
==
""
{
return
""
}
// Keep scheme for debugging, redact secret.
if
strings
.
HasPrefix
(
strings
.
ToLower
(
v
),
"bearer "
)
{
return
"Bearer [redacted]"
}
return
"[redacted]"
}
func
safeHeaderValueForLog
(
key
string
,
v
string
)
string
{
key
=
strings
.
ToLower
(
strings
.
TrimSpace
(
key
))
switch
key
{
case
"authorization"
,
"x-api-key"
:
return
redactAuthHeaderValue
(
v
)
default
:
return
strings
.
TrimSpace
(
v
)
}
}
func
extractSystemPreviewFromBody
(
body
[]
byte
)
string
{
if
len
(
body
)
==
0
{
return
""
}
sys
:=
gjson
.
GetBytes
(
body
,
"system"
)
if
!
sys
.
Exists
()
{
return
""
}
switch
{
case
sys
.
IsArray
()
:
for
_
,
item
:=
range
sys
.
Array
()
{
if
!
item
.
IsObject
()
{
continue
}
if
strings
.
EqualFold
(
item
.
Get
(
"type"
)
.
String
(),
"text"
)
{
if
t
:=
item
.
Get
(
"text"
)
.
String
();
strings
.
TrimSpace
(
t
)
!=
""
{
return
t
}
}
}
return
""
case
sys
.
Type
==
gjson
.
String
:
return
sys
.
String
()
default
:
return
""
}
}
func
buildClaudeMimicDebugLine
(
req
*
http
.
Request
,
body
[]
byte
,
account
*
Account
,
tokenType
string
,
mimicClaudeCode
bool
)
string
{
if
req
==
nil
{
return
""
}
// Only log a minimal fingerprint to avoid leaking user content.
interesting
:=
[]
string
{
"user-agent"
,
"x-app"
,
"anthropic-dangerous-direct-browser-access"
,
"anthropic-version"
,
"anthropic-beta"
,
"x-stainless-lang"
,
"x-stainless-package-version"
,
"x-stainless-os"
,
"x-stainless-arch"
,
"x-stainless-runtime"
,
"x-stainless-runtime-version"
,
"x-stainless-retry-count"
,
"x-stainless-timeout"
,
"authorization"
,
"x-api-key"
,
"content-type"
,
"accept"
,
"x-stainless-helper-method"
,
}
h
:=
make
([]
string
,
0
,
len
(
interesting
))
for
_
,
k
:=
range
interesting
{
if
v
:=
req
.
Header
.
Get
(
k
);
v
!=
""
{
h
=
append
(
h
,
fmt
.
Sprintf
(
"%s=%q"
,
k
,
safeHeaderValueForLog
(
k
,
v
)))
}
}
metaUserID
:=
strings
.
TrimSpace
(
gjson
.
GetBytes
(
body
,
"metadata.user_id"
)
.
String
())
sysPreview
:=
strings
.
TrimSpace
(
extractSystemPreviewFromBody
(
body
))
// Truncate preview to keep logs sane.
if
len
(
sysPreview
)
>
300
{
sysPreview
=
sysPreview
[
:
300
]
+
"..."
}
sysPreview
=
strings
.
ReplaceAll
(
sysPreview
,
"
\n
"
,
"
\\
n"
)
sysPreview
=
strings
.
ReplaceAll
(
sysPreview
,
"
\r
"
,
"
\\
r"
)
aid
:=
int64
(
0
)
aname
:=
""
if
account
!=
nil
{
aid
=
account
.
ID
aname
=
account
.
Name
}
return
fmt
.
Sprintf
(
"url=%s account=%d(%s) tokenType=%s mimic=%t meta.user_id=%q system.preview=%q headers={%s}"
,
req
.
URL
.
String
(),
aid
,
aname
,
tokenType
,
mimicClaudeCode
,
metaUserID
,
sysPreview
,
strings
.
Join
(
h
,
" "
),
)
}
func
logClaudeMimicDebug
(
req
*
http
.
Request
,
body
[]
byte
,
account
*
Account
,
tokenType
string
,
mimicClaudeCode
bool
)
{
line
:=
buildClaudeMimicDebugLine
(
req
,
body
,
account
,
tokenType
,
mimicClaudeCode
)
if
line
==
""
{
return
}
log
.
Printf
(
"[ClaudeMimicDebug] %s"
,
line
)
}
func
isClaudeCodeCredentialScopeError
(
msg
string
)
bool
{
m
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
msg
))
if
m
==
""
{
return
false
}
return
strings
.
Contains
(
m
,
"only authorized for use with claude code"
)
&&
strings
.
Contains
(
m
,
"cannot be used for other api requests"
)
}
// sseDataRe matches SSE data lines with optional whitespace after colon.
// sseDataRe matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var
(
var
(
sseDataRe
=
regexp
.
MustCompile
(
`^data:\s*`
)
sseDataRe
=
regexp
.
MustCompile
(
`^data:\s*`
)
sessionIDRegex
=
regexp
.
MustCompile
(
`session_([a-f0-9-]{36})`
)
sessionIDRegex
=
regexp
.
MustCompile
(
`session_([a-f0-9-]{36})`
)
claudeCliUserAgentRe
=
regexp
.
MustCompile
(
`^claude-cli/\d+\.\d+\.\d+`
)
claudeCliUserAgentRe
=
regexp
.
MustCompile
(
`^claude-cli/\d+\.\d+\.\d+`
)
toolPrefixRe
=
regexp
.
MustCompile
(
`(?i)^(?:oc_|mcp_)`
)
toolNameBoundaryRe
=
regexp
.
MustCompile
(
`[^a-zA-Z0-9]+`
)
toolNameCamelRe
=
regexp
.
MustCompile
(
`([a-z0-9])([A-Z])`
)
toolNameFieldRe
=
regexp
.
MustCompile
(
`"name"\s*:\s*"([^"]+)"`
)
modelFieldRe
=
regexp
.
MustCompile
(
`"model"\s*:\s*"([^"]+)"`
)
toolDescAbsPathRe
=
regexp
.
MustCompile
(
`/\/?(?:home|Users|tmp|var|opt|usr|etc)\/[^\s,\)"'\]]+`
)
toolDescWinPathRe
=
regexp
.
MustCompile
(
`(?i)[A-Z]:\\[^\s,\)"'\]]+`
)
claudeToolNameOverrides
=
map
[
string
]
string
{
"bash"
:
"Bash"
,
"read"
:
"Read"
,
"edit"
:
"Edit"
,
"write"
:
"Write"
,
"task"
:
"Task"
,
"glob"
:
"Glob"
,
"grep"
:
"Grep"
,
"webfetch"
:
"WebFetch"
,
"websearch"
:
"WebSearch"
,
"todowrite"
:
"TodoWrite"
,
"question"
:
"AskUserQuestion"
,
}
openCodeToolOverrides
=
map
[
string
]
string
{
"Bash"
:
"bash"
,
"Read"
:
"read"
,
"Edit"
:
"edit"
,
"Write"
:
"write"
,
"Task"
:
"task"
,
"Glob"
:
"glob"
,
"Grep"
:
"grep"
,
"WebFetch"
:
"webfetch"
,
"WebSearch"
:
"websearch"
,
"TodoWrite"
:
"todowrite"
,
"AskUserQuestion"
:
"question"
,
}
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
// 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等
// 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等
...
@@ -305,6 +485,19 @@ func (s *GatewayService) BindStickySession(ctx context.Context, groupID *int64,
...
@@ -305,6 +485,19 @@ func (s *GatewayService) BindStickySession(ctx context.Context, groupID *int64,
return
s
.
cache
.
SetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
accountID
,
stickySessionTTL
)
return
s
.
cache
.
SetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
accountID
,
stickySessionTTL
)
}
}
// GetCachedSessionAccountID retrieves the account ID bound to a sticky session.
// Returns 0 if no binding exists or on error.
func
(
s
*
GatewayService
)
GetCachedSessionAccountID
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
)
(
int64
,
error
)
{
if
sessionHash
==
""
||
s
.
cache
==
nil
{
return
0
,
nil
}
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
if
err
!=
nil
{
return
0
,
err
}
return
accountID
,
nil
}
func
(
s
*
GatewayService
)
extractCacheableContent
(
parsed
*
ParsedRequest
)
string
{
func
(
s
*
GatewayService
)
extractCacheableContent
(
parsed
*
ParsedRequest
)
string
{
if
parsed
==
nil
{
if
parsed
==
nil
{
return
""
return
""
...
@@ -405,6 +598,394 @@ func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte
...
@@ -405,6 +598,394 @@ func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte
return
newBody
return
newBody
}
}
type
claudeOAuthNormalizeOptions
struct
{
injectMetadata
bool
metadataUserID
string
stripSystemCacheControl
bool
}
func
stripToolPrefix
(
value
string
)
string
{
if
value
==
""
{
return
value
}
return
toolPrefixRe
.
ReplaceAllString
(
value
,
""
)
}
func
toPascalCase
(
value
string
)
string
{
if
value
==
""
{
return
value
}
normalized
:=
toolNameBoundaryRe
.
ReplaceAllString
(
value
,
" "
)
tokens
:=
make
([]
string
,
0
)
for
_
,
token
:=
range
strings
.
Fields
(
normalized
)
{
expanded
:=
toolNameCamelRe
.
ReplaceAllString
(
token
,
"$1 $2"
)
parts
:=
strings
.
Fields
(
expanded
)
if
len
(
parts
)
>
0
{
tokens
=
append
(
tokens
,
parts
...
)
}
}
if
len
(
tokens
)
==
0
{
return
value
}
var
builder
strings
.
Builder
for
_
,
token
:=
range
tokens
{
lower
:=
strings
.
ToLower
(
token
)
if
lower
==
""
{
continue
}
runes
:=
[]
rune
(
lower
)
runes
[
0
]
=
unicode
.
ToUpper
(
runes
[
0
])
_
,
_
=
builder
.
WriteString
(
string
(
runes
))
}
return
builder
.
String
()
}
func
toSnakeCase
(
value
string
)
string
{
if
value
==
""
{
return
value
}
output
:=
toolNameCamelRe
.
ReplaceAllString
(
value
,
"$1_$2"
)
output
=
toolNameBoundaryRe
.
ReplaceAllString
(
output
,
"_"
)
output
=
strings
.
Trim
(
output
,
"_"
)
return
strings
.
ToLower
(
output
)
}
func
normalizeToolNameForClaude
(
name
string
,
cache
map
[
string
]
string
)
string
{
if
name
==
""
{
return
name
}
stripped
:=
stripToolPrefix
(
name
)
mapped
,
ok
:=
claudeToolNameOverrides
[
strings
.
ToLower
(
stripped
)]
if
!
ok
{
mapped
=
toPascalCase
(
stripped
)
}
if
mapped
!=
""
&&
cache
!=
nil
&&
mapped
!=
stripped
{
cache
[
mapped
]
=
stripped
}
if
mapped
==
""
{
return
stripped
}
return
mapped
}
func
normalizeToolNameForOpenCode
(
name
string
,
cache
map
[
string
]
string
)
string
{
if
name
==
""
{
return
name
}
stripped
:=
stripToolPrefix
(
name
)
if
cache
!=
nil
{
if
mapped
,
ok
:=
cache
[
stripped
];
ok
{
return
mapped
}
}
if
mapped
,
ok
:=
openCodeToolOverrides
[
stripped
];
ok
{
return
mapped
}
return
toSnakeCase
(
stripped
)
}
func
normalizeParamNameForOpenCode
(
name
string
,
cache
map
[
string
]
string
)
string
{
if
name
==
""
{
return
name
}
if
cache
!=
nil
{
if
mapped
,
ok
:=
cache
[
name
];
ok
{
return
mapped
}
}
return
name
}
// sanitizeSystemText rewrites only the fixed OpenCode identity sentence (if present).
// We intentionally avoid broad keyword replacement in system prompts to prevent
// accidentally changing user-provided instructions.
func
sanitizeSystemText
(
text
string
)
string
{
if
text
==
""
{
return
text
}
// Some clients include a fixed OpenCode identity sentence. Anthropic may treat
// this as a non-Claude-Code fingerprint, so rewrite it to the canonical
// Claude Code banner before generic "OpenCode"/"opencode" replacements.
text
=
strings
.
ReplaceAll
(
text
,
"You are OpenCode, the best coding agent on the planet."
,
strings
.
TrimSpace
(
claudeCodeSystemPrompt
),
)
return
text
}
func
sanitizeToolDescription
(
description
string
)
string
{
if
description
==
""
{
return
description
}
description
=
toolDescAbsPathRe
.
ReplaceAllString
(
description
,
"[path]"
)
description
=
toolDescWinPathRe
.
ReplaceAllString
(
description
,
"[path]"
)
// Intentionally do NOT rewrite tool descriptions (OpenCode/Claude strings).
// Tool names/skill names may rely on exact wording, and rewriting can be misleading.
return
description
}
func
normalizeToolInputSchema
(
inputSchema
any
,
cache
map
[
string
]
string
)
{
schema
,
ok
:=
inputSchema
.
(
map
[
string
]
any
)
if
!
ok
{
return
}
properties
,
ok
:=
schema
[
"properties"
]
.
(
map
[
string
]
any
)
if
!
ok
{
return
}
newProperties
:=
make
(
map
[
string
]
any
,
len
(
properties
))
for
key
,
value
:=
range
properties
{
snakeKey
:=
toSnakeCase
(
key
)
newProperties
[
snakeKey
]
=
value
if
snakeKey
!=
key
&&
cache
!=
nil
{
cache
[
snakeKey
]
=
key
}
}
schema
[
"properties"
]
=
newProperties
if
required
,
ok
:=
schema
[
"required"
]
.
([]
any
);
ok
{
newRequired
:=
make
([]
any
,
0
,
len
(
required
))
for
_
,
item
:=
range
required
{
name
,
ok
:=
item
.
(
string
)
if
!
ok
{
newRequired
=
append
(
newRequired
,
item
)
continue
}
snakeName
:=
toSnakeCase
(
name
)
newRequired
=
append
(
newRequired
,
snakeName
)
if
snakeName
!=
name
&&
cache
!=
nil
{
cache
[
snakeName
]
=
name
}
}
schema
[
"required"
]
=
newRequired
}
}
func
stripCacheControlFromSystemBlocks
(
system
any
)
bool
{
blocks
,
ok
:=
system
.
([]
any
)
if
!
ok
{
return
false
}
changed
:=
false
for
_
,
item
:=
range
blocks
{
block
,
ok
:=
item
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
if
_
,
exists
:=
block
[
"cache_control"
];
!
exists
{
continue
}
delete
(
block
,
"cache_control"
)
changed
=
true
}
return
changed
}
func
normalizeClaudeOAuthRequestBody
(
body
[]
byte
,
modelID
string
,
opts
claudeOAuthNormalizeOptions
)
([]
byte
,
string
,
map
[
string
]
string
)
{
if
len
(
body
)
==
0
{
return
body
,
modelID
,
nil
}
var
req
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
req
);
err
!=
nil
{
return
body
,
modelID
,
nil
}
toolNameMap
:=
make
(
map
[
string
]
string
)
if
system
,
ok
:=
req
[
"system"
];
ok
{
switch
v
:=
system
.
(
type
)
{
case
string
:
sanitized
:=
sanitizeSystemText
(
v
)
if
sanitized
!=
v
{
req
[
"system"
]
=
sanitized
}
case
[]
any
:
for
_
,
item
:=
range
v
{
block
,
ok
:=
item
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
if
blockType
,
_
:=
block
[
"type"
]
.
(
string
);
blockType
!=
"text"
{
continue
}
text
,
ok
:=
block
[
"text"
]
.
(
string
)
if
!
ok
||
text
==
""
{
continue
}
sanitized
:=
sanitizeSystemText
(
text
)
if
sanitized
!=
text
{
block
[
"text"
]
=
sanitized
}
}
}
}
if
rawModel
,
ok
:=
req
[
"model"
]
.
(
string
);
ok
{
normalized
:=
claude
.
NormalizeModelID
(
rawModel
)
if
normalized
!=
rawModel
{
req
[
"model"
]
=
normalized
modelID
=
normalized
}
}
if
rawTools
,
exists
:=
req
[
"tools"
];
exists
{
switch
tools
:=
rawTools
.
(
type
)
{
case
[]
any
:
for
idx
,
tool
:=
range
tools
{
toolMap
,
ok
:=
tool
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
if
name
,
ok
:=
toolMap
[
"name"
]
.
(
string
);
ok
{
normalized
:=
normalizeToolNameForClaude
(
name
,
toolNameMap
)
if
normalized
!=
""
&&
normalized
!=
name
{
toolMap
[
"name"
]
=
normalized
}
}
if
desc
,
ok
:=
toolMap
[
"description"
]
.
(
string
);
ok
{
sanitized
:=
sanitizeToolDescription
(
desc
)
if
sanitized
!=
desc
{
toolMap
[
"description"
]
=
sanitized
}
}
if
schema
,
ok
:=
toolMap
[
"input_schema"
];
ok
{
normalizeToolInputSchema
(
schema
,
toolNameMap
)
}
tools
[
idx
]
=
toolMap
}
req
[
"tools"
]
=
tools
case
map
[
string
]
any
:
normalizedTools
:=
make
(
map
[
string
]
any
,
len
(
tools
))
for
name
,
value
:=
range
tools
{
normalized
:=
normalizeToolNameForClaude
(
name
,
toolNameMap
)
if
normalized
==
""
{
normalized
=
name
}
if
toolMap
,
ok
:=
value
.
(
map
[
string
]
any
);
ok
{
toolMap
[
"name"
]
=
normalized
if
desc
,
ok
:=
toolMap
[
"description"
]
.
(
string
);
ok
{
sanitized
:=
sanitizeToolDescription
(
desc
)
if
sanitized
!=
desc
{
toolMap
[
"description"
]
=
sanitized
}
}
if
schema
,
ok
:=
toolMap
[
"input_schema"
];
ok
{
normalizeToolInputSchema
(
schema
,
toolNameMap
)
}
normalizedTools
[
normalized
]
=
toolMap
continue
}
normalizedTools
[
normalized
]
=
value
}
req
[
"tools"
]
=
normalizedTools
}
}
else
{
req
[
"tools"
]
=
[]
any
{}
}
if
messages
,
ok
:=
req
[
"messages"
]
.
([]
any
);
ok
{
for
_
,
msg
:=
range
messages
{
msgMap
,
ok
:=
msg
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
content
,
ok
:=
msgMap
[
"content"
]
.
([]
any
)
if
!
ok
{
continue
}
for
_
,
block
:=
range
content
{
blockMap
,
ok
:=
block
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
if
blockType
,
_
:=
blockMap
[
"type"
]
.
(
string
);
blockType
!=
"tool_use"
{
continue
}
if
name
,
ok
:=
blockMap
[
"name"
]
.
(
string
);
ok
{
normalized
:=
normalizeToolNameForClaude
(
name
,
toolNameMap
)
if
normalized
!=
""
&&
normalized
!=
name
{
blockMap
[
"name"
]
=
normalized
}
}
}
}
}
if
opts
.
stripSystemCacheControl
{
if
system
,
ok
:=
req
[
"system"
];
ok
{
_
=
stripCacheControlFromSystemBlocks
(
system
)
}
}
if
opts
.
injectMetadata
&&
opts
.
metadataUserID
!=
""
{
metadata
,
ok
:=
req
[
"metadata"
]
.
(
map
[
string
]
any
)
if
!
ok
{
metadata
=
map
[
string
]
any
{}
req
[
"metadata"
]
=
metadata
}
if
existing
,
ok
:=
metadata
[
"user_id"
]
.
(
string
);
!
ok
||
existing
==
""
{
metadata
[
"user_id"
]
=
opts
.
metadataUserID
}
}
delete
(
req
,
"temperature"
)
delete
(
req
,
"tool_choice"
)
newBody
,
err
:=
json
.
Marshal
(
req
)
if
err
!=
nil
{
return
body
,
modelID
,
toolNameMap
}
return
newBody
,
modelID
,
toolNameMap
}
func
(
s
*
GatewayService
)
buildOAuthMetadataUserID
(
parsed
*
ParsedRequest
,
account
*
Account
,
fp
*
Fingerprint
)
string
{
if
parsed
==
nil
||
account
==
nil
{
return
""
}
if
parsed
.
MetadataUserID
!=
""
{
return
""
}
userID
:=
strings
.
TrimSpace
(
account
.
GetClaudeUserID
())
if
userID
==
""
&&
fp
!=
nil
{
userID
=
fp
.
ClientID
}
if
userID
==
""
{
// Fall back to a random, well-formed client id so we can still satisfy
// Claude Code OAuth requirements when account metadata is incomplete.
userID
=
generateClientID
()
}
sessionHash
:=
s
.
GenerateSessionHash
(
parsed
)
sessionID
:=
uuid
.
NewString
()
if
sessionHash
!=
""
{
seed
:=
fmt
.
Sprintf
(
"%d::%s"
,
account
.
ID
,
sessionHash
)
sessionID
=
generateSessionUUID
(
seed
)
}
// Prefer the newer format that includes account_uuid (if present),
// otherwise fall back to the legacy Claude Code format.
accountUUID
:=
strings
.
TrimSpace
(
account
.
GetExtraString
(
"account_uuid"
))
if
accountUUID
!=
""
{
return
fmt
.
Sprintf
(
"user_%s_account_%s_session_%s"
,
userID
,
accountUUID
,
sessionID
)
}
return
fmt
.
Sprintf
(
"user_%s_account__session_%s"
,
userID
,
sessionID
)
}
func
generateSessionUUID
(
seed
string
)
string
{
if
seed
==
""
{
return
uuid
.
NewString
()
}
hash
:=
sha256
.
Sum256
([]
byte
(
seed
))
bytes
:=
hash
[
:
16
]
bytes
[
6
]
=
(
bytes
[
6
]
&
0x0f
)
|
0x40
bytes
[
8
]
=
(
bytes
[
8
]
&
0x3f
)
|
0x80
return
fmt
.
Sprintf
(
"%x-%x-%x-%x-%x"
,
bytes
[
0
:
4
],
bytes
[
4
:
6
],
bytes
[
6
:
8
],
bytes
[
8
:
10
],
bytes
[
10
:
16
])
}
// SelectAccount 选择账号(粘性会话+优先级)
// SelectAccount 选择账号(粘性会话+优先级)
func
(
s
*
GatewayService
)
SelectAccount
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
)
(
*
Account
,
error
)
{
func
(
s
*
GatewayService
)
SelectAccount
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
)
(
*
Account
,
error
)
{
return
s
.
SelectAccountForModel
(
ctx
,
groupID
,
sessionHash
,
""
)
return
s
.
SelectAccountForModel
(
ctx
,
groupID
,
sessionHash
,
""
)
...
@@ -1880,6 +2461,10 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
...
@@ -1880,6 +2461,10 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
// Antigravity 平台使用专门的模型支持检查
// Antigravity 平台使用专门的模型支持检查
return
IsAntigravityModelSupported
(
requestedModel
)
return
IsAntigravityModelSupported
(
requestedModel
)
}
}
// Gemini API Key 账户直接透传,由上游判断模型是否支持
if
account
.
Platform
==
PlatformGemini
&&
account
.
Type
==
AccountTypeAPIKey
{
return
true
}
// 其他平台使用账户的模型支持检查
// 其他平台使用账户的模型支持检查
return
account
.
IsModelSupported
(
requestedModel
)
return
account
.
IsModelSupported
(
requestedModel
)
}
}
...
@@ -2004,6 +2589,16 @@ func isClaudeCodeClient(userAgent string, metadataUserID string) bool {
...
@@ -2004,6 +2589,16 @@ func isClaudeCodeClient(userAgent string, metadataUserID string) bool {
return
claudeCliUserAgentRe
.
MatchString
(
userAgent
)
return
claudeCliUserAgentRe
.
MatchString
(
userAgent
)
}
}
func
isClaudeCodeRequest
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
parsed
*
ParsedRequest
)
bool
{
if
IsClaudeCodeClient
(
ctx
)
{
return
true
}
if
parsed
==
nil
||
c
==
nil
{
return
false
}
return
isClaudeCodeClient
(
c
.
GetHeader
(
"User-Agent"
),
parsed
.
MetadataUserID
)
}
// systemIncludesClaudeCodePrompt 检查 system 中是否已包含 Claude Code 提示词
// systemIncludesClaudeCodePrompt 检查 system 中是否已包含 Claude Code 提示词
// 使用前缀匹配支持多种变体(标准版、Agent SDK 版等)
// 使用前缀匹配支持多种变体(标准版、Agent SDK 版等)
func
systemIncludesClaudeCodePrompt
(
system
any
)
bool
{
func
systemIncludesClaudeCodePrompt
(
system
any
)
bool
{
...
@@ -2040,6 +2635,10 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
...
@@ -2040,6 +2635,10 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
"text"
:
claudeCodeSystemPrompt
,
"text"
:
claudeCodeSystemPrompt
,
"cache_control"
:
map
[
string
]
string
{
"type"
:
"ephemeral"
},
"cache_control"
:
map
[
string
]
string
{
"type"
:
"ephemeral"
},
}
}
// Opencode plugin applies an extra safeguard: it not only prepends the Claude Code
// banner, it also prefixes the next system instruction with the same banner plus
// a blank line. This helps when upstream concatenates system instructions.
claudeCodePrefix
:=
strings
.
TrimSpace
(
claudeCodeSystemPrompt
)
var
newSystem
[]
any
var
newSystem
[]
any
...
@@ -2047,19 +2646,36 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
...
@@ -2047,19 +2646,36 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
case
nil
:
case
nil
:
newSystem
=
[]
any
{
claudeCodeBlock
}
newSystem
=
[]
any
{
claudeCodeBlock
}
case
string
:
case
string
:
if
v
==
""
||
v
==
claudeCodeSystemPrompt
{
// Be tolerant of older/newer clients that may differ only by trailing whitespace/newlines.
if
strings
.
TrimSpace
(
v
)
==
""
||
strings
.
TrimSpace
(
v
)
==
strings
.
TrimSpace
(
claudeCodeSystemPrompt
)
{
newSystem
=
[]
any
{
claudeCodeBlock
}
newSystem
=
[]
any
{
claudeCodeBlock
}
}
else
{
}
else
{
newSystem
=
[]
any
{
claudeCodeBlock
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
v
}}
// Mirror opencode behavior: keep the banner as a separate system entry,
// but also prefix the next system text with the banner.
merged
:=
v
if
!
strings
.
HasPrefix
(
v
,
claudeCodePrefix
)
{
merged
=
claudeCodePrefix
+
"
\n\n
"
+
v
}
newSystem
=
[]
any
{
claudeCodeBlock
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
merged
}}
}
}
case
[]
any
:
case
[]
any
:
newSystem
=
make
([]
any
,
0
,
len
(
v
)
+
1
)
newSystem
=
make
([]
any
,
0
,
len
(
v
)
+
1
)
newSystem
=
append
(
newSystem
,
claudeCodeBlock
)
newSystem
=
append
(
newSystem
,
claudeCodeBlock
)
prefixedNext
:=
false
for
_
,
item
:=
range
v
{
for
_
,
item
:=
range
v
{
if
m
,
ok
:=
item
.
(
map
[
string
]
any
);
ok
{
if
m
,
ok
:=
item
.
(
map
[
string
]
any
);
ok
{
if
text
,
ok
:=
m
[
"text"
]
.
(
string
);
ok
&&
text
==
claudeCodeSystemPrompt
{
if
text
,
ok
:=
m
[
"text"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
text
)
==
strings
.
TrimSpace
(
claudeCodeSystemPrompt
)
{
continue
continue
}
}
// Prefix the first subsequent text system block once.
if
!
prefixedNext
{
if
blockType
,
_
:=
m
[
"type"
]
.
(
string
);
blockType
==
"text"
{
if
text
,
ok
:=
m
[
"text"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
text
)
!=
""
&&
!
strings
.
HasPrefix
(
text
,
claudeCodePrefix
)
{
m
[
"text"
]
=
claudeCodePrefix
+
"
\n\n
"
+
text
prefixedNext
=
true
}
}
}
}
}
newSystem
=
append
(
newSystem
,
item
)
newSystem
=
append
(
newSystem
,
item
)
}
}
...
@@ -2263,21 +2879,38 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
...
@@ -2263,21 +2879,38 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
body
:=
parsed
.
Body
body
:=
parsed
.
Body
reqModel
:=
parsed
.
Model
reqModel
:=
parsed
.
Model
reqStream
:=
parsed
.
Stream
reqStream
:=
parsed
.
Stream
originalModel
:=
reqModel
var
toolNameMap
map
[
string
]
string
isClaudeCode
:=
isClaudeCodeRequest
(
ctx
,
c
,
parsed
)
shouldMimicClaudeCode
:=
account
.
IsOAuth
()
&&
!
isClaudeCode
if
shouldMimicClaudeCode
{
// 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要)
// 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
if
!
strings
.
Contains
(
strings
.
ToLower
(
reqModel
),
"haiku"
)
&&
!
systemIncludesClaudeCodePrompt
(
parsed
.
System
)
{
body
=
injectClaudeCodePrompt
(
body
,
parsed
.
System
)
}
normalizeOpts
:=
claudeOAuthNormalizeOptions
{
stripSystemCacheControl
:
true
}
if
s
.
identityService
!=
nil
{
fp
,
err
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
.
Request
.
Header
)
if
err
==
nil
&&
fp
!=
nil
{
if
metadataUserID
:=
s
.
buildOAuthMetadataUserID
(
parsed
,
account
,
fp
);
metadataUserID
!=
""
{
normalizeOpts
.
injectMetadata
=
true
normalizeOpts
.
metadataUserID
=
metadataUserID
}
}
}
// 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要)
body
,
reqModel
,
toolNameMap
=
normalizeClaudeOAuthRequestBody
(
body
,
reqModel
,
normalizeOpts
)
// 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
if
account
.
IsOAuth
()
&&
!
isClaudeCodeClient
(
c
.
GetHeader
(
"User-Agent"
),
parsed
.
MetadataUserID
)
&&
!
strings
.
Contains
(
strings
.
ToLower
(
reqModel
),
"haiku"
)
&&
!
systemIncludesClaudeCodePrompt
(
parsed
.
System
)
{
body
=
injectClaudeCodePrompt
(
body
,
parsed
.
System
)
}
}
// 强制执行 cache_control 块数量限制(最多 4 个)
// 强制执行 cache_control 块数量限制(最多 4 个)
body
=
enforceCacheControlLimit
(
body
)
body
=
enforceCacheControlLimit
(
body
)
// 应用模型映射(仅对apikey类型账号)
// 应用模型映射(仅对apikey类型账号)
originalModel
:=
reqModel
if
account
.
Type
==
AccountTypeAPIKey
{
if
account
.
Type
==
AccountTypeAPIKey
{
mappedModel
:=
account
.
GetMappedModel
(
reqModel
)
mappedModel
:=
account
.
GetMappedModel
(
reqModel
)
if
mappedModel
!=
reqModel
{
if
mappedModel
!=
reqModel
{
...
@@ -2309,10 +2942,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
...
@@ -2309,10 +2942,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
retryStart
:=
time
.
Now
()
retryStart
:=
time
.
Now
()
for
attempt
:=
1
;
attempt
<=
maxRetryAttempts
;
attempt
++
{
for
attempt
:=
1
;
attempt
<=
maxRetryAttempts
;
attempt
++
{
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
upstreamReq
,
err
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
body
,
token
,
tokenType
,
reqModel
)
// Capture upstream request body for ops retry of this attempt.
// Capture upstream request body for ops retry of this attempt.
c
.
Set
(
OpsUpstreamRequestBodyKey
,
string
(
body
))
c
.
Set
(
OpsUpstreamRequestBodyKey
,
string
(
body
))
upstreamReq
,
err
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
body
,
token
,
tokenType
,
reqModel
,
reqStream
,
shouldMimicClaudeCode
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
...
@@ -2390,7 +3022,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
...
@@ -2390,7 +3022,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// also downgrade tool_use/tool_result blocks to text.
// also downgrade tool_use/tool_result blocks to text.
filteredBody
:=
FilterThinkingBlocksForRetry
(
body
)
filteredBody
:=
FilterThinkingBlocksForRetry
(
body
)
retryReq
,
buildErr
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
filteredBody
,
token
,
tokenType
,
reqModel
)
retryReq
,
buildErr
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
filteredBody
,
token
,
tokenType
,
reqModel
,
reqStream
,
shouldMimicClaudeCode
)
if
buildErr
==
nil
{
if
buildErr
==
nil
{
retryResp
,
retryErr
:=
s
.
httpUpstream
.
DoWithTLS
(
retryReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
())
retryResp
,
retryErr
:=
s
.
httpUpstream
.
DoWithTLS
(
retryReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
())
if
retryErr
==
nil
{
if
retryErr
==
nil
{
...
@@ -2422,7 +3054,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
...
@@ -2422,7 +3054,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
if
looksLikeToolSignatureError
(
msg2
)
&&
time
.
Since
(
retryStart
)
<
maxRetryElapsed
{
if
looksLikeToolSignatureError
(
msg2
)
&&
time
.
Since
(
retryStart
)
<
maxRetryElapsed
{
log
.
Printf
(
"Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded"
,
account
.
ID
)
log
.
Printf
(
"Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded"
,
account
.
ID
)
filteredBody2
:=
FilterSignatureSensitiveBlocksForRetry
(
body
)
filteredBody2
:=
FilterSignatureSensitiveBlocksForRetry
(
body
)
retryReq2
,
buildErr2
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
filteredBody2
,
token
,
tokenType
,
reqModel
)
retryReq2
,
buildErr2
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
filteredBody2
,
token
,
tokenType
,
reqModel
,
reqStream
,
shouldMimicClaudeCode
)
if
buildErr2
==
nil
{
if
buildErr2
==
nil
{
retryResp2
,
retryErr2
:=
s
.
httpUpstream
.
DoWithTLS
(
retryReq2
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
())
retryResp2
,
retryErr2
:=
s
.
httpUpstream
.
DoWithTLS
(
retryReq2
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
())
if
retryErr2
==
nil
{
if
retryErr2
==
nil
{
...
@@ -2647,7 +3279,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
...
@@ -2647,7 +3279,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
var
firstTokenMs
*
int
var
firstTokenMs
*
int
var
clientDisconnect
bool
var
clientDisconnect
bool
if
reqStream
{
if
reqStream
{
streamResult
,
err
:=
s
.
handleStreamingResponse
(
ctx
,
resp
,
c
,
account
,
startTime
,
originalModel
,
reqModel
)
streamResult
,
err
:=
s
.
handleStreamingResponse
(
ctx
,
resp
,
c
,
account
,
startTime
,
originalModel
,
reqModel
,
toolNameMap
,
shouldMimicClaudeCode
)
if
err
!=
nil
{
if
err
!=
nil
{
if
err
.
Error
()
==
"have error in stream"
{
if
err
.
Error
()
==
"have error in stream"
{
return
nil
,
&
UpstreamFailoverError
{
return
nil
,
&
UpstreamFailoverError
{
...
@@ -2660,7 +3292,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
...
@@ -2660,7 +3292,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
firstTokenMs
=
streamResult
.
firstTokenMs
firstTokenMs
=
streamResult
.
firstTokenMs
clientDisconnect
=
streamResult
.
clientDisconnect
clientDisconnect
=
streamResult
.
clientDisconnect
}
else
{
}
else
{
usage
,
err
=
s
.
handleNonStreamingResponse
(
ctx
,
resp
,
c
,
account
,
originalModel
,
reqModel
)
usage
,
err
=
s
.
handleNonStreamingResponse
(
ctx
,
resp
,
c
,
account
,
originalModel
,
reqModel
,
toolNameMap
,
shouldMimicClaudeCode
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
...
@@ -2677,7 +3309,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
...
@@ -2677,7 +3309,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
},
nil
},
nil
}
}
func
(
s
*
GatewayService
)
buildUpstreamRequest
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
,
token
,
tokenType
,
modelID
string
)
(
*
http
.
Request
,
error
)
{
func
(
s
*
GatewayService
)
buildUpstreamRequest
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
,
token
,
tokenType
,
modelID
string
,
reqStream
bool
,
mimicClaudeCode
bool
)
(
*
http
.
Request
,
error
)
{
// 确定目标URL
// 确定目标URL
targetURL
:=
claudeAPIURL
targetURL
:=
claudeAPIURL
if
account
.
Type
==
AccountTypeAPIKey
{
if
account
.
Type
==
AccountTypeAPIKey
{
...
@@ -2691,11 +3323,16 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
...
@@ -2691,11 +3323,16 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
}
}
}
clientHeaders
:=
http
.
Header
{}
if
c
!=
nil
&&
c
.
Request
!=
nil
{
clientHeaders
=
c
.
Request
.
Header
}
// OAuth账号:应用统一指纹
// OAuth账号:应用统一指纹
var
fingerprint
*
Fingerprint
var
fingerprint
*
Fingerprint
if
account
.
IsOAuth
()
&&
s
.
identityService
!=
nil
{
if
account
.
IsOAuth
()
&&
s
.
identityService
!=
nil
{
// 1. 获取或创建指纹(包含随机生成的ClientID)
// 1. 获取或创建指纹(包含随机生成的ClientID)
fp
,
err
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
.
Request
.
Header
)
fp
,
err
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
lient
Header
s
)
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Printf
(
"Warning: failed to get fingerprint for account %d: %v"
,
account
.
ID
,
err
)
log
.
Printf
(
"Warning: failed to get fingerprint for account %d: %v"
,
account
.
ID
,
err
)
// 失败时降级为透传原始headers
// 失败时降级为透传原始headers
...
@@ -2726,7 +3363,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
...
@@ -2726,7 +3363,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
}
// 白名单透传headers
// 白名单透传headers
for
key
,
values
:=
range
c
.
Request
.
Header
{
for
key
,
values
:=
range
c
lient
Header
s
{
lowerKey
:=
strings
.
ToLower
(
key
)
lowerKey
:=
strings
.
ToLower
(
key
)
if
allowedHeaders
[
lowerKey
]
{
if
allowedHeaders
[
lowerKey
]
{
for
_
,
v
:=
range
values
{
for
_
,
v
:=
range
values
{
...
@@ -2747,10 +3384,30 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
...
@@ -2747,10 +3384,30 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
if
req
.
Header
.
Get
(
"anthropic-version"
)
==
""
{
if
req
.
Header
.
Get
(
"anthropic-version"
)
==
""
{
req
.
Header
.
Set
(
"anthropic-version"
,
"2023-06-01"
)
req
.
Header
.
Set
(
"anthropic-version"
,
"2023-06-01"
)
}
}
if
tokenType
==
"oauth"
{
applyClaudeOAuthHeaderDefaults
(
req
,
reqStream
)
}
// 处理anthropic-beta header(OAuth账号需要
特殊处理
)
// 处理
anthropic-beta header(OAuth
账号需要
包含 oauth beta
)
if
tokenType
==
"oauth"
{
if
tokenType
==
"oauth"
{
req
.
Header
.
Set
(
"anthropic-beta"
,
s
.
getBetaHeader
(
modelID
,
c
.
GetHeader
(
"anthropic-beta"
)))
if
mimicClaudeCode
{
// 非 Claude Code 客户端:按 opencode 的策略处理:
// - 强制 Claude Code 指纹相关请求头(尤其是 user-agent/x-stainless/x-app)
// - 保留 incoming beta 的同时,确保 OAuth 所需 beta 存在
applyClaudeCodeMimicHeaders
(
req
,
reqStream
)
incomingBeta
:=
req
.
Header
.
Get
(
"anthropic-beta"
)
// Match real Claude CLI traffic (per mitmproxy reports):
// messages requests typically use only oauth + interleaved-thinking.
// Also drop claude-code beta if a downstream client added it.
requiredBetas
:=
[]
string
{
claude
.
BetaOAuth
,
claude
.
BetaInterleavedThinking
}
drop
:=
map
[
string
]
struct
{}{
claude
.
BetaClaudeCode
:
{}}
req
.
Header
.
Set
(
"anthropic-beta"
,
mergeAnthropicBetaDropping
(
requiredBetas
,
incomingBeta
,
drop
))
}
else
{
// Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta
clientBetaHeader
:=
req
.
Header
.
Get
(
"anthropic-beta"
)
req
.
Header
.
Set
(
"anthropic-beta"
,
s
.
getBetaHeader
(
modelID
,
clientBetaHeader
))
}
}
else
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
InjectBetaForAPIKey
&&
req
.
Header
.
Get
(
"anthropic-beta"
)
==
""
{
}
else
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
InjectBetaForAPIKey
&&
req
.
Header
.
Get
(
"anthropic-beta"
)
==
""
{
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
if
requestNeedsBetaFeatures
(
body
)
{
if
requestNeedsBetaFeatures
(
body
)
{
...
@@ -2760,6 +3417,15 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
...
@@ -2760,6 +3417,15 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
}
}
}
// Always capture a compact fingerprint line for later error diagnostics.
// We only print it when needed (or when the explicit debug flag is enabled).
if
c
!=
nil
&&
tokenType
==
"oauth"
{
c
.
Set
(
claudeMimicDebugInfoKey
,
buildClaudeMimicDebugLine
(
req
,
body
,
account
,
tokenType
,
mimicClaudeCode
))
}
if
s
.
debugClaudeMimicEnabled
()
{
logClaudeMimicDebug
(
req
,
body
,
account
,
tokenType
,
mimicClaudeCode
)
}
return
req
,
nil
return
req
,
nil
}
}
...
@@ -2829,22 +3495,109 @@ func defaultAPIKeyBetaHeader(body []byte) string {
...
@@ -2829,22 +3495,109 @@ func defaultAPIKeyBetaHeader(body []byte) string {
return
claude
.
APIKeyBetaHeader
return
claude
.
APIKeyBetaHeader
}
}
func
truncateForLog
(
b
[]
byte
,
maxBytes
int
)
string
{
func
applyClaudeOAuthHeaderDefaults
(
req
*
http
.
Request
,
isStream
bool
)
{
if
maxBytes
<=
0
{
if
req
==
nil
{
maxBytes
=
2048
return
}
}
if
len
(
b
)
>
maxBytes
{
if
req
.
Header
.
Get
(
"accept"
)
==
""
{
b
=
b
[
:
maxBytes
]
req
.
Header
.
Set
(
"accept"
,
"application/json"
)
}
for
key
,
value
:=
range
claude
.
DefaultHeaders
{
if
value
==
""
{
continue
}
if
req
.
Header
.
Get
(
key
)
==
""
{
req
.
Header
.
Set
(
key
,
value
)
}
}
if
isStream
&&
req
.
Header
.
Get
(
"x-stainless-helper-method"
)
==
""
{
req
.
Header
.
Set
(
"x-stainless-helper-method"
,
"stream"
)
}
}
s
:=
string
(
b
)
// 保持一行,避免污染日志格式
s
=
strings
.
ReplaceAll
(
s
,
"
\n
"
,
"
\\
n"
)
s
=
strings
.
ReplaceAll
(
s
,
"
\r
"
,
"
\\
r"
)
return
s
}
}
// isThinkingBlockSignatureError 检测是否是thinking block相关错误
func
mergeAnthropicBeta
(
required
[]
string
,
incoming
string
)
string
{
// 这类错误可以通过过滤thinking blocks并重试来解决
seen
:=
make
(
map
[
string
]
struct
{},
len
(
required
)
+
8
)
out
:=
make
([]
string
,
0
,
len
(
required
)
+
8
)
add
:=
func
(
v
string
)
{
v
=
strings
.
TrimSpace
(
v
)
if
v
==
""
{
return
}
if
_
,
ok
:=
seen
[
v
];
ok
{
return
}
seen
[
v
]
=
struct
{}{}
out
=
append
(
out
,
v
)
}
for
_
,
r
:=
range
required
{
add
(
r
)
}
for
_
,
p
:=
range
strings
.
Split
(
incoming
,
","
)
{
add
(
p
)
}
return
strings
.
Join
(
out
,
","
)
}
func
mergeAnthropicBetaDropping
(
required
[]
string
,
incoming
string
,
drop
map
[
string
]
struct
{})
string
{
merged
:=
mergeAnthropicBeta
(
required
,
incoming
)
if
merged
==
""
||
len
(
drop
)
==
0
{
return
merged
}
out
:=
make
([]
string
,
0
,
8
)
for
_
,
p
:=
range
strings
.
Split
(
merged
,
","
)
{
p
=
strings
.
TrimSpace
(
p
)
if
p
==
""
{
continue
}
if
_
,
ok
:=
drop
[
p
];
ok
{
continue
}
out
=
append
(
out
,
p
)
}
return
strings
.
Join
(
out
,
","
)
}
// applyClaudeCodeMimicHeaders forces "Claude Code-like" request headers.
// This mirrors opencode-anthropic-auth behavior: do not trust downstream
// headers when using Claude Code-scoped OAuth credentials.
func
applyClaudeCodeMimicHeaders
(
req
*
http
.
Request
,
isStream
bool
)
{
if
req
==
nil
{
return
}
// Start with the standard defaults (fill missing).
applyClaudeOAuthHeaderDefaults
(
req
,
isStream
)
// Then force key headers to match Claude Code fingerprint regardless of what the client sent.
for
key
,
value
:=
range
claude
.
DefaultHeaders
{
if
value
==
""
{
continue
}
req
.
Header
.
Set
(
key
,
value
)
}
// Real Claude CLI uses Accept: application/json (even for streaming).
req
.
Header
.
Set
(
"accept"
,
"application/json"
)
if
isStream
{
req
.
Header
.
Set
(
"x-stainless-helper-method"
,
"stream"
)
}
}
func
truncateForLog
(
b
[]
byte
,
maxBytes
int
)
string
{
if
maxBytes
<=
0
{
maxBytes
=
2048
}
if
len
(
b
)
>
maxBytes
{
b
=
b
[
:
maxBytes
]
}
s
:=
string
(
b
)
// 保持一行,避免污染日志格式
s
=
strings
.
ReplaceAll
(
s
,
"
\n
"
,
"
\\
n"
)
s
=
strings
.
ReplaceAll
(
s
,
"
\r
"
,
"
\\
r"
)
return
s
}
// isThinkingBlockSignatureError 检测是否是thinking block相关错误
// 这类错误可以通过过滤thinking blocks并重试来解决
func
(
s
*
GatewayService
)
isThinkingBlockSignatureError
(
respBody
[]
byte
)
bool
{
func
(
s
*
GatewayService
)
isThinkingBlockSignatureError
(
respBody
[]
byte
)
bool
{
msg
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
respBody
)))
msg
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
respBody
)))
if
msg
==
""
{
if
msg
==
""
{
...
@@ -2932,6 +3685,20 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
...
@@ -2932,6 +3685,20 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
body
))
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
body
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
// Print a compact upstream request fingerprint when we hit the Claude Code OAuth
// credential scope error. This avoids requiring env-var tweaks in a fixed deploy.
if
isClaudeCodeCredentialScopeError
(
upstreamMsg
)
&&
c
!=
nil
{
if
v
,
ok
:=
c
.
Get
(
claudeMimicDebugInfoKey
);
ok
{
if
line
,
ok
:=
v
.
(
string
);
ok
&&
strings
.
TrimSpace
(
line
)
!=
""
{
log
.
Printf
(
"[ClaudeMimicDebugOnError] status=%d request_id=%s %s"
,
resp
.
StatusCode
,
resp
.
Header
.
Get
(
"x-request-id"
),
line
,
)
}
}
}
// Enrich Ops error logs with upstream status + message, and optionally a truncated body snippet.
// Enrich Ops error logs with upstream status + message, and optionally a truncated body snippet.
upstreamDetail
:=
""
upstreamDetail
:=
""
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
...
@@ -3061,6 +3828,19 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
...
@@ -3061,6 +3828,19 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
respBody
))
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
respBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
if
isClaudeCodeCredentialScopeError
(
upstreamMsg
)
&&
c
!=
nil
{
if
v
,
ok
:=
c
.
Get
(
claudeMimicDebugInfoKey
);
ok
{
if
line
,
ok
:=
v
.
(
string
);
ok
&&
strings
.
TrimSpace
(
line
)
!=
""
{
log
.
Printf
(
"[ClaudeMimicDebugOnError] status=%d request_id=%s %s"
,
resp
.
StatusCode
,
resp
.
Header
.
Get
(
"x-request-id"
),
line
,
)
}
}
}
upstreamDetail
:=
""
upstreamDetail
:=
""
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
maxBytes
:=
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
maxBytes
:=
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
...
@@ -3113,7 +3893,7 @@ type streamingResult struct {
...
@@ -3113,7 +3893,7 @@ type streamingResult struct {
clientDisconnect
bool
// 客户端是否在流式传输过程中断开
clientDisconnect
bool
// 客户端是否在流式传输过程中断开
}
}
func
(
s
*
GatewayService
)
handleStreamingResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
Account
,
startTime
time
.
Time
,
originalModel
,
mappedModel
string
)
(
*
streamingResult
,
error
)
{
func
(
s
*
GatewayService
)
handleStreamingResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
Account
,
startTime
time
.
Time
,
originalModel
,
mappedModel
string
,
toolNameMap
map
[
string
]
string
,
mimicClaudeCode
bool
)
(
*
streamingResult
,
error
)
{
// 更新5h窗口状态
// 更新5h窗口状态
s
.
rateLimitService
.
UpdateSessionWindow
(
ctx
,
account
,
resp
.
Header
)
s
.
rateLimitService
.
UpdateSessionWindow
(
ctx
,
account
,
resp
.
Header
)
...
@@ -3208,6 +3988,171 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
...
@@ -3208,6 +3988,171 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
needModelReplace
:=
originalModel
!=
mappedModel
needModelReplace
:=
originalModel
!=
mappedModel
clientDisconnected
:=
false
// 客户端断开标志,断开后继续读取上游以获取完整usage
clientDisconnected
:=
false
// 客户端断开标志,断开后继续读取上游以获取完整usage
pendingEventLines
:=
make
([]
string
,
0
,
4
)
var
toolInputBuffers
map
[
int
]
string
if
mimicClaudeCode
{
toolInputBuffers
=
make
(
map
[
int
]
string
)
}
transformToolInputJSON
:=
func
(
raw
string
)
string
{
if
!
mimicClaudeCode
{
return
raw
}
raw
=
strings
.
TrimSpace
(
raw
)
if
raw
==
""
{
return
raw
}
var
parsed
any
if
err
:=
json
.
Unmarshal
([]
byte
(
raw
),
&
parsed
);
err
!=
nil
{
return
replaceToolNamesInText
(
raw
,
toolNameMap
)
}
rewritten
,
changed
:=
rewriteParamKeysInValue
(
parsed
,
toolNameMap
)
if
changed
{
if
bytes
,
err
:=
json
.
Marshal
(
rewritten
);
err
==
nil
{
return
string
(
bytes
)
}
}
return
raw
}
processSSEEvent
:=
func
(
lines
[]
string
)
([]
string
,
string
,
error
)
{
if
len
(
lines
)
==
0
{
return
nil
,
""
,
nil
}
eventName
:=
""
dataLine
:=
""
for
_
,
line
:=
range
lines
{
trimmed
:=
strings
.
TrimSpace
(
line
)
if
strings
.
HasPrefix
(
trimmed
,
"event:"
)
{
eventName
=
strings
.
TrimSpace
(
strings
.
TrimPrefix
(
trimmed
,
"event:"
))
continue
}
if
dataLine
==
""
&&
sseDataRe
.
MatchString
(
trimmed
)
{
dataLine
=
sseDataRe
.
ReplaceAllString
(
trimmed
,
""
)
}
}
if
eventName
==
"error"
{
return
nil
,
dataLine
,
errors
.
New
(
"have error in stream"
)
}
if
dataLine
==
""
{
return
[]
string
{
strings
.
Join
(
lines
,
"
\n
"
)
+
"
\n\n
"
},
""
,
nil
}
if
dataLine
==
"[DONE]"
{
block
:=
""
if
eventName
!=
""
{
block
=
"event: "
+
eventName
+
"
\n
"
}
block
+=
"data: "
+
dataLine
+
"
\n\n
"
return
[]
string
{
block
},
dataLine
,
nil
}
var
event
map
[
string
]
any
if
err
:=
json
.
Unmarshal
([]
byte
(
dataLine
),
&
event
);
err
!=
nil
{
replaced
:=
dataLine
if
mimicClaudeCode
{
replaced
=
replaceToolNamesInText
(
dataLine
,
toolNameMap
)
}
block
:=
""
if
eventName
!=
""
{
block
=
"event: "
+
eventName
+
"
\n
"
}
block
+=
"data: "
+
replaced
+
"
\n\n
"
return
[]
string
{
block
},
replaced
,
nil
}
eventType
,
_
:=
event
[
"type"
]
.
(
string
)
if
eventName
==
""
{
eventName
=
eventType
}
if
needModelReplace
{
if
msg
,
ok
:=
event
[
"message"
]
.
(
map
[
string
]
any
);
ok
{
if
model
,
ok
:=
msg
[
"model"
]
.
(
string
);
ok
&&
model
==
mappedModel
{
msg
[
"model"
]
=
originalModel
}
}
}
if
mimicClaudeCode
&&
eventType
==
"content_block_delta"
{
if
delta
,
ok
:=
event
[
"delta"
]
.
(
map
[
string
]
any
);
ok
{
if
deltaType
,
_
:=
delta
[
"type"
]
.
(
string
);
deltaType
==
"input_json_delta"
{
if
indexVal
,
ok
:=
event
[
"index"
]
.
(
float64
);
ok
{
index
:=
int
(
indexVal
)
if
partial
,
ok
:=
delta
[
"partial_json"
]
.
(
string
);
ok
{
toolInputBuffers
[
index
]
+=
partial
}
}
return
nil
,
dataLine
,
nil
}
}
}
if
mimicClaudeCode
&&
eventType
==
"content_block_stop"
{
if
indexVal
,
ok
:=
event
[
"index"
]
.
(
float64
);
ok
{
index
:=
int
(
indexVal
)
if
buffered
:=
toolInputBuffers
[
index
];
buffered
!=
""
{
delete
(
toolInputBuffers
,
index
)
transformed
:=
transformToolInputJSON
(
buffered
)
synthetic
:=
map
[
string
]
any
{
"type"
:
"content_block_delta"
,
"index"
:
index
,
"delta"
:
map
[
string
]
any
{
"type"
:
"input_json_delta"
,
"partial_json"
:
transformed
,
},
}
synthBytes
,
synthErr
:=
json
.
Marshal
(
synthetic
)
if
synthErr
==
nil
{
synthBlock
:=
"event: content_block_delta
\n
"
+
"data: "
+
string
(
synthBytes
)
+
"
\n\n
"
rewriteToolNamesInValue
(
event
,
toolNameMap
)
stopBytes
,
stopErr
:=
json
.
Marshal
(
event
)
if
stopErr
==
nil
{
stopBlock
:=
""
if
eventName
!=
""
{
stopBlock
=
"event: "
+
eventName
+
"
\n
"
}
stopBlock
+=
"data: "
+
string
(
stopBytes
)
+
"
\n\n
"
return
[]
string
{
synthBlock
,
stopBlock
},
string
(
stopBytes
),
nil
}
}
}
}
}
if
mimicClaudeCode
{
rewriteToolNamesInValue
(
event
,
toolNameMap
)
}
newData
,
err
:=
json
.
Marshal
(
event
)
if
err
!=
nil
{
replaced
:=
dataLine
if
mimicClaudeCode
{
replaced
=
replaceToolNamesInText
(
dataLine
,
toolNameMap
)
}
block
:=
""
if
eventName
!=
""
{
block
=
"event: "
+
eventName
+
"
\n
"
}
block
+=
"data: "
+
replaced
+
"
\n\n
"
return
[]
string
{
block
},
replaced
,
nil
}
block
:=
""
if
eventName
!=
""
{
block
=
"event: "
+
eventName
+
"
\n
"
}
block
+=
"data: "
+
string
(
newData
)
+
"
\n\n
"
return
[]
string
{
block
},
string
(
newData
),
nil
}
for
{
for
{
select
{
select
{
case
ev
,
ok
:=
<-
events
:
case
ev
,
ok
:=
<-
events
:
...
@@ -3236,43 +4181,44 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
...
@@ -3236,43 +4181,44 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream read error: %w"
,
ev
.
err
)
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream read error: %w"
,
ev
.
err
)
}
}
line
:=
ev
.
line
line
:=
ev
.
line
if
line
==
"event: error"
{
trimmed
:=
strings
.
TrimSpace
(
line
)
// 上游返回错误事件,如果客户端已断开仍返回已收集的 usage
if
clientDisconnected
{
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
true
},
nil
}
return
nil
,
errors
.
New
(
"have error in stream"
)
}
// Extract data from SSE line (supports both "data: " and "data:" formats)
if
trimmed
==
""
{
var
data
string
if
len
(
pendingEventLines
)
==
0
{
if
sseDataRe
.
MatchString
(
line
)
{
continue
data
=
sseDataRe
.
ReplaceAllString
(
line
,
""
)
// 如果有模型映射,替换响应中的model字段
if
needModelReplace
{
line
=
s
.
replaceModelInSSELine
(
line
,
mappedModel
,
originalModel
)
}
}
}
// 写入客户端(统一处理 data 行和非 data 行)
outputBlocks
,
data
,
err
:=
processSSEEvent
(
pendingEventLines
)
if
!
clientDisconnected
{
pendingEventLines
=
pendingEventLines
[
:
0
]
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
if
err
!=
nil
{
clientDisconnected
=
true
if
clientDisconnected
{
log
.
Printf
(
"Client disconnected during streaming, continuing to drain upstream for billing"
)
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
true
},
nil
}
else
{
}
flusher
.
Flush
()
return
nil
,
err
}
}
}
// 无论客户端是否断开,都解析 usage(仅对 data 行)
for
_
,
block
:=
range
outputBlocks
{
if
data
!=
""
{
if
!
clientDisconnected
{
if
firstTokenMs
==
nil
&&
data
!=
"[DONE]"
{
if
_
,
werr
:=
fmt
.
Fprint
(
w
,
block
);
werr
!=
nil
{
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
clientDisconnected
=
true
firstTokenMs
=
&
ms
log
.
Printf
(
"Client disconnected during streaming, continuing to drain upstream for billing"
)
break
}
flusher
.
Flush
()
}
if
data
!=
""
{
if
firstTokenMs
==
nil
&&
data
!=
"[DONE]"
{
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
firstTokenMs
=
&
ms
}
s
.
parseSSEUsage
(
data
,
usage
)
}
}
}
s
.
parseSSEUsage
(
data
,
usage
)
continue
}
}
pendingEventLines
=
append
(
pendingEventLines
,
line
)
case
<-
intervalCh
:
case
<-
intervalCh
:
lastRead
:=
time
.
Unix
(
0
,
atomic
.
LoadInt64
(
&
lastReadAt
))
lastRead
:=
time
.
Unix
(
0
,
atomic
.
LoadInt64
(
&
lastReadAt
))
if
time
.
Since
(
lastRead
)
<
streamInterval
{
if
time
.
Since
(
lastRead
)
<
streamInterval
{
...
@@ -3295,43 +4241,124 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
...
@@ -3295,43 +4241,124 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
}
// replaceModelInSSELine 替换SSE数据行中的model字段
func
rewriteParamKeysInValue
(
value
any
,
cache
map
[
string
]
string
)
(
any
,
bool
)
{
func
(
s
*
GatewayService
)
replaceModelInSSELine
(
line
,
fromModel
,
toModel
string
)
string
{
switch
v
:=
value
.
(
type
)
{
if
!
sseDataRe
.
MatchString
(
line
)
{
case
map
[
string
]
any
:
return
line
changed
:=
false
}
rewritten
:=
make
(
map
[
string
]
any
,
len
(
v
))
data
:=
sseDataRe
.
ReplaceAllString
(
line
,
""
)
for
key
,
item
:=
range
v
{
if
data
==
""
||
data
==
"[DONE]"
{
newKey
:=
normalizeParamNameForOpenCode
(
key
,
cache
)
return
line
newItem
,
childChanged
:=
rewriteParamKeysInValue
(
item
,
cache
)
}
if
childChanged
{
changed
=
true
var
event
map
[
string
]
any
}
if
err
:=
json
.
Unmarshal
([]
byte
(
data
),
&
event
);
err
!=
nil
{
if
newKey
!=
key
{
return
line
changed
=
true
}
}
rewritten
[
newKey
]
=
newItem
// 只替换 message_start 事件中的 message.model
}
if
event
[
"type"
]
!=
"message_start"
{
if
!
changed
{
return
line
return
value
,
false
}
return
rewritten
,
true
case
[]
any
:
changed
:=
false
rewritten
:=
make
([]
any
,
len
(
v
))
for
idx
,
item
:=
range
v
{
newItem
,
childChanged
:=
rewriteParamKeysInValue
(
item
,
cache
)
if
childChanged
{
changed
=
true
}
rewritten
[
idx
]
=
newItem
}
if
!
changed
{
return
value
,
false
}
return
rewritten
,
true
default
:
return
value
,
false
}
}
}
msg
,
ok
:=
event
[
"message"
]
.
(
map
[
string
]
any
)
func
rewriteToolNamesInValue
(
value
any
,
toolNameMap
map
[
string
]
string
)
bool
{
if
!
ok
{
switch
v
:=
value
.
(
type
)
{
return
line
case
map
[
string
]
any
:
changed
:=
false
if
blockType
,
_
:=
v
[
"type"
]
.
(
string
);
blockType
==
"tool_use"
{
if
name
,
ok
:=
v
[
"name"
]
.
(
string
);
ok
{
mapped
:=
normalizeToolNameForOpenCode
(
name
,
toolNameMap
)
if
mapped
!=
name
{
v
[
"name"
]
=
mapped
changed
=
true
}
}
if
input
,
ok
:=
v
[
"input"
]
.
(
map
[
string
]
any
);
ok
{
rewrittenInput
,
inputChanged
:=
rewriteParamKeysInValue
(
input
,
toolNameMap
)
if
inputChanged
{
if
m
,
ok
:=
rewrittenInput
.
(
map
[
string
]
any
);
ok
{
v
[
"input"
]
=
m
changed
=
true
}
}
}
}
for
_
,
item
:=
range
v
{
if
rewriteToolNamesInValue
(
item
,
toolNameMap
)
{
changed
=
true
}
}
return
changed
case
[]
any
:
changed
:=
false
for
_
,
item
:=
range
v
{
if
rewriteToolNamesInValue
(
item
,
toolNameMap
)
{
changed
=
true
}
}
return
changed
default
:
return
false
}
}
}
model
,
ok
:=
msg
[
"model"
]
.
(
string
)
func
replaceToolNamesInText
(
text
string
,
toolNameMap
map
[
string
]
string
)
string
{
if
!
ok
||
model
!=
fromModel
{
if
text
==
""
{
return
line
return
text
}
}
output
:=
toolNameFieldRe
.
ReplaceAllStringFunc
(
text
,
func
(
match
string
)
string
{
submatches
:=
toolNameFieldRe
.
FindStringSubmatch
(
match
)
if
len
(
submatches
)
<
2
{
return
match
}
name
:=
submatches
[
1
]
mapped
:=
normalizeToolNameForOpenCode
(
name
,
toolNameMap
)
if
mapped
==
name
{
return
match
}
return
strings
.
Replace
(
match
,
name
,
mapped
,
1
)
})
output
=
modelFieldRe
.
ReplaceAllStringFunc
(
output
,
func
(
match
string
)
string
{
submatches
:=
modelFieldRe
.
FindStringSubmatch
(
match
)
if
len
(
submatches
)
<
2
{
return
match
}
model
:=
submatches
[
1
]
mapped
:=
claude
.
DenormalizeModelID
(
model
)
if
mapped
==
model
{
return
match
}
return
strings
.
Replace
(
match
,
model
,
mapped
,
1
)
})
msg
[
"model"
]
=
toModel
for
mapped
,
original
:=
range
toolNameMap
{
newData
,
err
:=
json
.
Marshal
(
event
)
if
mapped
==
""
||
original
==
""
||
mapped
==
original
{
if
err
!=
nil
{
continue
return
line
}
output
=
strings
.
ReplaceAll
(
output
,
"
\"
"
+
mapped
+
"
\"
:"
,
"
\"
"
+
original
+
"
\"
:"
)
output
=
strings
.
ReplaceAll
(
output
,
"
\\\"
"
+
mapped
+
"
\\\"
:"
,
"
\\\"
"
+
original
+
"
\\\"
:"
)
}
}
return
"data: "
+
string
(
newData
)
return
output
}
}
func
(
s
*
GatewayService
)
parseSSEUsage
(
data
string
,
usage
*
ClaudeUsage
)
{
func
(
s
*
GatewayService
)
parseSSEUsage
(
data
string
,
usage
*
ClaudeUsage
)
{
...
@@ -3359,23 +4386,25 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
...
@@ -3359,23 +4386,25 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
}
`json:"usage"`
}
`json:"usage"`
}
}
if
json
.
Unmarshal
([]
byte
(
data
),
&
msgDelta
)
==
nil
&&
msgDelta
.
Type
==
"message_delta"
{
if
json
.
Unmarshal
([]
byte
(
data
),
&
msgDelta
)
==
nil
&&
msgDelta
.
Type
==
"message_delta"
{
// output_tokens 总是从 message_delta 获取
// message_delta 仅覆盖存在且非0的字段
usage
.
OutputTokens
=
msgDelta
.
Usage
.
OutputTokens
// 避免覆盖 message_start 中已有的值(如 input_tokens)
// Claude API 的 message_delta 通常只包含 output_tokens
// 如果 message_start 中没有值,则从 message_delta 获取(兼容GLM等API)
if
msgDelta
.
Usage
.
InputTokens
>
0
{
if
usage
.
InputTokens
==
0
{
usage
.
InputTokens
=
msgDelta
.
Usage
.
InputTokens
usage
.
InputTokens
=
msgDelta
.
Usage
.
InputTokens
}
}
if
usage
.
CacheCreationInputTokens
==
0
{
if
msgDelta
.
Usage
.
OutputTokens
>
0
{
usage
.
OutputTokens
=
msgDelta
.
Usage
.
OutputTokens
}
if
msgDelta
.
Usage
.
CacheCreationInputTokens
>
0
{
usage
.
CacheCreationInputTokens
=
msgDelta
.
Usage
.
CacheCreationInputTokens
usage
.
CacheCreationInputTokens
=
msgDelta
.
Usage
.
CacheCreationInputTokens
}
}
if
u
sage
.
CacheReadInputTokens
==
0
{
if
msgDelta
.
U
sage
.
CacheReadInputTokens
>
0
{
usage
.
CacheReadInputTokens
=
msgDelta
.
Usage
.
CacheReadInputTokens
usage
.
CacheReadInputTokens
=
msgDelta
.
Usage
.
CacheReadInputTokens
}
}
}
}
}
}
func
(
s
*
GatewayService
)
handleNonStreamingResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
Account
,
originalModel
,
mappedModel
string
)
(
*
ClaudeUsage
,
error
)
{
func
(
s
*
GatewayService
)
handleNonStreamingResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
Account
,
originalModel
,
mappedModel
string
,
toolNameMap
map
[
string
]
string
,
mimicClaudeCode
bool
)
(
*
ClaudeUsage
,
error
)
{
// 更新5h窗口状态
// 更新5h窗口状态
s
.
rateLimitService
.
UpdateSessionWindow
(
ctx
,
account
,
resp
.
Header
)
s
.
rateLimitService
.
UpdateSessionWindow
(
ctx
,
account
,
resp
.
Header
)
...
@@ -3396,6 +4425,9 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
...
@@ -3396,6 +4425,9 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
if
originalModel
!=
mappedModel
{
if
originalModel
!=
mappedModel
{
body
=
s
.
replaceModelInResponseBody
(
body
,
mappedModel
,
originalModel
)
body
=
s
.
replaceModelInResponseBody
(
body
,
mappedModel
,
originalModel
)
}
}
if
mimicClaudeCode
{
body
=
s
.
replaceToolNamesInResponseBody
(
body
,
toolNameMap
)
}
responseheaders
.
WriteFilteredHeaders
(
c
.
Writer
.
Header
(),
resp
.
Header
,
s
.
cfg
.
Security
.
ResponseHeaders
)
responseheaders
.
WriteFilteredHeaders
(
c
.
Writer
.
Header
(),
resp
.
Header
,
s
.
cfg
.
Security
.
ResponseHeaders
)
...
@@ -3433,6 +4465,28 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
...
@@ -3433,6 +4465,28 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
return
newBody
return
newBody
}
}
func
(
s
*
GatewayService
)
replaceToolNamesInResponseBody
(
body
[]
byte
,
toolNameMap
map
[
string
]
string
)
[]
byte
{
if
len
(
body
)
==
0
{
return
body
}
var
resp
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
resp
);
err
!=
nil
{
replaced
:=
replaceToolNamesInText
(
string
(
body
),
toolNameMap
)
if
replaced
==
string
(
body
)
{
return
body
}
return
[]
byte
(
replaced
)
}
if
!
rewriteToolNamesInValue
(
resp
,
toolNameMap
)
{
return
body
}
newBody
,
err
:=
json
.
Marshal
(
resp
)
if
err
!=
nil
{
return
body
}
return
newBody
}
// RecordUsageInput 记录使用量的输入参数
// RecordUsageInput 记录使用量的输入参数
type
RecordUsageInput
struct
{
type
RecordUsageInput
struct
{
Result
*
ForwardResult
Result
*
ForwardResult
...
@@ -3587,6 +4641,162 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
...
@@ -3587,6 +4641,162 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
return
nil
return
nil
}
}
// RecordUsageLongContextInput 记录使用量的输入参数(支持长上下文双倍计费)
type
RecordUsageLongContextInput
struct
{
Result
*
ForwardResult
APIKey
*
APIKey
User
*
User
Account
*
Account
Subscription
*
UserSubscription
// 可选:订阅信息
UserAgent
string
// 请求的 User-Agent
IPAddress
string
// 请求的客户端 IP 地址
LongContextThreshold
int
// 长上下文阈值(如 200000)
LongContextMultiplier
float64
// 超出阈值部分的倍率(如 2.0)
}
// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini)
func
(
s
*
GatewayService
)
RecordUsageWithLongContext
(
ctx
context
.
Context
,
input
*
RecordUsageLongContextInput
)
error
{
result
:=
input
.
Result
apiKey
:=
input
.
APIKey
user
:=
input
.
User
account
:=
input
.
Account
subscription
:=
input
.
Subscription
// 获取费率倍数
multiplier
:=
s
.
cfg
.
Default
.
RateMultiplier
if
apiKey
.
GroupID
!=
nil
&&
apiKey
.
Group
!=
nil
{
multiplier
=
apiKey
.
Group
.
RateMultiplier
}
var
cost
*
CostBreakdown
// 根据请求类型选择计费方式
if
result
.
ImageCount
>
0
{
// 图片生成计费
var
groupConfig
*
ImagePriceConfig
if
apiKey
.
Group
!=
nil
{
groupConfig
=
&
ImagePriceConfig
{
Price1K
:
apiKey
.
Group
.
ImagePrice1K
,
Price2K
:
apiKey
.
Group
.
ImagePrice2K
,
Price4K
:
apiKey
.
Group
.
ImagePrice4K
,
}
}
cost
=
s
.
billingService
.
CalculateImageCost
(
result
.
Model
,
result
.
ImageSize
,
result
.
ImageCount
,
groupConfig
,
multiplier
)
}
else
{
// Token 计费(使用长上下文计费方法)
tokens
:=
UsageTokens
{
InputTokens
:
result
.
Usage
.
InputTokens
,
OutputTokens
:
result
.
Usage
.
OutputTokens
,
CacheCreationTokens
:
result
.
Usage
.
CacheCreationInputTokens
,
CacheReadTokens
:
result
.
Usage
.
CacheReadInputTokens
,
}
var
err
error
cost
,
err
=
s
.
billingService
.
CalculateCostWithLongContext
(
result
.
Model
,
tokens
,
multiplier
,
input
.
LongContextThreshold
,
input
.
LongContextMultiplier
)
if
err
!=
nil
{
log
.
Printf
(
"Calculate cost failed: %v"
,
err
)
cost
=
&
CostBreakdown
{
ActualCost
:
0
}
}
}
// 判断计费方式:订阅模式 vs 余额模式
isSubscriptionBilling
:=
subscription
!=
nil
&&
apiKey
.
Group
!=
nil
&&
apiKey
.
Group
.
IsSubscriptionType
()
billingType
:=
BillingTypeBalance
if
isSubscriptionBilling
{
billingType
=
BillingTypeSubscription
}
// 创建使用日志
durationMs
:=
int
(
result
.
Duration
.
Milliseconds
())
var
imageSize
*
string
if
result
.
ImageSize
!=
""
{
imageSize
=
&
result
.
ImageSize
}
accountRateMultiplier
:=
account
.
BillingRateMultiplier
()
usageLog
:=
&
UsageLog
{
UserID
:
user
.
ID
,
APIKeyID
:
apiKey
.
ID
,
AccountID
:
account
.
ID
,
RequestID
:
result
.
RequestID
,
Model
:
result
.
Model
,
InputTokens
:
result
.
Usage
.
InputTokens
,
OutputTokens
:
result
.
Usage
.
OutputTokens
,
CacheCreationTokens
:
result
.
Usage
.
CacheCreationInputTokens
,
CacheReadTokens
:
result
.
Usage
.
CacheReadInputTokens
,
InputCost
:
cost
.
InputCost
,
OutputCost
:
cost
.
OutputCost
,
CacheCreationCost
:
cost
.
CacheCreationCost
,
CacheReadCost
:
cost
.
CacheReadCost
,
TotalCost
:
cost
.
TotalCost
,
ActualCost
:
cost
.
ActualCost
,
RateMultiplier
:
multiplier
,
AccountRateMultiplier
:
&
accountRateMultiplier
,
BillingType
:
billingType
,
Stream
:
result
.
Stream
,
DurationMs
:
&
durationMs
,
FirstTokenMs
:
result
.
FirstTokenMs
,
ImageCount
:
result
.
ImageCount
,
ImageSize
:
imageSize
,
CreatedAt
:
time
.
Now
(),
}
// 添加 UserAgent
if
input
.
UserAgent
!=
""
{
usageLog
.
UserAgent
=
&
input
.
UserAgent
}
// 添加 IPAddress
if
input
.
IPAddress
!=
""
{
usageLog
.
IPAddress
=
&
input
.
IPAddress
}
// 添加分组和订阅关联
if
apiKey
.
GroupID
!=
nil
{
usageLog
.
GroupID
=
apiKey
.
GroupID
}
if
subscription
!=
nil
{
usageLog
.
SubscriptionID
=
&
subscription
.
ID
}
inserted
,
err
:=
s
.
usageLogRepo
.
Create
(
ctx
,
usageLog
)
if
err
!=
nil
{
log
.
Printf
(
"Create usage log failed: %v"
,
err
)
}
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
RunMode
==
config
.
RunModeSimple
{
log
.
Printf
(
"[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d"
,
usageLog
.
UserID
,
usageLog
.
TotalTokens
())
s
.
deferredService
.
ScheduleLastUsedUpdate
(
account
.
ID
)
return
nil
}
shouldBill
:=
inserted
||
err
!=
nil
// 根据计费类型执行扣费
if
isSubscriptionBilling
{
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
if
shouldBill
&&
cost
.
TotalCost
>
0
{
if
err
:=
s
.
userSubRepo
.
IncrementUsage
(
ctx
,
subscription
.
ID
,
cost
.
TotalCost
);
err
!=
nil
{
log
.
Printf
(
"Increment subscription usage failed: %v"
,
err
)
}
// 异步更新订阅缓存
s
.
billingCacheService
.
QueueUpdateSubscriptionUsage
(
user
.
ID
,
*
apiKey
.
GroupID
,
cost
.
TotalCost
)
}
}
else
{
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
if
shouldBill
&&
cost
.
ActualCost
>
0
{
if
err
:=
s
.
userRepo
.
DeductBalance
(
ctx
,
user
.
ID
,
cost
.
ActualCost
);
err
!=
nil
{
log
.
Printf
(
"Deduct balance failed: %v"
,
err
)
}
// 异步更新余额缓存
s
.
billingCacheService
.
QueueDeductBalance
(
user
.
ID
,
cost
.
ActualCost
)
}
}
// Schedule batch update for account last_used_at
s
.
deferredService
.
ScheduleLastUsedUpdate
(
account
.
ID
)
return
nil
}
// ForwardCountTokens 转发 count_tokens 请求到上游 API
// ForwardCountTokens 转发 count_tokens 请求到上游 API
// 特点:不记录使用量、仅支持非流式响应
// 特点:不记录使用量、仅支持非流式响应
func
(
s
*
GatewayService
)
ForwardCountTokens
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
parsed
*
ParsedRequest
)
error
{
func
(
s
*
GatewayService
)
ForwardCountTokens
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
parsed
*
ParsedRequest
)
error
{
...
@@ -3598,6 +4808,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
...
@@ -3598,6 +4808,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
body
:=
parsed
.
Body
body
:=
parsed
.
Body
reqModel
:=
parsed
.
Model
reqModel
:=
parsed
.
Model
isClaudeCode
:=
isClaudeCodeRequest
(
ctx
,
c
,
parsed
)
shouldMimicClaudeCode
:=
account
.
IsOAuth
()
&&
!
isClaudeCode
if
shouldMimicClaudeCode
{
normalizeOpts
:=
claudeOAuthNormalizeOptions
{
stripSystemCacheControl
:
true
}
body
,
reqModel
,
_
=
normalizeClaudeOAuthRequestBody
(
body
,
reqModel
,
normalizeOpts
)
}
// Antigravity 账户不支持 count_tokens 转发,直接返回空值
// Antigravity 账户不支持 count_tokens 转发,直接返回空值
if
account
.
Platform
==
PlatformAntigravity
{
if
account
.
Platform
==
PlatformAntigravity
{
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"input_tokens"
:
0
})
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"input_tokens"
:
0
})
...
@@ -3624,7 +4842,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
...
@@ -3624,7 +4842,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
}
// 构建上游请求
// 构建上游请求
upstreamReq
,
err
:=
s
.
buildCountTokensRequest
(
ctx
,
c
,
account
,
body
,
token
,
tokenType
,
reqModel
)
upstreamReq
,
err
:=
s
.
buildCountTokensRequest
(
ctx
,
c
,
account
,
body
,
token
,
tokenType
,
reqModel
,
shouldMimicClaudeCode
)
if
err
!=
nil
{
if
err
!=
nil
{
s
.
countTokensError
(
c
,
http
.
StatusInternalServerError
,
"api_error"
,
"Failed to build request"
)
s
.
countTokensError
(
c
,
http
.
StatusInternalServerError
,
"api_error"
,
"Failed to build request"
)
return
err
return
err
...
@@ -3657,7 +4875,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
...
@@ -3657,7 +4875,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
log
.
Printf
(
"Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks"
,
account
.
ID
)
log
.
Printf
(
"Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks"
,
account
.
ID
)
filteredBody
:=
FilterThinkingBlocksForRetry
(
body
)
filteredBody
:=
FilterThinkingBlocksForRetry
(
body
)
retryReq
,
buildErr
:=
s
.
buildCountTokensRequest
(
ctx
,
c
,
account
,
filteredBody
,
token
,
tokenType
,
reqModel
)
retryReq
,
buildErr
:=
s
.
buildCountTokensRequest
(
ctx
,
c
,
account
,
filteredBody
,
token
,
tokenType
,
reqModel
,
shouldMimicClaudeCode
)
if
buildErr
==
nil
{
if
buildErr
==
nil
{
retryResp
,
retryErr
:=
s
.
httpUpstream
.
DoWithTLS
(
retryReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
())
retryResp
,
retryErr
:=
s
.
httpUpstream
.
DoWithTLS
(
retryReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
())
if
retryErr
==
nil
{
if
retryErr
==
nil
{
...
@@ -3722,7 +4940,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
...
@@ -3722,7 +4940,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
}
// buildCountTokensRequest 构建 count_tokens 上游请求
// buildCountTokensRequest 构建 count_tokens 上游请求
func
(
s
*
GatewayService
)
buildCountTokensRequest
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
,
token
,
tokenType
,
modelID
string
)
(
*
http
.
Request
,
error
)
{
func
(
s
*
GatewayService
)
buildCountTokensRequest
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
,
token
,
tokenType
,
modelID
string
,
mimicClaudeCode
bool
)
(
*
http
.
Request
,
error
)
{
// 确定目标 URL
// 确定目标 URL
targetURL
:=
claudeAPICountTokensURL
targetURL
:=
claudeAPICountTokensURL
if
account
.
Type
==
AccountTypeAPIKey
{
if
account
.
Type
==
AccountTypeAPIKey
{
...
@@ -3736,10 +4954,15 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
...
@@ -3736,10 +4954,15 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
}
}
}
clientHeaders
:=
http
.
Header
{}
if
c
!=
nil
&&
c
.
Request
!=
nil
{
clientHeaders
=
c
.
Request
.
Header
}
// OAuth 账号:应用统一指纹和重写 userID
// OAuth 账号:应用统一指纹和重写 userID
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
if
account
.
IsOAuth
()
&&
s
.
identityService
!=
nil
{
if
account
.
IsOAuth
()
&&
s
.
identityService
!=
nil
{
fp
,
err
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
.
Request
.
Header
)
fp
,
err
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
lient
Header
s
)
if
err
==
nil
{
if
err
==
nil
{
accountUUID
:=
account
.
GetExtraString
(
"account_uuid"
)
accountUUID
:=
account
.
GetExtraString
(
"account_uuid"
)
if
accountUUID
!=
""
&&
fp
.
ClientID
!=
""
{
if
accountUUID
!=
""
&&
fp
.
ClientID
!=
""
{
...
@@ -3763,7 +4986,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
...
@@ -3763,7 +4986,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
}
// 白名单透传 headers
// 白名单透传 headers
for
key
,
values
:=
range
c
.
Request
.
Header
{
for
key
,
values
:=
range
c
lient
Header
s
{
lowerKey
:=
strings
.
ToLower
(
key
)
lowerKey
:=
strings
.
ToLower
(
key
)
if
allowedHeaders
[
lowerKey
]
{
if
allowedHeaders
[
lowerKey
]
{
for
_
,
v
:=
range
values
{
for
_
,
v
:=
range
values
{
...
@@ -3774,7 +4997,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
...
@@ -3774,7 +4997,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
// OAuth 账号:应用指纹到请求头
// OAuth 账号:应用指纹到请求头
if
account
.
IsOAuth
()
&&
s
.
identityService
!=
nil
{
if
account
.
IsOAuth
()
&&
s
.
identityService
!=
nil
{
fp
,
_
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
.
Request
.
Header
)
fp
,
_
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
lient
Header
s
)
if
fp
!=
nil
{
if
fp
!=
nil
{
s
.
identityService
.
ApplyFingerprint
(
req
,
fp
)
s
.
identityService
.
ApplyFingerprint
(
req
,
fp
)
}
}
...
@@ -3787,10 +5010,30 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
...
@@ -3787,10 +5010,30 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
if
req
.
Header
.
Get
(
"anthropic-version"
)
==
""
{
if
req
.
Header
.
Get
(
"anthropic-version"
)
==
""
{
req
.
Header
.
Set
(
"anthropic-version"
,
"2023-06-01"
)
req
.
Header
.
Set
(
"anthropic-version"
,
"2023-06-01"
)
}
}
if
tokenType
==
"oauth"
{
applyClaudeOAuthHeaderDefaults
(
req
,
false
)
}
// OAuth 账号:处理 anthropic-beta header
// OAuth 账号:处理 anthropic-beta header
if
tokenType
==
"oauth"
{
if
tokenType
==
"oauth"
{
req
.
Header
.
Set
(
"anthropic-beta"
,
s
.
getBetaHeader
(
modelID
,
c
.
GetHeader
(
"anthropic-beta"
)))
if
mimicClaudeCode
{
applyClaudeCodeMimicHeaders
(
req
,
false
)
incomingBeta
:=
req
.
Header
.
Get
(
"anthropic-beta"
)
requiredBetas
:=
[]
string
{
claude
.
BetaClaudeCode
,
claude
.
BetaOAuth
,
claude
.
BetaInterleavedThinking
,
claude
.
BetaTokenCounting
}
req
.
Header
.
Set
(
"anthropic-beta"
,
mergeAnthropicBeta
(
requiredBetas
,
incomingBeta
))
}
else
{
clientBetaHeader
:=
req
.
Header
.
Get
(
"anthropic-beta"
)
if
clientBetaHeader
==
""
{
req
.
Header
.
Set
(
"anthropic-beta"
,
claude
.
CountTokensBetaHeader
)
}
else
{
beta
:=
s
.
getBetaHeader
(
modelID
,
clientBetaHeader
)
if
!
strings
.
Contains
(
beta
,
claude
.
BetaTokenCounting
)
{
beta
=
beta
+
","
+
claude
.
BetaTokenCounting
}
req
.
Header
.
Set
(
"anthropic-beta"
,
beta
)
}
}
}
else
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
InjectBetaForAPIKey
&&
req
.
Header
.
Get
(
"anthropic-beta"
)
==
""
{
}
else
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
InjectBetaForAPIKey
&&
req
.
Header
.
Get
(
"anthropic-beta"
)
==
""
{
// API-key:与 messages 同步的按需 beta 注入(默认关闭)
// API-key:与 messages 同步的按需 beta 注入(默认关闭)
if
requestNeedsBetaFeatures
(
body
)
{
if
requestNeedsBetaFeatures
(
body
)
{
...
@@ -3800,6 +5043,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
...
@@ -3800,6 +5043,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
}
}
}
if
c
!=
nil
&&
tokenType
==
"oauth"
{
c
.
Set
(
claudeMimicDebugInfoKey
,
buildClaudeMimicDebugLine
(
req
,
body
,
account
,
tokenType
,
mimicClaudeCode
))
}
if
s
.
debugClaudeMimicEnabled
()
{
logClaudeMimicDebug
(
req
,
body
,
account
,
tokenType
,
mimicClaudeCode
)
}
return
req
,
nil
return
req
,
nil
}
}
...
...
backend/internal/service/gemini_messages_compat_service.go
View file @
31fe0178
...
@@ -36,6 +36,11 @@ const (
...
@@ -36,6 +36,11 @@ const (
geminiRetryMaxDelay
=
16
*
time
.
Second
geminiRetryMaxDelay
=
16
*
time
.
Second
)
)
// Gemini tool calling now requires `thoughtSignature` in parts that include `functionCall`.
// Many clients don't send it; we inject a known dummy signature to satisfy the validator.
// Ref: https://ai.google.dev/gemini-api/docs/thought-signatures
const
geminiDummyThoughtSignature
=
"skip_thought_signature_validator"
type
GeminiMessagesCompatService
struct
{
type
GeminiMessagesCompatService
struct
{
accountRepo
AccountRepository
accountRepo
AccountRepository
groupRepo
GroupRepository
groupRepo
GroupRepository
...
@@ -528,6 +533,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
...
@@ -528,6 +533,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
s
.
writeClaudeError
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
err
.
Error
())
return
nil
,
s
.
writeClaudeError
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
err
.
Error
())
}
}
geminiReq
=
ensureGeminiFunctionCallThoughtSignatures
(
geminiReq
)
originalClaudeBody
:=
body
originalClaudeBody
:=
body
proxyURL
:=
""
proxyURL
:=
""
...
@@ -931,6 +937,13 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
...
@@ -931,6 +937,13 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}
}
}
}
// 图片生成计费
imageCount
:=
0
imageSize
:=
s
.
extractImageSize
(
body
)
if
isImageGenerationModel
(
originalModel
)
{
imageCount
=
1
}
return
&
ForwardResult
{
return
&
ForwardResult
{
RequestID
:
requestID
,
RequestID
:
requestID
,
Usage
:
*
usage
,
Usage
:
*
usage
,
...
@@ -938,6 +951,8 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
...
@@ -938,6 +951,8 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
Stream
:
req
.
Stream
,
Stream
:
req
.
Stream
,
Duration
:
time
.
Since
(
startTime
),
Duration
:
time
.
Since
(
startTime
),
FirstTokenMs
:
firstTokenMs
,
FirstTokenMs
:
firstTokenMs
,
ImageCount
:
imageCount
,
ImageSize
:
imageSize
,
},
nil
},
nil
}
}
...
@@ -969,6 +984,10 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
...
@@ -969,6 +984,10 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return
nil
,
s
.
writeGoogleError
(
c
,
http
.
StatusNotFound
,
"Unsupported action: "
+
action
)
return
nil
,
s
.
writeGoogleError
(
c
,
http
.
StatusNotFound
,
"Unsupported action: "
+
action
)
}
}
// Some Gemini upstreams validate tool call parts strictly; ensure any `functionCall` part includes a
// `thoughtSignature` to avoid frequent INVALID_ARGUMENT 400s.
body
=
ensureGeminiFunctionCallThoughtSignatures
(
body
)
mappedModel
:=
originalModel
mappedModel
:=
originalModel
if
account
.
Type
==
AccountTypeAPIKey
{
if
account
.
Type
==
AccountTypeAPIKey
{
mappedModel
=
account
.
GetMappedModel
(
originalModel
)
mappedModel
=
account
.
GetMappedModel
(
originalModel
)
...
@@ -1371,6 +1390,13 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
...
@@ -1371,6 +1390,13 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
usage
=
&
ClaudeUsage
{}
usage
=
&
ClaudeUsage
{}
}
}
// 图片生成计费
imageCount
:=
0
imageSize
:=
s
.
extractImageSize
(
body
)
if
isImageGenerationModel
(
originalModel
)
{
imageCount
=
1
}
return
&
ForwardResult
{
return
&
ForwardResult
{
RequestID
:
requestID
,
RequestID
:
requestID
,
Usage
:
*
usage
,
Usage
:
*
usage
,
...
@@ -1378,6 +1404,8 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
...
@@ -1378,6 +1404,8 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
Stream
:
stream
,
Stream
:
stream
,
Duration
:
time
.
Since
(
startTime
),
Duration
:
time
.
Since
(
startTime
),
FirstTokenMs
:
firstTokenMs
,
FirstTokenMs
:
firstTokenMs
,
ImageCount
:
imageCount
,
ImageSize
:
imageSize
,
},
nil
},
nil
}
}
...
@@ -2504,9 +2532,13 @@ func extractGeminiUsage(geminiResp map[string]any) *ClaudeUsage {
...
@@ -2504,9 +2532,13 @@ func extractGeminiUsage(geminiResp map[string]any) *ClaudeUsage {
}
}
prompt
,
_
:=
asInt
(
usageMeta
[
"promptTokenCount"
])
prompt
,
_
:=
asInt
(
usageMeta
[
"promptTokenCount"
])
cand
,
_
:=
asInt
(
usageMeta
[
"candidatesTokenCount"
])
cand
,
_
:=
asInt
(
usageMeta
[
"candidatesTokenCount"
])
cached
,
_
:=
asInt
(
usageMeta
[
"cachedContentTokenCount"
])
// 注意:Gemini 的 promptTokenCount 包含 cachedContentTokenCount,
// 但 Claude 的 input_tokens 不包含 cache_read_input_tokens,需要减去
return
&
ClaudeUsage
{
return
&
ClaudeUsage
{
InputTokens
:
prompt
,
InputTokens
:
prompt
-
cached
,
OutputTokens
:
cand
,
OutputTokens
:
cand
,
CacheReadInputTokens
:
cached
,
}
}
}
}
...
@@ -2635,6 +2667,58 @@ func nextGeminiDailyResetUnix() *int64 {
...
@@ -2635,6 +2667,58 @@ func nextGeminiDailyResetUnix() *int64 {
return
&
ts
return
&
ts
}
}
func
ensureGeminiFunctionCallThoughtSignatures
(
body
[]
byte
)
[]
byte
{
// Fast path: only run when functionCall is present.
if
!
bytes
.
Contains
(
body
,
[]
byte
(
`"functionCall"`
))
{
return
body
}
var
payload
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
payload
);
err
!=
nil
{
return
body
}
contentsAny
,
ok
:=
payload
[
"contents"
]
.
([]
any
)
if
!
ok
||
len
(
contentsAny
)
==
0
{
return
body
}
modified
:=
false
for
_
,
c
:=
range
contentsAny
{
cm
,
ok
:=
c
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
partsAny
,
ok
:=
cm
[
"parts"
]
.
([]
any
)
if
!
ok
||
len
(
partsAny
)
==
0
{
continue
}
for
_
,
p
:=
range
partsAny
{
pm
,
ok
:=
p
.
(
map
[
string
]
any
)
if
!
ok
||
pm
==
nil
{
continue
}
if
fc
,
ok
:=
pm
[
"functionCall"
]
.
(
map
[
string
]
any
);
!
ok
||
fc
==
nil
{
continue
}
ts
,
_
:=
pm
[
"thoughtSignature"
]
.
(
string
)
if
strings
.
TrimSpace
(
ts
)
==
""
{
pm
[
"thoughtSignature"
]
=
geminiDummyThoughtSignature
modified
=
true
}
}
}
if
!
modified
{
return
body
}
b
,
err
:=
json
.
Marshal
(
payload
)
if
err
!=
nil
{
return
body
}
return
b
}
func
extractGeminiFinishReason
(
geminiResp
map
[
string
]
any
)
string
{
func
extractGeminiFinishReason
(
geminiResp
map
[
string
]
any
)
string
{
if
candidates
,
ok
:=
geminiResp
[
"candidates"
]
.
([]
any
);
ok
&&
len
(
candidates
)
>
0
{
if
candidates
,
ok
:=
geminiResp
[
"candidates"
]
.
([]
any
);
ok
&&
len
(
candidates
)
>
0
{
if
cand
,
ok
:=
candidates
[
0
]
.
(
map
[
string
]
any
);
ok
{
if
cand
,
ok
:=
candidates
[
0
]
.
(
map
[
string
]
any
);
ok
{
...
@@ -2834,7 +2918,13 @@ func convertClaudeMessagesToGeminiContents(messages any, toolUseIDToName map[str
...
@@ -2834,7 +2918,13 @@ func convertClaudeMessagesToGeminiContents(messages any, toolUseIDToName map[str
if
strings
.
TrimSpace
(
id
)
!=
""
&&
strings
.
TrimSpace
(
name
)
!=
""
{
if
strings
.
TrimSpace
(
id
)
!=
""
&&
strings
.
TrimSpace
(
name
)
!=
""
{
toolUseIDToName
[
id
]
=
name
toolUseIDToName
[
id
]
=
name
}
}
signature
,
_
:=
bm
[
"signature"
]
.
(
string
)
signature
=
strings
.
TrimSpace
(
signature
)
if
signature
==
""
{
signature
=
geminiDummyThoughtSignature
}
parts
=
append
(
parts
,
map
[
string
]
any
{
parts
=
append
(
parts
,
map
[
string
]
any
{
"thoughtSignature"
:
signature
,
"functionCall"
:
map
[
string
]
any
{
"functionCall"
:
map
[
string
]
any
{
"name"
:
name
,
"name"
:
name
,
"args"
:
bm
[
"input"
],
"args"
:
bm
[
"input"
],
...
@@ -3031,3 +3121,26 @@ func convertClaudeGenerationConfig(req map[string]any) map[string]any {
...
@@ -3031,3 +3121,26 @@ func convertClaudeGenerationConfig(req map[string]any) map[string]any {
}
}
return
out
return
out
}
}
// extractImageSize 从 Gemini 请求中提取 image_size 参数
func
(
s
*
GeminiMessagesCompatService
)
extractImageSize
(
body
[]
byte
)
string
{
var
req
struct
{
GenerationConfig
*
struct
{
ImageConfig
*
struct
{
ImageSize
string
`json:"imageSize"`
}
`json:"imageConfig"`
}
`json:"generationConfig"`
}
if
err
:=
json
.
Unmarshal
(
body
,
&
req
);
err
!=
nil
{
return
"2K"
}
if
req
.
GenerationConfig
!=
nil
&&
req
.
GenerationConfig
.
ImageConfig
!=
nil
{
size
:=
strings
.
ToUpper
(
strings
.
TrimSpace
(
req
.
GenerationConfig
.
ImageConfig
.
ImageSize
))
if
size
==
"1K"
||
size
==
"2K"
||
size
==
"4K"
{
return
size
}
}
return
"2K"
}
backend/internal/service/gemini_messages_compat_service_test.go
View file @
31fe0178
package
service
package
service
import
(
import
(
"encoding/json"
"strings"
"testing"
"testing"
)
)
...
@@ -126,3 +128,78 @@ func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
...
@@ -126,3 +128,78 @@ func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
})
})
}
}
}
}
func
TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse
(
t
*
testing
.
T
)
{
claudeReq
:=
map
[
string
]
any
{
"model"
:
"claude-haiku-4-5-20251001"
,
"max_tokens"
:
10
,
"messages"
:
[]
any
{
map
[
string
]
any
{
"role"
:
"user"
,
"content"
:
[]
any
{
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
"hi"
},
},
},
map
[
string
]
any
{
"role"
:
"assistant"
,
"content"
:
[]
any
{
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
"ok"
},
map
[
string
]
any
{
"type"
:
"tool_use"
,
"id"
:
"toolu_123"
,
"name"
:
"default_api:write_file"
,
"input"
:
map
[
string
]
any
{
"path"
:
"a.txt"
,
"content"
:
"x"
},
// no signature on purpose
},
},
},
},
"tools"
:
[]
any
{
map
[
string
]
any
{
"name"
:
"default_api:write_file"
,
"description"
:
"write file"
,
"input_schema"
:
map
[
string
]
any
{
"type"
:
"object"
,
"properties"
:
map
[
string
]
any
{
"path"
:
map
[
string
]
any
{
"type"
:
"string"
}},
},
},
},
}
b
,
_
:=
json
.
Marshal
(
claudeReq
)
out
,
err
:=
convertClaudeMessagesToGeminiGenerateContent
(
b
)
if
err
!=
nil
{
t
.
Fatalf
(
"convert failed: %v"
,
err
)
}
s
:=
string
(
out
)
if
!
strings
.
Contains
(
s
,
"
\"
functionCall
\"
"
)
{
t
.
Fatalf
(
"expected functionCall in output, got: %s"
,
s
)
}
if
!
strings
.
Contains
(
s
,
"
\"
thoughtSignature
\"
:
\"
"
+
geminiDummyThoughtSignature
+
"
\"
"
)
{
t
.
Fatalf
(
"expected injected thoughtSignature %q, got: %s"
,
geminiDummyThoughtSignature
,
s
)
}
}
func
TestEnsureGeminiFunctionCallThoughtSignatures_InsertsWhenMissing
(
t
*
testing
.
T
)
{
geminiReq
:=
map
[
string
]
any
{
"contents"
:
[]
any
{
map
[
string
]
any
{
"role"
:
"user"
,
"parts"
:
[]
any
{
map
[
string
]
any
{
"functionCall"
:
map
[
string
]
any
{
"name"
:
"default_api:write_file"
,
"args"
:
map
[
string
]
any
{
"path"
:
"a.txt"
},
},
},
},
},
},
}
b
,
_
:=
json
.
Marshal
(
geminiReq
)
out
:=
ensureGeminiFunctionCallThoughtSignatures
(
b
)
s
:=
string
(
out
)
if
!
strings
.
Contains
(
s
,
"
\"
thoughtSignature
\"
:
\"
"
+
geminiDummyThoughtSignature
+
"
\"
"
)
{
t
.
Fatalf
(
"expected injected thoughtSignature %q, got: %s"
,
geminiDummyThoughtSignature
,
s
)
}
}
backend/internal/service/gemini_multiplatform_test.go
View file @
31fe0178
...
@@ -218,6 +218,14 @@ func (m *mockGroupRepoForGemini) DeleteAccountGroupsByGroupID(ctx context.Contex
...
@@ -218,6 +218,14 @@ func (m *mockGroupRepoForGemini) DeleteAccountGroupsByGroupID(ctx context.Contex
return
0
,
nil
return
0
,
nil
}
}
func
(
m
*
mockGroupRepoForGemini
)
BindAccountsToGroup
(
ctx
context
.
Context
,
groupID
int64
,
accountIDs
[]
int64
)
error
{
return
nil
}
func
(
m
*
mockGroupRepoForGemini
)
GetAccountIDsByGroupIDs
(
ctx
context
.
Context
,
groupIDs
[]
int64
)
([]
int64
,
error
)
{
return
nil
,
nil
}
var
_
GroupRepository
=
(
*
mockGroupRepoForGemini
)(
nil
)
var
_
GroupRepository
=
(
*
mockGroupRepoForGemini
)(
nil
)
// mockGatewayCacheForGemini Gemini 测试用的 cache mock
// mockGatewayCacheForGemini Gemini 测试用的 cache mock
...
...
backend/internal/service/gemini_native_signature_cleaner.go
0 → 100644
View file @
31fe0178
package
service
import
(
"encoding/json"
)
// CleanGeminiNativeThoughtSignatures 从 Gemini 原生 API 请求中移除 thoughtSignature 字段,
// 以避免跨账号签名验证错误。
//
// 当粘性会话切换账号时(例如原账号异常、不可调度等),旧账号返回的 thoughtSignature
// 会导致新账号的签名验证失败。通过移除这些签名,让新账号重新生成有效的签名。
//
// CleanGeminiNativeThoughtSignatures removes thoughtSignature fields from Gemini native API requests
// to avoid cross-account signature validation errors.
//
// When sticky session switches accounts (e.g., original account becomes unavailable),
// thoughtSignatures from the old account will cause validation failures on the new account.
// By removing these signatures, we allow the new account to generate valid signatures.
func
CleanGeminiNativeThoughtSignatures
(
body
[]
byte
)
[]
byte
{
if
len
(
body
)
==
0
{
return
body
}
// 解析 JSON
var
data
any
if
err
:=
json
.
Unmarshal
(
body
,
&
data
);
err
!=
nil
{
// 如果解析失败,返回原始 body(可能不是 JSON 或格式不正确)
return
body
}
// 递归清理 thoughtSignature
cleaned
:=
cleanThoughtSignaturesRecursive
(
data
)
// 重新序列化
result
,
err
:=
json
.
Marshal
(
cleaned
)
if
err
!=
nil
{
// 如果序列化失败,返回原始 body
return
body
}
return
result
}
// cleanThoughtSignaturesRecursive 递归遍历数据结构,移除所有 thoughtSignature 字段
func
cleanThoughtSignaturesRecursive
(
data
any
)
any
{
switch
v
:=
data
.
(
type
)
{
case
map
[
string
]
any
:
// 创建新的 map,移除 thoughtSignature
result
:=
make
(
map
[
string
]
any
,
len
(
v
))
for
key
,
value
:=
range
v
{
// 跳过 thoughtSignature 字段
if
key
==
"thoughtSignature"
{
continue
}
// 递归处理嵌套结构
result
[
key
]
=
cleanThoughtSignaturesRecursive
(
value
)
}
return
result
case
[]
any
:
// 递归处理数组中的每个元素
result
:=
make
([]
any
,
len
(
v
))
for
i
,
item
:=
range
v
{
result
[
i
]
=
cleanThoughtSignaturesRecursive
(
item
)
}
return
result
default
:
// 基本类型(string, number, bool, null)直接返回
return
v
}
}
backend/internal/service/group_service.go
View file @
31fe0178
...
@@ -29,6 +29,10 @@ type GroupRepository interface {
...
@@ -29,6 +29,10 @@ type GroupRepository interface {
ExistsByName
(
ctx
context
.
Context
,
name
string
)
(
bool
,
error
)
ExistsByName
(
ctx
context
.
Context
,
name
string
)
(
bool
,
error
)
GetAccountCount
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
GetAccountCount
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
DeleteAccountGroupsByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
DeleteAccountGroupsByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
// GetAccountIDsByGroupIDs 获取多个分组的所有账号 ID(去重)
GetAccountIDsByGroupIDs
(
ctx
context
.
Context
,
groupIDs
[]
int64
)
([]
int64
,
error
)
// BindAccountsToGroup 将多个账号绑定到指定分组
BindAccountsToGroup
(
ctx
context
.
Context
,
groupID
int64
,
accountIDs
[]
int64
)
error
}
}
// CreateGroupRequest 创建分组请求
// CreateGroupRequest 创建分组请求
...
...
backend/internal/service/identity_service.go
View file @
31fe0178
...
@@ -26,13 +26,13 @@ var (
...
@@ -26,13 +26,13 @@ var (
// 默认指纹值(当客户端未提供时使用)
// 默认指纹值(当客户端未提供时使用)
var
defaultFingerprint
=
Fingerprint
{
var
defaultFingerprint
=
Fingerprint
{
UserAgent
:
"claude-cli/2.
0.6
2 (external, cli)"
,
UserAgent
:
"claude-cli/2.
1.2
2 (external, cli)"
,
StainlessLang
:
"js"
,
StainlessLang
:
"js"
,
StainlessPackageVersion
:
"0.
52
.0"
,
StainlessPackageVersion
:
"0.
70
.0"
,
StainlessOS
:
"Linux"
,
StainlessOS
:
"Linux"
,
StainlessArch
:
"
x
64"
,
StainlessArch
:
"
arm
64"
,
StainlessRuntime
:
"node"
,
StainlessRuntime
:
"node"
,
StainlessRuntimeVersion
:
"v2
2
.1
4
.0"
,
StainlessRuntimeVersion
:
"v2
4
.1
3
.0"
,
}
}
// Fingerprint represents account fingerprint data
// Fingerprint represents account fingerprint data
...
@@ -327,7 +327,7 @@ func generateUUIDFromSeed(seed string) string {
...
@@ -327,7 +327,7 @@ func generateUUIDFromSeed(seed string) string {
}
}
// parseUserAgentVersion 解析user-agent版本号
// parseUserAgentVersion 解析user-agent版本号
// 例如:claude-cli/2.
0.6
2 -> (2,
0
,
6
2)
// 例如:claude-cli/2.
1.
2 -> (2,
1
, 2)
func
parseUserAgentVersion
(
ua
string
)
(
major
,
minor
,
patch
int
,
ok
bool
)
{
func
parseUserAgentVersion
(
ua
string
)
(
major
,
minor
,
patch
int
,
ok
bool
)
{
// 匹配 xxx/x.y.z 格式
// 匹配 xxx/x.y.z 格式
matches
:=
userAgentVersionRegex
.
FindStringSubmatch
(
ua
)
matches
:=
userAgentVersionRegex
.
FindStringSubmatch
(
ua
)
...
...
Prev
1
…
3
4
5
6
7
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment