Commit eeaff85e authored by Forest's avatar Forest
Browse files

refactor: 自定义业务错误

parent f51ad2e1
......@@ -19,13 +19,13 @@ type UsageLogRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *UsageLogRepository
repo *usageLogRepository
}
func (s *UsageLogRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewUsageLogRepository(s.db)
s.repo = NewUsageLogRepository(s.db).(*usageLogRepository)
}
func TestUsageLogRepoSuite(t *testing.T) {
......
......@@ -2,56 +2,61 @@ package repository
import (
"context"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
)
type UserRepository struct {
type userRepository struct {
db *gorm.DB
}
func NewUserRepository(db *gorm.DB) *UserRepository {
return &UserRepository{db: db}
func NewUserRepository(db *gorm.DB) service.UserRepository {
return &userRepository{db: db}
}
func (r *UserRepository) Create(ctx context.Context, user *model.User) error {
return r.db.WithContext(ctx).Create(user).Error
func (r *userRepository) Create(ctx context.Context, user *model.User) error {
err := r.db.WithContext(ctx).Create(user).Error
return translatePersistenceError(err, nil, service.ErrEmailExists)
}
func (r *UserRepository) GetByID(ctx context.Context, id int64) (*model.User, error) {
func (r *userRepository) GetByID(ctx context.Context, id int64) (*model.User, error) {
var user model.User
err := r.db.WithContext(ctx).First(&user, id).Error
if err != nil {
return nil, err
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return &user, nil
}
func (r *UserRepository) GetByEmail(ctx context.Context, email string) (*model.User, error) {
func (r *userRepository) GetByEmail(ctx context.Context, email string) (*model.User, error) {
var user model.User
err := r.db.WithContext(ctx).Where("email = ?", email).First(&user).Error
if err != nil {
return nil, err
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return &user, nil
}
func (r *UserRepository) Update(ctx context.Context, user *model.User) error {
return r.db.WithContext(ctx).Save(user).Error
func (r *userRepository) Update(ctx context.Context, user *model.User) error {
err := r.db.WithContext(ctx).Save(user).Error
return translatePersistenceError(err, nil, service.ErrEmailExists)
}
func (r *UserRepository) Delete(ctx context.Context, id int64) error {
func (r *userRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.User{}, id).Error
}
func (r *UserRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) {
func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "")
}
// ListWithFilters lists users with optional filtering by status, role, and search query
func (r *UserRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error) {
func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error) {
var users []model.User
var total int64
......@@ -120,13 +125,13 @@ func (r *UserRepository) ListWithFilters(ctx context.Context, params pagination.
}, nil
}
func (r *UserRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id).
Update("balance", gorm.Expr("balance + ?", amount)).Error
}
// DeductBalance 扣减用户余额,仅当余额充足时执行
func (r *UserRepository) DeductBalance(ctx context.Context, id int64, amount float64) error {
func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error {
result := r.db.WithContext(ctx).Model(&model.User{}).
Where("id = ? AND balance >= ?", id, amount).
Update("balance", gorm.Expr("balance - ?", amount))
......@@ -134,17 +139,17 @@ func (r *UserRepository) DeductBalance(ctx context.Context, id int64, amount flo
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound // 余额不足或用户不存在
return service.ErrInsufficientBalance
}
return nil
}
func (r *UserRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id).
Update("concurrency", gorm.Expr("concurrency + ?", amount)).Error
}
func (r *UserRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.User{}).Where("email = ?", email).Count(&count).Error
return count > 0, err
......@@ -152,7 +157,7 @@ func (r *UserRepository) ExistsByEmail(ctx context.Context, email string) (bool,
// RemoveGroupFromAllowedGroups 从所有用户的 allowed_groups 数组中移除指定的分组ID
// 使用 PostgreSQL 的 array_remove 函数
func (r *UserRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Model(&model.User{}).
Where("? = ANY(allowed_groups)", groupID).
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", groupID))
......@@ -160,14 +165,14 @@ func (r *UserRepository) RemoveGroupFromAllowedGroups(ctx context.Context, group
}
// GetFirstAdmin 获取第一个管理员用户(用于 Admin API Key 认证)
func (r *UserRepository) GetFirstAdmin(ctx context.Context) (*model.User, error) {
func (r *userRepository) GetFirstAdmin(ctx context.Context) (*model.User, error) {
var user model.User
err := r.db.WithContext(ctx).
Where("role = ? AND status = ?", model.RoleAdmin, model.StatusActive).
Order("id ASC").
First(&user).Error
if err != nil {
return nil, err
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return &user, nil
}
......@@ -9,6 +9,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
......@@ -18,13 +19,13 @@ type UserRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *UserRepository
repo *userRepository
}
func (s *UserRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewUserRepository(s.db)
s.repo = NewUserRepository(s.db).(*userRepository)
}
func TestUserRepoSuite(t *testing.T) {
......@@ -247,7 +248,7 @@ func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() {
err := s.repo.DeductBalance(s.ctx, user.ID, 999)
s.Require().Error(err, "expected error for insufficient balance")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
s.Require().ErrorIs(err, service.ErrInsufficientBalance)
}
func (s *UserRepoSuite) TestDeductBalance_ExactAmount() {
......@@ -432,7 +433,7 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
err = s.repo.DeductBalance(s.ctx, user1.ID, 999)
s.Require().Error(err, "DeductBalance expected error for insufficient balance")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound, "DeductBalance unexpected error")
s.Require().ErrorIs(err, service.ErrInsufficientBalance, "DeductBalance unexpected error")
s.Require().NoError(s.repo.UpdateConcurrency(s.ctx, user1.ID, 3), "UpdateConcurrency")
got5, err := s.repo.GetByID(s.ctx, user1.ID)
......
......@@ -6,27 +6,29 @@ import (
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"gorm.io/gorm"
)
// UserSubscriptionRepository 用户订阅仓库
type UserSubscriptionRepository struct {
type userSubscriptionRepository struct {
db *gorm.DB
}
// NewUserSubscriptionRepository 创建用户订阅仓库
func NewUserSubscriptionRepository(db *gorm.DB) *UserSubscriptionRepository {
return &UserSubscriptionRepository{db: db}
func NewUserSubscriptionRepository(db *gorm.DB) service.UserSubscriptionRepository {
return &userSubscriptionRepository{db: db}
}
// Create 创建订阅
func (r *UserSubscriptionRepository) Create(ctx context.Context, sub *model.UserSubscription) error {
return r.db.WithContext(ctx).Create(sub).Error
func (r *userSubscriptionRepository) Create(ctx context.Context, sub *model.UserSubscription) error {
err := r.db.WithContext(ctx).Create(sub).Error
return translatePersistenceError(err, nil, service.ErrSubscriptionAlreadyExists)
}
// GetByID 根据ID获取订阅
func (r *UserSubscriptionRepository) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) {
func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) {
var sub model.UserSubscription
err := r.db.WithContext(ctx).
Preload("User").
......@@ -34,26 +36,26 @@ func (r *UserSubscriptionRepository) GetByID(ctx context.Context, id int64) (*mo
Preload("AssignedByUser").
First(&sub, id).Error
if err != nil {
return nil, err
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
return &sub, nil
}
// GetByUserIDAndGroupID 根据用户ID和分组ID获取订阅
func (r *UserSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
var sub model.UserSubscription
err := r.db.WithContext(ctx).
Preload("Group").
Where("user_id = ? AND group_id = ?", userID, groupID).
First(&sub).Error
if err != nil {
return nil, err
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
return &sub, nil
}
// GetActiveByUserIDAndGroupID 获取用户对特定分组的有效订阅
func (r *UserSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
var sub model.UserSubscription
err := r.db.WithContext(ctx).
Preload("Group").
......@@ -61,24 +63,24 @@ func (r *UserSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Con
userID, groupID, model.SubscriptionStatusActive, time.Now()).
First(&sub).Error
if err != nil {
return nil, err
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
return &sub, nil
}
// Update 更新订阅
func (r *UserSubscriptionRepository) Update(ctx context.Context, sub *model.UserSubscription) error {
func (r *userSubscriptionRepository) Update(ctx context.Context, sub *model.UserSubscription) error {
sub.UpdatedAt = time.Now()
return r.db.WithContext(ctx).Save(sub).Error
}
// Delete 删除订阅
func (r *UserSubscriptionRepository) Delete(ctx context.Context, id int64) error {
func (r *userSubscriptionRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.UserSubscription{}, id).Error
}
// ListByUserID 获取用户的所有订阅
func (r *UserSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
var subs []model.UserSubscription
err := r.db.WithContext(ctx).
Preload("Group").
......@@ -89,7 +91,7 @@ func (r *UserSubscriptionRepository) ListByUserID(ctx context.Context, userID in
}
// ListActiveByUserID 获取用户的所有有效订阅
func (r *UserSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
var subs []model.UserSubscription
err := r.db.WithContext(ctx).
Preload("Group").
......@@ -101,7 +103,7 @@ func (r *UserSubscriptionRepository) ListActiveByUserID(ctx context.Context, use
}
// ListByGroupID 获取分组的所有订阅(分页)
func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error) {
func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error) {
var subs []model.UserSubscription
var total int64
......@@ -136,7 +138,7 @@ func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
}
// List 获取所有订阅(分页,支持筛选)
func (r *UserSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) {
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) {
var subs []model.UserSubscription
var total int64
......@@ -182,7 +184,7 @@ func (r *UserSubscriptionRepository) List(ctx context.Context, params pagination
}
// IncrementUsage 增加使用量
func (r *UserSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]any{
......@@ -194,7 +196,7 @@ func (r *UserSubscriptionRepository) IncrementUsage(ctx context.Context, id int6
}
// ResetDailyUsage 重置日使用量
func (r *UserSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]any{
......@@ -205,7 +207,7 @@ func (r *UserSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int
}
// ResetWeeklyUsage 重置周使用量
func (r *UserSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]any{
......@@ -216,7 +218,7 @@ func (r *UserSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id in
}
// ResetMonthlyUsage 重置月使用量
func (r *UserSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]any{
......@@ -227,7 +229,7 @@ func (r *UserSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id i
}
// ActivateWindows 激活所有窗口(首次使用时)
func (r *UserSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, activateTime time.Time) error {
func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, activateTime time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]any{
......@@ -239,7 +241,7 @@ func (r *UserSubscriptionRepository) ActivateWindows(ctx context.Context, id int
}
// UpdateStatus 更新订阅状态
func (r *UserSubscriptionRepository) UpdateStatus(ctx context.Context, id int64, status string) error {
func (r *userSubscriptionRepository) UpdateStatus(ctx context.Context, id int64, status string) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]any{
......@@ -249,7 +251,7 @@ func (r *UserSubscriptionRepository) UpdateStatus(ctx context.Context, id int64,
}
// ExtendExpiry 延长订阅过期时间
func (r *UserSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64, newExpiresAt time.Time) error {
func (r *userSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64, newExpiresAt time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]any{
......@@ -259,7 +261,7 @@ func (r *UserSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64,
}
// UpdateNotes 更新订阅备注
func (r *UserSubscriptionRepository) UpdateNotes(ctx context.Context, id int64, notes string) error {
func (r *userSubscriptionRepository) UpdateNotes(ctx context.Context, id int64, notes string) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]any{
......@@ -269,7 +271,7 @@ func (r *UserSubscriptionRepository) UpdateNotes(ctx context.Context, id int64,
}
// ListExpired 获取所有已过期但状态仍为active的订阅
func (r *UserSubscriptionRepository) ListExpired(ctx context.Context) ([]model.UserSubscription, error) {
func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]model.UserSubscription, error) {
var subs []model.UserSubscription
err := r.db.WithContext(ctx).
Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()).
......@@ -278,7 +280,7 @@ func (r *UserSubscriptionRepository) ListExpired(ctx context.Context) ([]model.U
}
// BatchUpdateExpiredStatus 批量更新过期订阅状态
func (r *UserSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
result := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()).
Updates(map[string]any{
......@@ -289,7 +291,7 @@ func (r *UserSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Contex
}
// ExistsByUserIDAndGroupID 检查用户是否已有该分组的订阅
func (r *UserSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
func (r *userSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("user_id = ? AND group_id = ?", userID, groupID).
......@@ -298,7 +300,7 @@ func (r *UserSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Contex
}
// CountByGroupID 获取分组的订阅数量
func (r *UserSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
func (r *userSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("group_id = ?", groupID).
......@@ -307,7 +309,7 @@ func (r *UserSubscriptionRepository) CountByGroupID(ctx context.Context, groupID
}
// CountActiveByGroupID 获取分组的有效订阅数量
func (r *UserSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) {
func (r *userSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("group_id = ? AND status = ? AND expires_at > ?",
......@@ -317,7 +319,7 @@ func (r *UserSubscriptionRepository) CountActiveByGroupID(ctx context.Context, g
}
// DeleteByGroupID 删除分组相关的所有订阅记录
func (r *UserSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) {
func (r *userSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.UserSubscription{})
return result.RowsAffected, result.Error
}
......@@ -17,13 +17,13 @@ type UserSubscriptionRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *UserSubscriptionRepository
repo *userSubscriptionRepository
}
func (s *UserSubscriptionRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewUserSubscriptionRepository(s.db)
s.repo = NewUserSubscriptionRepository(s.db).(*userSubscriptionRepository)
}
func TestUserSubscriptionRepoSuite(t *testing.T) {
......
package repository
import (
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/google/wire"
)
......@@ -37,15 +36,4 @@ var ProviderSet = wire.NewSet(
NewClaudeOAuthClient,
NewHTTPUpstream,
NewOpenAIOAuthClient,
// Bind concrete repositories to service port interfaces
wire.Bind(new(service.UserRepository), new(*UserRepository)),
wire.Bind(new(service.ApiKeyRepository), new(*ApiKeyRepository)),
wire.Bind(new(service.GroupRepository), new(*GroupRepository)),
wire.Bind(new(service.AccountRepository), new(*AccountRepository)),
wire.Bind(new(service.ProxyRepository), new(*ProxyRepository)),
wire.Bind(new(service.RedeemCodeRepository), new(*RedeemCodeRepository)),
wire.Bind(new(service.UsageLogRepository), new(*UsageLogRepository)),
wire.Bind(new(service.SettingRepository), new(*SettingRepository)),
wire.Bind(new(service.UserSubscriptionRepository), new(*UserSubscriptionRepository)),
)
......@@ -2,17 +2,16 @@ package service
import (
"context"
"errors"
"fmt"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
)
var (
ErrAccountNotFound = errors.New("account not found")
ErrAccountNotFound = infraerrors.NotFound("ACCOUNT_NOT_FOUND", "account not found")
)
type AccountRepository interface {
......@@ -106,9 +105,6 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
for _, groupID := range req.GroupIDs {
_, err := s.groupRepo.GetByID(ctx, groupID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("group %d not found", groupID)
}
return nil, fmt.Errorf("get group: %w", err)
}
}
......@@ -145,9 +141,6 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
func (s *AccountService) GetByID(ctx context.Context, id int64) (*model.Account, error) {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrAccountNotFound
}
return nil, fmt.Errorf("get account: %w", err)
}
return account, nil
......@@ -184,9 +177,6 @@ func (s *AccountService) ListByGroup(ctx context.Context, groupID int64) ([]mode
func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccountRequest) (*model.Account, error) {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrAccountNotFound
}
return nil, fmt.Errorf("get account: %w", err)
}
......@@ -229,9 +219,6 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount
for _, groupID := range *req.GroupIDs {
_, err := s.groupRepo.GetByID(ctx, groupID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("group %d not found", groupID)
}
return nil, fmt.Errorf("get group: %w", err)
}
}
......@@ -249,9 +236,6 @@ func (s *AccountService) Delete(ctx context.Context, id int64) error {
// 检查账号是否存在
_, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrAccountNotFound
}
return fmt.Errorf("get account: %w", err)
}
......@@ -266,9 +250,6 @@ func (s *AccountService) Delete(ctx context.Context, id int64) error {
func (s *AccountService) UpdateStatus(ctx context.Context, id int64, status string, errorMessage string) error {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrAccountNotFound
}
return fmt.Errorf("get account: %w", err)
}
......@@ -294,9 +275,6 @@ func (s *AccountService) UpdateLastUsed(ctx context.Context, id int64) error {
func (s *AccountService) GetCredential(ctx context.Context, id int64, key string) (string, error) {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", ErrAccountNotFound
}
return "", fmt.Errorf("get account: %w", err)
}
......@@ -307,9 +285,6 @@ func (s *AccountService) GetCredential(ctx context.Context, id int64, key string
func (s *AccountService) TestCredentials(ctx context.Context, id int64) error {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrAccountNotFound
}
return fmt.Errorf("get account: %w", err)
}
......
......@@ -9,7 +9,6 @@ import (
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
)
// AdminService interface defines admin management operations
......@@ -550,61 +549,7 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
}
func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
// 先获取分组信息,检查是否存在
group, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
return fmt.Errorf("group not found: %w", err)
}
// 订阅类型分组:先获取受影响的用户ID列表(用于事务后失效缓存)
var affectedUserIDs []int64
if group.IsSubscriptionType() && s.billingCacheService != nil {
var subscriptions []model.UserSubscription
if err := s.groupRepo.DB().WithContext(ctx).
Where("group_id = ?", id).
Select("user_id").
Find(&subscriptions).Error; err == nil {
for _, sub := range subscriptions {
affectedUserIDs = append(affectedUserIDs, sub.UserID)
}
}
}
// 使用事务处理所有级联删除
db := s.groupRepo.DB()
err = db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// 1. 如果是订阅类型分组,删除 user_subscriptions 中的相关记录
if group.IsSubscriptionType() {
if err := tx.Where("group_id = ?", id).Delete(&model.UserSubscription{}).Error; err != nil {
return fmt.Errorf("delete user subscriptions: %w", err)
}
}
// 2. 将 api_keys 中绑定该分组的 group_id 设为 nil(任何类型的分组都需要)
if err := tx.Model(&model.ApiKey{}).Where("group_id = ?", id).Update("group_id", nil).Error; err != nil {
return fmt.Errorf("clear api key group_id: %w", err)
}
// 3. 从 users.allowed_groups 数组中移除该分组 ID
if err := tx.Model(&model.User{}).
Where("? = ANY(allowed_groups)", id).
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", id)).Error; err != nil {
return fmt.Errorf("remove from allowed_groups: %w", err)
}
// 4. 删除 account_groups 中间表的数据
if err := tx.Where("group_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil {
return fmt.Errorf("delete account groups: %w", err)
}
// 5. 删除分组本身
if err := tx.Delete(&model.Group{}, id).Error; err != nil {
return fmt.Errorf("delete group: %w", err)
}
return nil
})
affectedUserIDs, err := s.groupRepo.DeleteCascade(ctx, id)
if err != nil {
return err
}
......
......@@ -9,20 +9,20 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
)
var (
ErrApiKeyNotFound = errors.New("api key not found")
ErrGroupNotAllowed = errors.New("user is not allowed to bind this group")
ErrApiKeyExists = errors.New("api key already exists")
ErrApiKeyTooShort = errors.New("api key must be at least 16 characters")
ErrApiKeyInvalidChars = errors.New("api key can only contain letters, numbers, underscores, and hyphens")
ErrApiKeyRateLimited = errors.New("too many failed attempts, please try again later")
ErrApiKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found")
ErrGroupNotAllowed = infraerrors.Forbidden("GROUP_NOT_ALLOWED", "user is not allowed to bind this group")
ErrApiKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists")
ErrApiKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
ErrApiKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
ErrApiKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
)
const (
......@@ -183,9 +183,6 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
// 验证用户存在
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err)
}
......@@ -193,9 +190,6 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
if req.GroupID != nil {
group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("group not found")
}
return nil, fmt.Errorf("get group: %w", err)
}
......@@ -269,9 +263,6 @@ func (s *ApiKeyService) List(ctx context.Context, userID int64, params paginatio
func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrApiKeyNotFound
}
return nil, fmt.Errorf("get api key: %w", err)
}
return apiKey, nil
......@@ -285,9 +276,6 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey
// 这里可以添加Redis缓存逻辑,暂时直接查询数据库
apiKey, err := s.apiKeyRepo.GetByKey(ctx, key)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrApiKeyNotFound
}
return nil, fmt.Errorf("get api key: %w", err)
}
......@@ -304,9 +292,6 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey
func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*model.ApiKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrApiKeyNotFound
}
return nil, fmt.Errorf("get api key: %w", err)
}
......@@ -329,9 +314,6 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req
group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("group not found")
}
return nil, fmt.Errorf("get group: %w", err)
}
......@@ -361,9 +343,6 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req
func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) error {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrApiKeyNotFound
}
return fmt.Errorf("get api key: %w", err)
}
......@@ -394,15 +373,12 @@ func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*model.Api
// 检查API Key状态
if !apiKey.IsActive() {
return nil, nil, errors.New("api key is not active")
return nil, nil, infraerrors.Unauthorized("API_KEY_INACTIVE", "api key is not active")
}
// 获取用户信息
user, err := s.userRepo.GetByID(ctx, apiKey.UserID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil, ErrUserNotFound
}
return nil, nil, fmt.Errorf("get user: %w", err)
}
......@@ -436,9 +412,6 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([
// 获取用户信息
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err)
}
......@@ -450,7 +423,7 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([
// 获取用户的所有有效订阅
activeSubscriptions, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
if err != nil {
return nil, fmt.Errorf("list active subscriptions: %w", err)
}
......
......@@ -8,22 +8,22 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/golang-jwt/jwt/v5"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
var (
ErrInvalidCredentials = errors.New("invalid email or password")
ErrUserNotActive = errors.New("user is not active")
ErrEmailExists = errors.New("email already exists")
ErrInvalidToken = errors.New("invalid token")
ErrTokenExpired = errors.New("token has expired")
ErrEmailVerifyRequired = errors.New("email verification is required")
ErrRegDisabled = errors.New("registration is currently disabled")
ErrServiceUnavailable = errors.New("service temporarily unavailable")
ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password")
ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active")
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
)
// JWTClaims JWT载荷数据
......@@ -255,7 +255,7 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
// 查找用户
user, err := s.userRepo.GetByEmail(ctx, email)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
if errors.Is(err, ErrUserNotFound) {
return "", nil, ErrInvalidCredentials
}
// 记录数据库错误但不暴露给用户
......@@ -357,7 +357,7 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
// 获取最新的用户信息
user, err := s.userRepo.GetByID(ctx, claims.UserID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
if errors.Is(err, ErrUserNotFound) {
return "", ErrInvalidToken
}
log.Printf("[Auth] Database error refreshing token: %v", err)
......
......@@ -2,11 +2,11 @@ package service
import (
"context"
"errors"
"fmt"
"log"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
)
......@@ -14,7 +14,7 @@ import (
// 注:ErrInsufficientBalance在redeem_service.go中定义
// 注:ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义
var (
ErrSubscriptionInvalid = errors.New("subscription is invalid or expired")
ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
)
// subscriptionCacheData 订阅缓存数据结构(内部使用)
......
......@@ -4,21 +4,21 @@ import (
"context"
"crypto/rand"
"crypto/tls"
"errors"
"fmt"
"math/big"
"net/smtp"
"strconv"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
)
var (
ErrEmailNotConfigured = errors.New("email service not configured")
ErrInvalidVerifyCode = errors.New("invalid or expired verification code")
ErrVerifyCodeTooFrequent = errors.New("please wait before requesting a new code")
ErrVerifyCodeMaxAttempts = errors.New("too many failed attempts, please request a new code")
ErrEmailNotConfigured = infraerrors.ServiceUnavailable("EMAIL_NOT_CONFIGURED", "email service not configured")
ErrInvalidVerifyCode = infraerrors.BadRequest("INVALID_VERIFY_CODE", "invalid or expired verification code")
ErrVerifyCodeTooFrequent = infraerrors.TooManyRequests("VERIFY_CODE_TOO_FREQUENT", "please wait before requesting a new code")
ErrVerifyCodeMaxAttempts = infraerrors.TooManyRequests("VERIFY_CODE_MAX_ATTEMPTS", "too many failed attempts, please request a new code")
)
// EmailCache defines cache operations for email service
......
......@@ -2,17 +2,16 @@ package service
import (
"context"
"errors"
"fmt"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
)
var (
ErrGroupNotFound = errors.New("group not found")
ErrGroupExists = errors.New("group name already exists")
ErrGroupNotFound = infraerrors.NotFound("GROUP_NOT_FOUND", "group not found")
ErrGroupExists = infraerrors.Conflict("GROUP_EXISTS", "group name already exists")
)
type GroupRepository interface {
......@@ -20,6 +19,7 @@ type GroupRepository interface {
GetByID(ctx context.Context, id int64) (*model.Group, error)
Update(ctx context.Context, group *model.Group) error
Delete(ctx context.Context, id int64) error
DeleteCascade(ctx context.Context, id int64) ([]int64, error)
List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error)
......@@ -29,8 +29,6 @@ type GroupRepository interface {
ExistsByName(ctx context.Context, name string) (bool, error)
GetAccountCount(ctx context.Context, groupID int64) (int64, error)
DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error)
DB() *gorm.DB
}
// CreateGroupRequest 创建分组请求
......@@ -93,9 +91,6 @@ func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*mod
func (s *GroupService) GetByID(ctx context.Context, id int64) (*model.Group, error) {
group, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrGroupNotFound
}
return nil, fmt.Errorf("get group: %w", err)
}
return group, nil
......@@ -123,9 +118,6 @@ func (s *GroupService) ListActive(ctx context.Context) ([]model.Group, error) {
func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequest) (*model.Group, error) {
group, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrGroupNotFound
}
return nil, fmt.Errorf("get group: %w", err)
}
......@@ -170,9 +162,6 @@ func (s *GroupService) Delete(ctx context.Context, id int64) error {
// 检查分组是否存在
_, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrGroupNotFound
}
return fmt.Errorf("get group: %w", err)
}
......@@ -187,9 +176,6 @@ func (s *GroupService) Delete(ctx context.Context, id int64) error {
func (s *GroupService) GetStats(ctx context.Context, id int64) (map[string]any, error) {
group, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrGroupNotFound
}
return nil, fmt.Errorf("get group: %w", err)
}
......
......@@ -2,16 +2,15 @@ package service
import (
"context"
"errors"
"fmt"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
)
var (
ErrProxyNotFound = errors.New("proxy not found")
ErrProxyNotFound = infraerrors.NotFound("PROXY_NOT_FOUND", "proxy not found")
)
type ProxyRepository interface {
......@@ -86,9 +85,6 @@ func (s *ProxyService) Create(ctx context.Context, req CreateProxyRequest) (*mod
func (s *ProxyService) GetByID(ctx context.Context, id int64) (*model.Proxy, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrProxyNotFound
}
return nil, fmt.Errorf("get proxy: %w", err)
}
return proxy, nil
......@@ -116,9 +112,6 @@ func (s *ProxyService) ListActive(ctx context.Context) ([]model.Proxy, error) {
func (s *ProxyService) Update(ctx context.Context, id int64, req UpdateProxyRequest) (*model.Proxy, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrProxyNotFound
}
return nil, fmt.Errorf("get proxy: %w", err)
}
......@@ -163,9 +156,6 @@ func (s *ProxyService) Delete(ctx context.Context, id int64) error {
// 检查代理是否存在
_, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrProxyNotFound
}
return fmt.Errorf("get proxy: %w", err)
}
......@@ -180,9 +170,6 @@ func (s *ProxyService) Delete(ctx context.Context, id int64) error {
func (s *ProxyService) TestConnection(ctx context.Context, id int64) error {
proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrProxyNotFound
}
return fmt.Errorf("get proxy: %w", err)
}
......@@ -197,9 +184,6 @@ func (s *ProxyService) TestConnection(ctx context.Context, id int64) error {
func (s *ProxyService) GetURL(ctx context.Context, id int64) (string, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", ErrProxyNotFound
}
return "", fmt.Errorf("get proxy: %w", err)
}
......
......@@ -9,19 +9,18 @@ import (
"strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
)
var (
ErrRedeemCodeNotFound = errors.New("redeem code not found")
ErrRedeemCodeUsed = errors.New("redeem code already used")
ErrRedeemCodeInvalid = errors.New("invalid redeem code")
ErrInsufficientBalance = errors.New("insufficient balance")
ErrRedeemRateLimited = errors.New("too many failed attempts, please try again later")
ErrRedeemCodeLocked = errors.New("redeem code is being processed, please try again")
ErrRedeemCodeNotFound = infraerrors.NotFound("REDEEM_CODE_NOT_FOUND", "redeem code not found")
ErrRedeemCodeUsed = infraerrors.Conflict("REDEEM_CODE_USED", "redeem code already used")
ErrInsufficientBalance = infraerrors.BadRequest("INSUFFICIENT_BALANCE", "insufficient balance")
ErrRedeemRateLimited = infraerrors.TooManyRequests("REDEEM_RATE_LIMITED", "too many failed attempts, please try again later")
ErrRedeemCodeLocked = infraerrors.Conflict("REDEEM_CODE_LOCKED", "redeem code is being processed, please try again")
)
const (
......@@ -226,7 +225,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
// 查找兑换码
redeemCode, err := s.redeemRepo.GetByCode(ctx, code)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
if errors.Is(err, ErrRedeemCodeNotFound) {
s.incrementRedeemErrorCount(ctx, userID)
return nil, ErrRedeemCodeNotFound
}
......@@ -241,15 +240,12 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
// 验证兑换码类型的前置条件
if redeemCode.Type == model.RedeemTypeSubscription && redeemCode.GroupID == nil {
return nil, errors.New("invalid subscription redeem code: missing group_id")
return nil, infraerrors.BadRequest("REDEEM_CODE_INVALID", "invalid subscription redeem code: missing group_id")
}
// 获取用户信息
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err)
}
_ = user // 使用变量避免未使用错误
......@@ -257,8 +253,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
// 【关键】先标记兑换码为已使用,确保并发安全
// 利用数据库乐观锁(WHERE status = 'unused')保证原子性
if err := s.redeemRepo.Use(ctx, redeemCode.ID, userID); err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
// 兑换码已被其他请求使用
if errors.Is(err, ErrRedeemCodeNotFound) || errors.Is(err, ErrRedeemCodeUsed) {
return nil, ErrRedeemCodeUsed
}
return nil, fmt.Errorf("mark code as used: %w", err)
......@@ -328,9 +323,6 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
func (s *RedeemService) GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) {
code, err := s.redeemRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrRedeemCodeNotFound
}
return nil, fmt.Errorf("get redeem code: %w", err)
}
return code, nil
......@@ -340,9 +332,6 @@ func (s *RedeemService) GetByID(ctx context.Context, id int64) (*model.RedeemCod
func (s *RedeemService) GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) {
redeemCode, err := s.redeemRepo.GetByCode(ctx, code)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrRedeemCodeNotFound
}
return nil, fmt.Errorf("get redeem code: %w", err)
}
return redeemCode, nil
......@@ -362,15 +351,12 @@ func (s *RedeemService) Delete(ctx context.Context, id int64) error {
// 检查兑换码是否存在
code, err := s.redeemRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrRedeemCodeNotFound
}
return fmt.Errorf("get redeem code: %w", err)
}
// 不允许删除已使用的兑换码
if code.IsUsed() {
return errors.New("cannot delete used redeem code")
return infraerrors.Conflict("REDEEM_CODE_DELETE_USED", "cannot delete used redeem code")
}
if err := s.redeemRepo.Delete(ctx, id); err != nil {
......
......@@ -9,13 +9,13 @@ import (
"strconv"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"gorm.io/gorm"
)
var (
ErrRegistrationDisabled = errors.New("registration is currently disabled")
ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found")
)
type SettingRepository interface {
......@@ -187,7 +187,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// 已有设置,不需要初始化
return nil
}
if !errors.Is(err, gorm.ErrRecordNotFound) {
if !errors.Is(err, ErrSettingNotFound) {
return fmt.Errorf("check existing settings: %w", err)
}
......@@ -302,7 +302,7 @@ func (s *SettingService) GenerateAdminApiKey(ctx context.Context) (string, error
func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey string, exists bool, err error) {
key, err := s.settingRepo.GetValue(ctx, model.SettingKeyAdminApiKey)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
if errors.Is(err, ErrSettingNotFound) {
return "", false, nil
}
return "", false, err
......@@ -326,7 +326,7 @@ func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey st
func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) {
key, err := s.settingRepo.GetValue(ctx, model.SettingKeyAdminApiKey)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
if errors.Is(err, ErrSettingNotFound) {
return "", nil // 未配置,返回空字符串
}
return "", err // 数据库错误
......
......@@ -2,24 +2,24 @@ package service
import (
"context"
"errors"
"fmt"
"log"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
var (
ErrSubscriptionNotFound = errors.New("subscription not found")
ErrSubscriptionExpired = errors.New("subscription has expired")
ErrSubscriptionSuspended = errors.New("subscription is suspended")
ErrSubscriptionAlreadyExists = errors.New("subscription already exists for this user and group")
ErrGroupNotSubscriptionType = errors.New("group is not a subscription type")
ErrDailyLimitExceeded = errors.New("daily usage limit exceeded")
ErrWeeklyLimitExceeded = errors.New("weekly usage limit exceeded")
ErrMonthlyLimitExceeded = errors.New("monthly usage limit exceeded")
ErrSubscriptionNotFound = infraerrors.NotFound("SUBSCRIPTION_NOT_FOUND", "subscription not found")
ErrSubscriptionExpired = infraerrors.Forbidden("SUBSCRIPTION_EXPIRED", "subscription has expired")
ErrSubscriptionSuspended = infraerrors.Forbidden("SUBSCRIPTION_SUSPENDED", "subscription is suspended")
ErrSubscriptionAlreadyExists = infraerrors.Conflict("SUBSCRIPTION_ALREADY_EXISTS", "subscription already exists for this user and group")
ErrGroupNotSubscriptionType = infraerrors.BadRequest("GROUP_NOT_SUBSCRIPTION_TYPE", "group is not a subscription type")
ErrDailyLimitExceeded = infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", "daily usage limit exceeded")
ErrWeeklyLimitExceeded = infraerrors.TooManyRequests("WEEKLY_LIMIT_EXCEEDED", "weekly usage limit exceeded")
ErrMonthlyLimitExceeded = infraerrors.TooManyRequests("MONTHLY_LIMIT_EXCEEDED", "monthly usage limit exceeded")
)
// SubscriptionService 订阅服务
......
......@@ -2,14 +2,15 @@ package service
import (
"context"
"errors"
"fmt"
"log"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
)
var (
ErrTurnstileVerificationFailed = errors.New("turnstile verification failed")
ErrTurnstileNotConfigured = errors.New("turnstile not configured")
ErrTurnstileVerificationFailed = infraerrors.BadRequest("TURNSTILE_VERIFICATION_FAILED", "turnstile verification failed")
ErrTurnstileNotConfigured = infraerrors.ServiceUnavailable("TURNSTILE_NOT_CONFIGURED", "turnstile not configured")
)
// TurnstileVerifier 验证 Turnstile token 的接口
......
......@@ -2,18 +2,17 @@ package service
import (
"context"
"errors"
"fmt"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"gorm.io/gorm"
)
var (
ErrUsageLogNotFound = errors.New("usage log not found")
ErrUsageLogNotFound = infraerrors.NotFound("USAGE_LOG_NOT_FOUND", "usage log not found")
)
// CreateUsageLogRequest 创建使用日志请求
......@@ -71,9 +70,6 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
// 验证用户存在
_, err := s.userRepo.GetByID(ctx, req.UserID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err)
}
......@@ -119,9 +115,6 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
func (s *UsageService) GetByID(ctx context.Context, id int64) (*model.UsageLog, error) {
log, err := s.usageRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUsageLogNotFound
}
return nil, fmt.Errorf("get usage log: %w", err)
}
return log, nil
......
......@@ -2,19 +2,18 @@ package service
import (
"context"
"errors"
"fmt"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
var (
ErrUserNotFound = errors.New("user not found")
ErrPasswordIncorrect = errors.New("current password is incorrect")
ErrInsufficientPerms = errors.New("insufficient permissions")
ErrUserNotFound = infraerrors.NotFound("USER_NOT_FOUND", "user not found")
ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect")
ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions")
)
type UserRepository interface {
......@@ -65,9 +64,6 @@ func NewUserService(userRepo UserRepository) *UserService {
func (s *UserService) GetProfile(ctx context.Context, userID int64) (*model.User, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err)
}
return user, nil
......@@ -77,9 +73,6 @@ func (s *UserService) GetProfile(ctx context.Context, userID int64) (*model.User
func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*model.User, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err)
}
......@@ -119,9 +112,6 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
func (s *UserService) ChangePassword(ctx context.Context, userID int64, req ChangePasswordRequest) error {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrUserNotFound
}
return fmt.Errorf("get user: %w", err)
}
......@@ -149,9 +139,6 @@ func (s *UserService) ChangePassword(ctx context.Context, userID int64, req Chan
func (s *UserService) GetByID(ctx context.Context, id int64) (*model.User, error) {
user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err)
}
return user, nil
......@@ -178,9 +165,6 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl
func (s *UserService) UpdateStatus(ctx context.Context, userID int64, status string) error {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrUserNotFound
}
return fmt.Errorf("get user: %w", err)
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment