Commit eeaff85e authored by Forest's avatar Forest
Browse files

refactor: 自定义业务错误

parent f51ad2e1
...@@ -19,13 +19,13 @@ type UsageLogRepoSuite struct { ...@@ -19,13 +19,13 @@ type UsageLogRepoSuite struct {
suite.Suite suite.Suite
ctx context.Context ctx context.Context
db *gorm.DB db *gorm.DB
repo *UsageLogRepository repo *usageLogRepository
} }
func (s *UsageLogRepoSuite) SetupTest() { func (s *UsageLogRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
s.db = testTx(s.T()) s.db = testTx(s.T())
s.repo = NewUsageLogRepository(s.db) s.repo = NewUsageLogRepository(s.db).(*usageLogRepository)
} }
func TestUsageLogRepoSuite(t *testing.T) { func TestUsageLogRepoSuite(t *testing.T) {
......
...@@ -2,56 +2,61 @@ package repository ...@@ -2,56 +2,61 @@ package repository
import ( import (
"context" "context"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm" "gorm.io/gorm"
) )
type UserRepository struct { type userRepository struct {
db *gorm.DB db *gorm.DB
} }
func NewUserRepository(db *gorm.DB) *UserRepository { func NewUserRepository(db *gorm.DB) service.UserRepository {
return &UserRepository{db: db} return &userRepository{db: db}
} }
func (r *UserRepository) Create(ctx context.Context, user *model.User) error { func (r *userRepository) Create(ctx context.Context, user *model.User) error {
return r.db.WithContext(ctx).Create(user).Error err := r.db.WithContext(ctx).Create(user).Error
return translatePersistenceError(err, nil, service.ErrEmailExists)
} }
func (r *UserRepository) GetByID(ctx context.Context, id int64) (*model.User, error) { func (r *userRepository) GetByID(ctx context.Context, id int64) (*model.User, error) {
var user model.User var user model.User
err := r.db.WithContext(ctx).First(&user, id).Error err := r.db.WithContext(ctx).First(&user, id).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
} }
return &user, nil return &user, nil
} }
func (r *UserRepository) GetByEmail(ctx context.Context, email string) (*model.User, error) { func (r *userRepository) GetByEmail(ctx context.Context, email string) (*model.User, error) {
var user model.User var user model.User
err := r.db.WithContext(ctx).Where("email = ?", email).First(&user).Error err := r.db.WithContext(ctx).Where("email = ?", email).First(&user).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
} }
return &user, nil return &user, nil
} }
func (r *UserRepository) Update(ctx context.Context, user *model.User) error { func (r *userRepository) Update(ctx context.Context, user *model.User) error {
return r.db.WithContext(ctx).Save(user).Error err := r.db.WithContext(ctx).Save(user).Error
return translatePersistenceError(err, nil, service.ErrEmailExists)
} }
func (r *UserRepository) Delete(ctx context.Context, id int64) error { func (r *userRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.User{}, id).Error return r.db.WithContext(ctx).Delete(&model.User{}, id).Error
} }
func (r *UserRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) { func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "") return r.ListWithFilters(ctx, params, "", "", "")
} }
// ListWithFilters lists users with optional filtering by status, role, and search query // ListWithFilters lists users with optional filtering by status, role, and search query
func (r *UserRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error) { func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error) {
var users []model.User var users []model.User
var total int64 var total int64
...@@ -120,13 +125,13 @@ func (r *UserRepository) ListWithFilters(ctx context.Context, params pagination. ...@@ -120,13 +125,13 @@ func (r *UserRepository) ListWithFilters(ctx context.Context, params pagination.
}, nil }, nil
} }
func (r *UserRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error { func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id).
Update("balance", gorm.Expr("balance + ?", amount)).Error Update("balance", gorm.Expr("balance + ?", amount)).Error
} }
// DeductBalance 扣减用户余额,仅当余额充足时执行 // DeductBalance 扣减用户余额,仅当余额充足时执行
func (r *UserRepository) DeductBalance(ctx context.Context, id int64, amount float64) error { func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error {
result := r.db.WithContext(ctx).Model(&model.User{}). result := r.db.WithContext(ctx).Model(&model.User{}).
Where("id = ? AND balance >= ?", id, amount). Where("id = ? AND balance >= ?", id, amount).
Update("balance", gorm.Expr("balance - ?", amount)) Update("balance", gorm.Expr("balance - ?", amount))
...@@ -134,17 +139,17 @@ func (r *UserRepository) DeductBalance(ctx context.Context, id int64, amount flo ...@@ -134,17 +139,17 @@ func (r *UserRepository) DeductBalance(ctx context.Context, id int64, amount flo
return result.Error return result.Error
} }
if result.RowsAffected == 0 { if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound // 余额不足或用户不存在 return service.ErrInsufficientBalance
} }
return nil return nil
} }
func (r *UserRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error { func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id).
Update("concurrency", gorm.Expr("concurrency + ?", amount)).Error Update("concurrency", gorm.Expr("concurrency + ?", amount)).Error
} }
func (r *UserRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) { func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.User{}).Where("email = ?", email).Count(&count).Error err := r.db.WithContext(ctx).Model(&model.User{}).Where("email = ?", email).Count(&count).Error
return count > 0, err return count > 0, err
...@@ -152,7 +157,7 @@ func (r *UserRepository) ExistsByEmail(ctx context.Context, email string) (bool, ...@@ -152,7 +157,7 @@ func (r *UserRepository) ExistsByEmail(ctx context.Context, email string) (bool,
// RemoveGroupFromAllowedGroups 从所有用户的 allowed_groups 数组中移除指定的分组ID // RemoveGroupFromAllowedGroups 从所有用户的 allowed_groups 数组中移除指定的分组ID
// 使用 PostgreSQL 的 array_remove 函数 // 使用 PostgreSQL 的 array_remove 函数
func (r *UserRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) { func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Model(&model.User{}). result := r.db.WithContext(ctx).Model(&model.User{}).
Where("? = ANY(allowed_groups)", groupID). Where("? = ANY(allowed_groups)", groupID).
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", groupID)) Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", groupID))
...@@ -160,14 +165,14 @@ func (r *UserRepository) RemoveGroupFromAllowedGroups(ctx context.Context, group ...@@ -160,14 +165,14 @@ func (r *UserRepository) RemoveGroupFromAllowedGroups(ctx context.Context, group
} }
// GetFirstAdmin 获取第一个管理员用户(用于 Admin API Key 认证) // GetFirstAdmin 获取第一个管理员用户(用于 Admin API Key 认证)
func (r *UserRepository) GetFirstAdmin(ctx context.Context) (*model.User, error) { func (r *userRepository) GetFirstAdmin(ctx context.Context) (*model.User, error) {
var user model.User var user model.User
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("role = ? AND status = ?", model.RoleAdmin, model.StatusActive). Where("role = ? AND status = ?", model.RoleAdmin, model.StatusActive).
Order("id ASC"). Order("id ASC").
First(&user).Error First(&user).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
} }
return &user, nil return &user, nil
} }
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"gorm.io/gorm" "gorm.io/gorm"
...@@ -18,13 +19,13 @@ type UserRepoSuite struct { ...@@ -18,13 +19,13 @@ type UserRepoSuite struct {
suite.Suite suite.Suite
ctx context.Context ctx context.Context
db *gorm.DB db *gorm.DB
repo *UserRepository repo *userRepository
} }
func (s *UserRepoSuite) SetupTest() { func (s *UserRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
s.db = testTx(s.T()) s.db = testTx(s.T())
s.repo = NewUserRepository(s.db) s.repo = NewUserRepository(s.db).(*userRepository)
} }
func TestUserRepoSuite(t *testing.T) { func TestUserRepoSuite(t *testing.T) {
...@@ -247,7 +248,7 @@ func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() { ...@@ -247,7 +248,7 @@ func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() {
err := s.repo.DeductBalance(s.ctx, user.ID, 999) err := s.repo.DeductBalance(s.ctx, user.ID, 999)
s.Require().Error(err, "expected error for insufficient balance") s.Require().Error(err, "expected error for insufficient balance")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound) s.Require().ErrorIs(err, service.ErrInsufficientBalance)
} }
func (s *UserRepoSuite) TestDeductBalance_ExactAmount() { func (s *UserRepoSuite) TestDeductBalance_ExactAmount() {
...@@ -432,7 +433,7 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() { ...@@ -432,7 +433,7 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
err = s.repo.DeductBalance(s.ctx, user1.ID, 999) err = s.repo.DeductBalance(s.ctx, user1.ID, 999)
s.Require().Error(err, "DeductBalance expected error for insufficient balance") s.Require().Error(err, "DeductBalance expected error for insufficient balance")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound, "DeductBalance unexpected error") s.Require().ErrorIs(err, service.ErrInsufficientBalance, "DeductBalance unexpected error")
s.Require().NoError(s.repo.UpdateConcurrency(s.ctx, user1.ID, 3), "UpdateConcurrency") s.Require().NoError(s.repo.UpdateConcurrency(s.ctx, user1.ID, 3), "UpdateConcurrency")
got5, err := s.repo.GetByID(s.ctx, user1.ID) got5, err := s.repo.GetByID(s.ctx, user1.ID)
......
...@@ -6,27 +6,29 @@ import ( ...@@ -6,27 +6,29 @@ import (
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"gorm.io/gorm" "gorm.io/gorm"
) )
// UserSubscriptionRepository 用户订阅仓库 // UserSubscriptionRepository 用户订阅仓库
type UserSubscriptionRepository struct { type userSubscriptionRepository struct {
db *gorm.DB db *gorm.DB
} }
// NewUserSubscriptionRepository 创建用户订阅仓库 // NewUserSubscriptionRepository 创建用户订阅仓库
func NewUserSubscriptionRepository(db *gorm.DB) *UserSubscriptionRepository { func NewUserSubscriptionRepository(db *gorm.DB) service.UserSubscriptionRepository {
return &UserSubscriptionRepository{db: db} return &userSubscriptionRepository{db: db}
} }
// Create 创建订阅 // Create 创建订阅
func (r *UserSubscriptionRepository) Create(ctx context.Context, sub *model.UserSubscription) error { func (r *userSubscriptionRepository) Create(ctx context.Context, sub *model.UserSubscription) error {
return r.db.WithContext(ctx).Create(sub).Error err := r.db.WithContext(ctx).Create(sub).Error
return translatePersistenceError(err, nil, service.ErrSubscriptionAlreadyExists)
} }
// GetByID 根据ID获取订阅 // GetByID 根据ID获取订阅
func (r *UserSubscriptionRepository) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) { func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) {
var sub model.UserSubscription var sub model.UserSubscription
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Preload("User"). Preload("User").
...@@ -34,26 +36,26 @@ func (r *UserSubscriptionRepository) GetByID(ctx context.Context, id int64) (*mo ...@@ -34,26 +36,26 @@ func (r *UserSubscriptionRepository) GetByID(ctx context.Context, id int64) (*mo
Preload("AssignedByUser"). Preload("AssignedByUser").
First(&sub, id).Error First(&sub, id).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
} }
return &sub, nil return &sub, nil
} }
// GetByUserIDAndGroupID 根据用户ID和分组ID获取订阅 // GetByUserIDAndGroupID 根据用户ID和分组ID获取订阅
func (r *UserSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) { func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
var sub model.UserSubscription var sub model.UserSubscription
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Preload("Group"). Preload("Group").
Where("user_id = ? AND group_id = ?", userID, groupID). Where("user_id = ? AND group_id = ?", userID, groupID).
First(&sub).Error First(&sub).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
} }
return &sub, nil return &sub, nil
} }
// GetActiveByUserIDAndGroupID 获取用户对特定分组的有效订阅 // GetActiveByUserIDAndGroupID 获取用户对特定分组的有效订阅
func (r *UserSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) { func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
var sub model.UserSubscription var sub model.UserSubscription
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Preload("Group"). Preload("Group").
...@@ -61,24 +63,24 @@ func (r *UserSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Con ...@@ -61,24 +63,24 @@ func (r *UserSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Con
userID, groupID, model.SubscriptionStatusActive, time.Now()). userID, groupID, model.SubscriptionStatusActive, time.Now()).
First(&sub).Error First(&sub).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
} }
return &sub, nil return &sub, nil
} }
// Update 更新订阅 // Update 更新订阅
func (r *UserSubscriptionRepository) Update(ctx context.Context, sub *model.UserSubscription) error { func (r *userSubscriptionRepository) Update(ctx context.Context, sub *model.UserSubscription) error {
sub.UpdatedAt = time.Now() sub.UpdatedAt = time.Now()
return r.db.WithContext(ctx).Save(sub).Error return r.db.WithContext(ctx).Save(sub).Error
} }
// Delete 删除订阅 // Delete 删除订阅
func (r *UserSubscriptionRepository) Delete(ctx context.Context, id int64) error { func (r *userSubscriptionRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.UserSubscription{}, id).Error return r.db.WithContext(ctx).Delete(&model.UserSubscription{}, id).Error
} }
// ListByUserID 获取用户的所有订阅 // ListByUserID 获取用户的所有订阅
func (r *UserSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) { func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
var subs []model.UserSubscription var subs []model.UserSubscription
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Preload("Group"). Preload("Group").
...@@ -89,7 +91,7 @@ func (r *UserSubscriptionRepository) ListByUserID(ctx context.Context, userID in ...@@ -89,7 +91,7 @@ func (r *UserSubscriptionRepository) ListByUserID(ctx context.Context, userID in
} }
// ListActiveByUserID 获取用户的所有有效订阅 // ListActiveByUserID 获取用户的所有有效订阅
func (r *UserSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) { func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
var subs []model.UserSubscription var subs []model.UserSubscription
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Preload("Group"). Preload("Group").
...@@ -101,7 +103,7 @@ func (r *UserSubscriptionRepository) ListActiveByUserID(ctx context.Context, use ...@@ -101,7 +103,7 @@ func (r *UserSubscriptionRepository) ListActiveByUserID(ctx context.Context, use
} }
// ListByGroupID 获取分组的所有订阅(分页) // ListByGroupID 获取分组的所有订阅(分页)
func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error) { func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error) {
var subs []model.UserSubscription var subs []model.UserSubscription
var total int64 var total int64
...@@ -136,7 +138,7 @@ func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID ...@@ -136,7 +138,7 @@ func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
} }
// List 获取所有订阅(分页,支持筛选) // List 获取所有订阅(分页,支持筛选)
func (r *UserSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) { func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) {
var subs []model.UserSubscription var subs []model.UserSubscription
var total int64 var total int64
...@@ -182,7 +184,7 @@ func (r *UserSubscriptionRepository) List(ctx context.Context, params pagination ...@@ -182,7 +184,7 @@ func (r *UserSubscriptionRepository) List(ctx context.Context, params pagination
} }
// IncrementUsage 增加使用量 // IncrementUsage 增加使用量
func (r *UserSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error { func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}). return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
...@@ -194,7 +196,7 @@ func (r *UserSubscriptionRepository) IncrementUsage(ctx context.Context, id int6 ...@@ -194,7 +196,7 @@ func (r *UserSubscriptionRepository) IncrementUsage(ctx context.Context, id int6
} }
// ResetDailyUsage 重置日使用量 // ResetDailyUsage 重置日使用量
func (r *UserSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}). return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
...@@ -205,7 +207,7 @@ func (r *UserSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int ...@@ -205,7 +207,7 @@ func (r *UserSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int
} }
// ResetWeeklyUsage 重置周使用量 // ResetWeeklyUsage 重置周使用量
func (r *UserSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}). return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
...@@ -216,7 +218,7 @@ func (r *UserSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id in ...@@ -216,7 +218,7 @@ func (r *UserSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id in
} }
// ResetMonthlyUsage 重置月使用量 // ResetMonthlyUsage 重置月使用量
func (r *UserSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}). return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
...@@ -227,7 +229,7 @@ func (r *UserSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id i ...@@ -227,7 +229,7 @@ func (r *UserSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id i
} }
// ActivateWindows 激活所有窗口(首次使用时) // ActivateWindows 激活所有窗口(首次使用时)
func (r *UserSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, activateTime time.Time) error { func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, activateTime time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}). return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
...@@ -239,7 +241,7 @@ func (r *UserSubscriptionRepository) ActivateWindows(ctx context.Context, id int ...@@ -239,7 +241,7 @@ func (r *UserSubscriptionRepository) ActivateWindows(ctx context.Context, id int
} }
// UpdateStatus 更新订阅状态 // UpdateStatus 更新订阅状态
func (r *UserSubscriptionRepository) UpdateStatus(ctx context.Context, id int64, status string) error { func (r *userSubscriptionRepository) UpdateStatus(ctx context.Context, id int64, status string) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}). return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
...@@ -249,7 +251,7 @@ func (r *UserSubscriptionRepository) UpdateStatus(ctx context.Context, id int64, ...@@ -249,7 +251,7 @@ func (r *UserSubscriptionRepository) UpdateStatus(ctx context.Context, id int64,
} }
// ExtendExpiry 延长订阅过期时间 // ExtendExpiry 延长订阅过期时间
func (r *UserSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64, newExpiresAt time.Time) error { func (r *userSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64, newExpiresAt time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}). return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
...@@ -259,7 +261,7 @@ func (r *UserSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64, ...@@ -259,7 +261,7 @@ func (r *UserSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64,
} }
// UpdateNotes 更新订阅备注 // UpdateNotes 更新订阅备注
func (r *UserSubscriptionRepository) UpdateNotes(ctx context.Context, id int64, notes string) error { func (r *userSubscriptionRepository) UpdateNotes(ctx context.Context, id int64, notes string) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}). return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
...@@ -269,7 +271,7 @@ func (r *UserSubscriptionRepository) UpdateNotes(ctx context.Context, id int64, ...@@ -269,7 +271,7 @@ func (r *UserSubscriptionRepository) UpdateNotes(ctx context.Context, id int64,
} }
// ListExpired 获取所有已过期但状态仍为active的订阅 // ListExpired 获取所有已过期但状态仍为active的订阅
func (r *UserSubscriptionRepository) ListExpired(ctx context.Context) ([]model.UserSubscription, error) { func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]model.UserSubscription, error) {
var subs []model.UserSubscription var subs []model.UserSubscription
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()). Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()).
...@@ -278,7 +280,7 @@ func (r *UserSubscriptionRepository) ListExpired(ctx context.Context) ([]model.U ...@@ -278,7 +280,7 @@ func (r *UserSubscriptionRepository) ListExpired(ctx context.Context) ([]model.U
} }
// BatchUpdateExpiredStatus 批量更新过期订阅状态 // BatchUpdateExpiredStatus 批量更新过期订阅状态
func (r *UserSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) { func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
result := r.db.WithContext(ctx).Model(&model.UserSubscription{}). result := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()). Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()).
Updates(map[string]any{ Updates(map[string]any{
...@@ -289,7 +291,7 @@ func (r *UserSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Contex ...@@ -289,7 +291,7 @@ func (r *UserSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Contex
} }
// ExistsByUserIDAndGroupID 检查用户是否已有该分组的订阅 // ExistsByUserIDAndGroupID 检查用户是否已有该分组的订阅
func (r *UserSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) { func (r *userSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}). err := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("user_id = ? AND group_id = ?", userID, groupID). Where("user_id = ? AND group_id = ?", userID, groupID).
...@@ -298,7 +300,7 @@ func (r *UserSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Contex ...@@ -298,7 +300,7 @@ func (r *UserSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Contex
} }
// CountByGroupID 获取分组的订阅数量 // CountByGroupID 获取分组的订阅数量
func (r *UserSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *userSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}). err := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("group_id = ?", groupID). Where("group_id = ?", groupID).
...@@ -307,7 +309,7 @@ func (r *UserSubscriptionRepository) CountByGroupID(ctx context.Context, groupID ...@@ -307,7 +309,7 @@ func (r *UserSubscriptionRepository) CountByGroupID(ctx context.Context, groupID
} }
// CountActiveByGroupID 获取分组的有效订阅数量 // CountActiveByGroupID 获取分组的有效订阅数量
func (r *UserSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *userSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}). err := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("group_id = ? AND status = ? AND expires_at > ?", Where("group_id = ? AND status = ? AND expires_at > ?",
...@@ -317,7 +319,7 @@ func (r *UserSubscriptionRepository) CountActiveByGroupID(ctx context.Context, g ...@@ -317,7 +319,7 @@ func (r *UserSubscriptionRepository) CountActiveByGroupID(ctx context.Context, g
} }
// DeleteByGroupID 删除分组相关的所有订阅记录 // DeleteByGroupID 删除分组相关的所有订阅记录
func (r *UserSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *userSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.UserSubscription{}) result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.UserSubscription{})
return result.RowsAffected, result.Error return result.RowsAffected, result.Error
} }
...@@ -17,13 +17,13 @@ type UserSubscriptionRepoSuite struct { ...@@ -17,13 +17,13 @@ type UserSubscriptionRepoSuite struct {
suite.Suite suite.Suite
ctx context.Context ctx context.Context
db *gorm.DB db *gorm.DB
repo *UserSubscriptionRepository repo *userSubscriptionRepository
} }
func (s *UserSubscriptionRepoSuite) SetupTest() { func (s *UserSubscriptionRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
s.db = testTx(s.T()) s.db = testTx(s.T())
s.repo = NewUserSubscriptionRepository(s.db) s.repo = NewUserSubscriptionRepository(s.db).(*userSubscriptionRepository)
} }
func TestUserSubscriptionRepoSuite(t *testing.T) { func TestUserSubscriptionRepoSuite(t *testing.T) {
......
package repository package repository
import ( import (
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/google/wire" "github.com/google/wire"
) )
...@@ -37,15 +36,4 @@ var ProviderSet = wire.NewSet( ...@@ -37,15 +36,4 @@ var ProviderSet = wire.NewSet(
NewClaudeOAuthClient, NewClaudeOAuthClient,
NewHTTPUpstream, NewHTTPUpstream,
NewOpenAIOAuthClient, NewOpenAIOAuthClient,
// Bind concrete repositories to service port interfaces
wire.Bind(new(service.UserRepository), new(*UserRepository)),
wire.Bind(new(service.ApiKeyRepository), new(*ApiKeyRepository)),
wire.Bind(new(service.GroupRepository), new(*GroupRepository)),
wire.Bind(new(service.AccountRepository), new(*AccountRepository)),
wire.Bind(new(service.ProxyRepository), new(*ProxyRepository)),
wire.Bind(new(service.RedeemCodeRepository), new(*RedeemCodeRepository)),
wire.Bind(new(service.UsageLogRepository), new(*UsageLogRepository)),
wire.Bind(new(service.SettingRepository), new(*SettingRepository)),
wire.Bind(new(service.UserSubscriptionRepository), new(*UserSubscriptionRepository)),
) )
...@@ -2,17 +2,16 @@ package service ...@@ -2,17 +2,16 @@ package service
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
) )
var ( var (
ErrAccountNotFound = errors.New("account not found") ErrAccountNotFound = infraerrors.NotFound("ACCOUNT_NOT_FOUND", "account not found")
) )
type AccountRepository interface { type AccountRepository interface {
...@@ -106,9 +105,6 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) ( ...@@ -106,9 +105,6 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
for _, groupID := range req.GroupIDs { for _, groupID := range req.GroupIDs {
_, err := s.groupRepo.GetByID(ctx, groupID) _, err := s.groupRepo.GetByID(ctx, groupID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("group %d not found", groupID)
}
return nil, fmt.Errorf("get group: %w", err) return nil, fmt.Errorf("get group: %w", err)
} }
} }
...@@ -145,9 +141,6 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) ( ...@@ -145,9 +141,6 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
func (s *AccountService) GetByID(ctx context.Context, id int64) (*model.Account, error) { func (s *AccountService) GetByID(ctx context.Context, id int64) (*model.Account, error) {
account, err := s.accountRepo.GetByID(ctx, id) account, err := s.accountRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrAccountNotFound
}
return nil, fmt.Errorf("get account: %w", err) return nil, fmt.Errorf("get account: %w", err)
} }
return account, nil return account, nil
...@@ -184,9 +177,6 @@ func (s *AccountService) ListByGroup(ctx context.Context, groupID int64) ([]mode ...@@ -184,9 +177,6 @@ func (s *AccountService) ListByGroup(ctx context.Context, groupID int64) ([]mode
func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccountRequest) (*model.Account, error) { func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccountRequest) (*model.Account, error) {
account, err := s.accountRepo.GetByID(ctx, id) account, err := s.accountRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrAccountNotFound
}
return nil, fmt.Errorf("get account: %w", err) return nil, fmt.Errorf("get account: %w", err)
} }
...@@ -229,9 +219,6 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount ...@@ -229,9 +219,6 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount
for _, groupID := range *req.GroupIDs { for _, groupID := range *req.GroupIDs {
_, err := s.groupRepo.GetByID(ctx, groupID) _, err := s.groupRepo.GetByID(ctx, groupID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("group %d not found", groupID)
}
return nil, fmt.Errorf("get group: %w", err) return nil, fmt.Errorf("get group: %w", err)
} }
} }
...@@ -249,9 +236,6 @@ func (s *AccountService) Delete(ctx context.Context, id int64) error { ...@@ -249,9 +236,6 @@ func (s *AccountService) Delete(ctx context.Context, id int64) error {
// 检查账号是否存在 // 检查账号是否存在
_, err := s.accountRepo.GetByID(ctx, id) _, err := s.accountRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrAccountNotFound
}
return fmt.Errorf("get account: %w", err) return fmt.Errorf("get account: %w", err)
} }
...@@ -266,9 +250,6 @@ func (s *AccountService) Delete(ctx context.Context, id int64) error { ...@@ -266,9 +250,6 @@ func (s *AccountService) Delete(ctx context.Context, id int64) error {
func (s *AccountService) UpdateStatus(ctx context.Context, id int64, status string, errorMessage string) error { func (s *AccountService) UpdateStatus(ctx context.Context, id int64, status string, errorMessage string) error {
account, err := s.accountRepo.GetByID(ctx, id) account, err := s.accountRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrAccountNotFound
}
return fmt.Errorf("get account: %w", err) return fmt.Errorf("get account: %w", err)
} }
...@@ -294,9 +275,6 @@ func (s *AccountService) UpdateLastUsed(ctx context.Context, id int64) error { ...@@ -294,9 +275,6 @@ func (s *AccountService) UpdateLastUsed(ctx context.Context, id int64) error {
func (s *AccountService) GetCredential(ctx context.Context, id int64, key string) (string, error) { func (s *AccountService) GetCredential(ctx context.Context, id int64, key string) (string, error) {
account, err := s.accountRepo.GetByID(ctx, id) account, err := s.accountRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", ErrAccountNotFound
}
return "", fmt.Errorf("get account: %w", err) return "", fmt.Errorf("get account: %w", err)
} }
...@@ -307,9 +285,6 @@ func (s *AccountService) GetCredential(ctx context.Context, id int64, key string ...@@ -307,9 +285,6 @@ func (s *AccountService) GetCredential(ctx context.Context, id int64, key string
func (s *AccountService) TestCredentials(ctx context.Context, id int64) error { func (s *AccountService) TestCredentials(ctx context.Context, id int64) error {
account, err := s.accountRepo.GetByID(ctx, id) account, err := s.accountRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrAccountNotFound
}
return fmt.Errorf("get account: %w", err) return fmt.Errorf("get account: %w", err)
} }
......
...@@ -9,7 +9,6 @@ import ( ...@@ -9,7 +9,6 @@ import (
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
) )
// AdminService interface defines admin management operations // AdminService interface defines admin management operations
...@@ -550,61 +549,7 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd ...@@ -550,61 +549,7 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
} }
func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error { func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
// 先获取分组信息,检查是否存在 affectedUserIDs, err := s.groupRepo.DeleteCascade(ctx, id)
group, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
return fmt.Errorf("group not found: %w", err)
}
// 订阅类型分组:先获取受影响的用户ID列表(用于事务后失效缓存)
var affectedUserIDs []int64
if group.IsSubscriptionType() && s.billingCacheService != nil {
var subscriptions []model.UserSubscription
if err := s.groupRepo.DB().WithContext(ctx).
Where("group_id = ?", id).
Select("user_id").
Find(&subscriptions).Error; err == nil {
for _, sub := range subscriptions {
affectedUserIDs = append(affectedUserIDs, sub.UserID)
}
}
}
// 使用事务处理所有级联删除
db := s.groupRepo.DB()
err = db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// 1. 如果是订阅类型分组,删除 user_subscriptions 中的相关记录
if group.IsSubscriptionType() {
if err := tx.Where("group_id = ?", id).Delete(&model.UserSubscription{}).Error; err != nil {
return fmt.Errorf("delete user subscriptions: %w", err)
}
}
// 2. 将 api_keys 中绑定该分组的 group_id 设为 nil(任何类型的分组都需要)
if err := tx.Model(&model.ApiKey{}).Where("group_id = ?", id).Update("group_id", nil).Error; err != nil {
return fmt.Errorf("clear api key group_id: %w", err)
}
// 3. 从 users.allowed_groups 数组中移除该分组 ID
if err := tx.Model(&model.User{}).
Where("? = ANY(allowed_groups)", id).
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", id)).Error; err != nil {
return fmt.Errorf("remove from allowed_groups: %w", err)
}
// 4. 删除 account_groups 中间表的数据
if err := tx.Where("group_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil {
return fmt.Errorf("delete account groups: %w", err)
}
// 5. 删除分组本身
if err := tx.Delete(&model.Group{}, id).Error; err != nil {
return fmt.Errorf("delete group: %w", err)
}
return nil
})
if err != nil { if err != nil {
return err return err
} }
......
...@@ -9,20 +9,20 @@ import ( ...@@ -9,20 +9,20 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"gorm.io/gorm"
) )
var ( var (
ErrApiKeyNotFound = errors.New("api key not found") ErrApiKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found")
ErrGroupNotAllowed = errors.New("user is not allowed to bind this group") ErrGroupNotAllowed = infraerrors.Forbidden("GROUP_NOT_ALLOWED", "user is not allowed to bind this group")
ErrApiKeyExists = errors.New("api key already exists") ErrApiKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists")
ErrApiKeyTooShort = errors.New("api key must be at least 16 characters") ErrApiKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
ErrApiKeyInvalidChars = errors.New("api key can only contain letters, numbers, underscores, and hyphens") ErrApiKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
ErrApiKeyRateLimited = errors.New("too many failed attempts, please try again later") ErrApiKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
) )
const ( const (
...@@ -183,9 +183,6 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK ...@@ -183,9 +183,6 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
// 验证用户存在 // 验证用户存在
user, err := s.userRepo.GetByID(ctx, userID) user, err := s.userRepo.GetByID(ctx, userID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err) return nil, fmt.Errorf("get user: %w", err)
} }
...@@ -193,9 +190,6 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK ...@@ -193,9 +190,6 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
if req.GroupID != nil { if req.GroupID != nil {
group, err := s.groupRepo.GetByID(ctx, *req.GroupID) group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("group not found")
}
return nil, fmt.Errorf("get group: %w", err) return nil, fmt.Errorf("get group: %w", err)
} }
...@@ -269,9 +263,6 @@ func (s *ApiKeyService) List(ctx context.Context, userID int64, params paginatio ...@@ -269,9 +263,6 @@ func (s *ApiKeyService) List(ctx context.Context, userID int64, params paginatio
func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) { func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id) apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrApiKeyNotFound
}
return nil, fmt.Errorf("get api key: %w", err) return nil, fmt.Errorf("get api key: %w", err)
} }
return apiKey, nil return apiKey, nil
...@@ -285,9 +276,6 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey ...@@ -285,9 +276,6 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey
// 这里可以添加Redis缓存逻辑,暂时直接查询数据库 // 这里可以添加Redis缓存逻辑,暂时直接查询数据库
apiKey, err := s.apiKeyRepo.GetByKey(ctx, key) apiKey, err := s.apiKeyRepo.GetByKey(ctx, key)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrApiKeyNotFound
}
return nil, fmt.Errorf("get api key: %w", err) return nil, fmt.Errorf("get api key: %w", err)
} }
...@@ -304,9 +292,6 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey ...@@ -304,9 +292,6 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey
func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*model.ApiKey, error) { func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*model.ApiKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id) apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrApiKeyNotFound
}
return nil, fmt.Errorf("get api key: %w", err) return nil, fmt.Errorf("get api key: %w", err)
} }
...@@ -329,9 +314,6 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req ...@@ -329,9 +314,6 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req
group, err := s.groupRepo.GetByID(ctx, *req.GroupID) group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("group not found")
}
return nil, fmt.Errorf("get group: %w", err) return nil, fmt.Errorf("get group: %w", err)
} }
...@@ -361,9 +343,6 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req ...@@ -361,9 +343,6 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req
func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) error { func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) error {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id) apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrApiKeyNotFound
}
return fmt.Errorf("get api key: %w", err) return fmt.Errorf("get api key: %w", err)
} }
...@@ -394,15 +373,12 @@ func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*model.Api ...@@ -394,15 +373,12 @@ func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*model.Api
// 检查API Key状态 // 检查API Key状态
if !apiKey.IsActive() { if !apiKey.IsActive() {
return nil, nil, errors.New("api key is not active") return nil, nil, infraerrors.Unauthorized("API_KEY_INACTIVE", "api key is not active")
} }
// 获取用户信息 // 获取用户信息
user, err := s.userRepo.GetByID(ctx, apiKey.UserID) user, err := s.userRepo.GetByID(ctx, apiKey.UserID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil, ErrUserNotFound
}
return nil, nil, fmt.Errorf("get user: %w", err) return nil, nil, fmt.Errorf("get user: %w", err)
} }
...@@ -436,9 +412,6 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([ ...@@ -436,9 +412,6 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([
// 获取用户信息 // 获取用户信息
user, err := s.userRepo.GetByID(ctx, userID) user, err := s.userRepo.GetByID(ctx, userID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err) return nil, fmt.Errorf("get user: %w", err)
} }
...@@ -450,7 +423,7 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([ ...@@ -450,7 +423,7 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([
// 获取用户的所有有效订阅 // 获取用户的所有有效订阅
activeSubscriptions, err := s.userSubRepo.ListActiveByUserID(ctx, userID) activeSubscriptions, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { if err != nil {
return nil, fmt.Errorf("list active subscriptions: %w", err) return nil, fmt.Errorf("list active subscriptions: %w", err)
} }
......
...@@ -8,22 +8,22 @@ import ( ...@@ -8,22 +8,22 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
) )
var ( var (
ErrInvalidCredentials = errors.New("invalid email or password") ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password")
ErrUserNotActive = errors.New("user is not active") ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active")
ErrEmailExists = errors.New("email already exists") ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
ErrInvalidToken = errors.New("invalid token") ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
ErrTokenExpired = errors.New("token has expired") ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
ErrEmailVerifyRequired = errors.New("email verification is required") ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
ErrRegDisabled = errors.New("registration is currently disabled") ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
ErrServiceUnavailable = errors.New("service temporarily unavailable") ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
) )
// JWTClaims JWT载荷数据 // JWTClaims JWT载荷数据
...@@ -255,7 +255,7 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string ...@@ -255,7 +255,7 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
// 查找用户 // 查找用户
user, err := s.userRepo.GetByEmail(ctx, email) user, err := s.userRepo.GetByEmail(ctx, email)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, ErrUserNotFound) {
return "", nil, ErrInvalidCredentials return "", nil, ErrInvalidCredentials
} }
// 记录数据库错误但不暴露给用户 // 记录数据库错误但不暴露给用户
...@@ -357,7 +357,7 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) ( ...@@ -357,7 +357,7 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
// 获取最新的用户信息 // 获取最新的用户信息
user, err := s.userRepo.GetByID(ctx, claims.UserID) user, err := s.userRepo.GetByID(ctx, claims.UserID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, ErrUserNotFound) {
return "", ErrInvalidToken return "", ErrInvalidToken
} }
log.Printf("[Auth] Database error refreshing token: %v", err) log.Printf("[Auth] Database error refreshing token: %v", err)
......
...@@ -2,11 +2,11 @@ package service ...@@ -2,11 +2,11 @@ package service
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"log" "log"
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
) )
...@@ -14,7 +14,7 @@ import ( ...@@ -14,7 +14,7 @@ import (
// 注:ErrInsufficientBalance在redeem_service.go中定义 // 注:ErrInsufficientBalance在redeem_service.go中定义
// 注:ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义 // 注:ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义
var ( var (
ErrSubscriptionInvalid = errors.New("subscription is invalid or expired") ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
) )
// subscriptionCacheData 订阅缓存数据结构(内部使用) // subscriptionCacheData 订阅缓存数据结构(内部使用)
......
...@@ -4,21 +4,21 @@ import ( ...@@ -4,21 +4,21 @@ import (
"context" "context"
"crypto/rand" "crypto/rand"
"crypto/tls" "crypto/tls"
"errors"
"fmt" "fmt"
"math/big" "math/big"
"net/smtp" "net/smtp"
"strconv" "strconv"
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
) )
var ( var (
ErrEmailNotConfigured = errors.New("email service not configured") ErrEmailNotConfigured = infraerrors.ServiceUnavailable("EMAIL_NOT_CONFIGURED", "email service not configured")
ErrInvalidVerifyCode = errors.New("invalid or expired verification code") ErrInvalidVerifyCode = infraerrors.BadRequest("INVALID_VERIFY_CODE", "invalid or expired verification code")
ErrVerifyCodeTooFrequent = errors.New("please wait before requesting a new code") ErrVerifyCodeTooFrequent = infraerrors.TooManyRequests("VERIFY_CODE_TOO_FREQUENT", "please wait before requesting a new code")
ErrVerifyCodeMaxAttempts = errors.New("too many failed attempts, please request a new code") ErrVerifyCodeMaxAttempts = infraerrors.TooManyRequests("VERIFY_CODE_MAX_ATTEMPTS", "too many failed attempts, please request a new code")
) )
// EmailCache defines cache operations for email service // EmailCache defines cache operations for email service
......
...@@ -2,17 +2,16 @@ package service ...@@ -2,17 +2,16 @@ package service
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
) )
var ( var (
ErrGroupNotFound = errors.New("group not found") ErrGroupNotFound = infraerrors.NotFound("GROUP_NOT_FOUND", "group not found")
ErrGroupExists = errors.New("group name already exists") ErrGroupExists = infraerrors.Conflict("GROUP_EXISTS", "group name already exists")
) )
type GroupRepository interface { type GroupRepository interface {
...@@ -20,6 +19,7 @@ type GroupRepository interface { ...@@ -20,6 +19,7 @@ type GroupRepository interface {
GetByID(ctx context.Context, id int64) (*model.Group, error) GetByID(ctx context.Context, id int64) (*model.Group, error)
Update(ctx context.Context, group *model.Group) error Update(ctx context.Context, group *model.Group) error
Delete(ctx context.Context, id int64) error Delete(ctx context.Context, id int64) error
DeleteCascade(ctx context.Context, id int64) ([]int64, error)
List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error)
...@@ -29,8 +29,6 @@ type GroupRepository interface { ...@@ -29,8 +29,6 @@ type GroupRepository interface {
ExistsByName(ctx context.Context, name string) (bool, error) ExistsByName(ctx context.Context, name string) (bool, error)
GetAccountCount(ctx context.Context, groupID int64) (int64, error) GetAccountCount(ctx context.Context, groupID int64) (int64, error)
DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error)
DB() *gorm.DB
} }
// CreateGroupRequest 创建分组请求 // CreateGroupRequest 创建分组请求
...@@ -93,9 +91,6 @@ func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*mod ...@@ -93,9 +91,6 @@ func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*mod
func (s *GroupService) GetByID(ctx context.Context, id int64) (*model.Group, error) { func (s *GroupService) GetByID(ctx context.Context, id int64) (*model.Group, error) {
group, err := s.groupRepo.GetByID(ctx, id) group, err := s.groupRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrGroupNotFound
}
return nil, fmt.Errorf("get group: %w", err) return nil, fmt.Errorf("get group: %w", err)
} }
return group, nil return group, nil
...@@ -123,9 +118,6 @@ func (s *GroupService) ListActive(ctx context.Context) ([]model.Group, error) { ...@@ -123,9 +118,6 @@ func (s *GroupService) ListActive(ctx context.Context) ([]model.Group, error) {
func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequest) (*model.Group, error) { func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequest) (*model.Group, error) {
group, err := s.groupRepo.GetByID(ctx, id) group, err := s.groupRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrGroupNotFound
}
return nil, fmt.Errorf("get group: %w", err) return nil, fmt.Errorf("get group: %w", err)
} }
...@@ -170,9 +162,6 @@ func (s *GroupService) Delete(ctx context.Context, id int64) error { ...@@ -170,9 +162,6 @@ func (s *GroupService) Delete(ctx context.Context, id int64) error {
// 检查分组是否存在 // 检查分组是否存在
_, err := s.groupRepo.GetByID(ctx, id) _, err := s.groupRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrGroupNotFound
}
return fmt.Errorf("get group: %w", err) return fmt.Errorf("get group: %w", err)
} }
...@@ -187,9 +176,6 @@ func (s *GroupService) Delete(ctx context.Context, id int64) error { ...@@ -187,9 +176,6 @@ func (s *GroupService) Delete(ctx context.Context, id int64) error {
func (s *GroupService) GetStats(ctx context.Context, id int64) (map[string]any, error) { func (s *GroupService) GetStats(ctx context.Context, id int64) (map[string]any, error) {
group, err := s.groupRepo.GetByID(ctx, id) group, err := s.groupRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrGroupNotFound
}
return nil, fmt.Errorf("get group: %w", err) return nil, fmt.Errorf("get group: %w", err)
} }
......
...@@ -2,16 +2,15 @@ package service ...@@ -2,16 +2,15 @@ package service
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
) )
var ( var (
ErrProxyNotFound = errors.New("proxy not found") ErrProxyNotFound = infraerrors.NotFound("PROXY_NOT_FOUND", "proxy not found")
) )
type ProxyRepository interface { type ProxyRepository interface {
...@@ -86,9 +85,6 @@ func (s *ProxyService) Create(ctx context.Context, req CreateProxyRequest) (*mod ...@@ -86,9 +85,6 @@ func (s *ProxyService) Create(ctx context.Context, req CreateProxyRequest) (*mod
func (s *ProxyService) GetByID(ctx context.Context, id int64) (*model.Proxy, error) { func (s *ProxyService) GetByID(ctx context.Context, id int64) (*model.Proxy, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id) proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrProxyNotFound
}
return nil, fmt.Errorf("get proxy: %w", err) return nil, fmt.Errorf("get proxy: %w", err)
} }
return proxy, nil return proxy, nil
...@@ -116,9 +112,6 @@ func (s *ProxyService) ListActive(ctx context.Context) ([]model.Proxy, error) { ...@@ -116,9 +112,6 @@ func (s *ProxyService) ListActive(ctx context.Context) ([]model.Proxy, error) {
func (s *ProxyService) Update(ctx context.Context, id int64, req UpdateProxyRequest) (*model.Proxy, error) { func (s *ProxyService) Update(ctx context.Context, id int64, req UpdateProxyRequest) (*model.Proxy, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id) proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrProxyNotFound
}
return nil, fmt.Errorf("get proxy: %w", err) return nil, fmt.Errorf("get proxy: %w", err)
} }
...@@ -163,9 +156,6 @@ func (s *ProxyService) Delete(ctx context.Context, id int64) error { ...@@ -163,9 +156,6 @@ func (s *ProxyService) Delete(ctx context.Context, id int64) error {
// 检查代理是否存在 // 检查代理是否存在
_, err := s.proxyRepo.GetByID(ctx, id) _, err := s.proxyRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrProxyNotFound
}
return fmt.Errorf("get proxy: %w", err) return fmt.Errorf("get proxy: %w", err)
} }
...@@ -180,9 +170,6 @@ func (s *ProxyService) Delete(ctx context.Context, id int64) error { ...@@ -180,9 +170,6 @@ func (s *ProxyService) Delete(ctx context.Context, id int64) error {
func (s *ProxyService) TestConnection(ctx context.Context, id int64) error { func (s *ProxyService) TestConnection(ctx context.Context, id int64) error {
proxy, err := s.proxyRepo.GetByID(ctx, id) proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrProxyNotFound
}
return fmt.Errorf("get proxy: %w", err) return fmt.Errorf("get proxy: %w", err)
} }
...@@ -197,9 +184,6 @@ func (s *ProxyService) TestConnection(ctx context.Context, id int64) error { ...@@ -197,9 +184,6 @@ func (s *ProxyService) TestConnection(ctx context.Context, id int64) error {
func (s *ProxyService) GetURL(ctx context.Context, id int64) (string, error) { func (s *ProxyService) GetURL(ctx context.Context, id int64) (string, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id) proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", ErrProxyNotFound
}
return "", fmt.Errorf("get proxy: %w", err) return "", fmt.Errorf("get proxy: %w", err)
} }
......
...@@ -9,19 +9,18 @@ import ( ...@@ -9,19 +9,18 @@ import (
"strings" "strings"
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"gorm.io/gorm"
) )
var ( var (
ErrRedeemCodeNotFound = errors.New("redeem code not found") ErrRedeemCodeNotFound = infraerrors.NotFound("REDEEM_CODE_NOT_FOUND", "redeem code not found")
ErrRedeemCodeUsed = errors.New("redeem code already used") ErrRedeemCodeUsed = infraerrors.Conflict("REDEEM_CODE_USED", "redeem code already used")
ErrRedeemCodeInvalid = errors.New("invalid redeem code") ErrInsufficientBalance = infraerrors.BadRequest("INSUFFICIENT_BALANCE", "insufficient balance")
ErrInsufficientBalance = errors.New("insufficient balance") ErrRedeemRateLimited = infraerrors.TooManyRequests("REDEEM_RATE_LIMITED", "too many failed attempts, please try again later")
ErrRedeemRateLimited = errors.New("too many failed attempts, please try again later") ErrRedeemCodeLocked = infraerrors.Conflict("REDEEM_CODE_LOCKED", "redeem code is being processed, please try again")
ErrRedeemCodeLocked = errors.New("redeem code is being processed, please try again")
) )
const ( const (
...@@ -226,7 +225,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) ( ...@@ -226,7 +225,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
// 查找兑换码 // 查找兑换码
redeemCode, err := s.redeemRepo.GetByCode(ctx, code) redeemCode, err := s.redeemRepo.GetByCode(ctx, code)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, ErrRedeemCodeNotFound) {
s.incrementRedeemErrorCount(ctx, userID) s.incrementRedeemErrorCount(ctx, userID)
return nil, ErrRedeemCodeNotFound return nil, ErrRedeemCodeNotFound
} }
...@@ -241,15 +240,12 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) ( ...@@ -241,15 +240,12 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
// 验证兑换码类型的前置条件 // 验证兑换码类型的前置条件
if redeemCode.Type == model.RedeemTypeSubscription && redeemCode.GroupID == nil { if redeemCode.Type == model.RedeemTypeSubscription && redeemCode.GroupID == nil {
return nil, errors.New("invalid subscription redeem code: missing group_id") return nil, infraerrors.BadRequest("REDEEM_CODE_INVALID", "invalid subscription redeem code: missing group_id")
} }
// 获取用户信息 // 获取用户信息
user, err := s.userRepo.GetByID(ctx, userID) user, err := s.userRepo.GetByID(ctx, userID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err) return nil, fmt.Errorf("get user: %w", err)
} }
_ = user // 使用变量避免未使用错误 _ = user // 使用变量避免未使用错误
...@@ -257,8 +253,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) ( ...@@ -257,8 +253,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
// 【关键】先标记兑换码为已使用,确保并发安全 // 【关键】先标记兑换码为已使用,确保并发安全
// 利用数据库乐观锁(WHERE status = 'unused')保证原子性 // 利用数据库乐观锁(WHERE status = 'unused')保证原子性
if err := s.redeemRepo.Use(ctx, redeemCode.ID, userID); err != nil { if err := s.redeemRepo.Use(ctx, redeemCode.ID, userID); err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, ErrRedeemCodeNotFound) || errors.Is(err, ErrRedeemCodeUsed) {
// 兑换码已被其他请求使用
return nil, ErrRedeemCodeUsed return nil, ErrRedeemCodeUsed
} }
return nil, fmt.Errorf("mark code as used: %w", err) return nil, fmt.Errorf("mark code as used: %w", err)
...@@ -328,9 +323,6 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) ( ...@@ -328,9 +323,6 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
func (s *RedeemService) GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) { func (s *RedeemService) GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) {
code, err := s.redeemRepo.GetByID(ctx, id) code, err := s.redeemRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrRedeemCodeNotFound
}
return nil, fmt.Errorf("get redeem code: %w", err) return nil, fmt.Errorf("get redeem code: %w", err)
} }
return code, nil return code, nil
...@@ -340,9 +332,6 @@ func (s *RedeemService) GetByID(ctx context.Context, id int64) (*model.RedeemCod ...@@ -340,9 +332,6 @@ func (s *RedeemService) GetByID(ctx context.Context, id int64) (*model.RedeemCod
func (s *RedeemService) GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) { func (s *RedeemService) GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) {
redeemCode, err := s.redeemRepo.GetByCode(ctx, code) redeemCode, err := s.redeemRepo.GetByCode(ctx, code)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrRedeemCodeNotFound
}
return nil, fmt.Errorf("get redeem code: %w", err) return nil, fmt.Errorf("get redeem code: %w", err)
} }
return redeemCode, nil return redeemCode, nil
...@@ -362,15 +351,12 @@ func (s *RedeemService) Delete(ctx context.Context, id int64) error { ...@@ -362,15 +351,12 @@ func (s *RedeemService) Delete(ctx context.Context, id int64) error {
// 检查兑换码是否存在 // 检查兑换码是否存在
code, err := s.redeemRepo.GetByID(ctx, id) code, err := s.redeemRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrRedeemCodeNotFound
}
return fmt.Errorf("get redeem code: %w", err) return fmt.Errorf("get redeem code: %w", err)
} }
// 不允许删除已使用的兑换码 // 不允许删除已使用的兑换码
if code.IsUsed() { if code.IsUsed() {
return errors.New("cannot delete used redeem code") return infraerrors.Conflict("REDEEM_CODE_DELETE_USED", "cannot delete used redeem code")
} }
if err := s.redeemRepo.Delete(ctx, id); err != nil { if err := s.redeemRepo.Delete(ctx, id); err != nil {
......
...@@ -9,13 +9,13 @@ import ( ...@@ -9,13 +9,13 @@ import (
"strconv" "strconv"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"gorm.io/gorm"
) )
var ( var (
ErrRegistrationDisabled = errors.New("registration is currently disabled") ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found")
) )
type SettingRepository interface { type SettingRepository interface {
...@@ -187,7 +187,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { ...@@ -187,7 +187,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// 已有设置,不需要初始化 // 已有设置,不需要初始化
return nil return nil
} }
if !errors.Is(err, gorm.ErrRecordNotFound) { if !errors.Is(err, ErrSettingNotFound) {
return fmt.Errorf("check existing settings: %w", err) return fmt.Errorf("check existing settings: %w", err)
} }
...@@ -302,7 +302,7 @@ func (s *SettingService) GenerateAdminApiKey(ctx context.Context) (string, error ...@@ -302,7 +302,7 @@ func (s *SettingService) GenerateAdminApiKey(ctx context.Context) (string, error
func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey string, exists bool, err error) { func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey string, exists bool, err error) {
key, err := s.settingRepo.GetValue(ctx, model.SettingKeyAdminApiKey) key, err := s.settingRepo.GetValue(ctx, model.SettingKeyAdminApiKey)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, ErrSettingNotFound) {
return "", false, nil return "", false, nil
} }
return "", false, err return "", false, err
...@@ -326,7 +326,7 @@ func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey st ...@@ -326,7 +326,7 @@ func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey st
func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) { func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) {
key, err := s.settingRepo.GetValue(ctx, model.SettingKeyAdminApiKey) key, err := s.settingRepo.GetValue(ctx, model.SettingKeyAdminApiKey)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, ErrSettingNotFound) {
return "", nil // 未配置,返回空字符串 return "", nil // 未配置,返回空字符串
} }
return "", err // 数据库错误 return "", err // 数据库错误
......
...@@ -2,24 +2,24 @@ package service ...@@ -2,24 +2,24 @@ package service
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"log" "log"
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
) )
var ( var (
ErrSubscriptionNotFound = errors.New("subscription not found") ErrSubscriptionNotFound = infraerrors.NotFound("SUBSCRIPTION_NOT_FOUND", "subscription not found")
ErrSubscriptionExpired = errors.New("subscription has expired") ErrSubscriptionExpired = infraerrors.Forbidden("SUBSCRIPTION_EXPIRED", "subscription has expired")
ErrSubscriptionSuspended = errors.New("subscription is suspended") ErrSubscriptionSuspended = infraerrors.Forbidden("SUBSCRIPTION_SUSPENDED", "subscription is suspended")
ErrSubscriptionAlreadyExists = errors.New("subscription already exists for this user and group") ErrSubscriptionAlreadyExists = infraerrors.Conflict("SUBSCRIPTION_ALREADY_EXISTS", "subscription already exists for this user and group")
ErrGroupNotSubscriptionType = errors.New("group is not a subscription type") ErrGroupNotSubscriptionType = infraerrors.BadRequest("GROUP_NOT_SUBSCRIPTION_TYPE", "group is not a subscription type")
ErrDailyLimitExceeded = errors.New("daily usage limit exceeded") ErrDailyLimitExceeded = infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", "daily usage limit exceeded")
ErrWeeklyLimitExceeded = errors.New("weekly usage limit exceeded") ErrWeeklyLimitExceeded = infraerrors.TooManyRequests("WEEKLY_LIMIT_EXCEEDED", "weekly usage limit exceeded")
ErrMonthlyLimitExceeded = errors.New("monthly usage limit exceeded") ErrMonthlyLimitExceeded = infraerrors.TooManyRequests("MONTHLY_LIMIT_EXCEEDED", "monthly usage limit exceeded")
) )
// SubscriptionService 订阅服务 // SubscriptionService 订阅服务
......
...@@ -2,14 +2,15 @@ package service ...@@ -2,14 +2,15 @@ package service
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"log" "log"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
) )
var ( var (
ErrTurnstileVerificationFailed = errors.New("turnstile verification failed") ErrTurnstileVerificationFailed = infraerrors.BadRequest("TURNSTILE_VERIFICATION_FAILED", "turnstile verification failed")
ErrTurnstileNotConfigured = errors.New("turnstile not configured") ErrTurnstileNotConfigured = infraerrors.ServiceUnavailable("TURNSTILE_NOT_CONFIGURED", "turnstile not configured")
) )
// TurnstileVerifier 验证 Turnstile token 的接口 // TurnstileVerifier 验证 Turnstile token 的接口
......
...@@ -2,18 +2,17 @@ package service ...@@ -2,18 +2,17 @@ package service
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"gorm.io/gorm"
) )
var ( var (
ErrUsageLogNotFound = errors.New("usage log not found") ErrUsageLogNotFound = infraerrors.NotFound("USAGE_LOG_NOT_FOUND", "usage log not found")
) )
// CreateUsageLogRequest 创建使用日志请求 // CreateUsageLogRequest 创建使用日志请求
...@@ -71,9 +70,6 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (* ...@@ -71,9 +70,6 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
// 验证用户存在 // 验证用户存在
_, err := s.userRepo.GetByID(ctx, req.UserID) _, err := s.userRepo.GetByID(ctx, req.UserID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err) return nil, fmt.Errorf("get user: %w", err)
} }
...@@ -119,9 +115,6 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (* ...@@ -119,9 +115,6 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
func (s *UsageService) GetByID(ctx context.Context, id int64) (*model.UsageLog, error) { func (s *UsageService) GetByID(ctx context.Context, id int64) (*model.UsageLog, error) {
log, err := s.usageRepo.GetByID(ctx, id) log, err := s.usageRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUsageLogNotFound
}
return nil, fmt.Errorf("get usage log: %w", err) return nil, fmt.Errorf("get usage log: %w", err)
} }
return log, nil return log, nil
......
...@@ -2,19 +2,18 @@ package service ...@@ -2,19 +2,18 @@ package service
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
) )
var ( var (
ErrUserNotFound = errors.New("user not found") ErrUserNotFound = infraerrors.NotFound("USER_NOT_FOUND", "user not found")
ErrPasswordIncorrect = errors.New("current password is incorrect") ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect")
ErrInsufficientPerms = errors.New("insufficient permissions") ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions")
) )
type UserRepository interface { type UserRepository interface {
...@@ -65,9 +64,6 @@ func NewUserService(userRepo UserRepository) *UserService { ...@@ -65,9 +64,6 @@ func NewUserService(userRepo UserRepository) *UserService {
func (s *UserService) GetProfile(ctx context.Context, userID int64) (*model.User, error) { func (s *UserService) GetProfile(ctx context.Context, userID int64) (*model.User, error) {
user, err := s.userRepo.GetByID(ctx, userID) user, err := s.userRepo.GetByID(ctx, userID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err) return nil, fmt.Errorf("get user: %w", err)
} }
return user, nil return user, nil
...@@ -77,9 +73,6 @@ func (s *UserService) GetProfile(ctx context.Context, userID int64) (*model.User ...@@ -77,9 +73,6 @@ func (s *UserService) GetProfile(ctx context.Context, userID int64) (*model.User
func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*model.User, error) { func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*model.User, error) {
user, err := s.userRepo.GetByID(ctx, userID) user, err := s.userRepo.GetByID(ctx, userID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err) return nil, fmt.Errorf("get user: %w", err)
} }
...@@ -119,9 +112,6 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat ...@@ -119,9 +112,6 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
func (s *UserService) ChangePassword(ctx context.Context, userID int64, req ChangePasswordRequest) error { func (s *UserService) ChangePassword(ctx context.Context, userID int64, req ChangePasswordRequest) error {
user, err := s.userRepo.GetByID(ctx, userID) user, err := s.userRepo.GetByID(ctx, userID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrUserNotFound
}
return fmt.Errorf("get user: %w", err) return fmt.Errorf("get user: %w", err)
} }
...@@ -149,9 +139,6 @@ func (s *UserService) ChangePassword(ctx context.Context, userID int64, req Chan ...@@ -149,9 +139,6 @@ func (s *UserService) ChangePassword(ctx context.Context, userID int64, req Chan
func (s *UserService) GetByID(ctx context.Context, id int64) (*model.User, error) { func (s *UserService) GetByID(ctx context.Context, id int64) (*model.User, error) {
user, err := s.userRepo.GetByID(ctx, id) user, err := s.userRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err) return nil, fmt.Errorf("get user: %w", err)
} }
return user, nil return user, nil
...@@ -178,9 +165,6 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl ...@@ -178,9 +165,6 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl
func (s *UserService) UpdateStatus(ctx context.Context, userID int64, status string) error { func (s *UserService) UpdateStatus(ctx context.Context, userID int64, status string) error {
user, err := s.userRepo.GetByID(ctx, userID) user, err := s.userRepo.GetByID(ctx, userID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrUserNotFound
}
return fmt.Errorf("get user: %w", err) return fmt.Errorf("get user: %w", err)
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment