Unverified Commit 03c75787 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #325 from slovx2/main

fix(antigravity): 修复Antigravity 频繁429的问题,以及一系列优化,配置增强
parents de6797c5 c115c9e0
...@@ -257,6 +257,14 @@ type GatewayConfig struct { ...@@ -257,6 +257,14 @@ type GatewayConfig struct {
// 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义) // 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义)
FailoverOn400 bool `mapstructure:"failover_on_400"` FailoverOn400 bool `mapstructure:"failover_on_400"`
// 账户切换最大次数(遇到上游错误时切换到其他账户的次数上限)
MaxAccountSwitches int `mapstructure:"max_account_switches"`
// Gemini 账户切换最大次数(Gemini 平台单独配置,因 API 限制更严格)
MaxAccountSwitchesGemini int `mapstructure:"max_account_switches_gemini"`
// Antigravity 429 fallback 限流时间(分钟),解析重置时间失败时使用
AntigravityFallbackCooldownMinutes int `mapstructure:"antigravity_fallback_cooldown_minutes"`
// Scheduling: 账号调度相关配置 // Scheduling: 账号调度相关配置
Scheduling GatewaySchedulingConfig `mapstructure:"scheduling"` Scheduling GatewaySchedulingConfig `mapstructure:"scheduling"`
...@@ -298,6 +306,9 @@ type GatewaySchedulingConfig struct { ...@@ -298,6 +306,9 @@ type GatewaySchedulingConfig struct {
FallbackWaitTimeout time.Duration `mapstructure:"fallback_wait_timeout"` FallbackWaitTimeout time.Duration `mapstructure:"fallback_wait_timeout"`
FallbackMaxWaiting int `mapstructure:"fallback_max_waiting"` FallbackMaxWaiting int `mapstructure:"fallback_max_waiting"`
// 兜底层账户选择策略: "last_used"(按最后使用时间排序,默认) 或 "random"(随机)
FallbackSelectionMode string `mapstructure:"fallback_selection_mode"`
// 负载计算 // 负载计算
LoadBatchEnabled bool `mapstructure:"load_batch_enabled"` LoadBatchEnabled bool `mapstructure:"load_batch_enabled"`
...@@ -786,6 +797,9 @@ func setDefaults() { ...@@ -786,6 +797,9 @@ func setDefaults() {
viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048) viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048)
viper.SetDefault("gateway.inject_beta_for_apikey", false) viper.SetDefault("gateway.inject_beta_for_apikey", false)
viper.SetDefault("gateway.failover_on_400", false) viper.SetDefault("gateway.failover_on_400", false)
viper.SetDefault("gateway.max_account_switches", 10)
viper.SetDefault("gateway.max_account_switches_gemini", 3)
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024)) viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy) viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
// HTTP 上游连接池配置(针对 5000+ 并发用户优化) // HTTP 上游连接池配置(针对 5000+ 并发用户优化)
...@@ -798,11 +812,12 @@ func setDefaults() { ...@@ -798,11 +812,12 @@ func setDefaults() {
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求) viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求)
viper.SetDefault("gateway.stream_data_interval_timeout", 180) viper.SetDefault("gateway.stream_data_interval_timeout", 180)
viper.SetDefault("gateway.stream_keepalive_interval", 10) viper.SetDefault("gateway.stream_keepalive_interval", 10)
viper.SetDefault("gateway.max_line_size", 10*1024*1024) viper.SetDefault("gateway.max_line_size", 40*1024*1024)
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3) viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second) viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second)
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second) viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100) viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100)
viper.SetDefault("gateway.scheduling.fallback_selection_mode", "last_used")
viper.SetDefault("gateway.scheduling.load_batch_enabled", true) viper.SetDefault("gateway.scheduling.load_batch_enabled", true)
viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second) viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second)
viper.SetDefault("gateway.scheduling.db_fallback_enabled", true) viper.SetDefault("gateway.scheduling.db_fallback_enabled", true)
......
...@@ -541,6 +541,36 @@ func (h *AccountHandler) Refresh(c *gin.Context) { ...@@ -541,6 +541,36 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
newCredentials[k] = v newCredentials[k] = v
} }
} }
// 如果 project_id 获取失败,先更新凭证,再标记账户为 error
if tokenInfo.ProjectIDMissing {
// 先更新凭证
_, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
Credentials: newCredentials,
})
if updateErr != nil {
response.InternalError(c, "Failed to update credentials: "+updateErr.Error())
return
}
// 标记账户为 error
if setErr := h.adminService.SetAccountError(c.Request.Context(), accountID, "missing_project_id: 账户缺少project id,可能无法使用Antigravity"); setErr != nil {
response.InternalError(c, "Failed to set account error: "+setErr.Error())
return
}
response.Success(c, gin.H{
"message": "Token refreshed but project_id is missing, account marked as error",
"warning": "missing_project_id",
})
return
}
// 成功获取到 project_id,如果之前是 missing_project_id 错误则清除
if account.Status == service.StatusError && strings.Contains(account.ErrorMessage, "missing_project_id:") {
if _, clearErr := h.adminService.ClearAccountError(c.Request.Context(), accountID); clearErr != nil {
response.InternalError(c, "Failed to clear account error: "+clearErr.Error())
return
}
}
} else { } else {
// Use Anthropic/Claude OAuth service to refresh token // Use Anthropic/Claude OAuth service to refresh token
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account) tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
......
...@@ -31,6 +31,8 @@ type GatewayHandler struct { ...@@ -31,6 +31,8 @@ type GatewayHandler struct {
userService *service.UserService userService *service.UserService
billingCacheService *service.BillingCacheService billingCacheService *service.BillingCacheService
concurrencyHelper *ConcurrencyHelper concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
maxAccountSwitchesGemini int
} }
// NewGatewayHandler creates a new GatewayHandler // NewGatewayHandler creates a new GatewayHandler
...@@ -44,8 +46,16 @@ func NewGatewayHandler( ...@@ -44,8 +46,16 @@ func NewGatewayHandler(
cfg *config.Config, cfg *config.Config,
) *GatewayHandler { ) *GatewayHandler {
pingInterval := time.Duration(0) pingInterval := time.Duration(0)
maxAccountSwitches := 10
maxAccountSwitchesGemini := 3
if cfg != nil { if cfg != nil {
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
if cfg.Gateway.MaxAccountSwitches > 0 {
maxAccountSwitches = cfg.Gateway.MaxAccountSwitches
}
if cfg.Gateway.MaxAccountSwitchesGemini > 0 {
maxAccountSwitchesGemini = cfg.Gateway.MaxAccountSwitchesGemini
}
} }
return &GatewayHandler{ return &GatewayHandler{
gatewayService: gatewayService, gatewayService: gatewayService,
...@@ -54,6 +64,8 @@ func NewGatewayHandler( ...@@ -54,6 +64,8 @@ func NewGatewayHandler(
userService: userService, userService: userService,
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
maxAccountSwitches: maxAccountSwitches,
maxAccountSwitchesGemini: maxAccountSwitchesGemini,
} }
} }
...@@ -179,7 +191,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -179,7 +191,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} }
if platform == service.PlatformGemini { if platform == service.PlatformGemini {
const maxAccountSwitches = 3 maxAccountSwitches := h.maxAccountSwitchesGemini
switchCount := 0 switchCount := 0
failedAccountIDs := make(map[int64]struct{}) failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0 lastFailoverStatus := 0
...@@ -313,7 +325,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -313,7 +325,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} }
} }
const maxAccountSwitches = 10 maxAccountSwitches := h.maxAccountSwitches
switchCount := 0 switchCount := 0
failedAccountIDs := make(map[int64]struct{}) failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0 lastFailoverStatus := 0
......
...@@ -220,7 +220,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -220,7 +220,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
if sessionHash != "" { if sessionHash != "" {
sessionKey = "gemini:" + sessionHash sessionKey = "gemini:" + sessionHash
} }
const maxAccountSwitches = 3 maxAccountSwitches := h.maxAccountSwitchesGemini
switchCount := 0 switchCount := 0
failedAccountIDs := make(map[int64]struct{}) failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0 lastFailoverStatus := 0
......
...@@ -25,6 +25,7 @@ type OpenAIGatewayHandler struct { ...@@ -25,6 +25,7 @@ type OpenAIGatewayHandler struct {
gatewayService *service.OpenAIGatewayService gatewayService *service.OpenAIGatewayService
billingCacheService *service.BillingCacheService billingCacheService *service.BillingCacheService
concurrencyHelper *ConcurrencyHelper concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
} }
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler // NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
...@@ -35,13 +36,18 @@ func NewOpenAIGatewayHandler( ...@@ -35,13 +36,18 @@ func NewOpenAIGatewayHandler(
cfg *config.Config, cfg *config.Config,
) *OpenAIGatewayHandler { ) *OpenAIGatewayHandler {
pingInterval := time.Duration(0) pingInterval := time.Duration(0)
maxAccountSwitches := 3
if cfg != nil { if cfg != nil {
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
if cfg.Gateway.MaxAccountSwitches > 0 {
maxAccountSwitches = cfg.Gateway.MaxAccountSwitches
}
} }
return &OpenAIGatewayHandler{ return &OpenAIGatewayHandler{
gatewayService: gatewayService, gatewayService: gatewayService,
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
maxAccountSwitches: maxAccountSwitches,
} }
} }
...@@ -189,7 +195,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -189,7 +195,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Generate session hash (header first; fallback to prompt_cache_key) // Generate session hash (header first; fallback to prompt_cache_key)
sessionHash := h.gatewayService.GenerateSessionHash(c, reqBody) sessionHash := h.gatewayService.GenerateSessionHash(c, reqBody)
const maxAccountSwitches = 3 maxAccountSwitches := h.maxAccountSwitches
switchCount := 0 switchCount := 0
failedAccountIDs := make(map[int64]struct{}) failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0 lastFailoverStatus := 0
......
...@@ -16,15 +16,6 @@ import ( ...@@ -16,15 +16,6 @@ import (
"time" "time"
) )
// resolveHost 从 URL 解析 host
func resolveHost(urlStr string) string {
parsed, err := url.Parse(urlStr)
if err != nil {
return ""
}
return parsed.Host
}
// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点) // NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点)
func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) { func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) {
// 构建 URL,流式请求添加 ?alt=sse 参数 // 构建 URL,流式请求添加 ?alt=sse 参数
...@@ -39,23 +30,11 @@ func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken stri ...@@ -39,23 +30,11 @@ func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken stri
return nil, err return nil, err
} }
// 基础 Headers // 基础 Headers(与 Antigravity-Manager 保持一致,只设置这 3 个)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+accessToken) req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("User-Agent", UserAgent) req.Header.Set("User-Agent", UserAgent)
// Accept Header 根据请求类型设置
if isStream {
req.Header.Set("Accept", "text/event-stream")
} else {
req.Header.Set("Accept", "application/json")
}
// 显式设置 Host Header
if host := resolveHost(apiURL); host != "" {
req.Host = host
}
return req, nil return req, nil
} }
...@@ -195,12 +174,15 @@ func isConnectionError(err error) bool { ...@@ -195,12 +174,15 @@ func isConnectionError(err error) bool {
} }
// shouldFallbackToNextURL 判断是否应切换到下一个 URL // shouldFallbackToNextURL 判断是否应切换到下一个 URL
// 仅连接错误和 HTTP 429 触发 URL 降级 // 与 Antigravity-Manager 保持一致:连接错误、429、408、404、5xx 触发 URL 降级
func shouldFallbackToNextURL(err error, statusCode int) bool { func shouldFallbackToNextURL(err error, statusCode int) bool {
if isConnectionError(err) { if isConnectionError(err) {
return true return true
} }
return statusCode == http.StatusTooManyRequests return statusCode == http.StatusTooManyRequests ||
statusCode == http.StatusRequestTimeout ||
statusCode == http.StatusNotFound ||
statusCode >= 500
} }
// ExchangeCode 用 authorization code 交换 token // ExchangeCode 用 authorization code 交换 token
...@@ -321,11 +303,8 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC ...@@ -321,11 +303,8 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
return nil, nil, fmt.Errorf("序列化请求失败: %w", err) return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
} }
// 获取可用的 URL 列表 // 固定顺序:prod -> daily
availableURLs := DefaultURLAvailability.GetAvailableURLs() availableURLs := BaseURLs
if len(availableURLs) == 0 {
availableURLs = BaseURLs // 所有 URL 都不可用时,重试所有
}
var lastErr error var lastErr error
for urlIdx, baseURL := range availableURLs { for urlIdx, baseURL := range availableURLs {
...@@ -343,7 +322,6 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC ...@@ -343,7 +322,6 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
if err != nil { if err != nil {
lastErr = fmt.Errorf("loadCodeAssist 请求失败: %w", err) lastErr = fmt.Errorf("loadCodeAssist 请求失败: %w", err)
if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("[antigravity] loadCodeAssist URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1]) log.Printf("[antigravity] loadCodeAssist URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
continue continue
} }
...@@ -358,7 +336,6 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC ...@@ -358,7 +336,6 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
// 检查是否需要 URL 降级 // 检查是否需要 URL 降级
if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 { if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("[antigravity] loadCodeAssist URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1]) log.Printf("[antigravity] loadCodeAssist URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
continue continue
} }
...@@ -376,6 +353,8 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC ...@@ -376,6 +353,8 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
var rawResp map[string]any var rawResp map[string]any
_ = json.Unmarshal(respBodyBytes, &rawResp) _ = json.Unmarshal(respBodyBytes, &rawResp)
// 标记成功的 URL,下次优先使用
DefaultURLAvailability.MarkSuccess(baseURL)
return &loadResp, rawResp, nil return &loadResp, rawResp, nil
} }
...@@ -412,11 +391,8 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI ...@@ -412,11 +391,8 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
return nil, nil, fmt.Errorf("序列化请求失败: %w", err) return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
} }
// 获取可用的 URL 列表 // 固定顺序:prod -> daily
availableURLs := DefaultURLAvailability.GetAvailableURLs() availableURLs := BaseURLs
if len(availableURLs) == 0 {
availableURLs = BaseURLs // 所有 URL 都不可用时,重试所有
}
var lastErr error var lastErr error
for urlIdx, baseURL := range availableURLs { for urlIdx, baseURL := range availableURLs {
...@@ -434,7 +410,6 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI ...@@ -434,7 +410,6 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
if err != nil { if err != nil {
lastErr = fmt.Errorf("fetchAvailableModels 请求失败: %w", err) lastErr = fmt.Errorf("fetchAvailableModels 请求失败: %w", err)
if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("[antigravity] fetchAvailableModels URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1]) log.Printf("[antigravity] fetchAvailableModels URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
continue continue
} }
...@@ -449,7 +424,6 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI ...@@ -449,7 +424,6 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
// 检查是否需要 URL 降级 // 检查是否需要 URL 降级
if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 { if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("[antigravity] fetchAvailableModels URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1]) log.Printf("[antigravity] fetchAvailableModels URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
continue continue
} }
...@@ -467,6 +441,8 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI ...@@ -467,6 +441,8 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
var rawResp map[string]any var rawResp map[string]any
_ = json.Unmarshal(respBodyBytes, &rawResp) _ = json.Unmarshal(respBodyBytes, &rawResp)
// 标记成功的 URL,下次优先使用
DefaultURLAvailability.MarkSuccess(baseURL)
return &modelsResp, rawResp, nil return &modelsResp, rawResp, nil
} }
......
...@@ -143,9 +143,10 @@ type GeminiResponse struct { ...@@ -143,9 +143,10 @@ type GeminiResponse struct {
// GeminiCandidate Gemini 候选响应 // GeminiCandidate Gemini 候选响应
type GeminiCandidate struct { type GeminiCandidate struct {
Content *GeminiContent `json:"content,omitempty"` Content *GeminiContent `json:"content,omitempty"`
FinishReason string `json:"finishReason,omitempty"` FinishReason string `json:"finishReason,omitempty"`
Index int `json:"index,omitempty"` Index int `json:"index,omitempty"`
GroundingMetadata *GeminiGroundingMetadata `json:"groundingMetadata,omitempty"`
} }
// GeminiUsageMetadata Gemini 用量元数据 // GeminiUsageMetadata Gemini 用量元数据
...@@ -156,6 +157,23 @@ type GeminiUsageMetadata struct { ...@@ -156,6 +157,23 @@ type GeminiUsageMetadata struct {
TotalTokenCount int `json:"totalTokenCount,omitempty"` TotalTokenCount int `json:"totalTokenCount,omitempty"`
} }
// GeminiGroundingMetadata Gemini grounding 元数据(Web Search)
type GeminiGroundingMetadata struct {
WebSearchQueries []string `json:"webSearchQueries,omitempty"`
GroundingChunks []GeminiGroundingChunk `json:"groundingChunks,omitempty"`
}
// GeminiGroundingChunk Gemini grounding chunk
type GeminiGroundingChunk struct {
Web *GeminiGroundingWeb `json:"web,omitempty"`
}
// GeminiGroundingWeb Gemini grounding web 信息
type GeminiGroundingWeb struct {
Title string `json:"title,omitempty"`
URI string `json:"uri,omitempty"`
}
// DefaultSafetySettings 默认安全设置(关闭所有过滤) // DefaultSafetySettings 默认安全设置(关闭所有过滤)
var DefaultSafetySettings = []GeminiSafetySetting{ var DefaultSafetySettings = []GeminiSafetySetting{
{Category: "HARM_CATEGORY_HARASSMENT", Threshold: "OFF"}, {Category: "HARM_CATEGORY_HARASSMENT", Threshold: "OFF"},
......
...@@ -32,8 +32,8 @@ const ( ...@@ -32,8 +32,8 @@ const (
"https://www.googleapis.com/auth/cclog " + "https://www.googleapis.com/auth/cclog " +
"https://www.googleapis.com/auth/experimentsandconfigs" "https://www.googleapis.com/auth/experimentsandconfigs"
// User-Agent(模拟官方客户端 // User-Agent(与 Antigravity-Manager 保持一致
UserAgent = "antigravity/1.104.0 darwin/arm64" UserAgent = "antigravity/1.11.9 windows/amd64"
// Session 过期时间 // Session 过期时间
SessionTTL = 30 * time.Minute SessionTTL = 30 * time.Minute
...@@ -42,22 +42,21 @@ const ( ...@@ -42,22 +42,21 @@ const (
URLAvailabilityTTL = 5 * time.Minute URLAvailabilityTTL = 5 * time.Minute
) )
// BaseURLs 定义 Antigravity API 端点,按优先级排序 // BaseURLs 定义 Antigravity API 端点(与 Antigravity-Manager 保持一致)
// fallback 顺序: sandbox → daily → prod
var BaseURLs = []string{ var BaseURLs = []string{
"https://daily-cloudcode-pa.sandbox.googleapis.com", // sandbox "https://cloudcode-pa.googleapis.com", // prod (优先)
"https://daily-cloudcode-pa.googleapis.com", // daily "https://daily-cloudcode-pa.sandbox.googleapis.com", // daily sandbox (备用)
"https://cloudcode-pa.googleapis.com", // prod
} }
// BaseURL 默认 URL(保持向后兼容) // BaseURL 默认 URL(保持向后兼容)
var BaseURL = BaseURLs[0] var BaseURL = BaseURLs[0]
// URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复) // URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复和动态优先级
type URLAvailability struct { type URLAvailability struct {
mu sync.RWMutex mu sync.RWMutex
unavailable map[string]time.Time // URL -> 恢复时间 unavailable map[string]time.Time // URL -> 恢复时间
ttl time.Duration ttl time.Duration
lastSuccess string // 最近成功请求的 URL,优先使用
} }
// DefaultURLAvailability 全局 URL 可用性管理器 // DefaultURLAvailability 全局 URL 可用性管理器
...@@ -78,6 +77,15 @@ func (u *URLAvailability) MarkUnavailable(url string) { ...@@ -78,6 +77,15 @@ func (u *URLAvailability) MarkUnavailable(url string) {
u.unavailable[url] = time.Now().Add(u.ttl) u.unavailable[url] = time.Now().Add(u.ttl)
} }
// MarkSuccess 标记 URL 请求成功,将其设为优先使用
func (u *URLAvailability) MarkSuccess(url string) {
u.mu.Lock()
defer u.mu.Unlock()
u.lastSuccess = url
// 成功后清除该 URL 的不可用标记
delete(u.unavailable, url)
}
// IsAvailable 检查 URL 是否可用 // IsAvailable 检查 URL 是否可用
func (u *URLAvailability) IsAvailable(url string) bool { func (u *URLAvailability) IsAvailable(url string) bool {
u.mu.RLock() u.mu.RLock()
...@@ -89,14 +97,29 @@ func (u *URLAvailability) IsAvailable(url string) bool { ...@@ -89,14 +97,29 @@ func (u *URLAvailability) IsAvailable(url string) bool {
return time.Now().After(expiry) return time.Now().After(expiry)
} }
// GetAvailableURLs 返回可用的 URL 列表(保持优先级顺序) // GetAvailableURLs 返回可用的 URL 列表
// 最近成功的 URL 优先,其他按默认顺序
func (u *URLAvailability) GetAvailableURLs() []string { func (u *URLAvailability) GetAvailableURLs() []string {
u.mu.RLock() u.mu.RLock()
defer u.mu.RUnlock() defer u.mu.RUnlock()
now := time.Now() now := time.Now()
result := make([]string, 0, len(BaseURLs)) result := make([]string, 0, len(BaseURLs))
// 如果有最近成功的 URL 且可用,放在最前面
if u.lastSuccess != "" {
expiry, exists := u.unavailable[u.lastSuccess]
if !exists || now.After(expiry) {
result = append(result, u.lastSuccess)
}
}
// 添加其他可用的 URL(按默认顺序)
for _, url := range BaseURLs { for _, url := range BaseURLs {
// 跳过已添加的 lastSuccess
if url == u.lastSuccess {
continue
}
expiry, exists := u.unavailable[url] expiry, exists := u.unavailable[url]
if !exists || now.After(expiry) { if !exists || now.After(expiry) {
result = append(result, url) result = append(result, url)
...@@ -240,24 +263,3 @@ func BuildAuthorizationURL(state, codeChallenge string) string { ...@@ -240,24 +263,3 @@ func BuildAuthorizationURL(state, codeChallenge string) string {
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode()) return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
} }
// GenerateMockProjectID 生成随机 project_id(当 API 不返回时使用)
// 格式:{形容词}-{名词}-{5位随机字符}
func GenerateMockProjectID() string {
adjectives := []string{"useful", "bright", "swift", "calm", "bold"}
nouns := []string{"fuze", "wave", "spark", "flow", "core"}
randBytes, _ := GenerateRandomBytes(7)
adj := adjectives[int(randBytes[0])%len(adjectives)]
noun := nouns[int(randBytes[1])%len(nouns)]
// 生成 5 位随机字符(a-z0-9)
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
suffix := make([]byte, 5)
for i := 0; i < 5; i++ {
suffix[i] = charset[int(randBytes[i+2])%len(charset)]
}
return fmt.Sprintf("%s-%s-%s", adj, noun, string(suffix))
}
...@@ -54,6 +54,9 @@ func DefaultTransformOptions() TransformOptions { ...@@ -54,6 +54,9 @@ func DefaultTransformOptions() TransformOptions {
} }
} }
// webSearchFallbackModel web_search 请求使用的降级模型
const webSearchFallbackModel = "gemini-2.5-flash"
// TransformClaudeToGemini 将 Claude 请求转换为 v1internal Gemini 格式 // TransformClaudeToGemini 将 Claude 请求转换为 v1internal Gemini 格式
func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel string) ([]byte, error) { func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel string) ([]byte, error) {
return TransformClaudeToGeminiWithOptions(claudeReq, projectID, mappedModel, DefaultTransformOptions()) return TransformClaudeToGeminiWithOptions(claudeReq, projectID, mappedModel, DefaultTransformOptions())
...@@ -64,12 +67,23 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map ...@@ -64,12 +67,23 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
// 用于存储 tool_use id -> name 映射 // 用于存储 tool_use id -> name 映射
toolIDToName := make(map[string]string) toolIDToName := make(map[string]string)
// 检测是否有 web_search 工具
hasWebSearchTool := hasWebSearchTool(claudeReq.Tools)
requestType := "agent"
targetModel := mappedModel
if hasWebSearchTool {
requestType = "web_search"
if targetModel != webSearchFallbackModel {
targetModel = webSearchFallbackModel
}
}
// 检测是否启用 thinking // 检测是否启用 thinking
isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled" isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
// 只有 Gemini 模型支持 dummy thought workaround // 只有 Gemini 模型支持 dummy thought workaround
// Claude 模型通过 Vertex/Google API 需要有效的 thought signatures // Claude 模型通过 Vertex/Google API 需要有效的 thought signatures
allowDummyThought := strings.HasPrefix(mappedModel, "gemini-") allowDummyThought := strings.HasPrefix(targetModel, "gemini-")
// 1. 构建 contents // 1. 构建 contents
contents, strippedThinking, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought) contents, strippedThinking, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought)
...@@ -78,7 +92,7 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map ...@@ -78,7 +92,7 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
} }
// 2. 构建 systemInstruction // 2. 构建 systemInstruction
systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model, opts) systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model, opts, claudeReq.Tools)
// 3. 构建 generationConfig // 3. 构建 generationConfig
reqForConfig := claudeReq reqForConfig := claudeReq
...@@ -89,6 +103,11 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map ...@@ -89,6 +103,11 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
reqCopy.Thinking = nil reqCopy.Thinking = nil
reqForConfig = &reqCopy reqForConfig = &reqCopy
} }
if targetModel != "" && targetModel != reqForConfig.Model {
reqCopy := *reqForConfig
reqCopy.Model = targetModel
reqForConfig = &reqCopy
}
generationConfig := buildGenerationConfig(reqForConfig) generationConfig := buildGenerationConfig(reqForConfig)
// 4. 构建 tools // 4. 构建 tools
...@@ -127,8 +146,8 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map ...@@ -127,8 +146,8 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
Project: projectID, Project: projectID,
RequestID: "agent-" + uuid.New().String(), RequestID: "agent-" + uuid.New().String(),
UserAgent: "antigravity", // 固定值,与官方客户端一致 UserAgent: "antigravity", // 固定值,与官方客户端一致
RequestType: "agent", RequestType: requestType,
Model: mappedModel, Model: targetModel,
Request: innerRequest, Request: innerRequest,
} }
...@@ -154,8 +173,40 @@ func GetDefaultIdentityPatch() string { ...@@ -154,8 +173,40 @@ func GetDefaultIdentityPatch() string {
return antigravityIdentity return antigravityIdentity
} }
// buildSystemInstruction 构建 systemInstruction // mcpXMLProtocol MCP XML 工具调用协议(与 Antigravity-Manager 保持一致)
func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions) *GeminiContent { const mcpXMLProtocol = `
==== MCP XML 工具调用协议 (Workaround) ====
当你需要调用名称以 ` + "`mcp__`" + ` 开头的 MCP 工具时:
1) 优先尝试 XML 格式调用:输出 ` + "`<mcp__tool_name>{\"arg\":\"value\"}</mcp__tool_name>`" + `。
2) 必须直接输出 XML 块,无需 markdown 包装,内容为 JSON 格式的入参。
3) 这种方式具有更高的连通性和容错性,适用于大型结果返回场景。
===========================================`
// hasMCPTools 检测是否有 mcp__ 前缀的工具
func hasMCPTools(tools []ClaudeTool) bool {
for _, tool := range tools {
if strings.HasPrefix(tool.Name, "mcp__") {
return true
}
}
return false
}
// filterOpenCodePrompt 过滤 OpenCode 默认提示词,只保留用户自定义指令
func filterOpenCodePrompt(text string) string {
if !strings.Contains(text, "You are an interactive CLI tool") {
return text
}
// 提取 "Instructions from:" 及之后的部分
if idx := strings.Index(text, "Instructions from:"); idx >= 0 {
return text[idx:]
}
// 如果没有自定义指令,返回空
return ""
}
// buildSystemInstruction 构建 systemInstruction(与 Antigravity-Manager 保持一致)
func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions, tools []ClaudeTool) *GeminiContent {
var parts []GeminiPart var parts []GeminiPart
// 先解析用户的 system prompt,检测是否已包含 Antigravity identity // 先解析用户的 system prompt,检测是否已包含 Antigravity identity
...@@ -167,10 +218,14 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans ...@@ -167,10 +218,14 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
var sysStr string var sysStr string
if err := json.Unmarshal(system, &sysStr); err == nil { if err := json.Unmarshal(system, &sysStr); err == nil {
if strings.TrimSpace(sysStr) != "" { if strings.TrimSpace(sysStr) != "" {
userSystemParts = append(userSystemParts, GeminiPart{Text: sysStr})
if strings.Contains(sysStr, "You are Antigravity") { if strings.Contains(sysStr, "You are Antigravity") {
userHasAntigravityIdentity = true userHasAntigravityIdentity = true
} }
// 过滤 OpenCode 默认提示词
filtered := filterOpenCodePrompt(sysStr)
if filtered != "" {
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
}
} }
} else { } else {
// 尝试解析为数组 // 尝试解析为数组
...@@ -178,10 +233,14 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans ...@@ -178,10 +233,14 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
if err := json.Unmarshal(system, &sysBlocks); err == nil { if err := json.Unmarshal(system, &sysBlocks); err == nil {
for _, block := range sysBlocks { for _, block := range sysBlocks {
if block.Type == "text" && strings.TrimSpace(block.Text) != "" { if block.Type == "text" && strings.TrimSpace(block.Text) != "" {
userSystemParts = append(userSystemParts, GeminiPart{Text: block.Text})
if strings.Contains(block.Text, "You are Antigravity") { if strings.Contains(block.Text, "You are Antigravity") {
userHasAntigravityIdentity = true userHasAntigravityIdentity = true
} }
// 过滤 OpenCode 默认提示词
filtered := filterOpenCodePrompt(block.Text)
if filtered != "" {
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
}
} }
} }
} }
...@@ -200,6 +259,16 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans ...@@ -200,6 +259,16 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
// 添加用户的 system prompt // 添加用户的 system prompt
parts = append(parts, userSystemParts...) parts = append(parts, userSystemParts...)
// 检测是否有 MCP 工具,如有则注入 XML 调用协议
if hasMCPTools(tools) {
parts = append(parts, GeminiPart{Text: mcpXMLProtocol})
}
// 如果用户没有提供 Antigravity 身份,添加结束标记
if !userHasAntigravityIdentity {
parts = append(parts, GeminiPart{Text: "\n--- [SYSTEM_PROMPT_END] ---"})
}
if len(parts) == 0 { if len(parts) == 0 {
return nil return nil
} }
...@@ -429,6 +498,11 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig { ...@@ -429,6 +498,11 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
StopSequences: DefaultStopSequences, StopSequences: DefaultStopSequences,
} }
// 如果请求中指定了 MaxTokens,使用请求值
if req.MaxTokens > 0 {
config.MaxOutputTokens = req.MaxTokens
}
// Thinking 配置 // Thinking 配置
if req.Thinking != nil && req.Thinking.Type == "enabled" { if req.Thinking != nil && req.Thinking.Type == "enabled" {
config.ThinkingConfig = &GeminiThinkingConfig{ config.ThinkingConfig = &GeminiThinkingConfig{
...@@ -458,37 +532,43 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig { ...@@ -458,37 +532,43 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
return config return config
} }
func hasWebSearchTool(tools []ClaudeTool) bool {
for _, tool := range tools {
if isWebSearchTool(tool) {
return true
}
}
return false
}
func isWebSearchTool(tool ClaudeTool) bool {
if strings.HasPrefix(tool.Type, "web_search") || tool.Type == "google_search" {
return true
}
name := strings.TrimSpace(tool.Name)
switch name {
case "web_search", "google_search", "web_search_20250305":
return true
default:
return false
}
}
// buildTools 构建 tools // buildTools 构建 tools
func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
if len(tools) == 0 { if len(tools) == 0 {
return nil return nil
} }
// 检查是否有 web_search 工具 hasWebSearch := hasWebSearchTool(tools)
hasWebSearch := false
for _, tool := range tools {
if tool.Name == "web_search" {
hasWebSearch = true
break
}
}
if hasWebSearch {
// Web Search 工具映射
return []GeminiToolDeclaration{{
GoogleSearch: &GeminiGoogleSearch{
EnhancedContent: &GeminiEnhancedContent{
ImageSearch: &GeminiImageSearch{
MaxResultCount: 5,
},
},
},
}}
}
// 普通工具 // 普通工具
var funcDecls []GeminiFunctionDecl var funcDecls []GeminiFunctionDecl
for _, tool := range tools { for _, tool := range tools {
if isWebSearchTool(tool) {
continue
}
// 跳过无效工具名称 // 跳过无效工具名称
if strings.TrimSpace(tool.Name) == "" { if strings.TrimSpace(tool.Name) == "" {
log.Printf("Warning: skipping tool with empty name") log.Printf("Warning: skipping tool with empty name")
...@@ -531,7 +611,20 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { ...@@ -531,7 +611,20 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
} }
if len(funcDecls) == 0 { if len(funcDecls) == 0 {
return nil if !hasWebSearch {
return nil
}
// Web Search 工具映射
return []GeminiToolDeclaration{{
GoogleSearch: &GeminiGoogleSearch{
EnhancedContent: &GeminiEnhancedContent{
ImageSearch: &GeminiImageSearch{
MaxResultCount: 5,
},
},
},
}}
} }
return []GeminiToolDeclaration{{ return []GeminiToolDeclaration{{
......
...@@ -3,6 +3,7 @@ package antigravity ...@@ -3,6 +3,7 @@ package antigravity
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"strings"
) )
// TransformGeminiToClaude 将 Gemini 响应转换为 Claude 格式(非流式) // TransformGeminiToClaude 将 Gemini 响应转换为 Claude 格式(非流式)
...@@ -63,6 +64,12 @@ func (p *NonStreamingProcessor) Process(geminiResp *GeminiResponse, responseID, ...@@ -63,6 +64,12 @@ func (p *NonStreamingProcessor) Process(geminiResp *GeminiResponse, responseID,
p.processPart(&part) p.processPart(&part)
} }
if len(geminiResp.Candidates) > 0 {
if grounding := geminiResp.Candidates[0].GroundingMetadata; grounding != nil {
p.processGrounding(grounding)
}
}
// 刷新剩余内容 // 刷新剩余内容
p.flushThinking() p.flushThinking()
p.flushText() p.flushText()
...@@ -190,6 +197,18 @@ func (p *NonStreamingProcessor) processPart(part *GeminiPart) { ...@@ -190,6 +197,18 @@ func (p *NonStreamingProcessor) processPart(part *GeminiPart) {
} }
} }
func (p *NonStreamingProcessor) processGrounding(grounding *GeminiGroundingMetadata) {
groundingText := buildGroundingText(grounding)
if groundingText == "" {
return
}
p.flushThinking()
p.flushText()
p.textBuilder += groundingText
p.flushText()
}
// flushText 刷新 text builder // flushText 刷新 text builder
func (p *NonStreamingProcessor) flushText() { func (p *NonStreamingProcessor) flushText() {
if p.textBuilder == "" { if p.textBuilder == "" {
...@@ -262,6 +281,44 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon ...@@ -262,6 +281,44 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon
} }
} }
func buildGroundingText(grounding *GeminiGroundingMetadata) string {
if grounding == nil {
return ""
}
var builder strings.Builder
if len(grounding.WebSearchQueries) > 0 {
_, _ = builder.WriteString("\n\n---\nWeb search queries: ")
_, _ = builder.WriteString(strings.Join(grounding.WebSearchQueries, ", "))
}
if len(grounding.GroundingChunks) > 0 {
var links []string
for i, chunk := range grounding.GroundingChunks {
if chunk.Web == nil {
continue
}
title := strings.TrimSpace(chunk.Web.Title)
if title == "" {
title = "Source"
}
uri := strings.TrimSpace(chunk.Web.URI)
if uri == "" {
uri = "#"
}
links = append(links, fmt.Sprintf("[%d] [%s](%s)", i+1, title, uri))
}
if len(links) > 0 {
_, _ = builder.WriteString("\n\nSources:\n")
_, _ = builder.WriteString(strings.Join(links, "\n"))
}
}
return builder.String()
}
// generateRandomID 生成随机 ID // generateRandomID 生成随机 ID
func generateRandomID() string { func generateRandomID() string {
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
......
...@@ -27,6 +27,8 @@ type StreamingProcessor struct { ...@@ -27,6 +27,8 @@ type StreamingProcessor struct {
pendingSignature string pendingSignature string
trailingSignature string trailingSignature string
originalModel string originalModel string
webSearchQueries []string
groundingChunks []GeminiGroundingChunk
// 累计 usage // 累计 usage
inputTokens int inputTokens int
...@@ -93,6 +95,10 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte { ...@@ -93,6 +95,10 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
} }
} }
if len(geminiResp.Candidates) > 0 {
p.captureGrounding(geminiResp.Candidates[0].GroundingMetadata)
}
// 检查是否结束 // 检查是否结束
if len(geminiResp.Candidates) > 0 { if len(geminiResp.Candidates) > 0 {
finishReason := geminiResp.Candidates[0].FinishReason finishReason := geminiResp.Candidates[0].FinishReason
...@@ -200,6 +206,20 @@ func (p *StreamingProcessor) processPart(part *GeminiPart) []byte { ...@@ -200,6 +206,20 @@ func (p *StreamingProcessor) processPart(part *GeminiPart) []byte {
return result.Bytes() return result.Bytes()
} }
func (p *StreamingProcessor) captureGrounding(grounding *GeminiGroundingMetadata) {
if grounding == nil {
return
}
if len(grounding.WebSearchQueries) > 0 && len(p.webSearchQueries) == 0 {
p.webSearchQueries = append([]string(nil), grounding.WebSearchQueries...)
}
if len(grounding.GroundingChunks) > 0 && len(p.groundingChunks) == 0 {
p.groundingChunks = append([]GeminiGroundingChunk(nil), grounding.GroundingChunks...)
}
}
// processThinking 处理 thinking // processThinking 处理 thinking
func (p *StreamingProcessor) processThinking(text, signature string) []byte { func (p *StreamingProcessor) processThinking(text, signature string) []byte {
var result bytes.Buffer var result bytes.Buffer
...@@ -417,6 +437,23 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte { ...@@ -417,6 +437,23 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
p.trailingSignature = "" p.trailingSignature = ""
} }
if len(p.webSearchQueries) > 0 || len(p.groundingChunks) > 0 {
groundingText := buildGroundingText(&GeminiGroundingMetadata{
WebSearchQueries: p.webSearchQueries,
GroundingChunks: p.groundingChunks,
})
if groundingText != "" {
_, _ = result.Write(p.startBlock(BlockTypeText, map[string]any{
"type": "text",
"text": "",
}))
_, _ = result.Write(p.emitDelta("text_delta", map[string]any{
"text": groundingText,
}))
_, _ = result.Write(p.endBlock())
}
}
// 确定 stop_reason // 确定 stop_reason
stopReason := "end_turn" stopReason := "end_turn"
if p.usedTool { if p.usedTool {
......
...@@ -543,6 +543,15 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str ...@@ -543,6 +543,15 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str
return nil return nil
} }
func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
_, err := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
SetStatus(service.StatusActive).
SetErrorMessage("").
Save(ctx)
return err
}
func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error { func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
_, err := r.client.AccountGroup.Create(). _, err := r.client.AccountGroup.Create().
SetAccountID(accountID). SetAccountID(accountID).
......
...@@ -744,6 +744,10 @@ func (s *stubAccountRepo) SetError(ctx context.Context, id int64, errorMsg strin ...@@ -744,6 +744,10 @@ func (s *stubAccountRepo) SetError(ctx context.Context, id int64, errorMsg strin
return errors.New("not implemented") return errors.New("not implemented")
} }
func (s *stubAccountRepo) ClearError(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (s *stubAccountRepo) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { func (s *stubAccountRepo) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
......
...@@ -37,6 +37,7 @@ type AccountRepository interface { ...@@ -37,6 +37,7 @@ type AccountRepository interface {
UpdateLastUsed(ctx context.Context, id int64) error UpdateLastUsed(ctx context.Context, id int64) error
BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error
SetError(ctx context.Context, id int64, errorMsg string) error SetError(ctx context.Context, id int64, errorMsg string) error
ClearError(ctx context.Context, id int64) error
SetSchedulable(ctx context.Context, id int64, schedulable bool) error SetSchedulable(ctx context.Context, id int64, schedulable bool) error
AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error)
BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error
......
...@@ -99,6 +99,10 @@ func (s *accountRepoStub) SetError(ctx context.Context, id int64, errorMsg strin ...@@ -99,6 +99,10 @@ func (s *accountRepoStub) SetError(ctx context.Context, id int64, errorMsg strin
panic("unexpected SetError call") panic("unexpected SetError call")
} }
func (s *accountRepoStub) ClearError(ctx context.Context, id int64) error {
panic("unexpected ClearError call")
}
func (s *accountRepoStub) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { func (s *accountRepoStub) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
panic("unexpected SetSchedulable call") panic("unexpected SetSchedulable call")
} }
......
...@@ -42,6 +42,7 @@ type AdminService interface { ...@@ -42,6 +42,7 @@ type AdminService interface {
DeleteAccount(ctx context.Context, id int64) error DeleteAccount(ctx context.Context, id int64) error
RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error) RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error)
ClearAccountError(ctx context.Context, id int64) (*Account, error) ClearAccountError(ctx context.Context, id int64) (*Account, error)
SetAccountError(ctx context.Context, id int64, errorMsg string) error
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error)
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
...@@ -1101,6 +1102,10 @@ func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Ac ...@@ -1101,6 +1102,10 @@ func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Ac
return account, nil return account, nil
} }
func (s *adminServiceImpl) SetAccountError(ctx context.Context, id int64, errorMsg string) error {
return s.accountRepo.SetError(ctx, id, errorMsg)
}
func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) { func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) {
if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil { if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil {
return nil, err return nil, err
......
...@@ -82,13 +82,14 @@ type AntigravityExchangeCodeInput struct { ...@@ -82,13 +82,14 @@ type AntigravityExchangeCodeInput struct {
// AntigravityTokenInfo token 信息 // AntigravityTokenInfo token 信息
type AntigravityTokenInfo struct { type AntigravityTokenInfo struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"` RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"` ExpiresIn int64 `json:"expires_in"`
ExpiresAt int64 `json:"expires_at"` ExpiresAt int64 `json:"expires_at"`
TokenType string `json:"token_type"` TokenType string `json:"token_type"`
Email string `json:"email,omitempty"` Email string `json:"email,omitempty"`
ProjectID string `json:"project_id,omitempty"` ProjectID string `json:"project_id,omitempty"`
ProjectIDMissing bool `json:"-"` // LoadCodeAssist 未返回 project_id
} }
// ExchangeCode 用 authorization code 交换 token // ExchangeCode 用 authorization code 交换 token
...@@ -149,12 +150,6 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig ...@@ -149,12 +150,6 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
result.ProjectID = loadResp.CloudAICompanionProject result.ProjectID = loadResp.CloudAICompanionProject
} }
// 兜底:随机生成 project_id
if result.ProjectID == "" {
result.ProjectID = antigravity.GenerateMockProjectID()
fmt.Printf("[AntigravityOAuth] 使用随机生成的 project_id: %s\n", result.ProjectID)
}
return result, nil return result, nil
} }
...@@ -236,16 +231,24 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou ...@@ -236,16 +231,24 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou
return nil, err return nil, err
} }
// 保留原有的 project_id 和 email // 保留原有的 email
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
if existingProjectID != "" {
tokenInfo.ProjectID = existingProjectID
}
existingEmail := strings.TrimSpace(account.GetCredential("email")) existingEmail := strings.TrimSpace(account.GetCredential("email"))
if existingEmail != "" { if existingEmail != "" {
tokenInfo.Email = existingEmail tokenInfo.Email = existingEmail
} }
// 每次刷新都调用 LoadCodeAssist 获取 project_id
client := antigravity.NewClient(proxyURL)
loadResp, _, err := client.LoadCodeAssist(ctx, tokenInfo.AccessToken)
if err != nil || loadResp == nil || loadResp.CloudAICompanionProject == "" {
// LoadCodeAssist 失败或返回空,保留原有 project_id,标记缺失
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
tokenInfo.ProjectID = existingProjectID
tokenInfo.ProjectIDMissing = true
} else {
tokenInfo.ProjectID = loadResp.CloudAICompanionProject
}
return tokenInfo, nil return tokenInfo, nil
} }
......
...@@ -31,11 +31,6 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou ...@@ -31,11 +31,6 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
accessToken := account.GetCredential("access_token") accessToken := account.GetCredential("access_token")
projectID := account.GetCredential("project_id") projectID := account.GetCredential("project_id")
// 如果没有 project_id,生成一个随机的
if projectID == "" {
projectID = antigravity.GenerateMockProjectID()
}
client := antigravity.NewClient(proxyURL) client := antigravity.NewClient(proxyURL)
// 调用 API 获取配额 // 调用 API 获取配额
......
//go:build unit
package service
import (
"context"
"fmt"
"io"
"net/http"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/stretchr/testify/require"
)
type stubAntigravityUpstream struct {
firstBase string
secondBase string
calls []string
}
func (s *stubAntigravityUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
url := req.URL.String()
s.calls = append(s.calls, url)
if strings.HasPrefix(url, s.firstBase) {
return &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"error":{"message":"Resource has been exhausted"}}`)),
}, nil
}
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader("ok")),
}, nil
}
type scopeLimitCall struct {
accountID int64
scope AntigravityQuotaScope
resetAt time.Time
}
type rateLimitCall struct {
accountID int64
resetAt time.Time
}
type stubAntigravityAccountRepo struct {
AccountRepository
scopeCalls []scopeLimitCall
rateCalls []rateLimitCall
}
func (s *stubAntigravityAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
s.scopeCalls = append(s.scopeCalls, scopeLimitCall{accountID: id, scope: scope, resetAt: resetAt})
return nil
}
func (s *stubAntigravityAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
s.rateCalls = append(s.rateCalls, rateLimitCall{accountID: id, resetAt: resetAt})
return nil
}
func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
oldAvailability := antigravity.DefaultURLAvailability
defer func() {
antigravity.BaseURLs = oldBaseURLs
antigravity.DefaultURLAvailability = oldAvailability
}()
base1 := "https://ag-1.test"
base2 := "https://ag-2.test"
antigravity.BaseURLs = []string{base1, base2}
antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
upstream := &stubAntigravityUpstream{firstBase: base1, secondBase: base2}
account := &Account{
ID: 1,
Name: "acc-1",
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
}
var handleErrorCalled bool
result, err := antigravityRetryLoop(antigravityRetryLoopParams{
prefix: "[test]",
ctx: context.Background(),
account: account,
proxyURL: "",
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
quotaScope: AntigravityQuotaScopeClaude,
httpUpstream: upstream,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {
handleErrorCalled = true
},
})
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.resp)
defer func() { _ = result.resp.Body.Close() }()
require.Equal(t, http.StatusOK, result.resp.StatusCode)
require.False(t, handleErrorCalled)
require.Len(t, upstream.calls, 2)
require.True(t, strings.HasPrefix(upstream.calls[0], base1))
require.True(t, strings.HasPrefix(upstream.calls[1], base2))
available := antigravity.DefaultURLAvailability.GetAvailableURLs()
require.NotEmpty(t, available)
require.Equal(t, base2, available[0])
}
func TestAntigravityHandleUpstreamError_UsesScopeLimitWhenEnabled(t *testing.T) {
t.Setenv(antigravityScopeRateLimitEnv, "true")
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 9, Name: "acc-9", Platform: PlatformAntigravity}
body := buildGeminiRateLimitBody("3s")
svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude)
require.Len(t, repo.scopeCalls, 1)
require.Empty(t, repo.rateCalls)
call := repo.scopeCalls[0]
require.Equal(t, account.ID, call.accountID)
require.Equal(t, AntigravityQuotaScopeClaude, call.scope)
require.WithinDuration(t, time.Now().Add(3*time.Second), call.resetAt, 2*time.Second)
}
func TestAntigravityHandleUpstreamError_UsesAccountLimitWhenScopeDisabled(t *testing.T) {
t.Setenv(antigravityScopeRateLimitEnv, "false")
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 10, Name: "acc-10", Platform: PlatformAntigravity}
body := buildGeminiRateLimitBody("2s")
svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude)
require.Len(t, repo.rateCalls, 1)
require.Empty(t, repo.scopeCalls)
call := repo.rateCalls[0]
require.Equal(t, account.ID, call.accountID)
require.WithinDuration(t, time.Now().Add(2*time.Second), call.resetAt, 2*time.Second)
}
func TestAccountIsSchedulableForModel_AntigravityRateLimits(t *testing.T) {
now := time.Now()
future := now.Add(10 * time.Minute)
account := &Account{
ID: 1,
Name: "acc",
Platform: PlatformAntigravity,
Status: StatusActive,
Schedulable: true,
}
account.RateLimitResetAt = &future
require.False(t, account.IsSchedulableForModel("claude-sonnet-4-5"))
require.False(t, account.IsSchedulableForModel("gemini-3-flash"))
account.RateLimitResetAt = nil
account.Extra = map[string]any{
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": future.Format(time.RFC3339),
},
},
}
require.False(t, account.IsSchedulableForModel("claude-sonnet-4-5"))
require.True(t, account.IsSchedulableForModel("gemini-3-flash"))
}
func buildGeminiRateLimitBody(delay string) []byte {
return []byte(fmt.Sprintf(`{"error":{"message":"too many requests","details":[{"metadata":{"quotaResetDelay":%q}}]}}`, delay))
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment