Unverified Commit c7e18bd5 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #627 from touwaeriol/pr/bugfixes-and-enhancements

feat: 反重力(Antigravity)增强、Failover 重构及新模型支持
parents 516f8f28 8365a832
...@@ -1158,6 +1158,7 @@ func setDefaults() { ...@@ -1158,6 +1158,7 @@ func setDefaults() {
viper.SetDefault("gateway.force_codex_cli", false) viper.SetDefault("gateway.force_codex_cli", false)
viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false) viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false)
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1) viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
viper.SetDefault("gateway.antigravity_extra_retries", 10)
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024)) viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024)) viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024))
viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024)) viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024))
......
...@@ -74,6 +74,7 @@ var DefaultAntigravityModelMapping = map[string]string{ ...@@ -74,6 +74,7 @@ var DefaultAntigravityModelMapping = map[string]string{
"claude-opus-4-6-thinking": "claude-opus-4-6-thinking", // 官方模型 "claude-opus-4-6-thinking": "claude-opus-4-6-thinking", // 官方模型
"claude-opus-4-6": "claude-opus-4-6-thinking", // 简称映射 "claude-opus-4-6": "claude-opus-4-6-thinking", // 简称映射
"claude-opus-4-5-thinking": "claude-opus-4-6-thinking", // 迁移旧模型 "claude-opus-4-5-thinking": "claude-opus-4-6-thinking", // 迁移旧模型
"claude-sonnet-4-6": "claude-sonnet-4-6",
"claude-sonnet-4-5": "claude-sonnet-4-5", "claude-sonnet-4-5": "claude-sonnet-4-5",
"claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking", "claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
// Claude 详细版本 ID 映射 // Claude 详细版本 ID 映射
...@@ -89,16 +90,18 @@ var DefaultAntigravityModelMapping = map[string]string{ ...@@ -89,16 +90,18 @@ var DefaultAntigravityModelMapping = map[string]string{
"gemini-2.5-pro": "gemini-2.5-pro", "gemini-2.5-pro": "gemini-2.5-pro",
// Gemini 3 白名单 // Gemini 3 白名单
"gemini-3-flash": "gemini-3-flash", "gemini-3-flash": "gemini-3-flash",
"gemini-3-pro-high": "gemini-3.1-pro-high", "gemini-3-pro-high": "gemini-3-pro-high",
"gemini-3-pro-low": "gemini-3.1-pro-low", "gemini-3-pro-low": "gemini-3-pro-low",
"gemini-3-pro-image": "gemini-3-pro-image", "gemini-3-pro-image": "gemini-3-pro-image",
// Gemini 3.1 透传
"gemini-3.1-pro-high": "gemini-3.1-pro-high",
"gemini-3.1-pro-low": "gemini-3.1-pro-low",
// Gemini 3 preview 映射 // Gemini 3 preview 映射
"gemini-3-flash-preview": "gemini-3-flash", "gemini-3-flash-preview": "gemini-3-flash",
"gemini-3-pro-preview": "gemini-3.1-pro-high", "gemini-3-pro-preview": "gemini-3-pro-high",
"gemini-3-pro-image-preview": "gemini-3-pro-image", "gemini-3-pro-image-preview": "gemini-3-pro-image",
// Gemini 3.1 白名单
"gemini-3.1-pro-high": "gemini-3.1-pro-high",
"gemini-3.1-pro-low": "gemini-3.1-pro-low",
// Gemini 3.1 preview 映射
"gemini-3.1-pro-preview": "gemini-3.1-pro-high",
// 其他官方模型 // 其他官方模型
"gpt-oss-120b-medium": "gpt-oss-120b-medium", "gpt-oss-120b-medium": "gpt-oss-120b-medium",
"tab_flash_lite_preview": "tab_flash_lite_preview", "tab_flash_lite_preview": "tab_flash_lite_preview",
......
...@@ -139,6 +139,13 @@ type BulkUpdateAccountsRequest struct { ...@@ -139,6 +139,13 @@ type BulkUpdateAccountsRequest struct {
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险 ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
} }
// CheckMixedChannelRequest represents check mixed channel risk request
type CheckMixedChannelRequest struct {
Platform string `json:"platform" binding:"required"`
GroupIDs []int64 `json:"group_ids"`
AccountID *int64 `json:"account_id"`
}
// AccountWithConcurrency extends Account with real-time concurrency info // AccountWithConcurrency extends Account with real-time concurrency info
type AccountWithConcurrency struct { type AccountWithConcurrency struct {
*dto.Account *dto.Account
...@@ -389,6 +396,50 @@ func (h *AccountHandler) GetByID(c *gin.Context) { ...@@ -389,6 +396,50 @@ func (h *AccountHandler) GetByID(c *gin.Context) {
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account)) response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
} }
// CheckMixedChannel handles checking mixed channel risk for account-group binding.
// POST /api/v1/admin/accounts/check-mixed-channel
func (h *AccountHandler) CheckMixedChannel(c *gin.Context) {
var req CheckMixedChannelRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if len(req.GroupIDs) == 0 {
response.Success(c, gin.H{"has_risk": false})
return
}
accountID := int64(0)
if req.AccountID != nil {
accountID = *req.AccountID
}
err := h.adminService.CheckMixedChannelRisk(c.Request.Context(), accountID, req.Platform, req.GroupIDs)
if err != nil {
var mixedErr *service.MixedChannelError
if errors.As(err, &mixedErr) {
response.Success(c, gin.H{
"has_risk": true,
"error": "mixed_channel_warning",
"message": mixedErr.Error(),
"details": gin.H{
"group_id": mixedErr.GroupID,
"group_name": mixedErr.GroupName,
"current_platform": mixedErr.CurrentPlatform,
"other_platform": mixedErr.OtherPlatform,
},
})
return
}
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"has_risk": false})
}
// Create handles creating a new account // Create handles creating a new account
// POST /api/v1/admin/accounts // POST /api/v1/admin/accounts
func (h *AccountHandler) Create(c *gin.Context) { func (h *AccountHandler) Create(c *gin.Context) {
...@@ -431,17 +482,10 @@ func (h *AccountHandler) Create(c *gin.Context) { ...@@ -431,17 +482,10 @@ func (h *AccountHandler) Create(c *gin.Context) {
// 检查是否为混合渠道错误 // 检查是否为混合渠道错误
var mixedErr *service.MixedChannelError var mixedErr *service.MixedChannelError
if errors.As(err, &mixedErr) { if errors.As(err, &mixedErr) {
// 返回特殊错误码要求确认 // 创建接口仅返回最小必要字段,详细信息由专门检查接口提供
c.JSON(409, gin.H{ c.JSON(409, gin.H{
"error": "mixed_channel_warning", "error": "mixed_channel_warning",
"message": mixedErr.Error(), "message": mixedErr.Error(),
"details": gin.H{
"group_id": mixedErr.GroupID,
"group_name": mixedErr.GroupName,
"current_platform": mixedErr.CurrentPlatform,
"other_platform": mixedErr.OtherPlatform,
},
"require_confirmation": true,
}) })
return return
} }
...@@ -501,17 +545,10 @@ func (h *AccountHandler) Update(c *gin.Context) { ...@@ -501,17 +545,10 @@ func (h *AccountHandler) Update(c *gin.Context) {
// 检查是否为混合渠道错误 // 检查是否为混合渠道错误
var mixedErr *service.MixedChannelError var mixedErr *service.MixedChannelError
if errors.As(err, &mixedErr) { if errors.As(err, &mixedErr) {
// 返回特殊错误码要求确认 // 更新接口仅返回最小必要字段,详细信息由专门检查接口提供
c.JSON(409, gin.H{ c.JSON(409, gin.H{
"error": "mixed_channel_warning", "error": "mixed_channel_warning",
"message": mixedErr.Error(), "message": mixedErr.Error(),
"details": gin.H{
"group_id": mixedErr.GroupID,
"group_name": mixedErr.GroupName,
"current_platform": mixedErr.CurrentPlatform,
"other_platform": mixedErr.OtherPlatform,
},
"require_confirmation": true,
}) })
return return
} }
......
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func setupAccountMixedChannelRouter(adminSvc *stubAdminService) *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
accountHandler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
router.POST("/api/v1/admin/accounts/check-mixed-channel", accountHandler.CheckMixedChannel)
router.POST("/api/v1/admin/accounts", accountHandler.Create)
router.PUT("/api/v1/admin/accounts/:id", accountHandler.Update)
return router
}
func TestAccountHandlerCheckMixedChannelNoRisk(t *testing.T) {
adminSvc := newStubAdminService()
router := setupAccountMixedChannelRouter(adminSvc)
body, _ := json.Marshal(map[string]any{
"platform": "antigravity",
"group_ids": []int64{27},
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/check-mixed-channel", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, float64(0), resp["code"])
data, ok := resp["data"].(map[string]any)
require.True(t, ok)
require.Equal(t, false, data["has_risk"])
require.Equal(t, int64(0), adminSvc.lastMixedCheck.accountID)
require.Equal(t, "antigravity", adminSvc.lastMixedCheck.platform)
require.Equal(t, []int64{27}, adminSvc.lastMixedCheck.groupIDs)
}
func TestAccountHandlerCheckMixedChannelWithRisk(t *testing.T) {
adminSvc := newStubAdminService()
adminSvc.checkMixedErr = &service.MixedChannelError{
GroupID: 27,
GroupName: "claude-max",
CurrentPlatform: "Antigravity",
OtherPlatform: "Anthropic",
}
router := setupAccountMixedChannelRouter(adminSvc)
body, _ := json.Marshal(map[string]any{
"platform": "antigravity",
"group_ids": []int64{27},
"account_id": 99,
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/check-mixed-channel", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, float64(0), resp["code"])
data, ok := resp["data"].(map[string]any)
require.True(t, ok)
require.Equal(t, true, data["has_risk"])
require.Equal(t, "mixed_channel_warning", data["error"])
details, ok := data["details"].(map[string]any)
require.True(t, ok)
require.Equal(t, float64(27), details["group_id"])
require.Equal(t, "claude-max", details["group_name"])
require.Equal(t, "Antigravity", details["current_platform"])
require.Equal(t, "Anthropic", details["other_platform"])
require.Equal(t, int64(99), adminSvc.lastMixedCheck.accountID)
}
func TestAccountHandlerCreateMixedChannelConflictSimplifiedResponse(t *testing.T) {
adminSvc := newStubAdminService()
adminSvc.createAccountErr = &service.MixedChannelError{
GroupID: 27,
GroupName: "claude-max",
CurrentPlatform: "Antigravity",
OtherPlatform: "Anthropic",
}
router := setupAccountMixedChannelRouter(adminSvc)
body, _ := json.Marshal(map[string]any{
"name": "ag-oauth-1",
"platform": "antigravity",
"type": "oauth",
"credentials": map[string]any{"refresh_token": "rt"},
"group_ids": []int64{27},
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusConflict, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, "mixed_channel_warning", resp["error"])
require.Contains(t, resp["message"], "mixed_channel_warning")
_, hasDetails := resp["details"]
_, hasRequireConfirmation := resp["require_confirmation"]
require.False(t, hasDetails)
require.False(t, hasRequireConfirmation)
}
func TestAccountHandlerUpdateMixedChannelConflictSimplifiedResponse(t *testing.T) {
adminSvc := newStubAdminService()
adminSvc.updateAccountErr = &service.MixedChannelError{
GroupID: 27,
GroupName: "claude-max",
CurrentPlatform: "Antigravity",
OtherPlatform: "Anthropic",
}
router := setupAccountMixedChannelRouter(adminSvc)
body, _ := json.Marshal(map[string]any{
"group_ids": []int64{27},
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/accounts/3", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusConflict, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, "mixed_channel_warning", resp["error"])
require.Contains(t, resp["message"], "mixed_channel_warning")
_, hasDetails := resp["details"]
_, hasRequireConfirmation := resp["require_confirmation"]
require.False(t, hasDetails)
require.False(t, hasRequireConfirmation)
}
...@@ -10,19 +10,27 @@ import ( ...@@ -10,19 +10,27 @@ import (
) )
type stubAdminService struct { type stubAdminService struct {
users []service.User users []service.User
apiKeys []service.APIKey apiKeys []service.APIKey
groups []service.Group groups []service.Group
accounts []service.Account accounts []service.Account
proxies []service.Proxy proxies []service.Proxy
proxyCounts []service.ProxyWithAccountCount proxyCounts []service.ProxyWithAccountCount
redeems []service.RedeemCode redeems []service.RedeemCode
createdAccounts []*service.CreateAccountInput createdAccounts []*service.CreateAccountInput
createdProxies []*service.CreateProxyInput createdProxies []*service.CreateProxyInput
updatedProxyIDs []int64 updatedProxyIDs []int64
updatedProxies []*service.UpdateProxyInput updatedProxies []*service.UpdateProxyInput
testedProxyIDs []int64 testedProxyIDs []int64
mu sync.Mutex createAccountErr error
updateAccountErr error
checkMixedErr error
lastMixedCheck struct {
accountID int64
platform string
groupIDs []int64
}
mu sync.Mutex
} }
func newStubAdminService() *stubAdminService { func newStubAdminService() *stubAdminService {
...@@ -188,11 +196,17 @@ func (s *stubAdminService) CreateAccount(ctx context.Context, input *service.Cre ...@@ -188,11 +196,17 @@ func (s *stubAdminService) CreateAccount(ctx context.Context, input *service.Cre
s.mu.Lock() s.mu.Lock()
s.createdAccounts = append(s.createdAccounts, input) s.createdAccounts = append(s.createdAccounts, input)
s.mu.Unlock() s.mu.Unlock()
if s.createAccountErr != nil {
return nil, s.createAccountErr
}
account := service.Account{ID: 300, Name: input.Name, Status: service.StatusActive} account := service.Account{ID: 300, Name: input.Name, Status: service.StatusActive}
return &account, nil return &account, nil
} }
func (s *stubAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) { func (s *stubAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) {
if s.updateAccountErr != nil {
return nil, s.updateAccountErr
}
account := service.Account{ID: id, Name: input.Name, Status: service.StatusActive} account := service.Account{ID: id, Name: input.Name, Status: service.StatusActive}
return &account, nil return &account, nil
} }
...@@ -224,6 +238,13 @@ func (s *stubAdminService) BulkUpdateAccounts(ctx context.Context, input *servic ...@@ -224,6 +238,13 @@ func (s *stubAdminService) BulkUpdateAccounts(ctx context.Context, input *servic
return &service.BulkUpdateAccountsResult{Success: 1, Failed: 0, SuccessIDs: []int64{1}}, nil return &service.BulkUpdateAccountsResult{Success: 1, Failed: 0, SuccessIDs: []int64{1}}, nil
} }
func (s *stubAdminService) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
s.lastMixedCheck.accountID = currentAccountID
s.lastMixedCheck.platform = currentAccountPlatform
s.lastMixedCheck.groupIDs = append([]int64(nil), groupIDs...)
return s.checkMixedErr
}
func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.Proxy, int64, error) { func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.Proxy, int64, error) {
search = strings.TrimSpace(strings.ToLower(search)) search = strings.TrimSpace(strings.ToLower(search))
filtered := make([]service.Proxy, 0, len(s.proxies)) filtered := make([]service.Proxy, 0, len(s.proxies))
......
package handler
import (
"context"
"log"
"net/http"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// TempUnscheduler 用于 HandleFailoverError 中同账号重试耗尽后的临时封禁。
// GatewayService 隐式实现此接口。
type TempUnscheduler interface {
TempUnscheduleRetryableError(ctx context.Context, accountID int64, failoverErr *service.UpstreamFailoverError)
}
// FailoverAction 表示 failover 错误处理后的下一步动作
type FailoverAction int
const (
// FailoverContinue 继续循环(同账号重试或切换账号,调用方统一 continue)
FailoverContinue FailoverAction = iota
// FailoverExhausted 切换次数耗尽(调用方应返回错误响应)
FailoverExhausted
// FailoverCanceled context 已取消(调用方应直接 return)
FailoverCanceled
)
const (
// maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误)
maxSameAccountRetries = 2
// sameAccountRetryDelay 同账号重试间隔
sameAccountRetryDelay = 500 * time.Millisecond
// singleAccountBackoffDelay 单账号分组 503 退避重试固定延时。
// Service 层在 SingleAccountRetry 模式下已做充分原地重试(最多 3 次、总等待 30s),
// Handler 层只需短暂间隔后重新进入 Service 层即可。
singleAccountBackoffDelay = 2 * time.Second
)
// FailoverState 跨循环迭代共享的 failover 状态
type FailoverState struct {
SwitchCount int
MaxSwitches int
FailedAccountIDs map[int64]struct{}
SameAccountRetryCount map[int64]int
LastFailoverErr *service.UpstreamFailoverError
ForceCacheBilling bool
hasBoundSession bool
}
// NewFailoverState 创建 failover 状态
func NewFailoverState(maxSwitches int, hasBoundSession bool) *FailoverState {
return &FailoverState{
MaxSwitches: maxSwitches,
FailedAccountIDs: make(map[int64]struct{}),
SameAccountRetryCount: make(map[int64]int),
hasBoundSession: hasBoundSession,
}
}
// HandleFailoverError 处理 UpstreamFailoverError,返回下一步动作。
// 包含:缓存计费判断、同账号重试、临时封禁、切换计数、Antigravity 延时。
func (s *FailoverState) HandleFailoverError(
ctx context.Context,
gatewayService TempUnscheduler,
accountID int64,
platform string,
failoverErr *service.UpstreamFailoverError,
) FailoverAction {
s.LastFailoverErr = failoverErr
// 缓存计费判断
if needForceCacheBilling(s.hasBoundSession, failoverErr) {
s.ForceCacheBilling = true
}
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
if failoverErr.RetryableOnSameAccount && s.SameAccountRetryCount[accountID] < maxSameAccountRetries {
s.SameAccountRetryCount[accountID]++
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
accountID, failoverErr.StatusCode, s.SameAccountRetryCount[accountID], maxSameAccountRetries)
if !sleepWithContext(ctx, sameAccountRetryDelay) {
return FailoverCanceled
}
return FailoverContinue
}
// 同账号重试用尽,执行临时封禁
if failoverErr.RetryableOnSameAccount {
gatewayService.TempUnscheduleRetryableError(ctx, accountID, failoverErr)
}
// 加入失败列表
s.FailedAccountIDs[accountID] = struct{}{}
// 检查是否耗尽
if s.SwitchCount >= s.MaxSwitches {
return FailoverExhausted
}
// 递增切换计数
s.SwitchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d",
accountID, failoverErr.StatusCode, s.SwitchCount, s.MaxSwitches)
// Antigravity 平台换号线性递增延时
if platform == service.PlatformAntigravity {
delay := time.Duration(s.SwitchCount-1) * time.Second
if !sleepWithContext(ctx, delay) {
return FailoverCanceled
}
}
return FailoverContinue
}
// HandleSelectionExhausted 处理选号失败(所有候选账号都在排除列表中)时的退避重试决策。
// 针对 Antigravity 单账号分组的 503 (MODEL_CAPACITY_EXHAUSTED) 场景:
// 清除排除列表、等待退避后重新选号。
//
// 返回 FailoverContinue 时,调用方应设置 SingleAccountRetry context 并 continue。
// 返回 FailoverExhausted 时,调用方应返回错误响应。
// 返回 FailoverCanceled 时,调用方应直接 return。
func (s *FailoverState) HandleSelectionExhausted(ctx context.Context) FailoverAction {
if s.LastFailoverErr != nil &&
s.LastFailoverErr.StatusCode == http.StatusServiceUnavailable &&
s.SwitchCount <= s.MaxSwitches {
log.Printf("Antigravity single-account 503 backoff: waiting %v before retry (attempt %d)",
singleAccountBackoffDelay, s.SwitchCount)
if !sleepWithContext(ctx, singleAccountBackoffDelay) {
return FailoverCanceled
}
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d",
s.SwitchCount, s.MaxSwitches)
s.FailedAccountIDs = make(map[int64]struct{})
return FailoverContinue
}
return FailoverExhausted
}
// needForceCacheBilling 判断 failover 时是否需要强制缓存计费。
// 粘性会话切换账号、或上游明确标记时,将 input_tokens 转为 cache_read 计费。
func needForceCacheBilling(hasBoundSession bool, failoverErr *service.UpstreamFailoverError) bool {
return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling)
}
// sleepWithContext 等待指定时长,返回 false 表示 context 已取消。
func sleepWithContext(ctx context.Context, d time.Duration) bool {
if d <= 0 {
return true
}
select {
case <-ctx.Done():
return false
case <-time.After(d):
return true
}
}
This diff is collapsed.
...@@ -7,7 +7,6 @@ import ( ...@@ -7,7 +7,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log"
"net/http" "net/http"
"strings" "strings"
"time" "time"
...@@ -257,12 +256,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -257,12 +256,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0 hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
if platform == service.PlatformGemini { if platform == service.PlatformGemini {
maxAccountSwitches := h.maxAccountSwitchesGemini fs := NewFailoverState(h.maxAccountSwitchesGemini, hasBoundSession)
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
sameAccountRetryCount := make(map[int64]int) // 同账号重试计数
var lastFailoverErr *service.UpstreamFailoverError
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。 // 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。 // 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
...@@ -272,35 +266,28 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -272,35 +266,28 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} }
for { for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制 selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "") // Gemini 不使用会话限制
if err != nil { if err != nil {
if len(failedAccountIDs) == 0 { if len(fs.FailedAccountIDs) == 0 {
reqLog.Warn("gateway.account_select_failed", zap.Error(err), zap.Int("excluded_account_count", len(failedAccountIDs))) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return return
} }
// Antigravity 单账号退避重试:分组内没有其他可用账号时, action := fs.HandleSelectionExhausted(c.Request.Context())
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。 switch action {
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。 case FailoverContinue:
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches { ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) { c.Request = c.Request.WithContext(ctx)
reqLog.Warn("gateway.single_account_retrying", continue
zap.Int("retry_count", switchCount), case FailoverCanceled:
zap.Int("max_retries", maxAccountSwitches), return
) default: // FailoverExhausted
failedAccountIDs = make(map[int64]struct{}) if fs.LastFailoverErr != nil {
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换 h.handleFailoverExhausted(c, fs.LastFailoverErr, service.PlatformGemini, streamStarted)
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) } else {
c.Request = c.Request.WithContext(ctx) h.handleFailoverExhaustedSimple(c, 502, streamStarted)
continue
} }
return
} }
if lastFailoverErr != nil {
h.handleFailoverExhausted(c, lastFailoverErr, service.PlatformGemini, streamStarted)
} else {
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
}
return
} }
account := selection.Account account := selection.Account
setOpsSelectedAccount(c, account.ID, account.Platform) setOpsSelectedAccount(c, account.ID, account.Platform)
...@@ -376,8 +363,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -376,8 +363,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 转发请求 - 根据账号平台分流 // 转发请求 - 根据账号平台分流
var result *service.ForwardResult var result *service.ForwardResult
requestCtx := c.Request.Context() requestCtx := c.Request.Context()
if switchCount > 0 { if fs.SwitchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount) requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
} }
if account.Platform == service.PlatformAntigravity { if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession) result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession)
...@@ -390,45 +377,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -390,45 +377,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if err != nil { if err != nil {
var failoverErr *service.UpstreamFailoverError var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) { if errors.As(err, &failoverErr) {
lastFailoverErr = failoverErr action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
if needForceCacheBilling(hasBoundSession, failoverErr) { switch action {
forceCacheBilling = true case FailoverContinue:
}
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
if failoverErr.RetryableOnSameAccount && sameAccountRetryCount[account.ID] < maxSameAccountRetries {
sameAccountRetryCount[account.ID]++
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
account.ID, failoverErr.StatusCode, sameAccountRetryCount[account.ID], maxSameAccountRetries)
if !sleepSameAccountRetryDelay(c.Request.Context()) {
return
}
continue continue
} case FailoverExhausted:
h.handleFailoverExhausted(c, fs.LastFailoverErr, service.PlatformGemini, streamStarted)
// 同账号重试用尽,执行临时封禁并切换账号 return
if failoverErr.RetryableOnSameAccount { case FailoverCanceled:
h.gatewayService.TempUnscheduleRetryableError(c.Request.Context(), account.ID, failoverErr)
}
failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, streamStarted)
return return
} }
switchCount++
reqLog.Warn("gateway.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
if account.Platform == service.PlatformAntigravity {
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
return
}
}
continue
} }
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
reqLog.Error("gateway.forward_failed", reqLog.Error("gateway.forward_failed",
...@@ -453,7 +411,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -453,7 +411,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
Subscription: subscription, Subscription: subscription,
UserAgent: userAgent, UserAgent: userAgent,
IPAddress: clientIP, IPAddress: clientIP,
ForceCacheBilling: forceCacheBilling, ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
}); err != nil { }); err != nil {
logger.L().With( logger.L().With(
...@@ -486,45 +444,33 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -486,45 +444,33 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} }
for { for {
maxAccountSwitches := h.maxAccountSwitches fs := NewFailoverState(h.maxAccountSwitches, hasBoundSession)
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
sameAccountRetryCount := make(map[int64]int) // 同账号重试计数
var lastFailoverErr *service.UpstreamFailoverError
retryWithFallback := false retryWithFallback := false
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
for { for {
// 选择支持该模型的账号 // 选择支持该模型的账号
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID) selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID)
if err != nil { if err != nil {
if len(failedAccountIDs) == 0 { if len(fs.FailedAccountIDs) == 0 {
reqLog.Warn("gateway.account_select_failed", zap.Error(err), zap.Int("excluded_account_count", len(failedAccountIDs))) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return return
} }
// Antigravity 单账号退避重试:分组内没有其他可用账号时, action := fs.HandleSelectionExhausted(c.Request.Context())
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。 switch action {
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。 case FailoverContinue:
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches { ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) { c.Request = c.Request.WithContext(ctx)
reqLog.Warn("gateway.single_account_retrying", continue
zap.Int("retry_count", switchCount), case FailoverCanceled:
zap.Int("max_retries", maxAccountSwitches), return
) default: // FailoverExhausted
failedAccountIDs = make(map[int64]struct{}) if fs.LastFailoverErr != nil {
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换 h.handleFailoverExhausted(c, fs.LastFailoverErr, platform, streamStarted)
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) } else {
c.Request = c.Request.WithContext(ctx) h.handleFailoverExhaustedSimple(c, 502, streamStarted)
continue
} }
return
} }
if lastFailoverErr != nil {
h.handleFailoverExhausted(c, lastFailoverErr, platform, streamStarted)
} else {
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
}
return
} }
account := selection.Account account := selection.Account
setOpsSelectedAccount(c, account.ID, account.Platform) setOpsSelectedAccount(c, account.ID, account.Platform)
...@@ -600,8 +546,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -600,8 +546,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 转发请求 - 根据账号平台分流 // 转发请求 - 根据账号平台分流
var result *service.ForwardResult var result *service.ForwardResult
requestCtx := c.Request.Context() requestCtx := c.Request.Context()
if switchCount > 0 { if fs.SwitchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount) requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
} }
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey { if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession) result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
...@@ -657,45 +603,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -657,45 +603,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} }
var failoverErr *service.UpstreamFailoverError var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) { if errors.As(err, &failoverErr) {
lastFailoverErr = failoverErr action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
if needForceCacheBilling(hasBoundSession, failoverErr) { switch action {
forceCacheBilling = true case FailoverContinue:
}
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
if failoverErr.RetryableOnSameAccount && sameAccountRetryCount[account.ID] < maxSameAccountRetries {
sameAccountRetryCount[account.ID]++
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
account.ID, failoverErr.StatusCode, sameAccountRetryCount[account.ID], maxSameAccountRetries)
if !sleepSameAccountRetryDelay(c.Request.Context()) {
return
}
continue continue
} case FailoverExhausted:
h.handleFailoverExhausted(c, fs.LastFailoverErr, account.Platform, streamStarted)
// 同账号重试用尽,执行临时封禁并切换账号 return
if failoverErr.RetryableOnSameAccount { case FailoverCanceled:
h.gatewayService.TempUnscheduleRetryableError(c.Request.Context(), account.ID, failoverErr)
}
failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, failoverErr, account.Platform, streamStarted)
return return
} }
switchCount++
reqLog.Warn("gateway.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
if account.Platform == service.PlatformAntigravity {
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
return
}
}
continue
} }
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
reqLog.Error("gateway.forward_failed", reqLog.Error("gateway.forward_failed",
...@@ -720,7 +637,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -720,7 +637,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
Subscription: currentSubscription, Subscription: currentSubscription,
UserAgent: userAgent, UserAgent: userAgent,
IPAddress: clientIP, IPAddress: clientIP,
ForceCacheBilling: forceCacheBilling, ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
}); err != nil { }); err != nil {
logger.L().With( logger.L().With(
...@@ -733,11 +650,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -733,11 +650,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
).Error("gateway.record_usage_failed", zap.Error(err)) ).Error("gateway.record_usage_failed", zap.Error(err))
} }
}) })
reqLog.Debug("gateway.request_completed",
zap.Int64("account_id", account.ID),
zap.Int("switch_count", switchCount),
zap.Bool("fallback_used", fallbackUsed),
)
return return
} }
if !retryWithFallback { if !retryWithFallback {
...@@ -982,69 +894,6 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT ...@@ -982,69 +894,6 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
} }
// needForceCacheBilling 判断 failover 时是否需要强制缓存计费
// 粘性会话切换账号、或上游明确标记时,将 input_tokens 转为 cache_read 计费
func needForceCacheBilling(hasBoundSession bool, failoverErr *service.UpstreamFailoverError) bool {
return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling)
}
const (
// maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误)
maxSameAccountRetries = 2
// sameAccountRetryDelay 同账号重试间隔
sameAccountRetryDelay = 500 * time.Millisecond
)
// sleepSameAccountRetryDelay 同账号重试固定延时,返回 false 表示 context 已取消。
func sleepSameAccountRetryDelay(ctx context.Context) bool {
select {
case <-ctx.Done():
return false
case <-time.After(sameAccountRetryDelay):
return true
}
}
// sleepFailoverDelay 账号切换线性递增延时:第1次0s、第2次1s、第3次2s…
// 返回 false 表示 context 已取消。
func sleepFailoverDelay(ctx context.Context, switchCount int) bool {
delay := time.Duration(switchCount-1) * time.Second
if delay <= 0 {
return true
}
select {
case <-ctx.Done():
return false
case <-time.After(delay):
return true
}
}
// sleepAntigravitySingleAccountBackoff Antigravity 平台单账号分组的 503 退避重试延时。
// 当分组内只有一个可用账号且上游返回 503(MODEL_CAPACITY_EXHAUSTED)时使用,
// 采用短固定延时策略。Service 层在 SingleAccountRetry 模式下已经做了充分的原地重试
// (最多 3 次、总等待 30s),所以 Handler 层的退避只需短暂等待即可。
// 返回 false 表示 context 已取消。
func sleepAntigravitySingleAccountBackoff(ctx context.Context, retryCount int) bool {
// 固定短延时:2s
// Service 层已经在原地等待了足够长的时间(retryDelay × 重试次数),
// Handler 层只需短暂间隔后重新进入 Service 层即可。
const delay = 2 * time.Second
logger.L().With(
zap.String("component", "handler.gateway.failover"),
zap.Duration("delay", delay),
zap.Int("retry_count", retryCount),
).Info("gateway.single_account_backoff_waiting")
select {
case <-ctx.Done():
return false
case <-time.After(delay):
return true
}
}
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) { func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) {
statusCode := failoverErr.StatusCode statusCode := failoverErr.StatusCode
responseBody := failoverErr.ResponseBody responseBody := failoverErr.ResponseBody
......
package handler
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// sleepAntigravitySingleAccountBackoff 测试
// ---------------------------------------------------------------------------
func TestSleepAntigravitySingleAccountBackoff_ReturnsTrue(t *testing.T) {
ctx := context.Background()
start := time.Now()
ok := sleepAntigravitySingleAccountBackoff(ctx, 1)
elapsed := time.Since(start)
require.True(t, ok, "should return true when context is not canceled")
// 固定延迟 2s
require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond, "should wait approximately 2s")
require.Less(t, elapsed, 5*time.Second, "should not wait too long")
}
func TestSleepAntigravitySingleAccountBackoff_ContextCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel() // 立即取消
start := time.Now()
ok := sleepAntigravitySingleAccountBackoff(ctx, 1)
elapsed := time.Since(start)
require.False(t, ok, "should return false when context is canceled")
require.Less(t, elapsed, 500*time.Millisecond, "should return immediately on cancel")
}
func TestSleepAntigravitySingleAccountBackoff_FixedDelay(t *testing.T) {
// 验证不同 retryCount 都使用固定 2s 延迟
ctx := context.Background()
start := time.Now()
ok := sleepAntigravitySingleAccountBackoff(ctx, 5)
elapsed := time.Since(start)
require.True(t, ok)
// 即使 retryCount=5,延迟仍然是固定的 2s
require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond)
require.Less(t, elapsed, 5*time.Second)
}
//go:build unit
package handler
import (
"bytes"
"context"
"encoding/json"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
middleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
// 目标:严格验证“antigravity 账号通过 /v1/messages 提供 Claude 服务时”,
// 当账号 credentials.intercept_warmup_requests=true 且请求为 Warmup 时,
// 后端会在转发上游前直接拦截并返回 mock 响应(不依赖上游)。
type fakeSchedulerCache struct {
accounts []*service.Account
}
func (f *fakeSchedulerCache) GetSnapshot(_ context.Context, _ service.SchedulerBucket) ([]*service.Account, bool, error) {
return f.accounts, true, nil
}
func (f *fakeSchedulerCache) SetSnapshot(_ context.Context, _ service.SchedulerBucket, _ []service.Account) error {
return nil
}
func (f *fakeSchedulerCache) GetAccount(_ context.Context, _ int64) (*service.Account, error) {
return nil, nil
}
func (f *fakeSchedulerCache) SetAccount(_ context.Context, _ *service.Account) error { return nil }
func (f *fakeSchedulerCache) DeleteAccount(_ context.Context, _ int64) error { return nil }
func (f *fakeSchedulerCache) UpdateLastUsed(_ context.Context, _ map[int64]time.Time) error {
return nil
}
func (f *fakeSchedulerCache) TryLockBucket(_ context.Context, _ service.SchedulerBucket, _ time.Duration) (bool, error) {
return true, nil
}
func (f *fakeSchedulerCache) ListBuckets(_ context.Context) ([]service.SchedulerBucket, error) {
return nil, nil
}
func (f *fakeSchedulerCache) GetOutboxWatermark(_ context.Context) (int64, error) { return 0, nil }
func (f *fakeSchedulerCache) SetOutboxWatermark(_ context.Context, _ int64) error { return nil }
type fakeGroupRepo struct {
group *service.Group
}
func (f *fakeGroupRepo) Create(context.Context, *service.Group) error { return nil }
func (f *fakeGroupRepo) GetByID(context.Context, int64) (*service.Group, error) {
return f.group, nil
}
func (f *fakeGroupRepo) GetByIDLite(context.Context, int64) (*service.Group, error) {
return f.group, nil
}
func (f *fakeGroupRepo) Update(context.Context, *service.Group) error { return nil }
func (f *fakeGroupRepo) Delete(context.Context, int64) error { return nil }
func (f *fakeGroupRepo) DeleteCascade(context.Context, int64) ([]int64, error) { return nil, nil }
func (f *fakeGroupRepo) List(context.Context, pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (f *fakeGroupRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, *bool) ([]service.Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (f *fakeGroupRepo) ListActive(context.Context) ([]service.Group, error) { return nil, nil }
func (f *fakeGroupRepo) ListActiveByPlatform(context.Context, string) ([]service.Group, error) {
return nil, nil
}
func (f *fakeGroupRepo) ExistsByName(context.Context, string) (bool, error) { return false, nil }
func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, error) { return 0, nil }
func (f *fakeGroupRepo) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
return 0, nil
}
func (f *fakeGroupRepo) GetAccountIDsByGroupIDs(context.Context, []int64) ([]int64, error) {
return nil, nil
}
func (f *fakeGroupRepo) BindAccountsToGroup(context.Context, int64, []int64) error { return nil }
func (f *fakeGroupRepo) UpdateSortOrders(context.Context, []service.GroupSortOrderUpdate) error {
return nil
}
type fakeConcurrencyCache struct{}
func (f *fakeConcurrencyCache) AcquireAccountSlot(context.Context, int64, int, string) (bool, error) {
return true, nil
}
func (f *fakeConcurrencyCache) ReleaseAccountSlot(context.Context, int64, string) error { return nil }
func (f *fakeConcurrencyCache) GetAccountConcurrency(context.Context, int64) (int, error) {
return 0, nil
}
func (f *fakeConcurrencyCache) IncrementAccountWaitCount(context.Context, int64, int) (bool, error) {
return true, nil
}
func (f *fakeConcurrencyCache) DecrementAccountWaitCount(context.Context, int64) error { return nil }
func (f *fakeConcurrencyCache) GetAccountWaitingCount(context.Context, int64) (int, error) {
return 0, nil
}
func (f *fakeConcurrencyCache) AcquireUserSlot(context.Context, int64, int, string) (bool, error) {
return true, nil
}
func (f *fakeConcurrencyCache) ReleaseUserSlot(context.Context, int64, string) error { return nil }
func (f *fakeConcurrencyCache) GetUserConcurrency(context.Context, int64) (int, error) { return 0, nil }
func (f *fakeConcurrencyCache) IncrementWaitCount(context.Context, int64, int) (bool, error) {
return true, nil
}
func (f *fakeConcurrencyCache) DecrementWaitCount(context.Context, int64) error { return nil }
func (f *fakeConcurrencyCache) GetAccountsLoadBatch(context.Context, []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
return map[int64]*service.AccountLoadInfo{}, nil
}
func (f *fakeConcurrencyCache) GetUsersLoadBatch(context.Context, []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) {
return map[int64]*service.UserLoadInfo{}, nil
}
func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil }
func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) {
t.Helper()
schedulerCache := &fakeSchedulerCache{accounts: accounts}
schedulerSnapshot := service.NewSchedulerSnapshotService(schedulerCache, nil, nil, nil, nil)
gwSvc := service.NewGatewayService(
nil, // accountRepo (not used: scheduler snapshot hit)
&fakeGroupRepo{group: group},
nil, // usageLogRepo
nil, // userRepo
nil, // userSubRepo
nil, // userGroupRateRepo
nil, // cache (disable sticky)
nil, // cfg
schedulerSnapshot,
nil, // concurrencyService (disable load-aware; tryAcquire always acquired)
nil, // billingService
nil, // rateLimitService
nil, // billingCacheService
nil, // identityService
nil, // httpUpstream
nil, // deferredService
nil, // claudeTokenProvider
nil, // sessionLimitCache
nil, // digestStore
)
// RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。
cfg := &config.Config{RunMode: config.RunModeSimple}
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, cfg)
concurrencySvc := service.NewConcurrencyService(&fakeConcurrencyCache{})
concurrencyHelper := NewConcurrencyHelper(concurrencySvc, SSEPingFormatClaude, 0)
h := &GatewayHandler{
gatewayService: gwSvc,
billingCacheService: billingCacheSvc,
concurrencyHelper: concurrencyHelper,
// 这些字段对本测试不敏感,保持较小即可
maxAccountSwitches: 1,
maxAccountSwitchesGemini: 1,
}
cleanup := func() {
billingCacheSvc.Stop()
}
return h, cleanup
}
func TestGatewayHandlerMessages_InterceptWarmup_AntigravityAccount_MixedSchedulingV1(t *testing.T) {
gin.SetMode(gin.TestMode)
groupID := int64(2001)
accountID := int64(1001)
group := &service.Group{
ID: groupID,
Hydrated: true,
Platform: service.PlatformAnthropic, // /v1/messages(Claude兼容)入口
Status: service.StatusActive,
}
account := &service.Account{
ID: accountID,
Name: "ag-1",
Platform: service.PlatformAntigravity,
Type: service.AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "tok_xxx",
"intercept_warmup_requests": true,
},
Extra: map[string]any{
"mixed_scheduling": true, // 关键:允许被 anthropic 分组混合调度选中
},
Concurrency: 1,
Priority: 1,
Status: service.StatusActive,
Schedulable: true,
AccountGroups: []service.AccountGroup{{AccountID: accountID, GroupID: groupID}},
}
h, cleanup := newTestGatewayHandler(t, group, []*service.Account{account})
defer cleanup()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{
"model": "claude-sonnet-4-5",
"max_tokens": 256,
"messages": [{"role":"user","content":[{"type":"text","text":"Warmup"}]}]
}`)
req := httptest.NewRequest("POST", "/v1/messages", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req = req.WithContext(context.WithValue(req.Context(), ctxkey.Group, group))
c.Request = req
apiKey := &service.APIKey{
ID: 3001,
UserID: 4001,
GroupID: &groupID,
Status: service.StatusActive,
User: &service.User{
ID: 4001,
Concurrency: 10,
Balance: 100,
},
Group: group,
}
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.UserID, Concurrency: 10})
h.Messages(c)
require.Equal(t, 200, rec.Code)
// 断言:确实选中了 antigravity 账号(不是纯函数测试,而是从 Handler 里验证调度结果)
selected, ok := c.Get(opsAccountIDKey)
require.True(t, ok)
require.Equal(t, accountID, selected)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, "msg_mock_warmup", resp["id"])
require.Equal(t, "claude-sonnet-4-5", resp["model"])
content, ok := resp["content"].([]any)
require.True(t, ok)
require.Len(t, content, 1)
first, ok := content[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "New Conversation", first["text"])
}
func TestGatewayHandlerMessages_InterceptWarmup_AntigravityAccount_ForcePlatform(t *testing.T) {
gin.SetMode(gin.TestMode)
groupID := int64(2002)
accountID := int64(1002)
group := &service.Group{
ID: groupID,
Hydrated: true,
Platform: service.PlatformAntigravity,
Status: service.StatusActive,
}
account := &service.Account{
ID: accountID,
Name: "ag-2",
Platform: service.PlatformAntigravity,
Type: service.AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "tok_xxx",
"intercept_warmup_requests": true,
},
Concurrency: 1,
Priority: 1,
Status: service.StatusActive,
Schedulable: true,
AccountGroups: []service.AccountGroup{{AccountID: accountID, GroupID: groupID}},
}
h, cleanup := newTestGatewayHandler(t, group, []*service.Account{account})
defer cleanup()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{
"model": "claude-sonnet-4-5",
"max_tokens": 256,
"messages": [{"role":"user","content":[{"type":"text","text":"Warmup"}]}]
}`)
req := httptest.NewRequest("POST", "/antigravity/v1/messages", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
// 模拟 routes/gateway.go 里的 ForcePlatform 中间件效果:
// - 写入 request.Context(Service读取)
// - 写入 gin.Context(Handler快速读取)
ctx := context.WithValue(req.Context(), ctxkey.Group, group)
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformAntigravity)
req = req.WithContext(ctx)
c.Request = req
c.Set(string(middleware.ContextKeyForcePlatform), service.PlatformAntigravity)
apiKey := &service.APIKey{
ID: 3002,
UserID: 4002,
GroupID: &groupID,
Status: service.StatusActive,
User: &service.User{
ID: 4002,
Concurrency: 10,
Balance: 100,
},
Group: group,
}
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.UserID, Concurrency: 10})
h.Messages(c)
require.Equal(t, 200, rec.Code)
selected, ok := c.Get(opsAccountIDKey)
require.True(t, ok)
require.Equal(t, accountID, selected)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, "msg_mock_warmup", resp["id"])
require.Equal(t, "claude-sonnet-4-5", resp["model"])
}
...@@ -344,11 +344,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -344,11 +344,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0 hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
cleanedForUnknownBinding := false cleanedForUnknownBinding := false
maxAccountSwitches := h.maxAccountSwitchesGemini fs := NewFailoverState(h.maxAccountSwitchesGemini, hasBoundSession)
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
var lastFailoverErr *service.UpstreamFailoverError
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。 // 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。 // 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
...@@ -358,30 +354,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -358,30 +354,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
} }
for { for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制 selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, fs.FailedAccountIDs, "") // Gemini 不使用会话限制
if err != nil { if err != nil {
if len(failedAccountIDs) == 0 { if len(fs.FailedAccountIDs) == 0 {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return return
} }
// Antigravity 单账号退避重试:分组内没有其他可用账号时, action := fs.HandleSelectionExhausted(c.Request.Context())
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。 switch action {
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。 case FailoverContinue:
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches { ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) { c.Request = c.Request.WithContext(ctx)
reqLog.Warn("gemini.single_account_retrying", continue
zap.Int("retry_count", switchCount), case FailoverCanceled:
zap.Int("max_retries", maxAccountSwitches), return
) default: // FailoverExhausted
failedAccountIDs = make(map[int64]struct{}) h.handleGeminiFailoverExhausted(c, fs.LastFailoverErr)
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换 return
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
c.Request = c.Request.WithContext(ctx)
continue
}
} }
h.handleGeminiFailoverExhausted(c, lastFailoverErr)
return
} }
account := selection.Account account := selection.Account
setOpsSelectedAccount(c, account.ID, account.Platform) setOpsSelectedAccount(c, account.ID, account.Platform)
...@@ -465,8 +455,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -465,8 +455,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 5) forward (根据平台分流) // 5) forward (根据平台分流)
var result *service.ForwardResult var result *service.ForwardResult
requestCtx := c.Request.Context() requestCtx := c.Request.Context()
if switchCount > 0 { if fs.SwitchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount) requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
} }
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey { if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body, hasBoundSession) result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body, hasBoundSession)
...@@ -479,29 +469,16 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -479,29 +469,16 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
if err != nil { if err != nil {
var failoverErr *service.UpstreamFailoverError var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) { if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{} failoverAction := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
if needForceCacheBilling(hasBoundSession, failoverErr) { switch failoverAction {
forceCacheBilling = true case FailoverContinue:
} continue
if switchCount >= maxAccountSwitches { case FailoverExhausted:
lastFailoverErr = failoverErr h.handleGeminiFailoverExhausted(c, fs.LastFailoverErr)
h.handleGeminiFailoverExhausted(c, lastFailoverErr) return
case FailoverCanceled:
return return
} }
lastFailoverErr = failoverErr
switchCount++
reqLog.Warn("gemini.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
if account.Platform == service.PlatformAntigravity {
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
return
}
}
continue
} }
// ForwardNative already wrote the response // ForwardNative already wrote the response
reqLog.Error("gemini.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err)) reqLog.Error("gemini.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
...@@ -539,7 +516,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -539,7 +516,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
IPAddress: clientIP, IPAddress: clientIP,
LongContextThreshold: 200000, // Gemini 200K 阈值 LongContextThreshold: 200000, // Gemini 200K 阈值
LongContextMultiplier: 2.0, // 超出部分双倍计费 LongContextMultiplier: 2.0, // 超出部分双倍计费
ForceCacheBilling: forceCacheBilling, ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
}); err != nil { }); err != nil {
logger.L().With( logger.L().With(
...@@ -554,7 +531,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -554,7 +531,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}) })
reqLog.Debug("gemini.request_completed", reqLog.Debug("gemini.request_completed",
zap.Int64("account_id", account.ID), zap.Int64("account_id", account.ID),
zap.Int("switch_count", switchCount), zap.Int("switch_count", fs.SwitchCount),
) )
return return
} }
......
...@@ -400,7 +400,9 @@ func TestShouldFallbackToNextURL_无错误且200(t *testing.T) { ...@@ -400,7 +400,9 @@ func TestShouldFallbackToNextURL_无错误且200(t *testing.T) {
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
func TestClient_ExchangeCode_成功(t *testing.T) { func TestClient_ExchangeCode_成功(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 验证请求方法 // 验证请求方法
...@@ -493,7 +495,9 @@ func TestClient_ExchangeCode_成功(t *testing.T) { ...@@ -493,7 +495,9 @@ func TestClient_ExchangeCode_成功(t *testing.T) {
} }
func TestClient_ExchangeCode_无ClientSecret(t *testing.T) { func TestClient_ExchangeCode_无ClientSecret(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "") old := defaultClientSecret
defaultClientSecret = ""
t.Cleanup(func() { defaultClientSecret = old })
client := NewClient("") client := NewClient("")
_, err := client.ExchangeCode(context.Background(), "code", "verifier") _, err := client.ExchangeCode(context.Background(), "code", "verifier")
...@@ -506,7 +510,9 @@ func TestClient_ExchangeCode_无ClientSecret(t *testing.T) { ...@@ -506,7 +510,9 @@ func TestClient_ExchangeCode_无ClientSecret(t *testing.T) {
} }
func TestClient_ExchangeCode_服务器返回错误(t *testing.T) { func TestClient_ExchangeCode_服务器返回错误(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
...@@ -531,7 +537,9 @@ func TestClient_ExchangeCode_服务器返回错误(t *testing.T) { ...@@ -531,7 +537,9 @@ func TestClient_ExchangeCode_服务器返回错误(t *testing.T) {
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
func TestClient_RefreshToken_MockServer(t *testing.T) { func TestClient_RefreshToken_MockServer(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
...@@ -590,7 +598,9 @@ func TestClient_RefreshToken_MockServer(t *testing.T) { ...@@ -590,7 +598,9 @@ func TestClient_RefreshToken_MockServer(t *testing.T) {
} }
func TestClient_RefreshToken_无ClientSecret(t *testing.T) { func TestClient_RefreshToken_无ClientSecret(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "") old := defaultClientSecret
defaultClientSecret = ""
t.Cleanup(func() { defaultClientSecret = old })
client := NewClient("") client := NewClient("")
_, err := client.RefreshToken(context.Background(), "refresh-tok") _, err := client.RefreshToken(context.Background(), "refresh-tok")
...@@ -784,7 +794,9 @@ func newTestClientWithRedirect(redirects map[string]string) *Client { ...@@ -784,7 +794,9 @@ func newTestClientWithRedirect(redirects map[string]string) *Client {
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
func TestClient_ExchangeCode_Success_RealCall(t *testing.T) { func TestClient_ExchangeCode_Success_RealCall(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
...@@ -853,7 +865,9 @@ func TestClient_ExchangeCode_Success_RealCall(t *testing.T) { ...@@ -853,7 +865,9 @@ func TestClient_ExchangeCode_Success_RealCall(t *testing.T) {
} }
func TestClient_ExchangeCode_ServerError_RealCall(t *testing.T) { func TestClient_ExchangeCode_ServerError_RealCall(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
...@@ -878,7 +892,9 @@ func TestClient_ExchangeCode_ServerError_RealCall(t *testing.T) { ...@@ -878,7 +892,9 @@ func TestClient_ExchangeCode_ServerError_RealCall(t *testing.T) {
} }
func TestClient_ExchangeCode_InvalidJSON_RealCall(t *testing.T) { func TestClient_ExchangeCode_InvalidJSON_RealCall(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
...@@ -901,7 +917,9 @@ func TestClient_ExchangeCode_InvalidJSON_RealCall(t *testing.T) { ...@@ -901,7 +917,9 @@ func TestClient_ExchangeCode_InvalidJSON_RealCall(t *testing.T) {
} }
func TestClient_ExchangeCode_ContextCanceled_RealCall(t *testing.T) { func TestClient_ExchangeCode_ContextCanceled_RealCall(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(5 * time.Second) // 模拟慢响应 time.Sleep(5 * time.Second) // 模拟慢响应
...@@ -927,7 +945,9 @@ func TestClient_ExchangeCode_ContextCanceled_RealCall(t *testing.T) { ...@@ -927,7 +945,9 @@ func TestClient_ExchangeCode_ContextCanceled_RealCall(t *testing.T) {
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
func TestClient_RefreshToken_Success_RealCall(t *testing.T) { func TestClient_RefreshToken_Success_RealCall(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
...@@ -976,7 +996,9 @@ func TestClient_RefreshToken_Success_RealCall(t *testing.T) { ...@@ -976,7 +996,9 @@ func TestClient_RefreshToken_Success_RealCall(t *testing.T) {
} }
func TestClient_RefreshToken_ServerError_RealCall(t *testing.T) { func TestClient_RefreshToken_ServerError_RealCall(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
...@@ -998,7 +1020,9 @@ func TestClient_RefreshToken_ServerError_RealCall(t *testing.T) { ...@@ -998,7 +1020,9 @@ func TestClient_RefreshToken_ServerError_RealCall(t *testing.T) {
} }
func TestClient_RefreshToken_InvalidJSON_RealCall(t *testing.T) { func TestClient_RefreshToken_InvalidJSON_RealCall(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
...@@ -1021,7 +1045,9 @@ func TestClient_RefreshToken_InvalidJSON_RealCall(t *testing.T) { ...@@ -1021,7 +1045,9 @@ func TestClient_RefreshToken_InvalidJSON_RealCall(t *testing.T) {
} }
func TestClient_RefreshToken_ContextCanceled_RealCall(t *testing.T) { func TestClient_RefreshToken_ContextCanceled_RealCall(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
......
...@@ -23,11 +23,9 @@ const ( ...@@ -23,11 +23,9 @@ const (
UserInfoURL = "https://www.googleapis.com/oauth2/v2/userinfo" UserInfoURL = "https://www.googleapis.com/oauth2/v2/userinfo"
// Antigravity OAuth 客户端凭证 // Antigravity OAuth 客户端凭证
ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
ClientSecret = ""
// AntigravityOAuthClientSecretEnv 是 Antigravity OAuth client_secret 的环境变量名。 // AntigravityOAuthClientSecretEnv 是 Antigravity OAuth client_secret 的环境变量名。
// 出于安全原因,该值不得硬编码入库。
AntigravityOAuthClientSecretEnv = "ANTIGRAVITY_OAUTH_CLIENT_SECRET" AntigravityOAuthClientSecretEnv = "ANTIGRAVITY_OAUTH_CLIENT_SECRET"
// 固定的 redirect_uri(用户需手动复制 code) // 固定的 redirect_uri(用户需手动复制 code)
...@@ -51,14 +49,21 @@ const ( ...@@ -51,14 +49,21 @@ const (
antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com" antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
) )
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.84.2 // defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.18.4
var defaultUserAgentVersion = "1.84.2" var defaultUserAgentVersion = "1.18.4"
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
var defaultClientSecret = "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"
func init() { func init() {
// 从环境变量读取版本号,未设置则使用默认值 // 从环境变量读取版本号,未设置则使用默认值
if version := os.Getenv("ANTIGRAVITY_USER_AGENT_VERSION"); version != "" { if version := os.Getenv("ANTIGRAVITY_USER_AGENT_VERSION"); version != "" {
defaultUserAgentVersion = version defaultUserAgentVersion = version
} }
// 从环境变量读取 client_secret,未设置则使用默认值
if secret := os.Getenv(AntigravityOAuthClientSecretEnv); secret != "" {
defaultClientSecret = secret
}
} }
// GetUserAgent 返回当前配置的 User-Agent // GetUserAgent 返回当前配置的 User-Agent
...@@ -67,14 +72,9 @@ func GetUserAgent() string { ...@@ -67,14 +72,9 @@ func GetUserAgent() string {
} }
func getClientSecret() (string, error) { func getClientSecret() (string, error) {
if v := strings.TrimSpace(ClientSecret); v != "" { if v := strings.TrimSpace(defaultClientSecret); v != "" {
return v, nil return v, nil
} }
if v, ok := os.LookupEnv(AntigravityOAuthClientSecretEnv); ok {
if vv := strings.TrimSpace(v); vv != "" {
return vv, nil
}
}
return "", infraerrors.Newf(http.StatusBadRequest, "ANTIGRAVITY_OAUTH_CLIENT_SECRET_MISSING", "missing antigravity oauth client_secret; set %s", AntigravityOAuthClientSecretEnv) return "", infraerrors.Newf(http.StatusBadRequest, "ANTIGRAVITY_OAUTH_CLIENT_SECRET_MISSING", "missing antigravity oauth client_secret; set %s", AntigravityOAuthClientSecretEnv)
} }
......
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
"encoding/base64" "encoding/base64"
"encoding/hex" "encoding/hex"
"net/url" "net/url"
"os"
"strings" "strings"
"testing" "testing"
"time" "time"
...@@ -17,8 +18,14 @@ import ( ...@@ -17,8 +18,14 @@ import (
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
func TestGetClientSecret_环境变量设置(t *testing.T) { func TestGetClientSecret_环境变量设置(t *testing.T) {
old := defaultClientSecret
defaultClientSecret = ""
t.Cleanup(func() { defaultClientSecret = old })
t.Setenv(AntigravityOAuthClientSecretEnv, "my-secret-value") t.Setenv(AntigravityOAuthClientSecretEnv, "my-secret-value")
// 需要重新触发 init 逻辑:手动从环境变量读取
defaultClientSecret = os.Getenv(AntigravityOAuthClientSecretEnv)
secret, err := getClientSecret() secret, err := getClientSecret()
if err != nil { if err != nil {
t.Fatalf("获取 client_secret 失败: %v", err) t.Fatalf("获取 client_secret 失败: %v", err)
...@@ -29,11 +36,13 @@ func TestGetClientSecret_环境变量设置(t *testing.T) { ...@@ -29,11 +36,13 @@ func TestGetClientSecret_环境变量设置(t *testing.T) {
} }
func TestGetClientSecret_环境变量为空(t *testing.T) { func TestGetClientSecret_环境变量为空(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "") old := defaultClientSecret
defaultClientSecret = ""
t.Cleanup(func() { defaultClientSecret = old })
_, err := getClientSecret() _, err := getClientSecret()
if err == nil { if err == nil {
t.Fatal("环境变量为空时应返回错误") t.Fatal("defaultClientSecret 为空时应返回错误")
} }
if !strings.Contains(err.Error(), AntigravityOAuthClientSecretEnv) { if !strings.Contains(err.Error(), AntigravityOAuthClientSecretEnv) {
t.Errorf("错误信息应包含环境变量名: got %s", err.Error()) t.Errorf("错误信息应包含环境变量名: got %s", err.Error())
...@@ -41,30 +50,31 @@ func TestGetClientSecret_环境变量为空(t *testing.T) { ...@@ -41,30 +50,31 @@ func TestGetClientSecret_环境变量为空(t *testing.T) {
} }
func TestGetClientSecret_环境变量未设置(t *testing.T) { func TestGetClientSecret_环境变量未设置(t *testing.T) {
// t.Setenv 会在测试结束时恢复,但我们需要确保它不存在 old := defaultClientSecret
// 注意:如果 ClientSecret 常量非空,这个测试会直接返回常量值 defaultClientSecret = ""
// 当前代码中 ClientSecret = "",所以会走环境变量逻辑 t.Cleanup(func() { defaultClientSecret = old })
// 明确设置再取消,确保环境变量不存在
t.Setenv(AntigravityOAuthClientSecretEnv, "")
_, err := getClientSecret() _, err := getClientSecret()
if err == nil { if err == nil {
t.Fatal("环境变量未设置时应返回错误") t.Fatal("defaultClientSecret 为空时应返回错误")
} }
} }
func TestGetClientSecret_环境变量含空格(t *testing.T) { func TestGetClientSecret_环境变量含空格(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, " ") old := defaultClientSecret
defaultClientSecret = " "
t.Cleanup(func() { defaultClientSecret = old })
_, err := getClientSecret() _, err := getClientSecret()
if err == nil { if err == nil {
t.Fatal("环境变量仅含空格时应返回错误") t.Fatal("defaultClientSecret 仅含空格时应返回错误")
} }
} }
func TestGetClientSecret_环境变量有前后空格(t *testing.T) { func TestGetClientSecret_环境变量有前后空格(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, " valid-secret ") old := defaultClientSecret
defaultClientSecret = " valid-secret "
t.Cleanup(func() { defaultClientSecret = old })
secret, err := getClientSecret() secret, err := getClientSecret()
if err != nil { if err != nil {
...@@ -670,13 +680,17 @@ func TestConstants_值正确(t *testing.T) { ...@@ -670,13 +680,17 @@ func TestConstants_值正确(t *testing.T) {
if ClientID != "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" { if ClientID != "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" {
t.Errorf("ClientID 不匹配: got %s", ClientID) t.Errorf("ClientID 不匹配: got %s", ClientID)
} }
if ClientSecret != "" { secret, err := getClientSecret()
t.Error("ClientSecret 应为空字符串") if err != nil {
t.Fatalf("getClientSecret 应返回默认值,但报错: %v", err)
}
if secret != "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" {
t.Errorf("默认 client_secret 不匹配: got %s", secret)
} }
if RedirectURI != "http://localhost:8085/callback" { if RedirectURI != "http://localhost:8085/callback" {
t.Errorf("RedirectURI 不匹配: got %s", RedirectURI) t.Errorf("RedirectURI 不匹配: got %s", RedirectURI)
} }
if GetUserAgent() != "antigravity/1.84.2 windows/amd64" { if GetUserAgent() != "antigravity/1.18.4 windows/amd64" {
t.Errorf("UserAgent 不匹配: got %s", GetUserAgent()) t.Errorf("UserAgent 不匹配: got %s", GetUserAgent())
} }
if SessionTTL != 30*time.Minute { if SessionTTL != 30*time.Minute {
......
...@@ -206,6 +206,7 @@ type modelInfo struct { ...@@ -206,6 +206,7 @@ type modelInfo struct {
var modelInfoMap = map[string]modelInfo{ var modelInfoMap = map[string]modelInfo{
"claude-opus-4-5": {DisplayName: "Claude Opus 4.5", CanonicalID: "claude-opus-4-5-20250929"}, "claude-opus-4-5": {DisplayName: "Claude Opus 4.5", CanonicalID: "claude-opus-4-5-20250929"},
"claude-opus-4-6": {DisplayName: "Claude Opus 4.6", CanonicalID: "claude-opus-4-6"}, "claude-opus-4-6": {DisplayName: "Claude Opus 4.6", CanonicalID: "claude-opus-4-6"},
"claude-sonnet-4-6": {DisplayName: "Claude Sonnet 4.6", CanonicalID: "claude-sonnet-4-6"},
"claude-sonnet-4-5": {DisplayName: "Claude Sonnet 4.5", CanonicalID: "claude-sonnet-4-5-20250929"}, "claude-sonnet-4-5": {DisplayName: "Claude Sonnet 4.5", CanonicalID: "claude-sonnet-4-5-20250929"},
"claude-haiku-4-5": {DisplayName: "Claude Haiku 4.5", CanonicalID: "claude-haiku-4-5-20251001"}, "claude-haiku-4-5": {DisplayName: "Claude Haiku 4.5", CanonicalID: "claude-haiku-4-5-20251001"},
} }
......
...@@ -219,6 +219,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -219,6 +219,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts.GET("", h.Admin.Account.List) accounts.GET("", h.Admin.Account.List)
accounts.GET("/:id", h.Admin.Account.GetByID) accounts.GET("/:id", h.Admin.Account.GetByID)
accounts.POST("", h.Admin.Account.Create) accounts.POST("", h.Admin.Account.Create)
accounts.POST("/check-mixed-channel", h.Admin.Account.CheckMixedChannel)
accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS) accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS)
accounts.POST("/sync/crs/preview", h.Admin.Account.PreviewFromCRS) accounts.POST("/sync/crs/preview", h.Admin.Account.PreviewFromCRS)
accounts.PUT("/:id", h.Admin.Account.Update) accounts.PUT("/:id", h.Admin.Account.Update)
......
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestAccount_IsInterceptWarmupEnabled(t *testing.T) {
tests := []struct {
name string
credentials map[string]any
expected bool
}{
{
name: "nil credentials",
credentials: nil,
expected: false,
},
{
name: "empty map",
credentials: map[string]any{},
expected: false,
},
{
name: "field not present",
credentials: map[string]any{"access_token": "tok"},
expected: false,
},
{
name: "field is true",
credentials: map[string]any{"intercept_warmup_requests": true},
expected: true,
},
{
name: "field is false",
credentials: map[string]any{"intercept_warmup_requests": false},
expected: false,
},
{
name: "field is string true",
credentials: map[string]any{"intercept_warmup_requests": "true"},
expected: false,
},
{
name: "field is int 1",
credentials: map[string]any{"intercept_warmup_requests": 1},
expected: false,
},
{
name: "field is nil",
credentials: map[string]any{"intercept_warmup_requests": nil},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &Account{Credentials: tt.credentials}
result := a.IsInterceptWarmupEnabled()
require.Equal(t, tt.expected, result)
})
}
}
...@@ -54,6 +54,7 @@ type AdminService interface { ...@@ -54,6 +54,7 @@ type AdminService interface {
SetAccountError(ctx context.Context, id int64, errorMsg string) error SetAccountError(ctx context.Context, id int64, errorMsg string) error
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error)
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error
// Proxy management // Proxy management
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error)
...@@ -2114,6 +2115,11 @@ func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAcc ...@@ -2114,6 +2115,11 @@ func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAcc
return nil return nil
} }
// CheckMixedChannelRisk checks whether target groups contain mixed channels for the current account platform.
func (s *adminServiceImpl) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
return s.checkMixedChannelRisk(ctx, currentAccountID, currentAccountPlatform, groupIDs)
}
func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []ProxyWithAccountCount) { func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []ProxyWithAccountCount) {
if s.proxyLatencyCache == nil || len(proxies) == 0 { if s.proxyLatencyCache == nil || len(proxies) == 0 {
return return
......
...@@ -87,7 +87,6 @@ var ( ...@@ -87,7 +87,6 @@ var (
) )
const ( const (
antigravityBillingModelEnv = "GATEWAY_ANTIGRAVITY_BILL_WITH_MAPPED_MODEL"
antigravityForwardBaseURLEnv = "GATEWAY_ANTIGRAVITY_FORWARD_BASE_URL" antigravityForwardBaseURLEnv = "GATEWAY_ANTIGRAVITY_FORWARD_BASE_URL"
antigravityFallbackSecondsEnv = "GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS" antigravityFallbackSecondsEnv = "GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS"
) )
...@@ -1309,6 +1308,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -1309,6 +1308,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
// 应用 thinking 模式自动后缀:如果 thinking 开启且目标是 claude-sonnet-4-5,自动改为 thinking 版本 // 应用 thinking 模式自动后缀:如果 thinking 开启且目标是 claude-sonnet-4-5,自动改为 thinking 版本
thinkingEnabled := claudeReq.Thinking != nil && (claudeReq.Thinking.Type == "enabled" || claudeReq.Thinking.Type == "adaptive") thinkingEnabled := claudeReq.Thinking != nil && (claudeReq.Thinking.Type == "enabled" || claudeReq.Thinking.Type == "adaptive")
mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled) mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled)
billingModel := mappedModel
// 获取 access_token // 获取 access_token
if s.tokenProvider == nil { if s.tokenProvider == nil {
...@@ -1370,6 +1370,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -1370,6 +1370,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
ForceCacheBilling: switchErr.IsStickySession, ForceCacheBilling: switchErr.IsStickySession,
} }
} }
// 区分客户端取消和真正的上游失败,返回更准确的错误消息
if c.Request.Context().Err() != nil {
return nil, s.writeClaudeError(c, http.StatusBadGateway, "client_disconnected", "Client disconnected before upstream response")
}
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries") return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
} }
resp := result.resp resp := result.resp
...@@ -1618,7 +1622,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -1618,7 +1622,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return &ForwardResult{ return &ForwardResult{
RequestID: requestID, RequestID: requestID,
Usage: *usage, Usage: *usage,
Model: originalModel, // 使用原始模型用于计费和日志 Model: billingModel, // 使用映射模型用于计费和日志
Stream: claudeReq.Stream, Stream: claudeReq.Stream,
Duration: time.Since(startTime), Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs, FirstTokenMs: firstTokenMs,
...@@ -1972,6 +1976,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co ...@@ -1972,6 +1976,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
if mappedModel == "" { if mappedModel == "" {
return nil, s.writeGoogleError(c, http.StatusForbidden, fmt.Sprintf("model %s not in whitelist", originalModel)) return nil, s.writeGoogleError(c, http.StatusForbidden, fmt.Sprintf("model %s not in whitelist", originalModel))
} }
billingModel := mappedModel
// 获取 access_token // 获取 access_token
if s.tokenProvider == nil { if s.tokenProvider == nil {
...@@ -2042,6 +2047,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co ...@@ -2042,6 +2047,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
ForceCacheBilling: switchErr.IsStickySession, ForceCacheBilling: switchErr.IsStickySession,
} }
} }
// 区分客户端取消和真正的上游失败,返回更准确的错误消息
if c.Request.Context().Err() != nil {
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Client disconnected before upstream response")
}
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries") return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
} }
resp := result.resp resp := result.resp
...@@ -2197,7 +2206,7 @@ handleSuccess: ...@@ -2197,7 +2206,7 @@ handleSuccess:
return &ForwardResult{ return &ForwardResult{
RequestID: requestID, RequestID: requestID,
Usage: *usage, Usage: *usage,
Model: originalModel, Model: billingModel,
Stream: stream, Stream: stream,
Duration: time.Since(startTime), Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs, FirstTokenMs: firstTokenMs,
...@@ -2642,7 +2651,16 @@ func (s *AntigravityGatewayService) handleUpstreamError( ...@@ -2642,7 +2651,16 @@ func (s *AntigravityGatewayService) handleUpstreamError(
defaultDur := s.getDefaultRateLimitDuration() defaultDur := s.getDefaultRateLimitDuration()
// 尝试解析模型 key 并设置模型级限流 // 尝试解析模型 key 并设置模型级限流
modelKey := resolveAntigravityModelKey(requestedModel) //
// 注意:requestedModel 可能是"映射前"的请求模型名(例如 claude-opus-4-6),
// 调度与限流判定使用的是 Antigravity 最终模型名(包含映射与 thinking 后缀)。
// 因此这里必须写入最终模型 key,确保后续调度能正确避开已限流模型。
modelKey := resolveFinalAntigravityModelKey(ctx, account, requestedModel)
if strings.TrimSpace(modelKey) == "" {
// 极少数情况下无法映射(理论上不应发生:能转发成功说明映射已通过),
// 保持旧行为作为兜底,避免完全丢失模型级限流记录。
modelKey = resolveAntigravityModelKey(requestedModel)
}
if modelKey != "" { if modelKey != "" {
ra := s.resolveResetTime(resetAt, defaultDur) ra := s.resolveResetTime(resetAt, defaultDur)
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelKey, ra); err != nil { if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelKey, ra); err != nil {
...@@ -3881,7 +3899,6 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. ...@@ -3881,7 +3899,6 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
return nil, fmt.Errorf("missing model") return nil, fmt.Errorf("missing model")
} }
originalModel := claudeReq.Model originalModel := claudeReq.Model
billingModel := originalModel
// 构建上游请求 URL // 构建上游请求 URL
upstreamURL := baseURL + "/v1/messages" upstreamURL := baseURL + "/v1/messages"
...@@ -3934,7 +3951,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. ...@@ -3934,7 +3951,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
_, _ = c.Writer.Write(respBody) _, _ = c.Writer.Write(respBody)
return &ForwardResult{ return &ForwardResult{
Model: billingModel, Model: originalModel,
}, nil }, nil
} }
...@@ -3975,7 +3992,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. ...@@ -3975,7 +3992,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
logger.LegacyPrintf("service.antigravity_gateway", "%s status=success duration_ms=%d", prefix, duration.Milliseconds()) logger.LegacyPrintf("service.antigravity_gateway", "%s status=success duration_ms=%d", prefix, duration.Milliseconds())
return &ForwardResult{ return &ForwardResult{
Model: billingModel, Model: originalModel,
Stream: claudeReq.Stream, Stream: claudeReq.Stream,
Duration: duration, Duration: duration,
FirstTokenMs: firstTokenMs, FirstTokenMs: firstTokenMs,
......
...@@ -134,6 +134,36 @@ func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int, ...@@ -134,6 +134,36 @@ func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int,
return s.resp, s.err return s.resp, s.err
} }
type antigravitySettingRepoStub struct{}
func (s *antigravitySettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
panic("unexpected Get call")
}
func (s *antigravitySettingRepoStub) GetValue(ctx context.Context, key string) (string, error) {
return "", ErrSettingNotFound
}
func (s *antigravitySettingRepoStub) Set(ctx context.Context, key, value string) error {
panic("unexpected Set call")
}
func (s *antigravitySettingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
panic("unexpected GetMultiple call")
}
func (s *antigravitySettingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
panic("unexpected SetMultiple call")
}
func (s *antigravitySettingRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (s *antigravitySettingRepoStub) Delete(ctx context.Context, key string) error {
panic("unexpected Delete call")
}
func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) { func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder() writer := httptest.NewRecorder()
...@@ -160,8 +190,9 @@ func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) { ...@@ -160,8 +190,9 @@ func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
} }
svc := &AntigravityGatewayService{ svc := &AntigravityGatewayService{
tokenProvider: &AntigravityTokenProvider{}, settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
httpUpstream: &httpUpstreamStub{resp: resp}, tokenProvider: &AntigravityTokenProvider{},
httpUpstream: &httpUpstreamStub{resp: resp},
} }
account := &Account{ account := &Account{
...@@ -418,6 +449,113 @@ func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling( ...@@ -418,6 +449,113 @@ func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch") require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
} }
// TestAntigravityGatewayService_Forward_BillsWithMappedModel
// 验证:Antigravity Claude 转发返回的计费模型使用映射后的模型
func TestAntigravityGatewayService_Forward_BillsWithMappedModel(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
body, err := json.Marshal(map[string]any{
"model": "claude-sonnet-4-5",
"messages": []map[string]any{
{"role": "user", "content": "hello"},
},
"max_tokens": 16,
"stream": true,
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
c.Request = req
upstreamBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n")
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"X-Request-Id": []string{"req-bill-1"}},
Body: io.NopCloser(bytes.NewReader(upstreamBody)),
}
svc := &AntigravityGatewayService{
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: &httpUpstreamStub{resp: resp},
}
const mappedModel = "gemini-3-pro-high"
account := &Account{
ID: 5,
Name: "acc-forward-billing",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "token",
"model_mapping": map[string]any{
"claude-sonnet-4-5": mappedModel,
},
},
}
result, err := svc.Forward(context.Background(), c, account, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, mappedModel, result.Model)
}
// TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel
// 验证:Antigravity Gemini 转发返回的计费模型使用映射后的模型
func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
body, err := json.Marshal(map[string]any{
"contents": []map[string]any{
{"role": "user", "parts": []map[string]any{{"text": "hello"}}},
},
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
c.Request = req
upstreamBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n")
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"X-Request-Id": []string{"req-bill-2"}},
Body: io.NopCloser(bytes.NewReader(upstreamBody)),
}
svc := &AntigravityGatewayService{
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: &httpUpstreamStub{resp: resp},
}
const mappedModel = "gemini-3-pro-high"
account := &Account{
ID: 6,
Name: "acc-gemini-billing",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "token",
"model_mapping": map[string]any{
"gemini-2.5-flash": mappedModel,
},
},
}
result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", true, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, mappedModel, result.Model)
}
// TestStreamUpstreamResponse_UsageAndFirstToken // TestStreamUpstreamResponse_UsageAndFirstToken
// 验证:usage 字段可被累积/覆盖更新,并且能记录首 token 时间 // 验证:usage 字段可被累积/覆盖更新,并且能记录首 token 时间
func TestStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) { func TestStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment