Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
de7ff902
Commit
de7ff902
authored
Feb 04, 2026
by
yangjianbo
Browse files
Merge branch 'main' into test
parents
317f26f0
dd96ada3
Changes
90
Hide whitespace changes
Inline
Side-by-side
backend/internal/server/middleware/api_key_auth_test.go
View file @
de7ff902
...
@@ -319,6 +319,10 @@ func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) (
...
@@ -319,6 +319,10 @@ func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) (
return
nil
,
errors
.
New
(
"not implemented"
)
return
nil
,
errors
.
New
(
"not implemented"
)
}
}
func
(
r
*
stubApiKeyRepo
)
IncrementQuotaUsed
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
(
float64
,
error
)
{
return
0
,
errors
.
New
(
"not implemented"
)
}
type
stubUserSubscriptionRepo
struct
{
type
stubUserSubscriptionRepo
struct
{
getActive
func
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
service
.
UserSubscription
,
error
)
getActive
func
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
service
.
UserSubscription
,
error
)
updateStatus
func
(
ctx
context
.
Context
,
subscriptionID
int64
,
status
string
)
error
updateStatus
func
(
ctx
context
.
Context
,
subscriptionID
int64
,
status
string
)
error
...
...
backend/internal/service/admin_service.go
View file @
de7ff902
...
@@ -116,9 +116,14 @@ type CreateGroupInput struct {
...
@@ -116,9 +116,14 @@ type CreateGroupInput struct {
SoraVideoPricePerRequestHD
*
float64
SoraVideoPricePerRequestHD
*
float64
ClaudeCodeOnly
bool
// 仅允许 Claude Code 客户端
ClaudeCodeOnly
bool
// 仅允许 Claude Code 客户端
FallbackGroupID
*
int64
// 降级分组 ID
FallbackGroupID
*
int64
// 降级分组 ID
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
FallbackGroupIDOnInvalidRequest
*
int64
// 模型路由配置(仅 anthropic 平台使用)
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting
map
[
string
][]
int64
ModelRouting
map
[
string
][]
int64
ModelRoutingEnabled
bool
// 是否启用模型路由
ModelRoutingEnabled
bool
// 是否启用模型路由
MCPXMLInject
*
bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes
[]
string
// 从指定分组复制账号(创建分组后在同一事务内绑定)
// 从指定分组复制账号(创建分组后在同一事务内绑定)
CopyAccountsFromGroupIDs
[]
int64
CopyAccountsFromGroupIDs
[]
int64
}
}
...
@@ -145,9 +150,14 @@ type UpdateGroupInput struct {
...
@@ -145,9 +150,14 @@ type UpdateGroupInput struct {
SoraVideoPricePerRequestHD
*
float64
SoraVideoPricePerRequestHD
*
float64
ClaudeCodeOnly
*
bool
// 仅允许 Claude Code 客户端
ClaudeCodeOnly
*
bool
// 仅允许 Claude Code 客户端
FallbackGroupID
*
int64
// 降级分组 ID
FallbackGroupID
*
int64
// 降级分组 ID
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
FallbackGroupIDOnInvalidRequest
*
int64
// 模型路由配置(仅 anthropic 平台使用)
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting
map
[
string
][]
int64
ModelRouting
map
[
string
][]
int64
ModelRoutingEnabled
*
bool
// 是否启用模型路由
ModelRoutingEnabled
*
bool
// 是否启用模型路由
MCPXMLInject
*
bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes
*
[]
string
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs
[]
int64
CopyAccountsFromGroupIDs
[]
int64
}
}
...
@@ -611,6 +621,22 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
...
@@ -611,6 +621,22 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
return
nil
,
err
return
nil
,
err
}
}
}
}
fallbackOnInvalidRequest
:=
input
.
FallbackGroupIDOnInvalidRequest
if
fallbackOnInvalidRequest
!=
nil
&&
*
fallbackOnInvalidRequest
<=
0
{
fallbackOnInvalidRequest
=
nil
}
// 校验无效请求兜底分组
if
fallbackOnInvalidRequest
!=
nil
{
if
err
:=
s
.
validateFallbackGroupOnInvalidRequest
(
ctx
,
0
,
platform
,
subscriptionType
,
*
fallbackOnInvalidRequest
);
err
!=
nil
{
return
nil
,
err
}
}
// MCPXMLInject:默认为 true,仅当显式传入 false 时关闭
mcpXMLInject
:=
true
if
input
.
MCPXMLInject
!=
nil
{
mcpXMLInject
=
*
input
.
MCPXMLInject
}
// 如果指定了复制账号的源分组,先获取账号 ID 列表
// 如果指定了复制账号的源分组,先获取账号 ID 列表
var
accountIDsToCopy
[]
int64
var
accountIDsToCopy
[]
int64
...
@@ -645,26 +671,29 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
...
@@ -645,26 +671,29 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
}
}
group
:=
&
Group
{
group
:=
&
Group
{
Name
:
input
.
Name
,
Name
:
input
.
Name
,
Description
:
input
.
Description
,
Description
:
input
.
Description
,
Platform
:
platform
,
Platform
:
platform
,
RateMultiplier
:
input
.
RateMultiplier
,
RateMultiplier
:
input
.
RateMultiplier
,
IsExclusive
:
input
.
IsExclusive
,
IsExclusive
:
input
.
IsExclusive
,
Status
:
StatusActive
,
Status
:
StatusActive
,
SubscriptionType
:
subscriptionType
,
SubscriptionType
:
subscriptionType
,
DailyLimitUSD
:
dailyLimit
,
DailyLimitUSD
:
dailyLimit
,
WeeklyLimitUSD
:
weeklyLimit
,
WeeklyLimitUSD
:
weeklyLimit
,
MonthlyLimitUSD
:
monthlyLimit
,
MonthlyLimitUSD
:
monthlyLimit
,
ImagePrice1K
:
imagePrice1K
,
ImagePrice1K
:
imagePrice1K
,
ImagePrice2K
:
imagePrice2K
,
ImagePrice2K
:
imagePrice2K
,
ImagePrice4K
:
imagePrice4K
,
ImagePrice4K
:
imagePrice4K
,
SoraImagePrice360
:
soraImagePrice360
,
SoraImagePrice360
:
soraImagePrice360
,
SoraImagePrice540
:
soraImagePrice540
,
SoraImagePrice540
:
soraImagePrice540
,
SoraVideoPricePerRequest
:
soraVideoPrice
,
SoraVideoPricePerRequest
:
soraVideoPrice
,
SoraVideoPricePerRequestHD
:
soraVideoPriceHD
,
SoraVideoPricePerRequestHD
:
soraVideoPriceHD
,
ClaudeCodeOnly
:
input
.
ClaudeCodeOnly
,
ClaudeCodeOnly
:
input
.
ClaudeCodeOnly
,
FallbackGroupID
:
input
.
FallbackGroupID
,
FallbackGroupID
:
input
.
FallbackGroupID
,
ModelRouting
:
input
.
ModelRouting
,
FallbackGroupIDOnInvalidRequest
:
fallbackOnInvalidRequest
,
ModelRouting
:
input
.
ModelRouting
,
MCPXMLInject
:
mcpXMLInject
,
SupportedModelScopes
:
input
.
SupportedModelScopes
,
}
}
if
err
:=
s
.
groupRepo
.
Create
(
ctx
,
group
);
err
!=
nil
{
if
err
:=
s
.
groupRepo
.
Create
(
ctx
,
group
);
err
!=
nil
{
return
nil
,
err
return
nil
,
err
...
@@ -735,6 +764,37 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro
...
@@ -735,6 +764,37 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro
}
}
}
}
// validateFallbackGroupOnInvalidRequest 校验无效请求兜底分组的有效性
// currentGroupID: 当前分组 ID(新建时为 0)
// platform/subscriptionType: 当前分组的有效平台/订阅类型
// fallbackGroupID: 兜底分组 ID
func
(
s
*
adminServiceImpl
)
validateFallbackGroupOnInvalidRequest
(
ctx
context
.
Context
,
currentGroupID
int64
,
platform
,
subscriptionType
string
,
fallbackGroupID
int64
)
error
{
if
platform
!=
PlatformAnthropic
&&
platform
!=
PlatformAntigravity
{
return
fmt
.
Errorf
(
"invalid request fallback only supported for anthropic or antigravity groups"
)
}
if
subscriptionType
==
SubscriptionTypeSubscription
{
return
fmt
.
Errorf
(
"subscription groups cannot set invalid request fallback"
)
}
if
currentGroupID
>
0
&&
currentGroupID
==
fallbackGroupID
{
return
fmt
.
Errorf
(
"cannot set self as invalid request fallback group"
)
}
fallbackGroup
,
err
:=
s
.
groupRepo
.
GetByIDLite
(
ctx
,
fallbackGroupID
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"fallback group not found: %w"
,
err
)
}
if
fallbackGroup
.
Platform
!=
PlatformAnthropic
{
return
fmt
.
Errorf
(
"fallback group must be anthropic platform"
)
}
if
fallbackGroup
.
SubscriptionType
==
SubscriptionTypeSubscription
{
return
fmt
.
Errorf
(
"fallback group cannot be subscription type"
)
}
if
fallbackGroup
.
FallbackGroupIDOnInvalidRequest
!=
nil
{
return
fmt
.
Errorf
(
"fallback group cannot have invalid request fallback configured"
)
}
return
nil
}
func
(
s
*
adminServiceImpl
)
UpdateGroup
(
ctx
context
.
Context
,
id
int64
,
input
*
UpdateGroupInput
)
(
*
Group
,
error
)
{
func
(
s
*
adminServiceImpl
)
UpdateGroup
(
ctx
context
.
Context
,
id
int64
,
input
*
UpdateGroupInput
)
(
*
Group
,
error
)
{
group
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
id
)
group
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -813,6 +873,20 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
...
@@ -813,6 +873,20 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
group
.
FallbackGroupID
=
nil
group
.
FallbackGroupID
=
nil
}
}
}
}
fallbackOnInvalidRequest
:=
group
.
FallbackGroupIDOnInvalidRequest
if
input
.
FallbackGroupIDOnInvalidRequest
!=
nil
{
if
*
input
.
FallbackGroupIDOnInvalidRequest
>
0
{
fallbackOnInvalidRequest
=
input
.
FallbackGroupIDOnInvalidRequest
}
else
{
fallbackOnInvalidRequest
=
nil
}
}
if
fallbackOnInvalidRequest
!=
nil
{
if
err
:=
s
.
validateFallbackGroupOnInvalidRequest
(
ctx
,
id
,
group
.
Platform
,
group
.
SubscriptionType
,
*
fallbackOnInvalidRequest
);
err
!=
nil
{
return
nil
,
err
}
}
group
.
FallbackGroupIDOnInvalidRequest
=
fallbackOnInvalidRequest
// 模型路由配置
// 模型路由配置
if
input
.
ModelRouting
!=
nil
{
if
input
.
ModelRouting
!=
nil
{
...
@@ -821,6 +895,14 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
...
@@ -821,6 +895,14 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if
input
.
ModelRoutingEnabled
!=
nil
{
if
input
.
ModelRoutingEnabled
!=
nil
{
group
.
ModelRoutingEnabled
=
*
input
.
ModelRoutingEnabled
group
.
ModelRoutingEnabled
=
*
input
.
ModelRoutingEnabled
}
}
if
input
.
MCPXMLInject
!=
nil
{
group
.
MCPXMLInject
=
*
input
.
MCPXMLInject
}
// 支持的模型系列(仅 antigravity 平台使用)
if
input
.
SupportedModelScopes
!=
nil
{
group
.
SupportedModelScopes
=
*
input
.
SupportedModelScopes
}
if
err
:=
s
.
groupRepo
.
Update
(
ctx
,
group
);
err
!=
nil
{
if
err
:=
s
.
groupRepo
.
Update
(
ctx
,
group
);
err
!=
nil
{
return
nil
,
err
return
nil
,
err
...
...
backend/internal/service/admin_service_group_test.go
View file @
de7ff902
...
@@ -394,3 +394,382 @@ func (s *groupRepoStubForFallbackCycle) BindAccountsToGroup(_ context.Context, _
...
@@ -394,3 +394,382 @@ func (s *groupRepoStubForFallbackCycle) BindAccountsToGroup(_ context.Context, _
func
(
s
*
groupRepoStubForFallbackCycle
)
GetAccountIDsByGroupIDs
(
_
context
.
Context
,
_
[]
int64
)
([]
int64
,
error
)
{
func
(
s
*
groupRepoStubForFallbackCycle
)
GetAccountIDsByGroupIDs
(
_
context
.
Context
,
_
[]
int64
)
([]
int64
,
error
)
{
panic
(
"unexpected GetAccountIDsByGroupIDs call"
)
panic
(
"unexpected GetAccountIDsByGroupIDs call"
)
}
}
type
groupRepoStubForInvalidRequestFallback
struct
{
groups
map
[
int64
]
*
Group
created
*
Group
updated
*
Group
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
Create
(
_
context
.
Context
,
g
*
Group
)
error
{
s
.
created
=
g
return
nil
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
Update
(
_
context
.
Context
,
g
*
Group
)
error
{
s
.
updated
=
g
return
nil
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
return
s
.
GetByIDLite
(
ctx
,
id
)
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
GetByIDLite
(
_
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
if
g
,
ok
:=
s
.
groups
[
id
];
ok
{
return
g
,
nil
}
return
nil
,
ErrGroupNotFound
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
Delete
(
_
context
.
Context
,
_
int64
)
error
{
panic
(
"unexpected Delete call"
)
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
DeleteCascade
(
_
context
.
Context
,
_
int64
)
([]
int64
,
error
)
{
panic
(
"unexpected DeleteCascade call"
)
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
List
(
_
context
.
Context
,
_
pagination
.
PaginationParams
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
panic
(
"unexpected List call"
)
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
ListWithFilters
(
_
context
.
Context
,
_
pagination
.
PaginationParams
,
_
,
_
,
_
string
,
_
*
bool
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
panic
(
"unexpected ListWithFilters call"
)
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
ListActive
(
_
context
.
Context
)
([]
Group
,
error
)
{
panic
(
"unexpected ListActive call"
)
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
ListActiveByPlatform
(
_
context
.
Context
,
_
string
)
([]
Group
,
error
)
{
panic
(
"unexpected ListActiveByPlatform call"
)
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
ExistsByName
(
_
context
.
Context
,
_
string
)
(
bool
,
error
)
{
panic
(
"unexpected ExistsByName call"
)
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
GetAccountCount
(
_
context
.
Context
,
_
int64
)
(
int64
,
error
)
{
panic
(
"unexpected GetAccountCount call"
)
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
DeleteAccountGroupsByGroupID
(
_
context
.
Context
,
_
int64
)
(
int64
,
error
)
{
panic
(
"unexpected DeleteAccountGroupsByGroupID call"
)
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
GetAccountIDsByGroupIDs
(
_
context
.
Context
,
_
[]
int64
)
([]
int64
,
error
)
{
panic
(
"unexpected GetAccountIDsByGroupIDs call"
)
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
BindAccountsToGroup
(
_
context
.
Context
,
_
int64
,
_
[]
int64
)
error
{
panic
(
"unexpected BindAccountsToGroup call"
)
}
func
TestAdminService_CreateGroup_InvalidRequestFallbackRejectsUnsupportedPlatform
(
t
*
testing
.
T
)
{
fallbackID
:=
int64
(
10
)
repo
:=
&
groupRepoStubForInvalidRequestFallback
{
groups
:
map
[
int64
]
*
Group
{
fallbackID
:
{
ID
:
fallbackID
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
},
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
_
,
err
:=
svc
.
CreateGroup
(
context
.
Background
(),
&
CreateGroupInput
{
Name
:
"g1"
,
Platform
:
PlatformOpenAI
,
SubscriptionType
:
SubscriptionTypeStandard
,
FallbackGroupIDOnInvalidRequest
:
&
fallbackID
,
})
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"invalid request fallback only supported for anthropic or antigravity groups"
)
require
.
Nil
(
t
,
repo
.
created
)
}
func
TestAdminService_CreateGroup_InvalidRequestFallbackRejectsSubscription
(
t
*
testing
.
T
)
{
fallbackID
:=
int64
(
10
)
repo
:=
&
groupRepoStubForInvalidRequestFallback
{
groups
:
map
[
int64
]
*
Group
{
fallbackID
:
{
ID
:
fallbackID
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
},
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
_
,
err
:=
svc
.
CreateGroup
(
context
.
Background
(),
&
CreateGroupInput
{
Name
:
"g1"
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeSubscription
,
FallbackGroupIDOnInvalidRequest
:
&
fallbackID
,
})
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"subscription groups cannot set invalid request fallback"
)
require
.
Nil
(
t
,
repo
.
created
)
}
func
TestAdminService_CreateGroup_InvalidRequestFallbackRejectsFallbackGroup
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
fallback
*
Group
wantMessage
string
}{
{
name
:
"openai_target"
,
fallback
:
&
Group
{
ID
:
10
,
Platform
:
PlatformOpenAI
,
SubscriptionType
:
SubscriptionTypeStandard
},
wantMessage
:
"fallback group must be anthropic platform"
,
},
{
name
:
"antigravity_target"
,
fallback
:
&
Group
{
ID
:
10
,
Platform
:
PlatformAntigravity
,
SubscriptionType
:
SubscriptionTypeStandard
},
wantMessage
:
"fallback group must be anthropic platform"
,
},
{
name
:
"subscription_group"
,
fallback
:
&
Group
{
ID
:
10
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeSubscription
},
wantMessage
:
"fallback group cannot be subscription type"
,
},
{
name
:
"nested_fallback"
,
fallback
:
&
Group
{
ID
:
10
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
,
FallbackGroupIDOnInvalidRequest
:
func
()
*
int64
{
v
:=
int64
(
99
);
return
&
v
}(),
},
wantMessage
:
"fallback group cannot have invalid request fallback configured"
,
},
}
for
_
,
tc
:=
range
tests
{
t
.
Run
(
tc
.
name
,
func
(
t
*
testing
.
T
)
{
fallbackID
:=
tc
.
fallback
.
ID
repo
:=
&
groupRepoStubForInvalidRequestFallback
{
groups
:
map
[
int64
]
*
Group
{
fallbackID
:
tc
.
fallback
,
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
_
,
err
:=
svc
.
CreateGroup
(
context
.
Background
(),
&
CreateGroupInput
{
Name
:
"g1"
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
,
FallbackGroupIDOnInvalidRequest
:
&
fallbackID
,
})
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
tc
.
wantMessage
)
require
.
Nil
(
t
,
repo
.
created
)
})
}
}
func
TestAdminService_CreateGroup_InvalidRequestFallbackNotFound
(
t
*
testing
.
T
)
{
fallbackID
:=
int64
(
10
)
repo
:=
&
groupRepoStubForInvalidRequestFallback
{}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
_
,
err
:=
svc
.
CreateGroup
(
context
.
Background
(),
&
CreateGroupInput
{
Name
:
"g1"
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
,
FallbackGroupIDOnInvalidRequest
:
&
fallbackID
,
})
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"fallback group not found"
)
require
.
Nil
(
t
,
repo
.
created
)
}
func
TestAdminService_CreateGroup_InvalidRequestFallbackAllowsAntigravity
(
t
*
testing
.
T
)
{
fallbackID
:=
int64
(
10
)
repo
:=
&
groupRepoStubForInvalidRequestFallback
{
groups
:
map
[
int64
]
*
Group
{
fallbackID
:
{
ID
:
fallbackID
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
},
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
group
,
err
:=
svc
.
CreateGroup
(
context
.
Background
(),
&
CreateGroupInput
{
Name
:
"g1"
,
Platform
:
PlatformAntigravity
,
SubscriptionType
:
SubscriptionTypeStandard
,
FallbackGroupIDOnInvalidRequest
:
&
fallbackID
,
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
group
)
require
.
NotNil
(
t
,
repo
.
created
)
require
.
Equal
(
t
,
fallbackID
,
*
repo
.
created
.
FallbackGroupIDOnInvalidRequest
)
}
func
TestAdminService_CreateGroup_InvalidRequestFallbackClearsOnZero
(
t
*
testing
.
T
)
{
zero
:=
int64
(
0
)
repo
:=
&
groupRepoStubForInvalidRequestFallback
{}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
group
,
err
:=
svc
.
CreateGroup
(
context
.
Background
(),
&
CreateGroupInput
{
Name
:
"g1"
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
,
FallbackGroupIDOnInvalidRequest
:
&
zero
,
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
group
)
require
.
NotNil
(
t
,
repo
.
created
)
require
.
Nil
(
t
,
repo
.
created
.
FallbackGroupIDOnInvalidRequest
)
}
func
TestAdminService_UpdateGroup_InvalidRequestFallbackPlatformMismatch
(
t
*
testing
.
T
)
{
fallbackID
:=
int64
(
10
)
existing
:=
&
Group
{
ID
:
1
,
Name
:
"g1"
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
,
Status
:
StatusActive
,
FallbackGroupIDOnInvalidRequest
:
&
fallbackID
,
}
repo
:=
&
groupRepoStubForInvalidRequestFallback
{
groups
:
map
[
int64
]
*
Group
{
existing
.
ID
:
existing
,
fallbackID
:
{
ID
:
fallbackID
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
},
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
_
,
err
:=
svc
.
UpdateGroup
(
context
.
Background
(),
existing
.
ID
,
&
UpdateGroupInput
{
Platform
:
PlatformOpenAI
,
})
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"invalid request fallback only supported for anthropic or antigravity groups"
)
require
.
Nil
(
t
,
repo
.
updated
)
}
func
TestAdminService_UpdateGroup_InvalidRequestFallbackSubscriptionMismatch
(
t
*
testing
.
T
)
{
fallbackID
:=
int64
(
10
)
existing
:=
&
Group
{
ID
:
1
,
Name
:
"g1"
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
,
Status
:
StatusActive
,
FallbackGroupIDOnInvalidRequest
:
&
fallbackID
,
}
repo
:=
&
groupRepoStubForInvalidRequestFallback
{
groups
:
map
[
int64
]
*
Group
{
existing
.
ID
:
existing
,
fallbackID
:
{
ID
:
fallbackID
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
},
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
_
,
err
:=
svc
.
UpdateGroup
(
context
.
Background
(),
existing
.
ID
,
&
UpdateGroupInput
{
SubscriptionType
:
SubscriptionTypeSubscription
,
})
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"subscription groups cannot set invalid request fallback"
)
require
.
Nil
(
t
,
repo
.
updated
)
}
func
TestAdminService_UpdateGroup_InvalidRequestFallbackClearsOnZero
(
t
*
testing
.
T
)
{
fallbackID
:=
int64
(
10
)
existing
:=
&
Group
{
ID
:
1
,
Name
:
"g1"
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
,
Status
:
StatusActive
,
FallbackGroupIDOnInvalidRequest
:
&
fallbackID
,
}
repo
:=
&
groupRepoStubForInvalidRequestFallback
{
groups
:
map
[
int64
]
*
Group
{
existing
.
ID
:
existing
,
fallbackID
:
{
ID
:
fallbackID
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
},
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
clear
:=
int64
(
0
)
group
,
err
:=
svc
.
UpdateGroup
(
context
.
Background
(),
existing
.
ID
,
&
UpdateGroupInput
{
Platform
:
PlatformOpenAI
,
FallbackGroupIDOnInvalidRequest
:
&
clear
,
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
group
)
require
.
NotNil
(
t
,
repo
.
updated
)
require
.
Nil
(
t
,
repo
.
updated
.
FallbackGroupIDOnInvalidRequest
)
}
func
TestAdminService_UpdateGroup_InvalidRequestFallbackRejectsFallbackGroup
(
t
*
testing
.
T
)
{
fallbackID
:=
int64
(
10
)
existing
:=
&
Group
{
ID
:
1
,
Name
:
"g1"
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
,
Status
:
StatusActive
,
}
repo
:=
&
groupRepoStubForInvalidRequestFallback
{
groups
:
map
[
int64
]
*
Group
{
existing
.
ID
:
existing
,
fallbackID
:
{
ID
:
fallbackID
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeSubscription
},
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
_
,
err
:=
svc
.
UpdateGroup
(
context
.
Background
(),
existing
.
ID
,
&
UpdateGroupInput
{
FallbackGroupIDOnInvalidRequest
:
&
fallbackID
,
})
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"fallback group cannot be subscription type"
)
require
.
Nil
(
t
,
repo
.
updated
)
}
func
TestAdminService_UpdateGroup_InvalidRequestFallbackSetSuccess
(
t
*
testing
.
T
)
{
fallbackID
:=
int64
(
10
)
existing
:=
&
Group
{
ID
:
1
,
Name
:
"g1"
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
,
Status
:
StatusActive
,
}
repo
:=
&
groupRepoStubForInvalidRequestFallback
{
groups
:
map
[
int64
]
*
Group
{
existing
.
ID
:
existing
,
fallbackID
:
{
ID
:
fallbackID
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
},
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
group
,
err
:=
svc
.
UpdateGroup
(
context
.
Background
(),
existing
.
ID
,
&
UpdateGroupInput
{
FallbackGroupIDOnInvalidRequest
:
&
fallbackID
,
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
group
)
require
.
NotNil
(
t
,
repo
.
updated
)
require
.
Equal
(
t
,
fallbackID
,
*
repo
.
updated
.
FallbackGroupIDOnInvalidRequest
)
}
func
TestAdminService_UpdateGroup_InvalidRequestFallbackAllowsAntigravity
(
t
*
testing
.
T
)
{
fallbackID
:=
int64
(
10
)
existing
:=
&
Group
{
ID
:
1
,
Name
:
"g1"
,
Platform
:
PlatformAntigravity
,
SubscriptionType
:
SubscriptionTypeStandard
,
Status
:
StatusActive
,
}
repo
:=
&
groupRepoStubForInvalidRequestFallback
{
groups
:
map
[
int64
]
*
Group
{
existing
.
ID
:
existing
,
fallbackID
:
{
ID
:
fallbackID
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
},
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
group
,
err
:=
svc
.
UpdateGroup
(
context
.
Background
(),
existing
.
ID
,
&
UpdateGroupInput
{
FallbackGroupIDOnInvalidRequest
:
&
fallbackID
,
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
group
)
require
.
NotNil
(
t
,
repo
.
updated
)
require
.
Equal
(
t
,
fallbackID
,
*
repo
.
updated
.
FallbackGroupIDOnInvalidRequest
)
}
backend/internal/service/antigravity_gateway_service.go
View file @
de7ff902
...
@@ -13,23 +13,34 @@ import (
...
@@ -13,23 +13,34 @@ import (
"net"
"net"
"net/http"
"net/http"
"os"
"os"
"strconv"
"strings"
"strings"
"sync/atomic"
"sync/atomic"
"time"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/google/uuid"
)
)
const
(
const
(
antigravityStickySessionTTL
=
time
.
Hour
antigravityStickySessionTTL
=
time
.
Hour
antigravityMaxRetries
=
3
antigravity
Default
MaxRetries
=
3
antigravityRetryBaseDelay
=
1
*
time
.
Second
antigravityRetryBaseDelay
=
1
*
time
.
Second
antigravityRetryMaxDelay
=
16
*
time
.
Second
antigravityRetryMaxDelay
=
16
*
time
.
Second
)
)
const
antigravityScopeRateLimitEnv
=
"GATEWAY_ANTIGRAVITY_429_SCOPE_LIMIT"
const
(
antigravityMaxRetriesEnv
=
"GATEWAY_ANTIGRAVITY_MAX_RETRIES"
antigravityMaxRetriesAfterSwitchEnv
=
"GATEWAY_ANTIGRAVITY_AFTER_SWITCHMAX_RETRIES"
antigravityMaxRetriesClaudeEnv
=
"GATEWAY_ANTIGRAVITY_MAX_RETRIES_CLAUDE"
antigravityMaxRetriesGeminiTextEnv
=
"GATEWAY_ANTIGRAVITY_MAX_RETRIES_GEMINI_TEXT"
antigravityMaxRetriesGeminiImageEnv
=
"GATEWAY_ANTIGRAVITY_MAX_RETRIES_GEMINI_IMAGE"
antigravityScopeRateLimitEnv
=
"GATEWAY_ANTIGRAVITY_429_SCOPE_LIMIT"
antigravityBillingModelEnv
=
"GATEWAY_ANTIGRAVITY_BILL_WITH_MAPPED_MODEL"
antigravityFallbackSecondsEnv
=
"GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS"
)
// antigravityRetryLoopParams 重试循环的参数
// antigravityRetryLoopParams 重试循环的参数
type
antigravityRetryLoopParams
struct
{
type
antigravityRetryLoopParams
struct
{
...
@@ -41,6 +52,7 @@ type antigravityRetryLoopParams struct {
...
@@ -41,6 +52,7 @@ type antigravityRetryLoopParams struct {
action
string
action
string
body
[]
byte
body
[]
byte
quotaScope
AntigravityQuotaScope
quotaScope
AntigravityQuotaScope
maxRetries
int
c
*
gin
.
Context
c
*
gin
.
Context
httpUpstream
HTTPUpstream
httpUpstream
HTTPUpstream
settingService
*
SettingService
settingService
*
SettingService
...
@@ -52,11 +64,28 @@ type antigravityRetryLoopResult struct {
...
@@ -52,11 +64,28 @@ type antigravityRetryLoopResult struct {
resp
*
http
.
Response
resp
*
http
.
Response
}
}
// PromptTooLongError 表示上游明确返回 prompt too long
type
PromptTooLongError
struct
{
StatusCode
int
RequestID
string
Body
[]
byte
}
func
(
e
*
PromptTooLongError
)
Error
()
string
{
return
fmt
.
Sprintf
(
"prompt too long: status=%d"
,
e
.
StatusCode
)
}
// antigravityRetryLoop 执行带 URL fallback 的重试循环
// antigravityRetryLoop 执行带 URL fallback 的重试循环
func
antigravityRetryLoop
(
p
antigravityRetryLoopParams
)
(
*
antigravityRetryLoopResult
,
error
)
{
func
antigravityRetryLoop
(
p
antigravityRetryLoopParams
)
(
*
antigravityRetryLoopResult
,
error
)
{
availableURLs
:=
antigravity
.
DefaultURLAvailability
.
GetAvailableURLs
()
baseURLs
:=
antigravity
.
ForwardBaseURLs
()
availableURLs
:=
antigravity
.
DefaultURLAvailability
.
GetAvailableURLsWithBase
(
baseURLs
)
if
len
(
availableURLs
)
==
0
{
if
len
(
availableURLs
)
==
0
{
availableURLs
=
antigravity
.
BaseURLs
availableURLs
=
baseURLs
}
maxRetries
:=
p
.
maxRetries
if
maxRetries
<=
0
{
maxRetries
=
antigravityDefaultMaxRetries
}
}
var
resp
*
http
.
Response
var
resp
*
http
.
Response
...
@@ -76,7 +105,7 @@ func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopRe
...
@@ -76,7 +105,7 @@ func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopRe
urlFallbackLoop
:
urlFallbackLoop
:
for
urlIdx
,
baseURL
:=
range
availableURLs
{
for
urlIdx
,
baseURL
:=
range
availableURLs
{
usedBaseURL
=
baseURL
usedBaseURL
=
baseURL
for
attempt
:=
1
;
attempt
<=
antigravityM
axRetries
;
attempt
++
{
for
attempt
:=
1
;
attempt
<=
m
axRetries
;
attempt
++
{
select
{
select
{
case
<-
p
.
ctx
.
Done
()
:
case
<-
p
.
ctx
.
Done
()
:
log
.
Printf
(
"%s status=context_canceled error=%v"
,
p
.
prefix
,
p
.
ctx
.
Err
())
log
.
Printf
(
"%s status=context_canceled error=%v"
,
p
.
prefix
,
p
.
ctx
.
Err
())
...
@@ -109,8 +138,8 @@ urlFallbackLoop:
...
@@ -109,8 +138,8 @@ urlFallbackLoop:
log
.
Printf
(
"%s URL fallback (connection error): %s -> %s"
,
p
.
prefix
,
baseURL
,
availableURLs
[
urlIdx
+
1
])
log
.
Printf
(
"%s URL fallback (connection error): %s -> %s"
,
p
.
prefix
,
baseURL
,
availableURLs
[
urlIdx
+
1
])
continue
urlFallbackLoop
continue
urlFallbackLoop
}
}
if
attempt
<
antigravityM
axRetries
{
if
attempt
<
m
axRetries
{
log
.
Printf
(
"%s status=request_failed retry=%d/%d error=%v"
,
p
.
prefix
,
attempt
,
antigravityM
axRetries
,
err
)
log
.
Printf
(
"%s status=request_failed retry=%d/%d error=%v"
,
p
.
prefix
,
attempt
,
m
axRetries
,
err
)
if
!
sleepAntigravityBackoffWithContext
(
p
.
ctx
,
attempt
)
{
if
!
sleepAntigravityBackoffWithContext
(
p
.
ctx
,
attempt
)
{
log
.
Printf
(
"%s status=context_canceled_during_backoff"
,
p
.
prefix
)
log
.
Printf
(
"%s status=context_canceled_during_backoff"
,
p
.
prefix
)
return
nil
,
p
.
ctx
.
Err
()
return
nil
,
p
.
ctx
.
Err
()
...
@@ -134,7 +163,7 @@ urlFallbackLoop:
...
@@ -134,7 +163,7 @@ urlFallbackLoop:
}
}
// 账户/模型配额限流,重试 3 次(指数退避)
// 账户/模型配额限流,重试 3 次(指数退避)
if
attempt
<
antigravityM
axRetries
{
if
attempt
<
m
axRetries
{
upstreamMsg
:=
strings
.
TrimSpace
(
extractAntigravityErrorMessage
(
respBody
))
upstreamMsg
:=
strings
.
TrimSpace
(
extractAntigravityErrorMessage
(
respBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
appendOpsUpstreamError
(
p
.
c
,
OpsUpstreamErrorEvent
{
appendOpsUpstreamError
(
p
.
c
,
OpsUpstreamErrorEvent
{
...
@@ -147,7 +176,7 @@ urlFallbackLoop:
...
@@ -147,7 +176,7 @@ urlFallbackLoop:
Message
:
upstreamMsg
,
Message
:
upstreamMsg
,
Detail
:
getUpstreamDetail
(
respBody
),
Detail
:
getUpstreamDetail
(
respBody
),
})
})
log
.
Printf
(
"%s status=429 retry=%d/%d body=%s"
,
p
.
prefix
,
attempt
,
antigravityM
axRetries
,
truncateForLog
(
respBody
,
200
))
log
.
Printf
(
"%s status=429 retry=%d/%d body=%s"
,
p
.
prefix
,
attempt
,
m
axRetries
,
truncateForLog
(
respBody
,
200
))
if
!
sleepAntigravityBackoffWithContext
(
p
.
ctx
,
attempt
)
{
if
!
sleepAntigravityBackoffWithContext
(
p
.
ctx
,
attempt
)
{
log
.
Printf
(
"%s status=context_canceled_during_backoff"
,
p
.
prefix
)
log
.
Printf
(
"%s status=context_canceled_during_backoff"
,
p
.
prefix
)
return
nil
,
p
.
ctx
.
Err
()
return
nil
,
p
.
ctx
.
Err
()
...
@@ -171,7 +200,7 @@ urlFallbackLoop:
...
@@ -171,7 +200,7 @@ urlFallbackLoop:
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
_
=
resp
.
Body
.
Close
()
_
=
resp
.
Body
.
Close
()
if
attempt
<
antigravityM
axRetries
{
if
attempt
<
m
axRetries
{
upstreamMsg
:=
strings
.
TrimSpace
(
extractAntigravityErrorMessage
(
respBody
))
upstreamMsg
:=
strings
.
TrimSpace
(
extractAntigravityErrorMessage
(
respBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
appendOpsUpstreamError
(
p
.
c
,
OpsUpstreamErrorEvent
{
appendOpsUpstreamError
(
p
.
c
,
OpsUpstreamErrorEvent
{
...
@@ -184,7 +213,7 @@ urlFallbackLoop:
...
@@ -184,7 +213,7 @@ urlFallbackLoop:
Message
:
upstreamMsg
,
Message
:
upstreamMsg
,
Detail
:
getUpstreamDetail
(
respBody
),
Detail
:
getUpstreamDetail
(
respBody
),
})
})
log
.
Printf
(
"%s status=%d retry=%d/%d body=%s"
,
p
.
prefix
,
resp
.
StatusCode
,
attempt
,
antigravityM
axRetries
,
truncateForLog
(
respBody
,
500
))
log
.
Printf
(
"%s status=%d retry=%d/%d body=%s"
,
p
.
prefix
,
resp
.
StatusCode
,
attempt
,
m
axRetries
,
truncateForLog
(
respBody
,
500
))
if
!
sleepAntigravityBackoffWithContext
(
p
.
ctx
,
attempt
)
{
if
!
sleepAntigravityBackoffWithContext
(
p
.
ctx
,
attempt
)
{
log
.
Printf
(
"%s status=context_canceled_during_backoff"
,
p
.
prefix
)
log
.
Printf
(
"%s status=context_canceled_during_backoff"
,
p
.
prefix
)
return
nil
,
p
.
ctx
.
Err
()
return
nil
,
p
.
ctx
.
Err
()
...
@@ -390,6 +419,11 @@ type TestConnectionResult struct {
...
@@ -390,6 +419,11 @@ type TestConnectionResult struct {
// TestConnection 测试 Antigravity 账号连接(非流式,无重试、无计费)
// TestConnection 测试 Antigravity 账号连接(非流式,无重试、无计费)
// 支持 Claude 和 Gemini 两种协议,根据 modelID 前缀自动选择
// 支持 Claude 和 Gemini 两种协议,根据 modelID 前缀自动选择
func
(
s
*
AntigravityGatewayService
)
TestConnection
(
ctx
context
.
Context
,
account
*
Account
,
modelID
string
)
(
*
TestConnectionResult
,
error
)
{
func
(
s
*
AntigravityGatewayService
)
TestConnection
(
ctx
context
.
Context
,
account
*
Account
,
modelID
string
)
(
*
TestConnectionResult
,
error
)
{
// 上游透传账号使用专用测试方法
if
account
.
Type
==
AccountTypeUpstream
{
return
s
.
testUpstreamConnection
(
ctx
,
account
,
modelID
)
}
// 获取 token
// 获取 token
if
s
.
tokenProvider
==
nil
{
if
s
.
tokenProvider
==
nil
{
return
nil
,
errors
.
New
(
"antigravity token provider not configured"
)
return
nil
,
errors
.
New
(
"antigravity token provider not configured"
)
...
@@ -484,6 +518,87 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
...
@@ -484,6 +518,87 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
return
nil
,
lastErr
return
nil
,
lastErr
}
}
// testUpstreamConnection 测试上游透传账号连接
func
(
s
*
AntigravityGatewayService
)
testUpstreamConnection
(
ctx
context
.
Context
,
account
*
Account
,
modelID
string
)
(
*
TestConnectionResult
,
error
)
{
baseURL
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"base_url"
))
apiKey
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"api_key"
))
if
baseURL
==
""
||
apiKey
==
""
{
return
nil
,
errors
.
New
(
"upstream account missing base_url or api_key"
)
}
baseURL
=
strings
.
TrimSuffix
(
baseURL
,
"/"
)
// 使用 Claude 模型进行测试
if
modelID
==
""
{
modelID
=
"claude-sonnet-4-20250514"
}
// 构建最小测试请求
testReq
:=
map
[
string
]
any
{
"model"
:
modelID
,
"max_tokens"
:
1
,
"messages"
:
[]
map
[
string
]
any
{
{
"role"
:
"user"
,
"content"
:
"."
},
},
}
requestBody
,
err
:=
json
.
Marshal
(
testReq
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"构建请求失败: %w"
,
err
)
}
// 构建 HTTP 请求
upstreamURL
:=
baseURL
+
"/v1/messages"
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodPost
,
upstreamURL
,
bytes
.
NewReader
(
requestBody
))
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"创建请求失败: %w"
,
err
)
}
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
apiKey
)
req
.
Header
.
Set
(
"x-api-key"
,
apiKey
)
req
.
Header
.
Set
(
"anthropic-version"
,
"2023-06-01"
)
// 代理 URL
proxyURL
:=
""
if
account
.
ProxyID
!=
nil
&&
account
.
Proxy
!=
nil
{
proxyURL
=
account
.
Proxy
.
URL
()
}
log
.
Printf
(
"[antigravity-Test-Upstream] account=%s url=%s"
,
account
.
Name
,
upstreamURL
)
// 发送请求
resp
,
err
:=
s
.
httpUpstream
.
Do
(
req
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"请求失败: %w"
,
err
)
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
respBody
,
err
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"读取响应失败: %w"
,
err
)
}
if
resp
.
StatusCode
>=
400
{
return
nil
,
fmt
.
Errorf
(
"API 返回 %d: %s"
,
resp
.
StatusCode
,
string
(
respBody
))
}
// 提取响应文本
var
respData
map
[
string
]
any
text
:=
""
if
json
.
Unmarshal
(
respBody
,
&
respData
)
==
nil
{
if
content
,
ok
:=
respData
[
"content"
]
.
([]
any
);
ok
&&
len
(
content
)
>
0
{
if
block
,
ok
:=
content
[
0
]
.
(
map
[
string
]
any
);
ok
{
if
t
,
ok
:=
block
[
"text"
]
.
(
string
);
ok
{
text
=
t
}
}
}
}
return
&
TestConnectionResult
{
Text
:
text
,
MappedModel
:
modelID
,
},
nil
}
// buildGeminiTestRequest 构建 Gemini 格式测试请求
// buildGeminiTestRequest 构建 Gemini 格式测试请求
// 使用最小 token 消耗:输入 "." + maxOutputTokens: 1
// 使用最小 token 消耗:输入 "." + maxOutputTokens: 1
func
(
s
*
AntigravityGatewayService
)
buildGeminiTestRequest
(
projectID
,
model
string
)
([]
byte
,
error
)
{
func
(
s
*
AntigravityGatewayService
)
buildGeminiTestRequest
(
projectID
,
model
string
)
([]
byte
,
error
)
{
...
@@ -534,6 +649,10 @@ func (s *AntigravityGatewayService) getClaudeTransformOptions(ctx context.Contex
...
@@ -534,6 +649,10 @@ func (s *AntigravityGatewayService) getClaudeTransformOptions(ctx context.Contex
}
}
opts
.
EnableIdentityPatch
=
s
.
settingService
.
IsIdentityPatchEnabled
(
ctx
)
opts
.
EnableIdentityPatch
=
s
.
settingService
.
IsIdentityPatchEnabled
(
ctx
)
opts
.
IdentityPatch
=
s
.
settingService
.
GetIdentityPatchPrompt
(
ctx
)
opts
.
IdentityPatch
=
s
.
settingService
.
GetIdentityPatchPrompt
(
ctx
)
if
group
,
ok
:=
ctx
.
Value
(
ctxkey
.
Group
)
.
(
*
Group
);
ok
&&
group
!=
nil
{
opts
.
EnableMCPXML
=
group
.
MCPXMLInject
}
return
opts
return
opts
}
}
...
@@ -702,6 +821,11 @@ func isModelNotFoundError(statusCode int, body []byte) bool {
...
@@ -702,6 +821,11 @@ func isModelNotFoundError(statusCode int, body []byte) bool {
// Forward 转发 Claude 协议请求(Claude → Gemini 转换)
// Forward 转发 Claude 协议请求(Claude → Gemini 转换)
func
(
s
*
AntigravityGatewayService
)
Forward
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
)
(
*
ForwardResult
,
error
)
{
func
(
s
*
AntigravityGatewayService
)
Forward
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
)
(
*
ForwardResult
,
error
)
{
// 上游透传账号直接转发,不走 OAuth token 刷新
if
account
.
Type
==
AccountTypeUpstream
{
return
s
.
ForwardUpstream
(
ctx
,
c
,
account
,
body
)
}
startTime
:=
time
.
Now
()
startTime
:=
time
.
Now
()
sessionID
:=
getSessionID
(
c
)
sessionID
:=
getSessionID
(
c
)
prefix
:=
logPrefix
(
sessionID
,
account
.
Name
)
prefix
:=
logPrefix
(
sessionID
,
account
.
Name
)
...
@@ -718,6 +842,12 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
...
@@ -718,6 +842,12 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
originalModel
:=
claudeReq
.
Model
originalModel
:=
claudeReq
.
Model
mappedModel
:=
s
.
getMappedModel
(
account
,
claudeReq
.
Model
)
mappedModel
:=
s
.
getMappedModel
(
account
,
claudeReq
.
Model
)
quotaScope
,
_
:=
resolveAntigravityQuotaScope
(
originalModel
)
quotaScope
,
_
:=
resolveAntigravityQuotaScope
(
originalModel
)
billingModel
:=
originalModel
if
antigravityUseMappedModelForBilling
()
&&
strings
.
TrimSpace
(
mappedModel
)
!=
""
{
billingModel
=
mappedModel
}
afterSwitch
:=
antigravityHasAccountSwitch
(
ctx
)
maxRetries
:=
antigravityMaxRetriesForModel
(
originalModel
,
afterSwitch
)
// 获取 access_token
// 获取 access_token
if
s
.
tokenProvider
==
nil
{
if
s
.
tokenProvider
==
nil
{
...
@@ -766,6 +896,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
...
@@ -766,6 +896,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
httpUpstream
:
s
.
httpUpstream
,
httpUpstream
:
s
.
httpUpstream
,
settingService
:
s
.
settingService
,
settingService
:
s
.
settingService
,
handleError
:
s
.
handleUpstreamError
,
handleError
:
s
.
handleUpstreamError
,
maxRetries
:
maxRetries
,
})
})
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
s
.
writeClaudeError
(
c
,
http
.
StatusBadGateway
,
"upstream_error"
,
"Upstream request failed after retries"
)
return
nil
,
s
.
writeClaudeError
(
c
,
http
.
StatusBadGateway
,
"upstream_error"
,
"Upstream request failed after retries"
)
...
@@ -842,6 +973,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
...
@@ -842,6 +973,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
httpUpstream
:
s
.
httpUpstream
,
httpUpstream
:
s
.
httpUpstream
,
settingService
:
s
.
settingService
,
settingService
:
s
.
settingService
,
handleError
:
s
.
handleUpstreamError
,
handleError
:
s
.
handleUpstreamError
,
maxRetries
:
maxRetries
,
})
})
if
retryErr
!=
nil
{
if
retryErr
!=
nil
{
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
...
@@ -917,6 +1049,39 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
...
@@ -917,6 +1049,39 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
// 处理错误响应(重试后仍失败或不触发重试)
// 处理错误响应(重试后仍失败或不触发重试)
if
resp
.
StatusCode
>=
400
{
if
resp
.
StatusCode
>=
400
{
if
resp
.
StatusCode
==
http
.
StatusBadRequest
{
upstreamMsg
:=
strings
.
TrimSpace
(
extractAntigravityErrorMessage
(
respBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
log
.
Printf
(
"%s status=400 prompt_too_long=%v upstream_message=%q request_id=%s body=%s"
,
prefix
,
isPromptTooLongError
(
respBody
),
upstreamMsg
,
resp
.
Header
.
Get
(
"x-request-id"
),
truncateForLog
(
respBody
,
500
))
}
if
resp
.
StatusCode
==
http
.
StatusBadRequest
&&
isPromptTooLongError
(
respBody
)
{
upstreamMsg
:=
strings
.
TrimSpace
(
extractAntigravityErrorMessage
(
respBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
logBody
:=
s
.
settingService
!=
nil
&&
s
.
settingService
.
cfg
!=
nil
&&
s
.
settingService
.
cfg
.
Gateway
.
LogUpstreamErrorBody
maxBytes
:=
2048
if
s
.
settingService
!=
nil
&&
s
.
settingService
.
cfg
!=
nil
&&
s
.
settingService
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
>
0
{
maxBytes
=
s
.
settingService
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
}
upstreamDetail
:=
""
if
logBody
{
upstreamDetail
=
truncateString
(
string
(
respBody
),
maxBytes
)
}
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Kind
:
"prompt_too_long"
,
Message
:
upstreamMsg
,
Detail
:
upstreamDetail
,
})
return
nil
,
&
PromptTooLongError
{
StatusCode
:
resp
.
StatusCode
,
RequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Body
:
respBody
,
}
}
s
.
handleUpstreamError
(
ctx
,
prefix
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
,
quotaScope
)
s
.
handleUpstreamError
(
ctx
,
prefix
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
,
quotaScope
)
if
s
.
shouldFailoverUpstreamError
(
resp
.
StatusCode
)
{
if
s
.
shouldFailoverUpstreamError
(
resp
.
StatusCode
)
{
...
@@ -978,7 +1143,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
...
@@ -978,7 +1143,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return
&
ForwardResult
{
return
&
ForwardResult
{
RequestID
:
requestID
,
RequestID
:
requestID
,
Usage
:
*
usage
,
Usage
:
*
usage
,
Model
:
original
Model
,
//
使用原始模型用于计费和日志
Model
:
billing
Model
,
//
计费模型(可按映射模型覆盖)
Stream
:
claudeReq
.
Stream
,
Stream
:
claudeReq
.
Stream
,
Duration
:
time
.
Since
(
startTime
),
Duration
:
time
.
Since
(
startTime
),
FirstTokenMs
:
firstTokenMs
,
FirstTokenMs
:
firstTokenMs
,
...
@@ -1003,24 +1168,64 @@ func isSignatureRelatedError(respBody []byte) bool {
...
@@ -1003,24 +1168,64 @@ func isSignatureRelatedError(respBody []byte) bool {
return
true
return
true
}
}
// Detect thinking block modification errors:
// "thinking or redacted_thinking blocks in the latest assistant message cannot be modified"
if
strings
.
Contains
(
msg
,
"cannot be modified"
)
&&
(
strings
.
Contains
(
msg
,
"thinking"
)
||
strings
.
Contains
(
msg
,
"redacted_thinking"
))
{
return
true
}
return
false
return
false
}
}
func
isPromptTooLongError
(
respBody
[]
byte
)
bool
{
msg
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
extractAntigravityErrorMessage
(
respBody
)))
if
msg
==
""
{
msg
=
strings
.
ToLower
(
string
(
respBody
))
}
return
strings
.
Contains
(
msg
,
"prompt is too long"
)
}
func
extractAntigravityErrorMessage
(
body
[]
byte
)
string
{
func
extractAntigravityErrorMessage
(
body
[]
byte
)
string
{
var
payload
map
[
string
]
any
var
payload
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
payload
);
err
!=
nil
{
if
err
:=
json
.
Unmarshal
(
body
,
&
payload
);
err
!=
nil
{
return
""
return
""
}
}
parseNestedMessage
:=
func
(
msg
string
)
string
{
trimmed
:=
strings
.
TrimSpace
(
msg
)
if
trimmed
==
""
||
!
strings
.
HasPrefix
(
trimmed
,
"{"
)
{
return
""
}
var
nested
map
[
string
]
any
if
err
:=
json
.
Unmarshal
([]
byte
(
trimmed
),
&
nested
);
err
!=
nil
{
return
""
}
if
errObj
,
ok
:=
nested
[
"error"
]
.
(
map
[
string
]
any
);
ok
{
if
innerMsg
,
ok
:=
errObj
[
"message"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
innerMsg
)
!=
""
{
return
innerMsg
}
}
if
innerMsg
,
ok
:=
nested
[
"message"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
innerMsg
)
!=
""
{
return
innerMsg
}
return
""
}
// Google-style: {"error": {"message": "..."}}
// Google-style: {"error": {"message": "..."}}
if
errObj
,
ok
:=
payload
[
"error"
]
.
(
map
[
string
]
any
);
ok
{
if
errObj
,
ok
:=
payload
[
"error"
]
.
(
map
[
string
]
any
);
ok
{
if
msg
,
ok
:=
errObj
[
"message"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
msg
)
!=
""
{
if
msg
,
ok
:=
errObj
[
"message"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
msg
)
!=
""
{
if
innerMsg
:=
parseNestedMessage
(
msg
);
innerMsg
!=
""
{
return
innerMsg
}
return
msg
return
msg
}
}
}
}
// Fallback: top-level message
// Fallback: top-level message
if
msg
,
ok
:=
payload
[
"message"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
msg
)
!=
""
{
if
msg
,
ok
:=
payload
[
"message"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
msg
)
!=
""
{
if
innerMsg
:=
parseNestedMessage
(
msg
);
innerMsg
!=
""
{
return
innerMsg
}
return
msg
return
msg
}
}
...
@@ -1248,6 +1453,208 @@ func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeReque
...
@@ -1248,6 +1453,208 @@ func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeReque
return
changed
,
nil
return
changed
,
nil
}
}
// ForwardUpstream 透传请求到上游 Antigravity 服务
// 用于 upstream 类型账号,直接使用 base_url + api_key 转发,不走 OAuth token
func
(
s
*
AntigravityGatewayService
)
ForwardUpstream
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
)
(
*
ForwardResult
,
error
)
{
startTime
:=
time
.
Now
()
sessionID
:=
getSessionID
(
c
)
prefix
:=
logPrefix
(
sessionID
,
account
.
Name
)
// 获取上游配置
baseURL
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"base_url"
))
apiKey
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"api_key"
))
if
baseURL
==
""
||
apiKey
==
""
{
return
nil
,
fmt
.
Errorf
(
"upstream account missing base_url or api_key"
)
}
baseURL
=
strings
.
TrimSuffix
(
baseURL
,
"/"
)
// 解析请求获取模型信息
var
claudeReq
antigravity
.
ClaudeRequest
if
err
:=
json
.
Unmarshal
(
body
,
&
claudeReq
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"parse claude request: %w"
,
err
)
}
if
strings
.
TrimSpace
(
claudeReq
.
Model
)
==
""
{
return
nil
,
fmt
.
Errorf
(
"missing model"
)
}
originalModel
:=
claudeReq
.
Model
billingModel
:=
originalModel
// 构建上游请求 URL
upstreamURL
:=
baseURL
+
"/v1/messages"
// 创建请求
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodPost
,
upstreamURL
,
bytes
.
NewReader
(
body
))
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"create upstream request: %w"
,
err
)
}
// 设置请求头
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
apiKey
)
req
.
Header
.
Set
(
"x-api-key"
,
apiKey
)
// Claude API 兼容
// 透传 Claude 相关 headers
if
v
:=
c
.
GetHeader
(
"anthropic-version"
);
v
!=
""
{
req
.
Header
.
Set
(
"anthropic-version"
,
v
)
}
if
v
:=
c
.
GetHeader
(
"anthropic-beta"
);
v
!=
""
{
req
.
Header
.
Set
(
"anthropic-beta"
,
v
)
}
// 代理 URL
proxyURL
:=
""
if
account
.
ProxyID
!=
nil
&&
account
.
Proxy
!=
nil
{
proxyURL
=
account
.
Proxy
.
URL
()
}
// 发送请求
resp
,
err
:=
s
.
httpUpstream
.
Do
(
req
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
err
!=
nil
{
log
.
Printf
(
"%s upstream request failed: %v"
,
prefix
,
err
)
return
nil
,
fmt
.
Errorf
(
"upstream request failed: %w"
,
err
)
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
// 处理错误响应
if
resp
.
StatusCode
>=
400
{
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
// 429 错误时标记账号限流
if
resp
.
StatusCode
==
http
.
StatusTooManyRequests
{
s
.
handleUpstreamError
(
ctx
,
prefix
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
,
AntigravityQuotaScopeClaude
)
}
// 透传上游错误
c
.
Header
(
"Content-Type"
,
resp
.
Header
.
Get
(
"Content-Type"
))
c
.
Status
(
resp
.
StatusCode
)
_
,
_
=
c
.
Writer
.
Write
(
respBody
)
return
&
ForwardResult
{
Model
:
billingModel
,
},
nil
}
// 处理成功响应(流式/非流式)
var
usage
*
ClaudeUsage
var
firstTokenMs
*
int
if
claudeReq
.
Stream
{
// 流式响应:透传
c
.
Header
(
"Content-Type"
,
"text/event-stream"
)
c
.
Header
(
"Cache-Control"
,
"no-cache"
)
c
.
Header
(
"Connection"
,
"keep-alive"
)
c
.
Header
(
"X-Accel-Buffering"
,
"no"
)
c
.
Status
(
http
.
StatusOK
)
usage
,
firstTokenMs
=
s
.
streamUpstreamResponse
(
c
,
resp
,
startTime
)
}
else
{
// 非流式响应:直接透传
respBody
,
err
:=
io
.
ReadAll
(
resp
.
Body
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"read upstream response: %w"
,
err
)
}
// 提取 usage
usage
=
s
.
extractClaudeUsage
(
respBody
)
c
.
Header
(
"Content-Type"
,
resp
.
Header
.
Get
(
"Content-Type"
))
c
.
Status
(
http
.
StatusOK
)
_
,
_
=
c
.
Writer
.
Write
(
respBody
)
}
// 构建计费结果
duration
:=
time
.
Since
(
startTime
)
log
.
Printf
(
"%s status=success duration_ms=%d"
,
prefix
,
duration
.
Milliseconds
())
return
&
ForwardResult
{
Model
:
billingModel
,
Stream
:
claudeReq
.
Stream
,
Duration
:
duration
,
FirstTokenMs
:
firstTokenMs
,
Usage
:
ClaudeUsage
{
InputTokens
:
usage
.
InputTokens
,
OutputTokens
:
usage
.
OutputTokens
,
CacheReadInputTokens
:
usage
.
CacheReadInputTokens
,
CacheCreationInputTokens
:
usage
.
CacheCreationInputTokens
,
},
},
nil
}
// streamUpstreamResponse 透传上游流式响应并提取 usage
func
(
s
*
AntigravityGatewayService
)
streamUpstreamResponse
(
c
*
gin
.
Context
,
resp
*
http
.
Response
,
startTime
time
.
Time
)
(
*
ClaudeUsage
,
*
int
)
{
usage
:=
&
ClaudeUsage
{}
var
firstTokenMs
*
int
var
firstTokenRecorded
bool
scanner
:=
bufio
.
NewScanner
(
resp
.
Body
)
buf
:=
make
([]
byte
,
0
,
64
*
1024
)
scanner
.
Buffer
(
buf
,
1024
*
1024
)
for
scanner
.
Scan
()
{
line
:=
scanner
.
Bytes
()
// 记录首 token 时间
if
!
firstTokenRecorded
&&
len
(
line
)
>
0
{
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
firstTokenMs
=
&
ms
firstTokenRecorded
=
true
}
// 尝试从 message_delta 或 message_stop 事件提取 usage
if
bytes
.
HasPrefix
(
line
,
[]
byte
(
"data: "
))
{
dataStr
:=
bytes
.
TrimPrefix
(
line
,
[]
byte
(
"data: "
))
var
event
map
[
string
]
any
if
json
.
Unmarshal
(
dataStr
,
&
event
)
==
nil
{
if
u
,
ok
:=
event
[
"usage"
]
.
(
map
[
string
]
any
);
ok
{
if
v
,
ok
:=
u
[
"input_tokens"
]
.
(
float64
);
ok
&&
int
(
v
)
>
0
{
usage
.
InputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"output_tokens"
]
.
(
float64
);
ok
&&
int
(
v
)
>
0
{
usage
.
OutputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"cache_read_input_tokens"
]
.
(
float64
);
ok
&&
int
(
v
)
>
0
{
usage
.
CacheReadInputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"cache_creation_input_tokens"
]
.
(
float64
);
ok
&&
int
(
v
)
>
0
{
usage
.
CacheCreationInputTokens
=
int
(
v
)
}
}
}
}
// 透传行
_
,
_
=
c
.
Writer
.
Write
(
line
)
_
,
_
=
c
.
Writer
.
Write
([]
byte
(
"
\n
"
))
c
.
Writer
.
Flush
()
}
return
usage
,
firstTokenMs
}
// extractClaudeUsage 从非流式 Claude 响应提取 usage
func
(
s
*
AntigravityGatewayService
)
extractClaudeUsage
(
body
[]
byte
)
*
ClaudeUsage
{
usage
:=
&
ClaudeUsage
{}
var
resp
map
[
string
]
any
if
json
.
Unmarshal
(
body
,
&
resp
)
!=
nil
{
return
usage
}
if
u
,
ok
:=
resp
[
"usage"
]
.
(
map
[
string
]
any
);
ok
{
if
v
,
ok
:=
u
[
"input_tokens"
]
.
(
float64
);
ok
{
usage
.
InputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"output_tokens"
]
.
(
float64
);
ok
{
usage
.
OutputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"cache_read_input_tokens"
]
.
(
float64
);
ok
{
usage
.
CacheReadInputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"cache_creation_input_tokens"
]
.
(
float64
);
ok
{
usage
.
CacheCreationInputTokens
=
int
(
v
)
}
}
return
usage
}
// ForwardGemini 转发 Gemini 协议请求
// ForwardGemini 转发 Gemini 协议请求
func
(
s
*
AntigravityGatewayService
)
ForwardGemini
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
originalModel
string
,
action
string
,
stream
bool
,
body
[]
byte
)
(
*
ForwardResult
,
error
)
{
func
(
s
*
AntigravityGatewayService
)
ForwardGemini
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
originalModel
string
,
action
string
,
stream
bool
,
body
[]
byte
)
(
*
ForwardResult
,
error
)
{
startTime
:=
time
.
Now
()
startTime
:=
time
.
Now
()
...
@@ -1287,6 +1694,12 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
...
@@ -1287,6 +1694,12 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
}
}
mappedModel
:=
s
.
getMappedModel
(
account
,
originalModel
)
mappedModel
:=
s
.
getMappedModel
(
account
,
originalModel
)
billingModel
:=
originalModel
if
antigravityUseMappedModelForBilling
()
&&
strings
.
TrimSpace
(
mappedModel
)
!=
""
{
billingModel
=
mappedModel
}
afterSwitch
:=
antigravityHasAccountSwitch
(
ctx
)
maxRetries
:=
antigravityMaxRetriesForModel
(
originalModel
,
afterSwitch
)
// 获取 access_token
// 获取 access_token
if
s
.
tokenProvider
==
nil
{
if
s
.
tokenProvider
==
nil
{
...
@@ -1306,8 +1719,15 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
...
@@ -1306,8 +1719,15 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
proxyURL
=
account
.
Proxy
.
URL
()
proxyURL
=
account
.
Proxy
.
URL
()
}
}
// 过滤掉 parts 为空的消息(Gemini API 不接受空 parts)
filteredBody
,
err
:=
filterEmptyPartsFromGeminiRequest
(
body
)
if
err
!=
nil
{
log
.
Printf
(
"[Antigravity] Failed to filter empty parts: %v"
,
err
)
filteredBody
=
body
}
// Antigravity 上游要求必须包含身份提示词,注入到请求中
// Antigravity 上游要求必须包含身份提示词,注入到请求中
injectedBody
,
err
:=
injectIdentityPatchToGeminiRequest
(
b
ody
)
injectedBody
,
err
:=
injectIdentityPatchToGeminiRequest
(
filteredB
ody
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
...
@@ -1344,6 +1764,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
...
@@ -1344,6 +1764,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
httpUpstream
:
s
.
httpUpstream
,
httpUpstream
:
s
.
httpUpstream
,
settingService
:
s
.
settingService
,
settingService
:
s
.
settingService
,
handleError
:
s
.
handleUpstreamError
,
handleError
:
s
.
handleUpstreamError
,
maxRetries
:
maxRetries
,
})
})
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
s
.
writeGoogleError
(
c
,
http
.
StatusBadGateway
,
"Upstream request failed after retries"
)
return
nil
,
s
.
writeGoogleError
(
c
,
http
.
StatusBadGateway
,
"Upstream request failed after retries"
)
...
@@ -1493,7 +1914,7 @@ handleSuccess:
...
@@ -1493,7 +1914,7 @@ handleSuccess:
return
&
ForwardResult
{
return
&
ForwardResult
{
RequestID
:
requestID
,
RequestID
:
requestID
,
Usage
:
*
usage
,
Usage
:
*
usage
,
Model
:
original
Model
,
Model
:
billing
Model
,
Stream
:
stream
,
Stream
:
stream
,
Duration
:
time
.
Since
(
startTime
),
Duration
:
time
.
Since
(
startTime
),
FirstTokenMs
:
firstTokenMs
,
FirstTokenMs
:
firstTokenMs
,
...
@@ -1544,6 +1965,81 @@ func antigravityUseScopeRateLimit() bool {
...
@@ -1544,6 +1965,81 @@ func antigravityUseScopeRateLimit() bool {
return
true
return
true
}
}
func
antigravityHasAccountSwitch
(
ctx
context
.
Context
)
bool
{
if
ctx
==
nil
{
return
false
}
if
v
,
ok
:=
ctx
.
Value
(
ctxkey
.
AccountSwitchCount
)
.
(
int
);
ok
{
return
v
>
0
}
return
false
}
func
antigravityMaxRetries
()
int
{
raw
:=
strings
.
TrimSpace
(
os
.
Getenv
(
antigravityMaxRetriesEnv
))
if
raw
==
""
{
return
antigravityDefaultMaxRetries
}
value
,
err
:=
strconv
.
Atoi
(
raw
)
if
err
!=
nil
||
value
<=
0
{
return
antigravityDefaultMaxRetries
}
return
value
}
func
antigravityMaxRetriesAfterSwitch
()
int
{
raw
:=
strings
.
TrimSpace
(
os
.
Getenv
(
antigravityMaxRetriesAfterSwitchEnv
))
if
raw
==
""
{
return
antigravityMaxRetries
()
}
value
,
err
:=
strconv
.
Atoi
(
raw
)
if
err
!=
nil
||
value
<=
0
{
return
antigravityMaxRetries
()
}
return
value
}
// antigravityMaxRetriesForModel 根据模型类型获取重试次数
// 优先使用模型细分配置,未设置则回退到平台级配置
func
antigravityMaxRetriesForModel
(
model
string
,
afterSwitch
bool
)
int
{
var
envKey
string
if
strings
.
HasPrefix
(
model
,
"claude-"
)
{
envKey
=
antigravityMaxRetriesClaudeEnv
}
else
if
isImageGenerationModel
(
model
)
{
envKey
=
antigravityMaxRetriesGeminiImageEnv
}
else
if
strings
.
HasPrefix
(
model
,
"gemini-"
)
{
envKey
=
antigravityMaxRetriesGeminiTextEnv
}
if
envKey
!=
""
{
if
raw
:=
strings
.
TrimSpace
(
os
.
Getenv
(
envKey
));
raw
!=
""
{
if
value
,
err
:=
strconv
.
Atoi
(
raw
);
err
==
nil
&&
value
>
0
{
return
value
}
}
}
if
afterSwitch
{
return
antigravityMaxRetriesAfterSwitch
()
}
return
antigravityMaxRetries
()
}
func
antigravityUseMappedModelForBilling
()
bool
{
v
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
os
.
Getenv
(
antigravityBillingModelEnv
)))
return
v
==
"1"
||
v
==
"true"
||
v
==
"yes"
||
v
==
"on"
}
func
antigravityFallbackCooldownSeconds
()
(
time
.
Duration
,
bool
)
{
raw
:=
strings
.
TrimSpace
(
os
.
Getenv
(
antigravityFallbackSecondsEnv
))
if
raw
==
""
{
return
0
,
false
}
seconds
,
err
:=
strconv
.
Atoi
(
raw
)
if
err
!=
nil
||
seconds
<=
0
{
return
0
,
false
}
return
time
.
Duration
(
seconds
)
*
time
.
Second
,
true
}
func
(
s
*
AntigravityGatewayService
)
handleUpstreamError
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
quotaScope
AntigravityQuotaScope
)
{
func
(
s
*
AntigravityGatewayService
)
handleUpstreamError
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
quotaScope
AntigravityQuotaScope
)
{
// 429 使用 Gemini 格式解析(从 body 解析重置时间)
// 429 使用 Gemini 格式解析(从 body 解析重置时间)
if
statusCode
==
429
{
if
statusCode
==
429
{
...
@@ -1556,6 +2052,9 @@ func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, pre
...
@@ -1556,6 +2052,9 @@ func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, pre
fallbackMinutes
=
s
.
settingService
.
cfg
.
Gateway
.
AntigravityFallbackCooldownMinutes
fallbackMinutes
=
s
.
settingService
.
cfg
.
Gateway
.
AntigravityFallbackCooldownMinutes
}
}
defaultDur
:=
time
.
Duration
(
fallbackMinutes
)
*
time
.
Minute
defaultDur
:=
time
.
Duration
(
fallbackMinutes
)
*
time
.
Minute
if
fallbackDur
,
ok
:=
antigravityFallbackCooldownSeconds
();
ok
{
defaultDur
=
fallbackDur
}
ra
:=
time
.
Now
()
.
Add
(
defaultDur
)
ra
:=
time
.
Now
()
.
Add
(
defaultDur
)
if
useScopeLimit
{
if
useScopeLimit
{
log
.
Printf
(
"%s status=429 rate_limited scope=%s reset_in=%v (fallback)"
,
prefix
,
quotaScope
,
defaultDur
)
log
.
Printf
(
"%s status=429 rate_limited scope=%s reset_in=%v (fallback)"
,
prefix
,
quotaScope
,
defaultDur
)
...
@@ -2193,6 +2692,10 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou
...
@@ -2193,6 +2692,10 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou
return
fmt
.
Errorf
(
"upstream error: %d message=%s"
,
upstreamStatus
,
upstreamMsg
)
return
fmt
.
Errorf
(
"upstream error: %d message=%s"
,
upstreamStatus
,
upstreamMsg
)
}
}
func
(
s
*
AntigravityGatewayService
)
WriteMappedClaudeError
(
c
*
gin
.
Context
,
account
*
Account
,
upstreamStatus
int
,
upstreamRequestID
string
,
body
[]
byte
)
error
{
return
s
.
writeMappedClaudeError
(
c
,
account
,
upstreamStatus
,
upstreamRequestID
,
body
)
}
func
(
s
*
AntigravityGatewayService
)
writeGoogleError
(
c
*
gin
.
Context
,
status
int
,
message
string
)
error
{
func
(
s
*
AntigravityGatewayService
)
writeGoogleError
(
c
*
gin
.
Context
,
status
int
,
message
string
)
error
{
statusStr
:=
"UNKNOWN"
statusStr
:=
"UNKNOWN"
switch
status
{
switch
status
{
...
@@ -2618,3 +3121,55 @@ func cleanGeminiRequest(body []byte) ([]byte, error) {
...
@@ -2618,3 +3121,55 @@ func cleanGeminiRequest(body []byte) ([]byte, error) {
return
json
.
Marshal
(
payload
)
return
json
.
Marshal
(
payload
)
}
}
// filterEmptyPartsFromGeminiRequest 过滤 Gemini 请求中 parts 为空的消息
// Gemini API 不接受 parts 为空数组的消息,会返回 400 错误
func
filterEmptyPartsFromGeminiRequest
(
body
[]
byte
)
([]
byte
,
error
)
{
var
payload
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
payload
);
err
!=
nil
{
return
nil
,
err
}
contents
,
ok
:=
payload
[
"contents"
]
.
([]
any
)
if
!
ok
||
len
(
contents
)
==
0
{
return
body
,
nil
}
filtered
:=
make
([]
any
,
0
,
len
(
contents
))
modified
:=
false
for
_
,
c
:=
range
contents
{
contentMap
,
ok
:=
c
.
(
map
[
string
]
any
)
if
!
ok
{
filtered
=
append
(
filtered
,
c
)
continue
}
parts
,
hasParts
:=
contentMap
[
"parts"
]
if
!
hasParts
{
filtered
=
append
(
filtered
,
c
)
continue
}
partsSlice
,
ok
:=
parts
.
([]
any
)
if
!
ok
{
filtered
=
append
(
filtered
,
c
)
continue
}
// 跳过 parts 为空数组的消息
if
len
(
partsSlice
)
==
0
{
modified
=
true
continue
}
filtered
=
append
(
filtered
,
c
)
}
if
!
modified
{
return
body
,
nil
}
payload
[
"contents"
]
=
filtered
return
json
.
Marshal
(
payload
)
}
backend/internal/service/antigravity_gateway_service_test.go
View file @
de7ff902
package
service
package
service
import
(
import
(
"bytes"
"context"
"encoding/json"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/require"
)
)
...
@@ -81,3 +87,106 @@ func TestStripThinkingFromClaudeRequest_DoesNotDowngradeTools(t *testing.T) {
...
@@ -81,3 +87,106 @@ func TestStripThinkingFromClaudeRequest_DoesNotDowngradeTools(t *testing.T) {
require
.
Equal
(
t
,
"secret plan"
,
blocks
[
0
][
"text"
])
require
.
Equal
(
t
,
"secret plan"
,
blocks
[
0
][
"text"
])
require
.
Equal
(
t
,
"tool_use"
,
blocks
[
1
][
"type"
])
require
.
Equal
(
t
,
"tool_use"
,
blocks
[
1
][
"type"
])
}
}
func
TestIsPromptTooLongError
(
t
*
testing
.
T
)
{
require
.
True
(
t
,
isPromptTooLongError
([]
byte
(
`{"error":{"message":"Prompt is too long"}}`
)))
require
.
True
(
t
,
isPromptTooLongError
([]
byte
(
`{"message":"Prompt is too long"}`
)))
require
.
False
(
t
,
isPromptTooLongError
([]
byte
(
`{"error":{"message":"other"}}`
)))
}
type
httpUpstreamStub
struct
{
resp
*
http
.
Response
err
error
}
func
(
s
*
httpUpstreamStub
)
Do
(
_
*
http
.
Request
,
_
string
,
_
int64
,
_
int
)
(
*
http
.
Response
,
error
)
{
return
s
.
resp
,
s
.
err
}
func
(
s
*
httpUpstreamStub
)
DoWithTLS
(
_
*
http
.
Request
,
_
string
,
_
int64
,
_
int
,
_
bool
)
(
*
http
.
Response
,
error
)
{
return
s
.
resp
,
s
.
err
}
func
TestAntigravityGatewayService_Forward_PromptTooLong
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
writer
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
writer
)
body
,
err
:=
json
.
Marshal
(
map
[
string
]
any
{
"model"
:
"claude-opus-4-5"
,
"messages"
:
[]
map
[
string
]
any
{
{
"role"
:
"user"
,
"content"
:
"hi"
},
},
"max_tokens"
:
1
,
"stream"
:
false
,
})
require
.
NoError
(
t
,
err
)
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1/messages"
,
bytes
.
NewReader
(
body
))
c
.
Request
=
req
respBody
:=
[]
byte
(
`{"error":{"message":"Prompt is too long"}}`
)
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusBadRequest
,
Header
:
http
.
Header
{
"X-Request-Id"
:
[]
string
{
"req-1"
}},
Body
:
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
)),
}
svc
:=
&
AntigravityGatewayService
{
tokenProvider
:
&
AntigravityTokenProvider
{},
httpUpstream
:
&
httpUpstreamStub
{
resp
:
resp
},
}
account
:=
&
Account
{
ID
:
1
,
Name
:
"acc-1"
,
Platform
:
PlatformAntigravity
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"token"
,
},
}
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
c
,
account
,
body
)
require
.
Nil
(
t
,
result
)
var
promptErr
*
PromptTooLongError
require
.
ErrorAs
(
t
,
err
,
&
promptErr
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
promptErr
.
StatusCode
)
require
.
Equal
(
t
,
"req-1"
,
promptErr
.
RequestID
)
require
.
NotEmpty
(
t
,
promptErr
.
Body
)
raw
,
ok
:=
c
.
Get
(
OpsUpstreamErrorsKey
)
require
.
True
(
t
,
ok
)
events
,
ok
:=
raw
.
([]
*
OpsUpstreamErrorEvent
)
require
.
True
(
t
,
ok
)
require
.
Len
(
t
,
events
,
1
)
require
.
Equal
(
t
,
"prompt_too_long"
,
events
[
0
]
.
Kind
)
}
func
TestAntigravityMaxRetriesForModel_AfterSwitch
(
t
*
testing
.
T
)
{
t
.
Setenv
(
antigravityMaxRetriesEnv
,
"4"
)
t
.
Setenv
(
antigravityMaxRetriesAfterSwitchEnv
,
"7"
)
t
.
Setenv
(
antigravityMaxRetriesClaudeEnv
,
""
)
t
.
Setenv
(
antigravityMaxRetriesGeminiTextEnv
,
""
)
t
.
Setenv
(
antigravityMaxRetriesGeminiImageEnv
,
""
)
got
:=
antigravityMaxRetriesForModel
(
"claude-sonnet-4-5"
,
false
)
require
.
Equal
(
t
,
4
,
got
)
got
=
antigravityMaxRetriesForModel
(
"claude-sonnet-4-5"
,
true
)
require
.
Equal
(
t
,
7
,
got
)
}
func
TestAntigravityMaxRetriesForModel_AfterSwitchFallback
(
t
*
testing
.
T
)
{
t
.
Setenv
(
antigravityMaxRetriesEnv
,
"5"
)
t
.
Setenv
(
antigravityMaxRetriesAfterSwitchEnv
,
""
)
t
.
Setenv
(
antigravityMaxRetriesClaudeEnv
,
""
)
t
.
Setenv
(
antigravityMaxRetriesGeminiTextEnv
,
""
)
t
.
Setenv
(
antigravityMaxRetriesGeminiImageEnv
,
""
)
got
:=
antigravityMaxRetriesForModel
(
"gemini-2.5-flash"
,
true
)
require
.
Equal
(
t
,
5
,
got
)
}
backend/internal/service/antigravity_quota_scope.go
View file @
de7ff902
package
service
package
service
import
(
import
(
"slices"
"strings"
"strings"
"time"
"time"
)
)
...
@@ -16,6 +17,21 @@ const (
...
@@ -16,6 +17,21 @@ const (
AntigravityQuotaScopeGeminiImage
AntigravityQuotaScope
=
"gemini_image"
AntigravityQuotaScopeGeminiImage
AntigravityQuotaScope
=
"gemini_image"
)
)
// IsScopeSupported 检查给定的 scope 是否在分组支持的 scope 列表中
func
IsScopeSupported
(
supportedScopes
[]
string
,
scope
AntigravityQuotaScope
)
bool
{
if
len
(
supportedScopes
)
==
0
{
// 未配置时默认全部支持
return
true
}
supported
:=
slices
.
Contains
(
supportedScopes
,
string
(
scope
))
return
supported
}
// ResolveAntigravityQuotaScope 根据模型名称解析配额域(导出版本)
func
ResolveAntigravityQuotaScope
(
requestedModel
string
)
(
AntigravityQuotaScope
,
bool
)
{
return
resolveAntigravityQuotaScope
(
requestedModel
)
}
// resolveAntigravityQuotaScope 根据模型名称解析配额域
// resolveAntigravityQuotaScope 根据模型名称解析配额域
func
resolveAntigravityQuotaScope
(
requestedModel
string
)
(
AntigravityQuotaScope
,
bool
)
{
func
resolveAntigravityQuotaScope
(
requestedModel
string
)
(
AntigravityQuotaScope
,
bool
)
{
model
:=
normalizeAntigravityModelName
(
requestedModel
)
model
:=
normalizeAntigravityModelName
(
requestedModel
)
...
...
backend/internal/service/api_key.go
View file @
de7ff902
...
@@ -2,6 +2,14 @@ package service
...
@@ -2,6 +2,14 @@ package service
import
"time"
import
"time"
// API Key status constants
const
(
StatusAPIKeyActive
=
"active"
StatusAPIKeyDisabled
=
"disabled"
StatusAPIKeyQuotaExhausted
=
"quota_exhausted"
StatusAPIKeyExpired
=
"expired"
)
type
APIKey
struct
{
type
APIKey
struct
{
ID
int64
ID
int64
UserID
int64
UserID
int64
...
@@ -15,8 +23,53 @@ type APIKey struct {
...
@@ -15,8 +23,53 @@ type APIKey struct {
UpdatedAt
time
.
Time
UpdatedAt
time
.
Time
User
*
User
User
*
User
Group
*
Group
Group
*
Group
// Quota fields
Quota
float64
// Quota limit in USD (0 = unlimited)
QuotaUsed
float64
// Used quota amount
ExpiresAt
*
time
.
Time
// Expiration time (nil = never expires)
}
}
func
(
k
*
APIKey
)
IsActive
()
bool
{
func
(
k
*
APIKey
)
IsActive
()
bool
{
return
k
.
Status
==
StatusActive
return
k
.
Status
==
StatusActive
}
}
// IsExpired checks if the API key has expired
func
(
k
*
APIKey
)
IsExpired
()
bool
{
if
k
.
ExpiresAt
==
nil
{
return
false
}
return
time
.
Now
()
.
After
(
*
k
.
ExpiresAt
)
}
// IsQuotaExhausted checks if the API key quota is exhausted
func
(
k
*
APIKey
)
IsQuotaExhausted
()
bool
{
if
k
.
Quota
<=
0
{
return
false
// unlimited
}
return
k
.
QuotaUsed
>=
k
.
Quota
}
// GetQuotaRemaining returns remaining quota (-1 for unlimited)
func
(
k
*
APIKey
)
GetQuotaRemaining
()
float64
{
if
k
.
Quota
<=
0
{
return
-
1
// unlimited
}
remaining
:=
k
.
Quota
-
k
.
QuotaUsed
if
remaining
<
0
{
return
0
}
return
remaining
}
// GetDaysUntilExpiry returns days until expiry (-1 for never expires)
func
(
k
*
APIKey
)
GetDaysUntilExpiry
()
int
{
if
k
.
ExpiresAt
==
nil
{
return
-
1
// never expires
}
duration
:=
time
.
Until
(
*
k
.
ExpiresAt
)
if
duration
<
0
{
return
0
}
return
int
(
duration
.
Hours
()
/
24
)
}
backend/internal/service/api_key_auth_cache.go
View file @
de7ff902
package
service
package
service
import
"time"
// APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段)
// APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段)
type
APIKeyAuthSnapshot
struct
{
type
APIKeyAuthSnapshot
struct
{
APIKeyID
int64
`json:"api_key_id"`
APIKeyID
int64
`json:"api_key_id"`
...
@@ -10,6 +12,13 @@ type APIKeyAuthSnapshot struct {
...
@@ -10,6 +12,13 @@ type APIKeyAuthSnapshot struct {
IPBlacklist
[]
string
`json:"ip_blacklist,omitempty"`
IPBlacklist
[]
string
`json:"ip_blacklist,omitempty"`
User
APIKeyAuthUserSnapshot
`json:"user"`
User
APIKeyAuthUserSnapshot
`json:"user"`
Group
*
APIKeyAuthGroupSnapshot
`json:"group,omitempty"`
Group
*
APIKeyAuthGroupSnapshot
`json:"group,omitempty"`
// Quota fields for API Key independent quota feature
Quota
float64
`json:"quota"`
// Quota limit in USD (0 = unlimited)
QuotaUsed
float64
`json:"quota_used"`
// Used quota amount
// Expiration field for API Key expiration feature
ExpiresAt
*
time
.
Time
`json:"expires_at,omitempty"`
// Expiration time (nil = never expires)
}
}
// APIKeyAuthUserSnapshot 用户快照
// APIKeyAuthUserSnapshot 用户快照
...
@@ -23,29 +32,34 @@ type APIKeyAuthUserSnapshot struct {
...
@@ -23,29 +32,34 @@ type APIKeyAuthUserSnapshot struct {
// APIKeyAuthGroupSnapshot 分组快照
// APIKeyAuthGroupSnapshot 分组快照
type
APIKeyAuthGroupSnapshot
struct
{
type
APIKeyAuthGroupSnapshot
struct
{
ID
int64
`json:"id"`
ID
int64
`json:"id"`
Name
string
`json:"name"`
Name
string
`json:"name"`
Platform
string
`json:"platform"`
Platform
string
`json:"platform"`
Status
string
`json:"status"`
Status
string
`json:"status"`
SubscriptionType
string
`json:"subscription_type"`
SubscriptionType
string
`json:"subscription_type"`
RateMultiplier
float64
`json:"rate_multiplier"`
RateMultiplier
float64
`json:"rate_multiplier"`
DailyLimitUSD
*
float64
`json:"daily_limit_usd,omitempty"`
DailyLimitUSD
*
float64
`json:"daily_limit_usd,omitempty"`
WeeklyLimitUSD
*
float64
`json:"weekly_limit_usd,omitempty"`
WeeklyLimitUSD
*
float64
`json:"weekly_limit_usd,omitempty"`
MonthlyLimitUSD
*
float64
`json:"monthly_limit_usd,omitempty"`
MonthlyLimitUSD
*
float64
`json:"monthly_limit_usd,omitempty"`
ImagePrice1K
*
float64
`json:"image_price_1k,omitempty"`
ImagePrice1K
*
float64
`json:"image_price_1k,omitempty"`
ImagePrice2K
*
float64
`json:"image_price_2k,omitempty"`
ImagePrice2K
*
float64
`json:"image_price_2k,omitempty"`
ImagePrice4K
*
float64
`json:"image_price_4k,omitempty"`
ImagePrice4K
*
float64
`json:"image_price_4k,omitempty"`
SoraImagePrice360
*
float64
`json:"sora_image_price_360,omitempty"`
SoraImagePrice360
*
float64
`json:"sora_image_price_360,omitempty"`
SoraImagePrice540
*
float64
`json:"sora_image_price_540,omitempty"`
SoraImagePrice540
*
float64
`json:"sora_image_price_540,omitempty"`
SoraVideoPricePerRequest
*
float64
`json:"sora_video_price_per_request,omitempty"`
SoraVideoPricePerRequest
*
float64
`json:"sora_video_price_per_request,omitempty"`
SoraVideoPricePerRequestHD
*
float64
`json:"sora_video_price_per_request_hd,omitempty"`
SoraVideoPricePerRequestHD
*
float64
`json:"sora_video_price_per_request_hd,omitempty"`
ClaudeCodeOnly
bool
`json:"claude_code_only"`
ClaudeCodeOnly
bool
`json:"claude_code_only"`
FallbackGroupID
*
int64
`json:"fallback_group_id,omitempty"`
FallbackGroupID
*
int64
`json:"fallback_group_id,omitempty"`
FallbackGroupIDOnInvalidRequest
*
int64
`json:"fallback_group_id_on_invalid_request,omitempty"`
// Model routing is used by gateway account selection, so it must be part of auth cache snapshot.
// Model routing is used by gateway account selection, so it must be part of auth cache snapshot.
// Only anthropic groups use these fields; others may leave them empty.
// Only anthropic groups use these fields; others may leave them empty.
ModelRouting
map
[
string
][]
int64
`json:"model_routing,omitempty"`
ModelRouting
map
[
string
][]
int64
`json:"model_routing,omitempty"`
ModelRoutingEnabled
bool
`json:"model_routing_enabled"`
ModelRoutingEnabled
bool
`json:"model_routing_enabled"`
MCPXMLInject
bool
`json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes
[]
string
`json:"supported_model_scopes,omitempty"`
}
}
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
...
...
backend/internal/service/api_key_auth_cache_impl.go
View file @
de7ff902
...
@@ -213,6 +213,9 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
...
@@ -213,6 +213,9 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
Status
:
apiKey
.
Status
,
Status
:
apiKey
.
Status
,
IPWhitelist
:
apiKey
.
IPWhitelist
,
IPWhitelist
:
apiKey
.
IPWhitelist
,
IPBlacklist
:
apiKey
.
IPBlacklist
,
IPBlacklist
:
apiKey
.
IPBlacklist
,
Quota
:
apiKey
.
Quota
,
QuotaUsed
:
apiKey
.
QuotaUsed
,
ExpiresAt
:
apiKey
.
ExpiresAt
,
User
:
APIKeyAuthUserSnapshot
{
User
:
APIKeyAuthUserSnapshot
{
ID
:
apiKey
.
User
.
ID
,
ID
:
apiKey
.
User
.
ID
,
Status
:
apiKey
.
User
.
Status
,
Status
:
apiKey
.
User
.
Status
,
...
@@ -223,26 +226,29 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
...
@@ -223,26 +226,29 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
}
}
if
apiKey
.
Group
!=
nil
{
if
apiKey
.
Group
!=
nil
{
snapshot
.
Group
=
&
APIKeyAuthGroupSnapshot
{
snapshot
.
Group
=
&
APIKeyAuthGroupSnapshot
{
ID
:
apiKey
.
Group
.
ID
,
ID
:
apiKey
.
Group
.
ID
,
Name
:
apiKey
.
Group
.
Name
,
Name
:
apiKey
.
Group
.
Name
,
Platform
:
apiKey
.
Group
.
Platform
,
Platform
:
apiKey
.
Group
.
Platform
,
Status
:
apiKey
.
Group
.
Status
,
Status
:
apiKey
.
Group
.
Status
,
SubscriptionType
:
apiKey
.
Group
.
SubscriptionType
,
SubscriptionType
:
apiKey
.
Group
.
SubscriptionType
,
RateMultiplier
:
apiKey
.
Group
.
RateMultiplier
,
RateMultiplier
:
apiKey
.
Group
.
RateMultiplier
,
DailyLimitUSD
:
apiKey
.
Group
.
DailyLimitUSD
,
DailyLimitUSD
:
apiKey
.
Group
.
DailyLimitUSD
,
WeeklyLimitUSD
:
apiKey
.
Group
.
WeeklyLimitUSD
,
WeeklyLimitUSD
:
apiKey
.
Group
.
WeeklyLimitUSD
,
MonthlyLimitUSD
:
apiKey
.
Group
.
MonthlyLimitUSD
,
MonthlyLimitUSD
:
apiKey
.
Group
.
MonthlyLimitUSD
,
ImagePrice1K
:
apiKey
.
Group
.
ImagePrice1K
,
ImagePrice1K
:
apiKey
.
Group
.
ImagePrice1K
,
ImagePrice2K
:
apiKey
.
Group
.
ImagePrice2K
,
ImagePrice2K
:
apiKey
.
Group
.
ImagePrice2K
,
ImagePrice4K
:
apiKey
.
Group
.
ImagePrice4K
,
ImagePrice4K
:
apiKey
.
Group
.
ImagePrice4K
,
SoraImagePrice360
:
apiKey
.
Group
.
SoraImagePrice360
,
SoraImagePrice360
:
apiKey
.
Group
.
SoraImagePrice360
,
SoraImagePrice540
:
apiKey
.
Group
.
SoraImagePrice540
,
SoraImagePrice540
:
apiKey
.
Group
.
SoraImagePrice540
,
SoraVideoPricePerRequest
:
apiKey
.
Group
.
SoraVideoPricePerRequest
,
SoraVideoPricePerRequest
:
apiKey
.
Group
.
SoraVideoPricePerRequest
,
SoraVideoPricePerRequestHD
:
apiKey
.
Group
.
SoraVideoPricePerRequestHD
,
SoraVideoPricePerRequestHD
:
apiKey
.
Group
.
SoraVideoPricePerRequestHD
,
ClaudeCodeOnly
:
apiKey
.
Group
.
ClaudeCodeOnly
,
ClaudeCodeOnly
:
apiKey
.
Group
.
ClaudeCodeOnly
,
FallbackGroupID
:
apiKey
.
Group
.
FallbackGroupID
,
FallbackGroupID
:
apiKey
.
Group
.
FallbackGroupID
,
ModelRouting
:
apiKey
.
Group
.
ModelRouting
,
FallbackGroupIDOnInvalidRequest
:
apiKey
.
Group
.
FallbackGroupIDOnInvalidRequest
,
ModelRoutingEnabled
:
apiKey
.
Group
.
ModelRoutingEnabled
,
ModelRouting
:
apiKey
.
Group
.
ModelRouting
,
ModelRoutingEnabled
:
apiKey
.
Group
.
ModelRoutingEnabled
,
MCPXMLInject
:
apiKey
.
Group
.
MCPXMLInject
,
SupportedModelScopes
:
apiKey
.
Group
.
SupportedModelScopes
,
}
}
}
}
return
snapshot
return
snapshot
...
@@ -260,6 +266,9 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
...
@@ -260,6 +266,9 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
Status
:
snapshot
.
Status
,
Status
:
snapshot
.
Status
,
IPWhitelist
:
snapshot
.
IPWhitelist
,
IPWhitelist
:
snapshot
.
IPWhitelist
,
IPBlacklist
:
snapshot
.
IPBlacklist
,
IPBlacklist
:
snapshot
.
IPBlacklist
,
Quota
:
snapshot
.
Quota
,
QuotaUsed
:
snapshot
.
QuotaUsed
,
ExpiresAt
:
snapshot
.
ExpiresAt
,
User
:
&
User
{
User
:
&
User
{
ID
:
snapshot
.
User
.
ID
,
ID
:
snapshot
.
User
.
ID
,
Status
:
snapshot
.
User
.
Status
,
Status
:
snapshot
.
User
.
Status
,
...
@@ -270,27 +279,30 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
...
@@ -270,27 +279,30 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
}
}
if
snapshot
.
Group
!=
nil
{
if
snapshot
.
Group
!=
nil
{
apiKey
.
Group
=
&
Group
{
apiKey
.
Group
=
&
Group
{
ID
:
snapshot
.
Group
.
ID
,
ID
:
snapshot
.
Group
.
ID
,
Name
:
snapshot
.
Group
.
Name
,
Name
:
snapshot
.
Group
.
Name
,
Platform
:
snapshot
.
Group
.
Platform
,
Platform
:
snapshot
.
Group
.
Platform
,
Status
:
snapshot
.
Group
.
Status
,
Status
:
snapshot
.
Group
.
Status
,
Hydrated
:
true
,
Hydrated
:
true
,
SubscriptionType
:
snapshot
.
Group
.
SubscriptionType
,
SubscriptionType
:
snapshot
.
Group
.
SubscriptionType
,
RateMultiplier
:
snapshot
.
Group
.
RateMultiplier
,
RateMultiplier
:
snapshot
.
Group
.
RateMultiplier
,
DailyLimitUSD
:
snapshot
.
Group
.
DailyLimitUSD
,
DailyLimitUSD
:
snapshot
.
Group
.
DailyLimitUSD
,
WeeklyLimitUSD
:
snapshot
.
Group
.
WeeklyLimitUSD
,
WeeklyLimitUSD
:
snapshot
.
Group
.
WeeklyLimitUSD
,
MonthlyLimitUSD
:
snapshot
.
Group
.
MonthlyLimitUSD
,
MonthlyLimitUSD
:
snapshot
.
Group
.
MonthlyLimitUSD
,
ImagePrice1K
:
snapshot
.
Group
.
ImagePrice1K
,
ImagePrice1K
:
snapshot
.
Group
.
ImagePrice1K
,
ImagePrice2K
:
snapshot
.
Group
.
ImagePrice2K
,
ImagePrice2K
:
snapshot
.
Group
.
ImagePrice2K
,
ImagePrice4K
:
snapshot
.
Group
.
ImagePrice4K
,
ImagePrice4K
:
snapshot
.
Group
.
ImagePrice4K
,
SoraImagePrice360
:
snapshot
.
Group
.
SoraImagePrice360
,
SoraImagePrice360
:
snapshot
.
Group
.
SoraImagePrice360
,
SoraImagePrice540
:
snapshot
.
Group
.
SoraImagePrice540
,
SoraImagePrice540
:
snapshot
.
Group
.
SoraImagePrice540
,
SoraVideoPricePerRequest
:
snapshot
.
Group
.
SoraVideoPricePerRequest
,
SoraVideoPricePerRequest
:
snapshot
.
Group
.
SoraVideoPricePerRequest
,
SoraVideoPricePerRequestHD
:
snapshot
.
Group
.
SoraVideoPricePerRequestHD
,
SoraVideoPricePerRequestHD
:
snapshot
.
Group
.
SoraVideoPricePerRequestHD
,
ClaudeCodeOnly
:
snapshot
.
Group
.
ClaudeCodeOnly
,
ClaudeCodeOnly
:
snapshot
.
Group
.
ClaudeCodeOnly
,
FallbackGroupID
:
snapshot
.
Group
.
FallbackGroupID
,
FallbackGroupID
:
snapshot
.
Group
.
FallbackGroupID
,
ModelRouting
:
snapshot
.
Group
.
ModelRouting
,
FallbackGroupIDOnInvalidRequest
:
snapshot
.
Group
.
FallbackGroupIDOnInvalidRequest
,
ModelRoutingEnabled
:
snapshot
.
Group
.
ModelRoutingEnabled
,
ModelRouting
:
snapshot
.
Group
.
ModelRouting
,
ModelRoutingEnabled
:
snapshot
.
Group
.
ModelRoutingEnabled
,
MCPXMLInject
:
snapshot
.
Group
.
MCPXMLInject
,
SupportedModelScopes
:
snapshot
.
Group
.
SupportedModelScopes
,
}
}
}
}
return
apiKey
return
apiKey
...
...
backend/internal/service/api_key_service.go
View file @
de7ff902
...
@@ -24,6 +24,10 @@ var (
...
@@ -24,6 +24,10 @@ var (
ErrAPIKeyInvalidChars
=
infraerrors
.
BadRequest
(
"API_KEY_INVALID_CHARS"
,
"api key can only contain letters, numbers, underscores, and hyphens"
)
ErrAPIKeyInvalidChars
=
infraerrors
.
BadRequest
(
"API_KEY_INVALID_CHARS"
,
"api key can only contain letters, numbers, underscores, and hyphens"
)
ErrAPIKeyRateLimited
=
infraerrors
.
TooManyRequests
(
"API_KEY_RATE_LIMITED"
,
"too many failed attempts, please try again later"
)
ErrAPIKeyRateLimited
=
infraerrors
.
TooManyRequests
(
"API_KEY_RATE_LIMITED"
,
"too many failed attempts, please try again later"
)
ErrInvalidIPPattern
=
infraerrors
.
BadRequest
(
"INVALID_IP_PATTERN"
,
"invalid IP or CIDR pattern"
)
ErrInvalidIPPattern
=
infraerrors
.
BadRequest
(
"INVALID_IP_PATTERN"
,
"invalid IP or CIDR pattern"
)
// ErrAPIKeyExpired = infraerrors.Forbidden("API_KEY_EXPIRED", "api key has expired")
ErrAPIKeyExpired
=
infraerrors
.
Forbidden
(
"API_KEY_EXPIRED"
,
"api key 已过期"
)
// ErrAPIKeyQuotaExhausted = infraerrors.TooManyRequests("API_KEY_QUOTA_EXHAUSTED", "api key quota exhausted")
ErrAPIKeyQuotaExhausted
=
infraerrors
.
TooManyRequests
(
"API_KEY_QUOTA_EXHAUSTED"
,
"api key 额度已用完"
)
)
)
const
(
const
(
...
@@ -51,6 +55,9 @@ type APIKeyRepository interface {
...
@@ -51,6 +55,9 @@ type APIKeyRepository interface {
CountByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
CountByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
ListKeysByUserID
(
ctx
context
.
Context
,
userID
int64
)
([]
string
,
error
)
ListKeysByUserID
(
ctx
context
.
Context
,
userID
int64
)
([]
string
,
error
)
ListKeysByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
([]
string
,
error
)
ListKeysByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
([]
string
,
error
)
// Quota methods
IncrementQuotaUsed
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
(
float64
,
error
)
}
}
// APIKeyCache defines cache operations for API key service
// APIKeyCache defines cache operations for API key service
...
@@ -85,6 +92,10 @@ type CreateAPIKeyRequest struct {
...
@@ -85,6 +92,10 @@ type CreateAPIKeyRequest struct {
CustomKey
*
string
`json:"custom_key"`
// 可选的自定义key
CustomKey
*
string
`json:"custom_key"`
// 可选的自定义key
IPWhitelist
[]
string
`json:"ip_whitelist"`
// IP 白名单
IPWhitelist
[]
string
`json:"ip_whitelist"`
// IP 白名单
IPBlacklist
[]
string
`json:"ip_blacklist"`
// IP 黑名单
IPBlacklist
[]
string
`json:"ip_blacklist"`
// IP 黑名单
// Quota fields
Quota
float64
`json:"quota"`
// Quota limit in USD (0 = unlimited)
ExpiresInDays
*
int
`json:"expires_in_days"`
// Days until expiry (nil = never expires)
}
}
// UpdateAPIKeyRequest 更新API Key请求
// UpdateAPIKeyRequest 更新API Key请求
...
@@ -94,6 +105,12 @@ type UpdateAPIKeyRequest struct {
...
@@ -94,6 +105,12 @@ type UpdateAPIKeyRequest struct {
Status
*
string
`json:"status"`
Status
*
string
`json:"status"`
IPWhitelist
[]
string
`json:"ip_whitelist"`
// IP 白名单(空数组清空)
IPWhitelist
[]
string
`json:"ip_whitelist"`
// IP 白名单(空数组清空)
IPBlacklist
[]
string
`json:"ip_blacklist"`
// IP 黑名单(空数组清空)
IPBlacklist
[]
string
`json:"ip_blacklist"`
// IP 黑名单(空数组清空)
// Quota fields
Quota
*
float64
`json:"quota"`
// Quota limit in USD (nil = no change, 0 = unlimited)
ExpiresAt
*
time
.
Time
`json:"expires_at"`
// Expiration time (nil = no change)
ClearExpiration
bool
`json:"-"`
// Clear expiration (internal use)
ResetQuota
*
bool
`json:"reset_quota"`
// Reset quota_used to 0
}
}
// APIKeyService API Key服务
// APIKeyService API Key服务
...
@@ -289,6 +306,14 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
...
@@ -289,6 +306,14 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
Status
:
StatusActive
,
Status
:
StatusActive
,
IPWhitelist
:
req
.
IPWhitelist
,
IPWhitelist
:
req
.
IPWhitelist
,
IPBlacklist
:
req
.
IPBlacklist
,
IPBlacklist
:
req
.
IPBlacklist
,
Quota
:
req
.
Quota
,
QuotaUsed
:
0
,
}
// Set expiration time if specified
if
req
.
ExpiresInDays
!=
nil
&&
*
req
.
ExpiresInDays
>
0
{
expiresAt
:=
time
.
Now
()
.
AddDate
(
0
,
0
,
*
req
.
ExpiresInDays
)
apiKey
.
ExpiresAt
=
&
expiresAt
}
}
if
err
:=
s
.
apiKeyRepo
.
Create
(
ctx
,
apiKey
);
err
!=
nil
{
if
err
:=
s
.
apiKeyRepo
.
Create
(
ctx
,
apiKey
);
err
!=
nil
{
...
@@ -436,6 +461,35 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
...
@@ -436,6 +461,35 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
}
}
}
}
// Update quota fields
if
req
.
Quota
!=
nil
{
apiKey
.
Quota
=
*
req
.
Quota
// If quota is increased and status was quota_exhausted, reactivate
if
apiKey
.
Status
==
StatusAPIKeyQuotaExhausted
&&
*
req
.
Quota
>
apiKey
.
QuotaUsed
{
apiKey
.
Status
=
StatusActive
}
}
if
req
.
ResetQuota
!=
nil
&&
*
req
.
ResetQuota
{
apiKey
.
QuotaUsed
=
0
// If resetting quota and status was quota_exhausted, reactivate
if
apiKey
.
Status
==
StatusAPIKeyQuotaExhausted
{
apiKey
.
Status
=
StatusActive
}
}
if
req
.
ClearExpiration
{
apiKey
.
ExpiresAt
=
nil
// If clearing expiry and status was expired, reactivate
if
apiKey
.
Status
==
StatusAPIKeyExpired
{
apiKey
.
Status
=
StatusActive
}
}
else
if
req
.
ExpiresAt
!=
nil
{
apiKey
.
ExpiresAt
=
req
.
ExpiresAt
// If extending expiry and status was expired, reactivate
if
apiKey
.
Status
==
StatusAPIKeyExpired
&&
time
.
Now
()
.
Before
(
*
req
.
ExpiresAt
)
{
apiKey
.
Status
=
StatusActive
}
}
// 更新 IP 限制(空数组会清空设置)
// 更新 IP 限制(空数组会清空设置)
apiKey
.
IPWhitelist
=
req
.
IPWhitelist
apiKey
.
IPWhitelist
=
req
.
IPWhitelist
apiKey
.
IPBlacklist
=
req
.
IPBlacklist
apiKey
.
IPBlacklist
=
req
.
IPBlacklist
...
@@ -572,3 +626,51 @@ func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword
...
@@ -572,3 +626,51 @@ func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword
}
}
return
keys
,
nil
return
keys
,
nil
}
}
// CheckAPIKeyQuotaAndExpiry checks if the API key is valid for use (not expired, quota not exhausted)
// Returns nil if valid, error if invalid
func
(
s
*
APIKeyService
)
CheckAPIKeyQuotaAndExpiry
(
apiKey
*
APIKey
)
error
{
// Check expiration
if
apiKey
.
IsExpired
()
{
return
ErrAPIKeyExpired
}
// Check quota
if
apiKey
.
IsQuotaExhausted
()
{
return
ErrAPIKeyQuotaExhausted
}
return
nil
}
// UpdateQuotaUsed updates the quota_used field after a request
// Also checks if quota is exhausted and updates status accordingly
func
(
s
*
APIKeyService
)
UpdateQuotaUsed
(
ctx
context
.
Context
,
apiKeyID
int64
,
cost
float64
)
error
{
if
cost
<=
0
{
return
nil
}
// Use repository to atomically increment quota_used
newQuotaUsed
,
err
:=
s
.
apiKeyRepo
.
IncrementQuotaUsed
(
ctx
,
apiKeyID
,
cost
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"increment quota used: %w"
,
err
)
}
// Check if quota is now exhausted and update status if needed
apiKey
,
err
:=
s
.
apiKeyRepo
.
GetByID
(
ctx
,
apiKeyID
)
if
err
!=
nil
{
return
nil
// Don't fail the request, just log
}
// If quota is set and now exhausted, update status
if
apiKey
.
Quota
>
0
&&
newQuotaUsed
>=
apiKey
.
Quota
{
apiKey
.
Status
=
StatusAPIKeyQuotaExhausted
if
err
:=
s
.
apiKeyRepo
.
Update
(
ctx
,
apiKey
);
err
!=
nil
{
return
nil
// Don't fail the request
}
// Invalidate cache so next request sees the new status
s
.
InvalidateAuthCacheByKey
(
ctx
,
apiKey
.
Key
)
}
return
nil
}
backend/internal/service/api_key_service_cache_test.go
View file @
de7ff902
...
@@ -99,6 +99,10 @@ func (s *authRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]
...
@@ -99,6 +99,10 @@ func (s *authRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]
return
s
.
listKeysByGroupID
(
ctx
,
groupID
)
return
s
.
listKeysByGroupID
(
ctx
,
groupID
)
}
}
func
(
s
*
authRepoStub
)
IncrementQuotaUsed
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
(
float64
,
error
)
{
panic
(
"unexpected IncrementQuotaUsed call"
)
}
type
authCacheStub
struct
{
type
authCacheStub
struct
{
getAuthCache
func
(
ctx
context
.
Context
,
key
string
)
(
*
APIKeyAuthCacheEntry
,
error
)
getAuthCache
func
(
ctx
context
.
Context
,
key
string
)
(
*
APIKeyAuthCacheEntry
,
error
)
setAuthKeys
[]
string
setAuthKeys
[]
string
...
...
backend/internal/service/api_key_service_delete_test.go
View file @
de7ff902
...
@@ -118,6 +118,10 @@ func (s *apiKeyRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) (
...
@@ -118,6 +118,10 @@ func (s *apiKeyRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) (
panic
(
"unexpected ListKeysByGroupID call"
)
panic
(
"unexpected ListKeysByGroupID call"
)
}
}
func
(
s
*
apiKeyRepoStub
)
IncrementQuotaUsed
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
(
float64
,
error
)
{
panic
(
"unexpected IncrementQuotaUsed call"
)
}
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
// 用于验证删除操作时缓存清理逻辑是否被正确调用。
// 用于验证删除操作时缓存清理逻辑是否被正确调用。
//
//
...
...
backend/internal/service/auth_service.go
View file @
de7ff902
...
@@ -185,7 +185,6 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
...
@@ -185,7 +185,6 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
log
.
Printf
(
"[Auth] Failed to mark invitation code as used for user %d: %v"
,
user
.
ID
,
err
)
log
.
Printf
(
"[Auth] Failed to mark invitation code as used for user %d: %v"
,
user
.
ID
,
err
)
}
}
}
}
// 应用优惠码(如果提供且功能已启用)
// 应用优惠码(如果提供且功能已启用)
if
promoCode
!=
""
&&
s
.
promoService
!=
nil
&&
s
.
settingService
!=
nil
&&
s
.
settingService
.
IsPromoCodeEnabled
(
ctx
)
{
if
promoCode
!=
""
&&
s
.
promoService
!=
nil
&&
s
.
settingService
!=
nil
&&
s
.
settingService
.
IsPromoCodeEnabled
(
ctx
)
{
if
err
:=
s
.
promoService
.
ApplyPromoCode
(
ctx
,
user
.
ID
,
promoCode
);
err
!=
nil
{
if
err
:=
s
.
promoService
.
ApplyPromoCode
(
ctx
,
user
.
ID
,
promoCode
);
err
!=
nil
{
...
...
backend/internal/service/domain_constants.go
View file @
de7ff902
...
@@ -32,6 +32,7 @@ const (
...
@@ -32,6 +32,7 @@ const (
AccountTypeOAuth
=
domain
.
AccountTypeOAuth
// OAuth类型账号(full scope: profile + inference)
AccountTypeOAuth
=
domain
.
AccountTypeOAuth
// OAuth类型账号(full scope: profile + inference)
AccountTypeSetupToken
=
domain
.
AccountTypeSetupToken
// Setup Token类型账号(inference only scope)
AccountTypeSetupToken
=
domain
.
AccountTypeSetupToken
// Setup Token类型账号(inference only scope)
AccountTypeAPIKey
=
domain
.
AccountTypeAPIKey
// API Key类型账号
AccountTypeAPIKey
=
domain
.
AccountTypeAPIKey
// API Key类型账号
AccountTypeUpstream
=
domain
.
AccountTypeUpstream
// 上游透传类型账号(通过 Base URL + API Key 连接上游)
)
)
// Redeem type constants
// Redeem type constants
...
...
backend/internal/service/gateway_service.go
View file @
de7ff902
...
@@ -257,6 +257,9 @@ var (
...
@@ -257,6 +257,9 @@ var (
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
var
ErrClaudeCodeOnly
=
errors
.
New
(
"this group only allows Claude Code clients"
)
var
ErrClaudeCodeOnly
=
errors
.
New
(
"this group only allows Claude Code clients"
)
// ErrModelScopeNotSupported 表示请求的模型系列不在分组支持的范围内
var
ErrModelScopeNotSupported
=
errors
.
New
(
"model scope not supported by this group"
)
// allowedHeaders 白名单headers(参考CRS项目)
// allowedHeaders 白名单headers(参考CRS项目)
var
allowedHeaders
=
map
[
string
]
bool
{
var
allowedHeaders
=
map
[
string
]
bool
{
"accept"
:
true
,
"accept"
:
true
,
...
@@ -589,12 +592,18 @@ func (s *GatewayService) hashContent(content string) string {
...
@@ -589,12 +592,18 @@ func (s *GatewayService) hashContent(content string) string {
}
}
// replaceModelInBody 替换请求体中的model字段
// replaceModelInBody 替换请求体中的model字段
// 使用 json.RawMessage 保留其他字段的原始字节,避免 thinking 块等内容被修改
func
(
s
*
GatewayService
)
replaceModelInBody
(
body
[]
byte
,
newModel
string
)
[]
byte
{
func
(
s
*
GatewayService
)
replaceModelInBody
(
body
[]
byte
,
newModel
string
)
[]
byte
{
var
req
map
[
string
]
any
var
req
map
[
string
]
json
.
RawMessage
if
err
:=
json
.
Unmarshal
(
body
,
&
req
);
err
!=
nil
{
if
err
:=
json
.
Unmarshal
(
body
,
&
req
);
err
!=
nil
{
return
body
return
body
}
}
req
[
"model"
]
=
newModel
// 只序列化 model 字段
modelBytes
,
err
:=
json
.
Marshal
(
newModel
)
if
err
!=
nil
{
return
body
}
req
[
"model"
]
=
modelBytes
newBody
,
err
:=
json
.
Marshal
(
req
)
newBody
,
err
:=
json
.
Marshal
(
req
)
if
err
!=
nil
{
if
err
!=
nil
{
return
body
return
body
...
@@ -791,12 +800,21 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
...
@@ -791,12 +800,21 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
if
len
(
body
)
==
0
{
if
len
(
body
)
==
0
{
return
body
,
modelID
,
nil
return
body
,
modelID
,
nil
}
}
// 使用 json.RawMessage 保留 messages 的原始字节,避免 thinking 块被修改
var
reqRaw
map
[
string
]
json
.
RawMessage
if
err
:=
json
.
Unmarshal
(
body
,
&
reqRaw
);
err
!=
nil
{
return
body
,
modelID
,
nil
}
// 同时解析为 map[string]any 用于修改非 messages 字段
var
req
map
[
string
]
any
var
req
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
req
);
err
!=
nil
{
if
err
:=
json
.
Unmarshal
(
body
,
&
req
);
err
!=
nil
{
return
body
,
modelID
,
nil
return
body
,
modelID
,
nil
}
}
toolNameMap
:=
make
(
map
[
string
]
string
)
toolNameMap
:=
make
(
map
[
string
]
string
)
modified
:=
false
if
system
,
ok
:=
req
[
"system"
];
ok
{
if
system
,
ok
:=
req
[
"system"
];
ok
{
switch
v
:=
system
.
(
type
)
{
switch
v
:=
system
.
(
type
)
{
...
@@ -804,6 +822,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
...
@@ -804,6 +822,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
sanitized
:=
sanitizeSystemText
(
v
)
sanitized
:=
sanitizeSystemText
(
v
)
if
sanitized
!=
v
{
if
sanitized
!=
v
{
req
[
"system"
]
=
sanitized
req
[
"system"
]
=
sanitized
modified
=
true
}
}
case
[]
any
:
case
[]
any
:
for
_
,
item
:=
range
v
{
for
_
,
item
:=
range
v
{
...
@@ -821,6 +840,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
...
@@ -821,6 +840,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
sanitized
:=
sanitizeSystemText
(
text
)
sanitized
:=
sanitizeSystemText
(
text
)
if
sanitized
!=
text
{
if
sanitized
!=
text
{
block
[
"text"
]
=
sanitized
block
[
"text"
]
=
sanitized
modified
=
true
}
}
}
}
}
}
...
@@ -831,6 +851,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
...
@@ -831,6 +851,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
if
normalized
!=
rawModel
{
if
normalized
!=
rawModel
{
req
[
"model"
]
=
normalized
req
[
"model"
]
=
normalized
modelID
=
normalized
modelID
=
normalized
modified
=
true
}
}
}
}
...
@@ -846,16 +867,19 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
...
@@ -846,16 +867,19 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
normalized
:=
normalizeToolNameForClaude
(
name
,
toolNameMap
)
normalized
:=
normalizeToolNameForClaude
(
name
,
toolNameMap
)
if
normalized
!=
""
&&
normalized
!=
name
{
if
normalized
!=
""
&&
normalized
!=
name
{
toolMap
[
"name"
]
=
normalized
toolMap
[
"name"
]
=
normalized
modified
=
true
}
}
}
}
if
desc
,
ok
:=
toolMap
[
"description"
]
.
(
string
);
ok
{
if
desc
,
ok
:=
toolMap
[
"description"
]
.
(
string
);
ok
{
sanitized
:=
sanitizeToolDescription
(
desc
)
sanitized
:=
sanitizeToolDescription
(
desc
)
if
sanitized
!=
desc
{
if
sanitized
!=
desc
{
toolMap
[
"description"
]
=
sanitized
toolMap
[
"description"
]
=
sanitized
modified
=
true
}
}
}
}
if
schema
,
ok
:=
toolMap
[
"input_schema"
];
ok
{
if
schema
,
ok
:=
toolMap
[
"input_schema"
];
ok
{
normalizeToolInputSchema
(
schema
,
toolNameMap
)
normalizeToolInputSchema
(
schema
,
toolNameMap
)
modified
=
true
}
}
tools
[
idx
]
=
toolMap
tools
[
idx
]
=
toolMap
}
}
...
@@ -884,11 +908,15 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
...
@@ -884,11 +908,15 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
normalizedTools
[
normalized
]
=
value
normalizedTools
[
normalized
]
=
value
}
}
req
[
"tools"
]
=
normalizedTools
req
[
"tools"
]
=
normalizedTools
modified
=
true
}
}
}
else
{
}
else
{
req
[
"tools"
]
=
[]
any
{}
req
[
"tools"
]
=
[]
any
{}
modified
=
true
}
}
// 处理 messages 中的 tool_use 块,但保留包含 thinking 块的消息的原始字节
messagesModified
:=
false
if
messages
,
ok
:=
req
[
"messages"
]
.
([]
any
);
ok
{
if
messages
,
ok
:=
req
[
"messages"
]
.
([]
any
);
ok
{
for
_
,
msg
:=
range
messages
{
for
_
,
msg
:=
range
messages
{
msgMap
,
ok
:=
msg
.
(
map
[
string
]
any
)
msgMap
,
ok
:=
msg
.
(
map
[
string
]
any
)
...
@@ -899,6 +927,24 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
...
@@ -899,6 +927,24 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
if
!
ok
{
if
!
ok
{
continue
continue
}
}
// 检查此消息是否包含 thinking 块
hasThinking
:=
false
for
_
,
block
:=
range
content
{
blockMap
,
ok
:=
block
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
blockType
,
_
:=
blockMap
[
"type"
]
.
(
string
)
if
blockType
==
"thinking"
||
blockType
==
"redacted_thinking"
{
hasThinking
=
true
break
}
}
// 如果包含 thinking 块,跳过此消息的修改
if
hasThinking
{
continue
}
// 只修改不包含 thinking 块的消息中的 tool_use
for
_
,
block
:=
range
content
{
for
_
,
block
:=
range
content
{
blockMap
,
ok
:=
block
.
(
map
[
string
]
any
)
blockMap
,
ok
:=
block
.
(
map
[
string
]
any
)
if
!
ok
{
if
!
ok
{
...
@@ -911,6 +957,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
...
@@ -911,6 +957,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
normalized
:=
normalizeToolNameForClaude
(
name
,
toolNameMap
)
normalized
:=
normalizeToolNameForClaude
(
name
,
toolNameMap
)
if
normalized
!=
""
&&
normalized
!=
name
{
if
normalized
!=
""
&&
normalized
!=
name
{
blockMap
[
"name"
]
=
normalized
blockMap
[
"name"
]
=
normalized
messagesModified
=
true
}
}
}
}
}
}
...
@@ -920,6 +967,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
...
@@ -920,6 +967,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
if
opts
.
stripSystemCacheControl
{
if
opts
.
stripSystemCacheControl
{
if
system
,
ok
:=
req
[
"system"
];
ok
{
if
system
,
ok
:=
req
[
"system"
];
ok
{
_
=
stripCacheControlFromSystemBlocks
(
system
)
_
=
stripCacheControlFromSystemBlocks
(
system
)
modified
=
true
}
}
}
}
...
@@ -931,12 +979,46 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
...
@@ -931,12 +979,46 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
}
}
if
existing
,
ok
:=
metadata
[
"user_id"
]
.
(
string
);
!
ok
||
existing
==
""
{
if
existing
,
ok
:=
metadata
[
"user_id"
]
.
(
string
);
!
ok
||
existing
==
""
{
metadata
[
"user_id"
]
=
opts
.
metadataUserID
metadata
[
"user_id"
]
=
opts
.
metadataUserID
modified
=
true
}
}
}
}
delete
(
req
,
"temperature"
)
if
_
,
hasTemp
:=
req
[
"temperature"
];
hasTemp
{
delete
(
req
,
"tool_choice"
)
delete
(
req
,
"temperature"
)
modified
=
true
}
if
_
,
hasChoice
:=
req
[
"tool_choice"
];
hasChoice
{
delete
(
req
,
"tool_choice"
)
modified
=
true
}
if
!
modified
&&
!
messagesModified
{
return
body
,
modelID
,
toolNameMap
}
// 如果 messages 没有被修改,保留原始 messages 字节
if
!
messagesModified
{
// 序列化非 messages 字段
newBody
,
err
:=
json
.
Marshal
(
req
)
if
err
!=
nil
{
return
body
,
modelID
,
toolNameMap
}
// 替换回原始的 messages
var
newReq
map
[
string
]
json
.
RawMessage
if
err
:=
json
.
Unmarshal
(
newBody
,
&
newReq
);
err
!=
nil
{
return
newBody
,
modelID
,
toolNameMap
}
if
origMessages
,
ok
:=
reqRaw
[
"messages"
];
ok
{
newReq
[
"messages"
]
=
origMessages
}
finalBody
,
err
:=
json
.
Marshal
(
newReq
)
if
err
!=
nil
{
return
newBody
,
modelID
,
toolNameMap
}
return
finalBody
,
modelID
,
toolNameMap
}
// messages 被修改了,需要完整序列化
newBody
,
err
:=
json
.
Marshal
(
req
)
newBody
,
err
:=
json
.
Marshal
(
req
)
if
err
!=
nil
{
if
err
!=
nil
{
return
body
,
modelID
,
toolNameMap
return
body
,
modelID
,
toolNameMap
...
@@ -1139,6 +1221,13 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
...
@@ -1139,6 +1221,13 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
log
.
Printf
(
"[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
platform
)
log
.
Printf
(
"[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
platform
)
}
}
// Antigravity 模型系列检查(在账号选择前检查,确保所有代码路径都经过此检查)
if
platform
==
PlatformAntigravity
&&
groupID
!=
nil
&&
requestedModel
!=
""
{
if
err
:=
s
.
checkAntigravityModelScope
(
ctx
,
*
groupID
,
requestedModel
);
err
!=
nil
{
return
nil
,
err
}
}
accounts
,
useMixed
,
err
:=
s
.
listSchedulableAccounts
(
ctx
,
groupID
,
platform
,
hasForcePlatform
)
accounts
,
useMixed
,
err
:=
s
.
listSchedulableAccounts
(
ctx
,
groupID
,
platform
,
hasForcePlatform
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
...
@@ -1636,6 +1725,10 @@ func (s *GatewayService) resolveGroupByID(ctx context.Context, groupID int64) (*
...
@@ -1636,6 +1725,10 @@ func (s *GatewayService) resolveGroupByID(ctx context.Context, groupID int64) (*
return
group
,
nil
return
group
,
nil
}
}
func
(
s
*
GatewayService
)
ResolveGroupByID
(
ctx
context
.
Context
,
groupID
int64
)
(
*
Group
,
error
)
{
return
s
.
resolveGroupByID
(
ctx
,
groupID
)
}
func
(
s
*
GatewayService
)
routingAccountIDsForRequest
(
ctx
context
.
Context
,
groupID
*
int64
,
requestedModel
string
,
platform
string
)
[]
int64
{
func
(
s
*
GatewayService
)
routingAccountIDsForRequest
(
ctx
context
.
Context
,
groupID
*
int64
,
requestedModel
string
,
platform
string
)
[]
int64
{
if
groupID
==
nil
||
requestedModel
==
""
||
platform
!=
PlatformAnthropic
{
if
groupID
==
nil
||
requestedModel
==
""
||
platform
!=
PlatformAnthropic
{
return
nil
return
nil
...
@@ -1701,7 +1794,7 @@ func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID
...
@@ -1701,7 +1794,7 @@ func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID
}
}
// 强制平台模式不检查 Claude Code 限制
// 强制平台模式不检查 Claude Code 限制
if
_
,
hasForcePlatform
:=
ctx
.
Value
(
ctxkey
.
ForcePlatform
)
.
(
string
);
hasForcePlatform
{
if
forcePlatform
,
hasForcePlatform
:=
ctx
.
Value
(
ctxkey
.
ForcePlatform
)
.
(
string
);
hasForcePlatform
&&
forcePlatform
!=
""
{
return
nil
,
groupID
,
nil
return
nil
,
groupID
,
nil
}
}
...
@@ -2030,6 +2123,13 @@ func shuffleWithinPriority(accounts []*Account) {
...
@@ -2030,6 +2123,13 @@ func shuffleWithinPriority(accounts []*Account) {
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
func
(
s
*
GatewayService
)
selectAccountForModelWithPlatform
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{},
platform
string
)
(
*
Account
,
error
)
{
func
(
s
*
GatewayService
)
selectAccountForModelWithPlatform
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{},
platform
string
)
(
*
Account
,
error
)
{
// 对 Antigravity 平台,检查请求的模型系列是否在分组支持范围内
if
platform
==
PlatformAntigravity
&&
groupID
!=
nil
&&
requestedModel
!=
""
{
if
err
:=
s
.
checkAntigravityModelScope
(
ctx
,
*
groupID
,
requestedModel
);
err
!=
nil
{
return
nil
,
err
}
}
preferOAuth
:=
platform
==
PlatformGemini
preferOAuth
:=
platform
==
PlatformGemini
routingAccountIDs
:=
s
.
routingAccountIDsForRequest
(
ctx
,
groupID
,
requestedModel
,
platform
)
routingAccountIDs
:=
s
.
routingAccountIDsForRequest
(
ctx
,
groupID
,
requestedModel
,
platform
)
...
@@ -2465,6 +2565,10 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
...
@@ -2465,6 +2565,10 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
// Antigravity 平台使用专门的模型支持检查
// Antigravity 平台使用专门的模型支持检查
return
IsAntigravityModelSupported
(
requestedModel
)
return
IsAntigravityModelSupported
(
requestedModel
)
}
}
// OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID)
if
account
.
Platform
==
PlatformAnthropic
&&
account
.
Type
!=
AccountTypeAPIKey
{
requestedModel
=
claude
.
NormalizeModelID
(
requestedModel
)
}
// Gemini API Key 账户直接透传,由上游判断模型是否支持
// Gemini API Key 账户直接透传,由上游判断模型是否支持
if
account
.
Platform
==
PlatformGemini
&&
account
.
Type
==
AccountTypeAPIKey
{
if
account
.
Platform
==
PlatformGemini
&&
account
.
Type
==
AccountTypeAPIKey
{
return
true
return
true
...
@@ -2914,16 +3018,30 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
...
@@ -2914,16 +3018,30 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 强制执行 cache_control 块数量限制(最多 4 个)
// 强制执行 cache_control 块数量限制(最多 4 个)
body
=
enforceCacheControlLimit
(
body
)
body
=
enforceCacheControlLimit
(
body
)
// 应用模型映射(仅对apikey类型账号)
// 应用模型映射:
// - APIKey 账号:使用账号级别的显式映射(如果配置),否则透传原始模型名
// - OAuth/SetupToken 账号:使用 Anthropic 标准映射(短ID → 长ID)
mappedModel
:=
reqModel
mappingSource
:=
""
if
account
.
Type
==
AccountTypeAPIKey
{
if
account
.
Type
==
AccountTypeAPIKey
{
mappedModel
:
=
account
.
GetMappedModel
(
reqModel
)
mappedModel
=
account
.
GetMappedModel
(
reqModel
)
if
mappedModel
!=
reqModel
{
if
mappedModel
!=
reqModel
{
// 替换请求体中的模型名
mappingSource
=
"account"
body
=
s
.
replaceModelInBody
(
body
,
mappedModel
)
reqModel
=
mappedModel
log
.
Printf
(
"Model mapping applied: %s -> %s (account: %s)"
,
originalModel
,
mappedModel
,
account
.
Name
)
}
}
}
}
if
mappingSource
==
""
&&
account
.
Platform
==
PlatformAnthropic
&&
account
.
Type
!=
AccountTypeAPIKey
{
normalized
:=
claude
.
NormalizeModelID
(
reqModel
)
if
normalized
!=
reqModel
{
mappedModel
=
normalized
mappingSource
=
"prefix"
}
}
if
mappedModel
!=
reqModel
{
// 替换请求体中的模型名
body
=
s
.
replaceModelInBody
(
body
,
mappedModel
)
reqModel
=
mappedModel
log
.
Printf
(
"Model mapping applied: %s -> %s (account: %s, source=%s)"
,
originalModel
,
mappedModel
,
account
.
Name
,
mappingSource
)
}
// 获取凭证
// 获取凭证
token
,
tokenType
,
err
:=
s
.
GetAccessToken
(
ctx
,
account
)
token
,
tokenType
,
err
:=
s
.
GetAccessToken
(
ctx
,
account
)
...
@@ -3625,6 +3743,13 @@ func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool {
...
@@ -3625,6 +3743,13 @@ func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool {
return
true
return
true
}
}
// 检测 thinking block 被修改的错误
// 例如: "thinking or redacted_thinking blocks in the latest assistant message cannot be modified"
if
strings
.
Contains
(
msg
,
"cannot be modified"
)
&&
(
strings
.
Contains
(
msg
,
"thinking"
)
||
strings
.
Contains
(
msg
,
"redacted_thinking"
))
{
log
.
Printf
(
"[SignatureCheck] Detected thinking block modification error"
)
return
true
}
// 检测空消息内容错误(可能是过滤 thinking blocks 后导致的)
// 检测空消息内容错误(可能是过滤 thinking blocks 后导致的)
// 例如: "all messages must have non-empty content"
// 例如: "all messages must have non-empty content"
if
strings
.
Contains
(
msg
,
"non-empty content"
)
||
strings
.
Contains
(
msg
,
"empty content"
)
{
if
strings
.
Contains
(
msg
,
"non-empty content"
)
||
strings
.
Contains
(
msg
,
"empty content"
)
{
...
@@ -4493,13 +4618,19 @@ func (s *GatewayService) replaceToolNamesInResponseBody(body []byte, toolNameMap
...
@@ -4493,13 +4618,19 @@ func (s *GatewayService) replaceToolNamesInResponseBody(body []byte, toolNameMap
// RecordUsageInput 记录使用量的输入参数
// RecordUsageInput 记录使用量的输入参数
type
RecordUsageInput
struct
{
type
RecordUsageInput
struct
{
Result
*
ForwardResult
Result
*
ForwardResult
APIKey
*
APIKey
APIKey
*
APIKey
User
*
User
User
*
User
Account
*
Account
Account
*
Account
Subscription
*
UserSubscription
// 可选:订阅信息
Subscription
*
UserSubscription
// 可选:订阅信息
UserAgent
string
// 请求的 User-Agent
UserAgent
string
// 请求的 User-Agent
IPAddress
string
// 请求的客户端 IP 地址
IPAddress
string
// 请求的客户端 IP 地址
APIKeyService
APIKeyQuotaUpdater
// 可选:用于更新API Key配额
}
// APIKeyQuotaUpdater defines the interface for updating API Key quota
type
APIKeyQuotaUpdater
interface
{
UpdateQuotaUsed
(
ctx
context
.
Context
,
apiKeyID
int64
,
cost
float64
)
error
}
}
// RecordUsage 记录使用量并扣费(或更新订阅用量)
// RecordUsage 记录使用量并扣费(或更新订阅用量)
...
@@ -4661,6 +4792,13 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
...
@@ -4661,6 +4792,13 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
}
}
}
}
// 更新 API Key 配额(如果设置了配额限制)
if
shouldBill
&&
cost
.
ActualCost
>
0
&&
apiKey
.
Quota
>
0
&&
input
.
APIKeyService
!=
nil
{
if
err
:=
input
.
APIKeyService
.
UpdateQuotaUsed
(
ctx
,
apiKey
.
ID
,
cost
.
ActualCost
);
err
!=
nil
{
log
.
Printf
(
"Update API key quota failed: %v"
,
err
)
}
}
// Schedule batch update for account last_used_at
// Schedule batch update for account last_used_at
s
.
deferredService
.
ScheduleLastUsedUpdate
(
account
.
ID
)
s
.
deferredService
.
ScheduleLastUsedUpdate
(
account
.
ID
)
...
@@ -4678,6 +4816,7 @@ type RecordUsageLongContextInput struct {
...
@@ -4678,6 +4816,7 @@ type RecordUsageLongContextInput struct {
IPAddress
string
// 请求的客户端 IP 地址
IPAddress
string
// 请求的客户端 IP 地址
LongContextThreshold
int
// 长上下文阈值(如 200000)
LongContextThreshold
int
// 长上下文阈值(如 200000)
LongContextMultiplier
float64
// 超出阈值部分的倍率(如 2.0)
LongContextMultiplier
float64
// 超出阈值部分的倍率(如 2.0)
APIKeyService
*
APIKeyService
// API Key 配额服务(可选)
}
}
// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini)
// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini)
...
@@ -4814,6 +4953,12 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
...
@@ -4814,6 +4953,12 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
}
}
// 异步更新余额缓存
// 异步更新余额缓存
s
.
billingCacheService
.
QueueDeductBalance
(
user
.
ID
,
cost
.
ActualCost
)
s
.
billingCacheService
.
QueueDeductBalance
(
user
.
ID
,
cost
.
ActualCost
)
// API Key 独立配额扣费
if
input
.
APIKeyService
!=
nil
&&
apiKey
.
Quota
>
0
{
if
err
:=
input
.
APIKeyService
.
UpdateQuotaUsed
(
ctx
,
apiKey
.
ID
,
cost
.
ActualCost
);
err
!=
nil
{
log
.
Printf
(
"Add API key quota used failed: %v"
,
err
)
}
}
}
}
}
}
...
@@ -4848,16 +4993,30 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
...
@@ -4848,16 +4993,30 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
return
nil
return
nil
}
}
// 应用模型映射(仅对 apikey 类型账号)
// 应用模型映射:
if
account
.
Type
==
AccountTypeAPIKey
{
// - APIKey 账号:使用账号级别的显式映射(如果配置),否则透传原始模型名
if
reqModel
!=
""
{
// - OAuth/SetupToken 账号:使用 Anthropic 标准映射(短ID → 长ID)
mappedModel
:=
account
.
GetMappedModel
(
reqModel
)
if
reqModel
!=
""
{
mappedModel
:=
reqModel
mappingSource
:=
""
if
account
.
Type
==
AccountTypeAPIKey
{
mappedModel
=
account
.
GetMappedModel
(
reqModel
)
if
mappedModel
!=
reqModel
{
if
mappedModel
!=
reqModel
{
body
=
s
.
replaceModelInBody
(
body
,
mappedModel
)
mappingSource
=
"account"
reqModel
=
mappedModel
log
.
Printf
(
"CountTokens model mapping applied: %s -> %s (account: %s)"
,
parsed
.
Model
,
mappedModel
,
account
.
Name
)
}
}
}
}
if
mappingSource
==
""
&&
account
.
Platform
==
PlatformAnthropic
&&
account
.
Type
!=
AccountTypeAPIKey
{
normalized
:=
claude
.
NormalizeModelID
(
reqModel
)
if
normalized
!=
reqModel
{
mappedModel
=
normalized
mappingSource
=
"prefix"
}
}
if
mappedModel
!=
reqModel
{
body
=
s
.
replaceModelInBody
(
body
,
mappedModel
)
reqModel
=
mappedModel
log
.
Printf
(
"CountTokens model mapping applied: %s -> %s (account: %s, source=%s)"
,
parsed
.
Model
,
mappedModel
,
account
.
Name
,
mappingSource
)
}
}
}
// 获取凭证
// 获取凭证
...
@@ -5109,6 +5268,27 @@ func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
...
@@ -5109,6 +5268,27 @@ func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
return
normalized
,
nil
return
normalized
,
nil
}
}
// checkAntigravityModelScope 检查 Antigravity 平台的模型系列是否在分组支持范围内
func
(
s
*
GatewayService
)
checkAntigravityModelScope
(
ctx
context
.
Context
,
groupID
int64
,
requestedModel
string
)
error
{
scope
,
ok
:=
ResolveAntigravityQuotaScope
(
requestedModel
)
if
!
ok
{
return
nil
// 无法解析 scope,跳过检查
}
group
,
err
:=
s
.
resolveGroupByID
(
ctx
,
groupID
)
if
err
!=
nil
{
return
nil
// 查询失败时放行
}
if
group
==
nil
{
return
nil
// 分组不存在时放行
}
if
!
IsScopeSupported
(
group
.
SupportedModelScopes
,
scope
)
{
return
ErrModelScopeNotSupported
}
return
nil
}
// GetAvailableModels returns the list of models available for a group
// GetAvailableModels returns the list of models available for a group
// It aggregates model_mapping keys from all schedulable accounts in the group
// It aggregates model_mapping keys from all schedulable accounts in the group
func
(
s
*
GatewayService
)
GetAvailableModels
(
ctx
context
.
Context
,
groupID
*
int64
,
platform
string
)
[]
string
{
func
(
s
*
GatewayService
)
GetAvailableModels
(
ctx
context
.
Context
,
groupID
*
int64
,
platform
string
)
[]
string
{
...
...
backend/internal/service/gemini_messages_compat_service.go
View file @
de7ff902
...
@@ -977,6 +977,11 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
...
@@ -977,6 +977,11 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return
nil
,
s
.
writeGoogleError
(
c
,
http
.
StatusBadRequest
,
"Request body is empty"
)
return
nil
,
s
.
writeGoogleError
(
c
,
http
.
StatusBadRequest
,
"Request body is empty"
)
}
}
// 过滤掉 parts 为空的消息(Gemini API 不接受空 parts)
if
filteredBody
,
err
:=
filterEmptyPartsFromGeminiRequest
(
body
);
err
==
nil
{
body
=
filteredBody
}
switch
action
{
switch
action
{
case
"generateContent"
,
"streamGenerateContent"
,
"countTokens"
:
case
"generateContent"
,
"streamGenerateContent"
,
"countTokens"
:
// ok
// ok
...
...
backend/internal/service/gemini_native_signature_cleaner.go
View file @
de7ff902
...
@@ -2,20 +2,22 @@ package service
...
@@ -2,20 +2,22 @@ package service
import
(
import
(
"encoding/json"
"encoding/json"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
)
// CleanGeminiNativeThoughtSignatures 从 Gemini 原生 API 请求中
移除
thoughtSignature 字段,
// CleanGeminiNativeThoughtSignatures 从 Gemini 原生 API 请求中
替换
thoughtSignature 字段
为 dummy 签名
,
// 以避免跨账号签名验证错误。
// 以避免跨账号签名验证错误。
//
//
// 当粘性会话切换账号时(例如原账号异常、不可调度等),旧账号返回的 thoughtSignature
// 当粘性会话切换账号时(例如原账号异常、不可调度等),旧账号返回的 thoughtSignature
// 会导致新账号的签名验证失败。通过
移除这些签名,让新账号重新生成有效的签名
。
// 会导致新账号的签名验证失败。通过
替换为 dummy 签名,跳过签名验证
。
//
//
// CleanGeminiNativeThoughtSignatures re
mov
es thoughtSignature fields
from Gemini native API requests
// CleanGeminiNativeThoughtSignatures re
plac
es thoughtSignature fields
with dummy signature
// to avoid cross-account signature validation errors.
//
in Gemini native API requests
to avoid cross-account signature validation errors.
//
//
// When sticky session switches accounts (e.g., original account becomes unavailable),
// When sticky session switches accounts (e.g., original account becomes unavailable),
// thoughtSignatures from the old account will cause validation failures on the new account.
// thoughtSignatures from the old account will cause validation failures on the new account.
// By re
moving these
signature
s
, we
allow the new account to generate valid signatures
.
// By re
placing with dummy
signature, we
skip signature validation
.
func
CleanGeminiNativeThoughtSignatures
(
body
[]
byte
)
[]
byte
{
func
CleanGeminiNativeThoughtSignatures
(
body
[]
byte
)
[]
byte
{
if
len
(
body
)
==
0
{
if
len
(
body
)
==
0
{
return
body
return
body
...
@@ -28,11 +30,11 @@ func CleanGeminiNativeThoughtSignatures(body []byte) []byte {
...
@@ -28,11 +30,11 @@ func CleanGeminiNativeThoughtSignatures(body []byte) []byte {
return
body
return
body
}
}
// 递归
清理
thoughtSignature
// 递归
替换
thoughtSignature
为 dummy 签名
cleaned
:=
clean
ThoughtSignaturesRecursive
(
data
)
replaced
:=
replace
ThoughtSignaturesRecursive
(
data
)
// 重新序列化
// 重新序列化
result
,
err
:=
json
.
Marshal
(
clean
ed
)
result
,
err
:=
json
.
Marshal
(
replac
ed
)
if
err
!=
nil
{
if
err
!=
nil
{
// 如果序列化失败,返回原始 body
// 如果序列化失败,返回原始 body
return
body
return
body
...
@@ -41,19 +43,20 @@ func CleanGeminiNativeThoughtSignatures(body []byte) []byte {
...
@@ -41,19 +43,20 @@ func CleanGeminiNativeThoughtSignatures(body []byte) []byte {
return
result
return
result
}
}
//
clean
ThoughtSignaturesRecursive 递归遍历数据结构,
移除
所有 thoughtSignature 字段
//
replace
ThoughtSignaturesRecursive 递归遍历数据结构,
将
所有 thoughtSignature 字段
替换为 dummy 签名
func
clean
ThoughtSignaturesRecursive
(
data
any
)
any
{
func
replace
ThoughtSignaturesRecursive
(
data
any
)
any
{
switch
v
:=
data
.
(
type
)
{
switch
v
:=
data
.
(
type
)
{
case
map
[
string
]
any
:
case
map
[
string
]
any
:
// 创建新的 map,
移除
thoughtSignature
// 创建新的 map,
替换
thoughtSignature
为 dummy 签名
result
:=
make
(
map
[
string
]
any
,
len
(
v
))
result
:=
make
(
map
[
string
]
any
,
len
(
v
))
for
key
,
value
:=
range
v
{
for
key
,
value
:=
range
v
{
//
跳过
thoughtSignature 字段
//
替换
thoughtSignature 字段
为 dummy 签名
if
key
==
"thoughtSignature"
{
if
key
==
"thoughtSignature"
{
result
[
key
]
=
antigravity
.
DummyThoughtSignature
continue
continue
}
}
// 递归处理嵌套结构
// 递归处理嵌套结构
result
[
key
]
=
clean
ThoughtSignaturesRecursive
(
value
)
result
[
key
]
=
replace
ThoughtSignaturesRecursive
(
value
)
}
}
return
result
return
result
...
@@ -61,7 +64,7 @@ func cleanThoughtSignaturesRecursive(data any) any {
...
@@ -61,7 +64,7 @@ func cleanThoughtSignaturesRecursive(data any) any {
// 递归处理数组中的每个元素
// 递归处理数组中的每个元素
result
:=
make
([]
any
,
len
(
v
))
result
:=
make
([]
any
,
len
(
v
))
for
i
,
item
:=
range
v
{
for
i
,
item
:=
range
v
{
result
[
i
]
=
clean
ThoughtSignaturesRecursive
(
item
)
result
[
i
]
=
replace
ThoughtSignaturesRecursive
(
item
)
}
}
return
result
return
result
...
...
backend/internal/service/group.go
View file @
de7ff902
...
@@ -35,6 +35,8 @@ type Group struct {
...
@@ -35,6 +35,8 @@ type Group struct {
// Claude Code 客户端限制
// Claude Code 客户端限制
ClaudeCodeOnly
bool
ClaudeCodeOnly
bool
FallbackGroupID
*
int64
FallbackGroupID
*
int64
// 无效请求兜底分组(仅 anthropic 平台使用)
FallbackGroupIDOnInvalidRequest
*
int64
// 模型路由配置
// 模型路由配置
// key: 模型匹配模式(支持 * 通配符,如 "claude-opus-*")
// key: 模型匹配模式(支持 * 通配符,如 "claude-opus-*")
...
@@ -42,6 +44,13 @@ type Group struct {
...
@@ -42,6 +44,13 @@ type Group struct {
ModelRouting
map
[
string
][]
int64
ModelRouting
map
[
string
][]
int64
ModelRoutingEnabled
bool
ModelRoutingEnabled
bool
// MCP XML 协议注入开关(仅 antigravity 平台使用)
MCPXMLInject
bool
// 支持的模型系列(仅 antigravity 平台使用)
// 可选值: claude, gemini_text, gemini_image
SupportedModelScopes
[]
string
CreatedAt
time
.
Time
CreatedAt
time
.
Time
UpdatedAt
time
.
Time
UpdatedAt
time
.
Time
...
...
backend/internal/service/identity_service.go
View file @
de7ff902
...
@@ -169,22 +169,31 @@ func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) {
...
@@ -169,22 +169,31 @@ func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) {
// RewriteUserID 重写body中的metadata.user_id
// RewriteUserID 重写body中的metadata.user_id
// 输入格式:user_{clientId}_account__session_{sessionUUID}
// 输入格式:user_{clientId}_account__session_{sessionUUID}
// 输出格式:user_{cachedClientID}_account_{accountUUID}_session_{newHash}
// 输出格式:user_{cachedClientID}_account_{accountUUID}_session_{newHash}
//
// 重要:此函数使用 json.RawMessage 保留其他字段的原始字节,
// 避免重新序列化导致 thinking 块等内容被修改。
func
(
s
*
IdentityService
)
RewriteUserID
(
body
[]
byte
,
accountID
int64
,
accountUUID
,
cachedClientID
string
)
([]
byte
,
error
)
{
func
(
s
*
IdentityService
)
RewriteUserID
(
body
[]
byte
,
accountID
int64
,
accountUUID
,
cachedClientID
string
)
([]
byte
,
error
)
{
if
len
(
body
)
==
0
||
accountUUID
==
""
||
cachedClientID
==
""
{
if
len
(
body
)
==
0
||
accountUUID
==
""
||
cachedClientID
==
""
{
return
body
,
nil
return
body
,
nil
}
}
//
解析JSON
//
使用 RawMessage 保留其他字段的原始字节
var
reqMap
map
[
string
]
any
var
reqMap
map
[
string
]
json
.
RawMessage
if
err
:=
json
.
Unmarshal
(
body
,
&
reqMap
);
err
!=
nil
{
if
err
:=
json
.
Unmarshal
(
body
,
&
reqMap
);
err
!=
nil
{
return
body
,
nil
return
body
,
nil
}
}
metadata
,
ok
:=
reqMap
[
"metadata"
]
.
(
map
[
string
]
any
)
// 解析 metadata 字段
metadataRaw
,
ok
:=
reqMap
[
"metadata"
]
if
!
ok
{
if
!
ok
{
return
body
,
nil
return
body
,
nil
}
}
var
metadata
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
metadataRaw
,
&
metadata
);
err
!=
nil
{
return
body
,
nil
}
userID
,
ok
:=
metadata
[
"user_id"
]
.
(
string
)
userID
,
ok
:=
metadata
[
"user_id"
]
.
(
string
)
if
!
ok
||
userID
==
""
{
if
!
ok
||
userID
==
""
{
return
body
,
nil
return
body
,
nil
...
@@ -207,7 +216,13 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
...
@@ -207,7 +216,13 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
newUserID
:=
fmt
.
Sprintf
(
"user_%s_account_%s_session_%s"
,
cachedClientID
,
accountUUID
,
newSessionHash
)
newUserID
:=
fmt
.
Sprintf
(
"user_%s_account_%s_session_%s"
,
cachedClientID
,
accountUUID
,
newSessionHash
)
metadata
[
"user_id"
]
=
newUserID
metadata
[
"user_id"
]
=
newUserID
reqMap
[
"metadata"
]
=
metadata
// 只重新序列化 metadata 字段
newMetadataRaw
,
err
:=
json
.
Marshal
(
metadata
)
if
err
!=
nil
{
return
body
,
nil
}
reqMap
[
"metadata"
]
=
newMetadataRaw
return
json
.
Marshal
(
reqMap
)
return
json
.
Marshal
(
reqMap
)
}
}
...
@@ -215,6 +230,9 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
...
@@ -215,6 +230,9 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
// RewriteUserIDWithMasking 重写body中的metadata.user_id,支持会话ID伪装
// RewriteUserIDWithMasking 重写body中的metadata.user_id,支持会话ID伪装
// 如果账号启用了会话ID伪装(session_id_masking_enabled),
// 如果账号启用了会话ID伪装(session_id_masking_enabled),
// 则在完成常规重写后,将 session 部分替换为固定的伪装ID(15分钟内保持不变)
// 则在完成常规重写后,将 session 部分替换为固定的伪装ID(15分钟内保持不变)
//
// 重要:此函数使用 json.RawMessage 保留其他字段的原始字节,
// 避免重新序列化导致 thinking 块等内容被修改。
func
(
s
*
IdentityService
)
RewriteUserIDWithMasking
(
ctx
context
.
Context
,
body
[]
byte
,
account
*
Account
,
accountUUID
,
cachedClientID
string
)
([]
byte
,
error
)
{
func
(
s
*
IdentityService
)
RewriteUserIDWithMasking
(
ctx
context
.
Context
,
body
[]
byte
,
account
*
Account
,
accountUUID
,
cachedClientID
string
)
([]
byte
,
error
)
{
// 先执行常规的 RewriteUserID 逻辑
// 先执行常规的 RewriteUserID 逻辑
newBody
,
err
:=
s
.
RewriteUserID
(
body
,
account
.
ID
,
accountUUID
,
cachedClientID
)
newBody
,
err
:=
s
.
RewriteUserID
(
body
,
account
.
ID
,
accountUUID
,
cachedClientID
)
...
@@ -227,17 +245,23 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
...
@@ -227,17 +245,23 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
return
newBody
,
nil
return
newBody
,
nil
}
}
//
解析重写后的 body,提取 user_id
//
使用 RawMessage 保留其他字段的原始字节
var
reqMap
map
[
string
]
any
var
reqMap
map
[
string
]
json
.
RawMessage
if
err
:=
json
.
Unmarshal
(
newBody
,
&
reqMap
);
err
!=
nil
{
if
err
:=
json
.
Unmarshal
(
newBody
,
&
reqMap
);
err
!=
nil
{
return
newBody
,
nil
return
newBody
,
nil
}
}
metadata
,
ok
:=
reqMap
[
"metadata"
]
.
(
map
[
string
]
any
)
// 解析 metadata 字段
metadataRaw
,
ok
:=
reqMap
[
"metadata"
]
if
!
ok
{
if
!
ok
{
return
newBody
,
nil
return
newBody
,
nil
}
}
var
metadata
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
metadataRaw
,
&
metadata
);
err
!=
nil
{
return
newBody
,
nil
}
userID
,
ok
:=
metadata
[
"user_id"
]
.
(
string
)
userID
,
ok
:=
metadata
[
"user_id"
]
.
(
string
)
if
!
ok
||
userID
==
""
{
if
!
ok
||
userID
==
""
{
return
newBody
,
nil
return
newBody
,
nil
...
@@ -278,7 +302,13 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
...
@@ -278,7 +302,13 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
)
)
metadata
[
"user_id"
]
=
newUserID
metadata
[
"user_id"
]
=
newUserID
reqMap
[
"metadata"
]
=
metadata
// 只重新序列化 metadata 字段
newMetadataRaw
,
marshalErr
:=
json
.
Marshal
(
metadata
)
if
marshalErr
!=
nil
{
return
newBody
,
nil
}
reqMap
[
"metadata"
]
=
newMetadataRaw
return
json
.
Marshal
(
reqMap
)
return
json
.
Marshal
(
reqMap
)
}
}
...
...
backend/internal/service/openai_codex_transform.go
View file @
de7ff902
...
@@ -72,7 +72,7 @@ type opencodeCacheMetadata struct {
...
@@ -72,7 +72,7 @@ type opencodeCacheMetadata struct {
LastChecked
int64
`json:"lastChecked"`
LastChecked
int64
`json:"lastChecked"`
}
}
func
applyCodexOAuthTransform
(
reqBody
map
[
string
]
any
)
codexTransformResult
{
func
applyCodexOAuthTransform
(
reqBody
map
[
string
]
any
,
isCodexCLI
bool
)
codexTransformResult
{
result
:=
codexTransformResult
{}
result
:=
codexTransformResult
{}
// 工具续链需求会影响存储策略与 input 过滤逻辑。
// 工具续链需求会影响存储策略与 input 过滤逻辑。
needsToolContinuation
:=
NeedsToolContinuation
(
reqBody
)
needsToolContinuation
:=
NeedsToolContinuation
(
reqBody
)
...
@@ -118,22 +118,9 @@ func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult {
...
@@ -118,22 +118,9 @@ func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult {
result
.
PromptCacheKey
=
strings
.
TrimSpace
(
v
)
result
.
PromptCacheKey
=
strings
.
TrimSpace
(
v
)
}
}
instructions
:=
strings
.
TrimSpace
(
getOpenCodeCodexHeader
())
// instructions 处理逻辑:根据是否是 Codex CLI 分别调用不同方法
existingInstructions
,
_
:=
reqBody
[
"instructions"
]
.
(
string
)
if
applyInstructions
(
reqBody
,
isCodexCLI
)
{
existingInstructions
=
strings
.
TrimSpace
(
existingInstructions
)
result
.
Modified
=
true
if
instructions
!=
""
{
if
existingInstructions
!=
instructions
{
reqBody
[
"instructions"
]
=
instructions
result
.
Modified
=
true
}
}
else
if
existingInstructions
==
""
{
// 未获取到 opencode 指令时,回退使用 Codex CLI 指令。
codexInstructions
:=
strings
.
TrimSpace
(
getCodexCLIInstructions
())
if
codexInstructions
!=
""
{
reqBody
[
"instructions"
]
=
codexInstructions
result
.
Modified
=
true
}
}
}
// 续链场景保留 item_reference 与 id,避免 call_id 上下文丢失。
// 续链场景保留 item_reference 与 id,避免 call_id 上下文丢失。
...
@@ -276,6 +263,72 @@ func GetCodexCLIInstructions() string {
...
@@ -276,6 +263,72 @@ func GetCodexCLIInstructions() string {
return
getCodexCLIInstructions
()
return
getCodexCLIInstructions
()
}
}
// applyInstructions 处理 instructions 字段
// isCodexCLI=true: 仅补充缺失的 instructions(使用 opencode 指令)
// isCodexCLI=false: 优先使用 opencode 指令覆盖
func
applyInstructions
(
reqBody
map
[
string
]
any
,
isCodexCLI
bool
)
bool
{
if
isCodexCLI
{
return
applyCodexCLIInstructions
(
reqBody
)
}
return
applyOpenCodeInstructions
(
reqBody
)
}
// applyCodexCLIInstructions 为 Codex CLI 请求补充缺失的 instructions
// 仅在 instructions 为空时添加 opencode 指令
func
applyCodexCLIInstructions
(
reqBody
map
[
string
]
any
)
bool
{
if
!
isInstructionsEmpty
(
reqBody
)
{
return
false
// 已有有效 instructions,不修改
}
instructions
:=
strings
.
TrimSpace
(
getOpenCodeCodexHeader
())
if
instructions
!=
""
{
reqBody
[
"instructions"
]
=
instructions
return
true
}
return
false
}
// applyOpenCodeInstructions 为非 Codex CLI 请求应用 opencode 指令
// 优先使用 opencode 指令覆盖
func
applyOpenCodeInstructions
(
reqBody
map
[
string
]
any
)
bool
{
instructions
:=
strings
.
TrimSpace
(
getOpenCodeCodexHeader
())
existingInstructions
,
_
:=
reqBody
[
"instructions"
]
.
(
string
)
existingInstructions
=
strings
.
TrimSpace
(
existingInstructions
)
if
instructions
!=
""
{
if
existingInstructions
!=
instructions
{
reqBody
[
"instructions"
]
=
instructions
return
true
}
}
else
if
existingInstructions
==
""
{
codexInstructions
:=
strings
.
TrimSpace
(
getCodexCLIInstructions
())
if
codexInstructions
!=
""
{
reqBody
[
"instructions"
]
=
codexInstructions
return
true
}
}
return
false
}
// isInstructionsEmpty 检查 instructions 字段是否为空
// 处理以下情况:字段不存在、nil、空字符串、纯空白字符串
func
isInstructionsEmpty
(
reqBody
map
[
string
]
any
)
bool
{
val
,
exists
:=
reqBody
[
"instructions"
]
if
!
exists
{
return
true
}
if
val
==
nil
{
return
true
}
str
,
ok
:=
val
.
(
string
)
if
!
ok
{
return
true
}
return
strings
.
TrimSpace
(
str
)
==
""
}
// ReplaceWithCodexInstructions 将请求 instructions 替换为内置 Codex 指令(必要时)。
// ReplaceWithCodexInstructions 将请求 instructions 替换为内置 Codex 指令(必要时)。
func
ReplaceWithCodexInstructions
(
reqBody
map
[
string
]
any
)
bool
{
func
ReplaceWithCodexInstructions
(
reqBody
map
[
string
]
any
)
bool
{
codexInstructions
:=
strings
.
TrimSpace
(
getCodexCLIInstructions
())
codexInstructions
:=
strings
.
TrimSpace
(
getCodexCLIInstructions
())
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment