Commit c520de11 authored by qingyuzhang's avatar qingyuzhang
Browse files

Merge branch 'main' of github.com:Wei-Shaw/sub2api into qingyu/fix-smooth-sidebar-collapse

# Conflicts:
#	frontend/src/components/layout/AppSidebar.vue
parents 07d2add6 97f14b7a
//go:build integration
package repository
import (
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func (s *GroupRepoSuite) TestList_DefaultSortBySortOrderAsc() {
g1 := &service.Group{Name: "g1", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 20}
g2 := &service.Group{Name: "g2", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 10}
s.Require().NoError(s.repo.Create(s.ctx, g1))
s.Require().NoError(s.repo.Create(s.ctx, g2))
groups, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 100})
s.Require().NoError(err)
s.Require().GreaterOrEqual(len(groups), 2)
indexByID := make(map[int64]int, len(groups))
for i, g := range groups {
indexByID[g.ID] = i
}
s.Require().Contains(indexByID, g1.ID)
s.Require().Contains(indexByID, g2.ID)
// g2 has SortOrder=10, g1 has SortOrder=20; ascending means g2 comes first
s.Require().Less(indexByID[g2.ID], indexByID[g1.ID])
}
func (s *GroupRepoSuite) TestList_SortBySortOrderDesc() {
g1 := &service.Group{Name: "g1", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 40}
g2 := &service.Group{Name: "g2", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 50}
s.Require().NoError(s.repo.Create(s.ctx, g1))
s.Require().NoError(s.repo.Create(s.ctx, g2))
groups, _, err := s.repo.List(s.ctx, pagination.PaginationParams{
Page: 1,
PageSize: 10,
SortBy: "sort_order",
SortOrder: "desc",
})
s.Require().NoError(err)
s.Require().GreaterOrEqual(len(groups), 2)
indexByID := make(map[int64]int, len(groups))
for i, group := range groups {
indexByID[group.ID] = i
}
s.Require().Contains(indexByID, g1.ID)
s.Require().Contains(indexByID, g2.ID)
s.Require().Less(indexByID[g2.ID], indexByID[g1.ID])
}
...@@ -14,3 +14,22 @@ func paginationResultFromTotal(total int64, params pagination.PaginationParams) ...@@ -14,3 +14,22 @@ func paginationResultFromTotal(total int64, params pagination.PaginationParams)
Pages: pages, Pages: pages,
} }
} }
func paginateSlice[T any](items []T, params pagination.PaginationParams) []T {
if len(items) == 0 {
return []T{}
}
offset := params.Offset()
if offset >= len(items) {
return []T{}
}
limit := params.Limit()
end := offset + limit
if end > len(items) {
end = len(items)
}
return items[offset:end]
}
...@@ -2,12 +2,15 @@ package repository ...@@ -2,12 +2,15 @@ package repository
import ( import (
"context" "context"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
entsql "entgo.io/ent/dialect/sql"
) )
type promoCodeRepository struct { type promoCodeRepository struct {
...@@ -137,11 +140,14 @@ func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagina ...@@ -137,11 +140,14 @@ func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagina
return nil, nil, err return nil, nil, err
} }
codes, err := q. codesQuery := q.
Offset(params.Offset()). Offset(params.Offset()).
Limit(params.Limit()). Limit(params.Limit())
Order(dbent.Desc(promocode.FieldID)). for _, order := range promoCodeListOrder(params) {
All(ctx) codesQuery = codesQuery.Order(order)
}
codes, err := codesQuery.All(ctx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
...@@ -151,6 +157,32 @@ func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagina ...@@ -151,6 +157,32 @@ func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagina
return outCodes, paginationResultFromTotal(int64(total), params), nil return outCodes, paginationResultFromTotal(int64(total), params), nil
} }
func promoCodeListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
var field string
switch sortBy {
case "bonus_amount":
field = promocode.FieldBonusAmount
case "status":
field = promocode.FieldStatus
case "expires_at":
field = promocode.FieldExpiresAt
case "created_at":
field = promocode.FieldCreatedAt
case "code":
field = promocode.FieldCode
default:
field = promocode.FieldID
}
if sortOrder == pagination.SortOrderAsc {
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(promocode.FieldID)}
}
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(promocode.FieldID)}
}
func (r *promoCodeRepository) CreateUsage(ctx context.Context, usage *service.PromoCodeUsage) error { func (r *promoCodeRepository) CreateUsage(ctx context.Context, usage *service.PromoCodeUsage) error {
client := clientFromContext(ctx, r.client) client := clientFromContext(ctx, r.client)
created, err := client.PromoCodeUsage.Create(). created, err := client.PromoCodeUsage.Create().
......
...@@ -3,12 +3,16 @@ package repository ...@@ -3,12 +3,16 @@ package repository
import ( import (
"context" "context"
"database/sql" "database/sql"
"sort"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/proxy" "github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
entsql "entgo.io/ent/dialect/sql"
) )
type sqlQuerier interface { type sqlQuerier interface {
...@@ -135,11 +139,14 @@ func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination ...@@ -135,11 +139,14 @@ func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination
return nil, nil, err return nil, nil, err
} }
proxies, err := q. proxiesQuery := q.
Offset(params.Offset()). Offset(params.Offset()).
Limit(params.Limit()). Limit(params.Limit())
Order(dbent.Desc(proxy.FieldID)). for _, order := range proxyListOrder(params) {
All(ctx) proxiesQuery = proxiesQuery.Order(order)
}
proxies, err := proxiesQuery.All(ctx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
...@@ -170,22 +177,58 @@ func (r *proxyRepository) ListWithFiltersAndAccountCount(ctx context.Context, pa ...@@ -170,22 +177,58 @@ func (r *proxyRepository) ListWithFiltersAndAccountCount(ctx context.Context, pa
return nil, nil, err return nil, nil, err
} }
proxies, err := q. if strings.EqualFold(strings.TrimSpace(params.SortBy), "account_count") {
return r.listWithAccountCountSort(ctx, q, params, total)
}
proxiesQuery := q.
Offset(params.Offset()). Offset(params.Offset()).
Limit(params.Limit()). Limit(params.Limit())
for _, order := range proxyListOrder(params) {
proxiesQuery = proxiesQuery.Order(order)
}
proxies, err := proxiesQuery.All(ctx)
if err != nil {
return nil, nil, err
}
return r.buildProxyWithAccountCountResult(ctx, proxies, params, int64(total))
}
func (r *proxyRepository) listWithAccountCountSort(ctx context.Context, q *dbent.ProxyQuery, params pagination.PaginationParams, total int) ([]service.ProxyWithAccountCount, *pagination.PaginationResult, error) {
proxies, err := q.
Order(dbent.Desc(proxy.FieldID)). Order(dbent.Desc(proxy.FieldID)).
All(ctx) All(ctx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
// Get account counts result, _, err := r.buildProxyWithAccountCountResult(ctx, proxies, params, int64(total))
if err != nil {
return nil, nil, err
}
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
sort.SliceStable(result, func(i, j int) bool {
if result[i].AccountCount == result[j].AccountCount {
return result[i].ID > result[j].ID
}
if sortOrder == pagination.SortOrderAsc {
return result[i].AccountCount < result[j].AccountCount
}
return result[i].AccountCount > result[j].AccountCount
})
return paginateSlice(result, params), paginationResultFromTotal(int64(total), params), nil
}
func (r *proxyRepository) buildProxyWithAccountCountResult(ctx context.Context, proxies []*dbent.Proxy, params pagination.PaginationParams, total int64) ([]service.ProxyWithAccountCount, *pagination.PaginationResult, error) {
counts, err := r.GetAccountCountsForProxies(ctx) counts, err := r.GetAccountCountsForProxies(ctx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
// Build result with account counts
result := make([]service.ProxyWithAccountCount, 0, len(proxies)) result := make([]service.ProxyWithAccountCount, 0, len(proxies))
for i := range proxies { for i := range proxies {
proxyOut := proxyEntityToService(proxies[i]) proxyOut := proxyEntityToService(proxies[i])
...@@ -198,7 +241,31 @@ func (r *proxyRepository) ListWithFiltersAndAccountCount(ctx context.Context, pa ...@@ -198,7 +241,31 @@ func (r *proxyRepository) ListWithFiltersAndAccountCount(ctx context.Context, pa
}) })
} }
return result, paginationResultFromTotal(int64(total), params), nil return result, paginationResultFromTotal(total, params), nil
}
func proxyListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
var field string
switch sortBy {
case "name":
field = proxy.FieldName
case "protocol":
field = proxy.FieldProtocol
case "status":
field = proxy.FieldStatus
case "created_at":
field = proxy.FieldCreatedAt
default:
field = proxy.FieldID
}
if sortOrder == pagination.SortOrderAsc {
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(proxy.FieldID)}
}
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(proxy.FieldID)}
} }
func (r *proxyRepository) ListActive(ctx context.Context) ([]service.Proxy, error) { func (r *proxyRepository) ListActive(ctx context.Context) ([]service.Proxy, error) {
......
//go:build integration
package repository
import (
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func (s *ProxyRepoSuite) TestListWithFiltersAndAccountCount_SortByAccountCountDesc() {
p1 := s.mustCreateProxy(&service.Proxy{Name: "p1", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
p2 := s.mustCreateProxy(&service.Proxy{Name: "p2", Protocol: "http", Host: "127.0.0.1", Port: 8081, Status: service.StatusActive})
s.mustInsertAccount("a1", &p1.ID)
s.mustInsertAccount("a2", &p1.ID)
s.mustInsertAccount("a3", &p2.ID)
proxies, _, err := s.repo.ListWithFiltersAndAccountCount(s.ctx, pagination.PaginationParams{
Page: 1,
PageSize: 10,
SortBy: "account_count",
SortOrder: "desc",
}, "", "", "")
s.Require().NoError(err)
s.Require().Len(proxies, 2)
s.Require().Equal(p1.ID, proxies[0].ID)
s.Require().Equal(int64(2), proxies[0].AccountCount)
s.Require().Equal(p2.ID, proxies[1].ID)
}
...@@ -2,6 +2,7 @@ package repository ...@@ -2,6 +2,7 @@ package repository
import ( import (
"context" "context"
"strings"
"time" "time"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
...@@ -9,6 +10,8 @@ import ( ...@@ -9,6 +10,8 @@ import (
"github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
entsql "entgo.io/ent/dialect/sql"
) )
type redeemCodeRepository struct { type redeemCodeRepository struct {
...@@ -120,13 +123,16 @@ func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagin ...@@ -120,13 +123,16 @@ func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagin
return nil, nil, err return nil, nil, err
} }
codes, err := q. codesQuery := q.
WithUser(). WithUser().
WithGroup(). WithGroup().
Offset(params.Offset()). Offset(params.Offset()).
Limit(params.Limit()). Limit(params.Limit())
Order(dbent.Desc(redeemcode.FieldID)). for _, order := range redeemCodeListOrder(params) {
All(ctx) codesQuery = codesQuery.Order(order)
}
codes, err := codesQuery.All(ctx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
...@@ -136,6 +142,34 @@ func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagin ...@@ -136,6 +142,34 @@ func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagin
return outCodes, paginationResultFromTotal(int64(total), params), nil return outCodes, paginationResultFromTotal(int64(total), params), nil
} }
func redeemCodeListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
var field string
switch sortBy {
case "type":
field = redeemcode.FieldType
case "value":
field = redeemcode.FieldValue
case "status":
field = redeemcode.FieldStatus
case "used_at":
field = redeemcode.FieldUsedAt
case "created_at":
field = redeemcode.FieldCreatedAt
case "code":
field = redeemcode.FieldCode
default:
field = redeemcode.FieldID
}
if sortOrder == pagination.SortOrderAsc {
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(redeemcode.FieldID)}
}
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(redeemcode.FieldID)}
}
func (r *redeemCodeRepository) Update(ctx context.Context, code *service.RedeemCode) error { func (r *redeemCodeRepository) Update(ctx context.Context, code *service.RedeemCode) error {
up := r.client.RedeemCode.UpdateOneID(code.ID). up := r.client.RedeemCode.UpdateOneID(code.ID).
SetCode(code.Code). SetCode(code.Code).
......
//go:build integration
package repository
import (
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func (s *RedeemCodeRepoSuite) TestListWithFilters_SortByValueAsc() {
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "VALUE-20", Type: service.RedeemTypeBalance, Value: 20, Status: service.StatusUnused}))
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "VALUE-10", Type: service.RedeemTypeBalance, Value: 10, Status: service.StatusUnused}))
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
Page: 1,
PageSize: 10,
SortBy: "value",
SortOrder: "asc",
}, "", "", "")
s.Require().NoError(err)
s.Require().Len(codes, 2)
s.Require().Equal("VALUE-10", codes[0].Code)
s.Require().Equal("VALUE-20", codes[1].Code)
}
...@@ -3771,7 +3771,7 @@ func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, wh ...@@ -3771,7 +3771,7 @@ func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, wh
limitPos := len(args) + 1 limitPos := len(args) + 1
offsetPos := len(args) + 2 offsetPos := len(args) + 2
listArgs := append(append([]any{}, args...), params.Limit(), params.Offset()) listArgs := append(append([]any{}, args...), params.Limit(), params.Offset())
query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY id DESC LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, limitPos, offsetPos) query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY %s LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, usageLogOrderBy(params), limitPos, offsetPos)
logs, err := r.queryUsageLogs(ctx, query, listArgs...) logs, err := r.queryUsageLogs(ctx, query, listArgs...)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
...@@ -3786,7 +3786,7 @@ func (r *usageLogRepository) listUsageLogsWithFastPagination(ctx context.Context ...@@ -3786,7 +3786,7 @@ func (r *usageLogRepository) listUsageLogsWithFastPagination(ctx context.Context
limitPos := len(args) + 1 limitPos := len(args) + 1
offsetPos := len(args) + 2 offsetPos := len(args) + 2
listArgs := append(append([]any{}, args...), limit+1, offset) listArgs := append(append([]any{}, args...), limit+1, offset)
query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY id DESC LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, limitPos, offsetPos) query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY %s LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, usageLogOrderBy(params), limitPos, offsetPos)
logs, err := r.queryUsageLogs(ctx, query, listArgs...) logs, err := r.queryUsageLogs(ctx, query, listArgs...)
if err != nil { if err != nil {
...@@ -3808,6 +3808,26 @@ func (r *usageLogRepository) listUsageLogsWithFastPagination(ctx context.Context ...@@ -3808,6 +3808,26 @@ func (r *usageLogRepository) listUsageLogsWithFastPagination(ctx context.Context
return logs, paginationResultFromTotal(total, params), nil return logs, paginationResultFromTotal(total, params), nil
} }
func usageLogOrderBy(params pagination.PaginationParams) string {
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
sortOrder := strings.ToUpper(params.NormalizedSortOrder(pagination.SortOrderDesc))
var column string
switch sortBy {
case "model":
column = "COALESCE(NULLIF(TRIM(requested_model), ''), model)"
case "created_at":
column = "created_at"
default:
column = "id"
}
if column == "id" {
return fmt.Sprintf("id %s", sortOrder)
}
return fmt.Sprintf("%s %s, id %s", column, sortOrder, sortOrder)
}
func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, args ...any) (logs []service.UsageLog, err error) { func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, args ...any) (logs []service.UsageLog, err error) {
rows, err := r.sql.QueryContext(ctx, query, args...) rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
......
...@@ -330,6 +330,15 @@ func TestUsageLogRepositoryGetStatsWithFiltersRequestTypePriority(t *testing.T) ...@@ -330,6 +330,15 @@ func TestUsageLogRepositoryGetStatsWithFiltersRequestTypePriority(t *testing.T)
"total_account_cost", "total_account_cost",
"avg_duration_ms", "avg_duration_ms",
}).AddRow(int64(1), int64(2), int64(3), int64(4), 1.2, 1.0, 1.2, 20.0)) }).AddRow(int64(1), int64(2), int64(3), int64(4), 1.2, 1.0, 1.2, 20.0))
mock.ExpectQuery("SELECT COALESCE\\(NULLIF\\(TRIM\\(inbound_endpoint\\), ''\\), 'unknown'\\) AS endpoint").
WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), requestType).
WillReturnRows(sqlmock.NewRows([]string{"endpoint", "requests", "total_tokens", "cost", "actual_cost"}))
mock.ExpectQuery("SELECT COALESCE\\(NULLIF\\(TRIM\\(upstream_endpoint\\), ''\\), 'unknown'\\) AS endpoint").
WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), requestType).
WillReturnRows(sqlmock.NewRows([]string{"endpoint", "requests", "total_tokens", "cost", "actual_cost"}))
mock.ExpectQuery("SELECT CONCAT\\(").
WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), requestType).
WillReturnRows(sqlmock.NewRows([]string{"endpoint", "requests", "total_tokens", "cost", "actual_cost"}))
stats, err := repo.GetStatsWithFilters(context.Background(), filters) stats, err := repo.GetStatsWithFilters(context.Background(), filters)
require.NoError(t, err) require.NoError(t, err)
......
//go:build integration
package repository
import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/google/uuid"
)
func (s *UsageLogRepoSuite) TestListWithFilters_SortByModelAsc() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "usage-sort@example.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-usage-sort", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "usage-sort-account"})
first := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.New().String(),
Model: "z-model",
RequestedModel: "z-model",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now(),
}
_, err := s.repo.Create(s.ctx, first)
s.Require().NoError(err)
second := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.New().String(),
Model: "a-model",
RequestedModel: "a-model",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().Add(time.Second),
}
_, err = s.repo.Create(s.ctx, second)
s.Require().NoError(err)
logs, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
Page: 1,
PageSize: 10,
SortBy: "model",
SortOrder: "asc",
}, usagestats.UsageLogFilters{UserID: user.ID})
s.Require().NoError(err)
s.Require().Len(logs, 2)
s.Require().Equal("a-model", logs[0].RequestedModel)
s.Require().Equal("z-model", logs[1].RequestedModel)
}
...@@ -17,6 +17,8 @@ import ( ...@@ -17,6 +17,8 @@ import (
"github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
entsql "entgo.io/ent/dialect/sql"
) )
type userRepository struct { type userRepository struct {
...@@ -224,11 +226,14 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination. ...@@ -224,11 +226,14 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
return nil, nil, err return nil, nil, err
} }
users, err := q. usersQuery := q.
Offset(params.Offset()). Offset(params.Offset()).
Limit(params.Limit()). Limit(params.Limit())
Order(dbent.Desc(dbuser.FieldID)). for _, order := range userListOrder(params) {
All(ctx) usersQuery = usersQuery.Order(order)
}
users, err := usersQuery.All(ctx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
...@@ -281,6 +286,50 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination. ...@@ -281,6 +286,50 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
return outUsers, paginationResultFromTotal(int64(total), params), nil return outUsers, paginationResultFromTotal(int64(total), params), nil
} }
func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
var field string
defaultField := true
switch sortBy {
case "email":
field = dbuser.FieldEmail
defaultField = false
case "username":
field = dbuser.FieldUsername
defaultField = false
case "role":
field = dbuser.FieldRole
defaultField = false
case "balance":
field = dbuser.FieldBalance
defaultField = false
case "concurrency":
field = dbuser.FieldConcurrency
defaultField = false
case "status":
field = dbuser.FieldStatus
defaultField = false
case "created_at":
field = dbuser.FieldCreatedAt
defaultField = false
default:
field = dbuser.FieldID
}
if sortOrder == pagination.SortOrderAsc {
if defaultField && field == dbuser.FieldID {
return []func(*entsql.Selector){dbent.Asc(dbuser.FieldID)}
}
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(dbuser.FieldID)}
}
if defaultField && field == dbuser.FieldID {
return []func(*entsql.Selector){dbent.Desc(dbuser.FieldID)}
}
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(dbuser.FieldID)}
}
// filterUsersByAttributes returns user IDs that match ALL the given attribute filters // filterUsersByAttributes returns user IDs that match ALL the given attribute filters
func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) ([]int64, error) { func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) ([]int64, error) {
if len(attrs) == 0 { if len(attrs) == 0 {
......
//go:build integration
package repository
import (
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func (s *UserRepoSuite) TestListWithFilters_SortByEmailAsc() {
s.mustCreateUser(&service.User{Email: "z-last@example.com", Username: "z-user"})
s.mustCreateUser(&service.User{Email: "a-first@example.com", Username: "a-user"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
Page: 1,
PageSize: 10,
SortBy: "email",
SortOrder: "asc",
}, service.UserListFilters{})
s.Require().NoError(err)
s.Require().Len(users, 2)
s.Require().Equal("a-first@example.com", users[0].Email)
s.Require().Equal("z-last@example.com", users[1].Email)
}
func (s *UserRepoSuite) TestList_DefaultSortByNewestFirst() {
first := s.mustCreateUser(&service.User{Email: "first@example.com"})
second := s.mustCreateUser(&service.User{Email: "second@example.com"})
users, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err)
s.Require().Len(users, 2)
s.Require().Equal(second.ID, users[0].ID)
s.Require().Equal(first.ID, users[1].ID)
}
func TestUserRepoSortSuiteSmoke(_ *testing.T) {}
...@@ -491,8 +491,10 @@ func TestAPIContracts(t *testing.T) { ...@@ -491,8 +491,10 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeyContactInfo: "support", service.SettingKeyContactInfo: "support",
service.SettingKeyDocURL: "https://docs.example.com", service.SettingKeyDocURL: "https://docs.example.com",
service.SettingKeyDefaultConcurrency: "5", service.SettingKeyDefaultConcurrency: "5",
service.SettingKeyDefaultBalance: "1.25", service.SettingKeyDefaultBalance: "1.25",
service.SettingKeyTableDefaultPageSize: "20",
service.SettingKeyTablePageSizeOptions: "[10,20,50,100]",
service.SettingKeyOpsMonitoringEnabled: "false", service.SettingKeyOpsMonitoringEnabled: "false",
service.SettingKeyOpsRealtimeMonitoringEnabled: "true", service.SettingKeyOpsRealtimeMonitoringEnabled: "true",
...@@ -576,6 +578,8 @@ func TestAPIContracts(t *testing.T) { ...@@ -576,6 +578,8 @@ func TestAPIContracts(t *testing.T) {
"hide_ccs_import_button": false, "hide_ccs_import_button": false,
"purchase_subscription_enabled": false, "purchase_subscription_enabled": false,
"purchase_subscription_url": "", "purchase_subscription_url": "",
"table_default_page_size": 20,
"table_page_size_options": [10, 20, 50, 100],
"min_claude_code_version": "", "min_claude_code_version": "",
"max_claude_code_version": "", "max_claude_code_version": "",
"allow_ungrouped_key_scheduling": false, "allow_ungrouped_key_scheduling": false,
...@@ -583,6 +587,24 @@ func TestAPIContracts(t *testing.T) { ...@@ -583,6 +587,24 @@ func TestAPIContracts(t *testing.T) {
"enable_cch_signing": false, "enable_cch_signing": false,
"enable_fingerprint_unification": true, "enable_fingerprint_unification": true,
"enable_metadata_passthrough": false, "enable_metadata_passthrough": false,
"payment_enabled": false,
"payment_min_amount": 0,
"payment_max_amount": 0,
"payment_daily_limit": 0,
"payment_order_timeout_minutes": 0,
"payment_max_pending_orders": 0,
"payment_enabled_types": null,
"payment_balance_disabled": false,
"payment_load_balance_strategy": "",
"payment_product_name_prefix": "",
"payment_product_name_suffix": "",
"payment_help_image_url": "",
"payment_help_text": "",
"payment_cancel_rate_limit_enabled": false,
"payment_cancel_rate_limit_max": 0,
"payment_cancel_rate_limit_window": 0,
"payment_cancel_rate_limit_unit": "",
"payment_cancel_rate_limit_window_mode": "",
"custom_menu_items": [], "custom_menu_items": [],
"custom_endpoints": [] "custom_endpoints": []
} }
...@@ -696,7 +718,7 @@ func newContractDeps(t *testing.T) *contractDeps { ...@@ -696,7 +718,7 @@ func newContractDeps(t *testing.T) *contractDeps {
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil) adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil, nil, nil)
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
jwtAuth := func(c *gin.Context) { jwtAuth := func(c *gin.Context) {
......
...@@ -111,4 +111,5 @@ func registerRoutes( ...@@ -111,4 +111,5 @@ func registerRoutes(
routes.RegisterUserRoutes(v1, h, jwtAuth, settingService) routes.RegisterUserRoutes(v1, h, jwtAuth, settingService)
routes.RegisterAdminRoutes(v1, h, adminAuth) routes.RegisterAdminRoutes(v1, h, adminAuth)
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg) routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg)
routes.RegisterPaymentRoutes(v1, h.Payment, h.PaymentWebhook, h.Admin.Payment, jwtAuth, adminAuth, settingService)
} }
package routes
import (
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/handler/admin"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// RegisterPaymentRoutes registers all payment-related routes:
// user-facing endpoints, webhook endpoints, and admin endpoints.
func RegisterPaymentRoutes(
v1 *gin.RouterGroup,
paymentHandler *handler.PaymentHandler,
webhookHandler *handler.PaymentWebhookHandler,
adminPaymentHandler *admin.PaymentHandler,
jwtAuth middleware.JWTAuthMiddleware,
adminAuth middleware.AdminAuthMiddleware,
settingService *service.SettingService,
) {
// --- User-facing payment endpoints (authenticated) ---
authenticated := v1.Group("/payment")
authenticated.Use(gin.HandlerFunc(jwtAuth))
authenticated.Use(middleware.BackendModeUserGuard(settingService))
{
authenticated.GET("/config", paymentHandler.GetPaymentConfig)
authenticated.GET("/checkout-info", paymentHandler.GetCheckoutInfo)
authenticated.GET("/plans", paymentHandler.GetPlans)
authenticated.GET("/channels", paymentHandler.GetChannels)
authenticated.GET("/limits", paymentHandler.GetLimits)
orders := authenticated.Group("/orders")
{
orders.POST("", paymentHandler.CreateOrder)
orders.POST("/verify", paymentHandler.VerifyOrder)
orders.GET("/my", paymentHandler.GetMyOrders)
orders.GET("/:id", paymentHandler.GetOrder)
orders.POST("/:id/cancel", paymentHandler.CancelOrder)
orders.POST("/:id/refund-request", paymentHandler.RequestRefund)
}
}
// --- Public payment endpoints (no auth) ---
// Payment result page needs to verify order status without login
// (user session may have expired during provider redirect).
public := v1.Group("/payment/public")
{
public.POST("/orders/verify", paymentHandler.VerifyOrderPublic)
}
// --- Webhook endpoints (no auth) ---
webhook := v1.Group("/payment/webhook")
{
// EasyPay sends GET callbacks with query params
webhook.GET("/easypay", webhookHandler.EasyPayNotify)
webhook.POST("/easypay", webhookHandler.EasyPayNotify)
webhook.POST("/alipay", webhookHandler.AlipayNotify)
webhook.POST("/wxpay", webhookHandler.WxpayNotify)
webhook.POST("/stripe", webhookHandler.StripeWebhook)
}
// --- Admin payment endpoints (admin auth) ---
adminGroup := v1.Group("/admin/payment")
adminGroup.Use(gin.HandlerFunc(adminAuth))
{
// Dashboard
adminGroup.GET("/dashboard", adminPaymentHandler.GetDashboard)
// Config
adminGroup.GET("/config", adminPaymentHandler.GetConfig)
adminGroup.PUT("/config", adminPaymentHandler.UpdateConfig)
// Orders
adminOrders := adminGroup.Group("/orders")
{
adminOrders.GET("", adminPaymentHandler.ListOrders)
adminOrders.GET("/:id", adminPaymentHandler.GetOrderDetail)
adminOrders.POST("/:id/cancel", adminPaymentHandler.CancelOrder)
adminOrders.POST("/:id/retry", adminPaymentHandler.RetryFulfillment)
adminOrders.POST("/:id/refund", adminPaymentHandler.ProcessRefund)
}
// Subscription Plans
plans := adminGroup.Group("/plans")
{
plans.GET("", adminPaymentHandler.ListPlans)
plans.POST("", adminPaymentHandler.CreatePlan)
plans.PUT("/:id", adminPaymentHandler.UpdatePlan)
plans.DELETE("/:id", adminPaymentHandler.DeletePlan)
}
// Provider Instances
providers := adminGroup.Group("/providers")
{
providers.GET("", adminPaymentHandler.ListProviders)
providers.POST("", adminPaymentHandler.CreateProvider)
providers.PUT("/:id", adminPaymentHandler.UpdateProvider)
providers.DELETE("/:id", adminPaymentHandler.DeleteProvider)
}
}
}
...@@ -21,13 +21,13 @@ import ( ...@@ -21,13 +21,13 @@ import (
// AdminService interface defines admin management operations // AdminService interface defines admin management operations
type AdminService interface { type AdminService interface {
// User management // User management
ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error) ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters, sortBy, sortOrder string) ([]User, int64, error)
GetUser(ctx context.Context, id int64) (*User, error) GetUser(ctx context.Context, id int64) (*User, error)
CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error)
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
DeleteUser(ctx context.Context, id int64) error DeleteUser(ctx context.Context, id int64) error
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error)
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
// GetUserBalanceHistory returns paginated balance/concurrency change records for a user. // GetUserBalanceHistory returns paginated balance/concurrency change records for a user.
// codeType is optional - pass empty string to return all types. // codeType is optional - pass empty string to return all types.
...@@ -35,7 +35,7 @@ type AdminService interface { ...@@ -35,7 +35,7 @@ type AdminService interface {
GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error)
// Group management // Group management
ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error)
GetAllGroups(ctx context.Context) ([]Group, error) GetAllGroups(ctx context.Context) ([]Group, error)
GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error)
GetGroup(ctx context.Context, id int64) (*Group, error) GetGroup(ctx context.Context, id int64) (*Group, error)
...@@ -55,7 +55,7 @@ type AdminService interface { ...@@ -55,7 +55,7 @@ type AdminService interface {
ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error) ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error)
// Account management // Account management
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, int64, error) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string, sortBy, sortOrder string) ([]Account, int64, error)
GetAccount(ctx context.Context, id int64) (*Account, error) GetAccount(ctx context.Context, id int64) (*Account, error)
GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error)
CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
...@@ -77,8 +77,8 @@ type AdminService interface { ...@@ -77,8 +77,8 @@ type AdminService interface {
CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error
// Proxy management // Proxy management
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string, sortBy, sortOrder string) ([]Proxy, int64, error)
ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]ProxyWithAccountCount, int64, error) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string, sortBy, sortOrder string) ([]ProxyWithAccountCount, int64, error)
GetAllProxies(ctx context.Context) ([]Proxy, error) GetAllProxies(ctx context.Context) ([]Proxy, error)
GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error)
GetProxy(ctx context.Context, id int64) (*Proxy, error) GetProxy(ctx context.Context, id int64) (*Proxy, error)
...@@ -93,7 +93,7 @@ type AdminService interface { ...@@ -93,7 +93,7 @@ type AdminService interface {
CheckProxyQuality(ctx context.Context, id int64) (*ProxyQualityCheckResult, error) CheckProxyQuality(ctx context.Context, id int64) (*ProxyQualityCheckResult, error)
// Redeem code management // Redeem code management
ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string, sortBy, sortOrder string) ([]RedeemCode, int64, error)
GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error) GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error)
DeleteRedeemCode(ctx context.Context, id int64) error DeleteRedeemCode(ctx context.Context, id int64) error
...@@ -485,8 +485,8 @@ func NewAdminService( ...@@ -485,8 +485,8 @@ func NewAdminService(
} }
// User management implementations // User management implementations
func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error) { func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters, sortBy, sortOrder string) ([]User, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
users, result, err := s.userRepo.ListWithFilters(ctx, params, filters) users, result, err := s.userRepo.ListWithFilters(ctx, params, filters)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
...@@ -753,8 +753,8 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, ...@@ -753,8 +753,8 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
return user, nil return user, nil
} }
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) { func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params, APIKeyListFilters{}) keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params, APIKeyListFilters{})
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
...@@ -789,8 +789,8 @@ func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int ...@@ -789,8 +789,8 @@ func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int
} }
// Group management implementations // Group management implementations
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) { func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, search, isExclusive) groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, search, isExclusive)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
...@@ -1464,8 +1464,8 @@ func (s *adminServiceImpl) ReplaceUserGroup(ctx context.Context, userID, oldGrou ...@@ -1464,8 +1464,8 @@ func (s *adminServiceImpl) ReplaceUserGroup(ctx context.Context, userID, oldGrou
} }
// Account management implementations // Account management implementations
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, int64, error) { func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string, sortBy, sortOrder string) ([]Account, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID, privacyMode) accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID, privacyMode)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
...@@ -1893,8 +1893,8 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, ...@@ -1893,8 +1893,8 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64,
} }
// Proxy management implementations // Proxy management implementations
func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) { func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string, sortBy, sortOrder string) ([]Proxy, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search) proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
...@@ -1902,8 +1902,8 @@ func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, ...@@ -1902,8 +1902,8 @@ func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int,
return proxies, result.Total, nil return proxies, result.Total, nil
} }
func (s *adminServiceImpl) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]ProxyWithAccountCount, int64, error) { func (s *adminServiceImpl) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string, sortBy, sortOrder string) ([]ProxyWithAccountCount, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
proxies, result, err := s.proxyRepo.ListWithFiltersAndAccountCount(ctx, params, protocol, status, search) proxies, result, err := s.proxyRepo.ListWithFiltersAndAccountCount(ctx, params, protocol, status, search)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
...@@ -2040,8 +2040,8 @@ func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, po ...@@ -2040,8 +2040,8 @@ func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, po
} }
// Redeem code management implementations // Redeem code management implementations
func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error) { func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string, sortBy, sortOrder string) ([]RedeemCode, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search) codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
......
...@@ -125,6 +125,22 @@ func (s *groupRepoStubForAdmin) UpdateSortOrders(_ context.Context, _ []GroupSor ...@@ -125,6 +125,22 @@ func (s *groupRepoStubForAdmin) UpdateSortOrders(_ context.Context, _ []GroupSor
return nil return nil
} }
func TestAdminService_ListGroups_PassesSortParams(t *testing.T) {
repo := &groupRepoStubForAdmin{
listWithFiltersGroups: []Group{{ID: 1, Name: "g1"}},
}
svc := &adminServiceImpl{groupRepo: repo}
_, _, err := svc.ListGroups(context.Background(), 3, 25, PlatformOpenAI, StatusActive, "needle", nil, "account_count", "ASC")
require.NoError(t, err)
require.Equal(t, pagination.PaginationParams{
Page: 3,
PageSize: 25,
SortBy: "account_count",
SortOrder: "ASC",
}, repo.listWithFiltersParams)
}
// TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递 // TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递
func TestAdminService_CreateGroup_WithImagePricing(t *testing.T) { func TestAdminService_CreateGroup_WithImagePricing(t *testing.T) {
repo := &groupRepoStubForAdmin{} repo := &groupRepoStubForAdmin{}
...@@ -373,7 +389,7 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) { ...@@ -373,7 +389,7 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) {
} }
svc := &adminServiceImpl{groupRepo: repo} svc := &adminServiceImpl{groupRepo: repo}
groups, total, err := svc.ListGroups(context.Background(), 1, 20, "", "", "alpha", nil) groups, total, err := svc.ListGroups(context.Background(), 1, 20, "", "", "alpha", nil, "", "")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(1), total) require.Equal(t, int64(1), total)
require.Equal(t, []Group{{ID: 1, Name: "alpha"}}, groups) require.Equal(t, []Group{{ID: 1, Name: "alpha"}}, groups)
...@@ -391,7 +407,7 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) { ...@@ -391,7 +407,7 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) {
} }
svc := &adminServiceImpl{groupRepo: repo} svc := &adminServiceImpl{groupRepo: repo}
groups, total, err := svc.ListGroups(context.Background(), 2, 10, "", "", "", nil) groups, total, err := svc.ListGroups(context.Background(), 2, 10, "", "", "", nil, "", "")
require.NoError(t, err) require.NoError(t, err)
require.Empty(t, groups) require.Empty(t, groups)
require.Equal(t, int64(0), total) require.Equal(t, int64(0), total)
...@@ -410,7 +426,7 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) { ...@@ -410,7 +426,7 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) {
} }
svc := &adminServiceImpl{groupRepo: repo} svc := &adminServiceImpl{groupRepo: repo}
groups, total, err := svc.ListGroups(context.Background(), 3, 50, PlatformAntigravity, StatusActive, "beta", &isExclusive) groups, total, err := svc.ListGroups(context.Background(), 3, 50, PlatformAntigravity, StatusActive, "beta", &isExclusive, "", "")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(42), total) require.Equal(t, int64(42), total)
require.Equal(t, []Group{{ID: 2, Name: "beta"}}, groups) require.Equal(t, []Group{{ID: 2, Name: "beta"}}, groups)
......
...@@ -13,11 +13,13 @@ import ( ...@@ -13,11 +13,13 @@ import (
type userRepoStubForListUsers struct { type userRepoStubForListUsers struct {
userRepoStub userRepoStub
users []User users []User
err error err error
listWithFiltersParams pagination.PaginationParams
} }
func (s *userRepoStubForListUsers) ListWithFilters(_ context.Context, params pagination.PaginationParams, _ UserListFilters) ([]User, *pagination.PaginationResult, error) { func (s *userRepoStubForListUsers) ListWithFilters(_ context.Context, params pagination.PaginationParams, _ UserListFilters) ([]User, *pagination.PaginationResult, error) {
s.listWithFiltersParams = params
if s.err != nil { if s.err != nil {
return nil, nil, s.err return nil, nil, s.err
} }
...@@ -103,7 +105,7 @@ func TestAdminService_ListUsers_BatchRateFallbackToSingle(t *testing.T) { ...@@ -103,7 +105,7 @@ func TestAdminService_ListUsers_BatchRateFallbackToSingle(t *testing.T) {
userGroupRateRepo: rateRepo, userGroupRateRepo: rateRepo,
} }
users, total, err := svc.ListUsers(context.Background(), 1, 20, UserListFilters{}) users, total, err := svc.ListUsers(context.Background(), 1, 20, UserListFilters{}, "", "")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(2), total) require.Equal(t, int64(2), total)
require.Len(t, users, 2) require.Len(t, users, 2)
...@@ -112,3 +114,19 @@ func TestAdminService_ListUsers_BatchRateFallbackToSingle(t *testing.T) { ...@@ -112,3 +114,19 @@ func TestAdminService_ListUsers_BatchRateFallbackToSingle(t *testing.T) {
require.Equal(t, 1.1, users[0].GroupRates[11]) require.Equal(t, 1.1, users[0].GroupRates[11])
require.Equal(t, 2.2, users[1].GroupRates[22]) require.Equal(t, 2.2, users[1].GroupRates[22])
} }
func TestAdminService_ListUsers_PassesSortParams(t *testing.T) {
userRepo := &userRepoStubForListUsers{
users: []User{{ID: 1, Email: "a@example.com"}},
}
svc := &adminServiceImpl{userRepo: userRepo}
_, _, err := svc.ListUsers(context.Background(), 2, 50, UserListFilters{}, "email", "ASC")
require.NoError(t, err)
require.Equal(t, pagination.PaginationParams{
Page: 2,
PageSize: 50,
SortBy: "email",
SortOrder: "ASC",
}, userRepo.listWithFiltersParams)
}
...@@ -170,13 +170,13 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) { ...@@ -170,13 +170,13 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
} }
svc := &adminServiceImpl{accountRepo: repo} svc := &adminServiceImpl{accountRepo: repo}
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0, "") accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0, "", "name", "ASC")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(10), total) require.Equal(t, int64(10), total)
require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts) require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts)
require.Equal(t, 1, repo.listWithFiltersCalls) require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams) require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20, SortBy: "name", SortOrder: "ASC"}, repo.listWithFiltersParams)
require.Equal(t, PlatformGemini, repo.listWithFiltersPlatform) require.Equal(t, PlatformGemini, repo.listWithFiltersPlatform)
require.Equal(t, AccountTypeOAuth, repo.listWithFiltersType) require.Equal(t, AccountTypeOAuth, repo.listWithFiltersType)
require.Equal(t, StatusActive, repo.listWithFiltersStatus) require.Equal(t, StatusActive, repo.listWithFiltersStatus)
...@@ -192,7 +192,7 @@ func TestAdminService_ListAccounts_WithPrivacyMode(t *testing.T) { ...@@ -192,7 +192,7 @@ func TestAdminService_ListAccounts_WithPrivacyMode(t *testing.T) {
} }
svc := &adminServiceImpl{accountRepo: repo} svc := &adminServiceImpl{accountRepo: repo}
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, StatusActive, "acc2", 0, PrivacyModeCFBlocked) accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, StatusActive, "acc2", 0, PrivacyModeCFBlocked, "", "")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(1), total) require.Equal(t, int64(1), total)
require.Equal(t, []Account{{ID: 2, Name: "acc2"}}, accounts) require.Equal(t, []Account{{ID: 2, Name: "acc2"}}, accounts)
...@@ -208,13 +208,13 @@ func TestAdminService_ListProxies_WithSearch(t *testing.T) { ...@@ -208,13 +208,13 @@ func TestAdminService_ListProxies_WithSearch(t *testing.T) {
} }
svc := &adminServiceImpl{proxyRepo: repo} svc := &adminServiceImpl{proxyRepo: repo}
proxies, total, err := svc.ListProxies(context.Background(), 3, 50, "http", StatusActive, "p1") proxies, total, err := svc.ListProxies(context.Background(), 3, 50, "http", StatusActive, "p1", "name", "ASC")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(7), total) require.Equal(t, int64(7), total)
require.Equal(t, []Proxy{{ID: 2, Name: "p1"}}, proxies) require.Equal(t, []Proxy{{ID: 2, Name: "p1"}}, proxies)
require.Equal(t, 1, repo.listWithFiltersCalls) require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50}, repo.listWithFiltersParams) require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50, SortBy: "name", SortOrder: "ASC"}, repo.listWithFiltersParams)
require.Equal(t, "http", repo.listWithFiltersProtocol) require.Equal(t, "http", repo.listWithFiltersProtocol)
require.Equal(t, StatusActive, repo.listWithFiltersStatus) require.Equal(t, StatusActive, repo.listWithFiltersStatus)
require.Equal(t, "p1", repo.listWithFiltersSearch) require.Equal(t, "p1", repo.listWithFiltersSearch)
...@@ -229,13 +229,13 @@ func TestAdminService_ListProxiesWithAccountCount_WithSearch(t *testing.T) { ...@@ -229,13 +229,13 @@ func TestAdminService_ListProxiesWithAccountCount_WithSearch(t *testing.T) {
} }
svc := &adminServiceImpl{proxyRepo: repo} svc := &adminServiceImpl{proxyRepo: repo}
proxies, total, err := svc.ListProxiesWithAccountCount(context.Background(), 2, 10, "socks5", StatusDisabled, "p2") proxies, total, err := svc.ListProxiesWithAccountCount(context.Background(), 2, 10, "socks5", StatusDisabled, "p2", "account_count", "DESC")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(9), total) require.Equal(t, int64(9), total)
require.Equal(t, []ProxyWithAccountCount{{Proxy: Proxy{ID: 3, Name: "p2"}, AccountCount: 5}}, proxies) require.Equal(t, []ProxyWithAccountCount{{Proxy: Proxy{ID: 3, Name: "p2"}, AccountCount: 5}}, proxies)
require.Equal(t, 1, repo.listWithFiltersAndAccountCountCalls) require.Equal(t, 1, repo.listWithFiltersAndAccountCountCalls)
require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10}, repo.listWithFiltersAndAccountCountParams) require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10, SortBy: "account_count", SortOrder: "DESC"}, repo.listWithFiltersAndAccountCountParams)
require.Equal(t, "socks5", repo.listWithFiltersAndAccountCountProtocol) require.Equal(t, "socks5", repo.listWithFiltersAndAccountCountProtocol)
require.Equal(t, StatusDisabled, repo.listWithFiltersAndAccountCountStatus) require.Equal(t, StatusDisabled, repo.listWithFiltersAndAccountCountStatus)
require.Equal(t, "p2", repo.listWithFiltersAndAccountCountSearch) require.Equal(t, "p2", repo.listWithFiltersAndAccountCountSearch)
...@@ -250,13 +250,13 @@ func TestAdminService_ListRedeemCodes_WithSearch(t *testing.T) { ...@@ -250,13 +250,13 @@ func TestAdminService_ListRedeemCodes_WithSearch(t *testing.T) {
} }
svc := &adminServiceImpl{redeemCodeRepo: repo} svc := &adminServiceImpl{redeemCodeRepo: repo}
codes, total, err := svc.ListRedeemCodes(context.Background(), 1, 20, RedeemTypeBalance, StatusUnused, "ABC") codes, total, err := svc.ListRedeemCodes(context.Background(), 1, 20, RedeemTypeBalance, StatusUnused, "ABC", "value", "ASC")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(3), total) require.Equal(t, int64(3), total)
require.Equal(t, []RedeemCode{{ID: 4, Code: "ABC"}}, codes) require.Equal(t, []RedeemCode{{ID: 4, Code: "ABC"}}, codes)
require.Equal(t, 1, repo.listWithFiltersCalls) require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams) require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20, SortBy: "value", SortOrder: "ASC"}, repo.listWithFiltersParams)
require.Equal(t, RedeemTypeBalance, repo.listWithFiltersType) require.Equal(t, RedeemTypeBalance, repo.listWithFiltersType)
require.Equal(t, StatusUnused, repo.listWithFiltersStatus) require.Equal(t, StatusUnused, repo.listWithFiltersStatus)
require.Equal(t, "ABC", repo.listWithFiltersSearch) require.Equal(t, "ABC", repo.listWithFiltersSearch)
......
...@@ -4,6 +4,7 @@ import "time" ...@@ -4,6 +4,7 @@ import "time"
// APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段) // APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段)
type APIKeyAuthSnapshot struct { type APIKeyAuthSnapshot struct {
Version int `json:"version"`
APIKeyID int64 `json:"api_key_id"` APIKeyID int64 `json:"api_key_id"`
UserID int64 `json:"user_id"` UserID int64 `json:"user_id"`
GroupID *int64 `json:"group_id,omitempty"` GroupID *int64 `json:"group_id,omitempty"`
...@@ -63,8 +64,9 @@ type APIKeyAuthGroupSnapshot struct { ...@@ -63,8 +64,9 @@ type APIKeyAuthGroupSnapshot struct {
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"` SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
// OpenAI Messages 调度配置(仅 openai 平台使用) // OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch bool `json:"allow_messages_dispatch"` AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
DefaultMappedModel string `json:"default_mapped_model,omitempty"` DefaultMappedModel string `json:"default_mapped_model,omitempty"`
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"`
} }
// APIKeyAuthCacheEntry 缓存条目,支持负缓存 // APIKeyAuthCacheEntry 缓存条目,支持负缓存
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment