Commit 22f07a7b authored by shaw's avatar shaw
Browse files

Merge PR #36: refactor: 调整项目结构为单向依赖

parents ecb2c535 e5a77853
package handler package handler
import ( import (
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
...@@ -35,19 +36,13 @@ type UpdateProfileRequest struct { ...@@ -35,19 +36,13 @@ type UpdateProfileRequest struct {
// GetProfile handles getting user profile // GetProfile handles getting user profile
// GET /api/v1/users/me // GET /api/v1/users/me
func (h *UserHandler) GetProfile(c *gin.Context) { func (h *UserHandler) GetProfile(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
userData, err := h.userService.GetByID(c.Request.Context(), user.ID) userData, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
...@@ -56,21 +51,15 @@ func (h *UserHandler) GetProfile(c *gin.Context) { ...@@ -56,21 +51,15 @@ func (h *UserHandler) GetProfile(c *gin.Context) {
// 清空notes字段,普通用户不应看到备注 // 清空notes字段,普通用户不应看到备注
userData.Notes = "" userData.Notes = ""
response.Success(c, userData) response.Success(c, dto.UserFromService(userData))
} }
// ChangePassword handles changing user password // ChangePassword handles changing user password
// POST /api/v1/users/me/password // POST /api/v1/users/me/password
func (h *UserHandler) ChangePassword(c *gin.Context) { func (h *UserHandler) ChangePassword(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
...@@ -84,7 +73,7 @@ func (h *UserHandler) ChangePassword(c *gin.Context) { ...@@ -84,7 +73,7 @@ func (h *UserHandler) ChangePassword(c *gin.Context) {
CurrentPassword: req.OldPassword, CurrentPassword: req.OldPassword,
NewPassword: req.NewPassword, NewPassword: req.NewPassword,
} }
err := h.userService.ChangePassword(c.Request.Context(), user.ID, svcReq) err := h.userService.ChangePassword(c.Request.Context(), subject.UserID, svcReq)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
...@@ -96,15 +85,9 @@ func (h *UserHandler) ChangePassword(c *gin.Context) { ...@@ -96,15 +85,9 @@ func (h *UserHandler) ChangePassword(c *gin.Context) {
// UpdateProfile handles updating user profile // UpdateProfile handles updating user profile
// PUT /api/v1/users/me // PUT /api/v1/users/me
func (h *UserHandler) UpdateProfile(c *gin.Context) { func (h *UserHandler) UpdateProfile(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
...@@ -118,7 +101,7 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) { ...@@ -118,7 +101,7 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
Username: req.Username, Username: req.Username,
Wechat: req.Wechat, Wechat: req.Wechat,
} }
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), user.ID, svcReq) updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), subject.UserID, svcReq)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
...@@ -127,5 +110,5 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) { ...@@ -127,5 +110,5 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
// 清空notes字段,普通用户不应看到备注 // 清空notes字段,普通用户不应看到备注
updatedUser.Notes = "" updatedUser.Notes = ""
response.Success(c, updatedUser) response.Success(c, dto.UserFromService(updatedUser))
} }
...@@ -2,8 +2,8 @@ package infrastructure ...@@ -2,8 +2,8 @@ package infrastructure
import ( import (
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/repository"
"gorm.io/driver/postgres" "gorm.io/driver/postgres"
"gorm.io/gorm" "gorm.io/gorm"
...@@ -30,7 +30,7 @@ func InitDB(cfg *config.Config) (*gorm.DB, error) { ...@@ -30,7 +30,7 @@ func InitDB(cfg *config.Config) (*gorm.DB, error) {
// 自动迁移(始终执行,确保数据库结构与代码同步) // 自动迁移(始终执行,确保数据库结构与代码同步)
// GORM 的 AutoMigrate 只会添加新字段,不会删除或修改已有字段,是安全的 // GORM 的 AutoMigrate 只会添加新字段,不会删除或修改已有字段,是安全的
if err := model.AutoMigrate(db); err != nil { if err := repository.AutoMigrate(db); err != nil {
return nil, err return nil, err
} }
......
package model
import (
"time"
)
type AccountGroup struct {
AccountID int64 `gorm:"primaryKey" json:"account_id"`
GroupID int64 `gorm:"primaryKey" json:"group_id"`
Priority int `gorm:"default:50;not null" json:"priority"` // 分组内优先级
CreatedAt time.Time `gorm:"not null" json:"created_at"`
// 关联
Account *Account `gorm:"foreignKey:AccountID" json:"account,omitempty"`
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
}
func (AccountGroup) TableName() string {
return "account_groups"
}
package model
import (
"time"
"gorm.io/gorm"
)
type ApiKey struct {
ID int64 `gorm:"primaryKey" json:"id"`
UserID int64 `gorm:"index;not null" json:"user_id"`
Key string `gorm:"uniqueIndex;size:128;not null" json:"key"` // sk-xxx
Name string `gorm:"size:100;not null" json:"name"`
GroupID *int64 `gorm:"index" json:"group_id"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// 关联
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
}
func (ApiKey) TableName() string {
return "api_keys"
}
// IsActive 检查是否激活
func (k *ApiKey) IsActive() bool {
return k.Status == "active"
}
package model
import (
"time"
"gorm.io/gorm"
)
// 订阅类型常量
const (
SubscriptionTypeStandard = "standard" // 标准计费模式(按余额扣费)
SubscriptionTypeSubscription = "subscription" // 订阅模式(按限额控制)
)
type Group struct {
ID int64 `gorm:"primaryKey" json:"id"`
Name string `gorm:"uniqueIndex;size:100;not null" json:"name"`
Description string `gorm:"type:text" json:"description"`
Platform string `gorm:"size:50;default:anthropic;not null" json:"platform"` // anthropic/openai/gemini
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null" json:"rate_multiplier"`
IsExclusive bool `gorm:"default:false;not null" json:"is_exclusive"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
// 订阅功能字段
SubscriptionType string `gorm:"size:20;default:standard;not null" json:"subscription_type"` // standard/subscription
DailyLimitUSD *float64 `gorm:"type:decimal(20,8)" json:"daily_limit_usd"`
WeeklyLimitUSD *float64 `gorm:"type:decimal(20,8)" json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `gorm:"type:decimal(20,8)" json:"monthly_limit_usd"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// 关联
AccountGroups []AccountGroup `gorm:"foreignKey:GroupID" json:"account_groups,omitempty"`
// 虚拟字段 (不存储到数据库)
AccountCount int64 `gorm:"-" json:"account_count,omitempty"`
}
func (Group) TableName() string {
return "groups"
}
// IsActive 检查是否激活
func (g *Group) IsActive() bool {
return g.Status == "active"
}
// IsSubscriptionType 检查是否为订阅类型分组
func (g *Group) IsSubscriptionType() bool {
return g.SubscriptionType == SubscriptionTypeSubscription
}
// IsFreeSubscription 检查是否为免费订阅(不扣余额但有限额)
func (g *Group) IsFreeSubscription() bool {
return g.IsSubscriptionType() && g.RateMultiplier == 0
}
// HasDailyLimit 检查是否有日限额
func (g *Group) HasDailyLimit() bool {
return g.DailyLimitUSD != nil && *g.DailyLimitUSD > 0
}
// HasWeeklyLimit 检查是否有周限额
func (g *Group) HasWeeklyLimit() bool {
return g.WeeklyLimitUSD != nil && *g.WeeklyLimitUSD > 0
}
// HasMonthlyLimit 检查是否有月限额
func (g *Group) HasMonthlyLimit() bool {
return g.MonthlyLimitUSD != nil && *g.MonthlyLimitUSD > 0
}
package model
import (
"gorm.io/gorm"
)
// AutoMigrate 自动迁移所有模型
func AutoMigrate(db *gorm.DB) error {
return db.AutoMigrate(
&User{},
&ApiKey{},
&Group{},
&Account{},
&AccountGroup{},
&Proxy{},
&RedeemCode{},
&UsageLog{},
&Setting{},
&UserSubscription{},
)
}
// 状态常量
const (
StatusActive = "active"
StatusDisabled = "disabled"
StatusError = "error"
StatusUnused = "unused"
StatusUsed = "used"
StatusExpired = "expired"
)
// 角色常量
const (
RoleAdmin = "admin"
RoleUser = "user"
)
// 平台常量
const (
PlatformAnthropic = "anthropic"
PlatformOpenAI = "openai"
PlatformGemini = "gemini"
)
// 账号类型常量
const (
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
AccountTypeApiKey = "apikey" // API Key类型账号
)
// 卡密类型常量
const (
RedeemTypeBalance = "balance"
RedeemTypeConcurrency = "concurrency"
RedeemTypeSubscription = "subscription"
)
// 管理员调整类型常量
const (
AdjustmentTypeAdminBalance = "admin_balance" // 管理员调整余额
AdjustmentTypeAdminConcurrency = "admin_concurrency" // 管理员调整并发数
)
package model
import (
"fmt"
"time"
"gorm.io/gorm"
)
type Proxy struct {
ID int64 `gorm:"primaryKey" json:"id"`
Name string `gorm:"size:100;not null" json:"name"`
Protocol string `gorm:"size:20;not null" json:"protocol"` // http/https/socks5
Host string `gorm:"size:255;not null" json:"host"`
Port int `gorm:"not null" json:"port"`
Username string `gorm:"size:100" json:"username"`
Password string `gorm:"size:100" json:"-"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
}
func (Proxy) TableName() string {
return "proxies"
}
// IsActive 检查是否激活
func (p *Proxy) IsActive() bool {
return p.Status == "active"
}
// URL 返回代理URL
func (p *Proxy) URL() string {
if p.Username != "" && p.Password != "" {
return fmt.Sprintf("%s://%s:%s@%s:%d", p.Protocol, p.Username, p.Password, p.Host, p.Port)
}
return fmt.Sprintf("%s://%s:%d", p.Protocol, p.Host, p.Port)
}
// ProxyWithAccountCount extends Proxy with account count information
type ProxyWithAccountCount struct {
Proxy
AccountCount int64 `json:"account_count"`
}
package model
import (
"crypto/rand"
"encoding/hex"
"time"
)
type RedeemCode struct {
ID int64 `gorm:"primaryKey" json:"id"`
Code string `gorm:"uniqueIndex;size:32;not null" json:"code"`
Type string `gorm:"size:20;default:balance;not null" json:"type"` // balance/concurrency/subscription
Value float64 `gorm:"type:decimal(20,8);not null" json:"value"` // 面值(USD)或并发数或有效天数
Status string `gorm:"size:20;default:unused;not null" json:"status"` // unused/used
UsedBy *int64 `gorm:"index" json:"used_by"`
UsedAt *time.Time `json:"used_at"`
Notes string `gorm:"type:text" json:"notes"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
// 订阅类型专用字段
GroupID *int64 `gorm:"index" json:"group_id"` // 订阅分组ID (仅subscription类型使用)
ValidityDays int `gorm:"default:30" json:"validity_days"` // 订阅有效天数 (仅subscription类型使用)
// 关联
User *User `gorm:"foreignKey:UsedBy" json:"user,omitempty"`
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
}
func (RedeemCode) TableName() string {
return "redeem_codes"
}
// IsUsed 检查是否已使用
func (r *RedeemCode) IsUsed() bool {
return r.Status == "used"
}
// CanUse 检查是否可以使用
func (r *RedeemCode) CanUse() bool {
return r.Status == "unused"
}
// GenerateRedeemCode 生成唯一的兑换码
func GenerateRedeemCode() (string, error) {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return "", err
}
return hex.EncodeToString(b), nil
}
package model
import (
"time"
)
// 消费类型常量
const (
BillingTypeBalance int8 = 0 // 钱包余额
BillingTypeSubscription int8 = 1 // 订阅套餐
)
type UsageLog struct {
ID int64 `gorm:"primaryKey" json:"id"`
UserID int64 `gorm:"index;not null" json:"user_id"`
ApiKeyID int64 `gorm:"index;not null" json:"api_key_id"`
AccountID int64 `gorm:"index;not null" json:"account_id"`
RequestID string `gorm:"size:64" json:"request_id"`
Model string `gorm:"size:100;index;not null" json:"model"`
// 订阅关联(可选)
GroupID *int64 `gorm:"index" json:"group_id"`
SubscriptionID *int64 `gorm:"index" json:"subscription_id"`
// Token使用量(4类)
InputTokens int `gorm:"default:0;not null" json:"input_tokens"`
OutputTokens int `gorm:"default:0;not null" json:"output_tokens"`
CacheCreationTokens int `gorm:"default:0;not null" json:"cache_creation_tokens"`
CacheReadTokens int `gorm:"default:0;not null" json:"cache_read_tokens"`
// 详细的缓存创建分类
CacheCreation5mTokens int `gorm:"default:0;not null" json:"cache_creation_5m_tokens"`
CacheCreation1hTokens int `gorm:"default:0;not null" json:"cache_creation_1h_tokens"`
// 费用(USD)
InputCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"input_cost"`
OutputCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"output_cost"`
CacheCreationCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_creation_cost"`
CacheReadCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_read_cost"`
TotalCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"total_cost"` // 原始总费用
ActualCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"actual_cost"` // 实际扣除费用
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1;not null" json:"rate_multiplier"` // 计费倍率
// 元数据
BillingType int8 `gorm:"type:smallint;default:0;not null" json:"billing_type"` // 0=余额 1=订阅
Stream bool `gorm:"default:false;not null" json:"stream"`
DurationMs *int `json:"duration_ms"`
FirstTokenMs *int `json:"first_token_ms"` // 首字时间(流式请求)
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
// 关联
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
ApiKey *ApiKey `gorm:"foreignKey:ApiKeyID" json:"api_key,omitempty"`
Account *Account `gorm:"foreignKey:AccountID" json:"account,omitempty"`
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
Subscription *UserSubscription `gorm:"foreignKey:SubscriptionID" json:"subscription,omitempty"`
}
func (UsageLog) TableName() string {
return "usage_logs"
}
// TotalTokens 总token数
func (u *UsageLog) TotalTokens() int {
return u.InputTokens + u.OutputTokens + u.CacheCreationTokens + u.CacheReadTokens
}
package model
import (
"time"
"github.com/lib/pq"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
type User struct {
ID int64 `gorm:"primaryKey" json:"id"`
Email string `gorm:"uniqueIndex;size:255;not null" json:"email"`
Username string `gorm:"size:100;default:''" json:"username"`
Wechat string `gorm:"size:100;default:''" json:"wechat"`
Notes string `gorm:"type:text;default:''" json:"notes"`
PasswordHash string `gorm:"size:255;not null" json:"-"`
Role string `gorm:"size:20;default:user;not null" json:"role"` // admin/user
Balance float64 `gorm:"type:decimal(20,8);default:0;not null" json:"balance"`
Concurrency int `gorm:"default:5;not null" json:"concurrency"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
AllowedGroups pq.Int64Array `gorm:"type:bigint[]" json:"allowed_groups"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// 关联
ApiKeys []ApiKey `gorm:"foreignKey:UserID" json:"api_keys,omitempty"`
Subscriptions []UserSubscription `gorm:"foreignKey:UserID" json:"subscriptions,omitempty"`
}
func (User) TableName() string {
return "users"
}
// IsAdmin 检查是否管理员
func (u *User) IsAdmin() bool {
return u.Role == "admin"
}
// IsActive 检查是否激活
func (u *User) IsActive() bool {
return u.Status == "active"
}
// CanBindGroup 检查是否可以绑定指定分组
// 对于标准类型分组:
// - 如果 AllowedGroups 设置了值(非空数组),只能绑定列表中的分组
// - 如果 AllowedGroups 为 nil 或空数组,可以绑定所有非专属分组
func (u *User) CanBindGroup(groupID int64, isExclusive bool) bool {
// 如果设置了 allowed_groups 且不为空,只能绑定指定的分组
if len(u.AllowedGroups) > 0 {
for _, id := range u.AllowedGroups {
if id == groupID {
return true
}
}
return false
}
// 如果没有设置 allowed_groups 或为空数组,可以绑定所有非专属分组
return !isExclusive
}
// SetPassword 设置密码(哈希存储)
func (u *User) SetPassword(password string) error {
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return err
}
u.PasswordHash = string(hash)
return nil
}
// CheckPassword 验证密码
func (u *User) CheckPassword(password string) bool {
err := bcrypt.CompareHashAndPassword([]byte(u.PasswordHash), []byte(password))
return err == nil
}
...@@ -5,10 +5,10 @@ import ( ...@@ -5,10 +5,10 @@ import (
"errors" "errors"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"gorm.io/datatypes"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
) )
...@@ -21,69 +21,66 @@ func NewAccountRepository(db *gorm.DB) service.AccountRepository { ...@@ -21,69 +21,66 @@ func NewAccountRepository(db *gorm.DB) service.AccountRepository {
return &accountRepository{db: db} return &accountRepository{db: db}
} }
func (r *accountRepository) Create(ctx context.Context, account *model.Account) error { func (r *accountRepository) Create(ctx context.Context, account *service.Account) error {
return r.db.WithContext(ctx).Create(account).Error m := accountModelFromService(account)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyAccountModelToService(account, m)
}
return err
} }
func (r *accountRepository) GetByID(ctx context.Context, id int64) (*model.Account, error) { func (r *accountRepository) GetByID(ctx context.Context, id int64) (*service.Account, error) {
var account model.Account var m accountModel
err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups.Group").First(&account, id).Error err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups.Group").First(&m, id).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrAccountNotFound, nil) return nil, translatePersistenceError(err, service.ErrAccountNotFound, nil)
} }
// 填充 GroupIDs 和 Groups 虚拟字段 return accountModelToService(&m), nil
account.GroupIDs = make([]int64, 0, len(account.AccountGroups))
account.Groups = make([]*model.Group, 0, len(account.AccountGroups))
for _, ag := range account.AccountGroups {
account.GroupIDs = append(account.GroupIDs, ag.GroupID)
if ag.Group != nil {
account.Groups = append(account.Groups, ag.Group)
}
}
return &account, nil
} }
func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error) { func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*service.Account, error) {
if crsAccountID == "" { if crsAccountID == "" {
return nil, nil return nil, nil
} }
var account model.Account var m accountModel
err := r.db.WithContext(ctx).Where("extra->>'crs_account_id' = ?", crsAccountID).First(&account).Error err := r.db.WithContext(ctx).Where("extra->>'crs_account_id' = ?", crsAccountID).First(&m).Error
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil return nil, nil
} }
return nil, err return nil, err
} }
return &account, nil return accountModelToService(&m), nil
} }
func (r *accountRepository) Update(ctx context.Context, account *model.Account) error { func (r *accountRepository) Update(ctx context.Context, account *service.Account) error {
return r.db.WithContext(ctx).Save(account).Error m := accountModelFromService(account)
err := r.db.WithContext(ctx).Save(m).Error
if err == nil {
applyAccountModelToService(account, m)
}
return err
} }
func (r *accountRepository) Delete(ctx context.Context, id int64) error { func (r *accountRepository) Delete(ctx context.Context, id int64) error {
// 先删除账号与分组的绑定关系 if err := r.db.WithContext(ctx).Where("account_id = ?", id).Delete(&accountGroupModel{}).Error; err != nil {
if err := r.db.WithContext(ctx).Where("account_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil {
return err return err
} }
// 再删除账号 return r.db.WithContext(ctx).Delete(&accountModel{}, id).Error
return r.db.WithContext(ctx).Delete(&model.Account{}, id).Error
} }
func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) { func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "", "") return r.ListWithFilters(ctx, params, "", "", "", "")
} }
// ListWithFilters lists accounts with optional filtering by platform, type, status, and search query func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) {
func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error) { var accounts []accountModel
var accounts []model.Account
var total int64 var total int64
db := r.db.WithContext(ctx).Model(&model.Account{}) db := r.db.WithContext(ctx).Model(&accountModel{})
// Apply filters
if platform != "" { if platform != "" {
db = db.Where("platform = ?", platform) db = db.Where("platform = ?", platform)
} }
...@@ -106,67 +103,84 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati ...@@ -106,67 +103,84 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
return nil, nil, err return nil, nil, err
} }
// 填充每个 Account 的虚拟字段(GroupIDs 和 Groups) outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts { for i := range accounts {
accounts[i].GroupIDs = make([]int64, 0, len(accounts[i].AccountGroups)) outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
accounts[i].Groups = make([]*model.Group, 0, len(accounts[i].AccountGroups))
for _, ag := range accounts[i].AccountGroups {
accounts[i].GroupIDs = append(accounts[i].GroupIDs, ag.GroupID)
if ag.Group != nil {
accounts[i].Groups = append(accounts[i].Groups, ag.Group)
}
}
} }
pages := int(total) / params.Limit() return outAccounts, paginationResultFromTotal(total, params), nil
if int(total)%params.Limit() > 0 {
pages++
}
return accounts, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
} }
func (r *accountRepository) ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) { func (r *accountRepository) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {
var accounts []model.Account var accounts []accountModel
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.account_id = accounts.id"). Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
Where("account_groups.group_id = ? AND accounts.status = ?", groupID, model.StatusActive). Where("account_groups.group_id = ? AND accounts.status = ?", groupID, service.StatusActive).
Preload("Proxy"). Preload("Proxy").
Order("account_groups.priority ASC, accounts.priority ASC"). Order("account_groups.priority ASC, accounts.priority ASC").
Find(&accounts).Error Find(&accounts).Error
return accounts, err if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
} }
func (r *accountRepository) ListActive(ctx context.Context) ([]model.Account, error) { func (r *accountRepository) ListActive(ctx context.Context) ([]service.Account, error) {
var accounts []model.Account var accounts []accountModel
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("status = ?", model.StatusActive). Where("status = ?", service.StatusActive).
Preload("Proxy"). Preload("Proxy").
Order("priority ASC"). Order("priority ASC").
Find(&accounts).Error Find(&accounts).Error
return accounts, err if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
}
func (r *accountRepository) ListByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
var accounts []accountModel
err := r.db.WithContext(ctx).
Where("platform = ? AND status = ?", platform, service.StatusActive).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
} }
func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error { func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error {
now := time.Now() now := time.Now()
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).Update("last_used_at", now).Error return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).Update("last_used_at", now).Error
} }
func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg string) error { func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
"status": model.StatusError, "status": service.StatusError,
"error_message": errorMsg, "error_message": errorMsg,
}).Error }).Error
} }
func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error { func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
ag := &model.AccountGroup{ ag := &accountGroupModel{
AccountID: accountID, AccountID: accountID,
GroupID: groupID, GroupID: groupID,
Priority: priority, Priority: priority,
...@@ -176,131 +190,148 @@ func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID i ...@@ -176,131 +190,148 @@ func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID i
func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error { func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error {
return r.db.WithContext(ctx).Where("account_id = ? AND group_id = ?", accountID, groupID). return r.db.WithContext(ctx).Where("account_id = ? AND group_id = ?", accountID, groupID).
Delete(&model.AccountGroup{}).Error Delete(&accountGroupModel{}).Error
} }
func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]model.Group, error) { func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]service.Group, error) {
var groups []model.Group var groups []groupModel
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.group_id = groups.id"). Joins("JOIN account_groups ON account_groups.group_id = groups.id").
Where("account_groups.account_id = ?", accountID). Where("account_groups.account_id = ?", accountID).
Find(&groups).Error Find(&groups).Error
return groups, err if err != nil {
} return nil, err
}
func (r *accountRepository) ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) { outGroups := make([]service.Group, 0, len(groups))
var accounts []model.Account for i := range groups {
err := r.db.WithContext(ctx). outGroups = append(outGroups, *groupModelToService(&groups[i]))
Where("platform = ? AND status = ?", platform, model.StatusActive). }
Preload("Proxy"). return outGroups, nil
Order("priority ASC").
Find(&accounts).Error
return accounts, err
} }
func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error { func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
// 删除现有绑定 if err := r.db.WithContext(ctx).Where("account_id = ?", accountID).Delete(&accountGroupModel{}).Error; err != nil {
if err := r.db.WithContext(ctx).Where("account_id = ?", accountID).Delete(&model.AccountGroup{}).Error; err != nil {
return err return err
} }
// 添加新绑定 if len(groupIDs) == 0 {
if len(groupIDs) > 0 { return nil
accountGroups := make([]model.AccountGroup, 0, len(groupIDs))
for i, groupID := range groupIDs {
accountGroups = append(accountGroups, model.AccountGroup{
AccountID: accountID,
GroupID: groupID,
Priority: i + 1, // 使用索引作为优先级
})
}
return r.db.WithContext(ctx).Create(&accountGroups).Error
} }
return nil accountGroups := make([]accountGroupModel, 0, len(groupIDs))
for i, groupID := range groupIDs {
accountGroups = append(accountGroups, accountGroupModel{
AccountID: accountID,
GroupID: groupID,
Priority: i + 1,
})
}
return r.db.WithContext(ctx).Create(&accountGroups).Error
} }
// ListSchedulable 获取所有可调度的账号 func (r *accountRepository) ListSchedulable(ctx context.Context) ([]service.Account, error) {
func (r *accountRepository) ListSchedulable(ctx context.Context) ([]model.Account, error) { var accounts []accountModel
var accounts []model.Account
now := time.Now() now := time.Now()
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("status = ? AND schedulable = ?", model.StatusActive, true). Where("status = ? AND schedulable = ?", service.StatusActive, true).
Where("(overload_until IS NULL OR overload_until <= ?)", now). Where("(overload_until IS NULL OR overload_until <= ?)", now).
Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now). Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now).
Preload("Proxy"). Preload("Proxy").
Order("priority ASC"). Order("priority ASC").
Find(&accounts).Error Find(&accounts).Error
return accounts, err if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
} }
// ListSchedulableByGroupID 按组获取可调度的账号 func (r *accountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) {
func (r *accountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error) { var accounts []accountModel
var accounts []model.Account
now := time.Now() now := time.Now()
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.account_id = accounts.id"). Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
Where("account_groups.group_id = ?", groupID). Where("account_groups.group_id = ?", groupID).
Where("accounts.status = ? AND accounts.schedulable = ?", model.StatusActive, true). Where("accounts.status = ? AND accounts.schedulable = ?", service.StatusActive, true).
Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now). Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now).
Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now). Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now).
Preload("Proxy"). Preload("Proxy").
Order("account_groups.priority ASC, accounts.priority ASC"). Order("account_groups.priority ASC, accounts.priority ASC").
Find(&accounts).Error Find(&accounts).Error
return accounts, err if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
} }
// ListSchedulableByPlatform 按平台获取可调度的账号 func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]model.Account, error) { var accounts []accountModel
var accounts []model.Account
now := time.Now() now := time.Now()
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("platform = ?", platform). Where("platform = ?", platform).
Where("status = ? AND schedulable = ?", model.StatusActive, true). Where("status = ? AND schedulable = ?", service.StatusActive, true).
Where("(overload_until IS NULL OR overload_until <= ?)", now). Where("(overload_until IS NULL OR overload_until <= ?)", now).
Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now). Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now).
Preload("Proxy"). Preload("Proxy").
Order("priority ASC"). Order("priority ASC").
Find(&accounts).Error Find(&accounts).Error
return accounts, err if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
} }
// ListSchedulableByGroupIDAndPlatform 按组和平台获取可调度的账号 func (r *accountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
func (r *accountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]model.Account, error) { var accounts []accountModel
var accounts []model.Account
now := time.Now() now := time.Now()
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.account_id = accounts.id"). Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
Where("account_groups.group_id = ?", groupID). Where("account_groups.group_id = ?", groupID).
Where("accounts.platform = ?", platform). Where("accounts.platform = ?", platform).
Where("accounts.status = ? AND accounts.schedulable = ?", model.StatusActive, true). Where("accounts.status = ? AND accounts.schedulable = ?", service.StatusActive, true).
Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now). Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now).
Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now). Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now).
Preload("Proxy"). Preload("Proxy").
Order("account_groups.priority ASC, accounts.priority ASC"). Order("account_groups.priority ASC, accounts.priority ASC").
Find(&accounts).Error Find(&accounts).Error
return accounts, err if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
} }
// SetRateLimited 标记账号为限流状态(429)
func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
now := time.Now() now := time.Now()
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
"rate_limited_at": now, "rate_limited_at": now,
"rate_limit_reset_at": resetAt, "rate_limit_reset_at": resetAt,
}).Error }).Error
} }
// SetOverloaded 标记账号为过载状态(529)
func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error { func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Update("overload_until", until).Error Update("overload_until", until).Error
} }
// ClearRateLimit 清除账号的限流状态
func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error { func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
"rate_limited_at": nil, "rate_limited_at": nil,
"rate_limit_reset_at": nil, "rate_limit_reset_at": nil,
...@@ -308,7 +339,6 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error ...@@ -308,7 +339,6 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error
}).Error }).Error
} }
// UpdateSessionWindow 更新账号的5小时时间窗口信息
func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error { func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
updates := map[string]any{ updates := map[string]any{
"session_window_status": status, "session_window_status": status,
...@@ -319,45 +349,35 @@ func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, s ...@@ -319,45 +349,35 @@ func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, s
if end != nil { if end != nil {
updates["session_window_end"] = end updates["session_window_end"] = end
} }
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).Updates(updates).Error return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).Updates(updates).Error
} }
// SetSchedulable 设置账号的调度开关
func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Update("schedulable", schedulable).Error Update("schedulable", schedulable).Error
} }
// UpdateExtra updates specific fields in account's Extra JSONB field
// It merges the updates into existing Extra data without overwriting other fields
func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error { func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
if len(updates) == 0 { if len(updates) == 0 {
return nil return nil
} }
// Get current account to preserve existing Extra data var account accountModel
var account model.Account
if err := r.db.WithContext(ctx).Select("extra").Where("id = ?", id).First(&account).Error; err != nil { if err := r.db.WithContext(ctx).Select("extra").Where("id = ?", id).First(&account).Error; err != nil {
return err return err
} }
// Initialize Extra if nil
if account.Extra == nil { if account.Extra == nil {
account.Extra = make(model.JSONB) account.Extra = datatypes.JSONMap{}
} }
// Merge updates into existing Extra
for k, v := range updates { for k, v := range updates {
account.Extra[k] = v account.Extra[k] = v
} }
// Save updated Extra return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Update("extra", account.Extra).Error Update("extra", account.Extra).Error
} }
// BulkUpdate updates multiple accounts with the provided fields.
// It merges credentials/extra JSONB fields instead of overwriting them.
func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) { func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
if len(ids) == 0 { if len(ids) == 0 {
return 0, nil return 0, nil
...@@ -381,10 +401,10 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates ...@@ -381,10 +401,10 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
updateMap["status"] = *updates.Status updateMap["status"] = *updates.Status
} }
if len(updates.Credentials) > 0 { if len(updates.Credentials) > 0 {
updateMap["credentials"] = gorm.Expr("COALESCE(credentials,'{}') || ?", updates.Credentials) updateMap["credentials"] = gorm.Expr("COALESCE(credentials,'{}') || ?", datatypes.JSONMap(updates.Credentials))
} }
if len(updates.Extra) > 0 { if len(updates.Extra) > 0 {
updateMap["extra"] = gorm.Expr("COALESCE(extra,'{}') || ?", updates.Extra) updateMap["extra"] = gorm.Expr("COALESCE(extra,'{}') || ?", datatypes.JSONMap(updates.Extra))
} }
if len(updateMap) == 0 { if len(updateMap) == 0 {
...@@ -392,10 +412,178 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates ...@@ -392,10 +412,178 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
} }
result := r.db.WithContext(ctx). result := r.db.WithContext(ctx).
Model(&model.Account{}). Model(&accountModel{}).
Where("id IN ?", ids). Where("id IN ?", ids).
Clauses(clause.Returning{}). Clauses(clause.Returning{}).
Updates(updateMap) Updates(updateMap)
return result.RowsAffected, result.Error return result.RowsAffected, result.Error
} }
type accountModel struct {
ID int64 `gorm:"primaryKey"`
Name string `gorm:"size:100;not null"`
Platform string `gorm:"size:50;not null"`
Type string `gorm:"size:20;not null"`
Credentials datatypes.JSONMap `gorm:"type:jsonb;default:'{}'"`
Extra datatypes.JSONMap `gorm:"type:jsonb;default:'{}'"`
ProxyID *int64 `gorm:"index"`
Concurrency int `gorm:"default:3;not null"`
Priority int `gorm:"default:50;not null"`
Status string `gorm:"size:20;default:active;not null"`
ErrorMessage string `gorm:"type:text"`
LastUsedAt *time.Time `gorm:"index"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
DeletedAt gorm.DeletedAt `gorm:"index"`
Schedulable bool `gorm:"default:true;not null"`
RateLimitedAt *time.Time `gorm:"index"`
RateLimitResetAt *time.Time `gorm:"index"`
OverloadUntil *time.Time `gorm:"index"`
SessionWindowStart *time.Time
SessionWindowEnd *time.Time
SessionWindowStatus string `gorm:"size:20"`
Proxy *proxyModel `gorm:"foreignKey:ProxyID"`
AccountGroups []accountGroupModel `gorm:"foreignKey:AccountID"`
}
func (accountModel) TableName() string { return "accounts" }
type accountGroupModel struct {
AccountID int64 `gorm:"primaryKey"`
GroupID int64 `gorm:"primaryKey"`
Priority int `gorm:"default:50;not null"`
CreatedAt time.Time `gorm:"not null"`
Account *accountModel `gorm:"foreignKey:AccountID"`
Group *groupModel `gorm:"foreignKey:GroupID"`
}
func (accountGroupModel) TableName() string { return "account_groups" }
func accountGroupModelToService(m *accountGroupModel) *service.AccountGroup {
if m == nil {
return nil
}
return &service.AccountGroup{
AccountID: m.AccountID,
GroupID: m.GroupID,
Priority: m.Priority,
CreatedAt: m.CreatedAt,
Account: accountModelToService(m.Account),
Group: groupModelToService(m.Group),
}
}
func accountModelToService(m *accountModel) *service.Account {
if m == nil {
return nil
}
var credentials map[string]any
if m.Credentials != nil {
credentials = map[string]any(m.Credentials)
}
var extra map[string]any
if m.Extra != nil {
extra = map[string]any(m.Extra)
}
account := &service.Account{
ID: m.ID,
Name: m.Name,
Platform: m.Platform,
Type: m.Type,
Credentials: credentials,
Extra: extra,
ProxyID: m.ProxyID,
Concurrency: m.Concurrency,
Priority: m.Priority,
Status: m.Status,
ErrorMessage: m.ErrorMessage,
LastUsedAt: m.LastUsedAt,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
Schedulable: m.Schedulable,
RateLimitedAt: m.RateLimitedAt,
RateLimitResetAt: m.RateLimitResetAt,
OverloadUntil: m.OverloadUntil,
SessionWindowStart: m.SessionWindowStart,
SessionWindowEnd: m.SessionWindowEnd,
SessionWindowStatus: m.SessionWindowStatus,
Proxy: proxyModelToService(m.Proxy),
}
if len(m.AccountGroups) > 0 {
account.AccountGroups = make([]service.AccountGroup, 0, len(m.AccountGroups))
account.GroupIDs = make([]int64, 0, len(m.AccountGroups))
account.Groups = make([]*service.Group, 0, len(m.AccountGroups))
for i := range m.AccountGroups {
ag := accountGroupModelToService(&m.AccountGroups[i])
if ag == nil {
continue
}
account.AccountGroups = append(account.AccountGroups, *ag)
account.GroupIDs = append(account.GroupIDs, ag.GroupID)
if ag.Group != nil {
account.Groups = append(account.Groups, ag.Group)
}
}
}
return account
}
func accountModelFromService(a *service.Account) *accountModel {
if a == nil {
return nil
}
var credentials datatypes.JSONMap
if a.Credentials != nil {
credentials = datatypes.JSONMap(a.Credentials)
}
var extra datatypes.JSONMap
if a.Extra != nil {
extra = datatypes.JSONMap(a.Extra)
}
return &accountModel{
ID: a.ID,
Name: a.Name,
Platform: a.Platform,
Type: a.Type,
Credentials: credentials,
Extra: extra,
ProxyID: a.ProxyID,
Concurrency: a.Concurrency,
Priority: a.Priority,
Status: a.Status,
ErrorMessage: a.ErrorMessage,
LastUsedAt: a.LastUsedAt,
CreatedAt: a.CreatedAt,
UpdatedAt: a.UpdatedAt,
Schedulable: a.Schedulable,
RateLimitedAt: a.RateLimitedAt,
RateLimitResetAt: a.RateLimitResetAt,
OverloadUntil: a.OverloadUntil,
SessionWindowStart: a.SessionWindowStart,
SessionWindowEnd: a.SessionWindowEnd,
SessionWindowStatus: a.SessionWindowStatus,
}
}
func applyAccountModelToService(account *service.Account, m *accountModel) {
if account == nil || m == nil {
return
}
account.ID = m.ID
account.CreatedAt = m.CreatedAt
account.UpdatedAt = m.UpdatedAt
}
...@@ -7,10 +7,10 @@ import ( ...@@ -7,10 +7,10 @@ import (
"testing" "testing"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"gorm.io/datatypes"
"gorm.io/gorm" "gorm.io/gorm"
) )
...@@ -34,11 +34,16 @@ func TestAccountRepoSuite(t *testing.T) { ...@@ -34,11 +34,16 @@ func TestAccountRepoSuite(t *testing.T) {
// --- Create / GetByID / Update / Delete --- // --- Create / GetByID / Update / Delete ---
func (s *AccountRepoSuite) TestCreate() { func (s *AccountRepoSuite) TestCreate() {
account := &model.Account{ account := &service.Account{
Name: "test-create", Name: "test-create",
Platform: model.PlatformAnthropic, Platform: service.PlatformAnthropic,
Type: model.AccountTypeOAuth, Type: service.AccountTypeOAuth,
Status: model.StatusActive, Status: service.StatusActive,
Credentials: map[string]any{},
Extra: map[string]any{},
Concurrency: 3,
Priority: 50,
Schedulable: true,
} }
err := s.repo.Create(s.ctx, account) err := s.repo.Create(s.ctx, account)
...@@ -56,7 +61,7 @@ func (s *AccountRepoSuite) TestGetByID_NotFound() { ...@@ -56,7 +61,7 @@ func (s *AccountRepoSuite) TestGetByID_NotFound() {
} }
func (s *AccountRepoSuite) TestUpdate() { func (s *AccountRepoSuite) TestUpdate() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "original"}) account := accountModelToService(mustCreateAccount(s.T(), s.db, &accountModel{Name: "original"}))
account.Name = "updated" account.Name = "updated"
err := s.repo.Update(s.ctx, account) err := s.repo.Update(s.ctx, account)
...@@ -68,7 +73,7 @@ func (s *AccountRepoSuite) TestUpdate() { ...@@ -68,7 +73,7 @@ func (s *AccountRepoSuite) TestUpdate() {
} }
func (s *AccountRepoSuite) TestDelete() { func (s *AccountRepoSuite) TestDelete() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "to-delete"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "to-delete"})
err := s.repo.Delete(s.ctx, account.ID) err := s.repo.Delete(s.ctx, account.ID)
s.Require().NoError(err, "Delete") s.Require().NoError(err, "Delete")
...@@ -78,23 +83,23 @@ func (s *AccountRepoSuite) TestDelete() { ...@@ -78,23 +83,23 @@ func (s *AccountRepoSuite) TestDelete() {
} }
func (s *AccountRepoSuite) TestDelete_WithGroupBindings() { func (s *AccountRepoSuite) TestDelete_WithGroupBindings() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-del"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-del"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-del"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-del"})
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1) mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
err := s.repo.Delete(s.ctx, account.ID) err := s.repo.Delete(s.ctx, account.ID)
s.Require().NoError(err, "Delete should cascade remove bindings") s.Require().NoError(err, "Delete should cascade remove bindings")
var count int64 var count int64
s.db.Model(&model.AccountGroup{}).Where("account_id = ?", account.ID).Count(&count) s.db.Model(&accountGroupModel{}).Where("account_id = ?", account.ID).Count(&count)
s.Require().Zero(count, "expected bindings to be removed") s.Require().Zero(count, "expected bindings to be removed")
} }
// --- List / ListWithFilters --- // --- List / ListWithFilters ---
func (s *AccountRepoSuite) TestList() { func (s *AccountRepoSuite) TestList() {
mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc1"}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc1"})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc2"}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc2"})
accounts, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}) accounts, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List") s.Require().NoError(err, "List")
...@@ -111,53 +116,53 @@ func (s *AccountRepoSuite) TestListWithFilters() { ...@@ -111,53 +116,53 @@ func (s *AccountRepoSuite) TestListWithFilters() {
status string status string
search string search string
wantCount int wantCount int
validate func(accounts []model.Account) validate func(accounts []service.Account)
}{ }{
{ {
name: "filter_by_platform", name: "filter_by_platform",
setup: func(db *gorm.DB) { setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &model.Account{Name: "a1", Platform: model.PlatformAnthropic}) mustCreateAccount(s.T(), db, &accountModel{Name: "a1", Platform: service.PlatformAnthropic})
mustCreateAccount(s.T(), db, &model.Account{Name: "a2", Platform: model.PlatformOpenAI}) mustCreateAccount(s.T(), db, &accountModel{Name: "a2", Platform: service.PlatformOpenAI})
}, },
platform: model.PlatformOpenAI, platform: service.PlatformOpenAI,
wantCount: 1, wantCount: 1,
validate: func(accounts []model.Account) { validate: func(accounts []service.Account) {
s.Require().Equal(model.PlatformOpenAI, accounts[0].Platform) s.Require().Equal(service.PlatformOpenAI, accounts[0].Platform)
}, },
}, },
{ {
name: "filter_by_type", name: "filter_by_type",
setup: func(db *gorm.DB) { setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &model.Account{Name: "t1", Type: model.AccountTypeOAuth}) mustCreateAccount(s.T(), db, &accountModel{Name: "t1", Type: service.AccountTypeOAuth})
mustCreateAccount(s.T(), db, &model.Account{Name: "t2", Type: model.AccountTypeApiKey}) mustCreateAccount(s.T(), db, &accountModel{Name: "t2", Type: service.AccountTypeApiKey})
}, },
accType: model.AccountTypeApiKey, accType: service.AccountTypeApiKey,
wantCount: 1, wantCount: 1,
validate: func(accounts []model.Account) { validate: func(accounts []service.Account) {
s.Require().Equal(model.AccountTypeApiKey, accounts[0].Type) s.Require().Equal(service.AccountTypeApiKey, accounts[0].Type)
}, },
}, },
{ {
name: "filter_by_status", name: "filter_by_status",
setup: func(db *gorm.DB) { setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &model.Account{Name: "s1", Status: model.StatusActive}) mustCreateAccount(s.T(), db, &accountModel{Name: "s1", Status: service.StatusActive})
mustCreateAccount(s.T(), db, &model.Account{Name: "s2", Status: model.StatusDisabled}) mustCreateAccount(s.T(), db, &accountModel{Name: "s2", Status: service.StatusDisabled})
}, },
status: model.StatusDisabled, status: service.StatusDisabled,
wantCount: 1, wantCount: 1,
validate: func(accounts []model.Account) { validate: func(accounts []service.Account) {
s.Require().Equal(model.StatusDisabled, accounts[0].Status) s.Require().Equal(service.StatusDisabled, accounts[0].Status)
}, },
}, },
{ {
name: "filter_by_search", name: "filter_by_search",
setup: func(db *gorm.DB) { setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &model.Account{Name: "alpha-account"}) mustCreateAccount(s.T(), db, &accountModel{Name: "alpha-account"})
mustCreateAccount(s.T(), db, &model.Account{Name: "beta-account"}) mustCreateAccount(s.T(), db, &accountModel{Name: "beta-account"})
}, },
search: "alpha", search: "alpha",
wantCount: 1, wantCount: 1,
validate: func(accounts []model.Account) { validate: func(accounts []service.Account) {
s.Require().Contains(accounts[0].Name, "alpha") s.Require().Contains(accounts[0].Name, "alpha")
}, },
}, },
...@@ -185,9 +190,9 @@ func (s *AccountRepoSuite) TestListWithFilters() { ...@@ -185,9 +190,9 @@ func (s *AccountRepoSuite) TestListWithFilters() {
// --- ListByGroup / ListActive / ListByPlatform --- // --- ListByGroup / ListActive / ListByPlatform ---
func (s *AccountRepoSuite) TestListByGroup() { func (s *AccountRepoSuite) TestListByGroup() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list"})
acc1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", Status: model.StatusActive}) acc1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", Status: service.StatusActive})
acc2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", Status: model.StatusActive}) acc2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", Status: service.StatusActive})
mustBindAccountToGroup(s.T(), s.db, acc1.ID, group.ID, 2) mustBindAccountToGroup(s.T(), s.db, acc1.ID, group.ID, 2)
mustBindAccountToGroup(s.T(), s.db, acc2.ID, group.ID, 1) mustBindAccountToGroup(s.T(), s.db, acc2.ID, group.ID, 1)
...@@ -199,8 +204,8 @@ func (s *AccountRepoSuite) TestListByGroup() { ...@@ -199,8 +204,8 @@ func (s *AccountRepoSuite) TestListByGroup() {
} }
func (s *AccountRepoSuite) TestListActive() { func (s *AccountRepoSuite) TestListActive() {
mustCreateAccount(s.T(), s.db, &model.Account{Name: "active1", Status: model.StatusActive}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "active1", Status: service.StatusActive})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "inactive1", Status: model.StatusDisabled}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "inactive1", Status: service.StatusDisabled})
accounts, err := s.repo.ListActive(s.ctx) accounts, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive") s.Require().NoError(err, "ListActive")
...@@ -209,22 +214,22 @@ func (s *AccountRepoSuite) TestListActive() { ...@@ -209,22 +214,22 @@ func (s *AccountRepoSuite) TestListActive() {
} }
func (s *AccountRepoSuite) TestListByPlatform() { func (s *AccountRepoSuite) TestListByPlatform() {
mustCreateAccount(s.T(), s.db, &model.Account{Name: "p1", Platform: model.PlatformAnthropic, Status: model.StatusActive}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "p1", Platform: service.PlatformAnthropic, Status: service.StatusActive})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "p2", Platform: model.PlatformOpenAI, Status: model.StatusActive}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "p2", Platform: service.PlatformOpenAI, Status: service.StatusActive})
accounts, err := s.repo.ListByPlatform(s.ctx, model.PlatformAnthropic) accounts, err := s.repo.ListByPlatform(s.ctx, service.PlatformAnthropic)
s.Require().NoError(err, "ListByPlatform") s.Require().NoError(err, "ListByPlatform")
s.Require().Len(accounts, 1) s.Require().Len(accounts, 1)
s.Require().Equal(model.PlatformAnthropic, accounts[0].Platform) s.Require().Equal(service.PlatformAnthropic, accounts[0].Platform)
} }
// --- Preload and VirtualFields --- // --- Preload and VirtualFields ---
func (s *AccountRepoSuite) TestPreload_And_VirtualFields() { func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"}) proxy := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1"})
account := mustCreateAccount(s.T(), s.db, &model.Account{ account := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "acc1", Name: "acc1",
ProxyID: &proxy.ID, ProxyID: &proxy.ID,
}) })
...@@ -252,9 +257,9 @@ func (s *AccountRepoSuite) TestPreload_And_VirtualFields() { ...@@ -252,9 +257,9 @@ func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
// --- GroupBinding / AddToGroup / RemoveFromGroup / BindGroups / GetGroups --- // --- GroupBinding / AddToGroup / RemoveFromGroup / BindGroups / GetGroups ---
func (s *AccountRepoSuite) TestGroupBinding_And_BindGroups() { func (s *AccountRepoSuite) TestGroupBinding_And_BindGroups() {
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1"}) g1 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1"})
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2"}) g2 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc"})
s.Require().NoError(s.repo.AddToGroup(s.ctx, account.ID, g1.ID, 10), "AddToGroup") s.Require().NoError(s.repo.AddToGroup(s.ctx, account.ID, g1.ID, 10), "AddToGroup")
groups, err := s.repo.GetGroups(s.ctx, account.ID) groups, err := s.repo.GetGroups(s.ctx, account.ID)
...@@ -274,8 +279,8 @@ func (s *AccountRepoSuite) TestGroupBinding_And_BindGroups() { ...@@ -274,8 +279,8 @@ func (s *AccountRepoSuite) TestGroupBinding_And_BindGroups() {
} }
func (s *AccountRepoSuite) TestBindGroups_EmptyList() { func (s *AccountRepoSuite) TestBindGroups_EmptyList() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-empty"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-empty"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-empty"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-empty"})
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1) mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
s.Require().NoError(s.repo.BindGroups(s.ctx, account.ID, []int64{}), "BindGroups empty") s.Require().NoError(s.repo.BindGroups(s.ctx, account.ID, []int64{}), "BindGroups empty")
...@@ -289,13 +294,13 @@ func (s *AccountRepoSuite) TestBindGroups_EmptyList() { ...@@ -289,13 +294,13 @@ func (s *AccountRepoSuite) TestBindGroups_EmptyList() {
func (s *AccountRepoSuite) TestListSchedulable() { func (s *AccountRepoSuite) TestListSchedulable() {
now := time.Now() now := time.Now()
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sched"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sched"})
okAcc := mustCreateAccount(s.T(), s.db, &model.Account{Name: "ok", Schedulable: true}) okAcc := mustCreateAccount(s.T(), s.db, &accountModel{Name: "ok", Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1) mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1)
future := now.Add(10 * time.Minute) future := now.Add(10 * time.Minute)
overloaded := mustCreateAccount(s.T(), s.db, &model.Account{Name: "over", Schedulable: true, OverloadUntil: &future}) overloaded := mustCreateAccount(s.T(), s.db, &accountModel{Name: "over", Schedulable: true, OverloadUntil: &future})
mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1) mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1)
sched, err := s.repo.ListSchedulable(s.ctx) sched, err := s.repo.ListSchedulable(s.ctx)
...@@ -307,16 +312,16 @@ func (s *AccountRepoSuite) TestListSchedulable() { ...@@ -307,16 +312,16 @@ func (s *AccountRepoSuite) TestListSchedulable() {
func (s *AccountRepoSuite) TestListSchedulableByGroupID_TimeBoundaries_And_StatusUpdates() { func (s *AccountRepoSuite) TestListSchedulableByGroupID_TimeBoundaries_And_StatusUpdates() {
now := time.Now() now := time.Now()
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sched"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sched"})
okAcc := mustCreateAccount(s.T(), s.db, &model.Account{Name: "ok", Schedulable: true}) okAcc := mustCreateAccount(s.T(), s.db, &accountModel{Name: "ok", Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1) mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1)
future := now.Add(10 * time.Minute) future := now.Add(10 * time.Minute)
overloaded := mustCreateAccount(s.T(), s.db, &model.Account{Name: "over", Schedulable: true, OverloadUntil: &future}) overloaded := mustCreateAccount(s.T(), s.db, &accountModel{Name: "over", Schedulable: true, OverloadUntil: &future})
mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1) mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1)
rateLimited := mustCreateAccount(s.T(), s.db, &model.Account{Name: "rl", Schedulable: true}) rateLimited := mustCreateAccount(s.T(), s.db, &accountModel{Name: "rl", Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, rateLimited.ID, group.ID, 1) mustBindAccountToGroup(s.T(), s.db, rateLimited.ID, group.ID, 1)
s.Require().NoError(s.repo.SetRateLimited(s.ctx, rateLimited.ID, now.Add(10*time.Minute)), "SetRateLimited") s.Require().NoError(s.repo.SetRateLimited(s.ctx, rateLimited.ID, now.Add(10*time.Minute)), "SetRateLimited")
...@@ -334,30 +339,30 @@ func (s *AccountRepoSuite) TestListSchedulableByGroupID_TimeBoundaries_And_Statu ...@@ -334,30 +339,30 @@ func (s *AccountRepoSuite) TestListSchedulableByGroupID_TimeBoundaries_And_Statu
} }
func (s *AccountRepoSuite) TestListSchedulableByPlatform() { func (s *AccountRepoSuite) TestListSchedulableByPlatform() {
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", Platform: model.PlatformAnthropic, Schedulable: true}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", Platform: service.PlatformAnthropic, Schedulable: true})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", Platform: model.PlatformOpenAI, Schedulable: true}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", Platform: service.PlatformOpenAI, Schedulable: true})
accounts, err := s.repo.ListSchedulableByPlatform(s.ctx, model.PlatformAnthropic) accounts, err := s.repo.ListSchedulableByPlatform(s.ctx, service.PlatformAnthropic)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(accounts, 1) s.Require().Len(accounts, 1)
s.Require().Equal(model.PlatformAnthropic, accounts[0].Platform) s.Require().Equal(service.PlatformAnthropic, accounts[0].Platform)
} }
func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() { func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sp"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sp"})
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", Platform: model.PlatformAnthropic, Schedulable: true}) a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", Platform: service.PlatformAnthropic, Schedulable: true})
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", Platform: model.PlatformOpenAI, Schedulable: true}) a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", Platform: service.PlatformOpenAI, Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, a1.ID, group.ID, 1) mustBindAccountToGroup(s.T(), s.db, a1.ID, group.ID, 1)
mustBindAccountToGroup(s.T(), s.db, a2.ID, group.ID, 2) mustBindAccountToGroup(s.T(), s.db, a2.ID, group.ID, 2)
accounts, err := s.repo.ListSchedulableByGroupIDAndPlatform(s.ctx, group.ID, model.PlatformAnthropic) accounts, err := s.repo.ListSchedulableByGroupIDAndPlatform(s.ctx, group.ID, service.PlatformAnthropic)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(accounts, 1) s.Require().Len(accounts, 1)
s.Require().Equal(a1.ID, accounts[0].ID) s.Require().Equal(a1.ID, accounts[0].ID)
} }
func (s *AccountRepoSuite) TestSetSchedulable() { func (s *AccountRepoSuite) TestSetSchedulable() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-sched", Schedulable: true}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-sched", Schedulable: true})
s.Require().NoError(s.repo.SetSchedulable(s.ctx, account.ID, false)) s.Require().NoError(s.repo.SetSchedulable(s.ctx, account.ID, false))
...@@ -369,7 +374,7 @@ func (s *AccountRepoSuite) TestSetSchedulable() { ...@@ -369,7 +374,7 @@ func (s *AccountRepoSuite) TestSetSchedulable() {
// --- SetOverloaded / SetRateLimited / ClearRateLimit --- // --- SetOverloaded / SetRateLimited / ClearRateLimit ---
func (s *AccountRepoSuite) TestSetOverloaded() { func (s *AccountRepoSuite) TestSetOverloaded() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-over"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-over"})
until := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC) until := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until)) s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until))
...@@ -381,7 +386,7 @@ func (s *AccountRepoSuite) TestSetOverloaded() { ...@@ -381,7 +386,7 @@ func (s *AccountRepoSuite) TestSetOverloaded() {
} }
func (s *AccountRepoSuite) TestSetRateLimited() { func (s *AccountRepoSuite) TestSetRateLimited() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-rl"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-rl"})
resetAt := time.Date(2025, 6, 15, 14, 0, 0, 0, time.UTC) resetAt := time.Date(2025, 6, 15, 14, 0, 0, 0, time.UTC)
s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, resetAt)) s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, resetAt))
...@@ -394,7 +399,7 @@ func (s *AccountRepoSuite) TestSetRateLimited() { ...@@ -394,7 +399,7 @@ func (s *AccountRepoSuite) TestSetRateLimited() {
} }
func (s *AccountRepoSuite) TestClearRateLimit() { func (s *AccountRepoSuite) TestClearRateLimit() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-clear"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-clear"})
until := time.Now().Add(1 * time.Hour) until := time.Now().Add(1 * time.Hour)
s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until)) s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until))
s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, until)) s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, until))
...@@ -411,7 +416,7 @@ func (s *AccountRepoSuite) TestClearRateLimit() { ...@@ -411,7 +416,7 @@ func (s *AccountRepoSuite) TestClearRateLimit() {
// --- UpdateLastUsed --- // --- UpdateLastUsed ---
func (s *AccountRepoSuite) TestUpdateLastUsed() { func (s *AccountRepoSuite) TestUpdateLastUsed() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-used"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-used"})
s.Require().Nil(account.LastUsedAt) s.Require().Nil(account.LastUsedAt)
s.Require().NoError(s.repo.UpdateLastUsed(s.ctx, account.ID)) s.Require().NoError(s.repo.UpdateLastUsed(s.ctx, account.ID))
...@@ -424,20 +429,20 @@ func (s *AccountRepoSuite) TestUpdateLastUsed() { ...@@ -424,20 +429,20 @@ func (s *AccountRepoSuite) TestUpdateLastUsed() {
// --- SetError --- // --- SetError ---
func (s *AccountRepoSuite) TestSetError() { func (s *AccountRepoSuite) TestSetError() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-err", Status: model.StatusActive}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-err", Status: service.StatusActive})
s.Require().NoError(s.repo.SetError(s.ctx, account.ID, "something went wrong")) s.Require().NoError(s.repo.SetError(s.ctx, account.ID, "something went wrong"))
got, err := s.repo.GetByID(s.ctx, account.ID) got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Equal(model.StatusError, got.Status) s.Require().Equal(service.StatusError, got.Status)
s.Require().Equal("something went wrong", got.ErrorMessage) s.Require().Equal("something went wrong", got.ErrorMessage)
} }
// --- UpdateSessionWindow --- // --- UpdateSessionWindow ---
func (s *AccountRepoSuite) TestUpdateSessionWindow() { func (s *AccountRepoSuite) TestUpdateSessionWindow() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-win"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-win"})
start := time.Date(2025, 6, 15, 10, 0, 0, 0, time.UTC) start := time.Date(2025, 6, 15, 10, 0, 0, 0, time.UTC)
end := time.Date(2025, 6, 15, 15, 0, 0, 0, time.UTC) end := time.Date(2025, 6, 15, 15, 0, 0, 0, time.UTC)
...@@ -453,9 +458,9 @@ func (s *AccountRepoSuite) TestUpdateSessionWindow() { ...@@ -453,9 +458,9 @@ func (s *AccountRepoSuite) TestUpdateSessionWindow() {
// --- UpdateExtra --- // --- UpdateExtra ---
func (s *AccountRepoSuite) TestUpdateExtra_MergesFields() { func (s *AccountRepoSuite) TestUpdateExtra_MergesFields() {
account := mustCreateAccount(s.T(), s.db, &model.Account{ account := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "acc-extra", Name: "acc-extra",
Extra: model.JSONB{"a": "1"}, Extra: datatypes.JSONMap{"a": "1"},
}) })
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"b": "2"}), "UpdateExtra") s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"b": "2"}), "UpdateExtra")
...@@ -466,12 +471,12 @@ func (s *AccountRepoSuite) TestUpdateExtra_MergesFields() { ...@@ -466,12 +471,12 @@ func (s *AccountRepoSuite) TestUpdateExtra_MergesFields() {
} }
func (s *AccountRepoSuite) TestUpdateExtra_EmptyUpdates() { func (s *AccountRepoSuite) TestUpdateExtra_EmptyUpdates() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-extra-empty"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-extra-empty"})
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{})) s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{}))
} }
func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() { func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-nil-extra", Extra: nil}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-nil-extra", Extra: nil})
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"key": "val"})) s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"key": "val"}))
got, err := s.repo.GetByID(s.ctx, account.ID) got, err := s.repo.GetByID(s.ctx, account.ID)
...@@ -483,9 +488,9 @@ func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() { ...@@ -483,9 +488,9 @@ func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
func (s *AccountRepoSuite) TestGetByCRSAccountID() { func (s *AccountRepoSuite) TestGetByCRSAccountID() {
crsID := "crs-12345" crsID := "crs-12345"
mustCreateAccount(s.T(), s.db, &model.Account{ mustCreateAccount(s.T(), s.db, &accountModel{
Name: "acc-crs", Name: "acc-crs",
Extra: model.JSONB{"crs_account_id": crsID}, Extra: datatypes.JSONMap{"crs_account_id": crsID},
}) })
got, err := s.repo.GetByCRSAccountID(s.ctx, crsID) got, err := s.repo.GetByCRSAccountID(s.ctx, crsID)
...@@ -509,8 +514,8 @@ func (s *AccountRepoSuite) TestGetByCRSAccountID_EmptyString() { ...@@ -509,8 +514,8 @@ func (s *AccountRepoSuite) TestGetByCRSAccountID_EmptyString() {
// --- BulkUpdate --- // --- BulkUpdate ---
func (s *AccountRepoSuite) TestBulkUpdate() { func (s *AccountRepoSuite) TestBulkUpdate() {
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk1", Priority: 1}) a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "bulk1", Priority: 1})
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk2", Priority: 1}) a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "bulk2", Priority: 1})
newPriority := 99 newPriority := 99
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID, a2.ID}, service.AccountBulkUpdate{ affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID, a2.ID}, service.AccountBulkUpdate{
...@@ -526,13 +531,13 @@ func (s *AccountRepoSuite) TestBulkUpdate() { ...@@ -526,13 +531,13 @@ func (s *AccountRepoSuite) TestBulkUpdate() {
} }
func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() { func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() {
a1 := mustCreateAccount(s.T(), s.db, &model.Account{ a1 := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "bulk-cred", Name: "bulk-cred",
Credentials: model.JSONB{"existing": "value"}, Credentials: datatypes.JSONMap{"existing": "value"},
}) })
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{ _, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{
Credentials: model.JSONB{"new_key": "new_value"}, Credentials: datatypes.JSONMap{"new_key": "new_value"},
}) })
s.Require().NoError(err) s.Require().NoError(err)
...@@ -542,13 +547,13 @@ func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() { ...@@ -542,13 +547,13 @@ func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() {
} }
func (s *AccountRepoSuite) TestBulkUpdate_MergeExtra() { func (s *AccountRepoSuite) TestBulkUpdate_MergeExtra() {
a1 := mustCreateAccount(s.T(), s.db, &model.Account{ a1 := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "bulk-extra", Name: "bulk-extra",
Extra: model.JSONB{"existing": "val"}, Extra: datatypes.JSONMap{"existing": "val"},
}) })
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{ _, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{
Extra: model.JSONB{"new_key": "new_val"}, Extra: datatypes.JSONMap{"new_key": "new_val"},
}) })
s.Require().NoError(err) s.Require().NoError(err)
...@@ -564,14 +569,14 @@ func (s *AccountRepoSuite) TestBulkUpdate_EmptyIDs() { ...@@ -564,14 +569,14 @@ func (s *AccountRepoSuite) TestBulkUpdate_EmptyIDs() {
} }
func (s *AccountRepoSuite) TestBulkUpdate_EmptyUpdates() { func (s *AccountRepoSuite) TestBulkUpdate_EmptyUpdates() {
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk-empty"}) a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "bulk-empty"})
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{}) affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{})
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Zero(affected) s.Require().Zero(affected)
} }
func idsOfAccounts(accounts []model.Account) []int64 { func idsOfAccounts(accounts []service.Account) []int64 {
out := make([]int64, 0, len(accounts)) out := make([]int64, 0, len(accounts))
for i := range accounts { for i := range accounts {
out = append(out, accounts[i].ID) out = append(out, accounts[i].ID)
......
...@@ -2,10 +2,10 @@ package repository ...@@ -2,10 +2,10 @@ package repository
import ( import (
"context" "context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm" "gorm.io/gorm"
...@@ -19,42 +19,51 @@ func NewApiKeyRepository(db *gorm.DB) service.ApiKeyRepository { ...@@ -19,42 +19,51 @@ func NewApiKeyRepository(db *gorm.DB) service.ApiKeyRepository {
return &apiKeyRepository{db: db} return &apiKeyRepository{db: db}
} }
func (r *apiKeyRepository) Create(ctx context.Context, key *model.ApiKey) error { func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) error {
err := r.db.WithContext(ctx).Create(key).Error m := apiKeyModelFromService(key)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyApiKeyModelToService(key, m)
}
return translatePersistenceError(err, nil, service.ErrApiKeyExists) return translatePersistenceError(err, nil, service.ErrApiKeyExists)
} }
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) { func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
var key model.ApiKey var m apiKeyModel
err := r.db.WithContext(ctx).Preload("User").Preload("Group").First(&key, id).Error err := r.db.WithContext(ctx).Preload("User").Preload("Group").First(&m, id).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil) return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil)
} }
return &key, nil return apiKeyModelToService(&m), nil
} }
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*model.ApiKey, error) { func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
var apiKey model.ApiKey var m apiKeyModel
err := r.db.WithContext(ctx).Preload("User").Preload("Group").Where("key = ?", key).First(&apiKey).Error err := r.db.WithContext(ctx).Preload("User").Preload("Group").Where("key = ?", key).First(&m).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil) return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil)
} }
return &apiKey, nil return apiKeyModelToService(&m), nil
} }
func (r *apiKeyRepository) Update(ctx context.Context, key *model.ApiKey) error { func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) error {
return r.db.WithContext(ctx).Model(key).Select("name", "group_id", "status", "updated_at").Updates(key).Error m := apiKeyModelFromService(key)
err := r.db.WithContext(ctx).Model(m).Select("name", "group_id", "status", "updated_at").Updates(m).Error
if err == nil {
applyApiKeyModelToService(key, m)
}
return err
} }
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error { func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.ApiKey{}, id).Error return r.db.WithContext(ctx).Delete(&apiKeyModel{}, id).Error
} }
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) { func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
var keys []model.ApiKey var keys []apiKeyModel
var total int64 var total int64
db := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("user_id = ?", userID) db := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("user_id = ?", userID)
if err := db.Count(&total).Error; err != nil { if err := db.Count(&total).Error; err != nil {
return nil, nil, err return nil, nil, err
...@@ -64,36 +73,31 @@ func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, param ...@@ -64,36 +73,31 @@ func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
return nil, nil, err return nil, nil, err
} }
pages := int(total) / params.Limit() outKeys := make([]service.ApiKey, 0, len(keys))
if int(total)%params.Limit() > 0 { for i := range keys {
pages++ outKeys = append(outKeys, *apiKeyModelToService(&keys[i]))
} }
return keys, &pagination.PaginationResult{ return outKeys, paginationResultFromTotal(total, params), nil
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
} }
func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) { func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("user_id = ?", userID).Count(&count).Error err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("user_id = ?", userID).Count(&count).Error
return count, err return count, err
} }
func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) { func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("key = ?", key).Count(&count).Error err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("key = ?", key).Count(&count).Error
return count > 0, err return count > 0, err
} }
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) { func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
var keys []model.ApiKey var keys []apiKeyModel
var total int64 var total int64
db := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("group_id = ?", groupID) db := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("group_id = ?", groupID)
if err := db.Count(&total).Error; err != nil { if err := db.Count(&total).Error; err != nil {
return nil, nil, err return nil, nil, err
...@@ -103,24 +107,19 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par ...@@ -103,24 +107,19 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
return nil, nil, err return nil, nil, err
} }
pages := int(total) / params.Limit() outKeys := make([]service.ApiKey, 0, len(keys))
if int(total)%params.Limit() > 0 { for i := range keys {
pages++ outKeys = append(outKeys, *apiKeyModelToService(&keys[i]))
} }
return keys, &pagination.PaginationResult{ return outKeys, paginationResultFromTotal(total, params), nil
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
} }
// SearchApiKeys searches API keys by user ID and/or keyword (name) // SearchApiKeys searches API keys by user ID and/or keyword (name)
func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) { func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
var keys []model.ApiKey var keys []apiKeyModel
db := r.db.WithContext(ctx).Model(&model.ApiKey{}) db := r.db.WithContext(ctx).Model(&apiKeyModel{})
if userID > 0 { if userID > 0 {
db = db.Where("user_id = ?", userID) db = db.Where("user_id = ?", userID)
...@@ -135,12 +134,16 @@ func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw ...@@ -135,12 +134,16 @@ func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw
return nil, err return nil, err
} }
return keys, nil outKeys := make([]service.ApiKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyModelToService(&keys[i]))
}
return outKeys, nil
} }
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil // ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Model(&model.ApiKey{}). result := r.db.WithContext(ctx).Model(&apiKeyModel{}).
Where("group_id = ?", groupID). Where("group_id = ?", groupID).
Update("group_id", nil) Update("group_id", nil)
return result.RowsAffected, result.Error return result.RowsAffected, result.Error
...@@ -149,6 +152,66 @@ func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID in ...@@ -149,6 +152,66 @@ func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID in
// CountByGroupID 获取分组的 API Key 数量 // CountByGroupID 获取分组的 API Key 数量
func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("group_id = ?", groupID).Count(&count).Error err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("group_id = ?", groupID).Count(&count).Error
return count, err return count, err
} }
type apiKeyModel struct {
ID int64 `gorm:"primaryKey"`
UserID int64 `gorm:"index;not null"`
Key string `gorm:"uniqueIndex;size:128;not null"`
Name string `gorm:"size:100;not null"`
GroupID *int64 `gorm:"index"`
Status string `gorm:"size:20;default:active;not null"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
DeletedAt gorm.DeletedAt `gorm:"index"`
User *userModel `gorm:"foreignKey:UserID"`
Group *groupModel `gorm:"foreignKey:GroupID"`
}
func (apiKeyModel) TableName() string { return "api_keys" }
func apiKeyModelToService(m *apiKeyModel) *service.ApiKey {
if m == nil {
return nil
}
return &service.ApiKey{
ID: m.ID,
UserID: m.UserID,
Key: m.Key,
Name: m.Name,
GroupID: m.GroupID,
Status: m.Status,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
User: userModelToService(m.User),
Group: groupModelToService(m.Group),
}
}
func apiKeyModelFromService(k *service.ApiKey) *apiKeyModel {
if k == nil {
return nil
}
return &apiKeyModel{
ID: k.ID,
UserID: k.UserID,
Key: k.Key,
Name: k.Name,
GroupID: k.GroupID,
Status: k.Status,
CreatedAt: k.CreatedAt,
UpdatedAt: k.UpdatedAt,
}
}
func applyApiKeyModelToService(key *service.ApiKey, m *apiKeyModel) {
if key == nil || m == nil {
return
}
key.ID = m.ID
key.CreatedAt = m.CreatedAt
key.UpdatedAt = m.UpdatedAt
}
...@@ -6,8 +6,8 @@ import ( ...@@ -6,8 +6,8 @@ import (
"context" "context"
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"gorm.io/gorm" "gorm.io/gorm"
) )
...@@ -32,13 +32,13 @@ func TestApiKeyRepoSuite(t *testing.T) { ...@@ -32,13 +32,13 @@ func TestApiKeyRepoSuite(t *testing.T) {
// --- Create / GetByID / GetByKey --- // --- Create / GetByID / GetByKey ---
func (s *ApiKeyRepoSuite) TestCreate() { func (s *ApiKeyRepoSuite) TestCreate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "create@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "create@test.com"})
key := &model.ApiKey{ key := &service.ApiKey{
UserID: user.ID, UserID: user.ID,
Key: "sk-create-test", Key: "sk-create-test",
Name: "Test Key", Name: "Test Key",
Status: model.StatusActive, Status: service.StatusActive,
} }
err := s.repo.Create(s.ctx, key) err := s.repo.Create(s.ctx, key)
...@@ -56,15 +56,15 @@ func (s *ApiKeyRepoSuite) TestGetByID_NotFound() { ...@@ -56,15 +56,15 @@ func (s *ApiKeyRepoSuite) TestGetByID_NotFound() {
} }
func (s *ApiKeyRepoSuite) TestGetByKey() { func (s *ApiKeyRepoSuite) TestGetByKey() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "getbykey@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "getbykey@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-key"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-key"})
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{ key := mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID, UserID: user.ID,
Key: "sk-getbykey", Key: "sk-getbykey",
Name: "My Key", Name: "My Key",
GroupID: &group.ID, GroupID: &group.ID,
Status: model.StatusActive, Status: service.StatusActive,
}) })
got, err := s.repo.GetByKey(s.ctx, key.Key) got, err := s.repo.GetByKey(s.ctx, key.Key)
...@@ -84,16 +84,16 @@ func (s *ApiKeyRepoSuite) TestGetByKey_NotFound() { ...@@ -84,16 +84,16 @@ func (s *ApiKeyRepoSuite) TestGetByKey_NotFound() {
// --- Update --- // --- Update ---
func (s *ApiKeyRepoSuite) TestUpdate() { func (s *ApiKeyRepoSuite) TestUpdate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "update@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "update@test.com"})
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{ key := apiKeyModelToService(mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID, UserID: user.ID,
Key: "sk-update", Key: "sk-update",
Name: "Original", Name: "Original",
Status: model.StatusActive, Status: service.StatusActive,
}) }))
key.Name = "Renamed" key.Name = "Renamed"
key.Status = model.StatusDisabled key.Status = service.StatusDisabled
err := s.repo.Update(s.ctx, key) err := s.repo.Update(s.ctx, key)
s.Require().NoError(err, "Update") s.Require().NoError(err, "Update")
...@@ -102,18 +102,18 @@ func (s *ApiKeyRepoSuite) TestUpdate() { ...@@ -102,18 +102,18 @@ func (s *ApiKeyRepoSuite) TestUpdate() {
s.Require().Equal("sk-update", got.Key, "Update should not change key") s.Require().Equal("sk-update", got.Key, "Update should not change key")
s.Require().Equal(user.ID, got.UserID, "Update should not change user_id") s.Require().Equal(user.ID, got.UserID, "Update should not change user_id")
s.Require().Equal("Renamed", got.Name) s.Require().Equal("Renamed", got.Name)
s.Require().Equal(model.StatusDisabled, got.Status) s.Require().Equal(service.StatusDisabled, got.Status)
} }
func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() { func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "cleargroup@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "cleargroup@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-clear"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-clear"})
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{ key := apiKeyModelToService(mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID, UserID: user.ID,
Key: "sk-clear-group", Key: "sk-clear-group",
Name: "Group Key", Name: "Group Key",
GroupID: &group.ID, GroupID: &group.ID,
}) }))
key.GroupID = nil key.GroupID = nil
err := s.repo.Update(s.ctx, key) err := s.repo.Update(s.ctx, key)
...@@ -127,8 +127,8 @@ func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() { ...@@ -127,8 +127,8 @@ func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
// --- Delete --- // --- Delete ---
func (s *ApiKeyRepoSuite) TestDelete() { func (s *ApiKeyRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"})
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{ key := mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID, UserID: user.ID,
Key: "sk-delete", Key: "sk-delete",
Name: "Delete Me", Name: "Delete Me",
...@@ -144,9 +144,9 @@ func (s *ApiKeyRepoSuite) TestDelete() { ...@@ -144,9 +144,9 @@ func (s *ApiKeyRepoSuite) TestDelete() {
// --- ListByUserID / CountByUserID --- // --- ListByUserID / CountByUserID ---
func (s *ApiKeyRepoSuite) TestListByUserID() { func (s *ApiKeyRepoSuite) TestListByUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyuser@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyuser@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-list-1", Name: "Key 1"}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-list-1", Name: "Key 1"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-list-2", Name: "Key 2"}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-list-2", Name: "Key 2"})
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10}) keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByUserID") s.Require().NoError(err, "ListByUserID")
...@@ -155,9 +155,9 @@ func (s *ApiKeyRepoSuite) TestListByUserID() { ...@@ -155,9 +155,9 @@ func (s *ApiKeyRepoSuite) TestListByUserID() {
} }
func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() { func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "paging@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "paging@test.com"})
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
mustCreateApiKey(s.T(), s.db, &model.ApiKey{ mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID, UserID: user.ID,
Key: "sk-page-" + string(rune('a'+i)), Key: "sk-page-" + string(rune('a'+i)),
Name: "Key", Name: "Key",
...@@ -172,9 +172,9 @@ func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() { ...@@ -172,9 +172,9 @@ func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
} }
func (s *ApiKeyRepoSuite) TestCountByUserID() { func (s *ApiKeyRepoSuite) TestCountByUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "count@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "count@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-count-1", Name: "K1"}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-count-1", Name: "K1"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-count-2", Name: "K2"}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-count-2", Name: "K2"})
count, err := s.repo.CountByUserID(s.ctx, user.ID) count, err := s.repo.CountByUserID(s.ctx, user.ID)
s.Require().NoError(err, "CountByUserID") s.Require().NoError(err, "CountByUserID")
...@@ -184,12 +184,12 @@ func (s *ApiKeyRepoSuite) TestCountByUserID() { ...@@ -184,12 +184,12 @@ func (s *ApiKeyRepoSuite) TestCountByUserID() {
// --- ListByGroupID / CountByGroupID --- // --- ListByGroupID / CountByGroupID ---
func (s *ApiKeyRepoSuite) TestListByGroupID() { func (s *ApiKeyRepoSuite) TestListByGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbygroup@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbygroup@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-grp-1", Name: "K1", GroupID: &group.ID}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-grp-1", Name: "K1", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-grp-2", Name: "K2", GroupID: &group.ID}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-grp-2", Name: "K2", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-grp-3", Name: "K3"}) // no group mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-grp-3", Name: "K3"}) // no group
keys, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10}) keys, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByGroupID") s.Require().NoError(err, "ListByGroupID")
...@@ -200,10 +200,10 @@ func (s *ApiKeyRepoSuite) TestListByGroupID() { ...@@ -200,10 +200,10 @@ func (s *ApiKeyRepoSuite) TestListByGroupID() {
} }
func (s *ApiKeyRepoSuite) TestCountByGroupID() { func (s *ApiKeyRepoSuite) TestCountByGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "countgroup@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "countgroup@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-count"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-count"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-gc-1", Name: "K1", GroupID: &group.ID}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-gc-1", Name: "K1", GroupID: &group.ID})
count, err := s.repo.CountByGroupID(s.ctx, group.ID) count, err := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "CountByGroupID") s.Require().NoError(err, "CountByGroupID")
...@@ -213,8 +213,8 @@ func (s *ApiKeyRepoSuite) TestCountByGroupID() { ...@@ -213,8 +213,8 @@ func (s *ApiKeyRepoSuite) TestCountByGroupID() {
// --- ExistsByKey --- // --- ExistsByKey ---
func (s *ApiKeyRepoSuite) TestExistsByKey() { func (s *ApiKeyRepoSuite) TestExistsByKey() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "exists@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "exists@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-exists", Name: "K"}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-exists", Name: "K"})
exists, err := s.repo.ExistsByKey(s.ctx, "sk-exists") exists, err := s.repo.ExistsByKey(s.ctx, "sk-exists")
s.Require().NoError(err, "ExistsByKey") s.Require().NoError(err, "ExistsByKey")
...@@ -228,9 +228,9 @@ func (s *ApiKeyRepoSuite) TestExistsByKey() { ...@@ -228,9 +228,9 @@ func (s *ApiKeyRepoSuite) TestExistsByKey() {
// --- SearchApiKeys --- // --- SearchApiKeys ---
func (s *ApiKeyRepoSuite) TestSearchApiKeys() { func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "search@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "search@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-search-1", Name: "Production Key"}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-search-1", Name: "Production Key"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-search-2", Name: "Development Key"}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-search-2", Name: "Development Key"})
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "prod", 10) found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "prod", 10)
s.Require().NoError(err, "SearchApiKeys") s.Require().NoError(err, "SearchApiKeys")
...@@ -239,9 +239,9 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys() { ...@@ -239,9 +239,9 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
} }
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() { func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "searchnokw@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "searchnokw@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-nk-1", Name: "K1"}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-nk-1", Name: "K1"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-nk-2", Name: "K2"}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-nk-2", Name: "K2"})
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "", 10) found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "", 10)
s.Require().NoError(err) s.Require().NoError(err)
...@@ -249,8 +249,8 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() { ...@@ -249,8 +249,8 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
} }
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() { func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "searchnouid@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "searchnouid@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-nu-1", Name: "TestKey"}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-nu-1", Name: "TestKey"})
found, err := s.repo.SearchApiKeys(s.ctx, 0, "testkey", 10) found, err := s.repo.SearchApiKeys(s.ctx, 0, "testkey", 10)
s.Require().NoError(err) s.Require().NoError(err)
...@@ -260,12 +260,12 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() { ...@@ -260,12 +260,12 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
// --- ClearGroupIDByGroupID --- // --- ClearGroupIDByGroupID ---
func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() { func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "cleargrp@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "cleargrp@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-clear-bulk"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-clear-bulk"})
k1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-clr-1", Name: "K1", GroupID: &group.ID}) k1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-clr-1", Name: "K1", GroupID: &group.ID})
k2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-clr-2", Name: "K2", GroupID: &group.ID}) k2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-clr-2", Name: "K2", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-clr-3", Name: "K3"}) // no group mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-clr-3", Name: "K3"}) // no group
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID) affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "ClearGroupIDByGroupID") s.Require().NoError(err, "ClearGroupIDByGroupID")
...@@ -283,16 +283,16 @@ func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() { ...@@ -283,16 +283,16 @@ func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
// --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) --- // --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) ---
func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() { func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "k@example.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "k@example.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-k"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-k"})
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{ key := apiKeyModelToService(mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID, UserID: user.ID,
Key: "sk-test-1", Key: "sk-test-1",
Name: "My Key", Name: "My Key",
GroupID: &group.ID, GroupID: &group.ID,
Status: model.StatusActive, Status: service.StatusActive,
}) }))
got, err := s.repo.GetByKey(s.ctx, key.Key) got, err := s.repo.GetByKey(s.ctx, key.Key)
s.Require().NoError(err, "GetByKey") s.Require().NoError(err, "GetByKey")
...@@ -303,7 +303,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() { ...@@ -303,7 +303,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s.Require().Equal(group.ID, got.Group.ID) s.Require().Equal(group.ID, got.Group.ID)
key.Name = "Renamed" key.Name = "Renamed"
key.Status = model.StatusDisabled key.Status = service.StatusDisabled
key.GroupID = nil key.GroupID = nil
s.Require().NoError(s.repo.Update(s.ctx, key), "Update") s.Require().NoError(s.repo.Update(s.ctx, key), "Update")
...@@ -312,7 +312,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() { ...@@ -312,7 +312,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s.Require().Equal("sk-test-1", got2.Key, "Update should not change key") s.Require().Equal("sk-test-1", got2.Key, "Update should not change key")
s.Require().Equal(user.ID, got2.UserID, "Update should not change user_id") s.Require().Equal(user.ID, got2.UserID, "Update should not change user_id")
s.Require().Equal("Renamed", got2.Name) s.Require().Equal("Renamed", got2.Name)
s.Require().Equal(model.StatusDisabled, got2.Status) s.Require().Equal(service.StatusDisabled, got2.Status)
s.Require().Nil(got2.GroupID) s.Require().Nil(got2.GroupID)
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10}) keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
...@@ -330,7 +330,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() { ...@@ -330,7 +330,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s.Require().Equal(key.ID, found[0].ID) s.Require().Equal(key.ID, found[0].ID)
// ClearGroupIDByGroupID // ClearGroupIDByGroupID
k2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{ k2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID, UserID: user.ID,
Key: "sk-test-2", Key: "sk-test-2",
Name: "Group Key", Name: "Group Key",
......
package repository
import "gorm.io/gorm"
// AutoMigrate runs schema migrations for all repository persistence models.
// Persistence models are defined within individual `*_repo.go` files.
func AutoMigrate(db *gorm.DB) error {
return db.AutoMigrate(
&userModel{},
&apiKeyModel{},
&groupModel{},
&accountModel{},
&accountGroupModel{},
&proxyModel{},
&redeemCodeModel{},
&usageLogModel{},
&settingModel{},
&userSubscriptionModel{},
)
}
...@@ -6,21 +6,25 @@ import ( ...@@ -6,21 +6,25 @@ import (
"testing" "testing"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gorm.io/datatypes"
"gorm.io/gorm" "gorm.io/gorm"
) )
func mustCreateUser(t *testing.T, db *gorm.DB, u *model.User) *model.User { func mustCreateUser(t *testing.T, db *gorm.DB, u *userModel) *userModel {
t.Helper() t.Helper()
if u.PasswordHash == "" { if u.PasswordHash == "" {
u.PasswordHash = "test-password-hash" u.PasswordHash = "test-password-hash"
} }
if u.Role == "" { if u.Role == "" {
u.Role = model.RoleUser u.Role = service.RoleUser
} }
if u.Status == "" { if u.Status == "" {
u.Status = model.StatusActive u.Status = service.StatusActive
}
if u.Concurrency == 0 {
u.Concurrency = 5
} }
if u.CreatedAt.IsZero() { if u.CreatedAt.IsZero() {
u.CreatedAt = time.Now() u.CreatedAt = time.Now()
...@@ -32,16 +36,16 @@ func mustCreateUser(t *testing.T, db *gorm.DB, u *model.User) *model.User { ...@@ -32,16 +36,16 @@ func mustCreateUser(t *testing.T, db *gorm.DB, u *model.User) *model.User {
return u return u
} }
func mustCreateGroup(t *testing.T, db *gorm.DB, g *model.Group) *model.Group { func mustCreateGroup(t *testing.T, db *gorm.DB, g *groupModel) *groupModel {
t.Helper() t.Helper()
if g.Platform == "" { if g.Platform == "" {
g.Platform = model.PlatformAnthropic g.Platform = service.PlatformAnthropic
} }
if g.Status == "" { if g.Status == "" {
g.Status = model.StatusActive g.Status = service.StatusActive
} }
if g.SubscriptionType == "" { if g.SubscriptionType == "" {
g.SubscriptionType = model.SubscriptionTypeStandard g.SubscriptionType = service.SubscriptionTypeStandard
} }
if g.CreatedAt.IsZero() { if g.CreatedAt.IsZero() {
g.CreatedAt = time.Now() g.CreatedAt = time.Now()
...@@ -53,7 +57,7 @@ func mustCreateGroup(t *testing.T, db *gorm.DB, g *model.Group) *model.Group { ...@@ -53,7 +57,7 @@ func mustCreateGroup(t *testing.T, db *gorm.DB, g *model.Group) *model.Group {
return g return g
} }
func mustCreateProxy(t *testing.T, db *gorm.DB, p *model.Proxy) *model.Proxy { func mustCreateProxy(t *testing.T, db *gorm.DB, p *proxyModel) *proxyModel {
t.Helper() t.Helper()
if p.Protocol == "" { if p.Protocol == "" {
p.Protocol = "http" p.Protocol = "http"
...@@ -65,7 +69,7 @@ func mustCreateProxy(t *testing.T, db *gorm.DB, p *model.Proxy) *model.Proxy { ...@@ -65,7 +69,7 @@ func mustCreateProxy(t *testing.T, db *gorm.DB, p *model.Proxy) *model.Proxy {
p.Port = 8080 p.Port = 8080
} }
if p.Status == "" { if p.Status == "" {
p.Status = model.StatusActive p.Status = service.StatusActive
} }
if p.CreatedAt.IsZero() { if p.CreatedAt.IsZero() {
p.CreatedAt = time.Now() p.CreatedAt = time.Now()
...@@ -77,25 +81,25 @@ func mustCreateProxy(t *testing.T, db *gorm.DB, p *model.Proxy) *model.Proxy { ...@@ -77,25 +81,25 @@ func mustCreateProxy(t *testing.T, db *gorm.DB, p *model.Proxy) *model.Proxy {
return p return p
} }
func mustCreateAccount(t *testing.T, db *gorm.DB, a *model.Account) *model.Account { func mustCreateAccount(t *testing.T, db *gorm.DB, a *accountModel) *accountModel {
t.Helper() t.Helper()
if a.Platform == "" { if a.Platform == "" {
a.Platform = model.PlatformAnthropic a.Platform = service.PlatformAnthropic
} }
if a.Type == "" { if a.Type == "" {
a.Type = model.AccountTypeOAuth a.Type = service.AccountTypeOAuth
} }
if a.Status == "" { if a.Status == "" {
a.Status = model.StatusActive a.Status = service.StatusActive
} }
if !a.Schedulable { if !a.Schedulable {
a.Schedulable = true a.Schedulable = true
} }
if a.Credentials == nil { if a.Credentials == nil {
a.Credentials = model.JSONB{} a.Credentials = datatypes.JSONMap{}
} }
if a.Extra == nil { if a.Extra == nil {
a.Extra = model.JSONB{} a.Extra = datatypes.JSONMap{}
} }
if a.CreatedAt.IsZero() { if a.CreatedAt.IsZero() {
a.CreatedAt = time.Now() a.CreatedAt = time.Now()
...@@ -107,10 +111,10 @@ func mustCreateAccount(t *testing.T, db *gorm.DB, a *model.Account) *model.Accou ...@@ -107,10 +111,10 @@ func mustCreateAccount(t *testing.T, db *gorm.DB, a *model.Account) *model.Accou
return a return a
} }
func mustCreateApiKey(t *testing.T, db *gorm.DB, k *model.ApiKey) *model.ApiKey { func mustCreateApiKey(t *testing.T, db *gorm.DB, k *apiKeyModel) *apiKeyModel {
t.Helper() t.Helper()
if k.Status == "" { if k.Status == "" {
k.Status = model.StatusActive k.Status = service.StatusActive
} }
if k.CreatedAt.IsZero() { if k.CreatedAt.IsZero() {
k.CreatedAt = time.Now() k.CreatedAt = time.Now()
...@@ -122,13 +126,13 @@ func mustCreateApiKey(t *testing.T, db *gorm.DB, k *model.ApiKey) *model.ApiKey ...@@ -122,13 +126,13 @@ func mustCreateApiKey(t *testing.T, db *gorm.DB, k *model.ApiKey) *model.ApiKey
return k return k
} }
func mustCreateRedeemCode(t *testing.T, db *gorm.DB, c *model.RedeemCode) *model.RedeemCode { func mustCreateRedeemCode(t *testing.T, db *gorm.DB, c *redeemCodeModel) *redeemCodeModel {
t.Helper() t.Helper()
if c.Status == "" { if c.Status == "" {
c.Status = model.StatusUnused c.Status = service.StatusUnused
} }
if c.Type == "" { if c.Type == "" {
c.Type = model.RedeemTypeBalance c.Type = service.RedeemTypeBalance
} }
if c.CreatedAt.IsZero() { if c.CreatedAt.IsZero() {
c.CreatedAt = time.Now() c.CreatedAt = time.Now()
...@@ -137,10 +141,10 @@ func mustCreateRedeemCode(t *testing.T, db *gorm.DB, c *model.RedeemCode) *model ...@@ -137,10 +141,10 @@ func mustCreateRedeemCode(t *testing.T, db *gorm.DB, c *model.RedeemCode) *model
return c return c
} }
func mustCreateSubscription(t *testing.T, db *gorm.DB, s *model.UserSubscription) *model.UserSubscription { func mustCreateSubscription(t *testing.T, db *gorm.DB, s *userSubscriptionModel) *userSubscriptionModel {
t.Helper() t.Helper()
if s.Status == "" { if s.Status == "" {
s.Status = model.SubscriptionStatusActive s.Status = service.SubscriptionStatusActive
} }
now := time.Now() now := time.Now()
if s.StartsAt.IsZero() { if s.StartsAt.IsZero() {
...@@ -164,9 +168,10 @@ func mustCreateSubscription(t *testing.T, db *gorm.DB, s *model.UserSubscription ...@@ -164,9 +168,10 @@ func mustCreateSubscription(t *testing.T, db *gorm.DB, s *model.UserSubscription
func mustBindAccountToGroup(t *testing.T, db *gorm.DB, accountID, groupID int64, priority int) { func mustBindAccountToGroup(t *testing.T, db *gorm.DB, accountID, groupID int64, priority int) {
t.Helper() t.Helper()
require.NoError(t, db.Create(&model.AccountGroup{ require.NoError(t, db.Create(&accountGroupModel{
AccountID: accountID, AccountID: accountID,
GroupID: groupID, GroupID: groupID,
Priority: priority, Priority: priority,
CreatedAt: time.Now(),
}).Error, "create account_group") }).Error, "create account_group")
} }
...@@ -2,10 +2,10 @@ package repository ...@@ -2,10 +2,10 @@ package repository
import ( import (
"context" "context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm" "gorm.io/gorm"
...@@ -20,38 +20,50 @@ func NewGroupRepository(db *gorm.DB) service.GroupRepository { ...@@ -20,38 +20,50 @@ func NewGroupRepository(db *gorm.DB) service.GroupRepository {
return &groupRepository{db: db} return &groupRepository{db: db}
} }
func (r *groupRepository) Create(ctx context.Context, group *model.Group) error { func (r *groupRepository) Create(ctx context.Context, group *service.Group) error {
err := r.db.WithContext(ctx).Create(group).Error m := groupModelFromService(group)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyGroupModelToService(group, m)
}
return translatePersistenceError(err, nil, service.ErrGroupExists) return translatePersistenceError(err, nil, service.ErrGroupExists)
} }
func (r *groupRepository) GetByID(ctx context.Context, id int64) (*model.Group, error) { func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group, error) {
var group model.Group var m groupModel
err := r.db.WithContext(ctx).First(&group, id).Error err := r.db.WithContext(ctx).First(&m, id).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil) return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil)
} }
return &group, nil group := groupModelToService(&m)
count, _ := r.GetAccountCount(ctx, group.ID)
group.AccountCount = count
return group, nil
} }
func (r *groupRepository) Update(ctx context.Context, group *model.Group) error { func (r *groupRepository) Update(ctx context.Context, group *service.Group) error {
return r.db.WithContext(ctx).Save(group).Error m := groupModelFromService(group)
err := r.db.WithContext(ctx).Save(m).Error
if err == nil {
applyGroupModelToService(group, m)
}
return err
} }
func (r *groupRepository) Delete(ctx context.Context, id int64) error { func (r *groupRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.Group{}, id).Error return r.db.WithContext(ctx).Delete(&groupModel{}, id).Error
} }
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) { func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", nil) return r.ListWithFilters(ctx, params, "", "", nil)
} }
// ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive // ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive
func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error) { func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
var groups []model.Group var groups []groupModel
var total int64 var total int64
db := r.db.WithContext(ctx).Model(&model.Group{}) db := r.db.WithContext(ctx).Model(&groupModel{})
// Apply filters // Apply filters
if platform != "" { if platform != "" {
...@@ -72,68 +84,71 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination ...@@ -72,68 +84,71 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
return nil, nil, err return nil, nil, err
} }
// 获取每个分组的账号数量 outGroups := make([]service.Group, 0, len(groups))
for i := range groups { for i := range groups {
count, _ := r.GetAccountCount(ctx, groups[i].ID) outGroups = append(outGroups, *groupModelToService(&groups[i]))
groups[i].AccountCount = count
} }
pages := int(total) / params.Limit() // 获取每个分组的账号数量
if int(total)%params.Limit() > 0 { for i := range outGroups {
pages++ count, _ := r.GetAccountCount(ctx, outGroups[i].ID)
outGroups[i].AccountCount = count
} }
return groups, &pagination.PaginationResult{ return outGroups, paginationResultFromTotal(total, params), nil
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
} }
func (r *groupRepository) ListActive(ctx context.Context) ([]model.Group, error) { func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, error) {
var groups []model.Group var groups []groupModel
err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Order("id ASC").Find(&groups).Error err := r.db.WithContext(ctx).Where("status = ?", service.StatusActive).Order("id ASC").Find(&groups).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 获取每个分组的账号数量 outGroups := make([]service.Group, 0, len(groups))
for i := range groups { for i := range groups {
count, _ := r.GetAccountCount(ctx, groups[i].ID) outGroups = append(outGroups, *groupModelToService(&groups[i]))
groups[i].AccountCount = count }
// 获取每个分组的账号数量
for i := range outGroups {
count, _ := r.GetAccountCount(ctx, outGroups[i].ID)
outGroups[i].AccountCount = count
} }
return groups, nil return outGroups, nil
} }
func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error) { func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
var groups []model.Group var groups []groupModel
err := r.db.WithContext(ctx).Where("status = ? AND platform = ?", model.StatusActive, platform).Order("id ASC").Find(&groups).Error err := r.db.WithContext(ctx).Where("status = ? AND platform = ?", service.StatusActive, platform).Order("id ASC").Find(&groups).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 获取每个分组的账号数量 outGroups := make([]service.Group, 0, len(groups))
for i := range groups { for i := range groups {
count, _ := r.GetAccountCount(ctx, groups[i].ID) outGroups = append(outGroups, *groupModelToService(&groups[i]))
groups[i].AccountCount = count }
// 获取每个分组的账号数量
for i := range outGroups {
count, _ := r.GetAccountCount(ctx, outGroups[i].ID)
outGroups[i].AccountCount = count
} }
return groups, nil return outGroups, nil
} }
func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool, error) { func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.Group{}).Where("name = ?", name).Count(&count).Error err := r.db.WithContext(ctx).Model(&groupModel{}).Where("name = ?", name).Count(&count).Error
return count > 0, err return count > 0, err
} }
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.AccountGroup{}).Where("group_id = ?", groupID).Count(&count).Error err := r.db.WithContext(ctx).Table("account_groups").Where("group_id = ?", groupID).Count(&count).Error
return count, err return count, err
} }
// DeleteAccountGroupsByGroupID 删除分组与账号的关联关系 // DeleteAccountGroupsByGroupID 删除分组与账号的关联关系
func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.AccountGroup{}) result := r.db.WithContext(ctx).Exec("DELETE FROM account_groups WHERE group_id = ?", groupID)
return result.RowsAffected, result.Error return result.RowsAffected, result.Error
} }
...@@ -145,46 +160,42 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, ...@@ -145,46 +160,42 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
var affectedUserIDs []int64 var affectedUserIDs []int64
if group.IsSubscriptionType() { if group.IsSubscriptionType() {
var subscriptions []model.UserSubscription
if err := r.db.WithContext(ctx). if err := r.db.WithContext(ctx).
Model(&model.UserSubscription{}). Table("user_subscriptions").
Where("group_id = ?", id). Where("group_id = ?", id).
Select("user_id"). Pluck("user_id", &affectedUserIDs).Error; err != nil {
Find(&subscriptions).Error; err != nil {
return nil, err return nil, err
} }
for _, sub := range subscriptions {
affectedUserIDs = append(affectedUserIDs, sub.UserID)
}
} }
err = r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { err = r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// 1. 删除订阅类型分组的订阅记录 // 1. 删除订阅类型分组的订阅记录
if group.IsSubscriptionType() { if group.IsSubscriptionType() {
if err := tx.Where("group_id = ?", id).Delete(&model.UserSubscription{}).Error; err != nil { if err := tx.Exec("DELETE FROM user_subscriptions WHERE group_id = ?", id).Error; err != nil {
return err return err
} }
} }
// 2. 将 api_keys 中绑定该分组的 group_id 设为 nil // 2. 将 api_keys 中绑定该分组的 group_id 设为 nil
if err := tx.Model(&model.ApiKey{}).Where("group_id = ?", id).Update("group_id", nil).Error; err != nil { if err := tx.Exec("UPDATE api_keys SET group_id = NULL WHERE group_id = ?", id).Error; err != nil {
return err return err
} }
// 3. 从 users.allowed_groups 数组中移除该分组 ID // 3. 从 users.allowed_groups 数组中移除该分组 ID
if err := tx.Model(&model.User{}). if err := tx.Exec(
Where("? = ANY(allowed_groups)", id). "UPDATE users SET allowed_groups = array_remove(allowed_groups, ?) WHERE ? = ANY(allowed_groups)",
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", id)).Error; err != nil { id, id,
).Error; err != nil {
return err return err
} }
// 4. 删除 account_groups 中间表的数据 // 4. 删除 account_groups 中间表的数据
if err := tx.Where("group_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil { if err := tx.Exec("DELETE FROM account_groups WHERE group_id = ?", id).Error; err != nil {
return err return err
} }
// 5. 删除分组本身(带锁,避免并发写) // 5. 删除分组本身(带锁,避免并发写)
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Delete(&model.Group{}, id).Error; err != nil { if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Delete(&groupModel{}, id).Error; err != nil {
return err return err
} }
...@@ -196,3 +207,75 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, ...@@ -196,3 +207,75 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
return affectedUserIDs, nil return affectedUserIDs, nil
} }
type groupModel struct {
ID int64 `gorm:"primaryKey"`
Name string `gorm:"uniqueIndex;size:100;not null"`
Description string `gorm:"type:text"`
Platform string `gorm:"size:50;default:anthropic;not null"`
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null"`
IsExclusive bool `gorm:"default:false;not null"`
Status string `gorm:"size:20;default:active;not null"`
SubscriptionType string `gorm:"size:20;default:standard;not null"`
DailyLimitUSD *float64 `gorm:"type:decimal(20,8)"`
WeeklyLimitUSD *float64 `gorm:"type:decimal(20,8)"`
MonthlyLimitUSD *float64 `gorm:"type:decimal(20,8)"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
DeletedAt gorm.DeletedAt `gorm:"index"`
}
func (groupModel) TableName() string { return "groups" }
func groupModelToService(m *groupModel) *service.Group {
if m == nil {
return nil
}
return &service.Group{
ID: m.ID,
Name: m.Name,
Description: m.Description,
Platform: m.Platform,
RateMultiplier: m.RateMultiplier,
IsExclusive: m.IsExclusive,
Status: m.Status,
SubscriptionType: m.SubscriptionType,
DailyLimitUSD: m.DailyLimitUSD,
WeeklyLimitUSD: m.WeeklyLimitUSD,
MonthlyLimitUSD: m.MonthlyLimitUSD,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
}
}
func groupModelFromService(sg *service.Group) *groupModel {
if sg == nil {
return nil
}
return &groupModel{
ID: sg.ID,
Name: sg.Name,
Description: sg.Description,
Platform: sg.Platform,
RateMultiplier: sg.RateMultiplier,
IsExclusive: sg.IsExclusive,
Status: sg.Status,
SubscriptionType: sg.SubscriptionType,
DailyLimitUSD: sg.DailyLimitUSD,
WeeklyLimitUSD: sg.WeeklyLimitUSD,
MonthlyLimitUSD: sg.MonthlyLimitUSD,
CreatedAt: sg.CreatedAt,
UpdatedAt: sg.UpdatedAt,
}
}
func applyGroupModelToService(group *service.Group, m *groupModel) {
if group == nil || m == nil {
return
}
group.ID = m.ID
group.CreatedAt = m.CreatedAt
group.UpdatedAt = m.UpdatedAt
}
...@@ -6,8 +6,8 @@ import ( ...@@ -6,8 +6,8 @@ import (
"context" "context"
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"gorm.io/gorm" "gorm.io/gorm"
) )
...@@ -32,10 +32,10 @@ func TestGroupRepoSuite(t *testing.T) { ...@@ -32,10 +32,10 @@ func TestGroupRepoSuite(t *testing.T) {
// --- Create / GetByID / Update / Delete --- // --- Create / GetByID / Update / Delete ---
func (s *GroupRepoSuite) TestCreate() { func (s *GroupRepoSuite) TestCreate() {
group := &model.Group{ group := &service.Group{
Name: "test-create", Name: "test-create",
Platform: model.PlatformAnthropic, Platform: service.PlatformAnthropic,
Status: model.StatusActive, Status: service.StatusActive,
} }
err := s.repo.Create(s.ctx, group) err := s.repo.Create(s.ctx, group)
...@@ -53,7 +53,7 @@ func (s *GroupRepoSuite) TestGetByID_NotFound() { ...@@ -53,7 +53,7 @@ func (s *GroupRepoSuite) TestGetByID_NotFound() {
} }
func (s *GroupRepoSuite) TestUpdate() { func (s *GroupRepoSuite) TestUpdate() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "original"}) group := groupModelToService(mustCreateGroup(s.T(), s.db, &groupModel{Name: "original"}))
group.Name = "updated" group.Name = "updated"
err := s.repo.Update(s.ctx, group) err := s.repo.Update(s.ctx, group)
...@@ -65,7 +65,7 @@ func (s *GroupRepoSuite) TestUpdate() { ...@@ -65,7 +65,7 @@ func (s *GroupRepoSuite) TestUpdate() {
} }
func (s *GroupRepoSuite) TestDelete() { func (s *GroupRepoSuite) TestDelete() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "to-delete"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "to-delete"})
err := s.repo.Delete(s.ctx, group.ID) err := s.repo.Delete(s.ctx, group.ID)
s.Require().NoError(err, "Delete") s.Require().NoError(err, "Delete")
...@@ -77,8 +77,8 @@ func (s *GroupRepoSuite) TestDelete() { ...@@ -77,8 +77,8 @@ func (s *GroupRepoSuite) TestDelete() {
// --- List / ListWithFilters --- // --- List / ListWithFilters ---
func (s *GroupRepoSuite) TestList() { func (s *GroupRepoSuite) TestList() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1"}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1"})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2"}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2"})
groups, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}) groups, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List") s.Require().NoError(err, "List")
...@@ -87,28 +87,28 @@ func (s *GroupRepoSuite) TestList() { ...@@ -87,28 +87,28 @@ func (s *GroupRepoSuite) TestList() {
} }
func (s *GroupRepoSuite) TestListWithFilters_Platform() { func (s *GroupRepoSuite) TestListWithFilters_Platform() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", Platform: model.PlatformAnthropic}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", Platform: service.PlatformAnthropic})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", Platform: model.PlatformOpenAI}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", Platform: service.PlatformOpenAI})
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.PlatformOpenAI, "", nil) groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", nil)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(groups, 1) s.Require().Len(groups, 1)
s.Require().Equal(model.PlatformOpenAI, groups[0].Platform) s.Require().Equal(service.PlatformOpenAI, groups[0].Platform)
} }
func (s *GroupRepoSuite) TestListWithFilters_Status() { func (s *GroupRepoSuite) TestListWithFilters_Status() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", Status: model.StatusActive}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", Status: service.StatusActive})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", Status: model.StatusDisabled}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", Status: service.StatusDisabled})
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusDisabled, nil) groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, nil)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(groups, 1) s.Require().Len(groups, 1)
s.Require().Equal(model.StatusDisabled, groups[0].Status) s.Require().Equal(service.StatusDisabled, groups[0].Status)
} }
func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() { func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", IsExclusive: false}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", IsExclusive: false})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", IsExclusive: true}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", IsExclusive: true})
isExclusive := true isExclusive := true
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", &isExclusive) groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", &isExclusive)
...@@ -118,24 +118,24 @@ func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() { ...@@ -118,24 +118,24 @@ func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
} }
func (s *GroupRepoSuite) TestListWithFilters_AccountCount() { func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
g1 := mustCreateGroup(s.T(), s.db, &model.Group{ g1 := mustCreateGroup(s.T(), s.db, &groupModel{
Name: "g1", Name: "g1",
Platform: model.PlatformAnthropic, Platform: service.PlatformAnthropic,
Status: model.StatusActive, Status: service.StatusActive,
}) })
g2 := mustCreateGroup(s.T(), s.db, &model.Group{ g2 := mustCreateGroup(s.T(), s.db, &groupModel{
Name: "g2", Name: "g2",
Platform: model.PlatformAnthropic, Platform: service.PlatformAnthropic,
Status: model.StatusActive, Status: service.StatusActive,
IsExclusive: true, IsExclusive: true,
}) })
a := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc1"}) a := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc1"})
mustBindAccountToGroup(s.T(), s.db, a.ID, g1.ID, 1) mustBindAccountToGroup(s.T(), s.db, a.ID, g1.ID, 1)
mustBindAccountToGroup(s.T(), s.db, a.ID, g2.ID, 1) mustBindAccountToGroup(s.T(), s.db, a.ID, g2.ID, 1)
isExclusive := true isExclusive := true
groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.PlatformAnthropic, model.StatusActive, &isExclusive) groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformAnthropic, service.StatusActive, &isExclusive)
s.Require().NoError(err, "ListWithFilters") s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total) s.Require().Equal(int64(1), page.Total)
s.Require().Len(groups, 1) s.Require().Len(groups, 1)
...@@ -146,8 +146,8 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() { ...@@ -146,8 +146,8 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
// --- ListActive / ListActiveByPlatform --- // --- ListActive / ListActiveByPlatform ---
func (s *GroupRepoSuite) TestListActive() { func (s *GroupRepoSuite) TestListActive() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "active1", Status: model.StatusActive}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "active1", Status: service.StatusActive})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "inactive1", Status: model.StatusDisabled}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "inactive1", Status: service.StatusDisabled})
groups, err := s.repo.ListActive(s.ctx) groups, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive") s.Require().NoError(err, "ListActive")
...@@ -156,11 +156,11 @@ func (s *GroupRepoSuite) TestListActive() { ...@@ -156,11 +156,11 @@ func (s *GroupRepoSuite) TestListActive() {
} }
func (s *GroupRepoSuite) TestListActiveByPlatform() { func (s *GroupRepoSuite) TestListActiveByPlatform() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", Platform: model.PlatformAnthropic, Status: model.StatusActive}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", Platform: service.PlatformAnthropic, Status: service.StatusActive})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", Platform: model.PlatformOpenAI, Status: model.StatusActive}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", Platform: service.PlatformOpenAI, Status: service.StatusActive})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g3", Platform: model.PlatformAnthropic, Status: model.StatusDisabled}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "g3", Platform: service.PlatformAnthropic, Status: service.StatusDisabled})
groups, err := s.repo.ListActiveByPlatform(s.ctx, model.PlatformAnthropic) groups, err := s.repo.ListActiveByPlatform(s.ctx, service.PlatformAnthropic)
s.Require().NoError(err, "ListActiveByPlatform") s.Require().NoError(err, "ListActiveByPlatform")
s.Require().Len(groups, 1) s.Require().Len(groups, 1)
s.Require().Equal("g1", groups[0].Name) s.Require().Equal("g1", groups[0].Name)
...@@ -169,7 +169,7 @@ func (s *GroupRepoSuite) TestListActiveByPlatform() { ...@@ -169,7 +169,7 @@ func (s *GroupRepoSuite) TestListActiveByPlatform() {
// --- ExistsByName --- // --- ExistsByName ---
func (s *GroupRepoSuite) TestExistsByName() { func (s *GroupRepoSuite) TestExistsByName() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "existing-group"}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "existing-group"})
exists, err := s.repo.ExistsByName(s.ctx, "existing-group") exists, err := s.repo.ExistsByName(s.ctx, "existing-group")
s.Require().NoError(err, "ExistsByName") s.Require().NoError(err, "ExistsByName")
...@@ -183,9 +183,9 @@ func (s *GroupRepoSuite) TestExistsByName() { ...@@ -183,9 +183,9 @@ func (s *GroupRepoSuite) TestExistsByName() {
// --- GetAccountCount --- // --- GetAccountCount ---
func (s *GroupRepoSuite) TestGetAccountCount() { func (s *GroupRepoSuite) TestGetAccountCount() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-count"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-count"})
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1"}) a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1"})
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2"}) a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2"})
mustBindAccountToGroup(s.T(), s.db, a1.ID, group.ID, 1) mustBindAccountToGroup(s.T(), s.db, a1.ID, group.ID, 1)
mustBindAccountToGroup(s.T(), s.db, a2.ID, group.ID, 2) mustBindAccountToGroup(s.T(), s.db, a2.ID, group.ID, 2)
...@@ -195,7 +195,7 @@ func (s *GroupRepoSuite) TestGetAccountCount() { ...@@ -195,7 +195,7 @@ func (s *GroupRepoSuite) TestGetAccountCount() {
} }
func (s *GroupRepoSuite) TestGetAccountCount_Empty() { func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-empty"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-empty"})
count, err := s.repo.GetAccountCount(s.ctx, group.ID) count, err := s.repo.GetAccountCount(s.ctx, group.ID)
s.Require().NoError(err) s.Require().NoError(err)
...@@ -205,8 +205,8 @@ func (s *GroupRepoSuite) TestGetAccountCount_Empty() { ...@@ -205,8 +205,8 @@ func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
// --- DeleteAccountGroupsByGroupID --- // --- DeleteAccountGroupsByGroupID ---
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() { func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
g := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-del"}) g := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-del"})
a := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-del"}) a := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-del"})
mustBindAccountToGroup(s.T(), s.db, a.ID, g.ID, 1) mustBindAccountToGroup(s.T(), s.db, a.ID, g.ID, 1)
affected, err := s.repo.DeleteAccountGroupsByGroupID(s.ctx, g.ID) affected, err := s.repo.DeleteAccountGroupsByGroupID(s.ctx, g.ID)
...@@ -219,10 +219,10 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() { ...@@ -219,10 +219,10 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
} }
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() { func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
g := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-multi"}) g := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-multi"})
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1"}) a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1"})
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2"}) a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2"})
a3 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3"}) a3 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3"})
mustBindAccountToGroup(s.T(), s.db, a1.ID, g.ID, 1) mustBindAccountToGroup(s.T(), s.db, a1.ID, g.ID, 1)
mustBindAccountToGroup(s.T(), s.db, a2.ID, g.ID, 2) mustBindAccountToGroup(s.T(), s.db, a2.ID, g.ID, 2)
mustBindAccountToGroup(s.T(), s.db, a3.ID, g.ID, 3) mustBindAccountToGroup(s.T(), s.db, a3.ID, g.ID, 3)
......
...@@ -15,7 +15,6 @@ import ( ...@@ -15,7 +15,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
...@@ -94,7 +93,7 @@ func TestMain(m *testing.M) { ...@@ -94,7 +93,7 @@ func TestMain(m *testing.M) {
log.Printf("failed to open gorm db: %v", err) log.Printf("failed to open gorm db: %v", err)
os.Exit(1) os.Exit(1)
} }
if err := model.AutoMigrate(integrationDB); err != nil { if err := AutoMigrate(integrationDB); err != nil {
log.Printf("failed to automigrate db: %v", err) log.Printf("failed to automigrate db: %v", err)
os.Exit(1) os.Exit(1)
} }
......
package repository
import "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
func paginationResultFromTotal(total int64, params pagination.PaginationParams) *pagination.PaginationResult {
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
}
return &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment