Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
2220fd18
Commit
2220fd18
authored
Feb 03, 2026
by
song
Browse files
merge upstream main
parents
11ff73b5
df4c0adf
Changes
67
Expand all
Hide whitespace changes
Inline
Side-by-side
backend/internal/service/account_test_service.go
View file @
2220fd18
...
@@ -123,7 +123,7 @@ func createTestPayload(modelID string) (map[string]any, error) {
...
@@ -123,7 +123,7 @@ func createTestPayload(modelID string) (map[string]any, error) {
"system"
:
[]
map
[
string
]
any
{
"system"
:
[]
map
[
string
]
any
{
{
{
"type"
:
"text"
,
"type"
:
"text"
,
"text"
:
"You are C
laude
Code
, Anthropic's official CLI for Claude."
,
"text"
:
c
laudeCode
SystemPrompt
,
"cache_control"
:
map
[
string
]
string
{
"cache_control"
:
map
[
string
]
string
{
"type"
:
"ephemeral"
,
"type"
:
"ephemeral"
,
},
},
...
...
backend/internal/service/admin_service.go
View file @
2220fd18
...
@@ -115,6 +115,8 @@ type CreateGroupInput struct {
...
@@ -115,6 +115,8 @@ type CreateGroupInput struct {
MCPXMLInject
*
bool
MCPXMLInject
*
bool
// 支持的模型系列(仅 antigravity 平台使用)
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes
[]
string
SupportedModelScopes
[]
string
// 从指定分组复制账号(创建分组后在同一事务内绑定)
CopyAccountsFromGroupIDs
[]
int64
}
}
type
UpdateGroupInput
struct
{
type
UpdateGroupInput
struct
{
...
@@ -142,6 +144,8 @@ type UpdateGroupInput struct {
...
@@ -142,6 +144,8 @@ type UpdateGroupInput struct {
MCPXMLInject
*
bool
MCPXMLInject
*
bool
// 支持的模型系列(仅 antigravity 平台使用)
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes
*
[]
string
SupportedModelScopes
*
[]
string
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs
[]
int64
}
}
type
CreateAccountInput
struct
{
type
CreateAccountInput
struct
{
...
@@ -598,6 +602,38 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
...
@@ -598,6 +602,38 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
mcpXMLInject
=
*
input
.
MCPXMLInject
mcpXMLInject
=
*
input
.
MCPXMLInject
}
}
// 如果指定了复制账号的源分组,先获取账号 ID 列表
var
accountIDsToCopy
[]
int64
if
len
(
input
.
CopyAccountsFromGroupIDs
)
>
0
{
// 去重源分组 IDs
seen
:=
make
(
map
[
int64
]
struct
{})
uniqueSourceGroupIDs
:=
make
([]
int64
,
0
,
len
(
input
.
CopyAccountsFromGroupIDs
))
for
_
,
srcGroupID
:=
range
input
.
CopyAccountsFromGroupIDs
{
if
_
,
exists
:=
seen
[
srcGroupID
];
!
exists
{
seen
[
srcGroupID
]
=
struct
{}{}
uniqueSourceGroupIDs
=
append
(
uniqueSourceGroupIDs
,
srcGroupID
)
}
}
// 校验源分组的平台是否与新分组一致
for
_
,
srcGroupID
:=
range
uniqueSourceGroupIDs
{
srcGroup
,
err
:=
s
.
groupRepo
.
GetByIDLite
(
ctx
,
srcGroupID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"source group %d not found: %w"
,
srcGroupID
,
err
)
}
if
srcGroup
.
Platform
!=
platform
{
return
nil
,
fmt
.
Errorf
(
"source group %d platform mismatch: expected %s, got %s"
,
srcGroupID
,
platform
,
srcGroup
.
Platform
)
}
}
// 获取所有源分组的账号(去重)
var
err
error
accountIDsToCopy
,
err
=
s
.
groupRepo
.
GetAccountIDsByGroupIDs
(
ctx
,
uniqueSourceGroupIDs
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to get accounts from source groups: %w"
,
err
)
}
}
group
:=
&
Group
{
group
:=
&
Group
{
Name
:
input
.
Name
,
Name
:
input
.
Name
,
Description
:
input
.
Description
,
Description
:
input
.
Description
,
...
@@ -622,6 +658,15 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
...
@@ -622,6 +658,15 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
if
err
:=
s
.
groupRepo
.
Create
(
ctx
,
group
);
err
!=
nil
{
if
err
:=
s
.
groupRepo
.
Create
(
ctx
,
group
);
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
// 如果有需要复制的账号,绑定到新分组
if
len
(
accountIDsToCopy
)
>
0
{
if
err
:=
s
.
groupRepo
.
BindAccountsToGroup
(
ctx
,
group
.
ID
,
accountIDsToCopy
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to bind accounts to new group: %w"
,
err
)
}
group
.
AccountCount
=
int64
(
len
(
accountIDsToCopy
))
}
return
group
,
nil
return
group
,
nil
}
}
...
@@ -810,6 +855,54 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
...
@@ -810,6 +855,54 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if
err
:=
s
.
groupRepo
.
Update
(
ctx
,
group
);
err
!=
nil
{
if
err
:=
s
.
groupRepo
.
Update
(
ctx
,
group
);
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
// 如果指定了复制账号的源分组,同步绑定(替换当前分组的账号)
if
len
(
input
.
CopyAccountsFromGroupIDs
)
>
0
{
// 去重源分组 IDs
seen
:=
make
(
map
[
int64
]
struct
{})
uniqueSourceGroupIDs
:=
make
([]
int64
,
0
,
len
(
input
.
CopyAccountsFromGroupIDs
))
for
_
,
srcGroupID
:=
range
input
.
CopyAccountsFromGroupIDs
{
// 校验:源分组不能是自身
if
srcGroupID
==
id
{
return
nil
,
fmt
.
Errorf
(
"cannot copy accounts from self"
)
}
// 去重
if
_
,
exists
:=
seen
[
srcGroupID
];
!
exists
{
seen
[
srcGroupID
]
=
struct
{}{}
uniqueSourceGroupIDs
=
append
(
uniqueSourceGroupIDs
,
srcGroupID
)
}
}
// 校验源分组的平台是否与当前分组一致
for
_
,
srcGroupID
:=
range
uniqueSourceGroupIDs
{
srcGroup
,
err
:=
s
.
groupRepo
.
GetByIDLite
(
ctx
,
srcGroupID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"source group %d not found: %w"
,
srcGroupID
,
err
)
}
if
srcGroup
.
Platform
!=
group
.
Platform
{
return
nil
,
fmt
.
Errorf
(
"source group %d platform mismatch: expected %s, got %s"
,
srcGroupID
,
group
.
Platform
,
srcGroup
.
Platform
)
}
}
// 获取所有源分组的账号(去重)
accountIDsToCopy
,
err
:=
s
.
groupRepo
.
GetAccountIDsByGroupIDs
(
ctx
,
uniqueSourceGroupIDs
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to get accounts from source groups: %w"
,
err
)
}
// 先清空当前分组的所有账号绑定
if
_
,
err
:=
s
.
groupRepo
.
DeleteAccountGroupsByGroupID
(
ctx
,
id
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to clear existing account bindings: %w"
,
err
)
}
// 再绑定源分组的账号
if
len
(
accountIDsToCopy
)
>
0
{
if
err
:=
s
.
groupRepo
.
BindAccountsToGroup
(
ctx
,
id
,
accountIDsToCopy
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to bind accounts to group: %w"
,
err
)
}
}
}
if
s
.
authCacheInvalidator
!=
nil
{
if
s
.
authCacheInvalidator
!=
nil
{
s
.
authCacheInvalidator
.
InvalidateAuthCacheByGroupID
(
ctx
,
id
)
s
.
authCacheInvalidator
.
InvalidateAuthCacheByGroupID
(
ctx
,
id
)
}
}
...
...
backend/internal/service/admin_service_delete_test.go
View file @
2220fd18
...
@@ -164,6 +164,14 @@ func (s *groupRepoStub) DeleteAccountGroupsByGroupID(ctx context.Context, groupI
...
@@ -164,6 +164,14 @@ func (s *groupRepoStub) DeleteAccountGroupsByGroupID(ctx context.Context, groupI
panic
(
"unexpected DeleteAccountGroupsByGroupID call"
)
panic
(
"unexpected DeleteAccountGroupsByGroupID call"
)
}
}
func
(
s
*
groupRepoStub
)
BindAccountsToGroup
(
ctx
context
.
Context
,
groupID
int64
,
accountIDs
[]
int64
)
error
{
panic
(
"unexpected BindAccountsToGroup call"
)
}
func
(
s
*
groupRepoStub
)
GetAccountIDsByGroupIDs
(
ctx
context
.
Context
,
groupIDs
[]
int64
)
([]
int64
,
error
)
{
panic
(
"unexpected GetAccountIDsByGroupIDs call"
)
}
type
proxyRepoStub
struct
{
type
proxyRepoStub
struct
{
deleteErr
error
deleteErr
error
countErr
error
countErr
error
...
...
backend/internal/service/admin_service_group_test.go
View file @
2220fd18
...
@@ -108,6 +108,14 @@ func (s *groupRepoStubForAdmin) DeleteAccountGroupsByGroupID(_ context.Context,
...
@@ -108,6 +108,14 @@ func (s *groupRepoStubForAdmin) DeleteAccountGroupsByGroupID(_ context.Context,
panic
(
"unexpected DeleteAccountGroupsByGroupID call"
)
panic
(
"unexpected DeleteAccountGroupsByGroupID call"
)
}
}
func
(
s
*
groupRepoStubForAdmin
)
BindAccountsToGroup
(
_
context
.
Context
,
_
int64
,
_
[]
int64
)
error
{
panic
(
"unexpected BindAccountsToGroup call"
)
}
func
(
s
*
groupRepoStubForAdmin
)
GetAccountIDsByGroupIDs
(
_
context
.
Context
,
_
[]
int64
)
([]
int64
,
error
)
{
panic
(
"unexpected GetAccountIDsByGroupIDs call"
)
}
// TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递
// TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递
func
TestAdminService_CreateGroup_WithImagePricing
(
t
*
testing
.
T
)
{
func
TestAdminService_CreateGroup_WithImagePricing
(
t
*
testing
.
T
)
{
repo
:=
&
groupRepoStubForAdmin
{}
repo
:=
&
groupRepoStubForAdmin
{}
...
@@ -379,6 +387,14 @@ func (s *groupRepoStubForFallbackCycle) DeleteAccountGroupsByGroupID(_ context.C
...
@@ -379,6 +387,14 @@ func (s *groupRepoStubForFallbackCycle) DeleteAccountGroupsByGroupID(_ context.C
panic
(
"unexpected DeleteAccountGroupsByGroupID call"
)
panic
(
"unexpected DeleteAccountGroupsByGroupID call"
)
}
}
func
(
s
*
groupRepoStubForFallbackCycle
)
BindAccountsToGroup
(
_
context
.
Context
,
_
int64
,
_
[]
int64
)
error
{
panic
(
"unexpected BindAccountsToGroup call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
GetAccountIDsByGroupIDs
(
_
context
.
Context
,
_
[]
int64
)
([]
int64
,
error
)
{
panic
(
"unexpected GetAccountIDsByGroupIDs call"
)
}
type
groupRepoStubForInvalidRequestFallback
struct
{
type
groupRepoStubForInvalidRequestFallback
struct
{
groups
map
[
int64
]
*
Group
groups
map
[
int64
]
*
Group
created
*
Group
created
*
Group
...
@@ -748,4 +764,4 @@ func TestAdminService_UpdateGroup_InvalidRequestFallbackAllowsAntigravity(t *tes
...
@@ -748,4 +764,4 @@ func TestAdminService_UpdateGroup_InvalidRequestFallbackAllowsAntigravity(t *tes
require
.
NotNil
(
t
,
group
)
require
.
NotNil
(
t
,
group
)
require
.
NotNil
(
t
,
repo
.
updated
)
require
.
NotNil
(
t
,
repo
.
updated
)
require
.
Equal
(
t
,
fallbackID
,
*
repo
.
updated
.
FallbackGroupIDOnInvalidRequest
)
require
.
Equal
(
t
,
fallbackID
,
*
repo
.
updated
.
FallbackGroupIDOnInvalidRequest
)
}
}
backend/internal/service/antigravity_gateway_service.go
View file @
2220fd18
...
@@ -302,13 +302,11 @@ func logPrefix(sessionID, accountName string) string {
...
@@ -302,13 +302,11 @@ func logPrefix(sessionID, accountName string) string {
}
}
// Antigravity 直接支持的模型(精确匹配透传)
// Antigravity 直接支持的模型(精确匹配透传)
// 注意:gemini-2.5 系列已移除,统一映射到 gemini-3 系列
var
antigravitySupportedModels
=
map
[
string
]
bool
{
var
antigravitySupportedModels
=
map
[
string
]
bool
{
"claude-opus-4-5-thinking"
:
true
,
"claude-opus-4-5-thinking"
:
true
,
"claude-sonnet-4-5"
:
true
,
"claude-sonnet-4-5"
:
true
,
"claude-sonnet-4-5-thinking"
:
true
,
"claude-sonnet-4-5-thinking"
:
true
,
"gemini-2.5-flash"
:
true
,
"gemini-2.5-flash-lite"
:
true
,
"gemini-2.5-flash-thinking"
:
true
,
"gemini-3-flash"
:
true
,
"gemini-3-flash"
:
true
,
"gemini-3-pro-low"
:
true
,
"gemini-3-pro-low"
:
true
,
"gemini-3-pro-high"
:
true
,
"gemini-3-pro-high"
:
true
,
...
@@ -317,23 +315,32 @@ var antigravitySupportedModels = map[string]bool{
...
@@ -317,23 +315,32 @@ var antigravitySupportedModels = map[string]bool{
// Antigravity 前缀映射表(按前缀长度降序排列,确保最长匹配优先)
// Antigravity 前缀映射表(按前缀长度降序排列,确保最长匹配优先)
// 用于处理模型版本号变化(如 -20251111, -thinking, -preview 等后缀)
// 用于处理模型版本号变化(如 -20251111, -thinking, -preview 等后缀)
// gemini-2.5 系列统一映射到 gemini-3 系列(Antigravity 上游不再支持 2.5)
var
antigravityPrefixMapping
=
[]
struct
{
var
antigravityPrefixMapping
=
[]
struct
{
prefix
string
prefix
string
target
string
target
string
}{
}{
// 长前缀优先
// gemini-2.5 → gemini-3 映射(长前缀优先)
{
"gemini-2.5-flash-image"
,
"gemini-3-pro-image"
},
// gemini-2.5-flash-image → 3-pro-image
{
"gemini-2.5-flash-thinking"
,
"gemini-3-flash"
},
// gemini-2.5-flash-thinking → gemini-3-flash
{
"gemini-3-pro-image"
,
"gemini-3-pro-image"
},
// gemini-3-pro-image-preview 等
{
"gemini-2.5-flash-image"
,
"gemini-3-pro-image"
},
// gemini-2.5-flash-image → gemini-3-pro-image
{
"gemini-3-flash"
,
"gemini-3-flash"
},
// gemini-3-flash-preview 等 → gemini-3-flash
{
"gemini-2.5-flash-lite"
,
"gemini-3-flash"
},
// gemini-2.5-flash-lite → gemini-3-flash
{
"claude-3-5-sonnet"
,
"claude-sonnet-4-5"
},
// 旧版 claude-3-5-sonnet-xxx
{
"gemini-2.5-flash"
,
"gemini-3-flash"
},
// gemini-2.5-flash → gemini-3-flash
{
"claude-sonnet-4-5"
,
"claude-sonnet-4-5"
},
// claude-sonnet-4-5-xxx
{
"gemini-2.5-pro-preview"
,
"gemini-3-pro-high"
},
// gemini-2.5-pro-preview → gemini-3-pro-high
{
"claude-haiku-4-5"
,
"claude-sonnet-4-5"
},
// claude-haiku-4-5-xxx → sonnet
{
"gemini-2.5-pro-exp"
,
"gemini-3-pro-high"
},
// gemini-2.5-pro-exp → gemini-3-pro-high
{
"gemini-2.5-pro"
,
"gemini-3-pro-high"
},
// gemini-2.5-pro → gemini-3-pro-high
// gemini-3 前缀映射
{
"gemini-3-pro-image"
,
"gemini-3-pro-image"
},
// gemini-3-pro-image-preview 等
{
"gemini-3-flash"
,
"gemini-3-flash"
},
// gemini-3-flash-preview 等 → gemini-3-flash
{
"gemini-3-pro"
,
"gemini-3-pro-high"
},
// gemini-3-pro, gemini-3-pro-preview 等
// Claude 映射
{
"claude-3-5-sonnet"
,
"claude-sonnet-4-5"
},
// 旧版 claude-3-5-sonnet-xxx
{
"claude-sonnet-4-5"
,
"claude-sonnet-4-5"
},
// claude-sonnet-4-5-xxx
{
"claude-haiku-4-5"
,
"claude-sonnet-4-5"
},
// claude-haiku-4-5-xxx → sonnet
{
"claude-opus-4-5"
,
"claude-opus-4-5-thinking"
},
{
"claude-opus-4-5"
,
"claude-opus-4-5-thinking"
},
{
"claude-3-haiku"
,
"claude-sonnet-4-5"
},
// 旧版 claude-3-haiku-xxx → sonnet
{
"claude-3-haiku"
,
"claude-sonnet-4-5"
},
// 旧版 claude-3-haiku-xxx → sonnet
{
"claude-sonnet-4"
,
"claude-sonnet-4-5"
},
{
"claude-sonnet-4"
,
"claude-sonnet-4-5"
},
{
"claude-haiku-4"
,
"claude-sonnet-4-5"
},
// → sonnet
{
"claude-haiku-4"
,
"claude-sonnet-4-5"
},
// → sonnet
{
"claude-opus-4"
,
"claude-opus-4-5-thinking"
},
{
"claude-opus-4"
,
"claude-opus-4-5-thinking"
},
{
"gemini-3-pro"
,
"gemini-3-pro-high"
},
// gemini-3-pro, gemini-3-pro-preview 等
}
}
// AntigravityGatewayService 处理 Antigravity 平台的 API 转发
// AntigravityGatewayService 处理 Antigravity 平台的 API 转发
...
...
backend/internal/service/antigravity_gateway_service_test.go
View file @
2220fd18
...
@@ -103,6 +103,10 @@ func (s *httpUpstreamStub) Do(_ *http.Request, _ string, _ int64, _ int) (*http.
...
@@ -103,6 +103,10 @@ func (s *httpUpstreamStub) Do(_ *http.Request, _ string, _ int64, _ int) (*http.
return
s
.
resp
,
s
.
err
return
s
.
resp
,
s
.
err
}
}
func
(
s
*
httpUpstreamStub
)
DoWithTLS
(
_
*
http
.
Request
,
_
string
,
_
int64
,
_
int
,
_
bool
)
(
*
http
.
Response
,
error
)
{
return
s
.
resp
,
s
.
err
}
func
TestAntigravityGatewayService_Forward_PromptTooLong
(
t
*
testing
.
T
)
{
func
TestAntigravityGatewayService_Forward_PromptTooLong
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
gin
.
SetMode
(
gin
.
TestMode
)
writer
:=
httptest
.
NewRecorder
()
writer
:=
httptest
.
NewRecorder
()
...
...
backend/internal/service/antigravity_model_mapping_test.go
View file @
2220fd18
...
@@ -134,18 +134,18 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
...
@@ -134,18 +134,18 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
expected
:
"claude-sonnet-4-5"
,
expected
:
"claude-sonnet-4-5"
,
},
},
// 3. Gemini
透传
// 3. Gemini
2.5 → 3 映射
{
{
name
:
"Gemini
透传
- gemini-2.5-flash"
,
name
:
"Gemini
映射
- gemini-2.5-flash
→ gemini-3-flash
"
,
requestedModel
:
"gemini-2.5-flash"
,
requestedModel
:
"gemini-2.5-flash"
,
accountMapping
:
nil
,
accountMapping
:
nil
,
expected
:
"gemini-
2.5
-flash"
,
expected
:
"gemini-
3
-flash"
,
},
},
{
{
name
:
"Gemini
透传
- gemini-2.5-pro"
,
name
:
"Gemini
映射
- gemini-2.5-pro
→ gemini-3-pro-high
"
,
requestedModel
:
"gemini-2.5-pro"
,
requestedModel
:
"gemini-2.5-pro"
,
accountMapping
:
nil
,
accountMapping
:
nil
,
expected
:
"gemini-
2.5
-pro"
,
expected
:
"gemini-
3
-pro
-high
"
,
},
},
{
{
name
:
"Gemini透传 - gemini-future-model"
,
name
:
"Gemini透传 - gemini-future-model"
,
...
...
backend/internal/service/auth_service.go
View file @
2220fd18
...
@@ -19,17 +19,19 @@ import (
...
@@ -19,17 +19,19 @@ import (
)
)
var
(
var
(
ErrInvalidCredentials
=
infraerrors
.
Unauthorized
(
"INVALID_CREDENTIALS"
,
"invalid email or password"
)
ErrInvalidCredentials
=
infraerrors
.
Unauthorized
(
"INVALID_CREDENTIALS"
,
"invalid email or password"
)
ErrUserNotActive
=
infraerrors
.
Forbidden
(
"USER_NOT_ACTIVE"
,
"user is not active"
)
ErrUserNotActive
=
infraerrors
.
Forbidden
(
"USER_NOT_ACTIVE"
,
"user is not active"
)
ErrEmailExists
=
infraerrors
.
Conflict
(
"EMAIL_EXISTS"
,
"email already exists"
)
ErrEmailExists
=
infraerrors
.
Conflict
(
"EMAIL_EXISTS"
,
"email already exists"
)
ErrEmailReserved
=
infraerrors
.
BadRequest
(
"EMAIL_RESERVED"
,
"email is reserved"
)
ErrEmailReserved
=
infraerrors
.
BadRequest
(
"EMAIL_RESERVED"
,
"email is reserved"
)
ErrInvalidToken
=
infraerrors
.
Unauthorized
(
"INVALID_TOKEN"
,
"invalid token"
)
ErrInvalidToken
=
infraerrors
.
Unauthorized
(
"INVALID_TOKEN"
,
"invalid token"
)
ErrTokenExpired
=
infraerrors
.
Unauthorized
(
"TOKEN_EXPIRED"
,
"token has expired"
)
ErrTokenExpired
=
infraerrors
.
Unauthorized
(
"TOKEN_EXPIRED"
,
"token has expired"
)
ErrTokenTooLarge
=
infraerrors
.
BadRequest
(
"TOKEN_TOO_LARGE"
,
"token too large"
)
ErrTokenTooLarge
=
infraerrors
.
BadRequest
(
"TOKEN_TOO_LARGE"
,
"token too large"
)
ErrTokenRevoked
=
infraerrors
.
Unauthorized
(
"TOKEN_REVOKED"
,
"token has been revoked"
)
ErrTokenRevoked
=
infraerrors
.
Unauthorized
(
"TOKEN_REVOKED"
,
"token has been revoked"
)
ErrEmailVerifyRequired
=
infraerrors
.
BadRequest
(
"EMAIL_VERIFY_REQUIRED"
,
"email verification is required"
)
ErrEmailVerifyRequired
=
infraerrors
.
BadRequest
(
"EMAIL_VERIFY_REQUIRED"
,
"email verification is required"
)
ErrRegDisabled
=
infraerrors
.
Forbidden
(
"REGISTRATION_DISABLED"
,
"registration is currently disabled"
)
ErrRegDisabled
=
infraerrors
.
Forbidden
(
"REGISTRATION_DISABLED"
,
"registration is currently disabled"
)
ErrServiceUnavailable
=
infraerrors
.
ServiceUnavailable
(
"SERVICE_UNAVAILABLE"
,
"service temporarily unavailable"
)
ErrServiceUnavailable
=
infraerrors
.
ServiceUnavailable
(
"SERVICE_UNAVAILABLE"
,
"service temporarily unavailable"
)
ErrInvitationCodeRequired
=
infraerrors
.
BadRequest
(
"INVITATION_CODE_REQUIRED"
,
"invitation code is required"
)
ErrInvitationCodeInvalid
=
infraerrors
.
BadRequest
(
"INVITATION_CODE_INVALID"
,
"invalid or used invitation code"
)
)
)
// maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。
// maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。
...
@@ -47,6 +49,7 @@ type JWTClaims struct {
...
@@ -47,6 +49,7 @@ type JWTClaims struct {
// AuthService 认证服务
// AuthService 认证服务
type
AuthService
struct
{
type
AuthService
struct
{
userRepo
UserRepository
userRepo
UserRepository
redeemRepo
RedeemCodeRepository
cfg
*
config
.
Config
cfg
*
config
.
Config
settingService
*
SettingService
settingService
*
SettingService
emailService
*
EmailService
emailService
*
EmailService
...
@@ -58,6 +61,7 @@ type AuthService struct {
...
@@ -58,6 +61,7 @@ type AuthService struct {
// NewAuthService 创建认证服务实例
// NewAuthService 创建认证服务实例
func
NewAuthService
(
func
NewAuthService
(
userRepo
UserRepository
,
userRepo
UserRepository
,
redeemRepo
RedeemCodeRepository
,
cfg
*
config
.
Config
,
cfg
*
config
.
Config
,
settingService
*
SettingService
,
settingService
*
SettingService
,
emailService
*
EmailService
,
emailService
*
EmailService
,
...
@@ -67,6 +71,7 @@ func NewAuthService(
...
@@ -67,6 +71,7 @@ func NewAuthService(
)
*
AuthService
{
)
*
AuthService
{
return
&
AuthService
{
return
&
AuthService
{
userRepo
:
userRepo
,
userRepo
:
userRepo
,
redeemRepo
:
redeemRepo
,
cfg
:
cfg
,
cfg
:
cfg
,
settingService
:
settingService
,
settingService
:
settingService
,
emailService
:
emailService
,
emailService
:
emailService
,
...
@@ -78,11 +83,11 @@ func NewAuthService(
...
@@ -78,11 +83,11 @@ func NewAuthService(
// Register 用户注册,返回token和用户
// Register 用户注册,返回token和用户
func
(
s
*
AuthService
)
Register
(
ctx
context
.
Context
,
email
,
password
string
)
(
string
,
*
User
,
error
)
{
func
(
s
*
AuthService
)
Register
(
ctx
context
.
Context
,
email
,
password
string
)
(
string
,
*
User
,
error
)
{
return
s
.
RegisterWithVerification
(
ctx
,
email
,
password
,
""
,
""
)
return
s
.
RegisterWithVerification
(
ctx
,
email
,
password
,
""
,
""
,
""
)
}
}
// RegisterWithVerification 用户注册(支持邮件验证
和
优惠码),返回token和用户
// RegisterWithVerification 用户注册(支持邮件验证
、
优惠码
和邀请码
),返回token和用户
func
(
s
*
AuthService
)
RegisterWithVerification
(
ctx
context
.
Context
,
email
,
password
,
verifyCode
,
promoCode
string
)
(
string
,
*
User
,
error
)
{
func
(
s
*
AuthService
)
RegisterWithVerification
(
ctx
context
.
Context
,
email
,
password
,
verifyCode
,
promoCode
,
invitationCode
string
)
(
string
,
*
User
,
error
)
{
// 检查是否开放注册(默认关闭:settingService 未配置时不允许注册)
// 检查是否开放注册(默认关闭:settingService 未配置时不允许注册)
if
s
.
settingService
==
nil
||
!
s
.
settingService
.
IsRegistrationEnabled
(
ctx
)
{
if
s
.
settingService
==
nil
||
!
s
.
settingService
.
IsRegistrationEnabled
(
ctx
)
{
return
""
,
nil
,
ErrRegDisabled
return
""
,
nil
,
ErrRegDisabled
...
@@ -93,6 +98,26 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
...
@@ -93,6 +98,26 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return
""
,
nil
,
ErrEmailReserved
return
""
,
nil
,
ErrEmailReserved
}
}
// 检查是否需要邀请码
var
invitationRedeemCode
*
RedeemCode
if
s
.
settingService
!=
nil
&&
s
.
settingService
.
IsInvitationCodeEnabled
(
ctx
)
{
if
invitationCode
==
""
{
return
""
,
nil
,
ErrInvitationCodeRequired
}
// 验证邀请码
redeemCode
,
err
:=
s
.
redeemRepo
.
GetByCode
(
ctx
,
invitationCode
)
if
err
!=
nil
{
log
.
Printf
(
"[Auth] Invalid invitation code: %s, error: %v"
,
invitationCode
,
err
)
return
""
,
nil
,
ErrInvitationCodeInvalid
}
// 检查类型和状态
if
redeemCode
.
Type
!=
RedeemTypeInvitation
||
redeemCode
.
Status
!=
StatusUnused
{
log
.
Printf
(
"[Auth] Invitation code invalid: type=%s, status=%s"
,
redeemCode
.
Type
,
redeemCode
.
Status
)
return
""
,
nil
,
ErrInvitationCodeInvalid
}
invitationRedeemCode
=
redeemCode
}
// 检查是否需要邮件验证
// 检查是否需要邮件验证
if
s
.
settingService
!=
nil
&&
s
.
settingService
.
IsEmailVerifyEnabled
(
ctx
)
{
if
s
.
settingService
!=
nil
&&
s
.
settingService
.
IsEmailVerifyEnabled
(
ctx
)
{
// 如果邮件验证已开启但邮件服务未配置,拒绝注册
// 如果邮件验证已开启但邮件服务未配置,拒绝注册
...
@@ -153,6 +178,13 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
...
@@ -153,6 +178,13 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return
""
,
nil
,
ErrServiceUnavailable
return
""
,
nil
,
ErrServiceUnavailable
}
}
// 标记邀请码为已使用(如果使用了邀请码)
if
invitationRedeemCode
!=
nil
{
if
err
:=
s
.
redeemRepo
.
Use
(
ctx
,
invitationRedeemCode
.
ID
,
user
.
ID
);
err
!=
nil
{
// 邀请码标记失败不影响注册,只记录日志
log
.
Printf
(
"[Auth] Failed to mark invitation code as used for user %d: %v"
,
user
.
ID
,
err
)
}
}
// 应用优惠码(如果提供且功能已启用)
// 应用优惠码(如果提供且功能已启用)
if
promoCode
!=
""
&&
s
.
promoService
!=
nil
&&
s
.
settingService
!=
nil
&&
s
.
settingService
.
IsPromoCodeEnabled
(
ctx
)
{
if
promoCode
!=
""
&&
s
.
promoService
!=
nil
&&
s
.
settingService
!=
nil
&&
s
.
settingService
.
IsPromoCodeEnabled
(
ctx
)
{
if
err
:=
s
.
promoService
.
ApplyPromoCode
(
ctx
,
user
.
ID
,
promoCode
);
err
!=
nil
{
if
err
:=
s
.
promoService
.
ApplyPromoCode
(
ctx
,
user
.
ID
,
promoCode
);
err
!=
nil
{
...
...
backend/internal/service/auth_service_register_test.go
View file @
2220fd18
...
@@ -115,6 +115,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
...
@@ -115,6 +115,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
return
NewAuthService
(
return
NewAuthService
(
repo
,
repo
,
nil
,
// redeemRepo
cfg
,
cfg
,
settingService
,
settingService
,
emailService
,
emailService
,
...
@@ -152,7 +153,7 @@ func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testi
...
@@ -152,7 +153,7 @@ func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testi
},
nil
)
},
nil
)
// 应返回服务不可用错误,而不是允许绕过验证
// 应返回服务不可用错误,而不是允许绕过验证
_
,
_
,
err
:=
service
.
RegisterWithVerification
(
context
.
Background
(),
"user@test.com"
,
"password"
,
"any-code"
,
""
)
_
,
_
,
err
:=
service
.
RegisterWithVerification
(
context
.
Background
(),
"user@test.com"
,
"password"
,
"any-code"
,
""
,
""
)
require
.
ErrorIs
(
t
,
err
,
ErrServiceUnavailable
)
require
.
ErrorIs
(
t
,
err
,
ErrServiceUnavailable
)
}
}
...
@@ -164,7 +165,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
...
@@ -164,7 +165,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
SettingKeyEmailVerifyEnabled
:
"true"
,
SettingKeyEmailVerifyEnabled
:
"true"
,
},
cache
)
},
cache
)
_
,
_
,
err
:=
service
.
RegisterWithVerification
(
context
.
Background
(),
"user@test.com"
,
"password"
,
""
,
""
)
_
,
_
,
err
:=
service
.
RegisterWithVerification
(
context
.
Background
(),
"user@test.com"
,
"password"
,
""
,
""
,
""
)
require
.
ErrorIs
(
t
,
err
,
ErrEmailVerifyRequired
)
require
.
ErrorIs
(
t
,
err
,
ErrEmailVerifyRequired
)
}
}
...
@@ -178,7 +179,7 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
...
@@ -178,7 +179,7 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
SettingKeyEmailVerifyEnabled
:
"true"
,
SettingKeyEmailVerifyEnabled
:
"true"
,
},
cache
)
},
cache
)
_
,
_
,
err
:=
service
.
RegisterWithVerification
(
context
.
Background
(),
"user@test.com"
,
"password"
,
"wrong"
,
""
)
_
,
_
,
err
:=
service
.
RegisterWithVerification
(
context
.
Background
(),
"user@test.com"
,
"password"
,
"wrong"
,
""
,
""
)
require
.
ErrorIs
(
t
,
err
,
ErrInvalidVerifyCode
)
require
.
ErrorIs
(
t
,
err
,
ErrInvalidVerifyCode
)
require
.
ErrorContains
(
t
,
err
,
"verify code"
)
require
.
ErrorContains
(
t
,
err
,
"verify code"
)
}
}
...
...
backend/internal/service/billing_service.go
View file @
2220fd18
...
@@ -241,6 +241,76 @@ func (s *BillingService) CalculateCostWithConfig(model string, tokens UsageToken
...
@@ -241,6 +241,76 @@ func (s *BillingService) CalculateCostWithConfig(model string, tokens UsageToken
return
s
.
CalculateCost
(
model
,
tokens
,
multiplier
)
return
s
.
CalculateCost
(
model
,
tokens
,
multiplier
)
}
}
// CalculateCostWithLongContext 计算费用,支持长上下文双倍计费
// threshold: 阈值(如 200000),超过此值的部分按 extraMultiplier 倍计费
// extraMultiplier: 超出部分的倍率(如 2.0 表示双倍)
//
// 示例:缓存 210k + 输入 10k = 220k,阈值 200k,倍率 2.0
// 拆分为:范围内 (200k, 0) + 范围外 (10k, 10k)
// 范围内正常计费,范围外 × 2 计费
func
(
s
*
BillingService
)
CalculateCostWithLongContext
(
model
string
,
tokens
UsageTokens
,
rateMultiplier
float64
,
threshold
int
,
extraMultiplier
float64
)
(
*
CostBreakdown
,
error
)
{
// 未启用长上下文计费,直接走正常计费
if
threshold
<=
0
||
extraMultiplier
<=
1
{
return
s
.
CalculateCost
(
model
,
tokens
,
rateMultiplier
)
}
// 计算总输入 token(缓存读取 + 新输入)
total
:=
tokens
.
CacheReadTokens
+
tokens
.
InputTokens
if
total
<=
threshold
{
return
s
.
CalculateCost
(
model
,
tokens
,
rateMultiplier
)
}
// 拆分成范围内和范围外
var
inRangeCacheTokens
,
inRangeInputTokens
int
var
outRangeCacheTokens
,
outRangeInputTokens
int
if
tokens
.
CacheReadTokens
>=
threshold
{
// 缓存已超过阈值:范围内只有缓存,范围外是超出的缓存+全部输入
inRangeCacheTokens
=
threshold
inRangeInputTokens
=
0
outRangeCacheTokens
=
tokens
.
CacheReadTokens
-
threshold
outRangeInputTokens
=
tokens
.
InputTokens
}
else
{
// 缓存未超过阈值:范围内是全部缓存+部分输入,范围外是剩余输入
inRangeCacheTokens
=
tokens
.
CacheReadTokens
inRangeInputTokens
=
threshold
-
tokens
.
CacheReadTokens
outRangeCacheTokens
=
0
outRangeInputTokens
=
tokens
.
InputTokens
-
inRangeInputTokens
}
// 范围内部分:正常计费
inRangeTokens
:=
UsageTokens
{
InputTokens
:
inRangeInputTokens
,
OutputTokens
:
tokens
.
OutputTokens
,
// 输出只算一次
CacheCreationTokens
:
tokens
.
CacheCreationTokens
,
CacheReadTokens
:
inRangeCacheTokens
,
}
inRangeCost
,
err
:=
s
.
CalculateCost
(
model
,
inRangeTokens
,
rateMultiplier
)
if
err
!=
nil
{
return
nil
,
err
}
// 范围外部分:× extraMultiplier 计费
outRangeTokens
:=
UsageTokens
{
InputTokens
:
outRangeInputTokens
,
CacheReadTokens
:
outRangeCacheTokens
,
}
outRangeCost
,
err
:=
s
.
CalculateCost
(
model
,
outRangeTokens
,
rateMultiplier
*
extraMultiplier
)
if
err
!=
nil
{
return
inRangeCost
,
nil
// 出错时返回范围内成本
}
// 合并成本
return
&
CostBreakdown
{
InputCost
:
inRangeCost
.
InputCost
+
outRangeCost
.
InputCost
,
OutputCost
:
inRangeCost
.
OutputCost
,
CacheCreationCost
:
inRangeCost
.
CacheCreationCost
,
CacheReadCost
:
inRangeCost
.
CacheReadCost
+
outRangeCost
.
CacheReadCost
,
TotalCost
:
inRangeCost
.
TotalCost
+
outRangeCost
.
TotalCost
,
ActualCost
:
inRangeCost
.
ActualCost
+
outRangeCost
.
ActualCost
,
},
nil
}
// ListSupportedModels 列出所有支持的模型(现在总是返回true,因为有模糊匹配)
// ListSupportedModels 列出所有支持的模型(现在总是返回true,因为有模糊匹配)
func
(
s
*
BillingService
)
ListSupportedModels
()
[]
string
{
func
(
s
*
BillingService
)
ListSupportedModels
()
[]
string
{
models
:=
make
([]
string
,
0
)
models
:=
make
([]
string
,
0
)
...
...
backend/internal/service/domain_constants.go
View file @
2220fd18
...
@@ -39,6 +39,7 @@ const (
...
@@ -39,6 +39,7 @@ const (
RedeemTypeBalance
=
domain
.
RedeemTypeBalance
RedeemTypeBalance
=
domain
.
RedeemTypeBalance
RedeemTypeConcurrency
=
domain
.
RedeemTypeConcurrency
RedeemTypeConcurrency
=
domain
.
RedeemTypeConcurrency
RedeemTypeSubscription
=
domain
.
RedeemTypeSubscription
RedeemTypeSubscription
=
domain
.
RedeemTypeSubscription
RedeemTypeInvitation
=
domain
.
RedeemTypeInvitation
)
)
// PromoCode status constants
// PromoCode status constants
...
@@ -72,10 +73,11 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
...
@@ -72,10 +73,11 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
// Setting keys
// Setting keys
const
(
const
(
// 注册设置
// 注册设置
SettingKeyRegistrationEnabled
=
"registration_enabled"
// 是否开放注册
SettingKeyRegistrationEnabled
=
"registration_enabled"
// 是否开放注册
SettingKeyEmailVerifyEnabled
=
"email_verify_enabled"
// 是否开启邮件验证
SettingKeyEmailVerifyEnabled
=
"email_verify_enabled"
// 是否开启邮件验证
SettingKeyPromoCodeEnabled
=
"promo_code_enabled"
// 是否启用优惠码功能
SettingKeyPromoCodeEnabled
=
"promo_code_enabled"
// 是否启用优惠码功能
SettingKeyPasswordResetEnabled
=
"password_reset_enabled"
// 是否启用忘记密码功能(需要先开启邮件验证)
SettingKeyPasswordResetEnabled
=
"password_reset_enabled"
// 是否启用忘记密码功能(需要先开启邮件验证)
SettingKeyInvitationCodeEnabled
=
"invitation_code_enabled"
// 是否启用邀请码注册
// 邮件服务设置
// 邮件服务设置
SettingKeySMTPHost
=
"smtp_host"
// SMTP服务器地址
SettingKeySMTPHost
=
"smtp_host"
// SMTP服务器地址
...
...
backend/internal/service/gateway_beta_test.go
0 → 100644
View file @
2220fd18
package
service
import
(
"testing"
"github.com/stretchr/testify/require"
)
func
TestMergeAnthropicBeta
(
t
*
testing
.
T
)
{
got
:=
mergeAnthropicBeta
(
[]
string
{
"oauth-2025-04-20"
,
"interleaved-thinking-2025-05-14"
},
"foo, oauth-2025-04-20,bar, foo"
,
)
require
.
Equal
(
t
,
"oauth-2025-04-20,interleaved-thinking-2025-05-14,foo,bar"
,
got
)
}
func
TestMergeAnthropicBeta_EmptyIncoming
(
t
*
testing
.
T
)
{
got
:=
mergeAnthropicBeta
(
[]
string
{
"oauth-2025-04-20"
,
"interleaved-thinking-2025-05-14"
},
""
,
)
require
.
Equal
(
t
,
"oauth-2025-04-20,interleaved-thinking-2025-05-14"
,
got
)
}
backend/internal/service/gateway_multiplatform_test.go
View file @
2220fd18
...
@@ -266,6 +266,14 @@ func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Conte
...
@@ -266,6 +266,14 @@ func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Conte
return
0
,
nil
return
0
,
nil
}
}
func
(
m
*
mockGroupRepoForGateway
)
BindAccountsToGroup
(
ctx
context
.
Context
,
groupID
int64
,
accountIDs
[]
int64
)
error
{
return
nil
}
func
(
m
*
mockGroupRepoForGateway
)
GetAccountIDsByGroupIDs
(
ctx
context
.
Context
,
groupIDs
[]
int64
)
([]
int64
,
error
)
{
return
nil
,
nil
}
func
ptr
[
T
any
](
v
T
)
*
T
{
func
ptr
[
T
any
](
v
T
)
*
T
{
return
&
v
return
&
v
}
}
...
...
backend/internal/service/gateway_oauth_metadata_test.go
0 → 100644
View file @
2220fd18
package
service
import
(
"regexp"
"testing"
"github.com/stretchr/testify/require"
)
func
TestBuildOAuthMetadataUserID_FallbackWithoutAccountUUID
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
parsed
:=
&
ParsedRequest
{
Model
:
"claude-sonnet-4-5"
,
Stream
:
true
,
MetadataUserID
:
""
,
System
:
nil
,
Messages
:
nil
,
}
account
:=
&
Account
{
ID
:
123
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{},
// intentionally missing account_uuid / claude_user_id
}
fp
:=
&
Fingerprint
{
ClientID
:
"deadbeef"
}
// should be used as user id in legacy format
got
:=
svc
.
buildOAuthMetadataUserID
(
parsed
,
account
,
fp
)
require
.
NotEmpty
(
t
,
got
)
// Legacy format: user_{client}_account__session_{uuid}
re
:=
regexp
.
MustCompile
(
`^user_[a-zA-Z0-9]+_account__session_[a-f0-9-]{36}$`
)
require
.
True
(
t
,
re
.
MatchString
(
got
),
"unexpected user_id format: %s"
,
got
)
}
func
TestBuildOAuthMetadataUserID_UsesAccountUUIDWhenPresent
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
parsed
:=
&
ParsedRequest
{
Model
:
"claude-sonnet-4-5"
,
Stream
:
true
,
MetadataUserID
:
""
,
}
account
:=
&
Account
{
ID
:
123
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{
"account_uuid"
:
"acc-uuid"
,
"claude_user_id"
:
"clientid123"
,
"anthropic_user_id"
:
""
,
},
}
got
:=
svc
.
buildOAuthMetadataUserID
(
parsed
,
account
,
nil
)
require
.
NotEmpty
(
t
,
got
)
// New format: user_{client}_account_{account_uuid}_session_{uuid}
re
:=
regexp
.
MustCompile
(
`^user_clientid123_account_acc-uuid_session_[a-f0-9-]{36}$`
)
require
.
True
(
t
,
re
.
MatchString
(
got
),
"unexpected user_id format: %s"
,
got
)
}
backend/internal/service/gateway_prompt_test.go
View file @
2220fd18
...
@@ -2,6 +2,7 @@ package service
...
@@ -2,6 +2,7 @@ package service
import
(
import
(
"encoding/json"
"encoding/json"
"strings"
"testing"
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/require"
...
@@ -134,6 +135,8 @@ func TestSystemIncludesClaudeCodePrompt(t *testing.T) {
...
@@ -134,6 +135,8 @@ func TestSystemIncludesClaudeCodePrompt(t *testing.T) {
}
}
func
TestInjectClaudeCodePrompt
(
t
*
testing
.
T
)
{
func
TestInjectClaudeCodePrompt
(
t
*
testing
.
T
)
{
claudePrefix
:=
strings
.
TrimSpace
(
claudeCodeSystemPrompt
)
tests
:=
[]
struct
{
tests
:=
[]
struct
{
name
string
name
string
body
string
body
string
...
@@ -162,7 +165,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
...
@@ -162,7 +165,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
system
:
"Custom prompt"
,
system
:
"Custom prompt"
,
wantSystemLen
:
2
,
wantSystemLen
:
2
,
wantFirstText
:
claudeCodeSystemPrompt
,
wantFirstText
:
claudeCodeSystemPrompt
,
wantSecondText
:
"
Custom prompt"
,
wantSecondText
:
claudePrefix
+
"
\n\n
Custom prompt"
,
},
},
{
{
name
:
"string system equals Claude Code prompt"
,
name
:
"string system equals Claude Code prompt"
,
...
@@ -178,7 +181,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
...
@@ -178,7 +181,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
// Claude Code + Custom = 2
// Claude Code + Custom = 2
wantSystemLen
:
2
,
wantSystemLen
:
2
,
wantFirstText
:
claudeCodeSystemPrompt
,
wantFirstText
:
claudeCodeSystemPrompt
,
wantSecondText
:
"
Custom"
,
wantSecondText
:
claudePrefix
+
"
\n\n
Custom"
,
},
},
{
{
name
:
"array system with existing Claude Code prompt (should dedupe)"
,
name
:
"array system with existing Claude Code prompt (should dedupe)"
,
...
@@ -190,7 +193,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
...
@@ -190,7 +193,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
// Claude Code at start + Other = 2 (deduped)
// Claude Code at start + Other = 2 (deduped)
wantSystemLen
:
2
,
wantSystemLen
:
2
,
wantFirstText
:
claudeCodeSystemPrompt
,
wantFirstText
:
claudeCodeSystemPrompt
,
wantSecondText
:
"
Other"
,
wantSecondText
:
claudePrefix
+
"
\n\n
Other"
,
},
},
{
{
name
:
"empty array"
,
name
:
"empty array"
,
...
...
backend/internal/service/gateway_sanitize_test.go
0 → 100644
View file @
2220fd18
package
service
import
(
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func
TestSanitizeOpenCodeText_RewritesCanonicalSentence
(
t
*
testing
.
T
)
{
in
:=
"You are OpenCode, the best coding agent on the planet."
got
:=
sanitizeSystemText
(
in
)
require
.
Equal
(
t
,
strings
.
TrimSpace
(
claudeCodeSystemPrompt
),
got
)
}
func
TestSanitizeToolDescription_DoesNotRewriteKeywords
(
t
*
testing
.
T
)
{
in
:=
"OpenCode and opencode are mentioned."
got
:=
sanitizeToolDescription
(
in
)
// We no longer rewrite tool descriptions; only redact obvious path leaks.
require
.
Equal
(
t
,
in
,
got
)
}
backend/internal/service/gateway_service.go
View file @
2220fd18
This diff is collapsed.
Click to expand it.
backend/internal/service/gemini_messages_compat_service.go
View file @
2220fd18
...
@@ -36,6 +36,11 @@ const (
...
@@ -36,6 +36,11 @@ const (
geminiRetryMaxDelay
=
16
*
time
.
Second
geminiRetryMaxDelay
=
16
*
time
.
Second
)
)
// Gemini tool calling now requires `thoughtSignature` in parts that include `functionCall`.
// Many clients don't send it; we inject a known dummy signature to satisfy the validator.
// Ref: https://ai.google.dev/gemini-api/docs/thought-signatures
const
geminiDummyThoughtSignature
=
"skip_thought_signature_validator"
type
GeminiMessagesCompatService
struct
{
type
GeminiMessagesCompatService
struct
{
accountRepo
AccountRepository
accountRepo
AccountRepository
groupRepo
GroupRepository
groupRepo
GroupRepository
...
@@ -528,6 +533,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
...
@@ -528,6 +533,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
s
.
writeClaudeError
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
err
.
Error
())
return
nil
,
s
.
writeClaudeError
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
err
.
Error
())
}
}
geminiReq
=
ensureGeminiFunctionCallThoughtSignatures
(
geminiReq
)
originalClaudeBody
:=
body
originalClaudeBody
:=
body
proxyURL
:=
""
proxyURL
:=
""
...
@@ -983,6 +989,10 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
...
@@ -983,6 +989,10 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return
nil
,
s
.
writeGoogleError
(
c
,
http
.
StatusNotFound
,
"Unsupported action: "
+
action
)
return
nil
,
s
.
writeGoogleError
(
c
,
http
.
StatusNotFound
,
"Unsupported action: "
+
action
)
}
}
// Some Gemini upstreams validate tool call parts strictly; ensure any `functionCall` part includes a
// `thoughtSignature` to avoid frequent INVALID_ARGUMENT 400s.
body
=
ensureGeminiFunctionCallThoughtSignatures
(
body
)
mappedModel
:=
originalModel
mappedModel
:=
originalModel
if
account
.
Type
==
AccountTypeAPIKey
{
if
account
.
Type
==
AccountTypeAPIKey
{
mappedModel
=
account
.
GetMappedModel
(
originalModel
)
mappedModel
=
account
.
GetMappedModel
(
originalModel
)
...
@@ -2662,6 +2672,58 @@ func nextGeminiDailyResetUnix() *int64 {
...
@@ -2662,6 +2672,58 @@ func nextGeminiDailyResetUnix() *int64 {
return
&
ts
return
&
ts
}
}
func
ensureGeminiFunctionCallThoughtSignatures
(
body
[]
byte
)
[]
byte
{
// Fast path: only run when functionCall is present.
if
!
bytes
.
Contains
(
body
,
[]
byte
(
`"functionCall"`
))
{
return
body
}
var
payload
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
payload
);
err
!=
nil
{
return
body
}
contentsAny
,
ok
:=
payload
[
"contents"
]
.
([]
any
)
if
!
ok
||
len
(
contentsAny
)
==
0
{
return
body
}
modified
:=
false
for
_
,
c
:=
range
contentsAny
{
cm
,
ok
:=
c
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
partsAny
,
ok
:=
cm
[
"parts"
]
.
([]
any
)
if
!
ok
||
len
(
partsAny
)
==
0
{
continue
}
for
_
,
p
:=
range
partsAny
{
pm
,
ok
:=
p
.
(
map
[
string
]
any
)
if
!
ok
||
pm
==
nil
{
continue
}
if
fc
,
ok
:=
pm
[
"functionCall"
]
.
(
map
[
string
]
any
);
!
ok
||
fc
==
nil
{
continue
}
ts
,
_
:=
pm
[
"thoughtSignature"
]
.
(
string
)
if
strings
.
TrimSpace
(
ts
)
==
""
{
pm
[
"thoughtSignature"
]
=
geminiDummyThoughtSignature
modified
=
true
}
}
}
if
!
modified
{
return
body
}
b
,
err
:=
json
.
Marshal
(
payload
)
if
err
!=
nil
{
return
body
}
return
b
}
func
extractGeminiFinishReason
(
geminiResp
map
[
string
]
any
)
string
{
func
extractGeminiFinishReason
(
geminiResp
map
[
string
]
any
)
string
{
if
candidates
,
ok
:=
geminiResp
[
"candidates"
]
.
([]
any
);
ok
&&
len
(
candidates
)
>
0
{
if
candidates
,
ok
:=
geminiResp
[
"candidates"
]
.
([]
any
);
ok
&&
len
(
candidates
)
>
0
{
if
cand
,
ok
:=
candidates
[
0
]
.
(
map
[
string
]
any
);
ok
{
if
cand
,
ok
:=
candidates
[
0
]
.
(
map
[
string
]
any
);
ok
{
...
@@ -2861,7 +2923,13 @@ func convertClaudeMessagesToGeminiContents(messages any, toolUseIDToName map[str
...
@@ -2861,7 +2923,13 @@ func convertClaudeMessagesToGeminiContents(messages any, toolUseIDToName map[str
if
strings
.
TrimSpace
(
id
)
!=
""
&&
strings
.
TrimSpace
(
name
)
!=
""
{
if
strings
.
TrimSpace
(
id
)
!=
""
&&
strings
.
TrimSpace
(
name
)
!=
""
{
toolUseIDToName
[
id
]
=
name
toolUseIDToName
[
id
]
=
name
}
}
signature
,
_
:=
bm
[
"signature"
]
.
(
string
)
signature
=
strings
.
TrimSpace
(
signature
)
if
signature
==
""
{
signature
=
geminiDummyThoughtSignature
}
parts
=
append
(
parts
,
map
[
string
]
any
{
parts
=
append
(
parts
,
map
[
string
]
any
{
"thoughtSignature"
:
signature
,
"functionCall"
:
map
[
string
]
any
{
"functionCall"
:
map
[
string
]
any
{
"name"
:
name
,
"name"
:
name
,
"args"
:
bm
[
"input"
],
"args"
:
bm
[
"input"
],
...
...
backend/internal/service/gemini_messages_compat_service_test.go
View file @
2220fd18
package
service
package
service
import
(
import
(
"encoding/json"
"strings"
"testing"
"testing"
)
)
...
@@ -126,3 +128,78 @@ func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
...
@@ -126,3 +128,78 @@ func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
})
})
}
}
}
}
func
TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse
(
t
*
testing
.
T
)
{
claudeReq
:=
map
[
string
]
any
{
"model"
:
"claude-haiku-4-5-20251001"
,
"max_tokens"
:
10
,
"messages"
:
[]
any
{
map
[
string
]
any
{
"role"
:
"user"
,
"content"
:
[]
any
{
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
"hi"
},
},
},
map
[
string
]
any
{
"role"
:
"assistant"
,
"content"
:
[]
any
{
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
"ok"
},
map
[
string
]
any
{
"type"
:
"tool_use"
,
"id"
:
"toolu_123"
,
"name"
:
"default_api:write_file"
,
"input"
:
map
[
string
]
any
{
"path"
:
"a.txt"
,
"content"
:
"x"
},
// no signature on purpose
},
},
},
},
"tools"
:
[]
any
{
map
[
string
]
any
{
"name"
:
"default_api:write_file"
,
"description"
:
"write file"
,
"input_schema"
:
map
[
string
]
any
{
"type"
:
"object"
,
"properties"
:
map
[
string
]
any
{
"path"
:
map
[
string
]
any
{
"type"
:
"string"
}},
},
},
},
}
b
,
_
:=
json
.
Marshal
(
claudeReq
)
out
,
err
:=
convertClaudeMessagesToGeminiGenerateContent
(
b
)
if
err
!=
nil
{
t
.
Fatalf
(
"convert failed: %v"
,
err
)
}
s
:=
string
(
out
)
if
!
strings
.
Contains
(
s
,
"
\"
functionCall
\"
"
)
{
t
.
Fatalf
(
"expected functionCall in output, got: %s"
,
s
)
}
if
!
strings
.
Contains
(
s
,
"
\"
thoughtSignature
\"
:
\"
"
+
geminiDummyThoughtSignature
+
"
\"
"
)
{
t
.
Fatalf
(
"expected injected thoughtSignature %q, got: %s"
,
geminiDummyThoughtSignature
,
s
)
}
}
func
TestEnsureGeminiFunctionCallThoughtSignatures_InsertsWhenMissing
(
t
*
testing
.
T
)
{
geminiReq
:=
map
[
string
]
any
{
"contents"
:
[]
any
{
map
[
string
]
any
{
"role"
:
"user"
,
"parts"
:
[]
any
{
map
[
string
]
any
{
"functionCall"
:
map
[
string
]
any
{
"name"
:
"default_api:write_file"
,
"args"
:
map
[
string
]
any
{
"path"
:
"a.txt"
},
},
},
},
},
},
}
b
,
_
:=
json
.
Marshal
(
geminiReq
)
out
:=
ensureGeminiFunctionCallThoughtSignatures
(
b
)
s
:=
string
(
out
)
if
!
strings
.
Contains
(
s
,
"
\"
thoughtSignature
\"
:
\"
"
+
geminiDummyThoughtSignature
+
"
\"
"
)
{
t
.
Fatalf
(
"expected injected thoughtSignature %q, got: %s"
,
geminiDummyThoughtSignature
,
s
)
}
}
backend/internal/service/gemini_multiplatform_test.go
View file @
2220fd18
...
@@ -218,6 +218,14 @@ func (m *mockGroupRepoForGemini) DeleteAccountGroupsByGroupID(ctx context.Contex
...
@@ -218,6 +218,14 @@ func (m *mockGroupRepoForGemini) DeleteAccountGroupsByGroupID(ctx context.Contex
return
0
,
nil
return
0
,
nil
}
}
func
(
m
*
mockGroupRepoForGemini
)
BindAccountsToGroup
(
ctx
context
.
Context
,
groupID
int64
,
accountIDs
[]
int64
)
error
{
return
nil
}
func
(
m
*
mockGroupRepoForGemini
)
GetAccountIDsByGroupIDs
(
ctx
context
.
Context
,
groupIDs
[]
int64
)
([]
int64
,
error
)
{
return
nil
,
nil
}
var
_
GroupRepository
=
(
*
mockGroupRepoForGemini
)(
nil
)
var
_
GroupRepository
=
(
*
mockGroupRepoForGemini
)(
nil
)
// mockGatewayCacheForGemini Gemini 测试用的 cache mock
// mockGatewayCacheForGemini Gemini 测试用的 cache mock
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment