Commit 5906f9ab authored by yangjianbo's avatar yangjianbo
Browse files

fix(数据层): 修复数据完整性与仓储一致性问题

## 数据完整性修复 (fix-critical-data-integrity)
- 添加 error_translate.go 统一错误转换层
- 修复 nil 输入和 NotFound 错误处理
- 增强仓储层错误一致性

## 仓储一致性修复 (fix-high-repository-consistency)
- Group schema 添加 default_validity_days 字段
- Account schema 添加 proxy edge 关联
- 新增 UsageLog ent schema 定义
- 修复 UpdateBalance/UpdateConcurrency 受影响行数校验

## 数据卫生修复 (fix-medium-data-hygiene)
- UserSubscription 添加软删除支持 (SoftDeleteMixin)
- RedeemCode/Setting 添加硬删除策略文档
- account_groups/user_allowed_groups 的 created_at 声明 timestamptz
- 停止写入 legacy users.allowed_groups 列
- 新增迁移: 011-014 (索引优化、软删除、孤立数据审计、列清理)

## 测试补充
- 添加 UserSubscription 软删除测试
- 添加迁移回归测试
- 添加 NotFound 错误测试

🤖 Generated with [Claude Code](https://claude.com/claude-code

)
Co-Authored-By: default avatarClaude Opus 4.5 <noreply@anthropic.com>
parent 820bb16c
......@@ -5,6 +5,7 @@ package usersubscription
import (
"time"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
)
......@@ -18,6 +19,8 @@ const (
FieldCreatedAt = "created_at"
// FieldUpdatedAt holds the string denoting the updated_at field in the database.
FieldUpdatedAt = "updated_at"
// FieldDeletedAt holds the string denoting the deleted_at field in the database.
FieldDeletedAt = "deleted_at"
// FieldUserID holds the string denoting the user_id field in the database.
FieldUserID = "user_id"
// FieldGroupID holds the string denoting the group_id field in the database.
......@@ -52,6 +55,8 @@ const (
EdgeGroup = "group"
// EdgeAssignedByUser holds the string denoting the assigned_by_user edge name in mutations.
EdgeAssignedByUser = "assigned_by_user"
// EdgeUsageLogs holds the string denoting the usage_logs edge name in mutations.
EdgeUsageLogs = "usage_logs"
// Table holds the table name of the usersubscription in the database.
Table = "user_subscriptions"
// UserTable is the table that holds the user relation/edge.
......@@ -75,6 +80,13 @@ const (
AssignedByUserInverseTable = "users"
// AssignedByUserColumn is the table column denoting the assigned_by_user relation/edge.
AssignedByUserColumn = "assigned_by"
// UsageLogsTable is the table that holds the usage_logs relation/edge.
UsageLogsTable = "usage_logs"
// UsageLogsInverseTable is the table name for the UsageLog entity.
// It exists in this package in order to avoid circular dependency with the "usagelog" package.
UsageLogsInverseTable = "usage_logs"
// UsageLogsColumn is the table column denoting the usage_logs relation/edge.
UsageLogsColumn = "subscription_id"
)
// Columns holds all SQL columns for usersubscription fields.
......@@ -82,6 +94,7 @@ var Columns = []string{
FieldID,
FieldCreatedAt,
FieldUpdatedAt,
FieldDeletedAt,
FieldUserID,
FieldGroupID,
FieldStartsAt,
......@@ -108,7 +121,14 @@ func ValidColumn(column string) bool {
return false
}
// Note that the variables below are initialized by the runtime
// package on the initialization of the application. Therefore,
// it should be imported in the main as follows:
//
// import _ "github.com/Wei-Shaw/sub2api/ent/runtime"
var (
Hooks [1]ent.Hook
Interceptors [1]ent.Interceptor
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
DefaultCreatedAt func() time.Time
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
......@@ -147,6 +167,11 @@ func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
}
// ByDeletedAt orders the results by the deleted_at field.
func ByDeletedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDeletedAt, opts...).ToFunc()
}
// ByUserID orders the results by the user_id field.
func ByUserID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUserID, opts...).ToFunc()
......@@ -237,6 +262,20 @@ func ByAssignedByUserField(field string, opts ...sql.OrderTermOption) OrderOptio
sqlgraph.OrderByNeighborTerms(s, newAssignedByUserStep(), sql.OrderByField(field, opts...))
}
}
// ByUsageLogsCount orders the results by usage_logs count.
func ByUsageLogsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborsCount(s, newUsageLogsStep(), opts...)
}
}
// ByUsageLogs orders the results by usage_logs terms.
func ByUsageLogs(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newUsageLogsStep(), append([]sql.OrderTerm{term}, terms...)...)
}
}
func newUserStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
......@@ -258,3 +297,10 @@ func newAssignedByUserStep() *sqlgraph.Step {
sqlgraph.Edge(sqlgraph.M2O, true, AssignedByUserTable, AssignedByUserColumn),
)
}
func newUsageLogsStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(UsageLogsInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn),
)
}
......@@ -65,6 +65,11 @@ func UpdatedAt(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldUpdatedAt, v))
}
// DeletedAt applies equality check predicate on the "deleted_at" field. It's identical to DeletedAtEQ.
func DeletedAt(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldDeletedAt, v))
}
// UserID applies equality check predicate on the "user_id" field. It's identical to UserIDEQ.
func UserID(v int64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldUserID, v))
......@@ -215,6 +220,56 @@ func UpdatedAtLTE(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldLTE(FieldUpdatedAt, v))
}
// DeletedAtEQ applies the EQ predicate on the "deleted_at" field.
func DeletedAtEQ(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldDeletedAt, v))
}
// DeletedAtNEQ applies the NEQ predicate on the "deleted_at" field.
func DeletedAtNEQ(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNEQ(FieldDeletedAt, v))
}
// DeletedAtIn applies the In predicate on the "deleted_at" field.
func DeletedAtIn(vs ...time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldIn(FieldDeletedAt, vs...))
}
// DeletedAtNotIn applies the NotIn predicate on the "deleted_at" field.
func DeletedAtNotIn(vs ...time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNotIn(FieldDeletedAt, vs...))
}
// DeletedAtGT applies the GT predicate on the "deleted_at" field.
func DeletedAtGT(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldGT(FieldDeletedAt, v))
}
// DeletedAtGTE applies the GTE predicate on the "deleted_at" field.
func DeletedAtGTE(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldGTE(FieldDeletedAt, v))
}
// DeletedAtLT applies the LT predicate on the "deleted_at" field.
func DeletedAtLT(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldLT(FieldDeletedAt, v))
}
// DeletedAtLTE applies the LTE predicate on the "deleted_at" field.
func DeletedAtLTE(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldLTE(FieldDeletedAt, v))
}
// DeletedAtIsNil applies the IsNil predicate on the "deleted_at" field.
func DeletedAtIsNil() predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldIsNull(FieldDeletedAt))
}
// DeletedAtNotNil applies the NotNil predicate on the "deleted_at" field.
func DeletedAtNotNil() predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNotNull(FieldDeletedAt))
}
// UserIDEQ applies the EQ predicate on the "user_id" field.
func UserIDEQ(v int64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldUserID, v))
......@@ -884,6 +939,29 @@ func HasAssignedByUserWith(preds ...predicate.User) predicate.UserSubscription {
})
}
// HasUsageLogs applies the HasEdge predicate on the "usage_logs" edge.
func HasUsageLogs() predicate.UserSubscription {
return predicate.UserSubscription(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasUsageLogsWith applies the HasEdge predicate on the "usage_logs" edge with a given conditions (other predicates).
func HasUsageLogsWith(preds ...predicate.UsageLog) predicate.UserSubscription {
return predicate.UserSubscription(func(s *sql.Selector) {
step := newUsageLogsStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// And groups predicates with the AND operator between them.
func And(predicates ...predicate.UserSubscription) predicate.UserSubscription {
return predicate.UserSubscription(sql.AndPredicates(predicates...))
......
......@@ -12,6 +12,7 @@ import (
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
......@@ -52,6 +53,20 @@ func (_c *UserSubscriptionCreate) SetNillableUpdatedAt(v *time.Time) *UserSubscr
return _c
}
// SetDeletedAt sets the "deleted_at" field.
func (_c *UserSubscriptionCreate) SetDeletedAt(v time.Time) *UserSubscriptionCreate {
_c.mutation.SetDeletedAt(v)
return _c
}
// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil.
func (_c *UserSubscriptionCreate) SetNillableDeletedAt(v *time.Time) *UserSubscriptionCreate {
if v != nil {
_c.SetDeletedAt(*v)
}
return _c
}
// SetUserID sets the "user_id" field.
func (_c *UserSubscriptionCreate) SetUserID(v int64) *UserSubscriptionCreate {
_c.mutation.SetUserID(v)
......@@ -245,6 +260,21 @@ func (_c *UserSubscriptionCreate) SetAssignedByUser(v *User) *UserSubscriptionCr
return _c.SetAssignedByUserID(v.ID)
}
// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs.
func (_c *UserSubscriptionCreate) AddUsageLogIDs(ids ...int64) *UserSubscriptionCreate {
_c.mutation.AddUsageLogIDs(ids...)
return _c
}
// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity.
func (_c *UserSubscriptionCreate) AddUsageLogs(v ...*UsageLog) *UserSubscriptionCreate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _c.AddUsageLogIDs(ids...)
}
// Mutation returns the UserSubscriptionMutation object of the builder.
func (_c *UserSubscriptionCreate) Mutation() *UserSubscriptionMutation {
return _c.mutation
......@@ -252,7 +282,9 @@ func (_c *UserSubscriptionCreate) Mutation() *UserSubscriptionMutation {
// Save creates the UserSubscription in the database.
func (_c *UserSubscriptionCreate) Save(ctx context.Context) (*UserSubscription, error) {
_c.defaults()
if err := _c.defaults(); err != nil {
return nil, err
}
return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
}
......@@ -279,12 +311,18 @@ func (_c *UserSubscriptionCreate) ExecX(ctx context.Context) {
}
// defaults sets the default values of the builder before save.
func (_c *UserSubscriptionCreate) defaults() {
func (_c *UserSubscriptionCreate) defaults() error {
if _, ok := _c.mutation.CreatedAt(); !ok {
if usersubscription.DefaultCreatedAt == nil {
return fmt.Errorf("ent: uninitialized usersubscription.DefaultCreatedAt (forgotten import ent/runtime?)")
}
v := usersubscription.DefaultCreatedAt()
_c.mutation.SetCreatedAt(v)
}
if _, ok := _c.mutation.UpdatedAt(); !ok {
if usersubscription.DefaultUpdatedAt == nil {
return fmt.Errorf("ent: uninitialized usersubscription.DefaultUpdatedAt (forgotten import ent/runtime?)")
}
v := usersubscription.DefaultUpdatedAt()
_c.mutation.SetUpdatedAt(v)
}
......@@ -305,9 +343,13 @@ func (_c *UserSubscriptionCreate) defaults() {
_c.mutation.SetMonthlyUsageUsd(v)
}
if _, ok := _c.mutation.AssignedAt(); !ok {
if usersubscription.DefaultAssignedAt == nil {
return fmt.Errorf("ent: uninitialized usersubscription.DefaultAssignedAt (forgotten import ent/runtime?)")
}
v := usersubscription.DefaultAssignedAt()
_c.mutation.SetAssignedAt(v)
}
return nil
}
// check runs all checks and user-defined validators on the builder.
......@@ -391,6 +433,10 @@ func (_c *UserSubscriptionCreate) createSpec() (*UserSubscription, *sqlgraph.Cre
_spec.SetField(usersubscription.FieldUpdatedAt, field.TypeTime, value)
_node.UpdatedAt = value
}
if value, ok := _c.mutation.DeletedAt(); ok {
_spec.SetField(usersubscription.FieldDeletedAt, field.TypeTime, value)
_node.DeletedAt = &value
}
if value, ok := _c.mutation.StartsAt(); ok {
_spec.SetField(usersubscription.FieldStartsAt, field.TypeTime, value)
_node.StartsAt = value
......@@ -486,6 +532,22 @@ func (_c *UserSubscriptionCreate) createSpec() (*UserSubscription, *sqlgraph.Cre
_node.AssignedBy = &nodes[0]
_spec.Edges = append(_spec.Edges, edge)
}
if nodes := _c.mutation.UsageLogsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: usersubscription.UsageLogsTable,
Columns: []string{usersubscription.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges = append(_spec.Edges, edge)
}
return _node, _spec
}
......@@ -550,6 +612,24 @@ func (u *UserSubscriptionUpsert) UpdateUpdatedAt() *UserSubscriptionUpsert {
return u
}
// SetDeletedAt sets the "deleted_at" field.
func (u *UserSubscriptionUpsert) SetDeletedAt(v time.Time) *UserSubscriptionUpsert {
u.Set(usersubscription.FieldDeletedAt, v)
return u
}
// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create.
func (u *UserSubscriptionUpsert) UpdateDeletedAt() *UserSubscriptionUpsert {
u.SetExcluded(usersubscription.FieldDeletedAt)
return u
}
// ClearDeletedAt clears the value of the "deleted_at" field.
func (u *UserSubscriptionUpsert) ClearDeletedAt() *UserSubscriptionUpsert {
u.SetNull(usersubscription.FieldDeletedAt)
return u
}
// SetUserID sets the "user_id" field.
func (u *UserSubscriptionUpsert) SetUserID(v int64) *UserSubscriptionUpsert {
u.Set(usersubscription.FieldUserID, v)
......@@ -825,6 +905,27 @@ func (u *UserSubscriptionUpsertOne) UpdateUpdatedAt() *UserSubscriptionUpsertOne
})
}
// SetDeletedAt sets the "deleted_at" field.
func (u *UserSubscriptionUpsertOne) SetDeletedAt(v time.Time) *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.SetDeletedAt(v)
})
}
// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create.
func (u *UserSubscriptionUpsertOne) UpdateDeletedAt() *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.UpdateDeletedAt()
})
}
// ClearDeletedAt clears the value of the "deleted_at" field.
func (u *UserSubscriptionUpsertOne) ClearDeletedAt() *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.ClearDeletedAt()
})
}
// SetUserID sets the "user_id" field.
func (u *UserSubscriptionUpsertOne) SetUserID(v int64) *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
......@@ -1302,6 +1403,27 @@ func (u *UserSubscriptionUpsertBulk) UpdateUpdatedAt() *UserSubscriptionUpsertBu
})
}
// SetDeletedAt sets the "deleted_at" field.
func (u *UserSubscriptionUpsertBulk) SetDeletedAt(v time.Time) *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.SetDeletedAt(v)
})
}
// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create.
func (u *UserSubscriptionUpsertBulk) UpdateDeletedAt() *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.UpdateDeletedAt()
})
}
// ClearDeletedAt clears the value of the "deleted_at" field.
func (u *UserSubscriptionUpsertBulk) ClearDeletedAt() *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.ClearDeletedAt()
})
}
// SetUserID sets the "user_id" field.
func (u *UserSubscriptionUpsertBulk) SetUserID(v int64) *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
......
......@@ -4,6 +4,7 @@ package ent
import (
"context"
"database/sql/driver"
"fmt"
"math"
......@@ -13,6 +14,7 @@ import (
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
......@@ -27,6 +29,7 @@ type UserSubscriptionQuery struct {
withUser *UserQuery
withGroup *GroupQuery
withAssignedByUser *UserQuery
withUsageLogs *UsageLogQuery
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
......@@ -129,6 +132,28 @@ func (_q *UserSubscriptionQuery) QueryAssignedByUser() *UserQuery {
return query
}
// QueryUsageLogs chains the current query on the "usage_logs" edge.
func (_q *UserSubscriptionQuery) QueryUsageLogs() *UsageLogQuery {
query := (&UsageLogClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(usersubscription.Table, usersubscription.FieldID, selector),
sqlgraph.To(usagelog.Table, usagelog.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, usersubscription.UsageLogsTable, usersubscription.UsageLogsColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// First returns the first UserSubscription entity from the query.
// Returns a *NotFoundError when no UserSubscription was found.
func (_q *UserSubscriptionQuery) First(ctx context.Context) (*UserSubscription, error) {
......@@ -324,6 +349,7 @@ func (_q *UserSubscriptionQuery) Clone() *UserSubscriptionQuery {
withUser: _q.withUser.Clone(),
withGroup: _q.withGroup.Clone(),
withAssignedByUser: _q.withAssignedByUser.Clone(),
withUsageLogs: _q.withUsageLogs.Clone(),
// clone intermediate query.
sql: _q.sql.Clone(),
path: _q.path,
......@@ -363,6 +389,17 @@ func (_q *UserSubscriptionQuery) WithAssignedByUser(opts ...func(*UserQuery)) *U
return _q
}
// WithUsageLogs tells the query-builder to eager-load the nodes that are connected to
// the "usage_logs" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserSubscriptionQuery) WithUsageLogs(opts ...func(*UsageLogQuery)) *UserSubscriptionQuery {
query := (&UsageLogClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withUsageLogs = query
return _q
}
// GroupBy is used to group vertices by one or more fields/columns.
// It is often used with aggregate functions, like: count, max, mean, min, sum.
//
......@@ -441,10 +478,11 @@ func (_q *UserSubscriptionQuery) sqlAll(ctx context.Context, hooks ...queryHook)
var (
nodes = []*UserSubscription{}
_spec = _q.querySpec()
loadedTypes = [3]bool{
loadedTypes = [4]bool{
_q.withUser != nil,
_q.withGroup != nil,
_q.withAssignedByUser != nil,
_q.withUsageLogs != nil,
}
)
_spec.ScanValues = func(columns []string) ([]any, error) {
......@@ -483,6 +521,13 @@ func (_q *UserSubscriptionQuery) sqlAll(ctx context.Context, hooks ...queryHook)
return nil, err
}
}
if query := _q.withUsageLogs; query != nil {
if err := _q.loadUsageLogs(ctx, query, nodes,
func(n *UserSubscription) { n.Edges.UsageLogs = []*UsageLog{} },
func(n *UserSubscription, e *UsageLog) { n.Edges.UsageLogs = append(n.Edges.UsageLogs, e) }); err != nil {
return nil, err
}
}
return nodes, nil
}
......@@ -576,6 +621,39 @@ func (_q *UserSubscriptionQuery) loadAssignedByUser(ctx context.Context, query *
}
return nil
}
func (_q *UserSubscriptionQuery) loadUsageLogs(ctx context.Context, query *UsageLogQuery, nodes []*UserSubscription, init func(*UserSubscription), assign func(*UserSubscription, *UsageLog)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*UserSubscription)
for i := range nodes {
fks = append(fks, nodes[i].ID)
nodeids[nodes[i].ID] = nodes[i]
if init != nil {
init(nodes[i])
}
}
if len(query.ctx.Fields) > 0 {
query.ctx.AppendFieldOnce(usagelog.FieldSubscriptionID)
}
query.Where(predicate.UsageLog(func(s *sql.Selector) {
s.Where(sql.InValues(s.C(usersubscription.UsageLogsColumn), fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
fk := n.SubscriptionID
if fk == nil {
return fmt.Errorf(`foreign-key "subscription_id" is nil for node %v`, n.ID)
}
node, ok := nodeids[*fk]
if !ok {
return fmt.Errorf(`unexpected referenced foreign-key "subscription_id" returned %v for node %v`, *fk, n.ID)
}
assign(node, n)
}
return nil
}
func (_q *UserSubscriptionQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
......
......@@ -13,6 +13,7 @@ import (
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
......@@ -36,6 +37,26 @@ func (_u *UserSubscriptionUpdate) SetUpdatedAt(v time.Time) *UserSubscriptionUpd
return _u
}
// SetDeletedAt sets the "deleted_at" field.
func (_u *UserSubscriptionUpdate) SetDeletedAt(v time.Time) *UserSubscriptionUpdate {
_u.mutation.SetDeletedAt(v)
return _u
}
// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil.
func (_u *UserSubscriptionUpdate) SetNillableDeletedAt(v *time.Time) *UserSubscriptionUpdate {
if v != nil {
_u.SetDeletedAt(*v)
}
return _u
}
// ClearDeletedAt clears the value of the "deleted_at" field.
func (_u *UserSubscriptionUpdate) ClearDeletedAt() *UserSubscriptionUpdate {
_u.mutation.ClearDeletedAt()
return _u
}
// SetUserID sets the "user_id" field.
func (_u *UserSubscriptionUpdate) SetUserID(v int64) *UserSubscriptionUpdate {
_u.mutation.SetUserID(v)
......@@ -312,6 +333,21 @@ func (_u *UserSubscriptionUpdate) SetAssignedByUser(v *User) *UserSubscriptionUp
return _u.SetAssignedByUserID(v.ID)
}
// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs.
func (_u *UserSubscriptionUpdate) AddUsageLogIDs(ids ...int64) *UserSubscriptionUpdate {
_u.mutation.AddUsageLogIDs(ids...)
return _u
}
// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity.
func (_u *UserSubscriptionUpdate) AddUsageLogs(v ...*UsageLog) *UserSubscriptionUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddUsageLogIDs(ids...)
}
// Mutation returns the UserSubscriptionMutation object of the builder.
func (_u *UserSubscriptionUpdate) Mutation() *UserSubscriptionMutation {
return _u.mutation
......@@ -335,9 +371,32 @@ func (_u *UserSubscriptionUpdate) ClearAssignedByUser() *UserSubscriptionUpdate
return _u
}
// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity.
func (_u *UserSubscriptionUpdate) ClearUsageLogs() *UserSubscriptionUpdate {
_u.mutation.ClearUsageLogs()
return _u
}
// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs.
func (_u *UserSubscriptionUpdate) RemoveUsageLogIDs(ids ...int64) *UserSubscriptionUpdate {
_u.mutation.RemoveUsageLogIDs(ids...)
return _u
}
// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities.
func (_u *UserSubscriptionUpdate) RemoveUsageLogs(v ...*UsageLog) *UserSubscriptionUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveUsageLogIDs(ids...)
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *UserSubscriptionUpdate) Save(ctx context.Context) (int, error) {
_u.defaults()
if err := _u.defaults(); err != nil {
return 0, err
}
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
......@@ -364,11 +423,15 @@ func (_u *UserSubscriptionUpdate) ExecX(ctx context.Context) {
}
// defaults sets the default values of the builder before save.
func (_u *UserSubscriptionUpdate) defaults() {
func (_u *UserSubscriptionUpdate) defaults() error {
if _, ok := _u.mutation.UpdatedAt(); !ok {
if usersubscription.UpdateDefaultUpdatedAt == nil {
return fmt.Errorf("ent: uninitialized usersubscription.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)")
}
v := usersubscription.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
return nil
}
// check runs all checks and user-defined validators on the builder.
......@@ -402,6 +465,12 @@ func (_u *UserSubscriptionUpdate) sqlSave(ctx context.Context) (_node int, err e
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(usersubscription.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.DeletedAt(); ok {
_spec.SetField(usersubscription.FieldDeletedAt, field.TypeTime, value)
}
if _u.mutation.DeletedAtCleared() {
_spec.ClearField(usersubscription.FieldDeletedAt, field.TypeTime)
}
if value, ok := _u.mutation.StartsAt(); ok {
_spec.SetField(usersubscription.FieldStartsAt, field.TypeTime, value)
}
......@@ -543,6 +612,51 @@ func (_u *UserSubscriptionUpdate) sqlSave(ctx context.Context) (_node int, err e
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.UsageLogsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: usersubscription.UsageLogsTable,
Columns: []string{usersubscription.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedUsageLogsIDs(); len(nodes) > 0 && !_u.mutation.UsageLogsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: usersubscription.UsageLogsTable,
Columns: []string{usersubscription.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.UsageLogsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: usersubscription.UsageLogsTable,
Columns: []string{usersubscription.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{usersubscription.Label}
......@@ -569,6 +683,26 @@ func (_u *UserSubscriptionUpdateOne) SetUpdatedAt(v time.Time) *UserSubscription
return _u
}
// SetDeletedAt sets the "deleted_at" field.
func (_u *UserSubscriptionUpdateOne) SetDeletedAt(v time.Time) *UserSubscriptionUpdateOne {
_u.mutation.SetDeletedAt(v)
return _u
}
// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil.
func (_u *UserSubscriptionUpdateOne) SetNillableDeletedAt(v *time.Time) *UserSubscriptionUpdateOne {
if v != nil {
_u.SetDeletedAt(*v)
}
return _u
}
// ClearDeletedAt clears the value of the "deleted_at" field.
func (_u *UserSubscriptionUpdateOne) ClearDeletedAt() *UserSubscriptionUpdateOne {
_u.mutation.ClearDeletedAt()
return _u
}
// SetUserID sets the "user_id" field.
func (_u *UserSubscriptionUpdateOne) SetUserID(v int64) *UserSubscriptionUpdateOne {
_u.mutation.SetUserID(v)
......@@ -845,6 +979,21 @@ func (_u *UserSubscriptionUpdateOne) SetAssignedByUser(v *User) *UserSubscriptio
return _u.SetAssignedByUserID(v.ID)
}
// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs.
func (_u *UserSubscriptionUpdateOne) AddUsageLogIDs(ids ...int64) *UserSubscriptionUpdateOne {
_u.mutation.AddUsageLogIDs(ids...)
return _u
}
// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity.
func (_u *UserSubscriptionUpdateOne) AddUsageLogs(v ...*UsageLog) *UserSubscriptionUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddUsageLogIDs(ids...)
}
// Mutation returns the UserSubscriptionMutation object of the builder.
func (_u *UserSubscriptionUpdateOne) Mutation() *UserSubscriptionMutation {
return _u.mutation
......@@ -868,6 +1017,27 @@ func (_u *UserSubscriptionUpdateOne) ClearAssignedByUser() *UserSubscriptionUpda
return _u
}
// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity.
func (_u *UserSubscriptionUpdateOne) ClearUsageLogs() *UserSubscriptionUpdateOne {
_u.mutation.ClearUsageLogs()
return _u
}
// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs.
func (_u *UserSubscriptionUpdateOne) RemoveUsageLogIDs(ids ...int64) *UserSubscriptionUpdateOne {
_u.mutation.RemoveUsageLogIDs(ids...)
return _u
}
// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities.
func (_u *UserSubscriptionUpdateOne) RemoveUsageLogs(v ...*UsageLog) *UserSubscriptionUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveUsageLogIDs(ids...)
}
// Where appends a list predicates to the UserSubscriptionUpdate builder.
func (_u *UserSubscriptionUpdateOne) Where(ps ...predicate.UserSubscription) *UserSubscriptionUpdateOne {
_u.mutation.Where(ps...)
......@@ -883,7 +1053,9 @@ func (_u *UserSubscriptionUpdateOne) Select(field string, fields ...string) *Use
// Save executes the query and returns the updated UserSubscription entity.
func (_u *UserSubscriptionUpdateOne) Save(ctx context.Context) (*UserSubscription, error) {
_u.defaults()
if err := _u.defaults(); err != nil {
return nil, err
}
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
......@@ -910,11 +1082,15 @@ func (_u *UserSubscriptionUpdateOne) ExecX(ctx context.Context) {
}
// defaults sets the default values of the builder before save.
func (_u *UserSubscriptionUpdateOne) defaults() {
func (_u *UserSubscriptionUpdateOne) defaults() error {
if _, ok := _u.mutation.UpdatedAt(); !ok {
if usersubscription.UpdateDefaultUpdatedAt == nil {
return fmt.Errorf("ent: uninitialized usersubscription.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)")
}
v := usersubscription.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
return nil
}
// check runs all checks and user-defined validators on the builder.
......@@ -965,6 +1141,12 @@ func (_u *UserSubscriptionUpdateOne) sqlSave(ctx context.Context) (_node *UserSu
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(usersubscription.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.DeletedAt(); ok {
_spec.SetField(usersubscription.FieldDeletedAt, field.TypeTime, value)
}
if _u.mutation.DeletedAtCleared() {
_spec.ClearField(usersubscription.FieldDeletedAt, field.TypeTime)
}
if value, ok := _u.mutation.StartsAt(); ok {
_spec.SetField(usersubscription.FieldStartsAt, field.TypeTime, value)
}
......@@ -1106,6 +1288,51 @@ func (_u *UserSubscriptionUpdateOne) sqlSave(ctx context.Context) (_node *UserSu
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.UsageLogsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: usersubscription.UsageLogsTable,
Columns: []string{usersubscription.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedUsageLogsIDs(); len(nodes) > 0 && !_u.mutation.UsageLogsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: usersubscription.UsageLogsTable,
Columns: []string{usersubscription.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.UsageLogsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: usersubscription.UsageLogsTable,
Columns: []string{usersubscription.UsageLogsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
_node = &UserSubscription{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues
......
......@@ -14,6 +14,7 @@ import (
"context"
"database/sql"
"encoding/json"
"errors"
"strconv"
"time"
......@@ -56,7 +57,7 @@ func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *accoun
func (r *accountRepository) Create(ctx context.Context, account *service.Account) error {
if account == nil {
return nil
return service.ErrAccountNilInput
}
builder := r.client.Account.Create().
......@@ -98,7 +99,7 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
created, err := builder.Save(ctx)
if err != nil {
return err
return translatePersistenceError(err, service.ErrAccountNotFound, nil)
}
account.ID = created.ID
......@@ -231,11 +232,32 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
}
func (r *accountRepository) Delete(ctx context.Context, id int64) error {
if _, err := r.client.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(id)).Exec(ctx); err != nil {
// 使用事务保证账号与关联分组的删除原子性
tx, err := r.client.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
return err
}
_, err := r.client.Account.Delete().Where(dbaccount.IDEQ(id)).Exec(ctx)
return err
var txClient *dbent.Client
if err == nil {
defer func() { _ = tx.Rollback() }()
txClient = tx.Client()
} else {
// 已处于外部事务中(ErrTxStarted),复用当前 client
txClient = r.client
}
if _, err := txClient.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(id)).Exec(ctx); err != nil {
return err
}
if _, err := txClient.Account.Delete().Where(dbaccount.IDEQ(id)).Exec(ctx); err != nil {
return err
}
if tx != nil {
return tx.Commit()
}
return nil
}
func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
......@@ -393,25 +415,49 @@ func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]s
}
func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
if _, err := r.client.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(accountID)).Exec(ctx); err != nil {
// 使用事务保证删除旧绑定与创建新绑定的原子性
tx, err := r.client.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
return err
}
var txClient *dbent.Client
if err == nil {
defer func() { _ = tx.Rollback() }()
txClient = tx.Client()
} else {
// 已处于外部事务中(ErrTxStarted),复用当前 client
txClient = r.client
}
if _, err := txClient.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(accountID)).Exec(ctx); err != nil {
return err
}
if len(groupIDs) == 0 {
if tx != nil {
return tx.Commit()
}
return nil
}
builders := make([]*dbent.AccountGroupCreate, 0, len(groupIDs))
for i, groupID := range groupIDs {
builders = append(builders, r.client.AccountGroup.Create().
builders = append(builders, txClient.AccountGroup.Create().
SetAccountID(accountID).
SetGroupID(groupID).
SetPriority(i+1),
)
}
_, err := r.client.AccountGroup.CreateBulk(builders...).Save(ctx)
return err
if _, err := txClient.AccountGroup.CreateBulk(builders...).Save(ctx); err != nil {
return err
}
if tx != nil {
return tx.Commit()
}
return nil
}
func (r *accountRepository) ListSchedulable(ctx context.Context) ([]service.Account, error) {
......@@ -555,24 +601,30 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m
return nil
}
accountExtra, err := r.client.Account.Query().
Where(dbaccount.IDEQ(id)).
Select(dbaccount.FieldExtra).
Only(ctx)
// 使用 JSONB 合并操作实现原子更新,避免读-改-写的并发丢失更新问题
payload, err := json.Marshal(updates)
if err != nil {
return translatePersistenceError(err, service.ErrAccountNotFound, nil)
return err
}
extra := normalizeJSONMap(accountExtra.Extra)
for k, v := range updates {
extra[k] = v
client := clientFromContext(ctx, r.client)
result, err := client.ExecContext(
ctx,
"UPDATE accounts SET extra = COALESCE(extra, '{}'::jsonb) || $1::jsonb, updated_at = NOW() WHERE id = $2 AND deleted_at IS NULL",
payload, id,
)
if err != nil {
return err
}
_, err = r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
SetExtra(extra).
Save(ctx)
return err
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return service.ErrAccountNotFound
}
return nil
}
func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
......
......@@ -318,12 +318,13 @@ func groupEntityToService(g *dbent.Group) *service.Group {
RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive,
Status: g.Status,
SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUsd,
WeeklyLimitUSD: g.WeeklyLimitUsd,
MonthlyLimitUSD: g.MonthlyLimitUsd,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUsd,
WeeklyLimitUSD: g.WeeklyLimitUsd,
MonthlyLimitUSD: g.MonthlyLimitUsd,
DefaultValidityDays: g.DefaultValidityDays,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
}
}
......
package repository
import (
"context"
"database/sql"
"errors"
"strings"
......@@ -10,6 +11,25 @@ import (
"github.com/lib/pq"
)
// clientFromContext 从 context 中获取事务 client,如果不存在则返回默认 client。
//
// 这个辅助函数支持 repository 方法在事务上下文中工作:
// - 如果 context 中存在事务(通过 ent.NewTxContext 设置),返回事务的 client
// - 否则返回传入的默认 client
//
// 使用示例:
//
// func (r *someRepo) SomeMethod(ctx context.Context) error {
// client := clientFromContext(ctx, r.client)
// return client.SomeEntity.Create().Save(ctx)
// }
func clientFromContext(ctx context.Context, defaultClient *dbent.Client) *dbent.Client {
if tx := dbent.TxFromContext(ctx); tx != nil {
return tx.Client()
}
return defaultClient
}
// translatePersistenceError 将数据库层错误翻译为业务层错误。
//
// 这是 Repository 层的核心错误处理函数,确保数据库细节不会泄露到业务层。
......
......@@ -42,7 +42,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetSubscriptionType(groupIn.SubscriptionType).
SetNillableDailyLimitUsd(groupIn.DailyLimitUSD).
SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD).
SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD)
SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD).
SetDefaultValidityDays(groupIn.DefaultValidityDays)
created, err := builder.Save(ctx)
if err == nil {
......@@ -79,6 +80,7 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetNillableDailyLimitUsd(groupIn.DailyLimitUSD).
SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD).
SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD).
SetDefaultValidityDays(groupIn.DefaultValidityDays).
Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
......@@ -89,7 +91,7 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
func (r *groupRepository) Delete(ctx context.Context, id int64) error {
_, err := r.client.Group.Delete().Where(group.IDEQ(id)).Exec(ctx)
return err
return translatePersistenceError(err, service.ErrGroupNotFound, nil)
}
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
......@@ -239,8 +241,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
// err 为 dbent.ErrTxStarted 时,复用当前 client 参与同一事务。
// Lock the group row to avoid concurrent writes while we cascade.
// 这里使用 exec.QueryContext 手动扫描,确保同一事务内加锁并能区分未找到与其他错误。
rows, err := exec.QueryContext(ctx, "SELECT id FROM groups WHERE id = $1 FOR UPDATE", id)
// 这里使用 exec.QueryContext 手动扫描,确保同一事务内加锁并能区分"未找到"与其他错误。
rows, err := exec.QueryContext(ctx, "SELECT id FROM groups WHERE id = $1 AND deleted_at IS NULL FOR UPDATE", id)
if err != nil {
return nil, err
}
......@@ -263,7 +265,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
var affectedUserIDs []int64
if groupSvc.IsSubscriptionType() {
rows, err := exec.QueryContext(ctx, "SELECT user_id FROM user_subscriptions WHERE group_id = $1", id)
// 只查询未软删除的订阅,避免通知已取消订阅的用户
rows, err := exec.QueryContext(ctx, "SELECT user_id FROM user_subscriptions WHERE group_id = $1 AND deleted_at IS NULL", id)
if err != nil {
return nil, err
}
......@@ -282,7 +285,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
return nil, err
}
if _, err := exec.ExecContext(ctx, "DELETE FROM user_subscriptions WHERE group_id = $1", id); err != nil {
// 软删除订阅:设置 deleted_at 而非硬删除
if _, err := exec.ExecContext(ctx, "UPDATE user_subscriptions SET deleted_at = NOW() WHERE group_id = $1 AND deleted_at IS NULL", id); err != nil {
return nil, err
}
}
......@@ -297,18 +301,11 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
return nil, err
}
// 3. Remove the group id from users.allowed_groups array (legacy representation).
// Phase 1 compatibility: also delete from user_allowed_groups join table when present.
// 3. Remove the group id from user_allowed_groups join table.
// Legacy users.allowed_groups 列已弃用,不再同步。
if _, err := exec.ExecContext(ctx, "DELETE FROM user_allowed_groups WHERE group_id = $1", id); err != nil {
return nil, err
}
if _, err := exec.ExecContext(
ctx,
"UPDATE users SET allowed_groups = array_remove(allowed_groups, $1) WHERE $1 = ANY(allowed_groups)",
id,
); err != nil {
return nil, err
}
// 4. Delete account_groups join rows.
if _, err := exec.ExecContext(ctx, "DELETE FROM account_groups WHERE group_id = $1", id); err != nil {
......
......@@ -478,3 +478,58 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
count, _ := s.repo.GetAccountCount(s.ctx, g.ID)
s.Require().Zero(count)
}
// --- 软删除过滤测试 ---
func (s *GroupRepoSuite) TestDelete_SoftDelete_NotVisibleInList() {
group := &service.Group{
Name: "to-soft-delete",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, group))
// 获取删除前的列表数量
listBefore, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 100})
s.Require().NoError(err)
beforeCount := len(listBefore)
// 软删除
err = s.repo.Delete(s.ctx, group.ID)
s.Require().NoError(err, "Delete (soft delete)")
// 验证列表中不再包含软删除的 group
listAfter, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 100})
s.Require().NoError(err)
s.Require().Len(listAfter, beforeCount-1, "soft deleted group should not appear in list")
// 验证 GetByID 也无法找到
_, err = s.repo.GetByID(s.ctx, group.ID)
s.Require().Error(err)
s.Require().ErrorIs(err, service.ErrGroupNotFound)
}
func (s *GroupRepoSuite) TestDelete_SoftDeletedGroup_lockForUpdate() {
group := &service.Group{
Name: "lock-soft-delete",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, group))
// 软删除
err := s.repo.Delete(s.ctx, group.ID)
s.Require().NoError(err)
// 验证软删除的 group 在 GetByID 时返回 ErrGroupNotFound
// 这证明 lockForUpdate 的 deleted_at IS NULL 过滤正在工作
_, err = s.repo.GetByID(s.ctx, group.ID)
s.Require().Error(err, "should fail to get soft-deleted group")
s.Require().ErrorIs(err, service.ErrGroupNotFound)
}
......@@ -53,6 +53,20 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
var uagRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.user_allowed_groups')").Scan(&uagRegclass))
require.True(t, uagRegclass.Valid, "expected user_allowed_groups table to exist")
// user_subscriptions: deleted_at for soft delete support (migration 012)
requireColumn(t, tx, "user_subscriptions", "deleted_at", "timestamp with time zone", 0, true)
// orphan_allowed_groups_audit table should exist (migration 013)
var orphanAuditRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.orphan_allowed_groups_audit')").Scan(&orphanAuditRegclass))
require.True(t, orphanAuditRegclass.Valid, "expected orphan_allowed_groups_audit table to exist")
// account_groups: created_at should be timestamptz
requireColumn(t, tx, "account_groups", "created_at", "timestamp with time zone", 0, false)
// user_allowed_groups: created_at should be timestamptz
requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false)
}
func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) {
......
......@@ -178,7 +178,7 @@ func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID in
// GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies
func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (counts map[int64]int64, err error) {
rows, err := r.sql.QueryContext(ctx, "SELECT proxy_id, COUNT(*) AS count FROM accounts WHERE proxy_id IS NOT NULL GROUP BY proxy_id")
rows, err := r.sql.QueryContext(ctx, "SELECT proxy_id, COUNT(*) AS count FROM accounts WHERE proxy_id IS NOT NULL AND deleted_at IS NULL GROUP BY proxy_id")
if err != nil {
return nil, err
}
......
......@@ -168,7 +168,8 @@ func (r *redeemCodeRepository) Update(ctx context.Context, code *service.RedeemC
func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error {
now := time.Now()
affected, err := r.client.RedeemCode.Update().
client := clientFromContext(ctx, r.client)
affected, err := client.RedeemCode.Update().
Where(redeemcode.IDEQ(id), redeemcode.StatusEQ(service.StatusUnused)).
SetStatus(service.StatusUsed).
SetUsedBy(userID).
......
......@@ -7,10 +7,12 @@ import (
"fmt"
"strings"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
......@@ -111,3 +113,104 @@ func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
Only(mixins.SkipSoftDelete(ctx))
require.True(t, dbent.IsNotFound(err), "expected row to be hard deleted")
}
// --- UserSubscription 软删除测试 ---
func createEntGroup(t *testing.T, ctx context.Context, client *dbent.Client, name string) *dbent.Group {
t.Helper()
g, err := client.Group.Create().
SetName(name).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err, "create ent group")
return g
}
func TestEntSoftDelete_UserSubscription_DefaultFilterAndSkip(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-user")+"@example.com")
g := createEntGroup(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-group"))
repo := NewUserSubscriptionRepository(client)
sub := &service.UserSubscription{
UserID: u.ID,
GroupID: g.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
}
require.NoError(t, repo.Create(ctx, sub), "create user subscription")
require.NoError(t, repo.Delete(ctx, sub.ID), "soft delete user subscription")
_, err := repo.GetByID(ctx, sub.ID)
require.Error(t, err, "deleted rows should be hidden by default")
_, err = client.UserSubscription.Query().Where(usersubscription.IDEQ(sub.ID)).Only(ctx)
require.Error(t, err, "default ent query should not see soft-deleted rows")
require.True(t, dbent.IsNotFound(err), "expected ent not-found after default soft delete filter")
got, err := client.UserSubscription.Query().
Where(usersubscription.IDEQ(sub.ID)).
Only(mixins.SkipSoftDelete(ctx))
require.NoError(t, err, "SkipSoftDelete should include soft-deleted rows")
require.NotNil(t, got.DeletedAt, "deleted_at should be set after soft delete")
}
func TestEntSoftDelete_UserSubscription_DeleteIdempotent(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-user2")+"@example.com")
g := createEntGroup(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-group2"))
repo := NewUserSubscriptionRepository(client)
sub := &service.UserSubscription{
UserID: u.ID,
GroupID: g.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
}
require.NoError(t, repo.Create(ctx, sub), "create user subscription")
require.NoError(t, repo.Delete(ctx, sub.ID), "first delete")
require.NoError(t, repo.Delete(ctx, sub.ID), "second delete should be idempotent")
}
func TestEntSoftDelete_UserSubscription_ListExcludesDeleted(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-user3")+"@example.com")
g1 := createEntGroup(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-group3a"))
g2 := createEntGroup(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-group3b"))
repo := NewUserSubscriptionRepository(client)
sub1 := &service.UserSubscription{
UserID: u.ID,
GroupID: g1.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
}
require.NoError(t, repo.Create(ctx, sub1), "create subscription 1")
sub2 := &service.UserSubscription{
UserID: u.ID,
GroupID: g2.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
}
require.NoError(t, repo.Create(ctx, sub2), "create subscription 2")
// 软删除 sub1
require.NoError(t, repo.Delete(ctx, sub1.ID), "soft delete subscription 1")
// ListByUserID 应只返回未删除的订阅
subs, err := repo.ListByUserID(ctx, u.ID)
require.NoError(t, err, "ListByUserID")
require.Len(t, subs, 1, "should only return non-deleted subscriptions")
require.Equal(t, sub2.ID, subs[0].ID, "expected sub2 to be returned")
}
......@@ -1109,6 +1109,9 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
today := timezone.Today()
todayQuery := `
......@@ -1135,6 +1138,9 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return result, nil
}
......@@ -1177,6 +1183,9 @@ func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
today := timezone.Today()
todayQuery := `
......@@ -1203,6 +1212,9 @@ func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return result, nil
}
......
......@@ -12,7 +12,6 @@ import (
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
)
type userRepository struct {
......@@ -86,10 +85,11 @@ func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User,
out := userEntityToService(m)
groups, err := r.loadAllowedGroups(ctx, []int64{id})
if err == nil {
if v, ok := groups[id]; ok {
out.AllowedGroups = v
}
if err != nil {
return nil, err
}
if v, ok := groups[id]; ok {
out.AllowedGroups = v
}
return out, nil
}
......@@ -102,10 +102,11 @@ func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service
out := userEntityToService(m)
groups, err := r.loadAllowedGroups(ctx, []int64{m.ID})
if err == nil {
if v, ok := groups[m.ID]; ok {
out.AllowedGroups = v
}
if err != nil {
return nil, err
}
if v, ok := groups[m.ID]; ok {
out.AllowedGroups = v
}
return out, nil
}
......@@ -240,11 +241,12 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
}
allowedGroupsByUser, err := r.loadAllowedGroups(ctx, userIDs)
if err == nil {
for id, u := range userMap {
if groups, ok := allowedGroupsByUser[id]; ok {
u.AllowedGroups = groups
}
if err != nil {
return nil, nil, err
}
for id, u := range userMap {
if groups, ok := allowedGroupsByUser[id]; ok {
u.AllowedGroups = groups
}
}
......@@ -252,12 +254,20 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
}
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
_, err := r.client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount).Save(ctx)
return err
client := clientFromContext(ctx, r.client)
n, err := client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount).Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
if n == 0 {
return service.ErrUserNotFound
}
return nil
}
func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error {
n, err := r.client.User.Update().
client := clientFromContext(ctx, r.client)
n, err := client.User.Update().
Where(dbuser.IDEQ(id), dbuser.BalanceGTE(amount)).
AddBalance(-amount).
Save(ctx)
......@@ -271,8 +281,15 @@ func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount flo
}
func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
_, err := r.client.User.Update().Where(dbuser.IDEQ(id)).AddConcurrency(amount).Save(ctx)
return err
client := clientFromContext(ctx, r.client)
n, err := client.User.Update().Where(dbuser.IDEQ(id)).AddConcurrency(amount).Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
if n == 0 {
return service.ErrUserNotFound
}
return nil
}
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
......@@ -280,33 +297,14 @@ func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool,
}
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
exec := r.sql
if exec == nil {
// 未注入 sqlExecutor 时,退回到 ent client 的 ExecContext(支持事务)。
exec = r.client
}
joinAffected, err := r.client.UserAllowedGroup.Delete().
// 仅操作 user_allowed_groups 联接表,legacy users.allowed_groups 列已弃用。
affected, err := r.client.UserAllowedGroup.Delete().
Where(userallowedgroup.GroupIDEQ(groupID)).
Exec(ctx)
if err != nil {
return 0, err
}
arrayRes, err := exec.ExecContext(
ctx,
"UPDATE users SET allowed_groups = array_remove(allowed_groups, $1), updated_at = NOW() WHERE $1 = ANY(allowed_groups)",
groupID,
)
if err != nil {
return 0, err
}
arrayAffected, _ := arrayRes.RowsAffected()
if int64(joinAffected) > arrayAffected {
return int64(joinAffected), nil
}
return arrayAffected, nil
return int64(affected), nil
}
func (r *userRepository) GetFirstAdmin(ctx context.Context) (*service.User, error) {
......@@ -323,10 +321,11 @@ func (r *userRepository) GetFirstAdmin(ctx context.Context) (*service.User, erro
out := userEntityToService(m)
groups, err := r.loadAllowedGroups(ctx, []int64{m.ID})
if err == nil {
if v, ok := groups[m.ID]; ok {
out.AllowedGroups = v
}
if err != nil {
return nil, err
}
if v, ok := groups[m.ID]; ok {
out.AllowedGroups = v
}
return out, nil
}
......@@ -356,8 +355,7 @@ func (r *userRepository) loadAllowedGroups(ctx context.Context, userIDs []int64)
}
// syncUserAllowedGroupsWithClient 在 ent client/事务内同步用户允许分组:
// 1) 以 user_allowed_groups 为读写源,确保新旧逻辑一致;
// 2) 额外更新 users.allowed_groups(历史字段)以保持兼容。
// 仅操作 user_allowed_groups 联接表,legacy users.allowed_groups 列已弃用。
func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, client *dbent.Client, userID int64, groupIDs []int64) error {
if client == nil {
return nil
......@@ -376,12 +374,10 @@ func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, cl
unique[id] = struct{}{}
}
legacyGroups := make([]int64, 0, len(unique))
if len(unique) > 0 {
creates := make([]*dbent.UserAllowedGroupCreate, 0, len(unique))
for groupID := range unique {
creates = append(creates, client.UserAllowedGroup.Create().SetUserID(userID).SetGroupID(groupID))
legacyGroups = append(legacyGroups, groupID)
}
if err := client.UserAllowedGroup.
CreateBulk(creates...).
......@@ -392,16 +388,6 @@ func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, cl
}
}
// Phase 1 兼容:保持 users.allowed_groups(数组字段)同步,避免旧查询路径读取到过期数据。
var legacy any
if len(legacyGroups) > 0 {
sort.Slice(legacyGroups, func(i, j int) bool { return legacyGroups[i] < legacyGroups[j] })
legacy = pq.Array(legacyGroups)
}
if _, err := client.ExecContext(ctx, "UPDATE users SET allowed_groups = $1::bigint[] WHERE id = $2", legacy, userID); err != nil {
return err
}
return nil
}
......
......@@ -508,3 +508,24 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
s.Require().Equal(user2.ID, users[0].ID, "ListWithFilters result mismatch")
}
// --- UpdateBalance/UpdateConcurrency 影响行数校验测试 ---
func (s *UserRepoSuite) TestUpdateBalance_NotFound() {
err := s.repo.UpdateBalance(s.ctx, 999999, 10.0)
s.Require().Error(err, "expected error for non-existent user")
s.Require().ErrorIs(err, service.ErrUserNotFound)
}
func (s *UserRepoSuite) TestUpdateConcurrency_NotFound() {
err := s.repo.UpdateConcurrency(s.ctx, 999999, 5)
s.Require().Error(err, "expected error for non-existent user")
s.Require().ErrorIs(err, service.ErrUserNotFound)
}
func (s *UserRepoSuite) TestDeductBalance_NotFound() {
err := s.repo.DeductBalance(s.ctx, 999999, 5)
s.Require().Error(err, "expected error for non-existent user")
// DeductBalance 在用户不存在时返回 ErrInsufficientBalance 因为 WHERE 条件不匹配
s.Require().ErrorIs(err, service.ErrInsufficientBalance)
}
......@@ -20,10 +20,11 @@ func NewUserSubscriptionRepository(client *dbent.Client) service.UserSubscriptio
func (r *userSubscriptionRepository) Create(ctx context.Context, sub *service.UserSubscription) error {
if sub == nil {
return nil
return service.ErrSubscriptionNilInput
}
builder := r.client.UserSubscription.Create().
client := clientFromContext(ctx, r.client)
builder := client.UserSubscription.Create().
SetUserID(sub.UserID).
SetGroupID(sub.GroupID).
SetExpiresAt(sub.ExpiresAt).
......@@ -57,7 +58,8 @@ func (r *userSubscriptionRepository) Create(ctx context.Context, sub *service.Us
}
func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) {
m, err := r.client.UserSubscription.Query().
client := clientFromContext(ctx, r.client)
m, err := client.UserSubscription.Query().
Where(usersubscription.IDEQ(id)).
WithUser().
WithGroup().
......@@ -70,7 +72,8 @@ func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*se
}
func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
m, err := r.client.UserSubscription.Query().
client := clientFromContext(ctx, r.client)
m, err := client.UserSubscription.Query().
Where(usersubscription.UserIDEQ(userID), usersubscription.GroupIDEQ(groupID)).
WithGroup().
Only(ctx)
......@@ -81,7 +84,8 @@ func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context,
}
func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
m, err := r.client.UserSubscription.Query().
client := clientFromContext(ctx, r.client)
m, err := client.UserSubscription.Query().
Where(
usersubscription.UserIDEQ(userID),
usersubscription.GroupIDEQ(groupID),
......@@ -98,10 +102,11 @@ func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Con
func (r *userSubscriptionRepository) Update(ctx context.Context, sub *service.UserSubscription) error {
if sub == nil {
return nil
return service.ErrSubscriptionNilInput
}
builder := r.client.UserSubscription.UpdateOneID(sub.ID).
client := clientFromContext(ctx, r.client)
builder := client.UserSubscription.UpdateOneID(sub.ID).
SetUserID(sub.UserID).
SetGroupID(sub.GroupID).
SetStartsAt(sub.StartsAt).
......@@ -127,12 +132,14 @@ func (r *userSubscriptionRepository) Update(ctx context.Context, sub *service.Us
func (r *userSubscriptionRepository) Delete(ctx context.Context, id int64) error {
// Match GORM semantics: deleting a missing row is not an error.
_, err := r.client.UserSubscription.Delete().Where(usersubscription.IDEQ(id)).Exec(ctx)
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.Delete().Where(usersubscription.IDEQ(id)).Exec(ctx)
return err
}
func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
subs, err := r.client.UserSubscription.Query().
client := clientFromContext(ctx, r.client)
subs, err := client.UserSubscription.Query().
Where(usersubscription.UserIDEQ(userID)).
WithGroup().
Order(dbent.Desc(usersubscription.FieldCreatedAt)).
......@@ -144,7 +151,8 @@ func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID in
}
func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
subs, err := r.client.UserSubscription.Query().
client := clientFromContext(ctx, r.client)
subs, err := client.UserSubscription.Query().
Where(
usersubscription.UserIDEQ(userID),
usersubscription.StatusEQ(service.SubscriptionStatusActive),
......@@ -160,7 +168,8 @@ func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, use
}
func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
q := r.client.UserSubscription.Query().Where(usersubscription.GroupIDEQ(groupID))
client := clientFromContext(ctx, r.client)
q := client.UserSubscription.Query().Where(usersubscription.GroupIDEQ(groupID))
total, err := q.Clone().Count(ctx)
if err != nil {
......@@ -182,7 +191,8 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
}
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
q := r.client.UserSubscription.Query()
client := clientFromContext(ctx, r.client)
q := client.UserSubscription.Query()
if userID != nil {
q = q.Where(usersubscription.UserIDEQ(*userID))
}
......@@ -214,34 +224,39 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
}
func (r *userSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
return r.client.UserSubscription.Query().
client := clientFromContext(ctx, r.client)
return client.UserSubscription.Query().
Where(usersubscription.UserIDEQ(userID), usersubscription.GroupIDEQ(groupID)).
Exist(ctx)
}
func (r *userSubscriptionRepository) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error {
_, err := r.client.UserSubscription.UpdateOneID(subscriptionID).
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(subscriptionID).
SetExpiresAt(newExpiresAt).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
func (r *userSubscriptionRepository) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error {
_, err := r.client.UserSubscription.UpdateOneID(subscriptionID).
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(subscriptionID).
SetStatus(status).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
func (r *userSubscriptionRepository) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error {
_, err := r.client.UserSubscription.UpdateOneID(subscriptionID).
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(subscriptionID).
SetNotes(notes).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, start time.Time) error {
_, err := r.client.UserSubscription.UpdateOneID(id).
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(id).
SetDailyWindowStart(start).
SetWeeklyWindowStart(start).
SetMonthlyWindowStart(start).
......@@ -250,7 +265,8 @@ func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int
}
func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
_, err := r.client.UserSubscription.UpdateOneID(id).
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(id).
SetDailyUsageUsd(0).
SetDailyWindowStart(newWindowStart).
Save(ctx)
......@@ -258,7 +274,8 @@ func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int
}
func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
_, err := r.client.UserSubscription.UpdateOneID(id).
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(id).
SetWeeklyUsageUsd(0).
SetWeeklyWindowStart(newWindowStart).
Save(ctx)
......@@ -266,24 +283,112 @@ func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id in
}
func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
_, err := r.client.UserSubscription.UpdateOneID(id).
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(id).
SetMonthlyUsageUsd(0).
SetMonthlyWindowStart(newWindowStart).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
// IncrementUsage 原子性地累加用量并校验限额。
// 使用单条 SQL 语句同时检查 Group 的限额,如果任一限额即将超出则拒绝更新。
// 当更新失败时,会执行额外查询确定具体超出的限额类型。
func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
_, err := r.client.UserSubscription.UpdateOneID(id).
AddDailyUsageUsd(costUSD).
AddWeeklyUsageUsd(costUSD).
AddMonthlyUsageUsd(costUSD).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
// 使用 JOIN 的原子更新:只有当所有限额条件满足时才执行累加
// NULL 限额表示无限制
const atomicUpdateSQL = `
UPDATE user_subscriptions us
SET
daily_usage_usd = us.daily_usage_usd + $1,
weekly_usage_usd = us.weekly_usage_usd + $1,
monthly_usage_usd = us.monthly_usage_usd + $1,
updated_at = NOW()
FROM groups g
WHERE us.id = $2
AND us.deleted_at IS NULL
AND us.group_id = g.id
AND g.deleted_at IS NULL
AND (g.daily_limit_usd IS NULL OR us.daily_usage_usd + $1 <= g.daily_limit_usd)
AND (g.weekly_limit_usd IS NULL OR us.weekly_usage_usd + $1 <= g.weekly_limit_usd)
AND (g.monthly_limit_usd IS NULL OR us.monthly_usage_usd + $1 <= g.monthly_limit_usd)
`
client := clientFromContext(ctx, r.client)
result, err := client.ExecContext(ctx, atomicUpdateSQL, costUSD, id)
if err != nil {
return err
}
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected > 0 {
return nil // 更新成功
}
// affected == 0:可能是订阅不存在、分组已删除、或限额超出
// 执行额外查询确定具体原因
return r.checkIncrementFailureReason(ctx, id, costUSD)
}
// checkIncrementFailureReason 查询更新失败的具体原因
func (r *userSubscriptionRepository) checkIncrementFailureReason(ctx context.Context, id int64, costUSD float64) error {
const checkSQL = `
SELECT
CASE WHEN us.deleted_at IS NOT NULL THEN 'subscription_deleted'
WHEN g.id IS NULL THEN 'subscription_not_found'
WHEN g.deleted_at IS NOT NULL THEN 'group_deleted'
WHEN g.daily_limit_usd IS NOT NULL AND us.daily_usage_usd + $1 > g.daily_limit_usd THEN 'daily_exceeded'
WHEN g.weekly_limit_usd IS NOT NULL AND us.weekly_usage_usd + $1 > g.weekly_limit_usd THEN 'weekly_exceeded'
WHEN g.monthly_limit_usd IS NOT NULL AND us.monthly_usage_usd + $1 > g.monthly_limit_usd THEN 'monthly_exceeded'
ELSE 'unknown'
END AS reason
FROM user_subscriptions us
LEFT JOIN groups g ON us.group_id = g.id
WHERE us.id = $2
`
client := clientFromContext(ctx, r.client)
rows, err := client.QueryContext(ctx, checkSQL, costUSD, id)
if err != nil {
return err
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
return service.ErrSubscriptionNotFound
}
var reason string
if err := rows.Scan(&reason); err != nil {
return err
}
if err := rows.Err(); err != nil {
return err
}
switch reason {
case "subscription_not_found", "subscription_deleted", "group_deleted":
return service.ErrSubscriptionNotFound
case "daily_exceeded":
return service.ErrDailyLimitExceeded
case "weekly_exceeded":
return service.ErrWeeklyLimitExceeded
case "monthly_exceeded":
return service.ErrMonthlyLimitExceeded
default:
// unknown 情况理论上不应发生,但作为兜底返回
return service.ErrSubscriptionNotFound
}
}
func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
n, err := r.client.UserSubscription.Update().
client := clientFromContext(ctx, r.client)
n, err := client.UserSubscription.Update().
Where(
usersubscription.StatusEQ(service.SubscriptionStatusActive),
usersubscription.ExpiresAtLTE(time.Now()),
......@@ -296,7 +401,8 @@ func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Contex
// Extra repository helpers (currently used only by integration tests).
func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]service.UserSubscription, error) {
subs, err := r.client.UserSubscription.Query().
client := clientFromContext(ctx, r.client)
subs, err := client.UserSubscription.Query().
Where(
usersubscription.StatusEQ(service.SubscriptionStatusActive),
usersubscription.ExpiresAtLTE(time.Now()),
......@@ -309,12 +415,14 @@ func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]service
}
func (r *userSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
count, err := r.client.UserSubscription.Query().Where(usersubscription.GroupIDEQ(groupID)).Count(ctx)
client := clientFromContext(ctx, r.client)
count, err := client.UserSubscription.Query().Where(usersubscription.GroupIDEQ(groupID)).Count(ctx)
return int64(count), err
}
func (r *userSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) {
count, err := r.client.UserSubscription.Query().
client := clientFromContext(ctx, r.client)
count, err := client.UserSubscription.Query().
Where(
usersubscription.GroupIDEQ(groupID),
usersubscription.StatusEQ(service.SubscriptionStatusActive),
......@@ -325,7 +433,8 @@ func (r *userSubscriptionRepository) CountActiveByGroupID(ctx context.Context, g
}
func (r *userSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) {
n, err := r.client.UserSubscription.Delete().Where(usersubscription.GroupIDEQ(groupID)).Exec(ctx)
client := clientFromContext(ctx, r.client)
n, err := client.UserSubscription.Delete().Where(usersubscription.GroupIDEQ(groupID)).Exec(ctx)
return int64(n), err
}
......
......@@ -4,6 +4,7 @@ package repository
import (
"context"
"fmt"
"testing"
"time"
......@@ -631,3 +632,249 @@ func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_Ba
s.Require().NoError(err, "GetByID expired")
s.Require().Equal(service.SubscriptionStatusExpired, updated.Status, "expected status expired")
}
// --- 限额检查与软删除过滤测试 ---
func (s *UserSubscriptionRepoSuite) mustCreateGroupWithLimits(name string, daily, weekly, monthly *float64) *service.Group {
s.T().Helper()
create := s.client.Group.Create().
SetName(name).
SetStatus(service.StatusActive).
SetSubscriptionType(service.SubscriptionTypeSubscription)
if daily != nil {
create.SetDailyLimitUsd(*daily)
}
if weekly != nil {
create.SetWeeklyLimitUsd(*weekly)
}
if monthly != nil {
create.SetMonthlyLimitUsd(*monthly)
}
g, err := create.Save(s.ctx)
s.Require().NoError(err, "create group with limits")
return groupEntityToService(g)
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_DailyLimitExceeded() {
user := s.mustCreateUser("dailylimit@test.com", service.RoleUser)
dailyLimit := 10.0
group := s.mustCreateGroupWithLimits("g-dailylimit", &dailyLimit, nil, nil)
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
// 先增加 9.0,应该成功
err := s.repo.IncrementUsage(s.ctx, sub.ID, 9.0)
s.Require().NoError(err, "first increment should succeed")
// 再增加 2.0,会超过 10.0 限额,应该失败
err = s.repo.IncrementUsage(s.ctx, sub.ID, 2.0)
s.Require().Error(err, "should fail when daily limit exceeded")
s.Require().ErrorIs(err, service.ErrDailyLimitExceeded)
// 验证用量没有变化
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().InDelta(9.0, got.DailyUsageUSD, 1e-6, "usage should not change after failed increment")
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_WeeklyLimitExceeded() {
user := s.mustCreateUser("weeklylimit@test.com", service.RoleUser)
weeklyLimit := 50.0
group := s.mustCreateGroupWithLimits("g-weeklylimit", nil, &weeklyLimit, nil)
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
// 增加 45.0,应该成功
err := s.repo.IncrementUsage(s.ctx, sub.ID, 45.0)
s.Require().NoError(err, "first increment should succeed")
// 再增加 10.0,会超过 50.0 限额,应该失败
err = s.repo.IncrementUsage(s.ctx, sub.ID, 10.0)
s.Require().Error(err, "should fail when weekly limit exceeded")
s.Require().ErrorIs(err, service.ErrWeeklyLimitExceeded)
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_MonthlyLimitExceeded() {
user := s.mustCreateUser("monthlylimit@test.com", service.RoleUser)
monthlyLimit := 100.0
group := s.mustCreateGroupWithLimits("g-monthlylimit", nil, nil, &monthlyLimit)
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
// 增加 90.0,应该成功
err := s.repo.IncrementUsage(s.ctx, sub.ID, 90.0)
s.Require().NoError(err, "first increment should succeed")
// 再增加 20.0,会超过 100.0 限额,应该失败
err = s.repo.IncrementUsage(s.ctx, sub.ID, 20.0)
s.Require().Error(err, "should fail when monthly limit exceeded")
s.Require().ErrorIs(err, service.ErrMonthlyLimitExceeded)
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_NoLimits() {
user := s.mustCreateUser("nolimits@test.com", service.RoleUser)
group := s.mustCreateGroupWithLimits("g-nolimits", nil, nil, nil) // 无限额
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
// 应该可以增加任意金额
err := s.repo.IncrementUsage(s.ctx, sub.ID, 1000000.0)
s.Require().NoError(err, "should succeed without limits")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().InDelta(1000000.0, got.DailyUsageUSD, 1e-6)
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_AtExactLimit() {
user := s.mustCreateUser("exactlimit@test.com", service.RoleUser)
dailyLimit := 10.0
group := s.mustCreateGroupWithLimits("g-exactlimit", &dailyLimit, nil, nil)
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
// 正好达到限额应该成功
err := s.repo.IncrementUsage(s.ctx, sub.ID, 10.0)
s.Require().NoError(err, "should succeed at exact limit")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().InDelta(10.0, got.DailyUsageUSD, 1e-6)
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_SoftDeletedGroup() {
user := s.mustCreateUser("softdeleted@test.com", service.RoleUser)
group := s.mustCreateGroup("g-softdeleted")
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
// 软删除分组
_, err := s.client.Group.UpdateOneID(group.ID).SetDeletedAt(time.Now()).Save(s.ctx)
s.Require().NoError(err, "soft delete group")
// IncrementUsage 应该失败,因为分组已软删除
err = s.repo.IncrementUsage(s.ctx, sub.ID, 1.0)
s.Require().Error(err, "should fail for soft-deleted group")
s.Require().ErrorIs(err, service.ErrSubscriptionNotFound)
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_NotFound() {
err := s.repo.IncrementUsage(s.ctx, 999999, 1.0)
s.Require().Error(err, "should fail for non-existent subscription")
s.Require().ErrorIs(err, service.ErrSubscriptionNotFound)
}
// --- nil 入参测试 ---
func (s *UserSubscriptionRepoSuite) TestCreate_NilInput() {
err := s.repo.Create(s.ctx, nil)
s.Require().Error(err, "Create should fail with nil input")
s.Require().ErrorIs(err, service.ErrSubscriptionNilInput)
}
func (s *UserSubscriptionRepoSuite) TestUpdate_NilInput() {
err := s.repo.Update(s.ctx, nil)
s.Require().Error(err, "Update should fail with nil input")
s.Require().ErrorIs(err, service.ErrSubscriptionNilInput)
}
// --- 并发用量更新测试 ---
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Concurrent() {
user := s.mustCreateUser("concurrent@test.com", service.RoleUser)
group := s.mustCreateGroupWithLimits("g-concurrent", nil, nil, nil) // 无限额
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
const numGoroutines = 10
const incrementPerGoroutine = 1.5
// 启动多个 goroutine 并发调用 IncrementUsage
errCh := make(chan error, numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func() {
errCh <- s.repo.IncrementUsage(s.ctx, sub.ID, incrementPerGoroutine)
}()
}
// 等待所有 goroutine 完成
for i := 0; i < numGoroutines; i++ {
err := <-errCh
s.Require().NoError(err, "IncrementUsage should succeed")
}
// 验证累加结果正确
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
expectedUsage := float64(numGoroutines) * incrementPerGoroutine
s.Require().InDelta(expectedUsage, got.DailyUsageUSD, 1e-6, "daily usage should be correctly accumulated")
s.Require().InDelta(expectedUsage, got.WeeklyUsageUSD, 1e-6, "weekly usage should be correctly accumulated")
s.Require().InDelta(expectedUsage, got.MonthlyUsageUSD, 1e-6, "monthly usage should be correctly accumulated")
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_ConcurrentWithLimit() {
user := s.mustCreateUser("concurrentlimit@test.com", service.RoleUser)
dailyLimit := 5.0
group := s.mustCreateGroupWithLimits("g-concurrentlimit", &dailyLimit, nil, nil)
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
// 注意:事务内的操作是串行的,所以这里改为顺序执行以验证限额逻辑
// 尝试增加 10 次,每次 1.0,但限额只有 5.0
const numAttempts = 10
const incrementPerAttempt = 1.0
successCount := 0
for i := 0; i < numAttempts; i++ {
err := s.repo.IncrementUsage(s.ctx, sub.ID, incrementPerAttempt)
if err == nil {
successCount++
}
}
// 验证:应该有 5 次成功(不超过限额),5 次失败(超出限额)
s.Require().Equal(5, successCount, "exactly 5 increments should succeed (limit=5, increment=1)")
// 验证最终用量等于限额
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().InDelta(dailyLimit, got.DailyUsageUSD, 1e-6, "daily usage should equal limit")
}
func (s *UserSubscriptionRepoSuite) TestTxContext_RollbackIsolation() {
baseClient := testEntClient(s.T())
tx, err := baseClient.Tx(context.Background())
s.Require().NoError(err, "begin tx")
defer func() {
if tx != nil {
_ = tx.Rollback()
}
}()
txCtx := dbent.NewTxContext(context.Background(), tx)
suffix := fmt.Sprintf("%d", time.Now().UnixNano())
userEnt, err := tx.Client().User.Create().
SetEmail("tx-user-" + suffix + "@example.com").
SetPasswordHash("test").
Save(txCtx)
s.Require().NoError(err, "create user in tx")
groupEnt, err := tx.Client().Group.Create().
SetName("tx-group-" + suffix).
Save(txCtx)
s.Require().NoError(err, "create group in tx")
repo := NewUserSubscriptionRepository(baseClient)
sub := &service.UserSubscription{
UserID: userEnt.ID,
GroupID: groupEnt.ID,
ExpiresAt: time.Now().AddDate(0, 0, 30),
Status: service.SubscriptionStatusActive,
AssignedAt: time.Now(),
Notes: "tx",
}
s.Require().NoError(repo.Create(txCtx, sub), "create subscription in tx")
s.Require().NoError(repo.UpdateNotes(txCtx, sub.ID, "tx-note"), "update subscription in tx")
s.Require().NoError(tx.Rollback(), "rollback tx")
tx = nil
_, err = repo.GetByID(context.Background(), sub.ID)
s.Require().ErrorIs(err, service.ErrSubscriptionNotFound)
}
......@@ -11,6 +11,7 @@ import (
var (
ErrAccountNotFound = infraerrors.NotFound("ACCOUNT_NOT_FOUND", "account not found")
ErrAccountNilInput = infraerrors.BadRequest("ACCOUNT_NIL_INPUT", "account input cannot be nil")
)
type AccountRepository interface {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment