Commit 25a304c2 authored by Forest's avatar Forest
Browse files

test: 增加 repository 测试

parent 9d30ceae
//go:build integration
package repository
import (
"context"
"database/sql"
"fmt"
"log"
"os"
"os/exec"
"strconv"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
redisclient "github.com/redis/go-redis/v9"
tcpostgres "github.com/testcontainers/testcontainers-go/modules/postgres"
tcredis "github.com/testcontainers/testcontainers-go/modules/redis"
gormpostgres "gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
const (
redisImageTag = "redis:8.4-alpine"
postgresImageTag = "postgres:18.1-alpine3.23"
)
var (
integrationDB *gorm.DB
integrationRedis *redisclient.Client
redisNamespaceSeq uint64
)
func TestMain(m *testing.M) {
ctx := context.Background()
if err := timezone.Init("UTC"); err != nil {
log.Printf("failed to init timezone: %v", err)
os.Exit(1)
}
if !dockerIsAvailable(ctx) {
// In CI we expect Docker to be available so integration tests should fail loudly.
if os.Getenv("CI") != "" {
log.Printf("docker is not available (CI=true); failing integration tests")
os.Exit(1)
}
log.Printf("docker is not available; skipping integration tests (start Docker to enable)")
os.Exit(0)
}
postgresImage := selectDockerImage(ctx, postgresImageTag)
pgContainer, err := tcpostgres.Run(
ctx,
postgresImage,
tcpostgres.WithDatabase("sub2api_test"),
tcpostgres.WithUsername("postgres"),
tcpostgres.WithPassword("postgres"),
tcpostgres.BasicWaitStrategies(),
)
if err != nil {
log.Printf("failed to start postgres container: %v", err)
os.Exit(1)
}
defer func() { _ = pgContainer.Terminate(ctx) }()
redisContainer, err := tcredis.Run(
ctx,
redisImageTag,
)
if err != nil {
log.Printf("failed to start redis container: %v", err)
os.Exit(1)
}
defer func() { _ = redisContainer.Terminate(ctx) }()
dsn, err := pgContainer.ConnectionString(ctx, "sslmode=disable", "TimeZone=UTC")
if err != nil {
log.Printf("failed to get postgres dsn: %v", err)
os.Exit(1)
}
integrationDB, err = openGormWithRetry(ctx, dsn, 30*time.Second)
if err != nil {
log.Printf("failed to open gorm db: %v", err)
os.Exit(1)
}
if err := model.AutoMigrate(integrationDB); err != nil {
log.Printf("failed to automigrate db: %v", err)
os.Exit(1)
}
redisHost, err := redisContainer.Host(ctx)
if err != nil {
log.Printf("failed to get redis host: %v", err)
os.Exit(1)
}
redisPort, err := redisContainer.MappedPort(ctx, "6379/tcp")
if err != nil {
log.Printf("failed to get redis port: %v", err)
os.Exit(1)
}
integrationRedis = redisclient.NewClient(&redisclient.Options{
Addr: fmt.Sprintf("%s:%d", redisHost, redisPort.Int()),
DB: 0,
})
if err := integrationRedis.Ping(ctx).Err(); err != nil {
log.Printf("failed to ping redis: %v", err)
os.Exit(1)
}
code := m.Run()
_ = integrationRedis.Close()
os.Exit(code)
}
func dockerIsAvailable(ctx context.Context) bool {
cmd := exec.CommandContext(ctx, "docker", "info")
cmd.Env = os.Environ()
return cmd.Run() == nil
}
func selectDockerImage(ctx context.Context, preferred string) string {
if dockerImageExists(ctx, preferred) {
return preferred
}
return preferred
}
func dockerImageExists(ctx context.Context, image string) bool {
cmd := exec.CommandContext(ctx, "docker", "image", "inspect", image)
cmd.Env = os.Environ()
cmd.Stdout = nil
cmd.Stderr = nil
return cmd.Run() == nil
}
func openGormWithRetry(ctx context.Context, dsn string, timeout time.Duration) (*gorm.DB, error) {
deadline := time.Now().Add(timeout)
var lastErr error
for time.Now().Before(deadline) {
db, err := gorm.Open(gormpostgres.Open(dsn), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
lastErr = err
time.Sleep(250 * time.Millisecond)
continue
}
sqlDB, err := db.DB()
if err != nil {
lastErr = err
time.Sleep(250 * time.Millisecond)
continue
}
if err := pingWithTimeout(ctx, sqlDB, 2*time.Second); err != nil {
lastErr = err
time.Sleep(250 * time.Millisecond)
continue
}
return db, nil
}
return nil, fmt.Errorf("db not ready after %s: %w", timeout, lastErr)
}
func pingWithTimeout(ctx context.Context, db *sql.DB, timeout time.Duration) error {
pingCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
return db.PingContext(pingCtx)
}
func testTx(t *testing.T) *gorm.DB {
t.Helper()
tx := integrationDB.Begin()
require.NoError(t, tx.Error, "begin tx")
t.Cleanup(func() {
_ = tx.Rollback().Error
})
return tx
}
func testRedis(t *testing.T) *redisclient.Client {
t.Helper()
prefix := fmt.Sprintf(
"it:%s:%d:%d:",
sanitizeRedisNamespace(t.Name()),
time.Now().UnixNano(),
atomic.AddUint64(&redisNamespaceSeq, 1),
)
opts := *integrationRedis.Options()
rdb := redisclient.NewClient(&opts)
rdb.AddHook(prefixHook{prefix: prefix})
t.Cleanup(func() {
ctx := context.Background()
var cursor uint64
for {
keys, nextCursor, err := integrationRedis.Scan(ctx, cursor, prefix+"*", 500).Result()
require.NoError(t, err, "scan redis keys for cleanup")
if len(keys) > 0 {
require.NoError(t, integrationRedis.Unlink(ctx, keys...).Err(), "unlink redis keys for cleanup")
}
cursor = nextCursor
if cursor == 0 {
break
}
}
_ = rdb.Close()
})
return rdb
}
func assertTTLWithin(t *testing.T, ttl time.Duration, min, max time.Duration) {
t.Helper()
require.GreaterOrEqual(t, ttl, min, "ttl should be >= min")
require.LessOrEqual(t, ttl, max, "ttl should be <= max")
}
func sanitizeRedisNamespace(name string) string {
name = strings.ReplaceAll(name, "/", "_")
name = strings.ReplaceAll(name, " ", "_")
return name
}
type prefixHook struct {
prefix string
}
func (h prefixHook) DialHook(next redisclient.DialHook) redisclient.DialHook { return next }
func (h prefixHook) ProcessHook(next redisclient.ProcessHook) redisclient.ProcessHook {
return func(ctx context.Context, cmd redisclient.Cmder) error {
h.prefixCmd(cmd)
return next(ctx, cmd)
}
}
func (h prefixHook) ProcessPipelineHook(next redisclient.ProcessPipelineHook) redisclient.ProcessPipelineHook {
return func(ctx context.Context, cmds []redisclient.Cmder) error {
for _, cmd := range cmds {
h.prefixCmd(cmd)
}
return next(ctx, cmds)
}
}
func (h prefixHook) prefixCmd(cmd redisclient.Cmder) {
args := cmd.Args()
if len(args) < 2 {
return
}
prefixOne := func(i int) {
if i < 0 || i >= len(args) {
return
}
switch v := args[i].(type) {
case string:
if v != "" && !strings.HasPrefix(v, h.prefix) {
args[i] = h.prefix + v
}
case []byte:
s := string(v)
if s != "" && !strings.HasPrefix(s, h.prefix) {
args[i] = []byte(h.prefix + s)
}
}
}
switch strings.ToLower(cmd.Name()) {
case "get", "set", "setnx", "setex", "psetex", "incr", "decr", "incrby", "expire", "pexpire", "ttl", "pttl",
"hgetall", "hget", "hset", "hdel", "hincrbyfloat", "exists":
prefixOne(1)
case "del", "unlink":
for i := 1; i < len(args); i++ {
prefixOne(i)
}
case "eval", "evalsha", "eval_ro", "evalsha_ro":
if len(args) < 3 {
return
}
numKeys, err := strconv.Atoi(fmt.Sprint(args[2]))
if err != nil || numKeys <= 0 {
return
}
for i := 0; i < numKeys && 3+i < len(args); i++ {
prefixOne(3 + i)
}
case "scan":
for i := 2; i+1 < len(args); i++ {
if strings.EqualFold(fmt.Sprint(args[i]), "match") {
prefixOne(i + 1)
break
}
}
}
}
// IntegrationRedisSuite provides a base suite for Redis integration tests.
// Embedding suites should call SetupTest to initialize ctx and rdb.
type IntegrationRedisSuite struct {
suite.Suite
ctx context.Context
rdb *redisclient.Client
}
// SetupTest initializes ctx and rdb for each test method.
func (s *IntegrationRedisSuite) SetupTest() {
s.ctx = context.Background()
s.rdb = testRedis(s.T())
}
// RequireNoError is a convenience method wrapping require.NoError with s.T().
func (s *IntegrationRedisSuite) RequireNoError(err error, msgAndArgs ...any) {
s.T().Helper()
require.NoError(s.T(), err, msgAndArgs...)
}
// AssertTTLWithin asserts that ttl is within [min, max].
func (s *IntegrationRedisSuite) AssertTTLWithin(ttl, min, max time.Duration) {
s.T().Helper()
assertTTLWithin(s.T(), ttl, min, max)
}
// IntegrationDBSuite provides a base suite for DB (Gorm) integration tests.
// Embedding suites should call SetupTest to initialize ctx and db.
type IntegrationDBSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
}
// SetupTest initializes ctx and db for each test method.
func (s *IntegrationDBSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
}
// RequireNoError is a convenience method wrapping require.NoError with s.T().
func (s *IntegrationDBSuite) RequireNoError(err error, msgAndArgs ...any) {
s.T().Helper()
require.NoError(s.T(), err, msgAndArgs...)
}
......@@ -12,11 +12,13 @@ import (
"github.com/imroc/req/v3"
)
type openaiOAuthService struct{}
// NewOpenAIOAuthClient creates a new OpenAI OAuth client
func NewOpenAIOAuthClient() ports.OpenAIOAuthClient {
return &openaiOAuthService{}
return &openaiOAuthService{tokenURL: openai.TokenURL}
}
type openaiOAuthService struct {
tokenURL string
}
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
......@@ -39,7 +41,7 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
SetContext(ctx).
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(openai.TokenURL)
Post(s.tokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
......@@ -67,7 +69,7 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
SetContext(ctx).
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(openai.TokenURL)
Post(s.tokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
......
package repository
import (
"context"
"io"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type OpenAIOAuthServiceSuite struct {
suite.Suite
ctx context.Context
srv *httptest.Server
svc *openaiOAuthService
received chan url.Values
}
func (s *OpenAIOAuthServiceSuite) SetupTest() {
s.ctx = context.Background()
s.received = make(chan url.Values, 1)
}
func (s *OpenAIOAuthServiceSuite) TearDownTest() {
if s.srv != nil {
s.srv.Close()
s.srv = nil
}
}
func (s *OpenAIOAuthServiceSuite) setupServer(handler http.HandlerFunc) {
s.srv = httptest.NewServer(handler)
s.svc = &openaiOAuthService{tokenURL: s.srv.URL}
}
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_DefaultRedirectURI() {
errCh := make(chan string, 1)
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
errCh <- "method mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if err := r.ParseForm(); err != nil {
errCh <- "ParseForm failed"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("grant_type"); got != "authorization_code" {
errCh <- "grant_type mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("client_id"); got != openai.ClientID {
errCh <- "client_id mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("code"); got != "code" {
errCh <- "code mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("redirect_uri"); got != openai.DefaultRedirectURI {
errCh <- "redirect_uri mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("code_verifier"); got != "ver" {
errCh <- "code_verifier mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at","refresh_token":"rt","token_type":"bearer","expires_in":3600}`)
}))
resp, err := s.svc.ExchangeCode(s.ctx, "code", "ver", "", "")
require.NoError(s.T(), err, "ExchangeCode")
select {
case msg := <-errCh:
require.Fail(s.T(), msg)
default:
}
require.Equal(s.T(), "at", resp.AccessToken)
require.Equal(s.T(), "rt", resp.RefreshToken)
}
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() {
errCh := make(chan string, 1)
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
errCh <- "ParseForm failed"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("grant_type"); got != "refresh_token" {
errCh <- "grant_type mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("refresh_token"); got != "rt" {
errCh <- "refresh_token mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("client_id"); got != openai.ClientID {
errCh <- "client_id mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("scope"); got != openai.RefreshScopes {
errCh <- "scope mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at2","refresh_token":"rt2","token_type":"bearer","expires_in":3600}`)
}))
resp, err := s.svc.RefreshToken(s.ctx, "rt", "")
require.NoError(s.T(), err, "RefreshToken")
select {
case msg := <-errCh:
require.Fail(s.T(), msg)
default:
}
require.Equal(s.T(), "at2", resp.AccessToken)
require.Equal(s.T(), "rt2", resp.RefreshToken)
}
func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
_, _ = io.WriteString(w, "bad")
}))
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "status 400")
require.ErrorContains(s.T(), err, "bad")
}
func (s *OpenAIOAuthServiceSuite) TestRequestError_ClosedServer() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
s.srv.Close()
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "request failed")
}
func (s *OpenAIOAuthServiceSuite) TestContextCancel() {
started := make(chan struct{})
block := make(chan struct{})
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
close(started)
<-block
}))
ctx, cancel := context.WithCancel(s.ctx)
done := make(chan error, 1)
go func() {
_, err := s.svc.ExchangeCode(ctx, "code", "ver", openai.DefaultRedirectURI, "")
done <- err
}()
<-started
cancel()
close(block)
err := <-done
require.Error(s.T(), err)
}
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() {
want := "http://localhost:9999/cb"
errCh := make(chan string, 1)
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = r.ParseForm()
if got := r.PostForm.Get("redirect_uri"); got != want {
errCh <- "redirect_uri mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`)
}))
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", want, "")
require.NoError(s.T(), err, "ExchangeCode")
select {
case msg := <-errCh:
require.Fail(s.T(), msg)
default:
}
}
func (s *OpenAIOAuthServiceSuite) TestTokenURL_CanBeOverriddenWithQuery() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = r.ParseForm()
s.received <- r.PostForm
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`)
}))
s.svc.tokenURL = s.srv.URL + "?x=1"
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
require.NoError(s.T(), err, "ExchangeCode")
select {
case <-s.received:
default:
require.Fail(s.T(), "expected server to receive request")
}
}
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_SuccessButInvalidJSON() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = io.WriteString(w, "not-valid-json")
}))
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
require.Error(s.T(), err, "expected error for invalid JSON response")
}
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_NonSuccessStatus() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = io.WriteString(w, "unauthorized")
}))
_, err := s.svc.RefreshToken(s.ctx, "rt", "")
require.Error(s.T(), err, "expected error for non-2xx status")
require.ErrorContains(s.T(), err, "status 401")
}
func TestOpenAIOAuthServiceSuite(t *testing.T) {
suite.Run(t, new(OpenAIOAuthServiceSuite))
}
package repository
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type PricingServiceSuite struct {
suite.Suite
ctx context.Context
srv *httptest.Server
client *pricingRemoteClient
}
func (s *PricingServiceSuite) SetupTest() {
s.ctx = context.Background()
client, ok := NewPricingRemoteClient().(*pricingRemoteClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
}
func (s *PricingServiceSuite) TearDownTest() {
if s.srv != nil {
s.srv.Close()
s.srv = nil
}
}
func (s *PricingServiceSuite) setupServer(handler http.HandlerFunc) {
s.srv = httptest.NewServer(handler)
}
func (s *PricingServiceSuite) TestFetchPricingJSON_Success() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/ok" {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"ok":true}`))
return
}
w.WriteHeader(http.StatusInternalServerError)
}))
body, err := s.client.FetchPricingJSON(s.ctx, s.srv.URL+"/ok")
require.NoError(s.T(), err, "FetchPricingJSON")
require.Equal(s.T(), `{"ok":true}`, string(body), "body mismatch")
}
func (s *PricingServiceSuite) TestFetchPricingJSON_NonOKStatus() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
_, err := s.client.FetchPricingJSON(s.ctx, s.srv.URL+"/err")
require.Error(s.T(), err, "expected error for non-200 status")
}
func (s *PricingServiceSuite) TestFetchHashText_ParsesFields() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/hashfile":
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("abc123 model_prices.json\n"))
case "/hashonly":
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("def456\n"))
default:
w.WriteHeader(http.StatusNotFound)
}
}))
hash, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/hashfile")
require.NoError(s.T(), err, "FetchHashText")
require.Equal(s.T(), "abc123", hash, "hash mismatch")
hash2, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/hashonly")
require.NoError(s.T(), err, "FetchHashText")
require.Equal(s.T(), "def456", hash2, "hash mismatch")
}
func (s *PricingServiceSuite) TestFetchHashText_NonOKStatus() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
_, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/nope")
require.Error(s.T(), err, "expected error for non-200 status")
}
func (s *PricingServiceSuite) TestFetchPricingJSON_InvalidURL() {
_, err := s.client.FetchPricingJSON(s.ctx, "://invalid-url")
require.Error(s.T(), err, "expected error for invalid URL")
}
func (s *PricingServiceSuite) TestFetchHashText_EmptyBody() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
// empty body
}))
hash, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/empty")
require.NoError(s.T(), err, "FetchHashText empty body should not error")
require.Equal(s.T(), "", hash, "expected empty hash")
}
func (s *PricingServiceSuite) TestFetchHashText_WhitespaceOnly() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(" \n"))
}))
hash, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/ws")
require.NoError(s.T(), err, "FetchHashText whitespace body should not error")
require.Equal(s.T(), "", hash, "expected empty hash after trimming")
}
func (s *PricingServiceSuite) TestFetchPricingJSON_ContextCancel() {
started := make(chan struct{})
block := make(chan struct{})
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
close(started)
<-block
}))
ctx, cancel := context.WithCancel(s.ctx)
done := make(chan error, 1)
go func() {
_, err := s.client.FetchPricingJSON(ctx, s.srv.URL+"/block")
done <- err
}()
<-started
cancel()
close(block)
err := <-done
require.Error(s.T(), err)
}
func TestPricingServiceSuite(t *testing.T) {
suite.Run(t, new(PricingServiceSuite))
}
......@@ -16,10 +16,14 @@ import (
"golang.org/x/net/proxy"
)
type proxyProbeService struct{}
func NewProxyExitInfoProber() service.ProxyExitInfoProber {
return &proxyProbeService{}
return &proxyProbeService{ipInfoURL: defaultIPInfoURL}
}
const defaultIPInfoURL = "https://ipinfo.io/json"
type proxyProbeService struct {
ipInfoURL string
}
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
......@@ -34,7 +38,7 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
}
startTime := time.Now()
req, err := http.NewRequestWithContext(ctx, "GET", "https://ipinfo.io/json", nil)
req, err := http.NewRequestWithContext(ctx, "GET", s.ipInfoURL, nil)
if err != nil {
return nil, 0, fmt.Errorf("failed to create request: %w", err)
}
......
package repository
import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type ProxyProbeServiceSuite struct {
suite.Suite
ctx context.Context
proxySrv *httptest.Server
prober *proxyProbeService
}
func (s *ProxyProbeServiceSuite) SetupTest() {
s.ctx = context.Background()
s.prober = &proxyProbeService{ipInfoURL: "http://ipinfo.test/json"}
}
func (s *ProxyProbeServiceSuite) TearDownTest() {
if s.proxySrv != nil {
s.proxySrv.Close()
s.proxySrv = nil
}
}
func (s *ProxyProbeServiceSuite) setupProxyServer(handler http.HandlerFunc) {
s.proxySrv = httptest.NewServer(handler)
}
func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_InvalidURL() {
_, err := createProxyTransport("://bad")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "invalid proxy URL")
}
func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_UnsupportedScheme() {
_, err := createProxyTransport("ftp://127.0.0.1:1")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "unsupported proxy protocol")
}
func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_Socks5SetsDialer() {
tr, err := createProxyTransport("socks5://127.0.0.1:1080")
require.NoError(s.T(), err, "createProxyTransport")
require.NotNil(s.T(), tr.DialContext, "expected DialContext to be set for socks5")
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
seen := make(chan string, 1)
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
seen <- r.RequestURI
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"ip":"1.2.3.4","city":"c","region":"r","country":"cc"}`)
}))
info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
require.NoError(s.T(), err, "ProbeProxy")
require.GreaterOrEqual(s.T(), latencyMs, int64(0), "unexpected latency")
require.Equal(s.T(), "1.2.3.4", info.IP)
require.Equal(s.T(), "c", info.City)
require.Equal(s.T(), "r", info.Region)
require.Equal(s.T(), "cc", info.Country)
// Verify proxy received the request
select {
case uri := <-seen:
require.Contains(s.T(), uri, "ipinfo.test", "expected request to go through proxy")
default:
require.Fail(s.T(), "expected proxy to receive request")
}
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_NonOKStatus() {
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable)
}))
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "status: 503")
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidJSON() {
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, "not-json")
}))
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "failed to parse response")
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidIPInfoURL() {
s.prober.ipInfoURL = "://invalid-url"
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
require.Error(s.T(), err, "expected error for invalid ipInfoURL")
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_ProxyServerClosed() {
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
s.proxySrv.Close()
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
require.Error(s.T(), err, "expected error when proxy server is closed")
}
func TestProxyProbeServiceSuite(t *testing.T) {
suite.Run(t, new(ProxyProbeServiceSuite))
}
//go:build integration
package repository
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
type ProxyRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *ProxyRepository
}
func (s *ProxyRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewProxyRepository(s.db)
}
func TestProxyRepoSuite(t *testing.T) {
suite.Run(t, new(ProxyRepoSuite))
}
// --- Create / GetByID / Update / Delete ---
func (s *ProxyRepoSuite) TestCreate() {
proxy := &model.Proxy{
Name: "test-create",
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Status: model.StatusActive,
}
err := s.repo.Create(s.ctx, proxy)
s.Require().NoError(err, "Create")
s.Require().NotZero(proxy.ID, "expected ID to be set")
got, err := s.repo.GetByID(s.ctx, proxy.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("test-create", got.Name)
}
func (s *ProxyRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
}
func (s *ProxyRepoSuite) TestUpdate() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "original"})
proxy.Name = "updated"
err := s.repo.Update(s.ctx, proxy)
s.Require().NoError(err, "Update")
got, err := s.repo.GetByID(s.ctx, proxy.ID)
s.Require().NoError(err, "GetByID after update")
s.Require().Equal("updated", got.Name)
}
func (s *ProxyRepoSuite) TestDelete() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "to-delete"})
err := s.repo.Delete(s.ctx, proxy.ID)
s.Require().NoError(err, "Delete")
_, err = s.repo.GetByID(s.ctx, proxy.ID)
s.Require().Error(err, "expected error after delete")
}
// --- List / ListWithFilters ---
func (s *ProxyRepoSuite) TestList() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2"})
proxies, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
s.Require().Len(proxies, 2)
s.Require().Equal(int64(2), page.Total)
}
func (s *ProxyRepoSuite) TestListWithFilters_Protocol() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1", Protocol: "http"})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2", Protocol: "socks5"})
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "socks5", "", "")
s.Require().NoError(err)
s.Require().Len(proxies, 1)
s.Require().Equal("socks5", proxies[0].Protocol)
}
func (s *ProxyRepoSuite) TestListWithFilters_Status() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1", Status: model.StatusActive})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2", Status: model.StatusDisabled})
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusDisabled, "")
s.Require().NoError(err)
s.Require().Len(proxies, 1)
s.Require().Equal(model.StatusDisabled, proxies[0].Status)
}
func (s *ProxyRepoSuite) TestListWithFilters_Search() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "production-proxy"})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "dev-proxy"})
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "prod")
s.Require().NoError(err)
s.Require().Len(proxies, 1)
s.Require().Contains(proxies[0].Name, "production")
}
// --- ListActive ---
func (s *ProxyRepoSuite) TestListActive() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "active1", Status: model.StatusActive})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "inactive1", Status: model.StatusDisabled})
proxies, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive")
s.Require().Len(proxies, 1)
s.Require().Equal("active1", proxies[0].Name)
}
// --- ExistsByHostPortAuth ---
func (s *ProxyRepoSuite) TestExistsByHostPortAuth() {
mustCreateProxy(s.T(), s.db, &model.Proxy{
Name: "p1",
Protocol: "http",
Host: "1.2.3.4",
Port: 8080,
Username: "user",
Password: "pass",
})
exists, err := s.repo.ExistsByHostPortAuth(s.ctx, "1.2.3.4", 8080, "user", "pass")
s.Require().NoError(err, "ExistsByHostPortAuth")
s.Require().True(exists)
notExists, err := s.repo.ExistsByHostPortAuth(s.ctx, "1.2.3.4", 8080, "wrong", "creds")
s.Require().NoError(err)
s.Require().False(notExists)
}
func (s *ProxyRepoSuite) TestExistsByHostPortAuth_NoAuth() {
mustCreateProxy(s.T(), s.db, &model.Proxy{
Name: "p-noauth",
Protocol: "http",
Host: "5.6.7.8",
Port: 8081,
Username: "",
Password: "",
})
exists, err := s.repo.ExistsByHostPortAuth(s.ctx, "5.6.7.8", 8081, "", "")
s.Require().NoError(err)
s.Require().True(exists)
}
// --- CountAccountsByProxyID ---
func (s *ProxyRepoSuite) TestCountAccountsByProxyID() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p-count"})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &proxy.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &proxy.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3"}) // no proxy
count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID)
s.Require().NoError(err, "CountAccountsByProxyID")
s.Require().Equal(int64(2), count)
}
func (s *ProxyRepoSuite) TestCountAccountsByProxyID_Zero() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p-zero"})
count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID)
s.Require().NoError(err)
s.Require().Zero(count)
}
// --- GetAccountCountsForProxies ---
func (s *ProxyRepoSuite) TestGetAccountCountsForProxies() {
p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"})
p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2"})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID})
counts, err := s.repo.GetAccountCountsForProxies(s.ctx)
s.Require().NoError(err, "GetAccountCountsForProxies")
s.Require().Equal(int64(2), counts[p1.ID])
s.Require().Equal(int64(1), counts[p2.ID])
}
func (s *ProxyRepoSuite) TestGetAccountCountsForProxies_Empty() {
counts, err := s.repo.GetAccountCountsForProxies(s.ctx)
s.Require().NoError(err)
s.Require().Empty(counts)
}
// --- ListActiveWithAccountCount ---
func (s *ProxyRepoSuite) TestListActiveWithAccountCount() {
base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{
Name: "p1",
Status: model.StatusActive,
CreatedAt: base.Add(-1 * time.Hour),
})
p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{
Name: "p2",
Status: model.StatusActive,
CreatedAt: base,
})
mustCreateProxy(s.T(), s.db, &model.Proxy{
Name: "p3-inactive",
Status: model.StatusDisabled,
})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID})
withCounts, err := s.repo.ListActiveWithAccountCount(s.ctx)
s.Require().NoError(err, "ListActiveWithAccountCount")
s.Require().Len(withCounts, 2, "expected 2 active proxies")
// Sorted by created_at DESC, so p2 first
s.Require().Equal(p2.ID, withCounts[0].ID)
s.Require().Equal(int64(1), withCounts[0].AccountCount)
s.Require().Equal(p1.ID, withCounts[1].ID)
s.Require().Equal(int64(2), withCounts[1].AccountCount)
}
// --- Combined original test ---
func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() {
p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{
Name: "p1",
Protocol: "http",
Host: "1.2.3.4",
Port: 8080,
Username: "u",
Password: "p",
CreatedAt: time.Now().Add(-1 * time.Hour),
UpdatedAt: time.Now().Add(-1 * time.Hour),
})
p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{
Name: "p2",
Protocol: "http",
Host: "5.6.7.8",
Port: 8081,
Username: "",
Password: "",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
})
exists, err := s.repo.ExistsByHostPortAuth(s.ctx, "1.2.3.4", 8080, "u", "p")
s.Require().NoError(err, "ExistsByHostPortAuth")
s.Require().True(exists, "expected proxy to exist")
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID})
count1, err := s.repo.CountAccountsByProxyID(s.ctx, p1.ID)
s.Require().NoError(err, "CountAccountsByProxyID")
s.Require().Equal(int64(2), count1, "expected 2 accounts for p1")
counts, err := s.repo.GetAccountCountsForProxies(s.ctx)
s.Require().NoError(err, "GetAccountCountsForProxies")
s.Require().Equal(int64(2), counts[p1.ID])
s.Require().Equal(int64(1), counts[p2.ID])
withCounts, err := s.repo.ListActiveWithAccountCount(s.ctx)
s.Require().NoError(err, "ListActiveWithAccountCount")
s.Require().Len(withCounts, 2, "expected 2 proxies")
for _, pc := range withCounts {
switch pc.ID {
case p1.ID:
s.Require().Equal(int64(2), pc.AccountCount, "p1 count mismatch")
case p2.ID:
s.Require().Equal(int64(1), pc.AccountCount, "p2 count mismatch")
default:
s.Require().Fail("unexpected proxy id", pc.ID)
}
}
}
//go:build integration
package repository
import (
"errors"
"fmt"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type RedeemCacheSuite struct {
IntegrationRedisSuite
cache *redeemCache
}
func (s *RedeemCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewRedeemCache(s.rdb).(*redeemCache)
}
func (s *RedeemCacheSuite) TestGetRedeemAttemptCount_Missing() {
missingUserID := int64(99999)
_, err := s.cache.GetRedeemAttemptCount(s.ctx, missingUserID)
require.Error(s.T(), err, "expected redis.Nil for missing rate-limit key")
require.True(s.T(), errors.Is(err, redis.Nil))
}
func (s *RedeemCacheSuite) TestIncrementAndGetRedeemAttemptCount() {
userID := int64(1)
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
require.NoError(s.T(), s.cache.IncrementRedeemAttemptCount(s.ctx, userID), "IncrementRedeemAttemptCount")
count, err := s.cache.GetRedeemAttemptCount(s.ctx, userID)
require.NoError(s.T(), err, "GetRedeemAttemptCount")
require.Equal(s.T(), 1, count, "count mismatch")
ttl, err := s.rdb.TTL(s.ctx, key).Result()
require.NoError(s.T(), err, "TTL")
s.AssertTTLWithin(ttl, 1*time.Second, redeemRateLimitDuration)
}
func (s *RedeemCacheSuite) TestMultipleIncrements() {
userID := int64(2)
require.NoError(s.T(), s.cache.IncrementRedeemAttemptCount(s.ctx, userID))
require.NoError(s.T(), s.cache.IncrementRedeemAttemptCount(s.ctx, userID))
require.NoError(s.T(), s.cache.IncrementRedeemAttemptCount(s.ctx, userID))
count, err := s.cache.GetRedeemAttemptCount(s.ctx, userID)
require.NoError(s.T(), err)
require.Equal(s.T(), 3, count, "count after 3 increments")
}
func (s *RedeemCacheSuite) TestAcquireAndReleaseRedeemLock() {
ok, err := s.cache.AcquireRedeemLock(s.ctx, "CODE", 10*time.Second)
require.NoError(s.T(), err, "AcquireRedeemLock")
require.True(s.T(), ok)
// Second acquire should fail
ok, err = s.cache.AcquireRedeemLock(s.ctx, "CODE", 10*time.Second)
require.NoError(s.T(), err, "AcquireRedeemLock 2")
require.False(s.T(), ok, "expected lock to be held")
// Release
require.NoError(s.T(), s.cache.ReleaseRedeemLock(s.ctx, "CODE"), "ReleaseRedeemLock")
// Now acquire should succeed
ok, err = s.cache.AcquireRedeemLock(s.ctx, "CODE", 10*time.Second)
require.NoError(s.T(), err, "AcquireRedeemLock after release")
require.True(s.T(), ok)
}
func (s *RedeemCacheSuite) TestAcquireRedeemLock_TTL() {
lockKey := redeemLockKeyPrefix + "CODE2"
lockTTL := 15 * time.Second
ok, err := s.cache.AcquireRedeemLock(s.ctx, "CODE2", lockTTL)
require.NoError(s.T(), err, "AcquireRedeemLock CODE2")
require.True(s.T(), ok)
ttl, err := s.rdb.TTL(s.ctx, lockKey).Result()
require.NoError(s.T(), err, "TTL lock key")
s.AssertTTLWithin(ttl, 1*time.Second, lockTTL)
}
func (s *RedeemCacheSuite) TestReleaseRedeemLock_Idempotent() {
// Release a lock that doesn't exist should not error
require.NoError(s.T(), s.cache.ReleaseRedeemLock(s.ctx, "NONEXISTENT"))
// Acquire, release, release again
ok, err := s.cache.AcquireRedeemLock(s.ctx, "IDEMPOTENT", 10*time.Second)
require.NoError(s.T(), err)
require.True(s.T(), ok)
require.NoError(s.T(), s.cache.ReleaseRedeemLock(s.ctx, "IDEMPOTENT"))
require.NoError(s.T(), s.cache.ReleaseRedeemLock(s.ctx, "IDEMPOTENT"), "second release should be idempotent")
}
func TestRedeemCacheSuite(t *testing.T) {
suite.Run(t, new(RedeemCacheSuite))
}
//go:build integration
package repository
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
type RedeemCodeRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *RedeemCodeRepository
}
func (s *RedeemCodeRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewRedeemCodeRepository(s.db)
}
func TestRedeemCodeRepoSuite(t *testing.T) {
suite.Run(t, new(RedeemCodeRepoSuite))
}
// --- Create / CreateBatch / GetByID / GetByCode ---
func (s *RedeemCodeRepoSuite) TestCreate() {
code := &model.RedeemCode{
Code: "TEST-CREATE",
Type: model.RedeemTypeBalance,
Value: 100,
Status: model.StatusUnused,
}
err := s.repo.Create(s.ctx, code)
s.Require().NoError(err, "Create")
s.Require().NotZero(code.ID, "expected ID to be set")
got, err := s.repo.GetByID(s.ctx, code.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("TEST-CREATE", got.Code)
}
func (s *RedeemCodeRepoSuite) TestCreateBatch() {
codes := []model.RedeemCode{
{Code: "BATCH-1", Type: model.RedeemTypeBalance, Value: 10, Status: model.StatusUnused},
{Code: "BATCH-2", Type: model.RedeemTypeBalance, Value: 20, Status: model.StatusUnused},
}
err := s.repo.CreateBatch(s.ctx, codes)
s.Require().NoError(err, "CreateBatch")
got1, err := s.repo.GetByCode(s.ctx, "BATCH-1")
s.Require().NoError(err)
s.Require().Equal(float64(10), got1.Value)
got2, err := s.repo.GetByCode(s.ctx, "BATCH-2")
s.Require().NoError(err)
s.Require().Equal(float64(20), got2.Value)
}
func (s *RedeemCodeRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
}
func (s *RedeemCodeRepoSuite) TestGetByCode() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "GET-BY-CODE", Type: model.RedeemTypeBalance})
got, err := s.repo.GetByCode(s.ctx, "GET-BY-CODE")
s.Require().NoError(err, "GetByCode")
s.Require().Equal("GET-BY-CODE", got.Code)
}
func (s *RedeemCodeRepoSuite) TestGetByCode_NotFound() {
_, err := s.repo.GetByCode(s.ctx, "NON-EXISTENT")
s.Require().Error(err, "expected error for non-existent code")
}
// --- Delete ---
func (s *RedeemCodeRepoSuite) TestDelete() {
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "TO-DELETE", Type: model.RedeemTypeBalance})
err := s.repo.Delete(s.ctx, code.ID)
s.Require().NoError(err, "Delete")
_, err = s.repo.GetByID(s.ctx, code.ID)
s.Require().Error(err, "expected error after delete")
}
// --- List / ListWithFilters ---
func (s *RedeemCodeRepoSuite) TestList() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "LIST-1", Type: model.RedeemTypeBalance})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "LIST-2", Type: model.RedeemTypeBalance})
codes, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
s.Require().Len(codes, 2)
s.Require().Equal(int64(2), page.Total)
}
func (s *RedeemCodeRepoSuite) TestListWithFilters_Type() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "TYPE-BAL", Type: model.RedeemTypeBalance})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "TYPE-SUB", Type: model.RedeemTypeSubscription})
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.RedeemTypeSubscription, "", "")
s.Require().NoError(err)
s.Require().Len(codes, 1)
s.Require().Equal(model.RedeemTypeSubscription, codes[0].Type)
}
func (s *RedeemCodeRepoSuite) TestListWithFilters_Status() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "STAT-UNUSED", Type: model.RedeemTypeBalance, Status: model.StatusUnused})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "STAT-USED", Type: model.RedeemTypeBalance, Status: model.StatusUsed})
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusUsed, "")
s.Require().NoError(err)
s.Require().Len(codes, 1)
s.Require().Equal(model.StatusUsed, codes[0].Status)
}
func (s *RedeemCodeRepoSuite) TestListWithFilters_Search() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "ALPHA-CODE", Type: model.RedeemTypeBalance})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "BETA-CODE", Type: model.RedeemTypeBalance})
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alpha")
s.Require().NoError(err)
s.Require().Len(codes, 1)
s.Require().Contains(codes[0].Code, "ALPHA")
}
func (s *RedeemCodeRepoSuite) TestListWithFilters_GroupPreload() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-preload"})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
Code: "WITH-GROUP",
Type: model.RedeemTypeSubscription,
GroupID: &group.ID,
})
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "")
s.Require().NoError(err)
s.Require().Len(codes, 1)
s.Require().NotNil(codes[0].Group, "expected Group preload")
s.Require().Equal(group.ID, codes[0].Group.ID)
}
// --- Update ---
func (s *RedeemCodeRepoSuite) TestUpdate() {
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "UPDATE-ME", Type: model.RedeemTypeBalance, Value: 10})
code.Value = 50
err := s.repo.Update(s.ctx, code)
s.Require().NoError(err, "Update")
got, err := s.repo.GetByID(s.ctx, code.ID)
s.Require().NoError(err)
s.Require().Equal(float64(50), got.Value)
}
// --- Use ---
func (s *RedeemCodeRepoSuite) TestUse() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "use@test.com"})
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "USE-ME", Type: model.RedeemTypeBalance, Status: model.StatusUnused})
err := s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().NoError(err, "Use")
got, err := s.repo.GetByID(s.ctx, code.ID)
s.Require().NoError(err)
s.Require().Equal(model.StatusUsed, got.Status)
s.Require().NotNil(got.UsedBy)
s.Require().Equal(user.ID, *got.UsedBy)
s.Require().NotNil(got.UsedAt)
}
func (s *RedeemCodeRepoSuite) TestUse_Idempotency() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "idem@test.com"})
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "IDEM-CODE", Type: model.RedeemTypeBalance, Status: model.StatusUnused})
err := s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().NoError(err, "Use first time")
// Second use should fail
err = s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().Error(err, "Use expected error on second call")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
}
func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "already@test.com"})
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "ALREADY-USED", Type: model.RedeemTypeBalance, Status: model.StatusUsed})
err := s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().Error(err, "expected error for already used code")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
}
// --- ListByUser ---
func (s *RedeemCodeRepoSuite) TestListByUser() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listby@test.com"})
base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
// Create codes with explicit used_at for ordering
c1 := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
Code: "USER-1",
Type: model.RedeemTypeBalance,
Status: model.StatusUsed,
UsedBy: &user.ID,
})
s.db.Model(c1).Update("used_at", base)
c2 := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
Code: "USER-2",
Type: model.RedeemTypeBalance,
Status: model.StatusUsed,
UsedBy: &user.ID,
})
s.db.Model(c2).Update("used_at", base.Add(1*time.Hour))
codes, err := s.repo.ListByUser(s.ctx, user.ID, 10)
s.Require().NoError(err, "ListByUser")
s.Require().Len(codes, 2)
// Ordered by used_at DESC, so USER-2 first
s.Require().Equal("USER-2", codes[0].Code)
s.Require().Equal("USER-1", codes[1].Code)
}
func (s *RedeemCodeRepoSuite) TestListByUser_WithGroupPreload() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "grp@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-listby"})
c := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
Code: "WITH-GRP",
Type: model.RedeemTypeSubscription,
Status: model.StatusUsed,
UsedBy: &user.ID,
GroupID: &group.ID,
})
s.db.Model(c).Update("used_at", time.Now())
codes, err := s.repo.ListByUser(s.ctx, user.ID, 10)
s.Require().NoError(err)
s.Require().Len(codes, 1)
s.Require().NotNil(codes[0].Group)
s.Require().Equal(group.ID, codes[0].Group.ID)
}
func (s *RedeemCodeRepoSuite) TestListByUser_DefaultLimit() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "deflimit@test.com"})
c := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
Code: "DEF-LIM",
Type: model.RedeemTypeBalance,
Status: model.StatusUsed,
UsedBy: &user.ID,
})
s.db.Model(c).Update("used_at", time.Now())
// limit <= 0 should default to 10
codes, err := s.repo.ListByUser(s.ctx, user.ID, 0)
s.Require().NoError(err)
s.Require().Len(codes, 1)
}
// --- Combined original test ---
func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "rc@example.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-rc"})
codes := []model.RedeemCode{
{Code: "CODEA", Type: model.RedeemTypeBalance, Value: 1, Status: model.StatusUnused, CreatedAt: time.Now()},
{Code: "CODEB", Type: model.RedeemTypeSubscription, Value: 0, Status: model.StatusUnused, GroupID: &group.ID, ValidityDays: 7, CreatedAt: time.Now()},
}
s.Require().NoError(s.repo.CreateBatch(s.ctx, codes), "CreateBatch")
list, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.RedeemTypeSubscription, model.StatusUnused, "code")
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total)
s.Require().Len(list, 1)
s.Require().NotNil(list[0].Group, "expected Group preload")
s.Require().Equal(group.ID, list[0].Group.ID)
codeB, err := s.repo.GetByCode(s.ctx, "CODEB")
s.Require().NoError(err, "GetByCode")
s.Require().NoError(s.repo.Use(s.ctx, codeB.ID, user.ID), "Use")
err = s.repo.Use(s.ctx, codeB.ID, user.ID)
s.Require().Error(err, "Use expected error on second call")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
codeA, err := s.repo.GetByCode(s.ctx, "CODEA")
s.Require().NoError(err, "GetByCode")
// Use fixed time instead of time.Sleep for deterministic ordering
s.db.Model(&model.RedeemCode{}).Where("id = ?", codeB.ID).Update("used_at", time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC))
s.Require().NoError(s.repo.Use(s.ctx, codeA.ID, user.ID), "Use codeA")
s.db.Model(&model.RedeemCode{}).Where("id = ?", codeA.ID).Update("used_at", time.Date(2025, 1, 1, 13, 0, 0, 0, time.UTC))
used, err := s.repo.ListByUser(s.ctx, user.ID, 10)
s.Require().NoError(err, "ListByUser")
s.Require().Len(used, 2, "expected 2 used codes")
s.Require().Equal("CODEA", used[0].Code, "expected newest used code first")
}
//go:build integration
package repository
import (
"context"
"testing"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
type SettingRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *SettingRepository
}
func (s *SettingRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewSettingRepository(s.db)
}
func TestSettingRepoSuite(t *testing.T) {
suite.Run(t, new(SettingRepoSuite))
}
func (s *SettingRepoSuite) TestSetAndGetValue() {
s.Require().NoError(s.repo.Set(s.ctx, "k1", "v1"), "Set")
got, err := s.repo.GetValue(s.ctx, "k1")
s.Require().NoError(err, "GetValue")
s.Require().Equal("v1", got, "GetValue mismatch")
}
func (s *SettingRepoSuite) TestSet_Upsert() {
s.Require().NoError(s.repo.Set(s.ctx, "k1", "v1"), "Set")
s.Require().NoError(s.repo.Set(s.ctx, "k1", "v2"), "Set upsert")
got, err := s.repo.GetValue(s.ctx, "k1")
s.Require().NoError(err, "GetValue after upsert")
s.Require().Equal("v2", got, "upsert mismatch")
}
func (s *SettingRepoSuite) TestGetValue_Missing() {
_, err := s.repo.GetValue(s.ctx, "nonexistent")
s.Require().Error(err, "expected error for missing key")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
}
func (s *SettingRepoSuite) TestSetMultiple_AndGetMultiple() {
s.Require().NoError(s.repo.SetMultiple(s.ctx, map[string]string{"k2": "v2", "k3": "v3"}), "SetMultiple")
m, err := s.repo.GetMultiple(s.ctx, []string{"k2", "k3"})
s.Require().NoError(err, "GetMultiple")
s.Require().Equal("v2", m["k2"])
s.Require().Equal("v3", m["k3"])
}
func (s *SettingRepoSuite) TestGetMultiple_EmptyKeys() {
m, err := s.repo.GetMultiple(s.ctx, []string{})
s.Require().NoError(err, "GetMultiple with empty keys")
s.Require().Empty(m, "expected empty map")
}
func (s *SettingRepoSuite) TestGetMultiple_Subset() {
s.Require().NoError(s.repo.SetMultiple(s.ctx, map[string]string{"a": "1", "b": "2", "c": "3"}))
m, err := s.repo.GetMultiple(s.ctx, []string{"a", "c", "nonexistent"})
s.Require().NoError(err, "GetMultiple subset")
s.Require().Equal("1", m["a"])
s.Require().Equal("3", m["c"])
_, exists := m["nonexistent"]
s.Require().False(exists, "nonexistent key should not be in map")
}
func (s *SettingRepoSuite) TestGetAll() {
s.Require().NoError(s.repo.SetMultiple(s.ctx, map[string]string{"x": "1", "y": "2"}))
all, err := s.repo.GetAll(s.ctx)
s.Require().NoError(err, "GetAll")
s.Require().GreaterOrEqual(len(all), 2, "expected at least 2 settings")
s.Require().Equal("1", all["x"])
s.Require().Equal("2", all["y"])
}
func (s *SettingRepoSuite) TestDelete() {
s.Require().NoError(s.repo.Set(s.ctx, "todelete", "val"))
s.Require().NoError(s.repo.Delete(s.ctx, "todelete"), "Delete")
_, err := s.repo.GetValue(s.ctx, "todelete")
s.Require().Error(err, "expected missing key error after Delete")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
}
func (s *SettingRepoSuite) TestDelete_Idempotent() {
// Delete a key that doesn't exist should not error
s.Require().NoError(s.repo.Delete(s.ctx, "nonexistent_delete"), "Delete nonexistent should be idempotent")
}
func (s *SettingRepoSuite) TestSetMultiple_Upsert() {
s.Require().NoError(s.repo.Set(s.ctx, "upsert_key", "old_value"))
s.Require().NoError(s.repo.SetMultiple(s.ctx, map[string]string{"upsert_key": "new_value", "new_key": "new_val"}))
got, err := s.repo.GetValue(s.ctx, "upsert_key")
s.Require().NoError(err)
s.Require().Equal("new_value", got, "SetMultiple should upsert existing key")
got2, err := s.repo.GetValue(s.ctx, "new_key")
s.Require().NoError(err)
s.Require().Equal("new_val", got2)
}
......@@ -16,6 +16,7 @@ const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/sitev
type turnstileVerifier struct {
httpClient *http.Client
verifyURL string
}
func NewTurnstileVerifier() service.TurnstileVerifier {
......@@ -23,6 +24,7 @@ func NewTurnstileVerifier() service.TurnstileVerifier {
httpClient: &http.Client{
Timeout: 10 * time.Second,
},
verifyURL: turnstileVerifyURL,
}
}
......@@ -34,7 +36,7 @@ func (v *turnstileVerifier) VerifyToken(ctx context.Context, secretKey, token, r
formData.Set("remoteip", remoteIP)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, turnstileVerifyURL, strings.NewReader(formData.Encode()))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, v.verifyURL, strings.NewReader(formData.Encode()))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
......
package repository
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type TurnstileServiceSuite struct {
suite.Suite
ctx context.Context
srv *httptest.Server
verifier *turnstileVerifier
received chan url.Values
}
func (s *TurnstileServiceSuite) SetupTest() {
s.ctx = context.Background()
s.received = make(chan url.Values, 1)
verifier, ok := NewTurnstileVerifier().(*turnstileVerifier)
require.True(s.T(), ok, "type assertion failed")
s.verifier = verifier
}
func (s *TurnstileServiceSuite) TearDownTest() {
if s.srv != nil {
s.srv.Close()
s.srv = nil
}
}
func (s *TurnstileServiceSuite) setupServer(handler http.HandlerFunc) {
s.srv = httptest.NewServer(handler)
s.verifier.verifyURL = s.srv.URL
}
func (s *TurnstileServiceSuite) TestVerifyToken_SendsFormAndDecodesJSON() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Capture form data in main goroutine context later
body, _ := io.ReadAll(r.Body)
values, _ := url.ParseQuery(string(body))
s.received <- values
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{Success: true})
}))
resp, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
require.NoError(s.T(), err, "VerifyToken")
require.NotNil(s.T(), resp)
require.True(s.T(), resp.Success, "expected success response")
// Assert form fields in main goroutine
select {
case values := <-s.received:
require.Equal(s.T(), "sk", values.Get("secret"))
require.Equal(s.T(), "token", values.Get("response"))
require.Equal(s.T(), "1.1.1.1", values.Get("remoteip"))
default:
require.Fail(s.T(), "expected server to receive request")
}
}
func (s *TurnstileServiceSuite) TestVerifyToken_ContentType() {
var contentType string
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
contentType = r.Header.Get("Content-Type")
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{Success: true})
}))
_, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
require.NoError(s.T(), err)
require.True(s.T(), strings.HasPrefix(contentType, "application/x-www-form-urlencoded"), "unexpected content-type: %s", contentType)
}
func (s *TurnstileServiceSuite) TestVerifyToken_EmptyRemoteIP_NotSent() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
values, _ := url.ParseQuery(string(body))
s.received <- values
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{Success: true})
}))
_, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "")
require.NoError(s.T(), err)
select {
case values := <-s.received:
require.Equal(s.T(), "", values.Get("remoteip"), "remoteip should be empty or not sent")
default:
require.Fail(s.T(), "expected server to receive request")
}
}
func (s *TurnstileServiceSuite) TestVerifyToken_RequestError() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
s.srv.Close()
_, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
require.Error(s.T(), err, "expected error when server is closed")
}
func (s *TurnstileServiceSuite) TestVerifyToken_InvalidJSON() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, "not-valid-json")
}))
_, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
require.Error(s.T(), err, "expected error for invalid JSON response")
}
func (s *TurnstileServiceSuite) TestVerifyToken_SuccessFalse() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{
Success: false,
ErrorCodes: []string{"invalid-input-response"},
})
}))
resp, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
require.NoError(s.T(), err, "VerifyToken should not error on success=false")
require.NotNil(s.T(), resp)
require.False(s.T(), resp.Success)
require.Contains(s.T(), resp.ErrorCodes, "invalid-input-response")
}
func TestTurnstileServiceSuite(t *testing.T) {
suite.Run(t, new(TurnstileServiceSuite))
}
//go:build integration
package repository
import (
"errors"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type UpdateCacheSuite struct {
IntegrationRedisSuite
cache *updateCache
}
func (s *UpdateCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewUpdateCache(s.rdb).(*updateCache)
}
func (s *UpdateCacheSuite) TestGetUpdateInfo_Missing() {
_, err := s.cache.GetUpdateInfo(s.ctx)
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing update info")
}
func (s *UpdateCacheSuite) TestSetAndGetUpdateInfo() {
updateTTL := 5 * time.Minute
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v1.2.3", updateTTL), "SetUpdateInfo")
info, err := s.cache.GetUpdateInfo(s.ctx)
require.NoError(s.T(), err, "GetUpdateInfo")
require.Equal(s.T(), "v1.2.3", info, "update info mismatch")
}
func (s *UpdateCacheSuite) TestSetUpdateInfo_TTL() {
updateTTL := 5 * time.Minute
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v1.2.3", updateTTL))
ttl, err := s.rdb.TTL(s.ctx, updateCacheKey).Result()
require.NoError(s.T(), err, "TTL updateCacheKey")
s.AssertTTLWithin(ttl, 1*time.Second, updateTTL)
}
func (s *UpdateCacheSuite) TestSetUpdateInfo_Overwrite() {
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v1.0.0", 5*time.Minute))
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v2.0.0", 5*time.Minute))
info, err := s.cache.GetUpdateInfo(s.ctx)
require.NoError(s.T(), err)
require.Equal(s.T(), "v2.0.0", info, "expected overwritten value")
}
func (s *UpdateCacheSuite) TestSetUpdateInfo_ZeroTTL() {
// TTL=0 means persist forever (no expiry) in Redis SET command
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v0.0.0", 0))
info, err := s.cache.GetUpdateInfo(s.ctx)
require.NoError(s.T(), err)
require.Equal(s.T(), "v0.0.0", info)
ttl, err := s.rdb.TTL(s.ctx, updateCacheKey).Result()
require.NoError(s.T(), err)
// TTL=-1 means no expiry, TTL=-2 means key doesn't exist
require.Equal(s.T(), time.Duration(-1), ttl, "expected TTL=-1 for key with no expiry")
}
func TestUpdateCacheSuite(t *testing.T) {
suite.Run(t, new(UpdateCacheSuite))
}
//go:build integration
package repository
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
type UsageLogRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *UsageLogRepository
}
func (s *UsageLogRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewUsageLogRepository(s.db)
}
func TestUsageLogRepoSuite(t *testing.T) {
suite.Run(t, new(UsageLogRepoSuite))
}
func (s *UsageLogRepoSuite) createUsageLog(user *model.User, apiKey *model.ApiKey, account *model.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *model.UsageLog {
log := &model.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3",
InputTokens: inputTokens,
OutputTokens: outputTokens,
TotalCost: cost,
ActualCost: cost,
CreatedAt: createdAt,
}
s.Require().NoError(s.repo.Create(s.ctx, log))
return log
}
// --- Create / GetByID ---
func (s *UsageLogRepoSuite) TestCreate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "create@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-create", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-create"})
log := &model.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.4,
}
err := s.repo.Create(s.ctx, log)
s.Require().NoError(err, "Create")
s.Require().NotZero(log.ID)
}
func (s *UsageLogRepoSuite) TestGetByID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "getbyid@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-getbyid"})
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
got, err := s.repo.GetByID(s.ctx, log.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal(log.ID, got.ID)
s.Require().Equal(10, got.InputTokens)
}
func (s *UsageLogRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
}
// --- Delete ---
func (s *UsageLogRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-delete", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-delete"})
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
err := s.repo.Delete(s.ctx, log.ID)
s.Require().NoError(err, "Delete")
_, err = s.repo.GetByID(s.ctx, log.ID)
s.Require().Error(err, "expected error after delete")
}
// --- ListByUser ---
func (s *UsageLogRepoSuite) TestListByUser() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyuser@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-listbyuser", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-listbyuser"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now())
logs, page, err := s.repo.ListByUser(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByUser")
s.Require().Len(logs, 2)
s.Require().Equal(int64(2), page.Total)
}
// --- ListByApiKey ---
func (s *UsageLogRepoSuite) TestListByApiKey() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyapikey@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-listbyapikey"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now())
logs, page, err := s.repo.ListByApiKey(s.ctx, apiKey.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByApiKey")
s.Require().Len(logs, 2)
s.Require().Equal(int64(2), page.Total)
}
// --- ListByAccount ---
func (s *UsageLogRepoSuite) TestListByAccount() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyaccount@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-listbyaccount"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
logs, page, err := s.repo.ListByAccount(s.ctx, account.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByAccount")
s.Require().Len(logs, 1)
s.Require().Equal(int64(1), page.Total)
}
// --- GetUserStats ---
func (s *UsageLogRepoSuite) TestGetUserStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "userstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-userstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-userstats"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour)
stats, err := s.repo.GetUserStats(s.ctx, user.ID, startTime, endTime)
s.Require().NoError(err, "GetUserStats")
s.Require().Equal(int64(2), stats.TotalRequests)
s.Require().Equal(int64(25), stats.InputTokens)
s.Require().Equal(int64(45), stats.OutputTokens)
}
// --- ListWithFilters ---
func (s *UsageLogRepoSuite) TestListWithFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filters", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filters"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
filters := usagestats.UsageLogFilters{UserID: user.ID}
logs, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, filters)
s.Require().NoError(err, "ListWithFilters")
s.Require().Len(logs, 1)
s.Require().Equal(int64(1), page.Total)
}
// --- GetDashboardStats ---
func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
now := time.Now()
todayStart := timezone.Today()
userToday := mustCreateUser(s.T(), s.db, &model.User{
Email: "today@example.com",
CreatedAt: maxTime(todayStart.Add(10*time.Second), now.Add(-10*time.Second)),
UpdatedAt: now,
})
userOld := mustCreateUser(s.T(), s.db, &model.User{
Email: "old@example.com",
CreatedAt: todayStart.Add(-24 * time.Hour),
UpdatedAt: todayStart.Add(-24 * time.Hour),
})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-ul"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: model.StatusDisabled})
resetAt := now.Add(10 * time.Minute)
accNormal := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-normal", Schedulable: true})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-error", Status: model.StatusError, Schedulable: true})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-rl", RateLimitedAt: &now, RateLimitResetAt: &resetAt, Schedulable: true})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-ov", OverloadUntil: &resetAt, Schedulable: true})
d1, d2, d3 := 100, 200, 300
logToday := &model.UsageLog{
UserID: userToday.ID,
ApiKeyID: apiKey1.ID,
AccountID: accNormal.ID,
Model: "claude-3",
GroupID: &group.ID,
InputTokens: 10,
OutputTokens: 20,
CacheCreationTokens: 3,
CacheReadTokens: 4,
TotalCost: 1.5,
ActualCost: 1.2,
DurationMs: &d1,
CreatedAt: maxTime(todayStart.Add(2*time.Minute), now.Add(-2*time.Minute)),
}
s.Require().NoError(s.repo.Create(s.ctx, logToday), "Create logToday")
logOld := &model.UsageLog{
UserID: userOld.ID,
ApiKeyID: apiKey1.ID,
AccountID: accNormal.ID,
Model: "claude-3",
InputTokens: 5,
OutputTokens: 6,
TotalCost: 0.7,
ActualCost: 0.7,
DurationMs: &d2,
CreatedAt: todayStart.Add(-1 * time.Hour),
}
s.Require().NoError(s.repo.Create(s.ctx, logOld), "Create logOld")
logPerf := &model.UsageLog{
UserID: userToday.ID,
ApiKeyID: apiKey1.ID,
AccountID: accNormal.ID,
Model: "claude-3",
InputTokens: 1,
OutputTokens: 2,
TotalCost: 0.1,
ActualCost: 0.1,
DurationMs: &d3,
CreatedAt: now.Add(-30 * time.Second),
}
s.Require().NoError(s.repo.Create(s.ctx, logPerf), "Create logPerf")
stats, err := s.repo.GetDashboardStats(s.ctx)
s.Require().NoError(err, "GetDashboardStats")
s.Require().Equal(int64(2), stats.TotalUsers, "TotalUsers mismatch")
s.Require().Equal(int64(1), stats.TodayNewUsers, "TodayNewUsers mismatch")
s.Require().Equal(int64(1), stats.ActiveUsers, "ActiveUsers mismatch")
s.Require().Equal(int64(2), stats.TotalApiKeys, "TotalApiKeys mismatch")
s.Require().Equal(int64(1), stats.ActiveApiKeys, "ActiveApiKeys mismatch")
s.Require().Equal(int64(4), stats.TotalAccounts, "TotalAccounts mismatch")
s.Require().Equal(int64(1), stats.ErrorAccounts, "ErrorAccounts mismatch")
s.Require().Equal(int64(1), stats.RateLimitAccounts, "RateLimitAccounts mismatch")
s.Require().Equal(int64(1), stats.OverloadAccounts, "OverloadAccounts mismatch")
s.Require().Equal(int64(3), stats.TotalRequests, "TotalRequests mismatch")
s.Require().Equal(int64(16), stats.TotalInputTokens, "TotalInputTokens mismatch")
s.Require().Equal(int64(28), stats.TotalOutputTokens, "TotalOutputTokens mismatch")
s.Require().Equal(int64(3), stats.TotalCacheCreationTokens, "TotalCacheCreationTokens mismatch")
s.Require().Equal(int64(4), stats.TotalCacheReadTokens, "TotalCacheReadTokens mismatch")
s.Require().Equal(int64(51), stats.TotalTokens, "TotalTokens mismatch")
s.Require().Equal(2.3, stats.TotalCost, "TotalCost mismatch")
s.Require().Equal(2.0, stats.TotalActualCost, "TotalActualCost mismatch")
s.Require().GreaterOrEqual(stats.TodayRequests, int64(1), "expected TodayRequests >= 1")
s.Require().GreaterOrEqual(stats.TodayCost, 0.0, "expected TodayCost >= 0")
wantRpm, wantTpm := s.repo.getPerformanceStats(s.ctx, 0)
s.Require().Equal(wantRpm, stats.Rpm, "Rpm mismatch")
s.Require().Equal(wantTpm, stats.Tpm, "Tpm mismatch")
}
// --- GetUserDashboardStats ---
func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "userdash@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-userdash", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-userdash"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
stats, err := s.repo.GetUserDashboardStats(s.ctx, user.ID)
s.Require().NoError(err, "GetUserDashboardStats")
s.Require().Equal(int64(1), stats.TotalApiKeys)
s.Require().Equal(int64(1), stats.TotalRequests)
}
// --- GetAccountTodayStats ---
func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "acctoday@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-today"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
stats, err := s.repo.GetAccountTodayStats(s.ctx, account.ID)
s.Require().NoError(err, "GetAccountTodayStats")
s.Require().Equal(int64(1), stats.Requests)
s.Require().Equal(int64(30), stats.Tokens)
}
// --- GetBatchUserUsageStats ---
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "batch1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "batch2@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user1.ID, Key: "sk-batch1", Name: "k"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user2.ID, Key: "sk-batch2", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-batch"})
s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user2, apiKey2, account, 15, 25, 0.6, time.Now())
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{user1.ID, user2.ID})
s.Require().NoError(err, "GetBatchUserUsageStats")
s.Require().Len(stats, 2)
s.Require().NotNil(stats[user1.ID])
s.Require().NotNil(stats[user2.ID])
}
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() {
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{})
s.Require().NoError(err)
s.Require().Empty(stats)
}
// --- GetBatchApiKeyUsageStats ---
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "batchkey@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-batchkey"})
s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now())
stats, err := s.repo.GetBatchApiKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID})
s.Require().NoError(err, "GetBatchApiKeyUsageStats")
s.Require().Len(stats, 2)
}
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
stats, err := s.repo.GetBatchApiKeyUsageStats(s.ctx, []int64{})
s.Require().NoError(err)
s.Require().Empty(stats)
}
// --- GetGlobalStats ---
func (s *UsageLogRepoSuite) TestGetGlobalStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "global@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-global", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-global"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
stats, err := s.repo.GetGlobalStats(s.ctx, base.Add(-1*time.Hour), base.Add(2*time.Hour))
s.Require().NoError(err, "GetGlobalStats")
s.Require().Equal(int64(2), stats.TotalRequests)
s.Require().Equal(int64(25), stats.TotalInputTokens)
s.Require().Equal(int64(45), stats.TotalOutputTokens)
}
func maxTime(a, b time.Time) time.Time {
if a.After(b) {
return a
}
return b
}
// --- ListByUserAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "timerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-timerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-timerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
s.createUsageLog(user, apiKey, account, 20, 30, 0.7, base.Add(-24*time.Hour)) // outside range
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour)
logs, _, err := s.repo.ListByUserAndTimeRange(s.ctx, user.ID, startTime, endTime)
s.Require().NoError(err, "ListByUserAndTimeRange")
s.Require().Len(logs, 2)
}
// --- ListByApiKeyAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "keytimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-keytimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(30*time.Minute))
s.createUsageLog(user, apiKey, account, 20, 30, 0.7, base.Add(-24*time.Hour)) // outside range
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour)
logs, _, err := s.repo.ListByApiKeyAndTimeRange(s.ctx, apiKey.ID, startTime, endTime)
s.Require().NoError(err, "ListByApiKeyAndTimeRange")
s.Require().Len(logs, 2)
}
// --- ListByAccountAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "acctimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-acctimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-acctimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(45*time.Minute))
s.createUsageLog(user, apiKey, account, 20, 30, 0.7, base.Add(-24*time.Hour)) // outside range
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour)
logs, _, err := s.repo.ListByAccountAndTimeRange(s.ctx, account.ID, startTime, endTime)
s.Require().NoError(err, "ListByAccountAndTimeRange")
s.Require().Len(logs, 2)
}
// --- ListByModelAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "modeltimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-modeltimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
// Create logs with different models
log1 := &model.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: base,
}
s.Require().NoError(s.repo.Create(s.ctx, log1))
log2 := &model.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 15,
OutputTokens: 25,
TotalCost: 0.6,
ActualCost: 0.6,
CreatedAt: base.Add(30 * time.Minute),
}
s.Require().NoError(s.repo.Create(s.ctx, log2))
log3 := &model.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-sonnet",
InputTokens: 20,
OutputTokens: 30,
TotalCost: 0.7,
ActualCost: 0.7,
CreatedAt: base.Add(1 * time.Hour),
}
s.Require().NoError(s.repo.Create(s.ctx, log3))
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour)
logs, _, err := s.repo.ListByModelAndTimeRange(s.ctx, "claude-3-opus", startTime, endTime)
s.Require().NoError(err, "ListByModelAndTimeRange")
s.Require().Len(logs, 2)
}
// --- GetAccountWindowStats ---
func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "windowstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-windowstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-windowstats"})
now := time.Now()
windowStart := now.Add(-10 * time.Minute)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, now.Add(-5*time.Minute))
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, now.Add(-3*time.Minute))
s.createUsageLog(user, apiKey, account, 20, 30, 0.7, now.Add(-30*time.Minute)) // outside window
stats, err := s.repo.GetAccountWindowStats(s.ctx, account.ID, windowStart)
s.Require().NoError(err, "GetAccountWindowStats")
s.Require().Equal(int64(2), stats.Requests)
s.Require().Equal(int64(70), stats.Tokens) // (10+20) + (15+25)
}
// --- GetUserUsageTrendByUserID ---
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrend@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-usertrend", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-usertrend"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
s.createUsageLog(user, apiKey, account, 20, 30, 0.7, base.Add(24*time.Hour)) // next day
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(48 * time.Hour)
trend, err := s.repo.GetUserUsageTrendByUserID(s.ctx, user.ID, startTime, endTime, "day")
s.Require().NoError(err, "GetUserUsageTrendByUserID")
s.Require().Len(trend, 2) // 2 different days
}
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrendhourly@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-usertrendhourly"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
s.createUsageLog(user, apiKey, account, 20, 30, 0.7, base.Add(2*time.Hour))
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(3 * time.Hour)
trend, err := s.repo.GetUserUsageTrendByUserID(s.ctx, user.ID, startTime, endTime, "hour")
s.Require().NoError(err, "GetUserUsageTrendByUserID hourly")
s.Require().Len(trend, 3) // 3 different hours
}
// --- GetUserModelStats ---
func (s *UsageLogRepoSuite) TestGetUserModelStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "modelstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-modelstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-modelstats"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
// Create logs with different models
log1 := &model.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 100,
OutputTokens: 200,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: base,
}
s.Require().NoError(s.repo.Create(s.ctx, log1))
log2 := &model.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-sonnet",
InputTokens: 50,
OutputTokens: 100,
TotalCost: 0.2,
ActualCost: 0.2,
CreatedAt: base.Add(1 * time.Hour),
}
s.Require().NoError(s.repo.Create(s.ctx, log2))
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour)
stats, err := s.repo.GetUserModelStats(s.ctx, user.ID, startTime, endTime)
s.Require().NoError(err, "GetUserModelStats")
s.Require().Len(stats, 2)
// Should be ordered by total_tokens DESC
s.Require().Equal("claude-3-opus", stats[0].Model)
s.Require().Equal(int64(300), stats[0].TotalTokens)
}
// --- GetUsageTrendWithFilters ---
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "trendfilters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-trendfilters", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-trendfilters"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(24*time.Hour))
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(48 * time.Hour)
// Test with user filter
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0)
s.Require().NoError(err, "GetUsageTrendWithFilters user filter")
s.Require().Len(trend, 2)
// Test with apiKey filter
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID)
s.Require().NoError(err, "GetUsageTrendWithFilters apiKey filter")
s.Require().Len(trend, 2)
// Test with both filters
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID)
s.Require().NoError(err, "GetUsageTrendWithFilters both filters")
s.Require().Len(trend, 2)
}
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "trendfilters-h@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-trendfilters-h"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(3 * time.Hour)
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0)
s.Require().NoError(err, "GetUsageTrendWithFilters hourly")
s.Require().Len(trend, 2)
}
// --- GetModelStatsWithFilters ---
func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "modelfilters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-modelfilters", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-modelfilters"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
log1 := &model.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 100,
OutputTokens: 200,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: base,
}
s.Require().NoError(s.repo.Create(s.ctx, log1))
log2 := &model.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-sonnet",
InputTokens: 50,
OutputTokens: 100,
TotalCost: 0.2,
ActualCost: 0.2,
CreatedAt: base.Add(1 * time.Hour),
}
s.Require().NoError(s.repo.Create(s.ctx, log2))
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour)
// Test with user filter
stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0)
s.Require().NoError(err, "GetModelStatsWithFilters user filter")
s.Require().Len(stats, 2)
// Test with apiKey filter
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0)
s.Require().NoError(err, "GetModelStatsWithFilters apiKey filter")
s.Require().Len(stats, 2)
// Test with account filter
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID)
s.Require().NoError(err, "GetModelStatsWithFilters account filter")
s.Require().Len(stats, 2)
}
// --- GetAccountUsageStats ---
func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "accstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-accstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-accstats"})
base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC)
// Create logs on different days
log1 := &model.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 100,
OutputTokens: 200,
TotalCost: 0.5,
ActualCost: 0.4,
CreatedAt: base.Add(12 * time.Hour),
}
s.Require().NoError(s.repo.Create(s.ctx, log1))
log2 := &model.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-sonnet",
InputTokens: 50,
OutputTokens: 100,
TotalCost: 0.2,
ActualCost: 0.15,
CreatedAt: base.Add(36 * time.Hour), // next day
}
s.Require().NoError(s.repo.Create(s.ctx, log2))
startTime := base
endTime := base.Add(72 * time.Hour)
resp, err := s.repo.GetAccountUsageStats(s.ctx, account.ID, startTime, endTime)
s.Require().NoError(err, "GetAccountUsageStats")
s.Require().Len(resp.History, 2, "expected 2 days of history")
s.Require().Equal(int64(2), resp.Summary.TotalRequests)
s.Require().Equal(int64(450), resp.Summary.TotalTokens)
s.Require().Len(resp.Models, 2)
}
func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-emptystats"})
base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC)
startTime := base
endTime := base.Add(72 * time.Hour)
resp, err := s.repo.GetAccountUsageStats(s.ctx, account.ID, startTime, endTime)
s.Require().NoError(err, "GetAccountUsageStats empty")
s.Require().Len(resp.History, 0)
s.Require().Equal(int64(0), resp.Summary.TotalRequests)
}
// --- GetUserUsageTrend ---
func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrend1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrend2@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-usertrends"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user1, apiKey1, account, 100, 200, 1.0, base)
s.createUsageLog(user2, apiKey2, account, 50, 100, 0.5, base)
s.createUsageLog(user1, apiKey1, account, 100, 200, 1.0, base.Add(24*time.Hour))
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(48 * time.Hour)
trend, err := s.repo.GetUserUsageTrend(s.ctx, startTime, endTime, "day", 10)
s.Require().NoError(err, "GetUserUsageTrend")
s.Require().GreaterOrEqual(len(trend), 2)
}
// --- GetApiKeyUsageTrend ---
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "keytrend@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-keytrends"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey1, account, 100, 200, 1.0, base)
s.createUsageLog(user, apiKey2, account, 50, 100, 0.5, base)
s.createUsageLog(user, apiKey1, account, 100, 200, 1.0, base.Add(24*time.Hour))
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(48 * time.Hour)
trend, err := s.repo.GetApiKeyUsageTrend(s.ctx, startTime, endTime, "day", 10)
s.Require().NoError(err, "GetApiKeyUsageTrend")
s.Require().GreaterOrEqual(len(trend), 2)
}
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "keytrendh@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-keytrendh"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 100, 200, 1.0, base)
s.createUsageLog(user, apiKey, account, 50, 100, 0.5, base.Add(1*time.Hour))
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(3 * time.Hour)
trend, err := s.repo.GetApiKeyUsageTrend(s.ctx, startTime, endTime, "hour", 10)
s.Require().NoError(err, "GetApiKeyUsageTrend hourly")
s.Require().Len(trend, 2)
}
// --- ListWithFilters (additional filter tests) ---
func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filterskey@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filterskey", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filterskey"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
filters := usagestats.UsageLogFilters{ApiKeyID: apiKey.ID}
logs, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, filters)
s.Require().NoError(err, "ListWithFilters apiKey")
s.Require().Len(logs, 1)
s.Require().Equal(int64(1), page.Total)
}
func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filterstime@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filterstime", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filterstime"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
s.createUsageLog(user, apiKey, account, 20, 30, 0.7, base.Add(-24*time.Hour)) // outside range
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour)
filters := usagestats.UsageLogFilters{StartTime: &startTime, EndTime: &endTime}
logs, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, filters)
s.Require().NoError(err, "ListWithFilters time range")
s.Require().Len(logs, 2)
s.Require().Equal(int64(2), page.Total)
}
func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filterscombined@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filterscombined", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filterscombined"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour)
filters := usagestats.UsageLogFilters{
UserID: user.ID,
ApiKeyID: apiKey.ID,
StartTime: &startTime,
EndTime: &endTime,
}
logs, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, filters)
s.Require().NoError(err, "ListWithFilters combined")
s.Require().Len(logs, 2)
s.Require().Equal(int64(2), page.Total)
}
//go:build integration
package repository
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/lib/pq"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
type UserRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *UserRepository
}
func (s *UserRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewUserRepository(s.db)
}
func TestUserRepoSuite(t *testing.T) {
suite.Run(t, new(UserRepoSuite))
}
// --- Create / GetByID / GetByEmail / Update / Delete ---
func (s *UserRepoSuite) TestCreate() {
user := &model.User{
Email: "create@test.com",
Username: "testuser",
Role: model.RoleUser,
Status: model.StatusActive,
}
err := s.repo.Create(s.ctx, user)
s.Require().NoError(err, "Create")
s.Require().NotZero(user.ID, "expected ID to be set")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("create@test.com", got.Email)
}
func (s *UserRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
}
func (s *UserRepoSuite) TestGetByEmail() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "byemail@test.com"})
got, err := s.repo.GetByEmail(s.ctx, user.Email)
s.Require().NoError(err, "GetByEmail")
s.Require().Equal(user.ID, got.ID)
}
func (s *UserRepoSuite) TestGetByEmail_NotFound() {
_, err := s.repo.GetByEmail(s.ctx, "nonexistent@test.com")
s.Require().Error(err, "expected error for non-existent email")
}
func (s *UserRepoSuite) TestUpdate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "update@test.com", Username: "original"})
user.Username = "updated"
err := s.repo.Update(s.ctx, user)
s.Require().NoError(err, "Update")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err, "GetByID after update")
s.Require().Equal("updated", got.Username)
}
func (s *UserRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"})
err := s.repo.Delete(s.ctx, user.ID)
s.Require().NoError(err, "Delete")
_, err = s.repo.GetByID(s.ctx, user.ID)
s.Require().Error(err, "expected error after delete")
}
// --- List / ListWithFilters ---
func (s *UserRepoSuite) TestList() {
mustCreateUser(s.T(), s.db, &model.User{Email: "list1@test.com"})
mustCreateUser(s.T(), s.db, &model.User{Email: "list2@test.com"})
users, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
s.Require().Len(users, 2)
s.Require().Equal(int64(2), page.Total)
}
func (s *UserRepoSuite) TestListWithFilters_Status() {
mustCreateUser(s.T(), s.db, &model.User{Email: "active@test.com", Status: model.StatusActive})
mustCreateUser(s.T(), s.db, &model.User{Email: "disabled@test.com", Status: model.StatusDisabled})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.StatusActive, "", "")
s.Require().NoError(err)
s.Require().Len(users, 1)
s.Require().Equal(model.StatusActive, users[0].Status)
}
func (s *UserRepoSuite) TestListWithFilters_Role() {
mustCreateUser(s.T(), s.db, &model.User{Email: "user@test.com", Role: model.RoleUser})
mustCreateUser(s.T(), s.db, &model.User{Email: "admin@test.com", Role: model.RoleAdmin})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.RoleAdmin, "")
s.Require().NoError(err)
s.Require().Len(users, 1)
s.Require().Equal(model.RoleAdmin, users[0].Role)
}
func (s *UserRepoSuite) TestListWithFilters_Search() {
mustCreateUser(s.T(), s.db, &model.User{Email: "alice@test.com", Username: "Alice"})
mustCreateUser(s.T(), s.db, &model.User{Email: "bob@test.com", Username: "Bob"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alice")
s.Require().NoError(err)
s.Require().Len(users, 1)
s.Require().Contains(users[0].Email, "alice")
}
func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() {
mustCreateUser(s.T(), s.db, &model.User{Email: "u1@test.com", Username: "JohnDoe"})
mustCreateUser(s.T(), s.db, &model.User{Email: "u2@test.com", Username: "JaneSmith"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "john")
s.Require().NoError(err)
s.Require().Len(users, 1)
s.Require().Equal("JohnDoe", users[0].Username)
}
func (s *UserRepoSuite) TestListWithFilters_SearchByWechat() {
mustCreateUser(s.T(), s.db, &model.User{Email: "w1@test.com", Wechat: "wx_hello"})
mustCreateUser(s.T(), s.db, &model.User{Email: "w2@test.com", Wechat: "wx_world"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "wx_hello")
s.Require().NoError(err)
s.Require().Len(users, 1)
s.Require().Equal("wx_hello", users[0].Wechat)
}
func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "sub@test.com", Status: model.StatusActive})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sub"})
_ = mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(1 * time.Hour),
})
_ = mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-1 * time.Hour),
})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "sub@")
s.Require().NoError(err, "ListWithFilters")
s.Require().Len(users, 1, "expected 1 user")
s.Require().Len(users[0].Subscriptions, 1, "expected 1 active subscription")
s.Require().NotNil(users[0].Subscriptions[0].Group, "expected subscription group preload")
s.Require().Equal(group.ID, users[0].Subscriptions[0].Group.ID, "group ID mismatch")
}
func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
mustCreateUser(s.T(), s.db, &model.User{
Email: "a@example.com",
Username: "Alice",
Wechat: "wx_a",
Role: model.RoleUser,
Status: model.StatusActive,
Balance: 10,
})
target := mustCreateUser(s.T(), s.db, &model.User{
Email: "b@example.com",
Username: "Bob",
Wechat: "wx_b",
Role: model.RoleAdmin,
Status: model.StatusActive,
Balance: 1,
})
mustCreateUser(s.T(), s.db, &model.User{
Email: "c@example.com",
Role: model.RoleAdmin,
Status: model.StatusDisabled,
})
users, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.StatusActive, model.RoleAdmin, "b@")
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
s.Require().Len(users, 1, "ListWithFilters len mismatch")
s.Require().Equal(target.ID, users[0].ID, "ListWithFilters result mismatch")
}
// --- Balance operations ---
func (s *UserRepoSuite) TestUpdateBalance() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "bal@test.com", Balance: 10})
err := s.repo.UpdateBalance(s.ctx, user.ID, 2.5)
s.Require().NoError(err, "UpdateBalance")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().Equal(12.5, got.Balance)
}
func (s *UserRepoSuite) TestUpdateBalance_Negative() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "balneg@test.com", Balance: 10})
err := s.repo.UpdateBalance(s.ctx, user.ID, -3)
s.Require().NoError(err, "UpdateBalance with negative")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().Equal(7.0, got.Balance)
}
func (s *UserRepoSuite) TestDeductBalance() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "deduct@test.com", Balance: 10})
err := s.repo.DeductBalance(s.ctx, user.ID, 5)
s.Require().NoError(err, "DeductBalance")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().Equal(5.0, got.Balance)
}
func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "insuf@test.com", Balance: 5})
err := s.repo.DeductBalance(s.ctx, user.ID, 999)
s.Require().Error(err, "expected error for insufficient balance")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
}
func (s *UserRepoSuite) TestDeductBalance_ExactAmount() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "exact@test.com", Balance: 10})
err := s.repo.DeductBalance(s.ctx, user.ID, 10)
s.Require().NoError(err, "DeductBalance exact amount")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().Zero(got.Balance)
}
// --- Concurrency ---
func (s *UserRepoSuite) TestUpdateConcurrency() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "conc@test.com", Concurrency: 5})
err := s.repo.UpdateConcurrency(s.ctx, user.ID, 3)
s.Require().NoError(err, "UpdateConcurrency")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().Equal(8, got.Concurrency)
}
func (s *UserRepoSuite) TestUpdateConcurrency_Negative() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "concneg@test.com", Concurrency: 5})
err := s.repo.UpdateConcurrency(s.ctx, user.ID, -2)
s.Require().NoError(err, "UpdateConcurrency negative")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().Equal(3, got.Concurrency)
}
// --- ExistsByEmail ---
func (s *UserRepoSuite) TestExistsByEmail() {
mustCreateUser(s.T(), s.db, &model.User{Email: "exists@test.com"})
exists, err := s.repo.ExistsByEmail(s.ctx, "exists@test.com")
s.Require().NoError(err, "ExistsByEmail")
s.Require().True(exists)
notExists, err := s.repo.ExistsByEmail(s.ctx, "notexists@test.com")
s.Require().NoError(err)
s.Require().False(notExists)
}
// --- RemoveGroupFromAllowedGroups ---
func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups() {
groupID := int64(42)
userA := mustCreateUser(s.T(), s.db, &model.User{
Email: "a1@example.com",
AllowedGroups: pq.Int64Array{groupID, 7},
})
mustCreateUser(s.T(), s.db, &model.User{
Email: "a2@example.com",
AllowedGroups: pq.Int64Array{7},
})
affected, err := s.repo.RemoveGroupFromAllowedGroups(s.ctx, groupID)
s.Require().NoError(err, "RemoveGroupFromAllowedGroups")
s.Require().Equal(int64(1), affected, "expected 1 affected row")
got, err := s.repo.GetByID(s.ctx, userA.ID)
s.Require().NoError(err, "GetByID")
for _, id := range got.AllowedGroups {
s.Require().NotEqual(groupID, id, "expected groupID to be removed from allowed_groups")
}
}
func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups_NoMatch() {
mustCreateUser(s.T(), s.db, &model.User{
Email: "nomatch@test.com",
AllowedGroups: pq.Int64Array{1, 2, 3},
})
affected, err := s.repo.RemoveGroupFromAllowedGroups(s.ctx, 999)
s.Require().NoError(err)
s.Require().Zero(affected, "expected no affected rows")
}
// --- GetFirstAdmin ---
func (s *UserRepoSuite) TestGetFirstAdmin() {
admin1 := mustCreateUser(s.T(), s.db, &model.User{
Email: "admin1@example.com",
Role: model.RoleAdmin,
Status: model.StatusActive,
})
mustCreateUser(s.T(), s.db, &model.User{
Email: "admin2@example.com",
Role: model.RoleAdmin,
Status: model.StatusActive,
})
got, err := s.repo.GetFirstAdmin(s.ctx)
s.Require().NoError(err, "GetFirstAdmin")
s.Require().Equal(admin1.ID, got.ID, "GetFirstAdmin mismatch")
}
func (s *UserRepoSuite) TestGetFirstAdmin_NoAdmin() {
mustCreateUser(s.T(), s.db, &model.User{
Email: "user@example.com",
Role: model.RoleUser,
Status: model.StatusActive,
})
_, err := s.repo.GetFirstAdmin(s.ctx)
s.Require().Error(err, "expected error when no admin exists")
}
func (s *UserRepoSuite) TestGetFirstAdmin_DisabledAdminIgnored() {
mustCreateUser(s.T(), s.db, &model.User{
Email: "disabled@example.com",
Role: model.RoleAdmin,
Status: model.StatusDisabled,
})
activeAdmin := mustCreateUser(s.T(), s.db, &model.User{
Email: "active@example.com",
Role: model.RoleAdmin,
Status: model.StatusActive,
})
got, err := s.repo.GetFirstAdmin(s.ctx)
s.Require().NoError(err, "GetFirstAdmin")
s.Require().Equal(activeAdmin.ID, got.ID, "should return only active admin")
}
// --- Combined original test ---
func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
user1 := mustCreateUser(s.T(), s.db, &model.User{
Email: "a@example.com",
Username: "Alice",
Wechat: "wx_a",
Role: model.RoleUser,
Status: model.StatusActive,
Balance: 10,
})
user2 := mustCreateUser(s.T(), s.db, &model.User{
Email: "b@example.com",
Username: "Bob",
Wechat: "wx_b",
Role: model.RoleAdmin,
Status: model.StatusActive,
Balance: 1,
})
_ = mustCreateUser(s.T(), s.db, &model.User{
Email: "c@example.com",
Role: model.RoleAdmin,
Status: model.StatusDisabled,
})
got, err := s.repo.GetByID(s.ctx, user1.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal(user1.Email, got.Email, "GetByID email mismatch")
gotByEmail, err := s.repo.GetByEmail(s.ctx, user2.Email)
s.Require().NoError(err, "GetByEmail")
s.Require().Equal(user2.ID, gotByEmail.ID, "GetByEmail ID mismatch")
got.Username = "Alice2"
s.Require().NoError(s.repo.Update(s.ctx, got), "Update")
got2, err := s.repo.GetByID(s.ctx, user1.ID)
s.Require().NoError(err, "GetByID after update")
s.Require().Equal("Alice2", got2.Username, "Update did not persist")
s.Require().NoError(s.repo.UpdateBalance(s.ctx, user1.ID, 2.5), "UpdateBalance")
got3, err := s.repo.GetByID(s.ctx, user1.ID)
s.Require().NoError(err, "GetByID after UpdateBalance")
s.Require().Equal(12.5, got3.Balance, "UpdateBalance mismatch")
s.Require().NoError(s.repo.DeductBalance(s.ctx, user1.ID, 5), "DeductBalance")
got4, err := s.repo.GetByID(s.ctx, user1.ID)
s.Require().NoError(err, "GetByID after DeductBalance")
s.Require().Equal(7.5, got4.Balance, "DeductBalance mismatch")
err = s.repo.DeductBalance(s.ctx, user1.ID, 999)
s.Require().Error(err, "DeductBalance expected error for insufficient balance")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound, "DeductBalance unexpected error")
s.Require().NoError(s.repo.UpdateConcurrency(s.ctx, user1.ID, 3), "UpdateConcurrency")
got5, err := s.repo.GetByID(s.ctx, user1.ID)
s.Require().NoError(err, "GetByID after UpdateConcurrency")
s.Require().Equal(user1.Concurrency+3, got5.Concurrency, "UpdateConcurrency mismatch")
params := pagination.PaginationParams{Page: 1, PageSize: 10}
users, page, err := s.repo.ListWithFilters(s.ctx, params, model.StatusActive, model.RoleAdmin, "b@")
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
s.Require().Len(users, 1, "ListWithFilters len mismatch")
s.Require().Equal(user2.ID, users[0].ID, "ListWithFilters result mismatch")
}
//go:build integration
package repository
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
type UserSubscriptionRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *UserSubscriptionRepository
}
func (s *UserSubscriptionRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewUserSubscriptionRepository(s.db)
}
func TestUserSubscriptionRepoSuite(t *testing.T) {
suite.Run(t, new(UserSubscriptionRepoSuite))
}
// --- Create / GetByID / Update / Delete ---
func (s *UserSubscriptionRepoSuite) TestCreate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "sub-create@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-create"})
sub := &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
}
err := s.repo.Create(s.ctx, sub)
s.Require().NoError(err, "Create")
s.Require().NotZero(sub.ID, "expected ID to be set")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal(sub.UserID, got.UserID)
s.Require().Equal(sub.GroupID, got.GroupID)
}
func (s *UserSubscriptionRepoSuite) TestGetByID_WithPreloads() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "preload@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-preload"})
admin := mustCreateUser(s.T(), s.db, &model.User{Email: "admin@test.com", Role: model.RoleAdmin})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
AssignedBy: &admin.ID,
})
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err, "GetByID")
s.Require().NotNil(got.User, "expected User preload")
s.Require().NotNil(got.Group, "expected Group preload")
s.Require().NotNil(got.AssignedByUser, "expected AssignedByUser preload")
s.Require().Equal(user.ID, got.User.ID)
s.Require().Equal(group.ID, got.Group.ID)
s.Require().Equal(admin.ID, got.AssignedByUser.ID)
}
func (s *UserSubscriptionRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
}
func (s *UserSubscriptionRepoSuite) TestUpdate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "update@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-update"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
sub.Notes = "updated notes"
err := s.repo.Update(s.ctx, sub)
s.Require().NoError(err, "Update")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err, "GetByID after update")
s.Require().Equal("updated notes", got.Notes)
}
func (s *UserSubscriptionRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-delete"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
err := s.repo.Delete(s.ctx, sub.ID)
s.Require().NoError(err, "Delete")
_, err = s.repo.GetByID(s.ctx, sub.ID)
s.Require().Error(err, "expected error after delete")
}
// --- GetByUserIDAndGroupID / GetActiveByUserIDAndGroupID ---
func (s *UserSubscriptionRepoSuite) TestGetByUserIDAndGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "byuser@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-byuser"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
got, err := s.repo.GetByUserIDAndGroupID(s.ctx, user.ID, group.ID)
s.Require().NoError(err, "GetByUserIDAndGroupID")
s.Require().Equal(sub.ID, got.ID)
s.Require().NotNil(got.Group, "expected Group preload")
}
func (s *UserSubscriptionRepoSuite) TestGetByUserIDAndGroupID_NotFound() {
_, err := s.repo.GetByUserIDAndGroupID(s.ctx, 999999, 999999)
s.Require().Error(err, "expected error for non-existent pair")
}
func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "active@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-active"})
// Create active subscription (future expiry)
active := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(2 * time.Hour),
})
got, err := s.repo.GetActiveByUserIDAndGroupID(s.ctx, user.ID, group.ID)
s.Require().NoError(err, "GetActiveByUserIDAndGroupID")
s.Require().Equal(active.ID, got.ID)
}
func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID_ExpiredIgnored() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "expired@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-expired"})
// Create expired subscription (past expiry but active status)
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-2 * time.Hour),
})
_, err := s.repo.GetActiveByUserIDAndGroupID(s.ctx, user.ID, group.ID)
s.Require().Error(err, "expected error for expired subscription")
}
// --- ListByUserID / ListActiveByUserID ---
func (s *UserSubscriptionRepoSuite) TestListByUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listby@test.com"})
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list1"})
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list2"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: g1.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: g2.ID,
Status: model.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour),
})
subs, err := s.repo.ListByUserID(s.ctx, user.ID)
s.Require().NoError(err, "ListByUserID")
s.Require().Len(subs, 2)
for _, sub := range subs {
s.Require().NotNil(sub.Group, "expected Group preload")
}
}
func (s *UserSubscriptionRepoSuite) TestListActiveByUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listactive@test.com"})
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-act1"})
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-act2"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: g1.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: g2.ID,
Status: model.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour),
})
subs, err := s.repo.ListActiveByUserID(s.ctx, user.ID)
s.Require().NoError(err, "ListActiveByUserID")
s.Require().Len(subs, 1)
s.Require().Equal(model.SubscriptionStatusActive, subs[0].Status)
}
// --- ListByGroupID ---
func (s *UserSubscriptionRepoSuite) TestListByGroupID() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "u1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "u2@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-listgrp"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user1.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user2.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
subs, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByGroupID")
s.Require().Len(subs, 2)
s.Require().Equal(int64(2), page.Total)
for _, sub := range subs {
s.Require().NotNil(sub.User, "expected User preload")
s.Require().NotNil(sub.Group, "expected Group preload")
}
}
// --- List with filters ---
func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "list@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "")
s.Require().NoError(err, "List")
s.Require().Len(subs, 1)
s.Require().Equal(int64(1), page.Total)
}
func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "filter1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "filter2@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-filter"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user1.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user2.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "")
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(user1.ID, subs[0].UserID)
}
func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "grpfilter@test.com"})
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-f1"})
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-f2"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: g1.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: g2.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "")
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(g1.ID, subs[0].GroupID)
}
func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "statfilter@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-stat"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour),
})
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, model.SubscriptionStatusExpired)
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(model.SubscriptionStatusExpired, subs[0].Status)
}
// --- Usage tracking ---
func (s *UserSubscriptionRepoSuite) TestIncrementUsage() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "usage@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-usage"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
err := s.repo.IncrementUsage(s.ctx, sub.ID, 1.25)
s.Require().NoError(err, "IncrementUsage")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().Equal(1.25, got.DailyUsageUSD)
s.Require().Equal(1.25, got.WeeklyUsageUSD)
s.Require().Equal(1.25, got.MonthlyUsageUSD)
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Accumulates() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "accum@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-accum"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
s.Require().NoError(s.repo.IncrementUsage(s.ctx, sub.ID, 1.0))
s.Require().NoError(s.repo.IncrementUsage(s.ctx, sub.ID, 2.5))
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().Equal(3.5, got.DailyUsageUSD)
}
func (s *UserSubscriptionRepoSuite) TestActivateWindows() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "activate@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-activate"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
activateAt := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
err := s.repo.ActivateWindows(s.ctx, sub.ID, activateAt)
s.Require().NoError(err, "ActivateWindows")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().NotNil(got.DailyWindowStart)
s.Require().NotNil(got.WeeklyWindowStart)
s.Require().NotNil(got.MonthlyWindowStart)
s.Require().True(got.DailyWindowStart.Equal(activateAt))
}
func (s *UserSubscriptionRepoSuite) TestResetDailyUsage() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "resetd@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-resetd"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
DailyUsageUSD: 10.0,
WeeklyUsageUSD: 20.0,
})
resetAt := time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC)
err := s.repo.ResetDailyUsage(s.ctx, sub.ID, resetAt)
s.Require().NoError(err, "ResetDailyUsage")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().Zero(got.DailyUsageUSD)
s.Require().Equal(20.0, got.WeeklyUsageUSD, "weekly should remain unchanged")
s.Require().True(got.DailyWindowStart.Equal(resetAt))
}
func (s *UserSubscriptionRepoSuite) TestResetWeeklyUsage() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "resetw@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-resetw"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
WeeklyUsageUSD: 15.0,
MonthlyUsageUSD: 30.0,
})
resetAt := time.Date(2025, 1, 6, 0, 0, 0, 0, time.UTC)
err := s.repo.ResetWeeklyUsage(s.ctx, sub.ID, resetAt)
s.Require().NoError(err, "ResetWeeklyUsage")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().Zero(got.WeeklyUsageUSD)
s.Require().Equal(30.0, got.MonthlyUsageUSD, "monthly should remain unchanged")
s.Require().True(got.WeeklyWindowStart.Equal(resetAt))
}
func (s *UserSubscriptionRepoSuite) TestResetMonthlyUsage() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "resetm@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-resetm"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
MonthlyUsageUSD: 100.0,
})
resetAt := time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC)
err := s.repo.ResetMonthlyUsage(s.ctx, sub.ID, resetAt)
s.Require().NoError(err, "ResetMonthlyUsage")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().Zero(got.MonthlyUsageUSD)
s.Require().True(got.MonthlyWindowStart.Equal(resetAt))
}
// --- UpdateStatus / ExtendExpiry / UpdateNotes ---
func (s *UserSubscriptionRepoSuite) TestUpdateStatus() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "status@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-status"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
err := s.repo.UpdateStatus(s.ctx, sub.ID, model.SubscriptionStatusExpired)
s.Require().NoError(err, "UpdateStatus")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().Equal(model.SubscriptionStatusExpired, got.Status)
}
func (s *UserSubscriptionRepoSuite) TestExtendExpiry() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "extend@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-extend"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
newExpiry := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
err := s.repo.ExtendExpiry(s.ctx, sub.ID, newExpiry)
s.Require().NoError(err, "ExtendExpiry")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().True(got.ExpiresAt.Equal(newExpiry))
}
func (s *UserSubscriptionRepoSuite) TestUpdateNotes() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "notes@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-notes"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
err := s.repo.UpdateNotes(s.ctx, sub.ID, "VIP user")
s.Require().NoError(err, "UpdateNotes")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().Equal("VIP user", got.Notes)
}
// --- ListExpired / BatchUpdateExpiredStatus ---
func (s *UserSubscriptionRepoSuite) TestListExpired() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listexp@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-listexp"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-24 * time.Hour),
})
expired, err := s.repo.ListExpired(s.ctx)
s.Require().NoError(err, "ListExpired")
s.Require().Len(expired, 1)
}
func (s *UserSubscriptionRepoSuite) TestBatchUpdateExpiredStatus() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "batch@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-batch"})
active := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
expiredActive := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-24 * time.Hour),
})
affected, err := s.repo.BatchUpdateExpiredStatus(s.ctx)
s.Require().NoError(err, "BatchUpdateExpiredStatus")
s.Require().Equal(int64(1), affected)
gotActive, _ := s.repo.GetByID(s.ctx, active.ID)
s.Require().Equal(model.SubscriptionStatusActive, gotActive.Status)
gotExpired, _ := s.repo.GetByID(s.ctx, expiredActive.ID)
s.Require().Equal(model.SubscriptionStatusExpired, gotExpired.Status)
}
// --- ExistsByUserIDAndGroupID ---
func (s *UserSubscriptionRepoSuite) TestExistsByUserIDAndGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "exists@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-exists"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
exists, err := s.repo.ExistsByUserIDAndGroupID(s.ctx, user.ID, group.ID)
s.Require().NoError(err, "ExistsByUserIDAndGroupID")
s.Require().True(exists)
notExists, err := s.repo.ExistsByUserIDAndGroupID(s.ctx, user.ID, 999999)
s.Require().NoError(err)
s.Require().False(notExists)
}
// --- CountByGroupID / CountActiveByGroupID ---
func (s *UserSubscriptionRepoSuite) TestCountByGroupID() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "cnt1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "cnt2@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-count"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user1.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user2.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour),
})
count, err := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "CountByGroupID")
s.Require().Equal(int64(2), count)
}
func (s *UserSubscriptionRepoSuite) TestCountActiveByGroupID() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "cntact1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "cntact2@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-cntact"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user1.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user2.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-24 * time.Hour), // expired by time
})
count, err := s.repo.CountActiveByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "CountActiveByGroupID")
s.Require().Equal(int64(1), count, "only future expiry counts as active")
}
// --- DeleteByGroupID ---
func (s *UserSubscriptionRepoSuite) TestDeleteByGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delgrp@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-delgrp"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour),
})
affected, err := s.repo.DeleteByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "DeleteByGroupID")
s.Require().Equal(int64(2), affected)
count, _ := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().Zero(count)
}
// --- Combined original test ---
func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_BatchUpdateExpiredStatus() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "subr@example.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-subr"})
active := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(2 * time.Hour),
})
expiredActive := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-2 * time.Hour),
})
got, err := s.repo.GetActiveByUserIDAndGroupID(s.ctx, user.ID, group.ID)
s.Require().NoError(err, "GetActiveByUserIDAndGroupID")
s.Require().Equal(active.ID, got.ID, "expected active subscription")
activateAt := time.Now().Add(-25 * time.Hour)
s.Require().NoError(s.repo.ActivateWindows(s.ctx, active.ID, activateAt), "ActivateWindows")
s.Require().NoError(s.repo.IncrementUsage(s.ctx, active.ID, 1.25), "IncrementUsage")
after, err := s.repo.GetByID(s.ctx, active.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal(1.25, after.DailyUsageUSD, "DailyUsageUSD mismatch")
s.Require().Equal(1.25, after.WeeklyUsageUSD, "WeeklyUsageUSD mismatch")
s.Require().Equal(1.25, after.MonthlyUsageUSD, "MonthlyUsageUSD mismatch")
s.Require().NotNil(after.DailyWindowStart, "expected DailyWindowStart activated")
s.Require().NotNil(after.WeeklyWindowStart, "expected WeeklyWindowStart activated")
s.Require().NotNil(after.MonthlyWindowStart, "expected MonthlyWindowStart activated")
resetAt := time.Now().Truncate(time.Microsecond) // truncate to microsecond for DB precision
s.Require().NoError(s.repo.ResetDailyUsage(s.ctx, active.ID, resetAt), "ResetDailyUsage")
afterReset, err := s.repo.GetByID(s.ctx, active.ID)
s.Require().NoError(err, "GetByID after reset")
s.Require().Equal(0.0, afterReset.DailyUsageUSD, "expected daily usage reset to 0")
s.Require().NotNil(afterReset.DailyWindowStart, "expected DailyWindowStart not nil")
s.Require().True(afterReset.DailyWindowStart.Equal(resetAt), "expected daily window start updated")
affected, err := s.repo.BatchUpdateExpiredStatus(s.ctx)
s.Require().NoError(err, "BatchUpdateExpiredStatus")
s.Require().Equal(int64(1), affected, "expected 1 affected row")
updated, err := s.repo.GetByID(s.ctx, expiredActive.ID)
s.Require().NoError(err, "GetByID expired")
s.Require().Equal(model.SubscriptionStatusExpired, updated.Status, "expected status expired")
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment