Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
2220fd18
Commit
2220fd18
authored
Feb 03, 2026
by
song
Browse files
merge upstream main
parents
11ff73b5
df4c0adf
Changes
67
Hide whitespace changes
Inline
Side-by-side
backend/internal/service/account_test_service.go
View file @
2220fd18
...
@@ -123,7 +123,7 @@ func createTestPayload(modelID string) (map[string]any, error) {
...
@@ -123,7 +123,7 @@ func createTestPayload(modelID string) (map[string]any, error) {
"system"
:
[]
map
[
string
]
any
{
"system"
:
[]
map
[
string
]
any
{
{
{
"type"
:
"text"
,
"type"
:
"text"
,
"text"
:
"You are C
laude
Code
, Anthropic's official CLI for Claude."
,
"text"
:
c
laudeCode
SystemPrompt
,
"cache_control"
:
map
[
string
]
string
{
"cache_control"
:
map
[
string
]
string
{
"type"
:
"ephemeral"
,
"type"
:
"ephemeral"
,
},
},
...
...
backend/internal/service/admin_service.go
View file @
2220fd18
...
@@ -115,6 +115,8 @@ type CreateGroupInput struct {
...
@@ -115,6 +115,8 @@ type CreateGroupInput struct {
MCPXMLInject
*
bool
MCPXMLInject
*
bool
// 支持的模型系列(仅 antigravity 平台使用)
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes
[]
string
SupportedModelScopes
[]
string
// 从指定分组复制账号(创建分组后在同一事务内绑定)
CopyAccountsFromGroupIDs
[]
int64
}
}
type
UpdateGroupInput
struct
{
type
UpdateGroupInput
struct
{
...
@@ -142,6 +144,8 @@ type UpdateGroupInput struct {
...
@@ -142,6 +144,8 @@ type UpdateGroupInput struct {
MCPXMLInject
*
bool
MCPXMLInject
*
bool
// 支持的模型系列(仅 antigravity 平台使用)
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes
*
[]
string
SupportedModelScopes
*
[]
string
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs
[]
int64
}
}
type
CreateAccountInput
struct
{
type
CreateAccountInput
struct
{
...
@@ -598,6 +602,38 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
...
@@ -598,6 +602,38 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
mcpXMLInject
=
*
input
.
MCPXMLInject
mcpXMLInject
=
*
input
.
MCPXMLInject
}
}
// 如果指定了复制账号的源分组,先获取账号 ID 列表
var
accountIDsToCopy
[]
int64
if
len
(
input
.
CopyAccountsFromGroupIDs
)
>
0
{
// 去重源分组 IDs
seen
:=
make
(
map
[
int64
]
struct
{})
uniqueSourceGroupIDs
:=
make
([]
int64
,
0
,
len
(
input
.
CopyAccountsFromGroupIDs
))
for
_
,
srcGroupID
:=
range
input
.
CopyAccountsFromGroupIDs
{
if
_
,
exists
:=
seen
[
srcGroupID
];
!
exists
{
seen
[
srcGroupID
]
=
struct
{}{}
uniqueSourceGroupIDs
=
append
(
uniqueSourceGroupIDs
,
srcGroupID
)
}
}
// 校验源分组的平台是否与新分组一致
for
_
,
srcGroupID
:=
range
uniqueSourceGroupIDs
{
srcGroup
,
err
:=
s
.
groupRepo
.
GetByIDLite
(
ctx
,
srcGroupID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"source group %d not found: %w"
,
srcGroupID
,
err
)
}
if
srcGroup
.
Platform
!=
platform
{
return
nil
,
fmt
.
Errorf
(
"source group %d platform mismatch: expected %s, got %s"
,
srcGroupID
,
platform
,
srcGroup
.
Platform
)
}
}
// 获取所有源分组的账号(去重)
var
err
error
accountIDsToCopy
,
err
=
s
.
groupRepo
.
GetAccountIDsByGroupIDs
(
ctx
,
uniqueSourceGroupIDs
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to get accounts from source groups: %w"
,
err
)
}
}
group
:=
&
Group
{
group
:=
&
Group
{
Name
:
input
.
Name
,
Name
:
input
.
Name
,
Description
:
input
.
Description
,
Description
:
input
.
Description
,
...
@@ -622,6 +658,15 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
...
@@ -622,6 +658,15 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
if
err
:=
s
.
groupRepo
.
Create
(
ctx
,
group
);
err
!=
nil
{
if
err
:=
s
.
groupRepo
.
Create
(
ctx
,
group
);
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
// 如果有需要复制的账号,绑定到新分组
if
len
(
accountIDsToCopy
)
>
0
{
if
err
:=
s
.
groupRepo
.
BindAccountsToGroup
(
ctx
,
group
.
ID
,
accountIDsToCopy
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to bind accounts to new group: %w"
,
err
)
}
group
.
AccountCount
=
int64
(
len
(
accountIDsToCopy
))
}
return
group
,
nil
return
group
,
nil
}
}
...
@@ -810,6 +855,54 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
...
@@ -810,6 +855,54 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if
err
:=
s
.
groupRepo
.
Update
(
ctx
,
group
);
err
!=
nil
{
if
err
:=
s
.
groupRepo
.
Update
(
ctx
,
group
);
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
// 如果指定了复制账号的源分组,同步绑定(替换当前分组的账号)
if
len
(
input
.
CopyAccountsFromGroupIDs
)
>
0
{
// 去重源分组 IDs
seen
:=
make
(
map
[
int64
]
struct
{})
uniqueSourceGroupIDs
:=
make
([]
int64
,
0
,
len
(
input
.
CopyAccountsFromGroupIDs
))
for
_
,
srcGroupID
:=
range
input
.
CopyAccountsFromGroupIDs
{
// 校验:源分组不能是自身
if
srcGroupID
==
id
{
return
nil
,
fmt
.
Errorf
(
"cannot copy accounts from self"
)
}
// 去重
if
_
,
exists
:=
seen
[
srcGroupID
];
!
exists
{
seen
[
srcGroupID
]
=
struct
{}{}
uniqueSourceGroupIDs
=
append
(
uniqueSourceGroupIDs
,
srcGroupID
)
}
}
// 校验源分组的平台是否与当前分组一致
for
_
,
srcGroupID
:=
range
uniqueSourceGroupIDs
{
srcGroup
,
err
:=
s
.
groupRepo
.
GetByIDLite
(
ctx
,
srcGroupID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"source group %d not found: %w"
,
srcGroupID
,
err
)
}
if
srcGroup
.
Platform
!=
group
.
Platform
{
return
nil
,
fmt
.
Errorf
(
"source group %d platform mismatch: expected %s, got %s"
,
srcGroupID
,
group
.
Platform
,
srcGroup
.
Platform
)
}
}
// 获取所有源分组的账号(去重)
accountIDsToCopy
,
err
:=
s
.
groupRepo
.
GetAccountIDsByGroupIDs
(
ctx
,
uniqueSourceGroupIDs
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to get accounts from source groups: %w"
,
err
)
}
// 先清空当前分组的所有账号绑定
if
_
,
err
:=
s
.
groupRepo
.
DeleteAccountGroupsByGroupID
(
ctx
,
id
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to clear existing account bindings: %w"
,
err
)
}
// 再绑定源分组的账号
if
len
(
accountIDsToCopy
)
>
0
{
if
err
:=
s
.
groupRepo
.
BindAccountsToGroup
(
ctx
,
id
,
accountIDsToCopy
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to bind accounts to group: %w"
,
err
)
}
}
}
if
s
.
authCacheInvalidator
!=
nil
{
if
s
.
authCacheInvalidator
!=
nil
{
s
.
authCacheInvalidator
.
InvalidateAuthCacheByGroupID
(
ctx
,
id
)
s
.
authCacheInvalidator
.
InvalidateAuthCacheByGroupID
(
ctx
,
id
)
}
}
...
...
backend/internal/service/admin_service_delete_test.go
View file @
2220fd18
...
@@ -164,6 +164,14 @@ func (s *groupRepoStub) DeleteAccountGroupsByGroupID(ctx context.Context, groupI
...
@@ -164,6 +164,14 @@ func (s *groupRepoStub) DeleteAccountGroupsByGroupID(ctx context.Context, groupI
panic
(
"unexpected DeleteAccountGroupsByGroupID call"
)
panic
(
"unexpected DeleteAccountGroupsByGroupID call"
)
}
}
func
(
s
*
groupRepoStub
)
BindAccountsToGroup
(
ctx
context
.
Context
,
groupID
int64
,
accountIDs
[]
int64
)
error
{
panic
(
"unexpected BindAccountsToGroup call"
)
}
func
(
s
*
groupRepoStub
)
GetAccountIDsByGroupIDs
(
ctx
context
.
Context
,
groupIDs
[]
int64
)
([]
int64
,
error
)
{
panic
(
"unexpected GetAccountIDsByGroupIDs call"
)
}
type
proxyRepoStub
struct
{
type
proxyRepoStub
struct
{
deleteErr
error
deleteErr
error
countErr
error
countErr
error
...
...
backend/internal/service/admin_service_group_test.go
View file @
2220fd18
...
@@ -108,6 +108,14 @@ func (s *groupRepoStubForAdmin) DeleteAccountGroupsByGroupID(_ context.Context,
...
@@ -108,6 +108,14 @@ func (s *groupRepoStubForAdmin) DeleteAccountGroupsByGroupID(_ context.Context,
panic
(
"unexpected DeleteAccountGroupsByGroupID call"
)
panic
(
"unexpected DeleteAccountGroupsByGroupID call"
)
}
}
func
(
s
*
groupRepoStubForAdmin
)
BindAccountsToGroup
(
_
context
.
Context
,
_
int64
,
_
[]
int64
)
error
{
panic
(
"unexpected BindAccountsToGroup call"
)
}
func
(
s
*
groupRepoStubForAdmin
)
GetAccountIDsByGroupIDs
(
_
context
.
Context
,
_
[]
int64
)
([]
int64
,
error
)
{
panic
(
"unexpected GetAccountIDsByGroupIDs call"
)
}
// TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递
// TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递
func
TestAdminService_CreateGroup_WithImagePricing
(
t
*
testing
.
T
)
{
func
TestAdminService_CreateGroup_WithImagePricing
(
t
*
testing
.
T
)
{
repo
:=
&
groupRepoStubForAdmin
{}
repo
:=
&
groupRepoStubForAdmin
{}
...
@@ -379,6 +387,14 @@ func (s *groupRepoStubForFallbackCycle) DeleteAccountGroupsByGroupID(_ context.C
...
@@ -379,6 +387,14 @@ func (s *groupRepoStubForFallbackCycle) DeleteAccountGroupsByGroupID(_ context.C
panic
(
"unexpected DeleteAccountGroupsByGroupID call"
)
panic
(
"unexpected DeleteAccountGroupsByGroupID call"
)
}
}
func
(
s
*
groupRepoStubForFallbackCycle
)
BindAccountsToGroup
(
_
context
.
Context
,
_
int64
,
_
[]
int64
)
error
{
panic
(
"unexpected BindAccountsToGroup call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
GetAccountIDsByGroupIDs
(
_
context
.
Context
,
_
[]
int64
)
([]
int64
,
error
)
{
panic
(
"unexpected GetAccountIDsByGroupIDs call"
)
}
type
groupRepoStubForInvalidRequestFallback
struct
{
type
groupRepoStubForInvalidRequestFallback
struct
{
groups
map
[
int64
]
*
Group
groups
map
[
int64
]
*
Group
created
*
Group
created
*
Group
...
@@ -748,4 +764,4 @@ func TestAdminService_UpdateGroup_InvalidRequestFallbackAllowsAntigravity(t *tes
...
@@ -748,4 +764,4 @@ func TestAdminService_UpdateGroup_InvalidRequestFallbackAllowsAntigravity(t *tes
require
.
NotNil
(
t
,
group
)
require
.
NotNil
(
t
,
group
)
require
.
NotNil
(
t
,
repo
.
updated
)
require
.
NotNil
(
t
,
repo
.
updated
)
require
.
Equal
(
t
,
fallbackID
,
*
repo
.
updated
.
FallbackGroupIDOnInvalidRequest
)
require
.
Equal
(
t
,
fallbackID
,
*
repo
.
updated
.
FallbackGroupIDOnInvalidRequest
)
}
}
backend/internal/service/antigravity_gateway_service.go
View file @
2220fd18
...
@@ -302,13 +302,11 @@ func logPrefix(sessionID, accountName string) string {
...
@@ -302,13 +302,11 @@ func logPrefix(sessionID, accountName string) string {
}
}
// Antigravity 直接支持的模型(精确匹配透传)
// Antigravity 直接支持的模型(精确匹配透传)
// 注意:gemini-2.5 系列已移除,统一映射到 gemini-3 系列
var
antigravitySupportedModels
=
map
[
string
]
bool
{
var
antigravitySupportedModels
=
map
[
string
]
bool
{
"claude-opus-4-5-thinking"
:
true
,
"claude-opus-4-5-thinking"
:
true
,
"claude-sonnet-4-5"
:
true
,
"claude-sonnet-4-5"
:
true
,
"claude-sonnet-4-5-thinking"
:
true
,
"claude-sonnet-4-5-thinking"
:
true
,
"gemini-2.5-flash"
:
true
,
"gemini-2.5-flash-lite"
:
true
,
"gemini-2.5-flash-thinking"
:
true
,
"gemini-3-flash"
:
true
,
"gemini-3-flash"
:
true
,
"gemini-3-pro-low"
:
true
,
"gemini-3-pro-low"
:
true
,
"gemini-3-pro-high"
:
true
,
"gemini-3-pro-high"
:
true
,
...
@@ -317,23 +315,32 @@ var antigravitySupportedModels = map[string]bool{
...
@@ -317,23 +315,32 @@ var antigravitySupportedModels = map[string]bool{
// Antigravity 前缀映射表(按前缀长度降序排列,确保最长匹配优先)
// Antigravity 前缀映射表(按前缀长度降序排列,确保最长匹配优先)
// 用于处理模型版本号变化(如 -20251111, -thinking, -preview 等后缀)
// 用于处理模型版本号变化(如 -20251111, -thinking, -preview 等后缀)
// gemini-2.5 系列统一映射到 gemini-3 系列(Antigravity 上游不再支持 2.5)
var
antigravityPrefixMapping
=
[]
struct
{
var
antigravityPrefixMapping
=
[]
struct
{
prefix
string
prefix
string
target
string
target
string
}{
}{
// 长前缀优先
// gemini-2.5 → gemini-3 映射(长前缀优先)
{
"gemini-2.5-flash-image"
,
"gemini-3-pro-image"
},
// gemini-2.5-flash-image → 3-pro-image
{
"gemini-2.5-flash-thinking"
,
"gemini-3-flash"
},
// gemini-2.5-flash-thinking → gemini-3-flash
{
"gemini-3-pro-image"
,
"gemini-3-pro-image"
},
// gemini-3-pro-image-preview 等
{
"gemini-2.5-flash-image"
,
"gemini-3-pro-image"
},
// gemini-2.5-flash-image → gemini-3-pro-image
{
"gemini-3-flash"
,
"gemini-3-flash"
},
// gemini-3-flash-preview 等 → gemini-3-flash
{
"gemini-2.5-flash-lite"
,
"gemini-3-flash"
},
// gemini-2.5-flash-lite → gemini-3-flash
{
"claude-3-5-sonnet"
,
"claude-sonnet-4-5"
},
// 旧版 claude-3-5-sonnet-xxx
{
"gemini-2.5-flash"
,
"gemini-3-flash"
},
// gemini-2.5-flash → gemini-3-flash
{
"claude-sonnet-4-5"
,
"claude-sonnet-4-5"
},
// claude-sonnet-4-5-xxx
{
"gemini-2.5-pro-preview"
,
"gemini-3-pro-high"
},
// gemini-2.5-pro-preview → gemini-3-pro-high
{
"claude-haiku-4-5"
,
"claude-sonnet-4-5"
},
// claude-haiku-4-5-xxx → sonnet
{
"gemini-2.5-pro-exp"
,
"gemini-3-pro-high"
},
// gemini-2.5-pro-exp → gemini-3-pro-high
{
"gemini-2.5-pro"
,
"gemini-3-pro-high"
},
// gemini-2.5-pro → gemini-3-pro-high
// gemini-3 前缀映射
{
"gemini-3-pro-image"
,
"gemini-3-pro-image"
},
// gemini-3-pro-image-preview 等
{
"gemini-3-flash"
,
"gemini-3-flash"
},
// gemini-3-flash-preview 等 → gemini-3-flash
{
"gemini-3-pro"
,
"gemini-3-pro-high"
},
// gemini-3-pro, gemini-3-pro-preview 等
// Claude 映射
{
"claude-3-5-sonnet"
,
"claude-sonnet-4-5"
},
// 旧版 claude-3-5-sonnet-xxx
{
"claude-sonnet-4-5"
,
"claude-sonnet-4-5"
},
// claude-sonnet-4-5-xxx
{
"claude-haiku-4-5"
,
"claude-sonnet-4-5"
},
// claude-haiku-4-5-xxx → sonnet
{
"claude-opus-4-5"
,
"claude-opus-4-5-thinking"
},
{
"claude-opus-4-5"
,
"claude-opus-4-5-thinking"
},
{
"claude-3-haiku"
,
"claude-sonnet-4-5"
},
// 旧版 claude-3-haiku-xxx → sonnet
{
"claude-3-haiku"
,
"claude-sonnet-4-5"
},
// 旧版 claude-3-haiku-xxx → sonnet
{
"claude-sonnet-4"
,
"claude-sonnet-4-5"
},
{
"claude-sonnet-4"
,
"claude-sonnet-4-5"
},
{
"claude-haiku-4"
,
"claude-sonnet-4-5"
},
// → sonnet
{
"claude-haiku-4"
,
"claude-sonnet-4-5"
},
// → sonnet
{
"claude-opus-4"
,
"claude-opus-4-5-thinking"
},
{
"claude-opus-4"
,
"claude-opus-4-5-thinking"
},
{
"gemini-3-pro"
,
"gemini-3-pro-high"
},
// gemini-3-pro, gemini-3-pro-preview 等
}
}
// AntigravityGatewayService 处理 Antigravity 平台的 API 转发
// AntigravityGatewayService 处理 Antigravity 平台的 API 转发
...
...
backend/internal/service/antigravity_gateway_service_test.go
View file @
2220fd18
...
@@ -103,6 +103,10 @@ func (s *httpUpstreamStub) Do(_ *http.Request, _ string, _ int64, _ int) (*http.
...
@@ -103,6 +103,10 @@ func (s *httpUpstreamStub) Do(_ *http.Request, _ string, _ int64, _ int) (*http.
return
s
.
resp
,
s
.
err
return
s
.
resp
,
s
.
err
}
}
func
(
s
*
httpUpstreamStub
)
DoWithTLS
(
_
*
http
.
Request
,
_
string
,
_
int64
,
_
int
,
_
bool
)
(
*
http
.
Response
,
error
)
{
return
s
.
resp
,
s
.
err
}
func
TestAntigravityGatewayService_Forward_PromptTooLong
(
t
*
testing
.
T
)
{
func
TestAntigravityGatewayService_Forward_PromptTooLong
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
gin
.
SetMode
(
gin
.
TestMode
)
writer
:=
httptest
.
NewRecorder
()
writer
:=
httptest
.
NewRecorder
()
...
...
backend/internal/service/antigravity_model_mapping_test.go
View file @
2220fd18
...
@@ -134,18 +134,18 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
...
@@ -134,18 +134,18 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
expected
:
"claude-sonnet-4-5"
,
expected
:
"claude-sonnet-4-5"
,
},
},
// 3. Gemini
透传
// 3. Gemini
2.5 → 3 映射
{
{
name
:
"Gemini
透传
- gemini-2.5-flash"
,
name
:
"Gemini
映射
- gemini-2.5-flash
→ gemini-3-flash
"
,
requestedModel
:
"gemini-2.5-flash"
,
requestedModel
:
"gemini-2.5-flash"
,
accountMapping
:
nil
,
accountMapping
:
nil
,
expected
:
"gemini-
2.5
-flash"
,
expected
:
"gemini-
3
-flash"
,
},
},
{
{
name
:
"Gemini
透传
- gemini-2.5-pro"
,
name
:
"Gemini
映射
- gemini-2.5-pro
→ gemini-3-pro-high
"
,
requestedModel
:
"gemini-2.5-pro"
,
requestedModel
:
"gemini-2.5-pro"
,
accountMapping
:
nil
,
accountMapping
:
nil
,
expected
:
"gemini-
2.5
-pro"
,
expected
:
"gemini-
3
-pro
-high
"
,
},
},
{
{
name
:
"Gemini透传 - gemini-future-model"
,
name
:
"Gemini透传 - gemini-future-model"
,
...
...
backend/internal/service/auth_service.go
View file @
2220fd18
...
@@ -19,17 +19,19 @@ import (
...
@@ -19,17 +19,19 @@ import (
)
)
var
(
var
(
ErrInvalidCredentials
=
infraerrors
.
Unauthorized
(
"INVALID_CREDENTIALS"
,
"invalid email or password"
)
ErrInvalidCredentials
=
infraerrors
.
Unauthorized
(
"INVALID_CREDENTIALS"
,
"invalid email or password"
)
ErrUserNotActive
=
infraerrors
.
Forbidden
(
"USER_NOT_ACTIVE"
,
"user is not active"
)
ErrUserNotActive
=
infraerrors
.
Forbidden
(
"USER_NOT_ACTIVE"
,
"user is not active"
)
ErrEmailExists
=
infraerrors
.
Conflict
(
"EMAIL_EXISTS"
,
"email already exists"
)
ErrEmailExists
=
infraerrors
.
Conflict
(
"EMAIL_EXISTS"
,
"email already exists"
)
ErrEmailReserved
=
infraerrors
.
BadRequest
(
"EMAIL_RESERVED"
,
"email is reserved"
)
ErrEmailReserved
=
infraerrors
.
BadRequest
(
"EMAIL_RESERVED"
,
"email is reserved"
)
ErrInvalidToken
=
infraerrors
.
Unauthorized
(
"INVALID_TOKEN"
,
"invalid token"
)
ErrInvalidToken
=
infraerrors
.
Unauthorized
(
"INVALID_TOKEN"
,
"invalid token"
)
ErrTokenExpired
=
infraerrors
.
Unauthorized
(
"TOKEN_EXPIRED"
,
"token has expired"
)
ErrTokenExpired
=
infraerrors
.
Unauthorized
(
"TOKEN_EXPIRED"
,
"token has expired"
)
ErrTokenTooLarge
=
infraerrors
.
BadRequest
(
"TOKEN_TOO_LARGE"
,
"token too large"
)
ErrTokenTooLarge
=
infraerrors
.
BadRequest
(
"TOKEN_TOO_LARGE"
,
"token too large"
)
ErrTokenRevoked
=
infraerrors
.
Unauthorized
(
"TOKEN_REVOKED"
,
"token has been revoked"
)
ErrTokenRevoked
=
infraerrors
.
Unauthorized
(
"TOKEN_REVOKED"
,
"token has been revoked"
)
ErrEmailVerifyRequired
=
infraerrors
.
BadRequest
(
"EMAIL_VERIFY_REQUIRED"
,
"email verification is required"
)
ErrEmailVerifyRequired
=
infraerrors
.
BadRequest
(
"EMAIL_VERIFY_REQUIRED"
,
"email verification is required"
)
ErrRegDisabled
=
infraerrors
.
Forbidden
(
"REGISTRATION_DISABLED"
,
"registration is currently disabled"
)
ErrRegDisabled
=
infraerrors
.
Forbidden
(
"REGISTRATION_DISABLED"
,
"registration is currently disabled"
)
ErrServiceUnavailable
=
infraerrors
.
ServiceUnavailable
(
"SERVICE_UNAVAILABLE"
,
"service temporarily unavailable"
)
ErrServiceUnavailable
=
infraerrors
.
ServiceUnavailable
(
"SERVICE_UNAVAILABLE"
,
"service temporarily unavailable"
)
ErrInvitationCodeRequired
=
infraerrors
.
BadRequest
(
"INVITATION_CODE_REQUIRED"
,
"invitation code is required"
)
ErrInvitationCodeInvalid
=
infraerrors
.
BadRequest
(
"INVITATION_CODE_INVALID"
,
"invalid or used invitation code"
)
)
)
// maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。
// maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。
...
@@ -47,6 +49,7 @@ type JWTClaims struct {
...
@@ -47,6 +49,7 @@ type JWTClaims struct {
// AuthService 认证服务
// AuthService 认证服务
type
AuthService
struct
{
type
AuthService
struct
{
userRepo
UserRepository
userRepo
UserRepository
redeemRepo
RedeemCodeRepository
cfg
*
config
.
Config
cfg
*
config
.
Config
settingService
*
SettingService
settingService
*
SettingService
emailService
*
EmailService
emailService
*
EmailService
...
@@ -58,6 +61,7 @@ type AuthService struct {
...
@@ -58,6 +61,7 @@ type AuthService struct {
// NewAuthService 创建认证服务实例
// NewAuthService 创建认证服务实例
func
NewAuthService
(
func
NewAuthService
(
userRepo
UserRepository
,
userRepo
UserRepository
,
redeemRepo
RedeemCodeRepository
,
cfg
*
config
.
Config
,
cfg
*
config
.
Config
,
settingService
*
SettingService
,
settingService
*
SettingService
,
emailService
*
EmailService
,
emailService
*
EmailService
,
...
@@ -67,6 +71,7 @@ func NewAuthService(
...
@@ -67,6 +71,7 @@ func NewAuthService(
)
*
AuthService
{
)
*
AuthService
{
return
&
AuthService
{
return
&
AuthService
{
userRepo
:
userRepo
,
userRepo
:
userRepo
,
redeemRepo
:
redeemRepo
,
cfg
:
cfg
,
cfg
:
cfg
,
settingService
:
settingService
,
settingService
:
settingService
,
emailService
:
emailService
,
emailService
:
emailService
,
...
@@ -78,11 +83,11 @@ func NewAuthService(
...
@@ -78,11 +83,11 @@ func NewAuthService(
// Register 用户注册,返回token和用户
// Register 用户注册,返回token和用户
func
(
s
*
AuthService
)
Register
(
ctx
context
.
Context
,
email
,
password
string
)
(
string
,
*
User
,
error
)
{
func
(
s
*
AuthService
)
Register
(
ctx
context
.
Context
,
email
,
password
string
)
(
string
,
*
User
,
error
)
{
return
s
.
RegisterWithVerification
(
ctx
,
email
,
password
,
""
,
""
)
return
s
.
RegisterWithVerification
(
ctx
,
email
,
password
,
""
,
""
,
""
)
}
}
// RegisterWithVerification 用户注册(支持邮件验证
和
优惠码),返回token和用户
// RegisterWithVerification 用户注册(支持邮件验证
、
优惠码
和邀请码
),返回token和用户
func
(
s
*
AuthService
)
RegisterWithVerification
(
ctx
context
.
Context
,
email
,
password
,
verifyCode
,
promoCode
string
)
(
string
,
*
User
,
error
)
{
func
(
s
*
AuthService
)
RegisterWithVerification
(
ctx
context
.
Context
,
email
,
password
,
verifyCode
,
promoCode
,
invitationCode
string
)
(
string
,
*
User
,
error
)
{
// 检查是否开放注册(默认关闭:settingService 未配置时不允许注册)
// 检查是否开放注册(默认关闭:settingService 未配置时不允许注册)
if
s
.
settingService
==
nil
||
!
s
.
settingService
.
IsRegistrationEnabled
(
ctx
)
{
if
s
.
settingService
==
nil
||
!
s
.
settingService
.
IsRegistrationEnabled
(
ctx
)
{
return
""
,
nil
,
ErrRegDisabled
return
""
,
nil
,
ErrRegDisabled
...
@@ -93,6 +98,26 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
...
@@ -93,6 +98,26 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return
""
,
nil
,
ErrEmailReserved
return
""
,
nil
,
ErrEmailReserved
}
}
// 检查是否需要邀请码
var
invitationRedeemCode
*
RedeemCode
if
s
.
settingService
!=
nil
&&
s
.
settingService
.
IsInvitationCodeEnabled
(
ctx
)
{
if
invitationCode
==
""
{
return
""
,
nil
,
ErrInvitationCodeRequired
}
// 验证邀请码
redeemCode
,
err
:=
s
.
redeemRepo
.
GetByCode
(
ctx
,
invitationCode
)
if
err
!=
nil
{
log
.
Printf
(
"[Auth] Invalid invitation code: %s, error: %v"
,
invitationCode
,
err
)
return
""
,
nil
,
ErrInvitationCodeInvalid
}
// 检查类型和状态
if
redeemCode
.
Type
!=
RedeemTypeInvitation
||
redeemCode
.
Status
!=
StatusUnused
{
log
.
Printf
(
"[Auth] Invitation code invalid: type=%s, status=%s"
,
redeemCode
.
Type
,
redeemCode
.
Status
)
return
""
,
nil
,
ErrInvitationCodeInvalid
}
invitationRedeemCode
=
redeemCode
}
// 检查是否需要邮件验证
// 检查是否需要邮件验证
if
s
.
settingService
!=
nil
&&
s
.
settingService
.
IsEmailVerifyEnabled
(
ctx
)
{
if
s
.
settingService
!=
nil
&&
s
.
settingService
.
IsEmailVerifyEnabled
(
ctx
)
{
// 如果邮件验证已开启但邮件服务未配置,拒绝注册
// 如果邮件验证已开启但邮件服务未配置,拒绝注册
...
@@ -153,6 +178,13 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
...
@@ -153,6 +178,13 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return
""
,
nil
,
ErrServiceUnavailable
return
""
,
nil
,
ErrServiceUnavailable
}
}
// 标记邀请码为已使用(如果使用了邀请码)
if
invitationRedeemCode
!=
nil
{
if
err
:=
s
.
redeemRepo
.
Use
(
ctx
,
invitationRedeemCode
.
ID
,
user
.
ID
);
err
!=
nil
{
// 邀请码标记失败不影响注册,只记录日志
log
.
Printf
(
"[Auth] Failed to mark invitation code as used for user %d: %v"
,
user
.
ID
,
err
)
}
}
// 应用优惠码(如果提供且功能已启用)
// 应用优惠码(如果提供且功能已启用)
if
promoCode
!=
""
&&
s
.
promoService
!=
nil
&&
s
.
settingService
!=
nil
&&
s
.
settingService
.
IsPromoCodeEnabled
(
ctx
)
{
if
promoCode
!=
""
&&
s
.
promoService
!=
nil
&&
s
.
settingService
!=
nil
&&
s
.
settingService
.
IsPromoCodeEnabled
(
ctx
)
{
if
err
:=
s
.
promoService
.
ApplyPromoCode
(
ctx
,
user
.
ID
,
promoCode
);
err
!=
nil
{
if
err
:=
s
.
promoService
.
ApplyPromoCode
(
ctx
,
user
.
ID
,
promoCode
);
err
!=
nil
{
...
...
backend/internal/service/auth_service_register_test.go
View file @
2220fd18
...
@@ -115,6 +115,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
...
@@ -115,6 +115,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
return
NewAuthService
(
return
NewAuthService
(
repo
,
repo
,
nil
,
// redeemRepo
cfg
,
cfg
,
settingService
,
settingService
,
emailService
,
emailService
,
...
@@ -152,7 +153,7 @@ func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testi
...
@@ -152,7 +153,7 @@ func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testi
},
nil
)
},
nil
)
// 应返回服务不可用错误,而不是允许绕过验证
// 应返回服务不可用错误,而不是允许绕过验证
_
,
_
,
err
:=
service
.
RegisterWithVerification
(
context
.
Background
(),
"user@test.com"
,
"password"
,
"any-code"
,
""
)
_
,
_
,
err
:=
service
.
RegisterWithVerification
(
context
.
Background
(),
"user@test.com"
,
"password"
,
"any-code"
,
""
,
""
)
require
.
ErrorIs
(
t
,
err
,
ErrServiceUnavailable
)
require
.
ErrorIs
(
t
,
err
,
ErrServiceUnavailable
)
}
}
...
@@ -164,7 +165,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
...
@@ -164,7 +165,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
SettingKeyEmailVerifyEnabled
:
"true"
,
SettingKeyEmailVerifyEnabled
:
"true"
,
},
cache
)
},
cache
)
_
,
_
,
err
:=
service
.
RegisterWithVerification
(
context
.
Background
(),
"user@test.com"
,
"password"
,
""
,
""
)
_
,
_
,
err
:=
service
.
RegisterWithVerification
(
context
.
Background
(),
"user@test.com"
,
"password"
,
""
,
""
,
""
)
require
.
ErrorIs
(
t
,
err
,
ErrEmailVerifyRequired
)
require
.
ErrorIs
(
t
,
err
,
ErrEmailVerifyRequired
)
}
}
...
@@ -178,7 +179,7 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
...
@@ -178,7 +179,7 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
SettingKeyEmailVerifyEnabled
:
"true"
,
SettingKeyEmailVerifyEnabled
:
"true"
,
},
cache
)
},
cache
)
_
,
_
,
err
:=
service
.
RegisterWithVerification
(
context
.
Background
(),
"user@test.com"
,
"password"
,
"wrong"
,
""
)
_
,
_
,
err
:=
service
.
RegisterWithVerification
(
context
.
Background
(),
"user@test.com"
,
"password"
,
"wrong"
,
""
,
""
)
require
.
ErrorIs
(
t
,
err
,
ErrInvalidVerifyCode
)
require
.
ErrorIs
(
t
,
err
,
ErrInvalidVerifyCode
)
require
.
ErrorContains
(
t
,
err
,
"verify code"
)
require
.
ErrorContains
(
t
,
err
,
"verify code"
)
}
}
...
...
backend/internal/service/billing_service.go
View file @
2220fd18
...
@@ -241,6 +241,76 @@ func (s *BillingService) CalculateCostWithConfig(model string, tokens UsageToken
...
@@ -241,6 +241,76 @@ func (s *BillingService) CalculateCostWithConfig(model string, tokens UsageToken
return
s
.
CalculateCost
(
model
,
tokens
,
multiplier
)
return
s
.
CalculateCost
(
model
,
tokens
,
multiplier
)
}
}
// CalculateCostWithLongContext 计算费用,支持长上下文双倍计费
// threshold: 阈值(如 200000),超过此值的部分按 extraMultiplier 倍计费
// extraMultiplier: 超出部分的倍率(如 2.0 表示双倍)
//
// 示例:缓存 210k + 输入 10k = 220k,阈值 200k,倍率 2.0
// 拆分为:范围内 (200k, 0) + 范围外 (10k, 10k)
// 范围内正常计费,范围外 × 2 计费
func
(
s
*
BillingService
)
CalculateCostWithLongContext
(
model
string
,
tokens
UsageTokens
,
rateMultiplier
float64
,
threshold
int
,
extraMultiplier
float64
)
(
*
CostBreakdown
,
error
)
{
// 未启用长上下文计费,直接走正常计费
if
threshold
<=
0
||
extraMultiplier
<=
1
{
return
s
.
CalculateCost
(
model
,
tokens
,
rateMultiplier
)
}
// 计算总输入 token(缓存读取 + 新输入)
total
:=
tokens
.
CacheReadTokens
+
tokens
.
InputTokens
if
total
<=
threshold
{
return
s
.
CalculateCost
(
model
,
tokens
,
rateMultiplier
)
}
// 拆分成范围内和范围外
var
inRangeCacheTokens
,
inRangeInputTokens
int
var
outRangeCacheTokens
,
outRangeInputTokens
int
if
tokens
.
CacheReadTokens
>=
threshold
{
// 缓存已超过阈值:范围内只有缓存,范围外是超出的缓存+全部输入
inRangeCacheTokens
=
threshold
inRangeInputTokens
=
0
outRangeCacheTokens
=
tokens
.
CacheReadTokens
-
threshold
outRangeInputTokens
=
tokens
.
InputTokens
}
else
{
// 缓存未超过阈值:范围内是全部缓存+部分输入,范围外是剩余输入
inRangeCacheTokens
=
tokens
.
CacheReadTokens
inRangeInputTokens
=
threshold
-
tokens
.
CacheReadTokens
outRangeCacheTokens
=
0
outRangeInputTokens
=
tokens
.
InputTokens
-
inRangeInputTokens
}
// 范围内部分:正常计费
inRangeTokens
:=
UsageTokens
{
InputTokens
:
inRangeInputTokens
,
OutputTokens
:
tokens
.
OutputTokens
,
// 输出只算一次
CacheCreationTokens
:
tokens
.
CacheCreationTokens
,
CacheReadTokens
:
inRangeCacheTokens
,
}
inRangeCost
,
err
:=
s
.
CalculateCost
(
model
,
inRangeTokens
,
rateMultiplier
)
if
err
!=
nil
{
return
nil
,
err
}
// 范围外部分:× extraMultiplier 计费
outRangeTokens
:=
UsageTokens
{
InputTokens
:
outRangeInputTokens
,
CacheReadTokens
:
outRangeCacheTokens
,
}
outRangeCost
,
err
:=
s
.
CalculateCost
(
model
,
outRangeTokens
,
rateMultiplier
*
extraMultiplier
)
if
err
!=
nil
{
return
inRangeCost
,
nil
// 出错时返回范围内成本
}
// 合并成本
return
&
CostBreakdown
{
InputCost
:
inRangeCost
.
InputCost
+
outRangeCost
.
InputCost
,
OutputCost
:
inRangeCost
.
OutputCost
,
CacheCreationCost
:
inRangeCost
.
CacheCreationCost
,
CacheReadCost
:
inRangeCost
.
CacheReadCost
+
outRangeCost
.
CacheReadCost
,
TotalCost
:
inRangeCost
.
TotalCost
+
outRangeCost
.
TotalCost
,
ActualCost
:
inRangeCost
.
ActualCost
+
outRangeCost
.
ActualCost
,
},
nil
}
// ListSupportedModels 列出所有支持的模型(现在总是返回true,因为有模糊匹配)
// ListSupportedModels 列出所有支持的模型(现在总是返回true,因为有模糊匹配)
func
(
s
*
BillingService
)
ListSupportedModels
()
[]
string
{
func
(
s
*
BillingService
)
ListSupportedModels
()
[]
string
{
models
:=
make
([]
string
,
0
)
models
:=
make
([]
string
,
0
)
...
...
backend/internal/service/domain_constants.go
View file @
2220fd18
...
@@ -39,6 +39,7 @@ const (
...
@@ -39,6 +39,7 @@ const (
RedeemTypeBalance
=
domain
.
RedeemTypeBalance
RedeemTypeBalance
=
domain
.
RedeemTypeBalance
RedeemTypeConcurrency
=
domain
.
RedeemTypeConcurrency
RedeemTypeConcurrency
=
domain
.
RedeemTypeConcurrency
RedeemTypeSubscription
=
domain
.
RedeemTypeSubscription
RedeemTypeSubscription
=
domain
.
RedeemTypeSubscription
RedeemTypeInvitation
=
domain
.
RedeemTypeInvitation
)
)
// PromoCode status constants
// PromoCode status constants
...
@@ -72,10 +73,11 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
...
@@ -72,10 +73,11 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
// Setting keys
// Setting keys
const
(
const
(
// 注册设置
// 注册设置
SettingKeyRegistrationEnabled
=
"registration_enabled"
// 是否开放注册
SettingKeyRegistrationEnabled
=
"registration_enabled"
// 是否开放注册
SettingKeyEmailVerifyEnabled
=
"email_verify_enabled"
// 是否开启邮件验证
SettingKeyEmailVerifyEnabled
=
"email_verify_enabled"
// 是否开启邮件验证
SettingKeyPromoCodeEnabled
=
"promo_code_enabled"
// 是否启用优惠码功能
SettingKeyPromoCodeEnabled
=
"promo_code_enabled"
// 是否启用优惠码功能
SettingKeyPasswordResetEnabled
=
"password_reset_enabled"
// 是否启用忘记密码功能(需要先开启邮件验证)
SettingKeyPasswordResetEnabled
=
"password_reset_enabled"
// 是否启用忘记密码功能(需要先开启邮件验证)
SettingKeyInvitationCodeEnabled
=
"invitation_code_enabled"
// 是否启用邀请码注册
// 邮件服务设置
// 邮件服务设置
SettingKeySMTPHost
=
"smtp_host"
// SMTP服务器地址
SettingKeySMTPHost
=
"smtp_host"
// SMTP服务器地址
...
...
backend/internal/service/gateway_beta_test.go
0 → 100644
View file @
2220fd18
package
service
import
(
"testing"
"github.com/stretchr/testify/require"
)
func
TestMergeAnthropicBeta
(
t
*
testing
.
T
)
{
got
:=
mergeAnthropicBeta
(
[]
string
{
"oauth-2025-04-20"
,
"interleaved-thinking-2025-05-14"
},
"foo, oauth-2025-04-20,bar, foo"
,
)
require
.
Equal
(
t
,
"oauth-2025-04-20,interleaved-thinking-2025-05-14,foo,bar"
,
got
)
}
func
TestMergeAnthropicBeta_EmptyIncoming
(
t
*
testing
.
T
)
{
got
:=
mergeAnthropicBeta
(
[]
string
{
"oauth-2025-04-20"
,
"interleaved-thinking-2025-05-14"
},
""
,
)
require
.
Equal
(
t
,
"oauth-2025-04-20,interleaved-thinking-2025-05-14"
,
got
)
}
backend/internal/service/gateway_multiplatform_test.go
View file @
2220fd18
...
@@ -266,6 +266,14 @@ func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Conte
...
@@ -266,6 +266,14 @@ func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Conte
return
0
,
nil
return
0
,
nil
}
}
func
(
m
*
mockGroupRepoForGateway
)
BindAccountsToGroup
(
ctx
context
.
Context
,
groupID
int64
,
accountIDs
[]
int64
)
error
{
return
nil
}
func
(
m
*
mockGroupRepoForGateway
)
GetAccountIDsByGroupIDs
(
ctx
context
.
Context
,
groupIDs
[]
int64
)
([]
int64
,
error
)
{
return
nil
,
nil
}
func
ptr
[
T
any
](
v
T
)
*
T
{
func
ptr
[
T
any
](
v
T
)
*
T
{
return
&
v
return
&
v
}
}
...
...
backend/internal/service/gateway_oauth_metadata_test.go
0 → 100644
View file @
2220fd18
package
service
import
(
"regexp"
"testing"
"github.com/stretchr/testify/require"
)
func
TestBuildOAuthMetadataUserID_FallbackWithoutAccountUUID
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
parsed
:=
&
ParsedRequest
{
Model
:
"claude-sonnet-4-5"
,
Stream
:
true
,
MetadataUserID
:
""
,
System
:
nil
,
Messages
:
nil
,
}
account
:=
&
Account
{
ID
:
123
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{},
// intentionally missing account_uuid / claude_user_id
}
fp
:=
&
Fingerprint
{
ClientID
:
"deadbeef"
}
// should be used as user id in legacy format
got
:=
svc
.
buildOAuthMetadataUserID
(
parsed
,
account
,
fp
)
require
.
NotEmpty
(
t
,
got
)
// Legacy format: user_{client}_account__session_{uuid}
re
:=
regexp
.
MustCompile
(
`^user_[a-zA-Z0-9]+_account__session_[a-f0-9-]{36}$`
)
require
.
True
(
t
,
re
.
MatchString
(
got
),
"unexpected user_id format: %s"
,
got
)
}
func
TestBuildOAuthMetadataUserID_UsesAccountUUIDWhenPresent
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
parsed
:=
&
ParsedRequest
{
Model
:
"claude-sonnet-4-5"
,
Stream
:
true
,
MetadataUserID
:
""
,
}
account
:=
&
Account
{
ID
:
123
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{
"account_uuid"
:
"acc-uuid"
,
"claude_user_id"
:
"clientid123"
,
"anthropic_user_id"
:
""
,
},
}
got
:=
svc
.
buildOAuthMetadataUserID
(
parsed
,
account
,
nil
)
require
.
NotEmpty
(
t
,
got
)
// New format: user_{client}_account_{account_uuid}_session_{uuid}
re
:=
regexp
.
MustCompile
(
`^user_clientid123_account_acc-uuid_session_[a-f0-9-]{36}$`
)
require
.
True
(
t
,
re
.
MatchString
(
got
),
"unexpected user_id format: %s"
,
got
)
}
backend/internal/service/gateway_prompt_test.go
View file @
2220fd18
...
@@ -2,6 +2,7 @@ package service
...
@@ -2,6 +2,7 @@ package service
import
(
import
(
"encoding/json"
"encoding/json"
"strings"
"testing"
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/require"
...
@@ -134,6 +135,8 @@ func TestSystemIncludesClaudeCodePrompt(t *testing.T) {
...
@@ -134,6 +135,8 @@ func TestSystemIncludesClaudeCodePrompt(t *testing.T) {
}
}
func
TestInjectClaudeCodePrompt
(
t
*
testing
.
T
)
{
func
TestInjectClaudeCodePrompt
(
t
*
testing
.
T
)
{
claudePrefix
:=
strings
.
TrimSpace
(
claudeCodeSystemPrompt
)
tests
:=
[]
struct
{
tests
:=
[]
struct
{
name
string
name
string
body
string
body
string
...
@@ -162,7 +165,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
...
@@ -162,7 +165,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
system
:
"Custom prompt"
,
system
:
"Custom prompt"
,
wantSystemLen
:
2
,
wantSystemLen
:
2
,
wantFirstText
:
claudeCodeSystemPrompt
,
wantFirstText
:
claudeCodeSystemPrompt
,
wantSecondText
:
"
Custom prompt"
,
wantSecondText
:
claudePrefix
+
"
\n\n
Custom prompt"
,
},
},
{
{
name
:
"string system equals Claude Code prompt"
,
name
:
"string system equals Claude Code prompt"
,
...
@@ -178,7 +181,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
...
@@ -178,7 +181,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
// Claude Code + Custom = 2
// Claude Code + Custom = 2
wantSystemLen
:
2
,
wantSystemLen
:
2
,
wantFirstText
:
claudeCodeSystemPrompt
,
wantFirstText
:
claudeCodeSystemPrompt
,
wantSecondText
:
"
Custom"
,
wantSecondText
:
claudePrefix
+
"
\n\n
Custom"
,
},
},
{
{
name
:
"array system with existing Claude Code prompt (should dedupe)"
,
name
:
"array system with existing Claude Code prompt (should dedupe)"
,
...
@@ -190,7 +193,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
...
@@ -190,7 +193,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
// Claude Code at start + Other = 2 (deduped)
// Claude Code at start + Other = 2 (deduped)
wantSystemLen
:
2
,
wantSystemLen
:
2
,
wantFirstText
:
claudeCodeSystemPrompt
,
wantFirstText
:
claudeCodeSystemPrompt
,
wantSecondText
:
"
Other"
,
wantSecondText
:
claudePrefix
+
"
\n\n
Other"
,
},
},
{
{
name
:
"empty array"
,
name
:
"empty array"
,
...
...
backend/internal/service/gateway_sanitize_test.go
0 → 100644
View file @
2220fd18
package
service
import
(
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func
TestSanitizeOpenCodeText_RewritesCanonicalSentence
(
t
*
testing
.
T
)
{
in
:=
"You are OpenCode, the best coding agent on the planet."
got
:=
sanitizeSystemText
(
in
)
require
.
Equal
(
t
,
strings
.
TrimSpace
(
claudeCodeSystemPrompt
),
got
)
}
func
TestSanitizeToolDescription_DoesNotRewriteKeywords
(
t
*
testing
.
T
)
{
in
:=
"OpenCode and opencode are mentioned."
got
:=
sanitizeToolDescription
(
in
)
// We no longer rewrite tool descriptions; only redact obvious path leaks.
require
.
Equal
(
t
,
in
,
got
)
}
backend/internal/service/gateway_service.go
View file @
2220fd18
...
@@ -20,12 +20,14 @@ import (
...
@@ -20,12 +20,14 @@ import (
"strings"
"strings"
"sync/atomic"
"sync/atomic"
"time"
"time"
"unicode"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/google/uuid"
"github.com/tidwall/gjson"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"github.com/tidwall/sjson"
...
@@ -37,8 +39,15 @@ const (
...
@@ -37,8 +39,15 @@ const (
claudeAPICountTokensURL
=
"https://api.anthropic.com/v1/messages/count_tokens?beta=true"
claudeAPICountTokensURL
=
"https://api.anthropic.com/v1/messages/count_tokens?beta=true"
stickySessionTTL
=
time
.
Hour
// 粘性会话TTL
stickySessionTTL
=
time
.
Hour
// 粘性会话TTL
defaultMaxLineSize
=
40
*
1024
*
1024
defaultMaxLineSize
=
40
*
1024
*
1024
claudeCodeSystemPrompt
=
"You are Claude Code, Anthropic's official CLI for Claude."
// Canonical Claude Code banner. Keep it EXACT (no trailing whitespace/newlines)
maxCacheControlBlocks
=
4
// Anthropic API 允许的最大 cache_control 块数量
// to match real Claude CLI traffic as closely as possible. When we need a visual
// separator between system blocks, we add "\n\n" at concatenation time.
claudeCodeSystemPrompt
=
"You are Claude Code, Anthropic's official CLI for Claude."
maxCacheControlBlocks
=
4
// Anthropic API 允许的最大 cache_control 块数量
)
const
(
claudeMimicDebugInfoKey
=
"claude_mimic_debug_info"
)
)
func
(
s
*
GatewayService
)
debugModelRoutingEnabled
()
bool
{
func
(
s
*
GatewayService
)
debugModelRoutingEnabled
()
bool
{
...
@@ -46,6 +55,11 @@ func (s *GatewayService) debugModelRoutingEnabled() bool {
...
@@ -46,6 +55,11 @@ func (s *GatewayService) debugModelRoutingEnabled() bool {
return
v
==
"1"
||
v
==
"true"
||
v
==
"yes"
||
v
==
"on"
return
v
==
"1"
||
v
==
"true"
||
v
==
"yes"
||
v
==
"on"
}
}
func
(
s
*
GatewayService
)
debugClaudeMimicEnabled
()
bool
{
v
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
os
.
Getenv
(
"SUB2API_DEBUG_CLAUDE_MIMIC"
)))
return
v
==
"1"
||
v
==
"true"
||
v
==
"yes"
||
v
==
"on"
}
func
shortSessionHash
(
sessionHash
string
)
string
{
func
shortSessionHash
(
sessionHash
string
)
string
{
if
sessionHash
==
""
{
if
sessionHash
==
""
{
return
""
return
""
...
@@ -65,12 +79,178 @@ func normalizeClaudeModelForAnthropic(requestedModel string) string {
...
@@ -65,12 +79,178 @@ func normalizeClaudeModelForAnthropic(requestedModel string) string {
return
requestedModel
return
requestedModel
}
}
func
redactAuthHeaderValue
(
v
string
)
string
{
v
=
strings
.
TrimSpace
(
v
)
if
v
==
""
{
return
""
}
// Keep scheme for debugging, redact secret.
if
strings
.
HasPrefix
(
strings
.
ToLower
(
v
),
"bearer "
)
{
return
"Bearer [redacted]"
}
return
"[redacted]"
}
func
safeHeaderValueForLog
(
key
string
,
v
string
)
string
{
key
=
strings
.
ToLower
(
strings
.
TrimSpace
(
key
))
switch
key
{
case
"authorization"
,
"x-api-key"
:
return
redactAuthHeaderValue
(
v
)
default
:
return
strings
.
TrimSpace
(
v
)
}
}
func
extractSystemPreviewFromBody
(
body
[]
byte
)
string
{
if
len
(
body
)
==
0
{
return
""
}
sys
:=
gjson
.
GetBytes
(
body
,
"system"
)
if
!
sys
.
Exists
()
{
return
""
}
switch
{
case
sys
.
IsArray
()
:
for
_
,
item
:=
range
sys
.
Array
()
{
if
!
item
.
IsObject
()
{
continue
}
if
strings
.
EqualFold
(
item
.
Get
(
"type"
)
.
String
(),
"text"
)
{
if
t
:=
item
.
Get
(
"text"
)
.
String
();
strings
.
TrimSpace
(
t
)
!=
""
{
return
t
}
}
}
return
""
case
sys
.
Type
==
gjson
.
String
:
return
sys
.
String
()
default
:
return
""
}
}
func
buildClaudeMimicDebugLine
(
req
*
http
.
Request
,
body
[]
byte
,
account
*
Account
,
tokenType
string
,
mimicClaudeCode
bool
)
string
{
if
req
==
nil
{
return
""
}
// Only log a minimal fingerprint to avoid leaking user content.
interesting
:=
[]
string
{
"user-agent"
,
"x-app"
,
"anthropic-dangerous-direct-browser-access"
,
"anthropic-version"
,
"anthropic-beta"
,
"x-stainless-lang"
,
"x-stainless-package-version"
,
"x-stainless-os"
,
"x-stainless-arch"
,
"x-stainless-runtime"
,
"x-stainless-runtime-version"
,
"x-stainless-retry-count"
,
"x-stainless-timeout"
,
"authorization"
,
"x-api-key"
,
"content-type"
,
"accept"
,
"x-stainless-helper-method"
,
}
h
:=
make
([]
string
,
0
,
len
(
interesting
))
for
_
,
k
:=
range
interesting
{
if
v
:=
req
.
Header
.
Get
(
k
);
v
!=
""
{
h
=
append
(
h
,
fmt
.
Sprintf
(
"%s=%q"
,
k
,
safeHeaderValueForLog
(
k
,
v
)))
}
}
metaUserID
:=
strings
.
TrimSpace
(
gjson
.
GetBytes
(
body
,
"metadata.user_id"
)
.
String
())
sysPreview
:=
strings
.
TrimSpace
(
extractSystemPreviewFromBody
(
body
))
// Truncate preview to keep logs sane.
if
len
(
sysPreview
)
>
300
{
sysPreview
=
sysPreview
[
:
300
]
+
"..."
}
sysPreview
=
strings
.
ReplaceAll
(
sysPreview
,
"
\n
"
,
"
\\
n"
)
sysPreview
=
strings
.
ReplaceAll
(
sysPreview
,
"
\r
"
,
"
\\
r"
)
aid
:=
int64
(
0
)
aname
:=
""
if
account
!=
nil
{
aid
=
account
.
ID
aname
=
account
.
Name
}
return
fmt
.
Sprintf
(
"url=%s account=%d(%s) tokenType=%s mimic=%t meta.user_id=%q system.preview=%q headers={%s}"
,
req
.
URL
.
String
(),
aid
,
aname
,
tokenType
,
mimicClaudeCode
,
metaUserID
,
sysPreview
,
strings
.
Join
(
h
,
" "
),
)
}
func
logClaudeMimicDebug
(
req
*
http
.
Request
,
body
[]
byte
,
account
*
Account
,
tokenType
string
,
mimicClaudeCode
bool
)
{
line
:=
buildClaudeMimicDebugLine
(
req
,
body
,
account
,
tokenType
,
mimicClaudeCode
)
if
line
==
""
{
return
}
log
.
Printf
(
"[ClaudeMimicDebug] %s"
,
line
)
}
func
isClaudeCodeCredentialScopeError
(
msg
string
)
bool
{
m
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
msg
))
if
m
==
""
{
return
false
}
return
strings
.
Contains
(
m
,
"only authorized for use with claude code"
)
&&
strings
.
Contains
(
m
,
"cannot be used for other api requests"
)
}
// sseDataRe matches SSE data lines with optional whitespace after colon.
// sseDataRe matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var
(
var
(
sseDataRe
=
regexp
.
MustCompile
(
`^data:\s*`
)
sseDataRe
=
regexp
.
MustCompile
(
`^data:\s*`
)
sessionIDRegex
=
regexp
.
MustCompile
(
`session_([a-f0-9-]{36})`
)
sessionIDRegex
=
regexp
.
MustCompile
(
`session_([a-f0-9-]{36})`
)
claudeCliUserAgentRe
=
regexp
.
MustCompile
(
`^claude-cli/\d+\.\d+\.\d+`
)
claudeCliUserAgentRe
=
regexp
.
MustCompile
(
`^claude-cli/\d+\.\d+\.\d+`
)
toolPrefixRe
=
regexp
.
MustCompile
(
`(?i)^(?:oc_|mcp_)`
)
toolNameBoundaryRe
=
regexp
.
MustCompile
(
`[^a-zA-Z0-9]+`
)
toolNameCamelRe
=
regexp
.
MustCompile
(
`([a-z0-9])([A-Z])`
)
toolNameFieldRe
=
regexp
.
MustCompile
(
`"name"\s*:\s*"([^"]+)"`
)
modelFieldRe
=
regexp
.
MustCompile
(
`"model"\s*:\s*"([^"]+)"`
)
toolDescAbsPathRe
=
regexp
.
MustCompile
(
`/\/?(?:home|Users|tmp|var|opt|usr|etc)\/[^\s,\)"'\]]+`
)
toolDescWinPathRe
=
regexp
.
MustCompile
(
`(?i)[A-Z]:\\[^\s,\)"'\]]+`
)
claudeToolNameOverrides
=
map
[
string
]
string
{
"bash"
:
"Bash"
,
"read"
:
"Read"
,
"edit"
:
"Edit"
,
"write"
:
"Write"
,
"task"
:
"Task"
,
"glob"
:
"Glob"
,
"grep"
:
"Grep"
,
"webfetch"
:
"WebFetch"
,
"websearch"
:
"WebSearch"
,
"todowrite"
:
"TodoWrite"
,
"question"
:
"AskUserQuestion"
,
}
openCodeToolOverrides
=
map
[
string
]
string
{
"Bash"
:
"bash"
,
"Read"
:
"read"
,
"Edit"
:
"edit"
,
"Write"
:
"write"
,
"Task"
:
"task"
,
"Glob"
:
"glob"
,
"Grep"
:
"grep"
,
"WebFetch"
:
"webfetch"
,
"WebSearch"
:
"websearch"
,
"TodoWrite"
:
"todowrite"
,
"AskUserQuestion"
:
"question"
,
}
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
// 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等
// 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等
...
@@ -436,6 +616,394 @@ func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte
...
@@ -436,6 +616,394 @@ func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte
return
newBody
return
newBody
}
}
type
claudeOAuthNormalizeOptions
struct
{
injectMetadata
bool
metadataUserID
string
stripSystemCacheControl
bool
}
func
stripToolPrefix
(
value
string
)
string
{
if
value
==
""
{
return
value
}
return
toolPrefixRe
.
ReplaceAllString
(
value
,
""
)
}
func
toPascalCase
(
value
string
)
string
{
if
value
==
""
{
return
value
}
normalized
:=
toolNameBoundaryRe
.
ReplaceAllString
(
value
,
" "
)
tokens
:=
make
([]
string
,
0
)
for
_
,
token
:=
range
strings
.
Fields
(
normalized
)
{
expanded
:=
toolNameCamelRe
.
ReplaceAllString
(
token
,
"$1 $2"
)
parts
:=
strings
.
Fields
(
expanded
)
if
len
(
parts
)
>
0
{
tokens
=
append
(
tokens
,
parts
...
)
}
}
if
len
(
tokens
)
==
0
{
return
value
}
var
builder
strings
.
Builder
for
_
,
token
:=
range
tokens
{
lower
:=
strings
.
ToLower
(
token
)
if
lower
==
""
{
continue
}
runes
:=
[]
rune
(
lower
)
runes
[
0
]
=
unicode
.
ToUpper
(
runes
[
0
])
_
,
_
=
builder
.
WriteString
(
string
(
runes
))
}
return
builder
.
String
()
}
func
toSnakeCase
(
value
string
)
string
{
if
value
==
""
{
return
value
}
output
:=
toolNameCamelRe
.
ReplaceAllString
(
value
,
"$1_$2"
)
output
=
toolNameBoundaryRe
.
ReplaceAllString
(
output
,
"_"
)
output
=
strings
.
Trim
(
output
,
"_"
)
return
strings
.
ToLower
(
output
)
}
func
normalizeToolNameForClaude
(
name
string
,
cache
map
[
string
]
string
)
string
{
if
name
==
""
{
return
name
}
stripped
:=
stripToolPrefix
(
name
)
mapped
,
ok
:=
claudeToolNameOverrides
[
strings
.
ToLower
(
stripped
)]
if
!
ok
{
mapped
=
toPascalCase
(
stripped
)
}
if
mapped
!=
""
&&
cache
!=
nil
&&
mapped
!=
stripped
{
cache
[
mapped
]
=
stripped
}
if
mapped
==
""
{
return
stripped
}
return
mapped
}
func
normalizeToolNameForOpenCode
(
name
string
,
cache
map
[
string
]
string
)
string
{
if
name
==
""
{
return
name
}
stripped
:=
stripToolPrefix
(
name
)
if
cache
!=
nil
{
if
mapped
,
ok
:=
cache
[
stripped
];
ok
{
return
mapped
}
}
if
mapped
,
ok
:=
openCodeToolOverrides
[
stripped
];
ok
{
return
mapped
}
return
toSnakeCase
(
stripped
)
}
func
normalizeParamNameForOpenCode
(
name
string
,
cache
map
[
string
]
string
)
string
{
if
name
==
""
{
return
name
}
if
cache
!=
nil
{
if
mapped
,
ok
:=
cache
[
name
];
ok
{
return
mapped
}
}
return
name
}
// sanitizeSystemText rewrites only the fixed OpenCode identity sentence (if present).
// We intentionally avoid broad keyword replacement in system prompts to prevent
// accidentally changing user-provided instructions.
func
sanitizeSystemText
(
text
string
)
string
{
if
text
==
""
{
return
text
}
// Some clients include a fixed OpenCode identity sentence. Anthropic may treat
// this as a non-Claude-Code fingerprint, so rewrite it to the canonical
// Claude Code banner before generic "OpenCode"/"opencode" replacements.
text
=
strings
.
ReplaceAll
(
text
,
"You are OpenCode, the best coding agent on the planet."
,
strings
.
TrimSpace
(
claudeCodeSystemPrompt
),
)
return
text
}
func
sanitizeToolDescription
(
description
string
)
string
{
if
description
==
""
{
return
description
}
description
=
toolDescAbsPathRe
.
ReplaceAllString
(
description
,
"[path]"
)
description
=
toolDescWinPathRe
.
ReplaceAllString
(
description
,
"[path]"
)
// Intentionally do NOT rewrite tool descriptions (OpenCode/Claude strings).
// Tool names/skill names may rely on exact wording, and rewriting can be misleading.
return
description
}
func
normalizeToolInputSchema
(
inputSchema
any
,
cache
map
[
string
]
string
)
{
schema
,
ok
:=
inputSchema
.
(
map
[
string
]
any
)
if
!
ok
{
return
}
properties
,
ok
:=
schema
[
"properties"
]
.
(
map
[
string
]
any
)
if
!
ok
{
return
}
newProperties
:=
make
(
map
[
string
]
any
,
len
(
properties
))
for
key
,
value
:=
range
properties
{
snakeKey
:=
toSnakeCase
(
key
)
newProperties
[
snakeKey
]
=
value
if
snakeKey
!=
key
&&
cache
!=
nil
{
cache
[
snakeKey
]
=
key
}
}
schema
[
"properties"
]
=
newProperties
if
required
,
ok
:=
schema
[
"required"
]
.
([]
any
);
ok
{
newRequired
:=
make
([]
any
,
0
,
len
(
required
))
for
_
,
item
:=
range
required
{
name
,
ok
:=
item
.
(
string
)
if
!
ok
{
newRequired
=
append
(
newRequired
,
item
)
continue
}
snakeName
:=
toSnakeCase
(
name
)
newRequired
=
append
(
newRequired
,
snakeName
)
if
snakeName
!=
name
&&
cache
!=
nil
{
cache
[
snakeName
]
=
name
}
}
schema
[
"required"
]
=
newRequired
}
}
func
stripCacheControlFromSystemBlocks
(
system
any
)
bool
{
blocks
,
ok
:=
system
.
([]
any
)
if
!
ok
{
return
false
}
changed
:=
false
for
_
,
item
:=
range
blocks
{
block
,
ok
:=
item
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
if
_
,
exists
:=
block
[
"cache_control"
];
!
exists
{
continue
}
delete
(
block
,
"cache_control"
)
changed
=
true
}
return
changed
}
func
normalizeClaudeOAuthRequestBody
(
body
[]
byte
,
modelID
string
,
opts
claudeOAuthNormalizeOptions
)
([]
byte
,
string
,
map
[
string
]
string
)
{
if
len
(
body
)
==
0
{
return
body
,
modelID
,
nil
}
var
req
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
req
);
err
!=
nil
{
return
body
,
modelID
,
nil
}
toolNameMap
:=
make
(
map
[
string
]
string
)
if
system
,
ok
:=
req
[
"system"
];
ok
{
switch
v
:=
system
.
(
type
)
{
case
string
:
sanitized
:=
sanitizeSystemText
(
v
)
if
sanitized
!=
v
{
req
[
"system"
]
=
sanitized
}
case
[]
any
:
for
_
,
item
:=
range
v
{
block
,
ok
:=
item
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
if
blockType
,
_
:=
block
[
"type"
]
.
(
string
);
blockType
!=
"text"
{
continue
}
text
,
ok
:=
block
[
"text"
]
.
(
string
)
if
!
ok
||
text
==
""
{
continue
}
sanitized
:=
sanitizeSystemText
(
text
)
if
sanitized
!=
text
{
block
[
"text"
]
=
sanitized
}
}
}
}
if
rawModel
,
ok
:=
req
[
"model"
]
.
(
string
);
ok
{
normalized
:=
claude
.
NormalizeModelID
(
rawModel
)
if
normalized
!=
rawModel
{
req
[
"model"
]
=
normalized
modelID
=
normalized
}
}
if
rawTools
,
exists
:=
req
[
"tools"
];
exists
{
switch
tools
:=
rawTools
.
(
type
)
{
case
[]
any
:
for
idx
,
tool
:=
range
tools
{
toolMap
,
ok
:=
tool
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
if
name
,
ok
:=
toolMap
[
"name"
]
.
(
string
);
ok
{
normalized
:=
normalizeToolNameForClaude
(
name
,
toolNameMap
)
if
normalized
!=
""
&&
normalized
!=
name
{
toolMap
[
"name"
]
=
normalized
}
}
if
desc
,
ok
:=
toolMap
[
"description"
]
.
(
string
);
ok
{
sanitized
:=
sanitizeToolDescription
(
desc
)
if
sanitized
!=
desc
{
toolMap
[
"description"
]
=
sanitized
}
}
if
schema
,
ok
:=
toolMap
[
"input_schema"
];
ok
{
normalizeToolInputSchema
(
schema
,
toolNameMap
)
}
tools
[
idx
]
=
toolMap
}
req
[
"tools"
]
=
tools
case
map
[
string
]
any
:
normalizedTools
:=
make
(
map
[
string
]
any
,
len
(
tools
))
for
name
,
value
:=
range
tools
{
normalized
:=
normalizeToolNameForClaude
(
name
,
toolNameMap
)
if
normalized
==
""
{
normalized
=
name
}
if
toolMap
,
ok
:=
value
.
(
map
[
string
]
any
);
ok
{
toolMap
[
"name"
]
=
normalized
if
desc
,
ok
:=
toolMap
[
"description"
]
.
(
string
);
ok
{
sanitized
:=
sanitizeToolDescription
(
desc
)
if
sanitized
!=
desc
{
toolMap
[
"description"
]
=
sanitized
}
}
if
schema
,
ok
:=
toolMap
[
"input_schema"
];
ok
{
normalizeToolInputSchema
(
schema
,
toolNameMap
)
}
normalizedTools
[
normalized
]
=
toolMap
continue
}
normalizedTools
[
normalized
]
=
value
}
req
[
"tools"
]
=
normalizedTools
}
}
else
{
req
[
"tools"
]
=
[]
any
{}
}
if
messages
,
ok
:=
req
[
"messages"
]
.
([]
any
);
ok
{
for
_
,
msg
:=
range
messages
{
msgMap
,
ok
:=
msg
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
content
,
ok
:=
msgMap
[
"content"
]
.
([]
any
)
if
!
ok
{
continue
}
for
_
,
block
:=
range
content
{
blockMap
,
ok
:=
block
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
if
blockType
,
_
:=
blockMap
[
"type"
]
.
(
string
);
blockType
!=
"tool_use"
{
continue
}
if
name
,
ok
:=
blockMap
[
"name"
]
.
(
string
);
ok
{
normalized
:=
normalizeToolNameForClaude
(
name
,
toolNameMap
)
if
normalized
!=
""
&&
normalized
!=
name
{
blockMap
[
"name"
]
=
normalized
}
}
}
}
}
if
opts
.
stripSystemCacheControl
{
if
system
,
ok
:=
req
[
"system"
];
ok
{
_
=
stripCacheControlFromSystemBlocks
(
system
)
}
}
if
opts
.
injectMetadata
&&
opts
.
metadataUserID
!=
""
{
metadata
,
ok
:=
req
[
"metadata"
]
.
(
map
[
string
]
any
)
if
!
ok
{
metadata
=
map
[
string
]
any
{}
req
[
"metadata"
]
=
metadata
}
if
existing
,
ok
:=
metadata
[
"user_id"
]
.
(
string
);
!
ok
||
existing
==
""
{
metadata
[
"user_id"
]
=
opts
.
metadataUserID
}
}
delete
(
req
,
"temperature"
)
delete
(
req
,
"tool_choice"
)
newBody
,
err
:=
json
.
Marshal
(
req
)
if
err
!=
nil
{
return
body
,
modelID
,
toolNameMap
}
return
newBody
,
modelID
,
toolNameMap
}
func
(
s
*
GatewayService
)
buildOAuthMetadataUserID
(
parsed
*
ParsedRequest
,
account
*
Account
,
fp
*
Fingerprint
)
string
{
if
parsed
==
nil
||
account
==
nil
{
return
""
}
if
parsed
.
MetadataUserID
!=
""
{
return
""
}
userID
:=
strings
.
TrimSpace
(
account
.
GetClaudeUserID
())
if
userID
==
""
&&
fp
!=
nil
{
userID
=
fp
.
ClientID
}
if
userID
==
""
{
// Fall back to a random, well-formed client id so we can still satisfy
// Claude Code OAuth requirements when account metadata is incomplete.
userID
=
generateClientID
()
}
sessionHash
:=
s
.
GenerateSessionHash
(
parsed
)
sessionID
:=
uuid
.
NewString
()
if
sessionHash
!=
""
{
seed
:=
fmt
.
Sprintf
(
"%d::%s"
,
account
.
ID
,
sessionHash
)
sessionID
=
generateSessionUUID
(
seed
)
}
// Prefer the newer format that includes account_uuid (if present),
// otherwise fall back to the legacy Claude Code format.
accountUUID
:=
strings
.
TrimSpace
(
account
.
GetExtraString
(
"account_uuid"
))
if
accountUUID
!=
""
{
return
fmt
.
Sprintf
(
"user_%s_account_%s_session_%s"
,
userID
,
accountUUID
,
sessionID
)
}
return
fmt
.
Sprintf
(
"user_%s_account__session_%s"
,
userID
,
sessionID
)
}
func
generateSessionUUID
(
seed
string
)
string
{
if
seed
==
""
{
return
uuid
.
NewString
()
}
hash
:=
sha256
.
Sum256
([]
byte
(
seed
))
bytes
:=
hash
[
:
16
]
bytes
[
6
]
=
(
bytes
[
6
]
&
0x0f
)
|
0x40
bytes
[
8
]
=
(
bytes
[
8
]
&
0x3f
)
|
0x80
return
fmt
.
Sprintf
(
"%x-%x-%x-%x-%x"
,
bytes
[
0
:
4
],
bytes
[
4
:
6
],
bytes
[
6
:
8
],
bytes
[
8
:
10
],
bytes
[
10
:
16
])
}
// SelectAccount 选择账号(粘性会话+优先级)
// SelectAccount 选择账号(粘性会话+优先级)
func
(
s
*
GatewayService
)
SelectAccount
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
)
(
*
Account
,
error
)
{
func
(
s
*
GatewayService
)
SelectAccount
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
)
(
*
Account
,
error
)
{
return
s
.
SelectAccountForModel
(
ctx
,
groupID
,
sessionHash
,
""
)
return
s
.
SelectAccountForModel
(
ctx
,
groupID
,
sessionHash
,
""
)
...
@@ -2060,6 +2628,16 @@ func isClaudeCodeClient(userAgent string, metadataUserID string) bool {
...
@@ -2060,6 +2628,16 @@ func isClaudeCodeClient(userAgent string, metadataUserID string) bool {
return
claudeCliUserAgentRe
.
MatchString
(
userAgent
)
return
claudeCliUserAgentRe
.
MatchString
(
userAgent
)
}
}
func
isClaudeCodeRequest
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
parsed
*
ParsedRequest
)
bool
{
if
IsClaudeCodeClient
(
ctx
)
{
return
true
}
if
parsed
==
nil
||
c
==
nil
{
return
false
}
return
isClaudeCodeClient
(
c
.
GetHeader
(
"User-Agent"
),
parsed
.
MetadataUserID
)
}
// systemIncludesClaudeCodePrompt 检查 system 中是否已包含 Claude Code 提示词
// systemIncludesClaudeCodePrompt 检查 system 中是否已包含 Claude Code 提示词
// 使用前缀匹配支持多种变体(标准版、Agent SDK 版等)
// 使用前缀匹配支持多种变体(标准版、Agent SDK 版等)
func
systemIncludesClaudeCodePrompt
(
system
any
)
bool
{
func
systemIncludesClaudeCodePrompt
(
system
any
)
bool
{
...
@@ -2096,6 +2674,10 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
...
@@ -2096,6 +2674,10 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
"text"
:
claudeCodeSystemPrompt
,
"text"
:
claudeCodeSystemPrompt
,
"cache_control"
:
map
[
string
]
string
{
"type"
:
"ephemeral"
},
"cache_control"
:
map
[
string
]
string
{
"type"
:
"ephemeral"
},
}
}
// Opencode plugin applies an extra safeguard: it not only prepends the Claude Code
// banner, it also prefixes the next system instruction with the same banner plus
// a blank line. This helps when upstream concatenates system instructions.
claudeCodePrefix
:=
strings
.
TrimSpace
(
claudeCodeSystemPrompt
)
var
newSystem
[]
any
var
newSystem
[]
any
...
@@ -2103,19 +2685,36 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
...
@@ -2103,19 +2685,36 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
case
nil
:
case
nil
:
newSystem
=
[]
any
{
claudeCodeBlock
}
newSystem
=
[]
any
{
claudeCodeBlock
}
case
string
:
case
string
:
if
v
==
""
||
v
==
claudeCodeSystemPrompt
{
// Be tolerant of older/newer clients that may differ only by trailing whitespace/newlines.
if
strings
.
TrimSpace
(
v
)
==
""
||
strings
.
TrimSpace
(
v
)
==
strings
.
TrimSpace
(
claudeCodeSystemPrompt
)
{
newSystem
=
[]
any
{
claudeCodeBlock
}
newSystem
=
[]
any
{
claudeCodeBlock
}
}
else
{
}
else
{
newSystem
=
[]
any
{
claudeCodeBlock
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
v
}}
// Mirror opencode behavior: keep the banner as a separate system entry,
// but also prefix the next system text with the banner.
merged
:=
v
if
!
strings
.
HasPrefix
(
v
,
claudeCodePrefix
)
{
merged
=
claudeCodePrefix
+
"
\n\n
"
+
v
}
newSystem
=
[]
any
{
claudeCodeBlock
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
merged
}}
}
}
case
[]
any
:
case
[]
any
:
newSystem
=
make
([]
any
,
0
,
len
(
v
)
+
1
)
newSystem
=
make
([]
any
,
0
,
len
(
v
)
+
1
)
newSystem
=
append
(
newSystem
,
claudeCodeBlock
)
newSystem
=
append
(
newSystem
,
claudeCodeBlock
)
prefixedNext
:=
false
for
_
,
item
:=
range
v
{
for
_
,
item
:=
range
v
{
if
m
,
ok
:=
item
.
(
map
[
string
]
any
);
ok
{
if
m
,
ok
:=
item
.
(
map
[
string
]
any
);
ok
{
if
text
,
ok
:=
m
[
"text"
]
.
(
string
);
ok
&&
text
==
claudeCodeSystemPrompt
{
if
text
,
ok
:=
m
[
"text"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
text
)
==
strings
.
TrimSpace
(
claudeCodeSystemPrompt
)
{
continue
continue
}
}
// Prefix the first subsequent text system block once.
if
!
prefixedNext
{
if
blockType
,
_
:=
m
[
"type"
]
.
(
string
);
blockType
==
"text"
{
if
text
,
ok
:=
m
[
"text"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
text
)
!=
""
&&
!
strings
.
HasPrefix
(
text
,
claudeCodePrefix
)
{
m
[
"text"
]
=
claudeCodePrefix
+
"
\n\n
"
+
text
prefixedNext
=
true
}
}
}
}
}
newSystem
=
append
(
newSystem
,
item
)
newSystem
=
append
(
newSystem
,
item
)
}
}
...
@@ -2319,21 +2918,38 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
...
@@ -2319,21 +2918,38 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
body
:=
parsed
.
Body
body
:=
parsed
.
Body
reqModel
:=
parsed
.
Model
reqModel
:=
parsed
.
Model
reqStream
:=
parsed
.
Stream
reqStream
:=
parsed
.
Stream
originalModel
:=
reqModel
var
toolNameMap
map
[
string
]
string
isClaudeCode
:=
isClaudeCodeRequest
(
ctx
,
c
,
parsed
)
shouldMimicClaudeCode
:=
account
.
IsOAuth
()
&&
!
isClaudeCode
if
shouldMimicClaudeCode
{
// 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要)
// 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
if
!
strings
.
Contains
(
strings
.
ToLower
(
reqModel
),
"haiku"
)
&&
!
systemIncludesClaudeCodePrompt
(
parsed
.
System
)
{
body
=
injectClaudeCodePrompt
(
body
,
parsed
.
System
)
}
normalizeOpts
:=
claudeOAuthNormalizeOptions
{
stripSystemCacheControl
:
true
}
if
s
.
identityService
!=
nil
{
fp
,
err
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
.
Request
.
Header
)
if
err
==
nil
&&
fp
!=
nil
{
if
metadataUserID
:=
s
.
buildOAuthMetadataUserID
(
parsed
,
account
,
fp
);
metadataUserID
!=
""
{
normalizeOpts
.
injectMetadata
=
true
normalizeOpts
.
metadataUserID
=
metadataUserID
}
}
}
// 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要)
body
,
reqModel
,
toolNameMap
=
normalizeClaudeOAuthRequestBody
(
body
,
reqModel
,
normalizeOpts
)
// 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
if
account
.
IsOAuth
()
&&
!
isClaudeCodeClient
(
c
.
GetHeader
(
"User-Agent"
),
parsed
.
MetadataUserID
)
&&
!
strings
.
Contains
(
strings
.
ToLower
(
reqModel
),
"haiku"
)
&&
!
systemIncludesClaudeCodePrompt
(
parsed
.
System
)
{
body
=
injectClaudeCodePrompt
(
body
,
parsed
.
System
)
}
}
// 强制执行 cache_control 块数量限制(最多 4 个)
// 强制执行 cache_control 块数量限制(最多 4 个)
body
=
enforceCacheControlLimit
(
body
)
body
=
enforceCacheControlLimit
(
body
)
// 应用模型映射(APIKey 明确映射优先,其次使用 Anthropic 前缀映射)
// 应用模型映射(APIKey 明确映射优先,其次使用 Anthropic 前缀映射)
originalModel
:=
reqModel
mappedModel
:=
reqModel
mappedModel
:=
reqModel
mappingSource
:=
""
mappingSource
:=
""
if
account
.
Type
==
AccountTypeAPIKey
{
if
account
.
Type
==
AccountTypeAPIKey
{
...
@@ -2377,10 +2993,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
...
@@ -2377,10 +2993,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
retryStart
:=
time
.
Now
()
retryStart
:=
time
.
Now
()
for
attempt
:=
1
;
attempt
<=
maxRetryAttempts
;
attempt
++
{
for
attempt
:=
1
;
attempt
<=
maxRetryAttempts
;
attempt
++
{
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
upstreamReq
,
err
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
body
,
token
,
tokenType
,
reqModel
)
// Capture upstream request body for ops retry of this attempt.
// Capture upstream request body for ops retry of this attempt.
c
.
Set
(
OpsUpstreamRequestBodyKey
,
string
(
body
))
c
.
Set
(
OpsUpstreamRequestBodyKey
,
string
(
body
))
upstreamReq
,
err
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
body
,
token
,
tokenType
,
reqModel
,
reqStream
,
shouldMimicClaudeCode
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
...
@@ -2458,7 +3073,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
...
@@ -2458,7 +3073,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// also downgrade tool_use/tool_result blocks to text.
// also downgrade tool_use/tool_result blocks to text.
filteredBody
:=
FilterThinkingBlocksForRetry
(
body
)
filteredBody
:=
FilterThinkingBlocksForRetry
(
body
)
retryReq
,
buildErr
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
filteredBody
,
token
,
tokenType
,
reqModel
)
retryReq
,
buildErr
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
filteredBody
,
token
,
tokenType
,
reqModel
,
reqStream
,
shouldMimicClaudeCode
)
if
buildErr
==
nil
{
if
buildErr
==
nil
{
retryResp
,
retryErr
:=
s
.
httpUpstream
.
DoWithTLS
(
retryReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
())
retryResp
,
retryErr
:=
s
.
httpUpstream
.
DoWithTLS
(
retryReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
())
if
retryErr
==
nil
{
if
retryErr
==
nil
{
...
@@ -2490,7 +3105,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
...
@@ -2490,7 +3105,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
if
looksLikeToolSignatureError
(
msg2
)
&&
time
.
Since
(
retryStart
)
<
maxRetryElapsed
{
if
looksLikeToolSignatureError
(
msg2
)
&&
time
.
Since
(
retryStart
)
<
maxRetryElapsed
{
log
.
Printf
(
"Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded"
,
account
.
ID
)
log
.
Printf
(
"Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded"
,
account
.
ID
)
filteredBody2
:=
FilterSignatureSensitiveBlocksForRetry
(
body
)
filteredBody2
:=
FilterSignatureSensitiveBlocksForRetry
(
body
)
retryReq2
,
buildErr2
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
filteredBody2
,
token
,
tokenType
,
reqModel
)
retryReq2
,
buildErr2
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
filteredBody2
,
token
,
tokenType
,
reqModel
,
reqStream
,
shouldMimicClaudeCode
)
if
buildErr2
==
nil
{
if
buildErr2
==
nil
{
retryResp2
,
retryErr2
:=
s
.
httpUpstream
.
DoWithTLS
(
retryReq2
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
())
retryResp2
,
retryErr2
:=
s
.
httpUpstream
.
DoWithTLS
(
retryReq2
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
())
if
retryErr2
==
nil
{
if
retryErr2
==
nil
{
...
@@ -2715,7 +3330,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
...
@@ -2715,7 +3330,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
var
firstTokenMs
*
int
var
firstTokenMs
*
int
var
clientDisconnect
bool
var
clientDisconnect
bool
if
reqStream
{
if
reqStream
{
streamResult
,
err
:=
s
.
handleStreamingResponse
(
ctx
,
resp
,
c
,
account
,
startTime
,
originalModel
,
reqModel
)
streamResult
,
err
:=
s
.
handleStreamingResponse
(
ctx
,
resp
,
c
,
account
,
startTime
,
originalModel
,
reqModel
,
toolNameMap
,
shouldMimicClaudeCode
)
if
err
!=
nil
{
if
err
!=
nil
{
if
err
.
Error
()
==
"have error in stream"
{
if
err
.
Error
()
==
"have error in stream"
{
return
nil
,
&
UpstreamFailoverError
{
return
nil
,
&
UpstreamFailoverError
{
...
@@ -2728,7 +3343,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
...
@@ -2728,7 +3343,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
firstTokenMs
=
streamResult
.
firstTokenMs
firstTokenMs
=
streamResult
.
firstTokenMs
clientDisconnect
=
streamResult
.
clientDisconnect
clientDisconnect
=
streamResult
.
clientDisconnect
}
else
{
}
else
{
usage
,
err
=
s
.
handleNonStreamingResponse
(
ctx
,
resp
,
c
,
account
,
originalModel
,
reqModel
)
usage
,
err
=
s
.
handleNonStreamingResponse
(
ctx
,
resp
,
c
,
account
,
originalModel
,
reqModel
,
toolNameMap
,
shouldMimicClaudeCode
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
...
@@ -2745,7 +3360,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
...
@@ -2745,7 +3360,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
},
nil
},
nil
}
}
func
(
s
*
GatewayService
)
buildUpstreamRequest
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
,
token
,
tokenType
,
modelID
string
)
(
*
http
.
Request
,
error
)
{
func
(
s
*
GatewayService
)
buildUpstreamRequest
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
,
token
,
tokenType
,
modelID
string
,
reqStream
bool
,
mimicClaudeCode
bool
)
(
*
http
.
Request
,
error
)
{
// 确定目标URL
// 确定目标URL
targetURL
:=
claudeAPIURL
targetURL
:=
claudeAPIURL
if
account
.
Type
==
AccountTypeAPIKey
{
if
account
.
Type
==
AccountTypeAPIKey
{
...
@@ -2759,11 +3374,16 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
...
@@ -2759,11 +3374,16 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
}
}
}
clientHeaders
:=
http
.
Header
{}
if
c
!=
nil
&&
c
.
Request
!=
nil
{
clientHeaders
=
c
.
Request
.
Header
}
// OAuth账号:应用统一指纹
// OAuth账号:应用统一指纹
var
fingerprint
*
Fingerprint
var
fingerprint
*
Fingerprint
if
account
.
IsOAuth
()
&&
s
.
identityService
!=
nil
{
if
account
.
IsOAuth
()
&&
s
.
identityService
!=
nil
{
// 1. 获取或创建指纹(包含随机生成的ClientID)
// 1. 获取或创建指纹(包含随机生成的ClientID)
fp
,
err
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
.
Request
.
Header
)
fp
,
err
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
lient
Header
s
)
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Printf
(
"Warning: failed to get fingerprint for account %d: %v"
,
account
.
ID
,
err
)
log
.
Printf
(
"Warning: failed to get fingerprint for account %d: %v"
,
account
.
ID
,
err
)
// 失败时降级为透传原始headers
// 失败时降级为透传原始headers
...
@@ -2794,7 +3414,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
...
@@ -2794,7 +3414,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
}
// 白名单透传headers
// 白名单透传headers
for
key
,
values
:=
range
c
.
Request
.
Header
{
for
key
,
values
:=
range
c
lient
Header
s
{
lowerKey
:=
strings
.
ToLower
(
key
)
lowerKey
:=
strings
.
ToLower
(
key
)
if
allowedHeaders
[
lowerKey
]
{
if
allowedHeaders
[
lowerKey
]
{
for
_
,
v
:=
range
values
{
for
_
,
v
:=
range
values
{
...
@@ -2815,10 +3435,30 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
...
@@ -2815,10 +3435,30 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
if
req
.
Header
.
Get
(
"anthropic-version"
)
==
""
{
if
req
.
Header
.
Get
(
"anthropic-version"
)
==
""
{
req
.
Header
.
Set
(
"anthropic-version"
,
"2023-06-01"
)
req
.
Header
.
Set
(
"anthropic-version"
,
"2023-06-01"
)
}
}
if
tokenType
==
"oauth"
{
applyClaudeOAuthHeaderDefaults
(
req
,
reqStream
)
}
// 处理anthropic-beta header(OAuth账号需要
特殊处理
)
// 处理
anthropic-beta header(OAuth
账号需要
包含 oauth beta
)
if
tokenType
==
"oauth"
{
if
tokenType
==
"oauth"
{
req
.
Header
.
Set
(
"anthropic-beta"
,
s
.
getBetaHeader
(
modelID
,
c
.
GetHeader
(
"anthropic-beta"
)))
if
mimicClaudeCode
{
// 非 Claude Code 客户端:按 opencode 的策略处理:
// - 强制 Claude Code 指纹相关请求头(尤其是 user-agent/x-stainless/x-app)
// - 保留 incoming beta 的同时,确保 OAuth 所需 beta 存在
applyClaudeCodeMimicHeaders
(
req
,
reqStream
)
incomingBeta
:=
req
.
Header
.
Get
(
"anthropic-beta"
)
// Match real Claude CLI traffic (per mitmproxy reports):
// messages requests typically use only oauth + interleaved-thinking.
// Also drop claude-code beta if a downstream client added it.
requiredBetas
:=
[]
string
{
claude
.
BetaOAuth
,
claude
.
BetaInterleavedThinking
}
drop
:=
map
[
string
]
struct
{}{
claude
.
BetaClaudeCode
:
{}}
req
.
Header
.
Set
(
"anthropic-beta"
,
mergeAnthropicBetaDropping
(
requiredBetas
,
incomingBeta
,
drop
))
}
else
{
// Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta
clientBetaHeader
:=
req
.
Header
.
Get
(
"anthropic-beta"
)
req
.
Header
.
Set
(
"anthropic-beta"
,
s
.
getBetaHeader
(
modelID
,
clientBetaHeader
))
}
}
else
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
InjectBetaForAPIKey
&&
req
.
Header
.
Get
(
"anthropic-beta"
)
==
""
{
}
else
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
InjectBetaForAPIKey
&&
req
.
Header
.
Get
(
"anthropic-beta"
)
==
""
{
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
if
requestNeedsBetaFeatures
(
body
)
{
if
requestNeedsBetaFeatures
(
body
)
{
...
@@ -2828,6 +3468,15 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
...
@@ -2828,6 +3468,15 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
}
}
}
// Always capture a compact fingerprint line for later error diagnostics.
// We only print it when needed (or when the explicit debug flag is enabled).
if
c
!=
nil
&&
tokenType
==
"oauth"
{
c
.
Set
(
claudeMimicDebugInfoKey
,
buildClaudeMimicDebugLine
(
req
,
body
,
account
,
tokenType
,
mimicClaudeCode
))
}
if
s
.
debugClaudeMimicEnabled
()
{
logClaudeMimicDebug
(
req
,
body
,
account
,
tokenType
,
mimicClaudeCode
)
}
return
req
,
nil
return
req
,
nil
}
}
...
@@ -2897,30 +3546,117 @@ func defaultAPIKeyBetaHeader(body []byte) string {
...
@@ -2897,30 +3546,117 @@ func defaultAPIKeyBetaHeader(body []byte) string {
return
claude
.
APIKeyBetaHeader
return
claude
.
APIKeyBetaHeader
}
}
func
truncateForLog
(
b
[]
byte
,
maxBytes
int
)
string
{
func
applyClaudeOAuthHeaderDefaults
(
req
*
http
.
Request
,
isStream
bool
)
{
if
maxBytes
<=
0
{
if
req
==
nil
{
maxBytes
=
2048
return
}
}
if
len
(
b
)
>
maxBytes
{
if
req
.
Header
.
Get
(
"accept"
)
==
""
{
b
=
b
[
:
maxBytes
]
req
.
Header
.
Set
(
"accept"
,
"application/json"
)
}
for
key
,
value
:=
range
claude
.
DefaultHeaders
{
if
value
==
""
{
continue
}
if
req
.
Header
.
Get
(
key
)
==
""
{
req
.
Header
.
Set
(
key
,
value
)
}
}
if
isStream
&&
req
.
Header
.
Get
(
"x-stainless-helper-method"
)
==
""
{
req
.
Header
.
Set
(
"x-stainless-helper-method"
,
"stream"
)
}
}
s
:=
string
(
b
)
// 保持一行,避免污染日志格式
s
=
strings
.
ReplaceAll
(
s
,
"
\n
"
,
"
\\
n"
)
s
=
strings
.
ReplaceAll
(
s
,
"
\r
"
,
"
\\
r"
)
return
s
}
}
// isThinkingBlockSignatureError 检测是否是thinking block相关错误
func
mergeAnthropicBeta
(
required
[]
string
,
incoming
string
)
string
{
// 这类错误可以通过过滤thinking blocks并重试来解决
seen
:=
make
(
map
[
string
]
struct
{},
len
(
required
)
+
8
)
func
(
s
*
GatewayService
)
isThinkingBlockSignatureError
(
respBody
[]
byte
)
bool
{
out
:=
make
([]
string
,
0
,
len
(
required
)
+
8
)
msg
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
respBody
)))
if
msg
==
""
{
return
false
}
// Log for debugging
add
:=
func
(
v
string
)
{
log
.
Printf
(
"[SignatureCheck] Checking error message: %s"
,
msg
)
v
=
strings
.
TrimSpace
(
v
)
if
v
==
""
{
return
}
if
_
,
ok
:=
seen
[
v
];
ok
{
return
}
seen
[
v
]
=
struct
{}{}
out
=
append
(
out
,
v
)
}
for
_
,
r
:=
range
required
{
add
(
r
)
}
for
_
,
p
:=
range
strings
.
Split
(
incoming
,
","
)
{
add
(
p
)
}
return
strings
.
Join
(
out
,
","
)
}
func
mergeAnthropicBetaDropping
(
required
[]
string
,
incoming
string
,
drop
map
[
string
]
struct
{})
string
{
merged
:=
mergeAnthropicBeta
(
required
,
incoming
)
if
merged
==
""
||
len
(
drop
)
==
0
{
return
merged
}
out
:=
make
([]
string
,
0
,
8
)
for
_
,
p
:=
range
strings
.
Split
(
merged
,
","
)
{
p
=
strings
.
TrimSpace
(
p
)
if
p
==
""
{
continue
}
if
_
,
ok
:=
drop
[
p
];
ok
{
continue
}
out
=
append
(
out
,
p
)
}
return
strings
.
Join
(
out
,
","
)
}
// applyClaudeCodeMimicHeaders forces "Claude Code-like" request headers.
// This mirrors opencode-anthropic-auth behavior: do not trust downstream
// headers when using Claude Code-scoped OAuth credentials.
func
applyClaudeCodeMimicHeaders
(
req
*
http
.
Request
,
isStream
bool
)
{
if
req
==
nil
{
return
}
// Start with the standard defaults (fill missing).
applyClaudeOAuthHeaderDefaults
(
req
,
isStream
)
// Then force key headers to match Claude Code fingerprint regardless of what the client sent.
for
key
,
value
:=
range
claude
.
DefaultHeaders
{
if
value
==
""
{
continue
}
req
.
Header
.
Set
(
key
,
value
)
}
// Real Claude CLI uses Accept: application/json (even for streaming).
req
.
Header
.
Set
(
"accept"
,
"application/json"
)
if
isStream
{
req
.
Header
.
Set
(
"x-stainless-helper-method"
,
"stream"
)
}
}
func
truncateForLog
(
b
[]
byte
,
maxBytes
int
)
string
{
if
maxBytes
<=
0
{
maxBytes
=
2048
}
if
len
(
b
)
>
maxBytes
{
b
=
b
[
:
maxBytes
]
}
s
:=
string
(
b
)
// 保持一行,避免污染日志格式
s
=
strings
.
ReplaceAll
(
s
,
"
\n
"
,
"
\\
n"
)
s
=
strings
.
ReplaceAll
(
s
,
"
\r
"
,
"
\\
r"
)
return
s
}
// isThinkingBlockSignatureError 检测是否是thinking block相关错误
// 这类错误可以通过过滤thinking blocks并重试来解决
func
(
s
*
GatewayService
)
isThinkingBlockSignatureError
(
respBody
[]
byte
)
bool
{
msg
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
respBody
)))
if
msg
==
""
{
return
false
}
// Log for debugging
log
.
Printf
(
"[SignatureCheck] Checking error message: %s"
,
msg
)
// 检测signature相关的错误(更宽松的匹配)
// 检测signature相关的错误(更宽松的匹配)
// 例如: "Invalid `signature` in `thinking` block", "***.signature" 等
// 例如: "Invalid `signature` in `thinking` block", "***.signature" 等
...
@@ -3000,6 +3736,20 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
...
@@ -3000,6 +3736,20 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
body
))
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
body
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
// Print a compact upstream request fingerprint when we hit the Claude Code OAuth
// credential scope error. This avoids requiring env-var tweaks in a fixed deploy.
if
isClaudeCodeCredentialScopeError
(
upstreamMsg
)
&&
c
!=
nil
{
if
v
,
ok
:=
c
.
Get
(
claudeMimicDebugInfoKey
);
ok
{
if
line
,
ok
:=
v
.
(
string
);
ok
&&
strings
.
TrimSpace
(
line
)
!=
""
{
log
.
Printf
(
"[ClaudeMimicDebugOnError] status=%d request_id=%s %s"
,
resp
.
StatusCode
,
resp
.
Header
.
Get
(
"x-request-id"
),
line
,
)
}
}
}
// Enrich Ops error logs with upstream status + message, and optionally a truncated body snippet.
// Enrich Ops error logs with upstream status + message, and optionally a truncated body snippet.
upstreamDetail
:=
""
upstreamDetail
:=
""
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
...
@@ -3129,6 +3879,19 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
...
@@ -3129,6 +3879,19 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
respBody
))
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
respBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
if
isClaudeCodeCredentialScopeError
(
upstreamMsg
)
&&
c
!=
nil
{
if
v
,
ok
:=
c
.
Get
(
claudeMimicDebugInfoKey
);
ok
{
if
line
,
ok
:=
v
.
(
string
);
ok
&&
strings
.
TrimSpace
(
line
)
!=
""
{
log
.
Printf
(
"[ClaudeMimicDebugOnError] status=%d request_id=%s %s"
,
resp
.
StatusCode
,
resp
.
Header
.
Get
(
"x-request-id"
),
line
,
)
}
}
}
upstreamDetail
:=
""
upstreamDetail
:=
""
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
maxBytes
:=
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
maxBytes
:=
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
...
@@ -3181,7 +3944,7 @@ type streamingResult struct {
...
@@ -3181,7 +3944,7 @@ type streamingResult struct {
clientDisconnect
bool
// 客户端是否在流式传输过程中断开
clientDisconnect
bool
// 客户端是否在流式传输过程中断开
}
}
func
(
s
*
GatewayService
)
handleStreamingResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
Account
,
startTime
time
.
Time
,
originalModel
,
mappedModel
string
)
(
*
streamingResult
,
error
)
{
func
(
s
*
GatewayService
)
handleStreamingResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
Account
,
startTime
time
.
Time
,
originalModel
,
mappedModel
string
,
toolNameMap
map
[
string
]
string
,
mimicClaudeCode
bool
)
(
*
streamingResult
,
error
)
{
// 更新5h窗口状态
// 更新5h窗口状态
s
.
rateLimitService
.
UpdateSessionWindow
(
ctx
,
account
,
resp
.
Header
)
s
.
rateLimitService
.
UpdateSessionWindow
(
ctx
,
account
,
resp
.
Header
)
...
@@ -3276,6 +4039,171 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
...
@@ -3276,6 +4039,171 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
needModelReplace
:=
originalModel
!=
mappedModel
needModelReplace
:=
originalModel
!=
mappedModel
clientDisconnected
:=
false
// 客户端断开标志,断开后继续读取上游以获取完整usage
clientDisconnected
:=
false
// 客户端断开标志,断开后继续读取上游以获取完整usage
pendingEventLines
:=
make
([]
string
,
0
,
4
)
var
toolInputBuffers
map
[
int
]
string
if
mimicClaudeCode
{
toolInputBuffers
=
make
(
map
[
int
]
string
)
}
transformToolInputJSON
:=
func
(
raw
string
)
string
{
if
!
mimicClaudeCode
{
return
raw
}
raw
=
strings
.
TrimSpace
(
raw
)
if
raw
==
""
{
return
raw
}
var
parsed
any
if
err
:=
json
.
Unmarshal
([]
byte
(
raw
),
&
parsed
);
err
!=
nil
{
return
replaceToolNamesInText
(
raw
,
toolNameMap
)
}
rewritten
,
changed
:=
rewriteParamKeysInValue
(
parsed
,
toolNameMap
)
if
changed
{
if
bytes
,
err
:=
json
.
Marshal
(
rewritten
);
err
==
nil
{
return
string
(
bytes
)
}
}
return
raw
}
processSSEEvent
:=
func
(
lines
[]
string
)
([]
string
,
string
,
error
)
{
if
len
(
lines
)
==
0
{
return
nil
,
""
,
nil
}
eventName
:=
""
dataLine
:=
""
for
_
,
line
:=
range
lines
{
trimmed
:=
strings
.
TrimSpace
(
line
)
if
strings
.
HasPrefix
(
trimmed
,
"event:"
)
{
eventName
=
strings
.
TrimSpace
(
strings
.
TrimPrefix
(
trimmed
,
"event:"
))
continue
}
if
dataLine
==
""
&&
sseDataRe
.
MatchString
(
trimmed
)
{
dataLine
=
sseDataRe
.
ReplaceAllString
(
trimmed
,
""
)
}
}
if
eventName
==
"error"
{
return
nil
,
dataLine
,
errors
.
New
(
"have error in stream"
)
}
if
dataLine
==
""
{
return
[]
string
{
strings
.
Join
(
lines
,
"
\n
"
)
+
"
\n\n
"
},
""
,
nil
}
if
dataLine
==
"[DONE]"
{
block
:=
""
if
eventName
!=
""
{
block
=
"event: "
+
eventName
+
"
\n
"
}
block
+=
"data: "
+
dataLine
+
"
\n\n
"
return
[]
string
{
block
},
dataLine
,
nil
}
var
event
map
[
string
]
any
if
err
:=
json
.
Unmarshal
([]
byte
(
dataLine
),
&
event
);
err
!=
nil
{
replaced
:=
dataLine
if
mimicClaudeCode
{
replaced
=
replaceToolNamesInText
(
dataLine
,
toolNameMap
)
}
block
:=
""
if
eventName
!=
""
{
block
=
"event: "
+
eventName
+
"
\n
"
}
block
+=
"data: "
+
replaced
+
"
\n\n
"
return
[]
string
{
block
},
replaced
,
nil
}
eventType
,
_
:=
event
[
"type"
]
.
(
string
)
if
eventName
==
""
{
eventName
=
eventType
}
if
needModelReplace
{
if
msg
,
ok
:=
event
[
"message"
]
.
(
map
[
string
]
any
);
ok
{
if
model
,
ok
:=
msg
[
"model"
]
.
(
string
);
ok
&&
model
==
mappedModel
{
msg
[
"model"
]
=
originalModel
}
}
}
if
mimicClaudeCode
&&
eventType
==
"content_block_delta"
{
if
delta
,
ok
:=
event
[
"delta"
]
.
(
map
[
string
]
any
);
ok
{
if
deltaType
,
_
:=
delta
[
"type"
]
.
(
string
);
deltaType
==
"input_json_delta"
{
if
indexVal
,
ok
:=
event
[
"index"
]
.
(
float64
);
ok
{
index
:=
int
(
indexVal
)
if
partial
,
ok
:=
delta
[
"partial_json"
]
.
(
string
);
ok
{
toolInputBuffers
[
index
]
+=
partial
}
}
return
nil
,
dataLine
,
nil
}
}
}
if
mimicClaudeCode
&&
eventType
==
"content_block_stop"
{
if
indexVal
,
ok
:=
event
[
"index"
]
.
(
float64
);
ok
{
index
:=
int
(
indexVal
)
if
buffered
:=
toolInputBuffers
[
index
];
buffered
!=
""
{
delete
(
toolInputBuffers
,
index
)
transformed
:=
transformToolInputJSON
(
buffered
)
synthetic
:=
map
[
string
]
any
{
"type"
:
"content_block_delta"
,
"index"
:
index
,
"delta"
:
map
[
string
]
any
{
"type"
:
"input_json_delta"
,
"partial_json"
:
transformed
,
},
}
synthBytes
,
synthErr
:=
json
.
Marshal
(
synthetic
)
if
synthErr
==
nil
{
synthBlock
:=
"event: content_block_delta
\n
"
+
"data: "
+
string
(
synthBytes
)
+
"
\n\n
"
rewriteToolNamesInValue
(
event
,
toolNameMap
)
stopBytes
,
stopErr
:=
json
.
Marshal
(
event
)
if
stopErr
==
nil
{
stopBlock
:=
""
if
eventName
!=
""
{
stopBlock
=
"event: "
+
eventName
+
"
\n
"
}
stopBlock
+=
"data: "
+
string
(
stopBytes
)
+
"
\n\n
"
return
[]
string
{
synthBlock
,
stopBlock
},
string
(
stopBytes
),
nil
}
}
}
}
}
if
mimicClaudeCode
{
rewriteToolNamesInValue
(
event
,
toolNameMap
)
}
newData
,
err
:=
json
.
Marshal
(
event
)
if
err
!=
nil
{
replaced
:=
dataLine
if
mimicClaudeCode
{
replaced
=
replaceToolNamesInText
(
dataLine
,
toolNameMap
)
}
block
:=
""
if
eventName
!=
""
{
block
=
"event: "
+
eventName
+
"
\n
"
}
block
+=
"data: "
+
replaced
+
"
\n\n
"
return
[]
string
{
block
},
replaced
,
nil
}
block
:=
""
if
eventName
!=
""
{
block
=
"event: "
+
eventName
+
"
\n
"
}
block
+=
"data: "
+
string
(
newData
)
+
"
\n\n
"
return
[]
string
{
block
},
string
(
newData
),
nil
}
for
{
for
{
select
{
select
{
case
ev
,
ok
:=
<-
events
:
case
ev
,
ok
:=
<-
events
:
...
@@ -3304,43 +4232,44 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
...
@@ -3304,43 +4232,44 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream read error: %w"
,
ev
.
err
)
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream read error: %w"
,
ev
.
err
)
}
}
line
:=
ev
.
line
line
:=
ev
.
line
if
line
==
"event: error"
{
trimmed
:=
strings
.
TrimSpace
(
line
)
// 上游返回错误事件,如果客户端已断开仍返回已收集的 usage
if
clientDisconnected
{
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
true
},
nil
}
return
nil
,
errors
.
New
(
"have error in stream"
)
}
// Extract data from SSE line (supports both "data: " and "data:" formats)
if
trimmed
==
""
{
var
data
string
if
len
(
pendingEventLines
)
==
0
{
if
sseDataRe
.
MatchString
(
line
)
{
continue
data
=
sseDataRe
.
ReplaceAllString
(
line
,
""
)
// 如果有模型映射,替换响应中的model字段
if
needModelReplace
{
line
=
s
.
replaceModelInSSELine
(
line
,
mappedModel
,
originalModel
)
}
}
}
// 写入客户端(统一处理 data 行和非 data 行)
outputBlocks
,
data
,
err
:=
processSSEEvent
(
pendingEventLines
)
if
!
clientDisconnected
{
pendingEventLines
=
pendingEventLines
[
:
0
]
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
if
err
!=
nil
{
clientDisconnected
=
true
if
clientDisconnected
{
log
.
Printf
(
"Client disconnected during streaming, continuing to drain upstream for billing"
)
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
true
},
nil
}
else
{
}
flusher
.
Flush
()
return
nil
,
err
}
}
}
// 无论客户端是否断开,都解析 usage(仅对 data 行)
for
_
,
block
:=
range
outputBlocks
{
if
data
!=
""
{
if
!
clientDisconnected
{
if
firstTokenMs
==
nil
&&
data
!=
"[DONE]"
{
if
_
,
werr
:=
fmt
.
Fprint
(
w
,
block
);
werr
!=
nil
{
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
clientDisconnected
=
true
firstTokenMs
=
&
ms
log
.
Printf
(
"Client disconnected during streaming, continuing to drain upstream for billing"
)
break
}
flusher
.
Flush
()
}
if
data
!=
""
{
if
firstTokenMs
==
nil
&&
data
!=
"[DONE]"
{
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
firstTokenMs
=
&
ms
}
s
.
parseSSEUsage
(
data
,
usage
)
}
}
}
s
.
parseSSEUsage
(
data
,
usage
)
continue
}
}
pendingEventLines
=
append
(
pendingEventLines
,
line
)
case
<-
intervalCh
:
case
<-
intervalCh
:
lastRead
:=
time
.
Unix
(
0
,
atomic
.
LoadInt64
(
&
lastReadAt
))
lastRead
:=
time
.
Unix
(
0
,
atomic
.
LoadInt64
(
&
lastReadAt
))
if
time
.
Since
(
lastRead
)
<
streamInterval
{
if
time
.
Since
(
lastRead
)
<
streamInterval
{
...
@@ -3363,43 +4292,124 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
...
@@ -3363,43 +4292,124 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
}
// replaceModelInSSELine 替换SSE数据行中的model字段
func
rewriteParamKeysInValue
(
value
any
,
cache
map
[
string
]
string
)
(
any
,
bool
)
{
func
(
s
*
GatewayService
)
replaceModelInSSELine
(
line
,
fromModel
,
toModel
string
)
string
{
switch
v
:=
value
.
(
type
)
{
if
!
sseDataRe
.
MatchString
(
line
)
{
case
map
[
string
]
any
:
return
line
changed
:=
false
}
rewritten
:=
make
(
map
[
string
]
any
,
len
(
v
))
data
:=
sseDataRe
.
ReplaceAllString
(
line
,
""
)
for
key
,
item
:=
range
v
{
if
data
==
""
||
data
==
"[DONE]"
{
newKey
:=
normalizeParamNameForOpenCode
(
key
,
cache
)
return
line
newItem
,
childChanged
:=
rewriteParamKeysInValue
(
item
,
cache
)
}
if
childChanged
{
changed
=
true
var
event
map
[
string
]
any
}
if
err
:=
json
.
Unmarshal
([]
byte
(
data
),
&
event
);
err
!=
nil
{
if
newKey
!=
key
{
return
line
changed
=
true
}
}
rewritten
[
newKey
]
=
newItem
// 只替换 message_start 事件中的 message.model
}
if
event
[
"type"
]
!=
"message_start"
{
if
!
changed
{
return
line
return
value
,
false
}
return
rewritten
,
true
case
[]
any
:
changed
:=
false
rewritten
:=
make
([]
any
,
len
(
v
))
for
idx
,
item
:=
range
v
{
newItem
,
childChanged
:=
rewriteParamKeysInValue
(
item
,
cache
)
if
childChanged
{
changed
=
true
}
rewritten
[
idx
]
=
newItem
}
if
!
changed
{
return
value
,
false
}
return
rewritten
,
true
default
:
return
value
,
false
}
}
}
msg
,
ok
:=
event
[
"message"
]
.
(
map
[
string
]
any
)
func
rewriteToolNamesInValue
(
value
any
,
toolNameMap
map
[
string
]
string
)
bool
{
if
!
ok
{
switch
v
:=
value
.
(
type
)
{
return
line
case
map
[
string
]
any
:
changed
:=
false
if
blockType
,
_
:=
v
[
"type"
]
.
(
string
);
blockType
==
"tool_use"
{
if
name
,
ok
:=
v
[
"name"
]
.
(
string
);
ok
{
mapped
:=
normalizeToolNameForOpenCode
(
name
,
toolNameMap
)
if
mapped
!=
name
{
v
[
"name"
]
=
mapped
changed
=
true
}
}
if
input
,
ok
:=
v
[
"input"
]
.
(
map
[
string
]
any
);
ok
{
rewrittenInput
,
inputChanged
:=
rewriteParamKeysInValue
(
input
,
toolNameMap
)
if
inputChanged
{
if
m
,
ok
:=
rewrittenInput
.
(
map
[
string
]
any
);
ok
{
v
[
"input"
]
=
m
changed
=
true
}
}
}
}
for
_
,
item
:=
range
v
{
if
rewriteToolNamesInValue
(
item
,
toolNameMap
)
{
changed
=
true
}
}
return
changed
case
[]
any
:
changed
:=
false
for
_
,
item
:=
range
v
{
if
rewriteToolNamesInValue
(
item
,
toolNameMap
)
{
changed
=
true
}
}
return
changed
default
:
return
false
}
}
}
model
,
ok
:=
msg
[
"model"
]
.
(
string
)
func
replaceToolNamesInText
(
text
string
,
toolNameMap
map
[
string
]
string
)
string
{
if
!
ok
||
model
!=
fromModel
{
if
text
==
""
{
return
line
return
text
}
}
output
:=
toolNameFieldRe
.
ReplaceAllStringFunc
(
text
,
func
(
match
string
)
string
{
submatches
:=
toolNameFieldRe
.
FindStringSubmatch
(
match
)
if
len
(
submatches
)
<
2
{
return
match
}
name
:=
submatches
[
1
]
mapped
:=
normalizeToolNameForOpenCode
(
name
,
toolNameMap
)
if
mapped
==
name
{
return
match
}
return
strings
.
Replace
(
match
,
name
,
mapped
,
1
)
})
output
=
modelFieldRe
.
ReplaceAllStringFunc
(
output
,
func
(
match
string
)
string
{
submatches
:=
modelFieldRe
.
FindStringSubmatch
(
match
)
if
len
(
submatches
)
<
2
{
return
match
}
model
:=
submatches
[
1
]
mapped
:=
claude
.
DenormalizeModelID
(
model
)
if
mapped
==
model
{
return
match
}
return
strings
.
Replace
(
match
,
model
,
mapped
,
1
)
})
msg
[
"model"
]
=
toModel
for
mapped
,
original
:=
range
toolNameMap
{
newData
,
err
:=
json
.
Marshal
(
event
)
if
mapped
==
""
||
original
==
""
||
mapped
==
original
{
if
err
!=
nil
{
continue
return
line
}
output
=
strings
.
ReplaceAll
(
output
,
"
\"
"
+
mapped
+
"
\"
:"
,
"
\"
"
+
original
+
"
\"
:"
)
output
=
strings
.
ReplaceAll
(
output
,
"
\\\"
"
+
mapped
+
"
\\\"
:"
,
"
\\\"
"
+
original
+
"
\\\"
:"
)
}
}
return
"data: "
+
string
(
newData
)
return
output
}
}
func
(
s
*
GatewayService
)
parseSSEUsage
(
data
string
,
usage
*
ClaudeUsage
)
{
func
(
s
*
GatewayService
)
parseSSEUsage
(
data
string
,
usage
*
ClaudeUsage
)
{
...
@@ -3445,7 +4455,7 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
...
@@ -3445,7 +4455,7 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
}
}
}
}
func
(
s
*
GatewayService
)
handleNonStreamingResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
Account
,
originalModel
,
mappedModel
string
)
(
*
ClaudeUsage
,
error
)
{
func
(
s
*
GatewayService
)
handleNonStreamingResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
Account
,
originalModel
,
mappedModel
string
,
toolNameMap
map
[
string
]
string
,
mimicClaudeCode
bool
)
(
*
ClaudeUsage
,
error
)
{
// 更新5h窗口状态
// 更新5h窗口状态
s
.
rateLimitService
.
UpdateSessionWindow
(
ctx
,
account
,
resp
.
Header
)
s
.
rateLimitService
.
UpdateSessionWindow
(
ctx
,
account
,
resp
.
Header
)
...
@@ -3466,6 +4476,9 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
...
@@ -3466,6 +4476,9 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
if
originalModel
!=
mappedModel
{
if
originalModel
!=
mappedModel
{
body
=
s
.
replaceModelInResponseBody
(
body
,
mappedModel
,
originalModel
)
body
=
s
.
replaceModelInResponseBody
(
body
,
mappedModel
,
originalModel
)
}
}
if
mimicClaudeCode
{
body
=
s
.
replaceToolNamesInResponseBody
(
body
,
toolNameMap
)
}
responseheaders
.
WriteFilteredHeaders
(
c
.
Writer
.
Header
(),
resp
.
Header
,
s
.
cfg
.
Security
.
ResponseHeaders
)
responseheaders
.
WriteFilteredHeaders
(
c
.
Writer
.
Header
(),
resp
.
Header
,
s
.
cfg
.
Security
.
ResponseHeaders
)
...
@@ -3503,6 +4516,28 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
...
@@ -3503,6 +4516,28 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
return
newBody
return
newBody
}
}
func
(
s
*
GatewayService
)
replaceToolNamesInResponseBody
(
body
[]
byte
,
toolNameMap
map
[
string
]
string
)
[]
byte
{
if
len
(
body
)
==
0
{
return
body
}
var
resp
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
resp
);
err
!=
nil
{
replaced
:=
replaceToolNamesInText
(
string
(
body
),
toolNameMap
)
if
replaced
==
string
(
body
)
{
return
body
}
return
[]
byte
(
replaced
)
}
if
!
rewriteToolNamesInValue
(
resp
,
toolNameMap
)
{
return
body
}
newBody
,
err
:=
json
.
Marshal
(
resp
)
if
err
!=
nil
{
return
body
}
return
newBody
}
// RecordUsageInput 记录使用量的输入参数
// RecordUsageInput 记录使用量的输入参数
type
RecordUsageInput
struct
{
type
RecordUsageInput
struct
{
Result
*
ForwardResult
Result
*
ForwardResult
...
@@ -3657,6 +4692,162 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
...
@@ -3657,6 +4692,162 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
return
nil
return
nil
}
}
// RecordUsageLongContextInput 记录使用量的输入参数(支持长上下文双倍计费)
type
RecordUsageLongContextInput
struct
{
Result
*
ForwardResult
APIKey
*
APIKey
User
*
User
Account
*
Account
Subscription
*
UserSubscription
// 可选:订阅信息
UserAgent
string
// 请求的 User-Agent
IPAddress
string
// 请求的客户端 IP 地址
LongContextThreshold
int
// 长上下文阈值(如 200000)
LongContextMultiplier
float64
// 超出阈值部分的倍率(如 2.0)
}
// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini)
func
(
s
*
GatewayService
)
RecordUsageWithLongContext
(
ctx
context
.
Context
,
input
*
RecordUsageLongContextInput
)
error
{
result
:=
input
.
Result
apiKey
:=
input
.
APIKey
user
:=
input
.
User
account
:=
input
.
Account
subscription
:=
input
.
Subscription
// 获取费率倍数
multiplier
:=
s
.
cfg
.
Default
.
RateMultiplier
if
apiKey
.
GroupID
!=
nil
&&
apiKey
.
Group
!=
nil
{
multiplier
=
apiKey
.
Group
.
RateMultiplier
}
var
cost
*
CostBreakdown
// 根据请求类型选择计费方式
if
result
.
ImageCount
>
0
{
// 图片生成计费
var
groupConfig
*
ImagePriceConfig
if
apiKey
.
Group
!=
nil
{
groupConfig
=
&
ImagePriceConfig
{
Price1K
:
apiKey
.
Group
.
ImagePrice1K
,
Price2K
:
apiKey
.
Group
.
ImagePrice2K
,
Price4K
:
apiKey
.
Group
.
ImagePrice4K
,
}
}
cost
=
s
.
billingService
.
CalculateImageCost
(
result
.
Model
,
result
.
ImageSize
,
result
.
ImageCount
,
groupConfig
,
multiplier
)
}
else
{
// Token 计费(使用长上下文计费方法)
tokens
:=
UsageTokens
{
InputTokens
:
result
.
Usage
.
InputTokens
,
OutputTokens
:
result
.
Usage
.
OutputTokens
,
CacheCreationTokens
:
result
.
Usage
.
CacheCreationInputTokens
,
CacheReadTokens
:
result
.
Usage
.
CacheReadInputTokens
,
}
var
err
error
cost
,
err
=
s
.
billingService
.
CalculateCostWithLongContext
(
result
.
Model
,
tokens
,
multiplier
,
input
.
LongContextThreshold
,
input
.
LongContextMultiplier
)
if
err
!=
nil
{
log
.
Printf
(
"Calculate cost failed: %v"
,
err
)
cost
=
&
CostBreakdown
{
ActualCost
:
0
}
}
}
// 判断计费方式:订阅模式 vs 余额模式
isSubscriptionBilling
:=
subscription
!=
nil
&&
apiKey
.
Group
!=
nil
&&
apiKey
.
Group
.
IsSubscriptionType
()
billingType
:=
BillingTypeBalance
if
isSubscriptionBilling
{
billingType
=
BillingTypeSubscription
}
// 创建使用日志
durationMs
:=
int
(
result
.
Duration
.
Milliseconds
())
var
imageSize
*
string
if
result
.
ImageSize
!=
""
{
imageSize
=
&
result
.
ImageSize
}
accountRateMultiplier
:=
account
.
BillingRateMultiplier
()
usageLog
:=
&
UsageLog
{
UserID
:
user
.
ID
,
APIKeyID
:
apiKey
.
ID
,
AccountID
:
account
.
ID
,
RequestID
:
result
.
RequestID
,
Model
:
result
.
Model
,
InputTokens
:
result
.
Usage
.
InputTokens
,
OutputTokens
:
result
.
Usage
.
OutputTokens
,
CacheCreationTokens
:
result
.
Usage
.
CacheCreationInputTokens
,
CacheReadTokens
:
result
.
Usage
.
CacheReadInputTokens
,
InputCost
:
cost
.
InputCost
,
OutputCost
:
cost
.
OutputCost
,
CacheCreationCost
:
cost
.
CacheCreationCost
,
CacheReadCost
:
cost
.
CacheReadCost
,
TotalCost
:
cost
.
TotalCost
,
ActualCost
:
cost
.
ActualCost
,
RateMultiplier
:
multiplier
,
AccountRateMultiplier
:
&
accountRateMultiplier
,
BillingType
:
billingType
,
Stream
:
result
.
Stream
,
DurationMs
:
&
durationMs
,
FirstTokenMs
:
result
.
FirstTokenMs
,
ImageCount
:
result
.
ImageCount
,
ImageSize
:
imageSize
,
CreatedAt
:
time
.
Now
(),
}
// 添加 UserAgent
if
input
.
UserAgent
!=
""
{
usageLog
.
UserAgent
=
&
input
.
UserAgent
}
// 添加 IPAddress
if
input
.
IPAddress
!=
""
{
usageLog
.
IPAddress
=
&
input
.
IPAddress
}
// 添加分组和订阅关联
if
apiKey
.
GroupID
!=
nil
{
usageLog
.
GroupID
=
apiKey
.
GroupID
}
if
subscription
!=
nil
{
usageLog
.
SubscriptionID
=
&
subscription
.
ID
}
inserted
,
err
:=
s
.
usageLogRepo
.
Create
(
ctx
,
usageLog
)
if
err
!=
nil
{
log
.
Printf
(
"Create usage log failed: %v"
,
err
)
}
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
RunMode
==
config
.
RunModeSimple
{
log
.
Printf
(
"[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d"
,
usageLog
.
UserID
,
usageLog
.
TotalTokens
())
s
.
deferredService
.
ScheduleLastUsedUpdate
(
account
.
ID
)
return
nil
}
shouldBill
:=
inserted
||
err
!=
nil
// 根据计费类型执行扣费
if
isSubscriptionBilling
{
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
if
shouldBill
&&
cost
.
TotalCost
>
0
{
if
err
:=
s
.
userSubRepo
.
IncrementUsage
(
ctx
,
subscription
.
ID
,
cost
.
TotalCost
);
err
!=
nil
{
log
.
Printf
(
"Increment subscription usage failed: %v"
,
err
)
}
// 异步更新订阅缓存
s
.
billingCacheService
.
QueueUpdateSubscriptionUsage
(
user
.
ID
,
*
apiKey
.
GroupID
,
cost
.
TotalCost
)
}
}
else
{
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
if
shouldBill
&&
cost
.
ActualCost
>
0
{
if
err
:=
s
.
userRepo
.
DeductBalance
(
ctx
,
user
.
ID
,
cost
.
ActualCost
);
err
!=
nil
{
log
.
Printf
(
"Deduct balance failed: %v"
,
err
)
}
// 异步更新余额缓存
s
.
billingCacheService
.
QueueDeductBalance
(
user
.
ID
,
cost
.
ActualCost
)
}
}
// Schedule batch update for account last_used_at
s
.
deferredService
.
ScheduleLastUsedUpdate
(
account
.
ID
)
return
nil
}
// ForwardCountTokens 转发 count_tokens 请求到上游 API
// ForwardCountTokens 转发 count_tokens 请求到上游 API
// 特点:不记录使用量、仅支持非流式响应
// 特点:不记录使用量、仅支持非流式响应
func
(
s
*
GatewayService
)
ForwardCountTokens
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
parsed
*
ParsedRequest
)
error
{
func
(
s
*
GatewayService
)
ForwardCountTokens
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
parsed
*
ParsedRequest
)
error
{
...
@@ -3668,6 +4859,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
...
@@ -3668,6 +4859,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
body
:=
parsed
.
Body
body
:=
parsed
.
Body
reqModel
:=
parsed
.
Model
reqModel
:=
parsed
.
Model
isClaudeCode
:=
isClaudeCodeRequest
(
ctx
,
c
,
parsed
)
shouldMimicClaudeCode
:=
account
.
IsOAuth
()
&&
!
isClaudeCode
if
shouldMimicClaudeCode
{
normalizeOpts
:=
claudeOAuthNormalizeOptions
{
stripSystemCacheControl
:
true
}
body
,
reqModel
,
_
=
normalizeClaudeOAuthRequestBody
(
body
,
reqModel
,
normalizeOpts
)
}
// Antigravity 账户不支持 count_tokens 转发,直接返回空值
// Antigravity 账户不支持 count_tokens 转发,直接返回空值
if
account
.
Platform
==
PlatformAntigravity
{
if
account
.
Platform
==
PlatformAntigravity
{
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"input_tokens"
:
0
})
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"input_tokens"
:
0
})
...
@@ -3706,7 +4905,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
...
@@ -3706,7 +4905,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
}
// 构建上游请求
// 构建上游请求
upstreamReq
,
err
:=
s
.
buildCountTokensRequest
(
ctx
,
c
,
account
,
body
,
token
,
tokenType
,
reqModel
)
upstreamReq
,
err
:=
s
.
buildCountTokensRequest
(
ctx
,
c
,
account
,
body
,
token
,
tokenType
,
reqModel
,
shouldMimicClaudeCode
)
if
err
!=
nil
{
if
err
!=
nil
{
s
.
countTokensError
(
c
,
http
.
StatusInternalServerError
,
"api_error"
,
"Failed to build request"
)
s
.
countTokensError
(
c
,
http
.
StatusInternalServerError
,
"api_error"
,
"Failed to build request"
)
return
err
return
err
...
@@ -3739,7 +4938,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
...
@@ -3739,7 +4938,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
log
.
Printf
(
"Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks"
,
account
.
ID
)
log
.
Printf
(
"Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks"
,
account
.
ID
)
filteredBody
:=
FilterThinkingBlocksForRetry
(
body
)
filteredBody
:=
FilterThinkingBlocksForRetry
(
body
)
retryReq
,
buildErr
:=
s
.
buildCountTokensRequest
(
ctx
,
c
,
account
,
filteredBody
,
token
,
tokenType
,
reqModel
)
retryReq
,
buildErr
:=
s
.
buildCountTokensRequest
(
ctx
,
c
,
account
,
filteredBody
,
token
,
tokenType
,
reqModel
,
shouldMimicClaudeCode
)
if
buildErr
==
nil
{
if
buildErr
==
nil
{
retryResp
,
retryErr
:=
s
.
httpUpstream
.
DoWithTLS
(
retryReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
())
retryResp
,
retryErr
:=
s
.
httpUpstream
.
DoWithTLS
(
retryReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
())
if
retryErr
==
nil
{
if
retryErr
==
nil
{
...
@@ -3804,7 +5003,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
...
@@ -3804,7 +5003,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
}
// buildCountTokensRequest 构建 count_tokens 上游请求
// buildCountTokensRequest 构建 count_tokens 上游请求
func
(
s
*
GatewayService
)
buildCountTokensRequest
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
,
token
,
tokenType
,
modelID
string
)
(
*
http
.
Request
,
error
)
{
func
(
s
*
GatewayService
)
buildCountTokensRequest
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
,
token
,
tokenType
,
modelID
string
,
mimicClaudeCode
bool
)
(
*
http
.
Request
,
error
)
{
// 确定目标 URL
// 确定目标 URL
targetURL
:=
claudeAPICountTokensURL
targetURL
:=
claudeAPICountTokensURL
if
account
.
Type
==
AccountTypeAPIKey
{
if
account
.
Type
==
AccountTypeAPIKey
{
...
@@ -3818,10 +5017,15 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
...
@@ -3818,10 +5017,15 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
}
}
}
clientHeaders
:=
http
.
Header
{}
if
c
!=
nil
&&
c
.
Request
!=
nil
{
clientHeaders
=
c
.
Request
.
Header
}
// OAuth 账号:应用统一指纹和重写 userID
// OAuth 账号:应用统一指纹和重写 userID
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
if
account
.
IsOAuth
()
&&
s
.
identityService
!=
nil
{
if
account
.
IsOAuth
()
&&
s
.
identityService
!=
nil
{
fp
,
err
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
.
Request
.
Header
)
fp
,
err
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
lient
Header
s
)
if
err
==
nil
{
if
err
==
nil
{
accountUUID
:=
account
.
GetExtraString
(
"account_uuid"
)
accountUUID
:=
account
.
GetExtraString
(
"account_uuid"
)
if
accountUUID
!=
""
&&
fp
.
ClientID
!=
""
{
if
accountUUID
!=
""
&&
fp
.
ClientID
!=
""
{
...
@@ -3845,7 +5049,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
...
@@ -3845,7 +5049,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
}
// 白名单透传 headers
// 白名单透传 headers
for
key
,
values
:=
range
c
.
Request
.
Header
{
for
key
,
values
:=
range
c
lient
Header
s
{
lowerKey
:=
strings
.
ToLower
(
key
)
lowerKey
:=
strings
.
ToLower
(
key
)
if
allowedHeaders
[
lowerKey
]
{
if
allowedHeaders
[
lowerKey
]
{
for
_
,
v
:=
range
values
{
for
_
,
v
:=
range
values
{
...
@@ -3856,7 +5060,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
...
@@ -3856,7 +5060,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
// OAuth 账号:应用指纹到请求头
// OAuth 账号:应用指纹到请求头
if
account
.
IsOAuth
()
&&
s
.
identityService
!=
nil
{
if
account
.
IsOAuth
()
&&
s
.
identityService
!=
nil
{
fp
,
_
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
.
Request
.
Header
)
fp
,
_
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
lient
Header
s
)
if
fp
!=
nil
{
if
fp
!=
nil
{
s
.
identityService
.
ApplyFingerprint
(
req
,
fp
)
s
.
identityService
.
ApplyFingerprint
(
req
,
fp
)
}
}
...
@@ -3869,10 +5073,30 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
...
@@ -3869,10 +5073,30 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
if
req
.
Header
.
Get
(
"anthropic-version"
)
==
""
{
if
req
.
Header
.
Get
(
"anthropic-version"
)
==
""
{
req
.
Header
.
Set
(
"anthropic-version"
,
"2023-06-01"
)
req
.
Header
.
Set
(
"anthropic-version"
,
"2023-06-01"
)
}
}
if
tokenType
==
"oauth"
{
applyClaudeOAuthHeaderDefaults
(
req
,
false
)
}
// OAuth 账号:处理 anthropic-beta header
// OAuth 账号:处理 anthropic-beta header
if
tokenType
==
"oauth"
{
if
tokenType
==
"oauth"
{
req
.
Header
.
Set
(
"anthropic-beta"
,
s
.
getBetaHeader
(
modelID
,
c
.
GetHeader
(
"anthropic-beta"
)))
if
mimicClaudeCode
{
applyClaudeCodeMimicHeaders
(
req
,
false
)
incomingBeta
:=
req
.
Header
.
Get
(
"anthropic-beta"
)
requiredBetas
:=
[]
string
{
claude
.
BetaClaudeCode
,
claude
.
BetaOAuth
,
claude
.
BetaInterleavedThinking
,
claude
.
BetaTokenCounting
}
req
.
Header
.
Set
(
"anthropic-beta"
,
mergeAnthropicBeta
(
requiredBetas
,
incomingBeta
))
}
else
{
clientBetaHeader
:=
req
.
Header
.
Get
(
"anthropic-beta"
)
if
clientBetaHeader
==
""
{
req
.
Header
.
Set
(
"anthropic-beta"
,
claude
.
CountTokensBetaHeader
)
}
else
{
beta
:=
s
.
getBetaHeader
(
modelID
,
clientBetaHeader
)
if
!
strings
.
Contains
(
beta
,
claude
.
BetaTokenCounting
)
{
beta
=
beta
+
","
+
claude
.
BetaTokenCounting
}
req
.
Header
.
Set
(
"anthropic-beta"
,
beta
)
}
}
}
else
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
InjectBetaForAPIKey
&&
req
.
Header
.
Get
(
"anthropic-beta"
)
==
""
{
}
else
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
InjectBetaForAPIKey
&&
req
.
Header
.
Get
(
"anthropic-beta"
)
==
""
{
// API-key:与 messages 同步的按需 beta 注入(默认关闭)
// API-key:与 messages 同步的按需 beta 注入(默认关闭)
if
requestNeedsBetaFeatures
(
body
)
{
if
requestNeedsBetaFeatures
(
body
)
{
...
@@ -3882,6 +5106,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
...
@@ -3882,6 +5106,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
}
}
}
if
c
!=
nil
&&
tokenType
==
"oauth"
{
c
.
Set
(
claudeMimicDebugInfoKey
,
buildClaudeMimicDebugLine
(
req
,
body
,
account
,
tokenType
,
mimicClaudeCode
))
}
if
s
.
debugClaudeMimicEnabled
()
{
logClaudeMimicDebug
(
req
,
body
,
account
,
tokenType
,
mimicClaudeCode
)
}
return
req
,
nil
return
req
,
nil
}
}
...
...
backend/internal/service/gemini_messages_compat_service.go
View file @
2220fd18
...
@@ -36,6 +36,11 @@ const (
...
@@ -36,6 +36,11 @@ const (
geminiRetryMaxDelay
=
16
*
time
.
Second
geminiRetryMaxDelay
=
16
*
time
.
Second
)
)
// Gemini tool calling now requires `thoughtSignature` in parts that include `functionCall`.
// Many clients don't send it; we inject a known dummy signature to satisfy the validator.
// Ref: https://ai.google.dev/gemini-api/docs/thought-signatures
const
geminiDummyThoughtSignature
=
"skip_thought_signature_validator"
type
GeminiMessagesCompatService
struct
{
type
GeminiMessagesCompatService
struct
{
accountRepo
AccountRepository
accountRepo
AccountRepository
groupRepo
GroupRepository
groupRepo
GroupRepository
...
@@ -528,6 +533,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
...
@@ -528,6 +533,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
s
.
writeClaudeError
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
err
.
Error
())
return
nil
,
s
.
writeClaudeError
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
err
.
Error
())
}
}
geminiReq
=
ensureGeminiFunctionCallThoughtSignatures
(
geminiReq
)
originalClaudeBody
:=
body
originalClaudeBody
:=
body
proxyURL
:=
""
proxyURL
:=
""
...
@@ -983,6 +989,10 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
...
@@ -983,6 +989,10 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return
nil
,
s
.
writeGoogleError
(
c
,
http
.
StatusNotFound
,
"Unsupported action: "
+
action
)
return
nil
,
s
.
writeGoogleError
(
c
,
http
.
StatusNotFound
,
"Unsupported action: "
+
action
)
}
}
// Some Gemini upstreams validate tool call parts strictly; ensure any `functionCall` part includes a
// `thoughtSignature` to avoid frequent INVALID_ARGUMENT 400s.
body
=
ensureGeminiFunctionCallThoughtSignatures
(
body
)
mappedModel
:=
originalModel
mappedModel
:=
originalModel
if
account
.
Type
==
AccountTypeAPIKey
{
if
account
.
Type
==
AccountTypeAPIKey
{
mappedModel
=
account
.
GetMappedModel
(
originalModel
)
mappedModel
=
account
.
GetMappedModel
(
originalModel
)
...
@@ -2662,6 +2672,58 @@ func nextGeminiDailyResetUnix() *int64 {
...
@@ -2662,6 +2672,58 @@ func nextGeminiDailyResetUnix() *int64 {
return
&
ts
return
&
ts
}
}
func
ensureGeminiFunctionCallThoughtSignatures
(
body
[]
byte
)
[]
byte
{
// Fast path: only run when functionCall is present.
if
!
bytes
.
Contains
(
body
,
[]
byte
(
`"functionCall"`
))
{
return
body
}
var
payload
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
payload
);
err
!=
nil
{
return
body
}
contentsAny
,
ok
:=
payload
[
"contents"
]
.
([]
any
)
if
!
ok
||
len
(
contentsAny
)
==
0
{
return
body
}
modified
:=
false
for
_
,
c
:=
range
contentsAny
{
cm
,
ok
:=
c
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
partsAny
,
ok
:=
cm
[
"parts"
]
.
([]
any
)
if
!
ok
||
len
(
partsAny
)
==
0
{
continue
}
for
_
,
p
:=
range
partsAny
{
pm
,
ok
:=
p
.
(
map
[
string
]
any
)
if
!
ok
||
pm
==
nil
{
continue
}
if
fc
,
ok
:=
pm
[
"functionCall"
]
.
(
map
[
string
]
any
);
!
ok
||
fc
==
nil
{
continue
}
ts
,
_
:=
pm
[
"thoughtSignature"
]
.
(
string
)
if
strings
.
TrimSpace
(
ts
)
==
""
{
pm
[
"thoughtSignature"
]
=
geminiDummyThoughtSignature
modified
=
true
}
}
}
if
!
modified
{
return
body
}
b
,
err
:=
json
.
Marshal
(
payload
)
if
err
!=
nil
{
return
body
}
return
b
}
func
extractGeminiFinishReason
(
geminiResp
map
[
string
]
any
)
string
{
func
extractGeminiFinishReason
(
geminiResp
map
[
string
]
any
)
string
{
if
candidates
,
ok
:=
geminiResp
[
"candidates"
]
.
([]
any
);
ok
&&
len
(
candidates
)
>
0
{
if
candidates
,
ok
:=
geminiResp
[
"candidates"
]
.
([]
any
);
ok
&&
len
(
candidates
)
>
0
{
if
cand
,
ok
:=
candidates
[
0
]
.
(
map
[
string
]
any
);
ok
{
if
cand
,
ok
:=
candidates
[
0
]
.
(
map
[
string
]
any
);
ok
{
...
@@ -2861,7 +2923,13 @@ func convertClaudeMessagesToGeminiContents(messages any, toolUseIDToName map[str
...
@@ -2861,7 +2923,13 @@ func convertClaudeMessagesToGeminiContents(messages any, toolUseIDToName map[str
if
strings
.
TrimSpace
(
id
)
!=
""
&&
strings
.
TrimSpace
(
name
)
!=
""
{
if
strings
.
TrimSpace
(
id
)
!=
""
&&
strings
.
TrimSpace
(
name
)
!=
""
{
toolUseIDToName
[
id
]
=
name
toolUseIDToName
[
id
]
=
name
}
}
signature
,
_
:=
bm
[
"signature"
]
.
(
string
)
signature
=
strings
.
TrimSpace
(
signature
)
if
signature
==
""
{
signature
=
geminiDummyThoughtSignature
}
parts
=
append
(
parts
,
map
[
string
]
any
{
parts
=
append
(
parts
,
map
[
string
]
any
{
"thoughtSignature"
:
signature
,
"functionCall"
:
map
[
string
]
any
{
"functionCall"
:
map
[
string
]
any
{
"name"
:
name
,
"name"
:
name
,
"args"
:
bm
[
"input"
],
"args"
:
bm
[
"input"
],
...
...
backend/internal/service/gemini_messages_compat_service_test.go
View file @
2220fd18
package
service
package
service
import
(
import
(
"encoding/json"
"strings"
"testing"
"testing"
)
)
...
@@ -126,3 +128,78 @@ func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
...
@@ -126,3 +128,78 @@ func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
})
})
}
}
}
}
func
TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse
(
t
*
testing
.
T
)
{
claudeReq
:=
map
[
string
]
any
{
"model"
:
"claude-haiku-4-5-20251001"
,
"max_tokens"
:
10
,
"messages"
:
[]
any
{
map
[
string
]
any
{
"role"
:
"user"
,
"content"
:
[]
any
{
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
"hi"
},
},
},
map
[
string
]
any
{
"role"
:
"assistant"
,
"content"
:
[]
any
{
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
"ok"
},
map
[
string
]
any
{
"type"
:
"tool_use"
,
"id"
:
"toolu_123"
,
"name"
:
"default_api:write_file"
,
"input"
:
map
[
string
]
any
{
"path"
:
"a.txt"
,
"content"
:
"x"
},
// no signature on purpose
},
},
},
},
"tools"
:
[]
any
{
map
[
string
]
any
{
"name"
:
"default_api:write_file"
,
"description"
:
"write file"
,
"input_schema"
:
map
[
string
]
any
{
"type"
:
"object"
,
"properties"
:
map
[
string
]
any
{
"path"
:
map
[
string
]
any
{
"type"
:
"string"
}},
},
},
},
}
b
,
_
:=
json
.
Marshal
(
claudeReq
)
out
,
err
:=
convertClaudeMessagesToGeminiGenerateContent
(
b
)
if
err
!=
nil
{
t
.
Fatalf
(
"convert failed: %v"
,
err
)
}
s
:=
string
(
out
)
if
!
strings
.
Contains
(
s
,
"
\"
functionCall
\"
"
)
{
t
.
Fatalf
(
"expected functionCall in output, got: %s"
,
s
)
}
if
!
strings
.
Contains
(
s
,
"
\"
thoughtSignature
\"
:
\"
"
+
geminiDummyThoughtSignature
+
"
\"
"
)
{
t
.
Fatalf
(
"expected injected thoughtSignature %q, got: %s"
,
geminiDummyThoughtSignature
,
s
)
}
}
func
TestEnsureGeminiFunctionCallThoughtSignatures_InsertsWhenMissing
(
t
*
testing
.
T
)
{
geminiReq
:=
map
[
string
]
any
{
"contents"
:
[]
any
{
map
[
string
]
any
{
"role"
:
"user"
,
"parts"
:
[]
any
{
map
[
string
]
any
{
"functionCall"
:
map
[
string
]
any
{
"name"
:
"default_api:write_file"
,
"args"
:
map
[
string
]
any
{
"path"
:
"a.txt"
},
},
},
},
},
},
}
b
,
_
:=
json
.
Marshal
(
geminiReq
)
out
:=
ensureGeminiFunctionCallThoughtSignatures
(
b
)
s
:=
string
(
out
)
if
!
strings
.
Contains
(
s
,
"
\"
thoughtSignature
\"
:
\"
"
+
geminiDummyThoughtSignature
+
"
\"
"
)
{
t
.
Fatalf
(
"expected injected thoughtSignature %q, got: %s"
,
geminiDummyThoughtSignature
,
s
)
}
}
backend/internal/service/gemini_multiplatform_test.go
View file @
2220fd18
...
@@ -218,6 +218,14 @@ func (m *mockGroupRepoForGemini) DeleteAccountGroupsByGroupID(ctx context.Contex
...
@@ -218,6 +218,14 @@ func (m *mockGroupRepoForGemini) DeleteAccountGroupsByGroupID(ctx context.Contex
return
0
,
nil
return
0
,
nil
}
}
func
(
m
*
mockGroupRepoForGemini
)
BindAccountsToGroup
(
ctx
context
.
Context
,
groupID
int64
,
accountIDs
[]
int64
)
error
{
return
nil
}
func
(
m
*
mockGroupRepoForGemini
)
GetAccountIDsByGroupIDs
(
ctx
context
.
Context
,
groupIDs
[]
int64
)
([]
int64
,
error
)
{
return
nil
,
nil
}
var
_
GroupRepository
=
(
*
mockGroupRepoForGemini
)(
nil
)
var
_
GroupRepository
=
(
*
mockGroupRepoForGemini
)(
nil
)
// mockGatewayCacheForGemini Gemini 测试用的 cache mock
// mockGatewayCacheForGemini Gemini 测试用的 cache mock
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment