"backend/internal/vscode:/vscode.git/clone" did not exist on "1c6393b13137368c798648f455c1c0db07bec16d"
Commit d2fc14fb authored by long's avatar long
Browse files

feat: 实现注册优惠码功能

  - 支持创建/编辑/删除优惠码,设置赠送金额和使用限制
  - 注册页面实时验证优惠码并显示赠送金额
  - 支持 URL 参数自动填充 (?promo=CODE)
  - 添加优惠码验证接口速率限制
  - 使用数据库行锁防止并发超限
  - 新增后台优惠码管理页面,支持复制注册链接
parent 7d1fe818
......@@ -13,6 +13,7 @@ import (
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
......@@ -271,6 +272,21 @@ func (_c *UserCreate) AddAttributeValues(v ...*UserAttributeValue) *UserCreate {
return _c.AddAttributeValueIDs(ids...)
}
// AddPromoCodeUsageIDs adds the "promo_code_usages" edge to the PromoCodeUsage entity by IDs.
func (_c *UserCreate) AddPromoCodeUsageIDs(ids ...int64) *UserCreate {
_c.mutation.AddPromoCodeUsageIDs(ids...)
return _c
}
// AddPromoCodeUsages adds the "promo_code_usages" edges to the PromoCodeUsage entity.
func (_c *UserCreate) AddPromoCodeUsages(v ...*PromoCodeUsage) *UserCreate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _c.AddPromoCodeUsageIDs(ids...)
}
// Mutation returns the UserMutation object of the builder.
func (_c *UserCreate) Mutation() *UserMutation {
return _c.mutation
......@@ -593,6 +609,22 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
}
_spec.Edges = append(_spec.Edges, edge)
}
if nodes := _c.mutation.PromoCodeUsagesIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PromoCodeUsagesTable,
Columns: []string{user.PromoCodeUsagesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges = append(_spec.Edges, edge)
}
return _node, _spec
}
......
......@@ -9,12 +9,14 @@ import (
"math"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
......@@ -37,7 +39,9 @@ type UserQuery struct {
withAllowedGroups *GroupQuery
withUsageLogs *UsageLogQuery
withAttributeValues *UserAttributeValueQuery
withPromoCodeUsages *PromoCodeUsageQuery
withUserAllowedGroups *UserAllowedGroupQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
......@@ -228,6 +232,28 @@ func (_q *UserQuery) QueryAttributeValues() *UserAttributeValueQuery {
return query
}
// QueryPromoCodeUsages chains the current query on the "promo_code_usages" edge.
func (_q *UserQuery) QueryPromoCodeUsages() *PromoCodeUsageQuery {
query := (&PromoCodeUsageClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(user.Table, user.FieldID, selector),
sqlgraph.To(promocodeusage.Table, promocodeusage.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, user.PromoCodeUsagesTable, user.PromoCodeUsagesColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// QueryUserAllowedGroups chains the current query on the "user_allowed_groups" edge.
func (_q *UserQuery) QueryUserAllowedGroups() *UserAllowedGroupQuery {
query := (&UserAllowedGroupClient{config: _q.config}).Query()
......@@ -449,6 +475,7 @@ func (_q *UserQuery) Clone() *UserQuery {
withAllowedGroups: _q.withAllowedGroups.Clone(),
withUsageLogs: _q.withUsageLogs.Clone(),
withAttributeValues: _q.withAttributeValues.Clone(),
withPromoCodeUsages: _q.withPromoCodeUsages.Clone(),
withUserAllowedGroups: _q.withUserAllowedGroups.Clone(),
// clone intermediate query.
sql: _q.sql.Clone(),
......@@ -533,6 +560,17 @@ func (_q *UserQuery) WithAttributeValues(opts ...func(*UserAttributeValueQuery))
return _q
}
// WithPromoCodeUsages tells the query-builder to eager-load the nodes that are connected to
// the "promo_code_usages" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithPromoCodeUsages(opts ...func(*PromoCodeUsageQuery)) *UserQuery {
query := (&PromoCodeUsageClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withPromoCodeUsages = query
return _q
}
// WithUserAllowedGroups tells the query-builder to eager-load the nodes that are connected to
// the "user_allowed_groups" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithUserAllowedGroups(opts ...func(*UserAllowedGroupQuery)) *UserQuery {
......@@ -622,7 +660,7 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
var (
nodes = []*User{}
_spec = _q.querySpec()
loadedTypes = [8]bool{
loadedTypes = [9]bool{
_q.withAPIKeys != nil,
_q.withRedeemCodes != nil,
_q.withSubscriptions != nil,
......@@ -630,6 +668,7 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
_q.withAllowedGroups != nil,
_q.withUsageLogs != nil,
_q.withAttributeValues != nil,
_q.withPromoCodeUsages != nil,
_q.withUserAllowedGroups != nil,
}
)
......@@ -642,6 +681,9 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks {
hooks[i](ctx, _spec)
}
......@@ -702,6 +744,13 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
return nil, err
}
}
if query := _q.withPromoCodeUsages; query != nil {
if err := _q.loadPromoCodeUsages(ctx, query, nodes,
func(n *User) { n.Edges.PromoCodeUsages = []*PromoCodeUsage{} },
func(n *User, e *PromoCodeUsage) { n.Edges.PromoCodeUsages = append(n.Edges.PromoCodeUsages, e) }); err != nil {
return nil, err
}
}
if query := _q.withUserAllowedGroups; query != nil {
if err := _q.loadUserAllowedGroups(ctx, query, nodes,
func(n *User) { n.Edges.UserAllowedGroups = []*UserAllowedGroup{} },
......@@ -959,6 +1008,36 @@ func (_q *UserQuery) loadAttributeValues(ctx context.Context, query *UserAttribu
}
return nil
}
func (_q *UserQuery) loadPromoCodeUsages(ctx context.Context, query *PromoCodeUsageQuery, nodes []*User, init func(*User), assign func(*User, *PromoCodeUsage)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*User)
for i := range nodes {
fks = append(fks, nodes[i].ID)
nodeids[nodes[i].ID] = nodes[i]
if init != nil {
init(nodes[i])
}
}
if len(query.ctx.Fields) > 0 {
query.ctx.AppendFieldOnce(promocodeusage.FieldUserID)
}
query.Where(predicate.PromoCodeUsage(func(s *sql.Selector) {
s.Where(sql.InValues(s.C(user.PromoCodeUsagesColumn), fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
fk := n.UserID
node, ok := nodeids[fk]
if !ok {
return fmt.Errorf(`unexpected referenced foreign-key "user_id" returned %v for node %v`, fk, n.ID)
}
assign(node, n)
}
return nil
}
func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllowedGroupQuery, nodes []*User, init func(*User), assign func(*User, *UserAllowedGroup)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*User)
......@@ -992,6 +1071,9 @@ func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllow
func (_q *UserQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
......@@ -1054,6 +1136,9 @@ func (_q *UserQuery) sqlQuery(ctx context.Context) *sql.Selector {
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates {
p(selector)
}
......@@ -1071,6 +1156,32 @@ func (_q *UserQuery) sqlQuery(ctx context.Context) *sql.Selector {
return selector
}
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *UserQuery) ForUpdate(opts ...sql.LockOption) *UserQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *UserQuery) ForShare(opts ...sql.LockOption) *UserQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// UserGroupBy is the group-by builder for User entities.
type UserGroupBy struct {
selector
......
......@@ -14,6 +14,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
......@@ -291,6 +292,21 @@ func (_u *UserUpdate) AddAttributeValues(v ...*UserAttributeValue) *UserUpdate {
return _u.AddAttributeValueIDs(ids...)
}
// AddPromoCodeUsageIDs adds the "promo_code_usages" edge to the PromoCodeUsage entity by IDs.
func (_u *UserUpdate) AddPromoCodeUsageIDs(ids ...int64) *UserUpdate {
_u.mutation.AddPromoCodeUsageIDs(ids...)
return _u
}
// AddPromoCodeUsages adds the "promo_code_usages" edges to the PromoCodeUsage entity.
func (_u *UserUpdate) AddPromoCodeUsages(v ...*PromoCodeUsage) *UserUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddPromoCodeUsageIDs(ids...)
}
// Mutation returns the UserMutation object of the builder.
func (_u *UserUpdate) Mutation() *UserMutation {
return _u.mutation
......@@ -443,6 +459,27 @@ func (_u *UserUpdate) RemoveAttributeValues(v ...*UserAttributeValue) *UserUpdat
return _u.RemoveAttributeValueIDs(ids...)
}
// ClearPromoCodeUsages clears all "promo_code_usages" edges to the PromoCodeUsage entity.
func (_u *UserUpdate) ClearPromoCodeUsages() *UserUpdate {
_u.mutation.ClearPromoCodeUsages()
return _u
}
// RemovePromoCodeUsageIDs removes the "promo_code_usages" edge to PromoCodeUsage entities by IDs.
func (_u *UserUpdate) RemovePromoCodeUsageIDs(ids ...int64) *UserUpdate {
_u.mutation.RemovePromoCodeUsageIDs(ids...)
return _u
}
// RemovePromoCodeUsages removes "promo_code_usages" edges to PromoCodeUsage entities.
func (_u *UserUpdate) RemovePromoCodeUsages(v ...*PromoCodeUsage) *UserUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemovePromoCodeUsageIDs(ids...)
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *UserUpdate) Save(ctx context.Context) (int, error) {
if err := _u.defaults(); err != nil {
......@@ -893,6 +930,51 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.PromoCodeUsagesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PromoCodeUsagesTable,
Columns: []string{user.PromoCodeUsagesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedPromoCodeUsagesIDs(); len(nodes) > 0 && !_u.mutation.PromoCodeUsagesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PromoCodeUsagesTable,
Columns: []string{user.PromoCodeUsagesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.PromoCodeUsagesIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PromoCodeUsagesTable,
Columns: []string{user.PromoCodeUsagesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{user.Label}
......@@ -1170,6 +1252,21 @@ func (_u *UserUpdateOne) AddAttributeValues(v ...*UserAttributeValue) *UserUpdat
return _u.AddAttributeValueIDs(ids...)
}
// AddPromoCodeUsageIDs adds the "promo_code_usages" edge to the PromoCodeUsage entity by IDs.
func (_u *UserUpdateOne) AddPromoCodeUsageIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddPromoCodeUsageIDs(ids...)
return _u
}
// AddPromoCodeUsages adds the "promo_code_usages" edges to the PromoCodeUsage entity.
func (_u *UserUpdateOne) AddPromoCodeUsages(v ...*PromoCodeUsage) *UserUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddPromoCodeUsageIDs(ids...)
}
// Mutation returns the UserMutation object of the builder.
func (_u *UserUpdateOne) Mutation() *UserMutation {
return _u.mutation
......@@ -1322,6 +1419,27 @@ func (_u *UserUpdateOne) RemoveAttributeValues(v ...*UserAttributeValue) *UserUp
return _u.RemoveAttributeValueIDs(ids...)
}
// ClearPromoCodeUsages clears all "promo_code_usages" edges to the PromoCodeUsage entity.
func (_u *UserUpdateOne) ClearPromoCodeUsages() *UserUpdateOne {
_u.mutation.ClearPromoCodeUsages()
return _u
}
// RemovePromoCodeUsageIDs removes the "promo_code_usages" edge to PromoCodeUsage entities by IDs.
func (_u *UserUpdateOne) RemovePromoCodeUsageIDs(ids ...int64) *UserUpdateOne {
_u.mutation.RemovePromoCodeUsageIDs(ids...)
return _u
}
// RemovePromoCodeUsages removes "promo_code_usages" edges to PromoCodeUsage entities.
func (_u *UserUpdateOne) RemovePromoCodeUsages(v ...*PromoCodeUsage) *UserUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemovePromoCodeUsageIDs(ids...)
}
// Where appends a list predicates to the UserUpdate builder.
func (_u *UserUpdateOne) Where(ps ...predicate.User) *UserUpdateOne {
_u.mutation.Where(ps...)
......@@ -1802,6 +1920,51 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.PromoCodeUsagesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PromoCodeUsagesTable,
Columns: []string{user.PromoCodeUsagesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedPromoCodeUsagesIDs(); len(nodes) > 0 && !_u.mutation.PromoCodeUsagesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PromoCodeUsagesTable,
Columns: []string{user.PromoCodeUsagesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.PromoCodeUsagesIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PromoCodeUsagesTable,
Columns: []string{user.PromoCodeUsagesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
_node = &User{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues
......
......@@ -8,6 +8,7 @@ import (
"math"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"github.com/Wei-Shaw/sub2api/ent/group"
......@@ -25,6 +26,7 @@ type UserAllowedGroupQuery struct {
predicates []predicate.UserAllowedGroup
withUser *UserQuery
withGroup *GroupQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
......@@ -347,6 +349,9 @@ func (_q *UserAllowedGroupQuery) sqlAll(ctx context.Context, hooks ...queryHook)
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks {
hooks[i](ctx, _spec)
}
......@@ -432,6 +437,9 @@ func (_q *UserAllowedGroupQuery) loadGroup(ctx context.Context, query *GroupQuer
func (_q *UserAllowedGroupQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Unique = false
_spec.Node.Columns = nil
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
......@@ -495,6 +503,9 @@ func (_q *UserAllowedGroupQuery) sqlQuery(ctx context.Context) *sql.Selector {
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates {
p(selector)
}
......@@ -512,6 +523,32 @@ func (_q *UserAllowedGroupQuery) sqlQuery(ctx context.Context) *sql.Selector {
return selector
}
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *UserAllowedGroupQuery) ForUpdate(opts ...sql.LockOption) *UserAllowedGroupQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *UserAllowedGroupQuery) ForShare(opts ...sql.LockOption) *UserAllowedGroupQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// UserAllowedGroupGroupBy is the group-by builder for UserAllowedGroup entities.
type UserAllowedGroupGroupBy struct {
selector
......
......@@ -9,6 +9,7 @@ import (
"math"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
......@@ -25,6 +26,7 @@ type UserAttributeDefinitionQuery struct {
inters []Interceptor
predicates []predicate.UserAttributeDefinition
withValues *UserAttributeValueQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
......@@ -384,6 +386,9 @@ func (_q *UserAttributeDefinitionQuery) sqlAll(ctx context.Context, hooks ...que
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks {
hooks[i](ctx, _spec)
}
......@@ -436,6 +441,9 @@ func (_q *UserAttributeDefinitionQuery) loadValues(ctx context.Context, query *U
func (_q *UserAttributeDefinitionQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
......@@ -498,6 +506,9 @@ func (_q *UserAttributeDefinitionQuery) sqlQuery(ctx context.Context) *sql.Selec
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates {
p(selector)
}
......@@ -515,6 +526,32 @@ func (_q *UserAttributeDefinitionQuery) sqlQuery(ctx context.Context) *sql.Selec
return selector
}
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *UserAttributeDefinitionQuery) ForUpdate(opts ...sql.LockOption) *UserAttributeDefinitionQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *UserAttributeDefinitionQuery) ForShare(opts ...sql.LockOption) *UserAttributeDefinitionQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// UserAttributeDefinitionGroupBy is the group-by builder for UserAttributeDefinition entities.
type UserAttributeDefinitionGroupBy struct {
selector
......
......@@ -8,6 +8,7 @@ import (
"math"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
......@@ -26,6 +27,7 @@ type UserAttributeValueQuery struct {
predicates []predicate.UserAttributeValue
withUser *UserQuery
withDefinition *UserAttributeDefinitionQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
......@@ -420,6 +422,9 @@ func (_q *UserAttributeValueQuery) sqlAll(ctx context.Context, hooks ...queryHoo
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks {
hooks[i](ctx, _spec)
}
......@@ -505,6 +510,9 @@ func (_q *UserAttributeValueQuery) loadDefinition(ctx context.Context, query *Us
func (_q *UserAttributeValueQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
......@@ -573,6 +581,9 @@ func (_q *UserAttributeValueQuery) sqlQuery(ctx context.Context) *sql.Selector {
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates {
p(selector)
}
......@@ -590,6 +601,32 @@ func (_q *UserAttributeValueQuery) sqlQuery(ctx context.Context) *sql.Selector {
return selector
}
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *UserAttributeValueQuery) ForUpdate(opts ...sql.LockOption) *UserAttributeValueQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *UserAttributeValueQuery) ForShare(opts ...sql.LockOption) *UserAttributeValueQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// UserAttributeValueGroupBy is the group-by builder for UserAttributeValue entities.
type UserAttributeValueGroupBy struct {
selector
......
......@@ -9,6 +9,7 @@ import (
"math"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
......@@ -30,6 +31,7 @@ type UserSubscriptionQuery struct {
withGroup *GroupQuery
withAssignedByUser *UserQuery
withUsageLogs *UsageLogQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
......@@ -494,6 +496,9 @@ func (_q *UserSubscriptionQuery) sqlAll(ctx context.Context, hooks ...queryHook)
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks {
hooks[i](ctx, _spec)
}
......@@ -657,6 +662,9 @@ func (_q *UserSubscriptionQuery) loadUsageLogs(ctx context.Context, query *Usage
func (_q *UserSubscriptionQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
......@@ -728,6 +736,9 @@ func (_q *UserSubscriptionQuery) sqlQuery(ctx context.Context) *sql.Selector {
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates {
p(selector)
}
......@@ -745,6 +756,32 @@ func (_q *UserSubscriptionQuery) sqlQuery(ctx context.Context) *sql.Selector {
return selector
}
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *UserSubscriptionQuery) ForUpdate(opts ...sql.LockOption) *UserSubscriptionQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *UserSubscriptionQuery) ForShare(opts ...sql.LockOption) *UserSubscriptionQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// UserSubscriptionGroupBy is the group-by builder for UserSubscription entities.
type UserSubscriptionGroupBy struct {
selector
......
package admin
import (
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// PromoHandler handles admin promo code management
type PromoHandler struct {
promoService *service.PromoService
}
// NewPromoHandler creates a new admin promo handler
func NewPromoHandler(promoService *service.PromoService) *PromoHandler {
return &PromoHandler{
promoService: promoService,
}
}
// CreatePromoCodeRequest represents create promo code request
type CreatePromoCodeRequest struct {
Code string `json:"code"` // 可选,为空则自动生成
BonusAmount float64 `json:"bonus_amount" binding:"required,min=0"` // 赠送余额
MaxUses int `json:"max_uses" binding:"min=0"` // 最大使用次数,0=无限
ExpiresAt *int64 `json:"expires_at"` // 过期时间戳(秒)
Notes string `json:"notes"` // 备注
}
// UpdatePromoCodeRequest represents update promo code request
type UpdatePromoCodeRequest struct {
Code *string `json:"code"`
BonusAmount *float64 `json:"bonus_amount" binding:"omitempty,min=0"`
MaxUses *int `json:"max_uses" binding:"omitempty,min=0"`
Status *string `json:"status" binding:"omitempty,oneof=active disabled"`
ExpiresAt *int64 `json:"expires_at"`
Notes *string `json:"notes"`
}
// List handles listing all promo codes with pagination
// GET /api/v1/admin/promo-codes
func (h *PromoHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
status := c.Query("status")
search := strings.TrimSpace(c.Query("search"))
if len(search) > 100 {
search = search[:100]
}
params := pagination.PaginationParams{
Page: page,
PageSize: pageSize,
}
codes, paginationResult, err := h.promoService.List(c.Request.Context(), params, status, search)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]dto.PromoCode, 0, len(codes))
for i := range codes {
out = append(out, *dto.PromoCodeFromService(&codes[i]))
}
response.Paginated(c, out, paginationResult.Total, page, pageSize)
}
// GetByID handles getting a promo code by ID
// GET /api/v1/admin/promo-codes/:id
func (h *PromoHandler) GetByID(c *gin.Context) {
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid promo code ID")
return
}
code, err := h.promoService.GetByID(c.Request.Context(), codeID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.PromoCodeFromService(code))
}
// Create handles creating a new promo code
// POST /api/v1/admin/promo-codes
func (h *PromoHandler) Create(c *gin.Context) {
var req CreatePromoCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
input := &service.CreatePromoCodeInput{
Code: req.Code,
BonusAmount: req.BonusAmount,
MaxUses: req.MaxUses,
Notes: req.Notes,
}
if req.ExpiresAt != nil {
t := time.Unix(*req.ExpiresAt, 0)
input.ExpiresAt = &t
}
code, err := h.promoService.Create(c.Request.Context(), input)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.PromoCodeFromService(code))
}
// Update handles updating a promo code
// PUT /api/v1/admin/promo-codes/:id
func (h *PromoHandler) Update(c *gin.Context) {
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid promo code ID")
return
}
var req UpdatePromoCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
input := &service.UpdatePromoCodeInput{
Code: req.Code,
BonusAmount: req.BonusAmount,
MaxUses: req.MaxUses,
Status: req.Status,
Notes: req.Notes,
}
if req.ExpiresAt != nil {
if *req.ExpiresAt == 0 {
// 0 表示清除过期时间
input.ExpiresAt = nil
} else {
t := time.Unix(*req.ExpiresAt, 0)
input.ExpiresAt = &t
}
}
code, err := h.promoService.Update(c.Request.Context(), codeID, input)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.PromoCodeFromService(code))
}
// Delete handles deleting a promo code
// DELETE /api/v1/admin/promo-codes/:id
func (h *PromoHandler) Delete(c *gin.Context) {
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid promo code ID")
return
}
err = h.promoService.Delete(c.Request.Context(), codeID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Promo code deleted successfully"})
}
// GetUsages handles getting usage records for a promo code
// GET /api/v1/admin/promo-codes/:id/usages
func (h *PromoHandler) GetUsages(c *gin.Context) {
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid promo code ID")
return
}
page, pageSize := response.ParsePagination(c)
params := pagination.PaginationParams{
Page: page,
PageSize: pageSize,
}
usages, paginationResult, err := h.promoService.ListUsages(c.Request.Context(), codeID, params)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]dto.PromoCodeUsage, 0, len(usages))
for i := range usages {
out = append(out, *dto.PromoCodeUsageFromService(&usages[i]))
}
response.Paginated(c, out, paginationResult.Total, page, pageSize)
}
......@@ -12,19 +12,21 @@ import (
// AuthHandler handles authentication-related requests
type AuthHandler struct {
cfg *config.Config
authService *service.AuthService
userService *service.UserService
settingSvc *service.SettingService
cfg *config.Config
authService *service.AuthService
userService *service.UserService
settingSvc *service.SettingService
promoService *service.PromoService
}
// NewAuthHandler creates a new AuthHandler
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService) *AuthHandler {
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService) *AuthHandler {
return &AuthHandler{
cfg: cfg,
authService: authService,
userService: userService,
settingSvc: settingService,
cfg: cfg,
authService: authService,
userService: userService,
settingSvc: settingService,
promoService: promoService,
}
}
......@@ -34,6 +36,7 @@ type RegisterRequest struct {
Password string `json:"password" binding:"required,min=6"`
VerifyCode string `json:"verify_code"`
TurnstileToken string `json:"turnstile_token"`
PromoCode string `json:"promo_code"` // 注册优惠码
}
// SendVerifyCodeRequest 发送验证码请求
......@@ -79,7 +82,7 @@ func (h *AuthHandler) Register(c *gin.Context) {
}
}
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode)
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode)
if err != nil {
response.ErrorFrom(c, err)
return
......@@ -174,3 +177,63 @@ func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
response.Success(c, UserResponse{User: dto.UserFromService(user), RunMode: runMode})
}
// ValidatePromoCodeRequest 验证优惠码请求
type ValidatePromoCodeRequest struct {
Code string `json:"code" binding:"required"`
}
// ValidatePromoCodeResponse 验证优惠码响应
type ValidatePromoCodeResponse struct {
Valid bool `json:"valid"`
BonusAmount float64 `json:"bonus_amount,omitempty"`
ErrorCode string `json:"error_code,omitempty"`
Message string `json:"message,omitempty"`
}
// ValidatePromoCode 验证优惠码(公开接口,注册前调用)
// POST /api/v1/auth/validate-promo-code
func (h *AuthHandler) ValidatePromoCode(c *gin.Context) {
var req ValidatePromoCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
promoCode, err := h.promoService.ValidatePromoCode(c.Request.Context(), req.Code)
if err != nil {
// 根据错误类型返回对应的错误码
errorCode := "PROMO_CODE_INVALID"
switch err {
case service.ErrPromoCodeNotFound:
errorCode = "PROMO_CODE_NOT_FOUND"
case service.ErrPromoCodeExpired:
errorCode = "PROMO_CODE_EXPIRED"
case service.ErrPromoCodeDisabled:
errorCode = "PROMO_CODE_DISABLED"
case service.ErrPromoCodeMaxUsed:
errorCode = "PROMO_CODE_MAX_USED"
case service.ErrPromoCodeAlreadyUsed:
errorCode = "PROMO_CODE_ALREADY_USED"
}
response.Success(c, ValidatePromoCodeResponse{
Valid: false,
ErrorCode: errorCode,
})
return
}
if promoCode == nil {
response.Success(c, ValidatePromoCodeResponse{
Valid: false,
ErrorCode: "PROMO_CODE_INVALID",
})
return
}
response.Success(c, ValidatePromoCodeResponse{
Valid: true,
BonusAmount: promoCode.BonusAmount,
})
}
......@@ -370,3 +370,35 @@ func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult
Errors: r.Errors,
}
}
func PromoCodeFromService(pc *service.PromoCode) *PromoCode {
if pc == nil {
return nil
}
return &PromoCode{
ID: pc.ID,
Code: pc.Code,
BonusAmount: pc.BonusAmount,
MaxUses: pc.MaxUses,
UsedCount: pc.UsedCount,
Status: pc.Status,
ExpiresAt: pc.ExpiresAt,
Notes: pc.Notes,
CreatedAt: pc.CreatedAt,
UpdatedAt: pc.UpdatedAt,
}
}
func PromoCodeUsageFromService(u *service.PromoCodeUsage) *PromoCodeUsage {
if u == nil {
return nil
}
return &PromoCodeUsage{
ID: u.ID,
PromoCodeID: u.PromoCodeID,
UserID: u.UserID,
BonusAmount: u.BonusAmount,
UsedAt: u.UsedAt,
User: UserFromServiceShallow(u.User),
}
}
......@@ -250,3 +250,28 @@ type BulkAssignResult struct {
Subscriptions []UserSubscription `json:"subscriptions"`
Errors []string `json:"errors"`
}
// PromoCode 注册优惠码
type PromoCode struct {
ID int64 `json:"id"`
Code string `json:"code"`
BonusAmount float64 `json:"bonus_amount"`
MaxUses int `json:"max_uses"`
UsedCount int `json:"used_count"`
Status string `json:"status"`
ExpiresAt *time.Time `json:"expires_at"`
Notes string `json:"notes"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// PromoCodeUsage 优惠码使用记录
type PromoCodeUsage struct {
ID int64 `json:"id"`
PromoCodeID int64 `json:"promo_code_id"`
UserID int64 `json:"user_id"`
BonusAmount float64 `json:"bonus_amount"`
UsedAt time.Time `json:"used_at"`
User *User `json:"user,omitempty"`
}
......@@ -16,6 +16,7 @@ type AdminHandlers struct {
AntigravityOAuth *admin.AntigravityOAuthHandler
Proxy *admin.ProxyHandler
Redeem *admin.RedeemHandler
Promo *admin.PromoHandler
Setting *admin.SettingHandler
System *admin.SystemHandler
Subscription *admin.SubscriptionHandler
......
......@@ -19,6 +19,7 @@ func ProvideAdminHandlers(
antigravityOAuthHandler *admin.AntigravityOAuthHandler,
proxyHandler *admin.ProxyHandler,
redeemHandler *admin.RedeemHandler,
promoHandler *admin.PromoHandler,
settingHandler *admin.SettingHandler,
systemHandler *admin.SystemHandler,
subscriptionHandler *admin.SubscriptionHandler,
......@@ -36,6 +37,7 @@ func ProvideAdminHandlers(
AntigravityOAuth: antigravityOAuthHandler,
Proxy: proxyHandler,
Redeem: redeemHandler,
Promo: promoHandler,
Setting: settingHandler,
System: systemHandler,
Subscription: subscriptionHandler,
......@@ -105,6 +107,7 @@ var ProviderSet = wire.NewSet(
admin.NewAntigravityOAuthHandler,
admin.NewProxyHandler,
admin.NewRedeemHandler,
admin.NewPromoHandler,
admin.NewSettingHandler,
ProvideSystemHandler,
admin.NewSubscriptionHandler,
......
package middleware
import (
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
)
// RateLimiter Redis 速率限制器
type RateLimiter struct {
redis *redis.Client
prefix string
}
// NewRateLimiter 创建速率限制器实例
func NewRateLimiter(redisClient *redis.Client) *RateLimiter {
return &RateLimiter{
redis: redisClient,
prefix: "rate_limit:",
}
}
// Limit 返回速率限制中间件
// key: 限制类型标识
// limit: 时间窗口内最大请求数
// window: 时间窗口
func (r *RateLimiter) Limit(key string, limit int, window time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
ip := c.ClientIP()
redisKey := r.prefix + key + ":" + ip
ctx := c.Request.Context()
// 使用 INCR 原子操作增加计数
count, err := r.redis.Incr(ctx, redisKey).Result()
if err != nil {
// Redis 错误时放行,避免影响正常服务
c.Next()
return
}
// 首次访问时设置过期时间
if count == 1 {
r.redis.Expire(ctx, redisKey, window)
}
// 超过限制
if count > int64(limit) {
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
"error": "rate limit exceeded",
"message": "Too many requests, please try again later",
})
return
}
c.Next()
}
}
package repository
import (
"context"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type promoCodeRepository struct {
client *dbent.Client
}
func NewPromoCodeRepository(client *dbent.Client) service.PromoCodeRepository {
return &promoCodeRepository{client: client}
}
func (r *promoCodeRepository) Create(ctx context.Context, code *service.PromoCode) error {
client := clientFromContext(ctx, r.client)
builder := client.PromoCode.Create().
SetCode(code.Code).
SetBonusAmount(code.BonusAmount).
SetMaxUses(code.MaxUses).
SetUsedCount(code.UsedCount).
SetStatus(code.Status).
SetNotes(code.Notes)
if code.ExpiresAt != nil {
builder.SetExpiresAt(*code.ExpiresAt)
}
created, err := builder.Save(ctx)
if err != nil {
return err
}
code.ID = created.ID
code.CreatedAt = created.CreatedAt
code.UpdatedAt = created.UpdatedAt
return nil
}
func (r *promoCodeRepository) GetByID(ctx context.Context, id int64) (*service.PromoCode, error) {
m, err := r.client.PromoCode.Query().
Where(promocode.IDEQ(id)).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrPromoCodeNotFound
}
return nil, err
}
return promoCodeEntityToService(m), nil
}
func (r *promoCodeRepository) GetByCode(ctx context.Context, code string) (*service.PromoCode, error) {
m, err := r.client.PromoCode.Query().
Where(promocode.CodeEqualFold(code)).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrPromoCodeNotFound
}
return nil, err
}
return promoCodeEntityToService(m), nil
}
func (r *promoCodeRepository) GetByCodeForUpdate(ctx context.Context, code string) (*service.PromoCode, error) {
client := clientFromContext(ctx, r.client)
m, err := client.PromoCode.Query().
Where(promocode.CodeEqualFold(code)).
ForUpdate().
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrPromoCodeNotFound
}
return nil, err
}
return promoCodeEntityToService(m), nil
}
func (r *promoCodeRepository) Update(ctx context.Context, code *service.PromoCode) error {
client := clientFromContext(ctx, r.client)
builder := client.PromoCode.UpdateOneID(code.ID).
SetCode(code.Code).
SetBonusAmount(code.BonusAmount).
SetMaxUses(code.MaxUses).
SetUsedCount(code.UsedCount).
SetStatus(code.Status).
SetNotes(code.Notes)
if code.ExpiresAt != nil {
builder.SetExpiresAt(*code.ExpiresAt)
} else {
builder.ClearExpiresAt()
}
updated, err := builder.Save(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return service.ErrPromoCodeNotFound
}
return err
}
code.UpdatedAt = updated.UpdatedAt
return nil
}
func (r *promoCodeRepository) Delete(ctx context.Context, id int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.PromoCode.Delete().Where(promocode.IDEQ(id)).Exec(ctx)
return err
}
func (r *promoCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.PromoCode, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "")
}
func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, search string) ([]service.PromoCode, *pagination.PaginationResult, error) {
q := r.client.PromoCode.Query()
if status != "" {
q = q.Where(promocode.StatusEQ(status))
}
if search != "" {
q = q.Where(promocode.CodeContainsFold(search))
}
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
}
codes, err := q.
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(promocode.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
outCodes := promoCodeEntitiesToService(codes)
return outCodes, paginationResultFromTotal(int64(total), params), nil
}
func (r *promoCodeRepository) CreateUsage(ctx context.Context, usage *service.PromoCodeUsage) error {
client := clientFromContext(ctx, r.client)
created, err := client.PromoCodeUsage.Create().
SetPromoCodeID(usage.PromoCodeID).
SetUserID(usage.UserID).
SetBonusAmount(usage.BonusAmount).
SetUsedAt(usage.UsedAt).
Save(ctx)
if err != nil {
return err
}
usage.ID = created.ID
return nil
}
func (r *promoCodeRepository) GetUsageByPromoCodeAndUser(ctx context.Context, promoCodeID, userID int64) (*service.PromoCodeUsage, error) {
m, err := r.client.PromoCodeUsage.Query().
Where(
promocodeusage.PromoCodeIDEQ(promoCodeID),
promocodeusage.UserIDEQ(userID),
).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, nil
}
return nil, err
}
return promoCodeUsageEntityToService(m), nil
}
func (r *promoCodeRepository) ListUsagesByPromoCode(ctx context.Context, promoCodeID int64, params pagination.PaginationParams) ([]service.PromoCodeUsage, *pagination.PaginationResult, error) {
q := r.client.PromoCodeUsage.Query().
Where(promocodeusage.PromoCodeIDEQ(promoCodeID))
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
}
usages, err := q.
WithUser().
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(promocodeusage.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
outUsages := promoCodeUsageEntitiesToService(usages)
return outUsages, paginationResultFromTotal(int64(total), params), nil
}
func (r *promoCodeRepository) IncrementUsedCount(ctx context.Context, id int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.PromoCode.UpdateOneID(id).
AddUsedCount(1).
Save(ctx)
return err
}
// Entity to Service conversions
func promoCodeEntityToService(m *dbent.PromoCode) *service.PromoCode {
if m == nil {
return nil
}
return &service.PromoCode{
ID: m.ID,
Code: m.Code,
BonusAmount: m.BonusAmount,
MaxUses: m.MaxUses,
UsedCount: m.UsedCount,
Status: m.Status,
ExpiresAt: m.ExpiresAt,
Notes: derefString(m.Notes),
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
}
}
func promoCodeEntitiesToService(models []*dbent.PromoCode) []service.PromoCode {
out := make([]service.PromoCode, 0, len(models))
for i := range models {
if s := promoCodeEntityToService(models[i]); s != nil {
out = append(out, *s)
}
}
return out
}
func promoCodeUsageEntityToService(m *dbent.PromoCodeUsage) *service.PromoCodeUsage {
if m == nil {
return nil
}
out := &service.PromoCodeUsage{
ID: m.ID,
PromoCodeID: m.PromoCodeID,
UserID: m.UserID,
BonusAmount: m.BonusAmount,
UsedAt: m.UsedAt,
}
if m.Edges.User != nil {
out.User = userEntityToService(m.Edges.User)
}
return out
}
func promoCodeUsageEntitiesToService(models []*dbent.PromoCodeUsage) []service.PromoCodeUsage {
out := make([]service.PromoCodeUsage, 0, len(models))
for i := range models {
if s := promoCodeUsageEntityToService(models[i]); s != nil {
out = append(out, *s)
}
}
return out
}
......@@ -45,6 +45,7 @@ var ProviderSet = wire.NewSet(
NewAccountRepository,
NewProxyRepository,
NewRedeemCodeRepository,
NewPromoCodeRepository,
NewUsageLogRepository,
NewSettingRepository,
NewUserSubscriptionRepository,
......
......@@ -13,6 +13,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/google/wire"
"github.com/redis/go-redis/v9"
)
// ProviderSet 提供服务器层的依赖
......@@ -30,6 +31,7 @@ func ProvideRouter(
apiKeyAuth middleware2.APIKeyAuthMiddleware,
apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService,
redisClient *redis.Client,
) *gin.Engine {
if cfg.Server.Mode == "release" {
gin.SetMode(gin.ReleaseMode)
......@@ -47,7 +49,7 @@ func ProvideRouter(
}
}
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg)
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg, redisClient)
}
// ProvideHTTPServer 提供 HTTP 服务器
......
......@@ -9,6 +9,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/web"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
)
// SetupRouter 配置路由器中间件和路由
......@@ -21,6 +22,7 @@ func SetupRouter(
apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService,
cfg *config.Config,
redisClient *redis.Client,
) *gin.Engine {
// 应用中间件
r.Use(middleware2.Logger())
......@@ -33,7 +35,7 @@ func SetupRouter(
}
// 注册路由
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg)
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg, redisClient)
return r
}
......@@ -48,6 +50,7 @@ func registerRoutes(
apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService,
cfg *config.Config,
redisClient *redis.Client,
) {
// 通用路由(健康检查、状态等)
routes.RegisterCommonRoutes(r)
......@@ -56,7 +59,7 @@ func registerRoutes(
v1 := r.Group("/api/v1")
// 注册各模块路由
routes.RegisterAuthRoutes(v1, h, jwtAuth)
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient)
routes.RegisterUserRoutes(v1, h, jwtAuth)
routes.RegisterAdminRoutes(v1, h, adminAuth)
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, cfg)
......
......@@ -44,6 +44,9 @@ func RegisterAdminRoutes(
// 卡密管理
registerRedeemCodeRoutes(admin, h)
// 优惠码管理
registerPromoCodeRoutes(admin, h)
// 系统设置
registerSettingsRoutes(admin, h)
......@@ -201,6 +204,18 @@ func registerRedeemCodeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
}
}
func registerPromoCodeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
promoCodes := admin.Group("/promo-codes")
{
promoCodes.GET("", h.Admin.Promo.List)
promoCodes.GET("/:id", h.Admin.Promo.GetByID)
promoCodes.POST("", h.Admin.Promo.Create)
promoCodes.PUT("/:id", h.Admin.Promo.Update)
promoCodes.DELETE("/:id", h.Admin.Promo.Delete)
promoCodes.GET("/:id/usages", h.Admin.Promo.GetUsages)
}
}
func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
adminSettings := admin.Group("/settings")
{
......
package routes
import (
"time"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/middleware"
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
)
// RegisterAuthRoutes 注册认证相关路由
func RegisterAuthRoutes(
v1 *gin.RouterGroup,
h *handler.Handlers,
jwtAuth middleware.JWTAuthMiddleware,
jwtAuth servermiddleware.JWTAuthMiddleware,
redisClient *redis.Client,
) {
// 创建速率限制器
rateLimiter := middleware.NewRateLimiter(redisClient)
// 公开接口
auth := v1.Group("/auth")
{
auth.POST("/register", h.Auth.Register)
auth.POST("/login", h.Auth.Login)
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
// 优惠码验证接口添加速率限制:每分钟最多 10 次
auth.POST("/validate-promo-code", rateLimiter.Limit("validate-promo", 10, time.Minute), h.Auth.ValidatePromoCode)
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment