"frontend/src/views/git@web.lueluesay.top:chenxi/sub2api.git" did not exist on "a381910e8657ae2ab40be5c62f92227edf7ca6ef"
Unverified Commit 2fe8932c authored by Call White's avatar Call White Committed by GitHub
Browse files

Merge pull request #3 from cyhhao/main

merge to main
parents 2f2e76f9 adb77af1
...@@ -26,11 +26,20 @@ func RegisterAuthRoutes( ...@@ -26,11 +26,20 @@ func RegisterAuthRoutes(
{ {
auth.POST("/register", h.Auth.Register) auth.POST("/register", h.Auth.Register)
auth.POST("/login", h.Auth.Login) auth.POST("/login", h.Auth.Login)
auth.POST("/login/2fa", h.Auth.Login2FA)
auth.POST("/send-verify-code", h.Auth.SendVerifyCode) auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
// 优惠码验证接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close) // 优惠码验证接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close)
auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{ auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose, FailureMode: middleware.RateLimitFailClose,
}), h.Auth.ValidatePromoCode) }), h.Auth.ValidatePromoCode)
// 忘记密码接口添加速率限制:每分钟最多 5 次(Redis 故障时 fail-close)
auth.POST("/forgot-password", rateLimiter.LimitWithOptions("forgot-password", 5, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.ForgotPassword)
// 重置密码接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close)
auth.POST("/reset-password", rateLimiter.LimitWithOptions("reset-password", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.ResetPassword)
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart) auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback) auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
} }
......
...@@ -22,6 +22,17 @@ func RegisterUserRoutes( ...@@ -22,6 +22,17 @@ func RegisterUserRoutes(
user.GET("/profile", h.User.GetProfile) user.GET("/profile", h.User.GetProfile)
user.PUT("/password", h.User.ChangePassword) user.PUT("/password", h.User.ChangePassword)
user.PUT("", h.User.UpdateProfile) user.PUT("", h.User.UpdateProfile)
// TOTP 双因素认证
totp := user.Group("/totp")
{
totp.GET("/status", h.Totp.GetStatus)
totp.GET("/verification-method", h.Totp.GetVerificationMethod)
totp.POST("/send-code", h.Totp.SendVerifyCode)
totp.POST("/setup", h.Totp.InitiateSetup)
totp.POST("/enable", h.Totp.Enable)
totp.POST("/disable", h.Totp.Disable)
}
} }
// API Key管理 // API Key管理
......
...@@ -197,6 +197,35 @@ func (a *Account) GetCredentialAsTime(key string) *time.Time { ...@@ -197,6 +197,35 @@ func (a *Account) GetCredentialAsTime(key string) *time.Time {
return nil return nil
} }
// GetCredentialAsInt64 解析凭证中的 int64 字段
// 用于读取 _token_version 等内部字段
func (a *Account) GetCredentialAsInt64(key string) int64 {
if a == nil || a.Credentials == nil {
return 0
}
val, ok := a.Credentials[key]
if !ok || val == nil {
return 0
}
switch v := val.(type) {
case int64:
return v
case float64:
return int64(v)
case int:
return int64(v)
case json.Number:
if i, err := v.Int64(); err == nil {
return i
}
case string:
if i, err := strconv.ParseInt(strings.TrimSpace(v), 10, 64); err == nil {
return i
}
}
return 0
}
func (a *Account) IsTempUnschedulableEnabled() bool { func (a *Account) IsTempUnschedulableEnabled() bool {
if a.Credentials == nil { if a.Credentials == nil {
return false return false
...@@ -592,6 +621,44 @@ func (a *Account) IsAnthropicOAuthOrSetupToken() bool { ...@@ -592,6 +621,44 @@ func (a *Account) IsAnthropicOAuthOrSetupToken() bool {
return a.Platform == PlatformAnthropic && (a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken) return a.Platform == PlatformAnthropic && (a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken)
} }
// IsTLSFingerprintEnabled 检查是否启用 TLS 指纹伪装
// 仅适用于 Anthropic OAuth/SetupToken 类型账号
// 启用后将模拟 Claude Code (Node.js) 客户端的 TLS 握手特征
func (a *Account) IsTLSFingerprintEnabled() bool {
// 仅支持 Anthropic OAuth/SetupToken 账号
if !a.IsAnthropicOAuthOrSetupToken() {
return false
}
if a.Extra == nil {
return false
}
if v, ok := a.Extra["enable_tls_fingerprint"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}
// IsSessionIDMaskingEnabled 检查是否启用会话ID伪装
// 仅适用于 Anthropic OAuth/SetupToken 类型账号
// 启用后将在一段时间内(15分钟)固定 metadata.user_id 中的 session ID,
// 使上游认为请求来自同一个会话
func (a *Account) IsSessionIDMaskingEnabled() bool {
if !a.IsAnthropicOAuthOrSetupToken() {
return false
}
if a.Extra == nil {
return false
}
if v, ok := a.Extra["session_id_masking_enabled"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元) // GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
// 返回 0 表示未启用 // 返回 0 表示未启用
func (a *Account) GetWindowCostLimit() float64 { func (a *Account) GetWindowCostLimit() float64 {
...@@ -668,6 +735,23 @@ func (a *Account) CheckWindowCostSchedulability(currentWindowCost float64) Windo ...@@ -668,6 +735,23 @@ func (a *Account) CheckWindowCostSchedulability(currentWindowCost float64) Windo
return WindowCostNotSchedulable return WindowCostNotSchedulable
} }
// GetCurrentWindowStartTime 获取当前有效的窗口开始时间
// 逻辑:
// 1. 如果窗口未过期(SessionWindowEnd 存在且在当前时间之后),使用记录的 SessionWindowStart
// 2. 否则(窗口过期或未设置),使用新的预测窗口开始时间(从当前整点开始)
func (a *Account) GetCurrentWindowStartTime() time.Time {
now := time.Now()
// 窗口未过期,使用记录的窗口开始时间
if a.SessionWindowStart != nil && a.SessionWindowEnd != nil && now.Before(*a.SessionWindowEnd) {
return *a.SessionWindowStart
}
// 窗口已过期或未设置,预测新的窗口开始时间(从当前整点开始)
// 与 ratelimit_service.go 中 UpdateSessionWindow 的预测逻辑保持一致
return time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), 0, 0, 0, now.Location())
}
// parseExtraFloat64 从 extra 字段解析 float64 值 // parseExtraFloat64 从 extra 字段解析 float64 值
func parseExtraFloat64(value any) float64 { func parseExtraFloat64(value any) float64 {
switch v := value.(type) { switch v := value.(type) {
......
...@@ -37,6 +37,7 @@ type AccountRepository interface { ...@@ -37,6 +37,7 @@ type AccountRepository interface {
UpdateLastUsed(ctx context.Context, id int64) error UpdateLastUsed(ctx context.Context, id int64) error
BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error
SetError(ctx context.Context, id int64, errorMsg string) error SetError(ctx context.Context, id int64, errorMsg string) error
ClearError(ctx context.Context, id int64) error
SetSchedulable(ctx context.Context, id int64, schedulable bool) error SetSchedulable(ctx context.Context, id int64, schedulable bool) error
AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error)
BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error
......
...@@ -99,6 +99,10 @@ func (s *accountRepoStub) SetError(ctx context.Context, id int64, errorMsg strin ...@@ -99,6 +99,10 @@ func (s *accountRepoStub) SetError(ctx context.Context, id int64, errorMsg strin
panic("unexpected SetError call") panic("unexpected SetError call")
} }
func (s *accountRepoStub) ClearError(ctx context.Context, id int64) error {
panic("unexpected ClearError call")
}
func (s *accountRepoStub) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { func (s *accountRepoStub) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
panic("unexpected SetSchedulable call") panic("unexpected SetSchedulable call")
} }
......
...@@ -123,7 +123,7 @@ func createTestPayload(modelID string) (map[string]any, error) { ...@@ -123,7 +123,7 @@ func createTestPayload(modelID string) (map[string]any, error) {
"system": []map[string]any{ "system": []map[string]any{
{ {
"type": "text", "type": "text",
"text": "You are Claude Code, Anthropic's official CLI for Claude.", "text": claudeCodeSystemPrompt,
"cache_control": map[string]string{ "cache_control": map[string]string{
"type": "ephemeral", "type": "ephemeral",
}, },
...@@ -265,7 +265,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account ...@@ -265,7 +265,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
proxyURL = account.Proxy.URL() proxyURL = account.Proxy.URL()
} }
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if err != nil { if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
} }
...@@ -375,7 +375,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account ...@@ -375,7 +375,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
proxyURL = account.Proxy.URL() proxyURL = account.Proxy.URL()
} }
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if err != nil { if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
} }
...@@ -446,7 +446,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account ...@@ -446,7 +446,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
proxyURL = account.Proxy.URL() proxyURL = account.Proxy.URL()
} }
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if err != nil { if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
} }
......
...@@ -32,8 +32,8 @@ type UsageLogRepository interface { ...@@ -32,8 +32,8 @@ type UsageLogRepository interface {
// Admin dashboard stats // Admin dashboard stats
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error)
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
...@@ -157,9 +157,20 @@ type ClaudeUsageResponse struct { ...@@ -157,9 +157,20 @@ type ClaudeUsageResponse struct {
} `json:"seven_day_sonnet"` } `json:"seven_day_sonnet"`
} }
// ClaudeUsageFetchOptions 包含获取 Claude 用量数据所需的所有选项
type ClaudeUsageFetchOptions struct {
AccessToken string // OAuth access token
ProxyURL string // 代理 URL(可选)
AccountID int64 // 账号 ID(用于 TLS 指纹选择)
EnableTLSFingerprint bool // 是否启用 TLS 指纹伪装
Fingerprint *Fingerprint // 缓存的指纹信息(User-Agent 等)
}
// ClaudeUsageFetcher fetches usage data from Anthropic OAuth API // ClaudeUsageFetcher fetches usage data from Anthropic OAuth API
type ClaudeUsageFetcher interface { type ClaudeUsageFetcher interface {
FetchUsage(ctx context.Context, accessToken, proxyURL string) (*ClaudeUsageResponse, error) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*ClaudeUsageResponse, error)
// FetchUsageWithOptions 使用完整选项获取用量数据,支持 TLS 指纹和自定义 User-Agent
FetchUsageWithOptions(ctx context.Context, opts *ClaudeUsageFetchOptions) (*ClaudeUsageResponse, error)
} }
// AccountUsageService 账号使用量查询服务 // AccountUsageService 账号使用量查询服务
...@@ -170,6 +181,7 @@ type AccountUsageService struct { ...@@ -170,6 +181,7 @@ type AccountUsageService struct {
geminiQuotaService *GeminiQuotaService geminiQuotaService *GeminiQuotaService
antigravityQuotaFetcher *AntigravityQuotaFetcher antigravityQuotaFetcher *AntigravityQuotaFetcher
cache *UsageCache cache *UsageCache
identityCache IdentityCache
} }
// NewAccountUsageService 创建AccountUsageService实例 // NewAccountUsageService 创建AccountUsageService实例
...@@ -180,6 +192,7 @@ func NewAccountUsageService( ...@@ -180,6 +192,7 @@ func NewAccountUsageService(
geminiQuotaService *GeminiQuotaService, geminiQuotaService *GeminiQuotaService,
antigravityQuotaFetcher *AntigravityQuotaFetcher, antigravityQuotaFetcher *AntigravityQuotaFetcher,
cache *UsageCache, cache *UsageCache,
identityCache IdentityCache,
) *AccountUsageService { ) *AccountUsageService {
return &AccountUsageService{ return &AccountUsageService{
accountRepo: accountRepo, accountRepo: accountRepo,
...@@ -188,6 +201,7 @@ func NewAccountUsageService( ...@@ -188,6 +201,7 @@ func NewAccountUsageService(
geminiQuotaService: geminiQuotaService, geminiQuotaService: geminiQuotaService,
antigravityQuotaFetcher: antigravityQuotaFetcher, antigravityQuotaFetcher: antigravityQuotaFetcher,
cache: cache, cache: cache,
identityCache: identityCache,
} }
} }
...@@ -272,7 +286,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou ...@@ -272,7 +286,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
} }
dayStart := geminiDailyWindowStart(now) dayStart := geminiDailyWindowStart(now)
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil) stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil, nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("get gemini usage stats failed: %w", err) return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
} }
...@@ -294,7 +308,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou ...@@ -294,7 +308,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
// Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m) // Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m)
minuteStart := now.Truncate(time.Minute) minuteStart := now.Truncate(time.Minute)
minuteResetAt := minuteStart.Add(time.Minute) minuteResetAt := minuteStart.Add(time.Minute)
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil) minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil, nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err) return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err)
} }
...@@ -369,12 +383,8 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou ...@@ -369,12 +383,8 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou
// 如果没有缓存,从数据库查询 // 如果没有缓存,从数据库查询
if windowStats == nil { if windowStats == nil {
var startTime time.Time // 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
if account.SessionWindowStart != nil { startTime := account.GetCurrentWindowStartTime()
startTime = *account.SessionWindowStart
} else {
startTime = time.Now().Add(-5 * time.Hour)
}
stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime) stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime)
if err != nil { if err != nil {
...@@ -428,6 +438,8 @@ func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountI ...@@ -428,6 +438,8 @@ func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountI
} }
// fetchOAuthUsageRaw 从 Anthropic API 获取原始响应(不构建 UsageInfo) // fetchOAuthUsageRaw 从 Anthropic API 获取原始响应(不构建 UsageInfo)
// 如果账号开启了 TLS 指纹,则使用 TLS 指纹伪装
// 如果有缓存的 Fingerprint,则使用缓存的 User-Agent 等信息
func (s *AccountUsageService) fetchOAuthUsageRaw(ctx context.Context, account *Account) (*ClaudeUsageResponse, error) { func (s *AccountUsageService) fetchOAuthUsageRaw(ctx context.Context, account *Account) (*ClaudeUsageResponse, error) {
accessToken := account.GetCredential("access_token") accessToken := account.GetCredential("access_token")
if accessToken == "" { if accessToken == "" {
...@@ -439,7 +451,22 @@ func (s *AccountUsageService) fetchOAuthUsageRaw(ctx context.Context, account *A ...@@ -439,7 +451,22 @@ func (s *AccountUsageService) fetchOAuthUsageRaw(ctx context.Context, account *A
proxyURL = account.Proxy.URL() proxyURL = account.Proxy.URL()
} }
return s.usageFetcher.FetchUsage(ctx, accessToken, proxyURL) // 构建完整的选项
opts := &ClaudeUsageFetchOptions{
AccessToken: accessToken,
ProxyURL: proxyURL,
AccountID: account.ID,
EnableTLSFingerprint: account.IsTLSFingerprintEnabled(),
}
// 尝试获取缓存的 Fingerprint(包含 User-Agent 等信息)
if s.identityCache != nil {
if fp, err := s.identityCache.GetFingerprint(ctx, account.ID); err == nil && fp != nil {
opts.Fingerprint = fp
}
}
return s.usageFetcher.FetchUsageWithOptions(ctx, opts)
} }
// parseTime 尝试多种格式解析时间 // parseTime 尝试多种格式解析时间
......
...@@ -42,6 +42,7 @@ type AdminService interface { ...@@ -42,6 +42,7 @@ type AdminService interface {
DeleteAccount(ctx context.Context, id int64) error DeleteAccount(ctx context.Context, id int64) error
RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error) RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error)
ClearAccountError(ctx context.Context, id int64) (*Account, error) ClearAccountError(ctx context.Context, id int64) (*Account, error)
SetAccountError(ctx context.Context, id int64, errorMsg string) error
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error)
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
...@@ -1101,6 +1102,10 @@ func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Ac ...@@ -1101,6 +1102,10 @@ func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Ac
return account, nil return account, nil
} }
func (s *adminServiceImpl) SetAccountError(ctx context.Context, id int64, errorMsg string) error {
return s.accountRepo.SetError(ctx, id, errorMsg)
}
func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) { func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) {
if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil { if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil {
return nil, err return nil, err
......
...@@ -93,6 +93,18 @@ func (s *userRepoStub) RemoveGroupFromAllowedGroups(ctx context.Context, groupID ...@@ -93,6 +93,18 @@ func (s *userRepoStub) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
panic("unexpected RemoveGroupFromAllowedGroups call") panic("unexpected RemoveGroupFromAllowedGroups call")
} }
func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
panic("unexpected UpdateTotpSecret call")
}
func (s *userRepoStub) EnableTotp(ctx context.Context, userID int64) error {
panic("unexpected EnableTotp call")
}
func (s *userRepoStub) DisableTotp(ctx context.Context, userID int64) error {
panic("unexpected DisableTotp call")
}
type groupRepoStub struct { type groupRepoStub struct {
affectedUserIDs []int64 affectedUserIDs []int64
deleteErr error deleteErr error
......
...@@ -12,6 +12,7 @@ import ( ...@@ -12,6 +12,7 @@ import (
mathrand "math/rand" mathrand "math/rand"
"net" "net"
"net/http" "net/http"
"os"
"strings" "strings"
"sync/atomic" "sync/atomic"
"time" "time"
...@@ -28,6 +29,207 @@ const ( ...@@ -28,6 +29,207 @@ const (
antigravityRetryMaxDelay = 16 * time.Second antigravityRetryMaxDelay = 16 * time.Second
) )
const antigravityScopeRateLimitEnv = "GATEWAY_ANTIGRAVITY_429_SCOPE_LIMIT"
// antigravityRetryLoopParams 重试循环的参数
type antigravityRetryLoopParams struct {
ctx context.Context
prefix string
account *Account
proxyURL string
accessToken string
action string
body []byte
quotaScope AntigravityQuotaScope
c *gin.Context
httpUpstream HTTPUpstream
settingService *SettingService
handleError func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope)
}
// antigravityRetryLoopResult 重试循环的结果
type antigravityRetryLoopResult struct {
resp *http.Response
}
// antigravityRetryLoop 执行带 URL fallback 的重试循环
func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) {
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
if len(availableURLs) == 0 {
availableURLs = antigravity.BaseURLs
}
var resp *http.Response
var usedBaseURL string
logBody := p.settingService != nil && p.settingService.cfg != nil && p.settingService.cfg.Gateway.LogUpstreamErrorBody
maxBytes := 2048
if p.settingService != nil && p.settingService.cfg != nil && p.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
maxBytes = p.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
}
getUpstreamDetail := func(body []byte) string {
if !logBody {
return ""
}
return truncateString(string(body), maxBytes)
}
urlFallbackLoop:
for urlIdx, baseURL := range availableURLs {
usedBaseURL = baseURL
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
select {
case <-p.ctx.Done():
log.Printf("%s status=context_canceled error=%v", p.prefix, p.ctx.Err())
return nil, p.ctx.Err()
default:
}
upstreamReq, err := antigravity.NewAPIRequestWithURL(p.ctx, baseURL, p.action, p.accessToken, p.body)
if err != nil {
return nil, err
}
// Capture upstream request body for ops retry of this attempt.
if p.c != nil && len(p.body) > 0 {
p.c.Set(OpsUpstreamRequestBodyKey, string(p.body))
}
resp, err = p.httpUpstream.Do(upstreamReq, p.proxyURL, p.account.ID, p.account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{
Platform: p.account.Platform,
AccountID: p.account.ID,
AccountName: p.account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
log.Printf("%s URL fallback (connection error): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1])
continue urlFallbackLoop
}
if attempt < antigravityMaxRetries {
log.Printf("%s status=request_failed retry=%d/%d error=%v", p.prefix, attempt, antigravityMaxRetries, err)
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
return nil, p.ctx.Err()
}
continue
}
log.Printf("%s status=request_failed retries_exhausted error=%v", p.prefix, err)
setOpsUpstreamError(p.c, 0, safeErr, "")
return nil, fmt.Errorf("upstream request failed after retries: %w", err)
}
// 429 限流处理:区分 URL 级别限流和账户配额限流
if resp.StatusCode == http.StatusTooManyRequests {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
// "Resource has been exhausted" 是 URL 级别限流,切换 URL
if isURLLevelRateLimit(respBody) && urlIdx < len(availableURLs)-1 {
log.Printf("%s URL fallback (429): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1])
continue urlFallbackLoop
}
// 账户/模型配额限流,重试 3 次(指数退避)
if attempt < antigravityMaxRetries {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{
Platform: p.account.Platform,
AccountID: p.account.ID,
AccountName: p.account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "retry",
Message: upstreamMsg,
Detail: getUpstreamDetail(respBody),
})
log.Printf("%s status=429 retry=%d/%d body=%s", p.prefix, attempt, antigravityMaxRetries, truncateForLog(respBody, 200))
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
return nil, p.ctx.Err()
}
continue
}
// 重试用尽,标记账户限流
p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.quotaScope)
log.Printf("%s status=429 rate_limited base_url=%s body=%s", p.prefix, baseURL, truncateForLog(respBody, 200))
resp = &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
}
break urlFallbackLoop
}
// 其他可重试错误
if resp.StatusCode >= 400 && shouldRetryAntigravityError(resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
if attempt < antigravityMaxRetries {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{
Platform: p.account.Platform,
AccountID: p.account.ID,
AccountName: p.account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "retry",
Message: upstreamMsg,
Detail: getUpstreamDetail(respBody),
})
log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500))
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
return nil, p.ctx.Err()
}
continue
}
resp = &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
}
break urlFallbackLoop
}
break urlFallbackLoop
}
}
if resp != nil && resp.StatusCode < 400 && usedBaseURL != "" {
antigravity.DefaultURLAvailability.MarkSuccess(usedBaseURL)
}
return &antigravityRetryLoopResult{resp: resp}, nil
}
// shouldRetryAntigravityError 判断是否应该重试
func shouldRetryAntigravityError(statusCode int) bool {
switch statusCode {
case 429, 500, 502, 503, 504, 529:
return true
default:
return false
}
}
// isURLLevelRateLimit 判断是否为 URL 级别的限流(应切换 URL 重试)
// "Resource has been exhausted" 是 URL/节点级别限流,切换 URL 可能成功
// "exhausted your capacity on this model" 是账户/模型配额限流,切换 URL 无效
func isURLLevelRateLimit(body []byte) bool {
// 快速检查:包含 "Resource has been exhausted" 且不包含 "capacity on this model"
bodyStr := string(body)
return strings.Contains(bodyStr, "Resource has been exhausted") &&
!strings.Contains(bodyStr, "capacity on this model")
}
// isAntigravityConnectionError 判断是否为连接错误(网络超时、DNS 失败、连接拒绝) // isAntigravityConnectionError 判断是否为连接错误(网络超时、DNS 失败、连接拒绝)
func isAntigravityConnectionError(err error) bool { func isAntigravityConnectionError(err error) bool {
if err == nil { if err == nil {
...@@ -238,7 +440,6 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account ...@@ -238,7 +440,6 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
if err != nil { if err != nil {
lastErr = fmt.Errorf("请求失败: %w", err) lastErr = fmt.Errorf("请求失败: %w", err)
if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("[antigravity-Test] URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1]) log.Printf("[antigravity-Test] URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
continue continue
} }
...@@ -254,7 +455,6 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account ...@@ -254,7 +455,6 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
// 检查是否需要 URL 降级 // 检查是否需要 URL 降级
if shouldAntigravityFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 { if shouldAntigravityFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("[antigravity-Test] URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1]) log.Printf("[antigravity-Test] URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
continue continue
} }
...@@ -266,6 +466,8 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account ...@@ -266,6 +466,8 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
// 解析流式响应,提取文本 // 解析流式响应,提取文本
text := extractTextFromSSEResponse(respBody) text := extractTextFromSSEResponse(respBody)
// 标记成功的 URL,下次优先使用
antigravity.DefaultURLAvailability.MarkSuccess(baseURL)
return &TestConnectionResult{ return &TestConnectionResult{
Text: text, Text: text,
MappedModel: mappedModel, MappedModel: mappedModel,
...@@ -276,13 +478,14 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account ...@@ -276,13 +478,14 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
} }
// buildGeminiTestRequest 构建 Gemini 格式测试请求 // buildGeminiTestRequest 构建 Gemini 格式测试请求
// 使用最小 token 消耗:输入 "." + maxOutputTokens: 1
func (s *AntigravityGatewayService) buildGeminiTestRequest(projectID, model string) ([]byte, error) { func (s *AntigravityGatewayService) buildGeminiTestRequest(projectID, model string) ([]byte, error) {
payload := map[string]any{ payload := map[string]any{
"contents": []map[string]any{ "contents": []map[string]any{
{ {
"role": "user", "role": "user",
"parts": []map[string]any{ "parts": []map[string]any{
{"text": "hi"}, {"text": "."},
}, },
}, },
}, },
...@@ -292,22 +495,26 @@ func (s *AntigravityGatewayService) buildGeminiTestRequest(projectID, model stri ...@@ -292,22 +495,26 @@ func (s *AntigravityGatewayService) buildGeminiTestRequest(projectID, model stri
{"text": antigravity.GetDefaultIdentityPatch()}, {"text": antigravity.GetDefaultIdentityPatch()},
}, },
}, },
"generationConfig": map[string]any{
"maxOutputTokens": 1,
},
} }
payloadBytes, _ := json.Marshal(payload) payloadBytes, _ := json.Marshal(payload)
return s.wrapV1InternalRequest(projectID, model, payloadBytes) return s.wrapV1InternalRequest(projectID, model, payloadBytes)
} }
// buildClaudeTestRequest 构建 Claude 格式测试请求并转换为 Gemini 格式 // buildClaudeTestRequest 构建 Claude 格式测试请求并转换为 Gemini 格式
// 使用最小 token 消耗:输入 "." + MaxTokens: 1
func (s *AntigravityGatewayService) buildClaudeTestRequest(projectID, mappedModel string) ([]byte, error) { func (s *AntigravityGatewayService) buildClaudeTestRequest(projectID, mappedModel string) ([]byte, error) {
claudeReq := &antigravity.ClaudeRequest{ claudeReq := &antigravity.ClaudeRequest{
Model: mappedModel, Model: mappedModel,
Messages: []antigravity.ClaudeMessage{ Messages: []antigravity.ClaudeMessage{
{ {
Role: "user", Role: "user",
Content: json.RawMessage(`"hi"`), Content: json.RawMessage(`"."`),
}, },
}, },
MaxTokens: 1024, MaxTokens: 1,
Stream: false, Stream: false,
} }
return antigravity.TransformClaudeToGemini(claudeReq, projectID, mappedModel) return antigravity.TransformClaudeToGemini(claudeReq, projectID, mappedModel)
...@@ -523,9 +730,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -523,9 +730,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
proxyURL = account.Proxy.URL() proxyURL = account.Proxy.URL()
} }
// Sanitize thinking blocks (clean cache_control and flatten history thinking)
sanitizeThinkingBlocks(&claudeReq)
// 获取转换选项 // 获取转换选项
// Antigravity 上游要求必须包含身份提示词,否则会返回 429 // Antigravity 上游要求必须包含身份提示词,否则会返回 429
transformOpts := s.getClaudeTransformOptions(ctx) transformOpts := s.getClaudeTransformOptions(ctx)
...@@ -537,150 +741,29 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -537,150 +741,29 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return nil, fmt.Errorf("transform request: %w", err) return nil, fmt.Errorf("transform request: %w", err)
} }
// Safety net: ensure no cache_control leaked into Gemini request
geminiBody = cleanCacheControlFromGeminiJSON(geminiBody)
// Antigravity 上游只支持流式请求,统一使用 streamGenerateContent // Antigravity 上游只支持流式请求,统一使用 streamGenerateContent
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回 // 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回
action := "streamGenerateContent" action := "streamGenerateContent"
// URL fallback 循环 // 执行带重试的请求
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs() result, err := antigravityRetryLoop(antigravityRetryLoopParams{
if len(availableURLs) == 0 { ctx: ctx,
availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有 prefix: prefix,
} account: account,
proxyURL: proxyURL,
// 重试循环 accessToken: accessToken,
var resp *http.Response action: action,
urlFallbackLoop: body: geminiBody,
for urlIdx, baseURL := range availableURLs { quotaScope: quotaScope,
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { c: c,
// 检查 context 是否已取消(客户端断开连接) httpUpstream: s.httpUpstream,
select { settingService: s.settingService,
case <-ctx.Done(): handleError: s.handleUpstreamError,
log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err()) })
return nil, ctx.Err() if err != nil {
default: return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
}
upstreamReq, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, action, accessToken, geminiBody)
// Capture upstream request body for ops retry of this attempt.
if c != nil {
c.Set(OpsUpstreamRequestBodyKey, string(geminiBody))
}
if err != nil {
return nil, err
}
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
// 检查是否应触发 URL 降级
if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("%s URL fallback (connection error): %s -> %s", prefix, baseURL, availableURLs[urlIdx+1])
continue urlFallbackLoop
}
if attempt < antigravityMaxRetries {
log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err)
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", prefix)
return nil, ctx.Err()
}
continue
}
log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
setOpsUpstreamError(c, 0, safeErr, "")
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
}
// 检查是否应触发 URL 降级(仅 429)
if resp.StatusCode == http.StatusTooManyRequests && urlIdx < len(availableURLs)-1 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
maxBytes := 2048
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
}
upstreamDetail := ""
if logBody {
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "retry",
Message: upstreamMsg,
Detail: upstreamDetail,
})
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("%s URL fallback (HTTP 429): %s -> %s body=%s", prefix, baseURL, availableURLs[urlIdx+1], truncateForLog(respBody, 200))
continue urlFallbackLoop
}
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
if attempt < antigravityMaxRetries {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
maxBytes := 2048
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
}
upstreamDetail := ""
if logBody {
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "retry",
Message: upstreamMsg,
Detail: upstreamDetail,
})
log.Printf("%s status=%d retry=%d/%d body=%s", prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500))
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", prefix)
return nil, ctx.Err()
}
continue
}
// 所有重试都失败,标记限流状态
if resp.StatusCode == 429 {
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
}
// 最后一次尝试也失败
resp = &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
}
break urlFallbackLoop
}
break urlFallbackLoop
}
} }
resp := result.resp
defer func() { _ = resp.Body.Close() }() defer func() { _ = resp.Body.Close() }()
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
...@@ -739,11 +822,20 @@ urlFallbackLoop: ...@@ -739,11 +822,20 @@ urlFallbackLoop:
if txErr != nil { if txErr != nil {
continue continue
} }
retryReq, buildErr := antigravity.NewAPIRequest(ctx, action, accessToken, retryGeminiBody) retryResult, retryErr := antigravityRetryLoop(antigravityRetryLoopParams{
if buildErr != nil { ctx: ctx,
continue prefix: prefix,
} account: account,
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency) proxyURL: proxyURL,
accessToken: accessToken,
action: action,
body: retryGeminiBody,
quotaScope: quotaScope,
c: c,
httpUpstream: s.httpUpstream,
settingService: s.settingService,
handleError: s.handleUpstreamError,
})
if retryErr != nil { if retryErr != nil {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform, Platform: account.Platform,
...@@ -757,6 +849,7 @@ urlFallbackLoop: ...@@ -757,6 +849,7 @@ urlFallbackLoop:
continue continue
} }
retryResp := retryResult.resp
if retryResp.StatusCode < 400 { if retryResp.StatusCode < 400 {
_ = resp.Body.Close() _ = resp.Body.Close()
resp = retryResp resp = retryResp
...@@ -766,6 +859,13 @@ urlFallbackLoop: ...@@ -766,6 +859,13 @@ urlFallbackLoop:
retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20)) retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
_ = retryResp.Body.Close() _ = retryResp.Body.Close()
if retryResp.StatusCode == http.StatusTooManyRequests {
retryBaseURL := ""
if retryResp.Request != nil && retryResp.Request.URL != nil {
retryBaseURL = retryResp.Request.URL.Scheme + "://" + retryResp.Request.URL.Host
}
log.Printf("%s status=429 rate_limited base_url=%s retry_stage=%s body=%s", prefix, retryBaseURL, stage.name, truncateForLog(retryBody, 200))
}
kind := "signature_retry" kind := "signature_retry"
if strings.TrimSpace(stage.name) != "" { if strings.TrimSpace(stage.name) != "" {
kind = "signature_retry_" + strings.ReplaceAll(stage.name, "+", "_") kind = "signature_retry_" + strings.ReplaceAll(stage.name, "+", "_")
...@@ -920,143 +1020,6 @@ func extractAntigravityErrorMessage(body []byte) string { ...@@ -920,143 +1020,6 @@ func extractAntigravityErrorMessage(body []byte) string {
return "" return ""
} }
// cleanCacheControlFromGeminiJSON removes cache_control from Gemini JSON (emergency fix)
// This should not be needed if transformation is correct, but serves as a safety net
func cleanCacheControlFromGeminiJSON(body []byte) []byte {
// Try a more robust approach: parse and clean
var data map[string]any
if err := json.Unmarshal(body, &data); err != nil {
log.Printf("[Antigravity] Failed to parse Gemini JSON for cache_control cleaning: %v", err)
return body
}
cleaned := removeCacheControlFromAny(data)
if !cleaned {
return body
}
if result, err := json.Marshal(data); err == nil {
log.Printf("[Antigravity] Successfully cleaned cache_control from Gemini JSON")
return result
}
return body
}
// removeCacheControlFromAny recursively removes cache_control fields
func removeCacheControlFromAny(v any) bool {
cleaned := false
switch val := v.(type) {
case map[string]any:
for k, child := range val {
if k == "cache_control" {
delete(val, k)
cleaned = true
} else if removeCacheControlFromAny(child) {
cleaned = true
}
}
case []any:
for _, item := range val {
if removeCacheControlFromAny(item) {
cleaned = true
}
}
}
return cleaned
}
// sanitizeThinkingBlocks cleans cache_control and flattens history thinking blocks
// Thinking blocks do NOT support cache_control field (Anthropic API/Vertex AI requirement)
// Additionally, history thinking blocks are flattened to text to avoid upstream validation errors
func sanitizeThinkingBlocks(req *antigravity.ClaudeRequest) {
if req == nil {
return
}
log.Printf("[Antigravity] sanitizeThinkingBlocks: processing request with %d messages", len(req.Messages))
// Clean system blocks
if len(req.System) > 0 {
var systemBlocks []map[string]any
if err := json.Unmarshal(req.System, &systemBlocks); err == nil {
for i := range systemBlocks {
if blockType, _ := systemBlocks[i]["type"].(string); blockType == "thinking" || systemBlocks[i]["thinking"] != nil {
if removeCacheControlFromAny(systemBlocks[i]) {
log.Printf("[Antigravity] Deep cleaned cache_control from thinking block in system[%d]", i)
}
}
}
// Marshal back
if cleaned, err := json.Marshal(systemBlocks); err == nil {
req.System = cleaned
}
}
}
// Clean message content blocks and flatten history
lastMsgIdx := len(req.Messages) - 1
for msgIdx := range req.Messages {
raw := req.Messages[msgIdx].Content
if len(raw) == 0 {
continue
}
// Try to parse as blocks array
var blocks []map[string]any
if err := json.Unmarshal(raw, &blocks); err != nil {
continue
}
cleaned := false
for blockIdx := range blocks {
blockType, _ := blocks[blockIdx]["type"].(string)
// Check for thinking blocks (typed or untyped)
if blockType == "thinking" || blocks[blockIdx]["thinking"] != nil {
// 1. Clean cache_control
if removeCacheControlFromAny(blocks[blockIdx]) {
log.Printf("[Antigravity] Deep cleaned cache_control from thinking block in messages[%d].content[%d]", msgIdx, blockIdx)
cleaned = true
}
// 2. Flatten to text if it's a history message (not the last one)
if msgIdx < lastMsgIdx {
log.Printf("[Antigravity] Flattening history thinking block to text at messages[%d].content[%d]", msgIdx, blockIdx)
// Extract thinking content
var textContent string
if t, ok := blocks[blockIdx]["thinking"].(string); ok {
textContent = t
} else {
// Fallback for non-string content (marshal it)
if b, err := json.Marshal(blocks[blockIdx]["thinking"]); err == nil {
textContent = string(b)
}
}
// Convert to text block
blocks[blockIdx]["type"] = "text"
blocks[blockIdx]["text"] = textContent
delete(blocks[blockIdx], "thinking")
delete(blocks[blockIdx], "signature")
delete(blocks[blockIdx], "cache_control") // Ensure it's gone
cleaned = true
}
}
}
// Marshal back if modified
if cleaned {
if marshaled, err := json.Marshal(blocks); err == nil {
req.Messages[msgIdx].Content = marshaled
}
}
}
}
// stripThinkingFromClaudeRequest converts thinking blocks to text blocks in a Claude Messages request. // stripThinkingFromClaudeRequest converts thinking blocks to text blocks in a Claude Messages request.
// This preserves the thinking content while avoiding signature validation errors. // This preserves the thinking content while avoiding signature validation errors.
// Note: redacted_thinking blocks are removed because they cannot be converted to text. // Note: redacted_thinking blocks are removed because they cannot be converted to text.
...@@ -1342,6 +1305,14 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co ...@@ -1342,6 +1305,14 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
return nil, err return nil, err
} }
// 清理 Schema
if cleanedBody, err := cleanGeminiRequest(injectedBody); err == nil {
injectedBody = cleanedBody
log.Printf("[Antigravity] Cleaned request schema in forwarded request for account %s", account.Name)
} else {
log.Printf("[Antigravity] Failed to clean schema: %v", err)
}
// 包装请求 // 包装请求
wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, injectedBody) wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, injectedBody)
if err != nil { if err != nil {
...@@ -1352,138 +1323,25 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co ...@@ -1352,138 +1323,25 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后返回 // 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后返回
upstreamAction := "streamGenerateContent" upstreamAction := "streamGenerateContent"
// URL fallback 循环 // 执行带重试的请求
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs() result, err := antigravityRetryLoop(antigravityRetryLoopParams{
if len(availableURLs) == 0 { ctx: ctx,
availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有 prefix: prefix,
} account: account,
proxyURL: proxyURL,
// 重试循环 accessToken: accessToken,
var resp *http.Response action: upstreamAction,
urlFallbackLoop: body: wrappedBody,
for urlIdx, baseURL := range availableURLs { quotaScope: quotaScope,
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { c: c,
// 检查 context 是否已取消(客户端断开连接) httpUpstream: s.httpUpstream,
select { settingService: s.settingService,
case <-ctx.Done(): handleError: s.handleUpstreamError,
log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err()) })
return nil, ctx.Err() if err != nil {
default: return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
}
upstreamReq, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, upstreamAction, accessToken, wrappedBody)
if err != nil {
return nil, err
}
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
// 检查是否应触发 URL 降级
if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("%s URL fallback (connection error): %s -> %s", prefix, baseURL, availableURLs[urlIdx+1])
continue urlFallbackLoop
}
if attempt < antigravityMaxRetries {
log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err)
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", prefix)
return nil, ctx.Err()
}
continue
}
log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
setOpsUpstreamError(c, 0, safeErr, "")
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
}
// 检查是否应触发 URL 降级(仅 429)
if resp.StatusCode == http.StatusTooManyRequests && urlIdx < len(availableURLs)-1 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
maxBytes := 2048
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
}
upstreamDetail := ""
if logBody {
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "retry",
Message: upstreamMsg,
Detail: upstreamDetail,
})
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("%s URL fallback (HTTP 429): %s -> %s body=%s", prefix, baseURL, availableURLs[urlIdx+1], truncateForLog(respBody, 200))
continue urlFallbackLoop
}
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
if attempt < antigravityMaxRetries {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
maxBytes := 2048
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
}
upstreamDetail := ""
if logBody {
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "retry",
Message: upstreamMsg,
Detail: upstreamDetail,
})
log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries)
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", prefix)
return nil, ctx.Err()
}
continue
}
// 所有重试都失败,标记限流状态
if resp.StatusCode == 429 {
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
}
resp = &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
}
break urlFallbackLoop
}
break urlFallbackLoop
}
} }
resp := result.resp
defer func() { defer func() {
if resp != nil && resp.Body != nil { if resp != nil && resp.Body != nil {
_ = resp.Body.Close() _ = resp.Body.Close()
...@@ -1525,8 +1383,6 @@ urlFallbackLoop: ...@@ -1525,8 +1383,6 @@ urlFallbackLoop:
goto handleSuccess goto handleSuccess
} }
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
requestID := resp.Header.Get("x-request-id") requestID := resp.Header.Get("x-request-id")
if requestID != "" { if requestID != "" {
c.Header("x-request-id", requestID) c.Header("x-request-id", requestID)
...@@ -1537,6 +1393,7 @@ urlFallbackLoop: ...@@ -1537,6 +1393,7 @@ urlFallbackLoop:
if unwrapErr != nil || len(unwrappedForOps) == 0 { if unwrapErr != nil || len(unwrappedForOps) == 0 {
unwrappedForOps = respBody unwrappedForOps = respBody
} }
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(unwrappedForOps)) upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(unwrappedForOps))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
...@@ -1581,6 +1438,7 @@ urlFallbackLoop: ...@@ -1581,6 +1438,7 @@ urlFallbackLoop:
Message: upstreamMsg, Message: upstreamMsg,
Detail: upstreamDetail, Detail: upstreamDetail,
}) })
log.Printf("[antigravity-Forward] upstream error status=%d body=%s", resp.StatusCode, truncateForLog(unwrappedForOps, 500))
c.Data(resp.StatusCode, contentType, unwrappedForOps) c.Data(resp.StatusCode, contentType, unwrappedForOps)
return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode) return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode)
} }
...@@ -1637,15 +1495,6 @@ handleSuccess: ...@@ -1637,15 +1495,6 @@ handleSuccess:
}, nil }, nil
} }
func (s *AntigravityGatewayService) shouldRetryUpstreamError(statusCode int) bool {
switch statusCode {
case 429, 500, 502, 503, 504, 529:
return true
default:
return false
}
}
func (s *AntigravityGatewayService) shouldFailoverUpstreamError(statusCode int) bool { func (s *AntigravityGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
switch statusCode { switch statusCode {
case 401, 403, 429, 529: case 401, 403, 429, 529:
...@@ -1679,33 +1528,48 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool { ...@@ -1679,33 +1528,48 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
} }
} }
func antigravityUseScopeRateLimit() bool {
v := strings.ToLower(strings.TrimSpace(os.Getenv(antigravityScopeRateLimitEnv)))
return v == "1" || v == "true" || v == "yes" || v == "on"
}
func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) { func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {
// 429 使用 Gemini 格式解析(从 body 解析重置时间) // 429 使用 Gemini 格式解析(从 body 解析重置时间)
if statusCode == 429 { if statusCode == 429 {
useScopeLimit := antigravityUseScopeRateLimit() && quotaScope != ""
resetAt := ParseGeminiRateLimitResetTime(body) resetAt := ParseGeminiRateLimitResetTime(body)
if resetAt == nil { if resetAt == nil {
// 解析失败:Gemini 有重试时间用 5 分钟,Claude 没有用 1 分钟 // 解析失败:使用配置的 fallback 时间,直接限流整个账户
defaultDur := 1 * time.Minute fallbackMinutes := 5
if bytes.Contains(body, []byte("Please retry in")) || bytes.Contains(body, []byte("retryDelay")) { if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes > 0 {
defaultDur = 5 * time.Minute fallbackMinutes = s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes
} }
defaultDur := time.Duration(fallbackMinutes) * time.Minute
ra := time.Now().Add(defaultDur) ra := time.Now().Add(defaultDur)
log.Printf("%s status=429 rate_limited scope=%s reset_in=%v (fallback)", prefix, quotaScope, defaultDur) if useScopeLimit {
if quotaScope == "" { log.Printf("%s status=429 rate_limited scope=%s reset_in=%v (fallback)", prefix, quotaScope, defaultDur)
return if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, ra); err != nil {
} log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err)
if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, ra); err != nil { }
log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err) } else {
log.Printf("%s status=429 rate_limited account=%d reset_in=%v (fallback)", prefix, account.ID, defaultDur)
if err := s.accountRepo.SetRateLimited(ctx, account.ID, ra); err != nil {
log.Printf("%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err)
}
} }
return return
} }
resetTime := time.Unix(*resetAt, 0) resetTime := time.Unix(*resetAt, 0)
log.Printf("%s status=429 rate_limited scope=%s reset_at=%v reset_in=%v", prefix, quotaScope, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second)) if useScopeLimit {
if quotaScope == "" { log.Printf("%s status=429 rate_limited scope=%s reset_at=%v reset_in=%v", prefix, quotaScope, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second))
return if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, resetTime); err != nil {
} log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err)
if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, resetTime); err != nil { }
log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err) } else {
log.Printf("%s status=429 rate_limited account=%d reset_at=%v reset_in=%v", prefix, account.ID, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second))
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil {
log.Printf("%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err)
}
} }
return return
} }
...@@ -1849,6 +1713,19 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context ...@@ -1849,6 +1713,19 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
if u := extractGeminiUsage(parsed); u != nil { if u := extractGeminiUsage(parsed); u != nil {
usage = u usage = u
} }
// Check for MALFORMED_FUNCTION_CALL
if candidates, ok := parsed["candidates"].([]any); ok && len(candidates) > 0 {
if cand, ok := candidates[0].(map[string]any); ok {
if fr, ok := cand["finishReason"].(string); ok && fr == "MALFORMED_FUNCTION_CALL" {
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in forward stream")
if content, ok := cand["content"]; ok {
if b, err := json.Marshal(content); err == nil {
log.Printf("[Antigravity] Malformed content: %s", string(b))
}
}
}
}
}
} }
if firstTokenMs == nil { if firstTokenMs == nil {
...@@ -1884,7 +1761,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context ...@@ -1884,7 +1761,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
} }
// handleGeminiStreamToNonStreaming 读取上游流式响应,合并为非流式响应返回给客户端 // handleGeminiStreamToNonStreaming 读取上游流式响应,合并为非流式响应返回给客户端
// Gemini 流式响应中每个 chunk 都包含累积的完整文本,只需保留最后一个有效响应 // Gemini 流式响应是增量的,需要累积所有 chunk 的内容
func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Context, resp *http.Response, startTime time.Time) (*antigravityStreamResult, error) { func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Context, resp *http.Response, startTime time.Time) (*antigravityStreamResult, error) {
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize maxLineSize := defaultMaxLineSize
...@@ -1897,6 +1774,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont ...@@ -1897,6 +1774,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
var firstTokenMs *int var firstTokenMs *int
var last map[string]any var last map[string]any
var lastWithParts map[string]any var lastWithParts map[string]any
var collectedImageParts []map[string]any // 收集所有包含图片的 parts
var collectedTextParts []string // 收集所有文本片段
type scanEvent struct { type scanEvent struct {
line string line string
...@@ -1996,9 +1875,33 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont ...@@ -1996,9 +1875,33 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
usage = u usage = u
} }
// Check for MALFORMED_FUNCTION_CALL
if candidates, ok := parsed["candidates"].([]any); ok && len(candidates) > 0 {
if cand, ok := candidates[0].(map[string]any); ok {
if fr, ok := cand["finishReason"].(string); ok && fr == "MALFORMED_FUNCTION_CALL" {
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in forward non-stream collect")
if content, ok := cand["content"]; ok {
if b, err := json.Marshal(content); err == nil {
log.Printf("[Antigravity] Malformed content: %s", string(b))
}
}
}
}
}
// 保留最后一个有 parts 的响应 // 保留最后一个有 parts 的响应
if parts := extractGeminiParts(parsed); len(parts) > 0 { if parts := extractGeminiParts(parsed); len(parts) > 0 {
lastWithParts = parsed lastWithParts = parsed
// 收集包含图片和文本的 parts
for _, part := range parts {
if inlineData, ok := part["inlineData"].(map[string]any); ok {
collectedImageParts = append(collectedImageParts, part)
_ = inlineData // 避免 unused 警告
}
if text, ok := part["text"].(string); ok && text != "" {
collectedTextParts = append(collectedTextParts, text)
}
}
} }
case <-intervalCh: case <-intervalCh:
...@@ -2020,6 +1923,16 @@ returnResponse: ...@@ -2020,6 +1923,16 @@ returnResponse:
log.Printf("[antigravity-Forward] warning: empty stream response, no valid chunks received") log.Printf("[antigravity-Forward] warning: empty stream response, no valid chunks received")
} }
// 如果收集到了图片 parts,需要合并到最终响应中
if len(collectedImageParts) > 0 {
finalResponse = mergeImagePartsToResponse(finalResponse, collectedImageParts)
}
// 如果收集到了文本,需要合并到最终响应中
if len(collectedTextParts) > 0 {
finalResponse = mergeTextPartsToResponse(finalResponse, collectedTextParts)
}
respBody, err := json.Marshal(finalResponse) respBody, err := json.Marshal(finalResponse)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to marshal response: %w", err) return nil, fmt.Errorf("failed to marshal response: %w", err)
...@@ -2029,6 +1942,167 @@ returnResponse: ...@@ -2029,6 +1942,167 @@ returnResponse:
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
} }
// getOrCreateGeminiParts 获取 Gemini 响应的 parts 结构,返回深拷贝和更新回调
func getOrCreateGeminiParts(response map[string]any) (result map[string]any, existingParts []any, setParts func([]any)) {
// 深拷贝 response
result = make(map[string]any)
for k, v := range response {
result[k] = v
}
// 获取或创建 candidates
candidates, ok := result["candidates"].([]any)
if !ok || len(candidates) == 0 {
candidates = []any{map[string]any{}}
}
// 获取第一个 candidate
candidate, ok := candidates[0].(map[string]any)
if !ok {
candidate = make(map[string]any)
candidates[0] = candidate
}
// 获取或创建 content
content, ok := candidate["content"].(map[string]any)
if !ok {
content = map[string]any{"role": "model"}
candidate["content"] = content
}
// 获取现有 parts
existingParts, ok = content["parts"].([]any)
if !ok {
existingParts = []any{}
}
// 返回更新回调
setParts = func(newParts []any) {
content["parts"] = newParts
result["candidates"] = candidates
}
return result, existingParts, setParts
}
// mergeCollectedPartsToResponse 将收集的所有 parts 合并到 Gemini 响应中
// 这个函数会合并所有类型的 parts:text、thinking、functionCall、inlineData 等
// 保持原始顺序,只合并连续的普通 text parts
func mergeCollectedPartsToResponse(response map[string]any, collectedParts []map[string]any) map[string]any {
if len(collectedParts) == 0 {
return response
}
result, _, setParts := getOrCreateGeminiParts(response)
// 合并策略:
// 1. 保持原始顺序
// 2. 连续的普通 text parts 合并为一个
// 3. thinking、functionCall、inlineData 等保持原样
var mergedParts []any
var textBuffer strings.Builder
flushTextBuffer := func() {
if textBuffer.Len() > 0 {
mergedParts = append(mergedParts, map[string]any{
"text": textBuffer.String(),
})
textBuffer.Reset()
}
}
for _, part := range collectedParts {
// 检查是否是普通 text part
if text, ok := part["text"].(string); ok {
// 检查是否有 thought 标记
if thought, _ := part["thought"].(bool); thought {
// thinking part,先刷新 text buffer,然后保留原样
flushTextBuffer()
mergedParts = append(mergedParts, part)
} else {
// 普通 text,累积到 buffer
_, _ = textBuffer.WriteString(text)
}
} else {
// 非 text part(functionCall、inlineData 等),先刷新 text buffer,然后保留原样
flushTextBuffer()
mergedParts = append(mergedParts, part)
}
}
// 刷新剩余的 text
flushTextBuffer()
setParts(mergedParts)
return result
}
// mergeImagePartsToResponse 将收集到的图片 parts 合并到 Gemini 响应中
func mergeImagePartsToResponse(response map[string]any, imageParts []map[string]any) map[string]any {
if len(imageParts) == 0 {
return response
}
result, existingParts, setParts := getOrCreateGeminiParts(response)
// 检查现有 parts 中是否已经有图片
for _, p := range existingParts {
if pm, ok := p.(map[string]any); ok {
if _, hasInline := pm["inlineData"]; hasInline {
return result // 已有图片,不重复添加
}
}
}
// 添加收集到的图片 parts
for _, imgPart := range imageParts {
existingParts = append(existingParts, imgPart)
}
setParts(existingParts)
return result
}
// mergeTextPartsToResponse 将收集到的文本合并到 Gemini 响应中
func mergeTextPartsToResponse(response map[string]any, textParts []string) map[string]any {
if len(textParts) == 0 {
return response
}
mergedText := strings.Join(textParts, "")
result, existingParts, setParts := getOrCreateGeminiParts(response)
// 查找并更新第一个 text part,或创建新的
newParts := make([]any, 0, len(existingParts)+1)
textUpdated := false
for _, p := range existingParts {
pm, ok := p.(map[string]any)
if !ok {
newParts = append(newParts, p)
continue
}
if _, hasText := pm["text"]; hasText && !textUpdated {
// 用累积的文本替换
newPart := make(map[string]any)
for k, v := range pm {
newPart[k] = v
}
newPart["text"] = mergedText
newParts = append(newParts, newPart)
textUpdated = true
} else {
newParts = append(newParts, pm)
}
}
if !textUpdated {
newParts = append([]any{map[string]any{"text": mergedText}}, newParts...)
}
setParts(newParts)
return result
}
func (s *AntigravityGatewayService) writeClaudeError(c *gin.Context, status int, errType, message string) error { func (s *AntigravityGatewayService) writeClaudeError(c *gin.Context, status int, errType, message string) error {
c.JSON(status, gin.H{ c.JSON(status, gin.H{
"type": "error", "type": "error",
...@@ -2146,6 +2220,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont ...@@ -2146,6 +2220,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
var firstTokenMs *int var firstTokenMs *int
var last map[string]any var last map[string]any
var lastWithParts map[string]any var lastWithParts map[string]any
var collectedParts []map[string]any // 收集所有 parts(包括 text、thinking、functionCall、inlineData 等)
type scanEvent struct { type scanEvent struct {
line string line string
...@@ -2240,9 +2315,12 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont ...@@ -2240,9 +2315,12 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
last = parsed last = parsed
// 保留最后一个有 parts 的响应 // 保留最后一个有 parts 的响应,并收集所有 parts
if parts := extractGeminiParts(parsed); len(parts) > 0 { if parts := extractGeminiParts(parsed); len(parts) > 0 {
lastWithParts = parsed lastWithParts = parsed
// 收集所有 parts(text、thinking、functionCall、inlineData 等)
collectedParts = append(collectedParts, parts...)
} }
case <-intervalCh: case <-intervalCh:
...@@ -2265,6 +2343,11 @@ returnResponse: ...@@ -2265,6 +2343,11 @@ returnResponse:
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Empty response from upstream") return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Empty response from upstream")
} }
// 将收集的所有 parts 合并到最终响应中
if len(collectedParts) > 0 {
finalResponse = mergeCollectedPartsToResponse(finalResponse, collectedParts)
}
// 序列化为 JSON(Gemini 格式) // 序列化为 JSON(Gemini 格式)
geminiBody, err := json.Marshal(finalResponse) geminiBody, err := json.Marshal(finalResponse)
if err != nil { if err != nil {
...@@ -2472,3 +2555,55 @@ func isImageGenerationModel(model string) bool { ...@@ -2472,3 +2555,55 @@ func isImageGenerationModel(model string) bool {
modelLower == "gemini-2.5-flash-image-preview" || modelLower == "gemini-2.5-flash-image-preview" ||
strings.HasPrefix(modelLower, "gemini-2.5-flash-image-") strings.HasPrefix(modelLower, "gemini-2.5-flash-image-")
} }
// cleanGeminiRequest 清理 Gemini 请求体中的 Schema
func cleanGeminiRequest(body []byte) ([]byte, error) {
var payload map[string]any
if err := json.Unmarshal(body, &payload); err != nil {
return nil, err
}
modified := false
// 1. 清理 Tools
if tools, ok := payload["tools"].([]any); ok && len(tools) > 0 {
for _, t := range tools {
toolMap, ok := t.(map[string]any)
if !ok {
continue
}
// function_declarations (snake_case) or functionDeclarations (camelCase)
var funcs []any
if f, ok := toolMap["functionDeclarations"].([]any); ok {
funcs = f
} else if f, ok := toolMap["function_declarations"].([]any); ok {
funcs = f
}
if len(funcs) == 0 {
continue
}
for _, f := range funcs {
funcMap, ok := f.(map[string]any)
if !ok {
continue
}
if params, ok := funcMap["parameters"].(map[string]any); ok {
antigravity.DeepCleanUndefined(params)
cleaned := antigravity.CleanJSONSchema(params)
funcMap["parameters"] = cleaned
modified = true
}
}
}
}
if !modified {
return body, nil
}
return json.Marshal(payload)
}
...@@ -30,7 +30,7 @@ func TestIsAntigravityModelSupported(t *testing.T) { ...@@ -30,7 +30,7 @@ func TestIsAntigravityModelSupported(t *testing.T) {
{"可映射 - claude-3-haiku-20240307", "claude-3-haiku-20240307", true}, {"可映射 - claude-3-haiku-20240307", "claude-3-haiku-20240307", true},
// Gemini 前缀透传 // Gemini 前缀透传
{"Gemini前缀 - gemini-1.5-pro", "gemini-1.5-pro", true}, {"Gemini前缀 - gemini-2.5-pro", "gemini-2.5-pro", true},
{"Gemini前缀 - gemini-unknown-model", "gemini-unknown-model", true}, {"Gemini前缀 - gemini-unknown-model", "gemini-unknown-model", true},
{"Gemini前缀 - gemini-future-version", "gemini-future-version", true}, {"Gemini前缀 - gemini-future-version", "gemini-future-version", true},
...@@ -142,10 +142,10 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) { ...@@ -142,10 +142,10 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
expected: "gemini-2.5-flash", expected: "gemini-2.5-flash",
}, },
{ {
name: "Gemini透传 - gemini-1.5-pro", name: "Gemini透传 - gemini-2.5-pro",
requestedModel: "gemini-1.5-pro", requestedModel: "gemini-2.5-pro",
accountMapping: nil, accountMapping: nil,
expected: "gemini-1.5-pro", expected: "gemini-2.5-pro",
}, },
{ {
name: "Gemini透传 - gemini-future-model", name: "Gemini透传 - gemini-future-model",
......
...@@ -82,13 +82,14 @@ type AntigravityExchangeCodeInput struct { ...@@ -82,13 +82,14 @@ type AntigravityExchangeCodeInput struct {
// AntigravityTokenInfo token 信息 // AntigravityTokenInfo token 信息
type AntigravityTokenInfo struct { type AntigravityTokenInfo struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"` RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"` ExpiresIn int64 `json:"expires_in"`
ExpiresAt int64 `json:"expires_at"` ExpiresAt int64 `json:"expires_at"`
TokenType string `json:"token_type"` TokenType string `json:"token_type"`
Email string `json:"email,omitempty"` Email string `json:"email,omitempty"`
ProjectID string `json:"project_id,omitempty"` ProjectID string `json:"project_id,omitempty"`
ProjectIDMissing bool `json:"-"` // LoadCodeAssist 未返回 project_id
} }
// ExchangeCode 用 authorization code 交换 token // ExchangeCode 用 authorization code 交换 token
...@@ -141,18 +142,13 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig ...@@ -141,18 +142,13 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
result.Email = userInfo.Email result.Email = userInfo.Email
} }
// 获取 project_id(部分账户类型可能没有) // 获取 project_id(部分账户类型可能没有),失败时重试
loadResp, _, err := client.LoadCodeAssist(ctx, tokenResp.AccessToken) projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenResp.AccessToken, proxyURL, 3)
if err != nil { if loadErr != nil {
fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败: %v\n", err) fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败(重试后): %v\n", loadErr)
} else if loadResp != nil && loadResp.CloudAICompanionProject != "" { result.ProjectIDMissing = true
result.ProjectID = loadResp.CloudAICompanionProject } else {
} result.ProjectID = projectID
// 兜底:随机生成 project_id
if result.ProjectID == "" {
result.ProjectID = antigravity.GenerateMockProjectID()
fmt.Printf("[AntigravityOAuth] 使用随机生成的 project_id: %s\n", result.ProjectID)
} }
return result, nil return result, nil
...@@ -236,19 +232,66 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou ...@@ -236,19 +232,66 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou
return nil, err return nil, err
} }
// 保留原有的 project_id 和 email // 保留原有的 email
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
if existingProjectID != "" {
tokenInfo.ProjectID = existingProjectID
}
existingEmail := strings.TrimSpace(account.GetCredential("email")) existingEmail := strings.TrimSpace(account.GetCredential("email"))
if existingEmail != "" { if existingEmail != "" {
tokenInfo.Email = existingEmail tokenInfo.Email = existingEmail
} }
// 每次刷新都调用 LoadCodeAssist 获取 project_id,失败时重试
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenInfo.AccessToken, proxyURL, 3)
if loadErr != nil {
// LoadCodeAssist 失败,保留原有 project_id
tokenInfo.ProjectID = existingProjectID
// 只有从未获取过 project_id 且本次也获取失败时,才标记为真正缺失
// 如果之前有 project_id,本次只是临时故障,不应标记为错误
if existingProjectID == "" {
tokenInfo.ProjectIDMissing = true
}
} else {
tokenInfo.ProjectID = projectID
}
return tokenInfo, nil return tokenInfo, nil
} }
// loadProjectIDWithRetry 带重试机制获取 project_id
// 返回 project_id 和错误,失败时会重试指定次数
func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, accessToken, proxyURL string, maxRetries int) (string, error) {
var lastErr error
for attempt := 0; attempt <= maxRetries; attempt++ {
if attempt > 0 {
// 指数退避:1s, 2s, 4s
backoff := time.Duration(1<<uint(attempt-1)) * time.Second
if backoff > 8*time.Second {
backoff = 8 * time.Second
}
time.Sleep(backoff)
}
client := antigravity.NewClient(proxyURL)
loadResp, _, err := client.LoadCodeAssist(ctx, accessToken)
if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" {
return loadResp.CloudAICompanionProject, nil
}
// 记录错误
if err != nil {
lastErr = err
} else if loadResp == nil {
lastErr = fmt.Errorf("LoadCodeAssist 返回空响应")
} else {
lastErr = fmt.Errorf("LoadCodeAssist 返回空 project_id")
}
}
return "", fmt.Errorf("获取 project_id 失败 (重试 %d 次后): %w", maxRetries, lastErr)
}
// BuildAccountCredentials 构建账户凭证 // BuildAccountCredentials 构建账户凭证
func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *AntigravityTokenInfo) map[string]any { func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *AntigravityTokenInfo) map[string]any {
creds := map[string]any{ creds := map[string]any{
......
...@@ -31,11 +31,6 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou ...@@ -31,11 +31,6 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
accessToken := account.GetCredential("access_token") accessToken := account.GetCredential("access_token")
projectID := account.GetCredential("project_id") projectID := account.GetCredential("project_id")
// 如果没有 project_id,生成一个随机的
if projectID == "" {
projectID = antigravity.GenerateMockProjectID()
}
client := antigravity.NewClient(proxyURL) client := antigravity.NewClient(proxyURL)
// 调用 API 获取配额 // 调用 API 获取配额
......
//go:build unit
package service
import (
"context"
"fmt"
"io"
"net/http"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/stretchr/testify/require"
)
type stubAntigravityUpstream struct {
firstBase string
secondBase string
calls []string
}
func (s *stubAntigravityUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
url := req.URL.String()
s.calls = append(s.calls, url)
if strings.HasPrefix(url, s.firstBase) {
return &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"error":{"message":"Resource has been exhausted"}}`)),
}, nil
}
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader("ok")),
}, nil
}
func (s *stubAntigravityUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
return s.Do(req, proxyURL, accountID, accountConcurrency)
}
type scopeLimitCall struct {
accountID int64
scope AntigravityQuotaScope
resetAt time.Time
}
type rateLimitCall struct {
accountID int64
resetAt time.Time
}
type stubAntigravityAccountRepo struct {
AccountRepository
scopeCalls []scopeLimitCall
rateCalls []rateLimitCall
}
func (s *stubAntigravityAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
s.scopeCalls = append(s.scopeCalls, scopeLimitCall{accountID: id, scope: scope, resetAt: resetAt})
return nil
}
func (s *stubAntigravityAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
s.rateCalls = append(s.rateCalls, rateLimitCall{accountID: id, resetAt: resetAt})
return nil
}
func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
oldAvailability := antigravity.DefaultURLAvailability
defer func() {
antigravity.BaseURLs = oldBaseURLs
antigravity.DefaultURLAvailability = oldAvailability
}()
base1 := "https://ag-1.test"
base2 := "https://ag-2.test"
antigravity.BaseURLs = []string{base1, base2}
antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
upstream := &stubAntigravityUpstream{firstBase: base1, secondBase: base2}
account := &Account{
ID: 1,
Name: "acc-1",
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
}
var handleErrorCalled bool
result, err := antigravityRetryLoop(antigravityRetryLoopParams{
prefix: "[test]",
ctx: context.Background(),
account: account,
proxyURL: "",
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
quotaScope: AntigravityQuotaScopeClaude,
httpUpstream: upstream,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {
handleErrorCalled = true
},
})
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.resp)
defer func() { _ = result.resp.Body.Close() }()
require.Equal(t, http.StatusOK, result.resp.StatusCode)
require.False(t, handleErrorCalled)
require.Len(t, upstream.calls, 2)
require.True(t, strings.HasPrefix(upstream.calls[0], base1))
require.True(t, strings.HasPrefix(upstream.calls[1], base2))
available := antigravity.DefaultURLAvailability.GetAvailableURLs()
require.NotEmpty(t, available)
require.Equal(t, base2, available[0])
}
func TestAntigravityHandleUpstreamError_UsesScopeLimitWhenEnabled(t *testing.T) {
t.Setenv(antigravityScopeRateLimitEnv, "true")
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 9, Name: "acc-9", Platform: PlatformAntigravity}
body := buildGeminiRateLimitBody("3s")
svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude)
require.Len(t, repo.scopeCalls, 1)
require.Empty(t, repo.rateCalls)
call := repo.scopeCalls[0]
require.Equal(t, account.ID, call.accountID)
require.Equal(t, AntigravityQuotaScopeClaude, call.scope)
require.WithinDuration(t, time.Now().Add(3*time.Second), call.resetAt, 2*time.Second)
}
func TestAntigravityHandleUpstreamError_UsesAccountLimitWhenScopeDisabled(t *testing.T) {
t.Setenv(antigravityScopeRateLimitEnv, "false")
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 10, Name: "acc-10", Platform: PlatformAntigravity}
body := buildGeminiRateLimitBody("2s")
svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude)
require.Len(t, repo.rateCalls, 1)
require.Empty(t, repo.scopeCalls)
call := repo.rateCalls[0]
require.Equal(t, account.ID, call.accountID)
require.WithinDuration(t, time.Now().Add(2*time.Second), call.resetAt, 2*time.Second)
}
func TestAccountIsSchedulableForModel_AntigravityRateLimits(t *testing.T) {
now := time.Now()
future := now.Add(10 * time.Minute)
account := &Account{
ID: 1,
Name: "acc",
Platform: PlatformAntigravity,
Status: StatusActive,
Schedulable: true,
}
account.RateLimitResetAt = &future
require.False(t, account.IsSchedulableForModel("claude-sonnet-4-5"))
require.False(t, account.IsSchedulableForModel("gemini-3-flash"))
account.RateLimitResetAt = nil
account.Extra = map[string]any{
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": future.Format(time.RFC3339),
},
},
}
require.False(t, account.IsSchedulableForModel("claude-sonnet-4-5"))
require.True(t, account.IsSchedulableForModel("gemini-3-flash"))
}
func buildGeminiRateLimitBody(delay string) []byte {
return []byte(fmt.Sprintf(`{"error":{"message":"too many requests","details":[{"metadata":{"quotaResetDelay":%q}}]}}`, delay))
}
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"log" "log"
"log/slog"
"strconv" "strconv"
"strings" "strings"
"time" "time"
...@@ -101,21 +102,32 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * ...@@ -101,21 +102,32 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
return "", errors.New("access_token not found in credentials") return "", errors.New("access_token not found in credentials")
} }
// 3. 存入缓存 // 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
if p.tokenCache != nil { if p.tokenCache != nil {
ttl := 30 * time.Minute latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
if expiresAt != nil { if isStale && latestAccount != nil {
until := time.Until(*expiresAt) // 版本过时,使用 DB 中的最新 token
switch { slog.Debug("antigravity_token_version_stale_use_latest", "account_id", account.ID)
case until > antigravityTokenCacheSkew: accessToken = latestAccount.GetCredential("access_token")
ttl = until - antigravityTokenCacheSkew if strings.TrimSpace(accessToken) == "" {
case until > 0: return "", errors.New("access_token not found after version check")
ttl = until
default:
ttl = time.Minute
} }
// 不写入缓存,让下次请求重新处理
} else {
ttl := 30 * time.Minute
if expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > antigravityTokenCacheSkew:
ttl = until - antigravityTokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
}
}
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
} }
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
} }
return accessToken, nil return accessToken, nil
......
...@@ -3,6 +3,8 @@ package service ...@@ -3,6 +3,8 @@ package service
import ( import (
"context" "context"
"fmt" "fmt"
"log"
"strings"
"time" "time"
) )
...@@ -55,11 +57,33 @@ func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Accoun ...@@ -55,11 +57,33 @@ func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Accoun
} }
newCredentials := r.antigravityOAuthService.BuildAccountCredentials(tokenInfo) newCredentials := r.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
// 合并旧的 credentials,保留新 credentials 中不存在的字段
for k, v := range account.Credentials { for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists { if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v newCredentials[k] = v
} }
} }
// 特殊处理 project_id:如果新值为空但旧值非空,保留旧值
// 这确保了即使 LoadCodeAssist 失败,project_id 也不会丢失
if newProjectID, _ := newCredentials["project_id"].(string); newProjectID == "" {
if oldProjectID := strings.TrimSpace(account.GetCredential("project_id")); oldProjectID != "" {
newCredentials["project_id"] = oldProjectID
}
}
// 如果 project_id 获取失败,只记录警告,不返回错误
// LoadCodeAssist 失败可能是临时网络问题,应该允许重试而不是立即标记为不可重试错误
// Token 刷新本身是成功的(access_token 和 refresh_token 已更新)
if tokenInfo.ProjectIDMissing {
if tokenInfo.ProjectID != "" {
// 有旧的 project_id,本次获取失败,保留旧值
log.Printf("[AntigravityTokenRefresher] Account %d: LoadCodeAssist 临时失败,保留旧 project_id", account.ID)
} else {
// 从未获取过 project_id,本次也失败,但不返回错误以允许下次重试
log.Printf("[AntigravityTokenRefresher] Account %d: LoadCodeAssist 失败,project_id 缺失,但 token 已更新,将在下次刷新时重试", account.ID)
}
}
return newCredentials, nil return newCredentials, nil
} }
...@@ -94,6 +94,20 @@ func (s *APIKeyService) initAuthCache(cfg *config.Config) { ...@@ -94,6 +94,20 @@ func (s *APIKeyService) initAuthCache(cfg *config.Config) {
s.authCacheL1 = cache s.authCacheL1 = cache
} }
// StartAuthCacheInvalidationSubscriber starts the Pub/Sub subscriber for L1 cache invalidation.
// This should be called after the service is fully initialized.
func (s *APIKeyService) StartAuthCacheInvalidationSubscriber(ctx context.Context) {
if s.cache == nil || s.authCacheL1 == nil {
return
}
if err := s.cache.SubscribeAuthCacheInvalidation(ctx, func(cacheKey string) {
s.authCacheL1.Del(cacheKey)
}); err != nil {
// Log but don't fail - L1 cache will still work, just without cross-instance invalidation
println("[Service] Warning: failed to start auth cache invalidation subscriber:", err.Error())
}
}
func (s *APIKeyService) authCacheKey(key string) string { func (s *APIKeyService) authCacheKey(key string) string {
sum := sha256.Sum256([]byte(key)) sum := sha256.Sum256([]byte(key))
return hex.EncodeToString(sum[:]) return hex.EncodeToString(sum[:])
...@@ -149,6 +163,8 @@ func (s *APIKeyService) deleteAuthCache(ctx context.Context, cacheKey string) { ...@@ -149,6 +163,8 @@ func (s *APIKeyService) deleteAuthCache(ctx context.Context, cacheKey string) {
return return
} }
_ = s.cache.DeleteAuthCache(ctx, cacheKey) _ = s.cache.DeleteAuthCache(ctx, cacheKey)
// Publish invalidation message to other instances
_ = s.cache.PublishAuthCacheInvalidation(ctx, cacheKey)
} }
func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey string) (*APIKeyAuthCacheEntry, error) { func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey string) (*APIKeyAuthCacheEntry, error) {
......
...@@ -65,6 +65,10 @@ type APIKeyCache interface { ...@@ -65,6 +65,10 @@ type APIKeyCache interface {
GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error
DeleteAuthCache(ctx context.Context, key string) error DeleteAuthCache(ctx context.Context, key string) error
// Pub/Sub for L1 cache invalidation across instances
PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error
SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error
} }
// APIKeyAuthCacheInvalidator 提供认证缓存失效能力 // APIKeyAuthCacheInvalidator 提供认证缓存失效能力
......
...@@ -142,6 +142,14 @@ func (s *authCacheStub) DeleteAuthCache(ctx context.Context, key string) error { ...@@ -142,6 +142,14 @@ func (s *authCacheStub) DeleteAuthCache(ctx context.Context, key string) error {
return nil return nil
} }
func (s *authCacheStub) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error {
return nil
}
func (s *authCacheStub) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error {
return nil
}
func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) { func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
cache := &authCacheStub{} cache := &authCacheStub{}
repo := &authRepoStub{ repo := &authRepoStub{
......
...@@ -168,6 +168,14 @@ func (s *apiKeyCacheStub) DeleteAuthCache(ctx context.Context, key string) error ...@@ -168,6 +168,14 @@ func (s *apiKeyCacheStub) DeleteAuthCache(ctx context.Context, key string) error
return nil return nil
} }
func (s *apiKeyCacheStub) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error {
return nil
}
func (s *apiKeyCacheStub) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error {
return nil
}
// TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。 // TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
// 预期行为: // 预期行为:
// - GetKeyAndOwnerID 返回所有者 ID 为 1 // - GetKeyAndOwnerID 返回所有者 ID 为 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment