Commit f6f072cb authored by Edric Li's avatar Edric Li
Browse files

Merge branch 'main' into feat/api-key-ip-restriction

parents 5265b12c ff087586
......@@ -871,6 +871,29 @@ func HasAttributeValuesWith(preds ...predicate.UserAttributeValue) predicate.Use
})
}
// HasPromoCodeUsages applies the HasEdge predicate on the "promo_code_usages" edge.
func HasPromoCodeUsages() predicate.User {
return predicate.User(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, PromoCodeUsagesTable, PromoCodeUsagesColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasPromoCodeUsagesWith applies the HasEdge predicate on the "promo_code_usages" edge with a given conditions (other predicates).
func HasPromoCodeUsagesWith(preds ...predicate.PromoCodeUsage) predicate.User {
return predicate.User(func(s *sql.Selector) {
step := newPromoCodeUsagesStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// HasUserAllowedGroups applies the HasEdge predicate on the "user_allowed_groups" edge.
func HasUserAllowedGroups() predicate.User {
return predicate.User(func(s *sql.Selector) {
......
......@@ -13,6 +13,7 @@ import (
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
......@@ -271,6 +272,21 @@ func (_c *UserCreate) AddAttributeValues(v ...*UserAttributeValue) *UserCreate {
return _c.AddAttributeValueIDs(ids...)
}
// AddPromoCodeUsageIDs adds the "promo_code_usages" edge to the PromoCodeUsage entity by IDs.
func (_c *UserCreate) AddPromoCodeUsageIDs(ids ...int64) *UserCreate {
_c.mutation.AddPromoCodeUsageIDs(ids...)
return _c
}
// AddPromoCodeUsages adds the "promo_code_usages" edges to the PromoCodeUsage entity.
func (_c *UserCreate) AddPromoCodeUsages(v ...*PromoCodeUsage) *UserCreate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _c.AddPromoCodeUsageIDs(ids...)
}
// Mutation returns the UserMutation object of the builder.
func (_c *UserCreate) Mutation() *UserMutation {
return _c.mutation
......@@ -593,6 +609,22 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
}
_spec.Edges = append(_spec.Edges, edge)
}
if nodes := _c.mutation.PromoCodeUsagesIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PromoCodeUsagesTable,
Columns: []string{user.PromoCodeUsagesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges = append(_spec.Edges, edge)
}
return _node, _spec
}
......
......@@ -9,12 +9,14 @@ import (
"math"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
......@@ -37,7 +39,9 @@ type UserQuery struct {
withAllowedGroups *GroupQuery
withUsageLogs *UsageLogQuery
withAttributeValues *UserAttributeValueQuery
withPromoCodeUsages *PromoCodeUsageQuery
withUserAllowedGroups *UserAllowedGroupQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
......@@ -228,6 +232,28 @@ func (_q *UserQuery) QueryAttributeValues() *UserAttributeValueQuery {
return query
}
// QueryPromoCodeUsages chains the current query on the "promo_code_usages" edge.
func (_q *UserQuery) QueryPromoCodeUsages() *PromoCodeUsageQuery {
query := (&PromoCodeUsageClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(user.Table, user.FieldID, selector),
sqlgraph.To(promocodeusage.Table, promocodeusage.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, user.PromoCodeUsagesTable, user.PromoCodeUsagesColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// QueryUserAllowedGroups chains the current query on the "user_allowed_groups" edge.
func (_q *UserQuery) QueryUserAllowedGroups() *UserAllowedGroupQuery {
query := (&UserAllowedGroupClient{config: _q.config}).Query()
......@@ -449,6 +475,7 @@ func (_q *UserQuery) Clone() *UserQuery {
withAllowedGroups: _q.withAllowedGroups.Clone(),
withUsageLogs: _q.withUsageLogs.Clone(),
withAttributeValues: _q.withAttributeValues.Clone(),
withPromoCodeUsages: _q.withPromoCodeUsages.Clone(),
withUserAllowedGroups: _q.withUserAllowedGroups.Clone(),
// clone intermediate query.
sql: _q.sql.Clone(),
......@@ -533,6 +560,17 @@ func (_q *UserQuery) WithAttributeValues(opts ...func(*UserAttributeValueQuery))
return _q
}
// WithPromoCodeUsages tells the query-builder to eager-load the nodes that are connected to
// the "promo_code_usages" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithPromoCodeUsages(opts ...func(*PromoCodeUsageQuery)) *UserQuery {
query := (&PromoCodeUsageClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withPromoCodeUsages = query
return _q
}
// WithUserAllowedGroups tells the query-builder to eager-load the nodes that are connected to
// the "user_allowed_groups" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithUserAllowedGroups(opts ...func(*UserAllowedGroupQuery)) *UserQuery {
......@@ -622,7 +660,7 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
var (
nodes = []*User{}
_spec = _q.querySpec()
loadedTypes = [8]bool{
loadedTypes = [9]bool{
_q.withAPIKeys != nil,
_q.withRedeemCodes != nil,
_q.withSubscriptions != nil,
......@@ -630,6 +668,7 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
_q.withAllowedGroups != nil,
_q.withUsageLogs != nil,
_q.withAttributeValues != nil,
_q.withPromoCodeUsages != nil,
_q.withUserAllowedGroups != nil,
}
)
......@@ -642,6 +681,9 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks {
hooks[i](ctx, _spec)
}
......@@ -702,6 +744,13 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
return nil, err
}
}
if query := _q.withPromoCodeUsages; query != nil {
if err := _q.loadPromoCodeUsages(ctx, query, nodes,
func(n *User) { n.Edges.PromoCodeUsages = []*PromoCodeUsage{} },
func(n *User, e *PromoCodeUsage) { n.Edges.PromoCodeUsages = append(n.Edges.PromoCodeUsages, e) }); err != nil {
return nil, err
}
}
if query := _q.withUserAllowedGroups; query != nil {
if err := _q.loadUserAllowedGroups(ctx, query, nodes,
func(n *User) { n.Edges.UserAllowedGroups = []*UserAllowedGroup{} },
......@@ -959,6 +1008,36 @@ func (_q *UserQuery) loadAttributeValues(ctx context.Context, query *UserAttribu
}
return nil
}
func (_q *UserQuery) loadPromoCodeUsages(ctx context.Context, query *PromoCodeUsageQuery, nodes []*User, init func(*User), assign func(*User, *PromoCodeUsage)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*User)
for i := range nodes {
fks = append(fks, nodes[i].ID)
nodeids[nodes[i].ID] = nodes[i]
if init != nil {
init(nodes[i])
}
}
if len(query.ctx.Fields) > 0 {
query.ctx.AppendFieldOnce(promocodeusage.FieldUserID)
}
query.Where(predicate.PromoCodeUsage(func(s *sql.Selector) {
s.Where(sql.InValues(s.C(user.PromoCodeUsagesColumn), fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
fk := n.UserID
node, ok := nodeids[fk]
if !ok {
return fmt.Errorf(`unexpected referenced foreign-key "user_id" returned %v for node %v`, fk, n.ID)
}
assign(node, n)
}
return nil
}
func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllowedGroupQuery, nodes []*User, init func(*User), assign func(*User, *UserAllowedGroup)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*User)
......@@ -992,6 +1071,9 @@ func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllow
func (_q *UserQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
......@@ -1054,6 +1136,9 @@ func (_q *UserQuery) sqlQuery(ctx context.Context) *sql.Selector {
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates {
p(selector)
}
......@@ -1071,6 +1156,32 @@ func (_q *UserQuery) sqlQuery(ctx context.Context) *sql.Selector {
return selector
}
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *UserQuery) ForUpdate(opts ...sql.LockOption) *UserQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *UserQuery) ForShare(opts ...sql.LockOption) *UserQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// UserGroupBy is the group-by builder for User entities.
type UserGroupBy struct {
selector
......
......@@ -14,6 +14,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
......@@ -291,6 +292,21 @@ func (_u *UserUpdate) AddAttributeValues(v ...*UserAttributeValue) *UserUpdate {
return _u.AddAttributeValueIDs(ids...)
}
// AddPromoCodeUsageIDs adds the "promo_code_usages" edge to the PromoCodeUsage entity by IDs.
func (_u *UserUpdate) AddPromoCodeUsageIDs(ids ...int64) *UserUpdate {
_u.mutation.AddPromoCodeUsageIDs(ids...)
return _u
}
// AddPromoCodeUsages adds the "promo_code_usages" edges to the PromoCodeUsage entity.
func (_u *UserUpdate) AddPromoCodeUsages(v ...*PromoCodeUsage) *UserUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddPromoCodeUsageIDs(ids...)
}
// Mutation returns the UserMutation object of the builder.
func (_u *UserUpdate) Mutation() *UserMutation {
return _u.mutation
......@@ -443,6 +459,27 @@ func (_u *UserUpdate) RemoveAttributeValues(v ...*UserAttributeValue) *UserUpdat
return _u.RemoveAttributeValueIDs(ids...)
}
// ClearPromoCodeUsages clears all "promo_code_usages" edges to the PromoCodeUsage entity.
func (_u *UserUpdate) ClearPromoCodeUsages() *UserUpdate {
_u.mutation.ClearPromoCodeUsages()
return _u
}
// RemovePromoCodeUsageIDs removes the "promo_code_usages" edge to PromoCodeUsage entities by IDs.
func (_u *UserUpdate) RemovePromoCodeUsageIDs(ids ...int64) *UserUpdate {
_u.mutation.RemovePromoCodeUsageIDs(ids...)
return _u
}
// RemovePromoCodeUsages removes "promo_code_usages" edges to PromoCodeUsage entities.
func (_u *UserUpdate) RemovePromoCodeUsages(v ...*PromoCodeUsage) *UserUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemovePromoCodeUsageIDs(ids...)
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *UserUpdate) Save(ctx context.Context) (int, error) {
if err := _u.defaults(); err != nil {
......@@ -893,6 +930,51 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.PromoCodeUsagesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PromoCodeUsagesTable,
Columns: []string{user.PromoCodeUsagesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedPromoCodeUsagesIDs(); len(nodes) > 0 && !_u.mutation.PromoCodeUsagesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PromoCodeUsagesTable,
Columns: []string{user.PromoCodeUsagesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.PromoCodeUsagesIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PromoCodeUsagesTable,
Columns: []string{user.PromoCodeUsagesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{user.Label}
......@@ -1170,6 +1252,21 @@ func (_u *UserUpdateOne) AddAttributeValues(v ...*UserAttributeValue) *UserUpdat
return _u.AddAttributeValueIDs(ids...)
}
// AddPromoCodeUsageIDs adds the "promo_code_usages" edge to the PromoCodeUsage entity by IDs.
func (_u *UserUpdateOne) AddPromoCodeUsageIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddPromoCodeUsageIDs(ids...)
return _u
}
// AddPromoCodeUsages adds the "promo_code_usages" edges to the PromoCodeUsage entity.
func (_u *UserUpdateOne) AddPromoCodeUsages(v ...*PromoCodeUsage) *UserUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddPromoCodeUsageIDs(ids...)
}
// Mutation returns the UserMutation object of the builder.
func (_u *UserUpdateOne) Mutation() *UserMutation {
return _u.mutation
......@@ -1322,6 +1419,27 @@ func (_u *UserUpdateOne) RemoveAttributeValues(v ...*UserAttributeValue) *UserUp
return _u.RemoveAttributeValueIDs(ids...)
}
// ClearPromoCodeUsages clears all "promo_code_usages" edges to the PromoCodeUsage entity.
func (_u *UserUpdateOne) ClearPromoCodeUsages() *UserUpdateOne {
_u.mutation.ClearPromoCodeUsages()
return _u
}
// RemovePromoCodeUsageIDs removes the "promo_code_usages" edge to PromoCodeUsage entities by IDs.
func (_u *UserUpdateOne) RemovePromoCodeUsageIDs(ids ...int64) *UserUpdateOne {
_u.mutation.RemovePromoCodeUsageIDs(ids...)
return _u
}
// RemovePromoCodeUsages removes "promo_code_usages" edges to PromoCodeUsage entities.
func (_u *UserUpdateOne) RemovePromoCodeUsages(v ...*PromoCodeUsage) *UserUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemovePromoCodeUsageIDs(ids...)
}
// Where appends a list predicates to the UserUpdate builder.
func (_u *UserUpdateOne) Where(ps ...predicate.User) *UserUpdateOne {
_u.mutation.Where(ps...)
......@@ -1802,6 +1920,51 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.PromoCodeUsagesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PromoCodeUsagesTable,
Columns: []string{user.PromoCodeUsagesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedPromoCodeUsagesIDs(); len(nodes) > 0 && !_u.mutation.PromoCodeUsagesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PromoCodeUsagesTable,
Columns: []string{user.PromoCodeUsagesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.PromoCodeUsagesIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PromoCodeUsagesTable,
Columns: []string{user.PromoCodeUsagesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
_node = &User{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues
......
......@@ -8,6 +8,7 @@ import (
"math"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"github.com/Wei-Shaw/sub2api/ent/group"
......@@ -25,6 +26,7 @@ type UserAllowedGroupQuery struct {
predicates []predicate.UserAllowedGroup
withUser *UserQuery
withGroup *GroupQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
......@@ -347,6 +349,9 @@ func (_q *UserAllowedGroupQuery) sqlAll(ctx context.Context, hooks ...queryHook)
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks {
hooks[i](ctx, _spec)
}
......@@ -432,6 +437,9 @@ func (_q *UserAllowedGroupQuery) loadGroup(ctx context.Context, query *GroupQuer
func (_q *UserAllowedGroupQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Unique = false
_spec.Node.Columns = nil
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
......@@ -495,6 +503,9 @@ func (_q *UserAllowedGroupQuery) sqlQuery(ctx context.Context) *sql.Selector {
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates {
p(selector)
}
......@@ -512,6 +523,32 @@ func (_q *UserAllowedGroupQuery) sqlQuery(ctx context.Context) *sql.Selector {
return selector
}
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *UserAllowedGroupQuery) ForUpdate(opts ...sql.LockOption) *UserAllowedGroupQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *UserAllowedGroupQuery) ForShare(opts ...sql.LockOption) *UserAllowedGroupQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// UserAllowedGroupGroupBy is the group-by builder for UserAllowedGroup entities.
type UserAllowedGroupGroupBy struct {
selector
......
......@@ -9,6 +9,7 @@ import (
"math"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
......@@ -25,6 +26,7 @@ type UserAttributeDefinitionQuery struct {
inters []Interceptor
predicates []predicate.UserAttributeDefinition
withValues *UserAttributeValueQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
......@@ -384,6 +386,9 @@ func (_q *UserAttributeDefinitionQuery) sqlAll(ctx context.Context, hooks ...que
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks {
hooks[i](ctx, _spec)
}
......@@ -436,6 +441,9 @@ func (_q *UserAttributeDefinitionQuery) loadValues(ctx context.Context, query *U
func (_q *UserAttributeDefinitionQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
......@@ -498,6 +506,9 @@ func (_q *UserAttributeDefinitionQuery) sqlQuery(ctx context.Context) *sql.Selec
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates {
p(selector)
}
......@@ -515,6 +526,32 @@ func (_q *UserAttributeDefinitionQuery) sqlQuery(ctx context.Context) *sql.Selec
return selector
}
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *UserAttributeDefinitionQuery) ForUpdate(opts ...sql.LockOption) *UserAttributeDefinitionQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *UserAttributeDefinitionQuery) ForShare(opts ...sql.LockOption) *UserAttributeDefinitionQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// UserAttributeDefinitionGroupBy is the group-by builder for UserAttributeDefinition entities.
type UserAttributeDefinitionGroupBy struct {
selector
......
......@@ -8,6 +8,7 @@ import (
"math"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
......@@ -26,6 +27,7 @@ type UserAttributeValueQuery struct {
predicates []predicate.UserAttributeValue
withUser *UserQuery
withDefinition *UserAttributeDefinitionQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
......@@ -420,6 +422,9 @@ func (_q *UserAttributeValueQuery) sqlAll(ctx context.Context, hooks ...queryHoo
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks {
hooks[i](ctx, _spec)
}
......@@ -505,6 +510,9 @@ func (_q *UserAttributeValueQuery) loadDefinition(ctx context.Context, query *Us
func (_q *UserAttributeValueQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
......@@ -573,6 +581,9 @@ func (_q *UserAttributeValueQuery) sqlQuery(ctx context.Context) *sql.Selector {
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates {
p(selector)
}
......@@ -590,6 +601,32 @@ func (_q *UserAttributeValueQuery) sqlQuery(ctx context.Context) *sql.Selector {
return selector
}
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *UserAttributeValueQuery) ForUpdate(opts ...sql.LockOption) *UserAttributeValueQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *UserAttributeValueQuery) ForShare(opts ...sql.LockOption) *UserAttributeValueQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// UserAttributeValueGroupBy is the group-by builder for UserAttributeValue entities.
type UserAttributeValueGroupBy struct {
selector
......
......@@ -9,6 +9,7 @@ import (
"math"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
......@@ -30,6 +31,7 @@ type UserSubscriptionQuery struct {
withGroup *GroupQuery
withAssignedByUser *UserQuery
withUsageLogs *UsageLogQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
......@@ -494,6 +496,9 @@ func (_q *UserSubscriptionQuery) sqlAll(ctx context.Context, hooks ...queryHook)
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks {
hooks[i](ctx, _spec)
}
......@@ -657,6 +662,9 @@ func (_q *UserSubscriptionQuery) loadUsageLogs(ctx context.Context, query *Usage
func (_q *UserSubscriptionQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
......@@ -728,6 +736,9 @@ func (_q *UserSubscriptionQuery) sqlQuery(ctx context.Context) *sql.Selector {
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates {
p(selector)
}
......@@ -745,6 +756,32 @@ func (_q *UserSubscriptionQuery) sqlQuery(ctx context.Context) *sql.Selector {
return selector
}
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *UserSubscriptionQuery) ForUpdate(opts ...sql.LockOption) *UserSubscriptionQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *UserSubscriptionQuery) ForShare(opts ...sql.LockOption) *UserSubscriptionQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// UserSubscriptionGroupBy is the group-by builder for UserSubscription entities.
type UserSubscriptionGroupBy struct {
selector
......
......@@ -687,7 +687,7 @@ func setDefaults() {
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求)
viper.SetDefault("gateway.stream_data_interval_timeout", 180)
viper.SetDefault("gateway.stream_keepalive_interval", 10)
viper.SetDefault("gateway.max_line_size", 10*1024*1024)
viper.SetDefault("gateway.max_line_size", 40*1024*1024)
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 45*time.Second)
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
......
package admin
import (
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// PromoHandler handles admin promo code management
type PromoHandler struct {
promoService *service.PromoService
}
// NewPromoHandler creates a new admin promo handler
func NewPromoHandler(promoService *service.PromoService) *PromoHandler {
return &PromoHandler{
promoService: promoService,
}
}
// CreatePromoCodeRequest represents create promo code request
type CreatePromoCodeRequest struct {
Code string `json:"code"` // 可选,为空则自动生成
BonusAmount float64 `json:"bonus_amount" binding:"required,min=0"` // 赠送余额
MaxUses int `json:"max_uses" binding:"min=0"` // 最大使用次数,0=无限
ExpiresAt *int64 `json:"expires_at"` // 过期时间戳(秒)
Notes string `json:"notes"` // 备注
}
// UpdatePromoCodeRequest represents update promo code request
type UpdatePromoCodeRequest struct {
Code *string `json:"code"`
BonusAmount *float64 `json:"bonus_amount" binding:"omitempty,min=0"`
MaxUses *int `json:"max_uses" binding:"omitempty,min=0"`
Status *string `json:"status" binding:"omitempty,oneof=active disabled"`
ExpiresAt *int64 `json:"expires_at"`
Notes *string `json:"notes"`
}
// List handles listing all promo codes with pagination
// GET /api/v1/admin/promo-codes
func (h *PromoHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
status := c.Query("status")
search := strings.TrimSpace(c.Query("search"))
if len(search) > 100 {
search = search[:100]
}
params := pagination.PaginationParams{
Page: page,
PageSize: pageSize,
}
codes, paginationResult, err := h.promoService.List(c.Request.Context(), params, status, search)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]dto.PromoCode, 0, len(codes))
for i := range codes {
out = append(out, *dto.PromoCodeFromService(&codes[i]))
}
response.Paginated(c, out, paginationResult.Total, page, pageSize)
}
// GetByID handles getting a promo code by ID
// GET /api/v1/admin/promo-codes/:id
func (h *PromoHandler) GetByID(c *gin.Context) {
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid promo code ID")
return
}
code, err := h.promoService.GetByID(c.Request.Context(), codeID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.PromoCodeFromService(code))
}
// Create handles creating a new promo code
// POST /api/v1/admin/promo-codes
func (h *PromoHandler) Create(c *gin.Context) {
var req CreatePromoCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
input := &service.CreatePromoCodeInput{
Code: req.Code,
BonusAmount: req.BonusAmount,
MaxUses: req.MaxUses,
Notes: req.Notes,
}
if req.ExpiresAt != nil {
t := time.Unix(*req.ExpiresAt, 0)
input.ExpiresAt = &t
}
code, err := h.promoService.Create(c.Request.Context(), input)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.PromoCodeFromService(code))
}
// Update handles updating a promo code
// PUT /api/v1/admin/promo-codes/:id
func (h *PromoHandler) Update(c *gin.Context) {
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid promo code ID")
return
}
var req UpdatePromoCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
input := &service.UpdatePromoCodeInput{
Code: req.Code,
BonusAmount: req.BonusAmount,
MaxUses: req.MaxUses,
Status: req.Status,
Notes: req.Notes,
}
if req.ExpiresAt != nil {
if *req.ExpiresAt == 0 {
// 0 表示清除过期时间
input.ExpiresAt = nil
} else {
t := time.Unix(*req.ExpiresAt, 0)
input.ExpiresAt = &t
}
}
code, err := h.promoService.Update(c.Request.Context(), codeID, input)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.PromoCodeFromService(code))
}
// Delete handles deleting a promo code
// DELETE /api/v1/admin/promo-codes/:id
func (h *PromoHandler) Delete(c *gin.Context) {
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid promo code ID")
return
}
err = h.promoService.Delete(c.Request.Context(), codeID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Promo code deleted successfully"})
}
// GetUsages handles getting usage records for a promo code
// GET /api/v1/admin/promo-codes/:id/usages
func (h *PromoHandler) GetUsages(c *gin.Context) {
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid promo code ID")
return
}
page, pageSize := response.ParsePagination(c)
params := pagination.PaginationParams{
Page: page,
PageSize: pageSize,
}
usages, paginationResult, err := h.promoService.ListUsages(c.Request.Context(), codeID, params)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]dto.PromoCodeUsage, 0, len(usages))
for i := range usages {
out = append(out, *dto.PromoCodeUsageFromService(&usages[i]))
}
response.Paginated(c, out, paginationResult.Total, page, pageSize)
}
......@@ -12,19 +12,21 @@ import (
// AuthHandler handles authentication-related requests
type AuthHandler struct {
cfg *config.Config
authService *service.AuthService
userService *service.UserService
settingSvc *service.SettingService
cfg *config.Config
authService *service.AuthService
userService *service.UserService
settingSvc *service.SettingService
promoService *service.PromoService
}
// NewAuthHandler creates a new AuthHandler
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService) *AuthHandler {
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService) *AuthHandler {
return &AuthHandler{
cfg: cfg,
authService: authService,
userService: userService,
settingSvc: settingService,
cfg: cfg,
authService: authService,
userService: userService,
settingSvc: settingService,
promoService: promoService,
}
}
......@@ -34,6 +36,7 @@ type RegisterRequest struct {
Password string `json:"password" binding:"required,min=6"`
VerifyCode string `json:"verify_code"`
TurnstileToken string `json:"turnstile_token"`
PromoCode string `json:"promo_code"` // 注册优惠码
}
// SendVerifyCodeRequest 发送验证码请求
......@@ -79,7 +82,7 @@ func (h *AuthHandler) Register(c *gin.Context) {
}
}
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode)
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode)
if err != nil {
response.ErrorFrom(c, err)
return
......@@ -174,3 +177,63 @@ func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
response.Success(c, UserResponse{User: dto.UserFromService(user), RunMode: runMode})
}
// ValidatePromoCodeRequest 验证优惠码请求
type ValidatePromoCodeRequest struct {
Code string `json:"code" binding:"required"`
}
// ValidatePromoCodeResponse 验证优惠码响应
type ValidatePromoCodeResponse struct {
Valid bool `json:"valid"`
BonusAmount float64 `json:"bonus_amount,omitempty"`
ErrorCode string `json:"error_code,omitempty"`
Message string `json:"message,omitempty"`
}
// ValidatePromoCode 验证优惠码(公开接口,注册前调用)
// POST /api/v1/auth/validate-promo-code
func (h *AuthHandler) ValidatePromoCode(c *gin.Context) {
var req ValidatePromoCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
promoCode, err := h.promoService.ValidatePromoCode(c.Request.Context(), req.Code)
if err != nil {
// 根据错误类型返回对应的错误码
errorCode := "PROMO_CODE_INVALID"
switch err {
case service.ErrPromoCodeNotFound:
errorCode = "PROMO_CODE_NOT_FOUND"
case service.ErrPromoCodeExpired:
errorCode = "PROMO_CODE_EXPIRED"
case service.ErrPromoCodeDisabled:
errorCode = "PROMO_CODE_DISABLED"
case service.ErrPromoCodeMaxUsed:
errorCode = "PROMO_CODE_MAX_USED"
case service.ErrPromoCodeAlreadyUsed:
errorCode = "PROMO_CODE_ALREADY_USED"
}
response.Success(c, ValidatePromoCodeResponse{
Valid: false,
ErrorCode: errorCode,
})
return
}
if promoCode == nil {
response.Success(c, ValidatePromoCodeResponse{
Valid: false,
ErrorCode: "PROMO_CODE_INVALID",
})
return
}
response.Success(c, ValidatePromoCodeResponse{
Valid: true,
BonusAmount: promoCode.BonusAmount,
})
}
......@@ -370,3 +370,35 @@ func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult
Errors: r.Errors,
}
}
func PromoCodeFromService(pc *service.PromoCode) *PromoCode {
if pc == nil {
return nil
}
return &PromoCode{
ID: pc.ID,
Code: pc.Code,
BonusAmount: pc.BonusAmount,
MaxUses: pc.MaxUses,
UsedCount: pc.UsedCount,
Status: pc.Status,
ExpiresAt: pc.ExpiresAt,
Notes: pc.Notes,
CreatedAt: pc.CreatedAt,
UpdatedAt: pc.UpdatedAt,
}
}
func PromoCodeUsageFromService(u *service.PromoCodeUsage) *PromoCodeUsage {
if u == nil {
return nil
}
return &PromoCodeUsage{
ID: u.ID,
PromoCodeID: u.PromoCodeID,
UserID: u.UserID,
BonusAmount: u.BonusAmount,
UsedAt: u.UsedAt,
User: UserFromServiceShallow(u.User),
}
}
......@@ -250,3 +250,28 @@ type BulkAssignResult struct {
Subscriptions []UserSubscription `json:"subscriptions"`
Errors []string `json:"errors"`
}
// PromoCode 注册优惠码
type PromoCode struct {
ID int64 `json:"id"`
Code string `json:"code"`
BonusAmount float64 `json:"bonus_amount"`
MaxUses int `json:"max_uses"`
UsedCount int `json:"used_count"`
Status string `json:"status"`
ExpiresAt *time.Time `json:"expires_at"`
Notes string `json:"notes"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// PromoCodeUsage 优惠码使用记录
type PromoCodeUsage struct {
ID int64 `json:"id"`
PromoCodeID int64 `json:"promo_code_id"`
UserID int64 `json:"user_id"`
BonusAmount float64 `json:"bonus_amount"`
UsedAt time.Time `json:"used_at"`
User *User `json:"user,omitempty"`
}
......@@ -16,6 +16,7 @@ type AdminHandlers struct {
AntigravityOAuth *admin.AntigravityOAuthHandler
Proxy *admin.ProxyHandler
Redeem *admin.RedeemHandler
Promo *admin.PromoHandler
Setting *admin.SettingHandler
System *admin.SystemHandler
Subscription *admin.SubscriptionHandler
......
......@@ -19,6 +19,7 @@ func ProvideAdminHandlers(
antigravityOAuthHandler *admin.AntigravityOAuthHandler,
proxyHandler *admin.ProxyHandler,
redeemHandler *admin.RedeemHandler,
promoHandler *admin.PromoHandler,
settingHandler *admin.SettingHandler,
systemHandler *admin.SystemHandler,
subscriptionHandler *admin.SubscriptionHandler,
......@@ -36,6 +37,7 @@ func ProvideAdminHandlers(
AntigravityOAuth: antigravityOAuthHandler,
Proxy: proxyHandler,
Redeem: redeemHandler,
Promo: promoHandler,
Setting: settingHandler,
System: systemHandler,
Subscription: subscriptionHandler,
......@@ -105,6 +107,7 @@ var ProviderSet = wire.NewSet(
admin.NewAntigravityOAuthHandler,
admin.NewProxyHandler,
admin.NewRedeemHandler,
admin.NewPromoHandler,
admin.NewSettingHandler,
ProvideSystemHandler,
admin.NewSubscriptionHandler,
......
package middleware
import (
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
)
// RateLimiter Redis 速率限制器
type RateLimiter struct {
redis *redis.Client
prefix string
}
// NewRateLimiter 创建速率限制器实例
func NewRateLimiter(redisClient *redis.Client) *RateLimiter {
return &RateLimiter{
redis: redisClient,
prefix: "rate_limit:",
}
}
// Limit 返回速率限制中间件
// key: 限制类型标识
// limit: 时间窗口内最大请求数
// window: 时间窗口
func (r *RateLimiter) Limit(key string, limit int, window time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
ip := c.ClientIP()
redisKey := r.prefix + key + ":" + ip
ctx := c.Request.Context()
// 使用 INCR 原子操作增加计数
count, err := r.redis.Incr(ctx, redisKey).Result()
if err != nil {
// Redis 错误时放行,避免影响正常服务
c.Next()
return
}
// 首次访问时设置过期时间
if count == 1 {
r.redis.Expire(ctx, redisKey, window)
}
// 超过限制
if count > int64(limit) {
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
"error": "rate limit exceeded",
"message": "Too many requests, please try again later",
})
return
}
c.Next()
}
}
......@@ -9,4 +9,6 @@ const (
ForcePlatform Key = "ctx_force_platform"
// IsClaudeCodeClient 是否为 Claude Code 客户端,由中间件设置
IsClaudeCodeClient Key = "ctx_is_claude_code_client"
// Group 认证后的分组信息,由 API Key 认证中间件设置
Group Key = "ctx_group"
)
......@@ -675,6 +675,40 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
return err
}
func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error {
now := time.Now().UTC()
payload := map[string]string{
"rate_limited_at": now.Format(time.RFC3339),
"rate_limit_reset_at": resetAt.UTC().Format(time.RFC3339),
}
raw, err := json.Marshal(payload)
if err != nil {
return err
}
path := "{antigravity_quota_scopes," + string(scope) + "}"
client := clientFromContext(ctx, r.client)
result, err := client.ExecContext(
ctx,
"UPDATE accounts SET extra = jsonb_set(COALESCE(extra, '{}'::jsonb), $1::text[], $2::jsonb, true), updated_at = NOW() WHERE id = $3 AND deleted_at IS NULL",
path,
raw,
id,
)
if err != nil {
return err
}
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return service.ErrAccountNotFound
}
return nil
}
func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
_, err := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
......@@ -718,6 +752,27 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error
return err
}
func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
client := clientFromContext(ctx, r.client)
result, err := client.ExecContext(
ctx,
"UPDATE accounts SET extra = COALESCE(extra, '{}'::jsonb) - 'antigravity_quota_scopes', updated_at = NOW() WHERE id = $1 AND deleted_at IS NULL",
id,
)
if err != nil {
return err
}
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return service.ErrAccountNotFound
}
return nil
}
func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
builder := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
......
......@@ -339,6 +339,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive,
Status: g.Status,
Hydrated: true,
SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUsd,
WeeklyLimitUSD: g.WeeklyLimitUsd,
......
......@@ -60,6 +60,17 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
}
func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group, error) {
out, err := r.GetByIDLite(ctx, id)
if err != nil {
return nil, err
}
count, _ := r.GetAccountCount(ctx, out.ID)
out.AccountCount = count
return out, nil
}
func (r *groupRepository) GetByIDLite(ctx context.Context, id int64) (*service.Group, error) {
// AccountCount is intentionally not loaded here; use GetByID when needed.
m, err := r.client.Group.Query().
Where(group.IDEQ(id)).
Only(ctx)
......@@ -67,10 +78,7 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group
return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil)
}
out := groupEntityToService(m)
count, _ := r.GetAccountCount(ctx, out.ID)
out.AccountCount = count
return out, nil
return groupEntityToService(m), nil
}
func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) error {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment