Commit e99b344b authored by Forest's avatar Forest
Browse files

refactor(backend): 引入端口接口模式

parent 7fd94ab7
...@@ -8,7 +8,8 @@ import ( ...@@ -8,7 +8,8 @@ import (
"fmt" "fmt"
"strings" "strings"
"sub2api/internal/model" "sub2api/internal/model"
"sub2api/internal/repository" "sub2api/internal/pkg/pagination"
"sub2api/internal/service/ports"
"time" "time"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
...@@ -49,8 +50,8 @@ type RedeemCodeResponse struct { ...@@ -49,8 +50,8 @@ type RedeemCodeResponse struct {
// RedeemService 兑换码服务 // RedeemService 兑换码服务
type RedeemService struct { type RedeemService struct {
redeemRepo *repository.RedeemCodeRepository redeemRepo ports.RedeemCodeRepository
userRepo *repository.UserRepository userRepo ports.UserRepository
subscriptionService *SubscriptionService subscriptionService *SubscriptionService
rdb *redis.Client rdb *redis.Client
billingCacheService *BillingCacheService billingCacheService *BillingCacheService
...@@ -58,8 +59,8 @@ type RedeemService struct { ...@@ -58,8 +59,8 @@ type RedeemService struct {
// NewRedeemService 创建兑换码服务实例 // NewRedeemService 创建兑换码服务实例
func NewRedeemService( func NewRedeemService(
redeemRepo *repository.RedeemCodeRepository, redeemRepo ports.RedeemCodeRepository,
userRepo *repository.UserRepository, userRepo ports.UserRepository,
subscriptionService *SubscriptionService, subscriptionService *SubscriptionService,
rdb *redis.Client, rdb *redis.Client,
billingCacheService *BillingCacheService, billingCacheService *BillingCacheService,
...@@ -337,7 +338,7 @@ func (s *RedeemService) GetByCode(ctx context.Context, code string) (*model.Rede ...@@ -337,7 +338,7 @@ func (s *RedeemService) GetByCode(ctx context.Context, code string) (*model.Rede
} }
// List 获取兑换码列表(管理员功能) // List 获取兑换码列表(管理员功能)
func (s *RedeemService) List(ctx context.Context, params repository.PaginationParams) ([]model.RedeemCode, *repository.PaginationResult, error) { func (s *RedeemService) List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) {
codes, pagination, err := s.redeemRepo.List(ctx, params) codes, pagination, err := s.redeemRepo.List(ctx, params)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("list redeem codes: %w", err) return nil, nil, fmt.Errorf("list redeem codes: %w", err)
......
...@@ -7,7 +7,7 @@ import ( ...@@ -7,7 +7,7 @@ import (
"strconv" "strconv"
"sub2api/internal/config" "sub2api/internal/config"
"sub2api/internal/model" "sub2api/internal/model"
"sub2api/internal/repository" "sub2api/internal/service/ports"
"gorm.io/gorm" "gorm.io/gorm"
) )
...@@ -18,12 +18,12 @@ var ( ...@@ -18,12 +18,12 @@ var (
// SettingService 系统设置服务 // SettingService 系统设置服务
type SettingService struct { type SettingService struct {
settingRepo *repository.SettingRepository settingRepo ports.SettingRepository
cfg *config.Config cfg *config.Config
} }
// NewSettingService 创建系统设置服务实例 // NewSettingService 创建系统设置服务实例
func NewSettingService(settingRepo *repository.SettingRepository, cfg *config.Config) *SettingService { func NewSettingService(settingRepo ports.SettingRepository, cfg *config.Config) *SettingService {
return &SettingService{ return &SettingService{
settingRepo: settingRepo, settingRepo: settingRepo,
cfg: cfg, cfg: cfg,
......
...@@ -7,7 +7,8 @@ import ( ...@@ -7,7 +7,8 @@ import (
"time" "time"
"sub2api/internal/model" "sub2api/internal/model"
"sub2api/internal/repository" "sub2api/internal/pkg/pagination"
"sub2api/internal/service/ports"
) )
var ( var (
...@@ -23,14 +24,16 @@ var ( ...@@ -23,14 +24,16 @@ var (
// SubscriptionService 订阅服务 // SubscriptionService 订阅服务
type SubscriptionService struct { type SubscriptionService struct {
repos *repository.Repositories groupRepo ports.GroupRepository
userSubRepo ports.UserSubscriptionRepository
billingCacheService *BillingCacheService billingCacheService *BillingCacheService
} }
// NewSubscriptionService 创建订阅服务 // NewSubscriptionService 创建订阅服务
func NewSubscriptionService(repos *repository.Repositories, billingCacheService *BillingCacheService) *SubscriptionService { func NewSubscriptionService(groupRepo ports.GroupRepository, userSubRepo ports.UserSubscriptionRepository, billingCacheService *BillingCacheService) *SubscriptionService {
return &SubscriptionService{ return &SubscriptionService{
repos: repos, groupRepo: groupRepo,
userSubRepo: userSubRepo,
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
} }
} }
...@@ -47,7 +50,7 @@ type AssignSubscriptionInput struct { ...@@ -47,7 +50,7 @@ type AssignSubscriptionInput struct {
// AssignSubscription 分配订阅给用户(不允许重复分配) // AssignSubscription 分配订阅给用户(不允许重复分配)
func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, error) { func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, error) {
// 检查分组是否存在且为订阅类型 // 检查分组是否存在且为订阅类型
group, err := s.repos.Group.GetByID(ctx, input.GroupID) group, err := s.groupRepo.GetByID(ctx, input.GroupID)
if err != nil { if err != nil {
return nil, fmt.Errorf("group not found: %w", err) return nil, fmt.Errorf("group not found: %w", err)
} }
...@@ -56,7 +59,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass ...@@ -56,7 +59,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass
} }
// 检查是否已存在订阅 // 检查是否已存在订阅
exists, err := s.repos.UserSubscription.ExistsByUserIDAndGroupID(ctx, input.UserID, input.GroupID) exists, err := s.userSubRepo.ExistsByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -90,7 +93,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass ...@@ -90,7 +93,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass
// 如果没有订阅:创建新订阅 // 如果没有订阅:创建新订阅
func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, bool, error) { func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, bool, error) {
// 检查分组是否存在且为订阅类型 // 检查分组是否存在且为订阅类型
group, err := s.repos.Group.GetByID(ctx, input.GroupID) group, err := s.groupRepo.GetByID(ctx, input.GroupID)
if err != nil { if err != nil {
return nil, false, fmt.Errorf("group not found: %w", err) return nil, false, fmt.Errorf("group not found: %w", err)
} }
...@@ -99,7 +102,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in ...@@ -99,7 +102,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
} }
// 查询是否已有订阅 // 查询是否已有订阅
existingSub, err := s.repos.UserSubscription.GetByUserIDAndGroupID(ctx, input.UserID, input.GroupID) existingSub, err := s.userSubRepo.GetByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
if err != nil { if err != nil {
// 不存在记录是正常情况,其他错误需要返回 // 不存在记录是正常情况,其他错误需要返回
existingSub = nil existingSub = nil
...@@ -124,13 +127,13 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in ...@@ -124,13 +127,13 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
} }
// 更新过期时间 // 更新过期时间
if err := s.repos.UserSubscription.ExtendExpiry(ctx, existingSub.ID, newExpiresAt); err != nil { if err := s.userSubRepo.ExtendExpiry(ctx, existingSub.ID, newExpiresAt); err != nil {
return nil, false, fmt.Errorf("extend subscription: %w", err) return nil, false, fmt.Errorf("extend subscription: %w", err)
} }
// 如果订阅已过期或被暂停,恢复为active状态 // 如果订阅已过期或被暂停,恢复为active状态
if existingSub.Status != model.SubscriptionStatusActive { if existingSub.Status != model.SubscriptionStatusActive {
if err := s.repos.UserSubscription.UpdateStatus(ctx, existingSub.ID, model.SubscriptionStatusActive); err != nil { if err := s.userSubRepo.UpdateStatus(ctx, existingSub.ID, model.SubscriptionStatusActive); err != nil {
return nil, false, fmt.Errorf("update subscription status: %w", err) return nil, false, fmt.Errorf("update subscription status: %w", err)
} }
} }
...@@ -142,7 +145,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in ...@@ -142,7 +145,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
newNotes += "\n" newNotes += "\n"
} }
newNotes += input.Notes newNotes += input.Notes
if err := s.repos.UserSubscription.UpdateNotes(ctx, existingSub.ID, newNotes); err != nil { if err := s.userSubRepo.UpdateNotes(ctx, existingSub.ID, newNotes); err != nil {
// 备注更新失败不影响主流程 // 备注更新失败不影响主流程
} }
} }
...@@ -158,7 +161,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in ...@@ -158,7 +161,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
} }
// 返回更新后的订阅 // 返回更新后的订阅
sub, err := s.repos.UserSubscription.GetByID(ctx, existingSub.ID) sub, err := s.userSubRepo.GetByID(ctx, existingSub.ID)
return sub, true, err // true 表示是续期 return sub, true, err // true 表示是续期
} }
...@@ -205,12 +208,12 @@ func (s *SubscriptionService) createSubscription(ctx context.Context, input *Ass ...@@ -205,12 +208,12 @@ func (s *SubscriptionService) createSubscription(ctx context.Context, input *Ass
sub.AssignedBy = &input.AssignedBy sub.AssignedBy = &input.AssignedBy
} }
if err := s.repos.UserSubscription.Create(ctx, sub); err != nil { if err := s.userSubRepo.Create(ctx, sub); err != nil {
return nil, err return nil, err
} }
// 重新获取完整订阅信息(包含关联) // 重新获取完整订阅信息(包含关联)
return s.repos.UserSubscription.GetByID(ctx, sub.ID) return s.userSubRepo.GetByID(ctx, sub.ID)
} }
// BulkAssignSubscriptionInput 批量分配订阅输入 // BulkAssignSubscriptionInput 批量分配订阅输入
...@@ -260,12 +263,12 @@ func (s *SubscriptionService) BulkAssignSubscription(ctx context.Context, input ...@@ -260,12 +263,12 @@ func (s *SubscriptionService) BulkAssignSubscription(ctx context.Context, input
// RevokeSubscription 撤销订阅 // RevokeSubscription 撤销订阅
func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscriptionID int64) error { func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscriptionID int64) error {
// 先获取订阅信息用于失效缓存 // 先获取订阅信息用于失效缓存
sub, err := s.repos.UserSubscription.GetByID(ctx, subscriptionID) sub, err := s.userSubRepo.GetByID(ctx, subscriptionID)
if err != nil { if err != nil {
return err return err
} }
if err := s.repos.UserSubscription.Delete(ctx, subscriptionID); err != nil { if err := s.userSubRepo.Delete(ctx, subscriptionID); err != nil {
return err return err
} }
...@@ -284,20 +287,20 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti ...@@ -284,20 +287,20 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti
// ExtendSubscription 延长订阅 // ExtendSubscription 延长订阅
func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscriptionID int64, days int) (*model.UserSubscription, error) { func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscriptionID int64, days int) (*model.UserSubscription, error) {
sub, err := s.repos.UserSubscription.GetByID(ctx, subscriptionID) sub, err := s.userSubRepo.GetByID(ctx, subscriptionID)
if err != nil { if err != nil {
return nil, ErrSubscriptionNotFound return nil, ErrSubscriptionNotFound
} }
// 计算新的过期时间 // 计算新的过期时间
newExpiresAt := sub.ExpiresAt.AddDate(0, 0, days) newExpiresAt := sub.ExpiresAt.AddDate(0, 0, days)
if err := s.repos.UserSubscription.ExtendExpiry(ctx, subscriptionID, newExpiresAt); err != nil { if err := s.userSubRepo.ExtendExpiry(ctx, subscriptionID, newExpiresAt); err != nil {
return nil, err return nil, err
} }
// 如果订阅已过期,恢复为active状态 // 如果订阅已过期,恢复为active状态
if sub.Status == model.SubscriptionStatusExpired { if sub.Status == model.SubscriptionStatusExpired {
if err := s.repos.UserSubscription.UpdateStatus(ctx, subscriptionID, model.SubscriptionStatusActive); err != nil { if err := s.userSubRepo.UpdateStatus(ctx, subscriptionID, model.SubscriptionStatusActive); err != nil {
return nil, err return nil, err
} }
} }
...@@ -312,17 +315,17 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti ...@@ -312,17 +315,17 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti
}() }()
} }
return s.repos.UserSubscription.GetByID(ctx, subscriptionID) return s.userSubRepo.GetByID(ctx, subscriptionID)
} }
// GetByID 根据ID获取订阅 // GetByID 根据ID获取订阅
func (s *SubscriptionService) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) { func (s *SubscriptionService) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) {
return s.repos.UserSubscription.GetByID(ctx, id) return s.userSubRepo.GetByID(ctx, id)
} }
// GetActiveSubscription 获取用户对特定分组的有效订阅 // GetActiveSubscription 获取用户对特定分组的有效订阅
func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) { func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
sub, err := s.repos.UserSubscription.GetActiveByUserIDAndGroupID(ctx, userID, groupID) sub, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
if err != nil { if err != nil {
return nil, ErrSubscriptionNotFound return nil, ErrSubscriptionNotFound
} }
...@@ -331,24 +334,24 @@ func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID, ...@@ -331,24 +334,24 @@ func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID,
// ListUserSubscriptions 获取用户的所有订阅 // ListUserSubscriptions 获取用户的所有订阅
func (s *SubscriptionService) ListUserSubscriptions(ctx context.Context, userID int64) ([]model.UserSubscription, error) { func (s *SubscriptionService) ListUserSubscriptions(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
return s.repos.UserSubscription.ListByUserID(ctx, userID) return s.userSubRepo.ListByUserID(ctx, userID)
} }
// ListActiveUserSubscriptions 获取用户的所有有效订阅 // ListActiveUserSubscriptions 获取用户的所有有效订阅
func (s *SubscriptionService) ListActiveUserSubscriptions(ctx context.Context, userID int64) ([]model.UserSubscription, error) { func (s *SubscriptionService) ListActiveUserSubscriptions(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
return s.repos.UserSubscription.ListActiveByUserID(ctx, userID) return s.userSubRepo.ListActiveByUserID(ctx, userID)
} }
// ListGroupSubscriptions 获取分组的所有订阅 // ListGroupSubscriptions 获取分组的所有订阅
func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupID int64, page, pageSize int) ([]model.UserSubscription, *repository.PaginationResult, error) { func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupID int64, page, pageSize int) ([]model.UserSubscription, *pagination.PaginationResult, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
return s.repos.UserSubscription.ListByGroupID(ctx, groupID, params) return s.userSubRepo.ListByGroupID(ctx, groupID, params)
} }
// List 获取所有订阅(分页,支持筛选) // List 获取所有订阅(分页,支持筛选)
func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status string) ([]model.UserSubscription, *repository.PaginationResult, error) { func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
return s.repos.UserSubscription.List(ctx, params, userID, groupID, status) return s.userSubRepo.List(ctx, params, userID, groupID, status)
} }
// CheckAndActivateWindow 检查并激活窗口(首次使用时) // CheckAndActivateWindow 检查并激活窗口(首次使用时)
...@@ -358,7 +361,7 @@ func (s *SubscriptionService) CheckAndActivateWindow(ctx context.Context, sub *m ...@@ -358,7 +361,7 @@ func (s *SubscriptionService) CheckAndActivateWindow(ctx context.Context, sub *m
} }
now := time.Now() now := time.Now()
return s.repos.UserSubscription.ActivateWindows(ctx, sub.ID, now) return s.userSubRepo.ActivateWindows(ctx, sub.ID, now)
} }
// CheckAndResetWindows 检查并重置过期的窗口 // CheckAndResetWindows 检查并重置过期的窗口
...@@ -367,7 +370,7 @@ func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *mod ...@@ -367,7 +370,7 @@ func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *mod
// 日窗口重置(24小时) // 日窗口重置(24小时)
if sub.NeedsDailyReset() { if sub.NeedsDailyReset() {
if err := s.repos.UserSubscription.ResetDailyUsage(ctx, sub.ID, now); err != nil { if err := s.userSubRepo.ResetDailyUsage(ctx, sub.ID, now); err != nil {
return err return err
} }
sub.DailyWindowStart = &now sub.DailyWindowStart = &now
...@@ -376,7 +379,7 @@ func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *mod ...@@ -376,7 +379,7 @@ func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *mod
// 周窗口重置(7天) // 周窗口重置(7天)
if sub.NeedsWeeklyReset() { if sub.NeedsWeeklyReset() {
if err := s.repos.UserSubscription.ResetWeeklyUsage(ctx, sub.ID, now); err != nil { if err := s.userSubRepo.ResetWeeklyUsage(ctx, sub.ID, now); err != nil {
return err return err
} }
sub.WeeklyWindowStart = &now sub.WeeklyWindowStart = &now
...@@ -385,7 +388,7 @@ func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *mod ...@@ -385,7 +388,7 @@ func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *mod
// 月窗口重置(30天) // 月窗口重置(30天)
if sub.NeedsMonthlyReset() { if sub.NeedsMonthlyReset() {
if err := s.repos.UserSubscription.ResetMonthlyUsage(ctx, sub.ID, now); err != nil { if err := s.userSubRepo.ResetMonthlyUsage(ctx, sub.ID, now); err != nil {
return err return err
} }
sub.MonthlyWindowStart = &now sub.MonthlyWindowStart = &now
...@@ -411,7 +414,7 @@ func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *model.U ...@@ -411,7 +414,7 @@ func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *model.U
// RecordUsage 记录使用量到订阅 // RecordUsage 记录使用量到订阅
func (s *SubscriptionService) RecordUsage(ctx context.Context, subscriptionID int64, costUSD float64) error { func (s *SubscriptionService) RecordUsage(ctx context.Context, subscriptionID int64, costUSD float64) error {
return s.repos.UserSubscription.IncrementUsage(ctx, subscriptionID, costUSD) return s.userSubRepo.IncrementUsage(ctx, subscriptionID, costUSD)
} }
// SubscriptionProgress 订阅进度 // SubscriptionProgress 订阅进度
...@@ -438,14 +441,14 @@ type UsageWindowProgress struct { ...@@ -438,14 +441,14 @@ type UsageWindowProgress struct {
// GetSubscriptionProgress 获取订阅使用进度 // GetSubscriptionProgress 获取订阅使用进度
func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subscriptionID int64) (*SubscriptionProgress, error) { func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subscriptionID int64) (*SubscriptionProgress, error) {
sub, err := s.repos.UserSubscription.GetByID(ctx, subscriptionID) sub, err := s.userSubRepo.GetByID(ctx, subscriptionID)
if err != nil { if err != nil {
return nil, ErrSubscriptionNotFound return nil, ErrSubscriptionNotFound
} }
group := sub.Group group := sub.Group
if group == nil { if group == nil {
group, err = s.repos.Group.GetByID(ctx, sub.GroupID) group, err = s.groupRepo.GetByID(ctx, sub.GroupID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -535,7 +538,7 @@ func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subsc ...@@ -535,7 +538,7 @@ func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subsc
// GetUserSubscriptionsWithProgress 获取用户所有订阅及进度 // GetUserSubscriptionsWithProgress 获取用户所有订阅及进度
func (s *SubscriptionService) GetUserSubscriptionsWithProgress(ctx context.Context, userID int64) ([]SubscriptionProgress, error) { func (s *SubscriptionService) GetUserSubscriptionsWithProgress(ctx context.Context, userID int64) ([]SubscriptionProgress, error) {
subs, err := s.repos.UserSubscription.ListActiveByUserID(ctx, userID) subs, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -554,7 +557,7 @@ func (s *SubscriptionService) GetUserSubscriptionsWithProgress(ctx context.Conte ...@@ -554,7 +557,7 @@ func (s *SubscriptionService) GetUserSubscriptionsWithProgress(ctx context.Conte
// UpdateExpiredSubscriptions 更新过期订阅状态(定时任务调用) // UpdateExpiredSubscriptions 更新过期订阅状态(定时任务调用)
func (s *SubscriptionService) UpdateExpiredSubscriptions(ctx context.Context) (int64, error) { func (s *SubscriptionService) UpdateExpiredSubscriptions(ctx context.Context) (int64, error) {
return s.repos.UserSubscription.BatchUpdateExpiredStatus(ctx) return s.userSubRepo.BatchUpdateExpiredStatus(ctx)
} }
// ValidateSubscription 验证订阅是否有效 // ValidateSubscription 验证订阅是否有效
...@@ -567,7 +570,7 @@ func (s *SubscriptionService) ValidateSubscription(ctx context.Context, sub *mod ...@@ -567,7 +570,7 @@ func (s *SubscriptionService) ValidateSubscription(ctx context.Context, sub *mod
} }
if sub.IsExpired() { if sub.IsExpired() {
// 更新状态 // 更新状态
_ = s.repos.UserSubscription.UpdateStatus(ctx, sub.ID, model.SubscriptionStatusExpired) _ = s.userSubRepo.UpdateStatus(ctx, sub.ID, model.SubscriptionStatusExpired)
return ErrSubscriptionExpired return ErrSubscriptionExpired
} }
return nil return nil
......
...@@ -5,7 +5,8 @@ import ( ...@@ -5,7 +5,8 @@ import (
"errors" "errors"
"fmt" "fmt"
"sub2api/internal/model" "sub2api/internal/model"
"sub2api/internal/repository" "sub2api/internal/pkg/pagination"
"sub2api/internal/service/ports"
"time" "time"
"gorm.io/gorm" "gorm.io/gorm"
...@@ -41,24 +42,24 @@ type CreateUsageLogRequest struct { ...@@ -41,24 +42,24 @@ type CreateUsageLogRequest struct {
// UsageStats 使用统计 // UsageStats 使用统计
type UsageStats struct { type UsageStats struct {
TotalRequests int64 `json:"total_requests"` TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"` TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"` TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheTokens int64 `json:"total_cache_tokens"` TotalCacheTokens int64 `json:"total_cache_tokens"`
TotalTokens int64 `json:"total_tokens"` TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"` TotalCost float64 `json:"total_cost"`
TotalActualCost float64 `json:"total_actual_cost"` TotalActualCost float64 `json:"total_actual_cost"`
AverageDurationMs float64 `json:"average_duration_ms"` AverageDurationMs float64 `json:"average_duration_ms"`
} }
// UsageService 使用统计服务 // UsageService 使用统计服务
type UsageService struct { type UsageService struct {
usageRepo *repository.UsageLogRepository usageRepo ports.UsageLogRepository
userRepo *repository.UserRepository userRepo ports.UserRepository
} }
// NewUsageService 创建使用统计服务实例 // NewUsageService 创建使用统计服务实例
func NewUsageService(usageRepo *repository.UsageLogRepository, userRepo *repository.UserRepository) *UsageService { func NewUsageService(usageRepo ports.UsageLogRepository, userRepo ports.UserRepository) *UsageService {
return &UsageService{ return &UsageService{
usageRepo: usageRepo, usageRepo: usageRepo,
userRepo: userRepo, userRepo: userRepo,
...@@ -127,7 +128,7 @@ func (s *UsageService) GetByID(ctx context.Context, id int64) (*model.UsageLog, ...@@ -127,7 +128,7 @@ func (s *UsageService) GetByID(ctx context.Context, id int64) (*model.UsageLog,
} }
// ListByUser 获取用户的使用日志列表 // ListByUser 获取用户的使用日志列表
func (s *UsageService) ListByUser(ctx context.Context, userID int64, params repository.PaginationParams) ([]model.UsageLog, *repository.PaginationResult, error) { func (s *UsageService) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
logs, pagination, err := s.usageRepo.ListByUser(ctx, userID, params) logs, pagination, err := s.usageRepo.ListByUser(ctx, userID, params)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("list usage logs: %w", err) return nil, nil, fmt.Errorf("list usage logs: %w", err)
...@@ -136,7 +137,7 @@ func (s *UsageService) ListByUser(ctx context.Context, userID int64, params repo ...@@ -136,7 +137,7 @@ func (s *UsageService) ListByUser(ctx context.Context, userID int64, params repo
} }
// ListByApiKey 获取API Key的使用日志列表 // ListByApiKey 获取API Key的使用日志列表
func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params repository.PaginationParams) ([]model.UsageLog, *repository.PaginationResult, error) { func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
logs, pagination, err := s.usageRepo.ListByApiKey(ctx, apiKeyID, params) logs, pagination, err := s.usageRepo.ListByApiKey(ctx, apiKeyID, params)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("list usage logs: %w", err) return nil, nil, fmt.Errorf("list usage logs: %w", err)
...@@ -145,7 +146,7 @@ func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params ...@@ -145,7 +146,7 @@ func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params
} }
// ListByAccount 获取账号的使用日志列表 // ListByAccount 获取账号的使用日志列表
func (s *UsageService) ListByAccount(ctx context.Context, accountID int64, params repository.PaginationParams) ([]model.UsageLog, *repository.PaginationResult, error) { func (s *UsageService) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
logs, pagination, err := s.usageRepo.ListByAccount(ctx, accountID, params) logs, pagination, err := s.usageRepo.ListByAccount(ctx, accountID, params)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("list usage logs: %w", err) return nil, nil, fmt.Errorf("list usage logs: %w", err)
...@@ -233,15 +234,15 @@ func (s *UsageService) GetDailyStats(ctx context.Context, userID int64, days int ...@@ -233,15 +234,15 @@ func (s *UsageService) GetDailyStats(ctx context.Context, userID int64, days int
} }
result = append(result, map[string]interface{}{ result = append(result, map[string]interface{}{
"date": date, "date": date,
"total_requests": stats.TotalRequests, "total_requests": stats.TotalRequests,
"total_input_tokens": stats.TotalInputTokens, "total_input_tokens": stats.TotalInputTokens,
"total_output_tokens": stats.TotalOutputTokens, "total_output_tokens": stats.TotalOutputTokens,
"total_cache_tokens": stats.TotalCacheTokens, "total_cache_tokens": stats.TotalCacheTokens,
"total_tokens": stats.TotalTokens, "total_tokens": stats.TotalTokens,
"total_cost": stats.TotalCost, "total_cost": stats.TotalCost,
"total_actual_cost": stats.TotalActualCost, "total_actual_cost": stats.TotalActualCost,
"average_duration_ms": stats.AverageDurationMs, "average_duration_ms": stats.AverageDurationMs,
}) })
} }
......
...@@ -6,16 +6,17 @@ import ( ...@@ -6,16 +6,17 @@ import (
"fmt" "fmt"
"sub2api/internal/config" "sub2api/internal/config"
"sub2api/internal/model" "sub2api/internal/model"
"sub2api/internal/repository" "sub2api/internal/pkg/pagination"
"sub2api/internal/service/ports"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"gorm.io/gorm" "gorm.io/gorm"
) )
var ( var (
ErrUserNotFound = errors.New("user not found") ErrUserNotFound = errors.New("user not found")
ErrPasswordIncorrect = errors.New("current password is incorrect") ErrPasswordIncorrect = errors.New("current password is incorrect")
ErrInsufficientPerms = errors.New("insufficient permissions") ErrInsufficientPerms = errors.New("insufficient permissions")
) )
// UpdateProfileRequest 更新用户资料请求 // UpdateProfileRequest 更新用户资料请求
...@@ -32,12 +33,12 @@ type ChangePasswordRequest struct { ...@@ -32,12 +33,12 @@ type ChangePasswordRequest struct {
// UserService 用户服务 // UserService 用户服务
type UserService struct { type UserService struct {
userRepo *repository.UserRepository userRepo ports.UserRepository
cfg *config.Config cfg *config.Config
} }
// NewUserService 创建用户服务实例 // NewUserService 创建用户服务实例
func NewUserService(userRepo *repository.UserRepository, cfg *config.Config) *UserService { func NewUserService(userRepo ports.UserRepository, cfg *config.Config) *UserService {
return &UserService{ return &UserService{
userRepo: userRepo, userRepo: userRepo,
cfg: cfg, cfg: cfg,
...@@ -133,7 +134,7 @@ func (s *UserService) GetByID(ctx context.Context, id int64) (*model.User, error ...@@ -133,7 +134,7 @@ func (s *UserService) GetByID(ctx context.Context, id int64) (*model.User, error
} }
// List 获取用户列表(管理员功能) // List 获取用户列表(管理员功能)
func (s *UserService) List(ctx context.Context, params repository.PaginationParams) ([]model.User, *repository.PaginationResult, error) { func (s *UserService) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) {
users, pagination, err := s.userRepo.List(ctx, params) users, pagination, err := s.userRepo.List(ctx, params)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("list users: %w", err) return nil, nil, fmt.Errorf("list users: %w", err)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment