Commit 5f8e60a1 authored by IanShaw027's avatar IanShaw027
Browse files

feat(table): 表格排序与搜索改为后端处理

parent 66e15a54
...@@ -221,6 +221,8 @@ func (h *AccountHandler) List(c *gin.Context) { ...@@ -221,6 +221,8 @@ func (h *AccountHandler) List(c *gin.Context) {
status := c.Query("status") status := c.Query("status")
search := c.Query("search") search := c.Query("search")
privacyMode := strings.TrimSpace(c.Query("privacy_mode")) privacyMode := strings.TrimSpace(c.Query("privacy_mode"))
sortBy := c.DefaultQuery("sort_by", "name")
sortOrder := c.DefaultQuery("sort_order", "asc")
// 标准化和验证 search 参数 // 标准化和验证 search 参数
search = strings.TrimSpace(search) search = strings.TrimSpace(search)
if len(search) > 100 { if len(search) > 100 {
...@@ -246,7 +248,7 @@ func (h *AccountHandler) List(c *gin.Context) { ...@@ -246,7 +248,7 @@ func (h *AccountHandler) List(c *gin.Context) {
} }
} }
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID, privacyMode) accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID, privacyMode, sortBy, sortOrder)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
...@@ -2029,7 +2031,7 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) { ...@@ -2029,7 +2031,7 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
accounts := make([]*service.Account, 0) accounts := make([]*service.Account, 0)
if len(req.AccountIDs) == 0 { if len(req.AccountIDs) == 0 {
allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0, "") allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0, "", "name", "asc")
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
......
...@@ -31,6 +31,33 @@ type stubAdminService struct { ...@@ -31,6 +31,33 @@ type stubAdminService struct {
platform string platform string
groupIDs []int64 groupIDs []int64
} }
lastListAccounts struct {
platform string
accountType string
status string
search string
groupID int64
privacyMode string
sortBy string
sortOrder string
calls int
}
lastListProxies struct {
protocol string
status string
search string
sortBy string
sortOrder string
calls int
}
lastListRedeemCodes struct {
codeType string
status string
search string
sortBy string
sortOrder string
calls int
}
mu sync.Mutex mu sync.Mutex
} }
...@@ -99,7 +126,7 @@ func newStubAdminService() *stubAdminService { ...@@ -99,7 +126,7 @@ func newStubAdminService() *stubAdminService {
} }
} }
func (s *stubAdminService) ListUsers(ctx context.Context, page, pageSize int, filters service.UserListFilters) ([]service.User, int64, error) { func (s *stubAdminService) ListUsers(ctx context.Context, page, pageSize int, filters service.UserListFilters, sortBy, sortOrder string) ([]service.User, int64, error) {
return s.users, int64(len(s.users)), nil return s.users, int64(len(s.users)), nil
} }
...@@ -132,7 +159,7 @@ func (s *stubAdminService) UpdateUserBalance(ctx context.Context, userID int64, ...@@ -132,7 +159,7 @@ func (s *stubAdminService) UpdateUserBalance(ctx context.Context, userID int64,
return &user, nil return &user, nil
} }
func (s *stubAdminService) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]service.APIKey, int64, error) { func (s *stubAdminService) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]service.APIKey, int64, error) {
return s.apiKeys, int64(len(s.apiKeys)), nil return s.apiKeys, int64(len(s.apiKeys)), nil
} }
...@@ -140,7 +167,7 @@ func (s *stubAdminService) GetUserUsageStats(ctx context.Context, userID int64, ...@@ -140,7 +167,7 @@ func (s *stubAdminService) GetUserUsageStats(ctx context.Context, userID int64,
return map[string]any{"user_id": userID}, nil return map[string]any{"user_id": userID}, nil
} }
func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]service.Group, int64, error) { func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]service.Group, int64, error) {
return s.groups, int64(len(s.groups)), nil return s.groups, int64(len(s.groups)), nil
} }
...@@ -187,7 +214,16 @@ func (s *stubAdminService) BatchSetGroupRateMultipliers(_ context.Context, _ int ...@@ -187,7 +214,16 @@ func (s *stubAdminService) BatchSetGroupRateMultipliers(_ context.Context, _ int
return nil return nil
} }
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, int64, error) { func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string, sortBy, sortOrder string) ([]service.Account, int64, error) {
s.lastListAccounts.platform = platform
s.lastListAccounts.accountType = accountType
s.lastListAccounts.status = status
s.lastListAccounts.search = search
s.lastListAccounts.groupID = groupID
s.lastListAccounts.privacyMode = privacyMode
s.lastListAccounts.sortBy = sortBy
s.lastListAccounts.sortOrder = sortOrder
s.lastListAccounts.calls++
return s.accounts, int64(len(s.accounts)), nil return s.accounts, int64(len(s.accounts)), nil
} }
...@@ -261,7 +297,13 @@ func (s *stubAdminService) CheckMixedChannelRisk(ctx context.Context, currentAcc ...@@ -261,7 +297,13 @@ func (s *stubAdminService) CheckMixedChannelRisk(ctx context.Context, currentAcc
return s.checkMixedErr return s.checkMixedErr
} }
func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.Proxy, int64, error) { func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string, sortBy, sortOrder string) ([]service.Proxy, int64, error) {
s.lastListProxies.protocol = protocol
s.lastListProxies.status = status
s.lastListProxies.search = search
s.lastListProxies.sortBy = sortBy
s.lastListProxies.sortOrder = sortOrder
s.lastListProxies.calls++
search = strings.TrimSpace(strings.ToLower(search)) search = strings.TrimSpace(strings.ToLower(search))
filtered := make([]service.Proxy, 0, len(s.proxies)) filtered := make([]service.Proxy, 0, len(s.proxies))
for _, proxy := range s.proxies { for _, proxy := range s.proxies {
...@@ -283,7 +325,7 @@ func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, ...@@ -283,7 +325,7 @@ func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int,
return filtered, int64(len(filtered)), nil return filtered, int64(len(filtered)), nil
} }
func (s *stubAdminService) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.ProxyWithAccountCount, int64, error) { func (s *stubAdminService) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string, sortBy, sortOrder string) ([]service.ProxyWithAccountCount, int64, error) {
return s.proxyCounts, int64(len(s.proxyCounts)), nil return s.proxyCounts, int64(len(s.proxyCounts)), nil
} }
...@@ -384,7 +426,13 @@ func (s *stubAdminService) CheckProxyQuality(ctx context.Context, id int64) (*se ...@@ -384,7 +426,13 @@ func (s *stubAdminService) CheckProxyQuality(ctx context.Context, id int64) (*se
}, nil }, nil
} }
func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]service.RedeemCode, int64, error) { func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string, sortBy, sortOrder string) ([]service.RedeemCode, int64, error) {
s.lastListRedeemCodes.codeType = codeType
s.lastListRedeemCodes.status = status
s.lastListRedeemCodes.search = search
s.lastListRedeemCodes.sortBy = sortBy
s.lastListRedeemCodes.sortOrder = sortOrder
s.lastListRedeemCodes.calls++
return s.redeems, int64(len(s.redeems)), nil return s.redeems, int64(len(s.redeems)), nil
} }
......
...@@ -52,13 +52,17 @@ func (h *AnnouncementHandler) List(c *gin.Context) { ...@@ -52,13 +52,17 @@ func (h *AnnouncementHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c) page, pageSize := response.ParsePagination(c)
status := strings.TrimSpace(c.Query("status")) status := strings.TrimSpace(c.Query("status"))
search := strings.TrimSpace(c.Query("search")) search := strings.TrimSpace(c.Query("search"))
sortBy := c.DefaultQuery("sort_by", "created_at")
sortOrder := c.DefaultQuery("sort_order", "desc")
if len(search) > 200 { if len(search) > 200 {
search = search[:200] search = search[:200]
} }
params := pagination.PaginationParams{ params := pagination.PaginationParams{
Page: page, Page: page,
PageSize: pageSize, PageSize: pageSize,
SortBy: sortBy,
SortOrder: sortOrder,
} }
items, paginationResult, err := h.announcementService.List( items, paginationResult, err := h.announcementService.List(
...@@ -227,8 +231,10 @@ func (h *AnnouncementHandler) ListReadStatus(c *gin.Context) { ...@@ -227,8 +231,10 @@ func (h *AnnouncementHandler) ListReadStatus(c *gin.Context) {
page, pageSize := response.ParsePagination(c) page, pageSize := response.ParsePagination(c)
params := pagination.PaginationParams{ params := pagination.PaginationParams{
Page: page, Page: page,
PageSize: pageSize, PageSize: pageSize,
SortBy: c.DefaultQuery("sort_by", "email"),
SortOrder: c.DefaultQuery("sort_order", "asc"),
} }
search := strings.TrimSpace(c.Query("search")) search := strings.TrimSpace(c.Query("search"))
if len(search) > 200 { if len(search) > 200 {
......
package admin
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type announcementRepoCapture struct {
service.AnnouncementRepository
listParams pagination.PaginationParams
}
func (r *announcementRepoCapture) List(ctx context.Context, params pagination.PaginationParams, filters service.AnnouncementListFilters) ([]service.Announcement, *pagination.PaginationResult, error) {
r.listParams = params
return []service.Announcement{}, &pagination.PaginationResult{
Total: 0,
Page: params.Page,
PageSize: params.PageSize,
Pages: 0,
}, nil
}
func (r *announcementRepoCapture) GetByID(ctx context.Context, id int64) (*service.Announcement, error) {
return &service.Announcement{
ID: id,
Title: "announcement",
Content: "content",
Status: service.AnnouncementStatusActive,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}, nil
}
type announcementUserRepoCapture struct {
service.UserRepository
listParams pagination.PaginationParams
}
func (r *announcementUserRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
r.listParams = params
return []service.User{}, &pagination.PaginationResult{
Total: 0,
Page: params.Page,
PageSize: params.PageSize,
Pages: 0,
}, nil
}
type announcementReadRepoCapture struct {
service.AnnouncementReadRepository
}
func (r *announcementReadRepoCapture) GetReadMapByUsers(ctx context.Context, announcementID int64, userIDs []int64) (map[int64]time.Time, error) {
return map[int64]time.Time{}, nil
}
type announcementUserSubRepoCapture struct {
service.UserSubscriptionRepository
}
func newAnnouncementSortTestRouter(announcementRepo *announcementRepoCapture, userRepo *announcementUserRepoCapture) *gin.Engine {
gin.SetMode(gin.TestMode)
svc := service.NewAnnouncementService(
announcementRepo,
&announcementReadRepoCapture{},
userRepo,
&announcementUserSubRepoCapture{},
)
handler := NewAnnouncementHandler(svc)
router := gin.New()
router.GET("/admin/announcements", handler.List)
router.GET("/admin/announcements/:id/read-status", handler.ListReadStatus)
return router
}
func TestAdminAnnouncementListSortParams(t *testing.T) {
announcementRepo := &announcementRepoCapture{}
userRepo := &announcementUserRepoCapture{}
router := newAnnouncementSortTestRouter(announcementRepo, userRepo)
req := httptest.NewRequest(http.MethodGet, "/admin/announcements?sort_by=title&sort_order=ASC", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "title", announcementRepo.listParams.SortBy)
require.Equal(t, "ASC", announcementRepo.listParams.SortOrder)
}
func TestAdminAnnouncementListSortDefaults(t *testing.T) {
announcementRepo := &announcementRepoCapture{}
userRepo := &announcementUserRepoCapture{}
router := newAnnouncementSortTestRouter(announcementRepo, userRepo)
req := httptest.NewRequest(http.MethodGet, "/admin/announcements", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "created_at", announcementRepo.listParams.SortBy)
require.Equal(t, "desc", announcementRepo.listParams.SortOrder)
}
func TestAdminAnnouncementReadStatusSortParams(t *testing.T) {
announcementRepo := &announcementRepoCapture{}
userRepo := &announcementUserRepoCapture{}
router := newAnnouncementSortTestRouter(announcementRepo, userRepo)
req := httptest.NewRequest(http.MethodGet, "/admin/announcements/1/read-status?sort_by=balance&sort_order=DESC", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "balance", userRepo.listParams.SortBy)
require.Equal(t, "DESC", userRepo.listParams.SortOrder)
}
func TestAdminAnnouncementReadStatusSortDefaults(t *testing.T) {
announcementRepo := &announcementRepoCapture{}
userRepo := &announcementUserRepoCapture{}
router := newAnnouncementSortTestRouter(announcementRepo, userRepo)
req := httptest.NewRequest(http.MethodGet, "/admin/announcements/1/read-status", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "email", userRepo.listParams.SortBy)
require.Equal(t, "asc", userRepo.listParams.SortOrder)
}
...@@ -245,7 +245,12 @@ func (h *ChannelHandler) List(c *gin.Context) { ...@@ -245,7 +245,12 @@ func (h *ChannelHandler) List(c *gin.Context) {
search = search[:100] search = search[:100]
} }
channels, pag, err := h.channelService.List(c.Request.Context(), pagination.PaginationParams{Page: page, PageSize: pageSize}, status, search) channels, pag, err := h.channelService.List(c.Request.Context(), pagination.PaginationParams{
Page: page,
PageSize: pageSize,
SortBy: c.DefaultQuery("sort_by", "created_at"),
SortOrder: c.DefaultQuery("sort_order", "desc"),
}, status, search)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
......
...@@ -160,6 +160,8 @@ func (h *GroupHandler) List(c *gin.Context) { ...@@ -160,6 +160,8 @@ func (h *GroupHandler) List(c *gin.Context) {
search = search[:100] search = search[:100]
} }
isExclusiveStr := c.Query("is_exclusive") isExclusiveStr := c.Query("is_exclusive")
sortBy := c.DefaultQuery("sort_by", "sort_order")
sortOrder := c.DefaultQuery("sort_order", "asc")
var isExclusive *bool var isExclusive *bool
if isExclusiveStr != "" { if isExclusiveStr != "" {
...@@ -167,7 +169,7 @@ func (h *GroupHandler) List(c *gin.Context) { ...@@ -167,7 +169,7 @@ func (h *GroupHandler) List(c *gin.Context) {
isExclusive = &val isExclusive = &val
} }
groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, search, isExclusive) groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, search, isExclusive, sortBy, sortOrder)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
......
...@@ -55,8 +55,10 @@ func (h *PromoHandler) List(c *gin.Context) { ...@@ -55,8 +55,10 @@ func (h *PromoHandler) List(c *gin.Context) {
} }
params := pagination.PaginationParams{ params := pagination.PaginationParams{
Page: page, Page: page,
PageSize: pageSize, PageSize: pageSize,
SortBy: c.DefaultQuery("sort_by", "created_at"),
SortOrder: c.DefaultQuery("sort_order", "desc"),
} }
codes, paginationResult, err := h.promoService.List(c.Request.Context(), params, status, search) codes, paginationResult, err := h.promoService.List(c.Request.Context(), params, status, search)
......
...@@ -52,13 +52,15 @@ func (h *ProxyHandler) List(c *gin.Context) { ...@@ -52,13 +52,15 @@ func (h *ProxyHandler) List(c *gin.Context) {
protocol := c.Query("protocol") protocol := c.Query("protocol")
status := c.Query("status") status := c.Query("status")
search := c.Query("search") search := c.Query("search")
sortBy := c.DefaultQuery("sort_by", "id")
sortOrder := c.DefaultQuery("sort_order", "desc")
// 标准化和验证 search 参数 // 标准化和验证 search 参数
search = strings.TrimSpace(search) search = strings.TrimSpace(search)
if len(search) > 100 { if len(search) > 100 {
search = search[:100] search = search[:100]
} }
proxies, total, err := h.adminService.ListProxiesWithAccountCount(c.Request.Context(), page, pageSize, protocol, status, search) proxies, total, err := h.adminService.ListProxiesWithAccountCount(c.Request.Context(), page, pageSize, protocol, status, search, sortBy, sortOrder)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
......
...@@ -165,7 +165,12 @@ func (h *UsageHandler) List(c *gin.Context) { ...@@ -165,7 +165,12 @@ func (h *UsageHandler) List(c *gin.Context) {
endTime = &t endTime = &t
} }
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{
Page: page,
PageSize: pageSize,
SortBy: c.DefaultQuery("sort_by", "created_at"),
SortOrder: c.DefaultQuery("sort_order", "desc"),
}
filters := usagestats.UsageLogFilters{ filters := usagestats.UsageLogFilters{
UserID: userID, UserID: userID,
APIKeyID: apiKeyID, APIKeyID: apiKeyID,
...@@ -339,7 +344,7 @@ func (h *UsageHandler) SearchUsers(c *gin.Context) { ...@@ -339,7 +344,7 @@ func (h *UsageHandler) SearchUsers(c *gin.Context) {
} }
// Limit to 30 results // Limit to 30 results
users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, service.UserListFilters{Search: keyword}) users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, service.UserListFilters{Search: keyword}, "email", "asc")
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
......
...@@ -15,11 +15,14 @@ import ( ...@@ -15,11 +15,14 @@ import (
type adminUsageRepoCapture struct { type adminUsageRepoCapture struct {
service.UsageLogRepository service.UsageLogRepository
listParams pagination.PaginationParams
listFilters usagestats.UsageLogFilters listFilters usagestats.UsageLogFilters
statsParams pagination.PaginationParams
statsFilters usagestats.UsageLogFilters statsFilters usagestats.UsageLogFilters
} }
func (s *adminUsageRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) { func (s *adminUsageRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
s.listParams = params
s.listFilters = filters s.listFilters = filters
return []service.UsageLog{}, &pagination.PaginationResult{ return []service.UsageLog{}, &pagination.PaginationResult{
Total: 0, Total: 0,
......
package admin
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
)
func TestAdminUsageListSortParams(t *testing.T) {
repo := &adminUsageRepoCapture{}
router := newAdminUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/usage?sort_by=model&sort_order=ASC", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "model", repo.listParams.SortBy)
require.Equal(t, "ASC", repo.listParams.SortOrder)
}
func TestAdminUsageListSortDefaults(t *testing.T) {
repo := &adminUsageRepoCapture{}
router := newAdminUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/usage", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "created_at", repo.listParams.SortBy)
require.Equal(t, "desc", repo.listParams.SortOrder)
}
...@@ -91,12 +91,14 @@ func (h *UserHandler) List(c *gin.Context) { ...@@ -91,12 +91,14 @@ func (h *UserHandler) List(c *gin.Context) {
GroupName: strings.TrimSpace(c.Query("group_name")), GroupName: strings.TrimSpace(c.Query("group_name")),
Attributes: parseAttributeFilters(c), Attributes: parseAttributeFilters(c),
} }
sortBy := c.DefaultQuery("sort_by", "created_at")
sortOrder := c.DefaultQuery("sort_order", "desc")
if raw, ok := c.GetQuery("include_subscriptions"); ok { if raw, ok := c.GetQuery("include_subscriptions"); ok {
includeSubscriptions := parseBoolQueryWithDefault(raw, true) includeSubscriptions := parseBoolQueryWithDefault(raw, true)
filters.IncludeSubscriptions = &includeSubscriptions filters.IncludeSubscriptions = &includeSubscriptions
} }
users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, filters) users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, filters, sortBy, sortOrder)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
...@@ -290,8 +292,10 @@ func (h *UserHandler) GetUserAPIKeys(c *gin.Context) { ...@@ -290,8 +292,10 @@ func (h *UserHandler) GetUserAPIKeys(c *gin.Context) {
} }
page, pageSize := response.ParsePagination(c) page, pageSize := response.ParsePagination(c)
sortBy := c.DefaultQuery("sort_by", "created_at")
sortOrder := c.DefaultQuery("sort_order", "desc")
keys, total, err := h.adminService.GetUserAPIKeys(c.Request.Context(), userID, page, pageSize) keys, total, err := h.adminService.GetUserAPIKeys(c.Request.Context(), userID, page, pageSize, sortBy, sortOrder)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
......
...@@ -72,7 +72,12 @@ func (h *APIKeyHandler) List(c *gin.Context) { ...@@ -72,7 +72,12 @@ func (h *APIKeyHandler) List(c *gin.Context) {
} }
page, pageSize := response.ParsePagination(c) page, pageSize := response.ParsePagination(c)
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{
Page: page,
PageSize: pageSize,
SortBy: c.DefaultQuery("sort_by", "created_at"),
SortOrder: c.DefaultQuery("sort_order", "desc"),
}
// Parse filter parameters // Parse filter parameters
var filters service.APIKeyListFilters var filters service.APIKeyListFilters
......
...@@ -119,7 +119,12 @@ func (h *UsageHandler) List(c *gin.Context) { ...@@ -119,7 +119,12 @@ func (h *UsageHandler) List(c *gin.Context) {
endTime = &t endTime = &t
} }
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{
Page: page,
PageSize: pageSize,
SortBy: c.DefaultQuery("sort_by", "created_at"),
SortOrder: c.DefaultQuery("sort_order", "desc"),
}
filters := usagestats.UsageLogFilters{ filters := usagestats.UsageLogFilters{
UserID: subject.UserID, // Always filter by current user for security UserID: subject.UserID, // Always filter by current user for security
APIKeyID: apiKeyID, APIKeyID: apiKeyID,
......
...@@ -16,10 +16,12 @@ import ( ...@@ -16,10 +16,12 @@ import (
type userUsageRepoCapture struct { type userUsageRepoCapture struct {
service.UsageLogRepository service.UsageLogRepository
listParams pagination.PaginationParams
listFilters usagestats.UsageLogFilters listFilters usagestats.UsageLogFilters
} }
func (s *userUsageRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) { func (s *userUsageRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
s.listParams = params
s.listFilters = filters s.listFilters = filters
return []service.UsageLog{}, &pagination.PaginationResult{ return []service.UsageLog{}, &pagination.PaginationResult{
Total: 0, Total: 0,
......
package handler
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
)
func TestUserUsageListSortParams(t *testing.T) {
repo := &userUsageRepoCapture{}
router := newUserUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/usage?sort_by=model&sort_order=ASC", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "model", repo.listParams.SortBy)
require.Equal(t, "ASC", repo.listParams.SortOrder)
}
func TestUserUsageListSortDefaults(t *testing.T) {
repo := &userUsageRepoCapture{}
router := newUserUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/usage", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "created_at", repo.listParams.SortBy)
require.Equal(t, "desc", repo.listParams.SortOrder)
}
// Package pagination provides types and helpers for paginated responses. // Package pagination provides types and helpers for paginated responses.
package pagination package pagination
import "strings"
const (
SortOrderAsc = "asc"
SortOrderDesc = "desc"
)
// PaginationParams 分页参数 // PaginationParams 分页参数
type PaginationParams struct { type PaginationParams struct {
Page int Page int
PageSize int PageSize int
SortBy string
SortOrder string
} }
// PaginationResult 分页结果 // PaginationResult 分页结果
...@@ -18,8 +27,9 @@ type PaginationResult struct { ...@@ -18,8 +27,9 @@ type PaginationResult struct {
// DefaultPagination 默认分页参数 // DefaultPagination 默认分页参数
func DefaultPagination() PaginationParams { func DefaultPagination() PaginationParams {
return PaginationParams{ return PaginationParams{
Page: 1, Page: 1,
PageSize: 20, PageSize: 20,
SortOrder: SortOrderDesc,
} }
} }
...@@ -36,8 +46,32 @@ func (p PaginationParams) Limit() int { ...@@ -36,8 +46,32 @@ func (p PaginationParams) Limit() int {
if p.PageSize < 1 { if p.PageSize < 1 {
return 20 return 20
} }
if p.PageSize > 100 { if p.PageSize > 1000 {
return 100 return 1000
} }
return p.PageSize return p.PageSize
} }
// NormalizeSortOrder normalizes sort order to asc/desc and falls back to defaultOrder.
func NormalizeSortOrder(order string, defaultOrder string) string {
switch strings.ToLower(strings.TrimSpace(defaultOrder)) {
case SortOrderAsc:
defaultOrder = SortOrderAsc
default:
defaultOrder = SortOrderDesc
}
switch strings.ToLower(strings.TrimSpace(order)) {
case SortOrderAsc:
return SortOrderAsc
case SortOrderDesc:
return SortOrderDesc
default:
return defaultOrder
}
}
// NormalizedSortOrder returns the normalized sort order using defaultOrder as fallback.
func (p PaginationParams) NormalizedSortOrder(defaultOrder string) string {
return NormalizeSortOrder(p.SortOrder, defaultOrder)
}
package pagination
import "testing"
func TestNormalizeSortOrder(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
defaultOrder string
want string
}{
{name: "asc", input: "asc", defaultOrder: "desc", want: "asc"},
{name: "uppercase asc", input: "ASC", defaultOrder: "desc", want: "asc"},
{name: "desc", input: "desc", defaultOrder: "asc", want: "desc"},
{name: "trim spaces", input: " desc ", defaultOrder: "asc", want: "desc"},
{name: "invalid falls back", input: "sideways", defaultOrder: "asc", want: "asc"},
{name: "empty falls back", input: "", defaultOrder: "desc", want: "desc"},
{name: "invalid default falls back to desc", input: "", defaultOrder: "wat", want: "desc"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := NormalizeSortOrder(tt.input, tt.defaultOrder); got != tt.want {
t.Fatalf("NormalizeSortOrder(%q, %q) = %q, want %q", tt.input, tt.defaultOrder, got, tt.want)
}
})
}
}
func TestPaginationParamsNormalizedSortOrder(t *testing.T) {
t.Parallel()
params := PaginationParams{SortOrder: "ASC"}
if got := params.NormalizedSortOrder("desc"); got != "asc" {
t.Fatalf("NormalizedSortOrder = %q, want asc", got)
}
params = PaginationParams{SortOrder: "bad"}
if got := params.NormalizedSortOrder("asc"); got != "asc" {
t.Fatalf("NormalizedSortOrder invalid fallback = %q, want asc", got)
}
}
func TestPaginationParamsLimit(t *testing.T) {
t.Parallel()
tests := []struct {
name string
pageSize int
want int
}{
{name: "non-positive falls back to default", pageSize: 0, want: 20},
{name: "negative falls back to default", pageSize: -1, want: 20},
{name: "normal value keeps", pageSize: 50, want: 50},
{name: "max value keeps", pageSize: 1000, want: 1000},
{name: "beyond max clamps to 1000", pageSize: 1500, want: 1000},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
p := PaginationParams{PageSize: tt.pageSize}
if got := p.Limit(); got != tt.want {
t.Fatalf("Limit() for PageSize=%d = %d, want %d", tt.pageSize, got, tt.want)
}
})
}
}
...@@ -471,21 +471,58 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati ...@@ -471,21 +471,58 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
case service.StatusActive: case service.StatusActive:
q = q.Where( q = q.Where(
dbaccount.StatusEQ(status), dbaccount.StatusEQ(status),
dbaccount.SchedulableEQ(true),
dbaccount.Or( dbaccount.Or(
dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtIsNil(),
dbaccount.RateLimitResetAtLTE(time.Now()), dbaccount.RateLimitResetAtLTE(time.Now()),
), ),
dbpredicate.Account(func(s *entsql.Selector) {
col := s.C("temp_unschedulable_until")
s.Where(entsql.Or(
entsql.IsNull(col),
entsql.LTE(col, entsql.Expr("NOW()")),
))
}),
) )
case "rate_limited": case "rate_limited":
q = q.Where(dbaccount.RateLimitResetAtGT(time.Now())) q = q.Where(
dbaccount.StatusEQ(service.StatusActive),
dbaccount.RateLimitResetAtGT(time.Now()),
dbpredicate.Account(func(s *entsql.Selector) {
col := s.C("temp_unschedulable_until")
s.Where(entsql.Or(
entsql.IsNull(col),
entsql.LTE(col, entsql.Expr("NOW()")),
))
}),
)
case "temp_unschedulable": case "temp_unschedulable":
q = q.Where(dbpredicate.Account(func(s *entsql.Selector) { q = q.Where(
col := s.C("temp_unschedulable_until") dbaccount.StatusEQ(service.StatusActive),
s.Where(entsql.And( dbpredicate.Account(func(s *entsql.Selector) {
entsql.Not(entsql.IsNull(col)), col := s.C("temp_unschedulable_until")
entsql.GT(col, entsql.Expr("NOW()")), s.Where(entsql.And(
)) entsql.Not(entsql.IsNull(col)),
})) entsql.GT(col, entsql.Expr("NOW()")),
))
}),
)
case "unschedulable":
q = q.Where(
dbaccount.StatusEQ(service.StatusActive),
dbaccount.SchedulableEQ(false),
dbaccount.Or(
dbaccount.RateLimitResetAtIsNil(),
dbaccount.RateLimitResetAtLTE(time.Now()),
),
dbpredicate.Account(func(s *entsql.Selector) {
col := s.C("temp_unschedulable_until")
s.Where(entsql.Or(
entsql.IsNull(col),
entsql.LTE(col, entsql.Expr("NOW()")),
))
}),
)
default: default:
q = q.Where(dbaccount.StatusEQ(status)) q = q.Where(dbaccount.StatusEQ(status))
} }
...@@ -518,11 +555,14 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati ...@@ -518,11 +555,14 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
return nil, nil, err return nil, nil, err
} }
accounts, err := q. accountsQuery := q.
Offset(params.Offset()). Offset(params.Offset()).
Limit(params.Limit()). Limit(params.Limit())
Order(dbent.Desc(dbaccount.FieldID)). for _, order := range accountListOrder(params) {
All(ctx) accountsQuery = accountsQuery.Order(order)
}
accounts, err := accountsQuery.All(ctx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
...@@ -534,6 +574,50 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati ...@@ -534,6 +574,50 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
return outAccounts, paginationResultFromTotal(int64(total), params), nil return outAccounts, paginationResultFromTotal(int64(total), params), nil
} }
func accountListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
sortOrder := params.NormalizedSortOrder(pagination.SortOrderAsc)
field := dbaccount.FieldName
defaultOrder := true
switch sortBy {
case "", "name":
field = dbaccount.FieldName
case "id":
field = dbaccount.FieldID
defaultOrder = false
case "status":
field = dbaccount.FieldStatus
defaultOrder = false
case "schedulable":
field = dbaccount.FieldSchedulable
defaultOrder = false
case "priority":
field = dbaccount.FieldPriority
defaultOrder = false
case "rate_multiplier":
field = dbaccount.FieldRateMultiplier
defaultOrder = false
case "last_used_at":
field = dbaccount.FieldLastUsedAt
defaultOrder = false
case "expires_at":
field = dbaccount.FieldExpiresAt
defaultOrder = false
case "created_at":
field = dbaccount.FieldCreatedAt
defaultOrder = false
}
if sortOrder == pagination.SortOrderDesc {
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(dbaccount.FieldID)}
}
if defaultOrder {
return []func(*entsql.Selector){dbent.Asc(dbaccount.FieldName), dbent.Asc(dbaccount.FieldID)}
}
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(dbaccount.FieldID)}
}
func (r *accountRepository) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) { func (r *accountRepository) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {
accounts, err := r.queryAccountsByGroup(ctx, groupID, accountGroupQueryOptions{ accounts, err := r.queryAccountsByGroup(ctx, groupID, accountGroupQueryOptions{
status: service.StatusActive, status: service.StatusActive,
......
...@@ -256,7 +256,7 @@ func (s *AccountRepoSuite) TestListWithFilters() { ...@@ -256,7 +256,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
}, },
}, },
{ {
name: "filter_by_status_active_excludes_rate_limited", name: "filter_by_status_active_excludes_runtime_blocked_accounts",
setup: func(client *dbent.Client) { setup: func(client *dbent.Client) {
mustCreateAccount(s.T(), client, &service.Account{Name: "active-normal", Status: service.StatusActive}) mustCreateAccount(s.T(), client, &service.Account{Name: "active-normal", Status: service.StatusActive})
rateLimited := mustCreateAccount(s.T(), client, &service.Account{Name: "active-rate-limited", Status: service.StatusActive}) rateLimited := mustCreateAccount(s.T(), client, &service.Account{Name: "active-rate-limited", Status: service.StatusActive})
...@@ -264,6 +264,16 @@ func (s *AccountRepoSuite) TestListWithFilters() { ...@@ -264,6 +264,16 @@ func (s *AccountRepoSuite) TestListWithFilters() {
SetRateLimitResetAt(time.Now().Add(10 * time.Minute)). SetRateLimitResetAt(time.Now().Add(10 * time.Minute)).
Exec(context.Background()) Exec(context.Background())
s.Require().NoError(err) s.Require().NoError(err)
tempUnsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-temp-unsched", Status: service.StatusActive})
err = client.Account.UpdateOneID(tempUnsched.ID).
SetTempUnschedulableUntil(time.Now().Add(15 * time.Minute)).
Exec(context.Background())
s.Require().NoError(err)
unsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-unsched", Status: service.StatusActive})
err = client.Account.UpdateOneID(unsched.ID).
SetSchedulable(false).
Exec(context.Background())
s.Require().NoError(err)
}, },
status: service.StatusActive, status: service.StatusActive,
wantCount: 1, wantCount: 1,
...@@ -271,6 +281,75 @@ func (s *AccountRepoSuite) TestListWithFilters() { ...@@ -271,6 +281,75 @@ func (s *AccountRepoSuite) TestListWithFilters() {
s.Require().Equal("active-normal", accounts[0].Name) s.Require().Equal("active-normal", accounts[0].Name)
}, },
}, },
{
name: "filter_by_status_unschedulable_excludes_rate_limited_and_temp_unschedulable",
setup: func(client *dbent.Client) {
mustCreateAccount(s.T(), client, &service.Account{Name: "active-normal", Status: service.StatusActive, Schedulable: true})
unsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-unsched", Status: service.StatusActive})
err := client.Account.UpdateOneID(unsched.ID).
SetSchedulable(false).
Exec(context.Background())
s.Require().NoError(err)
rateLimited := mustCreateAccount(s.T(), client, &service.Account{Name: "active-rate-limited", Status: service.StatusActive})
err = client.Account.UpdateOneID(rateLimited.ID).
SetSchedulable(false).
SetRateLimitResetAt(time.Now().Add(10 * time.Minute)).
Exec(context.Background())
s.Require().NoError(err)
tempUnsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-temp-unsched", Status: service.StatusActive})
err = client.Account.UpdateOneID(tempUnsched.ID).
SetSchedulable(false).
SetTempUnschedulableUntil(time.Now().Add(15 * time.Minute)).
Exec(context.Background())
s.Require().NoError(err)
},
status: "unschedulable",
wantCount: 1,
validate: func(accounts []service.Account) {
s.Require().Equal("active-unsched", accounts[0].Name)
},
},
{
name: "filter_by_status_rate_limited_excludes_temp_unschedulable",
setup: func(client *dbent.Client) {
rateLimited := mustCreateAccount(s.T(), client, &service.Account{Name: "active-rate-limited", Status: service.StatusActive})
err := client.Account.UpdateOneID(rateLimited.ID).
SetRateLimitResetAt(time.Now().Add(10 * time.Minute)).
Exec(context.Background())
s.Require().NoError(err)
tempUnsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-temp-unsched", Status: service.StatusActive})
err = client.Account.UpdateOneID(tempUnsched.ID).
SetRateLimitResetAt(time.Now().Add(20 * time.Minute)).
SetTempUnschedulableUntil(time.Now().Add(15 * time.Minute)).
Exec(context.Background())
s.Require().NoError(err)
},
status: "rate_limited",
wantCount: 1,
validate: func(accounts []service.Account) {
s.Require().Equal("active-rate-limited", accounts[0].Name)
},
},
{
name: "filter_by_status_temp_unschedulable_excludes_manually_unschedulable",
setup: func(client *dbent.Client) {
tempUnsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-temp-unsched", Status: service.StatusActive, Schedulable: true})
err := client.Account.UpdateOneID(tempUnsched.ID).
SetTempUnschedulableUntil(time.Now().Add(15 * time.Minute)).
Exec(context.Background())
s.Require().NoError(err)
unsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-unsched", Status: service.StatusActive})
err = client.Account.UpdateOneID(unsched.ID).
SetSchedulable(false).
Exec(context.Background())
s.Require().NoError(err)
},
status: "temp_unschedulable",
wantCount: 1,
validate: func(accounts []service.Account) {
s.Require().Equal("active-temp-unsched", accounts[0].Name)
},
},
{ {
name: "filter_by_search", name: "filter_by_search",
setup: func(client *dbent.Client) { setup: func(client *dbent.Client) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment