Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
dd96ada3
Unverified
Commit
dd96ada3
authored
Feb 04, 2026
by
程序猿MT
Committed by
GitHub
Feb 04, 2026
Browse files
Merge branch 'Wei-Shaw:main' into main
parents
31fe0178
8f397548
Changes
90
Hide whitespace changes
Inline
Side-by-side
backend/internal/server/middleware/api_key_auth_test.go
View file @
dd96ada3
...
...
@@ -319,6 +319,10 @@ func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) (
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubApiKeyRepo
)
IncrementQuotaUsed
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
(
float64
,
error
)
{
return
0
,
errors
.
New
(
"not implemented"
)
}
type
stubUserSubscriptionRepo
struct
{
getActive
func
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
service
.
UserSubscription
,
error
)
updateStatus
func
(
ctx
context
.
Context
,
subscriptionID
int64
,
status
string
)
error
...
...
backend/internal/service/admin_service.go
View file @
dd96ada3
...
...
@@ -111,9 +111,14 @@ type CreateGroupInput struct {
ImagePrice4K
*
float64
ClaudeCodeOnly
bool
// 仅允许 Claude Code 客户端
FallbackGroupID
*
int64
// 降级分组 ID
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
FallbackGroupIDOnInvalidRequest
*
int64
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting
map
[
string
][]
int64
ModelRoutingEnabled
bool
// 是否启用模型路由
MCPXMLInject
*
bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes
[]
string
// 从指定分组复制账号(创建分组后在同一事务内绑定)
CopyAccountsFromGroupIDs
[]
int64
}
...
...
@@ -135,9 +140,14 @@ type UpdateGroupInput struct {
ImagePrice4K
*
float64
ClaudeCodeOnly
*
bool
// 仅允许 Claude Code 客户端
FallbackGroupID
*
int64
// 降级分组 ID
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
FallbackGroupIDOnInvalidRequest
*
int64
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting
map
[
string
][]
int64
ModelRoutingEnabled
*
bool
// 是否启用模型路由
MCPXMLInject
*
bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes
*
[]
string
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs
[]
int64
}
...
...
@@ -594,6 +604,22 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
return
nil
,
err
}
}
fallbackOnInvalidRequest
:=
input
.
FallbackGroupIDOnInvalidRequest
if
fallbackOnInvalidRequest
!=
nil
&&
*
fallbackOnInvalidRequest
<=
0
{
fallbackOnInvalidRequest
=
nil
}
// 校验无效请求兜底分组
if
fallbackOnInvalidRequest
!=
nil
{
if
err
:=
s
.
validateFallbackGroupOnInvalidRequest
(
ctx
,
0
,
platform
,
subscriptionType
,
*
fallbackOnInvalidRequest
);
err
!=
nil
{
return
nil
,
err
}
}
// MCPXMLInject:默认为 true,仅当显式传入 false 时关闭
mcpXMLInject
:=
true
if
input
.
MCPXMLInject
!=
nil
{
mcpXMLInject
=
*
input
.
MCPXMLInject
}
// 如果指定了复制账号的源分组,先获取账号 ID 列表
var
accountIDsToCopy
[]
int64
...
...
@@ -628,22 +654,25 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
}
group
:=
&
Group
{
Name
:
input
.
Name
,
Description
:
input
.
Description
,
Platform
:
platform
,
RateMultiplier
:
input
.
RateMultiplier
,
IsExclusive
:
input
.
IsExclusive
,
Status
:
StatusActive
,
SubscriptionType
:
subscriptionType
,
DailyLimitUSD
:
dailyLimit
,
WeeklyLimitUSD
:
weeklyLimit
,
MonthlyLimitUSD
:
monthlyLimit
,
ImagePrice1K
:
imagePrice1K
,
ImagePrice2K
:
imagePrice2K
,
ImagePrice4K
:
imagePrice4K
,
ClaudeCodeOnly
:
input
.
ClaudeCodeOnly
,
FallbackGroupID
:
input
.
FallbackGroupID
,
ModelRouting
:
input
.
ModelRouting
,
Name
:
input
.
Name
,
Description
:
input
.
Description
,
Platform
:
platform
,
RateMultiplier
:
input
.
RateMultiplier
,
IsExclusive
:
input
.
IsExclusive
,
Status
:
StatusActive
,
SubscriptionType
:
subscriptionType
,
DailyLimitUSD
:
dailyLimit
,
WeeklyLimitUSD
:
weeklyLimit
,
MonthlyLimitUSD
:
monthlyLimit
,
ImagePrice1K
:
imagePrice1K
,
ImagePrice2K
:
imagePrice2K
,
ImagePrice4K
:
imagePrice4K
,
ClaudeCodeOnly
:
input
.
ClaudeCodeOnly
,
FallbackGroupID
:
input
.
FallbackGroupID
,
FallbackGroupIDOnInvalidRequest
:
fallbackOnInvalidRequest
,
ModelRouting
:
input
.
ModelRouting
,
MCPXMLInject
:
mcpXMLInject
,
SupportedModelScopes
:
input
.
SupportedModelScopes
,
}
if
err
:=
s
.
groupRepo
.
Create
(
ctx
,
group
);
err
!=
nil
{
return
nil
,
err
...
...
@@ -714,6 +743,37 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro
}
}
// validateFallbackGroupOnInvalidRequest 校验无效请求兜底分组的有效性
// currentGroupID: 当前分组 ID(新建时为 0)
// platform/subscriptionType: 当前分组的有效平台/订阅类型
// fallbackGroupID: 兜底分组 ID
func
(
s
*
adminServiceImpl
)
validateFallbackGroupOnInvalidRequest
(
ctx
context
.
Context
,
currentGroupID
int64
,
platform
,
subscriptionType
string
,
fallbackGroupID
int64
)
error
{
if
platform
!=
PlatformAnthropic
&&
platform
!=
PlatformAntigravity
{
return
fmt
.
Errorf
(
"invalid request fallback only supported for anthropic or antigravity groups"
)
}
if
subscriptionType
==
SubscriptionTypeSubscription
{
return
fmt
.
Errorf
(
"subscription groups cannot set invalid request fallback"
)
}
if
currentGroupID
>
0
&&
currentGroupID
==
fallbackGroupID
{
return
fmt
.
Errorf
(
"cannot set self as invalid request fallback group"
)
}
fallbackGroup
,
err
:=
s
.
groupRepo
.
GetByIDLite
(
ctx
,
fallbackGroupID
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"fallback group not found: %w"
,
err
)
}
if
fallbackGroup
.
Platform
!=
PlatformAnthropic
{
return
fmt
.
Errorf
(
"fallback group must be anthropic platform"
)
}
if
fallbackGroup
.
SubscriptionType
==
SubscriptionTypeSubscription
{
return
fmt
.
Errorf
(
"fallback group cannot be subscription type"
)
}
if
fallbackGroup
.
FallbackGroupIDOnInvalidRequest
!=
nil
{
return
fmt
.
Errorf
(
"fallback group cannot have invalid request fallback configured"
)
}
return
nil
}
func
(
s
*
adminServiceImpl
)
UpdateGroup
(
ctx
context
.
Context
,
id
int64
,
input
*
UpdateGroupInput
)
(
*
Group
,
error
)
{
group
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
...
...
@@ -780,6 +840,20 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
group
.
FallbackGroupID
=
nil
}
}
fallbackOnInvalidRequest
:=
group
.
FallbackGroupIDOnInvalidRequest
if
input
.
FallbackGroupIDOnInvalidRequest
!=
nil
{
if
*
input
.
FallbackGroupIDOnInvalidRequest
>
0
{
fallbackOnInvalidRequest
=
input
.
FallbackGroupIDOnInvalidRequest
}
else
{
fallbackOnInvalidRequest
=
nil
}
}
if
fallbackOnInvalidRequest
!=
nil
{
if
err
:=
s
.
validateFallbackGroupOnInvalidRequest
(
ctx
,
id
,
group
.
Platform
,
group
.
SubscriptionType
,
*
fallbackOnInvalidRequest
);
err
!=
nil
{
return
nil
,
err
}
}
group
.
FallbackGroupIDOnInvalidRequest
=
fallbackOnInvalidRequest
// 模型路由配置
if
input
.
ModelRouting
!=
nil
{
...
...
@@ -788,6 +862,14 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if
input
.
ModelRoutingEnabled
!=
nil
{
group
.
ModelRoutingEnabled
=
*
input
.
ModelRoutingEnabled
}
if
input
.
MCPXMLInject
!=
nil
{
group
.
MCPXMLInject
=
*
input
.
MCPXMLInject
}
// 支持的模型系列(仅 antigravity 平台使用)
if
input
.
SupportedModelScopes
!=
nil
{
group
.
SupportedModelScopes
=
*
input
.
SupportedModelScopes
}
if
err
:=
s
.
groupRepo
.
Update
(
ctx
,
group
);
err
!=
nil
{
return
nil
,
err
...
...
backend/internal/service/admin_service_group_test.go
View file @
dd96ada3
...
...
@@ -394,3 +394,382 @@ func (s *groupRepoStubForFallbackCycle) BindAccountsToGroup(_ context.Context, _
func
(
s
*
groupRepoStubForFallbackCycle
)
GetAccountIDsByGroupIDs
(
_
context
.
Context
,
_
[]
int64
)
([]
int64
,
error
)
{
panic
(
"unexpected GetAccountIDsByGroupIDs call"
)
}
type
groupRepoStubForInvalidRequestFallback
struct
{
groups
map
[
int64
]
*
Group
created
*
Group
updated
*
Group
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
Create
(
_
context
.
Context
,
g
*
Group
)
error
{
s
.
created
=
g
return
nil
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
Update
(
_
context
.
Context
,
g
*
Group
)
error
{
s
.
updated
=
g
return
nil
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
return
s
.
GetByIDLite
(
ctx
,
id
)
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
GetByIDLite
(
_
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
if
g
,
ok
:=
s
.
groups
[
id
];
ok
{
return
g
,
nil
}
return
nil
,
ErrGroupNotFound
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
Delete
(
_
context
.
Context
,
_
int64
)
error
{
panic
(
"unexpected Delete call"
)
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
DeleteCascade
(
_
context
.
Context
,
_
int64
)
([]
int64
,
error
)
{
panic
(
"unexpected DeleteCascade call"
)
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
List
(
_
context
.
Context
,
_
pagination
.
PaginationParams
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
panic
(
"unexpected List call"
)
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
ListWithFilters
(
_
context
.
Context
,
_
pagination
.
PaginationParams
,
_
,
_
,
_
string
,
_
*
bool
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
panic
(
"unexpected ListWithFilters call"
)
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
ListActive
(
_
context
.
Context
)
([]
Group
,
error
)
{
panic
(
"unexpected ListActive call"
)
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
ListActiveByPlatform
(
_
context
.
Context
,
_
string
)
([]
Group
,
error
)
{
panic
(
"unexpected ListActiveByPlatform call"
)
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
ExistsByName
(
_
context
.
Context
,
_
string
)
(
bool
,
error
)
{
panic
(
"unexpected ExistsByName call"
)
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
GetAccountCount
(
_
context
.
Context
,
_
int64
)
(
int64
,
error
)
{
panic
(
"unexpected GetAccountCount call"
)
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
DeleteAccountGroupsByGroupID
(
_
context
.
Context
,
_
int64
)
(
int64
,
error
)
{
panic
(
"unexpected DeleteAccountGroupsByGroupID call"
)
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
GetAccountIDsByGroupIDs
(
_
context
.
Context
,
_
[]
int64
)
([]
int64
,
error
)
{
panic
(
"unexpected GetAccountIDsByGroupIDs call"
)
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
BindAccountsToGroup
(
_
context
.
Context
,
_
int64
,
_
[]
int64
)
error
{
panic
(
"unexpected BindAccountsToGroup call"
)
}
func
TestAdminService_CreateGroup_InvalidRequestFallbackRejectsUnsupportedPlatform
(
t
*
testing
.
T
)
{
fallbackID
:=
int64
(
10
)
repo
:=
&
groupRepoStubForInvalidRequestFallback
{
groups
:
map
[
int64
]
*
Group
{
fallbackID
:
{
ID
:
fallbackID
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
},
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
_
,
err
:=
svc
.
CreateGroup
(
context
.
Background
(),
&
CreateGroupInput
{
Name
:
"g1"
,
Platform
:
PlatformOpenAI
,
SubscriptionType
:
SubscriptionTypeStandard
,
FallbackGroupIDOnInvalidRequest
:
&
fallbackID
,
})
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"invalid request fallback only supported for anthropic or antigravity groups"
)
require
.
Nil
(
t
,
repo
.
created
)
}
func
TestAdminService_CreateGroup_InvalidRequestFallbackRejectsSubscription
(
t
*
testing
.
T
)
{
fallbackID
:=
int64
(
10
)
repo
:=
&
groupRepoStubForInvalidRequestFallback
{
groups
:
map
[
int64
]
*
Group
{
fallbackID
:
{
ID
:
fallbackID
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
},
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
_
,
err
:=
svc
.
CreateGroup
(
context
.
Background
(),
&
CreateGroupInput
{
Name
:
"g1"
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeSubscription
,
FallbackGroupIDOnInvalidRequest
:
&
fallbackID
,
})
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"subscription groups cannot set invalid request fallback"
)
require
.
Nil
(
t
,
repo
.
created
)
}
func
TestAdminService_CreateGroup_InvalidRequestFallbackRejectsFallbackGroup
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
fallback
*
Group
wantMessage
string
}{
{
name
:
"openai_target"
,
fallback
:
&
Group
{
ID
:
10
,
Platform
:
PlatformOpenAI
,
SubscriptionType
:
SubscriptionTypeStandard
},
wantMessage
:
"fallback group must be anthropic platform"
,
},
{
name
:
"antigravity_target"
,
fallback
:
&
Group
{
ID
:
10
,
Platform
:
PlatformAntigravity
,
SubscriptionType
:
SubscriptionTypeStandard
},
wantMessage
:
"fallback group must be anthropic platform"
,
},
{
name
:
"subscription_group"
,
fallback
:
&
Group
{
ID
:
10
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeSubscription
},
wantMessage
:
"fallback group cannot be subscription type"
,
},
{
name
:
"nested_fallback"
,
fallback
:
&
Group
{
ID
:
10
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
,
FallbackGroupIDOnInvalidRequest
:
func
()
*
int64
{
v
:=
int64
(
99
);
return
&
v
}(),
},
wantMessage
:
"fallback group cannot have invalid request fallback configured"
,
},
}
for
_
,
tc
:=
range
tests
{
t
.
Run
(
tc
.
name
,
func
(
t
*
testing
.
T
)
{
fallbackID
:=
tc
.
fallback
.
ID
repo
:=
&
groupRepoStubForInvalidRequestFallback
{
groups
:
map
[
int64
]
*
Group
{
fallbackID
:
tc
.
fallback
,
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
_
,
err
:=
svc
.
CreateGroup
(
context
.
Background
(),
&
CreateGroupInput
{
Name
:
"g1"
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
,
FallbackGroupIDOnInvalidRequest
:
&
fallbackID
,
})
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
tc
.
wantMessage
)
require
.
Nil
(
t
,
repo
.
created
)
})
}
}
func
TestAdminService_CreateGroup_InvalidRequestFallbackNotFound
(
t
*
testing
.
T
)
{
fallbackID
:=
int64
(
10
)
repo
:=
&
groupRepoStubForInvalidRequestFallback
{}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
_
,
err
:=
svc
.
CreateGroup
(
context
.
Background
(),
&
CreateGroupInput
{
Name
:
"g1"
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
,
FallbackGroupIDOnInvalidRequest
:
&
fallbackID
,
})
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"fallback group not found"
)
require
.
Nil
(
t
,
repo
.
created
)
}
func
TestAdminService_CreateGroup_InvalidRequestFallbackAllowsAntigravity
(
t
*
testing
.
T
)
{
fallbackID
:=
int64
(
10
)
repo
:=
&
groupRepoStubForInvalidRequestFallback
{
groups
:
map
[
int64
]
*
Group
{
fallbackID
:
{
ID
:
fallbackID
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
},
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
group
,
err
:=
svc
.
CreateGroup
(
context
.
Background
(),
&
CreateGroupInput
{
Name
:
"g1"
,
Platform
:
PlatformAntigravity
,
SubscriptionType
:
SubscriptionTypeStandard
,
FallbackGroupIDOnInvalidRequest
:
&
fallbackID
,
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
group
)
require
.
NotNil
(
t
,
repo
.
created
)
require
.
Equal
(
t
,
fallbackID
,
*
repo
.
created
.
FallbackGroupIDOnInvalidRequest
)
}
func
TestAdminService_CreateGroup_InvalidRequestFallbackClearsOnZero
(
t
*
testing
.
T
)
{
zero
:=
int64
(
0
)
repo
:=
&
groupRepoStubForInvalidRequestFallback
{}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
group
,
err
:=
svc
.
CreateGroup
(
context
.
Background
(),
&
CreateGroupInput
{
Name
:
"g1"
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
,
FallbackGroupIDOnInvalidRequest
:
&
zero
,
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
group
)
require
.
NotNil
(
t
,
repo
.
created
)
require
.
Nil
(
t
,
repo
.
created
.
FallbackGroupIDOnInvalidRequest
)
}
func
TestAdminService_UpdateGroup_InvalidRequestFallbackPlatformMismatch
(
t
*
testing
.
T
)
{
fallbackID
:=
int64
(
10
)
existing
:=
&
Group
{
ID
:
1
,
Name
:
"g1"
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
,
Status
:
StatusActive
,
FallbackGroupIDOnInvalidRequest
:
&
fallbackID
,
}
repo
:=
&
groupRepoStubForInvalidRequestFallback
{
groups
:
map
[
int64
]
*
Group
{
existing
.
ID
:
existing
,
fallbackID
:
{
ID
:
fallbackID
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
},
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
_
,
err
:=
svc
.
UpdateGroup
(
context
.
Background
(),
existing
.
ID
,
&
UpdateGroupInput
{
Platform
:
PlatformOpenAI
,
})
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"invalid request fallback only supported for anthropic or antigravity groups"
)
require
.
Nil
(
t
,
repo
.
updated
)
}
func
TestAdminService_UpdateGroup_InvalidRequestFallbackSubscriptionMismatch
(
t
*
testing
.
T
)
{
fallbackID
:=
int64
(
10
)
existing
:=
&
Group
{
ID
:
1
,
Name
:
"g1"
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
,
Status
:
StatusActive
,
FallbackGroupIDOnInvalidRequest
:
&
fallbackID
,
}
repo
:=
&
groupRepoStubForInvalidRequestFallback
{
groups
:
map
[
int64
]
*
Group
{
existing
.
ID
:
existing
,
fallbackID
:
{
ID
:
fallbackID
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
},
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
_
,
err
:=
svc
.
UpdateGroup
(
context
.
Background
(),
existing
.
ID
,
&
UpdateGroupInput
{
SubscriptionType
:
SubscriptionTypeSubscription
,
})
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"subscription groups cannot set invalid request fallback"
)
require
.
Nil
(
t
,
repo
.
updated
)
}
func
TestAdminService_UpdateGroup_InvalidRequestFallbackClearsOnZero
(
t
*
testing
.
T
)
{
fallbackID
:=
int64
(
10
)
existing
:=
&
Group
{
ID
:
1
,
Name
:
"g1"
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
,
Status
:
StatusActive
,
FallbackGroupIDOnInvalidRequest
:
&
fallbackID
,
}
repo
:=
&
groupRepoStubForInvalidRequestFallback
{
groups
:
map
[
int64
]
*
Group
{
existing
.
ID
:
existing
,
fallbackID
:
{
ID
:
fallbackID
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
},
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
clear
:=
int64
(
0
)
group
,
err
:=
svc
.
UpdateGroup
(
context
.
Background
(),
existing
.
ID
,
&
UpdateGroupInput
{
Platform
:
PlatformOpenAI
,
FallbackGroupIDOnInvalidRequest
:
&
clear
,
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
group
)
require
.
NotNil
(
t
,
repo
.
updated
)
require
.
Nil
(
t
,
repo
.
updated
.
FallbackGroupIDOnInvalidRequest
)
}
func
TestAdminService_UpdateGroup_InvalidRequestFallbackRejectsFallbackGroup
(
t
*
testing
.
T
)
{
fallbackID
:=
int64
(
10
)
existing
:=
&
Group
{
ID
:
1
,
Name
:
"g1"
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
,
Status
:
StatusActive
,
}
repo
:=
&
groupRepoStubForInvalidRequestFallback
{
groups
:
map
[
int64
]
*
Group
{
existing
.
ID
:
existing
,
fallbackID
:
{
ID
:
fallbackID
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeSubscription
},
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
_
,
err
:=
svc
.
UpdateGroup
(
context
.
Background
(),
existing
.
ID
,
&
UpdateGroupInput
{
FallbackGroupIDOnInvalidRequest
:
&
fallbackID
,
})
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"fallback group cannot be subscription type"
)
require
.
Nil
(
t
,
repo
.
updated
)
}
func
TestAdminService_UpdateGroup_InvalidRequestFallbackSetSuccess
(
t
*
testing
.
T
)
{
fallbackID
:=
int64
(
10
)
existing
:=
&
Group
{
ID
:
1
,
Name
:
"g1"
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
,
Status
:
StatusActive
,
}
repo
:=
&
groupRepoStubForInvalidRequestFallback
{
groups
:
map
[
int64
]
*
Group
{
existing
.
ID
:
existing
,
fallbackID
:
{
ID
:
fallbackID
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
},
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
group
,
err
:=
svc
.
UpdateGroup
(
context
.
Background
(),
existing
.
ID
,
&
UpdateGroupInput
{
FallbackGroupIDOnInvalidRequest
:
&
fallbackID
,
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
group
)
require
.
NotNil
(
t
,
repo
.
updated
)
require
.
Equal
(
t
,
fallbackID
,
*
repo
.
updated
.
FallbackGroupIDOnInvalidRequest
)
}
func
TestAdminService_UpdateGroup_InvalidRequestFallbackAllowsAntigravity
(
t
*
testing
.
T
)
{
fallbackID
:=
int64
(
10
)
existing
:=
&
Group
{
ID
:
1
,
Name
:
"g1"
,
Platform
:
PlatformAntigravity
,
SubscriptionType
:
SubscriptionTypeStandard
,
Status
:
StatusActive
,
}
repo
:=
&
groupRepoStubForInvalidRequestFallback
{
groups
:
map
[
int64
]
*
Group
{
existing
.
ID
:
existing
,
fallbackID
:
{
ID
:
fallbackID
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
},
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
group
,
err
:=
svc
.
UpdateGroup
(
context
.
Background
(),
existing
.
ID
,
&
UpdateGroupInput
{
FallbackGroupIDOnInvalidRequest
:
&
fallbackID
,
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
group
)
require
.
NotNil
(
t
,
repo
.
updated
)
require
.
Equal
(
t
,
fallbackID
,
*
repo
.
updated
.
FallbackGroupIDOnInvalidRequest
)
}
backend/internal/service/antigravity_gateway_service.go
View file @
dd96ada3
...
...
@@ -13,23 +13,34 @@ import (
"net"
"net/http"
"os"
"strconv"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
const
(
antigravityStickySessionTTL
=
time
.
Hour
antigravityMaxRetries
=
3
antigravityRetryBaseDelay
=
1
*
time
.
Second
antigravityRetryMaxDelay
=
16
*
time
.
Second
antigravityStickySessionTTL
=
time
.
Hour
antigravity
Default
MaxRetries
=
3
antigravityRetryBaseDelay
=
1
*
time
.
Second
antigravityRetryMaxDelay
=
16
*
time
.
Second
)
const
antigravityScopeRateLimitEnv
=
"GATEWAY_ANTIGRAVITY_429_SCOPE_LIMIT"
const
(
antigravityMaxRetriesEnv
=
"GATEWAY_ANTIGRAVITY_MAX_RETRIES"
antigravityMaxRetriesAfterSwitchEnv
=
"GATEWAY_ANTIGRAVITY_AFTER_SWITCHMAX_RETRIES"
antigravityMaxRetriesClaudeEnv
=
"GATEWAY_ANTIGRAVITY_MAX_RETRIES_CLAUDE"
antigravityMaxRetriesGeminiTextEnv
=
"GATEWAY_ANTIGRAVITY_MAX_RETRIES_GEMINI_TEXT"
antigravityMaxRetriesGeminiImageEnv
=
"GATEWAY_ANTIGRAVITY_MAX_RETRIES_GEMINI_IMAGE"
antigravityScopeRateLimitEnv
=
"GATEWAY_ANTIGRAVITY_429_SCOPE_LIMIT"
antigravityBillingModelEnv
=
"GATEWAY_ANTIGRAVITY_BILL_WITH_MAPPED_MODEL"
antigravityFallbackSecondsEnv
=
"GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS"
)
// antigravityRetryLoopParams 重试循环的参数
type
antigravityRetryLoopParams
struct
{
...
...
@@ -41,6 +52,7 @@ type antigravityRetryLoopParams struct {
action
string
body
[]
byte
quotaScope
AntigravityQuotaScope
maxRetries
int
c
*
gin
.
Context
httpUpstream
HTTPUpstream
settingService
*
SettingService
...
...
@@ -52,11 +64,28 @@ type antigravityRetryLoopResult struct {
resp
*
http
.
Response
}
// PromptTooLongError 表示上游明确返回 prompt too long
type
PromptTooLongError
struct
{
StatusCode
int
RequestID
string
Body
[]
byte
}
func
(
e
*
PromptTooLongError
)
Error
()
string
{
return
fmt
.
Sprintf
(
"prompt too long: status=%d"
,
e
.
StatusCode
)
}
// antigravityRetryLoop 执行带 URL fallback 的重试循环
func
antigravityRetryLoop
(
p
antigravityRetryLoopParams
)
(
*
antigravityRetryLoopResult
,
error
)
{
availableURLs
:=
antigravity
.
DefaultURLAvailability
.
GetAvailableURLs
()
baseURLs
:=
antigravity
.
ForwardBaseURLs
()
availableURLs
:=
antigravity
.
DefaultURLAvailability
.
GetAvailableURLsWithBase
(
baseURLs
)
if
len
(
availableURLs
)
==
0
{
availableURLs
=
antigravity
.
BaseURLs
availableURLs
=
baseURLs
}
maxRetries
:=
p
.
maxRetries
if
maxRetries
<=
0
{
maxRetries
=
antigravityDefaultMaxRetries
}
var
resp
*
http
.
Response
...
...
@@ -76,7 +105,7 @@ func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopRe
urlFallbackLoop
:
for
urlIdx
,
baseURL
:=
range
availableURLs
{
usedBaseURL
=
baseURL
for
attempt
:=
1
;
attempt
<=
antigravityM
axRetries
;
attempt
++
{
for
attempt
:=
1
;
attempt
<=
m
axRetries
;
attempt
++
{
select
{
case
<-
p
.
ctx
.
Done
()
:
log
.
Printf
(
"%s status=context_canceled error=%v"
,
p
.
prefix
,
p
.
ctx
.
Err
())
...
...
@@ -109,8 +138,8 @@ urlFallbackLoop:
log
.
Printf
(
"%s URL fallback (connection error): %s -> %s"
,
p
.
prefix
,
baseURL
,
availableURLs
[
urlIdx
+
1
])
continue
urlFallbackLoop
}
if
attempt
<
antigravityM
axRetries
{
log
.
Printf
(
"%s status=request_failed retry=%d/%d error=%v"
,
p
.
prefix
,
attempt
,
antigravityM
axRetries
,
err
)
if
attempt
<
m
axRetries
{
log
.
Printf
(
"%s status=request_failed retry=%d/%d error=%v"
,
p
.
prefix
,
attempt
,
m
axRetries
,
err
)
if
!
sleepAntigravityBackoffWithContext
(
p
.
ctx
,
attempt
)
{
log
.
Printf
(
"%s status=context_canceled_during_backoff"
,
p
.
prefix
)
return
nil
,
p
.
ctx
.
Err
()
...
...
@@ -134,7 +163,7 @@ urlFallbackLoop:
}
// 账户/模型配额限流,重试 3 次(指数退避)
if
attempt
<
antigravityM
axRetries
{
if
attempt
<
m
axRetries
{
upstreamMsg
:=
strings
.
TrimSpace
(
extractAntigravityErrorMessage
(
respBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
appendOpsUpstreamError
(
p
.
c
,
OpsUpstreamErrorEvent
{
...
...
@@ -147,7 +176,7 @@ urlFallbackLoop:
Message
:
upstreamMsg
,
Detail
:
getUpstreamDetail
(
respBody
),
})
log
.
Printf
(
"%s status=429 retry=%d/%d body=%s"
,
p
.
prefix
,
attempt
,
antigravityM
axRetries
,
truncateForLog
(
respBody
,
200
))
log
.
Printf
(
"%s status=429 retry=%d/%d body=%s"
,
p
.
prefix
,
attempt
,
m
axRetries
,
truncateForLog
(
respBody
,
200
))
if
!
sleepAntigravityBackoffWithContext
(
p
.
ctx
,
attempt
)
{
log
.
Printf
(
"%s status=context_canceled_during_backoff"
,
p
.
prefix
)
return
nil
,
p
.
ctx
.
Err
()
...
...
@@ -171,7 +200,7 @@ urlFallbackLoop:
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
_
=
resp
.
Body
.
Close
()
if
attempt
<
antigravityM
axRetries
{
if
attempt
<
m
axRetries
{
upstreamMsg
:=
strings
.
TrimSpace
(
extractAntigravityErrorMessage
(
respBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
appendOpsUpstreamError
(
p
.
c
,
OpsUpstreamErrorEvent
{
...
...
@@ -184,7 +213,7 @@ urlFallbackLoop:
Message
:
upstreamMsg
,
Detail
:
getUpstreamDetail
(
respBody
),
})
log
.
Printf
(
"%s status=%d retry=%d/%d body=%s"
,
p
.
prefix
,
resp
.
StatusCode
,
attempt
,
antigravityM
axRetries
,
truncateForLog
(
respBody
,
500
))
log
.
Printf
(
"%s status=%d retry=%d/%d body=%s"
,
p
.
prefix
,
resp
.
StatusCode
,
attempt
,
m
axRetries
,
truncateForLog
(
respBody
,
500
))
if
!
sleepAntigravityBackoffWithContext
(
p
.
ctx
,
attempt
)
{
log
.
Printf
(
"%s status=context_canceled_during_backoff"
,
p
.
prefix
)
return
nil
,
p
.
ctx
.
Err
()
...
...
@@ -390,6 +419,11 @@ type TestConnectionResult struct {
// TestConnection 测试 Antigravity 账号连接(非流式,无重试、无计费)
// 支持 Claude 和 Gemini 两种协议,根据 modelID 前缀自动选择
func
(
s
*
AntigravityGatewayService
)
TestConnection
(
ctx
context
.
Context
,
account
*
Account
,
modelID
string
)
(
*
TestConnectionResult
,
error
)
{
// 上游透传账号使用专用测试方法
if
account
.
Type
==
AccountTypeUpstream
{
return
s
.
testUpstreamConnection
(
ctx
,
account
,
modelID
)
}
// 获取 token
if
s
.
tokenProvider
==
nil
{
return
nil
,
errors
.
New
(
"antigravity token provider not configured"
)
...
...
@@ -484,6 +518,87 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
return
nil
,
lastErr
}
// testUpstreamConnection 测试上游透传账号连接
func
(
s
*
AntigravityGatewayService
)
testUpstreamConnection
(
ctx
context
.
Context
,
account
*
Account
,
modelID
string
)
(
*
TestConnectionResult
,
error
)
{
baseURL
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"base_url"
))
apiKey
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"api_key"
))
if
baseURL
==
""
||
apiKey
==
""
{
return
nil
,
errors
.
New
(
"upstream account missing base_url or api_key"
)
}
baseURL
=
strings
.
TrimSuffix
(
baseURL
,
"/"
)
// 使用 Claude 模型进行测试
if
modelID
==
""
{
modelID
=
"claude-sonnet-4-20250514"
}
// 构建最小测试请求
testReq
:=
map
[
string
]
any
{
"model"
:
modelID
,
"max_tokens"
:
1
,
"messages"
:
[]
map
[
string
]
any
{
{
"role"
:
"user"
,
"content"
:
"."
},
},
}
requestBody
,
err
:=
json
.
Marshal
(
testReq
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"构建请求失败: %w"
,
err
)
}
// 构建 HTTP 请求
upstreamURL
:=
baseURL
+
"/v1/messages"
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodPost
,
upstreamURL
,
bytes
.
NewReader
(
requestBody
))
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"创建请求失败: %w"
,
err
)
}
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
apiKey
)
req
.
Header
.
Set
(
"x-api-key"
,
apiKey
)
req
.
Header
.
Set
(
"anthropic-version"
,
"2023-06-01"
)
// 代理 URL
proxyURL
:=
""
if
account
.
ProxyID
!=
nil
&&
account
.
Proxy
!=
nil
{
proxyURL
=
account
.
Proxy
.
URL
()
}
log
.
Printf
(
"[antigravity-Test-Upstream] account=%s url=%s"
,
account
.
Name
,
upstreamURL
)
// 发送请求
resp
,
err
:=
s
.
httpUpstream
.
Do
(
req
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"请求失败: %w"
,
err
)
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
respBody
,
err
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"读取响应失败: %w"
,
err
)
}
if
resp
.
StatusCode
>=
400
{
return
nil
,
fmt
.
Errorf
(
"API 返回 %d: %s"
,
resp
.
StatusCode
,
string
(
respBody
))
}
// 提取响应文本
var
respData
map
[
string
]
any
text
:=
""
if
json
.
Unmarshal
(
respBody
,
&
respData
)
==
nil
{
if
content
,
ok
:=
respData
[
"content"
]
.
([]
any
);
ok
&&
len
(
content
)
>
0
{
if
block
,
ok
:=
content
[
0
]
.
(
map
[
string
]
any
);
ok
{
if
t
,
ok
:=
block
[
"text"
]
.
(
string
);
ok
{
text
=
t
}
}
}
}
return
&
TestConnectionResult
{
Text
:
text
,
MappedModel
:
modelID
,
},
nil
}
// buildGeminiTestRequest 构建 Gemini 格式测试请求
// 使用最小 token 消耗:输入 "." + maxOutputTokens: 1
func
(
s
*
AntigravityGatewayService
)
buildGeminiTestRequest
(
projectID
,
model
string
)
([]
byte
,
error
)
{
...
...
@@ -534,6 +649,10 @@ func (s *AntigravityGatewayService) getClaudeTransformOptions(ctx context.Contex
}
opts
.
EnableIdentityPatch
=
s
.
settingService
.
IsIdentityPatchEnabled
(
ctx
)
opts
.
IdentityPatch
=
s
.
settingService
.
GetIdentityPatchPrompt
(
ctx
)
if
group
,
ok
:=
ctx
.
Value
(
ctxkey
.
Group
)
.
(
*
Group
);
ok
&&
group
!=
nil
{
opts
.
EnableMCPXML
=
group
.
MCPXMLInject
}
return
opts
}
...
...
@@ -702,6 +821,11 @@ func isModelNotFoundError(statusCode int, body []byte) bool {
// Forward 转发 Claude 协议请求(Claude → Gemini 转换)
func
(
s
*
AntigravityGatewayService
)
Forward
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
)
(
*
ForwardResult
,
error
)
{
// 上游透传账号直接转发,不走 OAuth token 刷新
if
account
.
Type
==
AccountTypeUpstream
{
return
s
.
ForwardUpstream
(
ctx
,
c
,
account
,
body
)
}
startTime
:=
time
.
Now
()
sessionID
:=
getSessionID
(
c
)
prefix
:=
logPrefix
(
sessionID
,
account
.
Name
)
...
...
@@ -718,6 +842,12 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
originalModel
:=
claudeReq
.
Model
mappedModel
:=
s
.
getMappedModel
(
account
,
claudeReq
.
Model
)
quotaScope
,
_
:=
resolveAntigravityQuotaScope
(
originalModel
)
billingModel
:=
originalModel
if
antigravityUseMappedModelForBilling
()
&&
strings
.
TrimSpace
(
mappedModel
)
!=
""
{
billingModel
=
mappedModel
}
afterSwitch
:=
antigravityHasAccountSwitch
(
ctx
)
maxRetries
:=
antigravityMaxRetriesForModel
(
originalModel
,
afterSwitch
)
// 获取 access_token
if
s
.
tokenProvider
==
nil
{
...
...
@@ -766,6 +896,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
httpUpstream
:
s
.
httpUpstream
,
settingService
:
s
.
settingService
,
handleError
:
s
.
handleUpstreamError
,
maxRetries
:
maxRetries
,
})
if
err
!=
nil
{
return
nil
,
s
.
writeClaudeError
(
c
,
http
.
StatusBadGateway
,
"upstream_error"
,
"Upstream request failed after retries"
)
...
...
@@ -842,6 +973,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
httpUpstream
:
s
.
httpUpstream
,
settingService
:
s
.
settingService
,
handleError
:
s
.
handleUpstreamError
,
maxRetries
:
maxRetries
,
})
if
retryErr
!=
nil
{
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
...
...
@@ -917,6 +1049,39 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
// 处理错误响应(重试后仍失败或不触发重试)
if
resp
.
StatusCode
>=
400
{
if
resp
.
StatusCode
==
http
.
StatusBadRequest
{
upstreamMsg
:=
strings
.
TrimSpace
(
extractAntigravityErrorMessage
(
respBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
log
.
Printf
(
"%s status=400 prompt_too_long=%v upstream_message=%q request_id=%s body=%s"
,
prefix
,
isPromptTooLongError
(
respBody
),
upstreamMsg
,
resp
.
Header
.
Get
(
"x-request-id"
),
truncateForLog
(
respBody
,
500
))
}
if
resp
.
StatusCode
==
http
.
StatusBadRequest
&&
isPromptTooLongError
(
respBody
)
{
upstreamMsg
:=
strings
.
TrimSpace
(
extractAntigravityErrorMessage
(
respBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
logBody
:=
s
.
settingService
!=
nil
&&
s
.
settingService
.
cfg
!=
nil
&&
s
.
settingService
.
cfg
.
Gateway
.
LogUpstreamErrorBody
maxBytes
:=
2048
if
s
.
settingService
!=
nil
&&
s
.
settingService
.
cfg
!=
nil
&&
s
.
settingService
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
>
0
{
maxBytes
=
s
.
settingService
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
}
upstreamDetail
:=
""
if
logBody
{
upstreamDetail
=
truncateString
(
string
(
respBody
),
maxBytes
)
}
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Kind
:
"prompt_too_long"
,
Message
:
upstreamMsg
,
Detail
:
upstreamDetail
,
})
return
nil
,
&
PromptTooLongError
{
StatusCode
:
resp
.
StatusCode
,
RequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Body
:
respBody
,
}
}
s
.
handleUpstreamError
(
ctx
,
prefix
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
,
quotaScope
)
if
s
.
shouldFailoverUpstreamError
(
resp
.
StatusCode
)
{
...
...
@@ -978,7 +1143,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return
&
ForwardResult
{
RequestID
:
requestID
,
Usage
:
*
usage
,
Model
:
original
Model
,
//
使用原始模型用于计费和日志
Model
:
billing
Model
,
//
计费模型(可按映射模型覆盖)
Stream
:
claudeReq
.
Stream
,
Duration
:
time
.
Since
(
startTime
),
FirstTokenMs
:
firstTokenMs
,
...
...
@@ -1003,24 +1168,64 @@ func isSignatureRelatedError(respBody []byte) bool {
return
true
}
// Detect thinking block modification errors:
// "thinking or redacted_thinking blocks in the latest assistant message cannot be modified"
if
strings
.
Contains
(
msg
,
"cannot be modified"
)
&&
(
strings
.
Contains
(
msg
,
"thinking"
)
||
strings
.
Contains
(
msg
,
"redacted_thinking"
))
{
return
true
}
return
false
}
func
isPromptTooLongError
(
respBody
[]
byte
)
bool
{
msg
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
extractAntigravityErrorMessage
(
respBody
)))
if
msg
==
""
{
msg
=
strings
.
ToLower
(
string
(
respBody
))
}
return
strings
.
Contains
(
msg
,
"prompt is too long"
)
}
func
extractAntigravityErrorMessage
(
body
[]
byte
)
string
{
var
payload
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
payload
);
err
!=
nil
{
return
""
}
parseNestedMessage
:=
func
(
msg
string
)
string
{
trimmed
:=
strings
.
TrimSpace
(
msg
)
if
trimmed
==
""
||
!
strings
.
HasPrefix
(
trimmed
,
"{"
)
{
return
""
}
var
nested
map
[
string
]
any
if
err
:=
json
.
Unmarshal
([]
byte
(
trimmed
),
&
nested
);
err
!=
nil
{
return
""
}
if
errObj
,
ok
:=
nested
[
"error"
]
.
(
map
[
string
]
any
);
ok
{
if
innerMsg
,
ok
:=
errObj
[
"message"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
innerMsg
)
!=
""
{
return
innerMsg
}
}
if
innerMsg
,
ok
:=
nested
[
"message"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
innerMsg
)
!=
""
{
return
innerMsg
}
return
""
}
// Google-style: {"error": {"message": "..."}}
if
errObj
,
ok
:=
payload
[
"error"
]
.
(
map
[
string
]
any
);
ok
{
if
msg
,
ok
:=
errObj
[
"message"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
msg
)
!=
""
{
if
innerMsg
:=
parseNestedMessage
(
msg
);
innerMsg
!=
""
{
return
innerMsg
}
return
msg
}
}
// Fallback: top-level message
if
msg
,
ok
:=
payload
[
"message"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
msg
)
!=
""
{
if
innerMsg
:=
parseNestedMessage
(
msg
);
innerMsg
!=
""
{
return
innerMsg
}
return
msg
}
...
...
@@ -1248,6 +1453,208 @@ func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeReque
return
changed
,
nil
}
// ForwardUpstream 透传请求到上游 Antigravity 服务
// 用于 upstream 类型账号,直接使用 base_url + api_key 转发,不走 OAuth token
func
(
s
*
AntigravityGatewayService
)
ForwardUpstream
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
)
(
*
ForwardResult
,
error
)
{
startTime
:=
time
.
Now
()
sessionID
:=
getSessionID
(
c
)
prefix
:=
logPrefix
(
sessionID
,
account
.
Name
)
// 获取上游配置
baseURL
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"base_url"
))
apiKey
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"api_key"
))
if
baseURL
==
""
||
apiKey
==
""
{
return
nil
,
fmt
.
Errorf
(
"upstream account missing base_url or api_key"
)
}
baseURL
=
strings
.
TrimSuffix
(
baseURL
,
"/"
)
// 解析请求获取模型信息
var
claudeReq
antigravity
.
ClaudeRequest
if
err
:=
json
.
Unmarshal
(
body
,
&
claudeReq
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"parse claude request: %w"
,
err
)
}
if
strings
.
TrimSpace
(
claudeReq
.
Model
)
==
""
{
return
nil
,
fmt
.
Errorf
(
"missing model"
)
}
originalModel
:=
claudeReq
.
Model
billingModel
:=
originalModel
// 构建上游请求 URL
upstreamURL
:=
baseURL
+
"/v1/messages"
// 创建请求
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodPost
,
upstreamURL
,
bytes
.
NewReader
(
body
))
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"create upstream request: %w"
,
err
)
}
// 设置请求头
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
apiKey
)
req
.
Header
.
Set
(
"x-api-key"
,
apiKey
)
// Claude API 兼容
// 透传 Claude 相关 headers
if
v
:=
c
.
GetHeader
(
"anthropic-version"
);
v
!=
""
{
req
.
Header
.
Set
(
"anthropic-version"
,
v
)
}
if
v
:=
c
.
GetHeader
(
"anthropic-beta"
);
v
!=
""
{
req
.
Header
.
Set
(
"anthropic-beta"
,
v
)
}
// 代理 URL
proxyURL
:=
""
if
account
.
ProxyID
!=
nil
&&
account
.
Proxy
!=
nil
{
proxyURL
=
account
.
Proxy
.
URL
()
}
// 发送请求
resp
,
err
:=
s
.
httpUpstream
.
Do
(
req
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
err
!=
nil
{
log
.
Printf
(
"%s upstream request failed: %v"
,
prefix
,
err
)
return
nil
,
fmt
.
Errorf
(
"upstream request failed: %w"
,
err
)
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
// 处理错误响应
if
resp
.
StatusCode
>=
400
{
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
// 429 错误时标记账号限流
if
resp
.
StatusCode
==
http
.
StatusTooManyRequests
{
s
.
handleUpstreamError
(
ctx
,
prefix
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
,
AntigravityQuotaScopeClaude
)
}
// 透传上游错误
c
.
Header
(
"Content-Type"
,
resp
.
Header
.
Get
(
"Content-Type"
))
c
.
Status
(
resp
.
StatusCode
)
_
,
_
=
c
.
Writer
.
Write
(
respBody
)
return
&
ForwardResult
{
Model
:
billingModel
,
},
nil
}
// 处理成功响应(流式/非流式)
var
usage
*
ClaudeUsage
var
firstTokenMs
*
int
if
claudeReq
.
Stream
{
// 流式响应:透传
c
.
Header
(
"Content-Type"
,
"text/event-stream"
)
c
.
Header
(
"Cache-Control"
,
"no-cache"
)
c
.
Header
(
"Connection"
,
"keep-alive"
)
c
.
Header
(
"X-Accel-Buffering"
,
"no"
)
c
.
Status
(
http
.
StatusOK
)
usage
,
firstTokenMs
=
s
.
streamUpstreamResponse
(
c
,
resp
,
startTime
)
}
else
{
// 非流式响应:直接透传
respBody
,
err
:=
io
.
ReadAll
(
resp
.
Body
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"read upstream response: %w"
,
err
)
}
// 提取 usage
usage
=
s
.
extractClaudeUsage
(
respBody
)
c
.
Header
(
"Content-Type"
,
resp
.
Header
.
Get
(
"Content-Type"
))
c
.
Status
(
http
.
StatusOK
)
_
,
_
=
c
.
Writer
.
Write
(
respBody
)
}
// 构建计费结果
duration
:=
time
.
Since
(
startTime
)
log
.
Printf
(
"%s status=success duration_ms=%d"
,
prefix
,
duration
.
Milliseconds
())
return
&
ForwardResult
{
Model
:
billingModel
,
Stream
:
claudeReq
.
Stream
,
Duration
:
duration
,
FirstTokenMs
:
firstTokenMs
,
Usage
:
ClaudeUsage
{
InputTokens
:
usage
.
InputTokens
,
OutputTokens
:
usage
.
OutputTokens
,
CacheReadInputTokens
:
usage
.
CacheReadInputTokens
,
CacheCreationInputTokens
:
usage
.
CacheCreationInputTokens
,
},
},
nil
}
// streamUpstreamResponse 透传上游流式响应并提取 usage
func
(
s
*
AntigravityGatewayService
)
streamUpstreamResponse
(
c
*
gin
.
Context
,
resp
*
http
.
Response
,
startTime
time
.
Time
)
(
*
ClaudeUsage
,
*
int
)
{
usage
:=
&
ClaudeUsage
{}
var
firstTokenMs
*
int
var
firstTokenRecorded
bool
scanner
:=
bufio
.
NewScanner
(
resp
.
Body
)
buf
:=
make
([]
byte
,
0
,
64
*
1024
)
scanner
.
Buffer
(
buf
,
1024
*
1024
)
for
scanner
.
Scan
()
{
line
:=
scanner
.
Bytes
()
// 记录首 token 时间
if
!
firstTokenRecorded
&&
len
(
line
)
>
0
{
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
firstTokenMs
=
&
ms
firstTokenRecorded
=
true
}
// 尝试从 message_delta 或 message_stop 事件提取 usage
if
bytes
.
HasPrefix
(
line
,
[]
byte
(
"data: "
))
{
dataStr
:=
bytes
.
TrimPrefix
(
line
,
[]
byte
(
"data: "
))
var
event
map
[
string
]
any
if
json
.
Unmarshal
(
dataStr
,
&
event
)
==
nil
{
if
u
,
ok
:=
event
[
"usage"
]
.
(
map
[
string
]
any
);
ok
{
if
v
,
ok
:=
u
[
"input_tokens"
]
.
(
float64
);
ok
&&
int
(
v
)
>
0
{
usage
.
InputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"output_tokens"
]
.
(
float64
);
ok
&&
int
(
v
)
>
0
{
usage
.
OutputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"cache_read_input_tokens"
]
.
(
float64
);
ok
&&
int
(
v
)
>
0
{
usage
.
CacheReadInputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"cache_creation_input_tokens"
]
.
(
float64
);
ok
&&
int
(
v
)
>
0
{
usage
.
CacheCreationInputTokens
=
int
(
v
)
}
}
}
}
// 透传行
_
,
_
=
c
.
Writer
.
Write
(
line
)
_
,
_
=
c
.
Writer
.
Write
([]
byte
(
"
\n
"
))
c
.
Writer
.
Flush
()
}
return
usage
,
firstTokenMs
}
// extractClaudeUsage 从非流式 Claude 响应提取 usage
func
(
s
*
AntigravityGatewayService
)
extractClaudeUsage
(
body
[]
byte
)
*
ClaudeUsage
{
usage
:=
&
ClaudeUsage
{}
var
resp
map
[
string
]
any
if
json
.
Unmarshal
(
body
,
&
resp
)
!=
nil
{
return
usage
}
if
u
,
ok
:=
resp
[
"usage"
]
.
(
map
[
string
]
any
);
ok
{
if
v
,
ok
:=
u
[
"input_tokens"
]
.
(
float64
);
ok
{
usage
.
InputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"output_tokens"
]
.
(
float64
);
ok
{
usage
.
OutputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"cache_read_input_tokens"
]
.
(
float64
);
ok
{
usage
.
CacheReadInputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"cache_creation_input_tokens"
]
.
(
float64
);
ok
{
usage
.
CacheCreationInputTokens
=
int
(
v
)
}
}
return
usage
}
// ForwardGemini 转发 Gemini 协议请求
func
(
s
*
AntigravityGatewayService
)
ForwardGemini
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
originalModel
string
,
action
string
,
stream
bool
,
body
[]
byte
)
(
*
ForwardResult
,
error
)
{
startTime
:=
time
.
Now
()
...
...
@@ -1287,6 +1694,12 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
}
mappedModel
:=
s
.
getMappedModel
(
account
,
originalModel
)
billingModel
:=
originalModel
if
antigravityUseMappedModelForBilling
()
&&
strings
.
TrimSpace
(
mappedModel
)
!=
""
{
billingModel
=
mappedModel
}
afterSwitch
:=
antigravityHasAccountSwitch
(
ctx
)
maxRetries
:=
antigravityMaxRetriesForModel
(
originalModel
,
afterSwitch
)
// 获取 access_token
if
s
.
tokenProvider
==
nil
{
...
...
@@ -1306,8 +1719,15 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
proxyURL
=
account
.
Proxy
.
URL
()
}
// 过滤掉 parts 为空的消息(Gemini API 不接受空 parts)
filteredBody
,
err
:=
filterEmptyPartsFromGeminiRequest
(
body
)
if
err
!=
nil
{
log
.
Printf
(
"[Antigravity] Failed to filter empty parts: %v"
,
err
)
filteredBody
=
body
}
// Antigravity 上游要求必须包含身份提示词,注入到请求中
injectedBody
,
err
:=
injectIdentityPatchToGeminiRequest
(
b
ody
)
injectedBody
,
err
:=
injectIdentityPatchToGeminiRequest
(
filteredB
ody
)
if
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -1344,6 +1764,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
httpUpstream
:
s
.
httpUpstream
,
settingService
:
s
.
settingService
,
handleError
:
s
.
handleUpstreamError
,
maxRetries
:
maxRetries
,
})
if
err
!=
nil
{
return
nil
,
s
.
writeGoogleError
(
c
,
http
.
StatusBadGateway
,
"Upstream request failed after retries"
)
...
...
@@ -1493,7 +1914,7 @@ handleSuccess:
return
&
ForwardResult
{
RequestID
:
requestID
,
Usage
:
*
usage
,
Model
:
original
Model
,
Model
:
billing
Model
,
Stream
:
stream
,
Duration
:
time
.
Since
(
startTime
),
FirstTokenMs
:
firstTokenMs
,
...
...
@@ -1544,6 +1965,81 @@ func antigravityUseScopeRateLimit() bool {
return
true
}
func
antigravityHasAccountSwitch
(
ctx
context
.
Context
)
bool
{
if
ctx
==
nil
{
return
false
}
if
v
,
ok
:=
ctx
.
Value
(
ctxkey
.
AccountSwitchCount
)
.
(
int
);
ok
{
return
v
>
0
}
return
false
}
func
antigravityMaxRetries
()
int
{
raw
:=
strings
.
TrimSpace
(
os
.
Getenv
(
antigravityMaxRetriesEnv
))
if
raw
==
""
{
return
antigravityDefaultMaxRetries
}
value
,
err
:=
strconv
.
Atoi
(
raw
)
if
err
!=
nil
||
value
<=
0
{
return
antigravityDefaultMaxRetries
}
return
value
}
func
antigravityMaxRetriesAfterSwitch
()
int
{
raw
:=
strings
.
TrimSpace
(
os
.
Getenv
(
antigravityMaxRetriesAfterSwitchEnv
))
if
raw
==
""
{
return
antigravityMaxRetries
()
}
value
,
err
:=
strconv
.
Atoi
(
raw
)
if
err
!=
nil
||
value
<=
0
{
return
antigravityMaxRetries
()
}
return
value
}
// antigravityMaxRetriesForModel 根据模型类型获取重试次数
// 优先使用模型细分配置,未设置则回退到平台级配置
func
antigravityMaxRetriesForModel
(
model
string
,
afterSwitch
bool
)
int
{
var
envKey
string
if
strings
.
HasPrefix
(
model
,
"claude-"
)
{
envKey
=
antigravityMaxRetriesClaudeEnv
}
else
if
isImageGenerationModel
(
model
)
{
envKey
=
antigravityMaxRetriesGeminiImageEnv
}
else
if
strings
.
HasPrefix
(
model
,
"gemini-"
)
{
envKey
=
antigravityMaxRetriesGeminiTextEnv
}
if
envKey
!=
""
{
if
raw
:=
strings
.
TrimSpace
(
os
.
Getenv
(
envKey
));
raw
!=
""
{
if
value
,
err
:=
strconv
.
Atoi
(
raw
);
err
==
nil
&&
value
>
0
{
return
value
}
}
}
if
afterSwitch
{
return
antigravityMaxRetriesAfterSwitch
()
}
return
antigravityMaxRetries
()
}
func
antigravityUseMappedModelForBilling
()
bool
{
v
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
os
.
Getenv
(
antigravityBillingModelEnv
)))
return
v
==
"1"
||
v
==
"true"
||
v
==
"yes"
||
v
==
"on"
}
func
antigravityFallbackCooldownSeconds
()
(
time
.
Duration
,
bool
)
{
raw
:=
strings
.
TrimSpace
(
os
.
Getenv
(
antigravityFallbackSecondsEnv
))
if
raw
==
""
{
return
0
,
false
}
seconds
,
err
:=
strconv
.
Atoi
(
raw
)
if
err
!=
nil
||
seconds
<=
0
{
return
0
,
false
}
return
time
.
Duration
(
seconds
)
*
time
.
Second
,
true
}
func
(
s
*
AntigravityGatewayService
)
handleUpstreamError
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
quotaScope
AntigravityQuotaScope
)
{
// 429 使用 Gemini 格式解析(从 body 解析重置时间)
if
statusCode
==
429
{
...
...
@@ -1556,6 +2052,9 @@ func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, pre
fallbackMinutes
=
s
.
settingService
.
cfg
.
Gateway
.
AntigravityFallbackCooldownMinutes
}
defaultDur
:=
time
.
Duration
(
fallbackMinutes
)
*
time
.
Minute
if
fallbackDur
,
ok
:=
antigravityFallbackCooldownSeconds
();
ok
{
defaultDur
=
fallbackDur
}
ra
:=
time
.
Now
()
.
Add
(
defaultDur
)
if
useScopeLimit
{
log
.
Printf
(
"%s status=429 rate_limited scope=%s reset_in=%v (fallback)"
,
prefix
,
quotaScope
,
defaultDur
)
...
...
@@ -2193,6 +2692,10 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou
return
fmt
.
Errorf
(
"upstream error: %d message=%s"
,
upstreamStatus
,
upstreamMsg
)
}
func
(
s
*
AntigravityGatewayService
)
WriteMappedClaudeError
(
c
*
gin
.
Context
,
account
*
Account
,
upstreamStatus
int
,
upstreamRequestID
string
,
body
[]
byte
)
error
{
return
s
.
writeMappedClaudeError
(
c
,
account
,
upstreamStatus
,
upstreamRequestID
,
body
)
}
func
(
s
*
AntigravityGatewayService
)
writeGoogleError
(
c
*
gin
.
Context
,
status
int
,
message
string
)
error
{
statusStr
:=
"UNKNOWN"
switch
status
{
...
...
@@ -2618,3 +3121,55 @@ func cleanGeminiRequest(body []byte) ([]byte, error) {
return
json
.
Marshal
(
payload
)
}
// filterEmptyPartsFromGeminiRequest 过滤 Gemini 请求中 parts 为空的消息
// Gemini API 不接受 parts 为空数组的消息,会返回 400 错误
func
filterEmptyPartsFromGeminiRequest
(
body
[]
byte
)
([]
byte
,
error
)
{
var
payload
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
payload
);
err
!=
nil
{
return
nil
,
err
}
contents
,
ok
:=
payload
[
"contents"
]
.
([]
any
)
if
!
ok
||
len
(
contents
)
==
0
{
return
body
,
nil
}
filtered
:=
make
([]
any
,
0
,
len
(
contents
))
modified
:=
false
for
_
,
c
:=
range
contents
{
contentMap
,
ok
:=
c
.
(
map
[
string
]
any
)
if
!
ok
{
filtered
=
append
(
filtered
,
c
)
continue
}
parts
,
hasParts
:=
contentMap
[
"parts"
]
if
!
hasParts
{
filtered
=
append
(
filtered
,
c
)
continue
}
partsSlice
,
ok
:=
parts
.
([]
any
)
if
!
ok
{
filtered
=
append
(
filtered
,
c
)
continue
}
// 跳过 parts 为空数组的消息
if
len
(
partsSlice
)
==
0
{
modified
=
true
continue
}
filtered
=
append
(
filtered
,
c
)
}
if
!
modified
{
return
body
,
nil
}
payload
[
"contents"
]
=
filtered
return
json
.
Marshal
(
payload
)
}
backend/internal/service/antigravity_gateway_service_test.go
View file @
dd96ada3
package
service
import
(
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
...
...
@@ -81,3 +87,106 @@ func TestStripThinkingFromClaudeRequest_DoesNotDowngradeTools(t *testing.T) {
require
.
Equal
(
t
,
"secret plan"
,
blocks
[
0
][
"text"
])
require
.
Equal
(
t
,
"tool_use"
,
blocks
[
1
][
"type"
])
}
func
TestIsPromptTooLongError
(
t
*
testing
.
T
)
{
require
.
True
(
t
,
isPromptTooLongError
([]
byte
(
`{"error":{"message":"Prompt is too long"}}`
)))
require
.
True
(
t
,
isPromptTooLongError
([]
byte
(
`{"message":"Prompt is too long"}`
)))
require
.
False
(
t
,
isPromptTooLongError
([]
byte
(
`{"error":{"message":"other"}}`
)))
}
type
httpUpstreamStub
struct
{
resp
*
http
.
Response
err
error
}
func
(
s
*
httpUpstreamStub
)
Do
(
_
*
http
.
Request
,
_
string
,
_
int64
,
_
int
)
(
*
http
.
Response
,
error
)
{
return
s
.
resp
,
s
.
err
}
func
(
s
*
httpUpstreamStub
)
DoWithTLS
(
_
*
http
.
Request
,
_
string
,
_
int64
,
_
int
,
_
bool
)
(
*
http
.
Response
,
error
)
{
return
s
.
resp
,
s
.
err
}
func
TestAntigravityGatewayService_Forward_PromptTooLong
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
writer
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
writer
)
body
,
err
:=
json
.
Marshal
(
map
[
string
]
any
{
"model"
:
"claude-opus-4-5"
,
"messages"
:
[]
map
[
string
]
any
{
{
"role"
:
"user"
,
"content"
:
"hi"
},
},
"max_tokens"
:
1
,
"stream"
:
false
,
})
require
.
NoError
(
t
,
err
)
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1/messages"
,
bytes
.
NewReader
(
body
))
c
.
Request
=
req
respBody
:=
[]
byte
(
`{"error":{"message":"Prompt is too long"}}`
)
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusBadRequest
,
Header
:
http
.
Header
{
"X-Request-Id"
:
[]
string
{
"req-1"
}},
Body
:
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
)),
}
svc
:=
&
AntigravityGatewayService
{
tokenProvider
:
&
AntigravityTokenProvider
{},
httpUpstream
:
&
httpUpstreamStub
{
resp
:
resp
},
}
account
:=
&
Account
{
ID
:
1
,
Name
:
"acc-1"
,
Platform
:
PlatformAntigravity
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"token"
,
},
}
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
c
,
account
,
body
)
require
.
Nil
(
t
,
result
)
var
promptErr
*
PromptTooLongError
require
.
ErrorAs
(
t
,
err
,
&
promptErr
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
promptErr
.
StatusCode
)
require
.
Equal
(
t
,
"req-1"
,
promptErr
.
RequestID
)
require
.
NotEmpty
(
t
,
promptErr
.
Body
)
raw
,
ok
:=
c
.
Get
(
OpsUpstreamErrorsKey
)
require
.
True
(
t
,
ok
)
events
,
ok
:=
raw
.
([]
*
OpsUpstreamErrorEvent
)
require
.
True
(
t
,
ok
)
require
.
Len
(
t
,
events
,
1
)
require
.
Equal
(
t
,
"prompt_too_long"
,
events
[
0
]
.
Kind
)
}
func
TestAntigravityMaxRetriesForModel_AfterSwitch
(
t
*
testing
.
T
)
{
t
.
Setenv
(
antigravityMaxRetriesEnv
,
"4"
)
t
.
Setenv
(
antigravityMaxRetriesAfterSwitchEnv
,
"7"
)
t
.
Setenv
(
antigravityMaxRetriesClaudeEnv
,
""
)
t
.
Setenv
(
antigravityMaxRetriesGeminiTextEnv
,
""
)
t
.
Setenv
(
antigravityMaxRetriesGeminiImageEnv
,
""
)
got
:=
antigravityMaxRetriesForModel
(
"claude-sonnet-4-5"
,
false
)
require
.
Equal
(
t
,
4
,
got
)
got
=
antigravityMaxRetriesForModel
(
"claude-sonnet-4-5"
,
true
)
require
.
Equal
(
t
,
7
,
got
)
}
func
TestAntigravityMaxRetriesForModel_AfterSwitchFallback
(
t
*
testing
.
T
)
{
t
.
Setenv
(
antigravityMaxRetriesEnv
,
"5"
)
t
.
Setenv
(
antigravityMaxRetriesAfterSwitchEnv
,
""
)
t
.
Setenv
(
antigravityMaxRetriesClaudeEnv
,
""
)
t
.
Setenv
(
antigravityMaxRetriesGeminiTextEnv
,
""
)
t
.
Setenv
(
antigravityMaxRetriesGeminiImageEnv
,
""
)
got
:=
antigravityMaxRetriesForModel
(
"gemini-2.5-flash"
,
true
)
require
.
Equal
(
t
,
5
,
got
)
}
backend/internal/service/antigravity_quota_scope.go
View file @
dd96ada3
package
service
import
(
"slices"
"strings"
"time"
)
...
...
@@ -16,6 +17,21 @@ const (
AntigravityQuotaScopeGeminiImage
AntigravityQuotaScope
=
"gemini_image"
)
// IsScopeSupported 检查给定的 scope 是否在分组支持的 scope 列表中
func
IsScopeSupported
(
supportedScopes
[]
string
,
scope
AntigravityQuotaScope
)
bool
{
if
len
(
supportedScopes
)
==
0
{
// 未配置时默认全部支持
return
true
}
supported
:=
slices
.
Contains
(
supportedScopes
,
string
(
scope
))
return
supported
}
// ResolveAntigravityQuotaScope 根据模型名称解析配额域(导出版本)
func
ResolveAntigravityQuotaScope
(
requestedModel
string
)
(
AntigravityQuotaScope
,
bool
)
{
return
resolveAntigravityQuotaScope
(
requestedModel
)
}
// resolveAntigravityQuotaScope 根据模型名称解析配额域
func
resolveAntigravityQuotaScope
(
requestedModel
string
)
(
AntigravityQuotaScope
,
bool
)
{
model
:=
normalizeAntigravityModelName
(
requestedModel
)
...
...
backend/internal/service/api_key.go
View file @
dd96ada3
...
...
@@ -2,6 +2,14 @@ package service
import
"time"
// API Key status constants
const
(
StatusAPIKeyActive
=
"active"
StatusAPIKeyDisabled
=
"disabled"
StatusAPIKeyQuotaExhausted
=
"quota_exhausted"
StatusAPIKeyExpired
=
"expired"
)
type
APIKey
struct
{
ID
int64
UserID
int64
...
...
@@ -15,8 +23,53 @@ type APIKey struct {
UpdatedAt
time
.
Time
User
*
User
Group
*
Group
// Quota fields
Quota
float64
// Quota limit in USD (0 = unlimited)
QuotaUsed
float64
// Used quota amount
ExpiresAt
*
time
.
Time
// Expiration time (nil = never expires)
}
func
(
k
*
APIKey
)
IsActive
()
bool
{
return
k
.
Status
==
StatusActive
}
// IsExpired checks if the API key has expired
func
(
k
*
APIKey
)
IsExpired
()
bool
{
if
k
.
ExpiresAt
==
nil
{
return
false
}
return
time
.
Now
()
.
After
(
*
k
.
ExpiresAt
)
}
// IsQuotaExhausted checks if the API key quota is exhausted
func
(
k
*
APIKey
)
IsQuotaExhausted
()
bool
{
if
k
.
Quota
<=
0
{
return
false
// unlimited
}
return
k
.
QuotaUsed
>=
k
.
Quota
}
// GetQuotaRemaining returns remaining quota (-1 for unlimited)
func
(
k
*
APIKey
)
GetQuotaRemaining
()
float64
{
if
k
.
Quota
<=
0
{
return
-
1
// unlimited
}
remaining
:=
k
.
Quota
-
k
.
QuotaUsed
if
remaining
<
0
{
return
0
}
return
remaining
}
// GetDaysUntilExpiry returns days until expiry (-1 for never expires)
func
(
k
*
APIKey
)
GetDaysUntilExpiry
()
int
{
if
k
.
ExpiresAt
==
nil
{
return
-
1
// never expires
}
duration
:=
time
.
Until
(
*
k
.
ExpiresAt
)
if
duration
<
0
{
return
0
}
return
int
(
duration
.
Hours
()
/
24
)
}
backend/internal/service/api_key_auth_cache.go
View file @
dd96ada3
package
service
import
"time"
// APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段)
type
APIKeyAuthSnapshot
struct
{
APIKeyID
int64
`json:"api_key_id"`
...
...
@@ -10,6 +12,13 @@ type APIKeyAuthSnapshot struct {
IPBlacklist
[]
string
`json:"ip_blacklist,omitempty"`
User
APIKeyAuthUserSnapshot
`json:"user"`
Group
*
APIKeyAuthGroupSnapshot
`json:"group,omitempty"`
// Quota fields for API Key independent quota feature
Quota
float64
`json:"quota"`
// Quota limit in USD (0 = unlimited)
QuotaUsed
float64
`json:"quota_used"`
// Used quota amount
// Expiration field for API Key expiration feature
ExpiresAt
*
time
.
Time
`json:"expires_at,omitempty"`
// Expiration time (nil = never expires)
}
// APIKeyAuthUserSnapshot 用户快照
...
...
@@ -23,25 +32,30 @@ type APIKeyAuthUserSnapshot struct {
// APIKeyAuthGroupSnapshot 分组快照
type
APIKeyAuthGroupSnapshot
struct
{
ID
int64
`json:"id"`
Name
string
`json:"name"`
Platform
string
`json:"platform"`
Status
string
`json:"status"`
SubscriptionType
string
`json:"subscription_type"`
RateMultiplier
float64
`json:"rate_multiplier"`
DailyLimitUSD
*
float64
`json:"daily_limit_usd,omitempty"`
WeeklyLimitUSD
*
float64
`json:"weekly_limit_usd,omitempty"`
MonthlyLimitUSD
*
float64
`json:"monthly_limit_usd,omitempty"`
ImagePrice1K
*
float64
`json:"image_price_1k,omitempty"`
ImagePrice2K
*
float64
`json:"image_price_2k,omitempty"`
ImagePrice4K
*
float64
`json:"image_price_4k,omitempty"`
ClaudeCodeOnly
bool
`json:"claude_code_only"`
FallbackGroupID
*
int64
`json:"fallback_group_id,omitempty"`
ID
int64
`json:"id"`
Name
string
`json:"name"`
Platform
string
`json:"platform"`
Status
string
`json:"status"`
SubscriptionType
string
`json:"subscription_type"`
RateMultiplier
float64
`json:"rate_multiplier"`
DailyLimitUSD
*
float64
`json:"daily_limit_usd,omitempty"`
WeeklyLimitUSD
*
float64
`json:"weekly_limit_usd,omitempty"`
MonthlyLimitUSD
*
float64
`json:"monthly_limit_usd,omitempty"`
ImagePrice1K
*
float64
`json:"image_price_1k,omitempty"`
ImagePrice2K
*
float64
`json:"image_price_2k,omitempty"`
ImagePrice4K
*
float64
`json:"image_price_4k,omitempty"`
ClaudeCodeOnly
bool
`json:"claude_code_only"`
FallbackGroupID
*
int64
`json:"fallback_group_id,omitempty"`
FallbackGroupIDOnInvalidRequest
*
int64
`json:"fallback_group_id_on_invalid_request,omitempty"`
// Model routing is used by gateway account selection, so it must be part of auth cache snapshot.
// Only anthropic groups use these fields; others may leave them empty.
ModelRouting
map
[
string
][]
int64
`json:"model_routing,omitempty"`
ModelRoutingEnabled
bool
`json:"model_routing_enabled"`
MCPXMLInject
bool
`json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes
[]
string
`json:"supported_model_scopes,omitempty"`
}
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
...
...
backend/internal/service/api_key_auth_cache_impl.go
View file @
dd96ada3
...
...
@@ -213,6 +213,9 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
Status
:
apiKey
.
Status
,
IPWhitelist
:
apiKey
.
IPWhitelist
,
IPBlacklist
:
apiKey
.
IPBlacklist
,
Quota
:
apiKey
.
Quota
,
QuotaUsed
:
apiKey
.
QuotaUsed
,
ExpiresAt
:
apiKey
.
ExpiresAt
,
User
:
APIKeyAuthUserSnapshot
{
ID
:
apiKey
.
User
.
ID
,
Status
:
apiKey
.
User
.
Status
,
...
...
@@ -223,22 +226,25 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
}
if
apiKey
.
Group
!=
nil
{
snapshot
.
Group
=
&
APIKeyAuthGroupSnapshot
{
ID
:
apiKey
.
Group
.
ID
,
Name
:
apiKey
.
Group
.
Name
,
Platform
:
apiKey
.
Group
.
Platform
,
Status
:
apiKey
.
Group
.
Status
,
SubscriptionType
:
apiKey
.
Group
.
SubscriptionType
,
RateMultiplier
:
apiKey
.
Group
.
RateMultiplier
,
DailyLimitUSD
:
apiKey
.
Group
.
DailyLimitUSD
,
WeeklyLimitUSD
:
apiKey
.
Group
.
WeeklyLimitUSD
,
MonthlyLimitUSD
:
apiKey
.
Group
.
MonthlyLimitUSD
,
ImagePrice1K
:
apiKey
.
Group
.
ImagePrice1K
,
ImagePrice2K
:
apiKey
.
Group
.
ImagePrice2K
,
ImagePrice4K
:
apiKey
.
Group
.
ImagePrice4K
,
ClaudeCodeOnly
:
apiKey
.
Group
.
ClaudeCodeOnly
,
FallbackGroupID
:
apiKey
.
Group
.
FallbackGroupID
,
ModelRouting
:
apiKey
.
Group
.
ModelRouting
,
ModelRoutingEnabled
:
apiKey
.
Group
.
ModelRoutingEnabled
,
ID
:
apiKey
.
Group
.
ID
,
Name
:
apiKey
.
Group
.
Name
,
Platform
:
apiKey
.
Group
.
Platform
,
Status
:
apiKey
.
Group
.
Status
,
SubscriptionType
:
apiKey
.
Group
.
SubscriptionType
,
RateMultiplier
:
apiKey
.
Group
.
RateMultiplier
,
DailyLimitUSD
:
apiKey
.
Group
.
DailyLimitUSD
,
WeeklyLimitUSD
:
apiKey
.
Group
.
WeeklyLimitUSD
,
MonthlyLimitUSD
:
apiKey
.
Group
.
MonthlyLimitUSD
,
ImagePrice1K
:
apiKey
.
Group
.
ImagePrice1K
,
ImagePrice2K
:
apiKey
.
Group
.
ImagePrice2K
,
ImagePrice4K
:
apiKey
.
Group
.
ImagePrice4K
,
ClaudeCodeOnly
:
apiKey
.
Group
.
ClaudeCodeOnly
,
FallbackGroupID
:
apiKey
.
Group
.
FallbackGroupID
,
FallbackGroupIDOnInvalidRequest
:
apiKey
.
Group
.
FallbackGroupIDOnInvalidRequest
,
ModelRouting
:
apiKey
.
Group
.
ModelRouting
,
ModelRoutingEnabled
:
apiKey
.
Group
.
ModelRoutingEnabled
,
MCPXMLInject
:
apiKey
.
Group
.
MCPXMLInject
,
SupportedModelScopes
:
apiKey
.
Group
.
SupportedModelScopes
,
}
}
return
snapshot
...
...
@@ -256,6 +262,9 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
Status
:
snapshot
.
Status
,
IPWhitelist
:
snapshot
.
IPWhitelist
,
IPBlacklist
:
snapshot
.
IPBlacklist
,
Quota
:
snapshot
.
Quota
,
QuotaUsed
:
snapshot
.
QuotaUsed
,
ExpiresAt
:
snapshot
.
ExpiresAt
,
User
:
&
User
{
ID
:
snapshot
.
User
.
ID
,
Status
:
snapshot
.
User
.
Status
,
...
...
@@ -266,23 +275,26 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
}
if
snapshot
.
Group
!=
nil
{
apiKey
.
Group
=
&
Group
{
ID
:
snapshot
.
Group
.
ID
,
Name
:
snapshot
.
Group
.
Name
,
Platform
:
snapshot
.
Group
.
Platform
,
Status
:
snapshot
.
Group
.
Status
,
Hydrated
:
true
,
SubscriptionType
:
snapshot
.
Group
.
SubscriptionType
,
RateMultiplier
:
snapshot
.
Group
.
RateMultiplier
,
DailyLimitUSD
:
snapshot
.
Group
.
DailyLimitUSD
,
WeeklyLimitUSD
:
snapshot
.
Group
.
WeeklyLimitUSD
,
MonthlyLimitUSD
:
snapshot
.
Group
.
MonthlyLimitUSD
,
ImagePrice1K
:
snapshot
.
Group
.
ImagePrice1K
,
ImagePrice2K
:
snapshot
.
Group
.
ImagePrice2K
,
ImagePrice4K
:
snapshot
.
Group
.
ImagePrice4K
,
ClaudeCodeOnly
:
snapshot
.
Group
.
ClaudeCodeOnly
,
FallbackGroupID
:
snapshot
.
Group
.
FallbackGroupID
,
ModelRouting
:
snapshot
.
Group
.
ModelRouting
,
ModelRoutingEnabled
:
snapshot
.
Group
.
ModelRoutingEnabled
,
ID
:
snapshot
.
Group
.
ID
,
Name
:
snapshot
.
Group
.
Name
,
Platform
:
snapshot
.
Group
.
Platform
,
Status
:
snapshot
.
Group
.
Status
,
Hydrated
:
true
,
SubscriptionType
:
snapshot
.
Group
.
SubscriptionType
,
RateMultiplier
:
snapshot
.
Group
.
RateMultiplier
,
DailyLimitUSD
:
snapshot
.
Group
.
DailyLimitUSD
,
WeeklyLimitUSD
:
snapshot
.
Group
.
WeeklyLimitUSD
,
MonthlyLimitUSD
:
snapshot
.
Group
.
MonthlyLimitUSD
,
ImagePrice1K
:
snapshot
.
Group
.
ImagePrice1K
,
ImagePrice2K
:
snapshot
.
Group
.
ImagePrice2K
,
ImagePrice4K
:
snapshot
.
Group
.
ImagePrice4K
,
ClaudeCodeOnly
:
snapshot
.
Group
.
ClaudeCodeOnly
,
FallbackGroupID
:
snapshot
.
Group
.
FallbackGroupID
,
FallbackGroupIDOnInvalidRequest
:
snapshot
.
Group
.
FallbackGroupIDOnInvalidRequest
,
ModelRouting
:
snapshot
.
Group
.
ModelRouting
,
ModelRoutingEnabled
:
snapshot
.
Group
.
ModelRoutingEnabled
,
MCPXMLInject
:
snapshot
.
Group
.
MCPXMLInject
,
SupportedModelScopes
:
snapshot
.
Group
.
SupportedModelScopes
,
}
}
return
apiKey
...
...
backend/internal/service/api_key_service.go
View file @
dd96ada3
...
...
@@ -24,6 +24,10 @@ var (
ErrAPIKeyInvalidChars
=
infraerrors
.
BadRequest
(
"API_KEY_INVALID_CHARS"
,
"api key can only contain letters, numbers, underscores, and hyphens"
)
ErrAPIKeyRateLimited
=
infraerrors
.
TooManyRequests
(
"API_KEY_RATE_LIMITED"
,
"too many failed attempts, please try again later"
)
ErrInvalidIPPattern
=
infraerrors
.
BadRequest
(
"INVALID_IP_PATTERN"
,
"invalid IP or CIDR pattern"
)
// ErrAPIKeyExpired = infraerrors.Forbidden("API_KEY_EXPIRED", "api key has expired")
ErrAPIKeyExpired
=
infraerrors
.
Forbidden
(
"API_KEY_EXPIRED"
,
"api key 已过期"
)
// ErrAPIKeyQuotaExhausted = infraerrors.TooManyRequests("API_KEY_QUOTA_EXHAUSTED", "api key quota exhausted")
ErrAPIKeyQuotaExhausted
=
infraerrors
.
TooManyRequests
(
"API_KEY_QUOTA_EXHAUSTED"
,
"api key 额度已用完"
)
)
const
(
...
...
@@ -51,6 +55,9 @@ type APIKeyRepository interface {
CountByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
ListKeysByUserID
(
ctx
context
.
Context
,
userID
int64
)
([]
string
,
error
)
ListKeysByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
([]
string
,
error
)
// Quota methods
IncrementQuotaUsed
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
(
float64
,
error
)
}
// APIKeyCache defines cache operations for API key service
...
...
@@ -85,6 +92,10 @@ type CreateAPIKeyRequest struct {
CustomKey
*
string
`json:"custom_key"`
// 可选的自定义key
IPWhitelist
[]
string
`json:"ip_whitelist"`
// IP 白名单
IPBlacklist
[]
string
`json:"ip_blacklist"`
// IP 黑名单
// Quota fields
Quota
float64
`json:"quota"`
// Quota limit in USD (0 = unlimited)
ExpiresInDays
*
int
`json:"expires_in_days"`
// Days until expiry (nil = never expires)
}
// UpdateAPIKeyRequest 更新API Key请求
...
...
@@ -94,6 +105,12 @@ type UpdateAPIKeyRequest struct {
Status
*
string
`json:"status"`
IPWhitelist
[]
string
`json:"ip_whitelist"`
// IP 白名单(空数组清空)
IPBlacklist
[]
string
`json:"ip_blacklist"`
// IP 黑名单(空数组清空)
// Quota fields
Quota
*
float64
`json:"quota"`
// Quota limit in USD (nil = no change, 0 = unlimited)
ExpiresAt
*
time
.
Time
`json:"expires_at"`
// Expiration time (nil = no change)
ClearExpiration
bool
`json:"-"`
// Clear expiration (internal use)
ResetQuota
*
bool
`json:"reset_quota"`
// Reset quota_used to 0
}
// APIKeyService API Key服务
...
...
@@ -289,6 +306,14 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
Status
:
StatusActive
,
IPWhitelist
:
req
.
IPWhitelist
,
IPBlacklist
:
req
.
IPBlacklist
,
Quota
:
req
.
Quota
,
QuotaUsed
:
0
,
}
// Set expiration time if specified
if
req
.
ExpiresInDays
!=
nil
&&
*
req
.
ExpiresInDays
>
0
{
expiresAt
:=
time
.
Now
()
.
AddDate
(
0
,
0
,
*
req
.
ExpiresInDays
)
apiKey
.
ExpiresAt
=
&
expiresAt
}
if
err
:=
s
.
apiKeyRepo
.
Create
(
ctx
,
apiKey
);
err
!=
nil
{
...
...
@@ -436,6 +461,35 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
}
}
// Update quota fields
if
req
.
Quota
!=
nil
{
apiKey
.
Quota
=
*
req
.
Quota
// If quota is increased and status was quota_exhausted, reactivate
if
apiKey
.
Status
==
StatusAPIKeyQuotaExhausted
&&
*
req
.
Quota
>
apiKey
.
QuotaUsed
{
apiKey
.
Status
=
StatusActive
}
}
if
req
.
ResetQuota
!=
nil
&&
*
req
.
ResetQuota
{
apiKey
.
QuotaUsed
=
0
// If resetting quota and status was quota_exhausted, reactivate
if
apiKey
.
Status
==
StatusAPIKeyQuotaExhausted
{
apiKey
.
Status
=
StatusActive
}
}
if
req
.
ClearExpiration
{
apiKey
.
ExpiresAt
=
nil
// If clearing expiry and status was expired, reactivate
if
apiKey
.
Status
==
StatusAPIKeyExpired
{
apiKey
.
Status
=
StatusActive
}
}
else
if
req
.
ExpiresAt
!=
nil
{
apiKey
.
ExpiresAt
=
req
.
ExpiresAt
// If extending expiry and status was expired, reactivate
if
apiKey
.
Status
==
StatusAPIKeyExpired
&&
time
.
Now
()
.
Before
(
*
req
.
ExpiresAt
)
{
apiKey
.
Status
=
StatusActive
}
}
// 更新 IP 限制(空数组会清空设置)
apiKey
.
IPWhitelist
=
req
.
IPWhitelist
apiKey
.
IPBlacklist
=
req
.
IPBlacklist
...
...
@@ -572,3 +626,51 @@ func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword
}
return
keys
,
nil
}
// CheckAPIKeyQuotaAndExpiry checks if the API key is valid for use (not expired, quota not exhausted)
// Returns nil if valid, error if invalid
func
(
s
*
APIKeyService
)
CheckAPIKeyQuotaAndExpiry
(
apiKey
*
APIKey
)
error
{
// Check expiration
if
apiKey
.
IsExpired
()
{
return
ErrAPIKeyExpired
}
// Check quota
if
apiKey
.
IsQuotaExhausted
()
{
return
ErrAPIKeyQuotaExhausted
}
return
nil
}
// UpdateQuotaUsed updates the quota_used field after a request
// Also checks if quota is exhausted and updates status accordingly
func
(
s
*
APIKeyService
)
UpdateQuotaUsed
(
ctx
context
.
Context
,
apiKeyID
int64
,
cost
float64
)
error
{
if
cost
<=
0
{
return
nil
}
// Use repository to atomically increment quota_used
newQuotaUsed
,
err
:=
s
.
apiKeyRepo
.
IncrementQuotaUsed
(
ctx
,
apiKeyID
,
cost
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"increment quota used: %w"
,
err
)
}
// Check if quota is now exhausted and update status if needed
apiKey
,
err
:=
s
.
apiKeyRepo
.
GetByID
(
ctx
,
apiKeyID
)
if
err
!=
nil
{
return
nil
// Don't fail the request, just log
}
// If quota is set and now exhausted, update status
if
apiKey
.
Quota
>
0
&&
newQuotaUsed
>=
apiKey
.
Quota
{
apiKey
.
Status
=
StatusAPIKeyQuotaExhausted
if
err
:=
s
.
apiKeyRepo
.
Update
(
ctx
,
apiKey
);
err
!=
nil
{
return
nil
// Don't fail the request
}
// Invalidate cache so next request sees the new status
s
.
InvalidateAuthCacheByKey
(
ctx
,
apiKey
.
Key
)
}
return
nil
}
backend/internal/service/api_key_service_cache_test.go
View file @
dd96ada3
...
...
@@ -99,6 +99,10 @@ func (s *authRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]
return
s
.
listKeysByGroupID
(
ctx
,
groupID
)
}
func
(
s
*
authRepoStub
)
IncrementQuotaUsed
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
(
float64
,
error
)
{
panic
(
"unexpected IncrementQuotaUsed call"
)
}
type
authCacheStub
struct
{
getAuthCache
func
(
ctx
context
.
Context
,
key
string
)
(
*
APIKeyAuthCacheEntry
,
error
)
setAuthKeys
[]
string
...
...
backend/internal/service/api_key_service_delete_test.go
View file @
dd96ada3
...
...
@@ -118,6 +118,10 @@ func (s *apiKeyRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) (
panic
(
"unexpected ListKeysByGroupID call"
)
}
func
(
s
*
apiKeyRepoStub
)
IncrementQuotaUsed
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
(
float64
,
error
)
{
panic
(
"unexpected IncrementQuotaUsed call"
)
}
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
// 用于验证删除操作时缓存清理逻辑是否被正确调用。
//
...
...
backend/internal/service/auth_service.go
View file @
dd96ada3
...
...
@@ -185,7 +185,6 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
log
.
Printf
(
"[Auth] Failed to mark invitation code as used for user %d: %v"
,
user
.
ID
,
err
)
}
}
// 应用优惠码(如果提供且功能已启用)
if
promoCode
!=
""
&&
s
.
promoService
!=
nil
&&
s
.
settingService
!=
nil
&&
s
.
settingService
.
IsPromoCodeEnabled
(
ctx
)
{
if
err
:=
s
.
promoService
.
ApplyPromoCode
(
ctx
,
user
.
ID
,
promoCode
);
err
!=
nil
{
...
...
backend/internal/service/domain_constants.go
View file @
dd96ada3
...
...
@@ -31,6 +31,7 @@ const (
AccountTypeOAuth
=
domain
.
AccountTypeOAuth
// OAuth类型账号(full scope: profile + inference)
AccountTypeSetupToken
=
domain
.
AccountTypeSetupToken
// Setup Token类型账号(inference only scope)
AccountTypeAPIKey
=
domain
.
AccountTypeAPIKey
// API Key类型账号
AccountTypeUpstream
=
domain
.
AccountTypeUpstream
// 上游透传类型账号(通过 Base URL + API Key 连接上游)
)
// Redeem type constants
...
...
backend/internal/service/gateway_service.go
View file @
dd96ada3
...
...
@@ -257,6 +257,9 @@ var (
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
var
ErrClaudeCodeOnly
=
errors
.
New
(
"this group only allows Claude Code clients"
)
// ErrModelScopeNotSupported 表示请求的模型系列不在分组支持的范围内
var
ErrModelScopeNotSupported
=
errors
.
New
(
"model scope not supported by this group"
)
// allowedHeaders 白名单headers(参考CRS项目)
var
allowedHeaders
=
map
[
string
]
bool
{
"accept"
:
true
,
...
...
@@ -585,12 +588,18 @@ func (s *GatewayService) hashContent(content string) string {
}
// replaceModelInBody 替换请求体中的model字段
// 使用 json.RawMessage 保留其他字段的原始字节,避免 thinking 块等内容被修改
func
(
s
*
GatewayService
)
replaceModelInBody
(
body
[]
byte
,
newModel
string
)
[]
byte
{
var
req
map
[
string
]
any
var
req
map
[
string
]
json
.
RawMessage
if
err
:=
json
.
Unmarshal
(
body
,
&
req
);
err
!=
nil
{
return
body
}
req
[
"model"
]
=
newModel
// 只序列化 model 字段
modelBytes
,
err
:=
json
.
Marshal
(
newModel
)
if
err
!=
nil
{
return
body
}
req
[
"model"
]
=
modelBytes
newBody
,
err
:=
json
.
Marshal
(
req
)
if
err
!=
nil
{
return
body
...
...
@@ -787,12 +796,21 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
if
len
(
body
)
==
0
{
return
body
,
modelID
,
nil
}
// 使用 json.RawMessage 保留 messages 的原始字节,避免 thinking 块被修改
var
reqRaw
map
[
string
]
json
.
RawMessage
if
err
:=
json
.
Unmarshal
(
body
,
&
reqRaw
);
err
!=
nil
{
return
body
,
modelID
,
nil
}
// 同时解析为 map[string]any 用于修改非 messages 字段
var
req
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
req
);
err
!=
nil
{
return
body
,
modelID
,
nil
}
toolNameMap
:=
make
(
map
[
string
]
string
)
modified
:=
false
if
system
,
ok
:=
req
[
"system"
];
ok
{
switch
v
:=
system
.
(
type
)
{
...
...
@@ -800,6 +818,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
sanitized
:=
sanitizeSystemText
(
v
)
if
sanitized
!=
v
{
req
[
"system"
]
=
sanitized
modified
=
true
}
case
[]
any
:
for
_
,
item
:=
range
v
{
...
...
@@ -817,6 +836,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
sanitized
:=
sanitizeSystemText
(
text
)
if
sanitized
!=
text
{
block
[
"text"
]
=
sanitized
modified
=
true
}
}
}
...
...
@@ -827,6 +847,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
if
normalized
!=
rawModel
{
req
[
"model"
]
=
normalized
modelID
=
normalized
modified
=
true
}
}
...
...
@@ -842,16 +863,19 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
normalized
:=
normalizeToolNameForClaude
(
name
,
toolNameMap
)
if
normalized
!=
""
&&
normalized
!=
name
{
toolMap
[
"name"
]
=
normalized
modified
=
true
}
}
if
desc
,
ok
:=
toolMap
[
"description"
]
.
(
string
);
ok
{
sanitized
:=
sanitizeToolDescription
(
desc
)
if
sanitized
!=
desc
{
toolMap
[
"description"
]
=
sanitized
modified
=
true
}
}
if
schema
,
ok
:=
toolMap
[
"input_schema"
];
ok
{
normalizeToolInputSchema
(
schema
,
toolNameMap
)
modified
=
true
}
tools
[
idx
]
=
toolMap
}
...
...
@@ -880,11 +904,15 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
normalizedTools
[
normalized
]
=
value
}
req
[
"tools"
]
=
normalizedTools
modified
=
true
}
}
else
{
req
[
"tools"
]
=
[]
any
{}
modified
=
true
}
// 处理 messages 中的 tool_use 块,但保留包含 thinking 块的消息的原始字节
messagesModified
:=
false
if
messages
,
ok
:=
req
[
"messages"
]
.
([]
any
);
ok
{
for
_
,
msg
:=
range
messages
{
msgMap
,
ok
:=
msg
.
(
map
[
string
]
any
)
...
...
@@ -895,6 +923,24 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
if
!
ok
{
continue
}
// 检查此消息是否包含 thinking 块
hasThinking
:=
false
for
_
,
block
:=
range
content
{
blockMap
,
ok
:=
block
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
blockType
,
_
:=
blockMap
[
"type"
]
.
(
string
)
if
blockType
==
"thinking"
||
blockType
==
"redacted_thinking"
{
hasThinking
=
true
break
}
}
// 如果包含 thinking 块,跳过此消息的修改
if
hasThinking
{
continue
}
// 只修改不包含 thinking 块的消息中的 tool_use
for
_
,
block
:=
range
content
{
blockMap
,
ok
:=
block
.
(
map
[
string
]
any
)
if
!
ok
{
...
...
@@ -907,6 +953,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
normalized
:=
normalizeToolNameForClaude
(
name
,
toolNameMap
)
if
normalized
!=
""
&&
normalized
!=
name
{
blockMap
[
"name"
]
=
normalized
messagesModified
=
true
}
}
}
...
...
@@ -916,6 +963,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
if
opts
.
stripSystemCacheControl
{
if
system
,
ok
:=
req
[
"system"
];
ok
{
_
=
stripCacheControlFromSystemBlocks
(
system
)
modified
=
true
}
}
...
...
@@ -927,12 +975,46 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
}
if
existing
,
ok
:=
metadata
[
"user_id"
]
.
(
string
);
!
ok
||
existing
==
""
{
metadata
[
"user_id"
]
=
opts
.
metadataUserID
modified
=
true
}
}
delete
(
req
,
"temperature"
)
delete
(
req
,
"tool_choice"
)
if
_
,
hasTemp
:=
req
[
"temperature"
];
hasTemp
{
delete
(
req
,
"temperature"
)
modified
=
true
}
if
_
,
hasChoice
:=
req
[
"tool_choice"
];
hasChoice
{
delete
(
req
,
"tool_choice"
)
modified
=
true
}
if
!
modified
&&
!
messagesModified
{
return
body
,
modelID
,
toolNameMap
}
// 如果 messages 没有被修改,保留原始 messages 字节
if
!
messagesModified
{
// 序列化非 messages 字段
newBody
,
err
:=
json
.
Marshal
(
req
)
if
err
!=
nil
{
return
body
,
modelID
,
toolNameMap
}
// 替换回原始的 messages
var
newReq
map
[
string
]
json
.
RawMessage
if
err
:=
json
.
Unmarshal
(
newBody
,
&
newReq
);
err
!=
nil
{
return
newBody
,
modelID
,
toolNameMap
}
if
origMessages
,
ok
:=
reqRaw
[
"messages"
];
ok
{
newReq
[
"messages"
]
=
origMessages
}
finalBody
,
err
:=
json
.
Marshal
(
newReq
)
if
err
!=
nil
{
return
newBody
,
modelID
,
toolNameMap
}
return
finalBody
,
modelID
,
toolNameMap
}
// messages 被修改了,需要完整序列化
newBody
,
err
:=
json
.
Marshal
(
req
)
if
err
!=
nil
{
return
body
,
modelID
,
toolNameMap
...
...
@@ -1135,6 +1217,13 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
log
.
Printf
(
"[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
platform
)
}
// Antigravity 模型系列检查(在账号选择前检查,确保所有代码路径都经过此检查)
if
platform
==
PlatformAntigravity
&&
groupID
!=
nil
&&
requestedModel
!=
""
{
if
err
:=
s
.
checkAntigravityModelScope
(
ctx
,
*
groupID
,
requestedModel
);
err
!=
nil
{
return
nil
,
err
}
}
accounts
,
useMixed
,
err
:=
s
.
listSchedulableAccounts
(
ctx
,
groupID
,
platform
,
hasForcePlatform
)
if
err
!=
nil
{
return
nil
,
err
...
...
@@ -1632,6 +1721,10 @@ func (s *GatewayService) resolveGroupByID(ctx context.Context, groupID int64) (*
return
group
,
nil
}
func
(
s
*
GatewayService
)
ResolveGroupByID
(
ctx
context
.
Context
,
groupID
int64
)
(
*
Group
,
error
)
{
return
s
.
resolveGroupByID
(
ctx
,
groupID
)
}
func
(
s
*
GatewayService
)
routingAccountIDsForRequest
(
ctx
context
.
Context
,
groupID
*
int64
,
requestedModel
string
,
platform
string
)
[]
int64
{
if
groupID
==
nil
||
requestedModel
==
""
||
platform
!=
PlatformAnthropic
{
return
nil
...
...
@@ -1697,7 +1790,7 @@ func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID
}
// 强制平台模式不检查 Claude Code 限制
if
_
,
hasForcePlatform
:=
ctx
.
Value
(
ctxkey
.
ForcePlatform
)
.
(
string
);
hasForcePlatform
{
if
forcePlatform
,
hasForcePlatform
:=
ctx
.
Value
(
ctxkey
.
ForcePlatform
)
.
(
string
);
hasForcePlatform
&&
forcePlatform
!=
""
{
return
nil
,
groupID
,
nil
}
...
...
@@ -2026,6 +2119,13 @@ func shuffleWithinPriority(accounts []*Account) {
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
func
(
s
*
GatewayService
)
selectAccountForModelWithPlatform
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{},
platform
string
)
(
*
Account
,
error
)
{
// 对 Antigravity 平台,检查请求的模型系列是否在分组支持范围内
if
platform
==
PlatformAntigravity
&&
groupID
!=
nil
&&
requestedModel
!=
""
{
if
err
:=
s
.
checkAntigravityModelScope
(
ctx
,
*
groupID
,
requestedModel
);
err
!=
nil
{
return
nil
,
err
}
}
preferOAuth
:=
platform
==
PlatformGemini
routingAccountIDs
:=
s
.
routingAccountIDsForRequest
(
ctx
,
groupID
,
requestedModel
,
platform
)
...
...
@@ -2461,6 +2561,10 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
// Antigravity 平台使用专门的模型支持检查
return
IsAntigravityModelSupported
(
requestedModel
)
}
// OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID)
if
account
.
Platform
==
PlatformAnthropic
&&
account
.
Type
!=
AccountTypeAPIKey
{
requestedModel
=
claude
.
NormalizeModelID
(
requestedModel
)
}
// Gemini API Key 账户直接透传,由上游判断模型是否支持
if
account
.
Platform
==
PlatformGemini
&&
account
.
Type
==
AccountTypeAPIKey
{
return
true
...
...
@@ -2910,16 +3014,30 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 强制执行 cache_control 块数量限制(最多 4 个)
body
=
enforceCacheControlLimit
(
body
)
// 应用模型映射(仅对apikey类型账号)
// 应用模型映射:
// - APIKey 账号:使用账号级别的显式映射(如果配置),否则透传原始模型名
// - OAuth/SetupToken 账号:使用 Anthropic 标准映射(短ID → 长ID)
mappedModel
:=
reqModel
mappingSource
:=
""
if
account
.
Type
==
AccountTypeAPIKey
{
mappedModel
:
=
account
.
GetMappedModel
(
reqModel
)
mappedModel
=
account
.
GetMappedModel
(
reqModel
)
if
mappedModel
!=
reqModel
{
// 替换请求体中的模型名
body
=
s
.
replaceModelInBody
(
body
,
mappedModel
)
reqModel
=
mappedModel
log
.
Printf
(
"Model mapping applied: %s -> %s (account: %s)"
,
originalModel
,
mappedModel
,
account
.
Name
)
mappingSource
=
"account"
}
}
if
mappingSource
==
""
&&
account
.
Platform
==
PlatformAnthropic
&&
account
.
Type
!=
AccountTypeAPIKey
{
normalized
:=
claude
.
NormalizeModelID
(
reqModel
)
if
normalized
!=
reqModel
{
mappedModel
=
normalized
mappingSource
=
"prefix"
}
}
if
mappedModel
!=
reqModel
{
// 替换请求体中的模型名
body
=
s
.
replaceModelInBody
(
body
,
mappedModel
)
reqModel
=
mappedModel
log
.
Printf
(
"Model mapping applied: %s -> %s (account: %s, source=%s)"
,
originalModel
,
mappedModel
,
account
.
Name
,
mappingSource
)
}
// 获取凭证
token
,
tokenType
,
err
:=
s
.
GetAccessToken
(
ctx
,
account
)
...
...
@@ -3621,6 +3739,13 @@ func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool {
return
true
}
// 检测 thinking block 被修改的错误
// 例如: "thinking or redacted_thinking blocks in the latest assistant message cannot be modified"
if
strings
.
Contains
(
msg
,
"cannot be modified"
)
&&
(
strings
.
Contains
(
msg
,
"thinking"
)
||
strings
.
Contains
(
msg
,
"redacted_thinking"
))
{
log
.
Printf
(
"[SignatureCheck] Detected thinking block modification error"
)
return
true
}
// 检测空消息内容错误(可能是过滤 thinking blocks 后导致的)
// 例如: "all messages must have non-empty content"
if
strings
.
Contains
(
msg
,
"non-empty content"
)
||
strings
.
Contains
(
msg
,
"empty content"
)
{
...
...
@@ -4489,13 +4614,19 @@ func (s *GatewayService) replaceToolNamesInResponseBody(body []byte, toolNameMap
// RecordUsageInput 记录使用量的输入参数
type
RecordUsageInput
struct
{
Result
*
ForwardResult
APIKey
*
APIKey
User
*
User
Account
*
Account
Subscription
*
UserSubscription
// 可选:订阅信息
UserAgent
string
// 请求的 User-Agent
IPAddress
string
// 请求的客户端 IP 地址
Result
*
ForwardResult
APIKey
*
APIKey
User
*
User
Account
*
Account
Subscription
*
UserSubscription
// 可选:订阅信息
UserAgent
string
// 请求的 User-Agent
IPAddress
string
// 请求的客户端 IP 地址
APIKeyService
APIKeyQuotaUpdater
// 可选:用于更新API Key配额
}
// APIKeyQuotaUpdater defines the interface for updating API Key quota
type
APIKeyQuotaUpdater
interface
{
UpdateQuotaUsed
(
ctx
context
.
Context
,
apiKeyID
int64
,
cost
float64
)
error
}
// RecordUsage 记录使用量并扣费(或更新订阅用量)
...
...
@@ -4635,6 +4766,13 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
}
}
// 更新 API Key 配额(如果设置了配额限制)
if
shouldBill
&&
cost
.
ActualCost
>
0
&&
apiKey
.
Quota
>
0
&&
input
.
APIKeyService
!=
nil
{
if
err
:=
input
.
APIKeyService
.
UpdateQuotaUsed
(
ctx
,
apiKey
.
ID
,
cost
.
ActualCost
);
err
!=
nil
{
log
.
Printf
(
"Update API key quota failed: %v"
,
err
)
}
}
// Schedule batch update for account last_used_at
s
.
deferredService
.
ScheduleLastUsedUpdate
(
account
.
ID
)
...
...
@@ -4652,6 +4790,7 @@ type RecordUsageLongContextInput struct {
IPAddress
string
// 请求的客户端 IP 地址
LongContextThreshold
int
// 长上下文阈值(如 200000)
LongContextMultiplier
float64
// 超出阈值部分的倍率(如 2.0)
APIKeyService
*
APIKeyService
// API Key 配额服务(可选)
}
// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini)
...
...
@@ -4788,6 +4927,12 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
}
// 异步更新余额缓存
s
.
billingCacheService
.
QueueDeductBalance
(
user
.
ID
,
cost
.
ActualCost
)
// API Key 独立配额扣费
if
input
.
APIKeyService
!=
nil
&&
apiKey
.
Quota
>
0
{
if
err
:=
input
.
APIKeyService
.
UpdateQuotaUsed
(
ctx
,
apiKey
.
ID
,
cost
.
ActualCost
);
err
!=
nil
{
log
.
Printf
(
"Add API key quota used failed: %v"
,
err
)
}
}
}
}
...
...
@@ -4822,16 +4967,30 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
return
nil
}
// 应用模型映射(仅对 apikey 类型账号)
if
account
.
Type
==
AccountTypeAPIKey
{
if
reqModel
!=
""
{
mappedModel
:=
account
.
GetMappedModel
(
reqModel
)
// 应用模型映射:
// - APIKey 账号:使用账号级别的显式映射(如果配置),否则透传原始模型名
// - OAuth/SetupToken 账号:使用 Anthropic 标准映射(短ID → 长ID)
if
reqModel
!=
""
{
mappedModel
:=
reqModel
mappingSource
:=
""
if
account
.
Type
==
AccountTypeAPIKey
{
mappedModel
=
account
.
GetMappedModel
(
reqModel
)
if
mappedModel
!=
reqModel
{
body
=
s
.
replaceModelInBody
(
body
,
mappedModel
)
reqModel
=
mappedModel
log
.
Printf
(
"CountTokens model mapping applied: %s -> %s (account: %s)"
,
parsed
.
Model
,
mappedModel
,
account
.
Name
)
mappingSource
=
"account"
}
}
if
mappingSource
==
""
&&
account
.
Platform
==
PlatformAnthropic
&&
account
.
Type
!=
AccountTypeAPIKey
{
normalized
:=
claude
.
NormalizeModelID
(
reqModel
)
if
normalized
!=
reqModel
{
mappedModel
=
normalized
mappingSource
=
"prefix"
}
}
if
mappedModel
!=
reqModel
{
body
=
s
.
replaceModelInBody
(
body
,
mappedModel
)
reqModel
=
mappedModel
log
.
Printf
(
"CountTokens model mapping applied: %s -> %s (account: %s, source=%s)"
,
parsed
.
Model
,
mappedModel
,
account
.
Name
,
mappingSource
)
}
}
// 获取凭证
...
...
@@ -5083,6 +5242,27 @@ func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
return
normalized
,
nil
}
// checkAntigravityModelScope 检查 Antigravity 平台的模型系列是否在分组支持范围内
func
(
s
*
GatewayService
)
checkAntigravityModelScope
(
ctx
context
.
Context
,
groupID
int64
,
requestedModel
string
)
error
{
scope
,
ok
:=
ResolveAntigravityQuotaScope
(
requestedModel
)
if
!
ok
{
return
nil
// 无法解析 scope,跳过检查
}
group
,
err
:=
s
.
resolveGroupByID
(
ctx
,
groupID
)
if
err
!=
nil
{
return
nil
// 查询失败时放行
}
if
group
==
nil
{
return
nil
// 分组不存在时放行
}
if
!
IsScopeSupported
(
group
.
SupportedModelScopes
,
scope
)
{
return
ErrModelScopeNotSupported
}
return
nil
}
// GetAvailableModels returns the list of models available for a group
// It aggregates model_mapping keys from all schedulable accounts in the group
func
(
s
*
GatewayService
)
GetAvailableModels
(
ctx
context
.
Context
,
groupID
*
int64
,
platform
string
)
[]
string
{
...
...
backend/internal/service/gemini_messages_compat_service.go
View file @
dd96ada3
...
...
@@ -977,6 +977,11 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return
nil
,
s
.
writeGoogleError
(
c
,
http
.
StatusBadRequest
,
"Request body is empty"
)
}
// 过滤掉 parts 为空的消息(Gemini API 不接受空 parts)
if
filteredBody
,
err
:=
filterEmptyPartsFromGeminiRequest
(
body
);
err
==
nil
{
body
=
filteredBody
}
switch
action
{
case
"generateContent"
,
"streamGenerateContent"
,
"countTokens"
:
// ok
...
...
backend/internal/service/gemini_native_signature_cleaner.go
View file @
dd96ada3
...
...
@@ -2,20 +2,22 @@ package service
import
(
"encoding/json"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
// CleanGeminiNativeThoughtSignatures 从 Gemini 原生 API 请求中
移除
thoughtSignature 字段,
// CleanGeminiNativeThoughtSignatures 从 Gemini 原生 API 请求中
替换
thoughtSignature 字段
为 dummy 签名
,
// 以避免跨账号签名验证错误。
//
// 当粘性会话切换账号时(例如原账号异常、不可调度等),旧账号返回的 thoughtSignature
// 会导致新账号的签名验证失败。通过
移除这些签名,让新账号重新生成有效的签名
。
// 会导致新账号的签名验证失败。通过
替换为 dummy 签名,跳过签名验证
。
//
// CleanGeminiNativeThoughtSignatures re
mov
es thoughtSignature fields
from Gemini native API requests
// to avoid cross-account signature validation errors.
// CleanGeminiNativeThoughtSignatures re
plac
es thoughtSignature fields
with dummy signature
//
in Gemini native API requests
to avoid cross-account signature validation errors.
//
// When sticky session switches accounts (e.g., original account becomes unavailable),
// thoughtSignatures from the old account will cause validation failures on the new account.
// By re
moving these
signature
s
, we
allow the new account to generate valid signatures
.
// By re
placing with dummy
signature, we
skip signature validation
.
func
CleanGeminiNativeThoughtSignatures
(
body
[]
byte
)
[]
byte
{
if
len
(
body
)
==
0
{
return
body
...
...
@@ -28,11 +30,11 @@ func CleanGeminiNativeThoughtSignatures(body []byte) []byte {
return
body
}
// 递归
清理
thoughtSignature
cleaned
:=
clean
ThoughtSignaturesRecursive
(
data
)
// 递归
替换
thoughtSignature
为 dummy 签名
replaced
:=
replace
ThoughtSignaturesRecursive
(
data
)
// 重新序列化
result
,
err
:=
json
.
Marshal
(
clean
ed
)
result
,
err
:=
json
.
Marshal
(
replac
ed
)
if
err
!=
nil
{
// 如果序列化失败,返回原始 body
return
body
...
...
@@ -41,19 +43,20 @@ func CleanGeminiNativeThoughtSignatures(body []byte) []byte {
return
result
}
//
clean
ThoughtSignaturesRecursive 递归遍历数据结构,
移除
所有 thoughtSignature 字段
func
clean
ThoughtSignaturesRecursive
(
data
any
)
any
{
//
replace
ThoughtSignaturesRecursive 递归遍历数据结构,
将
所有 thoughtSignature 字段
替换为 dummy 签名
func
replace
ThoughtSignaturesRecursive
(
data
any
)
any
{
switch
v
:=
data
.
(
type
)
{
case
map
[
string
]
any
:
// 创建新的 map,
移除
thoughtSignature
// 创建新的 map,
替换
thoughtSignature
为 dummy 签名
result
:=
make
(
map
[
string
]
any
,
len
(
v
))
for
key
,
value
:=
range
v
{
//
跳过
thoughtSignature 字段
//
替换
thoughtSignature 字段
为 dummy 签名
if
key
==
"thoughtSignature"
{
result
[
key
]
=
antigravity
.
DummyThoughtSignature
continue
}
// 递归处理嵌套结构
result
[
key
]
=
clean
ThoughtSignaturesRecursive
(
value
)
result
[
key
]
=
replace
ThoughtSignaturesRecursive
(
value
)
}
return
result
...
...
@@ -61,7 +64,7 @@ func cleanThoughtSignaturesRecursive(data any) any {
// 递归处理数组中的每个元素
result
:=
make
([]
any
,
len
(
v
))
for
i
,
item
:=
range
v
{
result
[
i
]
=
clean
ThoughtSignaturesRecursive
(
item
)
result
[
i
]
=
replace
ThoughtSignaturesRecursive
(
item
)
}
return
result
...
...
backend/internal/service/group.go
View file @
dd96ada3
...
...
@@ -29,6 +29,8 @@ type Group struct {
// Claude Code 客户端限制
ClaudeCodeOnly
bool
FallbackGroupID
*
int64
// 无效请求兜底分组(仅 anthropic 平台使用)
FallbackGroupIDOnInvalidRequest
*
int64
// 模型路由配置
// key: 模型匹配模式(支持 * 通配符,如 "claude-opus-*")
...
...
@@ -36,6 +38,13 @@ type Group struct {
ModelRouting
map
[
string
][]
int64
ModelRoutingEnabled
bool
// MCP XML 协议注入开关(仅 antigravity 平台使用)
MCPXMLInject
bool
// 支持的模型系列(仅 antigravity 平台使用)
// 可选值: claude, gemini_text, gemini_image
SupportedModelScopes
[]
string
CreatedAt
time
.
Time
UpdatedAt
time
.
Time
...
...
backend/internal/service/identity_service.go
View file @
dd96ada3
...
...
@@ -169,22 +169,31 @@ func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) {
// RewriteUserID 重写body中的metadata.user_id
// 输入格式:user_{clientId}_account__session_{sessionUUID}
// 输出格式:user_{cachedClientID}_account_{accountUUID}_session_{newHash}
//
// 重要:此函数使用 json.RawMessage 保留其他字段的原始字节,
// 避免重新序列化导致 thinking 块等内容被修改。
func
(
s
*
IdentityService
)
RewriteUserID
(
body
[]
byte
,
accountID
int64
,
accountUUID
,
cachedClientID
string
)
([]
byte
,
error
)
{
if
len
(
body
)
==
0
||
accountUUID
==
""
||
cachedClientID
==
""
{
return
body
,
nil
}
//
解析JSON
var
reqMap
map
[
string
]
any
//
使用 RawMessage 保留其他字段的原始字节
var
reqMap
map
[
string
]
json
.
RawMessage
if
err
:=
json
.
Unmarshal
(
body
,
&
reqMap
);
err
!=
nil
{
return
body
,
nil
}
metadata
,
ok
:=
reqMap
[
"metadata"
]
.
(
map
[
string
]
any
)
// 解析 metadata 字段
metadataRaw
,
ok
:=
reqMap
[
"metadata"
]
if
!
ok
{
return
body
,
nil
}
var
metadata
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
metadataRaw
,
&
metadata
);
err
!=
nil
{
return
body
,
nil
}
userID
,
ok
:=
metadata
[
"user_id"
]
.
(
string
)
if
!
ok
||
userID
==
""
{
return
body
,
nil
...
...
@@ -207,7 +216,13 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
newUserID
:=
fmt
.
Sprintf
(
"user_%s_account_%s_session_%s"
,
cachedClientID
,
accountUUID
,
newSessionHash
)
metadata
[
"user_id"
]
=
newUserID
reqMap
[
"metadata"
]
=
metadata
// 只重新序列化 metadata 字段
newMetadataRaw
,
err
:=
json
.
Marshal
(
metadata
)
if
err
!=
nil
{
return
body
,
nil
}
reqMap
[
"metadata"
]
=
newMetadataRaw
return
json
.
Marshal
(
reqMap
)
}
...
...
@@ -215,6 +230,9 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
// RewriteUserIDWithMasking 重写body中的metadata.user_id,支持会话ID伪装
// 如果账号启用了会话ID伪装(session_id_masking_enabled),
// 则在完成常规重写后,将 session 部分替换为固定的伪装ID(15分钟内保持不变)
//
// 重要:此函数使用 json.RawMessage 保留其他字段的原始字节,
// 避免重新序列化导致 thinking 块等内容被修改。
func
(
s
*
IdentityService
)
RewriteUserIDWithMasking
(
ctx
context
.
Context
,
body
[]
byte
,
account
*
Account
,
accountUUID
,
cachedClientID
string
)
([]
byte
,
error
)
{
// 先执行常规的 RewriteUserID 逻辑
newBody
,
err
:=
s
.
RewriteUserID
(
body
,
account
.
ID
,
accountUUID
,
cachedClientID
)
...
...
@@ -227,17 +245,23 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
return
newBody
,
nil
}
//
解析重写后的 body,提取 user_id
var
reqMap
map
[
string
]
any
//
使用 RawMessage 保留其他字段的原始字节
var
reqMap
map
[
string
]
json
.
RawMessage
if
err
:=
json
.
Unmarshal
(
newBody
,
&
reqMap
);
err
!=
nil
{
return
newBody
,
nil
}
metadata
,
ok
:=
reqMap
[
"metadata"
]
.
(
map
[
string
]
any
)
// 解析 metadata 字段
metadataRaw
,
ok
:=
reqMap
[
"metadata"
]
if
!
ok
{
return
newBody
,
nil
}
var
metadata
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
metadataRaw
,
&
metadata
);
err
!=
nil
{
return
newBody
,
nil
}
userID
,
ok
:=
metadata
[
"user_id"
]
.
(
string
)
if
!
ok
||
userID
==
""
{
return
newBody
,
nil
...
...
@@ -278,7 +302,13 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
)
metadata
[
"user_id"
]
=
newUserID
reqMap
[
"metadata"
]
=
metadata
// 只重新序列化 metadata 字段
newMetadataRaw
,
marshalErr
:=
json
.
Marshal
(
metadata
)
if
marshalErr
!=
nil
{
return
newBody
,
nil
}
reqMap
[
"metadata"
]
=
newMetadataRaw
return
json
.
Marshal
(
reqMap
)
}
...
...
backend/internal/service/openai_codex_transform.go
View file @
dd96ada3
...
...
@@ -72,7 +72,7 @@ type opencodeCacheMetadata struct {
LastChecked
int64
`json:"lastChecked"`
}
func
applyCodexOAuthTransform
(
reqBody
map
[
string
]
any
)
codexTransformResult
{
func
applyCodexOAuthTransform
(
reqBody
map
[
string
]
any
,
isCodexCLI
bool
)
codexTransformResult
{
result
:=
codexTransformResult
{}
// 工具续链需求会影响存储策略与 input 过滤逻辑。
needsToolContinuation
:=
NeedsToolContinuation
(
reqBody
)
...
...
@@ -118,22 +118,9 @@ func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult {
result
.
PromptCacheKey
=
strings
.
TrimSpace
(
v
)
}
instructions
:=
strings
.
TrimSpace
(
getOpenCodeCodexHeader
())
existingInstructions
,
_
:=
reqBody
[
"instructions"
]
.
(
string
)
existingInstructions
=
strings
.
TrimSpace
(
existingInstructions
)
if
instructions
!=
""
{
if
existingInstructions
!=
instructions
{
reqBody
[
"instructions"
]
=
instructions
result
.
Modified
=
true
}
}
else
if
existingInstructions
==
""
{
// 未获取到 opencode 指令时,回退使用 Codex CLI 指令。
codexInstructions
:=
strings
.
TrimSpace
(
getCodexCLIInstructions
())
if
codexInstructions
!=
""
{
reqBody
[
"instructions"
]
=
codexInstructions
result
.
Modified
=
true
}
// instructions 处理逻辑:根据是否是 Codex CLI 分别调用不同方法
if
applyInstructions
(
reqBody
,
isCodexCLI
)
{
result
.
Modified
=
true
}
// 续链场景保留 item_reference 与 id,避免 call_id 上下文丢失。
...
...
@@ -276,6 +263,72 @@ func GetCodexCLIInstructions() string {
return
getCodexCLIInstructions
()
}
// applyInstructions 处理 instructions 字段
// isCodexCLI=true: 仅补充缺失的 instructions(使用 opencode 指令)
// isCodexCLI=false: 优先使用 opencode 指令覆盖
func
applyInstructions
(
reqBody
map
[
string
]
any
,
isCodexCLI
bool
)
bool
{
if
isCodexCLI
{
return
applyCodexCLIInstructions
(
reqBody
)
}
return
applyOpenCodeInstructions
(
reqBody
)
}
// applyCodexCLIInstructions 为 Codex CLI 请求补充缺失的 instructions
// 仅在 instructions 为空时添加 opencode 指令
func
applyCodexCLIInstructions
(
reqBody
map
[
string
]
any
)
bool
{
if
!
isInstructionsEmpty
(
reqBody
)
{
return
false
// 已有有效 instructions,不修改
}
instructions
:=
strings
.
TrimSpace
(
getOpenCodeCodexHeader
())
if
instructions
!=
""
{
reqBody
[
"instructions"
]
=
instructions
return
true
}
return
false
}
// applyOpenCodeInstructions 为非 Codex CLI 请求应用 opencode 指令
// 优先使用 opencode 指令覆盖
func
applyOpenCodeInstructions
(
reqBody
map
[
string
]
any
)
bool
{
instructions
:=
strings
.
TrimSpace
(
getOpenCodeCodexHeader
())
existingInstructions
,
_
:=
reqBody
[
"instructions"
]
.
(
string
)
existingInstructions
=
strings
.
TrimSpace
(
existingInstructions
)
if
instructions
!=
""
{
if
existingInstructions
!=
instructions
{
reqBody
[
"instructions"
]
=
instructions
return
true
}
}
else
if
existingInstructions
==
""
{
codexInstructions
:=
strings
.
TrimSpace
(
getCodexCLIInstructions
())
if
codexInstructions
!=
""
{
reqBody
[
"instructions"
]
=
codexInstructions
return
true
}
}
return
false
}
// isInstructionsEmpty 检查 instructions 字段是否为空
// 处理以下情况:字段不存在、nil、空字符串、纯空白字符串
func
isInstructionsEmpty
(
reqBody
map
[
string
]
any
)
bool
{
val
,
exists
:=
reqBody
[
"instructions"
]
if
!
exists
{
return
true
}
if
val
==
nil
{
return
true
}
str
,
ok
:=
val
.
(
string
)
if
!
ok
{
return
true
}
return
strings
.
TrimSpace
(
str
)
==
""
}
// ReplaceWithCodexInstructions 将请求 instructions 替换为内置 Codex 指令(必要时)。
func
ReplaceWithCodexInstructions
(
reqBody
map
[
string
]
any
)
bool
{
codexInstructions
:=
strings
.
TrimSpace
(
getCodexCLIInstructions
())
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment