Commit 3d79773b authored by kyx236's avatar kyx236
Browse files

Merge branch 'main' of https://github.com/james-6-23/sub2api

parents 6aa8cbbf 742e73c9
......@@ -21,11 +21,18 @@ var (
// - "" (默认): 使用 /v1/messages, /v1beta/models(混合模式,可调度 antigravity 账户)
// - "/antigravity": 使用 /antigravity/v1/messages, /antigravity/v1beta/models(非混合模式,仅 antigravity 账户)
endpointPrefix = getEnv("ENDPOINT_PREFIX", "")
claudeAPIKey = "sk-8e572bc3b3de92ace4f41f4256c28600ca11805732a7b693b5c44741346bbbb3"
geminiAPIKey = "sk-5950197a2085b38bbe5a1b229cc02b8ece914963fc44cacc06d497ae8b87410f"
testInterval = 1 * time.Second // 测试间隔,防止限流
)
const (
// 注意:E2E 测试请使用环境变量注入密钥,避免任何凭证进入仓库历史。
// 例如:
// export CLAUDE_API_KEY="sk-..."
// export GEMINI_API_KEY="sk-..."
claudeAPIKeyEnv = "CLAUDE_API_KEY"
geminiAPIKeyEnv = "GEMINI_API_KEY"
)
func getEnv(key, defaultVal string) string {
if v := os.Getenv(key); v != "" {
return v
......@@ -65,16 +72,45 @@ func TestMain(m *testing.M) {
if endpointPrefix != "" {
mode = "Antigravity 模式"
}
fmt.Printf("\n🚀 E2E Gateway Tests - %s (prefix=%q, %s)\n\n", baseURL, endpointPrefix, mode)
claudeKeySet := strings.TrimSpace(os.Getenv(claudeAPIKeyEnv)) != ""
geminiKeySet := strings.TrimSpace(os.Getenv(geminiAPIKeyEnv)) != ""
fmt.Printf("\n🚀 E2E Gateway Tests - %s (prefix=%q, %s, %s=%v, %s=%v)\n\n",
baseURL,
endpointPrefix,
mode,
claudeAPIKeyEnv,
claudeKeySet,
geminiAPIKeyEnv,
geminiKeySet,
)
os.Exit(m.Run())
}
func requireClaudeAPIKey(t *testing.T) string {
t.Helper()
key := strings.TrimSpace(os.Getenv(claudeAPIKeyEnv))
if key == "" {
t.Skipf("未设置 %s,跳过 Claude 相关 E2E 测试", claudeAPIKeyEnv)
}
return key
}
func requireGeminiAPIKey(t *testing.T) string {
t.Helper()
key := strings.TrimSpace(os.Getenv(geminiAPIKeyEnv))
if key == "" {
t.Skipf("未设置 %s,跳过 Gemini 相关 E2E 测试", geminiAPIKeyEnv)
}
return key
}
// TestClaudeModelsList 测试 GET /v1/models
func TestClaudeModelsList(t *testing.T) {
claudeKey := requireClaudeAPIKey(t)
url := baseURL + endpointPrefix + "/v1/models"
req, _ := http.NewRequest("GET", url, nil)
req.Header.Set("Authorization", "Bearer "+claudeAPIKey)
req.Header.Set("Authorization", "Bearer "+claudeKey)
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
......@@ -106,10 +142,11 @@ func TestClaudeModelsList(t *testing.T) {
// TestGeminiModelsList 测试 GET /v1beta/models
func TestGeminiModelsList(t *testing.T) {
geminiKey := requireGeminiAPIKey(t)
url := baseURL + endpointPrefix + "/v1beta/models"
req, _ := http.NewRequest("GET", url, nil)
req.Header.Set("Authorization", "Bearer "+geminiAPIKey)
req.Header.Set("Authorization", "Bearer "+geminiKey)
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
......@@ -137,21 +174,22 @@ func TestGeminiModelsList(t *testing.T) {
// TestClaudeMessages 测试 Claude /v1/messages 接口
func TestClaudeMessages(t *testing.T) {
claudeKey := requireClaudeAPIKey(t)
for i, model := range claudeModels {
if i > 0 {
time.Sleep(testInterval)
}
t.Run(model+"_非流式", func(t *testing.T) {
testClaudeMessage(t, model, false)
testClaudeMessage(t, claudeKey, model, false)
})
time.Sleep(testInterval)
t.Run(model+"_流式", func(t *testing.T) {
testClaudeMessage(t, model, true)
testClaudeMessage(t, claudeKey, model, true)
})
}
}
func testClaudeMessage(t *testing.T, model string, stream bool) {
func testClaudeMessage(t *testing.T, claudeKey string, model string, stream bool) {
url := baseURL + endpointPrefix + "/v1/messages"
payload := map[string]any{
......@@ -166,7 +204,7 @@ func testClaudeMessage(t *testing.T, model string, stream bool) {
req, _ := http.NewRequest("POST", url, bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+claudeAPIKey)
req.Header.Set("Authorization", "Bearer "+claudeKey)
req.Header.Set("anthropic-version", "2023-06-01")
client := &http.Client{Timeout: 60 * time.Second}
......@@ -213,21 +251,22 @@ func testClaudeMessage(t *testing.T, model string, stream bool) {
// TestGeminiGenerateContent 测试 Gemini /v1beta/models/:model 接口
func TestGeminiGenerateContent(t *testing.T) {
geminiKey := requireGeminiAPIKey(t)
for i, model := range geminiModels {
if i > 0 {
time.Sleep(testInterval)
}
t.Run(model+"_非流式", func(t *testing.T) {
testGeminiGenerate(t, model, false)
testGeminiGenerate(t, geminiKey, model, false)
})
time.Sleep(testInterval)
t.Run(model+"_流式", func(t *testing.T) {
testGeminiGenerate(t, model, true)
testGeminiGenerate(t, geminiKey, model, true)
})
}
}
func testGeminiGenerate(t *testing.T, model string, stream bool) {
func testGeminiGenerate(t *testing.T, geminiKey string, model string, stream bool) {
action := "generateContent"
if stream {
action = "streamGenerateContent"
......@@ -254,7 +293,7 @@ func testGeminiGenerate(t *testing.T, model string, stream bool) {
req, _ := http.NewRequest("POST", url, bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+geminiAPIKey)
req.Header.Set("Authorization", "Bearer "+geminiKey)
client := &http.Client{Timeout: 60 * time.Second}
resp, err := client.Do(req)
......@@ -301,6 +340,7 @@ func testGeminiGenerate(t *testing.T, model string, stream bool) {
// TestClaudeMessagesWithComplexTools 测试带复杂工具 schema 的请求
// 模拟 Claude Code 发送的请求,包含需要清理的 JSON Schema 字段
func TestClaudeMessagesWithComplexTools(t *testing.T) {
claudeKey := requireClaudeAPIKey(t)
// 测试模型列表(只测试几个代表性模型)
models := []string{
"claude-opus-4-5-20251101", // Claude 模型
......@@ -312,12 +352,12 @@ func TestClaudeMessagesWithComplexTools(t *testing.T) {
time.Sleep(testInterval)
}
t.Run(model+"_复杂工具", func(t *testing.T) {
testClaudeMessageWithTools(t, model)
testClaudeMessageWithTools(t, claudeKey, model)
})
}
}
func testClaudeMessageWithTools(t *testing.T, model string) {
func testClaudeMessageWithTools(t *testing.T, claudeKey string, model string) {
url := baseURL + endpointPrefix + "/v1/messages"
// 构造包含复杂 schema 的工具定义(模拟 Claude Code 的工具)
......@@ -473,7 +513,7 @@ func testClaudeMessageWithTools(t *testing.T, model string) {
req, _ := http.NewRequest("POST", url, bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+claudeAPIKey)
req.Header.Set("Authorization", "Bearer "+claudeKey)
req.Header.Set("anthropic-version", "2023-06-01")
client := &http.Client{Timeout: 60 * time.Second}
......@@ -519,6 +559,7 @@ func testClaudeMessageWithTools(t *testing.T, model string) {
// 验证:当历史 assistant 消息包含 tool_use 但没有 signature 时,
// 系统应自动添加 dummy thought_signature 避免 Gemini 400 错误
func TestClaudeMessagesWithThinkingAndTools(t *testing.T) {
claudeKey := requireClaudeAPIKey(t)
models := []string{
"claude-haiku-4-5-20251001", // gemini-3-flash
}
......@@ -527,12 +568,12 @@ func TestClaudeMessagesWithThinkingAndTools(t *testing.T) {
time.Sleep(testInterval)
}
t.Run(model+"_thinking模式工具调用", func(t *testing.T) {
testClaudeThinkingWithToolHistory(t, model)
testClaudeThinkingWithToolHistory(t, claudeKey, model)
})
}
}
func testClaudeThinkingWithToolHistory(t *testing.T, model string) {
func testClaudeThinkingWithToolHistory(t *testing.T, claudeKey string, model string) {
url := baseURL + endpointPrefix + "/v1/messages"
// 模拟历史对话:用户请求 → assistant 调用工具 → 工具返回 → 继续对话
......@@ -600,7 +641,7 @@ func testClaudeThinkingWithToolHistory(t *testing.T, model string) {
req, _ := http.NewRequest("POST", url, bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+claudeAPIKey)
req.Header.Set("Authorization", "Bearer "+claudeKey)
req.Header.Set("anthropic-version", "2023-06-01")
client := &http.Client{Timeout: 60 * time.Second}
......@@ -649,6 +690,7 @@ func TestClaudeMessagesWithGeminiModel(t *testing.T) {
if endpointPrefix != "/antigravity" {
t.Skip("仅在 Antigravity 模式下运行")
}
claudeKey := requireClaudeAPIKey(t)
// 测试通过 Claude 端点调用 Gemini 模型
geminiViaClaude := []string{
......@@ -664,11 +706,11 @@ func TestClaudeMessagesWithGeminiModel(t *testing.T) {
time.Sleep(testInterval)
}
t.Run(model+"_通过Claude端点", func(t *testing.T) {
testClaudeMessage(t, model, false)
testClaudeMessage(t, claudeKey, model, false)
})
time.Sleep(testInterval)
t.Run(model+"_通过Claude端点_流式", func(t *testing.T) {
testClaudeMessage(t, model, true)
testClaudeMessage(t, claudeKey, model, true)
})
}
}
......@@ -676,6 +718,7 @@ func TestClaudeMessagesWithGeminiModel(t *testing.T) {
// TestClaudeMessagesWithNoSignature 测试历史 thinking block 不带 signature 的场景
// 验证:Gemini 模型接受没有 signature 的 thinking block
func TestClaudeMessagesWithNoSignature(t *testing.T) {
claudeKey := requireClaudeAPIKey(t)
models := []string{
"claude-haiku-4-5-20251001", // gemini-3-flash - 支持无 signature
}
......@@ -684,12 +727,12 @@ func TestClaudeMessagesWithNoSignature(t *testing.T) {
time.Sleep(testInterval)
}
t.Run(model+"_无signature", func(t *testing.T) {
testClaudeWithNoSignature(t, model)
testClaudeWithNoSignature(t, claudeKey, model)
})
}
}
func testClaudeWithNoSignature(t *testing.T, model string) {
func testClaudeWithNoSignature(t *testing.T, claudeKey string, model string) {
url := baseURL + endpointPrefix + "/v1/messages"
// 模拟历史对话包含 thinking block 但没有 signature
......@@ -732,7 +775,7 @@ func testClaudeWithNoSignature(t *testing.T, model string) {
req, _ := http.NewRequest("POST", url, bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+claudeAPIKey)
req.Header.Set("Authorization", "Bearer "+claudeKey)
req.Header.Set("anthropic-version", "2023-06-01")
client := &http.Client{Timeout: 60 * time.Second}
......@@ -777,6 +820,7 @@ func TestGeminiEndpointWithClaudeModel(t *testing.T) {
if endpointPrefix != "/antigravity" {
t.Skip("仅在 Antigravity 模式下运行")
}
geminiKey := requireGeminiAPIKey(t)
// 测试通过 Gemini 端点调用 Claude 模型
claudeViaGemini := []string{
......@@ -789,11 +833,11 @@ func TestGeminiEndpointWithClaudeModel(t *testing.T) {
time.Sleep(testInterval)
}
t.Run(model+"_通过Gemini端点", func(t *testing.T) {
testGeminiGenerate(t, model, false)
testGeminiGenerate(t, geminiKey, model, false)
})
time.Sleep(testInterval)
t.Run(model+"_通过Gemini端点_流式", func(t *testing.T) {
testGeminiGenerate(t, model, true)
testGeminiGenerate(t, geminiKey, model, true)
})
}
}
//go:build e2e
package integration
import (
"os"
"strings"
"testing"
)
// =============================================================================
// E2E Mock 模式支持
// =============================================================================
// 当 E2E_MOCK=true 时,使用本地 Mock 响应替代真实 API 调用。
// 这允许在没有真实 API Key 的环境(如 CI)中验证基本的请求/响应流程。
// isMockMode 检查是否启用 Mock 模式
func isMockMode() bool {
return strings.EqualFold(os.Getenv("E2E_MOCK"), "true")
}
// skipIfNoRealAPI 如果未配置真实 API Key 且不在 Mock 模式,则跳过测试
func skipIfNoRealAPI(t *testing.T) {
t.Helper()
if isMockMode() {
return // Mock 模式下不跳过
}
claudeKey := strings.TrimSpace(os.Getenv(claudeAPIKeyEnv))
geminiKey := strings.TrimSpace(os.Getenv(geminiAPIKeyEnv))
if claudeKey == "" && geminiKey == "" {
t.Skip("未设置 API Key 且未启用 Mock 模式,跳过测试")
}
}
// =============================================================================
// API Key 脱敏(Task 6.10)
// =============================================================================
// safeLogKey 安全地记录 API Key(仅显示前 8 位)
func safeLogKey(t *testing.T, prefix string, key string) {
t.Helper()
key = strings.TrimSpace(key)
if len(key) <= 8 {
t.Logf("%s: ***(长度: %d)", prefix, len(key))
return
}
t.Logf("%s: %s...(长度: %d)", prefix, key[:8], len(key))
}
//go:build e2e
package integration
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"testing"
"time"
)
// E2E 用户流程测试
// 测试完整的用户操作链路:注册 → 登录 → 创建 API Key → 调用网关 → 查询用量
var (
testUserEmail = "e2e-test-" + fmt.Sprintf("%d", time.Now().UnixMilli()) + "@test.local"
testUserPassword = "E2eTest@12345"
testUserName = "e2e-test-user"
)
// TestUserRegistrationAndLogin 测试用户注册和登录流程
func TestUserRegistrationAndLogin(t *testing.T) {
// 步骤 1: 注册新用户
t.Run("注册新用户", func(t *testing.T) {
payload := map[string]string{
"email": testUserEmail,
"password": testUserPassword,
"username": testUserName,
}
body, _ := json.Marshal(payload)
resp, err := doRequest(t, "POST", "/api/auth/register", body, "")
if err != nil {
t.Skipf("注册接口不可用,跳过用户流程测试: %v", err)
return
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
// 注册可能返回 200(成功)或 400(邮箱已存在)或 403(注册已关闭)
switch resp.StatusCode {
case 200:
t.Logf("✅ 用户注册成功: %s", testUserEmail)
case 400:
t.Logf("⚠️ 用户可能已存在: %s", string(respBody))
case 403:
t.Skipf("注册功能已关闭: %s", string(respBody))
default:
t.Logf("⚠️ 注册返回 HTTP %d: %s(继续尝试登录)", resp.StatusCode, string(respBody))
}
})
// 步骤 2: 登录获取 JWT
var accessToken string
t.Run("用户登录获取JWT", func(t *testing.T) {
payload := map[string]string{
"email": testUserEmail,
"password": testUserPassword,
}
body, _ := json.Marshal(payload)
resp, err := doRequest(t, "POST", "/api/auth/login", body, "")
if err != nil {
t.Fatalf("登录请求失败: %v", err)
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != 200 {
t.Skipf("登录失败 HTTP %d: %s(可能需要先注册用户)", resp.StatusCode, string(respBody))
return
}
var result map[string]any
if err := json.Unmarshal(respBody, &result); err != nil {
t.Fatalf("解析登录响应失败: %v", err)
}
// 尝试从标准响应格式获取 token
if token, ok := result["access_token"].(string); ok && token != "" {
accessToken = token
} else if data, ok := result["data"].(map[string]any); ok {
if token, ok := data["access_token"].(string); ok {
accessToken = token
}
}
if accessToken == "" {
t.Skipf("未获取到 access_token,响应: %s", string(respBody))
return
}
// 验证 token 不为空且格式基本正确
if len(accessToken) < 10 {
t.Fatalf("access_token 格式异常: %s", accessToken)
}
t.Logf("✅ 登录成功,获取 JWT(长度: %d)", len(accessToken))
})
if accessToken == "" {
t.Skip("未获取到 JWT,跳过后续测试")
return
}
// 步骤 3: 使用 JWT 获取当前用户信息
t.Run("获取当前用户信息", func(t *testing.T) {
resp, err := doRequest(t, "GET", "/api/user/me", nil, accessToken)
if err != nil {
t.Fatalf("请求失败: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
body, _ := io.ReadAll(resp.Body)
t.Fatalf("HTTP %d: %s", resp.StatusCode, string(body))
}
t.Logf("✅ 成功获取用户信息")
})
}
// TestAPIKeyLifecycle 测试 API Key 的创建和使用
func TestAPIKeyLifecycle(t *testing.T) {
// 先登录获取 JWT
accessToken := loginTestUser(t)
if accessToken == "" {
t.Skip("无法登录,跳过 API Key 生命周期测试")
return
}
var apiKey string
// 步骤 1: 创建 API Key
t.Run("创建API_Key", func(t *testing.T) {
payload := map[string]string{
"name": "e2e-test-key-" + fmt.Sprintf("%d", time.Now().UnixMilli()),
}
body, _ := json.Marshal(payload)
resp, err := doRequest(t, "POST", "/api/keys", body, accessToken)
if err != nil {
t.Fatalf("创建 API Key 请求失败: %v", err)
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != 200 {
t.Skipf("创建 API Key 失败 HTTP %d: %s", resp.StatusCode, string(respBody))
return
}
var result map[string]any
if err := json.Unmarshal(respBody, &result); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
// 从响应中提取 key
if key, ok := result["key"].(string); ok {
apiKey = key
} else if data, ok := result["data"].(map[string]any); ok {
if key, ok := data["key"].(string); ok {
apiKey = key
}
}
if apiKey == "" {
t.Skipf("未获取到 API Key,响应: %s", string(respBody))
return
}
// 验证 API Key 脱敏日志(只显示前 8 位)
masked := apiKey
if len(masked) > 8 {
masked = masked[:8] + "..."
}
t.Logf("✅ API Key 创建成功: %s", masked)
})
if apiKey == "" {
t.Skip("未创建 API Key,跳过后续测试")
return
}
// 步骤 2: 使用 API Key 调用网关(需要 Claude 或 Gemini 可用)
t.Run("使用API_Key调用网关", func(t *testing.T) {
// 尝试调用 models 列表(最轻量的 API 调用)
resp, err := doRequest(t, "GET", "/v1/models", nil, apiKey)
if err != nil {
t.Fatalf("网关请求失败: %v", err)
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
// 可能返回 200(成功)或 402(余额不足)或 403(无可用账户)
switch {
case resp.StatusCode == 200:
t.Logf("✅ API Key 网关调用成功")
case resp.StatusCode == 402:
t.Logf("⚠️ 余额不足,但 API Key 认证通过")
case resp.StatusCode == 403:
t.Logf("⚠️ 无可用账户,但 API Key 认证通过")
default:
t.Logf("⚠️ 网关返回 HTTP %d: %s", resp.StatusCode, string(respBody))
}
})
// 步骤 3: 查询用量记录
t.Run("查询用量记录", func(t *testing.T) {
resp, err := doRequest(t, "GET", "/api/usage/dashboard", nil, accessToken)
if err != nil {
t.Fatalf("用量查询请求失败: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
body, _ := io.ReadAll(resp.Body)
t.Logf("⚠️ 用量查询返回 HTTP %d: %s", resp.StatusCode, string(body))
return
}
t.Logf("✅ 用量查询成功")
})
}
// =============================================================================
// 辅助函数
// =============================================================================
func doRequest(t *testing.T, method, path string, body []byte, token string) (*http.Response, error) {
t.Helper()
url := baseURL + path
var bodyReader io.Reader
if body != nil {
bodyReader = bytes.NewReader(body)
}
req, err := http.NewRequest(method, url, bodyReader)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
if body != nil {
req.Header.Set("Content-Type", "application/json")
}
if token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
client := &http.Client{Timeout: 30 * time.Second}
return client.Do(req)
}
func loginTestUser(t *testing.T) string {
t.Helper()
// 先尝试用管理员账户登录
adminEmail := getEnv("ADMIN_EMAIL", "admin@sub2api.local")
adminPassword := getEnv("ADMIN_PASSWORD", "")
if adminPassword == "" {
// 尝试用测试用户
adminEmail = testUserEmail
adminPassword = testUserPassword
}
payload := map[string]string{
"email": adminEmail,
"password": adminPassword,
}
body, _ := json.Marshal(payload)
resp, err := doRequest(t, "POST", "/api/auth/login", body, "")
if err != nil {
return ""
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
return ""
}
respBody, _ := io.ReadAll(resp.Body)
var result map[string]any
if err := json.Unmarshal(respBody, &result); err != nil {
return ""
}
if token, ok := result["access_token"].(string); ok {
return token
}
if data, ok := result["data"].(map[string]any); ok {
if token, ok := data["access_token"].(string); ok {
return token
}
}
return ""
}
// redactAPIKey API Key 脱敏,只显示前 8 位
func redactAPIKey(key string) string {
key = strings.TrimSpace(key)
if len(key) <= 8 {
return "***"
}
return key[:8] + "..."
}
......@@ -60,6 +60,49 @@ func TestRateLimiterFailureModes(t *testing.T) {
require.Equal(t, http.StatusTooManyRequests, recorder.Code)
}
func TestRateLimiterDifferentIPsIndependent(t *testing.T) {
gin.SetMode(gin.TestMode)
callCounts := make(map[string]int64)
originalRun := rateLimitRun
rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, bool, error) {
callCounts[key]++
return callCounts[key], false, nil
}
t.Cleanup(func() {
rateLimitRun = originalRun
})
limiter := NewRateLimiter(redis.NewClient(&redis.Options{Addr: "127.0.0.1:1"}))
router := gin.New()
router.Use(limiter.Limit("api", 1, time.Second))
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
// 第一个 IP 的请求应通过
req1 := httptest.NewRequest(http.MethodGet, "/test", nil)
req1.RemoteAddr = "10.0.0.1:1234"
rec1 := httptest.NewRecorder()
router.ServeHTTP(rec1, req1)
require.Equal(t, http.StatusOK, rec1.Code, "第一个 IP 的第一次请求应通过")
// 第二个 IP 的请求应独立通过(不受第一个 IP 的计数影响)
req2 := httptest.NewRequest(http.MethodGet, "/test", nil)
req2.RemoteAddr = "10.0.0.2:5678"
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
require.Equal(t, http.StatusOK, rec2.Code, "第二个 IP 的第一次请求应独立通过")
// 第一个 IP 的第二次请求应被限流
req3 := httptest.NewRequest(http.MethodGet, "/test", nil)
req3.RemoteAddr = "10.0.0.1:1234"
rec3 := httptest.NewRecorder()
router.ServeHTTP(rec3, req3)
require.Equal(t, http.StatusTooManyRequests, rec3.Code, "第一个 IP 的第二次请求应被限流")
}
func TestRateLimiterSuccessAndLimit(t *testing.T) {
gin.SetMode(gin.TestMode)
......
......@@ -151,6 +151,9 @@ var claudeModels = []modelDef{
{ID: "claude-opus-4-5-thinking", DisplayName: "Claude Opus 4.5 Thinking", CreatedAt: "2025-11-01T00:00:00Z"},
{ID: "claude-sonnet-4-5", DisplayName: "Claude Sonnet 4.5", CreatedAt: "2025-09-29T00:00:00Z"},
{ID: "claude-sonnet-4-5-thinking", DisplayName: "Claude Sonnet 4.5 Thinking", CreatedAt: "2025-09-29T00:00:00Z"},
{ID: "claude-opus-4-6", DisplayName: "Claude Opus 4.6", CreatedAt: "2026-02-05T00:00:00Z"},
{ID: "claude-opus-4-6-thinking", DisplayName: "Claude Opus 4.6 Thinking", CreatedAt: "2026-02-05T00:00:00Z"},
{ID: "claude-sonnet-4-6", DisplayName: "Claude Sonnet 4.6", CreatedAt: "2026-02-17T00:00:00Z"},
}
// Antigravity 支持的 Gemini 模型
......@@ -161,6 +164,10 @@ var geminiModels = []modelDef{
{ID: "gemini-3-flash", DisplayName: "Gemini 3 Flash", CreatedAt: "2025-06-01T00:00:00Z"},
{ID: "gemini-3-pro-low", DisplayName: "Gemini 3 Pro Low", CreatedAt: "2025-06-01T00:00:00Z"},
{ID: "gemini-3-pro-high", DisplayName: "Gemini 3 Pro High", CreatedAt: "2025-06-01T00:00:00Z"},
{ID: "gemini-3.1-pro-low", DisplayName: "Gemini 3.1 Pro Low", CreatedAt: "2026-02-19T00:00:00Z"},
{ID: "gemini-3.1-pro-high", DisplayName: "Gemini 3.1 Pro High", CreatedAt: "2026-02-19T00:00:00Z"},
{ID: "gemini-3.1-flash-image", DisplayName: "Gemini 3.1 Flash Image", CreatedAt: "2026-02-19T00:00:00Z"},
{ID: "gemini-3.1-flash-image-preview", DisplayName: "Gemini 3.1 Flash Image Preview", CreatedAt: "2026-02-19T00:00:00Z"},
{ID: "gemini-3-pro-preview", DisplayName: "Gemini 3 Pro Preview", CreatedAt: "2025-06-01T00:00:00Z"},
{ID: "gemini-3-pro-image", DisplayName: "Gemini 3 Pro Image", CreatedAt: "2025-06-01T00:00:00Z"},
}
......
package antigravity
import "testing"
func TestDefaultModels_ContainsNewAndLegacyImageModels(t *testing.T) {
t.Parallel()
models := DefaultModels()
byID := make(map[string]ClaudeModel, len(models))
for _, m := range models {
byID[m.ID] = m
}
requiredIDs := []string{
"claude-opus-4-6-thinking",
"gemini-3.1-flash-image",
"gemini-3.1-flash-image-preview",
"gemini-3-pro-image", // legacy compatibility
}
for _, id := range requiredIDs {
if _, ok := byID[id]; !ok {
t.Fatalf("expected model %q to be exposed in DefaultModels", id)
}
}
}
......@@ -14,6 +14,9 @@ import (
"net/url"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
)
// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点)
......@@ -33,7 +36,7 @@ func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken stri
// 基础 Headers(与 Antigravity-Manager 保持一致,只设置这 3 个)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("User-Agent", UserAgent)
req.Header.Set("User-Agent", GetUserAgent())
return req, nil
}
......@@ -149,22 +152,26 @@ type Client struct {
httpClient *http.Client
}
func NewClient(proxyURL string) *Client {
func NewClient(proxyURL string) (*Client, error) {
client := &http.Client{
Timeout: 30 * time.Second,
}
if strings.TrimSpace(proxyURL) != "" {
if proxyURLParsed, err := url.Parse(proxyURL); err == nil {
client.Transport = &http.Transport{
Proxy: http.ProxyURL(proxyURLParsed),
_, parsed, err := proxyurl.Parse(proxyURL)
if err != nil {
return nil, err
}
if parsed != nil {
transport := &http.Transport{}
if err := proxyutil.ConfigureTransportProxy(transport, parsed); err != nil {
return nil, fmt.Errorf("configure proxy: %w", err)
}
client.Transport = transport
}
return &Client{
httpClient: client,
}
}, nil
}
// isConnectionError 判断是否为连接错误(网络超时、DNS 失败、连接拒绝)
......@@ -204,9 +211,14 @@ func shouldFallbackToNextURL(err error, statusCode int) bool {
// ExchangeCode 用 authorization code 交换 token
func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (*TokenResponse, error) {
clientSecret, err := getClientSecret()
if err != nil {
return nil, err
}
params := url.Values{}
params.Set("client_id", ClientID)
params.Set("client_secret", ClientSecret)
params.Set("client_secret", clientSecret)
params.Set("code", code)
params.Set("redirect_uri", RedirectURI)
params.Set("grant_type", "authorization_code")
......@@ -243,9 +255,14 @@ func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (*
// RefreshToken 刷新 access_token
func (c *Client) RefreshToken(ctx context.Context, refreshToken string) (*TokenResponse, error) {
clientSecret, err := getClientSecret()
if err != nil {
return nil, err
}
params := url.Values{}
params.Set("client_id", ClientID)
params.Set("client_secret", ClientSecret)
params.Set("client_secret", clientSecret)
params.Set("refresh_token", refreshToken)
params.Set("grant_type", "refresh_token")
......@@ -333,7 +350,7 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", UserAgent)
req.Header.Set("User-Agent", GetUserAgent())
resp, err := c.httpClient.Do(req)
if err != nil {
......@@ -412,7 +429,7 @@ func (c *Client) OnboardUser(ctx context.Context, accessToken, tierID string) (s
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", UserAgent)
req.Header.Set("User-Agent", GetUserAgent())
resp, err := c.httpClient.Do(req)
if err != nil {
......@@ -532,7 +549,7 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", UserAgent)
req.Header.Set("User-Agent", GetUserAgent())
resp, err := c.httpClient.Do(req)
if err != nil {
......
//go:build unit
package antigravity
import (
"context"
"encoding/json"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
)
// ---------------------------------------------------------------------------
// NewAPIRequestWithURL
// ---------------------------------------------------------------------------
func TestNewAPIRequestWithURL_普通请求(t *testing.T) {
ctx := context.Background()
baseURL := "https://example.com"
action := "generateContent"
token := "test-token"
body := []byte(`{"prompt":"hello"}`)
req, err := NewAPIRequestWithURL(ctx, baseURL, action, token, body)
if err != nil {
t.Fatalf("创建请求失败: %v", err)
}
// 验证 URL 不含 ?alt=sse
expectedURL := "https://example.com/v1internal:generateContent"
if req.URL.String() != expectedURL {
t.Errorf("URL 不匹配: got %s, want %s", req.URL.String(), expectedURL)
}
// 验证请求方法
if req.Method != http.MethodPost {
t.Errorf("请求方法不匹配: got %s, want POST", req.Method)
}
// 验证 Headers
if ct := req.Header.Get("Content-Type"); ct != "application/json" {
t.Errorf("Content-Type 不匹配: got %s", ct)
}
if auth := req.Header.Get("Authorization"); auth != "Bearer test-token" {
t.Errorf("Authorization 不匹配: got %s", auth)
}
if ua := req.Header.Get("User-Agent"); ua != GetUserAgent() {
t.Errorf("User-Agent 不匹配: got %s, want %s", ua, GetUserAgent())
}
}
func TestNewAPIRequestWithURL_流式请求(t *testing.T) {
ctx := context.Background()
baseURL := "https://example.com"
action := "streamGenerateContent"
token := "tok"
body := []byte(`{}`)
req, err := NewAPIRequestWithURL(ctx, baseURL, action, token, body)
if err != nil {
t.Fatalf("创建请求失败: %v", err)
}
expectedURL := "https://example.com/v1internal:streamGenerateContent?alt=sse"
if req.URL.String() != expectedURL {
t.Errorf("URL 不匹配: got %s, want %s", req.URL.String(), expectedURL)
}
}
func TestNewAPIRequestWithURL_空Body(t *testing.T) {
ctx := context.Background()
req, err := NewAPIRequestWithURL(ctx, "https://example.com", "test", "tok", nil)
if err != nil {
t.Fatalf("创建请求失败: %v", err)
}
if req.Body == nil {
t.Error("Body 应该非 nil(bytes.NewReader(nil) 会返回空 reader)")
}
}
// ---------------------------------------------------------------------------
// NewAPIRequest
// ---------------------------------------------------------------------------
func TestNewAPIRequest_使用默认URL(t *testing.T) {
ctx := context.Background()
req, err := NewAPIRequest(ctx, "generateContent", "tok", []byte(`{}`))
if err != nil {
t.Fatalf("创建请求失败: %v", err)
}
expected := BaseURL + "/v1internal:generateContent"
if req.URL.String() != expected {
t.Errorf("URL 不匹配: got %s, want %s", req.URL.String(), expected)
}
}
// ---------------------------------------------------------------------------
// TierInfo.UnmarshalJSON
// ---------------------------------------------------------------------------
func TestTierInfo_UnmarshalJSON_字符串格式(t *testing.T) {
data := []byte(`"free-tier"`)
var tier TierInfo
if err := tier.UnmarshalJSON(data); err != nil {
t.Fatalf("反序列化失败: %v", err)
}
if tier.ID != "free-tier" {
t.Errorf("ID 不匹配: got %s, want free-tier", tier.ID)
}
if tier.Name != "" {
t.Errorf("Name 应为空: got %s", tier.Name)
}
}
func TestTierInfo_UnmarshalJSON_对象格式(t *testing.T) {
data := []byte(`{"id":"g1-pro-tier","name":"Pro","description":"Pro plan"}`)
var tier TierInfo
if err := tier.UnmarshalJSON(data); err != nil {
t.Fatalf("反序列化失败: %v", err)
}
if tier.ID != "g1-pro-tier" {
t.Errorf("ID 不匹配: got %s, want g1-pro-tier", tier.ID)
}
if tier.Name != "Pro" {
t.Errorf("Name 不匹配: got %s, want Pro", tier.Name)
}
if tier.Description != "Pro plan" {
t.Errorf("Description 不匹配: got %s, want Pro plan", tier.Description)
}
}
func TestTierInfo_UnmarshalJSON_null(t *testing.T) {
data := []byte(`null`)
var tier TierInfo
if err := tier.UnmarshalJSON(data); err != nil {
t.Fatalf("反序列化 null 失败: %v", err)
}
if tier.ID != "" {
t.Errorf("null 场景下 ID 应为空: got %s", tier.ID)
}
}
func TestTierInfo_UnmarshalJSON_空数据(t *testing.T) {
data := []byte(``)
var tier TierInfo
if err := tier.UnmarshalJSON(data); err != nil {
t.Fatalf("反序列化空数据失败: %v", err)
}
if tier.ID != "" {
t.Errorf("空数据场景下 ID 应为空: got %s", tier.ID)
}
}
func TestTierInfo_UnmarshalJSON_空格包裹null(t *testing.T) {
data := []byte(` null `)
var tier TierInfo
if err := tier.UnmarshalJSON(data); err != nil {
t.Fatalf("反序列化空格 null 失败: %v", err)
}
if tier.ID != "" {
t.Errorf("空格 null 场景下 ID 应为空: got %s", tier.ID)
}
}
func TestTierInfo_UnmarshalJSON_通过JSON嵌套结构(t *testing.T) {
// 模拟 LoadCodeAssistResponse 中的嵌套反序列化
jsonData := `{"currentTier":"free-tier","paidTier":{"id":"g1-ultra-tier","name":"Ultra"}}`
var resp LoadCodeAssistResponse
if err := json.Unmarshal([]byte(jsonData), &resp); err != nil {
t.Fatalf("反序列化嵌套结构失败: %v", err)
}
if resp.CurrentTier == nil || resp.CurrentTier.ID != "free-tier" {
t.Errorf("CurrentTier 不匹配: got %+v", resp.CurrentTier)
}
if resp.PaidTier == nil || resp.PaidTier.ID != "g1-ultra-tier" {
t.Errorf("PaidTier 不匹配: got %+v", resp.PaidTier)
}
}
// ---------------------------------------------------------------------------
// LoadCodeAssistResponse.GetTier
// ---------------------------------------------------------------------------
func TestGetTier_PaidTier优先(t *testing.T) {
resp := &LoadCodeAssistResponse{
CurrentTier: &TierInfo{ID: "free-tier"},
PaidTier: &TierInfo{ID: "g1-pro-tier"},
}
if got := resp.GetTier(); got != "g1-pro-tier" {
t.Errorf("应返回 paidTier: got %s", got)
}
}
func TestGetTier_回退到CurrentTier(t *testing.T) {
resp := &LoadCodeAssistResponse{
CurrentTier: &TierInfo{ID: "free-tier"},
}
if got := resp.GetTier(); got != "free-tier" {
t.Errorf("应返回 currentTier: got %s", got)
}
}
func TestGetTier_PaidTier为空ID(t *testing.T) {
resp := &LoadCodeAssistResponse{
CurrentTier: &TierInfo{ID: "free-tier"},
PaidTier: &TierInfo{ID: ""},
}
// paidTier.ID 为空时应回退到 currentTier
if got := resp.GetTier(); got != "free-tier" {
t.Errorf("paidTier.ID 为空时应回退到 currentTier: got %s", got)
}
}
func TestGetTier_两者都为nil(t *testing.T) {
resp := &LoadCodeAssistResponse{}
if got := resp.GetTier(); got != "" {
t.Errorf("两者都为 nil 时应返回空字符串: got %s", got)
}
}
// ---------------------------------------------------------------------------
// NewClient
// ---------------------------------------------------------------------------
func mustNewClient(t *testing.T, proxyURL string) *Client {
t.Helper()
client, err := NewClient(proxyURL)
if err != nil {
t.Fatalf("NewClient(%q) failed: %v", proxyURL, err)
}
return client
}
func TestNewClient_无代理(t *testing.T) {
client, err := NewClient("")
if err != nil {
t.Fatalf("NewClient 返回错误: %v", err)
}
if client == nil {
t.Fatal("NewClient 返回 nil")
}
if client.httpClient == nil {
t.Fatal("httpClient 为 nil")
}
if client.httpClient.Timeout != 30*time.Second {
t.Errorf("Timeout 不匹配: got %v, want 30s", client.httpClient.Timeout)
}
// 无代理时 Transport 应为 nil(使用默认)
if client.httpClient.Transport != nil {
t.Error("无代理时 Transport 应为 nil")
}
}
func TestNewClient_有代理(t *testing.T) {
client, err := NewClient("http://proxy.example.com:8080")
if err != nil {
t.Fatalf("NewClient 返回错误: %v", err)
}
if client == nil {
t.Fatal("NewClient 返回 nil")
}
if client.httpClient.Transport == nil {
t.Fatal("有代理时 Transport 不应为 nil")
}
}
func TestNewClient_空格代理(t *testing.T) {
client, err := NewClient(" ")
if err != nil {
t.Fatalf("NewClient 返回错误: %v", err)
}
if client == nil {
t.Fatal("NewClient 返回 nil")
}
// 空格代理应等同于无代理
if client.httpClient.Transport != nil {
t.Error("空格代理 Transport 应为 nil")
}
}
func TestNewClient_无效代理URL(t *testing.T) {
// 无效 URL 应返回 error
_, err := NewClient("://invalid")
if err == nil {
t.Fatal("无效代理 URL 应返回错误")
}
if !strings.Contains(err.Error(), "invalid proxy URL") {
t.Errorf("错误信息应包含 'invalid proxy URL': got %s", err.Error())
}
}
// ---------------------------------------------------------------------------
// isConnectionError
// ---------------------------------------------------------------------------
func TestIsConnectionError_nil(t *testing.T) {
if isConnectionError(nil) {
t.Error("nil 错误不应判定为连接错误")
}
}
func TestIsConnectionError_超时错误(t *testing.T) {
// 使用 net.OpError 包装超时
err := &net.OpError{
Op: "dial",
Net: "tcp",
Err: &timeoutError{},
}
if !isConnectionError(err) {
t.Error("超时错误应判定为连接错误")
}
}
// timeoutError 实现 net.Error 接口用于测试
type timeoutError struct{}
func (e *timeoutError) Error() string { return "timeout" }
func (e *timeoutError) Timeout() bool { return true }
func (e *timeoutError) Temporary() bool { return true }
func TestIsConnectionError_netOpError(t *testing.T) {
err := &net.OpError{
Op: "dial",
Net: "tcp",
Err: fmt.Errorf("connection refused"),
}
if !isConnectionError(err) {
t.Error("net.OpError 应判定为连接错误")
}
}
func TestIsConnectionError_urlError(t *testing.T) {
err := &url.Error{
Op: "Get",
URL: "https://example.com",
Err: fmt.Errorf("some error"),
}
if !isConnectionError(err) {
t.Error("url.Error 应判定为连接错误")
}
}
func TestIsConnectionError_普通错误(t *testing.T) {
err := fmt.Errorf("some random error")
if isConnectionError(err) {
t.Error("普通错误不应判定为连接错误")
}
}
func TestIsConnectionError_包装的netOpError(t *testing.T) {
inner := &net.OpError{
Op: "dial",
Net: "tcp",
Err: fmt.Errorf("connection refused"),
}
err := fmt.Errorf("wrapping: %w", inner)
if !isConnectionError(err) {
t.Error("被包装的 net.OpError 应判定为连接错误")
}
}
// ---------------------------------------------------------------------------
// shouldFallbackToNextURL
// ---------------------------------------------------------------------------
func TestShouldFallbackToNextURL_连接错误(t *testing.T) {
err := &net.OpError{Op: "dial", Net: "tcp", Err: fmt.Errorf("refused")}
if !shouldFallbackToNextURL(err, 0) {
t.Error("连接错误应触发 URL 降级")
}
}
func TestShouldFallbackToNextURL_状态码(t *testing.T) {
tests := []struct {
name string
statusCode int
want bool
}{
{"429 Too Many Requests", http.StatusTooManyRequests, true},
{"408 Request Timeout", http.StatusRequestTimeout, true},
{"404 Not Found", http.StatusNotFound, true},
{"500 Internal Server Error", http.StatusInternalServerError, true},
{"502 Bad Gateway", http.StatusBadGateway, true},
{"503 Service Unavailable", http.StatusServiceUnavailable, true},
{"200 OK", http.StatusOK, false},
{"201 Created", http.StatusCreated, false},
{"400 Bad Request", http.StatusBadRequest, false},
{"401 Unauthorized", http.StatusUnauthorized, false},
{"403 Forbidden", http.StatusForbidden, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := shouldFallbackToNextURL(nil, tt.statusCode)
if got != tt.want {
t.Errorf("shouldFallbackToNextURL(nil, %d) = %v, want %v", tt.statusCode, got, tt.want)
}
})
}
}
func TestShouldFallbackToNextURL_无错误且200(t *testing.T) {
if shouldFallbackToNextURL(nil, http.StatusOK) {
t.Error("无错误且 200 不应触发 URL 降级")
}
}
// ---------------------------------------------------------------------------
// Client.ExchangeCode (使用 httptest)
// ---------------------------------------------------------------------------
func TestClient_ExchangeCode_成功(t *testing.T) {
old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 验证请求方法
if r.Method != http.MethodPost {
t.Errorf("请求方法不匹配: got %s", r.Method)
}
// 验证 Content-Type
if ct := r.Header.Get("Content-Type"); ct != "application/x-www-form-urlencoded" {
t.Errorf("Content-Type 不匹配: got %s", ct)
}
// 验证请求体参数
if err := r.ParseForm(); err != nil {
t.Fatalf("解析表单失败: %v", err)
}
if r.FormValue("client_id") != ClientID {
t.Errorf("client_id 不匹配: got %s", r.FormValue("client_id"))
}
if r.FormValue("client_secret") != "test-secret" {
t.Errorf("client_secret 不匹配: got %s", r.FormValue("client_secret"))
}
if r.FormValue("code") != "auth-code" {
t.Errorf("code 不匹配: got %s", r.FormValue("code"))
}
if r.FormValue("code_verifier") != "verifier123" {
t.Errorf("code_verifier 不匹配: got %s", r.FormValue("code_verifier"))
}
if r.FormValue("grant_type") != "authorization_code" {
t.Errorf("grant_type 不匹配: got %s", r.FormValue("grant_type"))
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(TokenResponse{
AccessToken: "access-tok",
ExpiresIn: 3600,
TokenType: "Bearer",
RefreshToken: "refresh-tok",
})
}))
defer server.Close()
// 临时替换 TokenURL(该函数直接使用常量,需要我们通过构建自定义 client 来绕过)
// 由于 ExchangeCode 硬编码了 TokenURL,我们需要直接测试 HTTP client 的行为
// 这里通过构造一个直接调用 mock server 的测试
client := &Client{httpClient: server.Client()}
// 由于 ExchangeCode 使用硬编码的 TokenURL,我们无法直接注入 mock server URL
// 需要使用 httptest 的 Transport 重定向
originalTokenURL := TokenURL
// 我们改为直接构造请求来测试逻辑
_ = originalTokenURL
_ = client
// 改用直接构造请求测试 mock server 响应
ctx := context.Background()
params := url.Values{}
params.Set("client_id", ClientID)
params.Set("client_secret", "test-secret")
params.Set("code", "auth-code")
params.Set("redirect_uri", RedirectURI)
params.Set("grant_type", "authorization_code")
params.Set("code_verifier", "verifier123")
req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, strings.NewReader(params.Encode()))
if err != nil {
t.Fatalf("创建请求失败: %v", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := server.Client().Do(req)
if err != nil {
t.Fatalf("请求失败: %v", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
t.Fatalf("状态码不匹配: got %d", resp.StatusCode)
}
var tokenResp TokenResponse
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
t.Fatalf("解码失败: %v", err)
}
if tokenResp.AccessToken != "access-tok" {
t.Errorf("AccessToken 不匹配: got %s", tokenResp.AccessToken)
}
if tokenResp.RefreshToken != "refresh-tok" {
t.Errorf("RefreshToken 不匹配: got %s", tokenResp.RefreshToken)
}
}
func TestClient_ExchangeCode_无ClientSecret(t *testing.T) {
old := defaultClientSecret
defaultClientSecret = ""
t.Cleanup(func() { defaultClientSecret = old })
client := mustNewClient(t, "")
_, err := client.ExchangeCode(context.Background(), "code", "verifier")
if err == nil {
t.Fatal("缺少 client_secret 时应返回错误")
}
if !strings.Contains(err.Error(), AntigravityOAuthClientSecretEnv) {
t.Errorf("错误信息应包含环境变量名: got %s", err.Error())
}
}
func TestClient_ExchangeCode_服务器返回错误(t *testing.T) {
old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(`{"error":"invalid_grant"}`))
}))
defer server.Close()
// 直接测试 mock server 的错误响应
resp, err := server.Client().Get(server.URL)
if err != nil {
t.Fatalf("请求失败: %v", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("状态码不匹配: got %d, want 400", resp.StatusCode)
}
}
// ---------------------------------------------------------------------------
// Client.RefreshToken (使用 httptest)
// ---------------------------------------------------------------------------
func TestClient_RefreshToken_MockServer(t *testing.T) {
old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("请求方法不匹配: got %s", r.Method)
}
if err := r.ParseForm(); err != nil {
t.Fatalf("解析表单失败: %v", err)
}
if r.FormValue("grant_type") != "refresh_token" {
t.Errorf("grant_type 不匹配: got %s", r.FormValue("grant_type"))
}
if r.FormValue("refresh_token") != "old-refresh-tok" {
t.Errorf("refresh_token 不匹配: got %s", r.FormValue("refresh_token"))
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(TokenResponse{
AccessToken: "new-access-tok",
ExpiresIn: 3600,
TokenType: "Bearer",
})
}))
defer server.Close()
ctx := context.Background()
params := url.Values{}
params.Set("client_id", ClientID)
params.Set("client_secret", "test-secret")
params.Set("refresh_token", "old-refresh-tok")
params.Set("grant_type", "refresh_token")
req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, strings.NewReader(params.Encode()))
if err != nil {
t.Fatalf("创建请求失败: %v", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := server.Client().Do(req)
if err != nil {
t.Fatalf("请求失败: %v", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
t.Fatalf("状态码不匹配: got %d", resp.StatusCode)
}
var tokenResp TokenResponse
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
t.Fatalf("解码失败: %v", err)
}
if tokenResp.AccessToken != "new-access-tok" {
t.Errorf("AccessToken 不匹配: got %s", tokenResp.AccessToken)
}
}
func TestClient_RefreshToken_无ClientSecret(t *testing.T) {
old := defaultClientSecret
defaultClientSecret = ""
t.Cleanup(func() { defaultClientSecret = old })
client := mustNewClient(t, "")
_, err := client.RefreshToken(context.Background(), "refresh-tok")
if err == nil {
t.Fatal("缺少 client_secret 时应返回错误")
}
}
// ---------------------------------------------------------------------------
// Client.GetUserInfo (使用 httptest)
// ---------------------------------------------------------------------------
func TestClient_GetUserInfo_成功(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
t.Errorf("请求方法不匹配: got %s", r.Method)
}
auth := r.Header.Get("Authorization")
if auth != "Bearer test-access-token" {
t.Errorf("Authorization 不匹配: got %s", auth)
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(UserInfo{
Email: "user@example.com",
Name: "Test User",
GivenName: "Test",
FamilyName: "User",
Picture: "https://example.com/photo.jpg",
})
}))
defer server.Close()
// 直接通过 mock server 测试 GetUserInfo 的行为逻辑
ctx := context.Background()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil)
if err != nil {
t.Fatalf("创建请求失败: %v", err)
}
req.Header.Set("Authorization", "Bearer test-access-token")
resp, err := server.Client().Do(req)
if err != nil {
t.Fatalf("请求失败: %v", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
t.Fatalf("状态码不匹配: got %d", resp.StatusCode)
}
var userInfo UserInfo
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
t.Fatalf("解码失败: %v", err)
}
if userInfo.Email != "user@example.com" {
t.Errorf("Email 不匹配: got %s", userInfo.Email)
}
if userInfo.Name != "Test User" {
t.Errorf("Name 不匹配: got %s", userInfo.Name)
}
}
func TestClient_GetUserInfo_服务器返回错误(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte(`{"error":"invalid_token"}`))
}))
defer server.Close()
resp, err := server.Client().Get(server.URL)
if err != nil {
t.Fatalf("请求失败: %v", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusUnauthorized {
t.Errorf("状态码不匹配: got %d, want 401", resp.StatusCode)
}
}
// ---------------------------------------------------------------------------
// TokenResponse / UserInfo JSON 序列化
// ---------------------------------------------------------------------------
func TestTokenResponse_JSON序列化(t *testing.T) {
jsonData := `{"access_token":"at","expires_in":3600,"token_type":"Bearer","scope":"openid","refresh_token":"rt"}`
var resp TokenResponse
if err := json.Unmarshal([]byte(jsonData), &resp); err != nil {
t.Fatalf("反序列化失败: %v", err)
}
if resp.AccessToken != "at" {
t.Errorf("AccessToken 不匹配: got %s", resp.AccessToken)
}
if resp.ExpiresIn != 3600 {
t.Errorf("ExpiresIn 不匹配: got %d", resp.ExpiresIn)
}
if resp.RefreshToken != "rt" {
t.Errorf("RefreshToken 不匹配: got %s", resp.RefreshToken)
}
}
func TestUserInfo_JSON序列化(t *testing.T) {
jsonData := `{"email":"a@b.com","name":"Alice"}`
var info UserInfo
if err := json.Unmarshal([]byte(jsonData), &info); err != nil {
t.Fatalf("反序列化失败: %v", err)
}
if info.Email != "a@b.com" {
t.Errorf("Email 不匹配: got %s", info.Email)
}
if info.Name != "Alice" {
t.Errorf("Name 不匹配: got %s", info.Name)
}
}
// ---------------------------------------------------------------------------
// LoadCodeAssistResponse JSON 序列化
// ---------------------------------------------------------------------------
func TestLoadCodeAssistResponse_完整JSON(t *testing.T) {
jsonData := `{
"cloudaicompanionProject": "proj-123",
"currentTier": "free-tier",
"paidTier": {"id": "g1-pro-tier", "name": "Pro"},
"ineligibleTiers": [{"tier": {"id": "g1-ultra-tier"}, "reasonCode": "INELIGIBLE_ACCOUNT"}]
}`
var resp LoadCodeAssistResponse
if err := json.Unmarshal([]byte(jsonData), &resp); err != nil {
t.Fatalf("反序列化失败: %v", err)
}
if resp.CloudAICompanionProject != "proj-123" {
t.Errorf("CloudAICompanionProject 不匹配: got %s", resp.CloudAICompanionProject)
}
if resp.GetTier() != "g1-pro-tier" {
t.Errorf("GetTier 不匹配: got %s", resp.GetTier())
}
if len(resp.IneligibleTiers) != 1 {
t.Fatalf("IneligibleTiers 数量不匹配: got %d", len(resp.IneligibleTiers))
}
if resp.IneligibleTiers[0].ReasonCode != "INELIGIBLE_ACCOUNT" {
t.Errorf("ReasonCode 不匹配: got %s", resp.IneligibleTiers[0].ReasonCode)
}
}
// ===========================================================================
// 以下为新增测试:真正调用 Client 方法,通过 RoundTripper 拦截 HTTP 请求
// ===========================================================================
// redirectRoundTripper 将请求中特定前缀的 URL 重定向到 httptest server
type redirectRoundTripper struct {
// 原始 URL 前缀 -> 替换目标 URL 的映射
redirects map[string]string
transport http.RoundTripper
}
func (rt *redirectRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
originalURL := req.URL.String()
for prefix, target := range rt.redirects {
if strings.HasPrefix(originalURL, prefix) {
newURL := target + strings.TrimPrefix(originalURL, prefix)
parsed, err := url.Parse(newURL)
if err != nil {
return nil, err
}
req.URL = parsed
break
}
}
if rt.transport == nil {
return http.DefaultTransport.RoundTrip(req)
}
return rt.transport.RoundTrip(req)
}
// newTestClientWithRedirect 创建一个 Client,将指定 URL 前缀的请求重定向到 mock server
func newTestClientWithRedirect(redirects map[string]string) *Client {
return &Client{
httpClient: &http.Client{
Timeout: 10 * time.Second,
Transport: &redirectRoundTripper{
redirects: redirects,
},
},
}
}
// ---------------------------------------------------------------------------
// Client.ExchangeCode - 真正调用方法的测试
// ---------------------------------------------------------------------------
func TestClient_ExchangeCode_Success_RealCall(t *testing.T) {
old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("请求方法不匹配: got %s, want POST", r.Method)
}
if ct := r.Header.Get("Content-Type"); ct != "application/x-www-form-urlencoded" {
t.Errorf("Content-Type 不匹配: got %s", ct)
}
if err := r.ParseForm(); err != nil {
t.Fatalf("解析表单失败: %v", err)
}
if r.FormValue("client_id") != ClientID {
t.Errorf("client_id 不匹配: got %s", r.FormValue("client_id"))
}
if r.FormValue("client_secret") != "test-secret" {
t.Errorf("client_secret 不匹配: got %s", r.FormValue("client_secret"))
}
if r.FormValue("code") != "test-auth-code" {
t.Errorf("code 不匹配: got %s", r.FormValue("code"))
}
if r.FormValue("code_verifier") != "test-verifier" {
t.Errorf("code_verifier 不匹配: got %s", r.FormValue("code_verifier"))
}
if r.FormValue("grant_type") != "authorization_code" {
t.Errorf("grant_type 不匹配: got %s", r.FormValue("grant_type"))
}
if r.FormValue("redirect_uri") != RedirectURI {
t.Errorf("redirect_uri 不匹配: got %s", r.FormValue("redirect_uri"))
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(TokenResponse{
AccessToken: "new-access-token",
ExpiresIn: 3600,
TokenType: "Bearer",
Scope: "openid email",
RefreshToken: "new-refresh-token",
})
}))
defer server.Close()
client := newTestClientWithRedirect(map[string]string{
TokenURL: server.URL,
})
tokenResp, err := client.ExchangeCode(context.Background(), "test-auth-code", "test-verifier")
if err != nil {
t.Fatalf("ExchangeCode 失败: %v", err)
}
if tokenResp.AccessToken != "new-access-token" {
t.Errorf("AccessToken 不匹配: got %s, want new-access-token", tokenResp.AccessToken)
}
if tokenResp.RefreshToken != "new-refresh-token" {
t.Errorf("RefreshToken 不匹配: got %s, want new-refresh-token", tokenResp.RefreshToken)
}
if tokenResp.ExpiresIn != 3600 {
t.Errorf("ExpiresIn 不匹配: got %d, want 3600", tokenResp.ExpiresIn)
}
if tokenResp.TokenType != "Bearer" {
t.Errorf("TokenType 不匹配: got %s, want Bearer", tokenResp.TokenType)
}
if tokenResp.Scope != "openid email" {
t.Errorf("Scope 不匹配: got %s, want openid email", tokenResp.Scope)
}
}
func TestClient_ExchangeCode_ServerError_RealCall(t *testing.T) {
old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(`{"error":"invalid_grant","error_description":"code expired"}`))
}))
defer server.Close()
client := newTestClientWithRedirect(map[string]string{
TokenURL: server.URL,
})
_, err := client.ExchangeCode(context.Background(), "expired-code", "verifier")
if err == nil {
t.Fatal("服务器返回 400 时应返回错误")
}
if !strings.Contains(err.Error(), "token 交换失败") {
t.Errorf("错误信息应包含 'token 交换失败': got %s", err.Error())
}
if !strings.Contains(err.Error(), "400") {
t.Errorf("错误信息应包含状态码 400: got %s", err.Error())
}
}
func TestClient_ExchangeCode_InvalidJSON_RealCall(t *testing.T) {
old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{invalid json`))
}))
defer server.Close()
client := newTestClientWithRedirect(map[string]string{
TokenURL: server.URL,
})
_, err := client.ExchangeCode(context.Background(), "code", "verifier")
if err == nil {
t.Fatal("无效 JSON 响应应返回错误")
}
if !strings.Contains(err.Error(), "token 解析失败") {
t.Errorf("错误信息应包含 'token 解析失败': got %s", err.Error())
}
}
func TestClient_ExchangeCode_ContextCanceled_RealCall(t *testing.T) {
old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(5 * time.Second) // 模拟慢响应
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
client := newTestClientWithRedirect(map[string]string{
TokenURL: server.URL,
})
ctx, cancel := context.WithCancel(context.Background())
cancel() // 立即取消
_, err := client.ExchangeCode(ctx, "code", "verifier")
if err == nil {
t.Fatal("context 取消时应返回错误")
}
}
// ---------------------------------------------------------------------------
// Client.RefreshToken - 真正调用方法的测试
// ---------------------------------------------------------------------------
func TestClient_RefreshToken_Success_RealCall(t *testing.T) {
old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("请求方法不匹配: got %s, want POST", r.Method)
}
if err := r.ParseForm(); err != nil {
t.Fatalf("解析表单失败: %v", err)
}
if r.FormValue("grant_type") != "refresh_token" {
t.Errorf("grant_type 不匹配: got %s", r.FormValue("grant_type"))
}
if r.FormValue("refresh_token") != "my-refresh-token" {
t.Errorf("refresh_token 不匹配: got %s", r.FormValue("refresh_token"))
}
if r.FormValue("client_id") != ClientID {
t.Errorf("client_id 不匹配: got %s", r.FormValue("client_id"))
}
if r.FormValue("client_secret") != "test-secret" {
t.Errorf("client_secret 不匹配: got %s", r.FormValue("client_secret"))
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(TokenResponse{
AccessToken: "refreshed-access-token",
ExpiresIn: 3600,
TokenType: "Bearer",
})
}))
defer server.Close()
client := newTestClientWithRedirect(map[string]string{
TokenURL: server.URL,
})
tokenResp, err := client.RefreshToken(context.Background(), "my-refresh-token")
if err != nil {
t.Fatalf("RefreshToken 失败: %v", err)
}
if tokenResp.AccessToken != "refreshed-access-token" {
t.Errorf("AccessToken 不匹配: got %s, want refreshed-access-token", tokenResp.AccessToken)
}
if tokenResp.ExpiresIn != 3600 {
t.Errorf("ExpiresIn 不匹配: got %d, want 3600", tokenResp.ExpiresIn)
}
}
func TestClient_RefreshToken_ServerError_RealCall(t *testing.T) {
old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte(`{"error":"invalid_grant","error_description":"token revoked"}`))
}))
defer server.Close()
client := newTestClientWithRedirect(map[string]string{
TokenURL: server.URL,
})
_, err := client.RefreshToken(context.Background(), "revoked-token")
if err == nil {
t.Fatal("服务器返回 401 时应返回错误")
}
if !strings.Contains(err.Error(), "token 刷新失败") {
t.Errorf("错误信息应包含 'token 刷新失败': got %s", err.Error())
}
}
func TestClient_RefreshToken_InvalidJSON_RealCall(t *testing.T) {
old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`not-json`))
}))
defer server.Close()
client := newTestClientWithRedirect(map[string]string{
TokenURL: server.URL,
})
_, err := client.RefreshToken(context.Background(), "refresh-tok")
if err == nil {
t.Fatal("无效 JSON 响应应返回错误")
}
if !strings.Contains(err.Error(), "token 解析失败") {
t.Errorf("错误信息应包含 'token 解析失败': got %s", err.Error())
}
}
func TestClient_RefreshToken_ContextCanceled_RealCall(t *testing.T) {
old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(5 * time.Second)
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
client := newTestClientWithRedirect(map[string]string{
TokenURL: server.URL,
})
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := client.RefreshToken(ctx, "refresh-tok")
if err == nil {
t.Fatal("context 取消时应返回错误")
}
}
// ---------------------------------------------------------------------------
// Client.GetUserInfo - 真正调用方法的测试
// ---------------------------------------------------------------------------
func TestClient_GetUserInfo_Success_RealCall(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
t.Errorf("请求方法不匹配: got %s, want GET", r.Method)
}
auth := r.Header.Get("Authorization")
if auth != "Bearer user-access-token" {
t.Errorf("Authorization 不匹配: got %s", auth)
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(UserInfo{
Email: "test@example.com",
Name: "Test User",
GivenName: "Test",
FamilyName: "User",
Picture: "https://example.com/avatar.jpg",
})
}))
defer server.Close()
client := newTestClientWithRedirect(map[string]string{
UserInfoURL: server.URL,
})
userInfo, err := client.GetUserInfo(context.Background(), "user-access-token")
if err != nil {
t.Fatalf("GetUserInfo 失败: %v", err)
}
if userInfo.Email != "test@example.com" {
t.Errorf("Email 不匹配: got %s, want test@example.com", userInfo.Email)
}
if userInfo.Name != "Test User" {
t.Errorf("Name 不匹配: got %s, want Test User", userInfo.Name)
}
if userInfo.GivenName != "Test" {
t.Errorf("GivenName 不匹配: got %s, want Test", userInfo.GivenName)
}
if userInfo.FamilyName != "User" {
t.Errorf("FamilyName 不匹配: got %s, want User", userInfo.FamilyName)
}
if userInfo.Picture != "https://example.com/avatar.jpg" {
t.Errorf("Picture 不匹配: got %s", userInfo.Picture)
}
}
func TestClient_GetUserInfo_Unauthorized_RealCall(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte(`{"error":"invalid_token"}`))
}))
defer server.Close()
client := newTestClientWithRedirect(map[string]string{
UserInfoURL: server.URL,
})
_, err := client.GetUserInfo(context.Background(), "bad-token")
if err == nil {
t.Fatal("服务器返回 401 时应返回错误")
}
if !strings.Contains(err.Error(), "获取用户信息失败") {
t.Errorf("错误信息应包含 '获取用户信息失败': got %s", err.Error())
}
if !strings.Contains(err.Error(), "401") {
t.Errorf("错误信息应包含状态码 401: got %s", err.Error())
}
}
func TestClient_GetUserInfo_InvalidJSON_RealCall(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{broken`))
}))
defer server.Close()
client := newTestClientWithRedirect(map[string]string{
UserInfoURL: server.URL,
})
_, err := client.GetUserInfo(context.Background(), "token")
if err == nil {
t.Fatal("无效 JSON 响应应返回错误")
}
if !strings.Contains(err.Error(), "用户信息解析失败") {
t.Errorf("错误信息应包含 '用户信息解析失败': got %s", err.Error())
}
}
func TestClient_GetUserInfo_ContextCanceled_RealCall(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(5 * time.Second)
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
client := newTestClientWithRedirect(map[string]string{
UserInfoURL: server.URL,
})
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := client.GetUserInfo(ctx, "token")
if err == nil {
t.Fatal("context 取消时应返回错误")
}
}
// ---------------------------------------------------------------------------
// Client.LoadCodeAssist - 真正调用方法的测试
// ---------------------------------------------------------------------------
// withMockBaseURLs 临时替换 BaseURLs,测试结束后恢复
func withMockBaseURLs(t *testing.T, urls []string) {
t.Helper()
origBaseURLs := BaseURLs
origBaseURL := BaseURL
BaseURLs = urls
if len(urls) > 0 {
BaseURL = urls[0]
}
t.Cleanup(func() {
BaseURLs = origBaseURLs
BaseURL = origBaseURL
})
}
func TestClient_LoadCodeAssist_Success_RealCall(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("请求方法不匹配: got %s, want POST", r.Method)
}
if !strings.HasSuffix(r.URL.Path, "/v1internal:loadCodeAssist") {
t.Errorf("URL 路径不匹配: got %s", r.URL.Path)
}
auth := r.Header.Get("Authorization")
if auth != "Bearer test-token" {
t.Errorf("Authorization 不匹配: got %s", auth)
}
if ct := r.Header.Get("Content-Type"); ct != "application/json" {
t.Errorf("Content-Type 不匹配: got %s", ct)
}
if ua := r.Header.Get("User-Agent"); ua != GetUserAgent() {
t.Errorf("User-Agent 不匹配: got %s", ua)
}
// 验证请求体
var reqBody LoadCodeAssistRequest
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
t.Fatalf("解析请求体失败: %v", err)
}
if reqBody.Metadata.IDEType != "ANTIGRAVITY" {
t.Errorf("IDEType 不匹配: got %s, want ANTIGRAVITY", reqBody.Metadata.IDEType)
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{
"cloudaicompanionProject": "test-project-123",
"currentTier": {"id": "free-tier", "name": "Free"},
"paidTier": {"id": "g1-pro-tier", "name": "Pro", "description": "Pro plan"}
}`))
}))
defer server.Close()
withMockBaseURLs(t, []string{server.URL})
client := mustNewClient(t, "")
resp, rawResp, err := client.LoadCodeAssist(context.Background(), "test-token")
if err != nil {
t.Fatalf("LoadCodeAssist 失败: %v", err)
}
if resp.CloudAICompanionProject != "test-project-123" {
t.Errorf("CloudAICompanionProject 不匹配: got %s", resp.CloudAICompanionProject)
}
if resp.GetTier() != "g1-pro-tier" {
t.Errorf("GetTier 不匹配: got %s, want g1-pro-tier", resp.GetTier())
}
if resp.CurrentTier == nil || resp.CurrentTier.ID != "free-tier" {
t.Errorf("CurrentTier 不匹配: got %+v", resp.CurrentTier)
}
if resp.PaidTier == nil || resp.PaidTier.ID != "g1-pro-tier" {
t.Errorf("PaidTier 不匹配: got %+v", resp.PaidTier)
}
// 验证原始 JSON map
if rawResp == nil {
t.Fatal("rawResp 不应为 nil")
}
if rawResp["cloudaicompanionProject"] != "test-project-123" {
t.Errorf("rawResp cloudaicompanionProject 不匹配: got %v", rawResp["cloudaicompanionProject"])
}
}
func TestClient_LoadCodeAssist_HTTPError_RealCall(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusForbidden)
_, _ = w.Write([]byte(`{"error":"forbidden"}`))
}))
defer server.Close()
withMockBaseURLs(t, []string{server.URL})
client := mustNewClient(t, "")
_, _, err := client.LoadCodeAssist(context.Background(), "bad-token")
if err == nil {
t.Fatal("服务器返回 403 时应返回错误")
}
if !strings.Contains(err.Error(), "loadCodeAssist 失败") {
t.Errorf("错误信息应包含 'loadCodeAssist 失败': got %s", err.Error())
}
if !strings.Contains(err.Error(), "403") {
t.Errorf("错误信息应包含状态码 403: got %s", err.Error())
}
}
func TestClient_LoadCodeAssist_InvalidJSON_RealCall(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{not valid json!!!`))
}))
defer server.Close()
withMockBaseURLs(t, []string{server.URL})
client := mustNewClient(t, "")
_, _, err := client.LoadCodeAssist(context.Background(), "token")
if err == nil {
t.Fatal("无效 JSON 响应应返回错误")
}
if !strings.Contains(err.Error(), "响应解析失败") {
t.Errorf("错误信息应包含 '响应解析失败': got %s", err.Error())
}
}
func TestClient_LoadCodeAssist_URLFallback_RealCall(t *testing.T) {
// 第一个 server 返回 500,第二个 server 返回成功
callCount := 0
server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte(`{"error":"internal"}`))
}))
defer server1.Close()
server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{
"cloudaicompanionProject": "fallback-project",
"currentTier": {"id": "free-tier", "name": "Free"}
}`))
}))
defer server2.Close()
withMockBaseURLs(t, []string{server1.URL, server2.URL})
client := mustNewClient(t, "")
resp, _, err := client.LoadCodeAssist(context.Background(), "token")
if err != nil {
t.Fatalf("LoadCodeAssist 应在 fallback 后成功: %v", err)
}
if resp.CloudAICompanionProject != "fallback-project" {
t.Errorf("CloudAICompanionProject 不匹配: got %s", resp.CloudAICompanionProject)
}
if callCount != 2 {
t.Errorf("应该调用了 2 个 server,实际调用 %d 次", callCount)
}
}
func TestClient_LoadCodeAssist_AllURLsFail_RealCall(t *testing.T) {
server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable)
_, _ = w.Write([]byte(`{"error":"unavailable"}`))
}))
defer server1.Close()
server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadGateway)
_, _ = w.Write([]byte(`{"error":"bad_gateway"}`))
}))
defer server2.Close()
withMockBaseURLs(t, []string{server1.URL, server2.URL})
client := mustNewClient(t, "")
_, _, err := client.LoadCodeAssist(context.Background(), "token")
if err == nil {
t.Fatal("所有 URL 都失败时应返回错误")
}
}
func TestClient_LoadCodeAssist_ContextCanceled_RealCall(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(5 * time.Second)
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
withMockBaseURLs(t, []string{server.URL})
client := mustNewClient(t, "")
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, _, err := client.LoadCodeAssist(ctx, "token")
if err == nil {
t.Fatal("context 取消时应返回错误")
}
}
// ---------------------------------------------------------------------------
// Client.FetchAvailableModels - 真正调用方法的测试
// ---------------------------------------------------------------------------
func TestClient_FetchAvailableModels_Success_RealCall(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("请求方法不匹配: got %s, want POST", r.Method)
}
if !strings.HasSuffix(r.URL.Path, "/v1internal:fetchAvailableModels") {
t.Errorf("URL 路径不匹配: got %s", r.URL.Path)
}
auth := r.Header.Get("Authorization")
if auth != "Bearer test-token" {
t.Errorf("Authorization 不匹配: got %s", auth)
}
if ct := r.Header.Get("Content-Type"); ct != "application/json" {
t.Errorf("Content-Type 不匹配: got %s", ct)
}
if ua := r.Header.Get("User-Agent"); ua != GetUserAgent() {
t.Errorf("User-Agent 不匹配: got %s", ua)
}
// 验证请求体
var reqBody FetchAvailableModelsRequest
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
t.Fatalf("解析请求体失败: %v", err)
}
if reqBody.Project != "project-abc" {
t.Errorf("Project 不匹配: got %s, want project-abc", reqBody.Project)
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{
"models": {
"gemini-2.0-flash": {
"quotaInfo": {
"remainingFraction": 0.85,
"resetTime": "2025-01-01T00:00:00Z"
}
},
"gemini-2.5-pro": {
"quotaInfo": {
"remainingFraction": 0.5
}
}
}
}`))
}))
defer server.Close()
withMockBaseURLs(t, []string{server.URL})
client := mustNewClient(t, "")
resp, rawResp, err := client.FetchAvailableModels(context.Background(), "test-token", "project-abc")
if err != nil {
t.Fatalf("FetchAvailableModels 失败: %v", err)
}
if resp.Models == nil {
t.Fatal("Models 不应为 nil")
}
if len(resp.Models) != 2 {
t.Errorf("Models 数量不匹配: got %d, want 2", len(resp.Models))
}
flashModel, ok := resp.Models["gemini-2.0-flash"]
if !ok {
t.Fatal("缺少 gemini-2.0-flash 模型")
}
if flashModel.QuotaInfo == nil {
t.Fatal("gemini-2.0-flash QuotaInfo 不应为 nil")
}
if flashModel.QuotaInfo.RemainingFraction != 0.85 {
t.Errorf("RemainingFraction 不匹配: got %f, want 0.85", flashModel.QuotaInfo.RemainingFraction)
}
if flashModel.QuotaInfo.ResetTime != "2025-01-01T00:00:00Z" {
t.Errorf("ResetTime 不匹配: got %s", flashModel.QuotaInfo.ResetTime)
}
proModel, ok := resp.Models["gemini-2.5-pro"]
if !ok {
t.Fatal("缺少 gemini-2.5-pro 模型")
}
if proModel.QuotaInfo == nil {
t.Fatal("gemini-2.5-pro QuotaInfo 不应为 nil")
}
if proModel.QuotaInfo.RemainingFraction != 0.5 {
t.Errorf("RemainingFraction 不匹配: got %f, want 0.5", proModel.QuotaInfo.RemainingFraction)
}
// 验证原始 JSON map
if rawResp == nil {
t.Fatal("rawResp 不应为 nil")
}
if rawResp["models"] == nil {
t.Error("rawResp models 不应为 nil")
}
}
func TestClient_FetchAvailableModels_HTTPError_RealCall(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusForbidden)
_, _ = w.Write([]byte(`{"error":"forbidden"}`))
}))
defer server.Close()
withMockBaseURLs(t, []string{server.URL})
client := mustNewClient(t, "")
_, _, err := client.FetchAvailableModels(context.Background(), "bad-token", "proj")
if err == nil {
t.Fatal("服务器返回 403 时应返回错误")
}
if !strings.Contains(err.Error(), "fetchAvailableModels 失败") {
t.Errorf("错误信息应包含 'fetchAvailableModels 失败': got %s", err.Error())
}
}
func TestClient_FetchAvailableModels_InvalidJSON_RealCall(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`<<<not json>>>`))
}))
defer server.Close()
withMockBaseURLs(t, []string{server.URL})
client := mustNewClient(t, "")
_, _, err := client.FetchAvailableModels(context.Background(), "token", "proj")
if err == nil {
t.Fatal("无效 JSON 响应应返回错误")
}
if !strings.Contains(err.Error(), "响应解析失败") {
t.Errorf("错误信息应包含 '响应解析失败': got %s", err.Error())
}
}
func TestClient_FetchAvailableModels_URLFallback_RealCall(t *testing.T) {
callCount := 0
// 第一个 server 返回 429,第二个 server 返回成功
server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
w.WriteHeader(http.StatusTooManyRequests)
_, _ = w.Write([]byte(`{"error":"rate_limited"}`))
}))
defer server1.Close()
server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"models": {"model-a": {}}}`))
}))
defer server2.Close()
withMockBaseURLs(t, []string{server1.URL, server2.URL})
client := mustNewClient(t, "")
resp, _, err := client.FetchAvailableModels(context.Background(), "token", "proj")
if err != nil {
t.Fatalf("FetchAvailableModels 应在 fallback 后成功: %v", err)
}
if _, ok := resp.Models["model-a"]; !ok {
t.Error("应返回 fallback server 的模型")
}
if callCount != 2 {
t.Errorf("应该调用了 2 个 server,实际调用 %d 次", callCount)
}
}
func TestClient_FetchAvailableModels_AllURLsFail_RealCall(t *testing.T) {
server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
_, _ = w.Write([]byte(`not found`))
}))
defer server1.Close()
server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte(`internal error`))
}))
defer server2.Close()
withMockBaseURLs(t, []string{server1.URL, server2.URL})
client := mustNewClient(t, "")
_, _, err := client.FetchAvailableModels(context.Background(), "token", "proj")
if err == nil {
t.Fatal("所有 URL 都失败时应返回错误")
}
}
func TestClient_FetchAvailableModels_ContextCanceled_RealCall(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(5 * time.Second)
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
withMockBaseURLs(t, []string{server.URL})
client := mustNewClient(t, "")
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, _, err := client.FetchAvailableModels(ctx, "token", "proj")
if err == nil {
t.Fatal("context 取消时应返回错误")
}
}
func TestClient_FetchAvailableModels_EmptyModels_RealCall(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"models": {}}`))
}))
defer server.Close()
withMockBaseURLs(t, []string{server.URL})
client := mustNewClient(t, "")
resp, rawResp, err := client.FetchAvailableModels(context.Background(), "token", "proj")
if err != nil {
t.Fatalf("FetchAvailableModels 失败: %v", err)
}
if resp.Models == nil {
t.Fatal("Models 不应为 nil")
}
if len(resp.Models) != 0 {
t.Errorf("Models 应为空: got %d", len(resp.Models))
}
if rawResp == nil {
t.Fatal("rawResp 不应为 nil")
}
}
// ---------------------------------------------------------------------------
// LoadCodeAssist 和 FetchAvailableModels 的 408 fallback 测试
// ---------------------------------------------------------------------------
func TestClient_LoadCodeAssist_408Fallback_RealCall(t *testing.T) {
server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusRequestTimeout)
_, _ = w.Write([]byte(`timeout`))
}))
defer server1.Close()
server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"cloudaicompanionProject":"p2","currentTier":"free-tier"}`))
}))
defer server2.Close()
withMockBaseURLs(t, []string{server1.URL, server2.URL})
client := mustNewClient(t, "")
resp, _, err := client.LoadCodeAssist(context.Background(), "token")
if err != nil {
t.Fatalf("LoadCodeAssist 应在 408 fallback 后成功: %v", err)
}
if resp.CloudAICompanionProject != "p2" {
t.Errorf("CloudAICompanionProject 不匹配: got %s", resp.CloudAICompanionProject)
}
}
func TestClient_FetchAvailableModels_404Fallback_RealCall(t *testing.T) {
server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
_, _ = w.Write([]byte(`not found`))
}))
defer server1.Close()
server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"models":{"m1":{"quotaInfo":{"remainingFraction":1.0}}}}`))
}))
defer server2.Close()
withMockBaseURLs(t, []string{server1.URL, server2.URL})
client := mustNewClient(t, "")
resp, _, err := client.FetchAvailableModels(context.Background(), "token", "proj")
if err != nil {
t.Fatalf("FetchAvailableModels 应在 404 fallback 后成功: %v", err)
}
if _, ok := resp.Models["m1"]; !ok {
t.Error("应返回 fallback server 的模型 m1")
}
}
func TestExtractProjectIDFromOnboardResponse(t *testing.T) {
t.Parallel()
......
......@@ -70,7 +70,7 @@ type GeminiGenerationConfig struct {
ImageConfig *GeminiImageConfig `json:"imageConfig,omitempty"`
}
// GeminiImageConfig Gemini 图片生成配置(gemini-3-pro-image 支持)
// GeminiImageConfig Gemini 图片生成配置(gemini-3-pro-image / gemini-3.1-flash-image 等图片模型支持)
type GeminiImageConfig struct {
AspectRatio string `json:"aspectRatio,omitempty"` // "1:1", "16:9", "9:16", "4:3", "3:4"
ImageSize string `json:"imageSize,omitempty"` // "1K", "2K", "4K"
......
......@@ -6,10 +6,14 @@ import (
"encoding/base64"
"encoding/hex"
"fmt"
"net/http"
"net/url"
"os"
"strings"
"sync"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
const (
......@@ -20,7 +24,9 @@ const (
// Antigravity OAuth 客户端凭证
ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
ClientSecret = "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"
// AntigravityOAuthClientSecretEnv 是 Antigravity OAuth client_secret 的环境变量名。
AntigravityOAuthClientSecretEnv = "ANTIGRAVITY_OAUTH_CLIENT_SECRET"
// 固定的 redirect_uri(用户需手动复制 code)
RedirectURI = "http://localhost:8085/callback"
......@@ -32,9 +38,6 @@ const (
"https://www.googleapis.com/auth/cclog " +
"https://www.googleapis.com/auth/experimentsandconfigs"
// User-Agent(与 Antigravity-Manager 保持一致)
UserAgent = "antigravity/1.15.8 windows/amd64"
// Session 过期时间
SessionTTL = 30 * time.Minute
......@@ -46,6 +49,35 @@ const (
antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
)
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.19.6
var defaultUserAgentVersion = "1.19.6"
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
var defaultClientSecret = "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"
func init() {
// 从环境变量读取版本号,未设置则使用默认值
if version := os.Getenv("ANTIGRAVITY_USER_AGENT_VERSION"); version != "" {
defaultUserAgentVersion = version
}
// 从环境变量读取 client_secret,未设置则使用默认值
if secret := os.Getenv(AntigravityOAuthClientSecretEnv); secret != "" {
defaultClientSecret = secret
}
}
// GetUserAgent 返回当前配置的 User-Agent
func GetUserAgent() string {
return fmt.Sprintf("antigravity/%s windows/amd64", defaultUserAgentVersion)
}
func getClientSecret() (string, error) {
if v := strings.TrimSpace(defaultClientSecret); v != "" {
return v, nil
}
return "", infraerrors.Newf(http.StatusBadRequest, "ANTIGRAVITY_OAUTH_CLIENT_SECRET_MISSING", "missing antigravity oauth client_secret; set %s", AntigravityOAuthClientSecretEnv)
}
// BaseURLs 定义 Antigravity API 端点(与 Antigravity-Manager 保持一致)
var BaseURLs = []string{
antigravityProdBaseURL, // prod (优先)
......
//go:build unit
package antigravity
import (
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"net/url"
"os"
"strings"
"testing"
"time"
)
// ---------------------------------------------------------------------------
// getClientSecret
// ---------------------------------------------------------------------------
func TestGetClientSecret_环境变量设置(t *testing.T) {
old := defaultClientSecret
defaultClientSecret = ""
t.Cleanup(func() { defaultClientSecret = old })
t.Setenv(AntigravityOAuthClientSecretEnv, "my-secret-value")
// 需要重新触发 init 逻辑:手动从环境变量读取
defaultClientSecret = os.Getenv(AntigravityOAuthClientSecretEnv)
secret, err := getClientSecret()
if err != nil {
t.Fatalf("获取 client_secret 失败: %v", err)
}
if secret != "my-secret-value" {
t.Errorf("client_secret 不匹配: got %s, want my-secret-value", secret)
}
}
func TestGetClientSecret_环境变量为空(t *testing.T) {
old := defaultClientSecret
defaultClientSecret = ""
t.Cleanup(func() { defaultClientSecret = old })
_, err := getClientSecret()
if err == nil {
t.Fatal("defaultClientSecret 为空时应返回错误")
}
if !strings.Contains(err.Error(), AntigravityOAuthClientSecretEnv) {
t.Errorf("错误信息应包含环境变量名: got %s", err.Error())
}
}
func TestGetClientSecret_环境变量未设置(t *testing.T) {
old := defaultClientSecret
defaultClientSecret = ""
t.Cleanup(func() { defaultClientSecret = old })
_, err := getClientSecret()
if err == nil {
t.Fatal("defaultClientSecret 为空时应返回错误")
}
}
func TestGetClientSecret_环境变量含空格(t *testing.T) {
old := defaultClientSecret
defaultClientSecret = " "
t.Cleanup(func() { defaultClientSecret = old })
_, err := getClientSecret()
if err == nil {
t.Fatal("defaultClientSecret 仅含空格时应返回错误")
}
}
func TestGetClientSecret_环境变量有前后空格(t *testing.T) {
old := defaultClientSecret
defaultClientSecret = " valid-secret "
t.Cleanup(func() { defaultClientSecret = old })
secret, err := getClientSecret()
if err != nil {
t.Fatalf("获取 client_secret 失败: %v", err)
}
if secret != "valid-secret" {
t.Errorf("应去除前后空格: got %q, want %q", secret, "valid-secret")
}
}
// ---------------------------------------------------------------------------
// ForwardBaseURLs
// ---------------------------------------------------------------------------
func TestForwardBaseURLs_Daily优先(t *testing.T) {
urls := ForwardBaseURLs()
if len(urls) == 0 {
t.Fatal("ForwardBaseURLs 返回空列表")
}
// daily URL 应排在第一位
if urls[0] != antigravityDailyBaseURL {
t.Errorf("第一个 URL 应为 daily: got %s, want %s", urls[0], antigravityDailyBaseURL)
}
// 应包含所有 URL
if len(urls) != len(BaseURLs) {
t.Errorf("URL 数量不匹配: got %d, want %d", len(urls), len(BaseURLs))
}
// 验证 prod URL 也在列表中
found := false
for _, u := range urls {
if u == antigravityProdBaseURL {
found = true
break
}
}
if !found {
t.Error("ForwardBaseURLs 中缺少 prod URL")
}
}
func TestForwardBaseURLs_不修改原切片(t *testing.T) {
originalFirst := BaseURLs[0]
_ = ForwardBaseURLs()
// 确保原始 BaseURLs 未被修改
if BaseURLs[0] != originalFirst {
t.Errorf("ForwardBaseURLs 不应修改原始 BaseURLs: got %s, want %s", BaseURLs[0], originalFirst)
}
}
// ---------------------------------------------------------------------------
// URLAvailability
// ---------------------------------------------------------------------------
func TestNewURLAvailability(t *testing.T) {
ua := NewURLAvailability(5 * time.Minute)
if ua == nil {
t.Fatal("NewURLAvailability 返回 nil")
}
if ua.ttl != 5*time.Minute {
t.Errorf("TTL 不匹配: got %v, want 5m", ua.ttl)
}
if ua.unavailable == nil {
t.Error("unavailable map 不应为 nil")
}
}
func TestURLAvailability_MarkUnavailable(t *testing.T) {
ua := NewURLAvailability(5 * time.Minute)
testURL := "https://example.com"
ua.MarkUnavailable(testURL)
if ua.IsAvailable(testURL) {
t.Error("标记为不可用后 IsAvailable 应返回 false")
}
}
func TestURLAvailability_MarkSuccess(t *testing.T) {
ua := NewURLAvailability(5 * time.Minute)
testURL := "https://example.com"
// 先标记为不可用
ua.MarkUnavailable(testURL)
if ua.IsAvailable(testURL) {
t.Error("标记为不可用后应不可用")
}
// 标记成功后应恢复可用
ua.MarkSuccess(testURL)
if !ua.IsAvailable(testURL) {
t.Error("MarkSuccess 后应恢复可用")
}
// 验证 lastSuccess 被设置
ua.mu.RLock()
if ua.lastSuccess != testURL {
t.Errorf("lastSuccess 不匹配: got %s, want %s", ua.lastSuccess, testURL)
}
ua.mu.RUnlock()
}
func TestURLAvailability_IsAvailable_TTL过期(t *testing.T) {
// 使用极短的 TTL
ua := NewURLAvailability(1 * time.Millisecond)
testURL := "https://example.com"
ua.MarkUnavailable(testURL)
// 等待 TTL 过期
time.Sleep(5 * time.Millisecond)
if !ua.IsAvailable(testURL) {
t.Error("TTL 过期后 URL 应恢复可用")
}
}
func TestURLAvailability_IsAvailable_未标记的URL(t *testing.T) {
ua := NewURLAvailability(5 * time.Minute)
if !ua.IsAvailable("https://never-marked.com") {
t.Error("未标记的 URL 应默认可用")
}
}
func TestURLAvailability_GetAvailableURLs(t *testing.T) {
ua := NewURLAvailability(10 * time.Minute)
// 默认所有 URL 都可用
urls := ua.GetAvailableURLs()
if len(urls) != len(BaseURLs) {
t.Errorf("可用 URL 数量不匹配: got %d, want %d", len(urls), len(BaseURLs))
}
}
func TestURLAvailability_GetAvailableURLs_标记一个不可用(t *testing.T) {
ua := NewURLAvailability(10 * time.Minute)
if len(BaseURLs) < 2 {
t.Skip("BaseURLs 少于 2 个,跳过此测试")
}
ua.MarkUnavailable(BaseURLs[0])
urls := ua.GetAvailableURLs()
// 标记的 URL 不应出现在可用列表中
for _, u := range urls {
if u == BaseURLs[0] {
t.Errorf("被标记不可用的 URL 不应出现在可用列表中: %s", BaseURLs[0])
}
}
}
func TestURLAvailability_GetAvailableURLsWithBase(t *testing.T) {
ua := NewURLAvailability(10 * time.Minute)
customURLs := []string{"https://a.com", "https://b.com", "https://c.com"}
urls := ua.GetAvailableURLsWithBase(customURLs)
if len(urls) != 3 {
t.Errorf("可用 URL 数量不匹配: got %d, want 3", len(urls))
}
}
func TestURLAvailability_GetAvailableURLsWithBase_LastSuccess优先(t *testing.T) {
ua := NewURLAvailability(10 * time.Minute)
customURLs := []string{"https://a.com", "https://b.com", "https://c.com"}
ua.MarkSuccess("https://c.com")
urls := ua.GetAvailableURLsWithBase(customURLs)
if len(urls) != 3 {
t.Fatalf("可用 URL 数量不匹配: got %d, want 3", len(urls))
}
// c.com 应排在第一位
if urls[0] != "https://c.com" {
t.Errorf("lastSuccess 应排在第一位: got %s, want https://c.com", urls[0])
}
// 其余按原始顺序
if urls[1] != "https://a.com" {
t.Errorf("第二个应为 a.com: got %s", urls[1])
}
if urls[2] != "https://b.com" {
t.Errorf("第三个应为 b.com: got %s", urls[2])
}
}
func TestURLAvailability_GetAvailableURLsWithBase_LastSuccess不可用(t *testing.T) {
ua := NewURLAvailability(10 * time.Minute)
customURLs := []string{"https://a.com", "https://b.com"}
ua.MarkSuccess("https://b.com")
ua.MarkUnavailable("https://b.com")
urls := ua.GetAvailableURLsWithBase(customURLs)
// b.com 被标记不可用,不应出现
if len(urls) != 1 {
t.Fatalf("可用 URL 数量不匹配: got %d, want 1", len(urls))
}
if urls[0] != "https://a.com" {
t.Errorf("仅 a.com 应可用: got %s", urls[0])
}
}
func TestURLAvailability_GetAvailableURLsWithBase_LastSuccess不在列表中(t *testing.T) {
ua := NewURLAvailability(10 * time.Minute)
customURLs := []string{"https://a.com", "https://b.com"}
ua.MarkSuccess("https://not-in-list.com")
urls := ua.GetAvailableURLsWithBase(customURLs)
// lastSuccess 不在自定义列表中,不应被添加
if len(urls) != 2 {
t.Fatalf("可用 URL 数量不匹配: got %d, want 2", len(urls))
}
}
// ---------------------------------------------------------------------------
// SessionStore
// ---------------------------------------------------------------------------
func TestNewSessionStore(t *testing.T) {
store := NewSessionStore()
defer store.Stop()
if store == nil {
t.Fatal("NewSessionStore 返回 nil")
}
if store.sessions == nil {
t.Error("sessions map 不应为 nil")
}
}
func TestSessionStore_SetAndGet(t *testing.T) {
store := NewSessionStore()
defer store.Stop()
session := &OAuthSession{
State: "test-state",
CodeVerifier: "test-verifier",
ProxyURL: "http://proxy.example.com",
CreatedAt: time.Now(),
}
store.Set("session-1", session)
got, ok := store.Get("session-1")
if !ok {
t.Fatal("Get 应返回 true")
}
if got.State != "test-state" {
t.Errorf("State 不匹配: got %s", got.State)
}
if got.CodeVerifier != "test-verifier" {
t.Errorf("CodeVerifier 不匹配: got %s", got.CodeVerifier)
}
if got.ProxyURL != "http://proxy.example.com" {
t.Errorf("ProxyURL 不匹配: got %s", got.ProxyURL)
}
}
func TestSessionStore_Get_不存在(t *testing.T) {
store := NewSessionStore()
defer store.Stop()
_, ok := store.Get("nonexistent")
if ok {
t.Error("不存在的 session 应返回 false")
}
}
func TestSessionStore_Get_过期(t *testing.T) {
store := NewSessionStore()
defer store.Stop()
session := &OAuthSession{
State: "expired-state",
CreatedAt: time.Now().Add(-SessionTTL - time.Minute), // 已过期
}
store.Set("expired-session", session)
_, ok := store.Get("expired-session")
if ok {
t.Error("过期的 session 应返回 false")
}
}
func TestSessionStore_Delete(t *testing.T) {
store := NewSessionStore()
defer store.Stop()
session := &OAuthSession{
State: "to-delete",
CreatedAt: time.Now(),
}
store.Set("del-session", session)
store.Delete("del-session")
_, ok := store.Get("del-session")
if ok {
t.Error("删除后 Get 应返回 false")
}
}
func TestSessionStore_Delete_不存在(t *testing.T) {
store := NewSessionStore()
defer store.Stop()
// 删除不存在的 session 不应 panic
store.Delete("nonexistent")
}
func TestSessionStore_Stop(t *testing.T) {
store := NewSessionStore()
store.Stop()
// 多次 Stop 不应 panic
store.Stop()
}
func TestSessionStore_多个Session(t *testing.T) {
store := NewSessionStore()
defer store.Stop()
for i := 0; i < 10; i++ {
session := &OAuthSession{
State: "state-" + string(rune('0'+i)),
CreatedAt: time.Now(),
}
store.Set("session-"+string(rune('0'+i)), session)
}
// 验证都能取到
for i := 0; i < 10; i++ {
_, ok := store.Get("session-" + string(rune('0'+i)))
if !ok {
t.Errorf("session-%d 应存在", i)
}
}
}
// ---------------------------------------------------------------------------
// GenerateRandomBytes
// ---------------------------------------------------------------------------
func TestGenerateRandomBytes_长度正确(t *testing.T) {
sizes := []int{0, 1, 16, 32, 64, 128}
for _, size := range sizes {
b, err := GenerateRandomBytes(size)
if err != nil {
t.Fatalf("GenerateRandomBytes(%d) 失败: %v", size, err)
}
if len(b) != size {
t.Errorf("长度不匹配: got %d, want %d", len(b), size)
}
}
}
func TestGenerateRandomBytes_不同调用产生不同结果(t *testing.T) {
b1, err := GenerateRandomBytes(32)
if err != nil {
t.Fatalf("第一次调用失败: %v", err)
}
b2, err := GenerateRandomBytes(32)
if err != nil {
t.Fatalf("第二次调用失败: %v", err)
}
// 两次生成的随机字节应该不同(概率上几乎不可能相同)
if string(b1) == string(b2) {
t.Error("两次生成的随机字节相同,概率极低,可能有问题")
}
}
// ---------------------------------------------------------------------------
// GenerateState
// ---------------------------------------------------------------------------
func TestGenerateState_返回值格式(t *testing.T) {
state, err := GenerateState()
if err != nil {
t.Fatalf("GenerateState 失败: %v", err)
}
if state == "" {
t.Error("GenerateState 返回空字符串")
}
// base64url 编码不应包含 +, /, =
if strings.ContainsAny(state, "+/=") {
t.Errorf("GenerateState 返回值包含非 base64url 字符: %s", state)
}
// 32 字节的 base64url 编码长度应为 43(去掉了尾部 = 填充)
if len(state) != 43 {
t.Errorf("GenerateState 返回值长度不匹配: got %d, want 43", len(state))
}
}
func TestGenerateState_唯一性(t *testing.T) {
s1, _ := GenerateState()
s2, _ := GenerateState()
if s1 == s2 {
t.Error("两次 GenerateState 结果相同")
}
}
// ---------------------------------------------------------------------------
// GenerateSessionID
// ---------------------------------------------------------------------------
func TestGenerateSessionID_返回值格式(t *testing.T) {
id, err := GenerateSessionID()
if err != nil {
t.Fatalf("GenerateSessionID 失败: %v", err)
}
if id == "" {
t.Error("GenerateSessionID 返回空字符串")
}
// 16 字节的 hex 编码长度应为 32
if len(id) != 32 {
t.Errorf("GenerateSessionID 返回值长度不匹配: got %d, want 32", len(id))
}
// 验证是合法的 hex 字符串
if _, err := hex.DecodeString(id); err != nil {
t.Errorf("GenerateSessionID 返回值不是合法的 hex 字符串: %s, err: %v", id, err)
}
}
func TestGenerateSessionID_唯一性(t *testing.T) {
id1, _ := GenerateSessionID()
id2, _ := GenerateSessionID()
if id1 == id2 {
t.Error("两次 GenerateSessionID 结果相同")
}
}
// ---------------------------------------------------------------------------
// GenerateCodeVerifier
// ---------------------------------------------------------------------------
func TestGenerateCodeVerifier_返回值格式(t *testing.T) {
verifier, err := GenerateCodeVerifier()
if err != nil {
t.Fatalf("GenerateCodeVerifier 失败: %v", err)
}
if verifier == "" {
t.Error("GenerateCodeVerifier 返回空字符串")
}
// base64url 编码不应包含 +, /, =
if strings.ContainsAny(verifier, "+/=") {
t.Errorf("GenerateCodeVerifier 返回值包含非 base64url 字符: %s", verifier)
}
// 32 字节的 base64url 编码长度应为 43
if len(verifier) != 43 {
t.Errorf("GenerateCodeVerifier 返回值长度不匹配: got %d, want 43", len(verifier))
}
}
func TestGenerateCodeVerifier_唯一性(t *testing.T) {
v1, _ := GenerateCodeVerifier()
v2, _ := GenerateCodeVerifier()
if v1 == v2 {
t.Error("两次 GenerateCodeVerifier 结果相同")
}
}
// ---------------------------------------------------------------------------
// GenerateCodeChallenge
// ---------------------------------------------------------------------------
func TestGenerateCodeChallenge_SHA256_Base64URL(t *testing.T) {
verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
challenge := GenerateCodeChallenge(verifier)
// 手动计算预期值
hash := sha256.Sum256([]byte(verifier))
expected := strings.TrimRight(base64.URLEncoding.EncodeToString(hash[:]), "=")
if challenge != expected {
t.Errorf("CodeChallenge 不匹配: got %s, want %s", challenge, expected)
}
}
func TestGenerateCodeChallenge_不含填充字符(t *testing.T) {
challenge := GenerateCodeChallenge("test-verifier")
if strings.Contains(challenge, "=") {
t.Errorf("CodeChallenge 不应包含 = 填充字符: %s", challenge)
}
}
func TestGenerateCodeChallenge_不含非URL安全字符(t *testing.T) {
challenge := GenerateCodeChallenge("another-verifier")
if strings.ContainsAny(challenge, "+/") {
t.Errorf("CodeChallenge 不应包含 + 或 / 字符: %s", challenge)
}
}
func TestGenerateCodeChallenge_相同输入相同输出(t *testing.T) {
c1 := GenerateCodeChallenge("same-verifier")
c2 := GenerateCodeChallenge("same-verifier")
if c1 != c2 {
t.Errorf("相同输入应产生相同输出: got %s and %s", c1, c2)
}
}
func TestGenerateCodeChallenge_不同输入不同输出(t *testing.T) {
c1 := GenerateCodeChallenge("verifier-1")
c2 := GenerateCodeChallenge("verifier-2")
if c1 == c2 {
t.Error("不同输入应产生不同输出")
}
}
// ---------------------------------------------------------------------------
// BuildAuthorizationURL
// ---------------------------------------------------------------------------
func TestBuildAuthorizationURL_参数验证(t *testing.T) {
state := "test-state-123"
codeChallenge := "test-challenge-abc"
authURL := BuildAuthorizationURL(state, codeChallenge)
// 验证以 AuthorizeURL 开头
if !strings.HasPrefix(authURL, AuthorizeURL+"?") {
t.Errorf("URL 应以 %s? 开头: got %s", AuthorizeURL, authURL)
}
// 解析 URL 并验证参数
parsed, err := url.Parse(authURL)
if err != nil {
t.Fatalf("解析 URL 失败: %v", err)
}
params := parsed.Query()
expectedParams := map[string]string{
"client_id": ClientID,
"redirect_uri": RedirectURI,
"response_type": "code",
"scope": Scopes,
"state": state,
"code_challenge": codeChallenge,
"code_challenge_method": "S256",
"access_type": "offline",
"prompt": "consent",
"include_granted_scopes": "true",
}
for key, want := range expectedParams {
got := params.Get(key)
if got != want {
t.Errorf("参数 %s 不匹配: got %q, want %q", key, got, want)
}
}
}
func TestBuildAuthorizationURL_参数数量(t *testing.T) {
authURL := BuildAuthorizationURL("s", "c")
parsed, err := url.Parse(authURL)
if err != nil {
t.Fatalf("解析 URL 失败: %v", err)
}
params := parsed.Query()
// 应包含 10 个参数
expectedCount := 10
if len(params) != expectedCount {
t.Errorf("参数数量不匹配: got %d, want %d", len(params), expectedCount)
}
}
func TestBuildAuthorizationURL_特殊字符编码(t *testing.T) {
state := "state+with/special=chars"
codeChallenge := "challenge+value"
authURL := BuildAuthorizationURL(state, codeChallenge)
parsed, err := url.Parse(authURL)
if err != nil {
t.Fatalf("解析 URL 失败: %v", err)
}
// 解析后应正确还原特殊字符
if got := parsed.Query().Get("state"); got != state {
t.Errorf("state 参数编码/解码不匹配: got %q, want %q", got, state)
}
}
// ---------------------------------------------------------------------------
// 常量值验证
// ---------------------------------------------------------------------------
func TestConstants_值正确(t *testing.T) {
if AuthorizeURL != "https://accounts.google.com/o/oauth2/v2/auth" {
t.Errorf("AuthorizeURL 不匹配: got %s", AuthorizeURL)
}
if TokenURL != "https://oauth2.googleapis.com/token" {
t.Errorf("TokenURL 不匹配: got %s", TokenURL)
}
if UserInfoURL != "https://www.googleapis.com/oauth2/v2/userinfo" {
t.Errorf("UserInfoURL 不匹配: got %s", UserInfoURL)
}
if ClientID != "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" {
t.Errorf("ClientID 不匹配: got %s", ClientID)
}
secret, err := getClientSecret()
if err != nil {
t.Fatalf("getClientSecret 应返回默认值,但报错: %v", err)
}
if secret != "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" {
t.Errorf("默认 client_secret 不匹配: got %s", secret)
}
if RedirectURI != "http://localhost:8085/callback" {
t.Errorf("RedirectURI 不匹配: got %s", RedirectURI)
}
if GetUserAgent() != "antigravity/1.19.6 windows/amd64" {
t.Errorf("UserAgent 不匹配: got %s", GetUserAgent())
}
if SessionTTL != 30*time.Minute {
t.Errorf("SessionTTL 不匹配: got %v", SessionTTL)
}
if URLAvailabilityTTL != 5*time.Minute {
t.Errorf("URLAvailabilityTTL 不匹配: got %v", URLAvailabilityTTL)
}
}
func TestScopes_包含必要范围(t *testing.T) {
expectedScopes := []string{
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
"https://www.googleapis.com/auth/cclog",
"https://www.googleapis.com/auth/experimentsandconfigs",
}
for _, scope := range expectedScopes {
if !strings.Contains(Scopes, scope) {
t.Errorf("Scopes 缺少 %s", scope)
}
}
}
......@@ -206,6 +206,7 @@ type modelInfo struct {
var modelInfoMap = map[string]modelInfo{
"claude-opus-4-5": {DisplayName: "Claude Opus 4.5", CanonicalID: "claude-opus-4-5-20250929"},
"claude-opus-4-6": {DisplayName: "Claude Opus 4.6", CanonicalID: "claude-opus-4-6"},
"claude-sonnet-4-6": {DisplayName: "Claude Sonnet 4.6", CanonicalID: "claude-sonnet-4-6"},
"claude-sonnet-4-5": {DisplayName: "Claude Sonnet 4.5", CanonicalID: "claude-sonnet-4-5-20250929"},
"claude-haiku-4-5": {DisplayName: "Claude Haiku 4.5", CanonicalID: "claude-haiku-4-5-20251001"},
}
......
package antigravity
import (
"crypto/rand"
"encoding/json"
"fmt"
"log"
"strings"
"sync/atomic"
"time"
)
// TransformGeminiToClaude 将 Gemini 响应转换为 Claude 格式(非流式)
......@@ -341,12 +344,30 @@ func buildGroundingText(grounding *GeminiGroundingMetadata) string {
return builder.String()
}
// generateRandomID 生成随机 ID
// fallbackCounter 降级伪随机 ID 的全局计数器,混入 seed 避免高并发下 UnixNano 相同导致碰撞。
var fallbackCounter uint64
// generateRandomID 生成密码学安全的随机 ID
func generateRandomID() string {
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
result := make([]byte, 12)
for i := range result {
result[i] = chars[i%len(chars)]
}
return string(result)
id := make([]byte, 12)
randBytes := make([]byte, 12)
if _, err := rand.Read(randBytes); err != nil {
// 避免在请求路径里 panic:极端情况下熵源不可用时降级为伪随机。
// 这里主要用于生成响应/工具调用的临时 ID,安全要求不高但需尽量避免碰撞。
cnt := atomic.AddUint64(&fallbackCounter, 1)
seed := uint64(time.Now().UnixNano()) ^ cnt
seed ^= uint64(len(err.Error())) << 32
for i := range id {
seed ^= seed << 13
seed ^= seed >> 7
seed ^= seed << 17
id[i] = chars[int(seed)%len(chars)]
}
return string(id)
}
for i, b := range randBytes {
id[i] = chars[int(b)%len(chars)]
}
return string(id)
}
//go:build unit
package antigravity
import (
"sync"
"sync/atomic"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// --- Task 7: 验证 generateRandomID 和降级碰撞防护 ---
func TestGenerateRandomID_Uniqueness(t *testing.T) {
seen := make(map[string]struct{}, 100)
for i := 0; i < 100; i++ {
id := generateRandomID()
require.Len(t, id, 12, "ID 长度应为 12")
_, dup := seen[id]
require.False(t, dup, "第 %d 次调用生成了重复 ID: %s", i, id)
seen[id] = struct{}{}
}
}
func TestFallbackCounter_Increments(t *testing.T) {
// 验证 fallbackCounter 的原子递增行为确保降级分支不会生成相同 seed
before := atomic.LoadUint64(&fallbackCounter)
cnt1 := atomic.AddUint64(&fallbackCounter, 1)
cnt2 := atomic.AddUint64(&fallbackCounter, 1)
require.Equal(t, before+1, cnt1, "第一次递增应为 before+1")
require.Equal(t, before+2, cnt2, "第二次递增应为 before+2")
require.NotEqual(t, cnt1, cnt2, "连续两次递增的计数器值应不同")
}
func TestFallbackCounter_ConcurrentIncrements(t *testing.T) {
// 验证并发递增的原子性 — 每次递增都应产生唯一值
const goroutines = 50
results := make([]uint64, goroutines)
var wg sync.WaitGroup
wg.Add(goroutines)
for i := 0; i < goroutines; i++ {
go func(idx int) {
defer wg.Done()
results[idx] = atomic.AddUint64(&fallbackCounter, 1)
}(i)
}
wg.Wait()
// 所有结果应唯一
seen := make(map[uint64]bool, goroutines)
for _, v := range results {
assert.False(t, seen[v], "并发递增产生了重复值: %d", v)
seen[v] = true
}
}
func TestGenerateRandomID_Charset(t *testing.T) {
const validChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
validSet := make(map[byte]struct{}, len(validChars))
for i := 0; i < len(validChars); i++ {
validSet[validChars[i]] = struct{}{}
}
for i := 0; i < 50; i++ {
id := generateRandomID()
for j := 0; j < len(id); j++ {
_, ok := validSet[id[j]]
require.True(t, ok, "ID 包含非法字符: %c (ID=%s)", id[j], id)
}
}
}
func TestGenerateRandomID_Length(t *testing.T) {
for i := 0; i < 100; i++ {
id := generateRandomID()
assert.Len(t, id, 12, "每次生成的 ID 长度应为 12")
}
}
func TestGenerateRandomID_ConcurrentUniqueness(t *testing.T) {
// 验证并发调用不会产生重复 ID
const goroutines = 100
results := make([]string, goroutines)
var wg sync.WaitGroup
wg.Add(goroutines)
for i := 0; i < goroutines; i++ {
go func(idx int) {
defer wg.Done()
results[idx] = generateRandomID()
}(i)
}
wg.Wait()
seen := make(map[string]bool, goroutines)
for _, id := range results {
assert.False(t, seen[id], "并发调用产生了重复 ID: %s", id)
seen[id] = true
}
}
func BenchmarkGenerateRandomID(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = generateRandomID()
}
}
......@@ -10,8 +10,14 @@ const (
BetaInterleavedThinking = "interleaved-thinking-2025-05-14"
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
BetaTokenCounting = "token-counting-2024-11-01"
BetaContext1M = "context-1m-2025-08-07"
BetaFastMode = "fast-mode-2026-02-01"
)
// DroppedBetas 是转发时需要从 anthropic-beta header 中移除的 beta token 列表。
// 这些 token 是客户端特有的,不应透传给上游 API。
var DroppedBetas = []string{BetaContext1M, BetaFastMode}
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
......@@ -77,6 +83,12 @@ var DefaultModels = []Model{
DisplayName: "Claude Opus 4.6",
CreatedAt: "2026-02-06T00:00:00Z",
},
{
ID: "claude-sonnet-4-6",
Type: "model",
DisplayName: "Claude Sonnet 4.6",
CreatedAt: "2026-02-18T00:00:00Z",
},
{
ID: "claude-sonnet-4-5-20250929",
Type: "model",
......
......@@ -8,9 +8,21 @@ const (
// ForcePlatform 强制平台(用于 /antigravity 路由),由 middleware.ForcePlatform 设置
ForcePlatform Key = "ctx_force_platform"
// RequestID 为服务端生成/透传的请求 ID。
RequestID Key = "ctx_request_id"
// ClientRequestID 客户端请求的唯一标识,用于追踪请求全生命周期(用于 Ops 监控与排障)。
ClientRequestID Key = "ctx_client_request_id"
// Model 请求模型标识(用于统一请求链路日志字段)。
Model Key = "ctx_model"
// Platform 当前请求最终命中的平台(用于统一请求链路日志字段)。
Platform Key = "ctx_platform"
// AccountID 当前请求最终命中的账号 ID(用于统一请求链路日志字段)。
AccountID Key = "ctx_account_id"
// RetryCount 表示当前请求在网关层的重试次数(用于 Ops 记录与排障)。
RetryCount Key = "ctx_retry_count"
......@@ -32,4 +44,15 @@ const (
// SingleAccountRetry 标识当前请求处于单账号 503 退避重试模式。
// 在此模式下,Service 层的模型限流预检查将等待限流过期而非直接切换账号。
SingleAccountRetry Key = "ctx_single_account_retry"
// PrefetchedStickyAccountID 标识上游(通常 handler)预取到的 sticky session 账号 ID。
// Service 层可复用该值,避免同请求链路重复读取 Redis。
PrefetchedStickyAccountID Key = "ctx_prefetched_sticky_account_id"
// PrefetchedStickyGroupID 标识上游预取 sticky session 时所使用的分组 ID。
// Service 层仅在分组匹配时复用 PrefetchedStickyAccountID,避免分组切换重试误用旧 sticky。
PrefetchedStickyGroupID Key = "ctx_prefetched_sticky_group_id"
// ClaudeCodeVersion stores the extracted Claude Code version from User-Agent (e.g. "2.1.22")
ClaudeCodeVersion Key = "ctx_claude_code_version"
)
......@@ -166,3 +166,18 @@ func TestToHTTP(t *testing.T) {
})
}
}
func TestToHTTP_MetadataDeepCopy(t *testing.T) {
md := map[string]string{"k": "v"}
appErr := BadRequest("BAD_REQUEST", "invalid").WithMetadata(md)
code, body := ToHTTP(appErr)
require.Equal(t, http.StatusBadRequest, code)
require.Equal(t, "v", body.Metadata["k"])
md["k"] = "changed"
require.Equal(t, "v", body.Metadata["k"])
appErr.Metadata["k"] = "changed-again"
require.Equal(t, "v", body.Metadata["k"])
}
......@@ -16,6 +16,16 @@ func ToHTTP(err error) (statusCode int, body Status) {
return http.StatusOK, Status{Code: int32(http.StatusOK)}
}
cloned := Clone(appErr)
return int(cloned.Code), cloned.Status
body = Status{
Code: appErr.Code,
Reason: appErr.Reason,
Message: appErr.Message,
}
if appErr.Metadata != nil {
body.Metadata = make(map[string]string, len(appErr.Metadata))
for k, v := range appErr.Metadata {
body.Metadata[k] = v
}
}
return int(appErr.Code), body
}
......@@ -21,6 +21,7 @@ func DefaultModels() []Model {
{Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods},
{Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
{Name: "models/gemini-3.1-pro-preview", SupportedGenerationMethods: methods},
}
}
......
......@@ -41,6 +41,9 @@ const (
GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
GeminiCLIOAuthClientSecret = "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"
// GeminiCLIOAuthClientSecretEnv is the environment variable name for the built-in client secret.
GeminiCLIOAuthClientSecretEnv = "GEMINI_CLI_OAUTH_CLIENT_SECRET"
SessionTTL = 30 * time.Minute
// GeminiCLIUserAgent mimics Gemini CLI to maximize compatibility with internal endpoints.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment