Commit 292f25f9 authored by yangjianbo's avatar yangjianbo
Browse files
parents c92e3777 fbb57294
......@@ -575,6 +575,15 @@ func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, ac
}
}
func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
_, err := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
SetStatus(service.StatusActive).
SetErrorMessage("").
Save(ctx)
return err
}
func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
_, err := r.client.AccountGroup.Create().
SetAccountID(accountID).
......@@ -993,7 +1002,16 @@ func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, s
builder.SetSessionWindowEnd(*end)
}
_, err := builder.Save(ctx)
if err != nil {
return err
}
// 触发调度器缓存更新(仅当窗口时间有变化时)
if start != nil || end != nil {
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue session window update failed: account=%d err=%v", id, err)
}
}
return nil
}
func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
......
......@@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"log"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
......@@ -15,6 +16,7 @@ const (
apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
apiKeyRateLimitDuration = 24 * time.Hour
apiKeyAuthCachePrefix = "apikey:auth:"
authCacheInvalidateChannel = "auth:cache:invalidate"
)
// apiKeyRateLimitKey generates the Redis key for API key creation rate limiting.
......@@ -91,3 +93,45 @@ func (c *apiKeyCache) SetAuthCache(ctx context.Context, key string, entry *servi
func (c *apiKeyCache) DeleteAuthCache(ctx context.Context, key string) error {
return c.rdb.Del(ctx, apiKeyAuthCacheKey(key)).Err()
}
// PublishAuthCacheInvalidation publishes a cache invalidation message to all instances
func (c *apiKeyCache) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error {
return c.rdb.Publish(ctx, authCacheInvalidateChannel, cacheKey).Err()
}
// SubscribeAuthCacheInvalidation subscribes to cache invalidation messages
func (c *apiKeyCache) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error {
pubsub := c.rdb.Subscribe(ctx, authCacheInvalidateChannel)
// Verify subscription is working
_, err := pubsub.Receive(ctx)
if err != nil {
_ = pubsub.Close()
return fmt.Errorf("subscribe to auth cache invalidation: %w", err)
}
go func() {
defer func() {
if err := pubsub.Close(); err != nil {
log.Printf("Warning: failed to close auth cache invalidation pubsub: %v", err)
}
}()
ch := pubsub.Channel()
for {
select {
case <-ctx.Done():
return
case msg, ok := <-ch:
if !ok {
return
}
if msg != nil {
handler(msg.Payload)
}
}
}
}()
return nil
}
......@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/url"
......@@ -14,6 +15,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
)
......@@ -150,6 +152,172 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i
return resp, nil
}
// DoWithTLS 执行带 TLS 指纹伪装的 HTTP 请求
// 根据 enableTLSFingerprint 参数决定是否使用 TLS 指纹
//
// 参数:
// - req: HTTP 请求对象
// - proxyURL: 代理地址,空字符串表示直连
// - accountID: 账户 ID,用于账户级隔离和 TLS 指纹模板选择
// - accountConcurrency: 账户并发限制,用于动态调整连接池大小
// - enableTLSFingerprint: 是否启用 TLS 指纹伪装
//
// TLS 指纹说明:
// - 当 enableTLSFingerprint=true 时,使用 utls 库模拟 Claude CLI 的 TLS 指纹
// - 指纹模板根据 accountID % len(profiles) 自动选择
// - 支持直连、HTTP/HTTPS 代理、SOCKS5 代理三种场景
func (s *httpUpstreamService) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
// 如果未启用 TLS 指纹,直接使用标准请求路径
if !enableTLSFingerprint {
return s.Do(req, proxyURL, accountID, accountConcurrency)
}
// TLS 指纹已启用,记录调试日志
targetHost := ""
if req != nil && req.URL != nil {
targetHost = req.URL.Host
}
proxyInfo := "direct"
if proxyURL != "" {
proxyInfo = proxyURL
}
slog.Debug("tls_fingerprint_enabled", "account_id", accountID, "target", targetHost, "proxy", proxyInfo)
if err := s.validateRequestHost(req); err != nil {
return nil, err
}
// 获取 TLS 指纹 Profile
registry := tlsfingerprint.GlobalRegistry()
profile := registry.GetProfileByAccountID(accountID)
if profile == nil {
// 如果获取不到 profile,回退到普通请求
slog.Debug("tls_fingerprint_no_profile", "account_id", accountID, "fallback", "standard_request")
return s.Do(req, proxyURL, accountID, accountConcurrency)
}
slog.Debug("tls_fingerprint_using_profile", "account_id", accountID, "profile", profile.Name, "grease", profile.EnableGREASE)
// 获取或创建带 TLS 指纹的客户端
entry, err := s.acquireClientWithTLS(proxyURL, accountID, accountConcurrency, profile)
if err != nil {
slog.Debug("tls_fingerprint_acquire_client_failed", "account_id", accountID, "error", err)
return nil, err
}
// 执行请求
resp, err := entry.client.Do(req)
if err != nil {
// 请求失败,立即减少计数
atomic.AddInt64(&entry.inFlight, -1)
atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano())
slog.Debug("tls_fingerprint_request_failed", "account_id", accountID, "error", err)
return nil, err
}
slog.Debug("tls_fingerprint_request_success", "account_id", accountID, "status", resp.StatusCode)
// 包装响应体,在关闭时自动减少计数并更新时间戳
resp.Body = wrapTrackedBody(resp.Body, func() {
atomic.AddInt64(&entry.inFlight, -1)
atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano())
})
return resp, nil
}
// acquireClientWithTLS 获取或创建带 TLS 指纹的客户端
func (s *httpUpstreamService) acquireClientWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*upstreamClientEntry, error) {
return s.getClientEntryWithTLS(proxyURL, accountID, accountConcurrency, profile, true, true)
}
// getClientEntryWithTLS 获取或创建带 TLS 指纹的客户端条目
// TLS 指纹客户端使用独立的缓存键,与普通客户端隔离
func (s *httpUpstreamService) getClientEntryWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) {
isolation := s.getIsolationMode()
proxyKey, parsedProxy := normalizeProxyURL(proxyURL)
// TLS 指纹客户端使用独立的缓存键,加 "tls:" 前缀
cacheKey := "tls:" + buildCacheKey(isolation, proxyKey, accountID)
poolKey := s.buildPoolKey(isolation, accountConcurrency) + ":tls"
now := time.Now()
nowUnix := now.UnixNano()
// 读锁快速路径
s.mu.RLock()
if entry, ok := s.clients[cacheKey]; ok && s.shouldReuseEntry(entry, isolation, proxyKey, poolKey) {
atomic.StoreInt64(&entry.lastUsed, nowUnix)
if markInFlight {
atomic.AddInt64(&entry.inFlight, 1)
}
s.mu.RUnlock()
slog.Debug("tls_fingerprint_reusing_client", "account_id", accountID, "cache_key", cacheKey)
return entry, nil
}
s.mu.RUnlock()
// 写锁慢路径
s.mu.Lock()
if entry, ok := s.clients[cacheKey]; ok {
if s.shouldReuseEntry(entry, isolation, proxyKey, poolKey) {
atomic.StoreInt64(&entry.lastUsed, nowUnix)
if markInFlight {
atomic.AddInt64(&entry.inFlight, 1)
}
s.mu.Unlock()
slog.Debug("tls_fingerprint_reusing_client", "account_id", accountID, "cache_key", cacheKey)
return entry, nil
}
slog.Debug("tls_fingerprint_evicting_stale_client",
"account_id", accountID,
"cache_key", cacheKey,
"proxy_changed", entry.proxyKey != proxyKey,
"pool_changed", entry.poolKey != poolKey)
s.removeClientLocked(cacheKey, entry)
}
// 超出缓存上限时尝试淘汰
if enforceLimit && s.maxUpstreamClients() > 0 {
s.evictIdleLocked(now)
if len(s.clients) >= s.maxUpstreamClients() {
if !s.evictOldestIdleLocked() {
s.mu.Unlock()
return nil, errUpstreamClientLimitReached
}
}
}
// 创建带 TLS 指纹的 Transport
slog.Debug("tls_fingerprint_creating_new_client", "account_id", accountID, "cache_key", cacheKey, "proxy", proxyKey)
settings := s.resolvePoolSettings(isolation, accountConcurrency)
transport, err := buildUpstreamTransportWithTLSFingerprint(settings, parsedProxy, profile)
if err != nil {
s.mu.Unlock()
return nil, fmt.Errorf("build TLS fingerprint transport: %w", err)
}
client := &http.Client{Transport: transport}
if s.shouldValidateResolvedIP() {
client.CheckRedirect = s.redirectChecker
}
entry := &upstreamClientEntry{
client: client,
proxyKey: proxyKey,
poolKey: poolKey,
}
atomic.StoreInt64(&entry.lastUsed, nowUnix)
if markInFlight {
atomic.StoreInt64(&entry.inFlight, 1)
}
s.clients[cacheKey] = entry
s.evictIdleLocked(now)
s.evictOverLimitLocked()
s.mu.Unlock()
return entry, nil
}
func (s *httpUpstreamService) shouldValidateResolvedIP() bool {
if s.cfg == nil {
return false
......@@ -618,6 +786,64 @@ func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) (*http.Tra
return transport, nil
}
// buildUpstreamTransportWithTLSFingerprint 构建带 TLS 指纹伪装的 Transport
// 使用 utls 库模拟 Claude CLI 的 TLS 指纹
//
// 参数:
// - settings: 连接池配置
// - proxyURL: 代理 URL(nil 表示直连)
// - profile: TLS 指纹配置
//
// 返回:
// - *http.Transport: 配置好的 Transport 实例
// - error: 配置错误
//
// 代理类型处理:
// - nil/空: 直连,使用 TLSFingerprintDialer
// - http/https: HTTP 代理,使用 HTTPProxyDialer(CONNECT 隧道 + utls 握手)
// - socks5: SOCKS5 代理,使用 SOCKS5ProxyDialer(SOCKS5 隧道 + utls 握手)
func buildUpstreamTransportWithTLSFingerprint(settings poolSettings, proxyURL *url.URL, profile *tlsfingerprint.Profile) (*http.Transport, error) {
transport := &http.Transport{
MaxIdleConns: settings.maxIdleConns,
MaxIdleConnsPerHost: settings.maxIdleConnsPerHost,
MaxConnsPerHost: settings.maxConnsPerHost,
IdleConnTimeout: settings.idleConnTimeout,
ResponseHeaderTimeout: settings.responseHeaderTimeout,
// 禁用默认的 TLS,我们使用自定义的 DialTLSContext
ForceAttemptHTTP2: false,
}
// 根据代理类型选择合适的 TLS 指纹 Dialer
if proxyURL == nil {
// 直连:使用 TLSFingerprintDialer
slog.Debug("tls_fingerprint_transport_direct")
dialer := tlsfingerprint.NewDialer(profile, nil)
transport.DialTLSContext = dialer.DialTLSContext
} else {
scheme := strings.ToLower(proxyURL.Scheme)
switch scheme {
case "socks5", "socks5h":
// SOCKS5 代理:使用 SOCKS5ProxyDialer
slog.Debug("tls_fingerprint_transport_socks5", "proxy", proxyURL.Host)
socks5Dialer := tlsfingerprint.NewSOCKS5ProxyDialer(profile, proxyURL)
transport.DialTLSContext = socks5Dialer.DialTLSContext
case "http", "https":
// HTTP/HTTPS 代理:使用 HTTPProxyDialer(CONNECT 隧道)
slog.Debug("tls_fingerprint_transport_http_connect", "proxy", proxyURL.Host)
httpDialer := tlsfingerprint.NewHTTPProxyDialer(profile, proxyURL)
transport.DialTLSContext = httpDialer.DialTLSContext
default:
// 未知代理类型,回退到普通代理配置(无 TLS 指纹)
slog.Debug("tls_fingerprint_transport_unknown_scheme_fallback", "scheme", scheme)
if err := proxyutil.ConfigureTransportProxy(transport, proxyURL); err != nil {
return nil, err
}
}
}
return transport, nil
}
// trackedBody 带跟踪功能的响应体包装器
// 在 Close 时执行回调,用于更新请求计数
type trackedBody struct {
......
......@@ -13,6 +13,8 @@ import (
const (
fingerprintKeyPrefix = "fingerprint:"
fingerprintTTL = 24 * time.Hour
maskedSessionKeyPrefix = "masked_session:"
maskedSessionTTL = 15 * time.Minute
)
// fingerprintKey generates the Redis key for account fingerprint cache.
......@@ -20,6 +22,11 @@ func fingerprintKey(accountID int64) string {
return fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
}
// maskedSessionKey generates the Redis key for masked session ID cache.
func maskedSessionKey(accountID int64) string {
return fmt.Sprintf("%s%d", maskedSessionKeyPrefix, accountID)
}
type identityCache struct {
rdb *redis.Client
}
......@@ -49,3 +56,20 @@ func (c *identityCache) SetFingerprint(ctx context.Context, accountID int64, fp
}
return c.rdb.Set(ctx, key, val, fingerprintTTL).Err()
}
func (c *identityCache) GetMaskedSessionID(ctx context.Context, accountID int64) (string, error) {
key := maskedSessionKey(accountID)
val, err := c.rdb.Get(ctx, key).Result()
if err != nil {
if err == redis.Nil {
return "", nil
}
return "", err
}
return val, nil
}
func (c *identityCache) SetMaskedSessionID(ctx context.Context, accountID int64, sessionID string) error {
key := maskedSessionKey(accountID)
return c.rdb.Set(ctx, key, sessionID, maskedSessionTTL).Err()
}
......@@ -217,7 +217,7 @@ func (c *sessionLimitCache) GetActiveSessionCount(ctx context.Context, accountID
}
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64, idleTimeouts map[int64]time.Duration) (map[int64]int, error) {
if len(accountIDs) == 0 {
return make(map[int64]int), nil
}
......@@ -226,11 +226,18 @@ func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, acco
// 使用 pipeline 批量执行
pipe := c.rdb.Pipeline()
idleTimeoutSeconds := int(c.defaultIdleTimeout.Seconds())
cmds := make(map[int64]*redis.Cmd, len(accountIDs))
for _, accountID := range accountIDs {
key := sessionLimitKey(accountID)
// 使用各账号自己的 idleTimeout,如果没有则用默认值
idleTimeout := c.defaultIdleTimeout
if idleTimeouts != nil {
if t, ok := idleTimeouts[accountID]; ok && t > 0 {
idleTimeout = t
}
}
idleTimeoutSeconds := int(idleTimeout.Seconds())
cmds[accountID] = getActiveSessionCountScript.Run(ctx, pipe, []string{key}, idleTimeoutSeconds)
}
......
......@@ -618,6 +618,14 @@ func (stubApiKeyCache) DeleteAuthCache(ctx context.Context, key string) error {
return nil
}
func (stubApiKeyCache) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error {
return nil
}
func (stubApiKeyCache) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error {
return nil
}
type stubGroupRepo struct{}
func (stubGroupRepo) Create(ctx context.Context, group *service.Group) error {
......@@ -736,6 +744,10 @@ func (s *stubAccountRepo) SetError(ctx context.Context, id int64, errorMsg strin
return errors.New("not implemented")
}
func (s *stubAccountRepo) ClearError(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (s *stubAccountRepo) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return errors.New("not implemented")
}
......
......@@ -576,6 +576,44 @@ func (a *Account) IsAnthropicOAuthOrSetupToken() bool {
return a.Platform == PlatformAnthropic && (a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken)
}
// IsTLSFingerprintEnabled 检查是否启用 TLS 指纹伪装
// 仅适用于 Anthropic OAuth/SetupToken 类型账号
// 启用后将模拟 Claude Code (Node.js) 客户端的 TLS 握手特征
func (a *Account) IsTLSFingerprintEnabled() bool {
// 仅支持 Anthropic OAuth/SetupToken 账号
if !a.IsAnthropicOAuthOrSetupToken() {
return false
}
if a.Extra == nil {
return false
}
if v, ok := a.Extra["enable_tls_fingerprint"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}
// IsSessionIDMaskingEnabled 检查是否启用会话ID伪装
// 仅适用于 Anthropic OAuth/SetupToken 类型账号
// 启用后将在一段时间内(15分钟)固定 metadata.user_id 中的 session ID,
// 使上游认为请求来自同一个会话
func (a *Account) IsSessionIDMaskingEnabled() bool {
if !a.IsAnthropicOAuthOrSetupToken() {
return false
}
if a.Extra == nil {
return false
}
if v, ok := a.Extra["session_id_masking_enabled"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
// 返回 0 表示未启用
func (a *Account) GetWindowCostLimit() float64 {
......@@ -652,6 +690,23 @@ func (a *Account) CheckWindowCostSchedulability(currentWindowCost float64) Windo
return WindowCostNotSchedulable
}
// GetCurrentWindowStartTime 获取当前有效的窗口开始时间
// 逻辑:
// 1. 如果窗口未过期(SessionWindowEnd 存在且在当前时间之后),使用记录的 SessionWindowStart
// 2. 否则(窗口过期或未设置),使用新的预测窗口开始时间(从当前整点开始)
func (a *Account) GetCurrentWindowStartTime() time.Time {
now := time.Now()
// 窗口未过期,使用记录的窗口开始时间
if a.SessionWindowStart != nil && a.SessionWindowEnd != nil && now.Before(*a.SessionWindowEnd) {
return *a.SessionWindowStart
}
// 窗口已过期或未设置,预测新的窗口开始时间(从当前整点开始)
// 与 ratelimit_service.go 中 UpdateSessionWindow 的预测逻辑保持一致
return time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), 0, 0, 0, now.Location())
}
// parseExtraFloat64 从 extra 字段解析 float64 值
func parseExtraFloat64(value any) float64 {
switch v := value.(type) {
......
......@@ -37,6 +37,7 @@ type AccountRepository interface {
UpdateLastUsed(ctx context.Context, id int64) error
BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error
SetError(ctx context.Context, id int64, errorMsg string) error
ClearError(ctx context.Context, id int64) error
SetSchedulable(ctx context.Context, id int64, schedulable bool) error
AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error)
BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error
......
......@@ -99,6 +99,10 @@ func (s *accountRepoStub) SetError(ctx context.Context, id int64, errorMsg strin
panic("unexpected SetError call")
}
func (s *accountRepoStub) ClearError(ctx context.Context, id int64) error {
panic("unexpected ClearError call")
}
func (s *accountRepoStub) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
panic("unexpected SetSchedulable call")
}
......
......@@ -265,7 +265,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
......@@ -375,7 +375,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
......@@ -446,7 +446,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
......
......@@ -369,12 +369,8 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou
// 如果没有缓存,从数据库查询
if windowStats == nil {
var startTime time.Time
if account.SessionWindowStart != nil {
startTime = *account.SessionWindowStart
} else {
startTime = time.Now().Add(-5 * time.Hour)
}
// 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
startTime := account.GetCurrentWindowStartTime()
stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime)
if err != nil {
......
......@@ -42,6 +42,7 @@ type AdminService interface {
DeleteAccount(ctx context.Context, id int64) error
RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error)
ClearAccountError(ctx context.Context, id int64) (*Account, error)
SetAccountError(ctx context.Context, id int64, errorMsg string) error
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error)
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
......@@ -1101,6 +1102,10 @@ func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Ac
return account, nil
}
func (s *adminServiceImpl) SetAccountError(ctx context.Context, id int64, errorMsg string) error {
return s.accountRepo.SetError(ctx, id, errorMsg)
}
func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) {
if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil {
return nil, err
......
......@@ -12,6 +12,7 @@ import (
mathrand "math/rand"
"net"
"net/http"
"os"
"strings"
"sync/atomic"
"time"
......@@ -28,6 +29,207 @@ const (
antigravityRetryMaxDelay = 16 * time.Second
)
const antigravityScopeRateLimitEnv = "GATEWAY_ANTIGRAVITY_429_SCOPE_LIMIT"
// antigravityRetryLoopParams 重试循环的参数
type antigravityRetryLoopParams struct {
ctx context.Context
prefix string
account *Account
proxyURL string
accessToken string
action string
body []byte
quotaScope AntigravityQuotaScope
c *gin.Context
httpUpstream HTTPUpstream
settingService *SettingService
handleError func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope)
}
// antigravityRetryLoopResult 重试循环的结果
type antigravityRetryLoopResult struct {
resp *http.Response
}
// antigravityRetryLoop 执行带 URL fallback 的重试循环
func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) {
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
if len(availableURLs) == 0 {
availableURLs = antigravity.BaseURLs
}
var resp *http.Response
var usedBaseURL string
logBody := p.settingService != nil && p.settingService.cfg != nil && p.settingService.cfg.Gateway.LogUpstreamErrorBody
maxBytes := 2048
if p.settingService != nil && p.settingService.cfg != nil && p.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
maxBytes = p.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
}
getUpstreamDetail := func(body []byte) string {
if !logBody {
return ""
}
return truncateString(string(body), maxBytes)
}
urlFallbackLoop:
for urlIdx, baseURL := range availableURLs {
usedBaseURL = baseURL
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
select {
case <-p.ctx.Done():
log.Printf("%s status=context_canceled error=%v", p.prefix, p.ctx.Err())
return nil, p.ctx.Err()
default:
}
upstreamReq, err := antigravity.NewAPIRequestWithURL(p.ctx, baseURL, p.action, p.accessToken, p.body)
if err != nil {
return nil, err
}
// Capture upstream request body for ops retry of this attempt.
if p.c != nil && len(p.body) > 0 {
p.c.Set(OpsUpstreamRequestBodyKey, string(p.body))
}
resp, err = p.httpUpstream.Do(upstreamReq, p.proxyURL, p.account.ID, p.account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{
Platform: p.account.Platform,
AccountID: p.account.ID,
AccountName: p.account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
log.Printf("%s URL fallback (connection error): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1])
continue urlFallbackLoop
}
if attempt < antigravityMaxRetries {
log.Printf("%s status=request_failed retry=%d/%d error=%v", p.prefix, attempt, antigravityMaxRetries, err)
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
return nil, p.ctx.Err()
}
continue
}
log.Printf("%s status=request_failed retries_exhausted error=%v", p.prefix, err)
setOpsUpstreamError(p.c, 0, safeErr, "")
return nil, fmt.Errorf("upstream request failed after retries: %w", err)
}
// 429 限流处理:区分 URL 级别限流和账户配额限流
if resp.StatusCode == http.StatusTooManyRequests {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
// "Resource has been exhausted" 是 URL 级别限流,切换 URL
if isURLLevelRateLimit(respBody) && urlIdx < len(availableURLs)-1 {
log.Printf("%s URL fallback (429): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1])
continue urlFallbackLoop
}
// 账户/模型配额限流,重试 3 次(指数退避)
if attempt < antigravityMaxRetries {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{
Platform: p.account.Platform,
AccountID: p.account.ID,
AccountName: p.account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "retry",
Message: upstreamMsg,
Detail: getUpstreamDetail(respBody),
})
log.Printf("%s status=429 retry=%d/%d body=%s", p.prefix, attempt, antigravityMaxRetries, truncateForLog(respBody, 200))
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
return nil, p.ctx.Err()
}
continue
}
// 重试用尽,标记账户限流
p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.quotaScope)
log.Printf("%s status=429 rate_limited base_url=%s body=%s", p.prefix, baseURL, truncateForLog(respBody, 200))
resp = &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
}
break urlFallbackLoop
}
// 其他可重试错误
if resp.StatusCode >= 400 && shouldRetryAntigravityError(resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
if attempt < antigravityMaxRetries {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{
Platform: p.account.Platform,
AccountID: p.account.ID,
AccountName: p.account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "retry",
Message: upstreamMsg,
Detail: getUpstreamDetail(respBody),
})
log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500))
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
return nil, p.ctx.Err()
}
continue
}
resp = &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
}
break urlFallbackLoop
}
break urlFallbackLoop
}
}
if resp != nil && resp.StatusCode < 400 && usedBaseURL != "" {
antigravity.DefaultURLAvailability.MarkSuccess(usedBaseURL)
}
return &antigravityRetryLoopResult{resp: resp}, nil
}
// shouldRetryAntigravityError 判断是否应该重试
func shouldRetryAntigravityError(statusCode int) bool {
switch statusCode {
case 429, 500, 502, 503, 504, 529:
return true
default:
return false
}
}
// isURLLevelRateLimit 判断是否为 URL 级别的限流(应切换 URL 重试)
// "Resource has been exhausted" 是 URL/节点级别限流,切换 URL 可能成功
// "exhausted your capacity on this model" 是账户/模型配额限流,切换 URL 无效
func isURLLevelRateLimit(body []byte) bool {
// 快速检查:包含 "Resource has been exhausted" 且不包含 "capacity on this model"
bodyStr := string(body)
return strings.Contains(bodyStr, "Resource has been exhausted") &&
!strings.Contains(bodyStr, "capacity on this model")
}
// isAntigravityConnectionError 判断是否为连接错误(网络超时、DNS 失败、连接拒绝)
func isAntigravityConnectionError(err error) bool {
if err == nil {
......@@ -238,7 +440,6 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
if err != nil {
lastErr = fmt.Errorf("请求失败: %w", err)
if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("[antigravity-Test] URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
continue
}
......@@ -254,7 +455,6 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
// 检查是否需要 URL 降级
if shouldAntigravityFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("[antigravity-Test] URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
continue
}
......@@ -266,6 +466,8 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
// 解析流式响应,提取文本
text := extractTextFromSSEResponse(respBody)
// 标记成功的 URL,下次优先使用
antigravity.DefaultURLAvailability.MarkSuccess(baseURL)
return &TestConnectionResult{
Text: text,
MappedModel: mappedModel,
......@@ -276,13 +478,14 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
}
// buildGeminiTestRequest 构建 Gemini 格式测试请求
// 使用最小 token 消耗:输入 "." + maxOutputTokens: 1
func (s *AntigravityGatewayService) buildGeminiTestRequest(projectID, model string) ([]byte, error) {
payload := map[string]any{
"contents": []map[string]any{
{
"role": "user",
"parts": []map[string]any{
{"text": "hi"},
{"text": "."},
},
},
},
......@@ -292,22 +495,26 @@ func (s *AntigravityGatewayService) buildGeminiTestRequest(projectID, model stri
{"text": antigravity.GetDefaultIdentityPatch()},
},
},
"generationConfig": map[string]any{
"maxOutputTokens": 1,
},
}
payloadBytes, _ := json.Marshal(payload)
return s.wrapV1InternalRequest(projectID, model, payloadBytes)
}
// buildClaudeTestRequest 构建 Claude 格式测试请求并转换为 Gemini 格式
// 使用最小 token 消耗:输入 "." + MaxTokens: 1
func (s *AntigravityGatewayService) buildClaudeTestRequest(projectID, mappedModel string) ([]byte, error) {
claudeReq := &antigravity.ClaudeRequest{
Model: mappedModel,
Messages: []antigravity.ClaudeMessage{
{
Role: "user",
Content: json.RawMessage(`"hi"`),
Content: json.RawMessage(`"."`),
},
},
MaxTokens: 1024,
MaxTokens: 1,
Stream: false,
}
return antigravity.TransformClaudeToGemini(claudeReq, projectID, mappedModel)
......@@ -523,9 +730,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
proxyURL = account.Proxy.URL()
}
// Sanitize thinking blocks (clean cache_control and flatten history thinking)
sanitizeThinkingBlocks(&claudeReq)
// 获取转换选项
// Antigravity 上游要求必须包含身份提示词,否则会返回 429
transformOpts := s.getClaudeTransformOptions(ctx)
......@@ -537,150 +741,29 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return nil, fmt.Errorf("transform request: %w", err)
}
// Safety net: ensure no cache_control leaked into Gemini request
geminiBody = cleanCacheControlFromGeminiJSON(geminiBody)
// Antigravity 上游只支持流式请求,统一使用 streamGenerateContent
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回
action := "streamGenerateContent"
// URL fallback 循环
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
if len(availableURLs) == 0 {
availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有
}
// 重试循环
var resp *http.Response
urlFallbackLoop:
for urlIdx, baseURL := range availableURLs {
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
// 检查 context 是否已取消(客户端断开连接)
select {
case <-ctx.Done():
log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err())
return nil, ctx.Err()
default:
}
upstreamReq, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, action, accessToken, geminiBody)
// Capture upstream request body for ops retry of this attempt.
if c != nil {
c.Set(OpsUpstreamRequestBodyKey, string(geminiBody))
}
if err != nil {
return nil, err
}
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
// 执行带重试的请求
result, err := antigravityRetryLoop(antigravityRetryLoopParams{
ctx: ctx,
prefix: prefix,
account: account,
proxyURL: proxyURL,
accessToken: accessToken,
action: action,
body: geminiBody,
quotaScope: quotaScope,
c: c,
httpUpstream: s.httpUpstream,
settingService: s.settingService,
handleError: s.handleUpstreamError,
})
// 检查是否应触发 URL 降级
if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("%s URL fallback (connection error): %s -> %s", prefix, baseURL, availableURLs[urlIdx+1])
continue urlFallbackLoop
}
if attempt < antigravityMaxRetries {
log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err)
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", prefix)
return nil, ctx.Err()
}
continue
}
log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
setOpsUpstreamError(c, 0, safeErr, "")
if err != nil {
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
}
// 检查是否应触发 URL 降级(仅 429)
if resp.StatusCode == http.StatusTooManyRequests && urlIdx < len(availableURLs)-1 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
maxBytes := 2048
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
}
upstreamDetail := ""
if logBody {
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "retry",
Message: upstreamMsg,
Detail: upstreamDetail,
})
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("%s URL fallback (HTTP 429): %s -> %s body=%s", prefix, baseURL, availableURLs[urlIdx+1], truncateForLog(respBody, 200))
continue urlFallbackLoop
}
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
if attempt < antigravityMaxRetries {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
maxBytes := 2048
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
}
upstreamDetail := ""
if logBody {
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "retry",
Message: upstreamMsg,
Detail: upstreamDetail,
})
log.Printf("%s status=%d retry=%d/%d body=%s", prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500))
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", prefix)
return nil, ctx.Err()
}
continue
}
// 所有重试都失败,标记限流状态
if resp.StatusCode == 429 {
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
}
// 最后一次尝试也失败
resp = &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
}
break urlFallbackLoop
}
break urlFallbackLoop
}
}
resp := result.resp
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode >= 400 {
......@@ -739,11 +822,20 @@ urlFallbackLoop:
if txErr != nil {
continue
}
retryReq, buildErr := antigravity.NewAPIRequest(ctx, action, accessToken, retryGeminiBody)
if buildErr != nil {
continue
}
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
retryResult, retryErr := antigravityRetryLoop(antigravityRetryLoopParams{
ctx: ctx,
prefix: prefix,
account: account,
proxyURL: proxyURL,
accessToken: accessToken,
action: action,
body: retryGeminiBody,
quotaScope: quotaScope,
c: c,
httpUpstream: s.httpUpstream,
settingService: s.settingService,
handleError: s.handleUpstreamError,
})
if retryErr != nil {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
......@@ -757,6 +849,7 @@ urlFallbackLoop:
continue
}
retryResp := retryResult.resp
if retryResp.StatusCode < 400 {
_ = resp.Body.Close()
resp = retryResp
......@@ -766,6 +859,13 @@ urlFallbackLoop:
retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
_ = retryResp.Body.Close()
if retryResp.StatusCode == http.StatusTooManyRequests {
retryBaseURL := ""
if retryResp.Request != nil && retryResp.Request.URL != nil {
retryBaseURL = retryResp.Request.URL.Scheme + "://" + retryResp.Request.URL.Host
}
log.Printf("%s status=429 rate_limited base_url=%s retry_stage=%s body=%s", prefix, retryBaseURL, stage.name, truncateForLog(retryBody, 200))
}
kind := "signature_retry"
if strings.TrimSpace(stage.name) != "" {
kind = "signature_retry_" + strings.ReplaceAll(stage.name, "+", "_")
......@@ -920,143 +1020,6 @@ func extractAntigravityErrorMessage(body []byte) string {
return ""
}
// cleanCacheControlFromGeminiJSON removes cache_control from Gemini JSON (emergency fix)
// This should not be needed if transformation is correct, but serves as a safety net
func cleanCacheControlFromGeminiJSON(body []byte) []byte {
// Try a more robust approach: parse and clean
var data map[string]any
if err := json.Unmarshal(body, &data); err != nil {
log.Printf("[Antigravity] Failed to parse Gemini JSON for cache_control cleaning: %v", err)
return body
}
cleaned := removeCacheControlFromAny(data)
if !cleaned {
return body
}
if result, err := json.Marshal(data); err == nil {
log.Printf("[Antigravity] Successfully cleaned cache_control from Gemini JSON")
return result
}
return body
}
// removeCacheControlFromAny recursively removes cache_control fields
func removeCacheControlFromAny(v any) bool {
cleaned := false
switch val := v.(type) {
case map[string]any:
for k, child := range val {
if k == "cache_control" {
delete(val, k)
cleaned = true
} else if removeCacheControlFromAny(child) {
cleaned = true
}
}
case []any:
for _, item := range val {
if removeCacheControlFromAny(item) {
cleaned = true
}
}
}
return cleaned
}
// sanitizeThinkingBlocks cleans cache_control and flattens history thinking blocks
// Thinking blocks do NOT support cache_control field (Anthropic API/Vertex AI requirement)
// Additionally, history thinking blocks are flattened to text to avoid upstream validation errors
func sanitizeThinkingBlocks(req *antigravity.ClaudeRequest) {
if req == nil {
return
}
log.Printf("[Antigravity] sanitizeThinkingBlocks: processing request with %d messages", len(req.Messages))
// Clean system blocks
if len(req.System) > 0 {
var systemBlocks []map[string]any
if err := json.Unmarshal(req.System, &systemBlocks); err == nil {
for i := range systemBlocks {
if blockType, _ := systemBlocks[i]["type"].(string); blockType == "thinking" || systemBlocks[i]["thinking"] != nil {
if removeCacheControlFromAny(systemBlocks[i]) {
log.Printf("[Antigravity] Deep cleaned cache_control from thinking block in system[%d]", i)
}
}
}
// Marshal back
if cleaned, err := json.Marshal(systemBlocks); err == nil {
req.System = cleaned
}
}
}
// Clean message content blocks and flatten history
lastMsgIdx := len(req.Messages) - 1
for msgIdx := range req.Messages {
raw := req.Messages[msgIdx].Content
if len(raw) == 0 {
continue
}
// Try to parse as blocks array
var blocks []map[string]any
if err := json.Unmarshal(raw, &blocks); err != nil {
continue
}
cleaned := false
for blockIdx := range blocks {
blockType, _ := blocks[blockIdx]["type"].(string)
// Check for thinking blocks (typed or untyped)
if blockType == "thinking" || blocks[blockIdx]["thinking"] != nil {
// 1. Clean cache_control
if removeCacheControlFromAny(blocks[blockIdx]) {
log.Printf("[Antigravity] Deep cleaned cache_control from thinking block in messages[%d].content[%d]", msgIdx, blockIdx)
cleaned = true
}
// 2. Flatten to text if it's a history message (not the last one)
if msgIdx < lastMsgIdx {
log.Printf("[Antigravity] Flattening history thinking block to text at messages[%d].content[%d]", msgIdx, blockIdx)
// Extract thinking content
var textContent string
if t, ok := blocks[blockIdx]["thinking"].(string); ok {
textContent = t
} else {
// Fallback for non-string content (marshal it)
if b, err := json.Marshal(blocks[blockIdx]["thinking"]); err == nil {
textContent = string(b)
}
}
// Convert to text block
blocks[blockIdx]["type"] = "text"
blocks[blockIdx]["text"] = textContent
delete(blocks[blockIdx], "thinking")
delete(blocks[blockIdx], "signature")
delete(blocks[blockIdx], "cache_control") // Ensure it's gone
cleaned = true
}
}
}
// Marshal back if modified
if cleaned {
if marshaled, err := json.Marshal(blocks); err == nil {
req.Messages[msgIdx].Content = marshaled
}
}
}
}
// stripThinkingFromClaudeRequest converts thinking blocks to text blocks in a Claude Messages request.
// This preserves the thinking content while avoiding signature validation errors.
// Note: redacted_thinking blocks are removed because they cannot be converted to text.
......@@ -1352,138 +1315,25 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后返回
upstreamAction := "streamGenerateContent"
// URL fallback 循环
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
if len(availableURLs) == 0 {
availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有
}
// 重试循环
var resp *http.Response
urlFallbackLoop:
for urlIdx, baseURL := range availableURLs {
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
// 检查 context 是否已取消(客户端断开连接)
select {
case <-ctx.Done():
log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err())
return nil, ctx.Err()
default:
}
upstreamReq, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, upstreamAction, accessToken, wrappedBody)
if err != nil {
return nil, err
}
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
// 执行带重试的请求
result, err := antigravityRetryLoop(antigravityRetryLoopParams{
ctx: ctx,
prefix: prefix,
account: account,
proxyURL: proxyURL,
accessToken: accessToken,
action: upstreamAction,
body: wrappedBody,
quotaScope: quotaScope,
c: c,
httpUpstream: s.httpUpstream,
settingService: s.settingService,
handleError: s.handleUpstreamError,
})
// 检查是否应触发 URL 降级
if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("%s URL fallback (connection error): %s -> %s", prefix, baseURL, availableURLs[urlIdx+1])
continue urlFallbackLoop
}
if attempt < antigravityMaxRetries {
log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err)
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", prefix)
return nil, ctx.Err()
}
continue
}
log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
setOpsUpstreamError(c, 0, safeErr, "")
if err != nil {
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
}
// 检查是否应触发 URL 降级(仅 429)
if resp.StatusCode == http.StatusTooManyRequests && urlIdx < len(availableURLs)-1 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
maxBytes := 2048
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
}
upstreamDetail := ""
if logBody {
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "retry",
Message: upstreamMsg,
Detail: upstreamDetail,
})
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("%s URL fallback (HTTP 429): %s -> %s body=%s", prefix, baseURL, availableURLs[urlIdx+1], truncateForLog(respBody, 200))
continue urlFallbackLoop
}
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
if attempt < antigravityMaxRetries {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
maxBytes := 2048
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
}
upstreamDetail := ""
if logBody {
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "retry",
Message: upstreamMsg,
Detail: upstreamDetail,
})
log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries)
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", prefix)
return nil, ctx.Err()
}
continue
}
// 所有重试都失败,标记限流状态
if resp.StatusCode == 429 {
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
}
resp = &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
}
break urlFallbackLoop
}
break urlFallbackLoop
}
}
resp := result.resp
defer func() {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
......@@ -1525,8 +1375,6 @@ urlFallbackLoop:
goto handleSuccess
}
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
requestID := resp.Header.Get("x-request-id")
if requestID != "" {
c.Header("x-request-id", requestID)
......@@ -1537,6 +1385,7 @@ urlFallbackLoop:
if unwrapErr != nil || len(unwrappedForOps) == 0 {
unwrappedForOps = respBody
}
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(unwrappedForOps))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
......@@ -1581,6 +1430,7 @@ urlFallbackLoop:
Message: upstreamMsg,
Detail: upstreamDetail,
})
log.Printf("[antigravity-Forward] upstream error status=%d body=%s", resp.StatusCode, truncateForLog(unwrappedForOps, 500))
c.Data(resp.StatusCode, contentType, unwrappedForOps)
return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode)
}
......@@ -1637,15 +1487,6 @@ handleSuccess:
}, nil
}
func (s *AntigravityGatewayService) shouldRetryUpstreamError(statusCode int) bool {
switch statusCode {
case 429, 500, 502, 503, 504, 529:
return true
default:
return false
}
}
func (s *AntigravityGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
switch statusCode {
case 401, 403, 429, 529:
......@@ -1679,34 +1520,49 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
}
}
func antigravityUseScopeRateLimit() bool {
v := strings.ToLower(strings.TrimSpace(os.Getenv(antigravityScopeRateLimitEnv)))
return v == "1" || v == "true" || v == "yes" || v == "on"
}
func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {
// 429 使用 Gemini 格式解析(从 body 解析重置时间)
if statusCode == 429 {
useScopeLimit := antigravityUseScopeRateLimit() && quotaScope != ""
resetAt := ParseGeminiRateLimitResetTime(body)
if resetAt == nil {
// 解析失败:Gemini 有重试时间用 5 分钟,Claude 没有用 1 分钟
defaultDur := 1 * time.Minute
if bytes.Contains(body, []byte("Please retry in")) || bytes.Contains(body, []byte("retryDelay")) {
defaultDur = 5 * time.Minute
// 解析失败:使用配置的 fallback 时间,直接限流整个账户
fallbackMinutes := 5
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes > 0 {
fallbackMinutes = s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes
}
defaultDur := time.Duration(fallbackMinutes) * time.Minute
ra := time.Now().Add(defaultDur)
if useScopeLimit {
log.Printf("%s status=429 rate_limited scope=%s reset_in=%v (fallback)", prefix, quotaScope, defaultDur)
if quotaScope == "" {
return
}
if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, ra); err != nil {
log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err)
}
} else {
log.Printf("%s status=429 rate_limited account=%d reset_in=%v (fallback)", prefix, account.ID, defaultDur)
if err := s.accountRepo.SetRateLimited(ctx, account.ID, ra); err != nil {
log.Printf("%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err)
}
}
return
}
resetTime := time.Unix(*resetAt, 0)
if useScopeLimit {
log.Printf("%s status=429 rate_limited scope=%s reset_at=%v reset_in=%v", prefix, quotaScope, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second))
if quotaScope == "" {
return
}
if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, resetTime); err != nil {
log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err)
}
} else {
log.Printf("%s status=429 rate_limited account=%d reset_at=%v reset_in=%v", prefix, account.ID, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second))
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil {
log.Printf("%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err)
}
}
return
}
// 其他错误码继续使用 rateLimitService
......@@ -1884,7 +1740,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
}
// handleGeminiStreamToNonStreaming 读取上游流式响应,合并为非流式响应返回给客户端
// Gemini 流式响应中每个 chunk 都包含累积的完整文本,只需保留最后一个有效响应
// Gemini 流式响应是增量的,需要累积所有 chunk 的内容
func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Context, resp *http.Response, startTime time.Time) (*antigravityStreamResult, error) {
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
......@@ -1897,6 +1753,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
var firstTokenMs *int
var last map[string]any
var lastWithParts map[string]any
var collectedImageParts []map[string]any // 收集所有包含图片的 parts
var collectedTextParts []string // 收集所有文本片段
type scanEvent struct {
line string
......@@ -1999,6 +1857,16 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
// 保留最后一个有 parts 的响应
if parts := extractGeminiParts(parsed); len(parts) > 0 {
lastWithParts = parsed
// 收集包含图片和文本的 parts
for _, part := range parts {
if inlineData, ok := part["inlineData"].(map[string]any); ok {
collectedImageParts = append(collectedImageParts, part)
_ = inlineData // 避免 unused 警告
}
if text, ok := part["text"].(string); ok && text != "" {
collectedTextParts = append(collectedTextParts, text)
}
}
}
case <-intervalCh:
......@@ -2020,6 +1888,16 @@ returnResponse:
log.Printf("[antigravity-Forward] warning: empty stream response, no valid chunks received")
}
// 如果收集到了图片 parts,需要合并到最终响应中
if len(collectedImageParts) > 0 {
finalResponse = mergeImagePartsToResponse(finalResponse, collectedImageParts)
}
// 如果收集到了文本,需要合并到最终响应中
if len(collectedTextParts) > 0 {
finalResponse = mergeTextPartsToResponse(finalResponse, collectedTextParts)
}
respBody, err := json.Marshal(finalResponse)
if err != nil {
return nil, fmt.Errorf("failed to marshal response: %w", err)
......@@ -2029,6 +1907,115 @@ returnResponse:
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
// getOrCreateGeminiParts 获取 Gemini 响应的 parts 结构,返回深拷贝和更新回调
func getOrCreateGeminiParts(response map[string]any) (result map[string]any, existingParts []any, setParts func([]any)) {
// 深拷贝 response
result = make(map[string]any)
for k, v := range response {
result[k] = v
}
// 获取或创建 candidates
candidates, ok := result["candidates"].([]any)
if !ok || len(candidates) == 0 {
candidates = []any{map[string]any{}}
}
// 获取第一个 candidate
candidate, ok := candidates[0].(map[string]any)
if !ok {
candidate = make(map[string]any)
candidates[0] = candidate
}
// 获取或创建 content
content, ok := candidate["content"].(map[string]any)
if !ok {
content = map[string]any{"role": "model"}
candidate["content"] = content
}
// 获取现有 parts
existingParts, ok = content["parts"].([]any)
if !ok {
existingParts = []any{}
}
// 返回更新回调
setParts = func(newParts []any) {
content["parts"] = newParts
result["candidates"] = candidates
}
return result, existingParts, setParts
}
// mergeImagePartsToResponse 将收集到的图片 parts 合并到 Gemini 响应中
func mergeImagePartsToResponse(response map[string]any, imageParts []map[string]any) map[string]any {
if len(imageParts) == 0 {
return response
}
result, existingParts, setParts := getOrCreateGeminiParts(response)
// 检查现有 parts 中是否已经有图片
for _, p := range existingParts {
if pm, ok := p.(map[string]any); ok {
if _, hasInline := pm["inlineData"]; hasInline {
return result // 已有图片,不重复添加
}
}
}
// 添加收集到的图片 parts
for _, imgPart := range imageParts {
existingParts = append(existingParts, imgPart)
}
setParts(existingParts)
return result
}
// mergeTextPartsToResponse 将收集到的文本合并到 Gemini 响应中
func mergeTextPartsToResponse(response map[string]any, textParts []string) map[string]any {
if len(textParts) == 0 {
return response
}
mergedText := strings.Join(textParts, "")
result, existingParts, setParts := getOrCreateGeminiParts(response)
// 查找并更新第一个 text part,或创建新的
newParts := make([]any, 0, len(existingParts)+1)
textUpdated := false
for _, p := range existingParts {
pm, ok := p.(map[string]any)
if !ok {
newParts = append(newParts, p)
continue
}
if _, hasText := pm["text"]; hasText && !textUpdated {
// 用累积的文本替换
newPart := make(map[string]any)
for k, v := range pm {
newPart[k] = v
}
newPart["text"] = mergedText
newParts = append(newParts, newPart)
textUpdated = true
} else {
newParts = append(newParts, pm)
}
}
if !textUpdated {
newParts = append([]any{map[string]any{"text": mergedText}}, newParts...)
}
setParts(newParts)
return result
}
func (s *AntigravityGatewayService) writeClaudeError(c *gin.Context, status int, errType, message string) error {
c.JSON(status, gin.H{
"type": "error",
......
......@@ -89,6 +89,7 @@ type AntigravityTokenInfo struct {
TokenType string `json:"token_type"`
Email string `json:"email,omitempty"`
ProjectID string `json:"project_id,omitempty"`
ProjectIDMissing bool `json:"-"` // LoadCodeAssist 未返回 project_id
}
// ExchangeCode 用 authorization code 交换 token
......@@ -149,12 +150,6 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
result.ProjectID = loadResp.CloudAICompanionProject
}
// 兜底:随机生成 project_id
if result.ProjectID == "" {
result.ProjectID = antigravity.GenerateMockProjectID()
fmt.Printf("[AntigravityOAuth] 使用随机生成的 project_id: %s\n", result.ProjectID)
}
return result, nil
}
......@@ -236,16 +231,24 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou
return nil, err
}
// 保留原有的 project_id 和 email
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
if existingProjectID != "" {
tokenInfo.ProjectID = existingProjectID
}
// 保留原有的 email
existingEmail := strings.TrimSpace(account.GetCredential("email"))
if existingEmail != "" {
tokenInfo.Email = existingEmail
}
// 每次刷新都调用 LoadCodeAssist 获取 project_id
client := antigravity.NewClient(proxyURL)
loadResp, _, err := client.LoadCodeAssist(ctx, tokenInfo.AccessToken)
if err != nil || loadResp == nil || loadResp.CloudAICompanionProject == "" {
// LoadCodeAssist 失败或返回空,保留原有 project_id,标记缺失
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
tokenInfo.ProjectID = existingProjectID
tokenInfo.ProjectIDMissing = true
} else {
tokenInfo.ProjectID = loadResp.CloudAICompanionProject
}
return tokenInfo, nil
}
......
......@@ -31,11 +31,6 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
accessToken := account.GetCredential("access_token")
projectID := account.GetCredential("project_id")
// 如果没有 project_id,生成一个随机的
if projectID == "" {
projectID = antigravity.GenerateMockProjectID()
}
client := antigravity.NewClient(proxyURL)
// 调用 API 获取配额
......
//go:build unit
package service
import (
"context"
"fmt"
"io"
"net/http"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/stretchr/testify/require"
)
type stubAntigravityUpstream struct {
firstBase string
secondBase string
calls []string
}
func (s *stubAntigravityUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
url := req.URL.String()
s.calls = append(s.calls, url)
if strings.HasPrefix(url, s.firstBase) {
return &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"error":{"message":"Resource has been exhausted"}}`)),
}, nil
}
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader("ok")),
}, nil
}
func (s *stubAntigravityUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
return s.Do(req, proxyURL, accountID, accountConcurrency)
}
type scopeLimitCall struct {
accountID int64
scope AntigravityQuotaScope
resetAt time.Time
}
type rateLimitCall struct {
accountID int64
resetAt time.Time
}
type stubAntigravityAccountRepo struct {
AccountRepository
scopeCalls []scopeLimitCall
rateCalls []rateLimitCall
}
func (s *stubAntigravityAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
s.scopeCalls = append(s.scopeCalls, scopeLimitCall{accountID: id, scope: scope, resetAt: resetAt})
return nil
}
func (s *stubAntigravityAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
s.rateCalls = append(s.rateCalls, rateLimitCall{accountID: id, resetAt: resetAt})
return nil
}
func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
oldAvailability := antigravity.DefaultURLAvailability
defer func() {
antigravity.BaseURLs = oldBaseURLs
antigravity.DefaultURLAvailability = oldAvailability
}()
base1 := "https://ag-1.test"
base2 := "https://ag-2.test"
antigravity.BaseURLs = []string{base1, base2}
antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
upstream := &stubAntigravityUpstream{firstBase: base1, secondBase: base2}
account := &Account{
ID: 1,
Name: "acc-1",
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
}
var handleErrorCalled bool
result, err := antigravityRetryLoop(antigravityRetryLoopParams{
prefix: "[test]",
ctx: context.Background(),
account: account,
proxyURL: "",
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
quotaScope: AntigravityQuotaScopeClaude,
httpUpstream: upstream,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {
handleErrorCalled = true
},
})
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.resp)
defer func() { _ = result.resp.Body.Close() }()
require.Equal(t, http.StatusOK, result.resp.StatusCode)
require.False(t, handleErrorCalled)
require.Len(t, upstream.calls, 2)
require.True(t, strings.HasPrefix(upstream.calls[0], base1))
require.True(t, strings.HasPrefix(upstream.calls[1], base2))
available := antigravity.DefaultURLAvailability.GetAvailableURLs()
require.NotEmpty(t, available)
require.Equal(t, base2, available[0])
}
func TestAntigravityHandleUpstreamError_UsesScopeLimitWhenEnabled(t *testing.T) {
t.Setenv(antigravityScopeRateLimitEnv, "true")
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 9, Name: "acc-9", Platform: PlatformAntigravity}
body := buildGeminiRateLimitBody("3s")
svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude)
require.Len(t, repo.scopeCalls, 1)
require.Empty(t, repo.rateCalls)
call := repo.scopeCalls[0]
require.Equal(t, account.ID, call.accountID)
require.Equal(t, AntigravityQuotaScopeClaude, call.scope)
require.WithinDuration(t, time.Now().Add(3*time.Second), call.resetAt, 2*time.Second)
}
func TestAntigravityHandleUpstreamError_UsesAccountLimitWhenScopeDisabled(t *testing.T) {
t.Setenv(antigravityScopeRateLimitEnv, "false")
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 10, Name: "acc-10", Platform: PlatformAntigravity}
body := buildGeminiRateLimitBody("2s")
svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude)
require.Len(t, repo.rateCalls, 1)
require.Empty(t, repo.scopeCalls)
call := repo.rateCalls[0]
require.Equal(t, account.ID, call.accountID)
require.WithinDuration(t, time.Now().Add(2*time.Second), call.resetAt, 2*time.Second)
}
func TestAccountIsSchedulableForModel_AntigravityRateLimits(t *testing.T) {
now := time.Now()
future := now.Add(10 * time.Minute)
account := &Account{
ID: 1,
Name: "acc",
Platform: PlatformAntigravity,
Status: StatusActive,
Schedulable: true,
}
account.RateLimitResetAt = &future
require.False(t, account.IsSchedulableForModel("claude-sonnet-4-5"))
require.False(t, account.IsSchedulableForModel("gemini-3-flash"))
account.RateLimitResetAt = nil
account.Extra = map[string]any{
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": future.Format(time.RFC3339),
},
},
}
require.False(t, account.IsSchedulableForModel("claude-sonnet-4-5"))
require.True(t, account.IsSchedulableForModel("gemini-3-flash"))
}
func buildGeminiRateLimitBody(delay string) []byte {
return []byte(fmt.Sprintf(`{"error":{"message":"too many requests","details":[{"metadata":{"quotaResetDelay":%q}}]}}`, delay))
}
......@@ -61,5 +61,10 @@ func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Accoun
}
}
// 如果 project_id 获取失败,返回 credentials 但同时返回错误让账户被标记
if tokenInfo.ProjectIDMissing {
return newCredentials, fmt.Errorf("missing_project_id: 账户缺少project id,可能无法使用Antigravity")
}
return newCredentials, nil
}
......@@ -94,6 +94,20 @@ func (s *APIKeyService) initAuthCache(cfg *config.Config) {
s.authCacheL1 = cache
}
// StartAuthCacheInvalidationSubscriber starts the Pub/Sub subscriber for L1 cache invalidation.
// This should be called after the service is fully initialized.
func (s *APIKeyService) StartAuthCacheInvalidationSubscriber(ctx context.Context) {
if s.cache == nil || s.authCacheL1 == nil {
return
}
if err := s.cache.SubscribeAuthCacheInvalidation(ctx, func(cacheKey string) {
s.authCacheL1.Del(cacheKey)
}); err != nil {
// Log but don't fail - L1 cache will still work, just without cross-instance invalidation
println("[Service] Warning: failed to start auth cache invalidation subscriber:", err.Error())
}
}
func (s *APIKeyService) authCacheKey(key string) string {
sum := sha256.Sum256([]byte(key))
return hex.EncodeToString(sum[:])
......@@ -149,6 +163,8 @@ func (s *APIKeyService) deleteAuthCache(ctx context.Context, cacheKey string) {
return
}
_ = s.cache.DeleteAuthCache(ctx, cacheKey)
// Publish invalidation message to other instances
_ = s.cache.PublishAuthCacheInvalidation(ctx, cacheKey)
}
func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey string) (*APIKeyAuthCacheEntry, error) {
......
......@@ -65,6 +65,10 @@ type APIKeyCache interface {
GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error
DeleteAuthCache(ctx context.Context, key string) error
// Pub/Sub for L1 cache invalidation across instances
PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error
SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error
}
// APIKeyAuthCacheInvalidator 提供认证缓存失效能力
......
......@@ -142,6 +142,14 @@ func (s *authCacheStub) DeleteAuthCache(ctx context.Context, key string) error {
return nil
}
func (s *authCacheStub) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error {
return nil
}
func (s *authCacheStub) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error {
return nil
}
func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment