Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
22f07a7b
Commit
22f07a7b
authored
Dec 26, 2025
by
shaw
Browse files
Merge PR #36: refactor: 调整项目结构为单向依赖
parents
ecb2c535
e5a77853
Changes
95
Hide whitespace changes
Inline
Side-by-side
backend/internal/handler/user_handler.go
View file @
22f07a7b
package
handler
package
handler
import
(
import
(
"github.com/Wei-Shaw/sub2api/internal/
model
"
"github.com/Wei-Shaw/sub2api/internal/
handler/dto
"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin"
...
@@ -35,19 +36,13 @@ type UpdateProfileRequest struct {
...
@@ -35,19 +36,13 @@ type UpdateProfileRequest struct {
// GetProfile handles getting user profile
// GetProfile handles getting user profile
// GET /api/v1/users/me
// GET /api/v1/users/me
func
(
h
*
UserHandler
)
GetProfile
(
c
*
gin
.
Context
)
{
func
(
h
*
UserHandler
)
GetProfile
(
c
*
gin
.
Context
)
{
userValue
,
exists
:=
c
.
Get
(
"user"
)
subject
,
ok
:=
middleware2
.
GetAuthSubjectFromContext
(
c
)
if
!
exists
{
response
.
Unauthorized
(
c
,
"User not authenticated"
)
return
}
user
,
ok
:=
userValue
.
(
*
model
.
User
)
if
!
ok
{
if
!
ok
{
response
.
InternalError
(
c
,
"Invalid user context
"
)
response
.
Unauthorized
(
c
,
"User not authenticated
"
)
return
return
}
}
userData
,
err
:=
h
.
userService
.
GetByID
(
c
.
Request
.
Context
(),
u
ser
.
ID
)
userData
,
err
:=
h
.
userService
.
GetByID
(
c
.
Request
.
Context
(),
subject
.
U
serID
)
if
err
!=
nil
{
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
response
.
ErrorFrom
(
c
,
err
)
return
return
...
@@ -56,21 +51,15 @@ func (h *UserHandler) GetProfile(c *gin.Context) {
...
@@ -56,21 +51,15 @@ func (h *UserHandler) GetProfile(c *gin.Context) {
// 清空notes字段,普通用户不应看到备注
// 清空notes字段,普通用户不应看到备注
userData
.
Notes
=
""
userData
.
Notes
=
""
response
.
Success
(
c
,
userData
)
response
.
Success
(
c
,
dto
.
UserFromService
(
userData
)
)
}
}
// ChangePassword handles changing user password
// ChangePassword handles changing user password
// POST /api/v1/users/me/password
// POST /api/v1/users/me/password
func
(
h
*
UserHandler
)
ChangePassword
(
c
*
gin
.
Context
)
{
func
(
h
*
UserHandler
)
ChangePassword
(
c
*
gin
.
Context
)
{
userValue
,
exists
:=
c
.
Get
(
"user"
)
subject
,
ok
:=
middleware2
.
GetAuthSubjectFromContext
(
c
)
if
!
exists
{
response
.
Unauthorized
(
c
,
"User not authenticated"
)
return
}
user
,
ok
:=
userValue
.
(
*
model
.
User
)
if
!
ok
{
if
!
ok
{
response
.
InternalError
(
c
,
"Invalid user context
"
)
response
.
Unauthorized
(
c
,
"User not authenticated
"
)
return
return
}
}
...
@@ -84,7 +73,7 @@ func (h *UserHandler) ChangePassword(c *gin.Context) {
...
@@ -84,7 +73,7 @@ func (h *UserHandler) ChangePassword(c *gin.Context) {
CurrentPassword
:
req
.
OldPassword
,
CurrentPassword
:
req
.
OldPassword
,
NewPassword
:
req
.
NewPassword
,
NewPassword
:
req
.
NewPassword
,
}
}
err
:=
h
.
userService
.
ChangePassword
(
c
.
Request
.
Context
(),
u
ser
.
ID
,
svcReq
)
err
:=
h
.
userService
.
ChangePassword
(
c
.
Request
.
Context
(),
subject
.
U
serID
,
svcReq
)
if
err
!=
nil
{
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
response
.
ErrorFrom
(
c
,
err
)
return
return
...
@@ -96,15 +85,9 @@ func (h *UserHandler) ChangePassword(c *gin.Context) {
...
@@ -96,15 +85,9 @@ func (h *UserHandler) ChangePassword(c *gin.Context) {
// UpdateProfile handles updating user profile
// UpdateProfile handles updating user profile
// PUT /api/v1/users/me
// PUT /api/v1/users/me
func
(
h
*
UserHandler
)
UpdateProfile
(
c
*
gin
.
Context
)
{
func
(
h
*
UserHandler
)
UpdateProfile
(
c
*
gin
.
Context
)
{
userValue
,
exists
:=
c
.
Get
(
"user"
)
subject
,
ok
:=
middleware2
.
GetAuthSubjectFromContext
(
c
)
if
!
exists
{
response
.
Unauthorized
(
c
,
"User not authenticated"
)
return
}
user
,
ok
:=
userValue
.
(
*
model
.
User
)
if
!
ok
{
if
!
ok
{
response
.
InternalError
(
c
,
"Invalid user context
"
)
response
.
Unauthorized
(
c
,
"User not authenticated
"
)
return
return
}
}
...
@@ -118,7 +101,7 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
...
@@ -118,7 +101,7 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
Username
:
req
.
Username
,
Username
:
req
.
Username
,
Wechat
:
req
.
Wechat
,
Wechat
:
req
.
Wechat
,
}
}
updatedUser
,
err
:=
h
.
userService
.
UpdateProfile
(
c
.
Request
.
Context
(),
u
ser
.
ID
,
svcReq
)
updatedUser
,
err
:=
h
.
userService
.
UpdateProfile
(
c
.
Request
.
Context
(),
subject
.
U
serID
,
svcReq
)
if
err
!=
nil
{
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
response
.
ErrorFrom
(
c
,
err
)
return
return
...
@@ -127,5 +110,5 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
...
@@ -127,5 +110,5 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
// 清空notes字段,普通用户不应看到备注
// 清空notes字段,普通用户不应看到备注
updatedUser
.
Notes
=
""
updatedUser
.
Notes
=
""
response
.
Success
(
c
,
updatedUser
)
response
.
Success
(
c
,
dto
.
UserFromService
(
updatedUser
)
)
}
}
backend/internal/infrastructure/database.go
View file @
22f07a7b
...
@@ -2,8 +2,8 @@ package infrastructure
...
@@ -2,8 +2,8 @@ package infrastructure
import
(
import
(
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/repository"
"gorm.io/driver/postgres"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm"
...
@@ -30,7 +30,7 @@ func InitDB(cfg *config.Config) (*gorm.DB, error) {
...
@@ -30,7 +30,7 @@ func InitDB(cfg *config.Config) (*gorm.DB, error) {
// 自动迁移(始终执行,确保数据库结构与代码同步)
// 自动迁移(始终执行,确保数据库结构与代码同步)
// GORM 的 AutoMigrate 只会添加新字段,不会删除或修改已有字段,是安全的
// GORM 的 AutoMigrate 只会添加新字段,不会删除或修改已有字段,是安全的
if
err
:=
model
.
AutoMigrate
(
db
);
err
!=
nil
{
if
err
:=
repository
.
AutoMigrate
(
db
);
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
...
...
backend/internal/model/account_group.go
deleted
100644 → 0
View file @
ecb2c535
package
model
import
(
"time"
)
type
AccountGroup
struct
{
AccountID
int64
`gorm:"primaryKey" json:"account_id"`
GroupID
int64
`gorm:"primaryKey" json:"group_id"`
Priority
int
`gorm:"default:50;not null" json:"priority"`
// 分组内优先级
CreatedAt
time
.
Time
`gorm:"not null" json:"created_at"`
// 关联
Account
*
Account
`gorm:"foreignKey:AccountID" json:"account,omitempty"`
Group
*
Group
`gorm:"foreignKey:GroupID" json:"group,omitempty"`
}
func
(
AccountGroup
)
TableName
()
string
{
return
"account_groups"
}
backend/internal/model/api_key.go
deleted
100644 → 0
View file @
ecb2c535
package
model
import
(
"time"
"gorm.io/gorm"
)
type
ApiKey
struct
{
ID
int64
`gorm:"primaryKey" json:"id"`
UserID
int64
`gorm:"index;not null" json:"user_id"`
Key
string
`gorm:"uniqueIndex;size:128;not null" json:"key"`
// sk-xxx
Name
string
`gorm:"size:100;not null" json:"name"`
GroupID
*
int64
`gorm:"index" json:"group_id"`
Status
string
`gorm:"size:20;default:active;not null" json:"status"`
// active/disabled
CreatedAt
time
.
Time
`gorm:"not null" json:"created_at"`
UpdatedAt
time
.
Time
`gorm:"not null" json:"updated_at"`
DeletedAt
gorm
.
DeletedAt
`gorm:"index" json:"-"`
// 关联
User
*
User
`gorm:"foreignKey:UserID" json:"user,omitempty"`
Group
*
Group
`gorm:"foreignKey:GroupID" json:"group,omitempty"`
}
func
(
ApiKey
)
TableName
()
string
{
return
"api_keys"
}
// IsActive 检查是否激活
func
(
k
*
ApiKey
)
IsActive
()
bool
{
return
k
.
Status
==
"active"
}
backend/internal/model/group.go
deleted
100644 → 0
View file @
ecb2c535
package
model
import
(
"time"
"gorm.io/gorm"
)
// 订阅类型常量
const
(
SubscriptionTypeStandard
=
"standard"
// 标准计费模式(按余额扣费)
SubscriptionTypeSubscription
=
"subscription"
// 订阅模式(按限额控制)
)
type
Group
struct
{
ID
int64
`gorm:"primaryKey" json:"id"`
Name
string
`gorm:"uniqueIndex;size:100;not null" json:"name"`
Description
string
`gorm:"type:text" json:"description"`
Platform
string
`gorm:"size:50;default:anthropic;not null" json:"platform"`
// anthropic/openai/gemini
RateMultiplier
float64
`gorm:"type:decimal(10,4);default:1.0;not null" json:"rate_multiplier"`
IsExclusive
bool
`gorm:"default:false;not null" json:"is_exclusive"`
Status
string
`gorm:"size:20;default:active;not null" json:"status"`
// active/disabled
// 订阅功能字段
SubscriptionType
string
`gorm:"size:20;default:standard;not null" json:"subscription_type"`
// standard/subscription
DailyLimitUSD
*
float64
`gorm:"type:decimal(20,8)" json:"daily_limit_usd"`
WeeklyLimitUSD
*
float64
`gorm:"type:decimal(20,8)" json:"weekly_limit_usd"`
MonthlyLimitUSD
*
float64
`gorm:"type:decimal(20,8)" json:"monthly_limit_usd"`
CreatedAt
time
.
Time
`gorm:"not null" json:"created_at"`
UpdatedAt
time
.
Time
`gorm:"not null" json:"updated_at"`
DeletedAt
gorm
.
DeletedAt
`gorm:"index" json:"-"`
// 关联
AccountGroups
[]
AccountGroup
`gorm:"foreignKey:GroupID" json:"account_groups,omitempty"`
// 虚拟字段 (不存储到数据库)
AccountCount
int64
`gorm:"-" json:"account_count,omitempty"`
}
func
(
Group
)
TableName
()
string
{
return
"groups"
}
// IsActive 检查是否激活
func
(
g
*
Group
)
IsActive
()
bool
{
return
g
.
Status
==
"active"
}
// IsSubscriptionType 检查是否为订阅类型分组
func
(
g
*
Group
)
IsSubscriptionType
()
bool
{
return
g
.
SubscriptionType
==
SubscriptionTypeSubscription
}
// IsFreeSubscription 检查是否为免费订阅(不扣余额但有限额)
func
(
g
*
Group
)
IsFreeSubscription
()
bool
{
return
g
.
IsSubscriptionType
()
&&
g
.
RateMultiplier
==
0
}
// HasDailyLimit 检查是否有日限额
func
(
g
*
Group
)
HasDailyLimit
()
bool
{
return
g
.
DailyLimitUSD
!=
nil
&&
*
g
.
DailyLimitUSD
>
0
}
// HasWeeklyLimit 检查是否有周限额
func
(
g
*
Group
)
HasWeeklyLimit
()
bool
{
return
g
.
WeeklyLimitUSD
!=
nil
&&
*
g
.
WeeklyLimitUSD
>
0
}
// HasMonthlyLimit 检查是否有月限额
func
(
g
*
Group
)
HasMonthlyLimit
()
bool
{
return
g
.
MonthlyLimitUSD
!=
nil
&&
*
g
.
MonthlyLimitUSD
>
0
}
backend/internal/model/model.go
deleted
100644 → 0
View file @
ecb2c535
package
model
import
(
"gorm.io/gorm"
)
// AutoMigrate 自动迁移所有模型
func
AutoMigrate
(
db
*
gorm
.
DB
)
error
{
return
db
.
AutoMigrate
(
&
User
{},
&
ApiKey
{},
&
Group
{},
&
Account
{},
&
AccountGroup
{},
&
Proxy
{},
&
RedeemCode
{},
&
UsageLog
{},
&
Setting
{},
&
UserSubscription
{},
)
}
// 状态常量
const
(
StatusActive
=
"active"
StatusDisabled
=
"disabled"
StatusError
=
"error"
StatusUnused
=
"unused"
StatusUsed
=
"used"
StatusExpired
=
"expired"
)
// 角色常量
const
(
RoleAdmin
=
"admin"
RoleUser
=
"user"
)
// 平台常量
const
(
PlatformAnthropic
=
"anthropic"
PlatformOpenAI
=
"openai"
PlatformGemini
=
"gemini"
)
// 账号类型常量
const
(
AccountTypeOAuth
=
"oauth"
// OAuth类型账号(full scope: profile + inference)
AccountTypeSetupToken
=
"setup-token"
// Setup Token类型账号(inference only scope)
AccountTypeApiKey
=
"apikey"
// API Key类型账号
)
// 卡密类型常量
const
(
RedeemTypeBalance
=
"balance"
RedeemTypeConcurrency
=
"concurrency"
RedeemTypeSubscription
=
"subscription"
)
// 管理员调整类型常量
const
(
AdjustmentTypeAdminBalance
=
"admin_balance"
// 管理员调整余额
AdjustmentTypeAdminConcurrency
=
"admin_concurrency"
// 管理员调整并发数
)
backend/internal/model/proxy.go
deleted
100644 → 0
View file @
ecb2c535
package
model
import
(
"fmt"
"time"
"gorm.io/gorm"
)
type
Proxy
struct
{
ID
int64
`gorm:"primaryKey" json:"id"`
Name
string
`gorm:"size:100;not null" json:"name"`
Protocol
string
`gorm:"size:20;not null" json:"protocol"`
// http/https/socks5
Host
string
`gorm:"size:255;not null" json:"host"`
Port
int
`gorm:"not null" json:"port"`
Username
string
`gorm:"size:100" json:"username"`
Password
string
`gorm:"size:100" json:"-"`
Status
string
`gorm:"size:20;default:active;not null" json:"status"`
// active/disabled
CreatedAt
time
.
Time
`gorm:"not null" json:"created_at"`
UpdatedAt
time
.
Time
`gorm:"not null" json:"updated_at"`
DeletedAt
gorm
.
DeletedAt
`gorm:"index" json:"-"`
}
func
(
Proxy
)
TableName
()
string
{
return
"proxies"
}
// IsActive 检查是否激活
func
(
p
*
Proxy
)
IsActive
()
bool
{
return
p
.
Status
==
"active"
}
// URL 返回代理URL
func
(
p
*
Proxy
)
URL
()
string
{
if
p
.
Username
!=
""
&&
p
.
Password
!=
""
{
return
fmt
.
Sprintf
(
"%s://%s:%s@%s:%d"
,
p
.
Protocol
,
p
.
Username
,
p
.
Password
,
p
.
Host
,
p
.
Port
)
}
return
fmt
.
Sprintf
(
"%s://%s:%d"
,
p
.
Protocol
,
p
.
Host
,
p
.
Port
)
}
// ProxyWithAccountCount extends Proxy with account count information
type
ProxyWithAccountCount
struct
{
Proxy
AccountCount
int64
`json:"account_count"`
}
backend/internal/model/redeem_code.go
deleted
100644 → 0
View file @
ecb2c535
package
model
import
(
"crypto/rand"
"encoding/hex"
"time"
)
type
RedeemCode
struct
{
ID
int64
`gorm:"primaryKey" json:"id"`
Code
string
`gorm:"uniqueIndex;size:32;not null" json:"code"`
Type
string
`gorm:"size:20;default:balance;not null" json:"type"`
// balance/concurrency/subscription
Value
float64
`gorm:"type:decimal(20,8);not null" json:"value"`
// 面值(USD)或并发数或有效天数
Status
string
`gorm:"size:20;default:unused;not null" json:"status"`
// unused/used
UsedBy
*
int64
`gorm:"index" json:"used_by"`
UsedAt
*
time
.
Time
`json:"used_at"`
Notes
string
`gorm:"type:text" json:"notes"`
CreatedAt
time
.
Time
`gorm:"not null" json:"created_at"`
// 订阅类型专用字段
GroupID
*
int64
`gorm:"index" json:"group_id"`
// 订阅分组ID (仅subscription类型使用)
ValidityDays
int
`gorm:"default:30" json:"validity_days"`
// 订阅有效天数 (仅subscription类型使用)
// 关联
User
*
User
`gorm:"foreignKey:UsedBy" json:"user,omitempty"`
Group
*
Group
`gorm:"foreignKey:GroupID" json:"group,omitempty"`
}
func
(
RedeemCode
)
TableName
()
string
{
return
"redeem_codes"
}
// IsUsed 检查是否已使用
func
(
r
*
RedeemCode
)
IsUsed
()
bool
{
return
r
.
Status
==
"used"
}
// CanUse 检查是否可以使用
func
(
r
*
RedeemCode
)
CanUse
()
bool
{
return
r
.
Status
==
"unused"
}
// GenerateRedeemCode 生成唯一的兑换码
func
GenerateRedeemCode
()
(
string
,
error
)
{
b
:=
make
([]
byte
,
16
)
if
_
,
err
:=
rand
.
Read
(
b
);
err
!=
nil
{
return
""
,
err
}
return
hex
.
EncodeToString
(
b
),
nil
}
backend/internal/model/usage_log.go
deleted
100644 → 0
View file @
ecb2c535
package
model
import
(
"time"
)
// 消费类型常量
const
(
BillingTypeBalance
int8
=
0
// 钱包余额
BillingTypeSubscription
int8
=
1
// 订阅套餐
)
type
UsageLog
struct
{
ID
int64
`gorm:"primaryKey" json:"id"`
UserID
int64
`gorm:"index;not null" json:"user_id"`
ApiKeyID
int64
`gorm:"index;not null" json:"api_key_id"`
AccountID
int64
`gorm:"index;not null" json:"account_id"`
RequestID
string
`gorm:"size:64" json:"request_id"`
Model
string
`gorm:"size:100;index;not null" json:"model"`
// 订阅关联(可选)
GroupID
*
int64
`gorm:"index" json:"group_id"`
SubscriptionID
*
int64
`gorm:"index" json:"subscription_id"`
// Token使用量(4类)
InputTokens
int
`gorm:"default:0;not null" json:"input_tokens"`
OutputTokens
int
`gorm:"default:0;not null" json:"output_tokens"`
CacheCreationTokens
int
`gorm:"default:0;not null" json:"cache_creation_tokens"`
CacheReadTokens
int
`gorm:"default:0;not null" json:"cache_read_tokens"`
// 详细的缓存创建分类
CacheCreation5mTokens
int
`gorm:"default:0;not null" json:"cache_creation_5m_tokens"`
CacheCreation1hTokens
int
`gorm:"default:0;not null" json:"cache_creation_1h_tokens"`
// 费用(USD)
InputCost
float64
`gorm:"type:decimal(20,10);default:0;not null" json:"input_cost"`
OutputCost
float64
`gorm:"type:decimal(20,10);default:0;not null" json:"output_cost"`
CacheCreationCost
float64
`gorm:"type:decimal(20,10);default:0;not null" json:"cache_creation_cost"`
CacheReadCost
float64
`gorm:"type:decimal(20,10);default:0;not null" json:"cache_read_cost"`
TotalCost
float64
`gorm:"type:decimal(20,10);default:0;not null" json:"total_cost"`
// 原始总费用
ActualCost
float64
`gorm:"type:decimal(20,10);default:0;not null" json:"actual_cost"`
// 实际扣除费用
RateMultiplier
float64
`gorm:"type:decimal(10,4);default:1;not null" json:"rate_multiplier"`
// 计费倍率
// 元数据
BillingType
int8
`gorm:"type:smallint;default:0;not null" json:"billing_type"`
// 0=余额 1=订阅
Stream
bool
`gorm:"default:false;not null" json:"stream"`
DurationMs
*
int
`json:"duration_ms"`
FirstTokenMs
*
int
`json:"first_token_ms"`
// 首字时间(流式请求)
CreatedAt
time
.
Time
`gorm:"index;not null" json:"created_at"`
// 关联
User
*
User
`gorm:"foreignKey:UserID" json:"user,omitempty"`
ApiKey
*
ApiKey
`gorm:"foreignKey:ApiKeyID" json:"api_key,omitempty"`
Account
*
Account
`gorm:"foreignKey:AccountID" json:"account,omitempty"`
Group
*
Group
`gorm:"foreignKey:GroupID" json:"group,omitempty"`
Subscription
*
UserSubscription
`gorm:"foreignKey:SubscriptionID" json:"subscription,omitempty"`
}
func
(
UsageLog
)
TableName
()
string
{
return
"usage_logs"
}
// TotalTokens 总token数
func
(
u
*
UsageLog
)
TotalTokens
()
int
{
return
u
.
InputTokens
+
u
.
OutputTokens
+
u
.
CacheCreationTokens
+
u
.
CacheReadTokens
}
backend/internal/model/user.go
deleted
100644 → 0
View file @
ecb2c535
package
model
import
(
"time"
"github.com/lib/pq"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
type
User
struct
{
ID
int64
`gorm:"primaryKey" json:"id"`
Email
string
`gorm:"uniqueIndex;size:255;not null" json:"email"`
Username
string
`gorm:"size:100;default:''" json:"username"`
Wechat
string
`gorm:"size:100;default:''" json:"wechat"`
Notes
string
`gorm:"type:text;default:''" json:"notes"`
PasswordHash
string
`gorm:"size:255;not null" json:"-"`
Role
string
`gorm:"size:20;default:user;not null" json:"role"`
// admin/user
Balance
float64
`gorm:"type:decimal(20,8);default:0;not null" json:"balance"`
Concurrency
int
`gorm:"default:5;not null" json:"concurrency"`
Status
string
`gorm:"size:20;default:active;not null" json:"status"`
// active/disabled
AllowedGroups
pq
.
Int64Array
`gorm:"type:bigint[]" json:"allowed_groups"`
CreatedAt
time
.
Time
`gorm:"not null" json:"created_at"`
UpdatedAt
time
.
Time
`gorm:"not null" json:"updated_at"`
DeletedAt
gorm
.
DeletedAt
`gorm:"index" json:"-"`
// 关联
ApiKeys
[]
ApiKey
`gorm:"foreignKey:UserID" json:"api_keys,omitempty"`
Subscriptions
[]
UserSubscription
`gorm:"foreignKey:UserID" json:"subscriptions,omitempty"`
}
func
(
User
)
TableName
()
string
{
return
"users"
}
// IsAdmin 检查是否管理员
func
(
u
*
User
)
IsAdmin
()
bool
{
return
u
.
Role
==
"admin"
}
// IsActive 检查是否激活
func
(
u
*
User
)
IsActive
()
bool
{
return
u
.
Status
==
"active"
}
// CanBindGroup 检查是否可以绑定指定分组
// 对于标准类型分组:
// - 如果 AllowedGroups 设置了值(非空数组),只能绑定列表中的分组
// - 如果 AllowedGroups 为 nil 或空数组,可以绑定所有非专属分组
func
(
u
*
User
)
CanBindGroup
(
groupID
int64
,
isExclusive
bool
)
bool
{
// 如果设置了 allowed_groups 且不为空,只能绑定指定的分组
if
len
(
u
.
AllowedGroups
)
>
0
{
for
_
,
id
:=
range
u
.
AllowedGroups
{
if
id
==
groupID
{
return
true
}
}
return
false
}
// 如果没有设置 allowed_groups 或为空数组,可以绑定所有非专属分组
return
!
isExclusive
}
// SetPassword 设置密码(哈希存储)
func
(
u
*
User
)
SetPassword
(
password
string
)
error
{
hash
,
err
:=
bcrypt
.
GenerateFromPassword
([]
byte
(
password
),
bcrypt
.
DefaultCost
)
if
err
!=
nil
{
return
err
}
u
.
PasswordHash
=
string
(
hash
)
return
nil
}
// CheckPassword 验证密码
func
(
u
*
User
)
CheckPassword
(
password
string
)
bool
{
err
:=
bcrypt
.
CompareHashAndPassword
([]
byte
(
u
.
PasswordHash
),
[]
byte
(
password
))
return
err
==
nil
}
backend/internal/repository/account_repo.go
View file @
22f07a7b
...
@@ -5,10 +5,10 @@ import (
...
@@ -5,10 +5,10 @@ import (
"errors"
"errors"
"time"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/service"
"gorm.io/datatypes"
"gorm.io/gorm"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/clause"
)
)
...
@@ -21,69 +21,66 @@ func NewAccountRepository(db *gorm.DB) service.AccountRepository {
...
@@ -21,69 +21,66 @@ func NewAccountRepository(db *gorm.DB) service.AccountRepository {
return
&
accountRepository
{
db
:
db
}
return
&
accountRepository
{
db
:
db
}
}
}
func
(
r
*
accountRepository
)
Create
(
ctx
context
.
Context
,
account
*
model
.
Account
)
error
{
func
(
r
*
accountRepository
)
Create
(
ctx
context
.
Context
,
account
*
service
.
Account
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Create
(
account
)
.
Error
m
:=
accountModelFromService
(
account
)
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Create
(
m
)
.
Error
if
err
==
nil
{
applyAccountModelToService
(
account
,
m
)
}
return
err
}
}
func
(
r
*
accountRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
Account
,
error
)
{
func
(
r
*
accountRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
Account
,
error
)
{
var
account
m
odel
.
Account
var
m
account
M
odel
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Preload
(
"Proxy"
)
.
Preload
(
"AccountGroups.Group"
)
.
First
(
&
account
,
id
)
.
Error
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Preload
(
"Proxy"
)
.
Preload
(
"AccountGroups.Group"
)
.
First
(
&
m
,
id
)
.
Error
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
translatePersistenceError
(
err
,
service
.
ErrAccountNotFound
,
nil
)
return
nil
,
translatePersistenceError
(
err
,
service
.
ErrAccountNotFound
,
nil
)
}
}
// 填充 GroupIDs 和 Groups 虚拟字段
return
accountModelToService
(
&
m
),
nil
account
.
GroupIDs
=
make
([]
int64
,
0
,
len
(
account
.
AccountGroups
))
account
.
Groups
=
make
([]
*
model
.
Group
,
0
,
len
(
account
.
AccountGroups
))
for
_
,
ag
:=
range
account
.
AccountGroups
{
account
.
GroupIDs
=
append
(
account
.
GroupIDs
,
ag
.
GroupID
)
if
ag
.
Group
!=
nil
{
account
.
Groups
=
append
(
account
.
Groups
,
ag
.
Group
)
}
}
return
&
account
,
nil
}
}
func
(
r
*
accountRepository
)
GetByCRSAccountID
(
ctx
context
.
Context
,
crsAccountID
string
)
(
*
model
.
Account
,
error
)
{
func
(
r
*
accountRepository
)
GetByCRSAccountID
(
ctx
context
.
Context
,
crsAccountID
string
)
(
*
service
.
Account
,
error
)
{
if
crsAccountID
==
""
{
if
crsAccountID
==
""
{
return
nil
,
nil
return
nil
,
nil
}
}
var
account
m
odel
.
Account
var
m
account
M
odel
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"extra->>'crs_account_id' = ?"
,
crsAccountID
)
.
First
(
&
account
)
.
Error
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"extra->>'crs_account_id' = ?"
,
crsAccountID
)
.
First
(
&
m
)
.
Error
if
err
!=
nil
{
if
err
!=
nil
{
if
errors
.
Is
(
err
,
gorm
.
ErrRecordNotFound
)
{
if
errors
.
Is
(
err
,
gorm
.
ErrRecordNotFound
)
{
return
nil
,
nil
return
nil
,
nil
}
}
return
nil
,
err
return
nil
,
err
}
}
return
&
account
,
nil
return
account
ModelToService
(
&
m
)
,
nil
}
}
func
(
r
*
accountRepository
)
Update
(
ctx
context
.
Context
,
account
*
model
.
Account
)
error
{
func
(
r
*
accountRepository
)
Update
(
ctx
context
.
Context
,
account
*
service
.
Account
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Save
(
account
)
.
Error
m
:=
accountModelFromService
(
account
)
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Save
(
m
)
.
Error
if
err
==
nil
{
applyAccountModelToService
(
account
,
m
)
}
return
err
}
}
func
(
r
*
accountRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
func
(
r
*
accountRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
// 先删除账号与分组的绑定关系
if
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"account_id = ?"
,
id
)
.
Delete
(
&
accountGroupModel
{})
.
Error
;
err
!=
nil
{
if
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"account_id = ?"
,
id
)
.
Delete
(
&
model
.
AccountGroup
{})
.
Error
;
err
!=
nil
{
return
err
return
err
}
}
// 再删除账号
return
r
.
db
.
WithContext
(
ctx
)
.
Delete
(
&
accountModel
{},
id
)
.
Error
return
r
.
db
.
WithContext
(
ctx
)
.
Delete
(
&
model
.
Account
{},
id
)
.
Error
}
}
func
(
r
*
accountRepository
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
model
.
Account
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
accountRepository
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
service
.
Account
,
*
pagination
.
PaginationResult
,
error
)
{
return
r
.
ListWithFilters
(
ctx
,
params
,
""
,
""
,
""
,
""
)
return
r
.
ListWithFilters
(
ctx
,
params
,
""
,
""
,
""
,
""
)
}
}
// ListWithFilters lists accounts with optional filtering by platform, type, status, and search query
func
(
r
*
accountRepository
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
accountType
,
status
,
search
string
)
([]
service
.
Account
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
accountRepository
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
accountType
,
status
,
search
string
)
([]
model
.
Account
,
*
pagination
.
PaginationResult
,
error
)
{
var
accounts
[]
accountModel
var
accounts
[]
model
.
Account
var
total
int64
var
total
int64
db
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
A
ccount
{})
db
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
a
ccount
Model
{})
// Apply filters
if
platform
!=
""
{
if
platform
!=
""
{
db
=
db
.
Where
(
"platform = ?"
,
platform
)
db
=
db
.
Where
(
"platform = ?"
,
platform
)
}
}
...
@@ -106,67 +103,84 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
...
@@ -106,67 +103,84 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
return
nil
,
nil
,
err
return
nil
,
nil
,
err
}
}
// 填充每个 Account 的虚拟字段(GroupIDs 和 Groups)
outAccounts
:=
make
([]
service
.
Account
,
0
,
len
(
accounts
))
for
i
:=
range
accounts
{
for
i
:=
range
accounts
{
accounts
[
i
]
.
GroupIDs
=
make
([]
int64
,
0
,
len
(
accounts
[
i
]
.
AccountGroups
))
outAccounts
=
append
(
outAccounts
,
*
accountModelToService
(
&
accounts
[
i
]))
accounts
[
i
]
.
Groups
=
make
([]
*
model
.
Group
,
0
,
len
(
accounts
[
i
]
.
AccountGroups
))
for
_
,
ag
:=
range
accounts
[
i
]
.
AccountGroups
{
accounts
[
i
]
.
GroupIDs
=
append
(
accounts
[
i
]
.
GroupIDs
,
ag
.
GroupID
)
if
ag
.
Group
!=
nil
{
accounts
[
i
]
.
Groups
=
append
(
accounts
[
i
]
.
Groups
,
ag
.
Group
)
}
}
}
}
pages
:=
int
(
total
)
/
params
.
Limit
()
return
outAccounts
,
paginationResultFromTotal
(
total
,
params
),
nil
if
int
(
total
)
%
params
.
Limit
()
>
0
{
pages
++
}
return
accounts
,
&
pagination
.
PaginationResult
{
Total
:
total
,
Page
:
params
.
Page
,
PageSize
:
params
.
Limit
(),
Pages
:
pages
,
},
nil
}
}
func
(
r
*
accountRepository
)
ListByGroup
(
ctx
context
.
Context
,
groupID
int64
)
([]
model
.
Account
,
error
)
{
func
(
r
*
accountRepository
)
ListByGroup
(
ctx
context
.
Context
,
groupID
int64
)
([]
service
.
Account
,
error
)
{
var
accounts
[]
model
.
A
ccount
var
accounts
[]
a
ccount
Model
err
:=
r
.
db
.
WithContext
(
ctx
)
.
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Joins
(
"JOIN account_groups ON account_groups.account_id = accounts.id"
)
.
Joins
(
"JOIN account_groups ON account_groups.account_id = accounts.id"
)
.
Where
(
"account_groups.group_id = ? AND accounts.status = ?"
,
groupID
,
model
.
StatusActive
)
.
Where
(
"account_groups.group_id = ? AND accounts.status = ?"
,
groupID
,
service
.
StatusActive
)
.
Preload
(
"Proxy"
)
.
Preload
(
"Proxy"
)
.
Order
(
"account_groups.priority ASC, accounts.priority ASC"
)
.
Order
(
"account_groups.priority ASC, accounts.priority ASC"
)
.
Find
(
&
accounts
)
.
Error
Find
(
&
accounts
)
.
Error
return
accounts
,
err
if
err
!=
nil
{
return
nil
,
err
}
outAccounts
:=
make
([]
service
.
Account
,
0
,
len
(
accounts
))
for
i
:=
range
accounts
{
outAccounts
=
append
(
outAccounts
,
*
accountModelToService
(
&
accounts
[
i
]))
}
return
outAccounts
,
nil
}
}
func
(
r
*
accountRepository
)
ListActive
(
ctx
context
.
Context
)
([]
model
.
Account
,
error
)
{
func
(
r
*
accountRepository
)
ListActive
(
ctx
context
.
Context
)
([]
service
.
Account
,
error
)
{
var
accounts
[]
model
.
A
ccount
var
accounts
[]
a
ccount
Model
err
:=
r
.
db
.
WithContext
(
ctx
)
.
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"status = ?"
,
model
.
StatusActive
)
.
Where
(
"status = ?"
,
service
.
StatusActive
)
.
Preload
(
"Proxy"
)
.
Preload
(
"Proxy"
)
.
Order
(
"priority ASC"
)
.
Order
(
"priority ASC"
)
.
Find
(
&
accounts
)
.
Error
Find
(
&
accounts
)
.
Error
return
accounts
,
err
if
err
!=
nil
{
return
nil
,
err
}
outAccounts
:=
make
([]
service
.
Account
,
0
,
len
(
accounts
))
for
i
:=
range
accounts
{
outAccounts
=
append
(
outAccounts
,
*
accountModelToService
(
&
accounts
[
i
]))
}
return
outAccounts
,
nil
}
func
(
r
*
accountRepository
)
ListByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
service
.
Account
,
error
)
{
var
accounts
[]
accountModel
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"platform = ? AND status = ?"
,
platform
,
service
.
StatusActive
)
.
Preload
(
"Proxy"
)
.
Order
(
"priority ASC"
)
.
Find
(
&
accounts
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
outAccounts
:=
make
([]
service
.
Account
,
0
,
len
(
accounts
))
for
i
:=
range
accounts
{
outAccounts
=
append
(
outAccounts
,
*
accountModelToService
(
&
accounts
[
i
]))
}
return
outAccounts
,
nil
}
}
func
(
r
*
accountRepository
)
UpdateLastUsed
(
ctx
context
.
Context
,
id
int64
)
error
{
func
(
r
*
accountRepository
)
UpdateLastUsed
(
ctx
context
.
Context
,
id
int64
)
error
{
now
:=
time
.
Now
()
now
:=
time
.
Now
()
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
A
ccount
{})
.
Where
(
"id = ?"
,
id
)
.
Update
(
"last_used_at"
,
now
)
.
Error
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
a
ccount
Model
{})
.
Where
(
"id = ?"
,
id
)
.
Update
(
"last_used_at"
,
now
)
.
Error
}
}
func
(
r
*
accountRepository
)
SetError
(
ctx
context
.
Context
,
id
int64
,
errorMsg
string
)
error
{
func
(
r
*
accountRepository
)
SetError
(
ctx
context
.
Context
,
id
int64
,
errorMsg
string
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
A
ccount
{})
.
Where
(
"id = ?"
,
id
)
.
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
a
ccount
Model
{})
.
Where
(
"id = ?"
,
id
)
.
Updates
(
map
[
string
]
any
{
Updates
(
map
[
string
]
any
{
"status"
:
model
.
StatusError
,
"status"
:
service
.
StatusError
,
"error_message"
:
errorMsg
,
"error_message"
:
errorMsg
,
})
.
Error
})
.
Error
}
}
func
(
r
*
accountRepository
)
AddToGroup
(
ctx
context
.
Context
,
accountID
,
groupID
int64
,
priority
int
)
error
{
func
(
r
*
accountRepository
)
AddToGroup
(
ctx
context
.
Context
,
accountID
,
groupID
int64
,
priority
int
)
error
{
ag
:=
&
model
.
A
ccountGroup
{
ag
:=
&
a
ccountGroup
Model
{
AccountID
:
accountID
,
AccountID
:
accountID
,
GroupID
:
groupID
,
GroupID
:
groupID
,
Priority
:
priority
,
Priority
:
priority
,
...
@@ -176,131 +190,148 @@ func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID i
...
@@ -176,131 +190,148 @@ func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID i
func
(
r
*
accountRepository
)
RemoveFromGroup
(
ctx
context
.
Context
,
accountID
,
groupID
int64
)
error
{
func
(
r
*
accountRepository
)
RemoveFromGroup
(
ctx
context
.
Context
,
accountID
,
groupID
int64
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"account_id = ? AND group_id = ?"
,
accountID
,
groupID
)
.
return
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"account_id = ? AND group_id = ?"
,
accountID
,
groupID
)
.
Delete
(
&
model
.
A
ccountGroup
{})
.
Error
Delete
(
&
a
ccountGroup
Model
{})
.
Error
}
}
func
(
r
*
accountRepository
)
GetGroups
(
ctx
context
.
Context
,
accountID
int64
)
([]
model
.
Group
,
error
)
{
func
(
r
*
accountRepository
)
GetGroups
(
ctx
context
.
Context
,
accountID
int64
)
([]
service
.
Group
,
error
)
{
var
groups
[]
model
.
Group
var
groups
[]
groupModel
err
:=
r
.
db
.
WithContext
(
ctx
)
.
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Joins
(
"JOIN account_groups ON account_groups.group_id = groups.id"
)
.
Joins
(
"JOIN account_groups ON account_groups.group_id = groups.id"
)
.
Where
(
"account_groups.account_id = ?"
,
accountID
)
.
Where
(
"account_groups.account_id = ?"
,
accountID
)
.
Find
(
&
groups
)
.
Error
Find
(
&
groups
)
.
Error
return
groups
,
err
if
err
!=
nil
{
}
return
nil
,
err
}
func
(
r
*
accountRepository
)
ListByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
model
.
Account
,
error
)
{
outGroups
:=
make
([]
service
.
Group
,
0
,
len
(
groups
))
var
accounts
[]
model
.
Account
for
i
:=
range
groups
{
err
:=
r
.
db
.
WithContext
(
ctx
)
.
outGroups
=
append
(
outGroups
,
*
groupModelToService
(
&
groups
[
i
]))
Where
(
"platform = ? AND status = ?"
,
platform
,
model
.
StatusActive
)
.
}
Preload
(
"Proxy"
)
.
return
outGroups
,
nil
Order
(
"priority ASC"
)
.
Find
(
&
accounts
)
.
Error
return
accounts
,
err
}
}
func
(
r
*
accountRepository
)
BindGroups
(
ctx
context
.
Context
,
accountID
int64
,
groupIDs
[]
int64
)
error
{
func
(
r
*
accountRepository
)
BindGroups
(
ctx
context
.
Context
,
accountID
int64
,
groupIDs
[]
int64
)
error
{
// 删除现有绑定
if
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"account_id = ?"
,
accountID
)
.
Delete
(
&
accountGroupModel
{})
.
Error
;
err
!=
nil
{
if
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"account_id = ?"
,
accountID
)
.
Delete
(
&
model
.
AccountGroup
{})
.
Error
;
err
!=
nil
{
return
err
return
err
}
}
// 添加新绑定
if
len
(
groupIDs
)
==
0
{
if
len
(
groupIDs
)
>
0
{
return
nil
accountGroups
:=
make
([]
model
.
AccountGroup
,
0
,
len
(
groupIDs
))
for
i
,
groupID
:=
range
groupIDs
{
accountGroups
=
append
(
accountGroups
,
model
.
AccountGroup
{
AccountID
:
accountID
,
GroupID
:
groupID
,
Priority
:
i
+
1
,
// 使用索引作为优先级
})
}
return
r
.
db
.
WithContext
(
ctx
)
.
Create
(
&
accountGroups
)
.
Error
}
}
return
nil
accountGroups
:=
make
([]
accountGroupModel
,
0
,
len
(
groupIDs
))
for
i
,
groupID
:=
range
groupIDs
{
accountGroups
=
append
(
accountGroups
,
accountGroupModel
{
AccountID
:
accountID
,
GroupID
:
groupID
,
Priority
:
i
+
1
,
})
}
return
r
.
db
.
WithContext
(
ctx
)
.
Create
(
&
accountGroups
)
.
Error
}
}
// ListSchedulable 获取所有可调度的账号
func
(
r
*
accountRepository
)
ListSchedulable
(
ctx
context
.
Context
)
([]
service
.
Account
,
error
)
{
func
(
r
*
accountRepository
)
ListSchedulable
(
ctx
context
.
Context
)
([]
model
.
Account
,
error
)
{
var
accounts
[]
accountModel
var
accounts
[]
model
.
Account
now
:=
time
.
Now
()
now
:=
time
.
Now
()
err
:=
r
.
db
.
WithContext
(
ctx
)
.
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"status = ? AND schedulable = ?"
,
model
.
StatusActive
,
true
)
.
Where
(
"status = ? AND schedulable = ?"
,
service
.
StatusActive
,
true
)
.
Where
(
"(overload_until IS NULL OR overload_until <= ?)"
,
now
)
.
Where
(
"(overload_until IS NULL OR overload_until <= ?)"
,
now
)
.
Where
(
"(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)"
,
now
)
.
Where
(
"(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)"
,
now
)
.
Preload
(
"Proxy"
)
.
Preload
(
"Proxy"
)
.
Order
(
"priority ASC"
)
.
Order
(
"priority ASC"
)
.
Find
(
&
accounts
)
.
Error
Find
(
&
accounts
)
.
Error
return
accounts
,
err
if
err
!=
nil
{
return
nil
,
err
}
outAccounts
:=
make
([]
service
.
Account
,
0
,
len
(
accounts
))
for
i
:=
range
accounts
{
outAccounts
=
append
(
outAccounts
,
*
accountModelToService
(
&
accounts
[
i
]))
}
return
outAccounts
,
nil
}
}
// ListSchedulableByGroupID 按组获取可调度的账号
func
(
r
*
accountRepository
)
ListSchedulableByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
([]
service
.
Account
,
error
)
{
func
(
r
*
accountRepository
)
ListSchedulableByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
([]
model
.
Account
,
error
)
{
var
accounts
[]
accountModel
var
accounts
[]
model
.
Account
now
:=
time
.
Now
()
now
:=
time
.
Now
()
err
:=
r
.
db
.
WithContext
(
ctx
)
.
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Joins
(
"JOIN account_groups ON account_groups.account_id = accounts.id"
)
.
Joins
(
"JOIN account_groups ON account_groups.account_id = accounts.id"
)
.
Where
(
"account_groups.group_id = ?"
,
groupID
)
.
Where
(
"account_groups.group_id = ?"
,
groupID
)
.
Where
(
"accounts.status = ? AND accounts.schedulable = ?"
,
model
.
StatusActive
,
true
)
.
Where
(
"accounts.status = ? AND accounts.schedulable = ?"
,
service
.
StatusActive
,
true
)
.
Where
(
"(accounts.overload_until IS NULL OR accounts.overload_until <= ?)"
,
now
)
.
Where
(
"(accounts.overload_until IS NULL OR accounts.overload_until <= ?)"
,
now
)
.
Where
(
"(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)"
,
now
)
.
Where
(
"(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)"
,
now
)
.
Preload
(
"Proxy"
)
.
Preload
(
"Proxy"
)
.
Order
(
"account_groups.priority ASC, accounts.priority ASC"
)
.
Order
(
"account_groups.priority ASC, accounts.priority ASC"
)
.
Find
(
&
accounts
)
.
Error
Find
(
&
accounts
)
.
Error
return
accounts
,
err
if
err
!=
nil
{
return
nil
,
err
}
outAccounts
:=
make
([]
service
.
Account
,
0
,
len
(
accounts
))
for
i
:=
range
accounts
{
outAccounts
=
append
(
outAccounts
,
*
accountModelToService
(
&
accounts
[
i
]))
}
return
outAccounts
,
nil
}
}
// ListSchedulableByPlatform 按平台获取可调度的账号
func
(
r
*
accountRepository
)
ListSchedulableByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
service
.
Account
,
error
)
{
func
(
r
*
accountRepository
)
ListSchedulableByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
model
.
Account
,
error
)
{
var
accounts
[]
accountModel
var
accounts
[]
model
.
Account
now
:=
time
.
Now
()
now
:=
time
.
Now
()
err
:=
r
.
db
.
WithContext
(
ctx
)
.
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"platform = ?"
,
platform
)
.
Where
(
"platform = ?"
,
platform
)
.
Where
(
"status = ? AND schedulable = ?"
,
model
.
StatusActive
,
true
)
.
Where
(
"status = ? AND schedulable = ?"
,
service
.
StatusActive
,
true
)
.
Where
(
"(overload_until IS NULL OR overload_until <= ?)"
,
now
)
.
Where
(
"(overload_until IS NULL OR overload_until <= ?)"
,
now
)
.
Where
(
"(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)"
,
now
)
.
Where
(
"(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)"
,
now
)
.
Preload
(
"Proxy"
)
.
Preload
(
"Proxy"
)
.
Order
(
"priority ASC"
)
.
Order
(
"priority ASC"
)
.
Find
(
&
accounts
)
.
Error
Find
(
&
accounts
)
.
Error
return
accounts
,
err
if
err
!=
nil
{
return
nil
,
err
}
outAccounts
:=
make
([]
service
.
Account
,
0
,
len
(
accounts
))
for
i
:=
range
accounts
{
outAccounts
=
append
(
outAccounts
,
*
accountModelToService
(
&
accounts
[
i
]))
}
return
outAccounts
,
nil
}
}
// ListSchedulableByGroupIDAndPlatform 按组和平台获取可调度的账号
func
(
r
*
accountRepository
)
ListSchedulableByGroupIDAndPlatform
(
ctx
context
.
Context
,
groupID
int64
,
platform
string
)
([]
service
.
Account
,
error
)
{
func
(
r
*
accountRepository
)
ListSchedulableByGroupIDAndPlatform
(
ctx
context
.
Context
,
groupID
int64
,
platform
string
)
([]
model
.
Account
,
error
)
{
var
accounts
[]
accountModel
var
accounts
[]
model
.
Account
now
:=
time
.
Now
()
now
:=
time
.
Now
()
err
:=
r
.
db
.
WithContext
(
ctx
)
.
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Joins
(
"JOIN account_groups ON account_groups.account_id = accounts.id"
)
.
Joins
(
"JOIN account_groups ON account_groups.account_id = accounts.id"
)
.
Where
(
"account_groups.group_id = ?"
,
groupID
)
.
Where
(
"account_groups.group_id = ?"
,
groupID
)
.
Where
(
"accounts.platform = ?"
,
platform
)
.
Where
(
"accounts.platform = ?"
,
platform
)
.
Where
(
"accounts.status = ? AND accounts.schedulable = ?"
,
model
.
StatusActive
,
true
)
.
Where
(
"accounts.status = ? AND accounts.schedulable = ?"
,
service
.
StatusActive
,
true
)
.
Where
(
"(accounts.overload_until IS NULL OR accounts.overload_until <= ?)"
,
now
)
.
Where
(
"(accounts.overload_until IS NULL OR accounts.overload_until <= ?)"
,
now
)
.
Where
(
"(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)"
,
now
)
.
Where
(
"(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)"
,
now
)
.
Preload
(
"Proxy"
)
.
Preload
(
"Proxy"
)
.
Order
(
"account_groups.priority ASC, accounts.priority ASC"
)
.
Order
(
"account_groups.priority ASC, accounts.priority ASC"
)
.
Find
(
&
accounts
)
.
Error
Find
(
&
accounts
)
.
Error
return
accounts
,
err
if
err
!=
nil
{
return
nil
,
err
}
outAccounts
:=
make
([]
service
.
Account
,
0
,
len
(
accounts
))
for
i
:=
range
accounts
{
outAccounts
=
append
(
outAccounts
,
*
accountModelToService
(
&
accounts
[
i
]))
}
return
outAccounts
,
nil
}
}
// SetRateLimited 标记账号为限流状态(429)
func
(
r
*
accountRepository
)
SetRateLimited
(
ctx
context
.
Context
,
id
int64
,
resetAt
time
.
Time
)
error
{
func
(
r
*
accountRepository
)
SetRateLimited
(
ctx
context
.
Context
,
id
int64
,
resetAt
time
.
Time
)
error
{
now
:=
time
.
Now
()
now
:=
time
.
Now
()
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
A
ccount
{})
.
Where
(
"id = ?"
,
id
)
.
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
a
ccount
Model
{})
.
Where
(
"id = ?"
,
id
)
.
Updates
(
map
[
string
]
any
{
Updates
(
map
[
string
]
any
{
"rate_limited_at"
:
now
,
"rate_limited_at"
:
now
,
"rate_limit_reset_at"
:
resetAt
,
"rate_limit_reset_at"
:
resetAt
,
})
.
Error
})
.
Error
}
}
// SetOverloaded 标记账号为过载状态(529)
func
(
r
*
accountRepository
)
SetOverloaded
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
)
error
{
func
(
r
*
accountRepository
)
SetOverloaded
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
A
ccount
{})
.
Where
(
"id = ?"
,
id
)
.
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
a
ccount
Model
{})
.
Where
(
"id = ?"
,
id
)
.
Update
(
"overload_until"
,
until
)
.
Error
Update
(
"overload_until"
,
until
)
.
Error
}
}
// ClearRateLimit 清除账号的限流状态
func
(
r
*
accountRepository
)
ClearRateLimit
(
ctx
context
.
Context
,
id
int64
)
error
{
func
(
r
*
accountRepository
)
ClearRateLimit
(
ctx
context
.
Context
,
id
int64
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
A
ccount
{})
.
Where
(
"id = ?"
,
id
)
.
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
a
ccount
Model
{})
.
Where
(
"id = ?"
,
id
)
.
Updates
(
map
[
string
]
any
{
Updates
(
map
[
string
]
any
{
"rate_limited_at"
:
nil
,
"rate_limited_at"
:
nil
,
"rate_limit_reset_at"
:
nil
,
"rate_limit_reset_at"
:
nil
,
...
@@ -308,7 +339,6 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error
...
@@ -308,7 +339,6 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error
})
.
Error
})
.
Error
}
}
// UpdateSessionWindow 更新账号的5小时时间窗口信息
func
(
r
*
accountRepository
)
UpdateSessionWindow
(
ctx
context
.
Context
,
id
int64
,
start
,
end
*
time
.
Time
,
status
string
)
error
{
func
(
r
*
accountRepository
)
UpdateSessionWindow
(
ctx
context
.
Context
,
id
int64
,
start
,
end
*
time
.
Time
,
status
string
)
error
{
updates
:=
map
[
string
]
any
{
updates
:=
map
[
string
]
any
{
"session_window_status"
:
status
,
"session_window_status"
:
status
,
...
@@ -319,45 +349,35 @@ func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, s
...
@@ -319,45 +349,35 @@ func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, s
if
end
!=
nil
{
if
end
!=
nil
{
updates
[
"session_window_end"
]
=
end
updates
[
"session_window_end"
]
=
end
}
}
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
A
ccount
{})
.
Where
(
"id = ?"
,
id
)
.
Updates
(
updates
)
.
Error
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
a
ccount
Model
{})
.
Where
(
"id = ?"
,
id
)
.
Updates
(
updates
)
.
Error
}
}
// SetSchedulable 设置账号的调度开关
func
(
r
*
accountRepository
)
SetSchedulable
(
ctx
context
.
Context
,
id
int64
,
schedulable
bool
)
error
{
func
(
r
*
accountRepository
)
SetSchedulable
(
ctx
context
.
Context
,
id
int64
,
schedulable
bool
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
A
ccount
{})
.
Where
(
"id = ?"
,
id
)
.
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
a
ccount
Model
{})
.
Where
(
"id = ?"
,
id
)
.
Update
(
"schedulable"
,
schedulable
)
.
Error
Update
(
"schedulable"
,
schedulable
)
.
Error
}
}
// UpdateExtra updates specific fields in account's Extra JSONB field
// It merges the updates into existing Extra data without overwriting other fields
func
(
r
*
accountRepository
)
UpdateExtra
(
ctx
context
.
Context
,
id
int64
,
updates
map
[
string
]
any
)
error
{
func
(
r
*
accountRepository
)
UpdateExtra
(
ctx
context
.
Context
,
id
int64
,
updates
map
[
string
]
any
)
error
{
if
len
(
updates
)
==
0
{
if
len
(
updates
)
==
0
{
return
nil
return
nil
}
}
// Get current account to preserve existing Extra data
var
account
accountModel
var
account
model
.
Account
if
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Select
(
"extra"
)
.
Where
(
"id = ?"
,
id
)
.
First
(
&
account
)
.
Error
;
err
!=
nil
{
if
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Select
(
"extra"
)
.
Where
(
"id = ?"
,
id
)
.
First
(
&
account
)
.
Error
;
err
!=
nil
{
return
err
return
err
}
}
// Initialize Extra if nil
if
account
.
Extra
==
nil
{
if
account
.
Extra
==
nil
{
account
.
Extra
=
make
(
model
.
JSONB
)
account
.
Extra
=
datatypes
.
JSONMap
{}
}
}
// Merge updates into existing Extra
for
k
,
v
:=
range
updates
{
for
k
,
v
:=
range
updates
{
account
.
Extra
[
k
]
=
v
account
.
Extra
[
k
]
=
v
}
}
// Save updated Extra
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
accountModel
{})
.
Where
(
"id = ?"
,
id
)
.
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
Account
{})
.
Where
(
"id = ?"
,
id
)
.
Update
(
"extra"
,
account
.
Extra
)
.
Error
Update
(
"extra"
,
account
.
Extra
)
.
Error
}
}
// BulkUpdate updates multiple accounts with the provided fields.
// It merges credentials/extra JSONB fields instead of overwriting them.
func
(
r
*
accountRepository
)
BulkUpdate
(
ctx
context
.
Context
,
ids
[]
int64
,
updates
service
.
AccountBulkUpdate
)
(
int64
,
error
)
{
func
(
r
*
accountRepository
)
BulkUpdate
(
ctx
context
.
Context
,
ids
[]
int64
,
updates
service
.
AccountBulkUpdate
)
(
int64
,
error
)
{
if
len
(
ids
)
==
0
{
if
len
(
ids
)
==
0
{
return
0
,
nil
return
0
,
nil
...
@@ -381,10 +401,10 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
...
@@ -381,10 +401,10 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
updateMap
[
"status"
]
=
*
updates
.
Status
updateMap
[
"status"
]
=
*
updates
.
Status
}
}
if
len
(
updates
.
Credentials
)
>
0
{
if
len
(
updates
.
Credentials
)
>
0
{
updateMap
[
"credentials"
]
=
gorm
.
Expr
(
"COALESCE(credentials,'{}') || ?"
,
updates
.
Credentials
)
updateMap
[
"credentials"
]
=
gorm
.
Expr
(
"COALESCE(credentials,'{}') || ?"
,
datatypes
.
JSONMap
(
updates
.
Credentials
)
)
}
}
if
len
(
updates
.
Extra
)
>
0
{
if
len
(
updates
.
Extra
)
>
0
{
updateMap
[
"extra"
]
=
gorm
.
Expr
(
"COALESCE(extra,'{}') || ?"
,
updates
.
Extra
)
updateMap
[
"extra"
]
=
gorm
.
Expr
(
"COALESCE(extra,'{}') || ?"
,
datatypes
.
JSONMap
(
updates
.
Extra
)
)
}
}
if
len
(
updateMap
)
==
0
{
if
len
(
updateMap
)
==
0
{
...
@@ -392,10 +412,178 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
...
@@ -392,10 +412,178 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
}
}
result
:=
r
.
db
.
WithContext
(
ctx
)
.
result
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
A
ccount
{})
.
Model
(
&
a
ccount
Model
{})
.
Where
(
"id IN ?"
,
ids
)
.
Where
(
"id IN ?"
,
ids
)
.
Clauses
(
clause
.
Returning
{})
.
Clauses
(
clause
.
Returning
{})
.
Updates
(
updateMap
)
Updates
(
updateMap
)
return
result
.
RowsAffected
,
result
.
Error
return
result
.
RowsAffected
,
result
.
Error
}
}
type
accountModel
struct
{
ID
int64
`gorm:"primaryKey"`
Name
string
`gorm:"size:100;not null"`
Platform
string
`gorm:"size:50;not null"`
Type
string
`gorm:"size:20;not null"`
Credentials
datatypes
.
JSONMap
`gorm:"type:jsonb;default:'{}'"`
Extra
datatypes
.
JSONMap
`gorm:"type:jsonb;default:'{}'"`
ProxyID
*
int64
`gorm:"index"`
Concurrency
int
`gorm:"default:3;not null"`
Priority
int
`gorm:"default:50;not null"`
Status
string
`gorm:"size:20;default:active;not null"`
ErrorMessage
string
`gorm:"type:text"`
LastUsedAt
*
time
.
Time
`gorm:"index"`
CreatedAt
time
.
Time
`gorm:"not null"`
UpdatedAt
time
.
Time
`gorm:"not null"`
DeletedAt
gorm
.
DeletedAt
`gorm:"index"`
Schedulable
bool
`gorm:"default:true;not null"`
RateLimitedAt
*
time
.
Time
`gorm:"index"`
RateLimitResetAt
*
time
.
Time
`gorm:"index"`
OverloadUntil
*
time
.
Time
`gorm:"index"`
SessionWindowStart
*
time
.
Time
SessionWindowEnd
*
time
.
Time
SessionWindowStatus
string
`gorm:"size:20"`
Proxy
*
proxyModel
`gorm:"foreignKey:ProxyID"`
AccountGroups
[]
accountGroupModel
`gorm:"foreignKey:AccountID"`
}
func
(
accountModel
)
TableName
()
string
{
return
"accounts"
}
type
accountGroupModel
struct
{
AccountID
int64
`gorm:"primaryKey"`
GroupID
int64
`gorm:"primaryKey"`
Priority
int
`gorm:"default:50;not null"`
CreatedAt
time
.
Time
`gorm:"not null"`
Account
*
accountModel
`gorm:"foreignKey:AccountID"`
Group
*
groupModel
`gorm:"foreignKey:GroupID"`
}
func
(
accountGroupModel
)
TableName
()
string
{
return
"account_groups"
}
func
accountGroupModelToService
(
m
*
accountGroupModel
)
*
service
.
AccountGroup
{
if
m
==
nil
{
return
nil
}
return
&
service
.
AccountGroup
{
AccountID
:
m
.
AccountID
,
GroupID
:
m
.
GroupID
,
Priority
:
m
.
Priority
,
CreatedAt
:
m
.
CreatedAt
,
Account
:
accountModelToService
(
m
.
Account
),
Group
:
groupModelToService
(
m
.
Group
),
}
}
func
accountModelToService
(
m
*
accountModel
)
*
service
.
Account
{
if
m
==
nil
{
return
nil
}
var
credentials
map
[
string
]
any
if
m
.
Credentials
!=
nil
{
credentials
=
map
[
string
]
any
(
m
.
Credentials
)
}
var
extra
map
[
string
]
any
if
m
.
Extra
!=
nil
{
extra
=
map
[
string
]
any
(
m
.
Extra
)
}
account
:=
&
service
.
Account
{
ID
:
m
.
ID
,
Name
:
m
.
Name
,
Platform
:
m
.
Platform
,
Type
:
m
.
Type
,
Credentials
:
credentials
,
Extra
:
extra
,
ProxyID
:
m
.
ProxyID
,
Concurrency
:
m
.
Concurrency
,
Priority
:
m
.
Priority
,
Status
:
m
.
Status
,
ErrorMessage
:
m
.
ErrorMessage
,
LastUsedAt
:
m
.
LastUsedAt
,
CreatedAt
:
m
.
CreatedAt
,
UpdatedAt
:
m
.
UpdatedAt
,
Schedulable
:
m
.
Schedulable
,
RateLimitedAt
:
m
.
RateLimitedAt
,
RateLimitResetAt
:
m
.
RateLimitResetAt
,
OverloadUntil
:
m
.
OverloadUntil
,
SessionWindowStart
:
m
.
SessionWindowStart
,
SessionWindowEnd
:
m
.
SessionWindowEnd
,
SessionWindowStatus
:
m
.
SessionWindowStatus
,
Proxy
:
proxyModelToService
(
m
.
Proxy
),
}
if
len
(
m
.
AccountGroups
)
>
0
{
account
.
AccountGroups
=
make
([]
service
.
AccountGroup
,
0
,
len
(
m
.
AccountGroups
))
account
.
GroupIDs
=
make
([]
int64
,
0
,
len
(
m
.
AccountGroups
))
account
.
Groups
=
make
([]
*
service
.
Group
,
0
,
len
(
m
.
AccountGroups
))
for
i
:=
range
m
.
AccountGroups
{
ag
:=
accountGroupModelToService
(
&
m
.
AccountGroups
[
i
])
if
ag
==
nil
{
continue
}
account
.
AccountGroups
=
append
(
account
.
AccountGroups
,
*
ag
)
account
.
GroupIDs
=
append
(
account
.
GroupIDs
,
ag
.
GroupID
)
if
ag
.
Group
!=
nil
{
account
.
Groups
=
append
(
account
.
Groups
,
ag
.
Group
)
}
}
}
return
account
}
func
accountModelFromService
(
a
*
service
.
Account
)
*
accountModel
{
if
a
==
nil
{
return
nil
}
var
credentials
datatypes
.
JSONMap
if
a
.
Credentials
!=
nil
{
credentials
=
datatypes
.
JSONMap
(
a
.
Credentials
)
}
var
extra
datatypes
.
JSONMap
if
a
.
Extra
!=
nil
{
extra
=
datatypes
.
JSONMap
(
a
.
Extra
)
}
return
&
accountModel
{
ID
:
a
.
ID
,
Name
:
a
.
Name
,
Platform
:
a
.
Platform
,
Type
:
a
.
Type
,
Credentials
:
credentials
,
Extra
:
extra
,
ProxyID
:
a
.
ProxyID
,
Concurrency
:
a
.
Concurrency
,
Priority
:
a
.
Priority
,
Status
:
a
.
Status
,
ErrorMessage
:
a
.
ErrorMessage
,
LastUsedAt
:
a
.
LastUsedAt
,
CreatedAt
:
a
.
CreatedAt
,
UpdatedAt
:
a
.
UpdatedAt
,
Schedulable
:
a
.
Schedulable
,
RateLimitedAt
:
a
.
RateLimitedAt
,
RateLimitResetAt
:
a
.
RateLimitResetAt
,
OverloadUntil
:
a
.
OverloadUntil
,
SessionWindowStart
:
a
.
SessionWindowStart
,
SessionWindowEnd
:
a
.
SessionWindowEnd
,
SessionWindowStatus
:
a
.
SessionWindowStatus
,
}
}
func
applyAccountModelToService
(
account
*
service
.
Account
,
m
*
accountModel
)
{
if
account
==
nil
||
m
==
nil
{
return
}
account
.
ID
=
m
.
ID
account
.
CreatedAt
=
m
.
CreatedAt
account
.
UpdatedAt
=
m
.
UpdatedAt
}
backend/internal/repository/account_repo_integration_test.go
View file @
22f07a7b
...
@@ -7,10 +7,10 @@ import (
...
@@ -7,10 +7,10 @@ import (
"testing"
"testing"
"time"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"github.com/stretchr/testify/suite"
"gorm.io/datatypes"
"gorm.io/gorm"
"gorm.io/gorm"
)
)
...
@@ -34,11 +34,16 @@ func TestAccountRepoSuite(t *testing.T) {
...
@@ -34,11 +34,16 @@ func TestAccountRepoSuite(t *testing.T) {
// --- Create / GetByID / Update / Delete ---
// --- Create / GetByID / Update / Delete ---
func
(
s
*
AccountRepoSuite
)
TestCreate
()
{
func
(
s
*
AccountRepoSuite
)
TestCreate
()
{
account
:=
&
model
.
Account
{
account
:=
&
service
.
Account
{
Name
:
"test-create"
,
Name
:
"test-create"
,
Platform
:
model
.
PlatformAnthropic
,
Platform
:
service
.
PlatformAnthropic
,
Type
:
model
.
AccountTypeOAuth
,
Type
:
service
.
AccountTypeOAuth
,
Status
:
model
.
StatusActive
,
Status
:
service
.
StatusActive
,
Credentials
:
map
[
string
]
any
{},
Extra
:
map
[
string
]
any
{},
Concurrency
:
3
,
Priority
:
50
,
Schedulable
:
true
,
}
}
err
:=
s
.
repo
.
Create
(
s
.
ctx
,
account
)
err
:=
s
.
repo
.
Create
(
s
.
ctx
,
account
)
...
@@ -56,7 +61,7 @@ func (s *AccountRepoSuite) TestGetByID_NotFound() {
...
@@ -56,7 +61,7 @@ func (s *AccountRepoSuite) TestGetByID_NotFound() {
}
}
func
(
s
*
AccountRepoSuite
)
TestUpdate
()
{
func
(
s
*
AccountRepoSuite
)
TestUpdate
()
{
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"original"
})
account
:=
accountModelToService
(
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"original"
})
)
account
.
Name
=
"updated"
account
.
Name
=
"updated"
err
:=
s
.
repo
.
Update
(
s
.
ctx
,
account
)
err
:=
s
.
repo
.
Update
(
s
.
ctx
,
account
)
...
@@ -68,7 +73,7 @@ func (s *AccountRepoSuite) TestUpdate() {
...
@@ -68,7 +73,7 @@ func (s *AccountRepoSuite) TestUpdate() {
}
}
func
(
s
*
AccountRepoSuite
)
TestDelete
()
{
func
(
s
*
AccountRepoSuite
)
TestDelete
()
{
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"to-delete"
})
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"to-delete"
})
err
:=
s
.
repo
.
Delete
(
s
.
ctx
,
account
.
ID
)
err
:=
s
.
repo
.
Delete
(
s
.
ctx
,
account
.
ID
)
s
.
Require
()
.
NoError
(
err
,
"Delete"
)
s
.
Require
()
.
NoError
(
err
,
"Delete"
)
...
@@ -78,23 +83,23 @@ func (s *AccountRepoSuite) TestDelete() {
...
@@ -78,23 +83,23 @@ func (s *AccountRepoSuite) TestDelete() {
}
}
func
(
s
*
AccountRepoSuite
)
TestDelete_WithGroupBindings
()
{
func
(
s
*
AccountRepoSuite
)
TestDelete_WithGroupBindings
()
{
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g-del"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g-del"
})
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"acc-del"
})
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"acc-del"
})
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
account
.
ID
,
group
.
ID
,
1
)
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
account
.
ID
,
group
.
ID
,
1
)
err
:=
s
.
repo
.
Delete
(
s
.
ctx
,
account
.
ID
)
err
:=
s
.
repo
.
Delete
(
s
.
ctx
,
account
.
ID
)
s
.
Require
()
.
NoError
(
err
,
"Delete should cascade remove bindings"
)
s
.
Require
()
.
NoError
(
err
,
"Delete should cascade remove bindings"
)
var
count
int64
var
count
int64
s
.
db
.
Model
(
&
model
.
A
ccountGroup
{})
.
Where
(
"account_id = ?"
,
account
.
ID
)
.
Count
(
&
count
)
s
.
db
.
Model
(
&
a
ccountGroup
Model
{})
.
Where
(
"account_id = ?"
,
account
.
ID
)
.
Count
(
&
count
)
s
.
Require
()
.
Zero
(
count
,
"expected bindings to be removed"
)
s
.
Require
()
.
Zero
(
count
,
"expected bindings to be removed"
)
}
}
// --- List / ListWithFilters ---
// --- List / ListWithFilters ---
func
(
s
*
AccountRepoSuite
)
TestList
()
{
func
(
s
*
AccountRepoSuite
)
TestList
()
{
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"acc1"
})
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"acc1"
})
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"acc2"
})
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"acc2"
})
accounts
,
page
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
})
accounts
,
page
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
})
s
.
Require
()
.
NoError
(
err
,
"List"
)
s
.
Require
()
.
NoError
(
err
,
"List"
)
...
@@ -111,53 +116,53 @@ func (s *AccountRepoSuite) TestListWithFilters() {
...
@@ -111,53 +116,53 @@ func (s *AccountRepoSuite) TestListWithFilters() {
status
string
status
string
search
string
search
string
wantCount
int
wantCount
int
validate
func
(
accounts
[]
model
.
Account
)
validate
func
(
accounts
[]
service
.
Account
)
}{
}{
{
{
name
:
"filter_by_platform"
,
name
:
"filter_by_platform"
,
setup
:
func
(
db
*
gorm
.
DB
)
{
setup
:
func
(
db
*
gorm
.
DB
)
{
mustCreateAccount
(
s
.
T
(),
db
,
&
model
.
A
ccount
{
Name
:
"a1"
,
Platform
:
model
.
PlatformAnthropic
})
mustCreateAccount
(
s
.
T
(),
db
,
&
a
ccount
Model
{
Name
:
"a1"
,
Platform
:
service
.
PlatformAnthropic
})
mustCreateAccount
(
s
.
T
(),
db
,
&
model
.
A
ccount
{
Name
:
"a2"
,
Platform
:
model
.
PlatformOpenAI
})
mustCreateAccount
(
s
.
T
(),
db
,
&
a
ccount
Model
{
Name
:
"a2"
,
Platform
:
service
.
PlatformOpenAI
})
},
},
platform
:
model
.
PlatformOpenAI
,
platform
:
service
.
PlatformOpenAI
,
wantCount
:
1
,
wantCount
:
1
,
validate
:
func
(
accounts
[]
model
.
Account
)
{
validate
:
func
(
accounts
[]
service
.
Account
)
{
s
.
Require
()
.
Equal
(
model
.
PlatformOpenAI
,
accounts
[
0
]
.
Platform
)
s
.
Require
()
.
Equal
(
service
.
PlatformOpenAI
,
accounts
[
0
]
.
Platform
)
},
},
},
},
{
{
name
:
"filter_by_type"
,
name
:
"filter_by_type"
,
setup
:
func
(
db
*
gorm
.
DB
)
{
setup
:
func
(
db
*
gorm
.
DB
)
{
mustCreateAccount
(
s
.
T
(),
db
,
&
model
.
A
ccount
{
Name
:
"t1"
,
Type
:
model
.
AccountTypeOAuth
})
mustCreateAccount
(
s
.
T
(),
db
,
&
a
ccount
Model
{
Name
:
"t1"
,
Type
:
service
.
AccountTypeOAuth
})
mustCreateAccount
(
s
.
T
(),
db
,
&
model
.
A
ccount
{
Name
:
"t2"
,
Type
:
model
.
AccountTypeApiKey
})
mustCreateAccount
(
s
.
T
(),
db
,
&
a
ccount
Model
{
Name
:
"t2"
,
Type
:
service
.
AccountTypeApiKey
})
},
},
accType
:
model
.
AccountTypeApiKey
,
accType
:
service
.
AccountTypeApiKey
,
wantCount
:
1
,
wantCount
:
1
,
validate
:
func
(
accounts
[]
model
.
Account
)
{
validate
:
func
(
accounts
[]
service
.
Account
)
{
s
.
Require
()
.
Equal
(
model
.
AccountTypeApiKey
,
accounts
[
0
]
.
Type
)
s
.
Require
()
.
Equal
(
service
.
AccountTypeApiKey
,
accounts
[
0
]
.
Type
)
},
},
},
},
{
{
name
:
"filter_by_status"
,
name
:
"filter_by_status"
,
setup
:
func
(
db
*
gorm
.
DB
)
{
setup
:
func
(
db
*
gorm
.
DB
)
{
mustCreateAccount
(
s
.
T
(),
db
,
&
model
.
A
ccount
{
Name
:
"s1"
,
Status
:
model
.
StatusActive
})
mustCreateAccount
(
s
.
T
(),
db
,
&
a
ccount
Model
{
Name
:
"s1"
,
Status
:
service
.
StatusActive
})
mustCreateAccount
(
s
.
T
(),
db
,
&
model
.
A
ccount
{
Name
:
"s2"
,
Status
:
model
.
StatusDisabled
})
mustCreateAccount
(
s
.
T
(),
db
,
&
a
ccount
Model
{
Name
:
"s2"
,
Status
:
service
.
StatusDisabled
})
},
},
status
:
model
.
StatusDisabled
,
status
:
service
.
StatusDisabled
,
wantCount
:
1
,
wantCount
:
1
,
validate
:
func
(
accounts
[]
model
.
Account
)
{
validate
:
func
(
accounts
[]
service
.
Account
)
{
s
.
Require
()
.
Equal
(
model
.
StatusDisabled
,
accounts
[
0
]
.
Status
)
s
.
Require
()
.
Equal
(
service
.
StatusDisabled
,
accounts
[
0
]
.
Status
)
},
},
},
},
{
{
name
:
"filter_by_search"
,
name
:
"filter_by_search"
,
setup
:
func
(
db
*
gorm
.
DB
)
{
setup
:
func
(
db
*
gorm
.
DB
)
{
mustCreateAccount
(
s
.
T
(),
db
,
&
model
.
A
ccount
{
Name
:
"alpha-account"
})
mustCreateAccount
(
s
.
T
(),
db
,
&
a
ccount
Model
{
Name
:
"alpha-account"
})
mustCreateAccount
(
s
.
T
(),
db
,
&
model
.
A
ccount
{
Name
:
"beta-account"
})
mustCreateAccount
(
s
.
T
(),
db
,
&
a
ccount
Model
{
Name
:
"beta-account"
})
},
},
search
:
"alpha"
,
search
:
"alpha"
,
wantCount
:
1
,
wantCount
:
1
,
validate
:
func
(
accounts
[]
model
.
Account
)
{
validate
:
func
(
accounts
[]
service
.
Account
)
{
s
.
Require
()
.
Contains
(
accounts
[
0
]
.
Name
,
"alpha"
)
s
.
Require
()
.
Contains
(
accounts
[
0
]
.
Name
,
"alpha"
)
},
},
},
},
...
@@ -185,9 +190,9 @@ func (s *AccountRepoSuite) TestListWithFilters() {
...
@@ -185,9 +190,9 @@ func (s *AccountRepoSuite) TestListWithFilters() {
// --- ListByGroup / ListActive / ListByPlatform ---
// --- ListByGroup / ListActive / ListByPlatform ---
func
(
s
*
AccountRepoSuite
)
TestListByGroup
()
{
func
(
s
*
AccountRepoSuite
)
TestListByGroup
()
{
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g-list"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g-list"
})
acc1
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"a1"
,
Status
:
model
.
StatusActive
})
acc1
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"a1"
,
Status
:
service
.
StatusActive
})
acc2
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"a2"
,
Status
:
model
.
StatusActive
})
acc2
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"a2"
,
Status
:
service
.
StatusActive
})
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
acc1
.
ID
,
group
.
ID
,
2
)
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
acc1
.
ID
,
group
.
ID
,
2
)
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
acc2
.
ID
,
group
.
ID
,
1
)
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
acc2
.
ID
,
group
.
ID
,
1
)
...
@@ -199,8 +204,8 @@ func (s *AccountRepoSuite) TestListByGroup() {
...
@@ -199,8 +204,8 @@ func (s *AccountRepoSuite) TestListByGroup() {
}
}
func
(
s
*
AccountRepoSuite
)
TestListActive
()
{
func
(
s
*
AccountRepoSuite
)
TestListActive
()
{
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"active1"
,
Status
:
model
.
StatusActive
})
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"active1"
,
Status
:
service
.
StatusActive
})
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"inactive1"
,
Status
:
model
.
StatusDisabled
})
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"inactive1"
,
Status
:
service
.
StatusDisabled
})
accounts
,
err
:=
s
.
repo
.
ListActive
(
s
.
ctx
)
accounts
,
err
:=
s
.
repo
.
ListActive
(
s
.
ctx
)
s
.
Require
()
.
NoError
(
err
,
"ListActive"
)
s
.
Require
()
.
NoError
(
err
,
"ListActive"
)
...
@@ -209,22 +214,22 @@ func (s *AccountRepoSuite) TestListActive() {
...
@@ -209,22 +214,22 @@ func (s *AccountRepoSuite) TestListActive() {
}
}
func
(
s
*
AccountRepoSuite
)
TestListByPlatform
()
{
func
(
s
*
AccountRepoSuite
)
TestListByPlatform
()
{
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"p1"
,
Platform
:
model
.
PlatformAnthropic
,
Status
:
model
.
StatusActive
})
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"p1"
,
Platform
:
service
.
PlatformAnthropic
,
Status
:
service
.
StatusActive
})
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"p2"
,
Platform
:
model
.
PlatformOpenAI
,
Status
:
model
.
StatusActive
})
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"p2"
,
Platform
:
service
.
PlatformOpenAI
,
Status
:
service
.
StatusActive
})
accounts
,
err
:=
s
.
repo
.
ListByPlatform
(
s
.
ctx
,
model
.
PlatformAnthropic
)
accounts
,
err
:=
s
.
repo
.
ListByPlatform
(
s
.
ctx
,
service
.
PlatformAnthropic
)
s
.
Require
()
.
NoError
(
err
,
"ListByPlatform"
)
s
.
Require
()
.
NoError
(
err
,
"ListByPlatform"
)
s
.
Require
()
.
Len
(
accounts
,
1
)
s
.
Require
()
.
Len
(
accounts
,
1
)
s
.
Require
()
.
Equal
(
model
.
PlatformAnthropic
,
accounts
[
0
]
.
Platform
)
s
.
Require
()
.
Equal
(
service
.
PlatformAnthropic
,
accounts
[
0
]
.
Platform
)
}
}
// --- Preload and VirtualFields ---
// --- Preload and VirtualFields ---
func
(
s
*
AccountRepoSuite
)
TestPreload_And_VirtualFields
()
{
func
(
s
*
AccountRepoSuite
)
TestPreload_And_VirtualFields
()
{
proxy
:=
mustCreateProxy
(
s
.
T
(),
s
.
db
,
&
model
.
Proxy
{
Name
:
"p1"
})
proxy
:=
mustCreateProxy
(
s
.
T
(),
s
.
db
,
&
proxyModel
{
Name
:
"p1"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g1"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g1"
})
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"acc1"
,
Name
:
"acc1"
,
ProxyID
:
&
proxy
.
ID
,
ProxyID
:
&
proxy
.
ID
,
})
})
...
@@ -252,9 +257,9 @@ func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
...
@@ -252,9 +257,9 @@ func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
// --- GroupBinding / AddToGroup / RemoveFromGroup / BindGroups / GetGroups ---
// --- GroupBinding / AddToGroup / RemoveFromGroup / BindGroups / GetGroups ---
func
(
s
*
AccountRepoSuite
)
TestGroupBinding_And_BindGroups
()
{
func
(
s
*
AccountRepoSuite
)
TestGroupBinding_And_BindGroups
()
{
g1
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g1"
})
g1
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g1"
})
g2
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g2"
})
g2
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g2"
})
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"acc"
})
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"acc"
})
s
.
Require
()
.
NoError
(
s
.
repo
.
AddToGroup
(
s
.
ctx
,
account
.
ID
,
g1
.
ID
,
10
),
"AddToGroup"
)
s
.
Require
()
.
NoError
(
s
.
repo
.
AddToGroup
(
s
.
ctx
,
account
.
ID
,
g1
.
ID
,
10
),
"AddToGroup"
)
groups
,
err
:=
s
.
repo
.
GetGroups
(
s
.
ctx
,
account
.
ID
)
groups
,
err
:=
s
.
repo
.
GetGroups
(
s
.
ctx
,
account
.
ID
)
...
@@ -274,8 +279,8 @@ func (s *AccountRepoSuite) TestGroupBinding_And_BindGroups() {
...
@@ -274,8 +279,8 @@ func (s *AccountRepoSuite) TestGroupBinding_And_BindGroups() {
}
}
func
(
s
*
AccountRepoSuite
)
TestBindGroups_EmptyList
()
{
func
(
s
*
AccountRepoSuite
)
TestBindGroups_EmptyList
()
{
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"acc-empty"
})
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"acc-empty"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g-empty"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g-empty"
})
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
account
.
ID
,
group
.
ID
,
1
)
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
account
.
ID
,
group
.
ID
,
1
)
s
.
Require
()
.
NoError
(
s
.
repo
.
BindGroups
(
s
.
ctx
,
account
.
ID
,
[]
int64
{}),
"BindGroups empty"
)
s
.
Require
()
.
NoError
(
s
.
repo
.
BindGroups
(
s
.
ctx
,
account
.
ID
,
[]
int64
{}),
"BindGroups empty"
)
...
@@ -289,13 +294,13 @@ func (s *AccountRepoSuite) TestBindGroups_EmptyList() {
...
@@ -289,13 +294,13 @@ func (s *AccountRepoSuite) TestBindGroups_EmptyList() {
func
(
s
*
AccountRepoSuite
)
TestListSchedulable
()
{
func
(
s
*
AccountRepoSuite
)
TestListSchedulable
()
{
now
:=
time
.
Now
()
now
:=
time
.
Now
()
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g-sched"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g-sched"
})
okAcc
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"ok"
,
Schedulable
:
true
})
okAcc
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"ok"
,
Schedulable
:
true
})
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
okAcc
.
ID
,
group
.
ID
,
1
)
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
okAcc
.
ID
,
group
.
ID
,
1
)
future
:=
now
.
Add
(
10
*
time
.
Minute
)
future
:=
now
.
Add
(
10
*
time
.
Minute
)
overloaded
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"over"
,
Schedulable
:
true
,
OverloadUntil
:
&
future
})
overloaded
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"over"
,
Schedulable
:
true
,
OverloadUntil
:
&
future
})
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
overloaded
.
ID
,
group
.
ID
,
1
)
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
overloaded
.
ID
,
group
.
ID
,
1
)
sched
,
err
:=
s
.
repo
.
ListSchedulable
(
s
.
ctx
)
sched
,
err
:=
s
.
repo
.
ListSchedulable
(
s
.
ctx
)
...
@@ -307,16 +312,16 @@ func (s *AccountRepoSuite) TestListSchedulable() {
...
@@ -307,16 +312,16 @@ func (s *AccountRepoSuite) TestListSchedulable() {
func
(
s
*
AccountRepoSuite
)
TestListSchedulableByGroupID_TimeBoundaries_And_StatusUpdates
()
{
func
(
s
*
AccountRepoSuite
)
TestListSchedulableByGroupID_TimeBoundaries_And_StatusUpdates
()
{
now
:=
time
.
Now
()
now
:=
time
.
Now
()
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g-sched"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g-sched"
})
okAcc
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"ok"
,
Schedulable
:
true
})
okAcc
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"ok"
,
Schedulable
:
true
})
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
okAcc
.
ID
,
group
.
ID
,
1
)
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
okAcc
.
ID
,
group
.
ID
,
1
)
future
:=
now
.
Add
(
10
*
time
.
Minute
)
future
:=
now
.
Add
(
10
*
time
.
Minute
)
overloaded
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"over"
,
Schedulable
:
true
,
OverloadUntil
:
&
future
})
overloaded
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"over"
,
Schedulable
:
true
,
OverloadUntil
:
&
future
})
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
overloaded
.
ID
,
group
.
ID
,
1
)
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
overloaded
.
ID
,
group
.
ID
,
1
)
rateLimited
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"rl"
,
Schedulable
:
true
})
rateLimited
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"rl"
,
Schedulable
:
true
})
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
rateLimited
.
ID
,
group
.
ID
,
1
)
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
rateLimited
.
ID
,
group
.
ID
,
1
)
s
.
Require
()
.
NoError
(
s
.
repo
.
SetRateLimited
(
s
.
ctx
,
rateLimited
.
ID
,
now
.
Add
(
10
*
time
.
Minute
)),
"SetRateLimited"
)
s
.
Require
()
.
NoError
(
s
.
repo
.
SetRateLimited
(
s
.
ctx
,
rateLimited
.
ID
,
now
.
Add
(
10
*
time
.
Minute
)),
"SetRateLimited"
)
...
@@ -334,30 +339,30 @@ func (s *AccountRepoSuite) TestListSchedulableByGroupID_TimeBoundaries_And_Statu
...
@@ -334,30 +339,30 @@ func (s *AccountRepoSuite) TestListSchedulableByGroupID_TimeBoundaries_And_Statu
}
}
func
(
s
*
AccountRepoSuite
)
TestListSchedulableByPlatform
()
{
func
(
s
*
AccountRepoSuite
)
TestListSchedulableByPlatform
()
{
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"a1"
,
Platform
:
model
.
PlatformAnthropic
,
Schedulable
:
true
})
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"a1"
,
Platform
:
service
.
PlatformAnthropic
,
Schedulable
:
true
})
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"a2"
,
Platform
:
model
.
PlatformOpenAI
,
Schedulable
:
true
})
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"a2"
,
Platform
:
service
.
PlatformOpenAI
,
Schedulable
:
true
})
accounts
,
err
:=
s
.
repo
.
ListSchedulableByPlatform
(
s
.
ctx
,
model
.
PlatformAnthropic
)
accounts
,
err
:=
s
.
repo
.
ListSchedulableByPlatform
(
s
.
ctx
,
service
.
PlatformAnthropic
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
accounts
,
1
)
s
.
Require
()
.
Len
(
accounts
,
1
)
s
.
Require
()
.
Equal
(
model
.
PlatformAnthropic
,
accounts
[
0
]
.
Platform
)
s
.
Require
()
.
Equal
(
service
.
PlatformAnthropic
,
accounts
[
0
]
.
Platform
)
}
}
func
(
s
*
AccountRepoSuite
)
TestListSchedulableByGroupIDAndPlatform
()
{
func
(
s
*
AccountRepoSuite
)
TestListSchedulableByGroupIDAndPlatform
()
{
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g-sp"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g-sp"
})
a1
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"a1"
,
Platform
:
model
.
PlatformAnthropic
,
Schedulable
:
true
})
a1
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"a1"
,
Platform
:
service
.
PlatformAnthropic
,
Schedulable
:
true
})
a2
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"a2"
,
Platform
:
model
.
PlatformOpenAI
,
Schedulable
:
true
})
a2
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"a2"
,
Platform
:
service
.
PlatformOpenAI
,
Schedulable
:
true
})
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
a1
.
ID
,
group
.
ID
,
1
)
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
a1
.
ID
,
group
.
ID
,
1
)
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
a2
.
ID
,
group
.
ID
,
2
)
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
a2
.
ID
,
group
.
ID
,
2
)
accounts
,
err
:=
s
.
repo
.
ListSchedulableByGroupIDAndPlatform
(
s
.
ctx
,
group
.
ID
,
model
.
PlatformAnthropic
)
accounts
,
err
:=
s
.
repo
.
ListSchedulableByGroupIDAndPlatform
(
s
.
ctx
,
group
.
ID
,
service
.
PlatformAnthropic
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
accounts
,
1
)
s
.
Require
()
.
Len
(
accounts
,
1
)
s
.
Require
()
.
Equal
(
a1
.
ID
,
accounts
[
0
]
.
ID
)
s
.
Require
()
.
Equal
(
a1
.
ID
,
accounts
[
0
]
.
ID
)
}
}
func
(
s
*
AccountRepoSuite
)
TestSetSchedulable
()
{
func
(
s
*
AccountRepoSuite
)
TestSetSchedulable
()
{
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"acc-sched"
,
Schedulable
:
true
})
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"acc-sched"
,
Schedulable
:
true
})
s
.
Require
()
.
NoError
(
s
.
repo
.
SetSchedulable
(
s
.
ctx
,
account
.
ID
,
false
))
s
.
Require
()
.
NoError
(
s
.
repo
.
SetSchedulable
(
s
.
ctx
,
account
.
ID
,
false
))
...
@@ -369,7 +374,7 @@ func (s *AccountRepoSuite) TestSetSchedulable() {
...
@@ -369,7 +374,7 @@ func (s *AccountRepoSuite) TestSetSchedulable() {
// --- SetOverloaded / SetRateLimited / ClearRateLimit ---
// --- SetOverloaded / SetRateLimited / ClearRateLimit ---
func
(
s
*
AccountRepoSuite
)
TestSetOverloaded
()
{
func
(
s
*
AccountRepoSuite
)
TestSetOverloaded
()
{
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"acc-over"
})
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"acc-over"
})
until
:=
time
.
Date
(
2025
,
6
,
15
,
12
,
0
,
0
,
0
,
time
.
UTC
)
until
:=
time
.
Date
(
2025
,
6
,
15
,
12
,
0
,
0
,
0
,
time
.
UTC
)
s
.
Require
()
.
NoError
(
s
.
repo
.
SetOverloaded
(
s
.
ctx
,
account
.
ID
,
until
))
s
.
Require
()
.
NoError
(
s
.
repo
.
SetOverloaded
(
s
.
ctx
,
account
.
ID
,
until
))
...
@@ -381,7 +386,7 @@ func (s *AccountRepoSuite) TestSetOverloaded() {
...
@@ -381,7 +386,7 @@ func (s *AccountRepoSuite) TestSetOverloaded() {
}
}
func
(
s
*
AccountRepoSuite
)
TestSetRateLimited
()
{
func
(
s
*
AccountRepoSuite
)
TestSetRateLimited
()
{
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"acc-rl"
})
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"acc-rl"
})
resetAt
:=
time
.
Date
(
2025
,
6
,
15
,
14
,
0
,
0
,
0
,
time
.
UTC
)
resetAt
:=
time
.
Date
(
2025
,
6
,
15
,
14
,
0
,
0
,
0
,
time
.
UTC
)
s
.
Require
()
.
NoError
(
s
.
repo
.
SetRateLimited
(
s
.
ctx
,
account
.
ID
,
resetAt
))
s
.
Require
()
.
NoError
(
s
.
repo
.
SetRateLimited
(
s
.
ctx
,
account
.
ID
,
resetAt
))
...
@@ -394,7 +399,7 @@ func (s *AccountRepoSuite) TestSetRateLimited() {
...
@@ -394,7 +399,7 @@ func (s *AccountRepoSuite) TestSetRateLimited() {
}
}
func
(
s
*
AccountRepoSuite
)
TestClearRateLimit
()
{
func
(
s
*
AccountRepoSuite
)
TestClearRateLimit
()
{
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"acc-clear"
})
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"acc-clear"
})
until
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
until
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
s
.
Require
()
.
NoError
(
s
.
repo
.
SetOverloaded
(
s
.
ctx
,
account
.
ID
,
until
))
s
.
Require
()
.
NoError
(
s
.
repo
.
SetOverloaded
(
s
.
ctx
,
account
.
ID
,
until
))
s
.
Require
()
.
NoError
(
s
.
repo
.
SetRateLimited
(
s
.
ctx
,
account
.
ID
,
until
))
s
.
Require
()
.
NoError
(
s
.
repo
.
SetRateLimited
(
s
.
ctx
,
account
.
ID
,
until
))
...
@@ -411,7 +416,7 @@ func (s *AccountRepoSuite) TestClearRateLimit() {
...
@@ -411,7 +416,7 @@ func (s *AccountRepoSuite) TestClearRateLimit() {
// --- UpdateLastUsed ---
// --- UpdateLastUsed ---
func
(
s
*
AccountRepoSuite
)
TestUpdateLastUsed
()
{
func
(
s
*
AccountRepoSuite
)
TestUpdateLastUsed
()
{
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"acc-used"
})
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"acc-used"
})
s
.
Require
()
.
Nil
(
account
.
LastUsedAt
)
s
.
Require
()
.
Nil
(
account
.
LastUsedAt
)
s
.
Require
()
.
NoError
(
s
.
repo
.
UpdateLastUsed
(
s
.
ctx
,
account
.
ID
))
s
.
Require
()
.
NoError
(
s
.
repo
.
UpdateLastUsed
(
s
.
ctx
,
account
.
ID
))
...
@@ -424,20 +429,20 @@ func (s *AccountRepoSuite) TestUpdateLastUsed() {
...
@@ -424,20 +429,20 @@ func (s *AccountRepoSuite) TestUpdateLastUsed() {
// --- SetError ---
// --- SetError ---
func
(
s
*
AccountRepoSuite
)
TestSetError
()
{
func
(
s
*
AccountRepoSuite
)
TestSetError
()
{
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"acc-err"
,
Status
:
model
.
StatusActive
})
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"acc-err"
,
Status
:
service
.
StatusActive
})
s
.
Require
()
.
NoError
(
s
.
repo
.
SetError
(
s
.
ctx
,
account
.
ID
,
"something went wrong"
))
s
.
Require
()
.
NoError
(
s
.
repo
.
SetError
(
s
.
ctx
,
account
.
ID
,
"something went wrong"
))
got
,
err
:=
s
.
repo
.
GetByID
(
s
.
ctx
,
account
.
ID
)
got
,
err
:=
s
.
repo
.
GetByID
(
s
.
ctx
,
account
.
ID
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Equal
(
model
.
StatusError
,
got
.
Status
)
s
.
Require
()
.
Equal
(
service
.
StatusError
,
got
.
Status
)
s
.
Require
()
.
Equal
(
"something went wrong"
,
got
.
ErrorMessage
)
s
.
Require
()
.
Equal
(
"something went wrong"
,
got
.
ErrorMessage
)
}
}
// --- UpdateSessionWindow ---
// --- UpdateSessionWindow ---
func
(
s
*
AccountRepoSuite
)
TestUpdateSessionWindow
()
{
func
(
s
*
AccountRepoSuite
)
TestUpdateSessionWindow
()
{
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"acc-win"
})
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"acc-win"
})
start
:=
time
.
Date
(
2025
,
6
,
15
,
10
,
0
,
0
,
0
,
time
.
UTC
)
start
:=
time
.
Date
(
2025
,
6
,
15
,
10
,
0
,
0
,
0
,
time
.
UTC
)
end
:=
time
.
Date
(
2025
,
6
,
15
,
15
,
0
,
0
,
0
,
time
.
UTC
)
end
:=
time
.
Date
(
2025
,
6
,
15
,
15
,
0
,
0
,
0
,
time
.
UTC
)
...
@@ -453,9 +458,9 @@ func (s *AccountRepoSuite) TestUpdateSessionWindow() {
...
@@ -453,9 +458,9 @@ func (s *AccountRepoSuite) TestUpdateSessionWindow() {
// --- UpdateExtra ---
// --- UpdateExtra ---
func
(
s
*
AccountRepoSuite
)
TestUpdateExtra_MergesFields
()
{
func
(
s
*
AccountRepoSuite
)
TestUpdateExtra_MergesFields
()
{
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"acc-extra"
,
Name
:
"acc-extra"
,
Extra
:
model
.
JSON
B
{
"a"
:
"1"
},
Extra
:
datatypes
.
JSON
Map
{
"a"
:
"1"
},
})
})
s
.
Require
()
.
NoError
(
s
.
repo
.
UpdateExtra
(
s
.
ctx
,
account
.
ID
,
map
[
string
]
any
{
"b"
:
"2"
}),
"UpdateExtra"
)
s
.
Require
()
.
NoError
(
s
.
repo
.
UpdateExtra
(
s
.
ctx
,
account
.
ID
,
map
[
string
]
any
{
"b"
:
"2"
}),
"UpdateExtra"
)
...
@@ -466,12 +471,12 @@ func (s *AccountRepoSuite) TestUpdateExtra_MergesFields() {
...
@@ -466,12 +471,12 @@ func (s *AccountRepoSuite) TestUpdateExtra_MergesFields() {
}
}
func
(
s
*
AccountRepoSuite
)
TestUpdateExtra_EmptyUpdates
()
{
func
(
s
*
AccountRepoSuite
)
TestUpdateExtra_EmptyUpdates
()
{
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"acc-extra-empty"
})
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"acc-extra-empty"
})
s
.
Require
()
.
NoError
(
s
.
repo
.
UpdateExtra
(
s
.
ctx
,
account
.
ID
,
map
[
string
]
any
{}))
s
.
Require
()
.
NoError
(
s
.
repo
.
UpdateExtra
(
s
.
ctx
,
account
.
ID
,
map
[
string
]
any
{}))
}
}
func
(
s
*
AccountRepoSuite
)
TestUpdateExtra_NilExtra
()
{
func
(
s
*
AccountRepoSuite
)
TestUpdateExtra_NilExtra
()
{
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"acc-nil-extra"
,
Extra
:
nil
})
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"acc-nil-extra"
,
Extra
:
nil
})
s
.
Require
()
.
NoError
(
s
.
repo
.
UpdateExtra
(
s
.
ctx
,
account
.
ID
,
map
[
string
]
any
{
"key"
:
"val"
}))
s
.
Require
()
.
NoError
(
s
.
repo
.
UpdateExtra
(
s
.
ctx
,
account
.
ID
,
map
[
string
]
any
{
"key"
:
"val"
}))
got
,
err
:=
s
.
repo
.
GetByID
(
s
.
ctx
,
account
.
ID
)
got
,
err
:=
s
.
repo
.
GetByID
(
s
.
ctx
,
account
.
ID
)
...
@@ -483,9 +488,9 @@ func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
...
@@ -483,9 +488,9 @@ func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
func
(
s
*
AccountRepoSuite
)
TestGetByCRSAccountID
()
{
func
(
s
*
AccountRepoSuite
)
TestGetByCRSAccountID
()
{
crsID
:=
"crs-12345"
crsID
:=
"crs-12345"
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"acc-crs"
,
Name
:
"acc-crs"
,
Extra
:
model
.
JSON
B
{
"crs_account_id"
:
crsID
},
Extra
:
datatypes
.
JSON
Map
{
"crs_account_id"
:
crsID
},
})
})
got
,
err
:=
s
.
repo
.
GetByCRSAccountID
(
s
.
ctx
,
crsID
)
got
,
err
:=
s
.
repo
.
GetByCRSAccountID
(
s
.
ctx
,
crsID
)
...
@@ -509,8 +514,8 @@ func (s *AccountRepoSuite) TestGetByCRSAccountID_EmptyString() {
...
@@ -509,8 +514,8 @@ func (s *AccountRepoSuite) TestGetByCRSAccountID_EmptyString() {
// --- BulkUpdate ---
// --- BulkUpdate ---
func
(
s
*
AccountRepoSuite
)
TestBulkUpdate
()
{
func
(
s
*
AccountRepoSuite
)
TestBulkUpdate
()
{
a1
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"bulk1"
,
Priority
:
1
})
a1
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"bulk1"
,
Priority
:
1
})
a2
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"bulk2"
,
Priority
:
1
})
a2
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"bulk2"
,
Priority
:
1
})
newPriority
:=
99
newPriority
:=
99
affected
,
err
:=
s
.
repo
.
BulkUpdate
(
s
.
ctx
,
[]
int64
{
a1
.
ID
,
a2
.
ID
},
service
.
AccountBulkUpdate
{
affected
,
err
:=
s
.
repo
.
BulkUpdate
(
s
.
ctx
,
[]
int64
{
a1
.
ID
,
a2
.
ID
},
service
.
AccountBulkUpdate
{
...
@@ -526,13 +531,13 @@ func (s *AccountRepoSuite) TestBulkUpdate() {
...
@@ -526,13 +531,13 @@ func (s *AccountRepoSuite) TestBulkUpdate() {
}
}
func
(
s
*
AccountRepoSuite
)
TestBulkUpdate_MergeCredentials
()
{
func
(
s
*
AccountRepoSuite
)
TestBulkUpdate_MergeCredentials
()
{
a1
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
a1
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"bulk-cred"
,
Name
:
"bulk-cred"
,
Credentials
:
model
.
JSON
B
{
"existing"
:
"value"
},
Credentials
:
datatypes
.
JSON
Map
{
"existing"
:
"value"
},
})
})
_
,
err
:=
s
.
repo
.
BulkUpdate
(
s
.
ctx
,
[]
int64
{
a1
.
ID
},
service
.
AccountBulkUpdate
{
_
,
err
:=
s
.
repo
.
BulkUpdate
(
s
.
ctx
,
[]
int64
{
a1
.
ID
},
service
.
AccountBulkUpdate
{
Credentials
:
model
.
JSON
B
{
"new_key"
:
"new_value"
},
Credentials
:
datatypes
.
JSON
Map
{
"new_key"
:
"new_value"
},
})
})
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
NoError
(
err
)
...
@@ -542,13 +547,13 @@ func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() {
...
@@ -542,13 +547,13 @@ func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() {
}
}
func
(
s
*
AccountRepoSuite
)
TestBulkUpdate_MergeExtra
()
{
func
(
s
*
AccountRepoSuite
)
TestBulkUpdate_MergeExtra
()
{
a1
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
a1
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"bulk-extra"
,
Name
:
"bulk-extra"
,
Extra
:
model
.
JSON
B
{
"existing"
:
"val"
},
Extra
:
datatypes
.
JSON
Map
{
"existing"
:
"val"
},
})
})
_
,
err
:=
s
.
repo
.
BulkUpdate
(
s
.
ctx
,
[]
int64
{
a1
.
ID
},
service
.
AccountBulkUpdate
{
_
,
err
:=
s
.
repo
.
BulkUpdate
(
s
.
ctx
,
[]
int64
{
a1
.
ID
},
service
.
AccountBulkUpdate
{
Extra
:
model
.
JSON
B
{
"new_key"
:
"new_val"
},
Extra
:
datatypes
.
JSON
Map
{
"new_key"
:
"new_val"
},
})
})
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
NoError
(
err
)
...
@@ -564,14 +569,14 @@ func (s *AccountRepoSuite) TestBulkUpdate_EmptyIDs() {
...
@@ -564,14 +569,14 @@ func (s *AccountRepoSuite) TestBulkUpdate_EmptyIDs() {
}
}
func
(
s
*
AccountRepoSuite
)
TestBulkUpdate_EmptyUpdates
()
{
func
(
s
*
AccountRepoSuite
)
TestBulkUpdate_EmptyUpdates
()
{
a1
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"bulk-empty"
})
a1
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"bulk-empty"
})
affected
,
err
:=
s
.
repo
.
BulkUpdate
(
s
.
ctx
,
[]
int64
{
a1
.
ID
},
service
.
AccountBulkUpdate
{})
affected
,
err
:=
s
.
repo
.
BulkUpdate
(
s
.
ctx
,
[]
int64
{
a1
.
ID
},
service
.
AccountBulkUpdate
{})
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Zero
(
affected
)
s
.
Require
()
.
Zero
(
affected
)
}
}
func
idsOfAccounts
(
accounts
[]
model
.
Account
)
[]
int64
{
func
idsOfAccounts
(
accounts
[]
service
.
Account
)
[]
int64
{
out
:=
make
([]
int64
,
0
,
len
(
accounts
))
out
:=
make
([]
int64
,
0
,
len
(
accounts
))
for
i
:=
range
accounts
{
for
i
:=
range
accounts
{
out
=
append
(
out
,
accounts
[
i
]
.
ID
)
out
=
append
(
out
,
accounts
[
i
]
.
ID
)
...
...
backend/internal/repository/api_key_repo.go
View file @
22f07a7b
...
@@ -2,10 +2,10 @@ package repository
...
@@ -2,10 +2,10 @@ package repository
import
(
import
(
"context"
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
"gorm.io/gorm"
...
@@ -19,42 +19,51 @@ func NewApiKeyRepository(db *gorm.DB) service.ApiKeyRepository {
...
@@ -19,42 +19,51 @@ func NewApiKeyRepository(db *gorm.DB) service.ApiKeyRepository {
return
&
apiKeyRepository
{
db
:
db
}
return
&
apiKeyRepository
{
db
:
db
}
}
}
func
(
r
*
apiKeyRepository
)
Create
(
ctx
context
.
Context
,
key
*
model
.
ApiKey
)
error
{
func
(
r
*
apiKeyRepository
)
Create
(
ctx
context
.
Context
,
key
*
service
.
ApiKey
)
error
{
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Create
(
key
)
.
Error
m
:=
apiKeyModelFromService
(
key
)
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Create
(
m
)
.
Error
if
err
==
nil
{
applyApiKeyModelToService
(
key
,
m
)
}
return
translatePersistenceError
(
err
,
nil
,
service
.
ErrApiKeyExists
)
return
translatePersistenceError
(
err
,
nil
,
service
.
ErrApiKeyExists
)
}
}
func
(
r
*
apiKeyRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
ApiKey
,
error
)
{
func
(
r
*
apiKeyRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
ApiKey
,
error
)
{
var
key
model
.
ApiKey
var
m
apiKeyModel
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Preload
(
"User"
)
.
Preload
(
"Group"
)
.
First
(
&
key
,
id
)
.
Error
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Preload
(
"User"
)
.
Preload
(
"Group"
)
.
First
(
&
m
,
id
)
.
Error
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
translatePersistenceError
(
err
,
service
.
ErrApiKeyNotFound
,
nil
)
return
nil
,
translatePersistenceError
(
err
,
service
.
ErrApiKeyNotFound
,
nil
)
}
}
return
&
key
,
nil
return
apiKeyModelToService
(
&
m
)
,
nil
}
}
func
(
r
*
apiKeyRepository
)
GetByKey
(
ctx
context
.
Context
,
key
string
)
(
*
model
.
ApiKey
,
error
)
{
func
(
r
*
apiKeyRepository
)
GetByKey
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
ApiKey
,
error
)
{
var
apiKey
m
odel
.
ApiKey
var
m
apiKey
M
odel
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Preload
(
"User"
)
.
Preload
(
"Group"
)
.
Where
(
"key = ?"
,
key
)
.
First
(
&
apiKey
)
.
Error
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Preload
(
"User"
)
.
Preload
(
"Group"
)
.
Where
(
"key = ?"
,
key
)
.
First
(
&
m
)
.
Error
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
translatePersistenceError
(
err
,
service
.
ErrApiKeyNotFound
,
nil
)
return
nil
,
translatePersistenceError
(
err
,
service
.
ErrApiKeyNotFound
,
nil
)
}
}
return
&
apiKey
,
nil
return
apiKey
ModelToService
(
&
m
)
,
nil
}
}
func
(
r
*
apiKeyRepository
)
Update
(
ctx
context
.
Context
,
key
*
model
.
ApiKey
)
error
{
func
(
r
*
apiKeyRepository
)
Update
(
ctx
context
.
Context
,
key
*
service
.
ApiKey
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
key
)
.
Select
(
"name"
,
"group_id"
,
"status"
,
"updated_at"
)
.
Updates
(
key
)
.
Error
m
:=
apiKeyModelFromService
(
key
)
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
m
)
.
Select
(
"name"
,
"group_id"
,
"status"
,
"updated_at"
)
.
Updates
(
m
)
.
Error
if
err
==
nil
{
applyApiKeyModelToService
(
key
,
m
)
}
return
err
}
}
func
(
r
*
apiKeyRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
func
(
r
*
apiKeyRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Delete
(
&
model
.
ApiKey
{},
id
)
.
Error
return
r
.
db
.
WithContext
(
ctx
)
.
Delete
(
&
apiKeyModel
{},
id
)
.
Error
}
}
func
(
r
*
apiKeyRepository
)
ListByUserID
(
ctx
context
.
Context
,
userID
int64
,
params
pagination
.
PaginationParams
)
([]
model
.
ApiKey
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
apiKeyRepository
)
ListByUserID
(
ctx
context
.
Context
,
userID
int64
,
params
pagination
.
PaginationParams
)
([]
service
.
ApiKey
,
*
pagination
.
PaginationResult
,
error
)
{
var
keys
[]
model
.
ApiKey
var
keys
[]
apiKeyModel
var
total
int64
var
total
int64
db
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
ApiKey
{})
.
Where
(
"user_id = ?"
,
userID
)
db
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
apiKeyModel
{})
.
Where
(
"user_id = ?"
,
userID
)
if
err
:=
db
.
Count
(
&
total
)
.
Error
;
err
!=
nil
{
if
err
:=
db
.
Count
(
&
total
)
.
Error
;
err
!=
nil
{
return
nil
,
nil
,
err
return
nil
,
nil
,
err
...
@@ -64,36 +73,31 @@ func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
...
@@ -64,36 +73,31 @@ func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
return
nil
,
nil
,
err
return
nil
,
nil
,
err
}
}
pages
:=
int
(
total
)
/
params
.
Limit
(
)
outKeys
:=
make
([]
service
.
ApiKey
,
0
,
len
(
keys
)
)
if
int
(
total
)
%
params
.
Limit
()
>
0
{
for
i
:=
range
keys
{
pages
++
outKeys
=
append
(
outKeys
,
*
apiKeyModelToService
(
&
keys
[
i
]))
}
}
return
keys
,
&
pagination
.
PaginationResult
{
return
outKeys
,
paginationResultFromTotal
(
total
,
params
),
nil
Total
:
total
,
Page
:
params
.
Page
,
PageSize
:
params
.
Limit
(),
Pages
:
pages
,
},
nil
}
}
func
(
r
*
apiKeyRepository
)
CountByUserID
(
ctx
context
.
Context
,
userID
int64
)
(
int64
,
error
)
{
func
(
r
*
apiKeyRepository
)
CountByUserID
(
ctx
context
.
Context
,
userID
int64
)
(
int64
,
error
)
{
var
count
int64
var
count
int64
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
ApiKey
{})
.
Where
(
"user_id = ?"
,
userID
)
.
Count
(
&
count
)
.
Error
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
apiKeyModel
{})
.
Where
(
"user_id = ?"
,
userID
)
.
Count
(
&
count
)
.
Error
return
count
,
err
return
count
,
err
}
}
func
(
r
*
apiKeyRepository
)
ExistsByKey
(
ctx
context
.
Context
,
key
string
)
(
bool
,
error
)
{
func
(
r
*
apiKeyRepository
)
ExistsByKey
(
ctx
context
.
Context
,
key
string
)
(
bool
,
error
)
{
var
count
int64
var
count
int64
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
ApiKey
{})
.
Where
(
"key = ?"
,
key
)
.
Count
(
&
count
)
.
Error
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
apiKeyModel
{})
.
Where
(
"key = ?"
,
key
)
.
Count
(
&
count
)
.
Error
return
count
>
0
,
err
return
count
>
0
,
err
}
}
func
(
r
*
apiKeyRepository
)
ListByGroupID
(
ctx
context
.
Context
,
groupID
int64
,
params
pagination
.
PaginationParams
)
([]
model
.
ApiKey
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
apiKeyRepository
)
ListByGroupID
(
ctx
context
.
Context
,
groupID
int64
,
params
pagination
.
PaginationParams
)
([]
service
.
ApiKey
,
*
pagination
.
PaginationResult
,
error
)
{
var
keys
[]
model
.
ApiKey
var
keys
[]
apiKeyModel
var
total
int64
var
total
int64
db
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
ApiKey
{})
.
Where
(
"group_id = ?"
,
groupID
)
db
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
apiKeyModel
{})
.
Where
(
"group_id = ?"
,
groupID
)
if
err
:=
db
.
Count
(
&
total
)
.
Error
;
err
!=
nil
{
if
err
:=
db
.
Count
(
&
total
)
.
Error
;
err
!=
nil
{
return
nil
,
nil
,
err
return
nil
,
nil
,
err
...
@@ -103,24 +107,19 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
...
@@ -103,24 +107,19 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
return
nil
,
nil
,
err
return
nil
,
nil
,
err
}
}
pages
:=
int
(
total
)
/
params
.
Limit
(
)
outKeys
:=
make
([]
service
.
ApiKey
,
0
,
len
(
keys
)
)
if
int
(
total
)
%
params
.
Limit
()
>
0
{
for
i
:=
range
keys
{
pages
++
outKeys
=
append
(
outKeys
,
*
apiKeyModelToService
(
&
keys
[
i
]))
}
}
return
keys
,
&
pagination
.
PaginationResult
{
return
outKeys
,
paginationResultFromTotal
(
total
,
params
),
nil
Total
:
total
,
Page
:
params
.
Page
,
PageSize
:
params
.
Limit
(),
Pages
:
pages
,
},
nil
}
}
// SearchApiKeys searches API keys by user ID and/or keyword (name)
// SearchApiKeys searches API keys by user ID and/or keyword (name)
func
(
r
*
apiKeyRepository
)
SearchApiKeys
(
ctx
context
.
Context
,
userID
int64
,
keyword
string
,
limit
int
)
([]
model
.
ApiKey
,
error
)
{
func
(
r
*
apiKeyRepository
)
SearchApiKeys
(
ctx
context
.
Context
,
userID
int64
,
keyword
string
,
limit
int
)
([]
service
.
ApiKey
,
error
)
{
var
keys
[]
model
.
ApiKey
var
keys
[]
apiKeyModel
db
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
ApiKey
{})
db
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
apiKeyModel
{})
if
userID
>
0
{
if
userID
>
0
{
db
=
db
.
Where
(
"user_id = ?"
,
userID
)
db
=
db
.
Where
(
"user_id = ?"
,
userID
)
...
@@ -135,12 +134,16 @@ func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw
...
@@ -135,12 +134,16 @@ func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw
return
nil
,
err
return
nil
,
err
}
}
return
keys
,
nil
outKeys
:=
make
([]
service
.
ApiKey
,
0
,
len
(
keys
))
for
i
:=
range
keys
{
outKeys
=
append
(
outKeys
,
*
apiKeyModelToService
(
&
keys
[
i
]))
}
return
outKeys
,
nil
}
}
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
func
(
r
*
apiKeyRepository
)
ClearGroupIDByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
func
(
r
*
apiKeyRepository
)
ClearGroupIDByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
result
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
ApiKey
{})
.
result
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
apiKeyModel
{})
.
Where
(
"group_id = ?"
,
groupID
)
.
Where
(
"group_id = ?"
,
groupID
)
.
Update
(
"group_id"
,
nil
)
Update
(
"group_id"
,
nil
)
return
result
.
RowsAffected
,
result
.
Error
return
result
.
RowsAffected
,
result
.
Error
...
@@ -149,6 +152,66 @@ func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID in
...
@@ -149,6 +152,66 @@ func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID in
// CountByGroupID 获取分组的 API Key 数量
// CountByGroupID 获取分组的 API Key 数量
func
(
r
*
apiKeyRepository
)
CountByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
func
(
r
*
apiKeyRepository
)
CountByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
var
count
int64
var
count
int64
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
ApiKey
{})
.
Where
(
"group_id = ?"
,
groupID
)
.
Count
(
&
count
)
.
Error
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
apiKeyModel
{})
.
Where
(
"group_id = ?"
,
groupID
)
.
Count
(
&
count
)
.
Error
return
count
,
err
return
count
,
err
}
}
type
apiKeyModel
struct
{
ID
int64
`gorm:"primaryKey"`
UserID
int64
`gorm:"index;not null"`
Key
string
`gorm:"uniqueIndex;size:128;not null"`
Name
string
`gorm:"size:100;not null"`
GroupID
*
int64
`gorm:"index"`
Status
string
`gorm:"size:20;default:active;not null"`
CreatedAt
time
.
Time
`gorm:"not null"`
UpdatedAt
time
.
Time
`gorm:"not null"`
DeletedAt
gorm
.
DeletedAt
`gorm:"index"`
User
*
userModel
`gorm:"foreignKey:UserID"`
Group
*
groupModel
`gorm:"foreignKey:GroupID"`
}
func
(
apiKeyModel
)
TableName
()
string
{
return
"api_keys"
}
func
apiKeyModelToService
(
m
*
apiKeyModel
)
*
service
.
ApiKey
{
if
m
==
nil
{
return
nil
}
return
&
service
.
ApiKey
{
ID
:
m
.
ID
,
UserID
:
m
.
UserID
,
Key
:
m
.
Key
,
Name
:
m
.
Name
,
GroupID
:
m
.
GroupID
,
Status
:
m
.
Status
,
CreatedAt
:
m
.
CreatedAt
,
UpdatedAt
:
m
.
UpdatedAt
,
User
:
userModelToService
(
m
.
User
),
Group
:
groupModelToService
(
m
.
Group
),
}
}
func
apiKeyModelFromService
(
k
*
service
.
ApiKey
)
*
apiKeyModel
{
if
k
==
nil
{
return
nil
}
return
&
apiKeyModel
{
ID
:
k
.
ID
,
UserID
:
k
.
UserID
,
Key
:
k
.
Key
,
Name
:
k
.
Name
,
GroupID
:
k
.
GroupID
,
Status
:
k
.
Status
,
CreatedAt
:
k
.
CreatedAt
,
UpdatedAt
:
k
.
UpdatedAt
,
}
}
func
applyApiKeyModelToService
(
key
*
service
.
ApiKey
,
m
*
apiKeyModel
)
{
if
key
==
nil
||
m
==
nil
{
return
}
key
.
ID
=
m
.
ID
key
.
CreatedAt
=
m
.
CreatedAt
key
.
UpdatedAt
=
m
.
UpdatedAt
}
backend/internal/repository/api_key_repo_integration_test.go
View file @
22f07a7b
...
@@ -6,8 +6,8 @@ import (
...
@@ -6,8 +6,8 @@ import (
"context"
"context"
"testing"
"testing"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
"gorm.io/gorm"
)
)
...
@@ -32,13 +32,13 @@ func TestApiKeyRepoSuite(t *testing.T) {
...
@@ -32,13 +32,13 @@ func TestApiKeyRepoSuite(t *testing.T) {
// --- Create / GetByID / GetByKey ---
// --- Create / GetByID / GetByKey ---
func
(
s
*
ApiKeyRepoSuite
)
TestCreate
()
{
func
(
s
*
ApiKeyRepoSuite
)
TestCreate
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"create@test.com"
})
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
userModel
{
Email
:
"create@test.com"
})
key
:=
&
model
.
ApiKey
{
key
:=
&
service
.
ApiKey
{
UserID
:
user
.
ID
,
UserID
:
user
.
ID
,
Key
:
"sk-create-test"
,
Key
:
"sk-create-test"
,
Name
:
"Test Key"
,
Name
:
"Test Key"
,
Status
:
model
.
StatusActive
,
Status
:
service
.
StatusActive
,
}
}
err
:=
s
.
repo
.
Create
(
s
.
ctx
,
key
)
err
:=
s
.
repo
.
Create
(
s
.
ctx
,
key
)
...
@@ -56,15 +56,15 @@ func (s *ApiKeyRepoSuite) TestGetByID_NotFound() {
...
@@ -56,15 +56,15 @@ func (s *ApiKeyRepoSuite) TestGetByID_NotFound() {
}
}
func
(
s
*
ApiKeyRepoSuite
)
TestGetByKey
()
{
func
(
s
*
ApiKeyRepoSuite
)
TestGetByKey
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"getbykey@test.com"
})
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
userModel
{
Email
:
"getbykey@test.com"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g-key"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g-key"
})
key
:=
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
key
:=
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
UserID
:
user
.
ID
,
Key
:
"sk-getbykey"
,
Key
:
"sk-getbykey"
,
Name
:
"My Key"
,
Name
:
"My Key"
,
GroupID
:
&
group
.
ID
,
GroupID
:
&
group
.
ID
,
Status
:
model
.
StatusActive
,
Status
:
service
.
StatusActive
,
})
})
got
,
err
:=
s
.
repo
.
GetByKey
(
s
.
ctx
,
key
.
Key
)
got
,
err
:=
s
.
repo
.
GetByKey
(
s
.
ctx
,
key
.
Key
)
...
@@ -84,16 +84,16 @@ func (s *ApiKeyRepoSuite) TestGetByKey_NotFound() {
...
@@ -84,16 +84,16 @@ func (s *ApiKeyRepoSuite) TestGetByKey_NotFound() {
// --- Update ---
// --- Update ---
func
(
s
*
ApiKeyRepoSuite
)
TestUpdate
()
{
func
(
s
*
ApiKeyRepoSuite
)
TestUpdate
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"update@test.com"
})
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
userModel
{
Email
:
"update@test.com"
})
key
:=
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
key
:=
apiKeyModelToService
(
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
UserID
:
user
.
ID
,
Key
:
"sk-update"
,
Key
:
"sk-update"
,
Name
:
"Original"
,
Name
:
"Original"
,
Status
:
model
.
StatusActive
,
Status
:
service
.
StatusActive
,
})
})
)
key
.
Name
=
"Renamed"
key
.
Name
=
"Renamed"
key
.
Status
=
model
.
StatusDisabled
key
.
Status
=
service
.
StatusDisabled
err
:=
s
.
repo
.
Update
(
s
.
ctx
,
key
)
err
:=
s
.
repo
.
Update
(
s
.
ctx
,
key
)
s
.
Require
()
.
NoError
(
err
,
"Update"
)
s
.
Require
()
.
NoError
(
err
,
"Update"
)
...
@@ -102,18 +102,18 @@ func (s *ApiKeyRepoSuite) TestUpdate() {
...
@@ -102,18 +102,18 @@ func (s *ApiKeyRepoSuite) TestUpdate() {
s
.
Require
()
.
Equal
(
"sk-update"
,
got
.
Key
,
"Update should not change key"
)
s
.
Require
()
.
Equal
(
"sk-update"
,
got
.
Key
,
"Update should not change key"
)
s
.
Require
()
.
Equal
(
user
.
ID
,
got
.
UserID
,
"Update should not change user_id"
)
s
.
Require
()
.
Equal
(
user
.
ID
,
got
.
UserID
,
"Update should not change user_id"
)
s
.
Require
()
.
Equal
(
"Renamed"
,
got
.
Name
)
s
.
Require
()
.
Equal
(
"Renamed"
,
got
.
Name
)
s
.
Require
()
.
Equal
(
model
.
StatusDisabled
,
got
.
Status
)
s
.
Require
()
.
Equal
(
service
.
StatusDisabled
,
got
.
Status
)
}
}
func
(
s
*
ApiKeyRepoSuite
)
TestUpdate_ClearGroupID
()
{
func
(
s
*
ApiKeyRepoSuite
)
TestUpdate_ClearGroupID
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"cleargroup@test.com"
})
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
userModel
{
Email
:
"cleargroup@test.com"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g-clear"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g-clear"
})
key
:=
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
key
:=
apiKeyModelToService
(
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
UserID
:
user
.
ID
,
Key
:
"sk-clear-group"
,
Key
:
"sk-clear-group"
,
Name
:
"Group Key"
,
Name
:
"Group Key"
,
GroupID
:
&
group
.
ID
,
GroupID
:
&
group
.
ID
,
})
})
)
key
.
GroupID
=
nil
key
.
GroupID
=
nil
err
:=
s
.
repo
.
Update
(
s
.
ctx
,
key
)
err
:=
s
.
repo
.
Update
(
s
.
ctx
,
key
)
...
@@ -127,8 +127,8 @@ func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
...
@@ -127,8 +127,8 @@ func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
// --- Delete ---
// --- Delete ---
func
(
s
*
ApiKeyRepoSuite
)
TestDelete
()
{
func
(
s
*
ApiKeyRepoSuite
)
TestDelete
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"delete@test.com"
})
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
userModel
{
Email
:
"delete@test.com"
})
key
:=
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
key
:=
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
UserID
:
user
.
ID
,
Key
:
"sk-delete"
,
Key
:
"sk-delete"
,
Name
:
"Delete Me"
,
Name
:
"Delete Me"
,
...
@@ -144,9 +144,9 @@ func (s *ApiKeyRepoSuite) TestDelete() {
...
@@ -144,9 +144,9 @@ func (s *ApiKeyRepoSuite) TestDelete() {
// --- ListByUserID / CountByUserID ---
// --- ListByUserID / CountByUserID ---
func
(
s
*
ApiKeyRepoSuite
)
TestListByUserID
()
{
func
(
s
*
ApiKeyRepoSuite
)
TestListByUserID
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"listbyuser@test.com"
})
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
userModel
{
Email
:
"listbyuser@test.com"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-list-1"
,
Name
:
"Key 1"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-list-1"
,
Name
:
"Key 1"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-list-2"
,
Name
:
"Key 2"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-list-2"
,
Name
:
"Key 2"
})
keys
,
page
,
err
:=
s
.
repo
.
ListByUserID
(
s
.
ctx
,
user
.
ID
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
})
keys
,
page
,
err
:=
s
.
repo
.
ListByUserID
(
s
.
ctx
,
user
.
ID
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
})
s
.
Require
()
.
NoError
(
err
,
"ListByUserID"
)
s
.
Require
()
.
NoError
(
err
,
"ListByUserID"
)
...
@@ -155,9 +155,9 @@ func (s *ApiKeyRepoSuite) TestListByUserID() {
...
@@ -155,9 +155,9 @@ func (s *ApiKeyRepoSuite) TestListByUserID() {
}
}
func
(
s
*
ApiKeyRepoSuite
)
TestListByUserID_Pagination
()
{
func
(
s
*
ApiKeyRepoSuite
)
TestListByUserID_Pagination
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"paging@test.com"
})
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
userModel
{
Email
:
"paging@test.com"
})
for
i
:=
0
;
i
<
5
;
i
++
{
for
i
:=
0
;
i
<
5
;
i
++
{
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
UserID
:
user
.
ID
,
Key
:
"sk-page-"
+
string
(
rune
(
'a'
+
i
)),
Key
:
"sk-page-"
+
string
(
rune
(
'a'
+
i
)),
Name
:
"Key"
,
Name
:
"Key"
,
...
@@ -172,9 +172,9 @@ func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
...
@@ -172,9 +172,9 @@ func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
}
}
func
(
s
*
ApiKeyRepoSuite
)
TestCountByUserID
()
{
func
(
s
*
ApiKeyRepoSuite
)
TestCountByUserID
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"count@test.com"
})
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
userModel
{
Email
:
"count@test.com"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-count-1"
,
Name
:
"K1"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-count-1"
,
Name
:
"K1"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-count-2"
,
Name
:
"K2"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-count-2"
,
Name
:
"K2"
})
count
,
err
:=
s
.
repo
.
CountByUserID
(
s
.
ctx
,
user
.
ID
)
count
,
err
:=
s
.
repo
.
CountByUserID
(
s
.
ctx
,
user
.
ID
)
s
.
Require
()
.
NoError
(
err
,
"CountByUserID"
)
s
.
Require
()
.
NoError
(
err
,
"CountByUserID"
)
...
@@ -184,12 +184,12 @@ func (s *ApiKeyRepoSuite) TestCountByUserID() {
...
@@ -184,12 +184,12 @@ func (s *ApiKeyRepoSuite) TestCountByUserID() {
// --- ListByGroupID / CountByGroupID ---
// --- ListByGroupID / CountByGroupID ---
func
(
s
*
ApiKeyRepoSuite
)
TestListByGroupID
()
{
func
(
s
*
ApiKeyRepoSuite
)
TestListByGroupID
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"listbygroup@test.com"
})
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
userModel
{
Email
:
"listbygroup@test.com"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g-list"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g-list"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-grp-1"
,
Name
:
"K1"
,
GroupID
:
&
group
.
ID
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-grp-1"
,
Name
:
"K1"
,
GroupID
:
&
group
.
ID
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-grp-2"
,
Name
:
"K2"
,
GroupID
:
&
group
.
ID
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-grp-2"
,
Name
:
"K2"
,
GroupID
:
&
group
.
ID
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-grp-3"
,
Name
:
"K3"
})
// no group
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-grp-3"
,
Name
:
"K3"
})
// no group
keys
,
page
,
err
:=
s
.
repo
.
ListByGroupID
(
s
.
ctx
,
group
.
ID
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
})
keys
,
page
,
err
:=
s
.
repo
.
ListByGroupID
(
s
.
ctx
,
group
.
ID
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
})
s
.
Require
()
.
NoError
(
err
,
"ListByGroupID"
)
s
.
Require
()
.
NoError
(
err
,
"ListByGroupID"
)
...
@@ -200,10 +200,10 @@ func (s *ApiKeyRepoSuite) TestListByGroupID() {
...
@@ -200,10 +200,10 @@ func (s *ApiKeyRepoSuite) TestListByGroupID() {
}
}
func
(
s
*
ApiKeyRepoSuite
)
TestCountByGroupID
()
{
func
(
s
*
ApiKeyRepoSuite
)
TestCountByGroupID
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"countgroup@test.com"
})
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
userModel
{
Email
:
"countgroup@test.com"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g-count"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g-count"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-gc-1"
,
Name
:
"K1"
,
GroupID
:
&
group
.
ID
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-gc-1"
,
Name
:
"K1"
,
GroupID
:
&
group
.
ID
})
count
,
err
:=
s
.
repo
.
CountByGroupID
(
s
.
ctx
,
group
.
ID
)
count
,
err
:=
s
.
repo
.
CountByGroupID
(
s
.
ctx
,
group
.
ID
)
s
.
Require
()
.
NoError
(
err
,
"CountByGroupID"
)
s
.
Require
()
.
NoError
(
err
,
"CountByGroupID"
)
...
@@ -213,8 +213,8 @@ func (s *ApiKeyRepoSuite) TestCountByGroupID() {
...
@@ -213,8 +213,8 @@ func (s *ApiKeyRepoSuite) TestCountByGroupID() {
// --- ExistsByKey ---
// --- ExistsByKey ---
func
(
s
*
ApiKeyRepoSuite
)
TestExistsByKey
()
{
func
(
s
*
ApiKeyRepoSuite
)
TestExistsByKey
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"exists@test.com"
})
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
userModel
{
Email
:
"exists@test.com"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-exists"
,
Name
:
"K"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-exists"
,
Name
:
"K"
})
exists
,
err
:=
s
.
repo
.
ExistsByKey
(
s
.
ctx
,
"sk-exists"
)
exists
,
err
:=
s
.
repo
.
ExistsByKey
(
s
.
ctx
,
"sk-exists"
)
s
.
Require
()
.
NoError
(
err
,
"ExistsByKey"
)
s
.
Require
()
.
NoError
(
err
,
"ExistsByKey"
)
...
@@ -228,9 +228,9 @@ func (s *ApiKeyRepoSuite) TestExistsByKey() {
...
@@ -228,9 +228,9 @@ func (s *ApiKeyRepoSuite) TestExistsByKey() {
// --- SearchApiKeys ---
// --- SearchApiKeys ---
func
(
s
*
ApiKeyRepoSuite
)
TestSearchApiKeys
()
{
func
(
s
*
ApiKeyRepoSuite
)
TestSearchApiKeys
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"search@test.com"
})
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
userModel
{
Email
:
"search@test.com"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-search-1"
,
Name
:
"Production Key"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-search-1"
,
Name
:
"Production Key"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-search-2"
,
Name
:
"Development Key"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-search-2"
,
Name
:
"Development Key"
})
found
,
err
:=
s
.
repo
.
SearchApiKeys
(
s
.
ctx
,
user
.
ID
,
"prod"
,
10
)
found
,
err
:=
s
.
repo
.
SearchApiKeys
(
s
.
ctx
,
user
.
ID
,
"prod"
,
10
)
s
.
Require
()
.
NoError
(
err
,
"SearchApiKeys"
)
s
.
Require
()
.
NoError
(
err
,
"SearchApiKeys"
)
...
@@ -239,9 +239,9 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
...
@@ -239,9 +239,9 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
}
}
func
(
s
*
ApiKeyRepoSuite
)
TestSearchApiKeys_NoKeyword
()
{
func
(
s
*
ApiKeyRepoSuite
)
TestSearchApiKeys_NoKeyword
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"searchnokw@test.com"
})
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
userModel
{
Email
:
"searchnokw@test.com"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-nk-1"
,
Name
:
"K1"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-nk-1"
,
Name
:
"K1"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-nk-2"
,
Name
:
"K2"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-nk-2"
,
Name
:
"K2"
})
found
,
err
:=
s
.
repo
.
SearchApiKeys
(
s
.
ctx
,
user
.
ID
,
""
,
10
)
found
,
err
:=
s
.
repo
.
SearchApiKeys
(
s
.
ctx
,
user
.
ID
,
""
,
10
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
NoError
(
err
)
...
@@ -249,8 +249,8 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
...
@@ -249,8 +249,8 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
}
}
func
(
s
*
ApiKeyRepoSuite
)
TestSearchApiKeys_NoUserID
()
{
func
(
s
*
ApiKeyRepoSuite
)
TestSearchApiKeys_NoUserID
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"searchnouid@test.com"
})
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
userModel
{
Email
:
"searchnouid@test.com"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-nu-1"
,
Name
:
"TestKey"
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-nu-1"
,
Name
:
"TestKey"
})
found
,
err
:=
s
.
repo
.
SearchApiKeys
(
s
.
ctx
,
0
,
"testkey"
,
10
)
found
,
err
:=
s
.
repo
.
SearchApiKeys
(
s
.
ctx
,
0
,
"testkey"
,
10
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
NoError
(
err
)
...
@@ -260,12 +260,12 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
...
@@ -260,12 +260,12 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
// --- ClearGroupIDByGroupID ---
// --- ClearGroupIDByGroupID ---
func
(
s
*
ApiKeyRepoSuite
)
TestClearGroupIDByGroupID
()
{
func
(
s
*
ApiKeyRepoSuite
)
TestClearGroupIDByGroupID
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"cleargrp@test.com"
})
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
userModel
{
Email
:
"cleargrp@test.com"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g-clear-bulk"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g-clear-bulk"
})
k1
:=
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-clr-1"
,
Name
:
"K1"
,
GroupID
:
&
group
.
ID
})
k1
:=
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-clr-1"
,
Name
:
"K1"
,
GroupID
:
&
group
.
ID
})
k2
:=
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-clr-2"
,
Name
:
"K2"
,
GroupID
:
&
group
.
ID
})
k2
:=
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-clr-2"
,
Name
:
"K2"
,
GroupID
:
&
group
.
ID
})
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
UserID
:
user
.
ID
,
Key
:
"sk-clr-3"
,
Name
:
"K3"
})
// no group
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
Key
:
"sk-clr-3"
,
Name
:
"K3"
})
// no group
affected
,
err
:=
s
.
repo
.
ClearGroupIDByGroupID
(
s
.
ctx
,
group
.
ID
)
affected
,
err
:=
s
.
repo
.
ClearGroupIDByGroupID
(
s
.
ctx
,
group
.
ID
)
s
.
Require
()
.
NoError
(
err
,
"ClearGroupIDByGroupID"
)
s
.
Require
()
.
NoError
(
err
,
"ClearGroupIDByGroupID"
)
...
@@ -283,16 +283,16 @@ func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
...
@@ -283,16 +283,16 @@ func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
// --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) ---
// --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) ---
func
(
s
*
ApiKeyRepoSuite
)
TestCRUD_Search_ClearGroupID
()
{
func
(
s
*
ApiKeyRepoSuite
)
TestCRUD_Search_ClearGroupID
()
{
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
model
.
User
{
Email
:
"k@example.com"
})
user
:=
mustCreateUser
(
s
.
T
(),
s
.
db
,
&
userModel
{
Email
:
"k@example.com"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g-k"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g-k"
})
key
:=
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
key
:=
apiKeyModelToService
(
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
UserID
:
user
.
ID
,
Key
:
"sk-test-1"
,
Key
:
"sk-test-1"
,
Name
:
"My Key"
,
Name
:
"My Key"
,
GroupID
:
&
group
.
ID
,
GroupID
:
&
group
.
ID
,
Status
:
model
.
StatusActive
,
Status
:
service
.
StatusActive
,
})
})
)
got
,
err
:=
s
.
repo
.
GetByKey
(
s
.
ctx
,
key
.
Key
)
got
,
err
:=
s
.
repo
.
GetByKey
(
s
.
ctx
,
key
.
Key
)
s
.
Require
()
.
NoError
(
err
,
"GetByKey"
)
s
.
Require
()
.
NoError
(
err
,
"GetByKey"
)
...
@@ -303,7 +303,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
...
@@ -303,7 +303,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s
.
Require
()
.
Equal
(
group
.
ID
,
got
.
Group
.
ID
)
s
.
Require
()
.
Equal
(
group
.
ID
,
got
.
Group
.
ID
)
key
.
Name
=
"Renamed"
key
.
Name
=
"Renamed"
key
.
Status
=
model
.
StatusDisabled
key
.
Status
=
service
.
StatusDisabled
key
.
GroupID
=
nil
key
.
GroupID
=
nil
s
.
Require
()
.
NoError
(
s
.
repo
.
Update
(
s
.
ctx
,
key
),
"Update"
)
s
.
Require
()
.
NoError
(
s
.
repo
.
Update
(
s
.
ctx
,
key
),
"Update"
)
...
@@ -312,7 +312,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
...
@@ -312,7 +312,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s
.
Require
()
.
Equal
(
"sk-test-1"
,
got2
.
Key
,
"Update should not change key"
)
s
.
Require
()
.
Equal
(
"sk-test-1"
,
got2
.
Key
,
"Update should not change key"
)
s
.
Require
()
.
Equal
(
user
.
ID
,
got2
.
UserID
,
"Update should not change user_id"
)
s
.
Require
()
.
Equal
(
user
.
ID
,
got2
.
UserID
,
"Update should not change user_id"
)
s
.
Require
()
.
Equal
(
"Renamed"
,
got2
.
Name
)
s
.
Require
()
.
Equal
(
"Renamed"
,
got2
.
Name
)
s
.
Require
()
.
Equal
(
model
.
StatusDisabled
,
got2
.
Status
)
s
.
Require
()
.
Equal
(
service
.
StatusDisabled
,
got2
.
Status
)
s
.
Require
()
.
Nil
(
got2
.
GroupID
)
s
.
Require
()
.
Nil
(
got2
.
GroupID
)
keys
,
page
,
err
:=
s
.
repo
.
ListByUserID
(
s
.
ctx
,
user
.
ID
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
})
keys
,
page
,
err
:=
s
.
repo
.
ListByUserID
(
s
.
ctx
,
user
.
ID
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
})
...
@@ -330,7 +330,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
...
@@ -330,7 +330,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s
.
Require
()
.
Equal
(
key
.
ID
,
found
[
0
]
.
ID
)
s
.
Require
()
.
Equal
(
key
.
ID
,
found
[
0
]
.
ID
)
// ClearGroupIDByGroupID
// ClearGroupIDByGroupID
k2
:=
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
model
.
ApiKey
{
k2
:=
mustCreateApiKey
(
s
.
T
(),
s
.
db
,
&
apiKeyModel
{
UserID
:
user
.
ID
,
UserID
:
user
.
ID
,
Key
:
"sk-test-2"
,
Key
:
"sk-test-2"
,
Name
:
"Group Key"
,
Name
:
"Group Key"
,
...
...
backend/internal/repository/auto_migrate.go
0 → 100644
View file @
22f07a7b
package
repository
import
"gorm.io/gorm"
// AutoMigrate runs schema migrations for all repository persistence models.
// Persistence models are defined within individual `*_repo.go` files.
func
AutoMigrate
(
db
*
gorm
.
DB
)
error
{
return
db
.
AutoMigrate
(
&
userModel
{},
&
apiKeyModel
{},
&
groupModel
{},
&
accountModel
{},
&
accountGroupModel
{},
&
proxyModel
{},
&
redeemCodeModel
{},
&
usageLogModel
{},
&
settingModel
{},
&
userSubscriptionModel
{},
)
}
backend/internal/repository/fixtures_integration_test.go
View file @
22f07a7b
...
@@ -6,21 +6,25 @@ import (
...
@@ -6,21 +6,25 @@ import (
"testing"
"testing"
"time"
"time"
"github.com/Wei-Shaw/sub2api/internal/
model
"
"github.com/Wei-Shaw/sub2api/internal/
service
"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/require"
"gorm.io/datatypes"
"gorm.io/gorm"
"gorm.io/gorm"
)
)
func
mustCreateUser
(
t
*
testing
.
T
,
db
*
gorm
.
DB
,
u
*
model
.
User
)
*
model
.
User
{
func
mustCreateUser
(
t
*
testing
.
T
,
db
*
gorm
.
DB
,
u
*
userModel
)
*
userModel
{
t
.
Helper
()
t
.
Helper
()
if
u
.
PasswordHash
==
""
{
if
u
.
PasswordHash
==
""
{
u
.
PasswordHash
=
"test-password-hash"
u
.
PasswordHash
=
"test-password-hash"
}
}
if
u
.
Role
==
""
{
if
u
.
Role
==
""
{
u
.
Role
=
model
.
RoleUser
u
.
Role
=
service
.
RoleUser
}
}
if
u
.
Status
==
""
{
if
u
.
Status
==
""
{
u
.
Status
=
model
.
StatusActive
u
.
Status
=
service
.
StatusActive
}
if
u
.
Concurrency
==
0
{
u
.
Concurrency
=
5
}
}
if
u
.
CreatedAt
.
IsZero
()
{
if
u
.
CreatedAt
.
IsZero
()
{
u
.
CreatedAt
=
time
.
Now
()
u
.
CreatedAt
=
time
.
Now
()
...
@@ -32,16 +36,16 @@ func mustCreateUser(t *testing.T, db *gorm.DB, u *model.User) *model.User {
...
@@ -32,16 +36,16 @@ func mustCreateUser(t *testing.T, db *gorm.DB, u *model.User) *model.User {
return
u
return
u
}
}
func
mustCreateGroup
(
t
*
testing
.
T
,
db
*
gorm
.
DB
,
g
*
model
.
Group
)
*
model
.
Group
{
func
mustCreateGroup
(
t
*
testing
.
T
,
db
*
gorm
.
DB
,
g
*
groupModel
)
*
groupModel
{
t
.
Helper
()
t
.
Helper
()
if
g
.
Platform
==
""
{
if
g
.
Platform
==
""
{
g
.
Platform
=
model
.
PlatformAnthropic
g
.
Platform
=
service
.
PlatformAnthropic
}
}
if
g
.
Status
==
""
{
if
g
.
Status
==
""
{
g
.
Status
=
model
.
StatusActive
g
.
Status
=
service
.
StatusActive
}
}
if
g
.
SubscriptionType
==
""
{
if
g
.
SubscriptionType
==
""
{
g
.
SubscriptionType
=
model
.
SubscriptionTypeStandard
g
.
SubscriptionType
=
service
.
SubscriptionTypeStandard
}
}
if
g
.
CreatedAt
.
IsZero
()
{
if
g
.
CreatedAt
.
IsZero
()
{
g
.
CreatedAt
=
time
.
Now
()
g
.
CreatedAt
=
time
.
Now
()
...
@@ -53,7 +57,7 @@ func mustCreateGroup(t *testing.T, db *gorm.DB, g *model.Group) *model.Group {
...
@@ -53,7 +57,7 @@ func mustCreateGroup(t *testing.T, db *gorm.DB, g *model.Group) *model.Group {
return
g
return
g
}
}
func
mustCreateProxy
(
t
*
testing
.
T
,
db
*
gorm
.
DB
,
p
*
model
.
Proxy
)
*
model
.
Proxy
{
func
mustCreateProxy
(
t
*
testing
.
T
,
db
*
gorm
.
DB
,
p
*
proxyModel
)
*
proxyModel
{
t
.
Helper
()
t
.
Helper
()
if
p
.
Protocol
==
""
{
if
p
.
Protocol
==
""
{
p
.
Protocol
=
"http"
p
.
Protocol
=
"http"
...
@@ -65,7 +69,7 @@ func mustCreateProxy(t *testing.T, db *gorm.DB, p *model.Proxy) *model.Proxy {
...
@@ -65,7 +69,7 @@ func mustCreateProxy(t *testing.T, db *gorm.DB, p *model.Proxy) *model.Proxy {
p
.
Port
=
8080
p
.
Port
=
8080
}
}
if
p
.
Status
==
""
{
if
p
.
Status
==
""
{
p
.
Status
=
model
.
StatusActive
p
.
Status
=
service
.
StatusActive
}
}
if
p
.
CreatedAt
.
IsZero
()
{
if
p
.
CreatedAt
.
IsZero
()
{
p
.
CreatedAt
=
time
.
Now
()
p
.
CreatedAt
=
time
.
Now
()
...
@@ -77,25 +81,25 @@ func mustCreateProxy(t *testing.T, db *gorm.DB, p *model.Proxy) *model.Proxy {
...
@@ -77,25 +81,25 @@ func mustCreateProxy(t *testing.T, db *gorm.DB, p *model.Proxy) *model.Proxy {
return
p
return
p
}
}
func
mustCreateAccount
(
t
*
testing
.
T
,
db
*
gorm
.
DB
,
a
*
model
.
Account
)
*
model
.
A
ccount
{
func
mustCreateAccount
(
t
*
testing
.
T
,
db
*
gorm
.
DB
,
a
*
accountModel
)
*
a
ccount
Model
{
t
.
Helper
()
t
.
Helper
()
if
a
.
Platform
==
""
{
if
a
.
Platform
==
""
{
a
.
Platform
=
model
.
PlatformAnthropic
a
.
Platform
=
service
.
PlatformAnthropic
}
}
if
a
.
Type
==
""
{
if
a
.
Type
==
""
{
a
.
Type
=
model
.
AccountTypeOAuth
a
.
Type
=
service
.
AccountTypeOAuth
}
}
if
a
.
Status
==
""
{
if
a
.
Status
==
""
{
a
.
Status
=
model
.
StatusActive
a
.
Status
=
service
.
StatusActive
}
}
if
!
a
.
Schedulable
{
if
!
a
.
Schedulable
{
a
.
Schedulable
=
true
a
.
Schedulable
=
true
}
}
if
a
.
Credentials
==
nil
{
if
a
.
Credentials
==
nil
{
a
.
Credentials
=
model
.
JSON
B
{}
a
.
Credentials
=
datatypes
.
JSON
Map
{}
}
}
if
a
.
Extra
==
nil
{
if
a
.
Extra
==
nil
{
a
.
Extra
=
model
.
JSON
B
{}
a
.
Extra
=
datatypes
.
JSON
Map
{}
}
}
if
a
.
CreatedAt
.
IsZero
()
{
if
a
.
CreatedAt
.
IsZero
()
{
a
.
CreatedAt
=
time
.
Now
()
a
.
CreatedAt
=
time
.
Now
()
...
@@ -107,10 +111,10 @@ func mustCreateAccount(t *testing.T, db *gorm.DB, a *model.Account) *model.Accou
...
@@ -107,10 +111,10 @@ func mustCreateAccount(t *testing.T, db *gorm.DB, a *model.Account) *model.Accou
return
a
return
a
}
}
func
mustCreateApiKey
(
t
*
testing
.
T
,
db
*
gorm
.
DB
,
k
*
model
.
ApiKey
)
*
model
.
ApiKey
{
func
mustCreateApiKey
(
t
*
testing
.
T
,
db
*
gorm
.
DB
,
k
*
apiKeyModel
)
*
apiKeyModel
{
t
.
Helper
()
t
.
Helper
()
if
k
.
Status
==
""
{
if
k
.
Status
==
""
{
k
.
Status
=
model
.
StatusActive
k
.
Status
=
service
.
StatusActive
}
}
if
k
.
CreatedAt
.
IsZero
()
{
if
k
.
CreatedAt
.
IsZero
()
{
k
.
CreatedAt
=
time
.
Now
()
k
.
CreatedAt
=
time
.
Now
()
...
@@ -122,13 +126,13 @@ func mustCreateApiKey(t *testing.T, db *gorm.DB, k *model.ApiKey) *model.ApiKey
...
@@ -122,13 +126,13 @@ func mustCreateApiKey(t *testing.T, db *gorm.DB, k *model.ApiKey) *model.ApiKey
return
k
return
k
}
}
func
mustCreateRedeemCode
(
t
*
testing
.
T
,
db
*
gorm
.
DB
,
c
*
model
.
R
edeemCode
)
*
model
.
R
edeemCode
{
func
mustCreateRedeemCode
(
t
*
testing
.
T
,
db
*
gorm
.
DB
,
c
*
r
edeemCode
Model
)
*
r
edeemCode
Model
{
t
.
Helper
()
t
.
Helper
()
if
c
.
Status
==
""
{
if
c
.
Status
==
""
{
c
.
Status
=
model
.
StatusUnused
c
.
Status
=
service
.
StatusUnused
}
}
if
c
.
Type
==
""
{
if
c
.
Type
==
""
{
c
.
Type
=
model
.
RedeemTypeBalance
c
.
Type
=
service
.
RedeemTypeBalance
}
}
if
c
.
CreatedAt
.
IsZero
()
{
if
c
.
CreatedAt
.
IsZero
()
{
c
.
CreatedAt
=
time
.
Now
()
c
.
CreatedAt
=
time
.
Now
()
...
@@ -137,10 +141,10 @@ func mustCreateRedeemCode(t *testing.T, db *gorm.DB, c *model.RedeemCode) *model
...
@@ -137,10 +141,10 @@ func mustCreateRedeemCode(t *testing.T, db *gorm.DB, c *model.RedeemCode) *model
return
c
return
c
}
}
func
mustCreateSubscription
(
t
*
testing
.
T
,
db
*
gorm
.
DB
,
s
*
model
.
U
serSubscription
)
*
model
.
U
serSubscription
{
func
mustCreateSubscription
(
t
*
testing
.
T
,
db
*
gorm
.
DB
,
s
*
u
serSubscription
Model
)
*
u
serSubscription
Model
{
t
.
Helper
()
t
.
Helper
()
if
s
.
Status
==
""
{
if
s
.
Status
==
""
{
s
.
Status
=
model
.
SubscriptionStatusActive
s
.
Status
=
service
.
SubscriptionStatusActive
}
}
now
:=
time
.
Now
()
now
:=
time
.
Now
()
if
s
.
StartsAt
.
IsZero
()
{
if
s
.
StartsAt
.
IsZero
()
{
...
@@ -164,9 +168,10 @@ func mustCreateSubscription(t *testing.T, db *gorm.DB, s *model.UserSubscription
...
@@ -164,9 +168,10 @@ func mustCreateSubscription(t *testing.T, db *gorm.DB, s *model.UserSubscription
func
mustBindAccountToGroup
(
t
*
testing
.
T
,
db
*
gorm
.
DB
,
accountID
,
groupID
int64
,
priority
int
)
{
func
mustBindAccountToGroup
(
t
*
testing
.
T
,
db
*
gorm
.
DB
,
accountID
,
groupID
int64
,
priority
int
)
{
t
.
Helper
()
t
.
Helper
()
require
.
NoError
(
t
,
db
.
Create
(
&
model
.
A
ccountGroup
{
require
.
NoError
(
t
,
db
.
Create
(
&
a
ccountGroup
Model
{
AccountID
:
accountID
,
AccountID
:
accountID
,
GroupID
:
groupID
,
GroupID
:
groupID
,
Priority
:
priority
,
Priority
:
priority
,
CreatedAt
:
time
.
Now
(),
})
.
Error
,
"create account_group"
)
})
.
Error
,
"create account_group"
)
}
}
backend/internal/repository/group_repo.go
View file @
22f07a7b
...
@@ -2,10 +2,10 @@ package repository
...
@@ -2,10 +2,10 @@ package repository
import
(
import
(
"context"
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
"gorm.io/gorm"
...
@@ -20,38 +20,50 @@ func NewGroupRepository(db *gorm.DB) service.GroupRepository {
...
@@ -20,38 +20,50 @@ func NewGroupRepository(db *gorm.DB) service.GroupRepository {
return
&
groupRepository
{
db
:
db
}
return
&
groupRepository
{
db
:
db
}
}
}
func
(
r
*
groupRepository
)
Create
(
ctx
context
.
Context
,
group
*
model
.
Group
)
error
{
func
(
r
*
groupRepository
)
Create
(
ctx
context
.
Context
,
group
*
service
.
Group
)
error
{
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Create
(
group
)
.
Error
m
:=
groupModelFromService
(
group
)
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Create
(
m
)
.
Error
if
err
==
nil
{
applyGroupModelToService
(
group
,
m
)
}
return
translatePersistenceError
(
err
,
nil
,
service
.
ErrGroupExists
)
return
translatePersistenceError
(
err
,
nil
,
service
.
ErrGroupExists
)
}
}
func
(
r
*
groupRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
Group
,
error
)
{
func
(
r
*
groupRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
Group
,
error
)
{
var
group
m
odel
.
Group
var
m
group
M
odel
err
:=
r
.
db
.
WithContext
(
ctx
)
.
First
(
&
group
,
id
)
.
Error
err
:=
r
.
db
.
WithContext
(
ctx
)
.
First
(
&
m
,
id
)
.
Error
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
translatePersistenceError
(
err
,
service
.
ErrGroupNotFound
,
nil
)
return
nil
,
translatePersistenceError
(
err
,
service
.
ErrGroupNotFound
,
nil
)
}
}
return
&
group
,
nil
group
:=
groupModelToService
(
&
m
)
count
,
_
:=
r
.
GetAccountCount
(
ctx
,
group
.
ID
)
group
.
AccountCount
=
count
return
group
,
nil
}
}
func
(
r
*
groupRepository
)
Update
(
ctx
context
.
Context
,
group
*
model
.
Group
)
error
{
func
(
r
*
groupRepository
)
Update
(
ctx
context
.
Context
,
group
*
service
.
Group
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Save
(
group
)
.
Error
m
:=
groupModelFromService
(
group
)
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Save
(
m
)
.
Error
if
err
==
nil
{
applyGroupModelToService
(
group
,
m
)
}
return
err
}
}
func
(
r
*
groupRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
func
(
r
*
groupRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Delete
(
&
model
.
Group
{},
id
)
.
Error
return
r
.
db
.
WithContext
(
ctx
)
.
Delete
(
&
groupModel
{},
id
)
.
Error
}
}
func
(
r
*
groupRepository
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
model
.
Group
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
groupRepository
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
service
.
Group
,
*
pagination
.
PaginationResult
,
error
)
{
return
r
.
ListWithFilters
(
ctx
,
params
,
""
,
""
,
nil
)
return
r
.
ListWithFilters
(
ctx
,
params
,
""
,
""
,
nil
)
}
}
// ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive
// ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive
func
(
r
*
groupRepository
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
status
string
,
isExclusive
*
bool
)
([]
model
.
Group
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
groupRepository
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
status
string
,
isExclusive
*
bool
)
([]
service
.
Group
,
*
pagination
.
PaginationResult
,
error
)
{
var
groups
[]
model
.
Group
var
groups
[]
groupModel
var
total
int64
var
total
int64
db
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
Group
{})
db
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
groupModel
{})
// Apply filters
// Apply filters
if
platform
!=
""
{
if
platform
!=
""
{
...
@@ -72,68 +84,71 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
...
@@ -72,68 +84,71 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
return
nil
,
nil
,
err
return
nil
,
nil
,
err
}
}
// 获取每个分组的账号数量
outGroups
:=
make
([]
service
.
Group
,
0
,
len
(
groups
))
for
i
:=
range
groups
{
for
i
:=
range
groups
{
count
,
_
:=
r
.
GetAccountCount
(
ctx
,
groups
[
i
]
.
ID
)
outGroups
=
append
(
outGroups
,
*
groupModelToService
(
&
groups
[
i
]))
groups
[
i
]
.
AccountCount
=
count
}
}
pages
:=
int
(
total
)
/
params
.
Limit
()
// 获取每个分组的账号数量
if
int
(
total
)
%
params
.
Limit
()
>
0
{
for
i
:=
range
outGroups
{
pages
++
count
,
_
:=
r
.
GetAccountCount
(
ctx
,
outGroups
[
i
]
.
ID
)
outGroups
[
i
]
.
AccountCount
=
count
}
}
return
groups
,
&
pagination
.
PaginationResult
{
return
outGroups
,
paginationResultFromTotal
(
total
,
params
),
nil
Total
:
total
,
Page
:
params
.
Page
,
PageSize
:
params
.
Limit
(),
Pages
:
pages
,
},
nil
}
}
func
(
r
*
groupRepository
)
ListActive
(
ctx
context
.
Context
)
([]
model
.
Group
,
error
)
{
func
(
r
*
groupRepository
)
ListActive
(
ctx
context
.
Context
)
([]
service
.
Group
,
error
)
{
var
groups
[]
model
.
Group
var
groups
[]
groupModel
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"status = ?"
,
model
.
StatusActive
)
.
Order
(
"id ASC"
)
.
Find
(
&
groups
)
.
Error
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"status = ?"
,
service
.
StatusActive
)
.
Order
(
"id ASC"
)
.
Find
(
&
groups
)
.
Error
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
// 获取每个分组的账号数量
outGroups
:=
make
([]
service
.
Group
,
0
,
len
(
groups
))
for
i
:=
range
groups
{
for
i
:=
range
groups
{
count
,
_
:=
r
.
GetAccountCount
(
ctx
,
groups
[
i
]
.
ID
)
outGroups
=
append
(
outGroups
,
*
groupModelToService
(
&
groups
[
i
]))
groups
[
i
]
.
AccountCount
=
count
}
// 获取每个分组的账号数量
for
i
:=
range
outGroups
{
count
,
_
:=
r
.
GetAccountCount
(
ctx
,
outGroups
[
i
]
.
ID
)
outGroups
[
i
]
.
AccountCount
=
count
}
}
return
g
roups
,
nil
return
outG
roups
,
nil
}
}
func
(
r
*
groupRepository
)
ListActiveByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
model
.
Group
,
error
)
{
func
(
r
*
groupRepository
)
ListActiveByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
service
.
Group
,
error
)
{
var
groups
[]
model
.
Group
var
groups
[]
groupModel
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"status = ? AND platform = ?"
,
model
.
StatusActive
,
platform
)
.
Order
(
"id ASC"
)
.
Find
(
&
groups
)
.
Error
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"status = ? AND platform = ?"
,
service
.
StatusActive
,
platform
)
.
Order
(
"id ASC"
)
.
Find
(
&
groups
)
.
Error
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
// 获取每个分组的账号数量
outGroups
:=
make
([]
service
.
Group
,
0
,
len
(
groups
))
for
i
:=
range
groups
{
for
i
:=
range
groups
{
count
,
_
:=
r
.
GetAccountCount
(
ctx
,
groups
[
i
]
.
ID
)
outGroups
=
append
(
outGroups
,
*
groupModelToService
(
&
groups
[
i
]))
groups
[
i
]
.
AccountCount
=
count
}
// 获取每个分组的账号数量
for
i
:=
range
outGroups
{
count
,
_
:=
r
.
GetAccountCount
(
ctx
,
outGroups
[
i
]
.
ID
)
outGroups
[
i
]
.
AccountCount
=
count
}
}
return
g
roups
,
nil
return
outG
roups
,
nil
}
}
func
(
r
*
groupRepository
)
ExistsByName
(
ctx
context
.
Context
,
name
string
)
(
bool
,
error
)
{
func
(
r
*
groupRepository
)
ExistsByName
(
ctx
context
.
Context
,
name
string
)
(
bool
,
error
)
{
var
count
int64
var
count
int64
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
Group
{})
.
Where
(
"name = ?"
,
name
)
.
Count
(
&
count
)
.
Error
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
groupModel
{})
.
Where
(
"name = ?"
,
name
)
.
Count
(
&
count
)
.
Error
return
count
>
0
,
err
return
count
>
0
,
err
}
}
func
(
r
*
groupRepository
)
GetAccountCount
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
func
(
r
*
groupRepository
)
GetAccountCount
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
var
count
int64
var
count
int64
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
A
ccount
G
roup
{}
)
.
Where
(
"group_id = ?"
,
groupID
)
.
Count
(
&
count
)
.
Error
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Table
(
"a
ccount
_g
roup
s"
)
.
Where
(
"group_id = ?"
,
groupID
)
.
Count
(
&
count
)
.
Error
return
count
,
err
return
count
,
err
}
}
// DeleteAccountGroupsByGroupID 删除分组与账号的关联关系
// DeleteAccountGroupsByGroupID 删除分组与账号的关联关系
func
(
r
*
groupRepository
)
DeleteAccountGroupsByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
func
(
r
*
groupRepository
)
DeleteAccountGroupsByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
result
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"group_id = ?"
,
groupID
)
.
Delete
(
&
model
.
AccountGroup
{}
)
result
:=
r
.
db
.
WithContext
(
ctx
)
.
Exec
(
"DELETE FROM account_groups WHERE group_id = ?"
,
groupID
)
return
result
.
RowsAffected
,
result
.
Error
return
result
.
RowsAffected
,
result
.
Error
}
}
...
@@ -145,46 +160,42 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
...
@@ -145,46 +160,42 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
var
affectedUserIDs
[]
int64
var
affectedUserIDs
[]
int64
if
group
.
IsSubscriptionType
()
{
if
group
.
IsSubscriptionType
()
{
var
subscriptions
[]
model
.
UserSubscription
if
err
:=
r
.
db
.
WithContext
(
ctx
)
.
if
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
U
ser
S
ubscription
{}
)
.
Table
(
"u
ser
_s
ubscription
s"
)
.
Where
(
"group_id = ?"
,
id
)
.
Where
(
"group_id = ?"
,
id
)
.
Select
(
"user_id"
)
.
Pluck
(
"user_id"
,
&
affectedUserIDs
)
.
Error
;
err
!=
nil
{
Find
(
&
subscriptions
)
.
Error
;
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
for
_
,
sub
:=
range
subscriptions
{
affectedUserIDs
=
append
(
affectedUserIDs
,
sub
.
UserID
)
}
}
}
err
=
r
.
db
.
WithContext
(
ctx
)
.
Transaction
(
func
(
tx
*
gorm
.
DB
)
error
{
err
=
r
.
db
.
WithContext
(
ctx
)
.
Transaction
(
func
(
tx
*
gorm
.
DB
)
error
{
// 1. 删除订阅类型分组的订阅记录
// 1. 删除订阅类型分组的订阅记录
if
group
.
IsSubscriptionType
()
{
if
group
.
IsSubscriptionType
()
{
if
err
:=
tx
.
Where
(
"group_id = ?"
,
id
)
.
Delete
(
&
model
.
UserSubscription
{}
)
.
Error
;
err
!=
nil
{
if
err
:=
tx
.
Exec
(
"DELETE FROM user_subscriptions WHERE group_id = ?"
,
id
)
.
Error
;
err
!=
nil
{
return
err
return
err
}
}
}
}
// 2. 将 api_keys 中绑定该分组的 group_id 设为 nil
// 2. 将 api_keys 中绑定该分组的 group_id 设为 nil
if
err
:=
tx
.
Model
(
&
model
.
ApiKey
{})
.
Where
(
"group_id = ?"
,
id
)
.
Update
(
"group_id"
,
nil
)
.
Error
;
err
!=
nil
{
if
err
:=
tx
.
Exec
(
"UPDATE api_keys SET group_id = NULL WHERE group_id = ?"
,
id
)
.
Error
;
err
!=
nil
{
return
err
return
err
}
}
// 3. 从 users.allowed_groups 数组中移除该分组 ID
// 3. 从 users.allowed_groups 数组中移除该分组 ID
if
err
:=
tx
.
Model
(
&
model
.
User
{})
.
if
err
:=
tx
.
Exec
(
Where
(
"? = ANY(allowed_groups)"
,
id
)
.
"UPDATE users SET allowed_groups = array_remove(allowed_groups, ?) WHERE ? = ANY(allowed_groups)"
,
Update
(
"allowed_groups"
,
gorm
.
Expr
(
"array_remove(allowed_groups, ?)"
,
id
))
.
Error
;
err
!=
nil
{
id
,
id
,
)
.
Error
;
err
!=
nil
{
return
err
return
err
}
}
// 4. 删除 account_groups 中间表的数据
// 4. 删除 account_groups 中间表的数据
if
err
:=
tx
.
Where
(
"group_id = ?"
,
id
)
.
Delete
(
&
model
.
AccountGroup
{}
)
.
Error
;
err
!=
nil
{
if
err
:=
tx
.
Exec
(
"DELETE FROM account_groups WHERE group_id = ?"
,
id
)
.
Error
;
err
!=
nil
{
return
err
return
err
}
}
// 5. 删除分组本身(带锁,避免并发写)
// 5. 删除分组本身(带锁,避免并发写)
if
err
:=
tx
.
Clauses
(
clause
.
Locking
{
Strength
:
"UPDATE"
})
.
Delete
(
&
model
.
Group
{},
id
)
.
Error
;
err
!=
nil
{
if
err
:=
tx
.
Clauses
(
clause
.
Locking
{
Strength
:
"UPDATE"
})
.
Delete
(
&
groupModel
{},
id
)
.
Error
;
err
!=
nil
{
return
err
return
err
}
}
...
@@ -196,3 +207,75 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
...
@@ -196,3 +207,75 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
return
affectedUserIDs
,
nil
return
affectedUserIDs
,
nil
}
}
type
groupModel
struct
{
ID
int64
`gorm:"primaryKey"`
Name
string
`gorm:"uniqueIndex;size:100;not null"`
Description
string
`gorm:"type:text"`
Platform
string
`gorm:"size:50;default:anthropic;not null"`
RateMultiplier
float64
`gorm:"type:decimal(10,4);default:1.0;not null"`
IsExclusive
bool
`gorm:"default:false;not null"`
Status
string
`gorm:"size:20;default:active;not null"`
SubscriptionType
string
`gorm:"size:20;default:standard;not null"`
DailyLimitUSD
*
float64
`gorm:"type:decimal(20,8)"`
WeeklyLimitUSD
*
float64
`gorm:"type:decimal(20,8)"`
MonthlyLimitUSD
*
float64
`gorm:"type:decimal(20,8)"`
CreatedAt
time
.
Time
`gorm:"not null"`
UpdatedAt
time
.
Time
`gorm:"not null"`
DeletedAt
gorm
.
DeletedAt
`gorm:"index"`
}
func
(
groupModel
)
TableName
()
string
{
return
"groups"
}
func
groupModelToService
(
m
*
groupModel
)
*
service
.
Group
{
if
m
==
nil
{
return
nil
}
return
&
service
.
Group
{
ID
:
m
.
ID
,
Name
:
m
.
Name
,
Description
:
m
.
Description
,
Platform
:
m
.
Platform
,
RateMultiplier
:
m
.
RateMultiplier
,
IsExclusive
:
m
.
IsExclusive
,
Status
:
m
.
Status
,
SubscriptionType
:
m
.
SubscriptionType
,
DailyLimitUSD
:
m
.
DailyLimitUSD
,
WeeklyLimitUSD
:
m
.
WeeklyLimitUSD
,
MonthlyLimitUSD
:
m
.
MonthlyLimitUSD
,
CreatedAt
:
m
.
CreatedAt
,
UpdatedAt
:
m
.
UpdatedAt
,
}
}
func
groupModelFromService
(
sg
*
service
.
Group
)
*
groupModel
{
if
sg
==
nil
{
return
nil
}
return
&
groupModel
{
ID
:
sg
.
ID
,
Name
:
sg
.
Name
,
Description
:
sg
.
Description
,
Platform
:
sg
.
Platform
,
RateMultiplier
:
sg
.
RateMultiplier
,
IsExclusive
:
sg
.
IsExclusive
,
Status
:
sg
.
Status
,
SubscriptionType
:
sg
.
SubscriptionType
,
DailyLimitUSD
:
sg
.
DailyLimitUSD
,
WeeklyLimitUSD
:
sg
.
WeeklyLimitUSD
,
MonthlyLimitUSD
:
sg
.
MonthlyLimitUSD
,
CreatedAt
:
sg
.
CreatedAt
,
UpdatedAt
:
sg
.
UpdatedAt
,
}
}
func
applyGroupModelToService
(
group
*
service
.
Group
,
m
*
groupModel
)
{
if
group
==
nil
||
m
==
nil
{
return
}
group
.
ID
=
m
.
ID
group
.
CreatedAt
=
m
.
CreatedAt
group
.
UpdatedAt
=
m
.
UpdatedAt
}
backend/internal/repository/group_repo_integration_test.go
View file @
22f07a7b
...
@@ -6,8 +6,8 @@ import (
...
@@ -6,8 +6,8 @@ import (
"context"
"context"
"testing"
"testing"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
"gorm.io/gorm"
)
)
...
@@ -32,10 +32,10 @@ func TestGroupRepoSuite(t *testing.T) {
...
@@ -32,10 +32,10 @@ func TestGroupRepoSuite(t *testing.T) {
// --- Create / GetByID / Update / Delete ---
// --- Create / GetByID / Update / Delete ---
func
(
s
*
GroupRepoSuite
)
TestCreate
()
{
func
(
s
*
GroupRepoSuite
)
TestCreate
()
{
group
:=
&
model
.
Group
{
group
:=
&
service
.
Group
{
Name
:
"test-create"
,
Name
:
"test-create"
,
Platform
:
model
.
PlatformAnthropic
,
Platform
:
service
.
PlatformAnthropic
,
Status
:
model
.
StatusActive
,
Status
:
service
.
StatusActive
,
}
}
err
:=
s
.
repo
.
Create
(
s
.
ctx
,
group
)
err
:=
s
.
repo
.
Create
(
s
.
ctx
,
group
)
...
@@ -53,7 +53,7 @@ func (s *GroupRepoSuite) TestGetByID_NotFound() {
...
@@ -53,7 +53,7 @@ func (s *GroupRepoSuite) TestGetByID_NotFound() {
}
}
func
(
s
*
GroupRepoSuite
)
TestUpdate
()
{
func
(
s
*
GroupRepoSuite
)
TestUpdate
()
{
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"original"
})
group
:=
groupModelToService
(
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"original"
})
)
group
.
Name
=
"updated"
group
.
Name
=
"updated"
err
:=
s
.
repo
.
Update
(
s
.
ctx
,
group
)
err
:=
s
.
repo
.
Update
(
s
.
ctx
,
group
)
...
@@ -65,7 +65,7 @@ func (s *GroupRepoSuite) TestUpdate() {
...
@@ -65,7 +65,7 @@ func (s *GroupRepoSuite) TestUpdate() {
}
}
func
(
s
*
GroupRepoSuite
)
TestDelete
()
{
func
(
s
*
GroupRepoSuite
)
TestDelete
()
{
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"to-delete"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"to-delete"
})
err
:=
s
.
repo
.
Delete
(
s
.
ctx
,
group
.
ID
)
err
:=
s
.
repo
.
Delete
(
s
.
ctx
,
group
.
ID
)
s
.
Require
()
.
NoError
(
err
,
"Delete"
)
s
.
Require
()
.
NoError
(
err
,
"Delete"
)
...
@@ -77,8 +77,8 @@ func (s *GroupRepoSuite) TestDelete() {
...
@@ -77,8 +77,8 @@ func (s *GroupRepoSuite) TestDelete() {
// --- List / ListWithFilters ---
// --- List / ListWithFilters ---
func
(
s
*
GroupRepoSuite
)
TestList
()
{
func
(
s
*
GroupRepoSuite
)
TestList
()
{
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g1"
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g1"
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g2"
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g2"
})
groups
,
page
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
})
groups
,
page
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
})
s
.
Require
()
.
NoError
(
err
,
"List"
)
s
.
Require
()
.
NoError
(
err
,
"List"
)
...
@@ -87,28 +87,28 @@ func (s *GroupRepoSuite) TestList() {
...
@@ -87,28 +87,28 @@ func (s *GroupRepoSuite) TestList() {
}
}
func
(
s
*
GroupRepoSuite
)
TestListWithFilters_Platform
()
{
func
(
s
*
GroupRepoSuite
)
TestListWithFilters_Platform
()
{
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g1"
,
Platform
:
model
.
PlatformAnthropic
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g1"
,
Platform
:
service
.
PlatformAnthropic
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g2"
,
Platform
:
model
.
PlatformOpenAI
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g2"
,
Platform
:
service
.
PlatformOpenAI
})
groups
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
model
.
PlatformOpenAI
,
""
,
nil
)
groups
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
service
.
PlatformOpenAI
,
""
,
nil
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
groups
,
1
)
s
.
Require
()
.
Len
(
groups
,
1
)
s
.
Require
()
.
Equal
(
model
.
PlatformOpenAI
,
groups
[
0
]
.
Platform
)
s
.
Require
()
.
Equal
(
service
.
PlatformOpenAI
,
groups
[
0
]
.
Platform
)
}
}
func
(
s
*
GroupRepoSuite
)
TestListWithFilters_Status
()
{
func
(
s
*
GroupRepoSuite
)
TestListWithFilters_Status
()
{
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g1"
,
Status
:
model
.
StatusActive
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g1"
,
Status
:
service
.
StatusActive
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g2"
,
Status
:
model
.
StatusDisabled
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g2"
,
Status
:
service
.
StatusDisabled
})
groups
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
""
,
model
.
StatusDisabled
,
nil
)
groups
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
""
,
service
.
StatusDisabled
,
nil
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
groups
,
1
)
s
.
Require
()
.
Len
(
groups
,
1
)
s
.
Require
()
.
Equal
(
model
.
StatusDisabled
,
groups
[
0
]
.
Status
)
s
.
Require
()
.
Equal
(
service
.
StatusDisabled
,
groups
[
0
]
.
Status
)
}
}
func
(
s
*
GroupRepoSuite
)
TestListWithFilters_IsExclusive
()
{
func
(
s
*
GroupRepoSuite
)
TestListWithFilters_IsExclusive
()
{
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g1"
,
IsExclusive
:
false
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g1"
,
IsExclusive
:
false
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g2"
,
IsExclusive
:
true
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g2"
,
IsExclusive
:
true
})
isExclusive
:=
true
isExclusive
:=
true
groups
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
""
,
""
,
&
isExclusive
)
groups
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
""
,
""
,
&
isExclusive
)
...
@@ -118,24 +118,24 @@ func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
...
@@ -118,24 +118,24 @@ func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
}
}
func
(
s
*
GroupRepoSuite
)
TestListWithFilters_AccountCount
()
{
func
(
s
*
GroupRepoSuite
)
TestListWithFilters_AccountCount
()
{
g1
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
g1
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g1"
,
Name
:
"g1"
,
Platform
:
model
.
PlatformAnthropic
,
Platform
:
service
.
PlatformAnthropic
,
Status
:
model
.
StatusActive
,
Status
:
service
.
StatusActive
,
})
})
g2
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
g2
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g2"
,
Name
:
"g2"
,
Platform
:
model
.
PlatformAnthropic
,
Platform
:
service
.
PlatformAnthropic
,
Status
:
model
.
StatusActive
,
Status
:
service
.
StatusActive
,
IsExclusive
:
true
,
IsExclusive
:
true
,
})
})
a
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"acc1"
})
a
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"acc1"
})
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
a
.
ID
,
g1
.
ID
,
1
)
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
a
.
ID
,
g1
.
ID
,
1
)
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
a
.
ID
,
g2
.
ID
,
1
)
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
a
.
ID
,
g2
.
ID
,
1
)
isExclusive
:=
true
isExclusive
:=
true
groups
,
page
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
model
.
PlatformAnthropic
,
model
.
StatusActive
,
&
isExclusive
)
groups
,
page
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
service
.
PlatformAnthropic
,
service
.
StatusActive
,
&
isExclusive
)
s
.
Require
()
.
NoError
(
err
,
"ListWithFilters"
)
s
.
Require
()
.
NoError
(
err
,
"ListWithFilters"
)
s
.
Require
()
.
Equal
(
int64
(
1
),
page
.
Total
)
s
.
Require
()
.
Equal
(
int64
(
1
),
page
.
Total
)
s
.
Require
()
.
Len
(
groups
,
1
)
s
.
Require
()
.
Len
(
groups
,
1
)
...
@@ -146,8 +146,8 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
...
@@ -146,8 +146,8 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
// --- ListActive / ListActiveByPlatform ---
// --- ListActive / ListActiveByPlatform ---
func
(
s
*
GroupRepoSuite
)
TestListActive
()
{
func
(
s
*
GroupRepoSuite
)
TestListActive
()
{
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"active1"
,
Status
:
model
.
StatusActive
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"active1"
,
Status
:
service
.
StatusActive
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"inactive1"
,
Status
:
model
.
StatusDisabled
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"inactive1"
,
Status
:
service
.
StatusDisabled
})
groups
,
err
:=
s
.
repo
.
ListActive
(
s
.
ctx
)
groups
,
err
:=
s
.
repo
.
ListActive
(
s
.
ctx
)
s
.
Require
()
.
NoError
(
err
,
"ListActive"
)
s
.
Require
()
.
NoError
(
err
,
"ListActive"
)
...
@@ -156,11 +156,11 @@ func (s *GroupRepoSuite) TestListActive() {
...
@@ -156,11 +156,11 @@ func (s *GroupRepoSuite) TestListActive() {
}
}
func
(
s
*
GroupRepoSuite
)
TestListActiveByPlatform
()
{
func
(
s
*
GroupRepoSuite
)
TestListActiveByPlatform
()
{
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g1"
,
Platform
:
model
.
PlatformAnthropic
,
Status
:
model
.
StatusActive
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g1"
,
Platform
:
service
.
PlatformAnthropic
,
Status
:
service
.
StatusActive
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g2"
,
Platform
:
model
.
PlatformOpenAI
,
Status
:
model
.
StatusActive
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g2"
,
Platform
:
service
.
PlatformOpenAI
,
Status
:
service
.
StatusActive
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g3"
,
Platform
:
model
.
PlatformAnthropic
,
Status
:
model
.
StatusDisabled
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g3"
,
Platform
:
service
.
PlatformAnthropic
,
Status
:
service
.
StatusDisabled
})
groups
,
err
:=
s
.
repo
.
ListActiveByPlatform
(
s
.
ctx
,
model
.
PlatformAnthropic
)
groups
,
err
:=
s
.
repo
.
ListActiveByPlatform
(
s
.
ctx
,
service
.
PlatformAnthropic
)
s
.
Require
()
.
NoError
(
err
,
"ListActiveByPlatform"
)
s
.
Require
()
.
NoError
(
err
,
"ListActiveByPlatform"
)
s
.
Require
()
.
Len
(
groups
,
1
)
s
.
Require
()
.
Len
(
groups
,
1
)
s
.
Require
()
.
Equal
(
"g1"
,
groups
[
0
]
.
Name
)
s
.
Require
()
.
Equal
(
"g1"
,
groups
[
0
]
.
Name
)
...
@@ -169,7 +169,7 @@ func (s *GroupRepoSuite) TestListActiveByPlatform() {
...
@@ -169,7 +169,7 @@ func (s *GroupRepoSuite) TestListActiveByPlatform() {
// --- ExistsByName ---
// --- ExistsByName ---
func
(
s
*
GroupRepoSuite
)
TestExistsByName
()
{
func
(
s
*
GroupRepoSuite
)
TestExistsByName
()
{
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"existing-group"
})
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"existing-group"
})
exists
,
err
:=
s
.
repo
.
ExistsByName
(
s
.
ctx
,
"existing-group"
)
exists
,
err
:=
s
.
repo
.
ExistsByName
(
s
.
ctx
,
"existing-group"
)
s
.
Require
()
.
NoError
(
err
,
"ExistsByName"
)
s
.
Require
()
.
NoError
(
err
,
"ExistsByName"
)
...
@@ -183,9 +183,9 @@ func (s *GroupRepoSuite) TestExistsByName() {
...
@@ -183,9 +183,9 @@ func (s *GroupRepoSuite) TestExistsByName() {
// --- GetAccountCount ---
// --- GetAccountCount ---
func
(
s
*
GroupRepoSuite
)
TestGetAccountCount
()
{
func
(
s
*
GroupRepoSuite
)
TestGetAccountCount
()
{
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g-count"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g-count"
})
a1
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"a1"
})
a1
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"a1"
})
a2
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"a2"
})
a2
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"a2"
})
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
a1
.
ID
,
group
.
ID
,
1
)
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
a1
.
ID
,
group
.
ID
,
1
)
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
a2
.
ID
,
group
.
ID
,
2
)
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
a2
.
ID
,
group
.
ID
,
2
)
...
@@ -195,7 +195,7 @@ func (s *GroupRepoSuite) TestGetAccountCount() {
...
@@ -195,7 +195,7 @@ func (s *GroupRepoSuite) TestGetAccountCount() {
}
}
func
(
s
*
GroupRepoSuite
)
TestGetAccountCount_Empty
()
{
func
(
s
*
GroupRepoSuite
)
TestGetAccountCount_Empty
()
{
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g-empty"
})
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g-empty"
})
count
,
err
:=
s
.
repo
.
GetAccountCount
(
s
.
ctx
,
group
.
ID
)
count
,
err
:=
s
.
repo
.
GetAccountCount
(
s
.
ctx
,
group
.
ID
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
NoError
(
err
)
...
@@ -205,8 +205,8 @@ func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
...
@@ -205,8 +205,8 @@ func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
// --- DeleteAccountGroupsByGroupID ---
// --- DeleteAccountGroupsByGroupID ---
func
(
s
*
GroupRepoSuite
)
TestDeleteAccountGroupsByGroupID
()
{
func
(
s
*
GroupRepoSuite
)
TestDeleteAccountGroupsByGroupID
()
{
g
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g-del"
})
g
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g-del"
})
a
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"acc-del"
})
a
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"acc-del"
})
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
a
.
ID
,
g
.
ID
,
1
)
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
a
.
ID
,
g
.
ID
,
1
)
affected
,
err
:=
s
.
repo
.
DeleteAccountGroupsByGroupID
(
s
.
ctx
,
g
.
ID
)
affected
,
err
:=
s
.
repo
.
DeleteAccountGroupsByGroupID
(
s
.
ctx
,
g
.
ID
)
...
@@ -219,10 +219,10 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
...
@@ -219,10 +219,10 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
}
}
func
(
s
*
GroupRepoSuite
)
TestDeleteAccountGroupsByGroupID_MultipleAccounts
()
{
func
(
s
*
GroupRepoSuite
)
TestDeleteAccountGroupsByGroupID_MultipleAccounts
()
{
g
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
model
.
Group
{
Name
:
"g-multi"
})
g
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"g-multi"
})
a1
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"a1"
})
a1
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"a1"
})
a2
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"a2"
})
a2
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"a2"
})
a3
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
model
.
A
ccount
{
Name
:
"a3"
})
a3
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
a
ccount
Model
{
Name
:
"a3"
})
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
a1
.
ID
,
g
.
ID
,
1
)
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
a1
.
ID
,
g
.
ID
,
1
)
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
a2
.
ID
,
g
.
ID
,
2
)
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
a2
.
ID
,
g
.
ID
,
2
)
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
a3
.
ID
,
g
.
ID
,
3
)
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
a3
.
ID
,
g
.
ID
,
3
)
...
...
backend/internal/repository/integration_harness_test.go
View file @
22f07a7b
...
@@ -15,7 +15,6 @@ import (
...
@@ -15,7 +15,6 @@ import (
"testing"
"testing"
"time"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"github.com/stretchr/testify/suite"
...
@@ -94,7 +93,7 @@ func TestMain(m *testing.M) {
...
@@ -94,7 +93,7 @@ func TestMain(m *testing.M) {
log
.
Printf
(
"failed to open gorm db: %v"
,
err
)
log
.
Printf
(
"failed to open gorm db: %v"
,
err
)
os
.
Exit
(
1
)
os
.
Exit
(
1
)
}
}
if
err
:=
model
.
AutoMigrate
(
integrationDB
);
err
!=
nil
{
if
err
:=
AutoMigrate
(
integrationDB
);
err
!=
nil
{
log
.
Printf
(
"failed to automigrate db: %v"
,
err
)
log
.
Printf
(
"failed to automigrate db: %v"
,
err
)
os
.
Exit
(
1
)
os
.
Exit
(
1
)
}
}
...
...
backend/internal/repository/pagination.go
0 → 100644
View file @
22f07a7b
package
repository
import
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
func
paginationResultFromTotal
(
total
int64
,
params
pagination
.
PaginationParams
)
*
pagination
.
PaginationResult
{
pages
:=
int
(
total
)
/
params
.
Limit
()
if
int
(
total
)
%
params
.
Limit
()
>
0
{
pages
++
}
return
&
pagination
.
PaginationResult
{
Total
:
total
,
Page
:
params
.
Page
,
PageSize
:
params
.
Limit
(),
Pages
:
pages
,
}
}
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment