Commit e5a77853 authored by Forest's avatar Forest
Browse files

refactor: 调整项目结构为单向依赖

parent b3463769
...@@ -2,10 +2,10 @@ package repository ...@@ -2,10 +2,10 @@ package repository
import ( import (
"context" "context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm" "gorm.io/gorm"
...@@ -19,37 +19,47 @@ func NewProxyRepository(db *gorm.DB) service.ProxyRepository { ...@@ -19,37 +19,47 @@ func NewProxyRepository(db *gorm.DB) service.ProxyRepository {
return &proxyRepository{db: db} return &proxyRepository{db: db}
} }
func (r *proxyRepository) Create(ctx context.Context, proxy *model.Proxy) error { func (r *proxyRepository) Create(ctx context.Context, proxy *service.Proxy) error {
return r.db.WithContext(ctx).Create(proxy).Error m := proxyModelFromService(proxy)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyProxyModelToService(proxy, m)
}
return err
} }
func (r *proxyRepository) GetByID(ctx context.Context, id int64) (*model.Proxy, error) { func (r *proxyRepository) GetByID(ctx context.Context, id int64) (*service.Proxy, error) {
var proxy model.Proxy var m proxyModel
err := r.db.WithContext(ctx).First(&proxy, id).Error err := r.db.WithContext(ctx).First(&m, id).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrProxyNotFound, nil) return nil, translatePersistenceError(err, service.ErrProxyNotFound, nil)
} }
return &proxy, nil return proxyModelToService(&m), nil
} }
func (r *proxyRepository) Update(ctx context.Context, proxy *model.Proxy) error { func (r *proxyRepository) Update(ctx context.Context, proxy *service.Proxy) error {
return r.db.WithContext(ctx).Save(proxy).Error m := proxyModelFromService(proxy)
err := r.db.WithContext(ctx).Save(m).Error
if err == nil {
applyProxyModelToService(proxy, m)
}
return err
} }
func (r *proxyRepository) Delete(ctx context.Context, id int64) error { func (r *proxyRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.Proxy{}, id).Error return r.db.WithContext(ctx).Delete(&proxyModel{}, id).Error
} }
func (r *proxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) { func (r *proxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Proxy, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "") return r.ListWithFilters(ctx, params, "", "", "")
} }
// ListWithFilters lists proxies with optional filtering by protocol, status, and search query // ListWithFilters lists proxies with optional filtering by protocol, status, and search query
func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error) { func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]service.Proxy, *pagination.PaginationResult, error) {
var proxies []model.Proxy var proxies []proxyModel
var total int64 var total int64
db := r.db.WithContext(ctx).Model(&model.Proxy{}) db := r.db.WithContext(ctx).Model(&proxyModel{})
// Apply filters // Apply filters
if protocol != "" { if protocol != "" {
...@@ -71,29 +81,31 @@ func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination ...@@ -71,29 +81,31 @@ func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination
return nil, nil, err return nil, nil, err
} }
pages := int(total) / params.Limit() outProxies := make([]service.Proxy, 0, len(proxies))
if int(total)%params.Limit() > 0 { for i := range proxies {
pages++ outProxies = append(outProxies, *proxyModelToService(&proxies[i]))
} }
return proxies, &pagination.PaginationResult{ return outProxies, paginationResultFromTotal(total, params), nil
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
} }
func (r *proxyRepository) ListActive(ctx context.Context) ([]model.Proxy, error) { func (r *proxyRepository) ListActive(ctx context.Context) ([]service.Proxy, error) {
var proxies []model.Proxy var proxies []proxyModel
err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Find(&proxies).Error err := r.db.WithContext(ctx).Where("status = ?", service.StatusActive).Find(&proxies).Error
return proxies, err if err != nil {
return nil, err
}
outProxies := make([]service.Proxy, 0, len(proxies))
for i := range proxies {
outProxies = append(outProxies, *proxyModelToService(&proxies[i]))
}
return outProxies, nil
} }
// ExistsByHostPortAuth checks if a proxy with the same host, port, username, and password exists // ExistsByHostPortAuth checks if a proxy with the same host, port, username, and password exists
func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) { func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.Proxy{}). err := r.db.WithContext(ctx).Model(&proxyModel{}).
Where("host = ? AND port = ? AND username = ? AND password = ?", host, port, username, password). Where("host = ? AND port = ? AND username = ? AND password = ?", host, port, username, password).
Count(&count).Error Count(&count).Error
if err != nil { if err != nil {
...@@ -105,7 +117,7 @@ func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, ...@@ -105,7 +117,7 @@ func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string,
// CountAccountsByProxyID returns the number of accounts using a specific proxy // CountAccountsByProxyID returns the number of accounts using a specific proxy
func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) { func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.Account{}). err := r.db.WithContext(ctx).Table("accounts").
Where("proxy_id = ?", proxyID). Where("proxy_id = ?", proxyID).
Count(&count).Error Count(&count).Error
return count, err return count, err
...@@ -119,7 +131,7 @@ func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[i ...@@ -119,7 +131,7 @@ func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[i
} }
var results []result var results []result
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Model(&model.Account{}). Table("accounts").
Select("proxy_id, COUNT(*) as count"). Select("proxy_id, COUNT(*) as count").
Where("proxy_id IS NOT NULL"). Where("proxy_id IS NOT NULL").
Group("proxy_id"). Group("proxy_id").
...@@ -136,10 +148,10 @@ func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[i ...@@ -136,10 +148,10 @@ func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[i
} }
// ListActiveWithAccountCount returns all active proxies with account count, sorted by creation time descending // ListActiveWithAccountCount returns all active proxies with account count, sorted by creation time descending
func (r *proxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) { func (r *proxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]service.ProxyWithAccountCount, error) {
var proxies []model.Proxy var proxies []proxyModel
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("status = ?", model.StatusActive). Where("status = ?", service.StatusActive).
Order("created_at DESC"). Order("created_at DESC").
Find(&proxies).Error Find(&proxies).Error
if err != nil { if err != nil {
...@@ -153,13 +165,78 @@ func (r *proxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]mod ...@@ -153,13 +165,78 @@ func (r *proxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]mod
} }
// Build result with account counts // Build result with account counts
result := make([]model.ProxyWithAccountCount, len(proxies)) result := make([]service.ProxyWithAccountCount, 0, len(proxies))
for i, proxy := range proxies { for i := range proxies {
result[i] = model.ProxyWithAccountCount{ proxy := proxyModelToService(&proxies[i])
Proxy: proxy, if proxy == nil {
AccountCount: counts[proxy.ID], continue
} }
result = append(result, service.ProxyWithAccountCount{
Proxy: *proxy,
AccountCount: counts[proxy.ID],
})
} }
return result, nil return result, nil
} }
type proxyModel struct {
ID int64 `gorm:"primaryKey"`
Name string `gorm:"size:100;not null"`
Protocol string `gorm:"size:20;not null"`
Host string `gorm:"size:255;not null"`
Port int `gorm:"not null"`
Username string `gorm:"size:100"`
Password string `gorm:"size:100"`
Status string `gorm:"size:20;default:active;not null"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
DeletedAt gorm.DeletedAt `gorm:"index"`
}
func (proxyModel) TableName() string { return "proxies" }
func proxyModelToService(m *proxyModel) *service.Proxy {
if m == nil {
return nil
}
return &service.Proxy{
ID: m.ID,
Name: m.Name,
Protocol: m.Protocol,
Host: m.Host,
Port: m.Port,
Username: m.Username,
Password: m.Password,
Status: m.Status,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
}
}
func proxyModelFromService(p *service.Proxy) *proxyModel {
if p == nil {
return nil
}
return &proxyModel{
ID: p.ID,
Name: p.Name,
Protocol: p.Protocol,
Host: p.Host,
Port: p.Port,
Username: p.Username,
Password: p.Password,
Status: p.Status,
CreatedAt: p.CreatedAt,
UpdatedAt: p.UpdatedAt,
}
}
func applyProxyModelToService(proxy *service.Proxy, m *proxyModel) {
if proxy == nil || m == nil {
return
}
proxy.ID = m.ID
proxy.CreatedAt = m.CreatedAt
proxy.UpdatedAt = m.UpdatedAt
}
...@@ -7,8 +7,8 @@ import ( ...@@ -7,8 +7,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"gorm.io/gorm" "gorm.io/gorm"
) )
...@@ -33,12 +33,12 @@ func TestProxyRepoSuite(t *testing.T) { ...@@ -33,12 +33,12 @@ func TestProxyRepoSuite(t *testing.T) {
// --- Create / GetByID / Update / Delete --- // --- Create / GetByID / Update / Delete ---
func (s *ProxyRepoSuite) TestCreate() { func (s *ProxyRepoSuite) TestCreate() {
proxy := &model.Proxy{ proxy := &service.Proxy{
Name: "test-create", Name: "test-create",
Protocol: "http", Protocol: "http",
Host: "127.0.0.1", Host: "127.0.0.1",
Port: 8080, Port: 8080,
Status: model.StatusActive, Status: service.StatusActive,
} }
err := s.repo.Create(s.ctx, proxy) err := s.repo.Create(s.ctx, proxy)
...@@ -56,7 +56,7 @@ func (s *ProxyRepoSuite) TestGetByID_NotFound() { ...@@ -56,7 +56,7 @@ func (s *ProxyRepoSuite) TestGetByID_NotFound() {
} }
func (s *ProxyRepoSuite) TestUpdate() { func (s *ProxyRepoSuite) TestUpdate() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "original"}) proxy := proxyModelToService(mustCreateProxy(s.T(), s.db, &proxyModel{Name: "original"}))
proxy.Name = "updated" proxy.Name = "updated"
err := s.repo.Update(s.ctx, proxy) err := s.repo.Update(s.ctx, proxy)
...@@ -68,7 +68,7 @@ func (s *ProxyRepoSuite) TestUpdate() { ...@@ -68,7 +68,7 @@ func (s *ProxyRepoSuite) TestUpdate() {
} }
func (s *ProxyRepoSuite) TestDelete() { func (s *ProxyRepoSuite) TestDelete() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "to-delete"}) proxy := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "to-delete"})
err := s.repo.Delete(s.ctx, proxy.ID) err := s.repo.Delete(s.ctx, proxy.ID)
s.Require().NoError(err, "Delete") s.Require().NoError(err, "Delete")
...@@ -80,8 +80,8 @@ func (s *ProxyRepoSuite) TestDelete() { ...@@ -80,8 +80,8 @@ func (s *ProxyRepoSuite) TestDelete() {
// --- List / ListWithFilters --- // --- List / ListWithFilters ---
func (s *ProxyRepoSuite) TestList() { func (s *ProxyRepoSuite) TestList() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"}) mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1"})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2"}) mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p2"})
proxies, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}) proxies, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List") s.Require().NoError(err, "List")
...@@ -90,8 +90,8 @@ func (s *ProxyRepoSuite) TestList() { ...@@ -90,8 +90,8 @@ func (s *ProxyRepoSuite) TestList() {
} }
func (s *ProxyRepoSuite) TestListWithFilters_Protocol() { func (s *ProxyRepoSuite) TestListWithFilters_Protocol() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1", Protocol: "http"}) mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1", Protocol: "http"})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2", Protocol: "socks5"}) mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p2", Protocol: "socks5"})
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "socks5", "", "") proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "socks5", "", "")
s.Require().NoError(err) s.Require().NoError(err)
...@@ -100,18 +100,18 @@ func (s *ProxyRepoSuite) TestListWithFilters_Protocol() { ...@@ -100,18 +100,18 @@ func (s *ProxyRepoSuite) TestListWithFilters_Protocol() {
} }
func (s *ProxyRepoSuite) TestListWithFilters_Status() { func (s *ProxyRepoSuite) TestListWithFilters_Status() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1", Status: model.StatusActive}) mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1", Status: service.StatusActive})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2", Status: model.StatusDisabled}) mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p2", Status: service.StatusDisabled})
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusDisabled, "") proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, "")
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(proxies, 1) s.Require().Len(proxies, 1)
s.Require().Equal(model.StatusDisabled, proxies[0].Status) s.Require().Equal(service.StatusDisabled, proxies[0].Status)
} }
func (s *ProxyRepoSuite) TestListWithFilters_Search() { func (s *ProxyRepoSuite) TestListWithFilters_Search() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "production-proxy"}) mustCreateProxy(s.T(), s.db, &proxyModel{Name: "production-proxy"})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "dev-proxy"}) mustCreateProxy(s.T(), s.db, &proxyModel{Name: "dev-proxy"})
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "prod") proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "prod")
s.Require().NoError(err) s.Require().NoError(err)
...@@ -122,8 +122,8 @@ func (s *ProxyRepoSuite) TestListWithFilters_Search() { ...@@ -122,8 +122,8 @@ func (s *ProxyRepoSuite) TestListWithFilters_Search() {
// --- ListActive --- // --- ListActive ---
func (s *ProxyRepoSuite) TestListActive() { func (s *ProxyRepoSuite) TestListActive() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "active1", Status: model.StatusActive}) mustCreateProxy(s.T(), s.db, &proxyModel{Name: "active1", Status: service.StatusActive})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "inactive1", Status: model.StatusDisabled}) mustCreateProxy(s.T(), s.db, &proxyModel{Name: "inactive1", Status: service.StatusDisabled})
proxies, err := s.repo.ListActive(s.ctx) proxies, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive") s.Require().NoError(err, "ListActive")
...@@ -134,7 +134,7 @@ func (s *ProxyRepoSuite) TestListActive() { ...@@ -134,7 +134,7 @@ func (s *ProxyRepoSuite) TestListActive() {
// --- ExistsByHostPortAuth --- // --- ExistsByHostPortAuth ---
func (s *ProxyRepoSuite) TestExistsByHostPortAuth() { func (s *ProxyRepoSuite) TestExistsByHostPortAuth() {
mustCreateProxy(s.T(), s.db, &model.Proxy{ mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p1", Name: "p1",
Protocol: "http", Protocol: "http",
Host: "1.2.3.4", Host: "1.2.3.4",
...@@ -153,7 +153,7 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth() { ...@@ -153,7 +153,7 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth() {
} }
func (s *ProxyRepoSuite) TestExistsByHostPortAuth_NoAuth() { func (s *ProxyRepoSuite) TestExistsByHostPortAuth_NoAuth() {
mustCreateProxy(s.T(), s.db, &model.Proxy{ mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p-noauth", Name: "p-noauth",
Protocol: "http", Protocol: "http",
Host: "5.6.7.8", Host: "5.6.7.8",
...@@ -170,10 +170,10 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth_NoAuth() { ...@@ -170,10 +170,10 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth_NoAuth() {
// --- CountAccountsByProxyID --- // --- CountAccountsByProxyID ---
func (s *ProxyRepoSuite) TestCountAccountsByProxyID() { func (s *ProxyRepoSuite) TestCountAccountsByProxyID() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p-count"}) proxy := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p-count"})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &proxy.ID}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", ProxyID: &proxy.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &proxy.ID}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", ProxyID: &proxy.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3"}) // no proxy mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3"}) // no proxy
count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID) count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID)
s.Require().NoError(err, "CountAccountsByProxyID") s.Require().NoError(err, "CountAccountsByProxyID")
...@@ -181,7 +181,7 @@ func (s *ProxyRepoSuite) TestCountAccountsByProxyID() { ...@@ -181,7 +181,7 @@ func (s *ProxyRepoSuite) TestCountAccountsByProxyID() {
} }
func (s *ProxyRepoSuite) TestCountAccountsByProxyID_Zero() { func (s *ProxyRepoSuite) TestCountAccountsByProxyID_Zero() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p-zero"}) proxy := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p-zero"})
count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID) count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID)
s.Require().NoError(err) s.Require().NoError(err)
...@@ -191,12 +191,12 @@ func (s *ProxyRepoSuite) TestCountAccountsByProxyID_Zero() { ...@@ -191,12 +191,12 @@ func (s *ProxyRepoSuite) TestCountAccountsByProxyID_Zero() {
// --- GetAccountCountsForProxies --- // --- GetAccountCountsForProxies ---
func (s *ProxyRepoSuite) TestGetAccountCountsForProxies() { func (s *ProxyRepoSuite) TestGetAccountCountsForProxies() {
p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"}) p1 := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1"})
p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2"}) p2 := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p2"})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3", ProxyID: &p2.ID})
counts, err := s.repo.GetAccountCountsForProxies(s.ctx) counts, err := s.repo.GetAccountCountsForProxies(s.ctx)
s.Require().NoError(err, "GetAccountCountsForProxies") s.Require().NoError(err, "GetAccountCountsForProxies")
...@@ -215,24 +215,24 @@ func (s *ProxyRepoSuite) TestGetAccountCountsForProxies_Empty() { ...@@ -215,24 +215,24 @@ func (s *ProxyRepoSuite) TestGetAccountCountsForProxies_Empty() {
func (s *ProxyRepoSuite) TestListActiveWithAccountCount() { func (s *ProxyRepoSuite) TestListActiveWithAccountCount() {
base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{ p1 := mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p1", Name: "p1",
Status: model.StatusActive, Status: service.StatusActive,
CreatedAt: base.Add(-1 * time.Hour), CreatedAt: base.Add(-1 * time.Hour),
}) })
p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{ p2 := mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p2", Name: "p2",
Status: model.StatusActive, Status: service.StatusActive,
CreatedAt: base, CreatedAt: base,
}) })
mustCreateProxy(s.T(), s.db, &model.Proxy{ mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p3-inactive", Name: "p3-inactive",
Status: model.StatusDisabled, Status: service.StatusDisabled,
}) })
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3", ProxyID: &p2.ID})
withCounts, err := s.repo.ListActiveWithAccountCount(s.ctx) withCounts, err := s.repo.ListActiveWithAccountCount(s.ctx)
s.Require().NoError(err, "ListActiveWithAccountCount") s.Require().NoError(err, "ListActiveWithAccountCount")
...@@ -248,7 +248,7 @@ func (s *ProxyRepoSuite) TestListActiveWithAccountCount() { ...@@ -248,7 +248,7 @@ func (s *ProxyRepoSuite) TestListActiveWithAccountCount() {
// --- Combined original test --- // --- Combined original test ---
func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() { func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() {
p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{ p1 := mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p1", Name: "p1",
Protocol: "http", Protocol: "http",
Host: "1.2.3.4", Host: "1.2.3.4",
...@@ -258,7 +258,7 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() { ...@@ -258,7 +258,7 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() {
CreatedAt: time.Now().Add(-1 * time.Hour), CreatedAt: time.Now().Add(-1 * time.Hour),
UpdatedAt: time.Now().Add(-1 * time.Hour), UpdatedAt: time.Now().Add(-1 * time.Hour),
}) })
p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{ p2 := mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p2", Name: "p2",
Protocol: "http", Protocol: "http",
Host: "5.6.7.8", Host: "5.6.7.8",
...@@ -273,9 +273,9 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() { ...@@ -273,9 +273,9 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() {
s.Require().NoError(err, "ExistsByHostPortAuth") s.Require().NoError(err, "ExistsByHostPortAuth")
s.Require().True(exists, "expected proxy to exist") s.Require().True(exists, "expected proxy to exist")
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3", ProxyID: &p2.ID})
count1, err := s.repo.CountAccountsByProxyID(s.ctx, p1.ID) count1, err := s.repo.CountAccountsByProxyID(s.ctx, p1.ID)
s.Require().NoError(err, "CountAccountsByProxyID") s.Require().NoError(err, "CountAccountsByProxyID")
......
...@@ -4,10 +4,8 @@ import ( ...@@ -4,10 +4,8 @@ import (
"context" "context"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"gorm.io/gorm" "gorm.io/gorm"
) )
...@@ -20,48 +18,61 @@ func NewRedeemCodeRepository(db *gorm.DB) service.RedeemCodeRepository { ...@@ -20,48 +18,61 @@ func NewRedeemCodeRepository(db *gorm.DB) service.RedeemCodeRepository {
return &redeemCodeRepository{db: db} return &redeemCodeRepository{db: db}
} }
func (r *redeemCodeRepository) Create(ctx context.Context, code *model.RedeemCode) error { func (r *redeemCodeRepository) Create(ctx context.Context, code *service.RedeemCode) error {
return r.db.WithContext(ctx).Create(code).Error m := redeemCodeModelFromService(code)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyRedeemCodeModelToService(code, m)
}
return err
} }
func (r *redeemCodeRepository) CreateBatch(ctx context.Context, codes []model.RedeemCode) error { func (r *redeemCodeRepository) CreateBatch(ctx context.Context, codes []service.RedeemCode) error {
return r.db.WithContext(ctx).Create(&codes).Error if len(codes) == 0 {
return nil
}
models := make([]redeemCodeModel, 0, len(codes))
for i := range codes {
m := redeemCodeModelFromService(&codes[i])
if m != nil {
models = append(models, *m)
}
}
return r.db.WithContext(ctx).Create(&models).Error
} }
func (r *redeemCodeRepository) GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) { func (r *redeemCodeRepository) GetByID(ctx context.Context, id int64) (*service.RedeemCode, error) {
var code model.RedeemCode var m redeemCodeModel
err := r.db.WithContext(ctx).First(&code, id).Error err := r.db.WithContext(ctx).First(&m, id).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrRedeemCodeNotFound, nil) return nil, translatePersistenceError(err, service.ErrRedeemCodeNotFound, nil)
} }
return &code, nil return redeemCodeModelToService(&m), nil
} }
func (r *redeemCodeRepository) GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) { func (r *redeemCodeRepository) GetByCode(ctx context.Context, code string) (*service.RedeemCode, error) {
var redeemCode model.RedeemCode var m redeemCodeModel
err := r.db.WithContext(ctx).Where("code = ?", code).First(&redeemCode).Error err := r.db.WithContext(ctx).Where("code = ?", code).First(&m).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrRedeemCodeNotFound, nil) return nil, translatePersistenceError(err, service.ErrRedeemCodeNotFound, nil)
} }
return &redeemCode, nil return redeemCodeModelToService(&m), nil
} }
func (r *redeemCodeRepository) Delete(ctx context.Context, id int64) error { func (r *redeemCodeRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.RedeemCode{}, id).Error return r.db.WithContext(ctx).Delete(&redeemCodeModel{}, id).Error
} }
func (r *redeemCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) { func (r *redeemCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.RedeemCode, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "") return r.ListWithFilters(ctx, params, "", "", "")
} }
// ListWithFilters lists redeem codes with optional filtering by type, status, and search query func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]service.RedeemCode, *pagination.PaginationResult, error) {
func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]model.RedeemCode, *pagination.PaginationResult, error) { var codes []redeemCodeModel
var codes []model.RedeemCode
var total int64 var total int64
db := r.db.WithContext(ctx).Model(&model.RedeemCode{}) db := r.db.WithContext(ctx).Model(&redeemCodeModel{})
// Apply filters
if codeType != "" { if codeType != "" {
db = db.Where("type = ?", codeType) db = db.Where("type = ?", codeType)
} }
...@@ -81,29 +92,29 @@ func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagin ...@@ -81,29 +92,29 @@ func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagin
return nil, nil, err return nil, nil, err
} }
pages := int(total) / params.Limit() outCodes := make([]service.RedeemCode, 0, len(codes))
if int(total)%params.Limit() > 0 { for i := range codes {
pages++ outCodes = append(outCodes, *redeemCodeModelToService(&codes[i]))
} }
return codes, &pagination.PaginationResult{ return outCodes, paginationResultFromTotal(total, params), nil
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
} }
func (r *redeemCodeRepository) Update(ctx context.Context, code *model.RedeemCode) error { func (r *redeemCodeRepository) Update(ctx context.Context, code *service.RedeemCode) error {
return r.db.WithContext(ctx).Save(code).Error m := redeemCodeModelFromService(code)
err := r.db.WithContext(ctx).Save(m).Error
if err == nil {
applyRedeemCodeModelToService(code, m)
}
return err
} }
func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error { func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error {
now := time.Now() now := time.Now()
result := r.db.WithContext(ctx).Model(&model.RedeemCode{}). result := r.db.WithContext(ctx).Model(&redeemCodeModel{}).
Where("id = ? AND status = ?", id, model.StatusUnused). Where("id = ? AND status = ?", id, service.StatusUnused).
Updates(map[string]any{ Updates(map[string]any{
"status": model.StatusUsed, "status": service.StatusUsed,
"used_by": userID, "used_by": userID,
"used_at": now, "used_at": now,
}) })
...@@ -116,22 +127,93 @@ func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error ...@@ -116,22 +127,93 @@ func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error
return nil return nil
} }
// ListByUser returns all redeem codes used by a specific user func (r *redeemCodeRepository) ListByUser(ctx context.Context, userID int64, limit int) ([]service.RedeemCode, error) {
func (r *redeemCodeRepository) ListByUser(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error) {
var codes []model.RedeemCode
if limit <= 0 { if limit <= 0 {
limit = 10 limit = 10
} }
var codes []redeemCodeModel
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Preload("Group"). Preload("Group").
Where("used_by = ?", userID). Where("used_by = ?", userID).
Order("used_at DESC"). Order("used_at DESC").
Limit(limit). Limit(limit).
Find(&codes).Error Find(&codes).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
return codes, nil
outCodes := make([]service.RedeemCode, 0, len(codes))
for i := range codes {
outCodes = append(outCodes, *redeemCodeModelToService(&codes[i]))
}
return outCodes, nil
}
type redeemCodeModel struct {
ID int64 `gorm:"primaryKey"`
Code string `gorm:"uniqueIndex;size:32;not null"`
Type string `gorm:"size:20;default:balance;not null"`
Value float64 `gorm:"type:decimal(20,8);not null"`
Status string `gorm:"size:20;default:unused;not null"`
UsedBy *int64 `gorm:"index"`
UsedAt *time.Time
Notes string `gorm:"type:text"`
CreatedAt time.Time `gorm:"not null"`
GroupID *int64 `gorm:"index"`
ValidityDays int `gorm:"default:30"`
User *userModel `gorm:"foreignKey:UsedBy"`
Group *groupModel `gorm:"foreignKey:GroupID"`
}
func (redeemCodeModel) TableName() string { return "redeem_codes" }
func redeemCodeModelToService(m *redeemCodeModel) *service.RedeemCode {
if m == nil {
return nil
}
return &service.RedeemCode{
ID: m.ID,
Code: m.Code,
Type: m.Type,
Value: m.Value,
Status: m.Status,
UsedBy: m.UsedBy,
UsedAt: m.UsedAt,
Notes: m.Notes,
CreatedAt: m.CreatedAt,
GroupID: m.GroupID,
ValidityDays: m.ValidityDays,
User: userModelToService(m.User),
Group: groupModelToService(m.Group),
}
}
func redeemCodeModelFromService(r *service.RedeemCode) *redeemCodeModel {
if r == nil {
return nil
}
return &redeemCodeModel{
ID: r.ID,
Code: r.Code,
Type: r.Type,
Value: r.Value,
Status: r.Status,
UsedBy: r.UsedBy,
UsedAt: r.UsedAt,
Notes: r.Notes,
CreatedAt: r.CreatedAt,
GroupID: r.GroupID,
ValidityDays: r.ValidityDays,
}
}
func applyRedeemCodeModelToService(code *service.RedeemCode, m *redeemCodeModel) {
if code == nil || m == nil {
return
}
code.ID = m.ID
code.CreatedAt = m.CreatedAt
} }
...@@ -7,7 +7,6 @@ import ( ...@@ -7,7 +7,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
...@@ -34,11 +33,11 @@ func TestRedeemCodeRepoSuite(t *testing.T) { ...@@ -34,11 +33,11 @@ func TestRedeemCodeRepoSuite(t *testing.T) {
// --- Create / CreateBatch / GetByID / GetByCode --- // --- Create / CreateBatch / GetByID / GetByCode ---
func (s *RedeemCodeRepoSuite) TestCreate() { func (s *RedeemCodeRepoSuite) TestCreate() {
code := &model.RedeemCode{ code := &service.RedeemCode{
Code: "TEST-CREATE", Code: "TEST-CREATE",
Type: model.RedeemTypeBalance, Type: service.RedeemTypeBalance,
Value: 100, Value: 100,
Status: model.StatusUnused, Status: service.StatusUnused,
} }
err := s.repo.Create(s.ctx, code) err := s.repo.Create(s.ctx, code)
...@@ -51,9 +50,9 @@ func (s *RedeemCodeRepoSuite) TestCreate() { ...@@ -51,9 +50,9 @@ func (s *RedeemCodeRepoSuite) TestCreate() {
} }
func (s *RedeemCodeRepoSuite) TestCreateBatch() { func (s *RedeemCodeRepoSuite) TestCreateBatch() {
codes := []model.RedeemCode{ codes := []service.RedeemCode{
{Code: "BATCH-1", Type: model.RedeemTypeBalance, Value: 10, Status: model.StatusUnused}, {Code: "BATCH-1", Type: service.RedeemTypeBalance, Value: 10, Status: service.StatusUnused},
{Code: "BATCH-2", Type: model.RedeemTypeBalance, Value: 20, Status: model.StatusUnused}, {Code: "BATCH-2", Type: service.RedeemTypeBalance, Value: 20, Status: service.StatusUnused},
} }
err := s.repo.CreateBatch(s.ctx, codes) err := s.repo.CreateBatch(s.ctx, codes)
...@@ -74,7 +73,7 @@ func (s *RedeemCodeRepoSuite) TestGetByID_NotFound() { ...@@ -74,7 +73,7 @@ func (s *RedeemCodeRepoSuite) TestGetByID_NotFound() {
} }
func (s *RedeemCodeRepoSuite) TestGetByCode() { func (s *RedeemCodeRepoSuite) TestGetByCode() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "GET-BY-CODE", Type: model.RedeemTypeBalance}) mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "GET-BY-CODE", Type: service.RedeemTypeBalance})
got, err := s.repo.GetByCode(s.ctx, "GET-BY-CODE") got, err := s.repo.GetByCode(s.ctx, "GET-BY-CODE")
s.Require().NoError(err, "GetByCode") s.Require().NoError(err, "GetByCode")
...@@ -89,7 +88,7 @@ func (s *RedeemCodeRepoSuite) TestGetByCode_NotFound() { ...@@ -89,7 +88,7 @@ func (s *RedeemCodeRepoSuite) TestGetByCode_NotFound() {
// --- Delete --- // --- Delete ---
func (s *RedeemCodeRepoSuite) TestDelete() { func (s *RedeemCodeRepoSuite) TestDelete() {
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "TO-DELETE", Type: model.RedeemTypeBalance}) code := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "TO-DELETE", Type: service.RedeemTypeBalance})
err := s.repo.Delete(s.ctx, code.ID) err := s.repo.Delete(s.ctx, code.ID)
s.Require().NoError(err, "Delete") s.Require().NoError(err, "Delete")
...@@ -101,8 +100,8 @@ func (s *RedeemCodeRepoSuite) TestDelete() { ...@@ -101,8 +100,8 @@ func (s *RedeemCodeRepoSuite) TestDelete() {
// --- List / ListWithFilters --- // --- List / ListWithFilters ---
func (s *RedeemCodeRepoSuite) TestList() { func (s *RedeemCodeRepoSuite) TestList() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "LIST-1", Type: model.RedeemTypeBalance}) mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "LIST-1", Type: service.RedeemTypeBalance})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "LIST-2", Type: model.RedeemTypeBalance}) mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "LIST-2", Type: service.RedeemTypeBalance})
codes, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}) codes, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List") s.Require().NoError(err, "List")
...@@ -111,28 +110,28 @@ func (s *RedeemCodeRepoSuite) TestList() { ...@@ -111,28 +110,28 @@ func (s *RedeemCodeRepoSuite) TestList() {
} }
func (s *RedeemCodeRepoSuite) TestListWithFilters_Type() { func (s *RedeemCodeRepoSuite) TestListWithFilters_Type() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "TYPE-BAL", Type: model.RedeemTypeBalance}) mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "TYPE-BAL", Type: service.RedeemTypeBalance})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "TYPE-SUB", Type: model.RedeemTypeSubscription}) mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "TYPE-SUB", Type: service.RedeemTypeSubscription})
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.RedeemTypeSubscription, "", "") codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.RedeemTypeSubscription, "", "")
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(codes, 1) s.Require().Len(codes, 1)
s.Require().Equal(model.RedeemTypeSubscription, codes[0].Type) s.Require().Equal(service.RedeemTypeSubscription, codes[0].Type)
} }
func (s *RedeemCodeRepoSuite) TestListWithFilters_Status() { func (s *RedeemCodeRepoSuite) TestListWithFilters_Status() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "STAT-UNUSED", Type: model.RedeemTypeBalance, Status: model.StatusUnused}) mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "STAT-UNUSED", Type: service.RedeemTypeBalance, Status: service.StatusUnused})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "STAT-USED", Type: model.RedeemTypeBalance, Status: model.StatusUsed}) mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "STAT-USED", Type: service.RedeemTypeBalance, Status: service.StatusUsed})
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusUsed, "") codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusUsed, "")
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(codes, 1) s.Require().Len(codes, 1)
s.Require().Equal(model.StatusUsed, codes[0].Status) s.Require().Equal(service.StatusUsed, codes[0].Status)
} }
func (s *RedeemCodeRepoSuite) TestListWithFilters_Search() { func (s *RedeemCodeRepoSuite) TestListWithFilters_Search() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "ALPHA-CODE", Type: model.RedeemTypeBalance}) mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "ALPHA-CODE", Type: service.RedeemTypeBalance})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "BETA-CODE", Type: model.RedeemTypeBalance}) mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "BETA-CODE", Type: service.RedeemTypeBalance})
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alpha") codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alpha")
s.Require().NoError(err) s.Require().NoError(err)
...@@ -141,10 +140,10 @@ func (s *RedeemCodeRepoSuite) TestListWithFilters_Search() { ...@@ -141,10 +140,10 @@ func (s *RedeemCodeRepoSuite) TestListWithFilters_Search() {
} }
func (s *RedeemCodeRepoSuite) TestListWithFilters_GroupPreload() { func (s *RedeemCodeRepoSuite) TestListWithFilters_GroupPreload() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-preload"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-preload"})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{ mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
Code: "WITH-GROUP", Code: "WITH-GROUP",
Type: model.RedeemTypeSubscription, Type: service.RedeemTypeSubscription,
GroupID: &group.ID, GroupID: &group.ID,
}) })
...@@ -158,7 +157,7 @@ func (s *RedeemCodeRepoSuite) TestListWithFilters_GroupPreload() { ...@@ -158,7 +157,7 @@ func (s *RedeemCodeRepoSuite) TestListWithFilters_GroupPreload() {
// --- Update --- // --- Update ---
func (s *RedeemCodeRepoSuite) TestUpdate() { func (s *RedeemCodeRepoSuite) TestUpdate() {
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "UPDATE-ME", Type: model.RedeemTypeBalance, Value: 10}) code := redeemCodeModelToService(mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "UPDATE-ME", Type: service.RedeemTypeBalance, Value: 10}))
code.Value = 50 code.Value = 50
err := s.repo.Update(s.ctx, code) err := s.repo.Update(s.ctx, code)
...@@ -172,23 +171,23 @@ func (s *RedeemCodeRepoSuite) TestUpdate() { ...@@ -172,23 +171,23 @@ func (s *RedeemCodeRepoSuite) TestUpdate() {
// --- Use --- // --- Use ---
func (s *RedeemCodeRepoSuite) TestUse() { func (s *RedeemCodeRepoSuite) TestUse() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "use@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "use@test.com"})
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "USE-ME", Type: model.RedeemTypeBalance, Status: model.StatusUnused}) code := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "USE-ME", Type: service.RedeemTypeBalance, Status: service.StatusUnused})
err := s.repo.Use(s.ctx, code.ID, user.ID) err := s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().NoError(err, "Use") s.Require().NoError(err, "Use")
got, err := s.repo.GetByID(s.ctx, code.ID) got, err := s.repo.GetByID(s.ctx, code.ID)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Equal(model.StatusUsed, got.Status) s.Require().Equal(service.StatusUsed, got.Status)
s.Require().NotNil(got.UsedBy) s.Require().NotNil(got.UsedBy)
s.Require().Equal(user.ID, *got.UsedBy) s.Require().Equal(user.ID, *got.UsedBy)
s.Require().NotNil(got.UsedAt) s.Require().NotNil(got.UsedAt)
} }
func (s *RedeemCodeRepoSuite) TestUse_Idempotency() { func (s *RedeemCodeRepoSuite) TestUse_Idempotency() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "idem@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "idem@test.com"})
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "IDEM-CODE", Type: model.RedeemTypeBalance, Status: model.StatusUnused}) code := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "IDEM-CODE", Type: service.RedeemTypeBalance, Status: service.StatusUnused})
err := s.repo.Use(s.ctx, code.ID, user.ID) err := s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().NoError(err, "Use first time") s.Require().NoError(err, "Use first time")
...@@ -200,8 +199,8 @@ func (s *RedeemCodeRepoSuite) TestUse_Idempotency() { ...@@ -200,8 +199,8 @@ func (s *RedeemCodeRepoSuite) TestUse_Idempotency() {
} }
func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() { func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "already@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "already@test.com"})
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "ALREADY-USED", Type: model.RedeemTypeBalance, Status: model.StatusUsed}) code := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "ALREADY-USED", Type: service.RedeemTypeBalance, Status: service.StatusUsed})
err := s.repo.Use(s.ctx, code.ID, user.ID) err := s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().Error(err, "expected error for already used code") s.Require().Error(err, "expected error for already used code")
...@@ -211,22 +210,22 @@ func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() { ...@@ -211,22 +210,22 @@ func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() {
// --- ListByUser --- // --- ListByUser ---
func (s *RedeemCodeRepoSuite) TestListByUser() { func (s *RedeemCodeRepoSuite) TestListByUser() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listby@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "listby@test.com"})
base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
// Create codes with explicit used_at for ordering // Create codes with explicit used_at for ordering
c1 := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{ c1 := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
Code: "USER-1", Code: "USER-1",
Type: model.RedeemTypeBalance, Type: service.RedeemTypeBalance,
Status: model.StatusUsed, Status: service.StatusUsed,
UsedBy: &user.ID, UsedBy: &user.ID,
}) })
s.db.Model(c1).Update("used_at", base) s.db.Model(c1).Update("used_at", base)
c2 := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{ c2 := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
Code: "USER-2", Code: "USER-2",
Type: model.RedeemTypeBalance, Type: service.RedeemTypeBalance,
Status: model.StatusUsed, Status: service.StatusUsed,
UsedBy: &user.ID, UsedBy: &user.ID,
}) })
s.db.Model(c2).Update("used_at", base.Add(1*time.Hour)) s.db.Model(c2).Update("used_at", base.Add(1*time.Hour))
...@@ -240,13 +239,13 @@ func (s *RedeemCodeRepoSuite) TestListByUser() { ...@@ -240,13 +239,13 @@ func (s *RedeemCodeRepoSuite) TestListByUser() {
} }
func (s *RedeemCodeRepoSuite) TestListByUser_WithGroupPreload() { func (s *RedeemCodeRepoSuite) TestListByUser_WithGroupPreload() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "grp@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "grp@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-listby"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-listby"})
c := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{ c := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
Code: "WITH-GRP", Code: "WITH-GRP",
Type: model.RedeemTypeSubscription, Type: service.RedeemTypeSubscription,
Status: model.StatusUsed, Status: service.StatusUsed,
UsedBy: &user.ID, UsedBy: &user.ID,
GroupID: &group.ID, GroupID: &group.ID,
}) })
...@@ -260,11 +259,11 @@ func (s *RedeemCodeRepoSuite) TestListByUser_WithGroupPreload() { ...@@ -260,11 +259,11 @@ func (s *RedeemCodeRepoSuite) TestListByUser_WithGroupPreload() {
} }
func (s *RedeemCodeRepoSuite) TestListByUser_DefaultLimit() { func (s *RedeemCodeRepoSuite) TestListByUser_DefaultLimit() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "deflimit@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "deflimit@test.com"})
c := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{ c := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
Code: "DEF-LIM", Code: "DEF-LIM",
Type: model.RedeemTypeBalance, Type: service.RedeemTypeBalance,
Status: model.StatusUsed, Status: service.StatusUsed,
UsedBy: &user.ID, UsedBy: &user.ID,
}) })
s.db.Model(c).Update("used_at", time.Now()) s.db.Model(c).Update("used_at", time.Now())
...@@ -278,16 +277,16 @@ func (s *RedeemCodeRepoSuite) TestListByUser_DefaultLimit() { ...@@ -278,16 +277,16 @@ func (s *RedeemCodeRepoSuite) TestListByUser_DefaultLimit() {
// --- Combined original test --- // --- Combined original test ---
func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser() { func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "rc@example.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "rc@example.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-rc"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-rc"})
codes := []model.RedeemCode{ codes := []service.RedeemCode{
{Code: "CODEA", Type: model.RedeemTypeBalance, Value: 1, Status: model.StatusUnused, CreatedAt: time.Now()}, {Code: "CODEA", Type: service.RedeemTypeBalance, Value: 1, Status: service.StatusUnused, CreatedAt: time.Now()},
{Code: "CODEB", Type: model.RedeemTypeSubscription, Value: 0, Status: model.StatusUnused, GroupID: &group.ID, ValidityDays: 7, CreatedAt: time.Now()}, {Code: "CODEB", Type: service.RedeemTypeSubscription, Value: 0, Status: service.StatusUnused, GroupID: &group.ID, ValidityDays: 7, CreatedAt: time.Now()},
} }
s.Require().NoError(s.repo.CreateBatch(s.ctx, codes), "CreateBatch") s.Require().NoError(s.repo.CreateBatch(s.ctx, codes), "CreateBatch")
list, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.RedeemTypeSubscription, model.StatusUnused, "code") list, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.RedeemTypeSubscription, service.StatusUnused, "code")
s.Require().NoError(err, "ListWithFilters") s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total) s.Require().Equal(int64(1), page.Total)
s.Require().Len(list, 1) s.Require().Len(list, 1)
...@@ -305,9 +304,9 @@ func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser ...@@ -305,9 +304,9 @@ func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser
s.Require().NoError(err, "GetByCode") s.Require().NoError(err, "GetByCode")
// Use fixed time instead of time.Sleep for deterministic ordering // Use fixed time instead of time.Sleep for deterministic ordering
s.db.Model(&model.RedeemCode{}).Where("id = ?", codeB.ID).Update("used_at", time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)) s.db.Model(&redeemCodeModel{}).Where("id = ?", codeB.ID).Update("used_at", time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC))
s.Require().NoError(s.repo.Use(s.ctx, codeA.ID, user.ID), "Use codeA") s.Require().NoError(s.repo.Use(s.ctx, codeA.ID, user.ID), "Use codeA")
s.db.Model(&model.RedeemCode{}).Where("id = ?", codeA.ID).Update("used_at", time.Date(2025, 1, 1, 13, 0, 0, 0, time.UTC)) s.db.Model(&redeemCodeModel{}).Where("id = ?", codeA.ID).Update("used_at", time.Date(2025, 1, 1, 13, 0, 0, 0, time.UTC))
used, err := s.repo.ListByUser(s.ctx, user.ID, 10) used, err := s.repo.ListByUser(s.ctx, user.ID, 10)
s.Require().NoError(err, "ListByUser") s.Require().NoError(err, "ListByUser")
......
...@@ -6,33 +6,27 @@ import ( ...@@ -6,33 +6,27 @@ import (
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
) )
// SettingRepository 系统设置数据访问层
type settingRepository struct { type settingRepository struct {
db *gorm.DB db *gorm.DB
} }
// NewSettingRepository 创建系统设置仓库实例
func NewSettingRepository(db *gorm.DB) service.SettingRepository { func NewSettingRepository(db *gorm.DB) service.SettingRepository {
return &settingRepository{db: db} return &settingRepository{db: db}
} }
// Get 根据Key获取设置值 func (r *settingRepository) Get(ctx context.Context, key string) (*service.Setting, error) {
func (r *settingRepository) Get(ctx context.Context, key string) (*model.Setting, error) { var m settingModel
var setting model.Setting err := r.db.WithContext(ctx).Where("key = ?", key).First(&m).Error
err := r.db.WithContext(ctx).Where("key = ?", key).First(&setting).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrSettingNotFound, nil) return nil, translatePersistenceError(err, service.ErrSettingNotFound, nil)
} }
return &setting, nil return settingModelToService(&m), nil
} }
// GetValue 获取设置值字符串
func (r *settingRepository) GetValue(ctx context.Context, key string) (string, error) { func (r *settingRepository) GetValue(ctx context.Context, key string) (string, error) {
setting, err := r.Get(ctx, key) setting, err := r.Get(ctx, key)
if err != nil { if err != nil {
...@@ -41,9 +35,8 @@ func (r *settingRepository) GetValue(ctx context.Context, key string) (string, e ...@@ -41,9 +35,8 @@ func (r *settingRepository) GetValue(ctx context.Context, key string) (string, e
return setting.Value, nil return setting.Value, nil
} }
// Set 设置值(存在则更新,不存在则创建)
func (r *settingRepository) Set(ctx context.Context, key, value string) error { func (r *settingRepository) Set(ctx context.Context, key, value string) error {
setting := &model.Setting{ m := &settingModel{
Key: key, Key: key,
Value: value, Value: value,
UpdatedAt: time.Now(), UpdatedAt: time.Now(),
...@@ -52,12 +45,11 @@ func (r *settingRepository) Set(ctx context.Context, key, value string) error { ...@@ -52,12 +45,11 @@ func (r *settingRepository) Set(ctx context.Context, key, value string) error {
return r.db.WithContext(ctx).Clauses(clause.OnConflict{ return r.db.WithContext(ctx).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "key"}}, Columns: []clause.Column{{Name: "key"}},
DoUpdates: clause.AssignmentColumns([]string{"value", "updated_at"}), DoUpdates: clause.AssignmentColumns([]string{"value", "updated_at"}),
}).Create(setting).Error }).Create(m).Error
} }
// GetMultiple 批量获取设置
func (r *settingRepository) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { func (r *settingRepository) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
var settings []model.Setting var settings []settingModel
err := r.db.WithContext(ctx).Where("key IN ?", keys).Find(&settings).Error err := r.db.WithContext(ctx).Where("key IN ?", keys).Find(&settings).Error
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -70,11 +62,10 @@ func (r *settingRepository) GetMultiple(ctx context.Context, keys []string) (map ...@@ -70,11 +62,10 @@ func (r *settingRepository) GetMultiple(ctx context.Context, keys []string) (map
return result, nil return result, nil
} }
// SetMultiple 批量设置值
func (r *settingRepository) SetMultiple(ctx context.Context, settings map[string]string) error { func (r *settingRepository) SetMultiple(ctx context.Context, settings map[string]string) error {
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for key, value := range settings { for key, value := range settings {
setting := &model.Setting{ m := &settingModel{
Key: key, Key: key,
Value: value, Value: value,
UpdatedAt: time.Now(), UpdatedAt: time.Now(),
...@@ -82,7 +73,7 @@ func (r *settingRepository) SetMultiple(ctx context.Context, settings map[string ...@@ -82,7 +73,7 @@ func (r *settingRepository) SetMultiple(ctx context.Context, settings map[string
if err := tx.Clauses(clause.OnConflict{ if err := tx.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "key"}}, Columns: []clause.Column{{Name: "key"}},
DoUpdates: clause.AssignmentColumns([]string{"value", "updated_at"}), DoUpdates: clause.AssignmentColumns([]string{"value", "updated_at"}),
}).Create(setting).Error; err != nil { }).Create(m).Error; err != nil {
return err return err
} }
} }
...@@ -90,9 +81,8 @@ func (r *settingRepository) SetMultiple(ctx context.Context, settings map[string ...@@ -90,9 +81,8 @@ func (r *settingRepository) SetMultiple(ctx context.Context, settings map[string
}) })
} }
// GetAll 获取所有设置
func (r *settingRepository) GetAll(ctx context.Context) (map[string]string, error) { func (r *settingRepository) GetAll(ctx context.Context) (map[string]string, error) {
var settings []model.Setting var settings []settingModel
err := r.db.WithContext(ctx).Find(&settings).Error err := r.db.WithContext(ctx).Find(&settings).Error
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -105,7 +95,27 @@ func (r *settingRepository) GetAll(ctx context.Context) (map[string]string, erro ...@@ -105,7 +95,27 @@ func (r *settingRepository) GetAll(ctx context.Context) (map[string]string, erro
return result, nil return result, nil
} }
// Delete 删除设置
func (r *settingRepository) Delete(ctx context.Context, key string) error { func (r *settingRepository) Delete(ctx context.Context, key string) error {
return r.db.WithContext(ctx).Where("key = ?", key).Delete(&model.Setting{}).Error return r.db.WithContext(ctx).Where("key = ?", key).Delete(&settingModel{}).Error
}
type settingModel struct {
ID int64 `gorm:"primaryKey"`
Key string `gorm:"uniqueIndex;size:100;not null"`
Value string `gorm:"type:text;not null"`
UpdatedAt time.Time `gorm:"not null"`
}
func (settingModel) TableName() string { return "settings" }
func settingModelToService(m *settingModel) *service.Setting {
if m == nil {
return nil
}
return &service.Setting{
ID: m.ID,
Key: m.Key,
Value: m.Value,
UpdatedAt: m.UpdatedAt,
}
} }
...@@ -6,7 +6,6 @@ import ( ...@@ -6,7 +6,6 @@ import (
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
...@@ -30,7 +29,7 @@ func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int ...@@ -30,7 +29,7 @@ func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int
TokenCount int64 `gorm:"column:token_count"` TokenCount int64 `gorm:"column:token_count"`
} }
db := r.db.WithContext(ctx).Model(&model.UsageLog{}). db := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
COUNT(*) as request_count, COUNT(*) as request_count,
COALESCE(SUM(input_tokens + output_tokens), 0) as token_count COALESCE(SUM(input_tokens + output_tokens), 0) as token_count
...@@ -46,24 +45,29 @@ func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int ...@@ -46,24 +45,29 @@ func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int
return perfStats.RequestCount / 5, perfStats.TokenCount / 5 return perfStats.RequestCount / 5, perfStats.TokenCount / 5
} }
func (r *usageLogRepository) Create(ctx context.Context, log *model.UsageLog) error { func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) error {
return r.db.WithContext(ctx).Create(log).Error m := usageLogModelFromService(log)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyUsageLogModelToService(log, m)
}
return err
} }
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (*model.UsageLog, error) { func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) {
var log model.UsageLog var log usageLogModel
err := r.db.WithContext(ctx).First(&log, id).Error err := r.db.WithContext(ctx).First(&log, id).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrUsageLogNotFound, nil) return nil, translatePersistenceError(err, service.ErrUsageLogNotFound, nil)
} }
return &log, nil return usageLogModelToService(&log), nil
} }
func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []usageLogModel
var total int64 var total int64
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).Where("user_id = ?", userID) db := r.db.WithContext(ctx).Model(&usageLogModel{}).Where("user_id = ?", userID)
if err := db.Count(&total).Error; err != nil { if err := db.Count(&total).Error; err != nil {
return nil, nil, err return nil, nil, err
...@@ -73,24 +77,14 @@ func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, param ...@@ -73,24 +77,14 @@ func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, param
return nil, nil, err return nil, nil, err
} }
pages := int(total) / params.Limit() return usageLogModelsToService(logs), paginationResultFromTotal(total, params), nil
if int(total)%params.Limit() > 0 {
pages++
}
return logs, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
} }
func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []usageLogModel
var total int64 var total int64
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).Where("api_key_id = ?", apiKeyID) db := r.db.WithContext(ctx).Model(&usageLogModel{}).Where("api_key_id = ?", apiKeyID)
if err := db.Count(&total).Error; err != nil { if err := db.Count(&total).Error; err != nil {
return nil, nil, err return nil, nil, err
...@@ -100,17 +94,7 @@ func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, p ...@@ -100,17 +94,7 @@ func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, p
return nil, nil, err return nil, nil, err
} }
pages := int(total) / params.Limit() return usageLogModelsToService(logs), paginationResultFromTotal(total, params), nil
if int(total)%params.Limit() > 0 {
pages++
}
return logs, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
} }
// UserStats 用户使用统计 // UserStats 用户使用统计
...@@ -125,7 +109,7 @@ type UserStats struct { ...@@ -125,7 +109,7 @@ type UserStats struct {
func (r *usageLogRepository) GetUserStats(ctx context.Context, userID int64, startTime, endTime time.Time) (*UserStats, error) { func (r *usageLogRepository) GetUserStats(ctx context.Context, userID int64, startTime, endTime time.Time) (*UserStats, error) {
var stats UserStats var stats UserStats
err := r.db.WithContext(ctx).Model(&model.UsageLog{}). err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
COUNT(*) as total_requests, COUNT(*) as total_requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
...@@ -147,47 +131,47 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS ...@@ -147,47 +131,47 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
today := timezone.Today() today := timezone.Today()
// 总用户数 // 总用户数
r.db.WithContext(ctx).Model(&model.User{}).Count(&stats.TotalUsers) r.db.WithContext(ctx).Model(&userModel{}).Count(&stats.TotalUsers)
// 今日新增用户数 // 今日新增用户数
r.db.WithContext(ctx).Model(&model.User{}). r.db.WithContext(ctx).Model(&userModel{}).
Where("created_at >= ?", today). Where("created_at >= ?", today).
Count(&stats.TodayNewUsers) Count(&stats.TodayNewUsers)
// 今日活跃用户数 (今日有请求的用户) // 今日活跃用户数 (今日有请求的用户)
r.db.WithContext(ctx).Model(&model.UsageLog{}). r.db.WithContext(ctx).Model(&usageLogModel{}).
Distinct("user_id"). Distinct("user_id").
Where("created_at >= ?", today). Where("created_at >= ?", today).
Count(&stats.ActiveUsers) Count(&stats.ActiveUsers)
// 总 API Key 数 // 总 API Key 数
r.db.WithContext(ctx).Model(&model.ApiKey{}).Count(&stats.TotalApiKeys) r.db.WithContext(ctx).Model(&apiKeyModel{}).Count(&stats.TotalApiKeys)
// 活跃 API Key 数 // 活跃 API Key 数
r.db.WithContext(ctx).Model(&model.ApiKey{}). r.db.WithContext(ctx).Model(&apiKeyModel{}).
Where("status = ?", model.StatusActive). Where("status = ?", service.StatusActive).
Count(&stats.ActiveApiKeys) Count(&stats.ActiveApiKeys)
// 总账户数 // 总账户数
r.db.WithContext(ctx).Model(&model.Account{}).Count(&stats.TotalAccounts) r.db.WithContext(ctx).Model(&accountModel{}).Count(&stats.TotalAccounts)
// 正常账户数 (schedulable=true, status=active) // 正常账户数 (schedulable=true, status=active)
r.db.WithContext(ctx).Model(&model.Account{}). r.db.WithContext(ctx).Model(&accountModel{}).
Where("status = ? AND schedulable = ?", model.StatusActive, true). Where("status = ? AND schedulable = ?", service.StatusActive, true).
Count(&stats.NormalAccounts) Count(&stats.NormalAccounts)
// 异常账户数 (status=error) // 异常账户数 (status=error)
r.db.WithContext(ctx).Model(&model.Account{}). r.db.WithContext(ctx).Model(&accountModel{}).
Where("status = ?", model.StatusError). Where("status = ?", service.StatusError).
Count(&stats.ErrorAccounts) Count(&stats.ErrorAccounts)
// 限流账户数 // 限流账户数
r.db.WithContext(ctx).Model(&model.Account{}). r.db.WithContext(ctx).Model(&accountModel{}).
Where("rate_limited_at IS NOT NULL AND rate_limit_reset_at > ?", time.Now()). Where("rate_limited_at IS NOT NULL AND rate_limit_reset_at > ?", time.Now()).
Count(&stats.RateLimitAccounts) Count(&stats.RateLimitAccounts)
// 过载账户数 // 过载账户数
r.db.WithContext(ctx).Model(&model.Account{}). r.db.WithContext(ctx).Model(&accountModel{}).
Where("overload_until IS NOT NULL AND overload_until > ?", time.Now()). Where("overload_until IS NOT NULL AND overload_until > ?", time.Now()).
Count(&stats.OverloadAccounts) Count(&stats.OverloadAccounts)
...@@ -202,7 +186,7 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS ...@@ -202,7 +186,7 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
TotalActualCost float64 `gorm:"column:total_actual_cost"` TotalActualCost float64 `gorm:"column:total_actual_cost"`
AverageDurationMs float64 `gorm:"column:avg_duration_ms"` AverageDurationMs float64 `gorm:"column:avg_duration_ms"`
} }
r.db.WithContext(ctx).Model(&model.UsageLog{}). r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
COUNT(*) as total_requests, COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens, COALESCE(SUM(input_tokens), 0) as total_input_tokens,
...@@ -235,7 +219,7 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS ...@@ -235,7 +219,7 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
TodayCost float64 `gorm:"column:today_cost"` TodayCost float64 `gorm:"column:today_cost"`
TodayActualCost float64 `gorm:"column:today_actual_cost"` TodayActualCost float64 `gorm:"column:today_actual_cost"`
} }
r.db.WithContext(ctx).Model(&model.UsageLog{}). r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
COUNT(*) as today_requests, COUNT(*) as today_requests,
COALESCE(SUM(input_tokens), 0) as today_input_tokens, COALESCE(SUM(input_tokens), 0) as today_input_tokens,
...@@ -263,11 +247,11 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS ...@@ -263,11 +247,11 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
return &stats, nil return &stats, nil
} }
func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []usageLogModel
var total int64 var total int64
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).Where("account_id = ?", accountID) db := r.db.WithContext(ctx).Model(&usageLogModel{}).Where("account_id = ?", accountID)
if err := db.Count(&total).Error; err != nil { if err := db.Count(&total).Error; err != nil {
return nil, nil, err return nil, nil, err
...@@ -277,57 +261,47 @@ func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, ...@@ -277,57 +261,47 @@ func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64,
return nil, nil, err return nil, nil, err
} }
pages := int(total) / params.Limit() return usageLogModelsToService(logs), paginationResultFromTotal(total, params), nil
if int(total)%params.Limit() > 0 {
pages++
}
return logs, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
} }
func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []usageLogModel
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("user_id = ? AND created_at >= ? AND created_at < ?", userID, startTime, endTime). Where("user_id = ? AND created_at >= ? AND created_at < ?", userID, startTime, endTime).
Order("id DESC"). Order("id DESC").
Find(&logs).Error Find(&logs).Error
return logs, nil, err return usageLogModelsToService(logs), nil, err
} }
func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []usageLogModel
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("api_key_id = ? AND created_at >= ? AND created_at < ?", apiKeyID, startTime, endTime). Where("api_key_id = ? AND created_at >= ? AND created_at < ?", apiKeyID, startTime, endTime).
Order("id DESC"). Order("id DESC").
Find(&logs).Error Find(&logs).Error
return logs, nil, err return usageLogModelsToService(logs), nil, err
} }
func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []usageLogModel
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime). Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime).
Order("id DESC"). Order("id DESC").
Find(&logs).Error Find(&logs).Error
return logs, nil, err return usageLogModelsToService(logs), nil, err
} }
func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []usageLogModel
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("model = ? AND created_at >= ? AND created_at < ?", modelName, startTime, endTime). Where("model = ? AND created_at >= ? AND created_at < ?", modelName, startTime, endTime).
Order("id DESC"). Order("id DESC").
Find(&logs).Error Find(&logs).Error
return logs, nil, err return usageLogModelsToService(logs), nil, err
} }
func (r *usageLogRepository) Delete(ctx context.Context, id int64) error { func (r *usageLogRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.UsageLog{}, id).Error return r.db.WithContext(ctx).Delete(&usageLogModel{}, id).Error
} }
// GetAccountTodayStats 获取账号今日统计 // GetAccountTodayStats 获取账号今日统计
...@@ -340,7 +314,7 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID ...@@ -340,7 +314,7 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
Cost float64 `gorm:"column:cost"` Cost float64 `gorm:"column:cost"`
} }
err := r.db.WithContext(ctx).Model(&model.UsageLog{}). err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
COUNT(*) as requests, COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
...@@ -368,7 +342,7 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI ...@@ -368,7 +342,7 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
Cost float64 `gorm:"column:cost"` Cost float64 `gorm:"column:cost"`
} }
err := r.db.WithContext(ctx).Model(&model.UsageLog{}). err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
COUNT(*) as requests, COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
...@@ -499,12 +473,12 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i ...@@ -499,12 +473,12 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
today := timezone.Today() today := timezone.Today()
// API Key 统计 // API Key 统计
r.db.WithContext(ctx).Model(&model.ApiKey{}). r.db.WithContext(ctx).Model(&apiKeyModel{}).
Where("user_id = ?", userID). Where("user_id = ?", userID).
Count(&stats.TotalApiKeys) Count(&stats.TotalApiKeys)
r.db.WithContext(ctx).Model(&model.ApiKey{}). r.db.WithContext(ctx).Model(&apiKeyModel{}).
Where("user_id = ? AND status = ?", userID, model.StatusActive). Where("user_id = ? AND status = ?", userID, service.StatusActive).
Count(&stats.ActiveApiKeys) Count(&stats.ActiveApiKeys)
// 累计 Token 统计 // 累计 Token 统计
...@@ -518,7 +492,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i ...@@ -518,7 +492,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
TotalActualCost float64 `gorm:"column:total_actual_cost"` TotalActualCost float64 `gorm:"column:total_actual_cost"`
AverageDurationMs float64 `gorm:"column:avg_duration_ms"` AverageDurationMs float64 `gorm:"column:avg_duration_ms"`
} }
r.db.WithContext(ctx).Model(&model.UsageLog{}). r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
COUNT(*) as total_requests, COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens, COALESCE(SUM(input_tokens), 0) as total_input_tokens,
...@@ -552,7 +526,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i ...@@ -552,7 +526,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
TodayCost float64 `gorm:"column:today_cost"` TodayCost float64 `gorm:"column:today_cost"`
TodayActualCost float64 `gorm:"column:today_actual_cost"` TodayActualCost float64 `gorm:"column:today_actual_cost"`
} }
r.db.WithContext(ctx).Model(&model.UsageLog{}). r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
COUNT(*) as today_requests, COUNT(*) as today_requests,
COALESCE(SUM(input_tokens), 0) as today_input_tokens, COALESCE(SUM(input_tokens), 0) as today_input_tokens,
...@@ -591,7 +565,7 @@ func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user ...@@ -591,7 +565,7 @@ func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user
dateFormat = "YYYY-MM-DD" dateFormat = "YYYY-MM-DD"
} }
err := r.db.WithContext(ctx).Model(&model.UsageLog{}). err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
TO_CHAR(created_at, ?) as date, TO_CHAR(created_at, ?) as date,
COUNT(*) as requests, COUNT(*) as requests,
...@@ -618,7 +592,7 @@ func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user ...@@ -618,7 +592,7 @@ func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user
func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]ModelStat, error) { func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]ModelStat, error) {
var results []ModelStat var results []ModelStat
err := r.db.WithContext(ctx).Model(&model.UsageLog{}). err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
model, model,
COUNT(*) as requests, COUNT(*) as requests,
...@@ -644,11 +618,11 @@ func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64 ...@@ -644,11 +618,11 @@ func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64
type UsageLogFilters = usagestats.UsageLogFilters type UsageLogFilters = usagestats.UsageLogFilters
// ListWithFilters lists usage logs with optional filters (for admin) // ListWithFilters lists usage logs with optional filters (for admin)
func (r *usageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []usageLogModel
var total int64 var total int64
db := r.db.WithContext(ctx).Model(&model.UsageLog{}) db := r.db.WithContext(ctx).Model(&usageLogModel{})
// Apply filters // Apply filters
if filters.UserID > 0 { if filters.UserID > 0 {
...@@ -675,17 +649,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat ...@@ -675,17 +649,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
return nil, nil, err return nil, nil, err
} }
pages := int(total) / params.Limit() return usageLogModelsToService(logs), paginationResultFromTotal(total, params), nil
if int(total)%params.Limit() > 0 {
pages++
}
return logs, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
} }
// UsageStats represents usage statistics // UsageStats represents usage statistics
...@@ -713,7 +677,7 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs ...@@ -713,7 +677,7 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
UserID int64 `gorm:"column:user_id"` UserID int64 `gorm:"column:user_id"`
TotalCost float64 `gorm:"column:total_cost"` TotalCost float64 `gorm:"column:total_cost"`
} }
err := r.db.WithContext(ctx).Model(&model.UsageLog{}). err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select("user_id, COALESCE(SUM(actual_cost), 0) as total_cost"). Select("user_id, COALESCE(SUM(actual_cost), 0) as total_cost").
Where("user_id IN ?", userIDs). Where("user_id IN ?", userIDs).
Group("user_id"). Group("user_id").
...@@ -733,7 +697,7 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs ...@@ -733,7 +697,7 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
UserID int64 `gorm:"column:user_id"` UserID int64 `gorm:"column:user_id"`
TodayCost float64 `gorm:"column:today_cost"` TodayCost float64 `gorm:"column:today_cost"`
} }
err = r.db.WithContext(ctx).Model(&model.UsageLog{}). err = r.db.WithContext(ctx).Model(&usageLogModel{}).
Select("user_id, COALESCE(SUM(actual_cost), 0) as today_cost"). Select("user_id, COALESCE(SUM(actual_cost), 0) as today_cost").
Where("user_id IN ? AND created_at >= ?", userIDs, today). Where("user_id IN ? AND created_at >= ?", userIDs, today).
Group("user_id"). Group("user_id").
...@@ -773,7 +737,7 @@ func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe ...@@ -773,7 +737,7 @@ func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe
ApiKeyID int64 `gorm:"column:api_key_id"` ApiKeyID int64 `gorm:"column:api_key_id"`
TotalCost float64 `gorm:"column:total_cost"` TotalCost float64 `gorm:"column:total_cost"`
} }
err := r.db.WithContext(ctx).Model(&model.UsageLog{}). err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select("api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost"). Select("api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost").
Where("api_key_id IN ?", apiKeyIDs). Where("api_key_id IN ?", apiKeyIDs).
Group("api_key_id"). Group("api_key_id").
...@@ -793,7 +757,7 @@ func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe ...@@ -793,7 +757,7 @@ func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe
ApiKeyID int64 `gorm:"column:api_key_id"` ApiKeyID int64 `gorm:"column:api_key_id"`
TodayCost float64 `gorm:"column:today_cost"` TodayCost float64 `gorm:"column:today_cost"`
} }
err = r.db.WithContext(ctx).Model(&model.UsageLog{}). err = r.db.WithContext(ctx).Model(&usageLogModel{}).
Select("api_key_id, COALESCE(SUM(actual_cost), 0) as today_cost"). Select("api_key_id, COALESCE(SUM(actual_cost), 0) as today_cost").
Where("api_key_id IN ? AND created_at >= ?", apiKeyIDs, today). Where("api_key_id IN ? AND created_at >= ?", apiKeyIDs, today).
Group("api_key_id"). Group("api_key_id").
...@@ -822,7 +786,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start ...@@ -822,7 +786,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
dateFormat = "YYYY-MM-DD" dateFormat = "YYYY-MM-DD"
} }
db := r.db.WithContext(ctx).Model(&model.UsageLog{}). db := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
TO_CHAR(created_at, ?) as date, TO_CHAR(created_at, ?) as date,
COUNT(*) as requests, COUNT(*) as requests,
...@@ -854,7 +818,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start ...@@ -854,7 +818,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]ModelStat, error) { func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]ModelStat, error) {
var results []ModelStat var results []ModelStat
db := r.db.WithContext(ctx).Model(&model.UsageLog{}). db := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
model, model,
COUNT(*) as requests, COUNT(*) as requests,
...@@ -896,7 +860,7 @@ func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endT ...@@ -896,7 +860,7 @@ func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endT
AverageDurationMs float64 `gorm:"column:avg_duration_ms"` AverageDurationMs float64 `gorm:"column:avg_duration_ms"`
} }
err := r.db.WithContext(ctx).Model(&model.UsageLog{}). err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
COUNT(*) as total_requests, COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens, COALESCE(SUM(input_tokens), 0) as total_input_tokens,
...@@ -950,7 +914,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID ...@@ -950,7 +914,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
ActualCost float64 `gorm:"column:actual_cost"` ActualCost float64 `gorm:"column:actual_cost"`
} }
err := r.db.WithContext(ctx).Model(&model.UsageLog{}). err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
TO_CHAR(created_at, 'YYYY-MM-DD') as date, TO_CHAR(created_at, 'YYYY-MM-DD') as date,
COUNT(*) as requests, COUNT(*) as requests,
...@@ -1011,7 +975,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID ...@@ -1011,7 +975,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
var avgDuration struct { var avgDuration struct {
AvgDurationMs float64 `gorm:"column:avg_duration_ms"` AvgDurationMs float64 `gorm:"column:avg_duration_ms"`
} }
r.db.WithContext(ctx).Model(&model.UsageLog{}). r.db.WithContext(ctx).Model(&usageLogModel{}).
Select("COALESCE(AVG(duration_ms), 0) as avg_duration_ms"). Select("COALESCE(AVG(duration_ms), 0) as avg_duration_ms").
Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime). Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime).
Scan(&avgDuration) Scan(&avgDuration)
...@@ -1090,3 +1054,137 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID ...@@ -1090,3 +1054,137 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
Models: models, Models: models,
}, nil }, nil
} }
type usageLogModel struct {
ID int64 `gorm:"primaryKey"`
UserID int64 `gorm:"index;not null"`
ApiKeyID int64 `gorm:"index;not null"`
AccountID int64 `gorm:"index;not null"`
RequestID string `gorm:"size:64"`
Model string `gorm:"size:100;index;not null"`
GroupID *int64 `gorm:"index"`
SubscriptionID *int64 `gorm:"index"`
InputTokens int `gorm:"default:0;not null"`
OutputTokens int `gorm:"default:0;not null"`
CacheCreationTokens int `gorm:"default:0;not null"`
CacheReadTokens int `gorm:"default:0;not null"`
CacheCreation5mTokens int `gorm:"default:0;not null"`
CacheCreation1hTokens int `gorm:"default:0;not null"`
InputCost float64 `gorm:"type:decimal(20,10);default:0;not null"`
OutputCost float64 `gorm:"type:decimal(20,10);default:0;not null"`
CacheCreationCost float64 `gorm:"type:decimal(20,10);default:0;not null"`
CacheReadCost float64 `gorm:"type:decimal(20,10);default:0;not null"`
TotalCost float64 `gorm:"type:decimal(20,10);default:0;not null"`
ActualCost float64 `gorm:"type:decimal(20,10);default:0;not null"`
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1;not null"`
BillingType int8 `gorm:"type:smallint;default:0;not null"`
Stream bool `gorm:"default:false;not null"`
DurationMs *int
FirstTokenMs *int
CreatedAt time.Time `gorm:"index;not null"`
User *userModel `gorm:"foreignKey:UserID"`
ApiKey *apiKeyModel `gorm:"foreignKey:ApiKeyID"`
Account *accountModel `gorm:"foreignKey:AccountID"`
Group *groupModel `gorm:"foreignKey:GroupID"`
Subscription *userSubscriptionModel `gorm:"foreignKey:SubscriptionID"`
}
func (usageLogModel) TableName() string { return "usage_logs" }
func usageLogModelToService(m *usageLogModel) *service.UsageLog {
if m == nil {
return nil
}
return &service.UsageLog{
ID: m.ID,
UserID: m.UserID,
ApiKeyID: m.ApiKeyID,
AccountID: m.AccountID,
RequestID: m.RequestID,
Model: m.Model,
GroupID: m.GroupID,
SubscriptionID: m.SubscriptionID,
InputTokens: m.InputTokens,
OutputTokens: m.OutputTokens,
CacheCreationTokens: m.CacheCreationTokens,
CacheReadTokens: m.CacheReadTokens,
CacheCreation5mTokens: m.CacheCreation5mTokens,
CacheCreation1hTokens: m.CacheCreation1hTokens,
InputCost: m.InputCost,
OutputCost: m.OutputCost,
CacheCreationCost: m.CacheCreationCost,
CacheReadCost: m.CacheReadCost,
TotalCost: m.TotalCost,
ActualCost: m.ActualCost,
RateMultiplier: m.RateMultiplier,
BillingType: m.BillingType,
Stream: m.Stream,
DurationMs: m.DurationMs,
FirstTokenMs: m.FirstTokenMs,
CreatedAt: m.CreatedAt,
User: userModelToService(m.User),
ApiKey: apiKeyModelToService(m.ApiKey),
Account: accountModelToService(m.Account),
Group: groupModelToService(m.Group),
Subscription: userSubscriptionModelToService(m.Subscription),
}
}
func usageLogModelsToService(models []usageLogModel) []service.UsageLog {
out := make([]service.UsageLog, 0, len(models))
for i := range models {
if s := usageLogModelToService(&models[i]); s != nil {
out = append(out, *s)
}
}
return out
}
func usageLogModelFromService(log *service.UsageLog) *usageLogModel {
if log == nil {
return nil
}
return &usageLogModel{
ID: log.ID,
UserID: log.UserID,
ApiKeyID: log.ApiKeyID,
AccountID: log.AccountID,
RequestID: log.RequestID,
Model: log.Model,
GroupID: log.GroupID,
SubscriptionID: log.SubscriptionID,
InputTokens: log.InputTokens,
OutputTokens: log.OutputTokens,
CacheCreationTokens: log.CacheCreationTokens,
CacheReadTokens: log.CacheReadTokens,
CacheCreation5mTokens: log.CacheCreation5mTokens,
CacheCreation1hTokens: log.CacheCreation1hTokens,
InputCost: log.InputCost,
OutputCost: log.OutputCost,
CacheCreationCost: log.CacheCreationCost,
CacheReadCost: log.CacheReadCost,
TotalCost: log.TotalCost,
ActualCost: log.ActualCost,
RateMultiplier: log.RateMultiplier,
BillingType: log.BillingType,
Stream: log.Stream,
DurationMs: log.DurationMs,
FirstTokenMs: log.FirstTokenMs,
CreatedAt: log.CreatedAt,
}
}
func applyUsageLogModelToService(log *service.UsageLog, m *usageLogModel) {
if log == nil || m == nil {
return
}
log.ID = m.ID
log.CreatedAt = m.CreatedAt
}
...@@ -7,10 +7,10 @@ import ( ...@@ -7,10 +7,10 @@ import (
"testing" "testing"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"gorm.io/gorm" "gorm.io/gorm"
) )
...@@ -32,8 +32,8 @@ func TestUsageLogRepoSuite(t *testing.T) { ...@@ -32,8 +32,8 @@ func TestUsageLogRepoSuite(t *testing.T) {
suite.Run(t, new(UsageLogRepoSuite)) suite.Run(t, new(UsageLogRepoSuite))
} }
func (s *UsageLogRepoSuite) createUsageLog(user *model.User, apiKey *model.ApiKey, account *model.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *model.UsageLog { func (s *UsageLogRepoSuite) createUsageLog(user *userModel, apiKey *apiKeyModel, account *accountModel, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog {
log := &model.UsageLog{ log := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
...@@ -51,11 +51,11 @@ func (s *UsageLogRepoSuite) createUsageLog(user *model.User, apiKey *model.ApiKe ...@@ -51,11 +51,11 @@ func (s *UsageLogRepoSuite) createUsageLog(user *model.User, apiKey *model.ApiKe
// --- Create / GetByID --- // --- Create / GetByID ---
func (s *UsageLogRepoSuite) TestCreate() { func (s *UsageLogRepoSuite) TestCreate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "create@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "create@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-create", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-create", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-create"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-create"})
log := &model.UsageLog{ log := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
...@@ -72,9 +72,9 @@ func (s *UsageLogRepoSuite) TestCreate() { ...@@ -72,9 +72,9 @@ func (s *UsageLogRepoSuite) TestCreate() {
} }
func (s *UsageLogRepoSuite) TestGetByID() { func (s *UsageLogRepoSuite) TestGetByID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "getbyid@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "getbyid@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-getbyid"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-getbyid"})
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
...@@ -92,9 +92,9 @@ func (s *UsageLogRepoSuite) TestGetByID_NotFound() { ...@@ -92,9 +92,9 @@ func (s *UsageLogRepoSuite) TestGetByID_NotFound() {
// --- Delete --- // --- Delete ---
func (s *UsageLogRepoSuite) TestDelete() { func (s *UsageLogRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-delete", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-delete", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-delete"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-delete"})
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
...@@ -108,9 +108,9 @@ func (s *UsageLogRepoSuite) TestDelete() { ...@@ -108,9 +108,9 @@ func (s *UsageLogRepoSuite) TestDelete() {
// --- ListByUser --- // --- ListByUser ---
func (s *UsageLogRepoSuite) TestListByUser() { func (s *UsageLogRepoSuite) TestListByUser() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyuser@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyuser@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-listbyuser", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-listbyuser", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-listbyuser"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-listbyuser"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now()) s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now())
...@@ -124,9 +124,9 @@ func (s *UsageLogRepoSuite) TestListByUser() { ...@@ -124,9 +124,9 @@ func (s *UsageLogRepoSuite) TestListByUser() {
// --- ListByApiKey --- // --- ListByApiKey ---
func (s *UsageLogRepoSuite) TestListByApiKey() { func (s *UsageLogRepoSuite) TestListByApiKey() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyapikey@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyapikey@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-listbyapikey"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-listbyapikey"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now()) s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now())
...@@ -140,9 +140,9 @@ func (s *UsageLogRepoSuite) TestListByApiKey() { ...@@ -140,9 +140,9 @@ func (s *UsageLogRepoSuite) TestListByApiKey() {
// --- ListByAccount --- // --- ListByAccount ---
func (s *UsageLogRepoSuite) TestListByAccount() { func (s *UsageLogRepoSuite) TestListByAccount() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyaccount@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyaccount@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-listbyaccount"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-listbyaccount"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
...@@ -155,9 +155,9 @@ func (s *UsageLogRepoSuite) TestListByAccount() { ...@@ -155,9 +155,9 @@ func (s *UsageLogRepoSuite) TestListByAccount() {
// --- GetUserStats --- // --- GetUserStats ---
func (s *UsageLogRepoSuite) TestGetUserStats() { func (s *UsageLogRepoSuite) TestGetUserStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "userstats@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "userstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-userstats", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-userstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-userstats"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-userstats"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
...@@ -175,9 +175,9 @@ func (s *UsageLogRepoSuite) TestGetUserStats() { ...@@ -175,9 +175,9 @@ func (s *UsageLogRepoSuite) TestGetUserStats() {
// --- ListWithFilters --- // --- ListWithFilters ---
func (s *UsageLogRepoSuite) TestListWithFilters() { func (s *UsageLogRepoSuite) TestListWithFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filters@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "filters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filters", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-filters", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filters"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-filters"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
...@@ -194,29 +194,29 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() { ...@@ -194,29 +194,29 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
now := time.Now() now := time.Now()
todayStart := timezone.Today() todayStart := timezone.Today()
userToday := mustCreateUser(s.T(), s.db, &model.User{ userToday := mustCreateUser(s.T(), s.db, &userModel{
Email: "today@example.com", Email: "today@example.com",
CreatedAt: maxTime(todayStart.Add(10*time.Second), now.Add(-10*time.Second)), CreatedAt: maxTime(todayStart.Add(10*time.Second), now.Add(-10*time.Second)),
UpdatedAt: now, UpdatedAt: now,
}) })
userOld := mustCreateUser(s.T(), s.db, &model.User{ userOld := mustCreateUser(s.T(), s.db, &userModel{
Email: "old@example.com", Email: "old@example.com",
CreatedAt: todayStart.Add(-24 * time.Hour), CreatedAt: todayStart.Add(-24 * time.Hour),
UpdatedAt: todayStart.Add(-24 * time.Hour), UpdatedAt: todayStart.Add(-24 * time.Hour),
}) })
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-ul"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-ul"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"}) apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: model.StatusDisabled}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: service.StatusDisabled})
resetAt := now.Add(10 * time.Minute) resetAt := now.Add(10 * time.Minute)
accNormal := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-normal", Schedulable: true}) accNormal := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a-normal", Schedulable: true})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-error", Status: model.StatusError, Schedulable: true}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a-error", Status: service.StatusError, Schedulable: true})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-rl", RateLimitedAt: &now, RateLimitResetAt: &resetAt, Schedulable: true}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a-rl", RateLimitedAt: &now, RateLimitResetAt: &resetAt, Schedulable: true})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-ov", OverloadUntil: &resetAt, Schedulable: true}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a-ov", OverloadUntil: &resetAt, Schedulable: true})
d1, d2, d3 := 100, 200, 300 d1, d2, d3 := 100, 200, 300
logToday := &model.UsageLog{ logToday := &service.UsageLog{
UserID: userToday.ID, UserID: userToday.ID,
ApiKeyID: apiKey1.ID, ApiKeyID: apiKey1.ID,
AccountID: accNormal.ID, AccountID: accNormal.ID,
...@@ -233,7 +233,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() { ...@@ -233,7 +233,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
} }
s.Require().NoError(s.repo.Create(s.ctx, logToday), "Create logToday") s.Require().NoError(s.repo.Create(s.ctx, logToday), "Create logToday")
logOld := &model.UsageLog{ logOld := &service.UsageLog{
UserID: userOld.ID, UserID: userOld.ID,
ApiKeyID: apiKey1.ID, ApiKeyID: apiKey1.ID,
AccountID: accNormal.ID, AccountID: accNormal.ID,
...@@ -247,7 +247,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() { ...@@ -247,7 +247,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
} }
s.Require().NoError(s.repo.Create(s.ctx, logOld), "Create logOld") s.Require().NoError(s.repo.Create(s.ctx, logOld), "Create logOld")
logPerf := &model.UsageLog{ logPerf := &service.UsageLog{
UserID: userToday.ID, UserID: userToday.ID,
ApiKeyID: apiKey1.ID, ApiKeyID: apiKey1.ID,
AccountID: accNormal.ID, AccountID: accNormal.ID,
...@@ -293,9 +293,9 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() { ...@@ -293,9 +293,9 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
// --- GetUserDashboardStats --- // --- GetUserDashboardStats ---
func (s *UsageLogRepoSuite) TestGetUserDashboardStats() { func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "userdash@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "userdash@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-userdash", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-userdash", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-userdash"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-userdash"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
...@@ -308,9 +308,9 @@ func (s *UsageLogRepoSuite) TestGetUserDashboardStats() { ...@@ -308,9 +308,9 @@ func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
// --- GetAccountTodayStats --- // --- GetAccountTodayStats ---
func (s *UsageLogRepoSuite) TestGetAccountTodayStats() { func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "acctoday@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "acctoday@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-today"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-today"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
...@@ -323,11 +323,11 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() { ...@@ -323,11 +323,11 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
// --- GetBatchUserUsageStats --- // --- GetBatchUserUsageStats ---
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() { func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "batch1@test.com"}) user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "batch1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "batch2@test.com"}) user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "batch2@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user1.ID, Key: "sk-batch1", Name: "k"}) apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user1.ID, Key: "sk-batch1", Name: "k"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user2.ID, Key: "sk-batch2", Name: "k"}) apiKey2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user2.ID, Key: "sk-batch2", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-batch"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-batch"})
s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user2, apiKey2, account, 15, 25, 0.6, time.Now()) s.createUsageLog(user2, apiKey2, account, 15, 25, 0.6, time.Now())
...@@ -348,10 +348,10 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() { ...@@ -348,10 +348,10 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() {
// --- GetBatchApiKeyUsageStats --- // --- GetBatchApiKeyUsageStats ---
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() { func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "batchkey@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "batchkey@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"}) apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"}) apiKey2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-batchkey"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-batchkey"})
s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now()) s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now())
...@@ -370,9 +370,9 @@ func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() { ...@@ -370,9 +370,9 @@ func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
// --- GetGlobalStats --- // --- GetGlobalStats ---
func (s *UsageLogRepoSuite) TestGetGlobalStats() { func (s *UsageLogRepoSuite) TestGetGlobalStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "global@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "global@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-global", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-global", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-global"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-global"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
...@@ -395,9 +395,9 @@ func maxTime(a, b time.Time) time.Time { ...@@ -395,9 +395,9 @@ func maxTime(a, b time.Time) time.Time {
// --- ListByUserAndTimeRange --- // --- ListByUserAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() { func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "timerange@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "timerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-timerange", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-timerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-timerange"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-timerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
...@@ -414,9 +414,9 @@ func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() { ...@@ -414,9 +414,9 @@ func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
// --- ListByApiKeyAndTimeRange --- // --- ListByApiKeyAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() { func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "keytimerange@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "keytimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytimerange", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-keytimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-keytimerange"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-keytimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
...@@ -433,9 +433,9 @@ func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() { ...@@ -433,9 +433,9 @@ func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
// --- ListByAccountAndTimeRange --- // --- ListByAccountAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() { func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "acctimerange@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "acctimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-acctimerange", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-acctimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-acctimerange"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-acctimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
...@@ -452,14 +452,14 @@ func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() { ...@@ -452,14 +452,14 @@ func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
// --- ListByModelAndTimeRange --- // --- ListByModelAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() { func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "modeltimerange@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "modeltimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-modeltimerange"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-modeltimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
// Create logs with different models // Create logs with different models
log1 := &model.UsageLog{ log1 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
...@@ -472,7 +472,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() { ...@@ -472,7 +472,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
} }
s.Require().NoError(s.repo.Create(s.ctx, log1)) s.Require().NoError(s.repo.Create(s.ctx, log1))
log2 := &model.UsageLog{ log2 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
...@@ -485,7 +485,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() { ...@@ -485,7 +485,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
} }
s.Require().NoError(s.repo.Create(s.ctx, log2)) s.Require().NoError(s.repo.Create(s.ctx, log2))
log3 := &model.UsageLog{ log3 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
...@@ -508,9 +508,9 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() { ...@@ -508,9 +508,9 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
// --- GetAccountWindowStats --- // --- GetAccountWindowStats ---
func (s *UsageLogRepoSuite) TestGetAccountWindowStats() { func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "windowstats@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "windowstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-windowstats", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-windowstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-windowstats"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-windowstats"})
now := time.Now() now := time.Now()
windowStart := now.Add(-10 * time.Minute) windowStart := now.Add(-10 * time.Minute)
...@@ -528,9 +528,9 @@ func (s *UsageLogRepoSuite) TestGetAccountWindowStats() { ...@@ -528,9 +528,9 @@ func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
// --- GetUserUsageTrendByUserID --- // --- GetUserUsageTrendByUserID ---
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() { func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrend@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "usertrend@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-usertrend", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-usertrend", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-usertrend"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-usertrend"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
...@@ -545,9 +545,9 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() { ...@@ -545,9 +545,9 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
} }
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() { func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrendhourly@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "usertrendhourly@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-usertrendhourly"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-usertrendhourly"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
...@@ -564,14 +564,14 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() { ...@@ -564,14 +564,14 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
// --- GetUserModelStats --- // --- GetUserModelStats ---
func (s *UsageLogRepoSuite) TestGetUserModelStats() { func (s *UsageLogRepoSuite) TestGetUserModelStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "modelstats@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "modelstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-modelstats", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-modelstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-modelstats"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-modelstats"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
// Create logs with different models // Create logs with different models
log1 := &model.UsageLog{ log1 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
...@@ -584,7 +584,7 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() { ...@@ -584,7 +584,7 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
} }
s.Require().NoError(s.repo.Create(s.ctx, log1)) s.Require().NoError(s.repo.Create(s.ctx, log1))
log2 := &model.UsageLog{ log2 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
...@@ -611,9 +611,9 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() { ...@@ -611,9 +611,9 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
// --- GetUsageTrendWithFilters --- // --- GetUsageTrendWithFilters ---
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() { func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "trendfilters@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "trendfilters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-trendfilters", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-trendfilters", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-trendfilters"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-trendfilters"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
...@@ -639,9 +639,9 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() { ...@@ -639,9 +639,9 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
} }
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() { func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "trendfilters-h@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "trendfilters-h@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-trendfilters-h"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-trendfilters-h"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
...@@ -658,13 +658,13 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() { ...@@ -658,13 +658,13 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
// --- GetModelStatsWithFilters --- // --- GetModelStatsWithFilters ---
func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() { func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "modelfilters@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "modelfilters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-modelfilters", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-modelfilters", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-modelfilters"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-modelfilters"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
log1 := &model.UsageLog{ log1 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
...@@ -677,7 +677,7 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() { ...@@ -677,7 +677,7 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
} }
s.Require().NoError(s.repo.Create(s.ctx, log1)) s.Require().NoError(s.repo.Create(s.ctx, log1))
log2 := &model.UsageLog{ log2 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
...@@ -712,14 +712,14 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() { ...@@ -712,14 +712,14 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
// --- GetAccountUsageStats --- // --- GetAccountUsageStats ---
func (s *UsageLogRepoSuite) TestGetAccountUsageStats() { func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "accstats@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "accstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-accstats", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-accstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-accstats"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-accstats"})
base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC)
// Create logs on different days // Create logs on different days
log1 := &model.UsageLog{ log1 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
...@@ -732,7 +732,7 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() { ...@@ -732,7 +732,7 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
} }
s.Require().NoError(s.repo.Create(s.ctx, log1)) s.Require().NoError(s.repo.Create(s.ctx, log1))
log2 := &model.UsageLog{ log2 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
...@@ -758,7 +758,7 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() { ...@@ -758,7 +758,7 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
} }
func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() { func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-emptystats"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-emptystats"})
base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC)
startTime := base startTime := base
...@@ -774,11 +774,11 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() { ...@@ -774,11 +774,11 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() {
// --- GetUserUsageTrend --- // --- GetUserUsageTrend ---
func (s *UsageLogRepoSuite) TestGetUserUsageTrend() { func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrend1@test.com"}) user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "usertrend1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrend2@test.com"}) user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "usertrend2@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"}) apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"}) apiKey2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-usertrends"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-usertrends"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user1, apiKey1, account, 100, 200, 1.0, base) s.createUsageLog(user1, apiKey1, account, 100, 200, 1.0, base)
...@@ -796,10 +796,10 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrend() { ...@@ -796,10 +796,10 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
// --- GetApiKeyUsageTrend --- // --- GetApiKeyUsageTrend ---
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() { func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "keytrend@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "keytrend@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"}) apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"}) apiKey2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-keytrends"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-keytrends"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey1, account, 100, 200, 1.0, base) s.createUsageLog(user, apiKey1, account, 100, 200, 1.0, base)
...@@ -815,9 +815,9 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() { ...@@ -815,9 +815,9 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() {
} }
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() { func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "keytrendh@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "keytrendh@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-keytrendh", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-keytrendh"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-keytrendh"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 100, 200, 1.0, base) s.createUsageLog(user, apiKey, account, 100, 200, 1.0, base)
...@@ -834,9 +834,9 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() { ...@@ -834,9 +834,9 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
// --- ListWithFilters (additional filter tests) --- // --- ListWithFilters (additional filter tests) ---
func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() { func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filterskey@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "filterskey@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filterskey", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-filterskey", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filterskey"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-filterskey"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
...@@ -848,9 +848,9 @@ func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() { ...@@ -848,9 +848,9 @@ func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
} }
func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() { func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filterstime@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "filterstime@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filterstime", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-filterstime", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filterstime"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-filterstime"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
...@@ -867,9 +867,9 @@ func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() { ...@@ -867,9 +867,9 @@ func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
} }
func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() { func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filterscombined@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "filterscombined@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filterscombined", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-filterscombined", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filterscombined"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-filterscombined"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
......
...@@ -2,12 +2,13 @@ package repository ...@@ -2,12 +2,13 @@ package repository
import ( import (
"context" "context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/lib/pq"
"gorm.io/gorm" "gorm.io/gorm"
) )
...@@ -19,48 +20,56 @@ func NewUserRepository(db *gorm.DB) service.UserRepository { ...@@ -19,48 +20,56 @@ func NewUserRepository(db *gorm.DB) service.UserRepository {
return &userRepository{db: db} return &userRepository{db: db}
} }
func (r *userRepository) Create(ctx context.Context, user *model.User) error { func (r *userRepository) Create(ctx context.Context, user *service.User) error {
err := r.db.WithContext(ctx).Create(user).Error m := userModelFromService(user)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyUserModelToService(user, m)
}
return translatePersistenceError(err, nil, service.ErrEmailExists) return translatePersistenceError(err, nil, service.ErrEmailExists)
} }
func (r *userRepository) GetByID(ctx context.Context, id int64) (*model.User, error) { func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User, error) {
var user model.User var m userModel
err := r.db.WithContext(ctx).First(&user, id).Error err := r.db.WithContext(ctx).First(&m, id).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil) return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
} }
return &user, nil return userModelToService(&m), nil
} }
func (r *userRepository) GetByEmail(ctx context.Context, email string) (*model.User, error) { func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service.User, error) {
var user model.User var m userModel
err := r.db.WithContext(ctx).Where("email = ?", email).First(&user).Error err := r.db.WithContext(ctx).Where("email = ?", email).First(&m).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil) return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
} }
return &user, nil return userModelToService(&m), nil
} }
func (r *userRepository) Update(ctx context.Context, user *model.User) error { func (r *userRepository) Update(ctx context.Context, user *service.User) error {
err := r.db.WithContext(ctx).Save(user).Error m := userModelFromService(user)
err := r.db.WithContext(ctx).Save(m).Error
if err == nil {
applyUserModelToService(user, m)
}
return translatePersistenceError(err, nil, service.ErrEmailExists) return translatePersistenceError(err, nil, service.ErrEmailExists)
} }
func (r *userRepository) Delete(ctx context.Context, id int64) error { func (r *userRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.User{}, id).Error return r.db.WithContext(ctx).Delete(&userModel{}, id).Error
} }
func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) { func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "") return r.ListWithFilters(ctx, params, "", "", "")
} }
// ListWithFilters lists users with optional filtering by status, role, and search query // ListWithFilters lists users with optional filtering by status, role, and search query
func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error) { func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]service.User, *pagination.PaginationResult, error) {
var users []model.User var users []userModel
var total int64 var total int64
db := r.db.WithContext(ctx).Model(&model.User{}) db := r.db.WithContext(ctx).Model(&userModel{})
// Apply filters // Apply filters
if status != "" { if status != "" {
...@@ -89,17 +98,20 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination. ...@@ -89,17 +98,20 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
// Batch load subscriptions for all users (avoid N+1) // Batch load subscriptions for all users (avoid N+1)
if len(users) > 0 { if len(users) > 0 {
userIDs := make([]int64, len(users)) userIDs := make([]int64, len(users))
userMap := make(map[int64]*model.User, len(users)) userMap := make(map[int64]*service.User, len(users))
outUsers := make([]service.User, 0, len(users))
for i := range users { for i := range users {
userIDs[i] = users[i].ID userIDs[i] = users[i].ID
userMap[users[i].ID] = &users[i] u := userModelToService(&users[i])
outUsers = append(outUsers, *u)
userMap[u.ID] = &outUsers[len(outUsers)-1]
} }
// Query active subscriptions with groups in one query // Query active subscriptions with groups in one query
var subscriptions []model.UserSubscription var subscriptions []userSubscriptionModel
if err := r.db.WithContext(ctx). if err := r.db.WithContext(ctx).
Preload("Group"). Preload("Group").
Where("user_id IN ? AND status = ?", userIDs, model.SubscriptionStatusActive). Where("user_id IN ? AND status = ?", userIDs, service.SubscriptionStatusActive).
Find(&subscriptions).Error; err != nil { Find(&subscriptions).Error; err != nil {
return nil, nil, err return nil, nil, err
} }
...@@ -107,32 +119,29 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination. ...@@ -107,32 +119,29 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
// Associate subscriptions with users // Associate subscriptions with users
for i := range subscriptions { for i := range subscriptions {
if user, ok := userMap[subscriptions[i].UserID]; ok { if user, ok := userMap[subscriptions[i].UserID]; ok {
user.Subscriptions = append(user.Subscriptions, subscriptions[i]) user.Subscriptions = append(user.Subscriptions, *userSubscriptionModelToService(&subscriptions[i]))
} }
} }
return outUsers, paginationResultFromTotal(total, params), nil
} }
pages := int(total) / params.Limit() outUsers := make([]service.User, 0, len(users))
if int(total)%params.Limit() > 0 { for i := range users {
pages++ outUsers = append(outUsers, *userModelToService(&users[i]))
} }
return users, &pagination.PaginationResult{ return outUsers, paginationResultFromTotal(total, params), nil
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
} }
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error { func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&userModel{}).Where("id = ?", id).
Update("balance", gorm.Expr("balance + ?", amount)).Error Update("balance", gorm.Expr("balance + ?", amount)).Error
} }
// DeductBalance 扣减用户余额,仅当余额充足时执行 // DeductBalance 扣减用户余额,仅当余额充足时执行
func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error { func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error {
result := r.db.WithContext(ctx).Model(&model.User{}). result := r.db.WithContext(ctx).Model(&userModel{}).
Where("id = ? AND balance >= ?", id, amount). Where("id = ? AND balance >= ?", id, amount).
Update("balance", gorm.Expr("balance - ?", amount)) Update("balance", gorm.Expr("balance - ?", amount))
if result.Error != nil { if result.Error != nil {
...@@ -145,34 +154,104 @@ func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount flo ...@@ -145,34 +154,104 @@ func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount flo
} }
func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error { func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&userModel{}).Where("id = ?", id).
Update("concurrency", gorm.Expr("concurrency + ?", amount)).Error Update("concurrency", gorm.Expr("concurrency + ?", amount)).Error
} }
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) { func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.User{}).Where("email = ?", email).Count(&count).Error err := r.db.WithContext(ctx).Model(&userModel{}).Where("email = ?", email).Count(&count).Error
return count > 0, err return count > 0, err
} }
// RemoveGroupFromAllowedGroups 从所有用户的 allowed_groups 数组中移除指定的分组ID // RemoveGroupFromAllowedGroups 从所有用户的 allowed_groups 数组中移除指定的分组ID
// 使用 PostgreSQL 的 array_remove 函数 // 使用 PostgreSQL 的 array_remove 函数
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) { func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Model(&model.User{}). result := r.db.WithContext(ctx).Model(&userModel{}).
Where("? = ANY(allowed_groups)", groupID). Where("? = ANY(allowed_groups)", groupID).
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", groupID)) Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", groupID))
return result.RowsAffected, result.Error return result.RowsAffected, result.Error
} }
// GetFirstAdmin 获取第一个管理员用户(用于 Admin API Key 认证) // GetFirstAdmin 获取第一个管理员用户(用于 Admin API Key 认证)
func (r *userRepository) GetFirstAdmin(ctx context.Context) (*model.User, error) { func (r *userRepository) GetFirstAdmin(ctx context.Context) (*service.User, error) {
var user model.User var m userModel
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("role = ? AND status = ?", model.RoleAdmin, model.StatusActive). Where("role = ? AND status = ?", service.RoleAdmin, service.StatusActive).
Order("id ASC"). Order("id ASC").
First(&user).Error First(&m).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil) return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
} }
return &user, nil return userModelToService(&m), nil
}
type userModel struct {
ID int64 `gorm:"primaryKey"`
Email string `gorm:"uniqueIndex;size:255;not null"`
Username string `gorm:"size:100;default:''"`
Wechat string `gorm:"size:100;default:''"`
Notes string `gorm:"type:text;default:''"`
PasswordHash string `gorm:"size:255;not null"`
Role string `gorm:"size:20;default:user;not null"`
Balance float64 `gorm:"type:decimal(20,8);default:0;not null"`
Concurrency int `gorm:"default:5;not null"`
Status string `gorm:"size:20;default:active;not null"`
AllowedGroups pq.Int64Array `gorm:"type:bigint[]"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
DeletedAt gorm.DeletedAt `gorm:"index"`
}
func (userModel) TableName() string { return "users" }
func userModelToService(m *userModel) *service.User {
if m == nil {
return nil
}
return &service.User{
ID: m.ID,
Email: m.Email,
Username: m.Username,
Wechat: m.Wechat,
Notes: m.Notes,
PasswordHash: m.PasswordHash,
Role: m.Role,
Balance: m.Balance,
Concurrency: m.Concurrency,
Status: m.Status,
AllowedGroups: []int64(m.AllowedGroups),
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
}
}
func userModelFromService(u *service.User) *userModel {
if u == nil {
return nil
}
return &userModel{
ID: u.ID,
Email: u.Email,
Username: u.Username,
Wechat: u.Wechat,
Notes: u.Notes,
PasswordHash: u.PasswordHash,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
AllowedGroups: pq.Int64Array(u.AllowedGroups),
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
}
}
func applyUserModelToService(dst *service.User, src *userModel) {
if dst == nil || src == nil {
return
}
dst.ID = src.ID
dst.CreatedAt = src.CreatedAt
dst.UpdatedAt = src.UpdatedAt
} }
...@@ -7,7 +7,6 @@ import ( ...@@ -7,7 +7,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq" "github.com/lib/pq"
...@@ -35,11 +34,12 @@ func TestUserRepoSuite(t *testing.T) { ...@@ -35,11 +34,12 @@ func TestUserRepoSuite(t *testing.T) {
// --- Create / GetByID / GetByEmail / Update / Delete --- // --- Create / GetByID / GetByEmail / Update / Delete ---
func (s *UserRepoSuite) TestCreate() { func (s *UserRepoSuite) TestCreate() {
user := &model.User{ user := &service.User{
Email: "create@test.com", Email: "create@test.com",
Username: "testuser", Username: "testuser",
Role: model.RoleUser, PasswordHash: "test-password-hash",
Status: model.StatusActive, Role: service.RoleUser,
Status: service.StatusActive,
} }
err := s.repo.Create(s.ctx, user) err := s.repo.Create(s.ctx, user)
...@@ -57,7 +57,7 @@ func (s *UserRepoSuite) TestGetByID_NotFound() { ...@@ -57,7 +57,7 @@ func (s *UserRepoSuite) TestGetByID_NotFound() {
} }
func (s *UserRepoSuite) TestGetByEmail() { func (s *UserRepoSuite) TestGetByEmail() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "byemail@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "byemail@test.com"})
got, err := s.repo.GetByEmail(s.ctx, user.Email) got, err := s.repo.GetByEmail(s.ctx, user.Email)
s.Require().NoError(err, "GetByEmail") s.Require().NoError(err, "GetByEmail")
...@@ -70,7 +70,7 @@ func (s *UserRepoSuite) TestGetByEmail_NotFound() { ...@@ -70,7 +70,7 @@ func (s *UserRepoSuite) TestGetByEmail_NotFound() {
} }
func (s *UserRepoSuite) TestUpdate() { func (s *UserRepoSuite) TestUpdate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "update@test.com", Username: "original"}) user := userModelToService(mustCreateUser(s.T(), s.db, &userModel{Email: "update@test.com", Username: "original"}))
user.Username = "updated" user.Username = "updated"
err := s.repo.Update(s.ctx, user) err := s.repo.Update(s.ctx, user)
...@@ -82,7 +82,7 @@ func (s *UserRepoSuite) TestUpdate() { ...@@ -82,7 +82,7 @@ func (s *UserRepoSuite) TestUpdate() {
} }
func (s *UserRepoSuite) TestDelete() { func (s *UserRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"})
err := s.repo.Delete(s.ctx, user.ID) err := s.repo.Delete(s.ctx, user.ID)
s.Require().NoError(err, "Delete") s.Require().NoError(err, "Delete")
...@@ -94,8 +94,8 @@ func (s *UserRepoSuite) TestDelete() { ...@@ -94,8 +94,8 @@ func (s *UserRepoSuite) TestDelete() {
// --- List / ListWithFilters --- // --- List / ListWithFilters ---
func (s *UserRepoSuite) TestList() { func (s *UserRepoSuite) TestList() {
mustCreateUser(s.T(), s.db, &model.User{Email: "list1@test.com"}) mustCreateUser(s.T(), s.db, &userModel{Email: "list1@test.com"})
mustCreateUser(s.T(), s.db, &model.User{Email: "list2@test.com"}) mustCreateUser(s.T(), s.db, &userModel{Email: "list2@test.com"})
users, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}) users, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List") s.Require().NoError(err, "List")
...@@ -104,28 +104,28 @@ func (s *UserRepoSuite) TestList() { ...@@ -104,28 +104,28 @@ func (s *UserRepoSuite) TestList() {
} }
func (s *UserRepoSuite) TestListWithFilters_Status() { func (s *UserRepoSuite) TestListWithFilters_Status() {
mustCreateUser(s.T(), s.db, &model.User{Email: "active@test.com", Status: model.StatusActive}) mustCreateUser(s.T(), s.db, &userModel{Email: "active@test.com", Status: service.StatusActive})
mustCreateUser(s.T(), s.db, &model.User{Email: "disabled@test.com", Status: model.StatusDisabled}) mustCreateUser(s.T(), s.db, &userModel{Email: "disabled@test.com", Status: service.StatusDisabled})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.StatusActive, "", "") users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.StatusActive, "", "")
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(users, 1) s.Require().Len(users, 1)
s.Require().Equal(model.StatusActive, users[0].Status) s.Require().Equal(service.StatusActive, users[0].Status)
} }
func (s *UserRepoSuite) TestListWithFilters_Role() { func (s *UserRepoSuite) TestListWithFilters_Role() {
mustCreateUser(s.T(), s.db, &model.User{Email: "user@test.com", Role: model.RoleUser}) mustCreateUser(s.T(), s.db, &userModel{Email: "user@test.com", Role: service.RoleUser})
mustCreateUser(s.T(), s.db, &model.User{Email: "admin@test.com", Role: model.RoleAdmin}) mustCreateUser(s.T(), s.db, &userModel{Email: "admin@test.com", Role: service.RoleAdmin})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.RoleAdmin, "") users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.RoleAdmin, "")
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(users, 1) s.Require().Len(users, 1)
s.Require().Equal(model.RoleAdmin, users[0].Role) s.Require().Equal(service.RoleAdmin, users[0].Role)
} }
func (s *UserRepoSuite) TestListWithFilters_Search() { func (s *UserRepoSuite) TestListWithFilters_Search() {
mustCreateUser(s.T(), s.db, &model.User{Email: "alice@test.com", Username: "Alice"}) mustCreateUser(s.T(), s.db, &userModel{Email: "alice@test.com", Username: "Alice"})
mustCreateUser(s.T(), s.db, &model.User{Email: "bob@test.com", Username: "Bob"}) mustCreateUser(s.T(), s.db, &userModel{Email: "bob@test.com", Username: "Bob"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alice") users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alice")
s.Require().NoError(err) s.Require().NoError(err)
...@@ -134,8 +134,8 @@ func (s *UserRepoSuite) TestListWithFilters_Search() { ...@@ -134,8 +134,8 @@ func (s *UserRepoSuite) TestListWithFilters_Search() {
} }
func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() { func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() {
mustCreateUser(s.T(), s.db, &model.User{Email: "u1@test.com", Username: "JohnDoe"}) mustCreateUser(s.T(), s.db, &userModel{Email: "u1@test.com", Username: "JohnDoe"})
mustCreateUser(s.T(), s.db, &model.User{Email: "u2@test.com", Username: "JaneSmith"}) mustCreateUser(s.T(), s.db, &userModel{Email: "u2@test.com", Username: "JaneSmith"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "john") users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "john")
s.Require().NoError(err) s.Require().NoError(err)
...@@ -144,8 +144,8 @@ func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() { ...@@ -144,8 +144,8 @@ func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() {
} }
func (s *UserRepoSuite) TestListWithFilters_SearchByWechat() { func (s *UserRepoSuite) TestListWithFilters_SearchByWechat() {
mustCreateUser(s.T(), s.db, &model.User{Email: "w1@test.com", Wechat: "wx_hello"}) mustCreateUser(s.T(), s.db, &userModel{Email: "w1@test.com", Wechat: "wx_hello"})
mustCreateUser(s.T(), s.db, &model.User{Email: "w2@test.com", Wechat: "wx_world"}) mustCreateUser(s.T(), s.db, &userModel{Email: "w2@test.com", Wechat: "wx_world"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "wx_hello") users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "wx_hello")
s.Require().NoError(err) s.Require().NoError(err)
...@@ -154,19 +154,19 @@ func (s *UserRepoSuite) TestListWithFilters_SearchByWechat() { ...@@ -154,19 +154,19 @@ func (s *UserRepoSuite) TestListWithFilters_SearchByWechat() {
} }
func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() { func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "sub@test.com", Status: model.StatusActive}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "sub@test.com", Status: service.StatusActive})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sub"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sub"})
_ = mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ _ = mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(1 * time.Hour), ExpiresAt: time.Now().Add(1 * time.Hour),
}) })
_ = mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ _ = mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusExpired, Status: service.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-1 * time.Hour), ExpiresAt: time.Now().Add(-1 * time.Hour),
}) })
...@@ -179,29 +179,29 @@ func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() { ...@@ -179,29 +179,29 @@ func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() {
} }
func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() { func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
mustCreateUser(s.T(), s.db, &model.User{ mustCreateUser(s.T(), s.db, &userModel{
Email: "a@example.com", Email: "a@example.com",
Username: "Alice", Username: "Alice",
Wechat: "wx_a", Wechat: "wx_a",
Role: model.RoleUser, Role: service.RoleUser,
Status: model.StatusActive, Status: service.StatusActive,
Balance: 10, Balance: 10,
}) })
target := mustCreateUser(s.T(), s.db, &model.User{ target := mustCreateUser(s.T(), s.db, &userModel{
Email: "b@example.com", Email: "b@example.com",
Username: "Bob", Username: "Bob",
Wechat: "wx_b", Wechat: "wx_b",
Role: model.RoleAdmin, Role: service.RoleAdmin,
Status: model.StatusActive, Status: service.StatusActive,
Balance: 1, Balance: 1,
}) })
mustCreateUser(s.T(), s.db, &model.User{ mustCreateUser(s.T(), s.db, &userModel{
Email: "c@example.com", Email: "c@example.com",
Role: model.RoleAdmin, Role: service.RoleAdmin,
Status: model.StatusDisabled, Status: service.StatusDisabled,
}) })
users, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.StatusActive, model.RoleAdmin, "b@") users, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.StatusActive, service.RoleAdmin, "b@")
s.Require().NoError(err, "ListWithFilters") s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch") s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
s.Require().Len(users, 1, "ListWithFilters len mismatch") s.Require().Len(users, 1, "ListWithFilters len mismatch")
...@@ -211,7 +211,7 @@ func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() { ...@@ -211,7 +211,7 @@ func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
// --- Balance operations --- // --- Balance operations ---
func (s *UserRepoSuite) TestUpdateBalance() { func (s *UserRepoSuite) TestUpdateBalance() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "bal@test.com", Balance: 10}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "bal@test.com", Balance: 10})
err := s.repo.UpdateBalance(s.ctx, user.ID, 2.5) err := s.repo.UpdateBalance(s.ctx, user.ID, 2.5)
s.Require().NoError(err, "UpdateBalance") s.Require().NoError(err, "UpdateBalance")
...@@ -222,7 +222,7 @@ func (s *UserRepoSuite) TestUpdateBalance() { ...@@ -222,7 +222,7 @@ func (s *UserRepoSuite) TestUpdateBalance() {
} }
func (s *UserRepoSuite) TestUpdateBalance_Negative() { func (s *UserRepoSuite) TestUpdateBalance_Negative() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "balneg@test.com", Balance: 10}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "balneg@test.com", Balance: 10})
err := s.repo.UpdateBalance(s.ctx, user.ID, -3) err := s.repo.UpdateBalance(s.ctx, user.ID, -3)
s.Require().NoError(err, "UpdateBalance with negative") s.Require().NoError(err, "UpdateBalance with negative")
...@@ -233,7 +233,7 @@ func (s *UserRepoSuite) TestUpdateBalance_Negative() { ...@@ -233,7 +233,7 @@ func (s *UserRepoSuite) TestUpdateBalance_Negative() {
} }
func (s *UserRepoSuite) TestDeductBalance() { func (s *UserRepoSuite) TestDeductBalance() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "deduct@test.com", Balance: 10}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "deduct@test.com", Balance: 10})
err := s.repo.DeductBalance(s.ctx, user.ID, 5) err := s.repo.DeductBalance(s.ctx, user.ID, 5)
s.Require().NoError(err, "DeductBalance") s.Require().NoError(err, "DeductBalance")
...@@ -244,7 +244,7 @@ func (s *UserRepoSuite) TestDeductBalance() { ...@@ -244,7 +244,7 @@ func (s *UserRepoSuite) TestDeductBalance() {
} }
func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() { func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "insuf@test.com", Balance: 5}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "insuf@test.com", Balance: 5})
err := s.repo.DeductBalance(s.ctx, user.ID, 999) err := s.repo.DeductBalance(s.ctx, user.ID, 999)
s.Require().Error(err, "expected error for insufficient balance") s.Require().Error(err, "expected error for insufficient balance")
...@@ -252,7 +252,7 @@ func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() { ...@@ -252,7 +252,7 @@ func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() {
} }
func (s *UserRepoSuite) TestDeductBalance_ExactAmount() { func (s *UserRepoSuite) TestDeductBalance_ExactAmount() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "exact@test.com", Balance: 10}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "exact@test.com", Balance: 10})
err := s.repo.DeductBalance(s.ctx, user.ID, 10) err := s.repo.DeductBalance(s.ctx, user.ID, 10)
s.Require().NoError(err, "DeductBalance exact amount") s.Require().NoError(err, "DeductBalance exact amount")
...@@ -265,7 +265,7 @@ func (s *UserRepoSuite) TestDeductBalance_ExactAmount() { ...@@ -265,7 +265,7 @@ func (s *UserRepoSuite) TestDeductBalance_ExactAmount() {
// --- Concurrency --- // --- Concurrency ---
func (s *UserRepoSuite) TestUpdateConcurrency() { func (s *UserRepoSuite) TestUpdateConcurrency() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "conc@test.com", Concurrency: 5}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "conc@test.com", Concurrency: 5})
err := s.repo.UpdateConcurrency(s.ctx, user.ID, 3) err := s.repo.UpdateConcurrency(s.ctx, user.ID, 3)
s.Require().NoError(err, "UpdateConcurrency") s.Require().NoError(err, "UpdateConcurrency")
...@@ -276,7 +276,7 @@ func (s *UserRepoSuite) TestUpdateConcurrency() { ...@@ -276,7 +276,7 @@ func (s *UserRepoSuite) TestUpdateConcurrency() {
} }
func (s *UserRepoSuite) TestUpdateConcurrency_Negative() { func (s *UserRepoSuite) TestUpdateConcurrency_Negative() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "concneg@test.com", Concurrency: 5}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "concneg@test.com", Concurrency: 5})
err := s.repo.UpdateConcurrency(s.ctx, user.ID, -2) err := s.repo.UpdateConcurrency(s.ctx, user.ID, -2)
s.Require().NoError(err, "UpdateConcurrency negative") s.Require().NoError(err, "UpdateConcurrency negative")
...@@ -289,7 +289,7 @@ func (s *UserRepoSuite) TestUpdateConcurrency_Negative() { ...@@ -289,7 +289,7 @@ func (s *UserRepoSuite) TestUpdateConcurrency_Negative() {
// --- ExistsByEmail --- // --- ExistsByEmail ---
func (s *UserRepoSuite) TestExistsByEmail() { func (s *UserRepoSuite) TestExistsByEmail() {
mustCreateUser(s.T(), s.db, &model.User{Email: "exists@test.com"}) mustCreateUser(s.T(), s.db, &userModel{Email: "exists@test.com"})
exists, err := s.repo.ExistsByEmail(s.ctx, "exists@test.com") exists, err := s.repo.ExistsByEmail(s.ctx, "exists@test.com")
s.Require().NoError(err, "ExistsByEmail") s.Require().NoError(err, "ExistsByEmail")
...@@ -304,11 +304,11 @@ func (s *UserRepoSuite) TestExistsByEmail() { ...@@ -304,11 +304,11 @@ func (s *UserRepoSuite) TestExistsByEmail() {
func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups() { func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups() {
groupID := int64(42) groupID := int64(42)
userA := mustCreateUser(s.T(), s.db, &model.User{ userA := mustCreateUser(s.T(), s.db, &userModel{
Email: "a1@example.com", Email: "a1@example.com",
AllowedGroups: pq.Int64Array{groupID, 7}, AllowedGroups: pq.Int64Array{groupID, 7},
}) })
mustCreateUser(s.T(), s.db, &model.User{ mustCreateUser(s.T(), s.db, &userModel{
Email: "a2@example.com", Email: "a2@example.com",
AllowedGroups: pq.Int64Array{7}, AllowedGroups: pq.Int64Array{7},
}) })
...@@ -325,7 +325,7 @@ func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups() { ...@@ -325,7 +325,7 @@ func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups() {
} }
func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups_NoMatch() { func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups_NoMatch() {
mustCreateUser(s.T(), s.db, &model.User{ mustCreateUser(s.T(), s.db, &userModel{
Email: "nomatch@test.com", Email: "nomatch@test.com",
AllowedGroups: pq.Int64Array{1, 2, 3}, AllowedGroups: pq.Int64Array{1, 2, 3},
}) })
...@@ -338,15 +338,15 @@ func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups_NoMatch() { ...@@ -338,15 +338,15 @@ func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups_NoMatch() {
// --- GetFirstAdmin --- // --- GetFirstAdmin ---
func (s *UserRepoSuite) TestGetFirstAdmin() { func (s *UserRepoSuite) TestGetFirstAdmin() {
admin1 := mustCreateUser(s.T(), s.db, &model.User{ admin1 := mustCreateUser(s.T(), s.db, &userModel{
Email: "admin1@example.com", Email: "admin1@example.com",
Role: model.RoleAdmin, Role: service.RoleAdmin,
Status: model.StatusActive, Status: service.StatusActive,
}) })
mustCreateUser(s.T(), s.db, &model.User{ mustCreateUser(s.T(), s.db, &userModel{
Email: "admin2@example.com", Email: "admin2@example.com",
Role: model.RoleAdmin, Role: service.RoleAdmin,
Status: model.StatusActive, Status: service.StatusActive,
}) })
got, err := s.repo.GetFirstAdmin(s.ctx) got, err := s.repo.GetFirstAdmin(s.ctx)
...@@ -355,10 +355,10 @@ func (s *UserRepoSuite) TestGetFirstAdmin() { ...@@ -355,10 +355,10 @@ func (s *UserRepoSuite) TestGetFirstAdmin() {
} }
func (s *UserRepoSuite) TestGetFirstAdmin_NoAdmin() { func (s *UserRepoSuite) TestGetFirstAdmin_NoAdmin() {
mustCreateUser(s.T(), s.db, &model.User{ mustCreateUser(s.T(), s.db, &userModel{
Email: "user@example.com", Email: "user@example.com",
Role: model.RoleUser, Role: service.RoleUser,
Status: model.StatusActive, Status: service.StatusActive,
}) })
_, err := s.repo.GetFirstAdmin(s.ctx) _, err := s.repo.GetFirstAdmin(s.ctx)
...@@ -366,15 +366,15 @@ func (s *UserRepoSuite) TestGetFirstAdmin_NoAdmin() { ...@@ -366,15 +366,15 @@ func (s *UserRepoSuite) TestGetFirstAdmin_NoAdmin() {
} }
func (s *UserRepoSuite) TestGetFirstAdmin_DisabledAdminIgnored() { func (s *UserRepoSuite) TestGetFirstAdmin_DisabledAdminIgnored() {
mustCreateUser(s.T(), s.db, &model.User{ mustCreateUser(s.T(), s.db, &userModel{
Email: "disabled@example.com", Email: "disabled@example.com",
Role: model.RoleAdmin, Role: service.RoleAdmin,
Status: model.StatusDisabled, Status: service.StatusDisabled,
}) })
activeAdmin := mustCreateUser(s.T(), s.db, &model.User{ activeAdmin := mustCreateUser(s.T(), s.db, &userModel{
Email: "active@example.com", Email: "active@example.com",
Role: model.RoleAdmin, Role: service.RoleAdmin,
Status: model.StatusActive, Status: service.StatusActive,
}) })
got, err := s.repo.GetFirstAdmin(s.ctx) got, err := s.repo.GetFirstAdmin(s.ctx)
...@@ -385,26 +385,26 @@ func (s *UserRepoSuite) TestGetFirstAdmin_DisabledAdminIgnored() { ...@@ -385,26 +385,26 @@ func (s *UserRepoSuite) TestGetFirstAdmin_DisabledAdminIgnored() {
// --- Combined original test --- // --- Combined original test ---
func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() { func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
user1 := mustCreateUser(s.T(), s.db, &model.User{ user1 := mustCreateUser(s.T(), s.db, &userModel{
Email: "a@example.com", Email: "a@example.com",
Username: "Alice", Username: "Alice",
Wechat: "wx_a", Wechat: "wx_a",
Role: model.RoleUser, Role: service.RoleUser,
Status: model.StatusActive, Status: service.StatusActive,
Balance: 10, Balance: 10,
}) })
user2 := mustCreateUser(s.T(), s.db, &model.User{ user2 := mustCreateUser(s.T(), s.db, &userModel{
Email: "b@example.com", Email: "b@example.com",
Username: "Bob", Username: "Bob",
Wechat: "wx_b", Wechat: "wx_b",
Role: model.RoleAdmin, Role: service.RoleAdmin,
Status: model.StatusActive, Status: service.StatusActive,
Balance: 1, Balance: 1,
}) })
_ = mustCreateUser(s.T(), s.db, &model.User{ _ = mustCreateUser(s.T(), s.db, &userModel{
Email: "c@example.com", Email: "c@example.com",
Role: model.RoleAdmin, Role: service.RoleAdmin,
Status: model.StatusDisabled, Status: service.StatusDisabled,
}) })
got, err := s.repo.GetByID(s.ctx, user1.ID) got, err := s.repo.GetByID(s.ctx, user1.ID)
...@@ -441,7 +441,7 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() { ...@@ -441,7 +441,7 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
s.Require().Equal(user1.Concurrency+3, got5.Concurrency, "UpdateConcurrency mismatch") s.Require().Equal(user1.Concurrency+3, got5.Concurrency, "UpdateConcurrency mismatch")
params := pagination.PaginationParams{Page: 1, PageSize: 10} params := pagination.PaginationParams{Page: 1, PageSize: 10}
users, page, err := s.repo.ListWithFilters(s.ctx, params, model.StatusActive, model.RoleAdmin, "b@") users, page, err := s.repo.ListWithFilters(s.ctx, params, service.StatusActive, service.RoleAdmin, "b@")
s.Require().NoError(err, "ListWithFilters") s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch") s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
s.Require().Len(users, 1, "ListWithFilters len mismatch") s.Require().Len(users, 1, "ListWithFilters len mismatch")
......
...@@ -4,111 +4,113 @@ import ( ...@@ -4,111 +4,113 @@ import (
"context" "context"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"gorm.io/gorm" "gorm.io/gorm"
) )
// UserSubscriptionRepository 用户订阅仓库
type userSubscriptionRepository struct { type userSubscriptionRepository struct {
db *gorm.DB db *gorm.DB
} }
// NewUserSubscriptionRepository 创建用户订阅仓库
func NewUserSubscriptionRepository(db *gorm.DB) service.UserSubscriptionRepository { func NewUserSubscriptionRepository(db *gorm.DB) service.UserSubscriptionRepository {
return &userSubscriptionRepository{db: db} return &userSubscriptionRepository{db: db}
} }
// Create 创建订阅 func (r *userSubscriptionRepository) Create(ctx context.Context, sub *service.UserSubscription) error {
func (r *userSubscriptionRepository) Create(ctx context.Context, sub *model.UserSubscription) error { m := userSubscriptionModelFromService(sub)
err := r.db.WithContext(ctx).Create(sub).Error err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyUserSubscriptionModelToService(sub, m)
}
return translatePersistenceError(err, nil, service.ErrSubscriptionAlreadyExists) return translatePersistenceError(err, nil, service.ErrSubscriptionAlreadyExists)
} }
// GetByID 根据ID获取订阅 func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) {
func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) { var m userSubscriptionModel
var sub model.UserSubscription
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Preload("User"). Preload("User").
Preload("Group"). Preload("Group").
Preload("AssignedByUser"). Preload("AssignedByUser").
First(&sub, id).Error First(&m, id).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil) return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
} }
return &sub, nil return userSubscriptionModelToService(&m), nil
} }
// GetByUserIDAndGroupID 根据用户ID和分组ID获取订阅 func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) { var m userSubscriptionModel
var sub model.UserSubscription
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Preload("Group"). Preload("Group").
Where("user_id = ? AND group_id = ?", userID, groupID). Where("user_id = ? AND group_id = ?", userID, groupID).
First(&sub).Error First(&m).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil) return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
} }
return &sub, nil return userSubscriptionModelToService(&m), nil
} }
// GetActiveByUserIDAndGroupID 获取用户对特定分组的有效订阅 func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) { var m userSubscriptionModel
var sub model.UserSubscription
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Preload("Group"). Preload("Group").
Where("user_id = ? AND group_id = ? AND status = ? AND expires_at > ?", Where("user_id = ? AND group_id = ? AND status = ? AND expires_at > ?",
userID, groupID, model.SubscriptionStatusActive, time.Now()). userID, groupID, service.SubscriptionStatusActive, time.Now()).
First(&sub).Error First(&m).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil) return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
} }
return &sub, nil return userSubscriptionModelToService(&m), nil
} }
// Update 更新订阅 func (r *userSubscriptionRepository) Update(ctx context.Context, sub *service.UserSubscription) error {
func (r *userSubscriptionRepository) Update(ctx context.Context, sub *model.UserSubscription) error {
sub.UpdatedAt = time.Now() sub.UpdatedAt = time.Now()
return r.db.WithContext(ctx).Save(sub).Error m := userSubscriptionModelFromService(sub)
err := r.db.WithContext(ctx).Save(m).Error
if err == nil {
applyUserSubscriptionModelToService(sub, m)
}
return err
} }
// Delete 删除订阅
func (r *userSubscriptionRepository) Delete(ctx context.Context, id int64) error { func (r *userSubscriptionRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.UserSubscription{}, id).Error return r.db.WithContext(ctx).Delete(&userSubscriptionModel{}, id).Error
} }
// ListByUserID 获取用户的所有订阅 func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) { var subs []userSubscriptionModel
var subs []model.UserSubscription
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Preload("Group"). Preload("Group").
Where("user_id = ?", userID). Where("user_id = ?", userID).
Order("created_at DESC"). Order("created_at DESC").
Find(&subs).Error Find(&subs).Error
return subs, err if err != nil {
return nil, err
}
return userSubscriptionModelsToService(subs), nil
} }
// ListActiveByUserID 获取用户的所有有效订阅 func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) { var subs []userSubscriptionModel
var subs []model.UserSubscription
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Preload("Group"). Preload("Group").
Where("user_id = ? AND status = ? AND expires_at > ?", Where("user_id = ? AND status = ? AND expires_at > ?",
userID, model.SubscriptionStatusActive, time.Now()). userID, service.SubscriptionStatusActive, time.Now()).
Order("created_at DESC"). Order("created_at DESC").
Find(&subs).Error Find(&subs).Error
return subs, err if err != nil {
return nil, err
}
return userSubscriptionModelsToService(subs), nil
} }
// ListByGroupID 获取分组的所有订阅(分页) func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error) { var subs []userSubscriptionModel
var subs []model.UserSubscription
var total int64 var total int64
query := r.db.WithContext(ctx).Model(&model.UserSubscription{}).Where("group_id = ?", groupID) query := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).Where("group_id = ?", groupID)
if err := query.Count(&total).Error; err != nil { if err := query.Count(&total).Error; err != nil {
return nil, nil, err return nil, nil, err
} }
...@@ -124,26 +126,14 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID ...@@ -124,26 +126,14 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
return nil, nil, err return nil, nil, err
} }
pages := int(total) / params.Limit() return userSubscriptionModelsToService(subs), paginationResultFromTotal(total, params), nil
if int(total)%params.Limit() > 0 {
pages++
}
return subs, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
} }
// List 获取所有订阅(分页,支持筛选) func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) { var subs []userSubscriptionModel
var subs []model.UserSubscription
var total int64 var total int64
query := r.db.WithContext(ctx).Model(&model.UserSubscription{}) query := r.db.WithContext(ctx).Model(&userSubscriptionModel{})
if userID != nil { if userID != nil {
query = query.Where("user_id = ?", *userID) query = query.Where("user_id = ?", *userID)
} }
...@@ -170,156 +160,240 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination ...@@ -170,156 +160,240 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
return nil, nil, err return nil, nil, err
} }
pages := int(total) / params.Limit() return userSubscriptionModelsToService(subs), paginationResultFromTotal(total, params), nil
if int(total)%params.Limit() > 0 { }
pages++
}
return subs, &pagination.PaginationResult{ func (r *userSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
Total: total, var count int64
Page: params.Page, err := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
PageSize: params.Limit(), Where("user_id = ? AND group_id = ?", userID, groupID).
Pages: pages, Count(&count).Error
}, nil return count > 0, err
} }
// IncrementUsage 增加使用量 func (r *userSubscriptionRepository) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error {
func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error { return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
return r.db.WithContext(ctx).Model(&model.UserSubscription{}). Where("id = ?", subscriptionID).
Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
"daily_usage_usd": gorm.Expr("daily_usage_usd + ?", costUSD), "expires_at": newExpiresAt,
"weekly_usage_usd": gorm.Expr("weekly_usage_usd + ?", costUSD), "updated_at": time.Now(),
"monthly_usage_usd": gorm.Expr("monthly_usage_usd + ?", costUSD),
"updated_at": time.Now(),
}).Error }).Error
} }
// ResetDailyUsage 重置日使用量 func (r *userSubscriptionRepository) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error {
func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
return r.db.WithContext(ctx).Model(&model.UserSubscription{}). Where("id = ?", subscriptionID).
Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
"daily_usage_usd": 0, "status": status,
"daily_window_start": newWindowStart, "updated_at": time.Now(),
"updated_at": time.Now(),
}).Error }).Error
} }
// ResetWeeklyUsage 重置周使用量 func (r *userSubscriptionRepository) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error {
func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
return r.db.WithContext(ctx).Model(&model.UserSubscription{}). Where("id = ?", subscriptionID).
Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
"weekly_usage_usd": 0, "notes": notes,
"weekly_window_start": newWindowStart, "updated_at": time.Now(),
"updated_at": time.Now(),
}).Error }).Error
} }
// ResetMonthlyUsage 重置月使用量 func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, start time.Time) error {
func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
"monthly_usage_usd": 0, "daily_window_start": start,
"monthly_window_start": newWindowStart, "weekly_window_start": start,
"monthly_window_start": start,
"updated_at": time.Now(), "updated_at": time.Now(),
}).Error }).Error
} }
// ActivateWindows 激活所有窗口(首次使用时) func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, activateTime time.Time) error { return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
"daily_window_start": activateTime, "daily_usage_usd": 0,
"weekly_window_start": activateTime, "daily_window_start": newWindowStart,
"monthly_window_start": activateTime, "updated_at": time.Now(),
"updated_at": time.Now(),
}).Error }).Error
} }
// UpdateStatus 更新订阅状态 func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
func (r *userSubscriptionRepository) UpdateStatus(ctx context.Context, id int64, status string) error { return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
"status": status, "weekly_usage_usd": 0,
"updated_at": time.Now(), "weekly_window_start": newWindowStart,
"updated_at": time.Now(),
}).Error }).Error
} }
// ExtendExpiry 延长订阅过期时间 func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
func (r *userSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64, newExpiresAt time.Time) error { return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
"expires_at": newExpiresAt, "monthly_usage_usd": 0,
"updated_at": time.Now(), "monthly_window_start": newWindowStart,
"updated_at": time.Now(),
}).Error }).Error
} }
// UpdateNotes 更新订阅备注 func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
func (r *userSubscriptionRepository) UpdateNotes(ctx context.Context, id int64, notes string) error { return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
"notes": notes, "daily_usage_usd": gorm.Expr("daily_usage_usd + ?", costUSD),
"updated_at": time.Now(), "weekly_usage_usd": gorm.Expr("weekly_usage_usd + ?", costUSD),
"monthly_usage_usd": gorm.Expr("monthly_usage_usd + ?", costUSD),
"updated_at": time.Now(),
}).Error }).Error
} }
// ListExpired 获取所有已过期但状态仍为active的订阅
func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]model.UserSubscription, error) {
var subs []model.UserSubscription
err := r.db.WithContext(ctx).
Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()).
Find(&subs).Error
return subs, err
}
// BatchUpdateExpiredStatus 批量更新过期订阅状态
func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) { func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
result := r.db.WithContext(ctx).Model(&model.UserSubscription{}). result := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()). Where("status = ? AND expires_at <= ?", service.SubscriptionStatusActive, time.Now()).
Updates(map[string]any{ Updates(map[string]any{
"status": model.SubscriptionStatusExpired, "status": service.SubscriptionStatusExpired,
"updated_at": time.Now(), "updated_at": time.Now(),
}) })
return result.RowsAffected, result.Error return result.RowsAffected, result.Error
} }
// ExistsByUserIDAndGroupID 检查用户是否已有该分组的订阅 // Extra repository helpers (currently used only by integration tests).
func (r *userSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
var count int64 func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]service.UserSubscription, error) {
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}). var subs []userSubscriptionModel
Where("user_id = ? AND group_id = ?", userID, groupID). err := r.db.WithContext(ctx).
Count(&count).Error Where("status = ? AND expires_at <= ?", service.SubscriptionStatusActive, time.Now()).
return count > 0, err Find(&subs).Error
if err != nil {
return nil, err
}
return userSubscriptionModelsToService(subs), nil
} }
// CountByGroupID 获取分组的订阅数量
func (r *userSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *userSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}). err := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("group_id = ?", groupID). Where("group_id = ?", groupID).
Count(&count).Error Count(&count).Error
return count, err return count, err
} }
// CountActiveByGroupID 获取分组的有效订阅数量
func (r *userSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *userSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}). err := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("group_id = ? AND status = ? AND expires_at > ?", Where("group_id = ? AND status = ? AND expires_at > ?",
groupID, model.SubscriptionStatusActive, time.Now()). groupID, service.SubscriptionStatusActive, time.Now()).
Count(&count).Error Count(&count).Error
return count, err return count, err
} }
// DeleteByGroupID 删除分组相关的所有订阅记录
func (r *userSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *userSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.UserSubscription{}) result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&userSubscriptionModel{})
return result.RowsAffected, result.Error return result.RowsAffected, result.Error
} }
type userSubscriptionModel struct {
ID int64 `gorm:"primaryKey"`
UserID int64 `gorm:"index;not null"`
GroupID int64 `gorm:"index;not null"`
StartsAt time.Time `gorm:"not null"`
ExpiresAt time.Time `gorm:"not null"`
Status string `gorm:"size:20;default:active;not null"`
DailyWindowStart *time.Time
WeeklyWindowStart *time.Time
MonthlyWindowStart *time.Time
DailyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null"`
WeeklyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null"`
MonthlyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null"`
AssignedBy *int64 `gorm:"index"`
AssignedAt time.Time `gorm:"not null"`
Notes string `gorm:"type:text"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
User *userModel `gorm:"foreignKey:UserID"`
Group *groupModel `gorm:"foreignKey:GroupID"`
AssignedByUser *userModel `gorm:"foreignKey:AssignedBy"`
}
func (userSubscriptionModel) TableName() string { return "user_subscriptions" }
func userSubscriptionModelToService(m *userSubscriptionModel) *service.UserSubscription {
if m == nil {
return nil
}
return &service.UserSubscription{
ID: m.ID,
UserID: m.UserID,
GroupID: m.GroupID,
StartsAt: m.StartsAt,
ExpiresAt: m.ExpiresAt,
Status: m.Status,
DailyWindowStart: m.DailyWindowStart,
WeeklyWindowStart: m.WeeklyWindowStart,
MonthlyWindowStart: m.MonthlyWindowStart,
DailyUsageUSD: m.DailyUsageUSD,
WeeklyUsageUSD: m.WeeklyUsageUSD,
MonthlyUsageUSD: m.MonthlyUsageUSD,
AssignedBy: m.AssignedBy,
AssignedAt: m.AssignedAt,
Notes: m.Notes,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
User: userModelToService(m.User),
Group: groupModelToService(m.Group),
AssignedByUser: userModelToService(m.AssignedByUser),
}
}
func userSubscriptionModelsToService(models []userSubscriptionModel) []service.UserSubscription {
out := make([]service.UserSubscription, 0, len(models))
for i := range models {
if s := userSubscriptionModelToService(&models[i]); s != nil {
out = append(out, *s)
}
}
return out
}
func userSubscriptionModelFromService(s *service.UserSubscription) *userSubscriptionModel {
if s == nil {
return nil
}
return &userSubscriptionModel{
ID: s.ID,
UserID: s.UserID,
GroupID: s.GroupID,
StartsAt: s.StartsAt,
ExpiresAt: s.ExpiresAt,
Status: s.Status,
DailyWindowStart: s.DailyWindowStart,
WeeklyWindowStart: s.WeeklyWindowStart,
MonthlyWindowStart: s.MonthlyWindowStart,
DailyUsageUSD: s.DailyUsageUSD,
WeeklyUsageUSD: s.WeeklyUsageUSD,
MonthlyUsageUSD: s.MonthlyUsageUSD,
AssignedBy: s.AssignedBy,
AssignedAt: s.AssignedAt,
Notes: s.Notes,
CreatedAt: s.CreatedAt,
UpdatedAt: s.UpdatedAt,
}
}
func applyUserSubscriptionModelToService(sub *service.UserSubscription, m *userSubscriptionModel) {
if sub == nil || m == nil {
return
}
sub.ID = m.ID
sub.CreatedAt = m.CreatedAt
sub.UpdatedAt = m.UpdatedAt
}
...@@ -7,8 +7,8 @@ import ( ...@@ -7,8 +7,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"gorm.io/gorm" "gorm.io/gorm"
) )
...@@ -33,13 +33,13 @@ func TestUserSubscriptionRepoSuite(t *testing.T) { ...@@ -33,13 +33,13 @@ func TestUserSubscriptionRepoSuite(t *testing.T) {
// --- Create / GetByID / Update / Delete --- // --- Create / GetByID / Update / Delete ---
func (s *UserSubscriptionRepoSuite) TestCreate() { func (s *UserSubscriptionRepoSuite) TestCreate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "sub-create@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "sub-create@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-create"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-create"})
sub := &model.UserSubscription{ sub := &service.UserSubscription{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
} }
...@@ -54,14 +54,14 @@ func (s *UserSubscriptionRepoSuite) TestCreate() { ...@@ -54,14 +54,14 @@ func (s *UserSubscriptionRepoSuite) TestCreate() {
} }
func (s *UserSubscriptionRepoSuite) TestGetByID_WithPreloads() { func (s *UserSubscriptionRepoSuite) TestGetByID_WithPreloads() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "preload@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "preload@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-preload"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-preload"})
admin := mustCreateUser(s.T(), s.db, &model.User{Email: "admin@test.com", Role: model.RoleAdmin}) admin := mustCreateUser(s.T(), s.db, &userModel{Email: "admin@test.com", Role: service.RoleAdmin})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
AssignedBy: &admin.ID, AssignedBy: &admin.ID,
}) })
...@@ -82,14 +82,14 @@ func (s *UserSubscriptionRepoSuite) TestGetByID_NotFound() { ...@@ -82,14 +82,14 @@ func (s *UserSubscriptionRepoSuite) TestGetByID_NotFound() {
} }
func (s *UserSubscriptionRepoSuite) TestUpdate() { func (s *UserSubscriptionRepoSuite) TestUpdate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "update@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "update@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-update"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-update"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ sub := userSubscriptionModelToService(mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) }))
sub.Notes = "updated notes" sub.Notes = "updated notes"
err := s.repo.Update(s.ctx, sub) err := s.repo.Update(s.ctx, sub)
...@@ -101,12 +101,12 @@ func (s *UserSubscriptionRepoSuite) TestUpdate() { ...@@ -101,12 +101,12 @@ func (s *UserSubscriptionRepoSuite) TestUpdate() {
} }
func (s *UserSubscriptionRepoSuite) TestDelete() { func (s *UserSubscriptionRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-delete"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-delete"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
...@@ -120,12 +120,12 @@ func (s *UserSubscriptionRepoSuite) TestDelete() { ...@@ -120,12 +120,12 @@ func (s *UserSubscriptionRepoSuite) TestDelete() {
// --- GetByUserIDAndGroupID / GetActiveByUserIDAndGroupID --- // --- GetByUserIDAndGroupID / GetActiveByUserIDAndGroupID ---
func (s *UserSubscriptionRepoSuite) TestGetByUserIDAndGroupID() { func (s *UserSubscriptionRepoSuite) TestGetByUserIDAndGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "byuser@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "byuser@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-byuser"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-byuser"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
...@@ -141,14 +141,14 @@ func (s *UserSubscriptionRepoSuite) TestGetByUserIDAndGroupID_NotFound() { ...@@ -141,14 +141,14 @@ func (s *UserSubscriptionRepoSuite) TestGetByUserIDAndGroupID_NotFound() {
} }
func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID() { func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "active@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "active@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-active"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-active"})
// Create active subscription (future expiry) // Create active subscription (future expiry)
active := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ active := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(2 * time.Hour), ExpiresAt: time.Now().Add(2 * time.Hour),
}) })
...@@ -158,14 +158,14 @@ func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID() { ...@@ -158,14 +158,14 @@ func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID() {
} }
func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID_ExpiredIgnored() { func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID_ExpiredIgnored() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "expired@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "expired@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-expired"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-expired"})
// Create expired subscription (past expiry but active status) // Create expired subscription (past expiry but active status)
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-2 * time.Hour), ExpiresAt: time.Now().Add(-2 * time.Hour),
}) })
...@@ -176,20 +176,20 @@ func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID_ExpiredIgnor ...@@ -176,20 +176,20 @@ func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID_ExpiredIgnor
// --- ListByUserID / ListActiveByUserID --- // --- ListByUserID / ListActiveByUserID ---
func (s *UserSubscriptionRepoSuite) TestListByUserID() { func (s *UserSubscriptionRepoSuite) TestListByUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listby@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "listby@test.com"})
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list1"}) g1 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list1"})
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list2"}) g2 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list2"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: g1.ID, GroupID: g1.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: g2.ID, GroupID: g2.ID,
Status: model.SubscriptionStatusExpired, Status: service.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour), ExpiresAt: time.Now().Add(-24 * time.Hour),
}) })
...@@ -202,46 +202,46 @@ func (s *UserSubscriptionRepoSuite) TestListByUserID() { ...@@ -202,46 +202,46 @@ func (s *UserSubscriptionRepoSuite) TestListByUserID() {
} }
func (s *UserSubscriptionRepoSuite) TestListActiveByUserID() { func (s *UserSubscriptionRepoSuite) TestListActiveByUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listactive@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "listactive@test.com"})
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-act1"}) g1 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-act1"})
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-act2"}) g2 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-act2"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: g1.ID, GroupID: g1.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: g2.ID, GroupID: g2.ID,
Status: model.SubscriptionStatusExpired, Status: service.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour), ExpiresAt: time.Now().Add(-24 * time.Hour),
}) })
subs, err := s.repo.ListActiveByUserID(s.ctx, user.ID) subs, err := s.repo.ListActiveByUserID(s.ctx, user.ID)
s.Require().NoError(err, "ListActiveByUserID") s.Require().NoError(err, "ListActiveByUserID")
s.Require().Len(subs, 1) s.Require().Len(subs, 1)
s.Require().Equal(model.SubscriptionStatusActive, subs[0].Status) s.Require().Equal(service.SubscriptionStatusActive, subs[0].Status)
} }
// --- ListByGroupID --- // --- ListByGroupID ---
func (s *UserSubscriptionRepoSuite) TestListByGroupID() { func (s *UserSubscriptionRepoSuite) TestListByGroupID() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "u1@test.com"}) user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "u1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "u2@test.com"}) user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "u2@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-listgrp"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-listgrp"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user1.ID, UserID: user1.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user2.ID, UserID: user2.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
...@@ -258,13 +258,13 @@ func (s *UserSubscriptionRepoSuite) TestListByGroupID() { ...@@ -258,13 +258,13 @@ func (s *UserSubscriptionRepoSuite) TestListByGroupID() {
// --- List with filters --- // --- List with filters ---
func (s *UserSubscriptionRepoSuite) TestList_NoFilters() { func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "list@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "list@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
...@@ -275,20 +275,20 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() { ...@@ -275,20 +275,20 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
} }
func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() { func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "filter1@test.com"}) user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "filter1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "filter2@test.com"}) user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "filter2@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-filter"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-filter"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user1.ID, UserID: user1.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user2.ID, UserID: user2.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
...@@ -299,20 +299,20 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() { ...@@ -299,20 +299,20 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
} }
func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() { func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "grpfilter@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "grpfilter@test.com"})
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-f1"}) g1 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-f1"})
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-f2"}) g2 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-f2"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: g1.ID, GroupID: g1.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: g2.ID, GroupID: g2.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
...@@ -323,37 +323,37 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() { ...@@ -323,37 +323,37 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
} }
func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() { func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "statfilter@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "statfilter@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-stat"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-stat"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusExpired, Status: service.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour), ExpiresAt: time.Now().Add(-24 * time.Hour),
}) })
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, model.SubscriptionStatusExpired) subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(subs, 1) s.Require().Len(subs, 1)
s.Require().Equal(model.SubscriptionStatusExpired, subs[0].Status) s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status)
} }
// --- Usage tracking --- // --- Usage tracking ---
func (s *UserSubscriptionRepoSuite) TestIncrementUsage() { func (s *UserSubscriptionRepoSuite) TestIncrementUsage() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "usage@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "usage@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-usage"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-usage"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
...@@ -368,12 +368,12 @@ func (s *UserSubscriptionRepoSuite) TestIncrementUsage() { ...@@ -368,12 +368,12 @@ func (s *UserSubscriptionRepoSuite) TestIncrementUsage() {
} }
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Accumulates() { func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Accumulates() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "accum@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "accum@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-accum"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-accum"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
...@@ -386,12 +386,12 @@ func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Accumulates() { ...@@ -386,12 +386,12 @@ func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Accumulates() {
} }
func (s *UserSubscriptionRepoSuite) TestActivateWindows() { func (s *UserSubscriptionRepoSuite) TestActivateWindows() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "activate@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "activate@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-activate"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-activate"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
...@@ -408,12 +408,12 @@ func (s *UserSubscriptionRepoSuite) TestActivateWindows() { ...@@ -408,12 +408,12 @@ func (s *UserSubscriptionRepoSuite) TestActivateWindows() {
} }
func (s *UserSubscriptionRepoSuite) TestResetDailyUsage() { func (s *UserSubscriptionRepoSuite) TestResetDailyUsage() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "resetd@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "resetd@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-resetd"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-resetd"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
DailyUsageUSD: 10.0, DailyUsageUSD: 10.0,
WeeklyUsageUSD: 20.0, WeeklyUsageUSD: 20.0,
...@@ -431,12 +431,12 @@ func (s *UserSubscriptionRepoSuite) TestResetDailyUsage() { ...@@ -431,12 +431,12 @@ func (s *UserSubscriptionRepoSuite) TestResetDailyUsage() {
} }
func (s *UserSubscriptionRepoSuite) TestResetWeeklyUsage() { func (s *UserSubscriptionRepoSuite) TestResetWeeklyUsage() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "resetw@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "resetw@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-resetw"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-resetw"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
WeeklyUsageUSD: 15.0, WeeklyUsageUSD: 15.0,
MonthlyUsageUSD: 30.0, MonthlyUsageUSD: 30.0,
...@@ -454,12 +454,12 @@ func (s *UserSubscriptionRepoSuite) TestResetWeeklyUsage() { ...@@ -454,12 +454,12 @@ func (s *UserSubscriptionRepoSuite) TestResetWeeklyUsage() {
} }
func (s *UserSubscriptionRepoSuite) TestResetMonthlyUsage() { func (s *UserSubscriptionRepoSuite) TestResetMonthlyUsage() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "resetm@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "resetm@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-resetm"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-resetm"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
MonthlyUsageUSD: 100.0, MonthlyUsageUSD: 100.0,
}) })
...@@ -477,30 +477,30 @@ func (s *UserSubscriptionRepoSuite) TestResetMonthlyUsage() { ...@@ -477,30 +477,30 @@ func (s *UserSubscriptionRepoSuite) TestResetMonthlyUsage() {
// --- UpdateStatus / ExtendExpiry / UpdateNotes --- // --- UpdateStatus / ExtendExpiry / UpdateNotes ---
func (s *UserSubscriptionRepoSuite) TestUpdateStatus() { func (s *UserSubscriptionRepoSuite) TestUpdateStatus() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "status@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "status@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-status"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-status"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
err := s.repo.UpdateStatus(s.ctx, sub.ID, model.SubscriptionStatusExpired) err := s.repo.UpdateStatus(s.ctx, sub.ID, service.SubscriptionStatusExpired)
s.Require().NoError(err, "UpdateStatus") s.Require().NoError(err, "UpdateStatus")
got, err := s.repo.GetByID(s.ctx, sub.ID) got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Equal(model.SubscriptionStatusExpired, got.Status) s.Require().Equal(service.SubscriptionStatusExpired, got.Status)
} }
func (s *UserSubscriptionRepoSuite) TestExtendExpiry() { func (s *UserSubscriptionRepoSuite) TestExtendExpiry() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "extend@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "extend@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-extend"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-extend"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
...@@ -514,12 +514,12 @@ func (s *UserSubscriptionRepoSuite) TestExtendExpiry() { ...@@ -514,12 +514,12 @@ func (s *UserSubscriptionRepoSuite) TestExtendExpiry() {
} }
func (s *UserSubscriptionRepoSuite) TestUpdateNotes() { func (s *UserSubscriptionRepoSuite) TestUpdateNotes() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "notes@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "notes@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-notes"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-notes"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
...@@ -534,19 +534,19 @@ func (s *UserSubscriptionRepoSuite) TestUpdateNotes() { ...@@ -534,19 +534,19 @@ func (s *UserSubscriptionRepoSuite) TestUpdateNotes() {
// --- ListExpired / BatchUpdateExpiredStatus --- // --- ListExpired / BatchUpdateExpiredStatus ---
func (s *UserSubscriptionRepoSuite) TestListExpired() { func (s *UserSubscriptionRepoSuite) TestListExpired() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listexp@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "listexp@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-listexp"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-listexp"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-24 * time.Hour), ExpiresAt: time.Now().Add(-24 * time.Hour),
}) })
...@@ -556,19 +556,19 @@ func (s *UserSubscriptionRepoSuite) TestListExpired() { ...@@ -556,19 +556,19 @@ func (s *UserSubscriptionRepoSuite) TestListExpired() {
} }
func (s *UserSubscriptionRepoSuite) TestBatchUpdateExpiredStatus() { func (s *UserSubscriptionRepoSuite) TestBatchUpdateExpiredStatus() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "batch@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "batch@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-batch"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-batch"})
active := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ active := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
expiredActive := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ expiredActive := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-24 * time.Hour), ExpiresAt: time.Now().Add(-24 * time.Hour),
}) })
...@@ -577,22 +577,22 @@ func (s *UserSubscriptionRepoSuite) TestBatchUpdateExpiredStatus() { ...@@ -577,22 +577,22 @@ func (s *UserSubscriptionRepoSuite) TestBatchUpdateExpiredStatus() {
s.Require().Equal(int64(1), affected) s.Require().Equal(int64(1), affected)
gotActive, _ := s.repo.GetByID(s.ctx, active.ID) gotActive, _ := s.repo.GetByID(s.ctx, active.ID)
s.Require().Equal(model.SubscriptionStatusActive, gotActive.Status) s.Require().Equal(service.SubscriptionStatusActive, gotActive.Status)
gotExpired, _ := s.repo.GetByID(s.ctx, expiredActive.ID) gotExpired, _ := s.repo.GetByID(s.ctx, expiredActive.ID)
s.Require().Equal(model.SubscriptionStatusExpired, gotExpired.Status) s.Require().Equal(service.SubscriptionStatusExpired, gotExpired.Status)
} }
// --- ExistsByUserIDAndGroupID --- // --- ExistsByUserIDAndGroupID ---
func (s *UserSubscriptionRepoSuite) TestExistsByUserIDAndGroupID() { func (s *UserSubscriptionRepoSuite) TestExistsByUserIDAndGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "exists@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "exists@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-exists"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-exists"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
...@@ -608,20 +608,20 @@ func (s *UserSubscriptionRepoSuite) TestExistsByUserIDAndGroupID() { ...@@ -608,20 +608,20 @@ func (s *UserSubscriptionRepoSuite) TestExistsByUserIDAndGroupID() {
// --- CountByGroupID / CountActiveByGroupID --- // --- CountByGroupID / CountActiveByGroupID ---
func (s *UserSubscriptionRepoSuite) TestCountByGroupID() { func (s *UserSubscriptionRepoSuite) TestCountByGroupID() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "cnt1@test.com"}) user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "cnt1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "cnt2@test.com"}) user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "cnt2@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-count"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-count"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user1.ID, UserID: user1.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user2.ID, UserID: user2.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusExpired, Status: service.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour), ExpiresAt: time.Now().Add(-24 * time.Hour),
}) })
...@@ -631,20 +631,20 @@ func (s *UserSubscriptionRepoSuite) TestCountByGroupID() { ...@@ -631,20 +631,20 @@ func (s *UserSubscriptionRepoSuite) TestCountByGroupID() {
} }
func (s *UserSubscriptionRepoSuite) TestCountActiveByGroupID() { func (s *UserSubscriptionRepoSuite) TestCountActiveByGroupID() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "cntact1@test.com"}) user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "cntact1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "cntact2@test.com"}) user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "cntact2@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-cntact"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-cntact"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user1.ID, UserID: user1.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user2.ID, UserID: user2.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-24 * time.Hour), // expired by time ExpiresAt: time.Now().Add(-24 * time.Hour), // expired by time
}) })
...@@ -656,19 +656,19 @@ func (s *UserSubscriptionRepoSuite) TestCountActiveByGroupID() { ...@@ -656,19 +656,19 @@ func (s *UserSubscriptionRepoSuite) TestCountActiveByGroupID() {
// --- DeleteByGroupID --- // --- DeleteByGroupID ---
func (s *UserSubscriptionRepoSuite) TestDeleteByGroupID() { func (s *UserSubscriptionRepoSuite) TestDeleteByGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delgrp@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "delgrp@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-delgrp"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-delgrp"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusExpired, Status: service.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour), ExpiresAt: time.Now().Add(-24 * time.Hour),
}) })
...@@ -683,19 +683,19 @@ func (s *UserSubscriptionRepoSuite) TestDeleteByGroupID() { ...@@ -683,19 +683,19 @@ func (s *UserSubscriptionRepoSuite) TestDeleteByGroupID() {
// --- Combined original test --- // --- Combined original test ---
func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_BatchUpdateExpiredStatus() { func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_BatchUpdateExpiredStatus() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "subr@example.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "subr@example.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-subr"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-subr"})
active := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ active := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(2 * time.Hour), ExpiresAt: time.Now().Add(2 * time.Hour),
}) })
expiredActive := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ expiredActive := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-2 * time.Hour), ExpiresAt: time.Now().Add(-2 * time.Hour),
}) })
...@@ -729,5 +729,5 @@ func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_Ba ...@@ -729,5 +729,5 @@ func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_Ba
s.Require().Equal(int64(1), affected, "expected 1 affected row") s.Require().Equal(int64(1), affected, "expected 1 affected row")
updated, err := s.repo.GetByID(s.ctx, expiredActive.ID) updated, err := s.repo.GetByID(s.ctx, expiredActive.ID)
s.Require().NoError(err, "GetByID expired") s.Require().NoError(err, "GetByID expired")
s.Require().Equal(model.SubscriptionStatusExpired, updated.Status, "expected status expired") s.Require().Equal(service.SubscriptionStatusExpired, updated.Status, "expected status expired")
} }
//go:build unit
package server_test
import (
"bytes"
"context"
"errors"
"io"
"math"
"net/http"
"net/http/httptest"
"sort"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
adminhandler "github.com/Wei-Shaw/sub2api/internal/handler/admin"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestAPIContracts(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
setup func(t *testing.T, deps *contractDeps)
method string
path string
body string
headers map[string]string
wantStatus int
wantJSON string
}{
{
name: "GET /api/v1/auth/me",
method: http.MethodGet,
path: "/api/v1/auth/me",
wantStatus: http.StatusOK,
wantJSON: `{
"code": 0,
"message": "success",
"data": {
"id": 1,
"email": "alice@example.com",
"username": "alice",
"wechat": "wx_alice",
"notes": "hello",
"role": "user",
"balance": 12.5,
"concurrency": 5,
"status": "active",
"allowed_groups": null,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
}
}`,
},
{
name: "POST /api/v1/keys",
method: http.MethodPost,
path: "/api/v1/keys",
body: `{"name":"Key One","custom_key":"sk_custom_1234567890"}`,
headers: map[string]string{
"Content-Type": "application/json",
},
wantStatus: http.StatusOK,
wantJSON: `{
"code": 0,
"message": "success",
"data": {
"id": 100,
"user_id": 1,
"key": "sk_custom_1234567890",
"name": "Key One",
"group_id": null,
"status": "active",
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
}
}`,
},
{
name: "GET /api/v1/keys (paginated)",
setup: func(t *testing.T, deps *contractDeps) {
t.Helper()
deps.apiKeyRepo.MustSeed(&service.ApiKey{
ID: 100,
UserID: 1,
Key: "sk_custom_1234567890",
Name: "Key One",
Status: service.StatusActive,
CreatedAt: deps.now,
UpdatedAt: deps.now,
})
},
method: http.MethodGet,
path: "/api/v1/keys?page=1&page_size=10",
wantStatus: http.StatusOK,
wantJSON: `{
"code": 0,
"message": "success",
"data": {
"items": [
{
"id": 100,
"user_id": 1,
"key": "sk_custom_1234567890",
"name": "Key One",
"group_id": null,
"status": "active",
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
}
],
"total": 1,
"page": 1,
"page_size": 10,
"pages": 1
}
}`,
},
{
name: "GET /api/v1/usage/stats",
setup: func(t *testing.T, deps *contractDeps) {
t.Helper()
deps.usageRepo.SetUserLogs(1, []service.UsageLog{
{
ID: 1,
UserID: 1,
ApiKeyID: 100,
AccountID: 200,
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
CacheCreationTokens: 1,
CacheReadTokens: 2,
TotalCost: 0.5,
ActualCost: 0.5,
DurationMs: ptr(100),
CreatedAt: deps.now,
},
{
ID: 2,
UserID: 1,
ApiKeyID: 100,
AccountID: 200,
Model: "claude-3",
InputTokens: 5,
OutputTokens: 15,
TotalCost: 0.25,
ActualCost: 0.25,
DurationMs: ptr(300),
CreatedAt: deps.now,
},
})
},
method: http.MethodGet,
path: "/api/v1/usage/stats?start_date=2025-01-01&end_date=2025-01-02",
wantStatus: http.StatusOK,
wantJSON: `{
"code": 0,
"message": "success",
"data": {
"total_requests": 2,
"total_input_tokens": 15,
"total_output_tokens": 35,
"total_cache_tokens": 3,
"total_tokens": 53,
"total_cost": 0.75,
"total_actual_cost": 0.75,
"average_duration_ms": 200
}
}`,
},
{
name: "GET /api/v1/usage (paginated)",
setup: func(t *testing.T, deps *contractDeps) {
t.Helper()
deps.usageRepo.SetUserLogs(1, []service.UsageLog{
{
ID: 1,
UserID: 1,
ApiKeyID: 100,
AccountID: 200,
RequestID: "req_123",
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
CacheCreationTokens: 1,
CacheReadTokens: 2,
TotalCost: 0.5,
ActualCost: 0.5,
RateMultiplier: 1,
BillingType: service.BillingTypeBalance,
Stream: true,
DurationMs: ptr(100),
FirstTokenMs: ptr(50),
CreatedAt: deps.now,
},
})
},
method: http.MethodGet,
path: "/api/v1/usage?page=1&page_size=10",
wantStatus: http.StatusOK,
wantJSON: `{
"code": 0,
"message": "success",
"data": {
"items": [
{
"id": 1,
"user_id": 1,
"api_key_id": 100,
"account_id": 200,
"request_id": "req_123",
"model": "claude-3",
"group_id": null,
"subscription_id": null,
"input_tokens": 10,
"output_tokens": 20,
"cache_creation_tokens": 1,
"cache_read_tokens": 2,
"cache_creation_5m_tokens": 0,
"cache_creation_1h_tokens": 0,
"input_cost": 0,
"output_cost": 0,
"cache_creation_cost": 0,
"cache_read_cost": 0,
"total_cost": 0.5,
"actual_cost": 0.5,
"rate_multiplier": 1,
"billing_type": 0,
"stream": true,
"duration_ms": 100,
"first_token_ms": 50,
"created_at": "2025-01-02T03:04:05Z"
}
],
"total": 1,
"page": 1,
"page_size": 10,
"pages": 1
}
}`,
},
{
name: "GET /api/v1/admin/settings",
setup: func(t *testing.T, deps *contractDeps) {
t.Helper()
deps.settingRepo.SetAll(map[string]string{
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyEmailVerifyEnabled: "false",
service.SettingKeySmtpHost: "smtp.example.com",
service.SettingKeySmtpPort: "587",
service.SettingKeySmtpUsername: "user",
service.SettingKeySmtpPassword: "secret",
service.SettingKeySmtpFrom: "no-reply@example.com",
service.SettingKeySmtpFromName: "Sub2API",
service.SettingKeySmtpUseTLS: "true",
service.SettingKeyTurnstileEnabled: "true",
service.SettingKeyTurnstileSiteKey: "site-key",
service.SettingKeyTurnstileSecretKey: "secret-key",
service.SettingKeySiteName: "Sub2API",
service.SettingKeySiteLogo: "",
service.SettingKeySiteSubtitle: "Subtitle",
service.SettingKeyApiBaseUrl: "https://api.example.com",
service.SettingKeyContactInfo: "support",
service.SettingKeyDocUrl: "https://docs.example.com",
service.SettingKeyDefaultConcurrency: "5",
service.SettingKeyDefaultBalance: "1.25",
})
},
method: http.MethodGet,
path: "/api/v1/admin/settings",
wantStatus: http.StatusOK,
wantJSON: `{
"code": 0,
"message": "success",
"data": {
"registration_enabled": true,
"email_verify_enabled": false,
"smtp_host": "smtp.example.com",
"smtp_port": 587,
"smtp_username": "user",
"smtp_password": "secret",
"smtp_from_email": "no-reply@example.com",
"smtp_from_name": "Sub2API",
"smtp_use_tls": true,
"turnstile_enabled": true,
"turnstile_site_key": "site-key",
"turnstile_secret_key": "secret-key",
"site_name": "Sub2API",
"site_logo": "",
"site_subtitle": "Subtitle",
"api_base_url": "https://api.example.com",
"contact_info": "support",
"doc_url": "https://docs.example.com",
"default_concurrency": 5,
"default_balance": 1.25
}
}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
deps := newContractDeps(t)
if tt.setup != nil {
tt.setup(t, deps)
}
status, body := doRequest(t, deps.router, tt.method, tt.path, tt.body, tt.headers)
require.Equal(t, tt.wantStatus, status)
require.JSONEq(t, tt.wantJSON, body)
})
}
}
type contractDeps struct {
now time.Time
router http.Handler
apiKeyRepo *stubApiKeyRepo
usageRepo *stubUsageLogRepo
settingRepo *stubSettingRepo
}
func newContractDeps(t *testing.T) *contractDeps {
t.Helper()
now := time.Date(2025, 1, 2, 3, 4, 5, 0, time.UTC)
userRepo := &stubUserRepo{
users: map[int64]*service.User{
1: {
ID: 1,
Email: "alice@example.com",
Username: "alice",
Wechat: "wx_alice",
Notes: "hello",
Role: service.RoleUser,
Balance: 12.5,
Concurrency: 5,
Status: service.StatusActive,
AllowedGroups: nil,
CreatedAt: now,
UpdatedAt: now,
},
},
}
apiKeyRepo := newStubApiKeyRepo(now)
apiKeyCache := stubApiKeyCache{}
groupRepo := stubGroupRepo{}
userSubRepo := stubUserSubscriptionRepo{}
cfg := &config.Config{
Default: config.DefaultConfig{
ApiKeyPrefix: "sk-",
},
}
userService := service.NewUserService(userRepo)
apiKeyService := service.NewApiKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg)
usageRepo := newStubUsageLogRepo()
usageService := service.NewUsageService(usageRepo, userRepo)
settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg)
authHandler := handler.NewAuthHandler(nil, userService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil)
jwtAuth := func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
UserID: 1,
Concurrency: 5,
})
c.Set(string(middleware.ContextKeyUserRole), service.RoleUser)
c.Next()
}
adminAuth := func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
UserID: 1,
Concurrency: 5,
})
c.Set(string(middleware.ContextKeyUserRole), service.RoleAdmin)
c.Next()
}
r := gin.New()
v1 := r.Group("/api/v1")
v1Auth := v1.Group("")
v1Auth.Use(jwtAuth)
v1Auth.GET("/auth/me", authHandler.GetCurrentUser)
v1Keys := v1.Group("")
v1Keys.Use(jwtAuth)
v1Keys.GET("/keys", apiKeyHandler.List)
v1Keys.POST("/keys", apiKeyHandler.Create)
v1Usage := v1.Group("")
v1Usage.Use(jwtAuth)
v1Usage.GET("/usage", usageHandler.List)
v1Usage.GET("/usage/stats", usageHandler.Stats)
v1Admin := v1.Group("/admin")
v1Admin.Use(adminAuth)
v1Admin.GET("/settings", adminSettingHandler.GetSettings)
return &contractDeps{
now: now,
router: r,
apiKeyRepo: apiKeyRepo,
usageRepo: usageRepo,
settingRepo: settingRepo,
}
}
func doRequest(t *testing.T, router http.Handler, method, path, body string, headers map[string]string) (int, string) {
t.Helper()
req := httptest.NewRequest(method, path, bytes.NewBufferString(body))
for k, v := range headers {
req.Header.Set(k, v)
}
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
respBody, err := io.ReadAll(w.Result().Body)
require.NoError(t, err)
return w.Result().StatusCode, string(respBody)
}
func ptr[T any](v T) *T { return &v }
type stubUserRepo struct {
users map[int64]*service.User
}
func (r *stubUserRepo) Create(ctx context.Context, user *service.User) error {
return errors.New("not implemented")
}
func (r *stubUserRepo) GetByID(ctx context.Context, id int64) (*service.User, error) {
user, ok := r.users[id]
if !ok {
return nil, service.ErrUserNotFound
}
clone := *user
return &clone, nil
}
func (r *stubUserRepo) GetByEmail(ctx context.Context, email string) (*service.User, error) {
for _, user := range r.users {
if user.Email == email {
clone := *user
return &clone, nil
}
}
return nil, service.ErrUserNotFound
}
func (r *stubUserRepo) GetFirstAdmin(ctx context.Context) (*service.User, error) {
for _, user := range r.users {
if user.Role == service.RoleAdmin && user.Status == service.StatusActive {
clone := *user
return &clone, nil
}
}
return nil, service.ErrUserNotFound
}
func (r *stubUserRepo) Update(ctx context.Context, user *service.User) error {
return errors.New("not implemented")
}
func (r *stubUserRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (r *stubUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubUserRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]service.User, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubUserRepo) UpdateBalance(ctx context.Context, id int64, amount float64) error {
return errors.New("not implemented")
}
func (r *stubUserRepo) DeductBalance(ctx context.Context, id int64, amount float64) error {
return errors.New("not implemented")
}
func (r *stubUserRepo) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
return errors.New("not implemented")
}
func (r *stubUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
return false, errors.New("not implemented")
}
func (r *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}
type stubApiKeyCache struct{}
func (stubApiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
return 0, nil
}
func (stubApiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
return nil
}
func (stubApiKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
return nil
}
func (stubApiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) error {
return nil
}
func (stubApiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
return nil
}
type stubGroupRepo struct{}
func (stubGroupRepo) Create(ctx context.Context, group *service.Group) error {
return errors.New("not implemented")
}
func (stubGroupRepo) GetByID(ctx context.Context, id int64) (*service.Group, error) {
return nil, service.ErrGroupNotFound
}
func (stubGroupRepo) Update(ctx context.Context, group *service.Group) error {
return errors.New("not implemented")
}
func (stubGroupRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (stubGroupRepo) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
return nil, errors.New("not implemented")
}
func (stubGroupRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (stubGroupRepo) ListActive(ctx context.Context) ([]service.Group, error) {
return nil, errors.New("not implemented")
}
func (stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
return nil, errors.New("not implemented")
}
func (stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) {
return false, errors.New("not implemented")
}
func (stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}
type stubUserSubscriptionRepo struct{}
func (stubUserSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error {
return errors.New("not implemented")
}
func (stubUserSubscriptionRepo) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) {
return nil, errors.New("not implemented")
}
func (stubUserSubscriptionRepo) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
return nil, errors.New("not implemented")
}
func (stubUserSubscriptionRepo) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
return nil, errors.New("not implemented")
}
func (stubUserSubscriptionRepo) Update(ctx context.Context, sub *service.UserSubscription) error {
return errors.New("not implemented")
}
func (stubUserSubscriptionRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (stubUserSubscriptionRepo) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
return nil, errors.New("not implemented")
}
func (stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
return nil, errors.New("not implemented")
}
func (stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
return false, errors.New("not implemented")
}
func (stubUserSubscriptionRepo) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error {
return errors.New("not implemented")
}
func (stubUserSubscriptionRepo) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error {
return errors.New("not implemented")
}
func (stubUserSubscriptionRepo) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error {
return errors.New("not implemented")
}
func (stubUserSubscriptionRepo) ActivateWindows(ctx context.Context, id int64, start time.Time) error {
return errors.New("not implemented")
}
func (stubUserSubscriptionRepo) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return errors.New("not implemented")
}
func (stubUserSubscriptionRepo) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return errors.New("not implemented")
}
func (stubUserSubscriptionRepo) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return errors.New("not implemented")
}
func (stubUserSubscriptionRepo) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
return errors.New("not implemented")
}
func (stubUserSubscriptionRepo) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
return 0, errors.New("not implemented")
}
type stubApiKeyRepo struct {
now time.Time
nextID int64
byID map[int64]*service.ApiKey
byKey map[string]*service.ApiKey
}
func newStubApiKeyRepo(now time.Time) *stubApiKeyRepo {
return &stubApiKeyRepo{
now: now,
nextID: 100,
byID: make(map[int64]*service.ApiKey),
byKey: make(map[string]*service.ApiKey),
}
}
func (r *stubApiKeyRepo) MustSeed(key *service.ApiKey) {
if key == nil {
return
}
clone := *key
r.byID[clone.ID] = &clone
r.byKey[clone.Key] = &clone
}
func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error {
if key == nil {
return errors.New("nil key")
}
if key.ID == 0 {
key.ID = r.nextID
r.nextID++
}
if key.CreatedAt.IsZero() {
key.CreatedAt = r.now
}
if key.UpdatedAt.IsZero() {
key.UpdatedAt = r.now
}
clone := *key
r.byID[clone.ID] = &clone
r.byKey[clone.Key] = &clone
return nil
}
func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
key, ok := r.byID[id]
if !ok {
return nil, service.ErrApiKeyNotFound
}
clone := *key
return &clone, nil
}
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
found, ok := r.byKey[key]
if !ok {
return nil, service.ErrApiKeyNotFound
}
clone := *found
return &clone, nil
}
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error {
if key == nil {
return errors.New("nil key")
}
if _, ok := r.byID[key.ID]; !ok {
return service.ErrApiKeyNotFound
}
if key.UpdatedAt.IsZero() {
key.UpdatedAt = r.now
}
clone := *key
r.byID[clone.ID] = &clone
r.byKey[clone.Key] = &clone
return nil
}
func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
key, ok := r.byID[id]
if !ok {
return service.ErrApiKeyNotFound
}
delete(r.byID, id)
delete(r.byKey, key.Key)
return nil
}
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
ids := make([]int64, 0, len(r.byID))
for id := range r.byID {
if r.byID[id].UserID == userID {
ids = append(ids, id)
}
}
sort.Slice(ids, func(i, j int) bool { return ids[i] > ids[j] })
start := params.Offset()
if start > len(ids) {
start = len(ids)
}
end := start + params.Limit()
if end > len(ids) {
end = len(ids)
}
out := make([]service.ApiKey, 0, end-start)
for _, id := range ids[start:end] {
clone := *r.byID[id]
out = append(out, clone)
}
total := int64(len(ids))
pageSize := params.Limit()
pages := int(math.Ceil(float64(total) / float64(pageSize)))
if pages < 1 {
pages = 1
}
return out, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: pageSize,
Pages: pages,
}, nil
}
func (r *stubApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
var count int64
for _, key := range r.byID {
if key.UserID == userID {
count++
}
}
return count, nil
}
func (r *stubApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
_, ok := r.byKey[key]
return ok, nil
}
func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
return nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}
type stubUsageLogRepo struct {
userLogs map[int64][]service.UsageLog
}
func newStubUsageLogRepo() *stubUsageLogRepo {
return &stubUsageLogRepo{userLogs: make(map[int64][]service.UsageLog)}
}
func (r *stubUsageLogRepo) SetUserLogs(userID int64, logs []service.UsageLog) {
r.userLogs[userID] = logs
}
func (r *stubUsageLogRepo) Create(ctx context.Context, log *service.UsageLog) error {
return errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (r *stubUsageLogRepo) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
logs := r.userLogs[userID]
total := int64(len(logs))
out := paginateLogs(logs, params)
return out, paginationResult(total, params), nil
}
func (r *stubUsageLogRepo) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
logs := r.userLogs[userID]
return logs, paginationResult(int64(len(logs)), pagination.PaginationParams{Page: 1, PageSize: 100}), nil
}
func (r *stubUsageLogRepo) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) {
return nil, errors.New("not implemented")
}
type stubSettingRepo struct {
all map[string]string
}
func newStubSettingRepo() *stubSettingRepo {
return &stubSettingRepo{all: make(map[string]string)}
}
func (r *stubSettingRepo) SetAll(values map[string]string) {
r.all = make(map[string]string, len(values))
for k, v := range values {
r.all[k] = v
}
}
func (r *stubSettingRepo) Get(ctx context.Context, key string) (*service.Setting, error) {
value, ok := r.all[key]
if !ok {
return nil, service.ErrSettingNotFound
}
return &service.Setting{Key: key, Value: value}, nil
}
func (r *stubSettingRepo) GetValue(ctx context.Context, key string) (string, error) {
value, ok := r.all[key]
if !ok {
return "", service.ErrSettingNotFound
}
return value, nil
}
func (r *stubSettingRepo) Set(ctx context.Context, key, value string) error {
r.all[key] = value
return nil
}
func (r *stubSettingRepo) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
out := make(map[string]string, len(keys))
for _, key := range keys {
out[key] = r.all[key]
}
return out, nil
}
func (r *stubSettingRepo) SetMultiple(ctx context.Context, settings map[string]string) error {
for k, v := range settings {
r.all[k] = v
}
return nil
}
func (r *stubSettingRepo) GetAll(ctx context.Context) (map[string]string, error) {
out := make(map[string]string, len(r.all))
for k, v := range r.all {
out[k] = v
}
return out, nil
}
func (r *stubSettingRepo) Delete(ctx context.Context, key string) error {
delete(r.all, key)
return nil
}
func paginateLogs(logs []service.UsageLog, params pagination.PaginationParams) []service.UsageLog {
start := params.Offset()
if start > len(logs) {
start = len(logs)
}
end := start + params.Limit()
if end > len(logs) {
end = len(logs)
}
out := make([]service.UsageLog, 0, end-start)
out = append(out, logs[start:end]...)
return out
}
func paginationResult(total int64, params pagination.PaginationParams) *pagination.PaginationResult {
pageSize := params.Limit()
pages := int(math.Ceil(float64(total) / float64(pageSize)))
if pages < 1 {
pages = 1
}
return &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: pageSize,
Pages: pages,
}
}
// Ensure compile-time interface compliance.
var (
_ service.UserRepository = (*stubUserRepo)(nil)
_ service.ApiKeyRepository = (*stubApiKeyRepo)(nil)
_ service.ApiKeyCache = (*stubApiKeyCache)(nil)
_ service.GroupRepository = (*stubGroupRepo)(nil)
_ service.UserSubscriptionRepository = (*stubUserSubscriptionRepo)(nil)
_ service.UsageLogRepository = (*stubUsageLogRepo)(nil)
_ service.SettingRepository = (*stubSettingRepo)(nil)
)
...@@ -5,7 +5,6 @@ import ( ...@@ -5,7 +5,6 @@ import (
"errors" "errors"
"strings" "strings"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
...@@ -84,7 +83,11 @@ func validateAdminApiKey( ...@@ -84,7 +83,11 @@ func validateAdminApiKey(
return false return false
} }
c.Set(string(ContextKeyUser), admin) c.Set(string(ContextKeyUser), AuthSubject{
UserID: admin.ID,
Concurrency: admin.Concurrency,
})
c.Set(string(ContextKeyUserRole), admin.Role)
c.Set("auth_method", "admin_api_key") c.Set("auth_method", "admin_api_key")
return true return true
} }
...@@ -121,12 +124,16 @@ func validateJWTForAdmin( ...@@ -121,12 +124,16 @@ func validateJWTForAdmin(
} }
// 检查管理员权限 // 检查管理员权限
if user.Role != model.RoleAdmin { if !user.IsAdmin() {
AbortWithError(c, 403, "FORBIDDEN", "Admin access required") AbortWithError(c, 403, "FORBIDDEN", "Admin access required")
return false return false
} }
c.Set(string(ContextKeyUser), user) c.Set(string(ContextKeyUser), AuthSubject{
UserID: user.ID,
Concurrency: user.Concurrency,
})
c.Set(string(ContextKeyUserRole), user.Role)
c.Set("auth_method", "jwt") c.Set("auth_method", "jwt")
return true return true
......
package middleware package middleware
import ( import (
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
...@@ -10,15 +10,14 @@ import ( ...@@ -10,15 +10,14 @@ import (
// 必须在JWTAuth中间件之后使用 // 必须在JWTAuth中间件之后使用
func AdminOnly() gin.HandlerFunc { func AdminOnly() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// 从上下文获取用户 role, ok := GetUserRoleFromContext(c)
user, exists := GetUserFromContext(c) if !ok {
if !exists {
AbortWithError(c, 401, "UNAUTHORIZED", "User not found in context") AbortWithError(c, 401, "UNAUTHORIZED", "User not found in context")
return return
} }
// 检查是否为管理员 // 检查是否为管理员
if user.Role != model.RoleAdmin { if role != service.RoleAdmin {
AbortWithError(c, 403, "FORBIDDEN", "Admin access required") AbortWithError(c, 403, "FORBIDDEN", "Admin access required")
return return
} }
......
...@@ -5,11 +5,9 @@ import ( ...@@ -5,11 +5,9 @@ import (
"log" "log"
"strings" "strings"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm"
) )
// NewApiKeyAuthMiddleware 创建 API Key 认证中间件 // NewApiKeyAuthMiddleware 创建 API Key 认证中间件
...@@ -46,7 +44,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti ...@@ -46,7 +44,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
// 从数据库验证API key // 从数据库验证API key
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString) apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, service.ErrApiKeyNotFound) {
AbortWithError(c, 401, "INVALID_API_KEY", "Invalid API key") AbortWithError(c, 401, "INVALID_API_KEY", "Invalid API key")
return return
} }
...@@ -121,28 +119,32 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti ...@@ -121,28 +119,32 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
// 将API key和用户信息存入上下文 // 将API key和用户信息存入上下文
c.Set(string(ContextKeyApiKey), apiKey) c.Set(string(ContextKeyApiKey), apiKey)
c.Set(string(ContextKeyUser), apiKey.User) c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency,
})
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
c.Next() c.Next()
} }
} }
// GetApiKeyFromContext 从上下文中获取API key // GetApiKeyFromContext 从上下文中获取API key
func GetApiKeyFromContext(c *gin.Context) (*model.ApiKey, bool) { func GetApiKeyFromContext(c *gin.Context) (*service.ApiKey, bool) {
value, exists := c.Get(string(ContextKeyApiKey)) value, exists := c.Get(string(ContextKeyApiKey))
if !exists { if !exists {
return nil, false return nil, false
} }
apiKey, ok := value.(*model.ApiKey) apiKey, ok := value.(*service.ApiKey)
return apiKey, ok return apiKey, ok
} }
// GetSubscriptionFromContext 从上下文中获取订阅信息 // GetSubscriptionFromContext 从上下文中获取订阅信息
func GetSubscriptionFromContext(c *gin.Context) (*model.UserSubscription, bool) { func GetSubscriptionFromContext(c *gin.Context) (*service.UserSubscription, bool) {
value, exists := c.Get(string(ContextKeySubscription)) value, exists := c.Get(string(ContextKeySubscription))
if !exists { if !exists {
return nil, false return nil, false
} }
subscription, ok := value.(*model.UserSubscription) subscription, ok := value.(*service.UserSubscription)
return subscription, ok return subscription, ok
} }
package middleware
import "github.com/gin-gonic/gin"
// AuthSubject is the minimal authenticated identity stored in gin context.
// Decision: {UserID int64, Concurrency int}
type AuthSubject struct {
UserID int64
Concurrency int
}
func GetAuthSubjectFromContext(c *gin.Context) (AuthSubject, bool) {
value, exists := c.Get(string(ContextKeyUser))
if !exists {
return AuthSubject{}, false
}
subject, ok := value.(AuthSubject)
return subject, ok
}
func GetUserRoleFromContext(c *gin.Context) (string, bool) {
value, exists := c.Get(string(ContextKeyUserRole))
if !exists {
return "", false
}
role, ok := value.(string)
return role, ok
}
...@@ -4,7 +4,6 @@ import ( ...@@ -4,7 +4,6 @@ import (
"errors" "errors"
"strings" "strings"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
...@@ -62,19 +61,14 @@ func jwtAuth(authService *service.AuthService, userService *service.UserService) ...@@ -62,19 +61,14 @@ func jwtAuth(authService *service.AuthService, userService *service.UserService)
return return
} }
// 将用户信息存入上下文 c.Set(string(ContextKeyUser), AuthSubject{
c.Set(string(ContextKeyUser), user) UserID: user.ID,
Concurrency: user.Concurrency,
})
c.Set(string(ContextKeyUserRole), user.Role)
c.Next() c.Next()
} }
} }
// GetUserFromContext 从上下文中获取用户 // Deprecated: prefer GetAuthSubjectFromContext in auth_subject.go.
func GetUserFromContext(c *gin.Context) (*model.User, bool) {
value, exists := c.Get(string(ContextKeyUser))
if !exists {
return nil, false
}
user, ok := value.(*model.User)
return user, ok
}
...@@ -8,6 +8,8 @@ type ContextKey string ...@@ -8,6 +8,8 @@ type ContextKey string
const ( const (
// ContextKeyUser 用户上下文键 // ContextKeyUser 用户上下文键
ContextKeyUser ContextKey = "user" ContextKeyUser ContextKey = "user"
// ContextKeyUserRole 当前用户角色(string)
ContextKeyUserRole ContextKey = "user_role"
// ContextKeyApiKey API密钥上下文键 // ContextKeyApiKey API密钥上下文键
ContextKeyApiKey ContextKey = "api_key" ContextKeyApiKey ContextKey = "api_key"
// ContextKeySubscription 订阅上下文键 // ContextKeySubscription 订阅上下文键
......
package model package service
import ( import "time"
"database/sql/driver"
"encoding/json"
"errors"
"time"
"gorm.io/gorm" type Account struct {
) ID int64
Name string
// JSONB 用于存储JSONB数据 Platform string
type JSONB map[string]any Type string
Credentials map[string]any
func (j JSONB) Value() (driver.Value, error) { Extra map[string]any
if j == nil { ProxyID *int64
return nil, nil Concurrency int
} Priority int
return json.Marshal(j) Status string
} ErrorMessage string
LastUsedAt *time.Time
func (j *JSONB) Scan(value any) error { CreatedAt time.Time
if value == nil { UpdatedAt time.Time
*j = nil
return nil Schedulable bool
}
bytes, ok := value.([]byte) RateLimitedAt *time.Time
if !ok { RateLimitResetAt *time.Time
return errors.New("type assertion to []byte failed") OverloadUntil *time.Time
}
return json.Unmarshal(bytes, j) SessionWindowStart *time.Time
SessionWindowEnd *time.Time
SessionWindowStatus string
Proxy *Proxy
AccountGroups []AccountGroup
GroupIDs []int64
Groups []*Group
} }
type Account struct {
ID int64 `gorm:"primaryKey" json:"id"`
Name string `gorm:"size:100;not null" json:"name"`
Platform string `gorm:"size:50;not null" json:"platform"` // anthropic/openai/gemini
Type string `gorm:"size:20;not null" json:"type"` // oauth/apikey
Credentials JSONB `gorm:"type:jsonb;default:'{}'" json:"credentials"` // 凭证(加密存储)
Extra JSONB `gorm:"type:jsonb;default:'{}'" json:"extra"` // 扩展信息
ProxyID *int64 `gorm:"index" json:"proxy_id"`
Concurrency int `gorm:"default:3;not null" json:"concurrency"`
Priority int `gorm:"default:50;not null" json:"priority"` // 1-100,越小越高
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled/error
ErrorMessage string `gorm:"type:text" json:"error_message"`
LastUsedAt *time.Time `gorm:"index" json:"last_used_at"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// 调度控制
Schedulable bool `gorm:"default:true;not null" json:"schedulable"`
// 限流状态 (429)
RateLimitedAt *time.Time `gorm:"index" json:"rate_limited_at"`
RateLimitResetAt *time.Time `gorm:"index" json:"rate_limit_reset_at"`
// 过载状态 (529)
OverloadUntil *time.Time `gorm:"index" json:"overload_until"`
// 5小时时间窗口
SessionWindowStart *time.Time `json:"session_window_start"`
SessionWindowEnd *time.Time `json:"session_window_end"`
SessionWindowStatus string `gorm:"size:20" json:"session_window_status"` // allowed/allowed_warning/rejected
// 关联
Proxy *Proxy `gorm:"foreignKey:ProxyID" json:"proxy,omitempty"`
AccountGroups []AccountGroup `gorm:"foreignKey:AccountID" json:"account_groups,omitempty"`
// 虚拟字段 (不存储到数据库)
GroupIDs []int64 `gorm:"-" json:"group_ids,omitempty"`
Groups []*Group `gorm:"-" json:"groups,omitempty"`
}
func (Account) TableName() string {
return "accounts"
}
// IsActive 检查是否激活
func (a *Account) IsActive() bool { func (a *Account) IsActive() bool {
return a.Status == "active" return a.Status == StatusActive
} }
// IsSchedulable 检查账号是否可调度
func (a *Account) IsSchedulable() bool { func (a *Account) IsSchedulable() bool {
if !a.IsActive() || !a.Schedulable { if !a.IsActive() || !a.Schedulable {
return false return false
...@@ -96,7 +52,6 @@ func (a *Account) IsSchedulable() bool { ...@@ -96,7 +52,6 @@ func (a *Account) IsSchedulable() bool {
return true return true
} }
// IsRateLimited 检查是否处于限流状态
func (a *Account) IsRateLimited() bool { func (a *Account) IsRateLimited() bool {
if a.RateLimitResetAt == nil { if a.RateLimitResetAt == nil {
return false return false
...@@ -104,7 +59,6 @@ func (a *Account) IsRateLimited() bool { ...@@ -104,7 +59,6 @@ func (a *Account) IsRateLimited() bool {
return time.Now().Before(*a.RateLimitResetAt) return time.Now().Before(*a.RateLimitResetAt)
} }
// IsOverloaded 检查是否处于过载状态
func (a *Account) IsOverloaded() bool { func (a *Account) IsOverloaded() bool {
if a.OverloadUntil == nil { if a.OverloadUntil == nil {
return false return false
...@@ -112,17 +66,14 @@ func (a *Account) IsOverloaded() bool { ...@@ -112,17 +66,14 @@ func (a *Account) IsOverloaded() bool {
return time.Now().Before(*a.OverloadUntil) return time.Now().Before(*a.OverloadUntil)
} }
// IsOAuth 检查是否为OAuth类型账号(包括oauth和setup-token)
func (a *Account) IsOAuth() bool { func (a *Account) IsOAuth() bool {
return a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken return a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken
} }
// CanGetUsage 检查账号是否可以获取usage信息(只有oauth类型可以,setup-token没有profile权限)
func (a *Account) CanGetUsage() bool { func (a *Account) CanGetUsage() bool {
return a.Type == AccountTypeOAuth return a.Type == AccountTypeOAuth
} }
// GetCredential 获取凭证字段
func (a *Account) GetCredential(key string) string { func (a *Account) GetCredential(key string) string {
if a.Credentials == nil { if a.Credentials == nil {
return "" return ""
...@@ -135,8 +86,6 @@ func (a *Account) GetCredential(key string) string { ...@@ -135,8 +86,6 @@ func (a *Account) GetCredential(key string) string {
return "" return ""
} }
// GetModelMapping 获取模型映射配置
// 返回格式: map[请求模型名]实际模型名
func (a *Account) GetModelMapping() map[string]string { func (a *Account) GetModelMapping() map[string]string {
if a.Credentials == nil { if a.Credentials == nil {
return nil return nil
...@@ -145,7 +94,6 @@ func (a *Account) GetModelMapping() map[string]string { ...@@ -145,7 +94,6 @@ func (a *Account) GetModelMapping() map[string]string {
if !ok || raw == nil { if !ok || raw == nil {
return nil return nil
} }
// 处理map[string]interface{}类型
if m, ok := raw.(map[string]any); ok { if m, ok := raw.(map[string]any); ok {
result := make(map[string]string) result := make(map[string]string)
for k, v := range m { for k, v := range m {
...@@ -160,19 +108,15 @@ func (a *Account) GetModelMapping() map[string]string { ...@@ -160,19 +108,15 @@ func (a *Account) GetModelMapping() map[string]string {
return nil return nil
} }
// IsModelSupported 检查请求的模型是否被该账号支持
// 如果没有设置模型映射,则支持所有模型
func (a *Account) IsModelSupported(requestedModel string) bool { func (a *Account) IsModelSupported(requestedModel string) bool {
mapping := a.GetModelMapping() mapping := a.GetModelMapping()
if len(mapping) == 0 { if len(mapping) == 0 {
return true // 没有映射配置,支持所有模型 return true
} }
_, exists := mapping[requestedModel] _, exists := mapping[requestedModel]
return exists return exists
} }
// GetMappedModel 获取映射后的实际模型名
// 如果没有映射,返回原始模型名
func (a *Account) GetMappedModel(requestedModel string) string { func (a *Account) GetMappedModel(requestedModel string) string {
mapping := a.GetModelMapping() mapping := a.GetModelMapping()
if len(mapping) == 0 { if len(mapping) == 0 {
...@@ -184,19 +128,17 @@ func (a *Account) GetMappedModel(requestedModel string) string { ...@@ -184,19 +128,17 @@ func (a *Account) GetMappedModel(requestedModel string) string {
return requestedModel return requestedModel
} }
// GetBaseURL 获取API基础URL(用于apikey类型账号)
func (a *Account) GetBaseURL() string { func (a *Account) GetBaseURL() string {
if a.Type != AccountTypeApiKey { if a.Type != AccountTypeApiKey {
return "" return ""
} }
baseURL := a.GetCredential("base_url") baseURL := a.GetCredential("base_url")
if baseURL == "" { if baseURL == "" {
return "https://api.anthropic.com" // 默认URL return "https://api.anthropic.com"
} }
return baseURL return baseURL
} }
// GetExtraString 从Extra字段获取字符串值
func (a *Account) GetExtraString(key string) string { func (a *Account) GetExtraString(key string) string {
if a.Extra == nil { if a.Extra == nil {
return "" return ""
...@@ -209,7 +151,6 @@ func (a *Account) GetExtraString(key string) string { ...@@ -209,7 +151,6 @@ func (a *Account) GetExtraString(key string) string {
return "" return ""
} }
// IsCustomErrorCodesEnabled 检查是否启用自定义错误码功能(仅适用于 apikey 类型)
func (a *Account) IsCustomErrorCodesEnabled() bool { func (a *Account) IsCustomErrorCodesEnabled() bool {
if a.Type != AccountTypeApiKey || a.Credentials == nil { if a.Type != AccountTypeApiKey || a.Credentials == nil {
return false return false
...@@ -222,7 +163,6 @@ func (a *Account) IsCustomErrorCodesEnabled() bool { ...@@ -222,7 +163,6 @@ func (a *Account) IsCustomErrorCodesEnabled() bool {
return false return false
} }
// GetCustomErrorCodes 获取自定义错误码列表
func (a *Account) GetCustomErrorCodes() []int { func (a *Account) GetCustomErrorCodes() []int {
if a.Credentials == nil { if a.Credentials == nil {
return nil return nil
...@@ -231,11 +171,9 @@ func (a *Account) GetCustomErrorCodes() []int { ...@@ -231,11 +171,9 @@ func (a *Account) GetCustomErrorCodes() []int {
if !ok || raw == nil { if !ok || raw == nil {
return nil return nil
} }
// 处理 []interface{} 类型(JSON反序列化后的格式)
if arr, ok := raw.([]any); ok { if arr, ok := raw.([]any); ok {
result := make([]int, 0, len(arr)) result := make([]int, 0, len(arr))
for _, v := range arr { for _, v := range arr {
// JSON 数字默认解析为 float64
if f, ok := v.(float64); ok { if f, ok := v.(float64); ok {
result = append(result, int(f)) result = append(result, int(f))
} }
...@@ -245,18 +183,14 @@ func (a *Account) GetCustomErrorCodes() []int { ...@@ -245,18 +183,14 @@ func (a *Account) GetCustomErrorCodes() []int {
return nil return nil
} }
// ShouldHandleErrorCode 检查指定错误码是否应该被处理(停止调度/标记限流等)
// 如果未启用自定义错误码或列表为空,返回 true(使用默认策略)
// 如果启用且列表非空,只有在列表中的错误码才返回 true
func (a *Account) ShouldHandleErrorCode(statusCode int) bool { func (a *Account) ShouldHandleErrorCode(statusCode int) bool {
if !a.IsCustomErrorCodesEnabled() { if !a.IsCustomErrorCodesEnabled() {
return true // 未启用,使用默认策略 return true
} }
codes := a.GetCustomErrorCodes() codes := a.GetCustomErrorCodes()
if len(codes) == 0 { if len(codes) == 0 {
return true // 启用但列表为空,fallback到默认策略 return true
} }
// 检查是否在自定义列表中
for _, code := range codes { for _, code := range codes {
if code == statusCode { if code == statusCode {
return true return true
...@@ -265,8 +199,6 @@ func (a *Account) ShouldHandleErrorCode(statusCode int) bool { ...@@ -265,8 +199,6 @@ func (a *Account) ShouldHandleErrorCode(statusCode int) bool {
return false return false
} }
// IsInterceptWarmupEnabled 检查是否启用预热请求拦截
// 启用后,标题生成、Warmup等预热请求将返回mock响应,不消耗上游token
func (a *Account) IsInterceptWarmupEnabled() bool { func (a *Account) IsInterceptWarmupEnabled() bool {
if a.Credentials == nil { if a.Credentials == nil {
return false return false
...@@ -279,31 +211,22 @@ func (a *Account) IsInterceptWarmupEnabled() bool { ...@@ -279,31 +211,22 @@ func (a *Account) IsInterceptWarmupEnabled() bool {
return false return false
} }
// =============== OpenAI 相关方法 ===============
// IsOpenAI 检查是否为 OpenAI 平台账号
func (a *Account) IsOpenAI() bool { func (a *Account) IsOpenAI() bool {
return a.Platform == PlatformOpenAI return a.Platform == PlatformOpenAI
} }
// IsAnthropic 检查是否为 Anthropic 平台账号
func (a *Account) IsAnthropic() bool { func (a *Account) IsAnthropic() bool {
return a.Platform == PlatformAnthropic return a.Platform == PlatformAnthropic
} }
// IsOpenAIOAuth 检查是否为 OpenAI OAuth 类型账号
func (a *Account) IsOpenAIOAuth() bool { func (a *Account) IsOpenAIOAuth() bool {
return a.IsOpenAI() && a.Type == AccountTypeOAuth return a.IsOpenAI() && a.Type == AccountTypeOAuth
} }
// IsOpenAIApiKey 检查是否为 OpenAI API Key 类型账号(Response 账号)
func (a *Account) IsOpenAIApiKey() bool { func (a *Account) IsOpenAIApiKey() bool {
return a.IsOpenAI() && a.Type == AccountTypeApiKey return a.IsOpenAI() && a.Type == AccountTypeApiKey
} }
// GetOpenAIBaseURL 获取 OpenAI API 基础 URL
// 对于 API Key 类型账号,从 credentials 中获取 base_url
// 对于 OAuth 类型账号,返回默认的 OpenAI API URL
func (a *Account) GetOpenAIBaseURL() string { func (a *Account) GetOpenAIBaseURL() string {
if !a.IsOpenAI() { if !a.IsOpenAI() {
return "" return ""
...@@ -314,10 +237,9 @@ func (a *Account) GetOpenAIBaseURL() string { ...@@ -314,10 +237,9 @@ func (a *Account) GetOpenAIBaseURL() string {
return baseURL return baseURL
} }
} }
return "https://api.openai.com" // OpenAI 默认 API URL return "https://api.openai.com"
} }
// GetOpenAIAccessToken 获取 OpenAI 访问令牌
func (a *Account) GetOpenAIAccessToken() string { func (a *Account) GetOpenAIAccessToken() string {
if !a.IsOpenAI() { if !a.IsOpenAI() {
return "" return ""
...@@ -325,7 +247,6 @@ func (a *Account) GetOpenAIAccessToken() string { ...@@ -325,7 +247,6 @@ func (a *Account) GetOpenAIAccessToken() string {
return a.GetCredential("access_token") return a.GetCredential("access_token")
} }
// GetOpenAIRefreshToken 获取 OpenAI 刷新令牌
func (a *Account) GetOpenAIRefreshToken() string { func (a *Account) GetOpenAIRefreshToken() string {
if !a.IsOpenAIOAuth() { if !a.IsOpenAIOAuth() {
return "" return ""
...@@ -333,7 +254,6 @@ func (a *Account) GetOpenAIRefreshToken() string { ...@@ -333,7 +254,6 @@ func (a *Account) GetOpenAIRefreshToken() string {
return a.GetCredential("refresh_token") return a.GetCredential("refresh_token")
} }
// GetOpenAIIDToken 获取 OpenAI ID Token(JWT,包含用户信息)
func (a *Account) GetOpenAIIDToken() string { func (a *Account) GetOpenAIIDToken() string {
if !a.IsOpenAIOAuth() { if !a.IsOpenAIOAuth() {
return "" return ""
...@@ -341,7 +261,6 @@ func (a *Account) GetOpenAIIDToken() string { ...@@ -341,7 +261,6 @@ func (a *Account) GetOpenAIIDToken() string {
return a.GetCredential("id_token") return a.GetCredential("id_token")
} }
// GetOpenAIApiKey 获取 OpenAI API Key(用于 Response 账号)
func (a *Account) GetOpenAIApiKey() string { func (a *Account) GetOpenAIApiKey() string {
if !a.IsOpenAIApiKey() { if !a.IsOpenAIApiKey() {
return "" return ""
...@@ -349,8 +268,6 @@ func (a *Account) GetOpenAIApiKey() string { ...@@ -349,8 +268,6 @@ func (a *Account) GetOpenAIApiKey() string {
return a.GetCredential("api_key") return a.GetCredential("api_key")
} }
// GetOpenAIUserAgent 获取 OpenAI 自定义 User-Agent
// 返回空字符串表示透传原始 User-Agent
func (a *Account) GetOpenAIUserAgent() string { func (a *Account) GetOpenAIUserAgent() string {
if !a.IsOpenAI() { if !a.IsOpenAI() {
return "" return ""
...@@ -358,7 +275,6 @@ func (a *Account) GetOpenAIUserAgent() string { ...@@ -358,7 +275,6 @@ func (a *Account) GetOpenAIUserAgent() string {
return a.GetCredential("user_agent") return a.GetCredential("user_agent")
} }
// GetChatGPTAccountID 获取 ChatGPT 账号 ID(从 ID Token 解析)
func (a *Account) GetChatGPTAccountID() string { func (a *Account) GetChatGPTAccountID() string {
if !a.IsOpenAIOAuth() { if !a.IsOpenAIOAuth() {
return "" return ""
...@@ -366,7 +282,6 @@ func (a *Account) GetChatGPTAccountID() string { ...@@ -366,7 +282,6 @@ func (a *Account) GetChatGPTAccountID() string {
return a.GetCredential("chatgpt_account_id") return a.GetCredential("chatgpt_account_id")
} }
// GetChatGPTUserID 获取 ChatGPT 用户 ID(从 ID Token 解析)
func (a *Account) GetChatGPTUserID() string { func (a *Account) GetChatGPTUserID() string {
if !a.IsOpenAIOAuth() { if !a.IsOpenAIOAuth() {
return "" return ""
...@@ -374,7 +289,6 @@ func (a *Account) GetChatGPTUserID() string { ...@@ -374,7 +289,6 @@ func (a *Account) GetChatGPTUserID() string {
return a.GetCredential("chatgpt_user_id") return a.GetCredential("chatgpt_user_id")
} }
// GetOpenAIOrganizationID 获取 OpenAI 组织 ID
func (a *Account) GetOpenAIOrganizationID() string { func (a *Account) GetOpenAIOrganizationID() string {
if !a.IsOpenAIOAuth() { if !a.IsOpenAIOAuth() {
return "" return ""
...@@ -382,7 +296,6 @@ func (a *Account) GetOpenAIOrganizationID() string { ...@@ -382,7 +296,6 @@ func (a *Account) GetOpenAIOrganizationID() string {
return a.GetCredential("organization_id") return a.GetCredential("organization_id")
} }
// GetOpenAITokenExpiresAt 获取 OpenAI Token 过期时间
func (a *Account) GetOpenAITokenExpiresAt() *time.Time { func (a *Account) GetOpenAITokenExpiresAt() *time.Time {
if !a.IsOpenAIOAuth() { if !a.IsOpenAIOAuth() {
return nil return nil
...@@ -391,25 +304,21 @@ func (a *Account) GetOpenAITokenExpiresAt() *time.Time { ...@@ -391,25 +304,21 @@ func (a *Account) GetOpenAITokenExpiresAt() *time.Time {
if expiresAtStr == "" { if expiresAtStr == "" {
return nil return nil
} }
// 尝试解析时间
t, err := time.Parse(time.RFC3339, expiresAtStr) t, err := time.Parse(time.RFC3339, expiresAtStr)
if err != nil { if err != nil {
// 尝试解析为 Unix 时间戳
if v, ok := a.Credentials["expires_at"].(float64); ok { if v, ok := a.Credentials["expires_at"].(float64); ok {
t = time.Unix(int64(v), 0) tt := time.Unix(int64(v), 0)
return &t return &tt
} }
return nil return nil
} }
return &t return &t
} }
// IsOpenAITokenExpired 检查 OpenAI Token 是否过期
func (a *Account) IsOpenAITokenExpired() bool { func (a *Account) IsOpenAITokenExpired() bool {
expiresAt := a.GetOpenAITokenExpiresAt() expiresAt := a.GetOpenAITokenExpiresAt()
if expiresAt == nil { if expiresAt == nil {
return false // 没有过期时间信息,假设未过期 return false
} }
// 提前 60 秒认为过期,便于刷新
return time.Now().Add(60 * time.Second).After(*expiresAt) return time.Now().Add(60 * time.Second).After(*expiresAt)
} }
package service
import "time"
type AccountGroup struct {
AccountID int64
GroupID int64
Priority int
CreatedAt time.Time
Account *Account
Group *Group
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment