Unverified Commit 14b155c6 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #7 from NepetaLemon/refactor/ports-pattern

refactor(backend): 引入端口接口模式
parents 7fd94ab7 e99b344b
...@@ -57,32 +57,21 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -57,32 +57,21 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
usageService := service.NewUsageService(usageLogRepository, userRepository) usageService := service.NewUsageService(usageLogRepository, userRepository)
usageHandler := handler.NewUsageHandler(usageService, usageLogRepository, apiKeyService) usageHandler := handler.NewUsageHandler(usageService, usageLogRepository, apiKeyService)
redeemCodeRepository := repository.NewRedeemCodeRepository(db) redeemCodeRepository := repository.NewRedeemCodeRepository(db)
accountRepository := repository.NewAccountRepository(db)
proxyRepository := repository.NewProxyRepository(db)
repositories := &repository.Repositories{
User: userRepository,
ApiKey: apiKeyRepository,
Group: groupRepository,
Account: accountRepository,
Proxy: proxyRepository,
RedeemCode: redeemCodeRepository,
UsageLog: usageLogRepository,
Setting: settingRepository,
UserSubscription: userSubscriptionRepository,
}
billingCacheService := service.NewBillingCacheService(client, userRepository, userSubscriptionRepository) billingCacheService := service.NewBillingCacheService(client, userRepository, userSubscriptionRepository)
subscriptionService := service.NewSubscriptionService(repositories, billingCacheService) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, client, billingCacheService) redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, client, billingCacheService)
redeemHandler := handler.NewRedeemHandler(redeemService) redeemHandler := handler.NewRedeemHandler(redeemService)
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService) subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
adminService := service.NewAdminService(repositories, billingCacheService) accountRepository := repository.NewAccountRepository(db)
proxyRepository := repository.NewProxyRepository(db)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, usageLogRepository, userSubscriptionRepository, billingCacheService)
dashboardHandler := admin.NewDashboardHandler(adminService, usageLogRepository) dashboardHandler := admin.NewDashboardHandler(adminService, usageLogRepository)
adminUserHandler := admin.NewUserHandler(adminService) adminUserHandler := admin.NewUserHandler(adminService)
groupHandler := admin.NewGroupHandler(adminService) groupHandler := admin.NewGroupHandler(adminService)
oAuthService := service.NewOAuthService(proxyRepository) oAuthService := service.NewOAuthService(proxyRepository)
rateLimitService := service.NewRateLimitService(repositories, configConfig) rateLimitService := service.NewRateLimitService(accountRepository, configConfig)
accountUsageService := service.NewAccountUsageService(repositories, oAuthService) accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, oAuthService)
accountTestService := service.NewAccountTestService(repositories, oAuthService) accountTestService := service.NewAccountTestService(accountRepository, oAuthService)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, rateLimitService, accountUsageService, accountTestService) accountHandler := admin.NewAccountHandler(adminService, oAuthService, rateLimitService, accountUsageService, accountTestService)
oAuthHandler := admin.NewOAuthHandler(oAuthService, adminService) oAuthHandler := admin.NewOAuthHandler(oAuthService, adminService)
proxyHandler := admin.NewProxyHandler(adminService) proxyHandler := admin.NewProxyHandler(adminService)
...@@ -98,7 +87,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -98,7 +87,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
} }
billingService := service.NewBillingService(configConfig, pricingService) billingService := service.NewBillingService(configConfig, pricingService)
identityService := service.NewIdentityService(client) identityService := service.NewIdentityService(client)
gatewayService := service.NewGatewayService(repositories, client, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService) gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, client, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService)
concurrencyService := service.NewConcurrencyService(client) concurrencyService := service.NewConcurrencyService(client)
gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService) gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
...@@ -132,6 +121,17 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -132,6 +121,17 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
Concurrency: concurrencyService, Concurrency: concurrencyService,
Identity: identityService, Identity: identityService,
} }
repositories := &repository.Repositories{
User: userRepository,
ApiKey: apiKeyRepository,
Group: groupRepository,
Account: accountRepository,
Proxy: proxyRepository,
RedeemCode: redeemCodeRepository,
UsageLog: usageLogRepository,
Setting: settingRepository,
UserSubscription: userSubscriptionRepository,
}
engine := server.ProvideRouter(configConfig, handlers, services, repositories) engine := server.ProvideRouter(configConfig, handlers, services, repositories)
httpServer := server.ProvideHTTPServer(configConfig, engine) httpServer := server.ProvideHTTPServer(configConfig, engine)
v := provideCleanup(db, client, services) v := provideCleanup(db, client, services)
......
...@@ -4,15 +4,15 @@ import ( ...@@ -4,15 +4,15 @@ import (
"strconv" "strconv"
"sub2api/internal/model" "sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"sub2api/internal/pkg/response" "sub2api/internal/pkg/response"
"sub2api/internal/repository"
"sub2api/internal/service" "sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// toResponsePagination converts repository.PaginationResult to response.PaginationResult // toResponsePagination converts pagination.PaginationResult to response.PaginationResult
func toResponsePagination(p *repository.PaginationResult) *response.PaginationResult { func toResponsePagination(p *pagination.PaginationResult) *response.PaginationResult {
if p == nil { if p == nil {
return nil return nil
} }
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"strconv" "strconv"
"time" "time"
"sub2api/internal/pkg/pagination"
"sub2api/internal/pkg/response" "sub2api/internal/pkg/response"
"sub2api/internal/pkg/timezone" "sub2api/internal/pkg/timezone"
"sub2api/internal/repository" "sub2api/internal/repository"
...@@ -82,7 +83,7 @@ func (h *UsageHandler) List(c *gin.Context) { ...@@ -82,7 +83,7 @@ func (h *UsageHandler) List(c *gin.Context) {
endTime = &t endTime = &t
} }
params := repository.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
filters := repository.UsageLogFilters{ filters := repository.UsageLogFilters{
UserID: userID, UserID: userID,
ApiKeyID: apiKeyID, ApiKeyID: apiKeyID,
......
...@@ -4,8 +4,8 @@ import ( ...@@ -4,8 +4,8 @@ import (
"strconv" "strconv"
"sub2api/internal/model" "sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"sub2api/internal/pkg/response" "sub2api/internal/pkg/response"
"sub2api/internal/repository"
"sub2api/internal/service" "sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
...@@ -53,7 +53,7 @@ func (h *APIKeyHandler) List(c *gin.Context) { ...@@ -53,7 +53,7 @@ func (h *APIKeyHandler) List(c *gin.Context) {
} }
page, pageSize := response.ParsePagination(c) page, pageSize := response.ParsePagination(c)
params := repository.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := h.apiKeyService.List(c.Request.Context(), user.ID, params) keys, result, err := h.apiKeyService.List(c.Request.Context(), user.ID, params)
if err != nil { if err != nil {
......
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
"time" "time"
"sub2api/internal/model" "sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"sub2api/internal/pkg/response" "sub2api/internal/pkg/response"
"sub2api/internal/pkg/timezone" "sub2api/internal/pkg/timezone"
"sub2api/internal/repository" "sub2api/internal/repository"
...@@ -68,9 +69,9 @@ func (h *UsageHandler) List(c *gin.Context) { ...@@ -68,9 +69,9 @@ func (h *UsageHandler) List(c *gin.Context) {
apiKeyID = id apiKeyID = id
} }
params := repository.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
var records []model.UsageLog var records []model.UsageLog
var result *repository.PaginationResult var result *pagination.PaginationResult
var err error var err error
if apiKeyID > 0 { if apiKeyID > 0 {
...@@ -362,7 +363,7 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) { ...@@ -362,7 +363,7 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
} }
// Verify ownership of all requested API keys // Verify ownership of all requested API keys
userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), user.ID, repository.PaginationParams{Page: 1, PageSize: 1000}) userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), user.ID, pagination.PaginationParams{Page: 1, PageSize: 1000})
if err != nil { if err != nil {
response.InternalError(c, "Failed to verify API key ownership") response.InternalError(c, "Failed to verify API key ownership")
return return
......
package pagination
// PaginationParams 分页参数
type PaginationParams struct {
Page int
PageSize int
}
// PaginationResult 分页结果
type PaginationResult struct {
Total int64
Page int
PageSize int
Pages int
}
// DefaultPagination 默认分页参数
func DefaultPagination() PaginationParams {
return PaginationParams{
Page: 1,
PageSize: 20,
}
}
// Offset 计算偏移量
func (p PaginationParams) Offset() int {
if p.Page < 1 {
p.Page = 1
}
return (p.Page - 1) * p.PageSize
}
// Limit 获取限制数
func (p PaginationParams) Limit() int {
if p.PageSize < 1 {
return 20
}
if p.PageSize > 100 {
return 100
}
return p.PageSize
}
...@@ -90,7 +90,7 @@ func Paginated(c *gin.Context, items interface{}, total int64, page, pageSize in ...@@ -90,7 +90,7 @@ func Paginated(c *gin.Context, items interface{}, total int64, page, pageSize in
}) })
} }
// PaginationResult 分页结果(与repository.PaginationResult兼容) // PaginationResult 分页结果(与pagination.PaginationResult兼容)
type PaginationResult struct { type PaginationResult struct {
Total int64 Total int64
Page int Page int
......
package usagestats
// AccountStats 账号使用统计
type AccountStats struct {
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"`
}
...@@ -3,6 +3,7 @@ package repository ...@@ -3,6 +3,7 @@ package repository
import ( import (
"context" "context"
"sub2api/internal/model" "sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"time" "time"
"gorm.io/gorm" "gorm.io/gorm"
...@@ -47,12 +48,12 @@ func (r *AccountRepository) Delete(ctx context.Context, id int64) error { ...@@ -47,12 +48,12 @@ func (r *AccountRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.Account{}, id).Error return r.db.WithContext(ctx).Delete(&model.Account{}, id).Error
} }
func (r *AccountRepository) List(ctx context.Context, params PaginationParams) ([]model.Account, *PaginationResult, error) { func (r *AccountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "", "") return r.ListWithFilters(ctx, params, "", "", "", "")
} }
// ListWithFilters lists accounts with optional filtering by platform, type, status, and search query // ListWithFilters lists accounts with optional filtering by platform, type, status, and search query
func (r *AccountRepository) ListWithFilters(ctx context.Context, params PaginationParams, platform, accountType, status, search string) ([]model.Account, *PaginationResult, error) { func (r *AccountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error) {
var accounts []model.Account var accounts []model.Account
var total int64 var total int64
...@@ -94,7 +95,7 @@ func (r *AccountRepository) ListWithFilters(ctx context.Context, params Paginati ...@@ -94,7 +95,7 @@ func (r *AccountRepository) ListWithFilters(ctx context.Context, params Paginati
pages++ pages++
} }
return accounts, &PaginationResult{ return accounts, &pagination.PaginationResult{
Total: total, Total: total,
Page: params.Page, Page: params.Page,
PageSize: params.Limit(), PageSize: params.Limit(),
......
...@@ -3,6 +3,7 @@ package repository ...@@ -3,6 +3,7 @@ package repository
import ( import (
"context" "context"
"sub2api/internal/model" "sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"gorm.io/gorm" "gorm.io/gorm"
) )
...@@ -45,7 +46,7 @@ func (r *ApiKeyRepository) Delete(ctx context.Context, id int64) error { ...@@ -45,7 +46,7 @@ func (r *ApiKeyRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.ApiKey{}, id).Error return r.db.WithContext(ctx).Delete(&model.ApiKey{}, id).Error
} }
func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, params PaginationParams) ([]model.ApiKey, *PaginationResult, error) { func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
var keys []model.ApiKey var keys []model.ApiKey
var total int64 var total int64
...@@ -64,7 +65,7 @@ func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, param ...@@ -64,7 +65,7 @@ func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
pages++ pages++
} }
return keys, &PaginationResult{ return keys, &pagination.PaginationResult{
Total: total, Total: total,
Page: params.Page, Page: params.Page,
PageSize: params.Limit(), PageSize: params.Limit(),
...@@ -84,7 +85,7 @@ func (r *ApiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, e ...@@ -84,7 +85,7 @@ func (r *ApiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, e
return count > 0, err return count > 0, err
} }
func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params PaginationParams) ([]model.ApiKey, *PaginationResult, error) { func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
var keys []model.ApiKey var keys []model.ApiKey
var total int64 var total int64
...@@ -103,7 +104,7 @@ func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par ...@@ -103,7 +104,7 @@ func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
pages++ pages++
} }
return keys, &PaginationResult{ return keys, &pagination.PaginationResult{
Total: total, Total: total,
Page: params.Page, Page: params.Page,
PageSize: params.Limit(), PageSize: params.Limit(),
......
...@@ -3,6 +3,7 @@ package repository ...@@ -3,6 +3,7 @@ package repository
import ( import (
"context" "context"
"sub2api/internal/model" "sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"gorm.io/gorm" "gorm.io/gorm"
) )
...@@ -36,12 +37,12 @@ func (r *GroupRepository) Delete(ctx context.Context, id int64) error { ...@@ -36,12 +37,12 @@ func (r *GroupRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.Group{}, id).Error return r.db.WithContext(ctx).Delete(&model.Group{}, id).Error
} }
func (r *GroupRepository) List(ctx context.Context, params PaginationParams) ([]model.Group, *PaginationResult, error) { func (r *GroupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", nil) return r.ListWithFilters(ctx, params, "", "", nil)
} }
// ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive // ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive
func (r *GroupRepository) ListWithFilters(ctx context.Context, params PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *PaginationResult, error) { func (r *GroupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error) {
var groups []model.Group var groups []model.Group
var total int64 var total int64
...@@ -77,7 +78,7 @@ func (r *GroupRepository) ListWithFilters(ctx context.Context, params Pagination ...@@ -77,7 +78,7 @@ func (r *GroupRepository) ListWithFilters(ctx context.Context, params Pagination
pages++ pages++
} }
return groups, &PaginationResult{ return groups, &pagination.PaginationResult{
Total: total, Total: total,
Page: params.Page, Page: params.Page,
PageSize: params.Limit(), PageSize: params.Limit(),
......
...@@ -3,6 +3,7 @@ package repository ...@@ -3,6 +3,7 @@ package repository
import ( import (
"context" "context"
"sub2api/internal/model" "sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"gorm.io/gorm" "gorm.io/gorm"
) )
...@@ -36,12 +37,12 @@ func (r *ProxyRepository) Delete(ctx context.Context, id int64) error { ...@@ -36,12 +37,12 @@ func (r *ProxyRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.Proxy{}, id).Error return r.db.WithContext(ctx).Delete(&model.Proxy{}, id).Error
} }
func (r *ProxyRepository) List(ctx context.Context, params PaginationParams) ([]model.Proxy, *PaginationResult, error) { func (r *ProxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "") return r.ListWithFilters(ctx, params, "", "", "")
} }
// ListWithFilters lists proxies with optional filtering by protocol, status, and search query // ListWithFilters lists proxies with optional filtering by protocol, status, and search query
func (r *ProxyRepository) ListWithFilters(ctx context.Context, params PaginationParams, protocol, status, search string) ([]model.Proxy, *PaginationResult, error) { func (r *ProxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error) {
var proxies []model.Proxy var proxies []model.Proxy
var total int64 var total int64
...@@ -72,7 +73,7 @@ func (r *ProxyRepository) ListWithFilters(ctx context.Context, params Pagination ...@@ -72,7 +73,7 @@ func (r *ProxyRepository) ListWithFilters(ctx context.Context, params Pagination
pages++ pages++
} }
return proxies, &PaginationResult{ return proxies, &pagination.PaginationResult{
Total: total, Total: total,
Page: params.Page, Page: params.Page,
PageSize: params.Limit(), PageSize: params.Limit(),
......
...@@ -3,6 +3,7 @@ package repository ...@@ -3,6 +3,7 @@ package repository
import ( import (
"context" "context"
"sub2api/internal/model" "sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"time" "time"
"gorm.io/gorm" "gorm.io/gorm"
...@@ -46,12 +47,12 @@ func (r *RedeemCodeRepository) Delete(ctx context.Context, id int64) error { ...@@ -46,12 +47,12 @@ func (r *RedeemCodeRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.RedeemCode{}, id).Error return r.db.WithContext(ctx).Delete(&model.RedeemCode{}, id).Error
} }
func (r *RedeemCodeRepository) List(ctx context.Context, params PaginationParams) ([]model.RedeemCode, *PaginationResult, error) { func (r *RedeemCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "") return r.ListWithFilters(ctx, params, "", "", "")
} }
// ListWithFilters lists redeem codes with optional filtering by type, status, and search query // ListWithFilters lists redeem codes with optional filtering by type, status, and search query
func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params PaginationParams, codeType, status, search string) ([]model.RedeemCode, *PaginationResult, error) { func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]model.RedeemCode, *pagination.PaginationResult, error) {
var codes []model.RedeemCode var codes []model.RedeemCode
var total int64 var total int64
...@@ -82,7 +83,7 @@ func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params Pagin ...@@ -82,7 +83,7 @@ func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params Pagin
pages++ pages++
} }
return codes, &PaginationResult{ return codes, &pagination.PaginationResult{
Total: total, Total: total,
Page: params.Page, Page: params.Page,
PageSize: params.Limit(), PageSize: params.Limit(),
......
...@@ -12,44 +12,3 @@ type Repositories struct { ...@@ -12,44 +12,3 @@ type Repositories struct {
Setting *SettingRepository Setting *SettingRepository
UserSubscription *UserSubscriptionRepository UserSubscription *UserSubscriptionRepository
} }
// PaginationParams 分页参数
type PaginationParams struct {
Page int
PageSize int
}
// PaginationResult 分页结果
type PaginationResult struct {
Total int64
Page int
PageSize int
Pages int
}
// DefaultPagination 默认分页参数
func DefaultPagination() PaginationParams {
return PaginationParams{
Page: 1,
PageSize: 20,
}
}
// Offset 计算偏移量
func (p PaginationParams) Offset() int {
if p.Page < 1 {
p.Page = 1
}
return (p.Page - 1) * p.PageSize
}
// Limit 获取限制数
func (p PaginationParams) Limit() int {
if p.PageSize < 1 {
return 20
}
if p.PageSize > 100 {
return 100
}
return p.PageSize
}
...@@ -3,7 +3,9 @@ package repository ...@@ -3,7 +3,9 @@ package repository
import ( import (
"context" "context"
"sub2api/internal/model" "sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"sub2api/internal/pkg/timezone" "sub2api/internal/pkg/timezone"
"sub2api/internal/pkg/usagestats"
"time" "time"
"gorm.io/gorm" "gorm.io/gorm"
...@@ -30,7 +32,7 @@ func (r *UsageLogRepository) GetByID(ctx context.Context, id int64) (*model.Usag ...@@ -30,7 +32,7 @@ func (r *UsageLogRepository) GetByID(ctx context.Context, id int64) (*model.Usag
return &log, nil return &log, nil
} }
func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, params PaginationParams) ([]model.UsageLog, *PaginationResult, error) { func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
var total int64 var total int64
...@@ -49,7 +51,7 @@ func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, param ...@@ -49,7 +51,7 @@ func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, param
pages++ pages++
} }
return logs, &PaginationResult{ return logs, &pagination.PaginationResult{
Total: total, Total: total,
Page: params.Page, Page: params.Page,
PageSize: params.Limit(), PageSize: params.Limit(),
...@@ -57,7 +59,7 @@ func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, param ...@@ -57,7 +59,7 @@ func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, param
}, nil }, nil
} }
func (r *UsageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params PaginationParams) ([]model.UsageLog, *PaginationResult, error) { func (r *UsageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
var total int64 var total int64
...@@ -76,7 +78,7 @@ func (r *UsageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, p ...@@ -76,7 +78,7 @@ func (r *UsageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, p
pages++ pages++
} }
return logs, &PaginationResult{ return logs, &pagination.PaginationResult{
Total: total, Total: total,
Page: params.Page, Page: params.Page,
PageSize: params.Limit(), PageSize: params.Limit(),
...@@ -270,7 +272,7 @@ func (r *UsageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS ...@@ -270,7 +272,7 @@ func (r *UsageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
return &stats, nil return &stats, nil
} }
func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64, params PaginationParams) ([]model.UsageLog, *PaginationResult, error) { func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
var total int64 var total int64
...@@ -289,7 +291,7 @@ func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64, ...@@ -289,7 +291,7 @@ func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64,
pages++ pages++
} }
return logs, &PaginationResult{ return logs, &pagination.PaginationResult{
Total: total, Total: total,
Page: params.Page, Page: params.Page,
PageSize: params.Limit(), PageSize: params.Limit(),
...@@ -297,7 +299,7 @@ func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64, ...@@ -297,7 +299,7 @@ func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64,
}, nil }, nil
} }
func (r *UsageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *PaginationResult, error) { func (r *UsageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("user_id = ? AND created_at >= ? AND created_at < ?", userID, startTime, endTime). Where("user_id = ? AND created_at >= ? AND created_at < ?", userID, startTime, endTime).
...@@ -306,7 +308,7 @@ func (r *UsageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID ...@@ -306,7 +308,7 @@ func (r *UsageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID
return logs, nil, err return logs, nil, err
} }
func (r *UsageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *PaginationResult, error) { func (r *UsageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("api_key_id = ? AND created_at >= ? AND created_at < ?", apiKeyID, startTime, endTime). Where("api_key_id = ? AND created_at >= ? AND created_at < ?", apiKeyID, startTime, endTime).
...@@ -315,7 +317,7 @@ func (r *UsageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKe ...@@ -315,7 +317,7 @@ func (r *UsageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKe
return logs, nil, err return logs, nil, err
} }
func (r *UsageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *PaginationResult, error) { func (r *UsageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime). Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime).
...@@ -324,7 +326,7 @@ func (r *UsageLogRepository) ListByAccountAndTimeRange(ctx context.Context, acco ...@@ -324,7 +326,7 @@ func (r *UsageLogRepository) ListByAccountAndTimeRange(ctx context.Context, acco
return logs, nil, err return logs, nil, err
} }
func (r *UsageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *PaginationResult, error) { func (r *UsageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("model = ? AND created_at >= ? AND created_at < ?", modelName, startTime, endTime). Where("model = ? AND created_at >= ? AND created_at < ?", modelName, startTime, endTime).
...@@ -337,15 +339,8 @@ func (r *UsageLogRepository) Delete(ctx context.Context, id int64) error { ...@@ -337,15 +339,8 @@ func (r *UsageLogRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.UsageLog{}, id).Error return r.db.WithContext(ctx).Delete(&model.UsageLog{}, id).Error
} }
// AccountStats 账号使用统计
type AccountStats struct {
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"`
}
// GetAccountTodayStats 获取账号今日统计 // GetAccountTodayStats 获取账号今日统计
func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID int64) (*AccountStats, error) { func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) {
today := timezone.Today() today := timezone.Today()
var stats struct { var stats struct {
...@@ -367,7 +362,7 @@ func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID ...@@ -367,7 +362,7 @@ func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
return nil, err return nil, err
} }
return &AccountStats{ return &usagestats.AccountStats{
Requests: stats.Requests, Requests: stats.Requests,
Tokens: stats.Tokens, Tokens: stats.Tokens,
Cost: stats.Cost, Cost: stats.Cost,
...@@ -375,7 +370,7 @@ func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID ...@@ -375,7 +370,7 @@ func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
} }
// GetAccountWindowStats 获取账号时间窗口内的统计 // GetAccountWindowStats 获取账号时间窗口内的统计
func (r *UsageLogRepository) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*AccountStats, error) { func (r *UsageLogRepository) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
var stats struct { var stats struct {
Requests int64 `gorm:"column:requests"` Requests int64 `gorm:"column:requests"`
Tokens int64 `gorm:"column:tokens"` Tokens int64 `gorm:"column:tokens"`
...@@ -395,7 +390,7 @@ func (r *UsageLogRepository) GetAccountWindowStats(ctx context.Context, accountI ...@@ -395,7 +390,7 @@ func (r *UsageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
return nil, err return nil, err
} }
return &AccountStats{ return &usagestats.AccountStats{
Requests: stats.Requests, Requests: stats.Requests,
Tokens: stats.Tokens, Tokens: stats.Tokens,
Cost: stats.Cost, Cost: stats.Cost,
...@@ -780,7 +775,7 @@ type UsageLogFilters struct { ...@@ -780,7 +775,7 @@ type UsageLogFilters struct {
} }
// ListWithFilters lists usage logs with optional filters (for admin) // ListWithFilters lists usage logs with optional filters (for admin)
func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params PaginationParams, filters UsageLogFilters) ([]model.UsageLog, *PaginationResult, error) { func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
var total int64 var total int64
...@@ -816,7 +811,7 @@ func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params Paginat ...@@ -816,7 +811,7 @@ func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params Paginat
pages++ pages++
} }
return logs, &PaginationResult{ return logs, &pagination.PaginationResult{
Total: total, Total: total,
Page: params.Page, Page: params.Page,
PageSize: params.Limit(), PageSize: params.Limit(),
......
...@@ -3,6 +3,7 @@ package repository ...@@ -3,6 +3,7 @@ package repository
import ( import (
"context" "context"
"sub2api/internal/model" "sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"gorm.io/gorm" "gorm.io/gorm"
) )
...@@ -45,12 +46,12 @@ func (r *UserRepository) Delete(ctx context.Context, id int64) error { ...@@ -45,12 +46,12 @@ func (r *UserRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.User{}, id).Error return r.db.WithContext(ctx).Delete(&model.User{}, id).Error
} }
func (r *UserRepository) List(ctx context.Context, params PaginationParams) ([]model.User, *PaginationResult, error) { func (r *UserRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "") return r.ListWithFilters(ctx, params, "", "", "")
} }
// ListWithFilters lists users with optional filtering by status, role, and search query // ListWithFilters lists users with optional filtering by status, role, and search query
func (r *UserRepository) ListWithFilters(ctx context.Context, params PaginationParams, status, role, search string) ([]model.User, *PaginationResult, error) { func (r *UserRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error) {
var users []model.User var users []model.User
var total int64 var total int64
...@@ -81,7 +82,7 @@ func (r *UserRepository) ListWithFilters(ctx context.Context, params PaginationP ...@@ -81,7 +82,7 @@ func (r *UserRepository) ListWithFilters(ctx context.Context, params PaginationP
pages++ pages++
} }
return users, &PaginationResult{ return users, &pagination.PaginationResult{
Total: total, Total: total,
Page: params.Page, Page: params.Page,
PageSize: params.Limit(), PageSize: params.Limit(),
...@@ -127,4 +128,3 @@ func (r *UserRepository) RemoveGroupFromAllowedGroups(ctx context.Context, group ...@@ -127,4 +128,3 @@ func (r *UserRepository) RemoveGroupFromAllowedGroups(ctx context.Context, group
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", groupID)) Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", groupID))
return result.RowsAffected, result.Error return result.RowsAffected, result.Error
} }
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
"time" "time"
"sub2api/internal/model" "sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"gorm.io/gorm" "gorm.io/gorm"
) )
...@@ -100,7 +101,7 @@ func (r *UserSubscriptionRepository) ListActiveByUserID(ctx context.Context, use ...@@ -100,7 +101,7 @@ func (r *UserSubscriptionRepository) ListActiveByUserID(ctx context.Context, use
} }
// ListByGroupID 获取分组的所有订阅(分页) // ListByGroupID 获取分组的所有订阅(分页)
func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params PaginationParams) ([]model.UserSubscription, *PaginationResult, error) { func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error) {
var subs []model.UserSubscription var subs []model.UserSubscription
var total int64 var total int64
...@@ -126,7 +127,7 @@ func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID ...@@ -126,7 +127,7 @@ func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
pages++ pages++
} }
return subs, &PaginationResult{ return subs, &pagination.PaginationResult{
Total: total, Total: total,
Page: params.Page, Page: params.Page,
PageSize: params.Limit(), PageSize: params.Limit(),
...@@ -135,7 +136,7 @@ func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID ...@@ -135,7 +136,7 @@ func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
} }
// List 获取所有订阅(分页,支持筛选) // List 获取所有订阅(分页,支持筛选)
func (r *UserSubscriptionRepository) List(ctx context.Context, params PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *PaginationResult, error) { func (r *UserSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) {
var subs []model.UserSubscription var subs []model.UserSubscription
var total int64 var total int64
...@@ -172,7 +173,7 @@ func (r *UserSubscriptionRepository) List(ctx context.Context, params Pagination ...@@ -172,7 +173,7 @@ func (r *UserSubscriptionRepository) List(ctx context.Context, params Pagination
pages++ pages++
} }
return subs, &PaginationResult{ return subs, &pagination.PaginationResult{
Total: total, Total: total,
Page: params.Page, Page: params.Page,
PageSize: params.Limit(), PageSize: params.Limit(),
......
package repository package repository
import ( import (
"sub2api/internal/service/ports"
"github.com/google/wire" "github.com/google/wire"
) )
...@@ -16,4 +18,15 @@ var ProviderSet = wire.NewSet( ...@@ -16,4 +18,15 @@ var ProviderSet = wire.NewSet(
NewSettingRepository, NewSettingRepository,
NewUserSubscriptionRepository, NewUserSubscriptionRepository,
wire.Struct(new(Repositories), "*"), wire.Struct(new(Repositories), "*"),
// Bind concrete repositories to service port interfaces
wire.Bind(new(ports.UserRepository), new(*UserRepository)),
wire.Bind(new(ports.ApiKeyRepository), new(*ApiKeyRepository)),
wire.Bind(new(ports.GroupRepository), new(*GroupRepository)),
wire.Bind(new(ports.AccountRepository), new(*AccountRepository)),
wire.Bind(new(ports.ProxyRepository), new(*ProxyRepository)),
wire.Bind(new(ports.RedeemCodeRepository), new(*RedeemCodeRepository)),
wire.Bind(new(ports.UsageLogRepository), new(*UsageLogRepository)),
wire.Bind(new(ports.SettingRepository), new(*SettingRepository)),
wire.Bind(new(ports.UserSubscriptionRepository), new(*UserSubscriptionRepository)),
) )
...@@ -5,7 +5,8 @@ import ( ...@@ -5,7 +5,8 @@ import (
"errors" "errors"
"fmt" "fmt"
"sub2api/internal/model" "sub2api/internal/model"
"sub2api/internal/repository" "sub2api/internal/pkg/pagination"
"sub2api/internal/service/ports"
"gorm.io/gorm" "gorm.io/gorm"
) )
...@@ -41,12 +42,12 @@ type UpdateAccountRequest struct { ...@@ -41,12 +42,12 @@ type UpdateAccountRequest struct {
// AccountService 账号管理服务 // AccountService 账号管理服务
type AccountService struct { type AccountService struct {
accountRepo *repository.AccountRepository accountRepo ports.AccountRepository
groupRepo *repository.GroupRepository groupRepo ports.GroupRepository
} }
// NewAccountService 创建账号服务实例 // NewAccountService 创建账号服务实例
func NewAccountService(accountRepo *repository.AccountRepository, groupRepo *repository.GroupRepository) *AccountService { func NewAccountService(accountRepo ports.AccountRepository, groupRepo ports.GroupRepository) *AccountService {
return &AccountService{ return &AccountService{
accountRepo: accountRepo, accountRepo: accountRepo,
groupRepo: groupRepo, groupRepo: groupRepo,
...@@ -108,7 +109,7 @@ func (s *AccountService) GetByID(ctx context.Context, id int64) (*model.Account, ...@@ -108,7 +109,7 @@ func (s *AccountService) GetByID(ctx context.Context, id int64) (*model.Account,
} }
// List 获取账号列表 // List 获取账号列表
func (s *AccountService) List(ctx context.Context, params repository.PaginationParams) ([]model.Account, *repository.PaginationResult, error) { func (s *AccountService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) {
accounts, pagination, err := s.accountRepo.List(ctx, params) accounts, pagination, err := s.accountRepo.List(ctx, params)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("list accounts: %w", err) return nil, nil, fmt.Errorf("list accounts: %w", err)
......
...@@ -16,7 +16,7 @@ import ( ...@@ -16,7 +16,7 @@ import (
"time" "time"
"sub2api/internal/pkg/claude" "sub2api/internal/pkg/claude"
"sub2api/internal/repository" "sub2api/internal/service/ports"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
...@@ -37,15 +37,15 @@ type TestEvent struct { ...@@ -37,15 +37,15 @@ type TestEvent struct {
// AccountTestService handles account testing operations // AccountTestService handles account testing operations
type AccountTestService struct { type AccountTestService struct {
repos *repository.Repositories accountRepo ports.AccountRepository
oauthService *OAuthService oauthService *OAuthService
httpClient *http.Client httpClient *http.Client
} }
// NewAccountTestService creates a new AccountTestService // NewAccountTestService creates a new AccountTestService
func NewAccountTestService(repos *repository.Repositories, oauthService *OAuthService) *AccountTestService { func NewAccountTestService(accountRepo ports.AccountRepository, oauthService *OAuthService) *AccountTestService {
return &AccountTestService{ return &AccountTestService{
repos: repos, accountRepo: accountRepo,
oauthService: oauthService, oauthService: oauthService,
httpClient: &http.Client{ httpClient: &http.Client{
Timeout: 60 * time.Second, Timeout: 60 * time.Second,
...@@ -105,7 +105,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int ...@@ -105,7 +105,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
ctx := c.Request.Context() ctx := c.Request.Context()
// Get account // Get account
account, err := s.repos.Account.GetByID(ctx, accountID) account, err := s.accountRepo.GetByID(ctx, accountID)
if err != nil { if err != nil {
return s.sendErrorAndEnd(c, "Account not found") return s.sendErrorAndEnd(c, "Account not found")
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment