Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
e5a77853
Commit
e5a77853
authored
Dec 26, 2025
by
Forest
Browse files
refactor: 调整项目结构为单向依赖
parent
b3463769
Changes
95
Expand all
Show whitespace changes
Inline
Side-by-side
backend/internal/handler/user_handler.go
View file @
e5a77853
package
handler
import
(
"github.com/Wei-Shaw/sub2api/internal/
model
"
"github.com/Wei-Shaw/sub2api/internal/
handler/dto
"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
...
...
@@ -35,19 +36,13 @@ type UpdateProfileRequest struct {
// GetProfile handles getting user profile
// GET /api/v1/users/me
func
(
h
*
UserHandler
)
GetProfile
(
c
*
gin
.
Context
)
{
userValue
,
exists
:=
c
.
Get
(
"user"
)
if
!
exists
{
response
.
Unauthorized
(
c
,
"User not authenticated"
)
return
}
user
,
ok
:=
userValue
.
(
*
model
.
User
)
subject
,
ok
:=
middleware2
.
GetAuthSubjectFromContext
(
c
)
if
!
ok
{
response
.
InternalError
(
c
,
"Invalid user context
"
)
response
.
Unauthorized
(
c
,
"User not authenticated
"
)
return
}
userData
,
err
:=
h
.
userService
.
GetByID
(
c
.
Request
.
Context
(),
u
ser
.
ID
)
userData
,
err
:=
h
.
userService
.
GetByID
(
c
.
Request
.
Context
(),
subject
.
U
serID
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
...
...
@@ -56,21 +51,15 @@ func (h *UserHandler) GetProfile(c *gin.Context) {
// 清空notes字段,普通用户不应看到备注
userData
.
Notes
=
""
response
.
Success
(
c
,
userData
)
response
.
Success
(
c
,
dto
.
UserFromService
(
userData
)
)
}
// ChangePassword handles changing user password
// POST /api/v1/users/me/password
func
(
h
*
UserHandler
)
ChangePassword
(
c
*
gin
.
Context
)
{
userValue
,
exists
:=
c
.
Get
(
"user"
)
if
!
exists
{
response
.
Unauthorized
(
c
,
"User not authenticated"
)
return
}
user
,
ok
:=
userValue
.
(
*
model
.
User
)
subject
,
ok
:=
middleware2
.
GetAuthSubjectFromContext
(
c
)
if
!
ok
{
response
.
InternalError
(
c
,
"Invalid user context
"
)
response
.
Unauthorized
(
c
,
"User not authenticated
"
)
return
}
...
...
@@ -84,7 +73,7 @@ func (h *UserHandler) ChangePassword(c *gin.Context) {
CurrentPassword
:
req
.
OldPassword
,
NewPassword
:
req
.
NewPassword
,
}
err
:=
h
.
userService
.
ChangePassword
(
c
.
Request
.
Context
(),
u
ser
.
ID
,
svcReq
)
err
:=
h
.
userService
.
ChangePassword
(
c
.
Request
.
Context
(),
subject
.
U
serID
,
svcReq
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
...
...
@@ -96,15 +85,9 @@ func (h *UserHandler) ChangePassword(c *gin.Context) {
// UpdateProfile handles updating user profile
// PUT /api/v1/users/me
func
(
h
*
UserHandler
)
UpdateProfile
(
c
*
gin
.
Context
)
{
userValue
,
exists
:=
c
.
Get
(
"user"
)
if
!
exists
{
response
.
Unauthorized
(
c
,
"User not authenticated"
)
return
}
user
,
ok
:=
userValue
.
(
*
model
.
User
)
subject
,
ok
:=
middleware2
.
GetAuthSubjectFromContext
(
c
)
if
!
ok
{
response
.
InternalError
(
c
,
"Invalid user context
"
)
response
.
Unauthorized
(
c
,
"User not authenticated
"
)
return
}
...
...
@@ -118,7 +101,7 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
Username
:
req
.
Username
,
Wechat
:
req
.
Wechat
,
}
updatedUser
,
err
:=
h
.
userService
.
UpdateProfile
(
c
.
Request
.
Context
(),
u
ser
.
ID
,
svcReq
)
updatedUser
,
err
:=
h
.
userService
.
UpdateProfile
(
c
.
Request
.
Context
(),
subject
.
U
serID
,
svcReq
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
...
...
@@ -127,5 +110,5 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
// 清空notes字段,普通用户不应看到备注
updatedUser
.
Notes
=
""
response
.
Success
(
c
,
updatedUser
)
response
.
Success
(
c
,
dto
.
UserFromService
(
updatedUser
)
)
}
backend/internal/infrastructure/database.go
View file @
e5a77853
...
...
@@ -2,8 +2,8 @@ package infrastructure
import
(
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/repository"
"gorm.io/driver/postgres"
"gorm.io/gorm"
...
...
@@ -30,7 +30,7 @@ func InitDB(cfg *config.Config) (*gorm.DB, error) {
// 自动迁移(始终执行,确保数据库结构与代码同步)
// GORM 的 AutoMigrate 只会添加新字段,不会删除或修改已有字段,是安全的
if
err
:=
model
.
AutoMigrate
(
db
);
err
!=
nil
{
if
err
:=
repository
.
AutoMigrate
(
db
);
err
!=
nil
{
return
nil
,
err
}
...
...
backend/internal/model/account_group.go
deleted
100644 → 0
View file @
b3463769
package
model
import
(
"time"
)
type
AccountGroup
struct
{
AccountID
int64
`gorm:"primaryKey" json:"account_id"`
GroupID
int64
`gorm:"primaryKey" json:"group_id"`
Priority
int
`gorm:"default:50;not null" json:"priority"`
// 分组内优先级
CreatedAt
time
.
Time
`gorm:"not null" json:"created_at"`
// 关联
Account
*
Account
`gorm:"foreignKey:AccountID" json:"account,omitempty"`
Group
*
Group
`gorm:"foreignKey:GroupID" json:"group,omitempty"`
}
func
(
AccountGroup
)
TableName
()
string
{
return
"account_groups"
}
backend/internal/model/api_key.go
deleted
100644 → 0
View file @
b3463769
package
model
import
(
"time"
"gorm.io/gorm"
)
type
ApiKey
struct
{
ID
int64
`gorm:"primaryKey" json:"id"`
UserID
int64
`gorm:"index;not null" json:"user_id"`
Key
string
`gorm:"uniqueIndex;size:128;not null" json:"key"`
// sk-xxx
Name
string
`gorm:"size:100;not null" json:"name"`
GroupID
*
int64
`gorm:"index" json:"group_id"`
Status
string
`gorm:"size:20;default:active;not null" json:"status"`
// active/disabled
CreatedAt
time
.
Time
`gorm:"not null" json:"created_at"`
UpdatedAt
time
.
Time
`gorm:"not null" json:"updated_at"`
DeletedAt
gorm
.
DeletedAt
`gorm:"index" json:"-"`
// 关联
User
*
User
`gorm:"foreignKey:UserID" json:"user,omitempty"`
Group
*
Group
`gorm:"foreignKey:GroupID" json:"group,omitempty"`
}
func
(
ApiKey
)
TableName
()
string
{
return
"api_keys"
}
// IsActive 检查是否激活
func
(
k
*
ApiKey
)
IsActive
()
bool
{
return
k
.
Status
==
"active"
}
backend/internal/model/group.go
deleted
100644 → 0
View file @
b3463769
package
model
import
(
"time"
"gorm.io/gorm"
)
// 订阅类型常量
const
(
SubscriptionTypeStandard
=
"standard"
// 标准计费模式(按余额扣费)
SubscriptionTypeSubscription
=
"subscription"
// 订阅模式(按限额控制)
)
type
Group
struct
{
ID
int64
`gorm:"primaryKey" json:"id"`
Name
string
`gorm:"uniqueIndex;size:100;not null" json:"name"`
Description
string
`gorm:"type:text" json:"description"`
Platform
string
`gorm:"size:50;default:anthropic;not null" json:"platform"`
// anthropic/openai/gemini
RateMultiplier
float64
`gorm:"type:decimal(10,4);default:1.0;not null" json:"rate_multiplier"`
IsExclusive
bool
`gorm:"default:false;not null" json:"is_exclusive"`
Status
string
`gorm:"size:20;default:active;not null" json:"status"`
// active/disabled
// 订阅功能字段
SubscriptionType
string
`gorm:"size:20;default:standard;not null" json:"subscription_type"`
// standard/subscription
DailyLimitUSD
*
float64
`gorm:"type:decimal(20,8)" json:"daily_limit_usd"`
WeeklyLimitUSD
*
float64
`gorm:"type:decimal(20,8)" json:"weekly_limit_usd"`
MonthlyLimitUSD
*
float64
`gorm:"type:decimal(20,8)" json:"monthly_limit_usd"`
CreatedAt
time
.
Time
`gorm:"not null" json:"created_at"`
UpdatedAt
time
.
Time
`gorm:"not null" json:"updated_at"`
DeletedAt
gorm
.
DeletedAt
`gorm:"index" json:"-"`
// 关联
AccountGroups
[]
AccountGroup
`gorm:"foreignKey:GroupID" json:"account_groups,omitempty"`
// 虚拟字段 (不存储到数据库)
AccountCount
int64
`gorm:"-" json:"account_count,omitempty"`
}
func
(
Group
)
TableName
()
string
{
return
"groups"
}
// IsActive 检查是否激活
func
(
g
*
Group
)
IsActive
()
bool
{
return
g
.
Status
==
"active"
}
// IsSubscriptionType 检查是否为订阅类型分组
func
(
g
*
Group
)
IsSubscriptionType
()
bool
{
return
g
.
SubscriptionType
==
SubscriptionTypeSubscription
}
// IsFreeSubscription 检查是否为免费订阅(不扣余额但有限额)
func
(
g
*
Group
)
IsFreeSubscription
()
bool
{
return
g
.
IsSubscriptionType
()
&&
g
.
RateMultiplier
==
0
}
// HasDailyLimit 检查是否有日限额
func
(
g
*
Group
)
HasDailyLimit
()
bool
{
return
g
.
DailyLimitUSD
!=
nil
&&
*
g
.
DailyLimitUSD
>
0
}
// HasWeeklyLimit 检查是否有周限额
func
(
g
*
Group
)
HasWeeklyLimit
()
bool
{
return
g
.
WeeklyLimitUSD
!=
nil
&&
*
g
.
WeeklyLimitUSD
>
0
}
// HasMonthlyLimit 检查是否有月限额
func
(
g
*
Group
)
HasMonthlyLimit
()
bool
{
return
g
.
MonthlyLimitUSD
!=
nil
&&
*
g
.
MonthlyLimitUSD
>
0
}
backend/internal/model/model.go
deleted
100644 → 0
View file @
b3463769
package
model
import
(
"gorm.io/gorm"
)
// AutoMigrate 自动迁移所有模型
func
AutoMigrate
(
db
*
gorm
.
DB
)
error
{
return
db
.
AutoMigrate
(
&
User
{},
&
ApiKey
{},
&
Group
{},
&
Account
{},
&
AccountGroup
{},
&
Proxy
{},
&
RedeemCode
{},
&
UsageLog
{},
&
Setting
{},
&
UserSubscription
{},
)
}
// 状态常量
const
(
StatusActive
=
"active"
StatusDisabled
=
"disabled"
StatusError
=
"error"
StatusUnused
=
"unused"
StatusUsed
=
"used"
StatusExpired
=
"expired"
)
// 角色常量
const
(
RoleAdmin
=
"admin"
RoleUser
=
"user"
)
// 平台常量
const
(
PlatformAnthropic
=
"anthropic"
PlatformOpenAI
=
"openai"
PlatformGemini
=
"gemini"
)
// 账号类型常量
const
(
AccountTypeOAuth
=
"oauth"
// OAuth类型账号(full scope: profile + inference)
AccountTypeSetupToken
=
"setup-token"
// Setup Token类型账号(inference only scope)
AccountTypeApiKey
=
"apikey"
// API Key类型账号
)
// 卡密类型常量
const
(
RedeemTypeBalance
=
"balance"
RedeemTypeConcurrency
=
"concurrency"
RedeemTypeSubscription
=
"subscription"
)
// 管理员调整类型常量
const
(
AdjustmentTypeAdminBalance
=
"admin_balance"
// 管理员调整余额
AdjustmentTypeAdminConcurrency
=
"admin_concurrency"
// 管理员调整并发数
)
backend/internal/model/proxy.go
deleted
100644 → 0
View file @
b3463769
package
model
import
(
"fmt"
"time"
"gorm.io/gorm"
)
type
Proxy
struct
{
ID
int64
`gorm:"primaryKey" json:"id"`
Name
string
`gorm:"size:100;not null" json:"name"`
Protocol
string
`gorm:"size:20;not null" json:"protocol"`
// http/https/socks5
Host
string
`gorm:"size:255;not null" json:"host"`
Port
int
`gorm:"not null" json:"port"`
Username
string
`gorm:"size:100" json:"username"`
Password
string
`gorm:"size:100" json:"-"`
Status
string
`gorm:"size:20;default:active;not null" json:"status"`
// active/disabled
CreatedAt
time
.
Time
`gorm:"not null" json:"created_at"`
UpdatedAt
time
.
Time
`gorm:"not null" json:"updated_at"`
DeletedAt
gorm
.
DeletedAt
`gorm:"index" json:"-"`
}
func
(
Proxy
)
TableName
()
string
{
return
"proxies"
}
// IsActive 检查是否激活
func
(
p
*
Proxy
)
IsActive
()
bool
{
return
p
.
Status
==
"active"
}
// URL 返回代理URL
func
(
p
*
Proxy
)
URL
()
string
{
if
p
.
Username
!=
""
&&
p
.
Password
!=
""
{
return
fmt
.
Sprintf
(
"%s://%s:%s@%s:%d"
,
p
.
Protocol
,
p
.
Username
,
p
.
Password
,
p
.
Host
,
p
.
Port
)
}
return
fmt
.
Sprintf
(
"%s://%s:%d"
,
p
.
Protocol
,
p
.
Host
,
p
.
Port
)
}
// ProxyWithAccountCount extends Proxy with account count information
type
ProxyWithAccountCount
struct
{
Proxy
AccountCount
int64
`json:"account_count"`
}
backend/internal/model/redeem_code.go
deleted
100644 → 0
View file @
b3463769
package
model
import
(
"crypto/rand"
"encoding/hex"
"time"
)
type
RedeemCode
struct
{
ID
int64
`gorm:"primaryKey" json:"id"`
Code
string
`gorm:"uniqueIndex;size:32;not null" json:"code"`
Type
string
`gorm:"size:20;default:balance;not null" json:"type"`
// balance/concurrency/subscription
Value
float64
`gorm:"type:decimal(20,8);not null" json:"value"`
// 面值(USD)或并发数或有效天数
Status
string
`gorm:"size:20;default:unused;not null" json:"status"`
// unused/used
UsedBy
*
int64
`gorm:"index" json:"used_by"`
UsedAt
*
time
.
Time
`json:"used_at"`
Notes
string
`gorm:"type:text" json:"notes"`
CreatedAt
time
.
Time
`gorm:"not null" json:"created_at"`
// 订阅类型专用字段
GroupID
*
int64
`gorm:"index" json:"group_id"`
// 订阅分组ID (仅subscription类型使用)
ValidityDays
int
`gorm:"default:30" json:"validity_days"`
// 订阅有效天数 (仅subscription类型使用)
// 关联
User
*
User
`gorm:"foreignKey:UsedBy" json:"user,omitempty"`
Group
*
Group
`gorm:"foreignKey:GroupID" json:"group,omitempty"`
}
func
(
RedeemCode
)
TableName
()
string
{
return
"redeem_codes"
}
// IsUsed 检查是否已使用
func
(
r
*
RedeemCode
)
IsUsed
()
bool
{
return
r
.
Status
==
"used"
}
// CanUse 检查是否可以使用
func
(
r
*
RedeemCode
)
CanUse
()
bool
{
return
r
.
Status
==
"unused"
}
// GenerateRedeemCode 生成唯一的兑换码
func
GenerateRedeemCode
()
(
string
,
error
)
{
b
:=
make
([]
byte
,
16
)
if
_
,
err
:=
rand
.
Read
(
b
);
err
!=
nil
{
return
""
,
err
}
return
hex
.
EncodeToString
(
b
),
nil
}
backend/internal/model/usage_log.go
deleted
100644 → 0
View file @
b3463769
package
model
import
(
"time"
)
// 消费类型常量
const
(
BillingTypeBalance
int8
=
0
// 钱包余额
BillingTypeSubscription
int8
=
1
// 订阅套餐
)
type
UsageLog
struct
{
ID
int64
`gorm:"primaryKey" json:"id"`
UserID
int64
`gorm:"index;not null" json:"user_id"`
ApiKeyID
int64
`gorm:"index;not null" json:"api_key_id"`
AccountID
int64
`gorm:"index;not null" json:"account_id"`
RequestID
string
`gorm:"size:64" json:"request_id"`
Model
string
`gorm:"size:100;index;not null" json:"model"`
// 订阅关联(可选)
GroupID
*
int64
`gorm:"index" json:"group_id"`
SubscriptionID
*
int64
`gorm:"index" json:"subscription_id"`
// Token使用量(4类)
InputTokens
int
`gorm:"default:0;not null" json:"input_tokens"`
OutputTokens
int
`gorm:"default:0;not null" json:"output_tokens"`
CacheCreationTokens
int
`gorm:"default:0;not null" json:"cache_creation_tokens"`
CacheReadTokens
int
`gorm:"default:0;not null" json:"cache_read_tokens"`
// 详细的缓存创建分类
CacheCreation5mTokens
int
`gorm:"default:0;not null" json:"cache_creation_5m_tokens"`
CacheCreation1hTokens
int
`gorm:"default:0;not null" json:"cache_creation_1h_tokens"`
// 费用(USD)
InputCost
float64
`gorm:"type:decimal(20,10);default:0;not null" json:"input_cost"`
OutputCost
float64
`gorm:"type:decimal(20,10);default:0;not null" json:"output_cost"`
CacheCreationCost
float64
`gorm:"type:decimal(20,10);default:0;not null" json:"cache_creation_cost"`
CacheReadCost
float64
`gorm:"type:decimal(20,10);default:0;not null" json:"cache_read_cost"`
TotalCost
float64
`gorm:"type:decimal(20,10);default:0;not null" json:"total_cost"`
// 原始总费用
ActualCost
float64
`gorm:"type:decimal(20,10);default:0;not null" json:"actual_cost"`
// 实际扣除费用
RateMultiplier
float64
`gorm:"type:decimal(10,4);default:1;not null" json:"rate_multiplier"`
// 计费倍率
// 元数据
BillingType
int8
`gorm:"type:smallint;default:0;not null" json:"billing_type"`
// 0=余额 1=订阅
Stream
bool
`gorm:"default:false;not null" json:"stream"`
DurationMs
*
int
`json:"duration_ms"`
FirstTokenMs
*
int
`json:"first_token_ms"`
// 首字时间(流式请求)
CreatedAt
time
.
Time
`gorm:"index;not null" json:"created_at"`
// 关联
User
*
User
`gorm:"foreignKey:UserID" json:"user,omitempty"`
ApiKey
*
ApiKey
`gorm:"foreignKey:ApiKeyID" json:"api_key,omitempty"`
Account
*
Account
`gorm:"foreignKey:AccountID" json:"account,omitempty"`
Group
*
Group
`gorm:"foreignKey:GroupID" json:"group,omitempty"`
Subscription
*
UserSubscription
`gorm:"foreignKey:SubscriptionID" json:"subscription,omitempty"`
}
func
(
UsageLog
)
TableName
()
string
{
return
"usage_logs"
}
// TotalTokens 总token数
func
(
u
*
UsageLog
)
TotalTokens
()
int
{
return
u
.
InputTokens
+
u
.
OutputTokens
+
u
.
CacheCreationTokens
+
u
.
CacheReadTokens
}
backend/internal/model/user.go
deleted
100644 → 0
View file @
b3463769
package
model
import
(
"time"
"github.com/lib/pq"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
type
User
struct
{
ID
int64
`gorm:"primaryKey" json:"id"`
Email
string
`gorm:"uniqueIndex;size:255;not null" json:"email"`
Username
string
`gorm:"size:100;default:''" json:"username"`
Wechat
string
`gorm:"size:100;default:''" json:"wechat"`
Notes
string
`gorm:"type:text;default:''" json:"notes"`
PasswordHash
string
`gorm:"size:255;not null" json:"-"`
Role
string
`gorm:"size:20;default:user;not null" json:"role"`
// admin/user
Balance
float64
`gorm:"type:decimal(20,8);default:0;not null" json:"balance"`
Concurrency
int
`gorm:"default:5;not null" json:"concurrency"`
Status
string
`gorm:"size:20;default:active;not null" json:"status"`
// active/disabled
AllowedGroups
pq
.
Int64Array
`gorm:"type:bigint[]" json:"allowed_groups"`
CreatedAt
time
.
Time
`gorm:"not null" json:"created_at"`
UpdatedAt
time
.
Time
`gorm:"not null" json:"updated_at"`
DeletedAt
gorm
.
DeletedAt
`gorm:"index" json:"-"`
// 关联
ApiKeys
[]
ApiKey
`gorm:"foreignKey:UserID" json:"api_keys,omitempty"`
Subscriptions
[]
UserSubscription
`gorm:"foreignKey:UserID" json:"subscriptions,omitempty"`
}
func
(
User
)
TableName
()
string
{
return
"users"
}
// IsAdmin 检查是否管理员
func
(
u
*
User
)
IsAdmin
()
bool
{
return
u
.
Role
==
"admin"
}
// IsActive 检查是否激活
func
(
u
*
User
)
IsActive
()
bool
{
return
u
.
Status
==
"active"
}
// CanBindGroup 检查是否可以绑定指定分组
// 对于标准类型分组:
// - 如果 AllowedGroups 设置了值(非空数组),只能绑定列表中的分组
// - 如果 AllowedGroups 为 nil 或空数组,可以绑定所有非专属分组
func
(
u
*
User
)
CanBindGroup
(
groupID
int64
,
isExclusive
bool
)
bool
{
// 如果设置了 allowed_groups 且不为空,只能绑定指定的分组
if
len
(
u
.
AllowedGroups
)
>
0
{
for
_
,
id
:=
range
u
.
AllowedGroups
{
if
id
==
groupID
{
return
true
}
}
return
false
}
// 如果没有设置 allowed_groups 或为空数组,可以绑定所有非专属分组
return
!
isExclusive
}
// SetPassword 设置密码(哈希存储)
func
(
u
*
User
)
SetPassword
(
password
string
)
error
{
hash
,
err
:=
bcrypt
.
GenerateFromPassword
([]
byte
(
password
),
bcrypt
.
DefaultCost
)
if
err
!=
nil
{
return
err
}
u
.
PasswordHash
=
string
(
hash
)
return
nil
}
// CheckPassword 验证密码
func
(
u
*
User
)
CheckPassword
(
password
string
)
bool
{
err
:=
bcrypt
.
CompareHashAndPassword
([]
byte
(
u
.
PasswordHash
),
[]
byte
(
password
))
return
err
==
nil
}
backend/internal/repository/account_repo.go
View file @
e5a77853
This diff is collapsed.
Click to expand it.
backend/internal/repository/account_repo_integration_test.go
View file @
e5a77853
This diff is collapsed.
Click to expand it.
backend/internal/repository/api_key_repo.go
View file @
e5a77853
...
...
@@ -2,10 +2,10 @@ package repository
import
(
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
...
...
@@ -19,42 +19,51 @@ func NewApiKeyRepository(db *gorm.DB) service.ApiKeyRepository {
return
&
apiKeyRepository
{
db
:
db
}
}
func
(
r
*
apiKeyRepository
)
Create
(
ctx
context
.
Context
,
key
*
model
.
ApiKey
)
error
{
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Create
(
key
)
.
Error
func
(
r
*
apiKeyRepository
)
Create
(
ctx
context
.
Context
,
key
*
service
.
ApiKey
)
error
{
m
:=
apiKeyModelFromService
(
key
)
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Create
(
m
)
.
Error
if
err
==
nil
{
applyApiKeyModelToService
(
key
,
m
)
}
return
translatePersistenceError
(
err
,
nil
,
service
.
ErrApiKeyExists
)
}
func
(
r
*
apiKeyRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
ApiKey
,
error
)
{
var
key
model
.
ApiKey
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Preload
(
"User"
)
.
Preload
(
"Group"
)
.
First
(
&
key
,
id
)
.
Error
func
(
r
*
apiKeyRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
ApiKey
,
error
)
{
var
m
apiKeyModel
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Preload
(
"User"
)
.
Preload
(
"Group"
)
.
First
(
&
m
,
id
)
.
Error
if
err
!=
nil
{
return
nil
,
translatePersistenceError
(
err
,
service
.
ErrApiKeyNotFound
,
nil
)
}
return
&
key
,
nil
return
apiKeyModelToService
(
&
m
)
,
nil
}
func
(
r
*
apiKeyRepository
)
GetByKey
(
ctx
context
.
Context
,
key
string
)
(
*
model
.
ApiKey
,
error
)
{
var
apiKey
m
odel
.
ApiKey
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Preload
(
"User"
)
.
Preload
(
"Group"
)
.
Where
(
"key = ?"
,
key
)
.
First
(
&
apiKey
)
.
Error
func
(
r
*
apiKeyRepository
)
GetByKey
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
ApiKey
,
error
)
{
var
m
apiKey
M
odel
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Preload
(
"User"
)
.
Preload
(
"Group"
)
.
Where
(
"key = ?"
,
key
)
.
First
(
&
m
)
.
Error
if
err
!=
nil
{
return
nil
,
translatePersistenceError
(
err
,
service
.
ErrApiKeyNotFound
,
nil
)
}
return
&
apiKey
,
nil
return
apiKey
ModelToService
(
&
m
)
,
nil
}
func
(
r
*
apiKeyRepository
)
Update
(
ctx
context
.
Context
,
key
*
model
.
ApiKey
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
key
)
.
Select
(
"name"
,
"group_id"
,
"status"
,
"updated_at"
)
.
Updates
(
key
)
.
Error
func
(
r
*
apiKeyRepository
)
Update
(
ctx
context
.
Context
,
key
*
service
.
ApiKey
)
error
{
m
:=
apiKeyModelFromService
(
key
)
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
m
)
.
Select
(
"name"
,
"group_id"
,
"status"
,
"updated_at"
)
.
Updates
(
m
)
.
Error
if
err
==
nil
{
applyApiKeyModelToService
(
key
,
m
)
}
return
err
}
func
(
r
*
apiKeyRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Delete
(
&
model
.
ApiKey
{},
id
)
.
Error
return
r
.
db
.
WithContext
(
ctx
)
.
Delete
(
&
apiKeyModel
{},
id
)
.
Error
}
func
(
r
*
apiKeyRepository
)
ListByUserID
(
ctx
context
.
Context
,
userID
int64
,
params
pagination
.
PaginationParams
)
([]
model
.
ApiKey
,
*
pagination
.
PaginationResult
,
error
)
{
var
keys
[]
model
.
ApiKey
func
(
r
*
apiKeyRepository
)
ListByUserID
(
ctx
context
.
Context
,
userID
int64
,
params
pagination
.
PaginationParams
)
([]
service
.
ApiKey
,
*
pagination
.
PaginationResult
,
error
)
{
var
keys
[]
apiKeyModel
var
total
int64
db
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
ApiKey
{})
.
Where
(
"user_id = ?"
,
userID
)
db
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
apiKeyModel
{})
.
Where
(
"user_id = ?"
,
userID
)
if
err
:=
db
.
Count
(
&
total
)
.
Error
;
err
!=
nil
{
return
nil
,
nil
,
err
...
...
@@ -64,36 +73,31 @@ func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
return
nil
,
nil
,
err
}
pages
:=
int
(
total
)
/
params
.
Limit
(
)
if
int
(
total
)
%
params
.
Limit
()
>
0
{
pages
++
outKeys
:=
make
([]
service
.
ApiKey
,
0
,
len
(
keys
)
)
for
i
:=
range
keys
{
outKeys
=
append
(
outKeys
,
*
apiKeyModelToService
(
&
keys
[
i
]))
}
return
keys
,
&
pagination
.
PaginationResult
{
Total
:
total
,
Page
:
params
.
Page
,
PageSize
:
params
.
Limit
(),
Pages
:
pages
,
},
nil
return
outKeys
,
paginationResultFromTotal
(
total
,
params
),
nil
}
func
(
r
*
apiKeyRepository
)
CountByUserID
(
ctx
context
.
Context
,
userID
int64
)
(
int64
,
error
)
{
var
count
int64
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
ApiKey
{})
.
Where
(
"user_id = ?"
,
userID
)
.
Count
(
&
count
)
.
Error
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
apiKeyModel
{})
.
Where
(
"user_id = ?"
,
userID
)
.
Count
(
&
count
)
.
Error
return
count
,
err
}
func
(
r
*
apiKeyRepository
)
ExistsByKey
(
ctx
context
.
Context
,
key
string
)
(
bool
,
error
)
{
var
count
int64
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
ApiKey
{})
.
Where
(
"key = ?"
,
key
)
.
Count
(
&
count
)
.
Error
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
apiKeyModel
{})
.
Where
(
"key = ?"
,
key
)
.
Count
(
&
count
)
.
Error
return
count
>
0
,
err
}
func
(
r
*
apiKeyRepository
)
ListByGroupID
(
ctx
context
.
Context
,
groupID
int64
,
params
pagination
.
PaginationParams
)
([]
model
.
ApiKey
,
*
pagination
.
PaginationResult
,
error
)
{
var
keys
[]
model
.
ApiKey
func
(
r
*
apiKeyRepository
)
ListByGroupID
(
ctx
context
.
Context
,
groupID
int64
,
params
pagination
.
PaginationParams
)
([]
service
.
ApiKey
,
*
pagination
.
PaginationResult
,
error
)
{
var
keys
[]
apiKeyModel
var
total
int64
db
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
ApiKey
{})
.
Where
(
"group_id = ?"
,
groupID
)
db
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
apiKeyModel
{})
.
Where
(
"group_id = ?"
,
groupID
)
if
err
:=
db
.
Count
(
&
total
)
.
Error
;
err
!=
nil
{
return
nil
,
nil
,
err
...
...
@@ -103,24 +107,19 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
return
nil
,
nil
,
err
}
pages
:=
int
(
total
)
/
params
.
Limit
(
)
if
int
(
total
)
%
params
.
Limit
()
>
0
{
pages
++
outKeys
:=
make
([]
service
.
ApiKey
,
0
,
len
(
keys
)
)
for
i
:=
range
keys
{
outKeys
=
append
(
outKeys
,
*
apiKeyModelToService
(
&
keys
[
i
]))
}
return
keys
,
&
pagination
.
PaginationResult
{
Total
:
total
,
Page
:
params
.
Page
,
PageSize
:
params
.
Limit
(),
Pages
:
pages
,
},
nil
return
outKeys
,
paginationResultFromTotal
(
total
,
params
),
nil
}
// SearchApiKeys searches API keys by user ID and/or keyword (name)
func
(
r
*
apiKeyRepository
)
SearchApiKeys
(
ctx
context
.
Context
,
userID
int64
,
keyword
string
,
limit
int
)
([]
model
.
ApiKey
,
error
)
{
var
keys
[]
model
.
ApiKey
func
(
r
*
apiKeyRepository
)
SearchApiKeys
(
ctx
context
.
Context
,
userID
int64
,
keyword
string
,
limit
int
)
([]
service
.
ApiKey
,
error
)
{
var
keys
[]
apiKeyModel
db
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
ApiKey
{})
db
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
apiKeyModel
{})
if
userID
>
0
{
db
=
db
.
Where
(
"user_id = ?"
,
userID
)
...
...
@@ -135,12 +134,16 @@ func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw
return
nil
,
err
}
return
keys
,
nil
outKeys
:=
make
([]
service
.
ApiKey
,
0
,
len
(
keys
))
for
i
:=
range
keys
{
outKeys
=
append
(
outKeys
,
*
apiKeyModelToService
(
&
keys
[
i
]))
}
return
outKeys
,
nil
}
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
func
(
r
*
apiKeyRepository
)
ClearGroupIDByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
result
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
ApiKey
{})
.
result
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
apiKeyModel
{})
.
Where
(
"group_id = ?"
,
groupID
)
.
Update
(
"group_id"
,
nil
)
return
result
.
RowsAffected
,
result
.
Error
...
...
@@ -149,6 +152,66 @@ func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID in
// CountByGroupID 获取分组的 API Key 数量
func
(
r
*
apiKeyRepository
)
CountByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
var
count
int64
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
ApiKey
{})
.
Where
(
"group_id = ?"
,
groupID
)
.
Count
(
&
count
)
.
Error
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
apiKeyModel
{})
.
Where
(
"group_id = ?"
,
groupID
)
.
Count
(
&
count
)
.
Error
return
count
,
err
}
type
apiKeyModel
struct
{
ID
int64
`gorm:"primaryKey"`
UserID
int64
`gorm:"index;not null"`
Key
string
`gorm:"uniqueIndex;size:128;not null"`
Name
string
`gorm:"size:100;not null"`
GroupID
*
int64
`gorm:"index"`
Status
string
`gorm:"size:20;default:active;not null"`
CreatedAt
time
.
Time
`gorm:"not null"`
UpdatedAt
time
.
Time
`gorm:"not null"`
DeletedAt
gorm
.
DeletedAt
`gorm:"index"`
User
*
userModel
`gorm:"foreignKey:UserID"`
Group
*
groupModel
`gorm:"foreignKey:GroupID"`
}
func
(
apiKeyModel
)
TableName
()
string
{
return
"api_keys"
}
func
apiKeyModelToService
(
m
*
apiKeyModel
)
*
service
.
ApiKey
{
if
m
==
nil
{
return
nil
}
return
&
service
.
ApiKey
{
ID
:
m
.
ID
,
UserID
:
m
.
UserID
,
Key
:
m
.
Key
,
Name
:
m
.
Name
,
GroupID
:
m
.
GroupID
,
Status
:
m
.
Status
,
CreatedAt
:
m
.
CreatedAt
,
UpdatedAt
:
m
.
UpdatedAt
,
User
:
userModelToService
(
m
.
User
),
Group
:
groupModelToService
(
m
.
Group
),
}
}
func
apiKeyModelFromService
(
k
*
service
.
ApiKey
)
*
apiKeyModel
{
if
k
==
nil
{
return
nil
}
return
&
apiKeyModel
{
ID
:
k
.
ID
,
UserID
:
k
.
UserID
,
Key
:
k
.
Key
,
Name
:
k
.
Name
,
GroupID
:
k
.
GroupID
,
Status
:
k
.
Status
,
CreatedAt
:
k
.
CreatedAt
,
UpdatedAt
:
k
.
UpdatedAt
,
}
}
func
applyApiKeyModelToService
(
key
*
service
.
ApiKey
,
m
*
apiKeyModel
)
{
if
key
==
nil
||
m
==
nil
{
return
}
key
.
ID
=
m
.
ID
key
.
CreatedAt
=
m
.
CreatedAt
key
.
UpdatedAt
=
m
.
UpdatedAt
}
backend/internal/repository/api_key_repo_integration_test.go
View file @
e5a77853
...
...
@@ -6,8 +6,8 @@ import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
...
...
@@ -32,13 +32,13 @@ func TestApiKeyRepoSuite(t *testing.T) {
// --- Create / GetByID / GetByKey ---
func
(
s
*
ApiKeyRepoSuite
)
TestCreate
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"create@test.com"
})
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
userModel
{
Email
:
"create@test.com"
})
key
:=
&
model
.
ApiKey
{
key
:=
&
service
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-create-test"
,
Name
:
"Test Key"
,
Status
:
model
.
StatusActive
,
Status
:
service
.
StatusActive
,
}
err
:=
s
.
repo
.
Create
(
s
.
ctx
,
key
)
...
...
@@ -56,15 +56,15 @@ func (s *ApiKeyRepoSuite) TestGetByID_NotFound() {
}
func
(
s
*
ApiKeyRepoSuite
)
TestGetByKey
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"getbykey@test.com"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g-key"
})
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
userModel
{
Email
:
"getbykey@test.com"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g-key"
})
key
:=
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
key
:=
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-getbykey"
,
Name
:
"My Key"
,
GroupID
:
&
group
.
ID
,
Status
:
model
.
StatusActive
,
Status
:
service
.
StatusActive
,
})
got
,
err
:=
s
.
repo
.
GetByKey
(
s
.
ctx
,
key
.
Key
)
...
...
@@ -84,16 +84,16 @@ func (s *ApiKeyRepoSuite) TestGetByKey_NotFound() {
// --- Update ---
func
(
s
*
ApiKeyRepoSuite
)
TestUpdate
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"update@test.com"
})
key
:=
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
userModel
{
Email
:
"update@test.com"
})
key
:=
apiKeyModelToService
(
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-update"
,
Name
:
"Original"
,
Status
:
model
.
StatusActive
,
})
Status
:
service
.
StatusActive
,
})
)
key
.
Name
=
"Renamed"
key
.
Status
=
model
.
StatusDisabled
key
.
Status
=
service
.
StatusDisabled
err
:=
s
.
repo
.
Update
(
s
.
ctx
,
key
)
s
.
Require
()
.
NoError
(
err
,
"Update"
)
...
...
@@ -102,18 +102,18 @@ func (s *ApiKeyRepoSuite) TestUpdate() {
s
.
Require
()
.
Equal
(
"sk-update"
,
got
.
Key
,
"Update should not change key"
)
s
.
Require
()
.
Equal
(
user
.
ID
,
got
.
UserID
,
"Update should not change user_id"
)
s
.
Require
()
.
Equal
(
"Renamed"
,
got
.
Name
)
s
.
Require
()
.
Equal
(
model
.
StatusDisabled
,
got
.
Status
)
s
.
Require
()
.
Equal
(
service
.
StatusDisabled
,
got
.
Status
)
}
func
(
s
*
ApiKeyRepoSuite
)
TestUpdate_ClearGroupID
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"cleargroup@test.com"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g-clear"
})
key
:=
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
userModel
{
Email
:
"cleargroup@test.com"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g-clear"
})
key
:=
apiKeyModelToService
(
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-clear-group"
,
Name
:
"Group Key"
,
GroupID
:
&
group
.
ID
,
})
})
)
key
.
GroupID
=
nil
err
:=
s
.
repo
.
Update
(
s
.
ctx
,
key
)
...
...
@@ -127,8 +127,8 @@ func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
// --- Delete ---
func
(
s
*
ApiKeyRepoSuite
)
TestDelete
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"delete@test.com"
})
key
:=
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
userModel
{
Email
:
"delete@test.com"
})
key
:=
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-delete"
,
Name
:
"Delete Me"
,
...
...
@@ -144,9 +144,9 @@ func (s *ApiKeyRepoSuite) TestDelete() {
// --- ListByUserID / CountByUserID ---
func
(
s
*
ApiKeyRepoSuite
)
TestListByUserID
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"listbyuser@test.com"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-list-1"
,
Name
:
"Key 1"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-list-2"
,
Name
:
"Key 2"
})
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
userModel
{
Email
:
"listbyuser@test.com"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-list-1"
,
Name
:
"Key 1"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-list-2"
,
Name
:
"Key 2"
})
keys
,
page
,
err
:=
s
.
repo
.
ListByUserID
(
s
.
ctx
,
user
.
ID
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
})
s
.
Require
()
.
NoError
(
err
,
"ListByUserID"
)
...
...
@@ -155,9 +155,9 @@ func (s *ApiKeyRepoSuite) TestListByUserID() {
}
func
(
s
*
ApiKeyRepoSuite
)
TestListByUserID_Pagination
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"paging@test.com"
})
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
userModel
{
Email
:
"paging@test.com"
})
for
i
:=
0
;
i
<
5
;
i
++
{
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-page-"
+
string
(
rune
(
'a'
+
i
)),
Name
:
"Key"
,
...
...
@@ -172,9 +172,9 @@ func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
}
func
(
s
*
ApiKeyRepoSuite
)
TestCountByUserID
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"count@test.com"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-count-1"
,
Name
:
"K1"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-count-2"
,
Name
:
"K2"
})
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
userModel
{
Email
:
"count@test.com"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-count-1"
,
Name
:
"K1"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-count-2"
,
Name
:
"K2"
})
count
,
err
:=
s
.
repo
.
CountByUserID
(
s
.
ctx
,
user
.
ID
)
s
.
Require
()
.
NoError
(
err
,
"CountByUserID"
)
...
...
@@ -184,12 +184,12 @@ func (s *ApiKeyRepoSuite) TestCountByUserID() {
// --- ListByGroupID / CountByGroupID ---
func
(
s
*
ApiKeyRepoSuite
)
TestListByGroupID
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"listbygroup@test.com"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g-list"
})
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
userModel
{
Email
:
"listbygroup@test.com"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g-list"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-grp-1"
,
Name
:
"K1"
,
GroupID
:
&
group
.
ID
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-grp-2"
,
Name
:
"K2"
,
GroupID
:
&
group
.
ID
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-grp-3"
,
Name
:
"K3"
})
// no group
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-grp-1"
,
Name
:
"K1"
,
GroupID
:
&
group
.
ID
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-grp-2"
,
Name
:
"K2"
,
GroupID
:
&
group
.
ID
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-grp-3"
,
Name
:
"K3"
})
// no group
keys
,
page
,
err
:=
s
.
repo
.
ListByGroupID
(
s
.
ctx
,
group
.
ID
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
})
s
.
Require
()
.
NoError
(
err
,
"ListByGroupID"
)
...
...
@@ -200,10 +200,10 @@ func (s *ApiKeyRepoSuite) TestListByGroupID() {
}
func
(
s
*
ApiKeyRepoSuite
)
TestCountByGroupID
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"countgroup@test.com"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g-count"
})
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
userModel
{
Email
:
"countgroup@test.com"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g-count"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-gc-1"
,
Name
:
"K1"
,
GroupID
:
&
group
.
ID
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-gc-1"
,
Name
:
"K1"
,
GroupID
:
&
group
.
ID
})
count
,
err
:=
s
.
repo
.
CountByGroupID
(
s
.
ctx
,
group
.
ID
)
s
.
Require
()
.
NoError
(
err
,
"CountByGroupID"
)
...
...
@@ -213,8 +213,8 @@ func (s *ApiKeyRepoSuite) TestCountByGroupID() {
// --- ExistsByKey ---
func
(
s
*
ApiKeyRepoSuite
)
TestExistsByKey
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"exists@test.com"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-exists"
,
Name
:
"K"
})
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
userModel
{
Email
:
"exists@test.com"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-exists"
,
Name
:
"K"
})
exists
,
err
:=
s
.
repo
.
ExistsByKey
(
s
.
ctx
,
"sk-exists"
)
s
.
Require
()
.
NoError
(
err
,
"ExistsByKey"
)
...
...
@@ -228,9 +228,9 @@ func (s *ApiKeyRepoSuite) TestExistsByKey() {
// --- SearchApiKeys ---
func
(
s
*
ApiKeyRepoSuite
)
TestSearchApiKeys
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"search@test.com"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-search-1"
,
Name
:
"Production Key"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-search-2"
,
Name
:
"Development Key"
})
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
userModel
{
Email
:
"search@test.com"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-search-1"
,
Name
:
"Production Key"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-search-2"
,
Name
:
"Development Key"
})
found
,
err
:=
s
.
repo
.
SearchApiKeys
(
s
.
ctx
,
user
.
ID
,
"prod"
,
10
)
s
.
Require
()
.
NoError
(
err
,
"SearchApiKeys"
)
...
...
@@ -239,9 +239,9 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
}
func
(
s
*
ApiKeyRepoSuite
)
TestSearchApiKeys_NoKeyword
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"searchnokw@test.com"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-nk-1"
,
Name
:
"K1"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-nk-2"
,
Name
:
"K2"
})
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
userModel
{
Email
:
"searchnokw@test.com"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-nk-1"
,
Name
:
"K1"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-nk-2"
,
Name
:
"K2"
})
found
,
err
:=
s
.
repo
.
SearchApiKeys
(
s
.
ctx
,
user
.
ID
,
""
,
10
)
s
.
Require
()
.
NoError
(
err
)
...
...
@@ -249,8 +249,8 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
}
func
(
s
*
ApiKeyRepoSuite
)
TestSearchApiKeys_NoUserID
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"searchnouid@test.com"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-nu-1"
,
Name
:
"TestKey"
})
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
userModel
{
Email
:
"searchnouid@test.com"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-nu-1"
,
Name
:
"TestKey"
})
found
,
err
:=
s
.
repo
.
SearchApiKeys
(
s
.
ctx
,
0
,
"testkey"
,
10
)
s
.
Require
()
.
NoError
(
err
)
...
...
@@ -260,12 +260,12 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
// --- ClearGroupIDByGroupID ---
func
(
s
*
ApiKeyRepoSuite
)
TestClearGroupIDByGroupID
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"cleargrp@test.com"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g-clear-bulk"
})
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
userModel
{
Email
:
"cleargrp@test.com"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g-clear-bulk"
})
k1
:=
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-clr-1"
,
Name
:
"K1"
,
GroupID
:
&
group
.
ID
})
k2
:=
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-clr-2"
,
Name
:
"K2"
,
GroupID
:
&
group
.
ID
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-clr-3"
,
Name
:
"K3"
})
// no group
k1
:=
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-clr-1"
,
Name
:
"K1"
,
GroupID
:
&
group
.
ID
})
k2
:=
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-clr-2"
,
Name
:
"K2"
,
GroupID
:
&
group
.
ID
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-clr-3"
,
Name
:
"K3"
})
// no group
affected
,
err
:=
s
.
repo
.
ClearGroupIDByGroupID
(
s
.
ctx
,
group
.
ID
)
s
.
Require
()
.
NoError
(
err
,
"ClearGroupIDByGroupID"
)
...
...
@@ -283,16 +283,16 @@ func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
// --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) ---
func
(
s
*
ApiKeyRepoSuite
)
TestCRUD_Search_ClearGroupID
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"k@example.com"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g-k"
})
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
userModel
{
Email
:
"k@example.com"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g-k"
})
key
:=
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
key
:=
apiKeyModelToService
(
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-test-1"
,
Name
:
"My Key"
,
GroupID
:
&
group
.
ID
,
Status
:
model
.
StatusActive
,
})
Status
:
service
.
StatusActive
,
})
)
got
,
err
:=
s
.
repo
.
GetByKey
(
s
.
ctx
,
key
.
Key
)
s
.
Require
()
.
NoError
(
err
,
"GetByKey"
)
...
...
@@ -303,7 +303,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s
.
Require
()
.
Equal
(
group
.
ID
,
got
.
Group
.
ID
)
key
.
Name
=
"Renamed"
key
.
Status
=
model
.
StatusDisabled
key
.
Status
=
service
.
StatusDisabled
key
.
GroupID
=
nil
s
.
Require
()
.
NoError
(
s
.
repo
.
Update
(
s
.
ctx
,
key
),
"Update"
)
...
...
@@ -312,7 +312,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s
.
Require
()
.
Equal
(
"sk-test-1"
,
got2
.
Key
,
"Update should not change key"
)
s
.
Require
()
.
Equal
(
user
.
ID
,
got2
.
UserID
,
"Update should not change user_id"
)
s
.
Require
()
.
Equal
(
"Renamed"
,
got2
.
Name
)
s
.
Require
()
.
Equal
(
model
.
StatusDisabled
,
got2
.
Status
)
s
.
Require
()
.
Equal
(
service
.
StatusDisabled
,
got2
.
Status
)
s
.
Require
()
.
Nil
(
got2
.
GroupID
)
keys
,
page
,
err
:=
s
.
repo
.
ListByUserID
(
s
.
ctx
,
user
.
ID
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
})
...
...
@@ -330,7 +330,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s
.
Require
()
.
Equal
(
key
.
ID
,
found
[
0
]
.
ID
)
// ClearGroupIDByGroupID
k2
:=
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
k2
:=
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-test-2"
,
Name
:
"Group Key"
,
...
...
backend/internal/repository/auto_migrate.go
0 → 100644
View file @
e5a77853
package
repository
import
"gorm.io/gorm"
// AutoMigrate runs schema migrations for all repository persistence models.
// Persistence models are defined within individual `*_repo.go` files.
func
AutoMigrate
(
db
*
gorm
.
DB
)
error
{
return
db
.
AutoMigrate
(
&
userModel
{},
&
apiKeyModel
{},
&
groupModel
{},
&
accountModel
{},
&
accountGroupModel
{},
&
proxyModel
{},
&
redeemCodeModel
{},
&
usageLogModel
{},
&
settingModel
{},
&
userSubscriptionModel
{},
)
}
backend/internal/repository/fixtures_integration_test.go
View file @
e5a77853
...
...
@@ -6,21 +6,25 @@ import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/
model
"
"github.com/Wei-Shaw/sub2api/internal/
service
"
"github.com/stretchr/testify/require"
"gorm.io/datatypes"
"gorm.io/gorm"
)
func
mustCreateUser
(
t
*
testing
.
T
,
db
*
gorm
.
DB
,
u
*
model
.
User
)
*
model
.
User
{
func
mustCreateUser
(
t
*
testing
.
T
,
db
*
gorm
.
DB
,
u
*
userModel
)
*
userModel
{
t
.
Helper
()
if
u
.
PasswordHash
==
""
{
u
.
PasswordHash
=
"test-password-hash"
}
if
u
.
Role
==
""
{
u
.
Role
=
model
.
RoleUser
u
.
Role
=
service
.
RoleUser
}
if
u
.
Status
==
""
{
u
.
Status
=
model
.
StatusActive
u
.
Status
=
service
.
StatusActive
}
if
u
.
Concurrency
==
0
{
u
.
Concurrency
=
5
}
if
u
.
CreatedAt
.
IsZero
()
{
u
.
CreatedAt
=
time
.
Now
()
...
...
@@ -32,16 +36,16 @@ func mustCreateUser(t *testing.T, db *gorm.DB, u *model.User) *model.User {
return
u
}
func
mustCreateGroup
(
t
*
testing
.
T
,
db
*
gorm
.
DB
,
g
*
model
.
Group
)
*
model
.
Group
{
func
mustCreateGroup
(
t
*
testing
.
T
,
db
*
gorm
.
DB
,
g
*
groupModel
)
*
groupModel
{
t
.
Helper
()
if
g
.
Platform
==
""
{
g
.
Platform
=
model
.
PlatformAnthropic
g
.
Platform
=
service
.
PlatformAnthropic
}
if
g
.
Status
==
""
{
g
.
Status
=
model
.
StatusActive
g
.
Status
=
service
.
StatusActive
}
if
g
.
SubscriptionType
==
""
{
g
.
SubscriptionType
=
model
.
SubscriptionTypeStandard
g
.
SubscriptionType
=
service
.
SubscriptionTypeStandard
}
if
g
.
CreatedAt
.
IsZero
()
{
g
.
CreatedAt
=
time
.
Now
()
...
...
@@ -53,7 +57,7 @@ func mustCreateGroup(t *testing.T, db *gorm.DB, g *model.Group) *model.Group {
return
g
}
func
mustCreateProxy
(
t
*
testing
.
T
,
db
*
gorm
.
DB
,
p
*
model
.
Proxy
)
*
model
.
Proxy
{
func
mustCreateProxy
(
t
*
testing
.
T
,
db
*
gorm
.
DB
,
p
*
proxyModel
)
*
proxyModel
{
t
.
Helper
()
if
p
.
Protocol
==
""
{
p
.
Protocol
=
"http"
...
...
@@ -65,7 +69,7 @@ func mustCreateProxy(t *testing.T, db *gorm.DB, p *model.Proxy) *model.Proxy {
p
.
Port
=
8080
}
if
p
.
Status
==
""
{
p
.
Status
=
model
.
StatusActive
p
.
Status
=
service
.
StatusActive
}
if
p
.
CreatedAt
.
IsZero
()
{
p
.
CreatedAt
=
time
.
Now
()
...
...
@@ -77,25 +81,25 @@ func mustCreateProxy(t *testing.T, db *gorm.DB, p *model.Proxy) *model.Proxy {
return
p
}
func
mustCreateAccount
(
t
*
testing
.
T
,
db
*
gorm
.
DB
,
a
*
model
.
Account
)
*
model
.
A
ccount
{
func
mustCreateAccount
(
t
*
testing
.
T
,
db
*
gorm
.
DB
,
a
*
accountModel
)
*
a
ccount
Model
{
t
.
Helper
()
if
a
.
Platform
==
""
{
a
.
Platform
=
model
.
PlatformAnthropic
a
.
Platform
=
service
.
PlatformAnthropic
}
if
a
.
Type
==
""
{
a
.
Type
=
model
.
AccountTypeOAuth
a
.
Type
=
service
.
AccountTypeOAuth
}
if
a
.
Status
==
""
{
a
.
Status
=
model
.
StatusActive
a
.
Status
=
service
.
StatusActive
}
if
!
a
.
Schedulable
{
a
.
Schedulable
=
true
}
if
a
.
Credentials
==
nil
{
a
.
Credentials
=
model
.
JSON
B
{}
a
.
Credentials
=
datatypes
.
JSON
Map
{}
}
if
a
.
Extra
==
nil
{
a
.
Extra
=
model
.
JSON
B
{}
a
.
Extra
=
datatypes
.
JSON
Map
{}
}
if
a
.
CreatedAt
.
IsZero
()
{
a
.
CreatedAt
=
time
.
Now
()
...
...
@@ -107,10 +111,10 @@ func mustCreateAccount(t *testing.T, db *gorm.DB, a *model.Account) *model.Accou
return
a
}
func
mustCreateApiKey
(
t
*
testing
.
T
,
db
*
gorm
.
DB
,
k
*
model
.
ApiKey
)
*
model
.
ApiKey
{
func
mustCreateApiKey
(
t
*
testing
.
T
,
db
*
gorm
.
DB
,
k
*
apiKeyModel
)
*
apiKeyModel
{
t
.
Helper
()
if
k
.
Status
==
""
{
k
.
Status
=
model
.
StatusActive
k
.
Status
=
service
.
StatusActive
}
if
k
.
CreatedAt
.
IsZero
()
{
k
.
CreatedAt
=
time
.
Now
()
...
...
@@ -122,13 +126,13 @@ func mustCreateApiKey(t *testing.T, db *gorm.DB, k *model.ApiKey) *model.ApiKey
return
k
}
func
mustCreateRedeemCode
(
t
*
testing
.
T
,
db
*
gorm
.
DB
,
c
*
model
.
R
edeemCode
)
*
model
.
R
edeemCode
{
func
mustCreateRedeemCode
(
t
*
testing
.
T
,
db
*
gorm
.
DB
,
c
*
r
edeemCode
Model
)
*
r
edeemCode
Model
{
t
.
Helper
()
if
c
.
Status
==
""
{
c
.
Status
=
model
.
StatusUnused
c
.
Status
=
service
.
StatusUnused
}
if
c
.
Type
==
""
{
c
.
Type
=
model
.
RedeemTypeBalance
c
.
Type
=
service
.
RedeemTypeBalance
}
if
c
.
CreatedAt
.
IsZero
()
{
c
.
CreatedAt
=
time
.
Now
()
...
...
@@ -137,10 +141,10 @@ func mustCreateRedeemCode(t *testing.T, db *gorm.DB, c *model.RedeemCode) *model
return
c
}
func
mustCreateSubscription
(
t
*
testing
.
T
,
db
*
gorm
.
DB
,
s
*
model
.
U
serSubscription
)
*
model
.
U
serSubscription
{
func
mustCreateSubscription
(
t
*
testing
.
T
,
db
*
gorm
.
DB
,
s
*
u
serSubscription
Model
)
*
u
serSubscription
Model
{
t
.
Helper
()
if
s
.
Status
==
""
{
s
.
Status
=
model
.
SubscriptionStatusActive
s
.
Status
=
service
.
SubscriptionStatusActive
}
now
:=
time
.
Now
()
if
s
.
StartsAt
.
IsZero
()
{
...
...
@@ -164,9 +168,10 @@ func mustCreateSubscription(t *testing.T, db *gorm.DB, s *model.UserSubscription
func
mustBindAccountToGroup
(
t
*
testing
.
T
,
db
*
gorm
.
DB
,
accountID
,
groupID
int64
,
priority
int
)
{
t
.
Helper
()
require
.
NoError
(
t
,
db
.
Create
(
&
model
.
A
ccountGroup
{
require
.
NoError
(
t
,
db
.
Create
(
&
a
ccountGroup
Model
{
AccountID
:
accountID
,
GroupID
:
groupID
,
Priority
:
priority
,
CreatedAt
:
time
.
Now
(),
})
.
Error
,
"create account_group"
)
}
backend/internal/repository/group_repo.go
View file @
e5a77853
...
...
@@ -2,10 +2,10 @@ package repository
import
(
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
...
...
@@ -20,38 +20,50 @@ func NewGroupRepository(db *gorm.DB) service.GroupRepository {
return
&
groupRepository
{
db
:
db
}
}
func
(
r
*
groupRepository
)
Create
(
ctx
context
.
Context
,
group
*
model
.
Group
)
error
{
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Create
(
group
)
.
Error
func
(
r
*
groupRepository
)
Create
(
ctx
context
.
Context
,
group
*
service
.
Group
)
error
{
m
:=
groupModelFromService
(
group
)
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Create
(
m
)
.
Error
if
err
==
nil
{
applyGroupModelToService
(
group
,
m
)
}
return
translatePersistenceError
(
err
,
nil
,
service
.
ErrGroupExists
)
}
func
(
r
*
groupRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
Group
,
error
)
{
var
group
m
odel
.
Group
err
:=
r
.
db
.
WithContext
(
ctx
)
.
First
(
&
group
,
id
)
.
Error
func
(
r
*
groupRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
Group
,
error
)
{
var
m
group
M
odel
err
:=
r
.
db
.
WithContext
(
ctx
)
.
First
(
&
m
,
id
)
.
Error
if
err
!=
nil
{
return
nil
,
translatePersistenceError
(
err
,
service
.
ErrGroupNotFound
,
nil
)
}
return
&
group
,
nil
group
:=
groupModelToService
(
&
m
)
count
,
_
:=
r
.
GetAccountCount
(
ctx
,
group
.
ID
)
group
.
AccountCount
=
count
return
group
,
nil
}
func
(
r
*
groupRepository
)
Update
(
ctx
context
.
Context
,
group
*
model
.
Group
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Save
(
group
)
.
Error
func
(
r
*
groupRepository
)
Update
(
ctx
context
.
Context
,
group
*
service
.
Group
)
error
{
m
:=
groupModelFromService
(
group
)
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Save
(
m
)
.
Error
if
err
==
nil
{
applyGroupModelToService
(
group
,
m
)
}
return
err
}
func
(
r
*
groupRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Delete
(
&
model
.
Group
{},
id
)
.
Error
return
r
.
db
.
WithContext
(
ctx
)
.
Delete
(
&
groupModel
{},
id
)
.
Error
}
func
(
r
*
groupRepository
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
model
.
Group
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
groupRepository
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
service
.
Group
,
*
pagination
.
PaginationResult
,
error
)
{
return
r
.
ListWithFilters
(
ctx
,
params
,
""
,
""
,
nil
)
}
// ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive
func
(
r
*
groupRepository
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
status
string
,
isExclusive
*
bool
)
([]
model
.
Group
,
*
pagination
.
PaginationResult
,
error
)
{
var
groups
[]
model
.
Group
func
(
r
*
groupRepository
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
status
string
,
isExclusive
*
bool
)
([]
service
.
Group
,
*
pagination
.
PaginationResult
,
error
)
{
var
groups
[]
groupModel
var
total
int64
db
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
Group
{})
db
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
groupModel
{})
// Apply filters
if
platform
!=
""
{
...
...
@@ -72,68 +84,71 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
return
nil
,
nil
,
err
}
// 获取每个分组的账号数量
outGroups
:=
make
([]
service
.
Group
,
0
,
len
(
groups
))
for
i
:=
range
groups
{
count
,
_
:=
r
.
GetAccountCount
(
ctx
,
groups
[
i
]
.
ID
)
groups
[
i
]
.
AccountCount
=
count
outGroups
=
append
(
outGroups
,
*
groupModelToService
(
&
groups
[
i
]))
}
pages
:=
int
(
total
)
/
params
.
Limit
()
if
int
(
total
)
%
params
.
Limit
()
>
0
{
pages
++
// 获取每个分组的账号数量
for
i
:=
range
outGroups
{
count
,
_
:=
r
.
GetAccountCount
(
ctx
,
outGroups
[
i
]
.
ID
)
outGroups
[
i
]
.
AccountCount
=
count
}
return
groups
,
&
pagination
.
PaginationResult
{
Total
:
total
,
Page
:
params
.
Page
,
PageSize
:
params
.
Limit
(),
Pages
:
pages
,
},
nil
return
outGroups
,
paginationResultFromTotal
(
total
,
params
),
nil
}
func
(
r
*
groupRepository
)
ListActive
(
ctx
context
.
Context
)
([]
model
.
Group
,
error
)
{
var
groups
[]
model
.
Group
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"status = ?"
,
model
.
StatusActive
)
.
Order
(
"id ASC"
)
.
Find
(
&
groups
)
.
Error
func
(
r
*
groupRepository
)
ListActive
(
ctx
context
.
Context
)
([]
service
.
Group
,
error
)
{
var
groups
[]
groupModel
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"status = ?"
,
service
.
StatusActive
)
.
Order
(
"id ASC"
)
.
Find
(
&
groups
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
// 获取每个分组的账号数量
outGroups
:=
make
([]
service
.
Group
,
0
,
len
(
groups
))
for
i
:=
range
groups
{
count
,
_
:=
r
.
GetAccountCount
(
ctx
,
groups
[
i
]
.
ID
)
groups
[
i
]
.
AccountCount
=
count
outGroups
=
append
(
outGroups
,
*
groupModelToService
(
&
groups
[
i
]))
}
// 获取每个分组的账号数量
for
i
:=
range
outGroups
{
count
,
_
:=
r
.
GetAccountCount
(
ctx
,
outGroups
[
i
]
.
ID
)
outGroups
[
i
]
.
AccountCount
=
count
}
return
g
roups
,
nil
return
outG
roups
,
nil
}
func
(
r
*
groupRepository
)
ListActiveByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
model
.
Group
,
error
)
{
var
groups
[]
model
.
Group
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"status = ? AND platform = ?"
,
model
.
StatusActive
,
platform
)
.
Order
(
"id ASC"
)
.
Find
(
&
groups
)
.
Error
func
(
r
*
groupRepository
)
ListActiveByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
service
.
Group
,
error
)
{
var
groups
[]
groupModel
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"status = ? AND platform = ?"
,
service
.
StatusActive
,
platform
)
.
Order
(
"id ASC"
)
.
Find
(
&
groups
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
// 获取每个分组的账号数量
outGroups
:=
make
([]
service
.
Group
,
0
,
len
(
groups
))
for
i
:=
range
groups
{
count
,
_
:=
r
.
GetAccountCount
(
ctx
,
groups
[
i
]
.
ID
)
groups
[
i
]
.
AccountCount
=
count
outGroups
=
append
(
outGroups
,
*
groupModelToService
(
&
groups
[
i
]))
}
// 获取每个分组的账号数量
for
i
:=
range
outGroups
{
count
,
_
:=
r
.
GetAccountCount
(
ctx
,
outGroups
[
i
]
.
ID
)
outGroups
[
i
]
.
AccountCount
=
count
}
return
g
roups
,
nil
return
outG
roups
,
nil
}
func
(
r
*
groupRepository
)
ExistsByName
(
ctx
context
.
Context
,
name
string
)
(
bool
,
error
)
{
var
count
int64
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
Group
{})
.
Where
(
"name = ?"
,
name
)
.
Count
(
&
count
)
.
Error
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
groupModel
{})
.
Where
(
"name = ?"
,
name
)
.
Count
(
&
count
)
.
Error
return
count
>
0
,
err
}
func
(
r
*
groupRepository
)
GetAccountCount
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
var
count
int64
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
A
ccount
G
roup
{}
)
.
Where
(
"group_id = ?"
,
groupID
)
.
Count
(
&
count
)
.
Error
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Table
(
"a
ccount
_g
roup
s"
)
.
Where
(
"group_id = ?"
,
groupID
)
.
Count
(
&
count
)
.
Error
return
count
,
err
}
// DeleteAccountGroupsByGroupID 删除分组与账号的关联关系
func
(
r
*
groupRepository
)
DeleteAccountGroupsByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
result
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"group_id = ?"
,
groupID
)
.
Delete
(
&
model
.
AccountGroup
{}
)
result
:=
r
.
db
.
WithContext
(
ctx
)
.
Exec
(
"DELETE FROM account_groups WHERE group_id = ?"
,
groupID
)
return
result
.
RowsAffected
,
result
.
Error
}
...
...
@@ -145,46 +160,42 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
var
affectedUserIDs
[]
int64
if
group
.
IsSubscriptionType
()
{
var
subscriptions
[]
model
.
UserSubscription
if
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
U
ser
S
ubscription
{}
)
.
Table
(
"u
ser
_s
ubscription
s"
)
.
Where
(
"group_id = ?"
,
id
)
.
Select
(
"user_id"
)
.
Find
(
&
subscriptions
)
.
Error
;
err
!=
nil
{
Pluck
(
"user_id"
,
&
affectedUserIDs
)
.
Error
;
err
!=
nil
{
return
nil
,
err
}
for
_
,
sub
:=
range
subscriptions
{
affectedUserIDs
=
append
(
affectedUserIDs
,
sub
.
UserID
)
}
}
err
=
r
.
db
.
WithContext
(
ctx
)
.
Transaction
(
func
(
tx
*
gorm
.
DB
)
error
{
// 1. 删除订阅类型分组的订阅记录
if
group
.
IsSubscriptionType
()
{
if
err
:=
tx
.
Where
(
"group_id = ?"
,
id
)
.
Delete
(
&
model
.
UserSubscription
{}
)
.
Error
;
err
!=
nil
{
if
err
:=
tx
.
Exec
(
"DELETE FROM user_subscriptions WHERE group_id = ?"
,
id
)
.
Error
;
err
!=
nil
{
return
err
}
}
// 2. 将 api_keys 中绑定该分组的 group_id 设为 nil
if
err
:=
tx
.
Model
(
&
model
.
ApiKey
{})
.
Where
(
"group_id = ?"
,
id
)
.
Update
(
"group_id"
,
nil
)
.
Error
;
err
!=
nil
{
if
err
:=
tx
.
Exec
(
"UPDATE api_keys SET group_id = NULL WHERE group_id = ?"
,
id
)
.
Error
;
err
!=
nil
{
return
err
}
// 3. 从 users.allowed_groups 数组中移除该分组 ID
if
err
:=
tx
.
Model
(
&
model
.
User
{})
.
Where
(
"? = ANY(allowed_groups)"
,
id
)
.
Update
(
"allowed_groups"
,
gorm
.
Expr
(
"array_remove(allowed_groups, ?)"
,
id
))
.
Error
;
err
!=
nil
{
if
err
:=
tx
.
Exec
(
"UPDATE users SET allowed_groups = array_remove(allowed_groups, ?) WHERE ? = ANY(allowed_groups)"
,
id
,
id
,
)
.
Error
;
err
!=
nil
{
return
err
}
// 4. 删除 account_groups 中间表的数据
if
err
:=
tx
.
Where
(
"group_id = ?"
,
id
)
.
Delete
(
&
model
.
AccountGroup
{}
)
.
Error
;
err
!=
nil
{
if
err
:=
tx
.
Exec
(
"DELETE FROM account_groups WHERE group_id = ?"
,
id
)
.
Error
;
err
!=
nil
{
return
err
}
// 5. 删除分组本身(带锁,避免并发写)
if
err
:=
tx
.
Clauses
(
clause
.
Locking
{
Strength
:
"UPDATE"
})
.
Delete
(
&
model
.
Group
{},
id
)
.
Error
;
err
!=
nil
{
if
err
:=
tx
.
Clauses
(
clause
.
Locking
{
Strength
:
"UPDATE"
})
.
Delete
(
&
groupModel
{},
id
)
.
Error
;
err
!=
nil
{
return
err
}
...
...
@@ -196,3 +207,75 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
return
affectedUserIDs
,
nil
}
type
groupModel
struct
{
ID
int64
`gorm:"primaryKey"`
Name
string
`gorm:"uniqueIndex;size:100;not null"`
Description
string
`gorm:"type:text"`
Platform
string
`gorm:"size:50;default:anthropic;not null"`
RateMultiplier
float64
`gorm:"type:decimal(10,4);default:1.0;not null"`
IsExclusive
bool
`gorm:"default:false;not null"`
Status
string
`gorm:"size:20;default:active;not null"`
SubscriptionType
string
`gorm:"size:20;default:standard;not null"`
DailyLimitUSD
*
float64
`gorm:"type:decimal(20,8)"`
WeeklyLimitUSD
*
float64
`gorm:"type:decimal(20,8)"`
MonthlyLimitUSD
*
float64
`gorm:"type:decimal(20,8)"`
CreatedAt
time
.
Time
`gorm:"not null"`
UpdatedAt
time
.
Time
`gorm:"not null"`
DeletedAt
gorm
.
DeletedAt
`gorm:"index"`
}
func
(
groupModel
)
TableName
()
string
{
return
"groups"
}
func
groupModelToService
(
m
*
groupModel
)
*
service
.
Group
{
if
m
==
nil
{
return
nil
}
return
&
service
.
Group
{
ID
:
m
.
ID
,
Name
:
m
.
Name
,
Description
:
m
.
Description
,
Platform
:
m
.
Platform
,
RateMultiplier
:
m
.
RateMultiplier
,
IsExclusive
:
m
.
IsExclusive
,
Status
:
m
.
Status
,
SubscriptionType
:
m
.
SubscriptionType
,
DailyLimitUSD
:
m
.
DailyLimitUSD
,
WeeklyLimitUSD
:
m
.
WeeklyLimitUSD
,
MonthlyLimitUSD
:
m
.
MonthlyLimitUSD
,
CreatedAt
:
m
.
CreatedAt
,
UpdatedAt
:
m
.
UpdatedAt
,
}
}
func
groupModelFromService
(
sg
*
service
.
Group
)
*
groupModel
{
if
sg
==
nil
{
return
nil
}
return
&
groupModel
{
ID
:
sg
.
ID
,
Name
:
sg
.
Name
,
Description
:
sg
.
Description
,
Platform
:
sg
.
Platform
,
RateMultiplier
:
sg
.
RateMultiplier
,
IsExclusive
:
sg
.
IsExclusive
,
Status
:
sg
.
Status
,
SubscriptionType
:
sg
.
SubscriptionType
,
DailyLimitUSD
:
sg
.
DailyLimitUSD
,
WeeklyLimitUSD
:
sg
.
WeeklyLimitUSD
,
MonthlyLimitUSD
:
sg
.
MonthlyLimitUSD
,
CreatedAt
:
sg
.
CreatedAt
,
UpdatedAt
:
sg
.
UpdatedAt
,
}
}
func
applyGroupModelToService
(
group
*
service
.
Group
,
m
*
groupModel
)
{
if
group
==
nil
||
m
==
nil
{
return
}
group
.
ID
=
m
.
ID
group
.
CreatedAt
=
m
.
CreatedAt
group
.
UpdatedAt
=
m
.
UpdatedAt
}
backend/internal/repository/group_repo_integration_test.go
View file @
e5a77853
...
...
@@ -6,8 +6,8 @@ import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
...
...
@@ -32,10 +32,10 @@ func TestGroupRepoSuite(t *testing.T) {
// --- Create / GetByID / Update / Delete ---
func
(
s
*
GroupRepoSuite
)
TestCreate
()
{
group
:=
&
model
.
Group
{
group
:=
&
service
.
Group
{
Name
:
"test-create"
,
Platform
:
model
.
PlatformAnthropic
,
Status
:
model
.
StatusActive
,
Platform
:
service
.
PlatformAnthropic
,
Status
:
service
.
StatusActive
,
}
err
:=
s
.
repo
.
Create
(
s
.
ctx
,
group
)
...
...
@@ -53,7 +53,7 @@ func (s *GroupRepoSuite) TestGetByID_NotFound() {
}
func
(
s
*
GroupRepoSuite
)
TestUpdate
()
{
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"original"
})
group
:=
groupModelToService
(
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"original"
})
)
group
.
Name
=
"updated"
err
:=
s
.
repo
.
Update
(
s
.
ctx
,
group
)
...
...
@@ -65,7 +65,7 @@ func (s *GroupRepoSuite) TestUpdate() {
}
func
(
s
*
GroupRepoSuite
)
TestDelete
()
{
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"to-delete"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"to-delete"
})
err
:=
s
.
repo
.
Delete
(
s
.
ctx
,
group
.
ID
)
s
.
Require
()
.
NoError
(
err
,
"Delete"
)
...
...
@@ -77,8 +77,8 @@ func (s *GroupRepoSuite) TestDelete() {
// --- List / ListWithFilters ---
func
(
s
*
GroupRepoSuite
)
TestList
()
{
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g1"
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g2"
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g1"
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g2"
})
groups
,
page
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
})
s
.
Require
()
.
NoError
(
err
,
"List"
)
...
...
@@ -87,28 +87,28 @@ func (s *GroupRepoSuite) TestList() {
}
func
(
s
*
GroupRepoSuite
)
TestListWithFilters_Platform
()
{
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g1"
,
Platform
:
model
.
PlatformAnthropic
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g2"
,
Platform
:
model
.
PlatformOpenAI
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g1"
,
Platform
:
service
.
PlatformAnthropic
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g2"
,
Platform
:
service
.
PlatformOpenAI
})
groups
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
model
.
PlatformOpenAI
,
""
,
nil
)
groups
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
service
.
PlatformOpenAI
,
""
,
nil
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
groups
,
1
)
s
.
Require
()
.
Equal
(
model
.
PlatformOpenAI
,
groups
[
0
]
.
Platform
)
s
.
Require
()
.
Equal
(
service
.
PlatformOpenAI
,
groups
[
0
]
.
Platform
)
}
func
(
s
*
GroupRepoSuite
)
TestListWithFilters_Status
()
{
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g1"
,
Status
:
model
.
StatusActive
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g2"
,
Status
:
model
.
StatusDisabled
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g1"
,
Status
:
service
.
StatusActive
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g2"
,
Status
:
service
.
StatusDisabled
})
groups
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
""
,
model
.
StatusDisabled
,
nil
)
groups
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
""
,
service
.
StatusDisabled
,
nil
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
groups
,
1
)
s
.
Require
()
.
Equal
(
model
.
StatusDisabled
,
groups
[
0
]
.
Status
)
s
.
Require
()
.
Equal
(
service
.
StatusDisabled
,
groups
[
0
]
.
Status
)
}
func
(
s
*
GroupRepoSuite
)
TestListWithFilters_IsExclusive
()
{
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g1"
,
IsExclusive
:
false
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g2"
,
IsExclusive
:
true
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g1"
,
IsExclusive
:
false
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g2"
,
IsExclusive
:
true
})
isExclusive
:=
true
groups
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
""
,
""
,
&
isExclusive
)
...
...
@@ -118,24 +118,24 @@ func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
}
func
(
s
*
GroupRepoSuite
)
TestListWithFilters_AccountCount
()
{
g1
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
g1
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g1"
,
Platform
:
model
.
PlatformAnthropic
,
Status
:
model
.
StatusActive
,
Platform
:
service
.
PlatformAnthropic
,
Status
:
service
.
StatusActive
,
})
g2
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
g2
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g2"
,
Platform
:
model
.
PlatformAnthropic
,
Status
:
model
.
StatusActive
,
Platform
:
service
.
PlatformAnthropic
,
Status
:
service
.
StatusActive
,
IsExclusive
:
true
,
})
a
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"acc1"
})
a
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"acc1"
})
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
a
.
ID
,
g1
.
ID
,
1
)
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
a
.
ID
,
g2
.
ID
,
1
)
isExclusive
:=
true
groups
,
page
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
model
.
PlatformAnthropic
,
model
.
StatusActive
,
&
isExclusive
)
groups
,
page
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
service
.
PlatformAnthropic
,
service
.
StatusActive
,
&
isExclusive
)
s
.
Require
()
.
NoError
(
err
,
"ListWithFilters"
)
s
.
Require
()
.
Equal
(
int64
(
1
),
page
.
Total
)
s
.
Require
()
.
Len
(
groups
,
1
)
...
...
@@ -146,8 +146,8 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
// --- ListActive / ListActiveByPlatform ---
func
(
s
*
GroupRepoSuite
)
TestListActive
()
{
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"active1"
,
Status
:
model
.
StatusActive
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"inactive1"
,
Status
:
model
.
StatusDisabled
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"active1"
,
Status
:
service
.
StatusActive
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"inactive1"
,
Status
:
service
.
StatusDisabled
})
groups
,
err
:=
s
.
repo
.
ListActive
(
s
.
ctx
)
s
.
Require
()
.
NoError
(
err
,
"ListActive"
)
...
...
@@ -156,11 +156,11 @@ func (s *GroupRepoSuite) TestListActive() {
}
func
(
s
*
GroupRepoSuite
)
TestListActiveByPlatform
()
{
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g1"
,
Platform
:
model
.
PlatformAnthropic
,
Status
:
model
.
StatusActive
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g2"
,
Platform
:
model
.
PlatformOpenAI
,
Status
:
model
.
StatusActive
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g3"
,
Platform
:
model
.
PlatformAnthropic
,
Status
:
model
.
StatusDisabled
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g1"
,
Platform
:
service
.
PlatformAnthropic
,
Status
:
service
.
StatusActive
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g2"
,
Platform
:
service
.
PlatformOpenAI
,
Status
:
service
.
StatusActive
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g3"
,
Platform
:
service
.
PlatformAnthropic
,
Status
:
service
.
StatusDisabled
})
groups
,
err
:=
s
.
repo
.
ListActiveByPlatform
(
s
.
ctx
,
model
.
PlatformAnthropic
)
groups
,
err
:=
s
.
repo
.
ListActiveByPlatform
(
s
.
ctx
,
service
.
PlatformAnthropic
)
s
.
Require
()
.
NoError
(
err
,
"ListActiveByPlatform"
)
s
.
Require
()
.
Len
(
groups
,
1
)
s
.
Require
()
.
Equal
(
"g1"
,
groups
[
0
]
.
Name
)
...
...
@@ -169,7 +169,7 @@ func (s *GroupRepoSuite) TestListActiveByPlatform() {
// --- ExistsByName ---
func
(
s
*
GroupRepoSuite
)
TestExistsByName
()
{
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"existing-group"
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"existing-group"
})
exists
,
err
:=
s
.
repo
.
ExistsByName
(
s
.
ctx
,
"existing-group"
)
s
.
Require
()
.
NoError
(
err
,
"ExistsByName"
)
...
...
@@ -183,9 +183,9 @@ func (s *GroupRepoSuite) TestExistsByName() {
// --- GetAccountCount ---
func
(
s
*
GroupRepoSuite
)
TestGetAccountCount
()
{
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g-count"
})
a1
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"a1"
})
a2
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"a2"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g-count"
})
a1
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"a1"
})
a2
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"a2"
})
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
a1
.
ID
,
group
.
ID
,
1
)
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
a2
.
ID
,
group
.
ID
,
2
)
...
...
@@ -195,7 +195,7 @@ func (s *GroupRepoSuite) TestGetAccountCount() {
}
func
(
s
*
GroupRepoSuite
)
TestGetAccountCount_Empty
()
{
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g-empty"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g-empty"
})
count
,
err
:=
s
.
repo
.
GetAccountCount
(
s
.
ctx
,
group
.
ID
)
s
.
Require
()
.
NoError
(
err
)
...
...
@@ -205,8 +205,8 @@ func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
// --- DeleteAccountGroupsByGroupID ---
func
(
s
*
GroupRepoSuite
)
TestDeleteAccountGroupsByGroupID
()
{
g
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g-del"
})
a
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"acc-del"
})
g
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g-del"
})
a
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"acc-del"
})
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
a
.
ID
,
g
.
ID
,
1
)
affected
,
err
:=
s
.
repo
.
DeleteAccountGroupsByGroupID
(
s
.
ctx
,
g
.
ID
)
...
...
@@ -219,10 +219,10 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
}
func
(
s
*
GroupRepoSuite
)
TestDeleteAccountGroupsByGroupID_MultipleAccounts
()
{
g
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g-multi"
})
a1
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"a1"
})
a2
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"a2"
})
a3
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"a3"
})
g
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g-multi"
})
a1
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"a1"
})
a2
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"a2"
})
a3
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"a3"
})
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
a1
.
ID
,
g
.
ID
,
1
)
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
a2
.
ID
,
g
.
ID
,
2
)
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
a3
.
ID
,
g
.
ID
,
3
)
...
...
backend/internal/repository/integration_harness_test.go
View file @
e5a77853
...
...
@@ -15,7 +15,6 @@ import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
...
...
@@ -94,7 +93,7 @@ func TestMain(m *testing.M) {
log
.
Printf
(
"failed to open gorm db: %v"
,
err
)
os
.
Exit
(
1
)
}
if
err
:=
model
.
AutoMigrate
(
integrationDB
);
err
!=
nil
{
if
err
:=
AutoMigrate
(
integrationDB
);
err
!=
nil
{
log
.
Printf
(
"failed to automigrate db: %v"
,
err
)
os
.
Exit
(
1
)
}
...
...
backend/internal/repository/pagination.go
0 → 100644
View file @
e5a77853
package
repository
import
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
func
paginationResultFromTotal
(
total
int64
,
params
pagination
.
PaginationParams
)
*
pagination
.
PaginationResult
{
pages
:=
int
(
total
)
/
params
.
Limit
()
if
int
(
total
)
%
params
.
Limit
()
>
0
{
pages
++
}
return
&
pagination
.
PaginationResult
{
Total
:
total
,
Page
:
params
.
Page
,
PageSize
:
params
.
Limit
(),
Pages
:
pages
,
}
}
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment