Commit bfcd9501 authored by IanShaw027's avatar IanShaw027
Browse files

merge: 合并 upstream/main 解决 PR #37 冲突

- 删除 backend/internal/model/account.go 符合重构方向
- 合并最新的项目结构重构
- 包含 SSE 格式解析修复
- 更新依赖和配置文件
parents 9780f0fd 12252c60
//go:build unit
package repository
import (
"math"
"testing"
"github.com/stretchr/testify/require"
)
func TestApiKeyRateLimitKey(t *testing.T) {
tests := []struct {
name string
userID int64
expected string
}{
{
name: "normal_user_id",
userID: 123,
expected: "apikey:ratelimit:123",
},
{
name: "zero_user_id",
userID: 0,
expected: "apikey:ratelimit:0",
},
{
name: "negative_user_id",
userID: -1,
expected: "apikey:ratelimit:-1",
},
{
name: "max_int64",
userID: math.MaxInt64,
expected: "apikey:ratelimit:9223372036854775807",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := apiKeyRateLimitKey(tc.userID)
require.Equal(t, tc.expected, got)
})
}
}
...@@ -2,10 +2,10 @@ package repository ...@@ -2,10 +2,10 @@ package repository
import ( import (
"context" "context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm" "gorm.io/gorm"
...@@ -19,42 +19,51 @@ func NewApiKeyRepository(db *gorm.DB) service.ApiKeyRepository { ...@@ -19,42 +19,51 @@ func NewApiKeyRepository(db *gorm.DB) service.ApiKeyRepository {
return &apiKeyRepository{db: db} return &apiKeyRepository{db: db}
} }
func (r *apiKeyRepository) Create(ctx context.Context, key *model.ApiKey) error { func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) error {
err := r.db.WithContext(ctx).Create(key).Error m := apiKeyModelFromService(key)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyApiKeyModelToService(key, m)
}
return translatePersistenceError(err, nil, service.ErrApiKeyExists) return translatePersistenceError(err, nil, service.ErrApiKeyExists)
} }
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) { func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
var key model.ApiKey var m apiKeyModel
err := r.db.WithContext(ctx).Preload("User").Preload("Group").First(&key, id).Error err := r.db.WithContext(ctx).Preload("User").Preload("Group").First(&m, id).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil) return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil)
} }
return &key, nil return apiKeyModelToService(&m), nil
} }
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*model.ApiKey, error) { func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
var apiKey model.ApiKey var m apiKeyModel
err := r.db.WithContext(ctx).Preload("User").Preload("Group").Where("key = ?", key).First(&apiKey).Error err := r.db.WithContext(ctx).Preload("User").Preload("Group").Where("key = ?", key).First(&m).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil) return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil)
} }
return &apiKey, nil return apiKeyModelToService(&m), nil
} }
func (r *apiKeyRepository) Update(ctx context.Context, key *model.ApiKey) error { func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) error {
return r.db.WithContext(ctx).Model(key).Select("name", "group_id", "status", "updated_at").Updates(key).Error m := apiKeyModelFromService(key)
err := r.db.WithContext(ctx).Model(m).Select("name", "group_id", "status", "updated_at").Updates(m).Error
if err == nil {
applyApiKeyModelToService(key, m)
}
return err
} }
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error { func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.ApiKey{}, id).Error return r.db.WithContext(ctx).Delete(&apiKeyModel{}, id).Error
} }
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) { func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
var keys []model.ApiKey var keys []apiKeyModel
var total int64 var total int64
db := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("user_id = ?", userID) db := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("user_id = ?", userID)
if err := db.Count(&total).Error; err != nil { if err := db.Count(&total).Error; err != nil {
return nil, nil, err return nil, nil, err
...@@ -64,36 +73,31 @@ func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, param ...@@ -64,36 +73,31 @@ func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
return nil, nil, err return nil, nil, err
} }
pages := int(total) / params.Limit() outKeys := make([]service.ApiKey, 0, len(keys))
if int(total)%params.Limit() > 0 { for i := range keys {
pages++ outKeys = append(outKeys, *apiKeyModelToService(&keys[i]))
} }
return keys, &pagination.PaginationResult{ return outKeys, paginationResultFromTotal(total, params), nil
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
} }
func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) { func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("user_id = ?", userID).Count(&count).Error err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("user_id = ?", userID).Count(&count).Error
return count, err return count, err
} }
func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) { func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("key = ?", key).Count(&count).Error err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("key = ?", key).Count(&count).Error
return count > 0, err return count > 0, err
} }
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) { func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
var keys []model.ApiKey var keys []apiKeyModel
var total int64 var total int64
db := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("group_id = ?", groupID) db := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("group_id = ?", groupID)
if err := db.Count(&total).Error; err != nil { if err := db.Count(&total).Error; err != nil {
return nil, nil, err return nil, nil, err
...@@ -103,24 +107,19 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par ...@@ -103,24 +107,19 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
return nil, nil, err return nil, nil, err
} }
pages := int(total) / params.Limit() outKeys := make([]service.ApiKey, 0, len(keys))
if int(total)%params.Limit() > 0 { for i := range keys {
pages++ outKeys = append(outKeys, *apiKeyModelToService(&keys[i]))
} }
return keys, &pagination.PaginationResult{ return outKeys, paginationResultFromTotal(total, params), nil
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
} }
// SearchApiKeys searches API keys by user ID and/or keyword (name) // SearchApiKeys searches API keys by user ID and/or keyword (name)
func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) { func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
var keys []model.ApiKey var keys []apiKeyModel
db := r.db.WithContext(ctx).Model(&model.ApiKey{}) db := r.db.WithContext(ctx).Model(&apiKeyModel{})
if userID > 0 { if userID > 0 {
db = db.Where("user_id = ?", userID) db = db.Where("user_id = ?", userID)
...@@ -135,12 +134,16 @@ func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw ...@@ -135,12 +134,16 @@ func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw
return nil, err return nil, err
} }
return keys, nil outKeys := make([]service.ApiKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyModelToService(&keys[i]))
}
return outKeys, nil
} }
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil // ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Model(&model.ApiKey{}). result := r.db.WithContext(ctx).Model(&apiKeyModel{}).
Where("group_id = ?", groupID). Where("group_id = ?", groupID).
Update("group_id", nil) Update("group_id", nil)
return result.RowsAffected, result.Error return result.RowsAffected, result.Error
...@@ -149,6 +152,66 @@ func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID in ...@@ -149,6 +152,66 @@ func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID in
// CountByGroupID 获取分组的 API Key 数量 // CountByGroupID 获取分组的 API Key 数量
func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("group_id = ?", groupID).Count(&count).Error err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("group_id = ?", groupID).Count(&count).Error
return count, err return count, err
} }
type apiKeyModel struct {
ID int64 `gorm:"primaryKey"`
UserID int64 `gorm:"index;not null"`
Key string `gorm:"uniqueIndex;size:128;not null"`
Name string `gorm:"size:100;not null"`
GroupID *int64 `gorm:"index"`
Status string `gorm:"size:20;default:active;not null"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
DeletedAt gorm.DeletedAt `gorm:"index"`
User *userModel `gorm:"foreignKey:UserID"`
Group *groupModel `gorm:"foreignKey:GroupID"`
}
func (apiKeyModel) TableName() string { return "api_keys" }
func apiKeyModelToService(m *apiKeyModel) *service.ApiKey {
if m == nil {
return nil
}
return &service.ApiKey{
ID: m.ID,
UserID: m.UserID,
Key: m.Key,
Name: m.Name,
GroupID: m.GroupID,
Status: m.Status,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
User: userModelToService(m.User),
Group: groupModelToService(m.Group),
}
}
func apiKeyModelFromService(k *service.ApiKey) *apiKeyModel {
if k == nil {
return nil
}
return &apiKeyModel{
ID: k.ID,
UserID: k.UserID,
Key: k.Key,
Name: k.Name,
GroupID: k.GroupID,
Status: k.Status,
CreatedAt: k.CreatedAt,
UpdatedAt: k.UpdatedAt,
}
}
func applyApiKeyModelToService(key *service.ApiKey, m *apiKeyModel) {
if key == nil || m == nil {
return
}
key.ID = m.ID
key.CreatedAt = m.CreatedAt
key.UpdatedAt = m.UpdatedAt
}
...@@ -6,8 +6,8 @@ import ( ...@@ -6,8 +6,8 @@ import (
"context" "context"
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"gorm.io/gorm" "gorm.io/gorm"
) )
...@@ -32,13 +32,13 @@ func TestApiKeyRepoSuite(t *testing.T) { ...@@ -32,13 +32,13 @@ func TestApiKeyRepoSuite(t *testing.T) {
// --- Create / GetByID / GetByKey --- // --- Create / GetByID / GetByKey ---
func (s *ApiKeyRepoSuite) TestCreate() { func (s *ApiKeyRepoSuite) TestCreate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "create@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "create@test.com"})
key := &model.ApiKey{ key := &service.ApiKey{
UserID: user.ID, UserID: user.ID,
Key: "sk-create-test", Key: "sk-create-test",
Name: "Test Key", Name: "Test Key",
Status: model.StatusActive, Status: service.StatusActive,
} }
err := s.repo.Create(s.ctx, key) err := s.repo.Create(s.ctx, key)
...@@ -56,15 +56,15 @@ func (s *ApiKeyRepoSuite) TestGetByID_NotFound() { ...@@ -56,15 +56,15 @@ func (s *ApiKeyRepoSuite) TestGetByID_NotFound() {
} }
func (s *ApiKeyRepoSuite) TestGetByKey() { func (s *ApiKeyRepoSuite) TestGetByKey() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "getbykey@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "getbykey@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-key"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-key"})
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{ key := mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID, UserID: user.ID,
Key: "sk-getbykey", Key: "sk-getbykey",
Name: "My Key", Name: "My Key",
GroupID: &group.ID, GroupID: &group.ID,
Status: model.StatusActive, Status: service.StatusActive,
}) })
got, err := s.repo.GetByKey(s.ctx, key.Key) got, err := s.repo.GetByKey(s.ctx, key.Key)
...@@ -84,16 +84,16 @@ func (s *ApiKeyRepoSuite) TestGetByKey_NotFound() { ...@@ -84,16 +84,16 @@ func (s *ApiKeyRepoSuite) TestGetByKey_NotFound() {
// --- Update --- // --- Update ---
func (s *ApiKeyRepoSuite) TestUpdate() { func (s *ApiKeyRepoSuite) TestUpdate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "update@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "update@test.com"})
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{ key := apiKeyModelToService(mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID, UserID: user.ID,
Key: "sk-update", Key: "sk-update",
Name: "Original", Name: "Original",
Status: model.StatusActive, Status: service.StatusActive,
}) }))
key.Name = "Renamed" key.Name = "Renamed"
key.Status = model.StatusDisabled key.Status = service.StatusDisabled
err := s.repo.Update(s.ctx, key) err := s.repo.Update(s.ctx, key)
s.Require().NoError(err, "Update") s.Require().NoError(err, "Update")
...@@ -102,18 +102,18 @@ func (s *ApiKeyRepoSuite) TestUpdate() { ...@@ -102,18 +102,18 @@ func (s *ApiKeyRepoSuite) TestUpdate() {
s.Require().Equal("sk-update", got.Key, "Update should not change key") s.Require().Equal("sk-update", got.Key, "Update should not change key")
s.Require().Equal(user.ID, got.UserID, "Update should not change user_id") s.Require().Equal(user.ID, got.UserID, "Update should not change user_id")
s.Require().Equal("Renamed", got.Name) s.Require().Equal("Renamed", got.Name)
s.Require().Equal(model.StatusDisabled, got.Status) s.Require().Equal(service.StatusDisabled, got.Status)
} }
func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() { func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "cleargroup@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "cleargroup@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-clear"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-clear"})
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{ key := apiKeyModelToService(mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID, UserID: user.ID,
Key: "sk-clear-group", Key: "sk-clear-group",
Name: "Group Key", Name: "Group Key",
GroupID: &group.ID, GroupID: &group.ID,
}) }))
key.GroupID = nil key.GroupID = nil
err := s.repo.Update(s.ctx, key) err := s.repo.Update(s.ctx, key)
...@@ -127,8 +127,8 @@ func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() { ...@@ -127,8 +127,8 @@ func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
// --- Delete --- // --- Delete ---
func (s *ApiKeyRepoSuite) TestDelete() { func (s *ApiKeyRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"})
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{ key := mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID, UserID: user.ID,
Key: "sk-delete", Key: "sk-delete",
Name: "Delete Me", Name: "Delete Me",
...@@ -144,9 +144,9 @@ func (s *ApiKeyRepoSuite) TestDelete() { ...@@ -144,9 +144,9 @@ func (s *ApiKeyRepoSuite) TestDelete() {
// --- ListByUserID / CountByUserID --- // --- ListByUserID / CountByUserID ---
func (s *ApiKeyRepoSuite) TestListByUserID() { func (s *ApiKeyRepoSuite) TestListByUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyuser@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyuser@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-list-1", Name: "Key 1"}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-list-1", Name: "Key 1"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-list-2", Name: "Key 2"}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-list-2", Name: "Key 2"})
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10}) keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByUserID") s.Require().NoError(err, "ListByUserID")
...@@ -155,9 +155,9 @@ func (s *ApiKeyRepoSuite) TestListByUserID() { ...@@ -155,9 +155,9 @@ func (s *ApiKeyRepoSuite) TestListByUserID() {
} }
func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() { func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "paging@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "paging@test.com"})
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
mustCreateApiKey(s.T(), s.db, &model.ApiKey{ mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID, UserID: user.ID,
Key: "sk-page-" + string(rune('a'+i)), Key: "sk-page-" + string(rune('a'+i)),
Name: "Key", Name: "Key",
...@@ -172,9 +172,9 @@ func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() { ...@@ -172,9 +172,9 @@ func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
} }
func (s *ApiKeyRepoSuite) TestCountByUserID() { func (s *ApiKeyRepoSuite) TestCountByUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "count@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "count@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-count-1", Name: "K1"}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-count-1", Name: "K1"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-count-2", Name: "K2"}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-count-2", Name: "K2"})
count, err := s.repo.CountByUserID(s.ctx, user.ID) count, err := s.repo.CountByUserID(s.ctx, user.ID)
s.Require().NoError(err, "CountByUserID") s.Require().NoError(err, "CountByUserID")
...@@ -184,12 +184,12 @@ func (s *ApiKeyRepoSuite) TestCountByUserID() { ...@@ -184,12 +184,12 @@ func (s *ApiKeyRepoSuite) TestCountByUserID() {
// --- ListByGroupID / CountByGroupID --- // --- ListByGroupID / CountByGroupID ---
func (s *ApiKeyRepoSuite) TestListByGroupID() { func (s *ApiKeyRepoSuite) TestListByGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbygroup@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbygroup@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-grp-1", Name: "K1", GroupID: &group.ID}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-grp-1", Name: "K1", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-grp-2", Name: "K2", GroupID: &group.ID}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-grp-2", Name: "K2", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-grp-3", Name: "K3"}) // no group mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-grp-3", Name: "K3"}) // no group
keys, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10}) keys, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByGroupID") s.Require().NoError(err, "ListByGroupID")
...@@ -200,10 +200,10 @@ func (s *ApiKeyRepoSuite) TestListByGroupID() { ...@@ -200,10 +200,10 @@ func (s *ApiKeyRepoSuite) TestListByGroupID() {
} }
func (s *ApiKeyRepoSuite) TestCountByGroupID() { func (s *ApiKeyRepoSuite) TestCountByGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "countgroup@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "countgroup@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-count"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-count"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-gc-1", Name: "K1", GroupID: &group.ID}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-gc-1", Name: "K1", GroupID: &group.ID})
count, err := s.repo.CountByGroupID(s.ctx, group.ID) count, err := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "CountByGroupID") s.Require().NoError(err, "CountByGroupID")
...@@ -213,8 +213,8 @@ func (s *ApiKeyRepoSuite) TestCountByGroupID() { ...@@ -213,8 +213,8 @@ func (s *ApiKeyRepoSuite) TestCountByGroupID() {
// --- ExistsByKey --- // --- ExistsByKey ---
func (s *ApiKeyRepoSuite) TestExistsByKey() { func (s *ApiKeyRepoSuite) TestExistsByKey() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "exists@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "exists@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-exists", Name: "K"}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-exists", Name: "K"})
exists, err := s.repo.ExistsByKey(s.ctx, "sk-exists") exists, err := s.repo.ExistsByKey(s.ctx, "sk-exists")
s.Require().NoError(err, "ExistsByKey") s.Require().NoError(err, "ExistsByKey")
...@@ -228,9 +228,9 @@ func (s *ApiKeyRepoSuite) TestExistsByKey() { ...@@ -228,9 +228,9 @@ func (s *ApiKeyRepoSuite) TestExistsByKey() {
// --- SearchApiKeys --- // --- SearchApiKeys ---
func (s *ApiKeyRepoSuite) TestSearchApiKeys() { func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "search@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "search@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-search-1", Name: "Production Key"}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-search-1", Name: "Production Key"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-search-2", Name: "Development Key"}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-search-2", Name: "Development Key"})
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "prod", 10) found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "prod", 10)
s.Require().NoError(err, "SearchApiKeys") s.Require().NoError(err, "SearchApiKeys")
...@@ -239,9 +239,9 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys() { ...@@ -239,9 +239,9 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
} }
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() { func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "searchnokw@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "searchnokw@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-nk-1", Name: "K1"}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-nk-1", Name: "K1"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-nk-2", Name: "K2"}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-nk-2", Name: "K2"})
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "", 10) found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "", 10)
s.Require().NoError(err) s.Require().NoError(err)
...@@ -249,8 +249,8 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() { ...@@ -249,8 +249,8 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
} }
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() { func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "searchnouid@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "searchnouid@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-nu-1", Name: "TestKey"}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-nu-1", Name: "TestKey"})
found, err := s.repo.SearchApiKeys(s.ctx, 0, "testkey", 10) found, err := s.repo.SearchApiKeys(s.ctx, 0, "testkey", 10)
s.Require().NoError(err) s.Require().NoError(err)
...@@ -260,12 +260,12 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() { ...@@ -260,12 +260,12 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
// --- ClearGroupIDByGroupID --- // --- ClearGroupIDByGroupID ---
func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() { func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "cleargrp@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "cleargrp@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-clear-bulk"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-clear-bulk"})
k1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-clr-1", Name: "K1", GroupID: &group.ID}) k1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-clr-1", Name: "K1", GroupID: &group.ID})
k2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-clr-2", Name: "K2", GroupID: &group.ID}) k2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-clr-2", Name: "K2", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-clr-3", Name: "K3"}) // no group mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-clr-3", Name: "K3"}) // no group
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID) affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "ClearGroupIDByGroupID") s.Require().NoError(err, "ClearGroupIDByGroupID")
...@@ -283,16 +283,16 @@ func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() { ...@@ -283,16 +283,16 @@ func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
// --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) --- // --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) ---
func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() { func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "k@example.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "k@example.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-k"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-k"})
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{ key := apiKeyModelToService(mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID, UserID: user.ID,
Key: "sk-test-1", Key: "sk-test-1",
Name: "My Key", Name: "My Key",
GroupID: &group.ID, GroupID: &group.ID,
Status: model.StatusActive, Status: service.StatusActive,
}) }))
got, err := s.repo.GetByKey(s.ctx, key.Key) got, err := s.repo.GetByKey(s.ctx, key.Key)
s.Require().NoError(err, "GetByKey") s.Require().NoError(err, "GetByKey")
...@@ -303,7 +303,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() { ...@@ -303,7 +303,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s.Require().Equal(group.ID, got.Group.ID) s.Require().Equal(group.ID, got.Group.ID)
key.Name = "Renamed" key.Name = "Renamed"
key.Status = model.StatusDisabled key.Status = service.StatusDisabled
key.GroupID = nil key.GroupID = nil
s.Require().NoError(s.repo.Update(s.ctx, key), "Update") s.Require().NoError(s.repo.Update(s.ctx, key), "Update")
...@@ -312,7 +312,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() { ...@@ -312,7 +312,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s.Require().Equal("sk-test-1", got2.Key, "Update should not change key") s.Require().Equal("sk-test-1", got2.Key, "Update should not change key")
s.Require().Equal(user.ID, got2.UserID, "Update should not change user_id") s.Require().Equal(user.ID, got2.UserID, "Update should not change user_id")
s.Require().Equal("Renamed", got2.Name) s.Require().Equal("Renamed", got2.Name)
s.Require().Equal(model.StatusDisabled, got2.Status) s.Require().Equal(service.StatusDisabled, got2.Status)
s.Require().Nil(got2.GroupID) s.Require().Nil(got2.GroupID)
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10}) keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
...@@ -330,7 +330,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() { ...@@ -330,7 +330,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s.Require().Equal(key.ID, found[0].ID) s.Require().Equal(key.ID, found[0].ID)
// ClearGroupIDByGroupID // ClearGroupIDByGroupID
k2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{ k2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID, UserID: user.ID,
Key: "sk-test-2", Key: "sk-test-2",
Name: "Group Key", Name: "Group Key",
......
package repository
import "gorm.io/gorm"
// AutoMigrate runs schema migrations for all repository persistence models.
// Persistence models are defined within individual `*_repo.go` files.
func AutoMigrate(db *gorm.DB) error {
return db.AutoMigrate(
&userModel{},
&apiKeyModel{},
&groupModel{},
&accountModel{},
&accountGroupModel{},
&proxyModel{},
&redeemCodeModel{},
&usageLogModel{},
&settingModel{},
&userSubscriptionModel{},
)
}
...@@ -18,6 +18,16 @@ const ( ...@@ -18,6 +18,16 @@ const (
billingCacheTTL = 5 * time.Minute billingCacheTTL = 5 * time.Minute
) )
// billingBalanceKey generates the Redis key for user balance cache.
func billingBalanceKey(userID int64) string {
return fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
}
// billingSubKey generates the Redis key for subscription cache.
func billingSubKey(userID, groupID int64) string {
return fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
}
const ( const (
subFieldStatus = "status" subFieldStatus = "status"
subFieldExpiresAt = "expires_at" subFieldExpiresAt = "expires_at"
...@@ -62,7 +72,7 @@ func NewBillingCache(rdb *redis.Client) service.BillingCache { ...@@ -62,7 +72,7 @@ func NewBillingCache(rdb *redis.Client) service.BillingCache {
} }
func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float64, error) { func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID) key := billingBalanceKey(userID)
val, err := c.rdb.Get(ctx, key).Result() val, err := c.rdb.Get(ctx, key).Result()
if err != nil { if err != nil {
return 0, err return 0, err
...@@ -71,12 +81,12 @@ func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float6 ...@@ -71,12 +81,12 @@ func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float6
} }
func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error { func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID) key := billingBalanceKey(userID)
return c.rdb.Set(ctx, key, balance, billingCacheTTL).Err() return c.rdb.Set(ctx, key, balance, billingCacheTTL).Err()
} }
func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error { func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID) key := billingBalanceKey(userID)
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result() _, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) { if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err) log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
...@@ -85,12 +95,12 @@ func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amou ...@@ -85,12 +95,12 @@ func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amou
} }
func (c *billingCache) InvalidateUserBalance(ctx context.Context, userID int64) error { func (c *billingCache) InvalidateUserBalance(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID) key := billingBalanceKey(userID)
return c.rdb.Del(ctx, key).Err() return c.rdb.Del(ctx, key).Err()
} }
func (c *billingCache) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*service.SubscriptionCacheData, error) { func (c *billingCache) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*service.SubscriptionCacheData, error) {
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID) key := billingSubKey(userID, groupID)
result, err := c.rdb.HGetAll(ctx, key).Result() result, err := c.rdb.HGetAll(ctx, key).Result()
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -140,7 +150,7 @@ func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID ...@@ -140,7 +150,7 @@ func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID
return nil return nil
} }
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID) key := billingSubKey(userID, groupID)
fields := map[string]any{ fields := map[string]any{
subFieldStatus: data.Status, subFieldStatus: data.Status,
...@@ -159,7 +169,7 @@ func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID ...@@ -159,7 +169,7 @@ func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID
} }
func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error { func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID) key := billingSubKey(userID, groupID)
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(billingCacheTTL.Seconds())).Result() _, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(billingCacheTTL.Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) { if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err) log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
...@@ -168,6 +178,6 @@ func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, grou ...@@ -168,6 +178,6 @@ func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, grou
} }
func (c *billingCache) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error { func (c *billingCache) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID) key := billingSubKey(userID, groupID)
return c.rdb.Del(ctx, key).Err() return c.rdb.Del(ctx, key).Err()
} }
//go:build unit
package repository
import (
"math"
"testing"
"github.com/stretchr/testify/require"
)
func TestBillingBalanceKey(t *testing.T) {
tests := []struct {
name string
userID int64
expected string
}{
{
name: "normal_user_id",
userID: 123,
expected: "billing:balance:123",
},
{
name: "zero_user_id",
userID: 0,
expected: "billing:balance:0",
},
{
name: "negative_user_id",
userID: -1,
expected: "billing:balance:-1",
},
{
name: "max_int64",
userID: math.MaxInt64,
expected: "billing:balance:9223372036854775807",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := billingBalanceKey(tc.userID)
require.Equal(t, tc.expected, got)
})
}
}
func TestBillingSubKey(t *testing.T) {
tests := []struct {
name string
userID int64
groupID int64
expected string
}{
{
name: "normal_ids",
userID: 123,
groupID: 456,
expected: "billing:sub:123:456",
},
{
name: "zero_ids",
userID: 0,
groupID: 0,
expected: "billing:sub:0:0",
},
{
name: "negative_ids",
userID: -1,
groupID: -2,
expected: "billing:sub:-1:-2",
},
{
name: "max_int64_ids",
userID: math.MaxInt64,
groupID: math.MaxInt64,
expected: "billing:sub:9223372036854775807:9223372036854775807",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := billingSubKey(tc.userID, tc.groupID)
require.Equal(t, tc.expected, got)
})
}
}
...@@ -11,6 +11,11 @@ import ( ...@@ -11,6 +11,11 @@ import (
const verifyCodeKeyPrefix = "verify_code:" const verifyCodeKeyPrefix = "verify_code:"
// verifyCodeKey generates the Redis key for email verification code.
func verifyCodeKey(email string) string {
return verifyCodeKeyPrefix + email
}
type emailCache struct { type emailCache struct {
rdb *redis.Client rdb *redis.Client
} }
...@@ -20,7 +25,7 @@ func NewEmailCache(rdb *redis.Client) service.EmailCache { ...@@ -20,7 +25,7 @@ func NewEmailCache(rdb *redis.Client) service.EmailCache {
} }
func (c *emailCache) GetVerificationCode(ctx context.Context, email string) (*service.VerificationCodeData, error) { func (c *emailCache) GetVerificationCode(ctx context.Context, email string) (*service.VerificationCodeData, error) {
key := verifyCodeKeyPrefix + email key := verifyCodeKey(email)
val, err := c.rdb.Get(ctx, key).Result() val, err := c.rdb.Get(ctx, key).Result()
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -33,7 +38,7 @@ func (c *emailCache) GetVerificationCode(ctx context.Context, email string) (*se ...@@ -33,7 +38,7 @@ func (c *emailCache) GetVerificationCode(ctx context.Context, email string) (*se
} }
func (c *emailCache) SetVerificationCode(ctx context.Context, email string, data *service.VerificationCodeData, ttl time.Duration) error { func (c *emailCache) SetVerificationCode(ctx context.Context, email string, data *service.VerificationCodeData, ttl time.Duration) error {
key := verifyCodeKeyPrefix + email key := verifyCodeKey(email)
val, err := json.Marshal(data) val, err := json.Marshal(data)
if err != nil { if err != nil {
return err return err
...@@ -42,6 +47,6 @@ func (c *emailCache) SetVerificationCode(ctx context.Context, email string, data ...@@ -42,6 +47,6 @@ func (c *emailCache) SetVerificationCode(ctx context.Context, email string, data
} }
func (c *emailCache) DeleteVerificationCode(ctx context.Context, email string) error { func (c *emailCache) DeleteVerificationCode(ctx context.Context, email string) error {
key := verifyCodeKeyPrefix + email key := verifyCodeKey(email)
return c.rdb.Del(ctx, key).Err() return c.rdb.Del(ctx, key).Err()
} }
//go:build unit
package repository
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestVerifyCodeKey(t *testing.T) {
tests := []struct {
name string
email string
expected string
}{
{
name: "normal_email",
email: "user@example.com",
expected: "verify_code:user@example.com",
},
{
name: "empty_email",
email: "",
expected: "verify_code:",
},
{
name: "email_with_plus",
email: "user+tag@example.com",
expected: "verify_code:user+tag@example.com",
},
{
name: "email_with_special_chars",
email: "user.name+tag@sub.domain.com",
expected: "verify_code:user.name+tag@sub.domain.com",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := verifyCodeKey(tc.email)
require.Equal(t, tc.expected, got)
})
}
}
...@@ -6,21 +6,25 @@ import ( ...@@ -6,21 +6,25 @@ import (
"testing" "testing"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gorm.io/datatypes"
"gorm.io/gorm" "gorm.io/gorm"
) )
func mustCreateUser(t *testing.T, db *gorm.DB, u *model.User) *model.User { func mustCreateUser(t *testing.T, db *gorm.DB, u *userModel) *userModel {
t.Helper() t.Helper()
if u.PasswordHash == "" { if u.PasswordHash == "" {
u.PasswordHash = "test-password-hash" u.PasswordHash = "test-password-hash"
} }
if u.Role == "" { if u.Role == "" {
u.Role = model.RoleUser u.Role = service.RoleUser
} }
if u.Status == "" { if u.Status == "" {
u.Status = model.StatusActive u.Status = service.StatusActive
}
if u.Concurrency == 0 {
u.Concurrency = 5
} }
if u.CreatedAt.IsZero() { if u.CreatedAt.IsZero() {
u.CreatedAt = time.Now() u.CreatedAt = time.Now()
...@@ -32,16 +36,16 @@ func mustCreateUser(t *testing.T, db *gorm.DB, u *model.User) *model.User { ...@@ -32,16 +36,16 @@ func mustCreateUser(t *testing.T, db *gorm.DB, u *model.User) *model.User {
return u return u
} }
func mustCreateGroup(t *testing.T, db *gorm.DB, g *model.Group) *model.Group { func mustCreateGroup(t *testing.T, db *gorm.DB, g *groupModel) *groupModel {
t.Helper() t.Helper()
if g.Platform == "" { if g.Platform == "" {
g.Platform = model.PlatformAnthropic g.Platform = service.PlatformAnthropic
} }
if g.Status == "" { if g.Status == "" {
g.Status = model.StatusActive g.Status = service.StatusActive
} }
if g.SubscriptionType == "" { if g.SubscriptionType == "" {
g.SubscriptionType = model.SubscriptionTypeStandard g.SubscriptionType = service.SubscriptionTypeStandard
} }
if g.CreatedAt.IsZero() { if g.CreatedAt.IsZero() {
g.CreatedAt = time.Now() g.CreatedAt = time.Now()
...@@ -53,7 +57,7 @@ func mustCreateGroup(t *testing.T, db *gorm.DB, g *model.Group) *model.Group { ...@@ -53,7 +57,7 @@ func mustCreateGroup(t *testing.T, db *gorm.DB, g *model.Group) *model.Group {
return g return g
} }
func mustCreateProxy(t *testing.T, db *gorm.DB, p *model.Proxy) *model.Proxy { func mustCreateProxy(t *testing.T, db *gorm.DB, p *proxyModel) *proxyModel {
t.Helper() t.Helper()
if p.Protocol == "" { if p.Protocol == "" {
p.Protocol = "http" p.Protocol = "http"
...@@ -65,7 +69,7 @@ func mustCreateProxy(t *testing.T, db *gorm.DB, p *model.Proxy) *model.Proxy { ...@@ -65,7 +69,7 @@ func mustCreateProxy(t *testing.T, db *gorm.DB, p *model.Proxy) *model.Proxy {
p.Port = 8080 p.Port = 8080
} }
if p.Status == "" { if p.Status == "" {
p.Status = model.StatusActive p.Status = service.StatusActive
} }
if p.CreatedAt.IsZero() { if p.CreatedAt.IsZero() {
p.CreatedAt = time.Now() p.CreatedAt = time.Now()
...@@ -77,25 +81,25 @@ func mustCreateProxy(t *testing.T, db *gorm.DB, p *model.Proxy) *model.Proxy { ...@@ -77,25 +81,25 @@ func mustCreateProxy(t *testing.T, db *gorm.DB, p *model.Proxy) *model.Proxy {
return p return p
} }
func mustCreateAccount(t *testing.T, db *gorm.DB, a *model.Account) *model.Account { func mustCreateAccount(t *testing.T, db *gorm.DB, a *accountModel) *accountModel {
t.Helper() t.Helper()
if a.Platform == "" { if a.Platform == "" {
a.Platform = model.PlatformAnthropic a.Platform = service.PlatformAnthropic
} }
if a.Type == "" { if a.Type == "" {
a.Type = model.AccountTypeOAuth a.Type = service.AccountTypeOAuth
} }
if a.Status == "" { if a.Status == "" {
a.Status = model.StatusActive a.Status = service.StatusActive
} }
if !a.Schedulable { if !a.Schedulable {
a.Schedulable = true a.Schedulable = true
} }
if a.Credentials == nil { if a.Credentials == nil {
a.Credentials = model.JSONB{} a.Credentials = datatypes.JSONMap{}
} }
if a.Extra == nil { if a.Extra == nil {
a.Extra = model.JSONB{} a.Extra = datatypes.JSONMap{}
} }
if a.CreatedAt.IsZero() { if a.CreatedAt.IsZero() {
a.CreatedAt = time.Now() a.CreatedAt = time.Now()
...@@ -107,10 +111,10 @@ func mustCreateAccount(t *testing.T, db *gorm.DB, a *model.Account) *model.Accou ...@@ -107,10 +111,10 @@ func mustCreateAccount(t *testing.T, db *gorm.DB, a *model.Account) *model.Accou
return a return a
} }
func mustCreateApiKey(t *testing.T, db *gorm.DB, k *model.ApiKey) *model.ApiKey { func mustCreateApiKey(t *testing.T, db *gorm.DB, k *apiKeyModel) *apiKeyModel {
t.Helper() t.Helper()
if k.Status == "" { if k.Status == "" {
k.Status = model.StatusActive k.Status = service.StatusActive
} }
if k.CreatedAt.IsZero() { if k.CreatedAt.IsZero() {
k.CreatedAt = time.Now() k.CreatedAt = time.Now()
...@@ -122,13 +126,13 @@ func mustCreateApiKey(t *testing.T, db *gorm.DB, k *model.ApiKey) *model.ApiKey ...@@ -122,13 +126,13 @@ func mustCreateApiKey(t *testing.T, db *gorm.DB, k *model.ApiKey) *model.ApiKey
return k return k
} }
func mustCreateRedeemCode(t *testing.T, db *gorm.DB, c *model.RedeemCode) *model.RedeemCode { func mustCreateRedeemCode(t *testing.T, db *gorm.DB, c *redeemCodeModel) *redeemCodeModel {
t.Helper() t.Helper()
if c.Status == "" { if c.Status == "" {
c.Status = model.StatusUnused c.Status = service.StatusUnused
} }
if c.Type == "" { if c.Type == "" {
c.Type = model.RedeemTypeBalance c.Type = service.RedeemTypeBalance
} }
if c.CreatedAt.IsZero() { if c.CreatedAt.IsZero() {
c.CreatedAt = time.Now() c.CreatedAt = time.Now()
...@@ -137,10 +141,10 @@ func mustCreateRedeemCode(t *testing.T, db *gorm.DB, c *model.RedeemCode) *model ...@@ -137,10 +141,10 @@ func mustCreateRedeemCode(t *testing.T, db *gorm.DB, c *model.RedeemCode) *model
return c return c
} }
func mustCreateSubscription(t *testing.T, db *gorm.DB, s *model.UserSubscription) *model.UserSubscription { func mustCreateSubscription(t *testing.T, db *gorm.DB, s *userSubscriptionModel) *userSubscriptionModel {
t.Helper() t.Helper()
if s.Status == "" { if s.Status == "" {
s.Status = model.SubscriptionStatusActive s.Status = service.SubscriptionStatusActive
} }
now := time.Now() now := time.Now()
if s.StartsAt.IsZero() { if s.StartsAt.IsZero() {
...@@ -164,9 +168,10 @@ func mustCreateSubscription(t *testing.T, db *gorm.DB, s *model.UserSubscription ...@@ -164,9 +168,10 @@ func mustCreateSubscription(t *testing.T, db *gorm.DB, s *model.UserSubscription
func mustBindAccountToGroup(t *testing.T, db *gorm.DB, accountID, groupID int64, priority int) { func mustBindAccountToGroup(t *testing.T, db *gorm.DB, accountID, groupID int64, priority int) {
t.Helper() t.Helper()
require.NoError(t, db.Create(&model.AccountGroup{ require.NoError(t, db.Create(&accountGroupModel{
AccountID: accountID, AccountID: accountID,
GroupID: groupID, GroupID: groupID,
Priority: priority, Priority: priority,
CreatedAt: time.Now(),
}).Error, "create account_group") }).Error, "create account_group")
} }
...@@ -2,10 +2,10 @@ package repository ...@@ -2,10 +2,10 @@ package repository
import ( import (
"context" "context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm" "gorm.io/gorm"
...@@ -20,38 +20,50 @@ func NewGroupRepository(db *gorm.DB) service.GroupRepository { ...@@ -20,38 +20,50 @@ func NewGroupRepository(db *gorm.DB) service.GroupRepository {
return &groupRepository{db: db} return &groupRepository{db: db}
} }
func (r *groupRepository) Create(ctx context.Context, group *model.Group) error { func (r *groupRepository) Create(ctx context.Context, group *service.Group) error {
err := r.db.WithContext(ctx).Create(group).Error m := groupModelFromService(group)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyGroupModelToService(group, m)
}
return translatePersistenceError(err, nil, service.ErrGroupExists) return translatePersistenceError(err, nil, service.ErrGroupExists)
} }
func (r *groupRepository) GetByID(ctx context.Context, id int64) (*model.Group, error) { func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group, error) {
var group model.Group var m groupModel
err := r.db.WithContext(ctx).First(&group, id).Error err := r.db.WithContext(ctx).First(&m, id).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil) return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil)
} }
return &group, nil group := groupModelToService(&m)
count, _ := r.GetAccountCount(ctx, group.ID)
group.AccountCount = count
return group, nil
} }
func (r *groupRepository) Update(ctx context.Context, group *model.Group) error { func (r *groupRepository) Update(ctx context.Context, group *service.Group) error {
return r.db.WithContext(ctx).Save(group).Error m := groupModelFromService(group)
err := r.db.WithContext(ctx).Save(m).Error
if err == nil {
applyGroupModelToService(group, m)
}
return err
} }
func (r *groupRepository) Delete(ctx context.Context, id int64) error { func (r *groupRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.Group{}, id).Error return r.db.WithContext(ctx).Delete(&groupModel{}, id).Error
} }
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) { func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", nil) return r.ListWithFilters(ctx, params, "", "", nil)
} }
// ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive // ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive
func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error) { func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
var groups []model.Group var groups []groupModel
var total int64 var total int64
db := r.db.WithContext(ctx).Model(&model.Group{}) db := r.db.WithContext(ctx).Model(&groupModel{})
// Apply filters // Apply filters
if platform != "" { if platform != "" {
...@@ -72,68 +84,71 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination ...@@ -72,68 +84,71 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
return nil, nil, err return nil, nil, err
} }
// 获取每个分组的账号数量 outGroups := make([]service.Group, 0, len(groups))
for i := range groups { for i := range groups {
count, _ := r.GetAccountCount(ctx, groups[i].ID) outGroups = append(outGroups, *groupModelToService(&groups[i]))
groups[i].AccountCount = count
} }
pages := int(total) / params.Limit() // 获取每个分组的账号数量
if int(total)%params.Limit() > 0 { for i := range outGroups {
pages++ count, _ := r.GetAccountCount(ctx, outGroups[i].ID)
outGroups[i].AccountCount = count
} }
return groups, &pagination.PaginationResult{ return outGroups, paginationResultFromTotal(total, params), nil
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
} }
func (r *groupRepository) ListActive(ctx context.Context) ([]model.Group, error) { func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, error) {
var groups []model.Group var groups []groupModel
err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Order("id ASC").Find(&groups).Error err := r.db.WithContext(ctx).Where("status = ?", service.StatusActive).Order("id ASC").Find(&groups).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 获取每个分组的账号数量 outGroups := make([]service.Group, 0, len(groups))
for i := range groups { for i := range groups {
count, _ := r.GetAccountCount(ctx, groups[i].ID) outGroups = append(outGroups, *groupModelToService(&groups[i]))
groups[i].AccountCount = count }
// 获取每个分组的账号数量
for i := range outGroups {
count, _ := r.GetAccountCount(ctx, outGroups[i].ID)
outGroups[i].AccountCount = count
} }
return groups, nil return outGroups, nil
} }
func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error) { func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
var groups []model.Group var groups []groupModel
err := r.db.WithContext(ctx).Where("status = ? AND platform = ?", model.StatusActive, platform).Order("id ASC").Find(&groups).Error err := r.db.WithContext(ctx).Where("status = ? AND platform = ?", service.StatusActive, platform).Order("id ASC").Find(&groups).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 获取每个分组的账号数量 outGroups := make([]service.Group, 0, len(groups))
for i := range groups { for i := range groups {
count, _ := r.GetAccountCount(ctx, groups[i].ID) outGroups = append(outGroups, *groupModelToService(&groups[i]))
groups[i].AccountCount = count }
// 获取每个分组的账号数量
for i := range outGroups {
count, _ := r.GetAccountCount(ctx, outGroups[i].ID)
outGroups[i].AccountCount = count
} }
return groups, nil return outGroups, nil
} }
func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool, error) { func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.Group{}).Where("name = ?", name).Count(&count).Error err := r.db.WithContext(ctx).Model(&groupModel{}).Where("name = ?", name).Count(&count).Error
return count > 0, err return count > 0, err
} }
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.AccountGroup{}).Where("group_id = ?", groupID).Count(&count).Error err := r.db.WithContext(ctx).Table("account_groups").Where("group_id = ?", groupID).Count(&count).Error
return count, err return count, err
} }
// DeleteAccountGroupsByGroupID 删除分组与账号的关联关系 // DeleteAccountGroupsByGroupID 删除分组与账号的关联关系
func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.AccountGroup{}) result := r.db.WithContext(ctx).Exec("DELETE FROM account_groups WHERE group_id = ?", groupID)
return result.RowsAffected, result.Error return result.RowsAffected, result.Error
} }
...@@ -145,46 +160,42 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, ...@@ -145,46 +160,42 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
var affectedUserIDs []int64 var affectedUserIDs []int64
if group.IsSubscriptionType() { if group.IsSubscriptionType() {
var subscriptions []model.UserSubscription
if err := r.db.WithContext(ctx). if err := r.db.WithContext(ctx).
Model(&model.UserSubscription{}). Table("user_subscriptions").
Where("group_id = ?", id). Where("group_id = ?", id).
Select("user_id"). Pluck("user_id", &affectedUserIDs).Error; err != nil {
Find(&subscriptions).Error; err != nil {
return nil, err return nil, err
} }
for _, sub := range subscriptions {
affectedUserIDs = append(affectedUserIDs, sub.UserID)
}
} }
err = r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { err = r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// 1. 删除订阅类型分组的订阅记录 // 1. 删除订阅类型分组的订阅记录
if group.IsSubscriptionType() { if group.IsSubscriptionType() {
if err := tx.Where("group_id = ?", id).Delete(&model.UserSubscription{}).Error; err != nil { if err := tx.Exec("DELETE FROM user_subscriptions WHERE group_id = ?", id).Error; err != nil {
return err return err
} }
} }
// 2. 将 api_keys 中绑定该分组的 group_id 设为 nil // 2. 将 api_keys 中绑定该分组的 group_id 设为 nil
if err := tx.Model(&model.ApiKey{}).Where("group_id = ?", id).Update("group_id", nil).Error; err != nil { if err := tx.Exec("UPDATE api_keys SET group_id = NULL WHERE group_id = ?", id).Error; err != nil {
return err return err
} }
// 3. 从 users.allowed_groups 数组中移除该分组 ID // 3. 从 users.allowed_groups 数组中移除该分组 ID
if err := tx.Model(&model.User{}). if err := tx.Exec(
Where("? = ANY(allowed_groups)", id). "UPDATE users SET allowed_groups = array_remove(allowed_groups, ?) WHERE ? = ANY(allowed_groups)",
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", id)).Error; err != nil { id, id,
).Error; err != nil {
return err return err
} }
// 4. 删除 account_groups 中间表的数据 // 4. 删除 account_groups 中间表的数据
if err := tx.Where("group_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil { if err := tx.Exec("DELETE FROM account_groups WHERE group_id = ?", id).Error; err != nil {
return err return err
} }
// 5. 删除分组本身(带锁,避免并发写) // 5. 删除分组本身(带锁,避免并发写)
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Delete(&model.Group{}, id).Error; err != nil { if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Delete(&groupModel{}, id).Error; err != nil {
return err return err
} }
...@@ -196,3 +207,75 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, ...@@ -196,3 +207,75 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
return affectedUserIDs, nil return affectedUserIDs, nil
} }
type groupModel struct {
ID int64 `gorm:"primaryKey"`
Name string `gorm:"uniqueIndex;size:100;not null"`
Description string `gorm:"type:text"`
Platform string `gorm:"size:50;default:anthropic;not null"`
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null"`
IsExclusive bool `gorm:"default:false;not null"`
Status string `gorm:"size:20;default:active;not null"`
SubscriptionType string `gorm:"size:20;default:standard;not null"`
DailyLimitUSD *float64 `gorm:"type:decimal(20,8)"`
WeeklyLimitUSD *float64 `gorm:"type:decimal(20,8)"`
MonthlyLimitUSD *float64 `gorm:"type:decimal(20,8)"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
DeletedAt gorm.DeletedAt `gorm:"index"`
}
func (groupModel) TableName() string { return "groups" }
func groupModelToService(m *groupModel) *service.Group {
if m == nil {
return nil
}
return &service.Group{
ID: m.ID,
Name: m.Name,
Description: m.Description,
Platform: m.Platform,
RateMultiplier: m.RateMultiplier,
IsExclusive: m.IsExclusive,
Status: m.Status,
SubscriptionType: m.SubscriptionType,
DailyLimitUSD: m.DailyLimitUSD,
WeeklyLimitUSD: m.WeeklyLimitUSD,
MonthlyLimitUSD: m.MonthlyLimitUSD,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
}
}
func groupModelFromService(sg *service.Group) *groupModel {
if sg == nil {
return nil
}
return &groupModel{
ID: sg.ID,
Name: sg.Name,
Description: sg.Description,
Platform: sg.Platform,
RateMultiplier: sg.RateMultiplier,
IsExclusive: sg.IsExclusive,
Status: sg.Status,
SubscriptionType: sg.SubscriptionType,
DailyLimitUSD: sg.DailyLimitUSD,
WeeklyLimitUSD: sg.WeeklyLimitUSD,
MonthlyLimitUSD: sg.MonthlyLimitUSD,
CreatedAt: sg.CreatedAt,
UpdatedAt: sg.UpdatedAt,
}
}
func applyGroupModelToService(group *service.Group, m *groupModel) {
if group == nil || m == nil {
return
}
group.ID = m.ID
group.CreatedAt = m.CreatedAt
group.UpdatedAt = m.UpdatedAt
}
...@@ -6,8 +6,8 @@ import ( ...@@ -6,8 +6,8 @@ import (
"context" "context"
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"gorm.io/gorm" "gorm.io/gorm"
) )
...@@ -32,10 +32,10 @@ func TestGroupRepoSuite(t *testing.T) { ...@@ -32,10 +32,10 @@ func TestGroupRepoSuite(t *testing.T) {
// --- Create / GetByID / Update / Delete --- // --- Create / GetByID / Update / Delete ---
func (s *GroupRepoSuite) TestCreate() { func (s *GroupRepoSuite) TestCreate() {
group := &model.Group{ group := &service.Group{
Name: "test-create", Name: "test-create",
Platform: model.PlatformAnthropic, Platform: service.PlatformAnthropic,
Status: model.StatusActive, Status: service.StatusActive,
} }
err := s.repo.Create(s.ctx, group) err := s.repo.Create(s.ctx, group)
...@@ -53,7 +53,7 @@ func (s *GroupRepoSuite) TestGetByID_NotFound() { ...@@ -53,7 +53,7 @@ func (s *GroupRepoSuite) TestGetByID_NotFound() {
} }
func (s *GroupRepoSuite) TestUpdate() { func (s *GroupRepoSuite) TestUpdate() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "original"}) group := groupModelToService(mustCreateGroup(s.T(), s.db, &groupModel{Name: "original"}))
group.Name = "updated" group.Name = "updated"
err := s.repo.Update(s.ctx, group) err := s.repo.Update(s.ctx, group)
...@@ -65,7 +65,7 @@ func (s *GroupRepoSuite) TestUpdate() { ...@@ -65,7 +65,7 @@ func (s *GroupRepoSuite) TestUpdate() {
} }
func (s *GroupRepoSuite) TestDelete() { func (s *GroupRepoSuite) TestDelete() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "to-delete"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "to-delete"})
err := s.repo.Delete(s.ctx, group.ID) err := s.repo.Delete(s.ctx, group.ID)
s.Require().NoError(err, "Delete") s.Require().NoError(err, "Delete")
...@@ -77,8 +77,8 @@ func (s *GroupRepoSuite) TestDelete() { ...@@ -77,8 +77,8 @@ func (s *GroupRepoSuite) TestDelete() {
// --- List / ListWithFilters --- // --- List / ListWithFilters ---
func (s *GroupRepoSuite) TestList() { func (s *GroupRepoSuite) TestList() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1"}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1"})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2"}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2"})
groups, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}) groups, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List") s.Require().NoError(err, "List")
...@@ -87,28 +87,28 @@ func (s *GroupRepoSuite) TestList() { ...@@ -87,28 +87,28 @@ func (s *GroupRepoSuite) TestList() {
} }
func (s *GroupRepoSuite) TestListWithFilters_Platform() { func (s *GroupRepoSuite) TestListWithFilters_Platform() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", Platform: model.PlatformAnthropic}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", Platform: service.PlatformAnthropic})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", Platform: model.PlatformOpenAI}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", Platform: service.PlatformOpenAI})
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.PlatformOpenAI, "", nil) groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", nil)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(groups, 1) s.Require().Len(groups, 1)
s.Require().Equal(model.PlatformOpenAI, groups[0].Platform) s.Require().Equal(service.PlatformOpenAI, groups[0].Platform)
} }
func (s *GroupRepoSuite) TestListWithFilters_Status() { func (s *GroupRepoSuite) TestListWithFilters_Status() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", Status: model.StatusActive}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", Status: service.StatusActive})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", Status: model.StatusDisabled}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", Status: service.StatusDisabled})
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusDisabled, nil) groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, nil)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(groups, 1) s.Require().Len(groups, 1)
s.Require().Equal(model.StatusDisabled, groups[0].Status) s.Require().Equal(service.StatusDisabled, groups[0].Status)
} }
func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() { func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", IsExclusive: false}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", IsExclusive: false})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", IsExclusive: true}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", IsExclusive: true})
isExclusive := true isExclusive := true
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", &isExclusive) groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", &isExclusive)
...@@ -118,24 +118,24 @@ func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() { ...@@ -118,24 +118,24 @@ func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
} }
func (s *GroupRepoSuite) TestListWithFilters_AccountCount() { func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
g1 := mustCreateGroup(s.T(), s.db, &model.Group{ g1 := mustCreateGroup(s.T(), s.db, &groupModel{
Name: "g1", Name: "g1",
Platform: model.PlatformAnthropic, Platform: service.PlatformAnthropic,
Status: model.StatusActive, Status: service.StatusActive,
}) })
g2 := mustCreateGroup(s.T(), s.db, &model.Group{ g2 := mustCreateGroup(s.T(), s.db, &groupModel{
Name: "g2", Name: "g2",
Platform: model.PlatformAnthropic, Platform: service.PlatformAnthropic,
Status: model.StatusActive, Status: service.StatusActive,
IsExclusive: true, IsExclusive: true,
}) })
a := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc1"}) a := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc1"})
mustBindAccountToGroup(s.T(), s.db, a.ID, g1.ID, 1) mustBindAccountToGroup(s.T(), s.db, a.ID, g1.ID, 1)
mustBindAccountToGroup(s.T(), s.db, a.ID, g2.ID, 1) mustBindAccountToGroup(s.T(), s.db, a.ID, g2.ID, 1)
isExclusive := true isExclusive := true
groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.PlatformAnthropic, model.StatusActive, &isExclusive) groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformAnthropic, service.StatusActive, &isExclusive)
s.Require().NoError(err, "ListWithFilters") s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total) s.Require().Equal(int64(1), page.Total)
s.Require().Len(groups, 1) s.Require().Len(groups, 1)
...@@ -146,8 +146,8 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() { ...@@ -146,8 +146,8 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
// --- ListActive / ListActiveByPlatform --- // --- ListActive / ListActiveByPlatform ---
func (s *GroupRepoSuite) TestListActive() { func (s *GroupRepoSuite) TestListActive() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "active1", Status: model.StatusActive}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "active1", Status: service.StatusActive})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "inactive1", Status: model.StatusDisabled}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "inactive1", Status: service.StatusDisabled})
groups, err := s.repo.ListActive(s.ctx) groups, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive") s.Require().NoError(err, "ListActive")
...@@ -156,11 +156,11 @@ func (s *GroupRepoSuite) TestListActive() { ...@@ -156,11 +156,11 @@ func (s *GroupRepoSuite) TestListActive() {
} }
func (s *GroupRepoSuite) TestListActiveByPlatform() { func (s *GroupRepoSuite) TestListActiveByPlatform() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", Platform: model.PlatformAnthropic, Status: model.StatusActive}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", Platform: service.PlatformAnthropic, Status: service.StatusActive})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", Platform: model.PlatformOpenAI, Status: model.StatusActive}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", Platform: service.PlatformOpenAI, Status: service.StatusActive})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g3", Platform: model.PlatformAnthropic, Status: model.StatusDisabled}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "g3", Platform: service.PlatformAnthropic, Status: service.StatusDisabled})
groups, err := s.repo.ListActiveByPlatform(s.ctx, model.PlatformAnthropic) groups, err := s.repo.ListActiveByPlatform(s.ctx, service.PlatformAnthropic)
s.Require().NoError(err, "ListActiveByPlatform") s.Require().NoError(err, "ListActiveByPlatform")
s.Require().Len(groups, 1) s.Require().Len(groups, 1)
s.Require().Equal("g1", groups[0].Name) s.Require().Equal("g1", groups[0].Name)
...@@ -169,7 +169,7 @@ func (s *GroupRepoSuite) TestListActiveByPlatform() { ...@@ -169,7 +169,7 @@ func (s *GroupRepoSuite) TestListActiveByPlatform() {
// --- ExistsByName --- // --- ExistsByName ---
func (s *GroupRepoSuite) TestExistsByName() { func (s *GroupRepoSuite) TestExistsByName() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "existing-group"}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "existing-group"})
exists, err := s.repo.ExistsByName(s.ctx, "existing-group") exists, err := s.repo.ExistsByName(s.ctx, "existing-group")
s.Require().NoError(err, "ExistsByName") s.Require().NoError(err, "ExistsByName")
...@@ -183,9 +183,9 @@ func (s *GroupRepoSuite) TestExistsByName() { ...@@ -183,9 +183,9 @@ func (s *GroupRepoSuite) TestExistsByName() {
// --- GetAccountCount --- // --- GetAccountCount ---
func (s *GroupRepoSuite) TestGetAccountCount() { func (s *GroupRepoSuite) TestGetAccountCount() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-count"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-count"})
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1"}) a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1"})
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2"}) a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2"})
mustBindAccountToGroup(s.T(), s.db, a1.ID, group.ID, 1) mustBindAccountToGroup(s.T(), s.db, a1.ID, group.ID, 1)
mustBindAccountToGroup(s.T(), s.db, a2.ID, group.ID, 2) mustBindAccountToGroup(s.T(), s.db, a2.ID, group.ID, 2)
...@@ -195,7 +195,7 @@ func (s *GroupRepoSuite) TestGetAccountCount() { ...@@ -195,7 +195,7 @@ func (s *GroupRepoSuite) TestGetAccountCount() {
} }
func (s *GroupRepoSuite) TestGetAccountCount_Empty() { func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-empty"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-empty"})
count, err := s.repo.GetAccountCount(s.ctx, group.ID) count, err := s.repo.GetAccountCount(s.ctx, group.ID)
s.Require().NoError(err) s.Require().NoError(err)
...@@ -205,8 +205,8 @@ func (s *GroupRepoSuite) TestGetAccountCount_Empty() { ...@@ -205,8 +205,8 @@ func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
// --- DeleteAccountGroupsByGroupID --- // --- DeleteAccountGroupsByGroupID ---
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() { func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
g := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-del"}) g := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-del"})
a := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-del"}) a := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-del"})
mustBindAccountToGroup(s.T(), s.db, a.ID, g.ID, 1) mustBindAccountToGroup(s.T(), s.db, a.ID, g.ID, 1)
affected, err := s.repo.DeleteAccountGroupsByGroupID(s.ctx, g.ID) affected, err := s.repo.DeleteAccountGroupsByGroupID(s.ctx, g.ID)
...@@ -219,10 +219,10 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() { ...@@ -219,10 +219,10 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
} }
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() { func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
g := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-multi"}) g := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-multi"})
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1"}) a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1"})
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2"}) a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2"})
a3 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3"}) a3 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3"})
mustBindAccountToGroup(s.T(), s.db, a1.ID, g.ID, 1) mustBindAccountToGroup(s.T(), s.db, a1.ID, g.ID, 1)
mustBindAccountToGroup(s.T(), s.db, a2.ID, g.ID, 2) mustBindAccountToGroup(s.T(), s.db, a2.ID, g.ID, 2)
mustBindAccountToGroup(s.T(), s.db, a3.ID, g.ID, 3) mustBindAccountToGroup(s.T(), s.db, a3.ID, g.ID, 3)
......
...@@ -15,6 +15,11 @@ const ( ...@@ -15,6 +15,11 @@ const (
fingerprintTTL = 24 * time.Hour fingerprintTTL = 24 * time.Hour
) )
// fingerprintKey generates the Redis key for account fingerprint cache.
func fingerprintKey(accountID int64) string {
return fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
}
type identityCache struct { type identityCache struct {
rdb *redis.Client rdb *redis.Client
} }
...@@ -24,7 +29,7 @@ func NewIdentityCache(rdb *redis.Client) service.IdentityCache { ...@@ -24,7 +29,7 @@ func NewIdentityCache(rdb *redis.Client) service.IdentityCache {
} }
func (c *identityCache) GetFingerprint(ctx context.Context, accountID int64) (*service.Fingerprint, error) { func (c *identityCache) GetFingerprint(ctx context.Context, accountID int64) (*service.Fingerprint, error) {
key := fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID) key := fingerprintKey(accountID)
val, err := c.rdb.Get(ctx, key).Result() val, err := c.rdb.Get(ctx, key).Result()
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -37,7 +42,7 @@ func (c *identityCache) GetFingerprint(ctx context.Context, accountID int64) (*s ...@@ -37,7 +42,7 @@ func (c *identityCache) GetFingerprint(ctx context.Context, accountID int64) (*s
} }
func (c *identityCache) SetFingerprint(ctx context.Context, accountID int64, fp *service.Fingerprint) error { func (c *identityCache) SetFingerprint(ctx context.Context, accountID int64, fp *service.Fingerprint) error {
key := fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID) key := fingerprintKey(accountID)
val, err := json.Marshal(fp) val, err := json.Marshal(fp)
if err != nil { if err != nil {
return err return err
......
//go:build unit
package repository
import (
"math"
"testing"
"github.com/stretchr/testify/require"
)
func TestFingerprintKey(t *testing.T) {
tests := []struct {
name string
accountID int64
expected string
}{
{
name: "normal_account_id",
accountID: 123,
expected: "fingerprint:123",
},
{
name: "zero_account_id",
accountID: 0,
expected: "fingerprint:0",
},
{
name: "negative_account_id",
accountID: -1,
expected: "fingerprint:-1",
},
{
name: "max_int64",
accountID: math.MaxInt64,
expected: "fingerprint:9223372036854775807",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := fingerprintKey(tc.accountID)
require.Equal(t, tc.expected, got)
})
}
}
...@@ -15,7 +15,6 @@ import ( ...@@ -15,7 +15,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
...@@ -94,7 +93,7 @@ func TestMain(m *testing.M) { ...@@ -94,7 +93,7 @@ func TestMain(m *testing.M) {
log.Printf("failed to open gorm db: %v", err) log.Printf("failed to open gorm db: %v", err)
os.Exit(1) os.Exit(1)
} }
if err := model.AutoMigrate(integrationDB); err != nil { if err := AutoMigrate(integrationDB); err != nil {
log.Printf("failed to automigrate db: %v", err) log.Printf("failed to automigrate db: %v", err)
os.Exit(1) os.Exit(1)
} }
......
package repository
import "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
func paginationResultFromTotal(total int64, params pagination.PaginationParams) *pagination.PaginationResult {
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
}
return &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}
}
...@@ -2,10 +2,10 @@ package repository ...@@ -2,10 +2,10 @@ package repository
import ( import (
"context" "context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm" "gorm.io/gorm"
...@@ -19,37 +19,47 @@ func NewProxyRepository(db *gorm.DB) service.ProxyRepository { ...@@ -19,37 +19,47 @@ func NewProxyRepository(db *gorm.DB) service.ProxyRepository {
return &proxyRepository{db: db} return &proxyRepository{db: db}
} }
func (r *proxyRepository) Create(ctx context.Context, proxy *model.Proxy) error { func (r *proxyRepository) Create(ctx context.Context, proxy *service.Proxy) error {
return r.db.WithContext(ctx).Create(proxy).Error m := proxyModelFromService(proxy)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyProxyModelToService(proxy, m)
}
return err
} }
func (r *proxyRepository) GetByID(ctx context.Context, id int64) (*model.Proxy, error) { func (r *proxyRepository) GetByID(ctx context.Context, id int64) (*service.Proxy, error) {
var proxy model.Proxy var m proxyModel
err := r.db.WithContext(ctx).First(&proxy, id).Error err := r.db.WithContext(ctx).First(&m, id).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrProxyNotFound, nil) return nil, translatePersistenceError(err, service.ErrProxyNotFound, nil)
} }
return &proxy, nil return proxyModelToService(&m), nil
} }
func (r *proxyRepository) Update(ctx context.Context, proxy *model.Proxy) error { func (r *proxyRepository) Update(ctx context.Context, proxy *service.Proxy) error {
return r.db.WithContext(ctx).Save(proxy).Error m := proxyModelFromService(proxy)
err := r.db.WithContext(ctx).Save(m).Error
if err == nil {
applyProxyModelToService(proxy, m)
}
return err
} }
func (r *proxyRepository) Delete(ctx context.Context, id int64) error { func (r *proxyRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.Proxy{}, id).Error return r.db.WithContext(ctx).Delete(&proxyModel{}, id).Error
} }
func (r *proxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) { func (r *proxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Proxy, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "") return r.ListWithFilters(ctx, params, "", "", "")
} }
// ListWithFilters lists proxies with optional filtering by protocol, status, and search query // ListWithFilters lists proxies with optional filtering by protocol, status, and search query
func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error) { func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]service.Proxy, *pagination.PaginationResult, error) {
var proxies []model.Proxy var proxies []proxyModel
var total int64 var total int64
db := r.db.WithContext(ctx).Model(&model.Proxy{}) db := r.db.WithContext(ctx).Model(&proxyModel{})
// Apply filters // Apply filters
if protocol != "" { if protocol != "" {
...@@ -71,29 +81,31 @@ func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination ...@@ -71,29 +81,31 @@ func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination
return nil, nil, err return nil, nil, err
} }
pages := int(total) / params.Limit() outProxies := make([]service.Proxy, 0, len(proxies))
if int(total)%params.Limit() > 0 { for i := range proxies {
pages++ outProxies = append(outProxies, *proxyModelToService(&proxies[i]))
} }
return proxies, &pagination.PaginationResult{ return outProxies, paginationResultFromTotal(total, params), nil
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
} }
func (r *proxyRepository) ListActive(ctx context.Context) ([]model.Proxy, error) { func (r *proxyRepository) ListActive(ctx context.Context) ([]service.Proxy, error) {
var proxies []model.Proxy var proxies []proxyModel
err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Find(&proxies).Error err := r.db.WithContext(ctx).Where("status = ?", service.StatusActive).Find(&proxies).Error
return proxies, err if err != nil {
return nil, err
}
outProxies := make([]service.Proxy, 0, len(proxies))
for i := range proxies {
outProxies = append(outProxies, *proxyModelToService(&proxies[i]))
}
return outProxies, nil
} }
// ExistsByHostPortAuth checks if a proxy with the same host, port, username, and password exists // ExistsByHostPortAuth checks if a proxy with the same host, port, username, and password exists
func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) { func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.Proxy{}). err := r.db.WithContext(ctx).Model(&proxyModel{}).
Where("host = ? AND port = ? AND username = ? AND password = ?", host, port, username, password). Where("host = ? AND port = ? AND username = ? AND password = ?", host, port, username, password).
Count(&count).Error Count(&count).Error
if err != nil { if err != nil {
...@@ -105,7 +117,7 @@ func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, ...@@ -105,7 +117,7 @@ func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string,
// CountAccountsByProxyID returns the number of accounts using a specific proxy // CountAccountsByProxyID returns the number of accounts using a specific proxy
func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) { func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.Account{}). err := r.db.WithContext(ctx).Table("accounts").
Where("proxy_id = ?", proxyID). Where("proxy_id = ?", proxyID).
Count(&count).Error Count(&count).Error
return count, err return count, err
...@@ -119,7 +131,7 @@ func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[i ...@@ -119,7 +131,7 @@ func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[i
} }
var results []result var results []result
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Model(&model.Account{}). Table("accounts").
Select("proxy_id, COUNT(*) as count"). Select("proxy_id, COUNT(*) as count").
Where("proxy_id IS NOT NULL"). Where("proxy_id IS NOT NULL").
Group("proxy_id"). Group("proxy_id").
...@@ -136,10 +148,10 @@ func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[i ...@@ -136,10 +148,10 @@ func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[i
} }
// ListActiveWithAccountCount returns all active proxies with account count, sorted by creation time descending // ListActiveWithAccountCount returns all active proxies with account count, sorted by creation time descending
func (r *proxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) { func (r *proxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]service.ProxyWithAccountCount, error) {
var proxies []model.Proxy var proxies []proxyModel
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("status = ?", model.StatusActive). Where("status = ?", service.StatusActive).
Order("created_at DESC"). Order("created_at DESC").
Find(&proxies).Error Find(&proxies).Error
if err != nil { if err != nil {
...@@ -153,13 +165,78 @@ func (r *proxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]mod ...@@ -153,13 +165,78 @@ func (r *proxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]mod
} }
// Build result with account counts // Build result with account counts
result := make([]model.ProxyWithAccountCount, len(proxies)) result := make([]service.ProxyWithAccountCount, 0, len(proxies))
for i, proxy := range proxies { for i := range proxies {
result[i] = model.ProxyWithAccountCount{ proxy := proxyModelToService(&proxies[i])
Proxy: proxy, if proxy == nil {
AccountCount: counts[proxy.ID], continue
} }
result = append(result, service.ProxyWithAccountCount{
Proxy: *proxy,
AccountCount: counts[proxy.ID],
})
} }
return result, nil return result, nil
} }
type proxyModel struct {
ID int64 `gorm:"primaryKey"`
Name string `gorm:"size:100;not null"`
Protocol string `gorm:"size:20;not null"`
Host string `gorm:"size:255;not null"`
Port int `gorm:"not null"`
Username string `gorm:"size:100"`
Password string `gorm:"size:100"`
Status string `gorm:"size:20;default:active;not null"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
DeletedAt gorm.DeletedAt `gorm:"index"`
}
func (proxyModel) TableName() string { return "proxies" }
func proxyModelToService(m *proxyModel) *service.Proxy {
if m == nil {
return nil
}
return &service.Proxy{
ID: m.ID,
Name: m.Name,
Protocol: m.Protocol,
Host: m.Host,
Port: m.Port,
Username: m.Username,
Password: m.Password,
Status: m.Status,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
}
}
func proxyModelFromService(p *service.Proxy) *proxyModel {
if p == nil {
return nil
}
return &proxyModel{
ID: p.ID,
Name: p.Name,
Protocol: p.Protocol,
Host: p.Host,
Port: p.Port,
Username: p.Username,
Password: p.Password,
Status: p.Status,
CreatedAt: p.CreatedAt,
UpdatedAt: p.UpdatedAt,
}
}
func applyProxyModelToService(proxy *service.Proxy, m *proxyModel) {
if proxy == nil || m == nil {
return
}
proxy.ID = m.ID
proxy.CreatedAt = m.CreatedAt
proxy.UpdatedAt = m.UpdatedAt
}
...@@ -7,8 +7,8 @@ import ( ...@@ -7,8 +7,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"gorm.io/gorm" "gorm.io/gorm"
) )
...@@ -33,12 +33,12 @@ func TestProxyRepoSuite(t *testing.T) { ...@@ -33,12 +33,12 @@ func TestProxyRepoSuite(t *testing.T) {
// --- Create / GetByID / Update / Delete --- // --- Create / GetByID / Update / Delete ---
func (s *ProxyRepoSuite) TestCreate() { func (s *ProxyRepoSuite) TestCreate() {
proxy := &model.Proxy{ proxy := &service.Proxy{
Name: "test-create", Name: "test-create",
Protocol: "http", Protocol: "http",
Host: "127.0.0.1", Host: "127.0.0.1",
Port: 8080, Port: 8080,
Status: model.StatusActive, Status: service.StatusActive,
} }
err := s.repo.Create(s.ctx, proxy) err := s.repo.Create(s.ctx, proxy)
...@@ -56,7 +56,7 @@ func (s *ProxyRepoSuite) TestGetByID_NotFound() { ...@@ -56,7 +56,7 @@ func (s *ProxyRepoSuite) TestGetByID_NotFound() {
} }
func (s *ProxyRepoSuite) TestUpdate() { func (s *ProxyRepoSuite) TestUpdate() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "original"}) proxy := proxyModelToService(mustCreateProxy(s.T(), s.db, &proxyModel{Name: "original"}))
proxy.Name = "updated" proxy.Name = "updated"
err := s.repo.Update(s.ctx, proxy) err := s.repo.Update(s.ctx, proxy)
...@@ -68,7 +68,7 @@ func (s *ProxyRepoSuite) TestUpdate() { ...@@ -68,7 +68,7 @@ func (s *ProxyRepoSuite) TestUpdate() {
} }
func (s *ProxyRepoSuite) TestDelete() { func (s *ProxyRepoSuite) TestDelete() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "to-delete"}) proxy := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "to-delete"})
err := s.repo.Delete(s.ctx, proxy.ID) err := s.repo.Delete(s.ctx, proxy.ID)
s.Require().NoError(err, "Delete") s.Require().NoError(err, "Delete")
...@@ -80,8 +80,8 @@ func (s *ProxyRepoSuite) TestDelete() { ...@@ -80,8 +80,8 @@ func (s *ProxyRepoSuite) TestDelete() {
// --- List / ListWithFilters --- // --- List / ListWithFilters ---
func (s *ProxyRepoSuite) TestList() { func (s *ProxyRepoSuite) TestList() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"}) mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1"})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2"}) mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p2"})
proxies, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}) proxies, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List") s.Require().NoError(err, "List")
...@@ -90,8 +90,8 @@ func (s *ProxyRepoSuite) TestList() { ...@@ -90,8 +90,8 @@ func (s *ProxyRepoSuite) TestList() {
} }
func (s *ProxyRepoSuite) TestListWithFilters_Protocol() { func (s *ProxyRepoSuite) TestListWithFilters_Protocol() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1", Protocol: "http"}) mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1", Protocol: "http"})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2", Protocol: "socks5"}) mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p2", Protocol: "socks5"})
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "socks5", "", "") proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "socks5", "", "")
s.Require().NoError(err) s.Require().NoError(err)
...@@ -100,18 +100,18 @@ func (s *ProxyRepoSuite) TestListWithFilters_Protocol() { ...@@ -100,18 +100,18 @@ func (s *ProxyRepoSuite) TestListWithFilters_Protocol() {
} }
func (s *ProxyRepoSuite) TestListWithFilters_Status() { func (s *ProxyRepoSuite) TestListWithFilters_Status() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1", Status: model.StatusActive}) mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1", Status: service.StatusActive})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2", Status: model.StatusDisabled}) mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p2", Status: service.StatusDisabled})
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusDisabled, "") proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, "")
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(proxies, 1) s.Require().Len(proxies, 1)
s.Require().Equal(model.StatusDisabled, proxies[0].Status) s.Require().Equal(service.StatusDisabled, proxies[0].Status)
} }
func (s *ProxyRepoSuite) TestListWithFilters_Search() { func (s *ProxyRepoSuite) TestListWithFilters_Search() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "production-proxy"}) mustCreateProxy(s.T(), s.db, &proxyModel{Name: "production-proxy"})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "dev-proxy"}) mustCreateProxy(s.T(), s.db, &proxyModel{Name: "dev-proxy"})
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "prod") proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "prod")
s.Require().NoError(err) s.Require().NoError(err)
...@@ -122,8 +122,8 @@ func (s *ProxyRepoSuite) TestListWithFilters_Search() { ...@@ -122,8 +122,8 @@ func (s *ProxyRepoSuite) TestListWithFilters_Search() {
// --- ListActive --- // --- ListActive ---
func (s *ProxyRepoSuite) TestListActive() { func (s *ProxyRepoSuite) TestListActive() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "active1", Status: model.StatusActive}) mustCreateProxy(s.T(), s.db, &proxyModel{Name: "active1", Status: service.StatusActive})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "inactive1", Status: model.StatusDisabled}) mustCreateProxy(s.T(), s.db, &proxyModel{Name: "inactive1", Status: service.StatusDisabled})
proxies, err := s.repo.ListActive(s.ctx) proxies, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive") s.Require().NoError(err, "ListActive")
...@@ -134,7 +134,7 @@ func (s *ProxyRepoSuite) TestListActive() { ...@@ -134,7 +134,7 @@ func (s *ProxyRepoSuite) TestListActive() {
// --- ExistsByHostPortAuth --- // --- ExistsByHostPortAuth ---
func (s *ProxyRepoSuite) TestExistsByHostPortAuth() { func (s *ProxyRepoSuite) TestExistsByHostPortAuth() {
mustCreateProxy(s.T(), s.db, &model.Proxy{ mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p1", Name: "p1",
Protocol: "http", Protocol: "http",
Host: "1.2.3.4", Host: "1.2.3.4",
...@@ -153,7 +153,7 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth() { ...@@ -153,7 +153,7 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth() {
} }
func (s *ProxyRepoSuite) TestExistsByHostPortAuth_NoAuth() { func (s *ProxyRepoSuite) TestExistsByHostPortAuth_NoAuth() {
mustCreateProxy(s.T(), s.db, &model.Proxy{ mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p-noauth", Name: "p-noauth",
Protocol: "http", Protocol: "http",
Host: "5.6.7.8", Host: "5.6.7.8",
...@@ -170,10 +170,10 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth_NoAuth() { ...@@ -170,10 +170,10 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth_NoAuth() {
// --- CountAccountsByProxyID --- // --- CountAccountsByProxyID ---
func (s *ProxyRepoSuite) TestCountAccountsByProxyID() { func (s *ProxyRepoSuite) TestCountAccountsByProxyID() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p-count"}) proxy := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p-count"})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &proxy.ID}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", ProxyID: &proxy.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &proxy.ID}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", ProxyID: &proxy.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3"}) // no proxy mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3"}) // no proxy
count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID) count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID)
s.Require().NoError(err, "CountAccountsByProxyID") s.Require().NoError(err, "CountAccountsByProxyID")
...@@ -181,7 +181,7 @@ func (s *ProxyRepoSuite) TestCountAccountsByProxyID() { ...@@ -181,7 +181,7 @@ func (s *ProxyRepoSuite) TestCountAccountsByProxyID() {
} }
func (s *ProxyRepoSuite) TestCountAccountsByProxyID_Zero() { func (s *ProxyRepoSuite) TestCountAccountsByProxyID_Zero() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p-zero"}) proxy := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p-zero"})
count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID) count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID)
s.Require().NoError(err) s.Require().NoError(err)
...@@ -191,12 +191,12 @@ func (s *ProxyRepoSuite) TestCountAccountsByProxyID_Zero() { ...@@ -191,12 +191,12 @@ func (s *ProxyRepoSuite) TestCountAccountsByProxyID_Zero() {
// --- GetAccountCountsForProxies --- // --- GetAccountCountsForProxies ---
func (s *ProxyRepoSuite) TestGetAccountCountsForProxies() { func (s *ProxyRepoSuite) TestGetAccountCountsForProxies() {
p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"}) p1 := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1"})
p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2"}) p2 := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p2"})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3", ProxyID: &p2.ID})
counts, err := s.repo.GetAccountCountsForProxies(s.ctx) counts, err := s.repo.GetAccountCountsForProxies(s.ctx)
s.Require().NoError(err, "GetAccountCountsForProxies") s.Require().NoError(err, "GetAccountCountsForProxies")
...@@ -215,24 +215,24 @@ func (s *ProxyRepoSuite) TestGetAccountCountsForProxies_Empty() { ...@@ -215,24 +215,24 @@ func (s *ProxyRepoSuite) TestGetAccountCountsForProxies_Empty() {
func (s *ProxyRepoSuite) TestListActiveWithAccountCount() { func (s *ProxyRepoSuite) TestListActiveWithAccountCount() {
base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{ p1 := mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p1", Name: "p1",
Status: model.StatusActive, Status: service.StatusActive,
CreatedAt: base.Add(-1 * time.Hour), CreatedAt: base.Add(-1 * time.Hour),
}) })
p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{ p2 := mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p2", Name: "p2",
Status: model.StatusActive, Status: service.StatusActive,
CreatedAt: base, CreatedAt: base,
}) })
mustCreateProxy(s.T(), s.db, &model.Proxy{ mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p3-inactive", Name: "p3-inactive",
Status: model.StatusDisabled, Status: service.StatusDisabled,
}) })
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3", ProxyID: &p2.ID})
withCounts, err := s.repo.ListActiveWithAccountCount(s.ctx) withCounts, err := s.repo.ListActiveWithAccountCount(s.ctx)
s.Require().NoError(err, "ListActiveWithAccountCount") s.Require().NoError(err, "ListActiveWithAccountCount")
...@@ -248,7 +248,7 @@ func (s *ProxyRepoSuite) TestListActiveWithAccountCount() { ...@@ -248,7 +248,7 @@ func (s *ProxyRepoSuite) TestListActiveWithAccountCount() {
// --- Combined original test --- // --- Combined original test ---
func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() { func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() {
p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{ p1 := mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p1", Name: "p1",
Protocol: "http", Protocol: "http",
Host: "1.2.3.4", Host: "1.2.3.4",
...@@ -258,7 +258,7 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() { ...@@ -258,7 +258,7 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() {
CreatedAt: time.Now().Add(-1 * time.Hour), CreatedAt: time.Now().Add(-1 * time.Hour),
UpdatedAt: time.Now().Add(-1 * time.Hour), UpdatedAt: time.Now().Add(-1 * time.Hour),
}) })
p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{ p2 := mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p2", Name: "p2",
Protocol: "http", Protocol: "http",
Host: "5.6.7.8", Host: "5.6.7.8",
...@@ -273,9 +273,9 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() { ...@@ -273,9 +273,9 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() {
s.Require().NoError(err, "ExistsByHostPortAuth") s.Require().NoError(err, "ExistsByHostPortAuth")
s.Require().True(exists, "expected proxy to exist") s.Require().True(exists, "expected proxy to exist")
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3", ProxyID: &p2.ID})
count1, err := s.repo.CountAccountsByProxyID(s.ctx, p1.ID) count1, err := s.repo.CountAccountsByProxyID(s.ctx, p1.ID)
s.Require().NoError(err, "CountAccountsByProxyID") s.Require().NoError(err, "CountAccountsByProxyID")
......
...@@ -15,6 +15,16 @@ const ( ...@@ -15,6 +15,16 @@ const (
redeemRateLimitDuration = 24 * time.Hour redeemRateLimitDuration = 24 * time.Hour
) )
// redeemRateLimitKey generates the Redis key for redeem attempt rate limiting.
func redeemRateLimitKey(userID int64) string {
return fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
}
// redeemLockKey generates the Redis key for redeem code locking.
func redeemLockKey(code string) string {
return redeemLockKeyPrefix + code
}
type redeemCache struct { type redeemCache struct {
rdb *redis.Client rdb *redis.Client
} }
...@@ -24,12 +34,16 @@ func NewRedeemCache(rdb *redis.Client) service.RedeemCache { ...@@ -24,12 +34,16 @@ func NewRedeemCache(rdb *redis.Client) service.RedeemCache {
} }
func (c *redeemCache) GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error) { func (c *redeemCache) GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID) key := redeemRateLimitKey(userID)
return c.rdb.Get(ctx, key).Int() count, err := c.rdb.Get(ctx, key).Int()
if err == redis.Nil {
return 0, nil
}
return count, err
} }
func (c *redeemCache) IncrementRedeemAttemptCount(ctx context.Context, userID int64) error { func (c *redeemCache) IncrementRedeemAttemptCount(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID) key := redeemRateLimitKey(userID)
pipe := c.rdb.Pipeline() pipe := c.rdb.Pipeline()
pipe.Incr(ctx, key) pipe.Incr(ctx, key)
pipe.Expire(ctx, key, redeemRateLimitDuration) pipe.Expire(ctx, key, redeemRateLimitDuration)
...@@ -38,11 +52,11 @@ func (c *redeemCache) IncrementRedeemAttemptCount(ctx context.Context, userID in ...@@ -38,11 +52,11 @@ func (c *redeemCache) IncrementRedeemAttemptCount(ctx context.Context, userID in
} }
func (c *redeemCache) AcquireRedeemLock(ctx context.Context, code string, ttl time.Duration) (bool, error) { func (c *redeemCache) AcquireRedeemLock(ctx context.Context, code string, ttl time.Duration) (bool, error) {
key := redeemLockKeyPrefix + code key := redeemLockKey(code)
return c.rdb.SetNX(ctx, key, 1, ttl).Result() return c.rdb.SetNX(ctx, key, 1, ttl).Result()
} }
func (c *redeemCache) ReleaseRedeemLock(ctx context.Context, code string) error { func (c *redeemCache) ReleaseRedeemLock(ctx context.Context, code string) error {
key := redeemLockKeyPrefix + code key := redeemLockKey(code)
return c.rdb.Del(ctx, key).Err() return c.rdb.Del(ctx, key).Err()
} }
...@@ -3,12 +3,10 @@ ...@@ -3,12 +3,10 @@
package repository package repository
import ( import (
"errors"
"fmt" "fmt"
"testing" "testing"
"time" "time"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
) )
...@@ -25,9 +23,9 @@ func (s *RedeemCacheSuite) SetupTest() { ...@@ -25,9 +23,9 @@ func (s *RedeemCacheSuite) SetupTest() {
func (s *RedeemCacheSuite) TestGetRedeemAttemptCount_Missing() { func (s *RedeemCacheSuite) TestGetRedeemAttemptCount_Missing() {
missingUserID := int64(99999) missingUserID := int64(99999)
_, err := s.cache.GetRedeemAttemptCount(s.ctx, missingUserID) count, err := s.cache.GetRedeemAttemptCount(s.ctx, missingUserID)
require.Error(s.T(), err, "expected redis.Nil for missing rate-limit key") require.NoError(s.T(), err, "expected nil error for missing rate-limit key")
require.True(s.T(), errors.Is(err, redis.Nil)) require.Equal(s.T(), 0, count, "expected zero count for missing key")
} }
func (s *RedeemCacheSuite) TestIncrementAndGetRedeemAttemptCount() { func (s *RedeemCacheSuite) TestIncrementAndGetRedeemAttemptCount() {
......
//go:build unit
package repository
import (
"math"
"testing"
"github.com/stretchr/testify/require"
)
func TestRedeemRateLimitKey(t *testing.T) {
tests := []struct {
name string
userID int64
expected string
}{
{
name: "normal_user_id",
userID: 123,
expected: "redeem:ratelimit:123",
},
{
name: "zero_user_id",
userID: 0,
expected: "redeem:ratelimit:0",
},
{
name: "negative_user_id",
userID: -1,
expected: "redeem:ratelimit:-1",
},
{
name: "max_int64",
userID: math.MaxInt64,
expected: "redeem:ratelimit:9223372036854775807",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := redeemRateLimitKey(tc.userID)
require.Equal(t, tc.expected, got)
})
}
}
func TestRedeemLockKey(t *testing.T) {
tests := []struct {
name string
code string
expected string
}{
{
name: "normal_code",
code: "ABC123",
expected: "redeem:lock:ABC123",
},
{
name: "empty_code",
code: "",
expected: "redeem:lock:",
},
{
name: "code_with_special_chars",
code: "CODE-2024:test",
expected: "redeem:lock:CODE-2024:test",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := redeemLockKey(tc.code)
require.Equal(t, tc.expected, got)
})
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment