Commit e83f0ee3 authored by yangjianbo's avatar yangjianbo
Browse files

Merge branch 'main' into test-dev

parents bff3c66d 942c3e15
...@@ -297,6 +297,32 @@ go generate ./cmd/server ...@@ -297,6 +297,32 @@ go generate ./cmd/server
--- ---
## Antigravity Support
Sub2API supports [Antigravity](https://antigravity.so/) accounts. After authorization, dedicated endpoints are available for Claude and Gemini models.
### Dedicated Endpoints
| Endpoint | Model |
|----------|-------|
| `/antigravity/v1/messages` | Claude models |
| `/antigravity/v1beta/` | Gemini models |
### Claude Code Configuration
```bash
export ANTHROPIC_BASE_URL="http://localhost:8080/antigravity"
export ANTHROPIC_AUTH_TOKEN="sk-xxx"
```
### Hybrid Scheduling Mode
Antigravity accounts support optional **hybrid scheduling**. When enabled, the general endpoints `/v1/messages` and `/v1beta/` will also route requests to Antigravity accounts.
> **⚠️ Warning**: Anthropic Claude and Antigravity Claude **cannot be mixed within the same conversation context**. Use groups to isolate them properly.
---
## Project Structure ## Project Structure
``` ```
......
...@@ -307,6 +307,32 @@ go generate ./cmd/server ...@@ -307,6 +307,32 @@ go generate ./cmd/server
--- ---
## Antigravity 使用说明
Sub2API 支持 [Antigravity](https://antigravity.so/) 账户,授权后可通过专用端点访问 Claude 和 Gemini 模型。
### 专用端点
| 端点 | 模型 |
|------|------|
| `/antigravity/v1/messages` | Claude 模型 |
| `/antigravity/v1beta/` | Gemini 模型 |
### Claude Code 配置示例
```bash
export ANTHROPIC_BASE_URL="http://localhost:8080/antigravity"
export ANTHROPIC_AUTH_TOKEN="sk-xxx"
```
### 混合调度模式
Antigravity 账户支持可选的**混合调度**功能。开启后,通用端点 `/v1/messages``/v1beta/` 也会调度该账户。
> **⚠️ 注意**:Anthropic Claude 和 Antigravity Claude **不能在同一上下文中混合使用**,请通过分组功能做好隔离。
---
## 项目结构 ## 项目结构
``` ```
......
.PHONY: wire build build-embed test-unit test-integration test-cover-integration clean-coverage clean .PHONY: wire build build-embed test-unit test-integration test-e2e test-cover-integration clean-coverage
wire: wire:
@echo "生成 Wire 代码..." @echo "生成 Wire 代码..."
...@@ -21,6 +21,10 @@ test-unit: ...@@ -21,6 +21,10 @@ test-unit:
test-integration: test-integration:
@go test -tags integration ./... -count=1 -race -parallel=8 @go test -tags integration ./... -count=1 -race -parallel=8
test-e2e:
@echo "运行 E2E 测试(需要本地服务器运行)..."
@go test -tags e2e ./internal/integration/... -count=1 -v
test-cover-integration: test-cover-integration:
@echo "运行集成测试并生成覆盖率报告..." @echo "运行集成测试并生成覆盖率报告..."
@go test -tags=integration -cover -coverprofile=coverage.out -count=1 -race -parallel=8 ./... @go test -tags=integration -cover -coverprofile=coverage.out -count=1 -race -parallel=8 ./...
......
...@@ -29,26 +29,26 @@ type Application struct { ...@@ -29,26 +29,26 @@ type Application struct {
func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
wire.Build( wire.Build(
// 基础设施层 ProviderSets // Infrastructure layer ProviderSets
config.ProviderSet, config.ProviderSet,
infrastructure.ProviderSet, infrastructure.ProviderSet,
// 业务层 ProviderSets // Business layer ProviderSets
repository.ProviderSet, repository.ProviderSet,
service.ProviderSet, service.ProviderSet,
middleware.ProviderSet, middleware.ProviderSet,
handler.ProviderSet, handler.ProviderSet,
// 服务器层 ProviderSet // Server layer ProviderSet
server.ProviderSet, server.ProviderSet,
// BuildInfo provider // BuildInfo provider
provideServiceBuildInfo, provideServiceBuildInfo,
// 清理函数提供者 // Cleanup function provider
provideCleanup, provideCleanup,
// 应用程序结构体 // Application struct
wire.Struct(new(Application), "Server", "Cleanup"), wire.Struct(new(Application), "Server", "Cleanup"),
) )
return nil, nil return nil, nil
...@@ -70,6 +70,8 @@ func provideCleanup( ...@@ -70,6 +70,8 @@ func provideCleanup(
oauth *service.OAuthService, oauth *service.OAuthService,
openaiOAuth *service.OpenAIOAuthService, openaiOAuth *service.OpenAIOAuthService,
geminiOAuth *service.GeminiOAuthService, geminiOAuth *service.GeminiOAuthService,
antigravityOAuth *service.AntigravityOAuthService,
antigravityQuota *service.AntigravityQuotaRefresher,
) func() { ) func() {
return func() { return func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
...@@ -104,6 +106,14 @@ func provideCleanup( ...@@ -104,6 +106,14 @@ func provideCleanup(
geminiOAuth.Stop() geminiOAuth.Stop()
return nil return nil
}}, }},
{"AntigravityOAuthService", func() error {
antigravityOAuth.Stop()
return nil
}},
{"AntigravityQuotaRefresher", func() error {
antigravityQuota.Stop()
return nil
}},
{"Redis", func() error { {"Redis", func() error {
return rdb.Close() return rdb.Close()
}}, }},
......
...@@ -102,6 +102,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -102,6 +102,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
oAuthHandler := admin.NewOAuthHandler(oAuthService) oAuthHandler := admin.NewOAuthHandler(oAuthService)
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService) openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService) geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository)
antigravityOAuthHandler := admin.NewAntigravityOAuthHandler(antigravityOAuthService)
proxyHandler := admin.NewProxyHandler(adminService) proxyHandler := admin.NewProxyHandler(adminService)
adminRedeemHandler := admin.NewRedeemHandler(adminService) adminRedeemHandler := admin.NewRedeemHandler(adminService)
settingHandler := admin.NewSettingHandler(settingService, emailService) settingHandler := admin.NewSettingHandler(settingService, emailService)
...@@ -112,7 +114,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -112,7 +114,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
systemHandler := handler.ProvideSystemHandler(updateService) systemHandler := handler.ProvideSystemHandler(updateService)
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService) adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService) adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler) adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler)
gatewayCache := repository.NewGatewayCache(redisClient) gatewayCache := repository.NewGatewayCache(redisClient)
pricingRemoteClient := repository.NewPricingRemoteClient() pricingRemoteClient := repository.NewPricingRemoteClient()
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient) pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
...@@ -124,9 +126,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -124,9 +126,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
identityService := service.NewIdentityService(identityCache) identityService := service.NewIdentityService(identityCache)
timingWheelService := service.ProvideTimingWheelService() timingWheelService := service.ProvideTimingWheelService()
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService) gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream) antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, userService, concurrencyService, billingCacheService) antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
...@@ -136,8 +140,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -136,8 +140,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
apiKeyAuthMiddleware := middleware.NewApiKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig) apiKeyAuthMiddleware := middleware.NewApiKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService) engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService)
httpServer := server.ProvideHTTPServer(configConfig, engine) httpServer := server.ProvideHTTPServer(configConfig, engine)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig)
v := provideCleanup(client, redisClient, tokenRefreshService, pricingService, emailQueueService, oAuthService, openAIOAuthService, geminiOAuthService) antigravityQuotaRefresher := service.ProvideAntigravityQuotaRefresher(accountRepository, proxyRepository, antigravityOAuthService, configConfig)
v := provideCleanup(client, redisClient, tokenRefreshService, pricingService, emailQueueService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, antigravityQuotaRefresher)
application := &Application{ application := &Application{
Server: httpServer, Server: httpServer,
Cleanup: v, Cleanup: v,
...@@ -168,6 +173,8 @@ func provideCleanup( ...@@ -168,6 +173,8 @@ func provideCleanup(
oauth *service.OAuthService, oauth *service.OAuthService,
openaiOAuth *service.OpenAIOAuthService, openaiOAuth *service.OpenAIOAuthService,
geminiOAuth *service.GeminiOAuthService, geminiOAuth *service.GeminiOAuthService,
antigravityOAuth *service.AntigravityOAuthService,
antigravityQuota *service.AntigravityQuotaRefresher,
) func() { ) func() {
return func() { return func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
...@@ -201,6 +208,14 @@ func provideCleanup( ...@@ -201,6 +208,14 @@ func provideCleanup(
geminiOAuth.Stop() geminiOAuth.Stop()
return nil return nil
}}, }},
{"AntigravityOAuthService", func() error {
antigravityOAuth.Stop()
return nil
}},
{"AntigravityQuotaRefresher", func() error {
antigravityQuota.Stop()
return nil
}},
{"Redis", func() error { {"Redis", func() error {
return rdb.Close() return rdb.Close()
}}, }},
......
package admin
import (
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type AntigravityOAuthHandler struct {
antigravityOAuthService *service.AntigravityOAuthService
}
func NewAntigravityOAuthHandler(antigravityOAuthService *service.AntigravityOAuthService) *AntigravityOAuthHandler {
return &AntigravityOAuthHandler{antigravityOAuthService: antigravityOAuthService}
}
type AntigravityGenerateAuthURLRequest struct {
ProxyID *int64 `json:"proxy_id"`
}
// GenerateAuthURL generates Google OAuth authorization URL
// POST /api/v1/admin/antigravity/oauth/auth-url
func (h *AntigravityOAuthHandler) GenerateAuthURL(c *gin.Context) {
var req AntigravityGenerateAuthURLRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "请求无效: "+err.Error())
return
}
result, err := h.antigravityOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID)
if err != nil {
response.InternalError(c, "生成授权链接失败: "+err.Error())
return
}
response.Success(c, result)
}
type AntigravityExchangeCodeRequest struct {
SessionID string `json:"session_id" binding:"required"`
State string `json:"state" binding:"required"`
Code string `json:"code" binding:"required"`
ProxyID *int64 `json:"proxy_id"`
}
// ExchangeCode 用 authorization code 交换 token
// POST /api/v1/admin/antigravity/oauth/exchange-code
func (h *AntigravityOAuthHandler) ExchangeCode(c *gin.Context) {
var req AntigravityExchangeCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "请求无效: "+err.Error())
return
}
tokenInfo, err := h.antigravityOAuthService.ExchangeCode(c.Request.Context(), &service.AntigravityExchangeCodeInput{
SessionID: req.SessionID,
State: req.State,
Code: req.Code,
ProxyID: req.ProxyID,
})
if err != nil {
response.BadRequest(c, "Token 交换失败: "+err.Error())
return
}
response.Success(c, tokenInfo)
}
...@@ -26,7 +26,7 @@ func NewGroupHandler(adminService service.AdminService) *GroupHandler { ...@@ -26,7 +26,7 @@ func NewGroupHandler(adminService service.AdminService) *GroupHandler {
type CreateGroupRequest struct { type CreateGroupRequest struct {
Name string `json:"name" binding:"required"` Name string `json:"name" binding:"required"`
Description string `json:"description"` Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini"` Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
RateMultiplier float64 `json:"rate_multiplier"` RateMultiplier float64 `json:"rate_multiplier"`
IsExclusive bool `json:"is_exclusive"` IsExclusive bool `json:"is_exclusive"`
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"` SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
...@@ -39,7 +39,7 @@ type CreateGroupRequest struct { ...@@ -39,7 +39,7 @@ type CreateGroupRequest struct {
type UpdateGroupRequest struct { type UpdateGroupRequest struct {
Name string `json:"name"` Name string `json:"name"`
Description string `json:"description"` Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini"` Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
RateMultiplier *float64 `json:"rate_multiplier"` RateMultiplier *float64 `json:"rate_multiplier"`
IsExclusive *bool `json:"is_exclusive"` IsExclusive *bool `json:"is_exclusive"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"` Status string `json:"status" binding:"omitempty,oneof=active inactive"`
......
...@@ -23,6 +23,7 @@ import ( ...@@ -23,6 +23,7 @@ import (
type GatewayHandler struct { type GatewayHandler struct {
gatewayService *service.GatewayService gatewayService *service.GatewayService
geminiCompatService *service.GeminiMessagesCompatService geminiCompatService *service.GeminiMessagesCompatService
antigravityGatewayService *service.AntigravityGatewayService
userService *service.UserService userService *service.UserService
billingCacheService *service.BillingCacheService billingCacheService *service.BillingCacheService
concurrencyHelper *ConcurrencyHelper concurrencyHelper *ConcurrencyHelper
...@@ -32,6 +33,7 @@ type GatewayHandler struct { ...@@ -32,6 +33,7 @@ type GatewayHandler struct {
func NewGatewayHandler( func NewGatewayHandler(
gatewayService *service.GatewayService, gatewayService *service.GatewayService,
geminiCompatService *service.GeminiMessagesCompatService, geminiCompatService *service.GeminiMessagesCompatService,
antigravityGatewayService *service.AntigravityGatewayService,
userService *service.UserService, userService *service.UserService,
concurrencyService *service.ConcurrencyService, concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService, billingCacheService *service.BillingCacheService,
...@@ -39,6 +41,7 @@ func NewGatewayHandler( ...@@ -39,6 +41,7 @@ func NewGatewayHandler(
return &GatewayHandler{ return &GatewayHandler{
gatewayService: gatewayService, gatewayService: gatewayService,
geminiCompatService: geminiCompatService, geminiCompatService: geminiCompatService,
antigravityGatewayService: antigravityGatewayService,
userService: userService, userService: userService,
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude), concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude),
...@@ -123,8 +126,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -123,8 +126,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 计算粘性会话hash // 计算粘性会话hash
sessionHash := h.gatewayService.GenerateSessionHash(body) sessionHash := h.gatewayService.GenerateSessionHash(body)
// 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则使用分组平台
platform := "" platform := ""
if apiKey.Group != nil { if forcePlatform, ok := middleware2.GetForcePlatformFromContext(c); ok {
platform = forcePlatform
} else if apiKey.Group != nil {
platform = apiKey.Group.Platform platform = apiKey.Group.Platform
} }
...@@ -163,8 +169,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -163,8 +169,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return return
} }
// 转发请求 // 转发请求 - 根据账号平台分流
result, err := h.geminiCompatService.Forward(c.Request.Context(), c, account, body) var result *service.ForwardResult
if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, req.Model, "generateContent", req.Stream, body)
} else {
result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
}
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
accountReleaseFunc() accountReleaseFunc()
} }
...@@ -240,8 +251,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -240,8 +251,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return return
} }
// 转发请求 // 转发请求 - 根据账号平台分流
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body) var result *service.ForwardResult
if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, body)
} else {
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, body)
}
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
accountReleaseFunc() accountReleaseFunc()
} }
......
...@@ -25,13 +25,28 @@ func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) { ...@@ -25,13 +25,28 @@ func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
googleError(c, http.StatusUnauthorized, "Invalid API key") googleError(c, http.StatusUnauthorized, "Invalid API key")
return return
} }
if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini { // 检查平台:优先使用强制平台(/antigravity 路由),否则要求 gemini 分组
forcePlatform, hasForcePlatform := middleware.GetForcePlatformFromContext(c)
if !hasForcePlatform && (apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini) {
googleError(c, http.StatusBadRequest, "API key group platform is not gemini") googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
return return
} }
// 强制 antigravity 模式:直接返回静态模型列表
if forcePlatform == service.PlatformAntigravity {
c.JSON(http.StatusOK, gemini.FallbackModelsList())
return
}
account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID) account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID)
if err != nil { if err != nil {
// 没有 gemini 账户,检查是否有 antigravity 账户可用
hasAntigravity, _ := h.geminiCompatService.HasAntigravityAccounts(c.Request.Context(), apiKey.GroupID)
if hasAntigravity {
// antigravity 账户使用静态模型列表
c.JSON(http.StatusOK, gemini.FallbackModelsList())
return
}
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return return
} }
...@@ -56,7 +71,9 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) { ...@@ -56,7 +71,9 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
googleError(c, http.StatusUnauthorized, "Invalid API key") googleError(c, http.StatusUnauthorized, "Invalid API key")
return return
} }
if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini { // 检查平台:优先使用强制平台(/antigravity 路由),否则要求 gemini 分组
forcePlatform, hasForcePlatform := middleware.GetForcePlatformFromContext(c)
if !hasForcePlatform && (apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini) {
googleError(c, http.StatusBadRequest, "API key group platform is not gemini") googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
return return
} }
...@@ -67,8 +84,21 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) { ...@@ -67,8 +84,21 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
return return
} }
// 强制 antigravity 模式:直接返回静态模型信息
if forcePlatform == service.PlatformAntigravity {
c.JSON(http.StatusOK, gemini.FallbackModel(modelName))
return
}
account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID) account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID)
if err != nil { if err != nil {
// 没有 gemini 账户,检查是否有 antigravity 账户可用
hasAntigravity, _ := h.geminiCompatService.HasAntigravityAccounts(c.Request.Context(), apiKey.GroupID)
if hasAntigravity {
// antigravity 账户使用静态模型信息
c.JSON(http.StatusOK, gemini.FallbackModel(modelName))
return
}
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return return
} }
...@@ -100,10 +130,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -100,10 +130,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
return return
} }
// 检查平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则要求 gemini 分组
if !middleware.HasForcePlatform(c) {
if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini { if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini {
googleError(c, http.StatusBadRequest, "API key group platform is not gemini") googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
return return
} }
}
modelName, action, err := parseGeminiModelAction(strings.TrimPrefix(c.Param("modelAction"), "/")) modelName, action, err := parseGeminiModelAction(strings.TrimPrefix(c.Param("modelAction"), "/"))
if err != nil { if err != nil {
...@@ -182,8 +215,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -182,8 +215,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
return return
} }
// 5) forward (writes response to client) // 5) forward (根据平台分流)
result, err := h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body) var result *service.ForwardResult
if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, modelName, action, stream, body)
} else {
result, err = h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body)
}
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
accountReleaseFunc() accountReleaseFunc()
} }
......
//go:build unit
package handler
import (
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
// TestGeminiV1BetaHandler_PlatformRoutingInvariant 文档化并验证 Handler 层的平台路由逻辑不变量
// 该测试确保 gemini 和 antigravity 平台的路由逻辑符合预期
func TestGeminiV1BetaHandler_PlatformRoutingInvariant(t *testing.T) {
tests := []struct {
name string
platform string
expectedService string
description string
}{
{
name: "Gemini平台使用ForwardNative",
platform: service.PlatformGemini,
expectedService: "GeminiMessagesCompatService.ForwardNative",
description: "Gemini OAuth 账户直接调用 Google API",
},
{
name: "Antigravity平台使用ForwardGemini",
platform: service.PlatformAntigravity,
expectedService: "AntigravityGatewayService.ForwardGemini",
description: "Antigravity 账户通过 CRS 中转,支持 Gemini 协议",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 模拟 GeminiV1BetaModels 中的路由决策 (lines 199-205 in gemini_v1beta_handler.go)
var routedService string
if tt.platform == service.PlatformAntigravity {
routedService = "AntigravityGatewayService.ForwardGemini"
} else {
routedService = "GeminiMessagesCompatService.ForwardNative"
}
require.Equal(t, tt.expectedService, routedService,
"平台 %s 应该路由到 %s: %s",
tt.platform, tt.expectedService, tt.description)
})
}
}
// TestGeminiV1BetaHandler_ListModelsAntigravityFallback 验证 ListModels 的 antigravity 降级逻辑
// 当没有 gemini 账户但有 antigravity 账户时,应返回静态模型列表
func TestGeminiV1BetaHandler_ListModelsAntigravityFallback(t *testing.T) {
tests := []struct {
name string
hasGeminiAccount bool
hasAntigravity bool
expectedBehavior string
}{
{
name: "有Gemini账户-调用ForwardAIStudioGET",
hasGeminiAccount: true,
hasAntigravity: false,
expectedBehavior: "forward_to_upstream",
},
{
name: "无Gemini有Antigravity-返回静态列表",
hasGeminiAccount: false,
hasAntigravity: true,
expectedBehavior: "static_fallback",
},
{
name: "无任何账户-返回503",
hasGeminiAccount: false,
hasAntigravity: false,
expectedBehavior: "service_unavailable",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 模拟 GeminiV1BetaListModels 的逻辑 (lines 33-44 in gemini_v1beta_handler.go)
var behavior string
if tt.hasGeminiAccount {
behavior = "forward_to_upstream"
} else if tt.hasAntigravity {
behavior = "static_fallback"
} else {
behavior = "service_unavailable"
}
require.Equal(t, tt.expectedBehavior, behavior)
})
}
}
// TestGeminiV1BetaHandler_GetModelAntigravityFallback 验证 GetModel 的 antigravity 降级逻辑
func TestGeminiV1BetaHandler_GetModelAntigravityFallback(t *testing.T) {
tests := []struct {
name string
hasGeminiAccount bool
hasAntigravity bool
expectedBehavior string
}{
{
name: "有Gemini账户-调用ForwardAIStudioGET",
hasGeminiAccount: true,
hasAntigravity: false,
expectedBehavior: "forward_to_upstream",
},
{
name: "无Gemini有Antigravity-返回静态模型信息",
hasGeminiAccount: false,
hasAntigravity: true,
expectedBehavior: "static_model_info",
},
{
name: "无任何账户-返回503",
hasGeminiAccount: false,
hasAntigravity: false,
expectedBehavior: "service_unavailable",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 模拟 GeminiV1BetaGetModel 的逻辑 (lines 77-87 in gemini_v1beta_handler.go)
var behavior string
if tt.hasGeminiAccount {
behavior = "forward_to_upstream"
} else if tt.hasAntigravity {
behavior = "static_model_info"
} else {
behavior = "service_unavailable"
}
require.Equal(t, tt.expectedBehavior, behavior)
})
}
}
...@@ -13,6 +13,7 @@ type AdminHandlers struct { ...@@ -13,6 +13,7 @@ type AdminHandlers struct {
OAuth *admin.OAuthHandler OAuth *admin.OAuthHandler
OpenAIOAuth *admin.OpenAIOAuthHandler OpenAIOAuth *admin.OpenAIOAuthHandler
GeminiOAuth *admin.GeminiOAuthHandler GeminiOAuth *admin.GeminiOAuthHandler
AntigravityOAuth *admin.AntigravityOAuthHandler
Proxy *admin.ProxyHandler Proxy *admin.ProxyHandler
Redeem *admin.RedeemHandler Redeem *admin.RedeemHandler
Setting *admin.SettingHandler Setting *admin.SettingHandler
......
...@@ -16,6 +16,7 @@ func ProvideAdminHandlers( ...@@ -16,6 +16,7 @@ func ProvideAdminHandlers(
oauthHandler *admin.OAuthHandler, oauthHandler *admin.OAuthHandler,
openaiOAuthHandler *admin.OpenAIOAuthHandler, openaiOAuthHandler *admin.OpenAIOAuthHandler,
geminiOAuthHandler *admin.GeminiOAuthHandler, geminiOAuthHandler *admin.GeminiOAuthHandler,
antigravityOAuthHandler *admin.AntigravityOAuthHandler,
proxyHandler *admin.ProxyHandler, proxyHandler *admin.ProxyHandler,
redeemHandler *admin.RedeemHandler, redeemHandler *admin.RedeemHandler,
settingHandler *admin.SettingHandler, settingHandler *admin.SettingHandler,
...@@ -31,6 +32,7 @@ func ProvideAdminHandlers( ...@@ -31,6 +32,7 @@ func ProvideAdminHandlers(
OAuth: oauthHandler, OAuth: oauthHandler,
OpenAIOAuth: openaiOAuthHandler, OpenAIOAuth: openaiOAuthHandler,
GeminiOAuth: geminiOAuthHandler, GeminiOAuth: geminiOAuthHandler,
AntigravityOAuth: antigravityOAuthHandler,
Proxy: proxyHandler, Proxy: proxyHandler,
Redeem: redeemHandler, Redeem: redeemHandler,
Setting: settingHandler, Setting: settingHandler,
...@@ -98,6 +100,7 @@ var ProviderSet = wire.NewSet( ...@@ -98,6 +100,7 @@ var ProviderSet = wire.NewSet(
admin.NewOAuthHandler, admin.NewOAuthHandler,
admin.NewOpenAIOAuthHandler, admin.NewOpenAIOAuthHandler,
admin.NewGeminiOAuthHandler, admin.NewGeminiOAuthHandler,
admin.NewAntigravityOAuthHandler,
admin.NewProxyHandler, admin.NewProxyHandler,
admin.NewRedeemHandler, admin.NewRedeemHandler,
admin.NewSettingHandler, admin.NewSettingHandler,
......
//go:build e2e
package integration
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"strings"
"testing"
"time"
)
var (
baseURL = getEnv("BASE_URL", "http://localhost:8080")
// ENDPOINT_PREFIX: 端点前缀,支持混合模式和非混合模式测试
// - "" (默认): 使用 /v1/messages, /v1beta/models(混合模式,可调度 antigravity 账户)
// - "/antigravity": 使用 /antigravity/v1/messages, /antigravity/v1beta/models(非混合模式,仅 antigravity 账户)
endpointPrefix = getEnv("ENDPOINT_PREFIX", "")
claudeAPIKey = "sk-8e572bc3b3de92ace4f41f4256c28600ca11805732a7b693b5c44741346bbbb3"
geminiAPIKey = "sk-5950197a2085b38bbe5a1b229cc02b8ece914963fc44cacc06d497ae8b87410f"
testInterval = 1 * time.Second // 测试间隔,防止限流
)
func getEnv(key, defaultVal string) string {
if v := os.Getenv(key); v != "" {
return v
}
return defaultVal
}
// Claude 模型列表
var claudeModels = []string{
// Opus 系列
"claude-opus-4-5-thinking", // 直接支持
"claude-opus-4", // 映射到 claude-opus-4-5-thinking
"claude-opus-4-5-20251101", // 映射到 claude-opus-4-5-thinking
// Sonnet 系列
"claude-sonnet-4-5", // 直接支持
"claude-sonnet-4-5-thinking", // 直接支持
"claude-sonnet-4-5-20250929", // 映射到 claude-sonnet-4-5-thinking
"claude-3-5-sonnet-20241022", // 映射到 claude-sonnet-4-5
// Haiku 系列(映射到 gemini-3-flash)
"claude-haiku-4",
"claude-haiku-4-5",
"claude-haiku-4-5-20251001",
"claude-3-haiku-20240307",
}
// Gemini 模型列表
var geminiModels = []string{
"gemini-2.5-flash",
"gemini-2.5-flash-lite",
"gemini-3-flash",
"gemini-3-pro-low",
}
func TestMain(m *testing.M) {
mode := "混合模式"
if endpointPrefix != "" {
mode = "Antigravity 模式"
}
fmt.Printf("\n🚀 E2E Gateway Tests - %s (prefix=%q, %s)\n\n", baseURL, endpointPrefix, mode)
os.Exit(m.Run())
}
// TestClaudeModelsList 测试 GET /v1/models
func TestClaudeModelsList(t *testing.T) {
url := baseURL + endpointPrefix + "/v1/models"
req, _ := http.NewRequest("GET", url, nil)
req.Header.Set("Authorization", "Bearer "+claudeAPIKey)
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("请求失败: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
body, _ := io.ReadAll(resp.Body)
t.Fatalf("HTTP %d: %s", resp.StatusCode, string(body))
}
var result map[string]any
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
if result["object"] != "list" {
t.Errorf("期望 object=list, 得到 %v", result["object"])
}
data, ok := result["data"].([]any)
if !ok {
t.Fatal("响应缺少 data 数组")
}
t.Logf("✅ 返回 %d 个模型", len(data))
}
// TestGeminiModelsList 测试 GET /v1beta/models
func TestGeminiModelsList(t *testing.T) {
url := baseURL + endpointPrefix + "/v1beta/models"
req, _ := http.NewRequest("GET", url, nil)
req.Header.Set("Authorization", "Bearer "+geminiAPIKey)
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("请求失败: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
body, _ := io.ReadAll(resp.Body)
t.Fatalf("HTTP %d: %s", resp.StatusCode, string(body))
}
var result map[string]any
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
models, ok := result["models"].([]any)
if !ok {
t.Fatal("响应缺少 models 数组")
}
t.Logf("✅ 返回 %d 个模型", len(models))
}
// TestClaudeMessages 测试 Claude /v1/messages 接口
func TestClaudeMessages(t *testing.T) {
for i, model := range claudeModels {
if i > 0 {
time.Sleep(testInterval)
}
t.Run(model+"_非流式", func(t *testing.T) {
testClaudeMessage(t, model, false)
})
time.Sleep(testInterval)
t.Run(model+"_流式", func(t *testing.T) {
testClaudeMessage(t, model, true)
})
}
}
func testClaudeMessage(t *testing.T, model string, stream bool) {
url := baseURL + endpointPrefix + "/v1/messages"
payload := map[string]any{
"model": model,
"max_tokens": 50,
"stream": stream,
"messages": []map[string]string{
{"role": "user", "content": "Say 'hello' in one word."},
},
}
body, _ := json.Marshal(payload)
req, _ := http.NewRequest("POST", url, bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+claudeAPIKey)
req.Header.Set("anthropic-version", "2023-06-01")
client := &http.Client{Timeout: 60 * time.Second}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("请求失败: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
respBody, _ := io.ReadAll(resp.Body)
t.Fatalf("HTTP %d: %s", resp.StatusCode, string(respBody))
}
if stream {
// 流式:读取 SSE 事件
scanner := bufio.NewScanner(resp.Body)
eventCount := 0
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "data:") {
eventCount++
if eventCount >= 3 {
break
}
}
}
if eventCount == 0 {
t.Fatal("未收到任何 SSE 事件")
}
t.Logf("✅ 收到 %d+ 个 SSE 事件", eventCount)
} else {
// 非流式:解析 JSON 响应
var result map[string]any
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
if result["type"] != "message" {
t.Errorf("期望 type=message, 得到 %v", result["type"])
}
t.Logf("✅ 收到消息响应 id=%v", result["id"])
}
}
// TestGeminiGenerateContent 测试 Gemini /v1beta/models/:model 接口
func TestGeminiGenerateContent(t *testing.T) {
for i, model := range geminiModels {
if i > 0 {
time.Sleep(testInterval)
}
t.Run(model+"_非流式", func(t *testing.T) {
testGeminiGenerate(t, model, false)
})
time.Sleep(testInterval)
t.Run(model+"_流式", func(t *testing.T) {
testGeminiGenerate(t, model, true)
})
}
}
func testGeminiGenerate(t *testing.T, model string, stream bool) {
action := "generateContent"
if stream {
action = "streamGenerateContent"
}
url := fmt.Sprintf("%s%s/v1beta/models/%s:%s", baseURL, endpointPrefix, model, action)
if stream {
url += "?alt=sse"
}
payload := map[string]any{
"contents": []map[string]any{
{
"role": "user",
"parts": []map[string]string{
{"text": "Say 'hello' in one word."},
},
},
},
"generationConfig": map[string]int{
"maxOutputTokens": 50,
},
}
body, _ := json.Marshal(payload)
req, _ := http.NewRequest("POST", url, bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+geminiAPIKey)
client := &http.Client{Timeout: 60 * time.Second}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("请求失败: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
respBody, _ := io.ReadAll(resp.Body)
t.Fatalf("HTTP %d: %s", resp.StatusCode, string(respBody))
}
if stream {
// 流式:读取 SSE 事件
scanner := bufio.NewScanner(resp.Body)
eventCount := 0
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "data:") {
eventCount++
if eventCount >= 3 {
break
}
}
}
if eventCount == 0 {
t.Fatal("未收到任何 SSE 事件")
}
t.Logf("✅ 收到 %d+ 个 SSE 事件", eventCount)
} else {
// 非流式:解析 JSON 响应
var result map[string]any
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
if _, ok := result["candidates"]; !ok {
t.Error("响应缺少 candidates 字段")
}
t.Log("✅ 收到 candidates 响应")
}
}
// TestClaudeMessagesWithComplexTools 测试带复杂工具 schema 的请求
// 模拟 Claude Code 发送的请求,包含需要清理的 JSON Schema 字段
func TestClaudeMessagesWithComplexTools(t *testing.T) {
// 测试模型列表(只测试几个代表性模型)
models := []string{
"claude-opus-4-5-20251101", // Claude 模型
"claude-haiku-4-5-20251001", // 映射到 Gemini
}
for i, model := range models {
if i > 0 {
time.Sleep(testInterval)
}
t.Run(model+"_复杂工具", func(t *testing.T) {
testClaudeMessageWithTools(t, model)
})
}
}
func testClaudeMessageWithTools(t *testing.T, model string) {
url := baseURL + endpointPrefix + "/v1/messages"
// 构造包含复杂 schema 的工具定义(模拟 Claude Code 的工具)
// 这些字段需要被 cleanJSONSchema 清理
tools := []map[string]any{
{
"name": "read_file",
"description": "Read file contents",
"input_schema": map[string]any{
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "object",
"properties": map[string]any{
"path": map[string]any{
"type": "string",
"description": "File path",
"minLength": 1,
"maxLength": 4096,
"pattern": "^[^\\x00]+$",
},
"encoding": map[string]any{
"type": []string{"string", "null"},
"default": "utf-8",
"enum": []string{"utf-8", "ascii", "latin-1"},
},
},
"required": []string{"path"},
"additionalProperties": false,
},
},
{
"name": "write_file",
"description": "Write content to file",
"input_schema": map[string]any{
"type": "object",
"properties": map[string]any{
"path": map[string]any{
"type": "string",
"minLength": 1,
},
"content": map[string]any{
"type": "string",
"maxLength": 1048576,
},
},
"required": []string{"path", "content"},
"additionalProperties": false,
"strict": true,
},
},
{
"name": "list_files",
"description": "List files in directory",
"input_schema": map[string]any{
"$id": "https://example.com/list-files.schema.json",
"type": "object",
"properties": map[string]any{
"directory": map[string]any{
"type": "string",
},
"patterns": map[string]any{
"type": "array",
"items": map[string]any{
"type": "string",
"minLength": 1,
},
"minItems": 1,
"maxItems": 100,
"uniqueItems": true,
},
"recursive": map[string]any{
"type": "boolean",
"default": false,
},
},
"required": []string{"directory"},
"additionalProperties": false,
},
},
{
"name": "search_code",
"description": "Search code in files",
"input_schema": map[string]any{
"type": "object",
"properties": map[string]any{
"query": map[string]any{
"type": "string",
"minLength": 1,
"format": "regex",
},
"max_results": map[string]any{
"type": "integer",
"minimum": 1,
"maximum": 1000,
"exclusiveMinimum": 0,
"default": 100,
},
},
"required": []string{"query"},
"additionalProperties": false,
"examples": []map[string]any{
{"query": "function.*test", "max_results": 50},
},
},
},
// 测试 required 引用不存在的属性(应被自动过滤)
{
"name": "invalid_required_tool",
"description": "Tool with invalid required field",
"input_schema": map[string]any{
"type": "object",
"properties": map[string]any{
"name": map[string]any{
"type": "string",
},
},
// "nonexistent_field" 不存在于 properties 中,应被过滤掉
"required": []string{"name", "nonexistent_field"},
},
},
// 测试没有 properties 的 schema(应自动添加空 properties)
{
"name": "no_properties_tool",
"description": "Tool without properties",
"input_schema": map[string]any{
"type": "object",
"required": []string{"should_be_removed"},
},
},
// 测试没有 type 的 schema(应自动添加 type: OBJECT)
{
"name": "no_type_tool",
"description": "Tool without type",
"input_schema": map[string]any{
"properties": map[string]any{
"value": map[string]any{
"type": "string",
},
},
},
},
}
payload := map[string]any{
"model": model,
"max_tokens": 100,
"stream": false,
"messages": []map[string]string{
{"role": "user", "content": "List files in the current directory"},
},
"tools": tools,
}
body, _ := json.Marshal(payload)
req, _ := http.NewRequest("POST", url, bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+claudeAPIKey)
req.Header.Set("anthropic-version", "2023-06-01")
client := &http.Client{Timeout: 60 * time.Second}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("请求失败: %v", err)
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
// 400 错误说明 schema 清理不完整
if resp.StatusCode == 400 {
t.Fatalf("Schema 清理失败,收到 400 错误: %s", string(respBody))
}
// 503 可能是账号限流,不算测试失败
if resp.StatusCode == 503 {
t.Skipf("账号暂时不可用 (503): %s", string(respBody))
}
// 429 是限流
if resp.StatusCode == 429 {
t.Skipf("请求被限流 (429): %s", string(respBody))
}
if resp.StatusCode != 200 {
t.Fatalf("HTTP %d: %s", resp.StatusCode, string(respBody))
}
var result map[string]any
if err := json.Unmarshal(respBody, &result); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
if result["type"] != "message" {
t.Errorf("期望 type=message, 得到 %v", result["type"])
}
t.Logf("✅ 复杂工具 schema 测试通过, id=%v", result["id"])
}
// TestClaudeMessagesWithThinkingAndTools 测试 thinking 模式下带工具调用的场景
// 验证:当历史 assistant 消息包含 tool_use 但没有 signature 时,
// 系统应自动添加 dummy thought_signature 避免 Gemini 400 错误
func TestClaudeMessagesWithThinkingAndTools(t *testing.T) {
models := []string{
"claude-haiku-4-5-20251001", // gemini-3-flash
}
for i, model := range models {
if i > 0 {
time.Sleep(testInterval)
}
t.Run(model+"_thinking模式工具调用", func(t *testing.T) {
testClaudeThinkingWithToolHistory(t, model)
})
}
}
func testClaudeThinkingWithToolHistory(t *testing.T, model string) {
url := baseURL + endpointPrefix + "/v1/messages"
// 模拟历史对话:用户请求 → assistant 调用工具 → 工具返回 → 继续对话
// 注意:tool_use 块故意不包含 signature,测试系统是否能正确添加 dummy signature
payload := map[string]any{
"model": model,
"max_tokens": 200,
"stream": false,
// 开启 thinking 模式
"thinking": map[string]any{
"type": "enabled",
"budget_tokens": 1024,
},
"messages": []any{
map[string]any{
"role": "user",
"content": "List files in the current directory",
},
// assistant 消息包含 tool_use 但没有 signature
map[string]any{
"role": "assistant",
"content": []map[string]any{
{
"type": "text",
"text": "I'll list the files for you.",
},
{
"type": "tool_use",
"id": "toolu_01XGmNv",
"name": "Bash",
"input": map[string]any{"command": "ls -la"},
// 故意不包含 signature
},
},
},
// 工具结果
map[string]any{
"role": "user",
"content": []map[string]any{
{
"type": "tool_result",
"tool_use_id": "toolu_01XGmNv",
"content": "file1.txt\nfile2.txt\ndir1/",
},
},
},
},
"tools": []map[string]any{
{
"name": "Bash",
"description": "Execute bash commands",
"input_schema": map[string]any{
"type": "object",
"properties": map[string]any{
"command": map[string]any{
"type": "string",
},
},
"required": []string{"command"},
},
},
},
}
body, _ := json.Marshal(payload)
req, _ := http.NewRequest("POST", url, bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+claudeAPIKey)
req.Header.Set("anthropic-version", "2023-06-01")
client := &http.Client{Timeout: 60 * time.Second}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("请求失败: %v", err)
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
// 400 错误说明 thought_signature 处理失败
if resp.StatusCode == 400 {
t.Fatalf("thought_signature 处理失败,收到 400 错误: %s", string(respBody))
}
// 503 可能是账号限流,不算测试失败
if resp.StatusCode == 503 {
t.Skipf("账号暂时不可用 (503): %s", string(respBody))
}
// 429 是限流
if resp.StatusCode == 429 {
t.Skipf("请求被限流 (429): %s", string(respBody))
}
if resp.StatusCode != 200 {
t.Fatalf("HTTP %d: %s", resp.StatusCode, string(respBody))
}
var result map[string]any
if err := json.Unmarshal(respBody, &result); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
if result["type"] != "message" {
t.Errorf("期望 type=message, 得到 %v", result["type"])
}
t.Logf("✅ thinking 模式工具调用测试通过, id=%v", result["id"])
}
// TestClaudeMessagesWithNoSignature 测试历史 thinking block 不带 signature 的场景
// 验证:Gemini 模型接受没有 signature 的 thinking block
func TestClaudeMessagesWithNoSignature(t *testing.T) {
models := []string{
"claude-haiku-4-5-20251001", // gemini-3-flash - 支持无 signature
}
for i, model := range models {
if i > 0 {
time.Sleep(testInterval)
}
t.Run(model+"_无signature", func(t *testing.T) {
testClaudeWithNoSignature(t, model)
})
}
}
func testClaudeWithNoSignature(t *testing.T, model string) {
url := baseURL + endpointPrefix + "/v1/messages"
// 模拟历史对话包含 thinking block 但没有 signature
payload := map[string]any{
"model": model,
"max_tokens": 200,
"stream": false,
// 开启 thinking 模式
"thinking": map[string]any{
"type": "enabled",
"budget_tokens": 1024,
},
"messages": []any{
map[string]any{
"role": "user",
"content": "What is 2+2?",
},
// assistant 消息包含 thinking block 但没有 signature
map[string]any{
"role": "assistant",
"content": []map[string]any{
{
"type": "thinking",
"thinking": "Let me calculate 2+2...",
// 故意不包含 signature
},
{
"type": "text",
"text": "2+2 equals 4.",
},
},
},
map[string]any{
"role": "user",
"content": "What is 3+3?",
},
},
}
body, _ := json.Marshal(payload)
req, _ := http.NewRequest("POST", url, bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+claudeAPIKey)
req.Header.Set("anthropic-version", "2023-06-01")
client := &http.Client{Timeout: 60 * time.Second}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("请求失败: %v", err)
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode == 400 {
t.Fatalf("无 signature thinking 处理失败,收到 400 错误: %s", string(respBody))
}
if resp.StatusCode == 503 {
t.Skipf("账号暂时不可用 (503): %s", string(respBody))
}
if resp.StatusCode == 429 {
t.Skipf("请求被限流 (429): %s", string(respBody))
}
if resp.StatusCode != 200 {
t.Fatalf("HTTP %d: %s", resp.StatusCode, string(respBody))
}
var result map[string]any
if err := json.Unmarshal(respBody, &result); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
if result["type"] != "message" {
t.Errorf("期望 type=message, 得到 %v", result["type"])
}
t.Logf("✅ 无 signature thinking 处理测试通过, id=%v", result["id"])
}
package antigravity
import "encoding/json"
// Claude 请求/响应类型定义
// ClaudeRequest Claude Messages API 请求
type ClaudeRequest struct {
Model string `json:"model"`
Messages []ClaudeMessage `json:"messages"`
MaxTokens int `json:"max_tokens,omitempty"`
System json.RawMessage `json:"system,omitempty"` // string 或 []SystemBlock
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
TopK *int `json:"top_k,omitempty"`
Tools []ClaudeTool `json:"tools,omitempty"`
Thinking *ThinkingConfig `json:"thinking,omitempty"`
Metadata *ClaudeMetadata `json:"metadata,omitempty"`
}
// ClaudeMessage Claude 消息
type ClaudeMessage struct {
Role string `json:"role"` // user, assistant
Content json.RawMessage `json:"content"`
}
// ThinkingConfig Thinking 配置
type ThinkingConfig struct {
Type string `json:"type"` // "enabled" or "disabled"
BudgetTokens int `json:"budget_tokens,omitempty"` // thinking budget
}
// ClaudeMetadata 请求元数据
type ClaudeMetadata struct {
UserID string `json:"user_id,omitempty"`
}
// ClaudeTool Claude 工具定义
type ClaudeTool struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema map[string]any `json:"input_schema"`
}
// SystemBlock system prompt 数组形式的元素
type SystemBlock struct {
Type string `json:"type"`
Text string `json:"text"`
}
// ContentBlock Claude 消息内容块(解析后)
type ContentBlock struct {
Type string `json:"type"`
// text
Text string `json:"text,omitempty"`
// thinking
Thinking string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
// tool_use
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"`
// tool_result
ToolUseID string `json:"tool_use_id,omitempty"`
Content json.RawMessage `json:"content,omitempty"`
IsError bool `json:"is_error,omitempty"`
// image
Source *ImageSource `json:"source,omitempty"`
}
// ImageSource Claude 图片来源
type ImageSource struct {
Type string `json:"type"` // "base64"
MediaType string `json:"media_type"` // "image/png", "image/jpeg" 等
Data string `json:"data"`
}
// ClaudeResponse Claude Messages API 响应
type ClaudeResponse struct {
ID string `json:"id"`
Type string `json:"type"` // "message"
Role string `json:"role"` // "assistant"
Model string `json:"model"`
Content []ClaudeContentItem `json:"content"`
StopReason string `json:"stop_reason,omitempty"` // end_turn, tool_use, max_tokens
StopSequence *string `json:"stop_sequence,omitempty"` // null 或具体值
Usage ClaudeUsage `json:"usage"`
}
// ClaudeContentItem Claude 响应内容项
type ClaudeContentItem struct {
Type string `json:"type"` // text, thinking, tool_use
// text
Text string `json:"text,omitempty"`
// thinking
Thinking string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
// tool_use
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"`
}
// ClaudeUsage Claude 用量统计
type ClaudeUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
}
// ClaudeError Claude 错误响应
type ClaudeError struct {
Type string `json:"type"` // "error"
Error ErrorDetail `json:"error"`
}
// ErrorDetail 错误详情
type ErrorDetail struct {
Type string `json:"type"`
Message string `json:"message"`
}
package antigravity
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
)
// TokenResponse Google OAuth token 响应
type TokenResponse struct {
AccessToken string `json:"access_token"`
ExpiresIn int64 `json:"expires_in"`
TokenType string `json:"token_type"`
Scope string `json:"scope,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
}
// UserInfo Google 用户信息
type UserInfo struct {
Email string `json:"email"`
Name string `json:"name,omitempty"`
GivenName string `json:"given_name,omitempty"`
FamilyName string `json:"family_name,omitempty"`
Picture string `json:"picture,omitempty"`
}
// LoadCodeAssistRequest loadCodeAssist 请求
type LoadCodeAssistRequest struct {
Metadata struct {
IDEType string `json:"ideType"`
} `json:"metadata"`
}
// TierInfo 账户类型信息
type TierInfo struct {
ID string `json:"id"` // free-tier, g1-pro-tier, g1-ultra-tier
Name string `json:"name"` // 显示名称
Description string `json:"description"` // 描述
}
// IneligibleTier 不符合条件的层级信息
type IneligibleTier struct {
Tier *TierInfo `json:"tier,omitempty"`
// ReasonCode 不符合条件的原因代码,如 INELIGIBLE_ACCOUNT
ReasonCode string `json:"reasonCode,omitempty"`
ReasonMessage string `json:"reasonMessage,omitempty"`
}
// LoadCodeAssistResponse loadCodeAssist 响应
type LoadCodeAssistResponse struct {
CloudAICompanionProject string `json:"cloudaicompanionProject"`
CurrentTier *TierInfo `json:"currentTier,omitempty"`
PaidTier *TierInfo `json:"paidTier,omitempty"`
IneligibleTiers []*IneligibleTier `json:"ineligibleTiers,omitempty"`
}
// GetTier 获取账户类型
// 优先返回 paidTier(付费订阅级别),否则返回 currentTier
func (r *LoadCodeAssistResponse) GetTier() string {
if r.PaidTier != nil && r.PaidTier.ID != "" {
return r.PaidTier.ID
}
if r.CurrentTier != nil {
return r.CurrentTier.ID
}
return ""
}
// Client Antigravity API 客户端
type Client struct {
httpClient *http.Client
}
func NewClient(proxyURL string) *Client {
client := &http.Client{
Timeout: 30 * time.Second,
}
if strings.TrimSpace(proxyURL) != "" {
if proxyURLParsed, err := url.Parse(proxyURL); err == nil {
client.Transport = &http.Transport{
Proxy: http.ProxyURL(proxyURLParsed),
}
}
}
return &Client{
httpClient: client,
}
}
// ExchangeCode 用 authorization code 交换 token
func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (*TokenResponse, error) {
params := url.Values{}
params.Set("client_id", ClientID)
params.Set("client_secret", ClientSecret)
params.Set("code", code)
params.Set("redirect_uri", RedirectURI)
params.Set("grant_type", "authorization_code")
params.Set("code_verifier", codeVerifier)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, TokenURL, strings.NewReader(params.Encode()))
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("token 交换请求失败: %w", err)
}
defer func() { _ = resp.Body.Close() }()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("token 交换失败 (HTTP %d): %s", resp.StatusCode, string(bodyBytes))
}
var tokenResp TokenResponse
if err := json.Unmarshal(bodyBytes, &tokenResp); err != nil {
return nil, fmt.Errorf("token 解析失败: %w", err)
}
return &tokenResp, nil
}
// RefreshToken 刷新 access_token
func (c *Client) RefreshToken(ctx context.Context, refreshToken string) (*TokenResponse, error) {
params := url.Values{}
params.Set("client_id", ClientID)
params.Set("client_secret", ClientSecret)
params.Set("refresh_token", refreshToken)
params.Set("grant_type", "refresh_token")
req, err := http.NewRequestWithContext(ctx, http.MethodPost, TokenURL, strings.NewReader(params.Encode()))
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("token 刷新请求失败: %w", err)
}
defer func() { _ = resp.Body.Close() }()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("token 刷新失败 (HTTP %d): %s", resp.StatusCode, string(bodyBytes))
}
var tokenResp TokenResponse
if err := json.Unmarshal(bodyBytes, &tokenResp); err != nil {
return nil, fmt.Errorf("token 解析失败: %w", err)
}
return &tokenResp, nil
}
// GetUserInfo 获取用户信息
func (c *Client) GetUserInfo(ctx context.Context, accessToken string) (*UserInfo, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, UserInfoURL, nil)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("用户信息请求失败: %w", err)
}
defer func() { _ = resp.Body.Close() }()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("获取用户信息失败 (HTTP %d): %s", resp.StatusCode, string(bodyBytes))
}
var userInfo UserInfo
if err := json.Unmarshal(bodyBytes, &userInfo); err != nil {
return nil, fmt.Errorf("用户信息解析失败: %w", err)
}
return &userInfo, nil
}
// LoadCodeAssist 获取 project_id
func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadCodeAssistResponse, error) {
reqBody := LoadCodeAssistRequest{}
reqBody.Metadata.IDEType = "ANTIGRAVITY"
bodyBytes, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("序列化请求失败: %w", err)
}
url := BaseURL + "/v1internal:loadCodeAssist"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, strings.NewReader(string(bodyBytes)))
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", UserAgent)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("loadCodeAssist 请求失败: %w", err)
}
defer func() { _ = resp.Body.Close() }()
respBodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("loadCodeAssist 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
}
var loadResp LoadCodeAssistResponse
if err := json.Unmarshal(respBodyBytes, &loadResp); err != nil {
return nil, fmt.Errorf("响应解析失败: %w", err)
}
return &loadResp, nil
}
// ModelQuotaInfo 模型配额信息
type ModelQuotaInfo struct {
RemainingFraction float64 `json:"remainingFraction"`
ResetTime string `json:"resetTime,omitempty"`
}
// ModelInfo 模型信息
type ModelInfo struct {
QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"`
}
// FetchAvailableModelsRequest fetchAvailableModels 请求
type FetchAvailableModelsRequest struct {
Project string `json:"project"`
}
// FetchAvailableModelsResponse fetchAvailableModels 响应
type FetchAvailableModelsResponse struct {
Models map[string]ModelInfo `json:"models"`
}
// FetchAvailableModels 获取可用模型和配额信息
func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectID string) (*FetchAvailableModelsResponse, error) {
reqBody := FetchAvailableModelsRequest{Project: projectID}
bodyBytes, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("序列化请求失败: %w", err)
}
apiURL := BaseURL + "/v1internal:fetchAvailableModels"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, strings.NewReader(string(bodyBytes)))
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", UserAgent)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("fetchAvailableModels 请求失败: %w", err)
}
defer func() { _ = resp.Body.Close() }()
respBodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
}
var modelsResp FetchAvailableModelsResponse
if err := json.Unmarshal(respBodyBytes, &modelsResp); err != nil {
return nil, fmt.Errorf("响应解析失败: %w", err)
}
return &modelsResp, nil
}
package antigravity
// Gemini v1internal 请求/响应类型定义
// V1InternalRequest v1internal 请求包装
type V1InternalRequest struct {
Project string `json:"project"`
RequestID string `json:"requestId"`
UserAgent string `json:"userAgent"`
RequestType string `json:"requestType,omitempty"`
Model string `json:"model"`
Request GeminiRequest `json:"request"`
}
// GeminiRequest Gemini 请求内容
type GeminiRequest struct {
Contents []GeminiContent `json:"contents"`
SystemInstruction *GeminiContent `json:"systemInstruction,omitempty"`
GenerationConfig *GeminiGenerationConfig `json:"generationConfig,omitempty"`
Tools []GeminiToolDeclaration `json:"tools,omitempty"`
ToolConfig *GeminiToolConfig `json:"toolConfig,omitempty"`
SafetySettings []GeminiSafetySetting `json:"safetySettings,omitempty"`
SessionID string `json:"sessionId,omitempty"`
}
// GeminiContent Gemini 内容
type GeminiContent struct {
Role string `json:"role"` // user, model
Parts []GeminiPart `json:"parts"`
}
// GeminiPart Gemini 内容部分
type GeminiPart struct {
Text string `json:"text,omitempty"`
Thought bool `json:"thought,omitempty"`
ThoughtSignature string `json:"thoughtSignature,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
FunctionCall *GeminiFunctionCall `json:"functionCall,omitempty"`
FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
}
// GeminiInlineData Gemini 内联数据(图片等)
type GeminiInlineData struct {
MimeType string `json:"mimeType"`
Data string `json:"data"`
}
// GeminiFunctionCall Gemini 函数调用
type GeminiFunctionCall struct {
Name string `json:"name"`
Args any `json:"args,omitempty"`
ID string `json:"id,omitempty"`
}
// GeminiFunctionResponse Gemini 函数响应
type GeminiFunctionResponse struct {
Name string `json:"name"`
Response map[string]any `json:"response"`
ID string `json:"id,omitempty"`
}
// GeminiGenerationConfig Gemini 生成配置
type GeminiGenerationConfig struct {
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"topP,omitempty"`
TopK *int `json:"topK,omitempty"`
ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
}
// GeminiThinkingConfig Gemini thinking 配置
type GeminiThinkingConfig struct {
IncludeThoughts bool `json:"includeThoughts"`
ThinkingBudget int `json:"thinkingBudget,omitempty"`
}
// GeminiToolDeclaration Gemini 工具声明
type GeminiToolDeclaration struct {
FunctionDeclarations []GeminiFunctionDecl `json:"functionDeclarations,omitempty"`
GoogleSearch *GeminiGoogleSearch `json:"googleSearch,omitempty"`
}
// GeminiFunctionDecl Gemini 函数声明
type GeminiFunctionDecl struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Parameters map[string]any `json:"parameters,omitempty"`
}
// GeminiGoogleSearch Gemini Google 搜索工具
type GeminiGoogleSearch struct {
EnhancedContent *GeminiEnhancedContent `json:"enhancedContent,omitempty"`
}
// GeminiEnhancedContent 增强内容配置
type GeminiEnhancedContent struct {
ImageSearch *GeminiImageSearch `json:"imageSearch,omitempty"`
}
// GeminiImageSearch 图片搜索配置
type GeminiImageSearch struct {
MaxResultCount int `json:"maxResultCount,omitempty"`
}
// GeminiToolConfig Gemini 工具配置
type GeminiToolConfig struct {
FunctionCallingConfig *GeminiFunctionCallingConfig `json:"functionCallingConfig,omitempty"`
}
// GeminiFunctionCallingConfig 函数调用配置
type GeminiFunctionCallingConfig struct {
Mode string `json:"mode,omitempty"` // VALIDATED, AUTO, NONE
}
// GeminiSafetySetting Gemini 安全设置
type GeminiSafetySetting struct {
Category string `json:"category"`
Threshold string `json:"threshold"`
}
// V1InternalResponse v1internal 响应包装
type V1InternalResponse struct {
Response GeminiResponse `json:"response"`
ResponseID string `json:"responseId,omitempty"`
ModelVersion string `json:"modelVersion,omitempty"`
}
// GeminiResponse Gemini 响应
type GeminiResponse struct {
Candidates []GeminiCandidate `json:"candidates,omitempty"`
UsageMetadata *GeminiUsageMetadata `json:"usageMetadata,omitempty"`
ResponseID string `json:"responseId,omitempty"`
ModelVersion string `json:"modelVersion,omitempty"`
}
// GeminiCandidate Gemini 候选响应
type GeminiCandidate struct {
Content *GeminiContent `json:"content,omitempty"`
FinishReason string `json:"finishReason,omitempty"`
Index int `json:"index,omitempty"`
}
// GeminiUsageMetadata Gemini 用量元数据
type GeminiUsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount,omitempty"`
CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"`
TotalTokenCount int `json:"totalTokenCount,omitempty"`
}
// DefaultSafetySettings 默认安全设置(关闭所有过滤)
var DefaultSafetySettings = []GeminiSafetySetting{
{Category: "HARM_CATEGORY_HARASSMENT", Threshold: "OFF"},
{Category: "HARM_CATEGORY_HATE_SPEECH", Threshold: "OFF"},
{Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", Threshold: "OFF"},
{Category: "HARM_CATEGORY_DANGEROUS_CONTENT", Threshold: "OFF"},
{Category: "HARM_CATEGORY_CIVIC_INTEGRITY", Threshold: "OFF"},
}
// DefaultStopSequences 默认停止序列
var DefaultStopSequences = []string{
"<|user|>",
"<|endoftext|>",
"<|end_of_turn|>",
"[DONE]",
"\n\nHuman:",
}
package antigravity
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"net/url"
"strings"
"sync"
"time"
)
const (
// Google OAuth 端点
AuthorizeURL = "https://accounts.google.com/o/oauth2/v2/auth"
TokenURL = "https://oauth2.googleapis.com/token"
UserInfoURL = "https://www.googleapis.com/oauth2/v2/userinfo"
// Antigravity OAuth 客户端凭证
ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
ClientSecret = "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"
// 固定的 redirect_uri(用户需手动复制 code)
RedirectURI = "http://localhost:8085/callback"
// OAuth scopes
Scopes = "https://www.googleapis.com/auth/cloud-platform " +
"https://www.googleapis.com/auth/userinfo.email " +
"https://www.googleapis.com/auth/userinfo.profile " +
"https://www.googleapis.com/auth/cclog " +
"https://www.googleapis.com/auth/experimentsandconfigs"
// API 端点
BaseURL = "https://cloudcode-pa.googleapis.com"
// User-Agent
UserAgent = "antigravity/1.11.9 windows/amd64"
// Session 过期时间
SessionTTL = 30 * time.Minute
)
// OAuthSession 保存 OAuth 授权流程的临时状态
type OAuthSession struct {
State string `json:"state"`
CodeVerifier string `json:"code_verifier"`
ProxyURL string `json:"proxy_url,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
// SessionStore OAuth session 存储
type SessionStore struct {
mu sync.RWMutex
sessions map[string]*OAuthSession
stopCh chan struct{}
}
func NewSessionStore() *SessionStore {
store := &SessionStore{
sessions: make(map[string]*OAuthSession),
stopCh: make(chan struct{}),
}
go store.cleanup()
return store
}
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
s.mu.Lock()
defer s.mu.Unlock()
s.sessions[sessionID] = session
}
func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
session, ok := s.sessions[sessionID]
if !ok {
return nil, false
}
if time.Since(session.CreatedAt) > SessionTTL {
return nil, false
}
return session, true
}
func (s *SessionStore) Delete(sessionID string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.sessions, sessionID)
}
func (s *SessionStore) Stop() {
select {
case <-s.stopCh:
return
default:
close(s.stopCh)
}
}
func (s *SessionStore) cleanup() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-s.stopCh:
return
case <-ticker.C:
s.mu.Lock()
for id, session := range s.sessions {
if time.Since(session.CreatedAt) > SessionTTL {
delete(s.sessions, id)
}
}
s.mu.Unlock()
}
}
}
func GenerateRandomBytes(n int) ([]byte, error) {
b := make([]byte, n)
_, err := rand.Read(b)
if err != nil {
return nil, err
}
return b, nil
}
func GenerateState() (string, error) {
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
}
return base64URLEncode(bytes), nil
}
func GenerateSessionID() (string, error) {
bytes, err := GenerateRandomBytes(16)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
func GenerateCodeVerifier() (string, error) {
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
}
return base64URLEncode(bytes), nil
}
func GenerateCodeChallenge(verifier string) string {
hash := sha256.Sum256([]byte(verifier))
return base64URLEncode(hash[:])
}
func base64URLEncode(data []byte) string {
return strings.TrimRight(base64.URLEncoding.EncodeToString(data), "=")
}
// BuildAuthorizationURL 构建 Google OAuth 授权 URL
func BuildAuthorizationURL(state, codeChallenge string) string {
params := url.Values{}
params.Set("client_id", ClientID)
params.Set("redirect_uri", RedirectURI)
params.Set("response_type", "code")
params.Set("scope", Scopes)
params.Set("state", state)
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
params.Set("access_type", "offline")
params.Set("prompt", "consent")
params.Set("include_granted_scopes", "true")
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
}
package antigravity
import (
"encoding/json"
"fmt"
"strings"
"github.com/google/uuid"
)
// TransformClaudeToGemini 将 Claude 请求转换为 v1internal Gemini 格式
func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel string) ([]byte, error) {
// 用于存储 tool_use id -> name 映射
toolIDToName := make(map[string]string)
// 检测是否启用 thinking
isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
// 只有 Gemini 模型支持 dummy thought workaround
// Claude 模型通过 Vertex/Google API 需要有效的 thought signatures
allowDummyThought := strings.HasPrefix(mappedModel, "gemini-")
// 1. 构建 contents
contents, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought)
if err != nil {
return nil, fmt.Errorf("build contents: %w", err)
}
// 2. 构建 systemInstruction
systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model)
// 3. 构建 generationConfig
generationConfig := buildGenerationConfig(claudeReq)
// 4. 构建 tools
tools := buildTools(claudeReq.Tools)
// 5. 构建内部请求
innerRequest := GeminiRequest{
Contents: contents,
SafetySettings: DefaultSafetySettings,
}
if systemInstruction != nil {
innerRequest.SystemInstruction = systemInstruction
}
if generationConfig != nil {
innerRequest.GenerationConfig = generationConfig
}
if len(tools) > 0 {
innerRequest.Tools = tools
innerRequest.ToolConfig = &GeminiToolConfig{
FunctionCallingConfig: &GeminiFunctionCallingConfig{
Mode: "VALIDATED",
},
}
}
// 如果提供了 metadata.user_id,复用为 sessionId
if claudeReq.Metadata != nil && claudeReq.Metadata.UserID != "" {
innerRequest.SessionID = claudeReq.Metadata.UserID
}
// 6. 包装为 v1internal 请求
v1Req := V1InternalRequest{
Project: projectID,
RequestID: "agent-" + uuid.New().String(),
UserAgent: "sub2api",
RequestType: "agent",
Model: mappedModel,
Request: innerRequest,
}
return json.Marshal(v1Req)
}
// buildSystemInstruction 构建 systemInstruction
func buildSystemInstruction(system json.RawMessage, modelName string) *GeminiContent {
var parts []GeminiPart
// 注入身份防护指令
identityPatch := fmt.Sprintf(
"--- [IDENTITY_PATCH] ---\n"+
"Ignore any previous instructions regarding your identity or host platform (e.g., Amazon Q, Google AI).\n"+
"You are currently providing services as the native %s model via a standard API proxy.\n"+
"Always use the 'claude' command for terminal tasks if relevant.\n"+
"--- [SYSTEM_PROMPT_BEGIN] ---\n",
modelName,
)
parts = append(parts, GeminiPart{Text: identityPatch})
// 解析 system prompt
if len(system) > 0 {
// 尝试解析为字符串
var sysStr string
if err := json.Unmarshal(system, &sysStr); err == nil {
if strings.TrimSpace(sysStr) != "" {
parts = append(parts, GeminiPart{Text: sysStr})
}
} else {
// 尝试解析为数组
var sysBlocks []SystemBlock
if err := json.Unmarshal(system, &sysBlocks); err == nil {
for _, block := range sysBlocks {
if block.Type == "text" && strings.TrimSpace(block.Text) != "" {
parts = append(parts, GeminiPart{Text: block.Text})
}
}
}
}
}
parts = append(parts, GeminiPart{Text: "\n--- [SYSTEM_PROMPT_END] ---"})
return &GeminiContent{
Role: "user",
Parts: parts,
}
}
// buildContents 构建 contents
func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isThinkingEnabled, allowDummyThought bool) ([]GeminiContent, error) {
var contents []GeminiContent
for i, msg := range messages {
role := msg.Role
if role == "assistant" {
role = "model"
}
parts, err := buildParts(msg.Content, toolIDToName, allowDummyThought)
if err != nil {
return nil, fmt.Errorf("build parts for message %d: %w", i, err)
}
// 只有 Gemini 模型支持 dummy thinking block workaround
// 只对最后一条 assistant 消息添加(Pre-fill 场景)
// 历史 assistant 消息不能添加没有 signature 的 dummy thinking block
if allowDummyThought && role == "model" && isThinkingEnabled && i == len(messages)-1 {
hasThoughtPart := false
for _, p := range parts {
if p.Thought {
hasThoughtPart = true
break
}
}
if !hasThoughtPart && len(parts) > 0 {
// 在开头添加 dummy thinking block
parts = append([]GeminiPart{{
Text: "Thinking...",
Thought: true,
}}, parts...)
}
}
if len(parts) == 0 {
continue
}
contents = append(contents, GeminiContent{
Role: role,
Parts: parts,
})
}
return contents, nil
}
// dummyThoughtSignature 用于跳过 Gemini 3 thought_signature 验证
// 参考: https://ai.google.dev/gemini-api/docs/thought-signatures
const dummyThoughtSignature = "skip_thought_signature_validator"
// buildParts 构建消息的 parts
// allowDummyThought: 只有 Gemini 模型支持 dummy thought signature
func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDummyThought bool) ([]GeminiPart, error) {
var parts []GeminiPart
// 尝试解析为字符串
var textContent string
if err := json.Unmarshal(content, &textContent); err == nil {
if textContent != "(no content)" && strings.TrimSpace(textContent) != "" {
parts = append(parts, GeminiPart{Text: strings.TrimSpace(textContent)})
}
return parts, nil
}
// 解析为内容块数组
var blocks []ContentBlock
if err := json.Unmarshal(content, &blocks); err != nil {
return nil, fmt.Errorf("parse content blocks: %w", err)
}
for _, block := range blocks {
switch block.Type {
case "text":
if block.Text != "(no content)" && strings.TrimSpace(block.Text) != "" {
parts = append(parts, GeminiPart{Text: block.Text})
}
case "thinking":
part := GeminiPart{
Text: block.Thinking,
Thought: true,
}
// 保留原有 signature(Claude 模型需要有效的 signature)
if block.Signature != "" {
part.ThoughtSignature = block.Signature
}
parts = append(parts, part)
case "image":
if block.Source != nil && block.Source.Type == "base64" {
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: block.Source.MediaType,
Data: block.Source.Data,
},
})
}
case "tool_use":
// 存储 id -> name 映射
if block.ID != "" && block.Name != "" {
toolIDToName[block.ID] = block.Name
}
part := GeminiPart{
FunctionCall: &GeminiFunctionCall{
Name: block.Name,
Args: block.Input,
ID: block.ID,
},
}
// 保留原有 signature,或对 Gemini 模型使用 dummy signature
if block.Signature != "" {
part.ThoughtSignature = block.Signature
} else if allowDummyThought {
part.ThoughtSignature = dummyThoughtSignature
}
parts = append(parts, part)
case "tool_result":
// 获取函数名
funcName := block.Name
if funcName == "" {
if name, ok := toolIDToName[block.ToolUseID]; ok {
funcName = name
} else {
funcName = block.ToolUseID
}
}
// 解析 content
resultContent := parseToolResultContent(block.Content, block.IsError)
parts = append(parts, GeminiPart{
FunctionResponse: &GeminiFunctionResponse{
Name: funcName,
Response: map[string]any{
"result": resultContent,
},
ID: block.ToolUseID,
},
})
}
}
return parts, nil
}
// parseToolResultContent 解析 tool_result 的 content
func parseToolResultContent(content json.RawMessage, isError bool) string {
if len(content) == 0 {
if isError {
return "Tool execution failed with no output."
}
return "Command executed successfully."
}
// 尝试解析为字符串
var str string
if err := json.Unmarshal(content, &str); err == nil {
if strings.TrimSpace(str) == "" {
if isError {
return "Tool execution failed with no output."
}
return "Command executed successfully."
}
return str
}
// 尝试解析为数组
var arr []map[string]any
if err := json.Unmarshal(content, &arr); err == nil {
var texts []string
for _, item := range arr {
if text, ok := item["text"].(string); ok {
texts = append(texts, text)
}
}
result := strings.Join(texts, "\n")
if strings.TrimSpace(result) == "" {
if isError {
return "Tool execution failed with no output."
}
return "Command executed successfully."
}
return result
}
// 返回原始 JSON
return string(content)
}
// buildGenerationConfig 构建 generationConfig
func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
config := &GeminiGenerationConfig{
MaxOutputTokens: 64000, // 默认最大输出
StopSequences: DefaultStopSequences,
}
// Thinking 配置
if req.Thinking != nil && req.Thinking.Type == "enabled" {
config.ThinkingConfig = &GeminiThinkingConfig{
IncludeThoughts: true,
}
if req.Thinking.BudgetTokens > 0 {
budget := req.Thinking.BudgetTokens
// gemini-2.5-flash 上限 24576
if strings.Contains(req.Model, "gemini-2.5-flash") && budget > 24576 {
budget = 24576
}
config.ThinkingConfig.ThinkingBudget = budget
}
}
// 其他参数
if req.Temperature != nil {
config.Temperature = req.Temperature
}
if req.TopP != nil {
config.TopP = req.TopP
}
if req.TopK != nil {
config.TopK = req.TopK
}
return config
}
// buildTools 构建 tools
func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
if len(tools) == 0 {
return nil
}
// 检查是否有 web_search 工具
hasWebSearch := false
for _, tool := range tools {
if tool.Name == "web_search" {
hasWebSearch = true
break
}
}
if hasWebSearch {
// Web Search 工具映射
return []GeminiToolDeclaration{{
GoogleSearch: &GeminiGoogleSearch{
EnhancedContent: &GeminiEnhancedContent{
ImageSearch: &GeminiImageSearch{
MaxResultCount: 5,
},
},
},
}}
}
// 普通工具
var funcDecls []GeminiFunctionDecl
for _, tool := range tools {
// 清理 JSON Schema
params := cleanJSONSchema(tool.InputSchema)
funcDecls = append(funcDecls, GeminiFunctionDecl{
Name: tool.Name,
Description: tool.Description,
Parameters: params,
})
}
if len(funcDecls) == 0 {
return nil
}
return []GeminiToolDeclaration{{
FunctionDeclarations: funcDecls,
}}
}
// cleanJSONSchema 清理 JSON Schema,移除 Antigravity/Gemini 不支持的字段
// 参考 proxycast 的实现,确保 schema 符合 JSON Schema draft 2020-12
func cleanJSONSchema(schema map[string]any) map[string]any {
if schema == nil {
return nil
}
cleaned := cleanSchemaValue(schema)
result, ok := cleaned.(map[string]any)
if !ok {
return nil
}
// 确保有 type 字段(默认 OBJECT)
if _, hasType := result["type"]; !hasType {
result["type"] = "OBJECT"
}
// 确保有 properties 字段(默认空对象)
if _, hasProps := result["properties"]; !hasProps {
result["properties"] = make(map[string]any)
}
// 验证 required 中的字段都存在于 properties 中
if required, ok := result["required"].([]any); ok {
if props, ok := result["properties"].(map[string]any); ok {
validRequired := make([]any, 0, len(required))
for _, r := range required {
if reqName, ok := r.(string); ok {
if _, exists := props[reqName]; exists {
validRequired = append(validRequired, r)
}
}
}
if len(validRequired) > 0 {
result["required"] = validRequired
} else {
delete(result, "required")
}
}
}
return result
}
// excludedSchemaKeys 不支持的 schema 字段
var excludedSchemaKeys = map[string]bool{
"$schema": true,
"$id": true,
"$ref": true,
"additionalProperties": true,
"minLength": true,
"maxLength": true,
"minItems": true,
"maxItems": true,
"uniqueItems": true,
"minimum": true,
"maximum": true,
"exclusiveMinimum": true,
"exclusiveMaximum": true,
"pattern": true,
"format": true,
"default": true,
"strict": true,
"const": true,
"examples": true,
"deprecated": true,
"readOnly": true,
"writeOnly": true,
"contentMediaType": true,
"contentEncoding": true,
}
// cleanSchemaValue 递归清理 schema 值
func cleanSchemaValue(value any) any {
switch v := value.(type) {
case map[string]any:
result := make(map[string]any)
for k, val := range v {
// 跳过不支持的字段
if excludedSchemaKeys[k] {
continue
}
// 特殊处理 type 字段
if k == "type" {
result[k] = cleanTypeValue(val)
continue
}
// 递归清理所有值
result[k] = cleanSchemaValue(val)
}
return result
case []any:
// 递归处理数组中的每个元素
cleaned := make([]any, 0, len(v))
for _, item := range v {
cleaned = append(cleaned, cleanSchemaValue(item))
}
return cleaned
default:
return value
}
}
// cleanTypeValue 处理 type 字段,转换为大写
func cleanTypeValue(value any) any {
switch v := value.(type) {
case string:
return strings.ToUpper(v)
case []any:
// 联合类型 ["string", "null"] -> 取第一个非 null 类型
for _, t := range v {
if ts, ok := t.(string); ok && ts != "null" {
return strings.ToUpper(ts)
}
}
// 如果只有 null,返回 STRING
return "STRING"
default:
return value
}
}
package antigravity
import (
"encoding/json"
"fmt"
)
// TransformGeminiToClaude 将 Gemini 响应转换为 Claude 格式(非流式)
func TransformGeminiToClaude(geminiResp []byte, originalModel string) ([]byte, *ClaudeUsage, error) {
// 解包 v1internal 响应
var v1Resp V1InternalResponse
if err := json.Unmarshal(geminiResp, &v1Resp); err != nil {
// 尝试直接解析为 GeminiResponse
var directResp GeminiResponse
if err2 := json.Unmarshal(geminiResp, &directResp); err2 != nil {
return nil, nil, fmt.Errorf("parse gemini response: %w", err)
}
v1Resp.Response = directResp
v1Resp.ResponseID = directResp.ResponseID
v1Resp.ModelVersion = directResp.ModelVersion
}
// 使用处理器转换
processor := NewNonStreamingProcessor()
claudeResp := processor.Process(&v1Resp.Response, v1Resp.ResponseID, originalModel)
// 序列化
respBytes, err := json.Marshal(claudeResp)
if err != nil {
return nil, nil, fmt.Errorf("marshal claude response: %w", err)
}
return respBytes, &claudeResp.Usage, nil
}
// NonStreamingProcessor 非流式响应处理器
type NonStreamingProcessor struct {
contentBlocks []ClaudeContentItem
textBuilder string
thinkingBuilder string
thinkingSignature string
trailingSignature string
hasToolCall bool
}
// NewNonStreamingProcessor 创建非流式响应处理器
func NewNonStreamingProcessor() *NonStreamingProcessor {
return &NonStreamingProcessor{
contentBlocks: make([]ClaudeContentItem, 0),
}
}
// Process 处理 Gemini 响应
func (p *NonStreamingProcessor) Process(geminiResp *GeminiResponse, responseID, originalModel string) *ClaudeResponse {
// 获取 parts
var parts []GeminiPart
if len(geminiResp.Candidates) > 0 && geminiResp.Candidates[0].Content != nil {
parts = geminiResp.Candidates[0].Content.Parts
}
// 处理所有 parts
for _, part := range parts {
p.processPart(&part)
}
// 刷新剩余内容
p.flushThinking()
p.flushText()
// 处理 trailingSignature
if p.trailingSignature != "" {
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
Type: "thinking",
Thinking: "",
Signature: p.trailingSignature,
})
}
// 构建响应
return p.buildResponse(geminiResp, responseID, originalModel)
}
// processPart 处理单个 part
func (p *NonStreamingProcessor) processPart(part *GeminiPart) {
signature := part.ThoughtSignature
// 1. FunctionCall 处理
if part.FunctionCall != nil {
p.flushThinking()
p.flushText()
// 处理 trailingSignature
if p.trailingSignature != "" {
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
Type: "thinking",
Thinking: "",
Signature: p.trailingSignature,
})
p.trailingSignature = ""
}
p.hasToolCall = true
// 生成 tool_use id
toolID := part.FunctionCall.ID
if toolID == "" {
toolID = fmt.Sprintf("%s-%s", part.FunctionCall.Name, generateRandomID())
}
item := ClaudeContentItem{
Type: "tool_use",
ID: toolID,
Name: part.FunctionCall.Name,
Input: part.FunctionCall.Args,
}
if signature != "" {
item.Signature = signature
}
p.contentBlocks = append(p.contentBlocks, item)
return
}
// 2. Text 处理
if part.Text != "" || part.Thought {
if part.Thought {
// Thinking part
p.flushText()
// 处理 trailingSignature
if p.trailingSignature != "" {
p.flushThinking()
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
Type: "thinking",
Thinking: "",
Signature: p.trailingSignature,
})
p.trailingSignature = ""
}
p.thinkingBuilder += part.Text
if signature != "" {
p.thinkingSignature = signature
}
} else {
// 普通 Text
if part.Text == "" {
// 空 text 带签名 - 暂存
if signature != "" {
p.trailingSignature = signature
}
return
}
p.flushThinking()
// 处理之前的 trailingSignature
if p.trailingSignature != "" {
p.flushText()
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
Type: "thinking",
Thinking: "",
Signature: p.trailingSignature,
})
p.trailingSignature = ""
}
p.textBuilder += part.Text
// 非空 text 带签名 - 立即刷新并输出空 thinking 块
if signature != "" {
p.flushText()
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
Type: "thinking",
Thinking: "",
Signature: signature,
})
}
}
}
// 3. InlineData (Image) 处理
if part.InlineData != nil && part.InlineData.Data != "" {
p.flushThinking()
markdownImg := fmt.Sprintf("![image](data:%s;base64,%s)",
part.InlineData.MimeType, part.InlineData.Data)
p.textBuilder += markdownImg
p.flushText()
}
}
// flushText 刷新 text builder
func (p *NonStreamingProcessor) flushText() {
if p.textBuilder == "" {
return
}
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
Type: "text",
Text: p.textBuilder,
})
p.textBuilder = ""
}
// flushThinking 刷新 thinking builder
func (p *NonStreamingProcessor) flushThinking() {
if p.thinkingBuilder == "" && p.thinkingSignature == "" {
return
}
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
Type: "thinking",
Thinking: p.thinkingBuilder,
Signature: p.thinkingSignature,
})
p.thinkingBuilder = ""
p.thinkingSignature = ""
}
// buildResponse 构建最终响应
func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, responseID, originalModel string) *ClaudeResponse {
var finishReason string
if len(geminiResp.Candidates) > 0 {
finishReason = geminiResp.Candidates[0].FinishReason
}
stopReason := "end_turn"
if p.hasToolCall {
stopReason = "tool_use"
} else if finishReason == "MAX_TOKENS" {
stopReason = "max_tokens"
}
usage := ClaudeUsage{}
if geminiResp.UsageMetadata != nil {
usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount
usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount
}
// 生成响应 ID
respID := responseID
if respID == "" {
respID = geminiResp.ResponseID
}
if respID == "" {
respID = "msg_" + generateRandomID()
}
return &ClaudeResponse{
ID: respID,
Type: "message",
Role: "assistant",
Model: originalModel,
Content: p.contentBlocks,
StopReason: stopReason,
Usage: usage,
}
}
// generateRandomID 生成随机 ID
func generateRandomID() string {
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
result := make([]byte, 12)
for i := range result {
result[i] = chars[i%len(chars)]
}
return string(result)
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment