Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
de7ff902
Commit
de7ff902
authored
Feb 04, 2026
by
yangjianbo
Browse files
Merge branch 'main' into test
parents
317f26f0
dd96ada3
Changes
90
Hide whitespace changes
Inline
Side-by-side
backend/internal/server/middleware/api_key_auth_test.go
View file @
de7ff902
...
...
@@ -319,6 +319,10 @@ func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) (
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubApiKeyRepo
)
IncrementQuotaUsed
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
(
float64
,
error
)
{
return
0
,
errors
.
New
(
"not implemented"
)
}
type
stubUserSubscriptionRepo
struct
{
getActive
func
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
service
.
UserSubscription
,
error
)
updateStatus
func
(
ctx
context
.
Context
,
subscriptionID
int64
,
status
string
)
error
...
...
backend/internal/service/admin_service.go
View file @
de7ff902
...
...
@@ -116,9 +116,14 @@ type CreateGroupInput struct {
SoraVideoPricePerRequestHD
*
float64
ClaudeCodeOnly
bool
// 仅允许 Claude Code 客户端
FallbackGroupID
*
int64
// 降级分组 ID
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
FallbackGroupIDOnInvalidRequest
*
int64
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting
map
[
string
][]
int64
ModelRoutingEnabled
bool
// 是否启用模型路由
MCPXMLInject
*
bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes
[]
string
// 从指定分组复制账号(创建分组后在同一事务内绑定)
CopyAccountsFromGroupIDs
[]
int64
}
...
...
@@ -145,9 +150,14 @@ type UpdateGroupInput struct {
SoraVideoPricePerRequestHD
*
float64
ClaudeCodeOnly
*
bool
// 仅允许 Claude Code 客户端
FallbackGroupID
*
int64
// 降级分组 ID
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
FallbackGroupIDOnInvalidRequest
*
int64
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting
map
[
string
][]
int64
ModelRoutingEnabled
*
bool
// 是否启用模型路由
MCPXMLInject
*
bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes
*
[]
string
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs
[]
int64
}
...
...
@@ -611,6 +621,22 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
return
nil
,
err
}
}
fallbackOnInvalidRequest
:=
input
.
FallbackGroupIDOnInvalidRequest
if
fallbackOnInvalidRequest
!=
nil
&&
*
fallbackOnInvalidRequest
<=
0
{
fallbackOnInvalidRequest
=
nil
}
// 校验无效请求兜底分组
if
fallbackOnInvalidRequest
!=
nil
{
if
err
:=
s
.
validateFallbackGroupOnInvalidRequest
(
ctx
,
0
,
platform
,
subscriptionType
,
*
fallbackOnInvalidRequest
);
err
!=
nil
{
return
nil
,
err
}
}
// MCPXMLInject:默认为 true,仅当显式传入 false 时关闭
mcpXMLInject
:=
true
if
input
.
MCPXMLInject
!=
nil
{
mcpXMLInject
=
*
input
.
MCPXMLInject
}
// 如果指定了复制账号的源分组,先获取账号 ID 列表
var
accountIDsToCopy
[]
int64
...
...
@@ -645,26 +671,29 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
}
group
:=
&
Group
{
Name
:
input
.
Name
,
Description
:
input
.
Description
,
Platform
:
platform
,
RateMultiplier
:
input
.
RateMultiplier
,
IsExclusive
:
input
.
IsExclusive
,
Status
:
StatusActive
,
SubscriptionType
:
subscriptionType
,
DailyLimitUSD
:
dailyLimit
,
WeeklyLimitUSD
:
weeklyLimit
,
MonthlyLimitUSD
:
monthlyLimit
,
ImagePrice1K
:
imagePrice1K
,
ImagePrice2K
:
imagePrice2K
,
ImagePrice4K
:
imagePrice4K
,
SoraImagePrice360
:
soraImagePrice360
,
SoraImagePrice540
:
soraImagePrice540
,
SoraVideoPricePerRequest
:
soraVideoPrice
,
SoraVideoPricePerRequestHD
:
soraVideoPriceHD
,
ClaudeCodeOnly
:
input
.
ClaudeCodeOnly
,
FallbackGroupID
:
input
.
FallbackGroupID
,
ModelRouting
:
input
.
ModelRouting
,
Name
:
input
.
Name
,
Description
:
input
.
Description
,
Platform
:
platform
,
RateMultiplier
:
input
.
RateMultiplier
,
IsExclusive
:
input
.
IsExclusive
,
Status
:
StatusActive
,
SubscriptionType
:
subscriptionType
,
DailyLimitUSD
:
dailyLimit
,
WeeklyLimitUSD
:
weeklyLimit
,
MonthlyLimitUSD
:
monthlyLimit
,
ImagePrice1K
:
imagePrice1K
,
ImagePrice2K
:
imagePrice2K
,
ImagePrice4K
:
imagePrice4K
,
SoraImagePrice360
:
soraImagePrice360
,
SoraImagePrice540
:
soraImagePrice540
,
SoraVideoPricePerRequest
:
soraVideoPrice
,
SoraVideoPricePerRequestHD
:
soraVideoPriceHD
,
ClaudeCodeOnly
:
input
.
ClaudeCodeOnly
,
FallbackGroupID
:
input
.
FallbackGroupID
,
FallbackGroupIDOnInvalidRequest
:
fallbackOnInvalidRequest
,
ModelRouting
:
input
.
ModelRouting
,
MCPXMLInject
:
mcpXMLInject
,
SupportedModelScopes
:
input
.
SupportedModelScopes
,
}
if
err
:=
s
.
groupRepo
.
Create
(
ctx
,
group
);
err
!=
nil
{
return
nil
,
err
...
...
@@ -735,6 +764,37 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro
}
}
// validateFallbackGroupOnInvalidRequest 校验无效请求兜底分组的有效性
// currentGroupID: 当前分组 ID(新建时为 0)
// platform/subscriptionType: 当前分组的有效平台/订阅类型
// fallbackGroupID: 兜底分组 ID
func
(
s
*
adminServiceImpl
)
validateFallbackGroupOnInvalidRequest
(
ctx
context
.
Context
,
currentGroupID
int64
,
platform
,
subscriptionType
string
,
fallbackGroupID
int64
)
error
{
if
platform
!=
PlatformAnthropic
&&
platform
!=
PlatformAntigravity
{
return
fmt
.
Errorf
(
"invalid request fallback only supported for anthropic or antigravity groups"
)
}
if
subscriptionType
==
SubscriptionTypeSubscription
{
return
fmt
.
Errorf
(
"subscription groups cannot set invalid request fallback"
)
}
if
currentGroupID
>
0
&&
currentGroupID
==
fallbackGroupID
{
return
fmt
.
Errorf
(
"cannot set self as invalid request fallback group"
)
}
fallbackGroup
,
err
:=
s
.
groupRepo
.
GetByIDLite
(
ctx
,
fallbackGroupID
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"fallback group not found: %w"
,
err
)
}
if
fallbackGroup
.
Platform
!=
PlatformAnthropic
{
return
fmt
.
Errorf
(
"fallback group must be anthropic platform"
)
}
if
fallbackGroup
.
SubscriptionType
==
SubscriptionTypeSubscription
{
return
fmt
.
Errorf
(
"fallback group cannot be subscription type"
)
}
if
fallbackGroup
.
FallbackGroupIDOnInvalidRequest
!=
nil
{
return
fmt
.
Errorf
(
"fallback group cannot have invalid request fallback configured"
)
}
return
nil
}
func
(
s
*
adminServiceImpl
)
UpdateGroup
(
ctx
context
.
Context
,
id
int64
,
input
*
UpdateGroupInput
)
(
*
Group
,
error
)
{
group
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
...
...
@@ -813,6 +873,20 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
group
.
FallbackGroupID
=
nil
}
}
fallbackOnInvalidRequest
:=
group
.
FallbackGroupIDOnInvalidRequest
if
input
.
FallbackGroupIDOnInvalidRequest
!=
nil
{
if
*
input
.
FallbackGroupIDOnInvalidRequest
>
0
{
fallbackOnInvalidRequest
=
input
.
FallbackGroupIDOnInvalidRequest
}
else
{
fallbackOnInvalidRequest
=
nil
}
}
if
fallbackOnInvalidRequest
!=
nil
{
if
err
:=
s
.
validateFallbackGroupOnInvalidRequest
(
ctx
,
id
,
group
.
Platform
,
group
.
SubscriptionType
,
*
fallbackOnInvalidRequest
);
err
!=
nil
{
return
nil
,
err
}
}
group
.
FallbackGroupIDOnInvalidRequest
=
fallbackOnInvalidRequest
// 模型路由配置
if
input
.
ModelRouting
!=
nil
{
...
...
@@ -821,6 +895,14 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if
input
.
ModelRoutingEnabled
!=
nil
{
group
.
ModelRoutingEnabled
=
*
input
.
ModelRoutingEnabled
}
if
input
.
MCPXMLInject
!=
nil
{
group
.
MCPXMLInject
=
*
input
.
MCPXMLInject
}
// 支持的模型系列(仅 antigravity 平台使用)
if
input
.
SupportedModelScopes
!=
nil
{
group
.
SupportedModelScopes
=
*
input
.
SupportedModelScopes
}
if
err
:=
s
.
groupRepo
.
Update
(
ctx
,
group
);
err
!=
nil
{
return
nil
,
err
...
...
backend/internal/service/admin_service_group_test.go
View file @
de7ff902
...
...
@@ -394,3 +394,382 @@ func (s *groupRepoStubForFallbackCycle) BindAccountsToGroup(_ context.Context, _
func
(
s
*
groupRepoStubForFallbackCycle
)
GetAccountIDsByGroupIDs
(
_
context
.
Context
,
_
[]
int64
)
([]
int64
,
error
)
{
panic
(
"unexpected GetAccountIDsByGroupIDs call"
)
}
type
groupRepoStubForInvalidRequestFallback
struct
{
groups
map
[
int64
]
*
Group
created
*
Group
updated
*
Group
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
Create
(
_
context
.
Context
,
g
*
Group
)
error
{
s
.
created
=
g
return
nil
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
Update
(
_
context
.
Context
,
g
*
Group
)
error
{
s
.
updated
=
g
return
nil
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
return
s
.
GetByIDLite
(
ctx
,
id
)
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
GetByIDLite
(
_
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
if
g
,
ok
:=
s
.
groups
[
id
];
ok
{
return
g
,
nil
}
return
nil
,
ErrGroupNotFound
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
Delete
(
_
context
.
Context
,
_
int64
)
error
{
panic
(
"unexpected Delete call"
)
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
DeleteCascade
(
_
context
.
Context
,
_
int64
)
([]
int64
,
error
)
{
panic
(
"unexpected DeleteCascade call"
)
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
List
(
_
context
.
Context
,
_
pagination
.
PaginationParams
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
panic
(
"unexpected List call"
)
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
ListWithFilters
(
_
context
.
Context
,
_
pagination
.
PaginationParams
,
_
,
_
,
_
string
,
_
*
bool
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
panic
(
"unexpected ListWithFilters call"
)
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
ListActive
(
_
context
.
Context
)
([]
Group
,
error
)
{
panic
(
"unexpected ListActive call"
)
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
ListActiveByPlatform
(
_
context
.
Context
,
_
string
)
([]
Group
,
error
)
{
panic
(
"unexpected ListActiveByPlatform call"
)
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
ExistsByName
(
_
context
.
Context
,
_
string
)
(
bool
,
error
)
{
panic
(
"unexpected ExistsByName call"
)
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
GetAccountCount
(
_
context
.
Context
,
_
int64
)
(
int64
,
error
)
{
panic
(
"unexpected GetAccountCount call"
)
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
DeleteAccountGroupsByGroupID
(
_
context
.
Context
,
_
int64
)
(
int64
,
error
)
{
panic
(
"unexpected DeleteAccountGroupsByGroupID call"
)
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
GetAccountIDsByGroupIDs
(
_
context
.
Context
,
_
[]
int64
)
([]
int64
,
error
)
{
panic
(
"unexpected GetAccountIDsByGroupIDs call"
)
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
BindAccountsToGroup
(
_
context
.
Context
,
_
int64
,
_
[]
int64
)
error
{
panic
(
"unexpected BindAccountsToGroup call"
)
}
func
TestAdminService_CreateGroup_InvalidRequestFallbackRejectsUnsupportedPlatform
(
t
*
testing
.
T
)
{
fallbackID
:=
int64
(
10
)
repo
:=
&
groupRepoStubForInvalidRequestFallback
{
groups
:
map
[
int64
]
*
Group
{
fallbackID
:
{
ID
:
fallbackID
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
},
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
_
,
err
:=
svc
.
CreateGroup
(
context
.
Background
(),
&
CreateGroupInput
{
Name
:
"g1"
,
Platform
:
PlatformOpenAI
,
SubscriptionType
:
SubscriptionTypeStandard
,
FallbackGroupIDOnInvalidRequest
:
&
fallbackID
,
})
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"invalid request fallback only supported for anthropic or antigravity groups"
)
require
.
Nil
(
t
,
repo
.
created
)
}
func
TestAdminService_CreateGroup_InvalidRequestFallbackRejectsSubscription
(
t
*
testing
.
T
)
{
fallbackID
:=
int64
(
10
)
repo
:=
&
groupRepoStubForInvalidRequestFallback
{
groups
:
map
[
int64
]
*
Group
{
fallbackID
:
{
ID
:
fallbackID
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
},
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
_
,
err
:=
svc
.
CreateGroup
(
context
.
Background
(),
&
CreateGroupInput
{
Name
:
"g1"
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeSubscription
,
FallbackGroupIDOnInvalidRequest
:
&
fallbackID
,
})
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"subscription groups cannot set invalid request fallback"
)
require
.
Nil
(
t
,
repo
.
created
)
}
func
TestAdminService_CreateGroup_InvalidRequestFallbackRejectsFallbackGroup
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
fallback
*
Group
wantMessage
string
}{
{
name
:
"openai_target"
,
fallback
:
&
Group
{
ID
:
10
,
Platform
:
PlatformOpenAI
,
SubscriptionType
:
SubscriptionTypeStandard
},
wantMessage
:
"fallback group must be anthropic platform"
,
},
{
name
:
"antigravity_target"
,
fallback
:
&
Group
{
ID
:
10
,
Platform
:
PlatformAntigravity
,
SubscriptionType
:
SubscriptionTypeStandard
},
wantMessage
:
"fallback group must be anthropic platform"
,
},
{
name
:
"subscription_group"
,
fallback
:
&
Group
{
ID
:
10
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeSubscription
},
wantMessage
:
"fallback group cannot be subscription type"
,
},
{
name
:
"nested_fallback"
,
fallback
:
&
Group
{
ID
:
10
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
,
FallbackGroupIDOnInvalidRequest
:
func
()
*
int64
{
v
:=
int64
(
99
);
return
&
v
}(),
},
wantMessage
:
"fallback group cannot have invalid request fallback configured"
,
},
}
for
_
,
tc
:=
range
tests
{
t
.
Run
(
tc
.
name
,
func
(
t
*
testing
.
T
)
{
fallbackID
:=
tc
.
fallback
.
ID
repo
:=
&
groupRepoStubForInvalidRequestFallback
{
groups
:
map
[
int64
]
*
Group
{
fallbackID
:
tc
.
fallback
,
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
_
,
err
:=
svc
.
CreateGroup
(
context
.
Background
(),
&
CreateGroupInput
{
Name
:
"g1"
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
,
FallbackGroupIDOnInvalidRequest
:
&
fallbackID
,
})
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
tc
.
wantMessage
)
require
.
Nil
(
t
,
repo
.
created
)
})
}
}
func
TestAdminService_CreateGroup_InvalidRequestFallbackNotFound
(
t
*
testing
.
T
)
{
fallbackID
:=
int64
(
10
)
repo
:=
&
groupRepoStubForInvalidRequestFallback
{}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
_
,
err
:=
svc
.
CreateGroup
(
context
.
Background
(),
&
CreateGroupInput
{
Name
:
"g1"
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
,
FallbackGroupIDOnInvalidRequest
:
&
fallbackID
,
})
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"fallback group not found"
)
require
.
Nil
(
t
,
repo
.
created
)
}
func
TestAdminService_CreateGroup_InvalidRequestFallbackAllowsAntigravity
(
t
*
testing
.
T
)
{
fallbackID
:=
int64
(
10
)
repo
:=
&
groupRepoStubForInvalidRequestFallback
{
groups
:
map
[
int64
]
*
Group
{
fallbackID
:
{
ID
:
fallbackID
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
},
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
group
,
err
:=
svc
.
CreateGroup
(
context
.
Background
(),
&
CreateGroupInput
{
Name
:
"g1"
,
Platform
:
PlatformAntigravity
,
SubscriptionType
:
SubscriptionTypeStandard
,
FallbackGroupIDOnInvalidRequest
:
&
fallbackID
,
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
group
)
require
.
NotNil
(
t
,
repo
.
created
)
require
.
Equal
(
t
,
fallbackID
,
*
repo
.
created
.
FallbackGroupIDOnInvalidRequest
)
}
func
TestAdminService_CreateGroup_InvalidRequestFallbackClearsOnZero
(
t
*
testing
.
T
)
{
zero
:=
int64
(
0
)
repo
:=
&
groupRepoStubForInvalidRequestFallback
{}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
group
,
err
:=
svc
.
CreateGroup
(
context
.
Background
(),
&
CreateGroupInput
{
Name
:
"g1"
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
,
FallbackGroupIDOnInvalidRequest
:
&
zero
,
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
group
)
require
.
NotNil
(
t
,
repo
.
created
)
require
.
Nil
(
t
,
repo
.
created
.
FallbackGroupIDOnInvalidRequest
)
}
func
TestAdminService_UpdateGroup_InvalidRequestFallbackPlatformMismatch
(
t
*
testing
.
T
)
{
fallbackID
:=
int64
(
10
)
existing
:=
&
Group
{
ID
:
1
,
Name
:
"g1"
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
,
Status
:
StatusActive
,
FallbackGroupIDOnInvalidRequest
:
&
fallbackID
,
}
repo
:=
&
groupRepoStubForInvalidRequestFallback
{
groups
:
map
[
int64
]
*
Group
{
existing
.
ID
:
existing
,
fallbackID
:
{
ID
:
fallbackID
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
},
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
_
,
err
:=
svc
.
UpdateGroup
(
context
.
Background
(),
existing
.
ID
,
&
UpdateGroupInput
{
Platform
:
PlatformOpenAI
,
})
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"invalid request fallback only supported for anthropic or antigravity groups"
)
require
.
Nil
(
t
,
repo
.
updated
)
}
func
TestAdminService_UpdateGroup_InvalidRequestFallbackSubscriptionMismatch
(
t
*
testing
.
T
)
{
fallbackID
:=
int64
(
10
)
existing
:=
&
Group
{
ID
:
1
,
Name
:
"g1"
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
,
Status
:
StatusActive
,
FallbackGroupIDOnInvalidRequest
:
&
fallbackID
,
}
repo
:=
&
groupRepoStubForInvalidRequestFallback
{
groups
:
map
[
int64
]
*
Group
{
existing
.
ID
:
existing
,
fallbackID
:
{
ID
:
fallbackID
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
},
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
_
,
err
:=
svc
.
UpdateGroup
(
context
.
Background
(),
existing
.
ID
,
&
UpdateGroupInput
{
SubscriptionType
:
SubscriptionTypeSubscription
,
})
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"subscription groups cannot set invalid request fallback"
)
require
.
Nil
(
t
,
repo
.
updated
)
}
func
TestAdminService_UpdateGroup_InvalidRequestFallbackClearsOnZero
(
t
*
testing
.
T
)
{
fallbackID
:=
int64
(
10
)
existing
:=
&
Group
{
ID
:
1
,
Name
:
"g1"
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
,
Status
:
StatusActive
,
FallbackGroupIDOnInvalidRequest
:
&
fallbackID
,
}
repo
:=
&
groupRepoStubForInvalidRequestFallback
{
groups
:
map
[
int64
]
*
Group
{
existing
.
ID
:
existing
,
fallbackID
:
{
ID
:
fallbackID
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
},
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
clear
:=
int64
(
0
)
group
,
err
:=
svc
.
UpdateGroup
(
context
.
Background
(),
existing
.
ID
,
&
UpdateGroupInput
{
Platform
:
PlatformOpenAI
,
FallbackGroupIDOnInvalidRequest
:
&
clear
,
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
group
)
require
.
NotNil
(
t
,
repo
.
updated
)
require
.
Nil
(
t
,
repo
.
updated
.
FallbackGroupIDOnInvalidRequest
)
}
func
TestAdminService_UpdateGroup_InvalidRequestFallbackRejectsFallbackGroup
(
t
*
testing
.
T
)
{
fallbackID
:=
int64
(
10
)
existing
:=
&
Group
{
ID
:
1
,
Name
:
"g1"
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
,
Status
:
StatusActive
,
}
repo
:=
&
groupRepoStubForInvalidRequestFallback
{
groups
:
map
[
int64
]
*
Group
{
existing
.
ID
:
existing
,
fallbackID
:
{
ID
:
fallbackID
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeSubscription
},
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
_
,
err
:=
svc
.
UpdateGroup
(
context
.
Background
(),
existing
.
ID
,
&
UpdateGroupInput
{
FallbackGroupIDOnInvalidRequest
:
&
fallbackID
,
})
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"fallback group cannot be subscription type"
)
require
.
Nil
(
t
,
repo
.
updated
)
}
func
TestAdminService_UpdateGroup_InvalidRequestFallbackSetSuccess
(
t
*
testing
.
T
)
{
fallbackID
:=
int64
(
10
)
existing
:=
&
Group
{
ID
:
1
,
Name
:
"g1"
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
,
Status
:
StatusActive
,
}
repo
:=
&
groupRepoStubForInvalidRequestFallback
{
groups
:
map
[
int64
]
*
Group
{
existing
.
ID
:
existing
,
fallbackID
:
{
ID
:
fallbackID
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
},
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
group
,
err
:=
svc
.
UpdateGroup
(
context
.
Background
(),
existing
.
ID
,
&
UpdateGroupInput
{
FallbackGroupIDOnInvalidRequest
:
&
fallbackID
,
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
group
)
require
.
NotNil
(
t
,
repo
.
updated
)
require
.
Equal
(
t
,
fallbackID
,
*
repo
.
updated
.
FallbackGroupIDOnInvalidRequest
)
}
func
TestAdminService_UpdateGroup_InvalidRequestFallbackAllowsAntigravity
(
t
*
testing
.
T
)
{
fallbackID
:=
int64
(
10
)
existing
:=
&
Group
{
ID
:
1
,
Name
:
"g1"
,
Platform
:
PlatformAntigravity
,
SubscriptionType
:
SubscriptionTypeStandard
,
Status
:
StatusActive
,
}
repo
:=
&
groupRepoStubForInvalidRequestFallback
{
groups
:
map
[
int64
]
*
Group
{
existing
.
ID
:
existing
,
fallbackID
:
{
ID
:
fallbackID
,
Platform
:
PlatformAnthropic
,
SubscriptionType
:
SubscriptionTypeStandard
},
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
group
,
err
:=
svc
.
UpdateGroup
(
context
.
Background
(),
existing
.
ID
,
&
UpdateGroupInput
{
FallbackGroupIDOnInvalidRequest
:
&
fallbackID
,
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
group
)
require
.
NotNil
(
t
,
repo
.
updated
)
require
.
Equal
(
t
,
fallbackID
,
*
repo
.
updated
.
FallbackGroupIDOnInvalidRequest
)
}
backend/internal/service/antigravity_gateway_service.go
View file @
de7ff902
...
...
@@ -13,23 +13,34 @@ import (
"net"
"net/http"
"os"
"strconv"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
const
(
antigravityStickySessionTTL
=
time
.
Hour
antigravityMaxRetries
=
3
antigravityRetryBaseDelay
=
1
*
time
.
Second
antigravityRetryMaxDelay
=
16
*
time
.
Second
antigravityStickySessionTTL
=
time
.
Hour
antigravity
Default
MaxRetries
=
3
antigravityRetryBaseDelay
=
1
*
time
.
Second
antigravityRetryMaxDelay
=
16
*
time
.
Second
)
const
antigravityScopeRateLimitEnv
=
"GATEWAY_ANTIGRAVITY_429_SCOPE_LIMIT"
const
(
antigravityMaxRetriesEnv
=
"GATEWAY_ANTIGRAVITY_MAX_RETRIES"
antigravityMaxRetriesAfterSwitchEnv
=
"GATEWAY_ANTIGRAVITY_AFTER_SWITCHMAX_RETRIES"
antigravityMaxRetriesClaudeEnv
=
"GATEWAY_ANTIGRAVITY_MAX_RETRIES_CLAUDE"
antigravityMaxRetriesGeminiTextEnv
=
"GATEWAY_ANTIGRAVITY_MAX_RETRIES_GEMINI_TEXT"
antigravityMaxRetriesGeminiImageEnv
=
"GATEWAY_ANTIGRAVITY_MAX_RETRIES_GEMINI_IMAGE"
antigravityScopeRateLimitEnv
=
"GATEWAY_ANTIGRAVITY_429_SCOPE_LIMIT"
antigravityBillingModelEnv
=
"GATEWAY_ANTIGRAVITY_BILL_WITH_MAPPED_MODEL"
antigravityFallbackSecondsEnv
=
"GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS"
)
// antigravityRetryLoopParams 重试循环的参数
type
antigravityRetryLoopParams
struct
{
...
...
@@ -41,6 +52,7 @@ type antigravityRetryLoopParams struct {
action
string
body
[]
byte
quotaScope
AntigravityQuotaScope
maxRetries
int
c
*
gin
.
Context
httpUpstream
HTTPUpstream
settingService
*
SettingService
...
...
@@ -52,11 +64,28 @@ type antigravityRetryLoopResult struct {
resp
*
http
.
Response
}
// PromptTooLongError 表示上游明确返回 prompt too long
type
PromptTooLongError
struct
{
StatusCode
int
RequestID
string
Body
[]
byte
}
func
(
e
*
PromptTooLongError
)
Error
()
string
{
return
fmt
.
Sprintf
(
"prompt too long: status=%d"
,
e
.
StatusCode
)
}
// antigravityRetryLoop 执行带 URL fallback 的重试循环
func
antigravityRetryLoop
(
p
antigravityRetryLoopParams
)
(
*
antigravityRetryLoopResult
,
error
)
{
availableURLs
:=
antigravity
.
DefaultURLAvailability
.
GetAvailableURLs
()
baseURLs
:=
antigravity
.
ForwardBaseURLs
()
availableURLs
:=
antigravity
.
DefaultURLAvailability
.
GetAvailableURLsWithBase
(
baseURLs
)
if
len
(
availableURLs
)
==
0
{
availableURLs
=
antigravity
.
BaseURLs
availableURLs
=
baseURLs
}
maxRetries
:=
p
.
maxRetries
if
maxRetries
<=
0
{
maxRetries
=
antigravityDefaultMaxRetries
}
var
resp
*
http
.
Response
...
...
@@ -76,7 +105,7 @@ func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopRe
urlFallbackLoop
:
for
urlIdx
,
baseURL
:=
range
availableURLs
{
usedBaseURL
=
baseURL
for
attempt
:=
1
;
attempt
<=
antigravityM
axRetries
;
attempt
++
{
for
attempt
:=
1
;
attempt
<=
m
axRetries
;
attempt
++
{
select
{
case
<-
p
.
ctx
.
Done
()
:
log
.
Printf
(
"%s status=context_canceled error=%v"
,
p
.
prefix
,
p
.
ctx
.
Err
())
...
...
@@ -109,8 +138,8 @@ urlFallbackLoop:
log
.
Printf
(
"%s URL fallback (connection error): %s -> %s"
,
p
.
prefix
,
baseURL
,
availableURLs
[
urlIdx
+
1
])
continue
urlFallbackLoop
}
if
attempt
<
antigravityM
axRetries
{
log
.
Printf
(
"%s status=request_failed retry=%d/%d error=%v"
,
p
.
prefix
,
attempt
,
antigravityM
axRetries
,
err
)
if
attempt
<
m
axRetries
{
log
.
Printf
(
"%s status=request_failed retry=%d/%d error=%v"
,
p
.
prefix
,
attempt
,
m
axRetries
,
err
)
if
!
sleepAntigravityBackoffWithContext
(
p
.
ctx
,
attempt
)
{
log
.
Printf
(
"%s status=context_canceled_during_backoff"
,
p
.
prefix
)
return
nil
,
p
.
ctx
.
Err
()
...
...
@@ -134,7 +163,7 @@ urlFallbackLoop:
}
// 账户/模型配额限流,重试 3 次(指数退避)
if
attempt
<
antigravityM
axRetries
{
if
attempt
<
m
axRetries
{
upstreamMsg
:=
strings
.
TrimSpace
(
extractAntigravityErrorMessage
(
respBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
appendOpsUpstreamError
(
p
.
c
,
OpsUpstreamErrorEvent
{
...
...
@@ -147,7 +176,7 @@ urlFallbackLoop:
Message
:
upstreamMsg
,
Detail
:
getUpstreamDetail
(
respBody
),
})
log
.
Printf
(
"%s status=429 retry=%d/%d body=%s"
,
p
.
prefix
,
attempt
,
antigravityM
axRetries
,
truncateForLog
(
respBody
,
200
))
log
.
Printf
(
"%s status=429 retry=%d/%d body=%s"
,
p
.
prefix
,
attempt
,
m
axRetries
,
truncateForLog
(
respBody
,
200
))
if
!
sleepAntigravityBackoffWithContext
(
p
.
ctx
,
attempt
)
{
log
.
Printf
(
"%s status=context_canceled_during_backoff"
,
p
.
prefix
)
return
nil
,
p
.
ctx
.
Err
()
...
...
@@ -171,7 +200,7 @@ urlFallbackLoop:
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
_
=
resp
.
Body
.
Close
()
if
attempt
<
antigravityM
axRetries
{
if
attempt
<
m
axRetries
{
upstreamMsg
:=
strings
.
TrimSpace
(
extractAntigravityErrorMessage
(
respBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
appendOpsUpstreamError
(
p
.
c
,
OpsUpstreamErrorEvent
{
...
...
@@ -184,7 +213,7 @@ urlFallbackLoop:
Message
:
upstreamMsg
,
Detail
:
getUpstreamDetail
(
respBody
),
})
log
.
Printf
(
"%s status=%d retry=%d/%d body=%s"
,
p
.
prefix
,
resp
.
StatusCode
,
attempt
,
antigravityM
axRetries
,
truncateForLog
(
respBody
,
500
))
log
.
Printf
(
"%s status=%d retry=%d/%d body=%s"
,
p
.
prefix
,
resp
.
StatusCode
,
attempt
,
m
axRetries
,
truncateForLog
(
respBody
,
500
))
if
!
sleepAntigravityBackoffWithContext
(
p
.
ctx
,
attempt
)
{
log
.
Printf
(
"%s status=context_canceled_during_backoff"
,
p
.
prefix
)
return
nil
,
p
.
ctx
.
Err
()
...
...
@@ -390,6 +419,11 @@ type TestConnectionResult struct {
// TestConnection 测试 Antigravity 账号连接(非流式,无重试、无计费)
// 支持 Claude 和 Gemini 两种协议,根据 modelID 前缀自动选择
func
(
s
*
AntigravityGatewayService
)
TestConnection
(
ctx
context
.
Context
,
account
*
Account
,
modelID
string
)
(
*
TestConnectionResult
,
error
)
{
// 上游透传账号使用专用测试方法
if
account
.
Type
==
AccountTypeUpstream
{
return
s
.
testUpstreamConnection
(
ctx
,
account
,
modelID
)
}
// 获取 token
if
s
.
tokenProvider
==
nil
{
return
nil
,
errors
.
New
(
"antigravity token provider not configured"
)
...
...
@@ -484,6 +518,87 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
return
nil
,
lastErr
}
// testUpstreamConnection 测试上游透传账号连接
func
(
s
*
AntigravityGatewayService
)
testUpstreamConnection
(
ctx
context
.
Context
,
account
*
Account
,
modelID
string
)
(
*
TestConnectionResult
,
error
)
{
baseURL
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"base_url"
))
apiKey
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"api_key"
))
if
baseURL
==
""
||
apiKey
==
""
{
return
nil
,
errors
.
New
(
"upstream account missing base_url or api_key"
)
}
baseURL
=
strings
.
TrimSuffix
(
baseURL
,
"/"
)
// 使用 Claude 模型进行测试
if
modelID
==
""
{
modelID
=
"claude-sonnet-4-20250514"
}
// 构建最小测试请求
testReq
:=
map
[
string
]
any
{
"model"
:
modelID
,
"max_tokens"
:
1
,
"messages"
:
[]
map
[
string
]
any
{
{
"role"
:
"user"
,
"content"
:
"."
},
},
}
requestBody
,
err
:=
json
.
Marshal
(
testReq
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"构建请求失败: %w"
,
err
)
}
// 构建 HTTP 请求
upstreamURL
:=
baseURL
+
"/v1/messages"
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodPost
,
upstreamURL
,
bytes
.
NewReader
(
requestBody
))
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"创建请求失败: %w"
,
err
)
}
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
apiKey
)
req
.
Header
.
Set
(
"x-api-key"
,
apiKey
)
req
.
Header
.
Set
(
"anthropic-version"
,
"2023-06-01"
)
// 代理 URL
proxyURL
:=
""
if
account
.
ProxyID
!=
nil
&&
account
.
Proxy
!=
nil
{
proxyURL
=
account
.
Proxy
.
URL
()
}
log
.
Printf
(
"[antigravity-Test-Upstream] account=%s url=%s"
,
account
.
Name
,
upstreamURL
)
// 发送请求
resp
,
err
:=
s
.
httpUpstream
.
Do
(
req
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"请求失败: %w"
,
err
)
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
respBody
,
err
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"读取响应失败: %w"
,
err
)
}
if
resp
.
StatusCode
>=
400
{
return
nil
,
fmt
.
Errorf
(
"API 返回 %d: %s"
,
resp
.
StatusCode
,
string
(
respBody
))
}
// 提取响应文本
var
respData
map
[
string
]
any
text
:=
""
if
json
.
Unmarshal
(
respBody
,
&
respData
)
==
nil
{
if
content
,
ok
:=
respData
[
"content"
]
.
([]
any
);
ok
&&
len
(
content
)
>
0
{
if
block
,
ok
:=
content
[
0
]
.
(
map
[
string
]
any
);
ok
{
if
t
,
ok
:=
block
[
"text"
]
.
(
string
);
ok
{
text
=
t
}
}
}
}
return
&
TestConnectionResult
{
Text
:
text
,
MappedModel
:
modelID
,
},
nil
}
// buildGeminiTestRequest 构建 Gemini 格式测试请求
// 使用最小 token 消耗:输入 "." + maxOutputTokens: 1
func
(
s
*
AntigravityGatewayService
)
buildGeminiTestRequest
(
projectID
,
model
string
)
([]
byte
,
error
)
{
...
...
@@ -534,6 +649,10 @@ func (s *AntigravityGatewayService) getClaudeTransformOptions(ctx context.Contex
}
opts
.
EnableIdentityPatch
=
s
.
settingService
.
IsIdentityPatchEnabled
(
ctx
)
opts
.
IdentityPatch
=
s
.
settingService
.
GetIdentityPatchPrompt
(
ctx
)
if
group
,
ok
:=
ctx
.
Value
(
ctxkey
.
Group
)
.
(
*
Group
);
ok
&&
group
!=
nil
{
opts
.
EnableMCPXML
=
group
.
MCPXMLInject
}
return
opts
}
...
...
@@ -702,6 +821,11 @@ func isModelNotFoundError(statusCode int, body []byte) bool {
// Forward 转发 Claude 协议请求(Claude → Gemini 转换)
func
(
s
*
AntigravityGatewayService
)
Forward
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
)
(
*
ForwardResult
,
error
)
{
// 上游透传账号直接转发,不走 OAuth token 刷新
if
account
.
Type
==
AccountTypeUpstream
{
return
s
.
ForwardUpstream
(
ctx
,
c
,
account
,
body
)
}
startTime
:=
time
.
Now
()
sessionID
:=
getSessionID
(
c
)
prefix
:=
logPrefix
(
sessionID
,
account
.
Name
)
...
...
@@ -718,6 +842,12 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
originalModel
:=
claudeReq
.
Model
mappedModel
:=
s
.
getMappedModel
(
account
,
claudeReq
.
Model
)
quotaScope
,
_
:=
resolveAntigravityQuotaScope
(
originalModel
)
billingModel
:=
originalModel
if
antigravityUseMappedModelForBilling
()
&&
strings
.
TrimSpace
(
mappedModel
)
!=
""
{
billingModel
=
mappedModel
}
afterSwitch
:=
antigravityHasAccountSwitch
(
ctx
)
maxRetries
:=
antigravityMaxRetriesForModel
(
originalModel
,
afterSwitch
)
// 获取 access_token
if
s
.
tokenProvider
==
nil
{
...
...
@@ -766,6 +896,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
httpUpstream
:
s
.
httpUpstream
,
settingService
:
s
.
settingService
,
handleError
:
s
.
handleUpstreamError
,
maxRetries
:
maxRetries
,
})
if
err
!=
nil
{
return
nil
,
s
.
writeClaudeError
(
c
,
http
.
StatusBadGateway
,
"upstream_error"
,
"Upstream request failed after retries"
)
...
...
@@ -842,6 +973,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
httpUpstream
:
s
.
httpUpstream
,
settingService
:
s
.
settingService
,
handleError
:
s
.
handleUpstreamError
,
maxRetries
:
maxRetries
,
})
if
retryErr
!=
nil
{
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
...
...
@@ -917,6 +1049,39 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
// 处理错误响应(重试后仍失败或不触发重试)
if
resp
.
StatusCode
>=
400
{
if
resp
.
StatusCode
==
http
.
StatusBadRequest
{
upstreamMsg
:=
strings
.
TrimSpace
(
extractAntigravityErrorMessage
(
respBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
log
.
Printf
(
"%s status=400 prompt_too_long=%v upstream_message=%q request_id=%s body=%s"
,
prefix
,
isPromptTooLongError
(
respBody
),
upstreamMsg
,
resp
.
Header
.
Get
(
"x-request-id"
),
truncateForLog
(
respBody
,
500
))
}
if
resp
.
StatusCode
==
http
.
StatusBadRequest
&&
isPromptTooLongError
(
respBody
)
{
upstreamMsg
:=
strings
.
TrimSpace
(
extractAntigravityErrorMessage
(
respBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
logBody
:=
s
.
settingService
!=
nil
&&
s
.
settingService
.
cfg
!=
nil
&&
s
.
settingService
.
cfg
.
Gateway
.
LogUpstreamErrorBody
maxBytes
:=
2048
if
s
.
settingService
!=
nil
&&
s
.
settingService
.
cfg
!=
nil
&&
s
.
settingService
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
>
0
{
maxBytes
=
s
.
settingService
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
}
upstreamDetail
:=
""
if
logBody
{
upstreamDetail
=
truncateString
(
string
(
respBody
),
maxBytes
)
}
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Kind
:
"prompt_too_long"
,
Message
:
upstreamMsg
,
Detail
:
upstreamDetail
,
})
return
nil
,
&
PromptTooLongError
{
StatusCode
:
resp
.
StatusCode
,
RequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Body
:
respBody
,
}
}
s
.
handleUpstreamError
(
ctx
,
prefix
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
,
quotaScope
)
if
s
.
shouldFailoverUpstreamError
(
resp
.
StatusCode
)
{
...
...
@@ -978,7 +1143,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return
&
ForwardResult
{
RequestID
:
requestID
,
Usage
:
*
usage
,
Model
:
original
Model
,
//
使用原始模型用于计费和日志
Model
:
billing
Model
,
//
计费模型(可按映射模型覆盖)
Stream
:
claudeReq
.
Stream
,
Duration
:
time
.
Since
(
startTime
),
FirstTokenMs
:
firstTokenMs
,
...
...
@@ -1003,24 +1168,64 @@ func isSignatureRelatedError(respBody []byte) bool {
return
true
}
// Detect thinking block modification errors:
// "thinking or redacted_thinking blocks in the latest assistant message cannot be modified"
if
strings
.
Contains
(
msg
,
"cannot be modified"
)
&&
(
strings
.
Contains
(
msg
,
"thinking"
)
||
strings
.
Contains
(
msg
,
"redacted_thinking"
))
{
return
true
}
return
false
}
func
isPromptTooLongError
(
respBody
[]
byte
)
bool
{
msg
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
extractAntigravityErrorMessage
(
respBody
)))
if
msg
==
""
{
msg
=
strings
.
ToLower
(
string
(
respBody
))
}
return
strings
.
Contains
(
msg
,
"prompt is too long"
)
}
func
extractAntigravityErrorMessage
(
body
[]
byte
)
string
{
var
payload
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
payload
);
err
!=
nil
{
return
""
}
parseNestedMessage
:=
func
(
msg
string
)
string
{
trimmed
:=
strings
.
TrimSpace
(
msg
)
if
trimmed
==
""
||
!
strings
.
HasPrefix
(
trimmed
,
"{"
)
{
return
""
}
var
nested
map
[
string
]
any
if
err
:=
json
.
Unmarshal
([]
byte
(
trimmed
),
&
nested
);
err
!=
nil
{
return
""
}
if
errObj
,
ok
:=
nested
[
"error"
]
.
(
map
[
string
]
any
);
ok
{
if
innerMsg
,
ok
:=
errObj
[
"message"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
innerMsg
)
!=
""
{
return
innerMsg
}
}
if
innerMsg
,
ok
:=
nested
[
"message"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
innerMsg
)
!=
""
{
return
innerMsg
}
return
""
}
// Google-style: {"error": {"message": "..."}}
if
errObj
,
ok
:=
payload
[
"error"
]
.
(
map
[
string
]
any
);
ok
{
if
msg
,
ok
:=
errObj
[
"message"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
msg
)
!=
""
{
if
innerMsg
:=
parseNestedMessage
(
msg
);
innerMsg
!=
""
{
return
innerMsg
}
return
msg
}
}
// Fallback: top-level message
if
msg
,
ok
:=
payload
[
"message"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
msg
)
!=
""
{
if
innerMsg
:=
parseNestedMessage
(
msg
);
innerMsg
!=
""
{
return
innerMsg
}
return
msg
}
...
...
@@ -1248,6 +1453,208 @@ func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeReque
return
changed
,
nil
}
// ForwardUpstream 透传请求到上游 Antigravity 服务
// 用于 upstream 类型账号,直接使用 base_url + api_key 转发,不走 OAuth token
func
(
s
*
AntigravityGatewayService
)
ForwardUpstream
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
)
(
*
ForwardResult
,
error
)
{
startTime
:=
time
.
Now
()
sessionID
:=
getSessionID
(
c
)
prefix
:=
logPrefix
(
sessionID
,
account
.
Name
)
// 获取上游配置
baseURL
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"base_url"
))
apiKey
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"api_key"
))
if
baseURL
==
""
||
apiKey
==
""
{
return
nil
,
fmt
.
Errorf
(
"upstream account missing base_url or api_key"
)
}
baseURL
=
strings
.
TrimSuffix
(
baseURL
,
"/"
)
// 解析请求获取模型信息
var
claudeReq
antigravity
.
ClaudeRequest
if
err
:=
json
.
Unmarshal
(
body
,
&
claudeReq
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"parse claude request: %w"
,
err
)
}
if
strings
.
TrimSpace
(
claudeReq
.
Model
)
==
""
{
return
nil
,
fmt
.
Errorf
(
"missing model"
)
}
originalModel
:=
claudeReq
.
Model
billingModel
:=
originalModel
// 构建上游请求 URL
upstreamURL
:=
baseURL
+
"/v1/messages"
// 创建请求
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodPost
,
upstreamURL
,
bytes
.
NewReader
(
body
))
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"create upstream request: %w"
,
err
)
}
// 设置请求头
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
apiKey
)
req
.
Header
.
Set
(
"x-api-key"
,
apiKey
)
// Claude API 兼容
// 透传 Claude 相关 headers
if
v
:=
c
.
GetHeader
(
"anthropic-version"
);
v
!=
""
{
req
.
Header
.
Set
(
"anthropic-version"
,
v
)
}
if
v
:=
c
.
GetHeader
(
"anthropic-beta"
);
v
!=
""
{
req
.
Header
.
Set
(
"anthropic-beta"
,
v
)
}
// 代理 URL
proxyURL
:=
""
if
account
.
ProxyID
!=
nil
&&
account
.
Proxy
!=
nil
{
proxyURL
=
account
.
Proxy
.
URL
()
}
// 发送请求
resp
,
err
:=
s
.
httpUpstream
.
Do
(
req
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
err
!=
nil
{
log
.
Printf
(
"%s upstream request failed: %v"
,
prefix
,
err
)
return
nil
,
fmt
.
Errorf
(
"upstream request failed: %w"
,
err
)
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
// 处理错误响应
if
resp
.
StatusCode
>=
400
{
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
// 429 错误时标记账号限流
if
resp
.
StatusCode
==
http
.
StatusTooManyRequests
{
s
.
handleUpstreamError
(
ctx
,
prefix
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
,
AntigravityQuotaScopeClaude
)
}
// 透传上游错误
c
.
Header
(
"Content-Type"
,
resp
.
Header
.
Get
(
"Content-Type"
))
c
.
Status
(
resp
.
StatusCode
)
_
,
_
=
c
.
Writer
.
Write
(
respBody
)
return
&
ForwardResult
{
Model
:
billingModel
,
},
nil
}
// 处理成功响应(流式/非流式)
var
usage
*
ClaudeUsage
var
firstTokenMs
*
int
if
claudeReq
.
Stream
{
// 流式响应:透传
c
.
Header
(
"Content-Type"
,
"text/event-stream"
)
c
.
Header
(
"Cache-Control"
,
"no-cache"
)
c
.
Header
(
"Connection"
,
"keep-alive"
)
c
.
Header
(
"X-Accel-Buffering"
,
"no"
)
c
.
Status
(
http
.
StatusOK
)
usage
,
firstTokenMs
=
s
.
streamUpstreamResponse
(
c
,
resp
,
startTime
)
}
else
{
// 非流式响应:直接透传
respBody
,
err
:=
io
.
ReadAll
(
resp
.
Body
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"read upstream response: %w"
,
err
)
}
// 提取 usage
usage
=
s
.
extractClaudeUsage
(
respBody
)
c
.
Header
(
"Content-Type"
,
resp
.
Header
.
Get
(
"Content-Type"
))
c
.
Status
(
http
.
StatusOK
)
_
,
_
=
c
.
Writer
.
Write
(
respBody
)
}
// 构建计费结果
duration
:=
time
.
Since
(
startTime
)
log
.
Printf
(
"%s status=success duration_ms=%d"
,
prefix
,
duration
.
Milliseconds
())
return
&
ForwardResult
{
Model
:
billingModel
,
Stream
:
claudeReq
.
Stream
,
Duration
:
duration
,
FirstTokenMs
:
firstTokenMs
,
Usage
:
ClaudeUsage
{
InputTokens
:
usage
.
InputTokens
,
OutputTokens
:
usage
.
OutputTokens
,
CacheReadInputTokens
:
usage
.
CacheReadInputTokens
,
CacheCreationInputTokens
:
usage
.
CacheCreationInputTokens
,
},
},
nil
}
// streamUpstreamResponse 透传上游流式响应并提取 usage
func
(
s
*
AntigravityGatewayService
)
streamUpstreamResponse
(
c
*
gin
.
Context
,
resp
*
http
.
Response
,
startTime
time
.
Time
)
(
*
ClaudeUsage
,
*
int
)
{
usage
:=
&
ClaudeUsage
{}
var
firstTokenMs
*
int
var
firstTokenRecorded
bool
scanner
:=
bufio
.
NewScanner
(
resp
.
Body
)
buf
:=
make
([]
byte
,
0
,
64
*
1024
)
scanner
.
Buffer
(
buf
,
1024
*
1024
)
for
scanner
.
Scan
()
{
line
:=
scanner
.
Bytes
()
// 记录首 token 时间
if
!
firstTokenRecorded
&&
len
(
line
)
>
0
{
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
firstTokenMs
=
&
ms
firstTokenRecorded
=
true
}
// 尝试从 message_delta 或 message_stop 事件提取 usage
if
bytes
.
HasPrefix
(
line
,
[]
byte
(
"data: "
))
{
dataStr
:=
bytes
.
TrimPrefix
(
line
,
[]
byte
(
"data: "
))
var
event
map
[
string
]
any
if
json
.
Unmarshal
(
dataStr
,
&
event
)
==
nil
{
if
u
,
ok
:=
event
[
"usage"
]
.
(
map
[
string
]
any
);
ok
{
if
v
,
ok
:=
u
[
"input_tokens"
]
.
(
float64
);
ok
&&
int
(
v
)
>
0
{
usage
.
InputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"output_tokens"
]
.
(
float64
);
ok
&&
int
(
v
)
>
0
{
usage
.
OutputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"cache_read_input_tokens"
]
.
(
float64
);
ok
&&
int
(
v
)
>
0
{
usage
.
CacheReadInputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"cache_creation_input_tokens"
]
.
(
float64
);
ok
&&
int
(
v
)
>
0
{
usage
.
CacheCreationInputTokens
=
int
(
v
)
}
}
}
}
// 透传行
_
,
_
=
c
.
Writer
.
Write
(
line
)
_
,
_
=
c
.
Writer
.
Write
([]
byte
(
"
\n
"
))
c
.
Writer
.
Flush
()
}
return
usage
,
firstTokenMs
}
// extractClaudeUsage 从非流式 Claude 响应提取 usage
func
(
s
*
AntigravityGatewayService
)
extractClaudeUsage
(
body
[]
byte
)
*
ClaudeUsage
{
usage
:=
&
ClaudeUsage
{}
var
resp
map
[
string
]
any
if
json
.
Unmarshal
(
body
,
&
resp
)
!=
nil
{
return
usage
}
if
u
,
ok
:=
resp
[
"usage"
]
.
(
map
[
string
]
any
);
ok
{
if
v
,
ok
:=
u
[
"input_tokens"
]
.
(
float64
);
ok
{
usage
.
InputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"output_tokens"
]
.
(
float64
);
ok
{
usage
.
OutputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"cache_read_input_tokens"
]
.
(
float64
);
ok
{
usage
.
CacheReadInputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"cache_creation_input_tokens"
]
.
(
float64
);
ok
{
usage
.
CacheCreationInputTokens
=
int
(
v
)
}
}
return
usage
}
// ForwardGemini 转发 Gemini 协议请求
func
(
s
*
AntigravityGatewayService
)
ForwardGemini
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
originalModel
string
,
action
string
,
stream
bool
,
body
[]
byte
)
(
*
ForwardResult
,
error
)
{
startTime
:=
time
.
Now
()
...
...
@@ -1287,6 +1694,12 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
}
mappedModel
:=
s
.
getMappedModel
(
account
,
originalModel
)
billingModel
:=
originalModel
if
antigravityUseMappedModelForBilling
()
&&
strings
.
TrimSpace
(
mappedModel
)
!=
""
{
billingModel
=
mappedModel
}
afterSwitch
:=
antigravityHasAccountSwitch
(
ctx
)
maxRetries
:=
antigravityMaxRetriesForModel
(
originalModel
,
afterSwitch
)
// 获取 access_token
if
s
.
tokenProvider
==
nil
{
...
...
@@ -1306,8 +1719,15 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
proxyURL
=
account
.
Proxy
.
URL
()
}
// 过滤掉 parts 为空的消息(Gemini API 不接受空 parts)
filteredBody
,
err
:=
filterEmptyPartsFromGeminiRequest
(
body
)
if
err
!=
nil
{
log
.
Printf
(
"[Antigravity] Failed to filter empty parts: %v"
,
err
)
filteredBody
=
body
}
// Antigravity 上游要求必须包含身份提示词,注入到请求中
injectedBody
,
err
:=
injectIdentityPatchToGeminiRequest
(
b
ody
)
injectedBody
,
err
:=
injectIdentityPatchToGeminiRequest
(
filteredB
ody
)
if
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -1344,6 +1764,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
httpUpstream
:
s
.
httpUpstream
,
settingService
:
s
.
settingService
,
handleError
:
s
.
handleUpstreamError
,
maxRetries
:
maxRetries
,
})
if
err
!=
nil
{
return
nil
,
s
.
writeGoogleError
(
c
,
http
.
StatusBadGateway
,
"Upstream request failed after retries"
)
...
...
@@ -1493,7 +1914,7 @@ handleSuccess:
return
&
ForwardResult
{
RequestID
:
requestID
,
Usage
:
*
usage
,
Model
:
original
Model
,
Model
:
billing
Model
,
Stream
:
stream
,
Duration
:
time
.
Since
(
startTime
),
FirstTokenMs
:
firstTokenMs
,
...
...
@@ -1544,6 +1965,81 @@ func antigravityUseScopeRateLimit() bool {
return
true
}
func
antigravityHasAccountSwitch
(
ctx
context
.
Context
)
bool
{
if
ctx
==
nil
{
return
false
}
if
v
,
ok
:=
ctx
.
Value
(
ctxkey
.
AccountSwitchCount
)
.
(
int
);
ok
{
return
v
>
0
}
return
false
}
func
antigravityMaxRetries
()
int
{
raw
:=
strings
.
TrimSpace
(
os
.
Getenv
(
antigravityMaxRetriesEnv
))
if
raw
==
""
{
return
antigravityDefaultMaxRetries
}
value
,
err
:=
strconv
.
Atoi
(
raw
)
if
err
!=
nil
||
value
<=
0
{
return
antigravityDefaultMaxRetries
}
return
value
}
func
antigravityMaxRetriesAfterSwitch
()
int
{
raw
:=
strings
.
TrimSpace
(
os
.
Getenv
(
antigravityMaxRetriesAfterSwitchEnv
))
if
raw
==
""
{
return
antigravityMaxRetries
()
}
value
,
err
:=
strconv
.
Atoi
(
raw
)
if
err
!=
nil
||
value
<=
0
{
return
antigravityMaxRetries
()
}
return
value
}
// antigravityMaxRetriesForModel 根据模型类型获取重试次数
// 优先使用模型细分配置,未设置则回退到平台级配置
func
antigravityMaxRetriesForModel
(
model
string
,
afterSwitch
bool
)
int
{
var
envKey
string
if
strings
.
HasPrefix
(
model
,
"claude-"
)
{
envKey
=
antigravityMaxRetriesClaudeEnv
}
else
if
isImageGenerationModel
(
model
)
{
envKey
=
antigravityMaxRetriesGeminiImageEnv
}
else
if
strings
.
HasPrefix
(
model
,
"gemini-"
)
{
envKey
=
antigravityMaxRetriesGeminiTextEnv
}
if
envKey
!=
""
{
if
raw
:=
strings
.
TrimSpace
(
os
.
Getenv
(
envKey
));
raw
!=
""
{
if
value
,
err
:=
strconv
.
Atoi
(
raw
);
err
==
nil
&&
value
>
0
{
return
value
}
}
}
if
afterSwitch
{
return
antigravityMaxRetriesAfterSwitch
()
}
return
antigravityMaxRetries
()
}
func
antigravityUseMappedModelForBilling
()
bool
{
v
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
os
.
Getenv
(
antigravityBillingModelEnv
)))
return
v
==
"1"
||
v
==
"true"
||
v
==
"yes"
||
v
==
"on"
}
func
antigravityFallbackCooldownSeconds
()
(
time
.
Duration
,
bool
)
{
raw
:=
strings
.
TrimSpace
(
os
.
Getenv
(
antigravityFallbackSecondsEnv
))
if
raw
==
""
{
return
0
,
false
}
seconds
,
err
:=
strconv
.
Atoi
(
raw
)
if
err
!=
nil
||
seconds
<=
0
{
return
0
,
false
}
return
time
.
Duration
(
seconds
)
*
time
.
Second
,
true
}
func
(
s
*
AntigravityGatewayService
)
handleUpstreamError
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
quotaScope
AntigravityQuotaScope
)
{
// 429 使用 Gemini 格式解析(从 body 解析重置时间)
if
statusCode
==
429
{
...
...
@@ -1556,6 +2052,9 @@ func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, pre
fallbackMinutes
=
s
.
settingService
.
cfg
.
Gateway
.
AntigravityFallbackCooldownMinutes
}
defaultDur
:=
time
.
Duration
(
fallbackMinutes
)
*
time
.
Minute
if
fallbackDur
,
ok
:=
antigravityFallbackCooldownSeconds
();
ok
{
defaultDur
=
fallbackDur
}
ra
:=
time
.
Now
()
.
Add
(
defaultDur
)
if
useScopeLimit
{
log
.
Printf
(
"%s status=429 rate_limited scope=%s reset_in=%v (fallback)"
,
prefix
,
quotaScope
,
defaultDur
)
...
...
@@ -2193,6 +2692,10 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou
return
fmt
.
Errorf
(
"upstream error: %d message=%s"
,
upstreamStatus
,
upstreamMsg
)
}
func
(
s
*
AntigravityGatewayService
)
WriteMappedClaudeError
(
c
*
gin
.
Context
,
account
*
Account
,
upstreamStatus
int
,
upstreamRequestID
string
,
body
[]
byte
)
error
{
return
s
.
writeMappedClaudeError
(
c
,
account
,
upstreamStatus
,
upstreamRequestID
,
body
)
}
func
(
s
*
AntigravityGatewayService
)
writeGoogleError
(
c
*
gin
.
Context
,
status
int
,
message
string
)
error
{
statusStr
:=
"UNKNOWN"
switch
status
{
...
...
@@ -2618,3 +3121,55 @@ func cleanGeminiRequest(body []byte) ([]byte, error) {
return
json
.
Marshal
(
payload
)
}
// filterEmptyPartsFromGeminiRequest 过滤 Gemini 请求中 parts 为空的消息
// Gemini API 不接受 parts 为空数组的消息,会返回 400 错误
func
filterEmptyPartsFromGeminiRequest
(
body
[]
byte
)
([]
byte
,
error
)
{
var
payload
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
payload
);
err
!=
nil
{
return
nil
,
err
}
contents
,
ok
:=
payload
[
"contents"
]
.
([]
any
)
if
!
ok
||
len
(
contents
)
==
0
{
return
body
,
nil
}
filtered
:=
make
([]
any
,
0
,
len
(
contents
))
modified
:=
false
for
_
,
c
:=
range
contents
{
contentMap
,
ok
:=
c
.
(
map
[
string
]
any
)
if
!
ok
{
filtered
=
append
(
filtered
,
c
)
continue
}
parts
,
hasParts
:=
contentMap
[
"parts"
]
if
!
hasParts
{
filtered
=
append
(
filtered
,
c
)
continue
}
partsSlice
,
ok
:=
parts
.
([]
any
)
if
!
ok
{
filtered
=
append
(
filtered
,
c
)
continue
}
// 跳过 parts 为空数组的消息
if
len
(
partsSlice
)
==
0
{
modified
=
true
continue
}
filtered
=
append
(
filtered
,
c
)
}
if
!
modified
{
return
body
,
nil
}
payload
[
"contents"
]
=
filtered
return
json
.
Marshal
(
payload
)
}
backend/internal/service/antigravity_gateway_service_test.go
View file @
de7ff902
package
service
import
(
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
...
...
@@ -81,3 +87,106 @@ func TestStripThinkingFromClaudeRequest_DoesNotDowngradeTools(t *testing.T) {
require
.
Equal
(
t
,
"secret plan"
,
blocks
[
0
][
"text"
])
require
.
Equal
(
t
,
"tool_use"
,
blocks
[
1
][
"type"
])
}
func
TestIsPromptTooLongError
(
t
*
testing
.
T
)
{
require
.
True
(
t
,
isPromptTooLongError
([]
byte
(
`{"error":{"message":"Prompt is too long"}}`
)))
require
.
True
(
t
,
isPromptTooLongError
([]
byte
(
`{"message":"Prompt is too long"}`
)))
require
.
False
(
t
,
isPromptTooLongError
([]
byte
(
`{"error":{"message":"other"}}`
)))
}
type
httpUpstreamStub
struct
{
resp
*
http
.
Response
err
error
}
func
(
s
*
httpUpstreamStub
)
Do
(
_
*
http
.
Request
,
_
string
,
_
int64
,
_
int
)
(
*
http
.
Response
,
error
)
{
return
s
.
resp
,
s
.
err
}
func
(
s
*
httpUpstreamStub
)
DoWithTLS
(
_
*
http
.
Request
,
_
string
,
_
int64
,
_
int
,
_
bool
)
(
*
http
.
Response
,
error
)
{
return
s
.
resp
,
s
.
err
}
func
TestAntigravityGatewayService_Forward_PromptTooLong
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
writer
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
writer
)
body
,
err
:=
json
.
Marshal
(
map
[
string
]
any
{
"model"
:
"claude-opus-4-5"
,
"messages"
:
[]
map
[
string
]
any
{
{
"role"
:
"user"
,
"content"
:
"hi"
},
},
"max_tokens"
:
1
,
"stream"
:
false
,
})
require
.
NoError
(
t
,
err
)
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1/messages"
,
bytes
.
NewReader
(
body
))
c
.
Request
=
req
respBody
:=
[]
byte
(
`{"error":{"message":"Prompt is too long"}}`
)
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusBadRequest
,
Header
:
http
.
Header
{
"X-Request-Id"
:
[]
string
{
"req-1"
}},
Body
:
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
)),
}
svc
:=
&
AntigravityGatewayService
{
tokenProvider
:
&
AntigravityTokenProvider
{},
httpUpstream
:
&
httpUpstreamStub
{
resp
:
resp
},
}
account
:=
&
Account
{
ID
:
1
,
Name
:
"acc-1"
,
Platform
:
PlatformAntigravity
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"token"
,
},
}
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
c
,
account
,
body
)
require
.
Nil
(
t
,
result
)
var
promptErr
*
PromptTooLongError
require
.
ErrorAs
(
t
,
err
,
&
promptErr
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
promptErr
.
StatusCode
)
require
.
Equal
(
t
,
"req-1"
,
promptErr
.
RequestID
)
require
.
NotEmpty
(
t
,
promptErr
.
Body
)
raw
,
ok
:=
c
.
Get
(
OpsUpstreamErrorsKey
)
require
.
True
(
t
,
ok
)
events
,
ok
:=
raw
.
([]
*
OpsUpstreamErrorEvent
)
require
.
True
(
t
,
ok
)
require
.
Len
(
t
,
events
,
1
)
require
.
Equal
(
t
,
"prompt_too_long"
,
events
[
0
]
.
Kind
)
}
func
TestAntigravityMaxRetriesForModel_AfterSwitch
(
t
*
testing
.
T
)
{
t
.
Setenv
(
antigravityMaxRetriesEnv
,
"4"
)
t
.
Setenv
(
antigravityMaxRetriesAfterSwitchEnv
,
"7"
)
t
.
Setenv
(
antigravityMaxRetriesClaudeEnv
,
""
)
t
.
Setenv
(
antigravityMaxRetriesGeminiTextEnv
,
""
)
t
.
Setenv
(
antigravityMaxRetriesGeminiImageEnv
,
""
)
got
:=
antigravityMaxRetriesForModel
(
"claude-sonnet-4-5"
,
false
)
require
.
Equal
(
t
,
4
,
got
)
got
=
antigravityMaxRetriesForModel
(
"claude-sonnet-4-5"
,
true
)
require
.
Equal
(
t
,
7
,
got
)
}
func
TestAntigravityMaxRetriesForModel_AfterSwitchFallback
(
t
*
testing
.
T
)
{
t
.
Setenv
(
antigravityMaxRetriesEnv
,
"5"
)
t
.
Setenv
(
antigravityMaxRetriesAfterSwitchEnv
,
""
)
t
.
Setenv
(
antigravityMaxRetriesClaudeEnv
,
""
)
t
.
Setenv
(
antigravityMaxRetriesGeminiTextEnv
,
""
)
t
.
Setenv
(
antigravityMaxRetriesGeminiImageEnv
,
""
)
got
:=
antigravityMaxRetriesForModel
(
"gemini-2.5-flash"
,
true
)
require
.
Equal
(
t
,
5
,
got
)
}
backend/internal/service/antigravity_quota_scope.go
View file @
de7ff902
package
service
import
(
"slices"
"strings"
"time"
)
...
...
@@ -16,6 +17,21 @@ const (
AntigravityQuotaScopeGeminiImage
AntigravityQuotaScope
=
"gemini_image"
)
// IsScopeSupported 检查给定的 scope 是否在分组支持的 scope 列表中
func
IsScopeSupported
(
supportedScopes
[]
string
,
scope
AntigravityQuotaScope
)
bool
{
if
len
(
supportedScopes
)
==
0
{
// 未配置时默认全部支持
return
true
}
supported
:=
slices
.
Contains
(
supportedScopes
,
string
(
scope
))
return
supported
}
// ResolveAntigravityQuotaScope 根据模型名称解析配额域(导出版本)
func
ResolveAntigravityQuotaScope
(
requestedModel
string
)
(
AntigravityQuotaScope
,
bool
)
{
return
resolveAntigravityQuotaScope
(
requestedModel
)
}
// resolveAntigravityQuotaScope 根据模型名称解析配额域
func
resolveAntigravityQuotaScope
(
requestedModel
string
)
(
AntigravityQuotaScope
,
bool
)
{
model
:=
normalizeAntigravityModelName
(
requestedModel
)
...
...
backend/internal/service/api_key.go
View file @
de7ff902
...
...
@@ -2,6 +2,14 @@ package service
import
"time"
// API Key status constants
const
(
StatusAPIKeyActive
=
"active"
StatusAPIKeyDisabled
=
"disabled"
StatusAPIKeyQuotaExhausted
=
"quota_exhausted"
StatusAPIKeyExpired
=
"expired"
)
type
APIKey
struct
{
ID
int64
UserID
int64
...
...
@@ -15,8 +23,53 @@ type APIKey struct {
UpdatedAt
time
.
Time
User
*
User
Group
*
Group
// Quota fields
Quota
float64
// Quota limit in USD (0 = unlimited)
QuotaUsed
float64
// Used quota amount
ExpiresAt
*
time
.
Time
// Expiration time (nil = never expires)
}
func
(
k
*
APIKey
)
IsActive
()
bool
{
return
k
.
Status
==
StatusActive
}
// IsExpired checks if the API key has expired
func
(
k
*
APIKey
)
IsExpired
()
bool
{
if
k
.
ExpiresAt
==
nil
{
return
false
}
return
time
.
Now
()
.
After
(
*
k
.
ExpiresAt
)
}
// IsQuotaExhausted checks if the API key quota is exhausted
func
(
k
*
APIKey
)
IsQuotaExhausted
()
bool
{
if
k
.
Quota
<=
0
{
return
false
// unlimited
}
return
k
.
QuotaUsed
>=
k
.
Quota
}
// GetQuotaRemaining returns remaining quota (-1 for unlimited)
func
(
k
*
APIKey
)
GetQuotaRemaining
()
float64
{
if
k
.
Quota
<=
0
{
return
-
1
// unlimited
}
remaining
:=
k
.
Quota
-
k
.
QuotaUsed
if
remaining
<
0
{
return
0
}
return
remaining
}
// GetDaysUntilExpiry returns days until expiry (-1 for never expires)
func
(
k
*
APIKey
)
GetDaysUntilExpiry
()
int
{
if
k
.
ExpiresAt
==
nil
{
return
-
1
// never expires
}
duration
:=
time
.
Until
(
*
k
.
ExpiresAt
)
if
duration
<
0
{
return
0
}
return
int
(
duration
.
Hours
()
/
24
)
}
backend/internal/service/api_key_auth_cache.go
View file @
de7ff902
package
service
import
"time"
// APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段)
type
APIKeyAuthSnapshot
struct
{
APIKeyID
int64
`json:"api_key_id"`
...
...
@@ -10,6 +12,13 @@ type APIKeyAuthSnapshot struct {
IPBlacklist
[]
string
`json:"ip_blacklist,omitempty"`
User
APIKeyAuthUserSnapshot
`json:"user"`
Group
*
APIKeyAuthGroupSnapshot
`json:"group,omitempty"`
// Quota fields for API Key independent quota feature
Quota
float64
`json:"quota"`
// Quota limit in USD (0 = unlimited)
QuotaUsed
float64
`json:"quota_used"`
// Used quota amount
// Expiration field for API Key expiration feature
ExpiresAt
*
time
.
Time
`json:"expires_at,omitempty"`
// Expiration time (nil = never expires)
}
// APIKeyAuthUserSnapshot 用户快照
...
...
@@ -23,29 +32,34 @@ type APIKeyAuthUserSnapshot struct {
// APIKeyAuthGroupSnapshot 分组快照
type
APIKeyAuthGroupSnapshot
struct
{
ID
int64
`json:"id"`
Name
string
`json:"name"`
Platform
string
`json:"platform"`
Status
string
`json:"status"`
SubscriptionType
string
`json:"subscription_type"`
RateMultiplier
float64
`json:"rate_multiplier"`
DailyLimitUSD
*
float64
`json:"daily_limit_usd,omitempty"`
WeeklyLimitUSD
*
float64
`json:"weekly_limit_usd,omitempty"`
MonthlyLimitUSD
*
float64
`json:"monthly_limit_usd,omitempty"`
ImagePrice1K
*
float64
`json:"image_price_1k,omitempty"`
ImagePrice2K
*
float64
`json:"image_price_2k,omitempty"`
ImagePrice4K
*
float64
`json:"image_price_4k,omitempty"`
SoraImagePrice360
*
float64
`json:"sora_image_price_360,omitempty"`
SoraImagePrice540
*
float64
`json:"sora_image_price_540,omitempty"`
SoraVideoPricePerRequest
*
float64
`json:"sora_video_price_per_request,omitempty"`
SoraVideoPricePerRequestHD
*
float64
`json:"sora_video_price_per_request_hd,omitempty"`
ClaudeCodeOnly
bool
`json:"claude_code_only"`
FallbackGroupID
*
int64
`json:"fallback_group_id,omitempty"`
ID
int64
`json:"id"`
Name
string
`json:"name"`
Platform
string
`json:"platform"`
Status
string
`json:"status"`
SubscriptionType
string
`json:"subscription_type"`
RateMultiplier
float64
`json:"rate_multiplier"`
DailyLimitUSD
*
float64
`json:"daily_limit_usd,omitempty"`
WeeklyLimitUSD
*
float64
`json:"weekly_limit_usd,omitempty"`
MonthlyLimitUSD
*
float64
`json:"monthly_limit_usd,omitempty"`
ImagePrice1K
*
float64
`json:"image_price_1k,omitempty"`
ImagePrice2K
*
float64
`json:"image_price_2k,omitempty"`
ImagePrice4K
*
float64
`json:"image_price_4k,omitempty"`
SoraImagePrice360
*
float64
`json:"sora_image_price_360,omitempty"`
SoraImagePrice540
*
float64
`json:"sora_image_price_540,omitempty"`
SoraVideoPricePerRequest
*
float64
`json:"sora_video_price_per_request,omitempty"`
SoraVideoPricePerRequestHD
*
float64
`json:"sora_video_price_per_request_hd,omitempty"`
ClaudeCodeOnly
bool
`json:"claude_code_only"`
FallbackGroupID
*
int64
`json:"fallback_group_id,omitempty"`
FallbackGroupIDOnInvalidRequest
*
int64
`json:"fallback_group_id_on_invalid_request,omitempty"`
// Model routing is used by gateway account selection, so it must be part of auth cache snapshot.
// Only anthropic groups use these fields; others may leave them empty.
ModelRouting
map
[
string
][]
int64
`json:"model_routing,omitempty"`
ModelRoutingEnabled
bool
`json:"model_routing_enabled"`
MCPXMLInject
bool
`json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes
[]
string
`json:"supported_model_scopes,omitempty"`
}
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
...
...
backend/internal/service/api_key_auth_cache_impl.go
View file @
de7ff902
...
...
@@ -213,6 +213,9 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
Status
:
apiKey
.
Status
,
IPWhitelist
:
apiKey
.
IPWhitelist
,
IPBlacklist
:
apiKey
.
IPBlacklist
,
Quota
:
apiKey
.
Quota
,
QuotaUsed
:
apiKey
.
QuotaUsed
,
ExpiresAt
:
apiKey
.
ExpiresAt
,
User
:
APIKeyAuthUserSnapshot
{
ID
:
apiKey
.
User
.
ID
,
Status
:
apiKey
.
User
.
Status
,
...
...
@@ -223,26 +226,29 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
}
if
apiKey
.
Group
!=
nil
{
snapshot
.
Group
=
&
APIKeyAuthGroupSnapshot
{
ID
:
apiKey
.
Group
.
ID
,
Name
:
apiKey
.
Group
.
Name
,
Platform
:
apiKey
.
Group
.
Platform
,
Status
:
apiKey
.
Group
.
Status
,
SubscriptionType
:
apiKey
.
Group
.
SubscriptionType
,
RateMultiplier
:
apiKey
.
Group
.
RateMultiplier
,
DailyLimitUSD
:
apiKey
.
Group
.
DailyLimitUSD
,
WeeklyLimitUSD
:
apiKey
.
Group
.
WeeklyLimitUSD
,
MonthlyLimitUSD
:
apiKey
.
Group
.
MonthlyLimitUSD
,
ImagePrice1K
:
apiKey
.
Group
.
ImagePrice1K
,
ImagePrice2K
:
apiKey
.
Group
.
ImagePrice2K
,
ImagePrice4K
:
apiKey
.
Group
.
ImagePrice4K
,
SoraImagePrice360
:
apiKey
.
Group
.
SoraImagePrice360
,
SoraImagePrice540
:
apiKey
.
Group
.
SoraImagePrice540
,
SoraVideoPricePerRequest
:
apiKey
.
Group
.
SoraVideoPricePerRequest
,
SoraVideoPricePerRequestHD
:
apiKey
.
Group
.
SoraVideoPricePerRequestHD
,
ClaudeCodeOnly
:
apiKey
.
Group
.
ClaudeCodeOnly
,
FallbackGroupID
:
apiKey
.
Group
.
FallbackGroupID
,
ModelRouting
:
apiKey
.
Group
.
ModelRouting
,
ModelRoutingEnabled
:
apiKey
.
Group
.
ModelRoutingEnabled
,
ID
:
apiKey
.
Group
.
ID
,
Name
:
apiKey
.
Group
.
Name
,
Platform
:
apiKey
.
Group
.
Platform
,
Status
:
apiKey
.
Group
.
Status
,
SubscriptionType
:
apiKey
.
Group
.
SubscriptionType
,
RateMultiplier
:
apiKey
.
Group
.
RateMultiplier
,
DailyLimitUSD
:
apiKey
.
Group
.
DailyLimitUSD
,
WeeklyLimitUSD
:
apiKey
.
Group
.
WeeklyLimitUSD
,
MonthlyLimitUSD
:
apiKey
.
Group
.
MonthlyLimitUSD
,
ImagePrice1K
:
apiKey
.
Group
.
ImagePrice1K
,
ImagePrice2K
:
apiKey
.
Group
.
ImagePrice2K
,
ImagePrice4K
:
apiKey
.
Group
.
ImagePrice4K
,
SoraImagePrice360
:
apiKey
.
Group
.
SoraImagePrice360
,
SoraImagePrice540
:
apiKey
.
Group
.
SoraImagePrice540
,
SoraVideoPricePerRequest
:
apiKey
.
Group
.
SoraVideoPricePerRequest
,
SoraVideoPricePerRequestHD
:
apiKey
.
Group
.
SoraVideoPricePerRequestHD
,
ClaudeCodeOnly
:
apiKey
.
Group
.
ClaudeCodeOnly
,
FallbackGroupID
:
apiKey
.
Group
.
FallbackGroupID
,
FallbackGroupIDOnInvalidRequest
:
apiKey
.
Group
.
FallbackGroupIDOnInvalidRequest
,
ModelRouting
:
apiKey
.
Group
.
ModelRouting
,
ModelRoutingEnabled
:
apiKey
.
Group
.
ModelRoutingEnabled
,
MCPXMLInject
:
apiKey
.
Group
.
MCPXMLInject
,
SupportedModelScopes
:
apiKey
.
Group
.
SupportedModelScopes
,
}
}
return
snapshot
...
...
@@ -260,6 +266,9 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
Status
:
snapshot
.
Status
,
IPWhitelist
:
snapshot
.
IPWhitelist
,
IPBlacklist
:
snapshot
.
IPBlacklist
,
Quota
:
snapshot
.
Quota
,
QuotaUsed
:
snapshot
.
QuotaUsed
,
ExpiresAt
:
snapshot
.
ExpiresAt
,
User
:
&
User
{
ID
:
snapshot
.
User
.
ID
,
Status
:
snapshot
.
User
.
Status
,
...
...
@@ -270,27 +279,30 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
}
if
snapshot
.
Group
!=
nil
{
apiKey
.
Group
=
&
Group
{
ID
:
snapshot
.
Group
.
ID
,
Name
:
snapshot
.
Group
.
Name
,
Platform
:
snapshot
.
Group
.
Platform
,
Status
:
snapshot
.
Group
.
Status
,
Hydrated
:
true
,
SubscriptionType
:
snapshot
.
Group
.
SubscriptionType
,
RateMultiplier
:
snapshot
.
Group
.
RateMultiplier
,
DailyLimitUSD
:
snapshot
.
Group
.
DailyLimitUSD
,
WeeklyLimitUSD
:
snapshot
.
Group
.
WeeklyLimitUSD
,
MonthlyLimitUSD
:
snapshot
.
Group
.
MonthlyLimitUSD
,
ImagePrice1K
:
snapshot
.
Group
.
ImagePrice1K
,
ImagePrice2K
:
snapshot
.
Group
.
ImagePrice2K
,
ImagePrice4K
:
snapshot
.
Group
.
ImagePrice4K
,
SoraImagePrice360
:
snapshot
.
Group
.
SoraImagePrice360
,
SoraImagePrice540
:
snapshot
.
Group
.
SoraImagePrice540
,
SoraVideoPricePerRequest
:
snapshot
.
Group
.
SoraVideoPricePerRequest
,
SoraVideoPricePerRequestHD
:
snapshot
.
Group
.
SoraVideoPricePerRequestHD
,
ClaudeCodeOnly
:
snapshot
.
Group
.
ClaudeCodeOnly
,
FallbackGroupID
:
snapshot
.
Group
.
FallbackGroupID
,
ModelRouting
:
snapshot
.
Group
.
ModelRouting
,
ModelRoutingEnabled
:
snapshot
.
Group
.
ModelRoutingEnabled
,
ID
:
snapshot
.
Group
.
ID
,
Name
:
snapshot
.
Group
.
Name
,
Platform
:
snapshot
.
Group
.
Platform
,
Status
:
snapshot
.
Group
.
Status
,
Hydrated
:
true
,
SubscriptionType
:
snapshot
.
Group
.
SubscriptionType
,
RateMultiplier
:
snapshot
.
Group
.
RateMultiplier
,
DailyLimitUSD
:
snapshot
.
Group
.
DailyLimitUSD
,
WeeklyLimitUSD
:
snapshot
.
Group
.
WeeklyLimitUSD
,
MonthlyLimitUSD
:
snapshot
.
Group
.
MonthlyLimitUSD
,
ImagePrice1K
:
snapshot
.
Group
.
ImagePrice1K
,
ImagePrice2K
:
snapshot
.
Group
.
ImagePrice2K
,
ImagePrice4K
:
snapshot
.
Group
.
ImagePrice4K
,
SoraImagePrice360
:
snapshot
.
Group
.
SoraImagePrice360
,
SoraImagePrice540
:
snapshot
.
Group
.
SoraImagePrice540
,
SoraVideoPricePerRequest
:
snapshot
.
Group
.
SoraVideoPricePerRequest
,
SoraVideoPricePerRequestHD
:
snapshot
.
Group
.
SoraVideoPricePerRequestHD
,
ClaudeCodeOnly
:
snapshot
.
Group
.
ClaudeCodeOnly
,
FallbackGroupID
:
snapshot
.
Group
.
FallbackGroupID
,
FallbackGroupIDOnInvalidRequest
:
snapshot
.
Group
.
FallbackGroupIDOnInvalidRequest
,
ModelRouting
:
snapshot
.
Group
.
ModelRouting
,
ModelRoutingEnabled
:
snapshot
.
Group
.
ModelRoutingEnabled
,
MCPXMLInject
:
snapshot
.
Group
.
MCPXMLInject
,
SupportedModelScopes
:
snapshot
.
Group
.
SupportedModelScopes
,
}
}
return
apiKey
...
...
backend/internal/service/api_key_service.go
View file @
de7ff902
...
...
@@ -24,6 +24,10 @@ var (
ErrAPIKeyInvalidChars
=
infraerrors
.
BadRequest
(
"API_KEY_INVALID_CHARS"
,
"api key can only contain letters, numbers, underscores, and hyphens"
)
ErrAPIKeyRateLimited
=
infraerrors
.
TooManyRequests
(
"API_KEY_RATE_LIMITED"
,
"too many failed attempts, please try again later"
)
ErrInvalidIPPattern
=
infraerrors
.
BadRequest
(
"INVALID_IP_PATTERN"
,
"invalid IP or CIDR pattern"
)
// ErrAPIKeyExpired = infraerrors.Forbidden("API_KEY_EXPIRED", "api key has expired")
ErrAPIKeyExpired
=
infraerrors
.
Forbidden
(
"API_KEY_EXPIRED"
,
"api key 已过期"
)
// ErrAPIKeyQuotaExhausted = infraerrors.TooManyRequests("API_KEY_QUOTA_EXHAUSTED", "api key quota exhausted")
ErrAPIKeyQuotaExhausted
=
infraerrors
.
TooManyRequests
(
"API_KEY_QUOTA_EXHAUSTED"
,
"api key 额度已用完"
)
)
const
(
...
...
@@ -51,6 +55,9 @@ type APIKeyRepository interface {
CountByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
ListKeysByUserID
(
ctx
context
.
Context
,
userID
int64
)
([]
string
,
error
)
ListKeysByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
([]
string
,
error
)
// Quota methods
IncrementQuotaUsed
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
(
float64
,
error
)
}
// APIKeyCache defines cache operations for API key service
...
...
@@ -85,6 +92,10 @@ type CreateAPIKeyRequest struct {
CustomKey
*
string
`json:"custom_key"`
// 可选的自定义key
IPWhitelist
[]
string
`json:"ip_whitelist"`
// IP 白名单
IPBlacklist
[]
string
`json:"ip_blacklist"`
// IP 黑名单
// Quota fields
Quota
float64
`json:"quota"`
// Quota limit in USD (0 = unlimited)
ExpiresInDays
*
int
`json:"expires_in_days"`
// Days until expiry (nil = never expires)
}
// UpdateAPIKeyRequest 更新API Key请求
...
...
@@ -94,6 +105,12 @@ type UpdateAPIKeyRequest struct {
Status
*
string
`json:"status"`
IPWhitelist
[]
string
`json:"ip_whitelist"`
// IP 白名单(空数组清空)
IPBlacklist
[]
string
`json:"ip_blacklist"`
// IP 黑名单(空数组清空)
// Quota fields
Quota
*
float64
`json:"quota"`
// Quota limit in USD (nil = no change, 0 = unlimited)
ExpiresAt
*
time
.
Time
`json:"expires_at"`
// Expiration time (nil = no change)
ClearExpiration
bool
`json:"-"`
// Clear expiration (internal use)
ResetQuota
*
bool
`json:"reset_quota"`
// Reset quota_used to 0
}
// APIKeyService API Key服务
...
...
@@ -289,6 +306,14 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
Status
:
StatusActive
,
IPWhitelist
:
req
.
IPWhitelist
,
IPBlacklist
:
req
.
IPBlacklist
,
Quota
:
req
.
Quota
,
QuotaUsed
:
0
,
}
// Set expiration time if specified
if
req
.
ExpiresInDays
!=
nil
&&
*
req
.
ExpiresInDays
>
0
{
expiresAt
:=
time
.
Now
()
.
AddDate
(
0
,
0
,
*
req
.
ExpiresInDays
)
apiKey
.
ExpiresAt
=
&
expiresAt
}
if
err
:=
s
.
apiKeyRepo
.
Create
(
ctx
,
apiKey
);
err
!=
nil
{
...
...
@@ -436,6 +461,35 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
}
}
// Update quota fields
if
req
.
Quota
!=
nil
{
apiKey
.
Quota
=
*
req
.
Quota
// If quota is increased and status was quota_exhausted, reactivate
if
apiKey
.
Status
==
StatusAPIKeyQuotaExhausted
&&
*
req
.
Quota
>
apiKey
.
QuotaUsed
{
apiKey
.
Status
=
StatusActive
}
}
if
req
.
ResetQuota
!=
nil
&&
*
req
.
ResetQuota
{
apiKey
.
QuotaUsed
=
0
// If resetting quota and status was quota_exhausted, reactivate
if
apiKey
.
Status
==
StatusAPIKeyQuotaExhausted
{
apiKey
.
Status
=
StatusActive
}
}
if
req
.
ClearExpiration
{
apiKey
.
ExpiresAt
=
nil
// If clearing expiry and status was expired, reactivate
if
apiKey
.
Status
==
StatusAPIKeyExpired
{
apiKey
.
Status
=
StatusActive
}
}
else
if
req
.
ExpiresAt
!=
nil
{
apiKey
.
ExpiresAt
=
req
.
ExpiresAt
// If extending expiry and status was expired, reactivate
if
apiKey
.
Status
==
StatusAPIKeyExpired
&&
time
.
Now
()
.
Before
(
*
req
.
ExpiresAt
)
{
apiKey
.
Status
=
StatusActive
}
}
// 更新 IP 限制(空数组会清空设置)
apiKey
.
IPWhitelist
=
req
.
IPWhitelist
apiKey
.
IPBlacklist
=
req
.
IPBlacklist
...
...
@@ -572,3 +626,51 @@ func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword
}
return
keys
,
nil
}
// CheckAPIKeyQuotaAndExpiry checks if the API key is valid for use (not expired, quota not exhausted)
// Returns nil if valid, error if invalid
func
(
s
*
APIKeyService
)
CheckAPIKeyQuotaAndExpiry
(
apiKey
*
APIKey
)
error
{
// Check expiration
if
apiKey
.
IsExpired
()
{
return
ErrAPIKeyExpired
}
// Check quota
if
apiKey
.
IsQuotaExhausted
()
{
return
ErrAPIKeyQuotaExhausted
}
return
nil
}
// UpdateQuotaUsed updates the quota_used field after a request
// Also checks if quota is exhausted and updates status accordingly
func
(
s
*
APIKeyService
)
UpdateQuotaUsed
(
ctx
context
.
Context
,
apiKeyID
int64
,
cost
float64
)
error
{
if
cost
<=
0
{
return
nil
}
// Use repository to atomically increment quota_used
newQuotaUsed
,
err
:=
s
.
apiKeyRepo
.
IncrementQuotaUsed
(
ctx
,
apiKeyID
,
cost
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"increment quota used: %w"
,
err
)
}
// Check if quota is now exhausted and update status if needed
apiKey
,
err
:=
s
.
apiKeyRepo
.
GetByID
(
ctx
,
apiKeyID
)
if
err
!=
nil
{
return
nil
// Don't fail the request, just log
}
// If quota is set and now exhausted, update status
if
apiKey
.
Quota
>
0
&&
newQuotaUsed
>=
apiKey
.
Quota
{
apiKey
.
Status
=
StatusAPIKeyQuotaExhausted
if
err
:=
s
.
apiKeyRepo
.
Update
(
ctx
,
apiKey
);
err
!=
nil
{
return
nil
// Don't fail the request
}
// Invalidate cache so next request sees the new status
s
.
InvalidateAuthCacheByKey
(
ctx
,
apiKey
.
Key
)
}
return
nil
}
backend/internal/service/api_key_service_cache_test.go
View file @
de7ff902
...
...
@@ -99,6 +99,10 @@ func (s *authRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]
return
s
.
listKeysByGroupID
(
ctx
,
groupID
)
}
func
(
s
*
authRepoStub
)
IncrementQuotaUsed
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
(
float64
,
error
)
{
panic
(
"unexpected IncrementQuotaUsed call"
)
}
type
authCacheStub
struct
{
getAuthCache
func
(
ctx
context
.
Context
,
key
string
)
(
*
APIKeyAuthCacheEntry
,
error
)
setAuthKeys
[]
string
...
...
backend/internal/service/api_key_service_delete_test.go
View file @
de7ff902
...
...
@@ -118,6 +118,10 @@ func (s *apiKeyRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) (
panic
(
"unexpected ListKeysByGroupID call"
)
}
func
(
s
*
apiKeyRepoStub
)
IncrementQuotaUsed
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
(
float64
,
error
)
{
panic
(
"unexpected IncrementQuotaUsed call"
)
}
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
// 用于验证删除操作时缓存清理逻辑是否被正确调用。
//
...
...
backend/internal/service/auth_service.go
View file @
de7ff902
...
...
@@ -185,7 +185,6 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
log
.
Printf
(
"[Auth] Failed to mark invitation code as used for user %d: %v"
,
user
.
ID
,
err
)
}
}
// 应用优惠码(如果提供且功能已启用)
if
promoCode
!=
""
&&
s
.
promoService
!=
nil
&&
s
.
settingService
!=
nil
&&
s
.
settingService
.
IsPromoCodeEnabled
(
ctx
)
{
if
err
:=
s
.
promoService
.
ApplyPromoCode
(
ctx
,
user
.
ID
,
promoCode
);
err
!=
nil
{
...
...
backend/internal/service/domain_constants.go
View file @
de7ff902
...
...
@@ -32,6 +32,7 @@ const (
AccountTypeOAuth
=
domain
.
AccountTypeOAuth
// OAuth类型账号(full scope: profile + inference)
AccountTypeSetupToken
=
domain
.
AccountTypeSetupToken
// Setup Token类型账号(inference only scope)
AccountTypeAPIKey
=
domain
.
AccountTypeAPIKey
// API Key类型账号
AccountTypeUpstream
=
domain
.
AccountTypeUpstream
// 上游透传类型账号(通过 Base URL + API Key 连接上游)
)
// Redeem type constants
...
...
backend/internal/service/gateway_service.go
View file @
de7ff902
...
...
@@ -257,6 +257,9 @@ var (
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
var
ErrClaudeCodeOnly
=
errors
.
New
(
"this group only allows Claude Code clients"
)
// ErrModelScopeNotSupported 表示请求的模型系列不在分组支持的范围内
var
ErrModelScopeNotSupported
=
errors
.
New
(
"model scope not supported by this group"
)
// allowedHeaders 白名单headers(参考CRS项目)
var
allowedHeaders
=
map
[
string
]
bool
{
"accept"
:
true
,
...
...
@@ -589,12 +592,18 @@ func (s *GatewayService) hashContent(content string) string {
}
// replaceModelInBody 替换请求体中的model字段
// 使用 json.RawMessage 保留其他字段的原始字节,避免 thinking 块等内容被修改
func
(
s
*
GatewayService
)
replaceModelInBody
(
body
[]
byte
,
newModel
string
)
[]
byte
{
var
req
map
[
string
]
any
var
req
map
[
string
]
json
.
RawMessage
if
err
:=
json
.
Unmarshal
(
body
,
&
req
);
err
!=
nil
{
return
body
}
req
[
"model"
]
=
newModel
// 只序列化 model 字段
modelBytes
,
err
:=
json
.
Marshal
(
newModel
)
if
err
!=
nil
{
return
body
}
req
[
"model"
]
=
modelBytes
newBody
,
err
:=
json
.
Marshal
(
req
)
if
err
!=
nil
{
return
body
...
...
@@ -791,12 +800,21 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
if
len
(
body
)
==
0
{
return
body
,
modelID
,
nil
}
// 使用 json.RawMessage 保留 messages 的原始字节,避免 thinking 块被修改
var
reqRaw
map
[
string
]
json
.
RawMessage
if
err
:=
json
.
Unmarshal
(
body
,
&
reqRaw
);
err
!=
nil
{
return
body
,
modelID
,
nil
}
// 同时解析为 map[string]any 用于修改非 messages 字段
var
req
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
req
);
err
!=
nil
{
return
body
,
modelID
,
nil
}
toolNameMap
:=
make
(
map
[
string
]
string
)
modified
:=
false
if
system
,
ok
:=
req
[
"system"
];
ok
{
switch
v
:=
system
.
(
type
)
{
...
...
@@ -804,6 +822,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
sanitized
:=
sanitizeSystemText
(
v
)
if
sanitized
!=
v
{
req
[
"system"
]
=
sanitized
modified
=
true
}
case
[]
any
:
for
_
,
item
:=
range
v
{
...
...
@@ -821,6 +840,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
sanitized
:=
sanitizeSystemText
(
text
)
if
sanitized
!=
text
{
block
[
"text"
]
=
sanitized
modified
=
true
}
}
}
...
...
@@ -831,6 +851,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
if
normalized
!=
rawModel
{
req
[
"model"
]
=
normalized
modelID
=
normalized
modified
=
true
}
}
...
...
@@ -846,16 +867,19 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
normalized
:=
normalizeToolNameForClaude
(
name
,
toolNameMap
)
if
normalized
!=
""
&&
normalized
!=
name
{
toolMap
[
"name"
]
=
normalized
modified
=
true
}
}
if
desc
,
ok
:=
toolMap
[
"description"
]
.
(
string
);
ok
{
sanitized
:=
sanitizeToolDescription
(
desc
)
if
sanitized
!=
desc
{
toolMap
[
"description"
]
=
sanitized
modified
=
true
}
}
if
schema
,
ok
:=
toolMap
[
"input_schema"
];
ok
{
normalizeToolInputSchema
(
schema
,
toolNameMap
)
modified
=
true
}
tools
[
idx
]
=
toolMap
}
...
...
@@ -884,11 +908,15 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
normalizedTools
[
normalized
]
=
value
}
req
[
"tools"
]
=
normalizedTools
modified
=
true
}
}
else
{
req
[
"tools"
]
=
[]
any
{}
modified
=
true
}
// 处理 messages 中的 tool_use 块,但保留包含 thinking 块的消息的原始字节
messagesModified
:=
false
if
messages
,
ok
:=
req
[
"messages"
]
.
([]
any
);
ok
{
for
_
,
msg
:=
range
messages
{
msgMap
,
ok
:=
msg
.
(
map
[
string
]
any
)
...
...
@@ -899,6 +927,24 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
if
!
ok
{
continue
}
// 检查此消息是否包含 thinking 块
hasThinking
:=
false
for
_
,
block
:=
range
content
{
blockMap
,
ok
:=
block
.
(
map
[
string
]
any
)
if
!
ok
{
continue
}
blockType
,
_
:=
blockMap
[
"type"
]
.
(
string
)
if
blockType
==
"thinking"
||
blockType
==
"redacted_thinking"
{
hasThinking
=
true
break
}
}
// 如果包含 thinking 块,跳过此消息的修改
if
hasThinking
{
continue
}
// 只修改不包含 thinking 块的消息中的 tool_use
for
_
,
block
:=
range
content
{
blockMap
,
ok
:=
block
.
(
map
[
string
]
any
)
if
!
ok
{
...
...
@@ -911,6 +957,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
normalized
:=
normalizeToolNameForClaude
(
name
,
toolNameMap
)
if
normalized
!=
""
&&
normalized
!=
name
{
blockMap
[
"name"
]
=
normalized
messagesModified
=
true
}
}
}
...
...
@@ -920,6 +967,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
if
opts
.
stripSystemCacheControl
{
if
system
,
ok
:=
req
[
"system"
];
ok
{
_
=
stripCacheControlFromSystemBlocks
(
system
)
modified
=
true
}
}
...
...
@@ -931,12 +979,46 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
}
if
existing
,
ok
:=
metadata
[
"user_id"
]
.
(
string
);
!
ok
||
existing
==
""
{
metadata
[
"user_id"
]
=
opts
.
metadataUserID
modified
=
true
}
}
delete
(
req
,
"temperature"
)
delete
(
req
,
"tool_choice"
)
if
_
,
hasTemp
:=
req
[
"temperature"
];
hasTemp
{
delete
(
req
,
"temperature"
)
modified
=
true
}
if
_
,
hasChoice
:=
req
[
"tool_choice"
];
hasChoice
{
delete
(
req
,
"tool_choice"
)
modified
=
true
}
if
!
modified
&&
!
messagesModified
{
return
body
,
modelID
,
toolNameMap
}
// 如果 messages 没有被修改,保留原始 messages 字节
if
!
messagesModified
{
// 序列化非 messages 字段
newBody
,
err
:=
json
.
Marshal
(
req
)
if
err
!=
nil
{
return
body
,
modelID
,
toolNameMap
}
// 替换回原始的 messages
var
newReq
map
[
string
]
json
.
RawMessage
if
err
:=
json
.
Unmarshal
(
newBody
,
&
newReq
);
err
!=
nil
{
return
newBody
,
modelID
,
toolNameMap
}
if
origMessages
,
ok
:=
reqRaw
[
"messages"
];
ok
{
newReq
[
"messages"
]
=
origMessages
}
finalBody
,
err
:=
json
.
Marshal
(
newReq
)
if
err
!=
nil
{
return
newBody
,
modelID
,
toolNameMap
}
return
finalBody
,
modelID
,
toolNameMap
}
// messages 被修改了,需要完整序列化
newBody
,
err
:=
json
.
Marshal
(
req
)
if
err
!=
nil
{
return
body
,
modelID
,
toolNameMap
...
...
@@ -1139,6 +1221,13 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
log
.
Printf
(
"[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
platform
)
}
// Antigravity 模型系列检查(在账号选择前检查,确保所有代码路径都经过此检查)
if
platform
==
PlatformAntigravity
&&
groupID
!=
nil
&&
requestedModel
!=
""
{
if
err
:=
s
.
checkAntigravityModelScope
(
ctx
,
*
groupID
,
requestedModel
);
err
!=
nil
{
return
nil
,
err
}
}
accounts
,
useMixed
,
err
:=
s
.
listSchedulableAccounts
(
ctx
,
groupID
,
platform
,
hasForcePlatform
)
if
err
!=
nil
{
return
nil
,
err
...
...
@@ -1636,6 +1725,10 @@ func (s *GatewayService) resolveGroupByID(ctx context.Context, groupID int64) (*
return
group
,
nil
}
func
(
s
*
GatewayService
)
ResolveGroupByID
(
ctx
context
.
Context
,
groupID
int64
)
(
*
Group
,
error
)
{
return
s
.
resolveGroupByID
(
ctx
,
groupID
)
}
func
(
s
*
GatewayService
)
routingAccountIDsForRequest
(
ctx
context
.
Context
,
groupID
*
int64
,
requestedModel
string
,
platform
string
)
[]
int64
{
if
groupID
==
nil
||
requestedModel
==
""
||
platform
!=
PlatformAnthropic
{
return
nil
...
...
@@ -1701,7 +1794,7 @@ func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID
}
// 强制平台模式不检查 Claude Code 限制
if
_
,
hasForcePlatform
:=
ctx
.
Value
(
ctxkey
.
ForcePlatform
)
.
(
string
);
hasForcePlatform
{
if
forcePlatform
,
hasForcePlatform
:=
ctx
.
Value
(
ctxkey
.
ForcePlatform
)
.
(
string
);
hasForcePlatform
&&
forcePlatform
!=
""
{
return
nil
,
groupID
,
nil
}
...
...
@@ -2030,6 +2123,13 @@ func shuffleWithinPriority(accounts []*Account) {
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
func
(
s
*
GatewayService
)
selectAccountForModelWithPlatform
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{},
platform
string
)
(
*
Account
,
error
)
{
// 对 Antigravity 平台,检查请求的模型系列是否在分组支持范围内
if
platform
==
PlatformAntigravity
&&
groupID
!=
nil
&&
requestedModel
!=
""
{
if
err
:=
s
.
checkAntigravityModelScope
(
ctx
,
*
groupID
,
requestedModel
);
err
!=
nil
{
return
nil
,
err
}
}
preferOAuth
:=
platform
==
PlatformGemini
routingAccountIDs
:=
s
.
routingAccountIDsForRequest
(
ctx
,
groupID
,
requestedModel
,
platform
)
...
...
@@ -2465,6 +2565,10 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
// Antigravity 平台使用专门的模型支持检查
return
IsAntigravityModelSupported
(
requestedModel
)
}
// OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID)
if
account
.
Platform
==
PlatformAnthropic
&&
account
.
Type
!=
AccountTypeAPIKey
{
requestedModel
=
claude
.
NormalizeModelID
(
requestedModel
)
}
// Gemini API Key 账户直接透传,由上游判断模型是否支持
if
account
.
Platform
==
PlatformGemini
&&
account
.
Type
==
AccountTypeAPIKey
{
return
true
...
...
@@ -2914,16 +3018,30 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 强制执行 cache_control 块数量限制(最多 4 个)
body
=
enforceCacheControlLimit
(
body
)
// 应用模型映射(仅对apikey类型账号)
// 应用模型映射:
// - APIKey 账号:使用账号级别的显式映射(如果配置),否则透传原始模型名
// - OAuth/SetupToken 账号:使用 Anthropic 标准映射(短ID → 长ID)
mappedModel
:=
reqModel
mappingSource
:=
""
if
account
.
Type
==
AccountTypeAPIKey
{
mappedModel
:
=
account
.
GetMappedModel
(
reqModel
)
mappedModel
=
account
.
GetMappedModel
(
reqModel
)
if
mappedModel
!=
reqModel
{
// 替换请求体中的模型名
body
=
s
.
replaceModelInBody
(
body
,
mappedModel
)
reqModel
=
mappedModel
log
.
Printf
(
"Model mapping applied: %s -> %s (account: %s)"
,
originalModel
,
mappedModel
,
account
.
Name
)
mappingSource
=
"account"
}
}
if
mappingSource
==
""
&&
account
.
Platform
==
PlatformAnthropic
&&
account
.
Type
!=
AccountTypeAPIKey
{
normalized
:=
claude
.
NormalizeModelID
(
reqModel
)
if
normalized
!=
reqModel
{
mappedModel
=
normalized
mappingSource
=
"prefix"
}
}
if
mappedModel
!=
reqModel
{
// 替换请求体中的模型名
body
=
s
.
replaceModelInBody
(
body
,
mappedModel
)
reqModel
=
mappedModel
log
.
Printf
(
"Model mapping applied: %s -> %s (account: %s, source=%s)"
,
originalModel
,
mappedModel
,
account
.
Name
,
mappingSource
)
}
// 获取凭证
token
,
tokenType
,
err
:=
s
.
GetAccessToken
(
ctx
,
account
)
...
...
@@ -3625,6 +3743,13 @@ func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool {
return
true
}
// 检测 thinking block 被修改的错误
// 例如: "thinking or redacted_thinking blocks in the latest assistant message cannot be modified"
if
strings
.
Contains
(
msg
,
"cannot be modified"
)
&&
(
strings
.
Contains
(
msg
,
"thinking"
)
||
strings
.
Contains
(
msg
,
"redacted_thinking"
))
{
log
.
Printf
(
"[SignatureCheck] Detected thinking block modification error"
)
return
true
}
// 检测空消息内容错误(可能是过滤 thinking blocks 后导致的)
// 例如: "all messages must have non-empty content"
if
strings
.
Contains
(
msg
,
"non-empty content"
)
||
strings
.
Contains
(
msg
,
"empty content"
)
{
...
...
@@ -4493,13 +4618,19 @@ func (s *GatewayService) replaceToolNamesInResponseBody(body []byte, toolNameMap
// RecordUsageInput 记录使用量的输入参数
type
RecordUsageInput
struct
{
Result
*
ForwardResult
APIKey
*
APIKey
User
*
User
Account
*
Account
Subscription
*
UserSubscription
// 可选:订阅信息
UserAgent
string
// 请求的 User-Agent
IPAddress
string
// 请求的客户端 IP 地址
Result
*
ForwardResult
APIKey
*
APIKey
User
*
User
Account
*
Account
Subscription
*
UserSubscription
// 可选:订阅信息
UserAgent
string
// 请求的 User-Agent
IPAddress
string
// 请求的客户端 IP 地址
APIKeyService
APIKeyQuotaUpdater
// 可选:用于更新API Key配额
}
// APIKeyQuotaUpdater defines the interface for updating API Key quota
type
APIKeyQuotaUpdater
interface
{
UpdateQuotaUsed
(
ctx
context
.
Context
,
apiKeyID
int64
,
cost
float64
)
error
}
// RecordUsage 记录使用量并扣费(或更新订阅用量)
...
...
@@ -4661,6 +4792,13 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
}
}
// 更新 API Key 配额(如果设置了配额限制)
if
shouldBill
&&
cost
.
ActualCost
>
0
&&
apiKey
.
Quota
>
0
&&
input
.
APIKeyService
!=
nil
{
if
err
:=
input
.
APIKeyService
.
UpdateQuotaUsed
(
ctx
,
apiKey
.
ID
,
cost
.
ActualCost
);
err
!=
nil
{
log
.
Printf
(
"Update API key quota failed: %v"
,
err
)
}
}
// Schedule batch update for account last_used_at
s
.
deferredService
.
ScheduleLastUsedUpdate
(
account
.
ID
)
...
...
@@ -4678,6 +4816,7 @@ type RecordUsageLongContextInput struct {
IPAddress
string
// 请求的客户端 IP 地址
LongContextThreshold
int
// 长上下文阈值(如 200000)
LongContextMultiplier
float64
// 超出阈值部分的倍率(如 2.0)
APIKeyService
*
APIKeyService
// API Key 配额服务(可选)
}
// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini)
...
...
@@ -4814,6 +4953,12 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
}
// 异步更新余额缓存
s
.
billingCacheService
.
QueueDeductBalance
(
user
.
ID
,
cost
.
ActualCost
)
// API Key 独立配额扣费
if
input
.
APIKeyService
!=
nil
&&
apiKey
.
Quota
>
0
{
if
err
:=
input
.
APIKeyService
.
UpdateQuotaUsed
(
ctx
,
apiKey
.
ID
,
cost
.
ActualCost
);
err
!=
nil
{
log
.
Printf
(
"Add API key quota used failed: %v"
,
err
)
}
}
}
}
...
...
@@ -4848,16 +4993,30 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
return
nil
}
// 应用模型映射(仅对 apikey 类型账号)
if
account
.
Type
==
AccountTypeAPIKey
{
if
reqModel
!=
""
{
mappedModel
:=
account
.
GetMappedModel
(
reqModel
)
// 应用模型映射:
// - APIKey 账号:使用账号级别的显式映射(如果配置),否则透传原始模型名
// - OAuth/SetupToken 账号:使用 Anthropic 标准映射(短ID → 长ID)
if
reqModel
!=
""
{
mappedModel
:=
reqModel
mappingSource
:=
""
if
account
.
Type
==
AccountTypeAPIKey
{
mappedModel
=
account
.
GetMappedModel
(
reqModel
)
if
mappedModel
!=
reqModel
{
body
=
s
.
replaceModelInBody
(
body
,
mappedModel
)
reqModel
=
mappedModel
log
.
Printf
(
"CountTokens model mapping applied: %s -> %s (account: %s)"
,
parsed
.
Model
,
mappedModel
,
account
.
Name
)
mappingSource
=
"account"
}
}
if
mappingSource
==
""
&&
account
.
Platform
==
PlatformAnthropic
&&
account
.
Type
!=
AccountTypeAPIKey
{
normalized
:=
claude
.
NormalizeModelID
(
reqModel
)
if
normalized
!=
reqModel
{
mappedModel
=
normalized
mappingSource
=
"prefix"
}
}
if
mappedModel
!=
reqModel
{
body
=
s
.
replaceModelInBody
(
body
,
mappedModel
)
reqModel
=
mappedModel
log
.
Printf
(
"CountTokens model mapping applied: %s -> %s (account: %s, source=%s)"
,
parsed
.
Model
,
mappedModel
,
account
.
Name
,
mappingSource
)
}
}
// 获取凭证
...
...
@@ -5109,6 +5268,27 @@ func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
return
normalized
,
nil
}
// checkAntigravityModelScope 检查 Antigravity 平台的模型系列是否在分组支持范围内
func
(
s
*
GatewayService
)
checkAntigravityModelScope
(
ctx
context
.
Context
,
groupID
int64
,
requestedModel
string
)
error
{
scope
,
ok
:=
ResolveAntigravityQuotaScope
(
requestedModel
)
if
!
ok
{
return
nil
// 无法解析 scope,跳过检查
}
group
,
err
:=
s
.
resolveGroupByID
(
ctx
,
groupID
)
if
err
!=
nil
{
return
nil
// 查询失败时放行
}
if
group
==
nil
{
return
nil
// 分组不存在时放行
}
if
!
IsScopeSupported
(
group
.
SupportedModelScopes
,
scope
)
{
return
ErrModelScopeNotSupported
}
return
nil
}
// GetAvailableModels returns the list of models available for a group
// It aggregates model_mapping keys from all schedulable accounts in the group
func
(
s
*
GatewayService
)
GetAvailableModels
(
ctx
context
.
Context
,
groupID
*
int64
,
platform
string
)
[]
string
{
...
...
backend/internal/service/gemini_messages_compat_service.go
View file @
de7ff902
...
...
@@ -977,6 +977,11 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return
nil
,
s
.
writeGoogleError
(
c
,
http
.
StatusBadRequest
,
"Request body is empty"
)
}
// 过滤掉 parts 为空的消息(Gemini API 不接受空 parts)
if
filteredBody
,
err
:=
filterEmptyPartsFromGeminiRequest
(
body
);
err
==
nil
{
body
=
filteredBody
}
switch
action
{
case
"generateContent"
,
"streamGenerateContent"
,
"countTokens"
:
// ok
...
...
backend/internal/service/gemini_native_signature_cleaner.go
View file @
de7ff902
...
...
@@ -2,20 +2,22 @@ package service
import
(
"encoding/json"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
// CleanGeminiNativeThoughtSignatures 从 Gemini 原生 API 请求中
移除
thoughtSignature 字段,
// CleanGeminiNativeThoughtSignatures 从 Gemini 原生 API 请求中
替换
thoughtSignature 字段
为 dummy 签名
,
// 以避免跨账号签名验证错误。
//
// 当粘性会话切换账号时(例如原账号异常、不可调度等),旧账号返回的 thoughtSignature
// 会导致新账号的签名验证失败。通过
移除这些签名,让新账号重新生成有效的签名
。
// 会导致新账号的签名验证失败。通过
替换为 dummy 签名,跳过签名验证
。
//
// CleanGeminiNativeThoughtSignatures re
mov
es thoughtSignature fields
from Gemini native API requests
// to avoid cross-account signature validation errors.
// CleanGeminiNativeThoughtSignatures re
plac
es thoughtSignature fields
with dummy signature
//
in Gemini native API requests
to avoid cross-account signature validation errors.
//
// When sticky session switches accounts (e.g., original account becomes unavailable),
// thoughtSignatures from the old account will cause validation failures on the new account.
// By re
moving these
signature
s
, we
allow the new account to generate valid signatures
.
// By re
placing with dummy
signature, we
skip signature validation
.
func
CleanGeminiNativeThoughtSignatures
(
body
[]
byte
)
[]
byte
{
if
len
(
body
)
==
0
{
return
body
...
...
@@ -28,11 +30,11 @@ func CleanGeminiNativeThoughtSignatures(body []byte) []byte {
return
body
}
// 递归
清理
thoughtSignature
cleaned
:=
clean
ThoughtSignaturesRecursive
(
data
)
// 递归
替换
thoughtSignature
为 dummy 签名
replaced
:=
replace
ThoughtSignaturesRecursive
(
data
)
// 重新序列化
result
,
err
:=
json
.
Marshal
(
clean
ed
)
result
,
err
:=
json
.
Marshal
(
replac
ed
)
if
err
!=
nil
{
// 如果序列化失败,返回原始 body
return
body
...
...
@@ -41,19 +43,20 @@ func CleanGeminiNativeThoughtSignatures(body []byte) []byte {
return
result
}
//
clean
ThoughtSignaturesRecursive 递归遍历数据结构,
移除
所有 thoughtSignature 字段
func
clean
ThoughtSignaturesRecursive
(
data
any
)
any
{
//
replace
ThoughtSignaturesRecursive 递归遍历数据结构,
将
所有 thoughtSignature 字段
替换为 dummy 签名
func
replace
ThoughtSignaturesRecursive
(
data
any
)
any
{
switch
v
:=
data
.
(
type
)
{
case
map
[
string
]
any
:
// 创建新的 map,
移除
thoughtSignature
// 创建新的 map,
替换
thoughtSignature
为 dummy 签名
result
:=
make
(
map
[
string
]
any
,
len
(
v
))
for
key
,
value
:=
range
v
{
//
跳过
thoughtSignature 字段
//
替换
thoughtSignature 字段
为 dummy 签名
if
key
==
"thoughtSignature"
{
result
[
key
]
=
antigravity
.
DummyThoughtSignature
continue
}
// 递归处理嵌套结构
result
[
key
]
=
clean
ThoughtSignaturesRecursive
(
value
)
result
[
key
]
=
replace
ThoughtSignaturesRecursive
(
value
)
}
return
result
...
...
@@ -61,7 +64,7 @@ func cleanThoughtSignaturesRecursive(data any) any {
// 递归处理数组中的每个元素
result
:=
make
([]
any
,
len
(
v
))
for
i
,
item
:=
range
v
{
result
[
i
]
=
clean
ThoughtSignaturesRecursive
(
item
)
result
[
i
]
=
replace
ThoughtSignaturesRecursive
(
item
)
}
return
result
...
...
backend/internal/service/group.go
View file @
de7ff902
...
...
@@ -35,6 +35,8 @@ type Group struct {
// Claude Code 客户端限制
ClaudeCodeOnly
bool
FallbackGroupID
*
int64
// 无效请求兜底分组(仅 anthropic 平台使用)
FallbackGroupIDOnInvalidRequest
*
int64
// 模型路由配置
// key: 模型匹配模式(支持 * 通配符,如 "claude-opus-*")
...
...
@@ -42,6 +44,13 @@ type Group struct {
ModelRouting
map
[
string
][]
int64
ModelRoutingEnabled
bool
// MCP XML 协议注入开关(仅 antigravity 平台使用)
MCPXMLInject
bool
// 支持的模型系列(仅 antigravity 平台使用)
// 可选值: claude, gemini_text, gemini_image
SupportedModelScopes
[]
string
CreatedAt
time
.
Time
UpdatedAt
time
.
Time
...
...
backend/internal/service/identity_service.go
View file @
de7ff902
...
...
@@ -169,22 +169,31 @@ func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) {
// RewriteUserID 重写body中的metadata.user_id
// 输入格式:user_{clientId}_account__session_{sessionUUID}
// 输出格式:user_{cachedClientID}_account_{accountUUID}_session_{newHash}
//
// 重要:此函数使用 json.RawMessage 保留其他字段的原始字节,
// 避免重新序列化导致 thinking 块等内容被修改。
func
(
s
*
IdentityService
)
RewriteUserID
(
body
[]
byte
,
accountID
int64
,
accountUUID
,
cachedClientID
string
)
([]
byte
,
error
)
{
if
len
(
body
)
==
0
||
accountUUID
==
""
||
cachedClientID
==
""
{
return
body
,
nil
}
//
解析JSON
var
reqMap
map
[
string
]
any
//
使用 RawMessage 保留其他字段的原始字节
var
reqMap
map
[
string
]
json
.
RawMessage
if
err
:=
json
.
Unmarshal
(
body
,
&
reqMap
);
err
!=
nil
{
return
body
,
nil
}
metadata
,
ok
:=
reqMap
[
"metadata"
]
.
(
map
[
string
]
any
)
// 解析 metadata 字段
metadataRaw
,
ok
:=
reqMap
[
"metadata"
]
if
!
ok
{
return
body
,
nil
}
var
metadata
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
metadataRaw
,
&
metadata
);
err
!=
nil
{
return
body
,
nil
}
userID
,
ok
:=
metadata
[
"user_id"
]
.
(
string
)
if
!
ok
||
userID
==
""
{
return
body
,
nil
...
...
@@ -207,7 +216,13 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
newUserID
:=
fmt
.
Sprintf
(
"user_%s_account_%s_session_%s"
,
cachedClientID
,
accountUUID
,
newSessionHash
)
metadata
[
"user_id"
]
=
newUserID
reqMap
[
"metadata"
]
=
metadata
// 只重新序列化 metadata 字段
newMetadataRaw
,
err
:=
json
.
Marshal
(
metadata
)
if
err
!=
nil
{
return
body
,
nil
}
reqMap
[
"metadata"
]
=
newMetadataRaw
return
json
.
Marshal
(
reqMap
)
}
...
...
@@ -215,6 +230,9 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
// RewriteUserIDWithMasking 重写body中的metadata.user_id,支持会话ID伪装
// 如果账号启用了会话ID伪装(session_id_masking_enabled),
// 则在完成常规重写后,将 session 部分替换为固定的伪装ID(15分钟内保持不变)
//
// 重要:此函数使用 json.RawMessage 保留其他字段的原始字节,
// 避免重新序列化导致 thinking 块等内容被修改。
func
(
s
*
IdentityService
)
RewriteUserIDWithMasking
(
ctx
context
.
Context
,
body
[]
byte
,
account
*
Account
,
accountUUID
,
cachedClientID
string
)
([]
byte
,
error
)
{
// 先执行常规的 RewriteUserID 逻辑
newBody
,
err
:=
s
.
RewriteUserID
(
body
,
account
.
ID
,
accountUUID
,
cachedClientID
)
...
...
@@ -227,17 +245,23 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
return
newBody
,
nil
}
//
解析重写后的 body,提取 user_id
var
reqMap
map
[
string
]
any
//
使用 RawMessage 保留其他字段的原始字节
var
reqMap
map
[
string
]
json
.
RawMessage
if
err
:=
json
.
Unmarshal
(
newBody
,
&
reqMap
);
err
!=
nil
{
return
newBody
,
nil
}
metadata
,
ok
:=
reqMap
[
"metadata"
]
.
(
map
[
string
]
any
)
// 解析 metadata 字段
metadataRaw
,
ok
:=
reqMap
[
"metadata"
]
if
!
ok
{
return
newBody
,
nil
}
var
metadata
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
metadataRaw
,
&
metadata
);
err
!=
nil
{
return
newBody
,
nil
}
userID
,
ok
:=
metadata
[
"user_id"
]
.
(
string
)
if
!
ok
||
userID
==
""
{
return
newBody
,
nil
...
...
@@ -278,7 +302,13 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
)
metadata
[
"user_id"
]
=
newUserID
reqMap
[
"metadata"
]
=
metadata
// 只重新序列化 metadata 字段
newMetadataRaw
,
marshalErr
:=
json
.
Marshal
(
metadata
)
if
marshalErr
!=
nil
{
return
newBody
,
nil
}
reqMap
[
"metadata"
]
=
newMetadataRaw
return
json
.
Marshal
(
reqMap
)
}
...
...
backend/internal/service/openai_codex_transform.go
View file @
de7ff902
...
...
@@ -72,7 +72,7 @@ type opencodeCacheMetadata struct {
LastChecked
int64
`json:"lastChecked"`
}
func
applyCodexOAuthTransform
(
reqBody
map
[
string
]
any
)
codexTransformResult
{
func
applyCodexOAuthTransform
(
reqBody
map
[
string
]
any
,
isCodexCLI
bool
)
codexTransformResult
{
result
:=
codexTransformResult
{}
// 工具续链需求会影响存储策略与 input 过滤逻辑。
needsToolContinuation
:=
NeedsToolContinuation
(
reqBody
)
...
...
@@ -118,22 +118,9 @@ func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult {
result
.
PromptCacheKey
=
strings
.
TrimSpace
(
v
)
}
instructions
:=
strings
.
TrimSpace
(
getOpenCodeCodexHeader
())
existingInstructions
,
_
:=
reqBody
[
"instructions"
]
.
(
string
)
existingInstructions
=
strings
.
TrimSpace
(
existingInstructions
)
if
instructions
!=
""
{
if
existingInstructions
!=
instructions
{
reqBody
[
"instructions"
]
=
instructions
result
.
Modified
=
true
}
}
else
if
existingInstructions
==
""
{
// 未获取到 opencode 指令时,回退使用 Codex CLI 指令。
codexInstructions
:=
strings
.
TrimSpace
(
getCodexCLIInstructions
())
if
codexInstructions
!=
""
{
reqBody
[
"instructions"
]
=
codexInstructions
result
.
Modified
=
true
}
// instructions 处理逻辑:根据是否是 Codex CLI 分别调用不同方法
if
applyInstructions
(
reqBody
,
isCodexCLI
)
{
result
.
Modified
=
true
}
// 续链场景保留 item_reference 与 id,避免 call_id 上下文丢失。
...
...
@@ -276,6 +263,72 @@ func GetCodexCLIInstructions() string {
return
getCodexCLIInstructions
()
}
// applyInstructions 处理 instructions 字段
// isCodexCLI=true: 仅补充缺失的 instructions(使用 opencode 指令)
// isCodexCLI=false: 优先使用 opencode 指令覆盖
func
applyInstructions
(
reqBody
map
[
string
]
any
,
isCodexCLI
bool
)
bool
{
if
isCodexCLI
{
return
applyCodexCLIInstructions
(
reqBody
)
}
return
applyOpenCodeInstructions
(
reqBody
)
}
// applyCodexCLIInstructions 为 Codex CLI 请求补充缺失的 instructions
// 仅在 instructions 为空时添加 opencode 指令
func
applyCodexCLIInstructions
(
reqBody
map
[
string
]
any
)
bool
{
if
!
isInstructionsEmpty
(
reqBody
)
{
return
false
// 已有有效 instructions,不修改
}
instructions
:=
strings
.
TrimSpace
(
getOpenCodeCodexHeader
())
if
instructions
!=
""
{
reqBody
[
"instructions"
]
=
instructions
return
true
}
return
false
}
// applyOpenCodeInstructions 为非 Codex CLI 请求应用 opencode 指令
// 优先使用 opencode 指令覆盖
func
applyOpenCodeInstructions
(
reqBody
map
[
string
]
any
)
bool
{
instructions
:=
strings
.
TrimSpace
(
getOpenCodeCodexHeader
())
existingInstructions
,
_
:=
reqBody
[
"instructions"
]
.
(
string
)
existingInstructions
=
strings
.
TrimSpace
(
existingInstructions
)
if
instructions
!=
""
{
if
existingInstructions
!=
instructions
{
reqBody
[
"instructions"
]
=
instructions
return
true
}
}
else
if
existingInstructions
==
""
{
codexInstructions
:=
strings
.
TrimSpace
(
getCodexCLIInstructions
())
if
codexInstructions
!=
""
{
reqBody
[
"instructions"
]
=
codexInstructions
return
true
}
}
return
false
}
// isInstructionsEmpty 检查 instructions 字段是否为空
// 处理以下情况:字段不存在、nil、空字符串、纯空白字符串
func
isInstructionsEmpty
(
reqBody
map
[
string
]
any
)
bool
{
val
,
exists
:=
reqBody
[
"instructions"
]
if
!
exists
{
return
true
}
if
val
==
nil
{
return
true
}
str
,
ok
:=
val
.
(
string
)
if
!
ok
{
return
true
}
return
strings
.
TrimSpace
(
str
)
==
""
}
// ReplaceWithCodexInstructions 将请求 instructions 替换为内置 Codex 指令(必要时)。
func
ReplaceWithCodexInstructions
(
reqBody
map
[
string
]
any
)
bool
{
codexInstructions
:=
strings
.
TrimSpace
(
getCodexCLIInstructions
())
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment