Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
c7e18bd5
Unverified
Commit
c7e18bd5
authored
Feb 25, 2026
by
Wesley Liddick
Committed by
GitHub
Feb 25, 2026
Browse files
Merge pull request #627 from touwaeriol/pr/bugfixes-and-enhancements
feat: 反重力(Antigravity)增强、Failover 重构及新模型支持
parents
516f8f28
8365a832
Changes
41
Show whitespace changes
Inline
Side-by-side
backend/internal/config/config.go
View file @
c7e18bd5
...
...
@@ -1158,6 +1158,7 @@ func setDefaults() {
viper
.
SetDefault
(
"gateway.force_codex_cli"
,
false
)
viper
.
SetDefault
(
"gateway.openai_passthrough_allow_timeout_headers"
,
false
)
viper
.
SetDefault
(
"gateway.antigravity_fallback_cooldown_minutes"
,
1
)
viper
.
SetDefault
(
"gateway.antigravity_extra_retries"
,
10
)
viper
.
SetDefault
(
"gateway.max_body_size"
,
int64
(
100
*
1024
*
1024
))
viper
.
SetDefault
(
"gateway.upstream_response_read_max_bytes"
,
int64
(
8
*
1024
*
1024
))
viper
.
SetDefault
(
"gateway.proxy_probe_response_read_max_bytes"
,
int64
(
1024
*
1024
))
...
...
backend/internal/domain/constants.go
View file @
c7e18bd5
...
...
@@ -74,6 +74,7 @@ var DefaultAntigravityModelMapping = map[string]string{
"claude-opus-4-6-thinking"
:
"claude-opus-4-6-thinking"
,
// 官方模型
"claude-opus-4-6"
:
"claude-opus-4-6-thinking"
,
// 简称映射
"claude-opus-4-5-thinking"
:
"claude-opus-4-6-thinking"
,
// 迁移旧模型
"claude-sonnet-4-6"
:
"claude-sonnet-4-6"
,
"claude-sonnet-4-5"
:
"claude-sonnet-4-5"
,
"claude-sonnet-4-5-thinking"
:
"claude-sonnet-4-5-thinking"
,
// Claude 详细版本 ID 映射
...
...
@@ -89,16 +90,18 @@ var DefaultAntigravityModelMapping = map[string]string{
"gemini-2.5-pro"
:
"gemini-2.5-pro"
,
// Gemini 3 白名单
"gemini-3-flash"
:
"gemini-3-flash"
,
"gemini-3-pro-high"
:
"gemini-3
.1
-pro-high"
,
"gemini-3-pro-low"
:
"gemini-3
.1
-pro-low"
,
"gemini-3-pro-high"
:
"gemini-3-pro-high"
,
"gemini-3-pro-low"
:
"gemini-3-pro-low"
,
"gemini-3-pro-image"
:
"gemini-3-pro-image"
,
// Gemini 3.1 透传
"gemini-3.1-pro-high"
:
"gemini-3.1-pro-high"
,
"gemini-3.1-pro-low"
:
"gemini-3.1-pro-low"
,
// Gemini 3 preview 映射
"gemini-3-flash-preview"
:
"gemini-3-flash"
,
"gemini-3-pro-preview"
:
"gemini-3
.1
-pro-high"
,
"gemini-3-pro-preview"
:
"gemini-3-pro-high"
,
"gemini-3-pro-image-preview"
:
"gemini-3-pro-image"
,
// Gemini 3.1 白名单
"gemini-3.1-pro-high"
:
"gemini-3.1-pro-high"
,
"gemini-3.1-pro-low"
:
"gemini-3.1-pro-low"
,
// Gemini 3.1 preview 映射
"gemini-3.1-pro-preview"
:
"gemini-3.1-pro-high"
,
// 其他官方模型
"gpt-oss-120b-medium"
:
"gpt-oss-120b-medium"
,
"tab_flash_lite_preview"
:
"tab_flash_lite_preview"
,
...
...
backend/internal/handler/admin/account_handler.go
View file @
c7e18bd5
...
...
@@ -139,6 +139,13 @@ type BulkUpdateAccountsRequest struct {
ConfirmMixedChannelRisk
*
bool
`json:"confirm_mixed_channel_risk"`
// 用户确认混合渠道风险
}
// CheckMixedChannelRequest represents check mixed channel risk request
type
CheckMixedChannelRequest
struct
{
Platform
string
`json:"platform" binding:"required"`
GroupIDs
[]
int64
`json:"group_ids"`
AccountID
*
int64
`json:"account_id"`
}
// AccountWithConcurrency extends Account with real-time concurrency info
type
AccountWithConcurrency
struct
{
*
dto
.
Account
...
...
@@ -389,6 +396,50 @@ func (h *AccountHandler) GetByID(c *gin.Context) {
response
.
Success
(
c
,
h
.
buildAccountResponseWithRuntime
(
c
.
Request
.
Context
(),
account
))
}
// CheckMixedChannel handles checking mixed channel risk for account-group binding.
// POST /api/v1/admin/accounts/check-mixed-channel
func
(
h
*
AccountHandler
)
CheckMixedChannel
(
c
*
gin
.
Context
)
{
var
req
CheckMixedChannelRequest
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid request: "
+
err
.
Error
())
return
}
if
len
(
req
.
GroupIDs
)
==
0
{
response
.
Success
(
c
,
gin
.
H
{
"has_risk"
:
false
})
return
}
accountID
:=
int64
(
0
)
if
req
.
AccountID
!=
nil
{
accountID
=
*
req
.
AccountID
}
err
:=
h
.
adminService
.
CheckMixedChannelRisk
(
c
.
Request
.
Context
(),
accountID
,
req
.
Platform
,
req
.
GroupIDs
)
if
err
!=
nil
{
var
mixedErr
*
service
.
MixedChannelError
if
errors
.
As
(
err
,
&
mixedErr
)
{
response
.
Success
(
c
,
gin
.
H
{
"has_risk"
:
true
,
"error"
:
"mixed_channel_warning"
,
"message"
:
mixedErr
.
Error
(),
"details"
:
gin
.
H
{
"group_id"
:
mixedErr
.
GroupID
,
"group_name"
:
mixedErr
.
GroupName
,
"current_platform"
:
mixedErr
.
CurrentPlatform
,
"other_platform"
:
mixedErr
.
OtherPlatform
,
},
})
return
}
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
gin
.
H
{
"has_risk"
:
false
})
}
// Create handles creating a new account
// POST /api/v1/admin/accounts
func
(
h
*
AccountHandler
)
Create
(
c
*
gin
.
Context
)
{
...
...
@@ -431,17 +482,10 @@ func (h *AccountHandler) Create(c *gin.Context) {
// 检查是否为混合渠道错误
var
mixedErr
*
service
.
MixedChannelError
if
errors
.
As
(
err
,
&
mixedErr
)
{
//
返回特殊错误码要求确认
//
创建接口仅返回最小必要字段,详细信息由专门检查接口提供
c
.
JSON
(
409
,
gin
.
H
{
"error"
:
"mixed_channel_warning"
,
"message"
:
mixedErr
.
Error
(),
"details"
:
gin
.
H
{
"group_id"
:
mixedErr
.
GroupID
,
"group_name"
:
mixedErr
.
GroupName
,
"current_platform"
:
mixedErr
.
CurrentPlatform
,
"other_platform"
:
mixedErr
.
OtherPlatform
,
},
"require_confirmation"
:
true
,
})
return
}
...
...
@@ -501,17 +545,10 @@ func (h *AccountHandler) Update(c *gin.Context) {
// 检查是否为混合渠道错误
var
mixedErr
*
service
.
MixedChannelError
if
errors
.
As
(
err
,
&
mixedErr
)
{
//
返回特殊错误码要求确认
//
更新接口仅返回最小必要字段,详细信息由专门检查接口提供
c
.
JSON
(
409
,
gin
.
H
{
"error"
:
"mixed_channel_warning"
,
"message"
:
mixedErr
.
Error
(),
"details"
:
gin
.
H
{
"group_id"
:
mixedErr
.
GroupID
,
"group_name"
:
mixedErr
.
GroupName
,
"current_platform"
:
mixedErr
.
CurrentPlatform
,
"other_platform"
:
mixedErr
.
OtherPlatform
,
},
"require_confirmation"
:
true
,
})
return
}
...
...
backend/internal/handler/admin/account_handler_mixed_channel_test.go
0 → 100644
View file @
c7e18bd5
package
admin
import
(
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func
setupAccountMixedChannelRouter
(
adminSvc
*
stubAdminService
)
*
gin
.
Engine
{
gin
.
SetMode
(
gin
.
TestMode
)
router
:=
gin
.
New
()
accountHandler
:=
NewAccountHandler
(
adminSvc
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
router
.
POST
(
"/api/v1/admin/accounts/check-mixed-channel"
,
accountHandler
.
CheckMixedChannel
)
router
.
POST
(
"/api/v1/admin/accounts"
,
accountHandler
.
Create
)
router
.
PUT
(
"/api/v1/admin/accounts/:id"
,
accountHandler
.
Update
)
return
router
}
func
TestAccountHandlerCheckMixedChannelNoRisk
(
t
*
testing
.
T
)
{
adminSvc
:=
newStubAdminService
()
router
:=
setupAccountMixedChannelRouter
(
adminSvc
)
body
,
_
:=
json
.
Marshal
(
map
[
string
]
any
{
"platform"
:
"antigravity"
,
"group_ids"
:
[]
int64
{
27
},
})
rec
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/admin/accounts/check-mixed-channel"
,
bytes
.
NewReader
(
body
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
router
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
var
resp
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
rec
.
Body
.
Bytes
(),
&
resp
))
require
.
Equal
(
t
,
float64
(
0
),
resp
[
"code"
])
data
,
ok
:=
resp
[
"data"
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
false
,
data
[
"has_risk"
])
require
.
Equal
(
t
,
int64
(
0
),
adminSvc
.
lastMixedCheck
.
accountID
)
require
.
Equal
(
t
,
"antigravity"
,
adminSvc
.
lastMixedCheck
.
platform
)
require
.
Equal
(
t
,
[]
int64
{
27
},
adminSvc
.
lastMixedCheck
.
groupIDs
)
}
func
TestAccountHandlerCheckMixedChannelWithRisk
(
t
*
testing
.
T
)
{
adminSvc
:=
newStubAdminService
()
adminSvc
.
checkMixedErr
=
&
service
.
MixedChannelError
{
GroupID
:
27
,
GroupName
:
"claude-max"
,
CurrentPlatform
:
"Antigravity"
,
OtherPlatform
:
"Anthropic"
,
}
router
:=
setupAccountMixedChannelRouter
(
adminSvc
)
body
,
_
:=
json
.
Marshal
(
map
[
string
]
any
{
"platform"
:
"antigravity"
,
"group_ids"
:
[]
int64
{
27
},
"account_id"
:
99
,
})
rec
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/admin/accounts/check-mixed-channel"
,
bytes
.
NewReader
(
body
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
router
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
var
resp
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
rec
.
Body
.
Bytes
(),
&
resp
))
require
.
Equal
(
t
,
float64
(
0
),
resp
[
"code"
])
data
,
ok
:=
resp
[
"data"
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
true
,
data
[
"has_risk"
])
require
.
Equal
(
t
,
"mixed_channel_warning"
,
data
[
"error"
])
details
,
ok
:=
data
[
"details"
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
float64
(
27
),
details
[
"group_id"
])
require
.
Equal
(
t
,
"claude-max"
,
details
[
"group_name"
])
require
.
Equal
(
t
,
"Antigravity"
,
details
[
"current_platform"
])
require
.
Equal
(
t
,
"Anthropic"
,
details
[
"other_platform"
])
require
.
Equal
(
t
,
int64
(
99
),
adminSvc
.
lastMixedCheck
.
accountID
)
}
func
TestAccountHandlerCreateMixedChannelConflictSimplifiedResponse
(
t
*
testing
.
T
)
{
adminSvc
:=
newStubAdminService
()
adminSvc
.
createAccountErr
=
&
service
.
MixedChannelError
{
GroupID
:
27
,
GroupName
:
"claude-max"
,
CurrentPlatform
:
"Antigravity"
,
OtherPlatform
:
"Anthropic"
,
}
router
:=
setupAccountMixedChannelRouter
(
adminSvc
)
body
,
_
:=
json
.
Marshal
(
map
[
string
]
any
{
"name"
:
"ag-oauth-1"
,
"platform"
:
"antigravity"
,
"type"
:
"oauth"
,
"credentials"
:
map
[
string
]
any
{
"refresh_token"
:
"rt"
},
"group_ids"
:
[]
int64
{
27
},
})
rec
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/admin/accounts"
,
bytes
.
NewReader
(
body
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
router
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusConflict
,
rec
.
Code
)
var
resp
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
rec
.
Body
.
Bytes
(),
&
resp
))
require
.
Equal
(
t
,
"mixed_channel_warning"
,
resp
[
"error"
])
require
.
Contains
(
t
,
resp
[
"message"
],
"mixed_channel_warning"
)
_
,
hasDetails
:=
resp
[
"details"
]
_
,
hasRequireConfirmation
:=
resp
[
"require_confirmation"
]
require
.
False
(
t
,
hasDetails
)
require
.
False
(
t
,
hasRequireConfirmation
)
}
func
TestAccountHandlerUpdateMixedChannelConflictSimplifiedResponse
(
t
*
testing
.
T
)
{
adminSvc
:=
newStubAdminService
()
adminSvc
.
updateAccountErr
=
&
service
.
MixedChannelError
{
GroupID
:
27
,
GroupName
:
"claude-max"
,
CurrentPlatform
:
"Antigravity"
,
OtherPlatform
:
"Anthropic"
,
}
router
:=
setupAccountMixedChannelRouter
(
adminSvc
)
body
,
_
:=
json
.
Marshal
(
map
[
string
]
any
{
"group_ids"
:
[]
int64
{
27
},
})
rec
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodPut
,
"/api/v1/admin/accounts/3"
,
bytes
.
NewReader
(
body
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
router
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusConflict
,
rec
.
Code
)
var
resp
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
rec
.
Body
.
Bytes
(),
&
resp
))
require
.
Equal
(
t
,
"mixed_channel_warning"
,
resp
[
"error"
])
require
.
Contains
(
t
,
resp
[
"message"
],
"mixed_channel_warning"
)
_
,
hasDetails
:=
resp
[
"details"
]
_
,
hasRequireConfirmation
:=
resp
[
"require_confirmation"
]
require
.
False
(
t
,
hasDetails
)
require
.
False
(
t
,
hasRequireConfirmation
)
}
backend/internal/handler/admin/admin_service_stub_test.go
View file @
c7e18bd5
...
...
@@ -22,6 +22,14 @@ type stubAdminService struct {
updatedProxyIDs
[]
int64
updatedProxies
[]
*
service
.
UpdateProxyInput
testedProxyIDs
[]
int64
createAccountErr
error
updateAccountErr
error
checkMixedErr
error
lastMixedCheck
struct
{
accountID
int64
platform
string
groupIDs
[]
int64
}
mu
sync
.
Mutex
}
...
...
@@ -188,11 +196,17 @@ func (s *stubAdminService) CreateAccount(ctx context.Context, input *service.Cre
s
.
mu
.
Lock
()
s
.
createdAccounts
=
append
(
s
.
createdAccounts
,
input
)
s
.
mu
.
Unlock
()
if
s
.
createAccountErr
!=
nil
{
return
nil
,
s
.
createAccountErr
}
account
:=
service
.
Account
{
ID
:
300
,
Name
:
input
.
Name
,
Status
:
service
.
StatusActive
}
return
&
account
,
nil
}
func
(
s
*
stubAdminService
)
UpdateAccount
(
ctx
context
.
Context
,
id
int64
,
input
*
service
.
UpdateAccountInput
)
(
*
service
.
Account
,
error
)
{
if
s
.
updateAccountErr
!=
nil
{
return
nil
,
s
.
updateAccountErr
}
account
:=
service
.
Account
{
ID
:
id
,
Name
:
input
.
Name
,
Status
:
service
.
StatusActive
}
return
&
account
,
nil
}
...
...
@@ -224,6 +238,13 @@ func (s *stubAdminService) BulkUpdateAccounts(ctx context.Context, input *servic
return
&
service
.
BulkUpdateAccountsResult
{
Success
:
1
,
Failed
:
0
,
SuccessIDs
:
[]
int64
{
1
}},
nil
}
func
(
s
*
stubAdminService
)
CheckMixedChannelRisk
(
ctx
context
.
Context
,
currentAccountID
int64
,
currentAccountPlatform
string
,
groupIDs
[]
int64
)
error
{
s
.
lastMixedCheck
.
accountID
=
currentAccountID
s
.
lastMixedCheck
.
platform
=
currentAccountPlatform
s
.
lastMixedCheck
.
groupIDs
=
append
([]
int64
(
nil
),
groupIDs
...
)
return
s
.
checkMixedErr
}
func
(
s
*
stubAdminService
)
ListProxies
(
ctx
context
.
Context
,
page
,
pageSize
int
,
protocol
,
status
,
search
string
)
([]
service
.
Proxy
,
int64
,
error
)
{
search
=
strings
.
TrimSpace
(
strings
.
ToLower
(
search
))
filtered
:=
make
([]
service
.
Proxy
,
0
,
len
(
s
.
proxies
))
...
...
backend/internal/handler/failover_loop.go
0 → 100644
View file @
c7e18bd5
package
handler
import
(
"context"
"log"
"net/http"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// TempUnscheduler 用于 HandleFailoverError 中同账号重试耗尽后的临时封禁。
// GatewayService 隐式实现此接口。
type
TempUnscheduler
interface
{
TempUnscheduleRetryableError
(
ctx
context
.
Context
,
accountID
int64
,
failoverErr
*
service
.
UpstreamFailoverError
)
}
// FailoverAction 表示 failover 错误处理后的下一步动作
type
FailoverAction
int
const
(
// FailoverContinue 继续循环(同账号重试或切换账号,调用方统一 continue)
FailoverContinue
FailoverAction
=
iota
// FailoverExhausted 切换次数耗尽(调用方应返回错误响应)
FailoverExhausted
// FailoverCanceled context 已取消(调用方应直接 return)
FailoverCanceled
)
const
(
// maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误)
maxSameAccountRetries
=
2
// sameAccountRetryDelay 同账号重试间隔
sameAccountRetryDelay
=
500
*
time
.
Millisecond
// singleAccountBackoffDelay 单账号分组 503 退避重试固定延时。
// Service 层在 SingleAccountRetry 模式下已做充分原地重试(最多 3 次、总等待 30s),
// Handler 层只需短暂间隔后重新进入 Service 层即可。
singleAccountBackoffDelay
=
2
*
time
.
Second
)
// FailoverState 跨循环迭代共享的 failover 状态
type
FailoverState
struct
{
SwitchCount
int
MaxSwitches
int
FailedAccountIDs
map
[
int64
]
struct
{}
SameAccountRetryCount
map
[
int64
]
int
LastFailoverErr
*
service
.
UpstreamFailoverError
ForceCacheBilling
bool
hasBoundSession
bool
}
// NewFailoverState 创建 failover 状态
func
NewFailoverState
(
maxSwitches
int
,
hasBoundSession
bool
)
*
FailoverState
{
return
&
FailoverState
{
MaxSwitches
:
maxSwitches
,
FailedAccountIDs
:
make
(
map
[
int64
]
struct
{}),
SameAccountRetryCount
:
make
(
map
[
int64
]
int
),
hasBoundSession
:
hasBoundSession
,
}
}
// HandleFailoverError 处理 UpstreamFailoverError,返回下一步动作。
// 包含:缓存计费判断、同账号重试、临时封禁、切换计数、Antigravity 延时。
func
(
s
*
FailoverState
)
HandleFailoverError
(
ctx
context
.
Context
,
gatewayService
TempUnscheduler
,
accountID
int64
,
platform
string
,
failoverErr
*
service
.
UpstreamFailoverError
,
)
FailoverAction
{
s
.
LastFailoverErr
=
failoverErr
// 缓存计费判断
if
needForceCacheBilling
(
s
.
hasBoundSession
,
failoverErr
)
{
s
.
ForceCacheBilling
=
true
}
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
if
failoverErr
.
RetryableOnSameAccount
&&
s
.
SameAccountRetryCount
[
accountID
]
<
maxSameAccountRetries
{
s
.
SameAccountRetryCount
[
accountID
]
++
log
.
Printf
(
"Account %d: retryable error %d, same-account retry %d/%d"
,
accountID
,
failoverErr
.
StatusCode
,
s
.
SameAccountRetryCount
[
accountID
],
maxSameAccountRetries
)
if
!
sleepWithContext
(
ctx
,
sameAccountRetryDelay
)
{
return
FailoverCanceled
}
return
FailoverContinue
}
// 同账号重试用尽,执行临时封禁
if
failoverErr
.
RetryableOnSameAccount
{
gatewayService
.
TempUnscheduleRetryableError
(
ctx
,
accountID
,
failoverErr
)
}
// 加入失败列表
s
.
FailedAccountIDs
[
accountID
]
=
struct
{}{}
// 检查是否耗尽
if
s
.
SwitchCount
>=
s
.
MaxSwitches
{
return
FailoverExhausted
}
// 递增切换计数
s
.
SwitchCount
++
log
.
Printf
(
"Account %d: upstream error %d, switching account %d/%d"
,
accountID
,
failoverErr
.
StatusCode
,
s
.
SwitchCount
,
s
.
MaxSwitches
)
// Antigravity 平台换号线性递增延时
if
platform
==
service
.
PlatformAntigravity
{
delay
:=
time
.
Duration
(
s
.
SwitchCount
-
1
)
*
time
.
Second
if
!
sleepWithContext
(
ctx
,
delay
)
{
return
FailoverCanceled
}
}
return
FailoverContinue
}
// HandleSelectionExhausted 处理选号失败(所有候选账号都在排除列表中)时的退避重试决策。
// 针对 Antigravity 单账号分组的 503 (MODEL_CAPACITY_EXHAUSTED) 场景:
// 清除排除列表、等待退避后重新选号。
//
// 返回 FailoverContinue 时,调用方应设置 SingleAccountRetry context 并 continue。
// 返回 FailoverExhausted 时,调用方应返回错误响应。
// 返回 FailoverCanceled 时,调用方应直接 return。
func
(
s
*
FailoverState
)
HandleSelectionExhausted
(
ctx
context
.
Context
)
FailoverAction
{
if
s
.
LastFailoverErr
!=
nil
&&
s
.
LastFailoverErr
.
StatusCode
==
http
.
StatusServiceUnavailable
&&
s
.
SwitchCount
<=
s
.
MaxSwitches
{
log
.
Printf
(
"Antigravity single-account 503 backoff: waiting %v before retry (attempt %d)"
,
singleAccountBackoffDelay
,
s
.
SwitchCount
)
if
!
sleepWithContext
(
ctx
,
singleAccountBackoffDelay
)
{
return
FailoverCanceled
}
log
.
Printf
(
"Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d"
,
s
.
SwitchCount
,
s
.
MaxSwitches
)
s
.
FailedAccountIDs
=
make
(
map
[
int64
]
struct
{})
return
FailoverContinue
}
return
FailoverExhausted
}
// needForceCacheBilling 判断 failover 时是否需要强制缓存计费。
// 粘性会话切换账号、或上游明确标记时,将 input_tokens 转为 cache_read 计费。
func
needForceCacheBilling
(
hasBoundSession
bool
,
failoverErr
*
service
.
UpstreamFailoverError
)
bool
{
return
hasBoundSession
||
(
failoverErr
!=
nil
&&
failoverErr
.
ForceCacheBilling
)
}
// sleepWithContext 等待指定时长,返回 false 表示 context 已取消。
func
sleepWithContext
(
ctx
context
.
Context
,
d
time
.
Duration
)
bool
{
if
d
<=
0
{
return
true
}
select
{
case
<-
ctx
.
Done
()
:
return
false
case
<-
time
.
After
(
d
)
:
return
true
}
}
backend/internal/handler/failover_loop_test.go
0 → 100644
View file @
c7e18bd5
package
handler
import
(
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// Mock
// ---------------------------------------------------------------------------
// mockTempUnscheduler 记录 TempUnscheduleRetryableError 的调用信息。
type
mockTempUnscheduler
struct
{
calls
[]
tempUnscheduleCall
}
type
tempUnscheduleCall
struct
{
accountID
int64
failoverErr
*
service
.
UpstreamFailoverError
}
func
(
m
*
mockTempUnscheduler
)
TempUnscheduleRetryableError
(
_
context
.
Context
,
accountID
int64
,
failoverErr
*
service
.
UpstreamFailoverError
)
{
m
.
calls
=
append
(
m
.
calls
,
tempUnscheduleCall
{
accountID
:
accountID
,
failoverErr
:
failoverErr
})
}
// ---------------------------------------------------------------------------
// Helper
// ---------------------------------------------------------------------------
func
newTestFailoverErr
(
statusCode
int
,
retryable
,
forceBilling
bool
)
*
service
.
UpstreamFailoverError
{
return
&
service
.
UpstreamFailoverError
{
StatusCode
:
statusCode
,
RetryableOnSameAccount
:
retryable
,
ForceCacheBilling
:
forceBilling
,
}
}
// ---------------------------------------------------------------------------
// NewFailoverState 测试
// ---------------------------------------------------------------------------
func
TestNewFailoverState
(
t
*
testing
.
T
)
{
t
.
Run
(
"初始化字段正确"
,
func
(
t
*
testing
.
T
)
{
fs
:=
NewFailoverState
(
5
,
true
)
require
.
Equal
(
t
,
5
,
fs
.
MaxSwitches
)
require
.
Equal
(
t
,
0
,
fs
.
SwitchCount
)
require
.
NotNil
(
t
,
fs
.
FailedAccountIDs
)
require
.
Empty
(
t
,
fs
.
FailedAccountIDs
)
require
.
NotNil
(
t
,
fs
.
SameAccountRetryCount
)
require
.
Empty
(
t
,
fs
.
SameAccountRetryCount
)
require
.
Nil
(
t
,
fs
.
LastFailoverErr
)
require
.
False
(
t
,
fs
.
ForceCacheBilling
)
require
.
True
(
t
,
fs
.
hasBoundSession
)
})
t
.
Run
(
"无绑定会话"
,
func
(
t
*
testing
.
T
)
{
fs
:=
NewFailoverState
(
3
,
false
)
require
.
Equal
(
t
,
3
,
fs
.
MaxSwitches
)
require
.
False
(
t
,
fs
.
hasBoundSession
)
})
t
.
Run
(
"零最大切换次数"
,
func
(
t
*
testing
.
T
)
{
fs
:=
NewFailoverState
(
0
,
false
)
require
.
Equal
(
t
,
0
,
fs
.
MaxSwitches
)
})
}
// ---------------------------------------------------------------------------
// sleepWithContext 测试
// ---------------------------------------------------------------------------
func
TestSleepWithContext
(
t
*
testing
.
T
)
{
t
.
Run
(
"零时长立即返回true"
,
func
(
t
*
testing
.
T
)
{
start
:=
time
.
Now
()
ok
:=
sleepWithContext
(
context
.
Background
(),
0
)
require
.
True
(
t
,
ok
)
require
.
Less
(
t
,
time
.
Since
(
start
),
50
*
time
.
Millisecond
)
})
t
.
Run
(
"负时长立即返回true"
,
func
(
t
*
testing
.
T
)
{
start
:=
time
.
Now
()
ok
:=
sleepWithContext
(
context
.
Background
(),
-
1
*
time
.
Second
)
require
.
True
(
t
,
ok
)
require
.
Less
(
t
,
time
.
Since
(
start
),
50
*
time
.
Millisecond
)
})
t
.
Run
(
"正常等待后返回true"
,
func
(
t
*
testing
.
T
)
{
start
:=
time
.
Now
()
ok
:=
sleepWithContext
(
context
.
Background
(),
50
*
time
.
Millisecond
)
elapsed
:=
time
.
Since
(
start
)
require
.
True
(
t
,
ok
)
require
.
GreaterOrEqual
(
t
,
elapsed
,
40
*
time
.
Millisecond
)
require
.
Less
(
t
,
elapsed
,
500
*
time
.
Millisecond
)
})
t
.
Run
(
"已取消context立即返回false"
,
func
(
t
*
testing
.
T
)
{
ctx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
cancel
()
start
:=
time
.
Now
()
ok
:=
sleepWithContext
(
ctx
,
5
*
time
.
Second
)
require
.
False
(
t
,
ok
)
require
.
Less
(
t
,
time
.
Since
(
start
),
50
*
time
.
Millisecond
)
})
t
.
Run
(
"等待期间context取消返回false"
,
func
(
t
*
testing
.
T
)
{
ctx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
go
func
()
{
time
.
Sleep
(
30
*
time
.
Millisecond
)
cancel
()
}()
start
:=
time
.
Now
()
ok
:=
sleepWithContext
(
ctx
,
5
*
time
.
Second
)
elapsed
:=
time
.
Since
(
start
)
require
.
False
(
t
,
ok
)
require
.
Less
(
t
,
elapsed
,
500
*
time
.
Millisecond
)
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — 基本切换流程
// ---------------------------------------------------------------------------
func
TestHandleFailoverError_BasicSwitch
(
t
*
testing
.
T
)
{
t
.
Run
(
"非重试错误_非Antigravity_直接切换"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
err
:=
newTestFailoverErr
(
500
,
false
,
false
)
action
:=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Equal
(
t
,
1
,
fs
.
SwitchCount
)
require
.
Contains
(
t
,
fs
.
FailedAccountIDs
,
int64
(
100
))
require
.
Equal
(
t
,
err
,
fs
.
LastFailoverErr
)
require
.
False
(
t
,
fs
.
ForceCacheBilling
)
require
.
Empty
(
t
,
mock
.
calls
,
"不应调用 TempUnschedule"
)
})
t
.
Run
(
"非重试错误_Antigravity_第一次切换无延迟"
,
func
(
t
*
testing
.
T
)
{
// switchCount 从 0→1 时,sleepFailoverDelay(ctx, 1) 的延时 = (1-1)*1s = 0
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
err
:=
newTestFailoverErr
(
500
,
false
,
false
)
start
:=
time
.
Now
()
action
:=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
service
.
PlatformAntigravity
,
err
)
elapsed
:=
time
.
Since
(
start
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Equal
(
t
,
1
,
fs
.
SwitchCount
)
require
.
Less
(
t
,
elapsed
,
200
*
time
.
Millisecond
,
"第一次切换延迟应为 0"
)
})
t
.
Run
(
"非重试错误_Antigravity_第二次切换有1秒延迟"
,
func
(
t
*
testing
.
T
)
{
// switchCount 从 1→2 时,sleepFailoverDelay(ctx, 2) 的延时 = (2-1)*1s = 1s
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
fs
.
SwitchCount
=
1
// 模拟已切换一次
err
:=
newTestFailoverErr
(
500
,
false
,
false
)
start
:=
time
.
Now
()
action
:=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
200
,
service
.
PlatformAntigravity
,
err
)
elapsed
:=
time
.
Since
(
start
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Equal
(
t
,
2
,
fs
.
SwitchCount
)
require
.
GreaterOrEqual
(
t
,
elapsed
,
800
*
time
.
Millisecond
,
"第二次切换延迟应约 1s"
)
require
.
Less
(
t
,
elapsed
,
3
*
time
.
Second
)
})
t
.
Run
(
"连续切换直到耗尽"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
2
,
false
)
// 第一次切换:0→1
err1
:=
newTestFailoverErr
(
500
,
false
,
false
)
action
:=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err1
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Equal
(
t
,
1
,
fs
.
SwitchCount
)
// 第二次切换:1→2
err2
:=
newTestFailoverErr
(
502
,
false
,
false
)
action
=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
200
,
"openai"
,
err2
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Equal
(
t
,
2
,
fs
.
SwitchCount
)
// 第三次已耗尽:SwitchCount(2) >= MaxSwitches(2)
err3
:=
newTestFailoverErr
(
503
,
false
,
false
)
action
=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
300
,
"openai"
,
err3
)
require
.
Equal
(
t
,
FailoverExhausted
,
action
)
require
.
Equal
(
t
,
2
,
fs
.
SwitchCount
,
"耗尽时不应继续递增"
)
// 验证失败账号列表
require
.
Len
(
t
,
fs
.
FailedAccountIDs
,
3
)
require
.
Contains
(
t
,
fs
.
FailedAccountIDs
,
int64
(
100
))
require
.
Contains
(
t
,
fs
.
FailedAccountIDs
,
int64
(
200
))
require
.
Contains
(
t
,
fs
.
FailedAccountIDs
,
int64
(
300
))
// LastFailoverErr 应为最后一次的错误
require
.
Equal
(
t
,
err3
,
fs
.
LastFailoverErr
)
})
t
.
Run
(
"MaxSwitches为0时首次即耗尽"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
0
,
false
)
err
:=
newTestFailoverErr
(
500
,
false
,
false
)
action
:=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err
)
require
.
Equal
(
t
,
FailoverExhausted
,
action
)
require
.
Equal
(
t
,
0
,
fs
.
SwitchCount
)
require
.
Contains
(
t
,
fs
.
FailedAccountIDs
,
int64
(
100
))
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — 缓存计费 (ForceCacheBilling)
// ---------------------------------------------------------------------------
func
TestHandleFailoverError_CacheBilling
(
t
*
testing
.
T
)
{
t
.
Run
(
"hasBoundSession为true时设置ForceCacheBilling"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
true
)
// hasBoundSession=true
err
:=
newTestFailoverErr
(
500
,
false
,
false
)
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err
)
require
.
True
(
t
,
fs
.
ForceCacheBilling
)
})
t
.
Run
(
"failoverErr.ForceCacheBilling为true时设置"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
err
:=
newTestFailoverErr
(
500
,
false
,
true
)
// ForceCacheBilling=true
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err
)
require
.
True
(
t
,
fs
.
ForceCacheBilling
)
})
t
.
Run
(
"两者均为false时不设置"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
err
:=
newTestFailoverErr
(
500
,
false
,
false
)
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err
)
require
.
False
(
t
,
fs
.
ForceCacheBilling
)
})
t
.
Run
(
"一旦设置不会被后续错误重置"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
// 第一次:ForceCacheBilling=true → 设置
err1
:=
newTestFailoverErr
(
500
,
false
,
true
)
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err1
)
require
.
True
(
t
,
fs
.
ForceCacheBilling
)
// 第二次:ForceCacheBilling=false → 仍然保持 true
err2
:=
newTestFailoverErr
(
502
,
false
,
false
)
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
200
,
"openai"
,
err2
)
require
.
True
(
t
,
fs
.
ForceCacheBilling
,
"ForceCacheBilling 一旦设置不应被重置"
)
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — 同账号重试 (RetryableOnSameAccount)
// ---------------------------------------------------------------------------
func
TestHandleFailoverError_SameAccountRetry
(
t
*
testing
.
T
)
{
t
.
Run
(
"第一次重试返回FailoverContinue"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
err
:=
newTestFailoverErr
(
400
,
true
,
false
)
start
:=
time
.
Now
()
action
:=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err
)
elapsed
:=
time
.
Since
(
start
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Equal
(
t
,
1
,
fs
.
SameAccountRetryCount
[
100
])
require
.
Equal
(
t
,
0
,
fs
.
SwitchCount
,
"同账号重试不应增加切换计数"
)
require
.
NotContains
(
t
,
fs
.
FailedAccountIDs
,
int64
(
100
),
"同账号重试不应加入失败列表"
)
require
.
Empty
(
t
,
mock
.
calls
,
"同账号重试期间不应调用 TempUnschedule"
)
// 验证等待了 sameAccountRetryDelay (500ms)
require
.
GreaterOrEqual
(
t
,
elapsed
,
400
*
time
.
Millisecond
)
require
.
Less
(
t
,
elapsed
,
2
*
time
.
Second
)
})
t
.
Run
(
"第二次重试仍返回FailoverContinue"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
err
:=
newTestFailoverErr
(
400
,
true
,
false
)
// 第一次
action
:=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Equal
(
t
,
1
,
fs
.
SameAccountRetryCount
[
100
])
// 第二次
action
=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Equal
(
t
,
2
,
fs
.
SameAccountRetryCount
[
100
])
require
.
Empty
(
t
,
mock
.
calls
,
"两次重试期间均不应调用 TempUnschedule"
)
})
t
.
Run
(
"第三次重试耗尽_触发TempUnschedule并切换"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
err
:=
newTestFailoverErr
(
400
,
true
,
false
)
// 第一次、第二次重试
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err
)
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err
)
require
.
Equal
(
t
,
2
,
fs
.
SameAccountRetryCount
[
100
])
// 第三次:重试已达到 maxSameAccountRetries(2),应切换账号
action
:=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Equal
(
t
,
1
,
fs
.
SwitchCount
)
require
.
Contains
(
t
,
fs
.
FailedAccountIDs
,
int64
(
100
))
// 验证 TempUnschedule 被调用
require
.
Len
(
t
,
mock
.
calls
,
1
)
require
.
Equal
(
t
,
int64
(
100
),
mock
.
calls
[
0
]
.
accountID
)
require
.
Equal
(
t
,
err
,
mock
.
calls
[
0
]
.
failoverErr
)
})
t
.
Run
(
"不同账号独立跟踪重试次数"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
5
,
false
)
err
:=
newTestFailoverErr
(
400
,
true
,
false
)
// 账号 100 第一次重试
action
:=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Equal
(
t
,
1
,
fs
.
SameAccountRetryCount
[
100
])
// 账号 200 第一次重试(独立计数)
action
=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
200
,
"openai"
,
err
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Equal
(
t
,
1
,
fs
.
SameAccountRetryCount
[
200
])
require
.
Equal
(
t
,
1
,
fs
.
SameAccountRetryCount
[
100
],
"账号 100 的计数不应受影响"
)
})
t
.
Run
(
"重试耗尽后再次遇到同账号_直接切换"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
5
,
false
)
err
:=
newTestFailoverErr
(
400
,
true
,
false
)
// 耗尽账号 100 的重试
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err
)
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err
)
// 第三次: 重试耗尽 → 切换
action
:=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
// 再次遇到账号 100,计数仍为 2,条件不满足 → 直接切换
action
=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Len
(
t
,
mock
.
calls
,
2
,
"第二次耗尽也应调用 TempUnschedule"
)
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — TempUnschedule 调用验证
// ---------------------------------------------------------------------------
func
TestHandleFailoverError_TempUnschedule
(
t
*
testing
.
T
)
{
t
.
Run
(
"非重试错误不调用TempUnschedule"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
err
:=
newTestFailoverErr
(
500
,
false
,
false
)
// RetryableOnSameAccount=false
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err
)
require
.
Empty
(
t
,
mock
.
calls
)
})
t
.
Run
(
"重试错误耗尽后调用TempUnschedule_传入正确参数"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
err
:=
newTestFailoverErr
(
502
,
true
,
false
)
// 耗尽重试
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
42
,
"openai"
,
err
)
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
42
,
"openai"
,
err
)
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
42
,
"openai"
,
err
)
require
.
Len
(
t
,
mock
.
calls
,
1
)
require
.
Equal
(
t
,
int64
(
42
),
mock
.
calls
[
0
]
.
accountID
)
require
.
Equal
(
t
,
502
,
mock
.
calls
[
0
]
.
failoverErr
.
StatusCode
)
require
.
True
(
t
,
mock
.
calls
[
0
]
.
failoverErr
.
RetryableOnSameAccount
)
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — Context 取消
// ---------------------------------------------------------------------------
func
TestHandleFailoverError_ContextCanceled
(
t
*
testing
.
T
)
{
t
.
Run
(
"同账号重试sleep期间context取消"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
err
:=
newTestFailoverErr
(
400
,
true
,
false
)
ctx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
cancel
()
// 立即取消
start
:=
time
.
Now
()
action
:=
fs
.
HandleFailoverError
(
ctx
,
mock
,
100
,
"openai"
,
err
)
elapsed
:=
time
.
Since
(
start
)
require
.
Equal
(
t
,
FailoverCanceled
,
action
)
require
.
Less
(
t
,
elapsed
,
100
*
time
.
Millisecond
,
"应立即返回"
)
// 重试计数仍应递增
require
.
Equal
(
t
,
1
,
fs
.
SameAccountRetryCount
[
100
])
})
t
.
Run
(
"Antigravity延迟期间context取消"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
fs
.
SwitchCount
=
1
// 下一次 switchCount=2 → delay = 1s
err
:=
newTestFailoverErr
(
500
,
false
,
false
)
ctx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
cancel
()
// 立即取消
start
:=
time
.
Now
()
action
:=
fs
.
HandleFailoverError
(
ctx
,
mock
,
100
,
service
.
PlatformAntigravity
,
err
)
elapsed
:=
time
.
Since
(
start
)
require
.
Equal
(
t
,
FailoverCanceled
,
action
)
require
.
Less
(
t
,
elapsed
,
100
*
time
.
Millisecond
,
"应立即返回而非等待 1s"
)
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — FailedAccountIDs 跟踪
// ---------------------------------------------------------------------------
func
TestHandleFailoverError_FailedAccountIDs
(
t
*
testing
.
T
)
{
t
.
Run
(
"切换时添加到失败列表"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
newTestFailoverErr
(
500
,
false
,
false
))
require
.
Contains
(
t
,
fs
.
FailedAccountIDs
,
int64
(
100
))
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
200
,
"openai"
,
newTestFailoverErr
(
502
,
false
,
false
))
require
.
Contains
(
t
,
fs
.
FailedAccountIDs
,
int64
(
200
))
require
.
Len
(
t
,
fs
.
FailedAccountIDs
,
2
)
})
t
.
Run
(
"耗尽时也添加到失败列表"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
0
,
false
)
action
:=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
newTestFailoverErr
(
500
,
false
,
false
))
require
.
Equal
(
t
,
FailoverExhausted
,
action
)
require
.
Contains
(
t
,
fs
.
FailedAccountIDs
,
int64
(
100
))
})
t
.
Run
(
"同账号重试期间不添加到失败列表"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
action
:=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
newTestFailoverErr
(
400
,
true
,
false
))
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
NotContains
(
t
,
fs
.
FailedAccountIDs
,
int64
(
100
))
})
t
.
Run
(
"同一账号多次切换不重复添加"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
5
,
false
)
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
newTestFailoverErr
(
500
,
false
,
false
))
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
newTestFailoverErr
(
500
,
false
,
false
))
require
.
Len
(
t
,
fs
.
FailedAccountIDs
,
1
,
"map 天然去重"
)
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — LastFailoverErr 更新
// ---------------------------------------------------------------------------
func
TestHandleFailoverError_LastFailoverErr
(
t
*
testing
.
T
)
{
t
.
Run
(
"每次调用都更新LastFailoverErr"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
err1
:=
newTestFailoverErr
(
500
,
false
,
false
)
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err1
)
require
.
Equal
(
t
,
err1
,
fs
.
LastFailoverErr
)
err2
:=
newTestFailoverErr
(
502
,
false
,
false
)
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
200
,
"openai"
,
err2
)
require
.
Equal
(
t
,
err2
,
fs
.
LastFailoverErr
)
})
t
.
Run
(
"同账号重试时也更新LastFailoverErr"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
err
:=
newTestFailoverErr
(
400
,
true
,
false
)
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err
)
require
.
Equal
(
t
,
err
,
fs
.
LastFailoverErr
)
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — 综合集成场景
// ---------------------------------------------------------------------------
func
TestHandleFailoverError_IntegrationScenario
(
t
*
testing
.
T
)
{
t
.
Run
(
"模拟完整failover流程_多账号混合重试与切换"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
true
)
// hasBoundSession=true
// 1. 账号 100 遇到可重试错误,同账号重试 2 次
retryErr
:=
newTestFailoverErr
(
400
,
true
,
false
)
action
:=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
retryErr
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
True
(
t
,
fs
.
ForceCacheBilling
,
"hasBoundSession=true 应设置 ForceCacheBilling"
)
action
=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
retryErr
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
// 2. 账号 100 重试耗尽 → TempUnschedule + 切换
action
=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
retryErr
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Equal
(
t
,
1
,
fs
.
SwitchCount
)
require
.
Len
(
t
,
mock
.
calls
,
1
)
// 3. 账号 200 遇到不可重试错误 → 直接切换
switchErr
:=
newTestFailoverErr
(
500
,
false
,
false
)
action
=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
200
,
"openai"
,
switchErr
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Equal
(
t
,
2
,
fs
.
SwitchCount
)
// 4. 账号 300 遇到不可重试错误 → 再切换
action
=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
300
,
"openai"
,
switchErr
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Equal
(
t
,
3
,
fs
.
SwitchCount
)
// 5. 账号 400 → 已耗尽 (SwitchCount=3 >= MaxSwitches=3)
action
=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
400
,
"openai"
,
switchErr
)
require
.
Equal
(
t
,
FailoverExhausted
,
action
)
// 最终状态验证
require
.
Equal
(
t
,
3
,
fs
.
SwitchCount
,
"耗尽时不再递增"
)
require
.
Len
(
t
,
fs
.
FailedAccountIDs
,
4
,
"4个不同账号都在失败列表中"
)
require
.
True
(
t
,
fs
.
ForceCacheBilling
)
require
.
Len
(
t
,
mock
.
calls
,
1
,
"只有账号 100 触发了 TempUnschedule"
)
})
t
.
Run
(
"模拟Antigravity平台完整流程"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
2
,
false
)
err
:=
newTestFailoverErr
(
500
,
false
,
false
)
// 第一次切换:delay = 0s
start
:=
time
.
Now
()
action
:=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
service
.
PlatformAntigravity
,
err
)
elapsed
:=
time
.
Since
(
start
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Less
(
t
,
elapsed
,
200
*
time
.
Millisecond
,
"第一次切换延迟为 0"
)
// 第二次切换:delay = 1s
start
=
time
.
Now
()
action
=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
200
,
service
.
PlatformAntigravity
,
err
)
elapsed
=
time
.
Since
(
start
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
GreaterOrEqual
(
t
,
elapsed
,
800
*
time
.
Millisecond
,
"第二次切换延迟约 1s"
)
// 第三次:耗尽(无延迟,因为在检查延迟之前就返回了)
start
=
time
.
Now
()
action
=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
300
,
service
.
PlatformAntigravity
,
err
)
elapsed
=
time
.
Since
(
start
)
require
.
Equal
(
t
,
FailoverExhausted
,
action
)
require
.
Less
(
t
,
elapsed
,
200
*
time
.
Millisecond
,
"耗尽时不应有延迟"
)
})
t
.
Run
(
"ForceCacheBilling通过错误标志设置"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
// hasBoundSession=false
// 第一次:ForceCacheBilling=false
err1
:=
newTestFailoverErr
(
500
,
false
,
false
)
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err1
)
require
.
False
(
t
,
fs
.
ForceCacheBilling
)
// 第二次:ForceCacheBilling=true(Antigravity 粘性会话切换)
err2
:=
newTestFailoverErr
(
500
,
false
,
true
)
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
200
,
"openai"
,
err2
)
require
.
True
(
t
,
fs
.
ForceCacheBilling
,
"错误标志应触发 ForceCacheBilling"
)
// 第三次:ForceCacheBilling=false,但状态仍保持 true
err3
:=
newTestFailoverErr
(
500
,
false
,
false
)
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
300
,
"openai"
,
err3
)
require
.
True
(
t
,
fs
.
ForceCacheBilling
,
"不应重置"
)
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — 边界条件
// ---------------------------------------------------------------------------
func
TestHandleFailoverError_EdgeCases
(
t
*
testing
.
T
)
{
t
.
Run
(
"StatusCode为0的错误也能正常处理"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
err
:=
newTestFailoverErr
(
0
,
false
,
false
)
action
:=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
})
t
.
Run
(
"AccountID为0也能正常跟踪"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
err
:=
newTestFailoverErr
(
500
,
true
,
false
)
action
:=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
0
,
"openai"
,
err
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Equal
(
t
,
1
,
fs
.
SameAccountRetryCount
[
0
])
})
t
.
Run
(
"负AccountID也能正常跟踪"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
err
:=
newTestFailoverErr
(
500
,
true
,
false
)
action
:=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
-
1
,
"openai"
,
err
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Equal
(
t
,
1
,
fs
.
SameAccountRetryCount
[
-
1
])
})
t
.
Run
(
"空平台名称不触发Antigravity延迟"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
fs
.
SwitchCount
=
1
err
:=
newTestFailoverErr
(
500
,
false
,
false
)
start
:=
time
.
Now
()
action
:=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
""
,
err
)
elapsed
:=
time
.
Since
(
start
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Less
(
t
,
elapsed
,
200
*
time
.
Millisecond
,
"空平台不应触发 Antigravity 延迟"
)
})
}
// ---------------------------------------------------------------------------
// HandleSelectionExhausted 测试
// ---------------------------------------------------------------------------
func
TestHandleSelectionExhausted
(
t
*
testing
.
T
)
{
t
.
Run
(
"无LastFailoverErr时返回Exhausted"
,
func
(
t
*
testing
.
T
)
{
fs
:=
NewFailoverState
(
3
,
false
)
// LastFailoverErr 为 nil
action
:=
fs
.
HandleSelectionExhausted
(
context
.
Background
())
require
.
Equal
(
t
,
FailoverExhausted
,
action
)
})
t
.
Run
(
"非503错误返回Exhausted"
,
func
(
t
*
testing
.
T
)
{
fs
:=
NewFailoverState
(
3
,
false
)
fs
.
LastFailoverErr
=
newTestFailoverErr
(
500
,
false
,
false
)
action
:=
fs
.
HandleSelectionExhausted
(
context
.
Background
())
require
.
Equal
(
t
,
FailoverExhausted
,
action
)
})
t
.
Run
(
"503且未耗尽_等待后返回Continue并清除失败列表"
,
func
(
t
*
testing
.
T
)
{
fs
:=
NewFailoverState
(
3
,
false
)
fs
.
LastFailoverErr
=
newTestFailoverErr
(
503
,
false
,
false
)
fs
.
FailedAccountIDs
[
100
]
=
struct
{}{}
fs
.
SwitchCount
=
1
start
:=
time
.
Now
()
action
:=
fs
.
HandleSelectionExhausted
(
context
.
Background
())
elapsed
:=
time
.
Since
(
start
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Empty
(
t
,
fs
.
FailedAccountIDs
,
"应清除失败账号列表"
)
require
.
GreaterOrEqual
(
t
,
elapsed
,
1500
*
time
.
Millisecond
,
"应等待约 2s"
)
require
.
Less
(
t
,
elapsed
,
5
*
time
.
Second
)
})
t
.
Run
(
"503但SwitchCount已超过MaxSwitches_返回Exhausted"
,
func
(
t
*
testing
.
T
)
{
fs
:=
NewFailoverState
(
2
,
false
)
fs
.
LastFailoverErr
=
newTestFailoverErr
(
503
,
false
,
false
)
fs
.
SwitchCount
=
3
// > MaxSwitches(2)
start
:=
time
.
Now
()
action
:=
fs
.
HandleSelectionExhausted
(
context
.
Background
())
elapsed
:=
time
.
Since
(
start
)
require
.
Equal
(
t
,
FailoverExhausted
,
action
)
require
.
Less
(
t
,
elapsed
,
100
*
time
.
Millisecond
,
"不应等待"
)
})
t
.
Run
(
"503但context已取消_返回Canceled"
,
func
(
t
*
testing
.
T
)
{
fs
:=
NewFailoverState
(
3
,
false
)
fs
.
LastFailoverErr
=
newTestFailoverErr
(
503
,
false
,
false
)
ctx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
cancel
()
start
:=
time
.
Now
()
action
:=
fs
.
HandleSelectionExhausted
(
ctx
)
elapsed
:=
time
.
Since
(
start
)
require
.
Equal
(
t
,
FailoverCanceled
,
action
)
require
.
Less
(
t
,
elapsed
,
100
*
time
.
Millisecond
,
"应立即返回"
)
})
t
.
Run
(
"503且SwitchCount等于MaxSwitches_仍可重试"
,
func
(
t
*
testing
.
T
)
{
fs
:=
NewFailoverState
(
2
,
false
)
fs
.
LastFailoverErr
=
newTestFailoverErr
(
503
,
false
,
false
)
fs
.
SwitchCount
=
2
// == MaxSwitches,条件是 <=,仍可重试
action
:=
fs
.
HandleSelectionExhausted
(
context
.
Background
())
require
.
Equal
(
t
,
FailoverContinue
,
action
)
})
}
backend/internal/handler/gateway_handler.go
View file @
c7e18bd5
...
...
@@ -7,7 +7,6 @@ import (
"errors"
"fmt"
"io"
"log"
"net/http"
"strings"
"time"
...
...
@@ -257,12 +256,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
hasBoundSession
:=
sessionKey
!=
""
&&
sessionBoundAccountID
>
0
if
platform
==
service
.
PlatformGemini
{
maxAccountSwitches
:=
h
.
maxAccountSwitchesGemini
switchCount
:=
0
failedAccountIDs
:=
make
(
map
[
int64
]
struct
{})
sameAccountRetryCount
:=
make
(
map
[
int64
]
int
)
// 同账号重试计数
var
lastFailoverErr
*
service
.
UpstreamFailoverError
var
forceCacheBilling
bool
// 粘性会话切换时的缓存计费标记
fs
:=
NewFailoverState
(
h
.
maxAccountSwitchesGemini
,
hasBoundSession
)
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
...
...
@@ -272,36 +266,29 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
for
{
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
reqModel
,
failedAccountIDs
,
""
)
// Gemini 不使用会话限制
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
reqModel
,
f
s
.
F
ailedAccountIDs
,
""
)
// Gemini 不使用会话限制
if
err
!=
nil
{
if
len
(
failedAccountIDs
)
==
0
{
reqLog
.
Warn
(
"gateway.account_select_failed"
,
zap
.
Error
(
err
),
zap
.
Int
(
"excluded_account_count"
,
len
(
failedAccountIDs
)))
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"Service temporarily unavailable"
,
streamStarted
)
if
len
(
fs
.
FailedAccountIDs
)
==
0
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts: "
+
err
.
Error
(),
streamStarted
)
return
}
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
if
lastFailoverErr
!=
nil
&&
lastFailoverErr
.
StatusCode
==
http
.
StatusServiceUnavailable
&&
switchCount
<=
maxAccountSwitches
{
if
sleepAntigravitySingleAccountBackoff
(
c
.
Request
.
Context
(),
switchCount
)
{
reqLog
.
Warn
(
"gateway.single_account_retrying"
,
zap
.
Int
(
"retry_count"
,
switchCount
),
zap
.
Int
(
"max_retries"
,
maxAccountSwitches
),
)
failedAccountIDs
=
make
(
map
[
int64
]
struct
{})
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
action
:=
fs
.
HandleSelectionExhausted
(
c
.
Request
.
Context
())
switch
action
{
case
FailoverContinue
:
ctx
:=
context
.
WithValue
(
c
.
Request
.
Context
(),
ctxkey
.
SingleAccountRetry
,
true
)
c
.
Request
=
c
.
Request
.
WithContext
(
ctx
)
continue
}
}
if
lastFailoverErr
!=
nil
{
h
.
handleFailoverExhausted
(
c
,
lastFailoverErr
,
service
.
PlatformGemini
,
streamStarted
)
case
FailoverCanceled
:
return
default
:
// FailoverExhausted
if
fs
.
LastFailoverErr
!=
nil
{
h
.
handleFailoverExhausted
(
c
,
fs
.
LastFailoverErr
,
service
.
PlatformGemini
,
streamStarted
)
}
else
{
h
.
handleFailoverExhaustedSimple
(
c
,
502
,
streamStarted
)
}
return
}
}
account
:=
selection
.
Account
setOpsSelectedAccount
(
c
,
account
.
ID
,
account
.
Platform
)
...
...
@@ -376,8 +363,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 转发请求 - 根据账号平台分流
var
result
*
service
.
ForwardResult
requestCtx
:=
c
.
Request
.
Context
()
if
s
witchCount
>
0
{
requestCtx
=
context
.
WithValue
(
requestCtx
,
ctxkey
.
AccountSwitchCount
,
s
witchCount
)
if
fs
.
S
witchCount
>
0
{
requestCtx
=
context
.
WithValue
(
requestCtx
,
ctxkey
.
AccountSwitchCount
,
fs
.
S
witchCount
)
}
if
account
.
Platform
==
service
.
PlatformAntigravity
{
result
,
err
=
h
.
antigravityGatewayService
.
ForwardGemini
(
requestCtx
,
c
,
account
,
reqModel
,
"generateContent"
,
reqStream
,
body
,
hasBoundSession
)
...
...
@@ -390,46 +377,17 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if
err
!=
nil
{
var
failoverErr
*
service
.
UpstreamFailoverError
if
errors
.
As
(
err
,
&
failoverErr
)
{
lastFailoverErr
=
failoverErr
if
needForceCacheBilling
(
hasBoundSession
,
failoverErr
)
{
forceCacheBilling
=
true
}
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
if
failoverErr
.
RetryableOnSameAccount
&&
sameAccountRetryCount
[
account
.
ID
]
<
maxSameAccountRetries
{
sameAccountRetryCount
[
account
.
ID
]
++
log
.
Printf
(
"Account %d: retryable error %d, same-account retry %d/%d"
,
account
.
ID
,
failoverErr
.
StatusCode
,
sameAccountRetryCount
[
account
.
ID
],
maxSameAccountRetries
)
if
!
sleepSameAccountRetryDelay
(
c
.
Request
.
Context
())
{
return
}
action
:=
fs
.
HandleFailoverError
(
c
.
Request
.
Context
(),
h
.
gatewayService
,
account
.
ID
,
account
.
Platform
,
failoverErr
)
switch
action
{
case
FailoverContinue
:
continue
}
// 同账号重试用尽,执行临时封禁并切换账号
if
failoverErr
.
RetryableOnSameAccount
{
h
.
gatewayService
.
TempUnscheduleRetryableError
(
c
.
Request
.
Context
(),
account
.
ID
,
failoverErr
)
}
failedAccountIDs
[
account
.
ID
]
=
struct
{}{}
if
switchCount
>=
maxAccountSwitches
{
h
.
handleFailoverExhausted
(
c
,
failoverErr
,
service
.
PlatformGemini
,
streamStarted
)
case
FailoverExhausted
:
h
.
handleFailoverExhausted
(
c
,
fs
.
LastFailoverErr
,
service
.
PlatformGemini
,
streamStarted
)
return
}
switchCount
++
reqLog
.
Warn
(
"gateway.upstream_failover_switching"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Int
(
"upstream_status"
,
failoverErr
.
StatusCode
),
zap
.
Int
(
"switch_count"
,
switchCount
),
zap
.
Int
(
"max_switches"
,
maxAccountSwitches
),
)
if
account
.
Platform
==
service
.
PlatformAntigravity
{
if
!
sleepFailoverDelay
(
c
.
Request
.
Context
(),
switchCount
)
{
case
FailoverCanceled
:
return
}
}
continue
}
wroteFallback
:=
h
.
ensureForwardErrorResponse
(
c
,
streamStarted
)
reqLog
.
Error
(
"gateway.forward_failed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
...
...
@@ -453,7 +411,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
Subscription
:
subscription
,
UserAgent
:
userAgent
,
IPAddress
:
clientIP
,
ForceCacheBilling
:
forceCacheBilling
,
ForceCacheBilling
:
f
s
.
F
orceCacheBilling
,
APIKeyService
:
h
.
apiKeyService
,
});
err
!=
nil
{
logger
.
L
()
.
With
(
...
...
@@ -486,46 +444,34 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
for
{
maxAccountSwitches
:=
h
.
maxAccountSwitches
switchCount
:=
0
failedAccountIDs
:=
make
(
map
[
int64
]
struct
{})
sameAccountRetryCount
:=
make
(
map
[
int64
]
int
)
// 同账号重试计数
var
lastFailoverErr
*
service
.
UpstreamFailoverError
fs
:=
NewFailoverState
(
h
.
maxAccountSwitches
,
hasBoundSession
)
retryWithFallback
:=
false
var
forceCacheBilling
bool
// 粘性会话切换时的缓存计费标记
for
{
// 选择支持该模型的账号
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
currentAPIKey
.
GroupID
,
sessionKey
,
reqModel
,
failedAccountIDs
,
parsedReq
.
MetadataUserID
)
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
currentAPIKey
.
GroupID
,
sessionKey
,
reqModel
,
f
s
.
F
ailedAccountIDs
,
parsedReq
.
MetadataUserID
)
if
err
!=
nil
{
if
len
(
failedAccountIDs
)
==
0
{
reqLog
.
Warn
(
"gateway.account_select_failed"
,
zap
.
Error
(
err
),
zap
.
Int
(
"excluded_account_count"
,
len
(
failedAccountIDs
)))
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"Service temporarily unavailable"
,
streamStarted
)
if
len
(
fs
.
FailedAccountIDs
)
==
0
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts: "
+
err
.
Error
(),
streamStarted
)
return
}
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
if
lastFailoverErr
!=
nil
&&
lastFailoverErr
.
StatusCode
==
http
.
StatusServiceUnavailable
&&
switchCount
<=
maxAccountSwitches
{
if
sleepAntigravitySingleAccountBackoff
(
c
.
Request
.
Context
(),
switchCount
)
{
reqLog
.
Warn
(
"gateway.single_account_retrying"
,
zap
.
Int
(
"retry_count"
,
switchCount
),
zap
.
Int
(
"max_retries"
,
maxAccountSwitches
),
)
failedAccountIDs
=
make
(
map
[
int64
]
struct
{})
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
action
:=
fs
.
HandleSelectionExhausted
(
c
.
Request
.
Context
())
switch
action
{
case
FailoverContinue
:
ctx
:=
context
.
WithValue
(
c
.
Request
.
Context
(),
ctxkey
.
SingleAccountRetry
,
true
)
c
.
Request
=
c
.
Request
.
WithContext
(
ctx
)
continue
}
}
if
lastFailoverErr
!=
nil
{
h
.
handleFailoverExhausted
(
c
,
lastFailoverErr
,
platform
,
streamStarted
)
case
FailoverCanceled
:
return
default
:
// FailoverExhausted
if
fs
.
LastFailoverErr
!=
nil
{
h
.
handleFailoverExhausted
(
c
,
fs
.
LastFailoverErr
,
platform
,
streamStarted
)
}
else
{
h
.
handleFailoverExhaustedSimple
(
c
,
502
,
streamStarted
)
}
return
}
}
account
:=
selection
.
Account
setOpsSelectedAccount
(
c
,
account
.
ID
,
account
.
Platform
)
...
...
@@ -600,8 +546,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 转发请求 - 根据账号平台分流
var
result
*
service
.
ForwardResult
requestCtx
:=
c
.
Request
.
Context
()
if
s
witchCount
>
0
{
requestCtx
=
context
.
WithValue
(
requestCtx
,
ctxkey
.
AccountSwitchCount
,
s
witchCount
)
if
fs
.
S
witchCount
>
0
{
requestCtx
=
context
.
WithValue
(
requestCtx
,
ctxkey
.
AccountSwitchCount
,
fs
.
S
witchCount
)
}
if
account
.
Platform
==
service
.
PlatformAntigravity
&&
account
.
Type
!=
service
.
AccountTypeAPIKey
{
result
,
err
=
h
.
antigravityGatewayService
.
Forward
(
requestCtx
,
c
,
account
,
body
,
hasBoundSession
)
...
...
@@ -657,46 +603,17 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
var
failoverErr
*
service
.
UpstreamFailoverError
if
errors
.
As
(
err
,
&
failoverErr
)
{
lastFailoverErr
=
failoverErr
if
needForceCacheBilling
(
hasBoundSession
,
failoverErr
)
{
forceCacheBilling
=
true
}
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
if
failoverErr
.
RetryableOnSameAccount
&&
sameAccountRetryCount
[
account
.
ID
]
<
maxSameAccountRetries
{
sameAccountRetryCount
[
account
.
ID
]
++
log
.
Printf
(
"Account %d: retryable error %d, same-account retry %d/%d"
,
account
.
ID
,
failoverErr
.
StatusCode
,
sameAccountRetryCount
[
account
.
ID
],
maxSameAccountRetries
)
if
!
sleepSameAccountRetryDelay
(
c
.
Request
.
Context
())
{
return
}
action
:=
fs
.
HandleFailoverError
(
c
.
Request
.
Context
(),
h
.
gatewayService
,
account
.
ID
,
account
.
Platform
,
failoverErr
)
switch
action
{
case
FailoverContinue
:
continue
}
// 同账号重试用尽,执行临时封禁并切换账号
if
failoverErr
.
RetryableOnSameAccount
{
h
.
gatewayService
.
TempUnscheduleRetryableError
(
c
.
Request
.
Context
(),
account
.
ID
,
failoverErr
)
}
failedAccountIDs
[
account
.
ID
]
=
struct
{}{}
if
switchCount
>=
maxAccountSwitches
{
h
.
handleFailoverExhausted
(
c
,
failoverErr
,
account
.
Platform
,
streamStarted
)
case
FailoverExhausted
:
h
.
handleFailoverExhausted
(
c
,
fs
.
LastFailoverErr
,
account
.
Platform
,
streamStarted
)
return
}
switchCount
++
reqLog
.
Warn
(
"gateway.upstream_failover_switching"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Int
(
"upstream_status"
,
failoverErr
.
StatusCode
),
zap
.
Int
(
"switch_count"
,
switchCount
),
zap
.
Int
(
"max_switches"
,
maxAccountSwitches
),
)
if
account
.
Platform
==
service
.
PlatformAntigravity
{
if
!
sleepFailoverDelay
(
c
.
Request
.
Context
(),
switchCount
)
{
case
FailoverCanceled
:
return
}
}
continue
}
wroteFallback
:=
h
.
ensureForwardErrorResponse
(
c
,
streamStarted
)
reqLog
.
Error
(
"gateway.forward_failed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
...
...
@@ -720,7 +637,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
Subscription
:
currentSubscription
,
UserAgent
:
userAgent
,
IPAddress
:
clientIP
,
ForceCacheBilling
:
forceCacheBilling
,
ForceCacheBilling
:
f
s
.
F
orceCacheBilling
,
APIKeyService
:
h
.
apiKeyService
,
});
err
!=
nil
{
logger
.
L
()
.
With
(
...
...
@@ -733,11 +650,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
)
.
Error
(
"gateway.record_usage_failed"
,
zap
.
Error
(
err
))
}
})
reqLog
.
Debug
(
"gateway.request_completed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Int
(
"switch_count"
,
switchCount
),
zap
.
Bool
(
"fallback_used"
,
fallbackUsed
),
)
return
}
if
!
retryWithFallback
{
...
...
@@ -982,69 +894,6 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT
fmt
.
Sprintf
(
"Concurrency limit exceeded for %s, please retry later"
,
slotType
),
streamStarted
)
}
// needForceCacheBilling 判断 failover 时是否需要强制缓存计费
// 粘性会话切换账号、或上游明确标记时,将 input_tokens 转为 cache_read 计费
func
needForceCacheBilling
(
hasBoundSession
bool
,
failoverErr
*
service
.
UpstreamFailoverError
)
bool
{
return
hasBoundSession
||
(
failoverErr
!=
nil
&&
failoverErr
.
ForceCacheBilling
)
}
const
(
// maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误)
maxSameAccountRetries
=
2
// sameAccountRetryDelay 同账号重试间隔
sameAccountRetryDelay
=
500
*
time
.
Millisecond
)
// sleepSameAccountRetryDelay 同账号重试固定延时,返回 false 表示 context 已取消。
func
sleepSameAccountRetryDelay
(
ctx
context
.
Context
)
bool
{
select
{
case
<-
ctx
.
Done
()
:
return
false
case
<-
time
.
After
(
sameAccountRetryDelay
)
:
return
true
}
}
// sleepFailoverDelay 账号切换线性递增延时:第1次0s、第2次1s、第3次2s…
// 返回 false 表示 context 已取消。
func
sleepFailoverDelay
(
ctx
context
.
Context
,
switchCount
int
)
bool
{
delay
:=
time
.
Duration
(
switchCount
-
1
)
*
time
.
Second
if
delay
<=
0
{
return
true
}
select
{
case
<-
ctx
.
Done
()
:
return
false
case
<-
time
.
After
(
delay
)
:
return
true
}
}
// sleepAntigravitySingleAccountBackoff Antigravity 平台单账号分组的 503 退避重试延时。
// 当分组内只有一个可用账号且上游返回 503(MODEL_CAPACITY_EXHAUSTED)时使用,
// 采用短固定延时策略。Service 层在 SingleAccountRetry 模式下已经做了充分的原地重试
// (最多 3 次、总等待 30s),所以 Handler 层的退避只需短暂等待即可。
// 返回 false 表示 context 已取消。
func
sleepAntigravitySingleAccountBackoff
(
ctx
context
.
Context
,
retryCount
int
)
bool
{
// 固定短延时:2s
// Service 层已经在原地等待了足够长的时间(retryDelay × 重试次数),
// Handler 层只需短暂间隔后重新进入 Service 层即可。
const
delay
=
2
*
time
.
Second
logger
.
L
()
.
With
(
zap
.
String
(
"component"
,
"handler.gateway.failover"
),
zap
.
Duration
(
"delay"
,
delay
),
zap
.
Int
(
"retry_count"
,
retryCount
),
)
.
Info
(
"gateway.single_account_backoff_waiting"
)
select
{
case
<-
ctx
.
Done
()
:
return
false
case
<-
time
.
After
(
delay
)
:
return
true
}
}
func
(
h
*
GatewayHandler
)
handleFailoverExhausted
(
c
*
gin
.
Context
,
failoverErr
*
service
.
UpstreamFailoverError
,
platform
string
,
streamStarted
bool
)
{
statusCode
:=
failoverErr
.
StatusCode
responseBody
:=
failoverErr
.
ResponseBody
...
...
backend/internal/handler/gateway_handler_single_account_retry_test.go
deleted
100644 → 0
View file @
516f8f28
package
handler
import
(
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// sleepAntigravitySingleAccountBackoff 测试
// ---------------------------------------------------------------------------
func
TestSleepAntigravitySingleAccountBackoff_ReturnsTrue
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
start
:=
time
.
Now
()
ok
:=
sleepAntigravitySingleAccountBackoff
(
ctx
,
1
)
elapsed
:=
time
.
Since
(
start
)
require
.
True
(
t
,
ok
,
"should return true when context is not canceled"
)
// 固定延迟 2s
require
.
GreaterOrEqual
(
t
,
elapsed
,
1500
*
time
.
Millisecond
,
"should wait approximately 2s"
)
require
.
Less
(
t
,
elapsed
,
5
*
time
.
Second
,
"should not wait too long"
)
}
func
TestSleepAntigravitySingleAccountBackoff_ContextCanceled
(
t
*
testing
.
T
)
{
ctx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
cancel
()
// 立即取消
start
:=
time
.
Now
()
ok
:=
sleepAntigravitySingleAccountBackoff
(
ctx
,
1
)
elapsed
:=
time
.
Since
(
start
)
require
.
False
(
t
,
ok
,
"should return false when context is canceled"
)
require
.
Less
(
t
,
elapsed
,
500
*
time
.
Millisecond
,
"should return immediately on cancel"
)
}
func
TestSleepAntigravitySingleAccountBackoff_FixedDelay
(
t
*
testing
.
T
)
{
// 验证不同 retryCount 都使用固定 2s 延迟
ctx
:=
context
.
Background
()
start
:=
time
.
Now
()
ok
:=
sleepAntigravitySingleAccountBackoff
(
ctx
,
5
)
elapsed
:=
time
.
Since
(
start
)
require
.
True
(
t
,
ok
)
// 即使 retryCount=5,延迟仍然是固定的 2s
require
.
GreaterOrEqual
(
t
,
elapsed
,
1500
*
time
.
Millisecond
)
require
.
Less
(
t
,
elapsed
,
5
*
time
.
Second
)
}
backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go
0 → 100644
View file @
c7e18bd5
//go:build unit
package
handler
import
(
"bytes"
"context"
"encoding/json"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
middleware
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
// 目标:严格验证“antigravity 账号通过 /v1/messages 提供 Claude 服务时”,
// 当账号 credentials.intercept_warmup_requests=true 且请求为 Warmup 时,
// 后端会在转发上游前直接拦截并返回 mock 响应(不依赖上游)。
type
fakeSchedulerCache
struct
{
accounts
[]
*
service
.
Account
}
func
(
f
*
fakeSchedulerCache
)
GetSnapshot
(
_
context
.
Context
,
_
service
.
SchedulerBucket
)
([]
*
service
.
Account
,
bool
,
error
)
{
return
f
.
accounts
,
true
,
nil
}
func
(
f
*
fakeSchedulerCache
)
SetSnapshot
(
_
context
.
Context
,
_
service
.
SchedulerBucket
,
_
[]
service
.
Account
)
error
{
return
nil
}
func
(
f
*
fakeSchedulerCache
)
GetAccount
(
_
context
.
Context
,
_
int64
)
(
*
service
.
Account
,
error
)
{
return
nil
,
nil
}
func
(
f
*
fakeSchedulerCache
)
SetAccount
(
_
context
.
Context
,
_
*
service
.
Account
)
error
{
return
nil
}
func
(
f
*
fakeSchedulerCache
)
DeleteAccount
(
_
context
.
Context
,
_
int64
)
error
{
return
nil
}
func
(
f
*
fakeSchedulerCache
)
UpdateLastUsed
(
_
context
.
Context
,
_
map
[
int64
]
time
.
Time
)
error
{
return
nil
}
func
(
f
*
fakeSchedulerCache
)
TryLockBucket
(
_
context
.
Context
,
_
service
.
SchedulerBucket
,
_
time
.
Duration
)
(
bool
,
error
)
{
return
true
,
nil
}
func
(
f
*
fakeSchedulerCache
)
ListBuckets
(
_
context
.
Context
)
([]
service
.
SchedulerBucket
,
error
)
{
return
nil
,
nil
}
func
(
f
*
fakeSchedulerCache
)
GetOutboxWatermark
(
_
context
.
Context
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
f
*
fakeSchedulerCache
)
SetOutboxWatermark
(
_
context
.
Context
,
_
int64
)
error
{
return
nil
}
type
fakeGroupRepo
struct
{
group
*
service
.
Group
}
func
(
f
*
fakeGroupRepo
)
Create
(
context
.
Context
,
*
service
.
Group
)
error
{
return
nil
}
func
(
f
*
fakeGroupRepo
)
GetByID
(
context
.
Context
,
int64
)
(
*
service
.
Group
,
error
)
{
return
f
.
group
,
nil
}
func
(
f
*
fakeGroupRepo
)
GetByIDLite
(
context
.
Context
,
int64
)
(
*
service
.
Group
,
error
)
{
return
f
.
group
,
nil
}
func
(
f
*
fakeGroupRepo
)
Update
(
context
.
Context
,
*
service
.
Group
)
error
{
return
nil
}
func
(
f
*
fakeGroupRepo
)
Delete
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
f
*
fakeGroupRepo
)
DeleteCascade
(
context
.
Context
,
int64
)
([]
int64
,
error
)
{
return
nil
,
nil
}
func
(
f
*
fakeGroupRepo
)
List
(
context
.
Context
,
pagination
.
PaginationParams
)
([]
service
.
Group
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
f
*
fakeGroupRepo
)
ListWithFilters
(
context
.
Context
,
pagination
.
PaginationParams
,
string
,
string
,
string
,
*
bool
)
([]
service
.
Group
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
f
*
fakeGroupRepo
)
ListActive
(
context
.
Context
)
([]
service
.
Group
,
error
)
{
return
nil
,
nil
}
func
(
f
*
fakeGroupRepo
)
ListActiveByPlatform
(
context
.
Context
,
string
)
([]
service
.
Group
,
error
)
{
return
nil
,
nil
}
func
(
f
*
fakeGroupRepo
)
ExistsByName
(
context
.
Context
,
string
)
(
bool
,
error
)
{
return
false
,
nil
}
func
(
f
*
fakeGroupRepo
)
GetAccountCount
(
context
.
Context
,
int64
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
f
*
fakeGroupRepo
)
DeleteAccountGroupsByGroupID
(
context
.
Context
,
int64
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
f
*
fakeGroupRepo
)
GetAccountIDsByGroupIDs
(
context
.
Context
,
[]
int64
)
([]
int64
,
error
)
{
return
nil
,
nil
}
func
(
f
*
fakeGroupRepo
)
BindAccountsToGroup
(
context
.
Context
,
int64
,
[]
int64
)
error
{
return
nil
}
func
(
f
*
fakeGroupRepo
)
UpdateSortOrders
(
context
.
Context
,
[]
service
.
GroupSortOrderUpdate
)
error
{
return
nil
}
type
fakeConcurrencyCache
struct
{}
func
(
f
*
fakeConcurrencyCache
)
AcquireAccountSlot
(
context
.
Context
,
int64
,
int
,
string
)
(
bool
,
error
)
{
return
true
,
nil
}
func
(
f
*
fakeConcurrencyCache
)
ReleaseAccountSlot
(
context
.
Context
,
int64
,
string
)
error
{
return
nil
}
func
(
f
*
fakeConcurrencyCache
)
GetAccountConcurrency
(
context
.
Context
,
int64
)
(
int
,
error
)
{
return
0
,
nil
}
func
(
f
*
fakeConcurrencyCache
)
IncrementAccountWaitCount
(
context
.
Context
,
int64
,
int
)
(
bool
,
error
)
{
return
true
,
nil
}
func
(
f
*
fakeConcurrencyCache
)
DecrementAccountWaitCount
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
f
*
fakeConcurrencyCache
)
GetAccountWaitingCount
(
context
.
Context
,
int64
)
(
int
,
error
)
{
return
0
,
nil
}
func
(
f
*
fakeConcurrencyCache
)
AcquireUserSlot
(
context
.
Context
,
int64
,
int
,
string
)
(
bool
,
error
)
{
return
true
,
nil
}
func
(
f
*
fakeConcurrencyCache
)
ReleaseUserSlot
(
context
.
Context
,
int64
,
string
)
error
{
return
nil
}
func
(
f
*
fakeConcurrencyCache
)
GetUserConcurrency
(
context
.
Context
,
int64
)
(
int
,
error
)
{
return
0
,
nil
}
func
(
f
*
fakeConcurrencyCache
)
IncrementWaitCount
(
context
.
Context
,
int64
,
int
)
(
bool
,
error
)
{
return
true
,
nil
}
func
(
f
*
fakeConcurrencyCache
)
DecrementWaitCount
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
f
*
fakeConcurrencyCache
)
GetAccountsLoadBatch
(
context
.
Context
,
[]
service
.
AccountWithConcurrency
)
(
map
[
int64
]
*
service
.
AccountLoadInfo
,
error
)
{
return
map
[
int64
]
*
service
.
AccountLoadInfo
{},
nil
}
func
(
f
*
fakeConcurrencyCache
)
GetUsersLoadBatch
(
context
.
Context
,
[]
service
.
UserWithConcurrency
)
(
map
[
int64
]
*
service
.
UserLoadInfo
,
error
)
{
return
map
[
int64
]
*
service
.
UserLoadInfo
{},
nil
}
func
(
f
*
fakeConcurrencyCache
)
CleanupExpiredAccountSlots
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
newTestGatewayHandler
(
t
*
testing
.
T
,
group
*
service
.
Group
,
accounts
[]
*
service
.
Account
)
(
*
GatewayHandler
,
func
())
{
t
.
Helper
()
schedulerCache
:=
&
fakeSchedulerCache
{
accounts
:
accounts
}
schedulerSnapshot
:=
service
.
NewSchedulerSnapshotService
(
schedulerCache
,
nil
,
nil
,
nil
,
nil
)
gwSvc
:=
service
.
NewGatewayService
(
nil
,
// accountRepo (not used: scheduler snapshot hit)
&
fakeGroupRepo
{
group
:
group
},
nil
,
// usageLogRepo
nil
,
// userRepo
nil
,
// userSubRepo
nil
,
// userGroupRateRepo
nil
,
// cache (disable sticky)
nil
,
// cfg
schedulerSnapshot
,
nil
,
// concurrencyService (disable load-aware; tryAcquire always acquired)
nil
,
// billingService
nil
,
// rateLimitService
nil
,
// billingCacheService
nil
,
// identityService
nil
,
// httpUpstream
nil
,
// deferredService
nil
,
// claudeTokenProvider
nil
,
// sessionLimitCache
nil
,
// digestStore
)
// RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
billingCacheSvc
:=
service
.
NewBillingCacheService
(
nil
,
nil
,
nil
,
cfg
)
concurrencySvc
:=
service
.
NewConcurrencyService
(
&
fakeConcurrencyCache
{})
concurrencyHelper
:=
NewConcurrencyHelper
(
concurrencySvc
,
SSEPingFormatClaude
,
0
)
h
:=
&
GatewayHandler
{
gatewayService
:
gwSvc
,
billingCacheService
:
billingCacheSvc
,
concurrencyHelper
:
concurrencyHelper
,
// 这些字段对本测试不敏感,保持较小即可
maxAccountSwitches
:
1
,
maxAccountSwitchesGemini
:
1
,
}
cleanup
:=
func
()
{
billingCacheSvc
.
Stop
()
}
return
h
,
cleanup
}
func
TestGatewayHandlerMessages_InterceptWarmup_AntigravityAccount_MixedSchedulingV1
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
groupID
:=
int64
(
2001
)
accountID
:=
int64
(
1001
)
group
:=
&
service
.
Group
{
ID
:
groupID
,
Hydrated
:
true
,
Platform
:
service
.
PlatformAnthropic
,
// /v1/messages(Claude兼容)入口
Status
:
service
.
StatusActive
,
}
account
:=
&
service
.
Account
{
ID
:
accountID
,
Name
:
"ag-1"
,
Platform
:
service
.
PlatformAntigravity
,
Type
:
service
.
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"tok_xxx"
,
"intercept_warmup_requests"
:
true
,
},
Extra
:
map
[
string
]
any
{
"mixed_scheduling"
:
true
,
// 关键:允许被 anthropic 分组混合调度选中
},
Concurrency
:
1
,
Priority
:
1
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
AccountGroups
:
[]
service
.
AccountGroup
{{
AccountID
:
accountID
,
GroupID
:
groupID
}},
}
h
,
cleanup
:=
newTestGatewayHandler
(
t
,
group
,
[]
*
service
.
Account
{
account
})
defer
cleanup
()
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
body
:=
[]
byte
(
`{
"model": "claude-sonnet-4-5",
"max_tokens": 256,
"messages": [{"role":"user","content":[{"type":"text","text":"Warmup"}]}]
}`
)
req
:=
httptest
.
NewRequest
(
"POST"
,
"/v1/messages"
,
bytes
.
NewReader
(
body
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
=
req
.
WithContext
(
context
.
WithValue
(
req
.
Context
(),
ctxkey
.
Group
,
group
))
c
.
Request
=
req
apiKey
:=
&
service
.
APIKey
{
ID
:
3001
,
UserID
:
4001
,
GroupID
:
&
groupID
,
Status
:
service
.
StatusActive
,
User
:
&
service
.
User
{
ID
:
4001
,
Concurrency
:
10
,
Balance
:
100
,
},
Group
:
group
,
}
c
.
Set
(
string
(
middleware
.
ContextKeyAPIKey
),
apiKey
)
c
.
Set
(
string
(
middleware
.
ContextKeyUser
),
middleware
.
AuthSubject
{
UserID
:
apiKey
.
UserID
,
Concurrency
:
10
})
h
.
Messages
(
c
)
require
.
Equal
(
t
,
200
,
rec
.
Code
)
// 断言:确实选中了 antigravity 账号(不是纯函数测试,而是从 Handler 里验证调度结果)
selected
,
ok
:=
c
.
Get
(
opsAccountIDKey
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
accountID
,
selected
)
var
resp
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
rec
.
Body
.
Bytes
(),
&
resp
))
require
.
Equal
(
t
,
"msg_mock_warmup"
,
resp
[
"id"
])
require
.
Equal
(
t
,
"claude-sonnet-4-5"
,
resp
[
"model"
])
content
,
ok
:=
resp
[
"content"
]
.
([]
any
)
require
.
True
(
t
,
ok
)
require
.
Len
(
t
,
content
,
1
)
first
,
ok
:=
content
[
0
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"New Conversation"
,
first
[
"text"
])
}
func
TestGatewayHandlerMessages_InterceptWarmup_AntigravityAccount_ForcePlatform
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
groupID
:=
int64
(
2002
)
accountID
:=
int64
(
1002
)
group
:=
&
service
.
Group
{
ID
:
groupID
,
Hydrated
:
true
,
Platform
:
service
.
PlatformAntigravity
,
Status
:
service
.
StatusActive
,
}
account
:=
&
service
.
Account
{
ID
:
accountID
,
Name
:
"ag-2"
,
Platform
:
service
.
PlatformAntigravity
,
Type
:
service
.
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"tok_xxx"
,
"intercept_warmup_requests"
:
true
,
},
Concurrency
:
1
,
Priority
:
1
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
AccountGroups
:
[]
service
.
AccountGroup
{{
AccountID
:
accountID
,
GroupID
:
groupID
}},
}
h
,
cleanup
:=
newTestGatewayHandler
(
t
,
group
,
[]
*
service
.
Account
{
account
})
defer
cleanup
()
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
body
:=
[]
byte
(
`{
"model": "claude-sonnet-4-5",
"max_tokens": 256,
"messages": [{"role":"user","content":[{"type":"text","text":"Warmup"}]}]
}`
)
req
:=
httptest
.
NewRequest
(
"POST"
,
"/antigravity/v1/messages"
,
bytes
.
NewReader
(
body
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
// 模拟 routes/gateway.go 里的 ForcePlatform 中间件效果:
// - 写入 request.Context(Service读取)
// - 写入 gin.Context(Handler快速读取)
ctx
:=
context
.
WithValue
(
req
.
Context
(),
ctxkey
.
Group
,
group
)
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
ForcePlatform
,
service
.
PlatformAntigravity
)
req
=
req
.
WithContext
(
ctx
)
c
.
Request
=
req
c
.
Set
(
string
(
middleware
.
ContextKeyForcePlatform
),
service
.
PlatformAntigravity
)
apiKey
:=
&
service
.
APIKey
{
ID
:
3002
,
UserID
:
4002
,
GroupID
:
&
groupID
,
Status
:
service
.
StatusActive
,
User
:
&
service
.
User
{
ID
:
4002
,
Concurrency
:
10
,
Balance
:
100
,
},
Group
:
group
,
}
c
.
Set
(
string
(
middleware
.
ContextKeyAPIKey
),
apiKey
)
c
.
Set
(
string
(
middleware
.
ContextKeyUser
),
middleware
.
AuthSubject
{
UserID
:
apiKey
.
UserID
,
Concurrency
:
10
})
h
.
Messages
(
c
)
require
.
Equal
(
t
,
200
,
rec
.
Code
)
selected
,
ok
:=
c
.
Get
(
opsAccountIDKey
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
accountID
,
selected
)
var
resp
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
rec
.
Body
.
Bytes
(),
&
resp
))
require
.
Equal
(
t
,
"msg_mock_warmup"
,
resp
[
"id"
])
require
.
Equal
(
t
,
"claude-sonnet-4-5"
,
resp
[
"model"
])
}
backend/internal/handler/gemini_v1beta_handler.go
View file @
c7e18bd5
...
...
@@ -344,11 +344,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
hasBoundSession
:=
sessionKey
!=
""
&&
sessionBoundAccountID
>
0
cleanedForUnknownBinding
:=
false
maxAccountSwitches
:=
h
.
maxAccountSwitchesGemini
switchCount
:=
0
failedAccountIDs
:=
make
(
map
[
int64
]
struct
{})
var
lastFailoverErr
*
service
.
UpstreamFailoverError
var
forceCacheBilling
bool
// 粘性会话切换时的缓存计费标记
fs
:=
NewFailoverState
(
h
.
maxAccountSwitchesGemini
,
hasBoundSession
)
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
...
...
@@ -358,30 +354,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
for
{
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
modelName
,
failedAccountIDs
,
""
)
// Gemini 不使用会话限制
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
modelName
,
f
s
.
F
ailedAccountIDs
,
""
)
// Gemini 不使用会话限制
if
err
!=
nil
{
if
len
(
failedAccountIDs
)
==
0
{
if
len
(
f
s
.
F
ailedAccountIDs
)
==
0
{
googleError
(
c
,
http
.
StatusServiceUnavailable
,
"No available Gemini accounts: "
+
err
.
Error
())
return
}
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
if
lastFailoverErr
!=
nil
&&
lastFailoverErr
.
StatusCode
==
http
.
StatusServiceUnavailable
&&
switchCount
<=
maxAccountSwitches
{
if
sleepAntigravitySingleAccountBackoff
(
c
.
Request
.
Context
(),
switchCount
)
{
reqLog
.
Warn
(
"gemini.single_account_retrying"
,
zap
.
Int
(
"retry_count"
,
switchCount
),
zap
.
Int
(
"max_retries"
,
maxAccountSwitches
),
)
failedAccountIDs
=
make
(
map
[
int64
]
struct
{})
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
action
:=
fs
.
HandleSelectionExhausted
(
c
.
Request
.
Context
())
switch
action
{
case
FailoverContinue
:
ctx
:=
context
.
WithValue
(
c
.
Request
.
Context
(),
ctxkey
.
SingleAccountRetry
,
true
)
c
.
Request
=
c
.
Request
.
WithContext
(
ctx
)
continue
}
}
h
.
handleGeminiFailoverExhausted
(
c
,
lastFailoverErr
)
case
FailoverCanceled
:
return
default
:
// FailoverExhausted
h
.
handleGeminiFailoverExhausted
(
c
,
fs
.
LastFailoverErr
)
return
}
}
account
:=
selection
.
Account
setOpsSelectedAccount
(
c
,
account
.
ID
,
account
.
Platform
)
...
...
@@ -465,8 +455,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 5) forward (根据平台分流)
var
result
*
service
.
ForwardResult
requestCtx
:=
c
.
Request
.
Context
()
if
s
witchCount
>
0
{
requestCtx
=
context
.
WithValue
(
requestCtx
,
ctxkey
.
AccountSwitchCount
,
s
witchCount
)
if
fs
.
S
witchCount
>
0
{
requestCtx
=
context
.
WithValue
(
requestCtx
,
ctxkey
.
AccountSwitchCount
,
fs
.
S
witchCount
)
}
if
account
.
Platform
==
service
.
PlatformAntigravity
&&
account
.
Type
!=
service
.
AccountTypeAPIKey
{
result
,
err
=
h
.
antigravityGatewayService
.
ForwardGemini
(
requestCtx
,
c
,
account
,
modelName
,
action
,
stream
,
body
,
hasBoundSession
)
...
...
@@ -479,30 +469,17 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
if
err
!=
nil
{
var
failoverErr
*
service
.
UpstreamFailoverError
if
errors
.
As
(
err
,
&
failoverErr
)
{
failedAccountIDs
[
account
.
ID
]
=
struct
{}{}
if
needForceCacheBilling
(
hasBoundSession
,
failoverErr
)
{
forceCacheBilling
=
true
}
if
switchCount
>=
maxAccountSwitches
{
lastFailoverErr
=
failoverErr
h
.
handleGeminiFailoverExhausted
(
c
,
lastFailoverErr
)
failoverAction
:=
fs
.
HandleFailoverError
(
c
.
Request
.
Context
(),
h
.
gatewayService
,
account
.
ID
,
account
.
Platform
,
failoverErr
)
switch
failoverAction
{
case
FailoverContinue
:
continue
case
FailoverExhausted
:
h
.
handleGeminiFailoverExhausted
(
c
,
fs
.
LastFailoverErr
)
return
}
lastFailoverErr
=
failoverErr
switchCount
++
reqLog
.
Warn
(
"gemini.upstream_failover_switching"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Int
(
"upstream_status"
,
failoverErr
.
StatusCode
),
zap
.
Int
(
"switch_count"
,
switchCount
),
zap
.
Int
(
"max_switches"
,
maxAccountSwitches
),
)
if
account
.
Platform
==
service
.
PlatformAntigravity
{
if
!
sleepFailoverDelay
(
c
.
Request
.
Context
(),
switchCount
)
{
case
FailoverCanceled
:
return
}
}
continue
}
// ForwardNative already wrote the response
reqLog
.
Error
(
"gemini.forward_failed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Error
(
err
))
return
...
...
@@ -539,7 +516,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
IPAddress
:
clientIP
,
LongContextThreshold
:
200000
,
// Gemini 200K 阈值
LongContextMultiplier
:
2.0
,
// 超出部分双倍计费
ForceCacheBilling
:
forceCacheBilling
,
ForceCacheBilling
:
f
s
.
F
orceCacheBilling
,
APIKeyService
:
h
.
apiKeyService
,
});
err
!=
nil
{
logger
.
L
()
.
With
(
...
...
@@ -554,7 +531,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
})
reqLog
.
Debug
(
"gemini.request_completed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Int
(
"switch_count"
,
s
witchCount
),
zap
.
Int
(
"switch_count"
,
fs
.
S
witchCount
),
)
return
}
...
...
backend/internal/pkg/antigravity/client_test.go
View file @
c7e18bd5
...
...
@@ -400,7 +400,9 @@ func TestShouldFallbackToNextURL_无错误且200(t *testing.T) {
// ---------------------------------------------------------------------------
func
TestClient_ExchangeCode_成功
(
t
*
testing
.
T
)
{
t
.
Setenv
(
AntigravityOAuthClientSecretEnv
,
"test-secret"
)
old
:=
defaultClientSecret
defaultClientSecret
=
"test-secret"
t
.
Cleanup
(
func
()
{
defaultClientSecret
=
old
})
server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
// 验证请求方法
...
...
@@ -493,7 +495,9 @@ func TestClient_ExchangeCode_成功(t *testing.T) {
}
func
TestClient_ExchangeCode_无ClientSecret
(
t
*
testing
.
T
)
{
t
.
Setenv
(
AntigravityOAuthClientSecretEnv
,
""
)
old
:=
defaultClientSecret
defaultClientSecret
=
""
t
.
Cleanup
(
func
()
{
defaultClientSecret
=
old
})
client
:=
NewClient
(
""
)
_
,
err
:=
client
.
ExchangeCode
(
context
.
Background
(),
"code"
,
"verifier"
)
...
...
@@ -506,7 +510,9 @@ func TestClient_ExchangeCode_无ClientSecret(t *testing.T) {
}
func
TestClient_ExchangeCode_服务器返回错误
(
t
*
testing
.
T
)
{
t
.
Setenv
(
AntigravityOAuthClientSecretEnv
,
"test-secret"
)
old
:=
defaultClientSecret
defaultClientSecret
=
"test-secret"
t
.
Cleanup
(
func
()
{
defaultClientSecret
=
old
})
server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
WriteHeader
(
http
.
StatusBadRequest
)
...
...
@@ -531,7 +537,9 @@ func TestClient_ExchangeCode_服务器返回错误(t *testing.T) {
// ---------------------------------------------------------------------------
func
TestClient_RefreshToken_MockServer
(
t
*
testing
.
T
)
{
t
.
Setenv
(
AntigravityOAuthClientSecretEnv
,
"test-secret"
)
old
:=
defaultClientSecret
defaultClientSecret
=
"test-secret"
t
.
Cleanup
(
func
()
{
defaultClientSecret
=
old
})
server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
if
r
.
Method
!=
http
.
MethodPost
{
...
...
@@ -590,7 +598,9 @@ func TestClient_RefreshToken_MockServer(t *testing.T) {
}
func
TestClient_RefreshToken_无ClientSecret
(
t
*
testing
.
T
)
{
t
.
Setenv
(
AntigravityOAuthClientSecretEnv
,
""
)
old
:=
defaultClientSecret
defaultClientSecret
=
""
t
.
Cleanup
(
func
()
{
defaultClientSecret
=
old
})
client
:=
NewClient
(
""
)
_
,
err
:=
client
.
RefreshToken
(
context
.
Background
(),
"refresh-tok"
)
...
...
@@ -784,7 +794,9 @@ func newTestClientWithRedirect(redirects map[string]string) *Client {
// ---------------------------------------------------------------------------
func
TestClient_ExchangeCode_Success_RealCall
(
t
*
testing
.
T
)
{
t
.
Setenv
(
AntigravityOAuthClientSecretEnv
,
"test-secret"
)
old
:=
defaultClientSecret
defaultClientSecret
=
"test-secret"
t
.
Cleanup
(
func
()
{
defaultClientSecret
=
old
})
server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
if
r
.
Method
!=
http
.
MethodPost
{
...
...
@@ -853,7 +865,9 @@ func TestClient_ExchangeCode_Success_RealCall(t *testing.T) {
}
func
TestClient_ExchangeCode_ServerError_RealCall
(
t
*
testing
.
T
)
{
t
.
Setenv
(
AntigravityOAuthClientSecretEnv
,
"test-secret"
)
old
:=
defaultClientSecret
defaultClientSecret
=
"test-secret"
t
.
Cleanup
(
func
()
{
defaultClientSecret
=
old
})
server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
WriteHeader
(
http
.
StatusBadRequest
)
...
...
@@ -878,7 +892,9 @@ func TestClient_ExchangeCode_ServerError_RealCall(t *testing.T) {
}
func
TestClient_ExchangeCode_InvalidJSON_RealCall
(
t
*
testing
.
T
)
{
t
.
Setenv
(
AntigravityOAuthClientSecretEnv
,
"test-secret"
)
old
:=
defaultClientSecret
defaultClientSecret
=
"test-secret"
t
.
Cleanup
(
func
()
{
defaultClientSecret
=
old
})
server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
...
...
@@ -901,7 +917,9 @@ func TestClient_ExchangeCode_InvalidJSON_RealCall(t *testing.T) {
}
func
TestClient_ExchangeCode_ContextCanceled_RealCall
(
t
*
testing
.
T
)
{
t
.
Setenv
(
AntigravityOAuthClientSecretEnv
,
"test-secret"
)
old
:=
defaultClientSecret
defaultClientSecret
=
"test-secret"
t
.
Cleanup
(
func
()
{
defaultClientSecret
=
old
})
server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
time
.
Sleep
(
5
*
time
.
Second
)
// 模拟慢响应
...
...
@@ -927,7 +945,9 @@ func TestClient_ExchangeCode_ContextCanceled_RealCall(t *testing.T) {
// ---------------------------------------------------------------------------
func
TestClient_RefreshToken_Success_RealCall
(
t
*
testing
.
T
)
{
t
.
Setenv
(
AntigravityOAuthClientSecretEnv
,
"test-secret"
)
old
:=
defaultClientSecret
defaultClientSecret
=
"test-secret"
t
.
Cleanup
(
func
()
{
defaultClientSecret
=
old
})
server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
if
r
.
Method
!=
http
.
MethodPost
{
...
...
@@ -976,7 +996,9 @@ func TestClient_RefreshToken_Success_RealCall(t *testing.T) {
}
func
TestClient_RefreshToken_ServerError_RealCall
(
t
*
testing
.
T
)
{
t
.
Setenv
(
AntigravityOAuthClientSecretEnv
,
"test-secret"
)
old
:=
defaultClientSecret
defaultClientSecret
=
"test-secret"
t
.
Cleanup
(
func
()
{
defaultClientSecret
=
old
})
server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
WriteHeader
(
http
.
StatusUnauthorized
)
...
...
@@ -998,7 +1020,9 @@ func TestClient_RefreshToken_ServerError_RealCall(t *testing.T) {
}
func
TestClient_RefreshToken_InvalidJSON_RealCall
(
t
*
testing
.
T
)
{
t
.
Setenv
(
AntigravityOAuthClientSecretEnv
,
"test-secret"
)
old
:=
defaultClientSecret
defaultClientSecret
=
"test-secret"
t
.
Cleanup
(
func
()
{
defaultClientSecret
=
old
})
server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
...
...
@@ -1021,7 +1045,9 @@ func TestClient_RefreshToken_InvalidJSON_RealCall(t *testing.T) {
}
func
TestClient_RefreshToken_ContextCanceled_RealCall
(
t
*
testing
.
T
)
{
t
.
Setenv
(
AntigravityOAuthClientSecretEnv
,
"test-secret"
)
old
:=
defaultClientSecret
defaultClientSecret
=
"test-secret"
t
.
Cleanup
(
func
()
{
defaultClientSecret
=
old
})
server
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
time
.
Sleep
(
5
*
time
.
Second
)
...
...
backend/internal/pkg/antigravity/oauth.go
View file @
c7e18bd5
...
...
@@ -24,10 +24,8 @@ const (
// Antigravity OAuth 客户端凭证
ClientID
=
"1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
ClientSecret
=
""
// AntigravityOAuthClientSecretEnv 是 Antigravity OAuth client_secret 的环境变量名。
// 出于安全原因,该值不得硬编码入库。
AntigravityOAuthClientSecretEnv
=
"ANTIGRAVITY_OAUTH_CLIENT_SECRET"
// 固定的 redirect_uri(用户需手动复制 code)
...
...
@@ -51,14 +49,21 @@ const (
antigravityDailyBaseURL
=
"https://daily-cloudcode-pa.sandbox.googleapis.com"
)
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.84.2
var
defaultUserAgentVersion
=
"1.84.2"
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.18.4
var
defaultUserAgentVersion
=
"1.18.4"
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
var
defaultClientSecret
=
"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"
func
init
()
{
// 从环境变量读取版本号,未设置则使用默认值
if
version
:=
os
.
Getenv
(
"ANTIGRAVITY_USER_AGENT_VERSION"
);
version
!=
""
{
defaultUserAgentVersion
=
version
}
// 从环境变量读取 client_secret,未设置则使用默认值
if
secret
:=
os
.
Getenv
(
AntigravityOAuthClientSecretEnv
);
secret
!=
""
{
defaultClientSecret
=
secret
}
}
// GetUserAgent 返回当前配置的 User-Agent
...
...
@@ -67,14 +72,9 @@ func GetUserAgent() string {
}
func
getClientSecret
()
(
string
,
error
)
{
if
v
:=
strings
.
TrimSpace
(
ClientSecret
);
v
!=
""
{
if
v
:=
strings
.
TrimSpace
(
default
ClientSecret
);
v
!=
""
{
return
v
,
nil
}
if
v
,
ok
:=
os
.
LookupEnv
(
AntigravityOAuthClientSecretEnv
);
ok
{
if
vv
:=
strings
.
TrimSpace
(
v
);
vv
!=
""
{
return
vv
,
nil
}
}
return
""
,
infraerrors
.
Newf
(
http
.
StatusBadRequest
,
"ANTIGRAVITY_OAUTH_CLIENT_SECRET_MISSING"
,
"missing antigravity oauth client_secret; set %s"
,
AntigravityOAuthClientSecretEnv
)
}
...
...
backend/internal/pkg/antigravity/oauth_test.go
View file @
c7e18bd5
...
...
@@ -7,6 +7,7 @@ import (
"encoding/base64"
"encoding/hex"
"net/url"
"os"
"strings"
"testing"
"time"
...
...
@@ -17,8 +18,14 @@ import (
// ---------------------------------------------------------------------------
func
TestGetClientSecret_环境变量设置
(
t
*
testing
.
T
)
{
old
:=
defaultClientSecret
defaultClientSecret
=
""
t
.
Cleanup
(
func
()
{
defaultClientSecret
=
old
})
t
.
Setenv
(
AntigravityOAuthClientSecretEnv
,
"my-secret-value"
)
// 需要重新触发 init 逻辑:手动从环境变量读取
defaultClientSecret
=
os
.
Getenv
(
AntigravityOAuthClientSecretEnv
)
secret
,
err
:=
getClientSecret
()
if
err
!=
nil
{
t
.
Fatalf
(
"获取 client_secret 失败: %v"
,
err
)
...
...
@@ -29,11 +36,13 @@ func TestGetClientSecret_环境变量设置(t *testing.T) {
}
func
TestGetClientSecret_环境变量为空
(
t
*
testing
.
T
)
{
t
.
Setenv
(
AntigravityOAuthClientSecretEnv
,
""
)
old
:=
defaultClientSecret
defaultClientSecret
=
""
t
.
Cleanup
(
func
()
{
defaultClientSecret
=
old
})
_
,
err
:=
getClientSecret
()
if
err
==
nil
{
t
.
Fatal
(
"
环境变量
为空时应返回错误"
)
t
.
Fatal
(
"
defaultClientSecret
为空时应返回错误"
)
}
if
!
strings
.
Contains
(
err
.
Error
(),
AntigravityOAuthClientSecretEnv
)
{
t
.
Errorf
(
"错误信息应包含环境变量名: got %s"
,
err
.
Error
())
...
...
@@ -41,30 +50,31 @@ func TestGetClientSecret_环境变量为空(t *testing.T) {
}
func
TestGetClientSecret_环境变量未设置
(
t
*
testing
.
T
)
{
// t.Setenv 会在测试结束时恢复,但我们需要确保它不存在
// 注意:如果 ClientSecret 常量非空,这个测试会直接返回常量值
// 当前代码中 ClientSecret = "",所以会走环境变量逻辑
// 明确设置再取消,确保环境变量不存在
t
.
Setenv
(
AntigravityOAuthClientSecretEnv
,
""
)
old
:=
defaultClientSecret
defaultClientSecret
=
""
t
.
Cleanup
(
func
()
{
defaultClientSecret
=
old
})
_
,
err
:=
getClientSecret
()
if
err
==
nil
{
t
.
Fatal
(
"
环境变量未设置
时应返回错误"
)
t
.
Fatal
(
"
defaultClientSecret 为空
时应返回错误"
)
}
}
func
TestGetClientSecret_环境变量含空格
(
t
*
testing
.
T
)
{
t
.
Setenv
(
AntigravityOAuthClientSecretEnv
,
" "
)
old
:=
defaultClientSecret
defaultClientSecret
=
" "
t
.
Cleanup
(
func
()
{
defaultClientSecret
=
old
})
_
,
err
:=
getClientSecret
()
if
err
==
nil
{
t
.
Fatal
(
"
环境变量
仅含空格时应返回错误"
)
t
.
Fatal
(
"
defaultClientSecret
仅含空格时应返回错误"
)
}
}
func
TestGetClientSecret_环境变量有前后空格
(
t
*
testing
.
T
)
{
t
.
Setenv
(
AntigravityOAuthClientSecretEnv
,
" valid-secret "
)
old
:=
defaultClientSecret
defaultClientSecret
=
" valid-secret "
t
.
Cleanup
(
func
()
{
defaultClientSecret
=
old
})
secret
,
err
:=
getClientSecret
()
if
err
!=
nil
{
...
...
@@ -670,13 +680,17 @@ func TestConstants_值正确(t *testing.T) {
if
ClientID
!=
"1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
{
t
.
Errorf
(
"ClientID 不匹配: got %s"
,
ClientID
)
}
if
ClientSecret
!=
""
{
t
.
Error
(
"ClientSecret 应为空字符串"
)
secret
,
err
:=
getClientSecret
()
if
err
!=
nil
{
t
.
Fatalf
(
"getClientSecret 应返回默认值,但报错: %v"
,
err
)
}
if
secret
!=
"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"
{
t
.
Errorf
(
"默认 client_secret 不匹配: got %s"
,
secret
)
}
if
RedirectURI
!=
"http://localhost:8085/callback"
{
t
.
Errorf
(
"RedirectURI 不匹配: got %s"
,
RedirectURI
)
}
if
GetUserAgent
()
!=
"antigravity/1.
84.2
windows/amd64"
{
if
GetUserAgent
()
!=
"antigravity/1.
18.4
windows/amd64"
{
t
.
Errorf
(
"UserAgent 不匹配: got %s"
,
GetUserAgent
())
}
if
SessionTTL
!=
30
*
time
.
Minute
{
...
...
backend/internal/pkg/antigravity/request_transformer.go
View file @
c7e18bd5
...
...
@@ -206,6 +206,7 @@ type modelInfo struct {
var
modelInfoMap
=
map
[
string
]
modelInfo
{
"claude-opus-4-5"
:
{
DisplayName
:
"Claude Opus 4.5"
,
CanonicalID
:
"claude-opus-4-5-20250929"
},
"claude-opus-4-6"
:
{
DisplayName
:
"Claude Opus 4.6"
,
CanonicalID
:
"claude-opus-4-6"
},
"claude-sonnet-4-6"
:
{
DisplayName
:
"Claude Sonnet 4.6"
,
CanonicalID
:
"claude-sonnet-4-6"
},
"claude-sonnet-4-5"
:
{
DisplayName
:
"Claude Sonnet 4.5"
,
CanonicalID
:
"claude-sonnet-4-5-20250929"
},
"claude-haiku-4-5"
:
{
DisplayName
:
"Claude Haiku 4.5"
,
CanonicalID
:
"claude-haiku-4-5-20251001"
},
}
...
...
backend/internal/server/routes/admin.go
View file @
c7e18bd5
...
...
@@ -219,6 +219,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts
.
GET
(
""
,
h
.
Admin
.
Account
.
List
)
accounts
.
GET
(
"/:id"
,
h
.
Admin
.
Account
.
GetByID
)
accounts
.
POST
(
""
,
h
.
Admin
.
Account
.
Create
)
accounts
.
POST
(
"/check-mixed-channel"
,
h
.
Admin
.
Account
.
CheckMixedChannel
)
accounts
.
POST
(
"/sync/crs"
,
h
.
Admin
.
Account
.
SyncFromCRS
)
accounts
.
POST
(
"/sync/crs/preview"
,
h
.
Admin
.
Account
.
PreviewFromCRS
)
accounts
.
PUT
(
"/:id"
,
h
.
Admin
.
Account
.
Update
)
...
...
backend/internal/service/account_intercept_warmup_test.go
0 → 100644
View file @
c7e18bd5
//go:build unit
package
service
import
(
"testing"
"github.com/stretchr/testify/require"
)
func
TestAccount_IsInterceptWarmupEnabled
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
credentials
map
[
string
]
any
expected
bool
}{
{
name
:
"nil credentials"
,
credentials
:
nil
,
expected
:
false
,
},
{
name
:
"empty map"
,
credentials
:
map
[
string
]
any
{},
expected
:
false
,
},
{
name
:
"field not present"
,
credentials
:
map
[
string
]
any
{
"access_token"
:
"tok"
},
expected
:
false
,
},
{
name
:
"field is true"
,
credentials
:
map
[
string
]
any
{
"intercept_warmup_requests"
:
true
},
expected
:
true
,
},
{
name
:
"field is false"
,
credentials
:
map
[
string
]
any
{
"intercept_warmup_requests"
:
false
},
expected
:
false
,
},
{
name
:
"field is string true"
,
credentials
:
map
[
string
]
any
{
"intercept_warmup_requests"
:
"true"
},
expected
:
false
,
},
{
name
:
"field is int 1"
,
credentials
:
map
[
string
]
any
{
"intercept_warmup_requests"
:
1
},
expected
:
false
,
},
{
name
:
"field is nil"
,
credentials
:
map
[
string
]
any
{
"intercept_warmup_requests"
:
nil
},
expected
:
false
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
a
:=
&
Account
{
Credentials
:
tt
.
credentials
}
result
:=
a
.
IsInterceptWarmupEnabled
()
require
.
Equal
(
t
,
tt
.
expected
,
result
)
})
}
}
backend/internal/service/admin_service.go
View file @
c7e18bd5
...
...
@@ -54,6 +54,7 @@ type AdminService interface {
SetAccountError
(
ctx
context
.
Context
,
id
int64
,
errorMsg
string
)
error
SetAccountSchedulable
(
ctx
context
.
Context
,
id
int64
,
schedulable
bool
)
(
*
Account
,
error
)
BulkUpdateAccounts
(
ctx
context
.
Context
,
input
*
BulkUpdateAccountsInput
)
(
*
BulkUpdateAccountsResult
,
error
)
CheckMixedChannelRisk
(
ctx
context
.
Context
,
currentAccountID
int64
,
currentAccountPlatform
string
,
groupIDs
[]
int64
)
error
// Proxy management
ListProxies
(
ctx
context
.
Context
,
page
,
pageSize
int
,
protocol
,
status
,
search
string
)
([]
Proxy
,
int64
,
error
)
...
...
@@ -2114,6 +2115,11 @@ func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAcc
return
nil
}
// CheckMixedChannelRisk checks whether target groups contain mixed channels for the current account platform.
func
(
s
*
adminServiceImpl
)
CheckMixedChannelRisk
(
ctx
context
.
Context
,
currentAccountID
int64
,
currentAccountPlatform
string
,
groupIDs
[]
int64
)
error
{
return
s
.
checkMixedChannelRisk
(
ctx
,
currentAccountID
,
currentAccountPlatform
,
groupIDs
)
}
func
(
s
*
adminServiceImpl
)
attachProxyLatency
(
ctx
context
.
Context
,
proxies
[]
ProxyWithAccountCount
)
{
if
s
.
proxyLatencyCache
==
nil
||
len
(
proxies
)
==
0
{
return
...
...
backend/internal/service/antigravity_gateway_service.go
View file @
c7e18bd5
...
...
@@ -87,7 +87,6 @@ var (
)
const
(
antigravityBillingModelEnv
=
"GATEWAY_ANTIGRAVITY_BILL_WITH_MAPPED_MODEL"
antigravityForwardBaseURLEnv
=
"GATEWAY_ANTIGRAVITY_FORWARD_BASE_URL"
antigravityFallbackSecondsEnv
=
"GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS"
)
...
...
@@ -1309,6 +1308,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
// 应用 thinking 模式自动后缀:如果 thinking 开启且目标是 claude-sonnet-4-5,自动改为 thinking 版本
thinkingEnabled
:=
claudeReq
.
Thinking
!=
nil
&&
(
claudeReq
.
Thinking
.
Type
==
"enabled"
||
claudeReq
.
Thinking
.
Type
==
"adaptive"
)
mappedModel
=
applyThinkingModelSuffix
(
mappedModel
,
thinkingEnabled
)
billingModel
:=
mappedModel
// 获取 access_token
if
s
.
tokenProvider
==
nil
{
...
...
@@ -1370,6 +1370,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
ForceCacheBilling
:
switchErr
.
IsStickySession
,
}
}
// 区分客户端取消和真正的上游失败,返回更准确的错误消息
if
c
.
Request
.
Context
()
.
Err
()
!=
nil
{
return
nil
,
s
.
writeClaudeError
(
c
,
http
.
StatusBadGateway
,
"client_disconnected"
,
"Client disconnected before upstream response"
)
}
return
nil
,
s
.
writeClaudeError
(
c
,
http
.
StatusBadGateway
,
"upstream_error"
,
"Upstream request failed after retries"
)
}
resp
:=
result
.
resp
...
...
@@ -1618,7 +1622,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return
&
ForwardResult
{
RequestID
:
requestID
,
Usage
:
*
usage
,
Model
:
original
Model
,
// 使用
原始
模型用于计费和日志
Model
:
billing
Model
,
// 使用
映射
模型用于计费和日志
Stream
:
claudeReq
.
Stream
,
Duration
:
time
.
Since
(
startTime
),
FirstTokenMs
:
firstTokenMs
,
...
...
@@ -1972,6 +1976,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
if
mappedModel
==
""
{
return
nil
,
s
.
writeGoogleError
(
c
,
http
.
StatusForbidden
,
fmt
.
Sprintf
(
"model %s not in whitelist"
,
originalModel
))
}
billingModel
:=
mappedModel
// 获取 access_token
if
s
.
tokenProvider
==
nil
{
...
...
@@ -2042,6 +2047,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
ForceCacheBilling
:
switchErr
.
IsStickySession
,
}
}
// 区分客户端取消和真正的上游失败,返回更准确的错误消息
if
c
.
Request
.
Context
()
.
Err
()
!=
nil
{
return
nil
,
s
.
writeGoogleError
(
c
,
http
.
StatusBadGateway
,
"Client disconnected before upstream response"
)
}
return
nil
,
s
.
writeGoogleError
(
c
,
http
.
StatusBadGateway
,
"Upstream request failed after retries"
)
}
resp
:=
result
.
resp
...
...
@@ -2197,7 +2206,7 @@ handleSuccess:
return
&
ForwardResult
{
RequestID
:
requestID
,
Usage
:
*
usage
,
Model
:
original
Model
,
Model
:
billing
Model
,
Stream
:
stream
,
Duration
:
time
.
Since
(
startTime
),
FirstTokenMs
:
firstTokenMs
,
...
...
@@ -2642,7 +2651,16 @@ func (s *AntigravityGatewayService) handleUpstreamError(
defaultDur
:=
s
.
getDefaultRateLimitDuration
()
// 尝试解析模型 key 并设置模型级限流
modelKey
:=
resolveAntigravityModelKey
(
requestedModel
)
//
// 注意:requestedModel 可能是"映射前"的请求模型名(例如 claude-opus-4-6),
// 调度与限流判定使用的是 Antigravity 最终模型名(包含映射与 thinking 后缀)。
// 因此这里必须写入最终模型 key,确保后续调度能正确避开已限流模型。
modelKey
:=
resolveFinalAntigravityModelKey
(
ctx
,
account
,
requestedModel
)
if
strings
.
TrimSpace
(
modelKey
)
==
""
{
// 极少数情况下无法映射(理论上不应发生:能转发成功说明映射已通过),
// 保持旧行为作为兜底,避免完全丢失模型级限流记录。
modelKey
=
resolveAntigravityModelKey
(
requestedModel
)
}
if
modelKey
!=
""
{
ra
:=
s
.
resolveResetTime
(
resetAt
,
defaultDur
)
if
err
:=
s
.
accountRepo
.
SetModelRateLimit
(
ctx
,
account
.
ID
,
modelKey
,
ra
);
err
!=
nil
{
...
...
@@ -3881,7 +3899,6 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
return
nil
,
fmt
.
Errorf
(
"missing model"
)
}
originalModel
:=
claudeReq
.
Model
billingModel
:=
originalModel
// 构建上游请求 URL
upstreamURL
:=
baseURL
+
"/v1/messages"
...
...
@@ -3934,7 +3951,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
_
,
_
=
c
.
Writer
.
Write
(
respBody
)
return
&
ForwardResult
{
Model
:
billing
Model
,
Model
:
original
Model
,
},
nil
}
...
...
@@ -3975,7 +3992,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
logger
.
LegacyPrintf
(
"service.antigravity_gateway"
,
"%s status=success duration_ms=%d"
,
prefix
,
duration
.
Milliseconds
())
return
&
ForwardResult
{
Model
:
billing
Model
,
Model
:
original
Model
,
Stream
:
claudeReq
.
Stream
,
Duration
:
duration
,
FirstTokenMs
:
firstTokenMs
,
...
...
backend/internal/service/antigravity_gateway_service_test.go
View file @
c7e18bd5
...
...
@@ -134,6 +134,36 @@ func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int,
return
s
.
resp
,
s
.
err
}
type
antigravitySettingRepoStub
struct
{}
func
(
s
*
antigravitySettingRepoStub
)
Get
(
ctx
context
.
Context
,
key
string
)
(
*
Setting
,
error
)
{
panic
(
"unexpected Get call"
)
}
func
(
s
*
antigravitySettingRepoStub
)
GetValue
(
ctx
context
.
Context
,
key
string
)
(
string
,
error
)
{
return
""
,
ErrSettingNotFound
}
func
(
s
*
antigravitySettingRepoStub
)
Set
(
ctx
context
.
Context
,
key
,
value
string
)
error
{
panic
(
"unexpected Set call"
)
}
func
(
s
*
antigravitySettingRepoStub
)
GetMultiple
(
ctx
context
.
Context
,
keys
[]
string
)
(
map
[
string
]
string
,
error
)
{
panic
(
"unexpected GetMultiple call"
)
}
func
(
s
*
antigravitySettingRepoStub
)
SetMultiple
(
ctx
context
.
Context
,
settings
map
[
string
]
string
)
error
{
panic
(
"unexpected SetMultiple call"
)
}
func
(
s
*
antigravitySettingRepoStub
)
GetAll
(
ctx
context
.
Context
)
(
map
[
string
]
string
,
error
)
{
panic
(
"unexpected GetAll call"
)
}
func
(
s
*
antigravitySettingRepoStub
)
Delete
(
ctx
context
.
Context
,
key
string
)
error
{
panic
(
"unexpected Delete call"
)
}
func
TestAntigravityGatewayService_Forward_PromptTooLong
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
writer
:=
httptest
.
NewRecorder
()
...
...
@@ -160,6 +190,7 @@ func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
}
svc
:=
&
AntigravityGatewayService
{
settingService
:
NewSettingService
(
&
antigravitySettingRepoStub
{},
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
MaxLineSize
:
defaultMaxLineSize
}}),
tokenProvider
:
&
AntigravityTokenProvider
{},
httpUpstream
:
&
httpUpstreamStub
{
resp
:
resp
},
}
...
...
@@ -418,6 +449,113 @@ func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(
require
.
True
(
t
,
failoverErr
.
ForceCacheBilling
,
"ForceCacheBilling should be true for sticky session switch"
)
}
// TestAntigravityGatewayService_Forward_BillsWithMappedModel
// 验证:Antigravity Claude 转发返回的计费模型使用映射后的模型
func
TestAntigravityGatewayService_Forward_BillsWithMappedModel
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
writer
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
writer
)
body
,
err
:=
json
.
Marshal
(
map
[
string
]
any
{
"model"
:
"claude-sonnet-4-5"
,
"messages"
:
[]
map
[
string
]
any
{
{
"role"
:
"user"
,
"content"
:
"hello"
},
},
"max_tokens"
:
16
,
"stream"
:
true
,
})
require
.
NoError
(
t
,
err
)
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1/messages"
,
bytes
.
NewReader
(
body
))
c
.
Request
=
req
upstreamBody
:=
[]
byte
(
"data: {
\"
response
\"
:{
\"
candidates
\"
:[{
\"
content
\"
:{
\"
parts
\"
:[{
\"
text
\"
:
\"
ok
\"
}]},
\"
finishReason
\"
:
\"
STOP
\"
}],
\"
usageMetadata
\"
:{
\"
promptTokenCount
\"
:8,
\"
candidatesTokenCount
\"
:3}}}
\n\n
"
)
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"X-Request-Id"
:
[]
string
{
"req-bill-1"
}},
Body
:
io
.
NopCloser
(
bytes
.
NewReader
(
upstreamBody
)),
}
svc
:=
&
AntigravityGatewayService
{
settingService
:
NewSettingService
(
&
antigravitySettingRepoStub
{},
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
MaxLineSize
:
defaultMaxLineSize
}}),
tokenProvider
:
&
AntigravityTokenProvider
{},
httpUpstream
:
&
httpUpstreamStub
{
resp
:
resp
},
}
const
mappedModel
=
"gemini-3-pro-high"
account
:=
&
Account
{
ID
:
5
,
Name
:
"acc-forward-billing"
,
Platform
:
PlatformAntigravity
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"token"
,
"model_mapping"
:
map
[
string
]
any
{
"claude-sonnet-4-5"
:
mappedModel
,
},
},
}
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
c
,
account
,
body
,
false
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
mappedModel
,
result
.
Model
)
}
// TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel
// 验证:Antigravity Gemini 转发返回的计费模型使用映射后的模型
func
TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
writer
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
writer
)
body
,
err
:=
json
.
Marshal
(
map
[
string
]
any
{
"contents"
:
[]
map
[
string
]
any
{
{
"role"
:
"user"
,
"parts"
:
[]
map
[
string
]
any
{{
"text"
:
"hello"
}}},
},
})
require
.
NoError
(
t
,
err
)
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1beta/models/gemini-2.5-flash:generateContent"
,
bytes
.
NewReader
(
body
))
c
.
Request
=
req
upstreamBody
:=
[]
byte
(
"data: {
\"
response
\"
:{
\"
candidates
\"
:[{
\"
content
\"
:{
\"
parts
\"
:[{
\"
text
\"
:
\"
ok
\"
}]},
\"
finishReason
\"
:
\"
STOP
\"
}],
\"
usageMetadata
\"
:{
\"
promptTokenCount
\"
:8,
\"
candidatesTokenCount
\"
:3}}}
\n\n
"
)
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"X-Request-Id"
:
[]
string
{
"req-bill-2"
}},
Body
:
io
.
NopCloser
(
bytes
.
NewReader
(
upstreamBody
)),
}
svc
:=
&
AntigravityGatewayService
{
settingService
:
NewSettingService
(
&
antigravitySettingRepoStub
{},
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
MaxLineSize
:
defaultMaxLineSize
}}),
tokenProvider
:
&
AntigravityTokenProvider
{},
httpUpstream
:
&
httpUpstreamStub
{
resp
:
resp
},
}
const
mappedModel
=
"gemini-3-pro-high"
account
:=
&
Account
{
ID
:
6
,
Name
:
"acc-gemini-billing"
,
Platform
:
PlatformAntigravity
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"token"
,
"model_mapping"
:
map
[
string
]
any
{
"gemini-2.5-flash"
:
mappedModel
,
},
},
}
result
,
err
:=
svc
.
ForwardGemini
(
context
.
Background
(),
c
,
account
,
"gemini-2.5-flash"
,
"generateContent"
,
true
,
body
,
false
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
mappedModel
,
result
.
Model
)
}
// TestStreamUpstreamResponse_UsageAndFirstToken
// 验证:usage 字段可被累积/覆盖更新,并且能记录首 token 时间
func
TestStreamUpstreamResponse_UsageAndFirstToken
(
t
*
testing
.
T
)
{
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment