Unverified Commit 07be258d authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #603 from mt21625457/release

feat : 大幅度的性能优化 和 新增了很多功能
parents dbdb2959 53d55bb9
//go:build unit
package handler
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/testutil"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// 编译期接口断言
var _ service.SoraClient = (*stubSoraClient)(nil)
var _ service.AccountRepository = (*stubAccountRepo)(nil)
var _ service.GroupRepository = (*stubGroupRepo)(nil)
var _ service.UsageLogRepository = (*stubUsageLogRepo)(nil)
type stubSoraClient struct {
imageURLs []string
}
func (s *stubSoraClient) Enabled() bool { return true }
func (s *stubSoraClient) UploadImage(ctx context.Context, account *service.Account, data []byte, filename string) (string, error) {
return "upload", nil
}
func (s *stubSoraClient) CreateImageTask(ctx context.Context, account *service.Account, req service.SoraImageRequest) (string, error) {
return "task-image", nil
}
func (s *stubSoraClient) CreateVideoTask(ctx context.Context, account *service.Account, req service.SoraVideoRequest) (string, error) {
return "task-video", nil
}
func (s *stubSoraClient) CreateStoryboardTask(ctx context.Context, account *service.Account, req service.SoraStoryboardRequest) (string, error) {
return "task-video", nil
}
func (s *stubSoraClient) UploadCharacterVideo(ctx context.Context, account *service.Account, data []byte) (string, error) {
return "cameo-1", nil
}
func (s *stubSoraClient) GetCameoStatus(ctx context.Context, account *service.Account, cameoID string) (*service.SoraCameoStatus, error) {
return &service.SoraCameoStatus{
Status: "finalized",
StatusMessage: "Completed",
DisplayNameHint: "Character",
UsernameHint: "user.character",
ProfileAssetURL: "https://example.com/avatar.webp",
}, nil
}
func (s *stubSoraClient) DownloadCharacterImage(ctx context.Context, account *service.Account, imageURL string) ([]byte, error) {
return []byte("avatar"), nil
}
func (s *stubSoraClient) UploadCharacterImage(ctx context.Context, account *service.Account, data []byte) (string, error) {
return "asset-pointer", nil
}
func (s *stubSoraClient) FinalizeCharacter(ctx context.Context, account *service.Account, req service.SoraCharacterFinalizeRequest) (string, error) {
return "character-1", nil
}
func (s *stubSoraClient) SetCharacterPublic(ctx context.Context, account *service.Account, cameoID string) error {
return nil
}
func (s *stubSoraClient) DeleteCharacter(ctx context.Context, account *service.Account, characterID string) error {
return nil
}
func (s *stubSoraClient) PostVideoForWatermarkFree(ctx context.Context, account *service.Account, generationID string) (string, error) {
return "s_post", nil
}
func (s *stubSoraClient) DeletePost(ctx context.Context, account *service.Account, postID string) error {
return nil
}
func (s *stubSoraClient) GetWatermarkFreeURLCustom(ctx context.Context, account *service.Account, parseURL, parseToken, postID string) (string, error) {
return "https://example.com/no-watermark.mp4", nil
}
func (s *stubSoraClient) EnhancePrompt(ctx context.Context, account *service.Account, prompt, expansionLevel string, durationS int) (string, error) {
return "enhanced prompt", nil
}
func (s *stubSoraClient) GetImageTask(ctx context.Context, account *service.Account, taskID string) (*service.SoraImageTaskStatus, error) {
return &service.SoraImageTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil
}
func (s *stubSoraClient) GetVideoTask(ctx context.Context, account *service.Account, taskID string) (*service.SoraVideoTaskStatus, error) {
return &service.SoraVideoTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil
}
type stubAccountRepo struct {
accounts map[int64]*service.Account
}
func (r *stubAccountRepo) Create(ctx context.Context, account *service.Account) error { return nil }
func (r *stubAccountRepo) GetByID(ctx context.Context, id int64) (*service.Account, error) {
if acc, ok := r.accounts[id]; ok {
return acc, nil
}
return nil, service.ErrAccountNotFound
}
func (r *stubAccountRepo) GetByIDs(ctx context.Context, ids []int64) ([]*service.Account, error) {
var result []*service.Account
for _, id := range ids {
if acc, ok := r.accounts[id]; ok {
result = append(result, acc)
}
}
return result, nil
}
func (r *stubAccountRepo) ExistsByID(ctx context.Context, id int64) (bool, error) {
_, ok := r.accounts[id]
return ok, nil
}
func (r *stubAccountRepo) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*service.Account, error) {
return nil, nil
}
func (r *stubAccountRepo) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) {
return nil, nil
}
func (r *stubAccountRepo) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
return map[string]int64{}, nil
}
func (r *stubAccountRepo) Update(ctx context.Context, account *service.Account) error { return nil }
func (r *stubAccountRepo) Delete(ctx context.Context, id int64) error { return nil }
func (r *stubAccountRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubAccountRepo) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {
return nil, nil
}
func (r *stubAccountRepo) ListActive(ctx context.Context) ([]service.Account, error) { return nil, nil }
func (r *stubAccountRepo) ListByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
return r.listSchedulableByPlatform(platform), nil
}
func (r *stubAccountRepo) UpdateLastUsed(ctx context.Context, id int64) error { return nil }
func (r *stubAccountRepo) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
return nil
}
func (r *stubAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error { return nil }
func (r *stubAccountRepo) ClearError(ctx context.Context, id int64) error { return nil }
func (r *stubAccountRepo) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return nil
}
func (r *stubAccountRepo) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
return 0, nil
}
func (r *stubAccountRepo) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
return nil
}
func (r *stubAccountRepo) ListSchedulable(ctx context.Context) ([]service.Account, error) {
return r.listSchedulable(), nil
}
func (r *stubAccountRepo) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) {
return r.listSchedulable(), nil
}
func (r *stubAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
return r.listSchedulableByPlatform(platform), nil
}
func (r *stubAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
return r.listSchedulableByPlatform(platform), nil
}
func (r *stubAccountRepo) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
var result []service.Account
for _, acc := range r.accounts {
for _, platform := range platforms {
if acc.Platform == platform && acc.IsSchedulable() {
result = append(result, *acc)
break
}
}
}
return result, nil
}
func (r *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) {
return r.ListSchedulableByPlatforms(ctx, platforms)
}
func (r *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return nil
}
func (r *stubAccountRepo) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
return nil
}
func (r *stubAccountRepo) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return nil
}
func (r *stubAccountRepo) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
return nil
}
func (r *stubAccountRepo) ClearTempUnschedulable(ctx context.Context, id int64) error { return nil }
func (r *stubAccountRepo) ClearRateLimit(ctx context.Context, id int64) error { return nil }
func (r *stubAccountRepo) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
return nil
}
func (r *stubAccountRepo) ClearModelRateLimits(ctx context.Context, id int64) error { return nil }
func (r *stubAccountRepo) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
return nil
}
func (r *stubAccountRepo) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
return nil
}
func (r *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
return 0, nil
}
func (r *stubAccountRepo) listSchedulable() []service.Account {
var result []service.Account
for _, acc := range r.accounts {
if acc.IsSchedulable() {
result = append(result, *acc)
}
}
return result
}
func (r *stubAccountRepo) listSchedulableByPlatform(platform string) []service.Account {
var result []service.Account
for _, acc := range r.accounts {
if acc.Platform == platform && acc.IsSchedulable() {
result = append(result, *acc)
}
}
return result
}
type stubGroupRepo struct {
group *service.Group
}
func (r *stubGroupRepo) Create(ctx context.Context, group *service.Group) error { return nil }
func (r *stubGroupRepo) GetByID(ctx context.Context, id int64) (*service.Group, error) {
return r.group, nil
}
func (r *stubGroupRepo) GetByIDLite(ctx context.Context, id int64) (*service.Group, error) {
return r.group, nil
}
func (r *stubGroupRepo) Update(ctx context.Context, group *service.Group) error { return nil }
func (r *stubGroupRepo) Delete(ctx context.Context, id int64) error { return nil }
func (r *stubGroupRepo) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
return nil, nil
}
func (r *stubGroupRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubGroupRepo) ListActive(ctx context.Context) ([]service.Group, error) { return nil, nil }
func (r *stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
return nil, nil
}
func (r *stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) {
return false, nil
}
func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
return 0, nil
}
func (r *stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, nil
}
func (r *stubGroupRepo) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
return nil, nil
}
func (r *stubGroupRepo) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
return nil
}
func (r *stubGroupRepo) UpdateSortOrders(ctx context.Context, updates []service.GroupSortOrderUpdate) error {
return nil
}
type stubUsageLogRepo struct{}
func (s *stubUsageLogRepo) Create(ctx context.Context, log *service.UsageLog) (bool, error) {
return true, nil
}
func (s *stubUsageLogRepo) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) {
return nil, nil
}
func (s *stubUsageLogRepo) Delete(ctx context.Context, id int64) error { return nil }
func (s *stubUsageLogRepo) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*usagestats.UserDashboardStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error) {
return nil, nil
}
func (s *stubUsageLogRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error) {
return nil, nil
}
func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
RunMode: config.RunModeSimple,
Gateway: config.GatewayConfig{
SoraStreamMode: "force",
MaxAccountSwitches: 1,
Scheduling: config.GatewaySchedulingConfig{
LoadBatchEnabled: false,
},
},
Concurrency: config.ConcurrencyConfig{PingInterval: 0},
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.test",
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
account := &service.Account{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, Concurrency: 1, Priority: 1}
accountRepo := &stubAccountRepo{accounts: map[int64]*service.Account{account.ID: account}}
group := &service.Group{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Hydrated: true}
groupRepo := &stubGroupRepo{group: group}
usageLogRepo := &stubUsageLogRepo{}
deferredService := service.NewDeferredService(accountRepo, nil, 0)
billingService := service.NewBillingService(cfg, nil)
concurrencyService := service.NewConcurrencyService(testutil.StubConcurrencyCache{})
billingCacheService := service.NewBillingCacheService(nil, nil, nil, cfg)
t.Cleanup(func() {
billingCacheService.Stop()
})
gatewayService := service.NewGatewayService(
accountRepo,
groupRepo,
usageLogRepo,
nil,
nil,
nil,
testutil.StubGatewayCache{},
cfg,
nil,
concurrencyService,
billingService,
nil,
billingCacheService,
nil,
nil,
deferredService,
nil,
testutil.StubSessionLimitCache{},
nil,
)
soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}}
soraGatewayService := service.NewSoraGatewayService(soraClient, nil, nil, cfg)
handler := NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, nil, cfg)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := `{"model":"gpt-image","messages":[{"role":"user","content":"hello"}]}`
c.Request = httptest.NewRequest(http.MethodPost, "/sora/v1/chat/completions", strings.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
apiKey := &service.APIKey{
ID: 1,
UserID: 1,
Status: service.StatusActive,
GroupID: &group.ID,
User: &service.User{ID: 1, Concurrency: 1, Status: service.StatusActive},
Group: group,
}
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.UserID, Concurrency: apiKey.User.Concurrency})
handler.ChatCompletions(c)
require.Equal(t, http.StatusOK, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.NotEmpty(t, resp["media_url"])
}
// TestSoraHandler_StreamForcing 验证 sora handler 的 stream 强制逻辑
func TestSoraHandler_StreamForcing(t *testing.T) {
// 测试 1:stream=false 时 sjson 强制修改为 true
body := []byte(`{"model":"sora","messages":[{"role":"user","content":"test"}],"stream":false}`)
clientStream := gjson.GetBytes(body, "stream").Bool()
require.False(t, clientStream)
newBody, err := sjson.SetBytes(body, "stream", true)
require.NoError(t, err)
require.True(t, gjson.GetBytes(newBody, "stream").Bool())
// 测试 2:stream=true 时不修改
body2 := []byte(`{"model":"sora","messages":[{"role":"user","content":"test"}],"stream":true}`)
require.True(t, gjson.GetBytes(body2, "stream").Bool())
// 测试 3:无 stream 字段时 gjson 返回 false(零值)
body3 := []byte(`{"model":"sora","messages":[{"role":"user","content":"test"}]}`)
require.False(t, gjson.GetBytes(body3, "stream").Bool())
}
// TestSoraHandler_ValidationExtraction 验证 sora handler 中 gjson 字段校验逻辑
func TestSoraHandler_ValidationExtraction(t *testing.T) {
// model 缺失
body := []byte(`{"messages":[{"role":"user","content":"test"}]}`)
modelResult := gjson.GetBytes(body, "model")
require.True(t, !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "")
// model 为数字 → 类型不是 gjson.String,应被拒绝
body1b := []byte(`{"model":123,"messages":[{"role":"user","content":"test"}]}`)
modelResult1b := gjson.GetBytes(body1b, "model")
require.True(t, modelResult1b.Exists())
require.NotEqual(t, gjson.String, modelResult1b.Type)
// messages 缺失
body2 := []byte(`{"model":"sora"}`)
require.False(t, gjson.GetBytes(body2, "messages").IsArray())
// messages 不是 JSON 数组(字符串)
body3 := []byte(`{"model":"sora","messages":"not array"}`)
require.False(t, gjson.GetBytes(body3, "messages").IsArray())
// messages 是对象而非数组 → IsArray 返回 false
body4 := []byte(`{"model":"sora","messages":{}}`)
require.False(t, gjson.GetBytes(body4, "messages").IsArray())
// messages 是空数组 → IsArray 为 true 但 len==0,应被拒绝
body5 := []byte(`{"model":"sora","messages":[]}`)
msgsResult := gjson.GetBytes(body5, "messages")
require.True(t, msgsResult.IsArray())
require.Equal(t, 0, len(msgsResult.Array()))
// 非法 JSON 被 gjson.ValidBytes 拦截
require.False(t, gjson.ValidBytes([]byte(`{invalid`)))
}
// TestGenerateOpenAISessionHash_WithBody 验证 generateOpenAISessionHash 的 body/header 解析逻辑
func TestGenerateOpenAISessionHash_WithBody(t *testing.T) {
gin.SetMode(gin.TestMode)
// 从 body 提取 prompt_cache_key
body := []byte(`{"model":"sora","prompt_cache_key":"session-abc"}`)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/", nil)
hash := generateOpenAISessionHash(c, body)
require.NotEmpty(t, hash)
// 无 prompt_cache_key 且无 header → 空 hash
body2 := []byte(`{"model":"sora"}`)
hash2 := generateOpenAISessionHash(c, body2)
require.Empty(t, hash2)
// header 优先于 body
c.Request.Header.Set("session_id", "from-header")
hash3 := generateOpenAISessionHash(c, body)
require.NotEmpty(t, hash3)
require.NotEqual(t, hash, hash3) // 不同来源应产生不同 hash
}
func TestSoraHandleStreamingAwareError_JSONEscaping(t *testing.T) {
tests := []struct {
name string
errType string
message string
}{
{
name: "包含双引号",
errType: "upstream_error",
message: `upstream returned "invalid" payload`,
},
{
name: "包含换行和制表符",
errType: "rate_limit_error",
message: "line1\nline2\ttab",
},
{
name: "包含反斜杠",
errType: "upstream_error",
message: `path C:\Users\test\file.txt not found`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
h := &SoraGatewayHandler{}
h.handleStreamingAwareError(c, http.StatusBadGateway, tt.errType, tt.message, true)
body := w.Body.String()
require.True(t, strings.HasPrefix(body, "event: error\n"), "应以 SSE error 事件开头")
require.True(t, strings.HasSuffix(body, "\n\n"), "应以 SSE 结束分隔符结尾")
lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n")
require.Len(t, lines, 2, "SSE 错误事件应包含 event 行和 data 行")
require.Equal(t, "event: error", lines[0])
require.True(t, strings.HasPrefix(lines[1], "data: "), "第二行应为 data 前缀")
jsonStr := strings.TrimPrefix(lines[1], "data: ")
var parsed map[string]any
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed), "data 行必须是合法 JSON")
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok, "JSON 中应包含 error 对象")
require.Equal(t, tt.errType, errorObj["type"])
require.Equal(t, tt.message, errorObj["message"])
})
}
}
func TestSoraHandleFailoverExhausted_StreamPassesUpstreamMessage(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
h := &SoraGatewayHandler{}
resp := []byte(`{"error":{"message":"invalid \"prompt\"\nline2","code":"bad_request"}}`)
h.handleFailoverExhausted(c, http.StatusBadGateway, nil, resp, true)
body := w.Body.String()
require.True(t, strings.HasPrefix(body, "event: error\n"))
require.True(t, strings.HasSuffix(body, "\n\n"))
lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n")
require.Len(t, lines, 2)
jsonStr := strings.TrimPrefix(lines[1], "data: ")
var parsed map[string]any
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok)
require.Equal(t, "upstream_error", errorObj["type"])
require.Equal(t, "invalid \"prompt\"\nline2", errorObj["message"])
}
func TestSoraHandleFailoverExhausted_CloudflareChallengeIncludesRay(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
headers := http.Header{}
headers.Set("cf-ray", "9d01b0e9ecc35829-SEA")
body := []byte(`<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={};</script></body></html>`)
h := &SoraGatewayHandler{}
h.handleFailoverExhausted(c, http.StatusForbidden, headers, body, true)
lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n")
require.Len(t, lines, 2)
jsonStr := strings.TrimPrefix(lines[1], "data: ")
var parsed map[string]any
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok)
require.Equal(t, "upstream_error", errorObj["type"])
msg, _ := errorObj["message"].(string)
require.Contains(t, msg, "Cloudflare challenge")
require.Contains(t, msg, "cf-ray: 9d01b0e9ecc35829-SEA")
}
func TestSoraHandleFailoverExhausted_CfShield429MappedToRateLimitError(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
headers := http.Header{}
headers.Set("cf-ray", "9d03b68c086027a1-SEA")
body := []byte(`{"error":{"code":"cf_shield_429","message":"shield blocked"}}`)
h := &SoraGatewayHandler{}
h.handleFailoverExhausted(c, http.StatusTooManyRequests, headers, body, true)
lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n")
require.Len(t, lines, 2)
jsonStr := strings.TrimPrefix(lines[1], "data: ")
var parsed map[string]any
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok)
require.Equal(t, "rate_limit_error", errorObj["type"])
msg, _ := errorObj["message"].(string)
require.Contains(t, msg, "Cloudflare shield")
require.Contains(t, msg, "cf-ray: 9d03b68c086027a1-SEA")
}
func TestExtractSoraFailoverHeaderInsights(t *testing.T) {
headers := http.Header{}
headers.Set("cf-mitigated", "challenge")
headers.Set("content-type", "text/html")
body := []byte(`<script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script>`)
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(headers, body)
require.Equal(t, "9cff2d62d83bb98d", rayID)
require.Equal(t, "challenge", mitigated)
require.Equal(t, "text/html", contentType)
}
...@@ -392,7 +392,7 @@ func (h *UsageHandler) DashboardAPIKeysUsage(c *gin.Context) { ...@@ -392,7 +392,7 @@ func (h *UsageHandler) DashboardAPIKeysUsage(c *gin.Context) {
return return
} }
stats, err := h.usageService.GetBatchAPIKeyUsageStats(c.Request.Context(), validAPIKeyIDs) stats, err := h.usageService.GetBatchAPIKeyUsageStats(c.Request.Context(), validAPIKeyIDs, time.Time{}, time.Time{})
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
......
package handler
import (
"context"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func newUsageRecordTestPool(t *testing.T) *service.UsageRecordWorkerPool {
t.Helper()
pool := service.NewUsageRecordWorkerPoolWithOptions(service.UsageRecordWorkerPoolOptions{
WorkerCount: 1,
QueueSize: 8,
TaskTimeout: time.Second,
OverflowPolicy: "drop",
OverflowSamplePercent: 0,
AutoScaleEnabled: false,
})
t.Cleanup(pool.Stop)
return pool
}
func TestGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
pool := newUsageRecordTestPool(t)
h := &GatewayHandler{usageRecordWorkerPool: pool}
done := make(chan struct{})
h.submitUsageRecordTask(func(ctx context.Context) {
close(done)
})
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("task not executed")
}
}
func TestGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) {
h := &GatewayHandler{}
var called atomic.Bool
h.submitUsageRecordTask(func(ctx context.Context) {
if _, ok := ctx.Deadline(); !ok {
t.Fatal("expected deadline in fallback context")
}
called.Store(true)
})
require.True(t, called.Load())
}
func TestGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
h := &GatewayHandler{}
require.NotPanics(t, func() {
h.submitUsageRecordTask(nil)
})
}
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
pool := newUsageRecordTestPool(t)
h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool}
done := make(chan struct{})
h.submitUsageRecordTask(func(ctx context.Context) {
close(done)
})
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("task not executed")
}
}
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) {
h := &OpenAIGatewayHandler{}
var called atomic.Bool
h.submitUsageRecordTask(func(ctx context.Context) {
if _, ok := ctx.Deadline(); !ok {
t.Fatal("expected deadline in fallback context")
}
called.Store(true)
})
require.True(t, called.Load())
}
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
h := &OpenAIGatewayHandler{}
require.NotPanics(t, func() {
h.submitUsageRecordTask(nil)
})
}
func TestSoraGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
pool := newUsageRecordTestPool(t)
h := &SoraGatewayHandler{usageRecordWorkerPool: pool}
done := make(chan struct{})
h.submitUsageRecordTask(func(ctx context.Context) {
close(done)
})
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("task not executed")
}
}
func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) {
h := &SoraGatewayHandler{}
var called atomic.Bool
h.submitUsageRecordTask(func(ctx context.Context) {
if _, ok := ctx.Deadline(); !ok {
t.Fatal("expected deadline in fallback context")
}
called.Store(true)
})
require.True(t, called.Load())
}
func TestSoraGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
h := &SoraGatewayHandler{}
require.NotPanics(t, func() {
h.submitUsageRecordTask(nil)
})
}
...@@ -53,8 +53,8 @@ func ProvideAdminHandlers( ...@@ -53,8 +53,8 @@ func ProvideAdminHandlers(
} }
// ProvideSystemHandler creates admin.SystemHandler with UpdateService // ProvideSystemHandler creates admin.SystemHandler with UpdateService
func ProvideSystemHandler(updateService *service.UpdateService) *admin.SystemHandler { func ProvideSystemHandler(updateService *service.UpdateService, lockService *service.SystemOperationLockService) *admin.SystemHandler {
return admin.NewSystemHandler(updateService) return admin.NewSystemHandler(updateService, lockService)
} }
// ProvideSettingHandler creates SettingHandler with version from BuildInfo // ProvideSettingHandler creates SettingHandler with version from BuildInfo
...@@ -74,8 +74,11 @@ func ProvideHandlers( ...@@ -74,8 +74,11 @@ func ProvideHandlers(
adminHandlers *AdminHandlers, adminHandlers *AdminHandlers,
gatewayHandler *GatewayHandler, gatewayHandler *GatewayHandler,
openaiGatewayHandler *OpenAIGatewayHandler, openaiGatewayHandler *OpenAIGatewayHandler,
soraGatewayHandler *SoraGatewayHandler,
settingHandler *SettingHandler, settingHandler *SettingHandler,
totpHandler *TotpHandler, totpHandler *TotpHandler,
_ *service.IdempotencyCoordinator,
_ *service.IdempotencyCleanupService,
) *Handlers { ) *Handlers {
return &Handlers{ return &Handlers{
Auth: authHandler, Auth: authHandler,
...@@ -88,6 +91,7 @@ func ProvideHandlers( ...@@ -88,6 +91,7 @@ func ProvideHandlers(
Admin: adminHandlers, Admin: adminHandlers,
Gateway: gatewayHandler, Gateway: gatewayHandler,
OpenAIGateway: openaiGatewayHandler, OpenAIGateway: openaiGatewayHandler,
SoraGateway: soraGatewayHandler,
Setting: settingHandler, Setting: settingHandler,
Totp: totpHandler, Totp: totpHandler,
} }
...@@ -105,6 +109,7 @@ var ProviderSet = wire.NewSet( ...@@ -105,6 +109,7 @@ var ProviderSet = wire.NewSet(
NewAnnouncementHandler, NewAnnouncementHandler,
NewGatewayHandler, NewGatewayHandler,
NewOpenAIGatewayHandler, NewOpenAIGatewayHandler,
NewSoraGatewayHandler,
NewTotpHandler, NewTotpHandler,
ProvideSettingHandler, ProvideSettingHandler,
......
...@@ -21,11 +21,18 @@ var ( ...@@ -21,11 +21,18 @@ var (
// - "" (默认): 使用 /v1/messages, /v1beta/models(混合模式,可调度 antigravity 账户) // - "" (默认): 使用 /v1/messages, /v1beta/models(混合模式,可调度 antigravity 账户)
// - "/antigravity": 使用 /antigravity/v1/messages, /antigravity/v1beta/models(非混合模式,仅 antigravity 账户) // - "/antigravity": 使用 /antigravity/v1/messages, /antigravity/v1beta/models(非混合模式,仅 antigravity 账户)
endpointPrefix = getEnv("ENDPOINT_PREFIX", "") endpointPrefix = getEnv("ENDPOINT_PREFIX", "")
claudeAPIKey = "sk-8e572bc3b3de92ace4f41f4256c28600ca11805732a7b693b5c44741346bbbb3"
geminiAPIKey = "sk-5950197a2085b38bbe5a1b229cc02b8ece914963fc44cacc06d497ae8b87410f"
testInterval = 1 * time.Second // 测试间隔,防止限流 testInterval = 1 * time.Second // 测试间隔,防止限流
) )
const (
// 注意:E2E 测试请使用环境变量注入密钥,避免任何凭证进入仓库历史。
// 例如:
// export CLAUDE_API_KEY="sk-..."
// export GEMINI_API_KEY="sk-..."
claudeAPIKeyEnv = "CLAUDE_API_KEY"
geminiAPIKeyEnv = "GEMINI_API_KEY"
)
func getEnv(key, defaultVal string) string { func getEnv(key, defaultVal string) string {
if v := os.Getenv(key); v != "" { if v := os.Getenv(key); v != "" {
return v return v
...@@ -65,16 +72,45 @@ func TestMain(m *testing.M) { ...@@ -65,16 +72,45 @@ func TestMain(m *testing.M) {
if endpointPrefix != "" { if endpointPrefix != "" {
mode = "Antigravity 模式" mode = "Antigravity 模式"
} }
fmt.Printf("\n🚀 E2E Gateway Tests - %s (prefix=%q, %s)\n\n", baseURL, endpointPrefix, mode) claudeKeySet := strings.TrimSpace(os.Getenv(claudeAPIKeyEnv)) != ""
geminiKeySet := strings.TrimSpace(os.Getenv(geminiAPIKeyEnv)) != ""
fmt.Printf("\n🚀 E2E Gateway Tests - %s (prefix=%q, %s, %s=%v, %s=%v)\n\n",
baseURL,
endpointPrefix,
mode,
claudeAPIKeyEnv,
claudeKeySet,
geminiAPIKeyEnv,
geminiKeySet,
)
os.Exit(m.Run()) os.Exit(m.Run())
} }
func requireClaudeAPIKey(t *testing.T) string {
t.Helper()
key := strings.TrimSpace(os.Getenv(claudeAPIKeyEnv))
if key == "" {
t.Skipf("未设置 %s,跳过 Claude 相关 E2E 测试", claudeAPIKeyEnv)
}
return key
}
func requireGeminiAPIKey(t *testing.T) string {
t.Helper()
key := strings.TrimSpace(os.Getenv(geminiAPIKeyEnv))
if key == "" {
t.Skipf("未设置 %s,跳过 Gemini 相关 E2E 测试", geminiAPIKeyEnv)
}
return key
}
// TestClaudeModelsList 测试 GET /v1/models // TestClaudeModelsList 测试 GET /v1/models
func TestClaudeModelsList(t *testing.T) { func TestClaudeModelsList(t *testing.T) {
claudeKey := requireClaudeAPIKey(t)
url := baseURL + endpointPrefix + "/v1/models" url := baseURL + endpointPrefix + "/v1/models"
req, _ := http.NewRequest("GET", url, nil) req, _ := http.NewRequest("GET", url, nil)
req.Header.Set("Authorization", "Bearer "+claudeAPIKey) req.Header.Set("Authorization", "Bearer "+claudeKey)
client := &http.Client{Timeout: 30 * time.Second} client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req) resp, err := client.Do(req)
...@@ -106,10 +142,11 @@ func TestClaudeModelsList(t *testing.T) { ...@@ -106,10 +142,11 @@ func TestClaudeModelsList(t *testing.T) {
// TestGeminiModelsList 测试 GET /v1beta/models // TestGeminiModelsList 测试 GET /v1beta/models
func TestGeminiModelsList(t *testing.T) { func TestGeminiModelsList(t *testing.T) {
geminiKey := requireGeminiAPIKey(t)
url := baseURL + endpointPrefix + "/v1beta/models" url := baseURL + endpointPrefix + "/v1beta/models"
req, _ := http.NewRequest("GET", url, nil) req, _ := http.NewRequest("GET", url, nil)
req.Header.Set("Authorization", "Bearer "+geminiAPIKey) req.Header.Set("Authorization", "Bearer "+geminiKey)
client := &http.Client{Timeout: 30 * time.Second} client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req) resp, err := client.Do(req)
...@@ -137,21 +174,22 @@ func TestGeminiModelsList(t *testing.T) { ...@@ -137,21 +174,22 @@ func TestGeminiModelsList(t *testing.T) {
// TestClaudeMessages 测试 Claude /v1/messages 接口 // TestClaudeMessages 测试 Claude /v1/messages 接口
func TestClaudeMessages(t *testing.T) { func TestClaudeMessages(t *testing.T) {
claudeKey := requireClaudeAPIKey(t)
for i, model := range claudeModels { for i, model := range claudeModels {
if i > 0 { if i > 0 {
time.Sleep(testInterval) time.Sleep(testInterval)
} }
t.Run(model+"_非流式", func(t *testing.T) { t.Run(model+"_非流式", func(t *testing.T) {
testClaudeMessage(t, model, false) testClaudeMessage(t, claudeKey, model, false)
}) })
time.Sleep(testInterval) time.Sleep(testInterval)
t.Run(model+"_流式", func(t *testing.T) { t.Run(model+"_流式", func(t *testing.T) {
testClaudeMessage(t, model, true) testClaudeMessage(t, claudeKey, model, true)
}) })
} }
} }
func testClaudeMessage(t *testing.T, model string, stream bool) { func testClaudeMessage(t *testing.T, claudeKey string, model string, stream bool) {
url := baseURL + endpointPrefix + "/v1/messages" url := baseURL + endpointPrefix + "/v1/messages"
payload := map[string]any{ payload := map[string]any{
...@@ -166,7 +204,7 @@ func testClaudeMessage(t *testing.T, model string, stream bool) { ...@@ -166,7 +204,7 @@ func testClaudeMessage(t *testing.T, model string, stream bool) {
req, _ := http.NewRequest("POST", url, bytes.NewReader(body)) req, _ := http.NewRequest("POST", url, bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+claudeAPIKey) req.Header.Set("Authorization", "Bearer "+claudeKey)
req.Header.Set("anthropic-version", "2023-06-01") req.Header.Set("anthropic-version", "2023-06-01")
client := &http.Client{Timeout: 60 * time.Second} client := &http.Client{Timeout: 60 * time.Second}
...@@ -213,21 +251,22 @@ func testClaudeMessage(t *testing.T, model string, stream bool) { ...@@ -213,21 +251,22 @@ func testClaudeMessage(t *testing.T, model string, stream bool) {
// TestGeminiGenerateContent 测试 Gemini /v1beta/models/:model 接口 // TestGeminiGenerateContent 测试 Gemini /v1beta/models/:model 接口
func TestGeminiGenerateContent(t *testing.T) { func TestGeminiGenerateContent(t *testing.T) {
geminiKey := requireGeminiAPIKey(t)
for i, model := range geminiModels { for i, model := range geminiModels {
if i > 0 { if i > 0 {
time.Sleep(testInterval) time.Sleep(testInterval)
} }
t.Run(model+"_非流式", func(t *testing.T) { t.Run(model+"_非流式", func(t *testing.T) {
testGeminiGenerate(t, model, false) testGeminiGenerate(t, geminiKey, model, false)
}) })
time.Sleep(testInterval) time.Sleep(testInterval)
t.Run(model+"_流式", func(t *testing.T) { t.Run(model+"_流式", func(t *testing.T) {
testGeminiGenerate(t, model, true) testGeminiGenerate(t, geminiKey, model, true)
}) })
} }
} }
func testGeminiGenerate(t *testing.T, model string, stream bool) { func testGeminiGenerate(t *testing.T, geminiKey string, model string, stream bool) {
action := "generateContent" action := "generateContent"
if stream { if stream {
action = "streamGenerateContent" action = "streamGenerateContent"
...@@ -254,7 +293,7 @@ func testGeminiGenerate(t *testing.T, model string, stream bool) { ...@@ -254,7 +293,7 @@ func testGeminiGenerate(t *testing.T, model string, stream bool) {
req, _ := http.NewRequest("POST", url, bytes.NewReader(body)) req, _ := http.NewRequest("POST", url, bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+geminiAPIKey) req.Header.Set("Authorization", "Bearer "+geminiKey)
client := &http.Client{Timeout: 60 * time.Second} client := &http.Client{Timeout: 60 * time.Second}
resp, err := client.Do(req) resp, err := client.Do(req)
...@@ -301,6 +340,7 @@ func testGeminiGenerate(t *testing.T, model string, stream bool) { ...@@ -301,6 +340,7 @@ func testGeminiGenerate(t *testing.T, model string, stream bool) {
// TestClaudeMessagesWithComplexTools 测试带复杂工具 schema 的请求 // TestClaudeMessagesWithComplexTools 测试带复杂工具 schema 的请求
// 模拟 Claude Code 发送的请求,包含需要清理的 JSON Schema 字段 // 模拟 Claude Code 发送的请求,包含需要清理的 JSON Schema 字段
func TestClaudeMessagesWithComplexTools(t *testing.T) { func TestClaudeMessagesWithComplexTools(t *testing.T) {
claudeKey := requireClaudeAPIKey(t)
// 测试模型列表(只测试几个代表性模型) // 测试模型列表(只测试几个代表性模型)
models := []string{ models := []string{
"claude-opus-4-5-20251101", // Claude 模型 "claude-opus-4-5-20251101", // Claude 模型
...@@ -312,12 +352,12 @@ func TestClaudeMessagesWithComplexTools(t *testing.T) { ...@@ -312,12 +352,12 @@ func TestClaudeMessagesWithComplexTools(t *testing.T) {
time.Sleep(testInterval) time.Sleep(testInterval)
} }
t.Run(model+"_复杂工具", func(t *testing.T) { t.Run(model+"_复杂工具", func(t *testing.T) {
testClaudeMessageWithTools(t, model) testClaudeMessageWithTools(t, claudeKey, model)
}) })
} }
} }
func testClaudeMessageWithTools(t *testing.T, model string) { func testClaudeMessageWithTools(t *testing.T, claudeKey string, model string) {
url := baseURL + endpointPrefix + "/v1/messages" url := baseURL + endpointPrefix + "/v1/messages"
// 构造包含复杂 schema 的工具定义(模拟 Claude Code 的工具) // 构造包含复杂 schema 的工具定义(模拟 Claude Code 的工具)
...@@ -473,7 +513,7 @@ func testClaudeMessageWithTools(t *testing.T, model string) { ...@@ -473,7 +513,7 @@ func testClaudeMessageWithTools(t *testing.T, model string) {
req, _ := http.NewRequest("POST", url, bytes.NewReader(body)) req, _ := http.NewRequest("POST", url, bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+claudeAPIKey) req.Header.Set("Authorization", "Bearer "+claudeKey)
req.Header.Set("anthropic-version", "2023-06-01") req.Header.Set("anthropic-version", "2023-06-01")
client := &http.Client{Timeout: 60 * time.Second} client := &http.Client{Timeout: 60 * time.Second}
...@@ -519,6 +559,7 @@ func testClaudeMessageWithTools(t *testing.T, model string) { ...@@ -519,6 +559,7 @@ func testClaudeMessageWithTools(t *testing.T, model string) {
// 验证:当历史 assistant 消息包含 tool_use 但没有 signature 时, // 验证:当历史 assistant 消息包含 tool_use 但没有 signature 时,
// 系统应自动添加 dummy thought_signature 避免 Gemini 400 错误 // 系统应自动添加 dummy thought_signature 避免 Gemini 400 错误
func TestClaudeMessagesWithThinkingAndTools(t *testing.T) { func TestClaudeMessagesWithThinkingAndTools(t *testing.T) {
claudeKey := requireClaudeAPIKey(t)
models := []string{ models := []string{
"claude-haiku-4-5-20251001", // gemini-3-flash "claude-haiku-4-5-20251001", // gemini-3-flash
} }
...@@ -527,12 +568,12 @@ func TestClaudeMessagesWithThinkingAndTools(t *testing.T) { ...@@ -527,12 +568,12 @@ func TestClaudeMessagesWithThinkingAndTools(t *testing.T) {
time.Sleep(testInterval) time.Sleep(testInterval)
} }
t.Run(model+"_thinking模式工具调用", func(t *testing.T) { t.Run(model+"_thinking模式工具调用", func(t *testing.T) {
testClaudeThinkingWithToolHistory(t, model) testClaudeThinkingWithToolHistory(t, claudeKey, model)
}) })
} }
} }
func testClaudeThinkingWithToolHistory(t *testing.T, model string) { func testClaudeThinkingWithToolHistory(t *testing.T, claudeKey string, model string) {
url := baseURL + endpointPrefix + "/v1/messages" url := baseURL + endpointPrefix + "/v1/messages"
// 模拟历史对话:用户请求 → assistant 调用工具 → 工具返回 → 继续对话 // 模拟历史对话:用户请求 → assistant 调用工具 → 工具返回 → 继续对话
...@@ -600,7 +641,7 @@ func testClaudeThinkingWithToolHistory(t *testing.T, model string) { ...@@ -600,7 +641,7 @@ func testClaudeThinkingWithToolHistory(t *testing.T, model string) {
req, _ := http.NewRequest("POST", url, bytes.NewReader(body)) req, _ := http.NewRequest("POST", url, bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+claudeAPIKey) req.Header.Set("Authorization", "Bearer "+claudeKey)
req.Header.Set("anthropic-version", "2023-06-01") req.Header.Set("anthropic-version", "2023-06-01")
client := &http.Client{Timeout: 60 * time.Second} client := &http.Client{Timeout: 60 * time.Second}
...@@ -649,6 +690,7 @@ func TestClaudeMessagesWithGeminiModel(t *testing.T) { ...@@ -649,6 +690,7 @@ func TestClaudeMessagesWithGeminiModel(t *testing.T) {
if endpointPrefix != "/antigravity" { if endpointPrefix != "/antigravity" {
t.Skip("仅在 Antigravity 模式下运行") t.Skip("仅在 Antigravity 模式下运行")
} }
claudeKey := requireClaudeAPIKey(t)
// 测试通过 Claude 端点调用 Gemini 模型 // 测试通过 Claude 端点调用 Gemini 模型
geminiViaClaude := []string{ geminiViaClaude := []string{
...@@ -664,11 +706,11 @@ func TestClaudeMessagesWithGeminiModel(t *testing.T) { ...@@ -664,11 +706,11 @@ func TestClaudeMessagesWithGeminiModel(t *testing.T) {
time.Sleep(testInterval) time.Sleep(testInterval)
} }
t.Run(model+"_通过Claude端点", func(t *testing.T) { t.Run(model+"_通过Claude端点", func(t *testing.T) {
testClaudeMessage(t, model, false) testClaudeMessage(t, claudeKey, model, false)
}) })
time.Sleep(testInterval) time.Sleep(testInterval)
t.Run(model+"_通过Claude端点_流式", func(t *testing.T) { t.Run(model+"_通过Claude端点_流式", func(t *testing.T) {
testClaudeMessage(t, model, true) testClaudeMessage(t, claudeKey, model, true)
}) })
} }
} }
...@@ -676,6 +718,7 @@ func TestClaudeMessagesWithGeminiModel(t *testing.T) { ...@@ -676,6 +718,7 @@ func TestClaudeMessagesWithGeminiModel(t *testing.T) {
// TestClaudeMessagesWithNoSignature 测试历史 thinking block 不带 signature 的场景 // TestClaudeMessagesWithNoSignature 测试历史 thinking block 不带 signature 的场景
// 验证:Gemini 模型接受没有 signature 的 thinking block // 验证:Gemini 模型接受没有 signature 的 thinking block
func TestClaudeMessagesWithNoSignature(t *testing.T) { func TestClaudeMessagesWithNoSignature(t *testing.T) {
claudeKey := requireClaudeAPIKey(t)
models := []string{ models := []string{
"claude-haiku-4-5-20251001", // gemini-3-flash - 支持无 signature "claude-haiku-4-5-20251001", // gemini-3-flash - 支持无 signature
} }
...@@ -684,12 +727,12 @@ func TestClaudeMessagesWithNoSignature(t *testing.T) { ...@@ -684,12 +727,12 @@ func TestClaudeMessagesWithNoSignature(t *testing.T) {
time.Sleep(testInterval) time.Sleep(testInterval)
} }
t.Run(model+"_无signature", func(t *testing.T) { t.Run(model+"_无signature", func(t *testing.T) {
testClaudeWithNoSignature(t, model) testClaudeWithNoSignature(t, claudeKey, model)
}) })
} }
} }
func testClaudeWithNoSignature(t *testing.T, model string) { func testClaudeWithNoSignature(t *testing.T, claudeKey string, model string) {
url := baseURL + endpointPrefix + "/v1/messages" url := baseURL + endpointPrefix + "/v1/messages"
// 模拟历史对话包含 thinking block 但没有 signature // 模拟历史对话包含 thinking block 但没有 signature
...@@ -732,7 +775,7 @@ func testClaudeWithNoSignature(t *testing.T, model string) { ...@@ -732,7 +775,7 @@ func testClaudeWithNoSignature(t *testing.T, model string) {
req, _ := http.NewRequest("POST", url, bytes.NewReader(body)) req, _ := http.NewRequest("POST", url, bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+claudeAPIKey) req.Header.Set("Authorization", "Bearer "+claudeKey)
req.Header.Set("anthropic-version", "2023-06-01") req.Header.Set("anthropic-version", "2023-06-01")
client := &http.Client{Timeout: 60 * time.Second} client := &http.Client{Timeout: 60 * time.Second}
...@@ -777,6 +820,7 @@ func TestGeminiEndpointWithClaudeModel(t *testing.T) { ...@@ -777,6 +820,7 @@ func TestGeminiEndpointWithClaudeModel(t *testing.T) {
if endpointPrefix != "/antigravity" { if endpointPrefix != "/antigravity" {
t.Skip("仅在 Antigravity 模式下运行") t.Skip("仅在 Antigravity 模式下运行")
} }
geminiKey := requireGeminiAPIKey(t)
// 测试通过 Gemini 端点调用 Claude 模型 // 测试通过 Gemini 端点调用 Claude 模型
claudeViaGemini := []string{ claudeViaGemini := []string{
...@@ -789,11 +833,11 @@ func TestGeminiEndpointWithClaudeModel(t *testing.T) { ...@@ -789,11 +833,11 @@ func TestGeminiEndpointWithClaudeModel(t *testing.T) {
time.Sleep(testInterval) time.Sleep(testInterval)
} }
t.Run(model+"_通过Gemini端点", func(t *testing.T) { t.Run(model+"_通过Gemini端点", func(t *testing.T) {
testGeminiGenerate(t, model, false) testGeminiGenerate(t, geminiKey, model, false)
}) })
time.Sleep(testInterval) time.Sleep(testInterval)
t.Run(model+"_通过Gemini端点_流式", func(t *testing.T) { t.Run(model+"_通过Gemini端点_流式", func(t *testing.T) {
testGeminiGenerate(t, model, true) testGeminiGenerate(t, geminiKey, model, true)
}) })
} }
} }
//go:build e2e
package integration
import (
"os"
"strings"
"testing"
)
// =============================================================================
// E2E Mock 模式支持
// =============================================================================
// 当 E2E_MOCK=true 时,使用本地 Mock 响应替代真实 API 调用。
// 这允许在没有真实 API Key 的环境(如 CI)中验证基本的请求/响应流程。
// isMockMode 检查是否启用 Mock 模式
func isMockMode() bool {
return strings.EqualFold(os.Getenv("E2E_MOCK"), "true")
}
// skipIfNoRealAPI 如果未配置真实 API Key 且不在 Mock 模式,则跳过测试
func skipIfNoRealAPI(t *testing.T) {
t.Helper()
if isMockMode() {
return // Mock 模式下不跳过
}
claudeKey := strings.TrimSpace(os.Getenv(claudeAPIKeyEnv))
geminiKey := strings.TrimSpace(os.Getenv(geminiAPIKeyEnv))
if claudeKey == "" && geminiKey == "" {
t.Skip("未设置 API Key 且未启用 Mock 模式,跳过测试")
}
}
// =============================================================================
// API Key 脱敏(Task 6.10)
// =============================================================================
// safeLogKey 安全地记录 API Key(仅显示前 8 位)
func safeLogKey(t *testing.T, prefix string, key string) {
t.Helper()
key = strings.TrimSpace(key)
if len(key) <= 8 {
t.Logf("%s: ***(长度: %d)", prefix, len(key))
return
}
t.Logf("%s: %s...(长度: %d)", prefix, key[:8], len(key))
}
//go:build e2e
package integration
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"testing"
"time"
)
// E2E 用户流程测试
// 测试完整的用户操作链路:注册 → 登录 → 创建 API Key → 调用网关 → 查询用量
var (
testUserEmail = "e2e-test-" + fmt.Sprintf("%d", time.Now().UnixMilli()) + "@test.local"
testUserPassword = "E2eTest@12345"
testUserName = "e2e-test-user"
)
// TestUserRegistrationAndLogin 测试用户注册和登录流程
func TestUserRegistrationAndLogin(t *testing.T) {
// 步骤 1: 注册新用户
t.Run("注册新用户", func(t *testing.T) {
payload := map[string]string{
"email": testUserEmail,
"password": testUserPassword,
"username": testUserName,
}
body, _ := json.Marshal(payload)
resp, err := doRequest(t, "POST", "/api/auth/register", body, "")
if err != nil {
t.Skipf("注册接口不可用,跳过用户流程测试: %v", err)
return
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
// 注册可能返回 200(成功)或 400(邮箱已存在)或 403(注册已关闭)
switch resp.StatusCode {
case 200:
t.Logf("✅ 用户注册成功: %s", testUserEmail)
case 400:
t.Logf("⚠️ 用户可能已存在: %s", string(respBody))
case 403:
t.Skipf("注册功能已关闭: %s", string(respBody))
default:
t.Logf("⚠️ 注册返回 HTTP %d: %s(继续尝试登录)", resp.StatusCode, string(respBody))
}
})
// 步骤 2: 登录获取 JWT
var accessToken string
t.Run("用户登录获取JWT", func(t *testing.T) {
payload := map[string]string{
"email": testUserEmail,
"password": testUserPassword,
}
body, _ := json.Marshal(payload)
resp, err := doRequest(t, "POST", "/api/auth/login", body, "")
if err != nil {
t.Fatalf("登录请求失败: %v", err)
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != 200 {
t.Skipf("登录失败 HTTP %d: %s(可能需要先注册用户)", resp.StatusCode, string(respBody))
return
}
var result map[string]any
if err := json.Unmarshal(respBody, &result); err != nil {
t.Fatalf("解析登录响应失败: %v", err)
}
// 尝试从标准响应格式获取 token
if token, ok := result["access_token"].(string); ok && token != "" {
accessToken = token
} else if data, ok := result["data"].(map[string]any); ok {
if token, ok := data["access_token"].(string); ok {
accessToken = token
}
}
if accessToken == "" {
t.Skipf("未获取到 access_token,响应: %s", string(respBody))
return
}
// 验证 token 不为空且格式基本正确
if len(accessToken) < 10 {
t.Fatalf("access_token 格式异常: %s", accessToken)
}
t.Logf("✅ 登录成功,获取 JWT(长度: %d)", len(accessToken))
})
if accessToken == "" {
t.Skip("未获取到 JWT,跳过后续测试")
return
}
// 步骤 3: 使用 JWT 获取当前用户信息
t.Run("获取当前用户信息", func(t *testing.T) {
resp, err := doRequest(t, "GET", "/api/user/me", nil, accessToken)
if err != nil {
t.Fatalf("请求失败: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
body, _ := io.ReadAll(resp.Body)
t.Fatalf("HTTP %d: %s", resp.StatusCode, string(body))
}
t.Logf("✅ 成功获取用户信息")
})
}
// TestAPIKeyLifecycle 测试 API Key 的创建和使用
func TestAPIKeyLifecycle(t *testing.T) {
// 先登录获取 JWT
accessToken := loginTestUser(t)
if accessToken == "" {
t.Skip("无法登录,跳过 API Key 生命周期测试")
return
}
var apiKey string
// 步骤 1: 创建 API Key
t.Run("创建API_Key", func(t *testing.T) {
payload := map[string]string{
"name": "e2e-test-key-" + fmt.Sprintf("%d", time.Now().UnixMilli()),
}
body, _ := json.Marshal(payload)
resp, err := doRequest(t, "POST", "/api/keys", body, accessToken)
if err != nil {
t.Fatalf("创建 API Key 请求失败: %v", err)
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != 200 {
t.Skipf("创建 API Key 失败 HTTP %d: %s", resp.StatusCode, string(respBody))
return
}
var result map[string]any
if err := json.Unmarshal(respBody, &result); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
// 从响应中提取 key
if key, ok := result["key"].(string); ok {
apiKey = key
} else if data, ok := result["data"].(map[string]any); ok {
if key, ok := data["key"].(string); ok {
apiKey = key
}
}
if apiKey == "" {
t.Skipf("未获取到 API Key,响应: %s", string(respBody))
return
}
// 验证 API Key 脱敏日志(只显示前 8 位)
masked := apiKey
if len(masked) > 8 {
masked = masked[:8] + "..."
}
t.Logf("✅ API Key 创建成功: %s", masked)
})
if apiKey == "" {
t.Skip("未创建 API Key,跳过后续测试")
return
}
// 步骤 2: 使用 API Key 调用网关(需要 Claude 或 Gemini 可用)
t.Run("使用API_Key调用网关", func(t *testing.T) {
// 尝试调用 models 列表(最轻量的 API 调用)
resp, err := doRequest(t, "GET", "/v1/models", nil, apiKey)
if err != nil {
t.Fatalf("网关请求失败: %v", err)
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
// 可能返回 200(成功)或 402(余额不足)或 403(无可用账户)
switch {
case resp.StatusCode == 200:
t.Logf("✅ API Key 网关调用成功")
case resp.StatusCode == 402:
t.Logf("⚠️ 余额不足,但 API Key 认证通过")
case resp.StatusCode == 403:
t.Logf("⚠️ 无可用账户,但 API Key 认证通过")
default:
t.Logf("⚠️ 网关返回 HTTP %d: %s", resp.StatusCode, string(respBody))
}
})
// 步骤 3: 查询用量记录
t.Run("查询用量记录", func(t *testing.T) {
resp, err := doRequest(t, "GET", "/api/usage/dashboard", nil, accessToken)
if err != nil {
t.Fatalf("用量查询请求失败: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
body, _ := io.ReadAll(resp.Body)
t.Logf("⚠️ 用量查询返回 HTTP %d: %s", resp.StatusCode, string(body))
return
}
t.Logf("✅ 用量查询成功")
})
}
// =============================================================================
// 辅助函数
// =============================================================================
func doRequest(t *testing.T, method, path string, body []byte, token string) (*http.Response, error) {
t.Helper()
url := baseURL + path
var bodyReader io.Reader
if body != nil {
bodyReader = bytes.NewReader(body)
}
req, err := http.NewRequest(method, url, bodyReader)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
if body != nil {
req.Header.Set("Content-Type", "application/json")
}
if token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
client := &http.Client{Timeout: 30 * time.Second}
return client.Do(req)
}
func loginTestUser(t *testing.T) string {
t.Helper()
// 先尝试用管理员账户登录
adminEmail := getEnv("ADMIN_EMAIL", "admin@sub2api.local")
adminPassword := getEnv("ADMIN_PASSWORD", "")
if adminPassword == "" {
// 尝试用测试用户
adminEmail = testUserEmail
adminPassword = testUserPassword
}
payload := map[string]string{
"email": adminEmail,
"password": adminPassword,
}
body, _ := json.Marshal(payload)
resp, err := doRequest(t, "POST", "/api/auth/login", body, "")
if err != nil {
return ""
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
return ""
}
respBody, _ := io.ReadAll(resp.Body)
var result map[string]any
if err := json.Unmarshal(respBody, &result); err != nil {
return ""
}
if token, ok := result["access_token"].(string); ok {
return token
}
if data, ok := result["data"].(map[string]any); ok {
if token, ok := data["access_token"].(string); ok {
return token
}
}
return ""
}
// redactAPIKey API Key 脱敏,只显示前 8 位
func redactAPIKey(key string) string {
key = strings.TrimSpace(key)
if len(key) <= 8 {
return "***"
}
return key[:8] + "..."
}
...@@ -60,6 +60,49 @@ func TestRateLimiterFailureModes(t *testing.T) { ...@@ -60,6 +60,49 @@ func TestRateLimiterFailureModes(t *testing.T) {
require.Equal(t, http.StatusTooManyRequests, recorder.Code) require.Equal(t, http.StatusTooManyRequests, recorder.Code)
} }
func TestRateLimiterDifferentIPsIndependent(t *testing.T) {
gin.SetMode(gin.TestMode)
callCounts := make(map[string]int64)
originalRun := rateLimitRun
rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, bool, error) {
callCounts[key]++
return callCounts[key], false, nil
}
t.Cleanup(func() {
rateLimitRun = originalRun
})
limiter := NewRateLimiter(redis.NewClient(&redis.Options{Addr: "127.0.0.1:1"}))
router := gin.New()
router.Use(limiter.Limit("api", 1, time.Second))
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
// 第一个 IP 的请求应通过
req1 := httptest.NewRequest(http.MethodGet, "/test", nil)
req1.RemoteAddr = "10.0.0.1:1234"
rec1 := httptest.NewRecorder()
router.ServeHTTP(rec1, req1)
require.Equal(t, http.StatusOK, rec1.Code, "第一个 IP 的第一次请求应通过")
// 第二个 IP 的请求应独立通过(不受第一个 IP 的计数影响)
req2 := httptest.NewRequest(http.MethodGet, "/test", nil)
req2.RemoteAddr = "10.0.0.2:5678"
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
require.Equal(t, http.StatusOK, rec2.Code, "第二个 IP 的第一次请求应独立通过")
// 第一个 IP 的第二次请求应被限流
req3 := httptest.NewRequest(http.MethodGet, "/test", nil)
req3.RemoteAddr = "10.0.0.1:1234"
rec3 := httptest.NewRecorder()
router.ServeHTTP(rec3, req3)
require.Equal(t, http.StatusTooManyRequests, rec3.Code, "第一个 IP 的第二次请求应被限流")
}
func TestRateLimiterSuccessAndLimit(t *testing.T) { func TestRateLimiterSuccessAndLimit(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
......
...@@ -204,9 +204,14 @@ func shouldFallbackToNextURL(err error, statusCode int) bool { ...@@ -204,9 +204,14 @@ func shouldFallbackToNextURL(err error, statusCode int) bool {
// ExchangeCode 用 authorization code 交换 token // ExchangeCode 用 authorization code 交换 token
func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (*TokenResponse, error) { func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (*TokenResponse, error) {
clientSecret, err := getClientSecret()
if err != nil {
return nil, err
}
params := url.Values{} params := url.Values{}
params.Set("client_id", ClientID) params.Set("client_id", ClientID)
params.Set("client_secret", ClientSecret) params.Set("client_secret", clientSecret)
params.Set("code", code) params.Set("code", code)
params.Set("redirect_uri", RedirectURI) params.Set("redirect_uri", RedirectURI)
params.Set("grant_type", "authorization_code") params.Set("grant_type", "authorization_code")
...@@ -243,9 +248,14 @@ func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (* ...@@ -243,9 +248,14 @@ func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (*
// RefreshToken 刷新 access_token // RefreshToken 刷新 access_token
func (c *Client) RefreshToken(ctx context.Context, refreshToken string) (*TokenResponse, error) { func (c *Client) RefreshToken(ctx context.Context, refreshToken string) (*TokenResponse, error) {
clientSecret, err := getClientSecret()
if err != nil {
return nil, err
}
params := url.Values{} params := url.Values{}
params.Set("client_id", ClientID) params.Set("client_id", ClientID)
params.Set("client_secret", ClientSecret) params.Set("client_secret", clientSecret)
params.Set("refresh_token", refreshToken) params.Set("refresh_token", refreshToken)
params.Set("grant_type", "refresh_token") params.Set("grant_type", "refresh_token")
......
//go:build unit
package antigravity package antigravity
import ( import (
"context"
"encoding/json"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing" "testing"
"time"
) )
// ---------------------------------------------------------------------------
// NewAPIRequestWithURL
// ---------------------------------------------------------------------------
func TestNewAPIRequestWithURL_普通请求(t *testing.T) {
ctx := context.Background()
baseURL := "https://example.com"
action := "generateContent"
token := "test-token"
body := []byte(`{"prompt":"hello"}`)
req, err := NewAPIRequestWithURL(ctx, baseURL, action, token, body)
if err != nil {
t.Fatalf("创建请求失败: %v", err)
}
// 验证 URL 不含 ?alt=sse
expectedURL := "https://example.com/v1internal:generateContent"
if req.URL.String() != expectedURL {
t.Errorf("URL 不匹配: got %s, want %s", req.URL.String(), expectedURL)
}
// 验证请求方法
if req.Method != http.MethodPost {
t.Errorf("请求方法不匹配: got %s, want POST", req.Method)
}
// 验证 Headers
if ct := req.Header.Get("Content-Type"); ct != "application/json" {
t.Errorf("Content-Type 不匹配: got %s", ct)
}
if auth := req.Header.Get("Authorization"); auth != "Bearer test-token" {
t.Errorf("Authorization 不匹配: got %s", auth)
}
if ua := req.Header.Get("User-Agent"); ua != UserAgent {
t.Errorf("User-Agent 不匹配: got %s, want %s", ua, UserAgent)
}
}
func TestNewAPIRequestWithURL_流式请求(t *testing.T) {
ctx := context.Background()
baseURL := "https://example.com"
action := "streamGenerateContent"
token := "tok"
body := []byte(`{}`)
req, err := NewAPIRequestWithURL(ctx, baseURL, action, token, body)
if err != nil {
t.Fatalf("创建请求失败: %v", err)
}
expectedURL := "https://example.com/v1internal:streamGenerateContent?alt=sse"
if req.URL.String() != expectedURL {
t.Errorf("URL 不匹配: got %s, want %s", req.URL.String(), expectedURL)
}
}
func TestNewAPIRequestWithURL_空Body(t *testing.T) {
ctx := context.Background()
req, err := NewAPIRequestWithURL(ctx, "https://example.com", "test", "tok", nil)
if err != nil {
t.Fatalf("创建请求失败: %v", err)
}
if req.Body == nil {
t.Error("Body 应该非 nil(bytes.NewReader(nil) 会返回空 reader)")
}
}
// ---------------------------------------------------------------------------
// NewAPIRequest
// ---------------------------------------------------------------------------
func TestNewAPIRequest_使用默认URL(t *testing.T) {
ctx := context.Background()
req, err := NewAPIRequest(ctx, "generateContent", "tok", []byte(`{}`))
if err != nil {
t.Fatalf("创建请求失败: %v", err)
}
expected := BaseURL + "/v1internal:generateContent"
if req.URL.String() != expected {
t.Errorf("URL 不匹配: got %s, want %s", req.URL.String(), expected)
}
}
// ---------------------------------------------------------------------------
// TierInfo.UnmarshalJSON
// ---------------------------------------------------------------------------
func TestTierInfo_UnmarshalJSON_字符串格式(t *testing.T) {
data := []byte(`"free-tier"`)
var tier TierInfo
if err := tier.UnmarshalJSON(data); err != nil {
t.Fatalf("反序列化失败: %v", err)
}
if tier.ID != "free-tier" {
t.Errorf("ID 不匹配: got %s, want free-tier", tier.ID)
}
if tier.Name != "" {
t.Errorf("Name 应为空: got %s", tier.Name)
}
}
func TestTierInfo_UnmarshalJSON_对象格式(t *testing.T) {
data := []byte(`{"id":"g1-pro-tier","name":"Pro","description":"Pro plan"}`)
var tier TierInfo
if err := tier.UnmarshalJSON(data); err != nil {
t.Fatalf("反序列化失败: %v", err)
}
if tier.ID != "g1-pro-tier" {
t.Errorf("ID 不匹配: got %s, want g1-pro-tier", tier.ID)
}
if tier.Name != "Pro" {
t.Errorf("Name 不匹配: got %s, want Pro", tier.Name)
}
if tier.Description != "Pro plan" {
t.Errorf("Description 不匹配: got %s, want Pro plan", tier.Description)
}
}
func TestTierInfo_UnmarshalJSON_null(t *testing.T) {
data := []byte(`null`)
var tier TierInfo
if err := tier.UnmarshalJSON(data); err != nil {
t.Fatalf("反序列化 null 失败: %v", err)
}
if tier.ID != "" {
t.Errorf("null 场景下 ID 应为空: got %s", tier.ID)
}
}
func TestTierInfo_UnmarshalJSON_空数据(t *testing.T) {
data := []byte(``)
var tier TierInfo
if err := tier.UnmarshalJSON(data); err != nil {
t.Fatalf("反序列化空数据失败: %v", err)
}
if tier.ID != "" {
t.Errorf("空数据场景下 ID 应为空: got %s", tier.ID)
}
}
func TestTierInfo_UnmarshalJSON_空格包裹null(t *testing.T) {
data := []byte(` null `)
var tier TierInfo
if err := tier.UnmarshalJSON(data); err != nil {
t.Fatalf("反序列化空格 null 失败: %v", err)
}
if tier.ID != "" {
t.Errorf("空格 null 场景下 ID 应为空: got %s", tier.ID)
}
}
func TestTierInfo_UnmarshalJSON_通过JSON嵌套结构(t *testing.T) {
// 模拟 LoadCodeAssistResponse 中的嵌套反序列化
jsonData := `{"currentTier":"free-tier","paidTier":{"id":"g1-ultra-tier","name":"Ultra"}}`
var resp LoadCodeAssistResponse
if err := json.Unmarshal([]byte(jsonData), &resp); err != nil {
t.Fatalf("反序列化嵌套结构失败: %v", err)
}
if resp.CurrentTier == nil || resp.CurrentTier.ID != "free-tier" {
t.Errorf("CurrentTier 不匹配: got %+v", resp.CurrentTier)
}
if resp.PaidTier == nil || resp.PaidTier.ID != "g1-ultra-tier" {
t.Errorf("PaidTier 不匹配: got %+v", resp.PaidTier)
}
}
// ---------------------------------------------------------------------------
// LoadCodeAssistResponse.GetTier
// ---------------------------------------------------------------------------
func TestGetTier_PaidTier优先(t *testing.T) {
resp := &LoadCodeAssistResponse{
CurrentTier: &TierInfo{ID: "free-tier"},
PaidTier: &TierInfo{ID: "g1-pro-tier"},
}
if got := resp.GetTier(); got != "g1-pro-tier" {
t.Errorf("应返回 paidTier: got %s", got)
}
}
func TestGetTier_回退到CurrentTier(t *testing.T) {
resp := &LoadCodeAssistResponse{
CurrentTier: &TierInfo{ID: "free-tier"},
}
if got := resp.GetTier(); got != "free-tier" {
t.Errorf("应返回 currentTier: got %s", got)
}
}
func TestGetTier_PaidTier为空ID(t *testing.T) {
resp := &LoadCodeAssistResponse{
CurrentTier: &TierInfo{ID: "free-tier"},
PaidTier: &TierInfo{ID: ""},
}
// paidTier.ID 为空时应回退到 currentTier
if got := resp.GetTier(); got != "free-tier" {
t.Errorf("paidTier.ID 为空时应回退到 currentTier: got %s", got)
}
}
func TestGetTier_两者都为nil(t *testing.T) {
resp := &LoadCodeAssistResponse{}
if got := resp.GetTier(); got != "" {
t.Errorf("两者都为 nil 时应返回空字符串: got %s", got)
}
}
// ---------------------------------------------------------------------------
// NewClient
// ---------------------------------------------------------------------------
func TestNewClient_无代理(t *testing.T) {
client := NewClient("")
if client == nil {
t.Fatal("NewClient 返回 nil")
}
if client.httpClient == nil {
t.Fatal("httpClient 为 nil")
}
if client.httpClient.Timeout != 30*time.Second {
t.Errorf("Timeout 不匹配: got %v, want 30s", client.httpClient.Timeout)
}
// 无代理时 Transport 应为 nil(使用默认)
if client.httpClient.Transport != nil {
t.Error("无代理时 Transport 应为 nil")
}
}
func TestNewClient_有代理(t *testing.T) {
client := NewClient("http://proxy.example.com:8080")
if client == nil {
t.Fatal("NewClient 返回 nil")
}
if client.httpClient.Transport == nil {
t.Fatal("有代理时 Transport 不应为 nil")
}
}
func TestNewClient_空格代理(t *testing.T) {
client := NewClient(" ")
if client == nil {
t.Fatal("NewClient 返回 nil")
}
// 空格代理应等同于无代理
if client.httpClient.Transport != nil {
t.Error("空格代理 Transport 应为 nil")
}
}
func TestNewClient_无效代理URL(t *testing.T) {
// 无效 URL 时 url.Parse 不一定返回错误(Go 的 url.Parse 很宽容),
// 但 ://invalid 会导致解析错误
client := NewClient("://invalid")
if client == nil {
t.Fatal("NewClient 返回 nil")
}
// 无效 URL 解析失败时,Transport 应保持 nil
if client.httpClient.Transport != nil {
t.Error("无效代理 URL 时 Transport 应为 nil")
}
}
// ---------------------------------------------------------------------------
// isConnectionError
// ---------------------------------------------------------------------------
func TestIsConnectionError_nil(t *testing.T) {
if isConnectionError(nil) {
t.Error("nil 错误不应判定为连接错误")
}
}
func TestIsConnectionError_超时错误(t *testing.T) {
// 使用 net.OpError 包装超时
err := &net.OpError{
Op: "dial",
Net: "tcp",
Err: &timeoutError{},
}
if !isConnectionError(err) {
t.Error("超时错误应判定为连接错误")
}
}
// timeoutError 实现 net.Error 接口用于测试
type timeoutError struct{}
func (e *timeoutError) Error() string { return "timeout" }
func (e *timeoutError) Timeout() bool { return true }
func (e *timeoutError) Temporary() bool { return true }
func TestIsConnectionError_netOpError(t *testing.T) {
err := &net.OpError{
Op: "dial",
Net: "tcp",
Err: fmt.Errorf("connection refused"),
}
if !isConnectionError(err) {
t.Error("net.OpError 应判定为连接错误")
}
}
func TestIsConnectionError_urlError(t *testing.T) {
err := &url.Error{
Op: "Get",
URL: "https://example.com",
Err: fmt.Errorf("some error"),
}
if !isConnectionError(err) {
t.Error("url.Error 应判定为连接错误")
}
}
func TestIsConnectionError_普通错误(t *testing.T) {
err := fmt.Errorf("some random error")
if isConnectionError(err) {
t.Error("普通错误不应判定为连接错误")
}
}
func TestIsConnectionError_包装的netOpError(t *testing.T) {
inner := &net.OpError{
Op: "dial",
Net: "tcp",
Err: fmt.Errorf("connection refused"),
}
err := fmt.Errorf("wrapping: %w", inner)
if !isConnectionError(err) {
t.Error("被包装的 net.OpError 应判定为连接错误")
}
}
// ---------------------------------------------------------------------------
// shouldFallbackToNextURL
// ---------------------------------------------------------------------------
func TestShouldFallbackToNextURL_连接错误(t *testing.T) {
err := &net.OpError{Op: "dial", Net: "tcp", Err: fmt.Errorf("refused")}
if !shouldFallbackToNextURL(err, 0) {
t.Error("连接错误应触发 URL 降级")
}
}
func TestShouldFallbackToNextURL_状态码(t *testing.T) {
tests := []struct {
name string
statusCode int
want bool
}{
{"429 Too Many Requests", http.StatusTooManyRequests, true},
{"408 Request Timeout", http.StatusRequestTimeout, true},
{"404 Not Found", http.StatusNotFound, true},
{"500 Internal Server Error", http.StatusInternalServerError, true},
{"502 Bad Gateway", http.StatusBadGateway, true},
{"503 Service Unavailable", http.StatusServiceUnavailable, true},
{"200 OK", http.StatusOK, false},
{"201 Created", http.StatusCreated, false},
{"400 Bad Request", http.StatusBadRequest, false},
{"401 Unauthorized", http.StatusUnauthorized, false},
{"403 Forbidden", http.StatusForbidden, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := shouldFallbackToNextURL(nil, tt.statusCode)
if got != tt.want {
t.Errorf("shouldFallbackToNextURL(nil, %d) = %v, want %v", tt.statusCode, got, tt.want)
}
})
}
}
func TestShouldFallbackToNextURL_无错误且200(t *testing.T) {
if shouldFallbackToNextURL(nil, http.StatusOK) {
t.Error("无错误且 200 不应触发 URL 降级")
}
}
// ---------------------------------------------------------------------------
// Client.ExchangeCode (使用 httptest)
// ---------------------------------------------------------------------------
func TestClient_ExchangeCode_成功(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 验证请求方法
if r.Method != http.MethodPost {
t.Errorf("请求方法不匹配: got %s", r.Method)
}
// 验证 Content-Type
if ct := r.Header.Get("Content-Type"); ct != "application/x-www-form-urlencoded" {
t.Errorf("Content-Type 不匹配: got %s", ct)
}
// 验证请求体参数
if err := r.ParseForm(); err != nil {
t.Fatalf("解析表单失败: %v", err)
}
if r.FormValue("client_id") != ClientID {
t.Errorf("client_id 不匹配: got %s", r.FormValue("client_id"))
}
if r.FormValue("client_secret") != "test-secret" {
t.Errorf("client_secret 不匹配: got %s", r.FormValue("client_secret"))
}
if r.FormValue("code") != "auth-code" {
t.Errorf("code 不匹配: got %s", r.FormValue("code"))
}
if r.FormValue("code_verifier") != "verifier123" {
t.Errorf("code_verifier 不匹配: got %s", r.FormValue("code_verifier"))
}
if r.FormValue("grant_type") != "authorization_code" {
t.Errorf("grant_type 不匹配: got %s", r.FormValue("grant_type"))
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(TokenResponse{
AccessToken: "access-tok",
ExpiresIn: 3600,
TokenType: "Bearer",
RefreshToken: "refresh-tok",
})
}))
defer server.Close()
// 临时替换 TokenURL(该函数直接使用常量,需要我们通过构建自定义 client 来绕过)
// 由于 ExchangeCode 硬编码了 TokenURL,我们需要直接测试 HTTP client 的行为
// 这里通过构造一个直接调用 mock server 的测试
client := &Client{httpClient: server.Client()}
// 由于 ExchangeCode 使用硬编码的 TokenURL,我们无法直接注入 mock server URL
// 需要使用 httptest 的 Transport 重定向
originalTokenURL := TokenURL
// 我们改为直接构造请求来测试逻辑
_ = originalTokenURL
_ = client
// 改用直接构造请求测试 mock server 响应
ctx := context.Background()
params := url.Values{}
params.Set("client_id", ClientID)
params.Set("client_secret", "test-secret")
params.Set("code", "auth-code")
params.Set("redirect_uri", RedirectURI)
params.Set("grant_type", "authorization_code")
params.Set("code_verifier", "verifier123")
req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, strings.NewReader(params.Encode()))
if err != nil {
t.Fatalf("创建请求失败: %v", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := server.Client().Do(req)
if err != nil {
t.Fatalf("请求失败: %v", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
t.Fatalf("状态码不匹配: got %d", resp.StatusCode)
}
var tokenResp TokenResponse
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
t.Fatalf("解码失败: %v", err)
}
if tokenResp.AccessToken != "access-tok" {
t.Errorf("AccessToken 不匹配: got %s", tokenResp.AccessToken)
}
if tokenResp.RefreshToken != "refresh-tok" {
t.Errorf("RefreshToken 不匹配: got %s", tokenResp.RefreshToken)
}
}
func TestClient_ExchangeCode_无ClientSecret(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "")
client := NewClient("")
_, err := client.ExchangeCode(context.Background(), "code", "verifier")
if err == nil {
t.Fatal("缺少 client_secret 时应返回错误")
}
if !strings.Contains(err.Error(), AntigravityOAuthClientSecretEnv) {
t.Errorf("错误信息应包含环境变量名: got %s", err.Error())
}
}
func TestClient_ExchangeCode_服务器返回错误(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(`{"error":"invalid_grant"}`))
}))
defer server.Close()
// 直接测试 mock server 的错误响应
resp, err := server.Client().Get(server.URL)
if err != nil {
t.Fatalf("请求失败: %v", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("状态码不匹配: got %d, want 400", resp.StatusCode)
}
}
// ---------------------------------------------------------------------------
// Client.RefreshToken (使用 httptest)
// ---------------------------------------------------------------------------
func TestClient_RefreshToken_MockServer(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("请求方法不匹配: got %s", r.Method)
}
if err := r.ParseForm(); err != nil {
t.Fatalf("解析表单失败: %v", err)
}
if r.FormValue("grant_type") != "refresh_token" {
t.Errorf("grant_type 不匹配: got %s", r.FormValue("grant_type"))
}
if r.FormValue("refresh_token") != "old-refresh-tok" {
t.Errorf("refresh_token 不匹配: got %s", r.FormValue("refresh_token"))
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(TokenResponse{
AccessToken: "new-access-tok",
ExpiresIn: 3600,
TokenType: "Bearer",
})
}))
defer server.Close()
ctx := context.Background()
params := url.Values{}
params.Set("client_id", ClientID)
params.Set("client_secret", "test-secret")
params.Set("refresh_token", "old-refresh-tok")
params.Set("grant_type", "refresh_token")
req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, strings.NewReader(params.Encode()))
if err != nil {
t.Fatalf("创建请求失败: %v", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := server.Client().Do(req)
if err != nil {
t.Fatalf("请求失败: %v", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
t.Fatalf("状态码不匹配: got %d", resp.StatusCode)
}
var tokenResp TokenResponse
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
t.Fatalf("解码失败: %v", err)
}
if tokenResp.AccessToken != "new-access-tok" {
t.Errorf("AccessToken 不匹配: got %s", tokenResp.AccessToken)
}
}
func TestClient_RefreshToken_无ClientSecret(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "")
client := NewClient("")
_, err := client.RefreshToken(context.Background(), "refresh-tok")
if err == nil {
t.Fatal("缺少 client_secret 时应返回错误")
}
}
// ---------------------------------------------------------------------------
// Client.GetUserInfo (使用 httptest)
// ---------------------------------------------------------------------------
func TestClient_GetUserInfo_成功(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
t.Errorf("请求方法不匹配: got %s", r.Method)
}
auth := r.Header.Get("Authorization")
if auth != "Bearer test-access-token" {
t.Errorf("Authorization 不匹配: got %s", auth)
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(UserInfo{
Email: "user@example.com",
Name: "Test User",
GivenName: "Test",
FamilyName: "User",
Picture: "https://example.com/photo.jpg",
})
}))
defer server.Close()
// 直接通过 mock server 测试 GetUserInfo 的行为逻辑
ctx := context.Background()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil)
if err != nil {
t.Fatalf("创建请求失败: %v", err)
}
req.Header.Set("Authorization", "Bearer test-access-token")
resp, err := server.Client().Do(req)
if err != nil {
t.Fatalf("请求失败: %v", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
t.Fatalf("状态码不匹配: got %d", resp.StatusCode)
}
var userInfo UserInfo
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
t.Fatalf("解码失败: %v", err)
}
if userInfo.Email != "user@example.com" {
t.Errorf("Email 不匹配: got %s", userInfo.Email)
}
if userInfo.Name != "Test User" {
t.Errorf("Name 不匹配: got %s", userInfo.Name)
}
}
func TestClient_GetUserInfo_服务器返回错误(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte(`{"error":"invalid_token"}`))
}))
defer server.Close()
resp, err := server.Client().Get(server.URL)
if err != nil {
t.Fatalf("请求失败: %v", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusUnauthorized {
t.Errorf("状态码不匹配: got %d, want 401", resp.StatusCode)
}
}
// ---------------------------------------------------------------------------
// TokenResponse / UserInfo JSON 序列化
// ---------------------------------------------------------------------------
func TestTokenResponse_JSON序列化(t *testing.T) {
jsonData := `{"access_token":"at","expires_in":3600,"token_type":"Bearer","scope":"openid","refresh_token":"rt"}`
var resp TokenResponse
if err := json.Unmarshal([]byte(jsonData), &resp); err != nil {
t.Fatalf("反序列化失败: %v", err)
}
if resp.AccessToken != "at" {
t.Errorf("AccessToken 不匹配: got %s", resp.AccessToken)
}
if resp.ExpiresIn != 3600 {
t.Errorf("ExpiresIn 不匹配: got %d", resp.ExpiresIn)
}
if resp.RefreshToken != "rt" {
t.Errorf("RefreshToken 不匹配: got %s", resp.RefreshToken)
}
}
func TestUserInfo_JSON序列化(t *testing.T) {
jsonData := `{"email":"a@b.com","name":"Alice"}`
var info UserInfo
if err := json.Unmarshal([]byte(jsonData), &info); err != nil {
t.Fatalf("反序列化失败: %v", err)
}
if info.Email != "a@b.com" {
t.Errorf("Email 不匹配: got %s", info.Email)
}
if info.Name != "Alice" {
t.Errorf("Name 不匹配: got %s", info.Name)
}
}
// ---------------------------------------------------------------------------
// LoadCodeAssistResponse JSON 序列化
// ---------------------------------------------------------------------------
func TestLoadCodeAssistResponse_完整JSON(t *testing.T) {
jsonData := `{
"cloudaicompanionProject": "proj-123",
"currentTier": "free-tier",
"paidTier": {"id": "g1-pro-tier", "name": "Pro"},
"ineligibleTiers": [{"tier": {"id": "g1-ultra-tier"}, "reasonCode": "INELIGIBLE_ACCOUNT"}]
}`
var resp LoadCodeAssistResponse
if err := json.Unmarshal([]byte(jsonData), &resp); err != nil {
t.Fatalf("反序列化失败: %v", err)
}
if resp.CloudAICompanionProject != "proj-123" {
t.Errorf("CloudAICompanionProject 不匹配: got %s", resp.CloudAICompanionProject)
}
if resp.GetTier() != "g1-pro-tier" {
t.Errorf("GetTier 不匹配: got %s", resp.GetTier())
}
if len(resp.IneligibleTiers) != 1 {
t.Fatalf("IneligibleTiers 数量不匹配: got %d", len(resp.IneligibleTiers))
}
if resp.IneligibleTiers[0].ReasonCode != "INELIGIBLE_ACCOUNT" {
t.Errorf("ReasonCode 不匹配: got %s", resp.IneligibleTiers[0].ReasonCode)
}
}
// ===========================================================================
// 以下为新增测试:真正调用 Client 方法,通过 RoundTripper 拦截 HTTP 请求
// ===========================================================================
// redirectRoundTripper 将请求中特定前缀的 URL 重定向到 httptest server
type redirectRoundTripper struct {
// 原始 URL 前缀 -> 替换目标 URL 的映射
redirects map[string]string
transport http.RoundTripper
}
func (rt *redirectRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
originalURL := req.URL.String()
for prefix, target := range rt.redirects {
if strings.HasPrefix(originalURL, prefix) {
newURL := target + strings.TrimPrefix(originalURL, prefix)
parsed, err := url.Parse(newURL)
if err != nil {
return nil, err
}
req.URL = parsed
break
}
}
if rt.transport == nil {
return http.DefaultTransport.RoundTrip(req)
}
return rt.transport.RoundTrip(req)
}
// newTestClientWithRedirect 创建一个 Client,将指定 URL 前缀的请求重定向到 mock server
func newTestClientWithRedirect(redirects map[string]string) *Client {
return &Client{
httpClient: &http.Client{
Timeout: 10 * time.Second,
Transport: &redirectRoundTripper{
redirects: redirects,
},
},
}
}
// ---------------------------------------------------------------------------
// Client.ExchangeCode - 真正调用方法的测试
// ---------------------------------------------------------------------------
func TestClient_ExchangeCode_Success_RealCall(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("请求方法不匹配: got %s, want POST", r.Method)
}
if ct := r.Header.Get("Content-Type"); ct != "application/x-www-form-urlencoded" {
t.Errorf("Content-Type 不匹配: got %s", ct)
}
if err := r.ParseForm(); err != nil {
t.Fatalf("解析表单失败: %v", err)
}
if r.FormValue("client_id") != ClientID {
t.Errorf("client_id 不匹配: got %s", r.FormValue("client_id"))
}
if r.FormValue("client_secret") != "test-secret" {
t.Errorf("client_secret 不匹配: got %s", r.FormValue("client_secret"))
}
if r.FormValue("code") != "test-auth-code" {
t.Errorf("code 不匹配: got %s", r.FormValue("code"))
}
if r.FormValue("code_verifier") != "test-verifier" {
t.Errorf("code_verifier 不匹配: got %s", r.FormValue("code_verifier"))
}
if r.FormValue("grant_type") != "authorization_code" {
t.Errorf("grant_type 不匹配: got %s", r.FormValue("grant_type"))
}
if r.FormValue("redirect_uri") != RedirectURI {
t.Errorf("redirect_uri 不匹配: got %s", r.FormValue("redirect_uri"))
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(TokenResponse{
AccessToken: "new-access-token",
ExpiresIn: 3600,
TokenType: "Bearer",
Scope: "openid email",
RefreshToken: "new-refresh-token",
})
}))
defer server.Close()
client := newTestClientWithRedirect(map[string]string{
TokenURL: server.URL,
})
tokenResp, err := client.ExchangeCode(context.Background(), "test-auth-code", "test-verifier")
if err != nil {
t.Fatalf("ExchangeCode 失败: %v", err)
}
if tokenResp.AccessToken != "new-access-token" {
t.Errorf("AccessToken 不匹配: got %s, want new-access-token", tokenResp.AccessToken)
}
if tokenResp.RefreshToken != "new-refresh-token" {
t.Errorf("RefreshToken 不匹配: got %s, want new-refresh-token", tokenResp.RefreshToken)
}
if tokenResp.ExpiresIn != 3600 {
t.Errorf("ExpiresIn 不匹配: got %d, want 3600", tokenResp.ExpiresIn)
}
if tokenResp.TokenType != "Bearer" {
t.Errorf("TokenType 不匹配: got %s, want Bearer", tokenResp.TokenType)
}
if tokenResp.Scope != "openid email" {
t.Errorf("Scope 不匹配: got %s, want openid email", tokenResp.Scope)
}
}
func TestClient_ExchangeCode_ServerError_RealCall(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(`{"error":"invalid_grant","error_description":"code expired"}`))
}))
defer server.Close()
client := newTestClientWithRedirect(map[string]string{
TokenURL: server.URL,
})
_, err := client.ExchangeCode(context.Background(), "expired-code", "verifier")
if err == nil {
t.Fatal("服务器返回 400 时应返回错误")
}
if !strings.Contains(err.Error(), "token 交换失败") {
t.Errorf("错误信息应包含 'token 交换失败': got %s", err.Error())
}
if !strings.Contains(err.Error(), "400") {
t.Errorf("错误信息应包含状态码 400: got %s", err.Error())
}
}
func TestClient_ExchangeCode_InvalidJSON_RealCall(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{invalid json`))
}))
defer server.Close()
client := newTestClientWithRedirect(map[string]string{
TokenURL: server.URL,
})
_, err := client.ExchangeCode(context.Background(), "code", "verifier")
if err == nil {
t.Fatal("无效 JSON 响应应返回错误")
}
if !strings.Contains(err.Error(), "token 解析失败") {
t.Errorf("错误信息应包含 'token 解析失败': got %s", err.Error())
}
}
func TestClient_ExchangeCode_ContextCanceled_RealCall(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(5 * time.Second) // 模拟慢响应
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
client := newTestClientWithRedirect(map[string]string{
TokenURL: server.URL,
})
ctx, cancel := context.WithCancel(context.Background())
cancel() // 立即取消
_, err := client.ExchangeCode(ctx, "code", "verifier")
if err == nil {
t.Fatal("context 取消时应返回错误")
}
}
// ---------------------------------------------------------------------------
// Client.RefreshToken - 真正调用方法的测试
// ---------------------------------------------------------------------------
func TestClient_RefreshToken_Success_RealCall(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("请求方法不匹配: got %s, want POST", r.Method)
}
if err := r.ParseForm(); err != nil {
t.Fatalf("解析表单失败: %v", err)
}
if r.FormValue("grant_type") != "refresh_token" {
t.Errorf("grant_type 不匹配: got %s", r.FormValue("grant_type"))
}
if r.FormValue("refresh_token") != "my-refresh-token" {
t.Errorf("refresh_token 不匹配: got %s", r.FormValue("refresh_token"))
}
if r.FormValue("client_id") != ClientID {
t.Errorf("client_id 不匹配: got %s", r.FormValue("client_id"))
}
if r.FormValue("client_secret") != "test-secret" {
t.Errorf("client_secret 不匹配: got %s", r.FormValue("client_secret"))
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(TokenResponse{
AccessToken: "refreshed-access-token",
ExpiresIn: 3600,
TokenType: "Bearer",
})
}))
defer server.Close()
client := newTestClientWithRedirect(map[string]string{
TokenURL: server.URL,
})
tokenResp, err := client.RefreshToken(context.Background(), "my-refresh-token")
if err != nil {
t.Fatalf("RefreshToken 失败: %v", err)
}
if tokenResp.AccessToken != "refreshed-access-token" {
t.Errorf("AccessToken 不匹配: got %s, want refreshed-access-token", tokenResp.AccessToken)
}
if tokenResp.ExpiresIn != 3600 {
t.Errorf("ExpiresIn 不匹配: got %d, want 3600", tokenResp.ExpiresIn)
}
}
func TestClient_RefreshToken_ServerError_RealCall(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte(`{"error":"invalid_grant","error_description":"token revoked"}`))
}))
defer server.Close()
client := newTestClientWithRedirect(map[string]string{
TokenURL: server.URL,
})
_, err := client.RefreshToken(context.Background(), "revoked-token")
if err == nil {
t.Fatal("服务器返回 401 时应返回错误")
}
if !strings.Contains(err.Error(), "token 刷新失败") {
t.Errorf("错误信息应包含 'token 刷新失败': got %s", err.Error())
}
}
func TestClient_RefreshToken_InvalidJSON_RealCall(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`not-json`))
}))
defer server.Close()
client := newTestClientWithRedirect(map[string]string{
TokenURL: server.URL,
})
_, err := client.RefreshToken(context.Background(), "refresh-tok")
if err == nil {
t.Fatal("无效 JSON 响应应返回错误")
}
if !strings.Contains(err.Error(), "token 解析失败") {
t.Errorf("错误信息应包含 'token 解析失败': got %s", err.Error())
}
}
func TestClient_RefreshToken_ContextCanceled_RealCall(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(5 * time.Second)
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
client := newTestClientWithRedirect(map[string]string{
TokenURL: server.URL,
})
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := client.RefreshToken(ctx, "refresh-tok")
if err == nil {
t.Fatal("context 取消时应返回错误")
}
}
// ---------------------------------------------------------------------------
// Client.GetUserInfo - 真正调用方法的测试
// ---------------------------------------------------------------------------
func TestClient_GetUserInfo_Success_RealCall(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
t.Errorf("请求方法不匹配: got %s, want GET", r.Method)
}
auth := r.Header.Get("Authorization")
if auth != "Bearer user-access-token" {
t.Errorf("Authorization 不匹配: got %s", auth)
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(UserInfo{
Email: "test@example.com",
Name: "Test User",
GivenName: "Test",
FamilyName: "User",
Picture: "https://example.com/avatar.jpg",
})
}))
defer server.Close()
client := newTestClientWithRedirect(map[string]string{
UserInfoURL: server.URL,
})
userInfo, err := client.GetUserInfo(context.Background(), "user-access-token")
if err != nil {
t.Fatalf("GetUserInfo 失败: %v", err)
}
if userInfo.Email != "test@example.com" {
t.Errorf("Email 不匹配: got %s, want test@example.com", userInfo.Email)
}
if userInfo.Name != "Test User" {
t.Errorf("Name 不匹配: got %s, want Test User", userInfo.Name)
}
if userInfo.GivenName != "Test" {
t.Errorf("GivenName 不匹配: got %s, want Test", userInfo.GivenName)
}
if userInfo.FamilyName != "User" {
t.Errorf("FamilyName 不匹配: got %s, want User", userInfo.FamilyName)
}
if userInfo.Picture != "https://example.com/avatar.jpg" {
t.Errorf("Picture 不匹配: got %s", userInfo.Picture)
}
}
func TestClient_GetUserInfo_Unauthorized_RealCall(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte(`{"error":"invalid_token"}`))
}))
defer server.Close()
client := newTestClientWithRedirect(map[string]string{
UserInfoURL: server.URL,
})
_, err := client.GetUserInfo(context.Background(), "bad-token")
if err == nil {
t.Fatal("服务器返回 401 时应返回错误")
}
if !strings.Contains(err.Error(), "获取用户信息失败") {
t.Errorf("错误信息应包含 '获取用户信息失败': got %s", err.Error())
}
if !strings.Contains(err.Error(), "401") {
t.Errorf("错误信息应包含状态码 401: got %s", err.Error())
}
}
func TestClient_GetUserInfo_InvalidJSON_RealCall(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{broken`))
}))
defer server.Close()
client := newTestClientWithRedirect(map[string]string{
UserInfoURL: server.URL,
})
_, err := client.GetUserInfo(context.Background(), "token")
if err == nil {
t.Fatal("无效 JSON 响应应返回错误")
}
if !strings.Contains(err.Error(), "用户信息解析失败") {
t.Errorf("错误信息应包含 '用户信息解析失败': got %s", err.Error())
}
}
func TestClient_GetUserInfo_ContextCanceled_RealCall(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(5 * time.Second)
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
client := newTestClientWithRedirect(map[string]string{
UserInfoURL: server.URL,
})
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := client.GetUserInfo(ctx, "token")
if err == nil {
t.Fatal("context 取消时应返回错误")
}
}
// ---------------------------------------------------------------------------
// Client.LoadCodeAssist - 真正调用方法的测试
// ---------------------------------------------------------------------------
// withMockBaseURLs 临时替换 BaseURLs,测试结束后恢复
func withMockBaseURLs(t *testing.T, urls []string) {
t.Helper()
origBaseURLs := BaseURLs
origBaseURL := BaseURL
BaseURLs = urls
if len(urls) > 0 {
BaseURL = urls[0]
}
t.Cleanup(func() {
BaseURLs = origBaseURLs
BaseURL = origBaseURL
})
}
func TestClient_LoadCodeAssist_Success_RealCall(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("请求方法不匹配: got %s, want POST", r.Method)
}
if !strings.HasSuffix(r.URL.Path, "/v1internal:loadCodeAssist") {
t.Errorf("URL 路径不匹配: got %s", r.URL.Path)
}
auth := r.Header.Get("Authorization")
if auth != "Bearer test-token" {
t.Errorf("Authorization 不匹配: got %s", auth)
}
if ct := r.Header.Get("Content-Type"); ct != "application/json" {
t.Errorf("Content-Type 不匹配: got %s", ct)
}
if ua := r.Header.Get("User-Agent"); ua != UserAgent {
t.Errorf("User-Agent 不匹配: got %s", ua)
}
// 验证请求体
var reqBody LoadCodeAssistRequest
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
t.Fatalf("解析请求体失败: %v", err)
}
if reqBody.Metadata.IDEType != "ANTIGRAVITY" {
t.Errorf("IDEType 不匹配: got %s, want ANTIGRAVITY", reqBody.Metadata.IDEType)
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{
"cloudaicompanionProject": "test-project-123",
"currentTier": {"id": "free-tier", "name": "Free"},
"paidTier": {"id": "g1-pro-tier", "name": "Pro", "description": "Pro plan"}
}`))
}))
defer server.Close()
withMockBaseURLs(t, []string{server.URL})
client := NewClient("")
resp, rawResp, err := client.LoadCodeAssist(context.Background(), "test-token")
if err != nil {
t.Fatalf("LoadCodeAssist 失败: %v", err)
}
if resp.CloudAICompanionProject != "test-project-123" {
t.Errorf("CloudAICompanionProject 不匹配: got %s", resp.CloudAICompanionProject)
}
if resp.GetTier() != "g1-pro-tier" {
t.Errorf("GetTier 不匹配: got %s, want g1-pro-tier", resp.GetTier())
}
if resp.CurrentTier == nil || resp.CurrentTier.ID != "free-tier" {
t.Errorf("CurrentTier 不匹配: got %+v", resp.CurrentTier)
}
if resp.PaidTier == nil || resp.PaidTier.ID != "g1-pro-tier" {
t.Errorf("PaidTier 不匹配: got %+v", resp.PaidTier)
}
// 验证原始 JSON map
if rawResp == nil {
t.Fatal("rawResp 不应为 nil")
}
if rawResp["cloudaicompanionProject"] != "test-project-123" {
t.Errorf("rawResp cloudaicompanionProject 不匹配: got %v", rawResp["cloudaicompanionProject"])
}
}
func TestClient_LoadCodeAssist_HTTPError_RealCall(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusForbidden)
_, _ = w.Write([]byte(`{"error":"forbidden"}`))
}))
defer server.Close()
withMockBaseURLs(t, []string{server.URL})
client := NewClient("")
_, _, err := client.LoadCodeAssist(context.Background(), "bad-token")
if err == nil {
t.Fatal("服务器返回 403 时应返回错误")
}
if !strings.Contains(err.Error(), "loadCodeAssist 失败") {
t.Errorf("错误信息应包含 'loadCodeAssist 失败': got %s", err.Error())
}
if !strings.Contains(err.Error(), "403") {
t.Errorf("错误信息应包含状态码 403: got %s", err.Error())
}
}
func TestClient_LoadCodeAssist_InvalidJSON_RealCall(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{not valid json!!!`))
}))
defer server.Close()
withMockBaseURLs(t, []string{server.URL})
client := NewClient("")
_, _, err := client.LoadCodeAssist(context.Background(), "token")
if err == nil {
t.Fatal("无效 JSON 响应应返回错误")
}
if !strings.Contains(err.Error(), "响应解析失败") {
t.Errorf("错误信息应包含 '响应解析失败': got %s", err.Error())
}
}
func TestClient_LoadCodeAssist_URLFallback_RealCall(t *testing.T) {
// 第一个 server 返回 500,第二个 server 返回成功
callCount := 0
server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte(`{"error":"internal"}`))
}))
defer server1.Close()
server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{
"cloudaicompanionProject": "fallback-project",
"currentTier": {"id": "free-tier", "name": "Free"}
}`))
}))
defer server2.Close()
withMockBaseURLs(t, []string{server1.URL, server2.URL})
client := NewClient("")
resp, _, err := client.LoadCodeAssist(context.Background(), "token")
if err != nil {
t.Fatalf("LoadCodeAssist 应在 fallback 后成功: %v", err)
}
if resp.CloudAICompanionProject != "fallback-project" {
t.Errorf("CloudAICompanionProject 不匹配: got %s", resp.CloudAICompanionProject)
}
if callCount != 2 {
t.Errorf("应该调用了 2 个 server,实际调用 %d 次", callCount)
}
}
func TestClient_LoadCodeAssist_AllURLsFail_RealCall(t *testing.T) {
server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable)
_, _ = w.Write([]byte(`{"error":"unavailable"}`))
}))
defer server1.Close()
server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadGateway)
_, _ = w.Write([]byte(`{"error":"bad_gateway"}`))
}))
defer server2.Close()
withMockBaseURLs(t, []string{server1.URL, server2.URL})
client := NewClient("")
_, _, err := client.LoadCodeAssist(context.Background(), "token")
if err == nil {
t.Fatal("所有 URL 都失败时应返回错误")
}
}
func TestClient_LoadCodeAssist_ContextCanceled_RealCall(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(5 * time.Second)
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
withMockBaseURLs(t, []string{server.URL})
client := NewClient("")
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, _, err := client.LoadCodeAssist(ctx, "token")
if err == nil {
t.Fatal("context 取消时应返回错误")
}
}
// ---------------------------------------------------------------------------
// Client.FetchAvailableModels - 真正调用方法的测试
// ---------------------------------------------------------------------------
func TestClient_FetchAvailableModels_Success_RealCall(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("请求方法不匹配: got %s, want POST", r.Method)
}
if !strings.HasSuffix(r.URL.Path, "/v1internal:fetchAvailableModels") {
t.Errorf("URL 路径不匹配: got %s", r.URL.Path)
}
auth := r.Header.Get("Authorization")
if auth != "Bearer test-token" {
t.Errorf("Authorization 不匹配: got %s", auth)
}
if ct := r.Header.Get("Content-Type"); ct != "application/json" {
t.Errorf("Content-Type 不匹配: got %s", ct)
}
if ua := r.Header.Get("User-Agent"); ua != UserAgent {
t.Errorf("User-Agent 不匹配: got %s", ua)
}
// 验证请求体
var reqBody FetchAvailableModelsRequest
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
t.Fatalf("解析请求体失败: %v", err)
}
if reqBody.Project != "project-abc" {
t.Errorf("Project 不匹配: got %s, want project-abc", reqBody.Project)
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{
"models": {
"gemini-2.0-flash": {
"quotaInfo": {
"remainingFraction": 0.85,
"resetTime": "2025-01-01T00:00:00Z"
}
},
"gemini-2.5-pro": {
"quotaInfo": {
"remainingFraction": 0.5
}
}
}
}`))
}))
defer server.Close()
withMockBaseURLs(t, []string{server.URL})
client := NewClient("")
resp, rawResp, err := client.FetchAvailableModels(context.Background(), "test-token", "project-abc")
if err != nil {
t.Fatalf("FetchAvailableModels 失败: %v", err)
}
if resp.Models == nil {
t.Fatal("Models 不应为 nil")
}
if len(resp.Models) != 2 {
t.Errorf("Models 数量不匹配: got %d, want 2", len(resp.Models))
}
flashModel, ok := resp.Models["gemini-2.0-flash"]
if !ok {
t.Fatal("缺少 gemini-2.0-flash 模型")
}
if flashModel.QuotaInfo == nil {
t.Fatal("gemini-2.0-flash QuotaInfo 不应为 nil")
}
if flashModel.QuotaInfo.RemainingFraction != 0.85 {
t.Errorf("RemainingFraction 不匹配: got %f, want 0.85", flashModel.QuotaInfo.RemainingFraction)
}
if flashModel.QuotaInfo.ResetTime != "2025-01-01T00:00:00Z" {
t.Errorf("ResetTime 不匹配: got %s", flashModel.QuotaInfo.ResetTime)
}
proModel, ok := resp.Models["gemini-2.5-pro"]
if !ok {
t.Fatal("缺少 gemini-2.5-pro 模型")
}
if proModel.QuotaInfo == nil {
t.Fatal("gemini-2.5-pro QuotaInfo 不应为 nil")
}
if proModel.QuotaInfo.RemainingFraction != 0.5 {
t.Errorf("RemainingFraction 不匹配: got %f, want 0.5", proModel.QuotaInfo.RemainingFraction)
}
// 验证原始 JSON map
if rawResp == nil {
t.Fatal("rawResp 不应为 nil")
}
if rawResp["models"] == nil {
t.Error("rawResp models 不应为 nil")
}
}
func TestClient_FetchAvailableModels_HTTPError_RealCall(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusForbidden)
_, _ = w.Write([]byte(`{"error":"forbidden"}`))
}))
defer server.Close()
withMockBaseURLs(t, []string{server.URL})
client := NewClient("")
_, _, err := client.FetchAvailableModels(context.Background(), "bad-token", "proj")
if err == nil {
t.Fatal("服务器返回 403 时应返回错误")
}
if !strings.Contains(err.Error(), "fetchAvailableModels 失败") {
t.Errorf("错误信息应包含 'fetchAvailableModels 失败': got %s", err.Error())
}
}
func TestClient_FetchAvailableModels_InvalidJSON_RealCall(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`<<<not json>>>`))
}))
defer server.Close()
withMockBaseURLs(t, []string{server.URL})
client := NewClient("")
_, _, err := client.FetchAvailableModels(context.Background(), "token", "proj")
if err == nil {
t.Fatal("无效 JSON 响应应返回错误")
}
if !strings.Contains(err.Error(), "响应解析失败") {
t.Errorf("错误信息应包含 '响应解析失败': got %s", err.Error())
}
}
func TestClient_FetchAvailableModels_URLFallback_RealCall(t *testing.T) {
callCount := 0
// 第一个 server 返回 429,第二个 server 返回成功
server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
w.WriteHeader(http.StatusTooManyRequests)
_, _ = w.Write([]byte(`{"error":"rate_limited"}`))
}))
defer server1.Close()
server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"models": {"model-a": {}}}`))
}))
defer server2.Close()
withMockBaseURLs(t, []string{server1.URL, server2.URL})
client := NewClient("")
resp, _, err := client.FetchAvailableModels(context.Background(), "token", "proj")
if err != nil {
t.Fatalf("FetchAvailableModels 应在 fallback 后成功: %v", err)
}
if _, ok := resp.Models["model-a"]; !ok {
t.Error("应返回 fallback server 的模型")
}
if callCount != 2 {
t.Errorf("应该调用了 2 个 server,实际调用 %d 次", callCount)
}
}
func TestClient_FetchAvailableModels_AllURLsFail_RealCall(t *testing.T) {
server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
_, _ = w.Write([]byte(`not found`))
}))
defer server1.Close()
server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte(`internal error`))
}))
defer server2.Close()
withMockBaseURLs(t, []string{server1.URL, server2.URL})
client := NewClient("")
_, _, err := client.FetchAvailableModels(context.Background(), "token", "proj")
if err == nil {
t.Fatal("所有 URL 都失败时应返回错误")
}
}
func TestClient_FetchAvailableModels_ContextCanceled_RealCall(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(5 * time.Second)
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
withMockBaseURLs(t, []string{server.URL})
client := NewClient("")
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, _, err := client.FetchAvailableModels(ctx, "token", "proj")
if err == nil {
t.Fatal("context 取消时应返回错误")
}
}
func TestClient_FetchAvailableModels_EmptyModels_RealCall(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"models": {}}`))
}))
defer server.Close()
withMockBaseURLs(t, []string{server.URL})
client := NewClient("")
resp, rawResp, err := client.FetchAvailableModels(context.Background(), "token", "proj")
if err != nil {
t.Fatalf("FetchAvailableModels 失败: %v", err)
}
if resp.Models == nil {
t.Fatal("Models 不应为 nil")
}
if len(resp.Models) != 0 {
t.Errorf("Models 应为空: got %d", len(resp.Models))
}
if rawResp == nil {
t.Fatal("rawResp 不应为 nil")
}
}
// ---------------------------------------------------------------------------
// LoadCodeAssist 和 FetchAvailableModels 的 408 fallback 测试
// ---------------------------------------------------------------------------
func TestClient_LoadCodeAssist_408Fallback_RealCall(t *testing.T) {
server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusRequestTimeout)
_, _ = w.Write([]byte(`timeout`))
}))
defer server1.Close()
server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"cloudaicompanionProject":"p2","currentTier":"free-tier"}`))
}))
defer server2.Close()
withMockBaseURLs(t, []string{server1.URL, server2.URL})
client := NewClient("")
resp, _, err := client.LoadCodeAssist(context.Background(), "token")
if err != nil {
t.Fatalf("LoadCodeAssist 应在 408 fallback 后成功: %v", err)
}
if resp.CloudAICompanionProject != "p2" {
t.Errorf("CloudAICompanionProject 不匹配: got %s", resp.CloudAICompanionProject)
}
}
func TestClient_FetchAvailableModels_404Fallback_RealCall(t *testing.T) {
server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
_, _ = w.Write([]byte(`not found`))
}))
defer server1.Close()
server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"models":{"m1":{"quotaInfo":{"remainingFraction":1.0}}}}`))
}))
defer server2.Close()
withMockBaseURLs(t, []string{server1.URL, server2.URL})
client := NewClient("")
resp, _, err := client.FetchAvailableModels(context.Background(), "token", "proj")
if err != nil {
t.Fatalf("FetchAvailableModels 应在 404 fallback 后成功: %v", err)
}
if _, ok := resp.Models["m1"]; !ok {
t.Error("应返回 fallback server 的模型 m1")
}
}
func TestExtractProjectIDFromOnboardResponse(t *testing.T) { func TestExtractProjectIDFromOnboardResponse(t *testing.T) {
t.Parallel() t.Parallel()
......
...@@ -6,10 +6,14 @@ import ( ...@@ -6,10 +6,14 @@ import (
"encoding/base64" "encoding/base64"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"net/http"
"net/url" "net/url"
"os"
"strings" "strings"
"sync" "sync"
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
) )
const ( const (
...@@ -20,7 +24,11 @@ const ( ...@@ -20,7 +24,11 @@ const (
// Antigravity OAuth 客户端凭证 // Antigravity OAuth 客户端凭证
ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
ClientSecret = "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" ClientSecret = ""
// AntigravityOAuthClientSecretEnv 是 Antigravity OAuth client_secret 的环境变量名。
// 出于安全原因,该值不得硬编码入库。
AntigravityOAuthClientSecretEnv = "ANTIGRAVITY_OAUTH_CLIENT_SECRET"
// 固定的 redirect_uri(用户需手动复制 code) // 固定的 redirect_uri(用户需手动复制 code)
RedirectURI = "http://localhost:8085/callback" RedirectURI = "http://localhost:8085/callback"
...@@ -46,6 +54,18 @@ const ( ...@@ -46,6 +54,18 @@ const (
antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com" antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
) )
func getClientSecret() (string, error) {
if v := strings.TrimSpace(ClientSecret); v != "" {
return v, nil
}
if v, ok := os.LookupEnv(AntigravityOAuthClientSecretEnv); ok {
if vv := strings.TrimSpace(v); vv != "" {
return vv, nil
}
}
return "", infraerrors.Newf(http.StatusBadRequest, "ANTIGRAVITY_OAUTH_CLIENT_SECRET_MISSING", "missing antigravity oauth client_secret; set %s", AntigravityOAuthClientSecretEnv)
}
// BaseURLs 定义 Antigravity API 端点(与 Antigravity-Manager 保持一致) // BaseURLs 定义 Antigravity API 端点(与 Antigravity-Manager 保持一致)
var BaseURLs = []string{ var BaseURLs = []string{
antigravityProdBaseURL, // prod (优先) antigravityProdBaseURL, // prod (优先)
......
//go:build unit
package antigravity
import (
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"net/url"
"strings"
"testing"
"time"
)
// ---------------------------------------------------------------------------
// getClientSecret
// ---------------------------------------------------------------------------
func TestGetClientSecret_环境变量设置(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "my-secret-value")
secret, err := getClientSecret()
if err != nil {
t.Fatalf("获取 client_secret 失败: %v", err)
}
if secret != "my-secret-value" {
t.Errorf("client_secret 不匹配: got %s, want my-secret-value", secret)
}
}
func TestGetClientSecret_环境变量为空(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "")
_, err := getClientSecret()
if err == nil {
t.Fatal("环境变量为空时应返回错误")
}
if !strings.Contains(err.Error(), AntigravityOAuthClientSecretEnv) {
t.Errorf("错误信息应包含环境变量名: got %s", err.Error())
}
}
func TestGetClientSecret_环境变量未设置(t *testing.T) {
// t.Setenv 会在测试结束时恢复,但我们需要确保它不存在
// 注意:如果 ClientSecret 常量非空,这个测试会直接返回常量值
// 当前代码中 ClientSecret = "",所以会走环境变量逻辑
// 明确设置再取消,确保环境变量不存在
t.Setenv(AntigravityOAuthClientSecretEnv, "")
_, err := getClientSecret()
if err == nil {
t.Fatal("环境变量未设置时应返回错误")
}
}
func TestGetClientSecret_环境变量含空格(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, " ")
_, err := getClientSecret()
if err == nil {
t.Fatal("环境变量仅含空格时应返回错误")
}
}
func TestGetClientSecret_环境变量有前后空格(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, " valid-secret ")
secret, err := getClientSecret()
if err != nil {
t.Fatalf("获取 client_secret 失败: %v", err)
}
if secret != "valid-secret" {
t.Errorf("应去除前后空格: got %q, want %q", secret, "valid-secret")
}
}
// ---------------------------------------------------------------------------
// ForwardBaseURLs
// ---------------------------------------------------------------------------
func TestForwardBaseURLs_Daily优先(t *testing.T) {
urls := ForwardBaseURLs()
if len(urls) == 0 {
t.Fatal("ForwardBaseURLs 返回空列表")
}
// daily URL 应排在第一位
if urls[0] != antigravityDailyBaseURL {
t.Errorf("第一个 URL 应为 daily: got %s, want %s", urls[0], antigravityDailyBaseURL)
}
// 应包含所有 URL
if len(urls) != len(BaseURLs) {
t.Errorf("URL 数量不匹配: got %d, want %d", len(urls), len(BaseURLs))
}
// 验证 prod URL 也在列表中
found := false
for _, u := range urls {
if u == antigravityProdBaseURL {
found = true
break
}
}
if !found {
t.Error("ForwardBaseURLs 中缺少 prod URL")
}
}
func TestForwardBaseURLs_不修改原切片(t *testing.T) {
originalFirst := BaseURLs[0]
_ = ForwardBaseURLs()
// 确保原始 BaseURLs 未被修改
if BaseURLs[0] != originalFirst {
t.Errorf("ForwardBaseURLs 不应修改原始 BaseURLs: got %s, want %s", BaseURLs[0], originalFirst)
}
}
// ---------------------------------------------------------------------------
// URLAvailability
// ---------------------------------------------------------------------------
func TestNewURLAvailability(t *testing.T) {
ua := NewURLAvailability(5 * time.Minute)
if ua == nil {
t.Fatal("NewURLAvailability 返回 nil")
}
if ua.ttl != 5*time.Minute {
t.Errorf("TTL 不匹配: got %v, want 5m", ua.ttl)
}
if ua.unavailable == nil {
t.Error("unavailable map 不应为 nil")
}
}
func TestURLAvailability_MarkUnavailable(t *testing.T) {
ua := NewURLAvailability(5 * time.Minute)
testURL := "https://example.com"
ua.MarkUnavailable(testURL)
if ua.IsAvailable(testURL) {
t.Error("标记为不可用后 IsAvailable 应返回 false")
}
}
func TestURLAvailability_MarkSuccess(t *testing.T) {
ua := NewURLAvailability(5 * time.Minute)
testURL := "https://example.com"
// 先标记为不可用
ua.MarkUnavailable(testURL)
if ua.IsAvailable(testURL) {
t.Error("标记为不可用后应不可用")
}
// 标记成功后应恢复可用
ua.MarkSuccess(testURL)
if !ua.IsAvailable(testURL) {
t.Error("MarkSuccess 后应恢复可用")
}
// 验证 lastSuccess 被设置
ua.mu.RLock()
if ua.lastSuccess != testURL {
t.Errorf("lastSuccess 不匹配: got %s, want %s", ua.lastSuccess, testURL)
}
ua.mu.RUnlock()
}
func TestURLAvailability_IsAvailable_TTL过期(t *testing.T) {
// 使用极短的 TTL
ua := NewURLAvailability(1 * time.Millisecond)
testURL := "https://example.com"
ua.MarkUnavailable(testURL)
// 等待 TTL 过期
time.Sleep(5 * time.Millisecond)
if !ua.IsAvailable(testURL) {
t.Error("TTL 过期后 URL 应恢复可用")
}
}
func TestURLAvailability_IsAvailable_未标记的URL(t *testing.T) {
ua := NewURLAvailability(5 * time.Minute)
if !ua.IsAvailable("https://never-marked.com") {
t.Error("未标记的 URL 应默认可用")
}
}
func TestURLAvailability_GetAvailableURLs(t *testing.T) {
ua := NewURLAvailability(10 * time.Minute)
// 默认所有 URL 都可用
urls := ua.GetAvailableURLs()
if len(urls) != len(BaseURLs) {
t.Errorf("可用 URL 数量不匹配: got %d, want %d", len(urls), len(BaseURLs))
}
}
func TestURLAvailability_GetAvailableURLs_标记一个不可用(t *testing.T) {
ua := NewURLAvailability(10 * time.Minute)
if len(BaseURLs) < 2 {
t.Skip("BaseURLs 少于 2 个,跳过此测试")
}
ua.MarkUnavailable(BaseURLs[0])
urls := ua.GetAvailableURLs()
// 标记的 URL 不应出现在可用列表中
for _, u := range urls {
if u == BaseURLs[0] {
t.Errorf("被标记不可用的 URL 不应出现在可用列表中: %s", BaseURLs[0])
}
}
}
func TestURLAvailability_GetAvailableURLsWithBase(t *testing.T) {
ua := NewURLAvailability(10 * time.Minute)
customURLs := []string{"https://a.com", "https://b.com", "https://c.com"}
urls := ua.GetAvailableURLsWithBase(customURLs)
if len(urls) != 3 {
t.Errorf("可用 URL 数量不匹配: got %d, want 3", len(urls))
}
}
func TestURLAvailability_GetAvailableURLsWithBase_LastSuccess优先(t *testing.T) {
ua := NewURLAvailability(10 * time.Minute)
customURLs := []string{"https://a.com", "https://b.com", "https://c.com"}
ua.MarkSuccess("https://c.com")
urls := ua.GetAvailableURLsWithBase(customURLs)
if len(urls) != 3 {
t.Fatalf("可用 URL 数量不匹配: got %d, want 3", len(urls))
}
// c.com 应排在第一位
if urls[0] != "https://c.com" {
t.Errorf("lastSuccess 应排在第一位: got %s, want https://c.com", urls[0])
}
// 其余按原始顺序
if urls[1] != "https://a.com" {
t.Errorf("第二个应为 a.com: got %s", urls[1])
}
if urls[2] != "https://b.com" {
t.Errorf("第三个应为 b.com: got %s", urls[2])
}
}
func TestURLAvailability_GetAvailableURLsWithBase_LastSuccess不可用(t *testing.T) {
ua := NewURLAvailability(10 * time.Minute)
customURLs := []string{"https://a.com", "https://b.com"}
ua.MarkSuccess("https://b.com")
ua.MarkUnavailable("https://b.com")
urls := ua.GetAvailableURLsWithBase(customURLs)
// b.com 被标记不可用,不应出现
if len(urls) != 1 {
t.Fatalf("可用 URL 数量不匹配: got %d, want 1", len(urls))
}
if urls[0] != "https://a.com" {
t.Errorf("仅 a.com 应可用: got %s", urls[0])
}
}
func TestURLAvailability_GetAvailableURLsWithBase_LastSuccess不在列表中(t *testing.T) {
ua := NewURLAvailability(10 * time.Minute)
customURLs := []string{"https://a.com", "https://b.com"}
ua.MarkSuccess("https://not-in-list.com")
urls := ua.GetAvailableURLsWithBase(customURLs)
// lastSuccess 不在自定义列表中,不应被添加
if len(urls) != 2 {
t.Fatalf("可用 URL 数量不匹配: got %d, want 2", len(urls))
}
}
// ---------------------------------------------------------------------------
// SessionStore
// ---------------------------------------------------------------------------
func TestNewSessionStore(t *testing.T) {
store := NewSessionStore()
defer store.Stop()
if store == nil {
t.Fatal("NewSessionStore 返回 nil")
}
if store.sessions == nil {
t.Error("sessions map 不应为 nil")
}
}
func TestSessionStore_SetAndGet(t *testing.T) {
store := NewSessionStore()
defer store.Stop()
session := &OAuthSession{
State: "test-state",
CodeVerifier: "test-verifier",
ProxyURL: "http://proxy.example.com",
CreatedAt: time.Now(),
}
store.Set("session-1", session)
got, ok := store.Get("session-1")
if !ok {
t.Fatal("Get 应返回 true")
}
if got.State != "test-state" {
t.Errorf("State 不匹配: got %s", got.State)
}
if got.CodeVerifier != "test-verifier" {
t.Errorf("CodeVerifier 不匹配: got %s", got.CodeVerifier)
}
if got.ProxyURL != "http://proxy.example.com" {
t.Errorf("ProxyURL 不匹配: got %s", got.ProxyURL)
}
}
func TestSessionStore_Get_不存在(t *testing.T) {
store := NewSessionStore()
defer store.Stop()
_, ok := store.Get("nonexistent")
if ok {
t.Error("不存在的 session 应返回 false")
}
}
func TestSessionStore_Get_过期(t *testing.T) {
store := NewSessionStore()
defer store.Stop()
session := &OAuthSession{
State: "expired-state",
CreatedAt: time.Now().Add(-SessionTTL - time.Minute), // 已过期
}
store.Set("expired-session", session)
_, ok := store.Get("expired-session")
if ok {
t.Error("过期的 session 应返回 false")
}
}
func TestSessionStore_Delete(t *testing.T) {
store := NewSessionStore()
defer store.Stop()
session := &OAuthSession{
State: "to-delete",
CreatedAt: time.Now(),
}
store.Set("del-session", session)
store.Delete("del-session")
_, ok := store.Get("del-session")
if ok {
t.Error("删除后 Get 应返回 false")
}
}
func TestSessionStore_Delete_不存在(t *testing.T) {
store := NewSessionStore()
defer store.Stop()
// 删除不存在的 session 不应 panic
store.Delete("nonexistent")
}
func TestSessionStore_Stop(t *testing.T) {
store := NewSessionStore()
store.Stop()
// 多次 Stop 不应 panic
store.Stop()
}
func TestSessionStore_多个Session(t *testing.T) {
store := NewSessionStore()
defer store.Stop()
for i := 0; i < 10; i++ {
session := &OAuthSession{
State: "state-" + string(rune('0'+i)),
CreatedAt: time.Now(),
}
store.Set("session-"+string(rune('0'+i)), session)
}
// 验证都能取到
for i := 0; i < 10; i++ {
_, ok := store.Get("session-" + string(rune('0'+i)))
if !ok {
t.Errorf("session-%d 应存在", i)
}
}
}
// ---------------------------------------------------------------------------
// GenerateRandomBytes
// ---------------------------------------------------------------------------
func TestGenerateRandomBytes_长度正确(t *testing.T) {
sizes := []int{0, 1, 16, 32, 64, 128}
for _, size := range sizes {
b, err := GenerateRandomBytes(size)
if err != nil {
t.Fatalf("GenerateRandomBytes(%d) 失败: %v", size, err)
}
if len(b) != size {
t.Errorf("长度不匹配: got %d, want %d", len(b), size)
}
}
}
func TestGenerateRandomBytes_不同调用产生不同结果(t *testing.T) {
b1, err := GenerateRandomBytes(32)
if err != nil {
t.Fatalf("第一次调用失败: %v", err)
}
b2, err := GenerateRandomBytes(32)
if err != nil {
t.Fatalf("第二次调用失败: %v", err)
}
// 两次生成的随机字节应该不同(概率上几乎不可能相同)
if string(b1) == string(b2) {
t.Error("两次生成的随机字节相同,概率极低,可能有问题")
}
}
// ---------------------------------------------------------------------------
// GenerateState
// ---------------------------------------------------------------------------
func TestGenerateState_返回值格式(t *testing.T) {
state, err := GenerateState()
if err != nil {
t.Fatalf("GenerateState 失败: %v", err)
}
if state == "" {
t.Error("GenerateState 返回空字符串")
}
// base64url 编码不应包含 +, /, =
if strings.ContainsAny(state, "+/=") {
t.Errorf("GenerateState 返回值包含非 base64url 字符: %s", state)
}
// 32 字节的 base64url 编码长度应为 43(去掉了尾部 = 填充)
if len(state) != 43 {
t.Errorf("GenerateState 返回值长度不匹配: got %d, want 43", len(state))
}
}
func TestGenerateState_唯一性(t *testing.T) {
s1, _ := GenerateState()
s2, _ := GenerateState()
if s1 == s2 {
t.Error("两次 GenerateState 结果相同")
}
}
// ---------------------------------------------------------------------------
// GenerateSessionID
// ---------------------------------------------------------------------------
func TestGenerateSessionID_返回值格式(t *testing.T) {
id, err := GenerateSessionID()
if err != nil {
t.Fatalf("GenerateSessionID 失败: %v", err)
}
if id == "" {
t.Error("GenerateSessionID 返回空字符串")
}
// 16 字节的 hex 编码长度应为 32
if len(id) != 32 {
t.Errorf("GenerateSessionID 返回值长度不匹配: got %d, want 32", len(id))
}
// 验证是合法的 hex 字符串
if _, err := hex.DecodeString(id); err != nil {
t.Errorf("GenerateSessionID 返回值不是合法的 hex 字符串: %s, err: %v", id, err)
}
}
func TestGenerateSessionID_唯一性(t *testing.T) {
id1, _ := GenerateSessionID()
id2, _ := GenerateSessionID()
if id1 == id2 {
t.Error("两次 GenerateSessionID 结果相同")
}
}
// ---------------------------------------------------------------------------
// GenerateCodeVerifier
// ---------------------------------------------------------------------------
func TestGenerateCodeVerifier_返回值格式(t *testing.T) {
verifier, err := GenerateCodeVerifier()
if err != nil {
t.Fatalf("GenerateCodeVerifier 失败: %v", err)
}
if verifier == "" {
t.Error("GenerateCodeVerifier 返回空字符串")
}
// base64url 编码不应包含 +, /, =
if strings.ContainsAny(verifier, "+/=") {
t.Errorf("GenerateCodeVerifier 返回值包含非 base64url 字符: %s", verifier)
}
// 32 字节的 base64url 编码长度应为 43
if len(verifier) != 43 {
t.Errorf("GenerateCodeVerifier 返回值长度不匹配: got %d, want 43", len(verifier))
}
}
func TestGenerateCodeVerifier_唯一性(t *testing.T) {
v1, _ := GenerateCodeVerifier()
v2, _ := GenerateCodeVerifier()
if v1 == v2 {
t.Error("两次 GenerateCodeVerifier 结果相同")
}
}
// ---------------------------------------------------------------------------
// GenerateCodeChallenge
// ---------------------------------------------------------------------------
func TestGenerateCodeChallenge_SHA256_Base64URL(t *testing.T) {
verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
challenge := GenerateCodeChallenge(verifier)
// 手动计算预期值
hash := sha256.Sum256([]byte(verifier))
expected := strings.TrimRight(base64.URLEncoding.EncodeToString(hash[:]), "=")
if challenge != expected {
t.Errorf("CodeChallenge 不匹配: got %s, want %s", challenge, expected)
}
}
func TestGenerateCodeChallenge_不含填充字符(t *testing.T) {
challenge := GenerateCodeChallenge("test-verifier")
if strings.Contains(challenge, "=") {
t.Errorf("CodeChallenge 不应包含 = 填充字符: %s", challenge)
}
}
func TestGenerateCodeChallenge_不含非URL安全字符(t *testing.T) {
challenge := GenerateCodeChallenge("another-verifier")
if strings.ContainsAny(challenge, "+/") {
t.Errorf("CodeChallenge 不应包含 + 或 / 字符: %s", challenge)
}
}
func TestGenerateCodeChallenge_相同输入相同输出(t *testing.T) {
c1 := GenerateCodeChallenge("same-verifier")
c2 := GenerateCodeChallenge("same-verifier")
if c1 != c2 {
t.Errorf("相同输入应产生相同输出: got %s and %s", c1, c2)
}
}
func TestGenerateCodeChallenge_不同输入不同输出(t *testing.T) {
c1 := GenerateCodeChallenge("verifier-1")
c2 := GenerateCodeChallenge("verifier-2")
if c1 == c2 {
t.Error("不同输入应产生不同输出")
}
}
// ---------------------------------------------------------------------------
// BuildAuthorizationURL
// ---------------------------------------------------------------------------
func TestBuildAuthorizationURL_参数验证(t *testing.T) {
state := "test-state-123"
codeChallenge := "test-challenge-abc"
authURL := BuildAuthorizationURL(state, codeChallenge)
// 验证以 AuthorizeURL 开头
if !strings.HasPrefix(authURL, AuthorizeURL+"?") {
t.Errorf("URL 应以 %s? 开头: got %s", AuthorizeURL, authURL)
}
// 解析 URL 并验证参数
parsed, err := url.Parse(authURL)
if err != nil {
t.Fatalf("解析 URL 失败: %v", err)
}
params := parsed.Query()
expectedParams := map[string]string{
"client_id": ClientID,
"redirect_uri": RedirectURI,
"response_type": "code",
"scope": Scopes,
"state": state,
"code_challenge": codeChallenge,
"code_challenge_method": "S256",
"access_type": "offline",
"prompt": "consent",
"include_granted_scopes": "true",
}
for key, want := range expectedParams {
got := params.Get(key)
if got != want {
t.Errorf("参数 %s 不匹配: got %q, want %q", key, got, want)
}
}
}
func TestBuildAuthorizationURL_参数数量(t *testing.T) {
authURL := BuildAuthorizationURL("s", "c")
parsed, err := url.Parse(authURL)
if err != nil {
t.Fatalf("解析 URL 失败: %v", err)
}
params := parsed.Query()
// 应包含 10 个参数
expectedCount := 10
if len(params) != expectedCount {
t.Errorf("参数数量不匹配: got %d, want %d", len(params), expectedCount)
}
}
func TestBuildAuthorizationURL_特殊字符编码(t *testing.T) {
state := "state+with/special=chars"
codeChallenge := "challenge+value"
authURL := BuildAuthorizationURL(state, codeChallenge)
parsed, err := url.Parse(authURL)
if err != nil {
t.Fatalf("解析 URL 失败: %v", err)
}
// 解析后应正确还原特殊字符
if got := parsed.Query().Get("state"); got != state {
t.Errorf("state 参数编码/解码不匹配: got %q, want %q", got, state)
}
}
// ---------------------------------------------------------------------------
// 常量值验证
// ---------------------------------------------------------------------------
func TestConstants_值正确(t *testing.T) {
if AuthorizeURL != "https://accounts.google.com/o/oauth2/v2/auth" {
t.Errorf("AuthorizeURL 不匹配: got %s", AuthorizeURL)
}
if TokenURL != "https://oauth2.googleapis.com/token" {
t.Errorf("TokenURL 不匹配: got %s", TokenURL)
}
if UserInfoURL != "https://www.googleapis.com/oauth2/v2/userinfo" {
t.Errorf("UserInfoURL 不匹配: got %s", UserInfoURL)
}
if ClientID != "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" {
t.Errorf("ClientID 不匹配: got %s", ClientID)
}
if ClientSecret != "" {
t.Error("ClientSecret 应为空字符串")
}
if RedirectURI != "http://localhost:8085/callback" {
t.Errorf("RedirectURI 不匹配: got %s", RedirectURI)
}
if UserAgent != "antigravity/1.15.8 windows/amd64" {
t.Errorf("UserAgent 不匹配: got %s", UserAgent)
}
if SessionTTL != 30*time.Minute {
t.Errorf("SessionTTL 不匹配: got %v", SessionTTL)
}
if URLAvailabilityTTL != 5*time.Minute {
t.Errorf("URLAvailabilityTTL 不匹配: got %v", URLAvailabilityTTL)
}
}
func TestScopes_包含必要范围(t *testing.T) {
expectedScopes := []string{
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
"https://www.googleapis.com/auth/cclog",
"https://www.googleapis.com/auth/experimentsandconfigs",
}
for _, scope := range expectedScopes {
if !strings.Contains(Scopes, scope) {
t.Errorf("Scopes 缺少 %s", scope)
}
}
}
package antigravity package antigravity
import ( import (
"crypto/rand"
"encoding/json" "encoding/json"
"fmt" "fmt"
"log" "log"
"strings" "strings"
"sync/atomic"
"time"
) )
// TransformGeminiToClaude 将 Gemini 响应转换为 Claude 格式(非流式) // TransformGeminiToClaude 将 Gemini 响应转换为 Claude 格式(非流式)
...@@ -341,12 +344,30 @@ func buildGroundingText(grounding *GeminiGroundingMetadata) string { ...@@ -341,12 +344,30 @@ func buildGroundingText(grounding *GeminiGroundingMetadata) string {
return builder.String() return builder.String()
} }
// generateRandomID 生成随机 ID // fallbackCounter 降级伪随机 ID 的全局计数器,混入 seed 避免高并发下 UnixNano 相同导致碰撞。
var fallbackCounter uint64
// generateRandomID 生成密码学安全的随机 ID
func generateRandomID() string { func generateRandomID() string {
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
result := make([]byte, 12) id := make([]byte, 12)
for i := range result { randBytes := make([]byte, 12)
result[i] = chars[i%len(chars)] if _, err := rand.Read(randBytes); err != nil {
// 避免在请求路径里 panic:极端情况下熵源不可用时降级为伪随机。
// 这里主要用于生成响应/工具调用的临时 ID,安全要求不高但需尽量避免碰撞。
cnt := atomic.AddUint64(&fallbackCounter, 1)
seed := uint64(time.Now().UnixNano()) ^ cnt
seed ^= uint64(len(err.Error())) << 32
for i := range id {
seed ^= seed << 13
seed ^= seed >> 7
seed ^= seed << 17
id[i] = chars[int(seed)%len(chars)]
}
return string(id)
}
for i, b := range randBytes {
id[i] = chars[int(b)%len(chars)]
} }
return string(result) return string(id)
} }
//go:build unit
package antigravity
import (
"sync"
"sync/atomic"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// --- Task 7: 验证 generateRandomID 和降级碰撞防护 ---
func TestGenerateRandomID_Uniqueness(t *testing.T) {
seen := make(map[string]struct{}, 100)
for i := 0; i < 100; i++ {
id := generateRandomID()
require.Len(t, id, 12, "ID 长度应为 12")
_, dup := seen[id]
require.False(t, dup, "第 %d 次调用生成了重复 ID: %s", i, id)
seen[id] = struct{}{}
}
}
func TestFallbackCounter_Increments(t *testing.T) {
// 验证 fallbackCounter 的原子递增行为确保降级分支不会生成相同 seed
before := atomic.LoadUint64(&fallbackCounter)
cnt1 := atomic.AddUint64(&fallbackCounter, 1)
cnt2 := atomic.AddUint64(&fallbackCounter, 1)
require.Equal(t, before+1, cnt1, "第一次递增应为 before+1")
require.Equal(t, before+2, cnt2, "第二次递增应为 before+2")
require.NotEqual(t, cnt1, cnt2, "连续两次递增的计数器值应不同")
}
func TestFallbackCounter_ConcurrentIncrements(t *testing.T) {
// 验证并发递增的原子性 — 每次递增都应产生唯一值
const goroutines = 50
results := make([]uint64, goroutines)
var wg sync.WaitGroup
wg.Add(goroutines)
for i := 0; i < goroutines; i++ {
go func(idx int) {
defer wg.Done()
results[idx] = atomic.AddUint64(&fallbackCounter, 1)
}(i)
}
wg.Wait()
// 所有结果应唯一
seen := make(map[uint64]bool, goroutines)
for _, v := range results {
assert.False(t, seen[v], "并发递增产生了重复值: %d", v)
seen[v] = true
}
}
func TestGenerateRandomID_Charset(t *testing.T) {
const validChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
validSet := make(map[byte]struct{}, len(validChars))
for i := 0; i < len(validChars); i++ {
validSet[validChars[i]] = struct{}{}
}
for i := 0; i < 50; i++ {
id := generateRandomID()
for j := 0; j < len(id); j++ {
_, ok := validSet[id[j]]
require.True(t, ok, "ID 包含非法字符: %c (ID=%s)", id[j], id)
}
}
}
func TestGenerateRandomID_Length(t *testing.T) {
for i := 0; i < 100; i++ {
id := generateRandomID()
assert.Len(t, id, 12, "每次生成的 ID 长度应为 12")
}
}
func TestGenerateRandomID_ConcurrentUniqueness(t *testing.T) {
// 验证并发调用不会产生重复 ID
const goroutines = 100
results := make([]string, goroutines)
var wg sync.WaitGroup
wg.Add(goroutines)
for i := 0; i < goroutines; i++ {
go func(idx int) {
defer wg.Done()
results[idx] = generateRandomID()
}(i)
}
wg.Wait()
seen := make(map[string]bool, goroutines)
for _, id := range results {
assert.False(t, seen[id], "并发调用产生了重复 ID: %s", id)
seen[id] = true
}
}
func BenchmarkGenerateRandomID(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = generateRandomID()
}
}
...@@ -8,9 +8,21 @@ const ( ...@@ -8,9 +8,21 @@ const (
// ForcePlatform 强制平台(用于 /antigravity 路由),由 middleware.ForcePlatform 设置 // ForcePlatform 强制平台(用于 /antigravity 路由),由 middleware.ForcePlatform 设置
ForcePlatform Key = "ctx_force_platform" ForcePlatform Key = "ctx_force_platform"
// RequestID 为服务端生成/透传的请求 ID。
RequestID Key = "ctx_request_id"
// ClientRequestID 客户端请求的唯一标识,用于追踪请求全生命周期(用于 Ops 监控与排障)。 // ClientRequestID 客户端请求的唯一标识,用于追踪请求全生命周期(用于 Ops 监控与排障)。
ClientRequestID Key = "ctx_client_request_id" ClientRequestID Key = "ctx_client_request_id"
// Model 请求模型标识(用于统一请求链路日志字段)。
Model Key = "ctx_model"
// Platform 当前请求最终命中的平台(用于统一请求链路日志字段)。
Platform Key = "ctx_platform"
// AccountID 当前请求最终命中的账号 ID(用于统一请求链路日志字段)。
AccountID Key = "ctx_account_id"
// RetryCount 表示当前请求在网关层的重试次数(用于 Ops 记录与排障)。 // RetryCount 表示当前请求在网关层的重试次数(用于 Ops 记录与排障)。
RetryCount Key = "ctx_retry_count" RetryCount Key = "ctx_retry_count"
...@@ -32,4 +44,12 @@ const ( ...@@ -32,4 +44,12 @@ const (
// SingleAccountRetry 标识当前请求处于单账号 503 退避重试模式。 // SingleAccountRetry 标识当前请求处于单账号 503 退避重试模式。
// 在此模式下,Service 层的模型限流预检查将等待限流过期而非直接切换账号。 // 在此模式下,Service 层的模型限流预检查将等待限流过期而非直接切换账号。
SingleAccountRetry Key = "ctx_single_account_retry" SingleAccountRetry Key = "ctx_single_account_retry"
// PrefetchedStickyAccountID 标识上游(通常 handler)预取到的 sticky session 账号 ID。
// Service 层可复用该值,避免同请求链路重复读取 Redis。
PrefetchedStickyAccountID Key = "ctx_prefetched_sticky_account_id"
// PrefetchedStickyGroupID 标识上游预取 sticky session 时所使用的分组 ID。
// Service 层仅在分组匹配时复用 PrefetchedStickyAccountID,避免分组切换重试误用旧 sticky。
PrefetchedStickyGroupID Key = "ctx_prefetched_sticky_group_id"
) )
...@@ -38,8 +38,13 @@ const ( ...@@ -38,8 +38,13 @@ const (
// GeminiCLIOAuthClientID/Secret are the public OAuth client credentials used by Google Gemini CLI. // GeminiCLIOAuthClientID/Secret are the public OAuth client credentials used by Google Gemini CLI.
// They enable the "login without creating your own OAuth client" experience, but Google may // They enable the "login without creating your own OAuth client" experience, but Google may
// restrict which scopes are allowed for this client. // restrict which scopes are allowed for this client.
GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
GeminiCLIOAuthClientSecret = "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" // GeminiCLIOAuthClientSecret is intentionally not embedded in this repository.
// If you rely on the built-in Gemini CLI OAuth client, you MUST provide its client_secret via config/env.
GeminiCLIOAuthClientSecret = ""
// GeminiCLIOAuthClientSecretEnv is the environment variable name for the built-in client secret.
GeminiCLIOAuthClientSecretEnv = "GEMINI_CLI_OAUTH_CLIENT_SECRET"
SessionTTL = 30 * time.Minute SessionTTL = 30 * time.Minute
......
...@@ -6,10 +6,14 @@ import ( ...@@ -6,10 +6,14 @@ import (
"encoding/base64" "encoding/base64"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"net/http"
"net/url" "net/url"
"os"
"strings" "strings"
"sync" "sync"
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
) )
type OAuthConfig struct { type OAuthConfig struct {
...@@ -164,15 +168,24 @@ func EffectiveOAuthConfig(cfg OAuthConfig, oauthType string) (OAuthConfig, error ...@@ -164,15 +168,24 @@ func EffectiveOAuthConfig(cfg OAuthConfig, oauthType string) (OAuthConfig, error
} }
// Fall back to built-in Gemini CLI OAuth client when not configured. // Fall back to built-in Gemini CLI OAuth client when not configured.
// SECURITY: This repo does not embed the built-in client secret; it must be provided via env.
if effective.ClientID == "" && effective.ClientSecret == "" { if effective.ClientID == "" && effective.ClientSecret == "" {
secret := strings.TrimSpace(GeminiCLIOAuthClientSecret)
if secret == "" {
if v, ok := os.LookupEnv(GeminiCLIOAuthClientSecretEnv); ok {
secret = strings.TrimSpace(v)
}
}
if secret == "" {
return OAuthConfig{}, infraerrors.Newf(http.StatusBadRequest, "GEMINI_CLI_OAUTH_CLIENT_SECRET_MISSING", "built-in Gemini CLI OAuth client_secret is not configured; set %s or provide a custom OAuth client", GeminiCLIOAuthClientSecretEnv)
}
effective.ClientID = GeminiCLIOAuthClientID effective.ClientID = GeminiCLIOAuthClientID
effective.ClientSecret = GeminiCLIOAuthClientSecret effective.ClientSecret = secret
} else if effective.ClientID == "" || effective.ClientSecret == "" { } else if effective.ClientID == "" || effective.ClientSecret == "" {
return OAuthConfig{}, fmt.Errorf("OAuth client not configured: please set both client_id and client_secret (or leave both empty to use the built-in Gemini CLI client)") return OAuthConfig{}, infraerrors.New(http.StatusBadRequest, "GEMINI_OAUTH_CLIENT_NOT_CONFIGURED", "OAuth client not configured: please set both client_id and client_secret (or leave both empty to use the built-in Gemini CLI client)")
} }
isBuiltinClient := effective.ClientID == GeminiCLIOAuthClientID && isBuiltinClient := effective.ClientID == GeminiCLIOAuthClientID
effective.ClientSecret == GeminiCLIOAuthClientSecret
if effective.Scopes == "" { if effective.Scopes == "" {
// Use different default scopes based on OAuth type // Use different default scopes based on OAuth type
......
package geminicli package geminicli
import ( import (
"encoding/hex"
"strings" "strings"
"sync"
"testing" "testing"
"time"
) )
// ---------------------------------------------------------------------------
// SessionStore 测试
// ---------------------------------------------------------------------------
func TestSessionStore_SetAndGet(t *testing.T) {
store := NewSessionStore()
defer store.Stop()
session := &OAuthSession{
State: "test-state",
OAuthType: "code_assist",
CreatedAt: time.Now(),
}
store.Set("sid-1", session)
got, ok := store.Get("sid-1")
if !ok {
t.Fatal("期望 Get 返回 ok=true,实际返回 false")
}
if got.State != "test-state" {
t.Errorf("期望 State=%q,实际=%q", "test-state", got.State)
}
}
func TestSessionStore_GetNotFound(t *testing.T) {
store := NewSessionStore()
defer store.Stop()
_, ok := store.Get("不存在的ID")
if ok {
t.Error("期望不存在的 sessionID 返回 ok=false")
}
}
func TestSessionStore_GetExpired(t *testing.T) {
store := NewSessionStore()
defer store.Stop()
// 创建一个已过期的 session(CreatedAt 设置为 SessionTTL+1 分钟之前)
session := &OAuthSession{
State: "expired-state",
OAuthType: "code_assist",
CreatedAt: time.Now().Add(-(SessionTTL + 1*time.Minute)),
}
store.Set("expired-sid", session)
_, ok := store.Get("expired-sid")
if ok {
t.Error("期望过期的 session 返回 ok=false")
}
}
func TestSessionStore_Delete(t *testing.T) {
store := NewSessionStore()
defer store.Stop()
session := &OAuthSession{
State: "to-delete",
OAuthType: "code_assist",
CreatedAt: time.Now(),
}
store.Set("del-sid", session)
// 先确认存在
if _, ok := store.Get("del-sid"); !ok {
t.Fatal("删除前 session 应该存在")
}
store.Delete("del-sid")
if _, ok := store.Get("del-sid"); ok {
t.Error("删除后 session 不应该存在")
}
}
func TestSessionStore_Stop_Idempotent(t *testing.T) {
store := NewSessionStore()
// 多次调用 Stop 不应 panic
store.Stop()
store.Stop()
store.Stop()
}
func TestSessionStore_ConcurrentAccess(t *testing.T) {
store := NewSessionStore()
defer store.Stop()
const goroutines = 50
var wg sync.WaitGroup
wg.Add(goroutines * 3)
// 并发写入
for i := 0; i < goroutines; i++ {
go func(idx int) {
defer wg.Done()
sid := "concurrent-" + string(rune('A'+idx%26))
store.Set(sid, &OAuthSession{
State: sid,
OAuthType: "code_assist",
CreatedAt: time.Now(),
})
}(i)
}
// 并发读取
for i := 0; i < goroutines; i++ {
go func(idx int) {
defer wg.Done()
sid := "concurrent-" + string(rune('A'+idx%26))
store.Get(sid) // 可能找到也可能没找到,关键是不 panic
}(i)
}
// 并发删除
for i := 0; i < goroutines; i++ {
go func(idx int) {
defer wg.Done()
sid := "concurrent-" + string(rune('A'+idx%26))
store.Delete(sid)
}(i)
}
wg.Wait()
}
// ---------------------------------------------------------------------------
// GenerateRandomBytes 测试
// ---------------------------------------------------------------------------
func TestGenerateRandomBytes(t *testing.T) {
tests := []int{0, 1, 16, 32, 64}
for _, n := range tests {
b, err := GenerateRandomBytes(n)
if err != nil {
t.Errorf("GenerateRandomBytes(%d) 出错: %v", n, err)
continue
}
if len(b) != n {
t.Errorf("GenerateRandomBytes(%d) 返回长度=%d,期望=%d", n, len(b), n)
}
}
}
func TestGenerateRandomBytes_Uniqueness(t *testing.T) {
// 两次调用应该返回不同的结果(极小概率相同,32字节足够)
a, _ := GenerateRandomBytes(32)
b, _ := GenerateRandomBytes(32)
if string(a) == string(b) {
t.Error("两次 GenerateRandomBytes(32) 返回了相同结果,随机性可能有问题")
}
}
// ---------------------------------------------------------------------------
// GenerateState 测试
// ---------------------------------------------------------------------------
func TestGenerateState(t *testing.T) {
state, err := GenerateState()
if err != nil {
t.Fatalf("GenerateState() 出错: %v", err)
}
if state == "" {
t.Error("GenerateState() 返回空字符串")
}
// base64url 编码不应包含 padding '='
if strings.Contains(state, "=") {
t.Errorf("GenerateState() 结果包含 '=' padding: %s", state)
}
// base64url 不应包含 '+' 或 '/'
if strings.ContainsAny(state, "+/") {
t.Errorf("GenerateState() 结果包含非 base64url 字符: %s", state)
}
}
// ---------------------------------------------------------------------------
// GenerateSessionID 测试
// ---------------------------------------------------------------------------
func TestGenerateSessionID(t *testing.T) {
sid, err := GenerateSessionID()
if err != nil {
t.Fatalf("GenerateSessionID() 出错: %v", err)
}
// 16 字节 -> 32 个 hex 字符
if len(sid) != 32 {
t.Errorf("GenerateSessionID() 长度=%d,期望=32", len(sid))
}
// 必须是合法的 hex 字符串
if _, err := hex.DecodeString(sid); err != nil {
t.Errorf("GenerateSessionID() 不是合法的 hex 字符串: %s, err=%v", sid, err)
}
}
func TestGenerateSessionID_Uniqueness(t *testing.T) {
a, _ := GenerateSessionID()
b, _ := GenerateSessionID()
if a == b {
t.Error("两次 GenerateSessionID() 返回了相同结果")
}
}
// ---------------------------------------------------------------------------
// GenerateCodeVerifier 测试
// ---------------------------------------------------------------------------
func TestGenerateCodeVerifier(t *testing.T) {
verifier, err := GenerateCodeVerifier()
if err != nil {
t.Fatalf("GenerateCodeVerifier() 出错: %v", err)
}
if verifier == "" {
t.Error("GenerateCodeVerifier() 返回空字符串")
}
// RFC 7636 要求 code_verifier 至少 43 个字符
if len(verifier) < 43 {
t.Errorf("GenerateCodeVerifier() 长度=%d,RFC 7636 要求至少 43 字符", len(verifier))
}
// base64url 编码不应包含 padding 和非 URL 安全字符
if strings.Contains(verifier, "=") {
t.Errorf("GenerateCodeVerifier() 包含 '=' padding: %s", verifier)
}
if strings.ContainsAny(verifier, "+/") {
t.Errorf("GenerateCodeVerifier() 包含非 base64url 字符: %s", verifier)
}
}
// ---------------------------------------------------------------------------
// GenerateCodeChallenge 测试
// ---------------------------------------------------------------------------
func TestGenerateCodeChallenge(t *testing.T) {
// 使用已知输入验证输出
// RFC 7636 附录 B 示例: verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
// 预期 challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
expected := "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
challenge := GenerateCodeChallenge(verifier)
if challenge != expected {
t.Errorf("GenerateCodeChallenge(%q) = %q,期望 %q", verifier, challenge, expected)
}
}
func TestGenerateCodeChallenge_NoPadding(t *testing.T) {
challenge := GenerateCodeChallenge("test-verifier-string")
if strings.Contains(challenge, "=") {
t.Errorf("GenerateCodeChallenge() 结果包含 '=' padding: %s", challenge)
}
}
// ---------------------------------------------------------------------------
// base64URLEncode 测试
// ---------------------------------------------------------------------------
func TestBase64URLEncode(t *testing.T) {
tests := []struct {
name string
input []byte
}{
{"空字节", []byte{}},
{"单字节", []byte{0xff}},
{"多字节", []byte{0x01, 0x02, 0x03, 0x04, 0x05}},
{"全零", []byte{0x00, 0x00, 0x00}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := base64URLEncode(tt.input)
// 不应包含 '=' padding
if strings.Contains(result, "=") {
t.Errorf("base64URLEncode(%v) 包含 '=' padding: %s", tt.input, result)
}
// 不应包含标准 base64 的 '+' 或 '/'
if strings.ContainsAny(result, "+/") {
t.Errorf("base64URLEncode(%v) 包含非 URL 安全字符: %s", tt.input, result)
}
})
}
}
// ---------------------------------------------------------------------------
// hasRestrictedScope 测试
// ---------------------------------------------------------------------------
func TestHasRestrictedScope(t *testing.T) {
tests := []struct {
scope string
expected bool
}{
// 受限 scope
{"https://www.googleapis.com/auth/generative-language", true},
{"https://www.googleapis.com/auth/generative-language.retriever", true},
{"https://www.googleapis.com/auth/generative-language.tuning", true},
{"https://www.googleapis.com/auth/drive", true},
{"https://www.googleapis.com/auth/drive.readonly", true},
{"https://www.googleapis.com/auth/drive.file", true},
// 非受限 scope
{"https://www.googleapis.com/auth/cloud-platform", false},
{"https://www.googleapis.com/auth/userinfo.email", false},
{"https://www.googleapis.com/auth/userinfo.profile", false},
// 边界情况
{"", false},
{"random-scope", false},
}
for _, tt := range tests {
t.Run(tt.scope, func(t *testing.T) {
got := hasRestrictedScope(tt.scope)
if got != tt.expected {
t.Errorf("hasRestrictedScope(%q) = %v,期望 %v", tt.scope, got, tt.expected)
}
})
}
}
// ---------------------------------------------------------------------------
// BuildAuthorizationURL 测试
// ---------------------------------------------------------------------------
func TestBuildAuthorizationURL(t *testing.T) {
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-secret")
authURL, err := BuildAuthorizationURL(
OAuthConfig{},
"test-state",
"test-challenge",
"https://example.com/callback",
"",
"code_assist",
)
if err != nil {
t.Fatalf("BuildAuthorizationURL() 出错: %v", err)
}
// 检查返回的 URL 包含期望的参数
checks := []string{
"response_type=code",
"client_id=" + GeminiCLIOAuthClientID,
"redirect_uri=",
"state=test-state",
"code_challenge=test-challenge",
"code_challenge_method=S256",
"access_type=offline",
"prompt=consent",
"include_granted_scopes=true",
}
for _, check := range checks {
if !strings.Contains(authURL, check) {
t.Errorf("BuildAuthorizationURL() URL 缺少参数 %q\nURL: %s", check, authURL)
}
}
// 不应包含 project_id(因为传的是空字符串)
if strings.Contains(authURL, "project_id=") {
t.Errorf("BuildAuthorizationURL() 空 projectID 时不应包含 project_id 参数")
}
// URL 应该以正确的授权端点开头
if !strings.HasPrefix(authURL, AuthorizeURL+"?") {
t.Errorf("BuildAuthorizationURL() URL 应以 %s? 开头,实际: %s", AuthorizeURL, authURL)
}
}
func TestBuildAuthorizationURL_EmptyRedirectURI(t *testing.T) {
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-secret")
_, err := BuildAuthorizationURL(
OAuthConfig{},
"test-state",
"test-challenge",
"", // 空 redirectURI
"",
"code_assist",
)
if err == nil {
t.Error("BuildAuthorizationURL() 空 redirectURI 应该报错")
}
if !strings.Contains(err.Error(), "redirect_uri") {
t.Errorf("错误消息应包含 'redirect_uri',实际: %v", err)
}
}
func TestBuildAuthorizationURL_WithProjectID(t *testing.T) {
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-secret")
authURL, err := BuildAuthorizationURL(
OAuthConfig{},
"test-state",
"test-challenge",
"https://example.com/callback",
"my-project-123",
"code_assist",
)
if err != nil {
t.Fatalf("BuildAuthorizationURL() 出错: %v", err)
}
if !strings.Contains(authURL, "project_id=my-project-123") {
t.Errorf("BuildAuthorizationURL() 带 projectID 时应包含 project_id 参数\nURL: %s", authURL)
}
}
func TestBuildAuthorizationURL_OAuthConfigError(t *testing.T) {
// 不设置环境变量,也不提供 client 凭据,EffectiveOAuthConfig 应该报错
t.Setenv(GeminiCLIOAuthClientSecretEnv, "")
_, err := BuildAuthorizationURL(
OAuthConfig{},
"test-state",
"test-challenge",
"https://example.com/callback",
"",
"code_assist",
)
if err == nil {
t.Error("当 EffectiveOAuthConfig 失败时,BuildAuthorizationURL 应该返回错误")
}
}
// ---------------------------------------------------------------------------
// EffectiveOAuthConfig 测试 - 原有测试
// ---------------------------------------------------------------------------
func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) { func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
// 内置的 Gemini CLI client secret 不嵌入在此仓库中。
// 测试通过环境变量设置一个假的 secret 来模拟运维配置。
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
tests := []struct { tests := []struct {
name string name string
input OAuthConfig input OAuthConfig
...@@ -15,7 +443,7 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) { ...@@ -15,7 +443,7 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
wantErr bool wantErr bool
}{ }{
{ {
name: "Google One with built-in client (empty config)", name: "Google One 使用内置客户端(空配置)",
input: OAuthConfig{}, input: OAuthConfig{},
oauthType: "google_one", oauthType: "google_one",
wantClientID: GeminiCLIOAuthClientID, wantClientID: GeminiCLIOAuthClientID,
...@@ -23,18 +451,18 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) { ...@@ -23,18 +451,18 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
wantErr: false, wantErr: false,
}, },
{ {
name: "Google One always uses built-in client (even if custom credentials passed)", name: "Google One 使用自定义客户端(传入自定义凭据时使用自定义)",
input: OAuthConfig{ input: OAuthConfig{
ClientID: "custom-client-id", ClientID: "custom-client-id",
ClientSecret: "custom-client-secret", ClientSecret: "custom-client-secret",
}, },
oauthType: "google_one", oauthType: "google_one",
wantClientID: "custom-client-id", wantClientID: "custom-client-id",
wantScopes: DefaultCodeAssistScopes, // Uses code assist scopes even with custom client wantScopes: DefaultCodeAssistScopes,
wantErr: false, wantErr: false,
}, },
{ {
name: "Google One with built-in client and custom scopes (should filter restricted scopes)", name: "Google One 内置客户端 + 自定义 scopes(应过滤受限 scopes",
input: OAuthConfig{ input: OAuthConfig{
Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly", Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly",
}, },
...@@ -44,7 +472,7 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) { ...@@ -44,7 +472,7 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
wantErr: false, wantErr: false,
}, },
{ {
name: "Google One with built-in client and only restricted scopes (should fallback to default)", name: "Google One 内置客户端 + 仅受限 scopes(应回退到默认)",
input: OAuthConfig{ input: OAuthConfig{
Scopes: "https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly", Scopes: "https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly",
}, },
...@@ -54,7 +482,7 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) { ...@@ -54,7 +482,7 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
wantErr: false, wantErr: false,
}, },
{ {
name: "Code Assist with built-in client", name: "Code Assist 使用内置客户端",
input: OAuthConfig{}, input: OAuthConfig{},
oauthType: "code_assist", oauthType: "code_assist",
wantClientID: GeminiCLIOAuthClientID, wantClientID: GeminiCLIOAuthClientID,
...@@ -84,7 +512,9 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) { ...@@ -84,7 +512,9 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
} }
func TestEffectiveOAuthConfig_ScopeFiltering(t *testing.T) { func TestEffectiveOAuthConfig_ScopeFiltering(t *testing.T) {
// Test that Google One with built-in client filters out restricted scopes t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
// 测试 Google One + 内置客户端过滤受限 scopes
cfg, err := EffectiveOAuthConfig(OAuthConfig{ cfg, err := EffectiveOAuthConfig(OAuthConfig{
Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly https://www.googleapis.com/auth/userinfo.profile", Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly https://www.googleapis.com/auth/userinfo.profile",
}, "google_one") }, "google_one")
...@@ -93,21 +523,240 @@ func TestEffectiveOAuthConfig_ScopeFiltering(t *testing.T) { ...@@ -93,21 +523,240 @@ func TestEffectiveOAuthConfig_ScopeFiltering(t *testing.T) {
t.Fatalf("EffectiveOAuthConfig() error = %v", err) t.Fatalf("EffectiveOAuthConfig() error = %v", err)
} }
// Should only contain cloud-platform, userinfo.email, and userinfo.profile // 应仅包含 cloud-platformuserinfo.email userinfo.profile
// Should NOT contain generative-language or drive scopes // 不应包含 generative-language drive scopes
if strings.Contains(cfg.Scopes, "generative-language") { if strings.Contains(cfg.Scopes, "generative-language") {
t.Errorf("Scopes should not contain generative-language when using built-in client, got: %v", cfg.Scopes) t.Errorf("使用内置客户端时 Scopes 不应包含 generative-language,实际: %v", cfg.Scopes)
} }
if strings.Contains(cfg.Scopes, "drive") { if strings.Contains(cfg.Scopes, "drive") {
t.Errorf("Scopes should not contain drive when using built-in client, got: %v", cfg.Scopes) t.Errorf("使用内置客户端时 Scopes 不应包含 drive,实际: %v", cfg.Scopes)
} }
if !strings.Contains(cfg.Scopes, "cloud-platform") { if !strings.Contains(cfg.Scopes, "cloud-platform") {
t.Errorf("Scopes should contain cloud-platform, got: %v", cfg.Scopes) t.Errorf("Scopes 应包含 cloud-platform,实际: %v", cfg.Scopes)
} }
if !strings.Contains(cfg.Scopes, "userinfo.email") { if !strings.Contains(cfg.Scopes, "userinfo.email") {
t.Errorf("Scopes should contain userinfo.email, got: %v", cfg.Scopes) t.Errorf("Scopes 应包含 userinfo.email,实际: %v", cfg.Scopes)
} }
if !strings.Contains(cfg.Scopes, "userinfo.profile") { if !strings.Contains(cfg.Scopes, "userinfo.profile") {
t.Errorf("Scopes should contain userinfo.profile, got: %v", cfg.Scopes) t.Errorf("Scopes 应包含 userinfo.profile,实际: %v", cfg.Scopes)
}
}
// ---------------------------------------------------------------------------
// EffectiveOAuthConfig 测试 - 新增分支覆盖
// ---------------------------------------------------------------------------
func TestEffectiveOAuthConfig_OnlyClientID_NoSecret(t *testing.T) {
// 只提供 clientID 不提供 secret 应报错
_, err := EffectiveOAuthConfig(OAuthConfig{
ClientID: "some-client-id",
}, "code_assist")
if err == nil {
t.Error("只提供 ClientID 不提供 ClientSecret 应该报错")
}
if !strings.Contains(err.Error(), "client_id") || !strings.Contains(err.Error(), "client_secret") {
t.Errorf("错误消息应提及 client_id 和 client_secret,实际: %v", err)
}
}
func TestEffectiveOAuthConfig_OnlyClientSecret_NoID(t *testing.T) {
// 只提供 secret 不提供 clientID 应报错
_, err := EffectiveOAuthConfig(OAuthConfig{
ClientSecret: "some-client-secret",
}, "code_assist")
if err == nil {
t.Error("只提供 ClientSecret 不提供 ClientID 应该报错")
}
if !strings.Contains(err.Error(), "client_id") || !strings.Contains(err.Error(), "client_secret") {
t.Errorf("错误消息应提及 client_id 和 client_secret,实际: %v", err)
}
}
func TestEffectiveOAuthConfig_AIStudio_DefaultScopes_BuiltinClient(t *testing.T) {
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
// ai_studio 类型,使用内置客户端,scopes 为空 -> 应使用 DefaultCodeAssistScopes(因为内置客户端不能请求 generative-language scope)
cfg, err := EffectiveOAuthConfig(OAuthConfig{}, "ai_studio")
if err != nil {
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
}
if cfg.Scopes != DefaultCodeAssistScopes {
t.Errorf("ai_studio + 内置客户端应使用 DefaultCodeAssistScopes,实际: %q", cfg.Scopes)
}
}
func TestEffectiveOAuthConfig_AIStudio_DefaultScopes_CustomClient(t *testing.T) {
// ai_studio 类型,使用自定义客户端,scopes 为空 -> 应使用 DefaultAIStudioScopes
cfg, err := EffectiveOAuthConfig(OAuthConfig{
ClientID: "custom-id",
ClientSecret: "custom-secret",
}, "ai_studio")
if err != nil {
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
}
if cfg.Scopes != DefaultAIStudioScopes {
t.Errorf("ai_studio + 自定义客户端应使用 DefaultAIStudioScopes,实际: %q", cfg.Scopes)
}
}
func TestEffectiveOAuthConfig_AIStudio_ScopeNormalization(t *testing.T) {
// ai_studio 类型,旧的 generative-language scope 应被归一化为 generative-language.retriever
cfg, err := EffectiveOAuthConfig(OAuthConfig{
ClientID: "custom-id",
ClientSecret: "custom-secret",
Scopes: "https://www.googleapis.com/auth/generative-language https://www.googleapis.com/auth/cloud-platform",
}, "ai_studio")
if err != nil {
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
}
if strings.Contains(cfg.Scopes, "auth/generative-language ") || strings.HasSuffix(cfg.Scopes, "auth/generative-language") {
// 确保不包含未归一化的旧 scope(仅 generative-language 而非 generative-language.retriever)
parts := strings.Fields(cfg.Scopes)
for _, p := range parts {
if p == "https://www.googleapis.com/auth/generative-language" {
t.Errorf("ai_studio 应将 generative-language 归一化为 generative-language.retriever,实际 scopes: %q", cfg.Scopes)
}
}
}
if !strings.Contains(cfg.Scopes, "generative-language.retriever") {
t.Errorf("ai_studio 归一化后应包含 generative-language.retriever,实际: %q", cfg.Scopes)
}
}
func TestEffectiveOAuthConfig_CommaSeparatedScopes(t *testing.T) {
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
// 逗号分隔的 scopes 应被归一化为空格分隔
cfg, err := EffectiveOAuthConfig(OAuthConfig{
ClientID: "custom-id",
ClientSecret: "custom-secret",
Scopes: "https://www.googleapis.com/auth/cloud-platform,https://www.googleapis.com/auth/userinfo.email",
}, "code_assist")
if err != nil {
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
}
// 应该用空格分隔,而非逗号
if strings.Contains(cfg.Scopes, ",") {
t.Errorf("逗号分隔的 scopes 应被归一化为空格分隔,实际: %q", cfg.Scopes)
}
if !strings.Contains(cfg.Scopes, "cloud-platform") {
t.Errorf("归一化后应包含 cloud-platform,实际: %q", cfg.Scopes)
}
if !strings.Contains(cfg.Scopes, "userinfo.email") {
t.Errorf("归一化后应包含 userinfo.email,实际: %q", cfg.Scopes)
}
}
func TestEffectiveOAuthConfig_MixedCommaAndSpaceScopes(t *testing.T) {
// 混合逗号和空格分隔的 scopes
cfg, err := EffectiveOAuthConfig(OAuthConfig{
ClientID: "custom-id",
ClientSecret: "custom-secret",
Scopes: "https://www.googleapis.com/auth/cloud-platform, https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/userinfo.profile",
}, "code_assist")
if err != nil {
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
}
parts := strings.Fields(cfg.Scopes)
if len(parts) != 3 {
t.Errorf("归一化后应有 3 个 scope,实际: %d,scopes: %q", len(parts), cfg.Scopes)
}
}
func TestEffectiveOAuthConfig_WhitespaceTriming(t *testing.T) {
// 输入中的前后空白应被清理
cfg, err := EffectiveOAuthConfig(OAuthConfig{
ClientID: " custom-id ",
ClientSecret: " custom-secret ",
Scopes: " https://www.googleapis.com/auth/cloud-platform ",
}, "code_assist")
if err != nil {
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
}
if cfg.ClientID != "custom-id" {
t.Errorf("ClientID 应去除前后空白,实际: %q", cfg.ClientID)
}
if cfg.ClientSecret != "custom-secret" {
t.Errorf("ClientSecret 应去除前后空白,实际: %q", cfg.ClientSecret)
}
if cfg.Scopes != "https://www.googleapis.com/auth/cloud-platform" {
t.Errorf("Scopes 应去除前后空白,实际: %q", cfg.Scopes)
}
}
func TestEffectiveOAuthConfig_NoEnvSecret(t *testing.T) {
// 不设置环境变量且不提供凭据,应该报错
t.Setenv(GeminiCLIOAuthClientSecretEnv, "")
_, err := EffectiveOAuthConfig(OAuthConfig{}, "code_assist")
if err == nil {
t.Error("没有内置 secret 且未提供凭据时应该报错")
}
if !strings.Contains(err.Error(), GeminiCLIOAuthClientSecretEnv) {
t.Errorf("错误消息应提及环境变量 %s,实际: %v", GeminiCLIOAuthClientSecretEnv, err)
}
}
func TestEffectiveOAuthConfig_AIStudio_BuiltinClient_CustomScopes(t *testing.T) {
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
// ai_studio + 内置客户端 + 自定义 scopes -> 应过滤受限 scopes
cfg, err := EffectiveOAuthConfig(OAuthConfig{
Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever",
}, "ai_studio")
if err != nil {
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
}
// 内置客户端应过滤 generative-language.retriever
if strings.Contains(cfg.Scopes, "generative-language") {
t.Errorf("ai_studio + 内置客户端应过滤受限 scopes,实际: %q", cfg.Scopes)
}
if !strings.Contains(cfg.Scopes, "cloud-platform") {
t.Errorf("应保留 cloud-platform scope,实际: %q", cfg.Scopes)
}
}
func TestEffectiveOAuthConfig_UnknownOAuthType_DefaultScopes(t *testing.T) {
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
// 未知的 oauthType 应回退到默认的 code_assist scopes
cfg, err := EffectiveOAuthConfig(OAuthConfig{}, "unknown_type")
if err != nil {
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
}
if cfg.Scopes != DefaultCodeAssistScopes {
t.Errorf("未知 oauthType 应使用 DefaultCodeAssistScopes,实际: %q", cfg.Scopes)
}
}
func TestEffectiveOAuthConfig_EmptyOAuthType_DefaultScopes(t *testing.T) {
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
// 空的 oauthType 应走 default 分支,使用 DefaultCodeAssistScopes
cfg, err := EffectiveOAuthConfig(OAuthConfig{}, "")
if err != nil {
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
}
if cfg.Scopes != DefaultCodeAssistScopes {
t.Errorf("空 oauthType 应使用 DefaultCodeAssistScopes,实际: %q", cfg.Scopes)
}
}
func TestEffectiveOAuthConfig_CustomClient_NoScopeFiltering(t *testing.T) {
// 自定义客户端 + google_one + 包含受限 scopes -> 不应被过滤(因为不是内置客户端)
cfg, err := EffectiveOAuthConfig(OAuthConfig{
ClientID: "custom-id",
ClientSecret: "custom-secret",
Scopes: "https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly",
}, "google_one")
if err != nil {
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
}
// 自定义客户端不应过滤任何 scope
if !strings.Contains(cfg.Scopes, "generative-language.retriever") {
t.Errorf("自定义客户端不应过滤 generative-language.retriever,实际: %q", cfg.Scopes)
}
if !strings.Contains(cfg.Scopes, "drive.readonly") {
t.Errorf("自定义客户端不应过滤 drive.readonly,实际: %q", cfg.Scopes)
} }
} }
...@@ -44,6 +44,16 @@ func GetClientIP(c *gin.Context) string { ...@@ -44,6 +44,16 @@ func GetClientIP(c *gin.Context) string {
return normalizeIP(c.ClientIP()) return normalizeIP(c.ClientIP())
} }
// GetTrustedClientIP 从 Gin 的可信代理解析链提取客户端 IP。
// 该方法依赖 gin.Engine.SetTrustedProxies 配置,不会优先直接信任原始转发头值。
// 适用于 ACL / 风控等安全敏感场景。
func GetTrustedClientIP(c *gin.Context) string {
if c == nil {
return ""
}
return normalizeIP(c.ClientIP())
}
// normalizeIP 规范化 IP 地址,去除端口号和空格。 // normalizeIP 规范化 IP 地址,去除端口号和空格。
func normalizeIP(ip string) string { func normalizeIP(ip string) string {
ip = strings.TrimSpace(ip) ip = strings.TrimSpace(ip)
...@@ -54,29 +64,34 @@ func normalizeIP(ip string) string { ...@@ -54,29 +64,34 @@ func normalizeIP(ip string) string {
return ip return ip
} }
// isPrivateIP 检查 IP 是否为私有地址。 // privateNets 预编译私有 IP CIDR 块,避免每次调用 isPrivateIP 时重复解析
func isPrivateIP(ipStr string) bool { var privateNets []*net.IPNet
ip := net.ParseIP(ipStr)
if ip == nil {
return false
}
// 私有 IP 范围 func init() {
privateBlocks := []string{ for _, cidr := range []string{
"10.0.0.0/8", "10.0.0.0/8",
"172.16.0.0/12", "172.16.0.0/12",
"192.168.0.0/16", "192.168.0.0/16",
"127.0.0.0/8", "127.0.0.0/8",
"::1/128", "::1/128",
"fc00::/7", "fc00::/7",
} } {
_, block, err := net.ParseCIDR(cidr)
for _, block := range privateBlocks {
_, cidr, err := net.ParseCIDR(block)
if err != nil { if err != nil {
continue panic("invalid CIDR: " + cidr)
} }
if cidr.Contains(ip) { privateNets = append(privateNets, block)
}
}
// isPrivateIP 检查 IP 是否为私有地址。
func isPrivateIP(ipStr string) bool {
ip := net.ParseIP(ipStr)
if ip == nil {
return false
}
for _, block := range privateNets {
if block.Contains(ip) {
return true return true
} }
} }
......
//go:build unit
package ip
import (
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestIsPrivateIP(t *testing.T) {
tests := []struct {
name string
ip string
expected bool
}{
// 私有 IPv4
{"10.x 私有地址", "10.0.0.1", true},
{"10.x 私有地址段末", "10.255.255.255", true},
{"172.16.x 私有地址", "172.16.0.1", true},
{"172.31.x 私有地址", "172.31.255.255", true},
{"192.168.x 私有地址", "192.168.1.1", true},
{"127.0.0.1 本地回环", "127.0.0.1", true},
{"127.x 回环段", "127.255.255.255", true},
// 公网 IPv4
{"8.8.8.8 公网 DNS", "8.8.8.8", false},
{"1.1.1.1 公网", "1.1.1.1", false},
{"172.15.255.255 非私有", "172.15.255.255", false},
{"172.32.0.0 非私有", "172.32.0.0", false},
{"11.0.0.1 公网", "11.0.0.1", false},
// IPv6
{"::1 IPv6 回环", "::1", true},
{"fc00:: IPv6 私有", "fc00::1", true},
{"fd00:: IPv6 私有", "fd00::1", true},
{"2001:db8::1 IPv6 公网", "2001:db8::1", false},
// 无效输入
{"空字符串", "", false},
{"非法字符串", "not-an-ip", false},
{"不完整 IP", "192.168", false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := isPrivateIP(tc.ip)
require.Equal(t, tc.expected, got, "isPrivateIP(%q)", tc.ip)
})
}
}
func TestGetTrustedClientIPUsesGinClientIP(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
require.NoError(t, r.SetTrustedProxies(nil))
r.GET("/t", func(c *gin.Context) {
c.String(200, GetTrustedClientIP(c))
})
w := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/t", nil)
req.RemoteAddr = "9.9.9.9:12345"
req.Header.Set("X-Forwarded-For", "1.2.3.4")
req.Header.Set("X-Real-IP", "1.2.3.4")
req.Header.Set("CF-Connecting-IP", "1.2.3.4")
r.ServeHTTP(w, req)
require.Equal(t, 200, w.Code)
require.Equal(t, "9.9.9.9", w.Body.String())
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment