Commit 6b97a8be authored by Edric Li's avatar Edric Li
Browse files

Merge branch 'main' into feat/api-key-ip-restriction

parents 90798f14 62dc0b95
...@@ -23,14 +23,14 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) { ...@@ -23,14 +23,14 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
wantErr: false, wantErr: false,
}, },
{ {
name: "Google One with custom client", name: "Google One always uses built-in client (even if custom credentials passed)",
input: OAuthConfig{ input: OAuthConfig{
ClientID: "custom-client-id", ClientID: "custom-client-id",
ClientSecret: "custom-client-secret", ClientSecret: "custom-client-secret",
}, },
oauthType: "google_one", oauthType: "google_one",
wantClientID: "custom-client-id", wantClientID: "custom-client-id",
wantScopes: DefaultGoogleOneScopes, wantScopes: DefaultCodeAssistScopes, // Uses code assist scopes even with custom client
wantErr: false, wantErr: false,
}, },
{ {
......
...@@ -831,6 +831,11 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates ...@@ -831,6 +831,11 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
args = append(args, *updates.Status) args = append(args, *updates.Status)
idx++ idx++
} }
if updates.Schedulable != nil {
setClauses = append(setClauses, "schedulable = $"+itoa(idx))
args = append(args, *updates.Schedulable)
idx++
}
// JSONB 需要合并而非覆盖,使用 raw SQL 保持旧行为。 // JSONB 需要合并而非覆盖,使用 raw SQL 保持旧行为。
if len(updates.Credentials) > 0 { if len(updates.Credentials) > 0 {
payload, err := json.Marshal(updates.Credentials) payload, err := json.Marshal(updates.Credentials)
......
...@@ -30,14 +30,15 @@ func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, c ...@@ -30,14 +30,15 @@ func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, c
// Use different OAuth clients based on oauthType: // Use different OAuth clients based on oauthType:
// - code_assist: always use built-in Gemini CLI OAuth client (public) // - code_assist: always use built-in Gemini CLI OAuth client (public)
// - google_one: uses configured OAuth client when provided; otherwise falls back to built-in client // - google_one: always use built-in Gemini CLI OAuth client (public)
// - ai_studio: requires a user-provided OAuth client // - ai_studio: requires a user-provided OAuth client
oauthCfgInput := geminicli.OAuthConfig{ oauthCfgInput := geminicli.OAuthConfig{
ClientID: c.cfg.Gemini.OAuth.ClientID, ClientID: c.cfg.Gemini.OAuth.ClientID,
ClientSecret: c.cfg.Gemini.OAuth.ClientSecret, ClientSecret: c.cfg.Gemini.OAuth.ClientSecret,
Scopes: c.cfg.Gemini.OAuth.Scopes, Scopes: c.cfg.Gemini.OAuth.Scopes,
} }
if oauthType == "code_assist" { if oauthType == "code_assist" || oauthType == "google_one" {
// Force use of built-in Gemini CLI OAuth client
oauthCfgInput.ClientID = "" oauthCfgInput.ClientID = ""
oauthCfgInput.ClientSecret = "" oauthCfgInput.ClientSecret = ""
} }
...@@ -78,7 +79,8 @@ func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refresh ...@@ -78,7 +79,8 @@ func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refresh
ClientSecret: c.cfg.Gemini.OAuth.ClientSecret, ClientSecret: c.cfg.Gemini.OAuth.ClientSecret,
Scopes: c.cfg.Gemini.OAuth.Scopes, Scopes: c.cfg.Gemini.OAuth.Scopes,
} }
if oauthType == "code_assist" { if oauthType == "code_assist" || oauthType == "google_one" {
// Force use of built-in Gemini CLI OAuth client
oauthCfgInput.ClientID = "" oauthCfgInput.ClientID = ""
oauthCfgInput.ClientSecret = "" oauthCfgInput.ClientSecret = ""
} }
......
...@@ -112,10 +112,10 @@ func (r *groupRepository) Delete(ctx context.Context, id int64) error { ...@@ -112,10 +112,10 @@ func (r *groupRepository) Delete(ctx context.Context, id int64) error {
} }
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) { func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", nil) return r.ListWithFilters(ctx, params, "", "", "", nil)
} }
func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) { func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
q := r.client.Group.Query() q := r.client.Group.Query()
if platform != "" { if platform != "" {
...@@ -124,6 +124,12 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination ...@@ -124,6 +124,12 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
if status != "" { if status != "" {
q = q.Where(group.StatusEQ(status)) q = q.Where(group.StatusEQ(status))
} }
if search != "" {
q = q.Where(group.Or(
group.NameContainsFold(search),
group.DescriptionContainsFold(search),
))
}
if isExclusive != nil { if isExclusive != nil {
q = q.Where(group.IsExclusiveEQ(*isExclusive)) q = q.Where(group.IsExclusiveEQ(*isExclusive))
} }
......
...@@ -131,6 +131,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() { ...@@ -131,6 +131,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() {
pagination.PaginationParams{Page: 1, PageSize: 10}, pagination.PaginationParams{Page: 1, PageSize: 10},
service.PlatformOpenAI, service.PlatformOpenAI,
"", "",
"",
nil, nil,
) )
s.Require().NoError(err, "ListWithFilters base") s.Require().NoError(err, "ListWithFilters base")
...@@ -152,7 +153,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() { ...@@ -152,7 +153,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() {
SubscriptionType: service.SubscriptionTypeStandard, SubscriptionType: service.SubscriptionTypeStandard,
})) }))
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", nil) groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", "", nil)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(groups, len(baseGroups)+1) s.Require().Len(groups, len(baseGroups)+1)
// Verify all groups are OpenAI platform // Verify all groups are OpenAI platform
...@@ -179,7 +180,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Status() { ...@@ -179,7 +180,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Status() {
SubscriptionType: service.SubscriptionTypeStandard, SubscriptionType: service.SubscriptionTypeStandard,
})) }))
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, nil) groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, "", nil)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(groups, 1) s.Require().Len(groups, 1)
s.Require().Equal(service.StatusDisabled, groups[0].Status) s.Require().Equal(service.StatusDisabled, groups[0].Status)
...@@ -204,12 +205,117 @@ func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() { ...@@ -204,12 +205,117 @@ func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
})) }))
isExclusive := true isExclusive := true
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", &isExclusive) groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", &isExclusive)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(groups, 1) s.Require().Len(groups, 1)
s.Require().True(groups[0].IsExclusive) s.Require().True(groups[0].IsExclusive)
} }
func (s *GroupRepoSuite) TestListWithFilters_Search() {
newRepo := func() (*groupRepository, context.Context) {
tx := testEntTx(s.T())
return newGroupRepositoryWithSQL(tx.Client(), tx), context.Background()
}
containsID := func(groups []service.Group, id int64) bool {
for i := range groups {
if groups[i].ID == id {
return true
}
}
return false
}
mustCreate := func(repo *groupRepository, ctx context.Context, g *service.Group) *service.Group {
s.Require().NoError(repo.Create(ctx, g))
s.Require().NotZero(g.ID)
return g
}
newGroup := func(name string) *service.Group {
return &service.Group{
Name: name,
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
}
s.Run("search_name_should_match", func() {
repo, ctx := newRepo()
target := mustCreate(repo, ctx, newGroup("it-group-search-name-target"))
other := mustCreate(repo, ctx, newGroup("it-group-search-name-other"))
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "name-target", nil)
s.Require().NoError(err)
s.Require().True(containsID(groups, target.ID), "expected target group to match by name")
s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out")
})
s.Run("search_description_should_match", func() {
repo, ctx := newRepo()
target := newGroup("it-group-search-desc-target")
target.Description = "something about desc-needle in here"
target = mustCreate(repo, ctx, target)
other := newGroup("it-group-search-desc-other")
other.Description = "nothing to see here"
other = mustCreate(repo, ctx, other)
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "desc-needle", nil)
s.Require().NoError(err)
s.Require().True(containsID(groups, target.ID), "expected target group to match by description")
s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out")
})
s.Run("search_nonexistent_should_return_empty", func() {
repo, ctx := newRepo()
_ = mustCreate(repo, ctx, newGroup("it-group-search-nonexistent-baseline"))
search := s.T().Name() + "__no_such_group__"
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", search, nil)
s.Require().NoError(err)
s.Require().Empty(groups)
})
s.Run("search_should_be_case_insensitive", func() {
repo, ctx := newRepo()
target := mustCreate(repo, ctx, newGroup("MiXeDCaSe-Needle"))
other := mustCreate(repo, ctx, newGroup("it-group-search-case-other"))
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "mixedcase-needle", nil)
s.Require().NoError(err)
s.Require().True(containsID(groups, target.ID), "expected case-insensitive match")
s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out")
})
s.Run("search_should_escape_like_wildcards", func() {
repo, ctx := newRepo()
percentTarget := mustCreate(repo, ctx, newGroup("it-group-search-100%-target"))
percentOther := mustCreate(repo, ctx, newGroup("it-group-search-100X-other"))
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "100%", nil)
s.Require().NoError(err)
s.Require().True(containsID(groups, percentTarget.ID), "expected literal %% match")
s.Require().False(containsID(groups, percentOther.ID), "expected %% not to act as wildcard")
underscoreTarget := mustCreate(repo, ctx, newGroup("it-group-search-ab_cd-target"))
underscoreOther := mustCreate(repo, ctx, newGroup("it-group-search-abXcd-other"))
groups, _, err = repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "ab_cd", nil)
s.Require().NoError(err)
s.Require().True(containsID(groups, underscoreTarget.ID), "expected literal _ match")
s.Require().False(containsID(groups, underscoreOther.ID), "expected _ not to act as wildcard")
})
}
func (s *GroupRepoSuite) TestListWithFilters_AccountCount() { func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
g1 := &service.Group{ g1 := &service.Group{
Name: "g1", Name: "g1",
...@@ -244,7 +350,7 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() { ...@@ -244,7 +350,7 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
s.Require().NoError(err) s.Require().NoError(err)
isExclusive := true isExclusive := true
groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformAnthropic, service.StatusActive, &isExclusive) groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformAnthropic, service.StatusActive, "", &isExclusive)
s.Require().NoError(err, "ListWithFilters") s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total) s.Require().Equal(int64(1), page.Total)
s.Require().Len(groups, 1) s.Require().Len(groups, 1)
......
...@@ -304,6 +304,10 @@ func TestAPIContracts(t *testing.T) { ...@@ -304,6 +304,10 @@ func TestAPIContracts(t *testing.T) {
"turnstile_enabled": true, "turnstile_enabled": true,
"turnstile_site_key": "site-key", "turnstile_site_key": "site-key",
"turnstile_secret_key_configured": true, "turnstile_secret_key_configured": true,
"linuxdo_connect_enabled": false,
"linuxdo_connect_client_id": "",
"linuxdo_connect_client_secret_configured": false,
"linuxdo_connect_redirect_url": "",
"site_name": "Sub2API", "site_name": "Sub2API",
"site_logo": "", "site_logo": "",
"site_subtitle": "Subtitle", "site_subtitle": "Subtitle",
...@@ -390,7 +394,7 @@ func newContractDeps(t *testing.T) *contractDeps { ...@@ -390,7 +394,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo() settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg) settingService := service.NewSettingService(settingRepo, cfg)
authHandler := handler.NewAuthHandler(cfg, nil, userService) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil) adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil)
...@@ -583,7 +587,7 @@ func (stubGroupRepo) List(ctx context.Context, params pagination.PaginationParam ...@@ -583,7 +587,7 @@ func (stubGroupRepo) List(ctx context.Context, params pagination.PaginationParam
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) { func (stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
......
...@@ -19,6 +19,8 @@ func RegisterAuthRoutes( ...@@ -19,6 +19,8 @@ func RegisterAuthRoutes(
auth.POST("/register", h.Auth.Register) auth.POST("/register", h.Auth.Register)
auth.POST("/login", h.Auth.Login) auth.POST("/login", h.Auth.Login)
auth.POST("/send-verify-code", h.Auth.SendVerifyCode) auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
} }
// 公开设置(无需认证) // 公开设置(无需认证)
......
...@@ -66,6 +66,7 @@ type AccountBulkUpdate struct { ...@@ -66,6 +66,7 @@ type AccountBulkUpdate struct {
Concurrency *int Concurrency *int
Priority *int Priority *int
Status *string Status *string
Schedulable *bool
Credentials map[string]any Credentials map[string]any
Extra map[string]any Extra map[string]any
} }
......
...@@ -661,13 +661,7 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader) ...@@ -661,13 +661,7 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader)
} }
if candidates, ok := data["candidates"].([]any); ok && len(candidates) > 0 { if candidates, ok := data["candidates"].([]any); ok && len(candidates) > 0 {
if candidate, ok := candidates[0].(map[string]any); ok { if candidate, ok := candidates[0].(map[string]any); ok {
// Check for completion // Extract content first (before checking completion)
if finishReason, ok := candidate["finishReason"].(string); ok && finishReason != "" {
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
// Extract content
if content, ok := candidate["content"].(map[string]any); ok { if content, ok := candidate["content"].(map[string]any); ok {
if parts, ok := content["parts"].([]any); ok { if parts, ok := content["parts"].([]any); ok {
for _, part := range parts { for _, part := range parts {
...@@ -679,6 +673,12 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader) ...@@ -679,6 +673,12 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader)
} }
} }
} }
// Check for completion after extracting content
if finishReason, ok := candidate["finishReason"].(string); ok && finishReason != "" {
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
} }
} }
......
...@@ -24,7 +24,7 @@ type AdminService interface { ...@@ -24,7 +24,7 @@ type AdminService interface {
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
// Group management // Group management
ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error)
GetAllGroups(ctx context.Context) ([]Group, error) GetAllGroups(ctx context.Context) ([]Group, error)
GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error)
GetGroup(ctx context.Context, id int64) (*Group, error) GetGroup(ctx context.Context, id int64) (*Group, error)
...@@ -168,6 +168,7 @@ type BulkUpdateAccountsInput struct { ...@@ -168,6 +168,7 @@ type BulkUpdateAccountsInput struct {
Concurrency *int Concurrency *int
Priority *int Priority *int
Status string Status string
Schedulable *bool
GroupIDs *[]int64 GroupIDs *[]int64
Credentials map[string]any Credentials map[string]any
Extra map[string]any Extra map[string]any
...@@ -478,9 +479,9 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, ...@@ -478,9 +479,9 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64,
} }
// Group management implementations // Group management implementations
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error) { func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, isExclusive) groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, search, isExclusive)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
...@@ -910,6 +911,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp ...@@ -910,6 +911,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
if input.Status != "" { if input.Status != "" {
repoUpdates.Status = &input.Status repoUpdates.Status = &input.Status
} }
if input.Schedulable != nil {
repoUpdates.Schedulable = input.Schedulable
}
// Run bulk update for column/jsonb fields first. // Run bulk update for column/jsonb fields first.
if _, err := s.accountRepo.BulkUpdate(ctx, input.AccountIDs, repoUpdates); err != nil { if _, err := s.accountRepo.BulkUpdate(ctx, input.AccountIDs, repoUpdates); err != nil {
......
...@@ -124,7 +124,7 @@ func (s *groupRepoStub) List(ctx context.Context, params pagination.PaginationPa ...@@ -124,7 +124,7 @@ func (s *groupRepoStub) List(ctx context.Context, params pagination.PaginationPa
panic("unexpected List call") panic("unexpected List call")
} }
func (s *groupRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) { func (s *groupRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call") panic("unexpected ListWithFilters call")
} }
......
...@@ -16,6 +16,16 @@ type groupRepoStubForAdmin struct { ...@@ -16,6 +16,16 @@ type groupRepoStubForAdmin struct {
updated *Group // 记录 Update 调用的参数 updated *Group // 记录 Update 调用的参数
getByID *Group // GetByID 返回值 getByID *Group // GetByID 返回值
getErr error // GetByID 返回的错误 getErr error // GetByID 返回的错误
listWithFiltersCalls int
listWithFiltersParams pagination.PaginationParams
listWithFiltersPlatform string
listWithFiltersStatus string
listWithFiltersSearch string
listWithFiltersIsExclusive *bool
listWithFiltersGroups []Group
listWithFiltersResult *pagination.PaginationResult
listWithFiltersErr error
} }
func (s *groupRepoStubForAdmin) Create(_ context.Context, g *Group) error { func (s *groupRepoStubForAdmin) Create(_ context.Context, g *Group) error {
...@@ -47,8 +57,28 @@ func (s *groupRepoStubForAdmin) List(_ context.Context, _ pagination.PaginationP ...@@ -47,8 +57,28 @@ func (s *groupRepoStubForAdmin) List(_ context.Context, _ pagination.PaginationP
panic("unexpected List call") panic("unexpected List call")
} }
func (s *groupRepoStubForAdmin) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) { func (s *groupRepoStubForAdmin) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call") s.listWithFiltersCalls++
s.listWithFiltersParams = params
s.listWithFiltersPlatform = platform
s.listWithFiltersStatus = status
s.listWithFiltersSearch = search
s.listWithFiltersIsExclusive = isExclusive
if s.listWithFiltersErr != nil {
return nil, nil, s.listWithFiltersErr
}
result := s.listWithFiltersResult
if result == nil {
result = &pagination.PaginationResult{
Total: int64(len(s.listWithFiltersGroups)),
Page: params.Page,
PageSize: params.PageSize,
}
}
return s.listWithFiltersGroups, result, nil
} }
func (s *groupRepoStubForAdmin) ListActive(_ context.Context) ([]Group, error) { func (s *groupRepoStubForAdmin) ListActive(_ context.Context) ([]Group, error) {
...@@ -195,3 +225,68 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) { ...@@ -195,3 +225,68 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
require.InDelta(t, 0.15, *repo.updated.ImagePrice2K, 0.0001) // 原值保持 require.InDelta(t, 0.15, *repo.updated.ImagePrice2K, 0.0001) // 原值保持
require.Nil(t, repo.updated.ImagePrice4K) require.Nil(t, repo.updated.ImagePrice4K)
} }
func TestAdminService_ListGroups_WithSearch(t *testing.T) {
// 测试:
// 1. search 参数正常传递到 repository 层
// 2. search 为空字符串时的行为
// 3. search 与其他过滤条件组合使用
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
repo := &groupRepoStubForAdmin{
listWithFiltersGroups: []Group{{ID: 1, Name: "alpha"}},
listWithFiltersResult: &pagination.PaginationResult{Total: 1},
}
svc := &adminServiceImpl{groupRepo: repo}
groups, total, err := svc.ListGroups(context.Background(), 1, 20, "", "", "alpha", nil)
require.NoError(t, err)
require.Equal(t, int64(1), total)
require.Equal(t, []Group{{ID: 1, Name: "alpha"}}, groups)
require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams)
require.Equal(t, "alpha", repo.listWithFiltersSearch)
require.Nil(t, repo.listWithFiltersIsExclusive)
})
t.Run("search 为空字符串时传递空字符串", func(t *testing.T) {
repo := &groupRepoStubForAdmin{
listWithFiltersGroups: []Group{},
listWithFiltersResult: &pagination.PaginationResult{Total: 0},
}
svc := &adminServiceImpl{groupRepo: repo}
groups, total, err := svc.ListGroups(context.Background(), 2, 10, "", "", "", nil)
require.NoError(t, err)
require.Empty(t, groups)
require.Equal(t, int64(0), total)
require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10}, repo.listWithFiltersParams)
require.Equal(t, "", repo.listWithFiltersSearch)
require.Nil(t, repo.listWithFiltersIsExclusive)
})
t.Run("search 与其他过滤条件组合使用", func(t *testing.T) {
isExclusive := true
repo := &groupRepoStubForAdmin{
listWithFiltersGroups: []Group{{ID: 2, Name: "beta"}},
listWithFiltersResult: &pagination.PaginationResult{Total: 42},
}
svc := &adminServiceImpl{groupRepo: repo}
groups, total, err := svc.ListGroups(context.Background(), 3, 50, PlatformAntigravity, StatusActive, "beta", &isExclusive)
require.NoError(t, err)
require.Equal(t, int64(42), total)
require.Equal(t, []Group{{ID: 2, Name: "beta"}}, groups)
require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50}, repo.listWithFiltersParams)
require.Equal(t, PlatformAntigravity, repo.listWithFiltersPlatform)
require.Equal(t, StatusActive, repo.listWithFiltersStatus)
require.Equal(t, "beta", repo.listWithFiltersSearch)
require.NotNil(t, repo.listWithFiltersIsExclusive)
require.True(t, *repo.listWithFiltersIsExclusive)
})
}
//go:build unit
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type accountRepoStubForAdminList struct {
accountRepoStub
listWithFiltersCalls int
listWithFiltersParams pagination.PaginationParams
listWithFiltersPlatform string
listWithFiltersType string
listWithFiltersStatus string
listWithFiltersSearch string
listWithFiltersAccounts []Account
listWithFiltersResult *pagination.PaginationResult
listWithFiltersErr error
}
func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
s.listWithFiltersCalls++
s.listWithFiltersParams = params
s.listWithFiltersPlatform = platform
s.listWithFiltersType = accountType
s.listWithFiltersStatus = status
s.listWithFiltersSearch = search
if s.listWithFiltersErr != nil {
return nil, nil, s.listWithFiltersErr
}
result := s.listWithFiltersResult
if result == nil {
result = &pagination.PaginationResult{
Total: int64(len(s.listWithFiltersAccounts)),
Page: params.Page,
PageSize: params.PageSize,
}
}
return s.listWithFiltersAccounts, result, nil
}
type proxyRepoStubForAdminList struct {
proxyRepoStub
listWithFiltersCalls int
listWithFiltersParams pagination.PaginationParams
listWithFiltersProtocol string
listWithFiltersStatus string
listWithFiltersSearch string
listWithFiltersProxies []Proxy
listWithFiltersResult *pagination.PaginationResult
listWithFiltersErr error
listWithFiltersAndAccountCountCalls int
listWithFiltersAndAccountCountParams pagination.PaginationParams
listWithFiltersAndAccountCountProtocol string
listWithFiltersAndAccountCountStatus string
listWithFiltersAndAccountCountSearch string
listWithFiltersAndAccountCountProxies []ProxyWithAccountCount
listWithFiltersAndAccountCountResult *pagination.PaginationResult
listWithFiltersAndAccountCountErr error
}
func (s *proxyRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error) {
s.listWithFiltersCalls++
s.listWithFiltersParams = params
s.listWithFiltersProtocol = protocol
s.listWithFiltersStatus = status
s.listWithFiltersSearch = search
if s.listWithFiltersErr != nil {
return nil, nil, s.listWithFiltersErr
}
result := s.listWithFiltersResult
if result == nil {
result = &pagination.PaginationResult{
Total: int64(len(s.listWithFiltersProxies)),
Page: params.Page,
PageSize: params.PageSize,
}
}
return s.listWithFiltersProxies, result, nil
}
func (s *proxyRepoStubForAdminList) ListWithFiltersAndAccountCount(_ context.Context, params pagination.PaginationParams, protocol, status, search string) ([]ProxyWithAccountCount, *pagination.PaginationResult, error) {
s.listWithFiltersAndAccountCountCalls++
s.listWithFiltersAndAccountCountParams = params
s.listWithFiltersAndAccountCountProtocol = protocol
s.listWithFiltersAndAccountCountStatus = status
s.listWithFiltersAndAccountCountSearch = search
if s.listWithFiltersAndAccountCountErr != nil {
return nil, nil, s.listWithFiltersAndAccountCountErr
}
result := s.listWithFiltersAndAccountCountResult
if result == nil {
result = &pagination.PaginationResult{
Total: int64(len(s.listWithFiltersAndAccountCountProxies)),
Page: params.Page,
PageSize: params.PageSize,
}
}
return s.listWithFiltersAndAccountCountProxies, result, nil
}
type redeemRepoStubForAdminList struct {
redeemRepoStub
listWithFiltersCalls int
listWithFiltersParams pagination.PaginationParams
listWithFiltersType string
listWithFiltersStatus string
listWithFiltersSearch string
listWithFiltersCodes []RedeemCode
listWithFiltersResult *pagination.PaginationResult
listWithFiltersErr error
}
func (s *redeemRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, codeType, status, search string) ([]RedeemCode, *pagination.PaginationResult, error) {
s.listWithFiltersCalls++
s.listWithFiltersParams = params
s.listWithFiltersType = codeType
s.listWithFiltersStatus = status
s.listWithFiltersSearch = search
if s.listWithFiltersErr != nil {
return nil, nil, s.listWithFiltersErr
}
result := s.listWithFiltersResult
if result == nil {
result = &pagination.PaginationResult{
Total: int64(len(s.listWithFiltersCodes)),
Page: params.Page,
PageSize: params.PageSize,
}
}
return s.listWithFiltersCodes, result, nil
}
func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
repo := &accountRepoStubForAdminList{
listWithFiltersAccounts: []Account{{ID: 1, Name: "acc"}},
listWithFiltersResult: &pagination.PaginationResult{Total: 10},
}
svc := &adminServiceImpl{accountRepo: repo}
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc")
require.NoError(t, err)
require.Equal(t, int64(10), total)
require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts)
require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams)
require.Equal(t, PlatformGemini, repo.listWithFiltersPlatform)
require.Equal(t, AccountTypeOAuth, repo.listWithFiltersType)
require.Equal(t, StatusActive, repo.listWithFiltersStatus)
require.Equal(t, "acc", repo.listWithFiltersSearch)
})
}
func TestAdminService_ListProxies_WithSearch(t *testing.T) {
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
repo := &proxyRepoStubForAdminList{
listWithFiltersProxies: []Proxy{{ID: 2, Name: "p1"}},
listWithFiltersResult: &pagination.PaginationResult{Total: 7},
}
svc := &adminServiceImpl{proxyRepo: repo}
proxies, total, err := svc.ListProxies(context.Background(), 3, 50, "http", StatusActive, "p1")
require.NoError(t, err)
require.Equal(t, int64(7), total)
require.Equal(t, []Proxy{{ID: 2, Name: "p1"}}, proxies)
require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50}, repo.listWithFiltersParams)
require.Equal(t, "http", repo.listWithFiltersProtocol)
require.Equal(t, StatusActive, repo.listWithFiltersStatus)
require.Equal(t, "p1", repo.listWithFiltersSearch)
})
}
func TestAdminService_ListProxiesWithAccountCount_WithSearch(t *testing.T) {
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
repo := &proxyRepoStubForAdminList{
listWithFiltersAndAccountCountProxies: []ProxyWithAccountCount{{Proxy: Proxy{ID: 3, Name: "p2"}, AccountCount: 5}},
listWithFiltersAndAccountCountResult: &pagination.PaginationResult{Total: 9},
}
svc := &adminServiceImpl{proxyRepo: repo}
proxies, total, err := svc.ListProxiesWithAccountCount(context.Background(), 2, 10, "socks5", StatusDisabled, "p2")
require.NoError(t, err)
require.Equal(t, int64(9), total)
require.Equal(t, []ProxyWithAccountCount{{Proxy: Proxy{ID: 3, Name: "p2"}, AccountCount: 5}}, proxies)
require.Equal(t, 1, repo.listWithFiltersAndAccountCountCalls)
require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10}, repo.listWithFiltersAndAccountCountParams)
require.Equal(t, "socks5", repo.listWithFiltersAndAccountCountProtocol)
require.Equal(t, StatusDisabled, repo.listWithFiltersAndAccountCountStatus)
require.Equal(t, "p2", repo.listWithFiltersAndAccountCountSearch)
})
}
func TestAdminService_ListRedeemCodes_WithSearch(t *testing.T) {
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
repo := &redeemRepoStubForAdminList{
listWithFiltersCodes: []RedeemCode{{ID: 4, Code: "ABC"}},
listWithFiltersResult: &pagination.PaginationResult{Total: 3},
}
svc := &adminServiceImpl{redeemCodeRepo: repo}
codes, total, err := svc.ListRedeemCodes(context.Background(), 1, 20, RedeemTypeBalance, StatusUnused, "ABC")
require.NoError(t, err)
require.Equal(t, int64(3), total)
require.Equal(t, []RedeemCode{{ID: 4, Code: "ABC"}}, codes)
require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams)
require.Equal(t, RedeemTypeBalance, repo.listWithFiltersType)
require.Equal(t, StatusUnused, repo.listWithFiltersStatus)
require.Equal(t, "ABC", repo.listWithFiltersSearch)
})
}
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
"io" "io"
"log" "log"
mathrand "math/rand" mathrand "math/rand"
"net"
"net/http" "net/http"
"strings" "strings"
"sync/atomic" "sync/atomic"
...@@ -27,6 +28,32 @@ const ( ...@@ -27,6 +28,32 @@ const (
antigravityRetryMaxDelay = 16 * time.Second antigravityRetryMaxDelay = 16 * time.Second
) )
// isAntigravityConnectionError 判断是否为连接错误(网络超时、DNS 失败、连接拒绝)
func isAntigravityConnectionError(err error) bool {
if err == nil {
return false
}
// 检查超时错误
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
return true
}
// 检查连接错误(DNS 失败、连接拒绝)
var opErr *net.OpError
return errors.As(err, &opErr)
}
// shouldAntigravityFallbackToNextURL 判断是否应切换到下一个 URL
// 仅连接错误和 HTTP 429 触发 URL 降级
func shouldAntigravityFallbackToNextURL(err error, statusCode int) bool {
if isAntigravityConnectionError(err) {
return true
}
return statusCode == http.StatusTooManyRequests
}
// getSessionID 从 gin.Context 获取 session_id(用于日志追踪) // getSessionID 从 gin.Context 获取 session_id(用于日志追踪)
func getSessionID(c *gin.Context) string { func getSessionID(c *gin.Context) string {
if c == nil { if c == nil {
...@@ -181,45 +208,70 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account ...@@ -181,45 +208,70 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
return nil, fmt.Errorf("构建请求失败: %w", err) return nil, fmt.Errorf("构建请求失败: %w", err)
} }
// 构建 HTTP 请求(总是使用流式 endpoint,与官方客户端一致)
req, err := antigravity.NewAPIRequest(ctx, "streamGenerateContent", accessToken, requestBody)
if err != nil {
return nil, err
}
// 调试日志:Test 请求信息
log.Printf("[antigravity-Test] account=%s request_size=%d url=%s", account.Name, len(requestBody), req.URL.String())
// 代理 URL // 代理 URL
proxyURL := "" proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil { if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL() proxyURL = account.Proxy.URL()
} }
// 发送请求 // URL fallback 循环
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
if err != nil { if len(availableURLs) == 0 {
return nil, fmt.Errorf("请求失败: %w", err) availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有
} }
defer func() { _ = resp.Body.Close() }()
// 读取响应 var lastErr error
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) for urlIdx, baseURL := range availableURLs {
if err != nil { // 构建 HTTP 请求(总是使用流式 endpoint,与官方客户端一致)
return nil, fmt.Errorf("读取响应失败: %w", err) req, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, "streamGenerateContent", accessToken, requestBody)
} if err != nil {
lastErr = err
continue
}
if resp.StatusCode >= 400 { // 调试日志:Test 请求信息
return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody)) log.Printf("[antigravity-Test] account=%s request_size=%d url=%s", account.Name, len(requestBody), req.URL.String())
}
// 解析流式响应,提取文本 // 发送请求
text := extractTextFromSSEResponse(respBody) resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil {
lastErr = fmt.Errorf("请求失败: %w", err)
if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("[antigravity-Test] URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
continue
}
return nil, lastErr
}
return &TestConnectionResult{ // 读取响应
Text: text, respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
MappedModel: mappedModel, _ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏
}, nil if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
// 检查是否需要 URL 降级
if shouldAntigravityFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("[antigravity-Test] URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
continue
}
if resp.StatusCode >= 400 {
return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody))
}
// 解析流式响应,提取文本
text := extractTextFromSSEResponse(respBody)
return &TestConnectionResult{
Text: text,
MappedModel: mappedModel,
}, nil
}
return nil, lastErr
} }
// buildGeminiTestRequest 构建 Gemini 格式测试请求 // buildGeminiTestRequest 构建 Gemini 格式测试请求
...@@ -484,62 +536,86 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -484,62 +536,86 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回 // 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回
action := "streamGenerateContent" action := "streamGenerateContent"
// URL fallback 循环
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
if len(availableURLs) == 0 {
availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有
}
// 重试循环 // 重试循环
var resp *http.Response var resp *http.Response
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { urlFallbackLoop:
// 检查 context 是否已取消(客户端断开连接) for urlIdx, baseURL := range availableURLs {
select { for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
case <-ctx.Done(): // 检查 context 是否已取消(客户端断开连接)
log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err()) select {
return nil, ctx.Err() case <-ctx.Done():
default: log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err())
} return nil, ctx.Err()
default:
}
upstreamReq, err := antigravity.NewAPIRequest(ctx, action, accessToken, geminiBody) upstreamReq, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, action, accessToken, geminiBody)
if err != nil { if err != nil {
return nil, err return nil, err
} }
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil { if err != nil {
if attempt < antigravityMaxRetries { // 检查是否应触发 URL 降级
log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
if !sleepAntigravityBackoffWithContext(ctx, attempt) { antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("%s status=context_canceled_during_backoff", prefix) log.Printf("%s URL fallback (connection error): %s -> %s", prefix, baseURL, availableURLs[urlIdx+1])
return nil, ctx.Err() continue urlFallbackLoop
} }
continue if attempt < antigravityMaxRetries {
log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err)
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", prefix)
return nil, ctx.Err()
}
continue
}
log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
} }
log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
}
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) { // 检查是否应触发 URL 降级(仅 429)
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) if resp.StatusCode == http.StatusTooManyRequests && urlIdx < len(availableURLs)-1 {
_ = resp.Body.Close() respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("%s URL fallback (HTTP 429): %s -> %s body=%s", prefix, baseURL, availableURLs[urlIdx+1], truncateForLog(respBody, 200))
continue urlFallbackLoop
}
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
if attempt < antigravityMaxRetries { if attempt < antigravityMaxRetries {
log.Printf("%s status=%d retry=%d/%d body=%s", prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500)) log.Printf("%s status=%d retry=%d/%d body=%s", prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500))
if !sleepAntigravityBackoffWithContext(ctx, attempt) { if !sleepAntigravityBackoffWithContext(ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", prefix) log.Printf("%s status=context_canceled_during_backoff", prefix)
return nil, ctx.Err() return nil, ctx.Err()
}
continue
} }
continue // 所有重试都失败,标记限流状态
} if resp.StatusCode == 429 {
// 所有重试都失败,标记限流状态 s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
if resp.StatusCode == 429 { }
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody) // 最后一次尝试也失败
} resp = &http.Response{
// 最后一次尝试也失败 StatusCode: resp.StatusCode,
resp = &http.Response{ Header: resp.Header.Clone(),
StatusCode: resp.StatusCode, Body: io.NopCloser(bytes.NewReader(respBody)),
Header: resp.Header.Clone(), }
Body: io.NopCloser(bytes.NewReader(respBody)), break urlFallbackLoop
} }
break
}
break break urlFallbackLoop
}
} }
defer func() { _ = resp.Body.Close() }() defer func() { _ = resp.Body.Close() }()
...@@ -1003,61 +1079,85 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co ...@@ -1003,61 +1079,85 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后返回 // 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后返回
upstreamAction := "streamGenerateContent" upstreamAction := "streamGenerateContent"
// URL fallback 循环
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
if len(availableURLs) == 0 {
availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有
}
// 重试循环 // 重试循环
var resp *http.Response var resp *http.Response
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { urlFallbackLoop:
// 检查 context 是否已取消(客户端断开连接) for urlIdx, baseURL := range availableURLs {
select { for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
case <-ctx.Done(): // 检查 context 是否已取消(客户端断开连接)
log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err()) select {
return nil, ctx.Err() case <-ctx.Done():
default: log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err())
} return nil, ctx.Err()
default:
}
upstreamReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, wrappedBody) upstreamReq, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, upstreamAction, accessToken, wrappedBody)
if err != nil { if err != nil {
return nil, err return nil, err
} }
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil { if err != nil {
if attempt < antigravityMaxRetries { // 检查是否应触发 URL 降级
log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
if !sleepAntigravityBackoffWithContext(ctx, attempt) { antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("%s status=context_canceled_during_backoff", prefix) log.Printf("%s URL fallback (connection error): %s -> %s", prefix, baseURL, availableURLs[urlIdx+1])
return nil, ctx.Err() continue urlFallbackLoop
} }
continue if attempt < antigravityMaxRetries {
log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err)
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", prefix)
return nil, ctx.Err()
}
continue
}
log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
} }
log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
}
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) { // 检查是否应触发 URL 降级(仅 429)
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) if resp.StatusCode == http.StatusTooManyRequests && urlIdx < len(availableURLs)-1 {
_ = resp.Body.Close() respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("%s URL fallback (HTTP 429): %s -> %s body=%s", prefix, baseURL, availableURLs[urlIdx+1], truncateForLog(respBody, 200))
continue urlFallbackLoop
}
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
if attempt < antigravityMaxRetries { if attempt < antigravityMaxRetries {
log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries) log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries)
if !sleepAntigravityBackoffWithContext(ctx, attempt) { if !sleepAntigravityBackoffWithContext(ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", prefix) log.Printf("%s status=context_canceled_during_backoff", prefix)
return nil, ctx.Err() return nil, ctx.Err()
}
continue
} }
continue // 所有重试都失败,标记限流状态
} if resp.StatusCode == 429 {
// 所有重试都失败,标记限流状态 s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
if resp.StatusCode == 429 { }
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody) resp = &http.Response{
} StatusCode: resp.StatusCode,
resp = &http.Response{ Header: resp.Header.Clone(),
StatusCode: resp.StatusCode, Body: io.NopCloser(bytes.NewReader(respBody)),
Header: resp.Header.Clone(), }
Body: io.NopCloser(bytes.NewReader(respBody)), break urlFallbackLoop
} }
break
}
break break urlFallbackLoop
}
} }
defer func() { defer func() {
if resp != nil && resp.Body != nil { if resp != nil && resp.Body != nil {
......
...@@ -2,9 +2,13 @@ package service ...@@ -2,9 +2,13 @@ package service
import ( import (
"context" "context"
"crypto/rand"
"encoding/hex"
"errors" "errors"
"fmt" "fmt"
"log" "log"
"net/mail"
"strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
...@@ -18,6 +22,7 @@ var ( ...@@ -18,6 +22,7 @@ var (
ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password") ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password")
ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active") ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active")
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists") ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved")
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token") ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired") ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large") ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
...@@ -75,21 +80,30 @@ func (s *AuthService) Register(ctx context.Context, email, password string) (str ...@@ -75,21 +80,30 @@ func (s *AuthService) Register(ctx context.Context, email, password string) (str
// RegisterWithVerification 用户注册(支持邮件验证),返回token和用户 // RegisterWithVerification 用户注册(支持邮件验证),返回token和用户
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *User, error) { func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *User, error) {
// 检查是否开放注册 // 检查是否开放注册(默认关闭:settingService 未配置时不允许注册)
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) { if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
return "", nil, ErrRegDisabled return "", nil, ErrRegDisabled
} }
// 防止用户注册 LinuxDo OAuth 合成邮箱,避免第三方登录与本地账号发生碰撞。
if isReservedEmail(email) {
return "", nil, ErrEmailReserved
}
// 检查是否需要邮件验证 // 检查是否需要邮件验证
if s.settingService != nil && s.settingService.IsEmailVerifyEnabled(ctx) { if s.settingService != nil && s.settingService.IsEmailVerifyEnabled(ctx) {
// 如果邮件验证已开启但邮件服务未配置,拒绝注册
// 这是一个配置错误,不应该允许绕过验证
if s.emailService == nil {
log.Println("[Auth] Email verification enabled but email service not configured, rejecting registration")
return "", nil, ErrServiceUnavailable
}
if verifyCode == "" { if verifyCode == "" {
return "", nil, ErrEmailVerifyRequired return "", nil, ErrEmailVerifyRequired
} }
// 验证邮箱验证码 // 验证邮箱验证码
if s.emailService != nil { if err := s.emailService.VerifyCode(ctx, email, verifyCode); err != nil {
if err := s.emailService.VerifyCode(ctx, email, verifyCode); err != nil { return "", nil, fmt.Errorf("verify code: %w", err)
return "", nil, fmt.Errorf("verify code: %w", err)
}
} }
} }
...@@ -128,6 +142,10 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw ...@@ -128,6 +142,10 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
} }
if err := s.userRepo.Create(ctx, user); err != nil { if err := s.userRepo.Create(ctx, user); err != nil {
// 优先检查邮箱冲突错误(竞态条件下可能发生)
if errors.Is(err, ErrEmailExists) {
return "", nil, ErrEmailExists
}
log.Printf("[Auth] Database error creating user: %v", err) log.Printf("[Auth] Database error creating user: %v", err)
return "", nil, ErrServiceUnavailable return "", nil, ErrServiceUnavailable
} }
...@@ -148,11 +166,15 @@ type SendVerifyCodeResult struct { ...@@ -148,11 +166,15 @@ type SendVerifyCodeResult struct {
// SendVerifyCode 发送邮箱验证码(同步方式) // SendVerifyCode 发送邮箱验证码(同步方式)
func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error { func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
// 检查是否开放注册 // 检查是否开放注册(默认关闭)
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) { if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
return ErrRegDisabled return ErrRegDisabled
} }
if isReservedEmail(email) {
return ErrEmailReserved
}
// 检查邮箱是否已存在 // 检查邮箱是否已存在
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
if err != nil { if err != nil {
...@@ -181,12 +203,16 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error { ...@@ -181,12 +203,16 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*SendVerifyCodeResult, error) { func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*SendVerifyCodeResult, error) {
log.Printf("[Auth] SendVerifyCodeAsync called for email: %s", email) log.Printf("[Auth] SendVerifyCodeAsync called for email: %s", email)
// 检查是否开放注册 // 检查是否开放注册(默认关闭)
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) { if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
log.Println("[Auth] Registration is disabled") log.Println("[Auth] Registration is disabled")
return nil, ErrRegDisabled return nil, ErrRegDisabled
} }
if isReservedEmail(email) {
return nil, ErrEmailReserved
}
// 检查邮箱是否已存在 // 检查邮箱是否已存在
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
if err != nil { if err != nil {
...@@ -266,7 +292,7 @@ func (s *AuthService) IsTurnstileEnabled(ctx context.Context) bool { ...@@ -266,7 +292,7 @@ func (s *AuthService) IsTurnstileEnabled(ctx context.Context) bool {
// IsRegistrationEnabled 检查是否开放注册 // IsRegistrationEnabled 检查是否开放注册
func (s *AuthService) IsRegistrationEnabled(ctx context.Context) bool { func (s *AuthService) IsRegistrationEnabled(ctx context.Context) bool {
if s.settingService == nil { if s.settingService == nil {
return true return false // 安全默认:settingService 未配置时关闭注册
} }
return s.settingService.IsRegistrationEnabled(ctx) return s.settingService.IsRegistrationEnabled(ctx)
} }
...@@ -311,6 +337,102 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string ...@@ -311,6 +337,102 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
return token, user, nil return token, user, nil
} }
// LoginOrRegisterOAuth 用于第三方 OAuth/SSO 登录:
// - 如果邮箱已存在:直接登录(不需要本地密码)
// - 如果邮箱不存在:创建新用户并登录
//
// 注意:该函数用于“终端用户登录 Sub2API 本身”的场景(不同于上游账号的 OAuth,例如 OpenAI/Gemini)。
// 为了满足现有数据库约束(需要密码哈希),新用户会生成随机密码并进行哈希保存。
func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username string) (string, *User, error) {
email = strings.TrimSpace(email)
if email == "" || len(email) > 255 {
return "", nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
}
if _, err := mail.ParseAddress(email); err != nil {
return "", nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
}
username = strings.TrimSpace(username)
if len([]rune(username)) > 100 {
username = string([]rune(username)[:100])
}
user, err := s.userRepo.GetByEmail(ctx, email)
if err != nil {
if errors.Is(err, ErrUserNotFound) {
// OAuth 首次登录视为注册。
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
return "", nil, ErrRegDisabled
}
randomPassword, err := randomHexString(32)
if err != nil {
log.Printf("[Auth] Failed to generate random password for oauth signup: %v", err)
return "", nil, ErrServiceUnavailable
}
hashedPassword, err := s.HashPassword(randomPassword)
if err != nil {
return "", nil, fmt.Errorf("hash password: %w", err)
}
// 新用户默认值。
defaultBalance := s.cfg.Default.UserBalance
defaultConcurrency := s.cfg.Default.UserConcurrency
if s.settingService != nil {
defaultBalance = s.settingService.GetDefaultBalance(ctx)
defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
}
newUser := &User{
Email: email,
Username: username,
PasswordHash: hashedPassword,
Role: RoleUser,
Balance: defaultBalance,
Concurrency: defaultConcurrency,
Status: StatusActive,
}
if err := s.userRepo.Create(ctx, newUser); err != nil {
if errors.Is(err, ErrEmailExists) {
// 并发场景:GetByEmail 与 Create 之间用户被创建。
user, err = s.userRepo.GetByEmail(ctx, email)
if err != nil {
log.Printf("[Auth] Database error getting user after conflict: %v", err)
return "", nil, ErrServiceUnavailable
}
} else {
log.Printf("[Auth] Database error creating oauth user: %v", err)
return "", nil, ErrServiceUnavailable
}
} else {
user = newUser
}
} else {
log.Printf("[Auth] Database error during oauth login: %v", err)
return "", nil, ErrServiceUnavailable
}
}
if !user.IsActive() {
return "", nil, ErrUserNotActive
}
// 尽力补全:当用户名为空时,使用第三方返回的用户名回填。
if user.Username == "" && username != "" {
user.Username = username
if err := s.userRepo.Update(ctx, user); err != nil {
log.Printf("[Auth] Failed to update username after oauth login: %v", err)
}
}
token, err := s.GenerateToken(user)
if err != nil {
return "", nil, fmt.Errorf("generate token: %w", err)
}
return token, user, nil
}
// ValidateToken 验证JWT token并返回用户声明 // ValidateToken 验证JWT token并返回用户声明
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) { func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
// 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。 // 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。
...@@ -336,6 +458,11 @@ func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) { ...@@ -336,6 +458,11 @@ func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
if err != nil { if err != nil {
if errors.Is(err, jwt.ErrTokenExpired) { if errors.Is(err, jwt.ErrTokenExpired) {
// token 过期但仍返回 claims(用于 RefreshToken 等场景)
// jwt-go 在解析时即使遇到过期错误,token.Claims 仍会被填充
if claims, ok := token.Claims.(*JWTClaims); ok {
return claims, ErrTokenExpired
}
return nil, ErrTokenExpired return nil, ErrTokenExpired
} }
return nil, ErrInvalidToken return nil, ErrInvalidToken
...@@ -348,6 +475,22 @@ func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) { ...@@ -348,6 +475,22 @@ func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
return nil, ErrInvalidToken return nil, ErrInvalidToken
} }
func randomHexString(byteLength int) (string, error) {
if byteLength <= 0 {
byteLength = 16
}
buf := make([]byte, byteLength)
if _, err := rand.Read(buf); err != nil {
return "", err
}
return hex.EncodeToString(buf), nil
}
func isReservedEmail(email string) bool {
normalized := strings.ToLower(strings.TrimSpace(email))
return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain)
}
// GenerateToken 生成JWT token // GenerateToken 生成JWT token
func (s *AuthService) GenerateToken(user *User) (string, error) { func (s *AuthService) GenerateToken(user *User) (string, error) {
now := time.Now() now := time.Now()
......
...@@ -113,13 +113,36 @@ func TestAuthService_Register_Disabled(t *testing.T) { ...@@ -113,13 +113,36 @@ func TestAuthService_Register_Disabled(t *testing.T) {
require.ErrorIs(t, err, ErrRegDisabled) require.ErrorIs(t, err, ErrRegDisabled)
} }
func TestAuthService_Register_EmailVerifyRequired(t *testing.T) { func TestAuthService_Register_DisabledByDefault(t *testing.T) {
// 当 settings 为 nil(设置项不存在)时,注册应该默认关闭
repo := &userRepoStub{} repo := &userRepoStub{}
service := newAuthService(repo, nil, nil)
_, _, err := service.Register(context.Background(), "user@test.com", "password")
require.ErrorIs(t, err, ErrRegDisabled)
}
func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testing.T) {
repo := &userRepoStub{}
// 邮件验证开启但 emailCache 为 nil(emailService 未配置)
service := newAuthService(repo, map[string]string{ service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true", SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "true", SettingKeyEmailVerifyEnabled: "true",
}, nil) }, nil)
// 应返回服务不可用错误,而不是允许绕过验证
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code")
require.ErrorIs(t, err, ErrServiceUnavailable)
}
func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
repo := &userRepoStub{}
cache := &emailCacheStub{} // 配置 emailService
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "true",
}, cache)
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "") _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "")
require.ErrorIs(t, err, ErrEmailVerifyRequired) require.ErrorIs(t, err, ErrEmailVerifyRequired)
} }
...@@ -141,7 +164,9 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) { ...@@ -141,7 +164,9 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
func TestAuthService_Register_EmailExists(t *testing.T) { func TestAuthService_Register_EmailExists(t *testing.T) {
repo := &userRepoStub{exists: true} repo := &userRepoStub{exists: true}
service := newAuthService(repo, nil, nil) service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
_, _, err := service.Register(context.Background(), "user@test.com", "password") _, _, err := service.Register(context.Background(), "user@test.com", "password")
require.ErrorIs(t, err, ErrEmailExists) require.ErrorIs(t, err, ErrEmailExists)
...@@ -149,23 +174,50 @@ func TestAuthService_Register_EmailExists(t *testing.T) { ...@@ -149,23 +174,50 @@ func TestAuthService_Register_EmailExists(t *testing.T) {
func TestAuthService_Register_CheckEmailError(t *testing.T) { func TestAuthService_Register_CheckEmailError(t *testing.T) {
repo := &userRepoStub{existsErr: errors.New("db down")} repo := &userRepoStub{existsErr: errors.New("db down")}
service := newAuthService(repo, nil, nil) service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
_, _, err := service.Register(context.Background(), "user@test.com", "password") _, _, err := service.Register(context.Background(), "user@test.com", "password")
require.ErrorIs(t, err, ErrServiceUnavailable) require.ErrorIs(t, err, ErrServiceUnavailable)
} }
func TestAuthService_Register_ReservedEmail(t *testing.T) {
repo := &userRepoStub{}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
_, _, err := service.Register(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "password")
require.ErrorIs(t, err, ErrEmailReserved)
}
func TestAuthService_Register_CreateError(t *testing.T) { func TestAuthService_Register_CreateError(t *testing.T) {
repo := &userRepoStub{createErr: errors.New("create failed")} repo := &userRepoStub{createErr: errors.New("create failed")}
service := newAuthService(repo, nil, nil) service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
_, _, err := service.Register(context.Background(), "user@test.com", "password") _, _, err := service.Register(context.Background(), "user@test.com", "password")
require.ErrorIs(t, err, ErrServiceUnavailable) require.ErrorIs(t, err, ErrServiceUnavailable)
} }
func TestAuthService_Register_CreateEmailExistsRace(t *testing.T) {
// 模拟竞态条件:ExistsByEmail 返回 false,但 Create 时因唯一约束失败
repo := &userRepoStub{createErr: ErrEmailExists}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
_, _, err := service.Register(context.Background(), "user@test.com", "password")
require.ErrorIs(t, err, ErrEmailExists)
}
func TestAuthService_Register_Success(t *testing.T) { func TestAuthService_Register_Success(t *testing.T) {
repo := &userRepoStub{nextID: 5} repo := &userRepoStub{nextID: 5}
service := newAuthService(repo, nil, nil) service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
token, user, err := service.Register(context.Background(), "user@test.com", "password") token, user, err := service.Register(context.Background(), "user@test.com", "password")
require.NoError(t, err) require.NoError(t, err)
...@@ -180,3 +232,63 @@ func TestAuthService_Register_Success(t *testing.T) { ...@@ -180,3 +232,63 @@ func TestAuthService_Register_Success(t *testing.T) {
require.Len(t, repo.created, 1) require.Len(t, repo.created, 1)
require.True(t, user.CheckPassword("password")) require.True(t, user.CheckPassword("password"))
} }
func TestAuthService_ValidateToken_ExpiredReturnsClaimsWithError(t *testing.T) {
repo := &userRepoStub{}
service := newAuthService(repo, nil, nil)
// 创建用户并生成 token
user := &User{
ID: 1,
Email: "test@test.com",
Role: RoleUser,
Status: StatusActive,
TokenVersion: 1,
}
token, err := service.GenerateToken(user)
require.NoError(t, err)
// 验证有效 token
claims, err := service.ValidateToken(token)
require.NoError(t, err)
require.NotNil(t, claims)
require.Equal(t, int64(1), claims.UserID)
// 模拟过期 token(通过创建一个过期很久的 token)
service.cfg.JWT.ExpireHour = -1 // 设置为负数使 token 立即过期
expiredToken, err := service.GenerateToken(user)
require.NoError(t, err)
service.cfg.JWT.ExpireHour = 1 // 恢复
// 验证过期 token 应返回 claims 和 ErrTokenExpired
claims, err = service.ValidateToken(expiredToken)
require.ErrorIs(t, err, ErrTokenExpired)
require.NotNil(t, claims, "claims should not be nil when token is expired")
require.Equal(t, int64(1), claims.UserID)
require.Equal(t, "test@test.com", claims.Email)
}
func TestAuthService_RefreshToken_ExpiredTokenNoPanic(t *testing.T) {
user := &User{
ID: 1,
Email: "test@test.com",
Role: RoleUser,
Status: StatusActive,
TokenVersion: 1,
}
repo := &userRepoStub{user: user}
service := newAuthService(repo, nil, nil)
// 创建过期 token
service.cfg.JWT.ExpireHour = -1
expiredToken, err := service.GenerateToken(user)
require.NoError(t, err)
service.cfg.JWT.ExpireHour = 1
// RefreshToken 使用过期 token 不应 panic
require.NotPanics(t, func() {
newToken, err := service.RefreshToken(context.Background(), expiredToken)
require.NoError(t, err)
require.NotEmpty(t, newToken)
})
}
...@@ -105,7 +105,17 @@ const ( ...@@ -105,7 +105,17 @@ const (
// Request identity patch (Claude -> Gemini systemInstruction injection) // Request identity patch (Claude -> Gemini systemInstruction injection)
SettingKeyEnableIdentityPatch = "enable_identity_patch" SettingKeyEnableIdentityPatch = "enable_identity_patch"
SettingKeyIdentityPatchPrompt = "identity_patch_prompt" SettingKeyIdentityPatchPrompt = "identity_patch_prompt"
// LinuxDo Connect OAuth 登录(终端用户 SSO)
SettingKeyLinuxDoConnectEnabled = "linuxdo_connect_enabled"
SettingKeyLinuxDoConnectClientID = "linuxdo_connect_client_id"
SettingKeyLinuxDoConnectClientSecret = "linuxdo_connect_client_secret"
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
) )
// LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀(RFC 保留域名)。
// 目的:避免第三方登录返回的用户标识与本地真实邮箱发生碰撞,进而造成账号被接管的风险。
const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys). // AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
const AdminAPIKeyPrefix = "admin-" const AdminAPIKeyPrefix = "admin-"
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
"crypto/rand" "crypto/rand"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"log"
"math/big" "math/big"
"net/smtp" "net/smtp"
"strconv" "strconv"
...@@ -256,7 +257,9 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error ...@@ -256,7 +257,9 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
// 验证码不匹配 // 验证码不匹配
if data.Code != code { if data.Code != code {
data.Attempts++ data.Attempts++
_ = s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL) if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
log.Printf("[Email] Failed to update verification attempt count: %v", err)
}
if data.Attempts >= maxVerifyCodeAttempts { if data.Attempts >= maxVerifyCodeAttempts {
return ErrVerifyCodeMaxAttempts return ErrVerifyCodeMaxAttempts
} }
...@@ -264,7 +267,9 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error ...@@ -264,7 +267,9 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
} }
// 验证成功,删除验证码 // 验证成功,删除验证码
_ = s.cache.DeleteVerificationCode(ctx, email) if err := s.cache.DeleteVerificationCode(ctx, email); err != nil {
log.Printf("[Email] Failed to delete verification code after success: %v", err)
}
return nil return nil
} }
......
...@@ -166,7 +166,7 @@ func (m *mockGroupRepoForGemini) DeleteCascade(ctx context.Context, id int64) ([ ...@@ -166,7 +166,7 @@ func (m *mockGroupRepoForGemini) DeleteCascade(ctx context.Context, id int64) ([
func (m *mockGroupRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) { func (m *mockGroupRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
return nil, nil, nil return nil, nil, nil
} }
func (m *mockGroupRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) { func (m *mockGroupRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
return nil, nil, nil return nil, nil, nil
} }
func (m *mockGroupRepoForGemini) ListActive(ctx context.Context) ([]Group, error) { return nil, nil } func (m *mockGroupRepoForGemini) ListActive(ctx context.Context) ([]Group, error) { return nil, nil }
......
...@@ -120,15 +120,16 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64 ...@@ -120,15 +120,16 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
} }
// OAuth client selection: // OAuth client selection:
// - code_assist: always use built-in Gemini CLI OAuth client (public), regardless of configured client_id/secret. // - code_assist: always use built-in Gemini CLI OAuth client (public)
// - google_one: uses configured OAuth client when provided; otherwise falls back to built-in client. // - google_one: always use built-in Gemini CLI OAuth client (public)
// - ai_studio: requires a user-provided OAuth client. // - ai_studio: requires a user-provided OAuth client
oauthCfg := geminicli.OAuthConfig{ oauthCfg := geminicli.OAuthConfig{
ClientID: s.cfg.Gemini.OAuth.ClientID, ClientID: s.cfg.Gemini.OAuth.ClientID,
ClientSecret: s.cfg.Gemini.OAuth.ClientSecret, ClientSecret: s.cfg.Gemini.OAuth.ClientSecret,
Scopes: s.cfg.Gemini.OAuth.Scopes, Scopes: s.cfg.Gemini.OAuth.Scopes,
} }
if oauthType == "code_assist" { if oauthType == "code_assist" || oauthType == "google_one" {
// Force use of built-in Gemini CLI OAuth client
oauthCfg.ClientID = "" oauthCfg.ClientID = ""
oauthCfg.ClientSecret = "" oauthCfg.ClientSecret = ""
} }
...@@ -576,6 +577,20 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch ...@@ -576,6 +577,20 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
case "google_one": case "google_one":
log.Printf("[GeminiOAuth] Processing google_one OAuth type") log.Printf("[GeminiOAuth] Processing google_one OAuth type")
// Google One accounts use cloudaicompanion API, which requires a project_id.
// For personal accounts, Google auto-assigns a project_id via the LoadCodeAssist API.
if projectID == "" {
log.Printf("[GeminiOAuth] No project_id provided, attempting to fetch from LoadCodeAssist API...")
var err error
projectID, _, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
if err != nil {
log.Printf("[GeminiOAuth] ERROR: Failed to fetch project_id: %v", err)
return nil, fmt.Errorf("google One accounts require a project_id, failed to auto-detect: %w", err)
}
log.Printf("[GeminiOAuth] Successfully fetched project_id: %s", projectID)
}
log.Printf("[GeminiOAuth] Attempting to fetch Google One tier from Drive API...") log.Printf("[GeminiOAuth] Attempting to fetch Google One tier from Drive API...")
// Attempt to fetch Drive storage tier // Attempt to fetch Drive storage tier
var storageInfo *geminicli.DriveStorageInfo var storageInfo *geminicli.DriveStorageInfo
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment