"backend/ent/vscode:/vscode.git/clone" did not exist on "0d6c1c779053beb171b8158b1340f792eb5f9bd3"
Commit 3d617de5 authored by yangjianbo's avatar yangjianbo
Browse files

refactor(数据库): 迁移持久层到 Ent 并清理 GORM

将仓储层/基础设施改为 Ent + 原生 SQL 执行路径,并移除 AutoMigrate 与 GORM 依赖。
重构内容包括:
- 仓储层改用 Ent/SQL(含 usage_log/account 等复杂查询),统一错误映射
- 基础设施与 setup 初始化切换为 Ent + SQL migrations
- 集成测试与 fixtures 迁移到 Ent 事务模型
- 清理遗留 GORM 模型/依赖,补充迁移与文档说明
- 增加根目录 Makefile 便于前后端编译

测试:
- go test -tags unit ./...
- go test -tags integration ./...
parent fd51ff69
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
// UserSubscriptionDelete is the builder for deleting a UserSubscription entity.
type UserSubscriptionDelete struct {
config
hooks []Hook
mutation *UserSubscriptionMutation
}
// Where appends a list predicates to the UserSubscriptionDelete builder.
func (_d *UserSubscriptionDelete) Where(ps ...predicate.UserSubscription) *UserSubscriptionDelete {
_d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query and returns how many vertices were deleted.
func (_d *UserSubscriptionDelete) Exec(ctx context.Context) (int, error) {
return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *UserSubscriptionDelete) ExecX(ctx context.Context) int {
n, err := _d.Exec(ctx)
if err != nil {
panic(err)
}
return n
}
func (_d *UserSubscriptionDelete) sqlExec(ctx context.Context) (int, error) {
_spec := sqlgraph.NewDeleteSpec(usersubscription.Table, sqlgraph.NewFieldSpec(usersubscription.FieldID, field.TypeInt64))
if ps := _d.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
if err != nil && sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
_d.mutation.done = true
return affected, err
}
// UserSubscriptionDeleteOne is the builder for deleting a single UserSubscription entity.
type UserSubscriptionDeleteOne struct {
_d *UserSubscriptionDelete
}
// Where appends a list predicates to the UserSubscriptionDelete builder.
func (_d *UserSubscriptionDeleteOne) Where(ps ...predicate.UserSubscription) *UserSubscriptionDeleteOne {
_d._d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query.
func (_d *UserSubscriptionDeleteOne) Exec(ctx context.Context) error {
n, err := _d._d.Exec(ctx)
switch {
case err != nil:
return err
case n == 0:
return &NotFoundError{usersubscription.Label}
default:
return nil
}
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *UserSubscriptionDeleteOne) ExecX(ctx context.Context) {
if err := _d.Exec(ctx); err != nil {
panic(err)
}
}
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"fmt"
"math"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
// UserSubscriptionQuery is the builder for querying UserSubscription entities.
type UserSubscriptionQuery struct {
config
ctx *QueryContext
order []usersubscription.OrderOption
inters []Interceptor
predicates []predicate.UserSubscription
withUser *UserQuery
withGroup *GroupQuery
withAssignedByUser *UserQuery
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
}
// Where adds a new predicate for the UserSubscriptionQuery builder.
func (_q *UserSubscriptionQuery) Where(ps ...predicate.UserSubscription) *UserSubscriptionQuery {
_q.predicates = append(_q.predicates, ps...)
return _q
}
// Limit the number of records to be returned by this query.
func (_q *UserSubscriptionQuery) Limit(limit int) *UserSubscriptionQuery {
_q.ctx.Limit = &limit
return _q
}
// Offset to start from.
func (_q *UserSubscriptionQuery) Offset(offset int) *UserSubscriptionQuery {
_q.ctx.Offset = &offset
return _q
}
// Unique configures the query builder to filter duplicate records on query.
// By default, unique is set to true, and can be disabled using this method.
func (_q *UserSubscriptionQuery) Unique(unique bool) *UserSubscriptionQuery {
_q.ctx.Unique = &unique
return _q
}
// Order specifies how the records should be ordered.
func (_q *UserSubscriptionQuery) Order(o ...usersubscription.OrderOption) *UserSubscriptionQuery {
_q.order = append(_q.order, o...)
return _q
}
// QueryUser chains the current query on the "user" edge.
func (_q *UserSubscriptionQuery) QueryUser() *UserQuery {
query := (&UserClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(usersubscription.Table, usersubscription.FieldID, selector),
sqlgraph.To(user.Table, user.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, usersubscription.UserTable, usersubscription.UserColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// QueryGroup chains the current query on the "group" edge.
func (_q *UserSubscriptionQuery) QueryGroup() *GroupQuery {
query := (&GroupClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(usersubscription.Table, usersubscription.FieldID, selector),
sqlgraph.To(group.Table, group.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, usersubscription.GroupTable, usersubscription.GroupColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// QueryAssignedByUser chains the current query on the "assigned_by_user" edge.
func (_q *UserSubscriptionQuery) QueryAssignedByUser() *UserQuery {
query := (&UserClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(usersubscription.Table, usersubscription.FieldID, selector),
sqlgraph.To(user.Table, user.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, usersubscription.AssignedByUserTable, usersubscription.AssignedByUserColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// First returns the first UserSubscription entity from the query.
// Returns a *NotFoundError when no UserSubscription was found.
func (_q *UserSubscriptionQuery) First(ctx context.Context) (*UserSubscription, error) {
nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
if err != nil {
return nil, err
}
if len(nodes) == 0 {
return nil, &NotFoundError{usersubscription.Label}
}
return nodes[0], nil
}
// FirstX is like First, but panics if an error occurs.
func (_q *UserSubscriptionQuery) FirstX(ctx context.Context) *UserSubscription {
node, err := _q.First(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return node
}
// FirstID returns the first UserSubscription ID from the query.
// Returns a *NotFoundError when no UserSubscription ID was found.
func (_q *UserSubscriptionQuery) FirstID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
return
}
if len(ids) == 0 {
err = &NotFoundError{usersubscription.Label}
return
}
return ids[0], nil
}
// FirstIDX is like FirstID, but panics if an error occurs.
func (_q *UserSubscriptionQuery) FirstIDX(ctx context.Context) int64 {
id, err := _q.FirstID(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return id
}
// Only returns a single UserSubscription entity found by the query, ensuring it only returns one.
// Returns a *NotSingularError when more than one UserSubscription entity is found.
// Returns a *NotFoundError when no UserSubscription entities are found.
func (_q *UserSubscriptionQuery) Only(ctx context.Context) (*UserSubscription, error) {
nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
if err != nil {
return nil, err
}
switch len(nodes) {
case 1:
return nodes[0], nil
case 0:
return nil, &NotFoundError{usersubscription.Label}
default:
return nil, &NotSingularError{usersubscription.Label}
}
}
// OnlyX is like Only, but panics if an error occurs.
func (_q *UserSubscriptionQuery) OnlyX(ctx context.Context) *UserSubscription {
node, err := _q.Only(ctx)
if err != nil {
panic(err)
}
return node
}
// OnlyID is like Only, but returns the only UserSubscription ID in the query.
// Returns a *NotSingularError when more than one UserSubscription ID is found.
// Returns a *NotFoundError when no entities are found.
func (_q *UserSubscriptionQuery) OnlyID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
return
}
switch len(ids) {
case 1:
id = ids[0]
case 0:
err = &NotFoundError{usersubscription.Label}
default:
err = &NotSingularError{usersubscription.Label}
}
return
}
// OnlyIDX is like OnlyID, but panics if an error occurs.
func (_q *UserSubscriptionQuery) OnlyIDX(ctx context.Context) int64 {
id, err := _q.OnlyID(ctx)
if err != nil {
panic(err)
}
return id
}
// All executes the query and returns a list of UserSubscriptions.
func (_q *UserSubscriptionQuery) All(ctx context.Context) ([]*UserSubscription, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
qr := querierAll[[]*UserSubscription, *UserSubscriptionQuery]()
return withInterceptors[[]*UserSubscription](ctx, _q, qr, _q.inters)
}
// AllX is like All, but panics if an error occurs.
func (_q *UserSubscriptionQuery) AllX(ctx context.Context) []*UserSubscription {
nodes, err := _q.All(ctx)
if err != nil {
panic(err)
}
return nodes
}
// IDs executes the query and returns a list of UserSubscription IDs.
func (_q *UserSubscriptionQuery) IDs(ctx context.Context) (ids []int64, err error) {
if _q.ctx.Unique == nil && _q.path != nil {
_q.Unique(true)
}
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
if err = _q.Select(usersubscription.FieldID).Scan(ctx, &ids); err != nil {
return nil, err
}
return ids, nil
}
// IDsX is like IDs, but panics if an error occurs.
func (_q *UserSubscriptionQuery) IDsX(ctx context.Context) []int64 {
ids, err := _q.IDs(ctx)
if err != nil {
panic(err)
}
return ids
}
// Count returns the count of the given query.
func (_q *UserSubscriptionQuery) Count(ctx context.Context) (int, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
if err := _q.prepareQuery(ctx); err != nil {
return 0, err
}
return withInterceptors[int](ctx, _q, querierCount[*UserSubscriptionQuery](), _q.inters)
}
// CountX is like Count, but panics if an error occurs.
func (_q *UserSubscriptionQuery) CountX(ctx context.Context) int {
count, err := _q.Count(ctx)
if err != nil {
panic(err)
}
return count
}
// Exist returns true if the query has elements in the graph.
func (_q *UserSubscriptionQuery) Exist(ctx context.Context) (bool, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
switch _, err := _q.FirstID(ctx); {
case IsNotFound(err):
return false, nil
case err != nil:
return false, fmt.Errorf("ent: check existence: %w", err)
default:
return true, nil
}
}
// ExistX is like Exist, but panics if an error occurs.
func (_q *UserSubscriptionQuery) ExistX(ctx context.Context) bool {
exist, err := _q.Exist(ctx)
if err != nil {
panic(err)
}
return exist
}
// Clone returns a duplicate of the UserSubscriptionQuery builder, including all associated steps. It can be
// used to prepare common query builders and use them differently after the clone is made.
func (_q *UserSubscriptionQuery) Clone() *UserSubscriptionQuery {
if _q == nil {
return nil
}
return &UserSubscriptionQuery{
config: _q.config,
ctx: _q.ctx.Clone(),
order: append([]usersubscription.OrderOption{}, _q.order...),
inters: append([]Interceptor{}, _q.inters...),
predicates: append([]predicate.UserSubscription{}, _q.predicates...),
withUser: _q.withUser.Clone(),
withGroup: _q.withGroup.Clone(),
withAssignedByUser: _q.withAssignedByUser.Clone(),
// clone intermediate query.
sql: _q.sql.Clone(),
path: _q.path,
}
}
// WithUser tells the query-builder to eager-load the nodes that are connected to
// the "user" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserSubscriptionQuery) WithUser(opts ...func(*UserQuery)) *UserSubscriptionQuery {
query := (&UserClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withUser = query
return _q
}
// WithGroup tells the query-builder to eager-load the nodes that are connected to
// the "group" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserSubscriptionQuery) WithGroup(opts ...func(*GroupQuery)) *UserSubscriptionQuery {
query := (&GroupClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withGroup = query
return _q
}
// WithAssignedByUser tells the query-builder to eager-load the nodes that are connected to
// the "assigned_by_user" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserSubscriptionQuery) WithAssignedByUser(opts ...func(*UserQuery)) *UserSubscriptionQuery {
query := (&UserClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withAssignedByUser = query
return _q
}
// GroupBy is used to group vertices by one or more fields/columns.
// It is often used with aggregate functions, like: count, max, mean, min, sum.
//
// Example:
//
// var v []struct {
// CreatedAt time.Time `json:"created_at,omitempty"`
// Count int `json:"count,omitempty"`
// }
//
// client.UserSubscription.Query().
// GroupBy(usersubscription.FieldCreatedAt).
// Aggregate(ent.Count()).
// Scan(ctx, &v)
func (_q *UserSubscriptionQuery) GroupBy(field string, fields ...string) *UserSubscriptionGroupBy {
_q.ctx.Fields = append([]string{field}, fields...)
grbuild := &UserSubscriptionGroupBy{build: _q}
grbuild.flds = &_q.ctx.Fields
grbuild.label = usersubscription.Label
grbuild.scan = grbuild.Scan
return grbuild
}
// Select allows the selection one or more fields/columns for the given query,
// instead of selecting all fields in the entity.
//
// Example:
//
// var v []struct {
// CreatedAt time.Time `json:"created_at,omitempty"`
// }
//
// client.UserSubscription.Query().
// Select(usersubscription.FieldCreatedAt).
// Scan(ctx, &v)
func (_q *UserSubscriptionQuery) Select(fields ...string) *UserSubscriptionSelect {
_q.ctx.Fields = append(_q.ctx.Fields, fields...)
sbuild := &UserSubscriptionSelect{UserSubscriptionQuery: _q}
sbuild.label = usersubscription.Label
sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
return sbuild
}
// Aggregate returns a UserSubscriptionSelect configured with the given aggregations.
func (_q *UserSubscriptionQuery) Aggregate(fns ...AggregateFunc) *UserSubscriptionSelect {
return _q.Select().Aggregate(fns...)
}
func (_q *UserSubscriptionQuery) prepareQuery(ctx context.Context) error {
for _, inter := range _q.inters {
if inter == nil {
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
}
if trv, ok := inter.(Traverser); ok {
if err := trv.Traverse(ctx, _q); err != nil {
return err
}
}
}
for _, f := range _q.ctx.Fields {
if !usersubscription.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
}
if _q.path != nil {
prev, err := _q.path(ctx)
if err != nil {
return err
}
_q.sql = prev
}
return nil
}
func (_q *UserSubscriptionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*UserSubscription, error) {
var (
nodes = []*UserSubscription{}
_spec = _q.querySpec()
loadedTypes = [3]bool{
_q.withUser != nil,
_q.withGroup != nil,
_q.withAssignedByUser != nil,
}
)
_spec.ScanValues = func(columns []string) ([]any, error) {
return (*UserSubscription).scanValues(nil, columns)
}
_spec.Assign = func(columns []string, values []any) error {
node := &UserSubscription{config: _q.config}
nodes = append(nodes, node)
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
for i := range hooks {
hooks[i](ctx, _spec)
}
if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
return nil, err
}
if len(nodes) == 0 {
return nodes, nil
}
if query := _q.withUser; query != nil {
if err := _q.loadUser(ctx, query, nodes, nil,
func(n *UserSubscription, e *User) { n.Edges.User = e }); err != nil {
return nil, err
}
}
if query := _q.withGroup; query != nil {
if err := _q.loadGroup(ctx, query, nodes, nil,
func(n *UserSubscription, e *Group) { n.Edges.Group = e }); err != nil {
return nil, err
}
}
if query := _q.withAssignedByUser; query != nil {
if err := _q.loadAssignedByUser(ctx, query, nodes, nil,
func(n *UserSubscription, e *User) { n.Edges.AssignedByUser = e }); err != nil {
return nil, err
}
}
return nodes, nil
}
func (_q *UserSubscriptionQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*UserSubscription, init func(*UserSubscription), assign func(*UserSubscription, *User)) error {
ids := make([]int64, 0, len(nodes))
nodeids := make(map[int64][]*UserSubscription)
for i := range nodes {
fk := nodes[i].UserID
if _, ok := nodeids[fk]; !ok {
ids = append(ids, fk)
}
nodeids[fk] = append(nodeids[fk], nodes[i])
}
if len(ids) == 0 {
return nil
}
query.Where(user.IDIn(ids...))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
nodes, ok := nodeids[n.ID]
if !ok {
return fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID)
}
for i := range nodes {
assign(nodes[i], n)
}
}
return nil
}
func (_q *UserSubscriptionQuery) loadGroup(ctx context.Context, query *GroupQuery, nodes []*UserSubscription, init func(*UserSubscription), assign func(*UserSubscription, *Group)) error {
ids := make([]int64, 0, len(nodes))
nodeids := make(map[int64][]*UserSubscription)
for i := range nodes {
fk := nodes[i].GroupID
if _, ok := nodeids[fk]; !ok {
ids = append(ids, fk)
}
nodeids[fk] = append(nodeids[fk], nodes[i])
}
if len(ids) == 0 {
return nil
}
query.Where(group.IDIn(ids...))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
nodes, ok := nodeids[n.ID]
if !ok {
return fmt.Errorf(`unexpected foreign-key "group_id" returned %v`, n.ID)
}
for i := range nodes {
assign(nodes[i], n)
}
}
return nil
}
func (_q *UserSubscriptionQuery) loadAssignedByUser(ctx context.Context, query *UserQuery, nodes []*UserSubscription, init func(*UserSubscription), assign func(*UserSubscription, *User)) error {
ids := make([]int64, 0, len(nodes))
nodeids := make(map[int64][]*UserSubscription)
for i := range nodes {
if nodes[i].AssignedBy == nil {
continue
}
fk := *nodes[i].AssignedBy
if _, ok := nodeids[fk]; !ok {
ids = append(ids, fk)
}
nodeids[fk] = append(nodeids[fk], nodes[i])
}
if len(ids) == 0 {
return nil
}
query.Where(user.IDIn(ids...))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
nodes, ok := nodeids[n.ID]
if !ok {
return fmt.Errorf(`unexpected foreign-key "assigned_by" returned %v`, n.ID)
}
for i := range nodes {
assign(nodes[i], n)
}
}
return nil
}
func (_q *UserSubscriptionQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
}
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
}
func (_q *UserSubscriptionQuery) querySpec() *sqlgraph.QuerySpec {
_spec := sqlgraph.NewQuerySpec(usersubscription.Table, usersubscription.Columns, sqlgraph.NewFieldSpec(usersubscription.FieldID, field.TypeInt64))
_spec.From = _q.sql
if unique := _q.ctx.Unique; unique != nil {
_spec.Unique = *unique
} else if _q.path != nil {
_spec.Unique = true
}
if fields := _q.ctx.Fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, usersubscription.FieldID)
for i := range fields {
if fields[i] != usersubscription.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
}
}
if _q.withUser != nil {
_spec.Node.AddColumnOnce(usersubscription.FieldUserID)
}
if _q.withGroup != nil {
_spec.Node.AddColumnOnce(usersubscription.FieldGroupID)
}
if _q.withAssignedByUser != nil {
_spec.Node.AddColumnOnce(usersubscription.FieldAssignedBy)
}
}
if ps := _q.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if limit := _q.ctx.Limit; limit != nil {
_spec.Limit = *limit
}
if offset := _q.ctx.Offset; offset != nil {
_spec.Offset = *offset
}
if ps := _q.order; len(ps) > 0 {
_spec.Order = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
return _spec
}
func (_q *UserSubscriptionQuery) sqlQuery(ctx context.Context) *sql.Selector {
builder := sql.Dialect(_q.driver.Dialect())
t1 := builder.Table(usersubscription.Table)
columns := _q.ctx.Fields
if len(columns) == 0 {
columns = usersubscription.Columns
}
selector := builder.Select(t1.Columns(columns...)...).From(t1)
if _q.sql != nil {
selector = _q.sql
selector.Select(selector.Columns(columns...)...)
}
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, p := range _q.predicates {
p(selector)
}
for _, p := range _q.order {
p(selector)
}
if offset := _q.ctx.Offset; offset != nil {
// limit is mandatory for offset clause. We start
// with default value, and override it below if needed.
selector.Offset(*offset).Limit(math.MaxInt32)
}
if limit := _q.ctx.Limit; limit != nil {
selector.Limit(*limit)
}
return selector
}
// UserSubscriptionGroupBy is the group-by builder for UserSubscription entities.
type UserSubscriptionGroupBy struct {
selector
build *UserSubscriptionQuery
}
// Aggregate adds the given aggregation functions to the group-by query.
func (_g *UserSubscriptionGroupBy) Aggregate(fns ...AggregateFunc) *UserSubscriptionGroupBy {
_g.fns = append(_g.fns, fns...)
return _g
}
// Scan applies the selector query and scans the result into the given value.
func (_g *UserSubscriptionGroupBy) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
if err := _g.build.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*UserSubscriptionQuery, *UserSubscriptionGroupBy](ctx, _g.build, _g, _g.build.inters, v)
}
func (_g *UserSubscriptionGroupBy) sqlScan(ctx context.Context, root *UserSubscriptionQuery, v any) error {
selector := root.sqlQuery(ctx).Select()
aggregation := make([]string, 0, len(_g.fns))
for _, fn := range _g.fns {
aggregation = append(aggregation, fn(selector))
}
if len(selector.SelectedColumns()) == 0 {
columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
for _, f := range *_g.flds {
columns = append(columns, selector.C(f))
}
columns = append(columns, aggregation...)
selector.Select(columns...)
}
selector.GroupBy(selector.Columns(*_g.flds...)...)
if err := selector.Err(); err != nil {
return err
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// UserSubscriptionSelect is the builder for selecting fields of UserSubscription entities.
type UserSubscriptionSelect struct {
*UserSubscriptionQuery
selector
}
// Aggregate adds the given aggregation functions to the selector query.
func (_s *UserSubscriptionSelect) Aggregate(fns ...AggregateFunc) *UserSubscriptionSelect {
_s.fns = append(_s.fns, fns...)
return _s
}
// Scan applies the selector query and scans the result into the given value.
func (_s *UserSubscriptionSelect) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
if err := _s.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*UserSubscriptionQuery, *UserSubscriptionSelect](ctx, _s.UserSubscriptionQuery, _s, _s.inters, v)
}
func (_s *UserSubscriptionSelect) sqlScan(ctx context.Context, root *UserSubscriptionQuery, v any) error {
selector := root.sqlQuery(ctx)
aggregation := make([]string, 0, len(_s.fns))
for _, fn := range _s.fns {
aggregation = append(aggregation, fn(selector))
}
switch n := len(*_s.selector.flds); {
case n == 0 && len(aggregation) > 0:
selector.Select(aggregation...)
case n != 0 && len(aggregation) > 0:
selector.AppendSelect(aggregation...)
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _s.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"errors"
"fmt"
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
// UserSubscriptionUpdate is the builder for updating UserSubscription entities.
type UserSubscriptionUpdate struct {
config
hooks []Hook
mutation *UserSubscriptionMutation
}
// Where appends a list predicates to the UserSubscriptionUpdate builder.
func (_u *UserSubscriptionUpdate) Where(ps ...predicate.UserSubscription) *UserSubscriptionUpdate {
_u.mutation.Where(ps...)
return _u
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *UserSubscriptionUpdate) SetUpdatedAt(v time.Time) *UserSubscriptionUpdate {
_u.mutation.SetUpdatedAt(v)
return _u
}
// SetUserID sets the "user_id" field.
func (_u *UserSubscriptionUpdate) SetUserID(v int64) *UserSubscriptionUpdate {
_u.mutation.SetUserID(v)
return _u
}
// SetNillableUserID sets the "user_id" field if the given value is not nil.
func (_u *UserSubscriptionUpdate) SetNillableUserID(v *int64) *UserSubscriptionUpdate {
if v != nil {
_u.SetUserID(*v)
}
return _u
}
// SetGroupID sets the "group_id" field.
func (_u *UserSubscriptionUpdate) SetGroupID(v int64) *UserSubscriptionUpdate {
_u.mutation.SetGroupID(v)
return _u
}
// SetNillableGroupID sets the "group_id" field if the given value is not nil.
func (_u *UserSubscriptionUpdate) SetNillableGroupID(v *int64) *UserSubscriptionUpdate {
if v != nil {
_u.SetGroupID(*v)
}
return _u
}
// SetStartsAt sets the "starts_at" field.
func (_u *UserSubscriptionUpdate) SetStartsAt(v time.Time) *UserSubscriptionUpdate {
_u.mutation.SetStartsAt(v)
return _u
}
// SetNillableStartsAt sets the "starts_at" field if the given value is not nil.
func (_u *UserSubscriptionUpdate) SetNillableStartsAt(v *time.Time) *UserSubscriptionUpdate {
if v != nil {
_u.SetStartsAt(*v)
}
return _u
}
// SetExpiresAt sets the "expires_at" field.
func (_u *UserSubscriptionUpdate) SetExpiresAt(v time.Time) *UserSubscriptionUpdate {
_u.mutation.SetExpiresAt(v)
return _u
}
// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
func (_u *UserSubscriptionUpdate) SetNillableExpiresAt(v *time.Time) *UserSubscriptionUpdate {
if v != nil {
_u.SetExpiresAt(*v)
}
return _u
}
// SetStatus sets the "status" field.
func (_u *UserSubscriptionUpdate) SetStatus(v string) *UserSubscriptionUpdate {
_u.mutation.SetStatus(v)
return _u
}
// SetNillableStatus sets the "status" field if the given value is not nil.
func (_u *UserSubscriptionUpdate) SetNillableStatus(v *string) *UserSubscriptionUpdate {
if v != nil {
_u.SetStatus(*v)
}
return _u
}
// SetDailyWindowStart sets the "daily_window_start" field.
func (_u *UserSubscriptionUpdate) SetDailyWindowStart(v time.Time) *UserSubscriptionUpdate {
_u.mutation.SetDailyWindowStart(v)
return _u
}
// SetNillableDailyWindowStart sets the "daily_window_start" field if the given value is not nil.
func (_u *UserSubscriptionUpdate) SetNillableDailyWindowStart(v *time.Time) *UserSubscriptionUpdate {
if v != nil {
_u.SetDailyWindowStart(*v)
}
return _u
}
// ClearDailyWindowStart clears the value of the "daily_window_start" field.
func (_u *UserSubscriptionUpdate) ClearDailyWindowStart() *UserSubscriptionUpdate {
_u.mutation.ClearDailyWindowStart()
return _u
}
// SetWeeklyWindowStart sets the "weekly_window_start" field.
func (_u *UserSubscriptionUpdate) SetWeeklyWindowStart(v time.Time) *UserSubscriptionUpdate {
_u.mutation.SetWeeklyWindowStart(v)
return _u
}
// SetNillableWeeklyWindowStart sets the "weekly_window_start" field if the given value is not nil.
func (_u *UserSubscriptionUpdate) SetNillableWeeklyWindowStart(v *time.Time) *UserSubscriptionUpdate {
if v != nil {
_u.SetWeeklyWindowStart(*v)
}
return _u
}
// ClearWeeklyWindowStart clears the value of the "weekly_window_start" field.
func (_u *UserSubscriptionUpdate) ClearWeeklyWindowStart() *UserSubscriptionUpdate {
_u.mutation.ClearWeeklyWindowStart()
return _u
}
// SetMonthlyWindowStart sets the "monthly_window_start" field.
func (_u *UserSubscriptionUpdate) SetMonthlyWindowStart(v time.Time) *UserSubscriptionUpdate {
_u.mutation.SetMonthlyWindowStart(v)
return _u
}
// SetNillableMonthlyWindowStart sets the "monthly_window_start" field if the given value is not nil.
func (_u *UserSubscriptionUpdate) SetNillableMonthlyWindowStart(v *time.Time) *UserSubscriptionUpdate {
if v != nil {
_u.SetMonthlyWindowStart(*v)
}
return _u
}
// ClearMonthlyWindowStart clears the value of the "monthly_window_start" field.
func (_u *UserSubscriptionUpdate) ClearMonthlyWindowStart() *UserSubscriptionUpdate {
_u.mutation.ClearMonthlyWindowStart()
return _u
}
// SetDailyUsageUsd sets the "daily_usage_usd" field.
func (_u *UserSubscriptionUpdate) SetDailyUsageUsd(v float64) *UserSubscriptionUpdate {
_u.mutation.ResetDailyUsageUsd()
_u.mutation.SetDailyUsageUsd(v)
return _u
}
// SetNillableDailyUsageUsd sets the "daily_usage_usd" field if the given value is not nil.
func (_u *UserSubscriptionUpdate) SetNillableDailyUsageUsd(v *float64) *UserSubscriptionUpdate {
if v != nil {
_u.SetDailyUsageUsd(*v)
}
return _u
}
// AddDailyUsageUsd adds value to the "daily_usage_usd" field.
func (_u *UserSubscriptionUpdate) AddDailyUsageUsd(v float64) *UserSubscriptionUpdate {
_u.mutation.AddDailyUsageUsd(v)
return _u
}
// SetWeeklyUsageUsd sets the "weekly_usage_usd" field.
func (_u *UserSubscriptionUpdate) SetWeeklyUsageUsd(v float64) *UserSubscriptionUpdate {
_u.mutation.ResetWeeklyUsageUsd()
_u.mutation.SetWeeklyUsageUsd(v)
return _u
}
// SetNillableWeeklyUsageUsd sets the "weekly_usage_usd" field if the given value is not nil.
func (_u *UserSubscriptionUpdate) SetNillableWeeklyUsageUsd(v *float64) *UserSubscriptionUpdate {
if v != nil {
_u.SetWeeklyUsageUsd(*v)
}
return _u
}
// AddWeeklyUsageUsd adds value to the "weekly_usage_usd" field.
func (_u *UserSubscriptionUpdate) AddWeeklyUsageUsd(v float64) *UserSubscriptionUpdate {
_u.mutation.AddWeeklyUsageUsd(v)
return _u
}
// SetMonthlyUsageUsd sets the "monthly_usage_usd" field.
func (_u *UserSubscriptionUpdate) SetMonthlyUsageUsd(v float64) *UserSubscriptionUpdate {
_u.mutation.ResetMonthlyUsageUsd()
_u.mutation.SetMonthlyUsageUsd(v)
return _u
}
// SetNillableMonthlyUsageUsd sets the "monthly_usage_usd" field if the given value is not nil.
func (_u *UserSubscriptionUpdate) SetNillableMonthlyUsageUsd(v *float64) *UserSubscriptionUpdate {
if v != nil {
_u.SetMonthlyUsageUsd(*v)
}
return _u
}
// AddMonthlyUsageUsd adds value to the "monthly_usage_usd" field.
func (_u *UserSubscriptionUpdate) AddMonthlyUsageUsd(v float64) *UserSubscriptionUpdate {
_u.mutation.AddMonthlyUsageUsd(v)
return _u
}
// SetAssignedBy sets the "assigned_by" field.
func (_u *UserSubscriptionUpdate) SetAssignedBy(v int64) *UserSubscriptionUpdate {
_u.mutation.SetAssignedBy(v)
return _u
}
// SetNillableAssignedBy sets the "assigned_by" field if the given value is not nil.
func (_u *UserSubscriptionUpdate) SetNillableAssignedBy(v *int64) *UserSubscriptionUpdate {
if v != nil {
_u.SetAssignedBy(*v)
}
return _u
}
// ClearAssignedBy clears the value of the "assigned_by" field.
func (_u *UserSubscriptionUpdate) ClearAssignedBy() *UserSubscriptionUpdate {
_u.mutation.ClearAssignedBy()
return _u
}
// SetAssignedAt sets the "assigned_at" field.
func (_u *UserSubscriptionUpdate) SetAssignedAt(v time.Time) *UserSubscriptionUpdate {
_u.mutation.SetAssignedAt(v)
return _u
}
// SetNillableAssignedAt sets the "assigned_at" field if the given value is not nil.
func (_u *UserSubscriptionUpdate) SetNillableAssignedAt(v *time.Time) *UserSubscriptionUpdate {
if v != nil {
_u.SetAssignedAt(*v)
}
return _u
}
// SetNotes sets the "notes" field.
func (_u *UserSubscriptionUpdate) SetNotes(v string) *UserSubscriptionUpdate {
_u.mutation.SetNotes(v)
return _u
}
// SetNillableNotes sets the "notes" field if the given value is not nil.
func (_u *UserSubscriptionUpdate) SetNillableNotes(v *string) *UserSubscriptionUpdate {
if v != nil {
_u.SetNotes(*v)
}
return _u
}
// ClearNotes clears the value of the "notes" field.
func (_u *UserSubscriptionUpdate) ClearNotes() *UserSubscriptionUpdate {
_u.mutation.ClearNotes()
return _u
}
// SetUser sets the "user" edge to the User entity.
func (_u *UserSubscriptionUpdate) SetUser(v *User) *UserSubscriptionUpdate {
return _u.SetUserID(v.ID)
}
// SetGroup sets the "group" edge to the Group entity.
func (_u *UserSubscriptionUpdate) SetGroup(v *Group) *UserSubscriptionUpdate {
return _u.SetGroupID(v.ID)
}
// SetAssignedByUserID sets the "assigned_by_user" edge to the User entity by ID.
func (_u *UserSubscriptionUpdate) SetAssignedByUserID(id int64) *UserSubscriptionUpdate {
_u.mutation.SetAssignedByUserID(id)
return _u
}
// SetNillableAssignedByUserID sets the "assigned_by_user" edge to the User entity by ID if the given value is not nil.
func (_u *UserSubscriptionUpdate) SetNillableAssignedByUserID(id *int64) *UserSubscriptionUpdate {
if id != nil {
_u = _u.SetAssignedByUserID(*id)
}
return _u
}
// SetAssignedByUser sets the "assigned_by_user" edge to the User entity.
func (_u *UserSubscriptionUpdate) SetAssignedByUser(v *User) *UserSubscriptionUpdate {
return _u.SetAssignedByUserID(v.ID)
}
// Mutation returns the UserSubscriptionMutation object of the builder.
func (_u *UserSubscriptionUpdate) Mutation() *UserSubscriptionMutation {
return _u.mutation
}
// ClearUser clears the "user" edge to the User entity.
func (_u *UserSubscriptionUpdate) ClearUser() *UserSubscriptionUpdate {
_u.mutation.ClearUser()
return _u
}
// ClearGroup clears the "group" edge to the Group entity.
func (_u *UserSubscriptionUpdate) ClearGroup() *UserSubscriptionUpdate {
_u.mutation.ClearGroup()
return _u
}
// ClearAssignedByUser clears the "assigned_by_user" edge to the User entity.
func (_u *UserSubscriptionUpdate) ClearAssignedByUser() *UserSubscriptionUpdate {
_u.mutation.ClearAssignedByUser()
return _u
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *UserSubscriptionUpdate) Save(ctx context.Context) (int, error) {
_u.defaults()
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *UserSubscriptionUpdate) SaveX(ctx context.Context) int {
affected, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return affected
}
// Exec executes the query.
func (_u *UserSubscriptionUpdate) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *UserSubscriptionUpdate) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *UserSubscriptionUpdate) defaults() {
if _, ok := _u.mutation.UpdatedAt(); !ok {
v := usersubscription.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_u *UserSubscriptionUpdate) check() error {
if v, ok := _u.mutation.Status(); ok {
if err := usersubscription.StatusValidator(v); err != nil {
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "UserSubscription.status": %w`, err)}
}
}
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "UserSubscription.user"`)
}
if _u.mutation.GroupCleared() && len(_u.mutation.GroupIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "UserSubscription.group"`)
}
return nil
}
func (_u *UserSubscriptionUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(usersubscription.Table, usersubscription.Columns, sqlgraph.NewFieldSpec(usersubscription.FieldID, field.TypeInt64))
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(usersubscription.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.StartsAt(); ok {
_spec.SetField(usersubscription.FieldStartsAt, field.TypeTime, value)
}
if value, ok := _u.mutation.ExpiresAt(); ok {
_spec.SetField(usersubscription.FieldExpiresAt, field.TypeTime, value)
}
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(usersubscription.FieldStatus, field.TypeString, value)
}
if value, ok := _u.mutation.DailyWindowStart(); ok {
_spec.SetField(usersubscription.FieldDailyWindowStart, field.TypeTime, value)
}
if _u.mutation.DailyWindowStartCleared() {
_spec.ClearField(usersubscription.FieldDailyWindowStart, field.TypeTime)
}
if value, ok := _u.mutation.WeeklyWindowStart(); ok {
_spec.SetField(usersubscription.FieldWeeklyWindowStart, field.TypeTime, value)
}
if _u.mutation.WeeklyWindowStartCleared() {
_spec.ClearField(usersubscription.FieldWeeklyWindowStart, field.TypeTime)
}
if value, ok := _u.mutation.MonthlyWindowStart(); ok {
_spec.SetField(usersubscription.FieldMonthlyWindowStart, field.TypeTime, value)
}
if _u.mutation.MonthlyWindowStartCleared() {
_spec.ClearField(usersubscription.FieldMonthlyWindowStart, field.TypeTime)
}
if value, ok := _u.mutation.DailyUsageUsd(); ok {
_spec.SetField(usersubscription.FieldDailyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedDailyUsageUsd(); ok {
_spec.AddField(usersubscription.FieldDailyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.WeeklyUsageUsd(); ok {
_spec.SetField(usersubscription.FieldWeeklyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedWeeklyUsageUsd(); ok {
_spec.AddField(usersubscription.FieldWeeklyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.MonthlyUsageUsd(); ok {
_spec.SetField(usersubscription.FieldMonthlyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedMonthlyUsageUsd(); ok {
_spec.AddField(usersubscription.FieldMonthlyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AssignedAt(); ok {
_spec.SetField(usersubscription.FieldAssignedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.Notes(); ok {
_spec.SetField(usersubscription.FieldNotes, field.TypeString, value)
}
if _u.mutation.NotesCleared() {
_spec.ClearField(usersubscription.FieldNotes, field.TypeString)
}
if _u.mutation.UserCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: usersubscription.UserTable,
Columns: []string{usersubscription.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: usersubscription.UserTable,
Columns: []string{usersubscription.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.GroupCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: usersubscription.GroupTable,
Columns: []string{usersubscription.GroupColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.GroupIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: usersubscription.GroupTable,
Columns: []string{usersubscription.GroupColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.AssignedByUserCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: usersubscription.AssignedByUserTable,
Columns: []string{usersubscription.AssignedByUserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.AssignedByUserIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: usersubscription.AssignedByUserTable,
Columns: []string{usersubscription.AssignedByUserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{usersubscription.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return 0, err
}
_u.mutation.done = true
return _node, nil
}
// UserSubscriptionUpdateOne is the builder for updating a single UserSubscription entity.
type UserSubscriptionUpdateOne struct {
config
fields []string
hooks []Hook
mutation *UserSubscriptionMutation
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *UserSubscriptionUpdateOne) SetUpdatedAt(v time.Time) *UserSubscriptionUpdateOne {
_u.mutation.SetUpdatedAt(v)
return _u
}
// SetUserID sets the "user_id" field.
func (_u *UserSubscriptionUpdateOne) SetUserID(v int64) *UserSubscriptionUpdateOne {
_u.mutation.SetUserID(v)
return _u
}
// SetNillableUserID sets the "user_id" field if the given value is not nil.
func (_u *UserSubscriptionUpdateOne) SetNillableUserID(v *int64) *UserSubscriptionUpdateOne {
if v != nil {
_u.SetUserID(*v)
}
return _u
}
// SetGroupID sets the "group_id" field.
func (_u *UserSubscriptionUpdateOne) SetGroupID(v int64) *UserSubscriptionUpdateOne {
_u.mutation.SetGroupID(v)
return _u
}
// SetNillableGroupID sets the "group_id" field if the given value is not nil.
func (_u *UserSubscriptionUpdateOne) SetNillableGroupID(v *int64) *UserSubscriptionUpdateOne {
if v != nil {
_u.SetGroupID(*v)
}
return _u
}
// SetStartsAt sets the "starts_at" field.
func (_u *UserSubscriptionUpdateOne) SetStartsAt(v time.Time) *UserSubscriptionUpdateOne {
_u.mutation.SetStartsAt(v)
return _u
}
// SetNillableStartsAt sets the "starts_at" field if the given value is not nil.
func (_u *UserSubscriptionUpdateOne) SetNillableStartsAt(v *time.Time) *UserSubscriptionUpdateOne {
if v != nil {
_u.SetStartsAt(*v)
}
return _u
}
// SetExpiresAt sets the "expires_at" field.
func (_u *UserSubscriptionUpdateOne) SetExpiresAt(v time.Time) *UserSubscriptionUpdateOne {
_u.mutation.SetExpiresAt(v)
return _u
}
// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
func (_u *UserSubscriptionUpdateOne) SetNillableExpiresAt(v *time.Time) *UserSubscriptionUpdateOne {
if v != nil {
_u.SetExpiresAt(*v)
}
return _u
}
// SetStatus sets the "status" field.
func (_u *UserSubscriptionUpdateOne) SetStatus(v string) *UserSubscriptionUpdateOne {
_u.mutation.SetStatus(v)
return _u
}
// SetNillableStatus sets the "status" field if the given value is not nil.
func (_u *UserSubscriptionUpdateOne) SetNillableStatus(v *string) *UserSubscriptionUpdateOne {
if v != nil {
_u.SetStatus(*v)
}
return _u
}
// SetDailyWindowStart sets the "daily_window_start" field.
func (_u *UserSubscriptionUpdateOne) SetDailyWindowStart(v time.Time) *UserSubscriptionUpdateOne {
_u.mutation.SetDailyWindowStart(v)
return _u
}
// SetNillableDailyWindowStart sets the "daily_window_start" field if the given value is not nil.
func (_u *UserSubscriptionUpdateOne) SetNillableDailyWindowStart(v *time.Time) *UserSubscriptionUpdateOne {
if v != nil {
_u.SetDailyWindowStart(*v)
}
return _u
}
// ClearDailyWindowStart clears the value of the "daily_window_start" field.
func (_u *UserSubscriptionUpdateOne) ClearDailyWindowStart() *UserSubscriptionUpdateOne {
_u.mutation.ClearDailyWindowStart()
return _u
}
// SetWeeklyWindowStart sets the "weekly_window_start" field.
func (_u *UserSubscriptionUpdateOne) SetWeeklyWindowStart(v time.Time) *UserSubscriptionUpdateOne {
_u.mutation.SetWeeklyWindowStart(v)
return _u
}
// SetNillableWeeklyWindowStart sets the "weekly_window_start" field if the given value is not nil.
func (_u *UserSubscriptionUpdateOne) SetNillableWeeklyWindowStart(v *time.Time) *UserSubscriptionUpdateOne {
if v != nil {
_u.SetWeeklyWindowStart(*v)
}
return _u
}
// ClearWeeklyWindowStart clears the value of the "weekly_window_start" field.
func (_u *UserSubscriptionUpdateOne) ClearWeeklyWindowStart() *UserSubscriptionUpdateOne {
_u.mutation.ClearWeeklyWindowStart()
return _u
}
// SetMonthlyWindowStart sets the "monthly_window_start" field.
func (_u *UserSubscriptionUpdateOne) SetMonthlyWindowStart(v time.Time) *UserSubscriptionUpdateOne {
_u.mutation.SetMonthlyWindowStart(v)
return _u
}
// SetNillableMonthlyWindowStart sets the "monthly_window_start" field if the given value is not nil.
func (_u *UserSubscriptionUpdateOne) SetNillableMonthlyWindowStart(v *time.Time) *UserSubscriptionUpdateOne {
if v != nil {
_u.SetMonthlyWindowStart(*v)
}
return _u
}
// ClearMonthlyWindowStart clears the value of the "monthly_window_start" field.
func (_u *UserSubscriptionUpdateOne) ClearMonthlyWindowStart() *UserSubscriptionUpdateOne {
_u.mutation.ClearMonthlyWindowStart()
return _u
}
// SetDailyUsageUsd sets the "daily_usage_usd" field.
func (_u *UserSubscriptionUpdateOne) SetDailyUsageUsd(v float64) *UserSubscriptionUpdateOne {
_u.mutation.ResetDailyUsageUsd()
_u.mutation.SetDailyUsageUsd(v)
return _u
}
// SetNillableDailyUsageUsd sets the "daily_usage_usd" field if the given value is not nil.
func (_u *UserSubscriptionUpdateOne) SetNillableDailyUsageUsd(v *float64) *UserSubscriptionUpdateOne {
if v != nil {
_u.SetDailyUsageUsd(*v)
}
return _u
}
// AddDailyUsageUsd adds value to the "daily_usage_usd" field.
func (_u *UserSubscriptionUpdateOne) AddDailyUsageUsd(v float64) *UserSubscriptionUpdateOne {
_u.mutation.AddDailyUsageUsd(v)
return _u
}
// SetWeeklyUsageUsd sets the "weekly_usage_usd" field.
func (_u *UserSubscriptionUpdateOne) SetWeeklyUsageUsd(v float64) *UserSubscriptionUpdateOne {
_u.mutation.ResetWeeklyUsageUsd()
_u.mutation.SetWeeklyUsageUsd(v)
return _u
}
// SetNillableWeeklyUsageUsd sets the "weekly_usage_usd" field if the given value is not nil.
func (_u *UserSubscriptionUpdateOne) SetNillableWeeklyUsageUsd(v *float64) *UserSubscriptionUpdateOne {
if v != nil {
_u.SetWeeklyUsageUsd(*v)
}
return _u
}
// AddWeeklyUsageUsd adds value to the "weekly_usage_usd" field.
func (_u *UserSubscriptionUpdateOne) AddWeeklyUsageUsd(v float64) *UserSubscriptionUpdateOne {
_u.mutation.AddWeeklyUsageUsd(v)
return _u
}
// SetMonthlyUsageUsd sets the "monthly_usage_usd" field.
func (_u *UserSubscriptionUpdateOne) SetMonthlyUsageUsd(v float64) *UserSubscriptionUpdateOne {
_u.mutation.ResetMonthlyUsageUsd()
_u.mutation.SetMonthlyUsageUsd(v)
return _u
}
// SetNillableMonthlyUsageUsd sets the "monthly_usage_usd" field if the given value is not nil.
func (_u *UserSubscriptionUpdateOne) SetNillableMonthlyUsageUsd(v *float64) *UserSubscriptionUpdateOne {
if v != nil {
_u.SetMonthlyUsageUsd(*v)
}
return _u
}
// AddMonthlyUsageUsd adds value to the "monthly_usage_usd" field.
func (_u *UserSubscriptionUpdateOne) AddMonthlyUsageUsd(v float64) *UserSubscriptionUpdateOne {
_u.mutation.AddMonthlyUsageUsd(v)
return _u
}
// SetAssignedBy sets the "assigned_by" field.
func (_u *UserSubscriptionUpdateOne) SetAssignedBy(v int64) *UserSubscriptionUpdateOne {
_u.mutation.SetAssignedBy(v)
return _u
}
// SetNillableAssignedBy sets the "assigned_by" field if the given value is not nil.
func (_u *UserSubscriptionUpdateOne) SetNillableAssignedBy(v *int64) *UserSubscriptionUpdateOne {
if v != nil {
_u.SetAssignedBy(*v)
}
return _u
}
// ClearAssignedBy clears the value of the "assigned_by" field.
func (_u *UserSubscriptionUpdateOne) ClearAssignedBy() *UserSubscriptionUpdateOne {
_u.mutation.ClearAssignedBy()
return _u
}
// SetAssignedAt sets the "assigned_at" field.
func (_u *UserSubscriptionUpdateOne) SetAssignedAt(v time.Time) *UserSubscriptionUpdateOne {
_u.mutation.SetAssignedAt(v)
return _u
}
// SetNillableAssignedAt sets the "assigned_at" field if the given value is not nil.
func (_u *UserSubscriptionUpdateOne) SetNillableAssignedAt(v *time.Time) *UserSubscriptionUpdateOne {
if v != nil {
_u.SetAssignedAt(*v)
}
return _u
}
// SetNotes sets the "notes" field.
func (_u *UserSubscriptionUpdateOne) SetNotes(v string) *UserSubscriptionUpdateOne {
_u.mutation.SetNotes(v)
return _u
}
// SetNillableNotes sets the "notes" field if the given value is not nil.
func (_u *UserSubscriptionUpdateOne) SetNillableNotes(v *string) *UserSubscriptionUpdateOne {
if v != nil {
_u.SetNotes(*v)
}
return _u
}
// ClearNotes clears the value of the "notes" field.
func (_u *UserSubscriptionUpdateOne) ClearNotes() *UserSubscriptionUpdateOne {
_u.mutation.ClearNotes()
return _u
}
// SetUser sets the "user" edge to the User entity.
func (_u *UserSubscriptionUpdateOne) SetUser(v *User) *UserSubscriptionUpdateOne {
return _u.SetUserID(v.ID)
}
// SetGroup sets the "group" edge to the Group entity.
func (_u *UserSubscriptionUpdateOne) SetGroup(v *Group) *UserSubscriptionUpdateOne {
return _u.SetGroupID(v.ID)
}
// SetAssignedByUserID sets the "assigned_by_user" edge to the User entity by ID.
func (_u *UserSubscriptionUpdateOne) SetAssignedByUserID(id int64) *UserSubscriptionUpdateOne {
_u.mutation.SetAssignedByUserID(id)
return _u
}
// SetNillableAssignedByUserID sets the "assigned_by_user" edge to the User entity by ID if the given value is not nil.
func (_u *UserSubscriptionUpdateOne) SetNillableAssignedByUserID(id *int64) *UserSubscriptionUpdateOne {
if id != nil {
_u = _u.SetAssignedByUserID(*id)
}
return _u
}
// SetAssignedByUser sets the "assigned_by_user" edge to the User entity.
func (_u *UserSubscriptionUpdateOne) SetAssignedByUser(v *User) *UserSubscriptionUpdateOne {
return _u.SetAssignedByUserID(v.ID)
}
// Mutation returns the UserSubscriptionMutation object of the builder.
func (_u *UserSubscriptionUpdateOne) Mutation() *UserSubscriptionMutation {
return _u.mutation
}
// ClearUser clears the "user" edge to the User entity.
func (_u *UserSubscriptionUpdateOne) ClearUser() *UserSubscriptionUpdateOne {
_u.mutation.ClearUser()
return _u
}
// ClearGroup clears the "group" edge to the Group entity.
func (_u *UserSubscriptionUpdateOne) ClearGroup() *UserSubscriptionUpdateOne {
_u.mutation.ClearGroup()
return _u
}
// ClearAssignedByUser clears the "assigned_by_user" edge to the User entity.
func (_u *UserSubscriptionUpdateOne) ClearAssignedByUser() *UserSubscriptionUpdateOne {
_u.mutation.ClearAssignedByUser()
return _u
}
// Where appends a list predicates to the UserSubscriptionUpdate builder.
func (_u *UserSubscriptionUpdateOne) Where(ps ...predicate.UserSubscription) *UserSubscriptionUpdateOne {
_u.mutation.Where(ps...)
return _u
}
// Select allows selecting one or more fields (columns) of the returned entity.
// The default is selecting all fields defined in the entity schema.
func (_u *UserSubscriptionUpdateOne) Select(field string, fields ...string) *UserSubscriptionUpdateOne {
_u.fields = append([]string{field}, fields...)
return _u
}
// Save executes the query and returns the updated UserSubscription entity.
func (_u *UserSubscriptionUpdateOne) Save(ctx context.Context) (*UserSubscription, error) {
_u.defaults()
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *UserSubscriptionUpdateOne) SaveX(ctx context.Context) *UserSubscription {
node, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return node
}
// Exec executes the query on the entity.
func (_u *UserSubscriptionUpdateOne) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *UserSubscriptionUpdateOne) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *UserSubscriptionUpdateOne) defaults() {
if _, ok := _u.mutation.UpdatedAt(); !ok {
v := usersubscription.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_u *UserSubscriptionUpdateOne) check() error {
if v, ok := _u.mutation.Status(); ok {
if err := usersubscription.StatusValidator(v); err != nil {
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "UserSubscription.status": %w`, err)}
}
}
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "UserSubscription.user"`)
}
if _u.mutation.GroupCleared() && len(_u.mutation.GroupIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "UserSubscription.group"`)
}
return nil
}
func (_u *UserSubscriptionUpdateOne) sqlSave(ctx context.Context) (_node *UserSubscription, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(usersubscription.Table, usersubscription.Columns, sqlgraph.NewFieldSpec(usersubscription.FieldID, field.TypeInt64))
id, ok := _u.mutation.ID()
if !ok {
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "UserSubscription.id" for update`)}
}
_spec.Node.ID.Value = id
if fields := _u.fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, usersubscription.FieldID)
for _, f := range fields {
if !usersubscription.ValidColumn(f) {
return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
if f != usersubscription.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, f)
}
}
}
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(usersubscription.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.StartsAt(); ok {
_spec.SetField(usersubscription.FieldStartsAt, field.TypeTime, value)
}
if value, ok := _u.mutation.ExpiresAt(); ok {
_spec.SetField(usersubscription.FieldExpiresAt, field.TypeTime, value)
}
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(usersubscription.FieldStatus, field.TypeString, value)
}
if value, ok := _u.mutation.DailyWindowStart(); ok {
_spec.SetField(usersubscription.FieldDailyWindowStart, field.TypeTime, value)
}
if _u.mutation.DailyWindowStartCleared() {
_spec.ClearField(usersubscription.FieldDailyWindowStart, field.TypeTime)
}
if value, ok := _u.mutation.WeeklyWindowStart(); ok {
_spec.SetField(usersubscription.FieldWeeklyWindowStart, field.TypeTime, value)
}
if _u.mutation.WeeklyWindowStartCleared() {
_spec.ClearField(usersubscription.FieldWeeklyWindowStart, field.TypeTime)
}
if value, ok := _u.mutation.MonthlyWindowStart(); ok {
_spec.SetField(usersubscription.FieldMonthlyWindowStart, field.TypeTime, value)
}
if _u.mutation.MonthlyWindowStartCleared() {
_spec.ClearField(usersubscription.FieldMonthlyWindowStart, field.TypeTime)
}
if value, ok := _u.mutation.DailyUsageUsd(); ok {
_spec.SetField(usersubscription.FieldDailyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedDailyUsageUsd(); ok {
_spec.AddField(usersubscription.FieldDailyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.WeeklyUsageUsd(); ok {
_spec.SetField(usersubscription.FieldWeeklyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedWeeklyUsageUsd(); ok {
_spec.AddField(usersubscription.FieldWeeklyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.MonthlyUsageUsd(); ok {
_spec.SetField(usersubscription.FieldMonthlyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedMonthlyUsageUsd(); ok {
_spec.AddField(usersubscription.FieldMonthlyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AssignedAt(); ok {
_spec.SetField(usersubscription.FieldAssignedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.Notes(); ok {
_spec.SetField(usersubscription.FieldNotes, field.TypeString, value)
}
if _u.mutation.NotesCleared() {
_spec.ClearField(usersubscription.FieldNotes, field.TypeString)
}
if _u.mutation.UserCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: usersubscription.UserTable,
Columns: []string{usersubscription.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: usersubscription.UserTable,
Columns: []string{usersubscription.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.GroupCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: usersubscription.GroupTable,
Columns: []string{usersubscription.GroupColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.GroupIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: usersubscription.GroupTable,
Columns: []string{usersubscription.GroupColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.AssignedByUserCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: usersubscription.AssignedByUserTable,
Columns: []string{usersubscription.AssignedByUserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.AssignedByUserIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: usersubscription.AssignedByUserTable,
Columns: []string{usersubscription.AssignedByUserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
_node = &UserSubscription{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues
if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{usersubscription.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return nil, err
}
_u.mutation.done = true
return _node, nil
}
...@@ -5,6 +5,7 @@ go 1.24.0 ...@@ -5,6 +5,7 @@ go 1.24.0
toolchain go1.24.11 toolchain go1.24.11
require ( require (
entgo.io/ent v0.14.5
github.com/gin-gonic/gin v1.9.1 github.com/gin-gonic/gin v1.9.1
github.com/golang-jwt/jwt/v5 v5.2.0 github.com/golang-jwt/jwt/v5 v5.2.0
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
...@@ -23,17 +24,17 @@ require ( ...@@ -23,17 +24,17 @@ require (
golang.org/x/net v0.47.0 golang.org/x/net v0.47.0
golang.org/x/term v0.37.0 golang.org/x/term v0.37.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
gorm.io/datatypes v1.2.0
gorm.io/driver/postgres v1.5.4
gorm.io/gorm v1.25.5
) )
require ( require (
ariga.io/atlas v0.32.1-0.20250325101103-175b25e1c1b9 // indirect
dario.cat/mergo v1.0.2 // indirect dario.cat/mergo v1.0.2 // indirect
filippo.io/edwards25519 v1.1.0 // indirect
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/agext/levenshtein v1.2.3 // indirect
github.com/andybalholm/brotli v1.2.0 // indirect github.com/andybalholm/brotli v1.2.0 // indirect
github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect
github.com/bmatcuk/doublestar v1.3.4 // indirect
github.com/bytedance/sonic v1.9.1 // indirect github.com/bytedance/sonic v1.9.1 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect
...@@ -58,22 +59,19 @@ require ( ...@@ -58,22 +59,19 @@ require (
github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect github.com/go-ole/go-ole v1.2.6 // indirect
github.com/go-openapi/inflect v0.19.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.14.0 // indirect github.com/go-playground/validator/v10 v10.14.0 // indirect
github.com/go-sql-driver/mysql v1.9.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/go-querystring v1.1.0 // indirect github.com/google/go-querystring v1.1.0 // indirect
github.com/google/subcommands v1.2.0 // indirect github.com/google/subcommands v1.2.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect
github.com/hashicorp/hcl/v2 v2.18.1 // indirect
github.com/icholy/digest v1.1.0 // indirect github.com/icholy/digest v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/pgx/v5 v5.7.4 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.18.1 // indirect github.com/klauspost/compress v1.18.1 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect github.com/klauspost/cpuid/v2 v2.2.4 // indirect
...@@ -82,7 +80,9 @@ require ( ...@@ -82,7 +80,9 @@ require (
github.com/magiconair/properties v1.8.10 // indirect github.com/magiconair/properties v1.8.10 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-runewidth v0.0.15 // indirect
github.com/mdelapenya/tlscert v0.2.0 // indirect github.com/mdelapenya/tlscert v0.2.0 // indirect
github.com/mitchellh/go-wordwrap v1.0.1 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect
github.com/moby/go-archive v0.1.0 // indirect github.com/moby/go-archive v0.1.0 // indirect
...@@ -94,6 +94,7 @@ require ( ...@@ -94,6 +94,7 @@ require (
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/morikuni/aec v1.0.0 // indirect github.com/morikuni/aec v1.0.0 // indirect
github.com/olekukonko/tablewriter v0.0.5 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.1 // indirect github.com/opencontainers/image-spec v1.1.1 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect
...@@ -103,6 +104,7 @@ require ( ...@@ -103,6 +104,7 @@ require (
github.com/quic-go/qpack v0.5.1 // indirect github.com/quic-go/qpack v0.5.1 // indirect
github.com/quic-go/quic-go v0.56.0 // indirect github.com/quic-go/quic-go v0.56.0 // indirect
github.com/refraction-networking/utls v1.8.1 // indirect github.com/refraction-networking/utls v1.8.1 // indirect
github.com/rivo/uniseg v0.2.0 // indirect
github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect
github.com/shirou/gopsutil/v4 v4.25.6 // indirect github.com/shirou/gopsutil/v4 v4.25.6 // indirect
...@@ -111,6 +113,7 @@ require ( ...@@ -111,6 +113,7 @@ require (
github.com/spaolacci/murmur3 v1.1.0 // indirect github.com/spaolacci/murmur3 v1.1.0 // indirect
github.com/spf13/afero v1.11.0 // indirect github.com/spf13/afero v1.11.0 // indirect
github.com/spf13/cast v1.6.0 // indirect github.com/spf13/cast v1.6.0 // indirect
github.com/spf13/cobra v1.7.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect github.com/spf13/pflag v1.0.5 // indirect
github.com/subosito/gotenv v1.6.0 // indirect github.com/subosito/gotenv v1.6.0 // indirect
github.com/testcontainers/testcontainers-go v0.40.0 // indirect github.com/testcontainers/testcontainers-go v0.40.0 // indirect
...@@ -121,6 +124,8 @@ require ( ...@@ -121,6 +124,8 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect github.com/ugorji/go/codec v1.2.11 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect
github.com/zclconf/go-cty v1.14.4 // indirect
github.com/zclconf/go-cty-yaml v1.1.0 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
go.opentelemetry.io/otel v1.37.0 // indirect go.opentelemetry.io/otel v1.37.0 // indirect
...@@ -137,8 +142,8 @@ require ( ...@@ -137,8 +142,8 @@ require (
golang.org/x/sys v0.38.0 // indirect golang.org/x/sys v0.38.0 // indirect
golang.org/x/text v0.31.0 // indirect golang.org/x/text v0.31.0 // indirect
golang.org/x/tools v0.38.0 // indirect golang.org/x/tools v0.38.0 // indirect
golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated // indirect
google.golang.org/grpc v1.75.1 // indirect google.golang.org/grpc v1.75.1 // indirect
google.golang.org/protobuf v1.36.10 // indirect google.golang.org/protobuf v1.36.10 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect
gorm.io/driver/mysql v1.5.2 // indirect
) )
ariga.io/atlas v0.32.1-0.20250325101103-175b25e1c1b9 h1:E0wvcUXTkgyN4wy4LGtNzMNGMytJN8afmIWXJVMi4cc=
ariga.io/atlas v0.32.1-0.20250325101103-175b25e1c1b9/go.mod h1:Oe1xWPuu5q9LzyrWfbZmEZxFYeu4BHTyzfjeW2aZp/w=
dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8= dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA= dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= entgo.io/ent v0.14.5 h1:Rj2WOYJtCkWyFo6a+5wB3EfBRP0rnx1fMk6gGA0UUe4=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= entgo.io/ent v0.14.5/go.mod h1:zTzLmWtPvGpmSwtkaayM2cm5m819NdM7z7tYPq3vN0U=
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 h1:He8afgbRMd7mFxO99hRNu+6tazq8nFF9lIwo9JFroBk= github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 h1:He8afgbRMd7mFxO99hRNu+6tazq8nFF9lIwo9JFroBk=
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8= github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8= github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8=
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/agext/levenshtein v1.2.3 h1:YB2fHEn0UJagG8T1rrWknE3ZQzWM06O8AMAatNn7lmo=
github.com/agext/levenshtein v1.2.3/go.mod h1:JEDfjyjHDjOF/1e4FlBE/PkbqA9OfWu2ki2W0IB5558=
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY=
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
github.com/bmatcuk/doublestar v1.3.4 h1:gPypJ5xD31uhX6Tf54sDPUOBXTqKH4c9aPY66CyQrS0=
github.com/bmatcuk/doublestar v1.3.4/go.mod h1:wiQtGV+rzVYxB7WIlirSN++5HPtPlXEo9MEoZQC/PmE=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
...@@ -34,6 +44,7 @@ github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpS ...@@ -34,6 +44,7 @@ github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpS
github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw= github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw=
github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA= github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc= github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc=
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
...@@ -73,6 +84,8 @@ github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= ...@@ -73,6 +84,8 @@ github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-openapi/inflect v0.19.0 h1:9jCH9scKIbHeV9m12SmPilScz6krDxKRasNNSNPXu/4=
github.com/go-openapi/inflect v0.19.0/go.mod h1:lHpZVlpIQqLyKwJ4N+YSc9hchQy/i12fJykb83CRBH4=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
...@@ -81,17 +94,12 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn ...@@ -81,17 +94,12 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js= github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/go-test/deep v1.0.3 h1:ZrJSEWsXzPOxaZnFteGEfooLba+ju3FYIbOrS+rQd68=
github.com/go-sql-driver/mysql v1.9.0 h1:Y0zIbQXhQKmQgTp44Y1dp3wTXcn804QoTptLZT1vtvo= github.com/go-test/deep v1.0.3/go.mod h1:wGDj63lr65AM2AQyKZd/NYHGb0R+1RLqB8NKt3aSFNA=
github.com/go-sql-driver/mysql v1.9.0/go.mod h1:pDetrLJeA3oMujJuvXc8RJoasr589B6A9fwzD3QMrqw=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw= github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw=
github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA=
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A=
github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
...@@ -109,10 +117,14 @@ github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 h1:NmZ1PKzSTQbuGHw9DGPFomqkkLW ...@@ -109,10 +117,14 @@ github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 h1:NmZ1PKzSTQbuGHw9DGPFomqkkLW
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3/go.mod h1:zQrxl1YP88HQlA6i9c63DSVPFklWpGX4OWAc9bFuaH4= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3/go.mod h1:zQrxl1YP88HQlA6i9c63DSVPFklWpGX4OWAc9bFuaH4=
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
github.com/hashicorp/hcl/v2 v2.18.1 h1:6nxnOJFku1EuSawSD81fuviYUV8DxFr3fp2dUi3ZYSo=
github.com/hashicorp/hcl/v2 v2.18.1/go.mod h1:ThLC89FV4p9MPW804KVbe/cEXoQ8NZEh+JtMeeGErHE=
github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4= github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y= github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
github.com/imroc/req/v3 v3.56.0 h1:t6YdqqerYBXhZ9+VjqsQs5wlKxdUNEvsgBhxWc1AEEo= github.com/imroc/req/v3 v3.56.0 h1:t6YdqqerYBXhZ9+VjqsQs5wlKxdUNEvsgBhxWc1AEEo=
github.com/imroc/req/v3 v3.56.0/go.mod h1:cUZSooE8hhzFNOrAbdxuemXDQxFXLQTnu3066jr7ZGk= github.com/imroc/req/v3 v3.56.0/go.mod h1:cUZSooE8hhzFNOrAbdxuemXDQxFXLQTnu3066jr7ZGk=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
...@@ -121,10 +133,6 @@ github.com/jackc/pgx/v5 v5.7.4 h1:9wKznZrhWa2QiHL+NjTSPP6yjl3451BX3imWDnokYlg= ...@@ -121,10 +133,6 @@ github.com/jackc/pgx/v5 v5.7.4 h1:9wKznZrhWa2QiHL+NjTSPP6yjl3451BX3imWDnokYlg=
github.com/jackc/pgx/v5 v5.7.4/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ= github.com/jackc/pgx/v5 v5.7.4/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/compress v1.18.1 h1:bcSGx7UbpBqMChDtsF28Lw6v/G94LPrrbMbdC3JH2co= github.com/klauspost/compress v1.18.1 h1:bcSGx7UbpBqMChDtsF28Lw6v/G94LPrrbMbdC3JH2co=
...@@ -136,6 +144,8 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= ...@@ -136,6 +144,8 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
...@@ -149,12 +159,15 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk ...@@ -149,12 +159,15 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.15 h1:vfoHhTN1af61xCRSWzFIWzx2YskyMTwHLrExkBOjvxI= github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
github.com/mdelapenya/tlscert v0.2.0/go.mod h1:O4njj3ELLnJjGdkN7M/vIVCpZ+Cf0L6muqOG4tLSl8o= github.com/mdelapenya/tlscert v0.2.0/go.mod h1:O4njj3ELLnJjGdkN7M/vIVCpZ+Cf0L6muqOG4tLSl8o=
github.com/microsoft/go-mssqldb v0.17.0 h1:Fto83dMZPnYv1Zwx5vHHxpNraeEaUlQ/hhHLgZiaenE= github.com/mitchellh/go-wordwrap v1.0.1 h1:TLuKupo69TCn6TQSyGxwI1EblZZEsQ0vMlAFQflz0v0=
github.com/microsoft/go-mssqldb v0.17.0/go.mod h1:OkoNGhGEs8EZqchVTtochlXruEhEOaO4S0d2sB5aeGQ= github.com/mitchellh/go-wordwrap v1.0.1/go.mod h1:R62XHJLzvMFRBbcrT7m7WgmE1eOyTSsCt+hzestvNj0=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
...@@ -180,6 +193,8 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G ...@@ -180,6 +193,8 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
...@@ -203,12 +218,17 @@ github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4Vi ...@@ -203,12 +218,17 @@ github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4Vi
github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370= github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkrMJI0pRUOCAo= github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkrMJI0pRUOCAo=
github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ= github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ=
github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4= github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4=
github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE=
github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ= github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ=
github.com/sergi/go-diff v1.3.1 h1:xkr+Oxo4BOQKmkn/B9eMK0g5Kg/983T9DqqPHwYqD+8=
github.com/sergi/go-diff v1.3.1/go.mod h1:aMJSSKb2lpPvRNec0+w3fl7LP9IOFzdc9Pa4NFbPK1I=
github.com/shirou/gopsutil/v4 v4.25.6 h1:kLysI2JsKorfaFPcYmcJqbzROzsBWEOAtw6A7dIfqXs= github.com/shirou/gopsutil/v4 v4.25.6 h1:kLysI2JsKorfaFPcYmcJqbzROzsBWEOAtw6A7dIfqXs=
github.com/shirou/gopsutil/v4 v4.25.6/go.mod h1:PfybzyydfZcN+JMMjkF6Zb8Mq1A/VcogFFg7hj50W9c= github.com/shirou/gopsutil/v4 v4.25.6/go.mod h1:PfybzyydfZcN+JMMjkF6Zb8Mq1A/VcogFFg7hj50W9c=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
...@@ -221,6 +241,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= ...@@ -221,6 +241,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ= github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
...@@ -269,6 +291,10 @@ github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZ ...@@ -269,6 +291,10 @@ github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZ
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
github.com/zclconf/go-cty v1.14.4 h1:uXXczd9QDGsgu0i/QFR/hzI5NYCHLf6NQw/atrbnhq8=
github.com/zclconf/go-cty v1.14.4/go.mod h1:VvMs5i0vgZdhYawQNq5kePSpLAoz8u1xvZgrPIxfnZE=
github.com/zclconf/go-cty-yaml v1.1.0 h1:nP+jp0qPHv2IhUVqmQSzjvqAWcObN0KBkUl2rWBdig0=
github.com/zclconf/go-cty-yaml v1.1.0/go.mod h1:9YLUH4g7lOhVWqUbctnVlZ5KLpg7JAprQNgxSZ1Gyxs=
github.com/zeromicro/go-zero v1.9.4 h1:aRLFoISqAYijABtkbliQC5SsI5TbizJpQvoHc9xup8k= github.com/zeromicro/go-zero v1.9.4 h1:aRLFoISqAYijABtkbliQC5SsI5TbizJpQvoHc9xup8k=
github.com/zeromicro/go-zero v1.9.4/go.mod h1:a17JOTch25SWxBcUgJZYps60hygK3pIYdw7nGwlcS38= github.com/zeromicro/go-zero v1.9.4/go.mod h1:a17JOTch25SWxBcUgJZYps60hygK3pIYdw7nGwlcS38=
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
...@@ -329,6 +355,10 @@ golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= ...@@ -329,6 +355,10 @@ golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ= golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs= golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
golang.org/x/tools/go/expect v0.1.0-deprecated h1:jY2C5HGYR5lqex3gEniOQL0r7Dq5+VGVgY1nudX5lXY=
golang.org/x/tools/go/expect v0.1.0-deprecated/go.mod h1:eihoPOH+FgIqa3FpoTwguz/bVUSGBlGQU67vpBeOrBY=
golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated h1:1h2MnaIAIXISqTFKdENegdpAgUXz6NrPEsbIeWaBRvM=
golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated/go.mod h1:RVAQXBGNv1ib0J382/DPCRS/BPnsGebyM1Gj5VSDpG8=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ= google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ=
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU= google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU=
...@@ -347,19 +377,6 @@ gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= ...@@ -347,19 +377,6 @@ gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/datatypes v1.2.0 h1:5YT+eokWdIxhJgWHdrb2zYUimyk0+TaFth+7a0ybzco=
gorm.io/datatypes v1.2.0/go.mod h1:o1dh0ZvjIjhH/bngTpypG6lVRJ5chTBxE09FH/71k04=
gorm.io/driver/mysql v1.5.2 h1:QC2HRskSE75wBuOxe0+iCkyJZ+RqpudsQtqkp+IMuXs=
gorm.io/driver/mysql v1.5.2/go.mod h1:pQLhh1Ut/WUAySdTHwBpBv6+JKcj+ua4ZFx1QQTBzb8=
gorm.io/driver/postgres v1.5.4 h1:Iyrp9Meh3GmbSuyIAGyjkN+n9K+GHX9b9MqsTL4EJCo=
gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0=
gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU=
gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI=
gorm.io/driver/sqlserver v1.4.1 h1:t4r4r6Jam5E6ejqP7N82qAJIJAht27EGT41HyPfXRw0=
gorm.io/driver/sqlserver v1.4.1/go.mod h1:DJ4P+MeZbc5rvY58PnmN1Lnyvb5gw5NPzGshHDnJLig=
gorm.io/gorm v1.25.2-0.20230530020048-26663ab9bf55/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls=
gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q= gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA= gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
...@@ -195,8 +195,10 @@ func setDefaults() { ...@@ -195,8 +195,10 @@ func setDefaults() {
viper.SetDefault("jwt.expire_hour", 24) viper.SetDefault("jwt.expire_hour", 24)
// Default // Default
viper.SetDefault("default.admin_email", "admin@sub2api.com") // Admin credentials are created via the setup flow (web wizard / CLI / AUTO_SETUP).
viper.SetDefault("default.admin_password", "admin123") // Do not ship fixed defaults here to avoid insecure "known credentials" in production.
viper.SetDefault("default.admin_email", "")
viper.SetDefault("default.admin_password", "")
viper.SetDefault("default.user_concurrency", 5) viper.SetDefault("default.user_concurrency", 5)
viper.SetDefault("default.user_balance", 0) viper.SetDefault("default.user_balance", 0)
viper.SetDefault("default.api_key_prefix", "sk-") viper.SetDefault("default.api_key_prefix", "sk-")
......
package infrastructure
import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/repository"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// InitDB 初始化数据库连接
func InitDB(cfg *config.Config) (*gorm.DB, error) {
// 初始化时区(在数据库连接之前,确保时区设置正确)
if err := timezone.Init(cfg.Timezone); err != nil {
return nil, err
}
gormConfig := &gorm.Config{}
if cfg.Server.Mode == "debug" {
gormConfig.Logger = logger.Default.LogMode(logger.Info)
}
// 使用带时区的 DSN 连接数据库
db, err := gorm.Open(postgres.Open(cfg.Database.DSNWithTimezone(cfg.Timezone)), gormConfig)
if err != nil {
return nil, err
}
// 自动迁移(始终执行,确保数据库结构与代码同步)
// GORM 的 AutoMigrate 只会添加新字段,不会删除或修改已有字段,是安全的
if err := repository.AutoMigrate(db); err != nil {
return nil, err
}
return db, nil
}
// Package infrastructure 提供应用程序的基础设施层组件。
// 包括数据库连接初始化、ORM 客户端管理、Redis 连接、数据库迁移等核心功能。
package infrastructure
import (
"context"
"database/sql"
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/migrations"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "github.com/lib/pq" // PostgreSQL 驱动,通过副作用导入注册驱动
)
// InitEnt 初始化 Ent ORM 客户端并返回客户端实例和底层的 *sql.DB。
//
// 该函数执行以下操作:
// 1. 初始化全局时区设置,确保时间处理一致性
// 2. 建立 PostgreSQL 数据库连接
// 3. 自动执行数据库迁移,确保 schema 与代码同步
// 4. 创建并返回 Ent 客户端实例
//
// 重要提示:调用者必须负责关闭返回的 ent.Client(关闭时会自动关闭底层的 driver/db)。
//
// 参数:
// - cfg: 应用程序配置,包含数据库连接信息和时区设置
//
// 返回:
// - *ent.Client: Ent ORM 客户端,用于执行数据库操作
// - *sql.DB: 底层的 SQL 数据库连接,可用于直接执行原生 SQL
// - error: 初始化过程中的错误
func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) {
// 优先初始化时区设置,确保所有时间操作使用统一的时区。
// 这对于跨时区部署和日志时间戳的一致性至关重要。
if err := timezone.Init(cfg.Timezone); err != nil {
return nil, nil, err
}
// 构建包含时区信息的数据库连接字符串 (DSN)。
// 时区信息会传递给 PostgreSQL,确保数据库层面的时间处理正确。
dsn := cfg.Database.DSNWithTimezone(cfg.Timezone)
// 使用 Ent 的 SQL 驱动打开 PostgreSQL 连接。
// dialect.Postgres 指定使用 PostgreSQL 方言进行 SQL 生成。
drv, err := entsql.Open(dialect.Postgres, dsn)
if err != nil {
return nil, nil, err
}
// 确保数据库 schema 已准备就绪。
// SQL 迁移文件是 schema 的权威来源(source of truth)。
// 这种方式比 Ent 的自动迁移更可控,支持复杂的迁移场景。
if err := applyMigrationsFS(context.Background(), drv.DB(), migrations.FS); err != nil {
_ = drv.Close() // 迁移失败时关闭驱动,避免资源泄露
return nil, nil, err
}
// 创建 Ent 客户端,绑定到已配置的数据库驱动。
client := ent.NewClient(ent.Driver(drv))
return client, drv.DB(), nil
}
package infrastructure
import (
"context"
"crypto/sha256"
"database/sql"
"encoding/hex"
"errors"
"fmt"
"io/fs"
"sort"
"strings"
"github.com/Wei-Shaw/sub2api/migrations"
)
// schemaMigrationsTableDDL 定义迁移记录表的 DDL。
// 该表用于跟踪已应用的迁移文件及其校验和。
// - filename: 迁移文件名,作为主键唯一标识每个迁移
// - checksum: 文件内容的 SHA256 哈希值,用于检测迁移文件是否被篡改
// - applied_at: 迁移应用时间戳
const schemaMigrationsTableDDL = `
CREATE TABLE IF NOT EXISTS schema_migrations (
filename TEXT PRIMARY KEY,
checksum TEXT NOT NULL,
applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
`
// migrationsAdvisoryLockID 是用于序列化迁移操作的 PostgreSQL Advisory Lock ID。
// 在多实例部署场景下,该锁确保同一时间只有一个实例执行迁移。
// 任何稳定的 int64 值都可以,只要不与同一数据库中的其他锁冲突即可。
const migrationsAdvisoryLockID int64 = 694208311321144027
// ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。
//
// 该函数可以在每次应用启动时安全调用:
// - 已应用的迁移会被自动跳过(通过校验 filename 判断)
// - 如果迁移文件内容被修改(checksum 不匹配),会返回错误
// - 使用 PostgreSQL Advisory Lock 确保多实例并发安全
//
// 参数:
// - ctx: 上下文,用于超时控制和取消
// - db: 数据库连接
//
// 返回:
// - error: 迁移过程中的任何错误
func ApplyMigrations(ctx context.Context, db *sql.DB) error {
if db == nil {
return errors.New("nil sql db")
}
return applyMigrationsFS(ctx, db, migrations.FS)
}
// applyMigrationsFS 是迁移执行的核心实现。
// 它从指定的文件系统读取 SQL 迁移文件并按顺序应用。
//
// 迁移执行流程:
// 1. 获取 PostgreSQL Advisory Lock,防止多实例并发迁移
// 2. 确保 schema_migrations 表存在
// 3. 按文件名排序读取所有 .sql 文件
// 4. 对于每个迁移文件:
// - 计算文件内容的 SHA256 校验和
// - 检查该迁移是否已应用(通过 filename 查询)
// - 如果已应用,验证校验和是否匹配
// - 如果未应用,在事务中执行迁移并记录
// 5. 释放 Advisory Lock
//
// 参数:
// - ctx: 上下文
// - db: 数据库连接
// - fsys: 包含迁移文件的文件系统(通常是 embed.FS)
func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
if db == nil {
return errors.New("nil sql db")
}
// 获取分布式锁,确保多实例部署时只有一个实例执行迁移。
// 这是 PostgreSQL 特有的 Advisory Lock 机制。
if err := pgAdvisoryLock(ctx, db); err != nil {
return err
}
defer func() {
// 无论迁移是否成功,都要释放锁。
// 使用 context.Background() 确保即使原 ctx 已取消也能释放锁。
_ = pgAdvisoryUnlock(context.Background(), db)
}()
// 创建迁移记录表(如果不存在)。
// 该表记录所有已应用的迁移及其校验和。
if _, err := db.ExecContext(ctx, schemaMigrationsTableDDL); err != nil {
return fmt.Errorf("create schema_migrations: %w", err)
}
// 获取所有 .sql 迁移文件并按文件名排序。
// 命名规范:使用零填充数字前缀(如 001_init.sql, 002_add_users.sql)。
files, err := fs.Glob(fsys, "*.sql")
if err != nil {
return fmt.Errorf("list migrations: %w", err)
}
sort.Strings(files) // 确保按文件名顺序执行迁移
for _, name := range files {
// 读取迁移文件内容
contentBytes, err := fs.ReadFile(fsys, name)
if err != nil {
return fmt.Errorf("read migration %s: %w", name, err)
}
content := strings.TrimSpace(string(contentBytes))
if content == "" {
continue // 跳过空文件
}
// 计算文件内容的 SHA256 校验和,用于检测文件是否被修改。
// 这是一种防篡改机制:如果有人修改了已应用的迁移文件,系统会拒绝启动。
sum := sha256.Sum256([]byte(content))
checksum := hex.EncodeToString(sum[:])
// 检查该迁移是否已经应用
var existing string
rowErr := db.QueryRowContext(ctx, "SELECT checksum FROM schema_migrations WHERE filename = $1", name).Scan(&existing)
if rowErr == nil {
// 迁移已应用,验证校验和是否匹配
if existing != checksum {
// 校验和不匹配意味着迁移文件在应用后被修改,这是危险的。
// 正确的做法是创建新的迁移文件来进行变更。
return fmt.Errorf("migration %s checksum mismatch (db=%s file=%s)", name, existing, checksum)
}
continue // 迁移已应用且校验和匹配,跳过
}
if !errors.Is(rowErr, sql.ErrNoRows) {
return fmt.Errorf("check migration %s: %w", name, rowErr)
}
// 迁移未应用,在事务中执行。
// 使用事务确保迁移的原子性:要么完全成功,要么完全回滚。
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("begin migration %s: %w", name, err)
}
// 执行迁移 SQL
if _, err := tx.ExecContext(ctx, content); err != nil {
_ = tx.Rollback()
return fmt.Errorf("apply migration %s: %w", name, err)
}
// 记录迁移已完成,保存文件名和校验和
if _, err := tx.ExecContext(ctx, "INSERT INTO schema_migrations (filename, checksum) VALUES ($1, $2)", name, checksum); err != nil {
_ = tx.Rollback()
return fmt.Errorf("record migration %s: %w", name, err)
}
// 提交事务
if err := tx.Commit(); err != nil {
_ = tx.Rollback()
return fmt.Errorf("commit migration %s: %w", name, err)
}
}
return nil
}
// pgAdvisoryLock 获取 PostgreSQL Advisory Lock。
// Advisory Lock 是一种轻量级的锁机制,不与任何特定的数据库对象关联。
// 它非常适合用于应用层面的分布式锁场景,如迁移序列化。
func pgAdvisoryLock(ctx context.Context, db *sql.DB) error {
_, err := db.ExecContext(ctx, "SELECT pg_advisory_lock($1)", migrationsAdvisoryLockID)
if err != nil {
return fmt.Errorf("acquire migrations lock: %w", err)
}
return nil
}
// pgAdvisoryUnlock 释放 PostgreSQL Advisory Lock。
// 必须在获取锁后确保释放,否则会阻塞其他实例的迁移操作。
func pgAdvisoryUnlock(ctx context.Context, db *sql.DB) error {
_, err := db.ExecContext(ctx, "SELECT pg_advisory_unlock($1)", migrationsAdvisoryLockID)
if err != nil {
return fmt.Errorf("release migrations lock: %w", err)
}
return nil
}
package infrastructure package infrastructure
import ( import (
"database/sql"
"errors"
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/google/wire" "github.com/google/wire"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"gorm.io/gorm"
entsql "entgo.io/ent/dialect/sql"
) )
// ProviderSet 提供基础设施层的依赖 // ProviderSet 是基础设施层的 Wire 依赖提供者集合。
//
// Wire 是 Google 开发的编译时依赖注入工具。ProviderSet 将相关的依赖提供函数
// 组织在一起,便于在应用程序启动时自动组装依赖关系。
//
// 包含的提供者:
// - ProvideEnt: 提供 Ent ORM 客户端
// - ProvideSQLDB: 提供底层 SQL 数据库连接
// - ProvideRedis: 提供 Redis 客户端
var ProviderSet = wire.NewSet( var ProviderSet = wire.NewSet(
ProvideDB, ProvideEnt,
ProvideSQLDB,
ProvideRedis, ProvideRedis,
) )
// ProvideDB 提供数据库连接 // ProvideEnt 为依赖注入提供 Ent 客户端。
func ProvideDB(cfg *config.Config) (*gorm.DB, error) { //
return InitDB(cfg) // 该函数是 InitEnt 的包装器,符合 Wire 的依赖提供函数签名要求。
// Wire 会在编译时分析依赖关系,自动生成初始化代码。
//
// 依赖:config.Config
// 提供:*ent.Client
func ProvideEnt(cfg *config.Config) (*ent.Client, error) {
client, _, err := InitEnt(cfg)
return client, err
}
// ProvideSQLDB 从 Ent 客户端提取底层的 *sql.DB 连接。
//
// 某些 Repository 需要直接执行原生 SQL(如复杂的批量更新、聚合查询),
// 此时需要访问底层的 sql.DB 而不是通过 Ent ORM。
//
// 设计说明:
// - Ent 底层使用 sql.DB,通过 Driver 接口可以访问
// - 这种设计允许在同一事务中混用 Ent 和原生 SQL
//
// 依赖:*ent.Client
// 提供:*sql.DB
func ProvideSQLDB(client *ent.Client) (*sql.DB, error) {
if client == nil {
return nil, errors.New("nil ent client")
}
// 从 Ent 客户端获取底层驱动
drv, ok := client.Driver().(*entsql.Driver)
if !ok {
return nil, errors.New("ent driver does not expose *sql.DB")
}
// 返回驱动持有的 sql.DB 实例
return drv.DB(), nil
} }
// ProvideRedis 提供 Redis 客户端 // ProvideRedis 为依赖注入提供 Redis 客户端。
//
// Redis 用于:
// - 分布式锁(如并发控制)
// - 缓存(如用户会话、API 响应缓存)
// - 速率限制
// - 实时统计数据
//
// 依赖:config.Config
// 提供:*redis.Client
func ProvideRedis(cfg *config.Config) *redis.Client { func ProvideRedis(cfg *config.Config) *redis.Client {
return InitRedis(cfg) return InitRedis(cfg)
} }
// Package repository 实现数据访问层(Repository Pattern)。
//
// 该包提供了与数据库交互的所有操作,包括 CRUD、复杂查询和批量操作。
// 采用 Repository 模式将数据访问逻辑与业务逻辑分离,便于测试和维护。
//
// 主要特性:
// - 使用 Ent ORM 进行类型安全的数据库操作
// - 对于复杂查询(如批量更新、聚合统计)使用原生 SQL
// - 提供统一的错误翻译机制,将数据库错误转换为业务错误
// - 支持软删除,所有查询自动过滤已删除记录
package repository package repository
import ( import (
"context" "context"
"errors" "database/sql"
"encoding/json"
"strconv"
"time" "time"
dbent "github.com/Wei-Shaw/sub2api/ent"
dbaccount "github.com/Wei-Shaw/sub2api/ent/account"
dbaccountgroup "github.com/Wei-Shaw/sub2api/ent/accountgroup"
dbgroup "github.com/Wei-Shaw/sub2api/ent/group"
dbpredicate "github.com/Wei-Shaw/sub2api/ent/predicate"
dbproxy "github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
"gorm.io/datatypes" entsql "entgo.io/ent/dialect/sql"
"gorm.io/gorm"
"gorm.io/gorm/clause"
) )
// accountRepository 实现 service.AccountRepository 接口。
// 提供 AI API 账户的完整数据访问功能。
//
// 设计说明:
// - client: Ent 客户端,用于类型安全的 ORM 操作
// - sql: 原生 SQL 执行器,用于复杂查询和批量操作
// - begin: SQL 事务开启器,用于需要事务的操作
type accountRepository struct { type accountRepository struct {
db *gorm.DB client *dbent.Client // Ent ORM 客户端
sql sqlExecutor // 原生 SQL 执行接口
begin sqlBeginner // 事务开启接口
} }
func NewAccountRepository(db *gorm.DB) service.AccountRepository { // NewAccountRepository 创建账户仓储实例。
return &accountRepository{db: db} // 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB) service.AccountRepository {
return newAccountRepositoryWithSQL(client, sqlDB)
}
// newAccountRepositoryWithSQL 是内部构造函数,支持依赖注入 SQL 执行器。
// 这种设计便于单元测试时注入 mock 对象。
func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *accountRepository {
var beginner sqlBeginner
if b, ok := sqlq.(sqlBeginner); ok {
beginner = b
}
return &accountRepository{client: client, sql: sqlq, begin: beginner}
} }
func (r *accountRepository) Create(ctx context.Context, account *service.Account) error { func (r *accountRepository) Create(ctx context.Context, account *service.Account) error {
m := accountModelFromService(account) if account == nil {
err := r.db.WithContext(ctx).Create(m).Error return nil
if err == nil {
applyAccountModelToService(account, m)
} }
return err
builder := r.client.Account.Create().
SetName(account.Name).
SetPlatform(account.Platform).
SetType(account.Type).
SetCredentials(normalizeJSONMap(account.Credentials)).
SetExtra(normalizeJSONMap(account.Extra)).
SetConcurrency(account.Concurrency).
SetPriority(account.Priority).
SetStatus(account.Status).
SetErrorMessage(account.ErrorMessage).
SetSchedulable(account.Schedulable)
if account.ProxyID != nil {
builder.SetProxyID(*account.ProxyID)
}
if account.LastUsedAt != nil {
builder.SetLastUsedAt(*account.LastUsedAt)
}
if account.RateLimitedAt != nil {
builder.SetRateLimitedAt(*account.RateLimitedAt)
}
if account.RateLimitResetAt != nil {
builder.SetRateLimitResetAt(*account.RateLimitResetAt)
}
if account.OverloadUntil != nil {
builder.SetOverloadUntil(*account.OverloadUntil)
}
if account.SessionWindowStart != nil {
builder.SetSessionWindowStart(*account.SessionWindowStart)
}
if account.SessionWindowEnd != nil {
builder.SetSessionWindowEnd(*account.SessionWindowEnd)
}
if account.SessionWindowStatus != "" {
builder.SetSessionWindowStatus(account.SessionWindowStatus)
}
created, err := builder.Save(ctx)
if err != nil {
return err
}
account.ID = created.ID
account.CreatedAt = created.CreatedAt
account.UpdatedAt = created.UpdatedAt
return nil
} }
func (r *accountRepository) GetByID(ctx context.Context, id int64) (*service.Account, error) { func (r *accountRepository) GetByID(ctx context.Context, id int64) (*service.Account, error) {
var m accountModel m, err := r.client.Account.Query().Where(dbaccount.IDEQ(id)).Only(ctx)
err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups.Group").First(&m, id).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrAccountNotFound, nil) return nil, translatePersistenceError(err, service.ErrAccountNotFound, nil)
} }
return accountModelToService(&m), nil
accounts, err := r.accountsToService(ctx, []*dbent.Account{m})
if err != nil {
return nil, err
}
if len(accounts) == 0 {
return nil, service.ErrAccountNotFound
}
return &accounts[0], nil
} }
func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*service.Account, error) { func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*service.Account, error) {
...@@ -44,31 +133,100 @@ func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID ...@@ -44,31 +133,100 @@ func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID
return nil, nil return nil, nil
} }
var m accountModel m, err := r.client.Account.Query().
err := r.db.WithContext(ctx).Where("extra->>'crs_account_id' = ?", crsAccountID).First(&m).Error Where(func(s *entsql.Selector) {
s.Where(entsql.ExprP("extra->>'crs_account_id' = ?", crsAccountID))
}).
Only(ctx)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if dbent.IsNotFound(err) {
return nil, nil return nil, nil
} }
return nil, err return nil, err
} }
return accountModelToService(&m), nil
accounts, err := r.accountsToService(ctx, []*dbent.Account{m})
if err != nil {
return nil, err
}
if len(accounts) == 0 {
return nil, nil
}
return &accounts[0], nil
} }
func (r *accountRepository) Update(ctx context.Context, account *service.Account) error { func (r *accountRepository) Update(ctx context.Context, account *service.Account) error {
m := accountModelFromService(account) if account == nil {
err := r.db.WithContext(ctx).Save(m).Error return nil
if err == nil {
applyAccountModelToService(account, m)
} }
return err
builder := r.client.Account.UpdateOneID(account.ID).
SetName(account.Name).
SetPlatform(account.Platform).
SetType(account.Type).
SetCredentials(normalizeJSONMap(account.Credentials)).
SetExtra(normalizeJSONMap(account.Extra)).
SetConcurrency(account.Concurrency).
SetPriority(account.Priority).
SetStatus(account.Status).
SetErrorMessage(account.ErrorMessage).
SetSchedulable(account.Schedulable)
if account.ProxyID != nil {
builder.SetProxyID(*account.ProxyID)
} else {
builder.ClearProxyID()
}
if account.LastUsedAt != nil {
builder.SetLastUsedAt(*account.LastUsedAt)
} else {
builder.ClearLastUsedAt()
}
if account.RateLimitedAt != nil {
builder.SetRateLimitedAt(*account.RateLimitedAt)
} else {
builder.ClearRateLimitedAt()
}
if account.RateLimitResetAt != nil {
builder.SetRateLimitResetAt(*account.RateLimitResetAt)
} else {
builder.ClearRateLimitResetAt()
}
if account.OverloadUntil != nil {
builder.SetOverloadUntil(*account.OverloadUntil)
} else {
builder.ClearOverloadUntil()
}
if account.SessionWindowStart != nil {
builder.SetSessionWindowStart(*account.SessionWindowStart)
} else {
builder.ClearSessionWindowStart()
}
if account.SessionWindowEnd != nil {
builder.SetSessionWindowEnd(*account.SessionWindowEnd)
} else {
builder.ClearSessionWindowEnd()
}
if account.SessionWindowStatus != "" {
builder.SetSessionWindowStatus(account.SessionWindowStatus)
} else {
builder.ClearSessionWindowStatus()
}
updated, err := builder.Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrAccountNotFound, nil)
}
account.UpdatedAt = updated.UpdatedAt
return nil
} }
func (r *accountRepository) Delete(ctx context.Context, id int64) error { func (r *accountRepository) Delete(ctx context.Context, id int64) error {
if err := r.db.WithContext(ctx).Where("account_id = ?", id).Delete(&accountGroupModel{}).Error; err != nil { if _, err := r.client.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(id)).Exec(ctx); err != nil {
return err return err
} }
return r.db.WithContext(ctx).Delete(&accountModel{}, id).Error _, err := r.client.Account.Delete().Where(dbaccount.IDEQ(id)).Exec(ctx)
return err
} }
func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
...@@ -76,99 +234,84 @@ func (r *accountRepository) List(ctx context.Context, params pagination.Paginati ...@@ -76,99 +234,84 @@ func (r *accountRepository) List(ctx context.Context, params pagination.Paginati
} }
func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) { func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) {
var accounts []accountModel q := r.client.Account.Query()
var total int64
db := r.db.WithContext(ctx).Model(&accountModel{})
if platform != "" { if platform != "" {
db = db.Where("platform = ?", platform) q = q.Where(dbaccount.PlatformEQ(platform))
} }
if accountType != "" { if accountType != "" {
db = db.Where("type = ?", accountType) q = q.Where(dbaccount.TypeEQ(accountType))
} }
if status != "" { if status != "" {
db = db.Where("status = ?", status) q = q.Where(dbaccount.StatusEQ(status))
} }
if search != "" { if search != "" {
searchPattern := "%" + search + "%" q = q.Where(dbaccount.NameContainsFold(search))
db = db.Where("name ILIKE ?", searchPattern)
} }
if err := db.Count(&total).Error; err != nil { total, err := q.Count(ctx)
if err != nil {
return nil, nil, err return nil, nil, err
} }
if err := db.Preload("Proxy").Preload("AccountGroups.Group").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&accounts).Error; err != nil { accounts, err := q.
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(dbaccount.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err return nil, nil, err
} }
outAccounts := make([]service.Account, 0, len(accounts)) outAccounts, err := r.accountsToService(ctx, accounts)
for i := range accounts { if err != nil {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i])) return nil, nil, err
} }
return outAccounts, paginationResultFromTotal(int64(total), params), nil
return outAccounts, paginationResultFromTotal(total, params), nil
} }
func (r *accountRepository) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) { func (r *accountRepository) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {
var accounts []accountModel accounts, err := r.queryAccountsByGroup(ctx, groupID, accountGroupQueryOptions{
err := r.db.WithContext(ctx). status: service.StatusActive,
Joins("JOIN account_groups ON account_groups.account_id = accounts.id"). })
Where("account_groups.group_id = ? AND accounts.status = ?", groupID, service.StatusActive).
Preload("Proxy").
Order("account_groups.priority ASC, accounts.priority ASC").
Find(&accounts).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
return accounts, nil
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
} }
func (r *accountRepository) ListActive(ctx context.Context) ([]service.Account, error) { func (r *accountRepository) ListActive(ctx context.Context) ([]service.Account, error) {
var accounts []accountModel accounts, err := r.client.Account.Query().
err := r.db.WithContext(ctx). Where(dbaccount.StatusEQ(service.StatusActive)).
Where("status = ?", service.StatusActive). Order(dbent.Asc(dbaccount.FieldPriority)).
Preload("Proxy"). All(ctx)
Order("priority ASC").
Find(&accounts).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
return r.accountsToService(ctx, accounts)
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
} }
func (r *accountRepository) ListByPlatform(ctx context.Context, platform string) ([]service.Account, error) { func (r *accountRepository) ListByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
var accounts []accountModel accounts, err := r.client.Account.Query().
err := r.db.WithContext(ctx). Where(
Where("platform = ? AND status = ?", platform, service.StatusActive). dbaccount.PlatformEQ(platform),
Preload("Proxy"). dbaccount.StatusEQ(service.StatusActive),
Order("priority ASC"). ).
Find(&accounts).Error Order(dbent.Asc(dbaccount.FieldPriority)).
All(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return r.accountsToService(ctx, accounts)
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
} }
func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error { func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error {
now := time.Now() now := time.Now()
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).Update("last_used_at", now).Error _, err := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
SetLastUsedAt(now).
Save(ctx)
return err
} }
func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error { func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
...@@ -176,63 +319,72 @@ func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map ...@@ -176,63 +319,72 @@ func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map
return nil return nil
} }
var caseSql = "UPDATE accounts SET last_used_at = CASE id" ids := make([]int64, 0, len(updates))
var args []any args := make([]any, 0, len(updates)*2+1)
var ids []int64 caseSQL := "UPDATE accounts SET last_used_at = CASE id"
idx := 1
for id, ts := range updates { for id, ts := range updates {
caseSql += " WHEN ? THEN CAST(? AS TIMESTAMP)" caseSQL += " WHEN $" + itoa(idx) + " THEN $" + itoa(idx+1)
args = append(args, id, ts) args = append(args, id, ts)
ids = append(ids, id) ids = append(ids, id)
idx += 2
} }
caseSql += " END WHERE id IN ?" caseSQL += " END, updated_at = NOW() WHERE id = ANY($" + itoa(idx) + ") AND deleted_at IS NULL"
args = append(args, ids) args = append(args, pq.Array(ids))
return r.db.WithContext(ctx).Exec(caseSql, args...).Error _, err := r.sql.ExecContext(ctx, caseSQL, args...)
return err
} }
func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg string) error { func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id). _, err := r.client.Account.Update().
Updates(map[string]any{ Where(dbaccount.IDEQ(id)).
"status": service.StatusError, SetStatus(service.StatusError).
"error_message": errorMsg, SetErrorMessage(errorMsg).
}).Error Save(ctx)
return err
} }
func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error { func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
ag := &accountGroupModel{ _, err := r.client.AccountGroup.Create().
AccountID: accountID, SetAccountID(accountID).
GroupID: groupID, SetGroupID(groupID).
Priority: priority, SetPriority(priority).
} Save(ctx)
return r.db.WithContext(ctx).Create(ag).Error return err
} }
func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error { func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error {
return r.db.WithContext(ctx).Where("account_id = ? AND group_id = ?", accountID, groupID). _, err := r.client.AccountGroup.Delete().
Delete(&accountGroupModel{}).Error Where(
dbaccountgroup.AccountIDEQ(accountID),
dbaccountgroup.GroupIDEQ(groupID),
).
Exec(ctx)
return err
} }
func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]service.Group, error) { func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]service.Group, error) {
var groups []groupModel groups, err := r.client.Group.Query().
err := r.db.WithContext(ctx). Where(
Joins("JOIN account_groups ON account_groups.group_id = groups.id"). dbgroup.HasAccountsWith(dbaccount.IDEQ(accountID)),
Where("account_groups.account_id = ?", accountID). ).
Find(&groups).Error All(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
outGroups := make([]service.Group, 0, len(groups)) outGroups := make([]service.Group, 0, len(groups))
for i := range groups { for i := range groups {
outGroups = append(outGroups, *groupModelToService(&groups[i])) outGroups = append(outGroups, *groupEntityToService(groups[i]))
} }
return outGroups, nil return outGroups, nil
} }
func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error { func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
if err := r.db.WithContext(ctx).Where("account_id = ?", accountID).Delete(&accountGroupModel{}).Error; err != nil { if _, err := r.client.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(accountID)).Exec(ctx); err != nil {
return err return err
} }
...@@ -240,142 +392,117 @@ func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, gro ...@@ -240,142 +392,117 @@ func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, gro
return nil return nil
} }
accountGroups := make([]accountGroupModel, 0, len(groupIDs)) builders := make([]*dbent.AccountGroupCreate, 0, len(groupIDs))
for i, groupID := range groupIDs { for i, groupID := range groupIDs {
accountGroups = append(accountGroups, accountGroupModel{ builders = append(builders, r.client.AccountGroup.Create().
AccountID: accountID, SetAccountID(accountID).
GroupID: groupID, SetGroupID(groupID).
Priority: i + 1, SetPriority(i+1),
}) )
} }
return r.db.WithContext(ctx).Create(&accountGroups).Error
_, err := r.client.AccountGroup.CreateBulk(builders...).Save(ctx)
return err
} }
func (r *accountRepository) ListSchedulable(ctx context.Context) ([]service.Account, error) { func (r *accountRepository) ListSchedulable(ctx context.Context) ([]service.Account, error) {
var accounts []accountModel
now := time.Now() now := time.Now()
err := r.db.WithContext(ctx). accounts, err := r.client.Account.Query().
Where("status = ? AND schedulable = ?", service.StatusActive, true). Where(
Where("(overload_until IS NULL OR overload_until <= ?)", now). dbaccount.StatusEQ(service.StatusActive),
Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now). dbaccount.SchedulableEQ(true),
Preload("Proxy"). dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
Order("priority ASC"). dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
Find(&accounts).Error ).
Order(dbent.Asc(dbaccount.FieldPriority)).
All(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
outAccounts := make([]service.Account, 0, len(accounts)) return r.accountsToService(ctx, accounts)
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
} }
func (r *accountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) { func (r *accountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) {
var accounts []accountModel return r.queryAccountsByGroup(ctx, groupID, accountGroupQueryOptions{
now := time.Now() status: service.StatusActive,
err := r.db.WithContext(ctx). schedulable: true,
Joins("JOIN account_groups ON account_groups.account_id = accounts.id"). })
Where("account_groups.group_id = ?", groupID).
Where("accounts.status = ? AND accounts.schedulable = ?", service.StatusActive, true).
Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now).
Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("account_groups.priority ASC, accounts.priority ASC").
Find(&accounts).Error
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
} }
func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) { func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
var accounts []accountModel
now := time.Now() now := time.Now()
err := r.db.WithContext(ctx). accounts, err := r.client.Account.Query().
Where("platform = ?", platform). Where(
Where("status = ? AND schedulable = ?", service.StatusActive, true). dbaccount.PlatformEQ(platform),
Where("(overload_until IS NULL OR overload_until <= ?)", now). dbaccount.StatusEQ(service.StatusActive),
Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now). dbaccount.SchedulableEQ(true),
Preload("Proxy"). dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
Order("priority ASC"). dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
Find(&accounts).Error ).
Order(dbent.Asc(dbaccount.FieldPriority)).
All(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
outAccounts := make([]service.Account, 0, len(accounts)) return r.accountsToService(ctx, accounts)
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
} }
func (r *accountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) { func (r *accountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
var accounts []accountModel return r.queryAccountsByGroup(ctx, groupID, accountGroupQueryOptions{
now := time.Now() status: service.StatusActive,
err := r.db.WithContext(ctx). schedulable: true,
Joins("JOIN account_groups ON account_groups.account_id = accounts.id"). platform: platform,
Where("account_groups.group_id = ?", groupID). })
Where("accounts.platform = ?", platform).
Where("accounts.status = ? AND accounts.schedulable = ?", service.StatusActive, true).
Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now).
Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("account_groups.priority ASC, accounts.priority ASC").
Find(&accounts).Error
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
} }
func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
now := time.Now() now := time.Now()
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id). _, err := r.client.Account.Update().
Updates(map[string]any{ Where(dbaccount.IDEQ(id)).
"rate_limited_at": now, SetRateLimitedAt(now).
"rate_limit_reset_at": resetAt, SetRateLimitResetAt(resetAt).
}).Error Save(ctx)
return err
} }
func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error { func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id). _, err := r.client.Account.Update().
Update("overload_until", until).Error Where(dbaccount.IDEQ(id)).
SetOverloadUntil(until).
Save(ctx)
return err
} }
func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error { func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id). _, err := r.client.Account.Update().
Updates(map[string]any{ Where(dbaccount.IDEQ(id)).
"rate_limited_at": nil, ClearRateLimitedAt().
"rate_limit_reset_at": nil, ClearRateLimitResetAt().
"overload_until": nil, ClearOverloadUntil().
}).Error Save(ctx)
return err
} }
func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error { func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
updates := map[string]any{ builder := r.client.Account.Update().
"session_window_status": status, Where(dbaccount.IDEQ(id)).
} SetSessionWindowStatus(status)
if start != nil { if start != nil {
updates["session_window_start"] = start builder.SetSessionWindowStart(*start)
} }
if end != nil { if end != nil {
updates["session_window_end"] = end builder.SetSessionWindowEnd(*end)
} }
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).Updates(updates).Error _, err := builder.Save(ctx)
return err
} }
func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id). _, err := r.client.Account.Update().
Update("schedulable", schedulable).Error Where(dbaccount.IDEQ(id)).
SetSchedulable(schedulable).
Save(ctx)
return err
} }
func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error { func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
...@@ -383,20 +510,24 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m ...@@ -383,20 +510,24 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m
return nil return nil
} }
var account accountModel accountExtra, err := r.client.Account.Query().
if err := r.db.WithContext(ctx).Select("extra").Where("id = ?", id).First(&account).Error; err != nil { Where(dbaccount.IDEQ(id)).
return err Select(dbaccount.FieldExtra).
Only(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrAccountNotFound, nil)
} }
if account.Extra == nil { extra := normalizeJSONMap(accountExtra.Extra)
account.Extra = datatypes.JSONMap{}
}
for k, v := range updates { for k, v := range updates {
account.Extra[k] = v extra[k] = v
} }
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id). _, err = r.client.Account.Update().
Update("extra", account.Extra).Error Where(dbaccount.IDEQ(id)).
SetExtra(extra).
Save(ctx)
return err
} }
func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) { func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
...@@ -404,129 +535,260 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates ...@@ -404,129 +535,260 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
return 0, nil return 0, nil
} }
updateMap := map[string]any{} setClauses := make([]string, 0, 8)
args := make([]any, 0, 8)
idx := 1
if updates.Name != nil { if updates.Name != nil {
updateMap["name"] = *updates.Name setClauses = append(setClauses, "name = $"+itoa(idx))
args = append(args, *updates.Name)
idx++
} }
if updates.ProxyID != nil { if updates.ProxyID != nil {
updateMap["proxy_id"] = updates.ProxyID setClauses = append(setClauses, "proxy_id = $"+itoa(idx))
args = append(args, *updates.ProxyID)
idx++
} }
if updates.Concurrency != nil { if updates.Concurrency != nil {
updateMap["concurrency"] = *updates.Concurrency setClauses = append(setClauses, "concurrency = $"+itoa(idx))
args = append(args, *updates.Concurrency)
idx++
} }
if updates.Priority != nil { if updates.Priority != nil {
updateMap["priority"] = *updates.Priority setClauses = append(setClauses, "priority = $"+itoa(idx))
args = append(args, *updates.Priority)
idx++
} }
if updates.Status != nil { if updates.Status != nil {
updateMap["status"] = *updates.Status setClauses = append(setClauses, "status = $"+itoa(idx))
args = append(args, *updates.Status)
idx++
} }
// JSONB 需要合并而非覆盖,使用 raw SQL 保持旧行为。
if len(updates.Credentials) > 0 { if len(updates.Credentials) > 0 {
updateMap["credentials"] = gorm.Expr("COALESCE(credentials,'{}') || ?", datatypes.JSONMap(updates.Credentials)) payload, err := json.Marshal(updates.Credentials)
if err != nil {
return 0, err
}
setClauses = append(setClauses, "credentials = COALESCE(credentials, '{}'::jsonb) || $"+itoa(idx)+"::jsonb")
args = append(args, payload)
idx++
} }
if len(updates.Extra) > 0 { if len(updates.Extra) > 0 {
updateMap["extra"] = gorm.Expr("COALESCE(extra,'{}') || ?", datatypes.JSONMap(updates.Extra)) payload, err := json.Marshal(updates.Extra)
if err != nil {
return 0, err
}
setClauses = append(setClauses, "extra = COALESCE(extra, '{}'::jsonb) || $"+itoa(idx)+"::jsonb")
args = append(args, payload)
idx++
} }
if len(updateMap) == 0 { if len(setClauses) == 0 {
return 0, nil return 0, nil
} }
result := r.db.WithContext(ctx). setClauses = append(setClauses, "updated_at = NOW()")
Model(&accountModel{}).
Where("id IN ?", ids). query := "UPDATE accounts SET " + joinClauses(setClauses, ", ") + " WHERE id = ANY($" + itoa(idx) + ") AND deleted_at IS NULL"
Clauses(clause.Returning{}). args = append(args, pq.Array(ids))
Updates(updateMap)
result, err := r.sql.ExecContext(ctx, query, args...)
if err != nil {
return 0, err
}
rows, err := result.RowsAffected()
if err != nil {
return 0, err
}
return rows, nil
}
return result.RowsAffected, result.Error type accountGroupQueryOptions struct {
status string
schedulable bool
platform string
} }
type accountModel struct { func (r *accountRepository) queryAccountsByGroup(ctx context.Context, groupID int64, opts accountGroupQueryOptions) ([]service.Account, error) {
ID int64 `gorm:"primaryKey"` q := r.client.AccountGroup.Query().
Name string `gorm:"size:100;not null"` Where(dbaccountgroup.GroupIDEQ(groupID))
Platform string `gorm:"size:50;not null"`
Type string `gorm:"size:20;not null"` preds := make([]dbpredicate.Account, 0, 6)
Credentials datatypes.JSONMap `gorm:"type:jsonb;default:'{}'"` preds = append(preds, dbaccount.DeletedAtIsNil())
Extra datatypes.JSONMap `gorm:"type:jsonb;default:'{}'"` if opts.status != "" {
ProxyID *int64 `gorm:"index"` preds = append(preds, dbaccount.StatusEQ(opts.status))
Concurrency int `gorm:"default:3;not null"` }
Priority int `gorm:"default:50;not null"` if opts.platform != "" {
Status string `gorm:"size:20;default:active;not null"` preds = append(preds, dbaccount.PlatformEQ(opts.platform))
ErrorMessage string `gorm:"type:text"` }
LastUsedAt *time.Time `gorm:"index"` if opts.schedulable {
CreatedAt time.Time `gorm:"not null"` now := time.Now()
UpdatedAt time.Time `gorm:"not null"` preds = append(preds,
DeletedAt gorm.DeletedAt `gorm:"index"` dbaccount.SchedulableEQ(true),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
)
}
Schedulable bool `gorm:"default:true;not null"` if len(preds) > 0 {
q = q.Where(dbaccountgroup.HasAccountWith(preds...))
}
RateLimitedAt *time.Time `gorm:"index"` groups, err := q.
RateLimitResetAt *time.Time `gorm:"index"` Order(
OverloadUntil *time.Time `gorm:"index"` dbaccountgroup.ByPriority(),
dbaccountgroup.ByAccountField(dbaccount.FieldPriority),
).
WithAccount().
All(ctx)
if err != nil {
return nil, err
}
SessionWindowStart *time.Time orderedIDs := make([]int64, 0, len(groups))
SessionWindowEnd *time.Time accountMap := make(map[int64]*dbent.Account, len(groups))
SessionWindowStatus string `gorm:"size:20"` for _, ag := range groups {
if ag.Edges.Account == nil {
continue
}
if _, exists := accountMap[ag.AccountID]; exists {
continue
}
accountMap[ag.AccountID] = ag.Edges.Account
orderedIDs = append(orderedIDs, ag.AccountID)
}
Proxy *proxyModel `gorm:"foreignKey:ProxyID"` accounts := make([]*dbent.Account, 0, len(orderedIDs))
AccountGroups []accountGroupModel `gorm:"foreignKey:AccountID"` for _, id := range orderedIDs {
if acc, ok := accountMap[id]; ok {
accounts = append(accounts, acc)
}
}
return r.accountsToService(ctx, accounts)
} }
func (accountModel) TableName() string { return "accounts" } func (r *accountRepository) accountsToService(ctx context.Context, accounts []*dbent.Account) ([]service.Account, error) {
if len(accounts) == 0 {
return []service.Account{}, nil
}
type accountGroupModel struct { accountIDs := make([]int64, 0, len(accounts))
AccountID int64 `gorm:"primaryKey"` proxyIDs := make([]int64, 0, len(accounts))
GroupID int64 `gorm:"primaryKey"` for _, acc := range accounts {
Priority int `gorm:"default:50;not null"` accountIDs = append(accountIDs, acc.ID)
CreatedAt time.Time `gorm:"not null"` if acc.ProxyID != nil {
proxyIDs = append(proxyIDs, *acc.ProxyID)
}
}
Account *accountModel `gorm:"foreignKey:AccountID"` proxyMap, err := r.loadProxies(ctx, proxyIDs)
Group *groupModel `gorm:"foreignKey:GroupID"` if err != nil {
return nil, err
}
groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs)
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for _, acc := range accounts {
out := accountEntityToService(acc)
if out == nil {
continue
}
if acc.ProxyID != nil {
if proxy, ok := proxyMap[*acc.ProxyID]; ok {
out.Proxy = proxy
}
}
if groups, ok := groupsByAccount[acc.ID]; ok {
out.Groups = groups
}
if groupIDs, ok := groupIDsByAccount[acc.ID]; ok {
out.GroupIDs = groupIDs
}
if ags, ok := accountGroupsByAccount[acc.ID]; ok {
out.AccountGroups = ags
}
outAccounts = append(outAccounts, *out)
}
return outAccounts, nil
} }
func (accountGroupModel) TableName() string { return "account_groups" } func (r *accountRepository) loadProxies(ctx context.Context, proxyIDs []int64) (map[int64]*service.Proxy, error) {
proxyMap := make(map[int64]*service.Proxy)
if len(proxyIDs) == 0 {
return proxyMap, nil
}
func accountGroupModelToService(m *accountGroupModel) *service.AccountGroup { proxies, err := r.client.Proxy.Query().Where(dbproxy.IDIn(proxyIDs...)).All(ctx)
if m == nil { if err != nil {
return nil return nil, err
} }
return &service.AccountGroup{
AccountID: m.AccountID, for _, p := range proxies {
GroupID: m.GroupID, proxyMap[p.ID] = proxyEntityToService(p)
Priority: m.Priority,
CreatedAt: m.CreatedAt,
Account: accountModelToService(m.Account),
Group: groupModelToService(m.Group),
} }
return proxyMap, nil
} }
func accountModelToService(m *accountModel) *service.Account { func (r *accountRepository) loadAccountGroups(ctx context.Context, accountIDs []int64) (map[int64][]*service.Group, map[int64][]int64, map[int64][]service.AccountGroup, error) {
if m == nil { groupsByAccount := make(map[int64][]*service.Group)
return nil groupIDsByAccount := make(map[int64][]int64)
accountGroupsByAccount := make(map[int64][]service.AccountGroup)
if len(accountIDs) == 0 {
return groupsByAccount, groupIDsByAccount, accountGroupsByAccount, nil
} }
var credentials map[string]any entries, err := r.client.AccountGroup.Query().
if m.Credentials != nil { Where(dbaccountgroup.AccountIDIn(accountIDs...)).
credentials = map[string]any(m.Credentials) WithGroup().
Order(dbaccountgroup.ByAccountID(), dbaccountgroup.ByPriority()).
All(ctx)
if err != nil {
return nil, nil, nil, err
}
for _, ag := range entries {
groupSvc := groupEntityToService(ag.Edges.Group)
agSvc := service.AccountGroup{
AccountID: ag.AccountID,
GroupID: ag.GroupID,
Priority: ag.Priority,
CreatedAt: ag.CreatedAt,
Group: groupSvc,
}
accountGroupsByAccount[ag.AccountID] = append(accountGroupsByAccount[ag.AccountID], agSvc)
groupIDsByAccount[ag.AccountID] = append(groupIDsByAccount[ag.AccountID], ag.GroupID)
if groupSvc != nil {
groupsByAccount[ag.AccountID] = append(groupsByAccount[ag.AccountID], groupSvc)
}
} }
var extra map[string]any return groupsByAccount, groupIDsByAccount, accountGroupsByAccount, nil
if m.Extra != nil { }
extra = map[string]any(m.Extra)
func accountEntityToService(m *dbent.Account) *service.Account {
if m == nil {
return nil
} }
account := &service.Account{ return &service.Account{
ID: m.ID, ID: m.ID,
Name: m.Name, Name: m.Name,
Platform: m.Platform, Platform: m.Platform,
Type: m.Type, Type: m.Type,
Credentials: credentials, Credentials: copyJSONMap(m.Credentials),
Extra: extra, Extra: copyJSONMap(m.Extra),
ProxyID: m.ProxyID, ProxyID: m.ProxyID,
Concurrency: m.Concurrency, Concurrency: m.Concurrency,
Priority: m.Priority, Priority: m.Priority,
Status: m.Status, Status: m.Status,
ErrorMessage: m.ErrorMessage, ErrorMessage: derefString(m.ErrorMessage),
LastUsedAt: m.LastUsedAt, LastUsedAt: m.LastUsedAt,
CreatedAt: m.CreatedAt, CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt, UpdatedAt: m.UpdatedAt,
...@@ -536,75 +798,39 @@ func accountModelToService(m *accountModel) *service.Account { ...@@ -536,75 +798,39 @@ func accountModelToService(m *accountModel) *service.Account {
OverloadUntil: m.OverloadUntil, OverloadUntil: m.OverloadUntil,
SessionWindowStart: m.SessionWindowStart, SessionWindowStart: m.SessionWindowStart,
SessionWindowEnd: m.SessionWindowEnd, SessionWindowEnd: m.SessionWindowEnd,
SessionWindowStatus: m.SessionWindowStatus, SessionWindowStatus: derefString(m.SessionWindowStatus),
Proxy: proxyModelToService(m.Proxy),
}
if len(m.AccountGroups) > 0 {
account.AccountGroups = make([]service.AccountGroup, 0, len(m.AccountGroups))
account.GroupIDs = make([]int64, 0, len(m.AccountGroups))
account.Groups = make([]*service.Group, 0, len(m.AccountGroups))
for i := range m.AccountGroups {
ag := accountGroupModelToService(&m.AccountGroups[i])
if ag == nil {
continue
}
account.AccountGroups = append(account.AccountGroups, *ag)
account.GroupIDs = append(account.GroupIDs, ag.GroupID)
if ag.Group != nil {
account.Groups = append(account.Groups, ag.Group)
}
}
} }
}
return account func normalizeJSONMap(in map[string]any) map[string]any {
if in == nil {
return map[string]any{}
}
return in
} }
func accountModelFromService(a *service.Account) *accountModel { func copyJSONMap(in map[string]any) map[string]any {
if a == nil { if in == nil {
return nil return nil
} }
out := make(map[string]any, len(in))
for k, v := range in {
out[k] = v
}
return out
}
func joinClauses(clauses []string, sep string) string {
if len(clauses) == 0 {
return ""
}
out := clauses[0]
for i := 1; i < len(clauses); i++ {
out += sep + clauses[i]
}
return out
}
var credentials datatypes.JSONMap func itoa(v int) string {
if a.Credentials != nil { return strconv.Itoa(v)
credentials = datatypes.JSONMap(a.Credentials)
}
var extra datatypes.JSONMap
if a.Extra != nil {
extra = datatypes.JSONMap(a.Extra)
}
return &accountModel{
ID: a.ID,
Name: a.Name,
Platform: a.Platform,
Type: a.Type,
Credentials: credentials,
Extra: extra,
ProxyID: a.ProxyID,
Concurrency: a.Concurrency,
Priority: a.Priority,
Status: a.Status,
ErrorMessage: a.ErrorMessage,
LastUsedAt: a.LastUsedAt,
CreatedAt: a.CreatedAt,
UpdatedAt: a.UpdatedAt,
Schedulable: a.Schedulable,
RateLimitedAt: a.RateLimitedAt,
RateLimitResetAt: a.RateLimitResetAt,
OverloadUntil: a.OverloadUntil,
SessionWindowStart: a.SessionWindowStart,
SessionWindowEnd: a.SessionWindowEnd,
SessionWindowStatus: a.SessionWindowStatus,
}
}
func applyAccountModelToService(account *service.Account, m *accountModel) {
if account == nil || m == nil {
return
}
account.ID = m.ID
account.CreatedAt = m.CreatedAt
account.UpdatedAt = m.UpdatedAt
} }
...@@ -4,27 +4,31 @@ package repository ...@@ -4,27 +4,31 @@ package repository
import ( import (
"context" "context"
"database/sql"
"testing" "testing"
"time" "time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/accountgroup"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"gorm.io/datatypes"
"gorm.io/gorm"
) )
type AccountRepoSuite struct { type AccountRepoSuite struct {
suite.Suite suite.Suite
ctx context.Context ctx context.Context
db *gorm.DB tx *sql.Tx
client *dbent.Client
repo *accountRepository repo *accountRepository
} }
func (s *AccountRepoSuite) SetupTest() { func (s *AccountRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
s.db = testTx(s.T()) client, tx := testEntSQLTx(s.T())
s.repo = NewAccountRepository(s.db).(*accountRepository) s.client = client
s.tx = tx
s.repo = newAccountRepositoryWithSQL(client, tx)
} }
func TestAccountRepoSuite(t *testing.T) { func TestAccountRepoSuite(t *testing.T) {
...@@ -61,7 +65,7 @@ func (s *AccountRepoSuite) TestGetByID_NotFound() { ...@@ -61,7 +65,7 @@ func (s *AccountRepoSuite) TestGetByID_NotFound() {
} }
func (s *AccountRepoSuite) TestUpdate() { func (s *AccountRepoSuite) TestUpdate() {
account := accountModelToService(mustCreateAccount(s.T(), s.db, &accountModel{Name: "original"})) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "original"})
account.Name = "updated" account.Name = "updated"
err := s.repo.Update(s.ctx, account) err := s.repo.Update(s.ctx, account)
...@@ -73,7 +77,7 @@ func (s *AccountRepoSuite) TestUpdate() { ...@@ -73,7 +77,7 @@ func (s *AccountRepoSuite) TestUpdate() {
} }
func (s *AccountRepoSuite) TestDelete() { func (s *AccountRepoSuite) TestDelete() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "to-delete"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete"})
err := s.repo.Delete(s.ctx, account.ID) err := s.repo.Delete(s.ctx, account.ID)
s.Require().NoError(err, "Delete") s.Require().NoError(err, "Delete")
...@@ -83,23 +87,23 @@ func (s *AccountRepoSuite) TestDelete() { ...@@ -83,23 +87,23 @@ func (s *AccountRepoSuite) TestDelete() {
} }
func (s *AccountRepoSuite) TestDelete_WithGroupBindings() { func (s *AccountRepoSuite) TestDelete_WithGroupBindings() {
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-del"}) group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-del"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-del"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-del"})
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1) mustBindAccountToGroup(s.T(), s.client, account.ID, group.ID, 1)
err := s.repo.Delete(s.ctx, account.ID) err := s.repo.Delete(s.ctx, account.ID)
s.Require().NoError(err, "Delete should cascade remove bindings") s.Require().NoError(err, "Delete should cascade remove bindings")
var count int64 count, err := s.client.AccountGroup.Query().Where(accountgroup.AccountIDEQ(account.ID)).Count(s.ctx)
s.db.Model(&accountGroupModel{}).Where("account_id = ?", account.ID).Count(&count) s.Require().NoError(err)
s.Require().Zero(count, "expected bindings to be removed") s.Require().Zero(count, "expected bindings to be removed")
} }
// --- List / ListWithFilters --- // --- List / ListWithFilters ---
func (s *AccountRepoSuite) TestList() { func (s *AccountRepoSuite) TestList() {
mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc1"}) mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc1"})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc2"}) mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc2"})
accounts, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}) accounts, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List") s.Require().NoError(err, "List")
...@@ -110,7 +114,7 @@ func (s *AccountRepoSuite) TestList() { ...@@ -110,7 +114,7 @@ func (s *AccountRepoSuite) TestList() {
func (s *AccountRepoSuite) TestListWithFilters() { func (s *AccountRepoSuite) TestListWithFilters() {
tests := []struct { tests := []struct {
name string name string
setup func(db *gorm.DB) setup func(client *dbent.Client)
platform string platform string
accType string accType string
status string status string
...@@ -120,9 +124,9 @@ func (s *AccountRepoSuite) TestListWithFilters() { ...@@ -120,9 +124,9 @@ func (s *AccountRepoSuite) TestListWithFilters() {
}{ }{
{ {
name: "filter_by_platform", name: "filter_by_platform",
setup: func(db *gorm.DB) { setup: func(client *dbent.Client) {
mustCreateAccount(s.T(), db, &accountModel{Name: "a1", Platform: service.PlatformAnthropic}) mustCreateAccount(s.T(), client, &service.Account{Name: "a1", Platform: service.PlatformAnthropic})
mustCreateAccount(s.T(), db, &accountModel{Name: "a2", Platform: service.PlatformOpenAI}) mustCreateAccount(s.T(), client, &service.Account{Name: "a2", Platform: service.PlatformOpenAI})
}, },
platform: service.PlatformOpenAI, platform: service.PlatformOpenAI,
wantCount: 1, wantCount: 1,
...@@ -132,9 +136,9 @@ func (s *AccountRepoSuite) TestListWithFilters() { ...@@ -132,9 +136,9 @@ func (s *AccountRepoSuite) TestListWithFilters() {
}, },
{ {
name: "filter_by_type", name: "filter_by_type",
setup: func(db *gorm.DB) { setup: func(client *dbent.Client) {
mustCreateAccount(s.T(), db, &accountModel{Name: "t1", Type: service.AccountTypeOAuth}) mustCreateAccount(s.T(), client, &service.Account{Name: "t1", Type: service.AccountTypeOAuth})
mustCreateAccount(s.T(), db, &accountModel{Name: "t2", Type: service.AccountTypeApiKey}) mustCreateAccount(s.T(), client, &service.Account{Name: "t2", Type: service.AccountTypeApiKey})
}, },
accType: service.AccountTypeApiKey, accType: service.AccountTypeApiKey,
wantCount: 1, wantCount: 1,
...@@ -144,9 +148,9 @@ func (s *AccountRepoSuite) TestListWithFilters() { ...@@ -144,9 +148,9 @@ func (s *AccountRepoSuite) TestListWithFilters() {
}, },
{ {
name: "filter_by_status", name: "filter_by_status",
setup: func(db *gorm.DB) { setup: func(client *dbent.Client) {
mustCreateAccount(s.T(), db, &accountModel{Name: "s1", Status: service.StatusActive}) mustCreateAccount(s.T(), client, &service.Account{Name: "s1", Status: service.StatusActive})
mustCreateAccount(s.T(), db, &accountModel{Name: "s2", Status: service.StatusDisabled}) mustCreateAccount(s.T(), client, &service.Account{Name: "s2", Status: service.StatusDisabled})
}, },
status: service.StatusDisabled, status: service.StatusDisabled,
wantCount: 1, wantCount: 1,
...@@ -156,9 +160,9 @@ func (s *AccountRepoSuite) TestListWithFilters() { ...@@ -156,9 +160,9 @@ func (s *AccountRepoSuite) TestListWithFilters() {
}, },
{ {
name: "filter_by_search", name: "filter_by_search",
setup: func(db *gorm.DB) { setup: func(client *dbent.Client) {
mustCreateAccount(s.T(), db, &accountModel{Name: "alpha-account"}) mustCreateAccount(s.T(), client, &service.Account{Name: "alpha-account"})
mustCreateAccount(s.T(), db, &accountModel{Name: "beta-account"}) mustCreateAccount(s.T(), client, &service.Account{Name: "beta-account"})
}, },
search: "alpha", search: "alpha",
wantCount: 1, wantCount: 1,
...@@ -171,11 +175,11 @@ func (s *AccountRepoSuite) TestListWithFilters() { ...@@ -171,11 +175,11 @@ func (s *AccountRepoSuite) TestListWithFilters() {
for _, tt := range tests { for _, tt := range tests {
s.Run(tt.name, func() { s.Run(tt.name, func() {
// 每个 case 重新获取隔离资源 // 每个 case 重新获取隔离资源
db := testTx(s.T()) client, tx := testEntSQLTx(s.T())
repo := NewAccountRepository(db).(*accountRepository) repo := newAccountRepositoryWithSQL(client, tx)
ctx := context.Background() ctx := context.Background()
tt.setup(db) tt.setup(client)
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search) accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search)
s.Require().NoError(err) s.Require().NoError(err)
...@@ -190,11 +194,11 @@ func (s *AccountRepoSuite) TestListWithFilters() { ...@@ -190,11 +194,11 @@ func (s *AccountRepoSuite) TestListWithFilters() {
// --- ListByGroup / ListActive / ListByPlatform --- // --- ListByGroup / ListActive / ListByPlatform ---
func (s *AccountRepoSuite) TestListByGroup() { func (s *AccountRepoSuite) TestListByGroup() {
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list"}) group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-list"})
acc1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", Status: service.StatusActive}) acc1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a1", Status: service.StatusActive})
acc2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", Status: service.StatusActive}) acc2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a2", Status: service.StatusActive})
mustBindAccountToGroup(s.T(), s.db, acc1.ID, group.ID, 2) mustBindAccountToGroup(s.T(), s.client, acc1.ID, group.ID, 2)
mustBindAccountToGroup(s.T(), s.db, acc2.ID, group.ID, 1) mustBindAccountToGroup(s.T(), s.client, acc2.ID, group.ID, 1)
accounts, err := s.repo.ListByGroup(s.ctx, group.ID) accounts, err := s.repo.ListByGroup(s.ctx, group.ID)
s.Require().NoError(err, "ListByGroup") s.Require().NoError(err, "ListByGroup")
...@@ -204,8 +208,8 @@ func (s *AccountRepoSuite) TestListByGroup() { ...@@ -204,8 +208,8 @@ func (s *AccountRepoSuite) TestListByGroup() {
} }
func (s *AccountRepoSuite) TestListActive() { func (s *AccountRepoSuite) TestListActive() {
mustCreateAccount(s.T(), s.db, &accountModel{Name: "active1", Status: service.StatusActive}) mustCreateAccount(s.T(), s.client, &service.Account{Name: "active1", Status: service.StatusActive})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "inactive1", Status: service.StatusDisabled}) mustCreateAccount(s.T(), s.client, &service.Account{Name: "inactive1", Status: service.StatusDisabled})
accounts, err := s.repo.ListActive(s.ctx) accounts, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive") s.Require().NoError(err, "ListActive")
...@@ -214,8 +218,8 @@ func (s *AccountRepoSuite) TestListActive() { ...@@ -214,8 +218,8 @@ func (s *AccountRepoSuite) TestListActive() {
} }
func (s *AccountRepoSuite) TestListByPlatform() { func (s *AccountRepoSuite) TestListByPlatform() {
mustCreateAccount(s.T(), s.db, &accountModel{Name: "p1", Platform: service.PlatformAnthropic, Status: service.StatusActive}) mustCreateAccount(s.T(), s.client, &service.Account{Name: "p1", Platform: service.PlatformAnthropic, Status: service.StatusActive})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "p2", Platform: service.PlatformOpenAI, Status: service.StatusActive}) mustCreateAccount(s.T(), s.client, &service.Account{Name: "p2", Platform: service.PlatformOpenAI, Status: service.StatusActive})
accounts, err := s.repo.ListByPlatform(s.ctx, service.PlatformAnthropic) accounts, err := s.repo.ListByPlatform(s.ctx, service.PlatformAnthropic)
s.Require().NoError(err, "ListByPlatform") s.Require().NoError(err, "ListByPlatform")
...@@ -226,14 +230,14 @@ func (s *AccountRepoSuite) TestListByPlatform() { ...@@ -226,14 +230,14 @@ func (s *AccountRepoSuite) TestListByPlatform() {
// --- Preload and VirtualFields --- // --- Preload and VirtualFields ---
func (s *AccountRepoSuite) TestPreload_And_VirtualFields() { func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
proxy := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1"}) proxy := mustCreateProxy(s.T(), s.client, &service.Proxy{Name: "p1"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1"}) group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g1"})
account := mustCreateAccount(s.T(), s.db, &accountModel{ account := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "acc1", Name: "acc1",
ProxyID: &proxy.ID, ProxyID: &proxy.ID,
}) })
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1) mustBindAccountToGroup(s.T(), s.client, account.ID, group.ID, 1)
got, err := s.repo.GetByID(s.ctx, account.ID) got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err, "GetByID") s.Require().NoError(err, "GetByID")
...@@ -257,9 +261,9 @@ func (s *AccountRepoSuite) TestPreload_And_VirtualFields() { ...@@ -257,9 +261,9 @@ func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
// --- GroupBinding / AddToGroup / RemoveFromGroup / BindGroups / GetGroups --- // --- GroupBinding / AddToGroup / RemoveFromGroup / BindGroups / GetGroups ---
func (s *AccountRepoSuite) TestGroupBinding_And_BindGroups() { func (s *AccountRepoSuite) TestGroupBinding_And_BindGroups() {
g1 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1"}) g1 := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g1"})
g2 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2"}) g2 := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g2"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc"})
s.Require().NoError(s.repo.AddToGroup(s.ctx, account.ID, g1.ID, 10), "AddToGroup") s.Require().NoError(s.repo.AddToGroup(s.ctx, account.ID, g1.ID, 10), "AddToGroup")
groups, err := s.repo.GetGroups(s.ctx, account.ID) groups, err := s.repo.GetGroups(s.ctx, account.ID)
...@@ -279,9 +283,9 @@ func (s *AccountRepoSuite) TestGroupBinding_And_BindGroups() { ...@@ -279,9 +283,9 @@ func (s *AccountRepoSuite) TestGroupBinding_And_BindGroups() {
} }
func (s *AccountRepoSuite) TestBindGroups_EmptyList() { func (s *AccountRepoSuite) TestBindGroups_EmptyList() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-empty"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-empty"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-empty"}) group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-empty"})
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1) mustBindAccountToGroup(s.T(), s.client, account.ID, group.ID, 1)
s.Require().NoError(s.repo.BindGroups(s.ctx, account.ID, []int64{}), "BindGroups empty") s.Require().NoError(s.repo.BindGroups(s.ctx, account.ID, []int64{}), "BindGroups empty")
...@@ -294,14 +298,14 @@ func (s *AccountRepoSuite) TestBindGroups_EmptyList() { ...@@ -294,14 +298,14 @@ func (s *AccountRepoSuite) TestBindGroups_EmptyList() {
func (s *AccountRepoSuite) TestListSchedulable() { func (s *AccountRepoSuite) TestListSchedulable() {
now := time.Now() now := time.Now()
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sched"}) group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-sched"})
okAcc := mustCreateAccount(s.T(), s.db, &accountModel{Name: "ok", Schedulable: true}) okAcc := mustCreateAccount(s.T(), s.client, &service.Account{Name: "ok", Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1) mustBindAccountToGroup(s.T(), s.client, okAcc.ID, group.ID, 1)
future := now.Add(10 * time.Minute) future := now.Add(10 * time.Minute)
overloaded := mustCreateAccount(s.T(), s.db, &accountModel{Name: "over", Schedulable: true, OverloadUntil: &future}) overloaded := mustCreateAccount(s.T(), s.client, &service.Account{Name: "over", Schedulable: true, OverloadUntil: &future})
mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1) mustBindAccountToGroup(s.T(), s.client, overloaded.ID, group.ID, 1)
sched, err := s.repo.ListSchedulable(s.ctx) sched, err := s.repo.ListSchedulable(s.ctx)
s.Require().NoError(err, "ListSchedulable") s.Require().NoError(err, "ListSchedulable")
...@@ -312,17 +316,17 @@ func (s *AccountRepoSuite) TestListSchedulable() { ...@@ -312,17 +316,17 @@ func (s *AccountRepoSuite) TestListSchedulable() {
func (s *AccountRepoSuite) TestListSchedulableByGroupID_TimeBoundaries_And_StatusUpdates() { func (s *AccountRepoSuite) TestListSchedulableByGroupID_TimeBoundaries_And_StatusUpdates() {
now := time.Now() now := time.Now()
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sched"}) group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-sched"})
okAcc := mustCreateAccount(s.T(), s.db, &accountModel{Name: "ok", Schedulable: true}) okAcc := mustCreateAccount(s.T(), s.client, &service.Account{Name: "ok", Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1) mustBindAccountToGroup(s.T(), s.client, okAcc.ID, group.ID, 1)
future := now.Add(10 * time.Minute) future := now.Add(10 * time.Minute)
overloaded := mustCreateAccount(s.T(), s.db, &accountModel{Name: "over", Schedulable: true, OverloadUntil: &future}) overloaded := mustCreateAccount(s.T(), s.client, &service.Account{Name: "over", Schedulable: true, OverloadUntil: &future})
mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1) mustBindAccountToGroup(s.T(), s.client, overloaded.ID, group.ID, 1)
rateLimited := mustCreateAccount(s.T(), s.db, &accountModel{Name: "rl", Schedulable: true}) rateLimited := mustCreateAccount(s.T(), s.client, &service.Account{Name: "rl", Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, rateLimited.ID, group.ID, 1) mustBindAccountToGroup(s.T(), s.client, rateLimited.ID, group.ID, 1)
s.Require().NoError(s.repo.SetRateLimited(s.ctx, rateLimited.ID, now.Add(10*time.Minute)), "SetRateLimited") s.Require().NoError(s.repo.SetRateLimited(s.ctx, rateLimited.ID, now.Add(10*time.Minute)), "SetRateLimited")
s.Require().NoError(s.repo.SetError(s.ctx, overloaded.ID, "boom"), "SetError") s.Require().NoError(s.repo.SetError(s.ctx, overloaded.ID, "boom"), "SetError")
...@@ -339,8 +343,8 @@ func (s *AccountRepoSuite) TestListSchedulableByGroupID_TimeBoundaries_And_Statu ...@@ -339,8 +343,8 @@ func (s *AccountRepoSuite) TestListSchedulableByGroupID_TimeBoundaries_And_Statu
} }
func (s *AccountRepoSuite) TestListSchedulableByPlatform() { func (s *AccountRepoSuite) TestListSchedulableByPlatform() {
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", Platform: service.PlatformAnthropic, Schedulable: true}) mustCreateAccount(s.T(), s.client, &service.Account{Name: "a1", Platform: service.PlatformAnthropic, Schedulable: true})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", Platform: service.PlatformOpenAI, Schedulable: true}) mustCreateAccount(s.T(), s.client, &service.Account{Name: "a2", Platform: service.PlatformOpenAI, Schedulable: true})
accounts, err := s.repo.ListSchedulableByPlatform(s.ctx, service.PlatformAnthropic) accounts, err := s.repo.ListSchedulableByPlatform(s.ctx, service.PlatformAnthropic)
s.Require().NoError(err) s.Require().NoError(err)
...@@ -349,11 +353,11 @@ func (s *AccountRepoSuite) TestListSchedulableByPlatform() { ...@@ -349,11 +353,11 @@ func (s *AccountRepoSuite) TestListSchedulableByPlatform() {
} }
func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() { func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() {
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sp"}) group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-sp"})
a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", Platform: service.PlatformAnthropic, Schedulable: true}) a1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a1", Platform: service.PlatformAnthropic, Schedulable: true})
a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", Platform: service.PlatformOpenAI, Schedulable: true}) a2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a2", Platform: service.PlatformOpenAI, Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, a1.ID, group.ID, 1) mustBindAccountToGroup(s.T(), s.client, a1.ID, group.ID, 1)
mustBindAccountToGroup(s.T(), s.db, a2.ID, group.ID, 2) mustBindAccountToGroup(s.T(), s.client, a2.ID, group.ID, 2)
accounts, err := s.repo.ListSchedulableByGroupIDAndPlatform(s.ctx, group.ID, service.PlatformAnthropic) accounts, err := s.repo.ListSchedulableByGroupIDAndPlatform(s.ctx, group.ID, service.PlatformAnthropic)
s.Require().NoError(err) s.Require().NoError(err)
...@@ -362,7 +366,7 @@ func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() { ...@@ -362,7 +366,7 @@ func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() {
} }
func (s *AccountRepoSuite) TestSetSchedulable() { func (s *AccountRepoSuite) TestSetSchedulable() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-sched", Schedulable: true}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-sched", Schedulable: true})
s.Require().NoError(s.repo.SetSchedulable(s.ctx, account.ID, false)) s.Require().NoError(s.repo.SetSchedulable(s.ctx, account.ID, false))
...@@ -374,7 +378,7 @@ func (s *AccountRepoSuite) TestSetSchedulable() { ...@@ -374,7 +378,7 @@ func (s *AccountRepoSuite) TestSetSchedulable() {
// --- SetOverloaded / SetRateLimited / ClearRateLimit --- // --- SetOverloaded / SetRateLimited / ClearRateLimit ---
func (s *AccountRepoSuite) TestSetOverloaded() { func (s *AccountRepoSuite) TestSetOverloaded() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-over"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-over"})
until := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC) until := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until)) s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until))
...@@ -386,7 +390,7 @@ func (s *AccountRepoSuite) TestSetOverloaded() { ...@@ -386,7 +390,7 @@ func (s *AccountRepoSuite) TestSetOverloaded() {
} }
func (s *AccountRepoSuite) TestSetRateLimited() { func (s *AccountRepoSuite) TestSetRateLimited() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-rl"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-rl"})
resetAt := time.Date(2025, 6, 15, 14, 0, 0, 0, time.UTC) resetAt := time.Date(2025, 6, 15, 14, 0, 0, 0, time.UTC)
s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, resetAt)) s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, resetAt))
...@@ -399,7 +403,7 @@ func (s *AccountRepoSuite) TestSetRateLimited() { ...@@ -399,7 +403,7 @@ func (s *AccountRepoSuite) TestSetRateLimited() {
} }
func (s *AccountRepoSuite) TestClearRateLimit() { func (s *AccountRepoSuite) TestClearRateLimit() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-clear"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-clear"})
until := time.Now().Add(1 * time.Hour) until := time.Now().Add(1 * time.Hour)
s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until)) s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until))
s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, until)) s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, until))
...@@ -416,7 +420,7 @@ func (s *AccountRepoSuite) TestClearRateLimit() { ...@@ -416,7 +420,7 @@ func (s *AccountRepoSuite) TestClearRateLimit() {
// --- UpdateLastUsed --- // --- UpdateLastUsed ---
func (s *AccountRepoSuite) TestUpdateLastUsed() { func (s *AccountRepoSuite) TestUpdateLastUsed() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-used"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-used"})
s.Require().Nil(account.LastUsedAt) s.Require().Nil(account.LastUsedAt)
s.Require().NoError(s.repo.UpdateLastUsed(s.ctx, account.ID)) s.Require().NoError(s.repo.UpdateLastUsed(s.ctx, account.ID))
...@@ -429,7 +433,7 @@ func (s *AccountRepoSuite) TestUpdateLastUsed() { ...@@ -429,7 +433,7 @@ func (s *AccountRepoSuite) TestUpdateLastUsed() {
// --- SetError --- // --- SetError ---
func (s *AccountRepoSuite) TestSetError() { func (s *AccountRepoSuite) TestSetError() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-err", Status: service.StatusActive}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-err", Status: service.StatusActive})
s.Require().NoError(s.repo.SetError(s.ctx, account.ID, "something went wrong")) s.Require().NoError(s.repo.SetError(s.ctx, account.ID, "something went wrong"))
...@@ -442,7 +446,7 @@ func (s *AccountRepoSuite) TestSetError() { ...@@ -442,7 +446,7 @@ func (s *AccountRepoSuite) TestSetError() {
// --- UpdateSessionWindow --- // --- UpdateSessionWindow ---
func (s *AccountRepoSuite) TestUpdateSessionWindow() { func (s *AccountRepoSuite) TestUpdateSessionWindow() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-win"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-win"})
start := time.Date(2025, 6, 15, 10, 0, 0, 0, time.UTC) start := time.Date(2025, 6, 15, 10, 0, 0, 0, time.UTC)
end := time.Date(2025, 6, 15, 15, 0, 0, 0, time.UTC) end := time.Date(2025, 6, 15, 15, 0, 0, 0, time.UTC)
...@@ -458,9 +462,9 @@ func (s *AccountRepoSuite) TestUpdateSessionWindow() { ...@@ -458,9 +462,9 @@ func (s *AccountRepoSuite) TestUpdateSessionWindow() {
// --- UpdateExtra --- // --- UpdateExtra ---
func (s *AccountRepoSuite) TestUpdateExtra_MergesFields() { func (s *AccountRepoSuite) TestUpdateExtra_MergesFields() {
account := mustCreateAccount(s.T(), s.db, &accountModel{ account := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "acc-extra", Name: "acc-extra",
Extra: datatypes.JSONMap{"a": "1"}, Extra: map[string]any{"a": "1"},
}) })
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"b": "2"}), "UpdateExtra") s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"b": "2"}), "UpdateExtra")
...@@ -471,12 +475,12 @@ func (s *AccountRepoSuite) TestUpdateExtra_MergesFields() { ...@@ -471,12 +475,12 @@ func (s *AccountRepoSuite) TestUpdateExtra_MergesFields() {
} }
func (s *AccountRepoSuite) TestUpdateExtra_EmptyUpdates() { func (s *AccountRepoSuite) TestUpdateExtra_EmptyUpdates() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-extra-empty"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-extra-empty"})
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{})) s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{}))
} }
func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() { func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-nil-extra", Extra: nil}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-nil-extra", Extra: nil})
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"key": "val"})) s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"key": "val"}))
got, err := s.repo.GetByID(s.ctx, account.ID) got, err := s.repo.GetByID(s.ctx, account.ID)
...@@ -488,9 +492,9 @@ func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() { ...@@ -488,9 +492,9 @@ func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
func (s *AccountRepoSuite) TestGetByCRSAccountID() { func (s *AccountRepoSuite) TestGetByCRSAccountID() {
crsID := "crs-12345" crsID := "crs-12345"
mustCreateAccount(s.T(), s.db, &accountModel{ mustCreateAccount(s.T(), s.client, &service.Account{
Name: "acc-crs", Name: "acc-crs",
Extra: datatypes.JSONMap{"crs_account_id": crsID}, Extra: map[string]any{"crs_account_id": crsID},
}) })
got, err := s.repo.GetByCRSAccountID(s.ctx, crsID) got, err := s.repo.GetByCRSAccountID(s.ctx, crsID)
...@@ -514,8 +518,8 @@ func (s *AccountRepoSuite) TestGetByCRSAccountID_EmptyString() { ...@@ -514,8 +518,8 @@ func (s *AccountRepoSuite) TestGetByCRSAccountID_EmptyString() {
// --- BulkUpdate --- // --- BulkUpdate ---
func (s *AccountRepoSuite) TestBulkUpdate() { func (s *AccountRepoSuite) TestBulkUpdate() {
a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "bulk1", Priority: 1}) a1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk1", Priority: 1})
a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "bulk2", Priority: 1}) a2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk2", Priority: 1})
newPriority := 99 newPriority := 99
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID, a2.ID}, service.AccountBulkUpdate{ affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID, a2.ID}, service.AccountBulkUpdate{
...@@ -531,13 +535,13 @@ func (s *AccountRepoSuite) TestBulkUpdate() { ...@@ -531,13 +535,13 @@ func (s *AccountRepoSuite) TestBulkUpdate() {
} }
func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() { func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() {
a1 := mustCreateAccount(s.T(), s.db, &accountModel{ a1 := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "bulk-cred", Name: "bulk-cred",
Credentials: datatypes.JSONMap{"existing": "value"}, Credentials: map[string]any{"existing": "value"},
}) })
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{ _, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{
Credentials: datatypes.JSONMap{"new_key": "new_value"}, Credentials: map[string]any{"new_key": "new_value"},
}) })
s.Require().NoError(err) s.Require().NoError(err)
...@@ -547,13 +551,13 @@ func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() { ...@@ -547,13 +551,13 @@ func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() {
} }
func (s *AccountRepoSuite) TestBulkUpdate_MergeExtra() { func (s *AccountRepoSuite) TestBulkUpdate_MergeExtra() {
a1 := mustCreateAccount(s.T(), s.db, &accountModel{ a1 := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "bulk-extra", Name: "bulk-extra",
Extra: datatypes.JSONMap{"existing": "val"}, Extra: map[string]any{"existing": "val"},
}) })
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{ _, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{
Extra: datatypes.JSONMap{"new_key": "new_val"}, Extra: map[string]any{"new_key": "new_val"},
}) })
s.Require().NoError(err) s.Require().NoError(err)
...@@ -569,7 +573,7 @@ func (s *AccountRepoSuite) TestBulkUpdate_EmptyIDs() { ...@@ -569,7 +573,7 @@ func (s *AccountRepoSuite) TestBulkUpdate_EmptyIDs() {
} }
func (s *AccountRepoSuite) TestBulkUpdate_EmptyUpdates() { func (s *AccountRepoSuite) TestBulkUpdate_EmptyUpdates() {
a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "bulk-empty"}) a1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk-empty"})
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{}) affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{})
s.Require().NoError(err) s.Require().NoError(err)
......
//go:build integration
package repository
import (
"context"
"fmt"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func uniqueTestValue(t *testing.T, prefix string) string {
t.Helper()
safeName := strings.NewReplacer("/", "_", " ", "_").Replace(t.Name())
return fmt.Sprintf("%s-%s", prefix, safeName)
}
func TestUserRepository_RemoveGroupFromAllowedGroups_RemovesAllOccurrences(t *testing.T) {
ctx := context.Background()
entClient, sqlTx := testEntSQLTx(t)
targetGroup, err := entClient.Group.Create().
SetName(uniqueTestValue(t, "target-group")).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
otherGroup, err := entClient.Group.Create().
SetName(uniqueTestValue(t, "other-group")).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
repo := newUserRepositoryWithSQL(entClient, sqlTx)
u1 := &service.User{
Email: uniqueTestValue(t, "u1") + "@example.com",
PasswordHash: "test-password-hash",
Role: service.RoleUser,
Status: service.StatusActive,
Concurrency: 5,
AllowedGroups: []int64{targetGroup.ID, otherGroup.ID},
}
require.NoError(t, repo.Create(ctx, u1))
u2 := &service.User{
Email: uniqueTestValue(t, "u2") + "@example.com",
PasswordHash: "test-password-hash",
Role: service.RoleUser,
Status: service.StatusActive,
Concurrency: 5,
AllowedGroups: []int64{targetGroup.ID},
}
require.NoError(t, repo.Create(ctx, u2))
u3 := &service.User{
Email: uniqueTestValue(t, "u3") + "@example.com",
PasswordHash: "test-password-hash",
Role: service.RoleUser,
Status: service.StatusActive,
Concurrency: 5,
AllowedGroups: []int64{otherGroup.ID},
}
require.NoError(t, repo.Create(ctx, u3))
affected, err := repo.RemoveGroupFromAllowedGroups(ctx, targetGroup.ID)
require.NoError(t, err)
require.Equal(t, int64(2), affected)
u1After, err := repo.GetByID(ctx, u1.ID)
require.NoError(t, err)
require.NotContains(t, u1After.AllowedGroups, targetGroup.ID)
require.Contains(t, u1After.AllowedGroups, otherGroup.ID)
u2After, err := repo.GetByID(ctx, u2.ID)
require.NoError(t, err)
require.NotContains(t, u2After.AllowedGroups, targetGroup.ID)
}
func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *testing.T) {
ctx := context.Background()
entClient, sqlTx := testEntSQLTx(t)
targetGroup, err := entClient.Group.Create().
SetName(uniqueTestValue(t, "delete-cascade-target")).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
otherGroup, err := entClient.Group.Create().
SetName(uniqueTestValue(t, "delete-cascade-other")).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
userRepo := newUserRepositoryWithSQL(entClient, sqlTx)
groupRepo := newGroupRepositoryWithSQL(entClient, sqlTx)
apiKeyRepo := NewApiKeyRepository(entClient)
u := &service.User{
Email: uniqueTestValue(t, "cascade-user") + "@example.com",
PasswordHash: "test-password-hash",
Role: service.RoleUser,
Status: service.StatusActive,
Concurrency: 5,
AllowedGroups: []int64{targetGroup.ID, otherGroup.ID},
}
require.NoError(t, userRepo.Create(ctx, u))
key := &service.ApiKey{
UserID: u.ID,
Key: uniqueTestValue(t, "sk-test-delete-cascade"),
Name: "test key",
GroupID: &targetGroup.ID,
Status: service.StatusActive,
}
require.NoError(t, apiKeyRepo.Create(ctx, key))
_, err = groupRepo.DeleteCascade(ctx, targetGroup.ID)
require.NoError(t, err)
// Deleted group should be hidden by default queries (soft-delete semantics).
_, err = groupRepo.GetByID(ctx, targetGroup.ID)
require.ErrorIs(t, err, service.ErrGroupNotFound)
activeGroups, err := groupRepo.ListActive(ctx)
require.NoError(t, err)
for _, g := range activeGroups {
require.NotEqual(t, targetGroup.ID, g.ID)
}
// User.allowed_groups should no longer include the deleted group.
uAfter, err := userRepo.GetByID(ctx, u.ID)
require.NoError(t, err)
require.NotContains(t, uAfter.AllowedGroups, targetGroup.ID)
require.Contains(t, uAfter.AllowedGroups, otherGroup.ID)
// API keys bound to the deleted group should have group_id cleared.
keyAfter, err := apiKeyRepo.GetByID(ctx, key.ID)
require.NoError(t, err)
require.Nil(t, keyAfter.GroupID)
}
...@@ -2,83 +2,118 @@ package repository ...@@ -2,83 +2,118 @@ package repository
import ( import (
"context" "context"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
) )
type apiKeyRepository struct { type apiKeyRepository struct {
db *gorm.DB client *dbent.Client
} }
func NewApiKeyRepository(db *gorm.DB) service.ApiKeyRepository { func NewApiKeyRepository(client *dbent.Client) service.ApiKeyRepository {
return &apiKeyRepository{db: db} return &apiKeyRepository{client: client}
} }
func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) error { func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) error {
m := apiKeyModelFromService(key) created, err := r.client.ApiKey.Create().
err := r.db.WithContext(ctx).Create(m).Error SetUserID(key.UserID).
SetKey(key.Key).
SetName(key.Name).
SetStatus(key.Status).
SetNillableGroupID(key.GroupID).
Save(ctx)
if err == nil { if err == nil {
applyApiKeyModelToService(key, m) key.ID = created.ID
key.CreatedAt = created.CreatedAt
key.UpdatedAt = created.UpdatedAt
} }
return translatePersistenceError(err, nil, service.ErrApiKeyExists) return translatePersistenceError(err, nil, service.ErrApiKeyExists)
} }
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) { func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
var m apiKeyModel m, err := r.client.ApiKey.Query().
err := r.db.WithContext(ctx).Preload("User").Preload("Group").First(&m, id).Error Where(apikey.IDEQ(id)).
WithUser().
WithGroup().
Only(ctx)
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil) if dbent.IsNotFound(err) {
return nil, service.ErrApiKeyNotFound
}
return nil, err
} }
return apiKeyModelToService(&m), nil return apiKeyEntityToService(m), nil
} }
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) { func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
var m apiKeyModel m, err := r.client.ApiKey.Query().
err := r.db.WithContext(ctx).Preload("User").Preload("Group").Where("key = ?", key).First(&m).Error Where(apikey.KeyEQ(key)).
WithUser().
WithGroup().
Only(ctx)
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil) if dbent.IsNotFound(err) {
return nil, service.ErrApiKeyNotFound
}
return nil, err
} }
return apiKeyModelToService(&m), nil return apiKeyEntityToService(m), nil
} }
func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) error { func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) error {
m := apiKeyModelFromService(key) builder := r.client.ApiKey.UpdateOneID(key.ID).
err := r.db.WithContext(ctx).Model(m).Select("name", "group_id", "status", "updated_at").Updates(m).Error SetName(key.Name).
SetStatus(key.Status)
if key.GroupID != nil {
builder.SetGroupID(*key.GroupID)
} else {
builder.ClearGroupID()
}
updated, err := builder.Save(ctx)
if err == nil { if err == nil {
applyApiKeyModelToService(key, m) key.UpdatedAt = updated.UpdatedAt
return nil
}
if dbent.IsNotFound(err) {
return service.ErrApiKeyNotFound
} }
return err return err
} }
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error { func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&apiKeyModel{}, id).Error _, err := r.client.ApiKey.Delete().Where(apikey.IDEQ(id)).Exec(ctx)
return err
} }
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) { func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
var keys []apiKeyModel q := r.client.ApiKey.Query().Where(apikey.UserIDEQ(userID))
var total int64
db := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("user_id = ?", userID) total, err := q.Count(ctx)
if err != nil {
if err := db.Count(&total).Error; err != nil {
return nil, nil, err return nil, nil, err
} }
if err := db.Preload("Group").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&keys).Error; err != nil { keys, err := q.
WithGroup().
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(apikey.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err return nil, nil, err
} }
outKeys := make([]service.ApiKey, 0, len(keys)) outKeys := make([]service.ApiKey, 0, len(keys))
for i := range keys { for i := range keys {
outKeys = append(outKeys, *apiKeyModelToService(&keys[i])) outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
} }
return outKeys, paginationResultFromTotal(total, params), nil return outKeys, paginationResultFromTotal(int64(total), params), nil
} }
func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) { func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
...@@ -86,11 +121,9 @@ func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, ap ...@@ -86,11 +121,9 @@ func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, ap
return []int64{}, nil return []int64{}, nil
} }
ids := make([]int64, 0, len(apiKeyIDs)) ids, err := r.client.ApiKey.Query().
err := r.db.WithContext(ctx). Where(apikey.UserIDEQ(userID), apikey.IDIn(apiKeyIDs...)).
Model(&apiKeyModel{}). IDs(ctx)
Where("user_id = ? AND id IN ?", userID, apiKeyIDs).
Pluck("id", &ids).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -98,136 +131,146 @@ func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, ap ...@@ -98,136 +131,146 @@ func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, ap
} }
func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) { func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
var count int64 count, err := r.client.ApiKey.Query().Where(apikey.UserIDEQ(userID)).Count(ctx)
err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("user_id = ?", userID).Count(&count).Error return int64(count), err
return count, err
} }
func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) { func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
var count int64 count, err := r.client.ApiKey.Query().Where(apikey.KeyEQ(key)).Count(ctx)
err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("key = ?", key).Count(&count).Error
return count > 0, err return count > 0, err
} }
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) { func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
var keys []apiKeyModel q := r.client.ApiKey.Query().Where(apikey.GroupIDEQ(groupID))
var total int64
db := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("group_id = ?", groupID)
if err := db.Count(&total).Error; err != nil { total, err := q.Count(ctx)
if err != nil {
return nil, nil, err return nil, nil, err
} }
if err := db.Preload("User").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&keys).Error; err != nil { keys, err := q.
WithUser().
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(apikey.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err return nil, nil, err
} }
outKeys := make([]service.ApiKey, 0, len(keys)) outKeys := make([]service.ApiKey, 0, len(keys))
for i := range keys { for i := range keys {
outKeys = append(outKeys, *apiKeyModelToService(&keys[i])) outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
} }
return outKeys, paginationResultFromTotal(total, params), nil return outKeys, paginationResultFromTotal(int64(total), params), nil
} }
// SearchApiKeys searches API keys by user ID and/or keyword (name) // SearchApiKeys searches API keys by user ID and/or keyword (name)
func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) { func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
var keys []apiKeyModel q := r.client.ApiKey.Query()
db := r.db.WithContext(ctx).Model(&apiKeyModel{})
if userID > 0 { if userID > 0 {
db = db.Where("user_id = ?", userID) q = q.Where(apikey.UserIDEQ(userID))
} }
if keyword != "" { if keyword != "" {
searchPattern := "%" + keyword + "%" q = q.Where(apikey.NameContainsFold(keyword))
db = db.Where("name ILIKE ?", searchPattern)
} }
if err := db.Limit(limit).Order("id DESC").Find(&keys).Error; err != nil { keys, err := q.Limit(limit).Order(dbent.Desc(apikey.FieldID)).All(ctx)
if err != nil {
return nil, err return nil, err
} }
outKeys := make([]service.ApiKey, 0, len(keys)) outKeys := make([]service.ApiKey, 0, len(keys))
for i := range keys { for i := range keys {
outKeys = append(outKeys, *apiKeyModelToService(&keys[i])) outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
} }
return outKeys, nil return outKeys, nil
} }
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil // ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Model(&apiKeyModel{}). n, err := r.client.ApiKey.Update().
Where("group_id = ?", groupID). Where(apikey.GroupIDEQ(groupID)).
Update("group_id", nil) ClearGroupID().
return result.RowsAffected, result.Error Save(ctx)
return int64(n), err
} }
// CountByGroupID 获取分组的 API Key 数量 // CountByGroupID 获取分组的 API Key 数量
func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64 count, err := r.client.ApiKey.Query().Where(apikey.GroupIDEQ(groupID)).Count(ctx)
err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("group_id = ?", groupID).Count(&count).Error return int64(count), err
return count, err
}
type apiKeyModel struct {
ID int64 `gorm:"primaryKey"`
UserID int64 `gorm:"index;not null"`
Key string `gorm:"uniqueIndex;size:128;not null"`
Name string `gorm:"size:100;not null"`
GroupID *int64 `gorm:"index"`
Status string `gorm:"size:20;default:active;not null"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
DeletedAt gorm.DeletedAt `gorm:"index"`
User *userModel `gorm:"foreignKey:UserID"`
Group *groupModel `gorm:"foreignKey:GroupID"`
} }
func (apiKeyModel) TableName() string { return "api_keys" } func apiKeyEntityToService(m *dbent.ApiKey) *service.ApiKey {
func apiKeyModelToService(m *apiKeyModel) *service.ApiKey {
if m == nil { if m == nil {
return nil return nil
} }
return &service.ApiKey{ out := &service.ApiKey{
ID: m.ID, ID: m.ID,
UserID: m.UserID, UserID: m.UserID,
Key: m.Key, Key: m.Key,
Name: m.Name, Name: m.Name,
GroupID: m.GroupID,
Status: m.Status, Status: m.Status,
CreatedAt: m.CreatedAt, CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt, UpdatedAt: m.UpdatedAt,
User: userModelToService(m.User), GroupID: m.GroupID,
Group: groupModelToService(m.Group), }
if m.Edges.User != nil {
out.User = userEntityToService(m.Edges.User)
}
if m.Edges.Group != nil {
out.Group = groupEntityToService(m.Edges.Group)
}
return out
}
func userEntityToService(u *dbent.User) *service.User {
if u == nil {
return nil
}
return &service.User{
ID: u.ID,
Email: u.Email,
Username: u.Username,
Wechat: u.Wechat,
Notes: u.Notes,
PasswordHash: u.PasswordHash,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
} }
} }
func apiKeyModelFromService(k *service.ApiKey) *apiKeyModel { func groupEntityToService(g *dbent.Group) *service.Group {
if k == nil { if g == nil {
return nil return nil
} }
return &apiKeyModel{ return &service.Group{
ID: k.ID, ID: g.ID,
UserID: k.UserID, Name: g.Name,
Key: k.Key, Description: derefString(g.Description),
Name: k.Name, Platform: g.Platform,
GroupID: k.GroupID, RateMultiplier: g.RateMultiplier,
Status: k.Status, IsExclusive: g.IsExclusive,
CreatedAt: k.CreatedAt, Status: g.Status,
UpdatedAt: k.UpdatedAt, SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUsd,
WeeklyLimitUSD: g.WeeklyLimitUsd,
MonthlyLimitUSD: g.MonthlyLimitUsd,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
} }
} }
func applyApiKeyModelToService(key *service.ApiKey, m *apiKeyModel) { func derefString(s *string) string {
if key == nil || m == nil { if s == nil {
return return ""
} }
key.ID = m.ID return *s
key.CreatedAt = m.CreatedAt
key.UpdatedAt = m.UpdatedAt
} }
...@@ -6,23 +6,24 @@ import ( ...@@ -6,23 +6,24 @@ import (
"context" "context"
"testing" "testing"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"gorm.io/gorm"
) )
type ApiKeyRepoSuite struct { type ApiKeyRepoSuite struct {
suite.Suite suite.Suite
ctx context.Context ctx context.Context
db *gorm.DB client *dbent.Client
repo *apiKeyRepository repo *apiKeyRepository
} }
func (s *ApiKeyRepoSuite) SetupTest() { func (s *ApiKeyRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
s.db = testTx(s.T()) entClient, _ := testEntSQLTx(s.T())
s.repo = NewApiKeyRepository(s.db).(*apiKeyRepository) s.client = entClient
s.repo = NewApiKeyRepository(entClient).(*apiKeyRepository)
} }
func TestApiKeyRepoSuite(t *testing.T) { func TestApiKeyRepoSuite(t *testing.T) {
...@@ -32,7 +33,7 @@ func TestApiKeyRepoSuite(t *testing.T) { ...@@ -32,7 +33,7 @@ func TestApiKeyRepoSuite(t *testing.T) {
// --- Create / GetByID / GetByKey --- // --- Create / GetByID / GetByKey ---
func (s *ApiKeyRepoSuite) TestCreate() { func (s *ApiKeyRepoSuite) TestCreate() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "create@test.com"}) user := s.mustCreateUser("create@test.com")
key := &service.ApiKey{ key := &service.ApiKey{
UserID: user.ID, UserID: user.ID,
...@@ -56,16 +57,17 @@ func (s *ApiKeyRepoSuite) TestGetByID_NotFound() { ...@@ -56,16 +57,17 @@ func (s *ApiKeyRepoSuite) TestGetByID_NotFound() {
} }
func (s *ApiKeyRepoSuite) TestGetByKey() { func (s *ApiKeyRepoSuite) TestGetByKey() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "getbykey@test.com"}) user := s.mustCreateUser("getbykey@test.com")
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-key"}) group := s.mustCreateGroup("g-key")
key := mustCreateApiKey(s.T(), s.db, &apiKeyModel{ key := &service.ApiKey{
UserID: user.ID, UserID: user.ID,
Key: "sk-getbykey", Key: "sk-getbykey",
Name: "My Key", Name: "My Key",
GroupID: &group.ID, GroupID: &group.ID,
Status: service.StatusActive, Status: service.StatusActive,
}) }
s.Require().NoError(s.repo.Create(s.ctx, key))
got, err := s.repo.GetByKey(s.ctx, key.Key) got, err := s.repo.GetByKey(s.ctx, key.Key)
s.Require().NoError(err, "GetByKey") s.Require().NoError(err, "GetByKey")
...@@ -84,13 +86,14 @@ func (s *ApiKeyRepoSuite) TestGetByKey_NotFound() { ...@@ -84,13 +86,14 @@ func (s *ApiKeyRepoSuite) TestGetByKey_NotFound() {
// --- Update --- // --- Update ---
func (s *ApiKeyRepoSuite) TestUpdate() { func (s *ApiKeyRepoSuite) TestUpdate() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "update@test.com"}) user := s.mustCreateUser("update@test.com")
key := apiKeyModelToService(mustCreateApiKey(s.T(), s.db, &apiKeyModel{ key := &service.ApiKey{
UserID: user.ID, UserID: user.ID,
Key: "sk-update", Key: "sk-update",
Name: "Original", Name: "Original",
Status: service.StatusActive, Status: service.StatusActive,
})) }
s.Require().NoError(s.repo.Create(s.ctx, key))
key.Name = "Renamed" key.Name = "Renamed"
key.Status = service.StatusDisabled key.Status = service.StatusDisabled
...@@ -106,14 +109,16 @@ func (s *ApiKeyRepoSuite) TestUpdate() { ...@@ -106,14 +109,16 @@ func (s *ApiKeyRepoSuite) TestUpdate() {
} }
func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() { func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "cleargroup@test.com"}) user := s.mustCreateUser("cleargroup@test.com")
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-clear"}) group := s.mustCreateGroup("g-clear")
key := apiKeyModelToService(mustCreateApiKey(s.T(), s.db, &apiKeyModel{ key := &service.ApiKey{
UserID: user.ID, UserID: user.ID,
Key: "sk-clear-group", Key: "sk-clear-group",
Name: "Group Key", Name: "Group Key",
GroupID: &group.ID, GroupID: &group.ID,
})) Status: service.StatusActive,
}
s.Require().NoError(s.repo.Create(s.ctx, key))
key.GroupID = nil key.GroupID = nil
err := s.repo.Update(s.ctx, key) err := s.repo.Update(s.ctx, key)
...@@ -127,12 +132,14 @@ func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() { ...@@ -127,12 +132,14 @@ func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
// --- Delete --- // --- Delete ---
func (s *ApiKeyRepoSuite) TestDelete() { func (s *ApiKeyRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"}) user := s.mustCreateUser("delete@test.com")
key := mustCreateApiKey(s.T(), s.db, &apiKeyModel{ key := &service.ApiKey{
UserID: user.ID, UserID: user.ID,
Key: "sk-delete", Key: "sk-delete",
Name: "Delete Me", Name: "Delete Me",
}) Status: service.StatusActive,
}
s.Require().NoError(s.repo.Create(s.ctx, key))
err := s.repo.Delete(s.ctx, key.ID) err := s.repo.Delete(s.ctx, key.ID)
s.Require().NoError(err, "Delete") s.Require().NoError(err, "Delete")
...@@ -144,9 +151,9 @@ func (s *ApiKeyRepoSuite) TestDelete() { ...@@ -144,9 +151,9 @@ func (s *ApiKeyRepoSuite) TestDelete() {
// --- ListByUserID / CountByUserID --- // --- ListByUserID / CountByUserID ---
func (s *ApiKeyRepoSuite) TestListByUserID() { func (s *ApiKeyRepoSuite) TestListByUserID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyuser@test.com"}) user := s.mustCreateUser("listbyuser@test.com")
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-list-1", Name: "Key 1"}) s.mustCreateApiKey(user.ID, "sk-list-1", "Key 1", nil)
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-list-2", Name: "Key 2"}) s.mustCreateApiKey(user.ID, "sk-list-2", "Key 2", nil)
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10}) keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByUserID") s.Require().NoError(err, "ListByUserID")
...@@ -155,13 +162,9 @@ func (s *ApiKeyRepoSuite) TestListByUserID() { ...@@ -155,13 +162,9 @@ func (s *ApiKeyRepoSuite) TestListByUserID() {
} }
func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() { func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "paging@test.com"}) user := s.mustCreateUser("paging@test.com")
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
mustCreateApiKey(s.T(), s.db, &apiKeyModel{ s.mustCreateApiKey(user.ID, "sk-page-"+string(rune('a'+i)), "Key", nil)
UserID: user.ID,
Key: "sk-page-" + string(rune('a'+i)),
Name: "Key",
})
} }
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2}) keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2})
...@@ -172,9 +175,9 @@ func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() { ...@@ -172,9 +175,9 @@ func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
} }
func (s *ApiKeyRepoSuite) TestCountByUserID() { func (s *ApiKeyRepoSuite) TestCountByUserID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "count@test.com"}) user := s.mustCreateUser("count@test.com")
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-count-1", Name: "K1"}) s.mustCreateApiKey(user.ID, "sk-count-1", "K1", nil)
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-count-2", Name: "K2"}) s.mustCreateApiKey(user.ID, "sk-count-2", "K2", nil)
count, err := s.repo.CountByUserID(s.ctx, user.ID) count, err := s.repo.CountByUserID(s.ctx, user.ID)
s.Require().NoError(err, "CountByUserID") s.Require().NoError(err, "CountByUserID")
...@@ -184,12 +187,12 @@ func (s *ApiKeyRepoSuite) TestCountByUserID() { ...@@ -184,12 +187,12 @@ func (s *ApiKeyRepoSuite) TestCountByUserID() {
// --- ListByGroupID / CountByGroupID --- // --- ListByGroupID / CountByGroupID ---
func (s *ApiKeyRepoSuite) TestListByGroupID() { func (s *ApiKeyRepoSuite) TestListByGroupID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbygroup@test.com"}) user := s.mustCreateUser("listbygroup@test.com")
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list"}) group := s.mustCreateGroup("g-list")
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-grp-1", Name: "K1", GroupID: &group.ID}) s.mustCreateApiKey(user.ID, "sk-grp-1", "K1", &group.ID)
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-grp-2", Name: "K2", GroupID: &group.ID}) s.mustCreateApiKey(user.ID, "sk-grp-2", "K2", &group.ID)
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-grp-3", Name: "K3"}) // no group s.mustCreateApiKey(user.ID, "sk-grp-3", "K3", nil) // no group
keys, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10}) keys, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByGroupID") s.Require().NoError(err, "ListByGroupID")
...@@ -200,10 +203,9 @@ func (s *ApiKeyRepoSuite) TestListByGroupID() { ...@@ -200,10 +203,9 @@ func (s *ApiKeyRepoSuite) TestListByGroupID() {
} }
func (s *ApiKeyRepoSuite) TestCountByGroupID() { func (s *ApiKeyRepoSuite) TestCountByGroupID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "countgroup@test.com"}) user := s.mustCreateUser("countgroup@test.com")
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-count"}) group := s.mustCreateGroup("g-count")
s.mustCreateApiKey(user.ID, "sk-gc-1", "K1", &group.ID)
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-gc-1", Name: "K1", GroupID: &group.ID})
count, err := s.repo.CountByGroupID(s.ctx, group.ID) count, err := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "CountByGroupID") s.Require().NoError(err, "CountByGroupID")
...@@ -213,8 +215,8 @@ func (s *ApiKeyRepoSuite) TestCountByGroupID() { ...@@ -213,8 +215,8 @@ func (s *ApiKeyRepoSuite) TestCountByGroupID() {
// --- ExistsByKey --- // --- ExistsByKey ---
func (s *ApiKeyRepoSuite) TestExistsByKey() { func (s *ApiKeyRepoSuite) TestExistsByKey() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "exists@test.com"}) user := s.mustCreateUser("exists@test.com")
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-exists", Name: "K"}) s.mustCreateApiKey(user.ID, "sk-exists", "K", nil)
exists, err := s.repo.ExistsByKey(s.ctx, "sk-exists") exists, err := s.repo.ExistsByKey(s.ctx, "sk-exists")
s.Require().NoError(err, "ExistsByKey") s.Require().NoError(err, "ExistsByKey")
...@@ -228,9 +230,9 @@ func (s *ApiKeyRepoSuite) TestExistsByKey() { ...@@ -228,9 +230,9 @@ func (s *ApiKeyRepoSuite) TestExistsByKey() {
// --- SearchApiKeys --- // --- SearchApiKeys ---
func (s *ApiKeyRepoSuite) TestSearchApiKeys() { func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "search@test.com"}) user := s.mustCreateUser("search@test.com")
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-search-1", Name: "Production Key"}) s.mustCreateApiKey(user.ID, "sk-search-1", "Production Key", nil)
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-search-2", Name: "Development Key"}) s.mustCreateApiKey(user.ID, "sk-search-2", "Development Key", nil)
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "prod", 10) found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "prod", 10)
s.Require().NoError(err, "SearchApiKeys") s.Require().NoError(err, "SearchApiKeys")
...@@ -239,9 +241,9 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys() { ...@@ -239,9 +241,9 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
} }
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() { func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "searchnokw@test.com"}) user := s.mustCreateUser("searchnokw@test.com")
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-nk-1", Name: "K1"}) s.mustCreateApiKey(user.ID, "sk-nk-1", "K1", nil)
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-nk-2", Name: "K2"}) s.mustCreateApiKey(user.ID, "sk-nk-2", "K2", nil)
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "", 10) found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "", 10)
s.Require().NoError(err) s.Require().NoError(err)
...@@ -249,8 +251,8 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() { ...@@ -249,8 +251,8 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
} }
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() { func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "searchnouid@test.com"}) user := s.mustCreateUser("searchnouid@test.com")
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-nu-1", Name: "TestKey"}) s.mustCreateApiKey(user.ID, "sk-nu-1", "TestKey", nil)
found, err := s.repo.SearchApiKeys(s.ctx, 0, "testkey", 10) found, err := s.repo.SearchApiKeys(s.ctx, 0, "testkey", 10)
s.Require().NoError(err) s.Require().NoError(err)
...@@ -260,12 +262,12 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() { ...@@ -260,12 +262,12 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
// --- ClearGroupIDByGroupID --- // --- ClearGroupIDByGroupID ---
func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() { func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "cleargrp@test.com"}) user := s.mustCreateUser("cleargrp@test.com")
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-clear-bulk"}) group := s.mustCreateGroup("g-clear-bulk")
k1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-clr-1", Name: "K1", GroupID: &group.ID}) k1 := s.mustCreateApiKey(user.ID, "sk-clr-1", "K1", &group.ID)
k2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-clr-2", Name: "K2", GroupID: &group.ID}) k2 := s.mustCreateApiKey(user.ID, "sk-clr-2", "K2", &group.ID)
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-clr-3", Name: "K3"}) // no group s.mustCreateApiKey(user.ID, "sk-clr-3", "K3", nil) // no group
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID) affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "ClearGroupIDByGroupID") s.Require().NoError(err, "ClearGroupIDByGroupID")
...@@ -283,16 +285,10 @@ func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() { ...@@ -283,16 +285,10 @@ func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
// --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) --- // --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) ---
func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() { func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "k@example.com"}) user := s.mustCreateUser("k@example.com")
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-k"}) group := s.mustCreateGroup("g-k")
key := s.mustCreateApiKey(user.ID, "sk-test-1", "My Key", &group.ID)
key := apiKeyModelToService(mustCreateApiKey(s.T(), s.db, &apiKeyModel{ key.GroupID = &group.ID
UserID: user.ID,
Key: "sk-test-1",
Name: "My Key",
GroupID: &group.ID,
Status: service.StatusActive,
}))
got, err := s.repo.GetByKey(s.ctx, key.Key) got, err := s.repo.GetByKey(s.ctx, key.Key)
s.Require().NoError(err, "GetByKey") s.Require().NoError(err, "GetByKey")
...@@ -330,12 +326,8 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() { ...@@ -330,12 +326,8 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s.Require().Equal(key.ID, found[0].ID) s.Require().Equal(key.ID, found[0].ID)
// ClearGroupIDByGroupID // ClearGroupIDByGroupID
k2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{ k2 := s.mustCreateApiKey(user.ID, "sk-test-2", "Group Key", &group.ID)
UserID: user.ID, k2.GroupID = &group.ID
Key: "sk-test-2",
Name: "Group Key",
GroupID: &group.ID,
})
countBefore, err := s.repo.CountByGroupID(s.ctx, group.ID) countBefore, err := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "CountByGroupID") s.Require().NoError(err, "CountByGroupID")
...@@ -353,3 +345,41 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() { ...@@ -353,3 +345,41 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s.Require().NoError(err, "CountByGroupID after clear") s.Require().NoError(err, "CountByGroupID after clear")
s.Require().Equal(int64(0), countAfter, "expected 0 keys in group after clear") s.Require().Equal(int64(0), countAfter, "expected 0 keys in group after clear")
} }
func (s *ApiKeyRepoSuite) mustCreateUser(email string) *service.User {
s.T().Helper()
u, err := s.client.User.Create().
SetEmail(email).
SetPasswordHash("test-password-hash").
SetStatus(service.StatusActive).
SetRole(service.RoleUser).
Save(s.ctx)
s.Require().NoError(err, "create user")
return userEntityToService(u)
}
func (s *ApiKeyRepoSuite) mustCreateGroup(name string) *service.Group {
s.T().Helper()
g, err := s.client.Group.Create().
SetName(name).
SetStatus(service.StatusActive).
Save(s.ctx)
s.Require().NoError(err, "create group")
return groupEntityToService(g)
}
func (s *ApiKeyRepoSuite) mustCreateApiKey(userID int64, key, name string, groupID *int64) *service.ApiKey {
s.T().Helper()
k := &service.ApiKey{
UserID: userID,
Key: key,
Name: name,
GroupID: groupID,
Status: service.StatusActive,
}
s.Require().NoError(s.repo.Create(s.ctx, k), "create api key")
return k
}
package repository
import (
"log"
"time"
"gorm.io/gorm"
)
// MaxExpiresAt is the maximum allowed expiration date for subscriptions (year 2099)
// This prevents time.Time JSON serialization errors (RFC 3339 requires year <= 9999)
var maxExpiresAt = time.Date(2099, 12, 31, 23, 59, 59, 0, time.UTC)
// AutoMigrate runs schema migrations for all repository persistence models.
// Persistence models are defined within individual `*_repo.go` files.
func AutoMigrate(db *gorm.DB) error {
err := db.AutoMigrate(
&userModel{},
&apiKeyModel{},
&groupModel{},
&accountModel{},
&accountGroupModel{},
&proxyModel{},
&redeemCodeModel{},
&usageLogModel{},
&settingModel{},
&userSubscriptionModel{},
)
if err != nil {
return err
}
// 修复无效的过期时间(年份超过 2099 会导致 JSON 序列化失败)
return fixInvalidExpiresAt(db)
}
// fixInvalidExpiresAt 修复 user_subscriptions 表中无效的过期时间
func fixInvalidExpiresAt(db *gorm.DB) error {
result := db.Model(&userSubscriptionModel{}).
Where("expires_at > ?", maxExpiresAt).
Update("expires_at", maxExpiresAt)
if result.Error != nil {
return result.Error
}
if result.RowsAffected > 0 {
log.Printf("[AutoMigrate] Fixed %d subscriptions with invalid expires_at (year > 2099)", result.RowsAffected)
}
return nil
}
package repository package repository
import ( import (
"database/sql"
"errors" "errors"
"strings" "strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"gorm.io/gorm" "github.com/lib/pq"
) )
// translatePersistenceError 将数据库层错误翻译为业务层错误。
//
// 这是 Repository 层的核心错误处理函数,确保数据库细节不会泄露到业务层。
// 通过统一的错误翻译,业务层可以使用语义明确的错误类型(如 ErrUserNotFound)
// 而不是依赖于特定数据库的错误(如 sql.ErrNoRows)。
//
// 参数:
// - err: 原始数据库错误
// - notFound: 当记录不存在时返回的业务错误(可为 nil 表示不处理)
// - conflict: 当违反唯一约束时返回的业务错误(可为 nil 表示不处理)
//
// 返回:
// - 翻译后的业务错误,或原始错误(如果不匹配任何规则)
//
// 示例:
//
// err := translatePersistenceError(dbErr, service.ErrUserNotFound, service.ErrEmailExists)
func translatePersistenceError(err error, notFound, conflict *infraerrors.ApplicationError) error { func translatePersistenceError(err error, notFound, conflict *infraerrors.ApplicationError) error {
if err == nil { if err == nil {
return nil return nil
} }
if notFound != nil && errors.Is(err, gorm.ErrRecordNotFound) { // 兼容 Ent ORM 和标准 database/sql 的 NotFound 行为。
// Ent 使用自定义的 NotFoundError,而标准库使用 sql.ErrNoRows。
// 这里同时处理两种情况,保持业务错误映射一致。
if notFound != nil && (errors.Is(err, sql.ErrNoRows) || dbent.IsNotFound(err)) {
return notFound.WithCause(err) return notFound.WithCause(err)
} }
// 处理唯一约束冲突(如邮箱已存在、名称重复等)
if conflict != nil && isUniqueConstraintViolation(err) { if conflict != nil && isUniqueConstraintViolation(err) {
return conflict.WithCause(err) return conflict.WithCause(err)
} }
// 未匹配任何规则,返回原始错误
return err return err
} }
// isUniqueConstraintViolation 判断错误是否为唯一约束冲突。
//
// 支持多种检测方式:
// 1. PostgreSQL 特定错误码 23505(唯一约束冲突)
// 2. 错误消息中包含的通用关键词
//
// 这种多层次的检测确保了对不同数据库驱动和 ORM 的兼容性。
func isUniqueConstraintViolation(err error) bool { func isUniqueConstraintViolation(err error) bool {
if err == nil { if err == nil {
return false return false
} }
if errors.Is(err, gorm.ErrDuplicatedKey) { // 优先检测 PostgreSQL 特定错误码(最精确)。
return true // 错误码 23505 对应 unique_violation。
// 参考:https://www.postgresql.org/docs/current/errcodes-appendix.html
var pgErr *pq.Error
if errors.As(err, &pgErr) {
return pgErr.Code == "23505"
} }
// 回退到错误消息检测(兼容其他场景)。
// 这些关键词覆盖了 PostgreSQL、MySQL 等主流数据库的错误消息。
msg := strings.ToLower(err.Error()) msg := strings.ToLower(err.Error())
return strings.Contains(msg, "duplicate key") || return strings.Contains(msg, "duplicate key") ||
strings.Contains(msg, "unique constraint") || strings.Contains(msg, "unique constraint") ||
......
...@@ -3,17 +3,22 @@ ...@@ -3,17 +3,22 @@
package repository package repository
import ( import (
"context"
"testing" "testing"
"time" "time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gorm.io/datatypes"
"gorm.io/gorm"
) )
func mustCreateUser(t *testing.T, db *gorm.DB, u *userModel) *userModel { func mustCreateUser(t *testing.T, client *dbent.Client, u *service.User) *service.User {
t.Helper() t.Helper()
ctx := context.Background()
if u.Email == "" {
u.Email = "user-" + time.Now().Format(time.RFC3339Nano) + "@example.com"
}
if u.PasswordHash == "" { if u.PasswordHash == "" {
u.PasswordHash = "test-password-hash" u.PasswordHash = "test-password-hash"
} }
...@@ -26,18 +31,48 @@ func mustCreateUser(t *testing.T, db *gorm.DB, u *userModel) *userModel { ...@@ -26,18 +31,48 @@ func mustCreateUser(t *testing.T, db *gorm.DB, u *userModel) *userModel {
if u.Concurrency == 0 { if u.Concurrency == 0 {
u.Concurrency = 5 u.Concurrency = 5
} }
if u.CreatedAt.IsZero() {
u.CreatedAt = time.Now() create := client.User.Create().
SetEmail(u.Email).
SetPasswordHash(u.PasswordHash).
SetRole(u.Role).
SetStatus(u.Status).
SetBalance(u.Balance).
SetConcurrency(u.Concurrency).
SetUsername(u.Username).
SetWechat(u.Wechat).
SetNotes(u.Notes)
if !u.CreatedAt.IsZero() {
create.SetCreatedAt(u.CreatedAt)
}
if !u.UpdatedAt.IsZero() {
create.SetUpdatedAt(u.UpdatedAt)
} }
if u.UpdatedAt.IsZero() {
u.UpdatedAt = u.CreatedAt created, err := create.Save(ctx)
require.NoError(t, err, "create user")
u.ID = created.ID
u.CreatedAt = created.CreatedAt
u.UpdatedAt = created.UpdatedAt
if len(u.AllowedGroups) > 0 {
for _, groupID := range u.AllowedGroups {
_, err := client.UserAllowedGroup.Create().
SetUserID(u.ID).
SetGroupID(groupID).
Save(ctx)
require.NoError(t, err, "create user_allowed_groups row")
}
} }
require.NoError(t, db.Create(u).Error, "create user")
return u return u
} }
func mustCreateGroup(t *testing.T, db *gorm.DB, g *groupModel) *groupModel { func mustCreateGroup(t *testing.T, client *dbent.Client, g *service.Group) *service.Group {
t.Helper() t.Helper()
ctx := context.Background()
if g.Platform == "" { if g.Platform == "" {
g.Platform = service.PlatformAnthropic g.Platform = service.PlatformAnthropic
} }
...@@ -47,18 +82,46 @@ func mustCreateGroup(t *testing.T, db *gorm.DB, g *groupModel) *groupModel { ...@@ -47,18 +82,46 @@ func mustCreateGroup(t *testing.T, db *gorm.DB, g *groupModel) *groupModel {
if g.SubscriptionType == "" { if g.SubscriptionType == "" {
g.SubscriptionType = service.SubscriptionTypeStandard g.SubscriptionType = service.SubscriptionTypeStandard
} }
if g.CreatedAt.IsZero() {
g.CreatedAt = time.Now() create := client.Group.Create().
SetName(g.Name).
SetPlatform(g.Platform).
SetStatus(g.Status).
SetSubscriptionType(g.SubscriptionType).
SetRateMultiplier(g.RateMultiplier).
SetIsExclusive(g.IsExclusive)
if g.Description != "" {
create.SetDescription(g.Description)
}
if g.DailyLimitUSD != nil {
create.SetDailyLimitUsd(*g.DailyLimitUSD)
}
if g.WeeklyLimitUSD != nil {
create.SetWeeklyLimitUsd(*g.WeeklyLimitUSD)
}
if g.MonthlyLimitUSD != nil {
create.SetMonthlyLimitUsd(*g.MonthlyLimitUSD)
} }
if g.UpdatedAt.IsZero() { if !g.CreatedAt.IsZero() {
g.UpdatedAt = g.CreatedAt create.SetCreatedAt(g.CreatedAt)
} }
require.NoError(t, db.Create(g).Error, "create group") if !g.UpdatedAt.IsZero() {
create.SetUpdatedAt(g.UpdatedAt)
}
created, err := create.Save(ctx)
require.NoError(t, err, "create group")
g.ID = created.ID
g.CreatedAt = created.CreatedAt
g.UpdatedAt = created.UpdatedAt
return g return g
} }
func mustCreateProxy(t *testing.T, db *gorm.DB, p *proxyModel) *proxyModel { func mustCreateProxy(t *testing.T, client *dbent.Client, p *service.Proxy) *service.Proxy {
t.Helper() t.Helper()
ctx := context.Background()
if p.Protocol == "" { if p.Protocol == "" {
p.Protocol = "http" p.Protocol = "http"
} }
...@@ -71,18 +134,39 @@ func mustCreateProxy(t *testing.T, db *gorm.DB, p *proxyModel) *proxyModel { ...@@ -71,18 +134,39 @@ func mustCreateProxy(t *testing.T, db *gorm.DB, p *proxyModel) *proxyModel {
if p.Status == "" { if p.Status == "" {
p.Status = service.StatusActive p.Status = service.StatusActive
} }
if p.CreatedAt.IsZero() {
p.CreatedAt = time.Now() create := client.Proxy.Create().
SetName(p.Name).
SetProtocol(p.Protocol).
SetHost(p.Host).
SetPort(p.Port).
SetStatus(p.Status)
if p.Username != "" {
create.SetUsername(p.Username)
}
if p.Password != "" {
create.SetPassword(p.Password)
} }
if p.UpdatedAt.IsZero() { if !p.CreatedAt.IsZero() {
p.UpdatedAt = p.CreatedAt create.SetCreatedAt(p.CreatedAt)
} }
require.NoError(t, db.Create(p).Error, "create proxy") if !p.UpdatedAt.IsZero() {
create.SetUpdatedAt(p.UpdatedAt)
}
created, err := create.Save(ctx)
require.NoError(t, err, "create proxy")
p.ID = created.ID
p.CreatedAt = created.CreatedAt
p.UpdatedAt = created.UpdatedAt
return p return p
} }
func mustCreateAccount(t *testing.T, db *gorm.DB, a *accountModel) *accountModel { func mustCreateAccount(t *testing.T, client *dbent.Client, a *service.Account) *service.Account {
t.Helper() t.Helper()
ctx := context.Background()
if a.Platform == "" { if a.Platform == "" {
a.Platform = service.PlatformAnthropic a.Platform = service.PlatformAnthropic
} }
...@@ -92,57 +176,158 @@ func mustCreateAccount(t *testing.T, db *gorm.DB, a *accountModel) *accountModel ...@@ -92,57 +176,158 @@ func mustCreateAccount(t *testing.T, db *gorm.DB, a *accountModel) *accountModel
if a.Status == "" { if a.Status == "" {
a.Status = service.StatusActive a.Status = service.StatusActive
} }
if a.Concurrency == 0 {
a.Concurrency = 3
}
if a.Priority == 0 {
a.Priority = 50
}
if !a.Schedulable { if !a.Schedulable {
a.Schedulable = true a.Schedulable = true
} }
if a.Credentials == nil { if a.Credentials == nil {
a.Credentials = datatypes.JSONMap{} a.Credentials = map[string]any{}
} }
if a.Extra == nil { if a.Extra == nil {
a.Extra = datatypes.JSONMap{} a.Extra = map[string]any{}
}
create := client.Account.Create().
SetName(a.Name).
SetPlatform(a.Platform).
SetType(a.Type).
SetCredentials(a.Credentials).
SetExtra(a.Extra).
SetConcurrency(a.Concurrency).
SetPriority(a.Priority).
SetStatus(a.Status).
SetSchedulable(a.Schedulable).
SetErrorMessage(a.ErrorMessage)
if a.ProxyID != nil {
create.SetProxyID(*a.ProxyID)
}
if a.LastUsedAt != nil {
create.SetLastUsedAt(*a.LastUsedAt)
}
if a.RateLimitedAt != nil {
create.SetRateLimitedAt(*a.RateLimitedAt)
}
if a.RateLimitResetAt != nil {
create.SetRateLimitResetAt(*a.RateLimitResetAt)
}
if a.OverloadUntil != nil {
create.SetOverloadUntil(*a.OverloadUntil)
} }
if a.CreatedAt.IsZero() { if a.SessionWindowStart != nil {
a.CreatedAt = time.Now() create.SetSessionWindowStart(*a.SessionWindowStart)
} }
if a.UpdatedAt.IsZero() { if a.SessionWindowEnd != nil {
a.UpdatedAt = a.CreatedAt create.SetSessionWindowEnd(*a.SessionWindowEnd)
} }
require.NoError(t, db.Create(a).Error, "create account") if a.SessionWindowStatus != "" {
create.SetSessionWindowStatus(a.SessionWindowStatus)
}
if !a.CreatedAt.IsZero() {
create.SetCreatedAt(a.CreatedAt)
}
if !a.UpdatedAt.IsZero() {
create.SetUpdatedAt(a.UpdatedAt)
}
created, err := create.Save(ctx)
require.NoError(t, err, "create account")
a.ID = created.ID
a.CreatedAt = created.CreatedAt
a.UpdatedAt = created.UpdatedAt
return a return a
} }
func mustCreateApiKey(t *testing.T, db *gorm.DB, k *apiKeyModel) *apiKeyModel { func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.ApiKey) *service.ApiKey {
t.Helper() t.Helper()
ctx := context.Background()
if k.Status == "" { if k.Status == "" {
k.Status = service.StatusActive k.Status = service.StatusActive
} }
if k.CreatedAt.IsZero() { if k.Key == "" {
k.CreatedAt = time.Now() k.Key = "sk-" + time.Now().Format("150405.000000")
} }
if k.UpdatedAt.IsZero() { if k.Name == "" {
k.UpdatedAt = k.CreatedAt k.Name = "default"
}
create := client.ApiKey.Create().
SetUserID(k.UserID).
SetKey(k.Key).
SetName(k.Name).
SetStatus(k.Status)
if k.GroupID != nil {
create.SetGroupID(*k.GroupID)
} }
require.NoError(t, db.Create(k).Error, "create api key") if !k.CreatedAt.IsZero() {
create.SetCreatedAt(k.CreatedAt)
}
if !k.UpdatedAt.IsZero() {
create.SetUpdatedAt(k.UpdatedAt)
}
created, err := create.Save(ctx)
require.NoError(t, err, "create api key")
k.ID = created.ID
k.CreatedAt = created.CreatedAt
k.UpdatedAt = created.UpdatedAt
return k return k
} }
func mustCreateRedeemCode(t *testing.T, db *gorm.DB, c *redeemCodeModel) *redeemCodeModel { func mustCreateRedeemCode(t *testing.T, client *dbent.Client, c *service.RedeemCode) *service.RedeemCode {
t.Helper() t.Helper()
ctx := context.Background()
if c.Status == "" { if c.Status == "" {
c.Status = service.StatusUnused c.Status = service.StatusUnused
} }
if c.Type == "" { if c.Type == "" {
c.Type = service.RedeemTypeBalance c.Type = service.RedeemTypeBalance
} }
if c.CreatedAt.IsZero() { if c.Code == "" {
c.CreatedAt = time.Now() c.Code = "rc-" + time.Now().Format("150405.000000")
}
create := client.RedeemCode.Create().
SetCode(c.Code).
SetType(c.Type).
SetValue(c.Value).
SetStatus(c.Status).
SetNotes(c.Notes).
SetValidityDays(c.ValidityDays)
if c.UsedBy != nil {
create.SetUsedBy(*c.UsedBy)
}
if c.UsedAt != nil {
create.SetUsedAt(*c.UsedAt)
}
if c.GroupID != nil {
create.SetGroupID(*c.GroupID)
} }
require.NoError(t, db.Create(c).Error, "create redeem code") if !c.CreatedAt.IsZero() {
create.SetCreatedAt(c.CreatedAt)
}
created, err := create.Save(ctx)
require.NoError(t, err, "create redeem code")
c.ID = created.ID
c.CreatedAt = created.CreatedAt
return c return c
} }
func mustCreateSubscription(t *testing.T, db *gorm.DB, s *userSubscriptionModel) *userSubscriptionModel { func mustCreateSubscription(t *testing.T, client *dbent.Client, s *service.UserSubscription) *service.UserSubscription {
t.Helper() t.Helper()
ctx := context.Background()
if s.Status == "" { if s.Status == "" {
s.Status = service.SubscriptionStatusActive s.Status = service.SubscriptionStatusActive
} }
...@@ -162,16 +347,47 @@ func mustCreateSubscription(t *testing.T, db *gorm.DB, s *userSubscriptionModel) ...@@ -162,16 +347,47 @@ func mustCreateSubscription(t *testing.T, db *gorm.DB, s *userSubscriptionModel)
if s.UpdatedAt.IsZero() { if s.UpdatedAt.IsZero() {
s.UpdatedAt = now s.UpdatedAt = now
} }
require.NoError(t, db.Create(s).Error, "create user subscription")
create := client.UserSubscription.Create().
SetUserID(s.UserID).
SetGroupID(s.GroupID).
SetStartsAt(s.StartsAt).
SetExpiresAt(s.ExpiresAt).
SetStatus(s.Status).
SetAssignedAt(s.AssignedAt).
SetNotes(s.Notes).
SetDailyUsageUsd(s.DailyUsageUSD).
SetWeeklyUsageUsd(s.WeeklyUsageUSD).
SetMonthlyUsageUsd(s.MonthlyUsageUSD)
if s.AssignedBy != nil {
create.SetAssignedBy(*s.AssignedBy)
}
if !s.CreatedAt.IsZero() {
create.SetCreatedAt(s.CreatedAt)
}
if !s.UpdatedAt.IsZero() {
create.SetUpdatedAt(s.UpdatedAt)
}
created, err := create.Save(ctx)
require.NoError(t, err, "create user subscription")
s.ID = created.ID
s.CreatedAt = created.CreatedAt
s.UpdatedAt = created.UpdatedAt
return s return s
} }
func mustBindAccountToGroup(t *testing.T, db *gorm.DB, accountID, groupID int64, priority int) { func mustBindAccountToGroup(t *testing.T, client *dbent.Client, accountID, groupID int64, priority int) {
t.Helper() t.Helper()
require.NoError(t, db.Create(&accountGroupModel{ ctx := context.Background()
AccountID: accountID,
GroupID: groupID, _, err := client.AccountGroup.Create().
Priority: priority, SetAccountID(accountID).
CreatedAt: time.Now(), SetGroupID(groupID).
}).Error, "create account_group") SetPriority(priority).
Save(ctx)
require.NoError(t, err, "create account_group")
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment