Commit 2220fd18 authored by song's avatar song
Browse files

merge upstream main

parents 11ff73b5 df4c0adf
......@@ -123,7 +123,7 @@ func createTestPayload(modelID string) (map[string]any, error) {
"system": []map[string]any{
{
"type": "text",
"text": "You are Claude Code, Anthropic's official CLI for Claude.",
"text": claudeCodeSystemPrompt,
"cache_control": map[string]string{
"type": "ephemeral",
},
......
......@@ -115,6 +115,8 @@ type CreateGroupInput struct {
MCPXMLInject *bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string
// 从指定分组复制账号(创建分组后在同一事务内绑定)
CopyAccountsFromGroupIDs []int64
}
type UpdateGroupInput struct {
......@@ -142,6 +144,8 @@ type UpdateGroupInput struct {
MCPXMLInject *bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64
}
type CreateAccountInput struct {
......@@ -598,6 +602,38 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
mcpXMLInject = *input.MCPXMLInject
}
// 如果指定了复制账号的源分组,先获取账号 ID 列表
var accountIDsToCopy []int64
if len(input.CopyAccountsFromGroupIDs) > 0 {
// 去重源分组 IDs
seen := make(map[int64]struct{})
uniqueSourceGroupIDs := make([]int64, 0, len(input.CopyAccountsFromGroupIDs))
for _, srcGroupID := range input.CopyAccountsFromGroupIDs {
if _, exists := seen[srcGroupID]; !exists {
seen[srcGroupID] = struct{}{}
uniqueSourceGroupIDs = append(uniqueSourceGroupIDs, srcGroupID)
}
}
// 校验源分组的平台是否与新分组一致
for _, srcGroupID := range uniqueSourceGroupIDs {
srcGroup, err := s.groupRepo.GetByIDLite(ctx, srcGroupID)
if err != nil {
return nil, fmt.Errorf("source group %d not found: %w", srcGroupID, err)
}
if srcGroup.Platform != platform {
return nil, fmt.Errorf("source group %d platform mismatch: expected %s, got %s", srcGroupID, platform, srcGroup.Platform)
}
}
// 获取所有源分组的账号(去重)
var err error
accountIDsToCopy, err = s.groupRepo.GetAccountIDsByGroupIDs(ctx, uniqueSourceGroupIDs)
if err != nil {
return nil, fmt.Errorf("failed to get accounts from source groups: %w", err)
}
}
group := &Group{
Name: input.Name,
Description: input.Description,
......@@ -622,6 +658,15 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err
}
// 如果有需要复制的账号,绑定到新分组
if len(accountIDsToCopy) > 0 {
if err := s.groupRepo.BindAccountsToGroup(ctx, group.ID, accountIDsToCopy); err != nil {
return nil, fmt.Errorf("failed to bind accounts to new group: %w", err)
}
group.AccountCount = int64(len(accountIDsToCopy))
}
return group, nil
}
......@@ -810,6 +855,54 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, err
}
// 如果指定了复制账号的源分组,同步绑定(替换当前分组的账号)
if len(input.CopyAccountsFromGroupIDs) > 0 {
// 去重源分组 IDs
seen := make(map[int64]struct{})
uniqueSourceGroupIDs := make([]int64, 0, len(input.CopyAccountsFromGroupIDs))
for _, srcGroupID := range input.CopyAccountsFromGroupIDs {
// 校验:源分组不能是自身
if srcGroupID == id {
return nil, fmt.Errorf("cannot copy accounts from self")
}
// 去重
if _, exists := seen[srcGroupID]; !exists {
seen[srcGroupID] = struct{}{}
uniqueSourceGroupIDs = append(uniqueSourceGroupIDs, srcGroupID)
}
}
// 校验源分组的平台是否与当前分组一致
for _, srcGroupID := range uniqueSourceGroupIDs {
srcGroup, err := s.groupRepo.GetByIDLite(ctx, srcGroupID)
if err != nil {
return nil, fmt.Errorf("source group %d not found: %w", srcGroupID, err)
}
if srcGroup.Platform != group.Platform {
return nil, fmt.Errorf("source group %d platform mismatch: expected %s, got %s", srcGroupID, group.Platform, srcGroup.Platform)
}
}
// 获取所有源分组的账号(去重)
accountIDsToCopy, err := s.groupRepo.GetAccountIDsByGroupIDs(ctx, uniqueSourceGroupIDs)
if err != nil {
return nil, fmt.Errorf("failed to get accounts from source groups: %w", err)
}
// 先清空当前分组的所有账号绑定
if _, err := s.groupRepo.DeleteAccountGroupsByGroupID(ctx, id); err != nil {
return nil, fmt.Errorf("failed to clear existing account bindings: %w", err)
}
// 再绑定源分组的账号
if len(accountIDsToCopy) > 0 {
if err := s.groupRepo.BindAccountsToGroup(ctx, id, accountIDsToCopy); err != nil {
return nil, fmt.Errorf("failed to bind accounts to group: %w", err)
}
}
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
}
......
......@@ -164,6 +164,14 @@ func (s *groupRepoStub) DeleteAccountGroupsByGroupID(ctx context.Context, groupI
panic("unexpected DeleteAccountGroupsByGroupID call")
}
func (s *groupRepoStub) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
panic("unexpected BindAccountsToGroup call")
}
func (s *groupRepoStub) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
panic("unexpected GetAccountIDsByGroupIDs call")
}
type proxyRepoStub struct {
deleteErr error
countErr error
......
......@@ -108,6 +108,14 @@ func (s *groupRepoStubForAdmin) DeleteAccountGroupsByGroupID(_ context.Context,
panic("unexpected DeleteAccountGroupsByGroupID call")
}
func (s *groupRepoStubForAdmin) BindAccountsToGroup(_ context.Context, _ int64, _ []int64) error {
panic("unexpected BindAccountsToGroup call")
}
func (s *groupRepoStubForAdmin) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) {
panic("unexpected GetAccountIDsByGroupIDs call")
}
// TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递
func TestAdminService_CreateGroup_WithImagePricing(t *testing.T) {
repo := &groupRepoStubForAdmin{}
......@@ -379,6 +387,14 @@ func (s *groupRepoStubForFallbackCycle) DeleteAccountGroupsByGroupID(_ context.C
panic("unexpected DeleteAccountGroupsByGroupID call")
}
func (s *groupRepoStubForFallbackCycle) BindAccountsToGroup(_ context.Context, _ int64, _ []int64) error {
panic("unexpected BindAccountsToGroup call")
}
func (s *groupRepoStubForFallbackCycle) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) {
panic("unexpected GetAccountIDsByGroupIDs call")
}
type groupRepoStubForInvalidRequestFallback struct {
groups map[int64]*Group
created *Group
......@@ -748,4 +764,4 @@ func TestAdminService_UpdateGroup_InvalidRequestFallbackAllowsAntigravity(t *tes
require.NotNil(t, group)
require.NotNil(t, repo.updated)
require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest)
}
}
......@@ -302,13 +302,11 @@ func logPrefix(sessionID, accountName string) string {
}
// Antigravity 直接支持的模型(精确匹配透传)
// 注意:gemini-2.5 系列已移除,统一映射到 gemini-3 系列
var antigravitySupportedModels = map[string]bool{
"claude-opus-4-5-thinking": true,
"claude-sonnet-4-5": true,
"claude-sonnet-4-5-thinking": true,
"gemini-2.5-flash": true,
"gemini-2.5-flash-lite": true,
"gemini-2.5-flash-thinking": true,
"gemini-3-flash": true,
"gemini-3-pro-low": true,
"gemini-3-pro-high": true,
......@@ -317,23 +315,32 @@ var antigravitySupportedModels = map[string]bool{
// Antigravity 前缀映射表(按前缀长度降序排列,确保最长匹配优先)
// 用于处理模型版本号变化(如 -20251111, -thinking, -preview 等后缀)
// gemini-2.5 系列统一映射到 gemini-3 系列(Antigravity 上游不再支持 2.5)
var antigravityPrefixMapping = []struct {
prefix string
target string
}{
// 长前缀优先
{"gemini-2.5-flash-image", "gemini-3-pro-image"}, // gemini-2.5-flash-image → 3-pro-image
{"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等
{"gemini-3-flash", "gemini-3-flash"}, // gemini-3-flash-preview 等 → gemini-3-flash
{"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx
{"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx
{"claude-haiku-4-5", "claude-sonnet-4-5"}, // claude-haiku-4-5-xxx → sonnet
// gemini-2.5 → gemini-3 映射(长前缀优先)
{"gemini-2.5-flash-thinking", "gemini-3-flash"}, // gemini-2.5-flash-thinking → gemini-3-flash
{"gemini-2.5-flash-image", "gemini-3-pro-image"}, // gemini-2.5-flash-image → gemini-3-pro-image
{"gemini-2.5-flash-lite", "gemini-3-flash"}, // gemini-2.5-flash-lite → gemini-3-flash
{"gemini-2.5-flash", "gemini-3-flash"}, // gemini-2.5-flash → gemini-3-flash
{"gemini-2.5-pro-preview", "gemini-3-pro-high"}, // gemini-2.5-pro-preview → gemini-3-pro-high
{"gemini-2.5-pro-exp", "gemini-3-pro-high"}, // gemini-2.5-pro-exp → gemini-3-pro-high
{"gemini-2.5-pro", "gemini-3-pro-high"}, // gemini-2.5-pro → gemini-3-pro-high
// gemini-3 前缀映射
{"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等
{"gemini-3-flash", "gemini-3-flash"}, // gemini-3-flash-preview 等 → gemini-3-flash
{"gemini-3-pro", "gemini-3-pro-high"}, // gemini-3-pro, gemini-3-pro-preview 等
// Claude 映射
{"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx
{"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx
{"claude-haiku-4-5", "claude-sonnet-4-5"}, // claude-haiku-4-5-xxx → sonnet
{"claude-opus-4-5", "claude-opus-4-5-thinking"},
{"claude-3-haiku", "claude-sonnet-4-5"}, // 旧版 claude-3-haiku-xxx → sonnet
{"claude-sonnet-4", "claude-sonnet-4-5"},
{"claude-haiku-4", "claude-sonnet-4-5"}, // → sonnet
{"claude-opus-4", "claude-opus-4-5-thinking"},
{"gemini-3-pro", "gemini-3-pro-high"}, // gemini-3-pro, gemini-3-pro-preview 等
}
// AntigravityGatewayService 处理 Antigravity 平台的 API 转发
......
......@@ -103,6 +103,10 @@ func (s *httpUpstreamStub) Do(_ *http.Request, _ string, _ int64, _ int) (*http.
return s.resp, s.err
}
func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int, _ bool) (*http.Response, error) {
return s.resp, s.err
}
func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
......
......@@ -134,18 +134,18 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
expected: "claude-sonnet-4-5",
},
// 3. Gemini 透传
// 3. Gemini 2.5 → 3 映射
{
name: "Gemini透传 - gemini-2.5-flash",
name: "Gemini映射 - gemini-2.5-flash → gemini-3-flash",
requestedModel: "gemini-2.5-flash",
accountMapping: nil,
expected: "gemini-2.5-flash",
expected: "gemini-3-flash",
},
{
name: "Gemini透传 - gemini-2.5-pro",
name: "Gemini映射 - gemini-2.5-pro → gemini-3-pro-high",
requestedModel: "gemini-2.5-pro",
accountMapping: nil,
expected: "gemini-2.5-pro",
expected: "gemini-3-pro-high",
},
{
name: "Gemini透传 - gemini-future-model",
......
......@@ -19,17 +19,19 @@ import (
)
var (
ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password")
ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active")
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved")
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked")
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password")
ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active")
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved")
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked")
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
ErrInvitationCodeRequired = infraerrors.BadRequest("INVITATION_CODE_REQUIRED", "invitation code is required")
ErrInvitationCodeInvalid = infraerrors.BadRequest("INVITATION_CODE_INVALID", "invalid or used invitation code")
)
// maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。
......@@ -47,6 +49,7 @@ type JWTClaims struct {
// AuthService 认证服务
type AuthService struct {
userRepo UserRepository
redeemRepo RedeemCodeRepository
cfg *config.Config
settingService *SettingService
emailService *EmailService
......@@ -58,6 +61,7 @@ type AuthService struct {
// NewAuthService 创建认证服务实例
func NewAuthService(
userRepo UserRepository,
redeemRepo RedeemCodeRepository,
cfg *config.Config,
settingService *SettingService,
emailService *EmailService,
......@@ -67,6 +71,7 @@ func NewAuthService(
) *AuthService {
return &AuthService{
userRepo: userRepo,
redeemRepo: redeemRepo,
cfg: cfg,
settingService: settingService,
emailService: emailService,
......@@ -78,11 +83,11 @@ func NewAuthService(
// Register 用户注册,返回token和用户
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
return s.RegisterWithVerification(ctx, email, password, "", "")
return s.RegisterWithVerification(ctx, email, password, "", "", "")
}
// RegisterWithVerification 用户注册(支持邮件验证优惠码),返回token和用户
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode string) (string, *User, error) {
// RegisterWithVerification 用户注册(支持邮件验证优惠码和邀请码),返回token和用户
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode, invitationCode string) (string, *User, error) {
// 检查是否开放注册(默认关闭:settingService 未配置时不允许注册)
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
return "", nil, ErrRegDisabled
......@@ -93,6 +98,26 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return "", nil, ErrEmailReserved
}
// 检查是否需要邀请码
var invitationRedeemCode *RedeemCode
if s.settingService != nil && s.settingService.IsInvitationCodeEnabled(ctx) {
if invitationCode == "" {
return "", nil, ErrInvitationCodeRequired
}
// 验证邀请码
redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode)
if err != nil {
log.Printf("[Auth] Invalid invitation code: %s, error: %v", invitationCode, err)
return "", nil, ErrInvitationCodeInvalid
}
// 检查类型和状态
if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused {
log.Printf("[Auth] Invitation code invalid: type=%s, status=%s", redeemCode.Type, redeemCode.Status)
return "", nil, ErrInvitationCodeInvalid
}
invitationRedeemCode = redeemCode
}
// 检查是否需要邮件验证
if s.settingService != nil && s.settingService.IsEmailVerifyEnabled(ctx) {
// 如果邮件验证已开启但邮件服务未配置,拒绝注册
......@@ -153,6 +178,13 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return "", nil, ErrServiceUnavailable
}
// 标记邀请码为已使用(如果使用了邀请码)
if invitationRedeemCode != nil {
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
// 邀请码标记失败不影响注册,只记录日志
log.Printf("[Auth] Failed to mark invitation code as used for user %d: %v", user.ID, err)
}
}
// 应用优惠码(如果提供且功能已启用)
if promoCode != "" && s.promoService != nil && s.settingService != nil && s.settingService.IsPromoCodeEnabled(ctx) {
if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil {
......
......@@ -115,6 +115,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
return NewAuthService(
repo,
nil, // redeemRepo
cfg,
settingService,
emailService,
......@@ -152,7 +153,7 @@ func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testi
}, nil)
// 应返回服务不可用错误,而不是允许绕过验证
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "")
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "", "")
require.ErrorIs(t, err, ErrServiceUnavailable)
}
......@@ -164,7 +165,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
SettingKeyEmailVerifyEnabled: "true",
}, cache)
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "")
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "", "")
require.ErrorIs(t, err, ErrEmailVerifyRequired)
}
......@@ -178,7 +179,7 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
SettingKeyEmailVerifyEnabled: "true",
}, cache)
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "")
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "", "")
require.ErrorIs(t, err, ErrInvalidVerifyCode)
require.ErrorContains(t, err, "verify code")
}
......
......@@ -241,6 +241,76 @@ func (s *BillingService) CalculateCostWithConfig(model string, tokens UsageToken
return s.CalculateCost(model, tokens, multiplier)
}
// CalculateCostWithLongContext 计算费用,支持长上下文双倍计费
// threshold: 阈值(如 200000),超过此值的部分按 extraMultiplier 倍计费
// extraMultiplier: 超出部分的倍率(如 2.0 表示双倍)
//
// 示例:缓存 210k + 输入 10k = 220k,阈值 200k,倍率 2.0
// 拆分为:范围内 (200k, 0) + 范围外 (10k, 10k)
// 范围内正常计费,范围外 × 2 计费
func (s *BillingService) CalculateCostWithLongContext(model string, tokens UsageTokens, rateMultiplier float64, threshold int, extraMultiplier float64) (*CostBreakdown, error) {
// 未启用长上下文计费,直接走正常计费
if threshold <= 0 || extraMultiplier <= 1 {
return s.CalculateCost(model, tokens, rateMultiplier)
}
// 计算总输入 token(缓存读取 + 新输入)
total := tokens.CacheReadTokens + tokens.InputTokens
if total <= threshold {
return s.CalculateCost(model, tokens, rateMultiplier)
}
// 拆分成范围内和范围外
var inRangeCacheTokens, inRangeInputTokens int
var outRangeCacheTokens, outRangeInputTokens int
if tokens.CacheReadTokens >= threshold {
// 缓存已超过阈值:范围内只有缓存,范围外是超出的缓存+全部输入
inRangeCacheTokens = threshold
inRangeInputTokens = 0
outRangeCacheTokens = tokens.CacheReadTokens - threshold
outRangeInputTokens = tokens.InputTokens
} else {
// 缓存未超过阈值:范围内是全部缓存+部分输入,范围外是剩余输入
inRangeCacheTokens = tokens.CacheReadTokens
inRangeInputTokens = threshold - tokens.CacheReadTokens
outRangeCacheTokens = 0
outRangeInputTokens = tokens.InputTokens - inRangeInputTokens
}
// 范围内部分:正常计费
inRangeTokens := UsageTokens{
InputTokens: inRangeInputTokens,
OutputTokens: tokens.OutputTokens, // 输出只算一次
CacheCreationTokens: tokens.CacheCreationTokens,
CacheReadTokens: inRangeCacheTokens,
}
inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier)
if err != nil {
return nil, err
}
// 范围外部分:× extraMultiplier 计费
outRangeTokens := UsageTokens{
InputTokens: outRangeInputTokens,
CacheReadTokens: outRangeCacheTokens,
}
outRangeCost, err := s.CalculateCost(model, outRangeTokens, rateMultiplier*extraMultiplier)
if err != nil {
return inRangeCost, nil // 出错时返回范围内成本
}
// 合并成本
return &CostBreakdown{
InputCost: inRangeCost.InputCost + outRangeCost.InputCost,
OutputCost: inRangeCost.OutputCost,
CacheCreationCost: inRangeCost.CacheCreationCost,
CacheReadCost: inRangeCost.CacheReadCost + outRangeCost.CacheReadCost,
TotalCost: inRangeCost.TotalCost + outRangeCost.TotalCost,
ActualCost: inRangeCost.ActualCost + outRangeCost.ActualCost,
}, nil
}
// ListSupportedModels 列出所有支持的模型(现在总是返回true,因为有模糊匹配)
func (s *BillingService) ListSupportedModels() []string {
models := make([]string, 0)
......
......@@ -39,6 +39,7 @@ const (
RedeemTypeBalance = domain.RedeemTypeBalance
RedeemTypeConcurrency = domain.RedeemTypeConcurrency
RedeemTypeSubscription = domain.RedeemTypeSubscription
RedeemTypeInvitation = domain.RedeemTypeInvitation
)
// PromoCode status constants
......@@ -72,10 +73,11 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
// Setting keys
const (
// 注册设置
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册
// 邮件服务设置
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
......
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestMergeAnthropicBeta(t *testing.T) {
got := mergeAnthropicBeta(
[]string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"},
"foo, oauth-2025-04-20,bar, foo",
)
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,foo,bar", got)
}
func TestMergeAnthropicBeta_EmptyIncoming(t *testing.T) {
got := mergeAnthropicBeta(
[]string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"},
"",
)
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14", got)
}
......@@ -266,6 +266,14 @@ func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Conte
return 0, nil
}
func (m *mockGroupRepoForGateway) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
return nil
}
func (m *mockGroupRepoForGateway) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
return nil, nil
}
func ptr[T any](v T) *T {
return &v
}
......
package service
import (
"regexp"
"testing"
"github.com/stretchr/testify/require"
)
func TestBuildOAuthMetadataUserID_FallbackWithoutAccountUUID(t *testing.T) {
svc := &GatewayService{}
parsed := &ParsedRequest{
Model: "claude-sonnet-4-5",
Stream: true,
MetadataUserID: "",
System: nil,
Messages: nil,
}
account := &Account{
ID: 123,
Type: AccountTypeOAuth,
Extra: map[string]any{}, // intentionally missing account_uuid / claude_user_id
}
fp := &Fingerprint{ClientID: "deadbeef"} // should be used as user id in legacy format
got := svc.buildOAuthMetadataUserID(parsed, account, fp)
require.NotEmpty(t, got)
// Legacy format: user_{client}_account__session_{uuid}
re := regexp.MustCompile(`^user_[a-zA-Z0-9]+_account__session_[a-f0-9-]{36}$`)
require.True(t, re.MatchString(got), "unexpected user_id format: %s", got)
}
func TestBuildOAuthMetadataUserID_UsesAccountUUIDWhenPresent(t *testing.T) {
svc := &GatewayService{}
parsed := &ParsedRequest{
Model: "claude-sonnet-4-5",
Stream: true,
MetadataUserID: "",
}
account := &Account{
ID: 123,
Type: AccountTypeOAuth,
Extra: map[string]any{
"account_uuid": "acc-uuid",
"claude_user_id": "clientid123",
"anthropic_user_id": "",
},
}
got := svc.buildOAuthMetadataUserID(parsed, account, nil)
require.NotEmpty(t, got)
// New format: user_{client}_account_{account_uuid}_session_{uuid}
re := regexp.MustCompile(`^user_clientid123_account_acc-uuid_session_[a-f0-9-]{36}$`)
require.True(t, re.MatchString(got), "unexpected user_id format: %s", got)
}
......@@ -2,6 +2,7 @@ package service
import (
"encoding/json"
"strings"
"testing"
"github.com/stretchr/testify/require"
......@@ -134,6 +135,8 @@ func TestSystemIncludesClaudeCodePrompt(t *testing.T) {
}
func TestInjectClaudeCodePrompt(t *testing.T) {
claudePrefix := strings.TrimSpace(claudeCodeSystemPrompt)
tests := []struct {
name string
body string
......@@ -162,7 +165,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
system: "Custom prompt",
wantSystemLen: 2,
wantFirstText: claudeCodeSystemPrompt,
wantSecondText: "Custom prompt",
wantSecondText: claudePrefix + "\n\nCustom prompt",
},
{
name: "string system equals Claude Code prompt",
......@@ -178,7 +181,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
// Claude Code + Custom = 2
wantSystemLen: 2,
wantFirstText: claudeCodeSystemPrompt,
wantSecondText: "Custom",
wantSecondText: claudePrefix + "\n\nCustom",
},
{
name: "array system with existing Claude Code prompt (should dedupe)",
......@@ -190,7 +193,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
// Claude Code at start + Other = 2 (deduped)
wantSystemLen: 2,
wantFirstText: claudeCodeSystemPrompt,
wantSecondText: "Other",
wantSecondText: claudePrefix + "\n\nOther",
},
{
name: "empty array",
......
package service
import (
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestSanitizeOpenCodeText_RewritesCanonicalSentence(t *testing.T) {
in := "You are OpenCode, the best coding agent on the planet."
got := sanitizeSystemText(in)
require.Equal(t, strings.TrimSpace(claudeCodeSystemPrompt), got)
}
func TestSanitizeToolDescription_DoesNotRewriteKeywords(t *testing.T) {
in := "OpenCode and opencode are mentioned."
got := sanitizeToolDescription(in)
// We no longer rewrite tool descriptions; only redact obvious path leaks.
require.Equal(t, in, got)
}
......@@ -20,12 +20,14 @@ import (
"strings"
"sync/atomic"
"time"
"unicode"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/google/uuid"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
......@@ -37,8 +39,15 @@ const (
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
stickySessionTTL = time.Hour // 粘性会话TTL
defaultMaxLineSize = 40 * 1024 * 1024
claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude."
maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量
// Canonical Claude Code banner. Keep it EXACT (no trailing whitespace/newlines)
// to match real Claude CLI traffic as closely as possible. When we need a visual
// separator between system blocks, we add "\n\n" at concatenation time.
claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude."
maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量
)
const (
claudeMimicDebugInfoKey = "claude_mimic_debug_info"
)
func (s *GatewayService) debugModelRoutingEnabled() bool {
......@@ -46,6 +55,11 @@ func (s *GatewayService) debugModelRoutingEnabled() bool {
return v == "1" || v == "true" || v == "yes" || v == "on"
}
func (s *GatewayService) debugClaudeMimicEnabled() bool {
v := strings.ToLower(strings.TrimSpace(os.Getenv("SUB2API_DEBUG_CLAUDE_MIMIC")))
return v == "1" || v == "true" || v == "yes" || v == "on"
}
func shortSessionHash(sessionHash string) string {
if sessionHash == "" {
return ""
......@@ -65,12 +79,178 @@ func normalizeClaudeModelForAnthropic(requestedModel string) string {
return requestedModel
}
func redactAuthHeaderValue(v string) string {
v = strings.TrimSpace(v)
if v == "" {
return ""
}
// Keep scheme for debugging, redact secret.
if strings.HasPrefix(strings.ToLower(v), "bearer ") {
return "Bearer [redacted]"
}
return "[redacted]"
}
func safeHeaderValueForLog(key string, v string) string {
key = strings.ToLower(strings.TrimSpace(key))
switch key {
case "authorization", "x-api-key":
return redactAuthHeaderValue(v)
default:
return strings.TrimSpace(v)
}
}
func extractSystemPreviewFromBody(body []byte) string {
if len(body) == 0 {
return ""
}
sys := gjson.GetBytes(body, "system")
if !sys.Exists() {
return ""
}
switch {
case sys.IsArray():
for _, item := range sys.Array() {
if !item.IsObject() {
continue
}
if strings.EqualFold(item.Get("type").String(), "text") {
if t := item.Get("text").String(); strings.TrimSpace(t) != "" {
return t
}
}
}
return ""
case sys.Type == gjson.String:
return sys.String()
default:
return ""
}
}
func buildClaudeMimicDebugLine(req *http.Request, body []byte, account *Account, tokenType string, mimicClaudeCode bool) string {
if req == nil {
return ""
}
// Only log a minimal fingerprint to avoid leaking user content.
interesting := []string{
"user-agent",
"x-app",
"anthropic-dangerous-direct-browser-access",
"anthropic-version",
"anthropic-beta",
"x-stainless-lang",
"x-stainless-package-version",
"x-stainless-os",
"x-stainless-arch",
"x-stainless-runtime",
"x-stainless-runtime-version",
"x-stainless-retry-count",
"x-stainless-timeout",
"authorization",
"x-api-key",
"content-type",
"accept",
"x-stainless-helper-method",
}
h := make([]string, 0, len(interesting))
for _, k := range interesting {
if v := req.Header.Get(k); v != "" {
h = append(h, fmt.Sprintf("%s=%q", k, safeHeaderValueForLog(k, v)))
}
}
metaUserID := strings.TrimSpace(gjson.GetBytes(body, "metadata.user_id").String())
sysPreview := strings.TrimSpace(extractSystemPreviewFromBody(body))
// Truncate preview to keep logs sane.
if len(sysPreview) > 300 {
sysPreview = sysPreview[:300] + "..."
}
sysPreview = strings.ReplaceAll(sysPreview, "\n", "\\n")
sysPreview = strings.ReplaceAll(sysPreview, "\r", "\\r")
aid := int64(0)
aname := ""
if account != nil {
aid = account.ID
aname = account.Name
}
return fmt.Sprintf(
"url=%s account=%d(%s) tokenType=%s mimic=%t meta.user_id=%q system.preview=%q headers={%s}",
req.URL.String(),
aid,
aname,
tokenType,
mimicClaudeCode,
metaUserID,
sysPreview,
strings.Join(h, " "),
)
}
func logClaudeMimicDebug(req *http.Request, body []byte, account *Account, tokenType string, mimicClaudeCode bool) {
line := buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode)
if line == "" {
return
}
log.Printf("[ClaudeMimicDebug] %s", line)
}
func isClaudeCodeCredentialScopeError(msg string) bool {
m := strings.ToLower(strings.TrimSpace(msg))
if m == "" {
return false
}
return strings.Contains(m, "only authorized for use with claude code") &&
strings.Contains(m, "cannot be used for other api requests")
}
// sseDataRe matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var (
sseDataRe = regexp.MustCompile(`^data:\s*`)
sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`)
claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`)
toolPrefixRe = regexp.MustCompile(`(?i)^(?:oc_|mcp_)`)
toolNameBoundaryRe = regexp.MustCompile(`[^a-zA-Z0-9]+`)
toolNameCamelRe = regexp.MustCompile(`([a-z0-9])([A-Z])`)
toolNameFieldRe = regexp.MustCompile(`"name"\s*:\s*"([^"]+)"`)
modelFieldRe = regexp.MustCompile(`"model"\s*:\s*"([^"]+)"`)
toolDescAbsPathRe = regexp.MustCompile(`/\/?(?:home|Users|tmp|var|opt|usr|etc)\/[^\s,\)"'\]]+`)
toolDescWinPathRe = regexp.MustCompile(`(?i)[A-Z]:\\[^\s,\)"'\]]+`)
claudeToolNameOverrides = map[string]string{
"bash": "Bash",
"read": "Read",
"edit": "Edit",
"write": "Write",
"task": "Task",
"glob": "Glob",
"grep": "Grep",
"webfetch": "WebFetch",
"websearch": "WebSearch",
"todowrite": "TodoWrite",
"question": "AskUserQuestion",
}
openCodeToolOverrides = map[string]string{
"Bash": "bash",
"Read": "read",
"Edit": "edit",
"Write": "write",
"Task": "task",
"Glob": "glob",
"Grep": "grep",
"WebFetch": "webfetch",
"WebSearch": "websearch",
"TodoWrite": "todowrite",
"AskUserQuestion": "question",
}
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
// 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等
......@@ -436,6 +616,394 @@ func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte
return newBody
}
type claudeOAuthNormalizeOptions struct {
injectMetadata bool
metadataUserID string
stripSystemCacheControl bool
}
func stripToolPrefix(value string) string {
if value == "" {
return value
}
return toolPrefixRe.ReplaceAllString(value, "")
}
func toPascalCase(value string) string {
if value == "" {
return value
}
normalized := toolNameBoundaryRe.ReplaceAllString(value, " ")
tokens := make([]string, 0)
for _, token := range strings.Fields(normalized) {
expanded := toolNameCamelRe.ReplaceAllString(token, "$1 $2")
parts := strings.Fields(expanded)
if len(parts) > 0 {
tokens = append(tokens, parts...)
}
}
if len(tokens) == 0 {
return value
}
var builder strings.Builder
for _, token := range tokens {
lower := strings.ToLower(token)
if lower == "" {
continue
}
runes := []rune(lower)
runes[0] = unicode.ToUpper(runes[0])
_, _ = builder.WriteString(string(runes))
}
return builder.String()
}
func toSnakeCase(value string) string {
if value == "" {
return value
}
output := toolNameCamelRe.ReplaceAllString(value, "$1_$2")
output = toolNameBoundaryRe.ReplaceAllString(output, "_")
output = strings.Trim(output, "_")
return strings.ToLower(output)
}
func normalizeToolNameForClaude(name string, cache map[string]string) string {
if name == "" {
return name
}
stripped := stripToolPrefix(name)
mapped, ok := claudeToolNameOverrides[strings.ToLower(stripped)]
if !ok {
mapped = toPascalCase(stripped)
}
if mapped != "" && cache != nil && mapped != stripped {
cache[mapped] = stripped
}
if mapped == "" {
return stripped
}
return mapped
}
func normalizeToolNameForOpenCode(name string, cache map[string]string) string {
if name == "" {
return name
}
stripped := stripToolPrefix(name)
if cache != nil {
if mapped, ok := cache[stripped]; ok {
return mapped
}
}
if mapped, ok := openCodeToolOverrides[stripped]; ok {
return mapped
}
return toSnakeCase(stripped)
}
func normalizeParamNameForOpenCode(name string, cache map[string]string) string {
if name == "" {
return name
}
if cache != nil {
if mapped, ok := cache[name]; ok {
return mapped
}
}
return name
}
// sanitizeSystemText rewrites only the fixed OpenCode identity sentence (if present).
// We intentionally avoid broad keyword replacement in system prompts to prevent
// accidentally changing user-provided instructions.
func sanitizeSystemText(text string) string {
if text == "" {
return text
}
// Some clients include a fixed OpenCode identity sentence. Anthropic may treat
// this as a non-Claude-Code fingerprint, so rewrite it to the canonical
// Claude Code banner before generic "OpenCode"/"opencode" replacements.
text = strings.ReplaceAll(
text,
"You are OpenCode, the best coding agent on the planet.",
strings.TrimSpace(claudeCodeSystemPrompt),
)
return text
}
func sanitizeToolDescription(description string) string {
if description == "" {
return description
}
description = toolDescAbsPathRe.ReplaceAllString(description, "[path]")
description = toolDescWinPathRe.ReplaceAllString(description, "[path]")
// Intentionally do NOT rewrite tool descriptions (OpenCode/Claude strings).
// Tool names/skill names may rely on exact wording, and rewriting can be misleading.
return description
}
func normalizeToolInputSchema(inputSchema any, cache map[string]string) {
schema, ok := inputSchema.(map[string]any)
if !ok {
return
}
properties, ok := schema["properties"].(map[string]any)
if !ok {
return
}
newProperties := make(map[string]any, len(properties))
for key, value := range properties {
snakeKey := toSnakeCase(key)
newProperties[snakeKey] = value
if snakeKey != key && cache != nil {
cache[snakeKey] = key
}
}
schema["properties"] = newProperties
if required, ok := schema["required"].([]any); ok {
newRequired := make([]any, 0, len(required))
for _, item := range required {
name, ok := item.(string)
if !ok {
newRequired = append(newRequired, item)
continue
}
snakeName := toSnakeCase(name)
newRequired = append(newRequired, snakeName)
if snakeName != name && cache != nil {
cache[snakeName] = name
}
}
schema["required"] = newRequired
}
}
func stripCacheControlFromSystemBlocks(system any) bool {
blocks, ok := system.([]any)
if !ok {
return false
}
changed := false
for _, item := range blocks {
block, ok := item.(map[string]any)
if !ok {
continue
}
if _, exists := block["cache_control"]; !exists {
continue
}
delete(block, "cache_control")
changed = true
}
return changed
}
func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAuthNormalizeOptions) ([]byte, string, map[string]string) {
if len(body) == 0 {
return body, modelID, nil
}
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return body, modelID, nil
}
toolNameMap := make(map[string]string)
if system, ok := req["system"]; ok {
switch v := system.(type) {
case string:
sanitized := sanitizeSystemText(v)
if sanitized != v {
req["system"] = sanitized
}
case []any:
for _, item := range v {
block, ok := item.(map[string]any)
if !ok {
continue
}
if blockType, _ := block["type"].(string); blockType != "text" {
continue
}
text, ok := block["text"].(string)
if !ok || text == "" {
continue
}
sanitized := sanitizeSystemText(text)
if sanitized != text {
block["text"] = sanitized
}
}
}
}
if rawModel, ok := req["model"].(string); ok {
normalized := claude.NormalizeModelID(rawModel)
if normalized != rawModel {
req["model"] = normalized
modelID = normalized
}
}
if rawTools, exists := req["tools"]; exists {
switch tools := rawTools.(type) {
case []any:
for idx, tool := range tools {
toolMap, ok := tool.(map[string]any)
if !ok {
continue
}
if name, ok := toolMap["name"].(string); ok {
normalized := normalizeToolNameForClaude(name, toolNameMap)
if normalized != "" && normalized != name {
toolMap["name"] = normalized
}
}
if desc, ok := toolMap["description"].(string); ok {
sanitized := sanitizeToolDescription(desc)
if sanitized != desc {
toolMap["description"] = sanitized
}
}
if schema, ok := toolMap["input_schema"]; ok {
normalizeToolInputSchema(schema, toolNameMap)
}
tools[idx] = toolMap
}
req["tools"] = tools
case map[string]any:
normalizedTools := make(map[string]any, len(tools))
for name, value := range tools {
normalized := normalizeToolNameForClaude(name, toolNameMap)
if normalized == "" {
normalized = name
}
if toolMap, ok := value.(map[string]any); ok {
toolMap["name"] = normalized
if desc, ok := toolMap["description"].(string); ok {
sanitized := sanitizeToolDescription(desc)
if sanitized != desc {
toolMap["description"] = sanitized
}
}
if schema, ok := toolMap["input_schema"]; ok {
normalizeToolInputSchema(schema, toolNameMap)
}
normalizedTools[normalized] = toolMap
continue
}
normalizedTools[normalized] = value
}
req["tools"] = normalizedTools
}
} else {
req["tools"] = []any{}
}
if messages, ok := req["messages"].([]any); ok {
for _, msg := range messages {
msgMap, ok := msg.(map[string]any)
if !ok {
continue
}
content, ok := msgMap["content"].([]any)
if !ok {
continue
}
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
continue
}
if blockType, _ := blockMap["type"].(string); blockType != "tool_use" {
continue
}
if name, ok := blockMap["name"].(string); ok {
normalized := normalizeToolNameForClaude(name, toolNameMap)
if normalized != "" && normalized != name {
blockMap["name"] = normalized
}
}
}
}
}
if opts.stripSystemCacheControl {
if system, ok := req["system"]; ok {
_ = stripCacheControlFromSystemBlocks(system)
}
}
if opts.injectMetadata && opts.metadataUserID != "" {
metadata, ok := req["metadata"].(map[string]any)
if !ok {
metadata = map[string]any{}
req["metadata"] = metadata
}
if existing, ok := metadata["user_id"].(string); !ok || existing == "" {
metadata["user_id"] = opts.metadataUserID
}
}
delete(req, "temperature")
delete(req, "tool_choice")
newBody, err := json.Marshal(req)
if err != nil {
return body, modelID, toolNameMap
}
return newBody, modelID, toolNameMap
}
func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account *Account, fp *Fingerprint) string {
if parsed == nil || account == nil {
return ""
}
if parsed.MetadataUserID != "" {
return ""
}
userID := strings.TrimSpace(account.GetClaudeUserID())
if userID == "" && fp != nil {
userID = fp.ClientID
}
if userID == "" {
// Fall back to a random, well-formed client id so we can still satisfy
// Claude Code OAuth requirements when account metadata is incomplete.
userID = generateClientID()
}
sessionHash := s.GenerateSessionHash(parsed)
sessionID := uuid.NewString()
if sessionHash != "" {
seed := fmt.Sprintf("%d::%s", account.ID, sessionHash)
sessionID = generateSessionUUID(seed)
}
// Prefer the newer format that includes account_uuid (if present),
// otherwise fall back to the legacy Claude Code format.
accountUUID := strings.TrimSpace(account.GetExtraString("account_uuid"))
if accountUUID != "" {
return fmt.Sprintf("user_%s_account_%s_session_%s", userID, accountUUID, sessionID)
}
return fmt.Sprintf("user_%s_account__session_%s", userID, sessionID)
}
func generateSessionUUID(seed string) string {
if seed == "" {
return uuid.NewString()
}
hash := sha256.Sum256([]byte(seed))
bytes := hash[:16]
bytes[6] = (bytes[6] & 0x0f) | 0x40
bytes[8] = (bytes[8] & 0x3f) | 0x80
return fmt.Sprintf("%x-%x-%x-%x-%x",
bytes[0:4], bytes[4:6], bytes[6:8], bytes[8:10], bytes[10:16])
}
// SelectAccount 选择账号(粘性会话+优先级)
func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) {
return s.SelectAccountForModel(ctx, groupID, sessionHash, "")
......@@ -2060,6 +2628,16 @@ func isClaudeCodeClient(userAgent string, metadataUserID string) bool {
return claudeCliUserAgentRe.MatchString(userAgent)
}
func isClaudeCodeRequest(ctx context.Context, c *gin.Context, parsed *ParsedRequest) bool {
if IsClaudeCodeClient(ctx) {
return true
}
if parsed == nil || c == nil {
return false
}
return isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID)
}
// systemIncludesClaudeCodePrompt 检查 system 中是否已包含 Claude Code 提示词
// 使用前缀匹配支持多种变体(标准版、Agent SDK 版等)
func systemIncludesClaudeCodePrompt(system any) bool {
......@@ -2096,6 +2674,10 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
"text": claudeCodeSystemPrompt,
"cache_control": map[string]string{"type": "ephemeral"},
}
// Opencode plugin applies an extra safeguard: it not only prepends the Claude Code
// banner, it also prefixes the next system instruction with the same banner plus
// a blank line. This helps when upstream concatenates system instructions.
claudeCodePrefix := strings.TrimSpace(claudeCodeSystemPrompt)
var newSystem []any
......@@ -2103,19 +2685,36 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
case nil:
newSystem = []any{claudeCodeBlock}
case string:
if v == "" || v == claudeCodeSystemPrompt {
// Be tolerant of older/newer clients that may differ only by trailing whitespace/newlines.
if strings.TrimSpace(v) == "" || strings.TrimSpace(v) == strings.TrimSpace(claudeCodeSystemPrompt) {
newSystem = []any{claudeCodeBlock}
} else {
newSystem = []any{claudeCodeBlock, map[string]any{"type": "text", "text": v}}
// Mirror opencode behavior: keep the banner as a separate system entry,
// but also prefix the next system text with the banner.
merged := v
if !strings.HasPrefix(v, claudeCodePrefix) {
merged = claudeCodePrefix + "\n\n" + v
}
newSystem = []any{claudeCodeBlock, map[string]any{"type": "text", "text": merged}}
}
case []any:
newSystem = make([]any, 0, len(v)+1)
newSystem = append(newSystem, claudeCodeBlock)
prefixedNext := false
for _, item := range v {
if m, ok := item.(map[string]any); ok {
if text, ok := m["text"].(string); ok && text == claudeCodeSystemPrompt {
if text, ok := m["text"].(string); ok && strings.TrimSpace(text) == strings.TrimSpace(claudeCodeSystemPrompt) {
continue
}
// Prefix the first subsequent text system block once.
if !prefixedNext {
if blockType, _ := m["type"].(string); blockType == "text" {
if text, ok := m["text"].(string); ok && strings.TrimSpace(text) != "" && !strings.HasPrefix(text, claudeCodePrefix) {
m["text"] = claudeCodePrefix + "\n\n" + text
prefixedNext = true
}
}
}
}
newSystem = append(newSystem, item)
}
......@@ -2319,21 +2918,38 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
body := parsed.Body
reqModel := parsed.Model
reqStream := parsed.Stream
originalModel := reqModel
var toolNameMap map[string]string
isClaudeCode := isClaudeCodeRequest(ctx, c, parsed)
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
if shouldMimicClaudeCode {
// 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要)
// 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
if !strings.Contains(strings.ToLower(reqModel), "haiku") &&
!systemIncludesClaudeCodePrompt(parsed.System) {
body = injectClaudeCodePrompt(body, parsed.System)
}
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true}
if s.identityService != nil {
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
if err == nil && fp != nil {
if metadataUserID := s.buildOAuthMetadataUserID(parsed, account, fp); metadataUserID != "" {
normalizeOpts.injectMetadata = true
normalizeOpts.metadataUserID = metadataUserID
}
}
}
// 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要)
// 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
if account.IsOAuth() &&
!isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) &&
!strings.Contains(strings.ToLower(reqModel), "haiku") &&
!systemIncludesClaudeCodePrompt(parsed.System) {
body = injectClaudeCodePrompt(body, parsed.System)
body, reqModel, toolNameMap = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
}
// 强制执行 cache_control 块数量限制(最多 4 个)
body = enforceCacheControlLimit(body)
// 应用模型映射(APIKey 明确映射优先,其次使用 Anthropic 前缀映射)
originalModel := reqModel
mappedModel := reqModel
mappingSource := ""
if account.Type == AccountTypeAPIKey {
......@@ -2377,10 +2993,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
retryStart := time.Now()
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel)
// Capture upstream request body for ops retry of this attempt.
c.Set(OpsUpstreamRequestBodyKey, string(body))
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
if err != nil {
return nil, err
}
......@@ -2458,7 +3073,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// also downgrade tool_use/tool_result blocks to text.
filteredBody := FilterThinkingBlocksForRetry(body)
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
if buildErr == nil {
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if retryErr == nil {
......@@ -2490,7 +3105,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed {
log.Printf("Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID)
filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body)
retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel)
retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
if buildErr2 == nil {
retryResp2, retryErr2 := s.httpUpstream.DoWithTLS(retryReq2, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if retryErr2 == nil {
......@@ -2715,7 +3330,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
var firstTokenMs *int
var clientDisconnect bool
if reqStream {
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel)
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel, toolNameMap, shouldMimicClaudeCode)
if err != nil {
if err.Error() == "have error in stream" {
return nil, &UpstreamFailoverError{
......@@ -2728,7 +3343,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
firstTokenMs = streamResult.firstTokenMs
clientDisconnect = streamResult.clientDisconnect
} else {
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel)
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel, toolNameMap, shouldMimicClaudeCode)
if err != nil {
return nil, err
}
......@@ -2745,7 +3360,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}, nil
}
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool, mimicClaudeCode bool) (*http.Request, error) {
// 确定目标URL
targetURL := claudeAPIURL
if account.Type == AccountTypeAPIKey {
......@@ -2759,11 +3374,16 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
}
clientHeaders := http.Header{}
if c != nil && c.Request != nil {
clientHeaders = c.Request.Header
}
// OAuth账号:应用统一指纹
var fingerprint *Fingerprint
if account.IsOAuth() && s.identityService != nil {
// 1. 获取或创建指纹(包含随机生成的ClientID)
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, clientHeaders)
if err != nil {
log.Printf("Warning: failed to get fingerprint for account %d: %v", account.ID, err)
// 失败时降级为透传原始headers
......@@ -2794,7 +3414,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
// 白名单透传headers
for key, values := range c.Request.Header {
for key, values := range clientHeaders {
lowerKey := strings.ToLower(key)
if allowedHeaders[lowerKey] {
for _, v := range values {
......@@ -2815,10 +3435,30 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
if req.Header.Get("anthropic-version") == "" {
req.Header.Set("anthropic-version", "2023-06-01")
}
if tokenType == "oauth" {
applyClaudeOAuthHeaderDefaults(req, reqStream)
}
// 处理anthropic-beta header(OAuth账号需要特殊处理
// 处理 anthropic-beta header(OAuth 账号需要包含 oauth beta
if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
if mimicClaudeCode {
// 非 Claude Code 客户端:按 opencode 的策略处理:
// - 强制 Claude Code 指纹相关请求头(尤其是 user-agent/x-stainless/x-app)
// - 保留 incoming beta 的同时,确保 OAuth 所需 beta 存在
applyClaudeCodeMimicHeaders(req, reqStream)
incomingBeta := req.Header.Get("anthropic-beta")
// Match real Claude CLI traffic (per mitmproxy reports):
// messages requests typically use only oauth + interleaved-thinking.
// Also drop claude-code beta if a downstream client added it.
requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking}
drop := map[string]struct{}{claude.BetaClaudeCode: {}}
req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop))
} else {
// Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta
clientBetaHeader := req.Header.Get("anthropic-beta")
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, clientBetaHeader))
}
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
if requestNeedsBetaFeatures(body) {
......@@ -2828,6 +3468,15 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
}
// Always capture a compact fingerprint line for later error diagnostics.
// We only print it when needed (or when the explicit debug flag is enabled).
if c != nil && tokenType == "oauth" {
c.Set(claudeMimicDebugInfoKey, buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode))
}
if s.debugClaudeMimicEnabled() {
logClaudeMimicDebug(req, body, account, tokenType, mimicClaudeCode)
}
return req, nil
}
......@@ -2897,30 +3546,117 @@ func defaultAPIKeyBetaHeader(body []byte) string {
return claude.APIKeyBetaHeader
}
func truncateForLog(b []byte, maxBytes int) string {
if maxBytes <= 0 {
maxBytes = 2048
func applyClaudeOAuthHeaderDefaults(req *http.Request, isStream bool) {
if req == nil {
return
}
if len(b) > maxBytes {
b = b[:maxBytes]
if req.Header.Get("accept") == "" {
req.Header.Set("accept", "application/json")
}
for key, value := range claude.DefaultHeaders {
if value == "" {
continue
}
if req.Header.Get(key) == "" {
req.Header.Set(key, value)
}
}
if isStream && req.Header.Get("x-stainless-helper-method") == "" {
req.Header.Set("x-stainless-helper-method", "stream")
}
s := string(b)
// 保持一行,避免污染日志格式
s = strings.ReplaceAll(s, "\n", "\\n")
s = strings.ReplaceAll(s, "\r", "\\r")
return s
}
// isThinkingBlockSignatureError 检测是否是thinking block相关错误
// 这类错误可以通过过滤thinking blocks并重试来解决
func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool {
msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
if msg == "" {
return false
}
func mergeAnthropicBeta(required []string, incoming string) string {
seen := make(map[string]struct{}, len(required)+8)
out := make([]string, 0, len(required)+8)
// Log for debugging
log.Printf("[SignatureCheck] Checking error message: %s", msg)
add := func(v string) {
v = strings.TrimSpace(v)
if v == "" {
return
}
if _, ok := seen[v]; ok {
return
}
seen[v] = struct{}{}
out = append(out, v)
}
for _, r := range required {
add(r)
}
for _, p := range strings.Split(incoming, ",") {
add(p)
}
return strings.Join(out, ",")
}
func mergeAnthropicBetaDropping(required []string, incoming string, drop map[string]struct{}) string {
merged := mergeAnthropicBeta(required, incoming)
if merged == "" || len(drop) == 0 {
return merged
}
out := make([]string, 0, 8)
for _, p := range strings.Split(merged, ",") {
p = strings.TrimSpace(p)
if p == "" {
continue
}
if _, ok := drop[p]; ok {
continue
}
out = append(out, p)
}
return strings.Join(out, ",")
}
// applyClaudeCodeMimicHeaders forces "Claude Code-like" request headers.
// This mirrors opencode-anthropic-auth behavior: do not trust downstream
// headers when using Claude Code-scoped OAuth credentials.
func applyClaudeCodeMimicHeaders(req *http.Request, isStream bool) {
if req == nil {
return
}
// Start with the standard defaults (fill missing).
applyClaudeOAuthHeaderDefaults(req, isStream)
// Then force key headers to match Claude Code fingerprint regardless of what the client sent.
for key, value := range claude.DefaultHeaders {
if value == "" {
continue
}
req.Header.Set(key, value)
}
// Real Claude CLI uses Accept: application/json (even for streaming).
req.Header.Set("accept", "application/json")
if isStream {
req.Header.Set("x-stainless-helper-method", "stream")
}
}
func truncateForLog(b []byte, maxBytes int) string {
if maxBytes <= 0 {
maxBytes = 2048
}
if len(b) > maxBytes {
b = b[:maxBytes]
}
s := string(b)
// 保持一行,避免污染日志格式
s = strings.ReplaceAll(s, "\n", "\\n")
s = strings.ReplaceAll(s, "\r", "\\r")
return s
}
// isThinkingBlockSignatureError 检测是否是thinking block相关错误
// 这类错误可以通过过滤thinking blocks并重试来解决
func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool {
msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
if msg == "" {
return false
}
// Log for debugging
log.Printf("[SignatureCheck] Checking error message: %s", msg)
// 检测signature相关的错误(更宽松的匹配)
// 例如: "Invalid `signature` in `thinking` block", "***.signature" 等
......@@ -3000,6 +3736,20 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
// Print a compact upstream request fingerprint when we hit the Claude Code OAuth
// credential scope error. This avoids requiring env-var tweaks in a fixed deploy.
if isClaudeCodeCredentialScopeError(upstreamMsg) && c != nil {
if v, ok := c.Get(claudeMimicDebugInfoKey); ok {
if line, ok := v.(string); ok && strings.TrimSpace(line) != "" {
log.Printf("[ClaudeMimicDebugOnError] status=%d request_id=%s %s",
resp.StatusCode,
resp.Header.Get("x-request-id"),
line,
)
}
}
}
// Enrich Ops error logs with upstream status + message, and optionally a truncated body snippet.
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
......@@ -3129,6 +3879,19 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if isClaudeCodeCredentialScopeError(upstreamMsg) && c != nil {
if v, ok := c.Get(claudeMimicDebugInfoKey); ok {
if line, ok := v.(string); ok && strings.TrimSpace(line) != "" {
log.Printf("[ClaudeMimicDebugOnError] status=%d request_id=%s %s",
resp.StatusCode,
resp.Header.Get("x-request-id"),
line,
)
}
}
}
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
......@@ -3181,7 +3944,7 @@ type streamingResult struct {
clientDisconnect bool // 客户端是否在流式传输过程中断开
}
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) {
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string, toolNameMap map[string]string, mimicClaudeCode bool) (*streamingResult, error) {
// 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
......@@ -3276,6 +4039,171 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
needModelReplace := originalModel != mappedModel
clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage
pendingEventLines := make([]string, 0, 4)
var toolInputBuffers map[int]string
if mimicClaudeCode {
toolInputBuffers = make(map[int]string)
}
transformToolInputJSON := func(raw string) string {
if !mimicClaudeCode {
return raw
}
raw = strings.TrimSpace(raw)
if raw == "" {
return raw
}
var parsed any
if err := json.Unmarshal([]byte(raw), &parsed); err != nil {
return replaceToolNamesInText(raw, toolNameMap)
}
rewritten, changed := rewriteParamKeysInValue(parsed, toolNameMap)
if changed {
if bytes, err := json.Marshal(rewritten); err == nil {
return string(bytes)
}
}
return raw
}
processSSEEvent := func(lines []string) ([]string, string, error) {
if len(lines) == 0 {
return nil, "", nil
}
eventName := ""
dataLine := ""
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if strings.HasPrefix(trimmed, "event:") {
eventName = strings.TrimSpace(strings.TrimPrefix(trimmed, "event:"))
continue
}
if dataLine == "" && sseDataRe.MatchString(trimmed) {
dataLine = sseDataRe.ReplaceAllString(trimmed, "")
}
}
if eventName == "error" {
return nil, dataLine, errors.New("have error in stream")
}
if dataLine == "" {
return []string{strings.Join(lines, "\n") + "\n\n"}, "", nil
}
if dataLine == "[DONE]" {
block := ""
if eventName != "" {
block = "event: " + eventName + "\n"
}
block += "data: " + dataLine + "\n\n"
return []string{block}, dataLine, nil
}
var event map[string]any
if err := json.Unmarshal([]byte(dataLine), &event); err != nil {
replaced := dataLine
if mimicClaudeCode {
replaced = replaceToolNamesInText(dataLine, toolNameMap)
}
block := ""
if eventName != "" {
block = "event: " + eventName + "\n"
}
block += "data: " + replaced + "\n\n"
return []string{block}, replaced, nil
}
eventType, _ := event["type"].(string)
if eventName == "" {
eventName = eventType
}
if needModelReplace {
if msg, ok := event["message"].(map[string]any); ok {
if model, ok := msg["model"].(string); ok && model == mappedModel {
msg["model"] = originalModel
}
}
}
if mimicClaudeCode && eventType == "content_block_delta" {
if delta, ok := event["delta"].(map[string]any); ok {
if deltaType, _ := delta["type"].(string); deltaType == "input_json_delta" {
if indexVal, ok := event["index"].(float64); ok {
index := int(indexVal)
if partial, ok := delta["partial_json"].(string); ok {
toolInputBuffers[index] += partial
}
}
return nil, dataLine, nil
}
}
}
if mimicClaudeCode && eventType == "content_block_stop" {
if indexVal, ok := event["index"].(float64); ok {
index := int(indexVal)
if buffered := toolInputBuffers[index]; buffered != "" {
delete(toolInputBuffers, index)
transformed := transformToolInputJSON(buffered)
synthetic := map[string]any{
"type": "content_block_delta",
"index": index,
"delta": map[string]any{
"type": "input_json_delta",
"partial_json": transformed,
},
}
synthBytes, synthErr := json.Marshal(synthetic)
if synthErr == nil {
synthBlock := "event: content_block_delta\n" + "data: " + string(synthBytes) + "\n\n"
rewriteToolNamesInValue(event, toolNameMap)
stopBytes, stopErr := json.Marshal(event)
if stopErr == nil {
stopBlock := ""
if eventName != "" {
stopBlock = "event: " + eventName + "\n"
}
stopBlock += "data: " + string(stopBytes) + "\n\n"
return []string{synthBlock, stopBlock}, string(stopBytes), nil
}
}
}
}
}
if mimicClaudeCode {
rewriteToolNamesInValue(event, toolNameMap)
}
newData, err := json.Marshal(event)
if err != nil {
replaced := dataLine
if mimicClaudeCode {
replaced = replaceToolNamesInText(dataLine, toolNameMap)
}
block := ""
if eventName != "" {
block = "event: " + eventName + "\n"
}
block += "data: " + replaced + "\n\n"
return []string{block}, replaced, nil
}
block := ""
if eventName != "" {
block = "event: " + eventName + "\n"
}
block += "data: " + string(newData) + "\n\n"
return []string{block}, string(newData), nil
}
for {
select {
case ev, ok := <-events:
......@@ -3304,43 +4232,44 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
}
line := ev.line
if line == "event: error" {
// 上游返回错误事件,如果客户端已断开仍返回已收集的 usage
if clientDisconnected {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
return nil, errors.New("have error in stream")
}
trimmed := strings.TrimSpace(line)
// Extract data from SSE line (supports both "data: " and "data:" formats)
var data string
if sseDataRe.MatchString(line) {
data = sseDataRe.ReplaceAllString(line, "")
// 如果有模型映射,替换响应中的model字段
if needModelReplace {
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
if trimmed == "" {
if len(pendingEventLines) == 0 {
continue
}
}
// 写入客户端(统一处理 data 行和非 data 行)
if !clientDisconnected {
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
clientDisconnected = true
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
} else {
flusher.Flush()
outputBlocks, data, err := processSSEEvent(pendingEventLines)
pendingEventLines = pendingEventLines[:0]
if err != nil {
if clientDisconnected {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
return nil, err
}
}
// 无论客户端是否断开,都解析 usage(仅对 data 行)
if data != "" {
if firstTokenMs == nil && data != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
for _, block := range outputBlocks {
if !clientDisconnected {
if _, werr := fmt.Fprint(w, block); werr != nil {
clientDisconnected = true
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
break
}
flusher.Flush()
}
if data != "" {
if firstTokenMs == nil && data != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
s.parseSSEUsage(data, usage)
}
}
s.parseSSEUsage(data, usage)
continue
}
pendingEventLines = append(pendingEventLines, line)
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
......@@ -3363,43 +4292,124 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
// replaceModelInSSELine 替换SSE数据行中的model字段
func (s *GatewayService) replaceModelInSSELine(line, fromModel, toModel string) string {
if !sseDataRe.MatchString(line) {
return line
}
data := sseDataRe.ReplaceAllString(line, "")
if data == "" || data == "[DONE]" {
return line
}
var event map[string]any
if err := json.Unmarshal([]byte(data), &event); err != nil {
return line
}
// 只替换 message_start 事件中的 message.model
if event["type"] != "message_start" {
return line
func rewriteParamKeysInValue(value any, cache map[string]string) (any, bool) {
switch v := value.(type) {
case map[string]any:
changed := false
rewritten := make(map[string]any, len(v))
for key, item := range v {
newKey := normalizeParamNameForOpenCode(key, cache)
newItem, childChanged := rewriteParamKeysInValue(item, cache)
if childChanged {
changed = true
}
if newKey != key {
changed = true
}
rewritten[newKey] = newItem
}
if !changed {
return value, false
}
return rewritten, true
case []any:
changed := false
rewritten := make([]any, len(v))
for idx, item := range v {
newItem, childChanged := rewriteParamKeysInValue(item, cache)
if childChanged {
changed = true
}
rewritten[idx] = newItem
}
if !changed {
return value, false
}
return rewritten, true
default:
return value, false
}
}
msg, ok := event["message"].(map[string]any)
if !ok {
return line
func rewriteToolNamesInValue(value any, toolNameMap map[string]string) bool {
switch v := value.(type) {
case map[string]any:
changed := false
if blockType, _ := v["type"].(string); blockType == "tool_use" {
if name, ok := v["name"].(string); ok {
mapped := normalizeToolNameForOpenCode(name, toolNameMap)
if mapped != name {
v["name"] = mapped
changed = true
}
}
if input, ok := v["input"].(map[string]any); ok {
rewrittenInput, inputChanged := rewriteParamKeysInValue(input, toolNameMap)
if inputChanged {
if m, ok := rewrittenInput.(map[string]any); ok {
v["input"] = m
changed = true
}
}
}
}
for _, item := range v {
if rewriteToolNamesInValue(item, toolNameMap) {
changed = true
}
}
return changed
case []any:
changed := false
for _, item := range v {
if rewriteToolNamesInValue(item, toolNameMap) {
changed = true
}
}
return changed
default:
return false
}
}
model, ok := msg["model"].(string)
if !ok || model != fromModel {
return line
func replaceToolNamesInText(text string, toolNameMap map[string]string) string {
if text == "" {
return text
}
output := toolNameFieldRe.ReplaceAllStringFunc(text, func(match string) string {
submatches := toolNameFieldRe.FindStringSubmatch(match)
if len(submatches) < 2 {
return match
}
name := submatches[1]
mapped := normalizeToolNameForOpenCode(name, toolNameMap)
if mapped == name {
return match
}
return strings.Replace(match, name, mapped, 1)
})
output = modelFieldRe.ReplaceAllStringFunc(output, func(match string) string {
submatches := modelFieldRe.FindStringSubmatch(match)
if len(submatches) < 2 {
return match
}
model := submatches[1]
mapped := claude.DenormalizeModelID(model)
if mapped == model {
return match
}
return strings.Replace(match, model, mapped, 1)
})
msg["model"] = toModel
newData, err := json.Marshal(event)
if err != nil {
return line
for mapped, original := range toolNameMap {
if mapped == "" || original == "" || mapped == original {
continue
}
output = strings.ReplaceAll(output, "\""+mapped+"\":", "\""+original+"\":")
output = strings.ReplaceAll(output, "\\\""+mapped+"\\\":", "\\\""+original+"\\\":")
}
return "data: " + string(newData)
return output
}
func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
......@@ -3445,7 +4455,7 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
}
}
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) {
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string, toolNameMap map[string]string, mimicClaudeCode bool) (*ClaudeUsage, error) {
// 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
......@@ -3466,6 +4476,9 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
if originalModel != mappedModel {
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
if mimicClaudeCode {
body = s.replaceToolNamesInResponseBody(body, toolNameMap)
}
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
......@@ -3503,6 +4516,28 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
return newBody
}
func (s *GatewayService) replaceToolNamesInResponseBody(body []byte, toolNameMap map[string]string) []byte {
if len(body) == 0 {
return body
}
var resp map[string]any
if err := json.Unmarshal(body, &resp); err != nil {
replaced := replaceToolNamesInText(string(body), toolNameMap)
if replaced == string(body) {
return body
}
return []byte(replaced)
}
if !rewriteToolNamesInValue(resp, toolNameMap) {
return body
}
newBody, err := json.Marshal(resp)
if err != nil {
return body
}
return newBody
}
// RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct {
Result *ForwardResult
......@@ -3657,6 +4692,162 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
return nil
}
// RecordUsageLongContextInput 记录使用量的输入参数(支持长上下文双倍计费)
type RecordUsageLongContextInput struct {
Result *ForwardResult
APIKey *APIKey
User *User
Account *Account
Subscription *UserSubscription // 可选:订阅信息
UserAgent string // 请求的 User-Agent
IPAddress string // 请求的客户端 IP 地址
LongContextThreshold int // 长上下文阈值(如 200000)
LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0)
}
// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini)
func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *RecordUsageLongContextInput) error {
result := input.Result
apiKey := input.APIKey
user := input.User
account := input.Account
subscription := input.Subscription
// 获取费率倍数
multiplier := s.cfg.Default.RateMultiplier
if apiKey.GroupID != nil && apiKey.Group != nil {
multiplier = apiKey.Group.RateMultiplier
}
var cost *CostBreakdown
// 根据请求类型选择计费方式
if result.ImageCount > 0 {
// 图片生成计费
var groupConfig *ImagePriceConfig
if apiKey.Group != nil {
groupConfig = &ImagePriceConfig{
Price1K: apiKey.Group.ImagePrice1K,
Price2K: apiKey.Group.ImagePrice2K,
Price4K: apiKey.Group.ImagePrice4K,
}
}
cost = s.billingService.CalculateImageCost(result.Model, result.ImageSize, result.ImageCount, groupConfig, multiplier)
} else {
// Token 计费(使用长上下文计费方法)
tokens := UsageTokens{
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
}
var err error
cost, err = s.billingService.CalculateCostWithLongContext(result.Model, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
if err != nil {
log.Printf("Calculate cost failed: %v", err)
cost = &CostBreakdown{ActualCost: 0}
}
}
// 判断计费方式:订阅模式 vs 余额模式
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
billingType := BillingTypeBalance
if isSubscriptionBilling {
billingType = BillingTypeSubscription
}
// 创建使用日志
durationMs := int(result.Duration.Milliseconds())
var imageSize *string
if result.ImageSize != "" {
imageSize = &result.ImageSize
}
accountRateMultiplier := account.BillingRateMultiplier()
usageLog := &UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: result.RequestID,
Model: result.Model,
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
InputCost: cost.InputCost,
OutputCost: cost.OutputCost,
CacheCreationCost: cost.CacheCreationCost,
CacheReadCost: cost.CacheReadCost,
TotalCost: cost.TotalCost,
ActualCost: cost.ActualCost,
RateMultiplier: multiplier,
AccountRateMultiplier: &accountRateMultiplier,
BillingType: billingType,
Stream: result.Stream,
DurationMs: &durationMs,
FirstTokenMs: result.FirstTokenMs,
ImageCount: result.ImageCount,
ImageSize: imageSize,
CreatedAt: time.Now(),
}
// 添加 UserAgent
if input.UserAgent != "" {
usageLog.UserAgent = &input.UserAgent
}
// 添加 IPAddress
if input.IPAddress != "" {
usageLog.IPAddress = &input.IPAddress
}
// 添加分组和订阅关联
if apiKey.GroupID != nil {
usageLog.GroupID = apiKey.GroupID
}
if subscription != nil {
usageLog.SubscriptionID = &subscription.ID
}
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
if err != nil {
log.Printf("Create usage log failed: %v", err)
}
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil
}
shouldBill := inserted || err != nil
// 根据计费类型执行扣费
if isSubscriptionBilling {
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
if shouldBill && cost.TotalCost > 0 {
if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
log.Printf("Increment subscription usage failed: %v", err)
}
// 异步更新订阅缓存
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
}
} else {
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
if shouldBill && cost.ActualCost > 0 {
if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
log.Printf("Deduct balance failed: %v", err)
}
// 异步更新余额缓存
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
}
}
// Schedule batch update for account last_used_at
s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil
}
// ForwardCountTokens 转发 count_tokens 请求到上游 API
// 特点:不记录使用量、仅支持非流式响应
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error {
......@@ -3668,6 +4859,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
body := parsed.Body
reqModel := parsed.Model
isClaudeCode := isClaudeCodeRequest(ctx, c, parsed)
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
if shouldMimicClaudeCode {
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true}
body, reqModel, _ = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
}
// Antigravity 账户不支持 count_tokens 转发,直接返回空值
if account.Platform == PlatformAntigravity {
c.JSON(http.StatusOK, gin.H{"input_tokens": 0})
......@@ -3706,7 +4905,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
// 构建上游请求
upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType, reqModel)
upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType, reqModel, shouldMimicClaudeCode)
if err != nil {
s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request")
return err
......@@ -3739,7 +4938,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
log.Printf("Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID)
filteredBody := FilterThinkingBlocksForRetry(body)
retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, shouldMimicClaudeCode)
if buildErr == nil {
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if retryErr == nil {
......@@ -3804,7 +5003,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
// buildCountTokensRequest 构建 count_tokens 上游请求
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, mimicClaudeCode bool) (*http.Request, error) {
// 确定目标 URL
targetURL := claudeAPICountTokensURL
if account.Type == AccountTypeAPIKey {
......@@ -3818,10 +5017,15 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
}
clientHeaders := http.Header{}
if c != nil && c.Request != nil {
clientHeaders = c.Request.Header
}
// OAuth 账号:应用统一指纹和重写 userID
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
if account.IsOAuth() && s.identityService != nil {
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, clientHeaders)
if err == nil {
accountUUID := account.GetExtraString("account_uuid")
if accountUUID != "" && fp.ClientID != "" {
......@@ -3845,7 +5049,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
// 白名单透传 headers
for key, values := range c.Request.Header {
for key, values := range clientHeaders {
lowerKey := strings.ToLower(key)
if allowedHeaders[lowerKey] {
for _, v := range values {
......@@ -3856,7 +5060,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
// OAuth 账号:应用指纹到请求头
if account.IsOAuth() && s.identityService != nil {
fp, _ := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
fp, _ := s.identityService.GetOrCreateFingerprint(ctx, account.ID, clientHeaders)
if fp != nil {
s.identityService.ApplyFingerprint(req, fp)
}
......@@ -3869,10 +5073,30 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
if req.Header.Get("anthropic-version") == "" {
req.Header.Set("anthropic-version", "2023-06-01")
}
if tokenType == "oauth" {
applyClaudeOAuthHeaderDefaults(req, false)
}
// OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
if mimicClaudeCode {
applyClaudeCodeMimicHeaders(req, false)
incomingBeta := req.Header.Get("anthropic-beta")
requiredBetas := []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking, claude.BetaTokenCounting}
req.Header.Set("anthropic-beta", mergeAnthropicBeta(requiredBetas, incomingBeta))
} else {
clientBetaHeader := req.Header.Get("anthropic-beta")
if clientBetaHeader == "" {
req.Header.Set("anthropic-beta", claude.CountTokensBetaHeader)
} else {
beta := s.getBetaHeader(modelID, clientBetaHeader)
if !strings.Contains(beta, claude.BetaTokenCounting) {
beta = beta + "," + claude.BetaTokenCounting
}
req.Header.Set("anthropic-beta", beta)
}
}
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key:与 messages 同步的按需 beta 注入(默认关闭)
if requestNeedsBetaFeatures(body) {
......@@ -3882,6 +5106,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
}
if c != nil && tokenType == "oauth" {
c.Set(claudeMimicDebugInfoKey, buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode))
}
if s.debugClaudeMimicEnabled() {
logClaudeMimicDebug(req, body, account, tokenType, mimicClaudeCode)
}
return req, nil
}
......
......@@ -36,6 +36,11 @@ const (
geminiRetryMaxDelay = 16 * time.Second
)
// Gemini tool calling now requires `thoughtSignature` in parts that include `functionCall`.
// Many clients don't send it; we inject a known dummy signature to satisfy the validator.
// Ref: https://ai.google.dev/gemini-api/docs/thought-signatures
const geminiDummyThoughtSignature = "skip_thought_signature_validator"
type GeminiMessagesCompatService struct {
accountRepo AccountRepository
groupRepo GroupRepository
......@@ -528,6 +533,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
if err != nil {
return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", err.Error())
}
geminiReq = ensureGeminiFunctionCallThoughtSignatures(geminiReq)
originalClaudeBody := body
proxyURL := ""
......@@ -983,6 +989,10 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return nil, s.writeGoogleError(c, http.StatusNotFound, "Unsupported action: "+action)
}
// Some Gemini upstreams validate tool call parts strictly; ensure any `functionCall` part includes a
// `thoughtSignature` to avoid frequent INVALID_ARGUMENT 400s.
body = ensureGeminiFunctionCallThoughtSignatures(body)
mappedModel := originalModel
if account.Type == AccountTypeAPIKey {
mappedModel = account.GetMappedModel(originalModel)
......@@ -2662,6 +2672,58 @@ func nextGeminiDailyResetUnix() *int64 {
return &ts
}
func ensureGeminiFunctionCallThoughtSignatures(body []byte) []byte {
// Fast path: only run when functionCall is present.
if !bytes.Contains(body, []byte(`"functionCall"`)) {
return body
}
var payload map[string]any
if err := json.Unmarshal(body, &payload); err != nil {
return body
}
contentsAny, ok := payload["contents"].([]any)
if !ok || len(contentsAny) == 0 {
return body
}
modified := false
for _, c := range contentsAny {
cm, ok := c.(map[string]any)
if !ok {
continue
}
partsAny, ok := cm["parts"].([]any)
if !ok || len(partsAny) == 0 {
continue
}
for _, p := range partsAny {
pm, ok := p.(map[string]any)
if !ok || pm == nil {
continue
}
if fc, ok := pm["functionCall"].(map[string]any); !ok || fc == nil {
continue
}
ts, _ := pm["thoughtSignature"].(string)
if strings.TrimSpace(ts) == "" {
pm["thoughtSignature"] = geminiDummyThoughtSignature
modified = true
}
}
}
if !modified {
return body
}
b, err := json.Marshal(payload)
if err != nil {
return body
}
return b
}
func extractGeminiFinishReason(geminiResp map[string]any) string {
if candidates, ok := geminiResp["candidates"].([]any); ok && len(candidates) > 0 {
if cand, ok := candidates[0].(map[string]any); ok {
......@@ -2861,7 +2923,13 @@ func convertClaudeMessagesToGeminiContents(messages any, toolUseIDToName map[str
if strings.TrimSpace(id) != "" && strings.TrimSpace(name) != "" {
toolUseIDToName[id] = name
}
signature, _ := bm["signature"].(string)
signature = strings.TrimSpace(signature)
if signature == "" {
signature = geminiDummyThoughtSignature
}
parts = append(parts, map[string]any{
"thoughtSignature": signature,
"functionCall": map[string]any{
"name": name,
"args": bm["input"],
......
package service
import (
"encoding/json"
"strings"
"testing"
)
......@@ -126,3 +128,78 @@ func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
})
}
}
func TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse(t *testing.T) {
claudeReq := map[string]any{
"model": "claude-haiku-4-5-20251001",
"max_tokens": 10,
"messages": []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{"type": "text", "text": "hi"},
},
},
map[string]any{
"role": "assistant",
"content": []any{
map[string]any{"type": "text", "text": "ok"},
map[string]any{
"type": "tool_use",
"id": "toolu_123",
"name": "default_api:write_file",
"input": map[string]any{"path": "a.txt", "content": "x"},
// no signature on purpose
},
},
},
},
"tools": []any{
map[string]any{
"name": "default_api:write_file",
"description": "write file",
"input_schema": map[string]any{
"type": "object",
"properties": map[string]any{"path": map[string]any{"type": "string"}},
},
},
},
}
b, _ := json.Marshal(claudeReq)
out, err := convertClaudeMessagesToGeminiGenerateContent(b)
if err != nil {
t.Fatalf("convert failed: %v", err)
}
s := string(out)
if !strings.Contains(s, "\"functionCall\"") {
t.Fatalf("expected functionCall in output, got: %s", s)
}
if !strings.Contains(s, "\"thoughtSignature\":\""+geminiDummyThoughtSignature+"\"") {
t.Fatalf("expected injected thoughtSignature %q, got: %s", geminiDummyThoughtSignature, s)
}
}
func TestEnsureGeminiFunctionCallThoughtSignatures_InsertsWhenMissing(t *testing.T) {
geminiReq := map[string]any{
"contents": []any{
map[string]any{
"role": "user",
"parts": []any{
map[string]any{
"functionCall": map[string]any{
"name": "default_api:write_file",
"args": map[string]any{"path": "a.txt"},
},
},
},
},
},
}
b, _ := json.Marshal(geminiReq)
out := ensureGeminiFunctionCallThoughtSignatures(b)
s := string(out)
if !strings.Contains(s, "\"thoughtSignature\":\""+geminiDummyThoughtSignature+"\"") {
t.Fatalf("expected injected thoughtSignature %q, got: %s", geminiDummyThoughtSignature, s)
}
}
......@@ -218,6 +218,14 @@ func (m *mockGroupRepoForGemini) DeleteAccountGroupsByGroupID(ctx context.Contex
return 0, nil
}
func (m *mockGroupRepoForGemini) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
return nil
}
func (m *mockGroupRepoForGemini) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
return nil, nil
}
var _ GroupRepository = (*mockGroupRepoForGemini)(nil)
// mockGatewayCacheForGemini Gemini 测试用的 cache mock
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment