Unverified Commit 1ef3782d authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #1538 from IanShaw027/fix/bug-cleanup-main

 fix: 修复多个 UI 和功能问题 - 表格排序搜索、导出逻辑、分页配置和状态筛选
parents 00c08c57 f480e573
//go:build integration
package repository
import (
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func (s *GroupRepoSuite) TestList_DefaultSortBySortOrderAsc() {
g1 := &service.Group{Name: "g1", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 20}
g2 := &service.Group{Name: "g2", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 10}
s.Require().NoError(s.repo.Create(s.ctx, g1))
s.Require().NoError(s.repo.Create(s.ctx, g2))
groups, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 100})
s.Require().NoError(err)
s.Require().GreaterOrEqual(len(groups), 2)
indexByID := make(map[int64]int, len(groups))
for i, g := range groups {
indexByID[g.ID] = i
}
s.Require().Contains(indexByID, g1.ID)
s.Require().Contains(indexByID, g2.ID)
// g2 has SortOrder=10, g1 has SortOrder=20; ascending means g2 comes first
s.Require().Less(indexByID[g2.ID], indexByID[g1.ID])
}
func (s *GroupRepoSuite) TestList_SortBySortOrderDesc() {
g1 := &service.Group{Name: "g1", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 40}
g2 := &service.Group{Name: "g2", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 50}
s.Require().NoError(s.repo.Create(s.ctx, g1))
s.Require().NoError(s.repo.Create(s.ctx, g2))
groups, _, err := s.repo.List(s.ctx, pagination.PaginationParams{
Page: 1,
PageSize: 10,
SortBy: "sort_order",
SortOrder: "desc",
})
s.Require().NoError(err)
s.Require().GreaterOrEqual(len(groups), 2)
indexByID := make(map[int64]int, len(groups))
for i, group := range groups {
indexByID[group.ID] = i
}
s.Require().Contains(indexByID, g1.ID)
s.Require().Contains(indexByID, g2.ID)
s.Require().Less(indexByID[g2.ID], indexByID[g1.ID])
}
...@@ -14,3 +14,22 @@ func paginationResultFromTotal(total int64, params pagination.PaginationParams) ...@@ -14,3 +14,22 @@ func paginationResultFromTotal(total int64, params pagination.PaginationParams)
Pages: pages, Pages: pages,
} }
} }
func paginateSlice[T any](items []T, params pagination.PaginationParams) []T {
if len(items) == 0 {
return []T{}
}
offset := params.Offset()
if offset >= len(items) {
return []T{}
}
limit := params.Limit()
end := offset + limit
if end > len(items) {
end = len(items)
}
return items[offset:end]
}
...@@ -2,12 +2,15 @@ package repository ...@@ -2,12 +2,15 @@ package repository
import ( import (
"context" "context"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
entsql "entgo.io/ent/dialect/sql"
) )
type promoCodeRepository struct { type promoCodeRepository struct {
...@@ -137,11 +140,14 @@ func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagina ...@@ -137,11 +140,14 @@ func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagina
return nil, nil, err return nil, nil, err
} }
codes, err := q. codesQuery := q.
Offset(params.Offset()). Offset(params.Offset()).
Limit(params.Limit()). Limit(params.Limit())
Order(dbent.Desc(promocode.FieldID)). for _, order := range promoCodeListOrder(params) {
All(ctx) codesQuery = codesQuery.Order(order)
}
codes, err := codesQuery.All(ctx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
...@@ -151,6 +157,32 @@ func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagina ...@@ -151,6 +157,32 @@ func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagina
return outCodes, paginationResultFromTotal(int64(total), params), nil return outCodes, paginationResultFromTotal(int64(total), params), nil
} }
func promoCodeListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
var field string
switch sortBy {
case "bonus_amount":
field = promocode.FieldBonusAmount
case "status":
field = promocode.FieldStatus
case "expires_at":
field = promocode.FieldExpiresAt
case "created_at":
field = promocode.FieldCreatedAt
case "code":
field = promocode.FieldCode
default:
field = promocode.FieldID
}
if sortOrder == pagination.SortOrderAsc {
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(promocode.FieldID)}
}
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(promocode.FieldID)}
}
func (r *promoCodeRepository) CreateUsage(ctx context.Context, usage *service.PromoCodeUsage) error { func (r *promoCodeRepository) CreateUsage(ctx context.Context, usage *service.PromoCodeUsage) error {
client := clientFromContext(ctx, r.client) client := clientFromContext(ctx, r.client)
created, err := client.PromoCodeUsage.Create(). created, err := client.PromoCodeUsage.Create().
......
...@@ -3,12 +3,16 @@ package repository ...@@ -3,12 +3,16 @@ package repository
import ( import (
"context" "context"
"database/sql" "database/sql"
"sort"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/proxy" "github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
entsql "entgo.io/ent/dialect/sql"
) )
type sqlQuerier interface { type sqlQuerier interface {
...@@ -135,11 +139,14 @@ func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination ...@@ -135,11 +139,14 @@ func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination
return nil, nil, err return nil, nil, err
} }
proxies, err := q. proxiesQuery := q.
Offset(params.Offset()). Offset(params.Offset()).
Limit(params.Limit()). Limit(params.Limit())
Order(dbent.Desc(proxy.FieldID)). for _, order := range proxyListOrder(params) {
All(ctx) proxiesQuery = proxiesQuery.Order(order)
}
proxies, err := proxiesQuery.All(ctx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
...@@ -170,22 +177,58 @@ func (r *proxyRepository) ListWithFiltersAndAccountCount(ctx context.Context, pa ...@@ -170,22 +177,58 @@ func (r *proxyRepository) ListWithFiltersAndAccountCount(ctx context.Context, pa
return nil, nil, err return nil, nil, err
} }
proxies, err := q. if strings.EqualFold(strings.TrimSpace(params.SortBy), "account_count") {
return r.listWithAccountCountSort(ctx, q, params, total)
}
proxiesQuery := q.
Offset(params.Offset()). Offset(params.Offset()).
Limit(params.Limit()). Limit(params.Limit())
for _, order := range proxyListOrder(params) {
proxiesQuery = proxiesQuery.Order(order)
}
proxies, err := proxiesQuery.All(ctx)
if err != nil {
return nil, nil, err
}
return r.buildProxyWithAccountCountResult(ctx, proxies, params, int64(total))
}
func (r *proxyRepository) listWithAccountCountSort(ctx context.Context, q *dbent.ProxyQuery, params pagination.PaginationParams, total int) ([]service.ProxyWithAccountCount, *pagination.PaginationResult, error) {
proxies, err := q.
Order(dbent.Desc(proxy.FieldID)). Order(dbent.Desc(proxy.FieldID)).
All(ctx) All(ctx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
// Get account counts result, _, err := r.buildProxyWithAccountCountResult(ctx, proxies, params, int64(total))
if err != nil {
return nil, nil, err
}
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
sort.SliceStable(result, func(i, j int) bool {
if result[i].AccountCount == result[j].AccountCount {
return result[i].ID > result[j].ID
}
if sortOrder == pagination.SortOrderAsc {
return result[i].AccountCount < result[j].AccountCount
}
return result[i].AccountCount > result[j].AccountCount
})
return paginateSlice(result, params), paginationResultFromTotal(int64(total), params), nil
}
func (r *proxyRepository) buildProxyWithAccountCountResult(ctx context.Context, proxies []*dbent.Proxy, params pagination.PaginationParams, total int64) ([]service.ProxyWithAccountCount, *pagination.PaginationResult, error) {
counts, err := r.GetAccountCountsForProxies(ctx) counts, err := r.GetAccountCountsForProxies(ctx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
// Build result with account counts
result := make([]service.ProxyWithAccountCount, 0, len(proxies)) result := make([]service.ProxyWithAccountCount, 0, len(proxies))
for i := range proxies { for i := range proxies {
proxyOut := proxyEntityToService(proxies[i]) proxyOut := proxyEntityToService(proxies[i])
...@@ -198,7 +241,31 @@ func (r *proxyRepository) ListWithFiltersAndAccountCount(ctx context.Context, pa ...@@ -198,7 +241,31 @@ func (r *proxyRepository) ListWithFiltersAndAccountCount(ctx context.Context, pa
}) })
} }
return result, paginationResultFromTotal(int64(total), params), nil return result, paginationResultFromTotal(total, params), nil
}
func proxyListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
var field string
switch sortBy {
case "name":
field = proxy.FieldName
case "protocol":
field = proxy.FieldProtocol
case "status":
field = proxy.FieldStatus
case "created_at":
field = proxy.FieldCreatedAt
default:
field = proxy.FieldID
}
if sortOrder == pagination.SortOrderAsc {
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(proxy.FieldID)}
}
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(proxy.FieldID)}
} }
func (r *proxyRepository) ListActive(ctx context.Context) ([]service.Proxy, error) { func (r *proxyRepository) ListActive(ctx context.Context) ([]service.Proxy, error) {
......
//go:build integration
package repository
import (
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func (s *ProxyRepoSuite) TestListWithFiltersAndAccountCount_SortByAccountCountDesc() {
p1 := s.mustCreateProxy(&service.Proxy{Name: "p1", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
p2 := s.mustCreateProxy(&service.Proxy{Name: "p2", Protocol: "http", Host: "127.0.0.1", Port: 8081, Status: service.StatusActive})
s.mustInsertAccount("a1", &p1.ID)
s.mustInsertAccount("a2", &p1.ID)
s.mustInsertAccount("a3", &p2.ID)
proxies, _, err := s.repo.ListWithFiltersAndAccountCount(s.ctx, pagination.PaginationParams{
Page: 1,
PageSize: 10,
SortBy: "account_count",
SortOrder: "desc",
}, "", "", "")
s.Require().NoError(err)
s.Require().Len(proxies, 2)
s.Require().Equal(p1.ID, proxies[0].ID)
s.Require().Equal(int64(2), proxies[0].AccountCount)
s.Require().Equal(p2.ID, proxies[1].ID)
}
...@@ -2,6 +2,7 @@ package repository ...@@ -2,6 +2,7 @@ package repository
import ( import (
"context" "context"
"strings"
"time" "time"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
...@@ -9,6 +10,8 @@ import ( ...@@ -9,6 +10,8 @@ import (
"github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
entsql "entgo.io/ent/dialect/sql"
) )
type redeemCodeRepository struct { type redeemCodeRepository struct {
...@@ -120,13 +123,16 @@ func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagin ...@@ -120,13 +123,16 @@ func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagin
return nil, nil, err return nil, nil, err
} }
codes, err := q. codesQuery := q.
WithUser(). WithUser().
WithGroup(). WithGroup().
Offset(params.Offset()). Offset(params.Offset()).
Limit(params.Limit()). Limit(params.Limit())
Order(dbent.Desc(redeemcode.FieldID)). for _, order := range redeemCodeListOrder(params) {
All(ctx) codesQuery = codesQuery.Order(order)
}
codes, err := codesQuery.All(ctx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
...@@ -136,6 +142,34 @@ func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagin ...@@ -136,6 +142,34 @@ func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagin
return outCodes, paginationResultFromTotal(int64(total), params), nil return outCodes, paginationResultFromTotal(int64(total), params), nil
} }
func redeemCodeListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
var field string
switch sortBy {
case "type":
field = redeemcode.FieldType
case "value":
field = redeemcode.FieldValue
case "status":
field = redeemcode.FieldStatus
case "used_at":
field = redeemcode.FieldUsedAt
case "created_at":
field = redeemcode.FieldCreatedAt
case "code":
field = redeemcode.FieldCode
default:
field = redeemcode.FieldID
}
if sortOrder == pagination.SortOrderAsc {
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(redeemcode.FieldID)}
}
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(redeemcode.FieldID)}
}
func (r *redeemCodeRepository) Update(ctx context.Context, code *service.RedeemCode) error { func (r *redeemCodeRepository) Update(ctx context.Context, code *service.RedeemCode) error {
up := r.client.RedeemCode.UpdateOneID(code.ID). up := r.client.RedeemCode.UpdateOneID(code.ID).
SetCode(code.Code). SetCode(code.Code).
......
//go:build integration
package repository
import (
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func (s *RedeemCodeRepoSuite) TestListWithFilters_SortByValueAsc() {
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "VALUE-20", Type: service.RedeemTypeBalance, Value: 20, Status: service.StatusUnused}))
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "VALUE-10", Type: service.RedeemTypeBalance, Value: 10, Status: service.StatusUnused}))
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
Page: 1,
PageSize: 10,
SortBy: "value",
SortOrder: "asc",
}, "", "", "")
s.Require().NoError(err)
s.Require().Len(codes, 2)
s.Require().Equal("VALUE-10", codes[0].Code)
s.Require().Equal("VALUE-20", codes[1].Code)
}
...@@ -3771,7 +3771,7 @@ func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, wh ...@@ -3771,7 +3771,7 @@ func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, wh
limitPos := len(args) + 1 limitPos := len(args) + 1
offsetPos := len(args) + 2 offsetPos := len(args) + 2
listArgs := append(append([]any{}, args...), params.Limit(), params.Offset()) listArgs := append(append([]any{}, args...), params.Limit(), params.Offset())
query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY id DESC LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, limitPos, offsetPos) query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY %s LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, usageLogOrderBy(params), limitPos, offsetPos)
logs, err := r.queryUsageLogs(ctx, query, listArgs...) logs, err := r.queryUsageLogs(ctx, query, listArgs...)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
...@@ -3786,7 +3786,7 @@ func (r *usageLogRepository) listUsageLogsWithFastPagination(ctx context.Context ...@@ -3786,7 +3786,7 @@ func (r *usageLogRepository) listUsageLogsWithFastPagination(ctx context.Context
limitPos := len(args) + 1 limitPos := len(args) + 1
offsetPos := len(args) + 2 offsetPos := len(args) + 2
listArgs := append(append([]any{}, args...), limit+1, offset) listArgs := append(append([]any{}, args...), limit+1, offset)
query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY id DESC LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, limitPos, offsetPos) query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY %s LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, usageLogOrderBy(params), limitPos, offsetPos)
logs, err := r.queryUsageLogs(ctx, query, listArgs...) logs, err := r.queryUsageLogs(ctx, query, listArgs...)
if err != nil { if err != nil {
...@@ -3808,6 +3808,26 @@ func (r *usageLogRepository) listUsageLogsWithFastPagination(ctx context.Context ...@@ -3808,6 +3808,26 @@ func (r *usageLogRepository) listUsageLogsWithFastPagination(ctx context.Context
return logs, paginationResultFromTotal(total, params), nil return logs, paginationResultFromTotal(total, params), nil
} }
func usageLogOrderBy(params pagination.PaginationParams) string {
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
sortOrder := strings.ToUpper(params.NormalizedSortOrder(pagination.SortOrderDesc))
var column string
switch sortBy {
case "model":
column = "COALESCE(NULLIF(TRIM(requested_model), ''), model)"
case "created_at":
column = "created_at"
default:
column = "id"
}
if column == "id" {
return fmt.Sprintf("id %s", sortOrder)
}
return fmt.Sprintf("%s %s, id %s", column, sortOrder, sortOrder)
}
func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, args ...any) (logs []service.UsageLog, err error) { func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, args ...any) (logs []service.UsageLog, err error) {
rows, err := r.sql.QueryContext(ctx, query, args...) rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
......
...@@ -330,6 +330,15 @@ func TestUsageLogRepositoryGetStatsWithFiltersRequestTypePriority(t *testing.T) ...@@ -330,6 +330,15 @@ func TestUsageLogRepositoryGetStatsWithFiltersRequestTypePriority(t *testing.T)
"total_account_cost", "total_account_cost",
"avg_duration_ms", "avg_duration_ms",
}).AddRow(int64(1), int64(2), int64(3), int64(4), 1.2, 1.0, 1.2, 20.0)) }).AddRow(int64(1), int64(2), int64(3), int64(4), 1.2, 1.0, 1.2, 20.0))
mock.ExpectQuery("SELECT COALESCE\\(NULLIF\\(TRIM\\(inbound_endpoint\\), ''\\), 'unknown'\\) AS endpoint").
WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), requestType).
WillReturnRows(sqlmock.NewRows([]string{"endpoint", "requests", "total_tokens", "cost", "actual_cost"}))
mock.ExpectQuery("SELECT COALESCE\\(NULLIF\\(TRIM\\(upstream_endpoint\\), ''\\), 'unknown'\\) AS endpoint").
WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), requestType).
WillReturnRows(sqlmock.NewRows([]string{"endpoint", "requests", "total_tokens", "cost", "actual_cost"}))
mock.ExpectQuery("SELECT CONCAT\\(").
WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), requestType).
WillReturnRows(sqlmock.NewRows([]string{"endpoint", "requests", "total_tokens", "cost", "actual_cost"}))
stats, err := repo.GetStatsWithFilters(context.Background(), filters) stats, err := repo.GetStatsWithFilters(context.Background(), filters)
require.NoError(t, err) require.NoError(t, err)
......
//go:build integration
package repository
import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/google/uuid"
)
func (s *UsageLogRepoSuite) TestListWithFilters_SortByModelAsc() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "usage-sort@example.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-usage-sort", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "usage-sort-account"})
first := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.New().String(),
Model: "z-model",
RequestedModel: "z-model",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now(),
}
_, err := s.repo.Create(s.ctx, first)
s.Require().NoError(err)
second := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.New().String(),
Model: "a-model",
RequestedModel: "a-model",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().Add(time.Second),
}
_, err = s.repo.Create(s.ctx, second)
s.Require().NoError(err)
logs, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
Page: 1,
PageSize: 10,
SortBy: "model",
SortOrder: "asc",
}, usagestats.UsageLogFilters{UserID: user.ID})
s.Require().NoError(err)
s.Require().Len(logs, 2)
s.Require().Equal("a-model", logs[0].RequestedModel)
s.Require().Equal("z-model", logs[1].RequestedModel)
}
...@@ -17,6 +17,8 @@ import ( ...@@ -17,6 +17,8 @@ import (
"github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
entsql "entgo.io/ent/dialect/sql"
) )
type userRepository struct { type userRepository struct {
...@@ -224,11 +226,14 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination. ...@@ -224,11 +226,14 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
return nil, nil, err return nil, nil, err
} }
users, err := q. usersQuery := q.
Offset(params.Offset()). Offset(params.Offset()).
Limit(params.Limit()). Limit(params.Limit())
Order(dbent.Desc(dbuser.FieldID)). for _, order := range userListOrder(params) {
All(ctx) usersQuery = usersQuery.Order(order)
}
users, err := usersQuery.All(ctx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
...@@ -281,6 +286,50 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination. ...@@ -281,6 +286,50 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
return outUsers, paginationResultFromTotal(int64(total), params), nil return outUsers, paginationResultFromTotal(int64(total), params), nil
} }
func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
var field string
defaultField := true
switch sortBy {
case "email":
field = dbuser.FieldEmail
defaultField = false
case "username":
field = dbuser.FieldUsername
defaultField = false
case "role":
field = dbuser.FieldRole
defaultField = false
case "balance":
field = dbuser.FieldBalance
defaultField = false
case "concurrency":
field = dbuser.FieldConcurrency
defaultField = false
case "status":
field = dbuser.FieldStatus
defaultField = false
case "created_at":
field = dbuser.FieldCreatedAt
defaultField = false
default:
field = dbuser.FieldID
}
if sortOrder == pagination.SortOrderAsc {
if defaultField && field == dbuser.FieldID {
return []func(*entsql.Selector){dbent.Asc(dbuser.FieldID)}
}
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(dbuser.FieldID)}
}
if defaultField && field == dbuser.FieldID {
return []func(*entsql.Selector){dbent.Desc(dbuser.FieldID)}
}
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(dbuser.FieldID)}
}
// filterUsersByAttributes returns user IDs that match ALL the given attribute filters // filterUsersByAttributes returns user IDs that match ALL the given attribute filters
func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) ([]int64, error) { func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) ([]int64, error) {
if len(attrs) == 0 { if len(attrs) == 0 {
......
//go:build integration
package repository
import (
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func (s *UserRepoSuite) TestListWithFilters_SortByEmailAsc() {
s.mustCreateUser(&service.User{Email: "z-last@example.com", Username: "z-user"})
s.mustCreateUser(&service.User{Email: "a-first@example.com", Username: "a-user"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
Page: 1,
PageSize: 10,
SortBy: "email",
SortOrder: "asc",
}, service.UserListFilters{})
s.Require().NoError(err)
s.Require().Len(users, 2)
s.Require().Equal("a-first@example.com", users[0].Email)
s.Require().Equal("z-last@example.com", users[1].Email)
}
func (s *UserRepoSuite) TestList_DefaultSortByNewestFirst() {
first := s.mustCreateUser(&service.User{Email: "first@example.com"})
second := s.mustCreateUser(&service.User{Email: "second@example.com"})
users, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err)
s.Require().Len(users, 2)
s.Require().Equal(second.ID, users[0].ID)
s.Require().Equal(first.ID, users[1].ID)
}
func TestUserRepoSortSuiteSmoke(_ *testing.T) {}
...@@ -493,6 +493,8 @@ func TestAPIContracts(t *testing.T) { ...@@ -493,6 +493,8 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeyDefaultConcurrency: "5", service.SettingKeyDefaultConcurrency: "5",
service.SettingKeyDefaultBalance: "1.25", service.SettingKeyDefaultBalance: "1.25",
service.SettingKeyTableDefaultPageSize: "20",
service.SettingKeyTablePageSizeOptions: "[10,20,50,100]",
service.SettingKeyOpsMonitoringEnabled: "false", service.SettingKeyOpsMonitoringEnabled: "false",
service.SettingKeyOpsRealtimeMonitoringEnabled: "true", service.SettingKeyOpsRealtimeMonitoringEnabled: "true",
...@@ -576,6 +578,8 @@ func TestAPIContracts(t *testing.T) { ...@@ -576,6 +578,8 @@ func TestAPIContracts(t *testing.T) {
"hide_ccs_import_button": false, "hide_ccs_import_button": false,
"purchase_subscription_enabled": false, "purchase_subscription_enabled": false,
"purchase_subscription_url": "", "purchase_subscription_url": "",
"table_default_page_size": 20,
"table_page_size_options": [10, 20, 50, 100],
"min_claude_code_version": "", "min_claude_code_version": "",
"max_claude_code_version": "", "max_claude_code_version": "",
"allow_ungrouped_key_scheduling": false, "allow_ungrouped_key_scheduling": false,
......
...@@ -21,13 +21,13 @@ import ( ...@@ -21,13 +21,13 @@ import (
// AdminService interface defines admin management operations // AdminService interface defines admin management operations
type AdminService interface { type AdminService interface {
// User management // User management
ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error) ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters, sortBy, sortOrder string) ([]User, int64, error)
GetUser(ctx context.Context, id int64) (*User, error) GetUser(ctx context.Context, id int64) (*User, error)
CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error)
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
DeleteUser(ctx context.Context, id int64) error DeleteUser(ctx context.Context, id int64) error
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error)
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
// GetUserBalanceHistory returns paginated balance/concurrency change records for a user. // GetUserBalanceHistory returns paginated balance/concurrency change records for a user.
// codeType is optional - pass empty string to return all types. // codeType is optional - pass empty string to return all types.
...@@ -35,7 +35,7 @@ type AdminService interface { ...@@ -35,7 +35,7 @@ type AdminService interface {
GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error)
// Group management // Group management
ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error)
GetAllGroups(ctx context.Context) ([]Group, error) GetAllGroups(ctx context.Context) ([]Group, error)
GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error)
GetGroup(ctx context.Context, id int64) (*Group, error) GetGroup(ctx context.Context, id int64) (*Group, error)
...@@ -55,7 +55,7 @@ type AdminService interface { ...@@ -55,7 +55,7 @@ type AdminService interface {
ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error) ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error)
// Account management // Account management
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, int64, error) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string, sortBy, sortOrder string) ([]Account, int64, error)
GetAccount(ctx context.Context, id int64) (*Account, error) GetAccount(ctx context.Context, id int64) (*Account, error)
GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error)
CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
...@@ -77,8 +77,8 @@ type AdminService interface { ...@@ -77,8 +77,8 @@ type AdminService interface {
CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error
// Proxy management // Proxy management
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string, sortBy, sortOrder string) ([]Proxy, int64, error)
ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]ProxyWithAccountCount, int64, error) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string, sortBy, sortOrder string) ([]ProxyWithAccountCount, int64, error)
GetAllProxies(ctx context.Context) ([]Proxy, error) GetAllProxies(ctx context.Context) ([]Proxy, error)
GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error)
GetProxy(ctx context.Context, id int64) (*Proxy, error) GetProxy(ctx context.Context, id int64) (*Proxy, error)
...@@ -93,7 +93,7 @@ type AdminService interface { ...@@ -93,7 +93,7 @@ type AdminService interface {
CheckProxyQuality(ctx context.Context, id int64) (*ProxyQualityCheckResult, error) CheckProxyQuality(ctx context.Context, id int64) (*ProxyQualityCheckResult, error)
// Redeem code management // Redeem code management
ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string, sortBy, sortOrder string) ([]RedeemCode, int64, error)
GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error) GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error)
DeleteRedeemCode(ctx context.Context, id int64) error DeleteRedeemCode(ctx context.Context, id int64) error
...@@ -485,8 +485,8 @@ func NewAdminService( ...@@ -485,8 +485,8 @@ func NewAdminService(
} }
// User management implementations // User management implementations
func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error) { func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters, sortBy, sortOrder string) ([]User, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
users, result, err := s.userRepo.ListWithFilters(ctx, params, filters) users, result, err := s.userRepo.ListWithFilters(ctx, params, filters)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
...@@ -753,8 +753,8 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, ...@@ -753,8 +753,8 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
return user, nil return user, nil
} }
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) { func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params, APIKeyListFilters{}) keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params, APIKeyListFilters{})
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
...@@ -789,8 +789,8 @@ func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int ...@@ -789,8 +789,8 @@ func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int
} }
// Group management implementations // Group management implementations
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) { func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, search, isExclusive) groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, search, isExclusive)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
...@@ -1464,8 +1464,8 @@ func (s *adminServiceImpl) ReplaceUserGroup(ctx context.Context, userID, oldGrou ...@@ -1464,8 +1464,8 @@ func (s *adminServiceImpl) ReplaceUserGroup(ctx context.Context, userID, oldGrou
} }
// Account management implementations // Account management implementations
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, int64, error) { func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string, sortBy, sortOrder string) ([]Account, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID, privacyMode) accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID, privacyMode)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
...@@ -1893,8 +1893,8 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, ...@@ -1893,8 +1893,8 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64,
} }
// Proxy management implementations // Proxy management implementations
func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) { func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string, sortBy, sortOrder string) ([]Proxy, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search) proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
...@@ -1902,8 +1902,8 @@ func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, ...@@ -1902,8 +1902,8 @@ func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int,
return proxies, result.Total, nil return proxies, result.Total, nil
} }
func (s *adminServiceImpl) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]ProxyWithAccountCount, int64, error) { func (s *adminServiceImpl) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string, sortBy, sortOrder string) ([]ProxyWithAccountCount, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
proxies, result, err := s.proxyRepo.ListWithFiltersAndAccountCount(ctx, params, protocol, status, search) proxies, result, err := s.proxyRepo.ListWithFiltersAndAccountCount(ctx, params, protocol, status, search)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
...@@ -2040,8 +2040,8 @@ func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, po ...@@ -2040,8 +2040,8 @@ func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, po
} }
// Redeem code management implementations // Redeem code management implementations
func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error) { func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string, sortBy, sortOrder string) ([]RedeemCode, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search) codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
......
...@@ -125,6 +125,22 @@ func (s *groupRepoStubForAdmin) UpdateSortOrders(_ context.Context, _ []GroupSor ...@@ -125,6 +125,22 @@ func (s *groupRepoStubForAdmin) UpdateSortOrders(_ context.Context, _ []GroupSor
return nil return nil
} }
func TestAdminService_ListGroups_PassesSortParams(t *testing.T) {
repo := &groupRepoStubForAdmin{
listWithFiltersGroups: []Group{{ID: 1, Name: "g1"}},
}
svc := &adminServiceImpl{groupRepo: repo}
_, _, err := svc.ListGroups(context.Background(), 3, 25, PlatformOpenAI, StatusActive, "needle", nil, "account_count", "ASC")
require.NoError(t, err)
require.Equal(t, pagination.PaginationParams{
Page: 3,
PageSize: 25,
SortBy: "account_count",
SortOrder: "ASC",
}, repo.listWithFiltersParams)
}
// TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递 // TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递
func TestAdminService_CreateGroup_WithImagePricing(t *testing.T) { func TestAdminService_CreateGroup_WithImagePricing(t *testing.T) {
repo := &groupRepoStubForAdmin{} repo := &groupRepoStubForAdmin{}
...@@ -373,7 +389,7 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) { ...@@ -373,7 +389,7 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) {
} }
svc := &adminServiceImpl{groupRepo: repo} svc := &adminServiceImpl{groupRepo: repo}
groups, total, err := svc.ListGroups(context.Background(), 1, 20, "", "", "alpha", nil) groups, total, err := svc.ListGroups(context.Background(), 1, 20, "", "", "alpha", nil, "", "")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(1), total) require.Equal(t, int64(1), total)
require.Equal(t, []Group{{ID: 1, Name: "alpha"}}, groups) require.Equal(t, []Group{{ID: 1, Name: "alpha"}}, groups)
...@@ -391,7 +407,7 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) { ...@@ -391,7 +407,7 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) {
} }
svc := &adminServiceImpl{groupRepo: repo} svc := &adminServiceImpl{groupRepo: repo}
groups, total, err := svc.ListGroups(context.Background(), 2, 10, "", "", "", nil) groups, total, err := svc.ListGroups(context.Background(), 2, 10, "", "", "", nil, "", "")
require.NoError(t, err) require.NoError(t, err)
require.Empty(t, groups) require.Empty(t, groups)
require.Equal(t, int64(0), total) require.Equal(t, int64(0), total)
...@@ -410,7 +426,7 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) { ...@@ -410,7 +426,7 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) {
} }
svc := &adminServiceImpl{groupRepo: repo} svc := &adminServiceImpl{groupRepo: repo}
groups, total, err := svc.ListGroups(context.Background(), 3, 50, PlatformAntigravity, StatusActive, "beta", &isExclusive) groups, total, err := svc.ListGroups(context.Background(), 3, 50, PlatformAntigravity, StatusActive, "beta", &isExclusive, "", "")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(42), total) require.Equal(t, int64(42), total)
require.Equal(t, []Group{{ID: 2, Name: "beta"}}, groups) require.Equal(t, []Group{{ID: 2, Name: "beta"}}, groups)
......
...@@ -15,9 +15,11 @@ type userRepoStubForListUsers struct { ...@@ -15,9 +15,11 @@ type userRepoStubForListUsers struct {
userRepoStub userRepoStub
users []User users []User
err error err error
listWithFiltersParams pagination.PaginationParams
} }
func (s *userRepoStubForListUsers) ListWithFilters(_ context.Context, params pagination.PaginationParams, _ UserListFilters) ([]User, *pagination.PaginationResult, error) { func (s *userRepoStubForListUsers) ListWithFilters(_ context.Context, params pagination.PaginationParams, _ UserListFilters) ([]User, *pagination.PaginationResult, error) {
s.listWithFiltersParams = params
if s.err != nil { if s.err != nil {
return nil, nil, s.err return nil, nil, s.err
} }
...@@ -103,7 +105,7 @@ func TestAdminService_ListUsers_BatchRateFallbackToSingle(t *testing.T) { ...@@ -103,7 +105,7 @@ func TestAdminService_ListUsers_BatchRateFallbackToSingle(t *testing.T) {
userGroupRateRepo: rateRepo, userGroupRateRepo: rateRepo,
} }
users, total, err := svc.ListUsers(context.Background(), 1, 20, UserListFilters{}) users, total, err := svc.ListUsers(context.Background(), 1, 20, UserListFilters{}, "", "")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(2), total) require.Equal(t, int64(2), total)
require.Len(t, users, 2) require.Len(t, users, 2)
...@@ -112,3 +114,19 @@ func TestAdminService_ListUsers_BatchRateFallbackToSingle(t *testing.T) { ...@@ -112,3 +114,19 @@ func TestAdminService_ListUsers_BatchRateFallbackToSingle(t *testing.T) {
require.Equal(t, 1.1, users[0].GroupRates[11]) require.Equal(t, 1.1, users[0].GroupRates[11])
require.Equal(t, 2.2, users[1].GroupRates[22]) require.Equal(t, 2.2, users[1].GroupRates[22])
} }
func TestAdminService_ListUsers_PassesSortParams(t *testing.T) {
userRepo := &userRepoStubForListUsers{
users: []User{{ID: 1, Email: "a@example.com"}},
}
svc := &adminServiceImpl{userRepo: userRepo}
_, _, err := svc.ListUsers(context.Background(), 2, 50, UserListFilters{}, "email", "ASC")
require.NoError(t, err)
require.Equal(t, pagination.PaginationParams{
Page: 2,
PageSize: 50,
SortBy: "email",
SortOrder: "ASC",
}, userRepo.listWithFiltersParams)
}
...@@ -170,13 +170,13 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) { ...@@ -170,13 +170,13 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
} }
svc := &adminServiceImpl{accountRepo: repo} svc := &adminServiceImpl{accountRepo: repo}
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0, "") accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0, "", "name", "ASC")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(10), total) require.Equal(t, int64(10), total)
require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts) require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts)
require.Equal(t, 1, repo.listWithFiltersCalls) require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams) require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20, SortBy: "name", SortOrder: "ASC"}, repo.listWithFiltersParams)
require.Equal(t, PlatformGemini, repo.listWithFiltersPlatform) require.Equal(t, PlatformGemini, repo.listWithFiltersPlatform)
require.Equal(t, AccountTypeOAuth, repo.listWithFiltersType) require.Equal(t, AccountTypeOAuth, repo.listWithFiltersType)
require.Equal(t, StatusActive, repo.listWithFiltersStatus) require.Equal(t, StatusActive, repo.listWithFiltersStatus)
...@@ -192,7 +192,7 @@ func TestAdminService_ListAccounts_WithPrivacyMode(t *testing.T) { ...@@ -192,7 +192,7 @@ func TestAdminService_ListAccounts_WithPrivacyMode(t *testing.T) {
} }
svc := &adminServiceImpl{accountRepo: repo} svc := &adminServiceImpl{accountRepo: repo}
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, StatusActive, "acc2", 0, PrivacyModeCFBlocked) accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, StatusActive, "acc2", 0, PrivacyModeCFBlocked, "", "")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(1), total) require.Equal(t, int64(1), total)
require.Equal(t, []Account{{ID: 2, Name: "acc2"}}, accounts) require.Equal(t, []Account{{ID: 2, Name: "acc2"}}, accounts)
...@@ -208,13 +208,13 @@ func TestAdminService_ListProxies_WithSearch(t *testing.T) { ...@@ -208,13 +208,13 @@ func TestAdminService_ListProxies_WithSearch(t *testing.T) {
} }
svc := &adminServiceImpl{proxyRepo: repo} svc := &adminServiceImpl{proxyRepo: repo}
proxies, total, err := svc.ListProxies(context.Background(), 3, 50, "http", StatusActive, "p1") proxies, total, err := svc.ListProxies(context.Background(), 3, 50, "http", StatusActive, "p1", "name", "ASC")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(7), total) require.Equal(t, int64(7), total)
require.Equal(t, []Proxy{{ID: 2, Name: "p1"}}, proxies) require.Equal(t, []Proxy{{ID: 2, Name: "p1"}}, proxies)
require.Equal(t, 1, repo.listWithFiltersCalls) require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50}, repo.listWithFiltersParams) require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50, SortBy: "name", SortOrder: "ASC"}, repo.listWithFiltersParams)
require.Equal(t, "http", repo.listWithFiltersProtocol) require.Equal(t, "http", repo.listWithFiltersProtocol)
require.Equal(t, StatusActive, repo.listWithFiltersStatus) require.Equal(t, StatusActive, repo.listWithFiltersStatus)
require.Equal(t, "p1", repo.listWithFiltersSearch) require.Equal(t, "p1", repo.listWithFiltersSearch)
...@@ -229,13 +229,13 @@ func TestAdminService_ListProxiesWithAccountCount_WithSearch(t *testing.T) { ...@@ -229,13 +229,13 @@ func TestAdminService_ListProxiesWithAccountCount_WithSearch(t *testing.T) {
} }
svc := &adminServiceImpl{proxyRepo: repo} svc := &adminServiceImpl{proxyRepo: repo}
proxies, total, err := svc.ListProxiesWithAccountCount(context.Background(), 2, 10, "socks5", StatusDisabled, "p2") proxies, total, err := svc.ListProxiesWithAccountCount(context.Background(), 2, 10, "socks5", StatusDisabled, "p2", "account_count", "DESC")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(9), total) require.Equal(t, int64(9), total)
require.Equal(t, []ProxyWithAccountCount{{Proxy: Proxy{ID: 3, Name: "p2"}, AccountCount: 5}}, proxies) require.Equal(t, []ProxyWithAccountCount{{Proxy: Proxy{ID: 3, Name: "p2"}, AccountCount: 5}}, proxies)
require.Equal(t, 1, repo.listWithFiltersAndAccountCountCalls) require.Equal(t, 1, repo.listWithFiltersAndAccountCountCalls)
require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10}, repo.listWithFiltersAndAccountCountParams) require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10, SortBy: "account_count", SortOrder: "DESC"}, repo.listWithFiltersAndAccountCountParams)
require.Equal(t, "socks5", repo.listWithFiltersAndAccountCountProtocol) require.Equal(t, "socks5", repo.listWithFiltersAndAccountCountProtocol)
require.Equal(t, StatusDisabled, repo.listWithFiltersAndAccountCountStatus) require.Equal(t, StatusDisabled, repo.listWithFiltersAndAccountCountStatus)
require.Equal(t, "p2", repo.listWithFiltersAndAccountCountSearch) require.Equal(t, "p2", repo.listWithFiltersAndAccountCountSearch)
...@@ -250,13 +250,13 @@ func TestAdminService_ListRedeemCodes_WithSearch(t *testing.T) { ...@@ -250,13 +250,13 @@ func TestAdminService_ListRedeemCodes_WithSearch(t *testing.T) {
} }
svc := &adminServiceImpl{redeemCodeRepo: repo} svc := &adminServiceImpl{redeemCodeRepo: repo}
codes, total, err := svc.ListRedeemCodes(context.Background(), 1, 20, RedeemTypeBalance, StatusUnused, "ABC") codes, total, err := svc.ListRedeemCodes(context.Background(), 1, 20, RedeemTypeBalance, StatusUnused, "ABC", "value", "ASC")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(3), total) require.Equal(t, int64(3), total)
require.Equal(t, []RedeemCode{{ID: 4, Code: "ABC"}}, codes) require.Equal(t, []RedeemCode{{ID: 4, Code: "ABC"}}, codes)
require.Equal(t, 1, repo.listWithFiltersCalls) require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams) require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20, SortBy: "value", SortOrder: "ASC"}, repo.listWithFiltersParams)
require.Equal(t, RedeemTypeBalance, repo.listWithFiltersType) require.Equal(t, RedeemTypeBalance, repo.listWithFiltersType)
require.Equal(t, StatusUnused, repo.listWithFiltersStatus) require.Equal(t, StatusUnused, repo.listWithFiltersStatus)
require.Equal(t, "ABC", repo.listWithFiltersSearch) require.Equal(t, "ABC", repo.listWithFiltersSearch)
......
...@@ -4,6 +4,7 @@ import "time" ...@@ -4,6 +4,7 @@ import "time"
// APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段) // APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段)
type APIKeyAuthSnapshot struct { type APIKeyAuthSnapshot struct {
Version int `json:"version"`
APIKeyID int64 `json:"api_key_id"` APIKeyID int64 `json:"api_key_id"`
UserID int64 `json:"user_id"` UserID int64 `json:"user_id"`
GroupID *int64 `json:"group_id,omitempty"` GroupID *int64 `json:"group_id,omitempty"`
...@@ -65,6 +66,7 @@ type APIKeyAuthGroupSnapshot struct { ...@@ -65,6 +66,7 @@ type APIKeyAuthGroupSnapshot struct {
// OpenAI Messages 调度配置(仅 openai 平台使用) // OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch bool `json:"allow_messages_dispatch"` AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
DefaultMappedModel string `json:"default_mapped_model,omitempty"` DefaultMappedModel string `json:"default_mapped_model,omitempty"`
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"`
} }
// APIKeyAuthCacheEntry 缓存条目,支持负缓存 // APIKeyAuthCacheEntry 缓存条目,支持负缓存
......
...@@ -13,6 +13,8 @@ import ( ...@@ -13,6 +13,8 @@ import (
"github.com/dgraph-io/ristretto" "github.com/dgraph-io/ristretto"
) )
const apiKeyAuthSnapshotVersion = 3
type apiKeyAuthCacheConfig struct { type apiKeyAuthCacheConfig struct {
l1Size int l1Size int
l1TTL time.Duration l1TTL time.Duration
...@@ -192,6 +194,9 @@ func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEn ...@@ -192,6 +194,9 @@ func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEn
if entry.Snapshot == nil { if entry.Snapshot == nil {
return nil, false, nil return nil, false, nil
} }
if entry.Snapshot.Version != apiKeyAuthSnapshotVersion {
return nil, false, nil
}
return s.snapshotToAPIKey(key, entry.Snapshot), true, nil return s.snapshotToAPIKey(key, entry.Snapshot), true, nil
} }
...@@ -200,6 +205,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { ...@@ -200,6 +205,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
return nil return nil
} }
snapshot := &APIKeyAuthSnapshot{ snapshot := &APIKeyAuthSnapshot{
Version: apiKeyAuthSnapshotVersion,
APIKeyID: apiKey.ID, APIKeyID: apiKey.ID,
UserID: apiKey.UserID, UserID: apiKey.UserID,
GroupID: apiKey.GroupID, GroupID: apiKey.GroupID,
...@@ -243,6 +249,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { ...@@ -243,6 +249,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
SupportedModelScopes: apiKey.Group.SupportedModelScopes, SupportedModelScopes: apiKey.Group.SupportedModelScopes,
AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch, AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch,
DefaultMappedModel: apiKey.Group.DefaultMappedModel, DefaultMappedModel: apiKey.Group.DefaultMappedModel,
MessagesDispatchModelConfig: apiKey.Group.MessagesDispatchModelConfig,
} }
} }
return snapshot return snapshot
...@@ -298,6 +305,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho ...@@ -298,6 +305,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
SupportedModelScopes: snapshot.Group.SupportedModelScopes, SupportedModelScopes: snapshot.Group.SupportedModelScopes,
AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch, AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch,
DefaultMappedModel: snapshot.Group.DefaultMappedModel, DefaultMappedModel: snapshot.Group.DefaultMappedModel,
MessagesDispatchModelConfig: snapshot.Group.MessagesDispatchModelConfig,
} }
} }
s.compileAPIKeyIPRules(apiKey) s.compileAPIKeyIPRules(apiKey)
......
...@@ -188,6 +188,7 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) { ...@@ -188,6 +188,7 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
groupID := int64(9) groupID := int64(9)
cacheEntry := &APIKeyAuthCacheEntry{ cacheEntry := &APIKeyAuthCacheEntry{
Snapshot: &APIKeyAuthSnapshot{ Snapshot: &APIKeyAuthSnapshot{
Version: apiKeyAuthSnapshotVersion,
APIKeyID: 1, APIKeyID: 1,
UserID: 2, UserID: 2,
GroupID: &groupID, GroupID: &groupID,
...@@ -226,6 +227,129 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) { ...@@ -226,6 +227,129 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
require.Equal(t, map[string][]int64{"claude-opus-*": {1, 2}}, apiKey.Group.ModelRouting) require.Equal(t, map[string][]int64{"claude-opus-*": {1, 2}}, apiKey.Group.ModelRouting)
} }
func TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig(t *testing.T) {
svc := NewAPIKeyService(nil, nil, nil, nil, nil, nil, &config.Config{})
groupID := int64(9)
apiKey := &APIKey{
ID: 1,
UserID: 2,
GroupID: &groupID,
Key: "k-roundtrip",
Status: StatusActive,
User: &User{
ID: 2,
Status: StatusActive,
Role: RoleUser,
Balance: 10,
Concurrency: 3,
},
Group: &Group{
ID: groupID,
Name: "openai",
Platform: PlatformOpenAI,
Status: StatusActive,
SubscriptionType: SubscriptionTypeStandard,
RateMultiplier: 1,
AllowMessagesDispatch: true,
DefaultMappedModel: "gpt-5.4",
MessagesDispatchModelConfig: OpenAIMessagesDispatchModelConfig{
OpusMappedModel: "gpt-5.4-nano",
SonnetMappedModel: "gpt-5.3-codex",
HaikuMappedModel: "gpt-5.4-mini",
ExactModelMappings: map[string]string{
"claude-sonnet-4.5": "gpt-5.4-nano",
},
},
},
}
snapshot := svc.snapshotFromAPIKey(apiKey)
roundTrip := svc.snapshotToAPIKey(apiKey.Key, snapshot)
require.NotNil(t, roundTrip)
require.NotNil(t, roundTrip.Group)
require.Equal(t, apiKey.Group.MessagesDispatchModelConfig, roundTrip.Group.MessagesDispatchModelConfig)
}
func TestAPIKeyService_GetByKey_IgnoresLegacyAuthCacheSnapshotWithoutMessagesDispatchConfig(t *testing.T) {
cache := &authCacheStub{}
var repoCalls int32
repo := &authRepoStub{
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
atomic.AddInt32(&repoCalls, 1)
groupID := int64(9)
return &APIKey{
ID: 1,
UserID: 2,
GroupID: &groupID,
Status: StatusActive,
User: &User{
ID: 2,
Status: StatusActive,
Role: RoleUser,
Balance: 10,
Concurrency: 3,
},
Group: &Group{
ID: groupID,
Name: "openai",
Platform: PlatformOpenAI,
Status: StatusActive,
Hydrated: true,
SubscriptionType: SubscriptionTypeStandard,
RateMultiplier: 1,
AllowMessagesDispatch: true,
DefaultMappedModel: "gpt-5.4",
MessagesDispatchModelConfig: OpenAIMessagesDispatchModelConfig{
OpusMappedModel: "gpt-5.4-nano",
},
},
}, nil
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L2TTLSeconds: 60,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
groupID := int64(9)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return &APIKeyAuthCacheEntry{
Snapshot: &APIKeyAuthSnapshot{
APIKeyID: 1,
UserID: 2,
GroupID: &groupID,
Status: StatusActive,
User: APIKeyAuthUserSnapshot{
ID: 2,
Status: StatusActive,
Role: RoleUser,
Balance: 10,
Concurrency: 3,
},
Group: &APIKeyAuthGroupSnapshot{
ID: groupID,
Name: "openai",
Platform: PlatformOpenAI,
Status: StatusActive,
SubscriptionType: SubscriptionTypeStandard,
RateMultiplier: 1,
AllowMessagesDispatch: true,
DefaultMappedModel: "gpt-5.4",
},
},
}, nil
}
apiKey, err := svc.GetByKey(context.Background(), "k-legacy")
require.NoError(t, err)
require.Equal(t, int32(1), atomic.LoadInt32(&repoCalls))
require.NotNil(t, apiKey.Group)
require.Equal(t, "gpt-5.4-nano", apiKey.Group.MessagesDispatchModelConfig.OpusMappedModel)
}
func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) { func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) {
cache := &authCacheStub{} cache := &authCacheStub{}
repo := &authRepoStub{ repo := &authRepoStub{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment