Commit bfcd9501 authored by IanShaw027's avatar IanShaw027
Browse files

merge: 合并 upstream/main 解决 PR #37 冲突

- 删除 backend/internal/model/account.go 符合重构方向
- 合并最新的项目结构重构
- 包含 SSE 格式解析修复
- 更新依赖和配置文件
parents 9780f0fd 12252c60
...@@ -4,10 +4,8 @@ import ( ...@@ -4,10 +4,8 @@ import (
"context" "context"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"gorm.io/gorm" "gorm.io/gorm"
) )
...@@ -20,48 +18,61 @@ func NewRedeemCodeRepository(db *gorm.DB) service.RedeemCodeRepository { ...@@ -20,48 +18,61 @@ func NewRedeemCodeRepository(db *gorm.DB) service.RedeemCodeRepository {
return &redeemCodeRepository{db: db} return &redeemCodeRepository{db: db}
} }
func (r *redeemCodeRepository) Create(ctx context.Context, code *model.RedeemCode) error { func (r *redeemCodeRepository) Create(ctx context.Context, code *service.RedeemCode) error {
return r.db.WithContext(ctx).Create(code).Error m := redeemCodeModelFromService(code)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyRedeemCodeModelToService(code, m)
}
return err
} }
func (r *redeemCodeRepository) CreateBatch(ctx context.Context, codes []model.RedeemCode) error { func (r *redeemCodeRepository) CreateBatch(ctx context.Context, codes []service.RedeemCode) error {
return r.db.WithContext(ctx).Create(&codes).Error if len(codes) == 0 {
return nil
}
models := make([]redeemCodeModel, 0, len(codes))
for i := range codes {
m := redeemCodeModelFromService(&codes[i])
if m != nil {
models = append(models, *m)
}
}
return r.db.WithContext(ctx).Create(&models).Error
} }
func (r *redeemCodeRepository) GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) { func (r *redeemCodeRepository) GetByID(ctx context.Context, id int64) (*service.RedeemCode, error) {
var code model.RedeemCode var m redeemCodeModel
err := r.db.WithContext(ctx).First(&code, id).Error err := r.db.WithContext(ctx).First(&m, id).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrRedeemCodeNotFound, nil) return nil, translatePersistenceError(err, service.ErrRedeemCodeNotFound, nil)
} }
return &code, nil return redeemCodeModelToService(&m), nil
} }
func (r *redeemCodeRepository) GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) { func (r *redeemCodeRepository) GetByCode(ctx context.Context, code string) (*service.RedeemCode, error) {
var redeemCode model.RedeemCode var m redeemCodeModel
err := r.db.WithContext(ctx).Where("code = ?", code).First(&redeemCode).Error err := r.db.WithContext(ctx).Where("code = ?", code).First(&m).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrRedeemCodeNotFound, nil) return nil, translatePersistenceError(err, service.ErrRedeemCodeNotFound, nil)
} }
return &redeemCode, nil return redeemCodeModelToService(&m), nil
} }
func (r *redeemCodeRepository) Delete(ctx context.Context, id int64) error { func (r *redeemCodeRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.RedeemCode{}, id).Error return r.db.WithContext(ctx).Delete(&redeemCodeModel{}, id).Error
} }
func (r *redeemCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) { func (r *redeemCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.RedeemCode, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "") return r.ListWithFilters(ctx, params, "", "", "")
} }
// ListWithFilters lists redeem codes with optional filtering by type, status, and search query func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]service.RedeemCode, *pagination.PaginationResult, error) {
func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]model.RedeemCode, *pagination.PaginationResult, error) { var codes []redeemCodeModel
var codes []model.RedeemCode
var total int64 var total int64
db := r.db.WithContext(ctx).Model(&model.RedeemCode{}) db := r.db.WithContext(ctx).Model(&redeemCodeModel{})
// Apply filters
if codeType != "" { if codeType != "" {
db = db.Where("type = ?", codeType) db = db.Where("type = ?", codeType)
} }
...@@ -81,29 +92,29 @@ func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagin ...@@ -81,29 +92,29 @@ func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagin
return nil, nil, err return nil, nil, err
} }
pages := int(total) / params.Limit() outCodes := make([]service.RedeemCode, 0, len(codes))
if int(total)%params.Limit() > 0 { for i := range codes {
pages++ outCodes = append(outCodes, *redeemCodeModelToService(&codes[i]))
} }
return codes, &pagination.PaginationResult{ return outCodes, paginationResultFromTotal(total, params), nil
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
} }
func (r *redeemCodeRepository) Update(ctx context.Context, code *model.RedeemCode) error { func (r *redeemCodeRepository) Update(ctx context.Context, code *service.RedeemCode) error {
return r.db.WithContext(ctx).Save(code).Error m := redeemCodeModelFromService(code)
err := r.db.WithContext(ctx).Save(m).Error
if err == nil {
applyRedeemCodeModelToService(code, m)
}
return err
} }
func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error { func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error {
now := time.Now() now := time.Now()
result := r.db.WithContext(ctx).Model(&model.RedeemCode{}). result := r.db.WithContext(ctx).Model(&redeemCodeModel{}).
Where("id = ? AND status = ?", id, model.StatusUnused). Where("id = ? AND status = ?", id, service.StatusUnused).
Updates(map[string]any{ Updates(map[string]any{
"status": model.StatusUsed, "status": service.StatusUsed,
"used_by": userID, "used_by": userID,
"used_at": now, "used_at": now,
}) })
...@@ -116,22 +127,93 @@ func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error ...@@ -116,22 +127,93 @@ func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error
return nil return nil
} }
// ListByUser returns all redeem codes used by a specific user func (r *redeemCodeRepository) ListByUser(ctx context.Context, userID int64, limit int) ([]service.RedeemCode, error) {
func (r *redeemCodeRepository) ListByUser(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error) {
var codes []model.RedeemCode
if limit <= 0 { if limit <= 0 {
limit = 10 limit = 10
} }
var codes []redeemCodeModel
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Preload("Group"). Preload("Group").
Where("used_by = ?", userID). Where("used_by = ?", userID).
Order("used_at DESC"). Order("used_at DESC").
Limit(limit). Limit(limit).
Find(&codes).Error Find(&codes).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
return codes, nil
outCodes := make([]service.RedeemCode, 0, len(codes))
for i := range codes {
outCodes = append(outCodes, *redeemCodeModelToService(&codes[i]))
}
return outCodes, nil
}
type redeemCodeModel struct {
ID int64 `gorm:"primaryKey"`
Code string `gorm:"uniqueIndex;size:32;not null"`
Type string `gorm:"size:20;default:balance;not null"`
Value float64 `gorm:"type:decimal(20,8);not null"`
Status string `gorm:"size:20;default:unused;not null"`
UsedBy *int64 `gorm:"index"`
UsedAt *time.Time
Notes string `gorm:"type:text"`
CreatedAt time.Time `gorm:"not null"`
GroupID *int64 `gorm:"index"`
ValidityDays int `gorm:"default:30"`
User *userModel `gorm:"foreignKey:UsedBy"`
Group *groupModel `gorm:"foreignKey:GroupID"`
}
func (redeemCodeModel) TableName() string { return "redeem_codes" }
func redeemCodeModelToService(m *redeemCodeModel) *service.RedeemCode {
if m == nil {
return nil
}
return &service.RedeemCode{
ID: m.ID,
Code: m.Code,
Type: m.Type,
Value: m.Value,
Status: m.Status,
UsedBy: m.UsedBy,
UsedAt: m.UsedAt,
Notes: m.Notes,
CreatedAt: m.CreatedAt,
GroupID: m.GroupID,
ValidityDays: m.ValidityDays,
User: userModelToService(m.User),
Group: groupModelToService(m.Group),
}
}
func redeemCodeModelFromService(r *service.RedeemCode) *redeemCodeModel {
if r == nil {
return nil
}
return &redeemCodeModel{
ID: r.ID,
Code: r.Code,
Type: r.Type,
Value: r.Value,
Status: r.Status,
UsedBy: r.UsedBy,
UsedAt: r.UsedAt,
Notes: r.Notes,
CreatedAt: r.CreatedAt,
GroupID: r.GroupID,
ValidityDays: r.ValidityDays,
}
}
func applyRedeemCodeModelToService(code *service.RedeemCode, m *redeemCodeModel) {
if code == nil || m == nil {
return
}
code.ID = m.ID
code.CreatedAt = m.CreatedAt
} }
...@@ -7,7 +7,6 @@ import ( ...@@ -7,7 +7,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
...@@ -34,11 +33,11 @@ func TestRedeemCodeRepoSuite(t *testing.T) { ...@@ -34,11 +33,11 @@ func TestRedeemCodeRepoSuite(t *testing.T) {
// --- Create / CreateBatch / GetByID / GetByCode --- // --- Create / CreateBatch / GetByID / GetByCode ---
func (s *RedeemCodeRepoSuite) TestCreate() { func (s *RedeemCodeRepoSuite) TestCreate() {
code := &model.RedeemCode{ code := &service.RedeemCode{
Code: "TEST-CREATE", Code: "TEST-CREATE",
Type: model.RedeemTypeBalance, Type: service.RedeemTypeBalance,
Value: 100, Value: 100,
Status: model.StatusUnused, Status: service.StatusUnused,
} }
err := s.repo.Create(s.ctx, code) err := s.repo.Create(s.ctx, code)
...@@ -51,9 +50,9 @@ func (s *RedeemCodeRepoSuite) TestCreate() { ...@@ -51,9 +50,9 @@ func (s *RedeemCodeRepoSuite) TestCreate() {
} }
func (s *RedeemCodeRepoSuite) TestCreateBatch() { func (s *RedeemCodeRepoSuite) TestCreateBatch() {
codes := []model.RedeemCode{ codes := []service.RedeemCode{
{Code: "BATCH-1", Type: model.RedeemTypeBalance, Value: 10, Status: model.StatusUnused}, {Code: "BATCH-1", Type: service.RedeemTypeBalance, Value: 10, Status: service.StatusUnused},
{Code: "BATCH-2", Type: model.RedeemTypeBalance, Value: 20, Status: model.StatusUnused}, {Code: "BATCH-2", Type: service.RedeemTypeBalance, Value: 20, Status: service.StatusUnused},
} }
err := s.repo.CreateBatch(s.ctx, codes) err := s.repo.CreateBatch(s.ctx, codes)
...@@ -74,7 +73,7 @@ func (s *RedeemCodeRepoSuite) TestGetByID_NotFound() { ...@@ -74,7 +73,7 @@ func (s *RedeemCodeRepoSuite) TestGetByID_NotFound() {
} }
func (s *RedeemCodeRepoSuite) TestGetByCode() { func (s *RedeemCodeRepoSuite) TestGetByCode() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "GET-BY-CODE", Type: model.RedeemTypeBalance}) mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "GET-BY-CODE", Type: service.RedeemTypeBalance})
got, err := s.repo.GetByCode(s.ctx, "GET-BY-CODE") got, err := s.repo.GetByCode(s.ctx, "GET-BY-CODE")
s.Require().NoError(err, "GetByCode") s.Require().NoError(err, "GetByCode")
...@@ -89,7 +88,7 @@ func (s *RedeemCodeRepoSuite) TestGetByCode_NotFound() { ...@@ -89,7 +88,7 @@ func (s *RedeemCodeRepoSuite) TestGetByCode_NotFound() {
// --- Delete --- // --- Delete ---
func (s *RedeemCodeRepoSuite) TestDelete() { func (s *RedeemCodeRepoSuite) TestDelete() {
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "TO-DELETE", Type: model.RedeemTypeBalance}) code := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "TO-DELETE", Type: service.RedeemTypeBalance})
err := s.repo.Delete(s.ctx, code.ID) err := s.repo.Delete(s.ctx, code.ID)
s.Require().NoError(err, "Delete") s.Require().NoError(err, "Delete")
...@@ -101,8 +100,8 @@ func (s *RedeemCodeRepoSuite) TestDelete() { ...@@ -101,8 +100,8 @@ func (s *RedeemCodeRepoSuite) TestDelete() {
// --- List / ListWithFilters --- // --- List / ListWithFilters ---
func (s *RedeemCodeRepoSuite) TestList() { func (s *RedeemCodeRepoSuite) TestList() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "LIST-1", Type: model.RedeemTypeBalance}) mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "LIST-1", Type: service.RedeemTypeBalance})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "LIST-2", Type: model.RedeemTypeBalance}) mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "LIST-2", Type: service.RedeemTypeBalance})
codes, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}) codes, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List") s.Require().NoError(err, "List")
...@@ -111,28 +110,28 @@ func (s *RedeemCodeRepoSuite) TestList() { ...@@ -111,28 +110,28 @@ func (s *RedeemCodeRepoSuite) TestList() {
} }
func (s *RedeemCodeRepoSuite) TestListWithFilters_Type() { func (s *RedeemCodeRepoSuite) TestListWithFilters_Type() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "TYPE-BAL", Type: model.RedeemTypeBalance}) mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "TYPE-BAL", Type: service.RedeemTypeBalance})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "TYPE-SUB", Type: model.RedeemTypeSubscription}) mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "TYPE-SUB", Type: service.RedeemTypeSubscription})
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.RedeemTypeSubscription, "", "") codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.RedeemTypeSubscription, "", "")
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(codes, 1) s.Require().Len(codes, 1)
s.Require().Equal(model.RedeemTypeSubscription, codes[0].Type) s.Require().Equal(service.RedeemTypeSubscription, codes[0].Type)
} }
func (s *RedeemCodeRepoSuite) TestListWithFilters_Status() { func (s *RedeemCodeRepoSuite) TestListWithFilters_Status() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "STAT-UNUSED", Type: model.RedeemTypeBalance, Status: model.StatusUnused}) mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "STAT-UNUSED", Type: service.RedeemTypeBalance, Status: service.StatusUnused})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "STAT-USED", Type: model.RedeemTypeBalance, Status: model.StatusUsed}) mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "STAT-USED", Type: service.RedeemTypeBalance, Status: service.StatusUsed})
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusUsed, "") codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusUsed, "")
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(codes, 1) s.Require().Len(codes, 1)
s.Require().Equal(model.StatusUsed, codes[0].Status) s.Require().Equal(service.StatusUsed, codes[0].Status)
} }
func (s *RedeemCodeRepoSuite) TestListWithFilters_Search() { func (s *RedeemCodeRepoSuite) TestListWithFilters_Search() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "ALPHA-CODE", Type: model.RedeemTypeBalance}) mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "ALPHA-CODE", Type: service.RedeemTypeBalance})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "BETA-CODE", Type: model.RedeemTypeBalance}) mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "BETA-CODE", Type: service.RedeemTypeBalance})
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alpha") codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alpha")
s.Require().NoError(err) s.Require().NoError(err)
...@@ -141,10 +140,10 @@ func (s *RedeemCodeRepoSuite) TestListWithFilters_Search() { ...@@ -141,10 +140,10 @@ func (s *RedeemCodeRepoSuite) TestListWithFilters_Search() {
} }
func (s *RedeemCodeRepoSuite) TestListWithFilters_GroupPreload() { func (s *RedeemCodeRepoSuite) TestListWithFilters_GroupPreload() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-preload"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-preload"})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{ mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
Code: "WITH-GROUP", Code: "WITH-GROUP",
Type: model.RedeemTypeSubscription, Type: service.RedeemTypeSubscription,
GroupID: &group.ID, GroupID: &group.ID,
}) })
...@@ -158,7 +157,7 @@ func (s *RedeemCodeRepoSuite) TestListWithFilters_GroupPreload() { ...@@ -158,7 +157,7 @@ func (s *RedeemCodeRepoSuite) TestListWithFilters_GroupPreload() {
// --- Update --- // --- Update ---
func (s *RedeemCodeRepoSuite) TestUpdate() { func (s *RedeemCodeRepoSuite) TestUpdate() {
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "UPDATE-ME", Type: model.RedeemTypeBalance, Value: 10}) code := redeemCodeModelToService(mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "UPDATE-ME", Type: service.RedeemTypeBalance, Value: 10}))
code.Value = 50 code.Value = 50
err := s.repo.Update(s.ctx, code) err := s.repo.Update(s.ctx, code)
...@@ -172,23 +171,23 @@ func (s *RedeemCodeRepoSuite) TestUpdate() { ...@@ -172,23 +171,23 @@ func (s *RedeemCodeRepoSuite) TestUpdate() {
// --- Use --- // --- Use ---
func (s *RedeemCodeRepoSuite) TestUse() { func (s *RedeemCodeRepoSuite) TestUse() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "use@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "use@test.com"})
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "USE-ME", Type: model.RedeemTypeBalance, Status: model.StatusUnused}) code := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "USE-ME", Type: service.RedeemTypeBalance, Status: service.StatusUnused})
err := s.repo.Use(s.ctx, code.ID, user.ID) err := s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().NoError(err, "Use") s.Require().NoError(err, "Use")
got, err := s.repo.GetByID(s.ctx, code.ID) got, err := s.repo.GetByID(s.ctx, code.ID)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Equal(model.StatusUsed, got.Status) s.Require().Equal(service.StatusUsed, got.Status)
s.Require().NotNil(got.UsedBy) s.Require().NotNil(got.UsedBy)
s.Require().Equal(user.ID, *got.UsedBy) s.Require().Equal(user.ID, *got.UsedBy)
s.Require().NotNil(got.UsedAt) s.Require().NotNil(got.UsedAt)
} }
func (s *RedeemCodeRepoSuite) TestUse_Idempotency() { func (s *RedeemCodeRepoSuite) TestUse_Idempotency() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "idem@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "idem@test.com"})
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "IDEM-CODE", Type: model.RedeemTypeBalance, Status: model.StatusUnused}) code := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "IDEM-CODE", Type: service.RedeemTypeBalance, Status: service.StatusUnused})
err := s.repo.Use(s.ctx, code.ID, user.ID) err := s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().NoError(err, "Use first time") s.Require().NoError(err, "Use first time")
...@@ -200,8 +199,8 @@ func (s *RedeemCodeRepoSuite) TestUse_Idempotency() { ...@@ -200,8 +199,8 @@ func (s *RedeemCodeRepoSuite) TestUse_Idempotency() {
} }
func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() { func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "already@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "already@test.com"})
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "ALREADY-USED", Type: model.RedeemTypeBalance, Status: model.StatusUsed}) code := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "ALREADY-USED", Type: service.RedeemTypeBalance, Status: service.StatusUsed})
err := s.repo.Use(s.ctx, code.ID, user.ID) err := s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().Error(err, "expected error for already used code") s.Require().Error(err, "expected error for already used code")
...@@ -211,22 +210,22 @@ func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() { ...@@ -211,22 +210,22 @@ func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() {
// --- ListByUser --- // --- ListByUser ---
func (s *RedeemCodeRepoSuite) TestListByUser() { func (s *RedeemCodeRepoSuite) TestListByUser() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listby@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "listby@test.com"})
base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
// Create codes with explicit used_at for ordering // Create codes with explicit used_at for ordering
c1 := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{ c1 := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
Code: "USER-1", Code: "USER-1",
Type: model.RedeemTypeBalance, Type: service.RedeemTypeBalance,
Status: model.StatusUsed, Status: service.StatusUsed,
UsedBy: &user.ID, UsedBy: &user.ID,
}) })
s.db.Model(c1).Update("used_at", base) s.db.Model(c1).Update("used_at", base)
c2 := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{ c2 := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
Code: "USER-2", Code: "USER-2",
Type: model.RedeemTypeBalance, Type: service.RedeemTypeBalance,
Status: model.StatusUsed, Status: service.StatusUsed,
UsedBy: &user.ID, UsedBy: &user.ID,
}) })
s.db.Model(c2).Update("used_at", base.Add(1*time.Hour)) s.db.Model(c2).Update("used_at", base.Add(1*time.Hour))
...@@ -240,13 +239,13 @@ func (s *RedeemCodeRepoSuite) TestListByUser() { ...@@ -240,13 +239,13 @@ func (s *RedeemCodeRepoSuite) TestListByUser() {
} }
func (s *RedeemCodeRepoSuite) TestListByUser_WithGroupPreload() { func (s *RedeemCodeRepoSuite) TestListByUser_WithGroupPreload() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "grp@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "grp@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-listby"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-listby"})
c := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{ c := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
Code: "WITH-GRP", Code: "WITH-GRP",
Type: model.RedeemTypeSubscription, Type: service.RedeemTypeSubscription,
Status: model.StatusUsed, Status: service.StatusUsed,
UsedBy: &user.ID, UsedBy: &user.ID,
GroupID: &group.ID, GroupID: &group.ID,
}) })
...@@ -260,11 +259,11 @@ func (s *RedeemCodeRepoSuite) TestListByUser_WithGroupPreload() { ...@@ -260,11 +259,11 @@ func (s *RedeemCodeRepoSuite) TestListByUser_WithGroupPreload() {
} }
func (s *RedeemCodeRepoSuite) TestListByUser_DefaultLimit() { func (s *RedeemCodeRepoSuite) TestListByUser_DefaultLimit() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "deflimit@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "deflimit@test.com"})
c := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{ c := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
Code: "DEF-LIM", Code: "DEF-LIM",
Type: model.RedeemTypeBalance, Type: service.RedeemTypeBalance,
Status: model.StatusUsed, Status: service.StatusUsed,
UsedBy: &user.ID, UsedBy: &user.ID,
}) })
s.db.Model(c).Update("used_at", time.Now()) s.db.Model(c).Update("used_at", time.Now())
...@@ -278,16 +277,16 @@ func (s *RedeemCodeRepoSuite) TestListByUser_DefaultLimit() { ...@@ -278,16 +277,16 @@ func (s *RedeemCodeRepoSuite) TestListByUser_DefaultLimit() {
// --- Combined original test --- // --- Combined original test ---
func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser() { func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "rc@example.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "rc@example.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-rc"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-rc"})
codes := []model.RedeemCode{ codes := []service.RedeemCode{
{Code: "CODEA", Type: model.RedeemTypeBalance, Value: 1, Status: model.StatusUnused, CreatedAt: time.Now()}, {Code: "CODEA", Type: service.RedeemTypeBalance, Value: 1, Status: service.StatusUnused, CreatedAt: time.Now()},
{Code: "CODEB", Type: model.RedeemTypeSubscription, Value: 0, Status: model.StatusUnused, GroupID: &group.ID, ValidityDays: 7, CreatedAt: time.Now()}, {Code: "CODEB", Type: service.RedeemTypeSubscription, Value: 0, Status: service.StatusUnused, GroupID: &group.ID, ValidityDays: 7, CreatedAt: time.Now()},
} }
s.Require().NoError(s.repo.CreateBatch(s.ctx, codes), "CreateBatch") s.Require().NoError(s.repo.CreateBatch(s.ctx, codes), "CreateBatch")
list, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.RedeemTypeSubscription, model.StatusUnused, "code") list, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.RedeemTypeSubscription, service.StatusUnused, "code")
s.Require().NoError(err, "ListWithFilters") s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total) s.Require().Equal(int64(1), page.Total)
s.Require().Len(list, 1) s.Require().Len(list, 1)
...@@ -305,9 +304,9 @@ func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser ...@@ -305,9 +304,9 @@ func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser
s.Require().NoError(err, "GetByCode") s.Require().NoError(err, "GetByCode")
// Use fixed time instead of time.Sleep for deterministic ordering // Use fixed time instead of time.Sleep for deterministic ordering
s.db.Model(&model.RedeemCode{}).Where("id = ?", codeB.ID).Update("used_at", time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)) s.db.Model(&redeemCodeModel{}).Where("id = ?", codeB.ID).Update("used_at", time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC))
s.Require().NoError(s.repo.Use(s.ctx, codeA.ID, user.ID), "Use codeA") s.Require().NoError(s.repo.Use(s.ctx, codeA.ID, user.ID), "Use codeA")
s.db.Model(&model.RedeemCode{}).Where("id = ?", codeA.ID).Update("used_at", time.Date(2025, 1, 1, 13, 0, 0, 0, time.UTC)) s.db.Model(&redeemCodeModel{}).Where("id = ?", codeA.ID).Update("used_at", time.Date(2025, 1, 1, 13, 0, 0, 0, time.UTC))
used, err := s.repo.ListByUser(s.ctx, user.ID, 10) used, err := s.repo.ListByUser(s.ctx, user.ID, 10)
s.Require().NoError(err, "ListByUser") s.Require().NoError(err, "ListByUser")
......
...@@ -6,33 +6,27 @@ import ( ...@@ -6,33 +6,27 @@ import (
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
) )
// SettingRepository 系统设置数据访问层
type settingRepository struct { type settingRepository struct {
db *gorm.DB db *gorm.DB
} }
// NewSettingRepository 创建系统设置仓库实例
func NewSettingRepository(db *gorm.DB) service.SettingRepository { func NewSettingRepository(db *gorm.DB) service.SettingRepository {
return &settingRepository{db: db} return &settingRepository{db: db}
} }
// Get 根据Key获取设置值 func (r *settingRepository) Get(ctx context.Context, key string) (*service.Setting, error) {
func (r *settingRepository) Get(ctx context.Context, key string) (*model.Setting, error) { var m settingModel
var setting model.Setting err := r.db.WithContext(ctx).Where("key = ?", key).First(&m).Error
err := r.db.WithContext(ctx).Where("key = ?", key).First(&setting).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrSettingNotFound, nil) return nil, translatePersistenceError(err, service.ErrSettingNotFound, nil)
} }
return &setting, nil return settingModelToService(&m), nil
} }
// GetValue 获取设置值字符串
func (r *settingRepository) GetValue(ctx context.Context, key string) (string, error) { func (r *settingRepository) GetValue(ctx context.Context, key string) (string, error) {
setting, err := r.Get(ctx, key) setting, err := r.Get(ctx, key)
if err != nil { if err != nil {
...@@ -41,9 +35,8 @@ func (r *settingRepository) GetValue(ctx context.Context, key string) (string, e ...@@ -41,9 +35,8 @@ func (r *settingRepository) GetValue(ctx context.Context, key string) (string, e
return setting.Value, nil return setting.Value, nil
} }
// Set 设置值(存在则更新,不存在则创建)
func (r *settingRepository) Set(ctx context.Context, key, value string) error { func (r *settingRepository) Set(ctx context.Context, key, value string) error {
setting := &model.Setting{ m := &settingModel{
Key: key, Key: key,
Value: value, Value: value,
UpdatedAt: time.Now(), UpdatedAt: time.Now(),
...@@ -52,12 +45,11 @@ func (r *settingRepository) Set(ctx context.Context, key, value string) error { ...@@ -52,12 +45,11 @@ func (r *settingRepository) Set(ctx context.Context, key, value string) error {
return r.db.WithContext(ctx).Clauses(clause.OnConflict{ return r.db.WithContext(ctx).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "key"}}, Columns: []clause.Column{{Name: "key"}},
DoUpdates: clause.AssignmentColumns([]string{"value", "updated_at"}), DoUpdates: clause.AssignmentColumns([]string{"value", "updated_at"}),
}).Create(setting).Error }).Create(m).Error
} }
// GetMultiple 批量获取设置
func (r *settingRepository) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { func (r *settingRepository) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
var settings []model.Setting var settings []settingModel
err := r.db.WithContext(ctx).Where("key IN ?", keys).Find(&settings).Error err := r.db.WithContext(ctx).Where("key IN ?", keys).Find(&settings).Error
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -70,11 +62,10 @@ func (r *settingRepository) GetMultiple(ctx context.Context, keys []string) (map ...@@ -70,11 +62,10 @@ func (r *settingRepository) GetMultiple(ctx context.Context, keys []string) (map
return result, nil return result, nil
} }
// SetMultiple 批量设置值
func (r *settingRepository) SetMultiple(ctx context.Context, settings map[string]string) error { func (r *settingRepository) SetMultiple(ctx context.Context, settings map[string]string) error {
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for key, value := range settings { for key, value := range settings {
setting := &model.Setting{ m := &settingModel{
Key: key, Key: key,
Value: value, Value: value,
UpdatedAt: time.Now(), UpdatedAt: time.Now(),
...@@ -82,7 +73,7 @@ func (r *settingRepository) SetMultiple(ctx context.Context, settings map[string ...@@ -82,7 +73,7 @@ func (r *settingRepository) SetMultiple(ctx context.Context, settings map[string
if err := tx.Clauses(clause.OnConflict{ if err := tx.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "key"}}, Columns: []clause.Column{{Name: "key"}},
DoUpdates: clause.AssignmentColumns([]string{"value", "updated_at"}), DoUpdates: clause.AssignmentColumns([]string{"value", "updated_at"}),
}).Create(setting).Error; err != nil { }).Create(m).Error; err != nil {
return err return err
} }
} }
...@@ -90,9 +81,8 @@ func (r *settingRepository) SetMultiple(ctx context.Context, settings map[string ...@@ -90,9 +81,8 @@ func (r *settingRepository) SetMultiple(ctx context.Context, settings map[string
}) })
} }
// GetAll 获取所有设置
func (r *settingRepository) GetAll(ctx context.Context) (map[string]string, error) { func (r *settingRepository) GetAll(ctx context.Context) (map[string]string, error) {
var settings []model.Setting var settings []settingModel
err := r.db.WithContext(ctx).Find(&settings).Error err := r.db.WithContext(ctx).Find(&settings).Error
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -105,7 +95,27 @@ func (r *settingRepository) GetAll(ctx context.Context) (map[string]string, erro ...@@ -105,7 +95,27 @@ func (r *settingRepository) GetAll(ctx context.Context) (map[string]string, erro
return result, nil return result, nil
} }
// Delete 删除设置
func (r *settingRepository) Delete(ctx context.Context, key string) error { func (r *settingRepository) Delete(ctx context.Context, key string) error {
return r.db.WithContext(ctx).Where("key = ?", key).Delete(&model.Setting{}).Error return r.db.WithContext(ctx).Where("key = ?", key).Delete(&settingModel{}).Error
}
type settingModel struct {
ID int64 `gorm:"primaryKey"`
Key string `gorm:"uniqueIndex;size:100;not null"`
Value string `gorm:"type:text;not null"`
UpdatedAt time.Time `gorm:"not null"`
}
func (settingModel) TableName() string { return "settings" }
func settingModelToService(m *settingModel) *service.Setting {
if m == nil {
return nil
}
return &service.Setting{
ID: m.ID,
Key: m.Key,
Value: m.Value,
UpdatedAt: m.UpdatedAt,
}
} }
...@@ -6,7 +6,6 @@ import ( ...@@ -6,7 +6,6 @@ import (
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
...@@ -30,7 +29,7 @@ func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int ...@@ -30,7 +29,7 @@ func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int
TokenCount int64 `gorm:"column:token_count"` TokenCount int64 `gorm:"column:token_count"`
} }
db := r.db.WithContext(ctx).Model(&model.UsageLog{}). db := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
COUNT(*) as request_count, COUNT(*) as request_count,
COALESCE(SUM(input_tokens + output_tokens), 0) as token_count COALESCE(SUM(input_tokens + output_tokens), 0) as token_count
...@@ -46,24 +45,29 @@ func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int ...@@ -46,24 +45,29 @@ func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int
return perfStats.RequestCount / 5, perfStats.TokenCount / 5 return perfStats.RequestCount / 5, perfStats.TokenCount / 5
} }
func (r *usageLogRepository) Create(ctx context.Context, log *model.UsageLog) error { func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) error {
return r.db.WithContext(ctx).Create(log).Error m := usageLogModelFromService(log)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyUsageLogModelToService(log, m)
}
return err
} }
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (*model.UsageLog, error) { func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) {
var log model.UsageLog var log usageLogModel
err := r.db.WithContext(ctx).First(&log, id).Error err := r.db.WithContext(ctx).First(&log, id).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrUsageLogNotFound, nil) return nil, translatePersistenceError(err, service.ErrUsageLogNotFound, nil)
} }
return &log, nil return usageLogModelToService(&log), nil
} }
func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []usageLogModel
var total int64 var total int64
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).Where("user_id = ?", userID) db := r.db.WithContext(ctx).Model(&usageLogModel{}).Where("user_id = ?", userID)
if err := db.Count(&total).Error; err != nil { if err := db.Count(&total).Error; err != nil {
return nil, nil, err return nil, nil, err
...@@ -73,24 +77,14 @@ func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, param ...@@ -73,24 +77,14 @@ func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, param
return nil, nil, err return nil, nil, err
} }
pages := int(total) / params.Limit() return usageLogModelsToService(logs), paginationResultFromTotal(total, params), nil
if int(total)%params.Limit() > 0 {
pages++
}
return logs, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
} }
func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []usageLogModel
var total int64 var total int64
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).Where("api_key_id = ?", apiKeyID) db := r.db.WithContext(ctx).Model(&usageLogModel{}).Where("api_key_id = ?", apiKeyID)
if err := db.Count(&total).Error; err != nil { if err := db.Count(&total).Error; err != nil {
return nil, nil, err return nil, nil, err
...@@ -100,17 +94,7 @@ func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, p ...@@ -100,17 +94,7 @@ func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, p
return nil, nil, err return nil, nil, err
} }
pages := int(total) / params.Limit() return usageLogModelsToService(logs), paginationResultFromTotal(total, params), nil
if int(total)%params.Limit() > 0 {
pages++
}
return logs, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
} }
// UserStats 用户使用统计 // UserStats 用户使用统计
...@@ -125,7 +109,7 @@ type UserStats struct { ...@@ -125,7 +109,7 @@ type UserStats struct {
func (r *usageLogRepository) GetUserStats(ctx context.Context, userID int64, startTime, endTime time.Time) (*UserStats, error) { func (r *usageLogRepository) GetUserStats(ctx context.Context, userID int64, startTime, endTime time.Time) (*UserStats, error) {
var stats UserStats var stats UserStats
err := r.db.WithContext(ctx).Model(&model.UsageLog{}). err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
COUNT(*) as total_requests, COUNT(*) as total_requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
...@@ -147,47 +131,47 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS ...@@ -147,47 +131,47 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
today := timezone.Today() today := timezone.Today()
// 总用户数 // 总用户数
r.db.WithContext(ctx).Model(&model.User{}).Count(&stats.TotalUsers) r.db.WithContext(ctx).Model(&userModel{}).Count(&stats.TotalUsers)
// 今日新增用户数 // 今日新增用户数
r.db.WithContext(ctx).Model(&model.User{}). r.db.WithContext(ctx).Model(&userModel{}).
Where("created_at >= ?", today). Where("created_at >= ?", today).
Count(&stats.TodayNewUsers) Count(&stats.TodayNewUsers)
// 今日活跃用户数 (今日有请求的用户) // 今日活跃用户数 (今日有请求的用户)
r.db.WithContext(ctx).Model(&model.UsageLog{}). r.db.WithContext(ctx).Model(&usageLogModel{}).
Distinct("user_id"). Distinct("user_id").
Where("created_at >= ?", today). Where("created_at >= ?", today).
Count(&stats.ActiveUsers) Count(&stats.ActiveUsers)
// 总 API Key 数 // 总 API Key 数
r.db.WithContext(ctx).Model(&model.ApiKey{}).Count(&stats.TotalApiKeys) r.db.WithContext(ctx).Model(&apiKeyModel{}).Count(&stats.TotalApiKeys)
// 活跃 API Key 数 // 活跃 API Key 数
r.db.WithContext(ctx).Model(&model.ApiKey{}). r.db.WithContext(ctx).Model(&apiKeyModel{}).
Where("status = ?", model.StatusActive). Where("status = ?", service.StatusActive).
Count(&stats.ActiveApiKeys) Count(&stats.ActiveApiKeys)
// 总账户数 // 总账户数
r.db.WithContext(ctx).Model(&model.Account{}).Count(&stats.TotalAccounts) r.db.WithContext(ctx).Model(&accountModel{}).Count(&stats.TotalAccounts)
// 正常账户数 (schedulable=true, status=active) // 正常账户数 (schedulable=true, status=active)
r.db.WithContext(ctx).Model(&model.Account{}). r.db.WithContext(ctx).Model(&accountModel{}).
Where("status = ? AND schedulable = ?", model.StatusActive, true). Where("status = ? AND schedulable = ?", service.StatusActive, true).
Count(&stats.NormalAccounts) Count(&stats.NormalAccounts)
// 异常账户数 (status=error) // 异常账户数 (status=error)
r.db.WithContext(ctx).Model(&model.Account{}). r.db.WithContext(ctx).Model(&accountModel{}).
Where("status = ?", model.StatusError). Where("status = ?", service.StatusError).
Count(&stats.ErrorAccounts) Count(&stats.ErrorAccounts)
// 限流账户数 // 限流账户数
r.db.WithContext(ctx).Model(&model.Account{}). r.db.WithContext(ctx).Model(&accountModel{}).
Where("rate_limited_at IS NOT NULL AND rate_limit_reset_at > ?", time.Now()). Where("rate_limited_at IS NOT NULL AND rate_limit_reset_at > ?", time.Now()).
Count(&stats.RateLimitAccounts) Count(&stats.RateLimitAccounts)
// 过载账户数 // 过载账户数
r.db.WithContext(ctx).Model(&model.Account{}). r.db.WithContext(ctx).Model(&accountModel{}).
Where("overload_until IS NOT NULL AND overload_until > ?", time.Now()). Where("overload_until IS NOT NULL AND overload_until > ?", time.Now()).
Count(&stats.OverloadAccounts) Count(&stats.OverloadAccounts)
...@@ -202,7 +186,7 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS ...@@ -202,7 +186,7 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
TotalActualCost float64 `gorm:"column:total_actual_cost"` TotalActualCost float64 `gorm:"column:total_actual_cost"`
AverageDurationMs float64 `gorm:"column:avg_duration_ms"` AverageDurationMs float64 `gorm:"column:avg_duration_ms"`
} }
r.db.WithContext(ctx).Model(&model.UsageLog{}). r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
COUNT(*) as total_requests, COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens, COALESCE(SUM(input_tokens), 0) as total_input_tokens,
...@@ -235,7 +219,7 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS ...@@ -235,7 +219,7 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
TodayCost float64 `gorm:"column:today_cost"` TodayCost float64 `gorm:"column:today_cost"`
TodayActualCost float64 `gorm:"column:today_actual_cost"` TodayActualCost float64 `gorm:"column:today_actual_cost"`
} }
r.db.WithContext(ctx).Model(&model.UsageLog{}). r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
COUNT(*) as today_requests, COUNT(*) as today_requests,
COALESCE(SUM(input_tokens), 0) as today_input_tokens, COALESCE(SUM(input_tokens), 0) as today_input_tokens,
...@@ -263,11 +247,11 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS ...@@ -263,11 +247,11 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
return &stats, nil return &stats, nil
} }
func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []usageLogModel
var total int64 var total int64
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).Where("account_id = ?", accountID) db := r.db.WithContext(ctx).Model(&usageLogModel{}).Where("account_id = ?", accountID)
if err := db.Count(&total).Error; err != nil { if err := db.Count(&total).Error; err != nil {
return nil, nil, err return nil, nil, err
...@@ -277,57 +261,47 @@ func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, ...@@ -277,57 +261,47 @@ func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64,
return nil, nil, err return nil, nil, err
} }
pages := int(total) / params.Limit() return usageLogModelsToService(logs), paginationResultFromTotal(total, params), nil
if int(total)%params.Limit() > 0 {
pages++
}
return logs, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
} }
func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []usageLogModel
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("user_id = ? AND created_at >= ? AND created_at < ?", userID, startTime, endTime). Where("user_id = ? AND created_at >= ? AND created_at < ?", userID, startTime, endTime).
Order("id DESC"). Order("id DESC").
Find(&logs).Error Find(&logs).Error
return logs, nil, err return usageLogModelsToService(logs), nil, err
} }
func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []usageLogModel
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("api_key_id = ? AND created_at >= ? AND created_at < ?", apiKeyID, startTime, endTime). Where("api_key_id = ? AND created_at >= ? AND created_at < ?", apiKeyID, startTime, endTime).
Order("id DESC"). Order("id DESC").
Find(&logs).Error Find(&logs).Error
return logs, nil, err return usageLogModelsToService(logs), nil, err
} }
func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []usageLogModel
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime). Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime).
Order("id DESC"). Order("id DESC").
Find(&logs).Error Find(&logs).Error
return logs, nil, err return usageLogModelsToService(logs), nil, err
} }
func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []usageLogModel
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("model = ? AND created_at >= ? AND created_at < ?", modelName, startTime, endTime). Where("model = ? AND created_at >= ? AND created_at < ?", modelName, startTime, endTime).
Order("id DESC"). Order("id DESC").
Find(&logs).Error Find(&logs).Error
return logs, nil, err return usageLogModelsToService(logs), nil, err
} }
func (r *usageLogRepository) Delete(ctx context.Context, id int64) error { func (r *usageLogRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.UsageLog{}, id).Error return r.db.WithContext(ctx).Delete(&usageLogModel{}, id).Error
} }
// GetAccountTodayStats 获取账号今日统计 // GetAccountTodayStats 获取账号今日统计
...@@ -340,7 +314,7 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID ...@@ -340,7 +314,7 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
Cost float64 `gorm:"column:cost"` Cost float64 `gorm:"column:cost"`
} }
err := r.db.WithContext(ctx).Model(&model.UsageLog{}). err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
COUNT(*) as requests, COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
...@@ -368,7 +342,7 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI ...@@ -368,7 +342,7 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
Cost float64 `gorm:"column:cost"` Cost float64 `gorm:"column:cost"`
} }
err := r.db.WithContext(ctx).Model(&model.UsageLog{}). err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
COUNT(*) as requests, COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
...@@ -499,12 +473,12 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i ...@@ -499,12 +473,12 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
today := timezone.Today() today := timezone.Today()
// API Key 统计 // API Key 统计
r.db.WithContext(ctx).Model(&model.ApiKey{}). r.db.WithContext(ctx).Model(&apiKeyModel{}).
Where("user_id = ?", userID). Where("user_id = ?", userID).
Count(&stats.TotalApiKeys) Count(&stats.TotalApiKeys)
r.db.WithContext(ctx).Model(&model.ApiKey{}). r.db.WithContext(ctx).Model(&apiKeyModel{}).
Where("user_id = ? AND status = ?", userID, model.StatusActive). Where("user_id = ? AND status = ?", userID, service.StatusActive).
Count(&stats.ActiveApiKeys) Count(&stats.ActiveApiKeys)
// 累计 Token 统计 // 累计 Token 统计
...@@ -518,7 +492,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i ...@@ -518,7 +492,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
TotalActualCost float64 `gorm:"column:total_actual_cost"` TotalActualCost float64 `gorm:"column:total_actual_cost"`
AverageDurationMs float64 `gorm:"column:avg_duration_ms"` AverageDurationMs float64 `gorm:"column:avg_duration_ms"`
} }
r.db.WithContext(ctx).Model(&model.UsageLog{}). r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
COUNT(*) as total_requests, COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens, COALESCE(SUM(input_tokens), 0) as total_input_tokens,
...@@ -552,7 +526,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i ...@@ -552,7 +526,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
TodayCost float64 `gorm:"column:today_cost"` TodayCost float64 `gorm:"column:today_cost"`
TodayActualCost float64 `gorm:"column:today_actual_cost"` TodayActualCost float64 `gorm:"column:today_actual_cost"`
} }
r.db.WithContext(ctx).Model(&model.UsageLog{}). r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
COUNT(*) as today_requests, COUNT(*) as today_requests,
COALESCE(SUM(input_tokens), 0) as today_input_tokens, COALESCE(SUM(input_tokens), 0) as today_input_tokens,
...@@ -591,7 +565,7 @@ func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user ...@@ -591,7 +565,7 @@ func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user
dateFormat = "YYYY-MM-DD" dateFormat = "YYYY-MM-DD"
} }
err := r.db.WithContext(ctx).Model(&model.UsageLog{}). err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
TO_CHAR(created_at, ?) as date, TO_CHAR(created_at, ?) as date,
COUNT(*) as requests, COUNT(*) as requests,
...@@ -618,7 +592,7 @@ func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user ...@@ -618,7 +592,7 @@ func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user
func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]ModelStat, error) { func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]ModelStat, error) {
var results []ModelStat var results []ModelStat
err := r.db.WithContext(ctx).Model(&model.UsageLog{}). err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
model, model,
COUNT(*) as requests, COUNT(*) as requests,
...@@ -644,11 +618,11 @@ func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64 ...@@ -644,11 +618,11 @@ func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64
type UsageLogFilters = usagestats.UsageLogFilters type UsageLogFilters = usagestats.UsageLogFilters
// ListWithFilters lists usage logs with optional filters (for admin) // ListWithFilters lists usage logs with optional filters (for admin)
func (r *usageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []usageLogModel
var total int64 var total int64
db := r.db.WithContext(ctx).Model(&model.UsageLog{}) db := r.db.WithContext(ctx).Model(&usageLogModel{})
// Apply filters // Apply filters
if filters.UserID > 0 { if filters.UserID > 0 {
...@@ -675,17 +649,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat ...@@ -675,17 +649,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
return nil, nil, err return nil, nil, err
} }
pages := int(total) / params.Limit() return usageLogModelsToService(logs), paginationResultFromTotal(total, params), nil
if int(total)%params.Limit() > 0 {
pages++
}
return logs, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
} }
// UsageStats represents usage statistics // UsageStats represents usage statistics
...@@ -713,7 +677,7 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs ...@@ -713,7 +677,7 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
UserID int64 `gorm:"column:user_id"` UserID int64 `gorm:"column:user_id"`
TotalCost float64 `gorm:"column:total_cost"` TotalCost float64 `gorm:"column:total_cost"`
} }
err := r.db.WithContext(ctx).Model(&model.UsageLog{}). err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select("user_id, COALESCE(SUM(actual_cost), 0) as total_cost"). Select("user_id, COALESCE(SUM(actual_cost), 0) as total_cost").
Where("user_id IN ?", userIDs). Where("user_id IN ?", userIDs).
Group("user_id"). Group("user_id").
...@@ -733,7 +697,7 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs ...@@ -733,7 +697,7 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
UserID int64 `gorm:"column:user_id"` UserID int64 `gorm:"column:user_id"`
TodayCost float64 `gorm:"column:today_cost"` TodayCost float64 `gorm:"column:today_cost"`
} }
err = r.db.WithContext(ctx).Model(&model.UsageLog{}). err = r.db.WithContext(ctx).Model(&usageLogModel{}).
Select("user_id, COALESCE(SUM(actual_cost), 0) as today_cost"). Select("user_id, COALESCE(SUM(actual_cost), 0) as today_cost").
Where("user_id IN ? AND created_at >= ?", userIDs, today). Where("user_id IN ? AND created_at >= ?", userIDs, today).
Group("user_id"). Group("user_id").
...@@ -773,7 +737,7 @@ func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe ...@@ -773,7 +737,7 @@ func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe
ApiKeyID int64 `gorm:"column:api_key_id"` ApiKeyID int64 `gorm:"column:api_key_id"`
TotalCost float64 `gorm:"column:total_cost"` TotalCost float64 `gorm:"column:total_cost"`
} }
err := r.db.WithContext(ctx).Model(&model.UsageLog{}). err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select("api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost"). Select("api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost").
Where("api_key_id IN ?", apiKeyIDs). Where("api_key_id IN ?", apiKeyIDs).
Group("api_key_id"). Group("api_key_id").
...@@ -793,7 +757,7 @@ func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe ...@@ -793,7 +757,7 @@ func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe
ApiKeyID int64 `gorm:"column:api_key_id"` ApiKeyID int64 `gorm:"column:api_key_id"`
TodayCost float64 `gorm:"column:today_cost"` TodayCost float64 `gorm:"column:today_cost"`
} }
err = r.db.WithContext(ctx).Model(&model.UsageLog{}). err = r.db.WithContext(ctx).Model(&usageLogModel{}).
Select("api_key_id, COALESCE(SUM(actual_cost), 0) as today_cost"). Select("api_key_id, COALESCE(SUM(actual_cost), 0) as today_cost").
Where("api_key_id IN ? AND created_at >= ?", apiKeyIDs, today). Where("api_key_id IN ? AND created_at >= ?", apiKeyIDs, today).
Group("api_key_id"). Group("api_key_id").
...@@ -822,7 +786,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start ...@@ -822,7 +786,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
dateFormat = "YYYY-MM-DD" dateFormat = "YYYY-MM-DD"
} }
db := r.db.WithContext(ctx).Model(&model.UsageLog{}). db := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
TO_CHAR(created_at, ?) as date, TO_CHAR(created_at, ?) as date,
COUNT(*) as requests, COUNT(*) as requests,
...@@ -854,7 +818,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start ...@@ -854,7 +818,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]ModelStat, error) { func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]ModelStat, error) {
var results []ModelStat var results []ModelStat
db := r.db.WithContext(ctx).Model(&model.UsageLog{}). db := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
model, model,
COUNT(*) as requests, COUNT(*) as requests,
...@@ -896,7 +860,7 @@ func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endT ...@@ -896,7 +860,7 @@ func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endT
AverageDurationMs float64 `gorm:"column:avg_duration_ms"` AverageDurationMs float64 `gorm:"column:avg_duration_ms"`
} }
err := r.db.WithContext(ctx).Model(&model.UsageLog{}). err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
COUNT(*) as total_requests, COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens, COALESCE(SUM(input_tokens), 0) as total_input_tokens,
...@@ -950,7 +914,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID ...@@ -950,7 +914,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
ActualCost float64 `gorm:"column:actual_cost"` ActualCost float64 `gorm:"column:actual_cost"`
} }
err := r.db.WithContext(ctx).Model(&model.UsageLog{}). err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
TO_CHAR(created_at, 'YYYY-MM-DD') as date, TO_CHAR(created_at, 'YYYY-MM-DD') as date,
COUNT(*) as requests, COUNT(*) as requests,
...@@ -1011,7 +975,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID ...@@ -1011,7 +975,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
var avgDuration struct { var avgDuration struct {
AvgDurationMs float64 `gorm:"column:avg_duration_ms"` AvgDurationMs float64 `gorm:"column:avg_duration_ms"`
} }
r.db.WithContext(ctx).Model(&model.UsageLog{}). r.db.WithContext(ctx).Model(&usageLogModel{}).
Select("COALESCE(AVG(duration_ms), 0) as avg_duration_ms"). Select("COALESCE(AVG(duration_ms), 0) as avg_duration_ms").
Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime). Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime).
Scan(&avgDuration) Scan(&avgDuration)
...@@ -1090,3 +1054,137 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID ...@@ -1090,3 +1054,137 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
Models: models, Models: models,
}, nil }, nil
} }
type usageLogModel struct {
ID int64 `gorm:"primaryKey"`
UserID int64 `gorm:"index;not null"`
ApiKeyID int64 `gorm:"index;not null"`
AccountID int64 `gorm:"index;not null"`
RequestID string `gorm:"size:64"`
Model string `gorm:"size:100;index;not null"`
GroupID *int64 `gorm:"index"`
SubscriptionID *int64 `gorm:"index"`
InputTokens int `gorm:"default:0;not null"`
OutputTokens int `gorm:"default:0;not null"`
CacheCreationTokens int `gorm:"default:0;not null"`
CacheReadTokens int `gorm:"default:0;not null"`
CacheCreation5mTokens int `gorm:"default:0;not null"`
CacheCreation1hTokens int `gorm:"default:0;not null"`
InputCost float64 `gorm:"type:decimal(20,10);default:0;not null"`
OutputCost float64 `gorm:"type:decimal(20,10);default:0;not null"`
CacheCreationCost float64 `gorm:"type:decimal(20,10);default:0;not null"`
CacheReadCost float64 `gorm:"type:decimal(20,10);default:0;not null"`
TotalCost float64 `gorm:"type:decimal(20,10);default:0;not null"`
ActualCost float64 `gorm:"type:decimal(20,10);default:0;not null"`
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1;not null"`
BillingType int8 `gorm:"type:smallint;default:0;not null"`
Stream bool `gorm:"default:false;not null"`
DurationMs *int
FirstTokenMs *int
CreatedAt time.Time `gorm:"index;not null"`
User *userModel `gorm:"foreignKey:UserID"`
ApiKey *apiKeyModel `gorm:"foreignKey:ApiKeyID"`
Account *accountModel `gorm:"foreignKey:AccountID"`
Group *groupModel `gorm:"foreignKey:GroupID"`
Subscription *userSubscriptionModel `gorm:"foreignKey:SubscriptionID"`
}
func (usageLogModel) TableName() string { return "usage_logs" }
func usageLogModelToService(m *usageLogModel) *service.UsageLog {
if m == nil {
return nil
}
return &service.UsageLog{
ID: m.ID,
UserID: m.UserID,
ApiKeyID: m.ApiKeyID,
AccountID: m.AccountID,
RequestID: m.RequestID,
Model: m.Model,
GroupID: m.GroupID,
SubscriptionID: m.SubscriptionID,
InputTokens: m.InputTokens,
OutputTokens: m.OutputTokens,
CacheCreationTokens: m.CacheCreationTokens,
CacheReadTokens: m.CacheReadTokens,
CacheCreation5mTokens: m.CacheCreation5mTokens,
CacheCreation1hTokens: m.CacheCreation1hTokens,
InputCost: m.InputCost,
OutputCost: m.OutputCost,
CacheCreationCost: m.CacheCreationCost,
CacheReadCost: m.CacheReadCost,
TotalCost: m.TotalCost,
ActualCost: m.ActualCost,
RateMultiplier: m.RateMultiplier,
BillingType: m.BillingType,
Stream: m.Stream,
DurationMs: m.DurationMs,
FirstTokenMs: m.FirstTokenMs,
CreatedAt: m.CreatedAt,
User: userModelToService(m.User),
ApiKey: apiKeyModelToService(m.ApiKey),
Account: accountModelToService(m.Account),
Group: groupModelToService(m.Group),
Subscription: userSubscriptionModelToService(m.Subscription),
}
}
func usageLogModelsToService(models []usageLogModel) []service.UsageLog {
out := make([]service.UsageLog, 0, len(models))
for i := range models {
if s := usageLogModelToService(&models[i]); s != nil {
out = append(out, *s)
}
}
return out
}
func usageLogModelFromService(log *service.UsageLog) *usageLogModel {
if log == nil {
return nil
}
return &usageLogModel{
ID: log.ID,
UserID: log.UserID,
ApiKeyID: log.ApiKeyID,
AccountID: log.AccountID,
RequestID: log.RequestID,
Model: log.Model,
GroupID: log.GroupID,
SubscriptionID: log.SubscriptionID,
InputTokens: log.InputTokens,
OutputTokens: log.OutputTokens,
CacheCreationTokens: log.CacheCreationTokens,
CacheReadTokens: log.CacheReadTokens,
CacheCreation5mTokens: log.CacheCreation5mTokens,
CacheCreation1hTokens: log.CacheCreation1hTokens,
InputCost: log.InputCost,
OutputCost: log.OutputCost,
CacheCreationCost: log.CacheCreationCost,
CacheReadCost: log.CacheReadCost,
TotalCost: log.TotalCost,
ActualCost: log.ActualCost,
RateMultiplier: log.RateMultiplier,
BillingType: log.BillingType,
Stream: log.Stream,
DurationMs: log.DurationMs,
FirstTokenMs: log.FirstTokenMs,
CreatedAt: log.CreatedAt,
}
}
func applyUsageLogModelToService(log *service.UsageLog, m *usageLogModel) {
if log == nil || m == nil {
return
}
log.ID = m.ID
log.CreatedAt = m.CreatedAt
}
...@@ -7,10 +7,10 @@ import ( ...@@ -7,10 +7,10 @@ import (
"testing" "testing"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"gorm.io/gorm" "gorm.io/gorm"
) )
...@@ -32,8 +32,8 @@ func TestUsageLogRepoSuite(t *testing.T) { ...@@ -32,8 +32,8 @@ func TestUsageLogRepoSuite(t *testing.T) {
suite.Run(t, new(UsageLogRepoSuite)) suite.Run(t, new(UsageLogRepoSuite))
} }
func (s *UsageLogRepoSuite) createUsageLog(user *model.User, apiKey *model.ApiKey, account *model.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *model.UsageLog { func (s *UsageLogRepoSuite) createUsageLog(user *userModel, apiKey *apiKeyModel, account *accountModel, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog {
log := &model.UsageLog{ log := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
...@@ -51,11 +51,11 @@ func (s *UsageLogRepoSuite) createUsageLog(user *model.User, apiKey *model.ApiKe ...@@ -51,11 +51,11 @@ func (s *UsageLogRepoSuite) createUsageLog(user *model.User, apiKey *model.ApiKe
// --- Create / GetByID --- // --- Create / GetByID ---
func (s *UsageLogRepoSuite) TestCreate() { func (s *UsageLogRepoSuite) TestCreate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "create@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "create@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-create", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-create", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-create"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-create"})
log := &model.UsageLog{ log := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
...@@ -72,9 +72,9 @@ func (s *UsageLogRepoSuite) TestCreate() { ...@@ -72,9 +72,9 @@ func (s *UsageLogRepoSuite) TestCreate() {
} }
func (s *UsageLogRepoSuite) TestGetByID() { func (s *UsageLogRepoSuite) TestGetByID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "getbyid@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "getbyid@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-getbyid"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-getbyid"})
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
...@@ -92,9 +92,9 @@ func (s *UsageLogRepoSuite) TestGetByID_NotFound() { ...@@ -92,9 +92,9 @@ func (s *UsageLogRepoSuite) TestGetByID_NotFound() {
// --- Delete --- // --- Delete ---
func (s *UsageLogRepoSuite) TestDelete() { func (s *UsageLogRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-delete", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-delete", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-delete"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-delete"})
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
...@@ -108,9 +108,9 @@ func (s *UsageLogRepoSuite) TestDelete() { ...@@ -108,9 +108,9 @@ func (s *UsageLogRepoSuite) TestDelete() {
// --- ListByUser --- // --- ListByUser ---
func (s *UsageLogRepoSuite) TestListByUser() { func (s *UsageLogRepoSuite) TestListByUser() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyuser@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyuser@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-listbyuser", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-listbyuser", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-listbyuser"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-listbyuser"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now()) s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now())
...@@ -124,9 +124,9 @@ func (s *UsageLogRepoSuite) TestListByUser() { ...@@ -124,9 +124,9 @@ func (s *UsageLogRepoSuite) TestListByUser() {
// --- ListByApiKey --- // --- ListByApiKey ---
func (s *UsageLogRepoSuite) TestListByApiKey() { func (s *UsageLogRepoSuite) TestListByApiKey() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyapikey@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyapikey@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-listbyapikey"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-listbyapikey"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now()) s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now())
...@@ -140,9 +140,9 @@ func (s *UsageLogRepoSuite) TestListByApiKey() { ...@@ -140,9 +140,9 @@ func (s *UsageLogRepoSuite) TestListByApiKey() {
// --- ListByAccount --- // --- ListByAccount ---
func (s *UsageLogRepoSuite) TestListByAccount() { func (s *UsageLogRepoSuite) TestListByAccount() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyaccount@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyaccount@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-listbyaccount"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-listbyaccount"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
...@@ -155,9 +155,9 @@ func (s *UsageLogRepoSuite) TestListByAccount() { ...@@ -155,9 +155,9 @@ func (s *UsageLogRepoSuite) TestListByAccount() {
// --- GetUserStats --- // --- GetUserStats ---
func (s *UsageLogRepoSuite) TestGetUserStats() { func (s *UsageLogRepoSuite) TestGetUserStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "userstats@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "userstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-userstats", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-userstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-userstats"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-userstats"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
...@@ -175,9 +175,9 @@ func (s *UsageLogRepoSuite) TestGetUserStats() { ...@@ -175,9 +175,9 @@ func (s *UsageLogRepoSuite) TestGetUserStats() {
// --- ListWithFilters --- // --- ListWithFilters ---
func (s *UsageLogRepoSuite) TestListWithFilters() { func (s *UsageLogRepoSuite) TestListWithFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filters@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "filters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filters", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-filters", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filters"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-filters"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
...@@ -194,29 +194,29 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() { ...@@ -194,29 +194,29 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
now := time.Now() now := time.Now()
todayStart := timezone.Today() todayStart := timezone.Today()
userToday := mustCreateUser(s.T(), s.db, &model.User{ userToday := mustCreateUser(s.T(), s.db, &userModel{
Email: "today@example.com", Email: "today@example.com",
CreatedAt: maxTime(todayStart.Add(10*time.Second), now.Add(-10*time.Second)), CreatedAt: maxTime(todayStart.Add(10*time.Second), now.Add(-10*time.Second)),
UpdatedAt: now, UpdatedAt: now,
}) })
userOld := mustCreateUser(s.T(), s.db, &model.User{ userOld := mustCreateUser(s.T(), s.db, &userModel{
Email: "old@example.com", Email: "old@example.com",
CreatedAt: todayStart.Add(-24 * time.Hour), CreatedAt: todayStart.Add(-24 * time.Hour),
UpdatedAt: todayStart.Add(-24 * time.Hour), UpdatedAt: todayStart.Add(-24 * time.Hour),
}) })
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-ul"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-ul"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"}) apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: model.StatusDisabled}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: service.StatusDisabled})
resetAt := now.Add(10 * time.Minute) resetAt := now.Add(10 * time.Minute)
accNormal := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-normal", Schedulable: true}) accNormal := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a-normal", Schedulable: true})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-error", Status: model.StatusError, Schedulable: true}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a-error", Status: service.StatusError, Schedulable: true})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-rl", RateLimitedAt: &now, RateLimitResetAt: &resetAt, Schedulable: true}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a-rl", RateLimitedAt: &now, RateLimitResetAt: &resetAt, Schedulable: true})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-ov", OverloadUntil: &resetAt, Schedulable: true}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a-ov", OverloadUntil: &resetAt, Schedulable: true})
d1, d2, d3 := 100, 200, 300 d1, d2, d3 := 100, 200, 300
logToday := &model.UsageLog{ logToday := &service.UsageLog{
UserID: userToday.ID, UserID: userToday.ID,
ApiKeyID: apiKey1.ID, ApiKeyID: apiKey1.ID,
AccountID: accNormal.ID, AccountID: accNormal.ID,
...@@ -233,7 +233,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() { ...@@ -233,7 +233,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
} }
s.Require().NoError(s.repo.Create(s.ctx, logToday), "Create logToday") s.Require().NoError(s.repo.Create(s.ctx, logToday), "Create logToday")
logOld := &model.UsageLog{ logOld := &service.UsageLog{
UserID: userOld.ID, UserID: userOld.ID,
ApiKeyID: apiKey1.ID, ApiKeyID: apiKey1.ID,
AccountID: accNormal.ID, AccountID: accNormal.ID,
...@@ -247,7 +247,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() { ...@@ -247,7 +247,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
} }
s.Require().NoError(s.repo.Create(s.ctx, logOld), "Create logOld") s.Require().NoError(s.repo.Create(s.ctx, logOld), "Create logOld")
logPerf := &model.UsageLog{ logPerf := &service.UsageLog{
UserID: userToday.ID, UserID: userToday.ID,
ApiKeyID: apiKey1.ID, ApiKeyID: apiKey1.ID,
AccountID: accNormal.ID, AccountID: accNormal.ID,
...@@ -293,9 +293,9 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() { ...@@ -293,9 +293,9 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
// --- GetUserDashboardStats --- // --- GetUserDashboardStats ---
func (s *UsageLogRepoSuite) TestGetUserDashboardStats() { func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "userdash@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "userdash@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-userdash", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-userdash", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-userdash"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-userdash"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
...@@ -308,9 +308,9 @@ func (s *UsageLogRepoSuite) TestGetUserDashboardStats() { ...@@ -308,9 +308,9 @@ func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
// --- GetAccountTodayStats --- // --- GetAccountTodayStats ---
func (s *UsageLogRepoSuite) TestGetAccountTodayStats() { func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "acctoday@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "acctoday@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-today"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-today"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
...@@ -323,11 +323,11 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() { ...@@ -323,11 +323,11 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
// --- GetBatchUserUsageStats --- // --- GetBatchUserUsageStats ---
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() { func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "batch1@test.com"}) user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "batch1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "batch2@test.com"}) user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "batch2@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user1.ID, Key: "sk-batch1", Name: "k"}) apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user1.ID, Key: "sk-batch1", Name: "k"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user2.ID, Key: "sk-batch2", Name: "k"}) apiKey2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user2.ID, Key: "sk-batch2", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-batch"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-batch"})
s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user2, apiKey2, account, 15, 25, 0.6, time.Now()) s.createUsageLog(user2, apiKey2, account, 15, 25, 0.6, time.Now())
...@@ -348,10 +348,10 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() { ...@@ -348,10 +348,10 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() {
// --- GetBatchApiKeyUsageStats --- // --- GetBatchApiKeyUsageStats ---
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() { func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "batchkey@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "batchkey@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"}) apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"}) apiKey2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-batchkey"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-batchkey"})
s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now()) s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now())
...@@ -370,9 +370,9 @@ func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() { ...@@ -370,9 +370,9 @@ func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
// --- GetGlobalStats --- // --- GetGlobalStats ---
func (s *UsageLogRepoSuite) TestGetGlobalStats() { func (s *UsageLogRepoSuite) TestGetGlobalStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "global@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "global@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-global", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-global", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-global"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-global"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
...@@ -395,9 +395,9 @@ func maxTime(a, b time.Time) time.Time { ...@@ -395,9 +395,9 @@ func maxTime(a, b time.Time) time.Time {
// --- ListByUserAndTimeRange --- // --- ListByUserAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() { func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "timerange@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "timerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-timerange", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-timerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-timerange"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-timerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
...@@ -414,9 +414,9 @@ func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() { ...@@ -414,9 +414,9 @@ func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
// --- ListByApiKeyAndTimeRange --- // --- ListByApiKeyAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() { func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "keytimerange@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "keytimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytimerange", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-keytimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-keytimerange"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-keytimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
...@@ -433,9 +433,9 @@ func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() { ...@@ -433,9 +433,9 @@ func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
// --- ListByAccountAndTimeRange --- // --- ListByAccountAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() { func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "acctimerange@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "acctimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-acctimerange", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-acctimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-acctimerange"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-acctimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
...@@ -452,14 +452,14 @@ func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() { ...@@ -452,14 +452,14 @@ func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
// --- ListByModelAndTimeRange --- // --- ListByModelAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() { func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "modeltimerange@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "modeltimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-modeltimerange"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-modeltimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
// Create logs with different models // Create logs with different models
log1 := &model.UsageLog{ log1 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
...@@ -472,7 +472,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() { ...@@ -472,7 +472,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
} }
s.Require().NoError(s.repo.Create(s.ctx, log1)) s.Require().NoError(s.repo.Create(s.ctx, log1))
log2 := &model.UsageLog{ log2 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
...@@ -485,7 +485,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() { ...@@ -485,7 +485,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
} }
s.Require().NoError(s.repo.Create(s.ctx, log2)) s.Require().NoError(s.repo.Create(s.ctx, log2))
log3 := &model.UsageLog{ log3 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
...@@ -508,9 +508,9 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() { ...@@ -508,9 +508,9 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
// --- GetAccountWindowStats --- // --- GetAccountWindowStats ---
func (s *UsageLogRepoSuite) TestGetAccountWindowStats() { func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "windowstats@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "windowstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-windowstats", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-windowstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-windowstats"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-windowstats"})
now := time.Now() now := time.Now()
windowStart := now.Add(-10 * time.Minute) windowStart := now.Add(-10 * time.Minute)
...@@ -528,9 +528,9 @@ func (s *UsageLogRepoSuite) TestGetAccountWindowStats() { ...@@ -528,9 +528,9 @@ func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
// --- GetUserUsageTrendByUserID --- // --- GetUserUsageTrendByUserID ---
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() { func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrend@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "usertrend@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-usertrend", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-usertrend", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-usertrend"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-usertrend"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
...@@ -545,9 +545,9 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() { ...@@ -545,9 +545,9 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
} }
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() { func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrendhourly@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "usertrendhourly@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-usertrendhourly"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-usertrendhourly"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
...@@ -564,14 +564,14 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() { ...@@ -564,14 +564,14 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
// --- GetUserModelStats --- // --- GetUserModelStats ---
func (s *UsageLogRepoSuite) TestGetUserModelStats() { func (s *UsageLogRepoSuite) TestGetUserModelStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "modelstats@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "modelstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-modelstats", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-modelstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-modelstats"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-modelstats"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
// Create logs with different models // Create logs with different models
log1 := &model.UsageLog{ log1 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
...@@ -584,7 +584,7 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() { ...@@ -584,7 +584,7 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
} }
s.Require().NoError(s.repo.Create(s.ctx, log1)) s.Require().NoError(s.repo.Create(s.ctx, log1))
log2 := &model.UsageLog{ log2 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
...@@ -611,9 +611,9 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() { ...@@ -611,9 +611,9 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
// --- GetUsageTrendWithFilters --- // --- GetUsageTrendWithFilters ---
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() { func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "trendfilters@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "trendfilters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-trendfilters", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-trendfilters", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-trendfilters"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-trendfilters"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
...@@ -639,9 +639,9 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() { ...@@ -639,9 +639,9 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
} }
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() { func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "trendfilters-h@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "trendfilters-h@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-trendfilters-h"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-trendfilters-h"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
...@@ -658,13 +658,13 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() { ...@@ -658,13 +658,13 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
// --- GetModelStatsWithFilters --- // --- GetModelStatsWithFilters ---
func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() { func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "modelfilters@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "modelfilters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-modelfilters", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-modelfilters", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-modelfilters"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-modelfilters"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
log1 := &model.UsageLog{ log1 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
...@@ -677,7 +677,7 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() { ...@@ -677,7 +677,7 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
} }
s.Require().NoError(s.repo.Create(s.ctx, log1)) s.Require().NoError(s.repo.Create(s.ctx, log1))
log2 := &model.UsageLog{ log2 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
...@@ -712,14 +712,14 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() { ...@@ -712,14 +712,14 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
// --- GetAccountUsageStats --- // --- GetAccountUsageStats ---
func (s *UsageLogRepoSuite) TestGetAccountUsageStats() { func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "accstats@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "accstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-accstats", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-accstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-accstats"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-accstats"})
base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC)
// Create logs on different days // Create logs on different days
log1 := &model.UsageLog{ log1 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
...@@ -732,7 +732,7 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() { ...@@ -732,7 +732,7 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
} }
s.Require().NoError(s.repo.Create(s.ctx, log1)) s.Require().NoError(s.repo.Create(s.ctx, log1))
log2 := &model.UsageLog{ log2 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
...@@ -758,7 +758,7 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() { ...@@ -758,7 +758,7 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
} }
func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() { func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-emptystats"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-emptystats"})
base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC)
startTime := base startTime := base
...@@ -774,11 +774,11 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() { ...@@ -774,11 +774,11 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() {
// --- GetUserUsageTrend --- // --- GetUserUsageTrend ---
func (s *UsageLogRepoSuite) TestGetUserUsageTrend() { func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrend1@test.com"}) user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "usertrend1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrend2@test.com"}) user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "usertrend2@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"}) apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"}) apiKey2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-usertrends"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-usertrends"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user1, apiKey1, account, 100, 200, 1.0, base) s.createUsageLog(user1, apiKey1, account, 100, 200, 1.0, base)
...@@ -796,10 +796,10 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrend() { ...@@ -796,10 +796,10 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
// --- GetApiKeyUsageTrend --- // --- GetApiKeyUsageTrend ---
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() { func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "keytrend@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "keytrend@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"}) apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"}) apiKey2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-keytrends"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-keytrends"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey1, account, 100, 200, 1.0, base) s.createUsageLog(user, apiKey1, account, 100, 200, 1.0, base)
...@@ -815,9 +815,9 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() { ...@@ -815,9 +815,9 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() {
} }
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() { func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "keytrendh@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "keytrendh@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-keytrendh", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-keytrendh"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-keytrendh"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 100, 200, 1.0, base) s.createUsageLog(user, apiKey, account, 100, 200, 1.0, base)
...@@ -834,9 +834,9 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() { ...@@ -834,9 +834,9 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
// --- ListWithFilters (additional filter tests) --- // --- ListWithFilters (additional filter tests) ---
func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() { func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filterskey@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "filterskey@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filterskey", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-filterskey", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filterskey"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-filterskey"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
...@@ -848,9 +848,9 @@ func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() { ...@@ -848,9 +848,9 @@ func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
} }
func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() { func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filterstime@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "filterstime@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filterstime", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-filterstime", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filterstime"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-filterstime"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
...@@ -867,9 +867,9 @@ func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() { ...@@ -867,9 +867,9 @@ func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
} }
func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() { func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filterscombined@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "filterscombined@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filterscombined", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-filterscombined", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filterscombined"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-filterscombined"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
......
...@@ -2,12 +2,13 @@ package repository ...@@ -2,12 +2,13 @@ package repository
import ( import (
"context" "context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/lib/pq"
"gorm.io/gorm" "gorm.io/gorm"
) )
...@@ -19,48 +20,56 @@ func NewUserRepository(db *gorm.DB) service.UserRepository { ...@@ -19,48 +20,56 @@ func NewUserRepository(db *gorm.DB) service.UserRepository {
return &userRepository{db: db} return &userRepository{db: db}
} }
func (r *userRepository) Create(ctx context.Context, user *model.User) error { func (r *userRepository) Create(ctx context.Context, user *service.User) error {
err := r.db.WithContext(ctx).Create(user).Error m := userModelFromService(user)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyUserModelToService(user, m)
}
return translatePersistenceError(err, nil, service.ErrEmailExists) return translatePersistenceError(err, nil, service.ErrEmailExists)
} }
func (r *userRepository) GetByID(ctx context.Context, id int64) (*model.User, error) { func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User, error) {
var user model.User var m userModel
err := r.db.WithContext(ctx).First(&user, id).Error err := r.db.WithContext(ctx).First(&m, id).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil) return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
} }
return &user, nil return userModelToService(&m), nil
} }
func (r *userRepository) GetByEmail(ctx context.Context, email string) (*model.User, error) { func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service.User, error) {
var user model.User var m userModel
err := r.db.WithContext(ctx).Where("email = ?", email).First(&user).Error err := r.db.WithContext(ctx).Where("email = ?", email).First(&m).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil) return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
} }
return &user, nil return userModelToService(&m), nil
} }
func (r *userRepository) Update(ctx context.Context, user *model.User) error { func (r *userRepository) Update(ctx context.Context, user *service.User) error {
err := r.db.WithContext(ctx).Save(user).Error m := userModelFromService(user)
err := r.db.WithContext(ctx).Save(m).Error
if err == nil {
applyUserModelToService(user, m)
}
return translatePersistenceError(err, nil, service.ErrEmailExists) return translatePersistenceError(err, nil, service.ErrEmailExists)
} }
func (r *userRepository) Delete(ctx context.Context, id int64) error { func (r *userRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.User{}, id).Error return r.db.WithContext(ctx).Delete(&userModel{}, id).Error
} }
func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) { func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "") return r.ListWithFilters(ctx, params, "", "", "")
} }
// ListWithFilters lists users with optional filtering by status, role, and search query // ListWithFilters lists users with optional filtering by status, role, and search query
func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error) { func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]service.User, *pagination.PaginationResult, error) {
var users []model.User var users []userModel
var total int64 var total int64
db := r.db.WithContext(ctx).Model(&model.User{}) db := r.db.WithContext(ctx).Model(&userModel{})
// Apply filters // Apply filters
if status != "" { if status != "" {
...@@ -89,17 +98,20 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination. ...@@ -89,17 +98,20 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
// Batch load subscriptions for all users (avoid N+1) // Batch load subscriptions for all users (avoid N+1)
if len(users) > 0 { if len(users) > 0 {
userIDs := make([]int64, len(users)) userIDs := make([]int64, len(users))
userMap := make(map[int64]*model.User, len(users)) userMap := make(map[int64]*service.User, len(users))
outUsers := make([]service.User, 0, len(users))
for i := range users { for i := range users {
userIDs[i] = users[i].ID userIDs[i] = users[i].ID
userMap[users[i].ID] = &users[i] u := userModelToService(&users[i])
outUsers = append(outUsers, *u)
userMap[u.ID] = &outUsers[len(outUsers)-1]
} }
// Query active subscriptions with groups in one query // Query active subscriptions with groups in one query
var subscriptions []model.UserSubscription var subscriptions []userSubscriptionModel
if err := r.db.WithContext(ctx). if err := r.db.WithContext(ctx).
Preload("Group"). Preload("Group").
Where("user_id IN ? AND status = ?", userIDs, model.SubscriptionStatusActive). Where("user_id IN ? AND status = ?", userIDs, service.SubscriptionStatusActive).
Find(&subscriptions).Error; err != nil { Find(&subscriptions).Error; err != nil {
return nil, nil, err return nil, nil, err
} }
...@@ -107,32 +119,29 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination. ...@@ -107,32 +119,29 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
// Associate subscriptions with users // Associate subscriptions with users
for i := range subscriptions { for i := range subscriptions {
if user, ok := userMap[subscriptions[i].UserID]; ok { if user, ok := userMap[subscriptions[i].UserID]; ok {
user.Subscriptions = append(user.Subscriptions, subscriptions[i]) user.Subscriptions = append(user.Subscriptions, *userSubscriptionModelToService(&subscriptions[i]))
} }
} }
return outUsers, paginationResultFromTotal(total, params), nil
} }
pages := int(total) / params.Limit() outUsers := make([]service.User, 0, len(users))
if int(total)%params.Limit() > 0 { for i := range users {
pages++ outUsers = append(outUsers, *userModelToService(&users[i]))
} }
return users, &pagination.PaginationResult{ return outUsers, paginationResultFromTotal(total, params), nil
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
} }
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error { func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&userModel{}).Where("id = ?", id).
Update("balance", gorm.Expr("balance + ?", amount)).Error Update("balance", gorm.Expr("balance + ?", amount)).Error
} }
// DeductBalance 扣减用户余额,仅当余额充足时执行 // DeductBalance 扣减用户余额,仅当余额充足时执行
func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error { func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error {
result := r.db.WithContext(ctx).Model(&model.User{}). result := r.db.WithContext(ctx).Model(&userModel{}).
Where("id = ? AND balance >= ?", id, amount). Where("id = ? AND balance >= ?", id, amount).
Update("balance", gorm.Expr("balance - ?", amount)) Update("balance", gorm.Expr("balance - ?", amount))
if result.Error != nil { if result.Error != nil {
...@@ -145,34 +154,104 @@ func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount flo ...@@ -145,34 +154,104 @@ func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount flo
} }
func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error { func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&userModel{}).Where("id = ?", id).
Update("concurrency", gorm.Expr("concurrency + ?", amount)).Error Update("concurrency", gorm.Expr("concurrency + ?", amount)).Error
} }
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) { func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.User{}).Where("email = ?", email).Count(&count).Error err := r.db.WithContext(ctx).Model(&userModel{}).Where("email = ?", email).Count(&count).Error
return count > 0, err return count > 0, err
} }
// RemoveGroupFromAllowedGroups 从所有用户的 allowed_groups 数组中移除指定的分组ID // RemoveGroupFromAllowedGroups 从所有用户的 allowed_groups 数组中移除指定的分组ID
// 使用 PostgreSQL 的 array_remove 函数 // 使用 PostgreSQL 的 array_remove 函数
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) { func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Model(&model.User{}). result := r.db.WithContext(ctx).Model(&userModel{}).
Where("? = ANY(allowed_groups)", groupID). Where("? = ANY(allowed_groups)", groupID).
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", groupID)) Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", groupID))
return result.RowsAffected, result.Error return result.RowsAffected, result.Error
} }
// GetFirstAdmin 获取第一个管理员用户(用于 Admin API Key 认证) // GetFirstAdmin 获取第一个管理员用户(用于 Admin API Key 认证)
func (r *userRepository) GetFirstAdmin(ctx context.Context) (*model.User, error) { func (r *userRepository) GetFirstAdmin(ctx context.Context) (*service.User, error) {
var user model.User var m userModel
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("role = ? AND status = ?", model.RoleAdmin, model.StatusActive). Where("role = ? AND status = ?", service.RoleAdmin, service.StatusActive).
Order("id ASC"). Order("id ASC").
First(&user).Error First(&m).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil) return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
} }
return &user, nil return userModelToService(&m), nil
}
type userModel struct {
ID int64 `gorm:"primaryKey"`
Email string `gorm:"uniqueIndex;size:255;not null"`
Username string `gorm:"size:100;default:''"`
Wechat string `gorm:"size:100;default:''"`
Notes string `gorm:"type:text;default:''"`
PasswordHash string `gorm:"size:255;not null"`
Role string `gorm:"size:20;default:user;not null"`
Balance float64 `gorm:"type:decimal(20,8);default:0;not null"`
Concurrency int `gorm:"default:5;not null"`
Status string `gorm:"size:20;default:active;not null"`
AllowedGroups pq.Int64Array `gorm:"type:bigint[]"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
DeletedAt gorm.DeletedAt `gorm:"index"`
}
func (userModel) TableName() string { return "users" }
func userModelToService(m *userModel) *service.User {
if m == nil {
return nil
}
return &service.User{
ID: m.ID,
Email: m.Email,
Username: m.Username,
Wechat: m.Wechat,
Notes: m.Notes,
PasswordHash: m.PasswordHash,
Role: m.Role,
Balance: m.Balance,
Concurrency: m.Concurrency,
Status: m.Status,
AllowedGroups: []int64(m.AllowedGroups),
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
}
}
func userModelFromService(u *service.User) *userModel {
if u == nil {
return nil
}
return &userModel{
ID: u.ID,
Email: u.Email,
Username: u.Username,
Wechat: u.Wechat,
Notes: u.Notes,
PasswordHash: u.PasswordHash,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
AllowedGroups: pq.Int64Array(u.AllowedGroups),
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
}
}
func applyUserModelToService(dst *service.User, src *userModel) {
if dst == nil || src == nil {
return
}
dst.ID = src.ID
dst.CreatedAt = src.CreatedAt
dst.UpdatedAt = src.UpdatedAt
} }
...@@ -7,7 +7,6 @@ import ( ...@@ -7,7 +7,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq" "github.com/lib/pq"
...@@ -35,11 +34,12 @@ func TestUserRepoSuite(t *testing.T) { ...@@ -35,11 +34,12 @@ func TestUserRepoSuite(t *testing.T) {
// --- Create / GetByID / GetByEmail / Update / Delete --- // --- Create / GetByID / GetByEmail / Update / Delete ---
func (s *UserRepoSuite) TestCreate() { func (s *UserRepoSuite) TestCreate() {
user := &model.User{ user := &service.User{
Email: "create@test.com", Email: "create@test.com",
Username: "testuser", Username: "testuser",
Role: model.RoleUser, PasswordHash: "test-password-hash",
Status: model.StatusActive, Role: service.RoleUser,
Status: service.StatusActive,
} }
err := s.repo.Create(s.ctx, user) err := s.repo.Create(s.ctx, user)
...@@ -57,7 +57,7 @@ func (s *UserRepoSuite) TestGetByID_NotFound() { ...@@ -57,7 +57,7 @@ func (s *UserRepoSuite) TestGetByID_NotFound() {
} }
func (s *UserRepoSuite) TestGetByEmail() { func (s *UserRepoSuite) TestGetByEmail() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "byemail@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "byemail@test.com"})
got, err := s.repo.GetByEmail(s.ctx, user.Email) got, err := s.repo.GetByEmail(s.ctx, user.Email)
s.Require().NoError(err, "GetByEmail") s.Require().NoError(err, "GetByEmail")
...@@ -70,7 +70,7 @@ func (s *UserRepoSuite) TestGetByEmail_NotFound() { ...@@ -70,7 +70,7 @@ func (s *UserRepoSuite) TestGetByEmail_NotFound() {
} }
func (s *UserRepoSuite) TestUpdate() { func (s *UserRepoSuite) TestUpdate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "update@test.com", Username: "original"}) user := userModelToService(mustCreateUser(s.T(), s.db, &userModel{Email: "update@test.com", Username: "original"}))
user.Username = "updated" user.Username = "updated"
err := s.repo.Update(s.ctx, user) err := s.repo.Update(s.ctx, user)
...@@ -82,7 +82,7 @@ func (s *UserRepoSuite) TestUpdate() { ...@@ -82,7 +82,7 @@ func (s *UserRepoSuite) TestUpdate() {
} }
func (s *UserRepoSuite) TestDelete() { func (s *UserRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"})
err := s.repo.Delete(s.ctx, user.ID) err := s.repo.Delete(s.ctx, user.ID)
s.Require().NoError(err, "Delete") s.Require().NoError(err, "Delete")
...@@ -94,8 +94,8 @@ func (s *UserRepoSuite) TestDelete() { ...@@ -94,8 +94,8 @@ func (s *UserRepoSuite) TestDelete() {
// --- List / ListWithFilters --- // --- List / ListWithFilters ---
func (s *UserRepoSuite) TestList() { func (s *UserRepoSuite) TestList() {
mustCreateUser(s.T(), s.db, &model.User{Email: "list1@test.com"}) mustCreateUser(s.T(), s.db, &userModel{Email: "list1@test.com"})
mustCreateUser(s.T(), s.db, &model.User{Email: "list2@test.com"}) mustCreateUser(s.T(), s.db, &userModel{Email: "list2@test.com"})
users, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}) users, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List") s.Require().NoError(err, "List")
...@@ -104,28 +104,28 @@ func (s *UserRepoSuite) TestList() { ...@@ -104,28 +104,28 @@ func (s *UserRepoSuite) TestList() {
} }
func (s *UserRepoSuite) TestListWithFilters_Status() { func (s *UserRepoSuite) TestListWithFilters_Status() {
mustCreateUser(s.T(), s.db, &model.User{Email: "active@test.com", Status: model.StatusActive}) mustCreateUser(s.T(), s.db, &userModel{Email: "active@test.com", Status: service.StatusActive})
mustCreateUser(s.T(), s.db, &model.User{Email: "disabled@test.com", Status: model.StatusDisabled}) mustCreateUser(s.T(), s.db, &userModel{Email: "disabled@test.com", Status: service.StatusDisabled})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.StatusActive, "", "") users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.StatusActive, "", "")
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(users, 1) s.Require().Len(users, 1)
s.Require().Equal(model.StatusActive, users[0].Status) s.Require().Equal(service.StatusActive, users[0].Status)
} }
func (s *UserRepoSuite) TestListWithFilters_Role() { func (s *UserRepoSuite) TestListWithFilters_Role() {
mustCreateUser(s.T(), s.db, &model.User{Email: "user@test.com", Role: model.RoleUser}) mustCreateUser(s.T(), s.db, &userModel{Email: "user@test.com", Role: service.RoleUser})
mustCreateUser(s.T(), s.db, &model.User{Email: "admin@test.com", Role: model.RoleAdmin}) mustCreateUser(s.T(), s.db, &userModel{Email: "admin@test.com", Role: service.RoleAdmin})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.RoleAdmin, "") users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.RoleAdmin, "")
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(users, 1) s.Require().Len(users, 1)
s.Require().Equal(model.RoleAdmin, users[0].Role) s.Require().Equal(service.RoleAdmin, users[0].Role)
} }
func (s *UserRepoSuite) TestListWithFilters_Search() { func (s *UserRepoSuite) TestListWithFilters_Search() {
mustCreateUser(s.T(), s.db, &model.User{Email: "alice@test.com", Username: "Alice"}) mustCreateUser(s.T(), s.db, &userModel{Email: "alice@test.com", Username: "Alice"})
mustCreateUser(s.T(), s.db, &model.User{Email: "bob@test.com", Username: "Bob"}) mustCreateUser(s.T(), s.db, &userModel{Email: "bob@test.com", Username: "Bob"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alice") users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alice")
s.Require().NoError(err) s.Require().NoError(err)
...@@ -134,8 +134,8 @@ func (s *UserRepoSuite) TestListWithFilters_Search() { ...@@ -134,8 +134,8 @@ func (s *UserRepoSuite) TestListWithFilters_Search() {
} }
func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() { func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() {
mustCreateUser(s.T(), s.db, &model.User{Email: "u1@test.com", Username: "JohnDoe"}) mustCreateUser(s.T(), s.db, &userModel{Email: "u1@test.com", Username: "JohnDoe"})
mustCreateUser(s.T(), s.db, &model.User{Email: "u2@test.com", Username: "JaneSmith"}) mustCreateUser(s.T(), s.db, &userModel{Email: "u2@test.com", Username: "JaneSmith"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "john") users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "john")
s.Require().NoError(err) s.Require().NoError(err)
...@@ -144,8 +144,8 @@ func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() { ...@@ -144,8 +144,8 @@ func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() {
} }
func (s *UserRepoSuite) TestListWithFilters_SearchByWechat() { func (s *UserRepoSuite) TestListWithFilters_SearchByWechat() {
mustCreateUser(s.T(), s.db, &model.User{Email: "w1@test.com", Wechat: "wx_hello"}) mustCreateUser(s.T(), s.db, &userModel{Email: "w1@test.com", Wechat: "wx_hello"})
mustCreateUser(s.T(), s.db, &model.User{Email: "w2@test.com", Wechat: "wx_world"}) mustCreateUser(s.T(), s.db, &userModel{Email: "w2@test.com", Wechat: "wx_world"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "wx_hello") users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "wx_hello")
s.Require().NoError(err) s.Require().NoError(err)
...@@ -154,19 +154,19 @@ func (s *UserRepoSuite) TestListWithFilters_SearchByWechat() { ...@@ -154,19 +154,19 @@ func (s *UserRepoSuite) TestListWithFilters_SearchByWechat() {
} }
func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() { func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "sub@test.com", Status: model.StatusActive}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "sub@test.com", Status: service.StatusActive})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sub"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sub"})
_ = mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ _ = mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(1 * time.Hour), ExpiresAt: time.Now().Add(1 * time.Hour),
}) })
_ = mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ _ = mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusExpired, Status: service.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-1 * time.Hour), ExpiresAt: time.Now().Add(-1 * time.Hour),
}) })
...@@ -179,29 +179,29 @@ func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() { ...@@ -179,29 +179,29 @@ func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() {
} }
func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() { func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
mustCreateUser(s.T(), s.db, &model.User{ mustCreateUser(s.T(), s.db, &userModel{
Email: "a@example.com", Email: "a@example.com",
Username: "Alice", Username: "Alice",
Wechat: "wx_a", Wechat: "wx_a",
Role: model.RoleUser, Role: service.RoleUser,
Status: model.StatusActive, Status: service.StatusActive,
Balance: 10, Balance: 10,
}) })
target := mustCreateUser(s.T(), s.db, &model.User{ target := mustCreateUser(s.T(), s.db, &userModel{
Email: "b@example.com", Email: "b@example.com",
Username: "Bob", Username: "Bob",
Wechat: "wx_b", Wechat: "wx_b",
Role: model.RoleAdmin, Role: service.RoleAdmin,
Status: model.StatusActive, Status: service.StatusActive,
Balance: 1, Balance: 1,
}) })
mustCreateUser(s.T(), s.db, &model.User{ mustCreateUser(s.T(), s.db, &userModel{
Email: "c@example.com", Email: "c@example.com",
Role: model.RoleAdmin, Role: service.RoleAdmin,
Status: model.StatusDisabled, Status: service.StatusDisabled,
}) })
users, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.StatusActive, model.RoleAdmin, "b@") users, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.StatusActive, service.RoleAdmin, "b@")
s.Require().NoError(err, "ListWithFilters") s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch") s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
s.Require().Len(users, 1, "ListWithFilters len mismatch") s.Require().Len(users, 1, "ListWithFilters len mismatch")
...@@ -211,7 +211,7 @@ func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() { ...@@ -211,7 +211,7 @@ func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
// --- Balance operations --- // --- Balance operations ---
func (s *UserRepoSuite) TestUpdateBalance() { func (s *UserRepoSuite) TestUpdateBalance() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "bal@test.com", Balance: 10}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "bal@test.com", Balance: 10})
err := s.repo.UpdateBalance(s.ctx, user.ID, 2.5) err := s.repo.UpdateBalance(s.ctx, user.ID, 2.5)
s.Require().NoError(err, "UpdateBalance") s.Require().NoError(err, "UpdateBalance")
...@@ -222,7 +222,7 @@ func (s *UserRepoSuite) TestUpdateBalance() { ...@@ -222,7 +222,7 @@ func (s *UserRepoSuite) TestUpdateBalance() {
} }
func (s *UserRepoSuite) TestUpdateBalance_Negative() { func (s *UserRepoSuite) TestUpdateBalance_Negative() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "balneg@test.com", Balance: 10}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "balneg@test.com", Balance: 10})
err := s.repo.UpdateBalance(s.ctx, user.ID, -3) err := s.repo.UpdateBalance(s.ctx, user.ID, -3)
s.Require().NoError(err, "UpdateBalance with negative") s.Require().NoError(err, "UpdateBalance with negative")
...@@ -233,7 +233,7 @@ func (s *UserRepoSuite) TestUpdateBalance_Negative() { ...@@ -233,7 +233,7 @@ func (s *UserRepoSuite) TestUpdateBalance_Negative() {
} }
func (s *UserRepoSuite) TestDeductBalance() { func (s *UserRepoSuite) TestDeductBalance() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "deduct@test.com", Balance: 10}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "deduct@test.com", Balance: 10})
err := s.repo.DeductBalance(s.ctx, user.ID, 5) err := s.repo.DeductBalance(s.ctx, user.ID, 5)
s.Require().NoError(err, "DeductBalance") s.Require().NoError(err, "DeductBalance")
...@@ -244,7 +244,7 @@ func (s *UserRepoSuite) TestDeductBalance() { ...@@ -244,7 +244,7 @@ func (s *UserRepoSuite) TestDeductBalance() {
} }
func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() { func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "insuf@test.com", Balance: 5}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "insuf@test.com", Balance: 5})
err := s.repo.DeductBalance(s.ctx, user.ID, 999) err := s.repo.DeductBalance(s.ctx, user.ID, 999)
s.Require().Error(err, "expected error for insufficient balance") s.Require().Error(err, "expected error for insufficient balance")
...@@ -252,7 +252,7 @@ func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() { ...@@ -252,7 +252,7 @@ func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() {
} }
func (s *UserRepoSuite) TestDeductBalance_ExactAmount() { func (s *UserRepoSuite) TestDeductBalance_ExactAmount() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "exact@test.com", Balance: 10}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "exact@test.com", Balance: 10})
err := s.repo.DeductBalance(s.ctx, user.ID, 10) err := s.repo.DeductBalance(s.ctx, user.ID, 10)
s.Require().NoError(err, "DeductBalance exact amount") s.Require().NoError(err, "DeductBalance exact amount")
...@@ -265,7 +265,7 @@ func (s *UserRepoSuite) TestDeductBalance_ExactAmount() { ...@@ -265,7 +265,7 @@ func (s *UserRepoSuite) TestDeductBalance_ExactAmount() {
// --- Concurrency --- // --- Concurrency ---
func (s *UserRepoSuite) TestUpdateConcurrency() { func (s *UserRepoSuite) TestUpdateConcurrency() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "conc@test.com", Concurrency: 5}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "conc@test.com", Concurrency: 5})
err := s.repo.UpdateConcurrency(s.ctx, user.ID, 3) err := s.repo.UpdateConcurrency(s.ctx, user.ID, 3)
s.Require().NoError(err, "UpdateConcurrency") s.Require().NoError(err, "UpdateConcurrency")
...@@ -276,7 +276,7 @@ func (s *UserRepoSuite) TestUpdateConcurrency() { ...@@ -276,7 +276,7 @@ func (s *UserRepoSuite) TestUpdateConcurrency() {
} }
func (s *UserRepoSuite) TestUpdateConcurrency_Negative() { func (s *UserRepoSuite) TestUpdateConcurrency_Negative() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "concneg@test.com", Concurrency: 5}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "concneg@test.com", Concurrency: 5})
err := s.repo.UpdateConcurrency(s.ctx, user.ID, -2) err := s.repo.UpdateConcurrency(s.ctx, user.ID, -2)
s.Require().NoError(err, "UpdateConcurrency negative") s.Require().NoError(err, "UpdateConcurrency negative")
...@@ -289,7 +289,7 @@ func (s *UserRepoSuite) TestUpdateConcurrency_Negative() { ...@@ -289,7 +289,7 @@ func (s *UserRepoSuite) TestUpdateConcurrency_Negative() {
// --- ExistsByEmail --- // --- ExistsByEmail ---
func (s *UserRepoSuite) TestExistsByEmail() { func (s *UserRepoSuite) TestExistsByEmail() {
mustCreateUser(s.T(), s.db, &model.User{Email: "exists@test.com"}) mustCreateUser(s.T(), s.db, &userModel{Email: "exists@test.com"})
exists, err := s.repo.ExistsByEmail(s.ctx, "exists@test.com") exists, err := s.repo.ExistsByEmail(s.ctx, "exists@test.com")
s.Require().NoError(err, "ExistsByEmail") s.Require().NoError(err, "ExistsByEmail")
...@@ -304,11 +304,11 @@ func (s *UserRepoSuite) TestExistsByEmail() { ...@@ -304,11 +304,11 @@ func (s *UserRepoSuite) TestExistsByEmail() {
func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups() { func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups() {
groupID := int64(42) groupID := int64(42)
userA := mustCreateUser(s.T(), s.db, &model.User{ userA := mustCreateUser(s.T(), s.db, &userModel{
Email: "a1@example.com", Email: "a1@example.com",
AllowedGroups: pq.Int64Array{groupID, 7}, AllowedGroups: pq.Int64Array{groupID, 7},
}) })
mustCreateUser(s.T(), s.db, &model.User{ mustCreateUser(s.T(), s.db, &userModel{
Email: "a2@example.com", Email: "a2@example.com",
AllowedGroups: pq.Int64Array{7}, AllowedGroups: pq.Int64Array{7},
}) })
...@@ -325,7 +325,7 @@ func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups() { ...@@ -325,7 +325,7 @@ func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups() {
} }
func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups_NoMatch() { func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups_NoMatch() {
mustCreateUser(s.T(), s.db, &model.User{ mustCreateUser(s.T(), s.db, &userModel{
Email: "nomatch@test.com", Email: "nomatch@test.com",
AllowedGroups: pq.Int64Array{1, 2, 3}, AllowedGroups: pq.Int64Array{1, 2, 3},
}) })
...@@ -338,15 +338,15 @@ func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups_NoMatch() { ...@@ -338,15 +338,15 @@ func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups_NoMatch() {
// --- GetFirstAdmin --- // --- GetFirstAdmin ---
func (s *UserRepoSuite) TestGetFirstAdmin() { func (s *UserRepoSuite) TestGetFirstAdmin() {
admin1 := mustCreateUser(s.T(), s.db, &model.User{ admin1 := mustCreateUser(s.T(), s.db, &userModel{
Email: "admin1@example.com", Email: "admin1@example.com",
Role: model.RoleAdmin, Role: service.RoleAdmin,
Status: model.StatusActive, Status: service.StatusActive,
}) })
mustCreateUser(s.T(), s.db, &model.User{ mustCreateUser(s.T(), s.db, &userModel{
Email: "admin2@example.com", Email: "admin2@example.com",
Role: model.RoleAdmin, Role: service.RoleAdmin,
Status: model.StatusActive, Status: service.StatusActive,
}) })
got, err := s.repo.GetFirstAdmin(s.ctx) got, err := s.repo.GetFirstAdmin(s.ctx)
...@@ -355,10 +355,10 @@ func (s *UserRepoSuite) TestGetFirstAdmin() { ...@@ -355,10 +355,10 @@ func (s *UserRepoSuite) TestGetFirstAdmin() {
} }
func (s *UserRepoSuite) TestGetFirstAdmin_NoAdmin() { func (s *UserRepoSuite) TestGetFirstAdmin_NoAdmin() {
mustCreateUser(s.T(), s.db, &model.User{ mustCreateUser(s.T(), s.db, &userModel{
Email: "user@example.com", Email: "user@example.com",
Role: model.RoleUser, Role: service.RoleUser,
Status: model.StatusActive, Status: service.StatusActive,
}) })
_, err := s.repo.GetFirstAdmin(s.ctx) _, err := s.repo.GetFirstAdmin(s.ctx)
...@@ -366,15 +366,15 @@ func (s *UserRepoSuite) TestGetFirstAdmin_NoAdmin() { ...@@ -366,15 +366,15 @@ func (s *UserRepoSuite) TestGetFirstAdmin_NoAdmin() {
} }
func (s *UserRepoSuite) TestGetFirstAdmin_DisabledAdminIgnored() { func (s *UserRepoSuite) TestGetFirstAdmin_DisabledAdminIgnored() {
mustCreateUser(s.T(), s.db, &model.User{ mustCreateUser(s.T(), s.db, &userModel{
Email: "disabled@example.com", Email: "disabled@example.com",
Role: model.RoleAdmin, Role: service.RoleAdmin,
Status: model.StatusDisabled, Status: service.StatusDisabled,
}) })
activeAdmin := mustCreateUser(s.T(), s.db, &model.User{ activeAdmin := mustCreateUser(s.T(), s.db, &userModel{
Email: "active@example.com", Email: "active@example.com",
Role: model.RoleAdmin, Role: service.RoleAdmin,
Status: model.StatusActive, Status: service.StatusActive,
}) })
got, err := s.repo.GetFirstAdmin(s.ctx) got, err := s.repo.GetFirstAdmin(s.ctx)
...@@ -385,26 +385,26 @@ func (s *UserRepoSuite) TestGetFirstAdmin_DisabledAdminIgnored() { ...@@ -385,26 +385,26 @@ func (s *UserRepoSuite) TestGetFirstAdmin_DisabledAdminIgnored() {
// --- Combined original test --- // --- Combined original test ---
func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() { func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
user1 := mustCreateUser(s.T(), s.db, &model.User{ user1 := mustCreateUser(s.T(), s.db, &userModel{
Email: "a@example.com", Email: "a@example.com",
Username: "Alice", Username: "Alice",
Wechat: "wx_a", Wechat: "wx_a",
Role: model.RoleUser, Role: service.RoleUser,
Status: model.StatusActive, Status: service.StatusActive,
Balance: 10, Balance: 10,
}) })
user2 := mustCreateUser(s.T(), s.db, &model.User{ user2 := mustCreateUser(s.T(), s.db, &userModel{
Email: "b@example.com", Email: "b@example.com",
Username: "Bob", Username: "Bob",
Wechat: "wx_b", Wechat: "wx_b",
Role: model.RoleAdmin, Role: service.RoleAdmin,
Status: model.StatusActive, Status: service.StatusActive,
Balance: 1, Balance: 1,
}) })
_ = mustCreateUser(s.T(), s.db, &model.User{ _ = mustCreateUser(s.T(), s.db, &userModel{
Email: "c@example.com", Email: "c@example.com",
Role: model.RoleAdmin, Role: service.RoleAdmin,
Status: model.StatusDisabled, Status: service.StatusDisabled,
}) })
got, err := s.repo.GetByID(s.ctx, user1.ID) got, err := s.repo.GetByID(s.ctx, user1.ID)
...@@ -441,7 +441,7 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() { ...@@ -441,7 +441,7 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
s.Require().Equal(user1.Concurrency+3, got5.Concurrency, "UpdateConcurrency mismatch") s.Require().Equal(user1.Concurrency+3, got5.Concurrency, "UpdateConcurrency mismatch")
params := pagination.PaginationParams{Page: 1, PageSize: 10} params := pagination.PaginationParams{Page: 1, PageSize: 10}
users, page, err := s.repo.ListWithFilters(s.ctx, params, model.StatusActive, model.RoleAdmin, "b@") users, page, err := s.repo.ListWithFilters(s.ctx, params, service.StatusActive, service.RoleAdmin, "b@")
s.Require().NoError(err, "ListWithFilters") s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch") s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
s.Require().Len(users, 1, "ListWithFilters len mismatch") s.Require().Len(users, 1, "ListWithFilters len mismatch")
......
...@@ -4,111 +4,113 @@ import ( ...@@ -4,111 +4,113 @@ import (
"context" "context"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"gorm.io/gorm" "gorm.io/gorm"
) )
// UserSubscriptionRepository 用户订阅仓库
type userSubscriptionRepository struct { type userSubscriptionRepository struct {
db *gorm.DB db *gorm.DB
} }
// NewUserSubscriptionRepository 创建用户订阅仓库
func NewUserSubscriptionRepository(db *gorm.DB) service.UserSubscriptionRepository { func NewUserSubscriptionRepository(db *gorm.DB) service.UserSubscriptionRepository {
return &userSubscriptionRepository{db: db} return &userSubscriptionRepository{db: db}
} }
// Create 创建订阅 func (r *userSubscriptionRepository) Create(ctx context.Context, sub *service.UserSubscription) error {
func (r *userSubscriptionRepository) Create(ctx context.Context, sub *model.UserSubscription) error { m := userSubscriptionModelFromService(sub)
err := r.db.WithContext(ctx).Create(sub).Error err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyUserSubscriptionModelToService(sub, m)
}
return translatePersistenceError(err, nil, service.ErrSubscriptionAlreadyExists) return translatePersistenceError(err, nil, service.ErrSubscriptionAlreadyExists)
} }
// GetByID 根据ID获取订阅 func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) {
func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) { var m userSubscriptionModel
var sub model.UserSubscription
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Preload("User"). Preload("User").
Preload("Group"). Preload("Group").
Preload("AssignedByUser"). Preload("AssignedByUser").
First(&sub, id).Error First(&m, id).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil) return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
} }
return &sub, nil return userSubscriptionModelToService(&m), nil
} }
// GetByUserIDAndGroupID 根据用户ID和分组ID获取订阅 func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) { var m userSubscriptionModel
var sub model.UserSubscription
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Preload("Group"). Preload("Group").
Where("user_id = ? AND group_id = ?", userID, groupID). Where("user_id = ? AND group_id = ?", userID, groupID).
First(&sub).Error First(&m).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil) return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
} }
return &sub, nil return userSubscriptionModelToService(&m), nil
} }
// GetActiveByUserIDAndGroupID 获取用户对特定分组的有效订阅 func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) { var m userSubscriptionModel
var sub model.UserSubscription
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Preload("Group"). Preload("Group").
Where("user_id = ? AND group_id = ? AND status = ? AND expires_at > ?", Where("user_id = ? AND group_id = ? AND status = ? AND expires_at > ?",
userID, groupID, model.SubscriptionStatusActive, time.Now()). userID, groupID, service.SubscriptionStatusActive, time.Now()).
First(&sub).Error First(&m).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil) return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
} }
return &sub, nil return userSubscriptionModelToService(&m), nil
} }
// Update 更新订阅 func (r *userSubscriptionRepository) Update(ctx context.Context, sub *service.UserSubscription) error {
func (r *userSubscriptionRepository) Update(ctx context.Context, sub *model.UserSubscription) error {
sub.UpdatedAt = time.Now() sub.UpdatedAt = time.Now()
return r.db.WithContext(ctx).Save(sub).Error m := userSubscriptionModelFromService(sub)
err := r.db.WithContext(ctx).Save(m).Error
if err == nil {
applyUserSubscriptionModelToService(sub, m)
}
return err
} }
// Delete 删除订阅
func (r *userSubscriptionRepository) Delete(ctx context.Context, id int64) error { func (r *userSubscriptionRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.UserSubscription{}, id).Error return r.db.WithContext(ctx).Delete(&userSubscriptionModel{}, id).Error
} }
// ListByUserID 获取用户的所有订阅 func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) { var subs []userSubscriptionModel
var subs []model.UserSubscription
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Preload("Group"). Preload("Group").
Where("user_id = ?", userID). Where("user_id = ?", userID).
Order("created_at DESC"). Order("created_at DESC").
Find(&subs).Error Find(&subs).Error
return subs, err if err != nil {
return nil, err
}
return userSubscriptionModelsToService(subs), nil
} }
// ListActiveByUserID 获取用户的所有有效订阅 func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) { var subs []userSubscriptionModel
var subs []model.UserSubscription
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Preload("Group"). Preload("Group").
Where("user_id = ? AND status = ? AND expires_at > ?", Where("user_id = ? AND status = ? AND expires_at > ?",
userID, model.SubscriptionStatusActive, time.Now()). userID, service.SubscriptionStatusActive, time.Now()).
Order("created_at DESC"). Order("created_at DESC").
Find(&subs).Error Find(&subs).Error
return subs, err if err != nil {
return nil, err
}
return userSubscriptionModelsToService(subs), nil
} }
// ListByGroupID 获取分组的所有订阅(分页) func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error) { var subs []userSubscriptionModel
var subs []model.UserSubscription
var total int64 var total int64
query := r.db.WithContext(ctx).Model(&model.UserSubscription{}).Where("group_id = ?", groupID) query := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).Where("group_id = ?", groupID)
if err := query.Count(&total).Error; err != nil { if err := query.Count(&total).Error; err != nil {
return nil, nil, err return nil, nil, err
} }
...@@ -124,26 +126,14 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID ...@@ -124,26 +126,14 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
return nil, nil, err return nil, nil, err
} }
pages := int(total) / params.Limit() return userSubscriptionModelsToService(subs), paginationResultFromTotal(total, params), nil
if int(total)%params.Limit() > 0 {
pages++
}
return subs, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
} }
// List 获取所有订阅(分页,支持筛选) func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) { var subs []userSubscriptionModel
var subs []model.UserSubscription
var total int64 var total int64
query := r.db.WithContext(ctx).Model(&model.UserSubscription{}) query := r.db.WithContext(ctx).Model(&userSubscriptionModel{})
if userID != nil { if userID != nil {
query = query.Where("user_id = ?", *userID) query = query.Where("user_id = ?", *userID)
} }
...@@ -170,156 +160,240 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination ...@@ -170,156 +160,240 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
return nil, nil, err return nil, nil, err
} }
pages := int(total) / params.Limit() return userSubscriptionModelsToService(subs), paginationResultFromTotal(total, params), nil
if int(total)%params.Limit() > 0 { }
pages++
}
return subs, &pagination.PaginationResult{ func (r *userSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
Total: total, var count int64
Page: params.Page, err := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
PageSize: params.Limit(), Where("user_id = ? AND group_id = ?", userID, groupID).
Pages: pages, Count(&count).Error
}, nil return count > 0, err
} }
// IncrementUsage 增加使用量 func (r *userSubscriptionRepository) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error {
func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error { return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
return r.db.WithContext(ctx).Model(&model.UserSubscription{}). Where("id = ?", subscriptionID).
Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
"daily_usage_usd": gorm.Expr("daily_usage_usd + ?", costUSD), "expires_at": newExpiresAt,
"weekly_usage_usd": gorm.Expr("weekly_usage_usd + ?", costUSD), "updated_at": time.Now(),
"monthly_usage_usd": gorm.Expr("monthly_usage_usd + ?", costUSD),
"updated_at": time.Now(),
}).Error }).Error
} }
// ResetDailyUsage 重置日使用量 func (r *userSubscriptionRepository) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error {
func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
return r.db.WithContext(ctx).Model(&model.UserSubscription{}). Where("id = ?", subscriptionID).
Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
"daily_usage_usd": 0, "status": status,
"daily_window_start": newWindowStart, "updated_at": time.Now(),
"updated_at": time.Now(),
}).Error }).Error
} }
// ResetWeeklyUsage 重置周使用量 func (r *userSubscriptionRepository) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error {
func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
return r.db.WithContext(ctx).Model(&model.UserSubscription{}). Where("id = ?", subscriptionID).
Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
"weekly_usage_usd": 0, "notes": notes,
"weekly_window_start": newWindowStart, "updated_at": time.Now(),
"updated_at": time.Now(),
}).Error }).Error
} }
// ResetMonthlyUsage 重置月使用量 func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, start time.Time) error {
func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
"monthly_usage_usd": 0, "daily_window_start": start,
"monthly_window_start": newWindowStart, "weekly_window_start": start,
"monthly_window_start": start,
"updated_at": time.Now(), "updated_at": time.Now(),
}).Error }).Error
} }
// ActivateWindows 激活所有窗口(首次使用时) func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, activateTime time.Time) error { return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
"daily_window_start": activateTime, "daily_usage_usd": 0,
"weekly_window_start": activateTime, "daily_window_start": newWindowStart,
"monthly_window_start": activateTime, "updated_at": time.Now(),
"updated_at": time.Now(),
}).Error }).Error
} }
// UpdateStatus 更新订阅状态 func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
func (r *userSubscriptionRepository) UpdateStatus(ctx context.Context, id int64, status string) error { return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
"status": status, "weekly_usage_usd": 0,
"updated_at": time.Now(), "weekly_window_start": newWindowStart,
"updated_at": time.Now(),
}).Error }).Error
} }
// ExtendExpiry 延长订阅过期时间 func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
func (r *userSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64, newExpiresAt time.Time) error { return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
"expires_at": newExpiresAt, "monthly_usage_usd": 0,
"updated_at": time.Now(), "monthly_window_start": newWindowStart,
"updated_at": time.Now(),
}).Error }).Error
} }
// UpdateNotes 更新订阅备注 func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
func (r *userSubscriptionRepository) UpdateNotes(ctx context.Context, id int64, notes string) error { return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
"notes": notes, "daily_usage_usd": gorm.Expr("daily_usage_usd + ?", costUSD),
"updated_at": time.Now(), "weekly_usage_usd": gorm.Expr("weekly_usage_usd + ?", costUSD),
"monthly_usage_usd": gorm.Expr("monthly_usage_usd + ?", costUSD),
"updated_at": time.Now(),
}).Error }).Error
} }
// ListExpired 获取所有已过期但状态仍为active的订阅
func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]model.UserSubscription, error) {
var subs []model.UserSubscription
err := r.db.WithContext(ctx).
Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()).
Find(&subs).Error
return subs, err
}
// BatchUpdateExpiredStatus 批量更新过期订阅状态
func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) { func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
result := r.db.WithContext(ctx).Model(&model.UserSubscription{}). result := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()). Where("status = ? AND expires_at <= ?", service.SubscriptionStatusActive, time.Now()).
Updates(map[string]any{ Updates(map[string]any{
"status": model.SubscriptionStatusExpired, "status": service.SubscriptionStatusExpired,
"updated_at": time.Now(), "updated_at": time.Now(),
}) })
return result.RowsAffected, result.Error return result.RowsAffected, result.Error
} }
// ExistsByUserIDAndGroupID 检查用户是否已有该分组的订阅 // Extra repository helpers (currently used only by integration tests).
func (r *userSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
var count int64 func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]service.UserSubscription, error) {
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}). var subs []userSubscriptionModel
Where("user_id = ? AND group_id = ?", userID, groupID). err := r.db.WithContext(ctx).
Count(&count).Error Where("status = ? AND expires_at <= ?", service.SubscriptionStatusActive, time.Now()).
return count > 0, err Find(&subs).Error
if err != nil {
return nil, err
}
return userSubscriptionModelsToService(subs), nil
} }
// CountByGroupID 获取分组的订阅数量
func (r *userSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *userSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}). err := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("group_id = ?", groupID). Where("group_id = ?", groupID).
Count(&count).Error Count(&count).Error
return count, err return count, err
} }
// CountActiveByGroupID 获取分组的有效订阅数量
func (r *userSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *userSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}). err := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("group_id = ? AND status = ? AND expires_at > ?", Where("group_id = ? AND status = ? AND expires_at > ?",
groupID, model.SubscriptionStatusActive, time.Now()). groupID, service.SubscriptionStatusActive, time.Now()).
Count(&count).Error Count(&count).Error
return count, err return count, err
} }
// DeleteByGroupID 删除分组相关的所有订阅记录
func (r *userSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *userSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.UserSubscription{}) result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&userSubscriptionModel{})
return result.RowsAffected, result.Error return result.RowsAffected, result.Error
} }
type userSubscriptionModel struct {
ID int64 `gorm:"primaryKey"`
UserID int64 `gorm:"index;not null"`
GroupID int64 `gorm:"index;not null"`
StartsAt time.Time `gorm:"not null"`
ExpiresAt time.Time `gorm:"not null"`
Status string `gorm:"size:20;default:active;not null"`
DailyWindowStart *time.Time
WeeklyWindowStart *time.Time
MonthlyWindowStart *time.Time
DailyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null"`
WeeklyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null"`
MonthlyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null"`
AssignedBy *int64 `gorm:"index"`
AssignedAt time.Time `gorm:"not null"`
Notes string `gorm:"type:text"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
User *userModel `gorm:"foreignKey:UserID"`
Group *groupModel `gorm:"foreignKey:GroupID"`
AssignedByUser *userModel `gorm:"foreignKey:AssignedBy"`
}
func (userSubscriptionModel) TableName() string { return "user_subscriptions" }
func userSubscriptionModelToService(m *userSubscriptionModel) *service.UserSubscription {
if m == nil {
return nil
}
return &service.UserSubscription{
ID: m.ID,
UserID: m.UserID,
GroupID: m.GroupID,
StartsAt: m.StartsAt,
ExpiresAt: m.ExpiresAt,
Status: m.Status,
DailyWindowStart: m.DailyWindowStart,
WeeklyWindowStart: m.WeeklyWindowStart,
MonthlyWindowStart: m.MonthlyWindowStart,
DailyUsageUSD: m.DailyUsageUSD,
WeeklyUsageUSD: m.WeeklyUsageUSD,
MonthlyUsageUSD: m.MonthlyUsageUSD,
AssignedBy: m.AssignedBy,
AssignedAt: m.AssignedAt,
Notes: m.Notes,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
User: userModelToService(m.User),
Group: groupModelToService(m.Group),
AssignedByUser: userModelToService(m.AssignedByUser),
}
}
func userSubscriptionModelsToService(models []userSubscriptionModel) []service.UserSubscription {
out := make([]service.UserSubscription, 0, len(models))
for i := range models {
if s := userSubscriptionModelToService(&models[i]); s != nil {
out = append(out, *s)
}
}
return out
}
func userSubscriptionModelFromService(s *service.UserSubscription) *userSubscriptionModel {
if s == nil {
return nil
}
return &userSubscriptionModel{
ID: s.ID,
UserID: s.UserID,
GroupID: s.GroupID,
StartsAt: s.StartsAt,
ExpiresAt: s.ExpiresAt,
Status: s.Status,
DailyWindowStart: s.DailyWindowStart,
WeeklyWindowStart: s.WeeklyWindowStart,
MonthlyWindowStart: s.MonthlyWindowStart,
DailyUsageUSD: s.DailyUsageUSD,
WeeklyUsageUSD: s.WeeklyUsageUSD,
MonthlyUsageUSD: s.MonthlyUsageUSD,
AssignedBy: s.AssignedBy,
AssignedAt: s.AssignedAt,
Notes: s.Notes,
CreatedAt: s.CreatedAt,
UpdatedAt: s.UpdatedAt,
}
}
func applyUserSubscriptionModelToService(sub *service.UserSubscription, m *userSubscriptionModel) {
if sub == nil || m == nil {
return
}
sub.ID = m.ID
sub.CreatedAt = m.CreatedAt
sub.UpdatedAt = m.UpdatedAt
}
...@@ -7,8 +7,8 @@ import ( ...@@ -7,8 +7,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"gorm.io/gorm" "gorm.io/gorm"
) )
...@@ -33,13 +33,13 @@ func TestUserSubscriptionRepoSuite(t *testing.T) { ...@@ -33,13 +33,13 @@ func TestUserSubscriptionRepoSuite(t *testing.T) {
// --- Create / GetByID / Update / Delete --- // --- Create / GetByID / Update / Delete ---
func (s *UserSubscriptionRepoSuite) TestCreate() { func (s *UserSubscriptionRepoSuite) TestCreate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "sub-create@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "sub-create@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-create"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-create"})
sub := &model.UserSubscription{ sub := &service.UserSubscription{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
} }
...@@ -54,14 +54,14 @@ func (s *UserSubscriptionRepoSuite) TestCreate() { ...@@ -54,14 +54,14 @@ func (s *UserSubscriptionRepoSuite) TestCreate() {
} }
func (s *UserSubscriptionRepoSuite) TestGetByID_WithPreloads() { func (s *UserSubscriptionRepoSuite) TestGetByID_WithPreloads() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "preload@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "preload@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-preload"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-preload"})
admin := mustCreateUser(s.T(), s.db, &model.User{Email: "admin@test.com", Role: model.RoleAdmin}) admin := mustCreateUser(s.T(), s.db, &userModel{Email: "admin@test.com", Role: service.RoleAdmin})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
AssignedBy: &admin.ID, AssignedBy: &admin.ID,
}) })
...@@ -82,14 +82,14 @@ func (s *UserSubscriptionRepoSuite) TestGetByID_NotFound() { ...@@ -82,14 +82,14 @@ func (s *UserSubscriptionRepoSuite) TestGetByID_NotFound() {
} }
func (s *UserSubscriptionRepoSuite) TestUpdate() { func (s *UserSubscriptionRepoSuite) TestUpdate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "update@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "update@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-update"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-update"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ sub := userSubscriptionModelToService(mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) }))
sub.Notes = "updated notes" sub.Notes = "updated notes"
err := s.repo.Update(s.ctx, sub) err := s.repo.Update(s.ctx, sub)
...@@ -101,12 +101,12 @@ func (s *UserSubscriptionRepoSuite) TestUpdate() { ...@@ -101,12 +101,12 @@ func (s *UserSubscriptionRepoSuite) TestUpdate() {
} }
func (s *UserSubscriptionRepoSuite) TestDelete() { func (s *UserSubscriptionRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-delete"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-delete"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
...@@ -120,12 +120,12 @@ func (s *UserSubscriptionRepoSuite) TestDelete() { ...@@ -120,12 +120,12 @@ func (s *UserSubscriptionRepoSuite) TestDelete() {
// --- GetByUserIDAndGroupID / GetActiveByUserIDAndGroupID --- // --- GetByUserIDAndGroupID / GetActiveByUserIDAndGroupID ---
func (s *UserSubscriptionRepoSuite) TestGetByUserIDAndGroupID() { func (s *UserSubscriptionRepoSuite) TestGetByUserIDAndGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "byuser@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "byuser@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-byuser"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-byuser"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
...@@ -141,14 +141,14 @@ func (s *UserSubscriptionRepoSuite) TestGetByUserIDAndGroupID_NotFound() { ...@@ -141,14 +141,14 @@ func (s *UserSubscriptionRepoSuite) TestGetByUserIDAndGroupID_NotFound() {
} }
func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID() { func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "active@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "active@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-active"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-active"})
// Create active subscription (future expiry) // Create active subscription (future expiry)
active := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ active := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(2 * time.Hour), ExpiresAt: time.Now().Add(2 * time.Hour),
}) })
...@@ -158,14 +158,14 @@ func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID() { ...@@ -158,14 +158,14 @@ func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID() {
} }
func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID_ExpiredIgnored() { func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID_ExpiredIgnored() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "expired@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "expired@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-expired"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-expired"})
// Create expired subscription (past expiry but active status) // Create expired subscription (past expiry but active status)
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-2 * time.Hour), ExpiresAt: time.Now().Add(-2 * time.Hour),
}) })
...@@ -176,20 +176,20 @@ func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID_ExpiredIgnor ...@@ -176,20 +176,20 @@ func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID_ExpiredIgnor
// --- ListByUserID / ListActiveByUserID --- // --- ListByUserID / ListActiveByUserID ---
func (s *UserSubscriptionRepoSuite) TestListByUserID() { func (s *UserSubscriptionRepoSuite) TestListByUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listby@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "listby@test.com"})
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list1"}) g1 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list1"})
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list2"}) g2 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list2"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: g1.ID, GroupID: g1.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: g2.ID, GroupID: g2.ID,
Status: model.SubscriptionStatusExpired, Status: service.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour), ExpiresAt: time.Now().Add(-24 * time.Hour),
}) })
...@@ -202,46 +202,46 @@ func (s *UserSubscriptionRepoSuite) TestListByUserID() { ...@@ -202,46 +202,46 @@ func (s *UserSubscriptionRepoSuite) TestListByUserID() {
} }
func (s *UserSubscriptionRepoSuite) TestListActiveByUserID() { func (s *UserSubscriptionRepoSuite) TestListActiveByUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listactive@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "listactive@test.com"})
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-act1"}) g1 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-act1"})
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-act2"}) g2 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-act2"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: g1.ID, GroupID: g1.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: g2.ID, GroupID: g2.ID,
Status: model.SubscriptionStatusExpired, Status: service.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour), ExpiresAt: time.Now().Add(-24 * time.Hour),
}) })
subs, err := s.repo.ListActiveByUserID(s.ctx, user.ID) subs, err := s.repo.ListActiveByUserID(s.ctx, user.ID)
s.Require().NoError(err, "ListActiveByUserID") s.Require().NoError(err, "ListActiveByUserID")
s.Require().Len(subs, 1) s.Require().Len(subs, 1)
s.Require().Equal(model.SubscriptionStatusActive, subs[0].Status) s.Require().Equal(service.SubscriptionStatusActive, subs[0].Status)
} }
// --- ListByGroupID --- // --- ListByGroupID ---
func (s *UserSubscriptionRepoSuite) TestListByGroupID() { func (s *UserSubscriptionRepoSuite) TestListByGroupID() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "u1@test.com"}) user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "u1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "u2@test.com"}) user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "u2@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-listgrp"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-listgrp"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user1.ID, UserID: user1.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user2.ID, UserID: user2.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
...@@ -258,13 +258,13 @@ func (s *UserSubscriptionRepoSuite) TestListByGroupID() { ...@@ -258,13 +258,13 @@ func (s *UserSubscriptionRepoSuite) TestListByGroupID() {
// --- List with filters --- // --- List with filters ---
func (s *UserSubscriptionRepoSuite) TestList_NoFilters() { func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "list@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "list@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
...@@ -275,20 +275,20 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() { ...@@ -275,20 +275,20 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
} }
func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() { func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "filter1@test.com"}) user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "filter1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "filter2@test.com"}) user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "filter2@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-filter"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-filter"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user1.ID, UserID: user1.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user2.ID, UserID: user2.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
...@@ -299,20 +299,20 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() { ...@@ -299,20 +299,20 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
} }
func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() { func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "grpfilter@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "grpfilter@test.com"})
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-f1"}) g1 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-f1"})
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-f2"}) g2 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-f2"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: g1.ID, GroupID: g1.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: g2.ID, GroupID: g2.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
...@@ -323,37 +323,37 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() { ...@@ -323,37 +323,37 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
} }
func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() { func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "statfilter@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "statfilter@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-stat"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-stat"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusExpired, Status: service.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour), ExpiresAt: time.Now().Add(-24 * time.Hour),
}) })
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, model.SubscriptionStatusExpired) subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(subs, 1) s.Require().Len(subs, 1)
s.Require().Equal(model.SubscriptionStatusExpired, subs[0].Status) s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status)
} }
// --- Usage tracking --- // --- Usage tracking ---
func (s *UserSubscriptionRepoSuite) TestIncrementUsage() { func (s *UserSubscriptionRepoSuite) TestIncrementUsage() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "usage@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "usage@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-usage"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-usage"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
...@@ -368,12 +368,12 @@ func (s *UserSubscriptionRepoSuite) TestIncrementUsage() { ...@@ -368,12 +368,12 @@ func (s *UserSubscriptionRepoSuite) TestIncrementUsage() {
} }
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Accumulates() { func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Accumulates() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "accum@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "accum@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-accum"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-accum"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
...@@ -386,12 +386,12 @@ func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Accumulates() { ...@@ -386,12 +386,12 @@ func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Accumulates() {
} }
func (s *UserSubscriptionRepoSuite) TestActivateWindows() { func (s *UserSubscriptionRepoSuite) TestActivateWindows() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "activate@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "activate@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-activate"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-activate"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
...@@ -408,12 +408,12 @@ func (s *UserSubscriptionRepoSuite) TestActivateWindows() { ...@@ -408,12 +408,12 @@ func (s *UserSubscriptionRepoSuite) TestActivateWindows() {
} }
func (s *UserSubscriptionRepoSuite) TestResetDailyUsage() { func (s *UserSubscriptionRepoSuite) TestResetDailyUsage() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "resetd@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "resetd@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-resetd"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-resetd"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
DailyUsageUSD: 10.0, DailyUsageUSD: 10.0,
WeeklyUsageUSD: 20.0, WeeklyUsageUSD: 20.0,
...@@ -431,12 +431,12 @@ func (s *UserSubscriptionRepoSuite) TestResetDailyUsage() { ...@@ -431,12 +431,12 @@ func (s *UserSubscriptionRepoSuite) TestResetDailyUsage() {
} }
func (s *UserSubscriptionRepoSuite) TestResetWeeklyUsage() { func (s *UserSubscriptionRepoSuite) TestResetWeeklyUsage() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "resetw@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "resetw@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-resetw"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-resetw"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
WeeklyUsageUSD: 15.0, WeeklyUsageUSD: 15.0,
MonthlyUsageUSD: 30.0, MonthlyUsageUSD: 30.0,
...@@ -454,12 +454,12 @@ func (s *UserSubscriptionRepoSuite) TestResetWeeklyUsage() { ...@@ -454,12 +454,12 @@ func (s *UserSubscriptionRepoSuite) TestResetWeeklyUsage() {
} }
func (s *UserSubscriptionRepoSuite) TestResetMonthlyUsage() { func (s *UserSubscriptionRepoSuite) TestResetMonthlyUsage() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "resetm@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "resetm@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-resetm"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-resetm"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
MonthlyUsageUSD: 100.0, MonthlyUsageUSD: 100.0,
}) })
...@@ -477,30 +477,30 @@ func (s *UserSubscriptionRepoSuite) TestResetMonthlyUsage() { ...@@ -477,30 +477,30 @@ func (s *UserSubscriptionRepoSuite) TestResetMonthlyUsage() {
// --- UpdateStatus / ExtendExpiry / UpdateNotes --- // --- UpdateStatus / ExtendExpiry / UpdateNotes ---
func (s *UserSubscriptionRepoSuite) TestUpdateStatus() { func (s *UserSubscriptionRepoSuite) TestUpdateStatus() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "status@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "status@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-status"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-status"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
err := s.repo.UpdateStatus(s.ctx, sub.ID, model.SubscriptionStatusExpired) err := s.repo.UpdateStatus(s.ctx, sub.ID, service.SubscriptionStatusExpired)
s.Require().NoError(err, "UpdateStatus") s.Require().NoError(err, "UpdateStatus")
got, err := s.repo.GetByID(s.ctx, sub.ID) got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Equal(model.SubscriptionStatusExpired, got.Status) s.Require().Equal(service.SubscriptionStatusExpired, got.Status)
} }
func (s *UserSubscriptionRepoSuite) TestExtendExpiry() { func (s *UserSubscriptionRepoSuite) TestExtendExpiry() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "extend@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "extend@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-extend"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-extend"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
...@@ -514,12 +514,12 @@ func (s *UserSubscriptionRepoSuite) TestExtendExpiry() { ...@@ -514,12 +514,12 @@ func (s *UserSubscriptionRepoSuite) TestExtendExpiry() {
} }
func (s *UserSubscriptionRepoSuite) TestUpdateNotes() { func (s *UserSubscriptionRepoSuite) TestUpdateNotes() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "notes@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "notes@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-notes"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-notes"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
...@@ -534,19 +534,19 @@ func (s *UserSubscriptionRepoSuite) TestUpdateNotes() { ...@@ -534,19 +534,19 @@ func (s *UserSubscriptionRepoSuite) TestUpdateNotes() {
// --- ListExpired / BatchUpdateExpiredStatus --- // --- ListExpired / BatchUpdateExpiredStatus ---
func (s *UserSubscriptionRepoSuite) TestListExpired() { func (s *UserSubscriptionRepoSuite) TestListExpired() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listexp@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "listexp@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-listexp"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-listexp"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-24 * time.Hour), ExpiresAt: time.Now().Add(-24 * time.Hour),
}) })
...@@ -556,19 +556,19 @@ func (s *UserSubscriptionRepoSuite) TestListExpired() { ...@@ -556,19 +556,19 @@ func (s *UserSubscriptionRepoSuite) TestListExpired() {
} }
func (s *UserSubscriptionRepoSuite) TestBatchUpdateExpiredStatus() { func (s *UserSubscriptionRepoSuite) TestBatchUpdateExpiredStatus() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "batch@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "batch@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-batch"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-batch"})
active := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ active := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
expiredActive := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ expiredActive := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-24 * time.Hour), ExpiresAt: time.Now().Add(-24 * time.Hour),
}) })
...@@ -577,22 +577,22 @@ func (s *UserSubscriptionRepoSuite) TestBatchUpdateExpiredStatus() { ...@@ -577,22 +577,22 @@ func (s *UserSubscriptionRepoSuite) TestBatchUpdateExpiredStatus() {
s.Require().Equal(int64(1), affected) s.Require().Equal(int64(1), affected)
gotActive, _ := s.repo.GetByID(s.ctx, active.ID) gotActive, _ := s.repo.GetByID(s.ctx, active.ID)
s.Require().Equal(model.SubscriptionStatusActive, gotActive.Status) s.Require().Equal(service.SubscriptionStatusActive, gotActive.Status)
gotExpired, _ := s.repo.GetByID(s.ctx, expiredActive.ID) gotExpired, _ := s.repo.GetByID(s.ctx, expiredActive.ID)
s.Require().Equal(model.SubscriptionStatusExpired, gotExpired.Status) s.Require().Equal(service.SubscriptionStatusExpired, gotExpired.Status)
} }
// --- ExistsByUserIDAndGroupID --- // --- ExistsByUserIDAndGroupID ---
func (s *UserSubscriptionRepoSuite) TestExistsByUserIDAndGroupID() { func (s *UserSubscriptionRepoSuite) TestExistsByUserIDAndGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "exists@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "exists@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-exists"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-exists"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
...@@ -608,20 +608,20 @@ func (s *UserSubscriptionRepoSuite) TestExistsByUserIDAndGroupID() { ...@@ -608,20 +608,20 @@ func (s *UserSubscriptionRepoSuite) TestExistsByUserIDAndGroupID() {
// --- CountByGroupID / CountActiveByGroupID --- // --- CountByGroupID / CountActiveByGroupID ---
func (s *UserSubscriptionRepoSuite) TestCountByGroupID() { func (s *UserSubscriptionRepoSuite) TestCountByGroupID() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "cnt1@test.com"}) user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "cnt1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "cnt2@test.com"}) user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "cnt2@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-count"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-count"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user1.ID, UserID: user1.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user2.ID, UserID: user2.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusExpired, Status: service.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour), ExpiresAt: time.Now().Add(-24 * time.Hour),
}) })
...@@ -631,20 +631,20 @@ func (s *UserSubscriptionRepoSuite) TestCountByGroupID() { ...@@ -631,20 +631,20 @@ func (s *UserSubscriptionRepoSuite) TestCountByGroupID() {
} }
func (s *UserSubscriptionRepoSuite) TestCountActiveByGroupID() { func (s *UserSubscriptionRepoSuite) TestCountActiveByGroupID() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "cntact1@test.com"}) user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "cntact1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "cntact2@test.com"}) user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "cntact2@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-cntact"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-cntact"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user1.ID, UserID: user1.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user2.ID, UserID: user2.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-24 * time.Hour), // expired by time ExpiresAt: time.Now().Add(-24 * time.Hour), // expired by time
}) })
...@@ -656,19 +656,19 @@ func (s *UserSubscriptionRepoSuite) TestCountActiveByGroupID() { ...@@ -656,19 +656,19 @@ func (s *UserSubscriptionRepoSuite) TestCountActiveByGroupID() {
// --- DeleteByGroupID --- // --- DeleteByGroupID ---
func (s *UserSubscriptionRepoSuite) TestDeleteByGroupID() { func (s *UserSubscriptionRepoSuite) TestDeleteByGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delgrp@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "delgrp@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-delgrp"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-delgrp"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusExpired, Status: service.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour), ExpiresAt: time.Now().Add(-24 * time.Hour),
}) })
...@@ -683,19 +683,19 @@ func (s *UserSubscriptionRepoSuite) TestDeleteByGroupID() { ...@@ -683,19 +683,19 @@ func (s *UserSubscriptionRepoSuite) TestDeleteByGroupID() {
// --- Combined original test --- // --- Combined original test ---
func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_BatchUpdateExpiredStatus() { func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_BatchUpdateExpiredStatus() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "subr@example.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "subr@example.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-subr"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-subr"})
active := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ active := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(2 * time.Hour), ExpiresAt: time.Now().Add(2 * time.Hour),
}) })
expiredActive := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ expiredActive := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-2 * time.Hour), ExpiresAt: time.Now().Add(-2 * time.Hour),
}) })
...@@ -729,5 +729,5 @@ func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_Ba ...@@ -729,5 +729,5 @@ func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_Ba
s.Require().Equal(int64(1), affected, "expected 1 affected row") s.Require().Equal(int64(1), affected, "expected 1 affected row")
updated, err := s.repo.GetByID(s.ctx, expiredActive.ID) updated, err := s.repo.GetByID(s.ctx, expiredActive.ID)
s.Require().NoError(err, "GetByID expired") s.Require().NoError(err, "GetByID expired")
s.Require().Equal(model.SubscriptionStatusExpired, updated.Status, "expected status expired") s.Require().Equal(service.SubscriptionStatusExpired, updated.Status, "expected status expired")
} }
//go:build unit
package server_test
import (
"bytes"
"context"
"errors"
"io"
"math"
"net/http"
"net/http/httptest"
"sort"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
adminhandler "github.com/Wei-Shaw/sub2api/internal/handler/admin"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestAPIContracts(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
setup func(t *testing.T, deps *contractDeps)
method string
path string
body string
headers map[string]string
wantStatus int
wantJSON string
}{
{
name: "GET /api/v1/auth/me",
method: http.MethodGet,
path: "/api/v1/auth/me",
wantStatus: http.StatusOK,
wantJSON: `{
"code": 0,
"message": "success",
"data": {
"id": 1,
"email": "alice@example.com",
"username": "alice",
"wechat": "wx_alice",
"notes": "hello",
"role": "user",
"balance": 12.5,
"concurrency": 5,
"status": "active",
"allowed_groups": null,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
}
}`,
},
{
name: "POST /api/v1/keys",
method: http.MethodPost,
path: "/api/v1/keys",
body: `{"name":"Key One","custom_key":"sk_custom_1234567890"}`,
headers: map[string]string{
"Content-Type": "application/json",
},
wantStatus: http.StatusOK,
wantJSON: `{
"code": 0,
"message": "success",
"data": {
"id": 100,
"user_id": 1,
"key": "sk_custom_1234567890",
"name": "Key One",
"group_id": null,
"status": "active",
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
}
}`,
},
{
name: "GET /api/v1/keys (paginated)",
setup: func(t *testing.T, deps *contractDeps) {
t.Helper()
deps.apiKeyRepo.MustSeed(&service.ApiKey{
ID: 100,
UserID: 1,
Key: "sk_custom_1234567890",
Name: "Key One",
Status: service.StatusActive,
CreatedAt: deps.now,
UpdatedAt: deps.now,
})
},
method: http.MethodGet,
path: "/api/v1/keys?page=1&page_size=10",
wantStatus: http.StatusOK,
wantJSON: `{
"code": 0,
"message": "success",
"data": {
"items": [
{
"id": 100,
"user_id": 1,
"key": "sk_custom_1234567890",
"name": "Key One",
"group_id": null,
"status": "active",
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
}
],
"total": 1,
"page": 1,
"page_size": 10,
"pages": 1
}
}`,
},
{
name: "GET /api/v1/usage/stats",
setup: func(t *testing.T, deps *contractDeps) {
t.Helper()
deps.usageRepo.SetUserLogs(1, []service.UsageLog{
{
ID: 1,
UserID: 1,
ApiKeyID: 100,
AccountID: 200,
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
CacheCreationTokens: 1,
CacheReadTokens: 2,
TotalCost: 0.5,
ActualCost: 0.5,
DurationMs: ptr(100),
CreatedAt: deps.now,
},
{
ID: 2,
UserID: 1,
ApiKeyID: 100,
AccountID: 200,
Model: "claude-3",
InputTokens: 5,
OutputTokens: 15,
TotalCost: 0.25,
ActualCost: 0.25,
DurationMs: ptr(300),
CreatedAt: deps.now,
},
})
},
method: http.MethodGet,
path: "/api/v1/usage/stats?start_date=2025-01-01&end_date=2025-01-02",
wantStatus: http.StatusOK,
wantJSON: `{
"code": 0,
"message": "success",
"data": {
"total_requests": 2,
"total_input_tokens": 15,
"total_output_tokens": 35,
"total_cache_tokens": 3,
"total_tokens": 53,
"total_cost": 0.75,
"total_actual_cost": 0.75,
"average_duration_ms": 200
}
}`,
},
{
name: "GET /api/v1/usage (paginated)",
setup: func(t *testing.T, deps *contractDeps) {
t.Helper()
deps.usageRepo.SetUserLogs(1, []service.UsageLog{
{
ID: 1,
UserID: 1,
ApiKeyID: 100,
AccountID: 200,
RequestID: "req_123",
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
CacheCreationTokens: 1,
CacheReadTokens: 2,
TotalCost: 0.5,
ActualCost: 0.5,
RateMultiplier: 1,
BillingType: service.BillingTypeBalance,
Stream: true,
DurationMs: ptr(100),
FirstTokenMs: ptr(50),
CreatedAt: deps.now,
},
})
},
method: http.MethodGet,
path: "/api/v1/usage?page=1&page_size=10",
wantStatus: http.StatusOK,
wantJSON: `{
"code": 0,
"message": "success",
"data": {
"items": [
{
"id": 1,
"user_id": 1,
"api_key_id": 100,
"account_id": 200,
"request_id": "req_123",
"model": "claude-3",
"group_id": null,
"subscription_id": null,
"input_tokens": 10,
"output_tokens": 20,
"cache_creation_tokens": 1,
"cache_read_tokens": 2,
"cache_creation_5m_tokens": 0,
"cache_creation_1h_tokens": 0,
"input_cost": 0,
"output_cost": 0,
"cache_creation_cost": 0,
"cache_read_cost": 0,
"total_cost": 0.5,
"actual_cost": 0.5,
"rate_multiplier": 1,
"billing_type": 0,
"stream": true,
"duration_ms": 100,
"first_token_ms": 50,
"created_at": "2025-01-02T03:04:05Z"
}
],
"total": 1,
"page": 1,
"page_size": 10,
"pages": 1
}
}`,
},
{
name: "GET /api/v1/admin/settings",
setup: func(t *testing.T, deps *contractDeps) {
t.Helper()
deps.settingRepo.SetAll(map[string]string{
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyEmailVerifyEnabled: "false",
service.SettingKeySmtpHost: "smtp.example.com",
service.SettingKeySmtpPort: "587",
service.SettingKeySmtpUsername: "user",
service.SettingKeySmtpPassword: "secret",
service.SettingKeySmtpFrom: "no-reply@example.com",
service.SettingKeySmtpFromName: "Sub2API",
service.SettingKeySmtpUseTLS: "true",
service.SettingKeyTurnstileEnabled: "true",
service.SettingKeyTurnstileSiteKey: "site-key",
service.SettingKeyTurnstileSecretKey: "secret-key",
service.SettingKeySiteName: "Sub2API",
service.SettingKeySiteLogo: "",
service.SettingKeySiteSubtitle: "Subtitle",
service.SettingKeyApiBaseUrl: "https://api.example.com",
service.SettingKeyContactInfo: "support",
service.SettingKeyDocUrl: "https://docs.example.com",
service.SettingKeyDefaultConcurrency: "5",
service.SettingKeyDefaultBalance: "1.25",
})
},
method: http.MethodGet,
path: "/api/v1/admin/settings",
wantStatus: http.StatusOK,
wantJSON: `{
"code": 0,
"message": "success",
"data": {
"registration_enabled": true,
"email_verify_enabled": false,
"smtp_host": "smtp.example.com",
"smtp_port": 587,
"smtp_username": "user",
"smtp_password": "secret",
"smtp_from_email": "no-reply@example.com",
"smtp_from_name": "Sub2API",
"smtp_use_tls": true,
"turnstile_enabled": true,
"turnstile_site_key": "site-key",
"turnstile_secret_key": "secret-key",
"site_name": "Sub2API",
"site_logo": "",
"site_subtitle": "Subtitle",
"api_base_url": "https://api.example.com",
"contact_info": "support",
"doc_url": "https://docs.example.com",
"default_concurrency": 5,
"default_balance": 1.25
}
}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
deps := newContractDeps(t)
if tt.setup != nil {
tt.setup(t, deps)
}
status, body := doRequest(t, deps.router, tt.method, tt.path, tt.body, tt.headers)
require.Equal(t, tt.wantStatus, status)
require.JSONEq(t, tt.wantJSON, body)
})
}
}
type contractDeps struct {
now time.Time
router http.Handler
apiKeyRepo *stubApiKeyRepo
usageRepo *stubUsageLogRepo
settingRepo *stubSettingRepo
}
func newContractDeps(t *testing.T) *contractDeps {
t.Helper()
now := time.Date(2025, 1, 2, 3, 4, 5, 0, time.UTC)
userRepo := &stubUserRepo{
users: map[int64]*service.User{
1: {
ID: 1,
Email: "alice@example.com",
Username: "alice",
Wechat: "wx_alice",
Notes: "hello",
Role: service.RoleUser,
Balance: 12.5,
Concurrency: 5,
Status: service.StatusActive,
AllowedGroups: nil,
CreatedAt: now,
UpdatedAt: now,
},
},
}
apiKeyRepo := newStubApiKeyRepo(now)
apiKeyCache := stubApiKeyCache{}
groupRepo := stubGroupRepo{}
userSubRepo := stubUserSubscriptionRepo{}
cfg := &config.Config{
Default: config.DefaultConfig{
ApiKeyPrefix: "sk-",
},
}
userService := service.NewUserService(userRepo)
apiKeyService := service.NewApiKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg)
usageRepo := newStubUsageLogRepo()
usageService := service.NewUsageService(usageRepo, userRepo)
settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg)
authHandler := handler.NewAuthHandler(nil, userService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil)
jwtAuth := func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
UserID: 1,
Concurrency: 5,
})
c.Set(string(middleware.ContextKeyUserRole), service.RoleUser)
c.Next()
}
adminAuth := func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
UserID: 1,
Concurrency: 5,
})
c.Set(string(middleware.ContextKeyUserRole), service.RoleAdmin)
c.Next()
}
r := gin.New()
v1 := r.Group("/api/v1")
v1Auth := v1.Group("")
v1Auth.Use(jwtAuth)
v1Auth.GET("/auth/me", authHandler.GetCurrentUser)
v1Keys := v1.Group("")
v1Keys.Use(jwtAuth)
v1Keys.GET("/keys", apiKeyHandler.List)
v1Keys.POST("/keys", apiKeyHandler.Create)
v1Usage := v1.Group("")
v1Usage.Use(jwtAuth)
v1Usage.GET("/usage", usageHandler.List)
v1Usage.GET("/usage/stats", usageHandler.Stats)
v1Admin := v1.Group("/admin")
v1Admin.Use(adminAuth)
v1Admin.GET("/settings", adminSettingHandler.GetSettings)
return &contractDeps{
now: now,
router: r,
apiKeyRepo: apiKeyRepo,
usageRepo: usageRepo,
settingRepo: settingRepo,
}
}
func doRequest(t *testing.T, router http.Handler, method, path, body string, headers map[string]string) (int, string) {
t.Helper()
req := httptest.NewRequest(method, path, bytes.NewBufferString(body))
for k, v := range headers {
req.Header.Set(k, v)
}
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
respBody, err := io.ReadAll(w.Result().Body)
require.NoError(t, err)
return w.Result().StatusCode, string(respBody)
}
func ptr[T any](v T) *T { return &v }
type stubUserRepo struct {
users map[int64]*service.User
}
func (r *stubUserRepo) Create(ctx context.Context, user *service.User) error {
return errors.New("not implemented")
}
func (r *stubUserRepo) GetByID(ctx context.Context, id int64) (*service.User, error) {
user, ok := r.users[id]
if !ok {
return nil, service.ErrUserNotFound
}
clone := *user
return &clone, nil
}
func (r *stubUserRepo) GetByEmail(ctx context.Context, email string) (*service.User, error) {
for _, user := range r.users {
if user.Email == email {
clone := *user
return &clone, nil
}
}
return nil, service.ErrUserNotFound
}
func (r *stubUserRepo) GetFirstAdmin(ctx context.Context) (*service.User, error) {
for _, user := range r.users {
if user.Role == service.RoleAdmin && user.Status == service.StatusActive {
clone := *user
return &clone, nil
}
}
return nil, service.ErrUserNotFound
}
func (r *stubUserRepo) Update(ctx context.Context, user *service.User) error {
return errors.New("not implemented")
}
func (r *stubUserRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (r *stubUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubUserRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]service.User, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubUserRepo) UpdateBalance(ctx context.Context, id int64, amount float64) error {
return errors.New("not implemented")
}
func (r *stubUserRepo) DeductBalance(ctx context.Context, id int64, amount float64) error {
return errors.New("not implemented")
}
func (r *stubUserRepo) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
return errors.New("not implemented")
}
func (r *stubUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
return false, errors.New("not implemented")
}
func (r *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}
type stubApiKeyCache struct{}
func (stubApiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
return 0, nil
}
func (stubApiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
return nil
}
func (stubApiKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
return nil
}
func (stubApiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) error {
return nil
}
func (stubApiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
return nil
}
type stubGroupRepo struct{}
func (stubGroupRepo) Create(ctx context.Context, group *service.Group) error {
return errors.New("not implemented")
}
func (stubGroupRepo) GetByID(ctx context.Context, id int64) (*service.Group, error) {
return nil, service.ErrGroupNotFound
}
func (stubGroupRepo) Update(ctx context.Context, group *service.Group) error {
return errors.New("not implemented")
}
func (stubGroupRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (stubGroupRepo) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
return nil, errors.New("not implemented")
}
func (stubGroupRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (stubGroupRepo) ListActive(ctx context.Context) ([]service.Group, error) {
return nil, errors.New("not implemented")
}
func (stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
return nil, errors.New("not implemented")
}
func (stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) {
return false, errors.New("not implemented")
}
func (stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}
type stubUserSubscriptionRepo struct{}
func (stubUserSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error {
return errors.New("not implemented")
}
func (stubUserSubscriptionRepo) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) {
return nil, errors.New("not implemented")
}
func (stubUserSubscriptionRepo) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
return nil, errors.New("not implemented")
}
func (stubUserSubscriptionRepo) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
return nil, errors.New("not implemented")
}
func (stubUserSubscriptionRepo) Update(ctx context.Context, sub *service.UserSubscription) error {
return errors.New("not implemented")
}
func (stubUserSubscriptionRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (stubUserSubscriptionRepo) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
return nil, errors.New("not implemented")
}
func (stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
return nil, errors.New("not implemented")
}
func (stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
return false, errors.New("not implemented")
}
func (stubUserSubscriptionRepo) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error {
return errors.New("not implemented")
}
func (stubUserSubscriptionRepo) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error {
return errors.New("not implemented")
}
func (stubUserSubscriptionRepo) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error {
return errors.New("not implemented")
}
func (stubUserSubscriptionRepo) ActivateWindows(ctx context.Context, id int64, start time.Time) error {
return errors.New("not implemented")
}
func (stubUserSubscriptionRepo) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return errors.New("not implemented")
}
func (stubUserSubscriptionRepo) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return errors.New("not implemented")
}
func (stubUserSubscriptionRepo) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return errors.New("not implemented")
}
func (stubUserSubscriptionRepo) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
return errors.New("not implemented")
}
func (stubUserSubscriptionRepo) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
return 0, errors.New("not implemented")
}
type stubApiKeyRepo struct {
now time.Time
nextID int64
byID map[int64]*service.ApiKey
byKey map[string]*service.ApiKey
}
func newStubApiKeyRepo(now time.Time) *stubApiKeyRepo {
return &stubApiKeyRepo{
now: now,
nextID: 100,
byID: make(map[int64]*service.ApiKey),
byKey: make(map[string]*service.ApiKey),
}
}
func (r *stubApiKeyRepo) MustSeed(key *service.ApiKey) {
if key == nil {
return
}
clone := *key
r.byID[clone.ID] = &clone
r.byKey[clone.Key] = &clone
}
func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error {
if key == nil {
return errors.New("nil key")
}
if key.ID == 0 {
key.ID = r.nextID
r.nextID++
}
if key.CreatedAt.IsZero() {
key.CreatedAt = r.now
}
if key.UpdatedAt.IsZero() {
key.UpdatedAt = r.now
}
clone := *key
r.byID[clone.ID] = &clone
r.byKey[clone.Key] = &clone
return nil
}
func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
key, ok := r.byID[id]
if !ok {
return nil, service.ErrApiKeyNotFound
}
clone := *key
return &clone, nil
}
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
found, ok := r.byKey[key]
if !ok {
return nil, service.ErrApiKeyNotFound
}
clone := *found
return &clone, nil
}
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error {
if key == nil {
return errors.New("nil key")
}
if _, ok := r.byID[key.ID]; !ok {
return service.ErrApiKeyNotFound
}
if key.UpdatedAt.IsZero() {
key.UpdatedAt = r.now
}
clone := *key
r.byID[clone.ID] = &clone
r.byKey[clone.Key] = &clone
return nil
}
func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
key, ok := r.byID[id]
if !ok {
return service.ErrApiKeyNotFound
}
delete(r.byID, id)
delete(r.byKey, key.Key)
return nil
}
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
ids := make([]int64, 0, len(r.byID))
for id := range r.byID {
if r.byID[id].UserID == userID {
ids = append(ids, id)
}
}
sort.Slice(ids, func(i, j int) bool { return ids[i] > ids[j] })
start := params.Offset()
if start > len(ids) {
start = len(ids)
}
end := start + params.Limit()
if end > len(ids) {
end = len(ids)
}
out := make([]service.ApiKey, 0, end-start)
for _, id := range ids[start:end] {
clone := *r.byID[id]
out = append(out, clone)
}
total := int64(len(ids))
pageSize := params.Limit()
pages := int(math.Ceil(float64(total) / float64(pageSize)))
if pages < 1 {
pages = 1
}
return out, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: pageSize,
Pages: pages,
}, nil
}
func (r *stubApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
var count int64
for _, key := range r.byID {
if key.UserID == userID {
count++
}
}
return count, nil
}
func (r *stubApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
_, ok := r.byKey[key]
return ok, nil
}
func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
return nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}
type stubUsageLogRepo struct {
userLogs map[int64][]service.UsageLog
}
func newStubUsageLogRepo() *stubUsageLogRepo {
return &stubUsageLogRepo{userLogs: make(map[int64][]service.UsageLog)}
}
func (r *stubUsageLogRepo) SetUserLogs(userID int64, logs []service.UsageLog) {
r.userLogs[userID] = logs
}
func (r *stubUsageLogRepo) Create(ctx context.Context, log *service.UsageLog) error {
return errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (r *stubUsageLogRepo) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
logs := r.userLogs[userID]
total := int64(len(logs))
out := paginateLogs(logs, params)
return out, paginationResult(total, params), nil
}
func (r *stubUsageLogRepo) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
logs := r.userLogs[userID]
return logs, paginationResult(int64(len(logs)), pagination.PaginationParams{Page: 1, PageSize: 100}), nil
}
func (r *stubUsageLogRepo) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) {
return nil, errors.New("not implemented")
}
type stubSettingRepo struct {
all map[string]string
}
func newStubSettingRepo() *stubSettingRepo {
return &stubSettingRepo{all: make(map[string]string)}
}
func (r *stubSettingRepo) SetAll(values map[string]string) {
r.all = make(map[string]string, len(values))
for k, v := range values {
r.all[k] = v
}
}
func (r *stubSettingRepo) Get(ctx context.Context, key string) (*service.Setting, error) {
value, ok := r.all[key]
if !ok {
return nil, service.ErrSettingNotFound
}
return &service.Setting{Key: key, Value: value}, nil
}
func (r *stubSettingRepo) GetValue(ctx context.Context, key string) (string, error) {
value, ok := r.all[key]
if !ok {
return "", service.ErrSettingNotFound
}
return value, nil
}
func (r *stubSettingRepo) Set(ctx context.Context, key, value string) error {
r.all[key] = value
return nil
}
func (r *stubSettingRepo) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
out := make(map[string]string, len(keys))
for _, key := range keys {
out[key] = r.all[key]
}
return out, nil
}
func (r *stubSettingRepo) SetMultiple(ctx context.Context, settings map[string]string) error {
for k, v := range settings {
r.all[k] = v
}
return nil
}
func (r *stubSettingRepo) GetAll(ctx context.Context) (map[string]string, error) {
out := make(map[string]string, len(r.all))
for k, v := range r.all {
out[k] = v
}
return out, nil
}
func (r *stubSettingRepo) Delete(ctx context.Context, key string) error {
delete(r.all, key)
return nil
}
func paginateLogs(logs []service.UsageLog, params pagination.PaginationParams) []service.UsageLog {
start := params.Offset()
if start > len(logs) {
start = len(logs)
}
end := start + params.Limit()
if end > len(logs) {
end = len(logs)
}
out := make([]service.UsageLog, 0, end-start)
out = append(out, logs[start:end]...)
return out
}
func paginationResult(total int64, params pagination.PaginationParams) *pagination.PaginationResult {
pageSize := params.Limit()
pages := int(math.Ceil(float64(total) / float64(pageSize)))
if pages < 1 {
pages = 1
}
return &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: pageSize,
Pages: pages,
}
}
// Ensure compile-time interface compliance.
var (
_ service.UserRepository = (*stubUserRepo)(nil)
_ service.ApiKeyRepository = (*stubApiKeyRepo)(nil)
_ service.ApiKeyCache = (*stubApiKeyCache)(nil)
_ service.GroupRepository = (*stubGroupRepo)(nil)
_ service.UserSubscriptionRepository = (*stubUserSubscriptionRepo)(nil)
_ service.UsageLogRepository = (*stubUsageLogRepo)(nil)
_ service.SettingRepository = (*stubSettingRepo)(nil)
)
...@@ -5,7 +5,6 @@ import ( ...@@ -5,7 +5,6 @@ import (
"errors" "errors"
"strings" "strings"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
...@@ -84,7 +83,11 @@ func validateAdminApiKey( ...@@ -84,7 +83,11 @@ func validateAdminApiKey(
return false return false
} }
c.Set(string(ContextKeyUser), admin) c.Set(string(ContextKeyUser), AuthSubject{
UserID: admin.ID,
Concurrency: admin.Concurrency,
})
c.Set(string(ContextKeyUserRole), admin.Role)
c.Set("auth_method", "admin_api_key") c.Set("auth_method", "admin_api_key")
return true return true
} }
...@@ -121,12 +124,16 @@ func validateJWTForAdmin( ...@@ -121,12 +124,16 @@ func validateJWTForAdmin(
} }
// 检查管理员权限 // 检查管理员权限
if user.Role != model.RoleAdmin { if !user.IsAdmin() {
AbortWithError(c, 403, "FORBIDDEN", "Admin access required") AbortWithError(c, 403, "FORBIDDEN", "Admin access required")
return false return false
} }
c.Set(string(ContextKeyUser), user) c.Set(string(ContextKeyUser), AuthSubject{
UserID: user.ID,
Concurrency: user.Concurrency,
})
c.Set(string(ContextKeyUserRole), user.Role)
c.Set("auth_method", "jwt") c.Set("auth_method", "jwt")
return true return true
......
package middleware package middleware
import ( import (
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
...@@ -10,15 +10,14 @@ import ( ...@@ -10,15 +10,14 @@ import (
// 必须在JWTAuth中间件之后使用 // 必须在JWTAuth中间件之后使用
func AdminOnly() gin.HandlerFunc { func AdminOnly() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// 从上下文获取用户 role, ok := GetUserRoleFromContext(c)
user, exists := GetUserFromContext(c) if !ok {
if !exists {
AbortWithError(c, 401, "UNAUTHORIZED", "User not found in context") AbortWithError(c, 401, "UNAUTHORIZED", "User not found in context")
return return
} }
// 检查是否为管理员 // 检查是否为管理员
if user.Role != model.RoleAdmin { if role != service.RoleAdmin {
AbortWithError(c, 403, "FORBIDDEN", "Admin access required") AbortWithError(c, 403, "FORBIDDEN", "Admin access required")
return return
} }
......
...@@ -5,11 +5,9 @@ import ( ...@@ -5,11 +5,9 @@ import (
"log" "log"
"strings" "strings"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm"
) )
// NewApiKeyAuthMiddleware 创建 API Key 认证中间件 // NewApiKeyAuthMiddleware 创建 API Key 认证中间件
...@@ -61,7 +59,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti ...@@ -61,7 +59,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
// 从数据库验证API key // 从数据库验证API key
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString) apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, service.ErrApiKeyNotFound) {
AbortWithError(c, 401, "INVALID_API_KEY", "Invalid API key") AbortWithError(c, 401, "INVALID_API_KEY", "Invalid API key")
return return
} }
...@@ -136,28 +134,32 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti ...@@ -136,28 +134,32 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
// 将API key和用户信息存入上下文 // 将API key和用户信息存入上下文
c.Set(string(ContextKeyApiKey), apiKey) c.Set(string(ContextKeyApiKey), apiKey)
c.Set(string(ContextKeyUser), apiKey.User) c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency,
})
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
c.Next() c.Next()
} }
} }
// GetApiKeyFromContext 从上下文中获取API key // GetApiKeyFromContext 从上下文中获取API key
func GetApiKeyFromContext(c *gin.Context) (*model.ApiKey, bool) { func GetApiKeyFromContext(c *gin.Context) (*service.ApiKey, bool) {
value, exists := c.Get(string(ContextKeyApiKey)) value, exists := c.Get(string(ContextKeyApiKey))
if !exists { if !exists {
return nil, false return nil, false
} }
apiKey, ok := value.(*model.ApiKey) apiKey, ok := value.(*service.ApiKey)
return apiKey, ok return apiKey, ok
} }
// GetSubscriptionFromContext 从上下文中获取订阅信息 // GetSubscriptionFromContext 从上下文中获取订阅信息
func GetSubscriptionFromContext(c *gin.Context) (*model.UserSubscription, bool) { func GetSubscriptionFromContext(c *gin.Context) (*service.UserSubscription, bool) {
value, exists := c.Get(string(ContextKeySubscription)) value, exists := c.Get(string(ContextKeySubscription))
if !exists { if !exists {
return nil, false return nil, false
} }
subscription, ok := value.(*model.UserSubscription) subscription, ok := value.(*service.UserSubscription)
return subscription, ok return subscription, ok
} }
package middleware
import "github.com/gin-gonic/gin"
// AuthSubject is the minimal authenticated identity stored in gin context.
// Decision: {UserID int64, Concurrency int}
type AuthSubject struct {
UserID int64
Concurrency int
}
func GetAuthSubjectFromContext(c *gin.Context) (AuthSubject, bool) {
value, exists := c.Get(string(ContextKeyUser))
if !exists {
return AuthSubject{}, false
}
subject, ok := value.(AuthSubject)
return subject, ok
}
func GetUserRoleFromContext(c *gin.Context) (string, bool) {
value, exists := c.Get(string(ContextKeyUserRole))
if !exists {
return "", false
}
role, ok := value.(string)
return role, ok
}
...@@ -4,7 +4,6 @@ import ( ...@@ -4,7 +4,6 @@ import (
"errors" "errors"
"strings" "strings"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
...@@ -62,19 +61,14 @@ func jwtAuth(authService *service.AuthService, userService *service.UserService) ...@@ -62,19 +61,14 @@ func jwtAuth(authService *service.AuthService, userService *service.UserService)
return return
} }
// 将用户信息存入上下文 c.Set(string(ContextKeyUser), AuthSubject{
c.Set(string(ContextKeyUser), user) UserID: user.ID,
Concurrency: user.Concurrency,
})
c.Set(string(ContextKeyUserRole), user.Role)
c.Next() c.Next()
} }
} }
// GetUserFromContext 从上下文中获取用户 // Deprecated: prefer GetAuthSubjectFromContext in auth_subject.go.
func GetUserFromContext(c *gin.Context) (*model.User, bool) {
value, exists := c.Get(string(ContextKeyUser))
if !exists {
return nil, false
}
user, ok := value.(*model.User)
return user, ok
}
...@@ -8,6 +8,8 @@ type ContextKey string ...@@ -8,6 +8,8 @@ type ContextKey string
const ( const (
// ContextKeyUser 用户上下文键 // ContextKeyUser 用户上下文键
ContextKeyUser ContextKey = "user" ContextKeyUser ContextKey = "user"
// ContextKeyUserRole 当前用户角色(string)
ContextKeyUserRole ContextKey = "user_role"
// ContextKeyApiKey API密钥上下文键 // ContextKeyApiKey API密钥上下文键
ContextKeyApiKey ContextKey = "api_key" ContextKeyApiKey ContextKey = "api_key"
// ContextKeySubscription 订阅上下文键 // ContextKeySubscription 订阅上下文键
......
package model package service
import ( import "time"
"database/sql/driver"
"encoding/json"
"errors"
"strconv"
"time"
"gorm.io/gorm" type Account struct {
) ID int64
Name string
// JSONB 用于存储JSONB数据 Platform string
type JSONB map[string]any Type string
Credentials map[string]any
func (j JSONB) Value() (driver.Value, error) { Extra map[string]any
if j == nil { ProxyID *int64
return nil, nil Concurrency int
} Priority int
return json.Marshal(j) Status string
} ErrorMessage string
LastUsedAt *time.Time
func (j *JSONB) Scan(value any) error { CreatedAt time.Time
if value == nil { UpdatedAt time.Time
*j = nil
return nil Schedulable bool
}
bytes, ok := value.([]byte) RateLimitedAt *time.Time
if !ok { RateLimitResetAt *time.Time
return errors.New("type assertion to []byte failed") OverloadUntil *time.Time
}
return json.Unmarshal(bytes, j) SessionWindowStart *time.Time
SessionWindowEnd *time.Time
SessionWindowStatus string
Proxy *Proxy
AccountGroups []AccountGroup
GroupIDs []int64
Groups []*Group
} }
type Account struct {
ID int64 `gorm:"primaryKey" json:"id"`
Name string `gorm:"size:100;not null" json:"name"`
Platform string `gorm:"size:50;not null" json:"platform"` // anthropic/openai/gemini
Type string `gorm:"size:20;not null" json:"type"` // oauth/apikey
Credentials JSONB `gorm:"type:jsonb;default:'{}'" json:"credentials"` // 凭证(加密存储)
Extra JSONB `gorm:"type:jsonb;default:'{}'" json:"extra"` // 扩展信息
ProxyID *int64 `gorm:"index" json:"proxy_id"`
Concurrency int `gorm:"default:3;not null" json:"concurrency"`
Priority int `gorm:"default:50;not null" json:"priority"` // 1-100,越小越高
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled/error
ErrorMessage string `gorm:"type:text" json:"error_message"`
LastUsedAt *time.Time `gorm:"index" json:"last_used_at"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// 调度控制
Schedulable bool `gorm:"default:true;not null" json:"schedulable"`
// 限流状态 (429)
RateLimitedAt *time.Time `gorm:"index" json:"rate_limited_at"`
RateLimitResetAt *time.Time `gorm:"index" json:"rate_limit_reset_at"`
// 过载状态 (529)
OverloadUntil *time.Time `gorm:"index" json:"overload_until"`
// 5小时时间窗口
SessionWindowStart *time.Time `json:"session_window_start"`
SessionWindowEnd *time.Time `json:"session_window_end"`
SessionWindowStatus string `gorm:"size:20" json:"session_window_status"` // allowed/allowed_warning/rejected
// 关联
Proxy *Proxy `gorm:"foreignKey:ProxyID" json:"proxy,omitempty"`
AccountGroups []AccountGroup `gorm:"foreignKey:AccountID" json:"account_groups,omitempty"`
// 虚拟字段 (不存储到数据库)
GroupIDs []int64 `gorm:"-" json:"group_ids,omitempty"`
Groups []*Group `gorm:"-" json:"groups,omitempty"`
}
func (Account) TableName() string {
return "accounts"
}
// IsActive 检查是否激活
func (a *Account) IsActive() bool { func (a *Account) IsActive() bool {
return a.Status == "active" return a.Status == StatusActive
} }
// IsSchedulable 检查账号是否可调度
func (a *Account) IsSchedulable() bool { func (a *Account) IsSchedulable() bool {
if !a.IsActive() || !a.Schedulable { if !a.IsActive() || !a.Schedulable {
return false return false
...@@ -97,7 +52,6 @@ func (a *Account) IsSchedulable() bool { ...@@ -97,7 +52,6 @@ func (a *Account) IsSchedulable() bool {
return true return true
} }
// IsRateLimited 检查是否处于限流状态
func (a *Account) IsRateLimited() bool { func (a *Account) IsRateLimited() bool {
if a.RateLimitResetAt == nil { if a.RateLimitResetAt == nil {
return false return false
...@@ -105,7 +59,6 @@ func (a *Account) IsRateLimited() bool { ...@@ -105,7 +59,6 @@ func (a *Account) IsRateLimited() bool {
return time.Now().Before(*a.RateLimitResetAt) return time.Now().Before(*a.RateLimitResetAt)
} }
// IsOverloaded 检查是否处于过载状态
func (a *Account) IsOverloaded() bool { func (a *Account) IsOverloaded() bool {
if a.OverloadUntil == nil { if a.OverloadUntil == nil {
return false return false
...@@ -113,60 +66,26 @@ func (a *Account) IsOverloaded() bool { ...@@ -113,60 +66,26 @@ func (a *Account) IsOverloaded() bool {
return time.Now().Before(*a.OverloadUntil) return time.Now().Before(*a.OverloadUntil)
} }
// IsOAuth 检查是否为OAuth类型账号(包括oauth和setup-token)
func (a *Account) IsOAuth() bool { func (a *Account) IsOAuth() bool {
return a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken return a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken
} }
// CanGetUsage 检查账号是否可以获取usage信息(只有oauth类型可以,setup-token没有profile权限)
func (a *Account) CanGetUsage() bool { func (a *Account) CanGetUsage() bool {
return a.Type == AccountTypeOAuth return a.Type == AccountTypeOAuth
} }
// GetCredential 获取凭证字段
func (a *Account) GetCredential(key string) string { func (a *Account) GetCredential(key string) string {
if a.Credentials == nil { if a.Credentials == nil {
return "" return ""
} }
if v, ok := a.Credentials[key]; ok { if v, ok := a.Credentials[key]; ok {
switch vv := v.(type) { if s, ok := v.(string); ok {
case string: return s
return vv
case json.Number:
return vv.String()
case float64:
// JSON numbers decode to float64; keep integer formatting for integer-like values.
i := int64(vv)
if vv == float64(i) {
return strconv.FormatInt(i, 10)
}
return strconv.FormatFloat(vv, 'f', -1, 64)
case float32:
f := float64(vv)
i := int64(f)
if f == float64(i) {
return strconv.FormatInt(i, 10)
}
return strconv.FormatFloat(f, 'f', -1, 64)
case int:
return strconv.FormatInt(int64(vv), 10)
case int64:
return strconv.FormatInt(vv, 10)
case int32:
return strconv.FormatInt(int64(vv), 10)
case uint:
return strconv.FormatUint(uint64(vv), 10)
case uint64:
return strconv.FormatUint(vv, 10)
case uint32:
return strconv.FormatUint(uint64(vv), 10)
} }
} }
return "" return ""
} }
// GetModelMapping 获取模型映射配置
// 返回格式: map[请求模型名]实际模型名
func (a *Account) GetModelMapping() map[string]string { func (a *Account) GetModelMapping() map[string]string {
if a.Credentials == nil { if a.Credentials == nil {
return nil return nil
...@@ -175,7 +94,6 @@ func (a *Account) GetModelMapping() map[string]string { ...@@ -175,7 +94,6 @@ func (a *Account) GetModelMapping() map[string]string {
if !ok || raw == nil { if !ok || raw == nil {
return nil return nil
} }
// 处理map[string]interface{}类型
if m, ok := raw.(map[string]any); ok { if m, ok := raw.(map[string]any); ok {
result := make(map[string]string) result := make(map[string]string)
for k, v := range m { for k, v := range m {
...@@ -190,19 +108,15 @@ func (a *Account) GetModelMapping() map[string]string { ...@@ -190,19 +108,15 @@ func (a *Account) GetModelMapping() map[string]string {
return nil return nil
} }
// IsModelSupported 检查请求的模型是否被该账号支持
// 如果没有设置模型映射,则支持所有模型
func (a *Account) IsModelSupported(requestedModel string) bool { func (a *Account) IsModelSupported(requestedModel string) bool {
mapping := a.GetModelMapping() mapping := a.GetModelMapping()
if len(mapping) == 0 { if len(mapping) == 0 {
return true // 没有映射配置,支持所有模型 return true
} }
_, exists := mapping[requestedModel] _, exists := mapping[requestedModel]
return exists return exists
} }
// GetMappedModel 获取映射后的实际模型名
// 如果没有映射,返回原始模型名
func (a *Account) GetMappedModel(requestedModel string) string { func (a *Account) GetMappedModel(requestedModel string) string {
mapping := a.GetModelMapping() mapping := a.GetModelMapping()
if len(mapping) == 0 { if len(mapping) == 0 {
...@@ -214,19 +128,17 @@ func (a *Account) GetMappedModel(requestedModel string) string { ...@@ -214,19 +128,17 @@ func (a *Account) GetMappedModel(requestedModel string) string {
return requestedModel return requestedModel
} }
// GetBaseURL 获取API基础URL(用于apikey类型账号)
func (a *Account) GetBaseURL() string { func (a *Account) GetBaseURL() string {
if a.Type != AccountTypeApiKey { if a.Type != AccountTypeApiKey {
return "" return ""
} }
baseURL := a.GetCredential("base_url") baseURL := a.GetCredential("base_url")
if baseURL == "" { if baseURL == "" {
return "https://api.anthropic.com" // 默认URL return "https://api.anthropic.com"
} }
return baseURL return baseURL
} }
// GetExtraString 从Extra字段获取字符串值
func (a *Account) GetExtraString(key string) string { func (a *Account) GetExtraString(key string) string {
if a.Extra == nil { if a.Extra == nil {
return "" return ""
...@@ -239,7 +151,6 @@ func (a *Account) GetExtraString(key string) string { ...@@ -239,7 +151,6 @@ func (a *Account) GetExtraString(key string) string {
return "" return ""
} }
// IsCustomErrorCodesEnabled 检查是否启用自定义错误码功能(仅适用于 apikey 类型)
func (a *Account) IsCustomErrorCodesEnabled() bool { func (a *Account) IsCustomErrorCodesEnabled() bool {
if a.Type != AccountTypeApiKey || a.Credentials == nil { if a.Type != AccountTypeApiKey || a.Credentials == nil {
return false return false
...@@ -252,7 +163,6 @@ func (a *Account) IsCustomErrorCodesEnabled() bool { ...@@ -252,7 +163,6 @@ func (a *Account) IsCustomErrorCodesEnabled() bool {
return false return false
} }
// GetCustomErrorCodes 获取自定义错误码列表
func (a *Account) GetCustomErrorCodes() []int { func (a *Account) GetCustomErrorCodes() []int {
if a.Credentials == nil { if a.Credentials == nil {
return nil return nil
...@@ -261,11 +171,9 @@ func (a *Account) GetCustomErrorCodes() []int { ...@@ -261,11 +171,9 @@ func (a *Account) GetCustomErrorCodes() []int {
if !ok || raw == nil { if !ok || raw == nil {
return nil return nil
} }
// 处理 []interface{} 类型(JSON反序列化后的格式)
if arr, ok := raw.([]any); ok { if arr, ok := raw.([]any); ok {
result := make([]int, 0, len(arr)) result := make([]int, 0, len(arr))
for _, v := range arr { for _, v := range arr {
// JSON 数字默认解析为 float64
if f, ok := v.(float64); ok { if f, ok := v.(float64); ok {
result = append(result, int(f)) result = append(result, int(f))
} }
...@@ -275,18 +183,14 @@ func (a *Account) GetCustomErrorCodes() []int { ...@@ -275,18 +183,14 @@ func (a *Account) GetCustomErrorCodes() []int {
return nil return nil
} }
// ShouldHandleErrorCode 检查指定错误码是否应该被处理(停止调度/标记限流等)
// 如果未启用自定义错误码或列表为空,返回 true(使用默认策略)
// 如果启用且列表非空,只有在列表中的错误码才返回 true
func (a *Account) ShouldHandleErrorCode(statusCode int) bool { func (a *Account) ShouldHandleErrorCode(statusCode int) bool {
if !a.IsCustomErrorCodesEnabled() { if !a.IsCustomErrorCodesEnabled() {
return true // 未启用,使用默认策略 return true
} }
codes := a.GetCustomErrorCodes() codes := a.GetCustomErrorCodes()
if len(codes) == 0 { if len(codes) == 0 {
return true // 启用但列表为空,fallback到默认策略 return true
} }
// 检查是否在自定义列表中
for _, code := range codes { for _, code := range codes {
if code == statusCode { if code == statusCode {
return true return true
...@@ -295,8 +199,6 @@ func (a *Account) ShouldHandleErrorCode(statusCode int) bool { ...@@ -295,8 +199,6 @@ func (a *Account) ShouldHandleErrorCode(statusCode int) bool {
return false return false
} }
// IsInterceptWarmupEnabled 检查是否启用预热请求拦截
// 启用后,标题生成、Warmup等预热请求将返回mock响应,不消耗上游token
func (a *Account) IsInterceptWarmupEnabled() bool { func (a *Account) IsInterceptWarmupEnabled() bool {
if a.Credentials == nil { if a.Credentials == nil {
return false return false
...@@ -309,36 +211,22 @@ func (a *Account) IsInterceptWarmupEnabled() bool { ...@@ -309,36 +211,22 @@ func (a *Account) IsInterceptWarmupEnabled() bool {
return false return false
} }
// =============== OpenAI 相关方法 ===============
// IsOpenAI 检查是否为 OpenAI 平台账号
func (a *Account) IsOpenAI() bool { func (a *Account) IsOpenAI() bool {
return a.Platform == PlatformOpenAI return a.Platform == PlatformOpenAI
} }
// IsAnthropic 检查是否为 Anthropic 平台账号
func (a *Account) IsAnthropic() bool { func (a *Account) IsAnthropic() bool {
return a.Platform == PlatformAnthropic return a.Platform == PlatformAnthropic
} }
// IsGemini 检查是否为 Gemini 平台账号
func (a *Account) IsGemini() bool {
return a.Platform == PlatformGemini
}
// IsOpenAIOAuth 检查是否为 OpenAI OAuth 类型账号
func (a *Account) IsOpenAIOAuth() bool { func (a *Account) IsOpenAIOAuth() bool {
return a.IsOpenAI() && a.Type == AccountTypeOAuth return a.IsOpenAI() && a.Type == AccountTypeOAuth
} }
// IsOpenAIApiKey 检查是否为 OpenAI API Key 类型账号(Response 账号)
func (a *Account) IsOpenAIApiKey() bool { func (a *Account) IsOpenAIApiKey() bool {
return a.IsOpenAI() && a.Type == AccountTypeApiKey return a.IsOpenAI() && a.Type == AccountTypeApiKey
} }
// GetOpenAIBaseURL 获取 OpenAI API 基础 URL
// 对于 API Key 类型账号,从 credentials 中获取 base_url
// 对于 OAuth 类型账号,返回默认的 OpenAI API URL
func (a *Account) GetOpenAIBaseURL() string { func (a *Account) GetOpenAIBaseURL() string {
if !a.IsOpenAI() { if !a.IsOpenAI() {
return "" return ""
...@@ -349,10 +237,9 @@ func (a *Account) GetOpenAIBaseURL() string { ...@@ -349,10 +237,9 @@ func (a *Account) GetOpenAIBaseURL() string {
return baseURL return baseURL
} }
} }
return "https://api.openai.com" // OpenAI 默认 API URL return "https://api.openai.com"
} }
// GetOpenAIAccessToken 获取 OpenAI 访问令牌
func (a *Account) GetOpenAIAccessToken() string { func (a *Account) GetOpenAIAccessToken() string {
if !a.IsOpenAI() { if !a.IsOpenAI() {
return "" return ""
...@@ -360,7 +247,6 @@ func (a *Account) GetOpenAIAccessToken() string { ...@@ -360,7 +247,6 @@ func (a *Account) GetOpenAIAccessToken() string {
return a.GetCredential("access_token") return a.GetCredential("access_token")
} }
// GetOpenAIRefreshToken 获取 OpenAI 刷新令牌
func (a *Account) GetOpenAIRefreshToken() string { func (a *Account) GetOpenAIRefreshToken() string {
if !a.IsOpenAIOAuth() { if !a.IsOpenAIOAuth() {
return "" return ""
...@@ -368,7 +254,6 @@ func (a *Account) GetOpenAIRefreshToken() string { ...@@ -368,7 +254,6 @@ func (a *Account) GetOpenAIRefreshToken() string {
return a.GetCredential("refresh_token") return a.GetCredential("refresh_token")
} }
// GetOpenAIIDToken 获取 OpenAI ID Token(JWT,包含用户信息)
func (a *Account) GetOpenAIIDToken() string { func (a *Account) GetOpenAIIDToken() string {
if !a.IsOpenAIOAuth() { if !a.IsOpenAIOAuth() {
return "" return ""
...@@ -376,7 +261,6 @@ func (a *Account) GetOpenAIIDToken() string { ...@@ -376,7 +261,6 @@ func (a *Account) GetOpenAIIDToken() string {
return a.GetCredential("id_token") return a.GetCredential("id_token")
} }
// GetOpenAIApiKey 获取 OpenAI API Key(用于 Response 账号)
func (a *Account) GetOpenAIApiKey() string { func (a *Account) GetOpenAIApiKey() string {
if !a.IsOpenAIApiKey() { if !a.IsOpenAIApiKey() {
return "" return ""
...@@ -384,8 +268,6 @@ func (a *Account) GetOpenAIApiKey() string { ...@@ -384,8 +268,6 @@ func (a *Account) GetOpenAIApiKey() string {
return a.GetCredential("api_key") return a.GetCredential("api_key")
} }
// GetOpenAIUserAgent 获取 OpenAI 自定义 User-Agent
// 返回空字符串表示透传原始 User-Agent
func (a *Account) GetOpenAIUserAgent() string { func (a *Account) GetOpenAIUserAgent() string {
if !a.IsOpenAI() { if !a.IsOpenAI() {
return "" return ""
...@@ -393,7 +275,6 @@ func (a *Account) GetOpenAIUserAgent() string { ...@@ -393,7 +275,6 @@ func (a *Account) GetOpenAIUserAgent() string {
return a.GetCredential("user_agent") return a.GetCredential("user_agent")
} }
// GetChatGPTAccountID 获取 ChatGPT 账号 ID(从 ID Token 解析)
func (a *Account) GetChatGPTAccountID() string { func (a *Account) GetChatGPTAccountID() string {
if !a.IsOpenAIOAuth() { if !a.IsOpenAIOAuth() {
return "" return ""
...@@ -401,7 +282,6 @@ func (a *Account) GetChatGPTAccountID() string { ...@@ -401,7 +282,6 @@ func (a *Account) GetChatGPTAccountID() string {
return a.GetCredential("chatgpt_account_id") return a.GetCredential("chatgpt_account_id")
} }
// GetChatGPTUserID 获取 ChatGPT 用户 ID(从 ID Token 解析)
func (a *Account) GetChatGPTUserID() string { func (a *Account) GetChatGPTUserID() string {
if !a.IsOpenAIOAuth() { if !a.IsOpenAIOAuth() {
return "" return ""
...@@ -409,7 +289,6 @@ func (a *Account) GetChatGPTUserID() string { ...@@ -409,7 +289,6 @@ func (a *Account) GetChatGPTUserID() string {
return a.GetCredential("chatgpt_user_id") return a.GetCredential("chatgpt_user_id")
} }
// GetOpenAIOrganizationID 获取 OpenAI 组织 ID
func (a *Account) GetOpenAIOrganizationID() string { func (a *Account) GetOpenAIOrganizationID() string {
if !a.IsOpenAIOAuth() { if !a.IsOpenAIOAuth() {
return "" return ""
...@@ -417,7 +296,6 @@ func (a *Account) GetOpenAIOrganizationID() string { ...@@ -417,7 +296,6 @@ func (a *Account) GetOpenAIOrganizationID() string {
return a.GetCredential("organization_id") return a.GetCredential("organization_id")
} }
// GetOpenAITokenExpiresAt 获取 OpenAI Token 过期时间
func (a *Account) GetOpenAITokenExpiresAt() *time.Time { func (a *Account) GetOpenAITokenExpiresAt() *time.Time {
if !a.IsOpenAIOAuth() { if !a.IsOpenAIOAuth() {
return nil return nil
...@@ -426,25 +304,21 @@ func (a *Account) GetOpenAITokenExpiresAt() *time.Time { ...@@ -426,25 +304,21 @@ func (a *Account) GetOpenAITokenExpiresAt() *time.Time {
if expiresAtStr == "" { if expiresAtStr == "" {
return nil return nil
} }
// 尝试解析时间
t, err := time.Parse(time.RFC3339, expiresAtStr) t, err := time.Parse(time.RFC3339, expiresAtStr)
if err != nil { if err != nil {
// 尝试解析为 Unix 时间戳
if v, ok := a.Credentials["expires_at"].(float64); ok { if v, ok := a.Credentials["expires_at"].(float64); ok {
t = time.Unix(int64(v), 0) tt := time.Unix(int64(v), 0)
return &t return &tt
} }
return nil return nil
} }
return &t return &t
} }
// IsOpenAITokenExpired 检查 OpenAI Token 是否过期
func (a *Account) IsOpenAITokenExpired() bool { func (a *Account) IsOpenAITokenExpired() bool {
expiresAt := a.GetOpenAITokenExpiresAt() expiresAt := a.GetOpenAITokenExpiresAt()
if expiresAt == nil { if expiresAt == nil {
return false // 没有过期时间信息,假设未过期 return false
} }
// 提前 60 秒认为过期,便于刷新
return time.Now().Add(60 * time.Second).After(*expiresAt) return time.Now().Add(60 * time.Second).After(*expiresAt)
} }
package service
import "time"
type AccountGroup struct {
AccountID int64
GroupID int64
Priority int
CreatedAt time.Time
Account *Account
Group *Group
}
...@@ -6,7 +6,6 @@ import ( ...@@ -6,7 +6,6 @@ import (
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
) )
...@@ -15,29 +14,29 @@ var ( ...@@ -15,29 +14,29 @@ var (
) )
type AccountRepository interface { type AccountRepository interface {
Create(ctx context.Context, account *model.Account) error Create(ctx context.Context, account *Account) error
GetByID(ctx context.Context, id int64) (*model.Account, error) GetByID(ctx context.Context, id int64) (*Account, error)
// GetByCRSAccountID finds an account previously synced from CRS. // GetByCRSAccountID finds an account previously synced from CRS.
// Returns (nil, nil) if not found. // Returns (nil, nil) if not found.
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error)
Update(ctx context.Context, account *model.Account) error Update(ctx context.Context, account *Account) error
Delete(ctx context.Context, id int64) error Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error)
ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) ListByGroup(ctx context.Context, groupID int64) ([]Account, error)
ListActive(ctx context.Context) ([]model.Account, error) ListActive(ctx context.Context) ([]Account, error)
ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) ListByPlatform(ctx context.Context, platform string) ([]Account, error)
UpdateLastUsed(ctx context.Context, id int64) error UpdateLastUsed(ctx context.Context, id int64) error
SetError(ctx context.Context, id int64, errorMsg string) error SetError(ctx context.Context, id int64, errorMsg string) error
SetSchedulable(ctx context.Context, id int64, schedulable bool) error SetSchedulable(ctx context.Context, id int64, schedulable bool) error
BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error
ListSchedulable(ctx context.Context) ([]model.Account, error) ListSchedulable(ctx context.Context) ([]Account, error)
ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error)
ListSchedulableByPlatform(ctx context.Context, platform string) ([]model.Account, error) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error)
ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]model.Account, error) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error)
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
SetOverloaded(ctx context.Context, id int64, until time.Time) error SetOverloaded(ctx context.Context, id int64, until time.Time) error
...@@ -99,7 +98,7 @@ func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository) ...@@ -99,7 +98,7 @@ func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository)
} }
// Create 创建账号 // Create 创建账号
func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (*model.Account, error) { func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (*Account, error) {
// 验证分组是否存在(如果指定了分组) // 验证分组是否存在(如果指定了分组)
if len(req.GroupIDs) > 0 { if len(req.GroupIDs) > 0 {
for _, groupID := range req.GroupIDs { for _, groupID := range req.GroupIDs {
...@@ -111,7 +110,7 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) ( ...@@ -111,7 +110,7 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
} }
// 创建账号 // 创建账号
account := &model.Account{ account := &Account{
Name: req.Name, Name: req.Name,
Platform: req.Platform, Platform: req.Platform,
Type: req.Type, Type: req.Type,
...@@ -120,7 +119,7 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) ( ...@@ -120,7 +119,7 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
ProxyID: req.ProxyID, ProxyID: req.ProxyID,
Concurrency: req.Concurrency, Concurrency: req.Concurrency,
Priority: req.Priority, Priority: req.Priority,
Status: model.StatusActive, Status: StatusActive,
} }
if err := s.accountRepo.Create(ctx, account); err != nil { if err := s.accountRepo.Create(ctx, account); err != nil {
...@@ -138,7 +137,7 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) ( ...@@ -138,7 +137,7 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
} }
// GetByID 根据ID获取账号 // GetByID 根据ID获取账号
func (s *AccountService) GetByID(ctx context.Context, id int64) (*model.Account, error) { func (s *AccountService) GetByID(ctx context.Context, id int64) (*Account, error) {
account, err := s.accountRepo.GetByID(ctx, id) account, err := s.accountRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, fmt.Errorf("get account: %w", err) return nil, fmt.Errorf("get account: %w", err)
...@@ -147,7 +146,7 @@ func (s *AccountService) GetByID(ctx context.Context, id int64) (*model.Account, ...@@ -147,7 +146,7 @@ func (s *AccountService) GetByID(ctx context.Context, id int64) (*model.Account,
} }
// List 获取账号列表 // List 获取账号列表
func (s *AccountService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) { func (s *AccountService) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
accounts, pagination, err := s.accountRepo.List(ctx, params) accounts, pagination, err := s.accountRepo.List(ctx, params)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("list accounts: %w", err) return nil, nil, fmt.Errorf("list accounts: %w", err)
...@@ -156,7 +155,7 @@ func (s *AccountService) List(ctx context.Context, params pagination.PaginationP ...@@ -156,7 +155,7 @@ func (s *AccountService) List(ctx context.Context, params pagination.PaginationP
} }
// ListByPlatform 根据平台获取账号列表 // ListByPlatform 根据平台获取账号列表
func (s *AccountService) ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) { func (s *AccountService) ListByPlatform(ctx context.Context, platform string) ([]Account, error) {
accounts, err := s.accountRepo.ListByPlatform(ctx, platform) accounts, err := s.accountRepo.ListByPlatform(ctx, platform)
if err != nil { if err != nil {
return nil, fmt.Errorf("list accounts by platform: %w", err) return nil, fmt.Errorf("list accounts by platform: %w", err)
...@@ -165,7 +164,7 @@ func (s *AccountService) ListByPlatform(ctx context.Context, platform string) ([ ...@@ -165,7 +164,7 @@ func (s *AccountService) ListByPlatform(ctx context.Context, platform string) ([
} }
// ListByGroup 根据分组获取账号列表 // ListByGroup 根据分组获取账号列表
func (s *AccountService) ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) { func (s *AccountService) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
accounts, err := s.accountRepo.ListByGroup(ctx, groupID) accounts, err := s.accountRepo.ListByGroup(ctx, groupID)
if err != nil { if err != nil {
return nil, fmt.Errorf("list accounts by group: %w", err) return nil, fmt.Errorf("list accounts by group: %w", err)
...@@ -174,7 +173,7 @@ func (s *AccountService) ListByGroup(ctx context.Context, groupID int64) ([]mode ...@@ -174,7 +173,7 @@ func (s *AccountService) ListByGroup(ctx context.Context, groupID int64) ([]mode
} }
// Update 更新账号 // Update 更新账号
func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccountRequest) (*model.Account, error) { func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccountRequest) (*Account, error) {
account, err := s.accountRepo.GetByID(ctx, id) account, err := s.accountRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, fmt.Errorf("get account: %w", err) return nil, fmt.Errorf("get account: %w", err)
...@@ -290,13 +289,13 @@ func (s *AccountService) TestCredentials(ctx context.Context, id int64) error { ...@@ -290,13 +289,13 @@ func (s *AccountService) TestCredentials(ctx context.Context, id int64) error {
// 根据平台执行不同的测试逻辑 // 根据平台执行不同的测试逻辑
switch account.Platform { switch account.Platform {
case model.PlatformAnthropic: case PlatformAnthropic:
// TODO: 测试Anthropic API凭证 // TODO: 测试Anthropic API凭证
return nil return nil
case model.PlatformOpenAI: case PlatformOpenAI:
// TODO: 测试OpenAI API凭证 // TODO: 测试OpenAI API凭证
return nil return nil
case model.PlatformGemini: case PlatformGemini:
// TODO: 测试Gemini API凭证 // TODO: 测试Gemini API凭证
return nil return nil
default: default:
......
...@@ -11,11 +11,11 @@ import ( ...@@ -11,11 +11,11 @@ import (
"io" "io"
"log" "log"
"net/http" "net/http"
"regexp"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
...@@ -23,6 +23,10 @@ import ( ...@@ -23,6 +23,10 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
) )
// sseDataPrefix matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var sseDataPrefix = regexp.MustCompile(`^data:\s*`)
const ( const (
testClaudeAPIURL = "https://api.anthropic.com/v1/messages" testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
testOpenAIAPIURL = "https://api.openai.com/v1/responses" testOpenAIAPIURL = "https://api.openai.com/v1/responses"
...@@ -141,7 +145,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int ...@@ -141,7 +145,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
} }
// testClaudeAccountConnection tests an Anthropic Claude account's connection // testClaudeAccountConnection tests an Anthropic Claude account's connection
func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account *model.Account, modelID string) error { func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account *Account, modelID string) error {
ctx := c.Request.Context() ctx := c.Request.Context()
// Determine the model to use // Determine the model to use
...@@ -268,7 +272,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account ...@@ -268,7 +272,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
} }
// testOpenAIAccountConnection tests an OpenAI account's connection // testOpenAIAccountConnection tests an OpenAI account's connection
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *model.Account, modelID string) error { func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string) error {
ctx := c.Request.Context() ctx := c.Request.Context()
// Default to openai.DefaultTestModel for OpenAI testing // Default to openai.DefaultTestModel for OpenAI testing
...@@ -667,11 +671,11 @@ func (s *AccountTestService) processClaudeStream(c *gin.Context, body io.Reader) ...@@ -667,11 +671,11 @@ func (s *AccountTestService) processClaudeStream(c *gin.Context, body io.Reader)
} }
line = strings.TrimSpace(line) line = strings.TrimSpace(line)
if line == "" || !strings.HasPrefix(line, "data: ") { if line == "" || !sseDataPrefix.MatchString(line) {
continue continue
} }
jsonStr := strings.TrimPrefix(line, "data: ") jsonStr := sseDataPrefix.ReplaceAllString(line, "")
if jsonStr == "[DONE]" { if jsonStr == "[DONE]" {
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil return nil
...@@ -721,11 +725,11 @@ func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader) ...@@ -721,11 +725,11 @@ func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader)
} }
line = strings.TrimSpace(line) line = strings.TrimSpace(line)
if line == "" || !strings.HasPrefix(line, "data: ") { if line == "" || !sseDataPrefix.MatchString(line) {
continue continue
} }
jsonStr := strings.TrimPrefix(line, "data: ") jsonStr := sseDataPrefix.ReplaceAllString(line, "")
if jsonStr == "[DONE]" { if jsonStr == "[DONE]" {
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil return nil
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment