Commit 11bfc807 authored by song's avatar song
Browse files

merge upstream/main

parents c2a6ca8d 62dc0b95
...@@ -24,18 +24,19 @@ func (s *GatewayCacheSuite) SetupTest() { ...@@ -24,18 +24,19 @@ func (s *GatewayCacheSuite) SetupTest() {
} }
func (s *GatewayCacheSuite) TestGetSessionAccountID_Missing() { func (s *GatewayCacheSuite) TestGetSessionAccountID_Missing() {
_, err := s.cache.GetSessionAccountID(s.ctx, "nonexistent") _, err := s.cache.GetSessionAccountID(s.ctx, 1, "nonexistent")
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing session") require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing session")
} }
func (s *GatewayCacheSuite) TestSetAndGetSessionAccountID() { func (s *GatewayCacheSuite) TestSetAndGetSessionAccountID() {
sessionID := "s1" sessionID := "s1"
accountID := int64(99) accountID := int64(99)
groupID := int64(1)
sessionTTL := 1 * time.Minute sessionTTL := 1 * time.Minute
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, sessionTTL), "SetSessionAccountID") require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, sessionTTL), "SetSessionAccountID")
sid, err := s.cache.GetSessionAccountID(s.ctx, sessionID) sid, err := s.cache.GetSessionAccountID(s.ctx, groupID, sessionID)
require.NoError(s.T(), err, "GetSessionAccountID") require.NoError(s.T(), err, "GetSessionAccountID")
require.Equal(s.T(), accountID, sid, "session id mismatch") require.Equal(s.T(), accountID, sid, "session id mismatch")
} }
...@@ -43,11 +44,12 @@ func (s *GatewayCacheSuite) TestSetAndGetSessionAccountID() { ...@@ -43,11 +44,12 @@ func (s *GatewayCacheSuite) TestSetAndGetSessionAccountID() {
func (s *GatewayCacheSuite) TestSessionAccountID_TTL() { func (s *GatewayCacheSuite) TestSessionAccountID_TTL() {
sessionID := "s2" sessionID := "s2"
accountID := int64(100) accountID := int64(100)
groupID := int64(1)
sessionTTL := 1 * time.Minute sessionTTL := 1 * time.Minute
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, sessionTTL), "SetSessionAccountID") require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, sessionTTL), "SetSessionAccountID")
sessionKey := stickySessionPrefix + sessionID sessionKey := buildSessionKey(groupID, sessionID)
ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result() ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result()
require.NoError(s.T(), err, "TTL sessionKey after Set") require.NoError(s.T(), err, "TTL sessionKey after Set")
s.AssertTTLWithin(ttl, 1*time.Second, sessionTTL) s.AssertTTLWithin(ttl, 1*time.Second, sessionTTL)
...@@ -56,14 +58,15 @@ func (s *GatewayCacheSuite) TestSessionAccountID_TTL() { ...@@ -56,14 +58,15 @@ func (s *GatewayCacheSuite) TestSessionAccountID_TTL() {
func (s *GatewayCacheSuite) TestRefreshSessionTTL() { func (s *GatewayCacheSuite) TestRefreshSessionTTL() {
sessionID := "s3" sessionID := "s3"
accountID := int64(101) accountID := int64(101)
groupID := int64(1)
initialTTL := 1 * time.Minute initialTTL := 1 * time.Minute
refreshTTL := 3 * time.Minute refreshTTL := 3 * time.Minute
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, initialTTL), "SetSessionAccountID") require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, initialTTL), "SetSessionAccountID")
require.NoError(s.T(), s.cache.RefreshSessionTTL(s.ctx, sessionID, refreshTTL), "RefreshSessionTTL") require.NoError(s.T(), s.cache.RefreshSessionTTL(s.ctx, groupID, sessionID, refreshTTL), "RefreshSessionTTL")
sessionKey := stickySessionPrefix + sessionID sessionKey := buildSessionKey(groupID, sessionID)
ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result() ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result()
require.NoError(s.T(), err, "TTL after Refresh") require.NoError(s.T(), err, "TTL after Refresh")
s.AssertTTLWithin(ttl, 1*time.Second, refreshTTL) s.AssertTTLWithin(ttl, 1*time.Second, refreshTTL)
...@@ -71,18 +74,19 @@ func (s *GatewayCacheSuite) TestRefreshSessionTTL() { ...@@ -71,18 +74,19 @@ func (s *GatewayCacheSuite) TestRefreshSessionTTL() {
func (s *GatewayCacheSuite) TestRefreshSessionTTL_MissingKey() { func (s *GatewayCacheSuite) TestRefreshSessionTTL_MissingKey() {
// RefreshSessionTTL on a missing key should not error (no-op) // RefreshSessionTTL on a missing key should not error (no-op)
err := s.cache.RefreshSessionTTL(s.ctx, "missing-session", 1*time.Minute) err := s.cache.RefreshSessionTTL(s.ctx, 1, "missing-session", 1*time.Minute)
require.NoError(s.T(), err, "RefreshSessionTTL on missing key should not error") require.NoError(s.T(), err, "RefreshSessionTTL on missing key should not error")
} }
func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() { func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
sessionID := "corrupted" sessionID := "corrupted"
sessionKey := stickySessionPrefix + sessionID groupID := int64(1)
sessionKey := buildSessionKey(groupID, sessionID)
// Set a non-integer value // Set a non-integer value
require.NoError(s.T(), s.rdb.Set(s.ctx, sessionKey, "not-a-number", 1*time.Minute).Err(), "Set invalid value") require.NoError(s.T(), s.rdb.Set(s.ctx, sessionKey, "not-a-number", 1*time.Minute).Err(), "Set invalid value")
_, err := s.cache.GetSessionAccountID(s.ctx, sessionID) _, err := s.cache.GetSessionAccountID(s.ctx, groupID, sessionID)
require.Error(s.T(), err, "expected error for corrupted value") require.Error(s.T(), err, "expected error for corrupted value")
require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil") require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil")
} }
......
...@@ -30,14 +30,15 @@ func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, c ...@@ -30,14 +30,15 @@ func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, c
// Use different OAuth clients based on oauthType: // Use different OAuth clients based on oauthType:
// - code_assist: always use built-in Gemini CLI OAuth client (public) // - code_assist: always use built-in Gemini CLI OAuth client (public)
// - google_one: uses configured OAuth client when provided; otherwise falls back to built-in client // - google_one: always use built-in Gemini CLI OAuth client (public)
// - ai_studio: requires a user-provided OAuth client // - ai_studio: requires a user-provided OAuth client
oauthCfgInput := geminicli.OAuthConfig{ oauthCfgInput := geminicli.OAuthConfig{
ClientID: c.cfg.Gemini.OAuth.ClientID, ClientID: c.cfg.Gemini.OAuth.ClientID,
ClientSecret: c.cfg.Gemini.OAuth.ClientSecret, ClientSecret: c.cfg.Gemini.OAuth.ClientSecret,
Scopes: c.cfg.Gemini.OAuth.Scopes, Scopes: c.cfg.Gemini.OAuth.Scopes,
} }
if oauthType == "code_assist" { if oauthType == "code_assist" || oauthType == "google_one" {
// Force use of built-in Gemini CLI OAuth client
oauthCfgInput.ClientID = "" oauthCfgInput.ClientID = ""
oauthCfgInput.ClientSecret = "" oauthCfgInput.ClientSecret = ""
} }
...@@ -78,7 +79,8 @@ func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refresh ...@@ -78,7 +79,8 @@ func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refresh
ClientSecret: c.cfg.Gemini.OAuth.ClientSecret, ClientSecret: c.cfg.Gemini.OAuth.ClientSecret,
Scopes: c.cfg.Gemini.OAuth.Scopes, Scopes: c.cfg.Gemini.OAuth.Scopes,
} }
if oauthType == "code_assist" { if oauthType == "code_assist" || oauthType == "google_one" {
// Force use of built-in Gemini CLI OAuth client
oauthCfgInput.ClientID = "" oauthCfgInput.ClientID = ""
oauthCfgInput.ClientSecret = "" oauthCfgInput.ClientSecret = ""
} }
......
...@@ -46,7 +46,9 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er ...@@ -46,7 +46,9 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice1k(groupIn.ImagePrice1K). SetNillableImagePrice1k(groupIn.ImagePrice1K).
SetNillableImagePrice2k(groupIn.ImagePrice2K). SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K). SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetDefaultValidityDays(groupIn.DefaultValidityDays) SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetNillableFallbackGroupID(groupIn.FallbackGroupID)
created, err := builder.Save(ctx) created, err := builder.Save(ctx)
if err == nil { if err == nil {
...@@ -72,7 +74,7 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group ...@@ -72,7 +74,7 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group
} }
func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) error { func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) error {
updated, err := r.client.Group.UpdateOneID(groupIn.ID). builder := r.client.Group.UpdateOneID(groupIn.ID).
SetName(groupIn.Name). SetName(groupIn.Name).
SetDescription(groupIn.Description). SetDescription(groupIn.Description).
SetPlatform(groupIn.Platform). SetPlatform(groupIn.Platform).
...@@ -87,7 +89,16 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er ...@@ -87,7 +89,16 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice2k(groupIn.ImagePrice2K). SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K). SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetDefaultValidityDays(groupIn.DefaultValidityDays). SetDefaultValidityDays(groupIn.DefaultValidityDays).
Save(ctx) SetClaudeCodeOnly(groupIn.ClaudeCodeOnly)
// 处理 FallbackGroupID:nil 时清除,否则设置
if groupIn.FallbackGroupID != nil {
builder = builder.SetFallbackGroupID(*groupIn.FallbackGroupID)
} else {
builder = builder.ClearFallbackGroupID()
}
updated, err := builder.Save(ctx)
if err != nil { if err != nil {
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists) return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
} }
...@@ -101,10 +112,10 @@ func (r *groupRepository) Delete(ctx context.Context, id int64) error { ...@@ -101,10 +112,10 @@ func (r *groupRepository) Delete(ctx context.Context, id int64) error {
} }
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) { func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", nil) return r.ListWithFilters(ctx, params, "", "", "", nil)
} }
func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) { func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
q := r.client.Group.Query() q := r.client.Group.Query()
if platform != "" { if platform != "" {
...@@ -113,6 +124,12 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination ...@@ -113,6 +124,12 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
if status != "" { if status != "" {
q = q.Where(group.StatusEQ(status)) q = q.Where(group.StatusEQ(status))
} }
if search != "" {
q = q.Where(group.Or(
group.NameContainsFold(search),
group.DescriptionContainsFold(search),
))
}
if isExclusive != nil { if isExclusive != nil {
q = q.Where(group.IsExclusiveEQ(*isExclusive)) q = q.Where(group.IsExclusiveEQ(*isExclusive))
} }
......
...@@ -131,6 +131,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() { ...@@ -131,6 +131,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() {
pagination.PaginationParams{Page: 1, PageSize: 10}, pagination.PaginationParams{Page: 1, PageSize: 10},
service.PlatformOpenAI, service.PlatformOpenAI,
"", "",
"",
nil, nil,
) )
s.Require().NoError(err, "ListWithFilters base") s.Require().NoError(err, "ListWithFilters base")
...@@ -152,7 +153,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() { ...@@ -152,7 +153,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() {
SubscriptionType: service.SubscriptionTypeStandard, SubscriptionType: service.SubscriptionTypeStandard,
})) }))
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", nil) groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", "", nil)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(groups, len(baseGroups)+1) s.Require().Len(groups, len(baseGroups)+1)
// Verify all groups are OpenAI platform // Verify all groups are OpenAI platform
...@@ -179,7 +180,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Status() { ...@@ -179,7 +180,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Status() {
SubscriptionType: service.SubscriptionTypeStandard, SubscriptionType: service.SubscriptionTypeStandard,
})) }))
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, nil) groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, "", nil)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(groups, 1) s.Require().Len(groups, 1)
s.Require().Equal(service.StatusDisabled, groups[0].Status) s.Require().Equal(service.StatusDisabled, groups[0].Status)
...@@ -204,12 +205,117 @@ func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() { ...@@ -204,12 +205,117 @@ func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
})) }))
isExclusive := true isExclusive := true
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", &isExclusive) groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", &isExclusive)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(groups, 1) s.Require().Len(groups, 1)
s.Require().True(groups[0].IsExclusive) s.Require().True(groups[0].IsExclusive)
} }
func (s *GroupRepoSuite) TestListWithFilters_Search() {
newRepo := func() (*groupRepository, context.Context) {
tx := testEntTx(s.T())
return newGroupRepositoryWithSQL(tx.Client(), tx), context.Background()
}
containsID := func(groups []service.Group, id int64) bool {
for i := range groups {
if groups[i].ID == id {
return true
}
}
return false
}
mustCreate := func(repo *groupRepository, ctx context.Context, g *service.Group) *service.Group {
s.Require().NoError(repo.Create(ctx, g))
s.Require().NotZero(g.ID)
return g
}
newGroup := func(name string) *service.Group {
return &service.Group{
Name: name,
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
}
s.Run("search_name_should_match", func() {
repo, ctx := newRepo()
target := mustCreate(repo, ctx, newGroup("it-group-search-name-target"))
other := mustCreate(repo, ctx, newGroup("it-group-search-name-other"))
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "name-target", nil)
s.Require().NoError(err)
s.Require().True(containsID(groups, target.ID), "expected target group to match by name")
s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out")
})
s.Run("search_description_should_match", func() {
repo, ctx := newRepo()
target := newGroup("it-group-search-desc-target")
target.Description = "something about desc-needle in here"
target = mustCreate(repo, ctx, target)
other := newGroup("it-group-search-desc-other")
other.Description = "nothing to see here"
other = mustCreate(repo, ctx, other)
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "desc-needle", nil)
s.Require().NoError(err)
s.Require().True(containsID(groups, target.ID), "expected target group to match by description")
s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out")
})
s.Run("search_nonexistent_should_return_empty", func() {
repo, ctx := newRepo()
_ = mustCreate(repo, ctx, newGroup("it-group-search-nonexistent-baseline"))
search := s.T().Name() + "__no_such_group__"
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", search, nil)
s.Require().NoError(err)
s.Require().Empty(groups)
})
s.Run("search_should_be_case_insensitive", func() {
repo, ctx := newRepo()
target := mustCreate(repo, ctx, newGroup("MiXeDCaSe-Needle"))
other := mustCreate(repo, ctx, newGroup("it-group-search-case-other"))
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "mixedcase-needle", nil)
s.Require().NoError(err)
s.Require().True(containsID(groups, target.ID), "expected case-insensitive match")
s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out")
})
s.Run("search_should_escape_like_wildcards", func() {
repo, ctx := newRepo()
percentTarget := mustCreate(repo, ctx, newGroup("it-group-search-100%-target"))
percentOther := mustCreate(repo, ctx, newGroup("it-group-search-100X-other"))
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "100%", nil)
s.Require().NoError(err)
s.Require().True(containsID(groups, percentTarget.ID), "expected literal %% match")
s.Require().False(containsID(groups, percentOther.ID), "expected %% not to act as wildcard")
underscoreTarget := mustCreate(repo, ctx, newGroup("it-group-search-ab_cd-target"))
underscoreOther := mustCreate(repo, ctx, newGroup("it-group-search-abXcd-other"))
groups, _, err = repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "ab_cd", nil)
s.Require().NoError(err)
s.Require().True(containsID(groups, underscoreTarget.ID), "expected literal _ match")
s.Require().False(containsID(groups, underscoreOther.ID), "expected _ not to act as wildcard")
})
}
func (s *GroupRepoSuite) TestListWithFilters_AccountCount() { func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
g1 := &service.Group{ g1 := &service.Group{
Name: "g1", Name: "g1",
...@@ -244,7 +350,7 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() { ...@@ -244,7 +350,7 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
s.Require().NoError(err) s.Require().NoError(err)
isExclusive := true isExclusive := true
groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformAnthropic, service.StatusActive, &isExclusive) groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformAnthropic, service.StatusActive, "", &isExclusive)
s.Require().NoError(err, "ListWithFilters") s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total) s.Require().Equal(int64(1), page.Total)
s.Require().Len(groups, 1) s.Require().Len(groups, 1)
......
...@@ -133,6 +133,55 @@ func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination ...@@ -133,6 +133,55 @@ func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination
return outProxies, paginationResultFromTotal(int64(total), params), nil return outProxies, paginationResultFromTotal(int64(total), params), nil
} }
// ListWithFiltersAndAccountCount lists proxies with filters and includes account count per proxy
func (r *proxyRepository) ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]service.ProxyWithAccountCount, *pagination.PaginationResult, error) {
q := r.client.Proxy.Query()
if protocol != "" {
q = q.Where(proxy.ProtocolEQ(protocol))
}
if status != "" {
q = q.Where(proxy.StatusEQ(status))
}
if search != "" {
q = q.Where(proxy.NameContainsFold(search))
}
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
}
proxies, err := q.
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(proxy.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
// Get account counts
counts, err := r.GetAccountCountsForProxies(ctx)
if err != nil {
return nil, nil, err
}
// Build result with account counts
result := make([]service.ProxyWithAccountCount, 0, len(proxies))
for i := range proxies {
proxyOut := proxyEntityToService(proxies[i])
if proxyOut == nil {
continue
}
result = append(result, service.ProxyWithAccountCount{
Proxy: *proxyOut,
AccountCount: counts[proxyOut.ID],
})
}
return result, paginationResultFromTotal(int64(total), params), nil
}
func (r *proxyRepository) ListActive(ctx context.Context) ([]service.Proxy, error) { func (r *proxyRepository) ListActive(ctx context.Context) ([]service.Proxy, error) {
proxies, err := r.client.Proxy.Query(). proxies, err := r.client.Proxy.Query().
Where(proxy.StatusEQ(service.StatusActive)). Where(proxy.StatusEQ(service.StatusActive)).
......
...@@ -243,7 +243,8 @@ func TestAPIContracts(t *testing.T) { ...@@ -243,7 +243,8 @@ func TestAPIContracts(t *testing.T) {
"first_token_ms": 50, "first_token_ms": 50,
"image_count": 0, "image_count": 0,
"image_size": null, "image_size": null,
"created_at": "2025-01-02T03:04:05Z" "created_at": "2025-01-02T03:04:05Z",
"user_agent": null
} }
], ],
"total": 1, "total": 1,
...@@ -303,6 +304,10 @@ func TestAPIContracts(t *testing.T) { ...@@ -303,6 +304,10 @@ func TestAPIContracts(t *testing.T) {
"turnstile_enabled": true, "turnstile_enabled": true,
"turnstile_site_key": "site-key", "turnstile_site_key": "site-key",
"turnstile_secret_key_configured": true, "turnstile_secret_key_configured": true,
"linuxdo_connect_enabled": false,
"linuxdo_connect_client_id": "",
"linuxdo_connect_client_secret_configured": false,
"linuxdo_connect_redirect_url": "",
"site_name": "Sub2API", "site_name": "Sub2API",
"site_logo": "", "site_logo": "",
"site_subtitle": "Subtitle", "site_subtitle": "Subtitle",
...@@ -389,7 +394,7 @@ func newContractDeps(t *testing.T) *contractDeps { ...@@ -389,7 +394,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo() settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg) settingService := service.NewSettingService(settingRepo, cfg)
authHandler := handler.NewAuthHandler(cfg, nil, userService) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil) adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil)
...@@ -582,7 +587,7 @@ func (stubGroupRepo) List(ctx context.Context, params pagination.PaginationParam ...@@ -582,7 +587,7 @@ func (stubGroupRepo) List(ctx context.Context, params pagination.PaginationParam
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) { func (stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
......
...@@ -19,6 +19,8 @@ func RegisterAuthRoutes( ...@@ -19,6 +19,8 @@ func RegisterAuthRoutes(
auth.POST("/register", h.Auth.Register) auth.POST("/register", h.Auth.Register)
auth.POST("/login", h.Auth.Login) auth.POST("/login", h.Auth.Login)
auth.POST("/send-verify-code", h.Auth.SendVerifyCode) auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
} }
// 公开设置(无需认证) // 公开设置(无需认证)
......
...@@ -68,6 +68,7 @@ type AccountBulkUpdate struct { ...@@ -68,6 +68,7 @@ type AccountBulkUpdate struct {
Concurrency *int Concurrency *int
Priority *int Priority *int
Status *string Status *string
Schedulable *bool
Credentials map[string]any Credentials map[string]any
Extra map[string]any Extra map[string]any
} }
......
...@@ -661,13 +661,7 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader) ...@@ -661,13 +661,7 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader)
} }
if candidates, ok := data["candidates"].([]any); ok && len(candidates) > 0 { if candidates, ok := data["candidates"].([]any); ok && len(candidates) > 0 {
if candidate, ok := candidates[0].(map[string]any); ok { if candidate, ok := candidates[0].(map[string]any); ok {
// Check for completion // Extract content first (before checking completion)
if finishReason, ok := candidate["finishReason"].(string); ok && finishReason != "" {
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
// Extract content
if content, ok := candidate["content"].(map[string]any); ok { if content, ok := candidate["content"].(map[string]any); ok {
if parts, ok := content["parts"].([]any); ok { if parts, ok := content["parts"].([]any); ok {
for _, part := range parts { for _, part := range parts {
...@@ -679,6 +673,12 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader) ...@@ -679,6 +673,12 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader)
} }
} }
} }
// Check for completion after extracting content
if finishReason, ok := candidate["finishReason"].(string); ok && finishReason != "" {
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
} }
} }
......
...@@ -24,7 +24,7 @@ type AdminService interface { ...@@ -24,7 +24,7 @@ type AdminService interface {
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
// Group management // Group management
ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error)
GetAllGroups(ctx context.Context) ([]Group, error) GetAllGroups(ctx context.Context) ([]Group, error)
GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error)
GetGroup(ctx context.Context, id int64) (*Group, error) GetGroup(ctx context.Context, id int64) (*Group, error)
...@@ -47,6 +47,7 @@ type AdminService interface { ...@@ -47,6 +47,7 @@ type AdminService interface {
// Proxy management // Proxy management
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error)
ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]ProxyWithAccountCount, int64, error)
GetAllProxies(ctx context.Context) ([]Proxy, error) GetAllProxies(ctx context.Context) ([]Proxy, error)
GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error)
GetProxy(ctx context.Context, id int64) (*Proxy, error) GetProxy(ctx context.Context, id int64) (*Proxy, error)
...@@ -99,9 +100,11 @@ type CreateGroupInput struct { ...@@ -99,9 +100,11 @@ type CreateGroupInput struct {
WeeklyLimitUSD *float64 // 周限额 (USD) WeeklyLimitUSD *float64 // 周限额 (USD)
MonthlyLimitUSD *float64 // 月限额 (USD) MonthlyLimitUSD *float64 // 月限额 (USD)
// 图片生成计费配置(仅 antigravity 平台使用) // 图片生成计费配置(仅 antigravity 平台使用)
ImagePrice1K *float64 ImagePrice1K *float64
ImagePrice2K *float64 ImagePrice2K *float64
ImagePrice4K *float64 ImagePrice4K *float64
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID
} }
type UpdateGroupInput struct { type UpdateGroupInput struct {
...@@ -116,9 +119,11 @@ type UpdateGroupInput struct { ...@@ -116,9 +119,11 @@ type UpdateGroupInput struct {
WeeklyLimitUSD *float64 // 周限额 (USD) WeeklyLimitUSD *float64 // 周限额 (USD)
MonthlyLimitUSD *float64 // 月限额 (USD) MonthlyLimitUSD *float64 // 月限额 (USD)
// 图片生成计费配置(仅 antigravity 平台使用) // 图片生成计费配置(仅 antigravity 平台使用)
ImagePrice1K *float64 ImagePrice1K *float64
ImagePrice2K *float64 ImagePrice2K *float64
ImagePrice4K *float64 ImagePrice4K *float64
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID
} }
type CreateAccountInput struct { type CreateAccountInput struct {
...@@ -163,6 +168,7 @@ type BulkUpdateAccountsInput struct { ...@@ -163,6 +168,7 @@ type BulkUpdateAccountsInput struct {
Concurrency *int Concurrency *int
Priority *int Priority *int
Status string Status string
Schedulable *bool
GroupIDs *[]int64 GroupIDs *[]int64
Credentials map[string]any Credentials map[string]any
Extra map[string]any Extra map[string]any
...@@ -473,9 +479,9 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, ...@@ -473,9 +479,9 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64,
} }
// Group management implementations // Group management implementations
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error) { func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, isExclusive) groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, search, isExclusive)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
...@@ -515,6 +521,13 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ...@@ -515,6 +521,13 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
imagePrice2K := normalizePrice(input.ImagePrice2K) imagePrice2K := normalizePrice(input.ImagePrice2K)
imagePrice4K := normalizePrice(input.ImagePrice4K) imagePrice4K := normalizePrice(input.ImagePrice4K)
// 校验降级分组
if input.FallbackGroupID != nil {
if err := s.validateFallbackGroup(ctx, 0, *input.FallbackGroupID); err != nil {
return nil, err
}
}
group := &Group{ group := &Group{
Name: input.Name, Name: input.Name,
Description: input.Description, Description: input.Description,
...@@ -529,6 +542,8 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ...@@ -529,6 +542,8 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
ImagePrice1K: imagePrice1K, ImagePrice1K: imagePrice1K,
ImagePrice2K: imagePrice2K, ImagePrice2K: imagePrice2K,
ImagePrice4K: imagePrice4K, ImagePrice4K: imagePrice4K,
ClaudeCodeOnly: input.ClaudeCodeOnly,
FallbackGroupID: input.FallbackGroupID,
} }
if err := s.groupRepo.Create(ctx, group); err != nil { if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err return nil, err
...@@ -552,6 +567,29 @@ func normalizePrice(price *float64) *float64 { ...@@ -552,6 +567,29 @@ func normalizePrice(price *float64) *float64 {
return price return price
} }
// validateFallbackGroup 校验降级分组的有效性
// currentGroupID: 当前分组 ID(新建时为 0)
// fallbackGroupID: 降级分组 ID
func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGroupID, fallbackGroupID int64) error {
// 不能将自己设置为降级分组
if currentGroupID > 0 && currentGroupID == fallbackGroupID {
return fmt.Errorf("cannot set self as fallback group")
}
// 检查降级分组是否存在
fallbackGroup, err := s.groupRepo.GetByID(ctx, fallbackGroupID)
if err != nil {
return fmt.Errorf("fallback group not found: %w", err)
}
// 降级分组不能启用 claude_code_only,否则会造成死循环
if fallbackGroup.ClaudeCodeOnly {
return fmt.Errorf("fallback group cannot have claude_code_only enabled")
}
return nil
}
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) { func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
group, err := s.groupRepo.GetByID(ctx, id) group, err := s.groupRepo.GetByID(ctx, id)
if err != nil { if err != nil {
...@@ -602,6 +640,23 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd ...@@ -602,6 +640,23 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
group.ImagePrice4K = normalizePrice(input.ImagePrice4K) group.ImagePrice4K = normalizePrice(input.ImagePrice4K)
} }
// Claude Code 客户端限制
if input.ClaudeCodeOnly != nil {
group.ClaudeCodeOnly = *input.ClaudeCodeOnly
}
if input.FallbackGroupID != nil {
// 校验降级分组
if *input.FallbackGroupID > 0 {
if err := s.validateFallbackGroup(ctx, id, *input.FallbackGroupID); err != nil {
return nil, err
}
group.FallbackGroupID = input.FallbackGroupID
} else {
// 传入 0 或负数表示清除降级分组
group.FallbackGroupID = nil
}
}
if err := s.groupRepo.Update(ctx, group); err != nil { if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, err return nil, err
} }
...@@ -856,6 +911,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp ...@@ -856,6 +911,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
if input.Status != "" { if input.Status != "" {
repoUpdates.Status = &input.Status repoUpdates.Status = &input.Status
} }
if input.Schedulable != nil {
repoUpdates.Schedulable = input.Schedulable
}
// Run bulk update for column/jsonb fields first. // Run bulk update for column/jsonb fields first.
if _, err := s.accountRepo.BulkUpdate(ctx, input.AccountIDs, repoUpdates); err != nil { if _, err := s.accountRepo.BulkUpdate(ctx, input.AccountIDs, repoUpdates); err != nil {
...@@ -950,6 +1008,15 @@ func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, ...@@ -950,6 +1008,15 @@ func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int,
return proxies, result.Total, nil return proxies, result.Total, nil
} }
func (s *adminServiceImpl) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]ProxyWithAccountCount, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
proxies, result, err := s.proxyRepo.ListWithFiltersAndAccountCount(ctx, params, protocol, status, search)
if err != nil {
return nil, 0, err
}
return proxies, result.Total, nil
}
func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]Proxy, error) { func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]Proxy, error) {
return s.proxyRepo.ListActive(ctx) return s.proxyRepo.ListActive(ctx)
} }
......
...@@ -124,7 +124,7 @@ func (s *groupRepoStub) List(ctx context.Context, params pagination.PaginationPa ...@@ -124,7 +124,7 @@ func (s *groupRepoStub) List(ctx context.Context, params pagination.PaginationPa
panic("unexpected List call") panic("unexpected List call")
} }
func (s *groupRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) { func (s *groupRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call") panic("unexpected ListWithFilters call")
} }
...@@ -186,6 +186,10 @@ func (s *proxyRepoStub) ListActiveWithAccountCount(ctx context.Context) ([]Proxy ...@@ -186,6 +186,10 @@ func (s *proxyRepoStub) ListActiveWithAccountCount(ctx context.Context) ([]Proxy
panic("unexpected ListActiveWithAccountCount call") panic("unexpected ListActiveWithAccountCount call")
} }
func (s *proxyRepoStub) ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]ProxyWithAccountCount, *pagination.PaginationResult, error) {
panic("unexpected ListWithFiltersAndAccountCount call")
}
func (s *proxyRepoStub) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) { func (s *proxyRepoStub) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) {
panic("unexpected ExistsByHostPortAuth call") panic("unexpected ExistsByHostPortAuth call")
} }
......
...@@ -16,6 +16,16 @@ type groupRepoStubForAdmin struct { ...@@ -16,6 +16,16 @@ type groupRepoStubForAdmin struct {
updated *Group // 记录 Update 调用的参数 updated *Group // 记录 Update 调用的参数
getByID *Group // GetByID 返回值 getByID *Group // GetByID 返回值
getErr error // GetByID 返回的错误 getErr error // GetByID 返回的错误
listWithFiltersCalls int
listWithFiltersParams pagination.PaginationParams
listWithFiltersPlatform string
listWithFiltersStatus string
listWithFiltersSearch string
listWithFiltersIsExclusive *bool
listWithFiltersGroups []Group
listWithFiltersResult *pagination.PaginationResult
listWithFiltersErr error
} }
func (s *groupRepoStubForAdmin) Create(_ context.Context, g *Group) error { func (s *groupRepoStubForAdmin) Create(_ context.Context, g *Group) error {
...@@ -47,8 +57,28 @@ func (s *groupRepoStubForAdmin) List(_ context.Context, _ pagination.PaginationP ...@@ -47,8 +57,28 @@ func (s *groupRepoStubForAdmin) List(_ context.Context, _ pagination.PaginationP
panic("unexpected List call") panic("unexpected List call")
} }
func (s *groupRepoStubForAdmin) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) { func (s *groupRepoStubForAdmin) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call") s.listWithFiltersCalls++
s.listWithFiltersParams = params
s.listWithFiltersPlatform = platform
s.listWithFiltersStatus = status
s.listWithFiltersSearch = search
s.listWithFiltersIsExclusive = isExclusive
if s.listWithFiltersErr != nil {
return nil, nil, s.listWithFiltersErr
}
result := s.listWithFiltersResult
if result == nil {
result = &pagination.PaginationResult{
Total: int64(len(s.listWithFiltersGroups)),
Page: params.Page,
PageSize: params.PageSize,
}
}
return s.listWithFiltersGroups, result, nil
} }
func (s *groupRepoStubForAdmin) ListActive(_ context.Context) ([]Group, error) { func (s *groupRepoStubForAdmin) ListActive(_ context.Context) ([]Group, error) {
...@@ -195,3 +225,68 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) { ...@@ -195,3 +225,68 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
require.InDelta(t, 0.15, *repo.updated.ImagePrice2K, 0.0001) // 原值保持 require.InDelta(t, 0.15, *repo.updated.ImagePrice2K, 0.0001) // 原值保持
require.Nil(t, repo.updated.ImagePrice4K) require.Nil(t, repo.updated.ImagePrice4K)
} }
func TestAdminService_ListGroups_WithSearch(t *testing.T) {
// 测试:
// 1. search 参数正常传递到 repository 层
// 2. search 为空字符串时的行为
// 3. search 与其他过滤条件组合使用
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
repo := &groupRepoStubForAdmin{
listWithFiltersGroups: []Group{{ID: 1, Name: "alpha"}},
listWithFiltersResult: &pagination.PaginationResult{Total: 1},
}
svc := &adminServiceImpl{groupRepo: repo}
groups, total, err := svc.ListGroups(context.Background(), 1, 20, "", "", "alpha", nil)
require.NoError(t, err)
require.Equal(t, int64(1), total)
require.Equal(t, []Group{{ID: 1, Name: "alpha"}}, groups)
require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams)
require.Equal(t, "alpha", repo.listWithFiltersSearch)
require.Nil(t, repo.listWithFiltersIsExclusive)
})
t.Run("search 为空字符串时传递空字符串", func(t *testing.T) {
repo := &groupRepoStubForAdmin{
listWithFiltersGroups: []Group{},
listWithFiltersResult: &pagination.PaginationResult{Total: 0},
}
svc := &adminServiceImpl{groupRepo: repo}
groups, total, err := svc.ListGroups(context.Background(), 2, 10, "", "", "", nil)
require.NoError(t, err)
require.Empty(t, groups)
require.Equal(t, int64(0), total)
require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10}, repo.listWithFiltersParams)
require.Equal(t, "", repo.listWithFiltersSearch)
require.Nil(t, repo.listWithFiltersIsExclusive)
})
t.Run("search 与其他过滤条件组合使用", func(t *testing.T) {
isExclusive := true
repo := &groupRepoStubForAdmin{
listWithFiltersGroups: []Group{{ID: 2, Name: "beta"}},
listWithFiltersResult: &pagination.PaginationResult{Total: 42},
}
svc := &adminServiceImpl{groupRepo: repo}
groups, total, err := svc.ListGroups(context.Background(), 3, 50, PlatformAntigravity, StatusActive, "beta", &isExclusive)
require.NoError(t, err)
require.Equal(t, int64(42), total)
require.Equal(t, []Group{{ID: 2, Name: "beta"}}, groups)
require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50}, repo.listWithFiltersParams)
require.Equal(t, PlatformAntigravity, repo.listWithFiltersPlatform)
require.Equal(t, StatusActive, repo.listWithFiltersStatus)
require.Equal(t, "beta", repo.listWithFiltersSearch)
require.NotNil(t, repo.listWithFiltersIsExclusive)
require.True(t, *repo.listWithFiltersIsExclusive)
})
}
//go:build unit
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type accountRepoStubForAdminList struct {
accountRepoStub
listWithFiltersCalls int
listWithFiltersParams pagination.PaginationParams
listWithFiltersPlatform string
listWithFiltersType string
listWithFiltersStatus string
listWithFiltersSearch string
listWithFiltersAccounts []Account
listWithFiltersResult *pagination.PaginationResult
listWithFiltersErr error
}
func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
s.listWithFiltersCalls++
s.listWithFiltersParams = params
s.listWithFiltersPlatform = platform
s.listWithFiltersType = accountType
s.listWithFiltersStatus = status
s.listWithFiltersSearch = search
if s.listWithFiltersErr != nil {
return nil, nil, s.listWithFiltersErr
}
result := s.listWithFiltersResult
if result == nil {
result = &pagination.PaginationResult{
Total: int64(len(s.listWithFiltersAccounts)),
Page: params.Page,
PageSize: params.PageSize,
}
}
return s.listWithFiltersAccounts, result, nil
}
type proxyRepoStubForAdminList struct {
proxyRepoStub
listWithFiltersCalls int
listWithFiltersParams pagination.PaginationParams
listWithFiltersProtocol string
listWithFiltersStatus string
listWithFiltersSearch string
listWithFiltersProxies []Proxy
listWithFiltersResult *pagination.PaginationResult
listWithFiltersErr error
listWithFiltersAndAccountCountCalls int
listWithFiltersAndAccountCountParams pagination.PaginationParams
listWithFiltersAndAccountCountProtocol string
listWithFiltersAndAccountCountStatus string
listWithFiltersAndAccountCountSearch string
listWithFiltersAndAccountCountProxies []ProxyWithAccountCount
listWithFiltersAndAccountCountResult *pagination.PaginationResult
listWithFiltersAndAccountCountErr error
}
func (s *proxyRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error) {
s.listWithFiltersCalls++
s.listWithFiltersParams = params
s.listWithFiltersProtocol = protocol
s.listWithFiltersStatus = status
s.listWithFiltersSearch = search
if s.listWithFiltersErr != nil {
return nil, nil, s.listWithFiltersErr
}
result := s.listWithFiltersResult
if result == nil {
result = &pagination.PaginationResult{
Total: int64(len(s.listWithFiltersProxies)),
Page: params.Page,
PageSize: params.PageSize,
}
}
return s.listWithFiltersProxies, result, nil
}
func (s *proxyRepoStubForAdminList) ListWithFiltersAndAccountCount(_ context.Context, params pagination.PaginationParams, protocol, status, search string) ([]ProxyWithAccountCount, *pagination.PaginationResult, error) {
s.listWithFiltersAndAccountCountCalls++
s.listWithFiltersAndAccountCountParams = params
s.listWithFiltersAndAccountCountProtocol = protocol
s.listWithFiltersAndAccountCountStatus = status
s.listWithFiltersAndAccountCountSearch = search
if s.listWithFiltersAndAccountCountErr != nil {
return nil, nil, s.listWithFiltersAndAccountCountErr
}
result := s.listWithFiltersAndAccountCountResult
if result == nil {
result = &pagination.PaginationResult{
Total: int64(len(s.listWithFiltersAndAccountCountProxies)),
Page: params.Page,
PageSize: params.PageSize,
}
}
return s.listWithFiltersAndAccountCountProxies, result, nil
}
type redeemRepoStubForAdminList struct {
redeemRepoStub
listWithFiltersCalls int
listWithFiltersParams pagination.PaginationParams
listWithFiltersType string
listWithFiltersStatus string
listWithFiltersSearch string
listWithFiltersCodes []RedeemCode
listWithFiltersResult *pagination.PaginationResult
listWithFiltersErr error
}
func (s *redeemRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, codeType, status, search string) ([]RedeemCode, *pagination.PaginationResult, error) {
s.listWithFiltersCalls++
s.listWithFiltersParams = params
s.listWithFiltersType = codeType
s.listWithFiltersStatus = status
s.listWithFiltersSearch = search
if s.listWithFiltersErr != nil {
return nil, nil, s.listWithFiltersErr
}
result := s.listWithFiltersResult
if result == nil {
result = &pagination.PaginationResult{
Total: int64(len(s.listWithFiltersCodes)),
Page: params.Page,
PageSize: params.PageSize,
}
}
return s.listWithFiltersCodes, result, nil
}
func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
repo := &accountRepoStubForAdminList{
listWithFiltersAccounts: []Account{{ID: 1, Name: "acc"}},
listWithFiltersResult: &pagination.PaginationResult{Total: 10},
}
svc := &adminServiceImpl{accountRepo: repo}
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc")
require.NoError(t, err)
require.Equal(t, int64(10), total)
require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts)
require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams)
require.Equal(t, PlatformGemini, repo.listWithFiltersPlatform)
require.Equal(t, AccountTypeOAuth, repo.listWithFiltersType)
require.Equal(t, StatusActive, repo.listWithFiltersStatus)
require.Equal(t, "acc", repo.listWithFiltersSearch)
})
}
func TestAdminService_ListProxies_WithSearch(t *testing.T) {
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
repo := &proxyRepoStubForAdminList{
listWithFiltersProxies: []Proxy{{ID: 2, Name: "p1"}},
listWithFiltersResult: &pagination.PaginationResult{Total: 7},
}
svc := &adminServiceImpl{proxyRepo: repo}
proxies, total, err := svc.ListProxies(context.Background(), 3, 50, "http", StatusActive, "p1")
require.NoError(t, err)
require.Equal(t, int64(7), total)
require.Equal(t, []Proxy{{ID: 2, Name: "p1"}}, proxies)
require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50}, repo.listWithFiltersParams)
require.Equal(t, "http", repo.listWithFiltersProtocol)
require.Equal(t, StatusActive, repo.listWithFiltersStatus)
require.Equal(t, "p1", repo.listWithFiltersSearch)
})
}
func TestAdminService_ListProxiesWithAccountCount_WithSearch(t *testing.T) {
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
repo := &proxyRepoStubForAdminList{
listWithFiltersAndAccountCountProxies: []ProxyWithAccountCount{{Proxy: Proxy{ID: 3, Name: "p2"}, AccountCount: 5}},
listWithFiltersAndAccountCountResult: &pagination.PaginationResult{Total: 9},
}
svc := &adminServiceImpl{proxyRepo: repo}
proxies, total, err := svc.ListProxiesWithAccountCount(context.Background(), 2, 10, "socks5", StatusDisabled, "p2")
require.NoError(t, err)
require.Equal(t, int64(9), total)
require.Equal(t, []ProxyWithAccountCount{{Proxy: Proxy{ID: 3, Name: "p2"}, AccountCount: 5}}, proxies)
require.Equal(t, 1, repo.listWithFiltersAndAccountCountCalls)
require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10}, repo.listWithFiltersAndAccountCountParams)
require.Equal(t, "socks5", repo.listWithFiltersAndAccountCountProtocol)
require.Equal(t, StatusDisabled, repo.listWithFiltersAndAccountCountStatus)
require.Equal(t, "p2", repo.listWithFiltersAndAccountCountSearch)
})
}
func TestAdminService_ListRedeemCodes_WithSearch(t *testing.T) {
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
repo := &redeemRepoStubForAdminList{
listWithFiltersCodes: []RedeemCode{{ID: 4, Code: "ABC"}},
listWithFiltersResult: &pagination.PaginationResult{Total: 3},
}
svc := &adminServiceImpl{redeemCodeRepo: repo}
codes, total, err := svc.ListRedeemCodes(context.Background(), 1, 20, RedeemTypeBalance, StatusUnused, "ABC")
require.NoError(t, err)
require.Equal(t, int64(3), total)
require.Equal(t, []RedeemCode{{ID: 4, Code: "ABC"}}, codes)
require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams)
require.Equal(t, RedeemTypeBalance, repo.listWithFiltersType)
require.Equal(t, StatusUnused, repo.listWithFiltersStatus)
require.Equal(t, "ABC", repo.listWithFiltersSearch)
})
}
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
"io" "io"
"log" "log"
mathrand "math/rand" mathrand "math/rand"
"net"
"net/http" "net/http"
"strings" "strings"
"sync/atomic" "sync/atomic"
...@@ -27,6 +28,32 @@ const ( ...@@ -27,6 +28,32 @@ const (
antigravityRetryMaxDelay = 16 * time.Second antigravityRetryMaxDelay = 16 * time.Second
) )
// isAntigravityConnectionError 判断是否为连接错误(网络超时、DNS 失败、连接拒绝)
func isAntigravityConnectionError(err error) bool {
if err == nil {
return false
}
// 检查超时错误
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
return true
}
// 检查连接错误(DNS 失败、连接拒绝)
var opErr *net.OpError
return errors.As(err, &opErr)
}
// shouldAntigravityFallbackToNextURL 判断是否应切换到下一个 URL
// 仅连接错误和 HTTP 429 触发 URL 降级
func shouldAntigravityFallbackToNextURL(err error, statusCode int) bool {
if isAntigravityConnectionError(err) {
return true
}
return statusCode == http.StatusTooManyRequests
}
// getSessionID 从 gin.Context 获取 session_id(用于日志追踪) // getSessionID 从 gin.Context 获取 session_id(用于日志追踪)
func getSessionID(c *gin.Context) string { func getSessionID(c *gin.Context) string {
if c == nil { if c == nil {
...@@ -182,45 +209,70 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account ...@@ -182,45 +209,70 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
return nil, fmt.Errorf("构建请求失败: %w", err) return nil, fmt.Errorf("构建请求失败: %w", err)
} }
// 构建 HTTP 请求(总是使用流式 endpoint,与官方客户端一致)
req, err := antigravity.NewAPIRequest(ctx, "streamGenerateContent", accessToken, requestBody)
if err != nil {
return nil, err
}
// 调试日志:Test 请求信息
log.Printf("[antigravity-Test] account=%s request_size=%d url=%s", account.Name, len(requestBody), req.URL.String())
// 代理 URL // 代理 URL
proxyURL := "" proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil { if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL() proxyURL = account.Proxy.URL()
} }
// 发送请求 // URL fallback 循环
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
if err != nil { if len(availableURLs) == 0 {
return nil, fmt.Errorf("请求失败: %w", err) availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有
} }
defer func() { _ = resp.Body.Close() }()
// 读取响应 var lastErr error
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) for urlIdx, baseURL := range availableURLs {
if err != nil { // 构建 HTTP 请求(总是使用流式 endpoint,与官方客户端一致)
return nil, fmt.Errorf("读取响应失败: %w", err) req, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, "streamGenerateContent", accessToken, requestBody)
} if err != nil {
lastErr = err
continue
}
if resp.StatusCode >= 400 { // 调试日志:Test 请求信息
return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody)) log.Printf("[antigravity-Test] account=%s request_size=%d url=%s", account.Name, len(requestBody), req.URL.String())
}
// 解析流式响应,提取文本 // 发送请求
text := extractTextFromSSEResponse(respBody) resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil {
lastErr = fmt.Errorf("请求失败: %w", err)
if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("[antigravity-Test] URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
continue
}
return nil, lastErr
}
return &TestConnectionResult{ // 读取响应
Text: text, respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
MappedModel: mappedModel, _ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏
}, nil if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
// 检查是否需要 URL 降级
if shouldAntigravityFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("[antigravity-Test] URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
continue
}
if resp.StatusCode >= 400 {
return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody))
}
// 解析流式响应,提取文本
text := extractTextFromSSEResponse(respBody)
return &TestConnectionResult{
Text: text,
MappedModel: mappedModel,
}, nil
}
return nil, lastErr
} }
// buildGeminiTestRequest 构建 Gemini 格式测试请求 // buildGeminiTestRequest 构建 Gemini 格式测试请求
...@@ -486,62 +538,86 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -486,62 +538,86 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回 // 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回
action := "streamGenerateContent" action := "streamGenerateContent"
// URL fallback 循环
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
if len(availableURLs) == 0 {
availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有
}
// 重试循环 // 重试循环
var resp *http.Response var resp *http.Response
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { urlFallbackLoop:
// 检查 context 是否已取消(客户端断开连接) for urlIdx, baseURL := range availableURLs {
select { for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
case <-ctx.Done(): // 检查 context 是否已取消(客户端断开连接)
log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err()) select {
return nil, ctx.Err() case <-ctx.Done():
default: log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err())
} return nil, ctx.Err()
default:
}
upstreamReq, err := antigravity.NewAPIRequest(ctx, action, accessToken, geminiBody) upstreamReq, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, action, accessToken, geminiBody)
if err != nil { if err != nil {
return nil, err return nil, err
} }
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil { if err != nil {
if attempt < antigravityMaxRetries { // 检查是否应触发 URL 降级
log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
if !sleepAntigravityBackoffWithContext(ctx, attempt) { antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("%s status=context_canceled_during_backoff", prefix) log.Printf("%s URL fallback (connection error): %s -> %s", prefix, baseURL, availableURLs[urlIdx+1])
return nil, ctx.Err() continue urlFallbackLoop
} }
continue if attempt < antigravityMaxRetries {
log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err)
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", prefix)
return nil, ctx.Err()
}
continue
}
log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
} }
log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
}
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) { // 检查是否应触发 URL 降级(仅 429)
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) if resp.StatusCode == http.StatusTooManyRequests && urlIdx < len(availableURLs)-1 {
_ = resp.Body.Close() respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("%s URL fallback (HTTP 429): %s -> %s body=%s", prefix, baseURL, availableURLs[urlIdx+1], truncateForLog(respBody, 200))
continue urlFallbackLoop
}
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
if attempt < antigravityMaxRetries { if attempt < antigravityMaxRetries {
log.Printf("%s status=%d retry=%d/%d body=%s", prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500)) log.Printf("%s status=%d retry=%d/%d body=%s", prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500))
if !sleepAntigravityBackoffWithContext(ctx, attempt) { if !sleepAntigravityBackoffWithContext(ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", prefix) log.Printf("%s status=context_canceled_during_backoff", prefix)
return nil, ctx.Err() return nil, ctx.Err()
}
continue
} }
continue // 所有重试都失败,标记限流状态
} if resp.StatusCode == 429 {
// 所有重试都失败,标记限流状态 s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
if resp.StatusCode == 429 { }
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope) // 最后一次尝试也失败
} resp = &http.Response{
// 最后一次尝试也失败 StatusCode: resp.StatusCode,
resp = &http.Response{ Header: resp.Header.Clone(),
StatusCode: resp.StatusCode, Body: io.NopCloser(bytes.NewReader(respBody)),
Header: resp.Header.Clone(), }
Body: io.NopCloser(bytes.NewReader(respBody)), break urlFallbackLoop
} }
break
}
break break urlFallbackLoop
}
} }
defer func() { _ = resp.Body.Close() }() defer func() { _ = resp.Body.Close() }()
...@@ -1006,61 +1082,85 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co ...@@ -1006,61 +1082,85 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后返回 // 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后返回
upstreamAction := "streamGenerateContent" upstreamAction := "streamGenerateContent"
// URL fallback 循环
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
if len(availableURLs) == 0 {
availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有
}
// 重试循环 // 重试循环
var resp *http.Response var resp *http.Response
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { urlFallbackLoop:
// 检查 context 是否已取消(客户端断开连接) for urlIdx, baseURL := range availableURLs {
select { for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
case <-ctx.Done(): // 检查 context 是否已取消(客户端断开连接)
log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err()) select {
return nil, ctx.Err() case <-ctx.Done():
default: log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err())
} return nil, ctx.Err()
default:
}
upstreamReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, wrappedBody) upstreamReq, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, upstreamAction, accessToken, wrappedBody)
if err != nil { if err != nil {
return nil, err return nil, err
} }
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil { if err != nil {
if attempt < antigravityMaxRetries { // 检查是否应触发 URL 降级
log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
if !sleepAntigravityBackoffWithContext(ctx, attempt) { antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("%s status=context_canceled_during_backoff", prefix) log.Printf("%s URL fallback (connection error): %s -> %s", prefix, baseURL, availableURLs[urlIdx+1])
return nil, ctx.Err() continue urlFallbackLoop
} }
continue if attempt < antigravityMaxRetries {
log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err)
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", prefix)
return nil, ctx.Err()
}
continue
}
log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
} }
log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
}
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) { // 检查是否应触发 URL 降级(仅 429)
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) if resp.StatusCode == http.StatusTooManyRequests && urlIdx < len(availableURLs)-1 {
_ = resp.Body.Close() respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("%s URL fallback (HTTP 429): %s -> %s body=%s", prefix, baseURL, availableURLs[urlIdx+1], truncateForLog(respBody, 200))
continue urlFallbackLoop
}
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
if attempt < antigravityMaxRetries { if attempt < antigravityMaxRetries {
log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries) log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries)
if !sleepAntigravityBackoffWithContext(ctx, attempt) { if !sleepAntigravityBackoffWithContext(ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", prefix) log.Printf("%s status=context_canceled_during_backoff", prefix)
return nil, ctx.Err() return nil, ctx.Err()
}
continue
} }
continue // 所有重试都失败,标记限流状态
} if resp.StatusCode == 429 {
// 所有重试都失败,标记限流状态 s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
if resp.StatusCode == 429 { }
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope) resp = &http.Response{
} StatusCode: resp.StatusCode,
resp = &http.Response{ Header: resp.Header.Clone(),
StatusCode: resp.StatusCode, Body: io.NopCloser(bytes.NewReader(respBody)),
Header: resp.Header.Clone(), }
Body: io.NopCloser(bytes.NewReader(respBody)), break urlFallbackLoop
} }
break
}
break break urlFallbackLoop
}
} }
defer func() { defer func() {
if resp != nil && resp.Body != nil { if resp != nil && resp.Body != nil {
......
...@@ -2,9 +2,13 @@ package service ...@@ -2,9 +2,13 @@ package service
import ( import (
"context" "context"
"crypto/rand"
"encoding/hex"
"errors" "errors"
"fmt" "fmt"
"log" "log"
"net/mail"
"strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
...@@ -18,6 +22,7 @@ var ( ...@@ -18,6 +22,7 @@ var (
ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password") ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password")
ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active") ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active")
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists") ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved")
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token") ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired") ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large") ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
...@@ -75,21 +80,30 @@ func (s *AuthService) Register(ctx context.Context, email, password string) (str ...@@ -75,21 +80,30 @@ func (s *AuthService) Register(ctx context.Context, email, password string) (str
// RegisterWithVerification 用户注册(支持邮件验证),返回token和用户 // RegisterWithVerification 用户注册(支持邮件验证),返回token和用户
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *User, error) { func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *User, error) {
// 检查是否开放注册 // 检查是否开放注册(默认关闭:settingService 未配置时不允许注册)
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) { if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
return "", nil, ErrRegDisabled return "", nil, ErrRegDisabled
} }
// 防止用户注册 LinuxDo OAuth 合成邮箱,避免第三方登录与本地账号发生碰撞。
if isReservedEmail(email) {
return "", nil, ErrEmailReserved
}
// 检查是否需要邮件验证 // 检查是否需要邮件验证
if s.settingService != nil && s.settingService.IsEmailVerifyEnabled(ctx) { if s.settingService != nil && s.settingService.IsEmailVerifyEnabled(ctx) {
// 如果邮件验证已开启但邮件服务未配置,拒绝注册
// 这是一个配置错误,不应该允许绕过验证
if s.emailService == nil {
log.Println("[Auth] Email verification enabled but email service not configured, rejecting registration")
return "", nil, ErrServiceUnavailable
}
if verifyCode == "" { if verifyCode == "" {
return "", nil, ErrEmailVerifyRequired return "", nil, ErrEmailVerifyRequired
} }
// 验证邮箱验证码 // 验证邮箱验证码
if s.emailService != nil { if err := s.emailService.VerifyCode(ctx, email, verifyCode); err != nil {
if err := s.emailService.VerifyCode(ctx, email, verifyCode); err != nil { return "", nil, fmt.Errorf("verify code: %w", err)
return "", nil, fmt.Errorf("verify code: %w", err)
}
} }
} }
...@@ -128,6 +142,10 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw ...@@ -128,6 +142,10 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
} }
if err := s.userRepo.Create(ctx, user); err != nil { if err := s.userRepo.Create(ctx, user); err != nil {
// 优先检查邮箱冲突错误(竞态条件下可能发生)
if errors.Is(err, ErrEmailExists) {
return "", nil, ErrEmailExists
}
log.Printf("[Auth] Database error creating user: %v", err) log.Printf("[Auth] Database error creating user: %v", err)
return "", nil, ErrServiceUnavailable return "", nil, ErrServiceUnavailable
} }
...@@ -148,11 +166,15 @@ type SendVerifyCodeResult struct { ...@@ -148,11 +166,15 @@ type SendVerifyCodeResult struct {
// SendVerifyCode 发送邮箱验证码(同步方式) // SendVerifyCode 发送邮箱验证码(同步方式)
func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error { func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
// 检查是否开放注册 // 检查是否开放注册(默认关闭)
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) { if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
return ErrRegDisabled return ErrRegDisabled
} }
if isReservedEmail(email) {
return ErrEmailReserved
}
// 检查邮箱是否已存在 // 检查邮箱是否已存在
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
if err != nil { if err != nil {
...@@ -181,12 +203,16 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error { ...@@ -181,12 +203,16 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*SendVerifyCodeResult, error) { func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*SendVerifyCodeResult, error) {
log.Printf("[Auth] SendVerifyCodeAsync called for email: %s", email) log.Printf("[Auth] SendVerifyCodeAsync called for email: %s", email)
// 检查是否开放注册 // 检查是否开放注册(默认关闭)
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) { if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
log.Println("[Auth] Registration is disabled") log.Println("[Auth] Registration is disabled")
return nil, ErrRegDisabled return nil, ErrRegDisabled
} }
if isReservedEmail(email) {
return nil, ErrEmailReserved
}
// 检查邮箱是否已存在 // 检查邮箱是否已存在
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
if err != nil { if err != nil {
...@@ -266,7 +292,7 @@ func (s *AuthService) IsTurnstileEnabled(ctx context.Context) bool { ...@@ -266,7 +292,7 @@ func (s *AuthService) IsTurnstileEnabled(ctx context.Context) bool {
// IsRegistrationEnabled 检查是否开放注册 // IsRegistrationEnabled 检查是否开放注册
func (s *AuthService) IsRegistrationEnabled(ctx context.Context) bool { func (s *AuthService) IsRegistrationEnabled(ctx context.Context) bool {
if s.settingService == nil { if s.settingService == nil {
return true return false // 安全默认:settingService 未配置时关闭注册
} }
return s.settingService.IsRegistrationEnabled(ctx) return s.settingService.IsRegistrationEnabled(ctx)
} }
...@@ -311,6 +337,102 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string ...@@ -311,6 +337,102 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
return token, user, nil return token, user, nil
} }
// LoginOrRegisterOAuth 用于第三方 OAuth/SSO 登录:
// - 如果邮箱已存在:直接登录(不需要本地密码)
// - 如果邮箱不存在:创建新用户并登录
//
// 注意:该函数用于“终端用户登录 Sub2API 本身”的场景(不同于上游账号的 OAuth,例如 OpenAI/Gemini)。
// 为了满足现有数据库约束(需要密码哈希),新用户会生成随机密码并进行哈希保存。
func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username string) (string, *User, error) {
email = strings.TrimSpace(email)
if email == "" || len(email) > 255 {
return "", nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
}
if _, err := mail.ParseAddress(email); err != nil {
return "", nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
}
username = strings.TrimSpace(username)
if len([]rune(username)) > 100 {
username = string([]rune(username)[:100])
}
user, err := s.userRepo.GetByEmail(ctx, email)
if err != nil {
if errors.Is(err, ErrUserNotFound) {
// OAuth 首次登录视为注册。
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
return "", nil, ErrRegDisabled
}
randomPassword, err := randomHexString(32)
if err != nil {
log.Printf("[Auth] Failed to generate random password for oauth signup: %v", err)
return "", nil, ErrServiceUnavailable
}
hashedPassword, err := s.HashPassword(randomPassword)
if err != nil {
return "", nil, fmt.Errorf("hash password: %w", err)
}
// 新用户默认值。
defaultBalance := s.cfg.Default.UserBalance
defaultConcurrency := s.cfg.Default.UserConcurrency
if s.settingService != nil {
defaultBalance = s.settingService.GetDefaultBalance(ctx)
defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
}
newUser := &User{
Email: email,
Username: username,
PasswordHash: hashedPassword,
Role: RoleUser,
Balance: defaultBalance,
Concurrency: defaultConcurrency,
Status: StatusActive,
}
if err := s.userRepo.Create(ctx, newUser); err != nil {
if errors.Is(err, ErrEmailExists) {
// 并发场景:GetByEmail 与 Create 之间用户被创建。
user, err = s.userRepo.GetByEmail(ctx, email)
if err != nil {
log.Printf("[Auth] Database error getting user after conflict: %v", err)
return "", nil, ErrServiceUnavailable
}
} else {
log.Printf("[Auth] Database error creating oauth user: %v", err)
return "", nil, ErrServiceUnavailable
}
} else {
user = newUser
}
} else {
log.Printf("[Auth] Database error during oauth login: %v", err)
return "", nil, ErrServiceUnavailable
}
}
if !user.IsActive() {
return "", nil, ErrUserNotActive
}
// 尽力补全:当用户名为空时,使用第三方返回的用户名回填。
if user.Username == "" && username != "" {
user.Username = username
if err := s.userRepo.Update(ctx, user); err != nil {
log.Printf("[Auth] Failed to update username after oauth login: %v", err)
}
}
token, err := s.GenerateToken(user)
if err != nil {
return "", nil, fmt.Errorf("generate token: %w", err)
}
return token, user, nil
}
// ValidateToken 验证JWT token并返回用户声明 // ValidateToken 验证JWT token并返回用户声明
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) { func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
// 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。 // 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。
...@@ -336,6 +458,11 @@ func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) { ...@@ -336,6 +458,11 @@ func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
if err != nil { if err != nil {
if errors.Is(err, jwt.ErrTokenExpired) { if errors.Is(err, jwt.ErrTokenExpired) {
// token 过期但仍返回 claims(用于 RefreshToken 等场景)
// jwt-go 在解析时即使遇到过期错误,token.Claims 仍会被填充
if claims, ok := token.Claims.(*JWTClaims); ok {
return claims, ErrTokenExpired
}
return nil, ErrTokenExpired return nil, ErrTokenExpired
} }
return nil, ErrInvalidToken return nil, ErrInvalidToken
...@@ -348,6 +475,22 @@ func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) { ...@@ -348,6 +475,22 @@ func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
return nil, ErrInvalidToken return nil, ErrInvalidToken
} }
func randomHexString(byteLength int) (string, error) {
if byteLength <= 0 {
byteLength = 16
}
buf := make([]byte, byteLength)
if _, err := rand.Read(buf); err != nil {
return "", err
}
return hex.EncodeToString(buf), nil
}
func isReservedEmail(email string) bool {
normalized := strings.ToLower(strings.TrimSpace(email))
return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain)
}
// GenerateToken 生成JWT token // GenerateToken 生成JWT token
func (s *AuthService) GenerateToken(user *User) (string, error) { func (s *AuthService) GenerateToken(user *User) (string, error) {
now := time.Now() now := time.Now()
......
This diff is collapsed.
...@@ -105,7 +105,17 @@ const ( ...@@ -105,7 +105,17 @@ const (
// Request identity patch (Claude -> Gemini systemInstruction injection) // Request identity patch (Claude -> Gemini systemInstruction injection)
SettingKeyEnableIdentityPatch = "enable_identity_patch" SettingKeyEnableIdentityPatch = "enable_identity_patch"
SettingKeyIdentityPatchPrompt = "identity_patch_prompt" SettingKeyIdentityPatchPrompt = "identity_patch_prompt"
// LinuxDo Connect OAuth 登录(终端用户 SSO)
SettingKeyLinuxDoConnectEnabled = "linuxdo_connect_enabled"
SettingKeyLinuxDoConnectClientID = "linuxdo_connect_client_id"
SettingKeyLinuxDoConnectClientSecret = "linuxdo_connect_client_secret"
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
) )
// LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀(RFC 保留域名)。
// 目的:避免第三方登录返回的用户标识与本地真实邮箱发生碰撞,进而造成账号被接管的风险。
const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys). // AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
const AdminAPIKeyPrefix = "admin-" const AdminAPIKeyPrefix = "admin-"
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment