Commit f6f072cb authored by Edric Li's avatar Edric Li
Browse files

Merge branch 'main' into feat/api-key-ip-restriction

parents 5265b12c ff087586
...@@ -871,6 +871,29 @@ func HasAttributeValuesWith(preds ...predicate.UserAttributeValue) predicate.Use ...@@ -871,6 +871,29 @@ func HasAttributeValuesWith(preds ...predicate.UserAttributeValue) predicate.Use
}) })
} }
// HasPromoCodeUsages applies the HasEdge predicate on the "promo_code_usages" edge.
func HasPromoCodeUsages() predicate.User {
return predicate.User(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, PromoCodeUsagesTable, PromoCodeUsagesColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasPromoCodeUsagesWith applies the HasEdge predicate on the "promo_code_usages" edge with a given conditions (other predicates).
func HasPromoCodeUsagesWith(preds ...predicate.PromoCodeUsage) predicate.User {
return predicate.User(func(s *sql.Selector) {
step := newPromoCodeUsagesStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// HasUserAllowedGroups applies the HasEdge predicate on the "user_allowed_groups" edge. // HasUserAllowedGroups applies the HasEdge predicate on the "user_allowed_groups" edge.
func HasUserAllowedGroups() predicate.User { func HasUserAllowedGroups() predicate.User {
return predicate.User(func(s *sql.Selector) { return predicate.User(func(s *sql.Selector) {
......
...@@ -13,6 +13,7 @@ import ( ...@@ -13,6 +13,7 @@ import (
"entgo.io/ent/schema/field" "entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/user"
...@@ -271,6 +272,21 @@ func (_c *UserCreate) AddAttributeValues(v ...*UserAttributeValue) *UserCreate { ...@@ -271,6 +272,21 @@ func (_c *UserCreate) AddAttributeValues(v ...*UserAttributeValue) *UserCreate {
return _c.AddAttributeValueIDs(ids...) return _c.AddAttributeValueIDs(ids...)
} }
// AddPromoCodeUsageIDs adds the "promo_code_usages" edge to the PromoCodeUsage entity by IDs.
func (_c *UserCreate) AddPromoCodeUsageIDs(ids ...int64) *UserCreate {
_c.mutation.AddPromoCodeUsageIDs(ids...)
return _c
}
// AddPromoCodeUsages adds the "promo_code_usages" edges to the PromoCodeUsage entity.
func (_c *UserCreate) AddPromoCodeUsages(v ...*PromoCodeUsage) *UserCreate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _c.AddPromoCodeUsageIDs(ids...)
}
// Mutation returns the UserMutation object of the builder. // Mutation returns the UserMutation object of the builder.
func (_c *UserCreate) Mutation() *UserMutation { func (_c *UserCreate) Mutation() *UserMutation {
return _c.mutation return _c.mutation
...@@ -593,6 +609,22 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { ...@@ -593,6 +609,22 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
} }
_spec.Edges = append(_spec.Edges, edge) _spec.Edges = append(_spec.Edges, edge)
} }
if nodes := _c.mutation.PromoCodeUsagesIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PromoCodeUsagesTable,
Columns: []string{user.PromoCodeUsagesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges = append(_spec.Edges, edge)
}
return _node, _spec return _node, _spec
} }
......
...@@ -9,12 +9,14 @@ import ( ...@@ -9,12 +9,14 @@ import (
"math" "math"
"entgo.io/ent" "entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field" "entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/user"
...@@ -37,7 +39,9 @@ type UserQuery struct { ...@@ -37,7 +39,9 @@ type UserQuery struct {
withAllowedGroups *GroupQuery withAllowedGroups *GroupQuery
withUsageLogs *UsageLogQuery withUsageLogs *UsageLogQuery
withAttributeValues *UserAttributeValueQuery withAttributeValues *UserAttributeValueQuery
withPromoCodeUsages *PromoCodeUsageQuery
withUserAllowedGroups *UserAllowedGroupQuery withUserAllowedGroups *UserAllowedGroupQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path). // intermediate query (i.e. traversal path).
sql *sql.Selector sql *sql.Selector
path func(context.Context) (*sql.Selector, error) path func(context.Context) (*sql.Selector, error)
...@@ -228,6 +232,28 @@ func (_q *UserQuery) QueryAttributeValues() *UserAttributeValueQuery { ...@@ -228,6 +232,28 @@ func (_q *UserQuery) QueryAttributeValues() *UserAttributeValueQuery {
return query return query
} }
// QueryPromoCodeUsages chains the current query on the "promo_code_usages" edge.
func (_q *UserQuery) QueryPromoCodeUsages() *PromoCodeUsageQuery {
query := (&PromoCodeUsageClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(user.Table, user.FieldID, selector),
sqlgraph.To(promocodeusage.Table, promocodeusage.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, user.PromoCodeUsagesTable, user.PromoCodeUsagesColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// QueryUserAllowedGroups chains the current query on the "user_allowed_groups" edge. // QueryUserAllowedGroups chains the current query on the "user_allowed_groups" edge.
func (_q *UserQuery) QueryUserAllowedGroups() *UserAllowedGroupQuery { func (_q *UserQuery) QueryUserAllowedGroups() *UserAllowedGroupQuery {
query := (&UserAllowedGroupClient{config: _q.config}).Query() query := (&UserAllowedGroupClient{config: _q.config}).Query()
...@@ -449,6 +475,7 @@ func (_q *UserQuery) Clone() *UserQuery { ...@@ -449,6 +475,7 @@ func (_q *UserQuery) Clone() *UserQuery {
withAllowedGroups: _q.withAllowedGroups.Clone(), withAllowedGroups: _q.withAllowedGroups.Clone(),
withUsageLogs: _q.withUsageLogs.Clone(), withUsageLogs: _q.withUsageLogs.Clone(),
withAttributeValues: _q.withAttributeValues.Clone(), withAttributeValues: _q.withAttributeValues.Clone(),
withPromoCodeUsages: _q.withPromoCodeUsages.Clone(),
withUserAllowedGroups: _q.withUserAllowedGroups.Clone(), withUserAllowedGroups: _q.withUserAllowedGroups.Clone(),
// clone intermediate query. // clone intermediate query.
sql: _q.sql.Clone(), sql: _q.sql.Clone(),
...@@ -533,6 +560,17 @@ func (_q *UserQuery) WithAttributeValues(opts ...func(*UserAttributeValueQuery)) ...@@ -533,6 +560,17 @@ func (_q *UserQuery) WithAttributeValues(opts ...func(*UserAttributeValueQuery))
return _q return _q
} }
// WithPromoCodeUsages tells the query-builder to eager-load the nodes that are connected to
// the "promo_code_usages" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithPromoCodeUsages(opts ...func(*PromoCodeUsageQuery)) *UserQuery {
query := (&PromoCodeUsageClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withPromoCodeUsages = query
return _q
}
// WithUserAllowedGroups tells the query-builder to eager-load the nodes that are connected to // WithUserAllowedGroups tells the query-builder to eager-load the nodes that are connected to
// the "user_allowed_groups" edge. The optional arguments are used to configure the query builder of the edge. // the "user_allowed_groups" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithUserAllowedGroups(opts ...func(*UserAllowedGroupQuery)) *UserQuery { func (_q *UserQuery) WithUserAllowedGroups(opts ...func(*UserAllowedGroupQuery)) *UserQuery {
...@@ -622,7 +660,7 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e ...@@ -622,7 +660,7 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
var ( var (
nodes = []*User{} nodes = []*User{}
_spec = _q.querySpec() _spec = _q.querySpec()
loadedTypes = [8]bool{ loadedTypes = [9]bool{
_q.withAPIKeys != nil, _q.withAPIKeys != nil,
_q.withRedeemCodes != nil, _q.withRedeemCodes != nil,
_q.withSubscriptions != nil, _q.withSubscriptions != nil,
...@@ -630,6 +668,7 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e ...@@ -630,6 +668,7 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
_q.withAllowedGroups != nil, _q.withAllowedGroups != nil,
_q.withUsageLogs != nil, _q.withUsageLogs != nil,
_q.withAttributeValues != nil, _q.withAttributeValues != nil,
_q.withPromoCodeUsages != nil,
_q.withUserAllowedGroups != nil, _q.withUserAllowedGroups != nil,
} }
) )
...@@ -642,6 +681,9 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e ...@@ -642,6 +681,9 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
node.Edges.loadedTypes = loadedTypes node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values) return node.assignValues(columns, values)
} }
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks { for i := range hooks {
hooks[i](ctx, _spec) hooks[i](ctx, _spec)
} }
...@@ -702,6 +744,13 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e ...@@ -702,6 +744,13 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
return nil, err return nil, err
} }
} }
if query := _q.withPromoCodeUsages; query != nil {
if err := _q.loadPromoCodeUsages(ctx, query, nodes,
func(n *User) { n.Edges.PromoCodeUsages = []*PromoCodeUsage{} },
func(n *User, e *PromoCodeUsage) { n.Edges.PromoCodeUsages = append(n.Edges.PromoCodeUsages, e) }); err != nil {
return nil, err
}
}
if query := _q.withUserAllowedGroups; query != nil { if query := _q.withUserAllowedGroups; query != nil {
if err := _q.loadUserAllowedGroups(ctx, query, nodes, if err := _q.loadUserAllowedGroups(ctx, query, nodes,
func(n *User) { n.Edges.UserAllowedGroups = []*UserAllowedGroup{} }, func(n *User) { n.Edges.UserAllowedGroups = []*UserAllowedGroup{} },
...@@ -959,6 +1008,36 @@ func (_q *UserQuery) loadAttributeValues(ctx context.Context, query *UserAttribu ...@@ -959,6 +1008,36 @@ func (_q *UserQuery) loadAttributeValues(ctx context.Context, query *UserAttribu
} }
return nil return nil
} }
func (_q *UserQuery) loadPromoCodeUsages(ctx context.Context, query *PromoCodeUsageQuery, nodes []*User, init func(*User), assign func(*User, *PromoCodeUsage)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*User)
for i := range nodes {
fks = append(fks, nodes[i].ID)
nodeids[nodes[i].ID] = nodes[i]
if init != nil {
init(nodes[i])
}
}
if len(query.ctx.Fields) > 0 {
query.ctx.AppendFieldOnce(promocodeusage.FieldUserID)
}
query.Where(predicate.PromoCodeUsage(func(s *sql.Selector) {
s.Where(sql.InValues(s.C(user.PromoCodeUsagesColumn), fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
fk := n.UserID
node, ok := nodeids[fk]
if !ok {
return fmt.Errorf(`unexpected referenced foreign-key "user_id" returned %v for node %v`, fk, n.ID)
}
assign(node, n)
}
return nil
}
func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllowedGroupQuery, nodes []*User, init func(*User), assign func(*User, *UserAllowedGroup)) error { func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllowedGroupQuery, nodes []*User, init func(*User), assign func(*User, *UserAllowedGroup)) error {
fks := make([]driver.Value, 0, len(nodes)) fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*User) nodeids := make(map[int64]*User)
...@@ -992,6 +1071,9 @@ func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllow ...@@ -992,6 +1071,9 @@ func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllow
func (_q *UserQuery) sqlCount(ctx context.Context) (int, error) { func (_q *UserQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec() _spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Node.Columns = _q.ctx.Fields _spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 { if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
...@@ -1054,6 +1136,9 @@ func (_q *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { ...@@ -1054,6 +1136,9 @@ func (_q *UserQuery) sqlQuery(ctx context.Context) *sql.Selector {
if _q.ctx.Unique != nil && *_q.ctx.Unique { if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct() selector.Distinct()
} }
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates { for _, p := range _q.predicates {
p(selector) p(selector)
} }
...@@ -1071,6 +1156,32 @@ func (_q *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { ...@@ -1071,6 +1156,32 @@ func (_q *UserQuery) sqlQuery(ctx context.Context) *sql.Selector {
return selector return selector
} }
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *UserQuery) ForUpdate(opts ...sql.LockOption) *UserQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *UserQuery) ForShare(opts ...sql.LockOption) *UserQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// UserGroupBy is the group-by builder for User entities. // UserGroupBy is the group-by builder for User entities.
type UserGroupBy struct { type UserGroupBy struct {
selector selector
......
...@@ -14,6 +14,7 @@ import ( ...@@ -14,6 +14,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/user"
...@@ -291,6 +292,21 @@ func (_u *UserUpdate) AddAttributeValues(v ...*UserAttributeValue) *UserUpdate { ...@@ -291,6 +292,21 @@ func (_u *UserUpdate) AddAttributeValues(v ...*UserAttributeValue) *UserUpdate {
return _u.AddAttributeValueIDs(ids...) return _u.AddAttributeValueIDs(ids...)
} }
// AddPromoCodeUsageIDs adds the "promo_code_usages" edge to the PromoCodeUsage entity by IDs.
func (_u *UserUpdate) AddPromoCodeUsageIDs(ids ...int64) *UserUpdate {
_u.mutation.AddPromoCodeUsageIDs(ids...)
return _u
}
// AddPromoCodeUsages adds the "promo_code_usages" edges to the PromoCodeUsage entity.
func (_u *UserUpdate) AddPromoCodeUsages(v ...*PromoCodeUsage) *UserUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddPromoCodeUsageIDs(ids...)
}
// Mutation returns the UserMutation object of the builder. // Mutation returns the UserMutation object of the builder.
func (_u *UserUpdate) Mutation() *UserMutation { func (_u *UserUpdate) Mutation() *UserMutation {
return _u.mutation return _u.mutation
...@@ -443,6 +459,27 @@ func (_u *UserUpdate) RemoveAttributeValues(v ...*UserAttributeValue) *UserUpdat ...@@ -443,6 +459,27 @@ func (_u *UserUpdate) RemoveAttributeValues(v ...*UserAttributeValue) *UserUpdat
return _u.RemoveAttributeValueIDs(ids...) return _u.RemoveAttributeValueIDs(ids...)
} }
// ClearPromoCodeUsages clears all "promo_code_usages" edges to the PromoCodeUsage entity.
func (_u *UserUpdate) ClearPromoCodeUsages() *UserUpdate {
_u.mutation.ClearPromoCodeUsages()
return _u
}
// RemovePromoCodeUsageIDs removes the "promo_code_usages" edge to PromoCodeUsage entities by IDs.
func (_u *UserUpdate) RemovePromoCodeUsageIDs(ids ...int64) *UserUpdate {
_u.mutation.RemovePromoCodeUsageIDs(ids...)
return _u
}
// RemovePromoCodeUsages removes "promo_code_usages" edges to PromoCodeUsage entities.
func (_u *UserUpdate) RemovePromoCodeUsages(v ...*PromoCodeUsage) *UserUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemovePromoCodeUsageIDs(ids...)
}
// Save executes the query and returns the number of nodes affected by the update operation. // Save executes the query and returns the number of nodes affected by the update operation.
func (_u *UserUpdate) Save(ctx context.Context) (int, error) { func (_u *UserUpdate) Save(ctx context.Context) (int, error) {
if err := _u.defaults(); err != nil { if err := _u.defaults(); err != nil {
...@@ -893,6 +930,51 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) { ...@@ -893,6 +930,51 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
} }
_spec.Edges.Add = append(_spec.Edges.Add, edge) _spec.Edges.Add = append(_spec.Edges.Add, edge)
} }
if _u.mutation.PromoCodeUsagesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PromoCodeUsagesTable,
Columns: []string{user.PromoCodeUsagesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedPromoCodeUsagesIDs(); len(nodes) > 0 && !_u.mutation.PromoCodeUsagesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PromoCodeUsagesTable,
Columns: []string{user.PromoCodeUsagesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.PromoCodeUsagesIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PromoCodeUsagesTable,
Columns: []string{user.PromoCodeUsagesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok { if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{user.Label} err = &NotFoundError{user.Label}
...@@ -1170,6 +1252,21 @@ func (_u *UserUpdateOne) AddAttributeValues(v ...*UserAttributeValue) *UserUpdat ...@@ -1170,6 +1252,21 @@ func (_u *UserUpdateOne) AddAttributeValues(v ...*UserAttributeValue) *UserUpdat
return _u.AddAttributeValueIDs(ids...) return _u.AddAttributeValueIDs(ids...)
} }
// AddPromoCodeUsageIDs adds the "promo_code_usages" edge to the PromoCodeUsage entity by IDs.
func (_u *UserUpdateOne) AddPromoCodeUsageIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddPromoCodeUsageIDs(ids...)
return _u
}
// AddPromoCodeUsages adds the "promo_code_usages" edges to the PromoCodeUsage entity.
func (_u *UserUpdateOne) AddPromoCodeUsages(v ...*PromoCodeUsage) *UserUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddPromoCodeUsageIDs(ids...)
}
// Mutation returns the UserMutation object of the builder. // Mutation returns the UserMutation object of the builder.
func (_u *UserUpdateOne) Mutation() *UserMutation { func (_u *UserUpdateOne) Mutation() *UserMutation {
return _u.mutation return _u.mutation
...@@ -1322,6 +1419,27 @@ func (_u *UserUpdateOne) RemoveAttributeValues(v ...*UserAttributeValue) *UserUp ...@@ -1322,6 +1419,27 @@ func (_u *UserUpdateOne) RemoveAttributeValues(v ...*UserAttributeValue) *UserUp
return _u.RemoveAttributeValueIDs(ids...) return _u.RemoveAttributeValueIDs(ids...)
} }
// ClearPromoCodeUsages clears all "promo_code_usages" edges to the PromoCodeUsage entity.
func (_u *UserUpdateOne) ClearPromoCodeUsages() *UserUpdateOne {
_u.mutation.ClearPromoCodeUsages()
return _u
}
// RemovePromoCodeUsageIDs removes the "promo_code_usages" edge to PromoCodeUsage entities by IDs.
func (_u *UserUpdateOne) RemovePromoCodeUsageIDs(ids ...int64) *UserUpdateOne {
_u.mutation.RemovePromoCodeUsageIDs(ids...)
return _u
}
// RemovePromoCodeUsages removes "promo_code_usages" edges to PromoCodeUsage entities.
func (_u *UserUpdateOne) RemovePromoCodeUsages(v ...*PromoCodeUsage) *UserUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemovePromoCodeUsageIDs(ids...)
}
// Where appends a list predicates to the UserUpdate builder. // Where appends a list predicates to the UserUpdate builder.
func (_u *UserUpdateOne) Where(ps ...predicate.User) *UserUpdateOne { func (_u *UserUpdateOne) Where(ps ...predicate.User) *UserUpdateOne {
_u.mutation.Where(ps...) _u.mutation.Where(ps...)
...@@ -1802,6 +1920,51 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) { ...@@ -1802,6 +1920,51 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
} }
_spec.Edges.Add = append(_spec.Edges.Add, edge) _spec.Edges.Add = append(_spec.Edges.Add, edge)
} }
if _u.mutation.PromoCodeUsagesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PromoCodeUsagesTable,
Columns: []string{user.PromoCodeUsagesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedPromoCodeUsagesIDs(); len(nodes) > 0 && !_u.mutation.PromoCodeUsagesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PromoCodeUsagesTable,
Columns: []string{user.PromoCodeUsagesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.PromoCodeUsagesIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PromoCodeUsagesTable,
Columns: []string{user.PromoCodeUsagesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
_node = &User{config: _u.config} _node = &User{config: _u.config}
_spec.Assign = _node.assignValues _spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues _spec.ScanValues = _node.scanValues
......
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"math" "math"
"entgo.io/ent" "entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/dialect/sql/sqlgraph"
"github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/group"
...@@ -25,6 +26,7 @@ type UserAllowedGroupQuery struct { ...@@ -25,6 +26,7 @@ type UserAllowedGroupQuery struct {
predicates []predicate.UserAllowedGroup predicates []predicate.UserAllowedGroup
withUser *UserQuery withUser *UserQuery
withGroup *GroupQuery withGroup *GroupQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path). // intermediate query (i.e. traversal path).
sql *sql.Selector sql *sql.Selector
path func(context.Context) (*sql.Selector, error) path func(context.Context) (*sql.Selector, error)
...@@ -347,6 +349,9 @@ func (_q *UserAllowedGroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ...@@ -347,6 +349,9 @@ func (_q *UserAllowedGroupQuery) sqlAll(ctx context.Context, hooks ...queryHook)
node.Edges.loadedTypes = loadedTypes node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values) return node.assignValues(columns, values)
} }
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks { for i := range hooks {
hooks[i](ctx, _spec) hooks[i](ctx, _spec)
} }
...@@ -432,6 +437,9 @@ func (_q *UserAllowedGroupQuery) loadGroup(ctx context.Context, query *GroupQuer ...@@ -432,6 +437,9 @@ func (_q *UserAllowedGroupQuery) loadGroup(ctx context.Context, query *GroupQuer
func (_q *UserAllowedGroupQuery) sqlCount(ctx context.Context) (int, error) { func (_q *UserAllowedGroupQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec() _spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Unique = false _spec.Unique = false
_spec.Node.Columns = nil _spec.Node.Columns = nil
return sqlgraph.CountNodes(ctx, _q.driver, _spec) return sqlgraph.CountNodes(ctx, _q.driver, _spec)
...@@ -495,6 +503,9 @@ func (_q *UserAllowedGroupQuery) sqlQuery(ctx context.Context) *sql.Selector { ...@@ -495,6 +503,9 @@ func (_q *UserAllowedGroupQuery) sqlQuery(ctx context.Context) *sql.Selector {
if _q.ctx.Unique != nil && *_q.ctx.Unique { if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct() selector.Distinct()
} }
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates { for _, p := range _q.predicates {
p(selector) p(selector)
} }
...@@ -512,6 +523,32 @@ func (_q *UserAllowedGroupQuery) sqlQuery(ctx context.Context) *sql.Selector { ...@@ -512,6 +523,32 @@ func (_q *UserAllowedGroupQuery) sqlQuery(ctx context.Context) *sql.Selector {
return selector return selector
} }
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *UserAllowedGroupQuery) ForUpdate(opts ...sql.LockOption) *UserAllowedGroupQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *UserAllowedGroupQuery) ForShare(opts ...sql.LockOption) *UserAllowedGroupQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// UserAllowedGroupGroupBy is the group-by builder for UserAllowedGroup entities. // UserAllowedGroupGroupBy is the group-by builder for UserAllowedGroup entities.
type UserAllowedGroupGroupBy struct { type UserAllowedGroupGroupBy struct {
selector selector
......
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"math" "math"
"entgo.io/ent" "entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field" "entgo.io/ent/schema/field"
...@@ -25,6 +26,7 @@ type UserAttributeDefinitionQuery struct { ...@@ -25,6 +26,7 @@ type UserAttributeDefinitionQuery struct {
inters []Interceptor inters []Interceptor
predicates []predicate.UserAttributeDefinition predicates []predicate.UserAttributeDefinition
withValues *UserAttributeValueQuery withValues *UserAttributeValueQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path). // intermediate query (i.e. traversal path).
sql *sql.Selector sql *sql.Selector
path func(context.Context) (*sql.Selector, error) path func(context.Context) (*sql.Selector, error)
...@@ -384,6 +386,9 @@ func (_q *UserAttributeDefinitionQuery) sqlAll(ctx context.Context, hooks ...que ...@@ -384,6 +386,9 @@ func (_q *UserAttributeDefinitionQuery) sqlAll(ctx context.Context, hooks ...que
node.Edges.loadedTypes = loadedTypes node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values) return node.assignValues(columns, values)
} }
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks { for i := range hooks {
hooks[i](ctx, _spec) hooks[i](ctx, _spec)
} }
...@@ -436,6 +441,9 @@ func (_q *UserAttributeDefinitionQuery) loadValues(ctx context.Context, query *U ...@@ -436,6 +441,9 @@ func (_q *UserAttributeDefinitionQuery) loadValues(ctx context.Context, query *U
func (_q *UserAttributeDefinitionQuery) sqlCount(ctx context.Context) (int, error) { func (_q *UserAttributeDefinitionQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec() _spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Node.Columns = _q.ctx.Fields _spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 { if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
...@@ -498,6 +506,9 @@ func (_q *UserAttributeDefinitionQuery) sqlQuery(ctx context.Context) *sql.Selec ...@@ -498,6 +506,9 @@ func (_q *UserAttributeDefinitionQuery) sqlQuery(ctx context.Context) *sql.Selec
if _q.ctx.Unique != nil && *_q.ctx.Unique { if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct() selector.Distinct()
} }
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates { for _, p := range _q.predicates {
p(selector) p(selector)
} }
...@@ -515,6 +526,32 @@ func (_q *UserAttributeDefinitionQuery) sqlQuery(ctx context.Context) *sql.Selec ...@@ -515,6 +526,32 @@ func (_q *UserAttributeDefinitionQuery) sqlQuery(ctx context.Context) *sql.Selec
return selector return selector
} }
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *UserAttributeDefinitionQuery) ForUpdate(opts ...sql.LockOption) *UserAttributeDefinitionQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *UserAttributeDefinitionQuery) ForShare(opts ...sql.LockOption) *UserAttributeDefinitionQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// UserAttributeDefinitionGroupBy is the group-by builder for UserAttributeDefinition entities. // UserAttributeDefinitionGroupBy is the group-by builder for UserAttributeDefinition entities.
type UserAttributeDefinitionGroupBy struct { type UserAttributeDefinitionGroupBy struct {
selector selector
......
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"math" "math"
"entgo.io/ent" "entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field" "entgo.io/ent/schema/field"
...@@ -26,6 +27,7 @@ type UserAttributeValueQuery struct { ...@@ -26,6 +27,7 @@ type UserAttributeValueQuery struct {
predicates []predicate.UserAttributeValue predicates []predicate.UserAttributeValue
withUser *UserQuery withUser *UserQuery
withDefinition *UserAttributeDefinitionQuery withDefinition *UserAttributeDefinitionQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path). // intermediate query (i.e. traversal path).
sql *sql.Selector sql *sql.Selector
path func(context.Context) (*sql.Selector, error) path func(context.Context) (*sql.Selector, error)
...@@ -420,6 +422,9 @@ func (_q *UserAttributeValueQuery) sqlAll(ctx context.Context, hooks ...queryHoo ...@@ -420,6 +422,9 @@ func (_q *UserAttributeValueQuery) sqlAll(ctx context.Context, hooks ...queryHoo
node.Edges.loadedTypes = loadedTypes node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values) return node.assignValues(columns, values)
} }
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks { for i := range hooks {
hooks[i](ctx, _spec) hooks[i](ctx, _spec)
} }
...@@ -505,6 +510,9 @@ func (_q *UserAttributeValueQuery) loadDefinition(ctx context.Context, query *Us ...@@ -505,6 +510,9 @@ func (_q *UserAttributeValueQuery) loadDefinition(ctx context.Context, query *Us
func (_q *UserAttributeValueQuery) sqlCount(ctx context.Context) (int, error) { func (_q *UserAttributeValueQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec() _spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Node.Columns = _q.ctx.Fields _spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 { if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
...@@ -573,6 +581,9 @@ func (_q *UserAttributeValueQuery) sqlQuery(ctx context.Context) *sql.Selector { ...@@ -573,6 +581,9 @@ func (_q *UserAttributeValueQuery) sqlQuery(ctx context.Context) *sql.Selector {
if _q.ctx.Unique != nil && *_q.ctx.Unique { if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct() selector.Distinct()
} }
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates { for _, p := range _q.predicates {
p(selector) p(selector)
} }
...@@ -590,6 +601,32 @@ func (_q *UserAttributeValueQuery) sqlQuery(ctx context.Context) *sql.Selector { ...@@ -590,6 +601,32 @@ func (_q *UserAttributeValueQuery) sqlQuery(ctx context.Context) *sql.Selector {
return selector return selector
} }
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *UserAttributeValueQuery) ForUpdate(opts ...sql.LockOption) *UserAttributeValueQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *UserAttributeValueQuery) ForShare(opts ...sql.LockOption) *UserAttributeValueQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// UserAttributeValueGroupBy is the group-by builder for UserAttributeValue entities. // UserAttributeValueGroupBy is the group-by builder for UserAttributeValue entities.
type UserAttributeValueGroupBy struct { type UserAttributeValueGroupBy struct {
selector selector
......
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"math" "math"
"entgo.io/ent" "entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field" "entgo.io/ent/schema/field"
...@@ -30,6 +31,7 @@ type UserSubscriptionQuery struct { ...@@ -30,6 +31,7 @@ type UserSubscriptionQuery struct {
withGroup *GroupQuery withGroup *GroupQuery
withAssignedByUser *UserQuery withAssignedByUser *UserQuery
withUsageLogs *UsageLogQuery withUsageLogs *UsageLogQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path). // intermediate query (i.e. traversal path).
sql *sql.Selector sql *sql.Selector
path func(context.Context) (*sql.Selector, error) path func(context.Context) (*sql.Selector, error)
...@@ -494,6 +496,9 @@ func (_q *UserSubscriptionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ...@@ -494,6 +496,9 @@ func (_q *UserSubscriptionQuery) sqlAll(ctx context.Context, hooks ...queryHook)
node.Edges.loadedTypes = loadedTypes node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values) return node.assignValues(columns, values)
} }
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks { for i := range hooks {
hooks[i](ctx, _spec) hooks[i](ctx, _spec)
} }
...@@ -657,6 +662,9 @@ func (_q *UserSubscriptionQuery) loadUsageLogs(ctx context.Context, query *Usage ...@@ -657,6 +662,9 @@ func (_q *UserSubscriptionQuery) loadUsageLogs(ctx context.Context, query *Usage
func (_q *UserSubscriptionQuery) sqlCount(ctx context.Context) (int, error) { func (_q *UserSubscriptionQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec() _spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Node.Columns = _q.ctx.Fields _spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 { if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
...@@ -728,6 +736,9 @@ func (_q *UserSubscriptionQuery) sqlQuery(ctx context.Context) *sql.Selector { ...@@ -728,6 +736,9 @@ func (_q *UserSubscriptionQuery) sqlQuery(ctx context.Context) *sql.Selector {
if _q.ctx.Unique != nil && *_q.ctx.Unique { if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct() selector.Distinct()
} }
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates { for _, p := range _q.predicates {
p(selector) p(selector)
} }
...@@ -745,6 +756,32 @@ func (_q *UserSubscriptionQuery) sqlQuery(ctx context.Context) *sql.Selector { ...@@ -745,6 +756,32 @@ func (_q *UserSubscriptionQuery) sqlQuery(ctx context.Context) *sql.Selector {
return selector return selector
} }
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *UserSubscriptionQuery) ForUpdate(opts ...sql.LockOption) *UserSubscriptionQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *UserSubscriptionQuery) ForShare(opts ...sql.LockOption) *UserSubscriptionQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// UserSubscriptionGroupBy is the group-by builder for UserSubscription entities. // UserSubscriptionGroupBy is the group-by builder for UserSubscription entities.
type UserSubscriptionGroupBy struct { type UserSubscriptionGroupBy struct {
selector selector
......
...@@ -687,7 +687,7 @@ func setDefaults() { ...@@ -687,7 +687,7 @@ func setDefaults() {
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求) viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求)
viper.SetDefault("gateway.stream_data_interval_timeout", 180) viper.SetDefault("gateway.stream_data_interval_timeout", 180)
viper.SetDefault("gateway.stream_keepalive_interval", 10) viper.SetDefault("gateway.stream_keepalive_interval", 10)
viper.SetDefault("gateway.max_line_size", 10*1024*1024) viper.SetDefault("gateway.max_line_size", 40*1024*1024)
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3) viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 45*time.Second) viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 45*time.Second)
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second) viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
......
package admin
import (
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// PromoHandler handles admin promo code management
type PromoHandler struct {
promoService *service.PromoService
}
// NewPromoHandler creates a new admin promo handler
func NewPromoHandler(promoService *service.PromoService) *PromoHandler {
return &PromoHandler{
promoService: promoService,
}
}
// CreatePromoCodeRequest represents create promo code request
type CreatePromoCodeRequest struct {
Code string `json:"code"` // 可选,为空则自动生成
BonusAmount float64 `json:"bonus_amount" binding:"required,min=0"` // 赠送余额
MaxUses int `json:"max_uses" binding:"min=0"` // 最大使用次数,0=无限
ExpiresAt *int64 `json:"expires_at"` // 过期时间戳(秒)
Notes string `json:"notes"` // 备注
}
// UpdatePromoCodeRequest represents update promo code request
type UpdatePromoCodeRequest struct {
Code *string `json:"code"`
BonusAmount *float64 `json:"bonus_amount" binding:"omitempty,min=0"`
MaxUses *int `json:"max_uses" binding:"omitempty,min=0"`
Status *string `json:"status" binding:"omitempty,oneof=active disabled"`
ExpiresAt *int64 `json:"expires_at"`
Notes *string `json:"notes"`
}
// List handles listing all promo codes with pagination
// GET /api/v1/admin/promo-codes
func (h *PromoHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
status := c.Query("status")
search := strings.TrimSpace(c.Query("search"))
if len(search) > 100 {
search = search[:100]
}
params := pagination.PaginationParams{
Page: page,
PageSize: pageSize,
}
codes, paginationResult, err := h.promoService.List(c.Request.Context(), params, status, search)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]dto.PromoCode, 0, len(codes))
for i := range codes {
out = append(out, *dto.PromoCodeFromService(&codes[i]))
}
response.Paginated(c, out, paginationResult.Total, page, pageSize)
}
// GetByID handles getting a promo code by ID
// GET /api/v1/admin/promo-codes/:id
func (h *PromoHandler) GetByID(c *gin.Context) {
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid promo code ID")
return
}
code, err := h.promoService.GetByID(c.Request.Context(), codeID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.PromoCodeFromService(code))
}
// Create handles creating a new promo code
// POST /api/v1/admin/promo-codes
func (h *PromoHandler) Create(c *gin.Context) {
var req CreatePromoCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
input := &service.CreatePromoCodeInput{
Code: req.Code,
BonusAmount: req.BonusAmount,
MaxUses: req.MaxUses,
Notes: req.Notes,
}
if req.ExpiresAt != nil {
t := time.Unix(*req.ExpiresAt, 0)
input.ExpiresAt = &t
}
code, err := h.promoService.Create(c.Request.Context(), input)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.PromoCodeFromService(code))
}
// Update handles updating a promo code
// PUT /api/v1/admin/promo-codes/:id
func (h *PromoHandler) Update(c *gin.Context) {
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid promo code ID")
return
}
var req UpdatePromoCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
input := &service.UpdatePromoCodeInput{
Code: req.Code,
BonusAmount: req.BonusAmount,
MaxUses: req.MaxUses,
Status: req.Status,
Notes: req.Notes,
}
if req.ExpiresAt != nil {
if *req.ExpiresAt == 0 {
// 0 表示清除过期时间
input.ExpiresAt = nil
} else {
t := time.Unix(*req.ExpiresAt, 0)
input.ExpiresAt = &t
}
}
code, err := h.promoService.Update(c.Request.Context(), codeID, input)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.PromoCodeFromService(code))
}
// Delete handles deleting a promo code
// DELETE /api/v1/admin/promo-codes/:id
func (h *PromoHandler) Delete(c *gin.Context) {
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid promo code ID")
return
}
err = h.promoService.Delete(c.Request.Context(), codeID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Promo code deleted successfully"})
}
// GetUsages handles getting usage records for a promo code
// GET /api/v1/admin/promo-codes/:id/usages
func (h *PromoHandler) GetUsages(c *gin.Context) {
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid promo code ID")
return
}
page, pageSize := response.ParsePagination(c)
params := pagination.PaginationParams{
Page: page,
PageSize: pageSize,
}
usages, paginationResult, err := h.promoService.ListUsages(c.Request.Context(), codeID, params)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]dto.PromoCodeUsage, 0, len(usages))
for i := range usages {
out = append(out, *dto.PromoCodeUsageFromService(&usages[i]))
}
response.Paginated(c, out, paginationResult.Total, page, pageSize)
}
...@@ -12,19 +12,21 @@ import ( ...@@ -12,19 +12,21 @@ import (
// AuthHandler handles authentication-related requests // AuthHandler handles authentication-related requests
type AuthHandler struct { type AuthHandler struct {
cfg *config.Config cfg *config.Config
authService *service.AuthService authService *service.AuthService
userService *service.UserService userService *service.UserService
settingSvc *service.SettingService settingSvc *service.SettingService
promoService *service.PromoService
} }
// NewAuthHandler creates a new AuthHandler // NewAuthHandler creates a new AuthHandler
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService) *AuthHandler { func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService) *AuthHandler {
return &AuthHandler{ return &AuthHandler{
cfg: cfg, cfg: cfg,
authService: authService, authService: authService,
userService: userService, userService: userService,
settingSvc: settingService, settingSvc: settingService,
promoService: promoService,
} }
} }
...@@ -34,6 +36,7 @@ type RegisterRequest struct { ...@@ -34,6 +36,7 @@ type RegisterRequest struct {
Password string `json:"password" binding:"required,min=6"` Password string `json:"password" binding:"required,min=6"`
VerifyCode string `json:"verify_code"` VerifyCode string `json:"verify_code"`
TurnstileToken string `json:"turnstile_token"` TurnstileToken string `json:"turnstile_token"`
PromoCode string `json:"promo_code"` // 注册优惠码
} }
// SendVerifyCodeRequest 发送验证码请求 // SendVerifyCodeRequest 发送验证码请求
...@@ -79,7 +82,7 @@ func (h *AuthHandler) Register(c *gin.Context) { ...@@ -79,7 +82,7 @@ func (h *AuthHandler) Register(c *gin.Context) {
} }
} }
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode) token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
...@@ -174,3 +177,63 @@ func (h *AuthHandler) GetCurrentUser(c *gin.Context) { ...@@ -174,3 +177,63 @@ func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
response.Success(c, UserResponse{User: dto.UserFromService(user), RunMode: runMode}) response.Success(c, UserResponse{User: dto.UserFromService(user), RunMode: runMode})
} }
// ValidatePromoCodeRequest 验证优惠码请求
type ValidatePromoCodeRequest struct {
Code string `json:"code" binding:"required"`
}
// ValidatePromoCodeResponse 验证优惠码响应
type ValidatePromoCodeResponse struct {
Valid bool `json:"valid"`
BonusAmount float64 `json:"bonus_amount,omitempty"`
ErrorCode string `json:"error_code,omitempty"`
Message string `json:"message,omitempty"`
}
// ValidatePromoCode 验证优惠码(公开接口,注册前调用)
// POST /api/v1/auth/validate-promo-code
func (h *AuthHandler) ValidatePromoCode(c *gin.Context) {
var req ValidatePromoCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
promoCode, err := h.promoService.ValidatePromoCode(c.Request.Context(), req.Code)
if err != nil {
// 根据错误类型返回对应的错误码
errorCode := "PROMO_CODE_INVALID"
switch err {
case service.ErrPromoCodeNotFound:
errorCode = "PROMO_CODE_NOT_FOUND"
case service.ErrPromoCodeExpired:
errorCode = "PROMO_CODE_EXPIRED"
case service.ErrPromoCodeDisabled:
errorCode = "PROMO_CODE_DISABLED"
case service.ErrPromoCodeMaxUsed:
errorCode = "PROMO_CODE_MAX_USED"
case service.ErrPromoCodeAlreadyUsed:
errorCode = "PROMO_CODE_ALREADY_USED"
}
response.Success(c, ValidatePromoCodeResponse{
Valid: false,
ErrorCode: errorCode,
})
return
}
if promoCode == nil {
response.Success(c, ValidatePromoCodeResponse{
Valid: false,
ErrorCode: "PROMO_CODE_INVALID",
})
return
}
response.Success(c, ValidatePromoCodeResponse{
Valid: true,
BonusAmount: promoCode.BonusAmount,
})
}
...@@ -370,3 +370,35 @@ func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult ...@@ -370,3 +370,35 @@ func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult
Errors: r.Errors, Errors: r.Errors,
} }
} }
func PromoCodeFromService(pc *service.PromoCode) *PromoCode {
if pc == nil {
return nil
}
return &PromoCode{
ID: pc.ID,
Code: pc.Code,
BonusAmount: pc.BonusAmount,
MaxUses: pc.MaxUses,
UsedCount: pc.UsedCount,
Status: pc.Status,
ExpiresAt: pc.ExpiresAt,
Notes: pc.Notes,
CreatedAt: pc.CreatedAt,
UpdatedAt: pc.UpdatedAt,
}
}
func PromoCodeUsageFromService(u *service.PromoCodeUsage) *PromoCodeUsage {
if u == nil {
return nil
}
return &PromoCodeUsage{
ID: u.ID,
PromoCodeID: u.PromoCodeID,
UserID: u.UserID,
BonusAmount: u.BonusAmount,
UsedAt: u.UsedAt,
User: UserFromServiceShallow(u.User),
}
}
...@@ -250,3 +250,28 @@ type BulkAssignResult struct { ...@@ -250,3 +250,28 @@ type BulkAssignResult struct {
Subscriptions []UserSubscription `json:"subscriptions"` Subscriptions []UserSubscription `json:"subscriptions"`
Errors []string `json:"errors"` Errors []string `json:"errors"`
} }
// PromoCode 注册优惠码
type PromoCode struct {
ID int64 `json:"id"`
Code string `json:"code"`
BonusAmount float64 `json:"bonus_amount"`
MaxUses int `json:"max_uses"`
UsedCount int `json:"used_count"`
Status string `json:"status"`
ExpiresAt *time.Time `json:"expires_at"`
Notes string `json:"notes"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// PromoCodeUsage 优惠码使用记录
type PromoCodeUsage struct {
ID int64 `json:"id"`
PromoCodeID int64 `json:"promo_code_id"`
UserID int64 `json:"user_id"`
BonusAmount float64 `json:"bonus_amount"`
UsedAt time.Time `json:"used_at"`
User *User `json:"user,omitempty"`
}
...@@ -16,6 +16,7 @@ type AdminHandlers struct { ...@@ -16,6 +16,7 @@ type AdminHandlers struct {
AntigravityOAuth *admin.AntigravityOAuthHandler AntigravityOAuth *admin.AntigravityOAuthHandler
Proxy *admin.ProxyHandler Proxy *admin.ProxyHandler
Redeem *admin.RedeemHandler Redeem *admin.RedeemHandler
Promo *admin.PromoHandler
Setting *admin.SettingHandler Setting *admin.SettingHandler
System *admin.SystemHandler System *admin.SystemHandler
Subscription *admin.SubscriptionHandler Subscription *admin.SubscriptionHandler
......
...@@ -19,6 +19,7 @@ func ProvideAdminHandlers( ...@@ -19,6 +19,7 @@ func ProvideAdminHandlers(
antigravityOAuthHandler *admin.AntigravityOAuthHandler, antigravityOAuthHandler *admin.AntigravityOAuthHandler,
proxyHandler *admin.ProxyHandler, proxyHandler *admin.ProxyHandler,
redeemHandler *admin.RedeemHandler, redeemHandler *admin.RedeemHandler,
promoHandler *admin.PromoHandler,
settingHandler *admin.SettingHandler, settingHandler *admin.SettingHandler,
systemHandler *admin.SystemHandler, systemHandler *admin.SystemHandler,
subscriptionHandler *admin.SubscriptionHandler, subscriptionHandler *admin.SubscriptionHandler,
...@@ -36,6 +37,7 @@ func ProvideAdminHandlers( ...@@ -36,6 +37,7 @@ func ProvideAdminHandlers(
AntigravityOAuth: antigravityOAuthHandler, AntigravityOAuth: antigravityOAuthHandler,
Proxy: proxyHandler, Proxy: proxyHandler,
Redeem: redeemHandler, Redeem: redeemHandler,
Promo: promoHandler,
Setting: settingHandler, Setting: settingHandler,
System: systemHandler, System: systemHandler,
Subscription: subscriptionHandler, Subscription: subscriptionHandler,
...@@ -105,6 +107,7 @@ var ProviderSet = wire.NewSet( ...@@ -105,6 +107,7 @@ var ProviderSet = wire.NewSet(
admin.NewAntigravityOAuthHandler, admin.NewAntigravityOAuthHandler,
admin.NewProxyHandler, admin.NewProxyHandler,
admin.NewRedeemHandler, admin.NewRedeemHandler,
admin.NewPromoHandler,
admin.NewSettingHandler, admin.NewSettingHandler,
ProvideSystemHandler, ProvideSystemHandler,
admin.NewSubscriptionHandler, admin.NewSubscriptionHandler,
......
package middleware
import (
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
)
// RateLimiter Redis 速率限制器
type RateLimiter struct {
redis *redis.Client
prefix string
}
// NewRateLimiter 创建速率限制器实例
func NewRateLimiter(redisClient *redis.Client) *RateLimiter {
return &RateLimiter{
redis: redisClient,
prefix: "rate_limit:",
}
}
// Limit 返回速率限制中间件
// key: 限制类型标识
// limit: 时间窗口内最大请求数
// window: 时间窗口
func (r *RateLimiter) Limit(key string, limit int, window time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
ip := c.ClientIP()
redisKey := r.prefix + key + ":" + ip
ctx := c.Request.Context()
// 使用 INCR 原子操作增加计数
count, err := r.redis.Incr(ctx, redisKey).Result()
if err != nil {
// Redis 错误时放行,避免影响正常服务
c.Next()
return
}
// 首次访问时设置过期时间
if count == 1 {
r.redis.Expire(ctx, redisKey, window)
}
// 超过限制
if count > int64(limit) {
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
"error": "rate limit exceeded",
"message": "Too many requests, please try again later",
})
return
}
c.Next()
}
}
...@@ -9,4 +9,6 @@ const ( ...@@ -9,4 +9,6 @@ const (
ForcePlatform Key = "ctx_force_platform" ForcePlatform Key = "ctx_force_platform"
// IsClaudeCodeClient 是否为 Claude Code 客户端,由中间件设置 // IsClaudeCodeClient 是否为 Claude Code 客户端,由中间件设置
IsClaudeCodeClient Key = "ctx_is_claude_code_client" IsClaudeCodeClient Key = "ctx_is_claude_code_client"
// Group 认证后的分组信息,由 API Key 认证中间件设置
Group Key = "ctx_group"
) )
...@@ -675,6 +675,40 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA ...@@ -675,6 +675,40 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
return err return err
} }
func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error {
now := time.Now().UTC()
payload := map[string]string{
"rate_limited_at": now.Format(time.RFC3339),
"rate_limit_reset_at": resetAt.UTC().Format(time.RFC3339),
}
raw, err := json.Marshal(payload)
if err != nil {
return err
}
path := "{antigravity_quota_scopes," + string(scope) + "}"
client := clientFromContext(ctx, r.client)
result, err := client.ExecContext(
ctx,
"UPDATE accounts SET extra = jsonb_set(COALESCE(extra, '{}'::jsonb), $1::text[], $2::jsonb, true), updated_at = NOW() WHERE id = $3 AND deleted_at IS NULL",
path,
raw,
id,
)
if err != nil {
return err
}
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return service.ErrAccountNotFound
}
return nil
}
func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error { func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
_, err := r.client.Account.Update(). _, err := r.client.Account.Update().
Where(dbaccount.IDEQ(id)). Where(dbaccount.IDEQ(id)).
...@@ -718,6 +752,27 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error ...@@ -718,6 +752,27 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error
return err return err
} }
func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
client := clientFromContext(ctx, r.client)
result, err := client.ExecContext(
ctx,
"UPDATE accounts SET extra = COALESCE(extra, '{}'::jsonb) - 'antigravity_quota_scopes', updated_at = NOW() WHERE id = $1 AND deleted_at IS NULL",
id,
)
if err != nil {
return err
}
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return service.ErrAccountNotFound
}
return nil
}
func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error { func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
builder := r.client.Account.Update(). builder := r.client.Account.Update().
Where(dbaccount.IDEQ(id)). Where(dbaccount.IDEQ(id)).
......
...@@ -339,6 +339,7 @@ func groupEntityToService(g *dbent.Group) *service.Group { ...@@ -339,6 +339,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
RateMultiplier: g.RateMultiplier, RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive, IsExclusive: g.IsExclusive,
Status: g.Status, Status: g.Status,
Hydrated: true,
SubscriptionType: g.SubscriptionType, SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUsd, DailyLimitUSD: g.DailyLimitUsd,
WeeklyLimitUSD: g.WeeklyLimitUsd, WeeklyLimitUSD: g.WeeklyLimitUsd,
......
...@@ -60,6 +60,17 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er ...@@ -60,6 +60,17 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
} }
func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group, error) { func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group, error) {
out, err := r.GetByIDLite(ctx, id)
if err != nil {
return nil, err
}
count, _ := r.GetAccountCount(ctx, out.ID)
out.AccountCount = count
return out, nil
}
func (r *groupRepository) GetByIDLite(ctx context.Context, id int64) (*service.Group, error) {
// AccountCount is intentionally not loaded here; use GetByID when needed.
m, err := r.client.Group.Query(). m, err := r.client.Group.Query().
Where(group.IDEQ(id)). Where(group.IDEQ(id)).
Only(ctx) Only(ctx)
...@@ -67,10 +78,7 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group ...@@ -67,10 +78,7 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group
return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil) return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil)
} }
out := groupEntityToService(m) return groupEntityToService(m), nil
count, _ := r.GetAccountCount(ctx, out.ID)
out.AccountCount = count
return out, nil
} }
func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) error { func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) error {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment