Commit ec87f39d authored by shaw's avatar shaw
Browse files

feat: 从 gorm 迁移到 ent (#92)

## 主要变更

- 将 ORM 从 GORM 迁移到 Ent
- 使用 SQL 文件迁移替代 GORM AutoMigrate
- 新增迁移运行器支持分布式锁和校验和验证
- 优化 Repository 层查询,新增轻量级存在性检查方法
- 新增完整的单元测试覆盖删除操作

## 迁移优势

- 类型安全与编译期校验
- 关系建模更清晰(Edge/Through)
- 查询一致性更好
- 迁移可控(SQL 文件作为唯一事实来源)
- 可维护性提升

## 新增迁移文件

- 005_schema_parity.sql: 字段对齐
- 006_fix_invalid_subscription_expires_at.sql: 修复过期时间
- 007_add_user_allowed_groups.sql: 用户允许分组表
- 008_seed_default_group.sql: 默认分组种子
- 009_fix_usage_logs_cache_columns.sql: 缓存列修复
parents fb883f00 3d296d88
// Code generated by ent, DO NOT EDIT.
package usersubscription
import (
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
)
const (
// Label holds the string label denoting the usersubscription type in the database.
Label = "user_subscription"
// FieldID holds the string denoting the id field in the database.
FieldID = "id"
// FieldCreatedAt holds the string denoting the created_at field in the database.
FieldCreatedAt = "created_at"
// FieldUpdatedAt holds the string denoting the updated_at field in the database.
FieldUpdatedAt = "updated_at"
// FieldUserID holds the string denoting the user_id field in the database.
FieldUserID = "user_id"
// FieldGroupID holds the string denoting the group_id field in the database.
FieldGroupID = "group_id"
// FieldStartsAt holds the string denoting the starts_at field in the database.
FieldStartsAt = "starts_at"
// FieldExpiresAt holds the string denoting the expires_at field in the database.
FieldExpiresAt = "expires_at"
// FieldStatus holds the string denoting the status field in the database.
FieldStatus = "status"
// FieldDailyWindowStart holds the string denoting the daily_window_start field in the database.
FieldDailyWindowStart = "daily_window_start"
// FieldWeeklyWindowStart holds the string denoting the weekly_window_start field in the database.
FieldWeeklyWindowStart = "weekly_window_start"
// FieldMonthlyWindowStart holds the string denoting the monthly_window_start field in the database.
FieldMonthlyWindowStart = "monthly_window_start"
// FieldDailyUsageUsd holds the string denoting the daily_usage_usd field in the database.
FieldDailyUsageUsd = "daily_usage_usd"
// FieldWeeklyUsageUsd holds the string denoting the weekly_usage_usd field in the database.
FieldWeeklyUsageUsd = "weekly_usage_usd"
// FieldMonthlyUsageUsd holds the string denoting the monthly_usage_usd field in the database.
FieldMonthlyUsageUsd = "monthly_usage_usd"
// FieldAssignedBy holds the string denoting the assigned_by field in the database.
FieldAssignedBy = "assigned_by"
// FieldAssignedAt holds the string denoting the assigned_at field in the database.
FieldAssignedAt = "assigned_at"
// FieldNotes holds the string denoting the notes field in the database.
FieldNotes = "notes"
// EdgeUser holds the string denoting the user edge name in mutations.
EdgeUser = "user"
// EdgeGroup holds the string denoting the group edge name in mutations.
EdgeGroup = "group"
// EdgeAssignedByUser holds the string denoting the assigned_by_user edge name in mutations.
EdgeAssignedByUser = "assigned_by_user"
// Table holds the table name of the usersubscription in the database.
Table = "user_subscriptions"
// UserTable is the table that holds the user relation/edge.
UserTable = "user_subscriptions"
// UserInverseTable is the table name for the User entity.
// It exists in this package in order to avoid circular dependency with the "user" package.
UserInverseTable = "users"
// UserColumn is the table column denoting the user relation/edge.
UserColumn = "user_id"
// GroupTable is the table that holds the group relation/edge.
GroupTable = "user_subscriptions"
// GroupInverseTable is the table name for the Group entity.
// It exists in this package in order to avoid circular dependency with the "group" package.
GroupInverseTable = "groups"
// GroupColumn is the table column denoting the group relation/edge.
GroupColumn = "group_id"
// AssignedByUserTable is the table that holds the assigned_by_user relation/edge.
AssignedByUserTable = "user_subscriptions"
// AssignedByUserInverseTable is the table name for the User entity.
// It exists in this package in order to avoid circular dependency with the "user" package.
AssignedByUserInverseTable = "users"
// AssignedByUserColumn is the table column denoting the assigned_by_user relation/edge.
AssignedByUserColumn = "assigned_by"
)
// Columns holds all SQL columns for usersubscription fields.
var Columns = []string{
FieldID,
FieldCreatedAt,
FieldUpdatedAt,
FieldUserID,
FieldGroupID,
FieldStartsAt,
FieldExpiresAt,
FieldStatus,
FieldDailyWindowStart,
FieldWeeklyWindowStart,
FieldMonthlyWindowStart,
FieldDailyUsageUsd,
FieldWeeklyUsageUsd,
FieldMonthlyUsageUsd,
FieldAssignedBy,
FieldAssignedAt,
FieldNotes,
}
// ValidColumn reports if the column name is valid (part of the table columns).
func ValidColumn(column string) bool {
for i := range Columns {
if column == Columns[i] {
return true
}
}
return false
}
var (
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
DefaultCreatedAt func() time.Time
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
DefaultUpdatedAt func() time.Time
// UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
UpdateDefaultUpdatedAt func() time.Time
// DefaultStatus holds the default value on creation for the "status" field.
DefaultStatus string
// StatusValidator is a validator for the "status" field. It is called by the builders before save.
StatusValidator func(string) error
// DefaultDailyUsageUsd holds the default value on creation for the "daily_usage_usd" field.
DefaultDailyUsageUsd float64
// DefaultWeeklyUsageUsd holds the default value on creation for the "weekly_usage_usd" field.
DefaultWeeklyUsageUsd float64
// DefaultMonthlyUsageUsd holds the default value on creation for the "monthly_usage_usd" field.
DefaultMonthlyUsageUsd float64
// DefaultAssignedAt holds the default value on creation for the "assigned_at" field.
DefaultAssignedAt func() time.Time
)
// OrderOption defines the ordering options for the UserSubscription queries.
type OrderOption func(*sql.Selector)
// ByID orders the results by the id field.
func ByID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldID, opts...).ToFunc()
}
// ByCreatedAt orders the results by the created_at field.
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
}
// ByUpdatedAt orders the results by the updated_at field.
func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
}
// ByUserID orders the results by the user_id field.
func ByUserID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUserID, opts...).ToFunc()
}
// ByGroupID orders the results by the group_id field.
func ByGroupID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldGroupID, opts...).ToFunc()
}
// ByStartsAt orders the results by the starts_at field.
func ByStartsAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldStartsAt, opts...).ToFunc()
}
// ByExpiresAt orders the results by the expires_at field.
func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldExpiresAt, opts...).ToFunc()
}
// ByStatus orders the results by the status field.
func ByStatus(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldStatus, opts...).ToFunc()
}
// ByDailyWindowStart orders the results by the daily_window_start field.
func ByDailyWindowStart(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDailyWindowStart, opts...).ToFunc()
}
// ByWeeklyWindowStart orders the results by the weekly_window_start field.
func ByWeeklyWindowStart(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldWeeklyWindowStart, opts...).ToFunc()
}
// ByMonthlyWindowStart orders the results by the monthly_window_start field.
func ByMonthlyWindowStart(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldMonthlyWindowStart, opts...).ToFunc()
}
// ByDailyUsageUsd orders the results by the daily_usage_usd field.
func ByDailyUsageUsd(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDailyUsageUsd, opts...).ToFunc()
}
// ByWeeklyUsageUsd orders the results by the weekly_usage_usd field.
func ByWeeklyUsageUsd(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldWeeklyUsageUsd, opts...).ToFunc()
}
// ByMonthlyUsageUsd orders the results by the monthly_usage_usd field.
func ByMonthlyUsageUsd(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldMonthlyUsageUsd, opts...).ToFunc()
}
// ByAssignedBy orders the results by the assigned_by field.
func ByAssignedBy(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldAssignedBy, opts...).ToFunc()
}
// ByAssignedAt orders the results by the assigned_at field.
func ByAssignedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldAssignedAt, opts...).ToFunc()
}
// ByNotes orders the results by the notes field.
func ByNotes(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldNotes, opts...).ToFunc()
}
// ByUserField orders the results by user field.
func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newUserStep(), sql.OrderByField(field, opts...))
}
}
// ByGroupField orders the results by group field.
func ByGroupField(field string, opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newGroupStep(), sql.OrderByField(field, opts...))
}
}
// ByAssignedByUserField orders the results by assigned_by_user field.
func ByAssignedByUserField(field string, opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newAssignedByUserStep(), sql.OrderByField(field, opts...))
}
}
func newUserStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(UserInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn),
)
}
func newGroupStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(GroupInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, GroupTable, GroupColumn),
)
}
func newAssignedByUserStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(AssignedByUserInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, AssignedByUserTable, AssignedByUserColumn),
)
}
// Code generated by ent, DO NOT EDIT.
package usersubscription
import (
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
// ID filters vertices based on their ID field.
func ID(id int64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldID, id))
}
// IDEQ applies the EQ predicate on the ID field.
func IDEQ(id int64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldID, id))
}
// IDNEQ applies the NEQ predicate on the ID field.
func IDNEQ(id int64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNEQ(FieldID, id))
}
// IDIn applies the In predicate on the ID field.
func IDIn(ids ...int64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldIn(FieldID, ids...))
}
// IDNotIn applies the NotIn predicate on the ID field.
func IDNotIn(ids ...int64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNotIn(FieldID, ids...))
}
// IDGT applies the GT predicate on the ID field.
func IDGT(id int64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldGT(FieldID, id))
}
// IDGTE applies the GTE predicate on the ID field.
func IDGTE(id int64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldGTE(FieldID, id))
}
// IDLT applies the LT predicate on the ID field.
func IDLT(id int64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldLT(FieldID, id))
}
// IDLTE applies the LTE predicate on the ID field.
func IDLTE(id int64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldLTE(FieldID, id))
}
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
func CreatedAt(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldCreatedAt, v))
}
// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
func UpdatedAt(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldUpdatedAt, v))
}
// UserID applies equality check predicate on the "user_id" field. It's identical to UserIDEQ.
func UserID(v int64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldUserID, v))
}
// GroupID applies equality check predicate on the "group_id" field. It's identical to GroupIDEQ.
func GroupID(v int64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldGroupID, v))
}
// StartsAt applies equality check predicate on the "starts_at" field. It's identical to StartsAtEQ.
func StartsAt(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldStartsAt, v))
}
// ExpiresAt applies equality check predicate on the "expires_at" field. It's identical to ExpiresAtEQ.
func ExpiresAt(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldExpiresAt, v))
}
// Status applies equality check predicate on the "status" field. It's identical to StatusEQ.
func Status(v string) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldStatus, v))
}
// DailyWindowStart applies equality check predicate on the "daily_window_start" field. It's identical to DailyWindowStartEQ.
func DailyWindowStart(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldDailyWindowStart, v))
}
// WeeklyWindowStart applies equality check predicate on the "weekly_window_start" field. It's identical to WeeklyWindowStartEQ.
func WeeklyWindowStart(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldWeeklyWindowStart, v))
}
// MonthlyWindowStart applies equality check predicate on the "monthly_window_start" field. It's identical to MonthlyWindowStartEQ.
func MonthlyWindowStart(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldMonthlyWindowStart, v))
}
// DailyUsageUsd applies equality check predicate on the "daily_usage_usd" field. It's identical to DailyUsageUsdEQ.
func DailyUsageUsd(v float64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldDailyUsageUsd, v))
}
// WeeklyUsageUsd applies equality check predicate on the "weekly_usage_usd" field. It's identical to WeeklyUsageUsdEQ.
func WeeklyUsageUsd(v float64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldWeeklyUsageUsd, v))
}
// MonthlyUsageUsd applies equality check predicate on the "monthly_usage_usd" field. It's identical to MonthlyUsageUsdEQ.
func MonthlyUsageUsd(v float64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldMonthlyUsageUsd, v))
}
// AssignedBy applies equality check predicate on the "assigned_by" field. It's identical to AssignedByEQ.
func AssignedBy(v int64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldAssignedBy, v))
}
// AssignedAt applies equality check predicate on the "assigned_at" field. It's identical to AssignedAtEQ.
func AssignedAt(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldAssignedAt, v))
}
// Notes applies equality check predicate on the "notes" field. It's identical to NotesEQ.
func Notes(v string) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldNotes, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldCreatedAt, v))
}
// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
func CreatedAtNEQ(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNEQ(FieldCreatedAt, v))
}
// CreatedAtIn applies the In predicate on the "created_at" field.
func CreatedAtIn(vs ...time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldIn(FieldCreatedAt, vs...))
}
// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
func CreatedAtNotIn(vs ...time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNotIn(FieldCreatedAt, vs...))
}
// CreatedAtGT applies the GT predicate on the "created_at" field.
func CreatedAtGT(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldGT(FieldCreatedAt, v))
}
// CreatedAtGTE applies the GTE predicate on the "created_at" field.
func CreatedAtGTE(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldGTE(FieldCreatedAt, v))
}
// CreatedAtLT applies the LT predicate on the "created_at" field.
func CreatedAtLT(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldLT(FieldCreatedAt, v))
}
// CreatedAtLTE applies the LTE predicate on the "created_at" field.
func CreatedAtLTE(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldLTE(FieldCreatedAt, v))
}
// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
func UpdatedAtEQ(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldUpdatedAt, v))
}
// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
func UpdatedAtNEQ(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNEQ(FieldUpdatedAt, v))
}
// UpdatedAtIn applies the In predicate on the "updated_at" field.
func UpdatedAtIn(vs ...time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldIn(FieldUpdatedAt, vs...))
}
// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
func UpdatedAtNotIn(vs ...time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNotIn(FieldUpdatedAt, vs...))
}
// UpdatedAtGT applies the GT predicate on the "updated_at" field.
func UpdatedAtGT(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldGT(FieldUpdatedAt, v))
}
// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
func UpdatedAtGTE(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldGTE(FieldUpdatedAt, v))
}
// UpdatedAtLT applies the LT predicate on the "updated_at" field.
func UpdatedAtLT(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldLT(FieldUpdatedAt, v))
}
// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
func UpdatedAtLTE(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldLTE(FieldUpdatedAt, v))
}
// UserIDEQ applies the EQ predicate on the "user_id" field.
func UserIDEQ(v int64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldUserID, v))
}
// UserIDNEQ applies the NEQ predicate on the "user_id" field.
func UserIDNEQ(v int64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNEQ(FieldUserID, v))
}
// UserIDIn applies the In predicate on the "user_id" field.
func UserIDIn(vs ...int64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldIn(FieldUserID, vs...))
}
// UserIDNotIn applies the NotIn predicate on the "user_id" field.
func UserIDNotIn(vs ...int64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNotIn(FieldUserID, vs...))
}
// GroupIDEQ applies the EQ predicate on the "group_id" field.
func GroupIDEQ(v int64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldGroupID, v))
}
// GroupIDNEQ applies the NEQ predicate on the "group_id" field.
func GroupIDNEQ(v int64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNEQ(FieldGroupID, v))
}
// GroupIDIn applies the In predicate on the "group_id" field.
func GroupIDIn(vs ...int64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldIn(FieldGroupID, vs...))
}
// GroupIDNotIn applies the NotIn predicate on the "group_id" field.
func GroupIDNotIn(vs ...int64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNotIn(FieldGroupID, vs...))
}
// StartsAtEQ applies the EQ predicate on the "starts_at" field.
func StartsAtEQ(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldStartsAt, v))
}
// StartsAtNEQ applies the NEQ predicate on the "starts_at" field.
func StartsAtNEQ(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNEQ(FieldStartsAt, v))
}
// StartsAtIn applies the In predicate on the "starts_at" field.
func StartsAtIn(vs ...time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldIn(FieldStartsAt, vs...))
}
// StartsAtNotIn applies the NotIn predicate on the "starts_at" field.
func StartsAtNotIn(vs ...time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNotIn(FieldStartsAt, vs...))
}
// StartsAtGT applies the GT predicate on the "starts_at" field.
func StartsAtGT(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldGT(FieldStartsAt, v))
}
// StartsAtGTE applies the GTE predicate on the "starts_at" field.
func StartsAtGTE(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldGTE(FieldStartsAt, v))
}
// StartsAtLT applies the LT predicate on the "starts_at" field.
func StartsAtLT(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldLT(FieldStartsAt, v))
}
// StartsAtLTE applies the LTE predicate on the "starts_at" field.
func StartsAtLTE(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldLTE(FieldStartsAt, v))
}
// ExpiresAtEQ applies the EQ predicate on the "expires_at" field.
func ExpiresAtEQ(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldExpiresAt, v))
}
// ExpiresAtNEQ applies the NEQ predicate on the "expires_at" field.
func ExpiresAtNEQ(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNEQ(FieldExpiresAt, v))
}
// ExpiresAtIn applies the In predicate on the "expires_at" field.
func ExpiresAtIn(vs ...time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldIn(FieldExpiresAt, vs...))
}
// ExpiresAtNotIn applies the NotIn predicate on the "expires_at" field.
func ExpiresAtNotIn(vs ...time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNotIn(FieldExpiresAt, vs...))
}
// ExpiresAtGT applies the GT predicate on the "expires_at" field.
func ExpiresAtGT(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldGT(FieldExpiresAt, v))
}
// ExpiresAtGTE applies the GTE predicate on the "expires_at" field.
func ExpiresAtGTE(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldGTE(FieldExpiresAt, v))
}
// ExpiresAtLT applies the LT predicate on the "expires_at" field.
func ExpiresAtLT(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldLT(FieldExpiresAt, v))
}
// ExpiresAtLTE applies the LTE predicate on the "expires_at" field.
func ExpiresAtLTE(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldLTE(FieldExpiresAt, v))
}
// StatusEQ applies the EQ predicate on the "status" field.
func StatusEQ(v string) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldStatus, v))
}
// StatusNEQ applies the NEQ predicate on the "status" field.
func StatusNEQ(v string) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNEQ(FieldStatus, v))
}
// StatusIn applies the In predicate on the "status" field.
func StatusIn(vs ...string) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldIn(FieldStatus, vs...))
}
// StatusNotIn applies the NotIn predicate on the "status" field.
func StatusNotIn(vs ...string) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNotIn(FieldStatus, vs...))
}
// StatusGT applies the GT predicate on the "status" field.
func StatusGT(v string) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldGT(FieldStatus, v))
}
// StatusGTE applies the GTE predicate on the "status" field.
func StatusGTE(v string) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldGTE(FieldStatus, v))
}
// StatusLT applies the LT predicate on the "status" field.
func StatusLT(v string) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldLT(FieldStatus, v))
}
// StatusLTE applies the LTE predicate on the "status" field.
func StatusLTE(v string) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldLTE(FieldStatus, v))
}
// StatusContains applies the Contains predicate on the "status" field.
func StatusContains(v string) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldContains(FieldStatus, v))
}
// StatusHasPrefix applies the HasPrefix predicate on the "status" field.
func StatusHasPrefix(v string) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldHasPrefix(FieldStatus, v))
}
// StatusHasSuffix applies the HasSuffix predicate on the "status" field.
func StatusHasSuffix(v string) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldHasSuffix(FieldStatus, v))
}
// StatusEqualFold applies the EqualFold predicate on the "status" field.
func StatusEqualFold(v string) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEqualFold(FieldStatus, v))
}
// StatusContainsFold applies the ContainsFold predicate on the "status" field.
func StatusContainsFold(v string) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldContainsFold(FieldStatus, v))
}
// DailyWindowStartEQ applies the EQ predicate on the "daily_window_start" field.
func DailyWindowStartEQ(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldDailyWindowStart, v))
}
// DailyWindowStartNEQ applies the NEQ predicate on the "daily_window_start" field.
func DailyWindowStartNEQ(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNEQ(FieldDailyWindowStart, v))
}
// DailyWindowStartIn applies the In predicate on the "daily_window_start" field.
func DailyWindowStartIn(vs ...time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldIn(FieldDailyWindowStart, vs...))
}
// DailyWindowStartNotIn applies the NotIn predicate on the "daily_window_start" field.
func DailyWindowStartNotIn(vs ...time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNotIn(FieldDailyWindowStart, vs...))
}
// DailyWindowStartGT applies the GT predicate on the "daily_window_start" field.
func DailyWindowStartGT(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldGT(FieldDailyWindowStart, v))
}
// DailyWindowStartGTE applies the GTE predicate on the "daily_window_start" field.
func DailyWindowStartGTE(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldGTE(FieldDailyWindowStart, v))
}
// DailyWindowStartLT applies the LT predicate on the "daily_window_start" field.
func DailyWindowStartLT(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldLT(FieldDailyWindowStart, v))
}
// DailyWindowStartLTE applies the LTE predicate on the "daily_window_start" field.
func DailyWindowStartLTE(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldLTE(FieldDailyWindowStart, v))
}
// DailyWindowStartIsNil applies the IsNil predicate on the "daily_window_start" field.
func DailyWindowStartIsNil() predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldIsNull(FieldDailyWindowStart))
}
// DailyWindowStartNotNil applies the NotNil predicate on the "daily_window_start" field.
func DailyWindowStartNotNil() predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNotNull(FieldDailyWindowStart))
}
// WeeklyWindowStartEQ applies the EQ predicate on the "weekly_window_start" field.
func WeeklyWindowStartEQ(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldWeeklyWindowStart, v))
}
// WeeklyWindowStartNEQ applies the NEQ predicate on the "weekly_window_start" field.
func WeeklyWindowStartNEQ(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNEQ(FieldWeeklyWindowStart, v))
}
// WeeklyWindowStartIn applies the In predicate on the "weekly_window_start" field.
func WeeklyWindowStartIn(vs ...time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldIn(FieldWeeklyWindowStart, vs...))
}
// WeeklyWindowStartNotIn applies the NotIn predicate on the "weekly_window_start" field.
func WeeklyWindowStartNotIn(vs ...time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNotIn(FieldWeeklyWindowStart, vs...))
}
// WeeklyWindowStartGT applies the GT predicate on the "weekly_window_start" field.
func WeeklyWindowStartGT(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldGT(FieldWeeklyWindowStart, v))
}
// WeeklyWindowStartGTE applies the GTE predicate on the "weekly_window_start" field.
func WeeklyWindowStartGTE(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldGTE(FieldWeeklyWindowStart, v))
}
// WeeklyWindowStartLT applies the LT predicate on the "weekly_window_start" field.
func WeeklyWindowStartLT(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldLT(FieldWeeklyWindowStart, v))
}
// WeeklyWindowStartLTE applies the LTE predicate on the "weekly_window_start" field.
func WeeklyWindowStartLTE(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldLTE(FieldWeeklyWindowStart, v))
}
// WeeklyWindowStartIsNil applies the IsNil predicate on the "weekly_window_start" field.
func WeeklyWindowStartIsNil() predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldIsNull(FieldWeeklyWindowStart))
}
// WeeklyWindowStartNotNil applies the NotNil predicate on the "weekly_window_start" field.
func WeeklyWindowStartNotNil() predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNotNull(FieldWeeklyWindowStart))
}
// MonthlyWindowStartEQ applies the EQ predicate on the "monthly_window_start" field.
func MonthlyWindowStartEQ(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldMonthlyWindowStart, v))
}
// MonthlyWindowStartNEQ applies the NEQ predicate on the "monthly_window_start" field.
func MonthlyWindowStartNEQ(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNEQ(FieldMonthlyWindowStart, v))
}
// MonthlyWindowStartIn applies the In predicate on the "monthly_window_start" field.
func MonthlyWindowStartIn(vs ...time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldIn(FieldMonthlyWindowStart, vs...))
}
// MonthlyWindowStartNotIn applies the NotIn predicate on the "monthly_window_start" field.
func MonthlyWindowStartNotIn(vs ...time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNotIn(FieldMonthlyWindowStart, vs...))
}
// MonthlyWindowStartGT applies the GT predicate on the "monthly_window_start" field.
func MonthlyWindowStartGT(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldGT(FieldMonthlyWindowStart, v))
}
// MonthlyWindowStartGTE applies the GTE predicate on the "monthly_window_start" field.
func MonthlyWindowStartGTE(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldGTE(FieldMonthlyWindowStart, v))
}
// MonthlyWindowStartLT applies the LT predicate on the "monthly_window_start" field.
func MonthlyWindowStartLT(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldLT(FieldMonthlyWindowStart, v))
}
// MonthlyWindowStartLTE applies the LTE predicate on the "monthly_window_start" field.
func MonthlyWindowStartLTE(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldLTE(FieldMonthlyWindowStart, v))
}
// MonthlyWindowStartIsNil applies the IsNil predicate on the "monthly_window_start" field.
func MonthlyWindowStartIsNil() predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldIsNull(FieldMonthlyWindowStart))
}
// MonthlyWindowStartNotNil applies the NotNil predicate on the "monthly_window_start" field.
func MonthlyWindowStartNotNil() predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNotNull(FieldMonthlyWindowStart))
}
// DailyUsageUsdEQ applies the EQ predicate on the "daily_usage_usd" field.
func DailyUsageUsdEQ(v float64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldDailyUsageUsd, v))
}
// DailyUsageUsdNEQ applies the NEQ predicate on the "daily_usage_usd" field.
func DailyUsageUsdNEQ(v float64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNEQ(FieldDailyUsageUsd, v))
}
// DailyUsageUsdIn applies the In predicate on the "daily_usage_usd" field.
func DailyUsageUsdIn(vs ...float64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldIn(FieldDailyUsageUsd, vs...))
}
// DailyUsageUsdNotIn applies the NotIn predicate on the "daily_usage_usd" field.
func DailyUsageUsdNotIn(vs ...float64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNotIn(FieldDailyUsageUsd, vs...))
}
// DailyUsageUsdGT applies the GT predicate on the "daily_usage_usd" field.
func DailyUsageUsdGT(v float64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldGT(FieldDailyUsageUsd, v))
}
// DailyUsageUsdGTE applies the GTE predicate on the "daily_usage_usd" field.
func DailyUsageUsdGTE(v float64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldGTE(FieldDailyUsageUsd, v))
}
// DailyUsageUsdLT applies the LT predicate on the "daily_usage_usd" field.
func DailyUsageUsdLT(v float64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldLT(FieldDailyUsageUsd, v))
}
// DailyUsageUsdLTE applies the LTE predicate on the "daily_usage_usd" field.
func DailyUsageUsdLTE(v float64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldLTE(FieldDailyUsageUsd, v))
}
// WeeklyUsageUsdEQ applies the EQ predicate on the "weekly_usage_usd" field.
func WeeklyUsageUsdEQ(v float64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldWeeklyUsageUsd, v))
}
// WeeklyUsageUsdNEQ applies the NEQ predicate on the "weekly_usage_usd" field.
func WeeklyUsageUsdNEQ(v float64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNEQ(FieldWeeklyUsageUsd, v))
}
// WeeklyUsageUsdIn applies the In predicate on the "weekly_usage_usd" field.
func WeeklyUsageUsdIn(vs ...float64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldIn(FieldWeeklyUsageUsd, vs...))
}
// WeeklyUsageUsdNotIn applies the NotIn predicate on the "weekly_usage_usd" field.
func WeeklyUsageUsdNotIn(vs ...float64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNotIn(FieldWeeklyUsageUsd, vs...))
}
// WeeklyUsageUsdGT applies the GT predicate on the "weekly_usage_usd" field.
func WeeklyUsageUsdGT(v float64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldGT(FieldWeeklyUsageUsd, v))
}
// WeeklyUsageUsdGTE applies the GTE predicate on the "weekly_usage_usd" field.
func WeeklyUsageUsdGTE(v float64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldGTE(FieldWeeklyUsageUsd, v))
}
// WeeklyUsageUsdLT applies the LT predicate on the "weekly_usage_usd" field.
func WeeklyUsageUsdLT(v float64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldLT(FieldWeeklyUsageUsd, v))
}
// WeeklyUsageUsdLTE applies the LTE predicate on the "weekly_usage_usd" field.
func WeeklyUsageUsdLTE(v float64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldLTE(FieldWeeklyUsageUsd, v))
}
// MonthlyUsageUsdEQ applies the EQ predicate on the "monthly_usage_usd" field.
func MonthlyUsageUsdEQ(v float64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldMonthlyUsageUsd, v))
}
// MonthlyUsageUsdNEQ applies the NEQ predicate on the "monthly_usage_usd" field.
func MonthlyUsageUsdNEQ(v float64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNEQ(FieldMonthlyUsageUsd, v))
}
// MonthlyUsageUsdIn applies the In predicate on the "monthly_usage_usd" field.
func MonthlyUsageUsdIn(vs ...float64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldIn(FieldMonthlyUsageUsd, vs...))
}
// MonthlyUsageUsdNotIn applies the NotIn predicate on the "monthly_usage_usd" field.
func MonthlyUsageUsdNotIn(vs ...float64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNotIn(FieldMonthlyUsageUsd, vs...))
}
// MonthlyUsageUsdGT applies the GT predicate on the "monthly_usage_usd" field.
func MonthlyUsageUsdGT(v float64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldGT(FieldMonthlyUsageUsd, v))
}
// MonthlyUsageUsdGTE applies the GTE predicate on the "monthly_usage_usd" field.
func MonthlyUsageUsdGTE(v float64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldGTE(FieldMonthlyUsageUsd, v))
}
// MonthlyUsageUsdLT applies the LT predicate on the "monthly_usage_usd" field.
func MonthlyUsageUsdLT(v float64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldLT(FieldMonthlyUsageUsd, v))
}
// MonthlyUsageUsdLTE applies the LTE predicate on the "monthly_usage_usd" field.
func MonthlyUsageUsdLTE(v float64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldLTE(FieldMonthlyUsageUsd, v))
}
// AssignedByEQ applies the EQ predicate on the "assigned_by" field.
func AssignedByEQ(v int64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldAssignedBy, v))
}
// AssignedByNEQ applies the NEQ predicate on the "assigned_by" field.
func AssignedByNEQ(v int64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNEQ(FieldAssignedBy, v))
}
// AssignedByIn applies the In predicate on the "assigned_by" field.
func AssignedByIn(vs ...int64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldIn(FieldAssignedBy, vs...))
}
// AssignedByNotIn applies the NotIn predicate on the "assigned_by" field.
func AssignedByNotIn(vs ...int64) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNotIn(FieldAssignedBy, vs...))
}
// AssignedByIsNil applies the IsNil predicate on the "assigned_by" field.
func AssignedByIsNil() predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldIsNull(FieldAssignedBy))
}
// AssignedByNotNil applies the NotNil predicate on the "assigned_by" field.
func AssignedByNotNil() predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNotNull(FieldAssignedBy))
}
// AssignedAtEQ applies the EQ predicate on the "assigned_at" field.
func AssignedAtEQ(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldAssignedAt, v))
}
// AssignedAtNEQ applies the NEQ predicate on the "assigned_at" field.
func AssignedAtNEQ(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNEQ(FieldAssignedAt, v))
}
// AssignedAtIn applies the In predicate on the "assigned_at" field.
func AssignedAtIn(vs ...time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldIn(FieldAssignedAt, vs...))
}
// AssignedAtNotIn applies the NotIn predicate on the "assigned_at" field.
func AssignedAtNotIn(vs ...time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNotIn(FieldAssignedAt, vs...))
}
// AssignedAtGT applies the GT predicate on the "assigned_at" field.
func AssignedAtGT(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldGT(FieldAssignedAt, v))
}
// AssignedAtGTE applies the GTE predicate on the "assigned_at" field.
func AssignedAtGTE(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldGTE(FieldAssignedAt, v))
}
// AssignedAtLT applies the LT predicate on the "assigned_at" field.
func AssignedAtLT(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldLT(FieldAssignedAt, v))
}
// AssignedAtLTE applies the LTE predicate on the "assigned_at" field.
func AssignedAtLTE(v time.Time) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldLTE(FieldAssignedAt, v))
}
// NotesEQ applies the EQ predicate on the "notes" field.
func NotesEQ(v string) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEQ(FieldNotes, v))
}
// NotesNEQ applies the NEQ predicate on the "notes" field.
func NotesNEQ(v string) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNEQ(FieldNotes, v))
}
// NotesIn applies the In predicate on the "notes" field.
func NotesIn(vs ...string) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldIn(FieldNotes, vs...))
}
// NotesNotIn applies the NotIn predicate on the "notes" field.
func NotesNotIn(vs ...string) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNotIn(FieldNotes, vs...))
}
// NotesGT applies the GT predicate on the "notes" field.
func NotesGT(v string) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldGT(FieldNotes, v))
}
// NotesGTE applies the GTE predicate on the "notes" field.
func NotesGTE(v string) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldGTE(FieldNotes, v))
}
// NotesLT applies the LT predicate on the "notes" field.
func NotesLT(v string) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldLT(FieldNotes, v))
}
// NotesLTE applies the LTE predicate on the "notes" field.
func NotesLTE(v string) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldLTE(FieldNotes, v))
}
// NotesContains applies the Contains predicate on the "notes" field.
func NotesContains(v string) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldContains(FieldNotes, v))
}
// NotesHasPrefix applies the HasPrefix predicate on the "notes" field.
func NotesHasPrefix(v string) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldHasPrefix(FieldNotes, v))
}
// NotesHasSuffix applies the HasSuffix predicate on the "notes" field.
func NotesHasSuffix(v string) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldHasSuffix(FieldNotes, v))
}
// NotesIsNil applies the IsNil predicate on the "notes" field.
func NotesIsNil() predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldIsNull(FieldNotes))
}
// NotesNotNil applies the NotNil predicate on the "notes" field.
func NotesNotNil() predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldNotNull(FieldNotes))
}
// NotesEqualFold applies the EqualFold predicate on the "notes" field.
func NotesEqualFold(v string) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldEqualFold(FieldNotes, v))
}
// NotesContainsFold applies the ContainsFold predicate on the "notes" field.
func NotesContainsFold(v string) predicate.UserSubscription {
return predicate.UserSubscription(sql.FieldContainsFold(FieldNotes, v))
}
// HasUser applies the HasEdge predicate on the "user" edge.
func HasUser() predicate.UserSubscription {
return predicate.UserSubscription(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasUserWith applies the HasEdge predicate on the "user" edge with a given conditions (other predicates).
func HasUserWith(preds ...predicate.User) predicate.UserSubscription {
return predicate.UserSubscription(func(s *sql.Selector) {
step := newUserStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// HasGroup applies the HasEdge predicate on the "group" edge.
func HasGroup() predicate.UserSubscription {
return predicate.UserSubscription(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, GroupTable, GroupColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasGroupWith applies the HasEdge predicate on the "group" edge with a given conditions (other predicates).
func HasGroupWith(preds ...predicate.Group) predicate.UserSubscription {
return predicate.UserSubscription(func(s *sql.Selector) {
step := newGroupStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// HasAssignedByUser applies the HasEdge predicate on the "assigned_by_user" edge.
func HasAssignedByUser() predicate.UserSubscription {
return predicate.UserSubscription(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, AssignedByUserTable, AssignedByUserColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasAssignedByUserWith applies the HasEdge predicate on the "assigned_by_user" edge with a given conditions (other predicates).
func HasAssignedByUserWith(preds ...predicate.User) predicate.UserSubscription {
return predicate.UserSubscription(func(s *sql.Selector) {
step := newAssignedByUserStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// And groups predicates with the AND operator between them.
func And(predicates ...predicate.UserSubscription) predicate.UserSubscription {
return predicate.UserSubscription(sql.AndPredicates(predicates...))
}
// Or groups predicates with the OR operator between them.
func Or(predicates ...predicate.UserSubscription) predicate.UserSubscription {
return predicate.UserSubscription(sql.OrPredicates(predicates...))
}
// Not applies the not operator on the given predicate.
func Not(p predicate.UserSubscription) predicate.UserSubscription {
return predicate.UserSubscription(sql.NotPredicates(p))
}
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"errors"
"fmt"
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
// UserSubscriptionCreate is the builder for creating a UserSubscription entity.
type UserSubscriptionCreate struct {
config
mutation *UserSubscriptionMutation
hooks []Hook
conflict []sql.ConflictOption
}
// SetCreatedAt sets the "created_at" field.
func (_c *UserSubscriptionCreate) SetCreatedAt(v time.Time) *UserSubscriptionCreate {
_c.mutation.SetCreatedAt(v)
return _c
}
// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
func (_c *UserSubscriptionCreate) SetNillableCreatedAt(v *time.Time) *UserSubscriptionCreate {
if v != nil {
_c.SetCreatedAt(*v)
}
return _c
}
// SetUpdatedAt sets the "updated_at" field.
func (_c *UserSubscriptionCreate) SetUpdatedAt(v time.Time) *UserSubscriptionCreate {
_c.mutation.SetUpdatedAt(v)
return _c
}
// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
func (_c *UserSubscriptionCreate) SetNillableUpdatedAt(v *time.Time) *UserSubscriptionCreate {
if v != nil {
_c.SetUpdatedAt(*v)
}
return _c
}
// SetUserID sets the "user_id" field.
func (_c *UserSubscriptionCreate) SetUserID(v int64) *UserSubscriptionCreate {
_c.mutation.SetUserID(v)
return _c
}
// SetGroupID sets the "group_id" field.
func (_c *UserSubscriptionCreate) SetGroupID(v int64) *UserSubscriptionCreate {
_c.mutation.SetGroupID(v)
return _c
}
// SetStartsAt sets the "starts_at" field.
func (_c *UserSubscriptionCreate) SetStartsAt(v time.Time) *UserSubscriptionCreate {
_c.mutation.SetStartsAt(v)
return _c
}
// SetExpiresAt sets the "expires_at" field.
func (_c *UserSubscriptionCreate) SetExpiresAt(v time.Time) *UserSubscriptionCreate {
_c.mutation.SetExpiresAt(v)
return _c
}
// SetStatus sets the "status" field.
func (_c *UserSubscriptionCreate) SetStatus(v string) *UserSubscriptionCreate {
_c.mutation.SetStatus(v)
return _c
}
// SetNillableStatus sets the "status" field if the given value is not nil.
func (_c *UserSubscriptionCreate) SetNillableStatus(v *string) *UserSubscriptionCreate {
if v != nil {
_c.SetStatus(*v)
}
return _c
}
// SetDailyWindowStart sets the "daily_window_start" field.
func (_c *UserSubscriptionCreate) SetDailyWindowStart(v time.Time) *UserSubscriptionCreate {
_c.mutation.SetDailyWindowStart(v)
return _c
}
// SetNillableDailyWindowStart sets the "daily_window_start" field if the given value is not nil.
func (_c *UserSubscriptionCreate) SetNillableDailyWindowStart(v *time.Time) *UserSubscriptionCreate {
if v != nil {
_c.SetDailyWindowStart(*v)
}
return _c
}
// SetWeeklyWindowStart sets the "weekly_window_start" field.
func (_c *UserSubscriptionCreate) SetWeeklyWindowStart(v time.Time) *UserSubscriptionCreate {
_c.mutation.SetWeeklyWindowStart(v)
return _c
}
// SetNillableWeeklyWindowStart sets the "weekly_window_start" field if the given value is not nil.
func (_c *UserSubscriptionCreate) SetNillableWeeklyWindowStart(v *time.Time) *UserSubscriptionCreate {
if v != nil {
_c.SetWeeklyWindowStart(*v)
}
return _c
}
// SetMonthlyWindowStart sets the "monthly_window_start" field.
func (_c *UserSubscriptionCreate) SetMonthlyWindowStart(v time.Time) *UserSubscriptionCreate {
_c.mutation.SetMonthlyWindowStart(v)
return _c
}
// SetNillableMonthlyWindowStart sets the "monthly_window_start" field if the given value is not nil.
func (_c *UserSubscriptionCreate) SetNillableMonthlyWindowStart(v *time.Time) *UserSubscriptionCreate {
if v != nil {
_c.SetMonthlyWindowStart(*v)
}
return _c
}
// SetDailyUsageUsd sets the "daily_usage_usd" field.
func (_c *UserSubscriptionCreate) SetDailyUsageUsd(v float64) *UserSubscriptionCreate {
_c.mutation.SetDailyUsageUsd(v)
return _c
}
// SetNillableDailyUsageUsd sets the "daily_usage_usd" field if the given value is not nil.
func (_c *UserSubscriptionCreate) SetNillableDailyUsageUsd(v *float64) *UserSubscriptionCreate {
if v != nil {
_c.SetDailyUsageUsd(*v)
}
return _c
}
// SetWeeklyUsageUsd sets the "weekly_usage_usd" field.
func (_c *UserSubscriptionCreate) SetWeeklyUsageUsd(v float64) *UserSubscriptionCreate {
_c.mutation.SetWeeklyUsageUsd(v)
return _c
}
// SetNillableWeeklyUsageUsd sets the "weekly_usage_usd" field if the given value is not nil.
func (_c *UserSubscriptionCreate) SetNillableWeeklyUsageUsd(v *float64) *UserSubscriptionCreate {
if v != nil {
_c.SetWeeklyUsageUsd(*v)
}
return _c
}
// SetMonthlyUsageUsd sets the "monthly_usage_usd" field.
func (_c *UserSubscriptionCreate) SetMonthlyUsageUsd(v float64) *UserSubscriptionCreate {
_c.mutation.SetMonthlyUsageUsd(v)
return _c
}
// SetNillableMonthlyUsageUsd sets the "monthly_usage_usd" field if the given value is not nil.
func (_c *UserSubscriptionCreate) SetNillableMonthlyUsageUsd(v *float64) *UserSubscriptionCreate {
if v != nil {
_c.SetMonthlyUsageUsd(*v)
}
return _c
}
// SetAssignedBy sets the "assigned_by" field.
func (_c *UserSubscriptionCreate) SetAssignedBy(v int64) *UserSubscriptionCreate {
_c.mutation.SetAssignedBy(v)
return _c
}
// SetNillableAssignedBy sets the "assigned_by" field if the given value is not nil.
func (_c *UserSubscriptionCreate) SetNillableAssignedBy(v *int64) *UserSubscriptionCreate {
if v != nil {
_c.SetAssignedBy(*v)
}
return _c
}
// SetAssignedAt sets the "assigned_at" field.
func (_c *UserSubscriptionCreate) SetAssignedAt(v time.Time) *UserSubscriptionCreate {
_c.mutation.SetAssignedAt(v)
return _c
}
// SetNillableAssignedAt sets the "assigned_at" field if the given value is not nil.
func (_c *UserSubscriptionCreate) SetNillableAssignedAt(v *time.Time) *UserSubscriptionCreate {
if v != nil {
_c.SetAssignedAt(*v)
}
return _c
}
// SetNotes sets the "notes" field.
func (_c *UserSubscriptionCreate) SetNotes(v string) *UserSubscriptionCreate {
_c.mutation.SetNotes(v)
return _c
}
// SetNillableNotes sets the "notes" field if the given value is not nil.
func (_c *UserSubscriptionCreate) SetNillableNotes(v *string) *UserSubscriptionCreate {
if v != nil {
_c.SetNotes(*v)
}
return _c
}
// SetUser sets the "user" edge to the User entity.
func (_c *UserSubscriptionCreate) SetUser(v *User) *UserSubscriptionCreate {
return _c.SetUserID(v.ID)
}
// SetGroup sets the "group" edge to the Group entity.
func (_c *UserSubscriptionCreate) SetGroup(v *Group) *UserSubscriptionCreate {
return _c.SetGroupID(v.ID)
}
// SetAssignedByUserID sets the "assigned_by_user" edge to the User entity by ID.
func (_c *UserSubscriptionCreate) SetAssignedByUserID(id int64) *UserSubscriptionCreate {
_c.mutation.SetAssignedByUserID(id)
return _c
}
// SetNillableAssignedByUserID sets the "assigned_by_user" edge to the User entity by ID if the given value is not nil.
func (_c *UserSubscriptionCreate) SetNillableAssignedByUserID(id *int64) *UserSubscriptionCreate {
if id != nil {
_c = _c.SetAssignedByUserID(*id)
}
return _c
}
// SetAssignedByUser sets the "assigned_by_user" edge to the User entity.
func (_c *UserSubscriptionCreate) SetAssignedByUser(v *User) *UserSubscriptionCreate {
return _c.SetAssignedByUserID(v.ID)
}
// Mutation returns the UserSubscriptionMutation object of the builder.
func (_c *UserSubscriptionCreate) Mutation() *UserSubscriptionMutation {
return _c.mutation
}
// Save creates the UserSubscription in the database.
func (_c *UserSubscriptionCreate) Save(ctx context.Context) (*UserSubscription, error) {
_c.defaults()
return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
}
// SaveX calls Save and panics if Save returns an error.
func (_c *UserSubscriptionCreate) SaveX(ctx context.Context) *UserSubscription {
v, err := _c.Save(ctx)
if err != nil {
panic(err)
}
return v
}
// Exec executes the query.
func (_c *UserSubscriptionCreate) Exec(ctx context.Context) error {
_, err := _c.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_c *UserSubscriptionCreate) ExecX(ctx context.Context) {
if err := _c.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_c *UserSubscriptionCreate) defaults() {
if _, ok := _c.mutation.CreatedAt(); !ok {
v := usersubscription.DefaultCreatedAt()
_c.mutation.SetCreatedAt(v)
}
if _, ok := _c.mutation.UpdatedAt(); !ok {
v := usersubscription.DefaultUpdatedAt()
_c.mutation.SetUpdatedAt(v)
}
if _, ok := _c.mutation.Status(); !ok {
v := usersubscription.DefaultStatus
_c.mutation.SetStatus(v)
}
if _, ok := _c.mutation.DailyUsageUsd(); !ok {
v := usersubscription.DefaultDailyUsageUsd
_c.mutation.SetDailyUsageUsd(v)
}
if _, ok := _c.mutation.WeeklyUsageUsd(); !ok {
v := usersubscription.DefaultWeeklyUsageUsd
_c.mutation.SetWeeklyUsageUsd(v)
}
if _, ok := _c.mutation.MonthlyUsageUsd(); !ok {
v := usersubscription.DefaultMonthlyUsageUsd
_c.mutation.SetMonthlyUsageUsd(v)
}
if _, ok := _c.mutation.AssignedAt(); !ok {
v := usersubscription.DefaultAssignedAt()
_c.mutation.SetAssignedAt(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_c *UserSubscriptionCreate) check() error {
if _, ok := _c.mutation.CreatedAt(); !ok {
return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "UserSubscription.created_at"`)}
}
if _, ok := _c.mutation.UpdatedAt(); !ok {
return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "UserSubscription.updated_at"`)}
}
if _, ok := _c.mutation.UserID(); !ok {
return &ValidationError{Name: "user_id", err: errors.New(`ent: missing required field "UserSubscription.user_id"`)}
}
if _, ok := _c.mutation.GroupID(); !ok {
return &ValidationError{Name: "group_id", err: errors.New(`ent: missing required field "UserSubscription.group_id"`)}
}
if _, ok := _c.mutation.StartsAt(); !ok {
return &ValidationError{Name: "starts_at", err: errors.New(`ent: missing required field "UserSubscription.starts_at"`)}
}
if _, ok := _c.mutation.ExpiresAt(); !ok {
return &ValidationError{Name: "expires_at", err: errors.New(`ent: missing required field "UserSubscription.expires_at"`)}
}
if _, ok := _c.mutation.Status(); !ok {
return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "UserSubscription.status"`)}
}
if v, ok := _c.mutation.Status(); ok {
if err := usersubscription.StatusValidator(v); err != nil {
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "UserSubscription.status": %w`, err)}
}
}
if _, ok := _c.mutation.DailyUsageUsd(); !ok {
return &ValidationError{Name: "daily_usage_usd", err: errors.New(`ent: missing required field "UserSubscription.daily_usage_usd"`)}
}
if _, ok := _c.mutation.WeeklyUsageUsd(); !ok {
return &ValidationError{Name: "weekly_usage_usd", err: errors.New(`ent: missing required field "UserSubscription.weekly_usage_usd"`)}
}
if _, ok := _c.mutation.MonthlyUsageUsd(); !ok {
return &ValidationError{Name: "monthly_usage_usd", err: errors.New(`ent: missing required field "UserSubscription.monthly_usage_usd"`)}
}
if _, ok := _c.mutation.AssignedAt(); !ok {
return &ValidationError{Name: "assigned_at", err: errors.New(`ent: missing required field "UserSubscription.assigned_at"`)}
}
if len(_c.mutation.UserIDs()) == 0 {
return &ValidationError{Name: "user", err: errors.New(`ent: missing required edge "UserSubscription.user"`)}
}
if len(_c.mutation.GroupIDs()) == 0 {
return &ValidationError{Name: "group", err: errors.New(`ent: missing required edge "UserSubscription.group"`)}
}
return nil
}
func (_c *UserSubscriptionCreate) sqlSave(ctx context.Context) (*UserSubscription, error) {
if err := _c.check(); err != nil {
return nil, err
}
_node, _spec := _c.createSpec()
if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return nil, err
}
id := _spec.ID.Value.(int64)
_node.ID = int64(id)
_c.mutation.id = &_node.ID
_c.mutation.done = true
return _node, nil
}
func (_c *UserSubscriptionCreate) createSpec() (*UserSubscription, *sqlgraph.CreateSpec) {
var (
_node = &UserSubscription{config: _c.config}
_spec = sqlgraph.NewCreateSpec(usersubscription.Table, sqlgraph.NewFieldSpec(usersubscription.FieldID, field.TypeInt64))
)
_spec.OnConflict = _c.conflict
if value, ok := _c.mutation.CreatedAt(); ok {
_spec.SetField(usersubscription.FieldCreatedAt, field.TypeTime, value)
_node.CreatedAt = value
}
if value, ok := _c.mutation.UpdatedAt(); ok {
_spec.SetField(usersubscription.FieldUpdatedAt, field.TypeTime, value)
_node.UpdatedAt = value
}
if value, ok := _c.mutation.StartsAt(); ok {
_spec.SetField(usersubscription.FieldStartsAt, field.TypeTime, value)
_node.StartsAt = value
}
if value, ok := _c.mutation.ExpiresAt(); ok {
_spec.SetField(usersubscription.FieldExpiresAt, field.TypeTime, value)
_node.ExpiresAt = value
}
if value, ok := _c.mutation.Status(); ok {
_spec.SetField(usersubscription.FieldStatus, field.TypeString, value)
_node.Status = value
}
if value, ok := _c.mutation.DailyWindowStart(); ok {
_spec.SetField(usersubscription.FieldDailyWindowStart, field.TypeTime, value)
_node.DailyWindowStart = &value
}
if value, ok := _c.mutation.WeeklyWindowStart(); ok {
_spec.SetField(usersubscription.FieldWeeklyWindowStart, field.TypeTime, value)
_node.WeeklyWindowStart = &value
}
if value, ok := _c.mutation.MonthlyWindowStart(); ok {
_spec.SetField(usersubscription.FieldMonthlyWindowStart, field.TypeTime, value)
_node.MonthlyWindowStart = &value
}
if value, ok := _c.mutation.DailyUsageUsd(); ok {
_spec.SetField(usersubscription.FieldDailyUsageUsd, field.TypeFloat64, value)
_node.DailyUsageUsd = value
}
if value, ok := _c.mutation.WeeklyUsageUsd(); ok {
_spec.SetField(usersubscription.FieldWeeklyUsageUsd, field.TypeFloat64, value)
_node.WeeklyUsageUsd = value
}
if value, ok := _c.mutation.MonthlyUsageUsd(); ok {
_spec.SetField(usersubscription.FieldMonthlyUsageUsd, field.TypeFloat64, value)
_node.MonthlyUsageUsd = value
}
if value, ok := _c.mutation.AssignedAt(); ok {
_spec.SetField(usersubscription.FieldAssignedAt, field.TypeTime, value)
_node.AssignedAt = value
}
if value, ok := _c.mutation.Notes(); ok {
_spec.SetField(usersubscription.FieldNotes, field.TypeString, value)
_node.Notes = &value
}
if nodes := _c.mutation.UserIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: usersubscription.UserTable,
Columns: []string{usersubscription.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_node.UserID = nodes[0]
_spec.Edges = append(_spec.Edges, edge)
}
if nodes := _c.mutation.GroupIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: usersubscription.GroupTable,
Columns: []string{usersubscription.GroupColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_node.GroupID = nodes[0]
_spec.Edges = append(_spec.Edges, edge)
}
if nodes := _c.mutation.AssignedByUserIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: usersubscription.AssignedByUserTable,
Columns: []string{usersubscription.AssignedByUserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_node.AssignedBy = &nodes[0]
_spec.Edges = append(_spec.Edges, edge)
}
return _node, _spec
}
// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
// of the `INSERT` statement. For example:
//
// client.UserSubscription.Create().
// SetCreatedAt(v).
// OnConflict(
// // Update the row with the new values
// // the was proposed for insertion.
// sql.ResolveWithNewValues(),
// ).
// // Override some of the fields with custom
// // update values.
// Update(func(u *ent.UserSubscriptionUpsert) {
// SetCreatedAt(v+v).
// }).
// Exec(ctx)
func (_c *UserSubscriptionCreate) OnConflict(opts ...sql.ConflictOption) *UserSubscriptionUpsertOne {
_c.conflict = opts
return &UserSubscriptionUpsertOne{
create: _c,
}
}
// OnConflictColumns calls `OnConflict` and configures the columns
// as conflict target. Using this option is equivalent to using:
//
// client.UserSubscription.Create().
// OnConflict(sql.ConflictColumns(columns...)).
// Exec(ctx)
func (_c *UserSubscriptionCreate) OnConflictColumns(columns ...string) *UserSubscriptionUpsertOne {
_c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
return &UserSubscriptionUpsertOne{
create: _c,
}
}
type (
// UserSubscriptionUpsertOne is the builder for "upsert"-ing
// one UserSubscription node.
UserSubscriptionUpsertOne struct {
create *UserSubscriptionCreate
}
// UserSubscriptionUpsert is the "OnConflict" setter.
UserSubscriptionUpsert struct {
*sql.UpdateSet
}
)
// SetUpdatedAt sets the "updated_at" field.
func (u *UserSubscriptionUpsert) SetUpdatedAt(v time.Time) *UserSubscriptionUpsert {
u.Set(usersubscription.FieldUpdatedAt, v)
return u
}
// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
func (u *UserSubscriptionUpsert) UpdateUpdatedAt() *UserSubscriptionUpsert {
u.SetExcluded(usersubscription.FieldUpdatedAt)
return u
}
// SetUserID sets the "user_id" field.
func (u *UserSubscriptionUpsert) SetUserID(v int64) *UserSubscriptionUpsert {
u.Set(usersubscription.FieldUserID, v)
return u
}
// UpdateUserID sets the "user_id" field to the value that was provided on create.
func (u *UserSubscriptionUpsert) UpdateUserID() *UserSubscriptionUpsert {
u.SetExcluded(usersubscription.FieldUserID)
return u
}
// SetGroupID sets the "group_id" field.
func (u *UserSubscriptionUpsert) SetGroupID(v int64) *UserSubscriptionUpsert {
u.Set(usersubscription.FieldGroupID, v)
return u
}
// UpdateGroupID sets the "group_id" field to the value that was provided on create.
func (u *UserSubscriptionUpsert) UpdateGroupID() *UserSubscriptionUpsert {
u.SetExcluded(usersubscription.FieldGroupID)
return u
}
// SetStartsAt sets the "starts_at" field.
func (u *UserSubscriptionUpsert) SetStartsAt(v time.Time) *UserSubscriptionUpsert {
u.Set(usersubscription.FieldStartsAt, v)
return u
}
// UpdateStartsAt sets the "starts_at" field to the value that was provided on create.
func (u *UserSubscriptionUpsert) UpdateStartsAt() *UserSubscriptionUpsert {
u.SetExcluded(usersubscription.FieldStartsAt)
return u
}
// SetExpiresAt sets the "expires_at" field.
func (u *UserSubscriptionUpsert) SetExpiresAt(v time.Time) *UserSubscriptionUpsert {
u.Set(usersubscription.FieldExpiresAt, v)
return u
}
// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
func (u *UserSubscriptionUpsert) UpdateExpiresAt() *UserSubscriptionUpsert {
u.SetExcluded(usersubscription.FieldExpiresAt)
return u
}
// SetStatus sets the "status" field.
func (u *UserSubscriptionUpsert) SetStatus(v string) *UserSubscriptionUpsert {
u.Set(usersubscription.FieldStatus, v)
return u
}
// UpdateStatus sets the "status" field to the value that was provided on create.
func (u *UserSubscriptionUpsert) UpdateStatus() *UserSubscriptionUpsert {
u.SetExcluded(usersubscription.FieldStatus)
return u
}
// SetDailyWindowStart sets the "daily_window_start" field.
func (u *UserSubscriptionUpsert) SetDailyWindowStart(v time.Time) *UserSubscriptionUpsert {
u.Set(usersubscription.FieldDailyWindowStart, v)
return u
}
// UpdateDailyWindowStart sets the "daily_window_start" field to the value that was provided on create.
func (u *UserSubscriptionUpsert) UpdateDailyWindowStart() *UserSubscriptionUpsert {
u.SetExcluded(usersubscription.FieldDailyWindowStart)
return u
}
// ClearDailyWindowStart clears the value of the "daily_window_start" field.
func (u *UserSubscriptionUpsert) ClearDailyWindowStart() *UserSubscriptionUpsert {
u.SetNull(usersubscription.FieldDailyWindowStart)
return u
}
// SetWeeklyWindowStart sets the "weekly_window_start" field.
func (u *UserSubscriptionUpsert) SetWeeklyWindowStart(v time.Time) *UserSubscriptionUpsert {
u.Set(usersubscription.FieldWeeklyWindowStart, v)
return u
}
// UpdateWeeklyWindowStart sets the "weekly_window_start" field to the value that was provided on create.
func (u *UserSubscriptionUpsert) UpdateWeeklyWindowStart() *UserSubscriptionUpsert {
u.SetExcluded(usersubscription.FieldWeeklyWindowStart)
return u
}
// ClearWeeklyWindowStart clears the value of the "weekly_window_start" field.
func (u *UserSubscriptionUpsert) ClearWeeklyWindowStart() *UserSubscriptionUpsert {
u.SetNull(usersubscription.FieldWeeklyWindowStart)
return u
}
// SetMonthlyWindowStart sets the "monthly_window_start" field.
func (u *UserSubscriptionUpsert) SetMonthlyWindowStart(v time.Time) *UserSubscriptionUpsert {
u.Set(usersubscription.FieldMonthlyWindowStart, v)
return u
}
// UpdateMonthlyWindowStart sets the "monthly_window_start" field to the value that was provided on create.
func (u *UserSubscriptionUpsert) UpdateMonthlyWindowStart() *UserSubscriptionUpsert {
u.SetExcluded(usersubscription.FieldMonthlyWindowStart)
return u
}
// ClearMonthlyWindowStart clears the value of the "monthly_window_start" field.
func (u *UserSubscriptionUpsert) ClearMonthlyWindowStart() *UserSubscriptionUpsert {
u.SetNull(usersubscription.FieldMonthlyWindowStart)
return u
}
// SetDailyUsageUsd sets the "daily_usage_usd" field.
func (u *UserSubscriptionUpsert) SetDailyUsageUsd(v float64) *UserSubscriptionUpsert {
u.Set(usersubscription.FieldDailyUsageUsd, v)
return u
}
// UpdateDailyUsageUsd sets the "daily_usage_usd" field to the value that was provided on create.
func (u *UserSubscriptionUpsert) UpdateDailyUsageUsd() *UserSubscriptionUpsert {
u.SetExcluded(usersubscription.FieldDailyUsageUsd)
return u
}
// AddDailyUsageUsd adds v to the "daily_usage_usd" field.
func (u *UserSubscriptionUpsert) AddDailyUsageUsd(v float64) *UserSubscriptionUpsert {
u.Add(usersubscription.FieldDailyUsageUsd, v)
return u
}
// SetWeeklyUsageUsd sets the "weekly_usage_usd" field.
func (u *UserSubscriptionUpsert) SetWeeklyUsageUsd(v float64) *UserSubscriptionUpsert {
u.Set(usersubscription.FieldWeeklyUsageUsd, v)
return u
}
// UpdateWeeklyUsageUsd sets the "weekly_usage_usd" field to the value that was provided on create.
func (u *UserSubscriptionUpsert) UpdateWeeklyUsageUsd() *UserSubscriptionUpsert {
u.SetExcluded(usersubscription.FieldWeeklyUsageUsd)
return u
}
// AddWeeklyUsageUsd adds v to the "weekly_usage_usd" field.
func (u *UserSubscriptionUpsert) AddWeeklyUsageUsd(v float64) *UserSubscriptionUpsert {
u.Add(usersubscription.FieldWeeklyUsageUsd, v)
return u
}
// SetMonthlyUsageUsd sets the "monthly_usage_usd" field.
func (u *UserSubscriptionUpsert) SetMonthlyUsageUsd(v float64) *UserSubscriptionUpsert {
u.Set(usersubscription.FieldMonthlyUsageUsd, v)
return u
}
// UpdateMonthlyUsageUsd sets the "monthly_usage_usd" field to the value that was provided on create.
func (u *UserSubscriptionUpsert) UpdateMonthlyUsageUsd() *UserSubscriptionUpsert {
u.SetExcluded(usersubscription.FieldMonthlyUsageUsd)
return u
}
// AddMonthlyUsageUsd adds v to the "monthly_usage_usd" field.
func (u *UserSubscriptionUpsert) AddMonthlyUsageUsd(v float64) *UserSubscriptionUpsert {
u.Add(usersubscription.FieldMonthlyUsageUsd, v)
return u
}
// SetAssignedBy sets the "assigned_by" field.
func (u *UserSubscriptionUpsert) SetAssignedBy(v int64) *UserSubscriptionUpsert {
u.Set(usersubscription.FieldAssignedBy, v)
return u
}
// UpdateAssignedBy sets the "assigned_by" field to the value that was provided on create.
func (u *UserSubscriptionUpsert) UpdateAssignedBy() *UserSubscriptionUpsert {
u.SetExcluded(usersubscription.FieldAssignedBy)
return u
}
// ClearAssignedBy clears the value of the "assigned_by" field.
func (u *UserSubscriptionUpsert) ClearAssignedBy() *UserSubscriptionUpsert {
u.SetNull(usersubscription.FieldAssignedBy)
return u
}
// SetAssignedAt sets the "assigned_at" field.
func (u *UserSubscriptionUpsert) SetAssignedAt(v time.Time) *UserSubscriptionUpsert {
u.Set(usersubscription.FieldAssignedAt, v)
return u
}
// UpdateAssignedAt sets the "assigned_at" field to the value that was provided on create.
func (u *UserSubscriptionUpsert) UpdateAssignedAt() *UserSubscriptionUpsert {
u.SetExcluded(usersubscription.FieldAssignedAt)
return u
}
// SetNotes sets the "notes" field.
func (u *UserSubscriptionUpsert) SetNotes(v string) *UserSubscriptionUpsert {
u.Set(usersubscription.FieldNotes, v)
return u
}
// UpdateNotes sets the "notes" field to the value that was provided on create.
func (u *UserSubscriptionUpsert) UpdateNotes() *UserSubscriptionUpsert {
u.SetExcluded(usersubscription.FieldNotes)
return u
}
// ClearNotes clears the value of the "notes" field.
func (u *UserSubscriptionUpsert) ClearNotes() *UserSubscriptionUpsert {
u.SetNull(usersubscription.FieldNotes)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
// client.UserSubscription.Create().
// OnConflict(
// sql.ResolveWithNewValues(),
// ).
// Exec(ctx)
func (u *UserSubscriptionUpsertOne) UpdateNewValues() *UserSubscriptionUpsertOne {
u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
if _, exists := u.create.mutation.CreatedAt(); exists {
s.SetIgnore(usersubscription.FieldCreatedAt)
}
}))
return u
}
// Ignore sets each column to itself in case of conflict.
// Using this option is equivalent to using:
//
// client.UserSubscription.Create().
// OnConflict(sql.ResolveWithIgnore()).
// Exec(ctx)
func (u *UserSubscriptionUpsertOne) Ignore() *UserSubscriptionUpsertOne {
u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
return u
}
// DoNothing configures the conflict_action to `DO NOTHING`.
// Supported only by SQLite and PostgreSQL.
func (u *UserSubscriptionUpsertOne) DoNothing() *UserSubscriptionUpsertOne {
u.create.conflict = append(u.create.conflict, sql.DoNothing())
return u
}
// Update allows overriding fields `UPDATE` values. See the UserSubscriptionCreate.OnConflict
// documentation for more info.
func (u *UserSubscriptionUpsertOne) Update(set func(*UserSubscriptionUpsert)) *UserSubscriptionUpsertOne {
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
set(&UserSubscriptionUpsert{UpdateSet: update})
}))
return u
}
// SetUpdatedAt sets the "updated_at" field.
func (u *UserSubscriptionUpsertOne) SetUpdatedAt(v time.Time) *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.SetUpdatedAt(v)
})
}
// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
func (u *UserSubscriptionUpsertOne) UpdateUpdatedAt() *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.UpdateUpdatedAt()
})
}
// SetUserID sets the "user_id" field.
func (u *UserSubscriptionUpsertOne) SetUserID(v int64) *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.SetUserID(v)
})
}
// UpdateUserID sets the "user_id" field to the value that was provided on create.
func (u *UserSubscriptionUpsertOne) UpdateUserID() *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.UpdateUserID()
})
}
// SetGroupID sets the "group_id" field.
func (u *UserSubscriptionUpsertOne) SetGroupID(v int64) *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.SetGroupID(v)
})
}
// UpdateGroupID sets the "group_id" field to the value that was provided on create.
func (u *UserSubscriptionUpsertOne) UpdateGroupID() *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.UpdateGroupID()
})
}
// SetStartsAt sets the "starts_at" field.
func (u *UserSubscriptionUpsertOne) SetStartsAt(v time.Time) *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.SetStartsAt(v)
})
}
// UpdateStartsAt sets the "starts_at" field to the value that was provided on create.
func (u *UserSubscriptionUpsertOne) UpdateStartsAt() *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.UpdateStartsAt()
})
}
// SetExpiresAt sets the "expires_at" field.
func (u *UserSubscriptionUpsertOne) SetExpiresAt(v time.Time) *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.SetExpiresAt(v)
})
}
// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
func (u *UserSubscriptionUpsertOne) UpdateExpiresAt() *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.UpdateExpiresAt()
})
}
// SetStatus sets the "status" field.
func (u *UserSubscriptionUpsertOne) SetStatus(v string) *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.SetStatus(v)
})
}
// UpdateStatus sets the "status" field to the value that was provided on create.
func (u *UserSubscriptionUpsertOne) UpdateStatus() *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.UpdateStatus()
})
}
// SetDailyWindowStart sets the "daily_window_start" field.
func (u *UserSubscriptionUpsertOne) SetDailyWindowStart(v time.Time) *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.SetDailyWindowStart(v)
})
}
// UpdateDailyWindowStart sets the "daily_window_start" field to the value that was provided on create.
func (u *UserSubscriptionUpsertOne) UpdateDailyWindowStart() *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.UpdateDailyWindowStart()
})
}
// ClearDailyWindowStart clears the value of the "daily_window_start" field.
func (u *UserSubscriptionUpsertOne) ClearDailyWindowStart() *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.ClearDailyWindowStart()
})
}
// SetWeeklyWindowStart sets the "weekly_window_start" field.
func (u *UserSubscriptionUpsertOne) SetWeeklyWindowStart(v time.Time) *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.SetWeeklyWindowStart(v)
})
}
// UpdateWeeklyWindowStart sets the "weekly_window_start" field to the value that was provided on create.
func (u *UserSubscriptionUpsertOne) UpdateWeeklyWindowStart() *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.UpdateWeeklyWindowStart()
})
}
// ClearWeeklyWindowStart clears the value of the "weekly_window_start" field.
func (u *UserSubscriptionUpsertOne) ClearWeeklyWindowStart() *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.ClearWeeklyWindowStart()
})
}
// SetMonthlyWindowStart sets the "monthly_window_start" field.
func (u *UserSubscriptionUpsertOne) SetMonthlyWindowStart(v time.Time) *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.SetMonthlyWindowStart(v)
})
}
// UpdateMonthlyWindowStart sets the "monthly_window_start" field to the value that was provided on create.
func (u *UserSubscriptionUpsertOne) UpdateMonthlyWindowStart() *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.UpdateMonthlyWindowStart()
})
}
// ClearMonthlyWindowStart clears the value of the "monthly_window_start" field.
func (u *UserSubscriptionUpsertOne) ClearMonthlyWindowStart() *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.ClearMonthlyWindowStart()
})
}
// SetDailyUsageUsd sets the "daily_usage_usd" field.
func (u *UserSubscriptionUpsertOne) SetDailyUsageUsd(v float64) *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.SetDailyUsageUsd(v)
})
}
// AddDailyUsageUsd adds v to the "daily_usage_usd" field.
func (u *UserSubscriptionUpsertOne) AddDailyUsageUsd(v float64) *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.AddDailyUsageUsd(v)
})
}
// UpdateDailyUsageUsd sets the "daily_usage_usd" field to the value that was provided on create.
func (u *UserSubscriptionUpsertOne) UpdateDailyUsageUsd() *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.UpdateDailyUsageUsd()
})
}
// SetWeeklyUsageUsd sets the "weekly_usage_usd" field.
func (u *UserSubscriptionUpsertOne) SetWeeklyUsageUsd(v float64) *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.SetWeeklyUsageUsd(v)
})
}
// AddWeeklyUsageUsd adds v to the "weekly_usage_usd" field.
func (u *UserSubscriptionUpsertOne) AddWeeklyUsageUsd(v float64) *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.AddWeeklyUsageUsd(v)
})
}
// UpdateWeeklyUsageUsd sets the "weekly_usage_usd" field to the value that was provided on create.
func (u *UserSubscriptionUpsertOne) UpdateWeeklyUsageUsd() *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.UpdateWeeklyUsageUsd()
})
}
// SetMonthlyUsageUsd sets the "monthly_usage_usd" field.
func (u *UserSubscriptionUpsertOne) SetMonthlyUsageUsd(v float64) *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.SetMonthlyUsageUsd(v)
})
}
// AddMonthlyUsageUsd adds v to the "monthly_usage_usd" field.
func (u *UserSubscriptionUpsertOne) AddMonthlyUsageUsd(v float64) *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.AddMonthlyUsageUsd(v)
})
}
// UpdateMonthlyUsageUsd sets the "monthly_usage_usd" field to the value that was provided on create.
func (u *UserSubscriptionUpsertOne) UpdateMonthlyUsageUsd() *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.UpdateMonthlyUsageUsd()
})
}
// SetAssignedBy sets the "assigned_by" field.
func (u *UserSubscriptionUpsertOne) SetAssignedBy(v int64) *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.SetAssignedBy(v)
})
}
// UpdateAssignedBy sets the "assigned_by" field to the value that was provided on create.
func (u *UserSubscriptionUpsertOne) UpdateAssignedBy() *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.UpdateAssignedBy()
})
}
// ClearAssignedBy clears the value of the "assigned_by" field.
func (u *UserSubscriptionUpsertOne) ClearAssignedBy() *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.ClearAssignedBy()
})
}
// SetAssignedAt sets the "assigned_at" field.
func (u *UserSubscriptionUpsertOne) SetAssignedAt(v time.Time) *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.SetAssignedAt(v)
})
}
// UpdateAssignedAt sets the "assigned_at" field to the value that was provided on create.
func (u *UserSubscriptionUpsertOne) UpdateAssignedAt() *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.UpdateAssignedAt()
})
}
// SetNotes sets the "notes" field.
func (u *UserSubscriptionUpsertOne) SetNotes(v string) *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.SetNotes(v)
})
}
// UpdateNotes sets the "notes" field to the value that was provided on create.
func (u *UserSubscriptionUpsertOne) UpdateNotes() *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.UpdateNotes()
})
}
// ClearNotes clears the value of the "notes" field.
func (u *UserSubscriptionUpsertOne) ClearNotes() *UserSubscriptionUpsertOne {
return u.Update(func(s *UserSubscriptionUpsert) {
s.ClearNotes()
})
}
// Exec executes the query.
func (u *UserSubscriptionUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
return errors.New("ent: missing options for UserSubscriptionCreate.OnConflict")
}
return u.create.Exec(ctx)
}
// ExecX is like Exec, but panics if an error occurs.
func (u *UserSubscriptionUpsertOne) ExecX(ctx context.Context) {
if err := u.create.Exec(ctx); err != nil {
panic(err)
}
}
// Exec executes the UPSERT query and returns the inserted/updated ID.
func (u *UserSubscriptionUpsertOne) ID(ctx context.Context) (id int64, err error) {
node, err := u.create.Save(ctx)
if err != nil {
return id, err
}
return node.ID, nil
}
// IDX is like ID, but panics if an error occurs.
func (u *UserSubscriptionUpsertOne) IDX(ctx context.Context) int64 {
id, err := u.ID(ctx)
if err != nil {
panic(err)
}
return id
}
// UserSubscriptionCreateBulk is the builder for creating many UserSubscription entities in bulk.
type UserSubscriptionCreateBulk struct {
config
err error
builders []*UserSubscriptionCreate
conflict []sql.ConflictOption
}
// Save creates the UserSubscription entities in the database.
func (_c *UserSubscriptionCreateBulk) Save(ctx context.Context) ([]*UserSubscription, error) {
if _c.err != nil {
return nil, _c.err
}
specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
nodes := make([]*UserSubscription, len(_c.builders))
mutators := make([]Mutator, len(_c.builders))
for i := range _c.builders {
func(i int, root context.Context) {
builder := _c.builders[i]
builder.defaults()
var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
mutation, ok := m.(*UserSubscriptionMutation)
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T", m)
}
if err := builder.check(); err != nil {
return nil, err
}
builder.mutation = mutation
var err error
nodes[i], specs[i] = builder.createSpec()
if i < len(mutators)-1 {
_, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
} else {
spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
spec.OnConflict = _c.conflict
// Invoke the actual operation on the latest mutation in the chain.
if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
}
}
if err != nil {
return nil, err
}
mutation.id = &nodes[i].ID
if specs[i].ID.Value != nil {
id := specs[i].ID.Value.(int64)
nodes[i].ID = int64(id)
}
mutation.done = true
return nodes[i], nil
})
for i := len(builder.hooks) - 1; i >= 0; i-- {
mut = builder.hooks[i](mut)
}
mutators[i] = mut
}(i, ctx)
}
if len(mutators) > 0 {
if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
return nil, err
}
}
return nodes, nil
}
// SaveX is like Save, but panics if an error occurs.
func (_c *UserSubscriptionCreateBulk) SaveX(ctx context.Context) []*UserSubscription {
v, err := _c.Save(ctx)
if err != nil {
panic(err)
}
return v
}
// Exec executes the query.
func (_c *UserSubscriptionCreateBulk) Exec(ctx context.Context) error {
_, err := _c.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_c *UserSubscriptionCreateBulk) ExecX(ctx context.Context) {
if err := _c.Exec(ctx); err != nil {
panic(err)
}
}
// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
// of the `INSERT` statement. For example:
//
// client.UserSubscription.CreateBulk(builders...).
// OnConflict(
// // Update the row with the new values
// // the was proposed for insertion.
// sql.ResolveWithNewValues(),
// ).
// // Override some of the fields with custom
// // update values.
// Update(func(u *ent.UserSubscriptionUpsert) {
// SetCreatedAt(v+v).
// }).
// Exec(ctx)
func (_c *UserSubscriptionCreateBulk) OnConflict(opts ...sql.ConflictOption) *UserSubscriptionUpsertBulk {
_c.conflict = opts
return &UserSubscriptionUpsertBulk{
create: _c,
}
}
// OnConflictColumns calls `OnConflict` and configures the columns
// as conflict target. Using this option is equivalent to using:
//
// client.UserSubscription.Create().
// OnConflict(sql.ConflictColumns(columns...)).
// Exec(ctx)
func (_c *UserSubscriptionCreateBulk) OnConflictColumns(columns ...string) *UserSubscriptionUpsertBulk {
_c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
return &UserSubscriptionUpsertBulk{
create: _c,
}
}
// UserSubscriptionUpsertBulk is the builder for "upsert"-ing
// a bulk of UserSubscription nodes.
type UserSubscriptionUpsertBulk struct {
create *UserSubscriptionCreateBulk
}
// UpdateNewValues updates the mutable fields using the new values that
// were set on create. Using this option is equivalent to using:
//
// client.UserSubscription.Create().
// OnConflict(
// sql.ResolveWithNewValues(),
// ).
// Exec(ctx)
func (u *UserSubscriptionUpsertBulk) UpdateNewValues() *UserSubscriptionUpsertBulk {
u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
for _, b := range u.create.builders {
if _, exists := b.mutation.CreatedAt(); exists {
s.SetIgnore(usersubscription.FieldCreatedAt)
}
}
}))
return u
}
// Ignore sets each column to itself in case of conflict.
// Using this option is equivalent to using:
//
// client.UserSubscription.Create().
// OnConflict(sql.ResolveWithIgnore()).
// Exec(ctx)
func (u *UserSubscriptionUpsertBulk) Ignore() *UserSubscriptionUpsertBulk {
u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
return u
}
// DoNothing configures the conflict_action to `DO NOTHING`.
// Supported only by SQLite and PostgreSQL.
func (u *UserSubscriptionUpsertBulk) DoNothing() *UserSubscriptionUpsertBulk {
u.create.conflict = append(u.create.conflict, sql.DoNothing())
return u
}
// Update allows overriding fields `UPDATE` values. See the UserSubscriptionCreateBulk.OnConflict
// documentation for more info.
func (u *UserSubscriptionUpsertBulk) Update(set func(*UserSubscriptionUpsert)) *UserSubscriptionUpsertBulk {
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
set(&UserSubscriptionUpsert{UpdateSet: update})
}))
return u
}
// SetUpdatedAt sets the "updated_at" field.
func (u *UserSubscriptionUpsertBulk) SetUpdatedAt(v time.Time) *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.SetUpdatedAt(v)
})
}
// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
func (u *UserSubscriptionUpsertBulk) UpdateUpdatedAt() *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.UpdateUpdatedAt()
})
}
// SetUserID sets the "user_id" field.
func (u *UserSubscriptionUpsertBulk) SetUserID(v int64) *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.SetUserID(v)
})
}
// UpdateUserID sets the "user_id" field to the value that was provided on create.
func (u *UserSubscriptionUpsertBulk) UpdateUserID() *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.UpdateUserID()
})
}
// SetGroupID sets the "group_id" field.
func (u *UserSubscriptionUpsertBulk) SetGroupID(v int64) *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.SetGroupID(v)
})
}
// UpdateGroupID sets the "group_id" field to the value that was provided on create.
func (u *UserSubscriptionUpsertBulk) UpdateGroupID() *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.UpdateGroupID()
})
}
// SetStartsAt sets the "starts_at" field.
func (u *UserSubscriptionUpsertBulk) SetStartsAt(v time.Time) *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.SetStartsAt(v)
})
}
// UpdateStartsAt sets the "starts_at" field to the value that was provided on create.
func (u *UserSubscriptionUpsertBulk) UpdateStartsAt() *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.UpdateStartsAt()
})
}
// SetExpiresAt sets the "expires_at" field.
func (u *UserSubscriptionUpsertBulk) SetExpiresAt(v time.Time) *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.SetExpiresAt(v)
})
}
// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
func (u *UserSubscriptionUpsertBulk) UpdateExpiresAt() *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.UpdateExpiresAt()
})
}
// SetStatus sets the "status" field.
func (u *UserSubscriptionUpsertBulk) SetStatus(v string) *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.SetStatus(v)
})
}
// UpdateStatus sets the "status" field to the value that was provided on create.
func (u *UserSubscriptionUpsertBulk) UpdateStatus() *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.UpdateStatus()
})
}
// SetDailyWindowStart sets the "daily_window_start" field.
func (u *UserSubscriptionUpsertBulk) SetDailyWindowStart(v time.Time) *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.SetDailyWindowStart(v)
})
}
// UpdateDailyWindowStart sets the "daily_window_start" field to the value that was provided on create.
func (u *UserSubscriptionUpsertBulk) UpdateDailyWindowStart() *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.UpdateDailyWindowStart()
})
}
// ClearDailyWindowStart clears the value of the "daily_window_start" field.
func (u *UserSubscriptionUpsertBulk) ClearDailyWindowStart() *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.ClearDailyWindowStart()
})
}
// SetWeeklyWindowStart sets the "weekly_window_start" field.
func (u *UserSubscriptionUpsertBulk) SetWeeklyWindowStart(v time.Time) *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.SetWeeklyWindowStart(v)
})
}
// UpdateWeeklyWindowStart sets the "weekly_window_start" field to the value that was provided on create.
func (u *UserSubscriptionUpsertBulk) UpdateWeeklyWindowStart() *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.UpdateWeeklyWindowStart()
})
}
// ClearWeeklyWindowStart clears the value of the "weekly_window_start" field.
func (u *UserSubscriptionUpsertBulk) ClearWeeklyWindowStart() *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.ClearWeeklyWindowStart()
})
}
// SetMonthlyWindowStart sets the "monthly_window_start" field.
func (u *UserSubscriptionUpsertBulk) SetMonthlyWindowStart(v time.Time) *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.SetMonthlyWindowStart(v)
})
}
// UpdateMonthlyWindowStart sets the "monthly_window_start" field to the value that was provided on create.
func (u *UserSubscriptionUpsertBulk) UpdateMonthlyWindowStart() *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.UpdateMonthlyWindowStart()
})
}
// ClearMonthlyWindowStart clears the value of the "monthly_window_start" field.
func (u *UserSubscriptionUpsertBulk) ClearMonthlyWindowStart() *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.ClearMonthlyWindowStart()
})
}
// SetDailyUsageUsd sets the "daily_usage_usd" field.
func (u *UserSubscriptionUpsertBulk) SetDailyUsageUsd(v float64) *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.SetDailyUsageUsd(v)
})
}
// AddDailyUsageUsd adds v to the "daily_usage_usd" field.
func (u *UserSubscriptionUpsertBulk) AddDailyUsageUsd(v float64) *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.AddDailyUsageUsd(v)
})
}
// UpdateDailyUsageUsd sets the "daily_usage_usd" field to the value that was provided on create.
func (u *UserSubscriptionUpsertBulk) UpdateDailyUsageUsd() *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.UpdateDailyUsageUsd()
})
}
// SetWeeklyUsageUsd sets the "weekly_usage_usd" field.
func (u *UserSubscriptionUpsertBulk) SetWeeklyUsageUsd(v float64) *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.SetWeeklyUsageUsd(v)
})
}
// AddWeeklyUsageUsd adds v to the "weekly_usage_usd" field.
func (u *UserSubscriptionUpsertBulk) AddWeeklyUsageUsd(v float64) *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.AddWeeklyUsageUsd(v)
})
}
// UpdateWeeklyUsageUsd sets the "weekly_usage_usd" field to the value that was provided on create.
func (u *UserSubscriptionUpsertBulk) UpdateWeeklyUsageUsd() *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.UpdateWeeklyUsageUsd()
})
}
// SetMonthlyUsageUsd sets the "monthly_usage_usd" field.
func (u *UserSubscriptionUpsertBulk) SetMonthlyUsageUsd(v float64) *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.SetMonthlyUsageUsd(v)
})
}
// AddMonthlyUsageUsd adds v to the "monthly_usage_usd" field.
func (u *UserSubscriptionUpsertBulk) AddMonthlyUsageUsd(v float64) *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.AddMonthlyUsageUsd(v)
})
}
// UpdateMonthlyUsageUsd sets the "monthly_usage_usd" field to the value that was provided on create.
func (u *UserSubscriptionUpsertBulk) UpdateMonthlyUsageUsd() *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.UpdateMonthlyUsageUsd()
})
}
// SetAssignedBy sets the "assigned_by" field.
func (u *UserSubscriptionUpsertBulk) SetAssignedBy(v int64) *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.SetAssignedBy(v)
})
}
// UpdateAssignedBy sets the "assigned_by" field to the value that was provided on create.
func (u *UserSubscriptionUpsertBulk) UpdateAssignedBy() *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.UpdateAssignedBy()
})
}
// ClearAssignedBy clears the value of the "assigned_by" field.
func (u *UserSubscriptionUpsertBulk) ClearAssignedBy() *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.ClearAssignedBy()
})
}
// SetAssignedAt sets the "assigned_at" field.
func (u *UserSubscriptionUpsertBulk) SetAssignedAt(v time.Time) *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.SetAssignedAt(v)
})
}
// UpdateAssignedAt sets the "assigned_at" field to the value that was provided on create.
func (u *UserSubscriptionUpsertBulk) UpdateAssignedAt() *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.UpdateAssignedAt()
})
}
// SetNotes sets the "notes" field.
func (u *UserSubscriptionUpsertBulk) SetNotes(v string) *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.SetNotes(v)
})
}
// UpdateNotes sets the "notes" field to the value that was provided on create.
func (u *UserSubscriptionUpsertBulk) UpdateNotes() *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.UpdateNotes()
})
}
// ClearNotes clears the value of the "notes" field.
func (u *UserSubscriptionUpsertBulk) ClearNotes() *UserSubscriptionUpsertBulk {
return u.Update(func(s *UserSubscriptionUpsert) {
s.ClearNotes()
})
}
// Exec executes the query.
func (u *UserSubscriptionUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {
return u.create.err
}
for i, b := range u.create.builders {
if len(b.conflict) != 0 {
return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the UserSubscriptionCreateBulk instead", i)
}
}
if len(u.create.conflict) == 0 {
return errors.New("ent: missing options for UserSubscriptionCreateBulk.OnConflict")
}
return u.create.Exec(ctx)
}
// ExecX is like Exec, but panics if an error occurs.
func (u *UserSubscriptionUpsertBulk) ExecX(ctx context.Context) {
if err := u.create.Exec(ctx); err != nil {
panic(err)
}
}
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
// UserSubscriptionDelete is the builder for deleting a UserSubscription entity.
type UserSubscriptionDelete struct {
config
hooks []Hook
mutation *UserSubscriptionMutation
}
// Where appends a list predicates to the UserSubscriptionDelete builder.
func (_d *UserSubscriptionDelete) Where(ps ...predicate.UserSubscription) *UserSubscriptionDelete {
_d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query and returns how many vertices were deleted.
func (_d *UserSubscriptionDelete) Exec(ctx context.Context) (int, error) {
return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *UserSubscriptionDelete) ExecX(ctx context.Context) int {
n, err := _d.Exec(ctx)
if err != nil {
panic(err)
}
return n
}
func (_d *UserSubscriptionDelete) sqlExec(ctx context.Context) (int, error) {
_spec := sqlgraph.NewDeleteSpec(usersubscription.Table, sqlgraph.NewFieldSpec(usersubscription.FieldID, field.TypeInt64))
if ps := _d.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
if err != nil && sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
_d.mutation.done = true
return affected, err
}
// UserSubscriptionDeleteOne is the builder for deleting a single UserSubscription entity.
type UserSubscriptionDeleteOne struct {
_d *UserSubscriptionDelete
}
// Where appends a list predicates to the UserSubscriptionDelete builder.
func (_d *UserSubscriptionDeleteOne) Where(ps ...predicate.UserSubscription) *UserSubscriptionDeleteOne {
_d._d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query.
func (_d *UserSubscriptionDeleteOne) Exec(ctx context.Context) error {
n, err := _d._d.Exec(ctx)
switch {
case err != nil:
return err
case n == 0:
return &NotFoundError{usersubscription.Label}
default:
return nil
}
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *UserSubscriptionDeleteOne) ExecX(ctx context.Context) {
if err := _d.Exec(ctx); err != nil {
panic(err)
}
}
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"fmt"
"math"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
// UserSubscriptionQuery is the builder for querying UserSubscription entities.
type UserSubscriptionQuery struct {
config
ctx *QueryContext
order []usersubscription.OrderOption
inters []Interceptor
predicates []predicate.UserSubscription
withUser *UserQuery
withGroup *GroupQuery
withAssignedByUser *UserQuery
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
}
// Where adds a new predicate for the UserSubscriptionQuery builder.
func (_q *UserSubscriptionQuery) Where(ps ...predicate.UserSubscription) *UserSubscriptionQuery {
_q.predicates = append(_q.predicates, ps...)
return _q
}
// Limit the number of records to be returned by this query.
func (_q *UserSubscriptionQuery) Limit(limit int) *UserSubscriptionQuery {
_q.ctx.Limit = &limit
return _q
}
// Offset to start from.
func (_q *UserSubscriptionQuery) Offset(offset int) *UserSubscriptionQuery {
_q.ctx.Offset = &offset
return _q
}
// Unique configures the query builder to filter duplicate records on query.
// By default, unique is set to true, and can be disabled using this method.
func (_q *UserSubscriptionQuery) Unique(unique bool) *UserSubscriptionQuery {
_q.ctx.Unique = &unique
return _q
}
// Order specifies how the records should be ordered.
func (_q *UserSubscriptionQuery) Order(o ...usersubscription.OrderOption) *UserSubscriptionQuery {
_q.order = append(_q.order, o...)
return _q
}
// QueryUser chains the current query on the "user" edge.
func (_q *UserSubscriptionQuery) QueryUser() *UserQuery {
query := (&UserClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(usersubscription.Table, usersubscription.FieldID, selector),
sqlgraph.To(user.Table, user.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, usersubscription.UserTable, usersubscription.UserColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// QueryGroup chains the current query on the "group" edge.
func (_q *UserSubscriptionQuery) QueryGroup() *GroupQuery {
query := (&GroupClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(usersubscription.Table, usersubscription.FieldID, selector),
sqlgraph.To(group.Table, group.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, usersubscription.GroupTable, usersubscription.GroupColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// QueryAssignedByUser chains the current query on the "assigned_by_user" edge.
func (_q *UserSubscriptionQuery) QueryAssignedByUser() *UserQuery {
query := (&UserClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(usersubscription.Table, usersubscription.FieldID, selector),
sqlgraph.To(user.Table, user.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, usersubscription.AssignedByUserTable, usersubscription.AssignedByUserColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// First returns the first UserSubscription entity from the query.
// Returns a *NotFoundError when no UserSubscription was found.
func (_q *UserSubscriptionQuery) First(ctx context.Context) (*UserSubscription, error) {
nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
if err != nil {
return nil, err
}
if len(nodes) == 0 {
return nil, &NotFoundError{usersubscription.Label}
}
return nodes[0], nil
}
// FirstX is like First, but panics if an error occurs.
func (_q *UserSubscriptionQuery) FirstX(ctx context.Context) *UserSubscription {
node, err := _q.First(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return node
}
// FirstID returns the first UserSubscription ID from the query.
// Returns a *NotFoundError when no UserSubscription ID was found.
func (_q *UserSubscriptionQuery) FirstID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
return
}
if len(ids) == 0 {
err = &NotFoundError{usersubscription.Label}
return
}
return ids[0], nil
}
// FirstIDX is like FirstID, but panics if an error occurs.
func (_q *UserSubscriptionQuery) FirstIDX(ctx context.Context) int64 {
id, err := _q.FirstID(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return id
}
// Only returns a single UserSubscription entity found by the query, ensuring it only returns one.
// Returns a *NotSingularError when more than one UserSubscription entity is found.
// Returns a *NotFoundError when no UserSubscription entities are found.
func (_q *UserSubscriptionQuery) Only(ctx context.Context) (*UserSubscription, error) {
nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
if err != nil {
return nil, err
}
switch len(nodes) {
case 1:
return nodes[0], nil
case 0:
return nil, &NotFoundError{usersubscription.Label}
default:
return nil, &NotSingularError{usersubscription.Label}
}
}
// OnlyX is like Only, but panics if an error occurs.
func (_q *UserSubscriptionQuery) OnlyX(ctx context.Context) *UserSubscription {
node, err := _q.Only(ctx)
if err != nil {
panic(err)
}
return node
}
// OnlyID is like Only, but returns the only UserSubscription ID in the query.
// Returns a *NotSingularError when more than one UserSubscription ID is found.
// Returns a *NotFoundError when no entities are found.
func (_q *UserSubscriptionQuery) OnlyID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
return
}
switch len(ids) {
case 1:
id = ids[0]
case 0:
err = &NotFoundError{usersubscription.Label}
default:
err = &NotSingularError{usersubscription.Label}
}
return
}
// OnlyIDX is like OnlyID, but panics if an error occurs.
func (_q *UserSubscriptionQuery) OnlyIDX(ctx context.Context) int64 {
id, err := _q.OnlyID(ctx)
if err != nil {
panic(err)
}
return id
}
// All executes the query and returns a list of UserSubscriptions.
func (_q *UserSubscriptionQuery) All(ctx context.Context) ([]*UserSubscription, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
qr := querierAll[[]*UserSubscription, *UserSubscriptionQuery]()
return withInterceptors[[]*UserSubscription](ctx, _q, qr, _q.inters)
}
// AllX is like All, but panics if an error occurs.
func (_q *UserSubscriptionQuery) AllX(ctx context.Context) []*UserSubscription {
nodes, err := _q.All(ctx)
if err != nil {
panic(err)
}
return nodes
}
// IDs executes the query and returns a list of UserSubscription IDs.
func (_q *UserSubscriptionQuery) IDs(ctx context.Context) (ids []int64, err error) {
if _q.ctx.Unique == nil && _q.path != nil {
_q.Unique(true)
}
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
if err = _q.Select(usersubscription.FieldID).Scan(ctx, &ids); err != nil {
return nil, err
}
return ids, nil
}
// IDsX is like IDs, but panics if an error occurs.
func (_q *UserSubscriptionQuery) IDsX(ctx context.Context) []int64 {
ids, err := _q.IDs(ctx)
if err != nil {
panic(err)
}
return ids
}
// Count returns the count of the given query.
func (_q *UserSubscriptionQuery) Count(ctx context.Context) (int, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
if err := _q.prepareQuery(ctx); err != nil {
return 0, err
}
return withInterceptors[int](ctx, _q, querierCount[*UserSubscriptionQuery](), _q.inters)
}
// CountX is like Count, but panics if an error occurs.
func (_q *UserSubscriptionQuery) CountX(ctx context.Context) int {
count, err := _q.Count(ctx)
if err != nil {
panic(err)
}
return count
}
// Exist returns true if the query has elements in the graph.
func (_q *UserSubscriptionQuery) Exist(ctx context.Context) (bool, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
switch _, err := _q.FirstID(ctx); {
case IsNotFound(err):
return false, nil
case err != nil:
return false, fmt.Errorf("ent: check existence: %w", err)
default:
return true, nil
}
}
// ExistX is like Exist, but panics if an error occurs.
func (_q *UserSubscriptionQuery) ExistX(ctx context.Context) bool {
exist, err := _q.Exist(ctx)
if err != nil {
panic(err)
}
return exist
}
// Clone returns a duplicate of the UserSubscriptionQuery builder, including all associated steps. It can be
// used to prepare common query builders and use them differently after the clone is made.
func (_q *UserSubscriptionQuery) Clone() *UserSubscriptionQuery {
if _q == nil {
return nil
}
return &UserSubscriptionQuery{
config: _q.config,
ctx: _q.ctx.Clone(),
order: append([]usersubscription.OrderOption{}, _q.order...),
inters: append([]Interceptor{}, _q.inters...),
predicates: append([]predicate.UserSubscription{}, _q.predicates...),
withUser: _q.withUser.Clone(),
withGroup: _q.withGroup.Clone(),
withAssignedByUser: _q.withAssignedByUser.Clone(),
// clone intermediate query.
sql: _q.sql.Clone(),
path: _q.path,
}
}
// WithUser tells the query-builder to eager-load the nodes that are connected to
// the "user" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserSubscriptionQuery) WithUser(opts ...func(*UserQuery)) *UserSubscriptionQuery {
query := (&UserClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withUser = query
return _q
}
// WithGroup tells the query-builder to eager-load the nodes that are connected to
// the "group" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserSubscriptionQuery) WithGroup(opts ...func(*GroupQuery)) *UserSubscriptionQuery {
query := (&GroupClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withGroup = query
return _q
}
// WithAssignedByUser tells the query-builder to eager-load the nodes that are connected to
// the "assigned_by_user" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserSubscriptionQuery) WithAssignedByUser(opts ...func(*UserQuery)) *UserSubscriptionQuery {
query := (&UserClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withAssignedByUser = query
return _q
}
// GroupBy is used to group vertices by one or more fields/columns.
// It is often used with aggregate functions, like: count, max, mean, min, sum.
//
// Example:
//
// var v []struct {
// CreatedAt time.Time `json:"created_at,omitempty"`
// Count int `json:"count,omitempty"`
// }
//
// client.UserSubscription.Query().
// GroupBy(usersubscription.FieldCreatedAt).
// Aggregate(ent.Count()).
// Scan(ctx, &v)
func (_q *UserSubscriptionQuery) GroupBy(field string, fields ...string) *UserSubscriptionGroupBy {
_q.ctx.Fields = append([]string{field}, fields...)
grbuild := &UserSubscriptionGroupBy{build: _q}
grbuild.flds = &_q.ctx.Fields
grbuild.label = usersubscription.Label
grbuild.scan = grbuild.Scan
return grbuild
}
// Select allows the selection one or more fields/columns for the given query,
// instead of selecting all fields in the entity.
//
// Example:
//
// var v []struct {
// CreatedAt time.Time `json:"created_at,omitempty"`
// }
//
// client.UserSubscription.Query().
// Select(usersubscription.FieldCreatedAt).
// Scan(ctx, &v)
func (_q *UserSubscriptionQuery) Select(fields ...string) *UserSubscriptionSelect {
_q.ctx.Fields = append(_q.ctx.Fields, fields...)
sbuild := &UserSubscriptionSelect{UserSubscriptionQuery: _q}
sbuild.label = usersubscription.Label
sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
return sbuild
}
// Aggregate returns a UserSubscriptionSelect configured with the given aggregations.
func (_q *UserSubscriptionQuery) Aggregate(fns ...AggregateFunc) *UserSubscriptionSelect {
return _q.Select().Aggregate(fns...)
}
func (_q *UserSubscriptionQuery) prepareQuery(ctx context.Context) error {
for _, inter := range _q.inters {
if inter == nil {
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
}
if trv, ok := inter.(Traverser); ok {
if err := trv.Traverse(ctx, _q); err != nil {
return err
}
}
}
for _, f := range _q.ctx.Fields {
if !usersubscription.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
}
if _q.path != nil {
prev, err := _q.path(ctx)
if err != nil {
return err
}
_q.sql = prev
}
return nil
}
func (_q *UserSubscriptionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*UserSubscription, error) {
var (
nodes = []*UserSubscription{}
_spec = _q.querySpec()
loadedTypes = [3]bool{
_q.withUser != nil,
_q.withGroup != nil,
_q.withAssignedByUser != nil,
}
)
_spec.ScanValues = func(columns []string) ([]any, error) {
return (*UserSubscription).scanValues(nil, columns)
}
_spec.Assign = func(columns []string, values []any) error {
node := &UserSubscription{config: _q.config}
nodes = append(nodes, node)
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
for i := range hooks {
hooks[i](ctx, _spec)
}
if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
return nil, err
}
if len(nodes) == 0 {
return nodes, nil
}
if query := _q.withUser; query != nil {
if err := _q.loadUser(ctx, query, nodes, nil,
func(n *UserSubscription, e *User) { n.Edges.User = e }); err != nil {
return nil, err
}
}
if query := _q.withGroup; query != nil {
if err := _q.loadGroup(ctx, query, nodes, nil,
func(n *UserSubscription, e *Group) { n.Edges.Group = e }); err != nil {
return nil, err
}
}
if query := _q.withAssignedByUser; query != nil {
if err := _q.loadAssignedByUser(ctx, query, nodes, nil,
func(n *UserSubscription, e *User) { n.Edges.AssignedByUser = e }); err != nil {
return nil, err
}
}
return nodes, nil
}
func (_q *UserSubscriptionQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*UserSubscription, init func(*UserSubscription), assign func(*UserSubscription, *User)) error {
ids := make([]int64, 0, len(nodes))
nodeids := make(map[int64][]*UserSubscription)
for i := range nodes {
fk := nodes[i].UserID
if _, ok := nodeids[fk]; !ok {
ids = append(ids, fk)
}
nodeids[fk] = append(nodeids[fk], nodes[i])
}
if len(ids) == 0 {
return nil
}
query.Where(user.IDIn(ids...))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
nodes, ok := nodeids[n.ID]
if !ok {
return fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID)
}
for i := range nodes {
assign(nodes[i], n)
}
}
return nil
}
func (_q *UserSubscriptionQuery) loadGroup(ctx context.Context, query *GroupQuery, nodes []*UserSubscription, init func(*UserSubscription), assign func(*UserSubscription, *Group)) error {
ids := make([]int64, 0, len(nodes))
nodeids := make(map[int64][]*UserSubscription)
for i := range nodes {
fk := nodes[i].GroupID
if _, ok := nodeids[fk]; !ok {
ids = append(ids, fk)
}
nodeids[fk] = append(nodeids[fk], nodes[i])
}
if len(ids) == 0 {
return nil
}
query.Where(group.IDIn(ids...))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
nodes, ok := nodeids[n.ID]
if !ok {
return fmt.Errorf(`unexpected foreign-key "group_id" returned %v`, n.ID)
}
for i := range nodes {
assign(nodes[i], n)
}
}
return nil
}
func (_q *UserSubscriptionQuery) loadAssignedByUser(ctx context.Context, query *UserQuery, nodes []*UserSubscription, init func(*UserSubscription), assign func(*UserSubscription, *User)) error {
ids := make([]int64, 0, len(nodes))
nodeids := make(map[int64][]*UserSubscription)
for i := range nodes {
if nodes[i].AssignedBy == nil {
continue
}
fk := *nodes[i].AssignedBy
if _, ok := nodeids[fk]; !ok {
ids = append(ids, fk)
}
nodeids[fk] = append(nodeids[fk], nodes[i])
}
if len(ids) == 0 {
return nil
}
query.Where(user.IDIn(ids...))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
nodes, ok := nodeids[n.ID]
if !ok {
return fmt.Errorf(`unexpected foreign-key "assigned_by" returned %v`, n.ID)
}
for i := range nodes {
assign(nodes[i], n)
}
}
return nil
}
func (_q *UserSubscriptionQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
}
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
}
func (_q *UserSubscriptionQuery) querySpec() *sqlgraph.QuerySpec {
_spec := sqlgraph.NewQuerySpec(usersubscription.Table, usersubscription.Columns, sqlgraph.NewFieldSpec(usersubscription.FieldID, field.TypeInt64))
_spec.From = _q.sql
if unique := _q.ctx.Unique; unique != nil {
_spec.Unique = *unique
} else if _q.path != nil {
_spec.Unique = true
}
if fields := _q.ctx.Fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, usersubscription.FieldID)
for i := range fields {
if fields[i] != usersubscription.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
}
}
if _q.withUser != nil {
_spec.Node.AddColumnOnce(usersubscription.FieldUserID)
}
if _q.withGroup != nil {
_spec.Node.AddColumnOnce(usersubscription.FieldGroupID)
}
if _q.withAssignedByUser != nil {
_spec.Node.AddColumnOnce(usersubscription.FieldAssignedBy)
}
}
if ps := _q.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if limit := _q.ctx.Limit; limit != nil {
_spec.Limit = *limit
}
if offset := _q.ctx.Offset; offset != nil {
_spec.Offset = *offset
}
if ps := _q.order; len(ps) > 0 {
_spec.Order = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
return _spec
}
func (_q *UserSubscriptionQuery) sqlQuery(ctx context.Context) *sql.Selector {
builder := sql.Dialect(_q.driver.Dialect())
t1 := builder.Table(usersubscription.Table)
columns := _q.ctx.Fields
if len(columns) == 0 {
columns = usersubscription.Columns
}
selector := builder.Select(t1.Columns(columns...)...).From(t1)
if _q.sql != nil {
selector = _q.sql
selector.Select(selector.Columns(columns...)...)
}
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, p := range _q.predicates {
p(selector)
}
for _, p := range _q.order {
p(selector)
}
if offset := _q.ctx.Offset; offset != nil {
// limit is mandatory for offset clause. We start
// with default value, and override it below if needed.
selector.Offset(*offset).Limit(math.MaxInt32)
}
if limit := _q.ctx.Limit; limit != nil {
selector.Limit(*limit)
}
return selector
}
// UserSubscriptionGroupBy is the group-by builder for UserSubscription entities.
type UserSubscriptionGroupBy struct {
selector
build *UserSubscriptionQuery
}
// Aggregate adds the given aggregation functions to the group-by query.
func (_g *UserSubscriptionGroupBy) Aggregate(fns ...AggregateFunc) *UserSubscriptionGroupBy {
_g.fns = append(_g.fns, fns...)
return _g
}
// Scan applies the selector query and scans the result into the given value.
func (_g *UserSubscriptionGroupBy) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
if err := _g.build.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*UserSubscriptionQuery, *UserSubscriptionGroupBy](ctx, _g.build, _g, _g.build.inters, v)
}
func (_g *UserSubscriptionGroupBy) sqlScan(ctx context.Context, root *UserSubscriptionQuery, v any) error {
selector := root.sqlQuery(ctx).Select()
aggregation := make([]string, 0, len(_g.fns))
for _, fn := range _g.fns {
aggregation = append(aggregation, fn(selector))
}
if len(selector.SelectedColumns()) == 0 {
columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
for _, f := range *_g.flds {
columns = append(columns, selector.C(f))
}
columns = append(columns, aggregation...)
selector.Select(columns...)
}
selector.GroupBy(selector.Columns(*_g.flds...)...)
if err := selector.Err(); err != nil {
return err
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// UserSubscriptionSelect is the builder for selecting fields of UserSubscription entities.
type UserSubscriptionSelect struct {
*UserSubscriptionQuery
selector
}
// Aggregate adds the given aggregation functions to the selector query.
func (_s *UserSubscriptionSelect) Aggregate(fns ...AggregateFunc) *UserSubscriptionSelect {
_s.fns = append(_s.fns, fns...)
return _s
}
// Scan applies the selector query and scans the result into the given value.
func (_s *UserSubscriptionSelect) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
if err := _s.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*UserSubscriptionQuery, *UserSubscriptionSelect](ctx, _s.UserSubscriptionQuery, _s, _s.inters, v)
}
func (_s *UserSubscriptionSelect) sqlScan(ctx context.Context, root *UserSubscriptionQuery, v any) error {
selector := root.sqlQuery(ctx)
aggregation := make([]string, 0, len(_s.fns))
for _, fn := range _s.fns {
aggregation = append(aggregation, fn(selector))
}
switch n := len(*_s.selector.flds); {
case n == 0 && len(aggregation) > 0:
selector.Select(aggregation...)
case n != 0 && len(aggregation) > 0:
selector.AppendSelect(aggregation...)
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _s.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"errors"
"fmt"
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
// UserSubscriptionUpdate is the builder for updating UserSubscription entities.
type UserSubscriptionUpdate struct {
config
hooks []Hook
mutation *UserSubscriptionMutation
}
// Where appends a list predicates to the UserSubscriptionUpdate builder.
func (_u *UserSubscriptionUpdate) Where(ps ...predicate.UserSubscription) *UserSubscriptionUpdate {
_u.mutation.Where(ps...)
return _u
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *UserSubscriptionUpdate) SetUpdatedAt(v time.Time) *UserSubscriptionUpdate {
_u.mutation.SetUpdatedAt(v)
return _u
}
// SetUserID sets the "user_id" field.
func (_u *UserSubscriptionUpdate) SetUserID(v int64) *UserSubscriptionUpdate {
_u.mutation.SetUserID(v)
return _u
}
// SetNillableUserID sets the "user_id" field if the given value is not nil.
func (_u *UserSubscriptionUpdate) SetNillableUserID(v *int64) *UserSubscriptionUpdate {
if v != nil {
_u.SetUserID(*v)
}
return _u
}
// SetGroupID sets the "group_id" field.
func (_u *UserSubscriptionUpdate) SetGroupID(v int64) *UserSubscriptionUpdate {
_u.mutation.SetGroupID(v)
return _u
}
// SetNillableGroupID sets the "group_id" field if the given value is not nil.
func (_u *UserSubscriptionUpdate) SetNillableGroupID(v *int64) *UserSubscriptionUpdate {
if v != nil {
_u.SetGroupID(*v)
}
return _u
}
// SetStartsAt sets the "starts_at" field.
func (_u *UserSubscriptionUpdate) SetStartsAt(v time.Time) *UserSubscriptionUpdate {
_u.mutation.SetStartsAt(v)
return _u
}
// SetNillableStartsAt sets the "starts_at" field if the given value is not nil.
func (_u *UserSubscriptionUpdate) SetNillableStartsAt(v *time.Time) *UserSubscriptionUpdate {
if v != nil {
_u.SetStartsAt(*v)
}
return _u
}
// SetExpiresAt sets the "expires_at" field.
func (_u *UserSubscriptionUpdate) SetExpiresAt(v time.Time) *UserSubscriptionUpdate {
_u.mutation.SetExpiresAt(v)
return _u
}
// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
func (_u *UserSubscriptionUpdate) SetNillableExpiresAt(v *time.Time) *UserSubscriptionUpdate {
if v != nil {
_u.SetExpiresAt(*v)
}
return _u
}
// SetStatus sets the "status" field.
func (_u *UserSubscriptionUpdate) SetStatus(v string) *UserSubscriptionUpdate {
_u.mutation.SetStatus(v)
return _u
}
// SetNillableStatus sets the "status" field if the given value is not nil.
func (_u *UserSubscriptionUpdate) SetNillableStatus(v *string) *UserSubscriptionUpdate {
if v != nil {
_u.SetStatus(*v)
}
return _u
}
// SetDailyWindowStart sets the "daily_window_start" field.
func (_u *UserSubscriptionUpdate) SetDailyWindowStart(v time.Time) *UserSubscriptionUpdate {
_u.mutation.SetDailyWindowStart(v)
return _u
}
// SetNillableDailyWindowStart sets the "daily_window_start" field if the given value is not nil.
func (_u *UserSubscriptionUpdate) SetNillableDailyWindowStart(v *time.Time) *UserSubscriptionUpdate {
if v != nil {
_u.SetDailyWindowStart(*v)
}
return _u
}
// ClearDailyWindowStart clears the value of the "daily_window_start" field.
func (_u *UserSubscriptionUpdate) ClearDailyWindowStart() *UserSubscriptionUpdate {
_u.mutation.ClearDailyWindowStart()
return _u
}
// SetWeeklyWindowStart sets the "weekly_window_start" field.
func (_u *UserSubscriptionUpdate) SetWeeklyWindowStart(v time.Time) *UserSubscriptionUpdate {
_u.mutation.SetWeeklyWindowStart(v)
return _u
}
// SetNillableWeeklyWindowStart sets the "weekly_window_start" field if the given value is not nil.
func (_u *UserSubscriptionUpdate) SetNillableWeeklyWindowStart(v *time.Time) *UserSubscriptionUpdate {
if v != nil {
_u.SetWeeklyWindowStart(*v)
}
return _u
}
// ClearWeeklyWindowStart clears the value of the "weekly_window_start" field.
func (_u *UserSubscriptionUpdate) ClearWeeklyWindowStart() *UserSubscriptionUpdate {
_u.mutation.ClearWeeklyWindowStart()
return _u
}
// SetMonthlyWindowStart sets the "monthly_window_start" field.
func (_u *UserSubscriptionUpdate) SetMonthlyWindowStart(v time.Time) *UserSubscriptionUpdate {
_u.mutation.SetMonthlyWindowStart(v)
return _u
}
// SetNillableMonthlyWindowStart sets the "monthly_window_start" field if the given value is not nil.
func (_u *UserSubscriptionUpdate) SetNillableMonthlyWindowStart(v *time.Time) *UserSubscriptionUpdate {
if v != nil {
_u.SetMonthlyWindowStart(*v)
}
return _u
}
// ClearMonthlyWindowStart clears the value of the "monthly_window_start" field.
func (_u *UserSubscriptionUpdate) ClearMonthlyWindowStart() *UserSubscriptionUpdate {
_u.mutation.ClearMonthlyWindowStart()
return _u
}
// SetDailyUsageUsd sets the "daily_usage_usd" field.
func (_u *UserSubscriptionUpdate) SetDailyUsageUsd(v float64) *UserSubscriptionUpdate {
_u.mutation.ResetDailyUsageUsd()
_u.mutation.SetDailyUsageUsd(v)
return _u
}
// SetNillableDailyUsageUsd sets the "daily_usage_usd" field if the given value is not nil.
func (_u *UserSubscriptionUpdate) SetNillableDailyUsageUsd(v *float64) *UserSubscriptionUpdate {
if v != nil {
_u.SetDailyUsageUsd(*v)
}
return _u
}
// AddDailyUsageUsd adds value to the "daily_usage_usd" field.
func (_u *UserSubscriptionUpdate) AddDailyUsageUsd(v float64) *UserSubscriptionUpdate {
_u.mutation.AddDailyUsageUsd(v)
return _u
}
// SetWeeklyUsageUsd sets the "weekly_usage_usd" field.
func (_u *UserSubscriptionUpdate) SetWeeklyUsageUsd(v float64) *UserSubscriptionUpdate {
_u.mutation.ResetWeeklyUsageUsd()
_u.mutation.SetWeeklyUsageUsd(v)
return _u
}
// SetNillableWeeklyUsageUsd sets the "weekly_usage_usd" field if the given value is not nil.
func (_u *UserSubscriptionUpdate) SetNillableWeeklyUsageUsd(v *float64) *UserSubscriptionUpdate {
if v != nil {
_u.SetWeeklyUsageUsd(*v)
}
return _u
}
// AddWeeklyUsageUsd adds value to the "weekly_usage_usd" field.
func (_u *UserSubscriptionUpdate) AddWeeklyUsageUsd(v float64) *UserSubscriptionUpdate {
_u.mutation.AddWeeklyUsageUsd(v)
return _u
}
// SetMonthlyUsageUsd sets the "monthly_usage_usd" field.
func (_u *UserSubscriptionUpdate) SetMonthlyUsageUsd(v float64) *UserSubscriptionUpdate {
_u.mutation.ResetMonthlyUsageUsd()
_u.mutation.SetMonthlyUsageUsd(v)
return _u
}
// SetNillableMonthlyUsageUsd sets the "monthly_usage_usd" field if the given value is not nil.
func (_u *UserSubscriptionUpdate) SetNillableMonthlyUsageUsd(v *float64) *UserSubscriptionUpdate {
if v != nil {
_u.SetMonthlyUsageUsd(*v)
}
return _u
}
// AddMonthlyUsageUsd adds value to the "monthly_usage_usd" field.
func (_u *UserSubscriptionUpdate) AddMonthlyUsageUsd(v float64) *UserSubscriptionUpdate {
_u.mutation.AddMonthlyUsageUsd(v)
return _u
}
// SetAssignedBy sets the "assigned_by" field.
func (_u *UserSubscriptionUpdate) SetAssignedBy(v int64) *UserSubscriptionUpdate {
_u.mutation.SetAssignedBy(v)
return _u
}
// SetNillableAssignedBy sets the "assigned_by" field if the given value is not nil.
func (_u *UserSubscriptionUpdate) SetNillableAssignedBy(v *int64) *UserSubscriptionUpdate {
if v != nil {
_u.SetAssignedBy(*v)
}
return _u
}
// ClearAssignedBy clears the value of the "assigned_by" field.
func (_u *UserSubscriptionUpdate) ClearAssignedBy() *UserSubscriptionUpdate {
_u.mutation.ClearAssignedBy()
return _u
}
// SetAssignedAt sets the "assigned_at" field.
func (_u *UserSubscriptionUpdate) SetAssignedAt(v time.Time) *UserSubscriptionUpdate {
_u.mutation.SetAssignedAt(v)
return _u
}
// SetNillableAssignedAt sets the "assigned_at" field if the given value is not nil.
func (_u *UserSubscriptionUpdate) SetNillableAssignedAt(v *time.Time) *UserSubscriptionUpdate {
if v != nil {
_u.SetAssignedAt(*v)
}
return _u
}
// SetNotes sets the "notes" field.
func (_u *UserSubscriptionUpdate) SetNotes(v string) *UserSubscriptionUpdate {
_u.mutation.SetNotes(v)
return _u
}
// SetNillableNotes sets the "notes" field if the given value is not nil.
func (_u *UserSubscriptionUpdate) SetNillableNotes(v *string) *UserSubscriptionUpdate {
if v != nil {
_u.SetNotes(*v)
}
return _u
}
// ClearNotes clears the value of the "notes" field.
func (_u *UserSubscriptionUpdate) ClearNotes() *UserSubscriptionUpdate {
_u.mutation.ClearNotes()
return _u
}
// SetUser sets the "user" edge to the User entity.
func (_u *UserSubscriptionUpdate) SetUser(v *User) *UserSubscriptionUpdate {
return _u.SetUserID(v.ID)
}
// SetGroup sets the "group" edge to the Group entity.
func (_u *UserSubscriptionUpdate) SetGroup(v *Group) *UserSubscriptionUpdate {
return _u.SetGroupID(v.ID)
}
// SetAssignedByUserID sets the "assigned_by_user" edge to the User entity by ID.
func (_u *UserSubscriptionUpdate) SetAssignedByUserID(id int64) *UserSubscriptionUpdate {
_u.mutation.SetAssignedByUserID(id)
return _u
}
// SetNillableAssignedByUserID sets the "assigned_by_user" edge to the User entity by ID if the given value is not nil.
func (_u *UserSubscriptionUpdate) SetNillableAssignedByUserID(id *int64) *UserSubscriptionUpdate {
if id != nil {
_u = _u.SetAssignedByUserID(*id)
}
return _u
}
// SetAssignedByUser sets the "assigned_by_user" edge to the User entity.
func (_u *UserSubscriptionUpdate) SetAssignedByUser(v *User) *UserSubscriptionUpdate {
return _u.SetAssignedByUserID(v.ID)
}
// Mutation returns the UserSubscriptionMutation object of the builder.
func (_u *UserSubscriptionUpdate) Mutation() *UserSubscriptionMutation {
return _u.mutation
}
// ClearUser clears the "user" edge to the User entity.
func (_u *UserSubscriptionUpdate) ClearUser() *UserSubscriptionUpdate {
_u.mutation.ClearUser()
return _u
}
// ClearGroup clears the "group" edge to the Group entity.
func (_u *UserSubscriptionUpdate) ClearGroup() *UserSubscriptionUpdate {
_u.mutation.ClearGroup()
return _u
}
// ClearAssignedByUser clears the "assigned_by_user" edge to the User entity.
func (_u *UserSubscriptionUpdate) ClearAssignedByUser() *UserSubscriptionUpdate {
_u.mutation.ClearAssignedByUser()
return _u
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *UserSubscriptionUpdate) Save(ctx context.Context) (int, error) {
_u.defaults()
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *UserSubscriptionUpdate) SaveX(ctx context.Context) int {
affected, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return affected
}
// Exec executes the query.
func (_u *UserSubscriptionUpdate) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *UserSubscriptionUpdate) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *UserSubscriptionUpdate) defaults() {
if _, ok := _u.mutation.UpdatedAt(); !ok {
v := usersubscription.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_u *UserSubscriptionUpdate) check() error {
if v, ok := _u.mutation.Status(); ok {
if err := usersubscription.StatusValidator(v); err != nil {
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "UserSubscription.status": %w`, err)}
}
}
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "UserSubscription.user"`)
}
if _u.mutation.GroupCleared() && len(_u.mutation.GroupIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "UserSubscription.group"`)
}
return nil
}
func (_u *UserSubscriptionUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(usersubscription.Table, usersubscription.Columns, sqlgraph.NewFieldSpec(usersubscription.FieldID, field.TypeInt64))
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(usersubscription.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.StartsAt(); ok {
_spec.SetField(usersubscription.FieldStartsAt, field.TypeTime, value)
}
if value, ok := _u.mutation.ExpiresAt(); ok {
_spec.SetField(usersubscription.FieldExpiresAt, field.TypeTime, value)
}
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(usersubscription.FieldStatus, field.TypeString, value)
}
if value, ok := _u.mutation.DailyWindowStart(); ok {
_spec.SetField(usersubscription.FieldDailyWindowStart, field.TypeTime, value)
}
if _u.mutation.DailyWindowStartCleared() {
_spec.ClearField(usersubscription.FieldDailyWindowStart, field.TypeTime)
}
if value, ok := _u.mutation.WeeklyWindowStart(); ok {
_spec.SetField(usersubscription.FieldWeeklyWindowStart, field.TypeTime, value)
}
if _u.mutation.WeeklyWindowStartCleared() {
_spec.ClearField(usersubscription.FieldWeeklyWindowStart, field.TypeTime)
}
if value, ok := _u.mutation.MonthlyWindowStart(); ok {
_spec.SetField(usersubscription.FieldMonthlyWindowStart, field.TypeTime, value)
}
if _u.mutation.MonthlyWindowStartCleared() {
_spec.ClearField(usersubscription.FieldMonthlyWindowStart, field.TypeTime)
}
if value, ok := _u.mutation.DailyUsageUsd(); ok {
_spec.SetField(usersubscription.FieldDailyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedDailyUsageUsd(); ok {
_spec.AddField(usersubscription.FieldDailyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.WeeklyUsageUsd(); ok {
_spec.SetField(usersubscription.FieldWeeklyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedWeeklyUsageUsd(); ok {
_spec.AddField(usersubscription.FieldWeeklyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.MonthlyUsageUsd(); ok {
_spec.SetField(usersubscription.FieldMonthlyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedMonthlyUsageUsd(); ok {
_spec.AddField(usersubscription.FieldMonthlyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AssignedAt(); ok {
_spec.SetField(usersubscription.FieldAssignedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.Notes(); ok {
_spec.SetField(usersubscription.FieldNotes, field.TypeString, value)
}
if _u.mutation.NotesCleared() {
_spec.ClearField(usersubscription.FieldNotes, field.TypeString)
}
if _u.mutation.UserCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: usersubscription.UserTable,
Columns: []string{usersubscription.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: usersubscription.UserTable,
Columns: []string{usersubscription.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.GroupCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: usersubscription.GroupTable,
Columns: []string{usersubscription.GroupColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.GroupIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: usersubscription.GroupTable,
Columns: []string{usersubscription.GroupColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.AssignedByUserCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: usersubscription.AssignedByUserTable,
Columns: []string{usersubscription.AssignedByUserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.AssignedByUserIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: usersubscription.AssignedByUserTable,
Columns: []string{usersubscription.AssignedByUserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{usersubscription.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return 0, err
}
_u.mutation.done = true
return _node, nil
}
// UserSubscriptionUpdateOne is the builder for updating a single UserSubscription entity.
type UserSubscriptionUpdateOne struct {
config
fields []string
hooks []Hook
mutation *UserSubscriptionMutation
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *UserSubscriptionUpdateOne) SetUpdatedAt(v time.Time) *UserSubscriptionUpdateOne {
_u.mutation.SetUpdatedAt(v)
return _u
}
// SetUserID sets the "user_id" field.
func (_u *UserSubscriptionUpdateOne) SetUserID(v int64) *UserSubscriptionUpdateOne {
_u.mutation.SetUserID(v)
return _u
}
// SetNillableUserID sets the "user_id" field if the given value is not nil.
func (_u *UserSubscriptionUpdateOne) SetNillableUserID(v *int64) *UserSubscriptionUpdateOne {
if v != nil {
_u.SetUserID(*v)
}
return _u
}
// SetGroupID sets the "group_id" field.
func (_u *UserSubscriptionUpdateOne) SetGroupID(v int64) *UserSubscriptionUpdateOne {
_u.mutation.SetGroupID(v)
return _u
}
// SetNillableGroupID sets the "group_id" field if the given value is not nil.
func (_u *UserSubscriptionUpdateOne) SetNillableGroupID(v *int64) *UserSubscriptionUpdateOne {
if v != nil {
_u.SetGroupID(*v)
}
return _u
}
// SetStartsAt sets the "starts_at" field.
func (_u *UserSubscriptionUpdateOne) SetStartsAt(v time.Time) *UserSubscriptionUpdateOne {
_u.mutation.SetStartsAt(v)
return _u
}
// SetNillableStartsAt sets the "starts_at" field if the given value is not nil.
func (_u *UserSubscriptionUpdateOne) SetNillableStartsAt(v *time.Time) *UserSubscriptionUpdateOne {
if v != nil {
_u.SetStartsAt(*v)
}
return _u
}
// SetExpiresAt sets the "expires_at" field.
func (_u *UserSubscriptionUpdateOne) SetExpiresAt(v time.Time) *UserSubscriptionUpdateOne {
_u.mutation.SetExpiresAt(v)
return _u
}
// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
func (_u *UserSubscriptionUpdateOne) SetNillableExpiresAt(v *time.Time) *UserSubscriptionUpdateOne {
if v != nil {
_u.SetExpiresAt(*v)
}
return _u
}
// SetStatus sets the "status" field.
func (_u *UserSubscriptionUpdateOne) SetStatus(v string) *UserSubscriptionUpdateOne {
_u.mutation.SetStatus(v)
return _u
}
// SetNillableStatus sets the "status" field if the given value is not nil.
func (_u *UserSubscriptionUpdateOne) SetNillableStatus(v *string) *UserSubscriptionUpdateOne {
if v != nil {
_u.SetStatus(*v)
}
return _u
}
// SetDailyWindowStart sets the "daily_window_start" field.
func (_u *UserSubscriptionUpdateOne) SetDailyWindowStart(v time.Time) *UserSubscriptionUpdateOne {
_u.mutation.SetDailyWindowStart(v)
return _u
}
// SetNillableDailyWindowStart sets the "daily_window_start" field if the given value is not nil.
func (_u *UserSubscriptionUpdateOne) SetNillableDailyWindowStart(v *time.Time) *UserSubscriptionUpdateOne {
if v != nil {
_u.SetDailyWindowStart(*v)
}
return _u
}
// ClearDailyWindowStart clears the value of the "daily_window_start" field.
func (_u *UserSubscriptionUpdateOne) ClearDailyWindowStart() *UserSubscriptionUpdateOne {
_u.mutation.ClearDailyWindowStart()
return _u
}
// SetWeeklyWindowStart sets the "weekly_window_start" field.
func (_u *UserSubscriptionUpdateOne) SetWeeklyWindowStart(v time.Time) *UserSubscriptionUpdateOne {
_u.mutation.SetWeeklyWindowStart(v)
return _u
}
// SetNillableWeeklyWindowStart sets the "weekly_window_start" field if the given value is not nil.
func (_u *UserSubscriptionUpdateOne) SetNillableWeeklyWindowStart(v *time.Time) *UserSubscriptionUpdateOne {
if v != nil {
_u.SetWeeklyWindowStart(*v)
}
return _u
}
// ClearWeeklyWindowStart clears the value of the "weekly_window_start" field.
func (_u *UserSubscriptionUpdateOne) ClearWeeklyWindowStart() *UserSubscriptionUpdateOne {
_u.mutation.ClearWeeklyWindowStart()
return _u
}
// SetMonthlyWindowStart sets the "monthly_window_start" field.
func (_u *UserSubscriptionUpdateOne) SetMonthlyWindowStart(v time.Time) *UserSubscriptionUpdateOne {
_u.mutation.SetMonthlyWindowStart(v)
return _u
}
// SetNillableMonthlyWindowStart sets the "monthly_window_start" field if the given value is not nil.
func (_u *UserSubscriptionUpdateOne) SetNillableMonthlyWindowStart(v *time.Time) *UserSubscriptionUpdateOne {
if v != nil {
_u.SetMonthlyWindowStart(*v)
}
return _u
}
// ClearMonthlyWindowStart clears the value of the "monthly_window_start" field.
func (_u *UserSubscriptionUpdateOne) ClearMonthlyWindowStart() *UserSubscriptionUpdateOne {
_u.mutation.ClearMonthlyWindowStart()
return _u
}
// SetDailyUsageUsd sets the "daily_usage_usd" field.
func (_u *UserSubscriptionUpdateOne) SetDailyUsageUsd(v float64) *UserSubscriptionUpdateOne {
_u.mutation.ResetDailyUsageUsd()
_u.mutation.SetDailyUsageUsd(v)
return _u
}
// SetNillableDailyUsageUsd sets the "daily_usage_usd" field if the given value is not nil.
func (_u *UserSubscriptionUpdateOne) SetNillableDailyUsageUsd(v *float64) *UserSubscriptionUpdateOne {
if v != nil {
_u.SetDailyUsageUsd(*v)
}
return _u
}
// AddDailyUsageUsd adds value to the "daily_usage_usd" field.
func (_u *UserSubscriptionUpdateOne) AddDailyUsageUsd(v float64) *UserSubscriptionUpdateOne {
_u.mutation.AddDailyUsageUsd(v)
return _u
}
// SetWeeklyUsageUsd sets the "weekly_usage_usd" field.
func (_u *UserSubscriptionUpdateOne) SetWeeklyUsageUsd(v float64) *UserSubscriptionUpdateOne {
_u.mutation.ResetWeeklyUsageUsd()
_u.mutation.SetWeeklyUsageUsd(v)
return _u
}
// SetNillableWeeklyUsageUsd sets the "weekly_usage_usd" field if the given value is not nil.
func (_u *UserSubscriptionUpdateOne) SetNillableWeeklyUsageUsd(v *float64) *UserSubscriptionUpdateOne {
if v != nil {
_u.SetWeeklyUsageUsd(*v)
}
return _u
}
// AddWeeklyUsageUsd adds value to the "weekly_usage_usd" field.
func (_u *UserSubscriptionUpdateOne) AddWeeklyUsageUsd(v float64) *UserSubscriptionUpdateOne {
_u.mutation.AddWeeklyUsageUsd(v)
return _u
}
// SetMonthlyUsageUsd sets the "monthly_usage_usd" field.
func (_u *UserSubscriptionUpdateOne) SetMonthlyUsageUsd(v float64) *UserSubscriptionUpdateOne {
_u.mutation.ResetMonthlyUsageUsd()
_u.mutation.SetMonthlyUsageUsd(v)
return _u
}
// SetNillableMonthlyUsageUsd sets the "monthly_usage_usd" field if the given value is not nil.
func (_u *UserSubscriptionUpdateOne) SetNillableMonthlyUsageUsd(v *float64) *UserSubscriptionUpdateOne {
if v != nil {
_u.SetMonthlyUsageUsd(*v)
}
return _u
}
// AddMonthlyUsageUsd adds value to the "monthly_usage_usd" field.
func (_u *UserSubscriptionUpdateOne) AddMonthlyUsageUsd(v float64) *UserSubscriptionUpdateOne {
_u.mutation.AddMonthlyUsageUsd(v)
return _u
}
// SetAssignedBy sets the "assigned_by" field.
func (_u *UserSubscriptionUpdateOne) SetAssignedBy(v int64) *UserSubscriptionUpdateOne {
_u.mutation.SetAssignedBy(v)
return _u
}
// SetNillableAssignedBy sets the "assigned_by" field if the given value is not nil.
func (_u *UserSubscriptionUpdateOne) SetNillableAssignedBy(v *int64) *UserSubscriptionUpdateOne {
if v != nil {
_u.SetAssignedBy(*v)
}
return _u
}
// ClearAssignedBy clears the value of the "assigned_by" field.
func (_u *UserSubscriptionUpdateOne) ClearAssignedBy() *UserSubscriptionUpdateOne {
_u.mutation.ClearAssignedBy()
return _u
}
// SetAssignedAt sets the "assigned_at" field.
func (_u *UserSubscriptionUpdateOne) SetAssignedAt(v time.Time) *UserSubscriptionUpdateOne {
_u.mutation.SetAssignedAt(v)
return _u
}
// SetNillableAssignedAt sets the "assigned_at" field if the given value is not nil.
func (_u *UserSubscriptionUpdateOne) SetNillableAssignedAt(v *time.Time) *UserSubscriptionUpdateOne {
if v != nil {
_u.SetAssignedAt(*v)
}
return _u
}
// SetNotes sets the "notes" field.
func (_u *UserSubscriptionUpdateOne) SetNotes(v string) *UserSubscriptionUpdateOne {
_u.mutation.SetNotes(v)
return _u
}
// SetNillableNotes sets the "notes" field if the given value is not nil.
func (_u *UserSubscriptionUpdateOne) SetNillableNotes(v *string) *UserSubscriptionUpdateOne {
if v != nil {
_u.SetNotes(*v)
}
return _u
}
// ClearNotes clears the value of the "notes" field.
func (_u *UserSubscriptionUpdateOne) ClearNotes() *UserSubscriptionUpdateOne {
_u.mutation.ClearNotes()
return _u
}
// SetUser sets the "user" edge to the User entity.
func (_u *UserSubscriptionUpdateOne) SetUser(v *User) *UserSubscriptionUpdateOne {
return _u.SetUserID(v.ID)
}
// SetGroup sets the "group" edge to the Group entity.
func (_u *UserSubscriptionUpdateOne) SetGroup(v *Group) *UserSubscriptionUpdateOne {
return _u.SetGroupID(v.ID)
}
// SetAssignedByUserID sets the "assigned_by_user" edge to the User entity by ID.
func (_u *UserSubscriptionUpdateOne) SetAssignedByUserID(id int64) *UserSubscriptionUpdateOne {
_u.mutation.SetAssignedByUserID(id)
return _u
}
// SetNillableAssignedByUserID sets the "assigned_by_user" edge to the User entity by ID if the given value is not nil.
func (_u *UserSubscriptionUpdateOne) SetNillableAssignedByUserID(id *int64) *UserSubscriptionUpdateOne {
if id != nil {
_u = _u.SetAssignedByUserID(*id)
}
return _u
}
// SetAssignedByUser sets the "assigned_by_user" edge to the User entity.
func (_u *UserSubscriptionUpdateOne) SetAssignedByUser(v *User) *UserSubscriptionUpdateOne {
return _u.SetAssignedByUserID(v.ID)
}
// Mutation returns the UserSubscriptionMutation object of the builder.
func (_u *UserSubscriptionUpdateOne) Mutation() *UserSubscriptionMutation {
return _u.mutation
}
// ClearUser clears the "user" edge to the User entity.
func (_u *UserSubscriptionUpdateOne) ClearUser() *UserSubscriptionUpdateOne {
_u.mutation.ClearUser()
return _u
}
// ClearGroup clears the "group" edge to the Group entity.
func (_u *UserSubscriptionUpdateOne) ClearGroup() *UserSubscriptionUpdateOne {
_u.mutation.ClearGroup()
return _u
}
// ClearAssignedByUser clears the "assigned_by_user" edge to the User entity.
func (_u *UserSubscriptionUpdateOne) ClearAssignedByUser() *UserSubscriptionUpdateOne {
_u.mutation.ClearAssignedByUser()
return _u
}
// Where appends a list predicates to the UserSubscriptionUpdate builder.
func (_u *UserSubscriptionUpdateOne) Where(ps ...predicate.UserSubscription) *UserSubscriptionUpdateOne {
_u.mutation.Where(ps...)
return _u
}
// Select allows selecting one or more fields (columns) of the returned entity.
// The default is selecting all fields defined in the entity schema.
func (_u *UserSubscriptionUpdateOne) Select(field string, fields ...string) *UserSubscriptionUpdateOne {
_u.fields = append([]string{field}, fields...)
return _u
}
// Save executes the query and returns the updated UserSubscription entity.
func (_u *UserSubscriptionUpdateOne) Save(ctx context.Context) (*UserSubscription, error) {
_u.defaults()
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *UserSubscriptionUpdateOne) SaveX(ctx context.Context) *UserSubscription {
node, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return node
}
// Exec executes the query on the entity.
func (_u *UserSubscriptionUpdateOne) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *UserSubscriptionUpdateOne) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *UserSubscriptionUpdateOne) defaults() {
if _, ok := _u.mutation.UpdatedAt(); !ok {
v := usersubscription.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_u *UserSubscriptionUpdateOne) check() error {
if v, ok := _u.mutation.Status(); ok {
if err := usersubscription.StatusValidator(v); err != nil {
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "UserSubscription.status": %w`, err)}
}
}
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "UserSubscription.user"`)
}
if _u.mutation.GroupCleared() && len(_u.mutation.GroupIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "UserSubscription.group"`)
}
return nil
}
func (_u *UserSubscriptionUpdateOne) sqlSave(ctx context.Context) (_node *UserSubscription, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(usersubscription.Table, usersubscription.Columns, sqlgraph.NewFieldSpec(usersubscription.FieldID, field.TypeInt64))
id, ok := _u.mutation.ID()
if !ok {
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "UserSubscription.id" for update`)}
}
_spec.Node.ID.Value = id
if fields := _u.fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, usersubscription.FieldID)
for _, f := range fields {
if !usersubscription.ValidColumn(f) {
return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
if f != usersubscription.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, f)
}
}
}
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(usersubscription.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.StartsAt(); ok {
_spec.SetField(usersubscription.FieldStartsAt, field.TypeTime, value)
}
if value, ok := _u.mutation.ExpiresAt(); ok {
_spec.SetField(usersubscription.FieldExpiresAt, field.TypeTime, value)
}
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(usersubscription.FieldStatus, field.TypeString, value)
}
if value, ok := _u.mutation.DailyWindowStart(); ok {
_spec.SetField(usersubscription.FieldDailyWindowStart, field.TypeTime, value)
}
if _u.mutation.DailyWindowStartCleared() {
_spec.ClearField(usersubscription.FieldDailyWindowStart, field.TypeTime)
}
if value, ok := _u.mutation.WeeklyWindowStart(); ok {
_spec.SetField(usersubscription.FieldWeeklyWindowStart, field.TypeTime, value)
}
if _u.mutation.WeeklyWindowStartCleared() {
_spec.ClearField(usersubscription.FieldWeeklyWindowStart, field.TypeTime)
}
if value, ok := _u.mutation.MonthlyWindowStart(); ok {
_spec.SetField(usersubscription.FieldMonthlyWindowStart, field.TypeTime, value)
}
if _u.mutation.MonthlyWindowStartCleared() {
_spec.ClearField(usersubscription.FieldMonthlyWindowStart, field.TypeTime)
}
if value, ok := _u.mutation.DailyUsageUsd(); ok {
_spec.SetField(usersubscription.FieldDailyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedDailyUsageUsd(); ok {
_spec.AddField(usersubscription.FieldDailyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.WeeklyUsageUsd(); ok {
_spec.SetField(usersubscription.FieldWeeklyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedWeeklyUsageUsd(); ok {
_spec.AddField(usersubscription.FieldWeeklyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.MonthlyUsageUsd(); ok {
_spec.SetField(usersubscription.FieldMonthlyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedMonthlyUsageUsd(); ok {
_spec.AddField(usersubscription.FieldMonthlyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AssignedAt(); ok {
_spec.SetField(usersubscription.FieldAssignedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.Notes(); ok {
_spec.SetField(usersubscription.FieldNotes, field.TypeString, value)
}
if _u.mutation.NotesCleared() {
_spec.ClearField(usersubscription.FieldNotes, field.TypeString)
}
if _u.mutation.UserCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: usersubscription.UserTable,
Columns: []string{usersubscription.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: usersubscription.UserTable,
Columns: []string{usersubscription.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.GroupCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: usersubscription.GroupTable,
Columns: []string{usersubscription.GroupColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.GroupIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: usersubscription.GroupTable,
Columns: []string{usersubscription.GroupColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.AssignedByUserCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: usersubscription.AssignedByUserTable,
Columns: []string{usersubscription.AssignedByUserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.AssignedByUserIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: usersubscription.AssignedByUserTable,
Columns: []string{usersubscription.AssignedByUserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
_node = &UserSubscription{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues
if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{usersubscription.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return nil, err
}
_u.mutation.done = true
return _node, nil
}
......@@ -5,6 +5,7 @@ go 1.24.0
toolchain go1.24.11
require (
entgo.io/ent v0.14.5
github.com/gin-gonic/gin v1.9.1
github.com/golang-jwt/jwt/v5 v5.2.0
github.com/google/uuid v1.6.0
......@@ -23,17 +24,18 @@ require (
golang.org/x/net v0.47.0
golang.org/x/term v0.37.0
gopkg.in/yaml.v3 v3.0.1
gorm.io/datatypes v1.2.0
gorm.io/driver/postgres v1.5.4
gorm.io/gorm v1.25.5
)
require (
ariga.io/atlas v0.32.1-0.20250325101103-175b25e1c1b9 // indirect
dario.cat/mergo v1.0.2 // indirect
filippo.io/edwards25519 v1.1.0 // indirect
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/agext/levenshtein v1.2.3 // indirect
github.com/andybalholm/brotli v1.2.0 // indirect
github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect
github.com/bmatcuk/doublestar v1.3.4 // indirect
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
......@@ -58,20 +60,20 @@ require (
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/go-openapi/inflect v0.19.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.14.0 // indirect
github.com/go-sql-driver/mysql v1.9.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/google/subcommands v1.2.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/hashicorp/hcl/v2 v2.18.1 // indirect
github.com/icholy/digest v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/pgx/v5 v5.7.4 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect
......@@ -82,7 +84,9 @@ require (
github.com/magiconair/properties v1.8.10 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-runewidth v0.0.15 // indirect
github.com/mdelapenya/tlscert v0.2.0 // indirect
github.com/mitchellh/go-wordwrap v1.0.1 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect
github.com/moby/go-archive v0.1.0 // indirect
......@@ -94,6 +98,7 @@ require (
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/morikuni/aec v1.0.0 // indirect
github.com/olekukonko/tablewriter v0.0.5 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.1 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
......@@ -103,6 +108,7 @@ require (
github.com/quic-go/qpack v0.5.1 // indirect
github.com/quic-go/quic-go v0.56.0 // indirect
github.com/refraction-networking/utls v1.8.1 // indirect
github.com/rivo/uniseg v0.2.0 // indirect
github.com/sagikazarmark/locafero v0.4.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
github.com/shirou/gopsutil/v4 v4.25.6 // indirect
......@@ -111,6 +117,7 @@ require (
github.com/spaolacci/murmur3 v1.1.0 // indirect
github.com/spf13/afero v1.11.0 // indirect
github.com/spf13/cast v1.6.0 // indirect
github.com/spf13/cobra v1.7.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/testcontainers/testcontainers-go v0.40.0 // indirect
......@@ -121,6 +128,8 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
github.com/zclconf/go-cty v1.14.4 // indirect
github.com/zclconf/go-cty-yaml v1.1.0 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
go.opentelemetry.io/otel v1.37.0 // indirect
......@@ -137,8 +146,11 @@ require (
golang.org/x/sys v0.38.0 // indirect
golang.org/x/text v0.31.0 // indirect
golang.org/x/tools v0.38.0 // indirect
golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated // indirect
google.golang.org/grpc v1.75.1 // indirect
google.golang.org/protobuf v1.36.10 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gorm.io/driver/mysql v1.5.2 // indirect
gorm.io/datatypes v1.2.7 // indirect
gorm.io/driver/mysql v1.5.6 // indirect
gorm.io/gorm v1.30.0 // indirect
)
ariga.io/atlas v0.32.1-0.20250325101103-175b25e1c1b9 h1:E0wvcUXTkgyN4wy4LGtNzMNGMytJN8afmIWXJVMi4cc=
ariga.io/atlas v0.32.1-0.20250325101103-175b25e1c1b9/go.mod h1:Oe1xWPuu5q9LzyrWfbZmEZxFYeu4BHTyzfjeW2aZp/w=
dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
entgo.io/ent v0.14.5 h1:Rj2WOYJtCkWyFo6a+5wB3EfBRP0rnx1fMk6gGA0UUe4=
entgo.io/ent v0.14.5/go.mod h1:zTzLmWtPvGpmSwtkaayM2cm5m819NdM7z7tYPq3vN0U=
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 h1:He8afgbRMd7mFxO99hRNu+6tazq8nFF9lIwo9JFroBk=
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8=
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/agext/levenshtein v1.2.3 h1:YB2fHEn0UJagG8T1rrWknE3ZQzWM06O8AMAatNn7lmo=
github.com/agext/levenshtein v1.2.3/go.mod h1:JEDfjyjHDjOF/1e4FlBE/PkbqA9OfWu2ki2W0IB5558=
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY=
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
github.com/bmatcuk/doublestar v1.3.4 h1:gPypJ5xD31uhX6Tf54sDPUOBXTqKH4c9aPY66CyQrS0=
github.com/bmatcuk/doublestar v1.3.4/go.mod h1:wiQtGV+rzVYxB7WIlirSN++5HPtPlXEo9MEoZQC/PmE=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
......@@ -34,6 +46,7 @@ github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpS
github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw=
github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc=
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
......@@ -73,6 +86,8 @@ github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-openapi/inflect v0.19.0 h1:9jCH9scKIbHeV9m12SmPilScz6krDxKRasNNSNPXu/4=
github.com/go-openapi/inflect v0.19.0/go.mod h1:lHpZVlpIQqLyKwJ4N+YSc9hchQy/i12fJykb83CRBH4=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
......@@ -84,14 +99,12 @@ github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QX
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/go-sql-driver/mysql v1.9.0 h1:Y0zIbQXhQKmQgTp44Y1dp3wTXcn804QoTptLZT1vtvo=
github.com/go-sql-driver/mysql v1.9.0/go.mod h1:pDetrLJeA3oMujJuvXc8RJoasr589B6A9fwzD3QMrqw=
github.com/go-test/deep v1.0.3 h1:ZrJSEWsXzPOxaZnFteGEfooLba+ju3FYIbOrS+rQd68=
github.com/go-test/deep v1.0.3/go.mod h1:wGDj63lr65AM2AQyKZd/NYHGb0R+1RLqB8NKt3aSFNA=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw=
github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA=
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A=
github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
......@@ -109,10 +122,14 @@ github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 h1:NmZ1PKzSTQbuGHw9DGPFomqkkLW
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3/go.mod h1:zQrxl1YP88HQlA6i9c63DSVPFklWpGX4OWAc9bFuaH4=
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
github.com/hashicorp/hcl/v2 v2.18.1 h1:6nxnOJFku1EuSawSD81fuviYUV8DxFr3fp2dUi3ZYSo=
github.com/hashicorp/hcl/v2 v2.18.1/go.mod h1:ThLC89FV4p9MPW804KVbe/cEXoQ8NZEh+JtMeeGErHE=
github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
github.com/imroc/req/v3 v3.56.0 h1:t6YdqqerYBXhZ9+VjqsQs5wlKxdUNEvsgBhxWc1AEEo=
github.com/imroc/req/v3 v3.56.0/go.mod h1:cUZSooE8hhzFNOrAbdxuemXDQxFXLQTnu3066jr7ZGk=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
......@@ -136,6 +153,8 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
......@@ -149,12 +168,15 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.15 h1:vfoHhTN1af61xCRSWzFIWzx2YskyMTwHLrExkBOjvxI=
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
github.com/mdelapenya/tlscert v0.2.0/go.mod h1:O4njj3ELLnJjGdkN7M/vIVCpZ+Cf0L6muqOG4tLSl8o=
github.com/microsoft/go-mssqldb v0.17.0 h1:Fto83dMZPnYv1Zwx5vHHxpNraeEaUlQ/hhHLgZiaenE=
github.com/microsoft/go-mssqldb v0.17.0/go.mod h1:OkoNGhGEs8EZqchVTtochlXruEhEOaO4S0d2sB5aeGQ=
github.com/mitchellh/go-wordwrap v1.0.1 h1:TLuKupo69TCn6TQSyGxwI1EblZZEsQ0vMlAFQflz0v0=
github.com/mitchellh/go-wordwrap v1.0.1/go.mod h1:R62XHJLzvMFRBbcrT7m7WgmE1eOyTSsCt+hzestvNj0=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
......@@ -180,6 +202,8 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
......@@ -203,12 +227,17 @@ github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4Vi
github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkrMJI0pRUOCAo=
github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ=
github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4=
github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE=
github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ=
github.com/sergi/go-diff v1.3.1 h1:xkr+Oxo4BOQKmkn/B9eMK0g5Kg/983T9DqqPHwYqD+8=
github.com/sergi/go-diff v1.3.1/go.mod h1:aMJSSKb2lpPvRNec0+w3fl7LP9IOFzdc9Pa4NFbPK1I=
github.com/shirou/gopsutil/v4 v4.25.6 h1:kLysI2JsKorfaFPcYmcJqbzROzsBWEOAtw6A7dIfqXs=
github.com/shirou/gopsutil/v4 v4.25.6/go.mod h1:PfybzyydfZcN+JMMjkF6Zb8Mq1A/VcogFFg7hj50W9c=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
......@@ -221,6 +250,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
......@@ -269,6 +300,10 @@ github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZ
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
github.com/zclconf/go-cty v1.14.4 h1:uXXczd9QDGsgu0i/QFR/hzI5NYCHLf6NQw/atrbnhq8=
github.com/zclconf/go-cty v1.14.4/go.mod h1:VvMs5i0vgZdhYawQNq5kePSpLAoz8u1xvZgrPIxfnZE=
github.com/zclconf/go-cty-yaml v1.1.0 h1:nP+jp0qPHv2IhUVqmQSzjvqAWcObN0KBkUl2rWBdig0=
github.com/zclconf/go-cty-yaml v1.1.0/go.mod h1:9YLUH4g7lOhVWqUbctnVlZ5KLpg7JAprQNgxSZ1Gyxs=
github.com/zeromicro/go-zero v1.9.4 h1:aRLFoISqAYijABtkbliQC5SsI5TbizJpQvoHc9xup8k=
github.com/zeromicro/go-zero v1.9.4/go.mod h1:a17JOTch25SWxBcUgJZYps60hygK3pIYdw7nGwlcS38=
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
......@@ -329,6 +364,10 @@ golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
golang.org/x/tools/go/expect v0.1.0-deprecated h1:jY2C5HGYR5lqex3gEniOQL0r7Dq5+VGVgY1nudX5lXY=
golang.org/x/tools/go/expect v0.1.0-deprecated/go.mod h1:eihoPOH+FgIqa3FpoTwguz/bVUSGBlGQU67vpBeOrBY=
golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated h1:1h2MnaIAIXISqTFKdENegdpAgUXz6NrPEsbIeWaBRvM=
golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated/go.mod h1:RVAQXBGNv1ib0J382/DPCRS/BPnsGebyM1Gj5VSDpG8=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ=
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU=
......@@ -347,19 +386,13 @@ gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/datatypes v1.2.0 h1:5YT+eokWdIxhJgWHdrb2zYUimyk0+TaFth+7a0ybzco=
gorm.io/datatypes v1.2.0/go.mod h1:o1dh0ZvjIjhH/bngTpypG6lVRJ5chTBxE09FH/71k04=
gorm.io/driver/mysql v1.5.2 h1:QC2HRskSE75wBuOxe0+iCkyJZ+RqpudsQtqkp+IMuXs=
gorm.io/driver/mysql v1.5.2/go.mod h1:pQLhh1Ut/WUAySdTHwBpBv6+JKcj+ua4ZFx1QQTBzb8=
gorm.io/driver/postgres v1.5.4 h1:Iyrp9Meh3GmbSuyIAGyjkN+n9K+GHX9b9MqsTL4EJCo=
gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0=
gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU=
gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI=
gorm.io/driver/sqlserver v1.4.1 h1:t4r4r6Jam5E6ejqP7N82qAJIJAht27EGT41HyPfXRw0=
gorm.io/driver/sqlserver v1.4.1/go.mod h1:DJ4P+MeZbc5rvY58PnmN1Lnyvb5gw5NPzGshHDnJLig=
gorm.io/gorm v1.25.2-0.20230530020048-26663ab9bf55/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls=
gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
gorm.io/datatypes v1.2.7 h1:ww9GAhF1aGXZY3EB3cJPJ7//JiuQo7DlQA7NNlVaTdk=
gorm.io/datatypes v1.2.7/go.mod h1:M2iO+6S3hhi4nAyYe444Pcb0dcIiOMJ7QHaUXxyiNZY=
gorm.io/driver/mysql v1.5.6 h1:Ld4mkIickM+EliaQZQx3uOJDJHtrd70MxAUqWqlx3Y8=
gorm.io/driver/mysql v1.5.6/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM=
gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
gorm.io/gorm v1.30.0 h1:qbT5aPv1UH8gI99OsRlvDToLxW5zR7FzS9acZDOZcgs=
gorm.io/gorm v1.30.0/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE=
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
......@@ -215,8 +215,10 @@ func setDefaults() {
viper.SetDefault("jwt.expire_hour", 24)
// Default
viper.SetDefault("default.admin_email", "admin@sub2api.com")
viper.SetDefault("default.admin_password", "admin123")
// Admin credentials are created via the setup flow (web wizard / CLI / AUTO_SETUP).
// Do not ship fixed defaults here to avoid insecure "known credentials" in production.
viper.SetDefault("default.admin_email", "")
viper.SetDefault("default.admin_password", "")
viper.SetDefault("default.user_concurrency", 5)
viper.SetDefault("default.user_balance", 0)
viper.SetDefault("default.api_key_prefix", "sk-")
......
package infrastructure
import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/repository"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// InitDB 初始化数据库连接
func InitDB(cfg *config.Config) (*gorm.DB, error) {
// 初始化时区(在数据库连接之前,确保时区设置正确)
if err := timezone.Init(cfg.Timezone); err != nil {
return nil, err
}
gormConfig := &gorm.Config{}
if cfg.Server.Mode == "debug" {
gormConfig.Logger = logger.Default.LogMode(logger.Info)
}
// 使用带时区的 DSN 连接数据库
db, err := gorm.Open(postgres.Open(cfg.Database.DSNWithTimezone(cfg.Timezone)), gormConfig)
if err != nil {
return nil, err
}
// 自动迁移(始终执行,确保数据库结构与代码同步)
// GORM 的 AutoMigrate 只会添加新字段,不会删除或修改已有字段,是安全的
if err := repository.AutoMigrate(db, cfg.RunMode); err != nil {
return nil, err
}
return db, nil
}
// Package infrastructure 提供应用程序的基础设施层组件。
// 包括数据库连接初始化、ORM 客户端管理、Redis 连接、数据库迁移等核心功能。
package infrastructure
import (
"context"
"database/sql"
"time"
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/migrations"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "github.com/lib/pq" // PostgreSQL 驱动,通过副作用导入注册驱动
)
// InitEnt 初始化 Ent ORM 客户端并返回客户端实例和底层的 *sql.DB。
//
// 该函数执行以下操作:
// 1. 初始化全局时区设置,确保时间处理一致性
// 2. 建立 PostgreSQL 数据库连接
// 3. 自动执行数据库迁移,确保 schema 与代码同步
// 4. 创建并返回 Ent 客户端实例
//
// 重要提示:调用者必须负责关闭返回的 ent.Client(关闭时会自动关闭底层的 driver/db)。
//
// 参数:
// - cfg: 应用程序配置,包含数据库连接信息和时区设置
//
// 返回:
// - *ent.Client: Ent ORM 客户端,用于执行数据库操作
// - *sql.DB: 底层的 SQL 数据库连接,可用于直接执行原生 SQL
// - error: 初始化过程中的错误
func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) {
// 优先初始化时区设置,确保所有时间操作使用统一的时区。
// 这对于跨时区部署和日志时间戳的一致性至关重要。
if err := timezone.Init(cfg.Timezone); err != nil {
return nil, nil, err
}
// 构建包含时区信息的数据库连接字符串 (DSN)。
// 时区信息会传递给 PostgreSQL,确保数据库层面的时间处理正确。
dsn := cfg.Database.DSNWithTimezone(cfg.Timezone)
// 使用 Ent 的 SQL 驱动打开 PostgreSQL 连接。
// dialect.Postgres 指定使用 PostgreSQL 方言进行 SQL 生成。
drv, err := entsql.Open(dialect.Postgres, dsn)
if err != nil {
return nil, nil, err
}
// 确保数据库 schema 已准备就绪。
// SQL 迁移文件是 schema 的权威来源(source of truth)。
// 这种方式比 Ent 的自动迁移更可控,支持复杂的迁移场景。
migrationCtx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
if err := applyMigrationsFS(migrationCtx, drv.DB(), migrations.FS); err != nil {
_ = drv.Close() // 迁移失败时关闭驱动,避免资源泄露
return nil, nil, err
}
// 创建 Ent 客户端,绑定到已配置的数据库驱动。
client := ent.NewClient(ent.Driver(drv))
return client, drv.DB(), nil
}
package infrastructure
import (
"context"
"crypto/sha256"
"database/sql"
"encoding/hex"
"errors"
"fmt"
"io/fs"
"sort"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/migrations"
)
// schemaMigrationsTableDDL 定义迁移记录表的 DDL。
// 该表用于跟踪已应用的迁移文件及其校验和。
// - filename: 迁移文件名,作为主键唯一标识每个迁移
// - checksum: 文件内容的 SHA256 哈希值,用于检测迁移文件是否被篡改
// - applied_at: 迁移应用时间戳
const schemaMigrationsTableDDL = `
CREATE TABLE IF NOT EXISTS schema_migrations (
filename TEXT PRIMARY KEY,
checksum TEXT NOT NULL,
applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
`
// migrationsAdvisoryLockID 是用于序列化迁移操作的 PostgreSQL Advisory Lock ID。
// 在多实例部署场景下,该锁确保同一时间只有一个实例执行迁移。
// 任何稳定的 int64 值都可以,只要不与同一数据库中的其他锁冲突即可。
const migrationsAdvisoryLockID int64 = 694208311321144027
const migrationsLockRetryInterval = 500 * time.Millisecond
// ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。
//
// 该函数可以在每次应用启动时安全调用:
// - 已应用的迁移会被自动跳过(通过校验 filename 判断)
// - 如果迁移文件内容被修改(checksum 不匹配),会返回错误
// - 使用 PostgreSQL Advisory Lock 确保多实例并发安全
//
// 参数:
// - ctx: 上下文,用于超时控制和取消
// - db: 数据库连接
//
// 返回:
// - error: 迁移过程中的任何错误
func ApplyMigrations(ctx context.Context, db *sql.DB) error {
if db == nil {
return errors.New("nil sql db")
}
return applyMigrationsFS(ctx, db, migrations.FS)
}
// applyMigrationsFS 是迁移执行的核心实现。
// 它从指定的文件系统读取 SQL 迁移文件并按顺序应用。
//
// 迁移执行流程:
// 1. 获取 PostgreSQL Advisory Lock,防止多实例并发迁移
// 2. 确保 schema_migrations 表存在
// 3. 按文件名排序读取所有 .sql 文件
// 4. 对于每个迁移文件:
// - 计算文件内容的 SHA256 校验和
// - 检查该迁移是否已应用(通过 filename 查询)
// - 如果已应用,验证校验和是否匹配
// - 如果未应用,在事务中执行迁移并记录
// 5. 释放 Advisory Lock
//
// 参数:
// - ctx: 上下文
// - db: 数据库连接
// - fsys: 包含迁移文件的文件系统(通常是 embed.FS)
func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
if db == nil {
return errors.New("nil sql db")
}
// 获取分布式锁,确保多实例部署时只有一个实例执行迁移。
// 这是 PostgreSQL 特有的 Advisory Lock 机制。
if err := pgAdvisoryLock(ctx, db); err != nil {
return err
}
defer func() {
// 无论迁移是否成功,都要释放锁。
// 使用 context.Background() 确保即使原 ctx 已取消也能释放锁。
_ = pgAdvisoryUnlock(context.Background(), db)
}()
// 创建迁移记录表(如果不存在)。
// 该表记录所有已应用的迁移及其校验和。
if _, err := db.ExecContext(ctx, schemaMigrationsTableDDL); err != nil {
return fmt.Errorf("create schema_migrations: %w", err)
}
// 获取所有 .sql 迁移文件并按文件名排序。
// 命名规范:使用零填充数字前缀(如 001_init.sql, 002_add_users.sql)。
files, err := fs.Glob(fsys, "*.sql")
if err != nil {
return fmt.Errorf("list migrations: %w", err)
}
sort.Strings(files) // 确保按文件名顺序执行迁移
for _, name := range files {
// 读取迁移文件内容
contentBytes, err := fs.ReadFile(fsys, name)
if err != nil {
return fmt.Errorf("read migration %s: %w", name, err)
}
content := strings.TrimSpace(string(contentBytes))
if content == "" {
continue // 跳过空文件
}
// 计算文件内容的 SHA256 校验和,用于检测文件是否被修改。
// 这是一种防篡改机制:如果有人修改了已应用的迁移文件,系统会拒绝启动。
sum := sha256.Sum256([]byte(content))
checksum := hex.EncodeToString(sum[:])
// 检查该迁移是否已经应用
var existing string
rowErr := db.QueryRowContext(ctx, "SELECT checksum FROM schema_migrations WHERE filename = $1", name).Scan(&existing)
if rowErr == nil {
// 迁移已应用,验证校验和是否匹配
if existing != checksum {
// 校验和不匹配意味着迁移文件在应用后被修改,这是危险的。
// 正确的做法是创建新的迁移文件来进行变更。
return fmt.Errorf("migration %s checksum mismatch (db=%s file=%s)", name, existing, checksum)
}
continue // 迁移已应用且校验和匹配,跳过
}
if !errors.Is(rowErr, sql.ErrNoRows) {
return fmt.Errorf("check migration %s: %w", name, rowErr)
}
// 迁移未应用,在事务中执行。
// 使用事务确保迁移的原子性:要么完全成功,要么完全回滚。
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("begin migration %s: %w", name, err)
}
// 执行迁移 SQL
if _, err := tx.ExecContext(ctx, content); err != nil {
_ = tx.Rollback()
return fmt.Errorf("apply migration %s: %w", name, err)
}
// 记录迁移已完成,保存文件名和校验和
if _, err := tx.ExecContext(ctx, "INSERT INTO schema_migrations (filename, checksum) VALUES ($1, $2)", name, checksum); err != nil {
_ = tx.Rollback()
return fmt.Errorf("record migration %s: %w", name, err)
}
// 提交事务
if err := tx.Commit(); err != nil {
_ = tx.Rollback()
return fmt.Errorf("commit migration %s: %w", name, err)
}
}
return nil
}
// pgAdvisoryLock 获取 PostgreSQL Advisory Lock。
// Advisory Lock 是一种轻量级的锁机制,不与任何特定的数据库对象关联。
// 它非常适合用于应用层面的分布式锁场景,如迁移序列化。
func pgAdvisoryLock(ctx context.Context, db *sql.DB) error {
ticker := time.NewTicker(migrationsLockRetryInterval)
defer ticker.Stop()
for {
var locked bool
if err := db.QueryRowContext(ctx, "SELECT pg_try_advisory_lock($1)", migrationsAdvisoryLockID).Scan(&locked); err != nil {
return fmt.Errorf("acquire migrations lock: %w", err)
}
if locked {
return nil
}
select {
case <-ctx.Done():
return fmt.Errorf("acquire migrations lock: %w", ctx.Err())
case <-ticker.C:
}
}
}
// pgAdvisoryUnlock 释放 PostgreSQL Advisory Lock。
// 必须在获取锁后确保释放,否则会阻塞其他实例的迁移操作。
func pgAdvisoryUnlock(ctx context.Context, db *sql.DB) error {
_, err := db.ExecContext(ctx, "SELECT pg_advisory_unlock($1)", migrationsAdvisoryLockID)
if err != nil {
return fmt.Errorf("release migrations lock: %w", err)
}
return nil
}
package infrastructure
import (
"database/sql"
"errors"
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/google/wire"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
entsql "entgo.io/ent/dialect/sql"
)
// ProviderSet 提供基础设施层的依赖
// ProviderSet 是基础设施层的 Wire 依赖提供者集合。
//
// Wire 是 Google 开发的编译时依赖注入工具。ProviderSet 将相关的依赖提供函数
// 组织在一起,便于在应用程序启动时自动组装依赖关系。
//
// 包含的提供者:
// - ProvideEnt: 提供 Ent ORM 客户端
// - ProvideSQLDB: 提供底层 SQL 数据库连接
// - ProvideRedis: 提供 Redis 客户端
var ProviderSet = wire.NewSet(
ProvideDB,
ProvideEnt,
ProvideSQLDB,
ProvideRedis,
)
// ProvideDB 提供数据库连接
func ProvideDB(cfg *config.Config) (*gorm.DB, error) {
return InitDB(cfg)
// ProvideEnt 为依赖注入提供 Ent 客户端。
//
// 该函数是 InitEnt 的包装器,符合 Wire 的依赖提供函数签名要求。
// Wire 会在编译时分析依赖关系,自动生成初始化代码。
//
// 依赖:config.Config
// 提供:*ent.Client
func ProvideEnt(cfg *config.Config) (*ent.Client, error) {
client, _, err := InitEnt(cfg)
return client, err
}
// ProvideSQLDB 从 Ent 客户端提取底层的 *sql.DB 连接。
//
// 某些 Repository 需要直接执行原生 SQL(如复杂的批量更新、聚合查询),
// 此时需要访问底层的 sql.DB 而不是通过 Ent ORM。
//
// 设计说明:
// - Ent 底层使用 sql.DB,通过 Driver 接口可以访问
// - 这种设计允许在同一事务中混用 Ent 和原生 SQL
//
// 依赖:*ent.Client
// 提供:*sql.DB
func ProvideSQLDB(client *ent.Client) (*sql.DB, error) {
if client == nil {
return nil, errors.New("nil ent client")
}
// 从 Ent 客户端获取底层驱动
drv, ok := client.Driver().(*entsql.Driver)
if !ok {
return nil, errors.New("ent driver does not expose *sql.DB")
}
// 返回驱动持有的 sql.DB 实例
return drv.DB(), nil
}
// ProvideRedis 提供 Redis 客户端
// ProvideRedis 为依赖注入提供 Redis 客户端。
//
// Redis 用于:
// - 分布式锁(如并发控制)
// - 缓存(如用户会话、API 响应缓存)
// - 速率限制
// - 实时统计数据
//
// 依赖:config.Config
// 提供:*redis.Client
func ProvideRedis(cfg *config.Config) *redis.Client {
return InitRedis(cfg)
}
// Package repository 实现数据访问层(Repository Pattern)。
//
// 该包提供了与数据库交互的所有操作,包括 CRUD、复杂查询和批量操作。
// 采用 Repository 模式将数据访问逻辑与业务逻辑分离,便于测试和维护。
//
// 主要特性:
// - 使用 Ent ORM 进行类型安全的数据库操作
// - 对于复杂查询(如批量更新、聚合统计)使用原生 SQL
// - 提供统一的错误翻译机制,将数据库错误转换为业务错误
// - 支持软删除,所有查询自动过滤已删除记录
package repository
import (
"context"
"errors"
"database/sql"
"encoding/json"
"strconv"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
dbaccount "github.com/Wei-Shaw/sub2api/ent/account"
dbaccountgroup "github.com/Wei-Shaw/sub2api/ent/accountgroup"
dbgroup "github.com/Wei-Shaw/sub2api/ent/group"
dbpredicate "github.com/Wei-Shaw/sub2api/ent/predicate"
dbproxy "github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
"gorm.io/datatypes"
"gorm.io/gorm"
"gorm.io/gorm/clause"
entsql "entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqljson"
)
// accountRepository 实现 service.AccountRepository 接口。
// 提供 AI API 账户的完整数据访问功能。
//
// 设计说明:
// - client: Ent 客户端,用于类型安全的 ORM 操作
// - sql: 原生 SQL 执行器,用于复杂查询和批量操作
type accountRepository struct {
db *gorm.DB
client *dbent.Client // Ent ORM 客户端
sql sqlExecutor // 原生 SQL 执行接口
}
func NewAccountRepository(db *gorm.DB) service.AccountRepository {
return &accountRepository{db: db}
// NewAccountRepository 创建账户仓储实例。
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB) service.AccountRepository {
return newAccountRepositoryWithSQL(client, sqlDB)
}
// newAccountRepositoryWithSQL 是内部构造函数,支持依赖注入 SQL 执行器。
// 这种设计便于单元测试时注入 mock 对象。
func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *accountRepository {
return &accountRepository{client: client, sql: sqlq}
}
func (r *accountRepository) Create(ctx context.Context, account *service.Account) error {
m := accountModelFromService(account)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyAccountModelToService(account, m)
if account == nil {
return nil
}
builder := r.client.Account.Create().
SetName(account.Name).
SetPlatform(account.Platform).
SetType(account.Type).
SetCredentials(normalizeJSONMap(account.Credentials)).
SetExtra(normalizeJSONMap(account.Extra)).
SetConcurrency(account.Concurrency).
SetPriority(account.Priority).
SetStatus(account.Status).
SetErrorMessage(account.ErrorMessage).
SetSchedulable(account.Schedulable)
if account.ProxyID != nil {
builder.SetProxyID(*account.ProxyID)
}
if account.LastUsedAt != nil {
builder.SetLastUsedAt(*account.LastUsedAt)
}
if account.RateLimitedAt != nil {
builder.SetRateLimitedAt(*account.RateLimitedAt)
}
if account.RateLimitResetAt != nil {
builder.SetRateLimitResetAt(*account.RateLimitResetAt)
}
if account.OverloadUntil != nil {
builder.SetOverloadUntil(*account.OverloadUntil)
}
if account.SessionWindowStart != nil {
builder.SetSessionWindowStart(*account.SessionWindowStart)
}
if account.SessionWindowEnd != nil {
builder.SetSessionWindowEnd(*account.SessionWindowEnd)
}
if account.SessionWindowStatus != "" {
builder.SetSessionWindowStatus(account.SessionWindowStatus)
}
created, err := builder.Save(ctx)
if err != nil {
return err
}
account.ID = created.ID
account.CreatedAt = created.CreatedAt
account.UpdatedAt = created.UpdatedAt
return nil
}
func (r *accountRepository) GetByID(ctx context.Context, id int64) (*service.Account, error) {
var m accountModel
err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups.Group").First(&m, id).Error
m, err := r.client.Account.Query().Where(dbaccount.IDEQ(id)).Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrAccountNotFound, nil)
}
return accountModelToService(&m), nil
accounts, err := r.accountsToService(ctx, []*dbent.Account{m})
if err != nil {
return nil, err
}
if len(accounts) == 0 {
return nil, service.ErrAccountNotFound
}
return &accounts[0], nil
}
// ExistsByID 检查指定 ID 的账号是否存在。
// 相比 GetByID,此方法性能更优,因为:
// - 使用 Exist() 方法生成 SELECT EXISTS 查询,只返回布尔值
// - 不加载完整的账号实体及其关联数据(Groups、Proxy 等)
// - 适用于删除前的存在性检查等只需判断有无的场景
func (r *accountRepository) ExistsByID(ctx context.Context, id int64) (bool, error) {
exists, err := r.client.Account.Query().Where(dbaccount.IDEQ(id)).Exist(ctx)
if err != nil {
return false, err
}
return exists, nil
}
func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*service.Account, error) {
......@@ -44,31 +141,101 @@ func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID
return nil, nil
}
var m accountModel
err := r.db.WithContext(ctx).Where("extra->>'crs_account_id' = ?", crsAccountID).First(&m).Error
// 使用 sqljson.ValueEQ 生成 JSON 路径过滤,避免手写 SQL 片段导致语法兼容问题。
m, err := r.client.Account.Query().
Where(func(s *entsql.Selector) {
s.Where(sqljson.ValueEQ(dbaccount.FieldExtra, crsAccountID, sqljson.Path("crs_account_id")))
}).
Only(ctx)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
if dbent.IsNotFound(err) {
return nil, nil
}
return nil, err
}
return accountModelToService(&m), nil
accounts, err := r.accountsToService(ctx, []*dbent.Account{m})
if err != nil {
return nil, err
}
if len(accounts) == 0 {
return nil, nil
}
return &accounts[0], nil
}
func (r *accountRepository) Update(ctx context.Context, account *service.Account) error {
m := accountModelFromService(account)
err := r.db.WithContext(ctx).Save(m).Error
if err == nil {
applyAccountModelToService(account, m)
if account == nil {
return nil
}
return err
builder := r.client.Account.UpdateOneID(account.ID).
SetName(account.Name).
SetPlatform(account.Platform).
SetType(account.Type).
SetCredentials(normalizeJSONMap(account.Credentials)).
SetExtra(normalizeJSONMap(account.Extra)).
SetConcurrency(account.Concurrency).
SetPriority(account.Priority).
SetStatus(account.Status).
SetErrorMessage(account.ErrorMessage).
SetSchedulable(account.Schedulable)
if account.ProxyID != nil {
builder.SetProxyID(*account.ProxyID)
} else {
builder.ClearProxyID()
}
if account.LastUsedAt != nil {
builder.SetLastUsedAt(*account.LastUsedAt)
} else {
builder.ClearLastUsedAt()
}
if account.RateLimitedAt != nil {
builder.SetRateLimitedAt(*account.RateLimitedAt)
} else {
builder.ClearRateLimitedAt()
}
if account.RateLimitResetAt != nil {
builder.SetRateLimitResetAt(*account.RateLimitResetAt)
} else {
builder.ClearRateLimitResetAt()
}
if account.OverloadUntil != nil {
builder.SetOverloadUntil(*account.OverloadUntil)
} else {
builder.ClearOverloadUntil()
}
if account.SessionWindowStart != nil {
builder.SetSessionWindowStart(*account.SessionWindowStart)
} else {
builder.ClearSessionWindowStart()
}
if account.SessionWindowEnd != nil {
builder.SetSessionWindowEnd(*account.SessionWindowEnd)
} else {
builder.ClearSessionWindowEnd()
}
if account.SessionWindowStatus != "" {
builder.SetSessionWindowStatus(account.SessionWindowStatus)
} else {
builder.ClearSessionWindowStatus()
}
updated, err := builder.Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrAccountNotFound, nil)
}
account.UpdatedAt = updated.UpdatedAt
return nil
}
func (r *accountRepository) Delete(ctx context.Context, id int64) error {
if err := r.db.WithContext(ctx).Where("account_id = ?", id).Delete(&accountGroupModel{}).Error; err != nil {
if _, err := r.client.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(id)).Exec(ctx); err != nil {
return err
}
return r.db.WithContext(ctx).Delete(&accountModel{}, id).Error
_, err := r.client.Account.Delete().Where(dbaccount.IDEQ(id)).Exec(ctx)
return err
}
func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
......@@ -76,99 +243,84 @@ func (r *accountRepository) List(ctx context.Context, params pagination.Paginati
}
func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) {
var accounts []accountModel
var total int64
db := r.db.WithContext(ctx).Model(&accountModel{})
q := r.client.Account.Query()
if platform != "" {
db = db.Where("platform = ?", platform)
q = q.Where(dbaccount.PlatformEQ(platform))
}
if accountType != "" {
db = db.Where("type = ?", accountType)
q = q.Where(dbaccount.TypeEQ(accountType))
}
if status != "" {
db = db.Where("status = ?", status)
q = q.Where(dbaccount.StatusEQ(status))
}
if search != "" {
searchPattern := "%" + search + "%"
db = db.Where("name ILIKE ?", searchPattern)
q = q.Where(dbaccount.NameContainsFold(search))
}
if err := db.Count(&total).Error; err != nil {
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
}
if err := db.Preload("Proxy").Preload("AccountGroups.Group").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&accounts).Error; err != nil {
accounts, err := q.
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(dbaccount.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
outAccounts, err := r.accountsToService(ctx, accounts)
if err != nil {
return nil, nil, err
}
return outAccounts, paginationResultFromTotal(total, params), nil
return outAccounts, paginationResultFromTotal(int64(total), params), nil
}
func (r *accountRepository) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {
var accounts []accountModel
err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
Where("account_groups.group_id = ? AND accounts.status = ?", groupID, service.StatusActive).
Preload("Proxy").
Order("account_groups.priority ASC, accounts.priority ASC").
Find(&accounts).Error
accounts, err := r.queryAccountsByGroup(ctx, groupID, accountGroupQueryOptions{
status: service.StatusActive,
})
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
return accounts, nil
}
func (r *accountRepository) ListActive(ctx context.Context) ([]service.Account, error) {
var accounts []accountModel
err := r.db.WithContext(ctx).
Where("status = ?", service.StatusActive).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
accounts, err := r.client.Account.Query().
Where(dbaccount.StatusEQ(service.StatusActive)).
Order(dbent.Asc(dbaccount.FieldPriority)).
All(ctx)
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
return r.accountsToService(ctx, accounts)
}
func (r *accountRepository) ListByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
var accounts []accountModel
err := r.db.WithContext(ctx).
Where("platform = ? AND status = ?", platform, service.StatusActive).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
accounts, err := r.client.Account.Query().
Where(
dbaccount.PlatformEQ(platform),
dbaccount.StatusEQ(service.StatusActive),
).
Order(dbent.Asc(dbaccount.FieldPriority)).
All(ctx)
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
return r.accountsToService(ctx, accounts)
}
func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error {
now := time.Now()
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).Update("last_used_at", now).Error
_, err := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
SetLastUsedAt(now).
Save(ctx)
return err
}
func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
......@@ -176,63 +328,72 @@ func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map
return nil
}
var caseSql = "UPDATE accounts SET last_used_at = CASE id"
var args []any
var ids []int64
ids := make([]int64, 0, len(updates))
args := make([]any, 0, len(updates)*2+1)
caseSQL := "UPDATE accounts SET last_used_at = CASE id"
idx := 1
for id, ts := range updates {
caseSql += " WHEN ? THEN CAST(? AS TIMESTAMP)"
caseSQL += " WHEN $" + itoa(idx) + " THEN $" + itoa(idx+1)
args = append(args, id, ts)
ids = append(ids, id)
idx += 2
}
caseSql += " END WHERE id IN ? AND deleted_at IS NULL"
args = append(args, ids)
caseSQL += " END, updated_at = NOW() WHERE id = ANY($" + itoa(idx) + ") AND deleted_at IS NULL"
args = append(args, pq.Array(ids))
return r.db.WithContext(ctx).Exec(caseSql, args...).Error
_, err := r.sql.ExecContext(ctx, caseSQL, args...)
return err
}
func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Updates(map[string]any{
"status": service.StatusError,
"error_message": errorMsg,
}).Error
_, err := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
SetStatus(service.StatusError).
SetErrorMessage(errorMsg).
Save(ctx)
return err
}
func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
ag := &accountGroupModel{
AccountID: accountID,
GroupID: groupID,
Priority: priority,
}
return r.db.WithContext(ctx).Create(ag).Error
_, err := r.client.AccountGroup.Create().
SetAccountID(accountID).
SetGroupID(groupID).
SetPriority(priority).
Save(ctx)
return err
}
func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error {
return r.db.WithContext(ctx).Where("account_id = ? AND group_id = ?", accountID, groupID).
Delete(&accountGroupModel{}).Error
_, err := r.client.AccountGroup.Delete().
Where(
dbaccountgroup.AccountIDEQ(accountID),
dbaccountgroup.GroupIDEQ(groupID),
).
Exec(ctx)
return err
}
func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]service.Group, error) {
var groups []groupModel
err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.group_id = groups.id").
Where("account_groups.account_id = ?", accountID).
Find(&groups).Error
groups, err := r.client.Group.Query().
Where(
dbgroup.HasAccountsWith(dbaccount.IDEQ(accountID)),
).
All(ctx)
if err != nil {
return nil, err
}
outGroups := make([]service.Group, 0, len(groups))
for i := range groups {
outGroups = append(outGroups, *groupModelToService(&groups[i]))
outGroups = append(outGroups, *groupEntityToService(groups[i]))
}
return outGroups, nil
}
func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
if err := r.db.WithContext(ctx).Where("account_id = ?", accountID).Delete(&accountGroupModel{}).Error; err != nil {
if _, err := r.client.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(accountID)).Exec(ctx); err != nil {
return err
}
......@@ -240,192 +401,153 @@ func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, gro
return nil
}
accountGroups := make([]accountGroupModel, 0, len(groupIDs))
builders := make([]*dbent.AccountGroupCreate, 0, len(groupIDs))
for i, groupID := range groupIDs {
accountGroups = append(accountGroups, accountGroupModel{
AccountID: accountID,
GroupID: groupID,
Priority: i + 1,
})
builders = append(builders, r.client.AccountGroup.Create().
SetAccountID(accountID).
SetGroupID(groupID).
SetPriority(i+1),
)
}
return r.db.WithContext(ctx).Create(&accountGroups).Error
_, err := r.client.AccountGroup.CreateBulk(builders...).Save(ctx)
return err
}
func (r *accountRepository) ListSchedulable(ctx context.Context) ([]service.Account, error) {
var accounts []accountModel
now := time.Now()
err := r.db.WithContext(ctx).
Where("status = ? AND schedulable = ?", service.StatusActive, true).
Where("(overload_until IS NULL OR overload_until <= ?)", now).
Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
accounts, err := r.client.Account.Query().
Where(
dbaccount.StatusEQ(service.StatusActive),
dbaccount.SchedulableEQ(true),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
).
Order(dbent.Asc(dbaccount.FieldPriority)).
All(ctx)
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
return r.accountsToService(ctx, accounts)
}
func (r *accountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) {
var accounts []accountModel
now := time.Now()
err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
Where("account_groups.group_id = ?", groupID).
Where("accounts.status = ? AND accounts.schedulable = ?", service.StatusActive, true).
Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now).
Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("account_groups.priority ASC, accounts.priority ASC").
Find(&accounts).Error
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
return r.queryAccountsByGroup(ctx, groupID, accountGroupQueryOptions{
status: service.StatusActive,
schedulable: true,
})
}
func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
var accounts []accountModel
now := time.Now()
err := r.db.WithContext(ctx).
Where("platform = ?", platform).
Where("status = ? AND schedulable = ?", service.StatusActive, true).
Where("(overload_until IS NULL OR overload_until <= ?)", now).
Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
accounts, err := r.client.Account.Query().
Where(
dbaccount.PlatformEQ(platform),
dbaccount.StatusEQ(service.StatusActive),
dbaccount.SchedulableEQ(true),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
).
Order(dbent.Asc(dbaccount.FieldPriority)).
All(ctx)
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
return r.accountsToService(ctx, accounts)
}
func (r *accountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
var accounts []accountModel
now := time.Now()
err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
Where("account_groups.group_id = ?", groupID).
Where("accounts.platform = ?", platform).
Where("accounts.status = ? AND accounts.schedulable = ?", service.StatusActive, true).
Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now).
Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("account_groups.priority ASC, accounts.priority ASC").
Find(&accounts).Error
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
// 单平台查询复用多平台逻辑,保持过滤条件与排序策略一致。
return r.queryAccountsByGroup(ctx, groupID, accountGroupQueryOptions{
status: service.StatusActive,
schedulable: true,
platforms: []string{platform},
})
}
func (r *accountRepository) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
if len(platforms) == 0 {
return nil, nil
}
var accounts []accountModel
// 仅返回可调度的活跃账号,并过滤处于过载/限流窗口的账号。
// 代理与分组信息统一在 accountsToService 中批量加载,避免 N+1 查询。
now := time.Now()
err := r.db.WithContext(ctx).
Where("platform IN ?", platforms).
Where("status = ? AND schedulable = ?", service.StatusActive, true).
Where("(overload_until IS NULL OR overload_until <= ?)", now).
Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
accounts, err := r.client.Account.Query().
Where(
dbaccount.PlatformIn(platforms...),
dbaccount.StatusEQ(service.StatusActive),
dbaccount.SchedulableEQ(true),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
).
Order(dbent.Asc(dbaccount.FieldPriority)).
All(ctx)
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
return r.accountsToService(ctx, accounts)
}
func (r *accountRepository) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) {
if len(platforms) == 0 {
return nil, nil
}
var accounts []accountModel
now := time.Now()
err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
Where("account_groups.group_id = ?", groupID).
Where("accounts.platform IN ?", platforms).
Where("accounts.status = ? AND accounts.schedulable = ?", service.StatusActive, true).
Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now).
Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("account_groups.priority ASC, accounts.priority ASC").
Find(&accounts).Error
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
// 复用按分组查询逻辑,保证分组优先级 + 账号优先级的排序与筛选一致。
return r.queryAccountsByGroup(ctx, groupID, accountGroupQueryOptions{
status: service.StatusActive,
schedulable: true,
platforms: platforms,
})
}
func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
now := time.Now()
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Updates(map[string]any{
"rate_limited_at": now,
"rate_limit_reset_at": resetAt,
}).Error
_, err := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
SetRateLimitedAt(now).
SetRateLimitResetAt(resetAt).
Save(ctx)
return err
}
func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Update("overload_until", until).Error
_, err := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
SetOverloadUntil(until).
Save(ctx)
return err
}
func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Updates(map[string]any{
"rate_limited_at": nil,
"rate_limit_reset_at": nil,
"overload_until": nil,
}).Error
_, err := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
ClearRateLimitedAt().
ClearRateLimitResetAt().
ClearOverloadUntil().
Save(ctx)
return err
}
func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
updates := map[string]any{
"session_window_status": status,
}
builder := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
SetSessionWindowStatus(status)
if start != nil {
updates["session_window_start"] = start
builder.SetSessionWindowStart(*start)
}
if end != nil {
updates["session_window_end"] = end
builder.SetSessionWindowEnd(*end)
}
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).Updates(updates).Error
_, err := builder.Save(ctx)
return err
}
func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Update("schedulable", schedulable).Error
_, err := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
SetSchedulable(schedulable).
Save(ctx)
return err
}
func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
......@@ -433,20 +555,24 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m
return nil
}
var account accountModel
if err := r.db.WithContext(ctx).Select("extra").Where("id = ?", id).First(&account).Error; err != nil {
return err
accountExtra, err := r.client.Account.Query().
Where(dbaccount.IDEQ(id)).
Select(dbaccount.FieldExtra).
Only(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrAccountNotFound, nil)
}
if account.Extra == nil {
account.Extra = datatypes.JSONMap{}
}
extra := normalizeJSONMap(accountExtra.Extra)
for k, v := range updates {
account.Extra[k] = v
extra[k] = v
}
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Update("extra", account.Extra).Error
_, err = r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
SetExtra(extra).
Save(ctx)
return err
}
func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
......@@ -454,129 +580,261 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
return 0, nil
}
updateMap := map[string]any{}
setClauses := make([]string, 0, 8)
args := make([]any, 0, 8)
idx := 1
if updates.Name != nil {
updateMap["name"] = *updates.Name
setClauses = append(setClauses, "name = $"+itoa(idx))
args = append(args, *updates.Name)
idx++
}
if updates.ProxyID != nil {
updateMap["proxy_id"] = updates.ProxyID
setClauses = append(setClauses, "proxy_id = $"+itoa(idx))
args = append(args, *updates.ProxyID)
idx++
}
if updates.Concurrency != nil {
updateMap["concurrency"] = *updates.Concurrency
setClauses = append(setClauses, "concurrency = $"+itoa(idx))
args = append(args, *updates.Concurrency)
idx++
}
if updates.Priority != nil {
updateMap["priority"] = *updates.Priority
setClauses = append(setClauses, "priority = $"+itoa(idx))
args = append(args, *updates.Priority)
idx++
}
if updates.Status != nil {
updateMap["status"] = *updates.Status
setClauses = append(setClauses, "status = $"+itoa(idx))
args = append(args, *updates.Status)
idx++
}
// JSONB 需要合并而非覆盖,使用 raw SQL 保持旧行为。
if len(updates.Credentials) > 0 {
updateMap["credentials"] = gorm.Expr("COALESCE(credentials,'{}') || ?", datatypes.JSONMap(updates.Credentials))
payload, err := json.Marshal(updates.Credentials)
if err != nil {
return 0, err
}
setClauses = append(setClauses, "credentials = COALESCE(credentials, '{}'::jsonb) || $"+itoa(idx)+"::jsonb")
args = append(args, payload)
idx++
}
if len(updates.Extra) > 0 {
updateMap["extra"] = gorm.Expr("COALESCE(extra,'{}') || ?", datatypes.JSONMap(updates.Extra))
payload, err := json.Marshal(updates.Extra)
if err != nil {
return 0, err
}
setClauses = append(setClauses, "extra = COALESCE(extra, '{}'::jsonb) || $"+itoa(idx)+"::jsonb")
args = append(args, payload)
idx++
}
if len(updateMap) == 0 {
if len(setClauses) == 0 {
return 0, nil
}
result := r.db.WithContext(ctx).
Model(&accountModel{}).
Where("id IN ?", ids).
Clauses(clause.Returning{}).
Updates(updateMap)
setClauses = append(setClauses, "updated_at = NOW()")
query := "UPDATE accounts SET " + joinClauses(setClauses, ", ") + " WHERE id = ANY($" + itoa(idx) + ") AND deleted_at IS NULL"
args = append(args, pq.Array(ids))
result, err := r.sql.ExecContext(ctx, query, args...)
if err != nil {
return 0, err
}
rows, err := result.RowsAffected()
if err != nil {
return 0, err
}
return rows, nil
}
return result.RowsAffected, result.Error
type accountGroupQueryOptions struct {
status string
schedulable bool
platforms []string // 允许的多个平台,空切片表示不进行平台过滤
}
type accountModel struct {
ID int64 `gorm:"primaryKey"`
Name string `gorm:"size:100;not null"`
Platform string `gorm:"size:50;not null"`
Type string `gorm:"size:20;not null"`
Credentials datatypes.JSONMap `gorm:"type:jsonb;default:'{}'"`
Extra datatypes.JSONMap `gorm:"type:jsonb;default:'{}'"`
ProxyID *int64 `gorm:"index"`
Concurrency int `gorm:"default:3;not null"`
Priority int `gorm:"default:50;not null"`
Status string `gorm:"size:20;default:active;not null"`
ErrorMessage string `gorm:"type:text"`
LastUsedAt *time.Time `gorm:"index"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
DeletedAt gorm.DeletedAt `gorm:"index"`
func (r *accountRepository) queryAccountsByGroup(ctx context.Context, groupID int64, opts accountGroupQueryOptions) ([]service.Account, error) {
q := r.client.AccountGroup.Query().
Where(dbaccountgroup.GroupIDEQ(groupID))
// 通过 account_groups 中间表查询账号,并按需叠加状态/平台/调度能力过滤。
preds := make([]dbpredicate.Account, 0, 6)
preds = append(preds, dbaccount.DeletedAtIsNil())
if opts.status != "" {
preds = append(preds, dbaccount.StatusEQ(opts.status))
}
if len(opts.platforms) > 0 {
preds = append(preds, dbaccount.PlatformIn(opts.platforms...))
}
if opts.schedulable {
now := time.Now()
preds = append(preds,
dbaccount.SchedulableEQ(true),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
)
}
if len(preds) > 0 {
q = q.Where(dbaccountgroup.HasAccountWith(preds...))
}
Schedulable bool `gorm:"default:true;not null"`
groups, err := q.
Order(
dbaccountgroup.ByPriority(),
dbaccountgroup.ByAccountField(dbaccount.FieldPriority),
).
WithAccount().
All(ctx)
if err != nil {
return nil, err
}
RateLimitedAt *time.Time `gorm:"index"`
RateLimitResetAt *time.Time `gorm:"index"`
OverloadUntil *time.Time `gorm:"index"`
orderedIDs := make([]int64, 0, len(groups))
accountMap := make(map[int64]*dbent.Account, len(groups))
for _, ag := range groups {
if ag.Edges.Account == nil {
continue
}
if _, exists := accountMap[ag.AccountID]; exists {
continue
}
accountMap[ag.AccountID] = ag.Edges.Account
orderedIDs = append(orderedIDs, ag.AccountID)
}
SessionWindowStart *time.Time
SessionWindowEnd *time.Time
SessionWindowStatus string `gorm:"size:20"`
accounts := make([]*dbent.Account, 0, len(orderedIDs))
for _, id := range orderedIDs {
if acc, ok := accountMap[id]; ok {
accounts = append(accounts, acc)
}
}
Proxy *proxyModel `gorm:"foreignKey:ProxyID"`
AccountGroups []accountGroupModel `gorm:"foreignKey:AccountID"`
return r.accountsToService(ctx, accounts)
}
func (accountModel) TableName() string { return "accounts" }
func (r *accountRepository) accountsToService(ctx context.Context, accounts []*dbent.Account) ([]service.Account, error) {
if len(accounts) == 0 {
return []service.Account{}, nil
}
type accountGroupModel struct {
AccountID int64 `gorm:"primaryKey"`
GroupID int64 `gorm:"primaryKey"`
Priority int `gorm:"default:50;not null"`
CreatedAt time.Time `gorm:"not null"`
accountIDs := make([]int64, 0, len(accounts))
proxyIDs := make([]int64, 0, len(accounts))
for _, acc := range accounts {
accountIDs = append(accountIDs, acc.ID)
if acc.ProxyID != nil {
proxyIDs = append(proxyIDs, *acc.ProxyID)
}
}
Account *accountModel `gorm:"foreignKey:AccountID"`
Group *groupModel `gorm:"foreignKey:GroupID"`
proxyMap, err := r.loadProxies(ctx, proxyIDs)
if err != nil {
return nil, err
}
groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs)
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for _, acc := range accounts {
out := accountEntityToService(acc)
if out == nil {
continue
}
if acc.ProxyID != nil {
if proxy, ok := proxyMap[*acc.ProxyID]; ok {
out.Proxy = proxy
}
}
if groups, ok := groupsByAccount[acc.ID]; ok {
out.Groups = groups
}
if groupIDs, ok := groupIDsByAccount[acc.ID]; ok {
out.GroupIDs = groupIDs
}
if ags, ok := accountGroupsByAccount[acc.ID]; ok {
out.AccountGroups = ags
}
outAccounts = append(outAccounts, *out)
}
return outAccounts, nil
}
func (accountGroupModel) TableName() string { return "account_groups" }
func (r *accountRepository) loadProxies(ctx context.Context, proxyIDs []int64) (map[int64]*service.Proxy, error) {
proxyMap := make(map[int64]*service.Proxy)
if len(proxyIDs) == 0 {
return proxyMap, nil
}
func accountGroupModelToService(m *accountGroupModel) *service.AccountGroup {
if m == nil {
return nil
proxies, err := r.client.Proxy.Query().Where(dbproxy.IDIn(proxyIDs...)).All(ctx)
if err != nil {
return nil, err
}
return &service.AccountGroup{
AccountID: m.AccountID,
GroupID: m.GroupID,
Priority: m.Priority,
CreatedAt: m.CreatedAt,
Account: accountModelToService(m.Account),
Group: groupModelToService(m.Group),
for _, p := range proxies {
proxyMap[p.ID] = proxyEntityToService(p)
}
return proxyMap, nil
}
func accountModelToService(m *accountModel) *service.Account {
if m == nil {
return nil
func (r *accountRepository) loadAccountGroups(ctx context.Context, accountIDs []int64) (map[int64][]*service.Group, map[int64][]int64, map[int64][]service.AccountGroup, error) {
groupsByAccount := make(map[int64][]*service.Group)
groupIDsByAccount := make(map[int64][]int64)
accountGroupsByAccount := make(map[int64][]service.AccountGroup)
if len(accountIDs) == 0 {
return groupsByAccount, groupIDsByAccount, accountGroupsByAccount, nil
}
var credentials map[string]any
if m.Credentials != nil {
credentials = map[string]any(m.Credentials)
entries, err := r.client.AccountGroup.Query().
Where(dbaccountgroup.AccountIDIn(accountIDs...)).
WithGroup().
Order(dbaccountgroup.ByAccountID(), dbaccountgroup.ByPriority()).
All(ctx)
if err != nil {
return nil, nil, nil, err
}
var extra map[string]any
if m.Extra != nil {
extra = map[string]any(m.Extra)
for _, ag := range entries {
groupSvc := groupEntityToService(ag.Edges.Group)
agSvc := service.AccountGroup{
AccountID: ag.AccountID,
GroupID: ag.GroupID,
Priority: ag.Priority,
CreatedAt: ag.CreatedAt,
Group: groupSvc,
}
accountGroupsByAccount[ag.AccountID] = append(accountGroupsByAccount[ag.AccountID], agSvc)
groupIDsByAccount[ag.AccountID] = append(groupIDsByAccount[ag.AccountID], ag.GroupID)
if groupSvc != nil {
groupsByAccount[ag.AccountID] = append(groupsByAccount[ag.AccountID], groupSvc)
}
}
return groupsByAccount, groupIDsByAccount, accountGroupsByAccount, nil
}
account := &service.Account{
func accountEntityToService(m *dbent.Account) *service.Account {
if m == nil {
return nil
}
return &service.Account{
ID: m.ID,
Name: m.Name,
Platform: m.Platform,
Type: m.Type,
Credentials: credentials,
Extra: extra,
Credentials: copyJSONMap(m.Credentials),
Extra: copyJSONMap(m.Extra),
ProxyID: m.ProxyID,
Concurrency: m.Concurrency,
Priority: m.Priority,
Status: m.Status,
ErrorMessage: m.ErrorMessage,
ErrorMessage: derefString(m.ErrorMessage),
LastUsedAt: m.LastUsedAt,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
......@@ -586,75 +844,39 @@ func accountModelToService(m *accountModel) *service.Account {
OverloadUntil: m.OverloadUntil,
SessionWindowStart: m.SessionWindowStart,
SessionWindowEnd: m.SessionWindowEnd,
SessionWindowStatus: m.SessionWindowStatus,
Proxy: proxyModelToService(m.Proxy),
SessionWindowStatus: derefString(m.SessionWindowStatus),
}
}
if len(m.AccountGroups) > 0 {
account.AccountGroups = make([]service.AccountGroup, 0, len(m.AccountGroups))
account.GroupIDs = make([]int64, 0, len(m.AccountGroups))
account.Groups = make([]*service.Group, 0, len(m.AccountGroups))
for i := range m.AccountGroups {
ag := accountGroupModelToService(&m.AccountGroups[i])
if ag == nil {
continue
}
account.AccountGroups = append(account.AccountGroups, *ag)
account.GroupIDs = append(account.GroupIDs, ag.GroupID)
if ag.Group != nil {
account.Groups = append(account.Groups, ag.Group)
func normalizeJSONMap(in map[string]any) map[string]any {
if in == nil {
return map[string]any{}
}
return in
}
func copyJSONMap(in map[string]any) map[string]any {
if in == nil {
return nil
}
out := make(map[string]any, len(in))
for k, v := range in {
out[k] = v
}
return account
return out
}
func accountModelFromService(a *service.Account) *accountModel {
if a == nil {
return nil
func joinClauses(clauses []string, sep string) string {
if len(clauses) == 0 {
return ""
}
out := clauses[0]
for i := 1; i < len(clauses); i++ {
out += sep + clauses[i]
}
return out
}
var credentials datatypes.JSONMap
if a.Credentials != nil {
credentials = datatypes.JSONMap(a.Credentials)
}
var extra datatypes.JSONMap
if a.Extra != nil {
extra = datatypes.JSONMap(a.Extra)
}
return &accountModel{
ID: a.ID,
Name: a.Name,
Platform: a.Platform,
Type: a.Type,
Credentials: credentials,
Extra: extra,
ProxyID: a.ProxyID,
Concurrency: a.Concurrency,
Priority: a.Priority,
Status: a.Status,
ErrorMessage: a.ErrorMessage,
LastUsedAt: a.LastUsedAt,
CreatedAt: a.CreatedAt,
UpdatedAt: a.UpdatedAt,
Schedulable: a.Schedulable,
RateLimitedAt: a.RateLimitedAt,
RateLimitResetAt: a.RateLimitResetAt,
OverloadUntil: a.OverloadUntil,
SessionWindowStart: a.SessionWindowStart,
SessionWindowEnd: a.SessionWindowEnd,
SessionWindowStatus: a.SessionWindowStatus,
}
}
func applyAccountModelToService(account *service.Account, m *accountModel) {
if account == nil || m == nil {
return
}
account.ID = m.ID
account.CreatedAt = m.CreatedAt
account.UpdatedAt = m.UpdatedAt
func itoa(v int) string {
return strconv.Itoa(v)
}
......@@ -7,24 +7,25 @@ import (
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/accountgroup"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"gorm.io/datatypes"
"gorm.io/gorm"
)
type AccountRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
client *dbent.Client
repo *accountRepository
}
func (s *AccountRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewAccountRepository(s.db).(*accountRepository)
tx := testEntTx(s.T())
s.client = tx.Client()
s.repo = newAccountRepositoryWithSQL(s.client, tx)
}
func TestAccountRepoSuite(t *testing.T) {
......@@ -61,7 +62,7 @@ func (s *AccountRepoSuite) TestGetByID_NotFound() {
}
func (s *AccountRepoSuite) TestUpdate() {
account := accountModelToService(mustCreateAccount(s.T(), s.db, &accountModel{Name: "original"}))
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "original"})
account.Name = "updated"
err := s.repo.Update(s.ctx, account)
......@@ -73,7 +74,7 @@ func (s *AccountRepoSuite) TestUpdate() {
}
func (s *AccountRepoSuite) TestDelete() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "to-delete"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete"})
err := s.repo.Delete(s.ctx, account.ID)
s.Require().NoError(err, "Delete")
......@@ -83,23 +84,23 @@ func (s *AccountRepoSuite) TestDelete() {
}
func (s *AccountRepoSuite) TestDelete_WithGroupBindings() {
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-del"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-del"})
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-del"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-del"})
mustBindAccountToGroup(s.T(), s.client, account.ID, group.ID, 1)
err := s.repo.Delete(s.ctx, account.ID)
s.Require().NoError(err, "Delete should cascade remove bindings")
var count int64
s.db.Model(&accountGroupModel{}).Where("account_id = ?", account.ID).Count(&count)
count, err := s.client.AccountGroup.Query().Where(accountgroup.AccountIDEQ(account.ID)).Count(s.ctx)
s.Require().NoError(err)
s.Require().Zero(count, "expected bindings to be removed")
}
// --- List / ListWithFilters ---
func (s *AccountRepoSuite) TestList() {
mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc1"})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc2"})
mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc1"})
mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc2"})
accounts, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
......@@ -110,7 +111,7 @@ func (s *AccountRepoSuite) TestList() {
func (s *AccountRepoSuite) TestListWithFilters() {
tests := []struct {
name string
setup func(db *gorm.DB)
setup func(client *dbent.Client)
platform string
accType string
status string
......@@ -120,9 +121,9 @@ func (s *AccountRepoSuite) TestListWithFilters() {
}{
{
name: "filter_by_platform",
setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &accountModel{Name: "a1", Platform: service.PlatformAnthropic})
mustCreateAccount(s.T(), db, &accountModel{Name: "a2", Platform: service.PlatformOpenAI})
setup: func(client *dbent.Client) {
mustCreateAccount(s.T(), client, &service.Account{Name: "a1", Platform: service.PlatformAnthropic})
mustCreateAccount(s.T(), client, &service.Account{Name: "a2", Platform: service.PlatformOpenAI})
},
platform: service.PlatformOpenAI,
wantCount: 1,
......@@ -132,9 +133,9 @@ func (s *AccountRepoSuite) TestListWithFilters() {
},
{
name: "filter_by_type",
setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &accountModel{Name: "t1", Type: service.AccountTypeOAuth})
mustCreateAccount(s.T(), db, &accountModel{Name: "t2", Type: service.AccountTypeApiKey})
setup: func(client *dbent.Client) {
mustCreateAccount(s.T(), client, &service.Account{Name: "t1", Type: service.AccountTypeOAuth})
mustCreateAccount(s.T(), client, &service.Account{Name: "t2", Type: service.AccountTypeApiKey})
},
accType: service.AccountTypeApiKey,
wantCount: 1,
......@@ -144,9 +145,9 @@ func (s *AccountRepoSuite) TestListWithFilters() {
},
{
name: "filter_by_status",
setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &accountModel{Name: "s1", Status: service.StatusActive})
mustCreateAccount(s.T(), db, &accountModel{Name: "s2", Status: service.StatusDisabled})
setup: func(client *dbent.Client) {
mustCreateAccount(s.T(), client, &service.Account{Name: "s1", Status: service.StatusActive})
mustCreateAccount(s.T(), client, &service.Account{Name: "s2", Status: service.StatusDisabled})
},
status: service.StatusDisabled,
wantCount: 1,
......@@ -156,9 +157,9 @@ func (s *AccountRepoSuite) TestListWithFilters() {
},
{
name: "filter_by_search",
setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &accountModel{Name: "alpha-account"})
mustCreateAccount(s.T(), db, &accountModel{Name: "beta-account"})
setup: func(client *dbent.Client) {
mustCreateAccount(s.T(), client, &service.Account{Name: "alpha-account"})
mustCreateAccount(s.T(), client, &service.Account{Name: "beta-account"})
},
search: "alpha",
wantCount: 1,
......@@ -171,11 +172,12 @@ func (s *AccountRepoSuite) TestListWithFilters() {
for _, tt := range tests {
s.Run(tt.name, func() {
// 每个 case 重新获取隔离资源
db := testTx(s.T())
repo := NewAccountRepository(db).(*accountRepository)
tx := testEntTx(s.T())
client := tx.Client()
repo := newAccountRepositoryWithSQL(client, tx)
ctx := context.Background()
tt.setup(db)
tt.setup(client)
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search)
s.Require().NoError(err)
......@@ -190,11 +192,11 @@ func (s *AccountRepoSuite) TestListWithFilters() {
// --- ListByGroup / ListActive / ListByPlatform ---
func (s *AccountRepoSuite) TestListByGroup() {
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list"})
acc1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", Status: service.StatusActive})
acc2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", Status: service.StatusActive})
mustBindAccountToGroup(s.T(), s.db, acc1.ID, group.ID, 2)
mustBindAccountToGroup(s.T(), s.db, acc2.ID, group.ID, 1)
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-list"})
acc1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a1", Status: service.StatusActive})
acc2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a2", Status: service.StatusActive})
mustBindAccountToGroup(s.T(), s.client, acc1.ID, group.ID, 2)
mustBindAccountToGroup(s.T(), s.client, acc2.ID, group.ID, 1)
accounts, err := s.repo.ListByGroup(s.ctx, group.ID)
s.Require().NoError(err, "ListByGroup")
......@@ -204,8 +206,8 @@ func (s *AccountRepoSuite) TestListByGroup() {
}
func (s *AccountRepoSuite) TestListActive() {
mustCreateAccount(s.T(), s.db, &accountModel{Name: "active1", Status: service.StatusActive})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "inactive1", Status: service.StatusDisabled})
mustCreateAccount(s.T(), s.client, &service.Account{Name: "active1", Status: service.StatusActive})
mustCreateAccount(s.T(), s.client, &service.Account{Name: "inactive1", Status: service.StatusDisabled})
accounts, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive")
......@@ -214,8 +216,8 @@ func (s *AccountRepoSuite) TestListActive() {
}
func (s *AccountRepoSuite) TestListByPlatform() {
mustCreateAccount(s.T(), s.db, &accountModel{Name: "p1", Platform: service.PlatformAnthropic, Status: service.StatusActive})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "p2", Platform: service.PlatformOpenAI, Status: service.StatusActive})
mustCreateAccount(s.T(), s.client, &service.Account{Name: "p1", Platform: service.PlatformAnthropic, Status: service.StatusActive})
mustCreateAccount(s.T(), s.client, &service.Account{Name: "p2", Platform: service.PlatformOpenAI, Status: service.StatusActive})
accounts, err := s.repo.ListByPlatform(s.ctx, service.PlatformAnthropic)
s.Require().NoError(err, "ListByPlatform")
......@@ -226,14 +228,14 @@ func (s *AccountRepoSuite) TestListByPlatform() {
// --- Preload and VirtualFields ---
func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
proxy := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1"})
proxy := mustCreateProxy(s.T(), s.client, &service.Proxy{Name: "p1"})
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g1"})
account := mustCreateAccount(s.T(), s.db, &accountModel{
account := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "acc1",
ProxyID: &proxy.ID,
})
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
mustBindAccountToGroup(s.T(), s.client, account.ID, group.ID, 1)
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err, "GetByID")
......@@ -257,9 +259,9 @@ func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
// --- GroupBinding / AddToGroup / RemoveFromGroup / BindGroups / GetGroups ---
func (s *AccountRepoSuite) TestGroupBinding_And_BindGroups() {
g1 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1"})
g2 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc"})
g1 := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g1"})
g2 := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g2"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc"})
s.Require().NoError(s.repo.AddToGroup(s.ctx, account.ID, g1.ID, 10), "AddToGroup")
groups, err := s.repo.GetGroups(s.ctx, account.ID)
......@@ -279,9 +281,9 @@ func (s *AccountRepoSuite) TestGroupBinding_And_BindGroups() {
}
func (s *AccountRepoSuite) TestBindGroups_EmptyList() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-empty"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-empty"})
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-empty"})
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-empty"})
mustBindAccountToGroup(s.T(), s.client, account.ID, group.ID, 1)
s.Require().NoError(s.repo.BindGroups(s.ctx, account.ID, []int64{}), "BindGroups empty")
......@@ -294,14 +296,14 @@ func (s *AccountRepoSuite) TestBindGroups_EmptyList() {
func (s *AccountRepoSuite) TestListSchedulable() {
now := time.Now()
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sched"})
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-sched"})
okAcc := mustCreateAccount(s.T(), s.db, &accountModel{Name: "ok", Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1)
okAcc := mustCreateAccount(s.T(), s.client, &service.Account{Name: "ok", Schedulable: true})
mustBindAccountToGroup(s.T(), s.client, okAcc.ID, group.ID, 1)
future := now.Add(10 * time.Minute)
overloaded := mustCreateAccount(s.T(), s.db, &accountModel{Name: "over", Schedulable: true, OverloadUntil: &future})
mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1)
overloaded := mustCreateAccount(s.T(), s.client, &service.Account{Name: "over", Schedulable: true, OverloadUntil: &future})
mustBindAccountToGroup(s.T(), s.client, overloaded.ID, group.ID, 1)
sched, err := s.repo.ListSchedulable(s.ctx)
s.Require().NoError(err, "ListSchedulable")
......@@ -312,17 +314,17 @@ func (s *AccountRepoSuite) TestListSchedulable() {
func (s *AccountRepoSuite) TestListSchedulableByGroupID_TimeBoundaries_And_StatusUpdates() {
now := time.Now()
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sched"})
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-sched"})
okAcc := mustCreateAccount(s.T(), s.db, &accountModel{Name: "ok", Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1)
okAcc := mustCreateAccount(s.T(), s.client, &service.Account{Name: "ok", Schedulable: true})
mustBindAccountToGroup(s.T(), s.client, okAcc.ID, group.ID, 1)
future := now.Add(10 * time.Minute)
overloaded := mustCreateAccount(s.T(), s.db, &accountModel{Name: "over", Schedulable: true, OverloadUntil: &future})
mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1)
overloaded := mustCreateAccount(s.T(), s.client, &service.Account{Name: "over", Schedulable: true, OverloadUntil: &future})
mustBindAccountToGroup(s.T(), s.client, overloaded.ID, group.ID, 1)
rateLimited := mustCreateAccount(s.T(), s.db, &accountModel{Name: "rl", Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, rateLimited.ID, group.ID, 1)
rateLimited := mustCreateAccount(s.T(), s.client, &service.Account{Name: "rl", Schedulable: true})
mustBindAccountToGroup(s.T(), s.client, rateLimited.ID, group.ID, 1)
s.Require().NoError(s.repo.SetRateLimited(s.ctx, rateLimited.ID, now.Add(10*time.Minute)), "SetRateLimited")
s.Require().NoError(s.repo.SetError(s.ctx, overloaded.ID, "boom"), "SetError")
......@@ -339,8 +341,8 @@ func (s *AccountRepoSuite) TestListSchedulableByGroupID_TimeBoundaries_And_Statu
}
func (s *AccountRepoSuite) TestListSchedulableByPlatform() {
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", Platform: service.PlatformAnthropic, Schedulable: true})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", Platform: service.PlatformOpenAI, Schedulable: true})
mustCreateAccount(s.T(), s.client, &service.Account{Name: "a1", Platform: service.PlatformAnthropic, Schedulable: true})
mustCreateAccount(s.T(), s.client, &service.Account{Name: "a2", Platform: service.PlatformOpenAI, Schedulable: true})
accounts, err := s.repo.ListSchedulableByPlatform(s.ctx, service.PlatformAnthropic)
s.Require().NoError(err)
......@@ -349,11 +351,11 @@ func (s *AccountRepoSuite) TestListSchedulableByPlatform() {
}
func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() {
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sp"})
a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", Platform: service.PlatformAnthropic, Schedulable: true})
a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", Platform: service.PlatformOpenAI, Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, a1.ID, group.ID, 1)
mustBindAccountToGroup(s.T(), s.db, a2.ID, group.ID, 2)
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-sp"})
a1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a1", Platform: service.PlatformAnthropic, Schedulable: true})
a2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a2", Platform: service.PlatformOpenAI, Schedulable: true})
mustBindAccountToGroup(s.T(), s.client, a1.ID, group.ID, 1)
mustBindAccountToGroup(s.T(), s.client, a2.ID, group.ID, 2)
accounts, err := s.repo.ListSchedulableByGroupIDAndPlatform(s.ctx, group.ID, service.PlatformAnthropic)
s.Require().NoError(err)
......@@ -362,7 +364,7 @@ func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() {
}
func (s *AccountRepoSuite) TestSetSchedulable() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-sched", Schedulable: true})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-sched", Schedulable: true})
s.Require().NoError(s.repo.SetSchedulable(s.ctx, account.ID, false))
......@@ -374,7 +376,7 @@ func (s *AccountRepoSuite) TestSetSchedulable() {
// --- SetOverloaded / SetRateLimited / ClearRateLimit ---
func (s *AccountRepoSuite) TestSetOverloaded() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-over"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-over"})
until := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until))
......@@ -386,7 +388,7 @@ func (s *AccountRepoSuite) TestSetOverloaded() {
}
func (s *AccountRepoSuite) TestSetRateLimited() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-rl"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-rl"})
resetAt := time.Date(2025, 6, 15, 14, 0, 0, 0, time.UTC)
s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, resetAt))
......@@ -399,7 +401,7 @@ func (s *AccountRepoSuite) TestSetRateLimited() {
}
func (s *AccountRepoSuite) TestClearRateLimit() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-clear"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-clear"})
until := time.Now().Add(1 * time.Hour)
s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until))
s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, until))
......@@ -416,7 +418,7 @@ func (s *AccountRepoSuite) TestClearRateLimit() {
// --- UpdateLastUsed ---
func (s *AccountRepoSuite) TestUpdateLastUsed() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-used"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-used"})
s.Require().Nil(account.LastUsedAt)
s.Require().NoError(s.repo.UpdateLastUsed(s.ctx, account.ID))
......@@ -429,7 +431,7 @@ func (s *AccountRepoSuite) TestUpdateLastUsed() {
// --- SetError ---
func (s *AccountRepoSuite) TestSetError() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-err", Status: service.StatusActive})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-err", Status: service.StatusActive})
s.Require().NoError(s.repo.SetError(s.ctx, account.ID, "something went wrong"))
......@@ -442,7 +444,7 @@ func (s *AccountRepoSuite) TestSetError() {
// --- UpdateSessionWindow ---
func (s *AccountRepoSuite) TestUpdateSessionWindow() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-win"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-win"})
start := time.Date(2025, 6, 15, 10, 0, 0, 0, time.UTC)
end := time.Date(2025, 6, 15, 15, 0, 0, 0, time.UTC)
......@@ -458,9 +460,9 @@ func (s *AccountRepoSuite) TestUpdateSessionWindow() {
// --- UpdateExtra ---
func (s *AccountRepoSuite) TestUpdateExtra_MergesFields() {
account := mustCreateAccount(s.T(), s.db, &accountModel{
account := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "acc-extra",
Extra: datatypes.JSONMap{"a": "1"},
Extra: map[string]any{"a": "1"},
})
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"b": "2"}), "UpdateExtra")
......@@ -471,12 +473,12 @@ func (s *AccountRepoSuite) TestUpdateExtra_MergesFields() {
}
func (s *AccountRepoSuite) TestUpdateExtra_EmptyUpdates() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-extra-empty"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-extra-empty"})
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{}))
}
func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-nil-extra", Extra: nil})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-nil-extra", Extra: nil})
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"key": "val"}))
got, err := s.repo.GetByID(s.ctx, account.ID)
......@@ -488,9 +490,9 @@ func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
func (s *AccountRepoSuite) TestGetByCRSAccountID() {
crsID := "crs-12345"
mustCreateAccount(s.T(), s.db, &accountModel{
mustCreateAccount(s.T(), s.client, &service.Account{
Name: "acc-crs",
Extra: datatypes.JSONMap{"crs_account_id": crsID},
Extra: map[string]any{"crs_account_id": crsID},
})
got, err := s.repo.GetByCRSAccountID(s.ctx, crsID)
......@@ -514,8 +516,8 @@ func (s *AccountRepoSuite) TestGetByCRSAccountID_EmptyString() {
// --- BulkUpdate ---
func (s *AccountRepoSuite) TestBulkUpdate() {
a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "bulk1", Priority: 1})
a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "bulk2", Priority: 1})
a1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk1", Priority: 1})
a2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk2", Priority: 1})
newPriority := 99
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID, a2.ID}, service.AccountBulkUpdate{
......@@ -531,13 +533,13 @@ func (s *AccountRepoSuite) TestBulkUpdate() {
}
func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() {
a1 := mustCreateAccount(s.T(), s.db, &accountModel{
a1 := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "bulk-cred",
Credentials: datatypes.JSONMap{"existing": "value"},
Credentials: map[string]any{"existing": "value"},
})
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{
Credentials: datatypes.JSONMap{"new_key": "new_value"},
Credentials: map[string]any{"new_key": "new_value"},
})
s.Require().NoError(err)
......@@ -547,13 +549,13 @@ func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() {
}
func (s *AccountRepoSuite) TestBulkUpdate_MergeExtra() {
a1 := mustCreateAccount(s.T(), s.db, &accountModel{
a1 := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "bulk-extra",
Extra: datatypes.JSONMap{"existing": "val"},
Extra: map[string]any{"existing": "val"},
})
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{
Extra: datatypes.JSONMap{"new_key": "new_val"},
Extra: map[string]any{"new_key": "new_val"},
})
s.Require().NoError(err)
......@@ -569,7 +571,7 @@ func (s *AccountRepoSuite) TestBulkUpdate_EmptyIDs() {
}
func (s *AccountRepoSuite) TestBulkUpdate_EmptyUpdates() {
a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "bulk-empty"})
a1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk-empty"})
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{})
s.Require().NoError(err)
......
//go:build integration
package repository
import (
"context"
"fmt"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func uniqueTestValue(t *testing.T, prefix string) string {
t.Helper()
safeName := strings.NewReplacer("/", "_", " ", "_").Replace(t.Name())
return fmt.Sprintf("%s-%s", prefix, safeName)
}
func TestUserRepository_RemoveGroupFromAllowedGroups_RemovesAllOccurrences(t *testing.T) {
ctx := context.Background()
tx := testEntTx(t)
entClient := tx.Client()
targetGroup, err := entClient.Group.Create().
SetName(uniqueTestValue(t, "target-group")).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
otherGroup, err := entClient.Group.Create().
SetName(uniqueTestValue(t, "other-group")).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
repo := newUserRepositoryWithSQL(entClient, tx)
u1 := &service.User{
Email: uniqueTestValue(t, "u1") + "@example.com",
PasswordHash: "test-password-hash",
Role: service.RoleUser,
Status: service.StatusActive,
Concurrency: 5,
AllowedGroups: []int64{targetGroup.ID, otherGroup.ID},
}
require.NoError(t, repo.Create(ctx, u1))
u2 := &service.User{
Email: uniqueTestValue(t, "u2") + "@example.com",
PasswordHash: "test-password-hash",
Role: service.RoleUser,
Status: service.StatusActive,
Concurrency: 5,
AllowedGroups: []int64{targetGroup.ID},
}
require.NoError(t, repo.Create(ctx, u2))
u3 := &service.User{
Email: uniqueTestValue(t, "u3") + "@example.com",
PasswordHash: "test-password-hash",
Role: service.RoleUser,
Status: service.StatusActive,
Concurrency: 5,
AllowedGroups: []int64{otherGroup.ID},
}
require.NoError(t, repo.Create(ctx, u3))
affected, err := repo.RemoveGroupFromAllowedGroups(ctx, targetGroup.ID)
require.NoError(t, err)
require.Equal(t, int64(2), affected)
u1After, err := repo.GetByID(ctx, u1.ID)
require.NoError(t, err)
require.NotContains(t, u1After.AllowedGroups, targetGroup.ID)
require.Contains(t, u1After.AllowedGroups, otherGroup.ID)
u2After, err := repo.GetByID(ctx, u2.ID)
require.NoError(t, err)
require.NotContains(t, u2After.AllowedGroups, targetGroup.ID)
}
func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *testing.T) {
ctx := context.Background()
tx := testEntTx(t)
entClient := tx.Client()
targetGroup, err := entClient.Group.Create().
SetName(uniqueTestValue(t, "delete-cascade-target")).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
otherGroup, err := entClient.Group.Create().
SetName(uniqueTestValue(t, "delete-cascade-other")).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
userRepo := newUserRepositoryWithSQL(entClient, tx)
groupRepo := newGroupRepositoryWithSQL(entClient, tx)
apiKeyRepo := NewApiKeyRepository(entClient)
u := &service.User{
Email: uniqueTestValue(t, "cascade-user") + "@example.com",
PasswordHash: "test-password-hash",
Role: service.RoleUser,
Status: service.StatusActive,
Concurrency: 5,
AllowedGroups: []int64{targetGroup.ID, otherGroup.ID},
}
require.NoError(t, userRepo.Create(ctx, u))
key := &service.ApiKey{
UserID: u.ID,
Key: uniqueTestValue(t, "sk-test-delete-cascade"),
Name: "test key",
GroupID: &targetGroup.ID,
Status: service.StatusActive,
}
require.NoError(t, apiKeyRepo.Create(ctx, key))
_, err = groupRepo.DeleteCascade(ctx, targetGroup.ID)
require.NoError(t, err)
// Deleted group should be hidden by default queries (soft-delete semantics).
_, err = groupRepo.GetByID(ctx, targetGroup.ID)
require.ErrorIs(t, err, service.ErrGroupNotFound)
activeGroups, err := groupRepo.ListActive(ctx)
require.NoError(t, err)
for _, g := range activeGroups {
require.NotEqual(t, targetGroup.ID, g.ID)
}
// User.allowed_groups should no longer include the deleted group.
uAfter, err := userRepo.GetByID(ctx, u.ID)
require.NoError(t, err)
require.NotContains(t, uAfter.AllowedGroups, targetGroup.ID)
require.Contains(t, uAfter.AllowedGroups, otherGroup.ID)
// API keys bound to the deleted group should have group_id cleared.
keyAfter, err := apiKeyRepo.GetByID(ctx, key.ID)
require.NoError(t, err)
require.Nil(t, keyAfter.GroupID)
}
......@@ -4,81 +4,175 @@ import (
"context"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
)
type apiKeyRepository struct {
db *gorm.DB
client *dbent.Client
}
func NewApiKeyRepository(client *dbent.Client) service.ApiKeyRepository {
return &apiKeyRepository{client: client}
}
func NewApiKeyRepository(db *gorm.DB) service.ApiKeyRepository {
return &apiKeyRepository{db: db}
func (r *apiKeyRepository) activeQuery() *dbent.ApiKeyQuery {
// 默认过滤已软删除记录,避免删除后仍被查询到。
return r.client.ApiKey.Query().Where(apikey.DeletedAtIsNil())
}
func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) error {
m := apiKeyModelFromService(key)
err := r.db.WithContext(ctx).Create(m).Error
created, err := r.client.ApiKey.Create().
SetUserID(key.UserID).
SetKey(key.Key).
SetName(key.Name).
SetStatus(key.Status).
SetNillableGroupID(key.GroupID).
Save(ctx)
if err == nil {
applyApiKeyModelToService(key, m)
key.ID = created.ID
key.CreatedAt = created.CreatedAt
key.UpdatedAt = created.UpdatedAt
}
return translatePersistenceError(err, nil, service.ErrApiKeyExists)
}
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
var m apiKeyModel
err := r.db.WithContext(ctx).Preload("User").Preload("Group").First(&m, id).Error
m, err := r.activeQuery().
Where(apikey.IDEQ(id)).
WithUser().
WithGroup().
Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil)
if dbent.IsNotFound(err) {
return nil, service.ErrApiKeyNotFound
}
return apiKeyModelToService(&m), nil
return nil, err
}
return apiKeyEntityToService(m), nil
}
// GetOwnerID 根据 API Key ID 获取其所有者(用户)的 ID。
// 相比 GetByID,此方法性能更优,因为:
// - 使用 Select() 只查询 user_id 字段,减少数据传输量
// - 不加载完整的 ApiKey 实体及其关联数据(User、Group 等)
// - 适用于权限验证等只需用户 ID 的场景(如删除前的所有权检查)
func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, error) {
m, err := r.activeQuery().
Where(apikey.IDEQ(id)).
Select(apikey.FieldUserID).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return 0, service.ErrApiKeyNotFound
}
return 0, err
}
return m.UserID, nil
}
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
var m apiKeyModel
err := r.db.WithContext(ctx).Preload("User").Preload("Group").Where("key = ?", key).First(&m).Error
m, err := r.activeQuery().
Where(apikey.KeyEQ(key)).
WithUser().
WithGroup().
Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil)
if dbent.IsNotFound(err) {
return nil, service.ErrApiKeyNotFound
}
return apiKeyModelToService(&m), nil
return nil, err
}
return apiKeyEntityToService(m), nil
}
func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) error {
m := apiKeyModelFromService(key)
err := r.db.WithContext(ctx).Model(m).Select("name", "group_id", "status", "updated_at").Updates(m).Error
if err == nil {
applyApiKeyModelToService(key, m)
// 使用原子操作:将软删除检查与更新合并到同一语句,避免竞态条件。
// 之前的实现先检查 Exist 再 UpdateOneID,若在两步之间发生软删除,
// 则会更新已删除的记录。
// 这里选择 Update().Where(),确保只有未软删除记录能被更新。
// 同时显式设置 updated_at,避免二次查询带来的并发可见性问题。
now := time.Now()
builder := r.client.ApiKey.Update().
Where(apikey.IDEQ(key.ID), apikey.DeletedAtIsNil()).
SetName(key.Name).
SetStatus(key.Status).
SetUpdatedAt(now)
if key.GroupID != nil {
builder.SetGroupID(*key.GroupID)
} else {
builder.ClearGroupID()
}
affected, err := builder.Save(ctx)
if err != nil {
return err
}
if affected == 0 {
// 更新影响行数为 0,说明记录不存在或已被软删除。
return service.ErrApiKeyNotFound
}
// 使用同一时间戳回填,避免并发删除导致二次查询失败。
key.UpdatedAt = now
return nil
}
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&apiKeyModel{}, id).Error
// 显式软删除:避免依赖 Hook 行为,确保 deleted_at 一定被设置。
affected, err := r.client.ApiKey.Update().
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
SetDeletedAt(time.Now()).
Save(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return service.ErrApiKeyNotFound
}
return err
}
if affected == 0 {
exists, err := r.client.ApiKey.Query().
Where(apikey.IDEQ(id)).
Exist(mixins.SkipSoftDelete(ctx))
if err != nil {
return err
}
if exists {
return nil
}
return service.ErrApiKeyNotFound
}
return nil
}
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
var keys []apiKeyModel
var total int64
q := r.activeQuery().Where(apikey.UserIDEQ(userID))
db := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("user_id = ?", userID)
if err := db.Count(&total).Error; err != nil {
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
}
if err := db.Preload("Group").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&keys).Error; err != nil {
keys, err := q.
WithGroup().
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(apikey.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
outKeys := make([]service.ApiKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyModelToService(&keys[i]))
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
}
return outKeys, paginationResultFromTotal(total, params), nil
return outKeys, paginationResultFromTotal(int64(total), params), nil
}
func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
......@@ -86,11 +180,9 @@ func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, ap
return []int64{}, nil
}
ids := make([]int64, 0, len(apiKeyIDs))
err := r.db.WithContext(ctx).
Model(&apiKeyModel{}).
Where("user_id = ? AND id IN ?", userID, apiKeyIDs).
Pluck("id", &ids).Error
ids, err := r.client.ApiKey.Query().
Where(apikey.UserIDEQ(userID), apikey.IDIn(apiKeyIDs...), apikey.DeletedAtIsNil()).
IDs(ctx)
if err != nil {
return nil, err
}
......@@ -98,136 +190,146 @@ func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, ap
}
func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("user_id = ?", userID).Count(&count).Error
return count, err
count, err := r.activeQuery().Where(apikey.UserIDEQ(userID)).Count(ctx)
return int64(count), err
}
func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("key = ?", key).Count(&count).Error
count, err := r.activeQuery().Where(apikey.KeyEQ(key)).Count(ctx)
return count > 0, err
}
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
var keys []apiKeyModel
var total int64
q := r.activeQuery().Where(apikey.GroupIDEQ(groupID))
db := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("group_id = ?", groupID)
if err := db.Count(&total).Error; err != nil {
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
}
if err := db.Preload("User").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&keys).Error; err != nil {
keys, err := q.
WithUser().
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(apikey.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
outKeys := make([]service.ApiKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyModelToService(&keys[i]))
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
}
return outKeys, paginationResultFromTotal(total, params), nil
return outKeys, paginationResultFromTotal(int64(total), params), nil
}
// SearchApiKeys searches API keys by user ID and/or keyword (name)
func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
var keys []apiKeyModel
db := r.db.WithContext(ctx).Model(&apiKeyModel{})
q := r.activeQuery()
if userID > 0 {
db = db.Where("user_id = ?", userID)
q = q.Where(apikey.UserIDEQ(userID))
}
if keyword != "" {
searchPattern := "%" + keyword + "%"
db = db.Where("name ILIKE ?", searchPattern)
q = q.Where(apikey.NameContainsFold(keyword))
}
if err := db.Limit(limit).Order("id DESC").Find(&keys).Error; err != nil {
keys, err := q.Limit(limit).Order(dbent.Desc(apikey.FieldID)).All(ctx)
if err != nil {
return nil, err
}
outKeys := make([]service.ApiKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyModelToService(&keys[i]))
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
}
return outKeys, nil
}
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Model(&apiKeyModel{}).
Where("group_id = ?", groupID).
Update("group_id", nil)
return result.RowsAffected, result.Error
n, err := r.client.ApiKey.Update().
Where(apikey.GroupIDEQ(groupID), apikey.DeletedAtIsNil()).
ClearGroupID().
Save(ctx)
return int64(n), err
}
// CountByGroupID 获取分组的 API Key 数量
func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("group_id = ?", groupID).Count(&count).Error
return count, err
count, err := r.activeQuery().Where(apikey.GroupIDEQ(groupID)).Count(ctx)
return int64(count), err
}
type apiKeyModel struct {
ID int64 `gorm:"primaryKey"`
UserID int64 `gorm:"index;not null"`
Key string `gorm:"uniqueIndex;size:128;not null"`
Name string `gorm:"size:100;not null"`
GroupID *int64 `gorm:"index"`
Status string `gorm:"size:20;default:active;not null"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
DeletedAt gorm.DeletedAt `gorm:"index"`
User *userModel `gorm:"foreignKey:UserID"`
Group *groupModel `gorm:"foreignKey:GroupID"`
}
func (apiKeyModel) TableName() string { return "api_keys" }
func apiKeyModelToService(m *apiKeyModel) *service.ApiKey {
func apiKeyEntityToService(m *dbent.ApiKey) *service.ApiKey {
if m == nil {
return nil
}
return &service.ApiKey{
out := &service.ApiKey{
ID: m.ID,
UserID: m.UserID,
Key: m.Key,
Name: m.Name,
GroupID: m.GroupID,
Status: m.Status,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
User: userModelToService(m.User),
Group: groupModelToService(m.Group),
GroupID: m.GroupID,
}
if m.Edges.User != nil {
out.User = userEntityToService(m.Edges.User)
}
if m.Edges.Group != nil {
out.Group = groupEntityToService(m.Edges.Group)
}
return out
}
func userEntityToService(u *dbent.User) *service.User {
if u == nil {
return nil
}
return &service.User{
ID: u.ID,
Email: u.Email,
Username: u.Username,
Wechat: u.Wechat,
Notes: u.Notes,
PasswordHash: u.PasswordHash,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
}
}
func apiKeyModelFromService(k *service.ApiKey) *apiKeyModel {
if k == nil {
func groupEntityToService(g *dbent.Group) *service.Group {
if g == nil {
return nil
}
return &apiKeyModel{
ID: k.ID,
UserID: k.UserID,
Key: k.Key,
Name: k.Name,
GroupID: k.GroupID,
Status: k.Status,
CreatedAt: k.CreatedAt,
UpdatedAt: k.UpdatedAt,
return &service.Group{
ID: g.ID,
Name: g.Name,
Description: derefString(g.Description),
Platform: g.Platform,
RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive,
Status: g.Status,
SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUsd,
WeeklyLimitUSD: g.WeeklyLimitUsd,
MonthlyLimitUSD: g.MonthlyLimitUsd,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
}
}
func applyApiKeyModelToService(key *service.ApiKey, m *apiKeyModel) {
if key == nil || m == nil {
return
func derefString(s *string) string {
if s == nil {
return ""
}
key.ID = m.ID
key.CreatedAt = m.CreatedAt
key.UpdatedAt = m.UpdatedAt
return *s
}
......@@ -6,23 +6,24 @@ import (
"context"
"testing"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
type ApiKeyRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
client *dbent.Client
repo *apiKeyRepository
}
func (s *ApiKeyRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewApiKeyRepository(s.db).(*apiKeyRepository)
tx := testEntTx(s.T())
s.client = tx.Client()
s.repo = NewApiKeyRepository(s.client).(*apiKeyRepository)
}
func TestApiKeyRepoSuite(t *testing.T) {
......@@ -32,7 +33,7 @@ func TestApiKeyRepoSuite(t *testing.T) {
// --- Create / GetByID / GetByKey ---
func (s *ApiKeyRepoSuite) TestCreate() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "create@test.com"})
user := s.mustCreateUser("create@test.com")
key := &service.ApiKey{
UserID: user.ID,
......@@ -56,16 +57,17 @@ func (s *ApiKeyRepoSuite) TestGetByID_NotFound() {
}
func (s *ApiKeyRepoSuite) TestGetByKey() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "getbykey@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-key"})
user := s.mustCreateUser("getbykey@test.com")
group := s.mustCreateGroup("g-key")
key := mustCreateApiKey(s.T(), s.db, &apiKeyModel{
key := &service.ApiKey{
UserID: user.ID,
Key: "sk-getbykey",
Name: "My Key",
GroupID: &group.ID,
Status: service.StatusActive,
})
}
s.Require().NoError(s.repo.Create(s.ctx, key))
got, err := s.repo.GetByKey(s.ctx, key.Key)
s.Require().NoError(err, "GetByKey")
......@@ -84,13 +86,14 @@ func (s *ApiKeyRepoSuite) TestGetByKey_NotFound() {
// --- Update ---
func (s *ApiKeyRepoSuite) TestUpdate() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "update@test.com"})
key := apiKeyModelToService(mustCreateApiKey(s.T(), s.db, &apiKeyModel{
user := s.mustCreateUser("update@test.com")
key := &service.ApiKey{
UserID: user.ID,
Key: "sk-update",
Name: "Original",
Status: service.StatusActive,
}))
}
s.Require().NoError(s.repo.Create(s.ctx, key))
key.Name = "Renamed"
key.Status = service.StatusDisabled
......@@ -106,14 +109,16 @@ func (s *ApiKeyRepoSuite) TestUpdate() {
}
func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "cleargroup@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-clear"})
key := apiKeyModelToService(mustCreateApiKey(s.T(), s.db, &apiKeyModel{
user := s.mustCreateUser("cleargroup@test.com")
group := s.mustCreateGroup("g-clear")
key := &service.ApiKey{
UserID: user.ID,
Key: "sk-clear-group",
Name: "Group Key",
GroupID: &group.ID,
}))
Status: service.StatusActive,
}
s.Require().NoError(s.repo.Create(s.ctx, key))
key.GroupID = nil
err := s.repo.Update(s.ctx, key)
......@@ -127,12 +132,14 @@ func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
// --- Delete ---
func (s *ApiKeyRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"})
key := mustCreateApiKey(s.T(), s.db, &apiKeyModel{
user := s.mustCreateUser("delete@test.com")
key := &service.ApiKey{
UserID: user.ID,
Key: "sk-delete",
Name: "Delete Me",
})
Status: service.StatusActive,
}
s.Require().NoError(s.repo.Create(s.ctx, key))
err := s.repo.Delete(s.ctx, key.ID)
s.Require().NoError(err, "Delete")
......@@ -144,9 +151,9 @@ func (s *ApiKeyRepoSuite) TestDelete() {
// --- ListByUserID / CountByUserID ---
func (s *ApiKeyRepoSuite) TestListByUserID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyuser@test.com"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-list-1", Name: "Key 1"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-list-2", Name: "Key 2"})
user := s.mustCreateUser("listbyuser@test.com")
s.mustCreateApiKey(user.ID, "sk-list-1", "Key 1", nil)
s.mustCreateApiKey(user.ID, "sk-list-2", "Key 2", nil)
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByUserID")
......@@ -155,13 +162,9 @@ func (s *ApiKeyRepoSuite) TestListByUserID() {
}
func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "paging@test.com"})
user := s.mustCreateUser("paging@test.com")
for i := 0; i < 5; i++ {
mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID,
Key: "sk-page-" + string(rune('a'+i)),
Name: "Key",
})
s.mustCreateApiKey(user.ID, "sk-page-"+string(rune('a'+i)), "Key", nil)
}
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2})
......@@ -172,9 +175,9 @@ func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
}
func (s *ApiKeyRepoSuite) TestCountByUserID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "count@test.com"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-count-1", Name: "K1"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-count-2", Name: "K2"})
user := s.mustCreateUser("count@test.com")
s.mustCreateApiKey(user.ID, "sk-count-1", "K1", nil)
s.mustCreateApiKey(user.ID, "sk-count-2", "K2", nil)
count, err := s.repo.CountByUserID(s.ctx, user.ID)
s.Require().NoError(err, "CountByUserID")
......@@ -184,12 +187,12 @@ func (s *ApiKeyRepoSuite) TestCountByUserID() {
// --- ListByGroupID / CountByGroupID ---
func (s *ApiKeyRepoSuite) TestListByGroupID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbygroup@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list"})
user := s.mustCreateUser("listbygroup@test.com")
group := s.mustCreateGroup("g-list")
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-grp-1", Name: "K1", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-grp-2", Name: "K2", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-grp-3", Name: "K3"}) // no group
s.mustCreateApiKey(user.ID, "sk-grp-1", "K1", &group.ID)
s.mustCreateApiKey(user.ID, "sk-grp-2", "K2", &group.ID)
s.mustCreateApiKey(user.ID, "sk-grp-3", "K3", nil) // no group
keys, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByGroupID")
......@@ -200,10 +203,9 @@ func (s *ApiKeyRepoSuite) TestListByGroupID() {
}
func (s *ApiKeyRepoSuite) TestCountByGroupID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "countgroup@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-count"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-gc-1", Name: "K1", GroupID: &group.ID})
user := s.mustCreateUser("countgroup@test.com")
group := s.mustCreateGroup("g-count")
s.mustCreateApiKey(user.ID, "sk-gc-1", "K1", &group.ID)
count, err := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "CountByGroupID")
......@@ -213,8 +215,8 @@ func (s *ApiKeyRepoSuite) TestCountByGroupID() {
// --- ExistsByKey ---
func (s *ApiKeyRepoSuite) TestExistsByKey() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "exists@test.com"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-exists", Name: "K"})
user := s.mustCreateUser("exists@test.com")
s.mustCreateApiKey(user.ID, "sk-exists", "K", nil)
exists, err := s.repo.ExistsByKey(s.ctx, "sk-exists")
s.Require().NoError(err, "ExistsByKey")
......@@ -228,9 +230,9 @@ func (s *ApiKeyRepoSuite) TestExistsByKey() {
// --- SearchApiKeys ---
func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "search@test.com"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-search-1", Name: "Production Key"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-search-2", Name: "Development Key"})
user := s.mustCreateUser("search@test.com")
s.mustCreateApiKey(user.ID, "sk-search-1", "Production Key", nil)
s.mustCreateApiKey(user.ID, "sk-search-2", "Development Key", nil)
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "prod", 10)
s.Require().NoError(err, "SearchApiKeys")
......@@ -239,9 +241,9 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
}
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "searchnokw@test.com"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-nk-1", Name: "K1"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-nk-2", Name: "K2"})
user := s.mustCreateUser("searchnokw@test.com")
s.mustCreateApiKey(user.ID, "sk-nk-1", "K1", nil)
s.mustCreateApiKey(user.ID, "sk-nk-2", "K2", nil)
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "", 10)
s.Require().NoError(err)
......@@ -249,8 +251,8 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
}
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "searchnouid@test.com"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-nu-1", Name: "TestKey"})
user := s.mustCreateUser("searchnouid@test.com")
s.mustCreateApiKey(user.ID, "sk-nu-1", "TestKey", nil)
found, err := s.repo.SearchApiKeys(s.ctx, 0, "testkey", 10)
s.Require().NoError(err)
......@@ -260,12 +262,12 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
// --- ClearGroupIDByGroupID ---
func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "cleargrp@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-clear-bulk"})
user := s.mustCreateUser("cleargrp@test.com")
group := s.mustCreateGroup("g-clear-bulk")
k1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-clr-1", Name: "K1", GroupID: &group.ID})
k2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-clr-2", Name: "K2", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-clr-3", Name: "K3"}) // no group
k1 := s.mustCreateApiKey(user.ID, "sk-clr-1", "K1", &group.ID)
k2 := s.mustCreateApiKey(user.ID, "sk-clr-2", "K2", &group.ID)
s.mustCreateApiKey(user.ID, "sk-clr-3", "K3", nil) // no group
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "ClearGroupIDByGroupID")
......@@ -283,16 +285,10 @@ func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
// --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) ---
func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "k@example.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-k"})
key := apiKeyModelToService(mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID,
Key: "sk-test-1",
Name: "My Key",
GroupID: &group.ID,
Status: service.StatusActive,
}))
user := s.mustCreateUser("k@example.com")
group := s.mustCreateGroup("g-k")
key := s.mustCreateApiKey(user.ID, "sk-test-1", "My Key", &group.ID)
key.GroupID = &group.ID
got, err := s.repo.GetByKey(s.ctx, key.Key)
s.Require().NoError(err, "GetByKey")
......@@ -330,12 +326,8 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s.Require().Equal(key.ID, found[0].ID)
// ClearGroupIDByGroupID
k2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID,
Key: "sk-test-2",
Name: "Group Key",
GroupID: &group.ID,
})
k2 := s.mustCreateApiKey(user.ID, "sk-test-2", "Group Key", &group.ID)
k2.GroupID = &group.ID
countBefore, err := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "CountByGroupID")
......@@ -353,3 +345,41 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s.Require().NoError(err, "CountByGroupID after clear")
s.Require().Equal(int64(0), countAfter, "expected 0 keys in group after clear")
}
func (s *ApiKeyRepoSuite) mustCreateUser(email string) *service.User {
s.T().Helper()
u, err := s.client.User.Create().
SetEmail(email).
SetPasswordHash("test-password-hash").
SetStatus(service.StatusActive).
SetRole(service.RoleUser).
Save(s.ctx)
s.Require().NoError(err, "create user")
return userEntityToService(u)
}
func (s *ApiKeyRepoSuite) mustCreateGroup(name string) *service.Group {
s.T().Helper()
g, err := s.client.Group.Create().
SetName(name).
SetStatus(service.StatusActive).
Save(s.ctx)
s.Require().NoError(err, "create group")
return groupEntityToService(g)
}
func (s *ApiKeyRepoSuite) mustCreateApiKey(userID int64, key, name string, groupID *int64) *service.ApiKey {
s.T().Helper()
k := &service.ApiKey{
UserID: userID,
Key: key,
Name: name,
GroupID: groupID,
Status: service.StatusActive,
}
s.Require().NoError(s.repo.Create(s.ctx, k), "create api key")
return k
}
package repository
import (
"log"
"time"
"gorm.io/gorm"
)
// MaxExpiresAt is the maximum allowed expiration date for subscriptions (year 2099)
// This prevents time.Time JSON serialization errors (RFC 3339 requires year <= 9999)
var maxExpiresAt = time.Date(2099, 12, 31, 23, 59, 59, 0, time.UTC)
// AutoMigrate runs schema migrations for all repository persistence models.
// Persistence models are defined within individual `*_repo.go` files.
// runMode: "standard" or "simple" - determines whether to create default groups
func AutoMigrate(db *gorm.DB, runMode string) error {
err := db.AutoMigrate(
&userModel{},
&apiKeyModel{},
&groupModel{},
&accountModel{},
&accountGroupModel{},
&proxyModel{},
&redeemCodeModel{},
&usageLogModel{},
&settingModel{},
&userSubscriptionModel{},
)
if err != nil {
return err
}
// 创建默认分组(简易模式支持)
if err := ensureDefaultGroups(db, runMode); err != nil {
return err
}
// 修复无效的过期时间(年份超过 2099 会导致 JSON 序列化失败)
return fixInvalidExpiresAt(db)
}
// fixInvalidExpiresAt 修复 user_subscriptions 表中无效的过期时间
func fixInvalidExpiresAt(db *gorm.DB) error {
result := db.Model(&userSubscriptionModel{}).
Where("expires_at > ?", maxExpiresAt).
Update("expires_at", maxExpiresAt)
if result.Error != nil {
return result.Error
}
if result.RowsAffected > 0 {
log.Printf("[AutoMigrate] Fixed %d subscriptions with invalid expires_at (year > 2099)", result.RowsAffected)
}
return nil
}
// ensureDefaultGroups 确保默认分组存在(简易模式支持)
// 为每个平台创建一个默认分组,配置最大权限以确保简易模式下不受限制
// runMode: "standard" 时跳过创建, "simple" 时创建/恢复默认分组
func ensureDefaultGroups(db *gorm.DB, runMode string) error {
// 标准版不创建默认分组
if runMode == "standard" {
return nil
}
defaultGroups := []struct {
name string
platform string
description string
}{
{
name: "anthropic-default",
platform: "anthropic",
description: "Default group for Anthropic accounts (Simple Mode)",
},
{
name: "openai-default",
platform: "openai",
description: "Default group for OpenAI accounts (Simple Mode)",
},
{
name: "gemini-default",
platform: "gemini",
description: "Default group for Gemini accounts (Simple Mode)",
},
}
for _, dg := range defaultGroups {
// 步骤1: 检查是否有软删除的记录
var softDeletedCount int64
if err := db.Unscoped().Model(&groupModel{}).
Where("name = ? AND deleted_at IS NOT NULL", dg.name).
Count(&softDeletedCount).Error; err != nil {
return err
}
if softDeletedCount > 0 {
// 恢复软删除的记录
if err := db.Unscoped().Model(&groupModel{}).
Where("name = ?", dg.name).
Update("deleted_at", nil).Error; err != nil {
log.Printf("[AutoMigrate] Failed to restore default group %s: %v", dg.name, err)
return err
}
log.Printf("[AutoMigrate] Restored default group: %s (platform: %s)", dg.name, dg.platform)
continue
}
// 步骤2: 检查是否有活跃记录
var activeCount int64
if err := db.Model(&groupModel{}).Where("name = ?", dg.name).Count(&activeCount).Error; err != nil {
return err
}
if activeCount == 0 {
// 创建新分组
group := &groupModel{
Name: dg.name,
Description: dg.description,
Platform: dg.platform,
RateMultiplier: 1.0,
IsExclusive: false,
Status: "active",
SubscriptionType: "standard",
}
if err := db.Create(group).Error; err != nil {
log.Printf("[AutoMigrate] Failed to create default group %s: %v", dg.name, err)
return err
}
log.Printf("[AutoMigrate] Created default group: %s (platform: %s)", dg.name, dg.platform)
}
}
return nil
}
package repository
import (
"database/sql"
"errors"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"gorm.io/gorm"
"github.com/lib/pq"
)
// translatePersistenceError 将数据库层错误翻译为业务层错误。
//
// 这是 Repository 层的核心错误处理函数,确保数据库细节不会泄露到业务层。
// 通过统一的错误翻译,业务层可以使用语义明确的错误类型(如 ErrUserNotFound)
// 而不是依赖于特定数据库的错误(如 sql.ErrNoRows)。
//
// 参数:
// - err: 原始数据库错误
// - notFound: 当记录不存在时返回的业务错误(可为 nil 表示不处理)
// - conflict: 当违反唯一约束时返回的业务错误(可为 nil 表示不处理)
//
// 返回:
// - 翻译后的业务错误,或原始错误(如果不匹配任何规则)
//
// 示例:
//
// err := translatePersistenceError(dbErr, service.ErrUserNotFound, service.ErrEmailExists)
func translatePersistenceError(err error, notFound, conflict *infraerrors.ApplicationError) error {
if err == nil {
return nil
}
if notFound != nil && errors.Is(err, gorm.ErrRecordNotFound) {
// 兼容 Ent ORM 和标准 database/sql 的 NotFound 行为。
// Ent 使用自定义的 NotFoundError,而标准库使用 sql.ErrNoRows。
// 这里同时处理两种情况,保持业务错误映射一致。
if notFound != nil && (errors.Is(err, sql.ErrNoRows) || dbent.IsNotFound(err)) {
return notFound.WithCause(err)
}
// 处理唯一约束冲突(如邮箱已存在、名称重复等)
if conflict != nil && isUniqueConstraintViolation(err) {
return conflict.WithCause(err)
}
// 未匹配任何规则,返回原始错误
return err
}
// isUniqueConstraintViolation 判断错误是否为唯一约束冲突。
//
// 支持多种检测方式:
// 1. PostgreSQL 特定错误码 23505(唯一约束冲突)
// 2. 错误消息中包含的通用关键词
//
// 这种多层次的检测确保了对不同数据库驱动和 ORM 的兼容性。
func isUniqueConstraintViolation(err error) bool {
if err == nil {
return false
}
if errors.Is(err, gorm.ErrDuplicatedKey) {
return true
// 优先检测 PostgreSQL 特定错误码(最精确)。
// 错误码 23505 对应 unique_violation。
// 参考:https://www.postgresql.org/docs/current/errcodes-appendix.html
var pgErr *pq.Error
if errors.As(err, &pgErr) {
return pgErr.Code == "23505"
}
// 回退到错误消息检测(兼容其他场景)。
// 这些关键词覆盖了 PostgreSQL、MySQL 等主流数据库的错误消息。
msg := strings.ToLower(err.Error())
return strings.Contains(msg, "duplicate key") ||
strings.Contains(msg, "unique constraint") ||
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment