diff --git a/.gitignore b/.gitignore index 1a92ea3e641316b8bc2f88def04cee30ec1377d5..cf2bda9f108f6f99f60c94f8eeee11f0466e17fa 100644 --- a/.gitignore +++ b/.gitignore @@ -126,12 +126,9 @@ backend/cmd/server/server deploy/docker-compose.override.yml .gocache/ vite.config.js -docs/* -!docs/PAYMENT.md -!docs/PAYMENT_CN.md +docs/ .serena/ .codex/ frontend/coverage/ aicodex output/ - diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 1d39fa1e31a9c8b69f8b512cd2a504953f2cfa1d..3b474c4a98e88af391cb0ed6e86991474a93078b 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -79,7 +79,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { totpCache := repository.NewTotpCache(redisClient) totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService) authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService) - userHandler := handler.NewUserHandler(userService, emailService, emailCache) + userHandler := handler.NewUserHandler(userService, authService, emailService, emailCache) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageLogRepository := repository.NewUsageLogRepository(client, db) usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator) diff --git a/backend/ent/authidentity.go b/backend/ent/authidentity.go new file mode 100644 index 0000000000000000000000000000000000000000..5ccfcf19102d780646333f5f2da43d21ca8ce685 --- /dev/null +++ b/backend/ent/authidentity.go @@ -0,0 +1,266 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// AuthIdentity is the model entity for the AuthIdentity schema. +type AuthIdentity struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // UserID holds the value of the "user_id" field. + UserID int64 `json:"user_id,omitempty"` + // ProviderType holds the value of the "provider_type" field. + ProviderType string `json:"provider_type,omitempty"` + // ProviderKey holds the value of the "provider_key" field. + ProviderKey string `json:"provider_key,omitempty"` + // ProviderSubject holds the value of the "provider_subject" field. + ProviderSubject string `json:"provider_subject,omitempty"` + // VerifiedAt holds the value of the "verified_at" field. + VerifiedAt *time.Time `json:"verified_at,omitempty"` + // Issuer holds the value of the "issuer" field. + Issuer *string `json:"issuer,omitempty"` + // Metadata holds the value of the "metadata" field. + Metadata map[string]interface{} `json:"metadata,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the AuthIdentityQuery when eager-loading is set. + Edges AuthIdentityEdges `json:"edges"` + selectValues sql.SelectValues +} + +// AuthIdentityEdges holds the relations/edges for other nodes in the graph. +type AuthIdentityEdges struct { + // User holds the value of the user edge. + User *User `json:"user,omitempty"` + // Channels holds the value of the channels edge. + Channels []*AuthIdentityChannel `json:"channels,omitempty"` + // AdoptionDecisions holds the value of the adoption_decisions edge. + AdoptionDecisions []*IdentityAdoptionDecision `json:"adoption_decisions,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [3]bool +} + +// UserOrErr returns the User value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e AuthIdentityEdges) UserOrErr() (*User, error) { + if e.User != nil { + return e.User, nil + } else if e.loadedTypes[0] { + return nil, &NotFoundError{label: user.Label} + } + return nil, &NotLoadedError{edge: "user"} +} + +// ChannelsOrErr returns the Channels value or an error if the edge +// was not loaded in eager-loading. +func (e AuthIdentityEdges) ChannelsOrErr() ([]*AuthIdentityChannel, error) { + if e.loadedTypes[1] { + return e.Channels, nil + } + return nil, &NotLoadedError{edge: "channels"} +} + +// AdoptionDecisionsOrErr returns the AdoptionDecisions value or an error if the edge +// was not loaded in eager-loading. +func (e AuthIdentityEdges) AdoptionDecisionsOrErr() ([]*IdentityAdoptionDecision, error) { + if e.loadedTypes[2] { + return e.AdoptionDecisions, nil + } + return nil, &NotLoadedError{edge: "adoption_decisions"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*AuthIdentity) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case authidentity.FieldMetadata: + values[i] = new([]byte) + case authidentity.FieldID, authidentity.FieldUserID: + values[i] = new(sql.NullInt64) + case authidentity.FieldProviderType, authidentity.FieldProviderKey, authidentity.FieldProviderSubject, authidentity.FieldIssuer: + values[i] = new(sql.NullString) + case authidentity.FieldCreatedAt, authidentity.FieldUpdatedAt, authidentity.FieldVerifiedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the AuthIdentity fields. +func (_m *AuthIdentity) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case authidentity.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case authidentity.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case authidentity.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + _m.UpdatedAt = value.Time + } + case authidentity.FieldUserID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field user_id", values[i]) + } else if value.Valid { + _m.UserID = value.Int64 + } + case authidentity.FieldProviderType: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_type", values[i]) + } else if value.Valid { + _m.ProviderType = value.String + } + case authidentity.FieldProviderKey: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_key", values[i]) + } else if value.Valid { + _m.ProviderKey = value.String + } + case authidentity.FieldProviderSubject: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_subject", values[i]) + } else if value.Valid { + _m.ProviderSubject = value.String + } + case authidentity.FieldVerifiedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field verified_at", values[i]) + } else if value.Valid { + _m.VerifiedAt = new(time.Time) + *_m.VerifiedAt = value.Time + } + case authidentity.FieldIssuer: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field issuer", values[i]) + } else if value.Valid { + _m.Issuer = new(string) + *_m.Issuer = value.String + } + case authidentity.FieldMetadata: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field metadata", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.Metadata); err != nil { + return fmt.Errorf("unmarshal field metadata: %w", err) + } + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the AuthIdentity. +// This includes values selected through modifiers, order, etc. +func (_m *AuthIdentity) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// QueryUser queries the "user" edge of the AuthIdentity entity. +func (_m *AuthIdentity) QueryUser() *UserQuery { + return NewAuthIdentityClient(_m.config).QueryUser(_m) +} + +// QueryChannels queries the "channels" edge of the AuthIdentity entity. +func (_m *AuthIdentity) QueryChannels() *AuthIdentityChannelQuery { + return NewAuthIdentityClient(_m.config).QueryChannels(_m) +} + +// QueryAdoptionDecisions queries the "adoption_decisions" edge of the AuthIdentity entity. +func (_m *AuthIdentity) QueryAdoptionDecisions() *IdentityAdoptionDecisionQuery { + return NewAuthIdentityClient(_m.config).QueryAdoptionDecisions(_m) +} + +// Update returns a builder for updating this AuthIdentity. +// Note that you need to call AuthIdentity.Unwrap() before calling this method if this AuthIdentity +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *AuthIdentity) Update() *AuthIdentityUpdateOne { + return NewAuthIdentityClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the AuthIdentity entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *AuthIdentity) Unwrap() *AuthIdentity { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: AuthIdentity is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *AuthIdentity) String() string { + var builder strings.Builder + builder.WriteString("AuthIdentity(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("user_id=") + builder.WriteString(fmt.Sprintf("%v", _m.UserID)) + builder.WriteString(", ") + builder.WriteString("provider_type=") + builder.WriteString(_m.ProviderType) + builder.WriteString(", ") + builder.WriteString("provider_key=") + builder.WriteString(_m.ProviderKey) + builder.WriteString(", ") + builder.WriteString("provider_subject=") + builder.WriteString(_m.ProviderSubject) + builder.WriteString(", ") + if v := _m.VerifiedAt; v != nil { + builder.WriteString("verified_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.Issuer; v != nil { + builder.WriteString("issuer=") + builder.WriteString(*v) + } + builder.WriteString(", ") + builder.WriteString("metadata=") + builder.WriteString(fmt.Sprintf("%v", _m.Metadata)) + builder.WriteByte(')') + return builder.String() +} + +// AuthIdentities is a parsable slice of AuthIdentity. +type AuthIdentities []*AuthIdentity diff --git a/backend/ent/authidentity/authidentity.go b/backend/ent/authidentity/authidentity.go new file mode 100644 index 0000000000000000000000000000000000000000..c90be759e827db02e0f8637d6a2bdfe559e10303 --- /dev/null +++ b/backend/ent/authidentity/authidentity.go @@ -0,0 +1,209 @@ +// Code generated by ent, DO NOT EDIT. + +package authidentity + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the authidentity type in the database. + Label = "auth_identity" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldUserID holds the string denoting the user_id field in the database. + FieldUserID = "user_id" + // FieldProviderType holds the string denoting the provider_type field in the database. + FieldProviderType = "provider_type" + // FieldProviderKey holds the string denoting the provider_key field in the database. + FieldProviderKey = "provider_key" + // FieldProviderSubject holds the string denoting the provider_subject field in the database. + FieldProviderSubject = "provider_subject" + // FieldVerifiedAt holds the string denoting the verified_at field in the database. + FieldVerifiedAt = "verified_at" + // FieldIssuer holds the string denoting the issuer field in the database. + FieldIssuer = "issuer" + // FieldMetadata holds the string denoting the metadata field in the database. + FieldMetadata = "metadata" + // EdgeUser holds the string denoting the user edge name in mutations. + EdgeUser = "user" + // EdgeChannels holds the string denoting the channels edge name in mutations. + EdgeChannels = "channels" + // EdgeAdoptionDecisions holds the string denoting the adoption_decisions edge name in mutations. + EdgeAdoptionDecisions = "adoption_decisions" + // Table holds the table name of the authidentity in the database. + Table = "auth_identities" + // UserTable is the table that holds the user relation/edge. + UserTable = "auth_identities" + // UserInverseTable is the table name for the User entity. + // It exists in this package in order to avoid circular dependency with the "user" package. + UserInverseTable = "users" + // UserColumn is the table column denoting the user relation/edge. + UserColumn = "user_id" + // ChannelsTable is the table that holds the channels relation/edge. + ChannelsTable = "auth_identity_channels" + // ChannelsInverseTable is the table name for the AuthIdentityChannel entity. + // It exists in this package in order to avoid circular dependency with the "authidentitychannel" package. + ChannelsInverseTable = "auth_identity_channels" + // ChannelsColumn is the table column denoting the channels relation/edge. + ChannelsColumn = "identity_id" + // AdoptionDecisionsTable is the table that holds the adoption_decisions relation/edge. + AdoptionDecisionsTable = "identity_adoption_decisions" + // AdoptionDecisionsInverseTable is the table name for the IdentityAdoptionDecision entity. + // It exists in this package in order to avoid circular dependency with the "identityadoptiondecision" package. + AdoptionDecisionsInverseTable = "identity_adoption_decisions" + // AdoptionDecisionsColumn is the table column denoting the adoption_decisions relation/edge. + AdoptionDecisionsColumn = "identity_id" +) + +// Columns holds all SQL columns for authidentity fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldUserID, + FieldProviderType, + FieldProviderKey, + FieldProviderSubject, + FieldVerifiedAt, + FieldIssuer, + FieldMetadata, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save. + ProviderTypeValidator func(string) error + // ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save. + ProviderKeyValidator func(string) error + // ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save. + ProviderSubjectValidator func(string) error + // DefaultMetadata holds the default value on creation for the "metadata" field. + DefaultMetadata func() map[string]interface{} +) + +// OrderOption defines the ordering options for the AuthIdentity queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByUserID orders the results by the user_id field. +func ByUserID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUserID, opts...).ToFunc() +} + +// ByProviderType orders the results by the provider_type field. +func ByProviderType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderType, opts...).ToFunc() +} + +// ByProviderKey orders the results by the provider_key field. +func ByProviderKey(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderKey, opts...).ToFunc() +} + +// ByProviderSubject orders the results by the provider_subject field. +func ByProviderSubject(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderSubject, opts...).ToFunc() +} + +// ByVerifiedAt orders the results by the verified_at field. +func ByVerifiedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldVerifiedAt, opts...).ToFunc() +} + +// ByIssuer orders the results by the issuer field. +func ByIssuer(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIssuer, opts...).ToFunc() +} + +// ByUserField orders the results by user field. +func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newUserStep(), sql.OrderByField(field, opts...)) + } +} + +// ByChannelsCount orders the results by channels count. +func ByChannelsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newChannelsStep(), opts...) + } +} + +// ByChannels orders the results by channels terms. +func ByChannels(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newChannelsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// ByAdoptionDecisionsCount orders the results by adoption_decisions count. +func ByAdoptionDecisionsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newAdoptionDecisionsStep(), opts...) + } +} + +// ByAdoptionDecisions orders the results by adoption_decisions terms. +func ByAdoptionDecisions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newAdoptionDecisionsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} +func newUserStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(UserInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn), + ) +} +func newChannelsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(ChannelsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, ChannelsTable, ChannelsColumn), + ) +} +func newAdoptionDecisionsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(AdoptionDecisionsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, AdoptionDecisionsTable, AdoptionDecisionsColumn), + ) +} diff --git a/backend/ent/authidentity/where.go b/backend/ent/authidentity/where.go new file mode 100644 index 0000000000000000000000000000000000000000..3dbf317879b8813b2dee21acdda4f6122d477865 --- /dev/null +++ b/backend/ent/authidentity/where.go @@ -0,0 +1,600 @@ +// Code generated by ent, DO NOT EDIT. + +package authidentity + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UserID applies equality check predicate on the "user_id" field. It's identical to UserIDEQ. +func UserID(v int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldUserID, v)) +} + +// ProviderType applies equality check predicate on the "provider_type" field. It's identical to ProviderTypeEQ. +func ProviderType(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldProviderType, v)) +} + +// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ. +func ProviderKey(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldProviderKey, v)) +} + +// ProviderSubject applies equality check predicate on the "provider_subject" field. It's identical to ProviderSubjectEQ. +func ProviderSubject(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldProviderSubject, v)) +} + +// VerifiedAt applies equality check predicate on the "verified_at" field. It's identical to VerifiedAtEQ. +func VerifiedAt(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldVerifiedAt, v)) +} + +// Issuer applies equality check predicate on the "issuer" field. It's identical to IssuerEQ. +func Issuer(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldIssuer, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// UserIDEQ applies the EQ predicate on the "user_id" field. +func UserIDEQ(v int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldUserID, v)) +} + +// UserIDNEQ applies the NEQ predicate on the "user_id" field. +func UserIDNEQ(v int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldUserID, v)) +} + +// UserIDIn applies the In predicate on the "user_id" field. +func UserIDIn(vs ...int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldUserID, vs...)) +} + +// UserIDNotIn applies the NotIn predicate on the "user_id" field. +func UserIDNotIn(vs ...int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldUserID, vs...)) +} + +// ProviderTypeEQ applies the EQ predicate on the "provider_type" field. +func ProviderTypeEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldProviderType, v)) +} + +// ProviderTypeNEQ applies the NEQ predicate on the "provider_type" field. +func ProviderTypeNEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldProviderType, v)) +} + +// ProviderTypeIn applies the In predicate on the "provider_type" field. +func ProviderTypeIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldProviderType, vs...)) +} + +// ProviderTypeNotIn applies the NotIn predicate on the "provider_type" field. +func ProviderTypeNotIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldProviderType, vs...)) +} + +// ProviderTypeGT applies the GT predicate on the "provider_type" field. +func ProviderTypeGT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldProviderType, v)) +} + +// ProviderTypeGTE applies the GTE predicate on the "provider_type" field. +func ProviderTypeGTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldProviderType, v)) +} + +// ProviderTypeLT applies the LT predicate on the "provider_type" field. +func ProviderTypeLT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldProviderType, v)) +} + +// ProviderTypeLTE applies the LTE predicate on the "provider_type" field. +func ProviderTypeLTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldProviderType, v)) +} + +// ProviderTypeContains applies the Contains predicate on the "provider_type" field. +func ProviderTypeContains(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContains(FieldProviderType, v)) +} + +// ProviderTypeHasPrefix applies the HasPrefix predicate on the "provider_type" field. +func ProviderTypeHasPrefix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasPrefix(FieldProviderType, v)) +} + +// ProviderTypeHasSuffix applies the HasSuffix predicate on the "provider_type" field. +func ProviderTypeHasSuffix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasSuffix(FieldProviderType, v)) +} + +// ProviderTypeEqualFold applies the EqualFold predicate on the "provider_type" field. +func ProviderTypeEqualFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEqualFold(FieldProviderType, v)) +} + +// ProviderTypeContainsFold applies the ContainsFold predicate on the "provider_type" field. +func ProviderTypeContainsFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContainsFold(FieldProviderType, v)) +} + +// ProviderKeyEQ applies the EQ predicate on the "provider_key" field. +func ProviderKeyEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldProviderKey, v)) +} + +// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field. +func ProviderKeyNEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldProviderKey, v)) +} + +// ProviderKeyIn applies the In predicate on the "provider_key" field. +func ProviderKeyIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldProviderKey, vs...)) +} + +// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field. +func ProviderKeyNotIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldProviderKey, vs...)) +} + +// ProviderKeyGT applies the GT predicate on the "provider_key" field. +func ProviderKeyGT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldProviderKey, v)) +} + +// ProviderKeyGTE applies the GTE predicate on the "provider_key" field. +func ProviderKeyGTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldProviderKey, v)) +} + +// ProviderKeyLT applies the LT predicate on the "provider_key" field. +func ProviderKeyLT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldProviderKey, v)) +} + +// ProviderKeyLTE applies the LTE predicate on the "provider_key" field. +func ProviderKeyLTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldProviderKey, v)) +} + +// ProviderKeyContains applies the Contains predicate on the "provider_key" field. +func ProviderKeyContains(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContains(FieldProviderKey, v)) +} + +// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field. +func ProviderKeyHasPrefix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasPrefix(FieldProviderKey, v)) +} + +// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field. +func ProviderKeyHasSuffix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasSuffix(FieldProviderKey, v)) +} + +// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field. +func ProviderKeyEqualFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEqualFold(FieldProviderKey, v)) +} + +// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field. +func ProviderKeyContainsFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContainsFold(FieldProviderKey, v)) +} + +// ProviderSubjectEQ applies the EQ predicate on the "provider_subject" field. +func ProviderSubjectEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldProviderSubject, v)) +} + +// ProviderSubjectNEQ applies the NEQ predicate on the "provider_subject" field. +func ProviderSubjectNEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldProviderSubject, v)) +} + +// ProviderSubjectIn applies the In predicate on the "provider_subject" field. +func ProviderSubjectIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldProviderSubject, vs...)) +} + +// ProviderSubjectNotIn applies the NotIn predicate on the "provider_subject" field. +func ProviderSubjectNotIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldProviderSubject, vs...)) +} + +// ProviderSubjectGT applies the GT predicate on the "provider_subject" field. +func ProviderSubjectGT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldProviderSubject, v)) +} + +// ProviderSubjectGTE applies the GTE predicate on the "provider_subject" field. +func ProviderSubjectGTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldProviderSubject, v)) +} + +// ProviderSubjectLT applies the LT predicate on the "provider_subject" field. +func ProviderSubjectLT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldProviderSubject, v)) +} + +// ProviderSubjectLTE applies the LTE predicate on the "provider_subject" field. +func ProviderSubjectLTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldProviderSubject, v)) +} + +// ProviderSubjectContains applies the Contains predicate on the "provider_subject" field. +func ProviderSubjectContains(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContains(FieldProviderSubject, v)) +} + +// ProviderSubjectHasPrefix applies the HasPrefix predicate on the "provider_subject" field. +func ProviderSubjectHasPrefix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasPrefix(FieldProviderSubject, v)) +} + +// ProviderSubjectHasSuffix applies the HasSuffix predicate on the "provider_subject" field. +func ProviderSubjectHasSuffix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasSuffix(FieldProviderSubject, v)) +} + +// ProviderSubjectEqualFold applies the EqualFold predicate on the "provider_subject" field. +func ProviderSubjectEqualFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEqualFold(FieldProviderSubject, v)) +} + +// ProviderSubjectContainsFold applies the ContainsFold predicate on the "provider_subject" field. +func ProviderSubjectContainsFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContainsFold(FieldProviderSubject, v)) +} + +// VerifiedAtEQ applies the EQ predicate on the "verified_at" field. +func VerifiedAtEQ(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldVerifiedAt, v)) +} + +// VerifiedAtNEQ applies the NEQ predicate on the "verified_at" field. +func VerifiedAtNEQ(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldVerifiedAt, v)) +} + +// VerifiedAtIn applies the In predicate on the "verified_at" field. +func VerifiedAtIn(vs ...time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldVerifiedAt, vs...)) +} + +// VerifiedAtNotIn applies the NotIn predicate on the "verified_at" field. +func VerifiedAtNotIn(vs ...time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldVerifiedAt, vs...)) +} + +// VerifiedAtGT applies the GT predicate on the "verified_at" field. +func VerifiedAtGT(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldVerifiedAt, v)) +} + +// VerifiedAtGTE applies the GTE predicate on the "verified_at" field. +func VerifiedAtGTE(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldVerifiedAt, v)) +} + +// VerifiedAtLT applies the LT predicate on the "verified_at" field. +func VerifiedAtLT(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldVerifiedAt, v)) +} + +// VerifiedAtLTE applies the LTE predicate on the "verified_at" field. +func VerifiedAtLTE(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldVerifiedAt, v)) +} + +// VerifiedAtIsNil applies the IsNil predicate on the "verified_at" field. +func VerifiedAtIsNil() predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIsNull(FieldVerifiedAt)) +} + +// VerifiedAtNotNil applies the NotNil predicate on the "verified_at" field. +func VerifiedAtNotNil() predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotNull(FieldVerifiedAt)) +} + +// IssuerEQ applies the EQ predicate on the "issuer" field. +func IssuerEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldIssuer, v)) +} + +// IssuerNEQ applies the NEQ predicate on the "issuer" field. +func IssuerNEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldIssuer, v)) +} + +// IssuerIn applies the In predicate on the "issuer" field. +func IssuerIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldIssuer, vs...)) +} + +// IssuerNotIn applies the NotIn predicate on the "issuer" field. +func IssuerNotIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldIssuer, vs...)) +} + +// IssuerGT applies the GT predicate on the "issuer" field. +func IssuerGT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldIssuer, v)) +} + +// IssuerGTE applies the GTE predicate on the "issuer" field. +func IssuerGTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldIssuer, v)) +} + +// IssuerLT applies the LT predicate on the "issuer" field. +func IssuerLT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldIssuer, v)) +} + +// IssuerLTE applies the LTE predicate on the "issuer" field. +func IssuerLTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldIssuer, v)) +} + +// IssuerContains applies the Contains predicate on the "issuer" field. +func IssuerContains(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContains(FieldIssuer, v)) +} + +// IssuerHasPrefix applies the HasPrefix predicate on the "issuer" field. +func IssuerHasPrefix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasPrefix(FieldIssuer, v)) +} + +// IssuerHasSuffix applies the HasSuffix predicate on the "issuer" field. +func IssuerHasSuffix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasSuffix(FieldIssuer, v)) +} + +// IssuerIsNil applies the IsNil predicate on the "issuer" field. +func IssuerIsNil() predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIsNull(FieldIssuer)) +} + +// IssuerNotNil applies the NotNil predicate on the "issuer" field. +func IssuerNotNil() predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotNull(FieldIssuer)) +} + +// IssuerEqualFold applies the EqualFold predicate on the "issuer" field. +func IssuerEqualFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEqualFold(FieldIssuer, v)) +} + +// IssuerContainsFold applies the ContainsFold predicate on the "issuer" field. +func IssuerContainsFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContainsFold(FieldIssuer, v)) +} + +// HasUser applies the HasEdge predicate on the "user" edge. +func HasUser() predicate.AuthIdentity { + return predicate.AuthIdentity(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasUserWith applies the HasEdge predicate on the "user" edge with a given conditions (other predicates). +func HasUserWith(preds ...predicate.User) predicate.AuthIdentity { + return predicate.AuthIdentity(func(s *sql.Selector) { + step := newUserStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasChannels applies the HasEdge predicate on the "channels" edge. +func HasChannels() predicate.AuthIdentity { + return predicate.AuthIdentity(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, ChannelsTable, ChannelsColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasChannelsWith applies the HasEdge predicate on the "channels" edge with a given conditions (other predicates). +func HasChannelsWith(preds ...predicate.AuthIdentityChannel) predicate.AuthIdentity { + return predicate.AuthIdentity(func(s *sql.Selector) { + step := newChannelsStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasAdoptionDecisions applies the HasEdge predicate on the "adoption_decisions" edge. +func HasAdoptionDecisions() predicate.AuthIdentity { + return predicate.AuthIdentity(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, AdoptionDecisionsTable, AdoptionDecisionsColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasAdoptionDecisionsWith applies the HasEdge predicate on the "adoption_decisions" edge with a given conditions (other predicates). +func HasAdoptionDecisionsWith(preds ...predicate.IdentityAdoptionDecision) predicate.AuthIdentity { + return predicate.AuthIdentity(func(s *sql.Selector) { + step := newAdoptionDecisionsStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.AuthIdentity) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.AuthIdentity) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.AuthIdentity) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.NotPredicates(p)) +} diff --git a/backend/ent/authidentity_create.go b/backend/ent/authidentity_create.go new file mode 100644 index 0000000000000000000000000000000000000000..e287705ce2af71def4c2140c10877347c24ff459 --- /dev/null +++ b/backend/ent/authidentity_create.go @@ -0,0 +1,1036 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// AuthIdentityCreate is the builder for creating a AuthIdentity entity. +type AuthIdentityCreate struct { + config + mutation *AuthIdentityMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (_c *AuthIdentityCreate) SetCreatedAt(v time.Time) *AuthIdentityCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *AuthIdentityCreate) SetNillableCreatedAt(v *time.Time) *AuthIdentityCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUpdatedAt sets the "updated_at" field. +func (_c *AuthIdentityCreate) SetUpdatedAt(v time.Time) *AuthIdentityCreate { + _c.mutation.SetUpdatedAt(v) + return _c +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (_c *AuthIdentityCreate) SetNillableUpdatedAt(v *time.Time) *AuthIdentityCreate { + if v != nil { + _c.SetUpdatedAt(*v) + } + return _c +} + +// SetUserID sets the "user_id" field. +func (_c *AuthIdentityCreate) SetUserID(v int64) *AuthIdentityCreate { + _c.mutation.SetUserID(v) + return _c +} + +// SetProviderType sets the "provider_type" field. +func (_c *AuthIdentityCreate) SetProviderType(v string) *AuthIdentityCreate { + _c.mutation.SetProviderType(v) + return _c +} + +// SetProviderKey sets the "provider_key" field. +func (_c *AuthIdentityCreate) SetProviderKey(v string) *AuthIdentityCreate { + _c.mutation.SetProviderKey(v) + return _c +} + +// SetProviderSubject sets the "provider_subject" field. +func (_c *AuthIdentityCreate) SetProviderSubject(v string) *AuthIdentityCreate { + _c.mutation.SetProviderSubject(v) + return _c +} + +// SetVerifiedAt sets the "verified_at" field. +func (_c *AuthIdentityCreate) SetVerifiedAt(v time.Time) *AuthIdentityCreate { + _c.mutation.SetVerifiedAt(v) + return _c +} + +// SetNillableVerifiedAt sets the "verified_at" field if the given value is not nil. +func (_c *AuthIdentityCreate) SetNillableVerifiedAt(v *time.Time) *AuthIdentityCreate { + if v != nil { + _c.SetVerifiedAt(*v) + } + return _c +} + +// SetIssuer sets the "issuer" field. +func (_c *AuthIdentityCreate) SetIssuer(v string) *AuthIdentityCreate { + _c.mutation.SetIssuer(v) + return _c +} + +// SetNillableIssuer sets the "issuer" field if the given value is not nil. +func (_c *AuthIdentityCreate) SetNillableIssuer(v *string) *AuthIdentityCreate { + if v != nil { + _c.SetIssuer(*v) + } + return _c +} + +// SetMetadata sets the "metadata" field. +func (_c *AuthIdentityCreate) SetMetadata(v map[string]interface{}) *AuthIdentityCreate { + _c.mutation.SetMetadata(v) + return _c +} + +// SetUser sets the "user" edge to the User entity. +func (_c *AuthIdentityCreate) SetUser(v *User) *AuthIdentityCreate { + return _c.SetUserID(v.ID) +} + +// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by IDs. +func (_c *AuthIdentityCreate) AddChannelIDs(ids ...int64) *AuthIdentityCreate { + _c.mutation.AddChannelIDs(ids...) + return _c +} + +// AddChannels adds the "channels" edges to the AuthIdentityChannel entity. +func (_c *AuthIdentityCreate) AddChannels(v ...*AuthIdentityChannel) *AuthIdentityCreate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _c.AddChannelIDs(ids...) +} + +// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs. +func (_c *AuthIdentityCreate) AddAdoptionDecisionIDs(ids ...int64) *AuthIdentityCreate { + _c.mutation.AddAdoptionDecisionIDs(ids...) + return _c +} + +// AddAdoptionDecisions adds the "adoption_decisions" edges to the IdentityAdoptionDecision entity. +func (_c *AuthIdentityCreate) AddAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityCreate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _c.AddAdoptionDecisionIDs(ids...) +} + +// Mutation returns the AuthIdentityMutation object of the builder. +func (_c *AuthIdentityCreate) Mutation() *AuthIdentityMutation { + return _c.mutation +} + +// Save creates the AuthIdentity in the database. +func (_c *AuthIdentityCreate) Save(ctx context.Context) (*AuthIdentity, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *AuthIdentityCreate) SaveX(ctx context.Context) *AuthIdentity { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *AuthIdentityCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *AuthIdentityCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *AuthIdentityCreate) defaults() { + if _, ok := _c.mutation.CreatedAt(); !ok { + v := authidentity.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + v := authidentity.DefaultUpdatedAt() + _c.mutation.SetUpdatedAt(v) + } + if _, ok := _c.mutation.Metadata(); !ok { + v := authidentity.DefaultMetadata() + _c.mutation.SetMetadata(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *AuthIdentityCreate) check() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "AuthIdentity.created_at"`)} + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "AuthIdentity.updated_at"`)} + } + if _, ok := _c.mutation.UserID(); !ok { + return &ValidationError{Name: "user_id", err: errors.New(`ent: missing required field "AuthIdentity.user_id"`)} + } + if _, ok := _c.mutation.ProviderType(); !ok { + return &ValidationError{Name: "provider_type", err: errors.New(`ent: missing required field "AuthIdentity.provider_type"`)} + } + if v, ok := _c.mutation.ProviderType(); ok { + if err := authidentity.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_type": %w`, err)} + } + } + if _, ok := _c.mutation.ProviderKey(); !ok { + return &ValidationError{Name: "provider_key", err: errors.New(`ent: missing required field "AuthIdentity.provider_key"`)} + } + if v, ok := _c.mutation.ProviderKey(); ok { + if err := authidentity.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_key": %w`, err)} + } + } + if _, ok := _c.mutation.ProviderSubject(); !ok { + return &ValidationError{Name: "provider_subject", err: errors.New(`ent: missing required field "AuthIdentity.provider_subject"`)} + } + if v, ok := _c.mutation.ProviderSubject(); ok { + if err := authidentity.ProviderSubjectValidator(v); err != nil { + return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_subject": %w`, err)} + } + } + if _, ok := _c.mutation.Metadata(); !ok { + return &ValidationError{Name: "metadata", err: errors.New(`ent: missing required field "AuthIdentity.metadata"`)} + } + if len(_c.mutation.UserIDs()) == 0 { + return &ValidationError{Name: "user", err: errors.New(`ent: missing required edge "AuthIdentity.user"`)} + } + return nil +} + +func (_c *AuthIdentityCreate) sqlSave(ctx context.Context) (*AuthIdentity, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *AuthIdentityCreate) createSpec() (*AuthIdentity, *sqlgraph.CreateSpec) { + var ( + _node = &AuthIdentity{config: _c.config} + _spec = sqlgraph.NewCreateSpec(authidentity.Table, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(authidentity.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.UpdatedAt(); ok { + _spec.SetField(authidentity.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := _c.mutation.ProviderType(); ok { + _spec.SetField(authidentity.FieldProviderType, field.TypeString, value) + _node.ProviderType = value + } + if value, ok := _c.mutation.ProviderKey(); ok { + _spec.SetField(authidentity.FieldProviderKey, field.TypeString, value) + _node.ProviderKey = value + } + if value, ok := _c.mutation.ProviderSubject(); ok { + _spec.SetField(authidentity.FieldProviderSubject, field.TypeString, value) + _node.ProviderSubject = value + } + if value, ok := _c.mutation.VerifiedAt(); ok { + _spec.SetField(authidentity.FieldVerifiedAt, field.TypeTime, value) + _node.VerifiedAt = &value + } + if value, ok := _c.mutation.Issuer(); ok { + _spec.SetField(authidentity.FieldIssuer, field.TypeString, value) + _node.Issuer = &value + } + if value, ok := _c.mutation.Metadata(); ok { + _spec.SetField(authidentity.FieldMetadata, field.TypeJSON, value) + _node.Metadata = value + } + if nodes := _c.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentity.UserTable, + Columns: []string{authidentity.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.UserID = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := _c.mutation.ChannelsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.ChannelsTable, + Columns: []string{authidentity.ChannelsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := _c.mutation.AdoptionDecisionsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.AdoptionDecisionsTable, + Columns: []string{authidentity.AdoptionDecisionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.AuthIdentity.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.AuthIdentityUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *AuthIdentityCreate) OnConflict(opts ...sql.ConflictOption) *AuthIdentityUpsertOne { + _c.conflict = opts + return &AuthIdentityUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.AuthIdentity.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *AuthIdentityCreate) OnConflictColumns(columns ...string) *AuthIdentityUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &AuthIdentityUpsertOne{ + create: _c, + } +} + +type ( + // AuthIdentityUpsertOne is the builder for "upsert"-ing + // one AuthIdentity node. + AuthIdentityUpsertOne struct { + create *AuthIdentityCreate + } + + // AuthIdentityUpsert is the "OnConflict" setter. + AuthIdentityUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *AuthIdentityUpsert) SetUpdatedAt(v time.Time) *AuthIdentityUpsert { + u.Set(authidentity.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateUpdatedAt() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldUpdatedAt) + return u +} + +// SetUserID sets the "user_id" field. +func (u *AuthIdentityUpsert) SetUserID(v int64) *AuthIdentityUpsert { + u.Set(authidentity.FieldUserID, v) + return u +} + +// UpdateUserID sets the "user_id" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateUserID() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldUserID) + return u +} + +// SetProviderType sets the "provider_type" field. +func (u *AuthIdentityUpsert) SetProviderType(v string) *AuthIdentityUpsert { + u.Set(authidentity.FieldProviderType, v) + return u +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateProviderType() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldProviderType) + return u +} + +// SetProviderKey sets the "provider_key" field. +func (u *AuthIdentityUpsert) SetProviderKey(v string) *AuthIdentityUpsert { + u.Set(authidentity.FieldProviderKey, v) + return u +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateProviderKey() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldProviderKey) + return u +} + +// SetProviderSubject sets the "provider_subject" field. +func (u *AuthIdentityUpsert) SetProviderSubject(v string) *AuthIdentityUpsert { + u.Set(authidentity.FieldProviderSubject, v) + return u +} + +// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateProviderSubject() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldProviderSubject) + return u +} + +// SetVerifiedAt sets the "verified_at" field. +func (u *AuthIdentityUpsert) SetVerifiedAt(v time.Time) *AuthIdentityUpsert { + u.Set(authidentity.FieldVerifiedAt, v) + return u +} + +// UpdateVerifiedAt sets the "verified_at" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateVerifiedAt() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldVerifiedAt) + return u +} + +// ClearVerifiedAt clears the value of the "verified_at" field. +func (u *AuthIdentityUpsert) ClearVerifiedAt() *AuthIdentityUpsert { + u.SetNull(authidentity.FieldVerifiedAt) + return u +} + +// SetIssuer sets the "issuer" field. +func (u *AuthIdentityUpsert) SetIssuer(v string) *AuthIdentityUpsert { + u.Set(authidentity.FieldIssuer, v) + return u +} + +// UpdateIssuer sets the "issuer" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateIssuer() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldIssuer) + return u +} + +// ClearIssuer clears the value of the "issuer" field. +func (u *AuthIdentityUpsert) ClearIssuer() *AuthIdentityUpsert { + u.SetNull(authidentity.FieldIssuer) + return u +} + +// SetMetadata sets the "metadata" field. +func (u *AuthIdentityUpsert) SetMetadata(v map[string]interface{}) *AuthIdentityUpsert { + u.Set(authidentity.FieldMetadata, v) + return u +} + +// UpdateMetadata sets the "metadata" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateMetadata() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldMetadata) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.AuthIdentity.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *AuthIdentityUpsertOne) UpdateNewValues() *AuthIdentityUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(authidentity.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.AuthIdentity.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *AuthIdentityUpsertOne) Ignore() *AuthIdentityUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *AuthIdentityUpsertOne) DoNothing() *AuthIdentityUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the AuthIdentityCreate.OnConflict +// documentation for more info. +func (u *AuthIdentityUpsertOne) Update(set func(*AuthIdentityUpsert)) *AuthIdentityUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&AuthIdentityUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *AuthIdentityUpsertOne) SetUpdatedAt(v time.Time) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateUpdatedAt() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetUserID sets the "user_id" field. +func (u *AuthIdentityUpsertOne) SetUserID(v int64) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetUserID(v) + }) +} + +// UpdateUserID sets the "user_id" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateUserID() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateUserID() + }) +} + +// SetProviderType sets the "provider_type" field. +func (u *AuthIdentityUpsertOne) SetProviderType(v string) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetProviderType(v) + }) +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateProviderType() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateProviderType() + }) +} + +// SetProviderKey sets the "provider_key" field. +func (u *AuthIdentityUpsertOne) SetProviderKey(v string) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetProviderKey(v) + }) +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateProviderKey() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateProviderKey() + }) +} + +// SetProviderSubject sets the "provider_subject" field. +func (u *AuthIdentityUpsertOne) SetProviderSubject(v string) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetProviderSubject(v) + }) +} + +// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateProviderSubject() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateProviderSubject() + }) +} + +// SetVerifiedAt sets the "verified_at" field. +func (u *AuthIdentityUpsertOne) SetVerifiedAt(v time.Time) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetVerifiedAt(v) + }) +} + +// UpdateVerifiedAt sets the "verified_at" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateVerifiedAt() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateVerifiedAt() + }) +} + +// ClearVerifiedAt clears the value of the "verified_at" field. +func (u *AuthIdentityUpsertOne) ClearVerifiedAt() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.ClearVerifiedAt() + }) +} + +// SetIssuer sets the "issuer" field. +func (u *AuthIdentityUpsertOne) SetIssuer(v string) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetIssuer(v) + }) +} + +// UpdateIssuer sets the "issuer" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateIssuer() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateIssuer() + }) +} + +// ClearIssuer clears the value of the "issuer" field. +func (u *AuthIdentityUpsertOne) ClearIssuer() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.ClearIssuer() + }) +} + +// SetMetadata sets the "metadata" field. +func (u *AuthIdentityUpsertOne) SetMetadata(v map[string]interface{}) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetMetadata(v) + }) +} + +// UpdateMetadata sets the "metadata" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateMetadata() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateMetadata() + }) +} + +// Exec executes the query. +func (u *AuthIdentityUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for AuthIdentityCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *AuthIdentityUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *AuthIdentityUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *AuthIdentityUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// AuthIdentityCreateBulk is the builder for creating many AuthIdentity entities in bulk. +type AuthIdentityCreateBulk struct { + config + err error + builders []*AuthIdentityCreate + conflict []sql.ConflictOption +} + +// Save creates the AuthIdentity entities in the database. +func (_c *AuthIdentityCreateBulk) Save(ctx context.Context) ([]*AuthIdentity, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*AuthIdentity, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*AuthIdentityMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *AuthIdentityCreateBulk) SaveX(ctx context.Context) []*AuthIdentity { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *AuthIdentityCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *AuthIdentityCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.AuthIdentity.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.AuthIdentityUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *AuthIdentityCreateBulk) OnConflict(opts ...sql.ConflictOption) *AuthIdentityUpsertBulk { + _c.conflict = opts + return &AuthIdentityUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.AuthIdentity.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *AuthIdentityCreateBulk) OnConflictColumns(columns ...string) *AuthIdentityUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &AuthIdentityUpsertBulk{ + create: _c, + } +} + +// AuthIdentityUpsertBulk is the builder for "upsert"-ing +// a bulk of AuthIdentity nodes. +type AuthIdentityUpsertBulk struct { + create *AuthIdentityCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.AuthIdentity.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *AuthIdentityUpsertBulk) UpdateNewValues() *AuthIdentityUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(authidentity.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.AuthIdentity.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *AuthIdentityUpsertBulk) Ignore() *AuthIdentityUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *AuthIdentityUpsertBulk) DoNothing() *AuthIdentityUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the AuthIdentityCreateBulk.OnConflict +// documentation for more info. +func (u *AuthIdentityUpsertBulk) Update(set func(*AuthIdentityUpsert)) *AuthIdentityUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&AuthIdentityUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *AuthIdentityUpsertBulk) SetUpdatedAt(v time.Time) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateUpdatedAt() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetUserID sets the "user_id" field. +func (u *AuthIdentityUpsertBulk) SetUserID(v int64) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetUserID(v) + }) +} + +// UpdateUserID sets the "user_id" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateUserID() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateUserID() + }) +} + +// SetProviderType sets the "provider_type" field. +func (u *AuthIdentityUpsertBulk) SetProviderType(v string) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetProviderType(v) + }) +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateProviderType() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateProviderType() + }) +} + +// SetProviderKey sets the "provider_key" field. +func (u *AuthIdentityUpsertBulk) SetProviderKey(v string) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetProviderKey(v) + }) +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateProviderKey() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateProviderKey() + }) +} + +// SetProviderSubject sets the "provider_subject" field. +func (u *AuthIdentityUpsertBulk) SetProviderSubject(v string) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetProviderSubject(v) + }) +} + +// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateProviderSubject() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateProviderSubject() + }) +} + +// SetVerifiedAt sets the "verified_at" field. +func (u *AuthIdentityUpsertBulk) SetVerifiedAt(v time.Time) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetVerifiedAt(v) + }) +} + +// UpdateVerifiedAt sets the "verified_at" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateVerifiedAt() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateVerifiedAt() + }) +} + +// ClearVerifiedAt clears the value of the "verified_at" field. +func (u *AuthIdentityUpsertBulk) ClearVerifiedAt() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.ClearVerifiedAt() + }) +} + +// SetIssuer sets the "issuer" field. +func (u *AuthIdentityUpsertBulk) SetIssuer(v string) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetIssuer(v) + }) +} + +// UpdateIssuer sets the "issuer" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateIssuer() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateIssuer() + }) +} + +// ClearIssuer clears the value of the "issuer" field. +func (u *AuthIdentityUpsertBulk) ClearIssuer() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.ClearIssuer() + }) +} + +// SetMetadata sets the "metadata" field. +func (u *AuthIdentityUpsertBulk) SetMetadata(v map[string]interface{}) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetMetadata(v) + }) +} + +// UpdateMetadata sets the "metadata" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateMetadata() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateMetadata() + }) +} + +// Exec executes the query. +func (u *AuthIdentityUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the AuthIdentityCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for AuthIdentityCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *AuthIdentityUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/authidentity_delete.go b/backend/ent/authidentity_delete.go new file mode 100644 index 0000000000000000000000000000000000000000..4f1f6f3ce48d3a28fb14bd08fdcc11e4b4420576 --- /dev/null +++ b/backend/ent/authidentity_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// AuthIdentityDelete is the builder for deleting a AuthIdentity entity. +type AuthIdentityDelete struct { + config + hooks []Hook + mutation *AuthIdentityMutation +} + +// Where appends a list predicates to the AuthIdentityDelete builder. +func (_d *AuthIdentityDelete) Where(ps ...predicate.AuthIdentity) *AuthIdentityDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *AuthIdentityDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *AuthIdentityDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *AuthIdentityDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(authidentity.Table, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// AuthIdentityDeleteOne is the builder for deleting a single AuthIdentity entity. +type AuthIdentityDeleteOne struct { + _d *AuthIdentityDelete +} + +// Where appends a list predicates to the AuthIdentityDelete builder. +func (_d *AuthIdentityDeleteOne) Where(ps ...predicate.AuthIdentity) *AuthIdentityDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *AuthIdentityDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{authidentity.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *AuthIdentityDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/authidentity_query.go b/backend/ent/authidentity_query.go new file mode 100644 index 0000000000000000000000000000000000000000..ff27ef3cd260d445c356f55e39e761055fd25ac0 --- /dev/null +++ b/backend/ent/authidentity_query.go @@ -0,0 +1,797 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "database/sql/driver" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// AuthIdentityQuery is the builder for querying AuthIdentity entities. +type AuthIdentityQuery struct { + config + ctx *QueryContext + order []authidentity.OrderOption + inters []Interceptor + predicates []predicate.AuthIdentity + withUser *UserQuery + withChannels *AuthIdentityChannelQuery + withAdoptionDecisions *IdentityAdoptionDecisionQuery + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the AuthIdentityQuery builder. +func (_q *AuthIdentityQuery) Where(ps ...predicate.AuthIdentity) *AuthIdentityQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *AuthIdentityQuery) Limit(limit int) *AuthIdentityQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *AuthIdentityQuery) Offset(offset int) *AuthIdentityQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *AuthIdentityQuery) Unique(unique bool) *AuthIdentityQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *AuthIdentityQuery) Order(o ...authidentity.OrderOption) *AuthIdentityQuery { + _q.order = append(_q.order, o...) + return _q +} + +// QueryUser chains the current query on the "user" edge. +func (_q *AuthIdentityQuery) QueryUser() *UserQuery { + query := (&UserClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(authidentity.Table, authidentity.FieldID, selector), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, authidentity.UserTable, authidentity.UserColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryChannels chains the current query on the "channels" edge. +func (_q *AuthIdentityQuery) QueryChannels() *AuthIdentityChannelQuery { + query := (&AuthIdentityChannelClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(authidentity.Table, authidentity.FieldID, selector), + sqlgraph.To(authidentitychannel.Table, authidentitychannel.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, authidentity.ChannelsTable, authidentity.ChannelsColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryAdoptionDecisions chains the current query on the "adoption_decisions" edge. +func (_q *AuthIdentityQuery) QueryAdoptionDecisions() *IdentityAdoptionDecisionQuery { + query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(authidentity.Table, authidentity.FieldID, selector), + sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, authidentity.AdoptionDecisionsTable, authidentity.AdoptionDecisionsColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first AuthIdentity entity from the query. +// Returns a *NotFoundError when no AuthIdentity was found. +func (_q *AuthIdentityQuery) First(ctx context.Context) (*AuthIdentity, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{authidentity.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *AuthIdentityQuery) FirstX(ctx context.Context) *AuthIdentity { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first AuthIdentity ID from the query. +// Returns a *NotFoundError when no AuthIdentity ID was found. +func (_q *AuthIdentityQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{authidentity.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *AuthIdentityQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single AuthIdentity entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one AuthIdentity entity is found. +// Returns a *NotFoundError when no AuthIdentity entities are found. +func (_q *AuthIdentityQuery) Only(ctx context.Context) (*AuthIdentity, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{authidentity.Label} + default: + return nil, &NotSingularError{authidentity.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *AuthIdentityQuery) OnlyX(ctx context.Context) *AuthIdentity { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only AuthIdentity ID in the query. +// Returns a *NotSingularError when more than one AuthIdentity ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *AuthIdentityQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{authidentity.Label} + default: + err = &NotSingularError{authidentity.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *AuthIdentityQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of AuthIdentities. +func (_q *AuthIdentityQuery) All(ctx context.Context) ([]*AuthIdentity, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*AuthIdentity, *AuthIdentityQuery]() + return withInterceptors[[]*AuthIdentity](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *AuthIdentityQuery) AllX(ctx context.Context) []*AuthIdentity { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of AuthIdentity IDs. +func (_q *AuthIdentityQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(authidentity.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *AuthIdentityQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *AuthIdentityQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*AuthIdentityQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *AuthIdentityQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *AuthIdentityQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *AuthIdentityQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the AuthIdentityQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *AuthIdentityQuery) Clone() *AuthIdentityQuery { + if _q == nil { + return nil + } + return &AuthIdentityQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]authidentity.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.AuthIdentity{}, _q.predicates...), + withUser: _q.withUser.Clone(), + withChannels: _q.withChannels.Clone(), + withAdoptionDecisions: _q.withAdoptionDecisions.Clone(), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// WithUser tells the query-builder to eager-load the nodes that are connected to +// the "user" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *AuthIdentityQuery) WithUser(opts ...func(*UserQuery)) *AuthIdentityQuery { + query := (&UserClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withUser = query + return _q +} + +// WithChannels tells the query-builder to eager-load the nodes that are connected to +// the "channels" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *AuthIdentityQuery) WithChannels(opts ...func(*AuthIdentityChannelQuery)) *AuthIdentityQuery { + query := (&AuthIdentityChannelClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withChannels = query + return _q +} + +// WithAdoptionDecisions tells the query-builder to eager-load the nodes that are connected to +// the "adoption_decisions" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *AuthIdentityQuery) WithAdoptionDecisions(opts ...func(*IdentityAdoptionDecisionQuery)) *AuthIdentityQuery { + query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withAdoptionDecisions = query + return _q +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.AuthIdentity.Query(). +// GroupBy(authidentity.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *AuthIdentityQuery) GroupBy(field string, fields ...string) *AuthIdentityGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &AuthIdentityGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = authidentity.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.AuthIdentity.Query(). +// Select(authidentity.FieldCreatedAt). +// Scan(ctx, &v) +func (_q *AuthIdentityQuery) Select(fields ...string) *AuthIdentitySelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &AuthIdentitySelect{AuthIdentityQuery: _q} + sbuild.label = authidentity.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a AuthIdentitySelect configured with the given aggregations. +func (_q *AuthIdentityQuery) Aggregate(fns ...AggregateFunc) *AuthIdentitySelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *AuthIdentityQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !authidentity.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *AuthIdentityQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*AuthIdentity, error) { + var ( + nodes = []*AuthIdentity{} + _spec = _q.querySpec() + loadedTypes = [3]bool{ + _q.withUser != nil, + _q.withChannels != nil, + _q.withAdoptionDecisions != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*AuthIdentity).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &AuthIdentity{config: _q.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := _q.withUser; query != nil { + if err := _q.loadUser(ctx, query, nodes, nil, + func(n *AuthIdentity, e *User) { n.Edges.User = e }); err != nil { + return nil, err + } + } + if query := _q.withChannels; query != nil { + if err := _q.loadChannels(ctx, query, nodes, + func(n *AuthIdentity) { n.Edges.Channels = []*AuthIdentityChannel{} }, + func(n *AuthIdentity, e *AuthIdentityChannel) { n.Edges.Channels = append(n.Edges.Channels, e) }); err != nil { + return nil, err + } + } + if query := _q.withAdoptionDecisions; query != nil { + if err := _q.loadAdoptionDecisions(ctx, query, nodes, + func(n *AuthIdentity) { n.Edges.AdoptionDecisions = []*IdentityAdoptionDecision{} }, + func(n *AuthIdentity, e *IdentityAdoptionDecision) { + n.Edges.AdoptionDecisions = append(n.Edges.AdoptionDecisions, e) + }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (_q *AuthIdentityQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*AuthIdentity, init func(*AuthIdentity), assign func(*AuthIdentity, *User)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*AuthIdentity) + for i := range nodes { + fk := nodes[i].UserID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} +func (_q *AuthIdentityQuery) loadChannels(ctx context.Context, query *AuthIdentityChannelQuery, nodes []*AuthIdentity, init func(*AuthIdentity), assign func(*AuthIdentity, *AuthIdentityChannel)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*AuthIdentity) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(authidentitychannel.FieldIdentityID) + } + query.Where(predicate.AuthIdentityChannel(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(authidentity.ChannelsColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.IdentityID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "identity_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} +func (_q *AuthIdentityQuery) loadAdoptionDecisions(ctx context.Context, query *IdentityAdoptionDecisionQuery, nodes []*AuthIdentity, init func(*AuthIdentity), assign func(*AuthIdentity, *IdentityAdoptionDecision)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*AuthIdentity) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(identityadoptiondecision.FieldIdentityID) + } + query.Where(predicate.IdentityAdoptionDecision(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(authidentity.AdoptionDecisionsColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.IdentityID + if fk == nil { + return fmt.Errorf(`foreign-key "identity_id" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "identity_id" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} + +func (_q *AuthIdentityQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *AuthIdentityQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(authidentity.Table, authidentity.Columns, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, authidentity.FieldID) + for i := range fields { + if fields[i] != authidentity.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if _q.withUser != nil { + _spec.Node.AddColumnOnce(authidentity.FieldUserID) + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *AuthIdentityQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(authidentity.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = authidentity.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *AuthIdentityQuery) ForUpdate(opts ...sql.LockOption) *AuthIdentityQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *AuthIdentityQuery) ForShare(opts ...sql.LockOption) *AuthIdentityQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// AuthIdentityGroupBy is the group-by builder for AuthIdentity entities. +type AuthIdentityGroupBy struct { + selector + build *AuthIdentityQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *AuthIdentityGroupBy) Aggregate(fns ...AggregateFunc) *AuthIdentityGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *AuthIdentityGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*AuthIdentityQuery, *AuthIdentityGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *AuthIdentityGroupBy) sqlScan(ctx context.Context, root *AuthIdentityQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// AuthIdentitySelect is the builder for selecting fields of AuthIdentity entities. +type AuthIdentitySelect struct { + *AuthIdentityQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *AuthIdentitySelect) Aggregate(fns ...AggregateFunc) *AuthIdentitySelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *AuthIdentitySelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*AuthIdentityQuery, *AuthIdentitySelect](ctx, _s.AuthIdentityQuery, _s, _s.inters, v) +} + +func (_s *AuthIdentitySelect) sqlScan(ctx context.Context, root *AuthIdentityQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/authidentity_update.go b/backend/ent/authidentity_update.go new file mode 100644 index 0000000000000000000000000000000000000000..c457470b9b17b7bd03231239b19455e5e53d893a --- /dev/null +++ b/backend/ent/authidentity_update.go @@ -0,0 +1,923 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// AuthIdentityUpdate is the builder for updating AuthIdentity entities. +type AuthIdentityUpdate struct { + config + hooks []Hook + mutation *AuthIdentityMutation +} + +// Where appends a list predicates to the AuthIdentityUpdate builder. +func (_u *AuthIdentityUpdate) Where(ps ...predicate.AuthIdentity) *AuthIdentityUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *AuthIdentityUpdate) SetUpdatedAt(v time.Time) *AuthIdentityUpdate { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetUserID sets the "user_id" field. +func (_u *AuthIdentityUpdate) SetUserID(v int64) *AuthIdentityUpdate { + _u.mutation.SetUserID(v) + return _u +} + +// SetNillableUserID sets the "user_id" field if the given value is not nil. +func (_u *AuthIdentityUpdate) SetNillableUserID(v *int64) *AuthIdentityUpdate { + if v != nil { + _u.SetUserID(*v) + } + return _u +} + +// SetProviderType sets the "provider_type" field. +func (_u *AuthIdentityUpdate) SetProviderType(v string) *AuthIdentityUpdate { + _u.mutation.SetProviderType(v) + return _u +} + +// SetNillableProviderType sets the "provider_type" field if the given value is not nil. +func (_u *AuthIdentityUpdate) SetNillableProviderType(v *string) *AuthIdentityUpdate { + if v != nil { + _u.SetProviderType(*v) + } + return _u +} + +// SetProviderKey sets the "provider_key" field. +func (_u *AuthIdentityUpdate) SetProviderKey(v string) *AuthIdentityUpdate { + _u.mutation.SetProviderKey(v) + return _u +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_u *AuthIdentityUpdate) SetNillableProviderKey(v *string) *AuthIdentityUpdate { + if v != nil { + _u.SetProviderKey(*v) + } + return _u +} + +// SetProviderSubject sets the "provider_subject" field. +func (_u *AuthIdentityUpdate) SetProviderSubject(v string) *AuthIdentityUpdate { + _u.mutation.SetProviderSubject(v) + return _u +} + +// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil. +func (_u *AuthIdentityUpdate) SetNillableProviderSubject(v *string) *AuthIdentityUpdate { + if v != nil { + _u.SetProviderSubject(*v) + } + return _u +} + +// SetVerifiedAt sets the "verified_at" field. +func (_u *AuthIdentityUpdate) SetVerifiedAt(v time.Time) *AuthIdentityUpdate { + _u.mutation.SetVerifiedAt(v) + return _u +} + +// SetNillableVerifiedAt sets the "verified_at" field if the given value is not nil. +func (_u *AuthIdentityUpdate) SetNillableVerifiedAt(v *time.Time) *AuthIdentityUpdate { + if v != nil { + _u.SetVerifiedAt(*v) + } + return _u +} + +// ClearVerifiedAt clears the value of the "verified_at" field. +func (_u *AuthIdentityUpdate) ClearVerifiedAt() *AuthIdentityUpdate { + _u.mutation.ClearVerifiedAt() + return _u +} + +// SetIssuer sets the "issuer" field. +func (_u *AuthIdentityUpdate) SetIssuer(v string) *AuthIdentityUpdate { + _u.mutation.SetIssuer(v) + return _u +} + +// SetNillableIssuer sets the "issuer" field if the given value is not nil. +func (_u *AuthIdentityUpdate) SetNillableIssuer(v *string) *AuthIdentityUpdate { + if v != nil { + _u.SetIssuer(*v) + } + return _u +} + +// ClearIssuer clears the value of the "issuer" field. +func (_u *AuthIdentityUpdate) ClearIssuer() *AuthIdentityUpdate { + _u.mutation.ClearIssuer() + return _u +} + +// SetMetadata sets the "metadata" field. +func (_u *AuthIdentityUpdate) SetMetadata(v map[string]interface{}) *AuthIdentityUpdate { + _u.mutation.SetMetadata(v) + return _u +} + +// SetUser sets the "user" edge to the User entity. +func (_u *AuthIdentityUpdate) SetUser(v *User) *AuthIdentityUpdate { + return _u.SetUserID(v.ID) +} + +// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by IDs. +func (_u *AuthIdentityUpdate) AddChannelIDs(ids ...int64) *AuthIdentityUpdate { + _u.mutation.AddChannelIDs(ids...) + return _u +} + +// AddChannels adds the "channels" edges to the AuthIdentityChannel entity. +func (_u *AuthIdentityUpdate) AddChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddChannelIDs(ids...) +} + +// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs. +func (_u *AuthIdentityUpdate) AddAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdate { + _u.mutation.AddAdoptionDecisionIDs(ids...) + return _u +} + +// AddAdoptionDecisions adds the "adoption_decisions" edges to the IdentityAdoptionDecision entity. +func (_u *AuthIdentityUpdate) AddAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddAdoptionDecisionIDs(ids...) +} + +// Mutation returns the AuthIdentityMutation object of the builder. +func (_u *AuthIdentityUpdate) Mutation() *AuthIdentityMutation { + return _u.mutation +} + +// ClearUser clears the "user" edge to the User entity. +func (_u *AuthIdentityUpdate) ClearUser() *AuthIdentityUpdate { + _u.mutation.ClearUser() + return _u +} + +// ClearChannels clears all "channels" edges to the AuthIdentityChannel entity. +func (_u *AuthIdentityUpdate) ClearChannels() *AuthIdentityUpdate { + _u.mutation.ClearChannels() + return _u +} + +// RemoveChannelIDs removes the "channels" edge to AuthIdentityChannel entities by IDs. +func (_u *AuthIdentityUpdate) RemoveChannelIDs(ids ...int64) *AuthIdentityUpdate { + _u.mutation.RemoveChannelIDs(ids...) + return _u +} + +// RemoveChannels removes "channels" edges to AuthIdentityChannel entities. +func (_u *AuthIdentityUpdate) RemoveChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveChannelIDs(ids...) +} + +// ClearAdoptionDecisions clears all "adoption_decisions" edges to the IdentityAdoptionDecision entity. +func (_u *AuthIdentityUpdate) ClearAdoptionDecisions() *AuthIdentityUpdate { + _u.mutation.ClearAdoptionDecisions() + return _u +} + +// RemoveAdoptionDecisionIDs removes the "adoption_decisions" edge to IdentityAdoptionDecision entities by IDs. +func (_u *AuthIdentityUpdate) RemoveAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdate { + _u.mutation.RemoveAdoptionDecisionIDs(ids...) + return _u +} + +// RemoveAdoptionDecisions removes "adoption_decisions" edges to IdentityAdoptionDecision entities. +func (_u *AuthIdentityUpdate) RemoveAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveAdoptionDecisionIDs(ids...) +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *AuthIdentityUpdate) Save(ctx context.Context) (int, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *AuthIdentityUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *AuthIdentityUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *AuthIdentityUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *AuthIdentityUpdate) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := authidentity.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *AuthIdentityUpdate) check() error { + if v, ok := _u.mutation.ProviderType(); ok { + if err := authidentity.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_type": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderKey(); ok { + if err := authidentity.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_key": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderSubject(); ok { + if err := authidentity.ProviderSubjectValidator(v); err != nil { + return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_subject": %w`, err)} + } + } + if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "AuthIdentity.user"`) + } + return nil +} + +func (_u *AuthIdentityUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(authidentity.Table, authidentity.Columns, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(authidentity.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.ProviderType(); ok { + _spec.SetField(authidentity.FieldProviderType, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderKey(); ok { + _spec.SetField(authidentity.FieldProviderKey, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderSubject(); ok { + _spec.SetField(authidentity.FieldProviderSubject, field.TypeString, value) + } + if value, ok := _u.mutation.VerifiedAt(); ok { + _spec.SetField(authidentity.FieldVerifiedAt, field.TypeTime, value) + } + if _u.mutation.VerifiedAtCleared() { + _spec.ClearField(authidentity.FieldVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.Issuer(); ok { + _spec.SetField(authidentity.FieldIssuer, field.TypeString, value) + } + if _u.mutation.IssuerCleared() { + _spec.ClearField(authidentity.FieldIssuer, field.TypeString) + } + if value, ok := _u.mutation.Metadata(); ok { + _spec.SetField(authidentity.FieldMetadata, field.TypeJSON, value) + } + if _u.mutation.UserCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentity.UserTable, + Columns: []string{authidentity.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentity.UserTable, + Columns: []string{authidentity.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.ChannelsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.ChannelsTable, + Columns: []string{authidentity.ChannelsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedChannelsIDs(); len(nodes) > 0 && !_u.mutation.ChannelsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.ChannelsTable, + Columns: []string{authidentity.ChannelsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.ChannelsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.ChannelsTable, + Columns: []string{authidentity.ChannelsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.AdoptionDecisionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.AdoptionDecisionsTable, + Columns: []string{authidentity.AdoptionDecisionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedAdoptionDecisionsIDs(); len(nodes) > 0 && !_u.mutation.AdoptionDecisionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.AdoptionDecisionsTable, + Columns: []string{authidentity.AdoptionDecisionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.AdoptionDecisionsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.AdoptionDecisionsTable, + Columns: []string{authidentity.AdoptionDecisionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{authidentity.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// AuthIdentityUpdateOne is the builder for updating a single AuthIdentity entity. +type AuthIdentityUpdateOne struct { + config + fields []string + hooks []Hook + mutation *AuthIdentityMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *AuthIdentityUpdateOne) SetUpdatedAt(v time.Time) *AuthIdentityUpdateOne { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetUserID sets the "user_id" field. +func (_u *AuthIdentityUpdateOne) SetUserID(v int64) *AuthIdentityUpdateOne { + _u.mutation.SetUserID(v) + return _u +} + +// SetNillableUserID sets the "user_id" field if the given value is not nil. +func (_u *AuthIdentityUpdateOne) SetNillableUserID(v *int64) *AuthIdentityUpdateOne { + if v != nil { + _u.SetUserID(*v) + } + return _u +} + +// SetProviderType sets the "provider_type" field. +func (_u *AuthIdentityUpdateOne) SetProviderType(v string) *AuthIdentityUpdateOne { + _u.mutation.SetProviderType(v) + return _u +} + +// SetNillableProviderType sets the "provider_type" field if the given value is not nil. +func (_u *AuthIdentityUpdateOne) SetNillableProviderType(v *string) *AuthIdentityUpdateOne { + if v != nil { + _u.SetProviderType(*v) + } + return _u +} + +// SetProviderKey sets the "provider_key" field. +func (_u *AuthIdentityUpdateOne) SetProviderKey(v string) *AuthIdentityUpdateOne { + _u.mutation.SetProviderKey(v) + return _u +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_u *AuthIdentityUpdateOne) SetNillableProviderKey(v *string) *AuthIdentityUpdateOne { + if v != nil { + _u.SetProviderKey(*v) + } + return _u +} + +// SetProviderSubject sets the "provider_subject" field. +func (_u *AuthIdentityUpdateOne) SetProviderSubject(v string) *AuthIdentityUpdateOne { + _u.mutation.SetProviderSubject(v) + return _u +} + +// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil. +func (_u *AuthIdentityUpdateOne) SetNillableProviderSubject(v *string) *AuthIdentityUpdateOne { + if v != nil { + _u.SetProviderSubject(*v) + } + return _u +} + +// SetVerifiedAt sets the "verified_at" field. +func (_u *AuthIdentityUpdateOne) SetVerifiedAt(v time.Time) *AuthIdentityUpdateOne { + _u.mutation.SetVerifiedAt(v) + return _u +} + +// SetNillableVerifiedAt sets the "verified_at" field if the given value is not nil. +func (_u *AuthIdentityUpdateOne) SetNillableVerifiedAt(v *time.Time) *AuthIdentityUpdateOne { + if v != nil { + _u.SetVerifiedAt(*v) + } + return _u +} + +// ClearVerifiedAt clears the value of the "verified_at" field. +func (_u *AuthIdentityUpdateOne) ClearVerifiedAt() *AuthIdentityUpdateOne { + _u.mutation.ClearVerifiedAt() + return _u +} + +// SetIssuer sets the "issuer" field. +func (_u *AuthIdentityUpdateOne) SetIssuer(v string) *AuthIdentityUpdateOne { + _u.mutation.SetIssuer(v) + return _u +} + +// SetNillableIssuer sets the "issuer" field if the given value is not nil. +func (_u *AuthIdentityUpdateOne) SetNillableIssuer(v *string) *AuthIdentityUpdateOne { + if v != nil { + _u.SetIssuer(*v) + } + return _u +} + +// ClearIssuer clears the value of the "issuer" field. +func (_u *AuthIdentityUpdateOne) ClearIssuer() *AuthIdentityUpdateOne { + _u.mutation.ClearIssuer() + return _u +} + +// SetMetadata sets the "metadata" field. +func (_u *AuthIdentityUpdateOne) SetMetadata(v map[string]interface{}) *AuthIdentityUpdateOne { + _u.mutation.SetMetadata(v) + return _u +} + +// SetUser sets the "user" edge to the User entity. +func (_u *AuthIdentityUpdateOne) SetUser(v *User) *AuthIdentityUpdateOne { + return _u.SetUserID(v.ID) +} + +// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by IDs. +func (_u *AuthIdentityUpdateOne) AddChannelIDs(ids ...int64) *AuthIdentityUpdateOne { + _u.mutation.AddChannelIDs(ids...) + return _u +} + +// AddChannels adds the "channels" edges to the AuthIdentityChannel entity. +func (_u *AuthIdentityUpdateOne) AddChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddChannelIDs(ids...) +} + +// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs. +func (_u *AuthIdentityUpdateOne) AddAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdateOne { + _u.mutation.AddAdoptionDecisionIDs(ids...) + return _u +} + +// AddAdoptionDecisions adds the "adoption_decisions" edges to the IdentityAdoptionDecision entity. +func (_u *AuthIdentityUpdateOne) AddAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddAdoptionDecisionIDs(ids...) +} + +// Mutation returns the AuthIdentityMutation object of the builder. +func (_u *AuthIdentityUpdateOne) Mutation() *AuthIdentityMutation { + return _u.mutation +} + +// ClearUser clears the "user" edge to the User entity. +func (_u *AuthIdentityUpdateOne) ClearUser() *AuthIdentityUpdateOne { + _u.mutation.ClearUser() + return _u +} + +// ClearChannels clears all "channels" edges to the AuthIdentityChannel entity. +func (_u *AuthIdentityUpdateOne) ClearChannels() *AuthIdentityUpdateOne { + _u.mutation.ClearChannels() + return _u +} + +// RemoveChannelIDs removes the "channels" edge to AuthIdentityChannel entities by IDs. +func (_u *AuthIdentityUpdateOne) RemoveChannelIDs(ids ...int64) *AuthIdentityUpdateOne { + _u.mutation.RemoveChannelIDs(ids...) + return _u +} + +// RemoveChannels removes "channels" edges to AuthIdentityChannel entities. +func (_u *AuthIdentityUpdateOne) RemoveChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveChannelIDs(ids...) +} + +// ClearAdoptionDecisions clears all "adoption_decisions" edges to the IdentityAdoptionDecision entity. +func (_u *AuthIdentityUpdateOne) ClearAdoptionDecisions() *AuthIdentityUpdateOne { + _u.mutation.ClearAdoptionDecisions() + return _u +} + +// RemoveAdoptionDecisionIDs removes the "adoption_decisions" edge to IdentityAdoptionDecision entities by IDs. +func (_u *AuthIdentityUpdateOne) RemoveAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdateOne { + _u.mutation.RemoveAdoptionDecisionIDs(ids...) + return _u +} + +// RemoveAdoptionDecisions removes "adoption_decisions" edges to IdentityAdoptionDecision entities. +func (_u *AuthIdentityUpdateOne) RemoveAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveAdoptionDecisionIDs(ids...) +} + +// Where appends a list predicates to the AuthIdentityUpdate builder. +func (_u *AuthIdentityUpdateOne) Where(ps ...predicate.AuthIdentity) *AuthIdentityUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *AuthIdentityUpdateOne) Select(field string, fields ...string) *AuthIdentityUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated AuthIdentity entity. +func (_u *AuthIdentityUpdateOne) Save(ctx context.Context) (*AuthIdentity, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *AuthIdentityUpdateOne) SaveX(ctx context.Context) *AuthIdentity { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *AuthIdentityUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *AuthIdentityUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *AuthIdentityUpdateOne) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := authidentity.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *AuthIdentityUpdateOne) check() error { + if v, ok := _u.mutation.ProviderType(); ok { + if err := authidentity.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_type": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderKey(); ok { + if err := authidentity.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_key": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderSubject(); ok { + if err := authidentity.ProviderSubjectValidator(v); err != nil { + return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_subject": %w`, err)} + } + } + if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "AuthIdentity.user"`) + } + return nil +} + +func (_u *AuthIdentityUpdateOne) sqlSave(ctx context.Context) (_node *AuthIdentity, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(authidentity.Table, authidentity.Columns, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "AuthIdentity.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, authidentity.FieldID) + for _, f := range fields { + if !authidentity.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != authidentity.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(authidentity.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.ProviderType(); ok { + _spec.SetField(authidentity.FieldProviderType, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderKey(); ok { + _spec.SetField(authidentity.FieldProviderKey, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderSubject(); ok { + _spec.SetField(authidentity.FieldProviderSubject, field.TypeString, value) + } + if value, ok := _u.mutation.VerifiedAt(); ok { + _spec.SetField(authidentity.FieldVerifiedAt, field.TypeTime, value) + } + if _u.mutation.VerifiedAtCleared() { + _spec.ClearField(authidentity.FieldVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.Issuer(); ok { + _spec.SetField(authidentity.FieldIssuer, field.TypeString, value) + } + if _u.mutation.IssuerCleared() { + _spec.ClearField(authidentity.FieldIssuer, field.TypeString) + } + if value, ok := _u.mutation.Metadata(); ok { + _spec.SetField(authidentity.FieldMetadata, field.TypeJSON, value) + } + if _u.mutation.UserCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentity.UserTable, + Columns: []string{authidentity.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentity.UserTable, + Columns: []string{authidentity.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.ChannelsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.ChannelsTable, + Columns: []string{authidentity.ChannelsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedChannelsIDs(); len(nodes) > 0 && !_u.mutation.ChannelsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.ChannelsTable, + Columns: []string{authidentity.ChannelsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.ChannelsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.ChannelsTable, + Columns: []string{authidentity.ChannelsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.AdoptionDecisionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.AdoptionDecisionsTable, + Columns: []string{authidentity.AdoptionDecisionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedAdoptionDecisionsIDs(); len(nodes) > 0 && !_u.mutation.AdoptionDecisionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.AdoptionDecisionsTable, + Columns: []string{authidentity.AdoptionDecisionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.AdoptionDecisionsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.AdoptionDecisionsTable, + Columns: []string{authidentity.AdoptionDecisionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &AuthIdentity{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{authidentity.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/authidentitychannel.go b/backend/ent/authidentitychannel.go new file mode 100644 index 0000000000000000000000000000000000000000..1ff3e5d1c88a1d5e596e7ee998c6d14eb1f34407 --- /dev/null +++ b/backend/ent/authidentitychannel.go @@ -0,0 +1,228 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" +) + +// AuthIdentityChannel is the model entity for the AuthIdentityChannel schema. +type AuthIdentityChannel struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // IdentityID holds the value of the "identity_id" field. + IdentityID int64 `json:"identity_id,omitempty"` + // ProviderType holds the value of the "provider_type" field. + ProviderType string `json:"provider_type,omitempty"` + // ProviderKey holds the value of the "provider_key" field. + ProviderKey string `json:"provider_key,omitempty"` + // Channel holds the value of the "channel" field. + Channel string `json:"channel,omitempty"` + // ChannelAppID holds the value of the "channel_app_id" field. + ChannelAppID string `json:"channel_app_id,omitempty"` + // ChannelSubject holds the value of the "channel_subject" field. + ChannelSubject string `json:"channel_subject,omitempty"` + // Metadata holds the value of the "metadata" field. + Metadata map[string]interface{} `json:"metadata,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the AuthIdentityChannelQuery when eager-loading is set. + Edges AuthIdentityChannelEdges `json:"edges"` + selectValues sql.SelectValues +} + +// AuthIdentityChannelEdges holds the relations/edges for other nodes in the graph. +type AuthIdentityChannelEdges struct { + // Identity holds the value of the identity edge. + Identity *AuthIdentity `json:"identity,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [1]bool +} + +// IdentityOrErr returns the Identity value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e AuthIdentityChannelEdges) IdentityOrErr() (*AuthIdentity, error) { + if e.Identity != nil { + return e.Identity, nil + } else if e.loadedTypes[0] { + return nil, &NotFoundError{label: authidentity.Label} + } + return nil, &NotLoadedError{edge: "identity"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*AuthIdentityChannel) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case authidentitychannel.FieldMetadata: + values[i] = new([]byte) + case authidentitychannel.FieldID, authidentitychannel.FieldIdentityID: + values[i] = new(sql.NullInt64) + case authidentitychannel.FieldProviderType, authidentitychannel.FieldProviderKey, authidentitychannel.FieldChannel, authidentitychannel.FieldChannelAppID, authidentitychannel.FieldChannelSubject: + values[i] = new(sql.NullString) + case authidentitychannel.FieldCreatedAt, authidentitychannel.FieldUpdatedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the AuthIdentityChannel fields. +func (_m *AuthIdentityChannel) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case authidentitychannel.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case authidentitychannel.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case authidentitychannel.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + _m.UpdatedAt = value.Time + } + case authidentitychannel.FieldIdentityID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field identity_id", values[i]) + } else if value.Valid { + _m.IdentityID = value.Int64 + } + case authidentitychannel.FieldProviderType: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_type", values[i]) + } else if value.Valid { + _m.ProviderType = value.String + } + case authidentitychannel.FieldProviderKey: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_key", values[i]) + } else if value.Valid { + _m.ProviderKey = value.String + } + case authidentitychannel.FieldChannel: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field channel", values[i]) + } else if value.Valid { + _m.Channel = value.String + } + case authidentitychannel.FieldChannelAppID: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field channel_app_id", values[i]) + } else if value.Valid { + _m.ChannelAppID = value.String + } + case authidentitychannel.FieldChannelSubject: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field channel_subject", values[i]) + } else if value.Valid { + _m.ChannelSubject = value.String + } + case authidentitychannel.FieldMetadata: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field metadata", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.Metadata); err != nil { + return fmt.Errorf("unmarshal field metadata: %w", err) + } + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the AuthIdentityChannel. +// This includes values selected through modifiers, order, etc. +func (_m *AuthIdentityChannel) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// QueryIdentity queries the "identity" edge of the AuthIdentityChannel entity. +func (_m *AuthIdentityChannel) QueryIdentity() *AuthIdentityQuery { + return NewAuthIdentityChannelClient(_m.config).QueryIdentity(_m) +} + +// Update returns a builder for updating this AuthIdentityChannel. +// Note that you need to call AuthIdentityChannel.Unwrap() before calling this method if this AuthIdentityChannel +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *AuthIdentityChannel) Update() *AuthIdentityChannelUpdateOne { + return NewAuthIdentityChannelClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the AuthIdentityChannel entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *AuthIdentityChannel) Unwrap() *AuthIdentityChannel { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: AuthIdentityChannel is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *AuthIdentityChannel) String() string { + var builder strings.Builder + builder.WriteString("AuthIdentityChannel(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("identity_id=") + builder.WriteString(fmt.Sprintf("%v", _m.IdentityID)) + builder.WriteString(", ") + builder.WriteString("provider_type=") + builder.WriteString(_m.ProviderType) + builder.WriteString(", ") + builder.WriteString("provider_key=") + builder.WriteString(_m.ProviderKey) + builder.WriteString(", ") + builder.WriteString("channel=") + builder.WriteString(_m.Channel) + builder.WriteString(", ") + builder.WriteString("channel_app_id=") + builder.WriteString(_m.ChannelAppID) + builder.WriteString(", ") + builder.WriteString("channel_subject=") + builder.WriteString(_m.ChannelSubject) + builder.WriteString(", ") + builder.WriteString("metadata=") + builder.WriteString(fmt.Sprintf("%v", _m.Metadata)) + builder.WriteByte(')') + return builder.String() +} + +// AuthIdentityChannels is a parsable slice of AuthIdentityChannel. +type AuthIdentityChannels []*AuthIdentityChannel diff --git a/backend/ent/authidentitychannel/authidentitychannel.go b/backend/ent/authidentitychannel/authidentitychannel.go new file mode 100644 index 0000000000000000000000000000000000000000..7dcc98bb60b185f24fa7f97de257c4cb480ca811 --- /dev/null +++ b/backend/ent/authidentitychannel/authidentitychannel.go @@ -0,0 +1,153 @@ +// Code generated by ent, DO NOT EDIT. + +package authidentitychannel + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the authidentitychannel type in the database. + Label = "auth_identity_channel" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldIdentityID holds the string denoting the identity_id field in the database. + FieldIdentityID = "identity_id" + // FieldProviderType holds the string denoting the provider_type field in the database. + FieldProviderType = "provider_type" + // FieldProviderKey holds the string denoting the provider_key field in the database. + FieldProviderKey = "provider_key" + // FieldChannel holds the string denoting the channel field in the database. + FieldChannel = "channel" + // FieldChannelAppID holds the string denoting the channel_app_id field in the database. + FieldChannelAppID = "channel_app_id" + // FieldChannelSubject holds the string denoting the channel_subject field in the database. + FieldChannelSubject = "channel_subject" + // FieldMetadata holds the string denoting the metadata field in the database. + FieldMetadata = "metadata" + // EdgeIdentity holds the string denoting the identity edge name in mutations. + EdgeIdentity = "identity" + // Table holds the table name of the authidentitychannel in the database. + Table = "auth_identity_channels" + // IdentityTable is the table that holds the identity relation/edge. + IdentityTable = "auth_identity_channels" + // IdentityInverseTable is the table name for the AuthIdentity entity. + // It exists in this package in order to avoid circular dependency with the "authidentity" package. + IdentityInverseTable = "auth_identities" + // IdentityColumn is the table column denoting the identity relation/edge. + IdentityColumn = "identity_id" +) + +// Columns holds all SQL columns for authidentitychannel fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldIdentityID, + FieldProviderType, + FieldProviderKey, + FieldChannel, + FieldChannelAppID, + FieldChannelSubject, + FieldMetadata, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save. + ProviderTypeValidator func(string) error + // ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save. + ProviderKeyValidator func(string) error + // ChannelValidator is a validator for the "channel" field. It is called by the builders before save. + ChannelValidator func(string) error + // ChannelAppIDValidator is a validator for the "channel_app_id" field. It is called by the builders before save. + ChannelAppIDValidator func(string) error + // ChannelSubjectValidator is a validator for the "channel_subject" field. It is called by the builders before save. + ChannelSubjectValidator func(string) error + // DefaultMetadata holds the default value on creation for the "metadata" field. + DefaultMetadata func() map[string]interface{} +) + +// OrderOption defines the ordering options for the AuthIdentityChannel queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByIdentityID orders the results by the identity_id field. +func ByIdentityID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIdentityID, opts...).ToFunc() +} + +// ByProviderType orders the results by the provider_type field. +func ByProviderType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderType, opts...).ToFunc() +} + +// ByProviderKey orders the results by the provider_key field. +func ByProviderKey(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderKey, opts...).ToFunc() +} + +// ByChannel orders the results by the channel field. +func ByChannel(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldChannel, opts...).ToFunc() +} + +// ByChannelAppID orders the results by the channel_app_id field. +func ByChannelAppID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldChannelAppID, opts...).ToFunc() +} + +// ByChannelSubject orders the results by the channel_subject field. +func ByChannelSubject(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldChannelSubject, opts...).ToFunc() +} + +// ByIdentityField orders the results by identity field. +func ByIdentityField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newIdentityStep(), sql.OrderByField(field, opts...)) + } +} +func newIdentityStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(IdentityInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn), + ) +} diff --git a/backend/ent/authidentitychannel/where.go b/backend/ent/authidentitychannel/where.go new file mode 100644 index 0000000000000000000000000000000000000000..827dc38450ede777c9442f4855c6c234ad03f1c2 --- /dev/null +++ b/backend/ent/authidentitychannel/where.go @@ -0,0 +1,559 @@ +// Code generated by ent, DO NOT EDIT. + +package authidentitychannel + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// IdentityID applies equality check predicate on the "identity_id" field. It's identical to IdentityIDEQ. +func IdentityID(v int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldIdentityID, v)) +} + +// ProviderType applies equality check predicate on the "provider_type" field. It's identical to ProviderTypeEQ. +func ProviderType(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderType, v)) +} + +// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ. +func ProviderKey(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderKey, v)) +} + +// Channel applies equality check predicate on the "channel" field. It's identical to ChannelEQ. +func Channel(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannel, v)) +} + +// ChannelAppID applies equality check predicate on the "channel_app_id" field. It's identical to ChannelAppIDEQ. +func ChannelAppID(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelAppID, v)) +} + +// ChannelSubject applies equality check predicate on the "channel_subject" field. It's identical to ChannelSubjectEQ. +func ChannelSubject(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelSubject, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// IdentityIDEQ applies the EQ predicate on the "identity_id" field. +func IdentityIDEQ(v int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldIdentityID, v)) +} + +// IdentityIDNEQ applies the NEQ predicate on the "identity_id" field. +func IdentityIDNEQ(v int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldIdentityID, v)) +} + +// IdentityIDIn applies the In predicate on the "identity_id" field. +func IdentityIDIn(vs ...int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldIdentityID, vs...)) +} + +// IdentityIDNotIn applies the NotIn predicate on the "identity_id" field. +func IdentityIDNotIn(vs ...int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldIdentityID, vs...)) +} + +// ProviderTypeEQ applies the EQ predicate on the "provider_type" field. +func ProviderTypeEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderType, v)) +} + +// ProviderTypeNEQ applies the NEQ predicate on the "provider_type" field. +func ProviderTypeNEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldProviderType, v)) +} + +// ProviderTypeIn applies the In predicate on the "provider_type" field. +func ProviderTypeIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldProviderType, vs...)) +} + +// ProviderTypeNotIn applies the NotIn predicate on the "provider_type" field. +func ProviderTypeNotIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldProviderType, vs...)) +} + +// ProviderTypeGT applies the GT predicate on the "provider_type" field. +func ProviderTypeGT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldProviderType, v)) +} + +// ProviderTypeGTE applies the GTE predicate on the "provider_type" field. +func ProviderTypeGTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldProviderType, v)) +} + +// ProviderTypeLT applies the LT predicate on the "provider_type" field. +func ProviderTypeLT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldProviderType, v)) +} + +// ProviderTypeLTE applies the LTE predicate on the "provider_type" field. +func ProviderTypeLTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldProviderType, v)) +} + +// ProviderTypeContains applies the Contains predicate on the "provider_type" field. +func ProviderTypeContains(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContains(FieldProviderType, v)) +} + +// ProviderTypeHasPrefix applies the HasPrefix predicate on the "provider_type" field. +func ProviderTypeHasPrefix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldProviderType, v)) +} + +// ProviderTypeHasSuffix applies the HasSuffix predicate on the "provider_type" field. +func ProviderTypeHasSuffix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldProviderType, v)) +} + +// ProviderTypeEqualFold applies the EqualFold predicate on the "provider_type" field. +func ProviderTypeEqualFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldProviderType, v)) +} + +// ProviderTypeContainsFold applies the ContainsFold predicate on the "provider_type" field. +func ProviderTypeContainsFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldProviderType, v)) +} + +// ProviderKeyEQ applies the EQ predicate on the "provider_key" field. +func ProviderKeyEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderKey, v)) +} + +// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field. +func ProviderKeyNEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldProviderKey, v)) +} + +// ProviderKeyIn applies the In predicate on the "provider_key" field. +func ProviderKeyIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldProviderKey, vs...)) +} + +// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field. +func ProviderKeyNotIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldProviderKey, vs...)) +} + +// ProviderKeyGT applies the GT predicate on the "provider_key" field. +func ProviderKeyGT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldProviderKey, v)) +} + +// ProviderKeyGTE applies the GTE predicate on the "provider_key" field. +func ProviderKeyGTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldProviderKey, v)) +} + +// ProviderKeyLT applies the LT predicate on the "provider_key" field. +func ProviderKeyLT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldProviderKey, v)) +} + +// ProviderKeyLTE applies the LTE predicate on the "provider_key" field. +func ProviderKeyLTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldProviderKey, v)) +} + +// ProviderKeyContains applies the Contains predicate on the "provider_key" field. +func ProviderKeyContains(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContains(FieldProviderKey, v)) +} + +// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field. +func ProviderKeyHasPrefix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldProviderKey, v)) +} + +// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field. +func ProviderKeyHasSuffix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldProviderKey, v)) +} + +// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field. +func ProviderKeyEqualFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldProviderKey, v)) +} + +// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field. +func ProviderKeyContainsFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldProviderKey, v)) +} + +// ChannelEQ applies the EQ predicate on the "channel" field. +func ChannelEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannel, v)) +} + +// ChannelNEQ applies the NEQ predicate on the "channel" field. +func ChannelNEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldChannel, v)) +} + +// ChannelIn applies the In predicate on the "channel" field. +func ChannelIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldChannel, vs...)) +} + +// ChannelNotIn applies the NotIn predicate on the "channel" field. +func ChannelNotIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldChannel, vs...)) +} + +// ChannelGT applies the GT predicate on the "channel" field. +func ChannelGT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldChannel, v)) +} + +// ChannelGTE applies the GTE predicate on the "channel" field. +func ChannelGTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldChannel, v)) +} + +// ChannelLT applies the LT predicate on the "channel" field. +func ChannelLT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldChannel, v)) +} + +// ChannelLTE applies the LTE predicate on the "channel" field. +func ChannelLTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldChannel, v)) +} + +// ChannelContains applies the Contains predicate on the "channel" field. +func ChannelContains(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContains(FieldChannel, v)) +} + +// ChannelHasPrefix applies the HasPrefix predicate on the "channel" field. +func ChannelHasPrefix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldChannel, v)) +} + +// ChannelHasSuffix applies the HasSuffix predicate on the "channel" field. +func ChannelHasSuffix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldChannel, v)) +} + +// ChannelEqualFold applies the EqualFold predicate on the "channel" field. +func ChannelEqualFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldChannel, v)) +} + +// ChannelContainsFold applies the ContainsFold predicate on the "channel" field. +func ChannelContainsFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldChannel, v)) +} + +// ChannelAppIDEQ applies the EQ predicate on the "channel_app_id" field. +func ChannelAppIDEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelAppID, v)) +} + +// ChannelAppIDNEQ applies the NEQ predicate on the "channel_app_id" field. +func ChannelAppIDNEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldChannelAppID, v)) +} + +// ChannelAppIDIn applies the In predicate on the "channel_app_id" field. +func ChannelAppIDIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldChannelAppID, vs...)) +} + +// ChannelAppIDNotIn applies the NotIn predicate on the "channel_app_id" field. +func ChannelAppIDNotIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldChannelAppID, vs...)) +} + +// ChannelAppIDGT applies the GT predicate on the "channel_app_id" field. +func ChannelAppIDGT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldChannelAppID, v)) +} + +// ChannelAppIDGTE applies the GTE predicate on the "channel_app_id" field. +func ChannelAppIDGTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldChannelAppID, v)) +} + +// ChannelAppIDLT applies the LT predicate on the "channel_app_id" field. +func ChannelAppIDLT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldChannelAppID, v)) +} + +// ChannelAppIDLTE applies the LTE predicate on the "channel_app_id" field. +func ChannelAppIDLTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldChannelAppID, v)) +} + +// ChannelAppIDContains applies the Contains predicate on the "channel_app_id" field. +func ChannelAppIDContains(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContains(FieldChannelAppID, v)) +} + +// ChannelAppIDHasPrefix applies the HasPrefix predicate on the "channel_app_id" field. +func ChannelAppIDHasPrefix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldChannelAppID, v)) +} + +// ChannelAppIDHasSuffix applies the HasSuffix predicate on the "channel_app_id" field. +func ChannelAppIDHasSuffix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldChannelAppID, v)) +} + +// ChannelAppIDEqualFold applies the EqualFold predicate on the "channel_app_id" field. +func ChannelAppIDEqualFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldChannelAppID, v)) +} + +// ChannelAppIDContainsFold applies the ContainsFold predicate on the "channel_app_id" field. +func ChannelAppIDContainsFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldChannelAppID, v)) +} + +// ChannelSubjectEQ applies the EQ predicate on the "channel_subject" field. +func ChannelSubjectEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelSubject, v)) +} + +// ChannelSubjectNEQ applies the NEQ predicate on the "channel_subject" field. +func ChannelSubjectNEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldChannelSubject, v)) +} + +// ChannelSubjectIn applies the In predicate on the "channel_subject" field. +func ChannelSubjectIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldChannelSubject, vs...)) +} + +// ChannelSubjectNotIn applies the NotIn predicate on the "channel_subject" field. +func ChannelSubjectNotIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldChannelSubject, vs...)) +} + +// ChannelSubjectGT applies the GT predicate on the "channel_subject" field. +func ChannelSubjectGT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldChannelSubject, v)) +} + +// ChannelSubjectGTE applies the GTE predicate on the "channel_subject" field. +func ChannelSubjectGTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldChannelSubject, v)) +} + +// ChannelSubjectLT applies the LT predicate on the "channel_subject" field. +func ChannelSubjectLT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldChannelSubject, v)) +} + +// ChannelSubjectLTE applies the LTE predicate on the "channel_subject" field. +func ChannelSubjectLTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldChannelSubject, v)) +} + +// ChannelSubjectContains applies the Contains predicate on the "channel_subject" field. +func ChannelSubjectContains(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContains(FieldChannelSubject, v)) +} + +// ChannelSubjectHasPrefix applies the HasPrefix predicate on the "channel_subject" field. +func ChannelSubjectHasPrefix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldChannelSubject, v)) +} + +// ChannelSubjectHasSuffix applies the HasSuffix predicate on the "channel_subject" field. +func ChannelSubjectHasSuffix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldChannelSubject, v)) +} + +// ChannelSubjectEqualFold applies the EqualFold predicate on the "channel_subject" field. +func ChannelSubjectEqualFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldChannelSubject, v)) +} + +// ChannelSubjectContainsFold applies the ContainsFold predicate on the "channel_subject" field. +func ChannelSubjectContainsFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldChannelSubject, v)) +} + +// HasIdentity applies the HasEdge predicate on the "identity" edge. +func HasIdentity() predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasIdentityWith applies the HasEdge predicate on the "identity" edge with a given conditions (other predicates). +func HasIdentityWith(preds ...predicate.AuthIdentity) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(func(s *sql.Selector) { + step := newIdentityStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.AuthIdentityChannel) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.AuthIdentityChannel) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.AuthIdentityChannel) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.NotPredicates(p)) +} diff --git a/backend/ent/authidentitychannel_create.go b/backend/ent/authidentitychannel_create.go new file mode 100644 index 0000000000000000000000000000000000000000..4ce284792b16cdad35c6817a6b256cd0ec366be5 --- /dev/null +++ b/backend/ent/authidentitychannel_create.go @@ -0,0 +1,932 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" +) + +// AuthIdentityChannelCreate is the builder for creating a AuthIdentityChannel entity. +type AuthIdentityChannelCreate struct { + config + mutation *AuthIdentityChannelMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (_c *AuthIdentityChannelCreate) SetCreatedAt(v time.Time) *AuthIdentityChannelCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *AuthIdentityChannelCreate) SetNillableCreatedAt(v *time.Time) *AuthIdentityChannelCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUpdatedAt sets the "updated_at" field. +func (_c *AuthIdentityChannelCreate) SetUpdatedAt(v time.Time) *AuthIdentityChannelCreate { + _c.mutation.SetUpdatedAt(v) + return _c +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (_c *AuthIdentityChannelCreate) SetNillableUpdatedAt(v *time.Time) *AuthIdentityChannelCreate { + if v != nil { + _c.SetUpdatedAt(*v) + } + return _c +} + +// SetIdentityID sets the "identity_id" field. +func (_c *AuthIdentityChannelCreate) SetIdentityID(v int64) *AuthIdentityChannelCreate { + _c.mutation.SetIdentityID(v) + return _c +} + +// SetProviderType sets the "provider_type" field. +func (_c *AuthIdentityChannelCreate) SetProviderType(v string) *AuthIdentityChannelCreate { + _c.mutation.SetProviderType(v) + return _c +} + +// SetProviderKey sets the "provider_key" field. +func (_c *AuthIdentityChannelCreate) SetProviderKey(v string) *AuthIdentityChannelCreate { + _c.mutation.SetProviderKey(v) + return _c +} + +// SetChannel sets the "channel" field. +func (_c *AuthIdentityChannelCreate) SetChannel(v string) *AuthIdentityChannelCreate { + _c.mutation.SetChannel(v) + return _c +} + +// SetChannelAppID sets the "channel_app_id" field. +func (_c *AuthIdentityChannelCreate) SetChannelAppID(v string) *AuthIdentityChannelCreate { + _c.mutation.SetChannelAppID(v) + return _c +} + +// SetChannelSubject sets the "channel_subject" field. +func (_c *AuthIdentityChannelCreate) SetChannelSubject(v string) *AuthIdentityChannelCreate { + _c.mutation.SetChannelSubject(v) + return _c +} + +// SetMetadata sets the "metadata" field. +func (_c *AuthIdentityChannelCreate) SetMetadata(v map[string]interface{}) *AuthIdentityChannelCreate { + _c.mutation.SetMetadata(v) + return _c +} + +// SetIdentity sets the "identity" edge to the AuthIdentity entity. +func (_c *AuthIdentityChannelCreate) SetIdentity(v *AuthIdentity) *AuthIdentityChannelCreate { + return _c.SetIdentityID(v.ID) +} + +// Mutation returns the AuthIdentityChannelMutation object of the builder. +func (_c *AuthIdentityChannelCreate) Mutation() *AuthIdentityChannelMutation { + return _c.mutation +} + +// Save creates the AuthIdentityChannel in the database. +func (_c *AuthIdentityChannelCreate) Save(ctx context.Context) (*AuthIdentityChannel, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *AuthIdentityChannelCreate) SaveX(ctx context.Context) *AuthIdentityChannel { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *AuthIdentityChannelCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *AuthIdentityChannelCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *AuthIdentityChannelCreate) defaults() { + if _, ok := _c.mutation.CreatedAt(); !ok { + v := authidentitychannel.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + v := authidentitychannel.DefaultUpdatedAt() + _c.mutation.SetUpdatedAt(v) + } + if _, ok := _c.mutation.Metadata(); !ok { + v := authidentitychannel.DefaultMetadata() + _c.mutation.SetMetadata(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *AuthIdentityChannelCreate) check() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "AuthIdentityChannel.created_at"`)} + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "AuthIdentityChannel.updated_at"`)} + } + if _, ok := _c.mutation.IdentityID(); !ok { + return &ValidationError{Name: "identity_id", err: errors.New(`ent: missing required field "AuthIdentityChannel.identity_id"`)} + } + if _, ok := _c.mutation.ProviderType(); !ok { + return &ValidationError{Name: "provider_type", err: errors.New(`ent: missing required field "AuthIdentityChannel.provider_type"`)} + } + if v, ok := _c.mutation.ProviderType(); ok { + if err := authidentitychannel.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_type": %w`, err)} + } + } + if _, ok := _c.mutation.ProviderKey(); !ok { + return &ValidationError{Name: "provider_key", err: errors.New(`ent: missing required field "AuthIdentityChannel.provider_key"`)} + } + if v, ok := _c.mutation.ProviderKey(); ok { + if err := authidentitychannel.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_key": %w`, err)} + } + } + if _, ok := _c.mutation.Channel(); !ok { + return &ValidationError{Name: "channel", err: errors.New(`ent: missing required field "AuthIdentityChannel.channel"`)} + } + if v, ok := _c.mutation.Channel(); ok { + if err := authidentitychannel.ChannelValidator(v); err != nil { + return &ValidationError{Name: "channel", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel": %w`, err)} + } + } + if _, ok := _c.mutation.ChannelAppID(); !ok { + return &ValidationError{Name: "channel_app_id", err: errors.New(`ent: missing required field "AuthIdentityChannel.channel_app_id"`)} + } + if v, ok := _c.mutation.ChannelAppID(); ok { + if err := authidentitychannel.ChannelAppIDValidator(v); err != nil { + return &ValidationError{Name: "channel_app_id", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_app_id": %w`, err)} + } + } + if _, ok := _c.mutation.ChannelSubject(); !ok { + return &ValidationError{Name: "channel_subject", err: errors.New(`ent: missing required field "AuthIdentityChannel.channel_subject"`)} + } + if v, ok := _c.mutation.ChannelSubject(); ok { + if err := authidentitychannel.ChannelSubjectValidator(v); err != nil { + return &ValidationError{Name: "channel_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_subject": %w`, err)} + } + } + if _, ok := _c.mutation.Metadata(); !ok { + return &ValidationError{Name: "metadata", err: errors.New(`ent: missing required field "AuthIdentityChannel.metadata"`)} + } + if len(_c.mutation.IdentityIDs()) == 0 { + return &ValidationError{Name: "identity", err: errors.New(`ent: missing required edge "AuthIdentityChannel.identity"`)} + } + return nil +} + +func (_c *AuthIdentityChannelCreate) sqlSave(ctx context.Context) (*AuthIdentityChannel, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *AuthIdentityChannelCreate) createSpec() (*AuthIdentityChannel, *sqlgraph.CreateSpec) { + var ( + _node = &AuthIdentityChannel{config: _c.config} + _spec = sqlgraph.NewCreateSpec(authidentitychannel.Table, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(authidentitychannel.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.UpdatedAt(); ok { + _spec.SetField(authidentitychannel.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := _c.mutation.ProviderType(); ok { + _spec.SetField(authidentitychannel.FieldProviderType, field.TypeString, value) + _node.ProviderType = value + } + if value, ok := _c.mutation.ProviderKey(); ok { + _spec.SetField(authidentitychannel.FieldProviderKey, field.TypeString, value) + _node.ProviderKey = value + } + if value, ok := _c.mutation.Channel(); ok { + _spec.SetField(authidentitychannel.FieldChannel, field.TypeString, value) + _node.Channel = value + } + if value, ok := _c.mutation.ChannelAppID(); ok { + _spec.SetField(authidentitychannel.FieldChannelAppID, field.TypeString, value) + _node.ChannelAppID = value + } + if value, ok := _c.mutation.ChannelSubject(); ok { + _spec.SetField(authidentitychannel.FieldChannelSubject, field.TypeString, value) + _node.ChannelSubject = value + } + if value, ok := _c.mutation.Metadata(); ok { + _spec.SetField(authidentitychannel.FieldMetadata, field.TypeJSON, value) + _node.Metadata = value + } + if nodes := _c.mutation.IdentityIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentitychannel.IdentityTable, + Columns: []string{authidentitychannel.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.IdentityID = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.AuthIdentityChannel.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.AuthIdentityChannelUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *AuthIdentityChannelCreate) OnConflict(opts ...sql.ConflictOption) *AuthIdentityChannelUpsertOne { + _c.conflict = opts + return &AuthIdentityChannelUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.AuthIdentityChannel.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *AuthIdentityChannelCreate) OnConflictColumns(columns ...string) *AuthIdentityChannelUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &AuthIdentityChannelUpsertOne{ + create: _c, + } +} + +type ( + // AuthIdentityChannelUpsertOne is the builder for "upsert"-ing + // one AuthIdentityChannel node. + AuthIdentityChannelUpsertOne struct { + create *AuthIdentityChannelCreate + } + + // AuthIdentityChannelUpsert is the "OnConflict" setter. + AuthIdentityChannelUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *AuthIdentityChannelUpsert) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateUpdatedAt() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldUpdatedAt) + return u +} + +// SetIdentityID sets the "identity_id" field. +func (u *AuthIdentityChannelUpsert) SetIdentityID(v int64) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldIdentityID, v) + return u +} + +// UpdateIdentityID sets the "identity_id" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateIdentityID() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldIdentityID) + return u +} + +// SetProviderType sets the "provider_type" field. +func (u *AuthIdentityChannelUpsert) SetProviderType(v string) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldProviderType, v) + return u +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateProviderType() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldProviderType) + return u +} + +// SetProviderKey sets the "provider_key" field. +func (u *AuthIdentityChannelUpsert) SetProviderKey(v string) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldProviderKey, v) + return u +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateProviderKey() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldProviderKey) + return u +} + +// SetChannel sets the "channel" field. +func (u *AuthIdentityChannelUpsert) SetChannel(v string) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldChannel, v) + return u +} + +// UpdateChannel sets the "channel" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateChannel() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldChannel) + return u +} + +// SetChannelAppID sets the "channel_app_id" field. +func (u *AuthIdentityChannelUpsert) SetChannelAppID(v string) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldChannelAppID, v) + return u +} + +// UpdateChannelAppID sets the "channel_app_id" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateChannelAppID() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldChannelAppID) + return u +} + +// SetChannelSubject sets the "channel_subject" field. +func (u *AuthIdentityChannelUpsert) SetChannelSubject(v string) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldChannelSubject, v) + return u +} + +// UpdateChannelSubject sets the "channel_subject" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateChannelSubject() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldChannelSubject) + return u +} + +// SetMetadata sets the "metadata" field. +func (u *AuthIdentityChannelUpsert) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldMetadata, v) + return u +} + +// UpdateMetadata sets the "metadata" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateMetadata() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldMetadata) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.AuthIdentityChannel.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *AuthIdentityChannelUpsertOne) UpdateNewValues() *AuthIdentityChannelUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(authidentitychannel.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.AuthIdentityChannel.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *AuthIdentityChannelUpsertOne) Ignore() *AuthIdentityChannelUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *AuthIdentityChannelUpsertOne) DoNothing() *AuthIdentityChannelUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the AuthIdentityChannelCreate.OnConflict +// documentation for more info. +func (u *AuthIdentityChannelUpsertOne) Update(set func(*AuthIdentityChannelUpsert)) *AuthIdentityChannelUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&AuthIdentityChannelUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *AuthIdentityChannelUpsertOne) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateUpdatedAt() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetIdentityID sets the "identity_id" field. +func (u *AuthIdentityChannelUpsertOne) SetIdentityID(v int64) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetIdentityID(v) + }) +} + +// UpdateIdentityID sets the "identity_id" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateIdentityID() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateIdentityID() + }) +} + +// SetProviderType sets the "provider_type" field. +func (u *AuthIdentityChannelUpsertOne) SetProviderType(v string) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetProviderType(v) + }) +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateProviderType() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateProviderType() + }) +} + +// SetProviderKey sets the "provider_key" field. +func (u *AuthIdentityChannelUpsertOne) SetProviderKey(v string) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetProviderKey(v) + }) +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateProviderKey() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateProviderKey() + }) +} + +// SetChannel sets the "channel" field. +func (u *AuthIdentityChannelUpsertOne) SetChannel(v string) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetChannel(v) + }) +} + +// UpdateChannel sets the "channel" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateChannel() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateChannel() + }) +} + +// SetChannelAppID sets the "channel_app_id" field. +func (u *AuthIdentityChannelUpsertOne) SetChannelAppID(v string) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetChannelAppID(v) + }) +} + +// UpdateChannelAppID sets the "channel_app_id" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateChannelAppID() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateChannelAppID() + }) +} + +// SetChannelSubject sets the "channel_subject" field. +func (u *AuthIdentityChannelUpsertOne) SetChannelSubject(v string) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetChannelSubject(v) + }) +} + +// UpdateChannelSubject sets the "channel_subject" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateChannelSubject() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateChannelSubject() + }) +} + +// SetMetadata sets the "metadata" field. +func (u *AuthIdentityChannelUpsertOne) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetMetadata(v) + }) +} + +// UpdateMetadata sets the "metadata" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateMetadata() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateMetadata() + }) +} + +// Exec executes the query. +func (u *AuthIdentityChannelUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for AuthIdentityChannelCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *AuthIdentityChannelUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *AuthIdentityChannelUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *AuthIdentityChannelUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// AuthIdentityChannelCreateBulk is the builder for creating many AuthIdentityChannel entities in bulk. +type AuthIdentityChannelCreateBulk struct { + config + err error + builders []*AuthIdentityChannelCreate + conflict []sql.ConflictOption +} + +// Save creates the AuthIdentityChannel entities in the database. +func (_c *AuthIdentityChannelCreateBulk) Save(ctx context.Context) ([]*AuthIdentityChannel, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*AuthIdentityChannel, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*AuthIdentityChannelMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *AuthIdentityChannelCreateBulk) SaveX(ctx context.Context) []*AuthIdentityChannel { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *AuthIdentityChannelCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *AuthIdentityChannelCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.AuthIdentityChannel.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.AuthIdentityChannelUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *AuthIdentityChannelCreateBulk) OnConflict(opts ...sql.ConflictOption) *AuthIdentityChannelUpsertBulk { + _c.conflict = opts + return &AuthIdentityChannelUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.AuthIdentityChannel.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *AuthIdentityChannelCreateBulk) OnConflictColumns(columns ...string) *AuthIdentityChannelUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &AuthIdentityChannelUpsertBulk{ + create: _c, + } +} + +// AuthIdentityChannelUpsertBulk is the builder for "upsert"-ing +// a bulk of AuthIdentityChannel nodes. +type AuthIdentityChannelUpsertBulk struct { + create *AuthIdentityChannelCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.AuthIdentityChannel.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *AuthIdentityChannelUpsertBulk) UpdateNewValues() *AuthIdentityChannelUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(authidentitychannel.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.AuthIdentityChannel.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *AuthIdentityChannelUpsertBulk) Ignore() *AuthIdentityChannelUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *AuthIdentityChannelUpsertBulk) DoNothing() *AuthIdentityChannelUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the AuthIdentityChannelCreateBulk.OnConflict +// documentation for more info. +func (u *AuthIdentityChannelUpsertBulk) Update(set func(*AuthIdentityChannelUpsert)) *AuthIdentityChannelUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&AuthIdentityChannelUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *AuthIdentityChannelUpsertBulk) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateUpdatedAt() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetIdentityID sets the "identity_id" field. +func (u *AuthIdentityChannelUpsertBulk) SetIdentityID(v int64) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetIdentityID(v) + }) +} + +// UpdateIdentityID sets the "identity_id" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateIdentityID() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateIdentityID() + }) +} + +// SetProviderType sets the "provider_type" field. +func (u *AuthIdentityChannelUpsertBulk) SetProviderType(v string) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetProviderType(v) + }) +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateProviderType() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateProviderType() + }) +} + +// SetProviderKey sets the "provider_key" field. +func (u *AuthIdentityChannelUpsertBulk) SetProviderKey(v string) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetProviderKey(v) + }) +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateProviderKey() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateProviderKey() + }) +} + +// SetChannel sets the "channel" field. +func (u *AuthIdentityChannelUpsertBulk) SetChannel(v string) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetChannel(v) + }) +} + +// UpdateChannel sets the "channel" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateChannel() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateChannel() + }) +} + +// SetChannelAppID sets the "channel_app_id" field. +func (u *AuthIdentityChannelUpsertBulk) SetChannelAppID(v string) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetChannelAppID(v) + }) +} + +// UpdateChannelAppID sets the "channel_app_id" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateChannelAppID() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateChannelAppID() + }) +} + +// SetChannelSubject sets the "channel_subject" field. +func (u *AuthIdentityChannelUpsertBulk) SetChannelSubject(v string) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetChannelSubject(v) + }) +} + +// UpdateChannelSubject sets the "channel_subject" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateChannelSubject() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateChannelSubject() + }) +} + +// SetMetadata sets the "metadata" field. +func (u *AuthIdentityChannelUpsertBulk) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetMetadata(v) + }) +} + +// UpdateMetadata sets the "metadata" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateMetadata() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateMetadata() + }) +} + +// Exec executes the query. +func (u *AuthIdentityChannelUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the AuthIdentityChannelCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for AuthIdentityChannelCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *AuthIdentityChannelUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/authidentitychannel_delete.go b/backend/ent/authidentitychannel_delete.go new file mode 100644 index 0000000000000000000000000000000000000000..1a4acac59063fad5d80c787dec54c431928672e8 --- /dev/null +++ b/backend/ent/authidentitychannel_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// AuthIdentityChannelDelete is the builder for deleting a AuthIdentityChannel entity. +type AuthIdentityChannelDelete struct { + config + hooks []Hook + mutation *AuthIdentityChannelMutation +} + +// Where appends a list predicates to the AuthIdentityChannelDelete builder. +func (_d *AuthIdentityChannelDelete) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *AuthIdentityChannelDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *AuthIdentityChannelDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *AuthIdentityChannelDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(authidentitychannel.Table, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// AuthIdentityChannelDeleteOne is the builder for deleting a single AuthIdentityChannel entity. +type AuthIdentityChannelDeleteOne struct { + _d *AuthIdentityChannelDelete +} + +// Where appends a list predicates to the AuthIdentityChannelDelete builder. +func (_d *AuthIdentityChannelDeleteOne) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *AuthIdentityChannelDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{authidentitychannel.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *AuthIdentityChannelDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/authidentitychannel_query.go b/backend/ent/authidentitychannel_query.go new file mode 100644 index 0000000000000000000000000000000000000000..7a202b7f1fd9573923b6f29d7bc18e0373efacd4 --- /dev/null +++ b/backend/ent/authidentitychannel_query.go @@ -0,0 +1,643 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// AuthIdentityChannelQuery is the builder for querying AuthIdentityChannel entities. +type AuthIdentityChannelQuery struct { + config + ctx *QueryContext + order []authidentitychannel.OrderOption + inters []Interceptor + predicates []predicate.AuthIdentityChannel + withIdentity *AuthIdentityQuery + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the AuthIdentityChannelQuery builder. +func (_q *AuthIdentityChannelQuery) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *AuthIdentityChannelQuery) Limit(limit int) *AuthIdentityChannelQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *AuthIdentityChannelQuery) Offset(offset int) *AuthIdentityChannelQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *AuthIdentityChannelQuery) Unique(unique bool) *AuthIdentityChannelQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *AuthIdentityChannelQuery) Order(o ...authidentitychannel.OrderOption) *AuthIdentityChannelQuery { + _q.order = append(_q.order, o...) + return _q +} + +// QueryIdentity chains the current query on the "identity" edge. +func (_q *AuthIdentityChannelQuery) QueryIdentity() *AuthIdentityQuery { + query := (&AuthIdentityClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(authidentitychannel.Table, authidentitychannel.FieldID, selector), + sqlgraph.To(authidentity.Table, authidentity.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, authidentitychannel.IdentityTable, authidentitychannel.IdentityColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first AuthIdentityChannel entity from the query. +// Returns a *NotFoundError when no AuthIdentityChannel was found. +func (_q *AuthIdentityChannelQuery) First(ctx context.Context) (*AuthIdentityChannel, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{authidentitychannel.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) FirstX(ctx context.Context) *AuthIdentityChannel { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first AuthIdentityChannel ID from the query. +// Returns a *NotFoundError when no AuthIdentityChannel ID was found. +func (_q *AuthIdentityChannelQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{authidentitychannel.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single AuthIdentityChannel entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one AuthIdentityChannel entity is found. +// Returns a *NotFoundError when no AuthIdentityChannel entities are found. +func (_q *AuthIdentityChannelQuery) Only(ctx context.Context) (*AuthIdentityChannel, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{authidentitychannel.Label} + default: + return nil, &NotSingularError{authidentitychannel.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) OnlyX(ctx context.Context) *AuthIdentityChannel { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only AuthIdentityChannel ID in the query. +// Returns a *NotSingularError when more than one AuthIdentityChannel ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *AuthIdentityChannelQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{authidentitychannel.Label} + default: + err = &NotSingularError{authidentitychannel.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of AuthIdentityChannels. +func (_q *AuthIdentityChannelQuery) All(ctx context.Context) ([]*AuthIdentityChannel, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*AuthIdentityChannel, *AuthIdentityChannelQuery]() + return withInterceptors[[]*AuthIdentityChannel](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) AllX(ctx context.Context) []*AuthIdentityChannel { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of AuthIdentityChannel IDs. +func (_q *AuthIdentityChannelQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(authidentitychannel.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *AuthIdentityChannelQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*AuthIdentityChannelQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *AuthIdentityChannelQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the AuthIdentityChannelQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *AuthIdentityChannelQuery) Clone() *AuthIdentityChannelQuery { + if _q == nil { + return nil + } + return &AuthIdentityChannelQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]authidentitychannel.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.AuthIdentityChannel{}, _q.predicates...), + withIdentity: _q.withIdentity.Clone(), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// WithIdentity tells the query-builder to eager-load the nodes that are connected to +// the "identity" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *AuthIdentityChannelQuery) WithIdentity(opts ...func(*AuthIdentityQuery)) *AuthIdentityChannelQuery { + query := (&AuthIdentityClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withIdentity = query + return _q +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.AuthIdentityChannel.Query(). +// GroupBy(authidentitychannel.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *AuthIdentityChannelQuery) GroupBy(field string, fields ...string) *AuthIdentityChannelGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &AuthIdentityChannelGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = authidentitychannel.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.AuthIdentityChannel.Query(). +// Select(authidentitychannel.FieldCreatedAt). +// Scan(ctx, &v) +func (_q *AuthIdentityChannelQuery) Select(fields ...string) *AuthIdentityChannelSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &AuthIdentityChannelSelect{AuthIdentityChannelQuery: _q} + sbuild.label = authidentitychannel.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a AuthIdentityChannelSelect configured with the given aggregations. +func (_q *AuthIdentityChannelQuery) Aggregate(fns ...AggregateFunc) *AuthIdentityChannelSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *AuthIdentityChannelQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !authidentitychannel.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *AuthIdentityChannelQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*AuthIdentityChannel, error) { + var ( + nodes = []*AuthIdentityChannel{} + _spec = _q.querySpec() + loadedTypes = [1]bool{ + _q.withIdentity != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*AuthIdentityChannel).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &AuthIdentityChannel{config: _q.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := _q.withIdentity; query != nil { + if err := _q.loadIdentity(ctx, query, nodes, nil, + func(n *AuthIdentityChannel, e *AuthIdentity) { n.Edges.Identity = e }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (_q *AuthIdentityChannelQuery) loadIdentity(ctx context.Context, query *AuthIdentityQuery, nodes []*AuthIdentityChannel, init func(*AuthIdentityChannel), assign func(*AuthIdentityChannel, *AuthIdentity)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*AuthIdentityChannel) + for i := range nodes { + fk := nodes[i].IdentityID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(authidentity.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "identity_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + +func (_q *AuthIdentityChannelQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *AuthIdentityChannelQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(authidentitychannel.Table, authidentitychannel.Columns, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, authidentitychannel.FieldID) + for i := range fields { + if fields[i] != authidentitychannel.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if _q.withIdentity != nil { + _spec.Node.AddColumnOnce(authidentitychannel.FieldIdentityID) + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *AuthIdentityChannelQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(authidentitychannel.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = authidentitychannel.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *AuthIdentityChannelQuery) ForUpdate(opts ...sql.LockOption) *AuthIdentityChannelQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *AuthIdentityChannelQuery) ForShare(opts ...sql.LockOption) *AuthIdentityChannelQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// AuthIdentityChannelGroupBy is the group-by builder for AuthIdentityChannel entities. +type AuthIdentityChannelGroupBy struct { + selector + build *AuthIdentityChannelQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *AuthIdentityChannelGroupBy) Aggregate(fns ...AggregateFunc) *AuthIdentityChannelGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *AuthIdentityChannelGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*AuthIdentityChannelQuery, *AuthIdentityChannelGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *AuthIdentityChannelGroupBy) sqlScan(ctx context.Context, root *AuthIdentityChannelQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// AuthIdentityChannelSelect is the builder for selecting fields of AuthIdentityChannel entities. +type AuthIdentityChannelSelect struct { + *AuthIdentityChannelQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *AuthIdentityChannelSelect) Aggregate(fns ...AggregateFunc) *AuthIdentityChannelSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *AuthIdentityChannelSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*AuthIdentityChannelQuery, *AuthIdentityChannelSelect](ctx, _s.AuthIdentityChannelQuery, _s, _s.inters, v) +} + +func (_s *AuthIdentityChannelSelect) sqlScan(ctx context.Context, root *AuthIdentityChannelQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/authidentitychannel_update.go b/backend/ent/authidentitychannel_update.go new file mode 100644 index 0000000000000000000000000000000000000000..b550c4545fdf8187dff66fd9dc574920270dde9a --- /dev/null +++ b/backend/ent/authidentitychannel_update.go @@ -0,0 +1,581 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// AuthIdentityChannelUpdate is the builder for updating AuthIdentityChannel entities. +type AuthIdentityChannelUpdate struct { + config + hooks []Hook + mutation *AuthIdentityChannelMutation +} + +// Where appends a list predicates to the AuthIdentityChannelUpdate builder. +func (_u *AuthIdentityChannelUpdate) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *AuthIdentityChannelUpdate) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpdate { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetIdentityID sets the "identity_id" field. +func (_u *AuthIdentityChannelUpdate) SetIdentityID(v int64) *AuthIdentityChannelUpdate { + _u.mutation.SetIdentityID(v) + return _u +} + +// SetNillableIdentityID sets the "identity_id" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdate) SetNillableIdentityID(v *int64) *AuthIdentityChannelUpdate { + if v != nil { + _u.SetIdentityID(*v) + } + return _u +} + +// SetProviderType sets the "provider_type" field. +func (_u *AuthIdentityChannelUpdate) SetProviderType(v string) *AuthIdentityChannelUpdate { + _u.mutation.SetProviderType(v) + return _u +} + +// SetNillableProviderType sets the "provider_type" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdate) SetNillableProviderType(v *string) *AuthIdentityChannelUpdate { + if v != nil { + _u.SetProviderType(*v) + } + return _u +} + +// SetProviderKey sets the "provider_key" field. +func (_u *AuthIdentityChannelUpdate) SetProviderKey(v string) *AuthIdentityChannelUpdate { + _u.mutation.SetProviderKey(v) + return _u +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdate) SetNillableProviderKey(v *string) *AuthIdentityChannelUpdate { + if v != nil { + _u.SetProviderKey(*v) + } + return _u +} + +// SetChannel sets the "channel" field. +func (_u *AuthIdentityChannelUpdate) SetChannel(v string) *AuthIdentityChannelUpdate { + _u.mutation.SetChannel(v) + return _u +} + +// SetNillableChannel sets the "channel" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdate) SetNillableChannel(v *string) *AuthIdentityChannelUpdate { + if v != nil { + _u.SetChannel(*v) + } + return _u +} + +// SetChannelAppID sets the "channel_app_id" field. +func (_u *AuthIdentityChannelUpdate) SetChannelAppID(v string) *AuthIdentityChannelUpdate { + _u.mutation.SetChannelAppID(v) + return _u +} + +// SetNillableChannelAppID sets the "channel_app_id" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdate) SetNillableChannelAppID(v *string) *AuthIdentityChannelUpdate { + if v != nil { + _u.SetChannelAppID(*v) + } + return _u +} + +// SetChannelSubject sets the "channel_subject" field. +func (_u *AuthIdentityChannelUpdate) SetChannelSubject(v string) *AuthIdentityChannelUpdate { + _u.mutation.SetChannelSubject(v) + return _u +} + +// SetNillableChannelSubject sets the "channel_subject" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdate) SetNillableChannelSubject(v *string) *AuthIdentityChannelUpdate { + if v != nil { + _u.SetChannelSubject(*v) + } + return _u +} + +// SetMetadata sets the "metadata" field. +func (_u *AuthIdentityChannelUpdate) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpdate { + _u.mutation.SetMetadata(v) + return _u +} + +// SetIdentity sets the "identity" edge to the AuthIdentity entity. +func (_u *AuthIdentityChannelUpdate) SetIdentity(v *AuthIdentity) *AuthIdentityChannelUpdate { + return _u.SetIdentityID(v.ID) +} + +// Mutation returns the AuthIdentityChannelMutation object of the builder. +func (_u *AuthIdentityChannelUpdate) Mutation() *AuthIdentityChannelMutation { + return _u.mutation +} + +// ClearIdentity clears the "identity" edge to the AuthIdentity entity. +func (_u *AuthIdentityChannelUpdate) ClearIdentity() *AuthIdentityChannelUpdate { + _u.mutation.ClearIdentity() + return _u +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *AuthIdentityChannelUpdate) Save(ctx context.Context) (int, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *AuthIdentityChannelUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *AuthIdentityChannelUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *AuthIdentityChannelUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *AuthIdentityChannelUpdate) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := authidentitychannel.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *AuthIdentityChannelUpdate) check() error { + if v, ok := _u.mutation.ProviderType(); ok { + if err := authidentitychannel.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_type": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderKey(); ok { + if err := authidentitychannel.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_key": %w`, err)} + } + } + if v, ok := _u.mutation.Channel(); ok { + if err := authidentitychannel.ChannelValidator(v); err != nil { + return &ValidationError{Name: "channel", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel": %w`, err)} + } + } + if v, ok := _u.mutation.ChannelAppID(); ok { + if err := authidentitychannel.ChannelAppIDValidator(v); err != nil { + return &ValidationError{Name: "channel_app_id", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_app_id": %w`, err)} + } + } + if v, ok := _u.mutation.ChannelSubject(); ok { + if err := authidentitychannel.ChannelSubjectValidator(v); err != nil { + return &ValidationError{Name: "channel_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_subject": %w`, err)} + } + } + if _u.mutation.IdentityCleared() && len(_u.mutation.IdentityIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "AuthIdentityChannel.identity"`) + } + return nil +} + +func (_u *AuthIdentityChannelUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(authidentitychannel.Table, authidentitychannel.Columns, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(authidentitychannel.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.ProviderType(); ok { + _spec.SetField(authidentitychannel.FieldProviderType, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderKey(); ok { + _spec.SetField(authidentitychannel.FieldProviderKey, field.TypeString, value) + } + if value, ok := _u.mutation.Channel(); ok { + _spec.SetField(authidentitychannel.FieldChannel, field.TypeString, value) + } + if value, ok := _u.mutation.ChannelAppID(); ok { + _spec.SetField(authidentitychannel.FieldChannelAppID, field.TypeString, value) + } + if value, ok := _u.mutation.ChannelSubject(); ok { + _spec.SetField(authidentitychannel.FieldChannelSubject, field.TypeString, value) + } + if value, ok := _u.mutation.Metadata(); ok { + _spec.SetField(authidentitychannel.FieldMetadata, field.TypeJSON, value) + } + if _u.mutation.IdentityCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentitychannel.IdentityTable, + Columns: []string{authidentitychannel.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentitychannel.IdentityTable, + Columns: []string{authidentitychannel.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{authidentitychannel.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// AuthIdentityChannelUpdateOne is the builder for updating a single AuthIdentityChannel entity. +type AuthIdentityChannelUpdateOne struct { + config + fields []string + hooks []Hook + mutation *AuthIdentityChannelMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *AuthIdentityChannelUpdateOne) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpdateOne { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetIdentityID sets the "identity_id" field. +func (_u *AuthIdentityChannelUpdateOne) SetIdentityID(v int64) *AuthIdentityChannelUpdateOne { + _u.mutation.SetIdentityID(v) + return _u +} + +// SetNillableIdentityID sets the "identity_id" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdateOne) SetNillableIdentityID(v *int64) *AuthIdentityChannelUpdateOne { + if v != nil { + _u.SetIdentityID(*v) + } + return _u +} + +// SetProviderType sets the "provider_type" field. +func (_u *AuthIdentityChannelUpdateOne) SetProviderType(v string) *AuthIdentityChannelUpdateOne { + _u.mutation.SetProviderType(v) + return _u +} + +// SetNillableProviderType sets the "provider_type" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdateOne) SetNillableProviderType(v *string) *AuthIdentityChannelUpdateOne { + if v != nil { + _u.SetProviderType(*v) + } + return _u +} + +// SetProviderKey sets the "provider_key" field. +func (_u *AuthIdentityChannelUpdateOne) SetProviderKey(v string) *AuthIdentityChannelUpdateOne { + _u.mutation.SetProviderKey(v) + return _u +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdateOne) SetNillableProviderKey(v *string) *AuthIdentityChannelUpdateOne { + if v != nil { + _u.SetProviderKey(*v) + } + return _u +} + +// SetChannel sets the "channel" field. +func (_u *AuthIdentityChannelUpdateOne) SetChannel(v string) *AuthIdentityChannelUpdateOne { + _u.mutation.SetChannel(v) + return _u +} + +// SetNillableChannel sets the "channel" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdateOne) SetNillableChannel(v *string) *AuthIdentityChannelUpdateOne { + if v != nil { + _u.SetChannel(*v) + } + return _u +} + +// SetChannelAppID sets the "channel_app_id" field. +func (_u *AuthIdentityChannelUpdateOne) SetChannelAppID(v string) *AuthIdentityChannelUpdateOne { + _u.mutation.SetChannelAppID(v) + return _u +} + +// SetNillableChannelAppID sets the "channel_app_id" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdateOne) SetNillableChannelAppID(v *string) *AuthIdentityChannelUpdateOne { + if v != nil { + _u.SetChannelAppID(*v) + } + return _u +} + +// SetChannelSubject sets the "channel_subject" field. +func (_u *AuthIdentityChannelUpdateOne) SetChannelSubject(v string) *AuthIdentityChannelUpdateOne { + _u.mutation.SetChannelSubject(v) + return _u +} + +// SetNillableChannelSubject sets the "channel_subject" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdateOne) SetNillableChannelSubject(v *string) *AuthIdentityChannelUpdateOne { + if v != nil { + _u.SetChannelSubject(*v) + } + return _u +} + +// SetMetadata sets the "metadata" field. +func (_u *AuthIdentityChannelUpdateOne) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpdateOne { + _u.mutation.SetMetadata(v) + return _u +} + +// SetIdentity sets the "identity" edge to the AuthIdentity entity. +func (_u *AuthIdentityChannelUpdateOne) SetIdentity(v *AuthIdentity) *AuthIdentityChannelUpdateOne { + return _u.SetIdentityID(v.ID) +} + +// Mutation returns the AuthIdentityChannelMutation object of the builder. +func (_u *AuthIdentityChannelUpdateOne) Mutation() *AuthIdentityChannelMutation { + return _u.mutation +} + +// ClearIdentity clears the "identity" edge to the AuthIdentity entity. +func (_u *AuthIdentityChannelUpdateOne) ClearIdentity() *AuthIdentityChannelUpdateOne { + _u.mutation.ClearIdentity() + return _u +} + +// Where appends a list predicates to the AuthIdentityChannelUpdate builder. +func (_u *AuthIdentityChannelUpdateOne) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *AuthIdentityChannelUpdateOne) Select(field string, fields ...string) *AuthIdentityChannelUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated AuthIdentityChannel entity. +func (_u *AuthIdentityChannelUpdateOne) Save(ctx context.Context) (*AuthIdentityChannel, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *AuthIdentityChannelUpdateOne) SaveX(ctx context.Context) *AuthIdentityChannel { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *AuthIdentityChannelUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *AuthIdentityChannelUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *AuthIdentityChannelUpdateOne) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := authidentitychannel.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *AuthIdentityChannelUpdateOne) check() error { + if v, ok := _u.mutation.ProviderType(); ok { + if err := authidentitychannel.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_type": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderKey(); ok { + if err := authidentitychannel.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_key": %w`, err)} + } + } + if v, ok := _u.mutation.Channel(); ok { + if err := authidentitychannel.ChannelValidator(v); err != nil { + return &ValidationError{Name: "channel", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel": %w`, err)} + } + } + if v, ok := _u.mutation.ChannelAppID(); ok { + if err := authidentitychannel.ChannelAppIDValidator(v); err != nil { + return &ValidationError{Name: "channel_app_id", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_app_id": %w`, err)} + } + } + if v, ok := _u.mutation.ChannelSubject(); ok { + if err := authidentitychannel.ChannelSubjectValidator(v); err != nil { + return &ValidationError{Name: "channel_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_subject": %w`, err)} + } + } + if _u.mutation.IdentityCleared() && len(_u.mutation.IdentityIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "AuthIdentityChannel.identity"`) + } + return nil +} + +func (_u *AuthIdentityChannelUpdateOne) sqlSave(ctx context.Context) (_node *AuthIdentityChannel, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(authidentitychannel.Table, authidentitychannel.Columns, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "AuthIdentityChannel.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, authidentitychannel.FieldID) + for _, f := range fields { + if !authidentitychannel.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != authidentitychannel.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(authidentitychannel.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.ProviderType(); ok { + _spec.SetField(authidentitychannel.FieldProviderType, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderKey(); ok { + _spec.SetField(authidentitychannel.FieldProviderKey, field.TypeString, value) + } + if value, ok := _u.mutation.Channel(); ok { + _spec.SetField(authidentitychannel.FieldChannel, field.TypeString, value) + } + if value, ok := _u.mutation.ChannelAppID(); ok { + _spec.SetField(authidentitychannel.FieldChannelAppID, field.TypeString, value) + } + if value, ok := _u.mutation.ChannelSubject(); ok { + _spec.SetField(authidentitychannel.FieldChannelSubject, field.TypeString, value) + } + if value, ok := _u.mutation.Metadata(); ok { + _spec.SetField(authidentitychannel.FieldMetadata, field.TypeJSON, value) + } + if _u.mutation.IdentityCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentitychannel.IdentityTable, + Columns: []string{authidentitychannel.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentitychannel.IdentityTable, + Columns: []string{authidentitychannel.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &AuthIdentityChannel{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{authidentitychannel.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/client.go b/backend/ent/client.go index e52e015ad8b1fe576a57f6298562a5619ab3e4a8..b02f519b921d9d283681345a5b80e5688903edfd 100644 --- a/backend/ent/client.go +++ b/backend/ent/client.go @@ -20,12 +20,16 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/paymentauditlog" "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/proxy" @@ -60,18 +64,26 @@ type Client struct { Announcement *AnnouncementClient // AnnouncementRead is the client for interacting with the AnnouncementRead builders. AnnouncementRead *AnnouncementReadClient + // AuthIdentity is the client for interacting with the AuthIdentity builders. + AuthIdentity *AuthIdentityClient + // AuthIdentityChannel is the client for interacting with the AuthIdentityChannel builders. + AuthIdentityChannel *AuthIdentityChannelClient // ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders. ErrorPassthroughRule *ErrorPassthroughRuleClient // Group is the client for interacting with the Group builders. Group *GroupClient // IdempotencyRecord is the client for interacting with the IdempotencyRecord builders. IdempotencyRecord *IdempotencyRecordClient + // IdentityAdoptionDecision is the client for interacting with the IdentityAdoptionDecision builders. + IdentityAdoptionDecision *IdentityAdoptionDecisionClient // PaymentAuditLog is the client for interacting with the PaymentAuditLog builders. PaymentAuditLog *PaymentAuditLogClient // PaymentOrder is the client for interacting with the PaymentOrder builders. PaymentOrder *PaymentOrderClient // PaymentProviderInstance is the client for interacting with the PaymentProviderInstance builders. PaymentProviderInstance *PaymentProviderInstanceClient + // PendingAuthSession is the client for interacting with the PendingAuthSession builders. + PendingAuthSession *PendingAuthSessionClient // PromoCode is the client for interacting with the PromoCode builders. PromoCode *PromoCodeClient // PromoCodeUsage is the client for interacting with the PromoCodeUsage builders. @@ -118,12 +130,16 @@ func (c *Client) init() { c.AccountGroup = NewAccountGroupClient(c.config) c.Announcement = NewAnnouncementClient(c.config) c.AnnouncementRead = NewAnnouncementReadClient(c.config) + c.AuthIdentity = NewAuthIdentityClient(c.config) + c.AuthIdentityChannel = NewAuthIdentityChannelClient(c.config) c.ErrorPassthroughRule = NewErrorPassthroughRuleClient(c.config) c.Group = NewGroupClient(c.config) c.IdempotencyRecord = NewIdempotencyRecordClient(c.config) + c.IdentityAdoptionDecision = NewIdentityAdoptionDecisionClient(c.config) c.PaymentAuditLog = NewPaymentAuditLogClient(c.config) c.PaymentOrder = NewPaymentOrderClient(c.config) c.PaymentProviderInstance = NewPaymentProviderInstanceClient(c.config) + c.PendingAuthSession = NewPendingAuthSessionClient(c.config) c.PromoCode = NewPromoCodeClient(c.config) c.PromoCodeUsage = NewPromoCodeUsageClient(c.config) c.Proxy = NewProxyClient(c.config) @@ -229,34 +245,38 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) { cfg := c.config cfg.driver = tx return &Tx{ - ctx: ctx, - config: cfg, - APIKey: NewAPIKeyClient(cfg), - Account: NewAccountClient(cfg), - AccountGroup: NewAccountGroupClient(cfg), - Announcement: NewAnnouncementClient(cfg), - AnnouncementRead: NewAnnouncementReadClient(cfg), - ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg), - Group: NewGroupClient(cfg), - IdempotencyRecord: NewIdempotencyRecordClient(cfg), - PaymentAuditLog: NewPaymentAuditLogClient(cfg), - PaymentOrder: NewPaymentOrderClient(cfg), - PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg), - PromoCode: NewPromoCodeClient(cfg), - PromoCodeUsage: NewPromoCodeUsageClient(cfg), - Proxy: NewProxyClient(cfg), - RedeemCode: NewRedeemCodeClient(cfg), - SecuritySecret: NewSecuritySecretClient(cfg), - Setting: NewSettingClient(cfg), - SubscriptionPlan: NewSubscriptionPlanClient(cfg), - TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg), - UsageCleanupTask: NewUsageCleanupTaskClient(cfg), - UsageLog: NewUsageLogClient(cfg), - User: NewUserClient(cfg), - UserAllowedGroup: NewUserAllowedGroupClient(cfg), - UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg), - UserAttributeValue: NewUserAttributeValueClient(cfg), - UserSubscription: NewUserSubscriptionClient(cfg), + ctx: ctx, + config: cfg, + APIKey: NewAPIKeyClient(cfg), + Account: NewAccountClient(cfg), + AccountGroup: NewAccountGroupClient(cfg), + Announcement: NewAnnouncementClient(cfg), + AnnouncementRead: NewAnnouncementReadClient(cfg), + AuthIdentity: NewAuthIdentityClient(cfg), + AuthIdentityChannel: NewAuthIdentityChannelClient(cfg), + ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg), + Group: NewGroupClient(cfg), + IdempotencyRecord: NewIdempotencyRecordClient(cfg), + IdentityAdoptionDecision: NewIdentityAdoptionDecisionClient(cfg), + PaymentAuditLog: NewPaymentAuditLogClient(cfg), + PaymentOrder: NewPaymentOrderClient(cfg), + PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg), + PendingAuthSession: NewPendingAuthSessionClient(cfg), + PromoCode: NewPromoCodeClient(cfg), + PromoCodeUsage: NewPromoCodeUsageClient(cfg), + Proxy: NewProxyClient(cfg), + RedeemCode: NewRedeemCodeClient(cfg), + SecuritySecret: NewSecuritySecretClient(cfg), + Setting: NewSettingClient(cfg), + SubscriptionPlan: NewSubscriptionPlanClient(cfg), + TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg), + UsageCleanupTask: NewUsageCleanupTaskClient(cfg), + UsageLog: NewUsageLogClient(cfg), + User: NewUserClient(cfg), + UserAllowedGroup: NewUserAllowedGroupClient(cfg), + UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg), + UserAttributeValue: NewUserAttributeValueClient(cfg), + UserSubscription: NewUserSubscriptionClient(cfg), }, nil } @@ -274,34 +294,38 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) cfg := c.config cfg.driver = &txDriver{tx: tx, drv: c.driver} return &Tx{ - ctx: ctx, - config: cfg, - APIKey: NewAPIKeyClient(cfg), - Account: NewAccountClient(cfg), - AccountGroup: NewAccountGroupClient(cfg), - Announcement: NewAnnouncementClient(cfg), - AnnouncementRead: NewAnnouncementReadClient(cfg), - ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg), - Group: NewGroupClient(cfg), - IdempotencyRecord: NewIdempotencyRecordClient(cfg), - PaymentAuditLog: NewPaymentAuditLogClient(cfg), - PaymentOrder: NewPaymentOrderClient(cfg), - PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg), - PromoCode: NewPromoCodeClient(cfg), - PromoCodeUsage: NewPromoCodeUsageClient(cfg), - Proxy: NewProxyClient(cfg), - RedeemCode: NewRedeemCodeClient(cfg), - SecuritySecret: NewSecuritySecretClient(cfg), - Setting: NewSettingClient(cfg), - SubscriptionPlan: NewSubscriptionPlanClient(cfg), - TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg), - UsageCleanupTask: NewUsageCleanupTaskClient(cfg), - UsageLog: NewUsageLogClient(cfg), - User: NewUserClient(cfg), - UserAllowedGroup: NewUserAllowedGroupClient(cfg), - UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg), - UserAttributeValue: NewUserAttributeValueClient(cfg), - UserSubscription: NewUserSubscriptionClient(cfg), + ctx: ctx, + config: cfg, + APIKey: NewAPIKeyClient(cfg), + Account: NewAccountClient(cfg), + AccountGroup: NewAccountGroupClient(cfg), + Announcement: NewAnnouncementClient(cfg), + AnnouncementRead: NewAnnouncementReadClient(cfg), + AuthIdentity: NewAuthIdentityClient(cfg), + AuthIdentityChannel: NewAuthIdentityChannelClient(cfg), + ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg), + Group: NewGroupClient(cfg), + IdempotencyRecord: NewIdempotencyRecordClient(cfg), + IdentityAdoptionDecision: NewIdentityAdoptionDecisionClient(cfg), + PaymentAuditLog: NewPaymentAuditLogClient(cfg), + PaymentOrder: NewPaymentOrderClient(cfg), + PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg), + PendingAuthSession: NewPendingAuthSessionClient(cfg), + PromoCode: NewPromoCodeClient(cfg), + PromoCodeUsage: NewPromoCodeUsageClient(cfg), + Proxy: NewProxyClient(cfg), + RedeemCode: NewRedeemCodeClient(cfg), + SecuritySecret: NewSecuritySecretClient(cfg), + Setting: NewSettingClient(cfg), + SubscriptionPlan: NewSubscriptionPlanClient(cfg), + TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg), + UsageCleanupTask: NewUsageCleanupTaskClient(cfg), + UsageLog: NewUsageLogClient(cfg), + User: NewUserClient(cfg), + UserAllowedGroup: NewUserAllowedGroupClient(cfg), + UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg), + UserAttributeValue: NewUserAttributeValueClient(cfg), + UserSubscription: NewUserSubscriptionClient(cfg), }, nil } @@ -332,11 +356,12 @@ func (c *Client) Close() error { func (c *Client) Use(hooks ...Hook) { for _, n := range []interface{ Use(...Hook) }{ c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, - c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PaymentAuditLog, - c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, c.PromoCodeUsage, - c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.SubscriptionPlan, - c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User, - c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, + c.AuthIdentity, c.AuthIdentityChannel, c.ErrorPassthroughRule, c.Group, + c.IdempotencyRecord, c.IdentityAdoptionDecision, c.PaymentAuditLog, + c.PaymentOrder, c.PaymentProviderInstance, c.PendingAuthSession, c.PromoCode, + c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, + c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, + c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, } { n.Use(hooks...) @@ -348,11 +373,12 @@ func (c *Client) Use(hooks ...Hook) { func (c *Client) Intercept(interceptors ...Interceptor) { for _, n := range []interface{ Intercept(...Interceptor) }{ c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, - c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PaymentAuditLog, - c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, c.PromoCodeUsage, - c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.SubscriptionPlan, - c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User, - c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, + c.AuthIdentity, c.AuthIdentityChannel, c.ErrorPassthroughRule, c.Group, + c.IdempotencyRecord, c.IdentityAdoptionDecision, c.PaymentAuditLog, + c.PaymentOrder, c.PaymentProviderInstance, c.PendingAuthSession, c.PromoCode, + c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, + c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, + c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, } { n.Intercept(interceptors...) @@ -372,18 +398,26 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) { return c.Announcement.mutate(ctx, m) case *AnnouncementReadMutation: return c.AnnouncementRead.mutate(ctx, m) + case *AuthIdentityMutation: + return c.AuthIdentity.mutate(ctx, m) + case *AuthIdentityChannelMutation: + return c.AuthIdentityChannel.mutate(ctx, m) case *ErrorPassthroughRuleMutation: return c.ErrorPassthroughRule.mutate(ctx, m) case *GroupMutation: return c.Group.mutate(ctx, m) case *IdempotencyRecordMutation: return c.IdempotencyRecord.mutate(ctx, m) + case *IdentityAdoptionDecisionMutation: + return c.IdentityAdoptionDecision.mutate(ctx, m) case *PaymentAuditLogMutation: return c.PaymentAuditLog.mutate(ctx, m) case *PaymentOrderMutation: return c.PaymentOrder.mutate(ctx, m) case *PaymentProviderInstanceMutation: return c.PaymentProviderInstance.mutate(ctx, m) + case *PendingAuthSessionMutation: + return c.PendingAuthSession.mutate(ctx, m) case *PromoCodeMutation: return c.PromoCode.mutate(ctx, m) case *PromoCodeUsageMutation: @@ -1231,6 +1265,336 @@ func (c *AnnouncementReadClient) mutate(ctx context.Context, m *AnnouncementRead } } +// AuthIdentityClient is a client for the AuthIdentity schema. +type AuthIdentityClient struct { + config +} + +// NewAuthIdentityClient returns a client for the AuthIdentity from the given config. +func NewAuthIdentityClient(c config) *AuthIdentityClient { + return &AuthIdentityClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `authidentity.Hooks(f(g(h())))`. +func (c *AuthIdentityClient) Use(hooks ...Hook) { + c.hooks.AuthIdentity = append(c.hooks.AuthIdentity, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `authidentity.Intercept(f(g(h())))`. +func (c *AuthIdentityClient) Intercept(interceptors ...Interceptor) { + c.inters.AuthIdentity = append(c.inters.AuthIdentity, interceptors...) +} + +// Create returns a builder for creating a AuthIdentity entity. +func (c *AuthIdentityClient) Create() *AuthIdentityCreate { + mutation := newAuthIdentityMutation(c.config, OpCreate) + return &AuthIdentityCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of AuthIdentity entities. +func (c *AuthIdentityClient) CreateBulk(builders ...*AuthIdentityCreate) *AuthIdentityCreateBulk { + return &AuthIdentityCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *AuthIdentityClient) MapCreateBulk(slice any, setFunc func(*AuthIdentityCreate, int)) *AuthIdentityCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &AuthIdentityCreateBulk{err: fmt.Errorf("calling to AuthIdentityClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*AuthIdentityCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &AuthIdentityCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for AuthIdentity. +func (c *AuthIdentityClient) Update() *AuthIdentityUpdate { + mutation := newAuthIdentityMutation(c.config, OpUpdate) + return &AuthIdentityUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *AuthIdentityClient) UpdateOne(_m *AuthIdentity) *AuthIdentityUpdateOne { + mutation := newAuthIdentityMutation(c.config, OpUpdateOne, withAuthIdentity(_m)) + return &AuthIdentityUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *AuthIdentityClient) UpdateOneID(id int64) *AuthIdentityUpdateOne { + mutation := newAuthIdentityMutation(c.config, OpUpdateOne, withAuthIdentityID(id)) + return &AuthIdentityUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for AuthIdentity. +func (c *AuthIdentityClient) Delete() *AuthIdentityDelete { + mutation := newAuthIdentityMutation(c.config, OpDelete) + return &AuthIdentityDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *AuthIdentityClient) DeleteOne(_m *AuthIdentity) *AuthIdentityDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *AuthIdentityClient) DeleteOneID(id int64) *AuthIdentityDeleteOne { + builder := c.Delete().Where(authidentity.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &AuthIdentityDeleteOne{builder} +} + +// Query returns a query builder for AuthIdentity. +func (c *AuthIdentityClient) Query() *AuthIdentityQuery { + return &AuthIdentityQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeAuthIdentity}, + inters: c.Interceptors(), + } +} + +// Get returns a AuthIdentity entity by its id. +func (c *AuthIdentityClient) Get(ctx context.Context, id int64) (*AuthIdentity, error) { + return c.Query().Where(authidentity.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *AuthIdentityClient) GetX(ctx context.Context, id int64) *AuthIdentity { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryUser queries the user edge of a AuthIdentity. +func (c *AuthIdentityClient) QueryUser(_m *AuthIdentity) *UserQuery { + query := (&UserClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(authidentity.Table, authidentity.FieldID, id), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, authidentity.UserTable, authidentity.UserColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryChannels queries the channels edge of a AuthIdentity. +func (c *AuthIdentityClient) QueryChannels(_m *AuthIdentity) *AuthIdentityChannelQuery { + query := (&AuthIdentityChannelClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(authidentity.Table, authidentity.FieldID, id), + sqlgraph.To(authidentitychannel.Table, authidentitychannel.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, authidentity.ChannelsTable, authidentity.ChannelsColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryAdoptionDecisions queries the adoption_decisions edge of a AuthIdentity. +func (c *AuthIdentityClient) QueryAdoptionDecisions(_m *AuthIdentity) *IdentityAdoptionDecisionQuery { + query := (&IdentityAdoptionDecisionClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(authidentity.Table, authidentity.FieldID, id), + sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, authidentity.AdoptionDecisionsTable, authidentity.AdoptionDecisionsColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *AuthIdentityClient) Hooks() []Hook { + return c.hooks.AuthIdentity +} + +// Interceptors returns the client interceptors. +func (c *AuthIdentityClient) Interceptors() []Interceptor { + return c.inters.AuthIdentity +} + +func (c *AuthIdentityClient) mutate(ctx context.Context, m *AuthIdentityMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&AuthIdentityCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&AuthIdentityUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&AuthIdentityUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&AuthIdentityDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown AuthIdentity mutation op: %q", m.Op()) + } +} + +// AuthIdentityChannelClient is a client for the AuthIdentityChannel schema. +type AuthIdentityChannelClient struct { + config +} + +// NewAuthIdentityChannelClient returns a client for the AuthIdentityChannel from the given config. +func NewAuthIdentityChannelClient(c config) *AuthIdentityChannelClient { + return &AuthIdentityChannelClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `authidentitychannel.Hooks(f(g(h())))`. +func (c *AuthIdentityChannelClient) Use(hooks ...Hook) { + c.hooks.AuthIdentityChannel = append(c.hooks.AuthIdentityChannel, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `authidentitychannel.Intercept(f(g(h())))`. +func (c *AuthIdentityChannelClient) Intercept(interceptors ...Interceptor) { + c.inters.AuthIdentityChannel = append(c.inters.AuthIdentityChannel, interceptors...) +} + +// Create returns a builder for creating a AuthIdentityChannel entity. +func (c *AuthIdentityChannelClient) Create() *AuthIdentityChannelCreate { + mutation := newAuthIdentityChannelMutation(c.config, OpCreate) + return &AuthIdentityChannelCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of AuthIdentityChannel entities. +func (c *AuthIdentityChannelClient) CreateBulk(builders ...*AuthIdentityChannelCreate) *AuthIdentityChannelCreateBulk { + return &AuthIdentityChannelCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *AuthIdentityChannelClient) MapCreateBulk(slice any, setFunc func(*AuthIdentityChannelCreate, int)) *AuthIdentityChannelCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &AuthIdentityChannelCreateBulk{err: fmt.Errorf("calling to AuthIdentityChannelClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*AuthIdentityChannelCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &AuthIdentityChannelCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for AuthIdentityChannel. +func (c *AuthIdentityChannelClient) Update() *AuthIdentityChannelUpdate { + mutation := newAuthIdentityChannelMutation(c.config, OpUpdate) + return &AuthIdentityChannelUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *AuthIdentityChannelClient) UpdateOne(_m *AuthIdentityChannel) *AuthIdentityChannelUpdateOne { + mutation := newAuthIdentityChannelMutation(c.config, OpUpdateOne, withAuthIdentityChannel(_m)) + return &AuthIdentityChannelUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *AuthIdentityChannelClient) UpdateOneID(id int64) *AuthIdentityChannelUpdateOne { + mutation := newAuthIdentityChannelMutation(c.config, OpUpdateOne, withAuthIdentityChannelID(id)) + return &AuthIdentityChannelUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for AuthIdentityChannel. +func (c *AuthIdentityChannelClient) Delete() *AuthIdentityChannelDelete { + mutation := newAuthIdentityChannelMutation(c.config, OpDelete) + return &AuthIdentityChannelDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *AuthIdentityChannelClient) DeleteOne(_m *AuthIdentityChannel) *AuthIdentityChannelDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *AuthIdentityChannelClient) DeleteOneID(id int64) *AuthIdentityChannelDeleteOne { + builder := c.Delete().Where(authidentitychannel.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &AuthIdentityChannelDeleteOne{builder} +} + +// Query returns a query builder for AuthIdentityChannel. +func (c *AuthIdentityChannelClient) Query() *AuthIdentityChannelQuery { + return &AuthIdentityChannelQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeAuthIdentityChannel}, + inters: c.Interceptors(), + } +} + +// Get returns a AuthIdentityChannel entity by its id. +func (c *AuthIdentityChannelClient) Get(ctx context.Context, id int64) (*AuthIdentityChannel, error) { + return c.Query().Where(authidentitychannel.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *AuthIdentityChannelClient) GetX(ctx context.Context, id int64) *AuthIdentityChannel { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryIdentity queries the identity edge of a AuthIdentityChannel. +func (c *AuthIdentityChannelClient) QueryIdentity(_m *AuthIdentityChannel) *AuthIdentityQuery { + query := (&AuthIdentityClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(authidentitychannel.Table, authidentitychannel.FieldID, id), + sqlgraph.To(authidentity.Table, authidentity.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, authidentitychannel.IdentityTable, authidentitychannel.IdentityColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *AuthIdentityChannelClient) Hooks() []Hook { + return c.hooks.AuthIdentityChannel +} + +// Interceptors returns the client interceptors. +func (c *AuthIdentityChannelClient) Interceptors() []Interceptor { + return c.inters.AuthIdentityChannel +} + +func (c *AuthIdentityChannelClient) mutate(ctx context.Context, m *AuthIdentityChannelMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&AuthIdentityChannelCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&AuthIdentityChannelUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&AuthIdentityChannelUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&AuthIdentityChannelDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown AuthIdentityChannel mutation op: %q", m.Op()) + } +} + // ErrorPassthroughRuleClient is a client for the ErrorPassthroughRule schema. type ErrorPassthroughRuleClient struct { config @@ -1760,6 +2124,171 @@ func (c *IdempotencyRecordClient) mutate(ctx context.Context, m *IdempotencyReco } } +// IdentityAdoptionDecisionClient is a client for the IdentityAdoptionDecision schema. +type IdentityAdoptionDecisionClient struct { + config +} + +// NewIdentityAdoptionDecisionClient returns a client for the IdentityAdoptionDecision from the given config. +func NewIdentityAdoptionDecisionClient(c config) *IdentityAdoptionDecisionClient { + return &IdentityAdoptionDecisionClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `identityadoptiondecision.Hooks(f(g(h())))`. +func (c *IdentityAdoptionDecisionClient) Use(hooks ...Hook) { + c.hooks.IdentityAdoptionDecision = append(c.hooks.IdentityAdoptionDecision, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `identityadoptiondecision.Intercept(f(g(h())))`. +func (c *IdentityAdoptionDecisionClient) Intercept(interceptors ...Interceptor) { + c.inters.IdentityAdoptionDecision = append(c.inters.IdentityAdoptionDecision, interceptors...) +} + +// Create returns a builder for creating a IdentityAdoptionDecision entity. +func (c *IdentityAdoptionDecisionClient) Create() *IdentityAdoptionDecisionCreate { + mutation := newIdentityAdoptionDecisionMutation(c.config, OpCreate) + return &IdentityAdoptionDecisionCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of IdentityAdoptionDecision entities. +func (c *IdentityAdoptionDecisionClient) CreateBulk(builders ...*IdentityAdoptionDecisionCreate) *IdentityAdoptionDecisionCreateBulk { + return &IdentityAdoptionDecisionCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *IdentityAdoptionDecisionClient) MapCreateBulk(slice any, setFunc func(*IdentityAdoptionDecisionCreate, int)) *IdentityAdoptionDecisionCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &IdentityAdoptionDecisionCreateBulk{err: fmt.Errorf("calling to IdentityAdoptionDecisionClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*IdentityAdoptionDecisionCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &IdentityAdoptionDecisionCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for IdentityAdoptionDecision. +func (c *IdentityAdoptionDecisionClient) Update() *IdentityAdoptionDecisionUpdate { + mutation := newIdentityAdoptionDecisionMutation(c.config, OpUpdate) + return &IdentityAdoptionDecisionUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *IdentityAdoptionDecisionClient) UpdateOne(_m *IdentityAdoptionDecision) *IdentityAdoptionDecisionUpdateOne { + mutation := newIdentityAdoptionDecisionMutation(c.config, OpUpdateOne, withIdentityAdoptionDecision(_m)) + return &IdentityAdoptionDecisionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *IdentityAdoptionDecisionClient) UpdateOneID(id int64) *IdentityAdoptionDecisionUpdateOne { + mutation := newIdentityAdoptionDecisionMutation(c.config, OpUpdateOne, withIdentityAdoptionDecisionID(id)) + return &IdentityAdoptionDecisionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for IdentityAdoptionDecision. +func (c *IdentityAdoptionDecisionClient) Delete() *IdentityAdoptionDecisionDelete { + mutation := newIdentityAdoptionDecisionMutation(c.config, OpDelete) + return &IdentityAdoptionDecisionDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *IdentityAdoptionDecisionClient) DeleteOne(_m *IdentityAdoptionDecision) *IdentityAdoptionDecisionDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *IdentityAdoptionDecisionClient) DeleteOneID(id int64) *IdentityAdoptionDecisionDeleteOne { + builder := c.Delete().Where(identityadoptiondecision.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &IdentityAdoptionDecisionDeleteOne{builder} +} + +// Query returns a query builder for IdentityAdoptionDecision. +func (c *IdentityAdoptionDecisionClient) Query() *IdentityAdoptionDecisionQuery { + return &IdentityAdoptionDecisionQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeIdentityAdoptionDecision}, + inters: c.Interceptors(), + } +} + +// Get returns a IdentityAdoptionDecision entity by its id. +func (c *IdentityAdoptionDecisionClient) Get(ctx context.Context, id int64) (*IdentityAdoptionDecision, error) { + return c.Query().Where(identityadoptiondecision.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *IdentityAdoptionDecisionClient) GetX(ctx context.Context, id int64) *IdentityAdoptionDecision { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryPendingAuthSession queries the pending_auth_session edge of a IdentityAdoptionDecision. +func (c *IdentityAdoptionDecisionClient) QueryPendingAuthSession(_m *IdentityAdoptionDecision) *PendingAuthSessionQuery { + query := (&PendingAuthSessionClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, id), + sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID), + sqlgraph.Edge(sqlgraph.O2O, true, identityadoptiondecision.PendingAuthSessionTable, identityadoptiondecision.PendingAuthSessionColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryIdentity queries the identity edge of a IdentityAdoptionDecision. +func (c *IdentityAdoptionDecisionClient) QueryIdentity(_m *IdentityAdoptionDecision) *AuthIdentityQuery { + query := (&AuthIdentityClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, id), + sqlgraph.To(authidentity.Table, authidentity.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, identityadoptiondecision.IdentityTable, identityadoptiondecision.IdentityColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *IdentityAdoptionDecisionClient) Hooks() []Hook { + return c.hooks.IdentityAdoptionDecision +} + +// Interceptors returns the client interceptors. +func (c *IdentityAdoptionDecisionClient) Interceptors() []Interceptor { + return c.inters.IdentityAdoptionDecision +} + +func (c *IdentityAdoptionDecisionClient) mutate(ctx context.Context, m *IdentityAdoptionDecisionMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&IdentityAdoptionDecisionCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&IdentityAdoptionDecisionUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&IdentityAdoptionDecisionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&IdentityAdoptionDecisionDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown IdentityAdoptionDecision mutation op: %q", m.Op()) + } +} + // PaymentAuditLogClient is a client for the PaymentAuditLog schema. type PaymentAuditLogClient struct { config @@ -2175,6 +2704,171 @@ func (c *PaymentProviderInstanceClient) mutate(ctx context.Context, m *PaymentPr } } +// PendingAuthSessionClient is a client for the PendingAuthSession schema. +type PendingAuthSessionClient struct { + config +} + +// NewPendingAuthSessionClient returns a client for the PendingAuthSession from the given config. +func NewPendingAuthSessionClient(c config) *PendingAuthSessionClient { + return &PendingAuthSessionClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `pendingauthsession.Hooks(f(g(h())))`. +func (c *PendingAuthSessionClient) Use(hooks ...Hook) { + c.hooks.PendingAuthSession = append(c.hooks.PendingAuthSession, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `pendingauthsession.Intercept(f(g(h())))`. +func (c *PendingAuthSessionClient) Intercept(interceptors ...Interceptor) { + c.inters.PendingAuthSession = append(c.inters.PendingAuthSession, interceptors...) +} + +// Create returns a builder for creating a PendingAuthSession entity. +func (c *PendingAuthSessionClient) Create() *PendingAuthSessionCreate { + mutation := newPendingAuthSessionMutation(c.config, OpCreate) + return &PendingAuthSessionCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of PendingAuthSession entities. +func (c *PendingAuthSessionClient) CreateBulk(builders ...*PendingAuthSessionCreate) *PendingAuthSessionCreateBulk { + return &PendingAuthSessionCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *PendingAuthSessionClient) MapCreateBulk(slice any, setFunc func(*PendingAuthSessionCreate, int)) *PendingAuthSessionCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &PendingAuthSessionCreateBulk{err: fmt.Errorf("calling to PendingAuthSessionClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*PendingAuthSessionCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &PendingAuthSessionCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for PendingAuthSession. +func (c *PendingAuthSessionClient) Update() *PendingAuthSessionUpdate { + mutation := newPendingAuthSessionMutation(c.config, OpUpdate) + return &PendingAuthSessionUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *PendingAuthSessionClient) UpdateOne(_m *PendingAuthSession) *PendingAuthSessionUpdateOne { + mutation := newPendingAuthSessionMutation(c.config, OpUpdateOne, withPendingAuthSession(_m)) + return &PendingAuthSessionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *PendingAuthSessionClient) UpdateOneID(id int64) *PendingAuthSessionUpdateOne { + mutation := newPendingAuthSessionMutation(c.config, OpUpdateOne, withPendingAuthSessionID(id)) + return &PendingAuthSessionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for PendingAuthSession. +func (c *PendingAuthSessionClient) Delete() *PendingAuthSessionDelete { + mutation := newPendingAuthSessionMutation(c.config, OpDelete) + return &PendingAuthSessionDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *PendingAuthSessionClient) DeleteOne(_m *PendingAuthSession) *PendingAuthSessionDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *PendingAuthSessionClient) DeleteOneID(id int64) *PendingAuthSessionDeleteOne { + builder := c.Delete().Where(pendingauthsession.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &PendingAuthSessionDeleteOne{builder} +} + +// Query returns a query builder for PendingAuthSession. +func (c *PendingAuthSessionClient) Query() *PendingAuthSessionQuery { + return &PendingAuthSessionQuery{ + config: c.config, + ctx: &QueryContext{Type: TypePendingAuthSession}, + inters: c.Interceptors(), + } +} + +// Get returns a PendingAuthSession entity by its id. +func (c *PendingAuthSessionClient) Get(ctx context.Context, id int64) (*PendingAuthSession, error) { + return c.Query().Where(pendingauthsession.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *PendingAuthSessionClient) GetX(ctx context.Context, id int64) *PendingAuthSession { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryTargetUser queries the target_user edge of a PendingAuthSession. +func (c *PendingAuthSessionClient) QueryTargetUser(_m *PendingAuthSession) *UserQuery { + query := (&UserClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, id), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, pendingauthsession.TargetUserTable, pendingauthsession.TargetUserColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryAdoptionDecision queries the adoption_decision edge of a PendingAuthSession. +func (c *PendingAuthSessionClient) QueryAdoptionDecision(_m *PendingAuthSession) *IdentityAdoptionDecisionQuery { + query := (&IdentityAdoptionDecisionClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, id), + sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID), + sqlgraph.Edge(sqlgraph.O2O, false, pendingauthsession.AdoptionDecisionTable, pendingauthsession.AdoptionDecisionColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *PendingAuthSessionClient) Hooks() []Hook { + return c.hooks.PendingAuthSession +} + +// Interceptors returns the client interceptors. +func (c *PendingAuthSessionClient) Interceptors() []Interceptor { + return c.inters.PendingAuthSession +} + +func (c *PendingAuthSessionClient) mutate(ctx context.Context, m *PendingAuthSessionMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&PendingAuthSessionCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&PendingAuthSessionUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&PendingAuthSessionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&PendingAuthSessionDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown PendingAuthSession mutation op: %q", m.Op()) + } +} + // PromoCodeClient is a client for the PromoCode schema. type PromoCodeClient struct { config @@ -3951,6 +4645,38 @@ func (c *UserClient) QueryPaymentOrders(_m *User) *PaymentOrderQuery { return query } +// QueryAuthIdentities queries the auth_identities edge of a User. +func (c *UserClient) QueryAuthIdentities(_m *User) *AuthIdentityQuery { + query := (&AuthIdentityClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, id), + sqlgraph.To(authidentity.Table, authidentity.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.AuthIdentitiesTable, user.AuthIdentitiesColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryPendingAuthSessions queries the pending_auth_sessions edge of a User. +func (c *UserClient) QueryPendingAuthSessions(_m *User) *PendingAuthSessionQuery { + query := (&PendingAuthSessionClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, id), + sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.PendingAuthSessionsTable, user.PendingAuthSessionsColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + // QueryUserAllowedGroups queries the user_allowed_groups edge of a User. func (c *UserClient) QueryUserAllowedGroups(_m *User) *UserAllowedGroupQuery { query := (&UserAllowedGroupClient{config: c.config}).Query() @@ -4628,18 +5354,20 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription // hooks and interceptors per client, for fast access. type ( hooks struct { - APIKey, Account, AccountGroup, Announcement, AnnouncementRead, - ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, PaymentOrder, - PaymentProviderInstance, PromoCode, PromoCodeUsage, Proxy, RedeemCode, - SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile, + APIKey, Account, AccountGroup, Announcement, AnnouncementRead, AuthIdentity, + AuthIdentityChannel, ErrorPassthroughRule, Group, IdempotencyRecord, + IdentityAdoptionDecision, PaymentAuditLog, PaymentOrder, + PaymentProviderInstance, PendingAuthSession, PromoCode, PromoCodeUsage, Proxy, + RedeemCode, SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook } inters struct { - APIKey, Account, AccountGroup, Announcement, AnnouncementRead, - ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, PaymentOrder, - PaymentProviderInstance, PromoCode, PromoCodeUsage, Proxy, RedeemCode, - SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile, + APIKey, Account, AccountGroup, Announcement, AnnouncementRead, AuthIdentity, + AuthIdentityChannel, ErrorPassthroughRule, Group, IdempotencyRecord, + IdentityAdoptionDecision, PaymentAuditLog, PaymentOrder, + PaymentProviderInstance, PendingAuthSession, PromoCode, PromoCodeUsage, Proxy, + RedeemCode, SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor } diff --git a/backend/ent/ent.go b/backend/ent/ent.go index 96ed5e03a9569a3b5a33922573aa9a2fbf090825..339e5369d7530391ab07279a9bff7016bbb59e3f 100644 --- a/backend/ent/ent.go +++ b/backend/ent/ent.go @@ -17,12 +17,16 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/paymentauditlog" "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/proxy" @@ -98,32 +102,36 @@ var ( func checkColumn(t, c string) error { initCheck.Do(func() { columnCheck = sql.NewColumnCheck(map[string]func(string) bool{ - apikey.Table: apikey.ValidColumn, - account.Table: account.ValidColumn, - accountgroup.Table: accountgroup.ValidColumn, - announcement.Table: announcement.ValidColumn, - announcementread.Table: announcementread.ValidColumn, - errorpassthroughrule.Table: errorpassthroughrule.ValidColumn, - group.Table: group.ValidColumn, - idempotencyrecord.Table: idempotencyrecord.ValidColumn, - paymentauditlog.Table: paymentauditlog.ValidColumn, - paymentorder.Table: paymentorder.ValidColumn, - paymentproviderinstance.Table: paymentproviderinstance.ValidColumn, - promocode.Table: promocode.ValidColumn, - promocodeusage.Table: promocodeusage.ValidColumn, - proxy.Table: proxy.ValidColumn, - redeemcode.Table: redeemcode.ValidColumn, - securitysecret.Table: securitysecret.ValidColumn, - setting.Table: setting.ValidColumn, - subscriptionplan.Table: subscriptionplan.ValidColumn, - tlsfingerprintprofile.Table: tlsfingerprintprofile.ValidColumn, - usagecleanuptask.Table: usagecleanuptask.ValidColumn, - usagelog.Table: usagelog.ValidColumn, - user.Table: user.ValidColumn, - userallowedgroup.Table: userallowedgroup.ValidColumn, - userattributedefinition.Table: userattributedefinition.ValidColumn, - userattributevalue.Table: userattributevalue.ValidColumn, - usersubscription.Table: usersubscription.ValidColumn, + apikey.Table: apikey.ValidColumn, + account.Table: account.ValidColumn, + accountgroup.Table: accountgroup.ValidColumn, + announcement.Table: announcement.ValidColumn, + announcementread.Table: announcementread.ValidColumn, + authidentity.Table: authidentity.ValidColumn, + authidentitychannel.Table: authidentitychannel.ValidColumn, + errorpassthroughrule.Table: errorpassthroughrule.ValidColumn, + group.Table: group.ValidColumn, + idempotencyrecord.Table: idempotencyrecord.ValidColumn, + identityadoptiondecision.Table: identityadoptiondecision.ValidColumn, + paymentauditlog.Table: paymentauditlog.ValidColumn, + paymentorder.Table: paymentorder.ValidColumn, + paymentproviderinstance.Table: paymentproviderinstance.ValidColumn, + pendingauthsession.Table: pendingauthsession.ValidColumn, + promocode.Table: promocode.ValidColumn, + promocodeusage.Table: promocodeusage.ValidColumn, + proxy.Table: proxy.ValidColumn, + redeemcode.Table: redeemcode.ValidColumn, + securitysecret.Table: securitysecret.ValidColumn, + setting.Table: setting.ValidColumn, + subscriptionplan.Table: subscriptionplan.ValidColumn, + tlsfingerprintprofile.Table: tlsfingerprintprofile.ValidColumn, + usagecleanuptask.Table: usagecleanuptask.ValidColumn, + usagelog.Table: usagelog.ValidColumn, + user.Table: user.ValidColumn, + userallowedgroup.Table: userallowedgroup.ValidColumn, + userattributedefinition.Table: userattributedefinition.ValidColumn, + userattributevalue.Table: userattributevalue.ValidColumn, + usersubscription.Table: usersubscription.ValidColumn, }) }) return columnCheck(t, c) diff --git a/backend/ent/hook/hook.go b/backend/ent/hook/hook.go index 199dacea0ecc8c8a06ac0e0789e5e968ec63b66e..46ac02bc202736b6ead35c567b89a9e8ffc4afdd 100644 --- a/backend/ent/hook/hook.go +++ b/backend/ent/hook/hook.go @@ -69,6 +69,30 @@ func (f AnnouncementReadFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.V return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AnnouncementReadMutation", m) } +// The AuthIdentityFunc type is an adapter to allow the use of ordinary +// function as AuthIdentity mutator. +type AuthIdentityFunc func(context.Context, *ent.AuthIdentityMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f AuthIdentityFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.AuthIdentityMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AuthIdentityMutation", m) +} + +// The AuthIdentityChannelFunc type is an adapter to allow the use of ordinary +// function as AuthIdentityChannel mutator. +type AuthIdentityChannelFunc func(context.Context, *ent.AuthIdentityChannelMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f AuthIdentityChannelFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.AuthIdentityChannelMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AuthIdentityChannelMutation", m) +} + // The ErrorPassthroughRuleFunc type is an adapter to allow the use of ordinary // function as ErrorPassthroughRule mutator. type ErrorPassthroughRuleFunc func(context.Context, *ent.ErrorPassthroughRuleMutation) (ent.Value, error) @@ -105,6 +129,18 @@ func (f IdempotencyRecordFunc) Mutate(ctx context.Context, m ent.Mutation) (ent. return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.IdempotencyRecordMutation", m) } +// The IdentityAdoptionDecisionFunc type is an adapter to allow the use of ordinary +// function as IdentityAdoptionDecision mutator. +type IdentityAdoptionDecisionFunc func(context.Context, *ent.IdentityAdoptionDecisionMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f IdentityAdoptionDecisionFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.IdentityAdoptionDecisionMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.IdentityAdoptionDecisionMutation", m) +} + // The PaymentAuditLogFunc type is an adapter to allow the use of ordinary // function as PaymentAuditLog mutator. type PaymentAuditLogFunc func(context.Context, *ent.PaymentAuditLogMutation) (ent.Value, error) @@ -141,6 +177,18 @@ func (f PaymentProviderInstanceFunc) Mutate(ctx context.Context, m ent.Mutation) return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.PaymentProviderInstanceMutation", m) } +// The PendingAuthSessionFunc type is an adapter to allow the use of ordinary +// function as PendingAuthSession mutator. +type PendingAuthSessionFunc func(context.Context, *ent.PendingAuthSessionMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f PendingAuthSessionFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.PendingAuthSessionMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.PendingAuthSessionMutation", m) +} + // The PromoCodeFunc type is an adapter to allow the use of ordinary // function as PromoCode mutator. type PromoCodeFunc func(context.Context, *ent.PromoCodeMutation) (ent.Value, error) diff --git a/backend/ent/identityadoptiondecision.go b/backend/ent/identityadoptiondecision.go new file mode 100644 index 0000000000000000000000000000000000000000..ecaee65c2fc3b5f5cf3a6df4c3d5455a54691188 --- /dev/null +++ b/backend/ent/identityadoptiondecision.go @@ -0,0 +1,223 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" +) + +// IdentityAdoptionDecision is the model entity for the IdentityAdoptionDecision schema. +type IdentityAdoptionDecision struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // PendingAuthSessionID holds the value of the "pending_auth_session_id" field. + PendingAuthSessionID int64 `json:"pending_auth_session_id,omitempty"` + // IdentityID holds the value of the "identity_id" field. + IdentityID *int64 `json:"identity_id,omitempty"` + // AdoptDisplayName holds the value of the "adopt_display_name" field. + AdoptDisplayName bool `json:"adopt_display_name,omitempty"` + // AdoptAvatar holds the value of the "adopt_avatar" field. + AdoptAvatar bool `json:"adopt_avatar,omitempty"` + // DecidedAt holds the value of the "decided_at" field. + DecidedAt time.Time `json:"decided_at,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the IdentityAdoptionDecisionQuery when eager-loading is set. + Edges IdentityAdoptionDecisionEdges `json:"edges"` + selectValues sql.SelectValues +} + +// IdentityAdoptionDecisionEdges holds the relations/edges for other nodes in the graph. +type IdentityAdoptionDecisionEdges struct { + // PendingAuthSession holds the value of the pending_auth_session edge. + PendingAuthSession *PendingAuthSession `json:"pending_auth_session,omitempty"` + // Identity holds the value of the identity edge. + Identity *AuthIdentity `json:"identity,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [2]bool +} + +// PendingAuthSessionOrErr returns the PendingAuthSession value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e IdentityAdoptionDecisionEdges) PendingAuthSessionOrErr() (*PendingAuthSession, error) { + if e.PendingAuthSession != nil { + return e.PendingAuthSession, nil + } else if e.loadedTypes[0] { + return nil, &NotFoundError{label: pendingauthsession.Label} + } + return nil, &NotLoadedError{edge: "pending_auth_session"} +} + +// IdentityOrErr returns the Identity value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e IdentityAdoptionDecisionEdges) IdentityOrErr() (*AuthIdentity, error) { + if e.Identity != nil { + return e.Identity, nil + } else if e.loadedTypes[1] { + return nil, &NotFoundError{label: authidentity.Label} + } + return nil, &NotLoadedError{edge: "identity"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*IdentityAdoptionDecision) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case identityadoptiondecision.FieldAdoptDisplayName, identityadoptiondecision.FieldAdoptAvatar: + values[i] = new(sql.NullBool) + case identityadoptiondecision.FieldID, identityadoptiondecision.FieldPendingAuthSessionID, identityadoptiondecision.FieldIdentityID: + values[i] = new(sql.NullInt64) + case identityadoptiondecision.FieldCreatedAt, identityadoptiondecision.FieldUpdatedAt, identityadoptiondecision.FieldDecidedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the IdentityAdoptionDecision fields. +func (_m *IdentityAdoptionDecision) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case identityadoptiondecision.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case identityadoptiondecision.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case identityadoptiondecision.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + _m.UpdatedAt = value.Time + } + case identityadoptiondecision.FieldPendingAuthSessionID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field pending_auth_session_id", values[i]) + } else if value.Valid { + _m.PendingAuthSessionID = value.Int64 + } + case identityadoptiondecision.FieldIdentityID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field identity_id", values[i]) + } else if value.Valid { + _m.IdentityID = new(int64) + *_m.IdentityID = value.Int64 + } + case identityadoptiondecision.FieldAdoptDisplayName: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field adopt_display_name", values[i]) + } else if value.Valid { + _m.AdoptDisplayName = value.Bool + } + case identityadoptiondecision.FieldAdoptAvatar: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field adopt_avatar", values[i]) + } else if value.Valid { + _m.AdoptAvatar = value.Bool + } + case identityadoptiondecision.FieldDecidedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field decided_at", values[i]) + } else if value.Valid { + _m.DecidedAt = value.Time + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the IdentityAdoptionDecision. +// This includes values selected through modifiers, order, etc. +func (_m *IdentityAdoptionDecision) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// QueryPendingAuthSession queries the "pending_auth_session" edge of the IdentityAdoptionDecision entity. +func (_m *IdentityAdoptionDecision) QueryPendingAuthSession() *PendingAuthSessionQuery { + return NewIdentityAdoptionDecisionClient(_m.config).QueryPendingAuthSession(_m) +} + +// QueryIdentity queries the "identity" edge of the IdentityAdoptionDecision entity. +func (_m *IdentityAdoptionDecision) QueryIdentity() *AuthIdentityQuery { + return NewIdentityAdoptionDecisionClient(_m.config).QueryIdentity(_m) +} + +// Update returns a builder for updating this IdentityAdoptionDecision. +// Note that you need to call IdentityAdoptionDecision.Unwrap() before calling this method if this IdentityAdoptionDecision +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *IdentityAdoptionDecision) Update() *IdentityAdoptionDecisionUpdateOne { + return NewIdentityAdoptionDecisionClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the IdentityAdoptionDecision entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *IdentityAdoptionDecision) Unwrap() *IdentityAdoptionDecision { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: IdentityAdoptionDecision is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *IdentityAdoptionDecision) String() string { + var builder strings.Builder + builder.WriteString("IdentityAdoptionDecision(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("pending_auth_session_id=") + builder.WriteString(fmt.Sprintf("%v", _m.PendingAuthSessionID)) + builder.WriteString(", ") + if v := _m.IdentityID; v != nil { + builder.WriteString("identity_id=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + builder.WriteString("adopt_display_name=") + builder.WriteString(fmt.Sprintf("%v", _m.AdoptDisplayName)) + builder.WriteString(", ") + builder.WriteString("adopt_avatar=") + builder.WriteString(fmt.Sprintf("%v", _m.AdoptAvatar)) + builder.WriteString(", ") + builder.WriteString("decided_at=") + builder.WriteString(_m.DecidedAt.Format(time.ANSIC)) + builder.WriteByte(')') + return builder.String() +} + +// IdentityAdoptionDecisions is a parsable slice of IdentityAdoptionDecision. +type IdentityAdoptionDecisions []*IdentityAdoptionDecision diff --git a/backend/ent/identityadoptiondecision/identityadoptiondecision.go b/backend/ent/identityadoptiondecision/identityadoptiondecision.go new file mode 100644 index 0000000000000000000000000000000000000000..93adaf7397c4d18c07fee9601737b17f571905ef --- /dev/null +++ b/backend/ent/identityadoptiondecision/identityadoptiondecision.go @@ -0,0 +1,159 @@ +// Code generated by ent, DO NOT EDIT. + +package identityadoptiondecision + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the identityadoptiondecision type in the database. + Label = "identity_adoption_decision" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldPendingAuthSessionID holds the string denoting the pending_auth_session_id field in the database. + FieldPendingAuthSessionID = "pending_auth_session_id" + // FieldIdentityID holds the string denoting the identity_id field in the database. + FieldIdentityID = "identity_id" + // FieldAdoptDisplayName holds the string denoting the adopt_display_name field in the database. + FieldAdoptDisplayName = "adopt_display_name" + // FieldAdoptAvatar holds the string denoting the adopt_avatar field in the database. + FieldAdoptAvatar = "adopt_avatar" + // FieldDecidedAt holds the string denoting the decided_at field in the database. + FieldDecidedAt = "decided_at" + // EdgePendingAuthSession holds the string denoting the pending_auth_session edge name in mutations. + EdgePendingAuthSession = "pending_auth_session" + // EdgeIdentity holds the string denoting the identity edge name in mutations. + EdgeIdentity = "identity" + // Table holds the table name of the identityadoptiondecision in the database. + Table = "identity_adoption_decisions" + // PendingAuthSessionTable is the table that holds the pending_auth_session relation/edge. + PendingAuthSessionTable = "identity_adoption_decisions" + // PendingAuthSessionInverseTable is the table name for the PendingAuthSession entity. + // It exists in this package in order to avoid circular dependency with the "pendingauthsession" package. + PendingAuthSessionInverseTable = "pending_auth_sessions" + // PendingAuthSessionColumn is the table column denoting the pending_auth_session relation/edge. + PendingAuthSessionColumn = "pending_auth_session_id" + // IdentityTable is the table that holds the identity relation/edge. + IdentityTable = "identity_adoption_decisions" + // IdentityInverseTable is the table name for the AuthIdentity entity. + // It exists in this package in order to avoid circular dependency with the "authidentity" package. + IdentityInverseTable = "auth_identities" + // IdentityColumn is the table column denoting the identity relation/edge. + IdentityColumn = "identity_id" +) + +// Columns holds all SQL columns for identityadoptiondecision fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldPendingAuthSessionID, + FieldIdentityID, + FieldAdoptDisplayName, + FieldAdoptAvatar, + FieldDecidedAt, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // DefaultAdoptDisplayName holds the default value on creation for the "adopt_display_name" field. + DefaultAdoptDisplayName bool + // DefaultAdoptAvatar holds the default value on creation for the "adopt_avatar" field. + DefaultAdoptAvatar bool + // DefaultDecidedAt holds the default value on creation for the "decided_at" field. + DefaultDecidedAt func() time.Time +) + +// OrderOption defines the ordering options for the IdentityAdoptionDecision queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByPendingAuthSessionID orders the results by the pending_auth_session_id field. +func ByPendingAuthSessionID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPendingAuthSessionID, opts...).ToFunc() +} + +// ByIdentityID orders the results by the identity_id field. +func ByIdentityID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIdentityID, opts...).ToFunc() +} + +// ByAdoptDisplayName orders the results by the adopt_display_name field. +func ByAdoptDisplayName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAdoptDisplayName, opts...).ToFunc() +} + +// ByAdoptAvatar orders the results by the adopt_avatar field. +func ByAdoptAvatar(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAdoptAvatar, opts...).ToFunc() +} + +// ByDecidedAt orders the results by the decided_at field. +func ByDecidedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDecidedAt, opts...).ToFunc() +} + +// ByPendingAuthSessionField orders the results by pending_auth_session field. +func ByPendingAuthSessionField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newPendingAuthSessionStep(), sql.OrderByField(field, opts...)) + } +} + +// ByIdentityField orders the results by identity field. +func ByIdentityField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newIdentityStep(), sql.OrderByField(field, opts...)) + } +} +func newPendingAuthSessionStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(PendingAuthSessionInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2O, true, PendingAuthSessionTable, PendingAuthSessionColumn), + ) +} +func newIdentityStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(IdentityInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn), + ) +} diff --git a/backend/ent/identityadoptiondecision/where.go b/backend/ent/identityadoptiondecision/where.go new file mode 100644 index 0000000000000000000000000000000000000000..1968f175063ac5962573aa4fbfcb2f0742527835 --- /dev/null +++ b/backend/ent/identityadoptiondecision/where.go @@ -0,0 +1,342 @@ +// Code generated by ent, DO NOT EDIT. + +package identityadoptiondecision + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// PendingAuthSessionID applies equality check predicate on the "pending_auth_session_id" field. It's identical to PendingAuthSessionIDEQ. +func PendingAuthSessionID(v int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldPendingAuthSessionID, v)) +} + +// IdentityID applies equality check predicate on the "identity_id" field. It's identical to IdentityIDEQ. +func IdentityID(v int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldIdentityID, v)) +} + +// AdoptDisplayName applies equality check predicate on the "adopt_display_name" field. It's identical to AdoptDisplayNameEQ. +func AdoptDisplayName(v bool) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptDisplayName, v)) +} + +// AdoptAvatar applies equality check predicate on the "adopt_avatar" field. It's identical to AdoptAvatarEQ. +func AdoptAvatar(v bool) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptAvatar, v)) +} + +// DecidedAt applies equality check predicate on the "decided_at" field. It's identical to DecidedAtEQ. +func DecidedAt(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldDecidedAt, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// PendingAuthSessionIDEQ applies the EQ predicate on the "pending_auth_session_id" field. +func PendingAuthSessionIDEQ(v int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldPendingAuthSessionID, v)) +} + +// PendingAuthSessionIDNEQ applies the NEQ predicate on the "pending_auth_session_id" field. +func PendingAuthSessionIDNEQ(v int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldPendingAuthSessionID, v)) +} + +// PendingAuthSessionIDIn applies the In predicate on the "pending_auth_session_id" field. +func PendingAuthSessionIDIn(vs ...int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldPendingAuthSessionID, vs...)) +} + +// PendingAuthSessionIDNotIn applies the NotIn predicate on the "pending_auth_session_id" field. +func PendingAuthSessionIDNotIn(vs ...int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldPendingAuthSessionID, vs...)) +} + +// IdentityIDEQ applies the EQ predicate on the "identity_id" field. +func IdentityIDEQ(v int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldIdentityID, v)) +} + +// IdentityIDNEQ applies the NEQ predicate on the "identity_id" field. +func IdentityIDNEQ(v int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldIdentityID, v)) +} + +// IdentityIDIn applies the In predicate on the "identity_id" field. +func IdentityIDIn(vs ...int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldIdentityID, vs...)) +} + +// IdentityIDNotIn applies the NotIn predicate on the "identity_id" field. +func IdentityIDNotIn(vs ...int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldIdentityID, vs...)) +} + +// IdentityIDIsNil applies the IsNil predicate on the "identity_id" field. +func IdentityIDIsNil() predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldIsNull(FieldIdentityID)) +} + +// IdentityIDNotNil applies the NotNil predicate on the "identity_id" field. +func IdentityIDNotNil() predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNotNull(FieldIdentityID)) +} + +// AdoptDisplayNameEQ applies the EQ predicate on the "adopt_display_name" field. +func AdoptDisplayNameEQ(v bool) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptDisplayName, v)) +} + +// AdoptDisplayNameNEQ applies the NEQ predicate on the "adopt_display_name" field. +func AdoptDisplayNameNEQ(v bool) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldAdoptDisplayName, v)) +} + +// AdoptAvatarEQ applies the EQ predicate on the "adopt_avatar" field. +func AdoptAvatarEQ(v bool) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptAvatar, v)) +} + +// AdoptAvatarNEQ applies the NEQ predicate on the "adopt_avatar" field. +func AdoptAvatarNEQ(v bool) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldAdoptAvatar, v)) +} + +// DecidedAtEQ applies the EQ predicate on the "decided_at" field. +func DecidedAtEQ(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldDecidedAt, v)) +} + +// DecidedAtNEQ applies the NEQ predicate on the "decided_at" field. +func DecidedAtNEQ(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldDecidedAt, v)) +} + +// DecidedAtIn applies the In predicate on the "decided_at" field. +func DecidedAtIn(vs ...time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldDecidedAt, vs...)) +} + +// DecidedAtNotIn applies the NotIn predicate on the "decided_at" field. +func DecidedAtNotIn(vs ...time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldDecidedAt, vs...)) +} + +// DecidedAtGT applies the GT predicate on the "decided_at" field. +func DecidedAtGT(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldDecidedAt, v)) +} + +// DecidedAtGTE applies the GTE predicate on the "decided_at" field. +func DecidedAtGTE(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldDecidedAt, v)) +} + +// DecidedAtLT applies the LT predicate on the "decided_at" field. +func DecidedAtLT(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldDecidedAt, v)) +} + +// DecidedAtLTE applies the LTE predicate on the "decided_at" field. +func DecidedAtLTE(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldDecidedAt, v)) +} + +// HasPendingAuthSession applies the HasEdge predicate on the "pending_auth_session" edge. +func HasPendingAuthSession() predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2O, true, PendingAuthSessionTable, PendingAuthSessionColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasPendingAuthSessionWith applies the HasEdge predicate on the "pending_auth_session" edge with a given conditions (other predicates). +func HasPendingAuthSessionWith(preds ...predicate.PendingAuthSession) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(func(s *sql.Selector) { + step := newPendingAuthSessionStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasIdentity applies the HasEdge predicate on the "identity" edge. +func HasIdentity() predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasIdentityWith applies the HasEdge predicate on the "identity" edge with a given conditions (other predicates). +func HasIdentityWith(preds ...predicate.AuthIdentity) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(func(s *sql.Selector) { + step := newIdentityStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.IdentityAdoptionDecision) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.IdentityAdoptionDecision) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.IdentityAdoptionDecision) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.NotPredicates(p)) +} diff --git a/backend/ent/identityadoptiondecision_create.go b/backend/ent/identityadoptiondecision_create.go new file mode 100644 index 0000000000000000000000000000000000000000..491ba9f9a6626000bbc7095ba4390d03fcd00390 --- /dev/null +++ b/backend/ent/identityadoptiondecision_create.go @@ -0,0 +1,843 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" +) + +// IdentityAdoptionDecisionCreate is the builder for creating a IdentityAdoptionDecision entity. +type IdentityAdoptionDecisionCreate struct { + config + mutation *IdentityAdoptionDecisionMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (_c *IdentityAdoptionDecisionCreate) SetCreatedAt(v time.Time) *IdentityAdoptionDecisionCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *IdentityAdoptionDecisionCreate) SetNillableCreatedAt(v *time.Time) *IdentityAdoptionDecisionCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUpdatedAt sets the "updated_at" field. +func (_c *IdentityAdoptionDecisionCreate) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionCreate { + _c.mutation.SetUpdatedAt(v) + return _c +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (_c *IdentityAdoptionDecisionCreate) SetNillableUpdatedAt(v *time.Time) *IdentityAdoptionDecisionCreate { + if v != nil { + _c.SetUpdatedAt(*v) + } + return _c +} + +// SetPendingAuthSessionID sets the "pending_auth_session_id" field. +func (_c *IdentityAdoptionDecisionCreate) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionCreate { + _c.mutation.SetPendingAuthSessionID(v) + return _c +} + +// SetIdentityID sets the "identity_id" field. +func (_c *IdentityAdoptionDecisionCreate) SetIdentityID(v int64) *IdentityAdoptionDecisionCreate { + _c.mutation.SetIdentityID(v) + return _c +} + +// SetNillableIdentityID sets the "identity_id" field if the given value is not nil. +func (_c *IdentityAdoptionDecisionCreate) SetNillableIdentityID(v *int64) *IdentityAdoptionDecisionCreate { + if v != nil { + _c.SetIdentityID(*v) + } + return _c +} + +// SetAdoptDisplayName sets the "adopt_display_name" field. +func (_c *IdentityAdoptionDecisionCreate) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionCreate { + _c.mutation.SetAdoptDisplayName(v) + return _c +} + +// SetNillableAdoptDisplayName sets the "adopt_display_name" field if the given value is not nil. +func (_c *IdentityAdoptionDecisionCreate) SetNillableAdoptDisplayName(v *bool) *IdentityAdoptionDecisionCreate { + if v != nil { + _c.SetAdoptDisplayName(*v) + } + return _c +} + +// SetAdoptAvatar sets the "adopt_avatar" field. +func (_c *IdentityAdoptionDecisionCreate) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionCreate { + _c.mutation.SetAdoptAvatar(v) + return _c +} + +// SetNillableAdoptAvatar sets the "adopt_avatar" field if the given value is not nil. +func (_c *IdentityAdoptionDecisionCreate) SetNillableAdoptAvatar(v *bool) *IdentityAdoptionDecisionCreate { + if v != nil { + _c.SetAdoptAvatar(*v) + } + return _c +} + +// SetDecidedAt sets the "decided_at" field. +func (_c *IdentityAdoptionDecisionCreate) SetDecidedAt(v time.Time) *IdentityAdoptionDecisionCreate { + _c.mutation.SetDecidedAt(v) + return _c +} + +// SetNillableDecidedAt sets the "decided_at" field if the given value is not nil. +func (_c *IdentityAdoptionDecisionCreate) SetNillableDecidedAt(v *time.Time) *IdentityAdoptionDecisionCreate { + if v != nil { + _c.SetDecidedAt(*v) + } + return _c +} + +// SetPendingAuthSession sets the "pending_auth_session" edge to the PendingAuthSession entity. +func (_c *IdentityAdoptionDecisionCreate) SetPendingAuthSession(v *PendingAuthSession) *IdentityAdoptionDecisionCreate { + return _c.SetPendingAuthSessionID(v.ID) +} + +// SetIdentity sets the "identity" edge to the AuthIdentity entity. +func (_c *IdentityAdoptionDecisionCreate) SetIdentity(v *AuthIdentity) *IdentityAdoptionDecisionCreate { + return _c.SetIdentityID(v.ID) +} + +// Mutation returns the IdentityAdoptionDecisionMutation object of the builder. +func (_c *IdentityAdoptionDecisionCreate) Mutation() *IdentityAdoptionDecisionMutation { + return _c.mutation +} + +// Save creates the IdentityAdoptionDecision in the database. +func (_c *IdentityAdoptionDecisionCreate) Save(ctx context.Context) (*IdentityAdoptionDecision, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *IdentityAdoptionDecisionCreate) SaveX(ctx context.Context) *IdentityAdoptionDecision { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *IdentityAdoptionDecisionCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *IdentityAdoptionDecisionCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *IdentityAdoptionDecisionCreate) defaults() { + if _, ok := _c.mutation.CreatedAt(); !ok { + v := identityadoptiondecision.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + v := identityadoptiondecision.DefaultUpdatedAt() + _c.mutation.SetUpdatedAt(v) + } + if _, ok := _c.mutation.AdoptDisplayName(); !ok { + v := identityadoptiondecision.DefaultAdoptDisplayName + _c.mutation.SetAdoptDisplayName(v) + } + if _, ok := _c.mutation.AdoptAvatar(); !ok { + v := identityadoptiondecision.DefaultAdoptAvatar + _c.mutation.SetAdoptAvatar(v) + } + if _, ok := _c.mutation.DecidedAt(); !ok { + v := identityadoptiondecision.DefaultDecidedAt() + _c.mutation.SetDecidedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *IdentityAdoptionDecisionCreate) check() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.created_at"`)} + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.updated_at"`)} + } + if _, ok := _c.mutation.PendingAuthSessionID(); !ok { + return &ValidationError{Name: "pending_auth_session_id", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.pending_auth_session_id"`)} + } + if _, ok := _c.mutation.AdoptDisplayName(); !ok { + return &ValidationError{Name: "adopt_display_name", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.adopt_display_name"`)} + } + if _, ok := _c.mutation.AdoptAvatar(); !ok { + return &ValidationError{Name: "adopt_avatar", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.adopt_avatar"`)} + } + if _, ok := _c.mutation.DecidedAt(); !ok { + return &ValidationError{Name: "decided_at", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.decided_at"`)} + } + if len(_c.mutation.PendingAuthSessionIDs()) == 0 { + return &ValidationError{Name: "pending_auth_session", err: errors.New(`ent: missing required edge "IdentityAdoptionDecision.pending_auth_session"`)} + } + return nil +} + +func (_c *IdentityAdoptionDecisionCreate) sqlSave(ctx context.Context) (*IdentityAdoptionDecision, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *IdentityAdoptionDecisionCreate) createSpec() (*IdentityAdoptionDecision, *sqlgraph.CreateSpec) { + var ( + _node = &IdentityAdoptionDecision{config: _c.config} + _spec = sqlgraph.NewCreateSpec(identityadoptiondecision.Table, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(identityadoptiondecision.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.UpdatedAt(); ok { + _spec.SetField(identityadoptiondecision.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := _c.mutation.AdoptDisplayName(); ok { + _spec.SetField(identityadoptiondecision.FieldAdoptDisplayName, field.TypeBool, value) + _node.AdoptDisplayName = value + } + if value, ok := _c.mutation.AdoptAvatar(); ok { + _spec.SetField(identityadoptiondecision.FieldAdoptAvatar, field.TypeBool, value) + _node.AdoptAvatar = value + } + if value, ok := _c.mutation.DecidedAt(); ok { + _spec.SetField(identityadoptiondecision.FieldDecidedAt, field.TypeTime, value) + _node.DecidedAt = value + } + if nodes := _c.mutation.PendingAuthSessionIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: true, + Table: identityadoptiondecision.PendingAuthSessionTable, + Columns: []string{identityadoptiondecision.PendingAuthSessionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.PendingAuthSessionID = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := _c.mutation.IdentityIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: identityadoptiondecision.IdentityTable, + Columns: []string{identityadoptiondecision.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.IdentityID = &nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.IdentityAdoptionDecision.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.IdentityAdoptionDecisionUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *IdentityAdoptionDecisionCreate) OnConflict(opts ...sql.ConflictOption) *IdentityAdoptionDecisionUpsertOne { + _c.conflict = opts + return &IdentityAdoptionDecisionUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.IdentityAdoptionDecision.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *IdentityAdoptionDecisionCreate) OnConflictColumns(columns ...string) *IdentityAdoptionDecisionUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &IdentityAdoptionDecisionUpsertOne{ + create: _c, + } +} + +type ( + // IdentityAdoptionDecisionUpsertOne is the builder for "upsert"-ing + // one IdentityAdoptionDecision node. + IdentityAdoptionDecisionUpsertOne struct { + create *IdentityAdoptionDecisionCreate + } + + // IdentityAdoptionDecisionUpsert is the "OnConflict" setter. + IdentityAdoptionDecisionUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *IdentityAdoptionDecisionUpsert) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpsert { + u.Set(identityadoptiondecision.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsert) UpdateUpdatedAt() *IdentityAdoptionDecisionUpsert { + u.SetExcluded(identityadoptiondecision.FieldUpdatedAt) + return u +} + +// SetPendingAuthSessionID sets the "pending_auth_session_id" field. +func (u *IdentityAdoptionDecisionUpsert) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpsert { + u.Set(identityadoptiondecision.FieldPendingAuthSessionID, v) + return u +} + +// UpdatePendingAuthSessionID sets the "pending_auth_session_id" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsert) UpdatePendingAuthSessionID() *IdentityAdoptionDecisionUpsert { + u.SetExcluded(identityadoptiondecision.FieldPendingAuthSessionID) + return u +} + +// SetIdentityID sets the "identity_id" field. +func (u *IdentityAdoptionDecisionUpsert) SetIdentityID(v int64) *IdentityAdoptionDecisionUpsert { + u.Set(identityadoptiondecision.FieldIdentityID, v) + return u +} + +// UpdateIdentityID sets the "identity_id" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsert) UpdateIdentityID() *IdentityAdoptionDecisionUpsert { + u.SetExcluded(identityadoptiondecision.FieldIdentityID) + return u +} + +// ClearIdentityID clears the value of the "identity_id" field. +func (u *IdentityAdoptionDecisionUpsert) ClearIdentityID() *IdentityAdoptionDecisionUpsert { + u.SetNull(identityadoptiondecision.FieldIdentityID) + return u +} + +// SetAdoptDisplayName sets the "adopt_display_name" field. +func (u *IdentityAdoptionDecisionUpsert) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpsert { + u.Set(identityadoptiondecision.FieldAdoptDisplayName, v) + return u +} + +// UpdateAdoptDisplayName sets the "adopt_display_name" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsert) UpdateAdoptDisplayName() *IdentityAdoptionDecisionUpsert { + u.SetExcluded(identityadoptiondecision.FieldAdoptDisplayName) + return u +} + +// SetAdoptAvatar sets the "adopt_avatar" field. +func (u *IdentityAdoptionDecisionUpsert) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpsert { + u.Set(identityadoptiondecision.FieldAdoptAvatar, v) + return u +} + +// UpdateAdoptAvatar sets the "adopt_avatar" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsert) UpdateAdoptAvatar() *IdentityAdoptionDecisionUpsert { + u.SetExcluded(identityadoptiondecision.FieldAdoptAvatar) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.IdentityAdoptionDecision.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *IdentityAdoptionDecisionUpsertOne) UpdateNewValues() *IdentityAdoptionDecisionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(identityadoptiondecision.FieldCreatedAt) + } + if _, exists := u.create.mutation.DecidedAt(); exists { + s.SetIgnore(identityadoptiondecision.FieldDecidedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.IdentityAdoptionDecision.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *IdentityAdoptionDecisionUpsertOne) Ignore() *IdentityAdoptionDecisionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *IdentityAdoptionDecisionUpsertOne) DoNothing() *IdentityAdoptionDecisionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the IdentityAdoptionDecisionCreate.OnConflict +// documentation for more info. +func (u *IdentityAdoptionDecisionUpsertOne) Update(set func(*IdentityAdoptionDecisionUpsert)) *IdentityAdoptionDecisionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&IdentityAdoptionDecisionUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *IdentityAdoptionDecisionUpsertOne) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertOne) UpdateUpdatedAt() *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetPendingAuthSessionID sets the "pending_auth_session_id" field. +func (u *IdentityAdoptionDecisionUpsertOne) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetPendingAuthSessionID(v) + }) +} + +// UpdatePendingAuthSessionID sets the "pending_auth_session_id" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertOne) UpdatePendingAuthSessionID() *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdatePendingAuthSessionID() + }) +} + +// SetIdentityID sets the "identity_id" field. +func (u *IdentityAdoptionDecisionUpsertOne) SetIdentityID(v int64) *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetIdentityID(v) + }) +} + +// UpdateIdentityID sets the "identity_id" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertOne) UpdateIdentityID() *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateIdentityID() + }) +} + +// ClearIdentityID clears the value of the "identity_id" field. +func (u *IdentityAdoptionDecisionUpsertOne) ClearIdentityID() *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.ClearIdentityID() + }) +} + +// SetAdoptDisplayName sets the "adopt_display_name" field. +func (u *IdentityAdoptionDecisionUpsertOne) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetAdoptDisplayName(v) + }) +} + +// UpdateAdoptDisplayName sets the "adopt_display_name" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertOne) UpdateAdoptDisplayName() *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateAdoptDisplayName() + }) +} + +// SetAdoptAvatar sets the "adopt_avatar" field. +func (u *IdentityAdoptionDecisionUpsertOne) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetAdoptAvatar(v) + }) +} + +// UpdateAdoptAvatar sets the "adopt_avatar" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertOne) UpdateAdoptAvatar() *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateAdoptAvatar() + }) +} + +// Exec executes the query. +func (u *IdentityAdoptionDecisionUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for IdentityAdoptionDecisionCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *IdentityAdoptionDecisionUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *IdentityAdoptionDecisionUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *IdentityAdoptionDecisionUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// IdentityAdoptionDecisionCreateBulk is the builder for creating many IdentityAdoptionDecision entities in bulk. +type IdentityAdoptionDecisionCreateBulk struct { + config + err error + builders []*IdentityAdoptionDecisionCreate + conflict []sql.ConflictOption +} + +// Save creates the IdentityAdoptionDecision entities in the database. +func (_c *IdentityAdoptionDecisionCreateBulk) Save(ctx context.Context) ([]*IdentityAdoptionDecision, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*IdentityAdoptionDecision, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*IdentityAdoptionDecisionMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *IdentityAdoptionDecisionCreateBulk) SaveX(ctx context.Context) []*IdentityAdoptionDecision { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *IdentityAdoptionDecisionCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *IdentityAdoptionDecisionCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.IdentityAdoptionDecision.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.IdentityAdoptionDecisionUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *IdentityAdoptionDecisionCreateBulk) OnConflict(opts ...sql.ConflictOption) *IdentityAdoptionDecisionUpsertBulk { + _c.conflict = opts + return &IdentityAdoptionDecisionUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.IdentityAdoptionDecision.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *IdentityAdoptionDecisionCreateBulk) OnConflictColumns(columns ...string) *IdentityAdoptionDecisionUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &IdentityAdoptionDecisionUpsertBulk{ + create: _c, + } +} + +// IdentityAdoptionDecisionUpsertBulk is the builder for "upsert"-ing +// a bulk of IdentityAdoptionDecision nodes. +type IdentityAdoptionDecisionUpsertBulk struct { + create *IdentityAdoptionDecisionCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.IdentityAdoptionDecision.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *IdentityAdoptionDecisionUpsertBulk) UpdateNewValues() *IdentityAdoptionDecisionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(identityadoptiondecision.FieldCreatedAt) + } + if _, exists := b.mutation.DecidedAt(); exists { + s.SetIgnore(identityadoptiondecision.FieldDecidedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.IdentityAdoptionDecision.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *IdentityAdoptionDecisionUpsertBulk) Ignore() *IdentityAdoptionDecisionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *IdentityAdoptionDecisionUpsertBulk) DoNothing() *IdentityAdoptionDecisionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the IdentityAdoptionDecisionCreateBulk.OnConflict +// documentation for more info. +func (u *IdentityAdoptionDecisionUpsertBulk) Update(set func(*IdentityAdoptionDecisionUpsert)) *IdentityAdoptionDecisionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&IdentityAdoptionDecisionUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *IdentityAdoptionDecisionUpsertBulk) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertBulk) UpdateUpdatedAt() *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetPendingAuthSessionID sets the "pending_auth_session_id" field. +func (u *IdentityAdoptionDecisionUpsertBulk) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetPendingAuthSessionID(v) + }) +} + +// UpdatePendingAuthSessionID sets the "pending_auth_session_id" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertBulk) UpdatePendingAuthSessionID() *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdatePendingAuthSessionID() + }) +} + +// SetIdentityID sets the "identity_id" field. +func (u *IdentityAdoptionDecisionUpsertBulk) SetIdentityID(v int64) *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetIdentityID(v) + }) +} + +// UpdateIdentityID sets the "identity_id" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertBulk) UpdateIdentityID() *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateIdentityID() + }) +} + +// ClearIdentityID clears the value of the "identity_id" field. +func (u *IdentityAdoptionDecisionUpsertBulk) ClearIdentityID() *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.ClearIdentityID() + }) +} + +// SetAdoptDisplayName sets the "adopt_display_name" field. +func (u *IdentityAdoptionDecisionUpsertBulk) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetAdoptDisplayName(v) + }) +} + +// UpdateAdoptDisplayName sets the "adopt_display_name" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertBulk) UpdateAdoptDisplayName() *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateAdoptDisplayName() + }) +} + +// SetAdoptAvatar sets the "adopt_avatar" field. +func (u *IdentityAdoptionDecisionUpsertBulk) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetAdoptAvatar(v) + }) +} + +// UpdateAdoptAvatar sets the "adopt_avatar" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertBulk) UpdateAdoptAvatar() *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateAdoptAvatar() + }) +} + +// Exec executes the query. +func (u *IdentityAdoptionDecisionUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the IdentityAdoptionDecisionCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for IdentityAdoptionDecisionCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *IdentityAdoptionDecisionUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/identityadoptiondecision_delete.go b/backend/ent/identityadoptiondecision_delete.go new file mode 100644 index 0000000000000000000000000000000000000000..ef3d328d050c8743b8d6080d8d7daca8a6b59d4a --- /dev/null +++ b/backend/ent/identityadoptiondecision_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// IdentityAdoptionDecisionDelete is the builder for deleting a IdentityAdoptionDecision entity. +type IdentityAdoptionDecisionDelete struct { + config + hooks []Hook + mutation *IdentityAdoptionDecisionMutation +} + +// Where appends a list predicates to the IdentityAdoptionDecisionDelete builder. +func (_d *IdentityAdoptionDecisionDelete) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *IdentityAdoptionDecisionDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *IdentityAdoptionDecisionDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *IdentityAdoptionDecisionDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(identityadoptiondecision.Table, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// IdentityAdoptionDecisionDeleteOne is the builder for deleting a single IdentityAdoptionDecision entity. +type IdentityAdoptionDecisionDeleteOne struct { + _d *IdentityAdoptionDecisionDelete +} + +// Where appends a list predicates to the IdentityAdoptionDecisionDelete builder. +func (_d *IdentityAdoptionDecisionDeleteOne) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *IdentityAdoptionDecisionDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{identityadoptiondecision.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *IdentityAdoptionDecisionDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/identityadoptiondecision_query.go b/backend/ent/identityadoptiondecision_query.go new file mode 100644 index 0000000000000000000000000000000000000000..4082d8ee74e4d0210372e7e710032975cde08ceb --- /dev/null +++ b/backend/ent/identityadoptiondecision_query.go @@ -0,0 +1,721 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// IdentityAdoptionDecisionQuery is the builder for querying IdentityAdoptionDecision entities. +type IdentityAdoptionDecisionQuery struct { + config + ctx *QueryContext + order []identityadoptiondecision.OrderOption + inters []Interceptor + predicates []predicate.IdentityAdoptionDecision + withPendingAuthSession *PendingAuthSessionQuery + withIdentity *AuthIdentityQuery + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the IdentityAdoptionDecisionQuery builder. +func (_q *IdentityAdoptionDecisionQuery) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *IdentityAdoptionDecisionQuery) Limit(limit int) *IdentityAdoptionDecisionQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *IdentityAdoptionDecisionQuery) Offset(offset int) *IdentityAdoptionDecisionQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *IdentityAdoptionDecisionQuery) Unique(unique bool) *IdentityAdoptionDecisionQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *IdentityAdoptionDecisionQuery) Order(o ...identityadoptiondecision.OrderOption) *IdentityAdoptionDecisionQuery { + _q.order = append(_q.order, o...) + return _q +} + +// QueryPendingAuthSession chains the current query on the "pending_auth_session" edge. +func (_q *IdentityAdoptionDecisionQuery) QueryPendingAuthSession() *PendingAuthSessionQuery { + query := (&PendingAuthSessionClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, selector), + sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID), + sqlgraph.Edge(sqlgraph.O2O, true, identityadoptiondecision.PendingAuthSessionTable, identityadoptiondecision.PendingAuthSessionColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryIdentity chains the current query on the "identity" edge. +func (_q *IdentityAdoptionDecisionQuery) QueryIdentity() *AuthIdentityQuery { + query := (&AuthIdentityClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, selector), + sqlgraph.To(authidentity.Table, authidentity.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, identityadoptiondecision.IdentityTable, identityadoptiondecision.IdentityColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first IdentityAdoptionDecision entity from the query. +// Returns a *NotFoundError when no IdentityAdoptionDecision was found. +func (_q *IdentityAdoptionDecisionQuery) First(ctx context.Context) (*IdentityAdoptionDecision, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{identityadoptiondecision.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) FirstX(ctx context.Context) *IdentityAdoptionDecision { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first IdentityAdoptionDecision ID from the query. +// Returns a *NotFoundError when no IdentityAdoptionDecision ID was found. +func (_q *IdentityAdoptionDecisionQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{identityadoptiondecision.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single IdentityAdoptionDecision entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one IdentityAdoptionDecision entity is found. +// Returns a *NotFoundError when no IdentityAdoptionDecision entities are found. +func (_q *IdentityAdoptionDecisionQuery) Only(ctx context.Context) (*IdentityAdoptionDecision, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{identityadoptiondecision.Label} + default: + return nil, &NotSingularError{identityadoptiondecision.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) OnlyX(ctx context.Context) *IdentityAdoptionDecision { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only IdentityAdoptionDecision ID in the query. +// Returns a *NotSingularError when more than one IdentityAdoptionDecision ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *IdentityAdoptionDecisionQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{identityadoptiondecision.Label} + default: + err = &NotSingularError{identityadoptiondecision.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of IdentityAdoptionDecisions. +func (_q *IdentityAdoptionDecisionQuery) All(ctx context.Context) ([]*IdentityAdoptionDecision, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*IdentityAdoptionDecision, *IdentityAdoptionDecisionQuery]() + return withInterceptors[[]*IdentityAdoptionDecision](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) AllX(ctx context.Context) []*IdentityAdoptionDecision { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of IdentityAdoptionDecision IDs. +func (_q *IdentityAdoptionDecisionQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(identityadoptiondecision.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *IdentityAdoptionDecisionQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*IdentityAdoptionDecisionQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *IdentityAdoptionDecisionQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the IdentityAdoptionDecisionQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *IdentityAdoptionDecisionQuery) Clone() *IdentityAdoptionDecisionQuery { + if _q == nil { + return nil + } + return &IdentityAdoptionDecisionQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]identityadoptiondecision.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.IdentityAdoptionDecision{}, _q.predicates...), + withPendingAuthSession: _q.withPendingAuthSession.Clone(), + withIdentity: _q.withIdentity.Clone(), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// WithPendingAuthSession tells the query-builder to eager-load the nodes that are connected to +// the "pending_auth_session" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *IdentityAdoptionDecisionQuery) WithPendingAuthSession(opts ...func(*PendingAuthSessionQuery)) *IdentityAdoptionDecisionQuery { + query := (&PendingAuthSessionClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withPendingAuthSession = query + return _q +} + +// WithIdentity tells the query-builder to eager-load the nodes that are connected to +// the "identity" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *IdentityAdoptionDecisionQuery) WithIdentity(opts ...func(*AuthIdentityQuery)) *IdentityAdoptionDecisionQuery { + query := (&AuthIdentityClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withIdentity = query + return _q +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.IdentityAdoptionDecision.Query(). +// GroupBy(identityadoptiondecision.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *IdentityAdoptionDecisionQuery) GroupBy(field string, fields ...string) *IdentityAdoptionDecisionGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &IdentityAdoptionDecisionGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = identityadoptiondecision.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.IdentityAdoptionDecision.Query(). +// Select(identityadoptiondecision.FieldCreatedAt). +// Scan(ctx, &v) +func (_q *IdentityAdoptionDecisionQuery) Select(fields ...string) *IdentityAdoptionDecisionSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &IdentityAdoptionDecisionSelect{IdentityAdoptionDecisionQuery: _q} + sbuild.label = identityadoptiondecision.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a IdentityAdoptionDecisionSelect configured with the given aggregations. +func (_q *IdentityAdoptionDecisionQuery) Aggregate(fns ...AggregateFunc) *IdentityAdoptionDecisionSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *IdentityAdoptionDecisionQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !identityadoptiondecision.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *IdentityAdoptionDecisionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*IdentityAdoptionDecision, error) { + var ( + nodes = []*IdentityAdoptionDecision{} + _spec = _q.querySpec() + loadedTypes = [2]bool{ + _q.withPendingAuthSession != nil, + _q.withIdentity != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*IdentityAdoptionDecision).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &IdentityAdoptionDecision{config: _q.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := _q.withPendingAuthSession; query != nil { + if err := _q.loadPendingAuthSession(ctx, query, nodes, nil, + func(n *IdentityAdoptionDecision, e *PendingAuthSession) { n.Edges.PendingAuthSession = e }); err != nil { + return nil, err + } + } + if query := _q.withIdentity; query != nil { + if err := _q.loadIdentity(ctx, query, nodes, nil, + func(n *IdentityAdoptionDecision, e *AuthIdentity) { n.Edges.Identity = e }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (_q *IdentityAdoptionDecisionQuery) loadPendingAuthSession(ctx context.Context, query *PendingAuthSessionQuery, nodes []*IdentityAdoptionDecision, init func(*IdentityAdoptionDecision), assign func(*IdentityAdoptionDecision, *PendingAuthSession)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*IdentityAdoptionDecision) + for i := range nodes { + fk := nodes[i].PendingAuthSessionID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(pendingauthsession.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "pending_auth_session_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} +func (_q *IdentityAdoptionDecisionQuery) loadIdentity(ctx context.Context, query *AuthIdentityQuery, nodes []*IdentityAdoptionDecision, init func(*IdentityAdoptionDecision), assign func(*IdentityAdoptionDecision, *AuthIdentity)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*IdentityAdoptionDecision) + for i := range nodes { + if nodes[i].IdentityID == nil { + continue + } + fk := *nodes[i].IdentityID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(authidentity.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "identity_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + +func (_q *IdentityAdoptionDecisionQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *IdentityAdoptionDecisionQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(identityadoptiondecision.Table, identityadoptiondecision.Columns, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, identityadoptiondecision.FieldID) + for i := range fields { + if fields[i] != identityadoptiondecision.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if _q.withPendingAuthSession != nil { + _spec.Node.AddColumnOnce(identityadoptiondecision.FieldPendingAuthSessionID) + } + if _q.withIdentity != nil { + _spec.Node.AddColumnOnce(identityadoptiondecision.FieldIdentityID) + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *IdentityAdoptionDecisionQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(identityadoptiondecision.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = identityadoptiondecision.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *IdentityAdoptionDecisionQuery) ForUpdate(opts ...sql.LockOption) *IdentityAdoptionDecisionQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *IdentityAdoptionDecisionQuery) ForShare(opts ...sql.LockOption) *IdentityAdoptionDecisionQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// IdentityAdoptionDecisionGroupBy is the group-by builder for IdentityAdoptionDecision entities. +type IdentityAdoptionDecisionGroupBy struct { + selector + build *IdentityAdoptionDecisionQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *IdentityAdoptionDecisionGroupBy) Aggregate(fns ...AggregateFunc) *IdentityAdoptionDecisionGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *IdentityAdoptionDecisionGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*IdentityAdoptionDecisionQuery, *IdentityAdoptionDecisionGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *IdentityAdoptionDecisionGroupBy) sqlScan(ctx context.Context, root *IdentityAdoptionDecisionQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// IdentityAdoptionDecisionSelect is the builder for selecting fields of IdentityAdoptionDecision entities. +type IdentityAdoptionDecisionSelect struct { + *IdentityAdoptionDecisionQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *IdentityAdoptionDecisionSelect) Aggregate(fns ...AggregateFunc) *IdentityAdoptionDecisionSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *IdentityAdoptionDecisionSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*IdentityAdoptionDecisionQuery, *IdentityAdoptionDecisionSelect](ctx, _s.IdentityAdoptionDecisionQuery, _s, _s.inters, v) +} + +func (_s *IdentityAdoptionDecisionSelect) sqlScan(ctx context.Context, root *IdentityAdoptionDecisionQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/identityadoptiondecision_update.go b/backend/ent/identityadoptiondecision_update.go new file mode 100644 index 0000000000000000000000000000000000000000..0ca21d270aa3dd30254c874b44e25bd7f34bd203 --- /dev/null +++ b/backend/ent/identityadoptiondecision_update.go @@ -0,0 +1,532 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// IdentityAdoptionDecisionUpdate is the builder for updating IdentityAdoptionDecision entities. +type IdentityAdoptionDecisionUpdate struct { + config + hooks []Hook + mutation *IdentityAdoptionDecisionMutation +} + +// Where appends a list predicates to the IdentityAdoptionDecisionUpdate builder. +func (_u *IdentityAdoptionDecisionUpdate) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *IdentityAdoptionDecisionUpdate) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpdate { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetPendingAuthSessionID sets the "pending_auth_session_id" field. +func (_u *IdentityAdoptionDecisionUpdate) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpdate { + _u.mutation.SetPendingAuthSessionID(v) + return _u +} + +// SetNillablePendingAuthSessionID sets the "pending_auth_session_id" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdate) SetNillablePendingAuthSessionID(v *int64) *IdentityAdoptionDecisionUpdate { + if v != nil { + _u.SetPendingAuthSessionID(*v) + } + return _u +} + +// SetIdentityID sets the "identity_id" field. +func (_u *IdentityAdoptionDecisionUpdate) SetIdentityID(v int64) *IdentityAdoptionDecisionUpdate { + _u.mutation.SetIdentityID(v) + return _u +} + +// SetNillableIdentityID sets the "identity_id" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdate) SetNillableIdentityID(v *int64) *IdentityAdoptionDecisionUpdate { + if v != nil { + _u.SetIdentityID(*v) + } + return _u +} + +// ClearIdentityID clears the value of the "identity_id" field. +func (_u *IdentityAdoptionDecisionUpdate) ClearIdentityID() *IdentityAdoptionDecisionUpdate { + _u.mutation.ClearIdentityID() + return _u +} + +// SetAdoptDisplayName sets the "adopt_display_name" field. +func (_u *IdentityAdoptionDecisionUpdate) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpdate { + _u.mutation.SetAdoptDisplayName(v) + return _u +} + +// SetNillableAdoptDisplayName sets the "adopt_display_name" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdate) SetNillableAdoptDisplayName(v *bool) *IdentityAdoptionDecisionUpdate { + if v != nil { + _u.SetAdoptDisplayName(*v) + } + return _u +} + +// SetAdoptAvatar sets the "adopt_avatar" field. +func (_u *IdentityAdoptionDecisionUpdate) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpdate { + _u.mutation.SetAdoptAvatar(v) + return _u +} + +// SetNillableAdoptAvatar sets the "adopt_avatar" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdate) SetNillableAdoptAvatar(v *bool) *IdentityAdoptionDecisionUpdate { + if v != nil { + _u.SetAdoptAvatar(*v) + } + return _u +} + +// SetPendingAuthSession sets the "pending_auth_session" edge to the PendingAuthSession entity. +func (_u *IdentityAdoptionDecisionUpdate) SetPendingAuthSession(v *PendingAuthSession) *IdentityAdoptionDecisionUpdate { + return _u.SetPendingAuthSessionID(v.ID) +} + +// SetIdentity sets the "identity" edge to the AuthIdentity entity. +func (_u *IdentityAdoptionDecisionUpdate) SetIdentity(v *AuthIdentity) *IdentityAdoptionDecisionUpdate { + return _u.SetIdentityID(v.ID) +} + +// Mutation returns the IdentityAdoptionDecisionMutation object of the builder. +func (_u *IdentityAdoptionDecisionUpdate) Mutation() *IdentityAdoptionDecisionMutation { + return _u.mutation +} + +// ClearPendingAuthSession clears the "pending_auth_session" edge to the PendingAuthSession entity. +func (_u *IdentityAdoptionDecisionUpdate) ClearPendingAuthSession() *IdentityAdoptionDecisionUpdate { + _u.mutation.ClearPendingAuthSession() + return _u +} + +// ClearIdentity clears the "identity" edge to the AuthIdentity entity. +func (_u *IdentityAdoptionDecisionUpdate) ClearIdentity() *IdentityAdoptionDecisionUpdate { + _u.mutation.ClearIdentity() + return _u +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *IdentityAdoptionDecisionUpdate) Save(ctx context.Context) (int, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *IdentityAdoptionDecisionUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *IdentityAdoptionDecisionUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *IdentityAdoptionDecisionUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *IdentityAdoptionDecisionUpdate) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := identityadoptiondecision.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *IdentityAdoptionDecisionUpdate) check() error { + if _u.mutation.PendingAuthSessionCleared() && len(_u.mutation.PendingAuthSessionIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "IdentityAdoptionDecision.pending_auth_session"`) + } + return nil +} + +func (_u *IdentityAdoptionDecisionUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(identityadoptiondecision.Table, identityadoptiondecision.Columns, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(identityadoptiondecision.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.AdoptDisplayName(); ok { + _spec.SetField(identityadoptiondecision.FieldAdoptDisplayName, field.TypeBool, value) + } + if value, ok := _u.mutation.AdoptAvatar(); ok { + _spec.SetField(identityadoptiondecision.FieldAdoptAvatar, field.TypeBool, value) + } + if _u.mutation.PendingAuthSessionCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: true, + Table: identityadoptiondecision.PendingAuthSessionTable, + Columns: []string{identityadoptiondecision.PendingAuthSessionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.PendingAuthSessionIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: true, + Table: identityadoptiondecision.PendingAuthSessionTable, + Columns: []string{identityadoptiondecision.PendingAuthSessionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.IdentityCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: identityadoptiondecision.IdentityTable, + Columns: []string{identityadoptiondecision.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: identityadoptiondecision.IdentityTable, + Columns: []string{identityadoptiondecision.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{identityadoptiondecision.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// IdentityAdoptionDecisionUpdateOne is the builder for updating a single IdentityAdoptionDecision entity. +type IdentityAdoptionDecisionUpdateOne struct { + config + fields []string + hooks []Hook + mutation *IdentityAdoptionDecisionMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *IdentityAdoptionDecisionUpdateOne) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpdateOne { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetPendingAuthSessionID sets the "pending_auth_session_id" field. +func (_u *IdentityAdoptionDecisionUpdateOne) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpdateOne { + _u.mutation.SetPendingAuthSessionID(v) + return _u +} + +// SetNillablePendingAuthSessionID sets the "pending_auth_session_id" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdateOne) SetNillablePendingAuthSessionID(v *int64) *IdentityAdoptionDecisionUpdateOne { + if v != nil { + _u.SetPendingAuthSessionID(*v) + } + return _u +} + +// SetIdentityID sets the "identity_id" field. +func (_u *IdentityAdoptionDecisionUpdateOne) SetIdentityID(v int64) *IdentityAdoptionDecisionUpdateOne { + _u.mutation.SetIdentityID(v) + return _u +} + +// SetNillableIdentityID sets the "identity_id" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdateOne) SetNillableIdentityID(v *int64) *IdentityAdoptionDecisionUpdateOne { + if v != nil { + _u.SetIdentityID(*v) + } + return _u +} + +// ClearIdentityID clears the value of the "identity_id" field. +func (_u *IdentityAdoptionDecisionUpdateOne) ClearIdentityID() *IdentityAdoptionDecisionUpdateOne { + _u.mutation.ClearIdentityID() + return _u +} + +// SetAdoptDisplayName sets the "adopt_display_name" field. +func (_u *IdentityAdoptionDecisionUpdateOne) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpdateOne { + _u.mutation.SetAdoptDisplayName(v) + return _u +} + +// SetNillableAdoptDisplayName sets the "adopt_display_name" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdateOne) SetNillableAdoptDisplayName(v *bool) *IdentityAdoptionDecisionUpdateOne { + if v != nil { + _u.SetAdoptDisplayName(*v) + } + return _u +} + +// SetAdoptAvatar sets the "adopt_avatar" field. +func (_u *IdentityAdoptionDecisionUpdateOne) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpdateOne { + _u.mutation.SetAdoptAvatar(v) + return _u +} + +// SetNillableAdoptAvatar sets the "adopt_avatar" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdateOne) SetNillableAdoptAvatar(v *bool) *IdentityAdoptionDecisionUpdateOne { + if v != nil { + _u.SetAdoptAvatar(*v) + } + return _u +} + +// SetPendingAuthSession sets the "pending_auth_session" edge to the PendingAuthSession entity. +func (_u *IdentityAdoptionDecisionUpdateOne) SetPendingAuthSession(v *PendingAuthSession) *IdentityAdoptionDecisionUpdateOne { + return _u.SetPendingAuthSessionID(v.ID) +} + +// SetIdentity sets the "identity" edge to the AuthIdentity entity. +func (_u *IdentityAdoptionDecisionUpdateOne) SetIdentity(v *AuthIdentity) *IdentityAdoptionDecisionUpdateOne { + return _u.SetIdentityID(v.ID) +} + +// Mutation returns the IdentityAdoptionDecisionMutation object of the builder. +func (_u *IdentityAdoptionDecisionUpdateOne) Mutation() *IdentityAdoptionDecisionMutation { + return _u.mutation +} + +// ClearPendingAuthSession clears the "pending_auth_session" edge to the PendingAuthSession entity. +func (_u *IdentityAdoptionDecisionUpdateOne) ClearPendingAuthSession() *IdentityAdoptionDecisionUpdateOne { + _u.mutation.ClearPendingAuthSession() + return _u +} + +// ClearIdentity clears the "identity" edge to the AuthIdentity entity. +func (_u *IdentityAdoptionDecisionUpdateOne) ClearIdentity() *IdentityAdoptionDecisionUpdateOne { + _u.mutation.ClearIdentity() + return _u +} + +// Where appends a list predicates to the IdentityAdoptionDecisionUpdate builder. +func (_u *IdentityAdoptionDecisionUpdateOne) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *IdentityAdoptionDecisionUpdateOne) Select(field string, fields ...string) *IdentityAdoptionDecisionUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated IdentityAdoptionDecision entity. +func (_u *IdentityAdoptionDecisionUpdateOne) Save(ctx context.Context) (*IdentityAdoptionDecision, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *IdentityAdoptionDecisionUpdateOne) SaveX(ctx context.Context) *IdentityAdoptionDecision { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *IdentityAdoptionDecisionUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *IdentityAdoptionDecisionUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *IdentityAdoptionDecisionUpdateOne) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := identityadoptiondecision.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *IdentityAdoptionDecisionUpdateOne) check() error { + if _u.mutation.PendingAuthSessionCleared() && len(_u.mutation.PendingAuthSessionIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "IdentityAdoptionDecision.pending_auth_session"`) + } + return nil +} + +func (_u *IdentityAdoptionDecisionUpdateOne) sqlSave(ctx context.Context) (_node *IdentityAdoptionDecision, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(identityadoptiondecision.Table, identityadoptiondecision.Columns, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "IdentityAdoptionDecision.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, identityadoptiondecision.FieldID) + for _, f := range fields { + if !identityadoptiondecision.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != identityadoptiondecision.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(identityadoptiondecision.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.AdoptDisplayName(); ok { + _spec.SetField(identityadoptiondecision.FieldAdoptDisplayName, field.TypeBool, value) + } + if value, ok := _u.mutation.AdoptAvatar(); ok { + _spec.SetField(identityadoptiondecision.FieldAdoptAvatar, field.TypeBool, value) + } + if _u.mutation.PendingAuthSessionCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: true, + Table: identityadoptiondecision.PendingAuthSessionTable, + Columns: []string{identityadoptiondecision.PendingAuthSessionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.PendingAuthSessionIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: true, + Table: identityadoptiondecision.PendingAuthSessionTable, + Columns: []string{identityadoptiondecision.PendingAuthSessionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.IdentityCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: identityadoptiondecision.IdentityTable, + Columns: []string{identityadoptiondecision.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: identityadoptiondecision.IdentityTable, + Columns: []string{identityadoptiondecision.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &IdentityAdoptionDecision{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{identityadoptiondecision.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/intercept/intercept.go b/backend/ent/intercept/intercept.go index 8d8320bbba5f6cdec4ba26ce89e5a81f532d30f9..157c51225dac4eb5b6d404c788b4c65fa27b4d09 100644 --- a/backend/ent/intercept/intercept.go +++ b/backend/ent/intercept/intercept.go @@ -13,12 +13,16 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/paymentauditlog" "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" @@ -228,6 +232,60 @@ func (f TraverseAnnouncementRead) Traverse(ctx context.Context, q ent.Query) err return fmt.Errorf("unexpected query type %T. expect *ent.AnnouncementReadQuery", q) } +// The AuthIdentityFunc type is an adapter to allow the use of ordinary function as a Querier. +type AuthIdentityFunc func(context.Context, *ent.AuthIdentityQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f AuthIdentityFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.AuthIdentityQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityQuery", q) +} + +// The TraverseAuthIdentity type is an adapter to allow the use of ordinary function as Traverser. +type TraverseAuthIdentity func(context.Context, *ent.AuthIdentityQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseAuthIdentity) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseAuthIdentity) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.AuthIdentityQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityQuery", q) +} + +// The AuthIdentityChannelFunc type is an adapter to allow the use of ordinary function as a Querier. +type AuthIdentityChannelFunc func(context.Context, *ent.AuthIdentityChannelQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f AuthIdentityChannelFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.AuthIdentityChannelQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityChannelQuery", q) +} + +// The TraverseAuthIdentityChannel type is an adapter to allow the use of ordinary function as Traverser. +type TraverseAuthIdentityChannel func(context.Context, *ent.AuthIdentityChannelQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseAuthIdentityChannel) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseAuthIdentityChannel) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.AuthIdentityChannelQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityChannelQuery", q) +} + // The ErrorPassthroughRuleFunc type is an adapter to allow the use of ordinary function as a Querier. type ErrorPassthroughRuleFunc func(context.Context, *ent.ErrorPassthroughRuleQuery) (ent.Value, error) @@ -309,6 +367,33 @@ func (f TraverseIdempotencyRecord) Traverse(ctx context.Context, q ent.Query) er return fmt.Errorf("unexpected query type %T. expect *ent.IdempotencyRecordQuery", q) } +// The IdentityAdoptionDecisionFunc type is an adapter to allow the use of ordinary function as a Querier. +type IdentityAdoptionDecisionFunc func(context.Context, *ent.IdentityAdoptionDecisionQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f IdentityAdoptionDecisionFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.IdentityAdoptionDecisionQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.IdentityAdoptionDecisionQuery", q) +} + +// The TraverseIdentityAdoptionDecision type is an adapter to allow the use of ordinary function as Traverser. +type TraverseIdentityAdoptionDecision func(context.Context, *ent.IdentityAdoptionDecisionQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseIdentityAdoptionDecision) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseIdentityAdoptionDecision) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.IdentityAdoptionDecisionQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.IdentityAdoptionDecisionQuery", q) +} + // The PaymentAuditLogFunc type is an adapter to allow the use of ordinary function as a Querier. type PaymentAuditLogFunc func(context.Context, *ent.PaymentAuditLogQuery) (ent.Value, error) @@ -390,6 +475,33 @@ func (f TraversePaymentProviderInstance) Traverse(ctx context.Context, q ent.Que return fmt.Errorf("unexpected query type %T. expect *ent.PaymentProviderInstanceQuery", q) } +// The PendingAuthSessionFunc type is an adapter to allow the use of ordinary function as a Querier. +type PendingAuthSessionFunc func(context.Context, *ent.PendingAuthSessionQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f PendingAuthSessionFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.PendingAuthSessionQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.PendingAuthSessionQuery", q) +} + +// The TraversePendingAuthSession type is an adapter to allow the use of ordinary function as Traverser. +type TraversePendingAuthSession func(context.Context, *ent.PendingAuthSessionQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraversePendingAuthSession) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraversePendingAuthSession) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.PendingAuthSessionQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.PendingAuthSessionQuery", q) +} + // The PromoCodeFunc type is an adapter to allow the use of ordinary function as a Querier. type PromoCodeFunc func(context.Context, *ent.PromoCodeQuery) (ent.Value, error) @@ -808,18 +920,26 @@ func NewQuery(q ent.Query) (Query, error) { return &query[*ent.AnnouncementQuery, predicate.Announcement, announcement.OrderOption]{typ: ent.TypeAnnouncement, tq: q}, nil case *ent.AnnouncementReadQuery: return &query[*ent.AnnouncementReadQuery, predicate.AnnouncementRead, announcementread.OrderOption]{typ: ent.TypeAnnouncementRead, tq: q}, nil + case *ent.AuthIdentityQuery: + return &query[*ent.AuthIdentityQuery, predicate.AuthIdentity, authidentity.OrderOption]{typ: ent.TypeAuthIdentity, tq: q}, nil + case *ent.AuthIdentityChannelQuery: + return &query[*ent.AuthIdentityChannelQuery, predicate.AuthIdentityChannel, authidentitychannel.OrderOption]{typ: ent.TypeAuthIdentityChannel, tq: q}, nil case *ent.ErrorPassthroughRuleQuery: return &query[*ent.ErrorPassthroughRuleQuery, predicate.ErrorPassthroughRule, errorpassthroughrule.OrderOption]{typ: ent.TypeErrorPassthroughRule, tq: q}, nil case *ent.GroupQuery: return &query[*ent.GroupQuery, predicate.Group, group.OrderOption]{typ: ent.TypeGroup, tq: q}, nil case *ent.IdempotencyRecordQuery: return &query[*ent.IdempotencyRecordQuery, predicate.IdempotencyRecord, idempotencyrecord.OrderOption]{typ: ent.TypeIdempotencyRecord, tq: q}, nil + case *ent.IdentityAdoptionDecisionQuery: + return &query[*ent.IdentityAdoptionDecisionQuery, predicate.IdentityAdoptionDecision, identityadoptiondecision.OrderOption]{typ: ent.TypeIdentityAdoptionDecision, tq: q}, nil case *ent.PaymentAuditLogQuery: return &query[*ent.PaymentAuditLogQuery, predicate.PaymentAuditLog, paymentauditlog.OrderOption]{typ: ent.TypePaymentAuditLog, tq: q}, nil case *ent.PaymentOrderQuery: return &query[*ent.PaymentOrderQuery, predicate.PaymentOrder, paymentorder.OrderOption]{typ: ent.TypePaymentOrder, tq: q}, nil case *ent.PaymentProviderInstanceQuery: return &query[*ent.PaymentProviderInstanceQuery, predicate.PaymentProviderInstance, paymentproviderinstance.OrderOption]{typ: ent.TypePaymentProviderInstance, tq: q}, nil + case *ent.PendingAuthSessionQuery: + return &query[*ent.PendingAuthSessionQuery, predicate.PendingAuthSession, pendingauthsession.OrderOption]{typ: ent.TypePendingAuthSession, tq: q}, nil case *ent.PromoCodeQuery: return &query[*ent.PromoCodeQuery, predicate.PromoCode, promocode.OrderOption]{typ: ent.TypePromoCode, tq: q}, nil case *ent.PromoCodeUsageQuery: diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index 68bdbf5546839faeaffba965d35f9ac56e616e03..81f6a664d02f49e205d38bfec0ce5b353feefacd 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -338,6 +338,89 @@ var ( }, }, } + // AuthIdentitiesColumns holds the columns for the "auth_identities" table. + AuthIdentitiesColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "provider_type", Type: field.TypeString, Size: 20}, + {Name: "provider_key", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "provider_subject", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "issuer", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "metadata", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "user_id", Type: field.TypeInt64}, + } + // AuthIdentitiesTable holds the schema information for the "auth_identities" table. + AuthIdentitiesTable = &schema.Table{ + Name: "auth_identities", + Columns: AuthIdentitiesColumns, + PrimaryKey: []*schema.Column{AuthIdentitiesColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "auth_identities_users_auth_identities", + Columns: []*schema.Column{AuthIdentitiesColumns[9]}, + RefColumns: []*schema.Column{UsersColumns[0]}, + OnDelete: schema.NoAction, + }, + }, + Indexes: []*schema.Index{ + { + Name: "authidentity_provider_type_provider_key_provider_subject", + Unique: true, + Columns: []*schema.Column{AuthIdentitiesColumns[3], AuthIdentitiesColumns[4], AuthIdentitiesColumns[5]}, + }, + { + Name: "authidentity_user_id", + Unique: false, + Columns: []*schema.Column{AuthIdentitiesColumns[9]}, + }, + { + Name: "authidentity_user_id_provider_type", + Unique: false, + Columns: []*schema.Column{AuthIdentitiesColumns[9], AuthIdentitiesColumns[3]}, + }, + }, + } + // AuthIdentityChannelsColumns holds the columns for the "auth_identity_channels" table. + AuthIdentityChannelsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "provider_type", Type: field.TypeString, Size: 20}, + {Name: "provider_key", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "channel", Type: field.TypeString, Size: 20}, + {Name: "channel_app_id", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "channel_subject", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "metadata", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "identity_id", Type: field.TypeInt64}, + } + // AuthIdentityChannelsTable holds the schema information for the "auth_identity_channels" table. + AuthIdentityChannelsTable = &schema.Table{ + Name: "auth_identity_channels", + Columns: AuthIdentityChannelsColumns, + PrimaryKey: []*schema.Column{AuthIdentityChannelsColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "auth_identity_channels_auth_identities_channels", + Columns: []*schema.Column{AuthIdentityChannelsColumns[9]}, + RefColumns: []*schema.Column{AuthIdentitiesColumns[0]}, + OnDelete: schema.NoAction, + }, + }, + Indexes: []*schema.Index{ + { + Name: "authidentitychannel_provider_type_provider_key_channel_channel_app_id_channel_subject", + Unique: true, + Columns: []*schema.Column{AuthIdentityChannelsColumns[3], AuthIdentityChannelsColumns[4], AuthIdentityChannelsColumns[5], AuthIdentityChannelsColumns[6], AuthIdentityChannelsColumns[7]}, + }, + { + Name: "authidentitychannel_identity_id", + Unique: false, + Columns: []*schema.Column{AuthIdentityChannelsColumns[9]}, + }, + }, + } // ErrorPassthroughRulesColumns holds the columns for the "error_passthrough_rules" table. ErrorPassthroughRulesColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt64, Increment: true}, @@ -485,6 +568,49 @@ var ( }, }, } + // IdentityAdoptionDecisionsColumns holds the columns for the "identity_adoption_decisions" table. + IdentityAdoptionDecisionsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "adopt_display_name", Type: field.TypeBool, Default: false}, + {Name: "adopt_avatar", Type: field.TypeBool, Default: false}, + {Name: "decided_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "identity_id", Type: field.TypeInt64, Nullable: true}, + {Name: "pending_auth_session_id", Type: field.TypeInt64, Unique: true}, + } + // IdentityAdoptionDecisionsTable holds the schema information for the "identity_adoption_decisions" table. + IdentityAdoptionDecisionsTable = &schema.Table{ + Name: "identity_adoption_decisions", + Columns: IdentityAdoptionDecisionsColumns, + PrimaryKey: []*schema.Column{IdentityAdoptionDecisionsColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "identity_adoption_decisions_auth_identities_adoption_decisions", + Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[6]}, + RefColumns: []*schema.Column{AuthIdentitiesColumns[0]}, + OnDelete: schema.SetNull, + }, + { + Symbol: "identity_adoption_decisions_pending_auth_sessions_adoption_decision", + Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[7]}, + RefColumns: []*schema.Column{PendingAuthSessionsColumns[0]}, + OnDelete: schema.NoAction, + }, + }, + Indexes: []*schema.Index{ + { + Name: "identityadoptiondecision_pending_auth_session_id", + Unique: true, + Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[7]}, + }, + { + Name: "identityadoptiondecision_identity_id", + Unique: false, + Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[6]}, + }, + }, + } // PaymentAuditLogsColumns holds the columns for the "payment_audit_logs" table. PaymentAuditLogsColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt64, Increment: true}, @@ -528,6 +654,8 @@ var ( {Name: "subscription_group_id", Type: field.TypeInt64, Nullable: true}, {Name: "subscription_days", Type: field.TypeInt, Nullable: true}, {Name: "provider_instance_id", Type: field.TypeString, Nullable: true, Size: 64}, + {Name: "provider_key", Type: field.TypeString, Nullable: true, Size: 30}, + {Name: "provider_snapshot", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}}, {Name: "status", Type: field.TypeString, Size: 30, Default: "PENDING"}, {Name: "refund_amount", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,2)"}}, {Name: "refund_reason", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, @@ -556,7 +684,7 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "payment_orders_users_payment_orders", - Columns: []*schema.Column{PaymentOrdersColumns[37]}, + Columns: []*schema.Column{PaymentOrdersColumns[39]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.NoAction, }, @@ -570,32 +698,32 @@ var ( { Name: "paymentorder_user_id", Unique: false, - Columns: []*schema.Column{PaymentOrdersColumns[37]}, + Columns: []*schema.Column{PaymentOrdersColumns[39]}, }, { Name: "paymentorder_status", Unique: false, - Columns: []*schema.Column{PaymentOrdersColumns[19]}, + Columns: []*schema.Column{PaymentOrdersColumns[21]}, }, { Name: "paymentorder_expires_at", Unique: false, - Columns: []*schema.Column{PaymentOrdersColumns[27]}, + Columns: []*schema.Column{PaymentOrdersColumns[29]}, }, { Name: "paymentorder_created_at", Unique: false, - Columns: []*schema.Column{PaymentOrdersColumns[35]}, + Columns: []*schema.Column{PaymentOrdersColumns[37]}, }, { Name: "paymentorder_paid_at", Unique: false, - Columns: []*schema.Column{PaymentOrdersColumns[28]}, + Columns: []*schema.Column{PaymentOrdersColumns[30]}, }, { Name: "paymentorder_payment_type_paid_at", Unique: false, - Columns: []*schema.Column{PaymentOrdersColumns[9], PaymentOrdersColumns[28]}, + Columns: []*schema.Column{PaymentOrdersColumns[9], PaymentOrdersColumns[30]}, }, { Name: "paymentorder_order_type", @@ -638,6 +766,72 @@ var ( }, }, } + // PendingAuthSessionsColumns holds the columns for the "pending_auth_sessions" table. + PendingAuthSessionsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "session_token", Type: field.TypeString, Size: 255}, + {Name: "intent", Type: field.TypeString, Size: 40}, + {Name: "provider_type", Type: field.TypeString, Size: 20}, + {Name: "provider_key", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "provider_subject", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "redirect_to", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}}, + {Name: "resolved_email", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}}, + {Name: "registration_password_hash", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}}, + {Name: "upstream_identity_claims", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "local_flow_state", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "browser_session_key", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}}, + {Name: "completion_code_hash", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}}, + {Name: "completion_code_expires_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "email_verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "password_verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "totp_verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "expires_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "consumed_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "target_user_id", Type: field.TypeInt64, Nullable: true}, + } + // PendingAuthSessionsTable holds the schema information for the "pending_auth_sessions" table. + PendingAuthSessionsTable = &schema.Table{ + Name: "pending_auth_sessions", + Columns: PendingAuthSessionsColumns, + PrimaryKey: []*schema.Column{PendingAuthSessionsColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "pending_auth_sessions_users_pending_auth_sessions", + Columns: []*schema.Column{PendingAuthSessionsColumns[21]}, + RefColumns: []*schema.Column{UsersColumns[0]}, + OnDelete: schema.SetNull, + }, + }, + Indexes: []*schema.Index{ + { + Name: "pendingauthsession_session_token", + Unique: true, + Columns: []*schema.Column{PendingAuthSessionsColumns[3]}, + }, + { + Name: "pendingauthsession_target_user_id", + Unique: false, + Columns: []*schema.Column{PendingAuthSessionsColumns[21]}, + }, + { + Name: "pendingauthsession_expires_at", + Unique: false, + Columns: []*schema.Column{PendingAuthSessionsColumns[19]}, + }, + { + Name: "pendingauthsession_provider_type_provider_key_provider_subject", + Unique: false, + Columns: []*schema.Column{PendingAuthSessionsColumns[5], PendingAuthSessionsColumns[6], PendingAuthSessionsColumns[7]}, + }, + { + Name: "pendingauthsession_completion_code_hash", + Unique: false, + Columns: []*schema.Column{PendingAuthSessionsColumns[14]}, + }, + }, + } // PromoCodesColumns holds the columns for the "promo_codes" table. PromoCodesColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt64, Increment: true}, @@ -1079,6 +1273,9 @@ var ( {Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, {Name: "totp_enabled", Type: field.TypeBool, Default: false}, {Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true}, + {Name: "signup_source", Type: field.TypeString, Size: 20, Default: "email"}, + {Name: "last_login_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "last_active_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "balance_notify_enabled", Type: field.TypeBool, Default: true}, {Name: "balance_notify_threshold_type", Type: field.TypeString, Default: "fixed"}, {Name: "balance_notify_threshold", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, @@ -1318,12 +1515,16 @@ var ( AccountGroupsTable, AnnouncementsTable, AnnouncementReadsTable, + AuthIdentitiesTable, + AuthIdentityChannelsTable, ErrorPassthroughRulesTable, GroupsTable, IdempotencyRecordsTable, + IdentityAdoptionDecisionsTable, PaymentAuditLogsTable, PaymentOrdersTable, PaymentProviderInstancesTable, + PendingAuthSessionsTable, PromoCodesTable, PromoCodeUsagesTable, ProxiesTable, @@ -1365,6 +1566,14 @@ func init() { AnnouncementReadsTable.Annotation = &entsql.Annotation{ Table: "announcement_reads", } + AuthIdentitiesTable.ForeignKeys[0].RefTable = UsersTable + AuthIdentitiesTable.Annotation = &entsql.Annotation{ + Table: "auth_identities", + } + AuthIdentityChannelsTable.ForeignKeys[0].RefTable = AuthIdentitiesTable + AuthIdentityChannelsTable.Annotation = &entsql.Annotation{ + Table: "auth_identity_channels", + } ErrorPassthroughRulesTable.Annotation = &entsql.Annotation{ Table: "error_passthrough_rules", } @@ -1374,6 +1583,11 @@ func init() { IdempotencyRecordsTable.Annotation = &entsql.Annotation{ Table: "idempotency_records", } + IdentityAdoptionDecisionsTable.ForeignKeys[0].RefTable = AuthIdentitiesTable + IdentityAdoptionDecisionsTable.ForeignKeys[1].RefTable = PendingAuthSessionsTable + IdentityAdoptionDecisionsTable.Annotation = &entsql.Annotation{ + Table: "identity_adoption_decisions", + } PaymentAuditLogsTable.Annotation = &entsql.Annotation{ Table: "payment_audit_logs", } @@ -1384,6 +1598,10 @@ func init() { PaymentProviderInstancesTable.Annotation = &entsql.Annotation{ Table: "payment_provider_instances", } + PendingAuthSessionsTable.ForeignKeys[0].RefTable = UsersTable + PendingAuthSessionsTable.Annotation = &entsql.Annotation{ + Table: "pending_auth_sessions", + } PromoCodesTable.Annotation = &entsql.Annotation{ Table: "promo_codes", } diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 524ccb925f4561a123bdcd7f8449eac02a72c64b..ec4a4070f0213a63ddba9deb52aa762a7ab3299a 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -17,12 +17,16 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/paymentauditlog" "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" @@ -51,32 +55,36 @@ const ( OpUpdateOne = ent.OpUpdateOne // Node types. - TypeAPIKey = "APIKey" - TypeAccount = "Account" - TypeAccountGroup = "AccountGroup" - TypeAnnouncement = "Announcement" - TypeAnnouncementRead = "AnnouncementRead" - TypeErrorPassthroughRule = "ErrorPassthroughRule" - TypeGroup = "Group" - TypeIdempotencyRecord = "IdempotencyRecord" - TypePaymentAuditLog = "PaymentAuditLog" - TypePaymentOrder = "PaymentOrder" - TypePaymentProviderInstance = "PaymentProviderInstance" - TypePromoCode = "PromoCode" - TypePromoCodeUsage = "PromoCodeUsage" - TypeProxy = "Proxy" - TypeRedeemCode = "RedeemCode" - TypeSecuritySecret = "SecuritySecret" - TypeSetting = "Setting" - TypeSubscriptionPlan = "SubscriptionPlan" - TypeTLSFingerprintProfile = "TLSFingerprintProfile" - TypeUsageCleanupTask = "UsageCleanupTask" - TypeUsageLog = "UsageLog" - TypeUser = "User" - TypeUserAllowedGroup = "UserAllowedGroup" - TypeUserAttributeDefinition = "UserAttributeDefinition" - TypeUserAttributeValue = "UserAttributeValue" - TypeUserSubscription = "UserSubscription" + TypeAPIKey = "APIKey" + TypeAccount = "Account" + TypeAccountGroup = "AccountGroup" + TypeAnnouncement = "Announcement" + TypeAnnouncementRead = "AnnouncementRead" + TypeAuthIdentity = "AuthIdentity" + TypeAuthIdentityChannel = "AuthIdentityChannel" + TypeErrorPassthroughRule = "ErrorPassthroughRule" + TypeGroup = "Group" + TypeIdempotencyRecord = "IdempotencyRecord" + TypeIdentityAdoptionDecision = "IdentityAdoptionDecision" + TypePaymentAuditLog = "PaymentAuditLog" + TypePaymentOrder = "PaymentOrder" + TypePaymentProviderInstance = "PaymentProviderInstance" + TypePendingAuthSession = "PendingAuthSession" + TypePromoCode = "PromoCode" + TypePromoCodeUsage = "PromoCodeUsage" + TypeProxy = "Proxy" + TypeRedeemCode = "RedeemCode" + TypeSecuritySecret = "SecuritySecret" + TypeSetting = "Setting" + TypeSubscriptionPlan = "SubscriptionPlan" + TypeTLSFingerprintProfile = "TLSFingerprintProfile" + TypeUsageCleanupTask = "UsageCleanupTask" + TypeUsageLog = "UsageLog" + TypeUser = "User" + TypeUserAllowedGroup = "UserAllowedGroup" + TypeUserAttributeDefinition = "UserAttributeDefinition" + TypeUserAttributeValue = "UserAttributeValue" + TypeUserSubscription = "UserSubscription" ) // APIKeyMutation represents an operation that mutates the APIKey nodes in the graph. @@ -6887,49 +6895,45 @@ func (m *AnnouncementReadMutation) ResetEdge(name string) error { return fmt.Errorf("unknown AnnouncementRead edge %s", name) } -// ErrorPassthroughRuleMutation represents an operation that mutates the ErrorPassthroughRule nodes in the graph. -type ErrorPassthroughRuleMutation struct { +// AuthIdentityMutation represents an operation that mutates the AuthIdentity nodes in the graph. +type AuthIdentityMutation struct { config - op Op - typ string - id *int64 - created_at *time.Time - updated_at *time.Time - name *string - enabled *bool - priority *int - addpriority *int - error_codes *[]int - appenderror_codes []int - keywords *[]string - appendkeywords []string - match_mode *string - platforms *[]string - appendplatforms []string - passthrough_code *bool - response_code *int - addresponse_code *int - passthrough_body *bool - custom_message *string - skip_monitoring *bool - description *string - clearedFields map[string]struct{} - done bool - oldValue func(context.Context) (*ErrorPassthroughRule, error) - predicates []predicate.ErrorPassthroughRule + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + provider_type *string + provider_key *string + provider_subject *string + verified_at *time.Time + issuer *string + metadata *map[string]interface{} + clearedFields map[string]struct{} + user *int64 + cleareduser bool + channels map[int64]struct{} + removedchannels map[int64]struct{} + clearedchannels bool + adoption_decisions map[int64]struct{} + removedadoption_decisions map[int64]struct{} + clearedadoption_decisions bool + done bool + oldValue func(context.Context) (*AuthIdentity, error) + predicates []predicate.AuthIdentity } -var _ ent.Mutation = (*ErrorPassthroughRuleMutation)(nil) +var _ ent.Mutation = (*AuthIdentityMutation)(nil) -// errorpassthroughruleOption allows management of the mutation configuration using functional options. -type errorpassthroughruleOption func(*ErrorPassthroughRuleMutation) +// authidentityOption allows management of the mutation configuration using functional options. +type authidentityOption func(*AuthIdentityMutation) -// newErrorPassthroughRuleMutation creates new mutation for the ErrorPassthroughRule entity. -func newErrorPassthroughRuleMutation(c config, op Op, opts ...errorpassthroughruleOption) *ErrorPassthroughRuleMutation { - m := &ErrorPassthroughRuleMutation{ +// newAuthIdentityMutation creates new mutation for the AuthIdentity entity. +func newAuthIdentityMutation(c config, op Op, opts ...authidentityOption) *AuthIdentityMutation { + m := &AuthIdentityMutation{ config: c, op: op, - typ: TypeErrorPassthroughRule, + typ: TypeAuthIdentity, clearedFields: make(map[string]struct{}), } for _, opt := range opts { @@ -6938,20 +6942,20 @@ func newErrorPassthroughRuleMutation(c config, op Op, opts ...errorpassthroughru return m } -// withErrorPassthroughRuleID sets the ID field of the mutation. -func withErrorPassthroughRuleID(id int64) errorpassthroughruleOption { - return func(m *ErrorPassthroughRuleMutation) { +// withAuthIdentityID sets the ID field of the mutation. +func withAuthIdentityID(id int64) authidentityOption { + return func(m *AuthIdentityMutation) { var ( err error once sync.Once - value *ErrorPassthroughRule + value *AuthIdentity ) - m.oldValue = func(ctx context.Context) (*ErrorPassthroughRule, error) { + m.oldValue = func(ctx context.Context) (*AuthIdentity, error) { once.Do(func() { if m.done { err = errors.New("querying old values post mutation is not allowed") } else { - value, err = m.Client().ErrorPassthroughRule.Get(ctx, id) + value, err = m.Client().AuthIdentity.Get(ctx, id) } }) return value, err @@ -6960,10 +6964,10 @@ func withErrorPassthroughRuleID(id int64) errorpassthroughruleOption { } } -// withErrorPassthroughRule sets the old ErrorPassthroughRule of the mutation. -func withErrorPassthroughRule(node *ErrorPassthroughRule) errorpassthroughruleOption { - return func(m *ErrorPassthroughRuleMutation) { - m.oldValue = func(context.Context) (*ErrorPassthroughRule, error) { +// withAuthIdentity sets the old AuthIdentity of the mutation. +func withAuthIdentity(node *AuthIdentity) authidentityOption { + return func(m *AuthIdentityMutation) { + m.oldValue = func(context.Context) (*AuthIdentity, error) { return node, nil } m.id = &node.ID @@ -6972,7 +6976,7 @@ func withErrorPassthroughRule(node *ErrorPassthroughRule) errorpassthroughruleOp // Client returns a new `ent.Client` from the mutation. If the mutation was // executed in a transaction (ent.Tx), a transactional client is returned. -func (m ErrorPassthroughRuleMutation) Client() *Client { +func (m AuthIdentityMutation) Client() *Client { client := &Client{config: m.config} client.init() return client @@ -6980,7 +6984,7 @@ func (m ErrorPassthroughRuleMutation) Client() *Client { // Tx returns an `ent.Tx` for mutations that were executed in transactions; // it returns an error otherwise. -func (m ErrorPassthroughRuleMutation) Tx() (*Tx, error) { +func (m AuthIdentityMutation) Tx() (*Tx, error) { if _, ok := m.driver.(*txDriver); !ok { return nil, errors.New("ent: mutation is not running in a transaction") } @@ -6991,7 +6995,7 @@ func (m ErrorPassthroughRuleMutation) Tx() (*Tx, error) { // ID returns the ID value in the mutation. Note that the ID is only available // if it was provided to the builder or after it was returned from the database. -func (m *ErrorPassthroughRuleMutation) ID() (id int64, exists bool) { +func (m *AuthIdentityMutation) ID() (id int64, exists bool) { if m.id == nil { return } @@ -7002,7 +7006,7 @@ func (m *ErrorPassthroughRuleMutation) ID() (id int64, exists bool) { // That means, if the mutation is applied within a transaction with an isolation level such // as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated // or updated by the mutation. -func (m *ErrorPassthroughRuleMutation) IDs(ctx context.Context) ([]int64, error) { +func (m *AuthIdentityMutation) IDs(ctx context.Context) ([]int64, error) { switch { case m.op.Is(OpUpdateOne | OpDeleteOne): id, exists := m.ID() @@ -7011,19 +7015,19 @@ func (m *ErrorPassthroughRuleMutation) IDs(ctx context.Context) ([]int64, error) } fallthrough case m.op.Is(OpUpdate | OpDelete): - return m.Client().ErrorPassthroughRule.Query().Where(m.predicates...).IDs(ctx) + return m.Client().AuthIdentity.Query().Where(m.predicates...).IDs(ctx) default: return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) } } // SetCreatedAt sets the "created_at" field. -func (m *ErrorPassthroughRuleMutation) SetCreatedAt(t time.Time) { +func (m *AuthIdentityMutation) SetCreatedAt(t time.Time) { m.created_at = &t } // CreatedAt returns the value of the "created_at" field in the mutation. -func (m *ErrorPassthroughRuleMutation) CreatedAt() (r time.Time, exists bool) { +func (m *AuthIdentityMutation) CreatedAt() (r time.Time, exists bool) { v := m.created_at if v == nil { return @@ -7031,10 +7035,10 @@ func (m *ErrorPassthroughRuleMutation) CreatedAt() (r time.Time, exists bool) { return *v, true } -// OldCreatedAt returns the old "created_at" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// OldCreatedAt returns the old "created_at" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { +func (m *AuthIdentityMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") } @@ -7049,17 +7053,17 @@ func (m *ErrorPassthroughRuleMutation) OldCreatedAt(ctx context.Context) (v time } // ResetCreatedAt resets all changes to the "created_at" field. -func (m *ErrorPassthroughRuleMutation) ResetCreatedAt() { +func (m *AuthIdentityMutation) ResetCreatedAt() { m.created_at = nil } // SetUpdatedAt sets the "updated_at" field. -func (m *ErrorPassthroughRuleMutation) SetUpdatedAt(t time.Time) { +func (m *AuthIdentityMutation) SetUpdatedAt(t time.Time) { m.updated_at = &t } // UpdatedAt returns the value of the "updated_at" field in the mutation. -func (m *ErrorPassthroughRuleMutation) UpdatedAt() (r time.Time, exists bool) { +func (m *AuthIdentityMutation) UpdatedAt() (r time.Time, exists bool) { v := m.updated_at if v == nil { return @@ -7067,10 +7071,10 @@ func (m *ErrorPassthroughRuleMutation) UpdatedAt() (r time.Time, exists bool) { return *v, true } -// OldUpdatedAt returns the old "updated_at" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// OldUpdatedAt returns the old "updated_at" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { +func (m *AuthIdentityMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") } @@ -7085,942 +7089,1510 @@ func (m *ErrorPassthroughRuleMutation) OldUpdatedAt(ctx context.Context) (v time } // ResetUpdatedAt resets all changes to the "updated_at" field. -func (m *ErrorPassthroughRuleMutation) ResetUpdatedAt() { +func (m *AuthIdentityMutation) ResetUpdatedAt() { m.updated_at = nil } -// SetName sets the "name" field. -func (m *ErrorPassthroughRuleMutation) SetName(s string) { - m.name = &s +// SetUserID sets the "user_id" field. +func (m *AuthIdentityMutation) SetUserID(i int64) { + m.user = &i } -// Name returns the value of the "name" field in the mutation. -func (m *ErrorPassthroughRuleMutation) Name() (r string, exists bool) { - v := m.name +// UserID returns the value of the "user_id" field in the mutation. +func (m *AuthIdentityMutation) UserID() (r int64, exists bool) { + v := m.user if v == nil { return } return *v, true } -// OldName returns the old "name" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// OldUserID returns the old "user_id" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldName(ctx context.Context) (v string, err error) { +func (m *AuthIdentityMutation) OldUserID(ctx context.Context) (v int64, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldName is only allowed on UpdateOne operations") + return v, errors.New("OldUserID is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldName requires an ID field in the mutation") + return v, errors.New("OldUserID requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldName: %w", err) + return v, fmt.Errorf("querying old value for OldUserID: %w", err) } - return oldValue.Name, nil + return oldValue.UserID, nil } -// ResetName resets all changes to the "name" field. -func (m *ErrorPassthroughRuleMutation) ResetName() { - m.name = nil +// ResetUserID resets all changes to the "user_id" field. +func (m *AuthIdentityMutation) ResetUserID() { + m.user = nil } -// SetEnabled sets the "enabled" field. -func (m *ErrorPassthroughRuleMutation) SetEnabled(b bool) { - m.enabled = &b +// SetProviderType sets the "provider_type" field. +func (m *AuthIdentityMutation) SetProviderType(s string) { + m.provider_type = &s } -// Enabled returns the value of the "enabled" field in the mutation. -func (m *ErrorPassthroughRuleMutation) Enabled() (r bool, exists bool) { - v := m.enabled +// ProviderType returns the value of the "provider_type" field in the mutation. +func (m *AuthIdentityMutation) ProviderType() (r string, exists bool) { + v := m.provider_type if v == nil { return } return *v, true } -// OldEnabled returns the old "enabled" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// OldProviderType returns the old "provider_type" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldEnabled(ctx context.Context) (v bool, err error) { +func (m *AuthIdentityMutation) OldProviderType(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldEnabled is only allowed on UpdateOne operations") + return v, errors.New("OldProviderType is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldEnabled requires an ID field in the mutation") + return v, errors.New("OldProviderType requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldEnabled: %w", err) + return v, fmt.Errorf("querying old value for OldProviderType: %w", err) } - return oldValue.Enabled, nil + return oldValue.ProviderType, nil } -// ResetEnabled resets all changes to the "enabled" field. -func (m *ErrorPassthroughRuleMutation) ResetEnabled() { - m.enabled = nil +// ResetProviderType resets all changes to the "provider_type" field. +func (m *AuthIdentityMutation) ResetProviderType() { + m.provider_type = nil } -// SetPriority sets the "priority" field. -func (m *ErrorPassthroughRuleMutation) SetPriority(i int) { - m.priority = &i - m.addpriority = nil +// SetProviderKey sets the "provider_key" field. +func (m *AuthIdentityMutation) SetProviderKey(s string) { + m.provider_key = &s } -// Priority returns the value of the "priority" field in the mutation. -func (m *ErrorPassthroughRuleMutation) Priority() (r int, exists bool) { - v := m.priority +// ProviderKey returns the value of the "provider_key" field in the mutation. +func (m *AuthIdentityMutation) ProviderKey() (r string, exists bool) { + v := m.provider_key if v == nil { return } return *v, true } -// OldPriority returns the old "priority" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// OldProviderKey returns the old "provider_key" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldPriority(ctx context.Context) (v int, err error) { +func (m *AuthIdentityMutation) OldProviderKey(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPriority is only allowed on UpdateOne operations") + return v, errors.New("OldProviderKey is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPriority requires an ID field in the mutation") + return v, errors.New("OldProviderKey requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldPriority: %w", err) + return v, fmt.Errorf("querying old value for OldProviderKey: %w", err) } - return oldValue.Priority, nil + return oldValue.ProviderKey, nil } -// AddPriority adds i to the "priority" field. -func (m *ErrorPassthroughRuleMutation) AddPriority(i int) { - if m.addpriority != nil { - *m.addpriority += i - } else { - m.addpriority = &i - } +// ResetProviderKey resets all changes to the "provider_key" field. +func (m *AuthIdentityMutation) ResetProviderKey() { + m.provider_key = nil } -// AddedPriority returns the value that was added to the "priority" field in this mutation. -func (m *ErrorPassthroughRuleMutation) AddedPriority() (r int, exists bool) { - v := m.addpriority +// SetProviderSubject sets the "provider_subject" field. +func (m *AuthIdentityMutation) SetProviderSubject(s string) { + m.provider_subject = &s +} + +// ProviderSubject returns the value of the "provider_subject" field in the mutation. +func (m *AuthIdentityMutation) ProviderSubject() (r string, exists bool) { + v := m.provider_subject if v == nil { return } return *v, true } -// ResetPriority resets all changes to the "priority" field. -func (m *ErrorPassthroughRuleMutation) ResetPriority() { - m.priority = nil - m.addpriority = nil +// OldProviderSubject returns the old "provider_subject" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityMutation) OldProviderSubject(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderSubject is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderSubject requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderSubject: %w", err) + } + return oldValue.ProviderSubject, nil } -// SetErrorCodes sets the "error_codes" field. -func (m *ErrorPassthroughRuleMutation) SetErrorCodes(i []int) { - m.error_codes = &i - m.appenderror_codes = nil +// ResetProviderSubject resets all changes to the "provider_subject" field. +func (m *AuthIdentityMutation) ResetProviderSubject() { + m.provider_subject = nil } -// ErrorCodes returns the value of the "error_codes" field in the mutation. -func (m *ErrorPassthroughRuleMutation) ErrorCodes() (r []int, exists bool) { - v := m.error_codes +// SetVerifiedAt sets the "verified_at" field. +func (m *AuthIdentityMutation) SetVerifiedAt(t time.Time) { + m.verified_at = &t +} + +// VerifiedAt returns the value of the "verified_at" field in the mutation. +func (m *AuthIdentityMutation) VerifiedAt() (r time.Time, exists bool) { + v := m.verified_at if v == nil { return } return *v, true } -// OldErrorCodes returns the old "error_codes" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// OldVerifiedAt returns the old "verified_at" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldErrorCodes(ctx context.Context) (v []int, err error) { +func (m *AuthIdentityMutation) OldVerifiedAt(ctx context.Context) (v *time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldErrorCodes is only allowed on UpdateOne operations") + return v, errors.New("OldVerifiedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldErrorCodes requires an ID field in the mutation") + return v, errors.New("OldVerifiedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldErrorCodes: %w", err) - } - return oldValue.ErrorCodes, nil -} - -// AppendErrorCodes adds i to the "error_codes" field. -func (m *ErrorPassthroughRuleMutation) AppendErrorCodes(i []int) { - m.appenderror_codes = append(m.appenderror_codes, i...) -} - -// AppendedErrorCodes returns the list of values that were appended to the "error_codes" field in this mutation. -func (m *ErrorPassthroughRuleMutation) AppendedErrorCodes() ([]int, bool) { - if len(m.appenderror_codes) == 0 { - return nil, false + return v, fmt.Errorf("querying old value for OldVerifiedAt: %w", err) } - return m.appenderror_codes, true + return oldValue.VerifiedAt, nil } -// ClearErrorCodes clears the value of the "error_codes" field. -func (m *ErrorPassthroughRuleMutation) ClearErrorCodes() { - m.error_codes = nil - m.appenderror_codes = nil - m.clearedFields[errorpassthroughrule.FieldErrorCodes] = struct{}{} +// ClearVerifiedAt clears the value of the "verified_at" field. +func (m *AuthIdentityMutation) ClearVerifiedAt() { + m.verified_at = nil + m.clearedFields[authidentity.FieldVerifiedAt] = struct{}{} } -// ErrorCodesCleared returns if the "error_codes" field was cleared in this mutation. -func (m *ErrorPassthroughRuleMutation) ErrorCodesCleared() bool { - _, ok := m.clearedFields[errorpassthroughrule.FieldErrorCodes] +// VerifiedAtCleared returns if the "verified_at" field was cleared in this mutation. +func (m *AuthIdentityMutation) VerifiedAtCleared() bool { + _, ok := m.clearedFields[authidentity.FieldVerifiedAt] return ok } -// ResetErrorCodes resets all changes to the "error_codes" field. -func (m *ErrorPassthroughRuleMutation) ResetErrorCodes() { - m.error_codes = nil - m.appenderror_codes = nil - delete(m.clearedFields, errorpassthroughrule.FieldErrorCodes) +// ResetVerifiedAt resets all changes to the "verified_at" field. +func (m *AuthIdentityMutation) ResetVerifiedAt() { + m.verified_at = nil + delete(m.clearedFields, authidentity.FieldVerifiedAt) } -// SetKeywords sets the "keywords" field. -func (m *ErrorPassthroughRuleMutation) SetKeywords(s []string) { - m.keywords = &s - m.appendkeywords = nil +// SetIssuer sets the "issuer" field. +func (m *AuthIdentityMutation) SetIssuer(s string) { + m.issuer = &s } -// Keywords returns the value of the "keywords" field in the mutation. -func (m *ErrorPassthroughRuleMutation) Keywords() (r []string, exists bool) { - v := m.keywords +// Issuer returns the value of the "issuer" field in the mutation. +func (m *AuthIdentityMutation) Issuer() (r string, exists bool) { + v := m.issuer if v == nil { return } return *v, true } -// OldKeywords returns the old "keywords" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// OldIssuer returns the old "issuer" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldKeywords(ctx context.Context) (v []string, err error) { +func (m *AuthIdentityMutation) OldIssuer(ctx context.Context) (v *string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldKeywords is only allowed on UpdateOne operations") + return v, errors.New("OldIssuer is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldKeywords requires an ID field in the mutation") + return v, errors.New("OldIssuer requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldKeywords: %w", err) - } - return oldValue.Keywords, nil -} - -// AppendKeywords adds s to the "keywords" field. -func (m *ErrorPassthroughRuleMutation) AppendKeywords(s []string) { - m.appendkeywords = append(m.appendkeywords, s...) -} - -// AppendedKeywords returns the list of values that were appended to the "keywords" field in this mutation. -func (m *ErrorPassthroughRuleMutation) AppendedKeywords() ([]string, bool) { - if len(m.appendkeywords) == 0 { - return nil, false + return v, fmt.Errorf("querying old value for OldIssuer: %w", err) } - return m.appendkeywords, true + return oldValue.Issuer, nil } -// ClearKeywords clears the value of the "keywords" field. -func (m *ErrorPassthroughRuleMutation) ClearKeywords() { - m.keywords = nil - m.appendkeywords = nil - m.clearedFields[errorpassthroughrule.FieldKeywords] = struct{}{} +// ClearIssuer clears the value of the "issuer" field. +func (m *AuthIdentityMutation) ClearIssuer() { + m.issuer = nil + m.clearedFields[authidentity.FieldIssuer] = struct{}{} } -// KeywordsCleared returns if the "keywords" field was cleared in this mutation. -func (m *ErrorPassthroughRuleMutation) KeywordsCleared() bool { - _, ok := m.clearedFields[errorpassthroughrule.FieldKeywords] +// IssuerCleared returns if the "issuer" field was cleared in this mutation. +func (m *AuthIdentityMutation) IssuerCleared() bool { + _, ok := m.clearedFields[authidentity.FieldIssuer] return ok } -// ResetKeywords resets all changes to the "keywords" field. -func (m *ErrorPassthroughRuleMutation) ResetKeywords() { - m.keywords = nil - m.appendkeywords = nil - delete(m.clearedFields, errorpassthroughrule.FieldKeywords) +// ResetIssuer resets all changes to the "issuer" field. +func (m *AuthIdentityMutation) ResetIssuer() { + m.issuer = nil + delete(m.clearedFields, authidentity.FieldIssuer) } -// SetMatchMode sets the "match_mode" field. -func (m *ErrorPassthroughRuleMutation) SetMatchMode(s string) { - m.match_mode = &s +// SetMetadata sets the "metadata" field. +func (m *AuthIdentityMutation) SetMetadata(value map[string]interface{}) { + m.metadata = &value } -// MatchMode returns the value of the "match_mode" field in the mutation. -func (m *ErrorPassthroughRuleMutation) MatchMode() (r string, exists bool) { - v := m.match_mode +// Metadata returns the value of the "metadata" field in the mutation. +func (m *AuthIdentityMutation) Metadata() (r map[string]interface{}, exists bool) { + v := m.metadata if v == nil { return } return *v, true } -// OldMatchMode returns the old "match_mode" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// OldMetadata returns the old "metadata" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldMatchMode(ctx context.Context) (v string, err error) { +func (m *AuthIdentityMutation) OldMetadata(ctx context.Context) (v map[string]interface{}, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldMatchMode is only allowed on UpdateOne operations") + return v, errors.New("OldMetadata is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldMatchMode requires an ID field in the mutation") + return v, errors.New("OldMetadata requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldMatchMode: %w", err) + return v, fmt.Errorf("querying old value for OldMetadata: %w", err) } - return oldValue.MatchMode, nil + return oldValue.Metadata, nil } -// ResetMatchMode resets all changes to the "match_mode" field. -func (m *ErrorPassthroughRuleMutation) ResetMatchMode() { - m.match_mode = nil +// ResetMetadata resets all changes to the "metadata" field. +func (m *AuthIdentityMutation) ResetMetadata() { + m.metadata = nil } -// SetPlatforms sets the "platforms" field. -func (m *ErrorPassthroughRuleMutation) SetPlatforms(s []string) { - m.platforms = &s - m.appendplatforms = nil +// ClearUser clears the "user" edge to the User entity. +func (m *AuthIdentityMutation) ClearUser() { + m.cleareduser = true + m.clearedFields[authidentity.FieldUserID] = struct{}{} } -// Platforms returns the value of the "platforms" field in the mutation. -func (m *ErrorPassthroughRuleMutation) Platforms() (r []string, exists bool) { - v := m.platforms - if v == nil { - return - } - return *v, true +// UserCleared reports if the "user" edge to the User entity was cleared. +func (m *AuthIdentityMutation) UserCleared() bool { + return m.cleareduser } -// OldPlatforms returns the old "platforms" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldPlatforms(ctx context.Context) (v []string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPlatforms is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPlatforms requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldPlatforms: %w", err) +// UserIDs returns the "user" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// UserID instead. It exists only for internal usage by the builders. +func (m *AuthIdentityMutation) UserIDs() (ids []int64) { + if id := m.user; id != nil { + ids = append(ids, *id) } - return oldValue.Platforms, nil + return } -// AppendPlatforms adds s to the "platforms" field. -func (m *ErrorPassthroughRuleMutation) AppendPlatforms(s []string) { - m.appendplatforms = append(m.appendplatforms, s...) +// ResetUser resets all changes to the "user" edge. +func (m *AuthIdentityMutation) ResetUser() { + m.user = nil + m.cleareduser = false } -// AppendedPlatforms returns the list of values that were appended to the "platforms" field in this mutation. -func (m *ErrorPassthroughRuleMutation) AppendedPlatforms() ([]string, bool) { - if len(m.appendplatforms) == 0 { - return nil, false +// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by ids. +func (m *AuthIdentityMutation) AddChannelIDs(ids ...int64) { + if m.channels == nil { + m.channels = make(map[int64]struct{}) + } + for i := range ids { + m.channels[ids[i]] = struct{}{} } - return m.appendplatforms, true } -// ClearPlatforms clears the value of the "platforms" field. -func (m *ErrorPassthroughRuleMutation) ClearPlatforms() { - m.platforms = nil - m.appendplatforms = nil - m.clearedFields[errorpassthroughrule.FieldPlatforms] = struct{}{} +// ClearChannels clears the "channels" edge to the AuthIdentityChannel entity. +func (m *AuthIdentityMutation) ClearChannels() { + m.clearedchannels = true } -// PlatformsCleared returns if the "platforms" field was cleared in this mutation. -func (m *ErrorPassthroughRuleMutation) PlatformsCleared() bool { - _, ok := m.clearedFields[errorpassthroughrule.FieldPlatforms] - return ok +// ChannelsCleared reports if the "channels" edge to the AuthIdentityChannel entity was cleared. +func (m *AuthIdentityMutation) ChannelsCleared() bool { + return m.clearedchannels } -// ResetPlatforms resets all changes to the "platforms" field. -func (m *ErrorPassthroughRuleMutation) ResetPlatforms() { - m.platforms = nil - m.appendplatforms = nil - delete(m.clearedFields, errorpassthroughrule.FieldPlatforms) +// RemoveChannelIDs removes the "channels" edge to the AuthIdentityChannel entity by IDs. +func (m *AuthIdentityMutation) RemoveChannelIDs(ids ...int64) { + if m.removedchannels == nil { + m.removedchannels = make(map[int64]struct{}) + } + for i := range ids { + delete(m.channels, ids[i]) + m.removedchannels[ids[i]] = struct{}{} + } } -// SetPassthroughCode sets the "passthrough_code" field. -func (m *ErrorPassthroughRuleMutation) SetPassthroughCode(b bool) { - m.passthrough_code = &b +// RemovedChannels returns the removed IDs of the "channels" edge to the AuthIdentityChannel entity. +func (m *AuthIdentityMutation) RemovedChannelsIDs() (ids []int64) { + for id := range m.removedchannels { + ids = append(ids, id) + } + return } -// PassthroughCode returns the value of the "passthrough_code" field in the mutation. -func (m *ErrorPassthroughRuleMutation) PassthroughCode() (r bool, exists bool) { - v := m.passthrough_code - if v == nil { - return +// ChannelsIDs returns the "channels" edge IDs in the mutation. +func (m *AuthIdentityMutation) ChannelsIDs() (ids []int64) { + for id := range m.channels { + ids = append(ids, id) } - return *v, true + return } -// OldPassthroughCode returns the old "passthrough_code" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldPassthroughCode(ctx context.Context) (v bool, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPassthroughCode is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPassthroughCode requires an ID field in the mutation") +// ResetChannels resets all changes to the "channels" edge. +func (m *AuthIdentityMutation) ResetChannels() { + m.channels = nil + m.clearedchannels = false + m.removedchannels = nil +} + +// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by ids. +func (m *AuthIdentityMutation) AddAdoptionDecisionIDs(ids ...int64) { + if m.adoption_decisions == nil { + m.adoption_decisions = make(map[int64]struct{}) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldPassthroughCode: %w", err) + for i := range ids { + m.adoption_decisions[ids[i]] = struct{}{} } - return oldValue.PassthroughCode, nil } -// ResetPassthroughCode resets all changes to the "passthrough_code" field. -func (m *ErrorPassthroughRuleMutation) ResetPassthroughCode() { - m.passthrough_code = nil +// ClearAdoptionDecisions clears the "adoption_decisions" edge to the IdentityAdoptionDecision entity. +func (m *AuthIdentityMutation) ClearAdoptionDecisions() { + m.clearedadoption_decisions = true } -// SetResponseCode sets the "response_code" field. -func (m *ErrorPassthroughRuleMutation) SetResponseCode(i int) { - m.response_code = &i - m.addresponse_code = nil +// AdoptionDecisionsCleared reports if the "adoption_decisions" edge to the IdentityAdoptionDecision entity was cleared. +func (m *AuthIdentityMutation) AdoptionDecisionsCleared() bool { + return m.clearedadoption_decisions } -// ResponseCode returns the value of the "response_code" field in the mutation. -func (m *ErrorPassthroughRuleMutation) ResponseCode() (r int, exists bool) { - v := m.response_code - if v == nil { - return +// RemoveAdoptionDecisionIDs removes the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs. +func (m *AuthIdentityMutation) RemoveAdoptionDecisionIDs(ids ...int64) { + if m.removedadoption_decisions == nil { + m.removedadoption_decisions = make(map[int64]struct{}) + } + for i := range ids { + delete(m.adoption_decisions, ids[i]) + m.removedadoption_decisions[ids[i]] = struct{}{} } - return *v, true } -// OldResponseCode returns the old "response_code" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldResponseCode(ctx context.Context) (v *int, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldResponseCode is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldResponseCode requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldResponseCode: %w", err) +// RemovedAdoptionDecisions returns the removed IDs of the "adoption_decisions" edge to the IdentityAdoptionDecision entity. +func (m *AuthIdentityMutation) RemovedAdoptionDecisionsIDs() (ids []int64) { + for id := range m.removedadoption_decisions { + ids = append(ids, id) } - return oldValue.ResponseCode, nil + return } -// AddResponseCode adds i to the "response_code" field. -func (m *ErrorPassthroughRuleMutation) AddResponseCode(i int) { - if m.addresponse_code != nil { - *m.addresponse_code += i - } else { - m.addresponse_code = &i +// AdoptionDecisionsIDs returns the "adoption_decisions" edge IDs in the mutation. +func (m *AuthIdentityMutation) AdoptionDecisionsIDs() (ids []int64) { + for id := range m.adoption_decisions { + ids = append(ids, id) } + return } -// AddedResponseCode returns the value that was added to the "response_code" field in this mutation. -func (m *ErrorPassthroughRuleMutation) AddedResponseCode() (r int, exists bool) { - v := m.addresponse_code - if v == nil { - return - } - return *v, true +// ResetAdoptionDecisions resets all changes to the "adoption_decisions" edge. +func (m *AuthIdentityMutation) ResetAdoptionDecisions() { + m.adoption_decisions = nil + m.clearedadoption_decisions = false + m.removedadoption_decisions = nil } -// ClearResponseCode clears the value of the "response_code" field. -func (m *ErrorPassthroughRuleMutation) ClearResponseCode() { - m.response_code = nil - m.addresponse_code = nil - m.clearedFields[errorpassthroughrule.FieldResponseCode] = struct{}{} +// Where appends a list predicates to the AuthIdentityMutation builder. +func (m *AuthIdentityMutation) Where(ps ...predicate.AuthIdentity) { + m.predicates = append(m.predicates, ps...) } -// ResponseCodeCleared returns if the "response_code" field was cleared in this mutation. -func (m *ErrorPassthroughRuleMutation) ResponseCodeCleared() bool { - _, ok := m.clearedFields[errorpassthroughrule.FieldResponseCode] - return ok +// WhereP appends storage-level predicates to the AuthIdentityMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *AuthIdentityMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.AuthIdentity, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) } -// ResetResponseCode resets all changes to the "response_code" field. -func (m *ErrorPassthroughRuleMutation) ResetResponseCode() { - m.response_code = nil - m.addresponse_code = nil - delete(m.clearedFields, errorpassthroughrule.FieldResponseCode) +// Op returns the operation name. +func (m *AuthIdentityMutation) Op() Op { + return m.op } -// SetPassthroughBody sets the "passthrough_body" field. -func (m *ErrorPassthroughRuleMutation) SetPassthroughBody(b bool) { - m.passthrough_body = &b +// SetOp allows setting the mutation operation. +func (m *AuthIdentityMutation) SetOp(op Op) { + m.op = op } -// PassthroughBody returns the value of the "passthrough_body" field in the mutation. -func (m *ErrorPassthroughRuleMutation) PassthroughBody() (r bool, exists bool) { - v := m.passthrough_body - if v == nil { - return - } - return *v, true +// Type returns the node type of this mutation (AuthIdentity). +func (m *AuthIdentityMutation) Type() string { + return m.typ } -// OldPassthroughBody returns the old "passthrough_body" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldPassthroughBody(ctx context.Context) (v bool, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPassthroughBody is only allowed on UpdateOne operations") +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *AuthIdentityMutation) Fields() []string { + fields := make([]string, 0, 9) + if m.created_at != nil { + fields = append(fields, authidentity.FieldCreatedAt) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPassthroughBody requires an ID field in the mutation") + if m.updated_at != nil { + fields = append(fields, authidentity.FieldUpdatedAt) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldPassthroughBody: %w", err) + if m.user != nil { + fields = append(fields, authidentity.FieldUserID) } - return oldValue.PassthroughBody, nil -} - -// ResetPassthroughBody resets all changes to the "passthrough_body" field. -func (m *ErrorPassthroughRuleMutation) ResetPassthroughBody() { - m.passthrough_body = nil -} - -// SetCustomMessage sets the "custom_message" field. -func (m *ErrorPassthroughRuleMutation) SetCustomMessage(s string) { - m.custom_message = &s -} - -// CustomMessage returns the value of the "custom_message" field in the mutation. -func (m *ErrorPassthroughRuleMutation) CustomMessage() (r string, exists bool) { - v := m.custom_message - if v == nil { - return + if m.provider_type != nil { + fields = append(fields, authidentity.FieldProviderType) } - return *v, true -} - -// OldCustomMessage returns the old "custom_message" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldCustomMessage(ctx context.Context) (v *string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldCustomMessage is only allowed on UpdateOne operations") + if m.provider_key != nil { + fields = append(fields, authidentity.FieldProviderKey) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldCustomMessage requires an ID field in the mutation") + if m.provider_subject != nil { + fields = append(fields, authidentity.FieldProviderSubject) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldCustomMessage: %w", err) + if m.verified_at != nil { + fields = append(fields, authidentity.FieldVerifiedAt) } - return oldValue.CustomMessage, nil -} - -// ClearCustomMessage clears the value of the "custom_message" field. -func (m *ErrorPassthroughRuleMutation) ClearCustomMessage() { - m.custom_message = nil - m.clearedFields[errorpassthroughrule.FieldCustomMessage] = struct{}{} + if m.issuer != nil { + fields = append(fields, authidentity.FieldIssuer) + } + if m.metadata != nil { + fields = append(fields, authidentity.FieldMetadata) + } + return fields } -// CustomMessageCleared returns if the "custom_message" field was cleared in this mutation. -func (m *ErrorPassthroughRuleMutation) CustomMessageCleared() bool { - _, ok := m.clearedFields[errorpassthroughrule.FieldCustomMessage] - return ok +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *AuthIdentityMutation) Field(name string) (ent.Value, bool) { + switch name { + case authidentity.FieldCreatedAt: + return m.CreatedAt() + case authidentity.FieldUpdatedAt: + return m.UpdatedAt() + case authidentity.FieldUserID: + return m.UserID() + case authidentity.FieldProviderType: + return m.ProviderType() + case authidentity.FieldProviderKey: + return m.ProviderKey() + case authidentity.FieldProviderSubject: + return m.ProviderSubject() + case authidentity.FieldVerifiedAt: + return m.VerifiedAt() + case authidentity.FieldIssuer: + return m.Issuer() + case authidentity.FieldMetadata: + return m.Metadata() + } + return nil, false } -// ResetCustomMessage resets all changes to the "custom_message" field. -func (m *ErrorPassthroughRuleMutation) ResetCustomMessage() { - m.custom_message = nil - delete(m.clearedFields, errorpassthroughrule.FieldCustomMessage) +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *AuthIdentityMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case authidentity.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case authidentity.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case authidentity.FieldUserID: + return m.OldUserID(ctx) + case authidentity.FieldProviderType: + return m.OldProviderType(ctx) + case authidentity.FieldProviderKey: + return m.OldProviderKey(ctx) + case authidentity.FieldProviderSubject: + return m.OldProviderSubject(ctx) + case authidentity.FieldVerifiedAt: + return m.OldVerifiedAt(ctx) + case authidentity.FieldIssuer: + return m.OldIssuer(ctx) + case authidentity.FieldMetadata: + return m.OldMetadata(ctx) + } + return nil, fmt.Errorf("unknown AuthIdentity field %s", name) } -// SetSkipMonitoring sets the "skip_monitoring" field. -func (m *ErrorPassthroughRuleMutation) SetSkipMonitoring(b bool) { - m.skip_monitoring = &b +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *AuthIdentityMutation) SetField(name string, value ent.Value) error { + switch name { + case authidentity.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case authidentity.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case authidentity.FieldUserID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUserID(v) + return nil + case authidentity.FieldProviderType: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProviderType(v) + return nil + case authidentity.FieldProviderKey: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProviderKey(v) + return nil + case authidentity.FieldProviderSubject: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProviderSubject(v) + return nil + case authidentity.FieldVerifiedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetVerifiedAt(v) + return nil + case authidentity.FieldIssuer: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIssuer(v) + return nil + case authidentity.FieldMetadata: + v, ok := value.(map[string]interface{}) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMetadata(v) + return nil + } + return fmt.Errorf("unknown AuthIdentity field %s", name) } -// SkipMonitoring returns the value of the "skip_monitoring" field in the mutation. -func (m *ErrorPassthroughRuleMutation) SkipMonitoring() (r bool, exists bool) { - v := m.skip_monitoring - if v == nil { - return +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *AuthIdentityMutation) AddedFields() []string { + var fields []string + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *AuthIdentityMutation) AddedField(name string) (ent.Value, bool) { + switch name { } - return *v, true + return nil, false } -// OldSkipMonitoring returns the old "skip_monitoring" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldSkipMonitoring(ctx context.Context) (v bool, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSkipMonitoring is only allowed on UpdateOne operations") +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *AuthIdentityMutation) AddField(name string, value ent.Value) error { + switch name { } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSkipMonitoring requires an ID field in the mutation") + return fmt.Errorf("unknown AuthIdentity numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *AuthIdentityMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(authidentity.FieldVerifiedAt) { + fields = append(fields, authidentity.FieldVerifiedAt) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldSkipMonitoring: %w", err) + if m.FieldCleared(authidentity.FieldIssuer) { + fields = append(fields, authidentity.FieldIssuer) } - return oldValue.SkipMonitoring, nil + return fields } -// ResetSkipMonitoring resets all changes to the "skip_monitoring" field. -func (m *ErrorPassthroughRuleMutation) ResetSkipMonitoring() { - m.skip_monitoring = nil +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *AuthIdentityMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok } -// SetDescription sets the "description" field. -func (m *ErrorPassthroughRuleMutation) SetDescription(s string) { - m.description = &s +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *AuthIdentityMutation) ClearField(name string) error { + switch name { + case authidentity.FieldVerifiedAt: + m.ClearVerifiedAt() + return nil + case authidentity.FieldIssuer: + m.ClearIssuer() + return nil + } + return fmt.Errorf("unknown AuthIdentity nullable field %s", name) } -// Description returns the value of the "description" field in the mutation. -func (m *ErrorPassthroughRuleMutation) Description() (r string, exists bool) { - v := m.description - if v == nil { - return +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *AuthIdentityMutation) ResetField(name string) error { + switch name { + case authidentity.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case authidentity.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case authidentity.FieldUserID: + m.ResetUserID() + return nil + case authidentity.FieldProviderType: + m.ResetProviderType() + return nil + case authidentity.FieldProviderKey: + m.ResetProviderKey() + return nil + case authidentity.FieldProviderSubject: + m.ResetProviderSubject() + return nil + case authidentity.FieldVerifiedAt: + m.ResetVerifiedAt() + return nil + case authidentity.FieldIssuer: + m.ResetIssuer() + return nil + case authidentity.FieldMetadata: + m.ResetMetadata() + return nil } - return *v, true + return fmt.Errorf("unknown AuthIdentity field %s", name) } -// OldDescription returns the old "description" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldDescription(ctx context.Context) (v *string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldDescription is only allowed on UpdateOne operations") +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *AuthIdentityMutation) AddedEdges() []string { + edges := make([]string, 0, 3) + if m.user != nil { + edges = append(edges, authidentity.EdgeUser) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldDescription requires an ID field in the mutation") + if m.channels != nil { + edges = append(edges, authidentity.EdgeChannels) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldDescription: %w", err) + if m.adoption_decisions != nil { + edges = append(edges, authidentity.EdgeAdoptionDecisions) } - return oldValue.Description, nil + return edges } -// ClearDescription clears the value of the "description" field. -func (m *ErrorPassthroughRuleMutation) ClearDescription() { - m.description = nil - m.clearedFields[errorpassthroughrule.FieldDescription] = struct{}{} +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *AuthIdentityMutation) AddedIDs(name string) []ent.Value { + switch name { + case authidentity.EdgeUser: + if id := m.user; id != nil { + return []ent.Value{*id} + } + case authidentity.EdgeChannels: + ids := make([]ent.Value, 0, len(m.channels)) + for id := range m.channels { + ids = append(ids, id) + } + return ids + case authidentity.EdgeAdoptionDecisions: + ids := make([]ent.Value, 0, len(m.adoption_decisions)) + for id := range m.adoption_decisions { + ids = append(ids, id) + } + return ids + } + return nil } -// DescriptionCleared returns if the "description" field was cleared in this mutation. -func (m *ErrorPassthroughRuleMutation) DescriptionCleared() bool { - _, ok := m.clearedFields[errorpassthroughrule.FieldDescription] - return ok +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *AuthIdentityMutation) RemovedEdges() []string { + edges := make([]string, 0, 3) + if m.removedchannels != nil { + edges = append(edges, authidentity.EdgeChannels) + } + if m.removedadoption_decisions != nil { + edges = append(edges, authidentity.EdgeAdoptionDecisions) + } + return edges } -// ResetDescription resets all changes to the "description" field. -func (m *ErrorPassthroughRuleMutation) ResetDescription() { - m.description = nil - delete(m.clearedFields, errorpassthroughrule.FieldDescription) +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *AuthIdentityMutation) RemovedIDs(name string) []ent.Value { + switch name { + case authidentity.EdgeChannels: + ids := make([]ent.Value, 0, len(m.removedchannels)) + for id := range m.removedchannels { + ids = append(ids, id) + } + return ids + case authidentity.EdgeAdoptionDecisions: + ids := make([]ent.Value, 0, len(m.removedadoption_decisions)) + for id := range m.removedadoption_decisions { + ids = append(ids, id) + } + return ids + } + return nil } -// Where appends a list predicates to the ErrorPassthroughRuleMutation builder. -func (m *ErrorPassthroughRuleMutation) Where(ps ...predicate.ErrorPassthroughRule) { - m.predicates = append(m.predicates, ps...) +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *AuthIdentityMutation) ClearedEdges() []string { + edges := make([]string, 0, 3) + if m.cleareduser { + edges = append(edges, authidentity.EdgeUser) + } + if m.clearedchannels { + edges = append(edges, authidentity.EdgeChannels) + } + if m.clearedadoption_decisions { + edges = append(edges, authidentity.EdgeAdoptionDecisions) + } + return edges } -// WhereP appends storage-level predicates to the ErrorPassthroughRuleMutation builder. Using this method, -// users can use type-assertion to append predicates that do not depend on any generated package. -func (m *ErrorPassthroughRuleMutation) WhereP(ps ...func(*sql.Selector)) { - p := make([]predicate.ErrorPassthroughRule, len(ps)) - for i := range ps { - p[i] = ps[i] +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *AuthIdentityMutation) EdgeCleared(name string) bool { + switch name { + case authidentity.EdgeUser: + return m.cleareduser + case authidentity.EdgeChannels: + return m.clearedchannels + case authidentity.EdgeAdoptionDecisions: + return m.clearedadoption_decisions } - m.Where(p...) + return false } -// Op returns the operation name. -func (m *ErrorPassthroughRuleMutation) Op() Op { - return m.op +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *AuthIdentityMutation) ClearEdge(name string) error { + switch name { + case authidentity.EdgeUser: + m.ClearUser() + return nil + } + return fmt.Errorf("unknown AuthIdentity unique edge %s", name) } -// SetOp allows setting the mutation operation. -func (m *ErrorPassthroughRuleMutation) SetOp(op Op) { - m.op = op +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *AuthIdentityMutation) ResetEdge(name string) error { + switch name { + case authidentity.EdgeUser: + m.ResetUser() + return nil + case authidentity.EdgeChannels: + m.ResetChannels() + return nil + case authidentity.EdgeAdoptionDecisions: + m.ResetAdoptionDecisions() + return nil + } + return fmt.Errorf("unknown AuthIdentity edge %s", name) } -// Type returns the node type of this mutation (ErrorPassthroughRule). -func (m *ErrorPassthroughRuleMutation) Type() string { - return m.typ +// AuthIdentityChannelMutation represents an operation that mutates the AuthIdentityChannel nodes in the graph. +type AuthIdentityChannelMutation struct { + config + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + provider_type *string + provider_key *string + channel *string + channel_app_id *string + channel_subject *string + metadata *map[string]interface{} + clearedFields map[string]struct{} + identity *int64 + clearedidentity bool + done bool + oldValue func(context.Context) (*AuthIdentityChannel, error) + predicates []predicate.AuthIdentityChannel } -// Fields returns all fields that were changed during this mutation. Note that in -// order to get all numeric fields that were incremented/decremented, call -// AddedFields(). -func (m *ErrorPassthroughRuleMutation) Fields() []string { - fields := make([]string, 0, 15) - if m.created_at != nil { - fields = append(fields, errorpassthroughrule.FieldCreatedAt) +var _ ent.Mutation = (*AuthIdentityChannelMutation)(nil) + +// authidentitychannelOption allows management of the mutation configuration using functional options. +type authidentitychannelOption func(*AuthIdentityChannelMutation) + +// newAuthIdentityChannelMutation creates new mutation for the AuthIdentityChannel entity. +func newAuthIdentityChannelMutation(c config, op Op, opts ...authidentitychannelOption) *AuthIdentityChannelMutation { + m := &AuthIdentityChannelMutation{ + config: c, + op: op, + typ: TypeAuthIdentityChannel, + clearedFields: make(map[string]struct{}), } - if m.updated_at != nil { - fields = append(fields, errorpassthroughrule.FieldUpdatedAt) + for _, opt := range opts { + opt(m) } - if m.name != nil { - fields = append(fields, errorpassthroughrule.FieldName) + return m +} + +// withAuthIdentityChannelID sets the ID field of the mutation. +func withAuthIdentityChannelID(id int64) authidentitychannelOption { + return func(m *AuthIdentityChannelMutation) { + var ( + err error + once sync.Once + value *AuthIdentityChannel + ) + m.oldValue = func(ctx context.Context) (*AuthIdentityChannel, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().AuthIdentityChannel.Get(ctx, id) + } + }) + return value, err + } + m.id = &id } - if m.enabled != nil { - fields = append(fields, errorpassthroughrule.FieldEnabled) +} + +// withAuthIdentityChannel sets the old AuthIdentityChannel of the mutation. +func withAuthIdentityChannel(node *AuthIdentityChannel) authidentitychannelOption { + return func(m *AuthIdentityChannelMutation) { + m.oldValue = func(context.Context) (*AuthIdentityChannel, error) { + return node, nil + } + m.id = &node.ID } - if m.priority != nil { - fields = append(fields, errorpassthroughrule.FieldPriority) +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m AuthIdentityChannelMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m AuthIdentityChannelMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") } - if m.error_codes != nil { - fields = append(fields, errorpassthroughrule.FieldErrorCodes) + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *AuthIdentityChannelMutation) ID() (id int64, exists bool) { + if m.id == nil { + return } - if m.keywords != nil { - fields = append(fields, errorpassthroughrule.FieldKeywords) + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *AuthIdentityChannelMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().AuthIdentityChannel.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) } - if m.match_mode != nil { - fields = append(fields, errorpassthroughrule.FieldMatchMode) +} + +// SetCreatedAt sets the "created_at" field. +func (m *AuthIdentityChannelMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *AuthIdentityChannelMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return } - if m.platforms != nil { - fields = append(fields, errorpassthroughrule.FieldPlatforms) + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityChannelMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") } - if m.passthrough_code != nil { - fields = append(fields, errorpassthroughrule.FieldPassthroughCode) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") } - if m.response_code != nil { - fields = append(fields, errorpassthroughrule.FieldResponseCode) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) } - if m.passthrough_body != nil { - fields = append(fields, errorpassthroughrule.FieldPassthroughBody) + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *AuthIdentityChannelMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *AuthIdentityChannelMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *AuthIdentityChannelMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return } - if m.custom_message != nil { - fields = append(fields, errorpassthroughrule.FieldCustomMessage) + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityChannelMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") } - if m.skip_monitoring != nil { - fields = append(fields, errorpassthroughrule.FieldSkipMonitoring) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") } - if m.description != nil { - fields = append(fields, errorpassthroughrule.FieldDescription) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) } - return fields + return oldValue.UpdatedAt, nil } -// Field returns the value of a field with the given name. The second boolean -// return value indicates that this field was not set, or was not defined in the -// schema. -func (m *ErrorPassthroughRuleMutation) Field(name string) (ent.Value, bool) { - switch name { - case errorpassthroughrule.FieldCreatedAt: - return m.CreatedAt() - case errorpassthroughrule.FieldUpdatedAt: - return m.UpdatedAt() - case errorpassthroughrule.FieldName: - return m.Name() - case errorpassthroughrule.FieldEnabled: - return m.Enabled() - case errorpassthroughrule.FieldPriority: - return m.Priority() - case errorpassthroughrule.FieldErrorCodes: - return m.ErrorCodes() - case errorpassthroughrule.FieldKeywords: - return m.Keywords() - case errorpassthroughrule.FieldMatchMode: - return m.MatchMode() - case errorpassthroughrule.FieldPlatforms: - return m.Platforms() - case errorpassthroughrule.FieldPassthroughCode: - return m.PassthroughCode() - case errorpassthroughrule.FieldResponseCode: - return m.ResponseCode() - case errorpassthroughrule.FieldPassthroughBody: - return m.PassthroughBody() - case errorpassthroughrule.FieldCustomMessage: - return m.CustomMessage() - case errorpassthroughrule.FieldSkipMonitoring: - return m.SkipMonitoring() - case errorpassthroughrule.FieldDescription: - return m.Description() +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *AuthIdentityChannelMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetIdentityID sets the "identity_id" field. +func (m *AuthIdentityChannelMutation) SetIdentityID(i int64) { + m.identity = &i +} + +// IdentityID returns the value of the "identity_id" field in the mutation. +func (m *AuthIdentityChannelMutation) IdentityID() (r int64, exists bool) { + v := m.identity + if v == nil { + return } - return nil, false + return *v, true } -// OldField returns the old value of the field from the database. An error is -// returned if the mutation operation is not UpdateOne, or the query to the -// database failed. -func (m *ErrorPassthroughRuleMutation) OldField(ctx context.Context, name string) (ent.Value, error) { - switch name { - case errorpassthroughrule.FieldCreatedAt: - return m.OldCreatedAt(ctx) - case errorpassthroughrule.FieldUpdatedAt: - return m.OldUpdatedAt(ctx) - case errorpassthroughrule.FieldName: - return m.OldName(ctx) - case errorpassthroughrule.FieldEnabled: - return m.OldEnabled(ctx) - case errorpassthroughrule.FieldPriority: - return m.OldPriority(ctx) - case errorpassthroughrule.FieldErrorCodes: - return m.OldErrorCodes(ctx) - case errorpassthroughrule.FieldKeywords: - return m.OldKeywords(ctx) - case errorpassthroughrule.FieldMatchMode: - return m.OldMatchMode(ctx) - case errorpassthroughrule.FieldPlatforms: - return m.OldPlatforms(ctx) - case errorpassthroughrule.FieldPassthroughCode: - return m.OldPassthroughCode(ctx) - case errorpassthroughrule.FieldResponseCode: - return m.OldResponseCode(ctx) - case errorpassthroughrule.FieldPassthroughBody: - return m.OldPassthroughBody(ctx) - case errorpassthroughrule.FieldCustomMessage: - return m.OldCustomMessage(ctx) - case errorpassthroughrule.FieldSkipMonitoring: - return m.OldSkipMonitoring(ctx) - case errorpassthroughrule.FieldDescription: - return m.OldDescription(ctx) +// OldIdentityID returns the old "identity_id" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityChannelMutation) OldIdentityID(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIdentityID is only allowed on UpdateOne operations") } - return nil, fmt.Errorf("unknown ErrorPassthroughRule field %s", name) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIdentityID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIdentityID: %w", err) + } + return oldValue.IdentityID, nil } -// SetField sets the value of a field with the given name. It returns an error if -// the field is not defined in the schema, or if the type mismatched the field -// type. -func (m *ErrorPassthroughRuleMutation) SetField(name string, value ent.Value) error { - switch name { - case errorpassthroughrule.FieldCreatedAt: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetCreatedAt(v) - return nil - case errorpassthroughrule.FieldUpdatedAt: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetUpdatedAt(v) - return nil - case errorpassthroughrule.FieldName: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetName(v) - return nil - case errorpassthroughrule.FieldEnabled: - v, ok := value.(bool) +// ResetIdentityID resets all changes to the "identity_id" field. +func (m *AuthIdentityChannelMutation) ResetIdentityID() { + m.identity = nil +} + +// SetProviderType sets the "provider_type" field. +func (m *AuthIdentityChannelMutation) SetProviderType(s string) { + m.provider_type = &s +} + +// ProviderType returns the value of the "provider_type" field in the mutation. +func (m *AuthIdentityChannelMutation) ProviderType() (r string, exists bool) { + v := m.provider_type + if v == nil { + return + } + return *v, true +} + +// OldProviderType returns the old "provider_type" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityChannelMutation) OldProviderType(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderType is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderType requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderType: %w", err) + } + return oldValue.ProviderType, nil +} + +// ResetProviderType resets all changes to the "provider_type" field. +func (m *AuthIdentityChannelMutation) ResetProviderType() { + m.provider_type = nil +} + +// SetProviderKey sets the "provider_key" field. +func (m *AuthIdentityChannelMutation) SetProviderKey(s string) { + m.provider_key = &s +} + +// ProviderKey returns the value of the "provider_key" field in the mutation. +func (m *AuthIdentityChannelMutation) ProviderKey() (r string, exists bool) { + v := m.provider_key + if v == nil { + return + } + return *v, true +} + +// OldProviderKey returns the old "provider_key" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityChannelMutation) OldProviderKey(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderKey is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderKey requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderKey: %w", err) + } + return oldValue.ProviderKey, nil +} + +// ResetProviderKey resets all changes to the "provider_key" field. +func (m *AuthIdentityChannelMutation) ResetProviderKey() { + m.provider_key = nil +} + +// SetChannel sets the "channel" field. +func (m *AuthIdentityChannelMutation) SetChannel(s string) { + m.channel = &s +} + +// Channel returns the value of the "channel" field in the mutation. +func (m *AuthIdentityChannelMutation) Channel() (r string, exists bool) { + v := m.channel + if v == nil { + return + } + return *v, true +} + +// OldChannel returns the old "channel" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityChannelMutation) OldChannel(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldChannel is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldChannel requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldChannel: %w", err) + } + return oldValue.Channel, nil +} + +// ResetChannel resets all changes to the "channel" field. +func (m *AuthIdentityChannelMutation) ResetChannel() { + m.channel = nil +} + +// SetChannelAppID sets the "channel_app_id" field. +func (m *AuthIdentityChannelMutation) SetChannelAppID(s string) { + m.channel_app_id = &s +} + +// ChannelAppID returns the value of the "channel_app_id" field in the mutation. +func (m *AuthIdentityChannelMutation) ChannelAppID() (r string, exists bool) { + v := m.channel_app_id + if v == nil { + return + } + return *v, true +} + +// OldChannelAppID returns the old "channel_app_id" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityChannelMutation) OldChannelAppID(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldChannelAppID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldChannelAppID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldChannelAppID: %w", err) + } + return oldValue.ChannelAppID, nil +} + +// ResetChannelAppID resets all changes to the "channel_app_id" field. +func (m *AuthIdentityChannelMutation) ResetChannelAppID() { + m.channel_app_id = nil +} + +// SetChannelSubject sets the "channel_subject" field. +func (m *AuthIdentityChannelMutation) SetChannelSubject(s string) { + m.channel_subject = &s +} + +// ChannelSubject returns the value of the "channel_subject" field in the mutation. +func (m *AuthIdentityChannelMutation) ChannelSubject() (r string, exists bool) { + v := m.channel_subject + if v == nil { + return + } + return *v, true +} + +// OldChannelSubject returns the old "channel_subject" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityChannelMutation) OldChannelSubject(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldChannelSubject is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldChannelSubject requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldChannelSubject: %w", err) + } + return oldValue.ChannelSubject, nil +} + +// ResetChannelSubject resets all changes to the "channel_subject" field. +func (m *AuthIdentityChannelMutation) ResetChannelSubject() { + m.channel_subject = nil +} + +// SetMetadata sets the "metadata" field. +func (m *AuthIdentityChannelMutation) SetMetadata(value map[string]interface{}) { + m.metadata = &value +} + +// Metadata returns the value of the "metadata" field in the mutation. +func (m *AuthIdentityChannelMutation) Metadata() (r map[string]interface{}, exists bool) { + v := m.metadata + if v == nil { + return + } + return *v, true +} + +// OldMetadata returns the old "metadata" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityChannelMutation) OldMetadata(ctx context.Context) (v map[string]interface{}, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMetadata is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMetadata requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMetadata: %w", err) + } + return oldValue.Metadata, nil +} + +// ResetMetadata resets all changes to the "metadata" field. +func (m *AuthIdentityChannelMutation) ResetMetadata() { + m.metadata = nil +} + +// ClearIdentity clears the "identity" edge to the AuthIdentity entity. +func (m *AuthIdentityChannelMutation) ClearIdentity() { + m.clearedidentity = true + m.clearedFields[authidentitychannel.FieldIdentityID] = struct{}{} +} + +// IdentityCleared reports if the "identity" edge to the AuthIdentity entity was cleared. +func (m *AuthIdentityChannelMutation) IdentityCleared() bool { + return m.clearedidentity +} + +// IdentityIDs returns the "identity" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// IdentityID instead. It exists only for internal usage by the builders. +func (m *AuthIdentityChannelMutation) IdentityIDs() (ids []int64) { + if id := m.identity; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetIdentity resets all changes to the "identity" edge. +func (m *AuthIdentityChannelMutation) ResetIdentity() { + m.identity = nil + m.clearedidentity = false +} + +// Where appends a list predicates to the AuthIdentityChannelMutation builder. +func (m *AuthIdentityChannelMutation) Where(ps ...predicate.AuthIdentityChannel) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the AuthIdentityChannelMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *AuthIdentityChannelMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.AuthIdentityChannel, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *AuthIdentityChannelMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *AuthIdentityChannelMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (AuthIdentityChannel). +func (m *AuthIdentityChannelMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *AuthIdentityChannelMutation) Fields() []string { + fields := make([]string, 0, 9) + if m.created_at != nil { + fields = append(fields, authidentitychannel.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, authidentitychannel.FieldUpdatedAt) + } + if m.identity != nil { + fields = append(fields, authidentitychannel.FieldIdentityID) + } + if m.provider_type != nil { + fields = append(fields, authidentitychannel.FieldProviderType) + } + if m.provider_key != nil { + fields = append(fields, authidentitychannel.FieldProviderKey) + } + if m.channel != nil { + fields = append(fields, authidentitychannel.FieldChannel) + } + if m.channel_app_id != nil { + fields = append(fields, authidentitychannel.FieldChannelAppID) + } + if m.channel_subject != nil { + fields = append(fields, authidentitychannel.FieldChannelSubject) + } + if m.metadata != nil { + fields = append(fields, authidentitychannel.FieldMetadata) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *AuthIdentityChannelMutation) Field(name string) (ent.Value, bool) { + switch name { + case authidentitychannel.FieldCreatedAt: + return m.CreatedAt() + case authidentitychannel.FieldUpdatedAt: + return m.UpdatedAt() + case authidentitychannel.FieldIdentityID: + return m.IdentityID() + case authidentitychannel.FieldProviderType: + return m.ProviderType() + case authidentitychannel.FieldProviderKey: + return m.ProviderKey() + case authidentitychannel.FieldChannel: + return m.Channel() + case authidentitychannel.FieldChannelAppID: + return m.ChannelAppID() + case authidentitychannel.FieldChannelSubject: + return m.ChannelSubject() + case authidentitychannel.FieldMetadata: + return m.Metadata() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *AuthIdentityChannelMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case authidentitychannel.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case authidentitychannel.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case authidentitychannel.FieldIdentityID: + return m.OldIdentityID(ctx) + case authidentitychannel.FieldProviderType: + return m.OldProviderType(ctx) + case authidentitychannel.FieldProviderKey: + return m.OldProviderKey(ctx) + case authidentitychannel.FieldChannel: + return m.OldChannel(ctx) + case authidentitychannel.FieldChannelAppID: + return m.OldChannelAppID(ctx) + case authidentitychannel.FieldChannelSubject: + return m.OldChannelSubject(ctx) + case authidentitychannel.FieldMetadata: + return m.OldMetadata(ctx) + } + return nil, fmt.Errorf("unknown AuthIdentityChannel field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *AuthIdentityChannelMutation) SetField(name string, value ent.Value) error { + switch name { + case authidentitychannel.FieldCreatedAt: + v, ok := value.(time.Time) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetEnabled(v) + m.SetCreatedAt(v) return nil - case errorpassthroughrule.FieldPriority: - v, ok := value.(int) + case authidentitychannel.FieldUpdatedAt: + v, ok := value.(time.Time) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetPriority(v) + m.SetUpdatedAt(v) return nil - case errorpassthroughrule.FieldErrorCodes: - v, ok := value.([]int) + case authidentitychannel.FieldIdentityID: + v, ok := value.(int64) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetErrorCodes(v) + m.SetIdentityID(v) return nil - case errorpassthroughrule.FieldKeywords: - v, ok := value.([]string) + case authidentitychannel.FieldProviderType: + v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetKeywords(v) + m.SetProviderType(v) return nil - case errorpassthroughrule.FieldMatchMode: + case authidentitychannel.FieldProviderKey: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetMatchMode(v) + m.SetProviderKey(v) return nil - case errorpassthroughrule.FieldPlatforms: - v, ok := value.([]string) + case authidentitychannel.FieldChannel: + v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetPlatforms(v) + m.SetChannel(v) return nil - case errorpassthroughrule.FieldPassthroughCode: - v, ok := value.(bool) + case authidentitychannel.FieldChannelAppID: + v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetPassthroughCode(v) + m.SetChannelAppID(v) return nil - case errorpassthroughrule.FieldResponseCode: - v, ok := value.(int) + case authidentitychannel.FieldChannelSubject: + v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetResponseCode(v) + m.SetChannelSubject(v) return nil - case errorpassthroughrule.FieldPassthroughBody: - v, ok := value.(bool) + case authidentitychannel.FieldMetadata: + v, ok := value.(map[string]interface{}) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetPassthroughBody(v) - return nil - case errorpassthroughrule.FieldCustomMessage: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetCustomMessage(v) - return nil - case errorpassthroughrule.FieldSkipMonitoring: - v, ok := value.(bool) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetSkipMonitoring(v) - return nil - case errorpassthroughrule.FieldDescription: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetDescription(v) + m.SetMetadata(v) return nil } - return fmt.Errorf("unknown ErrorPassthroughRule field %s", name) + return fmt.Errorf("unknown AuthIdentityChannel field %s", name) } // AddedFields returns all numeric fields that were incremented/decremented during // this mutation. -func (m *ErrorPassthroughRuleMutation) AddedFields() []string { +func (m *AuthIdentityChannelMutation) AddedFields() []string { var fields []string - if m.addpriority != nil { - fields = append(fields, errorpassthroughrule.FieldPriority) - } - if m.addresponse_code != nil { - fields = append(fields, errorpassthroughrule.FieldResponseCode) - } return fields } // AddedField returns the numeric value that was incremented/decremented on a field // with the given name. The second boolean return value indicates that this field // was not set, or was not defined in the schema. -func (m *ErrorPassthroughRuleMutation) AddedField(name string) (ent.Value, bool) { +func (m *AuthIdentityChannelMutation) AddedField(name string) (ent.Value, bool) { switch name { - case errorpassthroughrule.FieldPriority: - return m.AddedPriority() - case errorpassthroughrule.FieldResponseCode: - return m.AddedResponseCode() } return nil, false } @@ -8028,290 +8600,205 @@ func (m *ErrorPassthroughRuleMutation) AddedField(name string) (ent.Value, bool) // AddField adds the value to the field with the given name. It returns an error if // the field is not defined in the schema, or if the type mismatched the field // type. -func (m *ErrorPassthroughRuleMutation) AddField(name string, value ent.Value) error { +func (m *AuthIdentityChannelMutation) AddField(name string, value ent.Value) error { switch name { - case errorpassthroughrule.FieldPriority: - v, ok := value.(int) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddPriority(v) - return nil - case errorpassthroughrule.FieldResponseCode: - v, ok := value.(int) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddResponseCode(v) - return nil } - return fmt.Errorf("unknown ErrorPassthroughRule numeric field %s", name) + return fmt.Errorf("unknown AuthIdentityChannel numeric field %s", name) } // ClearedFields returns all nullable fields that were cleared during this // mutation. -func (m *ErrorPassthroughRuleMutation) ClearedFields() []string { - var fields []string - if m.FieldCleared(errorpassthroughrule.FieldErrorCodes) { - fields = append(fields, errorpassthroughrule.FieldErrorCodes) - } - if m.FieldCleared(errorpassthroughrule.FieldKeywords) { - fields = append(fields, errorpassthroughrule.FieldKeywords) - } - if m.FieldCleared(errorpassthroughrule.FieldPlatforms) { - fields = append(fields, errorpassthroughrule.FieldPlatforms) - } - if m.FieldCleared(errorpassthroughrule.FieldResponseCode) { - fields = append(fields, errorpassthroughrule.FieldResponseCode) - } - if m.FieldCleared(errorpassthroughrule.FieldCustomMessage) { - fields = append(fields, errorpassthroughrule.FieldCustomMessage) - } - if m.FieldCleared(errorpassthroughrule.FieldDescription) { - fields = append(fields, errorpassthroughrule.FieldDescription) - } - return fields +func (m *AuthIdentityChannelMutation) ClearedFields() []string { + return nil } // FieldCleared returns a boolean indicating if a field with the given name was // cleared in this mutation. -func (m *ErrorPassthroughRuleMutation) FieldCleared(name string) bool { +func (m *AuthIdentityChannelMutation) FieldCleared(name string) bool { _, ok := m.clearedFields[name] return ok } // ClearField clears the value of the field with the given name. It returns an // error if the field is not defined in the schema. -func (m *ErrorPassthroughRuleMutation) ClearField(name string) error { - switch name { - case errorpassthroughrule.FieldErrorCodes: - m.ClearErrorCodes() - return nil - case errorpassthroughrule.FieldKeywords: - m.ClearKeywords() - return nil - case errorpassthroughrule.FieldPlatforms: - m.ClearPlatforms() - return nil - case errorpassthroughrule.FieldResponseCode: - m.ClearResponseCode() - return nil - case errorpassthroughrule.FieldCustomMessage: - m.ClearCustomMessage() - return nil - case errorpassthroughrule.FieldDescription: - m.ClearDescription() - return nil - } - return fmt.Errorf("unknown ErrorPassthroughRule nullable field %s", name) +func (m *AuthIdentityChannelMutation) ClearField(name string) error { + return fmt.Errorf("unknown AuthIdentityChannel nullable field %s", name) } // ResetField resets all changes in the mutation for the field with the given name. // It returns an error if the field is not defined in the schema. -func (m *ErrorPassthroughRuleMutation) ResetField(name string) error { +func (m *AuthIdentityChannelMutation) ResetField(name string) error { switch name { - case errorpassthroughrule.FieldCreatedAt: + case authidentitychannel.FieldCreatedAt: m.ResetCreatedAt() return nil - case errorpassthroughrule.FieldUpdatedAt: + case authidentitychannel.FieldUpdatedAt: m.ResetUpdatedAt() return nil - case errorpassthroughrule.FieldName: - m.ResetName() - return nil - case errorpassthroughrule.FieldEnabled: - m.ResetEnabled() - return nil - case errorpassthroughrule.FieldPriority: - m.ResetPriority() - return nil - case errorpassthroughrule.FieldErrorCodes: - m.ResetErrorCodes() - return nil - case errorpassthroughrule.FieldKeywords: - m.ResetKeywords() - return nil - case errorpassthroughrule.FieldMatchMode: - m.ResetMatchMode() - return nil - case errorpassthroughrule.FieldPlatforms: - m.ResetPlatforms() + case authidentitychannel.FieldIdentityID: + m.ResetIdentityID() return nil - case errorpassthroughrule.FieldPassthroughCode: - m.ResetPassthroughCode() + case authidentitychannel.FieldProviderType: + m.ResetProviderType() return nil - case errorpassthroughrule.FieldResponseCode: - m.ResetResponseCode() + case authidentitychannel.FieldProviderKey: + m.ResetProviderKey() return nil - case errorpassthroughrule.FieldPassthroughBody: - m.ResetPassthroughBody() + case authidentitychannel.FieldChannel: + m.ResetChannel() return nil - case errorpassthroughrule.FieldCustomMessage: - m.ResetCustomMessage() + case authidentitychannel.FieldChannelAppID: + m.ResetChannelAppID() return nil - case errorpassthroughrule.FieldSkipMonitoring: - m.ResetSkipMonitoring() + case authidentitychannel.FieldChannelSubject: + m.ResetChannelSubject() return nil - case errorpassthroughrule.FieldDescription: - m.ResetDescription() + case authidentitychannel.FieldMetadata: + m.ResetMetadata() return nil } - return fmt.Errorf("unknown ErrorPassthroughRule field %s", name) + return fmt.Errorf("unknown AuthIdentityChannel field %s", name) } // AddedEdges returns all edge names that were set/added in this mutation. -func (m *ErrorPassthroughRuleMutation) AddedEdges() []string { - edges := make([]string, 0, 0) +func (m *AuthIdentityChannelMutation) AddedEdges() []string { + edges := make([]string, 0, 1) + if m.identity != nil { + edges = append(edges, authidentitychannel.EdgeIdentity) + } return edges } // AddedIDs returns all IDs (to other nodes) that were added for the given edge // name in this mutation. -func (m *ErrorPassthroughRuleMutation) AddedIDs(name string) []ent.Value { +func (m *AuthIdentityChannelMutation) AddedIDs(name string) []ent.Value { + switch name { + case authidentitychannel.EdgeIdentity: + if id := m.identity; id != nil { + return []ent.Value{*id} + } + } return nil } // RemovedEdges returns all edge names that were removed in this mutation. -func (m *ErrorPassthroughRuleMutation) RemovedEdges() []string { - edges := make([]string, 0, 0) +func (m *AuthIdentityChannelMutation) RemovedEdges() []string { + edges := make([]string, 0, 1) return edges } // RemovedIDs returns all IDs (to other nodes) that were removed for the edge with // the given name in this mutation. -func (m *ErrorPassthroughRuleMutation) RemovedIDs(name string) []ent.Value { +func (m *AuthIdentityChannelMutation) RemovedIDs(name string) []ent.Value { return nil } // ClearedEdges returns all edge names that were cleared in this mutation. -func (m *ErrorPassthroughRuleMutation) ClearedEdges() []string { - edges := make([]string, 0, 0) +func (m *AuthIdentityChannelMutation) ClearedEdges() []string { + edges := make([]string, 0, 1) + if m.clearedidentity { + edges = append(edges, authidentitychannel.EdgeIdentity) + } return edges } // EdgeCleared returns a boolean which indicates if the edge with the given name // was cleared in this mutation. -func (m *ErrorPassthroughRuleMutation) EdgeCleared(name string) bool { +func (m *AuthIdentityChannelMutation) EdgeCleared(name string) bool { + switch name { + case authidentitychannel.EdgeIdentity: + return m.clearedidentity + } return false } // ClearEdge clears the value of the edge with the given name. It returns an error // if that edge is not defined in the schema. -func (m *ErrorPassthroughRuleMutation) ClearEdge(name string) error { - return fmt.Errorf("unknown ErrorPassthroughRule unique edge %s", name) +func (m *AuthIdentityChannelMutation) ClearEdge(name string) error { + switch name { + case authidentitychannel.EdgeIdentity: + m.ClearIdentity() + return nil + } + return fmt.Errorf("unknown AuthIdentityChannel unique edge %s", name) } // ResetEdge resets all changes to the edge with the given name in this mutation. // It returns an error if the edge is not defined in the schema. -func (m *ErrorPassthroughRuleMutation) ResetEdge(name string) error { - return fmt.Errorf("unknown ErrorPassthroughRule edge %s", name) -} - -// GroupMutation represents an operation that mutates the Group nodes in the graph. -type GroupMutation struct { - config - op Op - typ string - id *int64 - created_at *time.Time - updated_at *time.Time - deleted_at *time.Time - name *string - description *string - rate_multiplier *float64 - addrate_multiplier *float64 - is_exclusive *bool - status *string - platform *string - subscription_type *string - daily_limit_usd *float64 - adddaily_limit_usd *float64 - weekly_limit_usd *float64 - addweekly_limit_usd *float64 - monthly_limit_usd *float64 - addmonthly_limit_usd *float64 - default_validity_days *int - adddefault_validity_days *int - image_price_1k *float64 - addimage_price_1k *float64 - image_price_2k *float64 - addimage_price_2k *float64 - image_price_4k *float64 - addimage_price_4k *float64 - claude_code_only *bool - fallback_group_id *int64 - addfallback_group_id *int64 - fallback_group_id_on_invalid_request *int64 - addfallback_group_id_on_invalid_request *int64 - model_routing *map[string][]int64 - model_routing_enabled *bool - mcp_xml_inject *bool - supported_model_scopes *[]string - appendsupported_model_scopes []string - sort_order *int - addsort_order *int - allow_messages_dispatch *bool - require_oauth_only *bool - require_privacy_set *bool - default_mapped_model *string - messages_dispatch_model_config *domain.OpenAIMessagesDispatchModelConfig - clearedFields map[string]struct{} - api_keys map[int64]struct{} - removedapi_keys map[int64]struct{} - clearedapi_keys bool - redeem_codes map[int64]struct{} - removedredeem_codes map[int64]struct{} - clearedredeem_codes bool - subscriptions map[int64]struct{} - removedsubscriptions map[int64]struct{} - clearedsubscriptions bool - usage_logs map[int64]struct{} - removedusage_logs map[int64]struct{} - clearedusage_logs bool - accounts map[int64]struct{} - removedaccounts map[int64]struct{} - clearedaccounts bool - allowed_users map[int64]struct{} - removedallowed_users map[int64]struct{} - clearedallowed_users bool - done bool - oldValue func(context.Context) (*Group, error) - predicates []predicate.Group -} - -var _ ent.Mutation = (*GroupMutation)(nil) - -// groupOption allows management of the mutation configuration using functional options. -type groupOption func(*GroupMutation) - -// newGroupMutation creates new mutation for the Group entity. -func newGroupMutation(c config, op Op, opts ...groupOption) *GroupMutation { - m := &GroupMutation{ - config: c, - op: op, - typ: TypeGroup, - clearedFields: make(map[string]struct{}), - } - for _, opt := range opts { - opt(m) +func (m *AuthIdentityChannelMutation) ResetEdge(name string) error { + switch name { + case authidentitychannel.EdgeIdentity: + m.ResetIdentity() + return nil } - return m + return fmt.Errorf("unknown AuthIdentityChannel edge %s", name) } -// withGroupID sets the ID field of the mutation. -func withGroupID(id int64) groupOption { - return func(m *GroupMutation) { - var ( - err error - once sync.Once - value *Group +// ErrorPassthroughRuleMutation represents an operation that mutates the ErrorPassthroughRule nodes in the graph. +type ErrorPassthroughRuleMutation struct { + config + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + name *string + enabled *bool + priority *int + addpriority *int + error_codes *[]int + appenderror_codes []int + keywords *[]string + appendkeywords []string + match_mode *string + platforms *[]string + appendplatforms []string + passthrough_code *bool + response_code *int + addresponse_code *int + passthrough_body *bool + custom_message *string + skip_monitoring *bool + description *string + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*ErrorPassthroughRule, error) + predicates []predicate.ErrorPassthroughRule +} + +var _ ent.Mutation = (*ErrorPassthroughRuleMutation)(nil) + +// errorpassthroughruleOption allows management of the mutation configuration using functional options. +type errorpassthroughruleOption func(*ErrorPassthroughRuleMutation) + +// newErrorPassthroughRuleMutation creates new mutation for the ErrorPassthroughRule entity. +func newErrorPassthroughRuleMutation(c config, op Op, opts ...errorpassthroughruleOption) *ErrorPassthroughRuleMutation { + m := &ErrorPassthroughRuleMutation{ + config: c, + op: op, + typ: TypeErrorPassthroughRule, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withErrorPassthroughRuleID sets the ID field of the mutation. +func withErrorPassthroughRuleID(id int64) errorpassthroughruleOption { + return func(m *ErrorPassthroughRuleMutation) { + var ( + err error + once sync.Once + value *ErrorPassthroughRule ) - m.oldValue = func(ctx context.Context) (*Group, error) { + m.oldValue = func(ctx context.Context) (*ErrorPassthroughRule, error) { once.Do(func() { if m.done { err = errors.New("querying old values post mutation is not allowed") } else { - value, err = m.Client().Group.Get(ctx, id) + value, err = m.Client().ErrorPassthroughRule.Get(ctx, id) } }) return value, err @@ -8320,10 +8807,10 @@ func withGroupID(id int64) groupOption { } } -// withGroup sets the old Group of the mutation. -func withGroup(node *Group) groupOption { - return func(m *GroupMutation) { - m.oldValue = func(context.Context) (*Group, error) { +// withErrorPassthroughRule sets the old ErrorPassthroughRule of the mutation. +func withErrorPassthroughRule(node *ErrorPassthroughRule) errorpassthroughruleOption { + return func(m *ErrorPassthroughRuleMutation) { + m.oldValue = func(context.Context) (*ErrorPassthroughRule, error) { return node, nil } m.id = &node.ID @@ -8332,7 +8819,7 @@ func withGroup(node *Group) groupOption { // Client returns a new `ent.Client` from the mutation. If the mutation was // executed in a transaction (ent.Tx), a transactional client is returned. -func (m GroupMutation) Client() *Client { +func (m ErrorPassthroughRuleMutation) Client() *Client { client := &Client{config: m.config} client.init() return client @@ -8340,7 +8827,7 @@ func (m GroupMutation) Client() *Client { // Tx returns an `ent.Tx` for mutations that were executed in transactions; // it returns an error otherwise. -func (m GroupMutation) Tx() (*Tx, error) { +func (m ErrorPassthroughRuleMutation) Tx() (*Tx, error) { if _, ok := m.driver.(*txDriver); !ok { return nil, errors.New("ent: mutation is not running in a transaction") } @@ -8351,7 +8838,7 @@ func (m GroupMutation) Tx() (*Tx, error) { // ID returns the ID value in the mutation. Note that the ID is only available // if it was provided to the builder or after it was returned from the database. -func (m *GroupMutation) ID() (id int64, exists bool) { +func (m *ErrorPassthroughRuleMutation) ID() (id int64, exists bool) { if m.id == nil { return } @@ -8362,7 +8849,7 @@ func (m *GroupMutation) ID() (id int64, exists bool) { // That means, if the mutation is applied within a transaction with an isolation level such // as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated // or updated by the mutation. -func (m *GroupMutation) IDs(ctx context.Context) ([]int64, error) { +func (m *ErrorPassthroughRuleMutation) IDs(ctx context.Context) ([]int64, error) { switch { case m.op.Is(OpUpdateOne | OpDeleteOne): id, exists := m.ID() @@ -8371,19 +8858,19 @@ func (m *GroupMutation) IDs(ctx context.Context) ([]int64, error) { } fallthrough case m.op.Is(OpUpdate | OpDelete): - return m.Client().Group.Query().Where(m.predicates...).IDs(ctx) + return m.Client().ErrorPassthroughRule.Query().Where(m.predicates...).IDs(ctx) default: return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) } } // SetCreatedAt sets the "created_at" field. -func (m *GroupMutation) SetCreatedAt(t time.Time) { +func (m *ErrorPassthroughRuleMutation) SetCreatedAt(t time.Time) { m.created_at = &t } // CreatedAt returns the value of the "created_at" field in the mutation. -func (m *GroupMutation) CreatedAt() (r time.Time, exists bool) { +func (m *ErrorPassthroughRuleMutation) CreatedAt() (r time.Time, exists bool) { v := m.created_at if v == nil { return @@ -8391,10 +8878,10 @@ func (m *GroupMutation) CreatedAt() (r time.Time, exists bool) { return *v, true } -// OldCreatedAt returns the old "created_at" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldCreatedAt returns the old "created_at" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { +func (m *ErrorPassthroughRuleMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") } @@ -8409,17 +8896,17 @@ func (m *GroupMutation) OldCreatedAt(ctx context.Context) (v time.Time, err erro } // ResetCreatedAt resets all changes to the "created_at" field. -func (m *GroupMutation) ResetCreatedAt() { +func (m *ErrorPassthroughRuleMutation) ResetCreatedAt() { m.created_at = nil } // SetUpdatedAt sets the "updated_at" field. -func (m *GroupMutation) SetUpdatedAt(t time.Time) { +func (m *ErrorPassthroughRuleMutation) SetUpdatedAt(t time.Time) { m.updated_at = &t } // UpdatedAt returns the value of the "updated_at" field in the mutation. -func (m *GroupMutation) UpdatedAt() (r time.Time, exists bool) { +func (m *ErrorPassthroughRuleMutation) UpdatedAt() (r time.Time, exists bool) { v := m.updated_at if v == nil { return @@ -8427,10 +8914,10 @@ func (m *GroupMutation) UpdatedAt() (r time.Time, exists bool) { return *v, true } -// OldUpdatedAt returns the old "updated_at" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldUpdatedAt returns the old "updated_at" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { +func (m *ErrorPassthroughRuleMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") } @@ -8445,66 +8932,17 @@ func (m *GroupMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err erro } // ResetUpdatedAt resets all changes to the "updated_at" field. -func (m *GroupMutation) ResetUpdatedAt() { +func (m *ErrorPassthroughRuleMutation) ResetUpdatedAt() { m.updated_at = nil } -// SetDeletedAt sets the "deleted_at" field. -func (m *GroupMutation) SetDeletedAt(t time.Time) { - m.deleted_at = &t -} - -// DeletedAt returns the value of the "deleted_at" field in the mutation. -func (m *GroupMutation) DeletedAt() (r time.Time, exists bool) { - v := m.deleted_at - if v == nil { - return - } - return *v, true -} - -// OldDeletedAt returns the old "deleted_at" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldDeletedAt(ctx context.Context) (v *time.Time, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldDeletedAt is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldDeletedAt requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldDeletedAt: %w", err) - } - return oldValue.DeletedAt, nil -} - -// ClearDeletedAt clears the value of the "deleted_at" field. -func (m *GroupMutation) ClearDeletedAt() { - m.deleted_at = nil - m.clearedFields[group.FieldDeletedAt] = struct{}{} -} - -// DeletedAtCleared returns if the "deleted_at" field was cleared in this mutation. -func (m *GroupMutation) DeletedAtCleared() bool { - _, ok := m.clearedFields[group.FieldDeletedAt] - return ok -} - -// ResetDeletedAt resets all changes to the "deleted_at" field. -func (m *GroupMutation) ResetDeletedAt() { - m.deleted_at = nil - delete(m.clearedFields, group.FieldDeletedAt) -} - // SetName sets the "name" field. -func (m *GroupMutation) SetName(s string) { +func (m *ErrorPassthroughRuleMutation) SetName(s string) { m.name = &s } // Name returns the value of the "name" field in the mutation. -func (m *GroupMutation) Name() (r string, exists bool) { +func (m *ErrorPassthroughRuleMutation) Name() (r string, exists bool) { v := m.name if v == nil { return @@ -8512,10 +8950,10 @@ func (m *GroupMutation) Name() (r string, exists bool) { return *v, true } -// OldName returns the old "name" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldName returns the old "name" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldName(ctx context.Context) (v string, err error) { +func (m *ErrorPassthroughRuleMutation) OldName(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldName is only allowed on UpdateOne operations") } @@ -8530,3285 +8968,3061 @@ func (m *GroupMutation) OldName(ctx context.Context) (v string, err error) { } // ResetName resets all changes to the "name" field. -func (m *GroupMutation) ResetName() { +func (m *ErrorPassthroughRuleMutation) ResetName() { m.name = nil } -// SetDescription sets the "description" field. -func (m *GroupMutation) SetDescription(s string) { - m.description = &s +// SetEnabled sets the "enabled" field. +func (m *ErrorPassthroughRuleMutation) SetEnabled(b bool) { + m.enabled = &b } -// Description returns the value of the "description" field in the mutation. -func (m *GroupMutation) Description() (r string, exists bool) { - v := m.description +// Enabled returns the value of the "enabled" field in the mutation. +func (m *ErrorPassthroughRuleMutation) Enabled() (r bool, exists bool) { + v := m.enabled if v == nil { return } return *v, true } -// OldDescription returns the old "description" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldEnabled returns the old "enabled" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldDescription(ctx context.Context) (v *string, err error) { +func (m *ErrorPassthroughRuleMutation) OldEnabled(ctx context.Context) (v bool, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldDescription is only allowed on UpdateOne operations") + return v, errors.New("OldEnabled is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldDescription requires an ID field in the mutation") + return v, errors.New("OldEnabled requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldDescription: %w", err) + return v, fmt.Errorf("querying old value for OldEnabled: %w", err) } - return oldValue.Description, nil -} - -// ClearDescription clears the value of the "description" field. -func (m *GroupMutation) ClearDescription() { - m.description = nil - m.clearedFields[group.FieldDescription] = struct{}{} -} - -// DescriptionCleared returns if the "description" field was cleared in this mutation. -func (m *GroupMutation) DescriptionCleared() bool { - _, ok := m.clearedFields[group.FieldDescription] - return ok + return oldValue.Enabled, nil } -// ResetDescription resets all changes to the "description" field. -func (m *GroupMutation) ResetDescription() { - m.description = nil - delete(m.clearedFields, group.FieldDescription) +// ResetEnabled resets all changes to the "enabled" field. +func (m *ErrorPassthroughRuleMutation) ResetEnabled() { + m.enabled = nil } -// SetRateMultiplier sets the "rate_multiplier" field. -func (m *GroupMutation) SetRateMultiplier(f float64) { - m.rate_multiplier = &f - m.addrate_multiplier = nil +// SetPriority sets the "priority" field. +func (m *ErrorPassthroughRuleMutation) SetPriority(i int) { + m.priority = &i + m.addpriority = nil } -// RateMultiplier returns the value of the "rate_multiplier" field in the mutation. -func (m *GroupMutation) RateMultiplier() (r float64, exists bool) { - v := m.rate_multiplier +// Priority returns the value of the "priority" field in the mutation. +func (m *ErrorPassthroughRuleMutation) Priority() (r int, exists bool) { + v := m.priority if v == nil { return } return *v, true } -// OldRateMultiplier returns the old "rate_multiplier" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldPriority returns the old "priority" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldRateMultiplier(ctx context.Context) (v float64, err error) { +func (m *ErrorPassthroughRuleMutation) OldPriority(ctx context.Context) (v int, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldRateMultiplier is only allowed on UpdateOne operations") + return v, errors.New("OldPriority is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldRateMultiplier requires an ID field in the mutation") + return v, errors.New("OldPriority requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldRateMultiplier: %w", err) + return v, fmt.Errorf("querying old value for OldPriority: %w", err) } - return oldValue.RateMultiplier, nil + return oldValue.Priority, nil } -// AddRateMultiplier adds f to the "rate_multiplier" field. -func (m *GroupMutation) AddRateMultiplier(f float64) { - if m.addrate_multiplier != nil { - *m.addrate_multiplier += f +// AddPriority adds i to the "priority" field. +func (m *ErrorPassthroughRuleMutation) AddPriority(i int) { + if m.addpriority != nil { + *m.addpriority += i } else { - m.addrate_multiplier = &f + m.addpriority = &i } } -// AddedRateMultiplier returns the value that was added to the "rate_multiplier" field in this mutation. -func (m *GroupMutation) AddedRateMultiplier() (r float64, exists bool) { - v := m.addrate_multiplier +// AddedPriority returns the value that was added to the "priority" field in this mutation. +func (m *ErrorPassthroughRuleMutation) AddedPriority() (r int, exists bool) { + v := m.addpriority if v == nil { return } return *v, true } -// ResetRateMultiplier resets all changes to the "rate_multiplier" field. -func (m *GroupMutation) ResetRateMultiplier() { - m.rate_multiplier = nil - m.addrate_multiplier = nil +// ResetPriority resets all changes to the "priority" field. +func (m *ErrorPassthroughRuleMutation) ResetPriority() { + m.priority = nil + m.addpriority = nil } -// SetIsExclusive sets the "is_exclusive" field. -func (m *GroupMutation) SetIsExclusive(b bool) { - m.is_exclusive = &b +// SetErrorCodes sets the "error_codes" field. +func (m *ErrorPassthroughRuleMutation) SetErrorCodes(i []int) { + m.error_codes = &i + m.appenderror_codes = nil } -// IsExclusive returns the value of the "is_exclusive" field in the mutation. -func (m *GroupMutation) IsExclusive() (r bool, exists bool) { - v := m.is_exclusive +// ErrorCodes returns the value of the "error_codes" field in the mutation. +func (m *ErrorPassthroughRuleMutation) ErrorCodes() (r []int, exists bool) { + v := m.error_codes if v == nil { return } return *v, true } -// OldIsExclusive returns the old "is_exclusive" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldErrorCodes returns the old "error_codes" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldIsExclusive(ctx context.Context) (v bool, err error) { +func (m *ErrorPassthroughRuleMutation) OldErrorCodes(ctx context.Context) (v []int, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldIsExclusive is only allowed on UpdateOne operations") + return v, errors.New("OldErrorCodes is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldIsExclusive requires an ID field in the mutation") + return v, errors.New("OldErrorCodes requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldIsExclusive: %w", err) + return v, fmt.Errorf("querying old value for OldErrorCodes: %w", err) } - return oldValue.IsExclusive, nil + return oldValue.ErrorCodes, nil } -// ResetIsExclusive resets all changes to the "is_exclusive" field. -func (m *GroupMutation) ResetIsExclusive() { - m.is_exclusive = nil +// AppendErrorCodes adds i to the "error_codes" field. +func (m *ErrorPassthroughRuleMutation) AppendErrorCodes(i []int) { + m.appenderror_codes = append(m.appenderror_codes, i...) } -// SetStatus sets the "status" field. -func (m *GroupMutation) SetStatus(s string) { - m.status = &s +// AppendedErrorCodes returns the list of values that were appended to the "error_codes" field in this mutation. +func (m *ErrorPassthroughRuleMutation) AppendedErrorCodes() ([]int, bool) { + if len(m.appenderror_codes) == 0 { + return nil, false + } + return m.appenderror_codes, true } -// Status returns the value of the "status" field in the mutation. -func (m *GroupMutation) Status() (r string, exists bool) { - v := m.status - if v == nil { - return - } - return *v, true +// ClearErrorCodes clears the value of the "error_codes" field. +func (m *ErrorPassthroughRuleMutation) ClearErrorCodes() { + m.error_codes = nil + m.appenderror_codes = nil + m.clearedFields[errorpassthroughrule.FieldErrorCodes] = struct{}{} } -// OldStatus returns the old "status" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldStatus(ctx context.Context) (v string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldStatus is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldStatus requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldStatus: %w", err) - } - return oldValue.Status, nil +// ErrorCodesCleared returns if the "error_codes" field was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) ErrorCodesCleared() bool { + _, ok := m.clearedFields[errorpassthroughrule.FieldErrorCodes] + return ok } -// ResetStatus resets all changes to the "status" field. -func (m *GroupMutation) ResetStatus() { - m.status = nil +// ResetErrorCodes resets all changes to the "error_codes" field. +func (m *ErrorPassthroughRuleMutation) ResetErrorCodes() { + m.error_codes = nil + m.appenderror_codes = nil + delete(m.clearedFields, errorpassthroughrule.FieldErrorCodes) } -// SetPlatform sets the "platform" field. -func (m *GroupMutation) SetPlatform(s string) { - m.platform = &s +// SetKeywords sets the "keywords" field. +func (m *ErrorPassthroughRuleMutation) SetKeywords(s []string) { + m.keywords = &s + m.appendkeywords = nil } -// Platform returns the value of the "platform" field in the mutation. -func (m *GroupMutation) Platform() (r string, exists bool) { - v := m.platform +// Keywords returns the value of the "keywords" field in the mutation. +func (m *ErrorPassthroughRuleMutation) Keywords() (r []string, exists bool) { + v := m.keywords if v == nil { return } return *v, true } -// OldPlatform returns the old "platform" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldKeywords returns the old "keywords" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldPlatform(ctx context.Context) (v string, err error) { +func (m *ErrorPassthroughRuleMutation) OldKeywords(ctx context.Context) (v []string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPlatform is only allowed on UpdateOne operations") + return v, errors.New("OldKeywords is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPlatform requires an ID field in the mutation") + return v, errors.New("OldKeywords requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldPlatform: %w", err) + return v, fmt.Errorf("querying old value for OldKeywords: %w", err) } - return oldValue.Platform, nil + return oldValue.Keywords, nil } -// ResetPlatform resets all changes to the "platform" field. -func (m *GroupMutation) ResetPlatform() { - m.platform = nil +// AppendKeywords adds s to the "keywords" field. +func (m *ErrorPassthroughRuleMutation) AppendKeywords(s []string) { + m.appendkeywords = append(m.appendkeywords, s...) } -// SetSubscriptionType sets the "subscription_type" field. -func (m *GroupMutation) SetSubscriptionType(s string) { - m.subscription_type = &s +// AppendedKeywords returns the list of values that were appended to the "keywords" field in this mutation. +func (m *ErrorPassthroughRuleMutation) AppendedKeywords() ([]string, bool) { + if len(m.appendkeywords) == 0 { + return nil, false + } + return m.appendkeywords, true } -// SubscriptionType returns the value of the "subscription_type" field in the mutation. -func (m *GroupMutation) SubscriptionType() (r string, exists bool) { - v := m.subscription_type +// ClearKeywords clears the value of the "keywords" field. +func (m *ErrorPassthroughRuleMutation) ClearKeywords() { + m.keywords = nil + m.appendkeywords = nil + m.clearedFields[errorpassthroughrule.FieldKeywords] = struct{}{} +} + +// KeywordsCleared returns if the "keywords" field was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) KeywordsCleared() bool { + _, ok := m.clearedFields[errorpassthroughrule.FieldKeywords] + return ok +} + +// ResetKeywords resets all changes to the "keywords" field. +func (m *ErrorPassthroughRuleMutation) ResetKeywords() { + m.keywords = nil + m.appendkeywords = nil + delete(m.clearedFields, errorpassthroughrule.FieldKeywords) +} + +// SetMatchMode sets the "match_mode" field. +func (m *ErrorPassthroughRuleMutation) SetMatchMode(s string) { + m.match_mode = &s +} + +// MatchMode returns the value of the "match_mode" field in the mutation. +func (m *ErrorPassthroughRuleMutation) MatchMode() (r string, exists bool) { + v := m.match_mode if v == nil { return } return *v, true } -// OldSubscriptionType returns the old "subscription_type" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldMatchMode returns the old "match_mode" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldSubscriptionType(ctx context.Context) (v string, err error) { +func (m *ErrorPassthroughRuleMutation) OldMatchMode(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSubscriptionType is only allowed on UpdateOne operations") + return v, errors.New("OldMatchMode is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSubscriptionType requires an ID field in the mutation") + return v, errors.New("OldMatchMode requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldSubscriptionType: %w", err) + return v, fmt.Errorf("querying old value for OldMatchMode: %w", err) } - return oldValue.SubscriptionType, nil + return oldValue.MatchMode, nil } -// ResetSubscriptionType resets all changes to the "subscription_type" field. -func (m *GroupMutation) ResetSubscriptionType() { - m.subscription_type = nil +// ResetMatchMode resets all changes to the "match_mode" field. +func (m *ErrorPassthroughRuleMutation) ResetMatchMode() { + m.match_mode = nil } -// SetDailyLimitUsd sets the "daily_limit_usd" field. -func (m *GroupMutation) SetDailyLimitUsd(f float64) { - m.daily_limit_usd = &f - m.adddaily_limit_usd = nil +// SetPlatforms sets the "platforms" field. +func (m *ErrorPassthroughRuleMutation) SetPlatforms(s []string) { + m.platforms = &s + m.appendplatforms = nil } -// DailyLimitUsd returns the value of the "daily_limit_usd" field in the mutation. -func (m *GroupMutation) DailyLimitUsd() (r float64, exists bool) { - v := m.daily_limit_usd +// Platforms returns the value of the "platforms" field in the mutation. +func (m *ErrorPassthroughRuleMutation) Platforms() (r []string, exists bool) { + v := m.platforms if v == nil { return } return *v, true } -// OldDailyLimitUsd returns the old "daily_limit_usd" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldPlatforms returns the old "platforms" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldDailyLimitUsd(ctx context.Context) (v *float64, err error) { +func (m *ErrorPassthroughRuleMutation) OldPlatforms(ctx context.Context) (v []string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldDailyLimitUsd is only allowed on UpdateOne operations") + return v, errors.New("OldPlatforms is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldDailyLimitUsd requires an ID field in the mutation") + return v, errors.New("OldPlatforms requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldDailyLimitUsd: %w", err) + return v, fmt.Errorf("querying old value for OldPlatforms: %w", err) } - return oldValue.DailyLimitUsd, nil + return oldValue.Platforms, nil } -// AddDailyLimitUsd adds f to the "daily_limit_usd" field. -func (m *GroupMutation) AddDailyLimitUsd(f float64) { - if m.adddaily_limit_usd != nil { - *m.adddaily_limit_usd += f - } else { - m.adddaily_limit_usd = &f - } +// AppendPlatforms adds s to the "platforms" field. +func (m *ErrorPassthroughRuleMutation) AppendPlatforms(s []string) { + m.appendplatforms = append(m.appendplatforms, s...) } -// AddedDailyLimitUsd returns the value that was added to the "daily_limit_usd" field in this mutation. -func (m *GroupMutation) AddedDailyLimitUsd() (r float64, exists bool) { - v := m.adddaily_limit_usd - if v == nil { - return +// AppendedPlatforms returns the list of values that were appended to the "platforms" field in this mutation. +func (m *ErrorPassthroughRuleMutation) AppendedPlatforms() ([]string, bool) { + if len(m.appendplatforms) == 0 { + return nil, false } - return *v, true + return m.appendplatforms, true } -// ClearDailyLimitUsd clears the value of the "daily_limit_usd" field. -func (m *GroupMutation) ClearDailyLimitUsd() { - m.daily_limit_usd = nil - m.adddaily_limit_usd = nil - m.clearedFields[group.FieldDailyLimitUsd] = struct{}{} +// ClearPlatforms clears the value of the "platforms" field. +func (m *ErrorPassthroughRuleMutation) ClearPlatforms() { + m.platforms = nil + m.appendplatforms = nil + m.clearedFields[errorpassthroughrule.FieldPlatforms] = struct{}{} } -// DailyLimitUsdCleared returns if the "daily_limit_usd" field was cleared in this mutation. -func (m *GroupMutation) DailyLimitUsdCleared() bool { - _, ok := m.clearedFields[group.FieldDailyLimitUsd] +// PlatformsCleared returns if the "platforms" field was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) PlatformsCleared() bool { + _, ok := m.clearedFields[errorpassthroughrule.FieldPlatforms] return ok } -// ResetDailyLimitUsd resets all changes to the "daily_limit_usd" field. -func (m *GroupMutation) ResetDailyLimitUsd() { - m.daily_limit_usd = nil - m.adddaily_limit_usd = nil - delete(m.clearedFields, group.FieldDailyLimitUsd) +// ResetPlatforms resets all changes to the "platforms" field. +func (m *ErrorPassthroughRuleMutation) ResetPlatforms() { + m.platforms = nil + m.appendplatforms = nil + delete(m.clearedFields, errorpassthroughrule.FieldPlatforms) } -// SetWeeklyLimitUsd sets the "weekly_limit_usd" field. -func (m *GroupMutation) SetWeeklyLimitUsd(f float64) { - m.weekly_limit_usd = &f - m.addweekly_limit_usd = nil +// SetPassthroughCode sets the "passthrough_code" field. +func (m *ErrorPassthroughRuleMutation) SetPassthroughCode(b bool) { + m.passthrough_code = &b } -// WeeklyLimitUsd returns the value of the "weekly_limit_usd" field in the mutation. -func (m *GroupMutation) WeeklyLimitUsd() (r float64, exists bool) { - v := m.weekly_limit_usd +// PassthroughCode returns the value of the "passthrough_code" field in the mutation. +func (m *ErrorPassthroughRuleMutation) PassthroughCode() (r bool, exists bool) { + v := m.passthrough_code if v == nil { return } return *v, true } -// OldWeeklyLimitUsd returns the old "weekly_limit_usd" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldPassthroughCode returns the old "passthrough_code" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldWeeklyLimitUsd(ctx context.Context) (v *float64, err error) { +func (m *ErrorPassthroughRuleMutation) OldPassthroughCode(ctx context.Context) (v bool, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldWeeklyLimitUsd is only allowed on UpdateOne operations") + return v, errors.New("OldPassthroughCode is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldWeeklyLimitUsd requires an ID field in the mutation") + return v, errors.New("OldPassthroughCode requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldWeeklyLimitUsd: %w", err) - } - return oldValue.WeeklyLimitUsd, nil -} - -// AddWeeklyLimitUsd adds f to the "weekly_limit_usd" field. -func (m *GroupMutation) AddWeeklyLimitUsd(f float64) { - if m.addweekly_limit_usd != nil { - *m.addweekly_limit_usd += f - } else { - m.addweekly_limit_usd = &f - } -} - -// AddedWeeklyLimitUsd returns the value that was added to the "weekly_limit_usd" field in this mutation. -func (m *GroupMutation) AddedWeeklyLimitUsd() (r float64, exists bool) { - v := m.addweekly_limit_usd - if v == nil { - return + return v, fmt.Errorf("querying old value for OldPassthroughCode: %w", err) } - return *v, true -} - -// ClearWeeklyLimitUsd clears the value of the "weekly_limit_usd" field. -func (m *GroupMutation) ClearWeeklyLimitUsd() { - m.weekly_limit_usd = nil - m.addweekly_limit_usd = nil - m.clearedFields[group.FieldWeeklyLimitUsd] = struct{}{} -} - -// WeeklyLimitUsdCleared returns if the "weekly_limit_usd" field was cleared in this mutation. -func (m *GroupMutation) WeeklyLimitUsdCleared() bool { - _, ok := m.clearedFields[group.FieldWeeklyLimitUsd] - return ok + return oldValue.PassthroughCode, nil } -// ResetWeeklyLimitUsd resets all changes to the "weekly_limit_usd" field. -func (m *GroupMutation) ResetWeeklyLimitUsd() { - m.weekly_limit_usd = nil - m.addweekly_limit_usd = nil - delete(m.clearedFields, group.FieldWeeklyLimitUsd) +// ResetPassthroughCode resets all changes to the "passthrough_code" field. +func (m *ErrorPassthroughRuleMutation) ResetPassthroughCode() { + m.passthrough_code = nil } -// SetMonthlyLimitUsd sets the "monthly_limit_usd" field. -func (m *GroupMutation) SetMonthlyLimitUsd(f float64) { - m.monthly_limit_usd = &f - m.addmonthly_limit_usd = nil +// SetResponseCode sets the "response_code" field. +func (m *ErrorPassthroughRuleMutation) SetResponseCode(i int) { + m.response_code = &i + m.addresponse_code = nil } -// MonthlyLimitUsd returns the value of the "monthly_limit_usd" field in the mutation. -func (m *GroupMutation) MonthlyLimitUsd() (r float64, exists bool) { - v := m.monthly_limit_usd +// ResponseCode returns the value of the "response_code" field in the mutation. +func (m *ErrorPassthroughRuleMutation) ResponseCode() (r int, exists bool) { + v := m.response_code if v == nil { return } return *v, true } -// OldMonthlyLimitUsd returns the old "monthly_limit_usd" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldResponseCode returns the old "response_code" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldMonthlyLimitUsd(ctx context.Context) (v *float64, err error) { +func (m *ErrorPassthroughRuleMutation) OldResponseCode(ctx context.Context) (v *int, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldMonthlyLimitUsd is only allowed on UpdateOne operations") + return v, errors.New("OldResponseCode is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldMonthlyLimitUsd requires an ID field in the mutation") + return v, errors.New("OldResponseCode requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldMonthlyLimitUsd: %w", err) + return v, fmt.Errorf("querying old value for OldResponseCode: %w", err) } - return oldValue.MonthlyLimitUsd, nil + return oldValue.ResponseCode, nil } -// AddMonthlyLimitUsd adds f to the "monthly_limit_usd" field. -func (m *GroupMutation) AddMonthlyLimitUsd(f float64) { - if m.addmonthly_limit_usd != nil { - *m.addmonthly_limit_usd += f +// AddResponseCode adds i to the "response_code" field. +func (m *ErrorPassthroughRuleMutation) AddResponseCode(i int) { + if m.addresponse_code != nil { + *m.addresponse_code += i } else { - m.addmonthly_limit_usd = &f + m.addresponse_code = &i } } -// AddedMonthlyLimitUsd returns the value that was added to the "monthly_limit_usd" field in this mutation. -func (m *GroupMutation) AddedMonthlyLimitUsd() (r float64, exists bool) { - v := m.addmonthly_limit_usd +// AddedResponseCode returns the value that was added to the "response_code" field in this mutation. +func (m *ErrorPassthroughRuleMutation) AddedResponseCode() (r int, exists bool) { + v := m.addresponse_code if v == nil { return } return *v, true } -// ClearMonthlyLimitUsd clears the value of the "monthly_limit_usd" field. -func (m *GroupMutation) ClearMonthlyLimitUsd() { - m.monthly_limit_usd = nil - m.addmonthly_limit_usd = nil - m.clearedFields[group.FieldMonthlyLimitUsd] = struct{}{} +// ClearResponseCode clears the value of the "response_code" field. +func (m *ErrorPassthroughRuleMutation) ClearResponseCode() { + m.response_code = nil + m.addresponse_code = nil + m.clearedFields[errorpassthroughrule.FieldResponseCode] = struct{}{} } -// MonthlyLimitUsdCleared returns if the "monthly_limit_usd" field was cleared in this mutation. -func (m *GroupMutation) MonthlyLimitUsdCleared() bool { - _, ok := m.clearedFields[group.FieldMonthlyLimitUsd] +// ResponseCodeCleared returns if the "response_code" field was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) ResponseCodeCleared() bool { + _, ok := m.clearedFields[errorpassthroughrule.FieldResponseCode] return ok } -// ResetMonthlyLimitUsd resets all changes to the "monthly_limit_usd" field. -func (m *GroupMutation) ResetMonthlyLimitUsd() { - m.monthly_limit_usd = nil - m.addmonthly_limit_usd = nil - delete(m.clearedFields, group.FieldMonthlyLimitUsd) +// ResetResponseCode resets all changes to the "response_code" field. +func (m *ErrorPassthroughRuleMutation) ResetResponseCode() { + m.response_code = nil + m.addresponse_code = nil + delete(m.clearedFields, errorpassthroughrule.FieldResponseCode) } -// SetDefaultValidityDays sets the "default_validity_days" field. -func (m *GroupMutation) SetDefaultValidityDays(i int) { - m.default_validity_days = &i - m.adddefault_validity_days = nil +// SetPassthroughBody sets the "passthrough_body" field. +func (m *ErrorPassthroughRuleMutation) SetPassthroughBody(b bool) { + m.passthrough_body = &b } -// DefaultValidityDays returns the value of the "default_validity_days" field in the mutation. -func (m *GroupMutation) DefaultValidityDays() (r int, exists bool) { - v := m.default_validity_days +// PassthroughBody returns the value of the "passthrough_body" field in the mutation. +func (m *ErrorPassthroughRuleMutation) PassthroughBody() (r bool, exists bool) { + v := m.passthrough_body if v == nil { return } return *v, true } -// OldDefaultValidityDays returns the old "default_validity_days" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldPassthroughBody returns the old "passthrough_body" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldDefaultValidityDays(ctx context.Context) (v int, err error) { +func (m *ErrorPassthroughRuleMutation) OldPassthroughBody(ctx context.Context) (v bool, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldDefaultValidityDays is only allowed on UpdateOne operations") + return v, errors.New("OldPassthroughBody is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldDefaultValidityDays requires an ID field in the mutation") + return v, errors.New("OldPassthroughBody requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldDefaultValidityDays: %w", err) - } - return oldValue.DefaultValidityDays, nil -} - -// AddDefaultValidityDays adds i to the "default_validity_days" field. -func (m *GroupMutation) AddDefaultValidityDays(i int) { - if m.adddefault_validity_days != nil { - *m.adddefault_validity_days += i - } else { - m.adddefault_validity_days = &i - } -} - -// AddedDefaultValidityDays returns the value that was added to the "default_validity_days" field in this mutation. -func (m *GroupMutation) AddedDefaultValidityDays() (r int, exists bool) { - v := m.adddefault_validity_days - if v == nil { - return + return v, fmt.Errorf("querying old value for OldPassthroughBody: %w", err) } - return *v, true + return oldValue.PassthroughBody, nil } -// ResetDefaultValidityDays resets all changes to the "default_validity_days" field. -func (m *GroupMutation) ResetDefaultValidityDays() { - m.default_validity_days = nil - m.adddefault_validity_days = nil +// ResetPassthroughBody resets all changes to the "passthrough_body" field. +func (m *ErrorPassthroughRuleMutation) ResetPassthroughBody() { + m.passthrough_body = nil } -// SetImagePrice1k sets the "image_price_1k" field. -func (m *GroupMutation) SetImagePrice1k(f float64) { - m.image_price_1k = &f - m.addimage_price_1k = nil +// SetCustomMessage sets the "custom_message" field. +func (m *ErrorPassthroughRuleMutation) SetCustomMessage(s string) { + m.custom_message = &s } -// ImagePrice1k returns the value of the "image_price_1k" field in the mutation. -func (m *GroupMutation) ImagePrice1k() (r float64, exists bool) { - v := m.image_price_1k +// CustomMessage returns the value of the "custom_message" field in the mutation. +func (m *ErrorPassthroughRuleMutation) CustomMessage() (r string, exists bool) { + v := m.custom_message if v == nil { return } return *v, true } -// OldImagePrice1k returns the old "image_price_1k" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldCustomMessage returns the old "custom_message" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldImagePrice1k(ctx context.Context) (v *float64, err error) { +func (m *ErrorPassthroughRuleMutation) OldCustomMessage(ctx context.Context) (v *string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldImagePrice1k is only allowed on UpdateOne operations") + return v, errors.New("OldCustomMessage is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldImagePrice1k requires an ID field in the mutation") + return v, errors.New("OldCustomMessage requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldImagePrice1k: %w", err) + return v, fmt.Errorf("querying old value for OldCustomMessage: %w", err) } - return oldValue.ImagePrice1k, nil + return oldValue.CustomMessage, nil } -// AddImagePrice1k adds f to the "image_price_1k" field. -func (m *GroupMutation) AddImagePrice1k(f float64) { - if m.addimage_price_1k != nil { - *m.addimage_price_1k += f - } else { - m.addimage_price_1k = &f - } +// ClearCustomMessage clears the value of the "custom_message" field. +func (m *ErrorPassthroughRuleMutation) ClearCustomMessage() { + m.custom_message = nil + m.clearedFields[errorpassthroughrule.FieldCustomMessage] = struct{}{} } -// AddedImagePrice1k returns the value that was added to the "image_price_1k" field in this mutation. -func (m *GroupMutation) AddedImagePrice1k() (r float64, exists bool) { - v := m.addimage_price_1k - if v == nil { - return - } - return *v, true +// CustomMessageCleared returns if the "custom_message" field was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) CustomMessageCleared() bool { + _, ok := m.clearedFields[errorpassthroughrule.FieldCustomMessage] + return ok } -// ClearImagePrice1k clears the value of the "image_price_1k" field. -func (m *GroupMutation) ClearImagePrice1k() { - m.image_price_1k = nil - m.addimage_price_1k = nil - m.clearedFields[group.FieldImagePrice1k] = struct{}{} +// ResetCustomMessage resets all changes to the "custom_message" field. +func (m *ErrorPassthroughRuleMutation) ResetCustomMessage() { + m.custom_message = nil + delete(m.clearedFields, errorpassthroughrule.FieldCustomMessage) } -// ImagePrice1kCleared returns if the "image_price_1k" field was cleared in this mutation. -func (m *GroupMutation) ImagePrice1kCleared() bool { - _, ok := m.clearedFields[group.FieldImagePrice1k] - return ok +// SetSkipMonitoring sets the "skip_monitoring" field. +func (m *ErrorPassthroughRuleMutation) SetSkipMonitoring(b bool) { + m.skip_monitoring = &b } -// ResetImagePrice1k resets all changes to the "image_price_1k" field. -func (m *GroupMutation) ResetImagePrice1k() { - m.image_price_1k = nil - m.addimage_price_1k = nil - delete(m.clearedFields, group.FieldImagePrice1k) -} - -// SetImagePrice2k sets the "image_price_2k" field. -func (m *GroupMutation) SetImagePrice2k(f float64) { - m.image_price_2k = &f - m.addimage_price_2k = nil -} - -// ImagePrice2k returns the value of the "image_price_2k" field in the mutation. -func (m *GroupMutation) ImagePrice2k() (r float64, exists bool) { - v := m.image_price_2k +// SkipMonitoring returns the value of the "skip_monitoring" field in the mutation. +func (m *ErrorPassthroughRuleMutation) SkipMonitoring() (r bool, exists bool) { + v := m.skip_monitoring if v == nil { return } return *v, true } -// OldImagePrice2k returns the old "image_price_2k" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldSkipMonitoring returns the old "skip_monitoring" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldImagePrice2k(ctx context.Context) (v *float64, err error) { +func (m *ErrorPassthroughRuleMutation) OldSkipMonitoring(ctx context.Context) (v bool, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldImagePrice2k is only allowed on UpdateOne operations") + return v, errors.New("OldSkipMonitoring is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldImagePrice2k requires an ID field in the mutation") + return v, errors.New("OldSkipMonitoring requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldImagePrice2k: %w", err) - } - return oldValue.ImagePrice2k, nil -} - -// AddImagePrice2k adds f to the "image_price_2k" field. -func (m *GroupMutation) AddImagePrice2k(f float64) { - if m.addimage_price_2k != nil { - *m.addimage_price_2k += f - } else { - m.addimage_price_2k = &f - } -} - -// AddedImagePrice2k returns the value that was added to the "image_price_2k" field in this mutation. -func (m *GroupMutation) AddedImagePrice2k() (r float64, exists bool) { - v := m.addimage_price_2k - if v == nil { - return + return v, fmt.Errorf("querying old value for OldSkipMonitoring: %w", err) } - return *v, true -} - -// ClearImagePrice2k clears the value of the "image_price_2k" field. -func (m *GroupMutation) ClearImagePrice2k() { - m.image_price_2k = nil - m.addimage_price_2k = nil - m.clearedFields[group.FieldImagePrice2k] = struct{}{} -} - -// ImagePrice2kCleared returns if the "image_price_2k" field was cleared in this mutation. -func (m *GroupMutation) ImagePrice2kCleared() bool { - _, ok := m.clearedFields[group.FieldImagePrice2k] - return ok + return oldValue.SkipMonitoring, nil } -// ResetImagePrice2k resets all changes to the "image_price_2k" field. -func (m *GroupMutation) ResetImagePrice2k() { - m.image_price_2k = nil - m.addimage_price_2k = nil - delete(m.clearedFields, group.FieldImagePrice2k) +// ResetSkipMonitoring resets all changes to the "skip_monitoring" field. +func (m *ErrorPassthroughRuleMutation) ResetSkipMonitoring() { + m.skip_monitoring = nil } -// SetImagePrice4k sets the "image_price_4k" field. -func (m *GroupMutation) SetImagePrice4k(f float64) { - m.image_price_4k = &f - m.addimage_price_4k = nil +// SetDescription sets the "description" field. +func (m *ErrorPassthroughRuleMutation) SetDescription(s string) { + m.description = &s } -// ImagePrice4k returns the value of the "image_price_4k" field in the mutation. -func (m *GroupMutation) ImagePrice4k() (r float64, exists bool) { - v := m.image_price_4k +// Description returns the value of the "description" field in the mutation. +func (m *ErrorPassthroughRuleMutation) Description() (r string, exists bool) { + v := m.description if v == nil { return } return *v, true } -// OldImagePrice4k returns the old "image_price_4k" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldDescription returns the old "description" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldImagePrice4k(ctx context.Context) (v *float64, err error) { +func (m *ErrorPassthroughRuleMutation) OldDescription(ctx context.Context) (v *string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldImagePrice4k is only allowed on UpdateOne operations") + return v, errors.New("OldDescription is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldImagePrice4k requires an ID field in the mutation") + return v, errors.New("OldDescription requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldImagePrice4k: %w", err) - } - return oldValue.ImagePrice4k, nil -} - -// AddImagePrice4k adds f to the "image_price_4k" field. -func (m *GroupMutation) AddImagePrice4k(f float64) { - if m.addimage_price_4k != nil { - *m.addimage_price_4k += f - } else { - m.addimage_price_4k = &f - } -} - -// AddedImagePrice4k returns the value that was added to the "image_price_4k" field in this mutation. -func (m *GroupMutation) AddedImagePrice4k() (r float64, exists bool) { - v := m.addimage_price_4k - if v == nil { - return + return v, fmt.Errorf("querying old value for OldDescription: %w", err) } - return *v, true + return oldValue.Description, nil } -// ClearImagePrice4k clears the value of the "image_price_4k" field. -func (m *GroupMutation) ClearImagePrice4k() { - m.image_price_4k = nil - m.addimage_price_4k = nil - m.clearedFields[group.FieldImagePrice4k] = struct{}{} +// ClearDescription clears the value of the "description" field. +func (m *ErrorPassthroughRuleMutation) ClearDescription() { + m.description = nil + m.clearedFields[errorpassthroughrule.FieldDescription] = struct{}{} } -// ImagePrice4kCleared returns if the "image_price_4k" field was cleared in this mutation. -func (m *GroupMutation) ImagePrice4kCleared() bool { - _, ok := m.clearedFields[group.FieldImagePrice4k] +// DescriptionCleared returns if the "description" field was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) DescriptionCleared() bool { + _, ok := m.clearedFields[errorpassthroughrule.FieldDescription] return ok } -// ResetImagePrice4k resets all changes to the "image_price_4k" field. -func (m *GroupMutation) ResetImagePrice4k() { - m.image_price_4k = nil - m.addimage_price_4k = nil - delete(m.clearedFields, group.FieldImagePrice4k) +// ResetDescription resets all changes to the "description" field. +func (m *ErrorPassthroughRuleMutation) ResetDescription() { + m.description = nil + delete(m.clearedFields, errorpassthroughrule.FieldDescription) } -// SetClaudeCodeOnly sets the "claude_code_only" field. -func (m *GroupMutation) SetClaudeCodeOnly(b bool) { - m.claude_code_only = &b +// Where appends a list predicates to the ErrorPassthroughRuleMutation builder. +func (m *ErrorPassthroughRuleMutation) Where(ps ...predicate.ErrorPassthroughRule) { + m.predicates = append(m.predicates, ps...) } -// ClaudeCodeOnly returns the value of the "claude_code_only" field in the mutation. -func (m *GroupMutation) ClaudeCodeOnly() (r bool, exists bool) { - v := m.claude_code_only - if v == nil { - return +// WhereP appends storage-level predicates to the ErrorPassthroughRuleMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *ErrorPassthroughRuleMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.ErrorPassthroughRule, len(ps)) + for i := range ps { + p[i] = ps[i] } - return *v, true + m.Where(p...) } -// OldClaudeCodeOnly returns the old "claude_code_only" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldClaudeCodeOnly(ctx context.Context) (v bool, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldClaudeCodeOnly is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldClaudeCodeOnly requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldClaudeCodeOnly: %w", err) - } - return oldValue.ClaudeCodeOnly, nil +// Op returns the operation name. +func (m *ErrorPassthroughRuleMutation) Op() Op { + return m.op } -// ResetClaudeCodeOnly resets all changes to the "claude_code_only" field. -func (m *GroupMutation) ResetClaudeCodeOnly() { - m.claude_code_only = nil +// SetOp allows setting the mutation operation. +func (m *ErrorPassthroughRuleMutation) SetOp(op Op) { + m.op = op } -// SetFallbackGroupID sets the "fallback_group_id" field. -func (m *GroupMutation) SetFallbackGroupID(i int64) { - m.fallback_group_id = &i - m.addfallback_group_id = nil +// Type returns the node type of this mutation (ErrorPassthroughRule). +func (m *ErrorPassthroughRuleMutation) Type() string { + return m.typ } -// FallbackGroupID returns the value of the "fallback_group_id" field in the mutation. -func (m *GroupMutation) FallbackGroupID() (r int64, exists bool) { - v := m.fallback_group_id - if v == nil { - return +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *ErrorPassthroughRuleMutation) Fields() []string { + fields := make([]string, 0, 15) + if m.created_at != nil { + fields = append(fields, errorpassthroughrule.FieldCreatedAt) } - return *v, true -} - -// OldFallbackGroupID returns the old "fallback_group_id" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldFallbackGroupID(ctx context.Context) (v *int64, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldFallbackGroupID is only allowed on UpdateOne operations") + if m.updated_at != nil { + fields = append(fields, errorpassthroughrule.FieldUpdatedAt) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldFallbackGroupID requires an ID field in the mutation") + if m.name != nil { + fields = append(fields, errorpassthroughrule.FieldName) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldFallbackGroupID: %w", err) + if m.enabled != nil { + fields = append(fields, errorpassthroughrule.FieldEnabled) } - return oldValue.FallbackGroupID, nil -} - -// AddFallbackGroupID adds i to the "fallback_group_id" field. -func (m *GroupMutation) AddFallbackGroupID(i int64) { - if m.addfallback_group_id != nil { - *m.addfallback_group_id += i - } else { - m.addfallback_group_id = &i + if m.priority != nil { + fields = append(fields, errorpassthroughrule.FieldPriority) } -} - -// AddedFallbackGroupID returns the value that was added to the "fallback_group_id" field in this mutation. -func (m *GroupMutation) AddedFallbackGroupID() (r int64, exists bool) { - v := m.addfallback_group_id - if v == nil { - return + if m.error_codes != nil { + fields = append(fields, errorpassthroughrule.FieldErrorCodes) } - return *v, true -} - -// ClearFallbackGroupID clears the value of the "fallback_group_id" field. -func (m *GroupMutation) ClearFallbackGroupID() { - m.fallback_group_id = nil - m.addfallback_group_id = nil - m.clearedFields[group.FieldFallbackGroupID] = struct{}{} -} - -// FallbackGroupIDCleared returns if the "fallback_group_id" field was cleared in this mutation. -func (m *GroupMutation) FallbackGroupIDCleared() bool { - _, ok := m.clearedFields[group.FieldFallbackGroupID] - return ok -} - -// ResetFallbackGroupID resets all changes to the "fallback_group_id" field. -func (m *GroupMutation) ResetFallbackGroupID() { - m.fallback_group_id = nil - m.addfallback_group_id = nil - delete(m.clearedFields, group.FieldFallbackGroupID) -} - -// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field. -func (m *GroupMutation) SetFallbackGroupIDOnInvalidRequest(i int64) { - m.fallback_group_id_on_invalid_request = &i - m.addfallback_group_id_on_invalid_request = nil -} - -// FallbackGroupIDOnInvalidRequest returns the value of the "fallback_group_id_on_invalid_request" field in the mutation. -func (m *GroupMutation) FallbackGroupIDOnInvalidRequest() (r int64, exists bool) { - v := m.fallback_group_id_on_invalid_request - if v == nil { - return - } - return *v, true -} - -// OldFallbackGroupIDOnInvalidRequest returns the old "fallback_group_id_on_invalid_request" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldFallbackGroupIDOnInvalidRequest(ctx context.Context) (v *int64, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldFallbackGroupIDOnInvalidRequest is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldFallbackGroupIDOnInvalidRequest requires an ID field in the mutation") + if m.keywords != nil { + fields = append(fields, errorpassthroughrule.FieldKeywords) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldFallbackGroupIDOnInvalidRequest: %w", err) + if m.match_mode != nil { + fields = append(fields, errorpassthroughrule.FieldMatchMode) } - return oldValue.FallbackGroupIDOnInvalidRequest, nil -} - -// AddFallbackGroupIDOnInvalidRequest adds i to the "fallback_group_id_on_invalid_request" field. -func (m *GroupMutation) AddFallbackGroupIDOnInvalidRequest(i int64) { - if m.addfallback_group_id_on_invalid_request != nil { - *m.addfallback_group_id_on_invalid_request += i - } else { - m.addfallback_group_id_on_invalid_request = &i + if m.platforms != nil { + fields = append(fields, errorpassthroughrule.FieldPlatforms) } -} - -// AddedFallbackGroupIDOnInvalidRequest returns the value that was added to the "fallback_group_id_on_invalid_request" field in this mutation. -func (m *GroupMutation) AddedFallbackGroupIDOnInvalidRequest() (r int64, exists bool) { - v := m.addfallback_group_id_on_invalid_request - if v == nil { - return + if m.passthrough_code != nil { + fields = append(fields, errorpassthroughrule.FieldPassthroughCode) } - return *v, true -} - -// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field. -func (m *GroupMutation) ClearFallbackGroupIDOnInvalidRequest() { - m.fallback_group_id_on_invalid_request = nil - m.addfallback_group_id_on_invalid_request = nil - m.clearedFields[group.FieldFallbackGroupIDOnInvalidRequest] = struct{}{} -} - -// FallbackGroupIDOnInvalidRequestCleared returns if the "fallback_group_id_on_invalid_request" field was cleared in this mutation. -func (m *GroupMutation) FallbackGroupIDOnInvalidRequestCleared() bool { - _, ok := m.clearedFields[group.FieldFallbackGroupIDOnInvalidRequest] - return ok -} - -// ResetFallbackGroupIDOnInvalidRequest resets all changes to the "fallback_group_id_on_invalid_request" field. -func (m *GroupMutation) ResetFallbackGroupIDOnInvalidRequest() { - m.fallback_group_id_on_invalid_request = nil - m.addfallback_group_id_on_invalid_request = nil - delete(m.clearedFields, group.FieldFallbackGroupIDOnInvalidRequest) -} - -// SetModelRouting sets the "model_routing" field. -func (m *GroupMutation) SetModelRouting(value map[string][]int64) { - m.model_routing = &value -} - -// ModelRouting returns the value of the "model_routing" field in the mutation. -func (m *GroupMutation) ModelRouting() (r map[string][]int64, exists bool) { - v := m.model_routing - if v == nil { - return + if m.response_code != nil { + fields = append(fields, errorpassthroughrule.FieldResponseCode) } - return *v, true -} - -// OldModelRouting returns the old "model_routing" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldModelRouting(ctx context.Context) (v map[string][]int64, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldModelRouting is only allowed on UpdateOne operations") + if m.passthrough_body != nil { + fields = append(fields, errorpassthroughrule.FieldPassthroughBody) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldModelRouting requires an ID field in the mutation") + if m.custom_message != nil { + fields = append(fields, errorpassthroughrule.FieldCustomMessage) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldModelRouting: %w", err) + if m.skip_monitoring != nil { + fields = append(fields, errorpassthroughrule.FieldSkipMonitoring) } - return oldValue.ModelRouting, nil -} - -// ClearModelRouting clears the value of the "model_routing" field. -func (m *GroupMutation) ClearModelRouting() { - m.model_routing = nil - m.clearedFields[group.FieldModelRouting] = struct{}{} -} - -// ModelRoutingCleared returns if the "model_routing" field was cleared in this mutation. -func (m *GroupMutation) ModelRoutingCleared() bool { - _, ok := m.clearedFields[group.FieldModelRouting] - return ok -} - -// ResetModelRouting resets all changes to the "model_routing" field. -func (m *GroupMutation) ResetModelRouting() { - m.model_routing = nil - delete(m.clearedFields, group.FieldModelRouting) -} - -// SetModelRoutingEnabled sets the "model_routing_enabled" field. -func (m *GroupMutation) SetModelRoutingEnabled(b bool) { - m.model_routing_enabled = &b -} - -// ModelRoutingEnabled returns the value of the "model_routing_enabled" field in the mutation. -func (m *GroupMutation) ModelRoutingEnabled() (r bool, exists bool) { - v := m.model_routing_enabled - if v == nil { - return + if m.description != nil { + fields = append(fields, errorpassthroughrule.FieldDescription) } - return *v, true + return fields } -// OldModelRoutingEnabled returns the old "model_routing_enabled" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldModelRoutingEnabled(ctx context.Context) (v bool, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldModelRoutingEnabled is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldModelRoutingEnabled requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldModelRoutingEnabled: %w", err) +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *ErrorPassthroughRuleMutation) Field(name string) (ent.Value, bool) { + switch name { + case errorpassthroughrule.FieldCreatedAt: + return m.CreatedAt() + case errorpassthroughrule.FieldUpdatedAt: + return m.UpdatedAt() + case errorpassthroughrule.FieldName: + return m.Name() + case errorpassthroughrule.FieldEnabled: + return m.Enabled() + case errorpassthroughrule.FieldPriority: + return m.Priority() + case errorpassthroughrule.FieldErrorCodes: + return m.ErrorCodes() + case errorpassthroughrule.FieldKeywords: + return m.Keywords() + case errorpassthroughrule.FieldMatchMode: + return m.MatchMode() + case errorpassthroughrule.FieldPlatforms: + return m.Platforms() + case errorpassthroughrule.FieldPassthroughCode: + return m.PassthroughCode() + case errorpassthroughrule.FieldResponseCode: + return m.ResponseCode() + case errorpassthroughrule.FieldPassthroughBody: + return m.PassthroughBody() + case errorpassthroughrule.FieldCustomMessage: + return m.CustomMessage() + case errorpassthroughrule.FieldSkipMonitoring: + return m.SkipMonitoring() + case errorpassthroughrule.FieldDescription: + return m.Description() } - return oldValue.ModelRoutingEnabled, nil -} - -// ResetModelRoutingEnabled resets all changes to the "model_routing_enabled" field. -func (m *GroupMutation) ResetModelRoutingEnabled() { - m.model_routing_enabled = nil -} - -// SetMcpXMLInject sets the "mcp_xml_inject" field. -func (m *GroupMutation) SetMcpXMLInject(b bool) { - m.mcp_xml_inject = &b + return nil, false } -// McpXMLInject returns the value of the "mcp_xml_inject" field in the mutation. -func (m *GroupMutation) McpXMLInject() (r bool, exists bool) { - v := m.mcp_xml_inject - if v == nil { - return +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *ErrorPassthroughRuleMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case errorpassthroughrule.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case errorpassthroughrule.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case errorpassthroughrule.FieldName: + return m.OldName(ctx) + case errorpassthroughrule.FieldEnabled: + return m.OldEnabled(ctx) + case errorpassthroughrule.FieldPriority: + return m.OldPriority(ctx) + case errorpassthroughrule.FieldErrorCodes: + return m.OldErrorCodes(ctx) + case errorpassthroughrule.FieldKeywords: + return m.OldKeywords(ctx) + case errorpassthroughrule.FieldMatchMode: + return m.OldMatchMode(ctx) + case errorpassthroughrule.FieldPlatforms: + return m.OldPlatforms(ctx) + case errorpassthroughrule.FieldPassthroughCode: + return m.OldPassthroughCode(ctx) + case errorpassthroughrule.FieldResponseCode: + return m.OldResponseCode(ctx) + case errorpassthroughrule.FieldPassthroughBody: + return m.OldPassthroughBody(ctx) + case errorpassthroughrule.FieldCustomMessage: + return m.OldCustomMessage(ctx) + case errorpassthroughrule.FieldSkipMonitoring: + return m.OldSkipMonitoring(ctx) + case errorpassthroughrule.FieldDescription: + return m.OldDescription(ctx) } - return *v, true + return nil, fmt.Errorf("unknown ErrorPassthroughRule field %s", name) } -// OldMcpXMLInject returns the old "mcp_xml_inject" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldMcpXMLInject(ctx context.Context) (v bool, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldMcpXMLInject is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldMcpXMLInject requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldMcpXMLInject: %w", err) +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *ErrorPassthroughRuleMutation) SetField(name string, value ent.Value) error { + switch name { + case errorpassthroughrule.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case errorpassthroughrule.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case errorpassthroughrule.FieldName: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetName(v) + return nil + case errorpassthroughrule.FieldEnabled: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetEnabled(v) + return nil + case errorpassthroughrule.FieldPriority: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPriority(v) + return nil + case errorpassthroughrule.FieldErrorCodes: + v, ok := value.([]int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetErrorCodes(v) + return nil + case errorpassthroughrule.FieldKeywords: + v, ok := value.([]string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetKeywords(v) + return nil + case errorpassthroughrule.FieldMatchMode: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMatchMode(v) + return nil + case errorpassthroughrule.FieldPlatforms: + v, ok := value.([]string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPlatforms(v) + return nil + case errorpassthroughrule.FieldPassthroughCode: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPassthroughCode(v) + return nil + case errorpassthroughrule.FieldResponseCode: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetResponseCode(v) + return nil + case errorpassthroughrule.FieldPassthroughBody: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPassthroughBody(v) + return nil + case errorpassthroughrule.FieldCustomMessage: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCustomMessage(v) + return nil + case errorpassthroughrule.FieldSkipMonitoring: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSkipMonitoring(v) + return nil + case errorpassthroughrule.FieldDescription: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDescription(v) + return nil } - return oldValue.McpXMLInject, nil + return fmt.Errorf("unknown ErrorPassthroughRule field %s", name) } -// ResetMcpXMLInject resets all changes to the "mcp_xml_inject" field. -func (m *GroupMutation) ResetMcpXMLInject() { - m.mcp_xml_inject = nil +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *ErrorPassthroughRuleMutation) AddedFields() []string { + var fields []string + if m.addpriority != nil { + fields = append(fields, errorpassthroughrule.FieldPriority) + } + if m.addresponse_code != nil { + fields = append(fields, errorpassthroughrule.FieldResponseCode) + } + return fields } -// SetSupportedModelScopes sets the "supported_model_scopes" field. -func (m *GroupMutation) SetSupportedModelScopes(s []string) { - m.supported_model_scopes = &s - m.appendsupported_model_scopes = nil +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *ErrorPassthroughRuleMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case errorpassthroughrule.FieldPriority: + return m.AddedPriority() + case errorpassthroughrule.FieldResponseCode: + return m.AddedResponseCode() + } + return nil, false } -// SupportedModelScopes returns the value of the "supported_model_scopes" field in the mutation. -func (m *GroupMutation) SupportedModelScopes() (r []string, exists bool) { - v := m.supported_model_scopes - if v == nil { - return +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *ErrorPassthroughRuleMutation) AddField(name string, value ent.Value) error { + switch name { + case errorpassthroughrule.FieldPriority: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddPriority(v) + return nil + case errorpassthroughrule.FieldResponseCode: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddResponseCode(v) + return nil } - return *v, true + return fmt.Errorf("unknown ErrorPassthroughRule numeric field %s", name) } -// OldSupportedModelScopes returns the old "supported_model_scopes" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldSupportedModelScopes(ctx context.Context) (v []string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSupportedModelScopes is only allowed on UpdateOne operations") +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *ErrorPassthroughRuleMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(errorpassthroughrule.FieldErrorCodes) { + fields = append(fields, errorpassthroughrule.FieldErrorCodes) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSupportedModelScopes requires an ID field in the mutation") + if m.FieldCleared(errorpassthroughrule.FieldKeywords) { + fields = append(fields, errorpassthroughrule.FieldKeywords) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldSupportedModelScopes: %w", err) + if m.FieldCleared(errorpassthroughrule.FieldPlatforms) { + fields = append(fields, errorpassthroughrule.FieldPlatforms) } - return oldValue.SupportedModelScopes, nil -} - -// AppendSupportedModelScopes adds s to the "supported_model_scopes" field. -func (m *GroupMutation) AppendSupportedModelScopes(s []string) { - m.appendsupported_model_scopes = append(m.appendsupported_model_scopes, s...) -} - -// AppendedSupportedModelScopes returns the list of values that were appended to the "supported_model_scopes" field in this mutation. -func (m *GroupMutation) AppendedSupportedModelScopes() ([]string, bool) { - if len(m.appendsupported_model_scopes) == 0 { - return nil, false + if m.FieldCleared(errorpassthroughrule.FieldResponseCode) { + fields = append(fields, errorpassthroughrule.FieldResponseCode) } - return m.appendsupported_model_scopes, true + if m.FieldCleared(errorpassthroughrule.FieldCustomMessage) { + fields = append(fields, errorpassthroughrule.FieldCustomMessage) + } + if m.FieldCleared(errorpassthroughrule.FieldDescription) { + fields = append(fields, errorpassthroughrule.FieldDescription) + } + return fields } -// ResetSupportedModelScopes resets all changes to the "supported_model_scopes" field. -func (m *GroupMutation) ResetSupportedModelScopes() { - m.supported_model_scopes = nil - m.appendsupported_model_scopes = nil +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok } -// SetSortOrder sets the "sort_order" field. -func (m *GroupMutation) SetSortOrder(i int) { - m.sort_order = &i - m.addsort_order = nil +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *ErrorPassthroughRuleMutation) ClearField(name string) error { + switch name { + case errorpassthroughrule.FieldErrorCodes: + m.ClearErrorCodes() + return nil + case errorpassthroughrule.FieldKeywords: + m.ClearKeywords() + return nil + case errorpassthroughrule.FieldPlatforms: + m.ClearPlatforms() + return nil + case errorpassthroughrule.FieldResponseCode: + m.ClearResponseCode() + return nil + case errorpassthroughrule.FieldCustomMessage: + m.ClearCustomMessage() + return nil + case errorpassthroughrule.FieldDescription: + m.ClearDescription() + return nil + } + return fmt.Errorf("unknown ErrorPassthroughRule nullable field %s", name) } -// SortOrder returns the value of the "sort_order" field in the mutation. -func (m *GroupMutation) SortOrder() (r int, exists bool) { - v := m.sort_order - if v == nil { - return +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *ErrorPassthroughRuleMutation) ResetField(name string) error { + switch name { + case errorpassthroughrule.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case errorpassthroughrule.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case errorpassthroughrule.FieldName: + m.ResetName() + return nil + case errorpassthroughrule.FieldEnabled: + m.ResetEnabled() + return nil + case errorpassthroughrule.FieldPriority: + m.ResetPriority() + return nil + case errorpassthroughrule.FieldErrorCodes: + m.ResetErrorCodes() + return nil + case errorpassthroughrule.FieldKeywords: + m.ResetKeywords() + return nil + case errorpassthroughrule.FieldMatchMode: + m.ResetMatchMode() + return nil + case errorpassthroughrule.FieldPlatforms: + m.ResetPlatforms() + return nil + case errorpassthroughrule.FieldPassthroughCode: + m.ResetPassthroughCode() + return nil + case errorpassthroughrule.FieldResponseCode: + m.ResetResponseCode() + return nil + case errorpassthroughrule.FieldPassthroughBody: + m.ResetPassthroughBody() + return nil + case errorpassthroughrule.FieldCustomMessage: + m.ResetCustomMessage() + return nil + case errorpassthroughrule.FieldSkipMonitoring: + m.ResetSkipMonitoring() + return nil + case errorpassthroughrule.FieldDescription: + m.ResetDescription() + return nil } - return *v, true + return fmt.Errorf("unknown ErrorPassthroughRule field %s", name) } -// OldSortOrder returns the old "sort_order" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldSortOrder(ctx context.Context) (v int, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSortOrder is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSortOrder requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldSortOrder: %w", err) - } - return oldValue.SortOrder, nil +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *ErrorPassthroughRuleMutation) AddedEdges() []string { + edges := make([]string, 0, 0) + return edges } -// AddSortOrder adds i to the "sort_order" field. -func (m *GroupMutation) AddSortOrder(i int) { - if m.addsort_order != nil { - *m.addsort_order += i - } else { - m.addsort_order = &i - } +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *ErrorPassthroughRuleMutation) AddedIDs(name string) []ent.Value { + return nil } -// AddedSortOrder returns the value that was added to the "sort_order" field in this mutation. -func (m *GroupMutation) AddedSortOrder() (r int, exists bool) { - v := m.addsort_order - if v == nil { - return - } - return *v, true +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *ErrorPassthroughRuleMutation) RemovedEdges() []string { + edges := make([]string, 0, 0) + return edges } -// ResetSortOrder resets all changes to the "sort_order" field. -func (m *GroupMutation) ResetSortOrder() { - m.sort_order = nil - m.addsort_order = nil +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *ErrorPassthroughRuleMutation) RemovedIDs(name string) []ent.Value { + return nil } -// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field. -func (m *GroupMutation) SetAllowMessagesDispatch(b bool) { - m.allow_messages_dispatch = &b +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) ClearedEdges() []string { + edges := make([]string, 0, 0) + return edges } -// AllowMessagesDispatch returns the value of the "allow_messages_dispatch" field in the mutation. -func (m *GroupMutation) AllowMessagesDispatch() (r bool, exists bool) { - v := m.allow_messages_dispatch - if v == nil { - return - } - return *v, true +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) EdgeCleared(name string) bool { + return false } -// OldAllowMessagesDispatch returns the old "allow_messages_dispatch" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldAllowMessagesDispatch(ctx context.Context) (v bool, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldAllowMessagesDispatch is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldAllowMessagesDispatch requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldAllowMessagesDispatch: %w", err) - } - return oldValue.AllowMessagesDispatch, nil +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *ErrorPassthroughRuleMutation) ClearEdge(name string) error { + return fmt.Errorf("unknown ErrorPassthroughRule unique edge %s", name) } -// ResetAllowMessagesDispatch resets all changes to the "allow_messages_dispatch" field. -func (m *GroupMutation) ResetAllowMessagesDispatch() { - m.allow_messages_dispatch = nil +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *ErrorPassthroughRuleMutation) ResetEdge(name string) error { + return fmt.Errorf("unknown ErrorPassthroughRule edge %s", name) } -// SetRequireOauthOnly sets the "require_oauth_only" field. -func (m *GroupMutation) SetRequireOauthOnly(b bool) { - m.require_oauth_only = &b +// GroupMutation represents an operation that mutates the Group nodes in the graph. +type GroupMutation struct { + config + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + deleted_at *time.Time + name *string + description *string + rate_multiplier *float64 + addrate_multiplier *float64 + is_exclusive *bool + status *string + platform *string + subscription_type *string + daily_limit_usd *float64 + adddaily_limit_usd *float64 + weekly_limit_usd *float64 + addweekly_limit_usd *float64 + monthly_limit_usd *float64 + addmonthly_limit_usd *float64 + default_validity_days *int + adddefault_validity_days *int + image_price_1k *float64 + addimage_price_1k *float64 + image_price_2k *float64 + addimage_price_2k *float64 + image_price_4k *float64 + addimage_price_4k *float64 + claude_code_only *bool + fallback_group_id *int64 + addfallback_group_id *int64 + fallback_group_id_on_invalid_request *int64 + addfallback_group_id_on_invalid_request *int64 + model_routing *map[string][]int64 + model_routing_enabled *bool + mcp_xml_inject *bool + supported_model_scopes *[]string + appendsupported_model_scopes []string + sort_order *int + addsort_order *int + allow_messages_dispatch *bool + require_oauth_only *bool + require_privacy_set *bool + default_mapped_model *string + messages_dispatch_model_config *domain.OpenAIMessagesDispatchModelConfig + clearedFields map[string]struct{} + api_keys map[int64]struct{} + removedapi_keys map[int64]struct{} + clearedapi_keys bool + redeem_codes map[int64]struct{} + removedredeem_codes map[int64]struct{} + clearedredeem_codes bool + subscriptions map[int64]struct{} + removedsubscriptions map[int64]struct{} + clearedsubscriptions bool + usage_logs map[int64]struct{} + removedusage_logs map[int64]struct{} + clearedusage_logs bool + accounts map[int64]struct{} + removedaccounts map[int64]struct{} + clearedaccounts bool + allowed_users map[int64]struct{} + removedallowed_users map[int64]struct{} + clearedallowed_users bool + done bool + oldValue func(context.Context) (*Group, error) + predicates []predicate.Group } -// RequireOauthOnly returns the value of the "require_oauth_only" field in the mutation. -func (m *GroupMutation) RequireOauthOnly() (r bool, exists bool) { - v := m.require_oauth_only - if v == nil { - return - } - return *v, true -} +var _ ent.Mutation = (*GroupMutation)(nil) -// OldRequireOauthOnly returns the old "require_oauth_only" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldRequireOauthOnly(ctx context.Context) (v bool, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldRequireOauthOnly is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldRequireOauthOnly requires an ID field in the mutation") +// groupOption allows management of the mutation configuration using functional options. +type groupOption func(*GroupMutation) + +// newGroupMutation creates new mutation for the Group entity. +func newGroupMutation(c config, op Op, opts ...groupOption) *GroupMutation { + m := &GroupMutation{ + config: c, + op: op, + typ: TypeGroup, + clearedFields: make(map[string]struct{}), } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldRequireOauthOnly: %w", err) + for _, opt := range opts { + opt(m) } - return oldValue.RequireOauthOnly, nil + return m } -// ResetRequireOauthOnly resets all changes to the "require_oauth_only" field. -func (m *GroupMutation) ResetRequireOauthOnly() { - m.require_oauth_only = nil +// withGroupID sets the ID field of the mutation. +func withGroupID(id int64) groupOption { + return func(m *GroupMutation) { + var ( + err error + once sync.Once + value *Group + ) + m.oldValue = func(ctx context.Context) (*Group, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().Group.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } } -// SetRequirePrivacySet sets the "require_privacy_set" field. -func (m *GroupMutation) SetRequirePrivacySet(b bool) { - m.require_privacy_set = &b +// withGroup sets the old Group of the mutation. +func withGroup(node *Group) groupOption { + return func(m *GroupMutation) { + m.oldValue = func(context.Context) (*Group, error) { + return node, nil + } + m.id = &node.ID + } } -// RequirePrivacySet returns the value of the "require_privacy_set" field in the mutation. -func (m *GroupMutation) RequirePrivacySet() (r bool, exists bool) { - v := m.require_privacy_set +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m GroupMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m GroupMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *GroupMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *GroupMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().Group.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *GroupMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *GroupMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at if v == nil { return } return *v, true } -// OldRequirePrivacySet returns the old "require_privacy_set" field's value of the Group entity. +// OldCreatedAt returns the old "created_at" field's value of the Group entity. // If the Group object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldRequirePrivacySet(ctx context.Context) (v bool, err error) { +func (m *GroupMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldRequirePrivacySet is only allowed on UpdateOne operations") + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldRequirePrivacySet requires an ID field in the mutation") + return v, errors.New("OldCreatedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldRequirePrivacySet: %w", err) + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) } - return oldValue.RequirePrivacySet, nil + return oldValue.CreatedAt, nil } -// ResetRequirePrivacySet resets all changes to the "require_privacy_set" field. -func (m *GroupMutation) ResetRequirePrivacySet() { - m.require_privacy_set = nil +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *GroupMutation) ResetCreatedAt() { + m.created_at = nil } -// SetDefaultMappedModel sets the "default_mapped_model" field. -func (m *GroupMutation) SetDefaultMappedModel(s string) { - m.default_mapped_model = &s +// SetUpdatedAt sets the "updated_at" field. +func (m *GroupMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t } -// DefaultMappedModel returns the value of the "default_mapped_model" field in the mutation. -func (m *GroupMutation) DefaultMappedModel() (r string, exists bool) { - v := m.default_mapped_model +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *GroupMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at if v == nil { return } return *v, true } -// OldDefaultMappedModel returns the old "default_mapped_model" field's value of the Group entity. +// OldUpdatedAt returns the old "updated_at" field's value of the Group entity. // If the Group object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldDefaultMappedModel(ctx context.Context) (v string, err error) { +func (m *GroupMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldDefaultMappedModel is only allowed on UpdateOne operations") + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldDefaultMappedModel requires an ID field in the mutation") + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldDefaultMappedModel: %w", err) + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) } - return oldValue.DefaultMappedModel, nil + return oldValue.UpdatedAt, nil } -// ResetDefaultMappedModel resets all changes to the "default_mapped_model" field. -func (m *GroupMutation) ResetDefaultMappedModel() { - m.default_mapped_model = nil +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *GroupMutation) ResetUpdatedAt() { + m.updated_at = nil } -// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field. -func (m *GroupMutation) SetMessagesDispatchModelConfig(damdmc domain.OpenAIMessagesDispatchModelConfig) { - m.messages_dispatch_model_config = &damdmc +// SetDeletedAt sets the "deleted_at" field. +func (m *GroupMutation) SetDeletedAt(t time.Time) { + m.deleted_at = &t } -// MessagesDispatchModelConfig returns the value of the "messages_dispatch_model_config" field in the mutation. -func (m *GroupMutation) MessagesDispatchModelConfig() (r domain.OpenAIMessagesDispatchModelConfig, exists bool) { - v := m.messages_dispatch_model_config +// DeletedAt returns the value of the "deleted_at" field in the mutation. +func (m *GroupMutation) DeletedAt() (r time.Time, exists bool) { + v := m.deleted_at if v == nil { return } return *v, true } -// OldMessagesDispatchModelConfig returns the old "messages_dispatch_model_config" field's value of the Group entity. +// OldDeletedAt returns the old "deleted_at" field's value of the Group entity. // If the Group object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldMessagesDispatchModelConfig(ctx context.Context) (v domain.OpenAIMessagesDispatchModelConfig, err error) { +func (m *GroupMutation) OldDeletedAt(ctx context.Context) (v *time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldMessagesDispatchModelConfig is only allowed on UpdateOne operations") + return v, errors.New("OldDeletedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldMessagesDispatchModelConfig requires an ID field in the mutation") + return v, errors.New("OldDeletedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldMessagesDispatchModelConfig: %w", err) + return v, fmt.Errorf("querying old value for OldDeletedAt: %w", err) } - return oldValue.MessagesDispatchModelConfig, nil + return oldValue.DeletedAt, nil } -// ResetMessagesDispatchModelConfig resets all changes to the "messages_dispatch_model_config" field. -func (m *GroupMutation) ResetMessagesDispatchModelConfig() { - m.messages_dispatch_model_config = nil +// ClearDeletedAt clears the value of the "deleted_at" field. +func (m *GroupMutation) ClearDeletedAt() { + m.deleted_at = nil + m.clearedFields[group.FieldDeletedAt] = struct{}{} } -// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids. -func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) { - if m.api_keys == nil { - m.api_keys = make(map[int64]struct{}) - } - for i := range ids { - m.api_keys[ids[i]] = struct{}{} - } +// DeletedAtCleared returns if the "deleted_at" field was cleared in this mutation. +func (m *GroupMutation) DeletedAtCleared() bool { + _, ok := m.clearedFields[group.FieldDeletedAt] + return ok } -// ClearAPIKeys clears the "api_keys" edge to the APIKey entity. -func (m *GroupMutation) ClearAPIKeys() { - m.clearedapi_keys = true +// ResetDeletedAt resets all changes to the "deleted_at" field. +func (m *GroupMutation) ResetDeletedAt() { + m.deleted_at = nil + delete(m.clearedFields, group.FieldDeletedAt) } -// APIKeysCleared reports if the "api_keys" edge to the APIKey entity was cleared. -func (m *GroupMutation) APIKeysCleared() bool { - return m.clearedapi_keys +// SetName sets the "name" field. +func (m *GroupMutation) SetName(s string) { + m.name = &s } -// RemoveAPIKeyIDs removes the "api_keys" edge to the APIKey entity by IDs. -func (m *GroupMutation) RemoveAPIKeyIDs(ids ...int64) { - if m.removedapi_keys == nil { - m.removedapi_keys = make(map[int64]struct{}) - } - for i := range ids { - delete(m.api_keys, ids[i]) - m.removedapi_keys[ids[i]] = struct{}{} +// Name returns the value of the "name" field in the mutation. +func (m *GroupMutation) Name() (r string, exists bool) { + v := m.name + if v == nil { + return } + return *v, true } -// RemovedAPIKeys returns the removed IDs of the "api_keys" edge to the APIKey entity. -func (m *GroupMutation) RemovedAPIKeysIDs() (ids []int64) { - for id := range m.removedapi_keys { - ids = append(ids, id) +// OldName returns the old "name" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldName(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldName is only allowed on UpdateOne operations") } - return -} - -// APIKeysIDs returns the "api_keys" edge IDs in the mutation. -func (m *GroupMutation) APIKeysIDs() (ids []int64) { - for id := range m.api_keys { - ids = append(ids, id) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldName requires an ID field in the mutation") } - return + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldName: %w", err) + } + return oldValue.Name, nil } -// ResetAPIKeys resets all changes to the "api_keys" edge. -func (m *GroupMutation) ResetAPIKeys() { - m.api_keys = nil - m.clearedapi_keys = false - m.removedapi_keys = nil +// ResetName resets all changes to the "name" field. +func (m *GroupMutation) ResetName() { + m.name = nil } -// AddRedeemCodeIDs adds the "redeem_codes" edge to the RedeemCode entity by ids. -func (m *GroupMutation) AddRedeemCodeIDs(ids ...int64) { - if m.redeem_codes == nil { - m.redeem_codes = make(map[int64]struct{}) - } - for i := range ids { - m.redeem_codes[ids[i]] = struct{}{} - } +// SetDescription sets the "description" field. +func (m *GroupMutation) SetDescription(s string) { + m.description = &s } -// ClearRedeemCodes clears the "redeem_codes" edge to the RedeemCode entity. -func (m *GroupMutation) ClearRedeemCodes() { - m.clearedredeem_codes = true -} - -// RedeemCodesCleared reports if the "redeem_codes" edge to the RedeemCode entity was cleared. -func (m *GroupMutation) RedeemCodesCleared() bool { - return m.clearedredeem_codes +// Description returns the value of the "description" field in the mutation. +func (m *GroupMutation) Description() (r string, exists bool) { + v := m.description + if v == nil { + return + } + return *v, true } -// RemoveRedeemCodeIDs removes the "redeem_codes" edge to the RedeemCode entity by IDs. -func (m *GroupMutation) RemoveRedeemCodeIDs(ids ...int64) { - if m.removedredeem_codes == nil { - m.removedredeem_codes = make(map[int64]struct{}) +// OldDescription returns the old "description" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldDescription(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDescription is only allowed on UpdateOne operations") } - for i := range ids { - delete(m.redeem_codes, ids[i]) - m.removedredeem_codes[ids[i]] = struct{}{} + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDescription requires an ID field in the mutation") } -} - -// RemovedRedeemCodes returns the removed IDs of the "redeem_codes" edge to the RedeemCode entity. -func (m *GroupMutation) RemovedRedeemCodesIDs() (ids []int64) { - for id := range m.removedredeem_codes { - ids = append(ids, id) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDescription: %w", err) } - return + return oldValue.Description, nil } -// RedeemCodesIDs returns the "redeem_codes" edge IDs in the mutation. -func (m *GroupMutation) RedeemCodesIDs() (ids []int64) { - for id := range m.redeem_codes { - ids = append(ids, id) - } - return +// ClearDescription clears the value of the "description" field. +func (m *GroupMutation) ClearDescription() { + m.description = nil + m.clearedFields[group.FieldDescription] = struct{}{} } -// ResetRedeemCodes resets all changes to the "redeem_codes" edge. -func (m *GroupMutation) ResetRedeemCodes() { - m.redeem_codes = nil - m.clearedredeem_codes = false - m.removedredeem_codes = nil +// DescriptionCleared returns if the "description" field was cleared in this mutation. +func (m *GroupMutation) DescriptionCleared() bool { + _, ok := m.clearedFields[group.FieldDescription] + return ok } -// AddSubscriptionIDs adds the "subscriptions" edge to the UserSubscription entity by ids. -func (m *GroupMutation) AddSubscriptionIDs(ids ...int64) { - if m.subscriptions == nil { - m.subscriptions = make(map[int64]struct{}) - } - for i := range ids { - m.subscriptions[ids[i]] = struct{}{} - } +// ResetDescription resets all changes to the "description" field. +func (m *GroupMutation) ResetDescription() { + m.description = nil + delete(m.clearedFields, group.FieldDescription) } -// ClearSubscriptions clears the "subscriptions" edge to the UserSubscription entity. -func (m *GroupMutation) ClearSubscriptions() { - m.clearedsubscriptions = true +// SetRateMultiplier sets the "rate_multiplier" field. +func (m *GroupMutation) SetRateMultiplier(f float64) { + m.rate_multiplier = &f + m.addrate_multiplier = nil } -// SubscriptionsCleared reports if the "subscriptions" edge to the UserSubscription entity was cleared. -func (m *GroupMutation) SubscriptionsCleared() bool { - return m.clearedsubscriptions +// RateMultiplier returns the value of the "rate_multiplier" field in the mutation. +func (m *GroupMutation) RateMultiplier() (r float64, exists bool) { + v := m.rate_multiplier + if v == nil { + return + } + return *v, true } -// RemoveSubscriptionIDs removes the "subscriptions" edge to the UserSubscription entity by IDs. -func (m *GroupMutation) RemoveSubscriptionIDs(ids ...int64) { - if m.removedsubscriptions == nil { - m.removedsubscriptions = make(map[int64]struct{}) +// OldRateMultiplier returns the old "rate_multiplier" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldRateMultiplier(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRateMultiplier is only allowed on UpdateOne operations") } - for i := range ids { - delete(m.subscriptions, ids[i]) - m.removedsubscriptions[ids[i]] = struct{}{} + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRateMultiplier requires an ID field in the mutation") } -} - -// RemovedSubscriptions returns the removed IDs of the "subscriptions" edge to the UserSubscription entity. -func (m *GroupMutation) RemovedSubscriptionsIDs() (ids []int64) { - for id := range m.removedsubscriptions { - ids = append(ids, id) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRateMultiplier: %w", err) } - return + return oldValue.RateMultiplier, nil } -// SubscriptionsIDs returns the "subscriptions" edge IDs in the mutation. -func (m *GroupMutation) SubscriptionsIDs() (ids []int64) { - for id := range m.subscriptions { - ids = append(ids, id) +// AddRateMultiplier adds f to the "rate_multiplier" field. +func (m *GroupMutation) AddRateMultiplier(f float64) { + if m.addrate_multiplier != nil { + *m.addrate_multiplier += f + } else { + m.addrate_multiplier = &f } - return } -// ResetSubscriptions resets all changes to the "subscriptions" edge. -func (m *GroupMutation) ResetSubscriptions() { - m.subscriptions = nil - m.clearedsubscriptions = false - m.removedsubscriptions = nil +// AddedRateMultiplier returns the value that was added to the "rate_multiplier" field in this mutation. +func (m *GroupMutation) AddedRateMultiplier() (r float64, exists bool) { + v := m.addrate_multiplier + if v == nil { + return + } + return *v, true } -// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by ids. -func (m *GroupMutation) AddUsageLogIDs(ids ...int64) { - if m.usage_logs == nil { - m.usage_logs = make(map[int64]struct{}) - } - for i := range ids { - m.usage_logs[ids[i]] = struct{}{} - } +// ResetRateMultiplier resets all changes to the "rate_multiplier" field. +func (m *GroupMutation) ResetRateMultiplier() { + m.rate_multiplier = nil + m.addrate_multiplier = nil } -// ClearUsageLogs clears the "usage_logs" edge to the UsageLog entity. -func (m *GroupMutation) ClearUsageLogs() { - m.clearedusage_logs = true +// SetIsExclusive sets the "is_exclusive" field. +func (m *GroupMutation) SetIsExclusive(b bool) { + m.is_exclusive = &b } -// UsageLogsCleared reports if the "usage_logs" edge to the UsageLog entity was cleared. -func (m *GroupMutation) UsageLogsCleared() bool { - return m.clearedusage_logs +// IsExclusive returns the value of the "is_exclusive" field in the mutation. +func (m *GroupMutation) IsExclusive() (r bool, exists bool) { + v := m.is_exclusive + if v == nil { + return + } + return *v, true } -// RemoveUsageLogIDs removes the "usage_logs" edge to the UsageLog entity by IDs. -func (m *GroupMutation) RemoveUsageLogIDs(ids ...int64) { - if m.removedusage_logs == nil { - m.removedusage_logs = make(map[int64]struct{}) +// OldIsExclusive returns the old "is_exclusive" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldIsExclusive(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIsExclusive is only allowed on UpdateOne operations") } - for i := range ids { - delete(m.usage_logs, ids[i]) - m.removedusage_logs[ids[i]] = struct{}{} + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIsExclusive requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIsExclusive: %w", err) } + return oldValue.IsExclusive, nil } -// RemovedUsageLogs returns the removed IDs of the "usage_logs" edge to the UsageLog entity. -func (m *GroupMutation) RemovedUsageLogsIDs() (ids []int64) { - for id := range m.removedusage_logs { - ids = append(ids, id) - } - return +// ResetIsExclusive resets all changes to the "is_exclusive" field. +func (m *GroupMutation) ResetIsExclusive() { + m.is_exclusive = nil } -// UsageLogsIDs returns the "usage_logs" edge IDs in the mutation. -func (m *GroupMutation) UsageLogsIDs() (ids []int64) { - for id := range m.usage_logs { - ids = append(ids, id) - } - return +// SetStatus sets the "status" field. +func (m *GroupMutation) SetStatus(s string) { + m.status = &s } -// ResetUsageLogs resets all changes to the "usage_logs" edge. -func (m *GroupMutation) ResetUsageLogs() { - m.usage_logs = nil - m.clearedusage_logs = false - m.removedusage_logs = nil +// Status returns the value of the "status" field in the mutation. +func (m *GroupMutation) Status() (r string, exists bool) { + v := m.status + if v == nil { + return + } + return *v, true } -// AddAccountIDs adds the "accounts" edge to the Account entity by ids. -func (m *GroupMutation) AddAccountIDs(ids ...int64) { - if m.accounts == nil { - m.accounts = make(map[int64]struct{}) +// OldStatus returns the old "status" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldStatus(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldStatus is only allowed on UpdateOne operations") } - for i := range ids { - m.accounts[ids[i]] = struct{}{} + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldStatus requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldStatus: %w", err) } + return oldValue.Status, nil } -// ClearAccounts clears the "accounts" edge to the Account entity. -func (m *GroupMutation) ClearAccounts() { - m.clearedaccounts = true +// ResetStatus resets all changes to the "status" field. +func (m *GroupMutation) ResetStatus() { + m.status = nil } -// AccountsCleared reports if the "accounts" edge to the Account entity was cleared. -func (m *GroupMutation) AccountsCleared() bool { - return m.clearedaccounts +// SetPlatform sets the "platform" field. +func (m *GroupMutation) SetPlatform(s string) { + m.platform = &s } -// RemoveAccountIDs removes the "accounts" edge to the Account entity by IDs. -func (m *GroupMutation) RemoveAccountIDs(ids ...int64) { - if m.removedaccounts == nil { - m.removedaccounts = make(map[int64]struct{}) - } - for i := range ids { - delete(m.accounts, ids[i]) - m.removedaccounts[ids[i]] = struct{}{} +// Platform returns the value of the "platform" field in the mutation. +func (m *GroupMutation) Platform() (r string, exists bool) { + v := m.platform + if v == nil { + return } + return *v, true } -// RemovedAccounts returns the removed IDs of the "accounts" edge to the Account entity. -func (m *GroupMutation) RemovedAccountsIDs() (ids []int64) { - for id := range m.removedaccounts { - ids = append(ids, id) +// OldPlatform returns the old "platform" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldPlatform(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPlatform is only allowed on UpdateOne operations") } - return + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPlatform requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPlatform: %w", err) + } + return oldValue.Platform, nil } -// AccountsIDs returns the "accounts" edge IDs in the mutation. -func (m *GroupMutation) AccountsIDs() (ids []int64) { - for id := range m.accounts { - ids = append(ids, id) - } - return +// ResetPlatform resets all changes to the "platform" field. +func (m *GroupMutation) ResetPlatform() { + m.platform = nil } -// ResetAccounts resets all changes to the "accounts" edge. -func (m *GroupMutation) ResetAccounts() { - m.accounts = nil - m.clearedaccounts = false - m.removedaccounts = nil +// SetSubscriptionType sets the "subscription_type" field. +func (m *GroupMutation) SetSubscriptionType(s string) { + m.subscription_type = &s } -// AddAllowedUserIDs adds the "allowed_users" edge to the User entity by ids. -func (m *GroupMutation) AddAllowedUserIDs(ids ...int64) { - if m.allowed_users == nil { - m.allowed_users = make(map[int64]struct{}) +// SubscriptionType returns the value of the "subscription_type" field in the mutation. +func (m *GroupMutation) SubscriptionType() (r string, exists bool) { + v := m.subscription_type + if v == nil { + return } - for i := range ids { - m.allowed_users[ids[i]] = struct{}{} + return *v, true +} + +// OldSubscriptionType returns the old "subscription_type" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSubscriptionType(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSubscriptionType is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSubscriptionType requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSubscriptionType: %w", err) } + return oldValue.SubscriptionType, nil } -// ClearAllowedUsers clears the "allowed_users" edge to the User entity. -func (m *GroupMutation) ClearAllowedUsers() { - m.clearedallowed_users = true +// ResetSubscriptionType resets all changes to the "subscription_type" field. +func (m *GroupMutation) ResetSubscriptionType() { + m.subscription_type = nil } -// AllowedUsersCleared reports if the "allowed_users" edge to the User entity was cleared. -func (m *GroupMutation) AllowedUsersCleared() bool { - return m.clearedallowed_users +// SetDailyLimitUsd sets the "daily_limit_usd" field. +func (m *GroupMutation) SetDailyLimitUsd(f float64) { + m.daily_limit_usd = &f + m.adddaily_limit_usd = nil } -// RemoveAllowedUserIDs removes the "allowed_users" edge to the User entity by IDs. -func (m *GroupMutation) RemoveAllowedUserIDs(ids ...int64) { - if m.removedallowed_users == nil { - m.removedallowed_users = make(map[int64]struct{}) - } - for i := range ids { - delete(m.allowed_users, ids[i]) - m.removedallowed_users[ids[i]] = struct{}{} +// DailyLimitUsd returns the value of the "daily_limit_usd" field in the mutation. +func (m *GroupMutation) DailyLimitUsd() (r float64, exists bool) { + v := m.daily_limit_usd + if v == nil { + return } + return *v, true } -// RemovedAllowedUsers returns the removed IDs of the "allowed_users" edge to the User entity. -func (m *GroupMutation) RemovedAllowedUsersIDs() (ids []int64) { - for id := range m.removedallowed_users { - ids = append(ids, id) +// OldDailyLimitUsd returns the old "daily_limit_usd" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldDailyLimitUsd(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDailyLimitUsd is only allowed on UpdateOne operations") } - return + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDailyLimitUsd requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDailyLimitUsd: %w", err) + } + return oldValue.DailyLimitUsd, nil } -// AllowedUsersIDs returns the "allowed_users" edge IDs in the mutation. -func (m *GroupMutation) AllowedUsersIDs() (ids []int64) { - for id := range m.allowed_users { - ids = append(ids, id) +// AddDailyLimitUsd adds f to the "daily_limit_usd" field. +func (m *GroupMutation) AddDailyLimitUsd(f float64) { + if m.adddaily_limit_usd != nil { + *m.adddaily_limit_usd += f + } else { + m.adddaily_limit_usd = &f } - return } -// ResetAllowedUsers resets all changes to the "allowed_users" edge. -func (m *GroupMutation) ResetAllowedUsers() { - m.allowed_users = nil - m.clearedallowed_users = false - m.removedallowed_users = nil +// AddedDailyLimitUsd returns the value that was added to the "daily_limit_usd" field in this mutation. +func (m *GroupMutation) AddedDailyLimitUsd() (r float64, exists bool) { + v := m.adddaily_limit_usd + if v == nil { + return + } + return *v, true } -// Where appends a list predicates to the GroupMutation builder. -func (m *GroupMutation) Where(ps ...predicate.Group) { - m.predicates = append(m.predicates, ps...) +// ClearDailyLimitUsd clears the value of the "daily_limit_usd" field. +func (m *GroupMutation) ClearDailyLimitUsd() { + m.daily_limit_usd = nil + m.adddaily_limit_usd = nil + m.clearedFields[group.FieldDailyLimitUsd] = struct{}{} } -// WhereP appends storage-level predicates to the GroupMutation builder. Using this method, -// users can use type-assertion to append predicates that do not depend on any generated package. -func (m *GroupMutation) WhereP(ps ...func(*sql.Selector)) { - p := make([]predicate.Group, len(ps)) - for i := range ps { - p[i] = ps[i] - } - m.Where(p...) +// DailyLimitUsdCleared returns if the "daily_limit_usd" field was cleared in this mutation. +func (m *GroupMutation) DailyLimitUsdCleared() bool { + _, ok := m.clearedFields[group.FieldDailyLimitUsd] + return ok } -// Op returns the operation name. -func (m *GroupMutation) Op() Op { - return m.op +// ResetDailyLimitUsd resets all changes to the "daily_limit_usd" field. +func (m *GroupMutation) ResetDailyLimitUsd() { + m.daily_limit_usd = nil + m.adddaily_limit_usd = nil + delete(m.clearedFields, group.FieldDailyLimitUsd) } -// SetOp allows setting the mutation operation. -func (m *GroupMutation) SetOp(op Op) { - m.op = op +// SetWeeklyLimitUsd sets the "weekly_limit_usd" field. +func (m *GroupMutation) SetWeeklyLimitUsd(f float64) { + m.weekly_limit_usd = &f + m.addweekly_limit_usd = nil } -// Type returns the node type of this mutation (Group). -func (m *GroupMutation) Type() string { - return m.typ +// WeeklyLimitUsd returns the value of the "weekly_limit_usd" field in the mutation. +func (m *GroupMutation) WeeklyLimitUsd() (r float64, exists bool) { + v := m.weekly_limit_usd + if v == nil { + return + } + return *v, true } -// Fields returns all fields that were changed during this mutation. Note that in -// order to get all numeric fields that were incremented/decremented, call -// AddedFields(). -func (m *GroupMutation) Fields() []string { - fields := make([]string, 0, 30) - if m.created_at != nil { - fields = append(fields, group.FieldCreatedAt) - } - if m.updated_at != nil { - fields = append(fields, group.FieldUpdatedAt) - } - if m.deleted_at != nil { - fields = append(fields, group.FieldDeletedAt) - } - if m.name != nil { - fields = append(fields, group.FieldName) - } - if m.description != nil { - fields = append(fields, group.FieldDescription) - } - if m.rate_multiplier != nil { - fields = append(fields, group.FieldRateMultiplier) - } - if m.is_exclusive != nil { - fields = append(fields, group.FieldIsExclusive) - } - if m.status != nil { - fields = append(fields, group.FieldStatus) - } - if m.platform != nil { - fields = append(fields, group.FieldPlatform) - } - if m.subscription_type != nil { - fields = append(fields, group.FieldSubscriptionType) - } - if m.daily_limit_usd != nil { - fields = append(fields, group.FieldDailyLimitUsd) - } - if m.weekly_limit_usd != nil { - fields = append(fields, group.FieldWeeklyLimitUsd) +// OldWeeklyLimitUsd returns the old "weekly_limit_usd" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldWeeklyLimitUsd(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldWeeklyLimitUsd is only allowed on UpdateOne operations") } - if m.monthly_limit_usd != nil { - fields = append(fields, group.FieldMonthlyLimitUsd) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldWeeklyLimitUsd requires an ID field in the mutation") } - if m.default_validity_days != nil { - fields = append(fields, group.FieldDefaultValidityDays) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldWeeklyLimitUsd: %w", err) } - if m.image_price_1k != nil { - fields = append(fields, group.FieldImagePrice1k) + return oldValue.WeeklyLimitUsd, nil +} + +// AddWeeklyLimitUsd adds f to the "weekly_limit_usd" field. +func (m *GroupMutation) AddWeeklyLimitUsd(f float64) { + if m.addweekly_limit_usd != nil { + *m.addweekly_limit_usd += f + } else { + m.addweekly_limit_usd = &f } - if m.image_price_2k != nil { - fields = append(fields, group.FieldImagePrice2k) +} + +// AddedWeeklyLimitUsd returns the value that was added to the "weekly_limit_usd" field in this mutation. +func (m *GroupMutation) AddedWeeklyLimitUsd() (r float64, exists bool) { + v := m.addweekly_limit_usd + if v == nil { + return } - if m.image_price_4k != nil { - fields = append(fields, group.FieldImagePrice4k) - } - if m.claude_code_only != nil { - fields = append(fields, group.FieldClaudeCodeOnly) - } - if m.fallback_group_id != nil { - fields = append(fields, group.FieldFallbackGroupID) + return *v, true +} + +// ClearWeeklyLimitUsd clears the value of the "weekly_limit_usd" field. +func (m *GroupMutation) ClearWeeklyLimitUsd() { + m.weekly_limit_usd = nil + m.addweekly_limit_usd = nil + m.clearedFields[group.FieldWeeklyLimitUsd] = struct{}{} +} + +// WeeklyLimitUsdCleared returns if the "weekly_limit_usd" field was cleared in this mutation. +func (m *GroupMutation) WeeklyLimitUsdCleared() bool { + _, ok := m.clearedFields[group.FieldWeeklyLimitUsd] + return ok +} + +// ResetWeeklyLimitUsd resets all changes to the "weekly_limit_usd" field. +func (m *GroupMutation) ResetWeeklyLimitUsd() { + m.weekly_limit_usd = nil + m.addweekly_limit_usd = nil + delete(m.clearedFields, group.FieldWeeklyLimitUsd) +} + +// SetMonthlyLimitUsd sets the "monthly_limit_usd" field. +func (m *GroupMutation) SetMonthlyLimitUsd(f float64) { + m.monthly_limit_usd = &f + m.addmonthly_limit_usd = nil +} + +// MonthlyLimitUsd returns the value of the "monthly_limit_usd" field in the mutation. +func (m *GroupMutation) MonthlyLimitUsd() (r float64, exists bool) { + v := m.monthly_limit_usd + if v == nil { + return } - if m.fallback_group_id_on_invalid_request != nil { - fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest) + return *v, true +} + +// OldMonthlyLimitUsd returns the old "monthly_limit_usd" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldMonthlyLimitUsd(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMonthlyLimitUsd is only allowed on UpdateOne operations") } - if m.model_routing != nil { - fields = append(fields, group.FieldModelRouting) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMonthlyLimitUsd requires an ID field in the mutation") } - if m.model_routing_enabled != nil { - fields = append(fields, group.FieldModelRoutingEnabled) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMonthlyLimitUsd: %w", err) } - if m.mcp_xml_inject != nil { - fields = append(fields, group.FieldMcpXMLInject) + return oldValue.MonthlyLimitUsd, nil +} + +// AddMonthlyLimitUsd adds f to the "monthly_limit_usd" field. +func (m *GroupMutation) AddMonthlyLimitUsd(f float64) { + if m.addmonthly_limit_usd != nil { + *m.addmonthly_limit_usd += f + } else { + m.addmonthly_limit_usd = &f } - if m.supported_model_scopes != nil { - fields = append(fields, group.FieldSupportedModelScopes) +} + +// AddedMonthlyLimitUsd returns the value that was added to the "monthly_limit_usd" field in this mutation. +func (m *GroupMutation) AddedMonthlyLimitUsd() (r float64, exists bool) { + v := m.addmonthly_limit_usd + if v == nil { + return } - if m.sort_order != nil { - fields = append(fields, group.FieldSortOrder) + return *v, true +} + +// ClearMonthlyLimitUsd clears the value of the "monthly_limit_usd" field. +func (m *GroupMutation) ClearMonthlyLimitUsd() { + m.monthly_limit_usd = nil + m.addmonthly_limit_usd = nil + m.clearedFields[group.FieldMonthlyLimitUsd] = struct{}{} +} + +// MonthlyLimitUsdCleared returns if the "monthly_limit_usd" field was cleared in this mutation. +func (m *GroupMutation) MonthlyLimitUsdCleared() bool { + _, ok := m.clearedFields[group.FieldMonthlyLimitUsd] + return ok +} + +// ResetMonthlyLimitUsd resets all changes to the "monthly_limit_usd" field. +func (m *GroupMutation) ResetMonthlyLimitUsd() { + m.monthly_limit_usd = nil + m.addmonthly_limit_usd = nil + delete(m.clearedFields, group.FieldMonthlyLimitUsd) +} + +// SetDefaultValidityDays sets the "default_validity_days" field. +func (m *GroupMutation) SetDefaultValidityDays(i int) { + m.default_validity_days = &i + m.adddefault_validity_days = nil +} + +// DefaultValidityDays returns the value of the "default_validity_days" field in the mutation. +func (m *GroupMutation) DefaultValidityDays() (r int, exists bool) { + v := m.default_validity_days + if v == nil { + return } - if m.allow_messages_dispatch != nil { - fields = append(fields, group.FieldAllowMessagesDispatch) + return *v, true +} + +// OldDefaultValidityDays returns the old "default_validity_days" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldDefaultValidityDays(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDefaultValidityDays is only allowed on UpdateOne operations") } - if m.require_oauth_only != nil { - fields = append(fields, group.FieldRequireOauthOnly) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDefaultValidityDays requires an ID field in the mutation") } - if m.require_privacy_set != nil { - fields = append(fields, group.FieldRequirePrivacySet) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDefaultValidityDays: %w", err) } - if m.default_mapped_model != nil { - fields = append(fields, group.FieldDefaultMappedModel) + return oldValue.DefaultValidityDays, nil +} + +// AddDefaultValidityDays adds i to the "default_validity_days" field. +func (m *GroupMutation) AddDefaultValidityDays(i int) { + if m.adddefault_validity_days != nil { + *m.adddefault_validity_days += i + } else { + m.adddefault_validity_days = &i } - if m.messages_dispatch_model_config != nil { - fields = append(fields, group.FieldMessagesDispatchModelConfig) +} + +// AddedDefaultValidityDays returns the value that was added to the "default_validity_days" field in this mutation. +func (m *GroupMutation) AddedDefaultValidityDays() (r int, exists bool) { + v := m.adddefault_validity_days + if v == nil { + return } - return fields + return *v, true } -// Field returns the value of a field with the given name. The second boolean -// return value indicates that this field was not set, or was not defined in the -// schema. -func (m *GroupMutation) Field(name string) (ent.Value, bool) { - switch name { - case group.FieldCreatedAt: - return m.CreatedAt() - case group.FieldUpdatedAt: - return m.UpdatedAt() - case group.FieldDeletedAt: - return m.DeletedAt() - case group.FieldName: - return m.Name() - case group.FieldDescription: - return m.Description() - case group.FieldRateMultiplier: - return m.RateMultiplier() - case group.FieldIsExclusive: - return m.IsExclusive() - case group.FieldStatus: - return m.Status() - case group.FieldPlatform: - return m.Platform() - case group.FieldSubscriptionType: - return m.SubscriptionType() - case group.FieldDailyLimitUsd: - return m.DailyLimitUsd() - case group.FieldWeeklyLimitUsd: - return m.WeeklyLimitUsd() - case group.FieldMonthlyLimitUsd: - return m.MonthlyLimitUsd() - case group.FieldDefaultValidityDays: - return m.DefaultValidityDays() - case group.FieldImagePrice1k: - return m.ImagePrice1k() - case group.FieldImagePrice2k: - return m.ImagePrice2k() - case group.FieldImagePrice4k: - return m.ImagePrice4k() - case group.FieldClaudeCodeOnly: - return m.ClaudeCodeOnly() - case group.FieldFallbackGroupID: - return m.FallbackGroupID() - case group.FieldFallbackGroupIDOnInvalidRequest: - return m.FallbackGroupIDOnInvalidRequest() - case group.FieldModelRouting: - return m.ModelRouting() - case group.FieldModelRoutingEnabled: - return m.ModelRoutingEnabled() - case group.FieldMcpXMLInject: - return m.McpXMLInject() - case group.FieldSupportedModelScopes: - return m.SupportedModelScopes() - case group.FieldSortOrder: - return m.SortOrder() - case group.FieldAllowMessagesDispatch: - return m.AllowMessagesDispatch() - case group.FieldRequireOauthOnly: - return m.RequireOauthOnly() - case group.FieldRequirePrivacySet: - return m.RequirePrivacySet() - case group.FieldDefaultMappedModel: - return m.DefaultMappedModel() - case group.FieldMessagesDispatchModelConfig: - return m.MessagesDispatchModelConfig() +// ResetDefaultValidityDays resets all changes to the "default_validity_days" field. +func (m *GroupMutation) ResetDefaultValidityDays() { + m.default_validity_days = nil + m.adddefault_validity_days = nil +} + +// SetImagePrice1k sets the "image_price_1k" field. +func (m *GroupMutation) SetImagePrice1k(f float64) { + m.image_price_1k = &f + m.addimage_price_1k = nil +} + +// ImagePrice1k returns the value of the "image_price_1k" field in the mutation. +func (m *GroupMutation) ImagePrice1k() (r float64, exists bool) { + v := m.image_price_1k + if v == nil { + return } - return nil, false + return *v, true } -// OldField returns the old value of the field from the database. An error is -// returned if the mutation operation is not UpdateOne, or the query to the -// database failed. -func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, error) { - switch name { - case group.FieldCreatedAt: - return m.OldCreatedAt(ctx) - case group.FieldUpdatedAt: - return m.OldUpdatedAt(ctx) - case group.FieldDeletedAt: - return m.OldDeletedAt(ctx) - case group.FieldName: - return m.OldName(ctx) - case group.FieldDescription: - return m.OldDescription(ctx) - case group.FieldRateMultiplier: - return m.OldRateMultiplier(ctx) - case group.FieldIsExclusive: - return m.OldIsExclusive(ctx) - case group.FieldStatus: - return m.OldStatus(ctx) - case group.FieldPlatform: - return m.OldPlatform(ctx) - case group.FieldSubscriptionType: - return m.OldSubscriptionType(ctx) - case group.FieldDailyLimitUsd: - return m.OldDailyLimitUsd(ctx) - case group.FieldWeeklyLimitUsd: - return m.OldWeeklyLimitUsd(ctx) - case group.FieldMonthlyLimitUsd: - return m.OldMonthlyLimitUsd(ctx) - case group.FieldDefaultValidityDays: - return m.OldDefaultValidityDays(ctx) - case group.FieldImagePrice1k: - return m.OldImagePrice1k(ctx) - case group.FieldImagePrice2k: - return m.OldImagePrice2k(ctx) - case group.FieldImagePrice4k: - return m.OldImagePrice4k(ctx) - case group.FieldClaudeCodeOnly: - return m.OldClaudeCodeOnly(ctx) - case group.FieldFallbackGroupID: - return m.OldFallbackGroupID(ctx) - case group.FieldFallbackGroupIDOnInvalidRequest: - return m.OldFallbackGroupIDOnInvalidRequest(ctx) - case group.FieldModelRouting: - return m.OldModelRouting(ctx) - case group.FieldModelRoutingEnabled: - return m.OldModelRoutingEnabled(ctx) - case group.FieldMcpXMLInject: - return m.OldMcpXMLInject(ctx) - case group.FieldSupportedModelScopes: - return m.OldSupportedModelScopes(ctx) - case group.FieldSortOrder: - return m.OldSortOrder(ctx) - case group.FieldAllowMessagesDispatch: - return m.OldAllowMessagesDispatch(ctx) - case group.FieldRequireOauthOnly: - return m.OldRequireOauthOnly(ctx) - case group.FieldRequirePrivacySet: - return m.OldRequirePrivacySet(ctx) - case group.FieldDefaultMappedModel: - return m.OldDefaultMappedModel(ctx) - case group.FieldMessagesDispatchModelConfig: - return m.OldMessagesDispatchModelConfig(ctx) +// OldImagePrice1k returns the old "image_price_1k" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldImagePrice1k(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldImagePrice1k is only allowed on UpdateOne operations") } - return nil, fmt.Errorf("unknown Group field %s", name) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldImagePrice1k requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldImagePrice1k: %w", err) + } + return oldValue.ImagePrice1k, nil } -// SetField sets the value of a field with the given name. It returns an error if -// the field is not defined in the schema, or if the type mismatched the field -// type. -func (m *GroupMutation) SetField(name string, value ent.Value) error { - switch name { - case group.FieldCreatedAt: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetCreatedAt(v) - return nil - case group.FieldUpdatedAt: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetUpdatedAt(v) - return nil - case group.FieldDeletedAt: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetDeletedAt(v) - return nil - case group.FieldName: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetName(v) - return nil - case group.FieldDescription: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetDescription(v) - return nil - case group.FieldRateMultiplier: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetRateMultiplier(v) - return nil - case group.FieldIsExclusive: - v, ok := value.(bool) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetIsExclusive(v) - return nil - case group.FieldStatus: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetStatus(v) - return nil - case group.FieldPlatform: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetPlatform(v) - return nil - case group.FieldSubscriptionType: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetSubscriptionType(v) - return nil - case group.FieldDailyLimitUsd: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetDailyLimitUsd(v) - return nil - case group.FieldWeeklyLimitUsd: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetWeeklyLimitUsd(v) - return nil - case group.FieldMonthlyLimitUsd: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetMonthlyLimitUsd(v) - return nil - case group.FieldDefaultValidityDays: - v, ok := value.(int) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetDefaultValidityDays(v) - return nil - case group.FieldImagePrice1k: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetImagePrice1k(v) - return nil - case group.FieldImagePrice2k: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetImagePrice2k(v) - return nil - case group.FieldImagePrice4k: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetImagePrice4k(v) - return nil - case group.FieldClaudeCodeOnly: - v, ok := value.(bool) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetClaudeCodeOnly(v) - return nil - case group.FieldFallbackGroupID: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetFallbackGroupID(v) - return nil - case group.FieldFallbackGroupIDOnInvalidRequest: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetFallbackGroupIDOnInvalidRequest(v) - return nil - case group.FieldModelRouting: - v, ok := value.(map[string][]int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetModelRouting(v) - return nil - case group.FieldModelRoutingEnabled: - v, ok := value.(bool) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetModelRoutingEnabled(v) - return nil - case group.FieldMcpXMLInject: - v, ok := value.(bool) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetMcpXMLInject(v) - return nil - case group.FieldSupportedModelScopes: - v, ok := value.([]string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetSupportedModelScopes(v) - return nil - case group.FieldSortOrder: - v, ok := value.(int) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetSortOrder(v) - return nil - case group.FieldAllowMessagesDispatch: - v, ok := value.(bool) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetAllowMessagesDispatch(v) - return nil - case group.FieldRequireOauthOnly: - v, ok := value.(bool) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetRequireOauthOnly(v) - return nil - case group.FieldRequirePrivacySet: - v, ok := value.(bool) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetRequirePrivacySet(v) - return nil - case group.FieldDefaultMappedModel: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetDefaultMappedModel(v) - return nil - case group.FieldMessagesDispatchModelConfig: - v, ok := value.(domain.OpenAIMessagesDispatchModelConfig) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetMessagesDispatchModelConfig(v) - return nil +// AddImagePrice1k adds f to the "image_price_1k" field. +func (m *GroupMutation) AddImagePrice1k(f float64) { + if m.addimage_price_1k != nil { + *m.addimage_price_1k += f + } else { + m.addimage_price_1k = &f } - return fmt.Errorf("unknown Group field %s", name) } -// AddedFields returns all numeric fields that were incremented/decremented during -// this mutation. -func (m *GroupMutation) AddedFields() []string { - var fields []string - if m.addrate_multiplier != nil { - fields = append(fields, group.FieldRateMultiplier) - } - if m.adddaily_limit_usd != nil { - fields = append(fields, group.FieldDailyLimitUsd) +// AddedImagePrice1k returns the value that was added to the "image_price_1k" field in this mutation. +func (m *GroupMutation) AddedImagePrice1k() (r float64, exists bool) { + v := m.addimage_price_1k + if v == nil { + return } - if m.addweekly_limit_usd != nil { - fields = append(fields, group.FieldWeeklyLimitUsd) + return *v, true +} + +// ClearImagePrice1k clears the value of the "image_price_1k" field. +func (m *GroupMutation) ClearImagePrice1k() { + m.image_price_1k = nil + m.addimage_price_1k = nil + m.clearedFields[group.FieldImagePrice1k] = struct{}{} +} + +// ImagePrice1kCleared returns if the "image_price_1k" field was cleared in this mutation. +func (m *GroupMutation) ImagePrice1kCleared() bool { + _, ok := m.clearedFields[group.FieldImagePrice1k] + return ok +} + +// ResetImagePrice1k resets all changes to the "image_price_1k" field. +func (m *GroupMutation) ResetImagePrice1k() { + m.image_price_1k = nil + m.addimage_price_1k = nil + delete(m.clearedFields, group.FieldImagePrice1k) +} + +// SetImagePrice2k sets the "image_price_2k" field. +func (m *GroupMutation) SetImagePrice2k(f float64) { + m.image_price_2k = &f + m.addimage_price_2k = nil +} + +// ImagePrice2k returns the value of the "image_price_2k" field in the mutation. +func (m *GroupMutation) ImagePrice2k() (r float64, exists bool) { + v := m.image_price_2k + if v == nil { + return } - if m.addmonthly_limit_usd != nil { - fields = append(fields, group.FieldMonthlyLimitUsd) + return *v, true +} + +// OldImagePrice2k returns the old "image_price_2k" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldImagePrice2k(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldImagePrice2k is only allowed on UpdateOne operations") } - if m.adddefault_validity_days != nil { - fields = append(fields, group.FieldDefaultValidityDays) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldImagePrice2k requires an ID field in the mutation") } - if m.addimage_price_1k != nil { - fields = append(fields, group.FieldImagePrice1k) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldImagePrice2k: %w", err) } + return oldValue.ImagePrice2k, nil +} + +// AddImagePrice2k adds f to the "image_price_2k" field. +func (m *GroupMutation) AddImagePrice2k(f float64) { if m.addimage_price_2k != nil { - fields = append(fields, group.FieldImagePrice2k) + *m.addimage_price_2k += f + } else { + m.addimage_price_2k = &f } - if m.addimage_price_4k != nil { - fields = append(fields, group.FieldImagePrice4k) +} + +// AddedImagePrice2k returns the value that was added to the "image_price_2k" field in this mutation. +func (m *GroupMutation) AddedImagePrice2k() (r float64, exists bool) { + v := m.addimage_price_2k + if v == nil { + return } - if m.addfallback_group_id != nil { - fields = append(fields, group.FieldFallbackGroupID) + return *v, true +} + +// ClearImagePrice2k clears the value of the "image_price_2k" field. +func (m *GroupMutation) ClearImagePrice2k() { + m.image_price_2k = nil + m.addimage_price_2k = nil + m.clearedFields[group.FieldImagePrice2k] = struct{}{} +} + +// ImagePrice2kCleared returns if the "image_price_2k" field was cleared in this mutation. +func (m *GroupMutation) ImagePrice2kCleared() bool { + _, ok := m.clearedFields[group.FieldImagePrice2k] + return ok +} + +// ResetImagePrice2k resets all changes to the "image_price_2k" field. +func (m *GroupMutation) ResetImagePrice2k() { + m.image_price_2k = nil + m.addimage_price_2k = nil + delete(m.clearedFields, group.FieldImagePrice2k) +} + +// SetImagePrice4k sets the "image_price_4k" field. +func (m *GroupMutation) SetImagePrice4k(f float64) { + m.image_price_4k = &f + m.addimage_price_4k = nil +} + +// ImagePrice4k returns the value of the "image_price_4k" field in the mutation. +func (m *GroupMutation) ImagePrice4k() (r float64, exists bool) { + v := m.image_price_4k + if v == nil { + return } - if m.addfallback_group_id_on_invalid_request != nil { - fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest) + return *v, true +} + +// OldImagePrice4k returns the old "image_price_4k" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldImagePrice4k(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldImagePrice4k is only allowed on UpdateOne operations") } - if m.addsort_order != nil { - fields = append(fields, group.FieldSortOrder) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldImagePrice4k requires an ID field in the mutation") } - return fields + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldImagePrice4k: %w", err) + } + return oldValue.ImagePrice4k, nil } -// AddedField returns the numeric value that was incremented/decremented on a field -// with the given name. The second boolean return value indicates that this field -// was not set, or was not defined in the schema. -func (m *GroupMutation) AddedField(name string) (ent.Value, bool) { - switch name { - case group.FieldRateMultiplier: - return m.AddedRateMultiplier() - case group.FieldDailyLimitUsd: - return m.AddedDailyLimitUsd() - case group.FieldWeeklyLimitUsd: - return m.AddedWeeklyLimitUsd() - case group.FieldMonthlyLimitUsd: - return m.AddedMonthlyLimitUsd() - case group.FieldDefaultValidityDays: - return m.AddedDefaultValidityDays() - case group.FieldImagePrice1k: - return m.AddedImagePrice1k() - case group.FieldImagePrice2k: - return m.AddedImagePrice2k() - case group.FieldImagePrice4k: - return m.AddedImagePrice4k() - case group.FieldFallbackGroupID: - return m.AddedFallbackGroupID() - case group.FieldFallbackGroupIDOnInvalidRequest: - return m.AddedFallbackGroupIDOnInvalidRequest() - case group.FieldSortOrder: - return m.AddedSortOrder() +// AddImagePrice4k adds f to the "image_price_4k" field. +func (m *GroupMutation) AddImagePrice4k(f float64) { + if m.addimage_price_4k != nil { + *m.addimage_price_4k += f + } else { + m.addimage_price_4k = &f } - return nil, false } -// AddField adds the value to the field with the given name. It returns an error if -// the field is not defined in the schema, or if the type mismatched the field -// type. -func (m *GroupMutation) AddField(name string, value ent.Value) error { - switch name { - case group.FieldRateMultiplier: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddRateMultiplier(v) - return nil - case group.FieldDailyLimitUsd: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddDailyLimitUsd(v) - return nil - case group.FieldWeeklyLimitUsd: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddWeeklyLimitUsd(v) - return nil - case group.FieldMonthlyLimitUsd: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddMonthlyLimitUsd(v) - return nil - case group.FieldDefaultValidityDays: - v, ok := value.(int) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddDefaultValidityDays(v) - return nil - case group.FieldImagePrice1k: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddImagePrice1k(v) - return nil - case group.FieldImagePrice2k: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddImagePrice2k(v) - return nil - case group.FieldImagePrice4k: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddImagePrice4k(v) - return nil - case group.FieldFallbackGroupID: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddFallbackGroupID(v) - return nil - case group.FieldFallbackGroupIDOnInvalidRequest: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddFallbackGroupIDOnInvalidRequest(v) - return nil - case group.FieldSortOrder: - v, ok := value.(int) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddSortOrder(v) - return nil +// AddedImagePrice4k returns the value that was added to the "image_price_4k" field in this mutation. +func (m *GroupMutation) AddedImagePrice4k() (r float64, exists bool) { + v := m.addimage_price_4k + if v == nil { + return } - return fmt.Errorf("unknown Group numeric field %s", name) + return *v, true } -// ClearedFields returns all nullable fields that were cleared during this -// mutation. -func (m *GroupMutation) ClearedFields() []string { - var fields []string - if m.FieldCleared(group.FieldDeletedAt) { - fields = append(fields, group.FieldDeletedAt) - } - if m.FieldCleared(group.FieldDescription) { - fields = append(fields, group.FieldDescription) +// ClearImagePrice4k clears the value of the "image_price_4k" field. +func (m *GroupMutation) ClearImagePrice4k() { + m.image_price_4k = nil + m.addimage_price_4k = nil + m.clearedFields[group.FieldImagePrice4k] = struct{}{} +} + +// ImagePrice4kCleared returns if the "image_price_4k" field was cleared in this mutation. +func (m *GroupMutation) ImagePrice4kCleared() bool { + _, ok := m.clearedFields[group.FieldImagePrice4k] + return ok +} + +// ResetImagePrice4k resets all changes to the "image_price_4k" field. +func (m *GroupMutation) ResetImagePrice4k() { + m.image_price_4k = nil + m.addimage_price_4k = nil + delete(m.clearedFields, group.FieldImagePrice4k) +} + +// SetClaudeCodeOnly sets the "claude_code_only" field. +func (m *GroupMutation) SetClaudeCodeOnly(b bool) { + m.claude_code_only = &b +} + +// ClaudeCodeOnly returns the value of the "claude_code_only" field in the mutation. +func (m *GroupMutation) ClaudeCodeOnly() (r bool, exists bool) { + v := m.claude_code_only + if v == nil { + return } - if m.FieldCleared(group.FieldDailyLimitUsd) { - fields = append(fields, group.FieldDailyLimitUsd) + return *v, true +} + +// OldClaudeCodeOnly returns the old "claude_code_only" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldClaudeCodeOnly(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldClaudeCodeOnly is only allowed on UpdateOne operations") } - if m.FieldCleared(group.FieldWeeklyLimitUsd) { - fields = append(fields, group.FieldWeeklyLimitUsd) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldClaudeCodeOnly requires an ID field in the mutation") } - if m.FieldCleared(group.FieldMonthlyLimitUsd) { - fields = append(fields, group.FieldMonthlyLimitUsd) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldClaudeCodeOnly: %w", err) } - if m.FieldCleared(group.FieldImagePrice1k) { - fields = append(fields, group.FieldImagePrice1k) + return oldValue.ClaudeCodeOnly, nil +} + +// ResetClaudeCodeOnly resets all changes to the "claude_code_only" field. +func (m *GroupMutation) ResetClaudeCodeOnly() { + m.claude_code_only = nil +} + +// SetFallbackGroupID sets the "fallback_group_id" field. +func (m *GroupMutation) SetFallbackGroupID(i int64) { + m.fallback_group_id = &i + m.addfallback_group_id = nil +} + +// FallbackGroupID returns the value of the "fallback_group_id" field in the mutation. +func (m *GroupMutation) FallbackGroupID() (r int64, exists bool) { + v := m.fallback_group_id + if v == nil { + return } - if m.FieldCleared(group.FieldImagePrice2k) { - fields = append(fields, group.FieldImagePrice2k) + return *v, true +} + +// OldFallbackGroupID returns the old "fallback_group_id" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldFallbackGroupID(ctx context.Context) (v *int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldFallbackGroupID is only allowed on UpdateOne operations") } - if m.FieldCleared(group.FieldImagePrice4k) { - fields = append(fields, group.FieldImagePrice4k) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldFallbackGroupID requires an ID field in the mutation") } - if m.FieldCleared(group.FieldFallbackGroupID) { - fields = append(fields, group.FieldFallbackGroupID) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldFallbackGroupID: %w", err) } - if m.FieldCleared(group.FieldFallbackGroupIDOnInvalidRequest) { - fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest) + return oldValue.FallbackGroupID, nil +} + +// AddFallbackGroupID adds i to the "fallback_group_id" field. +func (m *GroupMutation) AddFallbackGroupID(i int64) { + if m.addfallback_group_id != nil { + *m.addfallback_group_id += i + } else { + m.addfallback_group_id = &i } - if m.FieldCleared(group.FieldModelRouting) { - fields = append(fields, group.FieldModelRouting) +} + +// AddedFallbackGroupID returns the value that was added to the "fallback_group_id" field in this mutation. +func (m *GroupMutation) AddedFallbackGroupID() (r int64, exists bool) { + v := m.addfallback_group_id + if v == nil { + return } - return fields + return *v, true } -// FieldCleared returns a boolean indicating if a field with the given name was -// cleared in this mutation. -func (m *GroupMutation) FieldCleared(name string) bool { - _, ok := m.clearedFields[name] +// ClearFallbackGroupID clears the value of the "fallback_group_id" field. +func (m *GroupMutation) ClearFallbackGroupID() { + m.fallback_group_id = nil + m.addfallback_group_id = nil + m.clearedFields[group.FieldFallbackGroupID] = struct{}{} +} + +// FallbackGroupIDCleared returns if the "fallback_group_id" field was cleared in this mutation. +func (m *GroupMutation) FallbackGroupIDCleared() bool { + _, ok := m.clearedFields[group.FieldFallbackGroupID] return ok } -// ClearField clears the value of the field with the given name. It returns an -// error if the field is not defined in the schema. -func (m *GroupMutation) ClearField(name string) error { - switch name { - case group.FieldDeletedAt: - m.ClearDeletedAt() - return nil - case group.FieldDescription: - m.ClearDescription() - return nil - case group.FieldDailyLimitUsd: - m.ClearDailyLimitUsd() - return nil - case group.FieldWeeklyLimitUsd: - m.ClearWeeklyLimitUsd() - return nil - case group.FieldMonthlyLimitUsd: - m.ClearMonthlyLimitUsd() - return nil - case group.FieldImagePrice1k: - m.ClearImagePrice1k() - return nil - case group.FieldImagePrice2k: - m.ClearImagePrice2k() - return nil - case group.FieldImagePrice4k: - m.ClearImagePrice4k() - return nil - case group.FieldFallbackGroupID: - m.ClearFallbackGroupID() - return nil - case group.FieldFallbackGroupIDOnInvalidRequest: - m.ClearFallbackGroupIDOnInvalidRequest() - return nil - case group.FieldModelRouting: - m.ClearModelRouting() - return nil +// ResetFallbackGroupID resets all changes to the "fallback_group_id" field. +func (m *GroupMutation) ResetFallbackGroupID() { + m.fallback_group_id = nil + m.addfallback_group_id = nil + delete(m.clearedFields, group.FieldFallbackGroupID) +} + +// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field. +func (m *GroupMutation) SetFallbackGroupIDOnInvalidRequest(i int64) { + m.fallback_group_id_on_invalid_request = &i + m.addfallback_group_id_on_invalid_request = nil +} + +// FallbackGroupIDOnInvalidRequest returns the value of the "fallback_group_id_on_invalid_request" field in the mutation. +func (m *GroupMutation) FallbackGroupIDOnInvalidRequest() (r int64, exists bool) { + v := m.fallback_group_id_on_invalid_request + if v == nil { + return } - return fmt.Errorf("unknown Group nullable field %s", name) + return *v, true } -// ResetField resets all changes in the mutation for the field with the given name. -// It returns an error if the field is not defined in the schema. -func (m *GroupMutation) ResetField(name string) error { - switch name { - case group.FieldCreatedAt: - m.ResetCreatedAt() - return nil - case group.FieldUpdatedAt: - m.ResetUpdatedAt() - return nil - case group.FieldDeletedAt: - m.ResetDeletedAt() - return nil - case group.FieldName: - m.ResetName() - return nil - case group.FieldDescription: - m.ResetDescription() - return nil - case group.FieldRateMultiplier: - m.ResetRateMultiplier() - return nil - case group.FieldIsExclusive: - m.ResetIsExclusive() - return nil - case group.FieldStatus: - m.ResetStatus() - return nil - case group.FieldPlatform: - m.ResetPlatform() - return nil - case group.FieldSubscriptionType: - m.ResetSubscriptionType() - return nil - case group.FieldDailyLimitUsd: - m.ResetDailyLimitUsd() - return nil - case group.FieldWeeklyLimitUsd: - m.ResetWeeklyLimitUsd() - return nil - case group.FieldMonthlyLimitUsd: - m.ResetMonthlyLimitUsd() - return nil - case group.FieldDefaultValidityDays: - m.ResetDefaultValidityDays() - return nil - case group.FieldImagePrice1k: - m.ResetImagePrice1k() - return nil - case group.FieldImagePrice2k: - m.ResetImagePrice2k() - return nil - case group.FieldImagePrice4k: - m.ResetImagePrice4k() - return nil - case group.FieldClaudeCodeOnly: - m.ResetClaudeCodeOnly() - return nil - case group.FieldFallbackGroupID: - m.ResetFallbackGroupID() - return nil - case group.FieldFallbackGroupIDOnInvalidRequest: - m.ResetFallbackGroupIDOnInvalidRequest() - return nil - case group.FieldModelRouting: - m.ResetModelRouting() - return nil - case group.FieldModelRoutingEnabled: - m.ResetModelRoutingEnabled() - return nil - case group.FieldMcpXMLInject: - m.ResetMcpXMLInject() - return nil - case group.FieldSupportedModelScopes: - m.ResetSupportedModelScopes() - return nil - case group.FieldSortOrder: - m.ResetSortOrder() - return nil - case group.FieldAllowMessagesDispatch: - m.ResetAllowMessagesDispatch() - return nil - case group.FieldRequireOauthOnly: - m.ResetRequireOauthOnly() - return nil - case group.FieldRequirePrivacySet: - m.ResetRequirePrivacySet() - return nil - case group.FieldDefaultMappedModel: - m.ResetDefaultMappedModel() - return nil - case group.FieldMessagesDispatchModelConfig: - m.ResetMessagesDispatchModelConfig() - return nil +// OldFallbackGroupIDOnInvalidRequest returns the old "fallback_group_id_on_invalid_request" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldFallbackGroupIDOnInvalidRequest(ctx context.Context) (v *int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldFallbackGroupIDOnInvalidRequest is only allowed on UpdateOne operations") } - return fmt.Errorf("unknown Group field %s", name) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldFallbackGroupIDOnInvalidRequest requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldFallbackGroupIDOnInvalidRequest: %w", err) + } + return oldValue.FallbackGroupIDOnInvalidRequest, nil } -// AddedEdges returns all edge names that were set/added in this mutation. -func (m *GroupMutation) AddedEdges() []string { - edges := make([]string, 0, 6) - if m.api_keys != nil { - edges = append(edges, group.EdgeAPIKeys) +// AddFallbackGroupIDOnInvalidRequest adds i to the "fallback_group_id_on_invalid_request" field. +func (m *GroupMutation) AddFallbackGroupIDOnInvalidRequest(i int64) { + if m.addfallback_group_id_on_invalid_request != nil { + *m.addfallback_group_id_on_invalid_request += i + } else { + m.addfallback_group_id_on_invalid_request = &i } - if m.redeem_codes != nil { - edges = append(edges, group.EdgeRedeemCodes) +} + +// AddedFallbackGroupIDOnInvalidRequest returns the value that was added to the "fallback_group_id_on_invalid_request" field in this mutation. +func (m *GroupMutation) AddedFallbackGroupIDOnInvalidRequest() (r int64, exists bool) { + v := m.addfallback_group_id_on_invalid_request + if v == nil { + return } - if m.subscriptions != nil { - edges = append(edges, group.EdgeSubscriptions) + return *v, true +} + +// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field. +func (m *GroupMutation) ClearFallbackGroupIDOnInvalidRequest() { + m.fallback_group_id_on_invalid_request = nil + m.addfallback_group_id_on_invalid_request = nil + m.clearedFields[group.FieldFallbackGroupIDOnInvalidRequest] = struct{}{} +} + +// FallbackGroupIDOnInvalidRequestCleared returns if the "fallback_group_id_on_invalid_request" field was cleared in this mutation. +func (m *GroupMutation) FallbackGroupIDOnInvalidRequestCleared() bool { + _, ok := m.clearedFields[group.FieldFallbackGroupIDOnInvalidRequest] + return ok +} + +// ResetFallbackGroupIDOnInvalidRequest resets all changes to the "fallback_group_id_on_invalid_request" field. +func (m *GroupMutation) ResetFallbackGroupIDOnInvalidRequest() { + m.fallback_group_id_on_invalid_request = nil + m.addfallback_group_id_on_invalid_request = nil + delete(m.clearedFields, group.FieldFallbackGroupIDOnInvalidRequest) +} + +// SetModelRouting sets the "model_routing" field. +func (m *GroupMutation) SetModelRouting(value map[string][]int64) { + m.model_routing = &value +} + +// ModelRouting returns the value of the "model_routing" field in the mutation. +func (m *GroupMutation) ModelRouting() (r map[string][]int64, exists bool) { + v := m.model_routing + if v == nil { + return } - if m.usage_logs != nil { - edges = append(edges, group.EdgeUsageLogs) + return *v, true +} + +// OldModelRouting returns the old "model_routing" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldModelRouting(ctx context.Context) (v map[string][]int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldModelRouting is only allowed on UpdateOne operations") } - if m.accounts != nil { - edges = append(edges, group.EdgeAccounts) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldModelRouting requires an ID field in the mutation") } - if m.allowed_users != nil { - edges = append(edges, group.EdgeAllowedUsers) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldModelRouting: %w", err) } - return edges + return oldValue.ModelRouting, nil } -// AddedIDs returns all IDs (to other nodes) that were added for the given edge -// name in this mutation. -func (m *GroupMutation) AddedIDs(name string) []ent.Value { - switch name { - case group.EdgeAPIKeys: - ids := make([]ent.Value, 0, len(m.api_keys)) - for id := range m.api_keys { - ids = append(ids, id) - } - return ids - case group.EdgeRedeemCodes: - ids := make([]ent.Value, 0, len(m.redeem_codes)) - for id := range m.redeem_codes { - ids = append(ids, id) - } - return ids - case group.EdgeSubscriptions: - ids := make([]ent.Value, 0, len(m.subscriptions)) - for id := range m.subscriptions { - ids = append(ids, id) - } - return ids - case group.EdgeUsageLogs: - ids := make([]ent.Value, 0, len(m.usage_logs)) - for id := range m.usage_logs { - ids = append(ids, id) - } - return ids - case group.EdgeAccounts: - ids := make([]ent.Value, 0, len(m.accounts)) - for id := range m.accounts { - ids = append(ids, id) - } - return ids - case group.EdgeAllowedUsers: - ids := make([]ent.Value, 0, len(m.allowed_users)) - for id := range m.allowed_users { - ids = append(ids, id) - } - return ids - } - return nil +// ClearModelRouting clears the value of the "model_routing" field. +func (m *GroupMutation) ClearModelRouting() { + m.model_routing = nil + m.clearedFields[group.FieldModelRouting] = struct{}{} } -// RemovedEdges returns all edge names that were removed in this mutation. -func (m *GroupMutation) RemovedEdges() []string { - edges := make([]string, 0, 6) - if m.removedapi_keys != nil { - edges = append(edges, group.EdgeAPIKeys) +// ModelRoutingCleared returns if the "model_routing" field was cleared in this mutation. +func (m *GroupMutation) ModelRoutingCleared() bool { + _, ok := m.clearedFields[group.FieldModelRouting] + return ok +} + +// ResetModelRouting resets all changes to the "model_routing" field. +func (m *GroupMutation) ResetModelRouting() { + m.model_routing = nil + delete(m.clearedFields, group.FieldModelRouting) +} + +// SetModelRoutingEnabled sets the "model_routing_enabled" field. +func (m *GroupMutation) SetModelRoutingEnabled(b bool) { + m.model_routing_enabled = &b +} + +// ModelRoutingEnabled returns the value of the "model_routing_enabled" field in the mutation. +func (m *GroupMutation) ModelRoutingEnabled() (r bool, exists bool) { + v := m.model_routing_enabled + if v == nil { + return } - if m.removedredeem_codes != nil { - edges = append(edges, group.EdgeRedeemCodes) + return *v, true +} + +// OldModelRoutingEnabled returns the old "model_routing_enabled" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldModelRoutingEnabled(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldModelRoutingEnabled is only allowed on UpdateOne operations") } - if m.removedsubscriptions != nil { - edges = append(edges, group.EdgeSubscriptions) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldModelRoutingEnabled requires an ID field in the mutation") } - if m.removedusage_logs != nil { - edges = append(edges, group.EdgeUsageLogs) - } - if m.removedaccounts != nil { - edges = append(edges, group.EdgeAccounts) - } - if m.removedallowed_users != nil { - edges = append(edges, group.EdgeAllowedUsers) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldModelRoutingEnabled: %w", err) } - return edges + return oldValue.ModelRoutingEnabled, nil } -// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with -// the given name in this mutation. -func (m *GroupMutation) RemovedIDs(name string) []ent.Value { - switch name { - case group.EdgeAPIKeys: - ids := make([]ent.Value, 0, len(m.removedapi_keys)) - for id := range m.removedapi_keys { - ids = append(ids, id) - } - return ids - case group.EdgeRedeemCodes: - ids := make([]ent.Value, 0, len(m.removedredeem_codes)) - for id := range m.removedredeem_codes { - ids = append(ids, id) - } - return ids - case group.EdgeSubscriptions: - ids := make([]ent.Value, 0, len(m.removedsubscriptions)) - for id := range m.removedsubscriptions { - ids = append(ids, id) - } - return ids - case group.EdgeUsageLogs: - ids := make([]ent.Value, 0, len(m.removedusage_logs)) - for id := range m.removedusage_logs { - ids = append(ids, id) - } - return ids - case group.EdgeAccounts: - ids := make([]ent.Value, 0, len(m.removedaccounts)) - for id := range m.removedaccounts { - ids = append(ids, id) - } - return ids - case group.EdgeAllowedUsers: - ids := make([]ent.Value, 0, len(m.removedallowed_users)) - for id := range m.removedallowed_users { - ids = append(ids, id) - } - return ids - } - return nil +// ResetModelRoutingEnabled resets all changes to the "model_routing_enabled" field. +func (m *GroupMutation) ResetModelRoutingEnabled() { + m.model_routing_enabled = nil } -// ClearedEdges returns all edge names that were cleared in this mutation. -func (m *GroupMutation) ClearedEdges() []string { - edges := make([]string, 0, 6) - if m.clearedapi_keys { - edges = append(edges, group.EdgeAPIKeys) - } - if m.clearedredeem_codes { - edges = append(edges, group.EdgeRedeemCodes) - } - if m.clearedsubscriptions { - edges = append(edges, group.EdgeSubscriptions) +// SetMcpXMLInject sets the "mcp_xml_inject" field. +func (m *GroupMutation) SetMcpXMLInject(b bool) { + m.mcp_xml_inject = &b +} + +// McpXMLInject returns the value of the "mcp_xml_inject" field in the mutation. +func (m *GroupMutation) McpXMLInject() (r bool, exists bool) { + v := m.mcp_xml_inject + if v == nil { + return } - if m.clearedusage_logs { - edges = append(edges, group.EdgeUsageLogs) + return *v, true +} + +// OldMcpXMLInject returns the old "mcp_xml_inject" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldMcpXMLInject(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMcpXMLInject is only allowed on UpdateOne operations") } - if m.clearedaccounts { - edges = append(edges, group.EdgeAccounts) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMcpXMLInject requires an ID field in the mutation") } - if m.clearedallowed_users { - edges = append(edges, group.EdgeAllowedUsers) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMcpXMLInject: %w", err) } - return edges + return oldValue.McpXMLInject, nil } -// EdgeCleared returns a boolean which indicates if the edge with the given name -// was cleared in this mutation. -func (m *GroupMutation) EdgeCleared(name string) bool { - switch name { - case group.EdgeAPIKeys: - return m.clearedapi_keys - case group.EdgeRedeemCodes: - return m.clearedredeem_codes - case group.EdgeSubscriptions: - return m.clearedsubscriptions - case group.EdgeUsageLogs: - return m.clearedusage_logs - case group.EdgeAccounts: - return m.clearedaccounts - case group.EdgeAllowedUsers: - return m.clearedallowed_users - } - return false +// ResetMcpXMLInject resets all changes to the "mcp_xml_inject" field. +func (m *GroupMutation) ResetMcpXMLInject() { + m.mcp_xml_inject = nil } -// ClearEdge clears the value of the edge with the given name. It returns an error -// if that edge is not defined in the schema. -func (m *GroupMutation) ClearEdge(name string) error { - switch name { +// SetSupportedModelScopes sets the "supported_model_scopes" field. +func (m *GroupMutation) SetSupportedModelScopes(s []string) { + m.supported_model_scopes = &s + m.appendsupported_model_scopes = nil +} + +// SupportedModelScopes returns the value of the "supported_model_scopes" field in the mutation. +func (m *GroupMutation) SupportedModelScopes() (r []string, exists bool) { + v := m.supported_model_scopes + if v == nil { + return } - return fmt.Errorf("unknown Group unique edge %s", name) + return *v, true } -// ResetEdge resets all changes to the edge with the given name in this mutation. -// It returns an error if the edge is not defined in the schema. -func (m *GroupMutation) ResetEdge(name string) error { - switch name { - case group.EdgeAPIKeys: - m.ResetAPIKeys() - return nil - case group.EdgeRedeemCodes: - m.ResetRedeemCodes() - return nil - case group.EdgeSubscriptions: - m.ResetSubscriptions() - return nil - case group.EdgeUsageLogs: - m.ResetUsageLogs() - return nil - case group.EdgeAccounts: - m.ResetAccounts() - return nil - case group.EdgeAllowedUsers: - m.ResetAllowedUsers() - return nil +// OldSupportedModelScopes returns the old "supported_model_scopes" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSupportedModelScopes(ctx context.Context) (v []string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSupportedModelScopes is only allowed on UpdateOne operations") } - return fmt.Errorf("unknown Group edge %s", name) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSupportedModelScopes requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSupportedModelScopes: %w", err) + } + return oldValue.SupportedModelScopes, nil } -// IdempotencyRecordMutation represents an operation that mutates the IdempotencyRecord nodes in the graph. -type IdempotencyRecordMutation struct { - config - op Op - typ string - id *int64 - created_at *time.Time - updated_at *time.Time - scope *string - idempotency_key_hash *string - request_fingerprint *string - status *string - response_status *int - addresponse_status *int - response_body *string - error_reason *string - locked_until *time.Time - expires_at *time.Time - clearedFields map[string]struct{} - done bool - oldValue func(context.Context) (*IdempotencyRecord, error) - predicates []predicate.IdempotencyRecord +// AppendSupportedModelScopes adds s to the "supported_model_scopes" field. +func (m *GroupMutation) AppendSupportedModelScopes(s []string) { + m.appendsupported_model_scopes = append(m.appendsupported_model_scopes, s...) } -var _ ent.Mutation = (*IdempotencyRecordMutation)(nil) +// AppendedSupportedModelScopes returns the list of values that were appended to the "supported_model_scopes" field in this mutation. +func (m *GroupMutation) AppendedSupportedModelScopes() ([]string, bool) { + if len(m.appendsupported_model_scopes) == 0 { + return nil, false + } + return m.appendsupported_model_scopes, true +} -// idempotencyrecordOption allows management of the mutation configuration using functional options. -type idempotencyrecordOption func(*IdempotencyRecordMutation) +// ResetSupportedModelScopes resets all changes to the "supported_model_scopes" field. +func (m *GroupMutation) ResetSupportedModelScopes() { + m.supported_model_scopes = nil + m.appendsupported_model_scopes = nil +} -// newIdempotencyRecordMutation creates new mutation for the IdempotencyRecord entity. -func newIdempotencyRecordMutation(c config, op Op, opts ...idempotencyrecordOption) *IdempotencyRecordMutation { - m := &IdempotencyRecordMutation{ - config: c, - op: op, - typ: TypeIdempotencyRecord, - clearedFields: make(map[string]struct{}), - } - for _, opt := range opts { - opt(m) - } - return m +// SetSortOrder sets the "sort_order" field. +func (m *GroupMutation) SetSortOrder(i int) { + m.sort_order = &i + m.addsort_order = nil } -// withIdempotencyRecordID sets the ID field of the mutation. -func withIdempotencyRecordID(id int64) idempotencyrecordOption { - return func(m *IdempotencyRecordMutation) { - var ( - err error - once sync.Once - value *IdempotencyRecord - ) - m.oldValue = func(ctx context.Context) (*IdempotencyRecord, error) { - once.Do(func() { - if m.done { - err = errors.New("querying old values post mutation is not allowed") - } else { - value, err = m.Client().IdempotencyRecord.Get(ctx, id) - } - }) - return value, err - } - m.id = &id +// SortOrder returns the value of the "sort_order" field in the mutation. +func (m *GroupMutation) SortOrder() (r int, exists bool) { + v := m.sort_order + if v == nil { + return } + return *v, true } -// withIdempotencyRecord sets the old IdempotencyRecord of the mutation. -func withIdempotencyRecord(node *IdempotencyRecord) idempotencyrecordOption { - return func(m *IdempotencyRecordMutation) { - m.oldValue = func(context.Context) (*IdempotencyRecord, error) { - return node, nil - } - m.id = &node.ID +// OldSortOrder returns the old "sort_order" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSortOrder(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSortOrder is only allowed on UpdateOne operations") } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSortOrder requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSortOrder: %w", err) + } + return oldValue.SortOrder, nil } -// Client returns a new `ent.Client` from the mutation. If the mutation was -// executed in a transaction (ent.Tx), a transactional client is returned. -func (m IdempotencyRecordMutation) Client() *Client { - client := &Client{config: m.config} - client.init() - return client -} - -// Tx returns an `ent.Tx` for mutations that were executed in transactions; -// it returns an error otherwise. -func (m IdempotencyRecordMutation) Tx() (*Tx, error) { - if _, ok := m.driver.(*txDriver); !ok { - return nil, errors.New("ent: mutation is not running in a transaction") +// AddSortOrder adds i to the "sort_order" field. +func (m *GroupMutation) AddSortOrder(i int) { + if m.addsort_order != nil { + *m.addsort_order += i + } else { + m.addsort_order = &i } - tx := &Tx{config: m.config} - tx.init() - return tx, nil } -// ID returns the ID value in the mutation. Note that the ID is only available -// if it was provided to the builder or after it was returned from the database. -func (m *IdempotencyRecordMutation) ID() (id int64, exists bool) { - if m.id == nil { +// AddedSortOrder returns the value that was added to the "sort_order" field in this mutation. +func (m *GroupMutation) AddedSortOrder() (r int, exists bool) { + v := m.addsort_order + if v == nil { return } - return *m.id, true + return *v, true } -// IDs queries the database and returns the entity ids that match the mutation's predicate. -// That means, if the mutation is applied within a transaction with an isolation level such -// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated -// or updated by the mutation. -func (m *IdempotencyRecordMutation) IDs(ctx context.Context) ([]int64, error) { - switch { - case m.op.Is(OpUpdateOne | OpDeleteOne): - id, exists := m.ID() - if exists { - return []int64{id}, nil - } - fallthrough - case m.op.Is(OpUpdate | OpDelete): - return m.Client().IdempotencyRecord.Query().Where(m.predicates...).IDs(ctx) - default: - return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) - } +// ResetSortOrder resets all changes to the "sort_order" field. +func (m *GroupMutation) ResetSortOrder() { + m.sort_order = nil + m.addsort_order = nil } -// SetCreatedAt sets the "created_at" field. -func (m *IdempotencyRecordMutation) SetCreatedAt(t time.Time) { - m.created_at = &t +// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field. +func (m *GroupMutation) SetAllowMessagesDispatch(b bool) { + m.allow_messages_dispatch = &b } -// CreatedAt returns the value of the "created_at" field in the mutation. -func (m *IdempotencyRecordMutation) CreatedAt() (r time.Time, exists bool) { - v := m.created_at +// AllowMessagesDispatch returns the value of the "allow_messages_dispatch" field in the mutation. +func (m *GroupMutation) AllowMessagesDispatch() (r bool, exists bool) { + v := m.allow_messages_dispatch if v == nil { return } return *v, true } -// OldCreatedAt returns the old "created_at" field's value of the IdempotencyRecord entity. -// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// OldAllowMessagesDispatch returns the old "allow_messages_dispatch" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *IdempotencyRecordMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { +func (m *GroupMutation) OldAllowMessagesDispatch(ctx context.Context) (v bool, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + return v, errors.New("OldAllowMessagesDispatch is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldCreatedAt requires an ID field in the mutation") + return v, errors.New("OldAllowMessagesDispatch requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + return v, fmt.Errorf("querying old value for OldAllowMessagesDispatch: %w", err) } - return oldValue.CreatedAt, nil + return oldValue.AllowMessagesDispatch, nil } -// ResetCreatedAt resets all changes to the "created_at" field. -func (m *IdempotencyRecordMutation) ResetCreatedAt() { - m.created_at = nil +// ResetAllowMessagesDispatch resets all changes to the "allow_messages_dispatch" field. +func (m *GroupMutation) ResetAllowMessagesDispatch() { + m.allow_messages_dispatch = nil } -// SetUpdatedAt sets the "updated_at" field. -func (m *IdempotencyRecordMutation) SetUpdatedAt(t time.Time) { - m.updated_at = &t +// SetRequireOauthOnly sets the "require_oauth_only" field. +func (m *GroupMutation) SetRequireOauthOnly(b bool) { + m.require_oauth_only = &b } -// UpdatedAt returns the value of the "updated_at" field in the mutation. -func (m *IdempotencyRecordMutation) UpdatedAt() (r time.Time, exists bool) { - v := m.updated_at +// RequireOauthOnly returns the value of the "require_oauth_only" field in the mutation. +func (m *GroupMutation) RequireOauthOnly() (r bool, exists bool) { + v := m.require_oauth_only if v == nil { return } return *v, true } -// OldUpdatedAt returns the old "updated_at" field's value of the IdempotencyRecord entity. -// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// OldRequireOauthOnly returns the old "require_oauth_only" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *IdempotencyRecordMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { +func (m *GroupMutation) OldRequireOauthOnly(ctx context.Context) (v bool, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + return v, errors.New("OldRequireOauthOnly is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + return v, errors.New("OldRequireOauthOnly requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + return v, fmt.Errorf("querying old value for OldRequireOauthOnly: %w", err) } - return oldValue.UpdatedAt, nil + return oldValue.RequireOauthOnly, nil } -// ResetUpdatedAt resets all changes to the "updated_at" field. -func (m *IdempotencyRecordMutation) ResetUpdatedAt() { - m.updated_at = nil +// ResetRequireOauthOnly resets all changes to the "require_oauth_only" field. +func (m *GroupMutation) ResetRequireOauthOnly() { + m.require_oauth_only = nil } -// SetScope sets the "scope" field. -func (m *IdempotencyRecordMutation) SetScope(s string) { - m.scope = &s +// SetRequirePrivacySet sets the "require_privacy_set" field. +func (m *GroupMutation) SetRequirePrivacySet(b bool) { + m.require_privacy_set = &b } -// Scope returns the value of the "scope" field in the mutation. -func (m *IdempotencyRecordMutation) Scope() (r string, exists bool) { - v := m.scope +// RequirePrivacySet returns the value of the "require_privacy_set" field in the mutation. +func (m *GroupMutation) RequirePrivacySet() (r bool, exists bool) { + v := m.require_privacy_set if v == nil { return } return *v, true } -// OldScope returns the old "scope" field's value of the IdempotencyRecord entity. -// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// OldRequirePrivacySet returns the old "require_privacy_set" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *IdempotencyRecordMutation) OldScope(ctx context.Context) (v string, err error) { +func (m *GroupMutation) OldRequirePrivacySet(ctx context.Context) (v bool, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldScope is only allowed on UpdateOne operations") + return v, errors.New("OldRequirePrivacySet is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldScope requires an ID field in the mutation") + return v, errors.New("OldRequirePrivacySet requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldScope: %w", err) + return v, fmt.Errorf("querying old value for OldRequirePrivacySet: %w", err) } - return oldValue.Scope, nil + return oldValue.RequirePrivacySet, nil } -// ResetScope resets all changes to the "scope" field. -func (m *IdempotencyRecordMutation) ResetScope() { - m.scope = nil +// ResetRequirePrivacySet resets all changes to the "require_privacy_set" field. +func (m *GroupMutation) ResetRequirePrivacySet() { + m.require_privacy_set = nil } -// SetIdempotencyKeyHash sets the "idempotency_key_hash" field. -func (m *IdempotencyRecordMutation) SetIdempotencyKeyHash(s string) { - m.idempotency_key_hash = &s +// SetDefaultMappedModel sets the "default_mapped_model" field. +func (m *GroupMutation) SetDefaultMappedModel(s string) { + m.default_mapped_model = &s } -// IdempotencyKeyHash returns the value of the "idempotency_key_hash" field in the mutation. -func (m *IdempotencyRecordMutation) IdempotencyKeyHash() (r string, exists bool) { - v := m.idempotency_key_hash +// DefaultMappedModel returns the value of the "default_mapped_model" field in the mutation. +func (m *GroupMutation) DefaultMappedModel() (r string, exists bool) { + v := m.default_mapped_model if v == nil { return } return *v, true } -// OldIdempotencyKeyHash returns the old "idempotency_key_hash" field's value of the IdempotencyRecord entity. -// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// OldDefaultMappedModel returns the old "default_mapped_model" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *IdempotencyRecordMutation) OldIdempotencyKeyHash(ctx context.Context) (v string, err error) { +func (m *GroupMutation) OldDefaultMappedModel(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldIdempotencyKeyHash is only allowed on UpdateOne operations") + return v, errors.New("OldDefaultMappedModel is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldIdempotencyKeyHash requires an ID field in the mutation") + return v, errors.New("OldDefaultMappedModel requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldIdempotencyKeyHash: %w", err) + return v, fmt.Errorf("querying old value for OldDefaultMappedModel: %w", err) } - return oldValue.IdempotencyKeyHash, nil + return oldValue.DefaultMappedModel, nil } -// ResetIdempotencyKeyHash resets all changes to the "idempotency_key_hash" field. -func (m *IdempotencyRecordMutation) ResetIdempotencyKeyHash() { - m.idempotency_key_hash = nil +// ResetDefaultMappedModel resets all changes to the "default_mapped_model" field. +func (m *GroupMutation) ResetDefaultMappedModel() { + m.default_mapped_model = nil } -// SetRequestFingerprint sets the "request_fingerprint" field. -func (m *IdempotencyRecordMutation) SetRequestFingerprint(s string) { - m.request_fingerprint = &s +// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field. +func (m *GroupMutation) SetMessagesDispatchModelConfig(damdmc domain.OpenAIMessagesDispatchModelConfig) { + m.messages_dispatch_model_config = &damdmc } -// RequestFingerprint returns the value of the "request_fingerprint" field in the mutation. -func (m *IdempotencyRecordMutation) RequestFingerprint() (r string, exists bool) { - v := m.request_fingerprint +// MessagesDispatchModelConfig returns the value of the "messages_dispatch_model_config" field in the mutation. +func (m *GroupMutation) MessagesDispatchModelConfig() (r domain.OpenAIMessagesDispatchModelConfig, exists bool) { + v := m.messages_dispatch_model_config if v == nil { return } return *v, true } -// OldRequestFingerprint returns the old "request_fingerprint" field's value of the IdempotencyRecord entity. -// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// OldMessagesDispatchModelConfig returns the old "messages_dispatch_model_config" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *IdempotencyRecordMutation) OldRequestFingerprint(ctx context.Context) (v string, err error) { +func (m *GroupMutation) OldMessagesDispatchModelConfig(ctx context.Context) (v domain.OpenAIMessagesDispatchModelConfig, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldRequestFingerprint is only allowed on UpdateOne operations") + return v, errors.New("OldMessagesDispatchModelConfig is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldRequestFingerprint requires an ID field in the mutation") + return v, errors.New("OldMessagesDispatchModelConfig requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldRequestFingerprint: %w", err) + return v, fmt.Errorf("querying old value for OldMessagesDispatchModelConfig: %w", err) } - return oldValue.RequestFingerprint, nil + return oldValue.MessagesDispatchModelConfig, nil } -// ResetRequestFingerprint resets all changes to the "request_fingerprint" field. -func (m *IdempotencyRecordMutation) ResetRequestFingerprint() { - m.request_fingerprint = nil +// ResetMessagesDispatchModelConfig resets all changes to the "messages_dispatch_model_config" field. +func (m *GroupMutation) ResetMessagesDispatchModelConfig() { + m.messages_dispatch_model_config = nil } -// SetStatus sets the "status" field. -func (m *IdempotencyRecordMutation) SetStatus(s string) { - m.status = &s +// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids. +func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) { + if m.api_keys == nil { + m.api_keys = make(map[int64]struct{}) + } + for i := range ids { + m.api_keys[ids[i]] = struct{}{} + } } -// Status returns the value of the "status" field in the mutation. -func (m *IdempotencyRecordMutation) Status() (r string, exists bool) { - v := m.status - if v == nil { - return - } - return *v, true +// ClearAPIKeys clears the "api_keys" edge to the APIKey entity. +func (m *GroupMutation) ClearAPIKeys() { + m.clearedapi_keys = true } -// OldStatus returns the old "status" field's value of the IdempotencyRecord entity. -// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *IdempotencyRecordMutation) OldStatus(ctx context.Context) (v string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldStatus is only allowed on UpdateOne operations") +// APIKeysCleared reports if the "api_keys" edge to the APIKey entity was cleared. +func (m *GroupMutation) APIKeysCleared() bool { + return m.clearedapi_keys +} + +// RemoveAPIKeyIDs removes the "api_keys" edge to the APIKey entity by IDs. +func (m *GroupMutation) RemoveAPIKeyIDs(ids ...int64) { + if m.removedapi_keys == nil { + m.removedapi_keys = make(map[int64]struct{}) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldStatus requires an ID field in the mutation") + for i := range ids { + delete(m.api_keys, ids[i]) + m.removedapi_keys[ids[i]] = struct{}{} } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldStatus: %w", err) +} + +// RemovedAPIKeys returns the removed IDs of the "api_keys" edge to the APIKey entity. +func (m *GroupMutation) RemovedAPIKeysIDs() (ids []int64) { + for id := range m.removedapi_keys { + ids = append(ids, id) } - return oldValue.Status, nil + return } -// ResetStatus resets all changes to the "status" field. -func (m *IdempotencyRecordMutation) ResetStatus() { - m.status = nil +// APIKeysIDs returns the "api_keys" edge IDs in the mutation. +func (m *GroupMutation) APIKeysIDs() (ids []int64) { + for id := range m.api_keys { + ids = append(ids, id) + } + return } -// SetResponseStatus sets the "response_status" field. -func (m *IdempotencyRecordMutation) SetResponseStatus(i int) { - m.response_status = &i - m.addresponse_status = nil +// ResetAPIKeys resets all changes to the "api_keys" edge. +func (m *GroupMutation) ResetAPIKeys() { + m.api_keys = nil + m.clearedapi_keys = false + m.removedapi_keys = nil } -// ResponseStatus returns the value of the "response_status" field in the mutation. -func (m *IdempotencyRecordMutation) ResponseStatus() (r int, exists bool) { - v := m.response_status - if v == nil { - return +// AddRedeemCodeIDs adds the "redeem_codes" edge to the RedeemCode entity by ids. +func (m *GroupMutation) AddRedeemCodeIDs(ids ...int64) { + if m.redeem_codes == nil { + m.redeem_codes = make(map[int64]struct{}) + } + for i := range ids { + m.redeem_codes[ids[i]] = struct{}{} } - return *v, true } -// OldResponseStatus returns the old "response_status" field's value of the IdempotencyRecord entity. -// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *IdempotencyRecordMutation) OldResponseStatus(ctx context.Context) (v *int, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldResponseStatus is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldResponseStatus requires an ID field in the mutation") +// ClearRedeemCodes clears the "redeem_codes" edge to the RedeemCode entity. +func (m *GroupMutation) ClearRedeemCodes() { + m.clearedredeem_codes = true +} + +// RedeemCodesCleared reports if the "redeem_codes" edge to the RedeemCode entity was cleared. +func (m *GroupMutation) RedeemCodesCleared() bool { + return m.clearedredeem_codes +} + +// RemoveRedeemCodeIDs removes the "redeem_codes" edge to the RedeemCode entity by IDs. +func (m *GroupMutation) RemoveRedeemCodeIDs(ids ...int64) { + if m.removedredeem_codes == nil { + m.removedredeem_codes = make(map[int64]struct{}) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldResponseStatus: %w", err) + for i := range ids { + delete(m.redeem_codes, ids[i]) + m.removedredeem_codes[ids[i]] = struct{}{} } - return oldValue.ResponseStatus, nil } -// AddResponseStatus adds i to the "response_status" field. -func (m *IdempotencyRecordMutation) AddResponseStatus(i int) { - if m.addresponse_status != nil { - *m.addresponse_status += i - } else { - m.addresponse_status = &i +// RemovedRedeemCodes returns the removed IDs of the "redeem_codes" edge to the RedeemCode entity. +func (m *GroupMutation) RemovedRedeemCodesIDs() (ids []int64) { + for id := range m.removedredeem_codes { + ids = append(ids, id) } + return } -// AddedResponseStatus returns the value that was added to the "response_status" field in this mutation. -func (m *IdempotencyRecordMutation) AddedResponseStatus() (r int, exists bool) { - v := m.addresponse_status - if v == nil { - return +// RedeemCodesIDs returns the "redeem_codes" edge IDs in the mutation. +func (m *GroupMutation) RedeemCodesIDs() (ids []int64) { + for id := range m.redeem_codes { + ids = append(ids, id) } - return *v, true + return } -// ClearResponseStatus clears the value of the "response_status" field. -func (m *IdempotencyRecordMutation) ClearResponseStatus() { - m.response_status = nil - m.addresponse_status = nil - m.clearedFields[idempotencyrecord.FieldResponseStatus] = struct{}{} +// ResetRedeemCodes resets all changes to the "redeem_codes" edge. +func (m *GroupMutation) ResetRedeemCodes() { + m.redeem_codes = nil + m.clearedredeem_codes = false + m.removedredeem_codes = nil } -// ResponseStatusCleared returns if the "response_status" field was cleared in this mutation. -func (m *IdempotencyRecordMutation) ResponseStatusCleared() bool { - _, ok := m.clearedFields[idempotencyrecord.FieldResponseStatus] - return ok +// AddSubscriptionIDs adds the "subscriptions" edge to the UserSubscription entity by ids. +func (m *GroupMutation) AddSubscriptionIDs(ids ...int64) { + if m.subscriptions == nil { + m.subscriptions = make(map[int64]struct{}) + } + for i := range ids { + m.subscriptions[ids[i]] = struct{}{} + } } -// ResetResponseStatus resets all changes to the "response_status" field. -func (m *IdempotencyRecordMutation) ResetResponseStatus() { - m.response_status = nil - m.addresponse_status = nil - delete(m.clearedFields, idempotencyrecord.FieldResponseStatus) +// ClearSubscriptions clears the "subscriptions" edge to the UserSubscription entity. +func (m *GroupMutation) ClearSubscriptions() { + m.clearedsubscriptions = true } -// SetResponseBody sets the "response_body" field. -func (m *IdempotencyRecordMutation) SetResponseBody(s string) { - m.response_body = &s +// SubscriptionsCleared reports if the "subscriptions" edge to the UserSubscription entity was cleared. +func (m *GroupMutation) SubscriptionsCleared() bool { + return m.clearedsubscriptions } -// ResponseBody returns the value of the "response_body" field in the mutation. -func (m *IdempotencyRecordMutation) ResponseBody() (r string, exists bool) { - v := m.response_body - if v == nil { - return +// RemoveSubscriptionIDs removes the "subscriptions" edge to the UserSubscription entity by IDs. +func (m *GroupMutation) RemoveSubscriptionIDs(ids ...int64) { + if m.removedsubscriptions == nil { + m.removedsubscriptions = make(map[int64]struct{}) + } + for i := range ids { + delete(m.subscriptions, ids[i]) + m.removedsubscriptions[ids[i]] = struct{}{} } - return *v, true } -// OldResponseBody returns the old "response_body" field's value of the IdempotencyRecord entity. -// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *IdempotencyRecordMutation) OldResponseBody(ctx context.Context) (v *string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldResponseBody is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldResponseBody requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldResponseBody: %w", err) +// RemovedSubscriptions returns the removed IDs of the "subscriptions" edge to the UserSubscription entity. +func (m *GroupMutation) RemovedSubscriptionsIDs() (ids []int64) { + for id := range m.removedsubscriptions { + ids = append(ids, id) } - return oldValue.ResponseBody, nil + return } -// ClearResponseBody clears the value of the "response_body" field. -func (m *IdempotencyRecordMutation) ClearResponseBody() { - m.response_body = nil - m.clearedFields[idempotencyrecord.FieldResponseBody] = struct{}{} +// SubscriptionsIDs returns the "subscriptions" edge IDs in the mutation. +func (m *GroupMutation) SubscriptionsIDs() (ids []int64) { + for id := range m.subscriptions { + ids = append(ids, id) + } + return } -// ResponseBodyCleared returns if the "response_body" field was cleared in this mutation. -func (m *IdempotencyRecordMutation) ResponseBodyCleared() bool { - _, ok := m.clearedFields[idempotencyrecord.FieldResponseBody] - return ok +// ResetSubscriptions resets all changes to the "subscriptions" edge. +func (m *GroupMutation) ResetSubscriptions() { + m.subscriptions = nil + m.clearedsubscriptions = false + m.removedsubscriptions = nil } -// ResetResponseBody resets all changes to the "response_body" field. -func (m *IdempotencyRecordMutation) ResetResponseBody() { - m.response_body = nil - delete(m.clearedFields, idempotencyrecord.FieldResponseBody) +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by ids. +func (m *GroupMutation) AddUsageLogIDs(ids ...int64) { + if m.usage_logs == nil { + m.usage_logs = make(map[int64]struct{}) + } + for i := range ids { + m.usage_logs[ids[i]] = struct{}{} + } } -// SetErrorReason sets the "error_reason" field. -func (m *IdempotencyRecordMutation) SetErrorReason(s string) { - m.error_reason = &s +// ClearUsageLogs clears the "usage_logs" edge to the UsageLog entity. +func (m *GroupMutation) ClearUsageLogs() { + m.clearedusage_logs = true } -// ErrorReason returns the value of the "error_reason" field in the mutation. -func (m *IdempotencyRecordMutation) ErrorReason() (r string, exists bool) { - v := m.error_reason - if v == nil { - return - } - return *v, true +// UsageLogsCleared reports if the "usage_logs" edge to the UsageLog entity was cleared. +func (m *GroupMutation) UsageLogsCleared() bool { + return m.clearedusage_logs } -// OldErrorReason returns the old "error_reason" field's value of the IdempotencyRecord entity. -// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *IdempotencyRecordMutation) OldErrorReason(ctx context.Context) (v *string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldErrorReason is only allowed on UpdateOne operations") +// RemoveUsageLogIDs removes the "usage_logs" edge to the UsageLog entity by IDs. +func (m *GroupMutation) RemoveUsageLogIDs(ids ...int64) { + if m.removedusage_logs == nil { + m.removedusage_logs = make(map[int64]struct{}) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldErrorReason requires an ID field in the mutation") + for i := range ids { + delete(m.usage_logs, ids[i]) + m.removedusage_logs[ids[i]] = struct{}{} } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldErrorReason: %w", err) +} + +// RemovedUsageLogs returns the removed IDs of the "usage_logs" edge to the UsageLog entity. +func (m *GroupMutation) RemovedUsageLogsIDs() (ids []int64) { + for id := range m.removedusage_logs { + ids = append(ids, id) } - return oldValue.ErrorReason, nil + return } -// ClearErrorReason clears the value of the "error_reason" field. -func (m *IdempotencyRecordMutation) ClearErrorReason() { - m.error_reason = nil - m.clearedFields[idempotencyrecord.FieldErrorReason] = struct{}{} +// UsageLogsIDs returns the "usage_logs" edge IDs in the mutation. +func (m *GroupMutation) UsageLogsIDs() (ids []int64) { + for id := range m.usage_logs { + ids = append(ids, id) + } + return } -// ErrorReasonCleared returns if the "error_reason" field was cleared in this mutation. -func (m *IdempotencyRecordMutation) ErrorReasonCleared() bool { - _, ok := m.clearedFields[idempotencyrecord.FieldErrorReason] - return ok +// ResetUsageLogs resets all changes to the "usage_logs" edge. +func (m *GroupMutation) ResetUsageLogs() { + m.usage_logs = nil + m.clearedusage_logs = false + m.removedusage_logs = nil } -// ResetErrorReason resets all changes to the "error_reason" field. -func (m *IdempotencyRecordMutation) ResetErrorReason() { - m.error_reason = nil - delete(m.clearedFields, idempotencyrecord.FieldErrorReason) +// AddAccountIDs adds the "accounts" edge to the Account entity by ids. +func (m *GroupMutation) AddAccountIDs(ids ...int64) { + if m.accounts == nil { + m.accounts = make(map[int64]struct{}) + } + for i := range ids { + m.accounts[ids[i]] = struct{}{} + } } -// SetLockedUntil sets the "locked_until" field. -func (m *IdempotencyRecordMutation) SetLockedUntil(t time.Time) { - m.locked_until = &t +// ClearAccounts clears the "accounts" edge to the Account entity. +func (m *GroupMutation) ClearAccounts() { + m.clearedaccounts = true } -// LockedUntil returns the value of the "locked_until" field in the mutation. -func (m *IdempotencyRecordMutation) LockedUntil() (r time.Time, exists bool) { - v := m.locked_until - if v == nil { - return - } - return *v, true +// AccountsCleared reports if the "accounts" edge to the Account entity was cleared. +func (m *GroupMutation) AccountsCleared() bool { + return m.clearedaccounts } -// OldLockedUntil returns the old "locked_until" field's value of the IdempotencyRecord entity. -// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *IdempotencyRecordMutation) OldLockedUntil(ctx context.Context) (v *time.Time, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldLockedUntil is only allowed on UpdateOne operations") +// RemoveAccountIDs removes the "accounts" edge to the Account entity by IDs. +func (m *GroupMutation) RemoveAccountIDs(ids ...int64) { + if m.removedaccounts == nil { + m.removedaccounts = make(map[int64]struct{}) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldLockedUntil requires an ID field in the mutation") + for i := range ids { + delete(m.accounts, ids[i]) + m.removedaccounts[ids[i]] = struct{}{} } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldLockedUntil: %w", err) +} + +// RemovedAccounts returns the removed IDs of the "accounts" edge to the Account entity. +func (m *GroupMutation) RemovedAccountsIDs() (ids []int64) { + for id := range m.removedaccounts { + ids = append(ids, id) } - return oldValue.LockedUntil, nil + return } -// ClearLockedUntil clears the value of the "locked_until" field. -func (m *IdempotencyRecordMutation) ClearLockedUntil() { - m.locked_until = nil - m.clearedFields[idempotencyrecord.FieldLockedUntil] = struct{}{} +// AccountsIDs returns the "accounts" edge IDs in the mutation. +func (m *GroupMutation) AccountsIDs() (ids []int64) { + for id := range m.accounts { + ids = append(ids, id) + } + return } -// LockedUntilCleared returns if the "locked_until" field was cleared in this mutation. -func (m *IdempotencyRecordMutation) LockedUntilCleared() bool { - _, ok := m.clearedFields[idempotencyrecord.FieldLockedUntil] - return ok +// ResetAccounts resets all changes to the "accounts" edge. +func (m *GroupMutation) ResetAccounts() { + m.accounts = nil + m.clearedaccounts = false + m.removedaccounts = nil } -// ResetLockedUntil resets all changes to the "locked_until" field. -func (m *IdempotencyRecordMutation) ResetLockedUntil() { - m.locked_until = nil - delete(m.clearedFields, idempotencyrecord.FieldLockedUntil) +// AddAllowedUserIDs adds the "allowed_users" edge to the User entity by ids. +func (m *GroupMutation) AddAllowedUserIDs(ids ...int64) { + if m.allowed_users == nil { + m.allowed_users = make(map[int64]struct{}) + } + for i := range ids { + m.allowed_users[ids[i]] = struct{}{} + } } -// SetExpiresAt sets the "expires_at" field. -func (m *IdempotencyRecordMutation) SetExpiresAt(t time.Time) { - m.expires_at = &t +// ClearAllowedUsers clears the "allowed_users" edge to the User entity. +func (m *GroupMutation) ClearAllowedUsers() { + m.clearedallowed_users = true } -// ExpiresAt returns the value of the "expires_at" field in the mutation. -func (m *IdempotencyRecordMutation) ExpiresAt() (r time.Time, exists bool) { - v := m.expires_at - if v == nil { - return - } - return *v, true +// AllowedUsersCleared reports if the "allowed_users" edge to the User entity was cleared. +func (m *GroupMutation) AllowedUsersCleared() bool { + return m.clearedallowed_users } -// OldExpiresAt returns the old "expires_at" field's value of the IdempotencyRecord entity. -// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *IdempotencyRecordMutation) OldExpiresAt(ctx context.Context) (v time.Time, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations") +// RemoveAllowedUserIDs removes the "allowed_users" edge to the User entity by IDs. +func (m *GroupMutation) RemoveAllowedUserIDs(ids ...int64) { + if m.removedallowed_users == nil { + m.removedallowed_users = make(map[int64]struct{}) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldExpiresAt requires an ID field in the mutation") + for i := range ids { + delete(m.allowed_users, ids[i]) + m.removedallowed_users[ids[i]] = struct{}{} } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err) +} + +// RemovedAllowedUsers returns the removed IDs of the "allowed_users" edge to the User entity. +func (m *GroupMutation) RemovedAllowedUsersIDs() (ids []int64) { + for id := range m.removedallowed_users { + ids = append(ids, id) } - return oldValue.ExpiresAt, nil + return } -// ResetExpiresAt resets all changes to the "expires_at" field. -func (m *IdempotencyRecordMutation) ResetExpiresAt() { - m.expires_at = nil +// AllowedUsersIDs returns the "allowed_users" edge IDs in the mutation. +func (m *GroupMutation) AllowedUsersIDs() (ids []int64) { + for id := range m.allowed_users { + ids = append(ids, id) + } + return } -// Where appends a list predicates to the IdempotencyRecordMutation builder. -func (m *IdempotencyRecordMutation) Where(ps ...predicate.IdempotencyRecord) { +// ResetAllowedUsers resets all changes to the "allowed_users" edge. +func (m *GroupMutation) ResetAllowedUsers() { + m.allowed_users = nil + m.clearedallowed_users = false + m.removedallowed_users = nil +} + +// Where appends a list predicates to the GroupMutation builder. +func (m *GroupMutation) Where(ps ...predicate.Group) { m.predicates = append(m.predicates, ps...) } -// WhereP appends storage-level predicates to the IdempotencyRecordMutation builder. Using this method, +// WhereP appends storage-level predicates to the GroupMutation builder. Using this method, // users can use type-assertion to append predicates that do not depend on any generated package. -func (m *IdempotencyRecordMutation) WhereP(ps ...func(*sql.Selector)) { - p := make([]predicate.IdempotencyRecord, len(ps)) +func (m *GroupMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Group, len(ps)) for i := range ps { p[i] = ps[i] } @@ -11816,215 +12030,511 @@ func (m *IdempotencyRecordMutation) WhereP(ps ...func(*sql.Selector)) { } // Op returns the operation name. -func (m *IdempotencyRecordMutation) Op() Op { +func (m *GroupMutation) Op() Op { return m.op } // SetOp allows setting the mutation operation. -func (m *IdempotencyRecordMutation) SetOp(op Op) { +func (m *GroupMutation) SetOp(op Op) { m.op = op } -// Type returns the node type of this mutation (IdempotencyRecord). -func (m *IdempotencyRecordMutation) Type() string { +// Type returns the node type of this mutation (Group). +func (m *GroupMutation) Type() string { return m.typ } // Fields returns all fields that were changed during this mutation. Note that in // order to get all numeric fields that were incremented/decremented, call // AddedFields(). -func (m *IdempotencyRecordMutation) Fields() []string { - fields := make([]string, 0, 11) +func (m *GroupMutation) Fields() []string { + fields := make([]string, 0, 30) if m.created_at != nil { - fields = append(fields, idempotencyrecord.FieldCreatedAt) + fields = append(fields, group.FieldCreatedAt) } if m.updated_at != nil { - fields = append(fields, idempotencyrecord.FieldUpdatedAt) + fields = append(fields, group.FieldUpdatedAt) } - if m.scope != nil { - fields = append(fields, idempotencyrecord.FieldScope) + if m.deleted_at != nil { + fields = append(fields, group.FieldDeletedAt) } - if m.idempotency_key_hash != nil { - fields = append(fields, idempotencyrecord.FieldIdempotencyKeyHash) + if m.name != nil { + fields = append(fields, group.FieldName) } - if m.request_fingerprint != nil { - fields = append(fields, idempotencyrecord.FieldRequestFingerprint) + if m.description != nil { + fields = append(fields, group.FieldDescription) } - if m.status != nil { - fields = append(fields, idempotencyrecord.FieldStatus) + if m.rate_multiplier != nil { + fields = append(fields, group.FieldRateMultiplier) } - if m.response_status != nil { - fields = append(fields, idempotencyrecord.FieldResponseStatus) + if m.is_exclusive != nil { + fields = append(fields, group.FieldIsExclusive) } - if m.response_body != nil { - fields = append(fields, idempotencyrecord.FieldResponseBody) + if m.status != nil { + fields = append(fields, group.FieldStatus) } - if m.error_reason != nil { - fields = append(fields, idempotencyrecord.FieldErrorReason) + if m.platform != nil { + fields = append(fields, group.FieldPlatform) } - if m.locked_until != nil { - fields = append(fields, idempotencyrecord.FieldLockedUntil) + if m.subscription_type != nil { + fields = append(fields, group.FieldSubscriptionType) } - if m.expires_at != nil { - fields = append(fields, idempotencyrecord.FieldExpiresAt) + if m.daily_limit_usd != nil { + fields = append(fields, group.FieldDailyLimitUsd) } - return fields -} - -// Field returns the value of a field with the given name. The second boolean -// return value indicates that this field was not set, or was not defined in the -// schema. -func (m *IdempotencyRecordMutation) Field(name string) (ent.Value, bool) { - switch name { - case idempotencyrecord.FieldCreatedAt: - return m.CreatedAt() - case idempotencyrecord.FieldUpdatedAt: - return m.UpdatedAt() - case idempotencyrecord.FieldScope: - return m.Scope() - case idempotencyrecord.FieldIdempotencyKeyHash: - return m.IdempotencyKeyHash() - case idempotencyrecord.FieldRequestFingerprint: - return m.RequestFingerprint() - case idempotencyrecord.FieldStatus: - return m.Status() - case idempotencyrecord.FieldResponseStatus: - return m.ResponseStatus() - case idempotencyrecord.FieldResponseBody: - return m.ResponseBody() - case idempotencyrecord.FieldErrorReason: - return m.ErrorReason() - case idempotencyrecord.FieldLockedUntil: - return m.LockedUntil() - case idempotencyrecord.FieldExpiresAt: - return m.ExpiresAt() + if m.weekly_limit_usd != nil { + fields = append(fields, group.FieldWeeklyLimitUsd) } - return nil, false + if m.monthly_limit_usd != nil { + fields = append(fields, group.FieldMonthlyLimitUsd) + } + if m.default_validity_days != nil { + fields = append(fields, group.FieldDefaultValidityDays) + } + if m.image_price_1k != nil { + fields = append(fields, group.FieldImagePrice1k) + } + if m.image_price_2k != nil { + fields = append(fields, group.FieldImagePrice2k) + } + if m.image_price_4k != nil { + fields = append(fields, group.FieldImagePrice4k) + } + if m.claude_code_only != nil { + fields = append(fields, group.FieldClaudeCodeOnly) + } + if m.fallback_group_id != nil { + fields = append(fields, group.FieldFallbackGroupID) + } + if m.fallback_group_id_on_invalid_request != nil { + fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest) + } + if m.model_routing != nil { + fields = append(fields, group.FieldModelRouting) + } + if m.model_routing_enabled != nil { + fields = append(fields, group.FieldModelRoutingEnabled) + } + if m.mcp_xml_inject != nil { + fields = append(fields, group.FieldMcpXMLInject) + } + if m.supported_model_scopes != nil { + fields = append(fields, group.FieldSupportedModelScopes) + } + if m.sort_order != nil { + fields = append(fields, group.FieldSortOrder) + } + if m.allow_messages_dispatch != nil { + fields = append(fields, group.FieldAllowMessagesDispatch) + } + if m.require_oauth_only != nil { + fields = append(fields, group.FieldRequireOauthOnly) + } + if m.require_privacy_set != nil { + fields = append(fields, group.FieldRequirePrivacySet) + } + if m.default_mapped_model != nil { + fields = append(fields, group.FieldDefaultMappedModel) + } + if m.messages_dispatch_model_config != nil { + fields = append(fields, group.FieldMessagesDispatchModelConfig) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *GroupMutation) Field(name string) (ent.Value, bool) { + switch name { + case group.FieldCreatedAt: + return m.CreatedAt() + case group.FieldUpdatedAt: + return m.UpdatedAt() + case group.FieldDeletedAt: + return m.DeletedAt() + case group.FieldName: + return m.Name() + case group.FieldDescription: + return m.Description() + case group.FieldRateMultiplier: + return m.RateMultiplier() + case group.FieldIsExclusive: + return m.IsExclusive() + case group.FieldStatus: + return m.Status() + case group.FieldPlatform: + return m.Platform() + case group.FieldSubscriptionType: + return m.SubscriptionType() + case group.FieldDailyLimitUsd: + return m.DailyLimitUsd() + case group.FieldWeeklyLimitUsd: + return m.WeeklyLimitUsd() + case group.FieldMonthlyLimitUsd: + return m.MonthlyLimitUsd() + case group.FieldDefaultValidityDays: + return m.DefaultValidityDays() + case group.FieldImagePrice1k: + return m.ImagePrice1k() + case group.FieldImagePrice2k: + return m.ImagePrice2k() + case group.FieldImagePrice4k: + return m.ImagePrice4k() + case group.FieldClaudeCodeOnly: + return m.ClaudeCodeOnly() + case group.FieldFallbackGroupID: + return m.FallbackGroupID() + case group.FieldFallbackGroupIDOnInvalidRequest: + return m.FallbackGroupIDOnInvalidRequest() + case group.FieldModelRouting: + return m.ModelRouting() + case group.FieldModelRoutingEnabled: + return m.ModelRoutingEnabled() + case group.FieldMcpXMLInject: + return m.McpXMLInject() + case group.FieldSupportedModelScopes: + return m.SupportedModelScopes() + case group.FieldSortOrder: + return m.SortOrder() + case group.FieldAllowMessagesDispatch: + return m.AllowMessagesDispatch() + case group.FieldRequireOauthOnly: + return m.RequireOauthOnly() + case group.FieldRequirePrivacySet: + return m.RequirePrivacySet() + case group.FieldDefaultMappedModel: + return m.DefaultMappedModel() + case group.FieldMessagesDispatchModelConfig: + return m.MessagesDispatchModelConfig() + } + return nil, false } // OldField returns the old value of the field from the database. An error is // returned if the mutation operation is not UpdateOne, or the query to the // database failed. -func (m *IdempotencyRecordMutation) OldField(ctx context.Context, name string) (ent.Value, error) { +func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, error) { switch name { - case idempotencyrecord.FieldCreatedAt: + case group.FieldCreatedAt: return m.OldCreatedAt(ctx) - case idempotencyrecord.FieldUpdatedAt: + case group.FieldUpdatedAt: return m.OldUpdatedAt(ctx) - case idempotencyrecord.FieldScope: - return m.OldScope(ctx) - case idempotencyrecord.FieldIdempotencyKeyHash: - return m.OldIdempotencyKeyHash(ctx) - case idempotencyrecord.FieldRequestFingerprint: - return m.OldRequestFingerprint(ctx) - case idempotencyrecord.FieldStatus: + case group.FieldDeletedAt: + return m.OldDeletedAt(ctx) + case group.FieldName: + return m.OldName(ctx) + case group.FieldDescription: + return m.OldDescription(ctx) + case group.FieldRateMultiplier: + return m.OldRateMultiplier(ctx) + case group.FieldIsExclusive: + return m.OldIsExclusive(ctx) + case group.FieldStatus: return m.OldStatus(ctx) - case idempotencyrecord.FieldResponseStatus: - return m.OldResponseStatus(ctx) - case idempotencyrecord.FieldResponseBody: - return m.OldResponseBody(ctx) - case idempotencyrecord.FieldErrorReason: - return m.OldErrorReason(ctx) - case idempotencyrecord.FieldLockedUntil: - return m.OldLockedUntil(ctx) - case idempotencyrecord.FieldExpiresAt: - return m.OldExpiresAt(ctx) + case group.FieldPlatform: + return m.OldPlatform(ctx) + case group.FieldSubscriptionType: + return m.OldSubscriptionType(ctx) + case group.FieldDailyLimitUsd: + return m.OldDailyLimitUsd(ctx) + case group.FieldWeeklyLimitUsd: + return m.OldWeeklyLimitUsd(ctx) + case group.FieldMonthlyLimitUsd: + return m.OldMonthlyLimitUsd(ctx) + case group.FieldDefaultValidityDays: + return m.OldDefaultValidityDays(ctx) + case group.FieldImagePrice1k: + return m.OldImagePrice1k(ctx) + case group.FieldImagePrice2k: + return m.OldImagePrice2k(ctx) + case group.FieldImagePrice4k: + return m.OldImagePrice4k(ctx) + case group.FieldClaudeCodeOnly: + return m.OldClaudeCodeOnly(ctx) + case group.FieldFallbackGroupID: + return m.OldFallbackGroupID(ctx) + case group.FieldFallbackGroupIDOnInvalidRequest: + return m.OldFallbackGroupIDOnInvalidRequest(ctx) + case group.FieldModelRouting: + return m.OldModelRouting(ctx) + case group.FieldModelRoutingEnabled: + return m.OldModelRoutingEnabled(ctx) + case group.FieldMcpXMLInject: + return m.OldMcpXMLInject(ctx) + case group.FieldSupportedModelScopes: + return m.OldSupportedModelScopes(ctx) + case group.FieldSortOrder: + return m.OldSortOrder(ctx) + case group.FieldAllowMessagesDispatch: + return m.OldAllowMessagesDispatch(ctx) + case group.FieldRequireOauthOnly: + return m.OldRequireOauthOnly(ctx) + case group.FieldRequirePrivacySet: + return m.OldRequirePrivacySet(ctx) + case group.FieldDefaultMappedModel: + return m.OldDefaultMappedModel(ctx) + case group.FieldMessagesDispatchModelConfig: + return m.OldMessagesDispatchModelConfig(ctx) } - return nil, fmt.Errorf("unknown IdempotencyRecord field %s", name) + return nil, fmt.Errorf("unknown Group field %s", name) } // SetField sets the value of a field with the given name. It returns an error if // the field is not defined in the schema, or if the type mismatched the field // type. -func (m *IdempotencyRecordMutation) SetField(name string, value ent.Value) error { +func (m *GroupMutation) SetField(name string, value ent.Value) error { switch name { - case idempotencyrecord.FieldCreatedAt: + case group.FieldCreatedAt: v, ok := value.(time.Time) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } m.SetCreatedAt(v) return nil - case idempotencyrecord.FieldUpdatedAt: + case group.FieldUpdatedAt: v, ok := value.(time.Time) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } m.SetUpdatedAt(v) return nil - case idempotencyrecord.FieldScope: - v, ok := value.(string) + case group.FieldDeletedAt: + v, ok := value.(time.Time) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetScope(v) + m.SetDeletedAt(v) return nil - case idempotencyrecord.FieldIdempotencyKeyHash: + case group.FieldName: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetIdempotencyKeyHash(v) + m.SetName(v) return nil - case idempotencyrecord.FieldRequestFingerprint: + case group.FieldDescription: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetRequestFingerprint(v) + m.SetDescription(v) return nil - case idempotencyrecord.FieldStatus: - v, ok := value.(string) + case group.FieldRateMultiplier: + v, ok := value.(float64) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetStatus(v) + m.SetRateMultiplier(v) return nil - case idempotencyrecord.FieldResponseStatus: - v, ok := value.(int) + case group.FieldIsExclusive: + v, ok := value.(bool) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetResponseStatus(v) + m.SetIsExclusive(v) return nil - case idempotencyrecord.FieldResponseBody: + case group.FieldStatus: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetResponseBody(v) + m.SetStatus(v) return nil - case idempotencyrecord.FieldErrorReason: + case group.FieldPlatform: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetErrorReason(v) + m.SetPlatform(v) return nil - case idempotencyrecord.FieldLockedUntil: - v, ok := value.(time.Time) + case group.FieldSubscriptionType: + v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetLockedUntil(v) + m.SetSubscriptionType(v) return nil - case idempotencyrecord.FieldExpiresAt: - v, ok := value.(time.Time) + case group.FieldDailyLimitUsd: + v, ok := value.(float64) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetExpiresAt(v) + m.SetDailyLimitUsd(v) + return nil + case group.FieldWeeklyLimitUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetWeeklyLimitUsd(v) + return nil + case group.FieldMonthlyLimitUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMonthlyLimitUsd(v) + return nil + case group.FieldDefaultValidityDays: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDefaultValidityDays(v) + return nil + case group.FieldImagePrice1k: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetImagePrice1k(v) + return nil + case group.FieldImagePrice2k: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetImagePrice2k(v) + return nil + case group.FieldImagePrice4k: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetImagePrice4k(v) + return nil + case group.FieldClaudeCodeOnly: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetClaudeCodeOnly(v) + return nil + case group.FieldFallbackGroupID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetFallbackGroupID(v) + return nil + case group.FieldFallbackGroupIDOnInvalidRequest: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetFallbackGroupIDOnInvalidRequest(v) + return nil + case group.FieldModelRouting: + v, ok := value.(map[string][]int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetModelRouting(v) + return nil + case group.FieldModelRoutingEnabled: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetModelRoutingEnabled(v) + return nil + case group.FieldMcpXMLInject: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMcpXMLInject(v) + return nil + case group.FieldSupportedModelScopes: + v, ok := value.([]string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSupportedModelScopes(v) + return nil + case group.FieldSortOrder: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSortOrder(v) + return nil + case group.FieldAllowMessagesDispatch: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAllowMessagesDispatch(v) + return nil + case group.FieldRequireOauthOnly: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRequireOauthOnly(v) + return nil + case group.FieldRequirePrivacySet: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRequirePrivacySet(v) + return nil + case group.FieldDefaultMappedModel: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDefaultMappedModel(v) + return nil + case group.FieldMessagesDispatchModelConfig: + v, ok := value.(domain.OpenAIMessagesDispatchModelConfig) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMessagesDispatchModelConfig(v) return nil } - return fmt.Errorf("unknown IdempotencyRecord field %s", name) + return fmt.Errorf("unknown Group field %s", name) } // AddedFields returns all numeric fields that were incremented/decremented during // this mutation. -func (m *IdempotencyRecordMutation) AddedFields() []string { +func (m *GroupMutation) AddedFields() []string { var fields []string - if m.addresponse_status != nil { - fields = append(fields, idempotencyrecord.FieldResponseStatus) + if m.addrate_multiplier != nil { + fields = append(fields, group.FieldRateMultiplier) + } + if m.adddaily_limit_usd != nil { + fields = append(fields, group.FieldDailyLimitUsd) + } + if m.addweekly_limit_usd != nil { + fields = append(fields, group.FieldWeeklyLimitUsd) + } + if m.addmonthly_limit_usd != nil { + fields = append(fields, group.FieldMonthlyLimitUsd) + } + if m.adddefault_validity_days != nil { + fields = append(fields, group.FieldDefaultValidityDays) + } + if m.addimage_price_1k != nil { + fields = append(fields, group.FieldImagePrice1k) + } + if m.addimage_price_2k != nil { + fields = append(fields, group.FieldImagePrice2k) + } + if m.addimage_price_4k != nil { + fields = append(fields, group.FieldImagePrice4k) + } + if m.addfallback_group_id != nil { + fields = append(fields, group.FieldFallbackGroupID) + } + if m.addfallback_group_id_on_invalid_request != nil { + fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest) + } + if m.addsort_order != nil { + fields = append(fields, group.FieldSortOrder) } return fields } @@ -12032,10 +12542,30 @@ func (m *IdempotencyRecordMutation) AddedFields() []string { // AddedField returns the numeric value that was incremented/decremented on a field // with the given name. The second boolean return value indicates that this field // was not set, or was not defined in the schema. -func (m *IdempotencyRecordMutation) AddedField(name string) (ent.Value, bool) { +func (m *GroupMutation) AddedField(name string) (ent.Value, bool) { switch name { - case idempotencyrecord.FieldResponseStatus: - return m.AddedResponseStatus() + case group.FieldRateMultiplier: + return m.AddedRateMultiplier() + case group.FieldDailyLimitUsd: + return m.AddedDailyLimitUsd() + case group.FieldWeeklyLimitUsd: + return m.AddedWeeklyLimitUsd() + case group.FieldMonthlyLimitUsd: + return m.AddedMonthlyLimitUsd() + case group.FieldDefaultValidityDays: + return m.AddedDefaultValidityDays() + case group.FieldImagePrice1k: + return m.AddedImagePrice1k() + case group.FieldImagePrice2k: + return m.AddedImagePrice2k() + case group.FieldImagePrice4k: + return m.AddedImagePrice4k() + case group.FieldFallbackGroupID: + return m.AddedFallbackGroupID() + case group.FieldFallbackGroupIDOnInvalidRequest: + return m.AddedFallbackGroupIDOnInvalidRequest() + case group.FieldSortOrder: + return m.AddedSortOrder() } return nil, false } @@ -12043,182 +12573,524 @@ func (m *IdempotencyRecordMutation) AddedField(name string) (ent.Value, bool) { // AddField adds the value to the field with the given name. It returns an error if // the field is not defined in the schema, or if the type mismatched the field // type. -func (m *IdempotencyRecordMutation) AddField(name string, value ent.Value) error { +func (m *GroupMutation) AddField(name string, value ent.Value) error { switch name { - case idempotencyrecord.FieldResponseStatus: - v, ok := value.(int) + case group.FieldRateMultiplier: + v, ok := value.(float64) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.AddResponseStatus(v) + m.AddRateMultiplier(v) return nil - } - return fmt.Errorf("unknown IdempotencyRecord numeric field %s", name) -} - -// ClearedFields returns all nullable fields that were cleared during this -// mutation. -func (m *IdempotencyRecordMutation) ClearedFields() []string { - var fields []string - if m.FieldCleared(idempotencyrecord.FieldResponseStatus) { - fields = append(fields, idempotencyrecord.FieldResponseStatus) - } - if m.FieldCleared(idempotencyrecord.FieldResponseBody) { - fields = append(fields, idempotencyrecord.FieldResponseBody) - } - if m.FieldCleared(idempotencyrecord.FieldErrorReason) { - fields = append(fields, idempotencyrecord.FieldErrorReason) - } - if m.FieldCleared(idempotencyrecord.FieldLockedUntil) { - fields = append(fields, idempotencyrecord.FieldLockedUntil) - } - return fields -} - -// FieldCleared returns a boolean indicating if a field with the given name was -// cleared in this mutation. -func (m *IdempotencyRecordMutation) FieldCleared(name string) bool { - _, ok := m.clearedFields[name] - return ok -} - -// ClearField clears the value of the field with the given name. It returns an -// error if the field is not defined in the schema. -func (m *IdempotencyRecordMutation) ClearField(name string) error { - switch name { - case idempotencyrecord.FieldResponseStatus: - m.ClearResponseStatus() + case group.FieldDailyLimitUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddDailyLimitUsd(v) return nil - case idempotencyrecord.FieldResponseBody: - m.ClearResponseBody() + case group.FieldWeeklyLimitUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddWeeklyLimitUsd(v) return nil - case idempotencyrecord.FieldErrorReason: - m.ClearErrorReason() + case group.FieldMonthlyLimitUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddMonthlyLimitUsd(v) return nil - case idempotencyrecord.FieldLockedUntil: - m.ClearLockedUntil() + case group.FieldDefaultValidityDays: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddDefaultValidityDays(v) + return nil + case group.FieldImagePrice1k: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddImagePrice1k(v) + return nil + case group.FieldImagePrice2k: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddImagePrice2k(v) + return nil + case group.FieldImagePrice4k: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddImagePrice4k(v) + return nil + case group.FieldFallbackGroupID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddFallbackGroupID(v) + return nil + case group.FieldFallbackGroupIDOnInvalidRequest: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddFallbackGroupIDOnInvalidRequest(v) + return nil + case group.FieldSortOrder: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSortOrder(v) return nil } - return fmt.Errorf("unknown IdempotencyRecord nullable field %s", name) + return fmt.Errorf("unknown Group numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *GroupMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(group.FieldDeletedAt) { + fields = append(fields, group.FieldDeletedAt) + } + if m.FieldCleared(group.FieldDescription) { + fields = append(fields, group.FieldDescription) + } + if m.FieldCleared(group.FieldDailyLimitUsd) { + fields = append(fields, group.FieldDailyLimitUsd) + } + if m.FieldCleared(group.FieldWeeklyLimitUsd) { + fields = append(fields, group.FieldWeeklyLimitUsd) + } + if m.FieldCleared(group.FieldMonthlyLimitUsd) { + fields = append(fields, group.FieldMonthlyLimitUsd) + } + if m.FieldCleared(group.FieldImagePrice1k) { + fields = append(fields, group.FieldImagePrice1k) + } + if m.FieldCleared(group.FieldImagePrice2k) { + fields = append(fields, group.FieldImagePrice2k) + } + if m.FieldCleared(group.FieldImagePrice4k) { + fields = append(fields, group.FieldImagePrice4k) + } + if m.FieldCleared(group.FieldFallbackGroupID) { + fields = append(fields, group.FieldFallbackGroupID) + } + if m.FieldCleared(group.FieldFallbackGroupIDOnInvalidRequest) { + fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest) + } + if m.FieldCleared(group.FieldModelRouting) { + fields = append(fields, group.FieldModelRouting) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *GroupMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *GroupMutation) ClearField(name string) error { + switch name { + case group.FieldDeletedAt: + m.ClearDeletedAt() + return nil + case group.FieldDescription: + m.ClearDescription() + return nil + case group.FieldDailyLimitUsd: + m.ClearDailyLimitUsd() + return nil + case group.FieldWeeklyLimitUsd: + m.ClearWeeklyLimitUsd() + return nil + case group.FieldMonthlyLimitUsd: + m.ClearMonthlyLimitUsd() + return nil + case group.FieldImagePrice1k: + m.ClearImagePrice1k() + return nil + case group.FieldImagePrice2k: + m.ClearImagePrice2k() + return nil + case group.FieldImagePrice4k: + m.ClearImagePrice4k() + return nil + case group.FieldFallbackGroupID: + m.ClearFallbackGroupID() + return nil + case group.FieldFallbackGroupIDOnInvalidRequest: + m.ClearFallbackGroupIDOnInvalidRequest() + return nil + case group.FieldModelRouting: + m.ClearModelRouting() + return nil + } + return fmt.Errorf("unknown Group nullable field %s", name) } // ResetField resets all changes in the mutation for the field with the given name. // It returns an error if the field is not defined in the schema. -func (m *IdempotencyRecordMutation) ResetField(name string) error { +func (m *GroupMutation) ResetField(name string) error { switch name { - case idempotencyrecord.FieldCreatedAt: + case group.FieldCreatedAt: m.ResetCreatedAt() return nil - case idempotencyrecord.FieldUpdatedAt: + case group.FieldUpdatedAt: m.ResetUpdatedAt() return nil - case idempotencyrecord.FieldScope: - m.ResetScope() + case group.FieldDeletedAt: + m.ResetDeletedAt() return nil - case idempotencyrecord.FieldIdempotencyKeyHash: - m.ResetIdempotencyKeyHash() + case group.FieldName: + m.ResetName() return nil - case idempotencyrecord.FieldRequestFingerprint: - m.ResetRequestFingerprint() + case group.FieldDescription: + m.ResetDescription() return nil - case idempotencyrecord.FieldStatus: + case group.FieldRateMultiplier: + m.ResetRateMultiplier() + return nil + case group.FieldIsExclusive: + m.ResetIsExclusive() + return nil + case group.FieldStatus: m.ResetStatus() return nil - case idempotencyrecord.FieldResponseStatus: - m.ResetResponseStatus() + case group.FieldPlatform: + m.ResetPlatform() return nil - case idempotencyrecord.FieldResponseBody: - m.ResetResponseBody() + case group.FieldSubscriptionType: + m.ResetSubscriptionType() return nil - case idempotencyrecord.FieldErrorReason: - m.ResetErrorReason() + case group.FieldDailyLimitUsd: + m.ResetDailyLimitUsd() return nil - case idempotencyrecord.FieldLockedUntil: - m.ResetLockedUntil() + case group.FieldWeeklyLimitUsd: + m.ResetWeeklyLimitUsd() return nil - case idempotencyrecord.FieldExpiresAt: - m.ResetExpiresAt() + case group.FieldMonthlyLimitUsd: + m.ResetMonthlyLimitUsd() + return nil + case group.FieldDefaultValidityDays: + m.ResetDefaultValidityDays() + return nil + case group.FieldImagePrice1k: + m.ResetImagePrice1k() + return nil + case group.FieldImagePrice2k: + m.ResetImagePrice2k() + return nil + case group.FieldImagePrice4k: + m.ResetImagePrice4k() + return nil + case group.FieldClaudeCodeOnly: + m.ResetClaudeCodeOnly() + return nil + case group.FieldFallbackGroupID: + m.ResetFallbackGroupID() + return nil + case group.FieldFallbackGroupIDOnInvalidRequest: + m.ResetFallbackGroupIDOnInvalidRequest() + return nil + case group.FieldModelRouting: + m.ResetModelRouting() + return nil + case group.FieldModelRoutingEnabled: + m.ResetModelRoutingEnabled() + return nil + case group.FieldMcpXMLInject: + m.ResetMcpXMLInject() + return nil + case group.FieldSupportedModelScopes: + m.ResetSupportedModelScopes() + return nil + case group.FieldSortOrder: + m.ResetSortOrder() + return nil + case group.FieldAllowMessagesDispatch: + m.ResetAllowMessagesDispatch() + return nil + case group.FieldRequireOauthOnly: + m.ResetRequireOauthOnly() + return nil + case group.FieldRequirePrivacySet: + m.ResetRequirePrivacySet() + return nil + case group.FieldDefaultMappedModel: + m.ResetDefaultMappedModel() + return nil + case group.FieldMessagesDispatchModelConfig: + m.ResetMessagesDispatchModelConfig() return nil } - return fmt.Errorf("unknown IdempotencyRecord field %s", name) + return fmt.Errorf("unknown Group field %s", name) } // AddedEdges returns all edge names that were set/added in this mutation. -func (m *IdempotencyRecordMutation) AddedEdges() []string { - edges := make([]string, 0, 0) - return edges -} - -// AddedIDs returns all IDs (to other nodes) that were added for the given edge -// name in this mutation. -func (m *IdempotencyRecordMutation) AddedIDs(name string) []ent.Value { - return nil -} +func (m *GroupMutation) AddedEdges() []string { + edges := make([]string, 0, 6) + if m.api_keys != nil { + edges = append(edges, group.EdgeAPIKeys) + } + if m.redeem_codes != nil { + edges = append(edges, group.EdgeRedeemCodes) + } + if m.subscriptions != nil { + edges = append(edges, group.EdgeSubscriptions) + } + if m.usage_logs != nil { + edges = append(edges, group.EdgeUsageLogs) + } + if m.accounts != nil { + edges = append(edges, group.EdgeAccounts) + } + if m.allowed_users != nil { + edges = append(edges, group.EdgeAllowedUsers) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *GroupMutation) AddedIDs(name string) []ent.Value { + switch name { + case group.EdgeAPIKeys: + ids := make([]ent.Value, 0, len(m.api_keys)) + for id := range m.api_keys { + ids = append(ids, id) + } + return ids + case group.EdgeRedeemCodes: + ids := make([]ent.Value, 0, len(m.redeem_codes)) + for id := range m.redeem_codes { + ids = append(ids, id) + } + return ids + case group.EdgeSubscriptions: + ids := make([]ent.Value, 0, len(m.subscriptions)) + for id := range m.subscriptions { + ids = append(ids, id) + } + return ids + case group.EdgeUsageLogs: + ids := make([]ent.Value, 0, len(m.usage_logs)) + for id := range m.usage_logs { + ids = append(ids, id) + } + return ids + case group.EdgeAccounts: + ids := make([]ent.Value, 0, len(m.accounts)) + for id := range m.accounts { + ids = append(ids, id) + } + return ids + case group.EdgeAllowedUsers: + ids := make([]ent.Value, 0, len(m.allowed_users)) + for id := range m.allowed_users { + ids = append(ids, id) + } + return ids + } + return nil +} // RemovedEdges returns all edge names that were removed in this mutation. -func (m *IdempotencyRecordMutation) RemovedEdges() []string { - edges := make([]string, 0, 0) +func (m *GroupMutation) RemovedEdges() []string { + edges := make([]string, 0, 6) + if m.removedapi_keys != nil { + edges = append(edges, group.EdgeAPIKeys) + } + if m.removedredeem_codes != nil { + edges = append(edges, group.EdgeRedeemCodes) + } + if m.removedsubscriptions != nil { + edges = append(edges, group.EdgeSubscriptions) + } + if m.removedusage_logs != nil { + edges = append(edges, group.EdgeUsageLogs) + } + if m.removedaccounts != nil { + edges = append(edges, group.EdgeAccounts) + } + if m.removedallowed_users != nil { + edges = append(edges, group.EdgeAllowedUsers) + } return edges } // RemovedIDs returns all IDs (to other nodes) that were removed for the edge with // the given name in this mutation. -func (m *IdempotencyRecordMutation) RemovedIDs(name string) []ent.Value { +func (m *GroupMutation) RemovedIDs(name string) []ent.Value { + switch name { + case group.EdgeAPIKeys: + ids := make([]ent.Value, 0, len(m.removedapi_keys)) + for id := range m.removedapi_keys { + ids = append(ids, id) + } + return ids + case group.EdgeRedeemCodes: + ids := make([]ent.Value, 0, len(m.removedredeem_codes)) + for id := range m.removedredeem_codes { + ids = append(ids, id) + } + return ids + case group.EdgeSubscriptions: + ids := make([]ent.Value, 0, len(m.removedsubscriptions)) + for id := range m.removedsubscriptions { + ids = append(ids, id) + } + return ids + case group.EdgeUsageLogs: + ids := make([]ent.Value, 0, len(m.removedusage_logs)) + for id := range m.removedusage_logs { + ids = append(ids, id) + } + return ids + case group.EdgeAccounts: + ids := make([]ent.Value, 0, len(m.removedaccounts)) + for id := range m.removedaccounts { + ids = append(ids, id) + } + return ids + case group.EdgeAllowedUsers: + ids := make([]ent.Value, 0, len(m.removedallowed_users)) + for id := range m.removedallowed_users { + ids = append(ids, id) + } + return ids + } return nil } // ClearedEdges returns all edge names that were cleared in this mutation. -func (m *IdempotencyRecordMutation) ClearedEdges() []string { - edges := make([]string, 0, 0) +func (m *GroupMutation) ClearedEdges() []string { + edges := make([]string, 0, 6) + if m.clearedapi_keys { + edges = append(edges, group.EdgeAPIKeys) + } + if m.clearedredeem_codes { + edges = append(edges, group.EdgeRedeemCodes) + } + if m.clearedsubscriptions { + edges = append(edges, group.EdgeSubscriptions) + } + if m.clearedusage_logs { + edges = append(edges, group.EdgeUsageLogs) + } + if m.clearedaccounts { + edges = append(edges, group.EdgeAccounts) + } + if m.clearedallowed_users { + edges = append(edges, group.EdgeAllowedUsers) + } return edges } // EdgeCleared returns a boolean which indicates if the edge with the given name // was cleared in this mutation. -func (m *IdempotencyRecordMutation) EdgeCleared(name string) bool { +func (m *GroupMutation) EdgeCleared(name string) bool { + switch name { + case group.EdgeAPIKeys: + return m.clearedapi_keys + case group.EdgeRedeemCodes: + return m.clearedredeem_codes + case group.EdgeSubscriptions: + return m.clearedsubscriptions + case group.EdgeUsageLogs: + return m.clearedusage_logs + case group.EdgeAccounts: + return m.clearedaccounts + case group.EdgeAllowedUsers: + return m.clearedallowed_users + } return false } // ClearEdge clears the value of the edge with the given name. It returns an error // if that edge is not defined in the schema. -func (m *IdempotencyRecordMutation) ClearEdge(name string) error { - return fmt.Errorf("unknown IdempotencyRecord unique edge %s", name) +func (m *GroupMutation) ClearEdge(name string) error { + switch name { + } + return fmt.Errorf("unknown Group unique edge %s", name) } // ResetEdge resets all changes to the edge with the given name in this mutation. // It returns an error if the edge is not defined in the schema. -func (m *IdempotencyRecordMutation) ResetEdge(name string) error { - return fmt.Errorf("unknown IdempotencyRecord edge %s", name) +func (m *GroupMutation) ResetEdge(name string) error { + switch name { + case group.EdgeAPIKeys: + m.ResetAPIKeys() + return nil + case group.EdgeRedeemCodes: + m.ResetRedeemCodes() + return nil + case group.EdgeSubscriptions: + m.ResetSubscriptions() + return nil + case group.EdgeUsageLogs: + m.ResetUsageLogs() + return nil + case group.EdgeAccounts: + m.ResetAccounts() + return nil + case group.EdgeAllowedUsers: + m.ResetAllowedUsers() + return nil + } + return fmt.Errorf("unknown Group edge %s", name) } -// PaymentAuditLogMutation represents an operation that mutates the PaymentAuditLog nodes in the graph. -type PaymentAuditLogMutation struct { +// IdempotencyRecordMutation represents an operation that mutates the IdempotencyRecord nodes in the graph. +type IdempotencyRecordMutation struct { config - op Op - typ string - id *int64 - order_id *string - action *string - detail *string - operator *string - created_at *time.Time - clearedFields map[string]struct{} - done bool - oldValue func(context.Context) (*PaymentAuditLog, error) - predicates []predicate.PaymentAuditLog + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + scope *string + idempotency_key_hash *string + request_fingerprint *string + status *string + response_status *int + addresponse_status *int + response_body *string + error_reason *string + locked_until *time.Time + expires_at *time.Time + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*IdempotencyRecord, error) + predicates []predicate.IdempotencyRecord } -var _ ent.Mutation = (*PaymentAuditLogMutation)(nil) +var _ ent.Mutation = (*IdempotencyRecordMutation)(nil) -// paymentauditlogOption allows management of the mutation configuration using functional options. -type paymentauditlogOption func(*PaymentAuditLogMutation) +// idempotencyrecordOption allows management of the mutation configuration using functional options. +type idempotencyrecordOption func(*IdempotencyRecordMutation) -// newPaymentAuditLogMutation creates new mutation for the PaymentAuditLog entity. -func newPaymentAuditLogMutation(c config, op Op, opts ...paymentauditlogOption) *PaymentAuditLogMutation { - m := &PaymentAuditLogMutation{ +// newIdempotencyRecordMutation creates new mutation for the IdempotencyRecord entity. +func newIdempotencyRecordMutation(c config, op Op, opts ...idempotencyrecordOption) *IdempotencyRecordMutation { + m := &IdempotencyRecordMutation{ config: c, op: op, - typ: TypePaymentAuditLog, + typ: TypeIdempotencyRecord, clearedFields: make(map[string]struct{}), } for _, opt := range opts { @@ -12227,20 +13099,20 @@ func newPaymentAuditLogMutation(c config, op Op, opts ...paymentauditlogOption) return m } -// withPaymentAuditLogID sets the ID field of the mutation. -func withPaymentAuditLogID(id int64) paymentauditlogOption { - return func(m *PaymentAuditLogMutation) { +// withIdempotencyRecordID sets the ID field of the mutation. +func withIdempotencyRecordID(id int64) idempotencyrecordOption { + return func(m *IdempotencyRecordMutation) { var ( err error once sync.Once - value *PaymentAuditLog + value *IdempotencyRecord ) - m.oldValue = func(ctx context.Context) (*PaymentAuditLog, error) { + m.oldValue = func(ctx context.Context) (*IdempotencyRecord, error) { once.Do(func() { if m.done { err = errors.New("querying old values post mutation is not allowed") } else { - value, err = m.Client().PaymentAuditLog.Get(ctx, id) + value, err = m.Client().IdempotencyRecord.Get(ctx, id) } }) return value, err @@ -12249,10 +13121,10 @@ func withPaymentAuditLogID(id int64) paymentauditlogOption { } } -// withPaymentAuditLog sets the old PaymentAuditLog of the mutation. -func withPaymentAuditLog(node *PaymentAuditLog) paymentauditlogOption { - return func(m *PaymentAuditLogMutation) { - m.oldValue = func(context.Context) (*PaymentAuditLog, error) { +// withIdempotencyRecord sets the old IdempotencyRecord of the mutation. +func withIdempotencyRecord(node *IdempotencyRecord) idempotencyrecordOption { + return func(m *IdempotencyRecordMutation) { + m.oldValue = func(context.Context) (*IdempotencyRecord, error) { return node, nil } m.id = &node.ID @@ -12261,7 +13133,7 @@ func withPaymentAuditLog(node *PaymentAuditLog) paymentauditlogOption { // Client returns a new `ent.Client` from the mutation. If the mutation was // executed in a transaction (ent.Tx), a transactional client is returned. -func (m PaymentAuditLogMutation) Client() *Client { +func (m IdempotencyRecordMutation) Client() *Client { client := &Client{config: m.config} client.init() return client @@ -12269,7 +13141,7 @@ func (m PaymentAuditLogMutation) Client() *Client { // Tx returns an `ent.Tx` for mutations that were executed in transactions; // it returns an error otherwise. -func (m PaymentAuditLogMutation) Tx() (*Tx, error) { +func (m IdempotencyRecordMutation) Tx() (*Tx, error) { if _, ok := m.driver.(*txDriver); !ok { return nil, errors.New("ent: mutation is not running in a transaction") } @@ -12280,7 +13152,7 @@ func (m PaymentAuditLogMutation) Tx() (*Tx, error) { // ID returns the ID value in the mutation. Note that the ID is only available // if it was provided to the builder or after it was returned from the database. -func (m *PaymentAuditLogMutation) ID() (id int64, exists bool) { +func (m *IdempotencyRecordMutation) ID() (id int64, exists bool) { if m.id == nil { return } @@ -12291,7 +13163,7 @@ func (m *PaymentAuditLogMutation) ID() (id int64, exists bool) { // That means, if the mutation is applied within a transaction with an isolation level such // as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated // or updated by the mutation. -func (m *PaymentAuditLogMutation) IDs(ctx context.Context) ([]int64, error) { +func (m *IdempotencyRecordMutation) IDs(ctx context.Context) ([]int64, error) { switch { case m.op.Is(OpUpdateOne | OpDeleteOne): id, exists := m.ID() @@ -12300,3381 +13172,6244 @@ func (m *PaymentAuditLogMutation) IDs(ctx context.Context) ([]int64, error) { } fallthrough case m.op.Is(OpUpdate | OpDelete): - return m.Client().PaymentAuditLog.Query().Where(m.predicates...).IDs(ctx) + return m.Client().IdempotencyRecord.Query().Where(m.predicates...).IDs(ctx) default: return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) } } -// SetOrderID sets the "order_id" field. -func (m *PaymentAuditLogMutation) SetOrderID(s string) { - m.order_id = &s +// SetCreatedAt sets the "created_at" field. +func (m *IdempotencyRecordMutation) SetCreatedAt(t time.Time) { + m.created_at = &t } -// OrderID returns the value of the "order_id" field in the mutation. -func (m *PaymentAuditLogMutation) OrderID() (r string, exists bool) { - v := m.order_id +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *IdempotencyRecordMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at if v == nil { return } return *v, true } -// OldOrderID returns the old "order_id" field's value of the PaymentAuditLog entity. -// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database. +// OldCreatedAt returns the old "created_at" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentAuditLogMutation) OldOrderID(ctx context.Context) (v string, err error) { +func (m *IdempotencyRecordMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldOrderID is only allowed on UpdateOne operations") + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldOrderID requires an ID field in the mutation") + return v, errors.New("OldCreatedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldOrderID: %w", err) + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) } - return oldValue.OrderID, nil + return oldValue.CreatedAt, nil } -// ResetOrderID resets all changes to the "order_id" field. -func (m *PaymentAuditLogMutation) ResetOrderID() { - m.order_id = nil +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *IdempotencyRecordMutation) ResetCreatedAt() { + m.created_at = nil } -// SetAction sets the "action" field. -func (m *PaymentAuditLogMutation) SetAction(s string) { - m.action = &s +// SetUpdatedAt sets the "updated_at" field. +func (m *IdempotencyRecordMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t } -// Action returns the value of the "action" field in the mutation. -func (m *PaymentAuditLogMutation) Action() (r string, exists bool) { - v := m.action +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *IdempotencyRecordMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at if v == nil { return } return *v, true } -// OldAction returns the old "action" field's value of the PaymentAuditLog entity. -// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database. +// OldUpdatedAt returns the old "updated_at" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentAuditLogMutation) OldAction(ctx context.Context) (v string, err error) { +func (m *IdempotencyRecordMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldAction is only allowed on UpdateOne operations") + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldAction requires an ID field in the mutation") + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldAction: %w", err) + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) } - return oldValue.Action, nil + return oldValue.UpdatedAt, nil } -// ResetAction resets all changes to the "action" field. -func (m *PaymentAuditLogMutation) ResetAction() { - m.action = nil +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *IdempotencyRecordMutation) ResetUpdatedAt() { + m.updated_at = nil } -// SetDetail sets the "detail" field. -func (m *PaymentAuditLogMutation) SetDetail(s string) { - m.detail = &s +// SetScope sets the "scope" field. +func (m *IdempotencyRecordMutation) SetScope(s string) { + m.scope = &s } -// Detail returns the value of the "detail" field in the mutation. -func (m *PaymentAuditLogMutation) Detail() (r string, exists bool) { - v := m.detail +// Scope returns the value of the "scope" field in the mutation. +func (m *IdempotencyRecordMutation) Scope() (r string, exists bool) { + v := m.scope if v == nil { return } return *v, true } -// OldDetail returns the old "detail" field's value of the PaymentAuditLog entity. -// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database. +// OldScope returns the old "scope" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentAuditLogMutation) OldDetail(ctx context.Context) (v string, err error) { +func (m *IdempotencyRecordMutation) OldScope(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldDetail is only allowed on UpdateOne operations") + return v, errors.New("OldScope is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldDetail requires an ID field in the mutation") + return v, errors.New("OldScope requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldDetail: %w", err) + return v, fmt.Errorf("querying old value for OldScope: %w", err) } - return oldValue.Detail, nil + return oldValue.Scope, nil } -// ResetDetail resets all changes to the "detail" field. -func (m *PaymentAuditLogMutation) ResetDetail() { - m.detail = nil +// ResetScope resets all changes to the "scope" field. +func (m *IdempotencyRecordMutation) ResetScope() { + m.scope = nil } -// SetOperator sets the "operator" field. -func (m *PaymentAuditLogMutation) SetOperator(s string) { - m.operator = &s +// SetIdempotencyKeyHash sets the "idempotency_key_hash" field. +func (m *IdempotencyRecordMutation) SetIdempotencyKeyHash(s string) { + m.idempotency_key_hash = &s } -// Operator returns the value of the "operator" field in the mutation. -func (m *PaymentAuditLogMutation) Operator() (r string, exists bool) { - v := m.operator +// IdempotencyKeyHash returns the value of the "idempotency_key_hash" field in the mutation. +func (m *IdempotencyRecordMutation) IdempotencyKeyHash() (r string, exists bool) { + v := m.idempotency_key_hash if v == nil { return } return *v, true } -// OldOperator returns the old "operator" field's value of the PaymentAuditLog entity. -// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database. +// OldIdempotencyKeyHash returns the old "idempotency_key_hash" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentAuditLogMutation) OldOperator(ctx context.Context) (v string, err error) { +func (m *IdempotencyRecordMutation) OldIdempotencyKeyHash(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldOperator is only allowed on UpdateOne operations") + return v, errors.New("OldIdempotencyKeyHash is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldOperator requires an ID field in the mutation") + return v, errors.New("OldIdempotencyKeyHash requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldOperator: %w", err) + return v, fmt.Errorf("querying old value for OldIdempotencyKeyHash: %w", err) } - return oldValue.Operator, nil + return oldValue.IdempotencyKeyHash, nil } -// ResetOperator resets all changes to the "operator" field. -func (m *PaymentAuditLogMutation) ResetOperator() { - m.operator = nil +// ResetIdempotencyKeyHash resets all changes to the "idempotency_key_hash" field. +func (m *IdempotencyRecordMutation) ResetIdempotencyKeyHash() { + m.idempotency_key_hash = nil } -// SetCreatedAt sets the "created_at" field. -func (m *PaymentAuditLogMutation) SetCreatedAt(t time.Time) { - m.created_at = &t +// SetRequestFingerprint sets the "request_fingerprint" field. +func (m *IdempotencyRecordMutation) SetRequestFingerprint(s string) { + m.request_fingerprint = &s } -// CreatedAt returns the value of the "created_at" field in the mutation. -func (m *PaymentAuditLogMutation) CreatedAt() (r time.Time, exists bool) { - v := m.created_at +// RequestFingerprint returns the value of the "request_fingerprint" field in the mutation. +func (m *IdempotencyRecordMutation) RequestFingerprint() (r string, exists bool) { + v := m.request_fingerprint if v == nil { return } return *v, true } -// OldCreatedAt returns the old "created_at" field's value of the PaymentAuditLog entity. -// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database. +// OldRequestFingerprint returns the old "request_fingerprint" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentAuditLogMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { +func (m *IdempotencyRecordMutation) OldRequestFingerprint(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + return v, errors.New("OldRequestFingerprint is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldCreatedAt requires an ID field in the mutation") + return v, errors.New("OldRequestFingerprint requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + return v, fmt.Errorf("querying old value for OldRequestFingerprint: %w", err) } - return oldValue.CreatedAt, nil + return oldValue.RequestFingerprint, nil } -// ResetCreatedAt resets all changes to the "created_at" field. -func (m *PaymentAuditLogMutation) ResetCreatedAt() { - m.created_at = nil +// ResetRequestFingerprint resets all changes to the "request_fingerprint" field. +func (m *IdempotencyRecordMutation) ResetRequestFingerprint() { + m.request_fingerprint = nil } -// Where appends a list predicates to the PaymentAuditLogMutation builder. -func (m *PaymentAuditLogMutation) Where(ps ...predicate.PaymentAuditLog) { - m.predicates = append(m.predicates, ps...) +// SetStatus sets the "status" field. +func (m *IdempotencyRecordMutation) SetStatus(s string) { + m.status = &s } -// WhereP appends storage-level predicates to the PaymentAuditLogMutation builder. Using this method, -// users can use type-assertion to append predicates that do not depend on any generated package. -func (m *PaymentAuditLogMutation) WhereP(ps ...func(*sql.Selector)) { - p := make([]predicate.PaymentAuditLog, len(ps)) - for i := range ps { - p[i] = ps[i] +// Status returns the value of the "status" field in the mutation. +func (m *IdempotencyRecordMutation) Status() (r string, exists bool) { + v := m.status + if v == nil { + return } - m.Where(p...) + return *v, true } -// Op returns the operation name. -func (m *PaymentAuditLogMutation) Op() Op { - return m.op +// OldStatus returns the old "status" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldStatus(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldStatus is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldStatus requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldStatus: %w", err) + } + return oldValue.Status, nil } -// SetOp allows setting the mutation operation. -func (m *PaymentAuditLogMutation) SetOp(op Op) { - m.op = op +// ResetStatus resets all changes to the "status" field. +func (m *IdempotencyRecordMutation) ResetStatus() { + m.status = nil } -// Type returns the node type of this mutation (PaymentAuditLog). -func (m *PaymentAuditLogMutation) Type() string { - return m.typ +// SetResponseStatus sets the "response_status" field. +func (m *IdempotencyRecordMutation) SetResponseStatus(i int) { + m.response_status = &i + m.addresponse_status = nil } -// Fields returns all fields that were changed during this mutation. Note that in -// order to get all numeric fields that were incremented/decremented, call -// AddedFields(). -func (m *PaymentAuditLogMutation) Fields() []string { - fields := make([]string, 0, 5) - if m.order_id != nil { - fields = append(fields, paymentauditlog.FieldOrderID) - } - if m.action != nil { - fields = append(fields, paymentauditlog.FieldAction) - } - if m.detail != nil { - fields = append(fields, paymentauditlog.FieldDetail) - } - if m.operator != nil { - fields = append(fields, paymentauditlog.FieldOperator) - } - if m.created_at != nil { - fields = append(fields, paymentauditlog.FieldCreatedAt) +// ResponseStatus returns the value of the "response_status" field in the mutation. +func (m *IdempotencyRecordMutation) ResponseStatus() (r int, exists bool) { + v := m.response_status + if v == nil { + return } - return fields + return *v, true } -// Field returns the value of a field with the given name. The second boolean -// return value indicates that this field was not set, or was not defined in the -// schema. -func (m *PaymentAuditLogMutation) Field(name string) (ent.Value, bool) { - switch name { - case paymentauditlog.FieldOrderID: - return m.OrderID() - case paymentauditlog.FieldAction: - return m.Action() - case paymentauditlog.FieldDetail: - return m.Detail() - case paymentauditlog.FieldOperator: - return m.Operator() - case paymentauditlog.FieldCreatedAt: - return m.CreatedAt() +// OldResponseStatus returns the old "response_status" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldResponseStatus(ctx context.Context) (v *int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldResponseStatus is only allowed on UpdateOne operations") } - return nil, false -} - -// OldField returns the old value of the field from the database. An error is -// returned if the mutation operation is not UpdateOne, or the query to the -// database failed. -func (m *PaymentAuditLogMutation) OldField(ctx context.Context, name string) (ent.Value, error) { - switch name { - case paymentauditlog.FieldOrderID: - return m.OldOrderID(ctx) - case paymentauditlog.FieldAction: - return m.OldAction(ctx) - case paymentauditlog.FieldDetail: - return m.OldDetail(ctx) - case paymentauditlog.FieldOperator: - return m.OldOperator(ctx) - case paymentauditlog.FieldCreatedAt: - return m.OldCreatedAt(ctx) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldResponseStatus requires an ID field in the mutation") } - return nil, fmt.Errorf("unknown PaymentAuditLog field %s", name) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldResponseStatus: %w", err) + } + return oldValue.ResponseStatus, nil } -// SetField sets the value of a field with the given name. It returns an error if -// the field is not defined in the schema, or if the type mismatched the field -// type. -func (m *PaymentAuditLogMutation) SetField(name string, value ent.Value) error { - switch name { - case paymentauditlog.FieldOrderID: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetOrderID(v) - return nil - case paymentauditlog.FieldAction: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetAction(v) - return nil - case paymentauditlog.FieldDetail: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetDetail(v) - return nil - case paymentauditlog.FieldOperator: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetOperator(v) - return nil - case paymentauditlog.FieldCreatedAt: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetCreatedAt(v) - return nil +// AddResponseStatus adds i to the "response_status" field. +func (m *IdempotencyRecordMutation) AddResponseStatus(i int) { + if m.addresponse_status != nil { + *m.addresponse_status += i + } else { + m.addresponse_status = &i } - return fmt.Errorf("unknown PaymentAuditLog field %s", name) } -// AddedFields returns all numeric fields that were incremented/decremented during -// this mutation. -func (m *PaymentAuditLogMutation) AddedFields() []string { - return nil +// AddedResponseStatus returns the value that was added to the "response_status" field in this mutation. +func (m *IdempotencyRecordMutation) AddedResponseStatus() (r int, exists bool) { + v := m.addresponse_status + if v == nil { + return + } + return *v, true } -// AddedField returns the numeric value that was incremented/decremented on a field -// with the given name. The second boolean return value indicates that this field -// was not set, or was not defined in the schema. -func (m *PaymentAuditLogMutation) AddedField(name string) (ent.Value, bool) { - return nil, false +// ClearResponseStatus clears the value of the "response_status" field. +func (m *IdempotencyRecordMutation) ClearResponseStatus() { + m.response_status = nil + m.addresponse_status = nil + m.clearedFields[idempotencyrecord.FieldResponseStatus] = struct{}{} } -// AddField adds the value to the field with the given name. It returns an error if -// the field is not defined in the schema, or if the type mismatched the field -// type. -func (m *PaymentAuditLogMutation) AddField(name string, value ent.Value) error { - switch name { - } - return fmt.Errorf("unknown PaymentAuditLog numeric field %s", name) +// ResponseStatusCleared returns if the "response_status" field was cleared in this mutation. +func (m *IdempotencyRecordMutation) ResponseStatusCleared() bool { + _, ok := m.clearedFields[idempotencyrecord.FieldResponseStatus] + return ok } -// ClearedFields returns all nullable fields that were cleared during this -// mutation. -func (m *PaymentAuditLogMutation) ClearedFields() []string { - return nil +// ResetResponseStatus resets all changes to the "response_status" field. +func (m *IdempotencyRecordMutation) ResetResponseStatus() { + m.response_status = nil + m.addresponse_status = nil + delete(m.clearedFields, idempotencyrecord.FieldResponseStatus) } -// FieldCleared returns a boolean indicating if a field with the given name was -// cleared in this mutation. -func (m *PaymentAuditLogMutation) FieldCleared(name string) bool { - _, ok := m.clearedFields[name] - return ok +// SetResponseBody sets the "response_body" field. +func (m *IdempotencyRecordMutation) SetResponseBody(s string) { + m.response_body = &s } -// ClearField clears the value of the field with the given name. It returns an -// error if the field is not defined in the schema. -func (m *PaymentAuditLogMutation) ClearField(name string) error { - return fmt.Errorf("unknown PaymentAuditLog nullable field %s", name) +// ResponseBody returns the value of the "response_body" field in the mutation. +func (m *IdempotencyRecordMutation) ResponseBody() (r string, exists bool) { + v := m.response_body + if v == nil { + return + } + return *v, true } -// ResetField resets all changes in the mutation for the field with the given name. -// It returns an error if the field is not defined in the schema. -func (m *PaymentAuditLogMutation) ResetField(name string) error { - switch name { - case paymentauditlog.FieldOrderID: - m.ResetOrderID() - return nil - case paymentauditlog.FieldAction: - m.ResetAction() - return nil - case paymentauditlog.FieldDetail: - m.ResetDetail() - return nil - case paymentauditlog.FieldOperator: - m.ResetOperator() - return nil - case paymentauditlog.FieldCreatedAt: - m.ResetCreatedAt() - return nil +// OldResponseBody returns the old "response_body" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldResponseBody(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldResponseBody is only allowed on UpdateOne operations") } - return fmt.Errorf("unknown PaymentAuditLog field %s", name) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldResponseBody requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldResponseBody: %w", err) + } + return oldValue.ResponseBody, nil } -// AddedEdges returns all edge names that were set/added in this mutation. -func (m *PaymentAuditLogMutation) AddedEdges() []string { - edges := make([]string, 0, 0) - return edges +// ClearResponseBody clears the value of the "response_body" field. +func (m *IdempotencyRecordMutation) ClearResponseBody() { + m.response_body = nil + m.clearedFields[idempotencyrecord.FieldResponseBody] = struct{}{} } -// AddedIDs returns all IDs (to other nodes) that were added for the given edge -// name in this mutation. -func (m *PaymentAuditLogMutation) AddedIDs(name string) []ent.Value { - return nil +// ResponseBodyCleared returns if the "response_body" field was cleared in this mutation. +func (m *IdempotencyRecordMutation) ResponseBodyCleared() bool { + _, ok := m.clearedFields[idempotencyrecord.FieldResponseBody] + return ok } -// RemovedEdges returns all edge names that were removed in this mutation. -func (m *PaymentAuditLogMutation) RemovedEdges() []string { - edges := make([]string, 0, 0) - return edges +// ResetResponseBody resets all changes to the "response_body" field. +func (m *IdempotencyRecordMutation) ResetResponseBody() { + m.response_body = nil + delete(m.clearedFields, idempotencyrecord.FieldResponseBody) } -// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with -// the given name in this mutation. -func (m *PaymentAuditLogMutation) RemovedIDs(name string) []ent.Value { - return nil +// SetErrorReason sets the "error_reason" field. +func (m *IdempotencyRecordMutation) SetErrorReason(s string) { + m.error_reason = &s } -// ClearedEdges returns all edge names that were cleared in this mutation. -func (m *PaymentAuditLogMutation) ClearedEdges() []string { - edges := make([]string, 0, 0) - return edges +// ErrorReason returns the value of the "error_reason" field in the mutation. +func (m *IdempotencyRecordMutation) ErrorReason() (r string, exists bool) { + v := m.error_reason + if v == nil { + return + } + return *v, true } -// EdgeCleared returns a boolean which indicates if the edge with the given name -// was cleared in this mutation. -func (m *PaymentAuditLogMutation) EdgeCleared(name string) bool { - return false +// OldErrorReason returns the old "error_reason" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldErrorReason(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldErrorReason is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldErrorReason requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldErrorReason: %w", err) + } + return oldValue.ErrorReason, nil } -// ClearEdge clears the value of the edge with the given name. It returns an error -// if that edge is not defined in the schema. -func (m *PaymentAuditLogMutation) ClearEdge(name string) error { - return fmt.Errorf("unknown PaymentAuditLog unique edge %s", name) +// ClearErrorReason clears the value of the "error_reason" field. +func (m *IdempotencyRecordMutation) ClearErrorReason() { + m.error_reason = nil + m.clearedFields[idempotencyrecord.FieldErrorReason] = struct{}{} } -// ResetEdge resets all changes to the edge with the given name in this mutation. -// It returns an error if the edge is not defined in the schema. -func (m *PaymentAuditLogMutation) ResetEdge(name string) error { - return fmt.Errorf("unknown PaymentAuditLog edge %s", name) -} - -// PaymentOrderMutation represents an operation that mutates the PaymentOrder nodes in the graph. -type PaymentOrderMutation struct { - config - op Op - typ string - id *int64 - user_email *string - user_name *string - user_notes *string - amount *float64 - addamount *float64 - pay_amount *float64 - addpay_amount *float64 - fee_rate *float64 - addfee_rate *float64 - recharge_code *string - out_trade_no *string - payment_type *string - payment_trade_no *string - pay_url *string - qr_code *string - qr_code_img *string - order_type *string - plan_id *int64 - addplan_id *int64 - subscription_group_id *int64 - addsubscription_group_id *int64 - subscription_days *int - addsubscription_days *int - provider_instance_id *string - status *string - refund_amount *float64 - addrefund_amount *float64 - refund_reason *string - refund_at *time.Time - force_refund *bool - refund_requested_at *time.Time - refund_request_reason *string - refund_requested_by *string - expires_at *time.Time - paid_at *time.Time - completed_at *time.Time - failed_at *time.Time - failed_reason *string - client_ip *string - src_host *string - src_url *string - created_at *time.Time - updated_at *time.Time - clearedFields map[string]struct{} - user *int64 - cleareduser bool - done bool - oldValue func(context.Context) (*PaymentOrder, error) - predicates []predicate.PaymentOrder -} - -var _ ent.Mutation = (*PaymentOrderMutation)(nil) - -// paymentorderOption allows management of the mutation configuration using functional options. -type paymentorderOption func(*PaymentOrderMutation) - -// newPaymentOrderMutation creates new mutation for the PaymentOrder entity. -func newPaymentOrderMutation(c config, op Op, opts ...paymentorderOption) *PaymentOrderMutation { - m := &PaymentOrderMutation{ - config: c, - op: op, - typ: TypePaymentOrder, - clearedFields: make(map[string]struct{}), - } - for _, opt := range opts { - opt(m) - } - return m -} - -// withPaymentOrderID sets the ID field of the mutation. -func withPaymentOrderID(id int64) paymentorderOption { - return func(m *PaymentOrderMutation) { - var ( - err error - once sync.Once - value *PaymentOrder - ) - m.oldValue = func(ctx context.Context) (*PaymentOrder, error) { - once.Do(func() { - if m.done { - err = errors.New("querying old values post mutation is not allowed") - } else { - value, err = m.Client().PaymentOrder.Get(ctx, id) - } - }) - return value, err - } - m.id = &id - } -} - -// withPaymentOrder sets the old PaymentOrder of the mutation. -func withPaymentOrder(node *PaymentOrder) paymentorderOption { - return func(m *PaymentOrderMutation) { - m.oldValue = func(context.Context) (*PaymentOrder, error) { - return node, nil - } - m.id = &node.ID - } -} - -// Client returns a new `ent.Client` from the mutation. If the mutation was -// executed in a transaction (ent.Tx), a transactional client is returned. -func (m PaymentOrderMutation) Client() *Client { - client := &Client{config: m.config} - client.init() - return client -} - -// Tx returns an `ent.Tx` for mutations that were executed in transactions; -// it returns an error otherwise. -func (m PaymentOrderMutation) Tx() (*Tx, error) { - if _, ok := m.driver.(*txDriver); !ok { - return nil, errors.New("ent: mutation is not running in a transaction") - } - tx := &Tx{config: m.config} - tx.init() - return tx, nil -} - -// ID returns the ID value in the mutation. Note that the ID is only available -// if it was provided to the builder or after it was returned from the database. -func (m *PaymentOrderMutation) ID() (id int64, exists bool) { - if m.id == nil { - return - } - return *m.id, true +// ErrorReasonCleared returns if the "error_reason" field was cleared in this mutation. +func (m *IdempotencyRecordMutation) ErrorReasonCleared() bool { + _, ok := m.clearedFields[idempotencyrecord.FieldErrorReason] + return ok } -// IDs queries the database and returns the entity ids that match the mutation's predicate. -// That means, if the mutation is applied within a transaction with an isolation level such -// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated -// or updated by the mutation. -func (m *PaymentOrderMutation) IDs(ctx context.Context) ([]int64, error) { - switch { - case m.op.Is(OpUpdateOne | OpDeleteOne): - id, exists := m.ID() - if exists { - return []int64{id}, nil - } - fallthrough - case m.op.Is(OpUpdate | OpDelete): - return m.Client().PaymentOrder.Query().Where(m.predicates...).IDs(ctx) - default: - return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) - } +// ResetErrorReason resets all changes to the "error_reason" field. +func (m *IdempotencyRecordMutation) ResetErrorReason() { + m.error_reason = nil + delete(m.clearedFields, idempotencyrecord.FieldErrorReason) } -// SetUserID sets the "user_id" field. -func (m *PaymentOrderMutation) SetUserID(i int64) { - m.user = &i +// SetLockedUntil sets the "locked_until" field. +func (m *IdempotencyRecordMutation) SetLockedUntil(t time.Time) { + m.locked_until = &t } -// UserID returns the value of the "user_id" field in the mutation. -func (m *PaymentOrderMutation) UserID() (r int64, exists bool) { - v := m.user +// LockedUntil returns the value of the "locked_until" field in the mutation. +func (m *IdempotencyRecordMutation) LockedUntil() (r time.Time, exists bool) { + v := m.locked_until if v == nil { return } return *v, true } -// OldUserID returns the old "user_id" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldLockedUntil returns the old "locked_until" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldUserID(ctx context.Context) (v int64, err error) { +func (m *IdempotencyRecordMutation) OldLockedUntil(ctx context.Context) (v *time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldUserID is only allowed on UpdateOne operations") + return v, errors.New("OldLockedUntil is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldUserID requires an ID field in the mutation") + return v, errors.New("OldLockedUntil requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldUserID: %w", err) + return v, fmt.Errorf("querying old value for OldLockedUntil: %w", err) } - return oldValue.UserID, nil + return oldValue.LockedUntil, nil } -// ResetUserID resets all changes to the "user_id" field. -func (m *PaymentOrderMutation) ResetUserID() { - m.user = nil +// ClearLockedUntil clears the value of the "locked_until" field. +func (m *IdempotencyRecordMutation) ClearLockedUntil() { + m.locked_until = nil + m.clearedFields[idempotencyrecord.FieldLockedUntil] = struct{}{} } -// SetUserEmail sets the "user_email" field. -func (m *PaymentOrderMutation) SetUserEmail(s string) { - m.user_email = &s +// LockedUntilCleared returns if the "locked_until" field was cleared in this mutation. +func (m *IdempotencyRecordMutation) LockedUntilCleared() bool { + _, ok := m.clearedFields[idempotencyrecord.FieldLockedUntil] + return ok } -// UserEmail returns the value of the "user_email" field in the mutation. -func (m *PaymentOrderMutation) UserEmail() (r string, exists bool) { - v := m.user_email +// ResetLockedUntil resets all changes to the "locked_until" field. +func (m *IdempotencyRecordMutation) ResetLockedUntil() { + m.locked_until = nil + delete(m.clearedFields, idempotencyrecord.FieldLockedUntil) +} + +// SetExpiresAt sets the "expires_at" field. +func (m *IdempotencyRecordMutation) SetExpiresAt(t time.Time) { + m.expires_at = &t +} + +// ExpiresAt returns the value of the "expires_at" field in the mutation. +func (m *IdempotencyRecordMutation) ExpiresAt() (r time.Time, exists bool) { + v := m.expires_at if v == nil { return } return *v, true } -// OldUserEmail returns the old "user_email" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldExpiresAt returns the old "expires_at" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldUserEmail(ctx context.Context) (v string, err error) { +func (m *IdempotencyRecordMutation) OldExpiresAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldUserEmail is only allowed on UpdateOne operations") + return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldUserEmail requires an ID field in the mutation") + return v, errors.New("OldExpiresAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldUserEmail: %w", err) + return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err) } - return oldValue.UserEmail, nil + return oldValue.ExpiresAt, nil } -// ResetUserEmail resets all changes to the "user_email" field. -func (m *PaymentOrderMutation) ResetUserEmail() { - m.user_email = nil +// ResetExpiresAt resets all changes to the "expires_at" field. +func (m *IdempotencyRecordMutation) ResetExpiresAt() { + m.expires_at = nil } -// SetUserName sets the "user_name" field. -func (m *PaymentOrderMutation) SetUserName(s string) { - m.user_name = &s +// Where appends a list predicates to the IdempotencyRecordMutation builder. +func (m *IdempotencyRecordMutation) Where(ps ...predicate.IdempotencyRecord) { + m.predicates = append(m.predicates, ps...) } -// UserName returns the value of the "user_name" field in the mutation. -func (m *PaymentOrderMutation) UserName() (r string, exists bool) { - v := m.user_name - if v == nil { - return - } - return *v, true -} - -// OldUserName returns the old "user_name" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldUserName(ctx context.Context) (v string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldUserName is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldUserName requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldUserName: %w", err) +// WhereP appends storage-level predicates to the IdempotencyRecordMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *IdempotencyRecordMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.IdempotencyRecord, len(ps)) + for i := range ps { + p[i] = ps[i] } - return oldValue.UserName, nil + m.Where(p...) } -// ResetUserName resets all changes to the "user_name" field. -func (m *PaymentOrderMutation) ResetUserName() { - m.user_name = nil +// Op returns the operation name. +func (m *IdempotencyRecordMutation) Op() Op { + return m.op } -// SetUserNotes sets the "user_notes" field. -func (m *PaymentOrderMutation) SetUserNotes(s string) { - m.user_notes = &s +// SetOp allows setting the mutation operation. +func (m *IdempotencyRecordMutation) SetOp(op Op) { + m.op = op } -// UserNotes returns the value of the "user_notes" field in the mutation. -func (m *PaymentOrderMutation) UserNotes() (r string, exists bool) { - v := m.user_notes - if v == nil { - return - } - return *v, true +// Type returns the node type of this mutation (IdempotencyRecord). +func (m *IdempotencyRecordMutation) Type() string { + return m.typ } -// OldUserNotes returns the old "user_notes" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldUserNotes(ctx context.Context) (v *string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldUserNotes is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldUserNotes requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldUserNotes: %w", err) +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *IdempotencyRecordMutation) Fields() []string { + fields := make([]string, 0, 11) + if m.created_at != nil { + fields = append(fields, idempotencyrecord.FieldCreatedAt) } - return oldValue.UserNotes, nil -} - -// ClearUserNotes clears the value of the "user_notes" field. -func (m *PaymentOrderMutation) ClearUserNotes() { - m.user_notes = nil - m.clearedFields[paymentorder.FieldUserNotes] = struct{}{} -} - -// UserNotesCleared returns if the "user_notes" field was cleared in this mutation. -func (m *PaymentOrderMutation) UserNotesCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldUserNotes] - return ok -} - -// ResetUserNotes resets all changes to the "user_notes" field. -func (m *PaymentOrderMutation) ResetUserNotes() { - m.user_notes = nil - delete(m.clearedFields, paymentorder.FieldUserNotes) -} - -// SetAmount sets the "amount" field. -func (m *PaymentOrderMutation) SetAmount(f float64) { - m.amount = &f - m.addamount = nil -} - -// Amount returns the value of the "amount" field in the mutation. -func (m *PaymentOrderMutation) Amount() (r float64, exists bool) { - v := m.amount - if v == nil { - return + if m.updated_at != nil { + fields = append(fields, idempotencyrecord.FieldUpdatedAt) } - return *v, true -} - -// OldAmount returns the old "amount" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldAmount(ctx context.Context) (v float64, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldAmount is only allowed on UpdateOne operations") + if m.scope != nil { + fields = append(fields, idempotencyrecord.FieldScope) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldAmount requires an ID field in the mutation") + if m.idempotency_key_hash != nil { + fields = append(fields, idempotencyrecord.FieldIdempotencyKeyHash) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldAmount: %w", err) + if m.request_fingerprint != nil { + fields = append(fields, idempotencyrecord.FieldRequestFingerprint) } - return oldValue.Amount, nil -} - -// AddAmount adds f to the "amount" field. -func (m *PaymentOrderMutation) AddAmount(f float64) { - if m.addamount != nil { - *m.addamount += f - } else { - m.addamount = &f + if m.status != nil { + fields = append(fields, idempotencyrecord.FieldStatus) } -} - -// AddedAmount returns the value that was added to the "amount" field in this mutation. -func (m *PaymentOrderMutation) AddedAmount() (r float64, exists bool) { - v := m.addamount - if v == nil { - return + if m.response_status != nil { + fields = append(fields, idempotencyrecord.FieldResponseStatus) } - return *v, true -} - -// ResetAmount resets all changes to the "amount" field. -func (m *PaymentOrderMutation) ResetAmount() { - m.amount = nil - m.addamount = nil -} - -// SetPayAmount sets the "pay_amount" field. -func (m *PaymentOrderMutation) SetPayAmount(f float64) { - m.pay_amount = &f - m.addpay_amount = nil -} - -// PayAmount returns the value of the "pay_amount" field in the mutation. -func (m *PaymentOrderMutation) PayAmount() (r float64, exists bool) { - v := m.pay_amount - if v == nil { - return + if m.response_body != nil { + fields = append(fields, idempotencyrecord.FieldResponseBody) } - return *v, true -} - -// OldPayAmount returns the old "pay_amount" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldPayAmount(ctx context.Context) (v float64, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPayAmount is only allowed on UpdateOne operations") + if m.error_reason != nil { + fields = append(fields, idempotencyrecord.FieldErrorReason) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPayAmount requires an ID field in the mutation") + if m.locked_until != nil { + fields = append(fields, idempotencyrecord.FieldLockedUntil) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldPayAmount: %w", err) + if m.expires_at != nil { + fields = append(fields, idempotencyrecord.FieldExpiresAt) } - return oldValue.PayAmount, nil + return fields } -// AddPayAmount adds f to the "pay_amount" field. -func (m *PaymentOrderMutation) AddPayAmount(f float64) { - if m.addpay_amount != nil { - *m.addpay_amount += f - } else { - m.addpay_amount = &f +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *IdempotencyRecordMutation) Field(name string) (ent.Value, bool) { + switch name { + case idempotencyrecord.FieldCreatedAt: + return m.CreatedAt() + case idempotencyrecord.FieldUpdatedAt: + return m.UpdatedAt() + case idempotencyrecord.FieldScope: + return m.Scope() + case idempotencyrecord.FieldIdempotencyKeyHash: + return m.IdempotencyKeyHash() + case idempotencyrecord.FieldRequestFingerprint: + return m.RequestFingerprint() + case idempotencyrecord.FieldStatus: + return m.Status() + case idempotencyrecord.FieldResponseStatus: + return m.ResponseStatus() + case idempotencyrecord.FieldResponseBody: + return m.ResponseBody() + case idempotencyrecord.FieldErrorReason: + return m.ErrorReason() + case idempotencyrecord.FieldLockedUntil: + return m.LockedUntil() + case idempotencyrecord.FieldExpiresAt: + return m.ExpiresAt() } + return nil, false } -// AddedPayAmount returns the value that was added to the "pay_amount" field in this mutation. -func (m *PaymentOrderMutation) AddedPayAmount() (r float64, exists bool) { - v := m.addpay_amount - if v == nil { - return +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *IdempotencyRecordMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case idempotencyrecord.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case idempotencyrecord.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case idempotencyrecord.FieldScope: + return m.OldScope(ctx) + case idempotencyrecord.FieldIdempotencyKeyHash: + return m.OldIdempotencyKeyHash(ctx) + case idempotencyrecord.FieldRequestFingerprint: + return m.OldRequestFingerprint(ctx) + case idempotencyrecord.FieldStatus: + return m.OldStatus(ctx) + case idempotencyrecord.FieldResponseStatus: + return m.OldResponseStatus(ctx) + case idempotencyrecord.FieldResponseBody: + return m.OldResponseBody(ctx) + case idempotencyrecord.FieldErrorReason: + return m.OldErrorReason(ctx) + case idempotencyrecord.FieldLockedUntil: + return m.OldLockedUntil(ctx) + case idempotencyrecord.FieldExpiresAt: + return m.OldExpiresAt(ctx) } - return *v, true -} - -// ResetPayAmount resets all changes to the "pay_amount" field. -func (m *PaymentOrderMutation) ResetPayAmount() { - m.pay_amount = nil - m.addpay_amount = nil + return nil, fmt.Errorf("unknown IdempotencyRecord field %s", name) } -// SetFeeRate sets the "fee_rate" field. -func (m *PaymentOrderMutation) SetFeeRate(f float64) { - m.fee_rate = &f - m.addfee_rate = nil -} - -// FeeRate returns the value of the "fee_rate" field in the mutation. -func (m *PaymentOrderMutation) FeeRate() (r float64, exists bool) { - v := m.fee_rate - if v == nil { - return +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *IdempotencyRecordMutation) SetField(name string, value ent.Value) error { + switch name { + case idempotencyrecord.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case idempotencyrecord.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case idempotencyrecord.FieldScope: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetScope(v) + return nil + case idempotencyrecord.FieldIdempotencyKeyHash: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIdempotencyKeyHash(v) + return nil + case idempotencyrecord.FieldRequestFingerprint: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRequestFingerprint(v) + return nil + case idempotencyrecord.FieldStatus: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetStatus(v) + return nil + case idempotencyrecord.FieldResponseStatus: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetResponseStatus(v) + return nil + case idempotencyrecord.FieldResponseBody: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetResponseBody(v) + return nil + case idempotencyrecord.FieldErrorReason: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetErrorReason(v) + return nil + case idempotencyrecord.FieldLockedUntil: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetLockedUntil(v) + return nil + case idempotencyrecord.FieldExpiresAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetExpiresAt(v) + return nil } - return *v, true + return fmt.Errorf("unknown IdempotencyRecord field %s", name) } -// OldFeeRate returns the old "fee_rate" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldFeeRate(ctx context.Context) (v float64, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldFeeRate is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldFeeRate requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldFeeRate: %w", err) +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *IdempotencyRecordMutation) AddedFields() []string { + var fields []string + if m.addresponse_status != nil { + fields = append(fields, idempotencyrecord.FieldResponseStatus) } - return oldValue.FeeRate, nil + return fields } -// AddFeeRate adds f to the "fee_rate" field. -func (m *PaymentOrderMutation) AddFeeRate(f float64) { - if m.addfee_rate != nil { - *m.addfee_rate += f - } else { - m.addfee_rate = &f +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *IdempotencyRecordMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case idempotencyrecord.FieldResponseStatus: + return m.AddedResponseStatus() } + return nil, false } -// AddedFeeRate returns the value that was added to the "fee_rate" field in this mutation. -func (m *PaymentOrderMutation) AddedFeeRate() (r float64, exists bool) { - v := m.addfee_rate - if v == nil { - return +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *IdempotencyRecordMutation) AddField(name string, value ent.Value) error { + switch name { + case idempotencyrecord.FieldResponseStatus: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddResponseStatus(v) + return nil } - return *v, true -} - -// ResetFeeRate resets all changes to the "fee_rate" field. -func (m *PaymentOrderMutation) ResetFeeRate() { - m.fee_rate = nil - m.addfee_rate = nil -} - -// SetRechargeCode sets the "recharge_code" field. -func (m *PaymentOrderMutation) SetRechargeCode(s string) { - m.recharge_code = &s + return fmt.Errorf("unknown IdempotencyRecord numeric field %s", name) } -// RechargeCode returns the value of the "recharge_code" field in the mutation. -func (m *PaymentOrderMutation) RechargeCode() (r string, exists bool) { - v := m.recharge_code - if v == nil { - return +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *IdempotencyRecordMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(idempotencyrecord.FieldResponseStatus) { + fields = append(fields, idempotencyrecord.FieldResponseStatus) } - return *v, true -} - -// OldRechargeCode returns the old "recharge_code" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldRechargeCode(ctx context.Context) (v string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldRechargeCode is only allowed on UpdateOne operations") + if m.FieldCleared(idempotencyrecord.FieldResponseBody) { + fields = append(fields, idempotencyrecord.FieldResponseBody) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldRechargeCode requires an ID field in the mutation") + if m.FieldCleared(idempotencyrecord.FieldErrorReason) { + fields = append(fields, idempotencyrecord.FieldErrorReason) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldRechargeCode: %w", err) + if m.FieldCleared(idempotencyrecord.FieldLockedUntil) { + fields = append(fields, idempotencyrecord.FieldLockedUntil) } - return oldValue.RechargeCode, nil -} - -// ResetRechargeCode resets all changes to the "recharge_code" field. -func (m *PaymentOrderMutation) ResetRechargeCode() { - m.recharge_code = nil -} - -// SetOutTradeNo sets the "out_trade_no" field. -func (m *PaymentOrderMutation) SetOutTradeNo(s string) { - m.out_trade_no = &s + return fields } -// OutTradeNo returns the value of the "out_trade_no" field in the mutation. -func (m *PaymentOrderMutation) OutTradeNo() (r string, exists bool) { - v := m.out_trade_no - if v == nil { - return - } - return *v, true +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *IdempotencyRecordMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok } -// OldOutTradeNo returns the old "out_trade_no" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldOutTradeNo(ctx context.Context) (v string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldOutTradeNo is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldOutTradeNo requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldOutTradeNo: %w", err) +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *IdempotencyRecordMutation) ClearField(name string) error { + switch name { + case idempotencyrecord.FieldResponseStatus: + m.ClearResponseStatus() + return nil + case idempotencyrecord.FieldResponseBody: + m.ClearResponseBody() + return nil + case idempotencyrecord.FieldErrorReason: + m.ClearErrorReason() + return nil + case idempotencyrecord.FieldLockedUntil: + m.ClearLockedUntil() + return nil } - return oldValue.OutTradeNo, nil -} - -// ResetOutTradeNo resets all changes to the "out_trade_no" field. -func (m *PaymentOrderMutation) ResetOutTradeNo() { - m.out_trade_no = nil -} - -// SetPaymentType sets the "payment_type" field. -func (m *PaymentOrderMutation) SetPaymentType(s string) { - m.payment_type = &s + return fmt.Errorf("unknown IdempotencyRecord nullable field %s", name) } -// PaymentType returns the value of the "payment_type" field in the mutation. -func (m *PaymentOrderMutation) PaymentType() (r string, exists bool) { - v := m.payment_type +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *IdempotencyRecordMutation) ResetField(name string) error { + switch name { + case idempotencyrecord.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case idempotencyrecord.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case idempotencyrecord.FieldScope: + m.ResetScope() + return nil + case idempotencyrecord.FieldIdempotencyKeyHash: + m.ResetIdempotencyKeyHash() + return nil + case idempotencyrecord.FieldRequestFingerprint: + m.ResetRequestFingerprint() + return nil + case idempotencyrecord.FieldStatus: + m.ResetStatus() + return nil + case idempotencyrecord.FieldResponseStatus: + m.ResetResponseStatus() + return nil + case idempotencyrecord.FieldResponseBody: + m.ResetResponseBody() + return nil + case idempotencyrecord.FieldErrorReason: + m.ResetErrorReason() + return nil + case idempotencyrecord.FieldLockedUntil: + m.ResetLockedUntil() + return nil + case idempotencyrecord.FieldExpiresAt: + m.ResetExpiresAt() + return nil + } + return fmt.Errorf("unknown IdempotencyRecord field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *IdempotencyRecordMutation) AddedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *IdempotencyRecordMutation) AddedIDs(name string) []ent.Value { + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *IdempotencyRecordMutation) RemovedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *IdempotencyRecordMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *IdempotencyRecordMutation) ClearedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *IdempotencyRecordMutation) EdgeCleared(name string) bool { + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *IdempotencyRecordMutation) ClearEdge(name string) error { + return fmt.Errorf("unknown IdempotencyRecord unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *IdempotencyRecordMutation) ResetEdge(name string) error { + return fmt.Errorf("unknown IdempotencyRecord edge %s", name) +} + +// IdentityAdoptionDecisionMutation represents an operation that mutates the IdentityAdoptionDecision nodes in the graph. +type IdentityAdoptionDecisionMutation struct { + config + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + adopt_display_name *bool + adopt_avatar *bool + decided_at *time.Time + clearedFields map[string]struct{} + pending_auth_session *int64 + clearedpending_auth_session bool + identity *int64 + clearedidentity bool + done bool + oldValue func(context.Context) (*IdentityAdoptionDecision, error) + predicates []predicate.IdentityAdoptionDecision +} + +var _ ent.Mutation = (*IdentityAdoptionDecisionMutation)(nil) + +// identityadoptiondecisionOption allows management of the mutation configuration using functional options. +type identityadoptiondecisionOption func(*IdentityAdoptionDecisionMutation) + +// newIdentityAdoptionDecisionMutation creates new mutation for the IdentityAdoptionDecision entity. +func newIdentityAdoptionDecisionMutation(c config, op Op, opts ...identityadoptiondecisionOption) *IdentityAdoptionDecisionMutation { + m := &IdentityAdoptionDecisionMutation{ + config: c, + op: op, + typ: TypeIdentityAdoptionDecision, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withIdentityAdoptionDecisionID sets the ID field of the mutation. +func withIdentityAdoptionDecisionID(id int64) identityadoptiondecisionOption { + return func(m *IdentityAdoptionDecisionMutation) { + var ( + err error + once sync.Once + value *IdentityAdoptionDecision + ) + m.oldValue = func(ctx context.Context) (*IdentityAdoptionDecision, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().IdentityAdoptionDecision.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withIdentityAdoptionDecision sets the old IdentityAdoptionDecision of the mutation. +func withIdentityAdoptionDecision(node *IdentityAdoptionDecision) identityadoptiondecisionOption { + return func(m *IdentityAdoptionDecisionMutation) { + m.oldValue = func(context.Context) (*IdentityAdoptionDecision, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m IdentityAdoptionDecisionMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m IdentityAdoptionDecisionMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *IdentityAdoptionDecisionMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *IdentityAdoptionDecisionMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().IdentityAdoptionDecision.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *IdentityAdoptionDecisionMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *IdentityAdoptionDecisionMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at if v == nil { return } return *v, true } -// OldPaymentType returns the old "payment_type" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldCreatedAt returns the old "created_at" field's value of the IdentityAdoptionDecision entity. +// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldPaymentType(ctx context.Context) (v string, err error) { +func (m *IdentityAdoptionDecisionMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPaymentType is only allowed on UpdateOne operations") + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPaymentType requires an ID field in the mutation") + return v, errors.New("OldCreatedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldPaymentType: %w", err) + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) } - return oldValue.PaymentType, nil + return oldValue.CreatedAt, nil } -// ResetPaymentType resets all changes to the "payment_type" field. -func (m *PaymentOrderMutation) ResetPaymentType() { - m.payment_type = nil +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *IdentityAdoptionDecisionMutation) ResetCreatedAt() { + m.created_at = nil } -// SetPaymentTradeNo sets the "payment_trade_no" field. -func (m *PaymentOrderMutation) SetPaymentTradeNo(s string) { - m.payment_trade_no = &s +// SetUpdatedAt sets the "updated_at" field. +func (m *IdentityAdoptionDecisionMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t } -// PaymentTradeNo returns the value of the "payment_trade_no" field in the mutation. -func (m *PaymentOrderMutation) PaymentTradeNo() (r string, exists bool) { - v := m.payment_trade_no +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *IdentityAdoptionDecisionMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at if v == nil { return } return *v, true } -// OldPaymentTradeNo returns the old "payment_trade_no" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldUpdatedAt returns the old "updated_at" field's value of the IdentityAdoptionDecision entity. +// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldPaymentTradeNo(ctx context.Context) (v string, err error) { +func (m *IdentityAdoptionDecisionMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPaymentTradeNo is only allowed on UpdateOne operations") + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPaymentTradeNo requires an ID field in the mutation") + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldPaymentTradeNo: %w", err) + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) } - return oldValue.PaymentTradeNo, nil + return oldValue.UpdatedAt, nil } -// ResetPaymentTradeNo resets all changes to the "payment_trade_no" field. -func (m *PaymentOrderMutation) ResetPaymentTradeNo() { - m.payment_trade_no = nil +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *IdentityAdoptionDecisionMutation) ResetUpdatedAt() { + m.updated_at = nil } -// SetPayURL sets the "pay_url" field. -func (m *PaymentOrderMutation) SetPayURL(s string) { - m.pay_url = &s +// SetPendingAuthSessionID sets the "pending_auth_session_id" field. +func (m *IdentityAdoptionDecisionMutation) SetPendingAuthSessionID(i int64) { + m.pending_auth_session = &i } -// PayURL returns the value of the "pay_url" field in the mutation. -func (m *PaymentOrderMutation) PayURL() (r string, exists bool) { - v := m.pay_url +// PendingAuthSessionID returns the value of the "pending_auth_session_id" field in the mutation. +func (m *IdentityAdoptionDecisionMutation) PendingAuthSessionID() (r int64, exists bool) { + v := m.pending_auth_session if v == nil { return } return *v, true } -// OldPayURL returns the old "pay_url" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldPendingAuthSessionID returns the old "pending_auth_session_id" field's value of the IdentityAdoptionDecision entity. +// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldPayURL(ctx context.Context) (v *string, err error) { +func (m *IdentityAdoptionDecisionMutation) OldPendingAuthSessionID(ctx context.Context) (v int64, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPayURL is only allowed on UpdateOne operations") + return v, errors.New("OldPendingAuthSessionID is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPayURL requires an ID field in the mutation") + return v, errors.New("OldPendingAuthSessionID requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldPayURL: %w", err) + return v, fmt.Errorf("querying old value for OldPendingAuthSessionID: %w", err) } - return oldValue.PayURL, nil + return oldValue.PendingAuthSessionID, nil } -// ClearPayURL clears the value of the "pay_url" field. -func (m *PaymentOrderMutation) ClearPayURL() { - m.pay_url = nil - m.clearedFields[paymentorder.FieldPayURL] = struct{}{} +// ResetPendingAuthSessionID resets all changes to the "pending_auth_session_id" field. +func (m *IdentityAdoptionDecisionMutation) ResetPendingAuthSessionID() { + m.pending_auth_session = nil } -// PayURLCleared returns if the "pay_url" field was cleared in this mutation. -func (m *PaymentOrderMutation) PayURLCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldPayURL] - return ok +// SetIdentityID sets the "identity_id" field. +func (m *IdentityAdoptionDecisionMutation) SetIdentityID(i int64) { + m.identity = &i } -// ResetPayURL resets all changes to the "pay_url" field. -func (m *PaymentOrderMutation) ResetPayURL() { - m.pay_url = nil - delete(m.clearedFields, paymentorder.FieldPayURL) +// IdentityID returns the value of the "identity_id" field in the mutation. +func (m *IdentityAdoptionDecisionMutation) IdentityID() (r int64, exists bool) { + v := m.identity + if v == nil { + return + } + return *v, true } -// SetQrCode sets the "qr_code" field. -func (m *PaymentOrderMutation) SetQrCode(s string) { - m.qr_code = &s -} - -// QrCode returns the value of the "qr_code" field in the mutation. -func (m *PaymentOrderMutation) QrCode() (r string, exists bool) { - v := m.qr_code - if v == nil { - return - } - return *v, true -} - -// OldQrCode returns the old "qr_code" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldIdentityID returns the old "identity_id" field's value of the IdentityAdoptionDecision entity. +// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldQrCode(ctx context.Context) (v *string, err error) { +func (m *IdentityAdoptionDecisionMutation) OldIdentityID(ctx context.Context) (v *int64, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldQrCode is only allowed on UpdateOne operations") + return v, errors.New("OldIdentityID is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldQrCode requires an ID field in the mutation") + return v, errors.New("OldIdentityID requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldQrCode: %w", err) + return v, fmt.Errorf("querying old value for OldIdentityID: %w", err) } - return oldValue.QrCode, nil + return oldValue.IdentityID, nil } -// ClearQrCode clears the value of the "qr_code" field. -func (m *PaymentOrderMutation) ClearQrCode() { - m.qr_code = nil - m.clearedFields[paymentorder.FieldQrCode] = struct{}{} +// ClearIdentityID clears the value of the "identity_id" field. +func (m *IdentityAdoptionDecisionMutation) ClearIdentityID() { + m.identity = nil + m.clearedFields[identityadoptiondecision.FieldIdentityID] = struct{}{} } -// QrCodeCleared returns if the "qr_code" field was cleared in this mutation. -func (m *PaymentOrderMutation) QrCodeCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldQrCode] +// IdentityIDCleared returns if the "identity_id" field was cleared in this mutation. +func (m *IdentityAdoptionDecisionMutation) IdentityIDCleared() bool { + _, ok := m.clearedFields[identityadoptiondecision.FieldIdentityID] return ok } -// ResetQrCode resets all changes to the "qr_code" field. -func (m *PaymentOrderMutation) ResetQrCode() { - m.qr_code = nil - delete(m.clearedFields, paymentorder.FieldQrCode) +// ResetIdentityID resets all changes to the "identity_id" field. +func (m *IdentityAdoptionDecisionMutation) ResetIdentityID() { + m.identity = nil + delete(m.clearedFields, identityadoptiondecision.FieldIdentityID) } -// SetQrCodeImg sets the "qr_code_img" field. -func (m *PaymentOrderMutation) SetQrCodeImg(s string) { - m.qr_code_img = &s +// SetAdoptDisplayName sets the "adopt_display_name" field. +func (m *IdentityAdoptionDecisionMutation) SetAdoptDisplayName(b bool) { + m.adopt_display_name = &b } -// QrCodeImg returns the value of the "qr_code_img" field in the mutation. -func (m *PaymentOrderMutation) QrCodeImg() (r string, exists bool) { - v := m.qr_code_img +// AdoptDisplayName returns the value of the "adopt_display_name" field in the mutation. +func (m *IdentityAdoptionDecisionMutation) AdoptDisplayName() (r bool, exists bool) { + v := m.adopt_display_name if v == nil { return } return *v, true } -// OldQrCodeImg returns the old "qr_code_img" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldAdoptDisplayName returns the old "adopt_display_name" field's value of the IdentityAdoptionDecision entity. +// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldQrCodeImg(ctx context.Context) (v *string, err error) { +func (m *IdentityAdoptionDecisionMutation) OldAdoptDisplayName(ctx context.Context) (v bool, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldQrCodeImg is only allowed on UpdateOne operations") + return v, errors.New("OldAdoptDisplayName is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldQrCodeImg requires an ID field in the mutation") + return v, errors.New("OldAdoptDisplayName requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldQrCodeImg: %w", err) + return v, fmt.Errorf("querying old value for OldAdoptDisplayName: %w", err) } - return oldValue.QrCodeImg, nil -} - -// ClearQrCodeImg clears the value of the "qr_code_img" field. -func (m *PaymentOrderMutation) ClearQrCodeImg() { - m.qr_code_img = nil - m.clearedFields[paymentorder.FieldQrCodeImg] = struct{}{} -} - -// QrCodeImgCleared returns if the "qr_code_img" field was cleared in this mutation. -func (m *PaymentOrderMutation) QrCodeImgCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldQrCodeImg] - return ok + return oldValue.AdoptDisplayName, nil } -// ResetQrCodeImg resets all changes to the "qr_code_img" field. -func (m *PaymentOrderMutation) ResetQrCodeImg() { - m.qr_code_img = nil - delete(m.clearedFields, paymentorder.FieldQrCodeImg) +// ResetAdoptDisplayName resets all changes to the "adopt_display_name" field. +func (m *IdentityAdoptionDecisionMutation) ResetAdoptDisplayName() { + m.adopt_display_name = nil } -// SetOrderType sets the "order_type" field. -func (m *PaymentOrderMutation) SetOrderType(s string) { - m.order_type = &s +// SetAdoptAvatar sets the "adopt_avatar" field. +func (m *IdentityAdoptionDecisionMutation) SetAdoptAvatar(b bool) { + m.adopt_avatar = &b } -// OrderType returns the value of the "order_type" field in the mutation. -func (m *PaymentOrderMutation) OrderType() (r string, exists bool) { - v := m.order_type +// AdoptAvatar returns the value of the "adopt_avatar" field in the mutation. +func (m *IdentityAdoptionDecisionMutation) AdoptAvatar() (r bool, exists bool) { + v := m.adopt_avatar if v == nil { return } return *v, true } -// OldOrderType returns the old "order_type" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldAdoptAvatar returns the old "adopt_avatar" field's value of the IdentityAdoptionDecision entity. +// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldOrderType(ctx context.Context) (v string, err error) { +func (m *IdentityAdoptionDecisionMutation) OldAdoptAvatar(ctx context.Context) (v bool, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldOrderType is only allowed on UpdateOne operations") + return v, errors.New("OldAdoptAvatar is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldOrderType requires an ID field in the mutation") + return v, errors.New("OldAdoptAvatar requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldOrderType: %w", err) + return v, fmt.Errorf("querying old value for OldAdoptAvatar: %w", err) } - return oldValue.OrderType, nil + return oldValue.AdoptAvatar, nil } -// ResetOrderType resets all changes to the "order_type" field. -func (m *PaymentOrderMutation) ResetOrderType() { - m.order_type = nil +// ResetAdoptAvatar resets all changes to the "adopt_avatar" field. +func (m *IdentityAdoptionDecisionMutation) ResetAdoptAvatar() { + m.adopt_avatar = nil } -// SetPlanID sets the "plan_id" field. -func (m *PaymentOrderMutation) SetPlanID(i int64) { - m.plan_id = &i - m.addplan_id = nil +// SetDecidedAt sets the "decided_at" field. +func (m *IdentityAdoptionDecisionMutation) SetDecidedAt(t time.Time) { + m.decided_at = &t } -// PlanID returns the value of the "plan_id" field in the mutation. -func (m *PaymentOrderMutation) PlanID() (r int64, exists bool) { - v := m.plan_id +// DecidedAt returns the value of the "decided_at" field in the mutation. +func (m *IdentityAdoptionDecisionMutation) DecidedAt() (r time.Time, exists bool) { + v := m.decided_at if v == nil { return } return *v, true } -// OldPlanID returns the old "plan_id" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldDecidedAt returns the old "decided_at" field's value of the IdentityAdoptionDecision entity. +// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldPlanID(ctx context.Context) (v *int64, err error) { +func (m *IdentityAdoptionDecisionMutation) OldDecidedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPlanID is only allowed on UpdateOne operations") + return v, errors.New("OldDecidedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPlanID requires an ID field in the mutation") + return v, errors.New("OldDecidedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldPlanID: %w", err) + return v, fmt.Errorf("querying old value for OldDecidedAt: %w", err) } - return oldValue.PlanID, nil + return oldValue.DecidedAt, nil } -// AddPlanID adds i to the "plan_id" field. -func (m *PaymentOrderMutation) AddPlanID(i int64) { - if m.addplan_id != nil { - *m.addplan_id += i - } else { - m.addplan_id = &i - } +// ResetDecidedAt resets all changes to the "decided_at" field. +func (m *IdentityAdoptionDecisionMutation) ResetDecidedAt() { + m.decided_at = nil } -// AddedPlanID returns the value that was added to the "plan_id" field in this mutation. -func (m *PaymentOrderMutation) AddedPlanID() (r int64, exists bool) { - v := m.addplan_id - if v == nil { - return - } - return *v, true +// ClearPendingAuthSession clears the "pending_auth_session" edge to the PendingAuthSession entity. +func (m *IdentityAdoptionDecisionMutation) ClearPendingAuthSession() { + m.clearedpending_auth_session = true + m.clearedFields[identityadoptiondecision.FieldPendingAuthSessionID] = struct{}{} } -// ClearPlanID clears the value of the "plan_id" field. -func (m *PaymentOrderMutation) ClearPlanID() { - m.plan_id = nil - m.addplan_id = nil - m.clearedFields[paymentorder.FieldPlanID] = struct{}{} +// PendingAuthSessionCleared reports if the "pending_auth_session" edge to the PendingAuthSession entity was cleared. +func (m *IdentityAdoptionDecisionMutation) PendingAuthSessionCleared() bool { + return m.clearedpending_auth_session } -// PlanIDCleared returns if the "plan_id" field was cleared in this mutation. -func (m *PaymentOrderMutation) PlanIDCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldPlanID] - return ok +// PendingAuthSessionIDs returns the "pending_auth_session" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// PendingAuthSessionID instead. It exists only for internal usage by the builders. +func (m *IdentityAdoptionDecisionMutation) PendingAuthSessionIDs() (ids []int64) { + if id := m.pending_auth_session; id != nil { + ids = append(ids, *id) + } + return } -// ResetPlanID resets all changes to the "plan_id" field. -func (m *PaymentOrderMutation) ResetPlanID() { - m.plan_id = nil - m.addplan_id = nil - delete(m.clearedFields, paymentorder.FieldPlanID) +// ResetPendingAuthSession resets all changes to the "pending_auth_session" edge. +func (m *IdentityAdoptionDecisionMutation) ResetPendingAuthSession() { + m.pending_auth_session = nil + m.clearedpending_auth_session = false } -// SetSubscriptionGroupID sets the "subscription_group_id" field. -func (m *PaymentOrderMutation) SetSubscriptionGroupID(i int64) { - m.subscription_group_id = &i - m.addsubscription_group_id = nil +// ClearIdentity clears the "identity" edge to the AuthIdentity entity. +func (m *IdentityAdoptionDecisionMutation) ClearIdentity() { + m.clearedidentity = true + m.clearedFields[identityadoptiondecision.FieldIdentityID] = struct{}{} } -// SubscriptionGroupID returns the value of the "subscription_group_id" field in the mutation. -func (m *PaymentOrderMutation) SubscriptionGroupID() (r int64, exists bool) { - v := m.subscription_group_id - if v == nil { - return - } - return *v, true +// IdentityCleared reports if the "identity" edge to the AuthIdentity entity was cleared. +func (m *IdentityAdoptionDecisionMutation) IdentityCleared() bool { + return m.IdentityIDCleared() || m.clearedidentity } -// OldSubscriptionGroupID returns the old "subscription_group_id" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldSubscriptionGroupID(ctx context.Context) (v *int64, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSubscriptionGroupID is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSubscriptionGroupID requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldSubscriptionGroupID: %w", err) +// IdentityIDs returns the "identity" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// IdentityID instead. It exists only for internal usage by the builders. +func (m *IdentityAdoptionDecisionMutation) IdentityIDs() (ids []int64) { + if id := m.identity; id != nil { + ids = append(ids, *id) } - return oldValue.SubscriptionGroupID, nil + return } -// AddSubscriptionGroupID adds i to the "subscription_group_id" field. -func (m *PaymentOrderMutation) AddSubscriptionGroupID(i int64) { - if m.addsubscription_group_id != nil { - *m.addsubscription_group_id += i - } else { - m.addsubscription_group_id = &i - } +// ResetIdentity resets all changes to the "identity" edge. +func (m *IdentityAdoptionDecisionMutation) ResetIdentity() { + m.identity = nil + m.clearedidentity = false } -// AddedSubscriptionGroupID returns the value that was added to the "subscription_group_id" field in this mutation. -func (m *PaymentOrderMutation) AddedSubscriptionGroupID() (r int64, exists bool) { - v := m.addsubscription_group_id - if v == nil { - return - } - return *v, true +// Where appends a list predicates to the IdentityAdoptionDecisionMutation builder. +func (m *IdentityAdoptionDecisionMutation) Where(ps ...predicate.IdentityAdoptionDecision) { + m.predicates = append(m.predicates, ps...) } -// ClearSubscriptionGroupID clears the value of the "subscription_group_id" field. -func (m *PaymentOrderMutation) ClearSubscriptionGroupID() { - m.subscription_group_id = nil - m.addsubscription_group_id = nil - m.clearedFields[paymentorder.FieldSubscriptionGroupID] = struct{}{} +// WhereP appends storage-level predicates to the IdentityAdoptionDecisionMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *IdentityAdoptionDecisionMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.IdentityAdoptionDecision, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) } -// SubscriptionGroupIDCleared returns if the "subscription_group_id" field was cleared in this mutation. -func (m *PaymentOrderMutation) SubscriptionGroupIDCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldSubscriptionGroupID] - return ok +// Op returns the operation name. +func (m *IdentityAdoptionDecisionMutation) Op() Op { + return m.op } -// ResetSubscriptionGroupID resets all changes to the "subscription_group_id" field. -func (m *PaymentOrderMutation) ResetSubscriptionGroupID() { - m.subscription_group_id = nil - m.addsubscription_group_id = nil - delete(m.clearedFields, paymentorder.FieldSubscriptionGroupID) +// SetOp allows setting the mutation operation. +func (m *IdentityAdoptionDecisionMutation) SetOp(op Op) { + m.op = op } -// SetSubscriptionDays sets the "subscription_days" field. -func (m *PaymentOrderMutation) SetSubscriptionDays(i int) { - m.subscription_days = &i - m.addsubscription_days = nil +// Type returns the node type of this mutation (IdentityAdoptionDecision). +func (m *IdentityAdoptionDecisionMutation) Type() string { + return m.typ } -// SubscriptionDays returns the value of the "subscription_days" field in the mutation. -func (m *PaymentOrderMutation) SubscriptionDays() (r int, exists bool) { - v := m.subscription_days - if v == nil { - return +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *IdentityAdoptionDecisionMutation) Fields() []string { + fields := make([]string, 0, 7) + if m.created_at != nil { + fields = append(fields, identityadoptiondecision.FieldCreatedAt) } - return *v, true -} - -// OldSubscriptionDays returns the old "subscription_days" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldSubscriptionDays(ctx context.Context) (v *int, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSubscriptionDays is only allowed on UpdateOne operations") + if m.updated_at != nil { + fields = append(fields, identityadoptiondecision.FieldUpdatedAt) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSubscriptionDays requires an ID field in the mutation") + if m.pending_auth_session != nil { + fields = append(fields, identityadoptiondecision.FieldPendingAuthSessionID) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldSubscriptionDays: %w", err) + if m.identity != nil { + fields = append(fields, identityadoptiondecision.FieldIdentityID) } - return oldValue.SubscriptionDays, nil -} - -// AddSubscriptionDays adds i to the "subscription_days" field. -func (m *PaymentOrderMutation) AddSubscriptionDays(i int) { - if m.addsubscription_days != nil { - *m.addsubscription_days += i - } else { - m.addsubscription_days = &i + if m.adopt_display_name != nil { + fields = append(fields, identityadoptiondecision.FieldAdoptDisplayName) + } + if m.adopt_avatar != nil { + fields = append(fields, identityadoptiondecision.FieldAdoptAvatar) } + if m.decided_at != nil { + fields = append(fields, identityadoptiondecision.FieldDecidedAt) + } + return fields } -// AddedSubscriptionDays returns the value that was added to the "subscription_days" field in this mutation. -func (m *PaymentOrderMutation) AddedSubscriptionDays() (r int, exists bool) { - v := m.addsubscription_days - if v == nil { - return +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *IdentityAdoptionDecisionMutation) Field(name string) (ent.Value, bool) { + switch name { + case identityadoptiondecision.FieldCreatedAt: + return m.CreatedAt() + case identityadoptiondecision.FieldUpdatedAt: + return m.UpdatedAt() + case identityadoptiondecision.FieldPendingAuthSessionID: + return m.PendingAuthSessionID() + case identityadoptiondecision.FieldIdentityID: + return m.IdentityID() + case identityadoptiondecision.FieldAdoptDisplayName: + return m.AdoptDisplayName() + case identityadoptiondecision.FieldAdoptAvatar: + return m.AdoptAvatar() + case identityadoptiondecision.FieldDecidedAt: + return m.DecidedAt() } - return *v, true + return nil, false } -// ClearSubscriptionDays clears the value of the "subscription_days" field. -func (m *PaymentOrderMutation) ClearSubscriptionDays() { - m.subscription_days = nil - m.addsubscription_days = nil - m.clearedFields[paymentorder.FieldSubscriptionDays] = struct{}{} +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *IdentityAdoptionDecisionMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case identityadoptiondecision.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case identityadoptiondecision.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case identityadoptiondecision.FieldPendingAuthSessionID: + return m.OldPendingAuthSessionID(ctx) + case identityadoptiondecision.FieldIdentityID: + return m.OldIdentityID(ctx) + case identityadoptiondecision.FieldAdoptDisplayName: + return m.OldAdoptDisplayName(ctx) + case identityadoptiondecision.FieldAdoptAvatar: + return m.OldAdoptAvatar(ctx) + case identityadoptiondecision.FieldDecidedAt: + return m.OldDecidedAt(ctx) + } + return nil, fmt.Errorf("unknown IdentityAdoptionDecision field %s", name) } -// SubscriptionDaysCleared returns if the "subscription_days" field was cleared in this mutation. -func (m *PaymentOrderMutation) SubscriptionDaysCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldSubscriptionDays] - return ok -} - -// ResetSubscriptionDays resets all changes to the "subscription_days" field. -func (m *PaymentOrderMutation) ResetSubscriptionDays() { - m.subscription_days = nil - m.addsubscription_days = nil - delete(m.clearedFields, paymentorder.FieldSubscriptionDays) +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *IdentityAdoptionDecisionMutation) SetField(name string, value ent.Value) error { + switch name { + case identityadoptiondecision.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case identityadoptiondecision.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case identityadoptiondecision.FieldPendingAuthSessionID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPendingAuthSessionID(v) + return nil + case identityadoptiondecision.FieldIdentityID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIdentityID(v) + return nil + case identityadoptiondecision.FieldAdoptDisplayName: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAdoptDisplayName(v) + return nil + case identityadoptiondecision.FieldAdoptAvatar: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAdoptAvatar(v) + return nil + case identityadoptiondecision.FieldDecidedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDecidedAt(v) + return nil + } + return fmt.Errorf("unknown IdentityAdoptionDecision field %s", name) } -// SetProviderInstanceID sets the "provider_instance_id" field. -func (m *PaymentOrderMutation) SetProviderInstanceID(s string) { - m.provider_instance_id = &s +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *IdentityAdoptionDecisionMutation) AddedFields() []string { + var fields []string + return fields } -// ProviderInstanceID returns the value of the "provider_instance_id" field in the mutation. -func (m *PaymentOrderMutation) ProviderInstanceID() (r string, exists bool) { - v := m.provider_instance_id - if v == nil { - return +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *IdentityAdoptionDecisionMutation) AddedField(name string) (ent.Value, bool) { + switch name { } - return *v, true + return nil, false } -// OldProviderInstanceID returns the old "provider_instance_id" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldProviderInstanceID(ctx context.Context) (v *string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldProviderInstanceID is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldProviderInstanceID requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldProviderInstanceID: %w", err) +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *IdentityAdoptionDecisionMutation) AddField(name string, value ent.Value) error { + switch name { } - return oldValue.ProviderInstanceID, nil + return fmt.Errorf("unknown IdentityAdoptionDecision numeric field %s", name) } -// ClearProviderInstanceID clears the value of the "provider_instance_id" field. -func (m *PaymentOrderMutation) ClearProviderInstanceID() { - m.provider_instance_id = nil - m.clearedFields[paymentorder.FieldProviderInstanceID] = struct{}{} +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *IdentityAdoptionDecisionMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(identityadoptiondecision.FieldIdentityID) { + fields = append(fields, identityadoptiondecision.FieldIdentityID) + } + return fields } -// ProviderInstanceIDCleared returns if the "provider_instance_id" field was cleared in this mutation. -func (m *PaymentOrderMutation) ProviderInstanceIDCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldProviderInstanceID] +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *IdentityAdoptionDecisionMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] return ok } -// ResetProviderInstanceID resets all changes to the "provider_instance_id" field. -func (m *PaymentOrderMutation) ResetProviderInstanceID() { - m.provider_instance_id = nil - delete(m.clearedFields, paymentorder.FieldProviderInstanceID) -} - -// SetStatus sets the "status" field. -func (m *PaymentOrderMutation) SetStatus(s string) { - m.status = &s +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *IdentityAdoptionDecisionMutation) ClearField(name string) error { + switch name { + case identityadoptiondecision.FieldIdentityID: + m.ClearIdentityID() + return nil + } + return fmt.Errorf("unknown IdentityAdoptionDecision nullable field %s", name) } -// Status returns the value of the "status" field in the mutation. -func (m *PaymentOrderMutation) Status() (r string, exists bool) { - v := m.status - if v == nil { - return +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *IdentityAdoptionDecisionMutation) ResetField(name string) error { + switch name { + case identityadoptiondecision.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case identityadoptiondecision.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case identityadoptiondecision.FieldPendingAuthSessionID: + m.ResetPendingAuthSessionID() + return nil + case identityadoptiondecision.FieldIdentityID: + m.ResetIdentityID() + return nil + case identityadoptiondecision.FieldAdoptDisplayName: + m.ResetAdoptDisplayName() + return nil + case identityadoptiondecision.FieldAdoptAvatar: + m.ResetAdoptAvatar() + return nil + case identityadoptiondecision.FieldDecidedAt: + m.ResetDecidedAt() + return nil } - return *v, true + return fmt.Errorf("unknown IdentityAdoptionDecision field %s", name) } -// OldStatus returns the old "status" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldStatus(ctx context.Context) (v string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldStatus is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldStatus requires an ID field in the mutation") +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *IdentityAdoptionDecisionMutation) AddedEdges() []string { + edges := make([]string, 0, 2) + if m.pending_auth_session != nil { + edges = append(edges, identityadoptiondecision.EdgePendingAuthSession) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldStatus: %w", err) + if m.identity != nil { + edges = append(edges, identityadoptiondecision.EdgeIdentity) } - return oldValue.Status, nil + return edges } -// ResetStatus resets all changes to the "status" field. -func (m *PaymentOrderMutation) ResetStatus() { - m.status = nil +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *IdentityAdoptionDecisionMutation) AddedIDs(name string) []ent.Value { + switch name { + case identityadoptiondecision.EdgePendingAuthSession: + if id := m.pending_auth_session; id != nil { + return []ent.Value{*id} + } + case identityadoptiondecision.EdgeIdentity: + if id := m.identity; id != nil { + return []ent.Value{*id} + } + } + return nil } -// SetRefundAmount sets the "refund_amount" field. -func (m *PaymentOrderMutation) SetRefundAmount(f float64) { - m.refund_amount = &f - m.addrefund_amount = nil +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *IdentityAdoptionDecisionMutation) RemovedEdges() []string { + edges := make([]string, 0, 2) + return edges } -// RefundAmount returns the value of the "refund_amount" field in the mutation. -func (m *PaymentOrderMutation) RefundAmount() (r float64, exists bool) { - v := m.refund_amount - if v == nil { - return - } - return *v, true +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *IdentityAdoptionDecisionMutation) RemovedIDs(name string) []ent.Value { + return nil } -// OldRefundAmount returns the old "refund_amount" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldRefundAmount(ctx context.Context) (v float64, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldRefundAmount is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldRefundAmount requires an ID field in the mutation") +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *IdentityAdoptionDecisionMutation) ClearedEdges() []string { + edges := make([]string, 0, 2) + if m.clearedpending_auth_session { + edges = append(edges, identityadoptiondecision.EdgePendingAuthSession) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldRefundAmount: %w", err) + if m.clearedidentity { + edges = append(edges, identityadoptiondecision.EdgeIdentity) } - return oldValue.RefundAmount, nil + return edges } -// AddRefundAmount adds f to the "refund_amount" field. -func (m *PaymentOrderMutation) AddRefundAmount(f float64) { - if m.addrefund_amount != nil { - *m.addrefund_amount += f - } else { - m.addrefund_amount = &f +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *IdentityAdoptionDecisionMutation) EdgeCleared(name string) bool { + switch name { + case identityadoptiondecision.EdgePendingAuthSession: + return m.clearedpending_auth_session + case identityadoptiondecision.EdgeIdentity: + return m.clearedidentity } + return false } -// AddedRefundAmount returns the value that was added to the "refund_amount" field in this mutation. -func (m *PaymentOrderMutation) AddedRefundAmount() (r float64, exists bool) { - v := m.addrefund_amount - if v == nil { - return +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *IdentityAdoptionDecisionMutation) ClearEdge(name string) error { + switch name { + case identityadoptiondecision.EdgePendingAuthSession: + m.ClearPendingAuthSession() + return nil + case identityadoptiondecision.EdgeIdentity: + m.ClearIdentity() + return nil } - return *v, true + return fmt.Errorf("unknown IdentityAdoptionDecision unique edge %s", name) } -// ResetRefundAmount resets all changes to the "refund_amount" field. -func (m *PaymentOrderMutation) ResetRefundAmount() { - m.refund_amount = nil - m.addrefund_amount = nil +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *IdentityAdoptionDecisionMutation) ResetEdge(name string) error { + switch name { + case identityadoptiondecision.EdgePendingAuthSession: + m.ResetPendingAuthSession() + return nil + case identityadoptiondecision.EdgeIdentity: + m.ResetIdentity() + return nil + } + return fmt.Errorf("unknown IdentityAdoptionDecision edge %s", name) } -// SetRefundReason sets the "refund_reason" field. -func (m *PaymentOrderMutation) SetRefundReason(s string) { - m.refund_reason = &s +// PaymentAuditLogMutation represents an operation that mutates the PaymentAuditLog nodes in the graph. +type PaymentAuditLogMutation struct { + config + op Op + typ string + id *int64 + order_id *string + action *string + detail *string + operator *string + created_at *time.Time + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*PaymentAuditLog, error) + predicates []predicate.PaymentAuditLog } -// RefundReason returns the value of the "refund_reason" field in the mutation. -func (m *PaymentOrderMutation) RefundReason() (r string, exists bool) { - v := m.refund_reason - if v == nil { - return - } - return *v, true -} +var _ ent.Mutation = (*PaymentAuditLogMutation)(nil) -// OldRefundReason returns the old "refund_reason" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldRefundReason(ctx context.Context) (v *string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldRefundReason is only allowed on UpdateOne operations") +// paymentauditlogOption allows management of the mutation configuration using functional options. +type paymentauditlogOption func(*PaymentAuditLogMutation) + +// newPaymentAuditLogMutation creates new mutation for the PaymentAuditLog entity. +func newPaymentAuditLogMutation(c config, op Op, opts ...paymentauditlogOption) *PaymentAuditLogMutation { + m := &PaymentAuditLogMutation{ + config: c, + op: op, + typ: TypePaymentAuditLog, + clearedFields: make(map[string]struct{}), } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldRefundReason requires an ID field in the mutation") + for _, opt := range opts { + opt(m) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldRefundReason: %w", err) + return m +} + +// withPaymentAuditLogID sets the ID field of the mutation. +func withPaymentAuditLogID(id int64) paymentauditlogOption { + return func(m *PaymentAuditLogMutation) { + var ( + err error + once sync.Once + value *PaymentAuditLog + ) + m.oldValue = func(ctx context.Context) (*PaymentAuditLog, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().PaymentAuditLog.Get(ctx, id) + } + }) + return value, err + } + m.id = &id } - return oldValue.RefundReason, nil } -// ClearRefundReason clears the value of the "refund_reason" field. -func (m *PaymentOrderMutation) ClearRefundReason() { - m.refund_reason = nil - m.clearedFields[paymentorder.FieldRefundReason] = struct{}{} +// withPaymentAuditLog sets the old PaymentAuditLog of the mutation. +func withPaymentAuditLog(node *PaymentAuditLog) paymentauditlogOption { + return func(m *PaymentAuditLogMutation) { + m.oldValue = func(context.Context) (*PaymentAuditLog, error) { + return node, nil + } + m.id = &node.ID + } } -// RefundReasonCleared returns if the "refund_reason" field was cleared in this mutation. -func (m *PaymentOrderMutation) RefundReasonCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldRefundReason] - return ok +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m PaymentAuditLogMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client } -// ResetRefundReason resets all changes to the "refund_reason" field. -func (m *PaymentOrderMutation) ResetRefundReason() { - m.refund_reason = nil - delete(m.clearedFields, paymentorder.FieldRefundReason) +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m PaymentAuditLogMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil } -// SetRefundAt sets the "refund_at" field. -func (m *PaymentOrderMutation) SetRefundAt(t time.Time) { - m.refund_at = &t +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *PaymentAuditLogMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true } -// RefundAt returns the value of the "refund_at" field in the mutation. -func (m *PaymentOrderMutation) RefundAt() (r time.Time, exists bool) { - v := m.refund_at +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *PaymentAuditLogMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().PaymentAuditLog.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetOrderID sets the "order_id" field. +func (m *PaymentAuditLogMutation) SetOrderID(s string) { + m.order_id = &s +} + +// OrderID returns the value of the "order_id" field in the mutation. +func (m *PaymentAuditLogMutation) OrderID() (r string, exists bool) { + v := m.order_id if v == nil { return } return *v, true } -// OldRefundAt returns the old "refund_at" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldOrderID returns the old "order_id" field's value of the PaymentAuditLog entity. +// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldRefundAt(ctx context.Context) (v *time.Time, err error) { +func (m *PaymentAuditLogMutation) OldOrderID(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldRefundAt is only allowed on UpdateOne operations") + return v, errors.New("OldOrderID is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldRefundAt requires an ID field in the mutation") + return v, errors.New("OldOrderID requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldRefundAt: %w", err) + return v, fmt.Errorf("querying old value for OldOrderID: %w", err) } - return oldValue.RefundAt, nil -} - -// ClearRefundAt clears the value of the "refund_at" field. -func (m *PaymentOrderMutation) ClearRefundAt() { - m.refund_at = nil - m.clearedFields[paymentorder.FieldRefundAt] = struct{}{} -} - -// RefundAtCleared returns if the "refund_at" field was cleared in this mutation. -func (m *PaymentOrderMutation) RefundAtCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldRefundAt] - return ok + return oldValue.OrderID, nil } -// ResetRefundAt resets all changes to the "refund_at" field. -func (m *PaymentOrderMutation) ResetRefundAt() { - m.refund_at = nil - delete(m.clearedFields, paymentorder.FieldRefundAt) +// ResetOrderID resets all changes to the "order_id" field. +func (m *PaymentAuditLogMutation) ResetOrderID() { + m.order_id = nil } -// SetForceRefund sets the "force_refund" field. -func (m *PaymentOrderMutation) SetForceRefund(b bool) { - m.force_refund = &b +// SetAction sets the "action" field. +func (m *PaymentAuditLogMutation) SetAction(s string) { + m.action = &s } -// ForceRefund returns the value of the "force_refund" field in the mutation. -func (m *PaymentOrderMutation) ForceRefund() (r bool, exists bool) { - v := m.force_refund +// Action returns the value of the "action" field in the mutation. +func (m *PaymentAuditLogMutation) Action() (r string, exists bool) { + v := m.action if v == nil { return } return *v, true } -// OldForceRefund returns the old "force_refund" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldAction returns the old "action" field's value of the PaymentAuditLog entity. +// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldForceRefund(ctx context.Context) (v bool, err error) { +func (m *PaymentAuditLogMutation) OldAction(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldForceRefund is only allowed on UpdateOne operations") + return v, errors.New("OldAction is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldForceRefund requires an ID field in the mutation") + return v, errors.New("OldAction requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldForceRefund: %w", err) + return v, fmt.Errorf("querying old value for OldAction: %w", err) } - return oldValue.ForceRefund, nil + return oldValue.Action, nil } -// ResetForceRefund resets all changes to the "force_refund" field. -func (m *PaymentOrderMutation) ResetForceRefund() { - m.force_refund = nil +// ResetAction resets all changes to the "action" field. +func (m *PaymentAuditLogMutation) ResetAction() { + m.action = nil } -// SetRefundRequestedAt sets the "refund_requested_at" field. -func (m *PaymentOrderMutation) SetRefundRequestedAt(t time.Time) { - m.refund_requested_at = &t +// SetDetail sets the "detail" field. +func (m *PaymentAuditLogMutation) SetDetail(s string) { + m.detail = &s } -// RefundRequestedAt returns the value of the "refund_requested_at" field in the mutation. -func (m *PaymentOrderMutation) RefundRequestedAt() (r time.Time, exists bool) { - v := m.refund_requested_at +// Detail returns the value of the "detail" field in the mutation. +func (m *PaymentAuditLogMutation) Detail() (r string, exists bool) { + v := m.detail if v == nil { return } return *v, true } -// OldRefundRequestedAt returns the old "refund_requested_at" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldDetail returns the old "detail" field's value of the PaymentAuditLog entity. +// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldRefundRequestedAt(ctx context.Context) (v *time.Time, err error) { +func (m *PaymentAuditLogMutation) OldDetail(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldRefundRequestedAt is only allowed on UpdateOne operations") + return v, errors.New("OldDetail is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldRefundRequestedAt requires an ID field in the mutation") + return v, errors.New("OldDetail requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldRefundRequestedAt: %w", err) + return v, fmt.Errorf("querying old value for OldDetail: %w", err) } - return oldValue.RefundRequestedAt, nil -} - -// ClearRefundRequestedAt clears the value of the "refund_requested_at" field. -func (m *PaymentOrderMutation) ClearRefundRequestedAt() { - m.refund_requested_at = nil - m.clearedFields[paymentorder.FieldRefundRequestedAt] = struct{}{} -} - -// RefundRequestedAtCleared returns if the "refund_requested_at" field was cleared in this mutation. -func (m *PaymentOrderMutation) RefundRequestedAtCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldRefundRequestedAt] - return ok + return oldValue.Detail, nil } -// ResetRefundRequestedAt resets all changes to the "refund_requested_at" field. -func (m *PaymentOrderMutation) ResetRefundRequestedAt() { - m.refund_requested_at = nil - delete(m.clearedFields, paymentorder.FieldRefundRequestedAt) +// ResetDetail resets all changes to the "detail" field. +func (m *PaymentAuditLogMutation) ResetDetail() { + m.detail = nil } -// SetRefundRequestReason sets the "refund_request_reason" field. -func (m *PaymentOrderMutation) SetRefundRequestReason(s string) { - m.refund_request_reason = &s +// SetOperator sets the "operator" field. +func (m *PaymentAuditLogMutation) SetOperator(s string) { + m.operator = &s } -// RefundRequestReason returns the value of the "refund_request_reason" field in the mutation. -func (m *PaymentOrderMutation) RefundRequestReason() (r string, exists bool) { - v := m.refund_request_reason +// Operator returns the value of the "operator" field in the mutation. +func (m *PaymentAuditLogMutation) Operator() (r string, exists bool) { + v := m.operator if v == nil { return } return *v, true } -// OldRefundRequestReason returns the old "refund_request_reason" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldOperator returns the old "operator" field's value of the PaymentAuditLog entity. +// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldRefundRequestReason(ctx context.Context) (v *string, err error) { +func (m *PaymentAuditLogMutation) OldOperator(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldRefundRequestReason is only allowed on UpdateOne operations") + return v, errors.New("OldOperator is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldRefundRequestReason requires an ID field in the mutation") + return v, errors.New("OldOperator requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldRefundRequestReason: %w", err) + return v, fmt.Errorf("querying old value for OldOperator: %w", err) } - return oldValue.RefundRequestReason, nil + return oldValue.Operator, nil } -// ClearRefundRequestReason clears the value of the "refund_request_reason" field. -func (m *PaymentOrderMutation) ClearRefundRequestReason() { - m.refund_request_reason = nil - m.clearedFields[paymentorder.FieldRefundRequestReason] = struct{}{} -} - -// RefundRequestReasonCleared returns if the "refund_request_reason" field was cleared in this mutation. -func (m *PaymentOrderMutation) RefundRequestReasonCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldRefundRequestReason] - return ok -} - -// ResetRefundRequestReason resets all changes to the "refund_request_reason" field. -func (m *PaymentOrderMutation) ResetRefundRequestReason() { - m.refund_request_reason = nil - delete(m.clearedFields, paymentorder.FieldRefundRequestReason) +// ResetOperator resets all changes to the "operator" field. +func (m *PaymentAuditLogMutation) ResetOperator() { + m.operator = nil } -// SetRefundRequestedBy sets the "refund_requested_by" field. -func (m *PaymentOrderMutation) SetRefundRequestedBy(s string) { - m.refund_requested_by = &s +// SetCreatedAt sets the "created_at" field. +func (m *PaymentAuditLogMutation) SetCreatedAt(t time.Time) { + m.created_at = &t } -// RefundRequestedBy returns the value of the "refund_requested_by" field in the mutation. -func (m *PaymentOrderMutation) RefundRequestedBy() (r string, exists bool) { - v := m.refund_requested_by +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *PaymentAuditLogMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at if v == nil { return } return *v, true } -// OldRefundRequestedBy returns the old "refund_requested_by" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldCreatedAt returns the old "created_at" field's value of the PaymentAuditLog entity. +// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldRefundRequestedBy(ctx context.Context) (v *string, err error) { +func (m *PaymentAuditLogMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldRefundRequestedBy is only allowed on UpdateOne operations") + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldRefundRequestedBy requires an ID field in the mutation") + return v, errors.New("OldCreatedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldRefundRequestedBy: %w", err) + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) } - return oldValue.RefundRequestedBy, nil + return oldValue.CreatedAt, nil } -// ClearRefundRequestedBy clears the value of the "refund_requested_by" field. -func (m *PaymentOrderMutation) ClearRefundRequestedBy() { - m.refund_requested_by = nil - m.clearedFields[paymentorder.FieldRefundRequestedBy] = struct{}{} +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *PaymentAuditLogMutation) ResetCreatedAt() { + m.created_at = nil } -// RefundRequestedByCleared returns if the "refund_requested_by" field was cleared in this mutation. -func (m *PaymentOrderMutation) RefundRequestedByCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldRefundRequestedBy] - return ok +// Where appends a list predicates to the PaymentAuditLogMutation builder. +func (m *PaymentAuditLogMutation) Where(ps ...predicate.PaymentAuditLog) { + m.predicates = append(m.predicates, ps...) } -// ResetRefundRequestedBy resets all changes to the "refund_requested_by" field. -func (m *PaymentOrderMutation) ResetRefundRequestedBy() { - m.refund_requested_by = nil - delete(m.clearedFields, paymentorder.FieldRefundRequestedBy) +// WhereP appends storage-level predicates to the PaymentAuditLogMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *PaymentAuditLogMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.PaymentAuditLog, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) } -// SetExpiresAt sets the "expires_at" field. -func (m *PaymentOrderMutation) SetExpiresAt(t time.Time) { - m.expires_at = &t +// Op returns the operation name. +func (m *PaymentAuditLogMutation) Op() Op { + return m.op } -// ExpiresAt returns the value of the "expires_at" field in the mutation. -func (m *PaymentOrderMutation) ExpiresAt() (r time.Time, exists bool) { - v := m.expires_at - if v == nil { - return - } - return *v, true +// SetOp allows setting the mutation operation. +func (m *PaymentAuditLogMutation) SetOp(op Op) { + m.op = op } -// OldExpiresAt returns the old "expires_at" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldExpiresAt(ctx context.Context) (v time.Time, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations") +// Type returns the node type of this mutation (PaymentAuditLog). +func (m *PaymentAuditLogMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *PaymentAuditLogMutation) Fields() []string { + fields := make([]string, 0, 5) + if m.order_id != nil { + fields = append(fields, paymentauditlog.FieldOrderID) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldExpiresAt requires an ID field in the mutation") + if m.action != nil { + fields = append(fields, paymentauditlog.FieldAction) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err) + if m.detail != nil { + fields = append(fields, paymentauditlog.FieldDetail) } - return oldValue.ExpiresAt, nil -} - -// ResetExpiresAt resets all changes to the "expires_at" field. -func (m *PaymentOrderMutation) ResetExpiresAt() { - m.expires_at = nil + if m.operator != nil { + fields = append(fields, paymentauditlog.FieldOperator) + } + if m.created_at != nil { + fields = append(fields, paymentauditlog.FieldCreatedAt) + } + return fields } -// SetPaidAt sets the "paid_at" field. -func (m *PaymentOrderMutation) SetPaidAt(t time.Time) { - m.paid_at = &t +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *PaymentAuditLogMutation) Field(name string) (ent.Value, bool) { + switch name { + case paymentauditlog.FieldOrderID: + return m.OrderID() + case paymentauditlog.FieldAction: + return m.Action() + case paymentauditlog.FieldDetail: + return m.Detail() + case paymentauditlog.FieldOperator: + return m.Operator() + case paymentauditlog.FieldCreatedAt: + return m.CreatedAt() + } + return nil, false } -// PaidAt returns the value of the "paid_at" field in the mutation. -func (m *PaymentOrderMutation) PaidAt() (r time.Time, exists bool) { - v := m.paid_at - if v == nil { - return +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *PaymentAuditLogMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case paymentauditlog.FieldOrderID: + return m.OldOrderID(ctx) + case paymentauditlog.FieldAction: + return m.OldAction(ctx) + case paymentauditlog.FieldDetail: + return m.OldDetail(ctx) + case paymentauditlog.FieldOperator: + return m.OldOperator(ctx) + case paymentauditlog.FieldCreatedAt: + return m.OldCreatedAt(ctx) } - return *v, true + return nil, fmt.Errorf("unknown PaymentAuditLog field %s", name) } -// OldPaidAt returns the old "paid_at" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldPaidAt(ctx context.Context) (v *time.Time, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPaidAt is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPaidAt requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldPaidAt: %w", err) +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *PaymentAuditLogMutation) SetField(name string, value ent.Value) error { + switch name { + case paymentauditlog.FieldOrderID: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetOrderID(v) + return nil + case paymentauditlog.FieldAction: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAction(v) + return nil + case paymentauditlog.FieldDetail: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDetail(v) + return nil + case paymentauditlog.FieldOperator: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetOperator(v) + return nil + case paymentauditlog.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil } - return oldValue.PaidAt, nil + return fmt.Errorf("unknown PaymentAuditLog field %s", name) } -// ClearPaidAt clears the value of the "paid_at" field. -func (m *PaymentOrderMutation) ClearPaidAt() { - m.paid_at = nil - m.clearedFields[paymentorder.FieldPaidAt] = struct{}{} +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *PaymentAuditLogMutation) AddedFields() []string { + return nil } -// PaidAtCleared returns if the "paid_at" field was cleared in this mutation. -func (m *PaymentOrderMutation) PaidAtCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldPaidAt] - return ok +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *PaymentAuditLogMutation) AddedField(name string) (ent.Value, bool) { + return nil, false } -// ResetPaidAt resets all changes to the "paid_at" field. -func (m *PaymentOrderMutation) ResetPaidAt() { - m.paid_at = nil - delete(m.clearedFields, paymentorder.FieldPaidAt) +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *PaymentAuditLogMutation) AddField(name string, value ent.Value) error { + switch name { + } + return fmt.Errorf("unknown PaymentAuditLog numeric field %s", name) } -// SetCompletedAt sets the "completed_at" field. -func (m *PaymentOrderMutation) SetCompletedAt(t time.Time) { - m.completed_at = &t +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *PaymentAuditLogMutation) ClearedFields() []string { + return nil } -// CompletedAt returns the value of the "completed_at" field in the mutation. -func (m *PaymentOrderMutation) CompletedAt() (r time.Time, exists bool) { - v := m.completed_at - if v == nil { - return - } - return *v, true -} - -// OldCompletedAt returns the old "completed_at" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldCompletedAt(ctx context.Context) (v *time.Time, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldCompletedAt is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldCompletedAt requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldCompletedAt: %w", err) - } - return oldValue.CompletedAt, nil -} - -// ClearCompletedAt clears the value of the "completed_at" field. -func (m *PaymentOrderMutation) ClearCompletedAt() { - m.completed_at = nil - m.clearedFields[paymentorder.FieldCompletedAt] = struct{}{} -} - -// CompletedAtCleared returns if the "completed_at" field was cleared in this mutation. -func (m *PaymentOrderMutation) CompletedAtCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldCompletedAt] +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *PaymentAuditLogMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] return ok } -// ResetCompletedAt resets all changes to the "completed_at" field. -func (m *PaymentOrderMutation) ResetCompletedAt() { - m.completed_at = nil - delete(m.clearedFields, paymentorder.FieldCompletedAt) -} - -// SetFailedAt sets the "failed_at" field. -func (m *PaymentOrderMutation) SetFailedAt(t time.Time) { - m.failed_at = &t +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *PaymentAuditLogMutation) ClearField(name string) error { + return fmt.Errorf("unknown PaymentAuditLog nullable field %s", name) } -// FailedAt returns the value of the "failed_at" field in the mutation. -func (m *PaymentOrderMutation) FailedAt() (r time.Time, exists bool) { - v := m.failed_at - if v == nil { - return +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *PaymentAuditLogMutation) ResetField(name string) error { + switch name { + case paymentauditlog.FieldOrderID: + m.ResetOrderID() + return nil + case paymentauditlog.FieldAction: + m.ResetAction() + return nil + case paymentauditlog.FieldDetail: + m.ResetDetail() + return nil + case paymentauditlog.FieldOperator: + m.ResetOperator() + return nil + case paymentauditlog.FieldCreatedAt: + m.ResetCreatedAt() + return nil } - return *v, true + return fmt.Errorf("unknown PaymentAuditLog field %s", name) } -// OldFailedAt returns the old "failed_at" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldFailedAt(ctx context.Context) (v *time.Time, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldFailedAt is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldFailedAt requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldFailedAt: %w", err) - } - return oldValue.FailedAt, nil +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *PaymentAuditLogMutation) AddedEdges() []string { + edges := make([]string, 0, 0) + return edges } -// ClearFailedAt clears the value of the "failed_at" field. -func (m *PaymentOrderMutation) ClearFailedAt() { - m.failed_at = nil - m.clearedFields[paymentorder.FieldFailedAt] = struct{}{} +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *PaymentAuditLogMutation) AddedIDs(name string) []ent.Value { + return nil } -// FailedAtCleared returns if the "failed_at" field was cleared in this mutation. -func (m *PaymentOrderMutation) FailedAtCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldFailedAt] - return ok +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *PaymentAuditLogMutation) RemovedEdges() []string { + edges := make([]string, 0, 0) + return edges } -// ResetFailedAt resets all changes to the "failed_at" field. -func (m *PaymentOrderMutation) ResetFailedAt() { - m.failed_at = nil - delete(m.clearedFields, paymentorder.FieldFailedAt) +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *PaymentAuditLogMutation) RemovedIDs(name string) []ent.Value { + return nil } -// SetFailedReason sets the "failed_reason" field. -func (m *PaymentOrderMutation) SetFailedReason(s string) { - m.failed_reason = &s +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *PaymentAuditLogMutation) ClearedEdges() []string { + edges := make([]string, 0, 0) + return edges } -// FailedReason returns the value of the "failed_reason" field in the mutation. -func (m *PaymentOrderMutation) FailedReason() (r string, exists bool) { - v := m.failed_reason - if v == nil { - return - } - return *v, true +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *PaymentAuditLogMutation) EdgeCleared(name string) bool { + return false } -// OldFailedReason returns the old "failed_reason" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldFailedReason(ctx context.Context) (v *string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldFailedReason is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldFailedReason requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldFailedReason: %w", err) - } - return oldValue.FailedReason, nil +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *PaymentAuditLogMutation) ClearEdge(name string) error { + return fmt.Errorf("unknown PaymentAuditLog unique edge %s", name) } -// ClearFailedReason clears the value of the "failed_reason" field. -func (m *PaymentOrderMutation) ClearFailedReason() { - m.failed_reason = nil - m.clearedFields[paymentorder.FieldFailedReason] = struct{}{} +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *PaymentAuditLogMutation) ResetEdge(name string) error { + return fmt.Errorf("unknown PaymentAuditLog edge %s", name) } -// FailedReasonCleared returns if the "failed_reason" field was cleared in this mutation. -func (m *PaymentOrderMutation) FailedReasonCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldFailedReason] - return ok +// PaymentOrderMutation represents an operation that mutates the PaymentOrder nodes in the graph. +type PaymentOrderMutation struct { + config + op Op + typ string + id *int64 + user_email *string + user_name *string + user_notes *string + amount *float64 + addamount *float64 + pay_amount *float64 + addpay_amount *float64 + fee_rate *float64 + addfee_rate *float64 + recharge_code *string + out_trade_no *string + payment_type *string + payment_trade_no *string + pay_url *string + qr_code *string + qr_code_img *string + order_type *string + plan_id *int64 + addplan_id *int64 + subscription_group_id *int64 + addsubscription_group_id *int64 + subscription_days *int + addsubscription_days *int + provider_instance_id *string + provider_key *string + provider_snapshot *map[string]interface{} + status *string + refund_amount *float64 + addrefund_amount *float64 + refund_reason *string + refund_at *time.Time + force_refund *bool + refund_requested_at *time.Time + refund_request_reason *string + refund_requested_by *string + expires_at *time.Time + paid_at *time.Time + completed_at *time.Time + failed_at *time.Time + failed_reason *string + client_ip *string + src_host *string + src_url *string + created_at *time.Time + updated_at *time.Time + clearedFields map[string]struct{} + user *int64 + cleareduser bool + done bool + oldValue func(context.Context) (*PaymentOrder, error) + predicates []predicate.PaymentOrder } -// ResetFailedReason resets all changes to the "failed_reason" field. -func (m *PaymentOrderMutation) ResetFailedReason() { - m.failed_reason = nil - delete(m.clearedFields, paymentorder.FieldFailedReason) -} +var _ ent.Mutation = (*PaymentOrderMutation)(nil) -// SetClientIP sets the "client_ip" field. -func (m *PaymentOrderMutation) SetClientIP(s string) { - m.client_ip = &s -} +// paymentorderOption allows management of the mutation configuration using functional options. +type paymentorderOption func(*PaymentOrderMutation) -// ClientIP returns the value of the "client_ip" field in the mutation. -func (m *PaymentOrderMutation) ClientIP() (r string, exists bool) { - v := m.client_ip - if v == nil { - return +// newPaymentOrderMutation creates new mutation for the PaymentOrder entity. +func newPaymentOrderMutation(c config, op Op, opts ...paymentorderOption) *PaymentOrderMutation { + m := &PaymentOrderMutation{ + config: c, + op: op, + typ: TypePaymentOrder, + clearedFields: make(map[string]struct{}), } - return *v, true + for _, opt := range opts { + opt(m) + } + return m } -// OldClientIP returns the old "client_ip" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldClientIP(ctx context.Context) (v string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldClientIP is only allowed on UpdateOne operations") +// withPaymentOrderID sets the ID field of the mutation. +func withPaymentOrderID(id int64) paymentorderOption { + return func(m *PaymentOrderMutation) { + var ( + err error + once sync.Once + value *PaymentOrder + ) + m.oldValue = func(ctx context.Context) (*PaymentOrder, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().PaymentOrder.Get(ctx, id) + } + }) + return value, err + } + m.id = &id } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldClientIP requires an ID field in the mutation") +} + +// withPaymentOrder sets the old PaymentOrder of the mutation. +func withPaymentOrder(node *PaymentOrder) paymentorderOption { + return func(m *PaymentOrderMutation) { + m.oldValue = func(context.Context) (*PaymentOrder, error) { + return node, nil + } + m.id = &node.ID } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldClientIP: %w", err) +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m PaymentOrderMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m PaymentOrderMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") } - return oldValue.ClientIP, nil + tx := &Tx{config: m.config} + tx.init() + return tx, nil } -// ResetClientIP resets all changes to the "client_ip" field. -func (m *PaymentOrderMutation) ResetClientIP() { - m.client_ip = nil +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *PaymentOrderMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true } -// SetSrcHost sets the "src_host" field. -func (m *PaymentOrderMutation) SetSrcHost(s string) { - m.src_host = &s +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *PaymentOrderMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().PaymentOrder.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } } -// SrcHost returns the value of the "src_host" field in the mutation. -func (m *PaymentOrderMutation) SrcHost() (r string, exists bool) { - v := m.src_host +// SetUserID sets the "user_id" field. +func (m *PaymentOrderMutation) SetUserID(i int64) { + m.user = &i +} + +// UserID returns the value of the "user_id" field in the mutation. +func (m *PaymentOrderMutation) UserID() (r int64, exists bool) { + v := m.user if v == nil { return } return *v, true } -// OldSrcHost returns the old "src_host" field's value of the PaymentOrder entity. +// OldUserID returns the old "user_id" field's value of the PaymentOrder entity. // If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldSrcHost(ctx context.Context) (v string, err error) { +func (m *PaymentOrderMutation) OldUserID(ctx context.Context) (v int64, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSrcHost is only allowed on UpdateOne operations") + return v, errors.New("OldUserID is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSrcHost requires an ID field in the mutation") + return v, errors.New("OldUserID requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldSrcHost: %w", err) + return v, fmt.Errorf("querying old value for OldUserID: %w", err) } - return oldValue.SrcHost, nil + return oldValue.UserID, nil } -// ResetSrcHost resets all changes to the "src_host" field. -func (m *PaymentOrderMutation) ResetSrcHost() { - m.src_host = nil +// ResetUserID resets all changes to the "user_id" field. +func (m *PaymentOrderMutation) ResetUserID() { + m.user = nil } -// SetSrcURL sets the "src_url" field. -func (m *PaymentOrderMutation) SetSrcURL(s string) { - m.src_url = &s +// SetUserEmail sets the "user_email" field. +func (m *PaymentOrderMutation) SetUserEmail(s string) { + m.user_email = &s } -// SrcURL returns the value of the "src_url" field in the mutation. -func (m *PaymentOrderMutation) SrcURL() (r string, exists bool) { - v := m.src_url +// UserEmail returns the value of the "user_email" field in the mutation. +func (m *PaymentOrderMutation) UserEmail() (r string, exists bool) { + v := m.user_email if v == nil { return } return *v, true } -// OldSrcURL returns the old "src_url" field's value of the PaymentOrder entity. +// OldUserEmail returns the old "user_email" field's value of the PaymentOrder entity. // If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldSrcURL(ctx context.Context) (v *string, err error) { +func (m *PaymentOrderMutation) OldUserEmail(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSrcURL is only allowed on UpdateOne operations") + return v, errors.New("OldUserEmail is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSrcURL requires an ID field in the mutation") + return v, errors.New("OldUserEmail requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldSrcURL: %w", err) + return v, fmt.Errorf("querying old value for OldUserEmail: %w", err) } - return oldValue.SrcURL, nil -} - -// ClearSrcURL clears the value of the "src_url" field. -func (m *PaymentOrderMutation) ClearSrcURL() { - m.src_url = nil - m.clearedFields[paymentorder.FieldSrcURL] = struct{}{} -} - -// SrcURLCleared returns if the "src_url" field was cleared in this mutation. -func (m *PaymentOrderMutation) SrcURLCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldSrcURL] - return ok + return oldValue.UserEmail, nil } -// ResetSrcURL resets all changes to the "src_url" field. -func (m *PaymentOrderMutation) ResetSrcURL() { - m.src_url = nil - delete(m.clearedFields, paymentorder.FieldSrcURL) +// ResetUserEmail resets all changes to the "user_email" field. +func (m *PaymentOrderMutation) ResetUserEmail() { + m.user_email = nil } -// SetCreatedAt sets the "created_at" field. -func (m *PaymentOrderMutation) SetCreatedAt(t time.Time) { - m.created_at = &t +// SetUserName sets the "user_name" field. +func (m *PaymentOrderMutation) SetUserName(s string) { + m.user_name = &s } -// CreatedAt returns the value of the "created_at" field in the mutation. -func (m *PaymentOrderMutation) CreatedAt() (r time.Time, exists bool) { - v := m.created_at +// UserName returns the value of the "user_name" field in the mutation. +func (m *PaymentOrderMutation) UserName() (r string, exists bool) { + v := m.user_name if v == nil { return } return *v, true } -// OldCreatedAt returns the old "created_at" field's value of the PaymentOrder entity. +// OldUserName returns the old "user_name" field's value of the PaymentOrder entity. // If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { +func (m *PaymentOrderMutation) OldUserName(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + return v, errors.New("OldUserName is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldCreatedAt requires an ID field in the mutation") + return v, errors.New("OldUserName requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + return v, fmt.Errorf("querying old value for OldUserName: %w", err) } - return oldValue.CreatedAt, nil + return oldValue.UserName, nil } -// ResetCreatedAt resets all changes to the "created_at" field. -func (m *PaymentOrderMutation) ResetCreatedAt() { - m.created_at = nil +// ResetUserName resets all changes to the "user_name" field. +func (m *PaymentOrderMutation) ResetUserName() { + m.user_name = nil } -// SetUpdatedAt sets the "updated_at" field. -func (m *PaymentOrderMutation) SetUpdatedAt(t time.Time) { - m.updated_at = &t +// SetUserNotes sets the "user_notes" field. +func (m *PaymentOrderMutation) SetUserNotes(s string) { + m.user_notes = &s } -// UpdatedAt returns the value of the "updated_at" field in the mutation. -func (m *PaymentOrderMutation) UpdatedAt() (r time.Time, exists bool) { - v := m.updated_at +// UserNotes returns the value of the "user_notes" field in the mutation. +func (m *PaymentOrderMutation) UserNotes() (r string, exists bool) { + v := m.user_notes if v == nil { return } return *v, true } -// OldUpdatedAt returns the old "updated_at" field's value of the PaymentOrder entity. +// OldUserNotes returns the old "user_notes" field's value of the PaymentOrder entity. // If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { +func (m *PaymentOrderMutation) OldUserNotes(ctx context.Context) (v *string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + return v, errors.New("OldUserNotes is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + return v, errors.New("OldUserNotes requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + return v, fmt.Errorf("querying old value for OldUserNotes: %w", err) } - return oldValue.UpdatedAt, nil + return oldValue.UserNotes, nil } -// ResetUpdatedAt resets all changes to the "updated_at" field. -func (m *PaymentOrderMutation) ResetUpdatedAt() { - m.updated_at = nil +// ClearUserNotes clears the value of the "user_notes" field. +func (m *PaymentOrderMutation) ClearUserNotes() { + m.user_notes = nil + m.clearedFields[paymentorder.FieldUserNotes] = struct{}{} } -// ClearUser clears the "user" edge to the User entity. -func (m *PaymentOrderMutation) ClearUser() { - m.cleareduser = true - m.clearedFields[paymentorder.FieldUserID] = struct{}{} +// UserNotesCleared returns if the "user_notes" field was cleared in this mutation. +func (m *PaymentOrderMutation) UserNotesCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldUserNotes] + return ok } -// UserCleared reports if the "user" edge to the User entity was cleared. -func (m *PaymentOrderMutation) UserCleared() bool { - return m.cleareduser +// ResetUserNotes resets all changes to the "user_notes" field. +func (m *PaymentOrderMutation) ResetUserNotes() { + m.user_notes = nil + delete(m.clearedFields, paymentorder.FieldUserNotes) } -// UserIDs returns the "user" edge IDs in the mutation. -// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use -// UserID instead. It exists only for internal usage by the builders. -func (m *PaymentOrderMutation) UserIDs() (ids []int64) { - if id := m.user; id != nil { - ids = append(ids, *id) +// SetAmount sets the "amount" field. +func (m *PaymentOrderMutation) SetAmount(f float64) { + m.amount = &f + m.addamount = nil +} + +// Amount returns the value of the "amount" field in the mutation. +func (m *PaymentOrderMutation) Amount() (r float64, exists bool) { + v := m.amount + if v == nil { + return } - return + return *v, true } -// ResetUser resets all changes to the "user" edge. -func (m *PaymentOrderMutation) ResetUser() { - m.user = nil - m.cleareduser = false +// OldAmount returns the old "amount" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldAmount(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAmount is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAmount requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAmount: %w", err) + } + return oldValue.Amount, nil } -// Where appends a list predicates to the PaymentOrderMutation builder. -func (m *PaymentOrderMutation) Where(ps ...predicate.PaymentOrder) { - m.predicates = append(m.predicates, ps...) +// AddAmount adds f to the "amount" field. +func (m *PaymentOrderMutation) AddAmount(f float64) { + if m.addamount != nil { + *m.addamount += f + } else { + m.addamount = &f + } } -// WhereP appends storage-level predicates to the PaymentOrderMutation builder. Using this method, -// users can use type-assertion to append predicates that do not depend on any generated package. -func (m *PaymentOrderMutation) WhereP(ps ...func(*sql.Selector)) { - p := make([]predicate.PaymentOrder, len(ps)) - for i := range ps { - p[i] = ps[i] +// AddedAmount returns the value that was added to the "amount" field in this mutation. +func (m *PaymentOrderMutation) AddedAmount() (r float64, exists bool) { + v := m.addamount + if v == nil { + return } - m.Where(p...) + return *v, true } -// Op returns the operation name. -func (m *PaymentOrderMutation) Op() Op { - return m.op +// ResetAmount resets all changes to the "amount" field. +func (m *PaymentOrderMutation) ResetAmount() { + m.amount = nil + m.addamount = nil } -// SetOp allows setting the mutation operation. -func (m *PaymentOrderMutation) SetOp(op Op) { - m.op = op +// SetPayAmount sets the "pay_amount" field. +func (m *PaymentOrderMutation) SetPayAmount(f float64) { + m.pay_amount = &f + m.addpay_amount = nil } -// Type returns the node type of this mutation (PaymentOrder). -func (m *PaymentOrderMutation) Type() string { - return m.typ +// PayAmount returns the value of the "pay_amount" field in the mutation. +func (m *PaymentOrderMutation) PayAmount() (r float64, exists bool) { + v := m.pay_amount + if v == nil { + return + } + return *v, true } -// Fields returns all fields that were changed during this mutation. Note that in -// order to get all numeric fields that were incremented/decremented, call -// AddedFields(). -func (m *PaymentOrderMutation) Fields() []string { - fields := make([]string, 0, 37) - if m.user != nil { - fields = append(fields, paymentorder.FieldUserID) - } - if m.user_email != nil { - fields = append(fields, paymentorder.FieldUserEmail) - } - if m.user_name != nil { - fields = append(fields, paymentorder.FieldUserName) +// OldPayAmount returns the old "pay_amount" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldPayAmount(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPayAmount is only allowed on UpdateOne operations") } - if m.user_notes != nil { - fields = append(fields, paymentorder.FieldUserNotes) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPayAmount requires an ID field in the mutation") } - if m.amount != nil { - fields = append(fields, paymentorder.FieldAmount) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPayAmount: %w", err) } - if m.pay_amount != nil { - fields = append(fields, paymentorder.FieldPayAmount) + return oldValue.PayAmount, nil +} + +// AddPayAmount adds f to the "pay_amount" field. +func (m *PaymentOrderMutation) AddPayAmount(f float64) { + if m.addpay_amount != nil { + *m.addpay_amount += f + } else { + m.addpay_amount = &f } - if m.fee_rate != nil { - fields = append(fields, paymentorder.FieldFeeRate) +} + +// AddedPayAmount returns the value that was added to the "pay_amount" field in this mutation. +func (m *PaymentOrderMutation) AddedPayAmount() (r float64, exists bool) { + v := m.addpay_amount + if v == nil { + return } - if m.recharge_code != nil { - fields = append(fields, paymentorder.FieldRechargeCode) + return *v, true +} + +// ResetPayAmount resets all changes to the "pay_amount" field. +func (m *PaymentOrderMutation) ResetPayAmount() { + m.pay_amount = nil + m.addpay_amount = nil +} + +// SetFeeRate sets the "fee_rate" field. +func (m *PaymentOrderMutation) SetFeeRate(f float64) { + m.fee_rate = &f + m.addfee_rate = nil +} + +// FeeRate returns the value of the "fee_rate" field in the mutation. +func (m *PaymentOrderMutation) FeeRate() (r float64, exists bool) { + v := m.fee_rate + if v == nil { + return } - if m.out_trade_no != nil { - fields = append(fields, paymentorder.FieldOutTradeNo) + return *v, true +} + +// OldFeeRate returns the old "fee_rate" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldFeeRate(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldFeeRate is only allowed on UpdateOne operations") } - if m.payment_type != nil { - fields = append(fields, paymentorder.FieldPaymentType) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldFeeRate requires an ID field in the mutation") } - if m.payment_trade_no != nil { - fields = append(fields, paymentorder.FieldPaymentTradeNo) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldFeeRate: %w", err) } - if m.pay_url != nil { - fields = append(fields, paymentorder.FieldPayURL) + return oldValue.FeeRate, nil +} + +// AddFeeRate adds f to the "fee_rate" field. +func (m *PaymentOrderMutation) AddFeeRate(f float64) { + if m.addfee_rate != nil { + *m.addfee_rate += f + } else { + m.addfee_rate = &f } - if m.qr_code != nil { - fields = append(fields, paymentorder.FieldQrCode) +} + +// AddedFeeRate returns the value that was added to the "fee_rate" field in this mutation. +func (m *PaymentOrderMutation) AddedFeeRate() (r float64, exists bool) { + v := m.addfee_rate + if v == nil { + return } - if m.qr_code_img != nil { - fields = append(fields, paymentorder.FieldQrCodeImg) + return *v, true +} + +// ResetFeeRate resets all changes to the "fee_rate" field. +func (m *PaymentOrderMutation) ResetFeeRate() { + m.fee_rate = nil + m.addfee_rate = nil +} + +// SetRechargeCode sets the "recharge_code" field. +func (m *PaymentOrderMutation) SetRechargeCode(s string) { + m.recharge_code = &s +} + +// RechargeCode returns the value of the "recharge_code" field in the mutation. +func (m *PaymentOrderMutation) RechargeCode() (r string, exists bool) { + v := m.recharge_code + if v == nil { + return } - if m.order_type != nil { - fields = append(fields, paymentorder.FieldOrderType) + return *v, true +} + +// OldRechargeCode returns the old "recharge_code" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldRechargeCode(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRechargeCode is only allowed on UpdateOne operations") } - if m.plan_id != nil { - fields = append(fields, paymentorder.FieldPlanID) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRechargeCode requires an ID field in the mutation") } - if m.subscription_group_id != nil { - fields = append(fields, paymentorder.FieldSubscriptionGroupID) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRechargeCode: %w", err) } - if m.subscription_days != nil { - fields = append(fields, paymentorder.FieldSubscriptionDays) + return oldValue.RechargeCode, nil +} + +// ResetRechargeCode resets all changes to the "recharge_code" field. +func (m *PaymentOrderMutation) ResetRechargeCode() { + m.recharge_code = nil +} + +// SetOutTradeNo sets the "out_trade_no" field. +func (m *PaymentOrderMutation) SetOutTradeNo(s string) { + m.out_trade_no = &s +} + +// OutTradeNo returns the value of the "out_trade_no" field in the mutation. +func (m *PaymentOrderMutation) OutTradeNo() (r string, exists bool) { + v := m.out_trade_no + if v == nil { + return } - if m.provider_instance_id != nil { - fields = append(fields, paymentorder.FieldProviderInstanceID) + return *v, true +} + +// OldOutTradeNo returns the old "out_trade_no" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldOutTradeNo(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldOutTradeNo is only allowed on UpdateOne operations") } - if m.status != nil { - fields = append(fields, paymentorder.FieldStatus) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldOutTradeNo requires an ID field in the mutation") } - if m.refund_amount != nil { - fields = append(fields, paymentorder.FieldRefundAmount) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldOutTradeNo: %w", err) } - if m.refund_reason != nil { - fields = append(fields, paymentorder.FieldRefundReason) - } - if m.refund_at != nil { - fields = append(fields, paymentorder.FieldRefundAt) - } - if m.force_refund != nil { - fields = append(fields, paymentorder.FieldForceRefund) - } - if m.refund_requested_at != nil { - fields = append(fields, paymentorder.FieldRefundRequestedAt) + return oldValue.OutTradeNo, nil +} + +// ResetOutTradeNo resets all changes to the "out_trade_no" field. +func (m *PaymentOrderMutation) ResetOutTradeNo() { + m.out_trade_no = nil +} + +// SetPaymentType sets the "payment_type" field. +func (m *PaymentOrderMutation) SetPaymentType(s string) { + m.payment_type = &s +} + +// PaymentType returns the value of the "payment_type" field in the mutation. +func (m *PaymentOrderMutation) PaymentType() (r string, exists bool) { + v := m.payment_type + if v == nil { + return } - if m.refund_request_reason != nil { - fields = append(fields, paymentorder.FieldRefundRequestReason) + return *v, true +} + +// OldPaymentType returns the old "payment_type" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldPaymentType(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPaymentType is only allowed on UpdateOne operations") } - if m.refund_requested_by != nil { - fields = append(fields, paymentorder.FieldRefundRequestedBy) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPaymentType requires an ID field in the mutation") } - if m.expires_at != nil { - fields = append(fields, paymentorder.FieldExpiresAt) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPaymentType: %w", err) } - if m.paid_at != nil { - fields = append(fields, paymentorder.FieldPaidAt) + return oldValue.PaymentType, nil +} + +// ResetPaymentType resets all changes to the "payment_type" field. +func (m *PaymentOrderMutation) ResetPaymentType() { + m.payment_type = nil +} + +// SetPaymentTradeNo sets the "payment_trade_no" field. +func (m *PaymentOrderMutation) SetPaymentTradeNo(s string) { + m.payment_trade_no = &s +} + +// PaymentTradeNo returns the value of the "payment_trade_no" field in the mutation. +func (m *PaymentOrderMutation) PaymentTradeNo() (r string, exists bool) { + v := m.payment_trade_no + if v == nil { + return } - if m.completed_at != nil { - fields = append(fields, paymentorder.FieldCompletedAt) + return *v, true +} + +// OldPaymentTradeNo returns the old "payment_trade_no" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldPaymentTradeNo(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPaymentTradeNo is only allowed on UpdateOne operations") } - if m.failed_at != nil { - fields = append(fields, paymentorder.FieldFailedAt) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPaymentTradeNo requires an ID field in the mutation") } - if m.failed_reason != nil { - fields = append(fields, paymentorder.FieldFailedReason) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPaymentTradeNo: %w", err) } - if m.client_ip != nil { - fields = append(fields, paymentorder.FieldClientIP) + return oldValue.PaymentTradeNo, nil +} + +// ResetPaymentTradeNo resets all changes to the "payment_trade_no" field. +func (m *PaymentOrderMutation) ResetPaymentTradeNo() { + m.payment_trade_no = nil +} + +// SetPayURL sets the "pay_url" field. +func (m *PaymentOrderMutation) SetPayURL(s string) { + m.pay_url = &s +} + +// PayURL returns the value of the "pay_url" field in the mutation. +func (m *PaymentOrderMutation) PayURL() (r string, exists bool) { + v := m.pay_url + if v == nil { + return } - if m.src_host != nil { - fields = append(fields, paymentorder.FieldSrcHost) + return *v, true +} + +// OldPayURL returns the old "pay_url" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldPayURL(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPayURL is only allowed on UpdateOne operations") } - if m.src_url != nil { - fields = append(fields, paymentorder.FieldSrcURL) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPayURL requires an ID field in the mutation") } - if m.created_at != nil { - fields = append(fields, paymentorder.FieldCreatedAt) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPayURL: %w", err) } - if m.updated_at != nil { - fields = append(fields, paymentorder.FieldUpdatedAt) + return oldValue.PayURL, nil +} + +// ClearPayURL clears the value of the "pay_url" field. +func (m *PaymentOrderMutation) ClearPayURL() { + m.pay_url = nil + m.clearedFields[paymentorder.FieldPayURL] = struct{}{} +} + +// PayURLCleared returns if the "pay_url" field was cleared in this mutation. +func (m *PaymentOrderMutation) PayURLCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldPayURL] + return ok +} + +// ResetPayURL resets all changes to the "pay_url" field. +func (m *PaymentOrderMutation) ResetPayURL() { + m.pay_url = nil + delete(m.clearedFields, paymentorder.FieldPayURL) +} + +// SetQrCode sets the "qr_code" field. +func (m *PaymentOrderMutation) SetQrCode(s string) { + m.qr_code = &s +} + +// QrCode returns the value of the "qr_code" field in the mutation. +func (m *PaymentOrderMutation) QrCode() (r string, exists bool) { + v := m.qr_code + if v == nil { + return } - return fields + return *v, true } -// Field returns the value of a field with the given name. The second boolean -// return value indicates that this field was not set, or was not defined in the -// schema. -func (m *PaymentOrderMutation) Field(name string) (ent.Value, bool) { - switch name { - case paymentorder.FieldUserID: - return m.UserID() - case paymentorder.FieldUserEmail: - return m.UserEmail() - case paymentorder.FieldUserName: - return m.UserName() - case paymentorder.FieldUserNotes: - return m.UserNotes() - case paymentorder.FieldAmount: - return m.Amount() - case paymentorder.FieldPayAmount: - return m.PayAmount() - case paymentorder.FieldFeeRate: - return m.FeeRate() - case paymentorder.FieldRechargeCode: - return m.RechargeCode() - case paymentorder.FieldOutTradeNo: - return m.OutTradeNo() - case paymentorder.FieldPaymentType: - return m.PaymentType() - case paymentorder.FieldPaymentTradeNo: - return m.PaymentTradeNo() - case paymentorder.FieldPayURL: - return m.PayURL() - case paymentorder.FieldQrCode: - return m.QrCode() - case paymentorder.FieldQrCodeImg: - return m.QrCodeImg() - case paymentorder.FieldOrderType: - return m.OrderType() - case paymentorder.FieldPlanID: - return m.PlanID() - case paymentorder.FieldSubscriptionGroupID: - return m.SubscriptionGroupID() - case paymentorder.FieldSubscriptionDays: - return m.SubscriptionDays() - case paymentorder.FieldProviderInstanceID: - return m.ProviderInstanceID() - case paymentorder.FieldStatus: - return m.Status() - case paymentorder.FieldRefundAmount: - return m.RefundAmount() - case paymentorder.FieldRefundReason: - return m.RefundReason() - case paymentorder.FieldRefundAt: - return m.RefundAt() - case paymentorder.FieldForceRefund: - return m.ForceRefund() - case paymentorder.FieldRefundRequestedAt: - return m.RefundRequestedAt() - case paymentorder.FieldRefundRequestReason: - return m.RefundRequestReason() - case paymentorder.FieldRefundRequestedBy: - return m.RefundRequestedBy() - case paymentorder.FieldExpiresAt: - return m.ExpiresAt() - case paymentorder.FieldPaidAt: - return m.PaidAt() - case paymentorder.FieldCompletedAt: - return m.CompletedAt() - case paymentorder.FieldFailedAt: - return m.FailedAt() - case paymentorder.FieldFailedReason: - return m.FailedReason() - case paymentorder.FieldClientIP: - return m.ClientIP() - case paymentorder.FieldSrcHost: - return m.SrcHost() - case paymentorder.FieldSrcURL: - return m.SrcURL() - case paymentorder.FieldCreatedAt: - return m.CreatedAt() - case paymentorder.FieldUpdatedAt: - return m.UpdatedAt() +// OldQrCode returns the old "qr_code" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldQrCode(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldQrCode is only allowed on UpdateOne operations") } - return nil, false + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldQrCode requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldQrCode: %w", err) + } + return oldValue.QrCode, nil } -// OldField returns the old value of the field from the database. An error is -// returned if the mutation operation is not UpdateOne, or the query to the -// database failed. -func (m *PaymentOrderMutation) OldField(ctx context.Context, name string) (ent.Value, error) { - switch name { - case paymentorder.FieldUserID: - return m.OldUserID(ctx) - case paymentorder.FieldUserEmail: - return m.OldUserEmail(ctx) - case paymentorder.FieldUserName: - return m.OldUserName(ctx) - case paymentorder.FieldUserNotes: - return m.OldUserNotes(ctx) - case paymentorder.FieldAmount: - return m.OldAmount(ctx) - case paymentorder.FieldPayAmount: - return m.OldPayAmount(ctx) - case paymentorder.FieldFeeRate: - return m.OldFeeRate(ctx) - case paymentorder.FieldRechargeCode: - return m.OldRechargeCode(ctx) - case paymentorder.FieldOutTradeNo: - return m.OldOutTradeNo(ctx) - case paymentorder.FieldPaymentType: - return m.OldPaymentType(ctx) - case paymentorder.FieldPaymentTradeNo: - return m.OldPaymentTradeNo(ctx) - case paymentorder.FieldPayURL: - return m.OldPayURL(ctx) - case paymentorder.FieldQrCode: - return m.OldQrCode(ctx) - case paymentorder.FieldQrCodeImg: - return m.OldQrCodeImg(ctx) - case paymentorder.FieldOrderType: - return m.OldOrderType(ctx) - case paymentorder.FieldPlanID: - return m.OldPlanID(ctx) - case paymentorder.FieldSubscriptionGroupID: - return m.OldSubscriptionGroupID(ctx) - case paymentorder.FieldSubscriptionDays: - return m.OldSubscriptionDays(ctx) - case paymentorder.FieldProviderInstanceID: - return m.OldProviderInstanceID(ctx) - case paymentorder.FieldStatus: - return m.OldStatus(ctx) - case paymentorder.FieldRefundAmount: - return m.OldRefundAmount(ctx) - case paymentorder.FieldRefundReason: - return m.OldRefundReason(ctx) - case paymentorder.FieldRefundAt: - return m.OldRefundAt(ctx) - case paymentorder.FieldForceRefund: - return m.OldForceRefund(ctx) - case paymentorder.FieldRefundRequestedAt: - return m.OldRefundRequestedAt(ctx) - case paymentorder.FieldRefundRequestReason: - return m.OldRefundRequestReason(ctx) - case paymentorder.FieldRefundRequestedBy: - return m.OldRefundRequestedBy(ctx) - case paymentorder.FieldExpiresAt: - return m.OldExpiresAt(ctx) - case paymentorder.FieldPaidAt: - return m.OldPaidAt(ctx) - case paymentorder.FieldCompletedAt: - return m.OldCompletedAt(ctx) - case paymentorder.FieldFailedAt: - return m.OldFailedAt(ctx) - case paymentorder.FieldFailedReason: - return m.OldFailedReason(ctx) - case paymentorder.FieldClientIP: - return m.OldClientIP(ctx) - case paymentorder.FieldSrcHost: - return m.OldSrcHost(ctx) - case paymentorder.FieldSrcURL: - return m.OldSrcURL(ctx) - case paymentorder.FieldCreatedAt: - return m.OldCreatedAt(ctx) - case paymentorder.FieldUpdatedAt: - return m.OldUpdatedAt(ctx) +// ClearQrCode clears the value of the "qr_code" field. +func (m *PaymentOrderMutation) ClearQrCode() { + m.qr_code = nil + m.clearedFields[paymentorder.FieldQrCode] = struct{}{} +} + +// QrCodeCleared returns if the "qr_code" field was cleared in this mutation. +func (m *PaymentOrderMutation) QrCodeCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldQrCode] + return ok +} + +// ResetQrCode resets all changes to the "qr_code" field. +func (m *PaymentOrderMutation) ResetQrCode() { + m.qr_code = nil + delete(m.clearedFields, paymentorder.FieldQrCode) +} + +// SetQrCodeImg sets the "qr_code_img" field. +func (m *PaymentOrderMutation) SetQrCodeImg(s string) { + m.qr_code_img = &s +} + +// QrCodeImg returns the value of the "qr_code_img" field in the mutation. +func (m *PaymentOrderMutation) QrCodeImg() (r string, exists bool) { + v := m.qr_code_img + if v == nil { + return } - return nil, fmt.Errorf("unknown PaymentOrder field %s", name) + return *v, true } -// SetField sets the value of a field with the given name. It returns an error if -// the field is not defined in the schema, or if the type mismatched the field -// type. -func (m *PaymentOrderMutation) SetField(name string, value ent.Value) error { - switch name { - case paymentorder.FieldUserID: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetUserID(v) - return nil - case paymentorder.FieldUserEmail: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetUserEmail(v) - return nil - case paymentorder.FieldUserName: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetUserName(v) - return nil - case paymentorder.FieldUserNotes: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetUserNotes(v) - return nil - case paymentorder.FieldAmount: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetAmount(v) - return nil - case paymentorder.FieldPayAmount: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetPayAmount(v) - return nil - case paymentorder.FieldFeeRate: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetFeeRate(v) - return nil - case paymentorder.FieldRechargeCode: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetRechargeCode(v) - return nil - case paymentorder.FieldOutTradeNo: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetOutTradeNo(v) - return nil - case paymentorder.FieldPaymentType: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetPaymentType(v) - return nil - case paymentorder.FieldPaymentTradeNo: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetPaymentTradeNo(v) - return nil - case paymentorder.FieldPayURL: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetPayURL(v) - return nil - case paymentorder.FieldQrCode: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetQrCode(v) - return nil - case paymentorder.FieldQrCodeImg: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetQrCodeImg(v) - return nil - case paymentorder.FieldOrderType: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetOrderType(v) - return nil - case paymentorder.FieldPlanID: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetPlanID(v) - return nil - case paymentorder.FieldSubscriptionGroupID: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetSubscriptionGroupID(v) - return nil - case paymentorder.FieldSubscriptionDays: - v, ok := value.(int) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetSubscriptionDays(v) - return nil - case paymentorder.FieldProviderInstanceID: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetProviderInstanceID(v) - return nil - case paymentorder.FieldStatus: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetStatus(v) - return nil - case paymentorder.FieldRefundAmount: - v, ok := value.(float64) +// OldQrCodeImg returns the old "qr_code_img" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldQrCodeImg(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldQrCodeImg is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldQrCodeImg requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldQrCodeImg: %w", err) + } + return oldValue.QrCodeImg, nil +} + +// ClearQrCodeImg clears the value of the "qr_code_img" field. +func (m *PaymentOrderMutation) ClearQrCodeImg() { + m.qr_code_img = nil + m.clearedFields[paymentorder.FieldQrCodeImg] = struct{}{} +} + +// QrCodeImgCleared returns if the "qr_code_img" field was cleared in this mutation. +func (m *PaymentOrderMutation) QrCodeImgCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldQrCodeImg] + return ok +} + +// ResetQrCodeImg resets all changes to the "qr_code_img" field. +func (m *PaymentOrderMutation) ResetQrCodeImg() { + m.qr_code_img = nil + delete(m.clearedFields, paymentorder.FieldQrCodeImg) +} + +// SetOrderType sets the "order_type" field. +func (m *PaymentOrderMutation) SetOrderType(s string) { + m.order_type = &s +} + +// OrderType returns the value of the "order_type" field in the mutation. +func (m *PaymentOrderMutation) OrderType() (r string, exists bool) { + v := m.order_type + if v == nil { + return + } + return *v, true +} + +// OldOrderType returns the old "order_type" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldOrderType(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldOrderType is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldOrderType requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldOrderType: %w", err) + } + return oldValue.OrderType, nil +} + +// ResetOrderType resets all changes to the "order_type" field. +func (m *PaymentOrderMutation) ResetOrderType() { + m.order_type = nil +} + +// SetPlanID sets the "plan_id" field. +func (m *PaymentOrderMutation) SetPlanID(i int64) { + m.plan_id = &i + m.addplan_id = nil +} + +// PlanID returns the value of the "plan_id" field in the mutation. +func (m *PaymentOrderMutation) PlanID() (r int64, exists bool) { + v := m.plan_id + if v == nil { + return + } + return *v, true +} + +// OldPlanID returns the old "plan_id" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldPlanID(ctx context.Context) (v *int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPlanID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPlanID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPlanID: %w", err) + } + return oldValue.PlanID, nil +} + +// AddPlanID adds i to the "plan_id" field. +func (m *PaymentOrderMutation) AddPlanID(i int64) { + if m.addplan_id != nil { + *m.addplan_id += i + } else { + m.addplan_id = &i + } +} + +// AddedPlanID returns the value that was added to the "plan_id" field in this mutation. +func (m *PaymentOrderMutation) AddedPlanID() (r int64, exists bool) { + v := m.addplan_id + if v == nil { + return + } + return *v, true +} + +// ClearPlanID clears the value of the "plan_id" field. +func (m *PaymentOrderMutation) ClearPlanID() { + m.plan_id = nil + m.addplan_id = nil + m.clearedFields[paymentorder.FieldPlanID] = struct{}{} +} + +// PlanIDCleared returns if the "plan_id" field was cleared in this mutation. +func (m *PaymentOrderMutation) PlanIDCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldPlanID] + return ok +} + +// ResetPlanID resets all changes to the "plan_id" field. +func (m *PaymentOrderMutation) ResetPlanID() { + m.plan_id = nil + m.addplan_id = nil + delete(m.clearedFields, paymentorder.FieldPlanID) +} + +// SetSubscriptionGroupID sets the "subscription_group_id" field. +func (m *PaymentOrderMutation) SetSubscriptionGroupID(i int64) { + m.subscription_group_id = &i + m.addsubscription_group_id = nil +} + +// SubscriptionGroupID returns the value of the "subscription_group_id" field in the mutation. +func (m *PaymentOrderMutation) SubscriptionGroupID() (r int64, exists bool) { + v := m.subscription_group_id + if v == nil { + return + } + return *v, true +} + +// OldSubscriptionGroupID returns the old "subscription_group_id" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldSubscriptionGroupID(ctx context.Context) (v *int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSubscriptionGroupID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSubscriptionGroupID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSubscriptionGroupID: %w", err) + } + return oldValue.SubscriptionGroupID, nil +} + +// AddSubscriptionGroupID adds i to the "subscription_group_id" field. +func (m *PaymentOrderMutation) AddSubscriptionGroupID(i int64) { + if m.addsubscription_group_id != nil { + *m.addsubscription_group_id += i + } else { + m.addsubscription_group_id = &i + } +} + +// AddedSubscriptionGroupID returns the value that was added to the "subscription_group_id" field in this mutation. +func (m *PaymentOrderMutation) AddedSubscriptionGroupID() (r int64, exists bool) { + v := m.addsubscription_group_id + if v == nil { + return + } + return *v, true +} + +// ClearSubscriptionGroupID clears the value of the "subscription_group_id" field. +func (m *PaymentOrderMutation) ClearSubscriptionGroupID() { + m.subscription_group_id = nil + m.addsubscription_group_id = nil + m.clearedFields[paymentorder.FieldSubscriptionGroupID] = struct{}{} +} + +// SubscriptionGroupIDCleared returns if the "subscription_group_id" field was cleared in this mutation. +func (m *PaymentOrderMutation) SubscriptionGroupIDCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldSubscriptionGroupID] + return ok +} + +// ResetSubscriptionGroupID resets all changes to the "subscription_group_id" field. +func (m *PaymentOrderMutation) ResetSubscriptionGroupID() { + m.subscription_group_id = nil + m.addsubscription_group_id = nil + delete(m.clearedFields, paymentorder.FieldSubscriptionGroupID) +} + +// SetSubscriptionDays sets the "subscription_days" field. +func (m *PaymentOrderMutation) SetSubscriptionDays(i int) { + m.subscription_days = &i + m.addsubscription_days = nil +} + +// SubscriptionDays returns the value of the "subscription_days" field in the mutation. +func (m *PaymentOrderMutation) SubscriptionDays() (r int, exists bool) { + v := m.subscription_days + if v == nil { + return + } + return *v, true +} + +// OldSubscriptionDays returns the old "subscription_days" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldSubscriptionDays(ctx context.Context) (v *int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSubscriptionDays is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSubscriptionDays requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSubscriptionDays: %w", err) + } + return oldValue.SubscriptionDays, nil +} + +// AddSubscriptionDays adds i to the "subscription_days" field. +func (m *PaymentOrderMutation) AddSubscriptionDays(i int) { + if m.addsubscription_days != nil { + *m.addsubscription_days += i + } else { + m.addsubscription_days = &i + } +} + +// AddedSubscriptionDays returns the value that was added to the "subscription_days" field in this mutation. +func (m *PaymentOrderMutation) AddedSubscriptionDays() (r int, exists bool) { + v := m.addsubscription_days + if v == nil { + return + } + return *v, true +} + +// ClearSubscriptionDays clears the value of the "subscription_days" field. +func (m *PaymentOrderMutation) ClearSubscriptionDays() { + m.subscription_days = nil + m.addsubscription_days = nil + m.clearedFields[paymentorder.FieldSubscriptionDays] = struct{}{} +} + +// SubscriptionDaysCleared returns if the "subscription_days" field was cleared in this mutation. +func (m *PaymentOrderMutation) SubscriptionDaysCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldSubscriptionDays] + return ok +} + +// ResetSubscriptionDays resets all changes to the "subscription_days" field. +func (m *PaymentOrderMutation) ResetSubscriptionDays() { + m.subscription_days = nil + m.addsubscription_days = nil + delete(m.clearedFields, paymentorder.FieldSubscriptionDays) +} + +// SetProviderInstanceID sets the "provider_instance_id" field. +func (m *PaymentOrderMutation) SetProviderInstanceID(s string) { + m.provider_instance_id = &s +} + +// ProviderInstanceID returns the value of the "provider_instance_id" field in the mutation. +func (m *PaymentOrderMutation) ProviderInstanceID() (r string, exists bool) { + v := m.provider_instance_id + if v == nil { + return + } + return *v, true +} + +// OldProviderInstanceID returns the old "provider_instance_id" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldProviderInstanceID(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderInstanceID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderInstanceID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderInstanceID: %w", err) + } + return oldValue.ProviderInstanceID, nil +} + +// ClearProviderInstanceID clears the value of the "provider_instance_id" field. +func (m *PaymentOrderMutation) ClearProviderInstanceID() { + m.provider_instance_id = nil + m.clearedFields[paymentorder.FieldProviderInstanceID] = struct{}{} +} + +// ProviderInstanceIDCleared returns if the "provider_instance_id" field was cleared in this mutation. +func (m *PaymentOrderMutation) ProviderInstanceIDCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldProviderInstanceID] + return ok +} + +// ResetProviderInstanceID resets all changes to the "provider_instance_id" field. +func (m *PaymentOrderMutation) ResetProviderInstanceID() { + m.provider_instance_id = nil + delete(m.clearedFields, paymentorder.FieldProviderInstanceID) +} + +// SetProviderKey sets the "provider_key" field. +func (m *PaymentOrderMutation) SetProviderKey(s string) { + m.provider_key = &s +} + +// ProviderKey returns the value of the "provider_key" field in the mutation. +func (m *PaymentOrderMutation) ProviderKey() (r string, exists bool) { + v := m.provider_key + if v == nil { + return + } + return *v, true +} + +// OldProviderKey returns the old "provider_key" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldProviderKey(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderKey is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderKey requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderKey: %w", err) + } + return oldValue.ProviderKey, nil +} + +// ClearProviderKey clears the value of the "provider_key" field. +func (m *PaymentOrderMutation) ClearProviderKey() { + m.provider_key = nil + m.clearedFields[paymentorder.FieldProviderKey] = struct{}{} +} + +// ProviderKeyCleared returns if the "provider_key" field was cleared in this mutation. +func (m *PaymentOrderMutation) ProviderKeyCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldProviderKey] + return ok +} + +// ResetProviderKey resets all changes to the "provider_key" field. +func (m *PaymentOrderMutation) ResetProviderKey() { + m.provider_key = nil + delete(m.clearedFields, paymentorder.FieldProviderKey) +} + +// SetProviderSnapshot sets the "provider_snapshot" field. +func (m *PaymentOrderMutation) SetProviderSnapshot(value map[string]interface{}) { + m.provider_snapshot = &value +} + +// ProviderSnapshot returns the value of the "provider_snapshot" field in the mutation. +func (m *PaymentOrderMutation) ProviderSnapshot() (r map[string]interface{}, exists bool) { + v := m.provider_snapshot + if v == nil { + return + } + return *v, true +} + +// OldProviderSnapshot returns the old "provider_snapshot" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldProviderSnapshot(ctx context.Context) (v map[string]interface{}, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderSnapshot is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderSnapshot requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderSnapshot: %w", err) + } + return oldValue.ProviderSnapshot, nil +} + +// ClearProviderSnapshot clears the value of the "provider_snapshot" field. +func (m *PaymentOrderMutation) ClearProviderSnapshot() { + m.provider_snapshot = nil + m.clearedFields[paymentorder.FieldProviderSnapshot] = struct{}{} +} + +// ProviderSnapshotCleared returns if the "provider_snapshot" field was cleared in this mutation. +func (m *PaymentOrderMutation) ProviderSnapshotCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldProviderSnapshot] + return ok +} + +// ResetProviderSnapshot resets all changes to the "provider_snapshot" field. +func (m *PaymentOrderMutation) ResetProviderSnapshot() { + m.provider_snapshot = nil + delete(m.clearedFields, paymentorder.FieldProviderSnapshot) +} + +// SetStatus sets the "status" field. +func (m *PaymentOrderMutation) SetStatus(s string) { + m.status = &s +} + +// Status returns the value of the "status" field in the mutation. +func (m *PaymentOrderMutation) Status() (r string, exists bool) { + v := m.status + if v == nil { + return + } + return *v, true +} + +// OldStatus returns the old "status" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldStatus(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldStatus is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldStatus requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldStatus: %w", err) + } + return oldValue.Status, nil +} + +// ResetStatus resets all changes to the "status" field. +func (m *PaymentOrderMutation) ResetStatus() { + m.status = nil +} + +// SetRefundAmount sets the "refund_amount" field. +func (m *PaymentOrderMutation) SetRefundAmount(f float64) { + m.refund_amount = &f + m.addrefund_amount = nil +} + +// RefundAmount returns the value of the "refund_amount" field in the mutation. +func (m *PaymentOrderMutation) RefundAmount() (r float64, exists bool) { + v := m.refund_amount + if v == nil { + return + } + return *v, true +} + +// OldRefundAmount returns the old "refund_amount" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldRefundAmount(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRefundAmount is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRefundAmount requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRefundAmount: %w", err) + } + return oldValue.RefundAmount, nil +} + +// AddRefundAmount adds f to the "refund_amount" field. +func (m *PaymentOrderMutation) AddRefundAmount(f float64) { + if m.addrefund_amount != nil { + *m.addrefund_amount += f + } else { + m.addrefund_amount = &f + } +} + +// AddedRefundAmount returns the value that was added to the "refund_amount" field in this mutation. +func (m *PaymentOrderMutation) AddedRefundAmount() (r float64, exists bool) { + v := m.addrefund_amount + if v == nil { + return + } + return *v, true +} + +// ResetRefundAmount resets all changes to the "refund_amount" field. +func (m *PaymentOrderMutation) ResetRefundAmount() { + m.refund_amount = nil + m.addrefund_amount = nil +} + +// SetRefundReason sets the "refund_reason" field. +func (m *PaymentOrderMutation) SetRefundReason(s string) { + m.refund_reason = &s +} + +// RefundReason returns the value of the "refund_reason" field in the mutation. +func (m *PaymentOrderMutation) RefundReason() (r string, exists bool) { + v := m.refund_reason + if v == nil { + return + } + return *v, true +} + +// OldRefundReason returns the old "refund_reason" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldRefundReason(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRefundReason is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRefundReason requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRefundReason: %w", err) + } + return oldValue.RefundReason, nil +} + +// ClearRefundReason clears the value of the "refund_reason" field. +func (m *PaymentOrderMutation) ClearRefundReason() { + m.refund_reason = nil + m.clearedFields[paymentorder.FieldRefundReason] = struct{}{} +} + +// RefundReasonCleared returns if the "refund_reason" field was cleared in this mutation. +func (m *PaymentOrderMutation) RefundReasonCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldRefundReason] + return ok +} + +// ResetRefundReason resets all changes to the "refund_reason" field. +func (m *PaymentOrderMutation) ResetRefundReason() { + m.refund_reason = nil + delete(m.clearedFields, paymentorder.FieldRefundReason) +} + +// SetRefundAt sets the "refund_at" field. +func (m *PaymentOrderMutation) SetRefundAt(t time.Time) { + m.refund_at = &t +} + +// RefundAt returns the value of the "refund_at" field in the mutation. +func (m *PaymentOrderMutation) RefundAt() (r time.Time, exists bool) { + v := m.refund_at + if v == nil { + return + } + return *v, true +} + +// OldRefundAt returns the old "refund_at" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldRefundAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRefundAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRefundAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRefundAt: %w", err) + } + return oldValue.RefundAt, nil +} + +// ClearRefundAt clears the value of the "refund_at" field. +func (m *PaymentOrderMutation) ClearRefundAt() { + m.refund_at = nil + m.clearedFields[paymentorder.FieldRefundAt] = struct{}{} +} + +// RefundAtCleared returns if the "refund_at" field was cleared in this mutation. +func (m *PaymentOrderMutation) RefundAtCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldRefundAt] + return ok +} + +// ResetRefundAt resets all changes to the "refund_at" field. +func (m *PaymentOrderMutation) ResetRefundAt() { + m.refund_at = nil + delete(m.clearedFields, paymentorder.FieldRefundAt) +} + +// SetForceRefund sets the "force_refund" field. +func (m *PaymentOrderMutation) SetForceRefund(b bool) { + m.force_refund = &b +} + +// ForceRefund returns the value of the "force_refund" field in the mutation. +func (m *PaymentOrderMutation) ForceRefund() (r bool, exists bool) { + v := m.force_refund + if v == nil { + return + } + return *v, true +} + +// OldForceRefund returns the old "force_refund" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldForceRefund(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldForceRefund is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldForceRefund requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldForceRefund: %w", err) + } + return oldValue.ForceRefund, nil +} + +// ResetForceRefund resets all changes to the "force_refund" field. +func (m *PaymentOrderMutation) ResetForceRefund() { + m.force_refund = nil +} + +// SetRefundRequestedAt sets the "refund_requested_at" field. +func (m *PaymentOrderMutation) SetRefundRequestedAt(t time.Time) { + m.refund_requested_at = &t +} + +// RefundRequestedAt returns the value of the "refund_requested_at" field in the mutation. +func (m *PaymentOrderMutation) RefundRequestedAt() (r time.Time, exists bool) { + v := m.refund_requested_at + if v == nil { + return + } + return *v, true +} + +// OldRefundRequestedAt returns the old "refund_requested_at" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldRefundRequestedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRefundRequestedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRefundRequestedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRefundRequestedAt: %w", err) + } + return oldValue.RefundRequestedAt, nil +} + +// ClearRefundRequestedAt clears the value of the "refund_requested_at" field. +func (m *PaymentOrderMutation) ClearRefundRequestedAt() { + m.refund_requested_at = nil + m.clearedFields[paymentorder.FieldRefundRequestedAt] = struct{}{} +} + +// RefundRequestedAtCleared returns if the "refund_requested_at" field was cleared in this mutation. +func (m *PaymentOrderMutation) RefundRequestedAtCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldRefundRequestedAt] + return ok +} + +// ResetRefundRequestedAt resets all changes to the "refund_requested_at" field. +func (m *PaymentOrderMutation) ResetRefundRequestedAt() { + m.refund_requested_at = nil + delete(m.clearedFields, paymentorder.FieldRefundRequestedAt) +} + +// SetRefundRequestReason sets the "refund_request_reason" field. +func (m *PaymentOrderMutation) SetRefundRequestReason(s string) { + m.refund_request_reason = &s +} + +// RefundRequestReason returns the value of the "refund_request_reason" field in the mutation. +func (m *PaymentOrderMutation) RefundRequestReason() (r string, exists bool) { + v := m.refund_request_reason + if v == nil { + return + } + return *v, true +} + +// OldRefundRequestReason returns the old "refund_request_reason" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldRefundRequestReason(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRefundRequestReason is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRefundRequestReason requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRefundRequestReason: %w", err) + } + return oldValue.RefundRequestReason, nil +} + +// ClearRefundRequestReason clears the value of the "refund_request_reason" field. +func (m *PaymentOrderMutation) ClearRefundRequestReason() { + m.refund_request_reason = nil + m.clearedFields[paymentorder.FieldRefundRequestReason] = struct{}{} +} + +// RefundRequestReasonCleared returns if the "refund_request_reason" field was cleared in this mutation. +func (m *PaymentOrderMutation) RefundRequestReasonCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldRefundRequestReason] + return ok +} + +// ResetRefundRequestReason resets all changes to the "refund_request_reason" field. +func (m *PaymentOrderMutation) ResetRefundRequestReason() { + m.refund_request_reason = nil + delete(m.clearedFields, paymentorder.FieldRefundRequestReason) +} + +// SetRefundRequestedBy sets the "refund_requested_by" field. +func (m *PaymentOrderMutation) SetRefundRequestedBy(s string) { + m.refund_requested_by = &s +} + +// RefundRequestedBy returns the value of the "refund_requested_by" field in the mutation. +func (m *PaymentOrderMutation) RefundRequestedBy() (r string, exists bool) { + v := m.refund_requested_by + if v == nil { + return + } + return *v, true +} + +// OldRefundRequestedBy returns the old "refund_requested_by" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldRefundRequestedBy(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRefundRequestedBy is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRefundRequestedBy requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRefundRequestedBy: %w", err) + } + return oldValue.RefundRequestedBy, nil +} + +// ClearRefundRequestedBy clears the value of the "refund_requested_by" field. +func (m *PaymentOrderMutation) ClearRefundRequestedBy() { + m.refund_requested_by = nil + m.clearedFields[paymentorder.FieldRefundRequestedBy] = struct{}{} +} + +// RefundRequestedByCleared returns if the "refund_requested_by" field was cleared in this mutation. +func (m *PaymentOrderMutation) RefundRequestedByCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldRefundRequestedBy] + return ok +} + +// ResetRefundRequestedBy resets all changes to the "refund_requested_by" field. +func (m *PaymentOrderMutation) ResetRefundRequestedBy() { + m.refund_requested_by = nil + delete(m.clearedFields, paymentorder.FieldRefundRequestedBy) +} + +// SetExpiresAt sets the "expires_at" field. +func (m *PaymentOrderMutation) SetExpiresAt(t time.Time) { + m.expires_at = &t +} + +// ExpiresAt returns the value of the "expires_at" field in the mutation. +func (m *PaymentOrderMutation) ExpiresAt() (r time.Time, exists bool) { + v := m.expires_at + if v == nil { + return + } + return *v, true +} + +// OldExpiresAt returns the old "expires_at" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldExpiresAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldExpiresAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err) + } + return oldValue.ExpiresAt, nil +} + +// ResetExpiresAt resets all changes to the "expires_at" field. +func (m *PaymentOrderMutation) ResetExpiresAt() { + m.expires_at = nil +} + +// SetPaidAt sets the "paid_at" field. +func (m *PaymentOrderMutation) SetPaidAt(t time.Time) { + m.paid_at = &t +} + +// PaidAt returns the value of the "paid_at" field in the mutation. +func (m *PaymentOrderMutation) PaidAt() (r time.Time, exists bool) { + v := m.paid_at + if v == nil { + return + } + return *v, true +} + +// OldPaidAt returns the old "paid_at" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldPaidAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPaidAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPaidAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPaidAt: %w", err) + } + return oldValue.PaidAt, nil +} + +// ClearPaidAt clears the value of the "paid_at" field. +func (m *PaymentOrderMutation) ClearPaidAt() { + m.paid_at = nil + m.clearedFields[paymentorder.FieldPaidAt] = struct{}{} +} + +// PaidAtCleared returns if the "paid_at" field was cleared in this mutation. +func (m *PaymentOrderMutation) PaidAtCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldPaidAt] + return ok +} + +// ResetPaidAt resets all changes to the "paid_at" field. +func (m *PaymentOrderMutation) ResetPaidAt() { + m.paid_at = nil + delete(m.clearedFields, paymentorder.FieldPaidAt) +} + +// SetCompletedAt sets the "completed_at" field. +func (m *PaymentOrderMutation) SetCompletedAt(t time.Time) { + m.completed_at = &t +} + +// CompletedAt returns the value of the "completed_at" field in the mutation. +func (m *PaymentOrderMutation) CompletedAt() (r time.Time, exists bool) { + v := m.completed_at + if v == nil { + return + } + return *v, true +} + +// OldCompletedAt returns the old "completed_at" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldCompletedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCompletedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCompletedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCompletedAt: %w", err) + } + return oldValue.CompletedAt, nil +} + +// ClearCompletedAt clears the value of the "completed_at" field. +func (m *PaymentOrderMutation) ClearCompletedAt() { + m.completed_at = nil + m.clearedFields[paymentorder.FieldCompletedAt] = struct{}{} +} + +// CompletedAtCleared returns if the "completed_at" field was cleared in this mutation. +func (m *PaymentOrderMutation) CompletedAtCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldCompletedAt] + return ok +} + +// ResetCompletedAt resets all changes to the "completed_at" field. +func (m *PaymentOrderMutation) ResetCompletedAt() { + m.completed_at = nil + delete(m.clearedFields, paymentorder.FieldCompletedAt) +} + +// SetFailedAt sets the "failed_at" field. +func (m *PaymentOrderMutation) SetFailedAt(t time.Time) { + m.failed_at = &t +} + +// FailedAt returns the value of the "failed_at" field in the mutation. +func (m *PaymentOrderMutation) FailedAt() (r time.Time, exists bool) { + v := m.failed_at + if v == nil { + return + } + return *v, true +} + +// OldFailedAt returns the old "failed_at" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldFailedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldFailedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldFailedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldFailedAt: %w", err) + } + return oldValue.FailedAt, nil +} + +// ClearFailedAt clears the value of the "failed_at" field. +func (m *PaymentOrderMutation) ClearFailedAt() { + m.failed_at = nil + m.clearedFields[paymentorder.FieldFailedAt] = struct{}{} +} + +// FailedAtCleared returns if the "failed_at" field was cleared in this mutation. +func (m *PaymentOrderMutation) FailedAtCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldFailedAt] + return ok +} + +// ResetFailedAt resets all changes to the "failed_at" field. +func (m *PaymentOrderMutation) ResetFailedAt() { + m.failed_at = nil + delete(m.clearedFields, paymentorder.FieldFailedAt) +} + +// SetFailedReason sets the "failed_reason" field. +func (m *PaymentOrderMutation) SetFailedReason(s string) { + m.failed_reason = &s +} + +// FailedReason returns the value of the "failed_reason" field in the mutation. +func (m *PaymentOrderMutation) FailedReason() (r string, exists bool) { + v := m.failed_reason + if v == nil { + return + } + return *v, true +} + +// OldFailedReason returns the old "failed_reason" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldFailedReason(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldFailedReason is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldFailedReason requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldFailedReason: %w", err) + } + return oldValue.FailedReason, nil +} + +// ClearFailedReason clears the value of the "failed_reason" field. +func (m *PaymentOrderMutation) ClearFailedReason() { + m.failed_reason = nil + m.clearedFields[paymentorder.FieldFailedReason] = struct{}{} +} + +// FailedReasonCleared returns if the "failed_reason" field was cleared in this mutation. +func (m *PaymentOrderMutation) FailedReasonCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldFailedReason] + return ok +} + +// ResetFailedReason resets all changes to the "failed_reason" field. +func (m *PaymentOrderMutation) ResetFailedReason() { + m.failed_reason = nil + delete(m.clearedFields, paymentorder.FieldFailedReason) +} + +// SetClientIP sets the "client_ip" field. +func (m *PaymentOrderMutation) SetClientIP(s string) { + m.client_ip = &s +} + +// ClientIP returns the value of the "client_ip" field in the mutation. +func (m *PaymentOrderMutation) ClientIP() (r string, exists bool) { + v := m.client_ip + if v == nil { + return + } + return *v, true +} + +// OldClientIP returns the old "client_ip" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldClientIP(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldClientIP is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldClientIP requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldClientIP: %w", err) + } + return oldValue.ClientIP, nil +} + +// ResetClientIP resets all changes to the "client_ip" field. +func (m *PaymentOrderMutation) ResetClientIP() { + m.client_ip = nil +} + +// SetSrcHost sets the "src_host" field. +func (m *PaymentOrderMutation) SetSrcHost(s string) { + m.src_host = &s +} + +// SrcHost returns the value of the "src_host" field in the mutation. +func (m *PaymentOrderMutation) SrcHost() (r string, exists bool) { + v := m.src_host + if v == nil { + return + } + return *v, true +} + +// OldSrcHost returns the old "src_host" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldSrcHost(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSrcHost is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSrcHost requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSrcHost: %w", err) + } + return oldValue.SrcHost, nil +} + +// ResetSrcHost resets all changes to the "src_host" field. +func (m *PaymentOrderMutation) ResetSrcHost() { + m.src_host = nil +} + +// SetSrcURL sets the "src_url" field. +func (m *PaymentOrderMutation) SetSrcURL(s string) { + m.src_url = &s +} + +// SrcURL returns the value of the "src_url" field in the mutation. +func (m *PaymentOrderMutation) SrcURL() (r string, exists bool) { + v := m.src_url + if v == nil { + return + } + return *v, true +} + +// OldSrcURL returns the old "src_url" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldSrcURL(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSrcURL is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSrcURL requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSrcURL: %w", err) + } + return oldValue.SrcURL, nil +} + +// ClearSrcURL clears the value of the "src_url" field. +func (m *PaymentOrderMutation) ClearSrcURL() { + m.src_url = nil + m.clearedFields[paymentorder.FieldSrcURL] = struct{}{} +} + +// SrcURLCleared returns if the "src_url" field was cleared in this mutation. +func (m *PaymentOrderMutation) SrcURLCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldSrcURL] + return ok +} + +// ResetSrcURL resets all changes to the "src_url" field. +func (m *PaymentOrderMutation) ResetSrcURL() { + m.src_url = nil + delete(m.clearedFields, paymentorder.FieldSrcURL) +} + +// SetCreatedAt sets the "created_at" field. +func (m *PaymentOrderMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *PaymentOrderMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *PaymentOrderMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *PaymentOrderMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *PaymentOrderMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *PaymentOrderMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// ClearUser clears the "user" edge to the User entity. +func (m *PaymentOrderMutation) ClearUser() { + m.cleareduser = true + m.clearedFields[paymentorder.FieldUserID] = struct{}{} +} + +// UserCleared reports if the "user" edge to the User entity was cleared. +func (m *PaymentOrderMutation) UserCleared() bool { + return m.cleareduser +} + +// UserIDs returns the "user" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// UserID instead. It exists only for internal usage by the builders. +func (m *PaymentOrderMutation) UserIDs() (ids []int64) { + if id := m.user; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetUser resets all changes to the "user" edge. +func (m *PaymentOrderMutation) ResetUser() { + m.user = nil + m.cleareduser = false +} + +// Where appends a list predicates to the PaymentOrderMutation builder. +func (m *PaymentOrderMutation) Where(ps ...predicate.PaymentOrder) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the PaymentOrderMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *PaymentOrderMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.PaymentOrder, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *PaymentOrderMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *PaymentOrderMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (PaymentOrder). +func (m *PaymentOrderMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *PaymentOrderMutation) Fields() []string { + fields := make([]string, 0, 39) + if m.user != nil { + fields = append(fields, paymentorder.FieldUserID) + } + if m.user_email != nil { + fields = append(fields, paymentorder.FieldUserEmail) + } + if m.user_name != nil { + fields = append(fields, paymentorder.FieldUserName) + } + if m.user_notes != nil { + fields = append(fields, paymentorder.FieldUserNotes) + } + if m.amount != nil { + fields = append(fields, paymentorder.FieldAmount) + } + if m.pay_amount != nil { + fields = append(fields, paymentorder.FieldPayAmount) + } + if m.fee_rate != nil { + fields = append(fields, paymentorder.FieldFeeRate) + } + if m.recharge_code != nil { + fields = append(fields, paymentorder.FieldRechargeCode) + } + if m.out_trade_no != nil { + fields = append(fields, paymentorder.FieldOutTradeNo) + } + if m.payment_type != nil { + fields = append(fields, paymentorder.FieldPaymentType) + } + if m.payment_trade_no != nil { + fields = append(fields, paymentorder.FieldPaymentTradeNo) + } + if m.pay_url != nil { + fields = append(fields, paymentorder.FieldPayURL) + } + if m.qr_code != nil { + fields = append(fields, paymentorder.FieldQrCode) + } + if m.qr_code_img != nil { + fields = append(fields, paymentorder.FieldQrCodeImg) + } + if m.order_type != nil { + fields = append(fields, paymentorder.FieldOrderType) + } + if m.plan_id != nil { + fields = append(fields, paymentorder.FieldPlanID) + } + if m.subscription_group_id != nil { + fields = append(fields, paymentorder.FieldSubscriptionGroupID) + } + if m.subscription_days != nil { + fields = append(fields, paymentorder.FieldSubscriptionDays) + } + if m.provider_instance_id != nil { + fields = append(fields, paymentorder.FieldProviderInstanceID) + } + if m.provider_key != nil { + fields = append(fields, paymentorder.FieldProviderKey) + } + if m.provider_snapshot != nil { + fields = append(fields, paymentorder.FieldProviderSnapshot) + } + if m.status != nil { + fields = append(fields, paymentorder.FieldStatus) + } + if m.refund_amount != nil { + fields = append(fields, paymentorder.FieldRefundAmount) + } + if m.refund_reason != nil { + fields = append(fields, paymentorder.FieldRefundReason) + } + if m.refund_at != nil { + fields = append(fields, paymentorder.FieldRefundAt) + } + if m.force_refund != nil { + fields = append(fields, paymentorder.FieldForceRefund) + } + if m.refund_requested_at != nil { + fields = append(fields, paymentorder.FieldRefundRequestedAt) + } + if m.refund_request_reason != nil { + fields = append(fields, paymentorder.FieldRefundRequestReason) + } + if m.refund_requested_by != nil { + fields = append(fields, paymentorder.FieldRefundRequestedBy) + } + if m.expires_at != nil { + fields = append(fields, paymentorder.FieldExpiresAt) + } + if m.paid_at != nil { + fields = append(fields, paymentorder.FieldPaidAt) + } + if m.completed_at != nil { + fields = append(fields, paymentorder.FieldCompletedAt) + } + if m.failed_at != nil { + fields = append(fields, paymentorder.FieldFailedAt) + } + if m.failed_reason != nil { + fields = append(fields, paymentorder.FieldFailedReason) + } + if m.client_ip != nil { + fields = append(fields, paymentorder.FieldClientIP) + } + if m.src_host != nil { + fields = append(fields, paymentorder.FieldSrcHost) + } + if m.src_url != nil { + fields = append(fields, paymentorder.FieldSrcURL) + } + if m.created_at != nil { + fields = append(fields, paymentorder.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, paymentorder.FieldUpdatedAt) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *PaymentOrderMutation) Field(name string) (ent.Value, bool) { + switch name { + case paymentorder.FieldUserID: + return m.UserID() + case paymentorder.FieldUserEmail: + return m.UserEmail() + case paymentorder.FieldUserName: + return m.UserName() + case paymentorder.FieldUserNotes: + return m.UserNotes() + case paymentorder.FieldAmount: + return m.Amount() + case paymentorder.FieldPayAmount: + return m.PayAmount() + case paymentorder.FieldFeeRate: + return m.FeeRate() + case paymentorder.FieldRechargeCode: + return m.RechargeCode() + case paymentorder.FieldOutTradeNo: + return m.OutTradeNo() + case paymentorder.FieldPaymentType: + return m.PaymentType() + case paymentorder.FieldPaymentTradeNo: + return m.PaymentTradeNo() + case paymentorder.FieldPayURL: + return m.PayURL() + case paymentorder.FieldQrCode: + return m.QrCode() + case paymentorder.FieldQrCodeImg: + return m.QrCodeImg() + case paymentorder.FieldOrderType: + return m.OrderType() + case paymentorder.FieldPlanID: + return m.PlanID() + case paymentorder.FieldSubscriptionGroupID: + return m.SubscriptionGroupID() + case paymentorder.FieldSubscriptionDays: + return m.SubscriptionDays() + case paymentorder.FieldProviderInstanceID: + return m.ProviderInstanceID() + case paymentorder.FieldProviderKey: + return m.ProviderKey() + case paymentorder.FieldProviderSnapshot: + return m.ProviderSnapshot() + case paymentorder.FieldStatus: + return m.Status() + case paymentorder.FieldRefundAmount: + return m.RefundAmount() + case paymentorder.FieldRefundReason: + return m.RefundReason() + case paymentorder.FieldRefundAt: + return m.RefundAt() + case paymentorder.FieldForceRefund: + return m.ForceRefund() + case paymentorder.FieldRefundRequestedAt: + return m.RefundRequestedAt() + case paymentorder.FieldRefundRequestReason: + return m.RefundRequestReason() + case paymentorder.FieldRefundRequestedBy: + return m.RefundRequestedBy() + case paymentorder.FieldExpiresAt: + return m.ExpiresAt() + case paymentorder.FieldPaidAt: + return m.PaidAt() + case paymentorder.FieldCompletedAt: + return m.CompletedAt() + case paymentorder.FieldFailedAt: + return m.FailedAt() + case paymentorder.FieldFailedReason: + return m.FailedReason() + case paymentorder.FieldClientIP: + return m.ClientIP() + case paymentorder.FieldSrcHost: + return m.SrcHost() + case paymentorder.FieldSrcURL: + return m.SrcURL() + case paymentorder.FieldCreatedAt: + return m.CreatedAt() + case paymentorder.FieldUpdatedAt: + return m.UpdatedAt() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *PaymentOrderMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case paymentorder.FieldUserID: + return m.OldUserID(ctx) + case paymentorder.FieldUserEmail: + return m.OldUserEmail(ctx) + case paymentorder.FieldUserName: + return m.OldUserName(ctx) + case paymentorder.FieldUserNotes: + return m.OldUserNotes(ctx) + case paymentorder.FieldAmount: + return m.OldAmount(ctx) + case paymentorder.FieldPayAmount: + return m.OldPayAmount(ctx) + case paymentorder.FieldFeeRate: + return m.OldFeeRate(ctx) + case paymentorder.FieldRechargeCode: + return m.OldRechargeCode(ctx) + case paymentorder.FieldOutTradeNo: + return m.OldOutTradeNo(ctx) + case paymentorder.FieldPaymentType: + return m.OldPaymentType(ctx) + case paymentorder.FieldPaymentTradeNo: + return m.OldPaymentTradeNo(ctx) + case paymentorder.FieldPayURL: + return m.OldPayURL(ctx) + case paymentorder.FieldQrCode: + return m.OldQrCode(ctx) + case paymentorder.FieldQrCodeImg: + return m.OldQrCodeImg(ctx) + case paymentorder.FieldOrderType: + return m.OldOrderType(ctx) + case paymentorder.FieldPlanID: + return m.OldPlanID(ctx) + case paymentorder.FieldSubscriptionGroupID: + return m.OldSubscriptionGroupID(ctx) + case paymentorder.FieldSubscriptionDays: + return m.OldSubscriptionDays(ctx) + case paymentorder.FieldProviderInstanceID: + return m.OldProviderInstanceID(ctx) + case paymentorder.FieldProviderKey: + return m.OldProviderKey(ctx) + case paymentorder.FieldProviderSnapshot: + return m.OldProviderSnapshot(ctx) + case paymentorder.FieldStatus: + return m.OldStatus(ctx) + case paymentorder.FieldRefundAmount: + return m.OldRefundAmount(ctx) + case paymentorder.FieldRefundReason: + return m.OldRefundReason(ctx) + case paymentorder.FieldRefundAt: + return m.OldRefundAt(ctx) + case paymentorder.FieldForceRefund: + return m.OldForceRefund(ctx) + case paymentorder.FieldRefundRequestedAt: + return m.OldRefundRequestedAt(ctx) + case paymentorder.FieldRefundRequestReason: + return m.OldRefundRequestReason(ctx) + case paymentorder.FieldRefundRequestedBy: + return m.OldRefundRequestedBy(ctx) + case paymentorder.FieldExpiresAt: + return m.OldExpiresAt(ctx) + case paymentorder.FieldPaidAt: + return m.OldPaidAt(ctx) + case paymentorder.FieldCompletedAt: + return m.OldCompletedAt(ctx) + case paymentorder.FieldFailedAt: + return m.OldFailedAt(ctx) + case paymentorder.FieldFailedReason: + return m.OldFailedReason(ctx) + case paymentorder.FieldClientIP: + return m.OldClientIP(ctx) + case paymentorder.FieldSrcHost: + return m.OldSrcHost(ctx) + case paymentorder.FieldSrcURL: + return m.OldSrcURL(ctx) + case paymentorder.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case paymentorder.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + } + return nil, fmt.Errorf("unknown PaymentOrder field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *PaymentOrderMutation) SetField(name string, value ent.Value) error { + switch name { + case paymentorder.FieldUserID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUserID(v) + return nil + case paymentorder.FieldUserEmail: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUserEmail(v) + return nil + case paymentorder.FieldUserName: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUserName(v) + return nil + case paymentorder.FieldUserNotes: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUserNotes(v) + return nil + case paymentorder.FieldAmount: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAmount(v) + return nil + case paymentorder.FieldPayAmount: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPayAmount(v) + return nil + case paymentorder.FieldFeeRate: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetFeeRate(v) + return nil + case paymentorder.FieldRechargeCode: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRechargeCode(v) + return nil + case paymentorder.FieldOutTradeNo: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetOutTradeNo(v) + return nil + case paymentorder.FieldPaymentType: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPaymentType(v) + return nil + case paymentorder.FieldPaymentTradeNo: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPaymentTradeNo(v) + return nil + case paymentorder.FieldPayURL: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPayURL(v) + return nil + case paymentorder.FieldQrCode: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetQrCode(v) + return nil + case paymentorder.FieldQrCodeImg: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetQrCodeImg(v) + return nil + case paymentorder.FieldOrderType: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetOrderType(v) + return nil + case paymentorder.FieldPlanID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPlanID(v) + return nil + case paymentorder.FieldSubscriptionGroupID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSubscriptionGroupID(v) + return nil + case paymentorder.FieldSubscriptionDays: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSubscriptionDays(v) + return nil + case paymentorder.FieldProviderInstanceID: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProviderInstanceID(v) + return nil + case paymentorder.FieldProviderKey: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProviderKey(v) + return nil + case paymentorder.FieldProviderSnapshot: + v, ok := value.(map[string]interface{}) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProviderSnapshot(v) + return nil + case paymentorder.FieldStatus: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetStatus(v) + return nil + case paymentorder.FieldRefundAmount: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRefundAmount(v) + return nil + case paymentorder.FieldRefundReason: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRefundReason(v) + return nil + case paymentorder.FieldRefundAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRefundAt(v) + return nil + case paymentorder.FieldForceRefund: + v, ok := value.(bool) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetRefundAmount(v) + m.SetForceRefund(v) + return nil + case paymentorder.FieldRefundRequestedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRefundRequestedAt(v) + return nil + case paymentorder.FieldRefundRequestReason: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRefundRequestReason(v) + return nil + case paymentorder.FieldRefundRequestedBy: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRefundRequestedBy(v) + return nil + case paymentorder.FieldExpiresAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetExpiresAt(v) + return nil + case paymentorder.FieldPaidAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPaidAt(v) + return nil + case paymentorder.FieldCompletedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCompletedAt(v) + return nil + case paymentorder.FieldFailedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetFailedAt(v) + return nil + case paymentorder.FieldFailedReason: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetFailedReason(v) + return nil + case paymentorder.FieldClientIP: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetClientIP(v) + return nil + case paymentorder.FieldSrcHost: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSrcHost(v) + return nil + case paymentorder.FieldSrcURL: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSrcURL(v) + return nil + case paymentorder.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case paymentorder.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + } + return fmt.Errorf("unknown PaymentOrder field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *PaymentOrderMutation) AddedFields() []string { + var fields []string + if m.addamount != nil { + fields = append(fields, paymentorder.FieldAmount) + } + if m.addpay_amount != nil { + fields = append(fields, paymentorder.FieldPayAmount) + } + if m.addfee_rate != nil { + fields = append(fields, paymentorder.FieldFeeRate) + } + if m.addplan_id != nil { + fields = append(fields, paymentorder.FieldPlanID) + } + if m.addsubscription_group_id != nil { + fields = append(fields, paymentorder.FieldSubscriptionGroupID) + } + if m.addsubscription_days != nil { + fields = append(fields, paymentorder.FieldSubscriptionDays) + } + if m.addrefund_amount != nil { + fields = append(fields, paymentorder.FieldRefundAmount) + } + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *PaymentOrderMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case paymentorder.FieldAmount: + return m.AddedAmount() + case paymentorder.FieldPayAmount: + return m.AddedPayAmount() + case paymentorder.FieldFeeRate: + return m.AddedFeeRate() + case paymentorder.FieldPlanID: + return m.AddedPlanID() + case paymentorder.FieldSubscriptionGroupID: + return m.AddedSubscriptionGroupID() + case paymentorder.FieldSubscriptionDays: + return m.AddedSubscriptionDays() + case paymentorder.FieldRefundAmount: + return m.AddedRefundAmount() + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *PaymentOrderMutation) AddField(name string, value ent.Value) error { + switch name { + case paymentorder.FieldAmount: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddAmount(v) + return nil + case paymentorder.FieldPayAmount: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddPayAmount(v) + return nil + case paymentorder.FieldFeeRate: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddFeeRate(v) + return nil + case paymentorder.FieldPlanID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddPlanID(v) + return nil + case paymentorder.FieldSubscriptionGroupID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSubscriptionGroupID(v) + return nil + case paymentorder.FieldSubscriptionDays: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSubscriptionDays(v) + return nil + case paymentorder.FieldRefundAmount: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddRefundAmount(v) + return nil + } + return fmt.Errorf("unknown PaymentOrder numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *PaymentOrderMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(paymentorder.FieldUserNotes) { + fields = append(fields, paymentorder.FieldUserNotes) + } + if m.FieldCleared(paymentorder.FieldPayURL) { + fields = append(fields, paymentorder.FieldPayURL) + } + if m.FieldCleared(paymentorder.FieldQrCode) { + fields = append(fields, paymentorder.FieldQrCode) + } + if m.FieldCleared(paymentorder.FieldQrCodeImg) { + fields = append(fields, paymentorder.FieldQrCodeImg) + } + if m.FieldCleared(paymentorder.FieldPlanID) { + fields = append(fields, paymentorder.FieldPlanID) + } + if m.FieldCleared(paymentorder.FieldSubscriptionGroupID) { + fields = append(fields, paymentorder.FieldSubscriptionGroupID) + } + if m.FieldCleared(paymentorder.FieldSubscriptionDays) { + fields = append(fields, paymentorder.FieldSubscriptionDays) + } + if m.FieldCleared(paymentorder.FieldProviderInstanceID) { + fields = append(fields, paymentorder.FieldProviderInstanceID) + } + if m.FieldCleared(paymentorder.FieldProviderKey) { + fields = append(fields, paymentorder.FieldProviderKey) + } + if m.FieldCleared(paymentorder.FieldProviderSnapshot) { + fields = append(fields, paymentorder.FieldProviderSnapshot) + } + if m.FieldCleared(paymentorder.FieldRefundReason) { + fields = append(fields, paymentorder.FieldRefundReason) + } + if m.FieldCleared(paymentorder.FieldRefundAt) { + fields = append(fields, paymentorder.FieldRefundAt) + } + if m.FieldCleared(paymentorder.FieldRefundRequestedAt) { + fields = append(fields, paymentorder.FieldRefundRequestedAt) + } + if m.FieldCleared(paymentorder.FieldRefundRequestReason) { + fields = append(fields, paymentorder.FieldRefundRequestReason) + } + if m.FieldCleared(paymentorder.FieldRefundRequestedBy) { + fields = append(fields, paymentorder.FieldRefundRequestedBy) + } + if m.FieldCleared(paymentorder.FieldPaidAt) { + fields = append(fields, paymentorder.FieldPaidAt) + } + if m.FieldCleared(paymentorder.FieldCompletedAt) { + fields = append(fields, paymentorder.FieldCompletedAt) + } + if m.FieldCleared(paymentorder.FieldFailedAt) { + fields = append(fields, paymentorder.FieldFailedAt) + } + if m.FieldCleared(paymentorder.FieldFailedReason) { + fields = append(fields, paymentorder.FieldFailedReason) + } + if m.FieldCleared(paymentorder.FieldSrcURL) { + fields = append(fields, paymentorder.FieldSrcURL) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *PaymentOrderMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *PaymentOrderMutation) ClearField(name string) error { + switch name { + case paymentorder.FieldUserNotes: + m.ClearUserNotes() + return nil + case paymentorder.FieldPayURL: + m.ClearPayURL() + return nil + case paymentorder.FieldQrCode: + m.ClearQrCode() + return nil + case paymentorder.FieldQrCodeImg: + m.ClearQrCodeImg() + return nil + case paymentorder.FieldPlanID: + m.ClearPlanID() + return nil + case paymentorder.FieldSubscriptionGroupID: + m.ClearSubscriptionGroupID() + return nil + case paymentorder.FieldSubscriptionDays: + m.ClearSubscriptionDays() + return nil + case paymentorder.FieldProviderInstanceID: + m.ClearProviderInstanceID() + return nil + case paymentorder.FieldProviderKey: + m.ClearProviderKey() + return nil + case paymentorder.FieldProviderSnapshot: + m.ClearProviderSnapshot() return nil case paymentorder.FieldRefundReason: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetRefundReason(v) + m.ClearRefundReason() return nil case paymentorder.FieldRefundAt: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetRefundAt(v) + m.ClearRefundAt() + return nil + case paymentorder.FieldRefundRequestedAt: + m.ClearRefundRequestedAt() + return nil + case paymentorder.FieldRefundRequestReason: + m.ClearRefundRequestReason() + return nil + case paymentorder.FieldRefundRequestedBy: + m.ClearRefundRequestedBy() + return nil + case paymentorder.FieldPaidAt: + m.ClearPaidAt() + return nil + case paymentorder.FieldCompletedAt: + m.ClearCompletedAt() + return nil + case paymentorder.FieldFailedAt: + m.ClearFailedAt() + return nil + case paymentorder.FieldFailedReason: + m.ClearFailedReason() + return nil + case paymentorder.FieldSrcURL: + m.ClearSrcURL() + return nil + } + return fmt.Errorf("unknown PaymentOrder nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *PaymentOrderMutation) ResetField(name string) error { + switch name { + case paymentorder.FieldUserID: + m.ResetUserID() + return nil + case paymentorder.FieldUserEmail: + m.ResetUserEmail() + return nil + case paymentorder.FieldUserName: + m.ResetUserName() + return nil + case paymentorder.FieldUserNotes: + m.ResetUserNotes() + return nil + case paymentorder.FieldAmount: + m.ResetAmount() + return nil + case paymentorder.FieldPayAmount: + m.ResetPayAmount() + return nil + case paymentorder.FieldFeeRate: + m.ResetFeeRate() + return nil + case paymentorder.FieldRechargeCode: + m.ResetRechargeCode() + return nil + case paymentorder.FieldOutTradeNo: + m.ResetOutTradeNo() + return nil + case paymentorder.FieldPaymentType: + m.ResetPaymentType() + return nil + case paymentorder.FieldPaymentTradeNo: + m.ResetPaymentTradeNo() + return nil + case paymentorder.FieldPayURL: + m.ResetPayURL() + return nil + case paymentorder.FieldQrCode: + m.ResetQrCode() + return nil + case paymentorder.FieldQrCodeImg: + m.ResetQrCodeImg() + return nil + case paymentorder.FieldOrderType: + m.ResetOrderType() + return nil + case paymentorder.FieldPlanID: + m.ResetPlanID() + return nil + case paymentorder.FieldSubscriptionGroupID: + m.ResetSubscriptionGroupID() + return nil + case paymentorder.FieldSubscriptionDays: + m.ResetSubscriptionDays() + return nil + case paymentorder.FieldProviderInstanceID: + m.ResetProviderInstanceID() + return nil + case paymentorder.FieldProviderKey: + m.ResetProviderKey() + return nil + case paymentorder.FieldProviderSnapshot: + m.ResetProviderSnapshot() + return nil + case paymentorder.FieldStatus: + m.ResetStatus() + return nil + case paymentorder.FieldRefundAmount: + m.ResetRefundAmount() + return nil + case paymentorder.FieldRefundReason: + m.ResetRefundReason() + return nil + case paymentorder.FieldRefundAt: + m.ResetRefundAt() return nil case paymentorder.FieldForceRefund: - v, ok := value.(bool) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetForceRefund(v) + m.ResetForceRefund() return nil case paymentorder.FieldRefundRequestedAt: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetRefundRequestedAt(v) + m.ResetRefundRequestedAt() return nil case paymentorder.FieldRefundRequestReason: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetRefundRequestReason(v) + m.ResetRefundRequestReason() return nil case paymentorder.FieldRefundRequestedBy: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetRefundRequestedBy(v) + m.ResetRefundRequestedBy() return nil case paymentorder.FieldExpiresAt: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetExpiresAt(v) + m.ResetExpiresAt() return nil case paymentorder.FieldPaidAt: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetPaidAt(v) + m.ResetPaidAt() return nil case paymentorder.FieldCompletedAt: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetCompletedAt(v) + m.ResetCompletedAt() return nil case paymentorder.FieldFailedAt: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetFailedAt(v) + m.ResetFailedAt() return nil case paymentorder.FieldFailedReason: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetFailedReason(v) + m.ResetFailedReason() return nil case paymentorder.FieldClientIP: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) + m.ResetClientIP() + return nil + case paymentorder.FieldSrcHost: + m.ResetSrcHost() + return nil + case paymentorder.FieldSrcURL: + m.ResetSrcURL() + return nil + case paymentorder.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case paymentorder.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + } + return fmt.Errorf("unknown PaymentOrder field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *PaymentOrderMutation) AddedEdges() []string { + edges := make([]string, 0, 1) + if m.user != nil { + edges = append(edges, paymentorder.EdgeUser) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *PaymentOrderMutation) AddedIDs(name string) []ent.Value { + switch name { + case paymentorder.EdgeUser: + if id := m.user; id != nil { + return []ent.Value{*id} } - m.SetClientIP(v) + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *PaymentOrderMutation) RemovedEdges() []string { + edges := make([]string, 0, 1) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *PaymentOrderMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *PaymentOrderMutation) ClearedEdges() []string { + edges := make([]string, 0, 1) + if m.cleareduser { + edges = append(edges, paymentorder.EdgeUser) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *PaymentOrderMutation) EdgeCleared(name string) bool { + switch name { + case paymentorder.EdgeUser: + return m.cleareduser + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *PaymentOrderMutation) ClearEdge(name string) error { + switch name { + case paymentorder.EdgeUser: + m.ClearUser() return nil - case paymentorder.FieldSrcHost: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetSrcHost(v) + } + return fmt.Errorf("unknown PaymentOrder unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *PaymentOrderMutation) ResetEdge(name string) error { + switch name { + case paymentorder.EdgeUser: + m.ResetUser() return nil - case paymentorder.FieldSrcURL: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) + } + return fmt.Errorf("unknown PaymentOrder edge %s", name) +} + +// PaymentProviderInstanceMutation represents an operation that mutates the PaymentProviderInstance nodes in the graph. +type PaymentProviderInstanceMutation struct { + config + op Op + typ string + id *int64 + provider_key *string + name *string + _config *string + supported_types *string + enabled *bool + payment_mode *string + sort_order *int + addsort_order *int + limits *string + refund_enabled *bool + allow_user_refund *bool + created_at *time.Time + updated_at *time.Time + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*PaymentProviderInstance, error) + predicates []predicate.PaymentProviderInstance +} + +var _ ent.Mutation = (*PaymentProviderInstanceMutation)(nil) + +// paymentproviderinstanceOption allows management of the mutation configuration using functional options. +type paymentproviderinstanceOption func(*PaymentProviderInstanceMutation) + +// newPaymentProviderInstanceMutation creates new mutation for the PaymentProviderInstance entity. +func newPaymentProviderInstanceMutation(c config, op Op, opts ...paymentproviderinstanceOption) *PaymentProviderInstanceMutation { + m := &PaymentProviderInstanceMutation{ + config: c, + op: op, + typ: TypePaymentProviderInstance, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withPaymentProviderInstanceID sets the ID field of the mutation. +func withPaymentProviderInstanceID(id int64) paymentproviderinstanceOption { + return func(m *PaymentProviderInstanceMutation) { + var ( + err error + once sync.Once + value *PaymentProviderInstance + ) + m.oldValue = func(ctx context.Context) (*PaymentProviderInstance, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().PaymentProviderInstance.Get(ctx, id) + } + }) + return value, err } - m.SetSrcURL(v) - return nil - case paymentorder.FieldCreatedAt: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) + m.id = &id + } +} + +// withPaymentProviderInstance sets the old PaymentProviderInstance of the mutation. +func withPaymentProviderInstance(node *PaymentProviderInstance) paymentproviderinstanceOption { + return func(m *PaymentProviderInstanceMutation) { + m.oldValue = func(context.Context) (*PaymentProviderInstance, error) { + return node, nil } - m.SetCreatedAt(v) - return nil - case paymentorder.FieldUpdatedAt: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m PaymentProviderInstanceMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m PaymentProviderInstanceMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *PaymentProviderInstanceMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *PaymentProviderInstanceMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil } - m.SetUpdatedAt(v) - return nil + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().PaymentProviderInstance.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetProviderKey sets the "provider_key" field. +func (m *PaymentProviderInstanceMutation) SetProviderKey(s string) { + m.provider_key = &s +} + +// ProviderKey returns the value of the "provider_key" field in the mutation. +func (m *PaymentProviderInstanceMutation) ProviderKey() (r string, exists bool) { + v := m.provider_key + if v == nil { + return + } + return *v, true +} + +// OldProviderKey returns the old "provider_key" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentProviderInstanceMutation) OldProviderKey(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderKey is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderKey requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderKey: %w", err) + } + return oldValue.ProviderKey, nil +} + +// ResetProviderKey resets all changes to the "provider_key" field. +func (m *PaymentProviderInstanceMutation) ResetProviderKey() { + m.provider_key = nil +} + +// SetName sets the "name" field. +func (m *PaymentProviderInstanceMutation) SetName(s string) { + m.name = &s +} + +// Name returns the value of the "name" field in the mutation. +func (m *PaymentProviderInstanceMutation) Name() (r string, exists bool) { + v := m.name + if v == nil { + return + } + return *v, true +} + +// OldName returns the old "name" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentProviderInstanceMutation) OldName(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldName is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldName requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldName: %w", err) + } + return oldValue.Name, nil +} + +// ResetName resets all changes to the "name" field. +func (m *PaymentProviderInstanceMutation) ResetName() { + m.name = nil +} + +// SetConfig sets the "config" field. +func (m *PaymentProviderInstanceMutation) SetConfig(s string) { + m._config = &s +} + +// Config returns the value of the "config" field in the mutation. +func (m *PaymentProviderInstanceMutation) Config() (r string, exists bool) { + v := m._config + if v == nil { + return + } + return *v, true +} + +// OldConfig returns the old "config" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentProviderInstanceMutation) OldConfig(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldConfig is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldConfig requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldConfig: %w", err) + } + return oldValue.Config, nil +} + +// ResetConfig resets all changes to the "config" field. +func (m *PaymentProviderInstanceMutation) ResetConfig() { + m._config = nil +} + +// SetSupportedTypes sets the "supported_types" field. +func (m *PaymentProviderInstanceMutation) SetSupportedTypes(s string) { + m.supported_types = &s +} + +// SupportedTypes returns the value of the "supported_types" field in the mutation. +func (m *PaymentProviderInstanceMutation) SupportedTypes() (r string, exists bool) { + v := m.supported_types + if v == nil { + return + } + return *v, true +} + +// OldSupportedTypes returns the old "supported_types" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentProviderInstanceMutation) OldSupportedTypes(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSupportedTypes is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSupportedTypes requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSupportedTypes: %w", err) + } + return oldValue.SupportedTypes, nil +} + +// ResetSupportedTypes resets all changes to the "supported_types" field. +func (m *PaymentProviderInstanceMutation) ResetSupportedTypes() { + m.supported_types = nil +} + +// SetEnabled sets the "enabled" field. +func (m *PaymentProviderInstanceMutation) SetEnabled(b bool) { + m.enabled = &b +} + +// Enabled returns the value of the "enabled" field in the mutation. +func (m *PaymentProviderInstanceMutation) Enabled() (r bool, exists bool) { + v := m.enabled + if v == nil { + return + } + return *v, true +} + +// OldEnabled returns the old "enabled" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentProviderInstanceMutation) OldEnabled(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldEnabled is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldEnabled requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldEnabled: %w", err) + } + return oldValue.Enabled, nil +} + +// ResetEnabled resets all changes to the "enabled" field. +func (m *PaymentProviderInstanceMutation) ResetEnabled() { + m.enabled = nil +} + +// SetPaymentMode sets the "payment_mode" field. +func (m *PaymentProviderInstanceMutation) SetPaymentMode(s string) { + m.payment_mode = &s +} + +// PaymentMode returns the value of the "payment_mode" field in the mutation. +func (m *PaymentProviderInstanceMutation) PaymentMode() (r string, exists bool) { + v := m.payment_mode + if v == nil { + return + } + return *v, true +} + +// OldPaymentMode returns the old "payment_mode" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentProviderInstanceMutation) OldPaymentMode(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPaymentMode is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPaymentMode requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPaymentMode: %w", err) + } + return oldValue.PaymentMode, nil +} + +// ResetPaymentMode resets all changes to the "payment_mode" field. +func (m *PaymentProviderInstanceMutation) ResetPaymentMode() { + m.payment_mode = nil +} + +// SetSortOrder sets the "sort_order" field. +func (m *PaymentProviderInstanceMutation) SetSortOrder(i int) { + m.sort_order = &i + m.addsort_order = nil +} + +// SortOrder returns the value of the "sort_order" field in the mutation. +func (m *PaymentProviderInstanceMutation) SortOrder() (r int, exists bool) { + v := m.sort_order + if v == nil { + return + } + return *v, true +} + +// OldSortOrder returns the old "sort_order" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentProviderInstanceMutation) OldSortOrder(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSortOrder is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSortOrder requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSortOrder: %w", err) + } + return oldValue.SortOrder, nil +} + +// AddSortOrder adds i to the "sort_order" field. +func (m *PaymentProviderInstanceMutation) AddSortOrder(i int) { + if m.addsort_order != nil { + *m.addsort_order += i + } else { + m.addsort_order = &i + } +} + +// AddedSortOrder returns the value that was added to the "sort_order" field in this mutation. +func (m *PaymentProviderInstanceMutation) AddedSortOrder() (r int, exists bool) { + v := m.addsort_order + if v == nil { + return + } + return *v, true +} + +// ResetSortOrder resets all changes to the "sort_order" field. +func (m *PaymentProviderInstanceMutation) ResetSortOrder() { + m.sort_order = nil + m.addsort_order = nil +} + +// SetLimits sets the "limits" field. +func (m *PaymentProviderInstanceMutation) SetLimits(s string) { + m.limits = &s +} + +// Limits returns the value of the "limits" field in the mutation. +func (m *PaymentProviderInstanceMutation) Limits() (r string, exists bool) { + v := m.limits + if v == nil { + return + } + return *v, true +} + +// OldLimits returns the old "limits" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentProviderInstanceMutation) OldLimits(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldLimits is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldLimits requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldLimits: %w", err) + } + return oldValue.Limits, nil +} + +// ResetLimits resets all changes to the "limits" field. +func (m *PaymentProviderInstanceMutation) ResetLimits() { + m.limits = nil +} + +// SetRefundEnabled sets the "refund_enabled" field. +func (m *PaymentProviderInstanceMutation) SetRefundEnabled(b bool) { + m.refund_enabled = &b +} + +// RefundEnabled returns the value of the "refund_enabled" field in the mutation. +func (m *PaymentProviderInstanceMutation) RefundEnabled() (r bool, exists bool) { + v := m.refund_enabled + if v == nil { + return } - return fmt.Errorf("unknown PaymentOrder field %s", name) + return *v, true } -// AddedFields returns all numeric fields that were incremented/decremented during -// this mutation. -func (m *PaymentOrderMutation) AddedFields() []string { - var fields []string - if m.addamount != nil { - fields = append(fields, paymentorder.FieldAmount) +// OldRefundEnabled returns the old "refund_enabled" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentProviderInstanceMutation) OldRefundEnabled(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRefundEnabled is only allowed on UpdateOne operations") } - if m.addpay_amount != nil { - fields = append(fields, paymentorder.FieldPayAmount) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRefundEnabled requires an ID field in the mutation") } - if m.addfee_rate != nil { - fields = append(fields, paymentorder.FieldFeeRate) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRefundEnabled: %w", err) } - if m.addplan_id != nil { - fields = append(fields, paymentorder.FieldPlanID) + return oldValue.RefundEnabled, nil +} + +// ResetRefundEnabled resets all changes to the "refund_enabled" field. +func (m *PaymentProviderInstanceMutation) ResetRefundEnabled() { + m.refund_enabled = nil +} + +// SetAllowUserRefund sets the "allow_user_refund" field. +func (m *PaymentProviderInstanceMutation) SetAllowUserRefund(b bool) { + m.allow_user_refund = &b +} + +// AllowUserRefund returns the value of the "allow_user_refund" field in the mutation. +func (m *PaymentProviderInstanceMutation) AllowUserRefund() (r bool, exists bool) { + v := m.allow_user_refund + if v == nil { + return } - if m.addsubscription_group_id != nil { - fields = append(fields, paymentorder.FieldSubscriptionGroupID) + return *v, true +} + +// OldAllowUserRefund returns the old "allow_user_refund" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentProviderInstanceMutation) OldAllowUserRefund(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAllowUserRefund is only allowed on UpdateOne operations") } - if m.addsubscription_days != nil { - fields = append(fields, paymentorder.FieldSubscriptionDays) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAllowUserRefund requires an ID field in the mutation") } - if m.addrefund_amount != nil { - fields = append(fields, paymentorder.FieldRefundAmount) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAllowUserRefund: %w", err) } - return fields + return oldValue.AllowUserRefund, nil } -// AddedField returns the numeric value that was incremented/decremented on a field -// with the given name. The second boolean return value indicates that this field -// was not set, or was not defined in the schema. -func (m *PaymentOrderMutation) AddedField(name string) (ent.Value, bool) { - switch name { - case paymentorder.FieldAmount: - return m.AddedAmount() - case paymentorder.FieldPayAmount: - return m.AddedPayAmount() - case paymentorder.FieldFeeRate: - return m.AddedFeeRate() - case paymentorder.FieldPlanID: - return m.AddedPlanID() - case paymentorder.FieldSubscriptionGroupID: - return m.AddedSubscriptionGroupID() - case paymentorder.FieldSubscriptionDays: - return m.AddedSubscriptionDays() - case paymentorder.FieldRefundAmount: - return m.AddedRefundAmount() - } - return nil, false +// ResetAllowUserRefund resets all changes to the "allow_user_refund" field. +func (m *PaymentProviderInstanceMutation) ResetAllowUserRefund() { + m.allow_user_refund = nil } -// AddField adds the value to the field with the given name. It returns an error if -// the field is not defined in the schema, or if the type mismatched the field -// type. -func (m *PaymentOrderMutation) AddField(name string, value ent.Value) error { - switch name { - case paymentorder.FieldAmount: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddAmount(v) - return nil - case paymentorder.FieldPayAmount: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddPayAmount(v) - return nil - case paymentorder.FieldFeeRate: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddFeeRate(v) - return nil - case paymentorder.FieldPlanID: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddPlanID(v) - return nil - case paymentorder.FieldSubscriptionGroupID: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddSubscriptionGroupID(v) - return nil - case paymentorder.FieldSubscriptionDays: - v, ok := value.(int) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddSubscriptionDays(v) - return nil - case paymentorder.FieldRefundAmount: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddRefundAmount(v) - return nil +// SetCreatedAt sets the "created_at" field. +func (m *PaymentProviderInstanceMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *PaymentProviderInstanceMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return } - return fmt.Errorf("unknown PaymentOrder numeric field %s", name) + return *v, true } -// ClearedFields returns all nullable fields that were cleared during this -// mutation. -func (m *PaymentOrderMutation) ClearedFields() []string { - var fields []string - if m.FieldCleared(paymentorder.FieldUserNotes) { - fields = append(fields, paymentorder.FieldUserNotes) +// OldCreatedAt returns the old "created_at" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentProviderInstanceMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") } - if m.FieldCleared(paymentorder.FieldPayURL) { - fields = append(fields, paymentorder.FieldPayURL) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") } - if m.FieldCleared(paymentorder.FieldQrCode) { - fields = append(fields, paymentorder.FieldQrCode) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) } - if m.FieldCleared(paymentorder.FieldQrCodeImg) { - fields = append(fields, paymentorder.FieldQrCodeImg) + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *PaymentProviderInstanceMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *PaymentProviderInstanceMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *PaymentProviderInstanceMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return } - if m.FieldCleared(paymentorder.FieldPlanID) { - fields = append(fields, paymentorder.FieldPlanID) + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentProviderInstanceMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") } - if m.FieldCleared(paymentorder.FieldSubscriptionGroupID) { - fields = append(fields, paymentorder.FieldSubscriptionGroupID) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") } - if m.FieldCleared(paymentorder.FieldSubscriptionDays) { - fields = append(fields, paymentorder.FieldSubscriptionDays) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *PaymentProviderInstanceMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// Where appends a list predicates to the PaymentProviderInstanceMutation builder. +func (m *PaymentProviderInstanceMutation) Where(ps ...predicate.PaymentProviderInstance) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the PaymentProviderInstanceMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *PaymentProviderInstanceMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.PaymentProviderInstance, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *PaymentProviderInstanceMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *PaymentProviderInstanceMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (PaymentProviderInstance). +func (m *PaymentProviderInstanceMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *PaymentProviderInstanceMutation) Fields() []string { + fields := make([]string, 0, 12) + if m.provider_key != nil { + fields = append(fields, paymentproviderinstance.FieldProviderKey) } - if m.FieldCleared(paymentorder.FieldProviderInstanceID) { - fields = append(fields, paymentorder.FieldProviderInstanceID) + if m.name != nil { + fields = append(fields, paymentproviderinstance.FieldName) } - if m.FieldCleared(paymentorder.FieldRefundReason) { - fields = append(fields, paymentorder.FieldRefundReason) + if m._config != nil { + fields = append(fields, paymentproviderinstance.FieldConfig) } - if m.FieldCleared(paymentorder.FieldRefundAt) { - fields = append(fields, paymentorder.FieldRefundAt) + if m.supported_types != nil { + fields = append(fields, paymentproviderinstance.FieldSupportedTypes) } - if m.FieldCleared(paymentorder.FieldRefundRequestedAt) { - fields = append(fields, paymentorder.FieldRefundRequestedAt) + if m.enabled != nil { + fields = append(fields, paymentproviderinstance.FieldEnabled) } - if m.FieldCleared(paymentorder.FieldRefundRequestReason) { - fields = append(fields, paymentorder.FieldRefundRequestReason) + if m.payment_mode != nil { + fields = append(fields, paymentproviderinstance.FieldPaymentMode) } - if m.FieldCleared(paymentorder.FieldRefundRequestedBy) { - fields = append(fields, paymentorder.FieldRefundRequestedBy) + if m.sort_order != nil { + fields = append(fields, paymentproviderinstance.FieldSortOrder) } - if m.FieldCleared(paymentorder.FieldPaidAt) { - fields = append(fields, paymentorder.FieldPaidAt) + if m.limits != nil { + fields = append(fields, paymentproviderinstance.FieldLimits) } - if m.FieldCleared(paymentorder.FieldCompletedAt) { - fields = append(fields, paymentorder.FieldCompletedAt) + if m.refund_enabled != nil { + fields = append(fields, paymentproviderinstance.FieldRefundEnabled) } - if m.FieldCleared(paymentorder.FieldFailedAt) { - fields = append(fields, paymentorder.FieldFailedAt) + if m.allow_user_refund != nil { + fields = append(fields, paymentproviderinstance.FieldAllowUserRefund) } - if m.FieldCleared(paymentorder.FieldFailedReason) { - fields = append(fields, paymentorder.FieldFailedReason) + if m.created_at != nil { + fields = append(fields, paymentproviderinstance.FieldCreatedAt) } - if m.FieldCleared(paymentorder.FieldSrcURL) { - fields = append(fields, paymentorder.FieldSrcURL) + if m.updated_at != nil { + fields = append(fields, paymentproviderinstance.FieldUpdatedAt) } return fields } -// FieldCleared returns a boolean indicating if a field with the given name was -// cleared in this mutation. -func (m *PaymentOrderMutation) FieldCleared(name string) bool { - _, ok := m.clearedFields[name] - return ok +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *PaymentProviderInstanceMutation) Field(name string) (ent.Value, bool) { + switch name { + case paymentproviderinstance.FieldProviderKey: + return m.ProviderKey() + case paymentproviderinstance.FieldName: + return m.Name() + case paymentproviderinstance.FieldConfig: + return m.Config() + case paymentproviderinstance.FieldSupportedTypes: + return m.SupportedTypes() + case paymentproviderinstance.FieldEnabled: + return m.Enabled() + case paymentproviderinstance.FieldPaymentMode: + return m.PaymentMode() + case paymentproviderinstance.FieldSortOrder: + return m.SortOrder() + case paymentproviderinstance.FieldLimits: + return m.Limits() + case paymentproviderinstance.FieldRefundEnabled: + return m.RefundEnabled() + case paymentproviderinstance.FieldAllowUserRefund: + return m.AllowUserRefund() + case paymentproviderinstance.FieldCreatedAt: + return m.CreatedAt() + case paymentproviderinstance.FieldUpdatedAt: + return m.UpdatedAt() + } + return nil, false } -// ClearField clears the value of the field with the given name. It returns an -// error if the field is not defined in the schema. -func (m *PaymentOrderMutation) ClearField(name string) error { +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *PaymentProviderInstanceMutation) OldField(ctx context.Context, name string) (ent.Value, error) { switch name { - case paymentorder.FieldUserNotes: - m.ClearUserNotes() - return nil - case paymentorder.FieldPayURL: - m.ClearPayURL() - return nil - case paymentorder.FieldQrCode: - m.ClearQrCode() - return nil - case paymentorder.FieldQrCodeImg: - m.ClearQrCodeImg() - return nil - case paymentorder.FieldPlanID: - m.ClearPlanID() - return nil - case paymentorder.FieldSubscriptionGroupID: - m.ClearSubscriptionGroupID() - return nil - case paymentorder.FieldSubscriptionDays: - m.ClearSubscriptionDays() - return nil - case paymentorder.FieldProviderInstanceID: - m.ClearProviderInstanceID() - return nil - case paymentorder.FieldRefundReason: - m.ClearRefundReason() - return nil - case paymentorder.FieldRefundAt: - m.ClearRefundAt() - return nil - case paymentorder.FieldRefundRequestedAt: - m.ClearRefundRequestedAt() - return nil - case paymentorder.FieldRefundRequestReason: - m.ClearRefundRequestReason() - return nil - case paymentorder.FieldRefundRequestedBy: - m.ClearRefundRequestedBy() - return nil - case paymentorder.FieldPaidAt: - m.ClearPaidAt() - return nil - case paymentorder.FieldCompletedAt: - m.ClearCompletedAt() - return nil - case paymentorder.FieldFailedAt: - m.ClearFailedAt() - return nil - case paymentorder.FieldFailedReason: - m.ClearFailedReason() - return nil - case paymentorder.FieldSrcURL: - m.ClearSrcURL() - return nil + case paymentproviderinstance.FieldProviderKey: + return m.OldProviderKey(ctx) + case paymentproviderinstance.FieldName: + return m.OldName(ctx) + case paymentproviderinstance.FieldConfig: + return m.OldConfig(ctx) + case paymentproviderinstance.FieldSupportedTypes: + return m.OldSupportedTypes(ctx) + case paymentproviderinstance.FieldEnabled: + return m.OldEnabled(ctx) + case paymentproviderinstance.FieldPaymentMode: + return m.OldPaymentMode(ctx) + case paymentproviderinstance.FieldSortOrder: + return m.OldSortOrder(ctx) + case paymentproviderinstance.FieldLimits: + return m.OldLimits(ctx) + case paymentproviderinstance.FieldRefundEnabled: + return m.OldRefundEnabled(ctx) + case paymentproviderinstance.FieldAllowUserRefund: + return m.OldAllowUserRefund(ctx) + case paymentproviderinstance.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case paymentproviderinstance.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) } - return fmt.Errorf("unknown PaymentOrder nullable field %s", name) + return nil, fmt.Errorf("unknown PaymentProviderInstance field %s", name) } -// ResetField resets all changes in the mutation for the field with the given name. -// It returns an error if the field is not defined in the schema. -func (m *PaymentOrderMutation) ResetField(name string) error { +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *PaymentProviderInstanceMutation) SetField(name string, value ent.Value) error { switch name { - case paymentorder.FieldUserID: - m.ResetUserID() - return nil - case paymentorder.FieldUserEmail: - m.ResetUserEmail() - return nil - case paymentorder.FieldUserName: - m.ResetUserName() - return nil - case paymentorder.FieldUserNotes: - m.ResetUserNotes() - return nil - case paymentorder.FieldAmount: - m.ResetAmount() - return nil - case paymentorder.FieldPayAmount: - m.ResetPayAmount() - return nil - case paymentorder.FieldFeeRate: - m.ResetFeeRate() - return nil - case paymentorder.FieldRechargeCode: - m.ResetRechargeCode() - return nil - case paymentorder.FieldOutTradeNo: - m.ResetOutTradeNo() - return nil - case paymentorder.FieldPaymentType: - m.ResetPaymentType() - return nil - case paymentorder.FieldPaymentTradeNo: - m.ResetPaymentTradeNo() - return nil - case paymentorder.FieldPayURL: - m.ResetPayURL() - return nil - case paymentorder.FieldQrCode: - m.ResetQrCode() + case paymentproviderinstance.FieldProviderKey: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProviderKey(v) return nil - case paymentorder.FieldQrCodeImg: - m.ResetQrCodeImg() + case paymentproviderinstance.FieldName: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetName(v) return nil - case paymentorder.FieldOrderType: - m.ResetOrderType() + case paymentproviderinstance.FieldConfig: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetConfig(v) return nil - case paymentorder.FieldPlanID: - m.ResetPlanID() + case paymentproviderinstance.FieldSupportedTypes: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSupportedTypes(v) return nil - case paymentorder.FieldSubscriptionGroupID: - m.ResetSubscriptionGroupID() + case paymentproviderinstance.FieldEnabled: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetEnabled(v) return nil - case paymentorder.FieldSubscriptionDays: - m.ResetSubscriptionDays() + case paymentproviderinstance.FieldPaymentMode: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPaymentMode(v) return nil - case paymentorder.FieldProviderInstanceID: - m.ResetProviderInstanceID() + case paymentproviderinstance.FieldSortOrder: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSortOrder(v) return nil - case paymentorder.FieldStatus: - m.ResetStatus() + case paymentproviderinstance.FieldLimits: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetLimits(v) return nil - case paymentorder.FieldRefundAmount: - m.ResetRefundAmount() + case paymentproviderinstance.FieldRefundEnabled: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRefundEnabled(v) return nil - case paymentorder.FieldRefundReason: - m.ResetRefundReason() + case paymentproviderinstance.FieldAllowUserRefund: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAllowUserRefund(v) return nil - case paymentorder.FieldRefundAt: - m.ResetRefundAt() + case paymentproviderinstance.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) return nil - case paymentorder.FieldForceRefund: - m.ResetForceRefund() + case paymentproviderinstance.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) return nil - case paymentorder.FieldRefundRequestedAt: - m.ResetRefundRequestedAt() + } + return fmt.Errorf("unknown PaymentProviderInstance field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *PaymentProviderInstanceMutation) AddedFields() []string { + var fields []string + if m.addsort_order != nil { + fields = append(fields, paymentproviderinstance.FieldSortOrder) + } + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *PaymentProviderInstanceMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case paymentproviderinstance.FieldSortOrder: + return m.AddedSortOrder() + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *PaymentProviderInstanceMutation) AddField(name string, value ent.Value) error { + switch name { + case paymentproviderinstance.FieldSortOrder: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSortOrder(v) return nil - case paymentorder.FieldRefundRequestReason: - m.ResetRefundRequestReason() + } + return fmt.Errorf("unknown PaymentProviderInstance numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *PaymentProviderInstanceMutation) ClearedFields() []string { + return nil +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *PaymentProviderInstanceMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *PaymentProviderInstanceMutation) ClearField(name string) error { + return fmt.Errorf("unknown PaymentProviderInstance nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *PaymentProviderInstanceMutation) ResetField(name string) error { + switch name { + case paymentproviderinstance.FieldProviderKey: + m.ResetProviderKey() return nil - case paymentorder.FieldRefundRequestedBy: - m.ResetRefundRequestedBy() + case paymentproviderinstance.FieldName: + m.ResetName() return nil - case paymentorder.FieldExpiresAt: - m.ResetExpiresAt() + case paymentproviderinstance.FieldConfig: + m.ResetConfig() return nil - case paymentorder.FieldPaidAt: - m.ResetPaidAt() + case paymentproviderinstance.FieldSupportedTypes: + m.ResetSupportedTypes() return nil - case paymentorder.FieldCompletedAt: - m.ResetCompletedAt() + case paymentproviderinstance.FieldEnabled: + m.ResetEnabled() return nil - case paymentorder.FieldFailedAt: - m.ResetFailedAt() + case paymentproviderinstance.FieldPaymentMode: + m.ResetPaymentMode() return nil - case paymentorder.FieldFailedReason: - m.ResetFailedReason() + case paymentproviderinstance.FieldSortOrder: + m.ResetSortOrder() return nil - case paymentorder.FieldClientIP: - m.ResetClientIP() + case paymentproviderinstance.FieldLimits: + m.ResetLimits() return nil - case paymentorder.FieldSrcHost: - m.ResetSrcHost() + case paymentproviderinstance.FieldRefundEnabled: + m.ResetRefundEnabled() return nil - case paymentorder.FieldSrcURL: - m.ResetSrcURL() + case paymentproviderinstance.FieldAllowUserRefund: + m.ResetAllowUserRefund() return nil - case paymentorder.FieldCreatedAt: + case paymentproviderinstance.FieldCreatedAt: m.ResetCreatedAt() return nil - case paymentorder.FieldUpdatedAt: + case paymentproviderinstance.FieldUpdatedAt: m.ResetUpdatedAt() return nil } - return fmt.Errorf("unknown PaymentOrder field %s", name) + return fmt.Errorf("unknown PaymentProviderInstance field %s", name) } // AddedEdges returns all edge names that were set/added in this mutation. -func (m *PaymentOrderMutation) AddedEdges() []string { - edges := make([]string, 0, 1) - if m.user != nil { - edges = append(edges, paymentorder.EdgeUser) - } +func (m *PaymentProviderInstanceMutation) AddedEdges() []string { + edges := make([]string, 0, 0) return edges } // AddedIDs returns all IDs (to other nodes) that were added for the given edge // name in this mutation. -func (m *PaymentOrderMutation) AddedIDs(name string) []ent.Value { - switch name { - case paymentorder.EdgeUser: - if id := m.user; id != nil { - return []ent.Value{*id} - } - } +func (m *PaymentProviderInstanceMutation) AddedIDs(name string) []ent.Value { return nil } // RemovedEdges returns all edge names that were removed in this mutation. -func (m *PaymentOrderMutation) RemovedEdges() []string { - edges := make([]string, 0, 1) +func (m *PaymentProviderInstanceMutation) RemovedEdges() []string { + edges := make([]string, 0, 0) return edges } // RemovedIDs returns all IDs (to other nodes) that were removed for the edge with // the given name in this mutation. -func (m *PaymentOrderMutation) RemovedIDs(name string) []ent.Value { +func (m *PaymentProviderInstanceMutation) RemovedIDs(name string) []ent.Value { return nil } // ClearedEdges returns all edge names that were cleared in this mutation. -func (m *PaymentOrderMutation) ClearedEdges() []string { - edges := make([]string, 0, 1) - if m.cleareduser { - edges = append(edges, paymentorder.EdgeUser) - } +func (m *PaymentProviderInstanceMutation) ClearedEdges() []string { + edges := make([]string, 0, 0) return edges } // EdgeCleared returns a boolean which indicates if the edge with the given name // was cleared in this mutation. -func (m *PaymentOrderMutation) EdgeCleared(name string) bool { - switch name { - case paymentorder.EdgeUser: - return m.cleareduser - } - return false -} - -// ClearEdge clears the value of the edge with the given name. It returns an error -// if that edge is not defined in the schema. -func (m *PaymentOrderMutation) ClearEdge(name string) error { - switch name { - case paymentorder.EdgeUser: - m.ClearUser() - return nil - } - return fmt.Errorf("unknown PaymentOrder unique edge %s", name) -} - -// ResetEdge resets all changes to the edge with the given name in this mutation. -// It returns an error if the edge is not defined in the schema. -func (m *PaymentOrderMutation) ResetEdge(name string) error { - switch name { - case paymentorder.EdgeUser: - m.ResetUser() - return nil - } - return fmt.Errorf("unknown PaymentOrder edge %s", name) -} - -// PaymentProviderInstanceMutation represents an operation that mutates the PaymentProviderInstance nodes in the graph. -type PaymentProviderInstanceMutation struct { - config - op Op - typ string - id *int64 - provider_key *string - name *string - _config *string - supported_types *string - enabled *bool - payment_mode *string - sort_order *int - addsort_order *int - limits *string - refund_enabled *bool - allow_user_refund *bool - created_at *time.Time - updated_at *time.Time - clearedFields map[string]struct{} - done bool - oldValue func(context.Context) (*PaymentProviderInstance, error) - predicates []predicate.PaymentProviderInstance +func (m *PaymentProviderInstanceMutation) EdgeCleared(name string) bool { + return false } -var _ ent.Mutation = (*PaymentProviderInstanceMutation)(nil) +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *PaymentProviderInstanceMutation) ClearEdge(name string) error { + return fmt.Errorf("unknown PaymentProviderInstance unique edge %s", name) +} -// paymentproviderinstanceOption allows management of the mutation configuration using functional options. -type paymentproviderinstanceOption func(*PaymentProviderInstanceMutation) +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *PaymentProviderInstanceMutation) ResetEdge(name string) error { + return fmt.Errorf("unknown PaymentProviderInstance edge %s", name) +} -// newPaymentProviderInstanceMutation creates new mutation for the PaymentProviderInstance entity. -func newPaymentProviderInstanceMutation(c config, op Op, opts ...paymentproviderinstanceOption) *PaymentProviderInstanceMutation { - m := &PaymentProviderInstanceMutation{ +// PendingAuthSessionMutation represents an operation that mutates the PendingAuthSession nodes in the graph. +type PendingAuthSessionMutation struct { + config + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + session_token *string + intent *string + provider_type *string + provider_key *string + provider_subject *string + redirect_to *string + resolved_email *string + registration_password_hash *string + upstream_identity_claims *map[string]interface{} + local_flow_state *map[string]interface{} + browser_session_key *string + completion_code_hash *string + completion_code_expires_at *time.Time + email_verified_at *time.Time + password_verified_at *time.Time + totp_verified_at *time.Time + expires_at *time.Time + consumed_at *time.Time + clearedFields map[string]struct{} + target_user *int64 + clearedtarget_user bool + adoption_decision *int64 + clearedadoption_decision bool + done bool + oldValue func(context.Context) (*PendingAuthSession, error) + predicates []predicate.PendingAuthSession +} + +var _ ent.Mutation = (*PendingAuthSessionMutation)(nil) + +// pendingauthsessionOption allows management of the mutation configuration using functional options. +type pendingauthsessionOption func(*PendingAuthSessionMutation) + +// newPendingAuthSessionMutation creates new mutation for the PendingAuthSession entity. +func newPendingAuthSessionMutation(c config, op Op, opts ...pendingauthsessionOption) *PendingAuthSessionMutation { + m := &PendingAuthSessionMutation{ config: c, op: op, - typ: TypePaymentProviderInstance, + typ: TypePendingAuthSession, clearedFields: make(map[string]struct{}), } for _, opt := range opts { @@ -15683,20 +19418,20 @@ func newPaymentProviderInstanceMutation(c config, op Op, opts ...paymentprovider return m } -// withPaymentProviderInstanceID sets the ID field of the mutation. -func withPaymentProviderInstanceID(id int64) paymentproviderinstanceOption { - return func(m *PaymentProviderInstanceMutation) { +// withPendingAuthSessionID sets the ID field of the mutation. +func withPendingAuthSessionID(id int64) pendingauthsessionOption { + return func(m *PendingAuthSessionMutation) { var ( err error once sync.Once - value *PaymentProviderInstance + value *PendingAuthSession ) - m.oldValue = func(ctx context.Context) (*PaymentProviderInstance, error) { + m.oldValue = func(ctx context.Context) (*PendingAuthSession, error) { once.Do(func() { if m.done { err = errors.New("querying old values post mutation is not allowed") } else { - value, err = m.Client().PaymentProviderInstance.Get(ctx, id) + value, err = m.Client().PendingAuthSession.Get(ctx, id) } }) return value, err @@ -15705,10 +19440,10 @@ func withPaymentProviderInstanceID(id int64) paymentproviderinstanceOption { } } -// withPaymentProviderInstance sets the old PaymentProviderInstance of the mutation. -func withPaymentProviderInstance(node *PaymentProviderInstance) paymentproviderinstanceOption { - return func(m *PaymentProviderInstanceMutation) { - m.oldValue = func(context.Context) (*PaymentProviderInstance, error) { +// withPendingAuthSession sets the old PendingAuthSession of the mutation. +func withPendingAuthSession(node *PendingAuthSession) pendingauthsessionOption { + return func(m *PendingAuthSessionMutation) { + m.oldValue = func(context.Context) (*PendingAuthSession, error) { return node, nil } m.id = &node.ID @@ -15717,7 +19452,7 @@ func withPaymentProviderInstance(node *PaymentProviderInstance) paymentprovideri // Client returns a new `ent.Client` from the mutation. If the mutation was // executed in a transaction (ent.Tx), a transactional client is returned. -func (m PaymentProviderInstanceMutation) Client() *Client { +func (m PendingAuthSessionMutation) Client() *Client { client := &Client{config: m.config} client.init() return client @@ -15725,7 +19460,7 @@ func (m PaymentProviderInstanceMutation) Client() *Client { // Tx returns an `ent.Tx` for mutations that were executed in transactions; // it returns an error otherwise. -func (m PaymentProviderInstanceMutation) Tx() (*Tx, error) { +func (m PendingAuthSessionMutation) Tx() (*Tx, error) { if _, ok := m.driver.(*txDriver); !ok { return nil, errors.New("ent: mutation is not running in a transaction") } @@ -15734,495 +19469,943 @@ func (m PaymentProviderInstanceMutation) Tx() (*Tx, error) { return tx, nil } -// ID returns the ID value in the mutation. Note that the ID is only available -// if it was provided to the builder or after it was returned from the database. -func (m *PaymentProviderInstanceMutation) ID() (id int64, exists bool) { - if m.id == nil { - return - } - return *m.id, true +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *PendingAuthSessionMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *PendingAuthSessionMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().PendingAuthSession.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *PendingAuthSessionMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *PendingAuthSessionMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *PendingAuthSessionMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *PendingAuthSessionMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *PendingAuthSessionMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *PendingAuthSessionMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetSessionToken sets the "session_token" field. +func (m *PendingAuthSessionMutation) SetSessionToken(s string) { + m.session_token = &s +} + +// SessionToken returns the value of the "session_token" field in the mutation. +func (m *PendingAuthSessionMutation) SessionToken() (r string, exists bool) { + v := m.session_token + if v == nil { + return + } + return *v, true +} + +// OldSessionToken returns the old "session_token" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldSessionToken(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSessionToken is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSessionToken requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSessionToken: %w", err) + } + return oldValue.SessionToken, nil +} + +// ResetSessionToken resets all changes to the "session_token" field. +func (m *PendingAuthSessionMutation) ResetSessionToken() { + m.session_token = nil +} + +// SetIntent sets the "intent" field. +func (m *PendingAuthSessionMutation) SetIntent(s string) { + m.intent = &s +} + +// Intent returns the value of the "intent" field in the mutation. +func (m *PendingAuthSessionMutation) Intent() (r string, exists bool) { + v := m.intent + if v == nil { + return + } + return *v, true +} + +// OldIntent returns the old "intent" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldIntent(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIntent is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIntent requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIntent: %w", err) + } + return oldValue.Intent, nil +} + +// ResetIntent resets all changes to the "intent" field. +func (m *PendingAuthSessionMutation) ResetIntent() { + m.intent = nil +} + +// SetProviderType sets the "provider_type" field. +func (m *PendingAuthSessionMutation) SetProviderType(s string) { + m.provider_type = &s +} + +// ProviderType returns the value of the "provider_type" field in the mutation. +func (m *PendingAuthSessionMutation) ProviderType() (r string, exists bool) { + v := m.provider_type + if v == nil { + return + } + return *v, true +} + +// OldProviderType returns the old "provider_type" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldProviderType(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderType is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderType requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderType: %w", err) + } + return oldValue.ProviderType, nil +} + +// ResetProviderType resets all changes to the "provider_type" field. +func (m *PendingAuthSessionMutation) ResetProviderType() { + m.provider_type = nil +} + +// SetProviderKey sets the "provider_key" field. +func (m *PendingAuthSessionMutation) SetProviderKey(s string) { + m.provider_key = &s +} + +// ProviderKey returns the value of the "provider_key" field in the mutation. +func (m *PendingAuthSessionMutation) ProviderKey() (r string, exists bool) { + v := m.provider_key + if v == nil { + return + } + return *v, true +} + +// OldProviderKey returns the old "provider_key" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldProviderKey(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderKey is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderKey requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderKey: %w", err) + } + return oldValue.ProviderKey, nil +} + +// ResetProviderKey resets all changes to the "provider_key" field. +func (m *PendingAuthSessionMutation) ResetProviderKey() { + m.provider_key = nil +} + +// SetProviderSubject sets the "provider_subject" field. +func (m *PendingAuthSessionMutation) SetProviderSubject(s string) { + m.provider_subject = &s +} + +// ProviderSubject returns the value of the "provider_subject" field in the mutation. +func (m *PendingAuthSessionMutation) ProviderSubject() (r string, exists bool) { + v := m.provider_subject + if v == nil { + return + } + return *v, true +} + +// OldProviderSubject returns the old "provider_subject" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldProviderSubject(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderSubject is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderSubject requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderSubject: %w", err) + } + return oldValue.ProviderSubject, nil +} + +// ResetProviderSubject resets all changes to the "provider_subject" field. +func (m *PendingAuthSessionMutation) ResetProviderSubject() { + m.provider_subject = nil +} + +// SetTargetUserID sets the "target_user_id" field. +func (m *PendingAuthSessionMutation) SetTargetUserID(i int64) { + m.target_user = &i +} + +// TargetUserID returns the value of the "target_user_id" field in the mutation. +func (m *PendingAuthSessionMutation) TargetUserID() (r int64, exists bool) { + v := m.target_user + if v == nil { + return + } + return *v, true +} + +// OldTargetUserID returns the old "target_user_id" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldTargetUserID(ctx context.Context) (v *int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldTargetUserID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldTargetUserID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldTargetUserID: %w", err) + } + return oldValue.TargetUserID, nil +} + +// ClearTargetUserID clears the value of the "target_user_id" field. +func (m *PendingAuthSessionMutation) ClearTargetUserID() { + m.target_user = nil + m.clearedFields[pendingauthsession.FieldTargetUserID] = struct{}{} } -// IDs queries the database and returns the entity ids that match the mutation's predicate. -// That means, if the mutation is applied within a transaction with an isolation level such -// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated -// or updated by the mutation. -func (m *PaymentProviderInstanceMutation) IDs(ctx context.Context) ([]int64, error) { - switch { - case m.op.Is(OpUpdateOne | OpDeleteOne): - id, exists := m.ID() - if exists { - return []int64{id}, nil - } - fallthrough - case m.op.Is(OpUpdate | OpDelete): - return m.Client().PaymentProviderInstance.Query().Where(m.predicates...).IDs(ctx) - default: - return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) - } +// TargetUserIDCleared returns if the "target_user_id" field was cleared in this mutation. +func (m *PendingAuthSessionMutation) TargetUserIDCleared() bool { + _, ok := m.clearedFields[pendingauthsession.FieldTargetUserID] + return ok } -// SetProviderKey sets the "provider_key" field. -func (m *PaymentProviderInstanceMutation) SetProviderKey(s string) { - m.provider_key = &s +// ResetTargetUserID resets all changes to the "target_user_id" field. +func (m *PendingAuthSessionMutation) ResetTargetUserID() { + m.target_user = nil + delete(m.clearedFields, pendingauthsession.FieldTargetUserID) } -// ProviderKey returns the value of the "provider_key" field in the mutation. -func (m *PaymentProviderInstanceMutation) ProviderKey() (r string, exists bool) { - v := m.provider_key +// SetRedirectTo sets the "redirect_to" field. +func (m *PendingAuthSessionMutation) SetRedirectTo(s string) { + m.redirect_to = &s +} + +// RedirectTo returns the value of the "redirect_to" field in the mutation. +func (m *PendingAuthSessionMutation) RedirectTo() (r string, exists bool) { + v := m.redirect_to if v == nil { return } return *v, true } -// OldProviderKey returns the old "provider_key" field's value of the PaymentProviderInstance entity. -// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// OldRedirectTo returns the old "redirect_to" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentProviderInstanceMutation) OldProviderKey(ctx context.Context) (v string, err error) { +func (m *PendingAuthSessionMutation) OldRedirectTo(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldProviderKey is only allowed on UpdateOne operations") + return v, errors.New("OldRedirectTo is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldProviderKey requires an ID field in the mutation") + return v, errors.New("OldRedirectTo requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldProviderKey: %w", err) + return v, fmt.Errorf("querying old value for OldRedirectTo: %w", err) } - return oldValue.ProviderKey, nil + return oldValue.RedirectTo, nil } -// ResetProviderKey resets all changes to the "provider_key" field. -func (m *PaymentProviderInstanceMutation) ResetProviderKey() { - m.provider_key = nil +// ResetRedirectTo resets all changes to the "redirect_to" field. +func (m *PendingAuthSessionMutation) ResetRedirectTo() { + m.redirect_to = nil } -// SetName sets the "name" field. -func (m *PaymentProviderInstanceMutation) SetName(s string) { - m.name = &s +// SetResolvedEmail sets the "resolved_email" field. +func (m *PendingAuthSessionMutation) SetResolvedEmail(s string) { + m.resolved_email = &s } -// Name returns the value of the "name" field in the mutation. -func (m *PaymentProviderInstanceMutation) Name() (r string, exists bool) { - v := m.name +// ResolvedEmail returns the value of the "resolved_email" field in the mutation. +func (m *PendingAuthSessionMutation) ResolvedEmail() (r string, exists bool) { + v := m.resolved_email if v == nil { return } return *v, true } -// OldName returns the old "name" field's value of the PaymentProviderInstance entity. -// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// OldResolvedEmail returns the old "resolved_email" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentProviderInstanceMutation) OldName(ctx context.Context) (v string, err error) { +func (m *PendingAuthSessionMutation) OldResolvedEmail(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldName is only allowed on UpdateOne operations") + return v, errors.New("OldResolvedEmail is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldName requires an ID field in the mutation") + return v, errors.New("OldResolvedEmail requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldName: %w", err) + return v, fmt.Errorf("querying old value for OldResolvedEmail: %w", err) } - return oldValue.Name, nil + return oldValue.ResolvedEmail, nil } -// ResetName resets all changes to the "name" field. -func (m *PaymentProviderInstanceMutation) ResetName() { - m.name = nil +// ResetResolvedEmail resets all changes to the "resolved_email" field. +func (m *PendingAuthSessionMutation) ResetResolvedEmail() { + m.resolved_email = nil } -// SetConfig sets the "config" field. -func (m *PaymentProviderInstanceMutation) SetConfig(s string) { - m._config = &s +// SetRegistrationPasswordHash sets the "registration_password_hash" field. +func (m *PendingAuthSessionMutation) SetRegistrationPasswordHash(s string) { + m.registration_password_hash = &s } -// Config returns the value of the "config" field in the mutation. -func (m *PaymentProviderInstanceMutation) Config() (r string, exists bool) { - v := m._config +// RegistrationPasswordHash returns the value of the "registration_password_hash" field in the mutation. +func (m *PendingAuthSessionMutation) RegistrationPasswordHash() (r string, exists bool) { + v := m.registration_password_hash if v == nil { return } return *v, true } -// OldConfig returns the old "config" field's value of the PaymentProviderInstance entity. -// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// OldRegistrationPasswordHash returns the old "registration_password_hash" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentProviderInstanceMutation) OldConfig(ctx context.Context) (v string, err error) { +func (m *PendingAuthSessionMutation) OldRegistrationPasswordHash(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldConfig is only allowed on UpdateOne operations") + return v, errors.New("OldRegistrationPasswordHash is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldConfig requires an ID field in the mutation") + return v, errors.New("OldRegistrationPasswordHash requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldConfig: %w", err) + return v, fmt.Errorf("querying old value for OldRegistrationPasswordHash: %w", err) } - return oldValue.Config, nil + return oldValue.RegistrationPasswordHash, nil } -// ResetConfig resets all changes to the "config" field. -func (m *PaymentProviderInstanceMutation) ResetConfig() { - m._config = nil +// ResetRegistrationPasswordHash resets all changes to the "registration_password_hash" field. +func (m *PendingAuthSessionMutation) ResetRegistrationPasswordHash() { + m.registration_password_hash = nil } -// SetSupportedTypes sets the "supported_types" field. -func (m *PaymentProviderInstanceMutation) SetSupportedTypes(s string) { - m.supported_types = &s +// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field. +func (m *PendingAuthSessionMutation) SetUpstreamIdentityClaims(value map[string]interface{}) { + m.upstream_identity_claims = &value } -// SupportedTypes returns the value of the "supported_types" field in the mutation. -func (m *PaymentProviderInstanceMutation) SupportedTypes() (r string, exists bool) { - v := m.supported_types +// UpstreamIdentityClaims returns the value of the "upstream_identity_claims" field in the mutation. +func (m *PendingAuthSessionMutation) UpstreamIdentityClaims() (r map[string]interface{}, exists bool) { + v := m.upstream_identity_claims if v == nil { return } return *v, true } -// OldSupportedTypes returns the old "supported_types" field's value of the PaymentProviderInstance entity. -// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// OldUpstreamIdentityClaims returns the old "upstream_identity_claims" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentProviderInstanceMutation) OldSupportedTypes(ctx context.Context) (v string, err error) { +func (m *PendingAuthSessionMutation) OldUpstreamIdentityClaims(ctx context.Context) (v map[string]interface{}, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSupportedTypes is only allowed on UpdateOne operations") + return v, errors.New("OldUpstreamIdentityClaims is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSupportedTypes requires an ID field in the mutation") + return v, errors.New("OldUpstreamIdentityClaims requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldSupportedTypes: %w", err) + return v, fmt.Errorf("querying old value for OldUpstreamIdentityClaims: %w", err) } - return oldValue.SupportedTypes, nil + return oldValue.UpstreamIdentityClaims, nil } -// ResetSupportedTypes resets all changes to the "supported_types" field. -func (m *PaymentProviderInstanceMutation) ResetSupportedTypes() { - m.supported_types = nil +// ResetUpstreamIdentityClaims resets all changes to the "upstream_identity_claims" field. +func (m *PendingAuthSessionMutation) ResetUpstreamIdentityClaims() { + m.upstream_identity_claims = nil } -// SetEnabled sets the "enabled" field. -func (m *PaymentProviderInstanceMutation) SetEnabled(b bool) { - m.enabled = &b +// SetLocalFlowState sets the "local_flow_state" field. +func (m *PendingAuthSessionMutation) SetLocalFlowState(value map[string]interface{}) { + m.local_flow_state = &value } -// Enabled returns the value of the "enabled" field in the mutation. -func (m *PaymentProviderInstanceMutation) Enabled() (r bool, exists bool) { - v := m.enabled +// LocalFlowState returns the value of the "local_flow_state" field in the mutation. +func (m *PendingAuthSessionMutation) LocalFlowState() (r map[string]interface{}, exists bool) { + v := m.local_flow_state if v == nil { return } return *v, true } -// OldEnabled returns the old "enabled" field's value of the PaymentProviderInstance entity. -// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// OldLocalFlowState returns the old "local_flow_state" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentProviderInstanceMutation) OldEnabled(ctx context.Context) (v bool, err error) { +func (m *PendingAuthSessionMutation) OldLocalFlowState(ctx context.Context) (v map[string]interface{}, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldEnabled is only allowed on UpdateOne operations") + return v, errors.New("OldLocalFlowState is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldEnabled requires an ID field in the mutation") + return v, errors.New("OldLocalFlowState requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldEnabled: %w", err) + return v, fmt.Errorf("querying old value for OldLocalFlowState: %w", err) } - return oldValue.Enabled, nil + return oldValue.LocalFlowState, nil } -// ResetEnabled resets all changes to the "enabled" field. -func (m *PaymentProviderInstanceMutation) ResetEnabled() { - m.enabled = nil +// ResetLocalFlowState resets all changes to the "local_flow_state" field. +func (m *PendingAuthSessionMutation) ResetLocalFlowState() { + m.local_flow_state = nil } -// SetPaymentMode sets the "payment_mode" field. -func (m *PaymentProviderInstanceMutation) SetPaymentMode(s string) { - m.payment_mode = &s +// SetBrowserSessionKey sets the "browser_session_key" field. +func (m *PendingAuthSessionMutation) SetBrowserSessionKey(s string) { + m.browser_session_key = &s } -// PaymentMode returns the value of the "payment_mode" field in the mutation. -func (m *PaymentProviderInstanceMutation) PaymentMode() (r string, exists bool) { - v := m.payment_mode +// BrowserSessionKey returns the value of the "browser_session_key" field in the mutation. +func (m *PendingAuthSessionMutation) BrowserSessionKey() (r string, exists bool) { + v := m.browser_session_key if v == nil { return } return *v, true } -// OldPaymentMode returns the old "payment_mode" field's value of the PaymentProviderInstance entity. -// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// OldBrowserSessionKey returns the old "browser_session_key" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentProviderInstanceMutation) OldPaymentMode(ctx context.Context) (v string, err error) { +func (m *PendingAuthSessionMutation) OldBrowserSessionKey(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPaymentMode is only allowed on UpdateOne operations") + return v, errors.New("OldBrowserSessionKey is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPaymentMode requires an ID field in the mutation") + return v, errors.New("OldBrowserSessionKey requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldPaymentMode: %w", err) + return v, fmt.Errorf("querying old value for OldBrowserSessionKey: %w", err) } - return oldValue.PaymentMode, nil + return oldValue.BrowserSessionKey, nil } -// ResetPaymentMode resets all changes to the "payment_mode" field. -func (m *PaymentProviderInstanceMutation) ResetPaymentMode() { - m.payment_mode = nil +// ResetBrowserSessionKey resets all changes to the "browser_session_key" field. +func (m *PendingAuthSessionMutation) ResetBrowserSessionKey() { + m.browser_session_key = nil } -// SetSortOrder sets the "sort_order" field. -func (m *PaymentProviderInstanceMutation) SetSortOrder(i int) { - m.sort_order = &i - m.addsort_order = nil +// SetCompletionCodeHash sets the "completion_code_hash" field. +func (m *PendingAuthSessionMutation) SetCompletionCodeHash(s string) { + m.completion_code_hash = &s } -// SortOrder returns the value of the "sort_order" field in the mutation. -func (m *PaymentProviderInstanceMutation) SortOrder() (r int, exists bool) { - v := m.sort_order +// CompletionCodeHash returns the value of the "completion_code_hash" field in the mutation. +func (m *PendingAuthSessionMutation) CompletionCodeHash() (r string, exists bool) { + v := m.completion_code_hash if v == nil { return } return *v, true } -// OldSortOrder returns the old "sort_order" field's value of the PaymentProviderInstance entity. -// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// OldCompletionCodeHash returns the old "completion_code_hash" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentProviderInstanceMutation) OldSortOrder(ctx context.Context) (v int, err error) { +func (m *PendingAuthSessionMutation) OldCompletionCodeHash(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSortOrder is only allowed on UpdateOne operations") + return v, errors.New("OldCompletionCodeHash is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSortOrder requires an ID field in the mutation") + return v, errors.New("OldCompletionCodeHash requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldSortOrder: %w", err) + return v, fmt.Errorf("querying old value for OldCompletionCodeHash: %w", err) } - return oldValue.SortOrder, nil + return oldValue.CompletionCodeHash, nil } -// AddSortOrder adds i to the "sort_order" field. -func (m *PaymentProviderInstanceMutation) AddSortOrder(i int) { - if m.addsort_order != nil { - *m.addsort_order += i - } else { - m.addsort_order = &i - } +// ResetCompletionCodeHash resets all changes to the "completion_code_hash" field. +func (m *PendingAuthSessionMutation) ResetCompletionCodeHash() { + m.completion_code_hash = nil } -// AddedSortOrder returns the value that was added to the "sort_order" field in this mutation. -func (m *PaymentProviderInstanceMutation) AddedSortOrder() (r int, exists bool) { - v := m.addsort_order +// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field. +func (m *PendingAuthSessionMutation) SetCompletionCodeExpiresAt(t time.Time) { + m.completion_code_expires_at = &t +} + +// CompletionCodeExpiresAt returns the value of the "completion_code_expires_at" field in the mutation. +func (m *PendingAuthSessionMutation) CompletionCodeExpiresAt() (r time.Time, exists bool) { + v := m.completion_code_expires_at if v == nil { return } return *v, true } -// ResetSortOrder resets all changes to the "sort_order" field. -func (m *PaymentProviderInstanceMutation) ResetSortOrder() { - m.sort_order = nil - m.addsort_order = nil +// OldCompletionCodeExpiresAt returns the old "completion_code_expires_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldCompletionCodeExpiresAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCompletionCodeExpiresAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCompletionCodeExpiresAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCompletionCodeExpiresAt: %w", err) + } + return oldValue.CompletionCodeExpiresAt, nil } -// SetLimits sets the "limits" field. -func (m *PaymentProviderInstanceMutation) SetLimits(s string) { - m.limits = &s +// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field. +func (m *PendingAuthSessionMutation) ClearCompletionCodeExpiresAt() { + m.completion_code_expires_at = nil + m.clearedFields[pendingauthsession.FieldCompletionCodeExpiresAt] = struct{}{} } -// Limits returns the value of the "limits" field in the mutation. -func (m *PaymentProviderInstanceMutation) Limits() (r string, exists bool) { - v := m.limits +// CompletionCodeExpiresAtCleared returns if the "completion_code_expires_at" field was cleared in this mutation. +func (m *PendingAuthSessionMutation) CompletionCodeExpiresAtCleared() bool { + _, ok := m.clearedFields[pendingauthsession.FieldCompletionCodeExpiresAt] + return ok +} + +// ResetCompletionCodeExpiresAt resets all changes to the "completion_code_expires_at" field. +func (m *PendingAuthSessionMutation) ResetCompletionCodeExpiresAt() { + m.completion_code_expires_at = nil + delete(m.clearedFields, pendingauthsession.FieldCompletionCodeExpiresAt) +} + +// SetEmailVerifiedAt sets the "email_verified_at" field. +func (m *PendingAuthSessionMutation) SetEmailVerifiedAt(t time.Time) { + m.email_verified_at = &t +} + +// EmailVerifiedAt returns the value of the "email_verified_at" field in the mutation. +func (m *PendingAuthSessionMutation) EmailVerifiedAt() (r time.Time, exists bool) { + v := m.email_verified_at if v == nil { return } return *v, true } -// OldLimits returns the old "limits" field's value of the PaymentProviderInstance entity. -// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// OldEmailVerifiedAt returns the old "email_verified_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentProviderInstanceMutation) OldLimits(ctx context.Context) (v string, err error) { +func (m *PendingAuthSessionMutation) OldEmailVerifiedAt(ctx context.Context) (v *time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldLimits is only allowed on UpdateOne operations") + return v, errors.New("OldEmailVerifiedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldLimits requires an ID field in the mutation") + return v, errors.New("OldEmailVerifiedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldLimits: %w", err) + return v, fmt.Errorf("querying old value for OldEmailVerifiedAt: %w", err) } - return oldValue.Limits, nil + return oldValue.EmailVerifiedAt, nil } -// ResetLimits resets all changes to the "limits" field. -func (m *PaymentProviderInstanceMutation) ResetLimits() { - m.limits = nil +// ClearEmailVerifiedAt clears the value of the "email_verified_at" field. +func (m *PendingAuthSessionMutation) ClearEmailVerifiedAt() { + m.email_verified_at = nil + m.clearedFields[pendingauthsession.FieldEmailVerifiedAt] = struct{}{} } -// SetRefundEnabled sets the "refund_enabled" field. -func (m *PaymentProviderInstanceMutation) SetRefundEnabled(b bool) { - m.refund_enabled = &b +// EmailVerifiedAtCleared returns if the "email_verified_at" field was cleared in this mutation. +func (m *PendingAuthSessionMutation) EmailVerifiedAtCleared() bool { + _, ok := m.clearedFields[pendingauthsession.FieldEmailVerifiedAt] + return ok } -// RefundEnabled returns the value of the "refund_enabled" field in the mutation. -func (m *PaymentProviderInstanceMutation) RefundEnabled() (r bool, exists bool) { - v := m.refund_enabled +// ResetEmailVerifiedAt resets all changes to the "email_verified_at" field. +func (m *PendingAuthSessionMutation) ResetEmailVerifiedAt() { + m.email_verified_at = nil + delete(m.clearedFields, pendingauthsession.FieldEmailVerifiedAt) +} + +// SetPasswordVerifiedAt sets the "password_verified_at" field. +func (m *PendingAuthSessionMutation) SetPasswordVerifiedAt(t time.Time) { + m.password_verified_at = &t +} + +// PasswordVerifiedAt returns the value of the "password_verified_at" field in the mutation. +func (m *PendingAuthSessionMutation) PasswordVerifiedAt() (r time.Time, exists bool) { + v := m.password_verified_at if v == nil { return } return *v, true } -// OldRefundEnabled returns the old "refund_enabled" field's value of the PaymentProviderInstance entity. -// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// OldPasswordVerifiedAt returns the old "password_verified_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentProviderInstanceMutation) OldRefundEnabled(ctx context.Context) (v bool, err error) { +func (m *PendingAuthSessionMutation) OldPasswordVerifiedAt(ctx context.Context) (v *time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldRefundEnabled is only allowed on UpdateOne operations") + return v, errors.New("OldPasswordVerifiedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldRefundEnabled requires an ID field in the mutation") + return v, errors.New("OldPasswordVerifiedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldRefundEnabled: %w", err) + return v, fmt.Errorf("querying old value for OldPasswordVerifiedAt: %w", err) } - return oldValue.RefundEnabled, nil + return oldValue.PasswordVerifiedAt, nil } -// ResetRefundEnabled resets all changes to the "refund_enabled" field. -func (m *PaymentProviderInstanceMutation) ResetRefundEnabled() { - m.refund_enabled = nil +// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field. +func (m *PendingAuthSessionMutation) ClearPasswordVerifiedAt() { + m.password_verified_at = nil + m.clearedFields[pendingauthsession.FieldPasswordVerifiedAt] = struct{}{} } -// SetAllowUserRefund sets the "allow_user_refund" field. -func (m *PaymentProviderInstanceMutation) SetAllowUserRefund(b bool) { - m.allow_user_refund = &b +// PasswordVerifiedAtCleared returns if the "password_verified_at" field was cleared in this mutation. +func (m *PendingAuthSessionMutation) PasswordVerifiedAtCleared() bool { + _, ok := m.clearedFields[pendingauthsession.FieldPasswordVerifiedAt] + return ok } -// AllowUserRefund returns the value of the "allow_user_refund" field in the mutation. -func (m *PaymentProviderInstanceMutation) AllowUserRefund() (r bool, exists bool) { - v := m.allow_user_refund +// ResetPasswordVerifiedAt resets all changes to the "password_verified_at" field. +func (m *PendingAuthSessionMutation) ResetPasswordVerifiedAt() { + m.password_verified_at = nil + delete(m.clearedFields, pendingauthsession.FieldPasswordVerifiedAt) +} + +// SetTotpVerifiedAt sets the "totp_verified_at" field. +func (m *PendingAuthSessionMutation) SetTotpVerifiedAt(t time.Time) { + m.totp_verified_at = &t +} + +// TotpVerifiedAt returns the value of the "totp_verified_at" field in the mutation. +func (m *PendingAuthSessionMutation) TotpVerifiedAt() (r time.Time, exists bool) { + v := m.totp_verified_at if v == nil { return } return *v, true } -// OldAllowUserRefund returns the old "allow_user_refund" field's value of the PaymentProviderInstance entity. -// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// OldTotpVerifiedAt returns the old "totp_verified_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentProviderInstanceMutation) OldAllowUserRefund(ctx context.Context) (v bool, err error) { +func (m *PendingAuthSessionMutation) OldTotpVerifiedAt(ctx context.Context) (v *time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldAllowUserRefund is only allowed on UpdateOne operations") + return v, errors.New("OldTotpVerifiedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldAllowUserRefund requires an ID field in the mutation") + return v, errors.New("OldTotpVerifiedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldAllowUserRefund: %w", err) + return v, fmt.Errorf("querying old value for OldTotpVerifiedAt: %w", err) } - return oldValue.AllowUserRefund, nil + return oldValue.TotpVerifiedAt, nil } -// ResetAllowUserRefund resets all changes to the "allow_user_refund" field. -func (m *PaymentProviderInstanceMutation) ResetAllowUserRefund() { - m.allow_user_refund = nil +// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field. +func (m *PendingAuthSessionMutation) ClearTotpVerifiedAt() { + m.totp_verified_at = nil + m.clearedFields[pendingauthsession.FieldTotpVerifiedAt] = struct{}{} } -// SetCreatedAt sets the "created_at" field. -func (m *PaymentProviderInstanceMutation) SetCreatedAt(t time.Time) { - m.created_at = &t +// TotpVerifiedAtCleared returns if the "totp_verified_at" field was cleared in this mutation. +func (m *PendingAuthSessionMutation) TotpVerifiedAtCleared() bool { + _, ok := m.clearedFields[pendingauthsession.FieldTotpVerifiedAt] + return ok } -// CreatedAt returns the value of the "created_at" field in the mutation. -func (m *PaymentProviderInstanceMutation) CreatedAt() (r time.Time, exists bool) { - v := m.created_at +// ResetTotpVerifiedAt resets all changes to the "totp_verified_at" field. +func (m *PendingAuthSessionMutation) ResetTotpVerifiedAt() { + m.totp_verified_at = nil + delete(m.clearedFields, pendingauthsession.FieldTotpVerifiedAt) +} + +// SetExpiresAt sets the "expires_at" field. +func (m *PendingAuthSessionMutation) SetExpiresAt(t time.Time) { + m.expires_at = &t +} + +// ExpiresAt returns the value of the "expires_at" field in the mutation. +func (m *PendingAuthSessionMutation) ExpiresAt() (r time.Time, exists bool) { + v := m.expires_at if v == nil { return } return *v, true } -// OldCreatedAt returns the old "created_at" field's value of the PaymentProviderInstance entity. -// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// OldExpiresAt returns the old "expires_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentProviderInstanceMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { +func (m *PendingAuthSessionMutation) OldExpiresAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldCreatedAt requires an ID field in the mutation") + return v, errors.New("OldExpiresAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err) } - return oldValue.CreatedAt, nil + return oldValue.ExpiresAt, nil } -// ResetCreatedAt resets all changes to the "created_at" field. -func (m *PaymentProviderInstanceMutation) ResetCreatedAt() { - m.created_at = nil +// ResetExpiresAt resets all changes to the "expires_at" field. +func (m *PendingAuthSessionMutation) ResetExpiresAt() { + m.expires_at = nil } -// SetUpdatedAt sets the "updated_at" field. -func (m *PaymentProviderInstanceMutation) SetUpdatedAt(t time.Time) { - m.updated_at = &t +// SetConsumedAt sets the "consumed_at" field. +func (m *PendingAuthSessionMutation) SetConsumedAt(t time.Time) { + m.consumed_at = &t } -// UpdatedAt returns the value of the "updated_at" field in the mutation. -func (m *PaymentProviderInstanceMutation) UpdatedAt() (r time.Time, exists bool) { - v := m.updated_at +// ConsumedAt returns the value of the "consumed_at" field in the mutation. +func (m *PendingAuthSessionMutation) ConsumedAt() (r time.Time, exists bool) { + v := m.consumed_at if v == nil { return } return *v, true } -// OldUpdatedAt returns the old "updated_at" field's value of the PaymentProviderInstance entity. -// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// OldConsumedAt returns the old "consumed_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentProviderInstanceMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { +func (m *PendingAuthSessionMutation) OldConsumedAt(ctx context.Context) (v *time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + return v, errors.New("OldConsumedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + return v, errors.New("OldConsumedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + return v, fmt.Errorf("querying old value for OldConsumedAt: %w", err) } - return oldValue.UpdatedAt, nil + return oldValue.ConsumedAt, nil } -// ResetUpdatedAt resets all changes to the "updated_at" field. -func (m *PaymentProviderInstanceMutation) ResetUpdatedAt() { - m.updated_at = nil +// ClearConsumedAt clears the value of the "consumed_at" field. +func (m *PendingAuthSessionMutation) ClearConsumedAt() { + m.consumed_at = nil + m.clearedFields[pendingauthsession.FieldConsumedAt] = struct{}{} } -// Where appends a list predicates to the PaymentProviderInstanceMutation builder. -func (m *PaymentProviderInstanceMutation) Where(ps ...predicate.PaymentProviderInstance) { +// ConsumedAtCleared returns if the "consumed_at" field was cleared in this mutation. +func (m *PendingAuthSessionMutation) ConsumedAtCleared() bool { + _, ok := m.clearedFields[pendingauthsession.FieldConsumedAt] + return ok +} + +// ResetConsumedAt resets all changes to the "consumed_at" field. +func (m *PendingAuthSessionMutation) ResetConsumedAt() { + m.consumed_at = nil + delete(m.clearedFields, pendingauthsession.FieldConsumedAt) +} + +// ClearTargetUser clears the "target_user" edge to the User entity. +func (m *PendingAuthSessionMutation) ClearTargetUser() { + m.clearedtarget_user = true + m.clearedFields[pendingauthsession.FieldTargetUserID] = struct{}{} +} + +// TargetUserCleared reports if the "target_user" edge to the User entity was cleared. +func (m *PendingAuthSessionMutation) TargetUserCleared() bool { + return m.TargetUserIDCleared() || m.clearedtarget_user +} + +// TargetUserIDs returns the "target_user" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// TargetUserID instead. It exists only for internal usage by the builders. +func (m *PendingAuthSessionMutation) TargetUserIDs() (ids []int64) { + if id := m.target_user; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetTargetUser resets all changes to the "target_user" edge. +func (m *PendingAuthSessionMutation) ResetTargetUser() { + m.target_user = nil + m.clearedtarget_user = false +} + +// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by id. +func (m *PendingAuthSessionMutation) SetAdoptionDecisionID(id int64) { + m.adoption_decision = &id +} + +// ClearAdoptionDecision clears the "adoption_decision" edge to the IdentityAdoptionDecision entity. +func (m *PendingAuthSessionMutation) ClearAdoptionDecision() { + m.clearedadoption_decision = true +} + +// AdoptionDecisionCleared reports if the "adoption_decision" edge to the IdentityAdoptionDecision entity was cleared. +func (m *PendingAuthSessionMutation) AdoptionDecisionCleared() bool { + return m.clearedadoption_decision +} + +// AdoptionDecisionID returns the "adoption_decision" edge ID in the mutation. +func (m *PendingAuthSessionMutation) AdoptionDecisionID() (id int64, exists bool) { + if m.adoption_decision != nil { + return *m.adoption_decision, true + } + return +} + +// AdoptionDecisionIDs returns the "adoption_decision" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// AdoptionDecisionID instead. It exists only for internal usage by the builders. +func (m *PendingAuthSessionMutation) AdoptionDecisionIDs() (ids []int64) { + if id := m.adoption_decision; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetAdoptionDecision resets all changes to the "adoption_decision" edge. +func (m *PendingAuthSessionMutation) ResetAdoptionDecision() { + m.adoption_decision = nil + m.clearedadoption_decision = false +} + +// Where appends a list predicates to the PendingAuthSessionMutation builder. +func (m *PendingAuthSessionMutation) Where(ps ...predicate.PendingAuthSession) { m.predicates = append(m.predicates, ps...) } -// WhereP appends storage-level predicates to the PaymentProviderInstanceMutation builder. Using this method, +// WhereP appends storage-level predicates to the PendingAuthSessionMutation builder. Using this method, // users can use type-assertion to append predicates that do not depend on any generated package. -func (m *PaymentProviderInstanceMutation) WhereP(ps ...func(*sql.Selector)) { - p := make([]predicate.PaymentProviderInstance, len(ps)) +func (m *PendingAuthSessionMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.PendingAuthSession, len(ps)) for i := range ps { p[i] = ps[i] } @@ -16230,60 +20413,87 @@ func (m *PaymentProviderInstanceMutation) WhereP(ps ...func(*sql.Selector)) { } // Op returns the operation name. -func (m *PaymentProviderInstanceMutation) Op() Op { +func (m *PendingAuthSessionMutation) Op() Op { return m.op } // SetOp allows setting the mutation operation. -func (m *PaymentProviderInstanceMutation) SetOp(op Op) { +func (m *PendingAuthSessionMutation) SetOp(op Op) { m.op = op } -// Type returns the node type of this mutation (PaymentProviderInstance). -func (m *PaymentProviderInstanceMutation) Type() string { +// Type returns the node type of this mutation (PendingAuthSession). +func (m *PendingAuthSessionMutation) Type() string { return m.typ } // Fields returns all fields that were changed during this mutation. Note that in // order to get all numeric fields that were incremented/decremented, call // AddedFields(). -func (m *PaymentProviderInstanceMutation) Fields() []string { - fields := make([]string, 0, 12) +func (m *PendingAuthSessionMutation) Fields() []string { + fields := make([]string, 0, 21) + if m.created_at != nil { + fields = append(fields, pendingauthsession.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, pendingauthsession.FieldUpdatedAt) + } + if m.session_token != nil { + fields = append(fields, pendingauthsession.FieldSessionToken) + } + if m.intent != nil { + fields = append(fields, pendingauthsession.FieldIntent) + } + if m.provider_type != nil { + fields = append(fields, pendingauthsession.FieldProviderType) + } if m.provider_key != nil { - fields = append(fields, paymentproviderinstance.FieldProviderKey) + fields = append(fields, pendingauthsession.FieldProviderKey) } - if m.name != nil { - fields = append(fields, paymentproviderinstance.FieldName) + if m.provider_subject != nil { + fields = append(fields, pendingauthsession.FieldProviderSubject) } - if m._config != nil { - fields = append(fields, paymentproviderinstance.FieldConfig) + if m.target_user != nil { + fields = append(fields, pendingauthsession.FieldTargetUserID) } - if m.supported_types != nil { - fields = append(fields, paymentproviderinstance.FieldSupportedTypes) + if m.redirect_to != nil { + fields = append(fields, pendingauthsession.FieldRedirectTo) } - if m.enabled != nil { - fields = append(fields, paymentproviderinstance.FieldEnabled) + if m.resolved_email != nil { + fields = append(fields, pendingauthsession.FieldResolvedEmail) } - if m.payment_mode != nil { - fields = append(fields, paymentproviderinstance.FieldPaymentMode) + if m.registration_password_hash != nil { + fields = append(fields, pendingauthsession.FieldRegistrationPasswordHash) } - if m.sort_order != nil { - fields = append(fields, paymentproviderinstance.FieldSortOrder) + if m.upstream_identity_claims != nil { + fields = append(fields, pendingauthsession.FieldUpstreamIdentityClaims) } - if m.limits != nil { - fields = append(fields, paymentproviderinstance.FieldLimits) + if m.local_flow_state != nil { + fields = append(fields, pendingauthsession.FieldLocalFlowState) } - if m.refund_enabled != nil { - fields = append(fields, paymentproviderinstance.FieldRefundEnabled) + if m.browser_session_key != nil { + fields = append(fields, pendingauthsession.FieldBrowserSessionKey) } - if m.allow_user_refund != nil { - fields = append(fields, paymentproviderinstance.FieldAllowUserRefund) + if m.completion_code_hash != nil { + fields = append(fields, pendingauthsession.FieldCompletionCodeHash) } - if m.created_at != nil { - fields = append(fields, paymentproviderinstance.FieldCreatedAt) + if m.completion_code_expires_at != nil { + fields = append(fields, pendingauthsession.FieldCompletionCodeExpiresAt) } - if m.updated_at != nil { - fields = append(fields, paymentproviderinstance.FieldUpdatedAt) + if m.email_verified_at != nil { + fields = append(fields, pendingauthsession.FieldEmailVerifiedAt) + } + if m.password_verified_at != nil { + fields = append(fields, pendingauthsession.FieldPasswordVerifiedAt) + } + if m.totp_verified_at != nil { + fields = append(fields, pendingauthsession.FieldTotpVerifiedAt) + } + if m.expires_at != nil { + fields = append(fields, pendingauthsession.FieldExpiresAt) + } + if m.consumed_at != nil { + fields = append(fields, pendingauthsession.FieldConsumedAt) } return fields } @@ -16291,32 +20501,50 @@ func (m *PaymentProviderInstanceMutation) Fields() []string { // Field returns the value of a field with the given name. The second boolean // return value indicates that this field was not set, or was not defined in the // schema. -func (m *PaymentProviderInstanceMutation) Field(name string) (ent.Value, bool) { +func (m *PendingAuthSessionMutation) Field(name string) (ent.Value, bool) { switch name { - case paymentproviderinstance.FieldProviderKey: - return m.ProviderKey() - case paymentproviderinstance.FieldName: - return m.Name() - case paymentproviderinstance.FieldConfig: - return m.Config() - case paymentproviderinstance.FieldSupportedTypes: - return m.SupportedTypes() - case paymentproviderinstance.FieldEnabled: - return m.Enabled() - case paymentproviderinstance.FieldPaymentMode: - return m.PaymentMode() - case paymentproviderinstance.FieldSortOrder: - return m.SortOrder() - case paymentproviderinstance.FieldLimits: - return m.Limits() - case paymentproviderinstance.FieldRefundEnabled: - return m.RefundEnabled() - case paymentproviderinstance.FieldAllowUserRefund: - return m.AllowUserRefund() - case paymentproviderinstance.FieldCreatedAt: + case pendingauthsession.FieldCreatedAt: return m.CreatedAt() - case paymentproviderinstance.FieldUpdatedAt: + case pendingauthsession.FieldUpdatedAt: return m.UpdatedAt() + case pendingauthsession.FieldSessionToken: + return m.SessionToken() + case pendingauthsession.FieldIntent: + return m.Intent() + case pendingauthsession.FieldProviderType: + return m.ProviderType() + case pendingauthsession.FieldProviderKey: + return m.ProviderKey() + case pendingauthsession.FieldProviderSubject: + return m.ProviderSubject() + case pendingauthsession.FieldTargetUserID: + return m.TargetUserID() + case pendingauthsession.FieldRedirectTo: + return m.RedirectTo() + case pendingauthsession.FieldResolvedEmail: + return m.ResolvedEmail() + case pendingauthsession.FieldRegistrationPasswordHash: + return m.RegistrationPasswordHash() + case pendingauthsession.FieldUpstreamIdentityClaims: + return m.UpstreamIdentityClaims() + case pendingauthsession.FieldLocalFlowState: + return m.LocalFlowState() + case pendingauthsession.FieldBrowserSessionKey: + return m.BrowserSessionKey() + case pendingauthsession.FieldCompletionCodeHash: + return m.CompletionCodeHash() + case pendingauthsession.FieldCompletionCodeExpiresAt: + return m.CompletionCodeExpiresAt() + case pendingauthsession.FieldEmailVerifiedAt: + return m.EmailVerifiedAt() + case pendingauthsession.FieldPasswordVerifiedAt: + return m.PasswordVerifiedAt() + case pendingauthsession.FieldTotpVerifiedAt: + return m.TotpVerifiedAt() + case pendingauthsession.FieldExpiresAt: + return m.ExpiresAt() + case pendingauthsession.FieldConsumedAt: + return m.ConsumedAt() } return nil, false } @@ -16324,146 +20552,222 @@ func (m *PaymentProviderInstanceMutation) Field(name string) (ent.Value, bool) { // OldField returns the old value of the field from the database. An error is // returned if the mutation operation is not UpdateOne, or the query to the // database failed. -func (m *PaymentProviderInstanceMutation) OldField(ctx context.Context, name string) (ent.Value, error) { +func (m *PendingAuthSessionMutation) OldField(ctx context.Context, name string) (ent.Value, error) { switch name { - case paymentproviderinstance.FieldProviderKey: - return m.OldProviderKey(ctx) - case paymentproviderinstance.FieldName: - return m.OldName(ctx) - case paymentproviderinstance.FieldConfig: - return m.OldConfig(ctx) - case paymentproviderinstance.FieldSupportedTypes: - return m.OldSupportedTypes(ctx) - case paymentproviderinstance.FieldEnabled: - return m.OldEnabled(ctx) - case paymentproviderinstance.FieldPaymentMode: - return m.OldPaymentMode(ctx) - case paymentproviderinstance.FieldSortOrder: - return m.OldSortOrder(ctx) - case paymentproviderinstance.FieldLimits: - return m.OldLimits(ctx) - case paymentproviderinstance.FieldRefundEnabled: - return m.OldRefundEnabled(ctx) - case paymentproviderinstance.FieldAllowUserRefund: - return m.OldAllowUserRefund(ctx) - case paymentproviderinstance.FieldCreatedAt: + case pendingauthsession.FieldCreatedAt: return m.OldCreatedAt(ctx) - case paymentproviderinstance.FieldUpdatedAt: + case pendingauthsession.FieldUpdatedAt: return m.OldUpdatedAt(ctx) + case pendingauthsession.FieldSessionToken: + return m.OldSessionToken(ctx) + case pendingauthsession.FieldIntent: + return m.OldIntent(ctx) + case pendingauthsession.FieldProviderType: + return m.OldProviderType(ctx) + case pendingauthsession.FieldProviderKey: + return m.OldProviderKey(ctx) + case pendingauthsession.FieldProviderSubject: + return m.OldProviderSubject(ctx) + case pendingauthsession.FieldTargetUserID: + return m.OldTargetUserID(ctx) + case pendingauthsession.FieldRedirectTo: + return m.OldRedirectTo(ctx) + case pendingauthsession.FieldResolvedEmail: + return m.OldResolvedEmail(ctx) + case pendingauthsession.FieldRegistrationPasswordHash: + return m.OldRegistrationPasswordHash(ctx) + case pendingauthsession.FieldUpstreamIdentityClaims: + return m.OldUpstreamIdentityClaims(ctx) + case pendingauthsession.FieldLocalFlowState: + return m.OldLocalFlowState(ctx) + case pendingauthsession.FieldBrowserSessionKey: + return m.OldBrowserSessionKey(ctx) + case pendingauthsession.FieldCompletionCodeHash: + return m.OldCompletionCodeHash(ctx) + case pendingauthsession.FieldCompletionCodeExpiresAt: + return m.OldCompletionCodeExpiresAt(ctx) + case pendingauthsession.FieldEmailVerifiedAt: + return m.OldEmailVerifiedAt(ctx) + case pendingauthsession.FieldPasswordVerifiedAt: + return m.OldPasswordVerifiedAt(ctx) + case pendingauthsession.FieldTotpVerifiedAt: + return m.OldTotpVerifiedAt(ctx) + case pendingauthsession.FieldExpiresAt: + return m.OldExpiresAt(ctx) + case pendingauthsession.FieldConsumedAt: + return m.OldConsumedAt(ctx) } - return nil, fmt.Errorf("unknown PaymentProviderInstance field %s", name) + return nil, fmt.Errorf("unknown PendingAuthSession field %s", name) } // SetField sets the value of a field with the given name. It returns an error if // the field is not defined in the schema, or if the type mismatched the field // type. -func (m *PaymentProviderInstanceMutation) SetField(name string, value ent.Value) error { +func (m *PendingAuthSessionMutation) SetField(name string, value ent.Value) error { switch name { - case paymentproviderinstance.FieldProviderKey: + case pendingauthsession.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case pendingauthsession.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case pendingauthsession.FieldSessionToken: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetProviderKey(v) + m.SetSessionToken(v) return nil - case paymentproviderinstance.FieldName: + case pendingauthsession.FieldIntent: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetName(v) + m.SetIntent(v) return nil - case paymentproviderinstance.FieldConfig: + case pendingauthsession.FieldProviderType: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetConfig(v) + m.SetProviderType(v) return nil - case paymentproviderinstance.FieldSupportedTypes: + case pendingauthsession.FieldProviderKey: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetSupportedTypes(v) + m.SetProviderKey(v) return nil - case paymentproviderinstance.FieldEnabled: - v, ok := value.(bool) + case pendingauthsession.FieldProviderSubject: + v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetEnabled(v) + m.SetProviderSubject(v) return nil - case paymentproviderinstance.FieldPaymentMode: + case pendingauthsession.FieldTargetUserID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTargetUserID(v) + return nil + case pendingauthsession.FieldRedirectTo: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetPaymentMode(v) + m.SetRedirectTo(v) return nil - case paymentproviderinstance.FieldSortOrder: - v, ok := value.(int) + case pendingauthsession.FieldResolvedEmail: + v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetSortOrder(v) + m.SetResolvedEmail(v) return nil - case paymentproviderinstance.FieldLimits: + case pendingauthsession.FieldRegistrationPasswordHash: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetLimits(v) + m.SetRegistrationPasswordHash(v) return nil - case paymentproviderinstance.FieldRefundEnabled: - v, ok := value.(bool) + case pendingauthsession.FieldUpstreamIdentityClaims: + v, ok := value.(map[string]interface{}) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetRefundEnabled(v) + m.SetUpstreamIdentityClaims(v) return nil - case paymentproviderinstance.FieldAllowUserRefund: - v, ok := value.(bool) + case pendingauthsession.FieldLocalFlowState: + v, ok := value.(map[string]interface{}) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetAllowUserRefund(v) + m.SetLocalFlowState(v) return nil - case paymentproviderinstance.FieldCreatedAt: + case pendingauthsession.FieldBrowserSessionKey: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetBrowserSessionKey(v) + return nil + case pendingauthsession.FieldCompletionCodeHash: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCompletionCodeHash(v) + return nil + case pendingauthsession.FieldCompletionCodeExpiresAt: v, ok := value.(time.Time) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetCreatedAt(v) + m.SetCompletionCodeExpiresAt(v) return nil - case paymentproviderinstance.FieldUpdatedAt: + case pendingauthsession.FieldEmailVerifiedAt: v, ok := value.(time.Time) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetUpdatedAt(v) + m.SetEmailVerifiedAt(v) + return nil + case pendingauthsession.FieldPasswordVerifiedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPasswordVerifiedAt(v) + return nil + case pendingauthsession.FieldTotpVerifiedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTotpVerifiedAt(v) + return nil + case pendingauthsession.FieldExpiresAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetExpiresAt(v) + return nil + case pendingauthsession.FieldConsumedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetConsumedAt(v) return nil } - return fmt.Errorf("unknown PaymentProviderInstance field %s", name) + return fmt.Errorf("unknown PendingAuthSession field %s", name) } // AddedFields returns all numeric fields that were incremented/decremented during // this mutation. -func (m *PaymentProviderInstanceMutation) AddedFields() []string { +func (m *PendingAuthSessionMutation) AddedFields() []string { var fields []string - if m.addsort_order != nil { - fields = append(fields, paymentproviderinstance.FieldSortOrder) - } return fields } // AddedField returns the numeric value that was incremented/decremented on a field // with the given name. The second boolean return value indicates that this field // was not set, or was not defined in the schema. -func (m *PaymentProviderInstanceMutation) AddedField(name string) (ent.Value, bool) { +func (m *PendingAuthSessionMutation) AddedField(name string) (ent.Value, bool) { switch name { - case paymentproviderinstance.FieldSortOrder: - return m.AddedSortOrder() } return nil, false } @@ -16471,128 +20775,231 @@ func (m *PaymentProviderInstanceMutation) AddedField(name string) (ent.Value, bo // AddField adds the value to the field with the given name. It returns an error if // the field is not defined in the schema, or if the type mismatched the field // type. -func (m *PaymentProviderInstanceMutation) AddField(name string, value ent.Value) error { +func (m *PendingAuthSessionMutation) AddField(name string, value ent.Value) error { switch name { - case paymentproviderinstance.FieldSortOrder: - v, ok := value.(int) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddSortOrder(v) - return nil } - return fmt.Errorf("unknown PaymentProviderInstance numeric field %s", name) + return fmt.Errorf("unknown PendingAuthSession numeric field %s", name) } // ClearedFields returns all nullable fields that were cleared during this // mutation. -func (m *PaymentProviderInstanceMutation) ClearedFields() []string { - return nil +func (m *PendingAuthSessionMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(pendingauthsession.FieldTargetUserID) { + fields = append(fields, pendingauthsession.FieldTargetUserID) + } + if m.FieldCleared(pendingauthsession.FieldCompletionCodeExpiresAt) { + fields = append(fields, pendingauthsession.FieldCompletionCodeExpiresAt) + } + if m.FieldCleared(pendingauthsession.FieldEmailVerifiedAt) { + fields = append(fields, pendingauthsession.FieldEmailVerifiedAt) + } + if m.FieldCleared(pendingauthsession.FieldPasswordVerifiedAt) { + fields = append(fields, pendingauthsession.FieldPasswordVerifiedAt) + } + if m.FieldCleared(pendingauthsession.FieldTotpVerifiedAt) { + fields = append(fields, pendingauthsession.FieldTotpVerifiedAt) + } + if m.FieldCleared(pendingauthsession.FieldConsumedAt) { + fields = append(fields, pendingauthsession.FieldConsumedAt) + } + return fields } // FieldCleared returns a boolean indicating if a field with the given name was // cleared in this mutation. -func (m *PaymentProviderInstanceMutation) FieldCleared(name string) bool { +func (m *PendingAuthSessionMutation) FieldCleared(name string) bool { _, ok := m.clearedFields[name] return ok } // ClearField clears the value of the field with the given name. It returns an // error if the field is not defined in the schema. -func (m *PaymentProviderInstanceMutation) ClearField(name string) error { - return fmt.Errorf("unknown PaymentProviderInstance nullable field %s", name) +func (m *PendingAuthSessionMutation) ClearField(name string) error { + switch name { + case pendingauthsession.FieldTargetUserID: + m.ClearTargetUserID() + return nil + case pendingauthsession.FieldCompletionCodeExpiresAt: + m.ClearCompletionCodeExpiresAt() + return nil + case pendingauthsession.FieldEmailVerifiedAt: + m.ClearEmailVerifiedAt() + return nil + case pendingauthsession.FieldPasswordVerifiedAt: + m.ClearPasswordVerifiedAt() + return nil + case pendingauthsession.FieldTotpVerifiedAt: + m.ClearTotpVerifiedAt() + return nil + case pendingauthsession.FieldConsumedAt: + m.ClearConsumedAt() + return nil + } + return fmt.Errorf("unknown PendingAuthSession nullable field %s", name) } // ResetField resets all changes in the mutation for the field with the given name. // It returns an error if the field is not defined in the schema. -func (m *PaymentProviderInstanceMutation) ResetField(name string) error { +func (m *PendingAuthSessionMutation) ResetField(name string) error { switch name { - case paymentproviderinstance.FieldProviderKey: + case pendingauthsession.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case pendingauthsession.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case pendingauthsession.FieldSessionToken: + m.ResetSessionToken() + return nil + case pendingauthsession.FieldIntent: + m.ResetIntent() + return nil + case pendingauthsession.FieldProviderType: + m.ResetProviderType() + return nil + case pendingauthsession.FieldProviderKey: m.ResetProviderKey() return nil - case paymentproviderinstance.FieldName: - m.ResetName() + case pendingauthsession.FieldProviderSubject: + m.ResetProviderSubject() return nil - case paymentproviderinstance.FieldConfig: - m.ResetConfig() + case pendingauthsession.FieldTargetUserID: + m.ResetTargetUserID() return nil - case paymentproviderinstance.FieldSupportedTypes: - m.ResetSupportedTypes() + case pendingauthsession.FieldRedirectTo: + m.ResetRedirectTo() return nil - case paymentproviderinstance.FieldEnabled: - m.ResetEnabled() + case pendingauthsession.FieldResolvedEmail: + m.ResetResolvedEmail() return nil - case paymentproviderinstance.FieldPaymentMode: - m.ResetPaymentMode() + case pendingauthsession.FieldRegistrationPasswordHash: + m.ResetRegistrationPasswordHash() return nil - case paymentproviderinstance.FieldSortOrder: - m.ResetSortOrder() + case pendingauthsession.FieldUpstreamIdentityClaims: + m.ResetUpstreamIdentityClaims() return nil - case paymentproviderinstance.FieldLimits: - m.ResetLimits() + case pendingauthsession.FieldLocalFlowState: + m.ResetLocalFlowState() return nil - case paymentproviderinstance.FieldRefundEnabled: - m.ResetRefundEnabled() + case pendingauthsession.FieldBrowserSessionKey: + m.ResetBrowserSessionKey() return nil - case paymentproviderinstance.FieldAllowUserRefund: - m.ResetAllowUserRefund() + case pendingauthsession.FieldCompletionCodeHash: + m.ResetCompletionCodeHash() return nil - case paymentproviderinstance.FieldCreatedAt: - m.ResetCreatedAt() + case pendingauthsession.FieldCompletionCodeExpiresAt: + m.ResetCompletionCodeExpiresAt() return nil - case paymentproviderinstance.FieldUpdatedAt: - m.ResetUpdatedAt() + case pendingauthsession.FieldEmailVerifiedAt: + m.ResetEmailVerifiedAt() + return nil + case pendingauthsession.FieldPasswordVerifiedAt: + m.ResetPasswordVerifiedAt() + return nil + case pendingauthsession.FieldTotpVerifiedAt: + m.ResetTotpVerifiedAt() + return nil + case pendingauthsession.FieldExpiresAt: + m.ResetExpiresAt() + return nil + case pendingauthsession.FieldConsumedAt: + m.ResetConsumedAt() return nil } - return fmt.Errorf("unknown PaymentProviderInstance field %s", name) + return fmt.Errorf("unknown PendingAuthSession field %s", name) } // AddedEdges returns all edge names that were set/added in this mutation. -func (m *PaymentProviderInstanceMutation) AddedEdges() []string { - edges := make([]string, 0, 0) +func (m *PendingAuthSessionMutation) AddedEdges() []string { + edges := make([]string, 0, 2) + if m.target_user != nil { + edges = append(edges, pendingauthsession.EdgeTargetUser) + } + if m.adoption_decision != nil { + edges = append(edges, pendingauthsession.EdgeAdoptionDecision) + } return edges } // AddedIDs returns all IDs (to other nodes) that were added for the given edge // name in this mutation. -func (m *PaymentProviderInstanceMutation) AddedIDs(name string) []ent.Value { +func (m *PendingAuthSessionMutation) AddedIDs(name string) []ent.Value { + switch name { + case pendingauthsession.EdgeTargetUser: + if id := m.target_user; id != nil { + return []ent.Value{*id} + } + case pendingauthsession.EdgeAdoptionDecision: + if id := m.adoption_decision; id != nil { + return []ent.Value{*id} + } + } return nil } // RemovedEdges returns all edge names that were removed in this mutation. -func (m *PaymentProviderInstanceMutation) RemovedEdges() []string { - edges := make([]string, 0, 0) +func (m *PendingAuthSessionMutation) RemovedEdges() []string { + edges := make([]string, 0, 2) return edges } // RemovedIDs returns all IDs (to other nodes) that were removed for the edge with // the given name in this mutation. -func (m *PaymentProviderInstanceMutation) RemovedIDs(name string) []ent.Value { +func (m *PendingAuthSessionMutation) RemovedIDs(name string) []ent.Value { return nil } // ClearedEdges returns all edge names that were cleared in this mutation. -func (m *PaymentProviderInstanceMutation) ClearedEdges() []string { - edges := make([]string, 0, 0) +func (m *PendingAuthSessionMutation) ClearedEdges() []string { + edges := make([]string, 0, 2) + if m.clearedtarget_user { + edges = append(edges, pendingauthsession.EdgeTargetUser) + } + if m.clearedadoption_decision { + edges = append(edges, pendingauthsession.EdgeAdoptionDecision) + } return edges } // EdgeCleared returns a boolean which indicates if the edge with the given name // was cleared in this mutation. -func (m *PaymentProviderInstanceMutation) EdgeCleared(name string) bool { +func (m *PendingAuthSessionMutation) EdgeCleared(name string) bool { + switch name { + case pendingauthsession.EdgeTargetUser: + return m.clearedtarget_user + case pendingauthsession.EdgeAdoptionDecision: + return m.clearedadoption_decision + } return false } // ClearEdge clears the value of the edge with the given name. It returns an error // if that edge is not defined in the schema. -func (m *PaymentProviderInstanceMutation) ClearEdge(name string) error { - return fmt.Errorf("unknown PaymentProviderInstance unique edge %s", name) +func (m *PendingAuthSessionMutation) ClearEdge(name string) error { + switch name { + case pendingauthsession.EdgeTargetUser: + m.ClearTargetUser() + return nil + case pendingauthsession.EdgeAdoptionDecision: + m.ClearAdoptionDecision() + return nil + } + return fmt.Errorf("unknown PendingAuthSession unique edge %s", name) } // ResetEdge resets all changes to the edge with the given name in this mutation. // It returns an error if the edge is not defined in the schema. -func (m *PaymentProviderInstanceMutation) ResetEdge(name string) error { - return fmt.Errorf("unknown PaymentProviderInstance edge %s", name) +func (m *PendingAuthSessionMutation) ResetEdge(name string) error { + switch name { + case pendingauthsession.EdgeTargetUser: + m.ResetTargetUser() + return nil + case pendingauthsession.EdgeAdoptionDecision: + m.ResetAdoptionDecision() + return nil + } + return fmt.Errorf("unknown PendingAuthSession edge %s", name) } // PromoCodeMutation represents an operation that mutates the PromoCode nodes in the graph. @@ -28264,6 +32671,9 @@ type UserMutation struct { totp_secret_encrypted *string totp_enabled *bool totp_enabled_at *time.Time + signup_source *string + last_login_at *time.Time + last_active_at *time.Time balance_notify_enabled *bool balance_notify_threshold_type *string balance_notify_threshold *float64 @@ -28302,6 +32712,12 @@ type UserMutation struct { payment_orders map[int64]struct{} removedpayment_orders map[int64]struct{} clearedpayment_orders bool + auth_identities map[int64]struct{} + removedauth_identities map[int64]struct{} + clearedauth_identities bool + pending_auth_sessions map[int64]struct{} + removedpending_auth_sessions map[int64]struct{} + clearedpending_auth_sessions bool done bool oldValue func(context.Context) (*User, error) predicates []predicate.User @@ -28988,6 +33404,140 @@ func (m *UserMutation) ResetTotpEnabledAt() { delete(m.clearedFields, user.FieldTotpEnabledAt) } +// SetSignupSource sets the "signup_source" field. +func (m *UserMutation) SetSignupSource(s string) { + m.signup_source = &s +} + +// SignupSource returns the value of the "signup_source" field in the mutation. +func (m *UserMutation) SignupSource() (r string, exists bool) { + v := m.signup_source + if v == nil { + return + } + return *v, true +} + +// OldSignupSource returns the old "signup_source" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldSignupSource(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSignupSource is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSignupSource requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSignupSource: %w", err) + } + return oldValue.SignupSource, nil +} + +// ResetSignupSource resets all changes to the "signup_source" field. +func (m *UserMutation) ResetSignupSource() { + m.signup_source = nil +} + +// SetLastLoginAt sets the "last_login_at" field. +func (m *UserMutation) SetLastLoginAt(t time.Time) { + m.last_login_at = &t +} + +// LastLoginAt returns the value of the "last_login_at" field in the mutation. +func (m *UserMutation) LastLoginAt() (r time.Time, exists bool) { + v := m.last_login_at + if v == nil { + return + } + return *v, true +} + +// OldLastLoginAt returns the old "last_login_at" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldLastLoginAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldLastLoginAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldLastLoginAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldLastLoginAt: %w", err) + } + return oldValue.LastLoginAt, nil +} + +// ClearLastLoginAt clears the value of the "last_login_at" field. +func (m *UserMutation) ClearLastLoginAt() { + m.last_login_at = nil + m.clearedFields[user.FieldLastLoginAt] = struct{}{} +} + +// LastLoginAtCleared returns if the "last_login_at" field was cleared in this mutation. +func (m *UserMutation) LastLoginAtCleared() bool { + _, ok := m.clearedFields[user.FieldLastLoginAt] + return ok +} + +// ResetLastLoginAt resets all changes to the "last_login_at" field. +func (m *UserMutation) ResetLastLoginAt() { + m.last_login_at = nil + delete(m.clearedFields, user.FieldLastLoginAt) +} + +// SetLastActiveAt sets the "last_active_at" field. +func (m *UserMutation) SetLastActiveAt(t time.Time) { + m.last_active_at = &t +} + +// LastActiveAt returns the value of the "last_active_at" field in the mutation. +func (m *UserMutation) LastActiveAt() (r time.Time, exists bool) { + v := m.last_active_at + if v == nil { + return + } + return *v, true +} + +// OldLastActiveAt returns the old "last_active_at" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldLastActiveAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldLastActiveAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldLastActiveAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldLastActiveAt: %w", err) + } + return oldValue.LastActiveAt, nil +} + +// ClearLastActiveAt clears the value of the "last_active_at" field. +func (m *UserMutation) ClearLastActiveAt() { + m.last_active_at = nil + m.clearedFields[user.FieldLastActiveAt] = struct{}{} +} + +// LastActiveAtCleared returns if the "last_active_at" field was cleared in this mutation. +func (m *UserMutation) LastActiveAtCleared() bool { + _, ok := m.clearedFields[user.FieldLastActiveAt] + return ok +} + +// ResetLastActiveAt resets all changes to the "last_active_at" field. +func (m *UserMutation) ResetLastActiveAt() { + m.last_active_at = nil + delete(m.clearedFields, user.FieldLastActiveAt) +} + // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. func (m *UserMutation) SetBalanceNotifyEnabled(b bool) { m.balance_notify_enabled = &b @@ -29762,6 +34312,114 @@ func (m *UserMutation) ResetPaymentOrders() { m.removedpayment_orders = nil } +// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by ids. +func (m *UserMutation) AddAuthIdentityIDs(ids ...int64) { + if m.auth_identities == nil { + m.auth_identities = make(map[int64]struct{}) + } + for i := range ids { + m.auth_identities[ids[i]] = struct{}{} + } +} + +// ClearAuthIdentities clears the "auth_identities" edge to the AuthIdentity entity. +func (m *UserMutation) ClearAuthIdentities() { + m.clearedauth_identities = true +} + +// AuthIdentitiesCleared reports if the "auth_identities" edge to the AuthIdentity entity was cleared. +func (m *UserMutation) AuthIdentitiesCleared() bool { + return m.clearedauth_identities +} + +// RemoveAuthIdentityIDs removes the "auth_identities" edge to the AuthIdentity entity by IDs. +func (m *UserMutation) RemoveAuthIdentityIDs(ids ...int64) { + if m.removedauth_identities == nil { + m.removedauth_identities = make(map[int64]struct{}) + } + for i := range ids { + delete(m.auth_identities, ids[i]) + m.removedauth_identities[ids[i]] = struct{}{} + } +} + +// RemovedAuthIdentities returns the removed IDs of the "auth_identities" edge to the AuthIdentity entity. +func (m *UserMutation) RemovedAuthIdentitiesIDs() (ids []int64) { + for id := range m.removedauth_identities { + ids = append(ids, id) + } + return +} + +// AuthIdentitiesIDs returns the "auth_identities" edge IDs in the mutation. +func (m *UserMutation) AuthIdentitiesIDs() (ids []int64) { + for id := range m.auth_identities { + ids = append(ids, id) + } + return +} + +// ResetAuthIdentities resets all changes to the "auth_identities" edge. +func (m *UserMutation) ResetAuthIdentities() { + m.auth_identities = nil + m.clearedauth_identities = false + m.removedauth_identities = nil +} + +// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by ids. +func (m *UserMutation) AddPendingAuthSessionIDs(ids ...int64) { + if m.pending_auth_sessions == nil { + m.pending_auth_sessions = make(map[int64]struct{}) + } + for i := range ids { + m.pending_auth_sessions[ids[i]] = struct{}{} + } +} + +// ClearPendingAuthSessions clears the "pending_auth_sessions" edge to the PendingAuthSession entity. +func (m *UserMutation) ClearPendingAuthSessions() { + m.clearedpending_auth_sessions = true +} + +// PendingAuthSessionsCleared reports if the "pending_auth_sessions" edge to the PendingAuthSession entity was cleared. +func (m *UserMutation) PendingAuthSessionsCleared() bool { + return m.clearedpending_auth_sessions +} + +// RemovePendingAuthSessionIDs removes the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs. +func (m *UserMutation) RemovePendingAuthSessionIDs(ids ...int64) { + if m.removedpending_auth_sessions == nil { + m.removedpending_auth_sessions = make(map[int64]struct{}) + } + for i := range ids { + delete(m.pending_auth_sessions, ids[i]) + m.removedpending_auth_sessions[ids[i]] = struct{}{} + } +} + +// RemovedPendingAuthSessions returns the removed IDs of the "pending_auth_sessions" edge to the PendingAuthSession entity. +func (m *UserMutation) RemovedPendingAuthSessionsIDs() (ids []int64) { + for id := range m.removedpending_auth_sessions { + ids = append(ids, id) + } + return +} + +// PendingAuthSessionsIDs returns the "pending_auth_sessions" edge IDs in the mutation. +func (m *UserMutation) PendingAuthSessionsIDs() (ids []int64) { + for id := range m.pending_auth_sessions { + ids = append(ids, id) + } + return +} + +// ResetPendingAuthSessions resets all changes to the "pending_auth_sessions" edge. +func (m *UserMutation) ResetPendingAuthSessions() { + m.pending_auth_sessions = nil + m.clearedpending_auth_sessions = false + m.removedpending_auth_sessions = nil +} + // Where appends a list predicates to the UserMutation builder. func (m *UserMutation) Where(ps ...predicate.User) { m.predicates = append(m.predicates, ps...) @@ -29796,7 +34454,7 @@ func (m *UserMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UserMutation) Fields() []string { - fields := make([]string, 0, 19) + fields := make([]string, 0, 22) if m.created_at != nil { fields = append(fields, user.FieldCreatedAt) } @@ -29839,6 +34497,15 @@ func (m *UserMutation) Fields() []string { if m.totp_enabled_at != nil { fields = append(fields, user.FieldTotpEnabledAt) } + if m.signup_source != nil { + fields = append(fields, user.FieldSignupSource) + } + if m.last_login_at != nil { + fields = append(fields, user.FieldLastLoginAt) + } + if m.last_active_at != nil { + fields = append(fields, user.FieldLastActiveAt) + } if m.balance_notify_enabled != nil { fields = append(fields, user.FieldBalanceNotifyEnabled) } @@ -29890,6 +34557,12 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) { return m.TotpEnabled() case user.FieldTotpEnabledAt: return m.TotpEnabledAt() + case user.FieldSignupSource: + return m.SignupSource() + case user.FieldLastLoginAt: + return m.LastLoginAt() + case user.FieldLastActiveAt: + return m.LastActiveAt() case user.FieldBalanceNotifyEnabled: return m.BalanceNotifyEnabled() case user.FieldBalanceNotifyThresholdType: @@ -29937,6 +34610,12 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er return m.OldTotpEnabled(ctx) case user.FieldTotpEnabledAt: return m.OldTotpEnabledAt(ctx) + case user.FieldSignupSource: + return m.OldSignupSource(ctx) + case user.FieldLastLoginAt: + return m.OldLastLoginAt(ctx) + case user.FieldLastActiveAt: + return m.OldLastActiveAt(ctx) case user.FieldBalanceNotifyEnabled: return m.OldBalanceNotifyEnabled(ctx) case user.FieldBalanceNotifyThresholdType: @@ -30054,6 +34733,27 @@ func (m *UserMutation) SetField(name string, value ent.Value) error { } m.SetTotpEnabledAt(v) return nil + case user.FieldSignupSource: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSignupSource(v) + return nil + case user.FieldLastLoginAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetLastLoginAt(v) + return nil + case user.FieldLastActiveAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetLastActiveAt(v) + return nil case user.FieldBalanceNotifyEnabled: v, ok := value.(bool) if !ok { @@ -30179,6 +34879,12 @@ func (m *UserMutation) ClearedFields() []string { if m.FieldCleared(user.FieldTotpEnabledAt) { fields = append(fields, user.FieldTotpEnabledAt) } + if m.FieldCleared(user.FieldLastLoginAt) { + fields = append(fields, user.FieldLastLoginAt) + } + if m.FieldCleared(user.FieldLastActiveAt) { + fields = append(fields, user.FieldLastActiveAt) + } if m.FieldCleared(user.FieldBalanceNotifyThreshold) { fields = append(fields, user.FieldBalanceNotifyThreshold) } @@ -30205,6 +34911,12 @@ func (m *UserMutation) ClearField(name string) error { case user.FieldTotpEnabledAt: m.ClearTotpEnabledAt() return nil + case user.FieldLastLoginAt: + m.ClearLastLoginAt() + return nil + case user.FieldLastActiveAt: + m.ClearLastActiveAt() + return nil case user.FieldBalanceNotifyThreshold: m.ClearBalanceNotifyThreshold() return nil @@ -30258,6 +34970,15 @@ func (m *UserMutation) ResetField(name string) error { case user.FieldTotpEnabledAt: m.ResetTotpEnabledAt() return nil + case user.FieldSignupSource: + m.ResetSignupSource() + return nil + case user.FieldLastLoginAt: + m.ResetLastLoginAt() + return nil + case user.FieldLastActiveAt: + m.ResetLastActiveAt() + return nil case user.FieldBalanceNotifyEnabled: m.ResetBalanceNotifyEnabled() return nil @@ -30279,7 +35000,7 @@ func (m *UserMutation) ResetField(name string) error { // AddedEdges returns all edge names that were set/added in this mutation. func (m *UserMutation) AddedEdges() []string { - edges := make([]string, 0, 10) + edges := make([]string, 0, 12) if m.api_keys != nil { edges = append(edges, user.EdgeAPIKeys) } @@ -30310,6 +35031,12 @@ func (m *UserMutation) AddedEdges() []string { if m.payment_orders != nil { edges = append(edges, user.EdgePaymentOrders) } + if m.auth_identities != nil { + edges = append(edges, user.EdgeAuthIdentities) + } + if m.pending_auth_sessions != nil { + edges = append(edges, user.EdgePendingAuthSessions) + } return edges } @@ -30377,13 +35104,25 @@ func (m *UserMutation) AddedIDs(name string) []ent.Value { ids = append(ids, id) } return ids + case user.EdgeAuthIdentities: + ids := make([]ent.Value, 0, len(m.auth_identities)) + for id := range m.auth_identities { + ids = append(ids, id) + } + return ids + case user.EdgePendingAuthSessions: + ids := make([]ent.Value, 0, len(m.pending_auth_sessions)) + for id := range m.pending_auth_sessions { + ids = append(ids, id) + } + return ids } return nil } // RemovedEdges returns all edge names that were removed in this mutation. func (m *UserMutation) RemovedEdges() []string { - edges := make([]string, 0, 10) + edges := make([]string, 0, 12) if m.removedapi_keys != nil { edges = append(edges, user.EdgeAPIKeys) } @@ -30414,6 +35153,12 @@ func (m *UserMutation) RemovedEdges() []string { if m.removedpayment_orders != nil { edges = append(edges, user.EdgePaymentOrders) } + if m.removedauth_identities != nil { + edges = append(edges, user.EdgeAuthIdentities) + } + if m.removedpending_auth_sessions != nil { + edges = append(edges, user.EdgePendingAuthSessions) + } return edges } @@ -30481,13 +35226,25 @@ func (m *UserMutation) RemovedIDs(name string) []ent.Value { ids = append(ids, id) } return ids + case user.EdgeAuthIdentities: + ids := make([]ent.Value, 0, len(m.removedauth_identities)) + for id := range m.removedauth_identities { + ids = append(ids, id) + } + return ids + case user.EdgePendingAuthSessions: + ids := make([]ent.Value, 0, len(m.removedpending_auth_sessions)) + for id := range m.removedpending_auth_sessions { + ids = append(ids, id) + } + return ids } return nil } // ClearedEdges returns all edge names that were cleared in this mutation. func (m *UserMutation) ClearedEdges() []string { - edges := make([]string, 0, 10) + edges := make([]string, 0, 12) if m.clearedapi_keys { edges = append(edges, user.EdgeAPIKeys) } @@ -30518,6 +35275,12 @@ func (m *UserMutation) ClearedEdges() []string { if m.clearedpayment_orders { edges = append(edges, user.EdgePaymentOrders) } + if m.clearedauth_identities { + edges = append(edges, user.EdgeAuthIdentities) + } + if m.clearedpending_auth_sessions { + edges = append(edges, user.EdgePendingAuthSessions) + } return edges } @@ -30545,6 +35308,10 @@ func (m *UserMutation) EdgeCleared(name string) bool { return m.clearedpromo_code_usages case user.EdgePaymentOrders: return m.clearedpayment_orders + case user.EdgeAuthIdentities: + return m.clearedauth_identities + case user.EdgePendingAuthSessions: + return m.clearedpending_auth_sessions } return false } @@ -30591,6 +35358,12 @@ func (m *UserMutation) ResetEdge(name string) error { case user.EdgePaymentOrders: m.ResetPaymentOrders() return nil + case user.EdgeAuthIdentities: + m.ResetAuthIdentities() + return nil + case user.EdgePendingAuthSessions: + m.ResetPendingAuthSessions() + return nil } return fmt.Errorf("unknown User edge %s", name) } diff --git a/backend/ent/paymentorder.go b/backend/ent/paymentorder.go index 6ea3e70981d1884751b7512e541f53c057c1c206..b131b8c8804575eb74196359508d52db2122391f 100644 --- a/backend/ent/paymentorder.go +++ b/backend/ent/paymentorder.go @@ -3,6 +3,7 @@ package ent import ( + "encoding/json" "fmt" "strings" "time" @@ -56,6 +57,10 @@ type PaymentOrder struct { SubscriptionDays *int `json:"subscription_days,omitempty"` // ProviderInstanceID holds the value of the "provider_instance_id" field. ProviderInstanceID *string `json:"provider_instance_id,omitempty"` + // ProviderKey holds the value of the "provider_key" field. + ProviderKey *string `json:"provider_key,omitempty"` + // ProviderSnapshot holds the value of the "provider_snapshot" field. + ProviderSnapshot map[string]interface{} `json:"provider_snapshot,omitempty"` // Status holds the value of the "status" field. Status string `json:"status,omitempty"` // RefundAmount holds the value of the "refund_amount" field. @@ -123,13 +128,15 @@ func (*PaymentOrder) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { + case paymentorder.FieldProviderSnapshot: + values[i] = new([]byte) case paymentorder.FieldForceRefund: values[i] = new(sql.NullBool) case paymentorder.FieldAmount, paymentorder.FieldPayAmount, paymentorder.FieldFeeRate, paymentorder.FieldRefundAmount: values[i] = new(sql.NullFloat64) case paymentorder.FieldID, paymentorder.FieldUserID, paymentorder.FieldPlanID, paymentorder.FieldSubscriptionGroupID, paymentorder.FieldSubscriptionDays: values[i] = new(sql.NullInt64) - case paymentorder.FieldUserEmail, paymentorder.FieldUserName, paymentorder.FieldUserNotes, paymentorder.FieldRechargeCode, paymentorder.FieldOutTradeNo, paymentorder.FieldPaymentType, paymentorder.FieldPaymentTradeNo, paymentorder.FieldPayURL, paymentorder.FieldQrCode, paymentorder.FieldQrCodeImg, paymentorder.FieldOrderType, paymentorder.FieldProviderInstanceID, paymentorder.FieldStatus, paymentorder.FieldRefundReason, paymentorder.FieldRefundRequestReason, paymentorder.FieldRefundRequestedBy, paymentorder.FieldFailedReason, paymentorder.FieldClientIP, paymentorder.FieldSrcHost, paymentorder.FieldSrcURL: + case paymentorder.FieldUserEmail, paymentorder.FieldUserName, paymentorder.FieldUserNotes, paymentorder.FieldRechargeCode, paymentorder.FieldOutTradeNo, paymentorder.FieldPaymentType, paymentorder.FieldPaymentTradeNo, paymentorder.FieldPayURL, paymentorder.FieldQrCode, paymentorder.FieldQrCodeImg, paymentorder.FieldOrderType, paymentorder.FieldProviderInstanceID, paymentorder.FieldProviderKey, paymentorder.FieldStatus, paymentorder.FieldRefundReason, paymentorder.FieldRefundRequestReason, paymentorder.FieldRefundRequestedBy, paymentorder.FieldFailedReason, paymentorder.FieldClientIP, paymentorder.FieldSrcHost, paymentorder.FieldSrcURL: values[i] = new(sql.NullString) case paymentorder.FieldRefundAt, paymentorder.FieldRefundRequestedAt, paymentorder.FieldExpiresAt, paymentorder.FieldPaidAt, paymentorder.FieldCompletedAt, paymentorder.FieldFailedAt, paymentorder.FieldCreatedAt, paymentorder.FieldUpdatedAt: values[i] = new(sql.NullTime) @@ -276,6 +283,21 @@ func (_m *PaymentOrder) assignValues(columns []string, values []any) error { _m.ProviderInstanceID = new(string) *_m.ProviderInstanceID = value.String } + case paymentorder.FieldProviderKey: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_key", values[i]) + } else if value.Valid { + _m.ProviderKey = new(string) + *_m.ProviderKey = value.String + } + case paymentorder.FieldProviderSnapshot: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field provider_snapshot", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.ProviderSnapshot); err != nil { + return fmt.Errorf("unmarshal field provider_snapshot: %w", err) + } + } case paymentorder.FieldStatus: if value, ok := values[i].(*sql.NullString); !ok { return fmt.Errorf("unexpected type %T for field status", values[i]) @@ -508,6 +530,14 @@ func (_m *PaymentOrder) String() string { builder.WriteString(*v) } builder.WriteString(", ") + if v := _m.ProviderKey; v != nil { + builder.WriteString("provider_key=") + builder.WriteString(*v) + } + builder.WriteString(", ") + builder.WriteString("provider_snapshot=") + builder.WriteString(fmt.Sprintf("%v", _m.ProviderSnapshot)) + builder.WriteString(", ") builder.WriteString("status=") builder.WriteString(_m.Status) builder.WriteString(", ") diff --git a/backend/ent/paymentorder/paymentorder.go b/backend/ent/paymentorder/paymentorder.go index 4467b2b635896402c3254245eff2fec0d8fb4136..6288379434280fe4c3cee2a8294ebdf98f686ed2 100644 --- a/backend/ent/paymentorder/paymentorder.go +++ b/backend/ent/paymentorder/paymentorder.go @@ -52,6 +52,10 @@ const ( FieldSubscriptionDays = "subscription_days" // FieldProviderInstanceID holds the string denoting the provider_instance_id field in the database. FieldProviderInstanceID = "provider_instance_id" + // FieldProviderKey holds the string denoting the provider_key field in the database. + FieldProviderKey = "provider_key" + // FieldProviderSnapshot holds the string denoting the provider_snapshot field in the database. + FieldProviderSnapshot = "provider_snapshot" // FieldStatus holds the string denoting the status field in the database. FieldStatus = "status" // FieldRefundAmount holds the string denoting the refund_amount field in the database. @@ -123,6 +127,8 @@ var Columns = []string{ FieldSubscriptionGroupID, FieldSubscriptionDays, FieldProviderInstanceID, + FieldProviderKey, + FieldProviderSnapshot, FieldStatus, FieldRefundAmount, FieldRefundReason, @@ -176,6 +182,8 @@ var ( OrderTypeValidator func(string) error // ProviderInstanceIDValidator is a validator for the "provider_instance_id" field. It is called by the builders before save. ProviderInstanceIDValidator func(string) error + // ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save. + ProviderKeyValidator func(string) error // DefaultStatus holds the default value on creation for the "status" field. DefaultStatus string // StatusValidator is a validator for the "status" field. It is called by the builders before save. @@ -301,6 +309,11 @@ func ByProviderInstanceID(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldProviderInstanceID, opts...).ToFunc() } +// ByProviderKey orders the results by the provider_key field. +func ByProviderKey(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderKey, opts...).ToFunc() +} + // ByStatus orders the results by the status field. func ByStatus(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldStatus, opts...).ToFunc() diff --git a/backend/ent/paymentorder/where.go b/backend/ent/paymentorder/where.go index 78520fac4286fd7281262789aa3154697f1e3951..e96bf51ebd09499c5f478fa88720e669f3b02894 100644 --- a/backend/ent/paymentorder/where.go +++ b/backend/ent/paymentorder/where.go @@ -150,6 +150,11 @@ func ProviderInstanceID(v string) predicate.PaymentOrder { return predicate.PaymentOrder(sql.FieldEQ(FieldProviderInstanceID, v)) } +// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ. +func ProviderKey(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldEQ(FieldProviderKey, v)) +} + // Status applies equality check predicate on the "status" field. It's identical to StatusEQ. func Status(v string) predicate.PaymentOrder { return predicate.PaymentOrder(sql.FieldEQ(FieldStatus, v)) @@ -1360,6 +1365,91 @@ func ProviderInstanceIDContainsFold(v string) predicate.PaymentOrder { return predicate.PaymentOrder(sql.FieldContainsFold(FieldProviderInstanceID, v)) } +// ProviderKeyEQ applies the EQ predicate on the "provider_key" field. +func ProviderKeyEQ(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldEQ(FieldProviderKey, v)) +} + +// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field. +func ProviderKeyNEQ(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldNEQ(FieldProviderKey, v)) +} + +// ProviderKeyIn applies the In predicate on the "provider_key" field. +func ProviderKeyIn(vs ...string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldIn(FieldProviderKey, vs...)) +} + +// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field. +func ProviderKeyNotIn(vs ...string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldNotIn(FieldProviderKey, vs...)) +} + +// ProviderKeyGT applies the GT predicate on the "provider_key" field. +func ProviderKeyGT(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldGT(FieldProviderKey, v)) +} + +// ProviderKeyGTE applies the GTE predicate on the "provider_key" field. +func ProviderKeyGTE(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldGTE(FieldProviderKey, v)) +} + +// ProviderKeyLT applies the LT predicate on the "provider_key" field. +func ProviderKeyLT(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldLT(FieldProviderKey, v)) +} + +// ProviderKeyLTE applies the LTE predicate on the "provider_key" field. +func ProviderKeyLTE(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldLTE(FieldProviderKey, v)) +} + +// ProviderKeyContains applies the Contains predicate on the "provider_key" field. +func ProviderKeyContains(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldContains(FieldProviderKey, v)) +} + +// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field. +func ProviderKeyHasPrefix(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldHasPrefix(FieldProviderKey, v)) +} + +// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field. +func ProviderKeyHasSuffix(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldHasSuffix(FieldProviderKey, v)) +} + +// ProviderKeyIsNil applies the IsNil predicate on the "provider_key" field. +func ProviderKeyIsNil() predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldIsNull(FieldProviderKey)) +} + +// ProviderKeyNotNil applies the NotNil predicate on the "provider_key" field. +func ProviderKeyNotNil() predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldNotNull(FieldProviderKey)) +} + +// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field. +func ProviderKeyEqualFold(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldEqualFold(FieldProviderKey, v)) +} + +// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field. +func ProviderKeyContainsFold(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldContainsFold(FieldProviderKey, v)) +} + +// ProviderSnapshotIsNil applies the IsNil predicate on the "provider_snapshot" field. +func ProviderSnapshotIsNil() predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldIsNull(FieldProviderSnapshot)) +} + +// ProviderSnapshotNotNil applies the NotNil predicate on the "provider_snapshot" field. +func ProviderSnapshotNotNil() predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldNotNull(FieldProviderSnapshot)) +} + // StatusEQ applies the EQ predicate on the "status" field. func StatusEQ(v string) predicate.PaymentOrder { return predicate.PaymentOrder(sql.FieldEQ(FieldStatus, v)) diff --git a/backend/ent/paymentorder_create.go b/backend/ent/paymentorder_create.go index 030983390124b2840304ae5f021b6634ed3abdba..3ee24f8e918dbd18a1f3aad0c805bf870af6eebd 100644 --- a/backend/ent/paymentorder_create.go +++ b/backend/ent/paymentorder_create.go @@ -225,6 +225,26 @@ func (_c *PaymentOrderCreate) SetNillableProviderInstanceID(v *string) *PaymentO return _c } +// SetProviderKey sets the "provider_key" field. +func (_c *PaymentOrderCreate) SetProviderKey(v string) *PaymentOrderCreate { + _c.mutation.SetProviderKey(v) + return _c +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_c *PaymentOrderCreate) SetNillableProviderKey(v *string) *PaymentOrderCreate { + if v != nil { + _c.SetProviderKey(*v) + } + return _c +} + +// SetProviderSnapshot sets the "provider_snapshot" field. +func (_c *PaymentOrderCreate) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderCreate { + _c.mutation.SetProviderSnapshot(v) + return _c +} + // SetStatus sets the "status" field. func (_c *PaymentOrderCreate) SetStatus(v string) *PaymentOrderCreate { _c.mutation.SetStatus(v) @@ -602,6 +622,11 @@ func (_c *PaymentOrderCreate) check() error { return &ValidationError{Name: "provider_instance_id", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_instance_id": %w`, err)} } } + if v, ok := _c.mutation.ProviderKey(); ok { + if err := paymentorder.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_key": %w`, err)} + } + } if _, ok := _c.mutation.Status(); !ok { return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "PaymentOrder.status"`)} } @@ -748,6 +773,14 @@ func (_c *PaymentOrderCreate) createSpec() (*PaymentOrder, *sqlgraph.CreateSpec) _spec.SetField(paymentorder.FieldProviderInstanceID, field.TypeString, value) _node.ProviderInstanceID = &value } + if value, ok := _c.mutation.ProviderKey(); ok { + _spec.SetField(paymentorder.FieldProviderKey, field.TypeString, value) + _node.ProviderKey = &value + } + if value, ok := _c.mutation.ProviderSnapshot(); ok { + _spec.SetField(paymentorder.FieldProviderSnapshot, field.TypeJSON, value) + _node.ProviderSnapshot = value + } if value, ok := _c.mutation.Status(); ok { _spec.SetField(paymentorder.FieldStatus, field.TypeString, value) _node.Status = value @@ -1201,6 +1234,42 @@ func (u *PaymentOrderUpsert) ClearProviderInstanceID() *PaymentOrderUpsert { return u } +// SetProviderKey sets the "provider_key" field. +func (u *PaymentOrderUpsert) SetProviderKey(v string) *PaymentOrderUpsert { + u.Set(paymentorder.FieldProviderKey, v) + return u +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *PaymentOrderUpsert) UpdateProviderKey() *PaymentOrderUpsert { + u.SetExcluded(paymentorder.FieldProviderKey) + return u +} + +// ClearProviderKey clears the value of the "provider_key" field. +func (u *PaymentOrderUpsert) ClearProviderKey() *PaymentOrderUpsert { + u.SetNull(paymentorder.FieldProviderKey) + return u +} + +// SetProviderSnapshot sets the "provider_snapshot" field. +func (u *PaymentOrderUpsert) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderUpsert { + u.Set(paymentorder.FieldProviderSnapshot, v) + return u +} + +// UpdateProviderSnapshot sets the "provider_snapshot" field to the value that was provided on create. +func (u *PaymentOrderUpsert) UpdateProviderSnapshot() *PaymentOrderUpsert { + u.SetExcluded(paymentorder.FieldProviderSnapshot) + return u +} + +// ClearProviderSnapshot clears the value of the "provider_snapshot" field. +func (u *PaymentOrderUpsert) ClearProviderSnapshot() *PaymentOrderUpsert { + u.SetNull(paymentorder.FieldProviderSnapshot) + return u +} + // SetStatus sets the "status" field. func (u *PaymentOrderUpsert) SetStatus(v string) *PaymentOrderUpsert { u.Set(paymentorder.FieldStatus, v) @@ -1880,6 +1949,48 @@ func (u *PaymentOrderUpsertOne) ClearProviderInstanceID() *PaymentOrderUpsertOne }) } +// SetProviderKey sets the "provider_key" field. +func (u *PaymentOrderUpsertOne) SetProviderKey(v string) *PaymentOrderUpsertOne { + return u.Update(func(s *PaymentOrderUpsert) { + s.SetProviderKey(v) + }) +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *PaymentOrderUpsertOne) UpdateProviderKey() *PaymentOrderUpsertOne { + return u.Update(func(s *PaymentOrderUpsert) { + s.UpdateProviderKey() + }) +} + +// ClearProviderKey clears the value of the "provider_key" field. +func (u *PaymentOrderUpsertOne) ClearProviderKey() *PaymentOrderUpsertOne { + return u.Update(func(s *PaymentOrderUpsert) { + s.ClearProviderKey() + }) +} + +// SetProviderSnapshot sets the "provider_snapshot" field. +func (u *PaymentOrderUpsertOne) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderUpsertOne { + return u.Update(func(s *PaymentOrderUpsert) { + s.SetProviderSnapshot(v) + }) +} + +// UpdateProviderSnapshot sets the "provider_snapshot" field to the value that was provided on create. +func (u *PaymentOrderUpsertOne) UpdateProviderSnapshot() *PaymentOrderUpsertOne { + return u.Update(func(s *PaymentOrderUpsert) { + s.UpdateProviderSnapshot() + }) +} + +// ClearProviderSnapshot clears the value of the "provider_snapshot" field. +func (u *PaymentOrderUpsertOne) ClearProviderSnapshot() *PaymentOrderUpsertOne { + return u.Update(func(s *PaymentOrderUpsert) { + s.ClearProviderSnapshot() + }) +} + // SetStatus sets the "status" field. func (u *PaymentOrderUpsertOne) SetStatus(v string) *PaymentOrderUpsertOne { return u.Update(func(s *PaymentOrderUpsert) { @@ -2770,6 +2881,48 @@ func (u *PaymentOrderUpsertBulk) ClearProviderInstanceID() *PaymentOrderUpsertBu }) } +// SetProviderKey sets the "provider_key" field. +func (u *PaymentOrderUpsertBulk) SetProviderKey(v string) *PaymentOrderUpsertBulk { + return u.Update(func(s *PaymentOrderUpsert) { + s.SetProviderKey(v) + }) +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *PaymentOrderUpsertBulk) UpdateProviderKey() *PaymentOrderUpsertBulk { + return u.Update(func(s *PaymentOrderUpsert) { + s.UpdateProviderKey() + }) +} + +// ClearProviderKey clears the value of the "provider_key" field. +func (u *PaymentOrderUpsertBulk) ClearProviderKey() *PaymentOrderUpsertBulk { + return u.Update(func(s *PaymentOrderUpsert) { + s.ClearProviderKey() + }) +} + +// SetProviderSnapshot sets the "provider_snapshot" field. +func (u *PaymentOrderUpsertBulk) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderUpsertBulk { + return u.Update(func(s *PaymentOrderUpsert) { + s.SetProviderSnapshot(v) + }) +} + +// UpdateProviderSnapshot sets the "provider_snapshot" field to the value that was provided on create. +func (u *PaymentOrderUpsertBulk) UpdateProviderSnapshot() *PaymentOrderUpsertBulk { + return u.Update(func(s *PaymentOrderUpsert) { + s.UpdateProviderSnapshot() + }) +} + +// ClearProviderSnapshot clears the value of the "provider_snapshot" field. +func (u *PaymentOrderUpsertBulk) ClearProviderSnapshot() *PaymentOrderUpsertBulk { + return u.Update(func(s *PaymentOrderUpsert) { + s.ClearProviderSnapshot() + }) +} + // SetStatus sets the "status" field. func (u *PaymentOrderUpsertBulk) SetStatus(v string) *PaymentOrderUpsertBulk { return u.Update(func(s *PaymentOrderUpsert) { diff --git a/backend/ent/paymentorder_update.go b/backend/ent/paymentorder_update.go index 5978fc29148618f828e6586f8baf59d24c64ed1f..378e0dad2f90f2233387fd8bb02bf7018a09ed69 100644 --- a/backend/ent/paymentorder_update.go +++ b/backend/ent/paymentorder_update.go @@ -385,6 +385,38 @@ func (_u *PaymentOrderUpdate) ClearProviderInstanceID() *PaymentOrderUpdate { return _u } +// SetProviderKey sets the "provider_key" field. +func (_u *PaymentOrderUpdate) SetProviderKey(v string) *PaymentOrderUpdate { + _u.mutation.SetProviderKey(v) + return _u +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_u *PaymentOrderUpdate) SetNillableProviderKey(v *string) *PaymentOrderUpdate { + if v != nil { + _u.SetProviderKey(*v) + } + return _u +} + +// ClearProviderKey clears the value of the "provider_key" field. +func (_u *PaymentOrderUpdate) ClearProviderKey() *PaymentOrderUpdate { + _u.mutation.ClearProviderKey() + return _u +} + +// SetProviderSnapshot sets the "provider_snapshot" field. +func (_u *PaymentOrderUpdate) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderUpdate { + _u.mutation.SetProviderSnapshot(v) + return _u +} + +// ClearProviderSnapshot clears the value of the "provider_snapshot" field. +func (_u *PaymentOrderUpdate) ClearProviderSnapshot() *PaymentOrderUpdate { + _u.mutation.ClearProviderSnapshot() + return _u +} + // SetStatus sets the "status" field. func (_u *PaymentOrderUpdate) SetStatus(v string) *PaymentOrderUpdate { _u.mutation.SetStatus(v) @@ -776,6 +808,11 @@ func (_u *PaymentOrderUpdate) check() error { return &ValidationError{Name: "provider_instance_id", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_instance_id": %w`, err)} } } + if v, ok := _u.mutation.ProviderKey(); ok { + if err := paymentorder.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_key": %w`, err)} + } + } if v, ok := _u.mutation.Status(); ok { if err := paymentorder.StatusValidator(v); err != nil { return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.status": %w`, err)} @@ -910,6 +947,18 @@ func (_u *PaymentOrderUpdate) sqlSave(ctx context.Context) (_node int, err error if _u.mutation.ProviderInstanceIDCleared() { _spec.ClearField(paymentorder.FieldProviderInstanceID, field.TypeString) } + if value, ok := _u.mutation.ProviderKey(); ok { + _spec.SetField(paymentorder.FieldProviderKey, field.TypeString, value) + } + if _u.mutation.ProviderKeyCleared() { + _spec.ClearField(paymentorder.FieldProviderKey, field.TypeString) + } + if value, ok := _u.mutation.ProviderSnapshot(); ok { + _spec.SetField(paymentorder.FieldProviderSnapshot, field.TypeJSON, value) + } + if _u.mutation.ProviderSnapshotCleared() { + _spec.ClearField(paymentorder.FieldProviderSnapshot, field.TypeJSON) + } if value, ok := _u.mutation.Status(); ok { _spec.SetField(paymentorder.FieldStatus, field.TypeString, value) } @@ -1399,6 +1448,38 @@ func (_u *PaymentOrderUpdateOne) ClearProviderInstanceID() *PaymentOrderUpdateOn return _u } +// SetProviderKey sets the "provider_key" field. +func (_u *PaymentOrderUpdateOne) SetProviderKey(v string) *PaymentOrderUpdateOne { + _u.mutation.SetProviderKey(v) + return _u +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_u *PaymentOrderUpdateOne) SetNillableProviderKey(v *string) *PaymentOrderUpdateOne { + if v != nil { + _u.SetProviderKey(*v) + } + return _u +} + +// ClearProviderKey clears the value of the "provider_key" field. +func (_u *PaymentOrderUpdateOne) ClearProviderKey() *PaymentOrderUpdateOne { + _u.mutation.ClearProviderKey() + return _u +} + +// SetProviderSnapshot sets the "provider_snapshot" field. +func (_u *PaymentOrderUpdateOne) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderUpdateOne { + _u.mutation.SetProviderSnapshot(v) + return _u +} + +// ClearProviderSnapshot clears the value of the "provider_snapshot" field. +func (_u *PaymentOrderUpdateOne) ClearProviderSnapshot() *PaymentOrderUpdateOne { + _u.mutation.ClearProviderSnapshot() + return _u +} + // SetStatus sets the "status" field. func (_u *PaymentOrderUpdateOne) SetStatus(v string) *PaymentOrderUpdateOne { _u.mutation.SetStatus(v) @@ -1803,6 +1884,11 @@ func (_u *PaymentOrderUpdateOne) check() error { return &ValidationError{Name: "provider_instance_id", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_instance_id": %w`, err)} } } + if v, ok := _u.mutation.ProviderKey(); ok { + if err := paymentorder.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_key": %w`, err)} + } + } if v, ok := _u.mutation.Status(); ok { if err := paymentorder.StatusValidator(v); err != nil { return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.status": %w`, err)} @@ -1954,6 +2040,18 @@ func (_u *PaymentOrderUpdateOne) sqlSave(ctx context.Context) (_node *PaymentOrd if _u.mutation.ProviderInstanceIDCleared() { _spec.ClearField(paymentorder.FieldProviderInstanceID, field.TypeString) } + if value, ok := _u.mutation.ProviderKey(); ok { + _spec.SetField(paymentorder.FieldProviderKey, field.TypeString, value) + } + if _u.mutation.ProviderKeyCleared() { + _spec.ClearField(paymentorder.FieldProviderKey, field.TypeString) + } + if value, ok := _u.mutation.ProviderSnapshot(); ok { + _spec.SetField(paymentorder.FieldProviderSnapshot, field.TypeJSON, value) + } + if _u.mutation.ProviderSnapshotCleared() { + _spec.ClearField(paymentorder.FieldProviderSnapshot, field.TypeJSON) + } if value, ok := _u.mutation.Status(); ok { _spec.SetField(paymentorder.FieldStatus, field.TypeString, value) } diff --git a/backend/ent/pendingauthsession.go b/backend/ent/pendingauthsession.go new file mode 100644 index 0000000000000000000000000000000000000000..e77c065f779add6dc6dd6cbf860bfda6dfe418ba --- /dev/null +++ b/backend/ent/pendingauthsession.go @@ -0,0 +1,399 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// PendingAuthSession is the model entity for the PendingAuthSession schema. +type PendingAuthSession struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // SessionToken holds the value of the "session_token" field. + SessionToken string `json:"session_token,omitempty"` + // Intent holds the value of the "intent" field. + Intent string `json:"intent,omitempty"` + // ProviderType holds the value of the "provider_type" field. + ProviderType string `json:"provider_type,omitempty"` + // ProviderKey holds the value of the "provider_key" field. + ProviderKey string `json:"provider_key,omitempty"` + // ProviderSubject holds the value of the "provider_subject" field. + ProviderSubject string `json:"provider_subject,omitempty"` + // TargetUserID holds the value of the "target_user_id" field. + TargetUserID *int64 `json:"target_user_id,omitempty"` + // RedirectTo holds the value of the "redirect_to" field. + RedirectTo string `json:"redirect_to,omitempty"` + // ResolvedEmail holds the value of the "resolved_email" field. + ResolvedEmail string `json:"resolved_email,omitempty"` + // RegistrationPasswordHash holds the value of the "registration_password_hash" field. + RegistrationPasswordHash string `json:"registration_password_hash,omitempty"` + // UpstreamIdentityClaims holds the value of the "upstream_identity_claims" field. + UpstreamIdentityClaims map[string]interface{} `json:"upstream_identity_claims,omitempty"` + // LocalFlowState holds the value of the "local_flow_state" field. + LocalFlowState map[string]interface{} `json:"local_flow_state,omitempty"` + // BrowserSessionKey holds the value of the "browser_session_key" field. + BrowserSessionKey string `json:"browser_session_key,omitempty"` + // CompletionCodeHash holds the value of the "completion_code_hash" field. + CompletionCodeHash string `json:"completion_code_hash,omitempty"` + // CompletionCodeExpiresAt holds the value of the "completion_code_expires_at" field. + CompletionCodeExpiresAt *time.Time `json:"completion_code_expires_at,omitempty"` + // EmailVerifiedAt holds the value of the "email_verified_at" field. + EmailVerifiedAt *time.Time `json:"email_verified_at,omitempty"` + // PasswordVerifiedAt holds the value of the "password_verified_at" field. + PasswordVerifiedAt *time.Time `json:"password_verified_at,omitempty"` + // TotpVerifiedAt holds the value of the "totp_verified_at" field. + TotpVerifiedAt *time.Time `json:"totp_verified_at,omitempty"` + // ExpiresAt holds the value of the "expires_at" field. + ExpiresAt time.Time `json:"expires_at,omitempty"` + // ConsumedAt holds the value of the "consumed_at" field. + ConsumedAt *time.Time `json:"consumed_at,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the PendingAuthSessionQuery when eager-loading is set. + Edges PendingAuthSessionEdges `json:"edges"` + selectValues sql.SelectValues +} + +// PendingAuthSessionEdges holds the relations/edges for other nodes in the graph. +type PendingAuthSessionEdges struct { + // TargetUser holds the value of the target_user edge. + TargetUser *User `json:"target_user,omitempty"` + // AdoptionDecision holds the value of the adoption_decision edge. + AdoptionDecision *IdentityAdoptionDecision `json:"adoption_decision,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [2]bool +} + +// TargetUserOrErr returns the TargetUser value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e PendingAuthSessionEdges) TargetUserOrErr() (*User, error) { + if e.TargetUser != nil { + return e.TargetUser, nil + } else if e.loadedTypes[0] { + return nil, &NotFoundError{label: user.Label} + } + return nil, &NotLoadedError{edge: "target_user"} +} + +// AdoptionDecisionOrErr returns the AdoptionDecision value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e PendingAuthSessionEdges) AdoptionDecisionOrErr() (*IdentityAdoptionDecision, error) { + if e.AdoptionDecision != nil { + return e.AdoptionDecision, nil + } else if e.loadedTypes[1] { + return nil, &NotFoundError{label: identityadoptiondecision.Label} + } + return nil, &NotLoadedError{edge: "adoption_decision"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*PendingAuthSession) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case pendingauthsession.FieldUpstreamIdentityClaims, pendingauthsession.FieldLocalFlowState: + values[i] = new([]byte) + case pendingauthsession.FieldID, pendingauthsession.FieldTargetUserID: + values[i] = new(sql.NullInt64) + case pendingauthsession.FieldSessionToken, pendingauthsession.FieldIntent, pendingauthsession.FieldProviderType, pendingauthsession.FieldProviderKey, pendingauthsession.FieldProviderSubject, pendingauthsession.FieldRedirectTo, pendingauthsession.FieldResolvedEmail, pendingauthsession.FieldRegistrationPasswordHash, pendingauthsession.FieldBrowserSessionKey, pendingauthsession.FieldCompletionCodeHash: + values[i] = new(sql.NullString) + case pendingauthsession.FieldCreatedAt, pendingauthsession.FieldUpdatedAt, pendingauthsession.FieldCompletionCodeExpiresAt, pendingauthsession.FieldEmailVerifiedAt, pendingauthsession.FieldPasswordVerifiedAt, pendingauthsession.FieldTotpVerifiedAt, pendingauthsession.FieldExpiresAt, pendingauthsession.FieldConsumedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the PendingAuthSession fields. +func (_m *PendingAuthSession) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case pendingauthsession.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case pendingauthsession.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case pendingauthsession.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + _m.UpdatedAt = value.Time + } + case pendingauthsession.FieldSessionToken: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field session_token", values[i]) + } else if value.Valid { + _m.SessionToken = value.String + } + case pendingauthsession.FieldIntent: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field intent", values[i]) + } else if value.Valid { + _m.Intent = value.String + } + case pendingauthsession.FieldProviderType: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_type", values[i]) + } else if value.Valid { + _m.ProviderType = value.String + } + case pendingauthsession.FieldProviderKey: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_key", values[i]) + } else if value.Valid { + _m.ProviderKey = value.String + } + case pendingauthsession.FieldProviderSubject: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_subject", values[i]) + } else if value.Valid { + _m.ProviderSubject = value.String + } + case pendingauthsession.FieldTargetUserID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field target_user_id", values[i]) + } else if value.Valid { + _m.TargetUserID = new(int64) + *_m.TargetUserID = value.Int64 + } + case pendingauthsession.FieldRedirectTo: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field redirect_to", values[i]) + } else if value.Valid { + _m.RedirectTo = value.String + } + case pendingauthsession.FieldResolvedEmail: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field resolved_email", values[i]) + } else if value.Valid { + _m.ResolvedEmail = value.String + } + case pendingauthsession.FieldRegistrationPasswordHash: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field registration_password_hash", values[i]) + } else if value.Valid { + _m.RegistrationPasswordHash = value.String + } + case pendingauthsession.FieldUpstreamIdentityClaims: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field upstream_identity_claims", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.UpstreamIdentityClaims); err != nil { + return fmt.Errorf("unmarshal field upstream_identity_claims: %w", err) + } + } + case pendingauthsession.FieldLocalFlowState: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field local_flow_state", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.LocalFlowState); err != nil { + return fmt.Errorf("unmarshal field local_flow_state: %w", err) + } + } + case pendingauthsession.FieldBrowserSessionKey: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field browser_session_key", values[i]) + } else if value.Valid { + _m.BrowserSessionKey = value.String + } + case pendingauthsession.FieldCompletionCodeHash: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field completion_code_hash", values[i]) + } else if value.Valid { + _m.CompletionCodeHash = value.String + } + case pendingauthsession.FieldCompletionCodeExpiresAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field completion_code_expires_at", values[i]) + } else if value.Valid { + _m.CompletionCodeExpiresAt = new(time.Time) + *_m.CompletionCodeExpiresAt = value.Time + } + case pendingauthsession.FieldEmailVerifiedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field email_verified_at", values[i]) + } else if value.Valid { + _m.EmailVerifiedAt = new(time.Time) + *_m.EmailVerifiedAt = value.Time + } + case pendingauthsession.FieldPasswordVerifiedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field password_verified_at", values[i]) + } else if value.Valid { + _m.PasswordVerifiedAt = new(time.Time) + *_m.PasswordVerifiedAt = value.Time + } + case pendingauthsession.FieldTotpVerifiedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field totp_verified_at", values[i]) + } else if value.Valid { + _m.TotpVerifiedAt = new(time.Time) + *_m.TotpVerifiedAt = value.Time + } + case pendingauthsession.FieldExpiresAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field expires_at", values[i]) + } else if value.Valid { + _m.ExpiresAt = value.Time + } + case pendingauthsession.FieldConsumedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field consumed_at", values[i]) + } else if value.Valid { + _m.ConsumedAt = new(time.Time) + *_m.ConsumedAt = value.Time + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the PendingAuthSession. +// This includes values selected through modifiers, order, etc. +func (_m *PendingAuthSession) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// QueryTargetUser queries the "target_user" edge of the PendingAuthSession entity. +func (_m *PendingAuthSession) QueryTargetUser() *UserQuery { + return NewPendingAuthSessionClient(_m.config).QueryTargetUser(_m) +} + +// QueryAdoptionDecision queries the "adoption_decision" edge of the PendingAuthSession entity. +func (_m *PendingAuthSession) QueryAdoptionDecision() *IdentityAdoptionDecisionQuery { + return NewPendingAuthSessionClient(_m.config).QueryAdoptionDecision(_m) +} + +// Update returns a builder for updating this PendingAuthSession. +// Note that you need to call PendingAuthSession.Unwrap() before calling this method if this PendingAuthSession +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *PendingAuthSession) Update() *PendingAuthSessionUpdateOne { + return NewPendingAuthSessionClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the PendingAuthSession entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *PendingAuthSession) Unwrap() *PendingAuthSession { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: PendingAuthSession is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *PendingAuthSession) String() string { + var builder strings.Builder + builder.WriteString("PendingAuthSession(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("session_token=") + builder.WriteString(_m.SessionToken) + builder.WriteString(", ") + builder.WriteString("intent=") + builder.WriteString(_m.Intent) + builder.WriteString(", ") + builder.WriteString("provider_type=") + builder.WriteString(_m.ProviderType) + builder.WriteString(", ") + builder.WriteString("provider_key=") + builder.WriteString(_m.ProviderKey) + builder.WriteString(", ") + builder.WriteString("provider_subject=") + builder.WriteString(_m.ProviderSubject) + builder.WriteString(", ") + if v := _m.TargetUserID; v != nil { + builder.WriteString("target_user_id=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + builder.WriteString("redirect_to=") + builder.WriteString(_m.RedirectTo) + builder.WriteString(", ") + builder.WriteString("resolved_email=") + builder.WriteString(_m.ResolvedEmail) + builder.WriteString(", ") + builder.WriteString("registration_password_hash=") + builder.WriteString(_m.RegistrationPasswordHash) + builder.WriteString(", ") + builder.WriteString("upstream_identity_claims=") + builder.WriteString(fmt.Sprintf("%v", _m.UpstreamIdentityClaims)) + builder.WriteString(", ") + builder.WriteString("local_flow_state=") + builder.WriteString(fmt.Sprintf("%v", _m.LocalFlowState)) + builder.WriteString(", ") + builder.WriteString("browser_session_key=") + builder.WriteString(_m.BrowserSessionKey) + builder.WriteString(", ") + builder.WriteString("completion_code_hash=") + builder.WriteString(_m.CompletionCodeHash) + builder.WriteString(", ") + if v := _m.CompletionCodeExpiresAt; v != nil { + builder.WriteString("completion_code_expires_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.EmailVerifiedAt; v != nil { + builder.WriteString("email_verified_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.PasswordVerifiedAt; v != nil { + builder.WriteString("password_verified_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.TotpVerifiedAt; v != nil { + builder.WriteString("totp_verified_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + builder.WriteString("expires_at=") + builder.WriteString(_m.ExpiresAt.Format(time.ANSIC)) + builder.WriteString(", ") + if v := _m.ConsumedAt; v != nil { + builder.WriteString("consumed_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteByte(')') + return builder.String() +} + +// PendingAuthSessions is a parsable slice of PendingAuthSession. +type PendingAuthSessions []*PendingAuthSession diff --git a/backend/ent/pendingauthsession/pendingauthsession.go b/backend/ent/pendingauthsession/pendingauthsession.go new file mode 100644 index 0000000000000000000000000000000000000000..8a3ac9bf783f191c796ce78f71c6d89130ae3c1c --- /dev/null +++ b/backend/ent/pendingauthsession/pendingauthsession.go @@ -0,0 +1,279 @@ +// Code generated by ent, DO NOT EDIT. + +package pendingauthsession + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the pendingauthsession type in the database. + Label = "pending_auth_session" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldSessionToken holds the string denoting the session_token field in the database. + FieldSessionToken = "session_token" + // FieldIntent holds the string denoting the intent field in the database. + FieldIntent = "intent" + // FieldProviderType holds the string denoting the provider_type field in the database. + FieldProviderType = "provider_type" + // FieldProviderKey holds the string denoting the provider_key field in the database. + FieldProviderKey = "provider_key" + // FieldProviderSubject holds the string denoting the provider_subject field in the database. + FieldProviderSubject = "provider_subject" + // FieldTargetUserID holds the string denoting the target_user_id field in the database. + FieldTargetUserID = "target_user_id" + // FieldRedirectTo holds the string denoting the redirect_to field in the database. + FieldRedirectTo = "redirect_to" + // FieldResolvedEmail holds the string denoting the resolved_email field in the database. + FieldResolvedEmail = "resolved_email" + // FieldRegistrationPasswordHash holds the string denoting the registration_password_hash field in the database. + FieldRegistrationPasswordHash = "registration_password_hash" + // FieldUpstreamIdentityClaims holds the string denoting the upstream_identity_claims field in the database. + FieldUpstreamIdentityClaims = "upstream_identity_claims" + // FieldLocalFlowState holds the string denoting the local_flow_state field in the database. + FieldLocalFlowState = "local_flow_state" + // FieldBrowserSessionKey holds the string denoting the browser_session_key field in the database. + FieldBrowserSessionKey = "browser_session_key" + // FieldCompletionCodeHash holds the string denoting the completion_code_hash field in the database. + FieldCompletionCodeHash = "completion_code_hash" + // FieldCompletionCodeExpiresAt holds the string denoting the completion_code_expires_at field in the database. + FieldCompletionCodeExpiresAt = "completion_code_expires_at" + // FieldEmailVerifiedAt holds the string denoting the email_verified_at field in the database. + FieldEmailVerifiedAt = "email_verified_at" + // FieldPasswordVerifiedAt holds the string denoting the password_verified_at field in the database. + FieldPasswordVerifiedAt = "password_verified_at" + // FieldTotpVerifiedAt holds the string denoting the totp_verified_at field in the database. + FieldTotpVerifiedAt = "totp_verified_at" + // FieldExpiresAt holds the string denoting the expires_at field in the database. + FieldExpiresAt = "expires_at" + // FieldConsumedAt holds the string denoting the consumed_at field in the database. + FieldConsumedAt = "consumed_at" + // EdgeTargetUser holds the string denoting the target_user edge name in mutations. + EdgeTargetUser = "target_user" + // EdgeAdoptionDecision holds the string denoting the adoption_decision edge name in mutations. + EdgeAdoptionDecision = "adoption_decision" + // Table holds the table name of the pendingauthsession in the database. + Table = "pending_auth_sessions" + // TargetUserTable is the table that holds the target_user relation/edge. + TargetUserTable = "pending_auth_sessions" + // TargetUserInverseTable is the table name for the User entity. + // It exists in this package in order to avoid circular dependency with the "user" package. + TargetUserInverseTable = "users" + // TargetUserColumn is the table column denoting the target_user relation/edge. + TargetUserColumn = "target_user_id" + // AdoptionDecisionTable is the table that holds the adoption_decision relation/edge. + AdoptionDecisionTable = "identity_adoption_decisions" + // AdoptionDecisionInverseTable is the table name for the IdentityAdoptionDecision entity. + // It exists in this package in order to avoid circular dependency with the "identityadoptiondecision" package. + AdoptionDecisionInverseTable = "identity_adoption_decisions" + // AdoptionDecisionColumn is the table column denoting the adoption_decision relation/edge. + AdoptionDecisionColumn = "pending_auth_session_id" +) + +// Columns holds all SQL columns for pendingauthsession fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldSessionToken, + FieldIntent, + FieldProviderType, + FieldProviderKey, + FieldProviderSubject, + FieldTargetUserID, + FieldRedirectTo, + FieldResolvedEmail, + FieldRegistrationPasswordHash, + FieldUpstreamIdentityClaims, + FieldLocalFlowState, + FieldBrowserSessionKey, + FieldCompletionCodeHash, + FieldCompletionCodeExpiresAt, + FieldEmailVerifiedAt, + FieldPasswordVerifiedAt, + FieldTotpVerifiedAt, + FieldExpiresAt, + FieldConsumedAt, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // SessionTokenValidator is a validator for the "session_token" field. It is called by the builders before save. + SessionTokenValidator func(string) error + // IntentValidator is a validator for the "intent" field. It is called by the builders before save. + IntentValidator func(string) error + // ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save. + ProviderTypeValidator func(string) error + // ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save. + ProviderKeyValidator func(string) error + // ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save. + ProviderSubjectValidator func(string) error + // DefaultRedirectTo holds the default value on creation for the "redirect_to" field. + DefaultRedirectTo string + // DefaultResolvedEmail holds the default value on creation for the "resolved_email" field. + DefaultResolvedEmail string + // DefaultRegistrationPasswordHash holds the default value on creation for the "registration_password_hash" field. + DefaultRegistrationPasswordHash string + // DefaultUpstreamIdentityClaims holds the default value on creation for the "upstream_identity_claims" field. + DefaultUpstreamIdentityClaims func() map[string]interface{} + // DefaultLocalFlowState holds the default value on creation for the "local_flow_state" field. + DefaultLocalFlowState func() map[string]interface{} + // DefaultBrowserSessionKey holds the default value on creation for the "browser_session_key" field. + DefaultBrowserSessionKey string + // DefaultCompletionCodeHash holds the default value on creation for the "completion_code_hash" field. + DefaultCompletionCodeHash string +) + +// OrderOption defines the ordering options for the PendingAuthSession queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// BySessionToken orders the results by the session_token field. +func BySessionToken(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSessionToken, opts...).ToFunc() +} + +// ByIntent orders the results by the intent field. +func ByIntent(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIntent, opts...).ToFunc() +} + +// ByProviderType orders the results by the provider_type field. +func ByProviderType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderType, opts...).ToFunc() +} + +// ByProviderKey orders the results by the provider_key field. +func ByProviderKey(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderKey, opts...).ToFunc() +} + +// ByProviderSubject orders the results by the provider_subject field. +func ByProviderSubject(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderSubject, opts...).ToFunc() +} + +// ByTargetUserID orders the results by the target_user_id field. +func ByTargetUserID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTargetUserID, opts...).ToFunc() +} + +// ByRedirectTo orders the results by the redirect_to field. +func ByRedirectTo(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRedirectTo, opts...).ToFunc() +} + +// ByResolvedEmail orders the results by the resolved_email field. +func ByResolvedEmail(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldResolvedEmail, opts...).ToFunc() +} + +// ByRegistrationPasswordHash orders the results by the registration_password_hash field. +func ByRegistrationPasswordHash(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRegistrationPasswordHash, opts...).ToFunc() +} + +// ByBrowserSessionKey orders the results by the browser_session_key field. +func ByBrowserSessionKey(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldBrowserSessionKey, opts...).ToFunc() +} + +// ByCompletionCodeHash orders the results by the completion_code_hash field. +func ByCompletionCodeHash(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCompletionCodeHash, opts...).ToFunc() +} + +// ByCompletionCodeExpiresAt orders the results by the completion_code_expires_at field. +func ByCompletionCodeExpiresAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCompletionCodeExpiresAt, opts...).ToFunc() +} + +// ByEmailVerifiedAt orders the results by the email_verified_at field. +func ByEmailVerifiedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldEmailVerifiedAt, opts...).ToFunc() +} + +// ByPasswordVerifiedAt orders the results by the password_verified_at field. +func ByPasswordVerifiedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPasswordVerifiedAt, opts...).ToFunc() +} + +// ByTotpVerifiedAt orders the results by the totp_verified_at field. +func ByTotpVerifiedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTotpVerifiedAt, opts...).ToFunc() +} + +// ByExpiresAt orders the results by the expires_at field. +func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldExpiresAt, opts...).ToFunc() +} + +// ByConsumedAt orders the results by the consumed_at field. +func ByConsumedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldConsumedAt, opts...).ToFunc() +} + +// ByTargetUserField orders the results by target_user field. +func ByTargetUserField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newTargetUserStep(), sql.OrderByField(field, opts...)) + } +} + +// ByAdoptionDecisionField orders the results by adoption_decision field. +func ByAdoptionDecisionField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newAdoptionDecisionStep(), sql.OrderByField(field, opts...)) + } +} +func newTargetUserStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(TargetUserInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, TargetUserTable, TargetUserColumn), + ) +} +func newAdoptionDecisionStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(AdoptionDecisionInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2O, false, AdoptionDecisionTable, AdoptionDecisionColumn), + ) +} diff --git a/backend/ent/pendingauthsession/where.go b/backend/ent/pendingauthsession/where.go new file mode 100644 index 0000000000000000000000000000000000000000..cb316f476e44195e74f961699d058837cbe38630 --- /dev/null +++ b/backend/ent/pendingauthsession/where.go @@ -0,0 +1,1262 @@ +// Code generated by ent, DO NOT EDIT. + +package pendingauthsession + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// SessionToken applies equality check predicate on the "session_token" field. It's identical to SessionTokenEQ. +func SessionToken(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldSessionToken, v)) +} + +// Intent applies equality check predicate on the "intent" field. It's identical to IntentEQ. +func Intent(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldIntent, v)) +} + +// ProviderType applies equality check predicate on the "provider_type" field. It's identical to ProviderTypeEQ. +func ProviderType(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderType, v)) +} + +// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ. +func ProviderKey(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderKey, v)) +} + +// ProviderSubject applies equality check predicate on the "provider_subject" field. It's identical to ProviderSubjectEQ. +func ProviderSubject(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderSubject, v)) +} + +// TargetUserID applies equality check predicate on the "target_user_id" field. It's identical to TargetUserIDEQ. +func TargetUserID(v int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldTargetUserID, v)) +} + +// RedirectTo applies equality check predicate on the "redirect_to" field. It's identical to RedirectToEQ. +func RedirectTo(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldRedirectTo, v)) +} + +// ResolvedEmail applies equality check predicate on the "resolved_email" field. It's identical to ResolvedEmailEQ. +func ResolvedEmail(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldResolvedEmail, v)) +} + +// RegistrationPasswordHash applies equality check predicate on the "registration_password_hash" field. It's identical to RegistrationPasswordHashEQ. +func RegistrationPasswordHash(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldRegistrationPasswordHash, v)) +} + +// BrowserSessionKey applies equality check predicate on the "browser_session_key" field. It's identical to BrowserSessionKeyEQ. +func BrowserSessionKey(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldBrowserSessionKey, v)) +} + +// CompletionCodeHash applies equality check predicate on the "completion_code_hash" field. It's identical to CompletionCodeHashEQ. +func CompletionCodeHash(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeHash, v)) +} + +// CompletionCodeExpiresAt applies equality check predicate on the "completion_code_expires_at" field. It's identical to CompletionCodeExpiresAtEQ. +func CompletionCodeExpiresAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeExpiresAt, v)) +} + +// EmailVerifiedAt applies equality check predicate on the "email_verified_at" field. It's identical to EmailVerifiedAtEQ. +func EmailVerifiedAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldEmailVerifiedAt, v)) +} + +// PasswordVerifiedAt applies equality check predicate on the "password_verified_at" field. It's identical to PasswordVerifiedAtEQ. +func PasswordVerifiedAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldPasswordVerifiedAt, v)) +} + +// TotpVerifiedAt applies equality check predicate on the "totp_verified_at" field. It's identical to TotpVerifiedAtEQ. +func TotpVerifiedAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldTotpVerifiedAt, v)) +} + +// ExpiresAt applies equality check predicate on the "expires_at" field. It's identical to ExpiresAtEQ. +func ExpiresAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldExpiresAt, v)) +} + +// ConsumedAt applies equality check predicate on the "consumed_at" field. It's identical to ConsumedAtEQ. +func ConsumedAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldConsumedAt, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// SessionTokenEQ applies the EQ predicate on the "session_token" field. +func SessionTokenEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldSessionToken, v)) +} + +// SessionTokenNEQ applies the NEQ predicate on the "session_token" field. +func SessionTokenNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldSessionToken, v)) +} + +// SessionTokenIn applies the In predicate on the "session_token" field. +func SessionTokenIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldSessionToken, vs...)) +} + +// SessionTokenNotIn applies the NotIn predicate on the "session_token" field. +func SessionTokenNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldSessionToken, vs...)) +} + +// SessionTokenGT applies the GT predicate on the "session_token" field. +func SessionTokenGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldSessionToken, v)) +} + +// SessionTokenGTE applies the GTE predicate on the "session_token" field. +func SessionTokenGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldSessionToken, v)) +} + +// SessionTokenLT applies the LT predicate on the "session_token" field. +func SessionTokenLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldSessionToken, v)) +} + +// SessionTokenLTE applies the LTE predicate on the "session_token" field. +func SessionTokenLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldSessionToken, v)) +} + +// SessionTokenContains applies the Contains predicate on the "session_token" field. +func SessionTokenContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldSessionToken, v)) +} + +// SessionTokenHasPrefix applies the HasPrefix predicate on the "session_token" field. +func SessionTokenHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldSessionToken, v)) +} + +// SessionTokenHasSuffix applies the HasSuffix predicate on the "session_token" field. +func SessionTokenHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldSessionToken, v)) +} + +// SessionTokenEqualFold applies the EqualFold predicate on the "session_token" field. +func SessionTokenEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldSessionToken, v)) +} + +// SessionTokenContainsFold applies the ContainsFold predicate on the "session_token" field. +func SessionTokenContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldSessionToken, v)) +} + +// IntentEQ applies the EQ predicate on the "intent" field. +func IntentEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldIntent, v)) +} + +// IntentNEQ applies the NEQ predicate on the "intent" field. +func IntentNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldIntent, v)) +} + +// IntentIn applies the In predicate on the "intent" field. +func IntentIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldIntent, vs...)) +} + +// IntentNotIn applies the NotIn predicate on the "intent" field. +func IntentNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldIntent, vs...)) +} + +// IntentGT applies the GT predicate on the "intent" field. +func IntentGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldIntent, v)) +} + +// IntentGTE applies the GTE predicate on the "intent" field. +func IntentGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldIntent, v)) +} + +// IntentLT applies the LT predicate on the "intent" field. +func IntentLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldIntent, v)) +} + +// IntentLTE applies the LTE predicate on the "intent" field. +func IntentLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldIntent, v)) +} + +// IntentContains applies the Contains predicate on the "intent" field. +func IntentContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldIntent, v)) +} + +// IntentHasPrefix applies the HasPrefix predicate on the "intent" field. +func IntentHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldIntent, v)) +} + +// IntentHasSuffix applies the HasSuffix predicate on the "intent" field. +func IntentHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldIntent, v)) +} + +// IntentEqualFold applies the EqualFold predicate on the "intent" field. +func IntentEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldIntent, v)) +} + +// IntentContainsFold applies the ContainsFold predicate on the "intent" field. +func IntentContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldIntent, v)) +} + +// ProviderTypeEQ applies the EQ predicate on the "provider_type" field. +func ProviderTypeEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderType, v)) +} + +// ProviderTypeNEQ applies the NEQ predicate on the "provider_type" field. +func ProviderTypeNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldProviderType, v)) +} + +// ProviderTypeIn applies the In predicate on the "provider_type" field. +func ProviderTypeIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldProviderType, vs...)) +} + +// ProviderTypeNotIn applies the NotIn predicate on the "provider_type" field. +func ProviderTypeNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldProviderType, vs...)) +} + +// ProviderTypeGT applies the GT predicate on the "provider_type" field. +func ProviderTypeGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldProviderType, v)) +} + +// ProviderTypeGTE applies the GTE predicate on the "provider_type" field. +func ProviderTypeGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldProviderType, v)) +} + +// ProviderTypeLT applies the LT predicate on the "provider_type" field. +func ProviderTypeLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldProviderType, v)) +} + +// ProviderTypeLTE applies the LTE predicate on the "provider_type" field. +func ProviderTypeLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldProviderType, v)) +} + +// ProviderTypeContains applies the Contains predicate on the "provider_type" field. +func ProviderTypeContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldProviderType, v)) +} + +// ProviderTypeHasPrefix applies the HasPrefix predicate on the "provider_type" field. +func ProviderTypeHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldProviderType, v)) +} + +// ProviderTypeHasSuffix applies the HasSuffix predicate on the "provider_type" field. +func ProviderTypeHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldProviderType, v)) +} + +// ProviderTypeEqualFold applies the EqualFold predicate on the "provider_type" field. +func ProviderTypeEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldProviderType, v)) +} + +// ProviderTypeContainsFold applies the ContainsFold predicate on the "provider_type" field. +func ProviderTypeContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldProviderType, v)) +} + +// ProviderKeyEQ applies the EQ predicate on the "provider_key" field. +func ProviderKeyEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderKey, v)) +} + +// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field. +func ProviderKeyNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldProviderKey, v)) +} + +// ProviderKeyIn applies the In predicate on the "provider_key" field. +func ProviderKeyIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldProviderKey, vs...)) +} + +// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field. +func ProviderKeyNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldProviderKey, vs...)) +} + +// ProviderKeyGT applies the GT predicate on the "provider_key" field. +func ProviderKeyGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldProviderKey, v)) +} + +// ProviderKeyGTE applies the GTE predicate on the "provider_key" field. +func ProviderKeyGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldProviderKey, v)) +} + +// ProviderKeyLT applies the LT predicate on the "provider_key" field. +func ProviderKeyLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldProviderKey, v)) +} + +// ProviderKeyLTE applies the LTE predicate on the "provider_key" field. +func ProviderKeyLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldProviderKey, v)) +} + +// ProviderKeyContains applies the Contains predicate on the "provider_key" field. +func ProviderKeyContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldProviderKey, v)) +} + +// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field. +func ProviderKeyHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldProviderKey, v)) +} + +// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field. +func ProviderKeyHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldProviderKey, v)) +} + +// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field. +func ProviderKeyEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldProviderKey, v)) +} + +// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field. +func ProviderKeyContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldProviderKey, v)) +} + +// ProviderSubjectEQ applies the EQ predicate on the "provider_subject" field. +func ProviderSubjectEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderSubject, v)) +} + +// ProviderSubjectNEQ applies the NEQ predicate on the "provider_subject" field. +func ProviderSubjectNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldProviderSubject, v)) +} + +// ProviderSubjectIn applies the In predicate on the "provider_subject" field. +func ProviderSubjectIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldProviderSubject, vs...)) +} + +// ProviderSubjectNotIn applies the NotIn predicate on the "provider_subject" field. +func ProviderSubjectNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldProviderSubject, vs...)) +} + +// ProviderSubjectGT applies the GT predicate on the "provider_subject" field. +func ProviderSubjectGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldProviderSubject, v)) +} + +// ProviderSubjectGTE applies the GTE predicate on the "provider_subject" field. +func ProviderSubjectGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldProviderSubject, v)) +} + +// ProviderSubjectLT applies the LT predicate on the "provider_subject" field. +func ProviderSubjectLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldProviderSubject, v)) +} + +// ProviderSubjectLTE applies the LTE predicate on the "provider_subject" field. +func ProviderSubjectLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldProviderSubject, v)) +} + +// ProviderSubjectContains applies the Contains predicate on the "provider_subject" field. +func ProviderSubjectContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldProviderSubject, v)) +} + +// ProviderSubjectHasPrefix applies the HasPrefix predicate on the "provider_subject" field. +func ProviderSubjectHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldProviderSubject, v)) +} + +// ProviderSubjectHasSuffix applies the HasSuffix predicate on the "provider_subject" field. +func ProviderSubjectHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldProviderSubject, v)) +} + +// ProviderSubjectEqualFold applies the EqualFold predicate on the "provider_subject" field. +func ProviderSubjectEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldProviderSubject, v)) +} + +// ProviderSubjectContainsFold applies the ContainsFold predicate on the "provider_subject" field. +func ProviderSubjectContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldProviderSubject, v)) +} + +// TargetUserIDEQ applies the EQ predicate on the "target_user_id" field. +func TargetUserIDEQ(v int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldTargetUserID, v)) +} + +// TargetUserIDNEQ applies the NEQ predicate on the "target_user_id" field. +func TargetUserIDNEQ(v int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldTargetUserID, v)) +} + +// TargetUserIDIn applies the In predicate on the "target_user_id" field. +func TargetUserIDIn(vs ...int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldTargetUserID, vs...)) +} + +// TargetUserIDNotIn applies the NotIn predicate on the "target_user_id" field. +func TargetUserIDNotIn(vs ...int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldTargetUserID, vs...)) +} + +// TargetUserIDIsNil applies the IsNil predicate on the "target_user_id" field. +func TargetUserIDIsNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIsNull(FieldTargetUserID)) +} + +// TargetUserIDNotNil applies the NotNil predicate on the "target_user_id" field. +func TargetUserIDNotNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotNull(FieldTargetUserID)) +} + +// RedirectToEQ applies the EQ predicate on the "redirect_to" field. +func RedirectToEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldRedirectTo, v)) +} + +// RedirectToNEQ applies the NEQ predicate on the "redirect_to" field. +func RedirectToNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldRedirectTo, v)) +} + +// RedirectToIn applies the In predicate on the "redirect_to" field. +func RedirectToIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldRedirectTo, vs...)) +} + +// RedirectToNotIn applies the NotIn predicate on the "redirect_to" field. +func RedirectToNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldRedirectTo, vs...)) +} + +// RedirectToGT applies the GT predicate on the "redirect_to" field. +func RedirectToGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldRedirectTo, v)) +} + +// RedirectToGTE applies the GTE predicate on the "redirect_to" field. +func RedirectToGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldRedirectTo, v)) +} + +// RedirectToLT applies the LT predicate on the "redirect_to" field. +func RedirectToLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldRedirectTo, v)) +} + +// RedirectToLTE applies the LTE predicate on the "redirect_to" field. +func RedirectToLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldRedirectTo, v)) +} + +// RedirectToContains applies the Contains predicate on the "redirect_to" field. +func RedirectToContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldRedirectTo, v)) +} + +// RedirectToHasPrefix applies the HasPrefix predicate on the "redirect_to" field. +func RedirectToHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldRedirectTo, v)) +} + +// RedirectToHasSuffix applies the HasSuffix predicate on the "redirect_to" field. +func RedirectToHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldRedirectTo, v)) +} + +// RedirectToEqualFold applies the EqualFold predicate on the "redirect_to" field. +func RedirectToEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldRedirectTo, v)) +} + +// RedirectToContainsFold applies the ContainsFold predicate on the "redirect_to" field. +func RedirectToContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldRedirectTo, v)) +} + +// ResolvedEmailEQ applies the EQ predicate on the "resolved_email" field. +func ResolvedEmailEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldResolvedEmail, v)) +} + +// ResolvedEmailNEQ applies the NEQ predicate on the "resolved_email" field. +func ResolvedEmailNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldResolvedEmail, v)) +} + +// ResolvedEmailIn applies the In predicate on the "resolved_email" field. +func ResolvedEmailIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldResolvedEmail, vs...)) +} + +// ResolvedEmailNotIn applies the NotIn predicate on the "resolved_email" field. +func ResolvedEmailNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldResolvedEmail, vs...)) +} + +// ResolvedEmailGT applies the GT predicate on the "resolved_email" field. +func ResolvedEmailGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldResolvedEmail, v)) +} + +// ResolvedEmailGTE applies the GTE predicate on the "resolved_email" field. +func ResolvedEmailGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldResolvedEmail, v)) +} + +// ResolvedEmailLT applies the LT predicate on the "resolved_email" field. +func ResolvedEmailLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldResolvedEmail, v)) +} + +// ResolvedEmailLTE applies the LTE predicate on the "resolved_email" field. +func ResolvedEmailLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldResolvedEmail, v)) +} + +// ResolvedEmailContains applies the Contains predicate on the "resolved_email" field. +func ResolvedEmailContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldResolvedEmail, v)) +} + +// ResolvedEmailHasPrefix applies the HasPrefix predicate on the "resolved_email" field. +func ResolvedEmailHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldResolvedEmail, v)) +} + +// ResolvedEmailHasSuffix applies the HasSuffix predicate on the "resolved_email" field. +func ResolvedEmailHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldResolvedEmail, v)) +} + +// ResolvedEmailEqualFold applies the EqualFold predicate on the "resolved_email" field. +func ResolvedEmailEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldResolvedEmail, v)) +} + +// ResolvedEmailContainsFold applies the ContainsFold predicate on the "resolved_email" field. +func ResolvedEmailContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldResolvedEmail, v)) +} + +// RegistrationPasswordHashEQ applies the EQ predicate on the "registration_password_hash" field. +func RegistrationPasswordHashEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashNEQ applies the NEQ predicate on the "registration_password_hash" field. +func RegistrationPasswordHashNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashIn applies the In predicate on the "registration_password_hash" field. +func RegistrationPasswordHashIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldRegistrationPasswordHash, vs...)) +} + +// RegistrationPasswordHashNotIn applies the NotIn predicate on the "registration_password_hash" field. +func RegistrationPasswordHashNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldRegistrationPasswordHash, vs...)) +} + +// RegistrationPasswordHashGT applies the GT predicate on the "registration_password_hash" field. +func RegistrationPasswordHashGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashGTE applies the GTE predicate on the "registration_password_hash" field. +func RegistrationPasswordHashGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashLT applies the LT predicate on the "registration_password_hash" field. +func RegistrationPasswordHashLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashLTE applies the LTE predicate on the "registration_password_hash" field. +func RegistrationPasswordHashLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashContains applies the Contains predicate on the "registration_password_hash" field. +func RegistrationPasswordHashContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashHasPrefix applies the HasPrefix predicate on the "registration_password_hash" field. +func RegistrationPasswordHashHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashHasSuffix applies the HasSuffix predicate on the "registration_password_hash" field. +func RegistrationPasswordHashHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashEqualFold applies the EqualFold predicate on the "registration_password_hash" field. +func RegistrationPasswordHashEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashContainsFold applies the ContainsFold predicate on the "registration_password_hash" field. +func RegistrationPasswordHashContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldRegistrationPasswordHash, v)) +} + +// BrowserSessionKeyEQ applies the EQ predicate on the "browser_session_key" field. +func BrowserSessionKeyEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyNEQ applies the NEQ predicate on the "browser_session_key" field. +func BrowserSessionKeyNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyIn applies the In predicate on the "browser_session_key" field. +func BrowserSessionKeyIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldBrowserSessionKey, vs...)) +} + +// BrowserSessionKeyNotIn applies the NotIn predicate on the "browser_session_key" field. +func BrowserSessionKeyNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldBrowserSessionKey, vs...)) +} + +// BrowserSessionKeyGT applies the GT predicate on the "browser_session_key" field. +func BrowserSessionKeyGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyGTE applies the GTE predicate on the "browser_session_key" field. +func BrowserSessionKeyGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyLT applies the LT predicate on the "browser_session_key" field. +func BrowserSessionKeyLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyLTE applies the LTE predicate on the "browser_session_key" field. +func BrowserSessionKeyLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyContains applies the Contains predicate on the "browser_session_key" field. +func BrowserSessionKeyContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyHasPrefix applies the HasPrefix predicate on the "browser_session_key" field. +func BrowserSessionKeyHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyHasSuffix applies the HasSuffix predicate on the "browser_session_key" field. +func BrowserSessionKeyHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyEqualFold applies the EqualFold predicate on the "browser_session_key" field. +func BrowserSessionKeyEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyContainsFold applies the ContainsFold predicate on the "browser_session_key" field. +func BrowserSessionKeyContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldBrowserSessionKey, v)) +} + +// CompletionCodeHashEQ applies the EQ predicate on the "completion_code_hash" field. +func CompletionCodeHashEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashNEQ applies the NEQ predicate on the "completion_code_hash" field. +func CompletionCodeHashNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashIn applies the In predicate on the "completion_code_hash" field. +func CompletionCodeHashIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldCompletionCodeHash, vs...)) +} + +// CompletionCodeHashNotIn applies the NotIn predicate on the "completion_code_hash" field. +func CompletionCodeHashNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldCompletionCodeHash, vs...)) +} + +// CompletionCodeHashGT applies the GT predicate on the "completion_code_hash" field. +func CompletionCodeHashGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashGTE applies the GTE predicate on the "completion_code_hash" field. +func CompletionCodeHashGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashLT applies the LT predicate on the "completion_code_hash" field. +func CompletionCodeHashLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashLTE applies the LTE predicate on the "completion_code_hash" field. +func CompletionCodeHashLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashContains applies the Contains predicate on the "completion_code_hash" field. +func CompletionCodeHashContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashHasPrefix applies the HasPrefix predicate on the "completion_code_hash" field. +func CompletionCodeHashHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashHasSuffix applies the HasSuffix predicate on the "completion_code_hash" field. +func CompletionCodeHashHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashEqualFold applies the EqualFold predicate on the "completion_code_hash" field. +func CompletionCodeHashEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashContainsFold applies the ContainsFold predicate on the "completion_code_hash" field. +func CompletionCodeHashContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldCompletionCodeHash, v)) +} + +// CompletionCodeExpiresAtEQ applies the EQ predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeExpiresAt, v)) +} + +// CompletionCodeExpiresAtNEQ applies the NEQ predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldCompletionCodeExpiresAt, v)) +} + +// CompletionCodeExpiresAtIn applies the In predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldCompletionCodeExpiresAt, vs...)) +} + +// CompletionCodeExpiresAtNotIn applies the NotIn predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldCompletionCodeExpiresAt, vs...)) +} + +// CompletionCodeExpiresAtGT applies the GT predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldCompletionCodeExpiresAt, v)) +} + +// CompletionCodeExpiresAtGTE applies the GTE predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldCompletionCodeExpiresAt, v)) +} + +// CompletionCodeExpiresAtLT applies the LT predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldCompletionCodeExpiresAt, v)) +} + +// CompletionCodeExpiresAtLTE applies the LTE predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldCompletionCodeExpiresAt, v)) +} + +// CompletionCodeExpiresAtIsNil applies the IsNil predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtIsNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIsNull(FieldCompletionCodeExpiresAt)) +} + +// CompletionCodeExpiresAtNotNil applies the NotNil predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtNotNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotNull(FieldCompletionCodeExpiresAt)) +} + +// EmailVerifiedAtEQ applies the EQ predicate on the "email_verified_at" field. +func EmailVerifiedAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldEmailVerifiedAt, v)) +} + +// EmailVerifiedAtNEQ applies the NEQ predicate on the "email_verified_at" field. +func EmailVerifiedAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldEmailVerifiedAt, v)) +} + +// EmailVerifiedAtIn applies the In predicate on the "email_verified_at" field. +func EmailVerifiedAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldEmailVerifiedAt, vs...)) +} + +// EmailVerifiedAtNotIn applies the NotIn predicate on the "email_verified_at" field. +func EmailVerifiedAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldEmailVerifiedAt, vs...)) +} + +// EmailVerifiedAtGT applies the GT predicate on the "email_verified_at" field. +func EmailVerifiedAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldEmailVerifiedAt, v)) +} + +// EmailVerifiedAtGTE applies the GTE predicate on the "email_verified_at" field. +func EmailVerifiedAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldEmailVerifiedAt, v)) +} + +// EmailVerifiedAtLT applies the LT predicate on the "email_verified_at" field. +func EmailVerifiedAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldEmailVerifiedAt, v)) +} + +// EmailVerifiedAtLTE applies the LTE predicate on the "email_verified_at" field. +func EmailVerifiedAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldEmailVerifiedAt, v)) +} + +// EmailVerifiedAtIsNil applies the IsNil predicate on the "email_verified_at" field. +func EmailVerifiedAtIsNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIsNull(FieldEmailVerifiedAt)) +} + +// EmailVerifiedAtNotNil applies the NotNil predicate on the "email_verified_at" field. +func EmailVerifiedAtNotNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotNull(FieldEmailVerifiedAt)) +} + +// PasswordVerifiedAtEQ applies the EQ predicate on the "password_verified_at" field. +func PasswordVerifiedAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldPasswordVerifiedAt, v)) +} + +// PasswordVerifiedAtNEQ applies the NEQ predicate on the "password_verified_at" field. +func PasswordVerifiedAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldPasswordVerifiedAt, v)) +} + +// PasswordVerifiedAtIn applies the In predicate on the "password_verified_at" field. +func PasswordVerifiedAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldPasswordVerifiedAt, vs...)) +} + +// PasswordVerifiedAtNotIn applies the NotIn predicate on the "password_verified_at" field. +func PasswordVerifiedAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldPasswordVerifiedAt, vs...)) +} + +// PasswordVerifiedAtGT applies the GT predicate on the "password_verified_at" field. +func PasswordVerifiedAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldPasswordVerifiedAt, v)) +} + +// PasswordVerifiedAtGTE applies the GTE predicate on the "password_verified_at" field. +func PasswordVerifiedAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldPasswordVerifiedAt, v)) +} + +// PasswordVerifiedAtLT applies the LT predicate on the "password_verified_at" field. +func PasswordVerifiedAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldPasswordVerifiedAt, v)) +} + +// PasswordVerifiedAtLTE applies the LTE predicate on the "password_verified_at" field. +func PasswordVerifiedAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldPasswordVerifiedAt, v)) +} + +// PasswordVerifiedAtIsNil applies the IsNil predicate on the "password_verified_at" field. +func PasswordVerifiedAtIsNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIsNull(FieldPasswordVerifiedAt)) +} + +// PasswordVerifiedAtNotNil applies the NotNil predicate on the "password_verified_at" field. +func PasswordVerifiedAtNotNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotNull(FieldPasswordVerifiedAt)) +} + +// TotpVerifiedAtEQ applies the EQ predicate on the "totp_verified_at" field. +func TotpVerifiedAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldTotpVerifiedAt, v)) +} + +// TotpVerifiedAtNEQ applies the NEQ predicate on the "totp_verified_at" field. +func TotpVerifiedAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldTotpVerifiedAt, v)) +} + +// TotpVerifiedAtIn applies the In predicate on the "totp_verified_at" field. +func TotpVerifiedAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldTotpVerifiedAt, vs...)) +} + +// TotpVerifiedAtNotIn applies the NotIn predicate on the "totp_verified_at" field. +func TotpVerifiedAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldTotpVerifiedAt, vs...)) +} + +// TotpVerifiedAtGT applies the GT predicate on the "totp_verified_at" field. +func TotpVerifiedAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldTotpVerifiedAt, v)) +} + +// TotpVerifiedAtGTE applies the GTE predicate on the "totp_verified_at" field. +func TotpVerifiedAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldTotpVerifiedAt, v)) +} + +// TotpVerifiedAtLT applies the LT predicate on the "totp_verified_at" field. +func TotpVerifiedAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldTotpVerifiedAt, v)) +} + +// TotpVerifiedAtLTE applies the LTE predicate on the "totp_verified_at" field. +func TotpVerifiedAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldTotpVerifiedAt, v)) +} + +// TotpVerifiedAtIsNil applies the IsNil predicate on the "totp_verified_at" field. +func TotpVerifiedAtIsNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIsNull(FieldTotpVerifiedAt)) +} + +// TotpVerifiedAtNotNil applies the NotNil predicate on the "totp_verified_at" field. +func TotpVerifiedAtNotNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotNull(FieldTotpVerifiedAt)) +} + +// ExpiresAtEQ applies the EQ predicate on the "expires_at" field. +func ExpiresAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldExpiresAt, v)) +} + +// ExpiresAtNEQ applies the NEQ predicate on the "expires_at" field. +func ExpiresAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldExpiresAt, v)) +} + +// ExpiresAtIn applies the In predicate on the "expires_at" field. +func ExpiresAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldExpiresAt, vs...)) +} + +// ExpiresAtNotIn applies the NotIn predicate on the "expires_at" field. +func ExpiresAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldExpiresAt, vs...)) +} + +// ExpiresAtGT applies the GT predicate on the "expires_at" field. +func ExpiresAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldExpiresAt, v)) +} + +// ExpiresAtGTE applies the GTE predicate on the "expires_at" field. +func ExpiresAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldExpiresAt, v)) +} + +// ExpiresAtLT applies the LT predicate on the "expires_at" field. +func ExpiresAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldExpiresAt, v)) +} + +// ExpiresAtLTE applies the LTE predicate on the "expires_at" field. +func ExpiresAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldExpiresAt, v)) +} + +// ConsumedAtEQ applies the EQ predicate on the "consumed_at" field. +func ConsumedAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldConsumedAt, v)) +} + +// ConsumedAtNEQ applies the NEQ predicate on the "consumed_at" field. +func ConsumedAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldConsumedAt, v)) +} + +// ConsumedAtIn applies the In predicate on the "consumed_at" field. +func ConsumedAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldConsumedAt, vs...)) +} + +// ConsumedAtNotIn applies the NotIn predicate on the "consumed_at" field. +func ConsumedAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldConsumedAt, vs...)) +} + +// ConsumedAtGT applies the GT predicate on the "consumed_at" field. +func ConsumedAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldConsumedAt, v)) +} + +// ConsumedAtGTE applies the GTE predicate on the "consumed_at" field. +func ConsumedAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldConsumedAt, v)) +} + +// ConsumedAtLT applies the LT predicate on the "consumed_at" field. +func ConsumedAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldConsumedAt, v)) +} + +// ConsumedAtLTE applies the LTE predicate on the "consumed_at" field. +func ConsumedAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldConsumedAt, v)) +} + +// ConsumedAtIsNil applies the IsNil predicate on the "consumed_at" field. +func ConsumedAtIsNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIsNull(FieldConsumedAt)) +} + +// ConsumedAtNotNil applies the NotNil predicate on the "consumed_at" field. +func ConsumedAtNotNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotNull(FieldConsumedAt)) +} + +// HasTargetUser applies the HasEdge predicate on the "target_user" edge. +func HasTargetUser() predicate.PendingAuthSession { + return predicate.PendingAuthSession(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, TargetUserTable, TargetUserColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasTargetUserWith applies the HasEdge predicate on the "target_user" edge with a given conditions (other predicates). +func HasTargetUserWith(preds ...predicate.User) predicate.PendingAuthSession { + return predicate.PendingAuthSession(func(s *sql.Selector) { + step := newTargetUserStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasAdoptionDecision applies the HasEdge predicate on the "adoption_decision" edge. +func HasAdoptionDecision() predicate.PendingAuthSession { + return predicate.PendingAuthSession(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2O, false, AdoptionDecisionTable, AdoptionDecisionColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasAdoptionDecisionWith applies the HasEdge predicate on the "adoption_decision" edge with a given conditions (other predicates). +func HasAdoptionDecisionWith(preds ...predicate.IdentityAdoptionDecision) predicate.PendingAuthSession { + return predicate.PendingAuthSession(func(s *sql.Selector) { + step := newAdoptionDecisionStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.PendingAuthSession) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.PendingAuthSession) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.PendingAuthSession) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.NotPredicates(p)) +} diff --git a/backend/ent/pendingauthsession_create.go b/backend/ent/pendingauthsession_create.go new file mode 100644 index 0000000000000000000000000000000000000000..60276daa1bd9a1913f8fb65b0ff515471dc48210 --- /dev/null +++ b/backend/ent/pendingauthsession_create.go @@ -0,0 +1,1815 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// PendingAuthSessionCreate is the builder for creating a PendingAuthSession entity. +type PendingAuthSessionCreate struct { + config + mutation *PendingAuthSessionMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (_c *PendingAuthSessionCreate) SetCreatedAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableCreatedAt(v *time.Time) *PendingAuthSessionCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUpdatedAt sets the "updated_at" field. +func (_c *PendingAuthSessionCreate) SetUpdatedAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetUpdatedAt(v) + return _c +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableUpdatedAt(v *time.Time) *PendingAuthSessionCreate { + if v != nil { + _c.SetUpdatedAt(*v) + } + return _c +} + +// SetSessionToken sets the "session_token" field. +func (_c *PendingAuthSessionCreate) SetSessionToken(v string) *PendingAuthSessionCreate { + _c.mutation.SetSessionToken(v) + return _c +} + +// SetIntent sets the "intent" field. +func (_c *PendingAuthSessionCreate) SetIntent(v string) *PendingAuthSessionCreate { + _c.mutation.SetIntent(v) + return _c +} + +// SetProviderType sets the "provider_type" field. +func (_c *PendingAuthSessionCreate) SetProviderType(v string) *PendingAuthSessionCreate { + _c.mutation.SetProviderType(v) + return _c +} + +// SetProviderKey sets the "provider_key" field. +func (_c *PendingAuthSessionCreate) SetProviderKey(v string) *PendingAuthSessionCreate { + _c.mutation.SetProviderKey(v) + return _c +} + +// SetProviderSubject sets the "provider_subject" field. +func (_c *PendingAuthSessionCreate) SetProviderSubject(v string) *PendingAuthSessionCreate { + _c.mutation.SetProviderSubject(v) + return _c +} + +// SetTargetUserID sets the "target_user_id" field. +func (_c *PendingAuthSessionCreate) SetTargetUserID(v int64) *PendingAuthSessionCreate { + _c.mutation.SetTargetUserID(v) + return _c +} + +// SetNillableTargetUserID sets the "target_user_id" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableTargetUserID(v *int64) *PendingAuthSessionCreate { + if v != nil { + _c.SetTargetUserID(*v) + } + return _c +} + +// SetRedirectTo sets the "redirect_to" field. +func (_c *PendingAuthSessionCreate) SetRedirectTo(v string) *PendingAuthSessionCreate { + _c.mutation.SetRedirectTo(v) + return _c +} + +// SetNillableRedirectTo sets the "redirect_to" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableRedirectTo(v *string) *PendingAuthSessionCreate { + if v != nil { + _c.SetRedirectTo(*v) + } + return _c +} + +// SetResolvedEmail sets the "resolved_email" field. +func (_c *PendingAuthSessionCreate) SetResolvedEmail(v string) *PendingAuthSessionCreate { + _c.mutation.SetResolvedEmail(v) + return _c +} + +// SetNillableResolvedEmail sets the "resolved_email" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableResolvedEmail(v *string) *PendingAuthSessionCreate { + if v != nil { + _c.SetResolvedEmail(*v) + } + return _c +} + +// SetRegistrationPasswordHash sets the "registration_password_hash" field. +func (_c *PendingAuthSessionCreate) SetRegistrationPasswordHash(v string) *PendingAuthSessionCreate { + _c.mutation.SetRegistrationPasswordHash(v) + return _c +} + +// SetNillableRegistrationPasswordHash sets the "registration_password_hash" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableRegistrationPasswordHash(v *string) *PendingAuthSessionCreate { + if v != nil { + _c.SetRegistrationPasswordHash(*v) + } + return _c +} + +// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field. +func (_c *PendingAuthSessionCreate) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionCreate { + _c.mutation.SetUpstreamIdentityClaims(v) + return _c +} + +// SetLocalFlowState sets the "local_flow_state" field. +func (_c *PendingAuthSessionCreate) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionCreate { + _c.mutation.SetLocalFlowState(v) + return _c +} + +// SetBrowserSessionKey sets the "browser_session_key" field. +func (_c *PendingAuthSessionCreate) SetBrowserSessionKey(v string) *PendingAuthSessionCreate { + _c.mutation.SetBrowserSessionKey(v) + return _c +} + +// SetNillableBrowserSessionKey sets the "browser_session_key" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableBrowserSessionKey(v *string) *PendingAuthSessionCreate { + if v != nil { + _c.SetBrowserSessionKey(*v) + } + return _c +} + +// SetCompletionCodeHash sets the "completion_code_hash" field. +func (_c *PendingAuthSessionCreate) SetCompletionCodeHash(v string) *PendingAuthSessionCreate { + _c.mutation.SetCompletionCodeHash(v) + return _c +} + +// SetNillableCompletionCodeHash sets the "completion_code_hash" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableCompletionCodeHash(v *string) *PendingAuthSessionCreate { + if v != nil { + _c.SetCompletionCodeHash(*v) + } + return _c +} + +// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field. +func (_c *PendingAuthSessionCreate) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetCompletionCodeExpiresAt(v) + return _c +} + +// SetNillableCompletionCodeExpiresAt sets the "completion_code_expires_at" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableCompletionCodeExpiresAt(v *time.Time) *PendingAuthSessionCreate { + if v != nil { + _c.SetCompletionCodeExpiresAt(*v) + } + return _c +} + +// SetEmailVerifiedAt sets the "email_verified_at" field. +func (_c *PendingAuthSessionCreate) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetEmailVerifiedAt(v) + return _c +} + +// SetNillableEmailVerifiedAt sets the "email_verified_at" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableEmailVerifiedAt(v *time.Time) *PendingAuthSessionCreate { + if v != nil { + _c.SetEmailVerifiedAt(*v) + } + return _c +} + +// SetPasswordVerifiedAt sets the "password_verified_at" field. +func (_c *PendingAuthSessionCreate) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetPasswordVerifiedAt(v) + return _c +} + +// SetNillablePasswordVerifiedAt sets the "password_verified_at" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillablePasswordVerifiedAt(v *time.Time) *PendingAuthSessionCreate { + if v != nil { + _c.SetPasswordVerifiedAt(*v) + } + return _c +} + +// SetTotpVerifiedAt sets the "totp_verified_at" field. +func (_c *PendingAuthSessionCreate) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetTotpVerifiedAt(v) + return _c +} + +// SetNillableTotpVerifiedAt sets the "totp_verified_at" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableTotpVerifiedAt(v *time.Time) *PendingAuthSessionCreate { + if v != nil { + _c.SetTotpVerifiedAt(*v) + } + return _c +} + +// SetExpiresAt sets the "expires_at" field. +func (_c *PendingAuthSessionCreate) SetExpiresAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetExpiresAt(v) + return _c +} + +// SetConsumedAt sets the "consumed_at" field. +func (_c *PendingAuthSessionCreate) SetConsumedAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetConsumedAt(v) + return _c +} + +// SetNillableConsumedAt sets the "consumed_at" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableConsumedAt(v *time.Time) *PendingAuthSessionCreate { + if v != nil { + _c.SetConsumedAt(*v) + } + return _c +} + +// SetTargetUser sets the "target_user" edge to the User entity. +func (_c *PendingAuthSessionCreate) SetTargetUser(v *User) *PendingAuthSessionCreate { + return _c.SetTargetUserID(v.ID) +} + +// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID. +func (_c *PendingAuthSessionCreate) SetAdoptionDecisionID(id int64) *PendingAuthSessionCreate { + _c.mutation.SetAdoptionDecisionID(id) + return _c +} + +// SetNillableAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableAdoptionDecisionID(id *int64) *PendingAuthSessionCreate { + if id != nil { + _c = _c.SetAdoptionDecisionID(*id) + } + return _c +} + +// SetAdoptionDecision sets the "adoption_decision" edge to the IdentityAdoptionDecision entity. +func (_c *PendingAuthSessionCreate) SetAdoptionDecision(v *IdentityAdoptionDecision) *PendingAuthSessionCreate { + return _c.SetAdoptionDecisionID(v.ID) +} + +// Mutation returns the PendingAuthSessionMutation object of the builder. +func (_c *PendingAuthSessionCreate) Mutation() *PendingAuthSessionMutation { + return _c.mutation +} + +// Save creates the PendingAuthSession in the database. +func (_c *PendingAuthSessionCreate) Save(ctx context.Context) (*PendingAuthSession, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *PendingAuthSessionCreate) SaveX(ctx context.Context) *PendingAuthSession { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *PendingAuthSessionCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *PendingAuthSessionCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *PendingAuthSessionCreate) defaults() { + if _, ok := _c.mutation.CreatedAt(); !ok { + v := pendingauthsession.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + v := pendingauthsession.DefaultUpdatedAt() + _c.mutation.SetUpdatedAt(v) + } + if _, ok := _c.mutation.RedirectTo(); !ok { + v := pendingauthsession.DefaultRedirectTo + _c.mutation.SetRedirectTo(v) + } + if _, ok := _c.mutation.ResolvedEmail(); !ok { + v := pendingauthsession.DefaultResolvedEmail + _c.mutation.SetResolvedEmail(v) + } + if _, ok := _c.mutation.RegistrationPasswordHash(); !ok { + v := pendingauthsession.DefaultRegistrationPasswordHash + _c.mutation.SetRegistrationPasswordHash(v) + } + if _, ok := _c.mutation.UpstreamIdentityClaims(); !ok { + v := pendingauthsession.DefaultUpstreamIdentityClaims() + _c.mutation.SetUpstreamIdentityClaims(v) + } + if _, ok := _c.mutation.LocalFlowState(); !ok { + v := pendingauthsession.DefaultLocalFlowState() + _c.mutation.SetLocalFlowState(v) + } + if _, ok := _c.mutation.BrowserSessionKey(); !ok { + v := pendingauthsession.DefaultBrowserSessionKey + _c.mutation.SetBrowserSessionKey(v) + } + if _, ok := _c.mutation.CompletionCodeHash(); !ok { + v := pendingauthsession.DefaultCompletionCodeHash + _c.mutation.SetCompletionCodeHash(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *PendingAuthSessionCreate) check() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "PendingAuthSession.created_at"`)} + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "PendingAuthSession.updated_at"`)} + } + if _, ok := _c.mutation.SessionToken(); !ok { + return &ValidationError{Name: "session_token", err: errors.New(`ent: missing required field "PendingAuthSession.session_token"`)} + } + if v, ok := _c.mutation.SessionToken(); ok { + if err := pendingauthsession.SessionTokenValidator(v); err != nil { + return &ValidationError{Name: "session_token", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.session_token": %w`, err)} + } + } + if _, ok := _c.mutation.Intent(); !ok { + return &ValidationError{Name: "intent", err: errors.New(`ent: missing required field "PendingAuthSession.intent"`)} + } + if v, ok := _c.mutation.Intent(); ok { + if err := pendingauthsession.IntentValidator(v); err != nil { + return &ValidationError{Name: "intent", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.intent": %w`, err)} + } + } + if _, ok := _c.mutation.ProviderType(); !ok { + return &ValidationError{Name: "provider_type", err: errors.New(`ent: missing required field "PendingAuthSession.provider_type"`)} + } + if v, ok := _c.mutation.ProviderType(); ok { + if err := pendingauthsession.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_type": %w`, err)} + } + } + if _, ok := _c.mutation.ProviderKey(); !ok { + return &ValidationError{Name: "provider_key", err: errors.New(`ent: missing required field "PendingAuthSession.provider_key"`)} + } + if v, ok := _c.mutation.ProviderKey(); ok { + if err := pendingauthsession.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_key": %w`, err)} + } + } + if _, ok := _c.mutation.ProviderSubject(); !ok { + return &ValidationError{Name: "provider_subject", err: errors.New(`ent: missing required field "PendingAuthSession.provider_subject"`)} + } + if v, ok := _c.mutation.ProviderSubject(); ok { + if err := pendingauthsession.ProviderSubjectValidator(v); err != nil { + return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_subject": %w`, err)} + } + } + if _, ok := _c.mutation.RedirectTo(); !ok { + return &ValidationError{Name: "redirect_to", err: errors.New(`ent: missing required field "PendingAuthSession.redirect_to"`)} + } + if _, ok := _c.mutation.ResolvedEmail(); !ok { + return &ValidationError{Name: "resolved_email", err: errors.New(`ent: missing required field "PendingAuthSession.resolved_email"`)} + } + if _, ok := _c.mutation.RegistrationPasswordHash(); !ok { + return &ValidationError{Name: "registration_password_hash", err: errors.New(`ent: missing required field "PendingAuthSession.registration_password_hash"`)} + } + if _, ok := _c.mutation.UpstreamIdentityClaims(); !ok { + return &ValidationError{Name: "upstream_identity_claims", err: errors.New(`ent: missing required field "PendingAuthSession.upstream_identity_claims"`)} + } + if _, ok := _c.mutation.LocalFlowState(); !ok { + return &ValidationError{Name: "local_flow_state", err: errors.New(`ent: missing required field "PendingAuthSession.local_flow_state"`)} + } + if _, ok := _c.mutation.BrowserSessionKey(); !ok { + return &ValidationError{Name: "browser_session_key", err: errors.New(`ent: missing required field "PendingAuthSession.browser_session_key"`)} + } + if _, ok := _c.mutation.CompletionCodeHash(); !ok { + return &ValidationError{Name: "completion_code_hash", err: errors.New(`ent: missing required field "PendingAuthSession.completion_code_hash"`)} + } + if _, ok := _c.mutation.ExpiresAt(); !ok { + return &ValidationError{Name: "expires_at", err: errors.New(`ent: missing required field "PendingAuthSession.expires_at"`)} + } + return nil +} + +func (_c *PendingAuthSessionCreate) sqlSave(ctx context.Context) (*PendingAuthSession, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *PendingAuthSessionCreate) createSpec() (*PendingAuthSession, *sqlgraph.CreateSpec) { + var ( + _node = &PendingAuthSession{config: _c.config} + _spec = sqlgraph.NewCreateSpec(pendingauthsession.Table, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(pendingauthsession.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.UpdatedAt(); ok { + _spec.SetField(pendingauthsession.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := _c.mutation.SessionToken(); ok { + _spec.SetField(pendingauthsession.FieldSessionToken, field.TypeString, value) + _node.SessionToken = value + } + if value, ok := _c.mutation.Intent(); ok { + _spec.SetField(pendingauthsession.FieldIntent, field.TypeString, value) + _node.Intent = value + } + if value, ok := _c.mutation.ProviderType(); ok { + _spec.SetField(pendingauthsession.FieldProviderType, field.TypeString, value) + _node.ProviderType = value + } + if value, ok := _c.mutation.ProviderKey(); ok { + _spec.SetField(pendingauthsession.FieldProviderKey, field.TypeString, value) + _node.ProviderKey = value + } + if value, ok := _c.mutation.ProviderSubject(); ok { + _spec.SetField(pendingauthsession.FieldProviderSubject, field.TypeString, value) + _node.ProviderSubject = value + } + if value, ok := _c.mutation.RedirectTo(); ok { + _spec.SetField(pendingauthsession.FieldRedirectTo, field.TypeString, value) + _node.RedirectTo = value + } + if value, ok := _c.mutation.ResolvedEmail(); ok { + _spec.SetField(pendingauthsession.FieldResolvedEmail, field.TypeString, value) + _node.ResolvedEmail = value + } + if value, ok := _c.mutation.RegistrationPasswordHash(); ok { + _spec.SetField(pendingauthsession.FieldRegistrationPasswordHash, field.TypeString, value) + _node.RegistrationPasswordHash = value + } + if value, ok := _c.mutation.UpstreamIdentityClaims(); ok { + _spec.SetField(pendingauthsession.FieldUpstreamIdentityClaims, field.TypeJSON, value) + _node.UpstreamIdentityClaims = value + } + if value, ok := _c.mutation.LocalFlowState(); ok { + _spec.SetField(pendingauthsession.FieldLocalFlowState, field.TypeJSON, value) + _node.LocalFlowState = value + } + if value, ok := _c.mutation.BrowserSessionKey(); ok { + _spec.SetField(pendingauthsession.FieldBrowserSessionKey, field.TypeString, value) + _node.BrowserSessionKey = value + } + if value, ok := _c.mutation.CompletionCodeHash(); ok { + _spec.SetField(pendingauthsession.FieldCompletionCodeHash, field.TypeString, value) + _node.CompletionCodeHash = value + } + if value, ok := _c.mutation.CompletionCodeExpiresAt(); ok { + _spec.SetField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime, value) + _node.CompletionCodeExpiresAt = &value + } + if value, ok := _c.mutation.EmailVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime, value) + _node.EmailVerifiedAt = &value + } + if value, ok := _c.mutation.PasswordVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime, value) + _node.PasswordVerifiedAt = &value + } + if value, ok := _c.mutation.TotpVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime, value) + _node.TotpVerifiedAt = &value + } + if value, ok := _c.mutation.ExpiresAt(); ok { + _spec.SetField(pendingauthsession.FieldExpiresAt, field.TypeTime, value) + _node.ExpiresAt = value + } + if value, ok := _c.mutation.ConsumedAt(); ok { + _spec.SetField(pendingauthsession.FieldConsumedAt, field.TypeTime, value) + _node.ConsumedAt = &value + } + if nodes := _c.mutation.TargetUserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: pendingauthsession.TargetUserTable, + Columns: []string{pendingauthsession.TargetUserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.TargetUserID = &nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := _c.mutation.AdoptionDecisionIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: false, + Table: pendingauthsession.AdoptionDecisionTable, + Columns: []string{pendingauthsession.AdoptionDecisionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.PendingAuthSession.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.PendingAuthSessionUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *PendingAuthSessionCreate) OnConflict(opts ...sql.ConflictOption) *PendingAuthSessionUpsertOne { + _c.conflict = opts + return &PendingAuthSessionUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.PendingAuthSession.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *PendingAuthSessionCreate) OnConflictColumns(columns ...string) *PendingAuthSessionUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &PendingAuthSessionUpsertOne{ + create: _c, + } +} + +type ( + // PendingAuthSessionUpsertOne is the builder for "upsert"-ing + // one PendingAuthSession node. + PendingAuthSessionUpsertOne struct { + create *PendingAuthSessionCreate + } + + // PendingAuthSessionUpsert is the "OnConflict" setter. + PendingAuthSessionUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *PendingAuthSessionUpsert) SetUpdatedAt(v time.Time) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateUpdatedAt() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldUpdatedAt) + return u +} + +// SetSessionToken sets the "session_token" field. +func (u *PendingAuthSessionUpsert) SetSessionToken(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldSessionToken, v) + return u +} + +// UpdateSessionToken sets the "session_token" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateSessionToken() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldSessionToken) + return u +} + +// SetIntent sets the "intent" field. +func (u *PendingAuthSessionUpsert) SetIntent(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldIntent, v) + return u +} + +// UpdateIntent sets the "intent" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateIntent() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldIntent) + return u +} + +// SetProviderType sets the "provider_type" field. +func (u *PendingAuthSessionUpsert) SetProviderType(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldProviderType, v) + return u +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateProviderType() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldProviderType) + return u +} + +// SetProviderKey sets the "provider_key" field. +func (u *PendingAuthSessionUpsert) SetProviderKey(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldProviderKey, v) + return u +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateProviderKey() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldProviderKey) + return u +} + +// SetProviderSubject sets the "provider_subject" field. +func (u *PendingAuthSessionUpsert) SetProviderSubject(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldProviderSubject, v) + return u +} + +// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateProviderSubject() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldProviderSubject) + return u +} + +// SetTargetUserID sets the "target_user_id" field. +func (u *PendingAuthSessionUpsert) SetTargetUserID(v int64) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldTargetUserID, v) + return u +} + +// UpdateTargetUserID sets the "target_user_id" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateTargetUserID() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldTargetUserID) + return u +} + +// ClearTargetUserID clears the value of the "target_user_id" field. +func (u *PendingAuthSessionUpsert) ClearTargetUserID() *PendingAuthSessionUpsert { + u.SetNull(pendingauthsession.FieldTargetUserID) + return u +} + +// SetRedirectTo sets the "redirect_to" field. +func (u *PendingAuthSessionUpsert) SetRedirectTo(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldRedirectTo, v) + return u +} + +// UpdateRedirectTo sets the "redirect_to" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateRedirectTo() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldRedirectTo) + return u +} + +// SetResolvedEmail sets the "resolved_email" field. +func (u *PendingAuthSessionUpsert) SetResolvedEmail(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldResolvedEmail, v) + return u +} + +// UpdateResolvedEmail sets the "resolved_email" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateResolvedEmail() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldResolvedEmail) + return u +} + +// SetRegistrationPasswordHash sets the "registration_password_hash" field. +func (u *PendingAuthSessionUpsert) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldRegistrationPasswordHash, v) + return u +} + +// UpdateRegistrationPasswordHash sets the "registration_password_hash" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateRegistrationPasswordHash() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldRegistrationPasswordHash) + return u +} + +// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field. +func (u *PendingAuthSessionUpsert) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldUpstreamIdentityClaims, v) + return u +} + +// UpdateUpstreamIdentityClaims sets the "upstream_identity_claims" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateUpstreamIdentityClaims() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldUpstreamIdentityClaims) + return u +} + +// SetLocalFlowState sets the "local_flow_state" field. +func (u *PendingAuthSessionUpsert) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldLocalFlowState, v) + return u +} + +// UpdateLocalFlowState sets the "local_flow_state" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateLocalFlowState() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldLocalFlowState) + return u +} + +// SetBrowserSessionKey sets the "browser_session_key" field. +func (u *PendingAuthSessionUpsert) SetBrowserSessionKey(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldBrowserSessionKey, v) + return u +} + +// UpdateBrowserSessionKey sets the "browser_session_key" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateBrowserSessionKey() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldBrowserSessionKey) + return u +} + +// SetCompletionCodeHash sets the "completion_code_hash" field. +func (u *PendingAuthSessionUpsert) SetCompletionCodeHash(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldCompletionCodeHash, v) + return u +} + +// UpdateCompletionCodeHash sets the "completion_code_hash" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateCompletionCodeHash() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldCompletionCodeHash) + return u +} + +// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field. +func (u *PendingAuthSessionUpsert) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldCompletionCodeExpiresAt, v) + return u +} + +// UpdateCompletionCodeExpiresAt sets the "completion_code_expires_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateCompletionCodeExpiresAt() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldCompletionCodeExpiresAt) + return u +} + +// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field. +func (u *PendingAuthSessionUpsert) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpsert { + u.SetNull(pendingauthsession.FieldCompletionCodeExpiresAt) + return u +} + +// SetEmailVerifiedAt sets the "email_verified_at" field. +func (u *PendingAuthSessionUpsert) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldEmailVerifiedAt, v) + return u +} + +// UpdateEmailVerifiedAt sets the "email_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateEmailVerifiedAt() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldEmailVerifiedAt) + return u +} + +// ClearEmailVerifiedAt clears the value of the "email_verified_at" field. +func (u *PendingAuthSessionUpsert) ClearEmailVerifiedAt() *PendingAuthSessionUpsert { + u.SetNull(pendingauthsession.FieldEmailVerifiedAt) + return u +} + +// SetPasswordVerifiedAt sets the "password_verified_at" field. +func (u *PendingAuthSessionUpsert) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldPasswordVerifiedAt, v) + return u +} + +// UpdatePasswordVerifiedAt sets the "password_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdatePasswordVerifiedAt() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldPasswordVerifiedAt) + return u +} + +// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field. +func (u *PendingAuthSessionUpsert) ClearPasswordVerifiedAt() *PendingAuthSessionUpsert { + u.SetNull(pendingauthsession.FieldPasswordVerifiedAt) + return u +} + +// SetTotpVerifiedAt sets the "totp_verified_at" field. +func (u *PendingAuthSessionUpsert) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldTotpVerifiedAt, v) + return u +} + +// UpdateTotpVerifiedAt sets the "totp_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateTotpVerifiedAt() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldTotpVerifiedAt) + return u +} + +// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field. +func (u *PendingAuthSessionUpsert) ClearTotpVerifiedAt() *PendingAuthSessionUpsert { + u.SetNull(pendingauthsession.FieldTotpVerifiedAt) + return u +} + +// SetExpiresAt sets the "expires_at" field. +func (u *PendingAuthSessionUpsert) SetExpiresAt(v time.Time) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldExpiresAt, v) + return u +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateExpiresAt() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldExpiresAt) + return u +} + +// SetConsumedAt sets the "consumed_at" field. +func (u *PendingAuthSessionUpsert) SetConsumedAt(v time.Time) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldConsumedAt, v) + return u +} + +// UpdateConsumedAt sets the "consumed_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateConsumedAt() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldConsumedAt) + return u +} + +// ClearConsumedAt clears the value of the "consumed_at" field. +func (u *PendingAuthSessionUpsert) ClearConsumedAt() *PendingAuthSessionUpsert { + u.SetNull(pendingauthsession.FieldConsumedAt) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.PendingAuthSession.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *PendingAuthSessionUpsertOne) UpdateNewValues() *PendingAuthSessionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(pendingauthsession.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.PendingAuthSession.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *PendingAuthSessionUpsertOne) Ignore() *PendingAuthSessionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *PendingAuthSessionUpsertOne) DoNothing() *PendingAuthSessionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the PendingAuthSessionCreate.OnConflict +// documentation for more info. +func (u *PendingAuthSessionUpsertOne) Update(set func(*PendingAuthSessionUpsert)) *PendingAuthSessionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&PendingAuthSessionUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *PendingAuthSessionUpsertOne) SetUpdatedAt(v time.Time) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateUpdatedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetSessionToken sets the "session_token" field. +func (u *PendingAuthSessionUpsertOne) SetSessionToken(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetSessionToken(v) + }) +} + +// UpdateSessionToken sets the "session_token" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateSessionToken() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateSessionToken() + }) +} + +// SetIntent sets the "intent" field. +func (u *PendingAuthSessionUpsertOne) SetIntent(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetIntent(v) + }) +} + +// UpdateIntent sets the "intent" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateIntent() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateIntent() + }) +} + +// SetProviderType sets the "provider_type" field. +func (u *PendingAuthSessionUpsertOne) SetProviderType(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetProviderType(v) + }) +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateProviderType() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateProviderType() + }) +} + +// SetProviderKey sets the "provider_key" field. +func (u *PendingAuthSessionUpsertOne) SetProviderKey(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetProviderKey(v) + }) +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateProviderKey() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateProviderKey() + }) +} + +// SetProviderSubject sets the "provider_subject" field. +func (u *PendingAuthSessionUpsertOne) SetProviderSubject(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetProviderSubject(v) + }) +} + +// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateProviderSubject() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateProviderSubject() + }) +} + +// SetTargetUserID sets the "target_user_id" field. +func (u *PendingAuthSessionUpsertOne) SetTargetUserID(v int64) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetTargetUserID(v) + }) +} + +// UpdateTargetUserID sets the "target_user_id" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateTargetUserID() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateTargetUserID() + }) +} + +// ClearTargetUserID clears the value of the "target_user_id" field. +func (u *PendingAuthSessionUpsertOne) ClearTargetUserID() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearTargetUserID() + }) +} + +// SetRedirectTo sets the "redirect_to" field. +func (u *PendingAuthSessionUpsertOne) SetRedirectTo(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetRedirectTo(v) + }) +} + +// UpdateRedirectTo sets the "redirect_to" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateRedirectTo() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateRedirectTo() + }) +} + +// SetResolvedEmail sets the "resolved_email" field. +func (u *PendingAuthSessionUpsertOne) SetResolvedEmail(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetResolvedEmail(v) + }) +} + +// UpdateResolvedEmail sets the "resolved_email" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateResolvedEmail() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateResolvedEmail() + }) +} + +// SetRegistrationPasswordHash sets the "registration_password_hash" field. +func (u *PendingAuthSessionUpsertOne) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetRegistrationPasswordHash(v) + }) +} + +// UpdateRegistrationPasswordHash sets the "registration_password_hash" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateRegistrationPasswordHash() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateRegistrationPasswordHash() + }) +} + +// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field. +func (u *PendingAuthSessionUpsertOne) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetUpstreamIdentityClaims(v) + }) +} + +// UpdateUpstreamIdentityClaims sets the "upstream_identity_claims" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateUpstreamIdentityClaims() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateUpstreamIdentityClaims() + }) +} + +// SetLocalFlowState sets the "local_flow_state" field. +func (u *PendingAuthSessionUpsertOne) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetLocalFlowState(v) + }) +} + +// UpdateLocalFlowState sets the "local_flow_state" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateLocalFlowState() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateLocalFlowState() + }) +} + +// SetBrowserSessionKey sets the "browser_session_key" field. +func (u *PendingAuthSessionUpsertOne) SetBrowserSessionKey(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetBrowserSessionKey(v) + }) +} + +// UpdateBrowserSessionKey sets the "browser_session_key" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateBrowserSessionKey() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateBrowserSessionKey() + }) +} + +// SetCompletionCodeHash sets the "completion_code_hash" field. +func (u *PendingAuthSessionUpsertOne) SetCompletionCodeHash(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetCompletionCodeHash(v) + }) +} + +// UpdateCompletionCodeHash sets the "completion_code_hash" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateCompletionCodeHash() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateCompletionCodeHash() + }) +} + +// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field. +func (u *PendingAuthSessionUpsertOne) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetCompletionCodeExpiresAt(v) + }) +} + +// UpdateCompletionCodeExpiresAt sets the "completion_code_expires_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateCompletionCodeExpiresAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateCompletionCodeExpiresAt() + }) +} + +// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field. +func (u *PendingAuthSessionUpsertOne) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearCompletionCodeExpiresAt() + }) +} + +// SetEmailVerifiedAt sets the "email_verified_at" field. +func (u *PendingAuthSessionUpsertOne) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetEmailVerifiedAt(v) + }) +} + +// UpdateEmailVerifiedAt sets the "email_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateEmailVerifiedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateEmailVerifiedAt() + }) +} + +// ClearEmailVerifiedAt clears the value of the "email_verified_at" field. +func (u *PendingAuthSessionUpsertOne) ClearEmailVerifiedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearEmailVerifiedAt() + }) +} + +// SetPasswordVerifiedAt sets the "password_verified_at" field. +func (u *PendingAuthSessionUpsertOne) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetPasswordVerifiedAt(v) + }) +} + +// UpdatePasswordVerifiedAt sets the "password_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdatePasswordVerifiedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdatePasswordVerifiedAt() + }) +} + +// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field. +func (u *PendingAuthSessionUpsertOne) ClearPasswordVerifiedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearPasswordVerifiedAt() + }) +} + +// SetTotpVerifiedAt sets the "totp_verified_at" field. +func (u *PendingAuthSessionUpsertOne) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetTotpVerifiedAt(v) + }) +} + +// UpdateTotpVerifiedAt sets the "totp_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateTotpVerifiedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateTotpVerifiedAt() + }) +} + +// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field. +func (u *PendingAuthSessionUpsertOne) ClearTotpVerifiedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearTotpVerifiedAt() + }) +} + +// SetExpiresAt sets the "expires_at" field. +func (u *PendingAuthSessionUpsertOne) SetExpiresAt(v time.Time) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetExpiresAt(v) + }) +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateExpiresAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateExpiresAt() + }) +} + +// SetConsumedAt sets the "consumed_at" field. +func (u *PendingAuthSessionUpsertOne) SetConsumedAt(v time.Time) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetConsumedAt(v) + }) +} + +// UpdateConsumedAt sets the "consumed_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateConsumedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateConsumedAt() + }) +} + +// ClearConsumedAt clears the value of the "consumed_at" field. +func (u *PendingAuthSessionUpsertOne) ClearConsumedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearConsumedAt() + }) +} + +// Exec executes the query. +func (u *PendingAuthSessionUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for PendingAuthSessionCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *PendingAuthSessionUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *PendingAuthSessionUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *PendingAuthSessionUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// PendingAuthSessionCreateBulk is the builder for creating many PendingAuthSession entities in bulk. +type PendingAuthSessionCreateBulk struct { + config + err error + builders []*PendingAuthSessionCreate + conflict []sql.ConflictOption +} + +// Save creates the PendingAuthSession entities in the database. +func (_c *PendingAuthSessionCreateBulk) Save(ctx context.Context) ([]*PendingAuthSession, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*PendingAuthSession, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*PendingAuthSessionMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *PendingAuthSessionCreateBulk) SaveX(ctx context.Context) []*PendingAuthSession { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *PendingAuthSessionCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *PendingAuthSessionCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.PendingAuthSession.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.PendingAuthSessionUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *PendingAuthSessionCreateBulk) OnConflict(opts ...sql.ConflictOption) *PendingAuthSessionUpsertBulk { + _c.conflict = opts + return &PendingAuthSessionUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.PendingAuthSession.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *PendingAuthSessionCreateBulk) OnConflictColumns(columns ...string) *PendingAuthSessionUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &PendingAuthSessionUpsertBulk{ + create: _c, + } +} + +// PendingAuthSessionUpsertBulk is the builder for "upsert"-ing +// a bulk of PendingAuthSession nodes. +type PendingAuthSessionUpsertBulk struct { + create *PendingAuthSessionCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.PendingAuthSession.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *PendingAuthSessionUpsertBulk) UpdateNewValues() *PendingAuthSessionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(pendingauthsession.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.PendingAuthSession.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *PendingAuthSessionUpsertBulk) Ignore() *PendingAuthSessionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *PendingAuthSessionUpsertBulk) DoNothing() *PendingAuthSessionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the PendingAuthSessionCreateBulk.OnConflict +// documentation for more info. +func (u *PendingAuthSessionUpsertBulk) Update(set func(*PendingAuthSessionUpsert)) *PendingAuthSessionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&PendingAuthSessionUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *PendingAuthSessionUpsertBulk) SetUpdatedAt(v time.Time) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateUpdatedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetSessionToken sets the "session_token" field. +func (u *PendingAuthSessionUpsertBulk) SetSessionToken(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetSessionToken(v) + }) +} + +// UpdateSessionToken sets the "session_token" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateSessionToken() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateSessionToken() + }) +} + +// SetIntent sets the "intent" field. +func (u *PendingAuthSessionUpsertBulk) SetIntent(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetIntent(v) + }) +} + +// UpdateIntent sets the "intent" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateIntent() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateIntent() + }) +} + +// SetProviderType sets the "provider_type" field. +func (u *PendingAuthSessionUpsertBulk) SetProviderType(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetProviderType(v) + }) +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateProviderType() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateProviderType() + }) +} + +// SetProviderKey sets the "provider_key" field. +func (u *PendingAuthSessionUpsertBulk) SetProviderKey(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetProviderKey(v) + }) +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateProviderKey() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateProviderKey() + }) +} + +// SetProviderSubject sets the "provider_subject" field. +func (u *PendingAuthSessionUpsertBulk) SetProviderSubject(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetProviderSubject(v) + }) +} + +// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateProviderSubject() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateProviderSubject() + }) +} + +// SetTargetUserID sets the "target_user_id" field. +func (u *PendingAuthSessionUpsertBulk) SetTargetUserID(v int64) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetTargetUserID(v) + }) +} + +// UpdateTargetUserID sets the "target_user_id" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateTargetUserID() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateTargetUserID() + }) +} + +// ClearTargetUserID clears the value of the "target_user_id" field. +func (u *PendingAuthSessionUpsertBulk) ClearTargetUserID() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearTargetUserID() + }) +} + +// SetRedirectTo sets the "redirect_to" field. +func (u *PendingAuthSessionUpsertBulk) SetRedirectTo(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetRedirectTo(v) + }) +} + +// UpdateRedirectTo sets the "redirect_to" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateRedirectTo() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateRedirectTo() + }) +} + +// SetResolvedEmail sets the "resolved_email" field. +func (u *PendingAuthSessionUpsertBulk) SetResolvedEmail(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetResolvedEmail(v) + }) +} + +// UpdateResolvedEmail sets the "resolved_email" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateResolvedEmail() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateResolvedEmail() + }) +} + +// SetRegistrationPasswordHash sets the "registration_password_hash" field. +func (u *PendingAuthSessionUpsertBulk) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetRegistrationPasswordHash(v) + }) +} + +// UpdateRegistrationPasswordHash sets the "registration_password_hash" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateRegistrationPasswordHash() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateRegistrationPasswordHash() + }) +} + +// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field. +func (u *PendingAuthSessionUpsertBulk) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetUpstreamIdentityClaims(v) + }) +} + +// UpdateUpstreamIdentityClaims sets the "upstream_identity_claims" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateUpstreamIdentityClaims() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateUpstreamIdentityClaims() + }) +} + +// SetLocalFlowState sets the "local_flow_state" field. +func (u *PendingAuthSessionUpsertBulk) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetLocalFlowState(v) + }) +} + +// UpdateLocalFlowState sets the "local_flow_state" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateLocalFlowState() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateLocalFlowState() + }) +} + +// SetBrowserSessionKey sets the "browser_session_key" field. +func (u *PendingAuthSessionUpsertBulk) SetBrowserSessionKey(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetBrowserSessionKey(v) + }) +} + +// UpdateBrowserSessionKey sets the "browser_session_key" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateBrowserSessionKey() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateBrowserSessionKey() + }) +} + +// SetCompletionCodeHash sets the "completion_code_hash" field. +func (u *PendingAuthSessionUpsertBulk) SetCompletionCodeHash(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetCompletionCodeHash(v) + }) +} + +// UpdateCompletionCodeHash sets the "completion_code_hash" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateCompletionCodeHash() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateCompletionCodeHash() + }) +} + +// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field. +func (u *PendingAuthSessionUpsertBulk) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetCompletionCodeExpiresAt(v) + }) +} + +// UpdateCompletionCodeExpiresAt sets the "completion_code_expires_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateCompletionCodeExpiresAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateCompletionCodeExpiresAt() + }) +} + +// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field. +func (u *PendingAuthSessionUpsertBulk) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearCompletionCodeExpiresAt() + }) +} + +// SetEmailVerifiedAt sets the "email_verified_at" field. +func (u *PendingAuthSessionUpsertBulk) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetEmailVerifiedAt(v) + }) +} + +// UpdateEmailVerifiedAt sets the "email_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateEmailVerifiedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateEmailVerifiedAt() + }) +} + +// ClearEmailVerifiedAt clears the value of the "email_verified_at" field. +func (u *PendingAuthSessionUpsertBulk) ClearEmailVerifiedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearEmailVerifiedAt() + }) +} + +// SetPasswordVerifiedAt sets the "password_verified_at" field. +func (u *PendingAuthSessionUpsertBulk) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetPasswordVerifiedAt(v) + }) +} + +// UpdatePasswordVerifiedAt sets the "password_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdatePasswordVerifiedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdatePasswordVerifiedAt() + }) +} + +// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field. +func (u *PendingAuthSessionUpsertBulk) ClearPasswordVerifiedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearPasswordVerifiedAt() + }) +} + +// SetTotpVerifiedAt sets the "totp_verified_at" field. +func (u *PendingAuthSessionUpsertBulk) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetTotpVerifiedAt(v) + }) +} + +// UpdateTotpVerifiedAt sets the "totp_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateTotpVerifiedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateTotpVerifiedAt() + }) +} + +// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field. +func (u *PendingAuthSessionUpsertBulk) ClearTotpVerifiedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearTotpVerifiedAt() + }) +} + +// SetExpiresAt sets the "expires_at" field. +func (u *PendingAuthSessionUpsertBulk) SetExpiresAt(v time.Time) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetExpiresAt(v) + }) +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateExpiresAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateExpiresAt() + }) +} + +// SetConsumedAt sets the "consumed_at" field. +func (u *PendingAuthSessionUpsertBulk) SetConsumedAt(v time.Time) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetConsumedAt(v) + }) +} + +// UpdateConsumedAt sets the "consumed_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateConsumedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateConsumedAt() + }) +} + +// ClearConsumedAt clears the value of the "consumed_at" field. +func (u *PendingAuthSessionUpsertBulk) ClearConsumedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearConsumedAt() + }) +} + +// Exec executes the query. +func (u *PendingAuthSessionUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the PendingAuthSessionCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for PendingAuthSessionCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *PendingAuthSessionUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/pendingauthsession_delete.go b/backend/ent/pendingauthsession_delete.go new file mode 100644 index 0000000000000000000000000000000000000000..ee4fe6051d11812bdced2f78b4cc9554828910fa --- /dev/null +++ b/backend/ent/pendingauthsession_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// PendingAuthSessionDelete is the builder for deleting a PendingAuthSession entity. +type PendingAuthSessionDelete struct { + config + hooks []Hook + mutation *PendingAuthSessionMutation +} + +// Where appends a list predicates to the PendingAuthSessionDelete builder. +func (_d *PendingAuthSessionDelete) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *PendingAuthSessionDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *PendingAuthSessionDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *PendingAuthSessionDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(pendingauthsession.Table, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// PendingAuthSessionDeleteOne is the builder for deleting a single PendingAuthSession entity. +type PendingAuthSessionDeleteOne struct { + _d *PendingAuthSessionDelete +} + +// Where appends a list predicates to the PendingAuthSessionDelete builder. +func (_d *PendingAuthSessionDeleteOne) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *PendingAuthSessionDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{pendingauthsession.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *PendingAuthSessionDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/pendingauthsession_query.go b/backend/ent/pendingauthsession_query.go new file mode 100644 index 0000000000000000000000000000000000000000..78e29cd2bedf07258e955a0ca95420c5e3da9e3e --- /dev/null +++ b/backend/ent/pendingauthsession_query.go @@ -0,0 +1,717 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "database/sql/driver" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// PendingAuthSessionQuery is the builder for querying PendingAuthSession entities. +type PendingAuthSessionQuery struct { + config + ctx *QueryContext + order []pendingauthsession.OrderOption + inters []Interceptor + predicates []predicate.PendingAuthSession + withTargetUser *UserQuery + withAdoptionDecision *IdentityAdoptionDecisionQuery + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the PendingAuthSessionQuery builder. +func (_q *PendingAuthSessionQuery) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *PendingAuthSessionQuery) Limit(limit int) *PendingAuthSessionQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *PendingAuthSessionQuery) Offset(offset int) *PendingAuthSessionQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *PendingAuthSessionQuery) Unique(unique bool) *PendingAuthSessionQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *PendingAuthSessionQuery) Order(o ...pendingauthsession.OrderOption) *PendingAuthSessionQuery { + _q.order = append(_q.order, o...) + return _q +} + +// QueryTargetUser chains the current query on the "target_user" edge. +func (_q *PendingAuthSessionQuery) QueryTargetUser() *UserQuery { + query := (&UserClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, selector), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, pendingauthsession.TargetUserTable, pendingauthsession.TargetUserColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryAdoptionDecision chains the current query on the "adoption_decision" edge. +func (_q *PendingAuthSessionQuery) QueryAdoptionDecision() *IdentityAdoptionDecisionQuery { + query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, selector), + sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID), + sqlgraph.Edge(sqlgraph.O2O, false, pendingauthsession.AdoptionDecisionTable, pendingauthsession.AdoptionDecisionColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first PendingAuthSession entity from the query. +// Returns a *NotFoundError when no PendingAuthSession was found. +func (_q *PendingAuthSessionQuery) First(ctx context.Context) (*PendingAuthSession, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{pendingauthsession.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) FirstX(ctx context.Context) *PendingAuthSession { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first PendingAuthSession ID from the query. +// Returns a *NotFoundError when no PendingAuthSession ID was found. +func (_q *PendingAuthSessionQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{pendingauthsession.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single PendingAuthSession entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one PendingAuthSession entity is found. +// Returns a *NotFoundError when no PendingAuthSession entities are found. +func (_q *PendingAuthSessionQuery) Only(ctx context.Context) (*PendingAuthSession, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{pendingauthsession.Label} + default: + return nil, &NotSingularError{pendingauthsession.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) OnlyX(ctx context.Context) *PendingAuthSession { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only PendingAuthSession ID in the query. +// Returns a *NotSingularError when more than one PendingAuthSession ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *PendingAuthSessionQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{pendingauthsession.Label} + default: + err = &NotSingularError{pendingauthsession.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of PendingAuthSessions. +func (_q *PendingAuthSessionQuery) All(ctx context.Context) ([]*PendingAuthSession, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*PendingAuthSession, *PendingAuthSessionQuery]() + return withInterceptors[[]*PendingAuthSession](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) AllX(ctx context.Context) []*PendingAuthSession { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of PendingAuthSession IDs. +func (_q *PendingAuthSessionQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(pendingauthsession.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *PendingAuthSessionQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*PendingAuthSessionQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *PendingAuthSessionQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the PendingAuthSessionQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *PendingAuthSessionQuery) Clone() *PendingAuthSessionQuery { + if _q == nil { + return nil + } + return &PendingAuthSessionQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]pendingauthsession.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.PendingAuthSession{}, _q.predicates...), + withTargetUser: _q.withTargetUser.Clone(), + withAdoptionDecision: _q.withAdoptionDecision.Clone(), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// WithTargetUser tells the query-builder to eager-load the nodes that are connected to +// the "target_user" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *PendingAuthSessionQuery) WithTargetUser(opts ...func(*UserQuery)) *PendingAuthSessionQuery { + query := (&UserClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withTargetUser = query + return _q +} + +// WithAdoptionDecision tells the query-builder to eager-load the nodes that are connected to +// the "adoption_decision" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *PendingAuthSessionQuery) WithAdoptionDecision(opts ...func(*IdentityAdoptionDecisionQuery)) *PendingAuthSessionQuery { + query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withAdoptionDecision = query + return _q +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.PendingAuthSession.Query(). +// GroupBy(pendingauthsession.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *PendingAuthSessionQuery) GroupBy(field string, fields ...string) *PendingAuthSessionGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &PendingAuthSessionGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = pendingauthsession.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.PendingAuthSession.Query(). +// Select(pendingauthsession.FieldCreatedAt). +// Scan(ctx, &v) +func (_q *PendingAuthSessionQuery) Select(fields ...string) *PendingAuthSessionSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &PendingAuthSessionSelect{PendingAuthSessionQuery: _q} + sbuild.label = pendingauthsession.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a PendingAuthSessionSelect configured with the given aggregations. +func (_q *PendingAuthSessionQuery) Aggregate(fns ...AggregateFunc) *PendingAuthSessionSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *PendingAuthSessionQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !pendingauthsession.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *PendingAuthSessionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*PendingAuthSession, error) { + var ( + nodes = []*PendingAuthSession{} + _spec = _q.querySpec() + loadedTypes = [2]bool{ + _q.withTargetUser != nil, + _q.withAdoptionDecision != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*PendingAuthSession).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &PendingAuthSession{config: _q.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := _q.withTargetUser; query != nil { + if err := _q.loadTargetUser(ctx, query, nodes, nil, + func(n *PendingAuthSession, e *User) { n.Edges.TargetUser = e }); err != nil { + return nil, err + } + } + if query := _q.withAdoptionDecision; query != nil { + if err := _q.loadAdoptionDecision(ctx, query, nodes, nil, + func(n *PendingAuthSession, e *IdentityAdoptionDecision) { n.Edges.AdoptionDecision = e }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (_q *PendingAuthSessionQuery) loadTargetUser(ctx context.Context, query *UserQuery, nodes []*PendingAuthSession, init func(*PendingAuthSession), assign func(*PendingAuthSession, *User)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*PendingAuthSession) + for i := range nodes { + if nodes[i].TargetUserID == nil { + continue + } + fk := *nodes[i].TargetUserID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "target_user_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} +func (_q *PendingAuthSessionQuery) loadAdoptionDecision(ctx context.Context, query *IdentityAdoptionDecisionQuery, nodes []*PendingAuthSession, init func(*PendingAuthSession), assign func(*PendingAuthSession, *IdentityAdoptionDecision)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*PendingAuthSession) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(identityadoptiondecision.FieldPendingAuthSessionID) + } + query.Where(predicate.IdentityAdoptionDecision(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(pendingauthsession.AdoptionDecisionColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.PendingAuthSessionID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "pending_auth_session_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} + +func (_q *PendingAuthSessionQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *PendingAuthSessionQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(pendingauthsession.Table, pendingauthsession.Columns, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, pendingauthsession.FieldID) + for i := range fields { + if fields[i] != pendingauthsession.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if _q.withTargetUser != nil { + _spec.Node.AddColumnOnce(pendingauthsession.FieldTargetUserID) + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *PendingAuthSessionQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(pendingauthsession.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = pendingauthsession.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *PendingAuthSessionQuery) ForUpdate(opts ...sql.LockOption) *PendingAuthSessionQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *PendingAuthSessionQuery) ForShare(opts ...sql.LockOption) *PendingAuthSessionQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// PendingAuthSessionGroupBy is the group-by builder for PendingAuthSession entities. +type PendingAuthSessionGroupBy struct { + selector + build *PendingAuthSessionQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *PendingAuthSessionGroupBy) Aggregate(fns ...AggregateFunc) *PendingAuthSessionGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *PendingAuthSessionGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*PendingAuthSessionQuery, *PendingAuthSessionGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *PendingAuthSessionGroupBy) sqlScan(ctx context.Context, root *PendingAuthSessionQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// PendingAuthSessionSelect is the builder for selecting fields of PendingAuthSession entities. +type PendingAuthSessionSelect struct { + *PendingAuthSessionQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *PendingAuthSessionSelect) Aggregate(fns ...AggregateFunc) *PendingAuthSessionSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *PendingAuthSessionSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*PendingAuthSessionQuery, *PendingAuthSessionSelect](ctx, _s.PendingAuthSessionQuery, _s, _s.inters, v) +} + +func (_s *PendingAuthSessionSelect) sqlScan(ctx context.Context, root *PendingAuthSessionQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/pendingauthsession_update.go b/backend/ent/pendingauthsession_update.go new file mode 100644 index 0000000000000000000000000000000000000000..00066f699baf7a7fe87a01f4870035ba87c4431e --- /dev/null +++ b/backend/ent/pendingauthsession_update.go @@ -0,0 +1,1178 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// PendingAuthSessionUpdate is the builder for updating PendingAuthSession entities. +type PendingAuthSessionUpdate struct { + config + hooks []Hook + mutation *PendingAuthSessionMutation +} + +// Where appends a list predicates to the PendingAuthSessionUpdate builder. +func (_u *PendingAuthSessionUpdate) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *PendingAuthSessionUpdate) SetUpdatedAt(v time.Time) *PendingAuthSessionUpdate { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetSessionToken sets the "session_token" field. +func (_u *PendingAuthSessionUpdate) SetSessionToken(v string) *PendingAuthSessionUpdate { + _u.mutation.SetSessionToken(v) + return _u +} + +// SetNillableSessionToken sets the "session_token" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableSessionToken(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetSessionToken(*v) + } + return _u +} + +// SetIntent sets the "intent" field. +func (_u *PendingAuthSessionUpdate) SetIntent(v string) *PendingAuthSessionUpdate { + _u.mutation.SetIntent(v) + return _u +} + +// SetNillableIntent sets the "intent" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableIntent(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetIntent(*v) + } + return _u +} + +// SetProviderType sets the "provider_type" field. +func (_u *PendingAuthSessionUpdate) SetProviderType(v string) *PendingAuthSessionUpdate { + _u.mutation.SetProviderType(v) + return _u +} + +// SetNillableProviderType sets the "provider_type" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableProviderType(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetProviderType(*v) + } + return _u +} + +// SetProviderKey sets the "provider_key" field. +func (_u *PendingAuthSessionUpdate) SetProviderKey(v string) *PendingAuthSessionUpdate { + _u.mutation.SetProviderKey(v) + return _u +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableProviderKey(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetProviderKey(*v) + } + return _u +} + +// SetProviderSubject sets the "provider_subject" field. +func (_u *PendingAuthSessionUpdate) SetProviderSubject(v string) *PendingAuthSessionUpdate { + _u.mutation.SetProviderSubject(v) + return _u +} + +// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableProviderSubject(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetProviderSubject(*v) + } + return _u +} + +// SetTargetUserID sets the "target_user_id" field. +func (_u *PendingAuthSessionUpdate) SetTargetUserID(v int64) *PendingAuthSessionUpdate { + _u.mutation.SetTargetUserID(v) + return _u +} + +// SetNillableTargetUserID sets the "target_user_id" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableTargetUserID(v *int64) *PendingAuthSessionUpdate { + if v != nil { + _u.SetTargetUserID(*v) + } + return _u +} + +// ClearTargetUserID clears the value of the "target_user_id" field. +func (_u *PendingAuthSessionUpdate) ClearTargetUserID() *PendingAuthSessionUpdate { + _u.mutation.ClearTargetUserID() + return _u +} + +// SetRedirectTo sets the "redirect_to" field. +func (_u *PendingAuthSessionUpdate) SetRedirectTo(v string) *PendingAuthSessionUpdate { + _u.mutation.SetRedirectTo(v) + return _u +} + +// SetNillableRedirectTo sets the "redirect_to" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableRedirectTo(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetRedirectTo(*v) + } + return _u +} + +// SetResolvedEmail sets the "resolved_email" field. +func (_u *PendingAuthSessionUpdate) SetResolvedEmail(v string) *PendingAuthSessionUpdate { + _u.mutation.SetResolvedEmail(v) + return _u +} + +// SetNillableResolvedEmail sets the "resolved_email" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableResolvedEmail(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetResolvedEmail(*v) + } + return _u +} + +// SetRegistrationPasswordHash sets the "registration_password_hash" field. +func (_u *PendingAuthSessionUpdate) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpdate { + _u.mutation.SetRegistrationPasswordHash(v) + return _u +} + +// SetNillableRegistrationPasswordHash sets the "registration_password_hash" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableRegistrationPasswordHash(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetRegistrationPasswordHash(*v) + } + return _u +} + +// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field. +func (_u *PendingAuthSessionUpdate) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpdate { + _u.mutation.SetUpstreamIdentityClaims(v) + return _u +} + +// SetLocalFlowState sets the "local_flow_state" field. +func (_u *PendingAuthSessionUpdate) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpdate { + _u.mutation.SetLocalFlowState(v) + return _u +} + +// SetBrowserSessionKey sets the "browser_session_key" field. +func (_u *PendingAuthSessionUpdate) SetBrowserSessionKey(v string) *PendingAuthSessionUpdate { + _u.mutation.SetBrowserSessionKey(v) + return _u +} + +// SetNillableBrowserSessionKey sets the "browser_session_key" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableBrowserSessionKey(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetBrowserSessionKey(*v) + } + return _u +} + +// SetCompletionCodeHash sets the "completion_code_hash" field. +func (_u *PendingAuthSessionUpdate) SetCompletionCodeHash(v string) *PendingAuthSessionUpdate { + _u.mutation.SetCompletionCodeHash(v) + return _u +} + +// SetNillableCompletionCodeHash sets the "completion_code_hash" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableCompletionCodeHash(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetCompletionCodeHash(*v) + } + return _u +} + +// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field. +func (_u *PendingAuthSessionUpdate) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpdate { + _u.mutation.SetCompletionCodeExpiresAt(v) + return _u +} + +// SetNillableCompletionCodeExpiresAt sets the "completion_code_expires_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableCompletionCodeExpiresAt(v *time.Time) *PendingAuthSessionUpdate { + if v != nil { + _u.SetCompletionCodeExpiresAt(*v) + } + return _u +} + +// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field. +func (_u *PendingAuthSessionUpdate) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpdate { + _u.mutation.ClearCompletionCodeExpiresAt() + return _u +} + +// SetEmailVerifiedAt sets the "email_verified_at" field. +func (_u *PendingAuthSessionUpdate) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpdate { + _u.mutation.SetEmailVerifiedAt(v) + return _u +} + +// SetNillableEmailVerifiedAt sets the "email_verified_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableEmailVerifiedAt(v *time.Time) *PendingAuthSessionUpdate { + if v != nil { + _u.SetEmailVerifiedAt(*v) + } + return _u +} + +// ClearEmailVerifiedAt clears the value of the "email_verified_at" field. +func (_u *PendingAuthSessionUpdate) ClearEmailVerifiedAt() *PendingAuthSessionUpdate { + _u.mutation.ClearEmailVerifiedAt() + return _u +} + +// SetPasswordVerifiedAt sets the "password_verified_at" field. +func (_u *PendingAuthSessionUpdate) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpdate { + _u.mutation.SetPasswordVerifiedAt(v) + return _u +} + +// SetNillablePasswordVerifiedAt sets the "password_verified_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillablePasswordVerifiedAt(v *time.Time) *PendingAuthSessionUpdate { + if v != nil { + _u.SetPasswordVerifiedAt(*v) + } + return _u +} + +// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field. +func (_u *PendingAuthSessionUpdate) ClearPasswordVerifiedAt() *PendingAuthSessionUpdate { + _u.mutation.ClearPasswordVerifiedAt() + return _u +} + +// SetTotpVerifiedAt sets the "totp_verified_at" field. +func (_u *PendingAuthSessionUpdate) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpdate { + _u.mutation.SetTotpVerifiedAt(v) + return _u +} + +// SetNillableTotpVerifiedAt sets the "totp_verified_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableTotpVerifiedAt(v *time.Time) *PendingAuthSessionUpdate { + if v != nil { + _u.SetTotpVerifiedAt(*v) + } + return _u +} + +// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field. +func (_u *PendingAuthSessionUpdate) ClearTotpVerifiedAt() *PendingAuthSessionUpdate { + _u.mutation.ClearTotpVerifiedAt() + return _u +} + +// SetExpiresAt sets the "expires_at" field. +func (_u *PendingAuthSessionUpdate) SetExpiresAt(v time.Time) *PendingAuthSessionUpdate { + _u.mutation.SetExpiresAt(v) + return _u +} + +// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableExpiresAt(v *time.Time) *PendingAuthSessionUpdate { + if v != nil { + _u.SetExpiresAt(*v) + } + return _u +} + +// SetConsumedAt sets the "consumed_at" field. +func (_u *PendingAuthSessionUpdate) SetConsumedAt(v time.Time) *PendingAuthSessionUpdate { + _u.mutation.SetConsumedAt(v) + return _u +} + +// SetNillableConsumedAt sets the "consumed_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableConsumedAt(v *time.Time) *PendingAuthSessionUpdate { + if v != nil { + _u.SetConsumedAt(*v) + } + return _u +} + +// ClearConsumedAt clears the value of the "consumed_at" field. +func (_u *PendingAuthSessionUpdate) ClearConsumedAt() *PendingAuthSessionUpdate { + _u.mutation.ClearConsumedAt() + return _u +} + +// SetTargetUser sets the "target_user" edge to the User entity. +func (_u *PendingAuthSessionUpdate) SetTargetUser(v *User) *PendingAuthSessionUpdate { + return _u.SetTargetUserID(v.ID) +} + +// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID. +func (_u *PendingAuthSessionUpdate) SetAdoptionDecisionID(id int64) *PendingAuthSessionUpdate { + _u.mutation.SetAdoptionDecisionID(id) + return _u +} + +// SetNillableAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableAdoptionDecisionID(id *int64) *PendingAuthSessionUpdate { + if id != nil { + _u = _u.SetAdoptionDecisionID(*id) + } + return _u +} + +// SetAdoptionDecision sets the "adoption_decision" edge to the IdentityAdoptionDecision entity. +func (_u *PendingAuthSessionUpdate) SetAdoptionDecision(v *IdentityAdoptionDecision) *PendingAuthSessionUpdate { + return _u.SetAdoptionDecisionID(v.ID) +} + +// Mutation returns the PendingAuthSessionMutation object of the builder. +func (_u *PendingAuthSessionUpdate) Mutation() *PendingAuthSessionMutation { + return _u.mutation +} + +// ClearTargetUser clears the "target_user" edge to the User entity. +func (_u *PendingAuthSessionUpdate) ClearTargetUser() *PendingAuthSessionUpdate { + _u.mutation.ClearTargetUser() + return _u +} + +// ClearAdoptionDecision clears the "adoption_decision" edge to the IdentityAdoptionDecision entity. +func (_u *PendingAuthSessionUpdate) ClearAdoptionDecision() *PendingAuthSessionUpdate { + _u.mutation.ClearAdoptionDecision() + return _u +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *PendingAuthSessionUpdate) Save(ctx context.Context) (int, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *PendingAuthSessionUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *PendingAuthSessionUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *PendingAuthSessionUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *PendingAuthSessionUpdate) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := pendingauthsession.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *PendingAuthSessionUpdate) check() error { + if v, ok := _u.mutation.SessionToken(); ok { + if err := pendingauthsession.SessionTokenValidator(v); err != nil { + return &ValidationError{Name: "session_token", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.session_token": %w`, err)} + } + } + if v, ok := _u.mutation.Intent(); ok { + if err := pendingauthsession.IntentValidator(v); err != nil { + return &ValidationError{Name: "intent", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.intent": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderType(); ok { + if err := pendingauthsession.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_type": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderKey(); ok { + if err := pendingauthsession.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_key": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderSubject(); ok { + if err := pendingauthsession.ProviderSubjectValidator(v); err != nil { + return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_subject": %w`, err)} + } + } + return nil +} + +func (_u *PendingAuthSessionUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(pendingauthsession.Table, pendingauthsession.Columns, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(pendingauthsession.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.SessionToken(); ok { + _spec.SetField(pendingauthsession.FieldSessionToken, field.TypeString, value) + } + if value, ok := _u.mutation.Intent(); ok { + _spec.SetField(pendingauthsession.FieldIntent, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderType(); ok { + _spec.SetField(pendingauthsession.FieldProviderType, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderKey(); ok { + _spec.SetField(pendingauthsession.FieldProviderKey, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderSubject(); ok { + _spec.SetField(pendingauthsession.FieldProviderSubject, field.TypeString, value) + } + if value, ok := _u.mutation.RedirectTo(); ok { + _spec.SetField(pendingauthsession.FieldRedirectTo, field.TypeString, value) + } + if value, ok := _u.mutation.ResolvedEmail(); ok { + _spec.SetField(pendingauthsession.FieldResolvedEmail, field.TypeString, value) + } + if value, ok := _u.mutation.RegistrationPasswordHash(); ok { + _spec.SetField(pendingauthsession.FieldRegistrationPasswordHash, field.TypeString, value) + } + if value, ok := _u.mutation.UpstreamIdentityClaims(); ok { + _spec.SetField(pendingauthsession.FieldUpstreamIdentityClaims, field.TypeJSON, value) + } + if value, ok := _u.mutation.LocalFlowState(); ok { + _spec.SetField(pendingauthsession.FieldLocalFlowState, field.TypeJSON, value) + } + if value, ok := _u.mutation.BrowserSessionKey(); ok { + _spec.SetField(pendingauthsession.FieldBrowserSessionKey, field.TypeString, value) + } + if value, ok := _u.mutation.CompletionCodeHash(); ok { + _spec.SetField(pendingauthsession.FieldCompletionCodeHash, field.TypeString, value) + } + if value, ok := _u.mutation.CompletionCodeExpiresAt(); ok { + _spec.SetField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime, value) + } + if _u.mutation.CompletionCodeExpiresAtCleared() { + _spec.ClearField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime) + } + if value, ok := _u.mutation.EmailVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime, value) + } + if _u.mutation.EmailVerifiedAtCleared() { + _spec.ClearField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.PasswordVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime, value) + } + if _u.mutation.PasswordVerifiedAtCleared() { + _spec.ClearField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.TotpVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime, value) + } + if _u.mutation.TotpVerifiedAtCleared() { + _spec.ClearField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.ExpiresAt(); ok { + _spec.SetField(pendingauthsession.FieldExpiresAt, field.TypeTime, value) + } + if value, ok := _u.mutation.ConsumedAt(); ok { + _spec.SetField(pendingauthsession.FieldConsumedAt, field.TypeTime, value) + } + if _u.mutation.ConsumedAtCleared() { + _spec.ClearField(pendingauthsession.FieldConsumedAt, field.TypeTime) + } + if _u.mutation.TargetUserCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: pendingauthsession.TargetUserTable, + Columns: []string{pendingauthsession.TargetUserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.TargetUserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: pendingauthsession.TargetUserTable, + Columns: []string{pendingauthsession.TargetUserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.AdoptionDecisionCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: false, + Table: pendingauthsession.AdoptionDecisionTable, + Columns: []string{pendingauthsession.AdoptionDecisionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.AdoptionDecisionIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: false, + Table: pendingauthsession.AdoptionDecisionTable, + Columns: []string{pendingauthsession.AdoptionDecisionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{pendingauthsession.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// PendingAuthSessionUpdateOne is the builder for updating a single PendingAuthSession entity. +type PendingAuthSessionUpdateOne struct { + config + fields []string + hooks []Hook + mutation *PendingAuthSessionMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *PendingAuthSessionUpdateOne) SetUpdatedAt(v time.Time) *PendingAuthSessionUpdateOne { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetSessionToken sets the "session_token" field. +func (_u *PendingAuthSessionUpdateOne) SetSessionToken(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetSessionToken(v) + return _u +} + +// SetNillableSessionToken sets the "session_token" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableSessionToken(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetSessionToken(*v) + } + return _u +} + +// SetIntent sets the "intent" field. +func (_u *PendingAuthSessionUpdateOne) SetIntent(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetIntent(v) + return _u +} + +// SetNillableIntent sets the "intent" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableIntent(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetIntent(*v) + } + return _u +} + +// SetProviderType sets the "provider_type" field. +func (_u *PendingAuthSessionUpdateOne) SetProviderType(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetProviderType(v) + return _u +} + +// SetNillableProviderType sets the "provider_type" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableProviderType(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetProviderType(*v) + } + return _u +} + +// SetProviderKey sets the "provider_key" field. +func (_u *PendingAuthSessionUpdateOne) SetProviderKey(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetProviderKey(v) + return _u +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableProviderKey(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetProviderKey(*v) + } + return _u +} + +// SetProviderSubject sets the "provider_subject" field. +func (_u *PendingAuthSessionUpdateOne) SetProviderSubject(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetProviderSubject(v) + return _u +} + +// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableProviderSubject(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetProviderSubject(*v) + } + return _u +} + +// SetTargetUserID sets the "target_user_id" field. +func (_u *PendingAuthSessionUpdateOne) SetTargetUserID(v int64) *PendingAuthSessionUpdateOne { + _u.mutation.SetTargetUserID(v) + return _u +} + +// SetNillableTargetUserID sets the "target_user_id" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableTargetUserID(v *int64) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetTargetUserID(*v) + } + return _u +} + +// ClearTargetUserID clears the value of the "target_user_id" field. +func (_u *PendingAuthSessionUpdateOne) ClearTargetUserID() *PendingAuthSessionUpdateOne { + _u.mutation.ClearTargetUserID() + return _u +} + +// SetRedirectTo sets the "redirect_to" field. +func (_u *PendingAuthSessionUpdateOne) SetRedirectTo(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetRedirectTo(v) + return _u +} + +// SetNillableRedirectTo sets the "redirect_to" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableRedirectTo(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetRedirectTo(*v) + } + return _u +} + +// SetResolvedEmail sets the "resolved_email" field. +func (_u *PendingAuthSessionUpdateOne) SetResolvedEmail(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetResolvedEmail(v) + return _u +} + +// SetNillableResolvedEmail sets the "resolved_email" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableResolvedEmail(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetResolvedEmail(*v) + } + return _u +} + +// SetRegistrationPasswordHash sets the "registration_password_hash" field. +func (_u *PendingAuthSessionUpdateOne) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetRegistrationPasswordHash(v) + return _u +} + +// SetNillableRegistrationPasswordHash sets the "registration_password_hash" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableRegistrationPasswordHash(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetRegistrationPasswordHash(*v) + } + return _u +} + +// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field. +func (_u *PendingAuthSessionUpdateOne) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpdateOne { + _u.mutation.SetUpstreamIdentityClaims(v) + return _u +} + +// SetLocalFlowState sets the "local_flow_state" field. +func (_u *PendingAuthSessionUpdateOne) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpdateOne { + _u.mutation.SetLocalFlowState(v) + return _u +} + +// SetBrowserSessionKey sets the "browser_session_key" field. +func (_u *PendingAuthSessionUpdateOne) SetBrowserSessionKey(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetBrowserSessionKey(v) + return _u +} + +// SetNillableBrowserSessionKey sets the "browser_session_key" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableBrowserSessionKey(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetBrowserSessionKey(*v) + } + return _u +} + +// SetCompletionCodeHash sets the "completion_code_hash" field. +func (_u *PendingAuthSessionUpdateOne) SetCompletionCodeHash(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetCompletionCodeHash(v) + return _u +} + +// SetNillableCompletionCodeHash sets the "completion_code_hash" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableCompletionCodeHash(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetCompletionCodeHash(*v) + } + return _u +} + +// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field. +func (_u *PendingAuthSessionUpdateOne) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpdateOne { + _u.mutation.SetCompletionCodeExpiresAt(v) + return _u +} + +// SetNillableCompletionCodeExpiresAt sets the "completion_code_expires_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableCompletionCodeExpiresAt(v *time.Time) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetCompletionCodeExpiresAt(*v) + } + return _u +} + +// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field. +func (_u *PendingAuthSessionUpdateOne) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpdateOne { + _u.mutation.ClearCompletionCodeExpiresAt() + return _u +} + +// SetEmailVerifiedAt sets the "email_verified_at" field. +func (_u *PendingAuthSessionUpdateOne) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpdateOne { + _u.mutation.SetEmailVerifiedAt(v) + return _u +} + +// SetNillableEmailVerifiedAt sets the "email_verified_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableEmailVerifiedAt(v *time.Time) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetEmailVerifiedAt(*v) + } + return _u +} + +// ClearEmailVerifiedAt clears the value of the "email_verified_at" field. +func (_u *PendingAuthSessionUpdateOne) ClearEmailVerifiedAt() *PendingAuthSessionUpdateOne { + _u.mutation.ClearEmailVerifiedAt() + return _u +} + +// SetPasswordVerifiedAt sets the "password_verified_at" field. +func (_u *PendingAuthSessionUpdateOne) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpdateOne { + _u.mutation.SetPasswordVerifiedAt(v) + return _u +} + +// SetNillablePasswordVerifiedAt sets the "password_verified_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillablePasswordVerifiedAt(v *time.Time) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetPasswordVerifiedAt(*v) + } + return _u +} + +// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field. +func (_u *PendingAuthSessionUpdateOne) ClearPasswordVerifiedAt() *PendingAuthSessionUpdateOne { + _u.mutation.ClearPasswordVerifiedAt() + return _u +} + +// SetTotpVerifiedAt sets the "totp_verified_at" field. +func (_u *PendingAuthSessionUpdateOne) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpdateOne { + _u.mutation.SetTotpVerifiedAt(v) + return _u +} + +// SetNillableTotpVerifiedAt sets the "totp_verified_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableTotpVerifiedAt(v *time.Time) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetTotpVerifiedAt(*v) + } + return _u +} + +// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field. +func (_u *PendingAuthSessionUpdateOne) ClearTotpVerifiedAt() *PendingAuthSessionUpdateOne { + _u.mutation.ClearTotpVerifiedAt() + return _u +} + +// SetExpiresAt sets the "expires_at" field. +func (_u *PendingAuthSessionUpdateOne) SetExpiresAt(v time.Time) *PendingAuthSessionUpdateOne { + _u.mutation.SetExpiresAt(v) + return _u +} + +// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableExpiresAt(v *time.Time) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetExpiresAt(*v) + } + return _u +} + +// SetConsumedAt sets the "consumed_at" field. +func (_u *PendingAuthSessionUpdateOne) SetConsumedAt(v time.Time) *PendingAuthSessionUpdateOne { + _u.mutation.SetConsumedAt(v) + return _u +} + +// SetNillableConsumedAt sets the "consumed_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableConsumedAt(v *time.Time) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetConsumedAt(*v) + } + return _u +} + +// ClearConsumedAt clears the value of the "consumed_at" field. +func (_u *PendingAuthSessionUpdateOne) ClearConsumedAt() *PendingAuthSessionUpdateOne { + _u.mutation.ClearConsumedAt() + return _u +} + +// SetTargetUser sets the "target_user" edge to the User entity. +func (_u *PendingAuthSessionUpdateOne) SetTargetUser(v *User) *PendingAuthSessionUpdateOne { + return _u.SetTargetUserID(v.ID) +} + +// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID. +func (_u *PendingAuthSessionUpdateOne) SetAdoptionDecisionID(id int64) *PendingAuthSessionUpdateOne { + _u.mutation.SetAdoptionDecisionID(id) + return _u +} + +// SetNillableAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableAdoptionDecisionID(id *int64) *PendingAuthSessionUpdateOne { + if id != nil { + _u = _u.SetAdoptionDecisionID(*id) + } + return _u +} + +// SetAdoptionDecision sets the "adoption_decision" edge to the IdentityAdoptionDecision entity. +func (_u *PendingAuthSessionUpdateOne) SetAdoptionDecision(v *IdentityAdoptionDecision) *PendingAuthSessionUpdateOne { + return _u.SetAdoptionDecisionID(v.ID) +} + +// Mutation returns the PendingAuthSessionMutation object of the builder. +func (_u *PendingAuthSessionUpdateOne) Mutation() *PendingAuthSessionMutation { + return _u.mutation +} + +// ClearTargetUser clears the "target_user" edge to the User entity. +func (_u *PendingAuthSessionUpdateOne) ClearTargetUser() *PendingAuthSessionUpdateOne { + _u.mutation.ClearTargetUser() + return _u +} + +// ClearAdoptionDecision clears the "adoption_decision" edge to the IdentityAdoptionDecision entity. +func (_u *PendingAuthSessionUpdateOne) ClearAdoptionDecision() *PendingAuthSessionUpdateOne { + _u.mutation.ClearAdoptionDecision() + return _u +} + +// Where appends a list predicates to the PendingAuthSessionUpdate builder. +func (_u *PendingAuthSessionUpdateOne) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *PendingAuthSessionUpdateOne) Select(field string, fields ...string) *PendingAuthSessionUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated PendingAuthSession entity. +func (_u *PendingAuthSessionUpdateOne) Save(ctx context.Context) (*PendingAuthSession, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *PendingAuthSessionUpdateOne) SaveX(ctx context.Context) *PendingAuthSession { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *PendingAuthSessionUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *PendingAuthSessionUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *PendingAuthSessionUpdateOne) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := pendingauthsession.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *PendingAuthSessionUpdateOne) check() error { + if v, ok := _u.mutation.SessionToken(); ok { + if err := pendingauthsession.SessionTokenValidator(v); err != nil { + return &ValidationError{Name: "session_token", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.session_token": %w`, err)} + } + } + if v, ok := _u.mutation.Intent(); ok { + if err := pendingauthsession.IntentValidator(v); err != nil { + return &ValidationError{Name: "intent", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.intent": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderType(); ok { + if err := pendingauthsession.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_type": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderKey(); ok { + if err := pendingauthsession.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_key": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderSubject(); ok { + if err := pendingauthsession.ProviderSubjectValidator(v); err != nil { + return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_subject": %w`, err)} + } + } + return nil +} + +func (_u *PendingAuthSessionUpdateOne) sqlSave(ctx context.Context) (_node *PendingAuthSession, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(pendingauthsession.Table, pendingauthsession.Columns, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "PendingAuthSession.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, pendingauthsession.FieldID) + for _, f := range fields { + if !pendingauthsession.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != pendingauthsession.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(pendingauthsession.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.SessionToken(); ok { + _spec.SetField(pendingauthsession.FieldSessionToken, field.TypeString, value) + } + if value, ok := _u.mutation.Intent(); ok { + _spec.SetField(pendingauthsession.FieldIntent, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderType(); ok { + _spec.SetField(pendingauthsession.FieldProviderType, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderKey(); ok { + _spec.SetField(pendingauthsession.FieldProviderKey, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderSubject(); ok { + _spec.SetField(pendingauthsession.FieldProviderSubject, field.TypeString, value) + } + if value, ok := _u.mutation.RedirectTo(); ok { + _spec.SetField(pendingauthsession.FieldRedirectTo, field.TypeString, value) + } + if value, ok := _u.mutation.ResolvedEmail(); ok { + _spec.SetField(pendingauthsession.FieldResolvedEmail, field.TypeString, value) + } + if value, ok := _u.mutation.RegistrationPasswordHash(); ok { + _spec.SetField(pendingauthsession.FieldRegistrationPasswordHash, field.TypeString, value) + } + if value, ok := _u.mutation.UpstreamIdentityClaims(); ok { + _spec.SetField(pendingauthsession.FieldUpstreamIdentityClaims, field.TypeJSON, value) + } + if value, ok := _u.mutation.LocalFlowState(); ok { + _spec.SetField(pendingauthsession.FieldLocalFlowState, field.TypeJSON, value) + } + if value, ok := _u.mutation.BrowserSessionKey(); ok { + _spec.SetField(pendingauthsession.FieldBrowserSessionKey, field.TypeString, value) + } + if value, ok := _u.mutation.CompletionCodeHash(); ok { + _spec.SetField(pendingauthsession.FieldCompletionCodeHash, field.TypeString, value) + } + if value, ok := _u.mutation.CompletionCodeExpiresAt(); ok { + _spec.SetField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime, value) + } + if _u.mutation.CompletionCodeExpiresAtCleared() { + _spec.ClearField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime) + } + if value, ok := _u.mutation.EmailVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime, value) + } + if _u.mutation.EmailVerifiedAtCleared() { + _spec.ClearField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.PasswordVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime, value) + } + if _u.mutation.PasswordVerifiedAtCleared() { + _spec.ClearField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.TotpVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime, value) + } + if _u.mutation.TotpVerifiedAtCleared() { + _spec.ClearField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.ExpiresAt(); ok { + _spec.SetField(pendingauthsession.FieldExpiresAt, field.TypeTime, value) + } + if value, ok := _u.mutation.ConsumedAt(); ok { + _spec.SetField(pendingauthsession.FieldConsumedAt, field.TypeTime, value) + } + if _u.mutation.ConsumedAtCleared() { + _spec.ClearField(pendingauthsession.FieldConsumedAt, field.TypeTime) + } + if _u.mutation.TargetUserCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: pendingauthsession.TargetUserTable, + Columns: []string{pendingauthsession.TargetUserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.TargetUserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: pendingauthsession.TargetUserTable, + Columns: []string{pendingauthsession.TargetUserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.AdoptionDecisionCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: false, + Table: pendingauthsession.AdoptionDecisionTable, + Columns: []string{pendingauthsession.AdoptionDecisionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.AdoptionDecisionIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: false, + Table: pendingauthsession.AdoptionDecisionTable, + Columns: []string{pendingauthsession.AdoptionDecisionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &PendingAuthSession{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{pendingauthsession.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/predicate/predicate.go b/backend/ent/predicate/predicate.go index ef551940067ec1a2d96783ca78994bc6cb8d9fae..0aa90b90a8a4f7f64177710e3377ac498e7b8af7 100644 --- a/backend/ent/predicate/predicate.go +++ b/backend/ent/predicate/predicate.go @@ -21,6 +21,12 @@ type Announcement func(*sql.Selector) // AnnouncementRead is the predicate function for announcementread builders. type AnnouncementRead func(*sql.Selector) +// AuthIdentity is the predicate function for authidentity builders. +type AuthIdentity func(*sql.Selector) + +// AuthIdentityChannel is the predicate function for authidentitychannel builders. +type AuthIdentityChannel func(*sql.Selector) + // ErrorPassthroughRule is the predicate function for errorpassthroughrule builders. type ErrorPassthroughRule func(*sql.Selector) @@ -30,6 +36,9 @@ type Group func(*sql.Selector) // IdempotencyRecord is the predicate function for idempotencyrecord builders. type IdempotencyRecord func(*sql.Selector) +// IdentityAdoptionDecision is the predicate function for identityadoptiondecision builders. +type IdentityAdoptionDecision func(*sql.Selector) + // PaymentAuditLog is the predicate function for paymentauditlog builders. type PaymentAuditLog func(*sql.Selector) @@ -39,6 +48,9 @@ type PaymentOrder func(*sql.Selector) // PaymentProviderInstance is the predicate function for paymentproviderinstance builders. type PaymentProviderInstance func(*sql.Selector) +// PendingAuthSession is the predicate function for pendingauthsession builders. +type PendingAuthSession func(*sql.Selector) + // PromoCode is the predicate function for promocode builders. type PromoCode func(*sql.Selector) diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index fbdd08c785c5a3a81ebafe1d27ddf488f9bae17c..bdb7f7a93a0970b3355de73a274f90ce07e28358 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -10,12 +10,16 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/paymentauditlog" "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/proxy" @@ -309,6 +313,120 @@ func init() { announcementreadDescCreatedAt := announcementreadFields[3].Descriptor() // announcementread.DefaultCreatedAt holds the default value on creation for the created_at field. announcementread.DefaultCreatedAt = announcementreadDescCreatedAt.Default.(func() time.Time) + authidentityMixin := schema.AuthIdentity{}.Mixin() + authidentityMixinFields0 := authidentityMixin[0].Fields() + _ = authidentityMixinFields0 + authidentityFields := schema.AuthIdentity{}.Fields() + _ = authidentityFields + // authidentityDescCreatedAt is the schema descriptor for created_at field. + authidentityDescCreatedAt := authidentityMixinFields0[0].Descriptor() + // authidentity.DefaultCreatedAt holds the default value on creation for the created_at field. + authidentity.DefaultCreatedAt = authidentityDescCreatedAt.Default.(func() time.Time) + // authidentityDescUpdatedAt is the schema descriptor for updated_at field. + authidentityDescUpdatedAt := authidentityMixinFields0[1].Descriptor() + // authidentity.DefaultUpdatedAt holds the default value on creation for the updated_at field. + authidentity.DefaultUpdatedAt = authidentityDescUpdatedAt.Default.(func() time.Time) + // authidentity.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + authidentity.UpdateDefaultUpdatedAt = authidentityDescUpdatedAt.UpdateDefault.(func() time.Time) + // authidentityDescProviderType is the schema descriptor for provider_type field. + authidentityDescProviderType := authidentityFields[1].Descriptor() + // authidentity.ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save. + authidentity.ProviderTypeValidator = func() func(string) error { + validators := authidentityDescProviderType.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + validators[2].(func(string) error), + } + return func(provider_type string) error { + for _, fn := range fns { + if err := fn(provider_type); err != nil { + return err + } + } + return nil + } + }() + // authidentityDescProviderKey is the schema descriptor for provider_key field. + authidentityDescProviderKey := authidentityFields[2].Descriptor() + // authidentity.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save. + authidentity.ProviderKeyValidator = authidentityDescProviderKey.Validators[0].(func(string) error) + // authidentityDescProviderSubject is the schema descriptor for provider_subject field. + authidentityDescProviderSubject := authidentityFields[3].Descriptor() + // authidentity.ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save. + authidentity.ProviderSubjectValidator = authidentityDescProviderSubject.Validators[0].(func(string) error) + // authidentityDescMetadata is the schema descriptor for metadata field. + authidentityDescMetadata := authidentityFields[6].Descriptor() + // authidentity.DefaultMetadata holds the default value on creation for the metadata field. + authidentity.DefaultMetadata = authidentityDescMetadata.Default.(func() map[string]interface{}) + authidentitychannelMixin := schema.AuthIdentityChannel{}.Mixin() + authidentitychannelMixinFields0 := authidentitychannelMixin[0].Fields() + _ = authidentitychannelMixinFields0 + authidentitychannelFields := schema.AuthIdentityChannel{}.Fields() + _ = authidentitychannelFields + // authidentitychannelDescCreatedAt is the schema descriptor for created_at field. + authidentitychannelDescCreatedAt := authidentitychannelMixinFields0[0].Descriptor() + // authidentitychannel.DefaultCreatedAt holds the default value on creation for the created_at field. + authidentitychannel.DefaultCreatedAt = authidentitychannelDescCreatedAt.Default.(func() time.Time) + // authidentitychannelDescUpdatedAt is the schema descriptor for updated_at field. + authidentitychannelDescUpdatedAt := authidentitychannelMixinFields0[1].Descriptor() + // authidentitychannel.DefaultUpdatedAt holds the default value on creation for the updated_at field. + authidentitychannel.DefaultUpdatedAt = authidentitychannelDescUpdatedAt.Default.(func() time.Time) + // authidentitychannel.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + authidentitychannel.UpdateDefaultUpdatedAt = authidentitychannelDescUpdatedAt.UpdateDefault.(func() time.Time) + // authidentitychannelDescProviderType is the schema descriptor for provider_type field. + authidentitychannelDescProviderType := authidentitychannelFields[1].Descriptor() + // authidentitychannel.ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save. + authidentitychannel.ProviderTypeValidator = func() func(string) error { + validators := authidentitychannelDescProviderType.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + validators[2].(func(string) error), + } + return func(provider_type string) error { + for _, fn := range fns { + if err := fn(provider_type); err != nil { + return err + } + } + return nil + } + }() + // authidentitychannelDescProviderKey is the schema descriptor for provider_key field. + authidentitychannelDescProviderKey := authidentitychannelFields[2].Descriptor() + // authidentitychannel.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save. + authidentitychannel.ProviderKeyValidator = authidentitychannelDescProviderKey.Validators[0].(func(string) error) + // authidentitychannelDescChannel is the schema descriptor for channel field. + authidentitychannelDescChannel := authidentitychannelFields[3].Descriptor() + // authidentitychannel.ChannelValidator is a validator for the "channel" field. It is called by the builders before save. + authidentitychannel.ChannelValidator = func() func(string) error { + validators := authidentitychannelDescChannel.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + } + return func(channel string) error { + for _, fn := range fns { + if err := fn(channel); err != nil { + return err + } + } + return nil + } + }() + // authidentitychannelDescChannelAppID is the schema descriptor for channel_app_id field. + authidentitychannelDescChannelAppID := authidentitychannelFields[4].Descriptor() + // authidentitychannel.ChannelAppIDValidator is a validator for the "channel_app_id" field. It is called by the builders before save. + authidentitychannel.ChannelAppIDValidator = authidentitychannelDescChannelAppID.Validators[0].(func(string) error) + // authidentitychannelDescChannelSubject is the schema descriptor for channel_subject field. + authidentitychannelDescChannelSubject := authidentitychannelFields[5].Descriptor() + // authidentitychannel.ChannelSubjectValidator is a validator for the "channel_subject" field. It is called by the builders before save. + authidentitychannel.ChannelSubjectValidator = authidentitychannelDescChannelSubject.Validators[0].(func(string) error) + // authidentitychannelDescMetadata is the schema descriptor for metadata field. + authidentitychannelDescMetadata := authidentitychannelFields[6].Descriptor() + // authidentitychannel.DefaultMetadata holds the default value on creation for the metadata field. + authidentitychannel.DefaultMetadata = authidentitychannelDescMetadata.Default.(func() map[string]interface{}) errorpassthroughruleMixin := schema.ErrorPassthroughRule{}.Mixin() errorpassthroughruleMixinFields0 := errorpassthroughruleMixin[0].Fields() _ = errorpassthroughruleMixinFields0 @@ -512,6 +630,33 @@ func init() { idempotencyrecordDescErrorReason := idempotencyrecordFields[6].Descriptor() // idempotencyrecord.ErrorReasonValidator is a validator for the "error_reason" field. It is called by the builders before save. idempotencyrecord.ErrorReasonValidator = idempotencyrecordDescErrorReason.Validators[0].(func(string) error) + identityadoptiondecisionMixin := schema.IdentityAdoptionDecision{}.Mixin() + identityadoptiondecisionMixinFields0 := identityadoptiondecisionMixin[0].Fields() + _ = identityadoptiondecisionMixinFields0 + identityadoptiondecisionFields := schema.IdentityAdoptionDecision{}.Fields() + _ = identityadoptiondecisionFields + // identityadoptiondecisionDescCreatedAt is the schema descriptor for created_at field. + identityadoptiondecisionDescCreatedAt := identityadoptiondecisionMixinFields0[0].Descriptor() + // identityadoptiondecision.DefaultCreatedAt holds the default value on creation for the created_at field. + identityadoptiondecision.DefaultCreatedAt = identityadoptiondecisionDescCreatedAt.Default.(func() time.Time) + // identityadoptiondecisionDescUpdatedAt is the schema descriptor for updated_at field. + identityadoptiondecisionDescUpdatedAt := identityadoptiondecisionMixinFields0[1].Descriptor() + // identityadoptiondecision.DefaultUpdatedAt holds the default value on creation for the updated_at field. + identityadoptiondecision.DefaultUpdatedAt = identityadoptiondecisionDescUpdatedAt.Default.(func() time.Time) + // identityadoptiondecision.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + identityadoptiondecision.UpdateDefaultUpdatedAt = identityadoptiondecisionDescUpdatedAt.UpdateDefault.(func() time.Time) + // identityadoptiondecisionDescAdoptDisplayName is the schema descriptor for adopt_display_name field. + identityadoptiondecisionDescAdoptDisplayName := identityadoptiondecisionFields[2].Descriptor() + // identityadoptiondecision.DefaultAdoptDisplayName holds the default value on creation for the adopt_display_name field. + identityadoptiondecision.DefaultAdoptDisplayName = identityadoptiondecisionDescAdoptDisplayName.Default.(bool) + // identityadoptiondecisionDescAdoptAvatar is the schema descriptor for adopt_avatar field. + identityadoptiondecisionDescAdoptAvatar := identityadoptiondecisionFields[3].Descriptor() + // identityadoptiondecision.DefaultAdoptAvatar holds the default value on creation for the adopt_avatar field. + identityadoptiondecision.DefaultAdoptAvatar = identityadoptiondecisionDescAdoptAvatar.Default.(bool) + // identityadoptiondecisionDescDecidedAt is the schema descriptor for decided_at field. + identityadoptiondecisionDescDecidedAt := identityadoptiondecisionFields[4].Descriptor() + // identityadoptiondecision.DefaultDecidedAt holds the default value on creation for the decided_at field. + identityadoptiondecision.DefaultDecidedAt = identityadoptiondecisionDescDecidedAt.Default.(func() time.Time) paymentauditlogFields := schema.PaymentAuditLog{}.Fields() _ = paymentauditlogFields // paymentauditlogDescOrderID is the schema descriptor for order_id field. @@ -578,38 +723,42 @@ func init() { paymentorderDescProviderInstanceID := paymentorderFields[18].Descriptor() // paymentorder.ProviderInstanceIDValidator is a validator for the "provider_instance_id" field. It is called by the builders before save. paymentorder.ProviderInstanceIDValidator = paymentorderDescProviderInstanceID.Validators[0].(func(string) error) + // paymentorderDescProviderKey is the schema descriptor for provider_key field. + paymentorderDescProviderKey := paymentorderFields[19].Descriptor() + // paymentorder.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save. + paymentorder.ProviderKeyValidator = paymentorderDescProviderKey.Validators[0].(func(string) error) // paymentorderDescStatus is the schema descriptor for status field. - paymentorderDescStatus := paymentorderFields[19].Descriptor() + paymentorderDescStatus := paymentorderFields[21].Descriptor() // paymentorder.DefaultStatus holds the default value on creation for the status field. paymentorder.DefaultStatus = paymentorderDescStatus.Default.(string) // paymentorder.StatusValidator is a validator for the "status" field. It is called by the builders before save. paymentorder.StatusValidator = paymentorderDescStatus.Validators[0].(func(string) error) // paymentorderDescRefundAmount is the schema descriptor for refund_amount field. - paymentorderDescRefundAmount := paymentorderFields[20].Descriptor() + paymentorderDescRefundAmount := paymentorderFields[22].Descriptor() // paymentorder.DefaultRefundAmount holds the default value on creation for the refund_amount field. paymentorder.DefaultRefundAmount = paymentorderDescRefundAmount.Default.(float64) // paymentorderDescForceRefund is the schema descriptor for force_refund field. - paymentorderDescForceRefund := paymentorderFields[23].Descriptor() + paymentorderDescForceRefund := paymentorderFields[25].Descriptor() // paymentorder.DefaultForceRefund holds the default value on creation for the force_refund field. paymentorder.DefaultForceRefund = paymentorderDescForceRefund.Default.(bool) // paymentorderDescRefundRequestedBy is the schema descriptor for refund_requested_by field. - paymentorderDescRefundRequestedBy := paymentorderFields[26].Descriptor() + paymentorderDescRefundRequestedBy := paymentorderFields[28].Descriptor() // paymentorder.RefundRequestedByValidator is a validator for the "refund_requested_by" field. It is called by the builders before save. paymentorder.RefundRequestedByValidator = paymentorderDescRefundRequestedBy.Validators[0].(func(string) error) // paymentorderDescClientIP is the schema descriptor for client_ip field. - paymentorderDescClientIP := paymentorderFields[32].Descriptor() + paymentorderDescClientIP := paymentorderFields[34].Descriptor() // paymentorder.ClientIPValidator is a validator for the "client_ip" field. It is called by the builders before save. paymentorder.ClientIPValidator = paymentorderDescClientIP.Validators[0].(func(string) error) // paymentorderDescSrcHost is the schema descriptor for src_host field. - paymentorderDescSrcHost := paymentorderFields[33].Descriptor() + paymentorderDescSrcHost := paymentorderFields[35].Descriptor() // paymentorder.SrcHostValidator is a validator for the "src_host" field. It is called by the builders before save. paymentorder.SrcHostValidator = paymentorderDescSrcHost.Validators[0].(func(string) error) // paymentorderDescCreatedAt is the schema descriptor for created_at field. - paymentorderDescCreatedAt := paymentorderFields[35].Descriptor() + paymentorderDescCreatedAt := paymentorderFields[37].Descriptor() // paymentorder.DefaultCreatedAt holds the default value on creation for the created_at field. paymentorder.DefaultCreatedAt = paymentorderDescCreatedAt.Default.(func() time.Time) // paymentorderDescUpdatedAt is the schema descriptor for updated_at field. - paymentorderDescUpdatedAt := paymentorderFields[36].Descriptor() + paymentorderDescUpdatedAt := paymentorderFields[38].Descriptor() // paymentorder.DefaultUpdatedAt holds the default value on creation for the updated_at field. paymentorder.DefaultUpdatedAt = paymentorderDescUpdatedAt.Default.(func() time.Time) // paymentorder.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. @@ -682,6 +831,113 @@ func init() { paymentproviderinstance.DefaultUpdatedAt = paymentproviderinstanceDescUpdatedAt.Default.(func() time.Time) // paymentproviderinstance.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. paymentproviderinstance.UpdateDefaultUpdatedAt = paymentproviderinstanceDescUpdatedAt.UpdateDefault.(func() time.Time) + pendingauthsessionMixin := schema.PendingAuthSession{}.Mixin() + pendingauthsessionMixinFields0 := pendingauthsessionMixin[0].Fields() + _ = pendingauthsessionMixinFields0 + pendingauthsessionFields := schema.PendingAuthSession{}.Fields() + _ = pendingauthsessionFields + // pendingauthsessionDescCreatedAt is the schema descriptor for created_at field. + pendingauthsessionDescCreatedAt := pendingauthsessionMixinFields0[0].Descriptor() + // pendingauthsession.DefaultCreatedAt holds the default value on creation for the created_at field. + pendingauthsession.DefaultCreatedAt = pendingauthsessionDescCreatedAt.Default.(func() time.Time) + // pendingauthsessionDescUpdatedAt is the schema descriptor for updated_at field. + pendingauthsessionDescUpdatedAt := pendingauthsessionMixinFields0[1].Descriptor() + // pendingauthsession.DefaultUpdatedAt holds the default value on creation for the updated_at field. + pendingauthsession.DefaultUpdatedAt = pendingauthsessionDescUpdatedAt.Default.(func() time.Time) + // pendingauthsession.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + pendingauthsession.UpdateDefaultUpdatedAt = pendingauthsessionDescUpdatedAt.UpdateDefault.(func() time.Time) + // pendingauthsessionDescSessionToken is the schema descriptor for session_token field. + pendingauthsessionDescSessionToken := pendingauthsessionFields[0].Descriptor() + // pendingauthsession.SessionTokenValidator is a validator for the "session_token" field. It is called by the builders before save. + pendingauthsession.SessionTokenValidator = func() func(string) error { + validators := pendingauthsessionDescSessionToken.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + } + return func(session_token string) error { + for _, fn := range fns { + if err := fn(session_token); err != nil { + return err + } + } + return nil + } + }() + // pendingauthsessionDescIntent is the schema descriptor for intent field. + pendingauthsessionDescIntent := pendingauthsessionFields[1].Descriptor() + // pendingauthsession.IntentValidator is a validator for the "intent" field. It is called by the builders before save. + pendingauthsession.IntentValidator = func() func(string) error { + validators := pendingauthsessionDescIntent.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + validators[2].(func(string) error), + } + return func(intent string) error { + for _, fn := range fns { + if err := fn(intent); err != nil { + return err + } + } + return nil + } + }() + // pendingauthsessionDescProviderType is the schema descriptor for provider_type field. + pendingauthsessionDescProviderType := pendingauthsessionFields[2].Descriptor() + // pendingauthsession.ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save. + pendingauthsession.ProviderTypeValidator = func() func(string) error { + validators := pendingauthsessionDescProviderType.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + validators[2].(func(string) error), + } + return func(provider_type string) error { + for _, fn := range fns { + if err := fn(provider_type); err != nil { + return err + } + } + return nil + } + }() + // pendingauthsessionDescProviderKey is the schema descriptor for provider_key field. + pendingauthsessionDescProviderKey := pendingauthsessionFields[3].Descriptor() + // pendingauthsession.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save. + pendingauthsession.ProviderKeyValidator = pendingauthsessionDescProviderKey.Validators[0].(func(string) error) + // pendingauthsessionDescProviderSubject is the schema descriptor for provider_subject field. + pendingauthsessionDescProviderSubject := pendingauthsessionFields[4].Descriptor() + // pendingauthsession.ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save. + pendingauthsession.ProviderSubjectValidator = pendingauthsessionDescProviderSubject.Validators[0].(func(string) error) + // pendingauthsessionDescRedirectTo is the schema descriptor for redirect_to field. + pendingauthsessionDescRedirectTo := pendingauthsessionFields[6].Descriptor() + // pendingauthsession.DefaultRedirectTo holds the default value on creation for the redirect_to field. + pendingauthsession.DefaultRedirectTo = pendingauthsessionDescRedirectTo.Default.(string) + // pendingauthsessionDescResolvedEmail is the schema descriptor for resolved_email field. + pendingauthsessionDescResolvedEmail := pendingauthsessionFields[7].Descriptor() + // pendingauthsession.DefaultResolvedEmail holds the default value on creation for the resolved_email field. + pendingauthsession.DefaultResolvedEmail = pendingauthsessionDescResolvedEmail.Default.(string) + // pendingauthsessionDescRegistrationPasswordHash is the schema descriptor for registration_password_hash field. + pendingauthsessionDescRegistrationPasswordHash := pendingauthsessionFields[8].Descriptor() + // pendingauthsession.DefaultRegistrationPasswordHash holds the default value on creation for the registration_password_hash field. + pendingauthsession.DefaultRegistrationPasswordHash = pendingauthsessionDescRegistrationPasswordHash.Default.(string) + // pendingauthsessionDescUpstreamIdentityClaims is the schema descriptor for upstream_identity_claims field. + pendingauthsessionDescUpstreamIdentityClaims := pendingauthsessionFields[9].Descriptor() + // pendingauthsession.DefaultUpstreamIdentityClaims holds the default value on creation for the upstream_identity_claims field. + pendingauthsession.DefaultUpstreamIdentityClaims = pendingauthsessionDescUpstreamIdentityClaims.Default.(func() map[string]interface{}) + // pendingauthsessionDescLocalFlowState is the schema descriptor for local_flow_state field. + pendingauthsessionDescLocalFlowState := pendingauthsessionFields[10].Descriptor() + // pendingauthsession.DefaultLocalFlowState holds the default value on creation for the local_flow_state field. + pendingauthsession.DefaultLocalFlowState = pendingauthsessionDescLocalFlowState.Default.(func() map[string]interface{}) + // pendingauthsessionDescBrowserSessionKey is the schema descriptor for browser_session_key field. + pendingauthsessionDescBrowserSessionKey := pendingauthsessionFields[11].Descriptor() + // pendingauthsession.DefaultBrowserSessionKey holds the default value on creation for the browser_session_key field. + pendingauthsession.DefaultBrowserSessionKey = pendingauthsessionDescBrowserSessionKey.Default.(string) + // pendingauthsessionDescCompletionCodeHash is the schema descriptor for completion_code_hash field. + pendingauthsessionDescCompletionCodeHash := pendingauthsessionFields[12].Descriptor() + // pendingauthsession.DefaultCompletionCodeHash holds the default value on creation for the completion_code_hash field. + pendingauthsession.DefaultCompletionCodeHash = pendingauthsessionDescCompletionCodeHash.Default.(string) promocodeFields := schema.PromoCode{}.Fields() _ = promocodeFields // promocodeDescCode is the schema descriptor for code field. @@ -1297,20 +1553,26 @@ func init() { userDescTotpEnabled := userFields[9].Descriptor() // user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field. user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool) + // userDescSignupSource is the schema descriptor for signup_source field. + userDescSignupSource := userFields[11].Descriptor() + // user.DefaultSignupSource holds the default value on creation for the signup_source field. + user.DefaultSignupSource = userDescSignupSource.Default.(string) + // user.SignupSourceValidator is a validator for the "signup_source" field. It is called by the builders before save. + user.SignupSourceValidator = userDescSignupSource.Validators[0].(func(string) error) // userDescBalanceNotifyEnabled is the schema descriptor for balance_notify_enabled field. - userDescBalanceNotifyEnabled := userFields[11].Descriptor() + userDescBalanceNotifyEnabled := userFields[14].Descriptor() // user.DefaultBalanceNotifyEnabled holds the default value on creation for the balance_notify_enabled field. user.DefaultBalanceNotifyEnabled = userDescBalanceNotifyEnabled.Default.(bool) // userDescBalanceNotifyThresholdType is the schema descriptor for balance_notify_threshold_type field. - userDescBalanceNotifyThresholdType := userFields[12].Descriptor() + userDescBalanceNotifyThresholdType := userFields[15].Descriptor() // user.DefaultBalanceNotifyThresholdType holds the default value on creation for the balance_notify_threshold_type field. user.DefaultBalanceNotifyThresholdType = userDescBalanceNotifyThresholdType.Default.(string) // userDescBalanceNotifyExtraEmails is the schema descriptor for balance_notify_extra_emails field. - userDescBalanceNotifyExtraEmails := userFields[14].Descriptor() + userDescBalanceNotifyExtraEmails := userFields[17].Descriptor() // user.DefaultBalanceNotifyExtraEmails holds the default value on creation for the balance_notify_extra_emails field. user.DefaultBalanceNotifyExtraEmails = userDescBalanceNotifyExtraEmails.Default.(string) // userDescTotalRecharged is the schema descriptor for total_recharged field. - userDescTotalRecharged := userFields[15].Descriptor() + userDescTotalRecharged := userFields[18].Descriptor() // user.DefaultTotalRecharged holds the default value on creation for the total_recharged field. user.DefaultTotalRecharged = userDescTotalRecharged.Default.(float64) userallowedgroupFields := schema.UserAllowedGroup{}.Fields() diff --git a/backend/ent/schema/auth_identity.go b/backend/ent/schema/auth_identity.go new file mode 100644 index 0000000000000000000000000000000000000000..e4b9ac909177e6db41b6a06e3e4658321d995fb2 --- /dev/null +++ b/backend/ent/schema/auth_identity.go @@ -0,0 +1,93 @@ +package schema + +import ( + "fmt" + + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" +) + +var authProviderTypes = map[string]struct{}{ + "email": {}, + "linuxdo": {}, + "oidc": {}, + "wechat": {}, +} + +func validateAuthProviderType(value string) error { + if _, ok := authProviderTypes[value]; ok { + return nil + } + return fmt.Errorf("invalid auth provider type %q", value) +} + +// AuthIdentity stores the canonical login identity for an account. +type AuthIdentity struct { + ent.Schema +} + +func (AuthIdentity) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "auth_identities"}, + } +} + +func (AuthIdentity) Mixin() []ent.Mixin { + return []ent.Mixin{ + mixins.TimeMixin{}, + } +} + +func (AuthIdentity) Fields() []ent.Field { + return []ent.Field{ + field.Int64("user_id"), + field.String("provider_type"). + MaxLen(20). + NotEmpty(). + Validate(validateAuthProviderType), + field.String("provider_key"). + NotEmpty(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.String("provider_subject"). + NotEmpty(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.Time("verified_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.String("issuer"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.JSON("metadata", map[string]any{}). + Default(func() map[string]any { return map[string]any{} }). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), + } +} + +func (AuthIdentity) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("user", User.Type). + Ref("auth_identities"). + Field("user_id"). + Required(). + Unique(), + edge.To("channels", AuthIdentityChannel.Type), + edge.To("adoption_decisions", IdentityAdoptionDecision.Type), + } +} + +func (AuthIdentity) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("provider_type", "provider_key", "provider_subject").Unique(), + index.Fields("user_id"), + index.Fields("user_id", "provider_type"), + } +} diff --git a/backend/ent/schema/auth_identity_channel.go b/backend/ent/schema/auth_identity_channel.go new file mode 100644 index 0000000000000000000000000000000000000000..69f2ad028f3249331ec8e2a3ecdcfbab93e1acaf --- /dev/null +++ b/backend/ent/schema/auth_identity_channel.go @@ -0,0 +1,72 @@ +package schema + +import ( + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" +) + +// AuthIdentityChannel stores channel-scoped identifiers for a canonical identity. +type AuthIdentityChannel struct { + ent.Schema +} + +func (AuthIdentityChannel) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "auth_identity_channels"}, + } +} + +func (AuthIdentityChannel) Mixin() []ent.Mixin { + return []ent.Mixin{ + mixins.TimeMixin{}, + } +} + +func (AuthIdentityChannel) Fields() []ent.Field { + return []ent.Field{ + field.Int64("identity_id"), + field.String("provider_type"). + MaxLen(20). + NotEmpty(). + Validate(validateAuthProviderType), + field.String("provider_key"). + NotEmpty(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.String("channel"). + MaxLen(20). + NotEmpty(), + field.String("channel_app_id"). + NotEmpty(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.String("channel_subject"). + NotEmpty(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.JSON("metadata", map[string]any{}). + Default(func() map[string]any { return map[string]any{} }). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), + } +} + +func (AuthIdentityChannel) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("identity", AuthIdentity.Type). + Ref("channels"). + Field("identity_id"). + Required(). + Unique(), + } +} + +func (AuthIdentityChannel) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("provider_type", "provider_key", "channel", "channel_app_id", "channel_subject").Unique(), + index.Fields("identity_id"), + } +} diff --git a/backend/ent/schema/auth_identity_schema_test.go b/backend/ent/schema/auth_identity_schema_test.go new file mode 100644 index 0000000000000000000000000000000000000000..de55dd6965483bf131fc1369de8b17233cd39d24 --- /dev/null +++ b/backend/ent/schema/auth_identity_schema_test.go @@ -0,0 +1,124 @@ +package schema + +import ( + "testing" + + "entgo.io/ent/entc/load" + "github.com/stretchr/testify/require" +) + +func TestAuthIdentityFoundationSchemas(t *testing.T) { + spec, err := (&load.Config{Path: "."}).Load() + require.NoError(t, err) + + schemas := map[string]*load.Schema{} + for _, schema := range spec.Schemas { + schemas[schema.Name] = schema + } + + authIdentity := requireSchema(t, schemas, "AuthIdentity") + requireSchemaFields(t, authIdentity, + "user_id", + "provider_type", + "provider_key", + "provider_subject", + "verified_at", + "issuer", + "metadata", + ) + requireHasUniqueIndex(t, authIdentity, "provider_type", "provider_key", "provider_subject") + + authIdentityChannel := requireSchema(t, schemas, "AuthIdentityChannel") + requireSchemaFields(t, authIdentityChannel, + "identity_id", + "provider_type", + "provider_key", + "channel", + "channel_app_id", + "channel_subject", + "metadata", + ) + requireHasUniqueIndex(t, authIdentityChannel, "provider_type", "provider_key", "channel", "channel_app_id", "channel_subject") + + pendingAuthSession := requireSchema(t, schemas, "PendingAuthSession") + requireSchemaFields(t, pendingAuthSession, + "intent", + "provider_type", + "provider_key", + "provider_subject", + "target_user_id", + "redirect_to", + "resolved_email", + "registration_password_hash", + "upstream_identity_claims", + "local_flow_state", + "browser_session_key", + "completion_code_hash", + "completion_code_expires_at", + "email_verified_at", + "password_verified_at", + "totp_verified_at", + "expires_at", + "consumed_at", + ) + + adoptionDecision := requireSchema(t, schemas, "IdentityAdoptionDecision") + requireSchemaFields(t, adoptionDecision, + "pending_auth_session_id", + "identity_id", + "adopt_display_name", + "adopt_avatar", + "decided_at", + ) + requireHasUniqueIndex(t, adoptionDecision, "pending_auth_session_id") + + userSchema := requireSchema(t, schemas, "User") + requireSchemaFields(t, userSchema, "signup_source", "last_login_at", "last_active_at") +} + +func requireSchema(t *testing.T, schemas map[string]*load.Schema, name string) *load.Schema { + t.Helper() + + schema, ok := schemas[name] + require.True(t, ok, "schema %s should exist", name) + return schema +} + +func requireSchemaFields(t *testing.T, schema *load.Schema, names ...string) { + t.Helper() + + fields := map[string]struct{}{} + for _, field := range schema.Fields { + fields[field.Name] = struct{}{} + } + + for _, name := range names { + _, ok := fields[name] + require.True(t, ok, "schema %s should include field %s", schema.Name, name) + } +} + +func requireHasUniqueIndex(t *testing.T, schema *load.Schema, fields ...string) { + t.Helper() + + for _, index := range schema.Indexes { + if !index.Unique { + continue + } + if len(index.Fields) != len(fields) { + continue + } + match := true + for i := range fields { + if index.Fields[i] != fields[i] { + match = false + break + } + } + if match { + return + } + } + + require.Failf(t, "missing unique index", "schema %s should include unique index on %v", schema.Name, fields) +} diff --git a/backend/ent/schema/identity_adoption_decision.go b/backend/ent/schema/identity_adoption_decision.go new file mode 100644 index 0000000000000000000000000000000000000000..9fdd26fbca5120d4ddc17caed5cdfe82cfcedf27 --- /dev/null +++ b/backend/ent/schema/identity_adoption_decision.go @@ -0,0 +1,70 @@ +package schema + +import ( + "time" + + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" +) + +// IdentityAdoptionDecision stores the one-time profile adoption choice captured during a pending auth flow. +type IdentityAdoptionDecision struct { + ent.Schema +} + +func (IdentityAdoptionDecision) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "identity_adoption_decisions"}, + } +} + +func (IdentityAdoptionDecision) Mixin() []ent.Mixin { + return []ent.Mixin{ + mixins.TimeMixin{}, + } +} + +func (IdentityAdoptionDecision) Fields() []ent.Field { + return []ent.Field{ + field.Int64("pending_auth_session_id"), + field.Int64("identity_id"). + Optional(). + Nillable(), + field.Bool("adopt_display_name"). + Default(false), + field.Bool("adopt_avatar"). + Default(false), + field.Time("decided_at"). + Immutable(). + Default(time.Now). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + } +} + +func (IdentityAdoptionDecision) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("pending_auth_session", PendingAuthSession.Type). + Ref("adoption_decision"). + Field("pending_auth_session_id"). + Required(). + Unique(), + edge.From("identity", AuthIdentity.Type). + Ref("adoption_decisions"). + Field("identity_id"). + Unique(), + } +} + +func (IdentityAdoptionDecision) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("pending_auth_session_id").Unique(), + index.Fields("identity_id"), + } +} diff --git a/backend/ent/schema/payment_order.go b/backend/ent/schema/payment_order.go index a9576d2ab02aca442249de0457ea4321cf15f85e..5815d0327e27b04345200d9cbbef48229c0985f4 100644 --- a/backend/ent/schema/payment_order.go +++ b/backend/ent/schema/payment_order.go @@ -91,6 +91,13 @@ func (PaymentOrder) Fields() []ent.Field { Optional(). Nillable(). MaxLen(64), + field.String("provider_key"). + Optional(). + Nillable(). + MaxLen(30), + field.JSON("provider_snapshot", map[string]any{}). + Optional(). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), // 状态 field.String("status"). diff --git a/backend/ent/schema/pending_auth_session.go b/backend/ent/schema/pending_auth_session.go new file mode 100644 index 0000000000000000000000000000000000000000..91341d49ff0835fe6a2884d64aadd0d402274929 --- /dev/null +++ b/backend/ent/schema/pending_auth_session.go @@ -0,0 +1,134 @@ +package schema + +import ( + "fmt" + + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" +) + +var pendingAuthIntents = map[string]struct{}{ + "login": {}, + "bind_current_user": {}, + "adopt_existing_user_by_email": {}, +} + +func validatePendingAuthIntent(value string) error { + if _, ok := pendingAuthIntents[value]; ok { + return nil + } + return fmt.Errorf("invalid pending auth intent %q", value) +} + +// PendingAuthSession stores a short-lived post-auth decision session. +type PendingAuthSession struct { + ent.Schema +} + +func (PendingAuthSession) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "pending_auth_sessions"}, + } +} + +func (PendingAuthSession) Mixin() []ent.Mixin { + return []ent.Mixin{ + mixins.TimeMixin{}, + } +} + +func (PendingAuthSession) Fields() []ent.Field { + return []ent.Field{ + field.String("session_token"). + MaxLen(255). + NotEmpty(), + field.String("intent"). + MaxLen(40). + NotEmpty(). + Validate(validatePendingAuthIntent), + field.String("provider_type"). + MaxLen(20). + NotEmpty(). + Validate(validateAuthProviderType), + field.String("provider_key"). + NotEmpty(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.String("provider_subject"). + NotEmpty(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.Int64("target_user_id"). + Optional(). + Nillable(), + field.String("redirect_to"). + Default(""). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.String("resolved_email"). + Default(""). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.String("registration_password_hash"). + Default(""). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.JSON("upstream_identity_claims", map[string]any{}). + Default(func() map[string]any { return map[string]any{} }). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), + field.JSON("local_flow_state", map[string]any{}). + Default(func() map[string]any { return map[string]any{} }). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), + field.String("browser_session_key"). + Default(""). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.String("completion_code_hash"). + Default(""). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.Time("completion_code_expires_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Time("email_verified_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Time("password_verified_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Time("totp_verified_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Time("expires_at"). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Time("consumed_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + } +} + +func (PendingAuthSession) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("target_user", User.Type). + Ref("pending_auth_sessions"). + Field("target_user_id"). + Unique(), + edge.To("adoption_decision", IdentityAdoptionDecision.Type). + Unique(), + } +} + +func (PendingAuthSession) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("session_token").Unique(), + index.Fields("target_user_id"), + index.Fields("expires_at"), + index.Fields("provider_type", "provider_key", "provider_subject"), + index.Fields("completion_code_hash"), + } +} diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go index ef52e985d91d84e568518cbecb1684251d82652c..bb58d9e33bec037b56523d0741e3497d296f7c46 100644 --- a/backend/ent/schema/user.go +++ b/backend/ent/schema/user.go @@ -72,6 +72,17 @@ func (User) Fields() []ent.Field { field.Time("totp_enabled_at"). Optional(). Nillable(), + field.String("signup_source"). + MaxLen(20). + Default("email"), + field.Time("last_login_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Time("last_active_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), // 余额不足通知 field.Bool("balance_notify_enabled"). @@ -104,6 +115,8 @@ func (User) Edges() []ent.Edge { edge.To("attribute_values", UserAttributeValue.Type), edge.To("promo_code_usages", PromoCodeUsage.Type), edge.To("payment_orders", PaymentOrder.Type), + edge.To("auth_identities", AuthIdentity.Type), + edge.To("pending_auth_sessions", PendingAuthSession.Type), } } diff --git a/backend/ent/tx.go b/backend/ent/tx.go index bb3139d5c84119195c82c071194ca4b8663fbab0..bde3e35b7bc5373248965a7395e24aa5d27ea844 100644 --- a/backend/ent/tx.go +++ b/backend/ent/tx.go @@ -24,18 +24,26 @@ type Tx struct { Announcement *AnnouncementClient // AnnouncementRead is the client for interacting with the AnnouncementRead builders. AnnouncementRead *AnnouncementReadClient + // AuthIdentity is the client for interacting with the AuthIdentity builders. + AuthIdentity *AuthIdentityClient + // AuthIdentityChannel is the client for interacting with the AuthIdentityChannel builders. + AuthIdentityChannel *AuthIdentityChannelClient // ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders. ErrorPassthroughRule *ErrorPassthroughRuleClient // Group is the client for interacting with the Group builders. Group *GroupClient // IdempotencyRecord is the client for interacting with the IdempotencyRecord builders. IdempotencyRecord *IdempotencyRecordClient + // IdentityAdoptionDecision is the client for interacting with the IdentityAdoptionDecision builders. + IdentityAdoptionDecision *IdentityAdoptionDecisionClient // PaymentAuditLog is the client for interacting with the PaymentAuditLog builders. PaymentAuditLog *PaymentAuditLogClient // PaymentOrder is the client for interacting with the PaymentOrder builders. PaymentOrder *PaymentOrderClient // PaymentProviderInstance is the client for interacting with the PaymentProviderInstance builders. PaymentProviderInstance *PaymentProviderInstanceClient + // PendingAuthSession is the client for interacting with the PendingAuthSession builders. + PendingAuthSession *PendingAuthSessionClient // PromoCode is the client for interacting with the PromoCode builders. PromoCode *PromoCodeClient // PromoCodeUsage is the client for interacting with the PromoCodeUsage builders. @@ -202,12 +210,16 @@ func (tx *Tx) init() { tx.AccountGroup = NewAccountGroupClient(tx.config) tx.Announcement = NewAnnouncementClient(tx.config) tx.AnnouncementRead = NewAnnouncementReadClient(tx.config) + tx.AuthIdentity = NewAuthIdentityClient(tx.config) + tx.AuthIdentityChannel = NewAuthIdentityChannelClient(tx.config) tx.ErrorPassthroughRule = NewErrorPassthroughRuleClient(tx.config) tx.Group = NewGroupClient(tx.config) tx.IdempotencyRecord = NewIdempotencyRecordClient(tx.config) + tx.IdentityAdoptionDecision = NewIdentityAdoptionDecisionClient(tx.config) tx.PaymentAuditLog = NewPaymentAuditLogClient(tx.config) tx.PaymentOrder = NewPaymentOrderClient(tx.config) tx.PaymentProviderInstance = NewPaymentProviderInstanceClient(tx.config) + tx.PendingAuthSession = NewPendingAuthSessionClient(tx.config) tx.PromoCode = NewPromoCodeClient(tx.config) tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config) tx.Proxy = NewProxyClient(tx.config) diff --git a/backend/ent/user.go b/backend/ent/user.go index 9fa91f74b9dda8b6bbeedc5f1f6fdf931df018d3..66f33623debd52b9b1e6f76e98a180a32fe54393 100644 --- a/backend/ent/user.go +++ b/backend/ent/user.go @@ -45,6 +45,12 @@ type User struct { TotpEnabled bool `json:"totp_enabled,omitempty"` // TotpEnabledAt holds the value of the "totp_enabled_at" field. TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"` + // SignupSource holds the value of the "signup_source" field. + SignupSource string `json:"signup_source,omitempty"` + // LastLoginAt holds the value of the "last_login_at" field. + LastLoginAt *time.Time `json:"last_login_at,omitempty"` + // LastActiveAt holds the value of the "last_active_at" field. + LastActiveAt *time.Time `json:"last_active_at,omitempty"` // BalanceNotifyEnabled holds the value of the "balance_notify_enabled" field. BalanceNotifyEnabled bool `json:"balance_notify_enabled,omitempty"` // BalanceNotifyThresholdType holds the value of the "balance_notify_threshold_type" field. @@ -83,11 +89,15 @@ type UserEdges struct { PromoCodeUsages []*PromoCodeUsage `json:"promo_code_usages,omitempty"` // PaymentOrders holds the value of the payment_orders edge. PaymentOrders []*PaymentOrder `json:"payment_orders,omitempty"` + // AuthIdentities holds the value of the auth_identities edge. + AuthIdentities []*AuthIdentity `json:"auth_identities,omitempty"` + // PendingAuthSessions holds the value of the pending_auth_sessions edge. + PendingAuthSessions []*PendingAuthSession `json:"pending_auth_sessions,omitempty"` // UserAllowedGroups holds the value of the user_allowed_groups edge. UserAllowedGroups []*UserAllowedGroup `json:"user_allowed_groups,omitempty"` // loadedTypes holds the information for reporting if a // type was loaded (or requested) in eager-loading or not. - loadedTypes [11]bool + loadedTypes [13]bool } // APIKeysOrErr returns the APIKeys value or an error if the edge @@ -180,10 +190,28 @@ func (e UserEdges) PaymentOrdersOrErr() ([]*PaymentOrder, error) { return nil, &NotLoadedError{edge: "payment_orders"} } +// AuthIdentitiesOrErr returns the AuthIdentities value or an error if the edge +// was not loaded in eager-loading. +func (e UserEdges) AuthIdentitiesOrErr() ([]*AuthIdentity, error) { + if e.loadedTypes[10] { + return e.AuthIdentities, nil + } + return nil, &NotLoadedError{edge: "auth_identities"} +} + +// PendingAuthSessionsOrErr returns the PendingAuthSessions value or an error if the edge +// was not loaded in eager-loading. +func (e UserEdges) PendingAuthSessionsOrErr() ([]*PendingAuthSession, error) { + if e.loadedTypes[11] { + return e.PendingAuthSessions, nil + } + return nil, &NotLoadedError{edge: "pending_auth_sessions"} +} + // UserAllowedGroupsOrErr returns the UserAllowedGroups value or an error if the edge // was not loaded in eager-loading. func (e UserEdges) UserAllowedGroupsOrErr() ([]*UserAllowedGroup, error) { - if e.loadedTypes[10] { + if e.loadedTypes[12] { return e.UserAllowedGroups, nil } return nil, &NotLoadedError{edge: "user_allowed_groups"} @@ -200,9 +228,9 @@ func (*User) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullFloat64) case user.FieldID, user.FieldConcurrency: values[i] = new(sql.NullInt64) - case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldBalanceNotifyThresholdType, user.FieldBalanceNotifyExtraEmails: + case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldSignupSource, user.FieldBalanceNotifyThresholdType, user.FieldBalanceNotifyExtraEmails: values[i] = new(sql.NullString) - case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt: + case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt, user.FieldLastLoginAt, user.FieldLastActiveAt: values[i] = new(sql.NullTime) default: values[i] = new(sql.UnknownType) @@ -312,6 +340,26 @@ func (_m *User) assignValues(columns []string, values []any) error { _m.TotpEnabledAt = new(time.Time) *_m.TotpEnabledAt = value.Time } + case user.FieldSignupSource: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field signup_source", values[i]) + } else if value.Valid { + _m.SignupSource = value.String + } + case user.FieldLastLoginAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field last_login_at", values[i]) + } else if value.Valid { + _m.LastLoginAt = new(time.Time) + *_m.LastLoginAt = value.Time + } + case user.FieldLastActiveAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field last_active_at", values[i]) + } else if value.Valid { + _m.LastActiveAt = new(time.Time) + *_m.LastActiveAt = value.Time + } case user.FieldBalanceNotifyEnabled: if value, ok := values[i].(*sql.NullBool); !ok { return fmt.Errorf("unexpected type %T for field balance_notify_enabled", values[i]) @@ -406,6 +454,16 @@ func (_m *User) QueryPaymentOrders() *PaymentOrderQuery { return NewUserClient(_m.config).QueryPaymentOrders(_m) } +// QueryAuthIdentities queries the "auth_identities" edge of the User entity. +func (_m *User) QueryAuthIdentities() *AuthIdentityQuery { + return NewUserClient(_m.config).QueryAuthIdentities(_m) +} + +// QueryPendingAuthSessions queries the "pending_auth_sessions" edge of the User entity. +func (_m *User) QueryPendingAuthSessions() *PendingAuthSessionQuery { + return NewUserClient(_m.config).QueryPendingAuthSessions(_m) +} + // QueryUserAllowedGroups queries the "user_allowed_groups" edge of the User entity. func (_m *User) QueryUserAllowedGroups() *UserAllowedGroupQuery { return NewUserClient(_m.config).QueryUserAllowedGroups(_m) @@ -482,6 +540,19 @@ func (_m *User) String() string { builder.WriteString(v.Format(time.ANSIC)) } builder.WriteString(", ") + builder.WriteString("signup_source=") + builder.WriteString(_m.SignupSource) + builder.WriteString(", ") + if v := _m.LastLoginAt; v != nil { + builder.WriteString("last_login_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.LastActiveAt; v != nil { + builder.WriteString("last_active_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") builder.WriteString("balance_notify_enabled=") builder.WriteString(fmt.Sprintf("%v", _m.BalanceNotifyEnabled)) builder.WriteString(", ") diff --git a/backend/ent/user/user.go b/backend/ent/user/user.go index d88a3a380b165f7a8988710ff761b2895aedbec0..567e3b14f27e90b89bf094a3347a2e181fb45cf3 100644 --- a/backend/ent/user/user.go +++ b/backend/ent/user/user.go @@ -43,6 +43,12 @@ const ( FieldTotpEnabled = "totp_enabled" // FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database. FieldTotpEnabledAt = "totp_enabled_at" + // FieldSignupSource holds the string denoting the signup_source field in the database. + FieldSignupSource = "signup_source" + // FieldLastLoginAt holds the string denoting the last_login_at field in the database. + FieldLastLoginAt = "last_login_at" + // FieldLastActiveAt holds the string denoting the last_active_at field in the database. + FieldLastActiveAt = "last_active_at" // FieldBalanceNotifyEnabled holds the string denoting the balance_notify_enabled field in the database. FieldBalanceNotifyEnabled = "balance_notify_enabled" // FieldBalanceNotifyThresholdType holds the string denoting the balance_notify_threshold_type field in the database. @@ -73,6 +79,10 @@ const ( EdgePromoCodeUsages = "promo_code_usages" // EdgePaymentOrders holds the string denoting the payment_orders edge name in mutations. EdgePaymentOrders = "payment_orders" + // EdgeAuthIdentities holds the string denoting the auth_identities edge name in mutations. + EdgeAuthIdentities = "auth_identities" + // EdgePendingAuthSessions holds the string denoting the pending_auth_sessions edge name in mutations. + EdgePendingAuthSessions = "pending_auth_sessions" // EdgeUserAllowedGroups holds the string denoting the user_allowed_groups edge name in mutations. EdgeUserAllowedGroups = "user_allowed_groups" // Table holds the table name of the user in the database. @@ -145,6 +155,20 @@ const ( PaymentOrdersInverseTable = "payment_orders" // PaymentOrdersColumn is the table column denoting the payment_orders relation/edge. PaymentOrdersColumn = "user_id" + // AuthIdentitiesTable is the table that holds the auth_identities relation/edge. + AuthIdentitiesTable = "auth_identities" + // AuthIdentitiesInverseTable is the table name for the AuthIdentity entity. + // It exists in this package in order to avoid circular dependency with the "authidentity" package. + AuthIdentitiesInverseTable = "auth_identities" + // AuthIdentitiesColumn is the table column denoting the auth_identities relation/edge. + AuthIdentitiesColumn = "user_id" + // PendingAuthSessionsTable is the table that holds the pending_auth_sessions relation/edge. + PendingAuthSessionsTable = "pending_auth_sessions" + // PendingAuthSessionsInverseTable is the table name for the PendingAuthSession entity. + // It exists in this package in order to avoid circular dependency with the "pendingauthsession" package. + PendingAuthSessionsInverseTable = "pending_auth_sessions" + // PendingAuthSessionsColumn is the table column denoting the pending_auth_sessions relation/edge. + PendingAuthSessionsColumn = "target_user_id" // UserAllowedGroupsTable is the table that holds the user_allowed_groups relation/edge. UserAllowedGroupsTable = "user_allowed_groups" // UserAllowedGroupsInverseTable is the table name for the UserAllowedGroup entity. @@ -171,6 +195,9 @@ var Columns = []string{ FieldTotpSecretEncrypted, FieldTotpEnabled, FieldTotpEnabledAt, + FieldSignupSource, + FieldLastLoginAt, + FieldLastActiveAt, FieldBalanceNotifyEnabled, FieldBalanceNotifyThresholdType, FieldBalanceNotifyThreshold, @@ -232,6 +259,10 @@ var ( DefaultNotes string // DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field. DefaultTotpEnabled bool + // DefaultSignupSource holds the default value on creation for the "signup_source" field. + DefaultSignupSource string + // SignupSourceValidator is a validator for the "signup_source" field. It is called by the builders before save. + SignupSourceValidator func(string) error // DefaultBalanceNotifyEnabled holds the default value on creation for the "balance_notify_enabled" field. DefaultBalanceNotifyEnabled bool // DefaultBalanceNotifyThresholdType holds the default value on creation for the "balance_notify_threshold_type" field. @@ -320,6 +351,21 @@ func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc() } +// BySignupSource orders the results by the signup_source field. +func BySignupSource(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSignupSource, opts...).ToFunc() +} + +// ByLastLoginAt orders the results by the last_login_at field. +func ByLastLoginAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldLastLoginAt, opts...).ToFunc() +} + +// ByLastActiveAt orders the results by the last_active_at field. +func ByLastActiveAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldLastActiveAt, opts...).ToFunc() +} + // ByBalanceNotifyEnabled orders the results by the balance_notify_enabled field. func ByBalanceNotifyEnabled(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldBalanceNotifyEnabled, opts...).ToFunc() @@ -485,6 +531,34 @@ func ByPaymentOrders(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { } } +// ByAuthIdentitiesCount orders the results by auth_identities count. +func ByAuthIdentitiesCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newAuthIdentitiesStep(), opts...) + } +} + +// ByAuthIdentities orders the results by auth_identities terms. +func ByAuthIdentities(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newAuthIdentitiesStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// ByPendingAuthSessionsCount orders the results by pending_auth_sessions count. +func ByPendingAuthSessionsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newPendingAuthSessionsStep(), opts...) + } +} + +// ByPendingAuthSessions orders the results by pending_auth_sessions terms. +func ByPendingAuthSessions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newPendingAuthSessionsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + // ByUserAllowedGroupsCount orders the results by user_allowed_groups count. func ByUserAllowedGroupsCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { @@ -568,6 +642,20 @@ func newPaymentOrdersStep() *sqlgraph.Step { sqlgraph.Edge(sqlgraph.O2M, false, PaymentOrdersTable, PaymentOrdersColumn), ) } +func newAuthIdentitiesStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(AuthIdentitiesInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, AuthIdentitiesTable, AuthIdentitiesColumn), + ) +} +func newPendingAuthSessionsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(PendingAuthSessionsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, PendingAuthSessionsTable, PendingAuthSessionsColumn), + ) +} func newUserAllowedGroupsStep() *sqlgraph.Step { return sqlgraph.NewStep( sqlgraph.From(Table, FieldID), diff --git a/backend/ent/user/where.go b/backend/ent/user/where.go index 2788aa7adc4f41e98c35e1d3810673632c0effcc..cbcfcc269e6dd15b0c953eb6291fdca2e5b6d5d4 100644 --- a/backend/ent/user/where.go +++ b/backend/ent/user/where.go @@ -125,6 +125,21 @@ func TotpEnabledAt(v time.Time) predicate.User { return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v)) } +// SignupSource applies equality check predicate on the "signup_source" field. It's identical to SignupSourceEQ. +func SignupSource(v string) predicate.User { + return predicate.User(sql.FieldEQ(FieldSignupSource, v)) +} + +// LastLoginAt applies equality check predicate on the "last_login_at" field. It's identical to LastLoginAtEQ. +func LastLoginAt(v time.Time) predicate.User { + return predicate.User(sql.FieldEQ(FieldLastLoginAt, v)) +} + +// LastActiveAt applies equality check predicate on the "last_active_at" field. It's identical to LastActiveAtEQ. +func LastActiveAt(v time.Time) predicate.User { + return predicate.User(sql.FieldEQ(FieldLastActiveAt, v)) +} + // BalanceNotifyEnabled applies equality check predicate on the "balance_notify_enabled" field. It's identical to BalanceNotifyEnabledEQ. func BalanceNotifyEnabled(v bool) predicate.User { return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v)) @@ -885,6 +900,171 @@ func TotpEnabledAtNotNil() predicate.User { return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt)) } +// SignupSourceEQ applies the EQ predicate on the "signup_source" field. +func SignupSourceEQ(v string) predicate.User { + return predicate.User(sql.FieldEQ(FieldSignupSource, v)) +} + +// SignupSourceNEQ applies the NEQ predicate on the "signup_source" field. +func SignupSourceNEQ(v string) predicate.User { + return predicate.User(sql.FieldNEQ(FieldSignupSource, v)) +} + +// SignupSourceIn applies the In predicate on the "signup_source" field. +func SignupSourceIn(vs ...string) predicate.User { + return predicate.User(sql.FieldIn(FieldSignupSource, vs...)) +} + +// SignupSourceNotIn applies the NotIn predicate on the "signup_source" field. +func SignupSourceNotIn(vs ...string) predicate.User { + return predicate.User(sql.FieldNotIn(FieldSignupSource, vs...)) +} + +// SignupSourceGT applies the GT predicate on the "signup_source" field. +func SignupSourceGT(v string) predicate.User { + return predicate.User(sql.FieldGT(FieldSignupSource, v)) +} + +// SignupSourceGTE applies the GTE predicate on the "signup_source" field. +func SignupSourceGTE(v string) predicate.User { + return predicate.User(sql.FieldGTE(FieldSignupSource, v)) +} + +// SignupSourceLT applies the LT predicate on the "signup_source" field. +func SignupSourceLT(v string) predicate.User { + return predicate.User(sql.FieldLT(FieldSignupSource, v)) +} + +// SignupSourceLTE applies the LTE predicate on the "signup_source" field. +func SignupSourceLTE(v string) predicate.User { + return predicate.User(sql.FieldLTE(FieldSignupSource, v)) +} + +// SignupSourceContains applies the Contains predicate on the "signup_source" field. +func SignupSourceContains(v string) predicate.User { + return predicate.User(sql.FieldContains(FieldSignupSource, v)) +} + +// SignupSourceHasPrefix applies the HasPrefix predicate on the "signup_source" field. +func SignupSourceHasPrefix(v string) predicate.User { + return predicate.User(sql.FieldHasPrefix(FieldSignupSource, v)) +} + +// SignupSourceHasSuffix applies the HasSuffix predicate on the "signup_source" field. +func SignupSourceHasSuffix(v string) predicate.User { + return predicate.User(sql.FieldHasSuffix(FieldSignupSource, v)) +} + +// SignupSourceEqualFold applies the EqualFold predicate on the "signup_source" field. +func SignupSourceEqualFold(v string) predicate.User { + return predicate.User(sql.FieldEqualFold(FieldSignupSource, v)) +} + +// SignupSourceContainsFold applies the ContainsFold predicate on the "signup_source" field. +func SignupSourceContainsFold(v string) predicate.User { + return predicate.User(sql.FieldContainsFold(FieldSignupSource, v)) +} + +// LastLoginAtEQ applies the EQ predicate on the "last_login_at" field. +func LastLoginAtEQ(v time.Time) predicate.User { + return predicate.User(sql.FieldEQ(FieldLastLoginAt, v)) +} + +// LastLoginAtNEQ applies the NEQ predicate on the "last_login_at" field. +func LastLoginAtNEQ(v time.Time) predicate.User { + return predicate.User(sql.FieldNEQ(FieldLastLoginAt, v)) +} + +// LastLoginAtIn applies the In predicate on the "last_login_at" field. +func LastLoginAtIn(vs ...time.Time) predicate.User { + return predicate.User(sql.FieldIn(FieldLastLoginAt, vs...)) +} + +// LastLoginAtNotIn applies the NotIn predicate on the "last_login_at" field. +func LastLoginAtNotIn(vs ...time.Time) predicate.User { + return predicate.User(sql.FieldNotIn(FieldLastLoginAt, vs...)) +} + +// LastLoginAtGT applies the GT predicate on the "last_login_at" field. +func LastLoginAtGT(v time.Time) predicate.User { + return predicate.User(sql.FieldGT(FieldLastLoginAt, v)) +} + +// LastLoginAtGTE applies the GTE predicate on the "last_login_at" field. +func LastLoginAtGTE(v time.Time) predicate.User { + return predicate.User(sql.FieldGTE(FieldLastLoginAt, v)) +} + +// LastLoginAtLT applies the LT predicate on the "last_login_at" field. +func LastLoginAtLT(v time.Time) predicate.User { + return predicate.User(sql.FieldLT(FieldLastLoginAt, v)) +} + +// LastLoginAtLTE applies the LTE predicate on the "last_login_at" field. +func LastLoginAtLTE(v time.Time) predicate.User { + return predicate.User(sql.FieldLTE(FieldLastLoginAt, v)) +} + +// LastLoginAtIsNil applies the IsNil predicate on the "last_login_at" field. +func LastLoginAtIsNil() predicate.User { + return predicate.User(sql.FieldIsNull(FieldLastLoginAt)) +} + +// LastLoginAtNotNil applies the NotNil predicate on the "last_login_at" field. +func LastLoginAtNotNil() predicate.User { + return predicate.User(sql.FieldNotNull(FieldLastLoginAt)) +} + +// LastActiveAtEQ applies the EQ predicate on the "last_active_at" field. +func LastActiveAtEQ(v time.Time) predicate.User { + return predicate.User(sql.FieldEQ(FieldLastActiveAt, v)) +} + +// LastActiveAtNEQ applies the NEQ predicate on the "last_active_at" field. +func LastActiveAtNEQ(v time.Time) predicate.User { + return predicate.User(sql.FieldNEQ(FieldLastActiveAt, v)) +} + +// LastActiveAtIn applies the In predicate on the "last_active_at" field. +func LastActiveAtIn(vs ...time.Time) predicate.User { + return predicate.User(sql.FieldIn(FieldLastActiveAt, vs...)) +} + +// LastActiveAtNotIn applies the NotIn predicate on the "last_active_at" field. +func LastActiveAtNotIn(vs ...time.Time) predicate.User { + return predicate.User(sql.FieldNotIn(FieldLastActiveAt, vs...)) +} + +// LastActiveAtGT applies the GT predicate on the "last_active_at" field. +func LastActiveAtGT(v time.Time) predicate.User { + return predicate.User(sql.FieldGT(FieldLastActiveAt, v)) +} + +// LastActiveAtGTE applies the GTE predicate on the "last_active_at" field. +func LastActiveAtGTE(v time.Time) predicate.User { + return predicate.User(sql.FieldGTE(FieldLastActiveAt, v)) +} + +// LastActiveAtLT applies the LT predicate on the "last_active_at" field. +func LastActiveAtLT(v time.Time) predicate.User { + return predicate.User(sql.FieldLT(FieldLastActiveAt, v)) +} + +// LastActiveAtLTE applies the LTE predicate on the "last_active_at" field. +func LastActiveAtLTE(v time.Time) predicate.User { + return predicate.User(sql.FieldLTE(FieldLastActiveAt, v)) +} + +// LastActiveAtIsNil applies the IsNil predicate on the "last_active_at" field. +func LastActiveAtIsNil() predicate.User { + return predicate.User(sql.FieldIsNull(FieldLastActiveAt)) +} + +// LastActiveAtNotNil applies the NotNil predicate on the "last_active_at" field. +func LastActiveAtNotNil() predicate.User { + return predicate.User(sql.FieldNotNull(FieldLastActiveAt)) +} + // BalanceNotifyEnabledEQ applies the EQ predicate on the "balance_notify_enabled" field. func BalanceNotifyEnabledEQ(v bool) predicate.User { return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v)) @@ -1345,6 +1525,52 @@ func HasPaymentOrdersWith(preds ...predicate.PaymentOrder) predicate.User { }) } +// HasAuthIdentities applies the HasEdge predicate on the "auth_identities" edge. +func HasAuthIdentities() predicate.User { + return predicate.User(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, AuthIdentitiesTable, AuthIdentitiesColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasAuthIdentitiesWith applies the HasEdge predicate on the "auth_identities" edge with a given conditions (other predicates). +func HasAuthIdentitiesWith(preds ...predicate.AuthIdentity) predicate.User { + return predicate.User(func(s *sql.Selector) { + step := newAuthIdentitiesStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasPendingAuthSessions applies the HasEdge predicate on the "pending_auth_sessions" edge. +func HasPendingAuthSessions() predicate.User { + return predicate.User(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, PendingAuthSessionsTable, PendingAuthSessionsColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasPendingAuthSessionsWith applies the HasEdge predicate on the "pending_auth_sessions" edge with a given conditions (other predicates). +func HasPendingAuthSessionsWith(preds ...predicate.PendingAuthSession) predicate.User { + return predicate.User(func(s *sql.Selector) { + step := newPendingAuthSessionsStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + // HasUserAllowedGroups applies the HasEdge predicate on the "user_allowed_groups" edge. func HasUserAllowedGroups() predicate.User { return predicate.User(func(s *sql.Selector) { diff --git a/backend/ent/user_create.go b/backend/ent/user_create.go index fbc64f9c46d852017276c26af0071460689b6058..db95e813e3847a4cef4634a17b05ab94093e15cf 100644 --- a/backend/ent/user_create.go +++ b/backend/ent/user_create.go @@ -13,8 +13,10 @@ import ( "entgo.io/ent/schema/field" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/paymentorder" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/usagelog" @@ -211,6 +213,48 @@ func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate { return _c } +// SetSignupSource sets the "signup_source" field. +func (_c *UserCreate) SetSignupSource(v string) *UserCreate { + _c.mutation.SetSignupSource(v) + return _c +} + +// SetNillableSignupSource sets the "signup_source" field if the given value is not nil. +func (_c *UserCreate) SetNillableSignupSource(v *string) *UserCreate { + if v != nil { + _c.SetSignupSource(*v) + } + return _c +} + +// SetLastLoginAt sets the "last_login_at" field. +func (_c *UserCreate) SetLastLoginAt(v time.Time) *UserCreate { + _c.mutation.SetLastLoginAt(v) + return _c +} + +// SetNillableLastLoginAt sets the "last_login_at" field if the given value is not nil. +func (_c *UserCreate) SetNillableLastLoginAt(v *time.Time) *UserCreate { + if v != nil { + _c.SetLastLoginAt(*v) + } + return _c +} + +// SetLastActiveAt sets the "last_active_at" field. +func (_c *UserCreate) SetLastActiveAt(v time.Time) *UserCreate { + _c.mutation.SetLastActiveAt(v) + return _c +} + +// SetNillableLastActiveAt sets the "last_active_at" field if the given value is not nil. +func (_c *UserCreate) SetNillableLastActiveAt(v *time.Time) *UserCreate { + if v != nil { + _c.SetLastActiveAt(*v) + } + return _c +} + // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. func (_c *UserCreate) SetBalanceNotifyEnabled(v bool) *UserCreate { _c.mutation.SetBalanceNotifyEnabled(v) @@ -431,6 +475,36 @@ func (_c *UserCreate) AddPaymentOrders(v ...*PaymentOrder) *UserCreate { return _c.AddPaymentOrderIDs(ids...) } +// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by IDs. +func (_c *UserCreate) AddAuthIdentityIDs(ids ...int64) *UserCreate { + _c.mutation.AddAuthIdentityIDs(ids...) + return _c +} + +// AddAuthIdentities adds the "auth_identities" edges to the AuthIdentity entity. +func (_c *UserCreate) AddAuthIdentities(v ...*AuthIdentity) *UserCreate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _c.AddAuthIdentityIDs(ids...) +} + +// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs. +func (_c *UserCreate) AddPendingAuthSessionIDs(ids ...int64) *UserCreate { + _c.mutation.AddPendingAuthSessionIDs(ids...) + return _c +} + +// AddPendingAuthSessions adds the "pending_auth_sessions" edges to the PendingAuthSession entity. +func (_c *UserCreate) AddPendingAuthSessions(v ...*PendingAuthSession) *UserCreate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _c.AddPendingAuthSessionIDs(ids...) +} + // Mutation returns the UserMutation object of the builder. func (_c *UserCreate) Mutation() *UserMutation { return _c.mutation @@ -510,6 +584,10 @@ func (_c *UserCreate) defaults() error { v := user.DefaultTotpEnabled _c.mutation.SetTotpEnabled(v) } + if _, ok := _c.mutation.SignupSource(); !ok { + v := user.DefaultSignupSource + _c.mutation.SetSignupSource(v) + } if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok { v := user.DefaultBalanceNotifyEnabled _c.mutation.SetBalanceNotifyEnabled(v) @@ -589,6 +667,14 @@ func (_c *UserCreate) check() error { if _, ok := _c.mutation.TotpEnabled(); !ok { return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)} } + if _, ok := _c.mutation.SignupSource(); !ok { + return &ValidationError{Name: "signup_source", err: errors.New(`ent: missing required field "User.signup_source"`)} + } + if v, ok := _c.mutation.SignupSource(); ok { + if err := user.SignupSourceValidator(v); err != nil { + return &ValidationError{Name: "signup_source", err: fmt.Errorf(`ent: validator failed for field "User.signup_source": %w`, err)} + } + } if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok { return &ValidationError{Name: "balance_notify_enabled", err: errors.New(`ent: missing required field "User.balance_notify_enabled"`)} } @@ -684,6 +770,18 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { _spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value) _node.TotpEnabledAt = &value } + if value, ok := _c.mutation.SignupSource(); ok { + _spec.SetField(user.FieldSignupSource, field.TypeString, value) + _node.SignupSource = value + } + if value, ok := _c.mutation.LastLoginAt(); ok { + _spec.SetField(user.FieldLastLoginAt, field.TypeTime, value) + _node.LastLoginAt = &value + } + if value, ok := _c.mutation.LastActiveAt(); ok { + _spec.SetField(user.FieldLastActiveAt, field.TypeTime, value) + _node.LastActiveAt = &value + } if value, ok := _c.mutation.BalanceNotifyEnabled(); ok { _spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value) _node.BalanceNotifyEnabled = value @@ -868,6 +966,38 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { } _spec.Edges = append(_spec.Edges, edge) } + if nodes := _c.mutation.AuthIdentitiesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.AuthIdentitiesTable, + Columns: []string{user.AuthIdentitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := _c.mutation.PendingAuthSessionsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PendingAuthSessionsTable, + Columns: []string{user.PendingAuthSessionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } return _node, _spec } @@ -1106,6 +1236,54 @@ func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert { return u } +// SetSignupSource sets the "signup_source" field. +func (u *UserUpsert) SetSignupSource(v string) *UserUpsert { + u.Set(user.FieldSignupSource, v) + return u +} + +// UpdateSignupSource sets the "signup_source" field to the value that was provided on create. +func (u *UserUpsert) UpdateSignupSource() *UserUpsert { + u.SetExcluded(user.FieldSignupSource) + return u +} + +// SetLastLoginAt sets the "last_login_at" field. +func (u *UserUpsert) SetLastLoginAt(v time.Time) *UserUpsert { + u.Set(user.FieldLastLoginAt, v) + return u +} + +// UpdateLastLoginAt sets the "last_login_at" field to the value that was provided on create. +func (u *UserUpsert) UpdateLastLoginAt() *UserUpsert { + u.SetExcluded(user.FieldLastLoginAt) + return u +} + +// ClearLastLoginAt clears the value of the "last_login_at" field. +func (u *UserUpsert) ClearLastLoginAt() *UserUpsert { + u.SetNull(user.FieldLastLoginAt) + return u +} + +// SetLastActiveAt sets the "last_active_at" field. +func (u *UserUpsert) SetLastActiveAt(v time.Time) *UserUpsert { + u.Set(user.FieldLastActiveAt, v) + return u +} + +// UpdateLastActiveAt sets the "last_active_at" field to the value that was provided on create. +func (u *UserUpsert) UpdateLastActiveAt() *UserUpsert { + u.SetExcluded(user.FieldLastActiveAt) + return u +} + +// ClearLastActiveAt clears the value of the "last_active_at" field. +func (u *UserUpsert) ClearLastActiveAt() *UserUpsert { + u.SetNull(user.FieldLastActiveAt) + return u +} + // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. func (u *UserUpsert) SetBalanceNotifyEnabled(v bool) *UserUpsert { u.Set(user.FieldBalanceNotifyEnabled, v) @@ -1446,6 +1624,62 @@ func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne { }) } +// SetSignupSource sets the "signup_source" field. +func (u *UserUpsertOne) SetSignupSource(v string) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetSignupSource(v) + }) +} + +// UpdateSignupSource sets the "signup_source" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateSignupSource() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateSignupSource() + }) +} + +// SetLastLoginAt sets the "last_login_at" field. +func (u *UserUpsertOne) SetLastLoginAt(v time.Time) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetLastLoginAt(v) + }) +} + +// UpdateLastLoginAt sets the "last_login_at" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateLastLoginAt() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateLastLoginAt() + }) +} + +// ClearLastLoginAt clears the value of the "last_login_at" field. +func (u *UserUpsertOne) ClearLastLoginAt() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.ClearLastLoginAt() + }) +} + +// SetLastActiveAt sets the "last_active_at" field. +func (u *UserUpsertOne) SetLastActiveAt(v time.Time) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetLastActiveAt(v) + }) +} + +// UpdateLastActiveAt sets the "last_active_at" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateLastActiveAt() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateLastActiveAt() + }) +} + +// ClearLastActiveAt clears the value of the "last_active_at" field. +func (u *UserUpsertOne) ClearLastActiveAt() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.ClearLastActiveAt() + }) +} + // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. func (u *UserUpsertOne) SetBalanceNotifyEnabled(v bool) *UserUpsertOne { return u.Update(func(s *UserUpsert) { @@ -1965,6 +2199,62 @@ func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk { }) } +// SetSignupSource sets the "signup_source" field. +func (u *UserUpsertBulk) SetSignupSource(v string) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetSignupSource(v) + }) +} + +// UpdateSignupSource sets the "signup_source" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateSignupSource() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateSignupSource() + }) +} + +// SetLastLoginAt sets the "last_login_at" field. +func (u *UserUpsertBulk) SetLastLoginAt(v time.Time) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetLastLoginAt(v) + }) +} + +// UpdateLastLoginAt sets the "last_login_at" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateLastLoginAt() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateLastLoginAt() + }) +} + +// ClearLastLoginAt clears the value of the "last_login_at" field. +func (u *UserUpsertBulk) ClearLastLoginAt() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.ClearLastLoginAt() + }) +} + +// SetLastActiveAt sets the "last_active_at" field. +func (u *UserUpsertBulk) SetLastActiveAt(v time.Time) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetLastActiveAt(v) + }) +} + +// UpdateLastActiveAt sets the "last_active_at" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateLastActiveAt() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateLastActiveAt() + }) +} + +// ClearLastActiveAt clears the value of the "last_active_at" field. +func (u *UserUpsertBulk) ClearLastActiveAt() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.ClearLastActiveAt() + }) +} + // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. func (u *UserUpsertBulk) SetBalanceNotifyEnabled(v bool) *UserUpsertBulk { return u.Update(func(s *UserUpsert) { diff --git a/backend/ent/user_query.go b/backend/ent/user_query.go index 113d87aca24a273a137326557955a6b4eb60b751..f1ee5cfe0aad9fb821eaab3d810683b5f571c6e2 100644 --- a/backend/ent/user_query.go +++ b/backend/ent/user_query.go @@ -15,8 +15,10 @@ import ( "entgo.io/ent/schema/field" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/paymentorder" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/redeemcode" @@ -44,6 +46,8 @@ type UserQuery struct { withAttributeValues *UserAttributeValueQuery withPromoCodeUsages *PromoCodeUsageQuery withPaymentOrders *PaymentOrderQuery + withAuthIdentities *AuthIdentityQuery + withPendingAuthSessions *PendingAuthSessionQuery withUserAllowedGroups *UserAllowedGroupQuery modifiers []func(*sql.Selector) // intermediate query (i.e. traversal path). @@ -302,6 +306,50 @@ func (_q *UserQuery) QueryPaymentOrders() *PaymentOrderQuery { return query } +// QueryAuthIdentities chains the current query on the "auth_identities" edge. +func (_q *UserQuery) QueryAuthIdentities() *AuthIdentityQuery { + query := (&AuthIdentityClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, selector), + sqlgraph.To(authidentity.Table, authidentity.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.AuthIdentitiesTable, user.AuthIdentitiesColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryPendingAuthSessions chains the current query on the "pending_auth_sessions" edge. +func (_q *UserQuery) QueryPendingAuthSessions() *PendingAuthSessionQuery { + query := (&PendingAuthSessionClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, selector), + sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.PendingAuthSessionsTable, user.PendingAuthSessionsColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + // QueryUserAllowedGroups chains the current query on the "user_allowed_groups" edge. func (_q *UserQuery) QueryUserAllowedGroups() *UserAllowedGroupQuery { query := (&UserAllowedGroupClient{config: _q.config}).Query() @@ -526,6 +574,8 @@ func (_q *UserQuery) Clone() *UserQuery { withAttributeValues: _q.withAttributeValues.Clone(), withPromoCodeUsages: _q.withPromoCodeUsages.Clone(), withPaymentOrders: _q.withPaymentOrders.Clone(), + withAuthIdentities: _q.withAuthIdentities.Clone(), + withPendingAuthSessions: _q.withPendingAuthSessions.Clone(), withUserAllowedGroups: _q.withUserAllowedGroups.Clone(), // clone intermediate query. sql: _q.sql.Clone(), @@ -643,6 +693,28 @@ func (_q *UserQuery) WithPaymentOrders(opts ...func(*PaymentOrderQuery)) *UserQu return _q } +// WithAuthIdentities tells the query-builder to eager-load the nodes that are connected to +// the "auth_identities" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *UserQuery) WithAuthIdentities(opts ...func(*AuthIdentityQuery)) *UserQuery { + query := (&AuthIdentityClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withAuthIdentities = query + return _q +} + +// WithPendingAuthSessions tells the query-builder to eager-load the nodes that are connected to +// the "pending_auth_sessions" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *UserQuery) WithPendingAuthSessions(opts ...func(*PendingAuthSessionQuery)) *UserQuery { + query := (&PendingAuthSessionClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withPendingAuthSessions = query + return _q +} + // WithUserAllowedGroups tells the query-builder to eager-load the nodes that are connected to // the "user_allowed_groups" edge. The optional arguments are used to configure the query builder of the edge. func (_q *UserQuery) WithUserAllowedGroups(opts ...func(*UserAllowedGroupQuery)) *UserQuery { @@ -732,7 +804,7 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e var ( nodes = []*User{} _spec = _q.querySpec() - loadedTypes = [11]bool{ + loadedTypes = [13]bool{ _q.withAPIKeys != nil, _q.withRedeemCodes != nil, _q.withSubscriptions != nil, @@ -743,6 +815,8 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e _q.withAttributeValues != nil, _q.withPromoCodeUsages != nil, _q.withPaymentOrders != nil, + _q.withAuthIdentities != nil, + _q.withPendingAuthSessions != nil, _q.withUserAllowedGroups != nil, } ) @@ -839,6 +913,22 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e return nil, err } } + if query := _q.withAuthIdentities; query != nil { + if err := _q.loadAuthIdentities(ctx, query, nodes, + func(n *User) { n.Edges.AuthIdentities = []*AuthIdentity{} }, + func(n *User, e *AuthIdentity) { n.Edges.AuthIdentities = append(n.Edges.AuthIdentities, e) }); err != nil { + return nil, err + } + } + if query := _q.withPendingAuthSessions; query != nil { + if err := _q.loadPendingAuthSessions(ctx, query, nodes, + func(n *User) { n.Edges.PendingAuthSessions = []*PendingAuthSession{} }, + func(n *User, e *PendingAuthSession) { + n.Edges.PendingAuthSessions = append(n.Edges.PendingAuthSessions, e) + }); err != nil { + return nil, err + } + } if query := _q.withUserAllowedGroups; query != nil { if err := _q.loadUserAllowedGroups(ctx, query, nodes, func(n *User) { n.Edges.UserAllowedGroups = []*UserAllowedGroup{} }, @@ -1186,6 +1276,69 @@ func (_q *UserQuery) loadPaymentOrders(ctx context.Context, query *PaymentOrderQ } return nil } +func (_q *UserQuery) loadAuthIdentities(ctx context.Context, query *AuthIdentityQuery, nodes []*User, init func(*User), assign func(*User, *AuthIdentity)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(authidentity.FieldUserID) + } + query.Where(predicate.AuthIdentity(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(user.AuthIdentitiesColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.UserID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "user_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} +func (_q *UserQuery) loadPendingAuthSessions(ctx context.Context, query *PendingAuthSessionQuery, nodes []*User, init func(*User), assign func(*User, *PendingAuthSession)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(pendingauthsession.FieldTargetUserID) + } + query.Where(predicate.PendingAuthSession(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(user.PendingAuthSessionsColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.TargetUserID + if fk == nil { + return fmt.Errorf(`foreign-key "target_user_id" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "target_user_id" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllowedGroupQuery, nodes []*User, init func(*User), assign func(*User, *UserAllowedGroup)) error { fks := make([]driver.Value, 0, len(nodes)) nodeids := make(map[int64]*User) diff --git a/backend/ent/user_update.go b/backend/ent/user_update.go index 6b3552476515700564fe9d17395d548fa99628d3..677eeb6bc4fdd9db634cf08703e919804fd2b8f7 100644 --- a/backend/ent/user_update.go +++ b/backend/ent/user_update.go @@ -13,8 +13,10 @@ import ( "entgo.io/ent/schema/field" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/paymentorder" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/redeemcode" @@ -243,6 +245,60 @@ func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate { return _u } +// SetSignupSource sets the "signup_source" field. +func (_u *UserUpdate) SetSignupSource(v string) *UserUpdate { + _u.mutation.SetSignupSource(v) + return _u +} + +// SetNillableSignupSource sets the "signup_source" field if the given value is not nil. +func (_u *UserUpdate) SetNillableSignupSource(v *string) *UserUpdate { + if v != nil { + _u.SetSignupSource(*v) + } + return _u +} + +// SetLastLoginAt sets the "last_login_at" field. +func (_u *UserUpdate) SetLastLoginAt(v time.Time) *UserUpdate { + _u.mutation.SetLastLoginAt(v) + return _u +} + +// SetNillableLastLoginAt sets the "last_login_at" field if the given value is not nil. +func (_u *UserUpdate) SetNillableLastLoginAt(v *time.Time) *UserUpdate { + if v != nil { + _u.SetLastLoginAt(*v) + } + return _u +} + +// ClearLastLoginAt clears the value of the "last_login_at" field. +func (_u *UserUpdate) ClearLastLoginAt() *UserUpdate { + _u.mutation.ClearLastLoginAt() + return _u +} + +// SetLastActiveAt sets the "last_active_at" field. +func (_u *UserUpdate) SetLastActiveAt(v time.Time) *UserUpdate { + _u.mutation.SetLastActiveAt(v) + return _u +} + +// SetNillableLastActiveAt sets the "last_active_at" field if the given value is not nil. +func (_u *UserUpdate) SetNillableLastActiveAt(v *time.Time) *UserUpdate { + if v != nil { + _u.SetLastActiveAt(*v) + } + return _u +} + +// ClearLastActiveAt clears the value of the "last_active_at" field. +func (_u *UserUpdate) ClearLastActiveAt() *UserUpdate { + _u.mutation.ClearLastActiveAt() + return _u +} + // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. func (_u *UserUpdate) SetBalanceNotifyEnabled(v bool) *UserUpdate { _u.mutation.SetBalanceNotifyEnabled(v) @@ -483,6 +539,36 @@ func (_u *UserUpdate) AddPaymentOrders(v ...*PaymentOrder) *UserUpdate { return _u.AddPaymentOrderIDs(ids...) } +// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by IDs. +func (_u *UserUpdate) AddAuthIdentityIDs(ids ...int64) *UserUpdate { + _u.mutation.AddAuthIdentityIDs(ids...) + return _u +} + +// AddAuthIdentities adds the "auth_identities" edges to the AuthIdentity entity. +func (_u *UserUpdate) AddAuthIdentities(v ...*AuthIdentity) *UserUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddAuthIdentityIDs(ids...) +} + +// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs. +func (_u *UserUpdate) AddPendingAuthSessionIDs(ids ...int64) *UserUpdate { + _u.mutation.AddPendingAuthSessionIDs(ids...) + return _u +} + +// AddPendingAuthSessions adds the "pending_auth_sessions" edges to the PendingAuthSession entity. +func (_u *UserUpdate) AddPendingAuthSessions(v ...*PendingAuthSession) *UserUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddPendingAuthSessionIDs(ids...) +} + // Mutation returns the UserMutation object of the builder. func (_u *UserUpdate) Mutation() *UserMutation { return _u.mutation @@ -698,6 +784,48 @@ func (_u *UserUpdate) RemovePaymentOrders(v ...*PaymentOrder) *UserUpdate { return _u.RemovePaymentOrderIDs(ids...) } +// ClearAuthIdentities clears all "auth_identities" edges to the AuthIdentity entity. +func (_u *UserUpdate) ClearAuthIdentities() *UserUpdate { + _u.mutation.ClearAuthIdentities() + return _u +} + +// RemoveAuthIdentityIDs removes the "auth_identities" edge to AuthIdentity entities by IDs. +func (_u *UserUpdate) RemoveAuthIdentityIDs(ids ...int64) *UserUpdate { + _u.mutation.RemoveAuthIdentityIDs(ids...) + return _u +} + +// RemoveAuthIdentities removes "auth_identities" edges to AuthIdentity entities. +func (_u *UserUpdate) RemoveAuthIdentities(v ...*AuthIdentity) *UserUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveAuthIdentityIDs(ids...) +} + +// ClearPendingAuthSessions clears all "pending_auth_sessions" edges to the PendingAuthSession entity. +func (_u *UserUpdate) ClearPendingAuthSessions() *UserUpdate { + _u.mutation.ClearPendingAuthSessions() + return _u +} + +// RemovePendingAuthSessionIDs removes the "pending_auth_sessions" edge to PendingAuthSession entities by IDs. +func (_u *UserUpdate) RemovePendingAuthSessionIDs(ids ...int64) *UserUpdate { + _u.mutation.RemovePendingAuthSessionIDs(ids...) + return _u +} + +// RemovePendingAuthSessions removes "pending_auth_sessions" edges to PendingAuthSession entities. +func (_u *UserUpdate) RemovePendingAuthSessions(v ...*PendingAuthSession) *UserUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemovePendingAuthSessionIDs(ids...) +} + // Save executes the query and returns the number of nodes affected by the update operation. func (_u *UserUpdate) Save(ctx context.Context) (int, error) { if err := _u.defaults(); err != nil { @@ -767,6 +895,11 @@ func (_u *UserUpdate) check() error { return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)} } } + if v, ok := _u.mutation.SignupSource(); ok { + if err := user.SignupSourceValidator(v); err != nil { + return &ValidationError{Name: "signup_source", err: fmt.Errorf(`ent: validator failed for field "User.signup_source": %w`, err)} + } + } return nil } @@ -836,6 +969,21 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.TotpEnabledAtCleared() { _spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime) } + if value, ok := _u.mutation.SignupSource(); ok { + _spec.SetField(user.FieldSignupSource, field.TypeString, value) + } + if value, ok := _u.mutation.LastLoginAt(); ok { + _spec.SetField(user.FieldLastLoginAt, field.TypeTime, value) + } + if _u.mutation.LastLoginAtCleared() { + _spec.ClearField(user.FieldLastLoginAt, field.TypeTime) + } + if value, ok := _u.mutation.LastActiveAt(); ok { + _spec.SetField(user.FieldLastActiveAt, field.TypeTime, value) + } + if _u.mutation.LastActiveAtCleared() { + _spec.ClearField(user.FieldLastActiveAt, field.TypeTime) + } if value, ok := _u.mutation.BalanceNotifyEnabled(); ok { _spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value) } @@ -1322,6 +1470,96 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) { } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if _u.mutation.AuthIdentitiesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.AuthIdentitiesTable, + Columns: []string{user.AuthIdentitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedAuthIdentitiesIDs(); len(nodes) > 0 && !_u.mutation.AuthIdentitiesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.AuthIdentitiesTable, + Columns: []string{user.AuthIdentitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.AuthIdentitiesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.AuthIdentitiesTable, + Columns: []string{user.AuthIdentitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.PendingAuthSessionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PendingAuthSessionsTable, + Columns: []string{user.PendingAuthSessionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedPendingAuthSessionsIDs(); len(nodes) > 0 && !_u.mutation.PendingAuthSessionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PendingAuthSessionsTable, + Columns: []string{user.PendingAuthSessionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.PendingAuthSessionsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PendingAuthSessionsTable, + Columns: []string{user.PendingAuthSessionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{user.Label} @@ -1548,6 +1786,60 @@ func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne { return _u } +// SetSignupSource sets the "signup_source" field. +func (_u *UserUpdateOne) SetSignupSource(v string) *UserUpdateOne { + _u.mutation.SetSignupSource(v) + return _u +} + +// SetNillableSignupSource sets the "signup_source" field if the given value is not nil. +func (_u *UserUpdateOne) SetNillableSignupSource(v *string) *UserUpdateOne { + if v != nil { + _u.SetSignupSource(*v) + } + return _u +} + +// SetLastLoginAt sets the "last_login_at" field. +func (_u *UserUpdateOne) SetLastLoginAt(v time.Time) *UserUpdateOne { + _u.mutation.SetLastLoginAt(v) + return _u +} + +// SetNillableLastLoginAt sets the "last_login_at" field if the given value is not nil. +func (_u *UserUpdateOne) SetNillableLastLoginAt(v *time.Time) *UserUpdateOne { + if v != nil { + _u.SetLastLoginAt(*v) + } + return _u +} + +// ClearLastLoginAt clears the value of the "last_login_at" field. +func (_u *UserUpdateOne) ClearLastLoginAt() *UserUpdateOne { + _u.mutation.ClearLastLoginAt() + return _u +} + +// SetLastActiveAt sets the "last_active_at" field. +func (_u *UserUpdateOne) SetLastActiveAt(v time.Time) *UserUpdateOne { + _u.mutation.SetLastActiveAt(v) + return _u +} + +// SetNillableLastActiveAt sets the "last_active_at" field if the given value is not nil. +func (_u *UserUpdateOne) SetNillableLastActiveAt(v *time.Time) *UserUpdateOne { + if v != nil { + _u.SetLastActiveAt(*v) + } + return _u +} + +// ClearLastActiveAt clears the value of the "last_active_at" field. +func (_u *UserUpdateOne) ClearLastActiveAt() *UserUpdateOne { + _u.mutation.ClearLastActiveAt() + return _u +} + // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. func (_u *UserUpdateOne) SetBalanceNotifyEnabled(v bool) *UserUpdateOne { _u.mutation.SetBalanceNotifyEnabled(v) @@ -1788,6 +2080,36 @@ func (_u *UserUpdateOne) AddPaymentOrders(v ...*PaymentOrder) *UserUpdateOne { return _u.AddPaymentOrderIDs(ids...) } +// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by IDs. +func (_u *UserUpdateOne) AddAuthIdentityIDs(ids ...int64) *UserUpdateOne { + _u.mutation.AddAuthIdentityIDs(ids...) + return _u +} + +// AddAuthIdentities adds the "auth_identities" edges to the AuthIdentity entity. +func (_u *UserUpdateOne) AddAuthIdentities(v ...*AuthIdentity) *UserUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddAuthIdentityIDs(ids...) +} + +// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs. +func (_u *UserUpdateOne) AddPendingAuthSessionIDs(ids ...int64) *UserUpdateOne { + _u.mutation.AddPendingAuthSessionIDs(ids...) + return _u +} + +// AddPendingAuthSessions adds the "pending_auth_sessions" edges to the PendingAuthSession entity. +func (_u *UserUpdateOne) AddPendingAuthSessions(v ...*PendingAuthSession) *UserUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddPendingAuthSessionIDs(ids...) +} + // Mutation returns the UserMutation object of the builder. func (_u *UserUpdateOne) Mutation() *UserMutation { return _u.mutation @@ -2003,6 +2325,48 @@ func (_u *UserUpdateOne) RemovePaymentOrders(v ...*PaymentOrder) *UserUpdateOne return _u.RemovePaymentOrderIDs(ids...) } +// ClearAuthIdentities clears all "auth_identities" edges to the AuthIdentity entity. +func (_u *UserUpdateOne) ClearAuthIdentities() *UserUpdateOne { + _u.mutation.ClearAuthIdentities() + return _u +} + +// RemoveAuthIdentityIDs removes the "auth_identities" edge to AuthIdentity entities by IDs. +func (_u *UserUpdateOne) RemoveAuthIdentityIDs(ids ...int64) *UserUpdateOne { + _u.mutation.RemoveAuthIdentityIDs(ids...) + return _u +} + +// RemoveAuthIdentities removes "auth_identities" edges to AuthIdentity entities. +func (_u *UserUpdateOne) RemoveAuthIdentities(v ...*AuthIdentity) *UserUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveAuthIdentityIDs(ids...) +} + +// ClearPendingAuthSessions clears all "pending_auth_sessions" edges to the PendingAuthSession entity. +func (_u *UserUpdateOne) ClearPendingAuthSessions() *UserUpdateOne { + _u.mutation.ClearPendingAuthSessions() + return _u +} + +// RemovePendingAuthSessionIDs removes the "pending_auth_sessions" edge to PendingAuthSession entities by IDs. +func (_u *UserUpdateOne) RemovePendingAuthSessionIDs(ids ...int64) *UserUpdateOne { + _u.mutation.RemovePendingAuthSessionIDs(ids...) + return _u +} + +// RemovePendingAuthSessions removes "pending_auth_sessions" edges to PendingAuthSession entities. +func (_u *UserUpdateOne) RemovePendingAuthSessions(v ...*PendingAuthSession) *UserUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemovePendingAuthSessionIDs(ids...) +} + // Where appends a list predicates to the UserUpdate builder. func (_u *UserUpdateOne) Where(ps ...predicate.User) *UserUpdateOne { _u.mutation.Where(ps...) @@ -2085,6 +2449,11 @@ func (_u *UserUpdateOne) check() error { return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)} } } + if v, ok := _u.mutation.SignupSource(); ok { + if err := user.SignupSourceValidator(v); err != nil { + return &ValidationError{Name: "signup_source", err: fmt.Errorf(`ent: validator failed for field "User.signup_source": %w`, err)} + } + } return nil } @@ -2171,6 +2540,21 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) { if _u.mutation.TotpEnabledAtCleared() { _spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime) } + if value, ok := _u.mutation.SignupSource(); ok { + _spec.SetField(user.FieldSignupSource, field.TypeString, value) + } + if value, ok := _u.mutation.LastLoginAt(); ok { + _spec.SetField(user.FieldLastLoginAt, field.TypeTime, value) + } + if _u.mutation.LastLoginAtCleared() { + _spec.ClearField(user.FieldLastLoginAt, field.TypeTime) + } + if value, ok := _u.mutation.LastActiveAt(); ok { + _spec.SetField(user.FieldLastActiveAt, field.TypeTime, value) + } + if _u.mutation.LastActiveAtCleared() { + _spec.ClearField(user.FieldLastActiveAt, field.TypeTime) + } if value, ok := _u.mutation.BalanceNotifyEnabled(); ok { _spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value) } @@ -2657,6 +3041,96 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) { } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if _u.mutation.AuthIdentitiesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.AuthIdentitiesTable, + Columns: []string{user.AuthIdentitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedAuthIdentitiesIDs(); len(nodes) > 0 && !_u.mutation.AuthIdentitiesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.AuthIdentitiesTable, + Columns: []string{user.AuthIdentitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.AuthIdentitiesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.AuthIdentitiesTable, + Columns: []string{user.AuthIdentitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.PendingAuthSessionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PendingAuthSessionsTable, + Columns: []string{user.PendingAuthSessionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedPendingAuthSessionsIDs(); len(nodes) > 0 && !_u.mutation.PendingAuthSessionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PendingAuthSessionsTable, + Columns: []string{user.PendingAuthSessionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.PendingAuthSessionsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PendingAuthSessionsTable, + Columns: []string{user.PendingAuthSessionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } _node = &User{config: _u.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/backend/go.mod b/backend/go.mod index 66b6cc25b598efeb750d0d5b36153f8c95bd6863..627851bf193442134b8b5b6d71846ebb254dd944 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -39,10 +39,11 @@ require ( github.com/wechatpay-apiv3/wechatpay-go v0.2.21 github.com/zeromicro/go-zero v1.9.4 go.uber.org/zap v1.24.0 - golang.org/x/crypto v0.48.0 - golang.org/x/net v0.49.0 - golang.org/x/sync v0.19.0 - golang.org/x/term v0.40.0 + golang.org/x/crypto v0.49.0 + golang.org/x/image v0.39.0 + golang.org/x/net v0.52.0 + golang.org/x/sync v0.20.0 + golang.org/x/term v0.41.0 gopkg.in/natefinch/lumberjack.v2 v2.2.1 gopkg.in/yaml.v3 v3.0.1 modernc.org/sqlite v1.44.3 @@ -103,7 +104,6 @@ require ( github.com/goccy/go-json v0.10.2 // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/go-querystring v1.1.0 // indirect - github.com/google/subcommands v1.2.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/hashicorp/hcl/v2 v2.18.1 // indirect @@ -172,10 +172,10 @@ require ( go.uber.org/multierr v1.9.0 // indirect golang.org/x/arch v0.3.0 // indirect golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect - golang.org/x/mod v0.32.0 // indirect - golang.org/x/sys v0.41.0 // indirect - golang.org/x/text v0.34.0 // indirect - golang.org/x/tools v0.41.0 // indirect + golang.org/x/mod v0.34.0 // indirect + golang.org/x/sys v0.42.0 // indirect + golang.org/x/text v0.36.0 // indirect + golang.org/x/tools v0.43.0 // indirect google.golang.org/grpc v1.75.1 // indirect google.golang.org/protobuf v1.36.10 // indirect gopkg.in/ini.v1 v1.67.0 // indirect diff --git a/backend/go.sum b/backend/go.sum index 9312af63e5ef962c2b48ac55fa7c4daac8309dc1..f1c864f54eeb6abe904cbaf8699850a5c604f815 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -162,8 +162,6 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17 github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= -github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE= -github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4= @@ -183,8 +181,6 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4= github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y= github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI= github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00= -github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= -github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= @@ -220,8 +216,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= -github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= @@ -255,8 +249,6 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= -github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= -github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= @@ -286,8 +278,6 @@ github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEv github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= -github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= -github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= @@ -320,8 +310,6 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= -github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= -github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ= @@ -413,16 +401,18 @@ go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= -golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= -golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= +golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= +golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= -golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c= -golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU= -golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= -golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= -golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= -golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/image v0.39.0 h1:skVYidAEVKgn8lZ602XO75asgXBgLj9G/FE3RbuPFww= +golang.org/x/image v0.39.0/go.mod h1:sIbmppfU+xFLPIG0FoVUTvyBMmgng1/XAMhQ2ft0hpA= +golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI= +golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY= +golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= +golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -432,16 +422,16 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= -golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg= -golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= -golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= -golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= +golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU= +golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A= +golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= +golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= -golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc= -golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= +golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s= +golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ= google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU= diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 1559290566aaf34cc080c8d32c55e2b36a5c43f9..44bc5c9f140a7394f3232b82ea3291a4caae008a 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -1613,6 +1613,9 @@ func (c *Config) Validate() error { return fmt.Errorf("security.csp.policy is required when CSP is enabled") } if c.LinuxDo.Enabled { + if !c.LinuxDo.UsePKCE { + return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.enabled=true") + } if strings.TrimSpace(c.LinuxDo.ClientID) == "" { return fmt.Errorf("linuxdo_connect.client_id is required when linuxdo_connect.enabled=true") } @@ -1634,9 +1637,6 @@ func (c *Config) Validate() error { default: return fmt.Errorf("linuxdo_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none") } - if method == "none" && !c.LinuxDo.UsePKCE { - return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.token_auth_method=none") - } if (method == "" || method == "client_secret_post" || method == "client_secret_basic") && strings.TrimSpace(c.LinuxDo.ClientSecret) == "" { return fmt.Errorf("linuxdo_connect.client_secret is required when linuxdo_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic") @@ -1668,6 +1668,12 @@ func (c *Config) Validate() error { warnIfInsecureURL("linuxdo_connect.frontend_redirect_url", c.LinuxDo.FrontendRedirectURL) } if c.OIDC.Enabled { + if !c.OIDC.UsePKCE { + return fmt.Errorf("oidc_connect.use_pkce must be true when oidc_connect.enabled=true") + } + if !c.OIDC.ValidateIDToken { + return fmt.Errorf("oidc_connect.validate_id_token must be true when oidc_connect.enabled=true") + } if strings.TrimSpace(c.OIDC.ClientID) == "" { return fmt.Errorf("oidc_connect.client_id is required when oidc_connect.enabled=true") } @@ -1690,9 +1696,6 @@ func (c *Config) Validate() error { default: return fmt.Errorf("oidc_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none") } - if method == "none" && !c.OIDC.UsePKCE { - return fmt.Errorf("oidc_connect.use_pkce must be true when oidc_connect.token_auth_method=none") - } if (method == "" || method == "client_secret_post" || method == "client_secret_basic") && strings.TrimSpace(c.OIDC.ClientSecret) == "" { return fmt.Errorf("oidc_connect.client_secret is required when oidc_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic") diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index cf58316c316a7297ff7b84e1e8a12f3e066fe8e3..fe48541b0614244c2e98de859872da9b4190d761 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -334,7 +334,7 @@ func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) { cfg.LinuxDo.ClientSecret = "test-secret" cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback" cfg.LinuxDo.TokenAuthMethod = "client_secret_post" - cfg.LinuxDo.UsePKCE = false + cfg.LinuxDo.UsePKCE = true cfg.LinuxDo.FrontendRedirectURL = "javascript:alert(1)" err = cfg.Validate() @@ -389,6 +389,7 @@ func TestValidateOIDCScopesMustContainOpenID(t *testing.T) { cfg.OIDC.RedirectURL = "https://example.com/api/v1/auth/oauth/oidc/callback" cfg.OIDC.FrontendRedirectURL = "/auth/oidc/callback" cfg.OIDC.Scopes = "profile email" + cfg.OIDC.UsePKCE = true err = cfg.Validate() if err == nil { @@ -418,6 +419,7 @@ func TestValidateOIDCAllowsIssuerOnlyEndpointsWithDiscoveryFallback(t *testing.T cfg.OIDC.FrontendRedirectURL = "/auth/oidc/callback" cfg.OIDC.Scopes = "openid email profile" cfg.OIDC.ValidateIDToken = true + cfg.OIDC.UsePKCE = true err = cfg.Validate() if err != nil { @@ -840,6 +842,7 @@ func TestValidateConfigWithLinuxDoEnabled(t *testing.T) { cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback" cfg.LinuxDo.FrontendRedirectURL = "/auth/linuxdo/callback" cfg.LinuxDo.TokenAuthMethod = "client_secret_post" + cfg.LinuxDo.UsePKCE = true if err := cfg.Validate(); err != nil { t.Fatalf("Validate() unexpected error: %v", err) @@ -990,6 +993,7 @@ func TestValidateConfigErrors(t *testing.T) { name: "linuxdo client id required", mutate: func(c *Config) { c.LinuxDo.Enabled = true + c.LinuxDo.UsePKCE = true c.LinuxDo.ClientID = "" }, wantErr: "linuxdo_connect.client_id", @@ -998,6 +1002,7 @@ func TestValidateConfigErrors(t *testing.T) { name: "linuxdo token auth method", mutate: func(c *Config) { c.LinuxDo.Enabled = true + c.LinuxDo.UsePKCE = true c.LinuxDo.ClientID = "client" c.LinuxDo.ClientSecret = "secret" c.LinuxDo.AuthorizeURL = "https://example.com/authorize" diff --git a/backend/internal/handler/admin/admin_basic_handlers_test.go b/backend/internal/handler/admin/admin_basic_handlers_test.go index cba3ae21494bcaa5cb500a8c36911787486e005c..ddeaab0218a203abb7f00216289636773ac19bf1 100644 --- a/backend/internal/handler/admin/admin_basic_handlers_test.go +++ b/backend/internal/handler/admin/admin_basic_handlers_test.go @@ -23,6 +23,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) { router.GET("/api/v1/admin/users", userHandler.List) router.GET("/api/v1/admin/users/:id", userHandler.GetByID) + router.POST("/api/v1/admin/users/:id/auth-identities", userHandler.BindAuthIdentity) router.POST("/api/v1/admin/users", userHandler.Create) router.PUT("/api/v1/admin/users/:id", userHandler.Update) router.DELETE("/api/v1/admin/users/:id", userHandler.Delete) @@ -75,8 +76,26 @@ func TestUserHandlerEndpoints(t *testing.T) { router.ServeHTTP(rec, req) require.Equal(t, http.StatusOK, rec.Code) + bindBody := map[string]any{ + "provider_type": "wechat", + "provider_key": "wechat-main", + "provider_subject": "union-123", + "metadata": map[string]any{"source": "admin-repair"}, + "channel": map[string]any{ + "channel": "open", + "channel_app_id": "wx-open", + "channel_subject": "openid-123", + }, + } + body, _ := json.Marshal(bindBody) + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/1/auth-identities", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + createBody := map[string]any{"email": "new@example.com", "password": "pass123", "balance": 1, "concurrency": 2} - body, _ := json.Marshal(createBody) + body, _ = json.Marshal(createBody) rec = httptest.NewRecorder() req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") @@ -113,6 +132,33 @@ func TestUserHandlerEndpoints(t *testing.T) { require.Equal(t, http.StatusOK, rec.Code) } +func TestUserHandlerBindAuthIdentityMapsRequest(t *testing.T) { + router, adminSvc := setupAdminRouter() + + body, err := json.Marshal(map[string]any{ + "provider_type": "oidc", + "provider_key": "https://issuer.example", + "provider_subject": "subject-123", + "issuer": "https://issuer.example", + "metadata": map[string]any{"report_id": 12}, + }) + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/9/auth-identities", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, int64(9), adminSvc.boundAuthIdentityFor) + require.NotNil(t, adminSvc.boundAuthIdentity) + require.Equal(t, "oidc", adminSvc.boundAuthIdentity.ProviderType) + require.Equal(t, "https://issuer.example", adminSvc.boundAuthIdentity.ProviderKey) + require.Equal(t, "subject-123", adminSvc.boundAuthIdentity.ProviderSubject) + require.Nil(t, adminSvc.boundAuthIdentity.Channel) + require.Equal(t, float64(12), adminSvc.boundAuthIdentity.Metadata["report_id"]) +} + func TestGroupHandlerEndpoints(t *testing.T) { router, _ := setupAdminRouter() diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index 6d1ef1b6b82426d8d2679808eb16073410cec2e2..3a395342e66965c15ae01afbe2de0dbe8f288d9c 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -17,6 +17,8 @@ type stubAdminService struct { proxies []service.Proxy proxyCounts []service.ProxyWithAccountCount redeems []service.RedeemCode + boundAuthIdentity *service.AdminBindAuthIdentityInput + boundAuthIdentityFor int64 createdAccounts []*service.CreateAccountInput createdProxies []*service.CreateProxyInput updatedProxyIDs []int64 @@ -42,6 +44,14 @@ type stubAdminService struct { sortOrder string calls int } + lastListUsers struct { + page int + pageSize int + filters service.UserListFilters + sortBy string + sortOrder string + calls int + } lastListProxies struct { protocol string status string @@ -127,6 +137,12 @@ func newStubAdminService() *stubAdminService { } func (s *stubAdminService) ListUsers(ctx context.Context, page, pageSize int, filters service.UserListFilters, sortBy, sortOrder string) ([]service.User, int64, error) { + s.lastListUsers.page = page + s.lastListUsers.pageSize = pageSize + s.lastListUsers.filters = filters + s.lastListUsers.sortBy = sortBy + s.lastListUsers.sortOrder = sortOrder + s.lastListUsers.calls++ return s.users, int64(len(s.users)), nil } @@ -167,6 +183,52 @@ func (s *stubAdminService) GetUserUsageStats(ctx context.Context, userID int64, return map[string]any{"user_id": userID}, nil } +func (s *stubAdminService) BindUserAuthIdentity(ctx context.Context, userID int64, input service.AdminBindAuthIdentityInput) (*service.AdminBoundAuthIdentity, error) { + s.boundAuthIdentityFor = userID + copied := input + if input.Metadata != nil { + copied.Metadata = map[string]any{} + for key, value := range input.Metadata { + copied.Metadata[key] = value + } + } + if input.Channel != nil { + channel := *input.Channel + if input.Channel.Metadata != nil { + channel.Metadata = map[string]any{} + for key, value := range input.Channel.Metadata { + channel.Metadata[key] = value + } + } + copied.Channel = &channel + } + s.boundAuthIdentity = &copied + + now := time.Now().UTC() + result := &service.AdminBoundAuthIdentity{ + UserID: userID, + ProviderType: input.ProviderType, + ProviderKey: input.ProviderKey, + ProviderSubject: input.ProviderSubject, + VerifiedAt: &now, + Issuer: input.Issuer, + Metadata: input.Metadata, + CreatedAt: now, + UpdatedAt: now, + } + if input.Channel != nil { + result.Channel = &service.AdminBoundAuthIdentityChannel{ + Channel: input.Channel.Channel, + ChannelAppID: input.Channel.ChannelAppID, + ChannelSubject: input.Channel.ChannelSubject, + Metadata: input.Channel.Metadata, + CreatedAt: now, + UpdatedAt: now, + } + } + return result, nil +} + func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]service.Group, int64, error) { return s.groups, int64(len(s.groups)), nil } diff --git a/backend/internal/handler/admin/payment_handler.go b/backend/internal/handler/admin/payment_handler.go index b0ed6aed8f7fca88c0ed58d7b9744958e70a6766..84359cd93bdecfae7d562a57a29655109a2ddd2a 100644 --- a/backend/internal/handler/admin/payment_handler.go +++ b/backend/internal/handler/admin/payment_handler.go @@ -3,6 +3,7 @@ package admin import ( "strconv" + dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/service" @@ -66,7 +67,7 @@ func (h *PaymentHandler) ListOrders(c *gin.Context) { response.ErrorFrom(c, err) return } - response.Paginated(c, orders, int64(total), page, pageSize) + response.Paginated(c, sanitizeAdminPaymentOrdersForResponse(orders), int64(total), page, pageSize) } // GetOrderDetail returns detailed information about a single order. @@ -82,7 +83,7 @@ func (h *PaymentHandler) GetOrderDetail(c *gin.Context) { return } auditLogs, _ := h.paymentService.GetOrderAuditLogs(c.Request.Context(), orderID) - response.Success(c, gin.H{"order": order, "auditLogs": auditLogs}) + response.Success(c, gin.H{"order": sanitizeAdminPaymentOrderForResponse(order), "auditLogs": auditLogs}) } // CancelOrder cancels a pending order (admin). @@ -114,6 +115,26 @@ func (h *PaymentHandler) RetryFulfillment(c *gin.Context) { response.Success(c, gin.H{"message": "fulfillment retried"}) } +func sanitizeAdminPaymentOrdersForResponse(orders []*dbent.PaymentOrder) []*dbent.PaymentOrder { + if len(orders) == 0 { + return orders + } + out := make([]*dbent.PaymentOrder, 0, len(orders)) + for _, order := range orders { + out = append(out, sanitizeAdminPaymentOrderForResponse(order)) + } + return out +} + +func sanitizeAdminPaymentOrderForResponse(order *dbent.PaymentOrder) *dbent.PaymentOrder { + if order == nil { + return nil + } + cloned := *order + cloned.ProviderSnapshot = nil + return &cloned +} + // AdminProcessRefundRequest is the request body for admin refund processing. type AdminProcessRefundRequest struct { Amount float64 `json:"amount"` diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index bec0f126137a2b05ca7acde898e4c1b6d21943ad..e6609c97123f2fcefc3773a4292f1fdd6dff248f 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -43,6 +43,15 @@ func scopesContainOpenID(scopes string) bool { return false } +func firstNonEmpty(values ...string) string { + for _, value := range values { + if trimmed := strings.TrimSpace(value); trimmed != "" { + return trimmed + } + } + return "" +} + // SettingHandler 系统设置处理器 type SettingHandler struct { settingService *service.SettingService @@ -73,6 +82,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { response.ErrorFrom(c, err) return } + authSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } // Check if ops monitoring is enabled (respects config.ops.enabled) opsEnabled := h.opsService != nil && h.opsService.IsMonitoringEnabled(c.Request.Context()) @@ -93,114 +107,136 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { paymentCfg = &service.PaymentConfig{} } - response.Success(c, dto.SystemSettings{ - RegistrationEnabled: settings.RegistrationEnabled, - EmailVerifyEnabled: settings.EmailVerifyEnabled, - RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist, - PromoCodeEnabled: settings.PromoCodeEnabled, - PasswordResetEnabled: settings.PasswordResetEnabled, - FrontendURL: settings.FrontendURL, - InvitationCodeEnabled: settings.InvitationCodeEnabled, - TotpEnabled: settings.TotpEnabled, - TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(), - SMTPHost: settings.SMTPHost, - SMTPPort: settings.SMTPPort, - SMTPUsername: settings.SMTPUsername, - SMTPPasswordConfigured: settings.SMTPPasswordConfigured, - SMTPFrom: settings.SMTPFrom, - SMTPFromName: settings.SMTPFromName, - SMTPUseTLS: settings.SMTPUseTLS, - TurnstileEnabled: settings.TurnstileEnabled, - TurnstileSiteKey: settings.TurnstileSiteKey, - TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured, - LinuxDoConnectEnabled: settings.LinuxDoConnectEnabled, - LinuxDoConnectClientID: settings.LinuxDoConnectClientID, - LinuxDoConnectClientSecretConfigured: settings.LinuxDoConnectClientSecretConfigured, - LinuxDoConnectRedirectURL: settings.LinuxDoConnectRedirectURL, - OIDCConnectEnabled: settings.OIDCConnectEnabled, - OIDCConnectProviderName: settings.OIDCConnectProviderName, - OIDCConnectClientID: settings.OIDCConnectClientID, - OIDCConnectClientSecretConfigured: settings.OIDCConnectClientSecretConfigured, - OIDCConnectIssuerURL: settings.OIDCConnectIssuerURL, - OIDCConnectDiscoveryURL: settings.OIDCConnectDiscoveryURL, - OIDCConnectAuthorizeURL: settings.OIDCConnectAuthorizeURL, - OIDCConnectTokenURL: settings.OIDCConnectTokenURL, - OIDCConnectUserInfoURL: settings.OIDCConnectUserInfoURL, - OIDCConnectJWKSURL: settings.OIDCConnectJWKSURL, - OIDCConnectScopes: settings.OIDCConnectScopes, - OIDCConnectRedirectURL: settings.OIDCConnectRedirectURL, - OIDCConnectFrontendRedirectURL: settings.OIDCConnectFrontendRedirectURL, - OIDCConnectTokenAuthMethod: settings.OIDCConnectTokenAuthMethod, - OIDCConnectUsePKCE: settings.OIDCConnectUsePKCE, - OIDCConnectValidateIDToken: settings.OIDCConnectValidateIDToken, - OIDCConnectAllowedSigningAlgs: settings.OIDCConnectAllowedSigningAlgs, - OIDCConnectClockSkewSeconds: settings.OIDCConnectClockSkewSeconds, - OIDCConnectRequireEmailVerified: settings.OIDCConnectRequireEmailVerified, - OIDCConnectUserInfoEmailPath: settings.OIDCConnectUserInfoEmailPath, - OIDCConnectUserInfoIDPath: settings.OIDCConnectUserInfoIDPath, - OIDCConnectUserInfoUsernamePath: settings.OIDCConnectUserInfoUsernamePath, - SiteName: settings.SiteName, - SiteLogo: settings.SiteLogo, - SiteSubtitle: settings.SiteSubtitle, - APIBaseURL: settings.APIBaseURL, - ContactInfo: settings.ContactInfo, - DocURL: settings.DocURL, - HomeContent: settings.HomeContent, - HideCcsImportButton: settings.HideCcsImportButton, - PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, - PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, - TableDefaultPageSize: settings.TableDefaultPageSize, - TablePageSizeOptions: settings.TablePageSizeOptions, - CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems), - CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints), - DefaultConcurrency: settings.DefaultConcurrency, - DefaultBalance: settings.DefaultBalance, - DefaultSubscriptions: defaultSubscriptions, - EnableModelFallback: settings.EnableModelFallback, - FallbackModelAnthropic: settings.FallbackModelAnthropic, - FallbackModelOpenAI: settings.FallbackModelOpenAI, - FallbackModelGemini: settings.FallbackModelGemini, - FallbackModelAntigravity: settings.FallbackModelAntigravity, - EnableIdentityPatch: settings.EnableIdentityPatch, - IdentityPatchPrompt: settings.IdentityPatchPrompt, - OpsMonitoringEnabled: opsEnabled && settings.OpsMonitoringEnabled, - OpsRealtimeMonitoringEnabled: settings.OpsRealtimeMonitoringEnabled, - OpsQueryModeDefault: settings.OpsQueryModeDefault, - OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds, - MinClaudeCodeVersion: settings.MinClaudeCodeVersion, - MaxClaudeCodeVersion: settings.MaxClaudeCodeVersion, - AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling, - BackendModeEnabled: settings.BackendModeEnabled, - EnableFingerprintUnification: settings.EnableFingerprintUnification, - EnableMetadataPassthrough: settings.EnableMetadataPassthrough, - EnableCCHSigning: settings.EnableCCHSigning, - WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled, - BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled, - BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold, - BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL, - AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled, - AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(settings.AccountQuotaNotifyEmails), - PaymentEnabled: paymentCfg.Enabled, - PaymentMinAmount: paymentCfg.MinAmount, - PaymentMaxAmount: paymentCfg.MaxAmount, - PaymentDailyLimit: paymentCfg.DailyLimit, - PaymentOrderTimeoutMin: paymentCfg.OrderTimeoutMin, - PaymentMaxPendingOrders: paymentCfg.MaxPendingOrders, - PaymentEnabledTypes: paymentCfg.EnabledTypes, - PaymentBalanceDisabled: paymentCfg.BalanceDisabled, - PaymentBalanceRechargeMultiplier: paymentCfg.BalanceRechargeMultiplier, - PaymentRechargeFeeRate: paymentCfg.RechargeFeeRate, - PaymentLoadBalanceStrat: paymentCfg.LoadBalanceStrategy, - PaymentProductNamePrefix: paymentCfg.ProductNamePrefix, - PaymentProductNameSuffix: paymentCfg.ProductNameSuffix, - PaymentHelpImageURL: paymentCfg.HelpImageURL, - PaymentHelpText: paymentCfg.HelpText, - PaymentCancelRateLimitEnabled: paymentCfg.CancelRateLimitEnabled, - PaymentCancelRateLimitMax: paymentCfg.CancelRateLimitMax, - PaymentCancelRateLimitWindow: paymentCfg.CancelRateLimitWindow, - PaymentCancelRateLimitUnit: paymentCfg.CancelRateLimitUnit, - PaymentCancelRateLimitMode: paymentCfg.CancelRateLimitMode, - }) + payload := dto.SystemSettings{ + RegistrationEnabled: settings.RegistrationEnabled, + EmailVerifyEnabled: settings.EmailVerifyEnabled, + RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist, + PromoCodeEnabled: settings.PromoCodeEnabled, + PasswordResetEnabled: settings.PasswordResetEnabled, + FrontendURL: settings.FrontendURL, + InvitationCodeEnabled: settings.InvitationCodeEnabled, + TotpEnabled: settings.TotpEnabled, + TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(), + SMTPHost: settings.SMTPHost, + SMTPPort: settings.SMTPPort, + SMTPUsername: settings.SMTPUsername, + SMTPPasswordConfigured: settings.SMTPPasswordConfigured, + SMTPFrom: settings.SMTPFrom, + SMTPFromName: settings.SMTPFromName, + SMTPUseTLS: settings.SMTPUseTLS, + TurnstileEnabled: settings.TurnstileEnabled, + TurnstileSiteKey: settings.TurnstileSiteKey, + TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured, + LinuxDoConnectEnabled: settings.LinuxDoConnectEnabled, + LinuxDoConnectClientID: settings.LinuxDoConnectClientID, + LinuxDoConnectClientSecretConfigured: settings.LinuxDoConnectClientSecretConfigured, + LinuxDoConnectRedirectURL: settings.LinuxDoConnectRedirectURL, + WeChatConnectEnabled: settings.WeChatConnectEnabled, + WeChatConnectAppID: settings.WeChatConnectAppID, + WeChatConnectAppSecretConfigured: settings.WeChatConnectAppSecretConfigured, + WeChatConnectOpenAppID: settings.WeChatConnectOpenAppID, + WeChatConnectOpenAppSecretConfigured: settings.WeChatConnectOpenAppSecretConfigured, + WeChatConnectMPAppID: settings.WeChatConnectMPAppID, + WeChatConnectMPAppSecretConfigured: settings.WeChatConnectMPAppSecretConfigured, + WeChatConnectMobileAppID: settings.WeChatConnectMobileAppID, + WeChatConnectMobileAppSecretConfigured: settings.WeChatConnectMobileAppSecretConfigured, + WeChatConnectOpenEnabled: settings.WeChatConnectOpenEnabled, + WeChatConnectMPEnabled: settings.WeChatConnectMPEnabled, + WeChatConnectMobileEnabled: settings.WeChatConnectMobileEnabled, + WeChatConnectMode: settings.WeChatConnectMode, + WeChatConnectScopes: settings.WeChatConnectScopes, + WeChatConnectRedirectURL: settings.WeChatConnectRedirectURL, + WeChatConnectFrontendRedirectURL: settings.WeChatConnectFrontendRedirectURL, + OIDCConnectEnabled: settings.OIDCConnectEnabled, + OIDCConnectProviderName: settings.OIDCConnectProviderName, + OIDCConnectClientID: settings.OIDCConnectClientID, + OIDCConnectClientSecretConfigured: settings.OIDCConnectClientSecretConfigured, + OIDCConnectIssuerURL: settings.OIDCConnectIssuerURL, + OIDCConnectDiscoveryURL: settings.OIDCConnectDiscoveryURL, + OIDCConnectAuthorizeURL: settings.OIDCConnectAuthorizeURL, + OIDCConnectTokenURL: settings.OIDCConnectTokenURL, + OIDCConnectUserInfoURL: settings.OIDCConnectUserInfoURL, + OIDCConnectJWKSURL: settings.OIDCConnectJWKSURL, + OIDCConnectScopes: settings.OIDCConnectScopes, + OIDCConnectRedirectURL: settings.OIDCConnectRedirectURL, + OIDCConnectFrontendRedirectURL: settings.OIDCConnectFrontendRedirectURL, + OIDCConnectTokenAuthMethod: settings.OIDCConnectTokenAuthMethod, + OIDCConnectUsePKCE: settings.OIDCConnectUsePKCE, + OIDCConnectValidateIDToken: settings.OIDCConnectValidateIDToken, + OIDCConnectAllowedSigningAlgs: settings.OIDCConnectAllowedSigningAlgs, + OIDCConnectClockSkewSeconds: settings.OIDCConnectClockSkewSeconds, + OIDCConnectRequireEmailVerified: settings.OIDCConnectRequireEmailVerified, + OIDCConnectUserInfoEmailPath: settings.OIDCConnectUserInfoEmailPath, + OIDCConnectUserInfoIDPath: settings.OIDCConnectUserInfoIDPath, + OIDCConnectUserInfoUsernamePath: settings.OIDCConnectUserInfoUsernamePath, + SiteName: settings.SiteName, + SiteLogo: settings.SiteLogo, + SiteSubtitle: settings.SiteSubtitle, + APIBaseURL: settings.APIBaseURL, + ContactInfo: settings.ContactInfo, + DocURL: settings.DocURL, + HomeContent: settings.HomeContent, + HideCcsImportButton: settings.HideCcsImportButton, + PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, + PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, + TableDefaultPageSize: settings.TableDefaultPageSize, + TablePageSizeOptions: settings.TablePageSizeOptions, + CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems), + CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints), + DefaultConcurrency: settings.DefaultConcurrency, + DefaultBalance: settings.DefaultBalance, + DefaultSubscriptions: defaultSubscriptions, + EnableModelFallback: settings.EnableModelFallback, + FallbackModelAnthropic: settings.FallbackModelAnthropic, + FallbackModelOpenAI: settings.FallbackModelOpenAI, + FallbackModelGemini: settings.FallbackModelGemini, + FallbackModelAntigravity: settings.FallbackModelAntigravity, + EnableIdentityPatch: settings.EnableIdentityPatch, + IdentityPatchPrompt: settings.IdentityPatchPrompt, + OpsMonitoringEnabled: opsEnabled && settings.OpsMonitoringEnabled, + OpsRealtimeMonitoringEnabled: settings.OpsRealtimeMonitoringEnabled, + OpsQueryModeDefault: settings.OpsQueryModeDefault, + OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds, + MinClaudeCodeVersion: settings.MinClaudeCodeVersion, + MaxClaudeCodeVersion: settings.MaxClaudeCodeVersion, + AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling, + BackendModeEnabled: settings.BackendModeEnabled, + EnableFingerprintUnification: settings.EnableFingerprintUnification, + EnableMetadataPassthrough: settings.EnableMetadataPassthrough, + EnableCCHSigning: settings.EnableCCHSigning, + WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled, + PaymentVisibleMethodAlipaySource: settings.PaymentVisibleMethodAlipaySource, + PaymentVisibleMethodWxpaySource: settings.PaymentVisibleMethodWxpaySource, + PaymentVisibleMethodAlipayEnabled: settings.PaymentVisibleMethodAlipayEnabled, + PaymentVisibleMethodWxpayEnabled: settings.PaymentVisibleMethodWxpayEnabled, + OpenAIAdvancedSchedulerEnabled: settings.OpenAIAdvancedSchedulerEnabled, + BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled, + BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold, + BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL, + AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled, + AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(settings.AccountQuotaNotifyEmails), + PaymentEnabled: paymentCfg.Enabled, + PaymentMinAmount: paymentCfg.MinAmount, + PaymentMaxAmount: paymentCfg.MaxAmount, + PaymentDailyLimit: paymentCfg.DailyLimit, + PaymentOrderTimeoutMin: paymentCfg.OrderTimeoutMin, + PaymentMaxPendingOrders: paymentCfg.MaxPendingOrders, + PaymentEnabledTypes: paymentCfg.EnabledTypes, + PaymentBalanceDisabled: paymentCfg.BalanceDisabled, + PaymentBalanceRechargeMultiplier: paymentCfg.BalanceRechargeMultiplier, + PaymentRechargeFeeRate: paymentCfg.RechargeFeeRate, + PaymentLoadBalanceStrat: paymentCfg.LoadBalanceStrategy, + PaymentProductNamePrefix: paymentCfg.ProductNamePrefix, + PaymentProductNameSuffix: paymentCfg.ProductNameSuffix, + PaymentHelpImageURL: paymentCfg.HelpImageURL, + PaymentHelpText: paymentCfg.HelpText, + PaymentCancelRateLimitEnabled: paymentCfg.CancelRateLimitEnabled, + PaymentCancelRateLimitMax: paymentCfg.CancelRateLimitMax, + PaymentCancelRateLimitWindow: paymentCfg.CancelRateLimitWindow, + PaymentCancelRateLimitUnit: paymentCfg.CancelRateLimitUnit, + PaymentCancelRateLimitMode: paymentCfg.CancelRateLimitMode, + } + response.Success(c, systemSettingsResponseData(payload, authSourceDefaults)) } // UpdateSettingsRequest 更新设置请求 @@ -235,6 +271,24 @@ type UpdateSettingsRequest struct { LinuxDoConnectClientSecret string `json:"linuxdo_connect_client_secret"` LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"` + // WeChat Connect OAuth 登录 + WeChatConnectEnabled bool `json:"wechat_connect_enabled"` + WeChatConnectAppID string `json:"wechat_connect_app_id"` + WeChatConnectAppSecret string `json:"wechat_connect_app_secret"` + WeChatConnectOpenAppID string `json:"wechat_connect_open_app_id"` + WeChatConnectOpenAppSecret string `json:"wechat_connect_open_app_secret"` + WeChatConnectMPAppID string `json:"wechat_connect_mp_app_id"` + WeChatConnectMPAppSecret string `json:"wechat_connect_mp_app_secret"` + WeChatConnectMobileAppID string `json:"wechat_connect_mobile_app_id"` + WeChatConnectMobileAppSecret string `json:"wechat_connect_mobile_app_secret"` + WeChatConnectOpenEnabled bool `json:"wechat_connect_open_enabled"` + WeChatConnectMPEnabled bool `json:"wechat_connect_mp_enabled"` + WeChatConnectMobileEnabled bool `json:"wechat_connect_mobile_enabled"` + WeChatConnectMode string `json:"wechat_connect_mode"` + WeChatConnectScopes string `json:"wechat_connect_scopes"` + WeChatConnectRedirectURL string `json:"wechat_connect_redirect_url"` + WeChatConnectFrontendRedirectURL string `json:"wechat_connect_frontend_redirect_url"` + // Generic OIDC OAuth 登录 OIDCConnectEnabled bool `json:"oidc_connect_enabled"` OIDCConnectProviderName string `json:"oidc_connect_provider_name"` @@ -276,9 +330,30 @@ type UpdateSettingsRequest struct { CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"` // 默认配置 - DefaultConcurrency int `json:"default_concurrency"` - DefaultBalance float64 `json:"default_balance"` - DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"` + DefaultConcurrency int `json:"default_concurrency"` + DefaultBalance float64 `json:"default_balance"` + DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"` + AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"` + AuthSourceDefaultEmailConcurrency *int `json:"auth_source_default_email_concurrency"` + AuthSourceDefaultEmailSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_email_subscriptions"` + AuthSourceDefaultEmailGrantOnSignup *bool `json:"auth_source_default_email_grant_on_signup"` + AuthSourceDefaultEmailGrantOnFirstBind *bool `json:"auth_source_default_email_grant_on_first_bind"` + AuthSourceDefaultLinuxDoBalance *float64 `json:"auth_source_default_linuxdo_balance"` + AuthSourceDefaultLinuxDoConcurrency *int `json:"auth_source_default_linuxdo_concurrency"` + AuthSourceDefaultLinuxDoSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_linuxdo_subscriptions"` + AuthSourceDefaultLinuxDoGrantOnSignup *bool `json:"auth_source_default_linuxdo_grant_on_signup"` + AuthSourceDefaultLinuxDoGrantOnFirstBind *bool `json:"auth_source_default_linuxdo_grant_on_first_bind"` + AuthSourceDefaultOIDCBalance *float64 `json:"auth_source_default_oidc_balance"` + AuthSourceDefaultOIDCConcurrency *int `json:"auth_source_default_oidc_concurrency"` + AuthSourceDefaultOIDCSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_oidc_subscriptions"` + AuthSourceDefaultOIDCGrantOnSignup *bool `json:"auth_source_default_oidc_grant_on_signup"` + AuthSourceDefaultOIDCGrantOnFirstBind *bool `json:"auth_source_default_oidc_grant_on_first_bind"` + AuthSourceDefaultWeChatBalance *float64 `json:"auth_source_default_wechat_balance"` + AuthSourceDefaultWeChatConcurrency *int `json:"auth_source_default_wechat_concurrency"` + AuthSourceDefaultWeChatSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_wechat_subscriptions"` + AuthSourceDefaultWeChatGrantOnSignup *bool `json:"auth_source_default_wechat_grant_on_signup"` + AuthSourceDefaultWeChatGrantOnFirstBind *bool `json:"auth_source_default_wechat_grant_on_first_bind"` + ForceEmailOnThirdPartySignup *bool `json:"force_email_on_third_party_signup"` // Model fallback configuration EnableModelFallback bool `json:"enable_model_fallback"` @@ -311,6 +386,15 @@ type UpdateSettingsRequest struct { EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"` EnableCCHSigning *bool `json:"enable_cch_signing"` + // Payment visible method routing + PaymentVisibleMethodAlipaySource *string `json:"payment_visible_method_alipay_source"` + PaymentVisibleMethodWxpaySource *string `json:"payment_visible_method_wxpay_source"` + PaymentVisibleMethodAlipayEnabled *bool `json:"payment_visible_method_alipay_enabled"` + PaymentVisibleMethodWxpayEnabled *bool `json:"payment_visible_method_wxpay_enabled"` + + // OpenAI account scheduling + OpenAIAdvancedSchedulerEnabled *bool `json:"openai_advanced_scheduler_enabled"` + // Balance low notification BalanceLowNotifyEnabled *bool `json:"balance_low_notify_enabled"` BalanceLowNotifyThreshold *float64 `json:"balance_low_notify_threshold"` @@ -357,6 +441,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { response.ErrorFrom(c, err) return } + previousAuthSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } // 验证参数 if req.DefaultConcurrency < 1 { @@ -381,6 +470,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { req.SMTPPort = 587 } req.DefaultSubscriptions = normalizeDefaultSubscriptions(req.DefaultSubscriptions) + req.AuthSourceDefaultEmailSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultEmailSubscriptions) + req.AuthSourceDefaultLinuxDoSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultLinuxDoSubscriptions) + req.AuthSourceDefaultOIDCSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultOIDCSubscriptions) + req.AuthSourceDefaultWeChatSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultWeChatSubscriptions) // SMTP 配置保护:如果请求中 smtp_host 为空但数据库中已有配置,则保留已有 SMTP 配置 // 防止前端加载设置失败时空表单覆盖已保存的 SMTP 配置 @@ -459,6 +552,124 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } } + if req.WeChatConnectEnabled { + req.WeChatConnectAppID = strings.TrimSpace(req.WeChatConnectAppID) + req.WeChatConnectAppSecret = strings.TrimSpace(req.WeChatConnectAppSecret) + req.WeChatConnectOpenAppID = strings.TrimSpace(req.WeChatConnectOpenAppID) + req.WeChatConnectOpenAppSecret = strings.TrimSpace(req.WeChatConnectOpenAppSecret) + req.WeChatConnectMPAppID = strings.TrimSpace(req.WeChatConnectMPAppID) + req.WeChatConnectMPAppSecret = strings.TrimSpace(req.WeChatConnectMPAppSecret) + req.WeChatConnectMobileAppID = strings.TrimSpace(req.WeChatConnectMobileAppID) + req.WeChatConnectMobileAppSecret = strings.TrimSpace(req.WeChatConnectMobileAppSecret) + req.WeChatConnectMode = strings.ToLower(strings.TrimSpace(req.WeChatConnectMode)) + req.WeChatConnectScopes = strings.TrimSpace(req.WeChatConnectScopes) + req.WeChatConnectRedirectURL = strings.TrimSpace(req.WeChatConnectRedirectURL) + req.WeChatConnectFrontendRedirectURL = strings.TrimSpace(req.WeChatConnectFrontendRedirectURL) + + if req.WeChatConnectMPEnabled && req.WeChatConnectMobileEnabled { + response.BadRequest(c, "WeChat Official Account and Mobile App cannot be enabled at the same time") + return + } + if req.WeChatConnectMode != "" { + switch req.WeChatConnectMode { + case "open", "mp", "mobile": + default: + response.BadRequest(c, "WeChat mode must be open, mp, or mobile") + return + } + } + if !req.WeChatConnectOpenEnabled && !req.WeChatConnectMPEnabled && !req.WeChatConnectMobileEnabled { + switch req.WeChatConnectMode { + case "mp": + req.WeChatConnectMPEnabled = true + case "mobile": + req.WeChatConnectMobileEnabled = true + default: + req.WeChatConnectOpenEnabled = true + } + } + if req.WeChatConnectMode == "" { + if req.WeChatConnectMPEnabled { + req.WeChatConnectMode = "mp" + } else if req.WeChatConnectMobileEnabled { + req.WeChatConnectMode = "mobile" + } else { + req.WeChatConnectMode = "open" + } + } + + req.WeChatConnectOpenAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectOpenAppID, req.WeChatConnectAppID)) + req.WeChatConnectMPAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectMPAppID, req.WeChatConnectAppID)) + req.WeChatConnectMobileAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectMobileAppID, req.WeChatConnectAppID)) + + if req.WeChatConnectOpenAppSecret == "" { + req.WeChatConnectOpenAppSecret = strings.TrimSpace(firstNonEmpty(previousSettings.WeChatConnectOpenAppSecret, previousSettings.WeChatConnectAppSecret, req.WeChatConnectAppSecret)) + } + if req.WeChatConnectMPAppSecret == "" { + req.WeChatConnectMPAppSecret = strings.TrimSpace(firstNonEmpty(previousSettings.WeChatConnectMPAppSecret, previousSettings.WeChatConnectAppSecret, req.WeChatConnectAppSecret)) + } + if req.WeChatConnectMobileAppSecret == "" { + req.WeChatConnectMobileAppSecret = strings.TrimSpace(firstNonEmpty(previousSettings.WeChatConnectMobileAppSecret, previousSettings.WeChatConnectAppSecret, req.WeChatConnectAppSecret)) + } + if req.WeChatConnectAppSecret == "" { + req.WeChatConnectAppSecret = strings.TrimSpace(firstNonEmpty(req.WeChatConnectOpenAppSecret, req.WeChatConnectMPAppSecret, req.WeChatConnectMobileAppSecret, previousSettings.WeChatConnectAppSecret)) + } + + if req.WeChatConnectOpenEnabled { + if req.WeChatConnectOpenAppID == "" { + response.BadRequest(c, "WeChat PC App ID is required when enabled") + return + } + if req.WeChatConnectOpenAppSecret == "" { + response.BadRequest(c, "WeChat PC App Secret is required when enabled") + return + } + } + if req.WeChatConnectMPEnabled { + if req.WeChatConnectMPAppID == "" { + response.BadRequest(c, "WeChat Official Account App ID is required when enabled") + return + } + if req.WeChatConnectMPAppSecret == "" { + response.BadRequest(c, "WeChat Official Account App Secret is required when enabled") + return + } + } + if req.WeChatConnectMobileEnabled { + if req.WeChatConnectMobileAppID == "" { + response.BadRequest(c, "WeChat Mobile App ID is required when enabled") + return + } + if req.WeChatConnectMobileAppSecret == "" { + response.BadRequest(c, "WeChat Mobile App Secret is required when enabled") + return + } + } + + if req.WeChatConnectScopes == "" { + if req.WeChatConnectMPEnabled { + req.WeChatConnectScopes = service.DefaultWeChatConnectScopesForMode("mp") + } else { + req.WeChatConnectScopes = service.DefaultWeChatConnectScopesForMode(req.WeChatConnectMode) + } + } + if req.WeChatConnectRedirectURL == "" { + response.BadRequest(c, "WeChat Redirect URL is required when enabled") + return + } + if err := config.ValidateAbsoluteHTTPURL(req.WeChatConnectRedirectURL); err != nil { + response.BadRequest(c, "WeChat Redirect URL must be an absolute http(s) URL") + return + } + if req.WeChatConnectFrontendRedirectURL == "" { + req.WeChatConnectFrontendRedirectURL = "/auth/wechat/callback" + } + if err := config.ValidateFrontendRedirectURL(req.WeChatConnectFrontendRedirectURL); err != nil { + response.BadRequest(c, "WeChat Frontend Redirect URL is invalid") + return + } + } + // Generic OIDC 参数验证 if req.OIDCConnectEnabled { req.OIDCConnectProviderName = strings.TrimSpace(req.OIDCConnectProviderName) @@ -538,25 +749,27 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { response.BadRequest(c, "OIDC scopes must contain openid") return } + if !req.OIDCConnectUsePKCE { + response.BadRequest(c, "OIDC PKCE must be enabled") + return + } + if !req.OIDCConnectValidateIDToken { + response.BadRequest(c, "OIDC ID Token validation must be enabled") + return + } switch req.OIDCConnectTokenAuthMethod { case "", "client_secret_post", "client_secret_basic", "none": default: response.BadRequest(c, "OIDC Token Auth Method must be one of client_secret_post/client_secret_basic/none") return } - if req.OIDCConnectTokenAuthMethod == "none" && !req.OIDCConnectUsePKCE { - response.BadRequest(c, "OIDC PKCE must be enabled when token_auth_method=none") - return - } if req.OIDCConnectClockSkewSeconds < 0 || req.OIDCConnectClockSkewSeconds > 600 { response.BadRequest(c, "OIDC clock skew seconds must be between 0 and 600") return } - if req.OIDCConnectValidateIDToken { - if req.OIDCConnectAllowedSigningAlgs == "" { - response.BadRequest(c, "OIDC Allowed Signing Algs is required when validate_id_token=true") - return - } + if req.OIDCConnectAllowedSigningAlgs == "" { + response.BadRequest(c, "OIDC Allowed Signing Algs is required when validate_id_token=true") + return } if req.OIDCConnectJWKSURL != "" { if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectJWKSURL); err != nil { @@ -805,6 +1018,22 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { LinuxDoConnectClientID: req.LinuxDoConnectClientID, LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret, LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL, + WeChatConnectEnabled: req.WeChatConnectEnabled, + WeChatConnectAppID: req.WeChatConnectAppID, + WeChatConnectAppSecret: req.WeChatConnectAppSecret, + WeChatConnectOpenAppID: req.WeChatConnectOpenAppID, + WeChatConnectOpenAppSecret: req.WeChatConnectOpenAppSecret, + WeChatConnectMPAppID: req.WeChatConnectMPAppID, + WeChatConnectMPAppSecret: req.WeChatConnectMPAppSecret, + WeChatConnectMobileAppID: req.WeChatConnectMobileAppID, + WeChatConnectMobileAppSecret: req.WeChatConnectMobileAppSecret, + WeChatConnectOpenEnabled: req.WeChatConnectOpenEnabled, + WeChatConnectMPEnabled: req.WeChatConnectMPEnabled, + WeChatConnectMobileEnabled: req.WeChatConnectMobileEnabled, + WeChatConnectMode: req.WeChatConnectMode, + WeChatConnectScopes: req.WeChatConnectScopes, + WeChatConnectRedirectURL: req.WeChatConnectRedirectURL, + WeChatConnectFrontendRedirectURL: req.WeChatConnectFrontendRedirectURL, OIDCConnectEnabled: req.OIDCConnectEnabled, OIDCConnectProviderName: req.OIDCConnectProviderName, OIDCConnectClientID: req.OIDCConnectClientID, @@ -897,6 +1126,36 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } return previousSettings.EnableCCHSigning }(), + PaymentVisibleMethodAlipaySource: func() string { + if req.PaymentVisibleMethodAlipaySource != nil { + return strings.TrimSpace(*req.PaymentVisibleMethodAlipaySource) + } + return previousSettings.PaymentVisibleMethodAlipaySource + }(), + PaymentVisibleMethodWxpaySource: func() string { + if req.PaymentVisibleMethodWxpaySource != nil { + return strings.TrimSpace(*req.PaymentVisibleMethodWxpaySource) + } + return previousSettings.PaymentVisibleMethodWxpaySource + }(), + PaymentVisibleMethodAlipayEnabled: func() bool { + if req.PaymentVisibleMethodAlipayEnabled != nil { + return *req.PaymentVisibleMethodAlipayEnabled + } + return previousSettings.PaymentVisibleMethodAlipayEnabled + }(), + PaymentVisibleMethodWxpayEnabled: func() bool { + if req.PaymentVisibleMethodWxpayEnabled != nil { + return *req.PaymentVisibleMethodWxpayEnabled + } + return previousSettings.PaymentVisibleMethodWxpayEnabled + }(), + OpenAIAdvancedSchedulerEnabled: func() bool { + if req.OpenAIAdvancedSchedulerEnabled != nil { + return *req.OpenAIAdvancedSchedulerEnabled + } + return previousSettings.OpenAIAdvancedSchedulerEnabled + }(), BalanceLowNotifyEnabled: func() bool { if req.BalanceLowNotifyEnabled != nil { return *req.BalanceLowNotifyEnabled @@ -929,7 +1188,38 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { }(), } - if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil { + authSourceDefaults := &service.AuthSourceDefaultSettings{ + Email: service.ProviderDefaultGrantSettings{ + Balance: float64ValueOrDefault(req.AuthSourceDefaultEmailBalance, previousAuthSourceDefaults.Email.Balance), + Concurrency: intValueOrDefault(req.AuthSourceDefaultEmailConcurrency, previousAuthSourceDefaults.Email.Concurrency), + Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultEmailSubscriptions, previousAuthSourceDefaults.Email.Subscriptions), + GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnSignup, previousAuthSourceDefaults.Email.GrantOnSignup), + GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnFirstBind, previousAuthSourceDefaults.Email.GrantOnFirstBind), + }, + LinuxDo: service.ProviderDefaultGrantSettings{ + Balance: float64ValueOrDefault(req.AuthSourceDefaultLinuxDoBalance, previousAuthSourceDefaults.LinuxDo.Balance), + Concurrency: intValueOrDefault(req.AuthSourceDefaultLinuxDoConcurrency, previousAuthSourceDefaults.LinuxDo.Concurrency), + Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultLinuxDoSubscriptions, previousAuthSourceDefaults.LinuxDo.Subscriptions), + GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnSignup, previousAuthSourceDefaults.LinuxDo.GrantOnSignup), + GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnFirstBind, previousAuthSourceDefaults.LinuxDo.GrantOnFirstBind), + }, + OIDC: service.ProviderDefaultGrantSettings{ + Balance: float64ValueOrDefault(req.AuthSourceDefaultOIDCBalance, previousAuthSourceDefaults.OIDC.Balance), + Concurrency: intValueOrDefault(req.AuthSourceDefaultOIDCConcurrency, previousAuthSourceDefaults.OIDC.Concurrency), + Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultOIDCSubscriptions, previousAuthSourceDefaults.OIDC.Subscriptions), + GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnSignup, previousAuthSourceDefaults.OIDC.GrantOnSignup), + GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnFirstBind, previousAuthSourceDefaults.OIDC.GrantOnFirstBind), + }, + WeChat: service.ProviderDefaultGrantSettings{ + Balance: float64ValueOrDefault(req.AuthSourceDefaultWeChatBalance, previousAuthSourceDefaults.WeChat.Balance), + Concurrency: intValueOrDefault(req.AuthSourceDefaultWeChatConcurrency, previousAuthSourceDefaults.WeChat.Concurrency), + Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultWeChatSubscriptions, previousAuthSourceDefaults.WeChat.Subscriptions), + GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnSignup, previousAuthSourceDefaults.WeChat.GrantOnSignup), + GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnFirstBind, previousAuthSourceDefaults.WeChat.GrantOnFirstBind), + }, + ForceEmailOnThirdPartySignup: boolValueOrDefault(req.ForceEmailOnThirdPartySignup, previousAuthSourceDefaults.ForceEmailOnThirdPartySignup), + } + if err := h.settingService.UpdateSettingsWithAuthSourceDefaults(c.Request.Context(), settings, authSourceDefaults); err != nil { response.ErrorFrom(c, err) return } @@ -969,7 +1259,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } } - h.auditSettingsUpdate(c, previousSettings, settings, req) + h.auditSettingsUpdate(c, previousSettings, settings, previousAuthSourceDefaults, authSourceDefaults, req) // 重新获取设置返回 updatedSettings, err := h.settingService.GetAllSettings(c.Request.Context()) @@ -977,6 +1267,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { response.ErrorFrom(c, err) return } + updatedAuthSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } updatedDefaultSubscriptions := make([]dto.DefaultSubscriptionSetting, 0, len(updatedSettings.DefaultSubscriptions)) for _, sub := range updatedSettings.DefaultSubscriptions { updatedDefaultSubscriptions = append(updatedDefaultSubscriptions, dto.DefaultSubscriptionSetting{ @@ -994,113 +1289,135 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { updatedPaymentCfg = &service.PaymentConfig{} } - response.Success(c, dto.SystemSettings{ - RegistrationEnabled: updatedSettings.RegistrationEnabled, - EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled, - RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist, - PromoCodeEnabled: updatedSettings.PromoCodeEnabled, - PasswordResetEnabled: updatedSettings.PasswordResetEnabled, - FrontendURL: updatedSettings.FrontendURL, - InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled, - TotpEnabled: updatedSettings.TotpEnabled, - TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(), - SMTPHost: updatedSettings.SMTPHost, - SMTPPort: updatedSettings.SMTPPort, - SMTPUsername: updatedSettings.SMTPUsername, - SMTPPasswordConfigured: updatedSettings.SMTPPasswordConfigured, - SMTPFrom: updatedSettings.SMTPFrom, - SMTPFromName: updatedSettings.SMTPFromName, - SMTPUseTLS: updatedSettings.SMTPUseTLS, - TurnstileEnabled: updatedSettings.TurnstileEnabled, - TurnstileSiteKey: updatedSettings.TurnstileSiteKey, - TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured, - LinuxDoConnectEnabled: updatedSettings.LinuxDoConnectEnabled, - LinuxDoConnectClientID: updatedSettings.LinuxDoConnectClientID, - LinuxDoConnectClientSecretConfigured: updatedSettings.LinuxDoConnectClientSecretConfigured, - LinuxDoConnectRedirectURL: updatedSettings.LinuxDoConnectRedirectURL, - OIDCConnectEnabled: updatedSettings.OIDCConnectEnabled, - OIDCConnectProviderName: updatedSettings.OIDCConnectProviderName, - OIDCConnectClientID: updatedSettings.OIDCConnectClientID, - OIDCConnectClientSecretConfigured: updatedSettings.OIDCConnectClientSecretConfigured, - OIDCConnectIssuerURL: updatedSettings.OIDCConnectIssuerURL, - OIDCConnectDiscoveryURL: updatedSettings.OIDCConnectDiscoveryURL, - OIDCConnectAuthorizeURL: updatedSettings.OIDCConnectAuthorizeURL, - OIDCConnectTokenURL: updatedSettings.OIDCConnectTokenURL, - OIDCConnectUserInfoURL: updatedSettings.OIDCConnectUserInfoURL, - OIDCConnectJWKSURL: updatedSettings.OIDCConnectJWKSURL, - OIDCConnectScopes: updatedSettings.OIDCConnectScopes, - OIDCConnectRedirectURL: updatedSettings.OIDCConnectRedirectURL, - OIDCConnectFrontendRedirectURL: updatedSettings.OIDCConnectFrontendRedirectURL, - OIDCConnectTokenAuthMethod: updatedSettings.OIDCConnectTokenAuthMethod, - OIDCConnectUsePKCE: updatedSettings.OIDCConnectUsePKCE, - OIDCConnectValidateIDToken: updatedSettings.OIDCConnectValidateIDToken, - OIDCConnectAllowedSigningAlgs: updatedSettings.OIDCConnectAllowedSigningAlgs, - OIDCConnectClockSkewSeconds: updatedSettings.OIDCConnectClockSkewSeconds, - OIDCConnectRequireEmailVerified: updatedSettings.OIDCConnectRequireEmailVerified, - OIDCConnectUserInfoEmailPath: updatedSettings.OIDCConnectUserInfoEmailPath, - OIDCConnectUserInfoIDPath: updatedSettings.OIDCConnectUserInfoIDPath, - OIDCConnectUserInfoUsernamePath: updatedSettings.OIDCConnectUserInfoUsernamePath, - SiteName: updatedSettings.SiteName, - SiteLogo: updatedSettings.SiteLogo, - SiteSubtitle: updatedSettings.SiteSubtitle, - APIBaseURL: updatedSettings.APIBaseURL, - ContactInfo: updatedSettings.ContactInfo, - DocURL: updatedSettings.DocURL, - HomeContent: updatedSettings.HomeContent, - HideCcsImportButton: updatedSettings.HideCcsImportButton, - PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled, - PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL, - TableDefaultPageSize: updatedSettings.TableDefaultPageSize, - TablePageSizeOptions: updatedSettings.TablePageSizeOptions, - CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems), - CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints), - DefaultConcurrency: updatedSettings.DefaultConcurrency, - DefaultBalance: updatedSettings.DefaultBalance, - DefaultSubscriptions: updatedDefaultSubscriptions, - EnableModelFallback: updatedSettings.EnableModelFallback, - FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic, - FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI, - FallbackModelGemini: updatedSettings.FallbackModelGemini, - FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity, - EnableIdentityPatch: updatedSettings.EnableIdentityPatch, - IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt, - OpsMonitoringEnabled: updatedSettings.OpsMonitoringEnabled, - OpsRealtimeMonitoringEnabled: updatedSettings.OpsRealtimeMonitoringEnabled, - OpsQueryModeDefault: updatedSettings.OpsQueryModeDefault, - OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds, - MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion, - MaxClaudeCodeVersion: updatedSettings.MaxClaudeCodeVersion, - AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling, - BackendModeEnabled: updatedSettings.BackendModeEnabled, - EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification, - EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough, - EnableCCHSigning: updatedSettings.EnableCCHSigning, - BalanceLowNotifyEnabled: updatedSettings.BalanceLowNotifyEnabled, - BalanceLowNotifyThreshold: updatedSettings.BalanceLowNotifyThreshold, - BalanceLowNotifyRechargeURL: updatedSettings.BalanceLowNotifyRechargeURL, - AccountQuotaNotifyEnabled: updatedSettings.AccountQuotaNotifyEnabled, - AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(updatedSettings.AccountQuotaNotifyEmails), - PaymentEnabled: updatedPaymentCfg.Enabled, - PaymentMinAmount: updatedPaymentCfg.MinAmount, - PaymentMaxAmount: updatedPaymentCfg.MaxAmount, - PaymentDailyLimit: updatedPaymentCfg.DailyLimit, - PaymentOrderTimeoutMin: updatedPaymentCfg.OrderTimeoutMin, - PaymentMaxPendingOrders: updatedPaymentCfg.MaxPendingOrders, - PaymentEnabledTypes: updatedPaymentCfg.EnabledTypes, - PaymentBalanceDisabled: updatedPaymentCfg.BalanceDisabled, - PaymentBalanceRechargeMultiplier: updatedPaymentCfg.BalanceRechargeMultiplier, - PaymentRechargeFeeRate: updatedPaymentCfg.RechargeFeeRate, - PaymentLoadBalanceStrat: updatedPaymentCfg.LoadBalanceStrategy, - PaymentProductNamePrefix: updatedPaymentCfg.ProductNamePrefix, - PaymentProductNameSuffix: updatedPaymentCfg.ProductNameSuffix, - PaymentHelpImageURL: updatedPaymentCfg.HelpImageURL, - PaymentHelpText: updatedPaymentCfg.HelpText, - PaymentCancelRateLimitEnabled: updatedPaymentCfg.CancelRateLimitEnabled, - PaymentCancelRateLimitMax: updatedPaymentCfg.CancelRateLimitMax, - PaymentCancelRateLimitWindow: updatedPaymentCfg.CancelRateLimitWindow, - PaymentCancelRateLimitUnit: updatedPaymentCfg.CancelRateLimitUnit, - PaymentCancelRateLimitMode: updatedPaymentCfg.CancelRateLimitMode, - }) + payload := dto.SystemSettings{ + RegistrationEnabled: updatedSettings.RegistrationEnabled, + EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled, + RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist, + PromoCodeEnabled: updatedSettings.PromoCodeEnabled, + PasswordResetEnabled: updatedSettings.PasswordResetEnabled, + FrontendURL: updatedSettings.FrontendURL, + InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled, + TotpEnabled: updatedSettings.TotpEnabled, + TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(), + SMTPHost: updatedSettings.SMTPHost, + SMTPPort: updatedSettings.SMTPPort, + SMTPUsername: updatedSettings.SMTPUsername, + SMTPPasswordConfigured: updatedSettings.SMTPPasswordConfigured, + SMTPFrom: updatedSettings.SMTPFrom, + SMTPFromName: updatedSettings.SMTPFromName, + SMTPUseTLS: updatedSettings.SMTPUseTLS, + TurnstileEnabled: updatedSettings.TurnstileEnabled, + TurnstileSiteKey: updatedSettings.TurnstileSiteKey, + TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured, + LinuxDoConnectEnabled: updatedSettings.LinuxDoConnectEnabled, + LinuxDoConnectClientID: updatedSettings.LinuxDoConnectClientID, + LinuxDoConnectClientSecretConfigured: updatedSettings.LinuxDoConnectClientSecretConfigured, + LinuxDoConnectRedirectURL: updatedSettings.LinuxDoConnectRedirectURL, + WeChatConnectEnabled: updatedSettings.WeChatConnectEnabled, + WeChatConnectAppID: updatedSettings.WeChatConnectAppID, + WeChatConnectAppSecretConfigured: updatedSettings.WeChatConnectAppSecretConfigured, + WeChatConnectOpenAppID: updatedSettings.WeChatConnectOpenAppID, + WeChatConnectOpenAppSecretConfigured: updatedSettings.WeChatConnectOpenAppSecretConfigured, + WeChatConnectMPAppID: updatedSettings.WeChatConnectMPAppID, + WeChatConnectMPAppSecretConfigured: updatedSettings.WeChatConnectMPAppSecretConfigured, + WeChatConnectMobileAppID: updatedSettings.WeChatConnectMobileAppID, + WeChatConnectMobileAppSecretConfigured: updatedSettings.WeChatConnectMobileAppSecretConfigured, + WeChatConnectOpenEnabled: updatedSettings.WeChatConnectOpenEnabled, + WeChatConnectMPEnabled: updatedSettings.WeChatConnectMPEnabled, + WeChatConnectMobileEnabled: updatedSettings.WeChatConnectMobileEnabled, + WeChatConnectMode: updatedSettings.WeChatConnectMode, + WeChatConnectScopes: updatedSettings.WeChatConnectScopes, + WeChatConnectRedirectURL: updatedSettings.WeChatConnectRedirectURL, + WeChatConnectFrontendRedirectURL: updatedSettings.WeChatConnectFrontendRedirectURL, + OIDCConnectEnabled: updatedSettings.OIDCConnectEnabled, + OIDCConnectProviderName: updatedSettings.OIDCConnectProviderName, + OIDCConnectClientID: updatedSettings.OIDCConnectClientID, + OIDCConnectClientSecretConfigured: updatedSettings.OIDCConnectClientSecretConfigured, + OIDCConnectIssuerURL: updatedSettings.OIDCConnectIssuerURL, + OIDCConnectDiscoveryURL: updatedSettings.OIDCConnectDiscoveryURL, + OIDCConnectAuthorizeURL: updatedSettings.OIDCConnectAuthorizeURL, + OIDCConnectTokenURL: updatedSettings.OIDCConnectTokenURL, + OIDCConnectUserInfoURL: updatedSettings.OIDCConnectUserInfoURL, + OIDCConnectJWKSURL: updatedSettings.OIDCConnectJWKSURL, + OIDCConnectScopes: updatedSettings.OIDCConnectScopes, + OIDCConnectRedirectURL: updatedSettings.OIDCConnectRedirectURL, + OIDCConnectFrontendRedirectURL: updatedSettings.OIDCConnectFrontendRedirectURL, + OIDCConnectTokenAuthMethod: updatedSettings.OIDCConnectTokenAuthMethod, + OIDCConnectUsePKCE: updatedSettings.OIDCConnectUsePKCE, + OIDCConnectValidateIDToken: updatedSettings.OIDCConnectValidateIDToken, + OIDCConnectAllowedSigningAlgs: updatedSettings.OIDCConnectAllowedSigningAlgs, + OIDCConnectClockSkewSeconds: updatedSettings.OIDCConnectClockSkewSeconds, + OIDCConnectRequireEmailVerified: updatedSettings.OIDCConnectRequireEmailVerified, + OIDCConnectUserInfoEmailPath: updatedSettings.OIDCConnectUserInfoEmailPath, + OIDCConnectUserInfoIDPath: updatedSettings.OIDCConnectUserInfoIDPath, + OIDCConnectUserInfoUsernamePath: updatedSettings.OIDCConnectUserInfoUsernamePath, + SiteName: updatedSettings.SiteName, + SiteLogo: updatedSettings.SiteLogo, + SiteSubtitle: updatedSettings.SiteSubtitle, + APIBaseURL: updatedSettings.APIBaseURL, + ContactInfo: updatedSettings.ContactInfo, + DocURL: updatedSettings.DocURL, + HomeContent: updatedSettings.HomeContent, + HideCcsImportButton: updatedSettings.HideCcsImportButton, + PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled, + PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL, + TableDefaultPageSize: updatedSettings.TableDefaultPageSize, + TablePageSizeOptions: updatedSettings.TablePageSizeOptions, + CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems), + CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints), + DefaultConcurrency: updatedSettings.DefaultConcurrency, + DefaultBalance: updatedSettings.DefaultBalance, + DefaultSubscriptions: updatedDefaultSubscriptions, + EnableModelFallback: updatedSettings.EnableModelFallback, + FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic, + FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI, + FallbackModelGemini: updatedSettings.FallbackModelGemini, + FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity, + EnableIdentityPatch: updatedSettings.EnableIdentityPatch, + IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt, + OpsMonitoringEnabled: updatedSettings.OpsMonitoringEnabled, + OpsRealtimeMonitoringEnabled: updatedSettings.OpsRealtimeMonitoringEnabled, + OpsQueryModeDefault: updatedSettings.OpsQueryModeDefault, + OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds, + MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion, + MaxClaudeCodeVersion: updatedSettings.MaxClaudeCodeVersion, + AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling, + BackendModeEnabled: updatedSettings.BackendModeEnabled, + EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification, + EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough, + EnableCCHSigning: updatedSettings.EnableCCHSigning, + PaymentVisibleMethodAlipaySource: updatedSettings.PaymentVisibleMethodAlipaySource, + PaymentVisibleMethodWxpaySource: updatedSettings.PaymentVisibleMethodWxpaySource, + PaymentVisibleMethodAlipayEnabled: updatedSettings.PaymentVisibleMethodAlipayEnabled, + PaymentVisibleMethodWxpayEnabled: updatedSettings.PaymentVisibleMethodWxpayEnabled, + OpenAIAdvancedSchedulerEnabled: updatedSettings.OpenAIAdvancedSchedulerEnabled, + BalanceLowNotifyEnabled: updatedSettings.BalanceLowNotifyEnabled, + BalanceLowNotifyThreshold: updatedSettings.BalanceLowNotifyThreshold, + BalanceLowNotifyRechargeURL: updatedSettings.BalanceLowNotifyRechargeURL, + AccountQuotaNotifyEnabled: updatedSettings.AccountQuotaNotifyEnabled, + AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(updatedSettings.AccountQuotaNotifyEmails), + PaymentEnabled: updatedPaymentCfg.Enabled, + PaymentMinAmount: updatedPaymentCfg.MinAmount, + PaymentMaxAmount: updatedPaymentCfg.MaxAmount, + PaymentDailyLimit: updatedPaymentCfg.DailyLimit, + PaymentOrderTimeoutMin: updatedPaymentCfg.OrderTimeoutMin, + PaymentMaxPendingOrders: updatedPaymentCfg.MaxPendingOrders, + PaymentEnabledTypes: updatedPaymentCfg.EnabledTypes, + PaymentBalanceDisabled: updatedPaymentCfg.BalanceDisabled, + PaymentBalanceRechargeMultiplier: updatedPaymentCfg.BalanceRechargeMultiplier, + PaymentRechargeFeeRate: updatedPaymentCfg.RechargeFeeRate, + PaymentLoadBalanceStrat: updatedPaymentCfg.LoadBalanceStrategy, + PaymentProductNamePrefix: updatedPaymentCfg.ProductNamePrefix, + PaymentProductNameSuffix: updatedPaymentCfg.ProductNameSuffix, + PaymentHelpImageURL: updatedPaymentCfg.HelpImageURL, + PaymentHelpText: updatedPaymentCfg.HelpText, + PaymentCancelRateLimitEnabled: updatedPaymentCfg.CancelRateLimitEnabled, + PaymentCancelRateLimitMax: updatedPaymentCfg.CancelRateLimitMax, + PaymentCancelRateLimitWindow: updatedPaymentCfg.CancelRateLimitWindow, + PaymentCancelRateLimitUnit: updatedPaymentCfg.CancelRateLimitUnit, + PaymentCancelRateLimitMode: updatedPaymentCfg.CancelRateLimitMode, + } + response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults)) } // hasPaymentFields returns true if any payment-related field was explicitly provided. @@ -1117,12 +1434,12 @@ func hasPaymentFields(req UpdateSettingsRequest) bool { req.PaymentCancelRateLimitUnit != nil || req.PaymentCancelRateLimitMode != nil } -func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.SystemSettings, after *service.SystemSettings, req UpdateSettingsRequest) { +func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.SystemSettings, after *service.SystemSettings, beforeAuthSourceDefaults *service.AuthSourceDefaultSettings, afterAuthSourceDefaults *service.AuthSourceDefaultSettings, req UpdateSettingsRequest) { if before == nil || after == nil { return } - changed := diffSettings(before, after, req) + changed := diffSettings(before, after, beforeAuthSourceDefaults, afterAuthSourceDefaults, req) if len(changed) == 0 { return } @@ -1137,7 +1454,7 @@ func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.Sys ) } -func diffSettings(before *service.SystemSettings, after *service.SystemSettings, req UpdateSettingsRequest) []string { +func diffSettings(before *service.SystemSettings, after *service.SystemSettings, beforeAuthSourceDefaults *service.AuthSourceDefaultSettings, afterAuthSourceDefaults *service.AuthSourceDefaultSettings, req UpdateSettingsRequest) []string { changed := make([]string, 0, 20) if before.RegistrationEnabled != after.RegistrationEnabled { changed = append(changed, "registration_enabled") @@ -1205,6 +1522,54 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.LinuxDoConnectRedirectURL != after.LinuxDoConnectRedirectURL { changed = append(changed, "linuxdo_connect_redirect_url") } + if before.WeChatConnectEnabled != after.WeChatConnectEnabled { + changed = append(changed, "wechat_connect_enabled") + } + if before.WeChatConnectAppID != after.WeChatConnectAppID { + changed = append(changed, "wechat_connect_app_id") + } + if req.WeChatConnectAppSecret != "" { + changed = append(changed, "wechat_connect_app_secret") + } + if before.WeChatConnectOpenAppID != after.WeChatConnectOpenAppID { + changed = append(changed, "wechat_connect_open_app_id") + } + if req.WeChatConnectOpenAppSecret != "" { + changed = append(changed, "wechat_connect_open_app_secret") + } + if before.WeChatConnectMPAppID != after.WeChatConnectMPAppID { + changed = append(changed, "wechat_connect_mp_app_id") + } + if req.WeChatConnectMPAppSecret != "" { + changed = append(changed, "wechat_connect_mp_app_secret") + } + if before.WeChatConnectMobileAppID != after.WeChatConnectMobileAppID { + changed = append(changed, "wechat_connect_mobile_app_id") + } + if req.WeChatConnectMobileAppSecret != "" { + changed = append(changed, "wechat_connect_mobile_app_secret") + } + if before.WeChatConnectOpenEnabled != after.WeChatConnectOpenEnabled { + changed = append(changed, "wechat_connect_open_enabled") + } + if before.WeChatConnectMPEnabled != after.WeChatConnectMPEnabled { + changed = append(changed, "wechat_connect_mp_enabled") + } + if before.WeChatConnectMobileEnabled != after.WeChatConnectMobileEnabled { + changed = append(changed, "wechat_connect_mobile_enabled") + } + if before.WeChatConnectMode != after.WeChatConnectMode { + changed = append(changed, "wechat_connect_mode") + } + if before.WeChatConnectScopes != after.WeChatConnectScopes { + changed = append(changed, "wechat_connect_scopes") + } + if before.WeChatConnectRedirectURL != after.WeChatConnectRedirectURL { + changed = append(changed, "wechat_connect_redirect_url") + } + if before.WeChatConnectFrontendRedirectURL != after.WeChatConnectFrontendRedirectURL { + changed = append(changed, "wechat_connect_frontend_redirect_url") + } if before.OIDCConnectEnabled != after.OIDCConnectEnabled { changed = append(changed, "oidc_connect_enabled") } @@ -1376,6 +1741,21 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.EnableCCHSigning != after.EnableCCHSigning { changed = append(changed, "enable_cch_signing") } + if before.PaymentVisibleMethodAlipaySource != after.PaymentVisibleMethodAlipaySource { + changed = append(changed, "payment_visible_method_alipay_source") + } + if before.PaymentVisibleMethodWxpaySource != after.PaymentVisibleMethodWxpaySource { + changed = append(changed, "payment_visible_method_wxpay_source") + } + if before.PaymentVisibleMethodAlipayEnabled != after.PaymentVisibleMethodAlipayEnabled { + changed = append(changed, "payment_visible_method_alipay_enabled") + } + if before.PaymentVisibleMethodWxpayEnabled != after.PaymentVisibleMethodWxpayEnabled { + changed = append(changed, "payment_visible_method_wxpay_enabled") + } + if before.OpenAIAdvancedSchedulerEnabled != after.OpenAIAdvancedSchedulerEnabled { + changed = append(changed, "openai_advanced_scheduler_enabled") + } // Balance & quota notification if before.BalanceLowNotifyEnabled != after.BalanceLowNotifyEnabled { changed = append(changed, "balance_low_notify_enabled") @@ -1392,6 +1772,50 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if !equalNotifyEmailEntries(before.AccountQuotaNotifyEmails, after.AccountQuotaNotifyEmails) { changed = append(changed, "account_quota_notify_emails") } + changed = appendAuthSourceDefaultChanges(changed, beforeAuthSourceDefaults, afterAuthSourceDefaults) + return changed +} + +func appendAuthSourceDefaultChanges(changed []string, before *service.AuthSourceDefaultSettings, after *service.AuthSourceDefaultSettings) []string { + if before == nil { + before = &service.AuthSourceDefaultSettings{} + } + if after == nil { + after = &service.AuthSourceDefaultSettings{} + } + + type providerDefaultGrantField struct { + name string + before service.ProviderDefaultGrantSettings + after service.ProviderDefaultGrantSettings + } + + fields := []providerDefaultGrantField{ + {name: "email", before: before.Email, after: after.Email}, + {name: "linuxdo", before: before.LinuxDo, after: after.LinuxDo}, + {name: "oidc", before: before.OIDC, after: after.OIDC}, + {name: "wechat", before: before.WeChat, after: after.WeChat}, + } + for _, field := range fields { + if field.before.Balance != field.after.Balance { + changed = append(changed, "auth_source_default_"+field.name+"_balance") + } + if field.before.Concurrency != field.after.Concurrency { + changed = append(changed, "auth_source_default_"+field.name+"_concurrency") + } + if !equalDefaultSubscriptions(field.before.Subscriptions, field.after.Subscriptions) { + changed = append(changed, "auth_source_default_"+field.name+"_subscriptions") + } + if field.before.GrantOnSignup != field.after.GrantOnSignup { + changed = append(changed, "auth_source_default_"+field.name+"_grant_on_signup") + } + if field.before.GrantOnFirstBind != field.after.GrantOnFirstBind { + changed = append(changed, "auth_source_default_"+field.name+"_grant_on_first_bind") + } + } + if before.ForceEmailOnThirdPartySignup != after.ForceEmailOnThirdPartySignup { + changed = append(changed, "force_email_on_third_party_signup") + } return changed } @@ -1412,6 +1836,84 @@ func normalizeDefaultSubscriptions(input []dto.DefaultSubscriptionSetting) []dto return normalized } +func normalizeOptionalDefaultSubscriptions(input *[]dto.DefaultSubscriptionSetting) *[]dto.DefaultSubscriptionSetting { + if input == nil { + return nil + } + normalized := normalizeDefaultSubscriptions(*input) + return &normalized +} + +func float64ValueOrDefault(value *float64, fallback float64) float64 { + if value == nil { + return fallback + } + return *value +} + +func intValueOrDefault(value *int, fallback int) int { + if value == nil { + return fallback + } + return *value +} + +func boolValueOrDefault(value *bool, fallback bool) bool { + if value == nil { + return fallback + } + return *value +} + +func defaultSubscriptionsValueOrDefault(input *[]dto.DefaultSubscriptionSetting, fallback []service.DefaultSubscriptionSetting) []service.DefaultSubscriptionSetting { + if input == nil { + return fallback + } + result := make([]service.DefaultSubscriptionSetting, 0, len(*input)) + for _, item := range *input { + result = append(result, service.DefaultSubscriptionSetting{ + GroupID: item.GroupID, + ValidityDays: item.ValidityDays, + }) + } + return result +} + +func systemSettingsResponseData(settings dto.SystemSettings, authSourceDefaults *service.AuthSourceDefaultSettings) map[string]any { + data := make(map[string]any) + raw, err := json.Marshal(settings) + if err == nil { + _ = json.Unmarshal(raw, &data) + } + if authSourceDefaults == nil { + authSourceDefaults = &service.AuthSourceDefaultSettings{} + } + + data["auth_source_default_email_balance"] = authSourceDefaults.Email.Balance + data["auth_source_default_email_concurrency"] = authSourceDefaults.Email.Concurrency + data["auth_source_default_email_subscriptions"] = authSourceDefaults.Email.Subscriptions + data["auth_source_default_email_grant_on_signup"] = authSourceDefaults.Email.GrantOnSignup + data["auth_source_default_email_grant_on_first_bind"] = authSourceDefaults.Email.GrantOnFirstBind + data["auth_source_default_linuxdo_balance"] = authSourceDefaults.LinuxDo.Balance + data["auth_source_default_linuxdo_concurrency"] = authSourceDefaults.LinuxDo.Concurrency + data["auth_source_default_linuxdo_subscriptions"] = authSourceDefaults.LinuxDo.Subscriptions + data["auth_source_default_linuxdo_grant_on_signup"] = authSourceDefaults.LinuxDo.GrantOnSignup + data["auth_source_default_linuxdo_grant_on_first_bind"] = authSourceDefaults.LinuxDo.GrantOnFirstBind + data["auth_source_default_oidc_balance"] = authSourceDefaults.OIDC.Balance + data["auth_source_default_oidc_concurrency"] = authSourceDefaults.OIDC.Concurrency + data["auth_source_default_oidc_subscriptions"] = authSourceDefaults.OIDC.Subscriptions + data["auth_source_default_oidc_grant_on_signup"] = authSourceDefaults.OIDC.GrantOnSignup + data["auth_source_default_oidc_grant_on_first_bind"] = authSourceDefaults.OIDC.GrantOnFirstBind + data["auth_source_default_wechat_balance"] = authSourceDefaults.WeChat.Balance + data["auth_source_default_wechat_concurrency"] = authSourceDefaults.WeChat.Concurrency + data["auth_source_default_wechat_subscriptions"] = authSourceDefaults.WeChat.Subscriptions + data["auth_source_default_wechat_grant_on_signup"] = authSourceDefaults.WeChat.GrantOnSignup + data["auth_source_default_wechat_grant_on_first_bind"] = authSourceDefaults.WeChat.GrantOnFirstBind + data["force_email_on_third_party_signup"] = authSourceDefaults.ForceEmailOnThirdPartySignup + + return data +} + func equalStringSlice(a, b []string) bool { if len(a) != len(b) { return false diff --git a/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go new file mode 100644 index 0000000000000000000000000000000000000000..cef531e0bab7f72d50c73e34dcf0fd947d4991bd --- /dev/null +++ b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go @@ -0,0 +1,346 @@ +package admin + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type settingHandlerRepoStub struct { + values map[string]string + lastUpdates map[string]string +} + +func (s *settingHandlerRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) { + panic("unexpected Get call") +} + +func (s *settingHandlerRepoStub) GetValue(ctx context.Context, key string) (string, error) { + panic("unexpected GetValue call") +} + +func (s *settingHandlerRepoStub) Set(ctx context.Context, key, value string) error { + panic("unexpected Set call") +} + +func (s *settingHandlerRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (s *settingHandlerRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + s.lastUpdates = make(map[string]string, len(settings)) + for key, value := range settings { + s.lastUpdates[key] = value + if s.values == nil { + s.values = map[string]string{} + } + s.values[key] = value + } + return nil +} + +func (s *settingHandlerRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + out := make(map[string]string, len(s.values)) + for key, value := range s.values { + out[key] = value + } + return out, nil +} + +func (s *settingHandlerRepoStub) Delete(ctx context.Context, key string) error { + panic("unexpected Delete call") +} + +type failingAuthSourceSettingsRepoStub struct { + values map[string]string + err error +} + +func (s *failingAuthSourceSettingsRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) { + panic("unexpected Get call") +} + +func (s *failingAuthSourceSettingsRepoStub) GetValue(ctx context.Context, key string) (string, error) { + panic("unexpected GetValue call") +} + +func (s *failingAuthSourceSettingsRepoStub) Set(ctx context.Context, key, value string) error { + panic("unexpected Set call") +} + +func (s *failingAuthSourceSettingsRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (s *failingAuthSourceSettingsRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + if _, ok := settings[service.SettingKeyAuthSourceDefaultEmailBalance]; ok { + return s.err + } + for key, value := range settings { + if s.values == nil { + s.values = map[string]string{} + } + s.values[key] = value + } + return nil +} + +func (s *failingAuthSourceSettingsRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + out := make(map[string]string, len(s.values)) + for key, value := range s.values { + out[key] = value + } + return out, nil +} + +func (s *failingAuthSourceSettingsRepoStub) Delete(ctx context.Context, key string) error { + panic("unexpected Delete call") +} + +func TestSettingHandler_GetSettings_InjectsAuthSourceDefaults(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := &settingHandlerRepoStub{ + values: map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyPromoCodeEnabled: "true", + service.SettingKeyAuthSourceDefaultEmailBalance: "9.5", + service.SettingKeyAuthSourceDefaultEmailConcurrency: "8", + service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":31,"validity_days":15}]`, + service.SettingKeyForceEmailOnThirdPartySignup: "true", + }, + } + svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}}) + handler := NewSettingHandler(svc, nil, nil, nil, nil, nil) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/admin/settings", nil) + + handler.GetSettings(c) + + require.Equal(t, http.StatusOK, rec.Code) + var resp response.Response + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + data, ok := resp.Data.(map[string]any) + require.True(t, ok) + require.Equal(t, 9.5, data["auth_source_default_email_balance"]) + require.Equal(t, float64(8), data["auth_source_default_email_concurrency"]) + require.Equal(t, true, data["force_email_on_third_party_signup"]) + + subscriptions, ok := data["auth_source_default_email_subscriptions"].([]any) + require.True(t, ok) + require.Len(t, subscriptions, 1) +} + +func TestSettingHandler_UpdateSettings_PreservesOmittedAuthSourceDefaults(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := &settingHandlerRepoStub{ + values: map[string]string{ + service.SettingKeyRegistrationEnabled: "false", + service.SettingKeyPromoCodeEnabled: "true", + service.SettingKeyAuthSourceDefaultEmailBalance: "9.5", + service.SettingKeyAuthSourceDefaultEmailConcurrency: "8", + service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":31,"validity_days":15}]`, + service.SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true", + service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "false", + service.SettingKeyForceEmailOnThirdPartySignup: "true", + }, + } + svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}}) + handler := NewSettingHandler(svc, nil, nil, nil, nil, nil) + + body := map[string]any{ + "registration_enabled": true, + "promo_code_enabled": true, + "auth_source_default_email_balance": 12.75, + } + rawBody, err := json.Marshal(body) + require.NoError(t, err) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody)) + c.Request.Header.Set("Content-Type", "application/json") + + handler.UpdateSettings(c) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "12.75000000", repo.values[service.SettingKeyAuthSourceDefaultEmailBalance]) + require.Equal(t, "8", repo.values[service.SettingKeyAuthSourceDefaultEmailConcurrency]) + require.Equal(t, `[{"group_id":31,"validity_days":15}]`, repo.values[service.SettingKeyAuthSourceDefaultEmailSubscriptions]) + require.Equal(t, "true", repo.values[service.SettingKeyForceEmailOnThirdPartySignup]) + + var resp response.Response + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + data, ok := resp.Data.(map[string]any) + require.True(t, ok) + require.Equal(t, 12.75, data["auth_source_default_email_balance"]) + require.Equal(t, float64(8), data["auth_source_default_email_concurrency"]) + require.Equal(t, true, data["force_email_on_third_party_signup"]) +} + +func TestSettingHandler_UpdateSettings_PersistsPaymentVisibleMethodsAndAdvancedScheduler(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := &settingHandlerRepoStub{ + values: map[string]string{ + service.SettingKeyPromoCodeEnabled: "true", + }, + } + svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}}) + handler := NewSettingHandler(svc, nil, nil, nil, nil, nil) + + body := map[string]any{ + "promo_code_enabled": true, + "payment_visible_method_alipay_source": "easypay", + "payment_visible_method_wxpay_source": "wxpay", + "payment_visible_method_alipay_enabled": true, + "payment_visible_method_wxpay_enabled": false, + "openai_advanced_scheduler_enabled": true, + } + rawBody, err := json.Marshal(body) + require.NoError(t, err) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody)) + c.Request.Header.Set("Content-Type", "application/json") + + handler.UpdateSettings(c) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, service.VisibleMethodSourceEasyPayAlipay, repo.values[service.SettingPaymentVisibleMethodAlipaySource]) + require.Equal(t, service.VisibleMethodSourceOfficialWechat, repo.values[service.SettingPaymentVisibleMethodWxpaySource]) + require.Equal(t, "true", repo.values[service.SettingPaymentVisibleMethodAlipayEnabled]) + require.Equal(t, "false", repo.values[service.SettingPaymentVisibleMethodWxpayEnabled]) + require.Equal(t, "true", repo.values["openai_advanced_scheduler_enabled"]) + + var resp response.Response + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + data, ok := resp.Data.(map[string]any) + require.True(t, ok) + require.Equal(t, service.VisibleMethodSourceEasyPayAlipay, data["payment_visible_method_alipay_source"]) + require.Equal(t, service.VisibleMethodSourceOfficialWechat, data["payment_visible_method_wxpay_source"]) + require.Equal(t, true, data["payment_visible_method_alipay_enabled"]) + require.Equal(t, false, data["payment_visible_method_wxpay_enabled"]) + require.Equal(t, true, data["openai_advanced_scheduler_enabled"]) +} + +func TestSettingHandler_UpdateSettings_RejectsInvalidPaymentVisibleMethodSource(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := &settingHandlerRepoStub{ + values: map[string]string{ + service.SettingKeyPromoCodeEnabled: "true", + }, + } + svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}}) + handler := NewSettingHandler(svc, nil, nil, nil, nil, nil) + + body := map[string]any{ + "promo_code_enabled": true, + "payment_visible_method_alipay_source": "bogus", + } + rawBody, err := json.Marshal(body) + require.NoError(t, err) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody)) + c.Request.Header.Set("Content-Type", "application/json") + + handler.UpdateSettings(c) + + require.Equal(t, http.StatusBadRequest, rec.Code) + require.NotContains(t, repo.values, service.SettingPaymentVisibleMethodAlipaySource) +} + +func TestSettingHandler_UpdateSettings_DoesNotPersistPartialSystemSettingsWhenAuthSourceDefaultsFail(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := &failingAuthSourceSettingsRepoStub{ + values: map[string]string{ + service.SettingKeyRegistrationEnabled: "false", + service.SettingKeyPromoCodeEnabled: "true", + service.SettingKeyAuthSourceDefaultEmailBalance: "9.5", + service.SettingKeyAuthSourceDefaultEmailConcurrency: "8", + service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":31,"validity_days":15}]`, + }, + err: errors.New("write auth source defaults failed"), + } + svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}}) + handler := NewSettingHandler(svc, nil, nil, nil, nil, nil) + + body := map[string]any{ + "registration_enabled": true, + "promo_code_enabled": true, + "auth_source_default_email_balance": 12.75, + } + rawBody, err := json.Marshal(body) + require.NoError(t, err) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody)) + c.Request.Header.Set("Content-Type", "application/json") + + handler.UpdateSettings(c) + + require.Equal(t, http.StatusInternalServerError, rec.Code) + require.Equal(t, "false", repo.values[service.SettingKeyRegistrationEnabled]) + require.Equal(t, "9.5", repo.values[service.SettingKeyAuthSourceDefaultEmailBalance]) +} + +func TestDiffSettings_IncludesAuthSourceDefaultsAndForceEmail(t *testing.T) { + changed := diffSettings( + &service.SystemSettings{}, + &service.SystemSettings{}, + &service.AuthSourceDefaultSettings{ + Email: service.ProviderDefaultGrantSettings{ + Balance: 0, + Concurrency: 5, + Subscriptions: nil, + GrantOnSignup: true, + GrantOnFirstBind: false, + }, + ForceEmailOnThirdPartySignup: false, + }, + &service.AuthSourceDefaultSettings{ + Email: service.ProviderDefaultGrantSettings{ + Balance: 12.5, + Concurrency: 7, + Subscriptions: []service.DefaultSubscriptionSetting{{GroupID: 21, ValidityDays: 30}}, + GrantOnSignup: false, + GrantOnFirstBind: true, + }, + ForceEmailOnThirdPartySignup: true, + }, + UpdateSettingsRequest{}, + ) + + require.Contains(t, changed, "auth_source_default_email_balance") + require.Contains(t, changed, "auth_source_default_email_concurrency") + require.Contains(t, changed, "auth_source_default_email_subscriptions") + require.Contains(t, changed, "auth_source_default_email_grant_on_signup") + require.Contains(t, changed, "auth_source_default_email_grant_on_first_bind") + require.Contains(t, changed, "force_email_on_third_party_signup") +} diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index 1453bd0739ddf6a518e010db4f3256fcdc8692fe..b2ed9d18b6e62e8c0a2ebc2ba1ab75d112ac69b2 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -66,6 +66,22 @@ type UpdateBalanceRequest struct { Notes string `json:"notes"` } +type BindUserAuthIdentityRequest struct { + ProviderType string `json:"provider_type"` + ProviderKey string `json:"provider_key"` + ProviderSubject string `json:"provider_subject"` + Issuer *string `json:"issuer"` + Metadata map[string]any `json:"metadata"` + Channel *BindUserAuthIdentityChannelRequest `json:"channel"` +} + +type BindUserAuthIdentityChannelRequest struct { + Channel string `json:"channel"` + ChannelAppID string `json:"channel_app_id"` + ChannelSubject string `json:"channel_subject"` + Metadata map[string]any `json:"metadata"` +} + // List handles listing all users with pagination // GET /api/v1/admin/users // Query params: @@ -172,6 +188,45 @@ func (h *UserHandler) GetByID(c *gin.Context) { response.Success(c, dto.UserFromServiceAdmin(user)) } +// BindAuthIdentity manually binds a canonical auth identity to a user. +// POST /api/v1/admin/users/:id/auth-identities +func (h *UserHandler) BindAuthIdentity(c *gin.Context) { + userID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid user ID") + return + } + + var req BindUserAuthIdentityRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + input := service.AdminBindAuthIdentityInput{ + ProviderType: req.ProviderType, + ProviderKey: req.ProviderKey, + ProviderSubject: req.ProviderSubject, + Issuer: req.Issuer, + Metadata: req.Metadata, + } + if req.Channel != nil { + input.Channel = &service.AdminBindAuthIdentityChannelInput{ + Channel: req.Channel.Channel, + ChannelAppID: req.Channel.ChannelAppID, + ChannelSubject: req.Channel.ChannelSubject, + Metadata: req.Channel.Metadata, + } + } + + result, err := h.adminService.BindUserAuthIdentity(c.Request.Context(), userID, input) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, result) +} + // Create handles creating a new user // POST /api/v1/admin/users func (h *UserHandler) Create(c *gin.Context) { diff --git a/backend/internal/handler/admin/user_handler_activity_test.go b/backend/internal/handler/admin/user_handler_activity_test.go new file mode 100644 index 0000000000000000000000000000000000000000..bfba2408035a9766e8b2ba1235068b7127280052 --- /dev/null +++ b/backend/internal/handler/admin/user_handler_activity_test.go @@ -0,0 +1,114 @@ +//go:build unit + +package admin + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestUserHandlerListIncludesActivityFieldsAndSortParams(t *testing.T) { + gin.SetMode(gin.TestMode) + + lastLoginAt := time.Date(2026, 4, 20, 8, 0, 0, 0, time.UTC) + lastActiveAt := lastLoginAt.Add(30 * time.Minute) + lastUsedAt := lastLoginAt.Add(90 * time.Minute) + + adminSvc := newStubAdminService() + adminSvc.users = []service.User{ + { + ID: 7, + Email: "activity@example.com", + Username: "activity-user", + Role: service.RoleUser, + Status: service.StatusActive, + LastActiveAt: &lastActiveAt, + LastUsedAt: &lastUsedAt, + CreatedAt: lastLoginAt.Add(-24 * time.Hour), + UpdatedAt: lastLoginAt, + }, + } + handler := NewUserHandler(adminSvc, nil) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest( + http.MethodGet, + "/api/v1/admin/users?sort_by=last_used_at&sort_order=asc&search=activity", + nil, + ) + + handler.List(c) + + require.Equal(t, http.StatusOK, recorder.Code) + require.Equal(t, "last_used_at", adminSvc.lastListUsers.sortBy) + require.Equal(t, "asc", adminSvc.lastListUsers.sortOrder) + require.Equal(t, "activity", adminSvc.lastListUsers.filters.Search) + + var resp struct { + Code int `json:"code"` + Data struct { + Items []struct { + LastActiveAt *time.Time `json:"last_active_at"` + LastUsedAt *time.Time `json:"last_used_at"` + } `json:"items"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.Len(t, resp.Data.Items, 1) + require.WithinDuration(t, lastActiveAt, *resp.Data.Items[0].LastActiveAt, time.Second) + require.WithinDuration(t, lastUsedAt, *resp.Data.Items[0].LastUsedAt, time.Second) +} + +func TestUserHandlerGetByIDIncludesActivityFields(t *testing.T) { + gin.SetMode(gin.TestMode) + + lastLoginAt := time.Date(2026, 4, 20, 8, 0, 0, 0, time.UTC) + lastActiveAt := lastLoginAt.Add(30 * time.Minute) + lastUsedAt := lastLoginAt.Add(90 * time.Minute) + + adminSvc := newStubAdminService() + adminSvc.users = []service.User{ + { + ID: 8, + Email: "detail@example.com", + Username: "detail-user", + Role: service.RoleUser, + Status: service.StatusActive, + LastActiveAt: &lastActiveAt, + LastUsedAt: &lastUsedAt, + CreatedAt: lastLoginAt.Add(-24 * time.Hour), + UpdatedAt: lastLoginAt, + }, + } + handler := NewUserHandler(adminSvc, nil) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Params = gin.Params{{Key: "id", Value: "8"}} + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/8", nil) + + handler.GetByID(c) + + require.Equal(t, http.StatusOK, recorder.Code) + + var resp struct { + Code int `json:"code"` + Data struct { + LastActiveAt *time.Time `json:"last_active_at"` + LastUsedAt *time.Time `json:"last_used_at"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.WithinDuration(t, lastActiveAt, *resp.Data.LastActiveAt, time.Second) + require.WithinDuration(t, lastUsedAt, *resp.Data.LastUsedAt, time.Second) +} diff --git a/backend/internal/handler/auth_current_user_test.go b/backend/internal/handler/auth_current_user_test.go new file mode 100644 index 0000000000000000000000000000000000000000..31d92a3613812cd7ab8e770c471dc41a34341e5b --- /dev/null +++ b/backend/internal/handler/auth_current_user_test.go @@ -0,0 +1,85 @@ +//go:build unit + +package handler + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestAuthHandlerGetCurrentUserReturnsProfileCompatibilityFields(t *testing.T) { + gin.SetMode(gin.TestMode) + + verifiedAt := time.Date(2026, 4, 20, 8, 30, 0, 0, time.UTC) + repo := &userHandlerRepoStub{ + user: &service.User{ + ID: 31, + Email: "me@example.com", + Username: "linuxdo-handle", + Role: service.RoleUser, + Status: service.StatusActive, + AvatarURL: "https://cdn.example.com/linuxdo.png", + AvatarSource: "remote_url", + }, + identities: []service.UserAuthIdentityRecord{ + { + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: "linuxdo-subject-31", + VerifiedAt: &verifiedAt, + Metadata: map[string]any{ + "username": "linuxdo-handle", + }, + }, + }, + } + + handler := &AuthHandler{ + userService: service.NewUserService(repo, nil, nil, nil), + } + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/me", nil) + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 31}) + + handler.GetCurrentUser(c) + + require.Equal(t, http.StatusOK, recorder.Code) + + var resp struct { + Code int `json:"code"` + Data map[string]any `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.Equal(t, true, resp.Data["email_bound"]) + require.Equal(t, true, resp.Data["linuxdo_bound"]) + require.Equal(t, "https://cdn.example.com/linuxdo.png", resp.Data["avatar_url"]) + + authBindings, ok := resp.Data["auth_bindings"].(map[string]any) + require.True(t, ok) + linuxdoBinding, ok := authBindings["linuxdo"].(map[string]any) + require.True(t, ok) + require.Equal(t, true, linuxdoBinding["bound"]) + + avatarSource, ok := resp.Data["avatar_source"].(map[string]any) + require.True(t, ok) + require.Equal(t, "linuxdo", avatarSource["provider"]) + require.Equal(t, "linuxdo", avatarSource["source"]) + + profileSources, ok := resp.Data["profile_sources"].(map[string]any) + require.True(t, ok) + usernameSource, ok := profileSources["username"].(map[string]any) + require.True(t, ok) + require.Equal(t, "linuxdo", usernameSource["provider"]) + require.Equal(t, "linuxdo", usernameSource["source"]) +} diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index f4ddf890caa50e36acb89bdeb39c0a0ef24d4cc8..9801b3b395090294533c5cc42a4c8bccd2bfe913 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -1,11 +1,13 @@ package handler import ( + "context" "log/slog" "strings" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler/dto" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/response" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" @@ -104,6 +106,34 @@ func (h *AuthHandler) respondWithTokenPair(c *gin.Context, user *service.User) { }) } +func (h *AuthHandler) ensureBackendModeAllowsUser(ctx context.Context, user *service.User) error { + if user == nil { + return infraerrors.Unauthorized("INVALID_USER", "user not found") + } + if h == nil || !h.isBackendModeEnabled(ctx) || user.IsAdmin() { + return nil + } + return infraerrors.Forbidden("BACKEND_MODE_ADMIN_ONLY", "Backend mode is active. Only admin login is allowed.") +} + +func (h *AuthHandler) ensureBackendModeAllowsNewUserLogin(ctx context.Context) error { + if h == nil || !h.isBackendModeEnabled(ctx) { + return nil + } + return infraerrors.Forbidden("BACKEND_MODE_ADMIN_ONLY", "Backend mode is active. Only admin login is allowed.") +} + +func (h *AuthHandler) isBackendModeEnabled(ctx context.Context) bool { + if h == nil || h.settingSvc == nil { + return false + } + settings, err := h.settingSvc.GetPublicSettings(ctx) + if err == nil && settings != nil { + return settings.BackendModeEnabled + } + return h.settingSvc.IsBackendModeEnabled(ctx) +} + // Register handles user registration // POST /api/v1/auth/register func (h *AuthHandler) Register(c *gin.Context) { @@ -177,6 +207,11 @@ func (h *AuthHandler) Login(c *gin.Context) { } _ = token // token 由 authService.Login 返回但此处由 respondWithTokenPair 重新生成 + if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil { + response.ErrorFrom(c, err) + return + } + // Check if TOTP 2FA is enabled for this user if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled { // Create a temporary login session for 2FA @@ -194,11 +229,7 @@ func (h *AuthHandler) Login(c *gin.Context) { return } - // Backend mode: only admin can login - if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() { - response.Forbidden(c, "Backend mode is active. Only admin login is allowed.") - return - } + h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) h.respondWithTokenPair(c, user) } @@ -263,15 +294,75 @@ func (h *AuthHandler) Login2FA(c *gin.Context) { return } - // Backend mode: only admin can login (check BEFORE deleting session) - if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() { - response.Forbidden(c, "Backend mode is active. Only admin login is allowed.") + if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil { + response.ErrorFrom(c, err) return } + if session.PendingOAuthBind != nil { + pendingSvc, err := h.pendingIdentityService() + if err != nil { + response.ErrorFrom(c, err) + return + } + + pendingSession, err := pendingSvc.GetBrowserSession( + c.Request.Context(), + session.PendingOAuthBind.PendingSessionToken, + session.PendingOAuthBind.BrowserSessionKey, + ) + if err != nil { + response.ErrorFrom(c, err) + return + } + + decision, err := h.ensurePendingOAuthAdoptionDecision(c, pendingSession.ID, oauthAdoptionDecisionRequest{}) + if err != nil { + response.ErrorFrom(c, err) + return + } + if err := applyPendingOAuthBinding( + c.Request.Context(), + h.entClient(), + h.authService, + h.userService, + pendingSession, + decision, + &user.ID, + true, + true, + ); err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err)) + return + } + if _, err := pendingSvc.ConsumeBrowserSession( + c.Request.Context(), + pendingSession.SessionToken, + pendingSession.BrowserSessionKey, + ); err != nil { + response.ErrorFrom(c, err) + return + } + + secureCookie := isRequestHTTPS(c) + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) + + user, err = h.userService.GetByID(c.Request.Context(), session.UserID) + if err != nil { + response.ErrorFrom(c, err) + return + } + } + // Delete the login session (only after all checks pass) _ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken) + if session.PendingOAuthBind == nil { + h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) + } + h.respondWithTokenPair(c, user) } @@ -290,8 +381,14 @@ func (h *AuthHandler) GetCurrentUser(c *gin.Context) { return } + identities, err := h.userService.GetProfileIdentitySummaries(c.Request.Context(), subject.UserID, user) + if err != nil { + response.ErrorFrom(c, err) + return + } + type UserResponse struct { - *dto.User + userProfileResponse RunMode string `json:"run_mode"` } @@ -300,7 +397,10 @@ func (h *AuthHandler) GetCurrentUser(c *gin.Context) { runMode = h.cfg.RunMode } - response.Success(c, UserResponse{User: dto.UserFromService(user), RunMode: runMode}) + response.Success(c, UserResponse{ + userProfileResponse: userProfileResponseFromService(user, identities), + RunMode: runMode, + }) } // ValidatePromoCodeRequest 验证优惠码请求 diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go index 0c7c2da7ab49a5dcfb557dfcee1a75300687a8de..e0bee2f5b2bdddf77a5a7f0042fd85bf383902cd 100644 --- a/backend/internal/handler/auth_linuxdo_oauth.go +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -2,6 +2,8 @@ package handler import ( "context" + "crypto/hmac" + "crypto/sha256" "encoding/base64" "errors" "fmt" @@ -13,10 +15,13 @@ import ( "time" "unicode/utf8" + dbent "github.com/Wei-Shaw/sub2api/ent" + dbuser "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" "github.com/Wei-Shaw/sub2api/internal/pkg/response" + servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" @@ -25,17 +30,24 @@ import ( ) const ( - linuxDoOAuthCookiePath = "/api/v1/auth/oauth/linuxdo" - linuxDoOAuthStateCookieName = "linuxdo_oauth_state" - linuxDoOAuthVerifierCookie = "linuxdo_oauth_verifier" - linuxDoOAuthRedirectCookie = "linuxdo_oauth_redirect" - linuxDoOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes - linuxDoOAuthDefaultRedirectTo = "/dashboard" - linuxDoOAuthDefaultFrontendCB = "/auth/linuxdo/callback" + linuxDoOAuthCookiePath = "/api/v1/auth/oauth/linuxdo" + oauthBindAccessTokenCookiePath = "/api/v1/auth/oauth" + linuxDoOAuthStateCookieName = "linuxdo_oauth_state" + linuxDoOAuthVerifierCookie = "linuxdo_oauth_verifier" + linuxDoOAuthRedirectCookie = "linuxdo_oauth_redirect" + linuxDoOAuthIntentCookieName = "linuxdo_oauth_intent" + linuxDoOAuthBindUserCookieName = "linuxdo_oauth_bind_user" + oauthBindAccessTokenCookieName = "oauth_bind_access_token" + linuxDoOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes + linuxDoOAuthDefaultRedirectTo = "/dashboard" + linuxDoOAuthDefaultFrontendCB = "/auth/linuxdo/callback" linuxDoOAuthMaxRedirectLen = 2048 linuxDoOAuthMaxFragmentValueLen = 512 linuxDoOAuthMaxSubjectLen = 64 - len("linuxdo-") + + oauthIntentLogin = "login" + oauthIntentBindCurrentUser = "bind_current_user" ) type linuxDoTokenResponse struct { @@ -87,20 +99,37 @@ func (h *AuthHandler) LinuxDoOAuthStart(c *gin.Context) { redirectTo = linuxDoOAuthDefaultRedirectTo } + browserSessionKey, err := generateOAuthPendingBrowserSession() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BROWSER_SESSION_GEN_FAILED", "failed to generate oauth browser session").WithCause(err)) + return + } + secureCookie := isRequestHTTPS(c) setCookie(c, linuxDoOAuthStateCookieName, encodeCookieValue(state), linuxDoOAuthCookieMaxAgeSec, secureCookie) setCookie(c, linuxDoOAuthRedirectCookie, encodeCookieValue(redirectTo), linuxDoOAuthCookieMaxAgeSec, secureCookie) - - codeChallenge := "" - if cfg.UsePKCE { - verifier, err := oauth.GenerateCodeVerifier() + intent := normalizeOAuthIntent(c.Query("intent")) + setCookie(c, linuxDoOAuthIntentCookieName, encodeCookieValue(intent), linuxDoOAuthCookieMaxAgeSec, secureCookie) + setOAuthPendingBrowserCookie(c, browserSessionKey, secureCookie) + clearOAuthPendingSessionCookie(c, secureCookie) + if intent == oauthIntentBindCurrentUser { + bindCookieValue, err := h.buildOAuthBindUserCookieFromContext(c) if err != nil { - response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(err)) + response.ErrorFrom(c, err) return } - codeChallenge = oauth.GenerateCodeChallenge(verifier) - setCookie(c, linuxDoOAuthVerifierCookie, encodeCookieValue(verifier), linuxDoOAuthCookieMaxAgeSec, secureCookie) + setCookie(c, linuxDoOAuthBindUserCookieName, encodeCookieValue(bindCookieValue), linuxDoOAuthCookieMaxAgeSec, secureCookie) + } else { + clearCookie(c, linuxDoOAuthBindUserCookieName, secureCookie) + } + + verifier, err := oauth.GenerateCodeVerifier() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(err)) + return } + codeChallenge := oauth.GenerateCodeChallenge(verifier) + setCookie(c, linuxDoOAuthVerifierCookie, encodeCookieValue(verifier), linuxDoOAuthCookieMaxAgeSec, secureCookie) redirectURI := strings.TrimSpace(cfg.RedirectURL) if redirectURI == "" { @@ -148,6 +177,8 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { clearCookie(c, linuxDoOAuthStateCookieName, secureCookie) clearCookie(c, linuxDoOAuthVerifierCookie, secureCookie) clearCookie(c, linuxDoOAuthRedirectCookie, secureCookie) + clearCookie(c, linuxDoOAuthIntentCookieName, secureCookie) + clearCookie(c, linuxDoOAuthBindUserCookieName, secureCookie) }() expectedState, err := readCookieDecoded(c, linuxDoOAuthStateCookieName) @@ -161,14 +192,18 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { if redirectTo == "" { redirectTo = linuxDoOAuthDefaultRedirectTo } + browserSessionKey, _ := readOAuthPendingBrowserCookie(c) + if strings.TrimSpace(browserSessionKey) == "" { + redirectOAuthError(c, frontendCallback, "missing_browser_session", "missing oauth browser session", "") + return + } + intent, _ := readCookieDecoded(c, linuxDoOAuthIntentCookieName) + intent = normalizeOAuthIntent(intent) - codeVerifier := "" - if cfg.UsePKCE { - codeVerifier, _ = readCookieDecoded(c, linuxDoOAuthVerifierCookie) - if codeVerifier == "" { - redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "") - return - } + codeVerifier, _ := readCookieDecoded(c, linuxDoOAuthVerifierCookie) + if codeVerifier == "" { + redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "") + return } redirectURI := strings.TrimSpace(cfg.RedirectURL) @@ -198,52 +233,202 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { return } - email, username, subject, err := linuxDoFetchUserInfo(c.Request.Context(), cfg, tokenResp) + email, username, subject, displayName, avatarURL, err := linuxDoFetchUserInfo(c.Request.Context(), cfg, tokenResp) if err != nil { log.Printf("[LinuxDo OAuth] userinfo fetch failed: %v", err) redirectOAuthError(c, frontendCallback, "userinfo_failed", "failed to fetch user info", "") return } + compatEmail := strings.TrimSpace(email) // 安全考虑:不要把第三方返回的 email 直接映射到本地账号(可能与本地邮箱用户冲突导致账号被接管)。 // 统一使用基于 subject 的稳定合成邮箱来做账号绑定。 if subject != "" { email = linuxDoSyntheticEmail(subject) } + identityKey := service.PendingAuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: subject, + } + upstreamClaims := map[string]any{ + "email": email, + "username": username, + "subject": subject, + "suggested_display_name": displayName, + "suggested_avatar_url": avatarURL, + } + if compatEmail != "" && !strings.EqualFold(strings.TrimSpace(compatEmail), strings.TrimSpace(email)) { + upstreamClaims["compat_email"] = compatEmail + } + if intent == oauthIntentBindCurrentUser { + targetUserID, err := h.readOAuthBindUserIDFromCookie(c, linuxDoOAuthBindUserCookieName) + if err != nil { + redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth bind target", "") + return + } + if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: oauthIntentBindCurrentUser, + Identity: identityKey, + TargetUserID: &targetUserID, + ResolvedEmail: email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, + CompletionResponse: map[string]any{ + "redirect": redirectTo, + }, + }); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth bind", "") + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } - // 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired - tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") + existingIdentityUser, err := h.findOAuthIdentityUser(c.Request.Context(), identityKey) if err != nil { - if errors.Is(err, service.ErrOAuthInvitationRequired) { - pendingToken, tokenErr := h.authService.CreatePendingOAuthToken(email, username) - if tokenErr != nil { - redirectOAuthError(c, frontendCallback, "login_failed", "service_error", "") - return - } - fragment := url.Values{} - fragment.Set("error", "invitation_required") - fragment.Set("pending_oauth_token", pendingToken) - fragment.Set("redirect", redirectTo) - redirectWithFragment(c, frontendCallback, fragment) + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + if existingIdentityUser != nil { + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "") + if err != nil { + redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: oauthIntentLogin, + Identity: identityKey, + TargetUserID: &user.ID, + ResolvedEmail: existingIdentityUser.Email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, + CompletionResponse: map[string]any{ + "access_token": tokenPair.AccessToken, + "refresh_token": tokenPair.RefreshToken, + "expires_in": tokenPair.ExpiresIn, + "token_type": "Bearer", + "redirect": redirectTo, + }, + }); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") return } - // 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。 - redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err)) + redirectToFrontendCallback(c, frontendCallback) return } - fragment := url.Values{} - fragment.Set("access_token", tokenPair.AccessToken) - fragment.Set("refresh_token", tokenPair.RefreshToken) - fragment.Set("expires_in", fmt.Sprintf("%d", tokenPair.ExpiresIn)) - fragment.Set("token_type", "Bearer") - fragment.Set("redirect", redirectTo) - redirectWithFragment(c, frontendCallback, fragment) + compatEmailUser, err := h.findLinuxDoCompatEmailUser(c.Request.Context(), compatEmail) + if err != nil { + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + if err := h.createLinuxDoOAuthChoicePendingSession( + c, + identityKey, + email, + email, + redirectTo, + browserSessionKey, + upstreamClaims, + compatEmail, + compatEmailUser, + h.isForceEmailOnThirdPartySignup(c.Request.Context()), + ); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) +} + +func (h *AuthHandler) findLinuxDoCompatEmailUser(ctx context.Context, email string) (*dbent.User, error) { + client := h.entClient() + if client == nil { + return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + + email = strings.TrimSpace(strings.ToLower(email)) + if email == "" || + strings.HasSuffix(email, service.LinuxDoConnectSyntheticEmailDomain) || + strings.HasSuffix(email, service.OIDCConnectSyntheticEmailDomain) || + strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) { + return nil, nil + } + + userEntity, err := client.User.Query(). + Where(dbuser.EmailEqualFold(email)). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil, nil + } + return nil, infraerrors.InternalServer("COMPAT_EMAIL_LOOKUP_FAILED", "failed to look up compat email user").WithCause(err) + } + return userEntity, nil +} + +func (h *AuthHandler) createLinuxDoOAuthChoicePendingSession( + c *gin.Context, + identity service.PendingAuthIdentityKey, + suggestedEmail string, + resolvedEmail string, + redirectTo string, + browserSessionKey string, + upstreamClaims map[string]any, + compatEmail string, + compatEmailUser *dbent.User, + forceEmailOnSignup bool, +) error { + suggestionEmail := strings.TrimSpace(suggestedEmail) + canonicalEmail := strings.TrimSpace(resolvedEmail) + if suggestionEmail == "" { + suggestionEmail = canonicalEmail + } + + completionResponse := map[string]any{ + "step": oauthPendingChoiceStep, + "adoption_required": true, + "redirect": strings.TrimSpace(redirectTo), + "email": suggestionEmail, + "resolved_email": canonicalEmail, + "existing_account_email": "", + "existing_account_bindable": false, + "create_account_allowed": true, + "force_email_on_signup": forceEmailOnSignup, + "choice_reason": "third_party_signup", + } + if strings.TrimSpace(compatEmail) != "" { + completionResponse["compat_email"] = strings.TrimSpace(compatEmail) + } + resolvedChoiceEmail := suggestionEmail + if compatEmailUser != nil { + completionResponse["email"] = strings.TrimSpace(compatEmailUser.Email) + completionResponse["existing_account_email"] = strings.TrimSpace(compatEmailUser.Email) + completionResponse["existing_account_bindable"] = true + completionResponse["choice_reason"] = "compat_email_match" + resolvedChoiceEmail = strings.TrimSpace(compatEmailUser.Email) + } + if forceEmailOnSignup && compatEmailUser == nil { + completionResponse["choice_reason"] = "force_email_on_signup" + } + + return h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: oauthIntentLogin, + Identity: identity, + ResolvedEmail: resolvedChoiceEmail, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, + CompletionResponse: completionResponse, + }) } type completeLinuxDoOAuthRequest struct { - PendingOAuthToken string `json:"pending_oauth_token" binding:"required"` - InvitationCode string `json:"invitation_code" binding:"required"` + InvitationCode string `json:"invitation_code" binding:"required"` + AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` + AdoptAvatar *bool `json:"adopt_avatar,omitempty"` } // CompleteLinuxDoOAuthRegistration completes a pending OAuth registration by validating @@ -256,17 +441,75 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) { return } - email, username, err := h.authService.VerifyPendingOAuthToken(req.PendingOAuthToken) + secureCookie := isRequestHTTPS(c) + sessionToken, err := readOAuthPendingSessionCookie(c) + if err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound) + return + } + browserSessionKey, err := readOAuthPendingBrowserCookie(c) if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{"error": "INVALID_TOKEN", "message": "invalid or expired registration token"}) + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch) + return + } + pendingSvc, err := h.pendingIdentityService() + if err != nil { + response.ErrorFrom(c, err) + return + } + session, err := pendingSvc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey) + if err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, err) + return + } + if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil { + response.ErrorFrom(c, err) + return + } + if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + email := strings.TrimSpace(session.ResolvedEmail) + username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username") + if email == "" || username == "" { + response.ErrorFrom(c, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")) return } - tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) if err != nil { response.ErrorFrom(c, err) return } + decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{ + AdoptDisplayName: req.AdoptDisplayName, + AdoptAvatar: req.AdoptAvatar, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID); err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) + return + } + h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) + if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, err) + return + } + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) c.JSON(http.StatusOK, gin.H{ "access_token": tokenPair.AccessToken, @@ -303,9 +546,7 @@ func linuxDoExchangeCode( form.Set("client_id", cfg.ClientID) form.Set("code", code) form.Set("redirect_uri", redirectURI) - if cfg.UsePKCE { - form.Set("code_verifier", codeVerifier) - } + form.Set("code_verifier", codeVerifier) r := client.R(). SetContext(ctx). @@ -353,11 +594,11 @@ func linuxDoFetchUserInfo( ctx context.Context, cfg config.LinuxDoConnectConfig, token *linuxDoTokenResponse, -) (email string, username string, subject string, err error) { +) (email string, username string, subject string, displayName string, avatarURL string, err error) { client := req.C().SetTimeout(30 * time.Second) authorization, err := buildBearerAuthorization(token.TokenType, token.AccessToken) if err != nil { - return "", "", "", fmt.Errorf("invalid token for userinfo request: %w", err) + return "", "", "", "", "", fmt.Errorf("invalid token for userinfo request: %w", err) } resp, err := client.R(). @@ -366,16 +607,16 @@ func linuxDoFetchUserInfo( SetHeader("Authorization", authorization). Get(cfg.UserInfoURL) if err != nil { - return "", "", "", fmt.Errorf("request userinfo: %w", err) + return "", "", "", "", "", fmt.Errorf("request userinfo: %w", err) } if !resp.IsSuccessState() { - return "", "", "", fmt.Errorf("userinfo status=%d", resp.StatusCode) + return "", "", "", "", "", fmt.Errorf("userinfo status=%d", resp.StatusCode) } return linuxDoParseUserInfo(resp.String(), cfg) } -func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email string, username string, subject string, err error) { +func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email string, username string, subject string, displayName string, avatarURL string, err error) { email = firstNonEmpty( getGJSON(body, cfg.UserInfoEmailPath), getGJSON(body, "email"), @@ -400,12 +641,29 @@ func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email s getGJSON(body, "user.id"), ) + displayName = firstNonEmpty( + getGJSON(body, "name"), + getGJSON(body, "nickname"), + getGJSON(body, "display_name"), + getGJSON(body, "user.name"), + getGJSON(body, "user.username"), + username, + ) + avatarURL = firstNonEmpty( + getGJSON(body, "avatar_url"), + getGJSON(body, "avatar"), + getGJSON(body, "picture"), + getGJSON(body, "profile_image_url"), + getGJSON(body, "user.avatar"), + getGJSON(body, "user.avatar_url"), + ) + subject = strings.TrimSpace(subject) if subject == "" { - return "", "", "", errors.New("userinfo missing id field") + return "", "", "", "", "", errors.New("userinfo missing id field") } if !isSafeLinuxDoSubject(subject) { - return "", "", "", errors.New("userinfo returned invalid id field") + return "", "", "", "", "", errors.New("userinfo returned invalid id field") } email = strings.TrimSpace(email) @@ -418,8 +676,13 @@ func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email s if username == "" { username = "linuxdo_" + subject } + displayName = strings.TrimSpace(displayName) + if displayName == "" { + displayName = username + } + avatarURL = strings.TrimSpace(avatarURL) - return email, username, subject, nil + return email, username, subject, displayName, avatarURL, nil } func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, codeChallenge string, redirectURI string) (string, error) { @@ -436,10 +699,8 @@ func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, cod q.Set("scope", cfg.Scopes) } q.Set("state", state) - if cfg.UsePKCE { - q.Set("code_challenge", codeChallenge) - q.Set("code_challenge_method", "S256") - } + q.Set("code_challenge", codeChallenge) + q.Set("code_challenge_method", "S256") u.RawQuery = q.Encode() return u.String(), nil @@ -670,6 +931,18 @@ func clearCookie(c *gin.Context, name string, secure bool) { }) } +func clearOAuthBindAccessTokenCookie(c *gin.Context, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: oauthBindAccessTokenCookieName, + Value: "", + Path: oauthBindAccessTokenCookiePath, + MaxAge: -1, + HttpOnly: false, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + func truncateFragmentValue(value string) string { value = strings.TrimSpace(value) if value == "" { @@ -728,3 +1001,107 @@ func linuxDoSyntheticEmail(subject string) string { } return "linuxdo-" + subject + service.LinuxDoConnectSyntheticEmailDomain } + +func normalizeOAuthIntent(raw string) string { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "", oauthIntentLogin: + return oauthIntentLogin + case "bind", oauthIntentBindCurrentUser: + return oauthIntentBindCurrentUser + default: + return oauthIntentLogin + } +} + +func (h *AuthHandler) buildOAuthBindUserCookieFromContext(c *gin.Context) (string, error) { + userID, err := h.resolveOAuthBindTargetUserID(c) + if err != nil || userID == nil || *userID <= 0 { + return "", infraerrors.Unauthorized("UNAUTHORIZED", "authentication required") + } + return buildOAuthBindUserCookieValue(*userID, h.oauthBindCookieSecret()) +} + +func (h *AuthHandler) resolveOAuthBindTargetUserID(c *gin.Context) (*int64, error) { + if subject, ok := servermiddleware.GetAuthSubjectFromContext(c); ok && subject.UserID > 0 { + return &subject.UserID, nil + } + if h == nil || h.authService == nil || h.userService == nil { + return nil, service.ErrInvalidToken + } + + ck, err := c.Request.Cookie(oauthBindAccessTokenCookieName) + clearOAuthBindAccessTokenCookie(c, isRequestHTTPS(c)) + if err != nil { + return nil, err + } + + tokenString, err := url.QueryUnescape(strings.TrimSpace(ck.Value)) + if err != nil { + return nil, err + } + if tokenString == "" { + return nil, service.ErrInvalidToken + } + + claims, err := h.authService.ValidateToken(tokenString) + if err != nil { + return nil, err + } + user, err := h.userService.GetByID(c.Request.Context(), claims.UserID) + if err != nil { + return nil, err + } + if user == nil || !user.IsActive() || claims.TokenVersion != user.TokenVersion { + return nil, service.ErrInvalidToken + } + return &user.ID, nil +} + +func (h *AuthHandler) readOAuthBindUserIDFromCookie(c *gin.Context, cookieName string) (int64, error) { + value, err := readCookieDecoded(c, cookieName) + if err != nil { + return 0, err + } + return parseOAuthBindUserCookieValue(value, h.oauthBindCookieSecret()) +} + +func (h *AuthHandler) oauthBindCookieSecret() string { + if h == nil || h.cfg == nil { + return "" + } + return strings.TrimSpace(h.cfg.JWT.Secret) +} + +func buildOAuthBindUserCookieValue(userID int64, secret string) (string, error) { + secret = strings.TrimSpace(secret) + if userID <= 0 || secret == "" { + return "", errors.New("invalid oauth bind cookie input") + } + payload := strconv.FormatInt(userID, 10) + mac := hmac.New(sha256.New, []byte(secret)) + _, _ = mac.Write([]byte(payload)) + signature := base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) + return payload + "." + signature, nil +} + +func parseOAuthBindUserCookieValue(value string, secret string) (int64, error) { + secret = strings.TrimSpace(secret) + if secret == "" { + return 0, errors.New("missing oauth bind cookie secret") + } + payload, signature, ok := strings.Cut(strings.TrimSpace(value), ".") + if !ok || payload == "" || signature == "" { + return 0, errors.New("invalid oauth bind cookie") + } + mac := hmac.New(sha256.New, []byte(secret)) + _, _ = mac.Write([]byte(payload)) + expectedSignature := base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) + if !hmac.Equal([]byte(signature), []byte(expectedSignature)) { + return 0, errors.New("invalid oauth bind cookie signature") + } + userID, err := strconv.ParseInt(payload, 10, 64) + if err != nil || userID <= 0 { + return 0, errors.New("invalid oauth bind cookie user") + } + return userID, nil +} diff --git a/backend/internal/handler/auth_linuxdo_oauth_test.go b/backend/internal/handler/auth_linuxdo_oauth_test.go index ff169c52ad694b76b53ea02892bbd05496aadd4a..0c760ee91a3c8f115deedc59efb6161016777324 100644 --- a/backend/internal/handler/auth_linuxdo_oauth_test.go +++ b/backend/internal/handler/auth_linuxdo_oauth_test.go @@ -1,10 +1,23 @@ package handler import ( + "bytes" + "context" + "net/http" + "net/http/httptest" "strings" "testing" + "time" + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + dbuser "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/internal/config" + servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) @@ -41,11 +54,13 @@ func TestLinuxDoParseUserInfoParsesIDAndUsername(t *testing.T) { UserInfoURL: "https://connect.linux.do/api/user", } - email, username, subject, err := linuxDoParseUserInfo(`{"id":123,"username":"alice"}`, cfg) + email, username, subject, displayName, avatarURL, err := linuxDoParseUserInfo(`{"id":123,"username":"alice","name":"Alice","avatar_url":"https://cdn.example/avatar.png"}`, cfg) require.NoError(t, err) require.Equal(t, "123", subject) require.Equal(t, "alice", username) require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email) + require.Equal(t, "Alice", displayName) + require.Equal(t, "https://cdn.example/avatar.png", avatarURL) } func TestLinuxDoParseUserInfoDefaultsUsername(t *testing.T) { @@ -53,11 +68,13 @@ func TestLinuxDoParseUserInfoDefaultsUsername(t *testing.T) { UserInfoURL: "https://connect.linux.do/api/user", } - email, username, subject, err := linuxDoParseUserInfo(`{"id":"123"}`, cfg) + email, username, subject, displayName, avatarURL, err := linuxDoParseUserInfo(`{"id":"123"}`, cfg) require.NoError(t, err) require.Equal(t, "123", subject) require.Equal(t, "linuxdo_123", username) require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email) + require.Equal(t, "linuxdo_123", displayName) + require.Equal(t, "", avatarURL) } func TestLinuxDoParseUserInfoRejectsUnsafeSubject(t *testing.T) { @@ -65,11 +82,11 @@ func TestLinuxDoParseUserInfoRejectsUnsafeSubject(t *testing.T) { UserInfoURL: "https://connect.linux.do/api/user", } - _, _, _, err := linuxDoParseUserInfo(`{"id":"123@456"}`, cfg) + _, _, _, _, _, err := linuxDoParseUserInfo(`{"id":"123@456"}`, cfg) require.Error(t, err) tooLong := strings.Repeat("a", linuxDoOAuthMaxSubjectLen+1) - _, _, _, err = linuxDoParseUserInfo(`{"id":"`+tooLong+`"}`, cfg) + _, _, _, _, _, err = linuxDoParseUserInfo(`{"id":"`+tooLong+`"}`, cfg) require.Error(t, err) } @@ -106,3 +123,559 @@ func TestSingleLineStripsWhitespace(t *testing.T) { require.Equal(t, "hello world", singleLine("hello\r\nworld")) require.Equal(t, "", singleLine("\n\t\r")) } + +func TestLinuxDoOAuthBindStartRedirectsAndSetsBindCookies(t *testing.T) { + handler := newLinuxDoOAuthTestHandler(t, false, config.LinuxDoConnectConfig{ + Enabled: true, + ClientID: "linuxdo-client", + ClientSecret: "linuxdo-secret", + AuthorizeURL: "https://connect.linux.do/oauth/authorize", + TokenURL: "https://connect.linux.do/oauth/token", + UserInfoURL: "https://connect.linux.do/api/user", + Scopes: "read", + RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback", + FrontendRedirectURL: "/auth/linuxdo/callback", + TokenAuthMethod: "client_secret_post", + UsePKCE: true, + }) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/bind/start?intent=bind_current_user&redirect=/settings/connections", nil) + c.Request = req + c.Set(string(servermiddleware.ContextKeyUser), servermiddleware.AuthSubject{UserID: 42}) + + handler.LinuxDoOAuthStart(c) + + require.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + require.Contains(t, location, "connect.linux.do/oauth/authorize") + require.Contains(t, location, "client_id=linuxdo-client") + require.Contains(t, location, "code_challenge=") + + cookies := recorder.Result().Cookies() + require.NotNil(t, findCookie(cookies, linuxDoOAuthStateCookieName)) + require.NotNil(t, findCookie(cookies, linuxDoOAuthRedirectCookie)) + require.NotNil(t, findCookie(cookies, linuxDoOAuthVerifierCookie)) + require.NotNil(t, findCookie(cookies, oauthPendingBrowserCookieName)) + + intentCookie := findCookie(cookies, linuxDoOAuthIntentCookieName) + require.NotNil(t, intentCookie) + require.Equal(t, oauthIntentBindCurrentUser, decodeCookieValueForTest(t, intentCookie.Value)) + + bindCookie := findCookie(cookies, linuxDoOAuthBindUserCookieName) + require.NotNil(t, bindCookie) + userID, err := parseOAuthBindUserCookieValue(decodeCookieValueForTest(t, bindCookie.Value), "test-secret") + require.NoError(t, err) + require.Equal(t, int64(42), userID) +} + +func TestLinuxDoOAuthBindStartAcceptsAccessTokenCookie(t *testing.T) { + handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{ + Enabled: true, + ClientID: "linuxdo-client", + ClientSecret: "linuxdo-secret", + AuthorizeURL: "https://connect.linux.do/oauth/authorize", + TokenURL: "https://connect.linux.do/oauth/token", + UserInfoURL: "https://connect.linux.do/api/user", + Scopes: "read", + RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback", + FrontendRedirectURL: "/auth/linuxdo/callback", + TokenAuthMethod: "client_secret_post", + UsePKCE: true, + }) + t.Cleanup(func() { _ = client.Close() }) + + user, err := client.User.Create(). + SetEmail("bind-cookie@example.com"). + SetUsername("bind-cookie-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(context.Background()) + require.NoError(t, err) + + token, err := handler.authService.GenerateToken(&service.User{ + ID: user.ID, + Email: user.Email, + Username: user.Username, + PasswordHash: user.PasswordHash, + Role: user.Role, + Status: user.Status, + }) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/start?intent=bind_current_user&redirect=/settings/connections", nil) + req.AddCookie(&http.Cookie{Name: oauthBindAccessTokenCookieName, Value: token, Path: oauthBindAccessTokenCookiePath}) + c.Request = req + + handler.LinuxDoOAuthStart(c) + + require.Equal(t, http.StatusFound, recorder.Code) + + bindCookie := findCookie(recorder.Result().Cookies(), linuxDoOAuthBindUserCookieName) + require.NotNil(t, bindCookie) + userID, err := parseOAuthBindUserCookieValue(decodeCookieValueForTest(t, bindCookie.Value), "test-secret") + require.NoError(t, err) + require.Equal(t, user.ID, userID) + + accessTokenCookie := findCookie(recorder.Result().Cookies(), oauthBindAccessTokenCookieName) + require.NotNil(t, accessTokenCookie) + require.Equal(t, -1, accessTokenCookie.MaxAge) +} + +func TestLinuxDoOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/token": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`)) + case "/userinfo": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"321","username":"linuxdo_user","name":"LinuxDo Display","avatar_url":"https://cdn.example/linuxdo.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + + handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{ + Enabled: true, + ClientID: "linuxdo-client", + ClientSecret: "linuxdo-secret", + AuthorizeURL: upstream.URL + "/authorize", + TokenURL: upstream.URL + "/token", + UserInfoURL: upstream.URL + "/userinfo", + Scopes: "read", + RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback", + FrontendRedirectURL: "/auth/linuxdo/callback", + TokenAuthMethod: "client_secret_post", + UsePKCE: true, + }) + t.Cleanup(func() { _ = client.Close() }) + + ctx := context.Background() + existingUser, err := client.User.Create(). + SetEmail(linuxDoSyntheticEmail("321")). + SetUsername("legacy-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + _, err = client.AuthIdentity.Create(). + SetUserID(existingUser.ID). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("321"). + SetMetadata(map[string]any{"username": "legacy-user"}). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-123&state=state-123", nil) + req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard")) + req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-123")) + req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin)) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.LinuxDoOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.Equal(t, oauthIntentLogin, session.Intent) + require.NotNil(t, session.TargetUserID) + require.Equal(t, existingUser.ID, *session.TargetUserID) + require.Equal(t, linuxDoSyntheticEmail("321"), session.ResolvedEmail) + require.Equal(t, "LinuxDo Display", session.UpstreamIdentityClaims["suggested_display_name"]) + + completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.True(t, ok) + require.Equal(t, "/dashboard", completion["redirect"]) + require.NotEmpty(t, completion["access_token"]) + require.Nil(t, completion["error"]) +} + +func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/token": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`)) + case "/userinfo": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"321","email":"legacy@example.com","username":"linuxdo_user","name":"LinuxDo Display","avatar_url":"https://cdn.example/linuxdo.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + + handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{ + Enabled: true, + ClientID: "linuxdo-client", + ClientSecret: "linuxdo-secret", + AuthorizeURL: upstream.URL + "/authorize", + TokenURL: upstream.URL + "/token", + UserInfoURL: upstream.URL + "/userinfo", + Scopes: "read", + RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback", + FrontendRedirectURL: "/auth/linuxdo/callback", + TokenAuthMethod: "client_secret_post", + UsePKCE: true, + }) + t.Cleanup(func() { _ = client.Close() }) + + ctx := context.Background() + existingUser, err := client.User.Create(). + SetEmail("legacy@example.com"). + SetUsername("legacy-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-compat&state=state-compat", nil) + req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-compat")) + req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard")) + req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-compat")) + req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin)) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-compat")) + c.Request = req + + handler.LinuxDoOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.Equal(t, oauthIntentLogin, session.Intent) + require.Nil(t, session.TargetUserID) + require.Equal(t, existingUser.Email, session.ResolvedEmail) + require.Equal(t, "legacy@example.com", session.UpstreamIdentityClaims["compat_email"]) + + completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.True(t, ok) + require.Equal(t, "/dashboard", completion["redirect"]) + require.Equal(t, oauthPendingChoiceStep, completion["step"]) + require.Equal(t, existingUser.Email, completion["email"]) + require.Equal(t, existingUser.Email, completion["existing_account_email"]) + require.Equal(t, true, completion["existing_account_bindable"]) + require.Equal(t, "compat_email_match", completion["choice_reason"]) + _, hasAccessToken := completion["access_token"] + require.False(t, hasAccessToken) +} + +func TestLinuxDoOAuthCallbackCreatesChoicePendingSessionWhenSignupRequiresInvite(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/token": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`)) + case "/userinfo": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"654","username":"linuxdo_invite","name":"Need Invite","avatar_url":"https://cdn.example/invite.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + + handler, client := newLinuxDoOAuthHandlerAndClient(t, true, config.LinuxDoConnectConfig{ + Enabled: true, + ClientID: "linuxdo-client", + ClientSecret: "linuxdo-secret", + AuthorizeURL: upstream.URL + "/authorize", + TokenURL: upstream.URL + "/token", + UserInfoURL: upstream.URL + "/userinfo", + Scopes: "read", + RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback", + FrontendRedirectURL: "/auth/linuxdo/callback", + TokenAuthMethod: "client_secret_post", + UsePKCE: true, + }) + t.Cleanup(func() { _ = client.Close() }) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-456&state=state-456", nil) + req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-456")) + req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard")) + req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-456")) + req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin)) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-456")) + c.Request = req + + handler.LinuxDoOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + ctx := context.Background() + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.Equal(t, oauthIntentLogin, session.Intent) + require.Nil(t, session.TargetUserID) + + completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.True(t, ok) + require.Equal(t, oauthPendingChoiceStep, completion["step"]) + require.Equal(t, "/dashboard", completion["redirect"]) + require.Equal(t, "third_party_signup", completion["choice_reason"]) +} + +func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCurrentUser(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/token": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`)) + case "/userinfo": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"999","username":"bind_user","name":"Bind Display","avatar_url":"https://cdn.example/bind.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + + handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{ + Enabled: true, + ClientID: "linuxdo-client", + ClientSecret: "linuxdo-secret", + AuthorizeURL: upstream.URL + "/authorize", + TokenURL: upstream.URL + "/token", + UserInfoURL: upstream.URL + "/userinfo", + Scopes: "read", + RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback", + FrontendRedirectURL: "/auth/linuxdo/callback", + TokenAuthMethod: "client_secret_post", + UsePKCE: true, + }) + t.Cleanup(func() { _ = client.Close() }) + + ctx := context.Background() + currentUser, err := client.User.Create(). + SetEmail("current@example.com"). + SetUsername("current-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-bind&state=state-bind", nil) + req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-bind")) + req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/settings/connections")) + req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-bind")) + req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentBindCurrentUser)) + req.AddCookie(encodedCookie(linuxDoOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret"))) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-bind")) + c.Request = req + + handler.LinuxDoOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.Equal(t, oauthIntentBindCurrentUser, session.Intent) + require.NotNil(t, session.TargetUserID) + require.Equal(t, currentUser.ID, *session.TargetUserID) + require.Equal(t, linuxDoSyntheticEmail("999"), session.ResolvedEmail) + + completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.True(t, ok) + require.Equal(t, "/settings/connections", completion["redirect"]) + require.Empty(t, completion["access_token"]) + require.Equal(t, "Bind Display", session.UpstreamIdentityClaims["suggested_display_name"]) + + userCount, err := client.User.Query().Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, userCount) +} + +func TestCompleteLinuxDoOAuthRegistrationAppliesPendingAdoptionDecision(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("linuxdo-complete-session"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("linuxdo-subject-1"). + SetResolvedEmail("linuxdo-subject-1@linuxdo-connect.invalid"). + SetBrowserSessionKey("linuxdo-browser"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "linuxdo_user", + "suggested_display_name": "LinuxDo Display", + "suggested_avatar_url": "https://cdn.example/linuxdo.png", + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + _, err = service.NewAuthPendingIdentityService(client).UpsertAdoptionDecision(ctx, service.PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: session.ID, + AdoptAvatar: true, + }) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-browser")}) + c.Request = req + + handler.CompleteLinuxDoOAuthRegistration(c) + + require.Equal(t, http.StatusOK, recorder.Code) + responseData := decodeJSONBody(t, recorder) + require.NotEmpty(t, responseData["access_token"]) + + userEntity, err := client.User.Query(). + Where(dbuser.EmailEQ(session.ResolvedEmail)). + Only(ctx) + require.NoError(t, err) + require.Equal(t, "LinuxDo Display", userEntity.Username) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("linuxdo"), + authidentity.ProviderKeyEQ("linuxdo"), + authidentity.ProviderSubjectEQ("linuxdo-subject-1"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, userEntity.ID, identity.UserID) + require.Equal(t, "LinuxDo Display", identity.Metadata["display_name"]) + require.Equal(t, "https://cdn.example/linuxdo.png", identity.Metadata["avatar_url"]) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, decision.IdentityID) + require.Equal(t, identity.ID, *decision.IdentityID) + require.True(t, decision.AdoptDisplayName) + require.True(t, decision.AdoptAvatar) + + consumed, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) +} + +func TestCompleteLinuxDoOAuthRegistrationRejectsAdoptExistingUserSession(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + existingUser, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("linuxdo-complete-invalid-session"). + SetIntent("adopt_existing_user_by_email"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("linuxdo-invalid-subject-1"). + SetTargetUserID(existingUser.ID). + SetResolvedEmail(existingUser.Email). + SetBrowserSessionKey("linuxdo-invalid-browser"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "linuxdo_user", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "step": "bind_login_required", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-invalid-browser")}) + c.Request = req + + handler.CompleteLinuxDoOAuthRegistration(c) + + require.Equal(t, http.StatusBadRequest, recorder.Code) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + +func newLinuxDoOAuthTestHandler(t *testing.T, invitationEnabled bool, oauthCfg config.LinuxDoConnectConfig) *AuthHandler { + t.Helper() + handler, _ := newLinuxDoOAuthHandlerAndClient(t, invitationEnabled, oauthCfg) + return handler +} + +func newLinuxDoOAuthHandlerAndClient(t *testing.T, invitationEnabled bool, oauthCfg config.LinuxDoConnectConfig) (*AuthHandler, *dbent.Client) { + t.Helper() + handler, client := newOAuthPendingFlowTestHandler(t, invitationEnabled) + handler.settingSvc = nil + handler.cfg = &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret", + ExpireHour: 1, + AccessTokenExpireMinutes: 60, + RefreshTokenExpireDays: 7, + }, + LinuxDo: oauthCfg, + } + return handler, client +} diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go new file mode 100644 index 0000000000000000000000000000000000000000..7d7b50f46c7690b0b016cee8e053be1140a98ecd --- /dev/null +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -0,0 +1,1729 @@ +package handler + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/predicate" + dbuser "github.com/Wei-Shaw/sub2api/ent/user" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" + "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + + entsql "entgo.io/ent/dialect/sql" + "github.com/gin-gonic/gin" +) + +const ( + oauthPendingBrowserCookiePath = "/api/v1/auth/oauth" + oauthPendingBrowserCookieName = "oauth_pending_browser_session" + oauthPendingSessionCookiePath = "/api/v1/auth/oauth" + oauthPendingSessionCookieName = "oauth_pending_session" + oauthPendingCookieMaxAgeSec = 10 * 60 + oauthPendingChoiceStep = "choose_account_action_required" + + oauthCompletionResponseKey = "completion_response" +) + +var pendingOAuthCreateAccountPreCommitHook func(context.Context, *dbent.PendingAuthSession) error + +type oauthPendingSessionPayload struct { + Intent string + Identity service.PendingAuthIdentityKey + TargetUserID *int64 + ResolvedEmail string + RedirectTo string + BrowserSessionKey string + UpstreamIdentityClaims map[string]any + CompletionResponse map[string]any +} + +type oauthAdoptionDecisionRequest struct { + AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` + AdoptAvatar *bool `json:"adopt_avatar,omitempty"` +} + +type bindPendingOAuthLoginRequest struct { + Email string `json:"email" binding:"required,email"` + Password string `json:"password" binding:"required"` + AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` + AdoptAvatar *bool `json:"adopt_avatar,omitempty"` +} + +type createPendingOAuthAccountRequest struct { + Email string `json:"email" binding:"required,email"` + VerifyCode string `json:"verify_code,omitempty"` + Password string `json:"password" binding:"required,min=6"` + InvitationCode string `json:"invitation_code,omitempty"` + AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` + AdoptAvatar *bool `json:"adopt_avatar,omitempty"` +} + +type sendPendingOAuthVerifyCodeRequest struct { + Email string `json:"email" binding:"required,email"` + TurnstileToken string `json:"turnstile_token,omitempty"` + PendingAuthToken string `json:"pending_auth_token,omitempty"` + PendingOAuthToken string `json:"pending_oauth_token,omitempty"` +} + +func (r bindPendingOAuthLoginRequest) adoptionDecision() oauthAdoptionDecisionRequest { + return oauthAdoptionDecisionRequest{ + AdoptDisplayName: r.AdoptDisplayName, + AdoptAvatar: r.AdoptAvatar, + } +} + +func (r createPendingOAuthAccountRequest) adoptionDecision() oauthAdoptionDecisionRequest { + return oauthAdoptionDecisionRequest{ + AdoptDisplayName: r.AdoptDisplayName, + AdoptAvatar: r.AdoptAvatar, + } +} + +func (h *AuthHandler) pendingIdentityService() (*service.AuthPendingIdentityService, error) { + if h == nil || h.authService == nil || h.authService.EntClient() == nil { + return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + return service.NewAuthPendingIdentityService(h.authService.EntClient()), nil +} + +func generateOAuthPendingBrowserSession() (string, error) { + return oauth.GenerateState() +} + +func setOAuthPendingBrowserCookie(c *gin.Context, sessionKey string, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: oauthPendingBrowserCookieName, + Value: encodeCookieValue(sessionKey), + Path: oauthPendingBrowserCookiePath, + MaxAge: oauthPendingCookieMaxAgeSec, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func clearOAuthPendingBrowserCookie(c *gin.Context, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: oauthPendingBrowserCookieName, + Value: "", + Path: oauthPendingBrowserCookiePath, + MaxAge: -1, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func readOAuthPendingBrowserCookie(c *gin.Context) (string, error) { + return readCookieDecoded(c, oauthPendingBrowserCookieName) +} + +func setOAuthPendingSessionCookie(c *gin.Context, sessionToken string, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: oauthPendingSessionCookieName, + Value: encodeCookieValue(sessionToken), + Path: oauthPendingSessionCookiePath, + MaxAge: oauthPendingCookieMaxAgeSec, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func clearOAuthPendingSessionCookie(c *gin.Context, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: oauthPendingSessionCookieName, + Value: "", + Path: oauthPendingSessionCookiePath, + MaxAge: -1, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func readOAuthPendingSessionCookie(c *gin.Context) (string, error) { + return readCookieDecoded(c, oauthPendingSessionCookieName) +} + +func redirectToFrontendCallback(c *gin.Context, frontendCallback string) { + u, err := url.Parse(frontendCallback) + if err != nil { + c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo) + return + } + if u.Scheme != "" && !strings.EqualFold(u.Scheme, "http") && !strings.EqualFold(u.Scheme, "https") { + c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo) + return + } + u.Fragment = "" + c.Header("Cache-Control", "no-store") + c.Header("Pragma", "no-cache") + c.Redirect(http.StatusFound, u.String()) +} + +func (h *AuthHandler) createOAuthPendingSession(c *gin.Context, payload oauthPendingSessionPayload) error { + svc, err := h.pendingIdentityService() + if err != nil { + return err + } + + session, err := svc.CreatePendingSession(c.Request.Context(), service.CreatePendingAuthSessionInput{ + Intent: strings.TrimSpace(payload.Intent), + Identity: payload.Identity, + TargetUserID: payload.TargetUserID, + ResolvedEmail: strings.TrimSpace(payload.ResolvedEmail), + RedirectTo: strings.TrimSpace(payload.RedirectTo), + BrowserSessionKey: strings.TrimSpace(payload.BrowserSessionKey), + UpstreamIdentityClaims: payload.UpstreamIdentityClaims, + LocalFlowState: map[string]any{ + oauthCompletionResponseKey: payload.CompletionResponse, + }, + }) + if err != nil { + return infraerrors.InternalServer("PENDING_AUTH_SESSION_CREATE_FAILED", "failed to create pending auth session").WithCause(err) + } + + setOAuthPendingSessionCookie(c, session.SessionToken, isRequestHTTPS(c)) + return nil +} + +func readCompletionResponse(session map[string]any) (map[string]any, bool) { + if len(session) == 0 { + return nil, false + } + value, ok := session[oauthCompletionResponseKey] + if !ok { + return nil, false + } + result, ok := value.(map[string]any) + if !ok { + return nil, false + } + return result, true +} + +func clonePendingMap(values map[string]any) map[string]any { + if len(values) == 0 { + return map[string]any{} + } + cloned := make(map[string]any, len(values)) + for key, value := range values { + cloned[key] = value + } + return cloned +} + +func mergePendingCompletionResponse(session *dbent.PendingAuthSession, overrides map[string]any) map[string]any { + payload, _ := readCompletionResponse(session.LocalFlowState) + merged := clonePendingMap(payload) + if strings.TrimSpace(session.RedirectTo) != "" { + if _, exists := merged["redirect"]; !exists { + merged["redirect"] = session.RedirectTo + } + } + for key, value := range overrides { + if value == nil { + delete(merged, key) + continue + } + merged[key] = value + } + applySuggestedProfileToCompletionResponse(merged, session.UpstreamIdentityClaims) + return merged +} + +func pendingSessionStringValue(values map[string]any, key string) string { + if len(values) == 0 { + return "" + } + raw, ok := values[key] + if !ok { + return "" + } + value, ok := raw.(string) + if !ok { + return "" + } + return strings.TrimSpace(value) +} + +func pendingSessionWantsInvitation(payload map[string]any) bool { + return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "error")), "invitation_required") +} + +func pendingOAuthCompletionIncludesTokenPayload(payload map[string]any) bool { + if len(payload) == 0 { + return false + } + for _, key := range []string{"access_token", "refresh_token"} { + if value := pendingSessionStringValue(payload, key); value != "" { + return true + } + } + return false +} + +func ensurePendingOAuthCompleteRegistrationSession(session *dbent.PendingAuthSession) error { + if session == nil { + return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid") + } + if strings.TrimSpace(session.Intent) != oauthIntentLogin { + return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid") + } + if session.TargetUserID != nil && *session.TargetUserID > 0 { + return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid") + } + payload, _ := readCompletionResponse(session.LocalFlowState) + if strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "step")), "bind_login_required") { + return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid") + } + return nil +} + +func (r oauthAdoptionDecisionRequest) hasDecision() bool { + return r.AdoptDisplayName != nil || r.AdoptAvatar != nil +} + +func bindOptionalOAuthAdoptionDecision(c *gin.Context) (oauthAdoptionDecisionRequest, error) { + var req oauthAdoptionDecisionRequest + if c == nil || c.Request == nil || c.Request.Body == nil { + return req, nil + } + if err := c.ShouldBindJSON(&req); err != nil { + if errors.Is(err, io.EOF) { + return req, nil + } + return req, err + } + return req, nil +} + +func cloneOAuthMetadata(values map[string]any) map[string]any { + if len(values) == 0 { + return map[string]any{} + } + cloned := make(map[string]any, len(values)) + for key, value := range values { + cloned[key] = value + } + return cloned +} + +func mergeOAuthMetadata(base map[string]any, overlay map[string]any) map[string]any { + merged := cloneOAuthMetadata(base) + for key, value := range overlay { + merged[key] = value + } + return merged +} + +func normalizeAdoptedOAuthDisplayName(value string) string { + value = strings.TrimSpace(value) + if len([]rune(value)) > 100 { + value = string([]rune(value)[:100]) + } + return value +} + +func (h *AuthHandler) entClient() *dbent.Client { + if h == nil || h.authService == nil { + return nil + } + return h.authService.EntClient() +} + +func (h *AuthHandler) isForceEmailOnThirdPartySignup(ctx context.Context) bool { + if h == nil || h.settingSvc == nil { + return false + } + defaults, err := h.settingSvc.GetAuthSourceDefaultSettings(ctx) + if err != nil || defaults == nil { + return false + } + return defaults.ForceEmailOnThirdPartySignup +} + +func (h *AuthHandler) findOAuthIdentityUser(ctx context.Context, identity service.PendingAuthIdentityKey) (*dbent.User, error) { + client := h.entClient() + if client == nil { + return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + + record, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(strings.TrimSpace(identity.ProviderType)), + authidentity.ProviderKeyEQ(strings.TrimSpace(identity.ProviderKey)), + authidentity.ProviderSubjectEQ(strings.TrimSpace(identity.ProviderSubject)), + ). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil, nil + } + return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err) + } + + userEntity, err := client.User.Get(ctx, record.UserID) + if err != nil { + if dbent.IsNotFound(err) { + return nil, nil + } + return nil, infraerrors.InternalServer("AUTH_IDENTITY_USER_LOOKUP_FAILED", "failed to load auth identity user").WithCause(err) + } + return userEntity, nil +} + +func (h *AuthHandler) BindLinuxDoOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "linuxdo") } +func (h *AuthHandler) BindOIDCOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "oidc") } +func (h *AuthHandler) BindWeChatOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "wechat") } +func (h *AuthHandler) BindPendingOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "") } + +func (h *AuthHandler) CreateLinuxDoOAuthAccount(c *gin.Context) { + h.createPendingOAuthAccount(c, "linuxdo") +} + +func (h *AuthHandler) CreateOIDCOAuthAccount(c *gin.Context) { h.createPendingOAuthAccount(c, "oidc") } + +func (h *AuthHandler) CreateWeChatOAuthAccount(c *gin.Context) { + h.createPendingOAuthAccount(c, "wechat") +} + +func (h *AuthHandler) CreatePendingOAuthAccount(c *gin.Context) { + h.createPendingOAuthAccount(c, "") +} + +// SendPendingOAuthVerifyCode sends a verification code for a browser-bound +// pending OAuth account-creation flow. +// POST /api/v1/auth/oauth/pending/send-verify-code +func (h *AuthHandler) SendPendingOAuthVerifyCode(c *gin.Context) { + var req sendPendingOAuthVerifyCodeRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil { + response.ErrorFrom(c, err) + return + } + + _, session, _, err := readPendingOAuthBrowserSession(c, h) + if err != nil { + response.ErrorFrom(c, err) + return + } + if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil { + response.ErrorFrom(c, err) + return + } + + client := h.entClient() + if client == nil { + response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")) + return + } + + email := strings.TrimSpace(strings.ToLower(req.Email)) + if existingUser, err := findUserByNormalizedEmail(c.Request.Context(), client, email); err == nil && existingUser != nil { + session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, email) + if err != nil { + response.ErrorFrom(c, err) + return + } + c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session)) + return + } else if err != nil && !errors.Is(err, service.ErrUserNotFound) { + response.ErrorFrom(c, err) + return + } + + result, err := h.authService.SendPendingOAuthVerifyCode(c.Request.Context(), req.Email) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, SendVerifyCodeResponse{ + Message: "Verification code sent successfully", + Countdown: result.Countdown, + }) +} + +func (h *AuthHandler) upsertPendingOAuthAdoptionDecision( + c *gin.Context, + sessionID int64, + req oauthAdoptionDecisionRequest, +) (*dbent.IdentityAdoptionDecision, error) { + client := h.entClient() + if client == nil { + return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + + existing, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(sessionID)). + Only(c.Request.Context()) + if err != nil && !dbent.IsNotFound(err) { + return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_LOAD_FAILED", "failed to load oauth profile adoption decision").WithCause(err) + } + if existing != nil && !req.hasDecision() { + return existing, nil + } + if existing == nil && !req.hasDecision() { + return nil, nil + } + + input := service.PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: sessionID, + } + if existing != nil { + input.AdoptDisplayName = existing.AdoptDisplayName + input.AdoptAvatar = existing.AdoptAvatar + input.IdentityID = existing.IdentityID + } + if req.AdoptDisplayName != nil { + input.AdoptDisplayName = *req.AdoptDisplayName + } + if req.AdoptAvatar != nil { + input.AdoptAvatar = *req.AdoptAvatar + } + + svc, err := h.pendingIdentityService() + if err != nil { + return nil, err + } + decision, err := svc.UpsertAdoptionDecision(c.Request.Context(), input) + if err != nil { + return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_SAVE_FAILED", "failed to save oauth profile adoption decision").WithCause(err) + } + return decision, nil +} + +func (h *AuthHandler) ensurePendingOAuthAdoptionDecision( + c *gin.Context, + sessionID int64, + req oauthAdoptionDecisionRequest, +) (*dbent.IdentityAdoptionDecision, error) { + decision, err := h.upsertPendingOAuthAdoptionDecision(c, sessionID, req) + if err != nil { + return nil, err + } + if decision != nil { + return decision, nil + } + + svc, err := h.pendingIdentityService() + if err != nil { + return nil, err + } + decision, err = svc.UpsertAdoptionDecision(c.Request.Context(), service.PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: sessionID, + }) + if err != nil { + return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_SAVE_FAILED", "failed to save oauth profile adoption decision").WithCause(err) + } + return decision, nil +} + +func updatePendingOAuthSessionProgress( + ctx context.Context, + client *dbent.Client, + session *dbent.PendingAuthSession, + intent string, + resolvedEmail string, + targetUserID *int64, + completionResponse map[string]any, +) (*dbent.PendingAuthSession, error) { + if client == nil || session == nil { + return nil, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth session is invalid") + } + + localFlowState := clonePendingMap(session.LocalFlowState) + localFlowState[oauthCompletionResponseKey] = clonePendingMap(completionResponse) + + update := client.PendingAuthSession.UpdateOneID(session.ID). + SetIntent(strings.TrimSpace(intent)). + SetResolvedEmail(strings.TrimSpace(resolvedEmail)). + SetLocalFlowState(localFlowState) + if targetUserID != nil && *targetUserID > 0 { + update = update.SetTargetUserID(*targetUserID) + } else { + update = update.ClearTargetUserID() + } + return update.Save(ctx) +} + +func resolvePendingOAuthTargetUserID(ctx context.Context, client *dbent.Client, session *dbent.PendingAuthSession) (int64, error) { + if session == nil { + return 0, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth session is invalid") + } + if session.TargetUserID != nil && *session.TargetUserID > 0 { + return *session.TargetUserID, nil + } + email := strings.TrimSpace(session.ResolvedEmail) + if email == "" { + return 0, infraerrors.BadRequest("PENDING_AUTH_TARGET_USER_MISSING", "pending auth target user is missing") + } + + userEntity, err := findUserByNormalizedEmail(ctx, client, email) + if err != nil { + if errors.Is(err, service.ErrUserNotFound) { + return 0, infraerrors.InternalServer("PENDING_AUTH_TARGET_USER_NOT_FOUND", "pending auth target user was not found") + } + return 0, err + } + return userEntity.ID, nil +} + +func userNormalizedEmailPredicate(email string) predicate.User { + normalized := strings.ToLower(strings.TrimSpace(email)) + if normalized == "" { + return dbuser.EmailEQ(email) + } + return predicate.User(func(s *entsql.Selector) { + s.Where(entsql.P(func(b *entsql.Builder) { + b.WriteString("LOWER(TRIM("). + Ident(s.C(dbuser.FieldEmail)). + WriteString(")) = "). + Arg(normalized) + })) + }) +} + +func findUserByNormalizedEmail(ctx context.Context, client *dbent.Client, email string) (*dbent.User, error) { + if client == nil { + return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + + matches, err := client.User.Query(). + Where(userNormalizedEmailPredicate(email)). + Order(dbent.Asc(dbuser.FieldID)). + All(ctx) + if err != nil { + return nil, err + } + if len(matches) == 0 { + return nil, service.ErrUserNotFound + } + if len(matches) > 1 { + return nil, infraerrors.Conflict("USER_EMAIL_CONFLICT", "normalized email matched multiple users") + } + return matches[0], nil +} + +func oauthIdentityIssuer(session *dbent.PendingAuthSession) *string { + if session == nil { + return nil + } + switch strings.TrimSpace(session.ProviderType) { + case "oidc": + issuer := strings.TrimSpace(session.ProviderKey) + if issuer == "" { + issuer = pendingSessionStringValue(session.UpstreamIdentityClaims, "issuer") + } + if issuer == "" { + return nil + } + return &issuer + default: + issuer := pendingSessionStringValue(session.UpstreamIdentityClaims, "issuer") + if issuer == "" { + return nil + } + return &issuer + } +} + +func ensurePendingOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, session *dbent.PendingAuthSession, userID int64) (*dbent.AuthIdentity, error) { + if session != nil && strings.EqualFold(strings.TrimSpace(session.ProviderType), "wechat") { + return ensurePendingWeChatOAuthIdentityForUser(ctx, tx, session, userID) + } + + client := tx.Client() + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(strings.TrimSpace(session.ProviderType)), + authidentity.ProviderKeyEQ(strings.TrimSpace(session.ProviderKey)), + authidentity.ProviderSubjectEQ(strings.TrimSpace(session.ProviderSubject)), + ). + Only(ctx) + if err != nil && !dbent.IsNotFound(err) { + return nil, err + } + if identity != nil { + if identity.UserID != userID { + activeOwner, err := findActiveUserByID(ctx, client, identity.UserID) + if err != nil { + return nil, err + } + if activeOwner != nil { + return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + } + return client.AuthIdentity.UpdateOneID(identity.ID). + SetUserID(userID). + Save(ctx) + } + return identity, nil + } + + create := client.AuthIdentity.Create(). + SetUserID(userID). + SetProviderType(strings.TrimSpace(session.ProviderType)). + SetProviderKey(strings.TrimSpace(session.ProviderKey)). + SetProviderSubject(strings.TrimSpace(session.ProviderSubject)). + SetMetadata(cloneOAuthMetadata(session.UpstreamIdentityClaims)) + if issuer := oauthIdentityIssuer(session); issuer != nil { + create = create.SetIssuer(strings.TrimSpace(*issuer)) + } + return create.Save(ctx) +} + +func ensurePendingWeChatOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, session *dbent.PendingAuthSession, userID int64) (*dbent.AuthIdentity, error) { + client := tx.Client() + providerType := strings.TrimSpace(session.ProviderType) + providerKey := strings.TrimSpace(session.ProviderKey) + providerSubject := strings.TrimSpace(session.ProviderSubject) + providerKeys := wechatCompatibleProviderKeys(providerKey) + channel := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel")) + channelAppID := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel_app_id")) + channelSubject := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel_subject")) + metadata := cloneOAuthMetadata(session.UpstreamIdentityClaims) + + identityRecords, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(providerType), + authidentity.ProviderKeyIn(providerKeys...), + authidentity.ProviderSubjectEQ(providerSubject), + ). + All(ctx) + if err != nil { + return nil, err + } + identity, hasCanonicalKey, err := chooseWeChatIdentityForUser(ctx, client, identityRecords, userID, providerKey) + if err != nil { + return nil, err + } + + var legacyOpenIDIdentity *dbent.AuthIdentity + if channelSubject != "" && channelSubject != providerSubject { + legacyOpenIDRecords, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(providerType), + authidentity.ProviderKeyIn(providerKeys...), + authidentity.ProviderSubjectEQ(channelSubject), + ). + All(ctx) + if err != nil { + return nil, err + } + legacyOpenIDIdentity, _, err = chooseWeChatIdentityForUser(ctx, client, legacyOpenIDRecords, userID, providerKey) + if err != nil { + return nil, err + } + } + + switch { + case identity != nil: + update := client.AuthIdentity.UpdateOneID(identity.ID). + SetMetadata(mergeOAuthMetadata(identity.Metadata, metadata)) + if identity.UserID != userID { + update = update.SetUserID(userID) + } + if !strings.EqualFold(strings.TrimSpace(identity.ProviderKey), providerKey) && !hasCanonicalKey { + update = update.SetProviderKey(providerKey) + } + if issuer := oauthIdentityIssuer(session); issuer != nil { + update = update.SetIssuer(strings.TrimSpace(*issuer)) + } + identity, err = update.Save(ctx) + if err != nil { + return nil, err + } + case legacyOpenIDIdentity != nil: + update := client.AuthIdentity.UpdateOneID(legacyOpenIDIdentity.ID). + SetProviderKey(providerKey). + SetProviderSubject(providerSubject). + SetMetadata(mergeOAuthMetadata(legacyOpenIDIdentity.Metadata, metadata)) + if issuer := oauthIdentityIssuer(session); issuer != nil { + update = update.SetIssuer(strings.TrimSpace(*issuer)) + } + identity, err = update.Save(ctx) + if err != nil { + return nil, err + } + default: + create := client.AuthIdentity.Create(). + SetUserID(userID). + SetProviderType(providerType). + SetProviderKey(providerKey). + SetProviderSubject(providerSubject). + SetMetadata(metadata) + if issuer := oauthIdentityIssuer(session); issuer != nil { + create = create.SetIssuer(strings.TrimSpace(*issuer)) + } + identity, err = create.Save(ctx) + if err != nil { + return nil, err + } + } + + if channel == "" || channelAppID == "" || channelSubject == "" { + return identity, nil + } + + channelRecords, err := client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ(providerType), + authidentitychannel.ProviderKeyIn(providerKeys...), + authidentitychannel.ChannelEQ(channel), + authidentitychannel.ChannelAppIDEQ(channelAppID), + authidentitychannel.ChannelSubjectEQ(channelSubject), + ). + WithIdentity(). + All(ctx) + if err != nil { + return nil, err + } + channelRecord, hasCanonicalChannelKey, err := chooseWeChatChannelForUser(ctx, client, channelRecords, userID, providerKey) + if err != nil { + return nil, err + } + + channelMetadata := mergeOAuthMetadata(channelRecordMetadata(channelRecord), metadata) + if channelRecord == nil { + if _, err := client.AuthIdentityChannel.Create(). + SetIdentityID(identity.ID). + SetProviderType(providerType). + SetProviderKey(providerKey). + SetChannel(channel). + SetChannelAppID(channelAppID). + SetChannelSubject(channelSubject). + SetMetadata(channelMetadata). + Save(ctx); err != nil { + return nil, err + } + return identity, nil + } + + updateChannel := client.AuthIdentityChannel.UpdateOneID(channelRecord.ID). + SetIdentityID(identity.ID). + SetMetadata(channelMetadata) + if !strings.EqualFold(strings.TrimSpace(channelRecord.ProviderKey), providerKey) && !hasCanonicalChannelKey { + updateChannel = updateChannel.SetProviderKey(providerKey) + } + _, err = updateChannel.Save(ctx) + if err != nil { + return nil, err + } + return identity, nil +} + +func chooseWeChatIdentityForUser(ctx context.Context, client *dbent.Client, records []*dbent.AuthIdentity, userID int64, preferredProviderKey string) (*dbent.AuthIdentity, bool, error) { + var preferred *dbent.AuthIdentity + var fallback *dbent.AuthIdentity + hasCanonicalKey := false + for _, record := range records { + if record == nil { + continue + } + if record.UserID != userID { + activeOwner, err := findActiveUserByID(ctx, client, record.UserID) + if err != nil { + return nil, false, err + } + if activeOwner != nil { + return nil, false, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + } + } + if strings.EqualFold(strings.TrimSpace(record.ProviderKey), preferredProviderKey) { + hasCanonicalKey = true + if preferred == nil { + preferred = record + } + continue + } + if fallback == nil { + fallback = record + } + } + if preferred != nil { + return preferred, hasCanonicalKey, nil + } + return fallback, hasCanonicalKey, nil +} + +func chooseWeChatChannelForUser(ctx context.Context, client *dbent.Client, records []*dbent.AuthIdentityChannel, userID int64, preferredProviderKey string) (*dbent.AuthIdentityChannel, bool, error) { + var preferred *dbent.AuthIdentityChannel + var fallback *dbent.AuthIdentityChannel + hasCanonicalKey := false + for _, record := range records { + if record == nil { + continue + } + if record.Edges.Identity != nil && record.Edges.Identity.UserID != userID { + activeOwner, err := findActiveUserByID(ctx, client, record.Edges.Identity.UserID) + if err != nil { + return nil, false, err + } + if activeOwner != nil { + return nil, false, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user") + } + } + if strings.EqualFold(strings.TrimSpace(record.ProviderKey), preferredProviderKey) { + hasCanonicalKey = true + if preferred == nil { + preferred = record + } + continue + } + if fallback == nil { + fallback = record + } + } + if preferred != nil { + return preferred, hasCanonicalKey, nil + } + return fallback, hasCanonicalKey, nil +} + +func findActiveUserByID(ctx context.Context, client *dbent.Client, userID int64) (*dbent.User, error) { + if client == nil || userID <= 0 { + return nil, nil + } + userEntity, err := client.User.Get(ctx, userID) + if err != nil { + if dbent.IsNotFound(err) { + return nil, nil + } + return nil, infraerrors.InternalServer("AUTH_IDENTITY_USER_LOOKUP_FAILED", "failed to load auth identity user").WithCause(err) + } + return userEntity, nil +} + +func channelRecordMetadata(channel *dbent.AuthIdentityChannel) map[string]any { + if channel == nil { + return map[string]any{} + } + return cloneOAuthMetadata(channel.Metadata) +} + +func shouldBindPendingOAuthIdentity(session *dbent.PendingAuthSession, decision *dbent.IdentityAdoptionDecision) bool { + if session == nil || decision == nil { + return false + } + switch strings.ToLower(strings.TrimSpace(session.Intent)) { + case "bind_current_user", "login", "adopt_existing_user_by_email": + return true + default: + return decision.AdoptDisplayName || decision.AdoptAvatar + } +} + +func shouldSkipAvatarAdoption(err error) bool { + return errors.Is(err, service.ErrAvatarInvalid) || + errors.Is(err, service.ErrAvatarTooLarge) || + errors.Is(err, service.ErrAvatarNotImage) +} + +func applyPendingOAuthBinding( + ctx context.Context, + client *dbent.Client, + authService *service.AuthService, + userService *service.UserService, + session *dbent.PendingAuthSession, + decision *dbent.IdentityAdoptionDecision, + overrideUserID *int64, + forceBind bool, + applyFirstBindDefaults bool, +) error { + if client == nil || session == nil { + return nil + } + if !forceBind && !shouldBindPendingOAuthIdentity(session, decision) { + return nil + } + + if tx := dbent.TxFromContext(ctx); tx != nil { + return applyPendingOAuthBindingTx(ctx, tx, authService, userService, session, decision, overrideUserID, forceBind, applyFirstBindDefaults) + } + + tx, err := client.Tx(ctx) + if err != nil { + return err + } + defer func() { _ = tx.Rollback() }() + + txCtx := dbent.NewTxContext(ctx, tx) + if err := applyPendingOAuthBindingTx(txCtx, tx, authService, userService, session, decision, overrideUserID, forceBind, applyFirstBindDefaults); err != nil { + return err + } + return tx.Commit() +} + +func applyPendingOAuthBindingTx( + ctx context.Context, + tx *dbent.Tx, + authService *service.AuthService, + userService *service.UserService, + session *dbent.PendingAuthSession, + decision *dbent.IdentityAdoptionDecision, + overrideUserID *int64, + forceBind bool, + applyFirstBindDefaults bool, +) error { + if tx == nil || session == nil { + return nil + } + if !forceBind && !shouldBindPendingOAuthIdentity(session, decision) { + return nil + } + + targetUserID := int64(0) + if overrideUserID != nil && *overrideUserID > 0 { + targetUserID = *overrideUserID + } else { + resolvedUserID, err := resolvePendingOAuthTargetUserID(ctx, tx.Client(), session) + if err != nil { + return err + } + targetUserID = resolvedUserID + } + + adoptedDisplayName := "" + if decision != nil && decision.AdoptDisplayName { + adoptedDisplayName = normalizeAdoptedOAuthDisplayName(pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_display_name")) + } + adoptedAvatarURL := "" + if decision != nil && decision.AdoptAvatar { + adoptedAvatarURL = pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_avatar_url") + } + shouldAdoptAvatar := false + if decision != nil && decision.AdoptAvatar && adoptedAvatarURL != "" { + if err := service.ValidateUserAvatar(adoptedAvatarURL); err == nil { + shouldAdoptAvatar = true + } else if !shouldSkipAvatarAdoption(err) { + return err + } + } + + if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" { + if err := tx.Client().User.UpdateOneID(targetUserID). + SetUsername(adoptedDisplayName). + Exec(ctx); err != nil { + return err + } + } + + identity, err := ensurePendingOAuthIdentityForUser(ctx, tx, session, targetUserID) + if err != nil { + return err + } + + metadata := cloneOAuthMetadata(identity.Metadata) + for key, value := range session.UpstreamIdentityClaims { + metadata[key] = value + } + if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" { + metadata["display_name"] = adoptedDisplayName + } + if shouldAdoptAvatar { + metadata["avatar_url"] = adoptedAvatarURL + } + + updateIdentity := tx.Client().AuthIdentity.UpdateOneID(identity.ID).SetMetadata(metadata) + if issuer := oauthIdentityIssuer(session); issuer != nil { + updateIdentity = updateIdentity.SetIssuer(strings.TrimSpace(*issuer)) + } + if _, err := updateIdentity.Save(ctx); err != nil { + return err + } + + if decision != nil && (decision.IdentityID == nil || *decision.IdentityID != identity.ID) { + if _, err := tx.Client().IdentityAdoptionDecision.Update(). + Where( + identityadoptiondecision.IdentityIDEQ(identity.ID), + identityadoptiondecision.IDNEQ(decision.ID), + ). + ClearIdentityID(). + Save(ctx); err != nil { + return err + } + if _, err := tx.Client().IdentityAdoptionDecision.UpdateOneID(decision.ID). + SetIdentityID(identity.ID). + Save(ctx); err != nil { + return err + } + } + + if applyFirstBindDefaults && authService != nil { + if err := authService.ApplyProviderDefaultSettingsOnFirstBind(ctx, targetUserID, session.ProviderType); err != nil { + return err + } + } + + if shouldAdoptAvatar && userService != nil { + if _, err := userService.SetAvatar(ctx, targetUserID, adoptedAvatarURL); err != nil { + return err + } + } + + return nil +} + +func consumePendingOAuthBrowserSessionTx( + ctx context.Context, + tx *dbent.Tx, + session *dbent.PendingAuthSession, +) error { + if tx == nil || session == nil { + return service.ErrPendingAuthSessionNotFound + } + + storedSession, err := tx.Client().PendingAuthSession.Get(ctx, session.ID) + if err != nil { + if dbent.IsNotFound(err) { + return service.ErrPendingAuthSessionNotFound + } + return err + } + + now := time.Now().UTC() + if storedSession.ConsumedAt != nil { + return service.ErrPendingAuthSessionConsumed + } + if !storedSession.ExpiresAt.IsZero() && now.After(storedSession.ExpiresAt) { + return service.ErrPendingAuthSessionExpired + } + if strings.TrimSpace(storedSession.BrowserSessionKey) != "" && + strings.TrimSpace(storedSession.BrowserSessionKey) != strings.TrimSpace(session.BrowserSessionKey) { + return service.ErrPendingAuthBrowserMismatch + } + + if _, err := tx.Client().PendingAuthSession.UpdateOneID(storedSession.ID). + SetConsumedAt(now). + SetCompletionCodeHash(""). + ClearCompletionCodeExpiresAt(). + Save(ctx); err != nil { + return err + } + + return nil +} + +func applyPendingOAuthAdoption( + ctx context.Context, + client *dbent.Client, + authService *service.AuthService, + userService *service.UserService, + session *dbent.PendingAuthSession, + decision *dbent.IdentityAdoptionDecision, + overrideUserID *int64, +) error { + return applyPendingOAuthBinding( + ctx, + client, + authService, + userService, + session, + decision, + overrideUserID, + false, + strings.EqualFold(strings.TrimSpace(session.Intent), "bind_current_user"), + ) +} + +func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream map[string]any) { + if len(payload) == 0 || len(upstream) == 0 { + return + } + + displayName := pendingSessionStringValue(upstream, "suggested_display_name") + avatarURL := pendingSessionStringValue(upstream, "suggested_avatar_url") + + if displayName != "" { + if _, exists := payload["suggested_display_name"]; !exists { + payload["suggested_display_name"] = displayName + } + } + if avatarURL != "" { + if _, exists := payload["suggested_avatar_url"]; !exists { + payload["suggested_avatar_url"] = avatarURL + } + } + if displayName != "" || avatarURL != "" { + payload["adoption_required"] = true + } +} + +func pendingOAuthIdentityExistsForUser( + ctx context.Context, + client *dbent.Client, + session *dbent.PendingAuthSession, + userID int64, +) (bool, error) { + if client == nil || session == nil || userID <= 0 { + return false, nil + } + + providerType := strings.TrimSpace(session.ProviderType) + providerKey := strings.TrimSpace(session.ProviderKey) + providerSubject := strings.TrimSpace(session.ProviderSubject) + if providerType == "" || providerSubject == "" { + return false, nil + } + + query := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(providerType), + authidentity.ProviderSubjectEQ(providerSubject), + authidentity.UserIDEQ(userID), + ) + if strings.EqualFold(providerType, "wechat") { + query = query.Where(authidentity.ProviderKeyIn(wechatCompatibleProviderKeys(providerKey)...)) + } else if providerKey != "" { + query = query.Where(authidentity.ProviderKeyEQ(providerKey)) + } + + count, err := query.Count(ctx) + if err != nil { + return false, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err) + } + return count > 0, nil +} + +func (h *AuthHandler) shouldSkipPendingOAuthAdoptionPrompt( + ctx context.Context, + session *dbent.PendingAuthSession, + payload map[string]any, +) (bool, error) { + if session == nil || len(payload) == 0 { + return false, nil + } + if !strings.EqualFold(strings.TrimSpace(session.Intent), oauthIntentLogin) { + return false, nil + } + if !pendingOAuthCompletionIncludesTokenPayload(payload) { + return false, nil + } + if session.TargetUserID == nil || *session.TargetUserID <= 0 { + return false, nil + } + if pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_display_name") == "" && + pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_avatar_url") == "" { + return false, nil + } + + return pendingOAuthIdentityExistsForUser(ctx, h.entClient(), session, *session.TargetUserID) +} + +func readPendingOAuthBrowserSession(c *gin.Context, h *AuthHandler) (*service.AuthPendingIdentityService, *dbent.PendingAuthSession, func(), error) { + secureCookie := isRequestHTTPS(c) + clearCookies := func() { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + } + + sessionToken, err := readOAuthPendingSessionCookie(c) + if err != nil || strings.TrimSpace(sessionToken) == "" { + clearCookies() + return nil, nil, clearCookies, service.ErrPendingAuthSessionNotFound + } + browserSessionKey, err := readOAuthPendingBrowserCookie(c) + if err != nil || strings.TrimSpace(browserSessionKey) == "" { + clearCookies() + return nil, nil, clearCookies, service.ErrPendingAuthBrowserMismatch + } + + svc, err := h.pendingIdentityService() + if err != nil { + clearCookies() + return nil, nil, clearCookies, err + } + + session, err := svc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey) + if err != nil { + clearCookies() + return nil, nil, clearCookies, err + } + + return svc, session, clearCookies, nil +} + +func buildPendingOAuthSessionStatusPayload(session *dbent.PendingAuthSession) gin.H { + completionResponse := normalizePendingOAuthCompletionResponse(mergePendingCompletionResponse(session, nil)) + payload := gin.H{ + "auth_result": "pending_session", + "provider": strings.TrimSpace(session.ProviderType), + "intent": strings.TrimSpace(session.Intent), + } + for key, value := range completionResponse { + payload[key] = value + } + if email := strings.TrimSpace(session.ResolvedEmail); email != "" { + payload["email"] = email + } + return payload +} + +func normalizePendingOAuthCompletionResponse(payload map[string]any) map[string]any { + normalized := clonePendingMap(payload) + step := strings.ToLower(strings.TrimSpace(pendingSessionStringValue(normalized, "step"))) + switch step { + case "choice", "choose_account_action", "choose_account", "choose", "email_required", "bind_login_required": + normalized["step"] = oauthPendingChoiceStep + } + if strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(normalized, "step")), oauthPendingChoiceStep) { + normalized["adoption_required"] = true + } + if _, exists := normalized["adoption_required"]; !exists { + if _, hasChoiceFields := normalized["email_binding_required"]; hasChoiceFields { + normalized["adoption_required"] = true + } + } + return normalized +} + +func pendingOAuthChoiceCompletionResponse(session *dbent.PendingAuthSession, email string) map[string]any { + response := mergePendingCompletionResponse(session, map[string]any{ + "step": oauthPendingChoiceStep, + "adoption_required": true, + "force_email_on_signup": true, + "email_binding_required": true, + "existing_account_bindable": true, + }) + if email = strings.TrimSpace(email); email != "" { + response["email"] = email + response["resolved_email"] = email + } + return response +} + +func (h *AuthHandler) transitionPendingOAuthAccountToChoiceState( + c *gin.Context, + client *dbent.Client, + session *dbent.PendingAuthSession, + email string, +) (*dbent.PendingAuthSession, error) { + completionResponse := pendingOAuthChoiceCompletionResponse(session, email) + session, err := updatePendingOAuthSessionProgress( + c.Request.Context(), + client, + session, + strings.TrimSpace(session.Intent), + email, + nil, + completionResponse, + ) + if err != nil { + return nil, infraerrors.InternalServer("PENDING_AUTH_SESSION_UPDATE_FAILED", "failed to update pending oauth session").WithCause(err) + } + return session, nil +} + +func writeOAuthTokenPairResponse(c *gin.Context, tokenPair *service.TokenPair) { + c.JSON(http.StatusOK, gin.H{ + "access_token": tokenPair.AccessToken, + "refresh_token": tokenPair.RefreshToken, + "expires_in": tokenPair.ExpiresIn, + "token_type": "Bearer", + }) +} + +func (h *AuthHandler) bindPendingOAuthLogin(c *gin.Context, provider string) { + var req bindPendingOAuthLoginRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + pendingSvc, session, clearCookies, err := readPendingOAuthBrowserSession(c, h) + if err != nil { + response.ErrorFrom(c, err) + return + } + if strings.TrimSpace(provider) != "" && !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) { + response.BadRequest(c, "Pending oauth session provider mismatch") + return + } + + user, err := h.authService.ValidatePasswordCredentials(c.Request.Context(), strings.TrimSpace(req.Email), req.Password) + if err != nil { + response.ErrorFrom(c, err) + return + } + if session.TargetUserID != nil && *session.TargetUserID > 0 && user.ID != *session.TargetUserID { + response.ErrorFrom(c, infraerrors.Conflict("PENDING_AUTH_TARGET_USER_MISMATCH", "pending oauth session must be completed by the targeted user")) + return + } + if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil { + response.ErrorFrom(c, err) + return + } + + decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision()) + if err != nil { + response.ErrorFrom(c, err) + return + } + if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled { + tempToken, err := h.totpService.CreatePendingOAuthBindLoginSession( + c.Request.Context(), + user.ID, + user.Email, + session.SessionToken, + session.BrowserSessionKey, + ) + if err != nil { + response.InternalError(c, "Failed to create 2FA session") + return + } + response.Success(c, TotpLoginResponse{ + Requires2FA: true, + TempToken: tempToken, + UserEmailMasked: service.MaskEmail(user.Email), + }) + return + } + if err := applyPendingOAuthBinding(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID, true, true); err != nil { + respondPendingOAuthBindingApplyError(c, err) + return + } + + h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) + tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), user, "") + if err != nil { + response.InternalError(c, "Failed to generate token pair") + return + } + if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), session.SessionToken, session.BrowserSessionKey); err != nil { + clearCookies() + response.ErrorFrom(c, err) + return + } + + clearCookies() + writeOAuthTokenPairResponse(c, tokenPair) +} + +func respondPendingOAuthBindingApplyError(c *gin.Context, err error) { + if code := infraerrors.Code(err); code >= http.StatusBadRequest && code < http.StatusInternalServerError { + response.ErrorFrom(c, err) + return + } + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err)) +} + +func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) { + var req createPendingOAuthAccountRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + _, session, clearCookies, err := readPendingOAuthBrowserSession(c, h) + if err != nil { + response.ErrorFrom(c, err) + return + } + if strings.TrimSpace(provider) != "" && !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) { + response.BadRequest(c, "Pending oauth session provider mismatch") + return + } + + client := h.entClient() + if client == nil { + response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")) + return + } + + email := strings.TrimSpace(strings.ToLower(req.Email)) + existingUser, err := findUserByNormalizedEmail(c.Request.Context(), client, email) + if err != nil { + switch { + case errors.Is(err, service.ErrUserNotFound): + existingUser = nil + case infraerrors.Code(err) >= http.StatusBadRequest && infraerrors.Code(err) < http.StatusInternalServerError: + response.ErrorFrom(c, err) + return + default: + response.ErrorFrom(c, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")) + return + } + } + if existingUser != nil { + session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, email) + if err != nil { + response.ErrorFrom(c, err) + return + } + c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session)) + return + } + if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + tokenPair, user, err := h.authService.RegisterOAuthEmailAccount( + c.Request.Context(), + email, + req.Password, + strings.TrimSpace(req.VerifyCode), + strings.TrimSpace(req.InvitationCode), + strings.TrimSpace(session.ProviderType), + ) + if err != nil { + if errors.Is(err, service.ErrEmailExists) { + session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, email) + if err != nil { + response.ErrorFrom(c, err) + return + } + c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session)) + return + } + response.ErrorFrom(c, err) + return + } + + rollbackCreatedUser := func(originalErr error) bool { + if user == nil || user.ID <= 0 { + return false + } + if rollbackErr := h.authService.RollbackOAuthEmailAccountCreation( + c.Request.Context(), + user.ID, + strings.TrimSpace(req.InvitationCode), + ); rollbackErr != nil { + response.ErrorFrom(c, infraerrors.InternalServer( + "PENDING_AUTH_ACCOUNT_ROLLBACK_FAILED", + "failed to rollback pending oauth account creation", + ).WithCause(fmt.Errorf("original error: %w; rollback error: %v", originalErr, rollbackErr))) + return true + } + user = nil + return false + } + + decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision()) + if err != nil { + if rollbackCreatedUser(err) { + return + } + response.ErrorFrom(c, err) + return + } + + tx, err := client.Tx(c.Request.Context()) + if err != nil { + if rollbackCreatedUser(err) { + return + } + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err)) + return + } + defer func() { _ = tx.Rollback() }() + txCtx := dbent.NewTxContext(c.Request.Context(), tx) + + if err := applyPendingOAuthBinding(txCtx, client, h.authService, h.userService, session, decision, &user.ID, true, false); err != nil { + _ = tx.Rollback() + if rollbackCreatedUser(err) { + return + } + respondPendingOAuthBindingApplyError(c, err) + return + } + + if err := h.authService.FinalizeOAuthEmailAccount( + txCtx, + user, + strings.TrimSpace(req.InvitationCode), + strings.TrimSpace(session.ProviderType), + ); err != nil { + _ = tx.Rollback() + if rollbackCreatedUser(err) { + return + } + response.ErrorFrom(c, err) + return + } + + if err := consumePendingOAuthBrowserSessionTx(txCtx, tx, session); err != nil { + _ = tx.Rollback() + if rollbackCreatedUser(err) { + return + } + clearCookies() + response.ErrorFrom(c, err) + return + } + + if pendingOAuthCreateAccountPreCommitHook != nil { + if err := pendingOAuthCreateAccountPreCommitHook(txCtx, session); err != nil { + _ = tx.Rollback() + if rollbackCreatedUser(err) { + return + } + respondPendingOAuthBindingApplyError(c, err) + return + } + } + + if err := tx.Commit(); err != nil { + if rollbackCreatedUser(err) { + return + } + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err)) + return + } + + h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) + clearCookies() + writeOAuthTokenPairResponse(c, tokenPair) +} + +// ExchangePendingOAuthCompletion redeems a pending OAuth browser session into a frontend-safe payload. +// POST /api/v1/auth/oauth/pending/exchange +func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) { + secureCookie := isRequestHTTPS(c) + clearCookies := func() { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + } + adoptionDecision, err := bindOptionalOAuthAdoptionDecision(c) + if err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + sessionToken, err := readOAuthPendingSessionCookie(c) + if err != nil || strings.TrimSpace(sessionToken) == "" { + clearCookies() + response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound) + return + } + browserSessionKey, err := readOAuthPendingBrowserCookie(c) + if err != nil || strings.TrimSpace(browserSessionKey) == "" { + clearCookies() + response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch) + return + } + + svc, err := h.pendingIdentityService() + if err != nil { + clearCookies() + response.ErrorFrom(c, err) + return + } + + session, err := svc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey) + if err != nil { + clearCookies() + response.ErrorFrom(c, err) + return + } + + payload, ok := readCompletionResponse(session.LocalFlowState) + if !ok { + clearCookies() + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_COMPLETION_INVALID", "pending auth completion payload is invalid")) + return + } + payload = normalizePendingOAuthCompletionResponse(payload) + if strings.TrimSpace(session.RedirectTo) != "" { + if _, exists := payload["redirect"]; !exists { + payload["redirect"] = session.RedirectTo + } + } + applySuggestedProfileToCompletionResponse(payload, session.UpstreamIdentityClaims) + skipAdoptionPrompt, err := h.shouldSkipPendingOAuthAdoptionPrompt(c.Request.Context(), session, payload) + if err != nil { + clearCookies() + response.ErrorFrom(c, err) + return + } + if skipAdoptionPrompt { + delete(payload, "adoption_required") + } + if pendingOAuthCompletionIncludesTokenPayload(payload) { + if session.TargetUserID == nil || *session.TargetUserID <= 0 { + clearCookies() + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_COMPLETION_INVALID", "pending auth completion payload is invalid")) + return + } + user, err := h.userService.GetByID(c.Request.Context(), *session.TargetUserID) + if err != nil { + clearCookies() + response.ErrorFrom(c, err) + return + } + if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil { + clearCookies() + response.ErrorFrom(c, err) + return + } + h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) + } + + if pendingSessionWantsInvitation(payload) { + if adoptionDecision.hasDecision() { + decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, adoptionDecision) + if err != nil { + response.ErrorFrom(c, err) + return + } + _ = decision + } + response.Success(c, payload) + return + } + if !adoptionDecision.hasDecision() { + adoptionRequired, _ := payload["adoption_required"].(bool) + if adoptionRequired { + response.Success(c, payload) + return + } + } + + decisionReq := adoptionDecision + if !decisionReq.hasDecision() { + adoptDisplayName := false + adoptAvatar := false + decisionReq = oauthAdoptionDecisionRequest{ + AdoptDisplayName: &adoptDisplayName, + AdoptAvatar: &adoptAvatar, + } + } + + decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, decisionReq) + if err != nil { + response.ErrorFrom(c, err) + return + } + if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, session.TargetUserID); err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) + return + } + + if _, err := svc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil { + clearCookies() + response.ErrorFrom(c, err) + return + } + + clearCookies() + response.Success(c, payload) +} diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go new file mode 100644 index 0000000000000000000000000000000000000000..8940e37d266f4336223fe4b017be2b590a053290 --- /dev/null +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -0,0 +1,2881 @@ +package handler + +import ( + "bytes" + "context" + "database/sql" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/Wei-Shaw/sub2api/ent/redeemcode" + dbuser "github.com/Wei-Shaw/sub2api/ent/user" + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/pquerna/otp/totp" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +func TestApplySuggestedProfileToCompletionResponse(t *testing.T) { + payload := map[string]any{ + "access_token": "token", + } + upstream := map[string]any{ + "suggested_display_name": "Alice", + "suggested_avatar_url": "https://cdn.example/avatar.png", + } + + applySuggestedProfileToCompletionResponse(payload, upstream) + + require.Equal(t, "Alice", payload["suggested_display_name"]) + require.Equal(t, "https://cdn.example/avatar.png", payload["suggested_avatar_url"]) + require.Equal(t, true, payload["adoption_required"]) +} + +func TestApplySuggestedProfileToCompletionResponseKeepsExistingPayloadValues(t *testing.T) { + payload := map[string]any{ + "suggested_display_name": "Existing", + "adoption_required": false, + } + upstream := map[string]any{ + "suggested_display_name": "Alice", + "suggested_avatar_url": "https://cdn.example/avatar.png", + } + + applySuggestedProfileToCompletionResponse(payload, upstream) + + require.Equal(t, "Existing", payload["suggested_display_name"]) + require.Equal(t, "https://cdn.example/avatar.png", payload["suggested_avatar_url"]) + require.Equal(t, true, payload["adoption_required"]) +} + +func TestSetOAuthPendingSessionCookieUsesProviderCompletionPathPrefix(t *testing.T) { + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + ginCtx.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback", nil) + + setOAuthPendingSessionCookie(ginCtx, "pending-session-token", false) + + cookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, cookie) + require.Equal(t, "/api/v1/auth/oauth", cookie.Path) +} + +func TestExchangePendingOAuthCompletionPreviewThenFinalizeAppliesAdoptionDecision(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + userEntity, err := client.User.Create(). + SetEmail("linuxdo-123@linuxdo-connect.invalid"). + SetUsername("legacy-name"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("pending-session-token"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("123"). + SetTargetUserID(userEntity.ID). + SetResolvedEmail(userEntity.Email). + SetBrowserSessionKey("browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "linuxdo_user", + "suggested_display_name": "Alice Example", + "suggested_avatar_url": "https://cdn.example/alice.png", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "access_token": "access-token", + "redirect": "/dashboard", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + previewRecorder := httptest.NewRecorder() + previewCtx, _ := gin.CreateTestContext(previewRecorder) + previewReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil) + previewReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + previewReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-session-key")}) + previewCtx.Request = previewReq + + handler.ExchangePendingOAuthCompletion(previewCtx) + + require.Equal(t, http.StatusOK, previewRecorder.Code) + previewData := decodeJSONResponseData(t, previewRecorder) + require.Equal(t, "Alice Example", previewData["suggested_display_name"]) + require.Equal(t, "https://cdn.example/alice.png", previewData["suggested_avatar_url"]) + require.Equal(t, true, previewData["adoption_required"]) + + storedUser, err := client.User.Get(ctx, userEntity.ID) + require.NoError(t, err) + require.Equal(t, "legacy-name", storedUser.Username) + + previewSession, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.Nil(t, previewSession.ConsumedAt) + + body := bytes.NewBufferString(`{"adopt_display_name":true,"adopt_avatar":true}`) + finalizeRecorder := httptest.NewRecorder() + finalizeCtx, _ := gin.CreateTestContext(finalizeRecorder) + finalizeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body) + finalizeReq.Header.Set("Content-Type", "application/json") + finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-session-key")}) + finalizeCtx.Request = finalizeReq + + handler.ExchangePendingOAuthCompletion(finalizeCtx) + + require.Equal(t, http.StatusOK, finalizeRecorder.Code) + + storedUser, err = client.User.Get(ctx, userEntity.ID) + require.NoError(t, err) + require.Equal(t, "Alice Example", storedUser.Username) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("linuxdo"), + authidentity.ProviderKeyEQ("linuxdo"), + authidentity.ProviderSubjectEQ("123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, userEntity.ID, identity.UserID) + require.Equal(t, "Alice Example", identity.Metadata["display_name"]) + require.Equal(t, "https://cdn.example/alice.png", identity.Metadata["avatar_url"]) + + avatar := loadUserAvatarRecord(t, client, userEntity.ID) + require.NotNil(t, avatar) + require.Equal(t, "remote_url", avatar.StorageProvider) + require.Equal(t, "https://cdn.example/alice.png", avatar.URL) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, decision.IdentityID) + require.Equal(t, identity.ID, *decision.IdentityID) + require.True(t, decision.AdoptDisplayName) + require.True(t, decision.AdoptAvatar) + + consumed, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) +} + +func TestExchangePendingOAuthCompletionSkipsInvalidAvatarAdoptionWithoutBlockingCompletion(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + userEntity, err := client.User.Create(). + SetEmail("invalid-avatar@example.com"). + SetUsername("legacy-name"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("pending-invalid-avatar-token"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("invalid-avatar-123"). + SetTargetUserID(userEntity.ID). + SetResolvedEmail(userEntity.Email). + SetBrowserSessionKey("browser-invalid-avatar-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "linuxdo_user", + "suggested_display_name": "Alice Example", + "suggested_avatar_url": "/avatars/alice.png", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "access_token": "access-token", + "redirect": "/dashboard", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"adopt_display_name":true,"adopt_avatar":true}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-invalid-avatar-key")}) + ginCtx.Request = req + + handler.ExchangePendingOAuthCompletion(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("linuxdo"), + authidentity.ProviderKeyEQ("linuxdo"), + authidentity.ProviderSubjectEQ("invalid-avatar-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, "Alice Example", identity.Metadata["display_name"]) + _, hasAdoptedAvatar := identity.Metadata["avatar_url"] + require.False(t, hasAdoptedAvatar) + + avatar := loadUserAvatarRecord(t, client, userEntity.ID) + require.Nil(t, avatar) + + consumed, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) +} + +func TestExchangePendingOAuthCompletionBindCurrentUserPreviewThenFinalizeBindsIdentityWithoutAdoption(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + userEntity, err := client.User.Create(). + SetEmail("bind-target@example.com"). + SetUsername("legacy-name"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("bind-pending-session-token"). + SetIntent("bind_current_user"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("bind-123"). + SetTargetUserID(userEntity.ID). + SetResolvedEmail(userEntity.Email). + SetBrowserSessionKey("bind-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "linuxdo_user", + "suggested_display_name": "Bound Example", + "suggested_avatar_url": "https://cdn.example/bound.png", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "access_token": "access-token", + "redirect": "/settings/profile", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + previewRecorder := httptest.NewRecorder() + previewCtx, _ := gin.CreateTestContext(previewRecorder) + previewReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil) + previewReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + previewReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-browser-session-key")}) + previewCtx.Request = previewReq + + handler.ExchangePendingOAuthCompletion(previewCtx) + + require.Equal(t, http.StatusOK, previewRecorder.Code) + previewData := decodeJSONResponseData(t, previewRecorder) + require.Equal(t, "Bound Example", previewData["suggested_display_name"]) + require.Equal(t, "https://cdn.example/bound.png", previewData["suggested_avatar_url"]) + require.Equal(t, true, previewData["adoption_required"]) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("linuxdo"), + authidentity.ProviderKeyEQ("linuxdo"), + authidentity.ProviderSubjectEQ("bind-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, identityCount) + + previewSession, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.Nil(t, previewSession.ConsumedAt) + + body := bytes.NewBufferString(`{"adopt_display_name":false,"adopt_avatar":false}`) + finalizeRecorder := httptest.NewRecorder() + finalizeCtx, _ := gin.CreateTestContext(finalizeRecorder) + finalizeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body) + finalizeReq.Header.Set("Content-Type", "application/json") + finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-browser-session-key")}) + finalizeCtx.Request = finalizeReq + + handler.ExchangePendingOAuthCompletion(finalizeCtx) + + require.Equal(t, http.StatusOK, finalizeRecorder.Code) + + storedUser, err := client.User.Get(ctx, userEntity.ID) + require.NoError(t, err) + require.Equal(t, "legacy-name", storedUser.Username) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("linuxdo"), + authidentity.ProviderKeyEQ("linuxdo"), + authidentity.ProviderSubjectEQ("bind-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, userEntity.ID, identity.UserID) + require.Equal(t, "Bound Example", identity.Metadata["suggested_display_name"]) + require.Equal(t, "https://cdn.example/bound.png", identity.Metadata["suggested_avatar_url"]) + _, hasDisplayName := identity.Metadata["display_name"] + require.False(t, hasDisplayName) + _, hasAvatarURL := identity.Metadata["avatar_url"] + require.False(t, hasAvatarURL) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, decision.IdentityID) + require.Equal(t, identity.ID, *decision.IdentityID) + require.False(t, decision.AdoptDisplayName) + require.False(t, decision.AdoptAvatar) + + consumed, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) +} + +func TestExchangePendingOAuthCompletionBindCurrentUserOwnershipConflict(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + targetUser, err := client.User.Create(). + SetEmail("bind-conflict-target@example.com"). + SetUsername("target-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + ownerUser, err := client.User.Create(). + SetEmail("bind-conflict-owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + existingIdentity, err := client.AuthIdentity.Create(). + SetUserID(ownerUser.ID). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("conflict-123"). + SetMetadata(map[string]any{"username": "owner-user"}). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("bind-conflict-session-token"). + SetIntent("bind_current_user"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("conflict-123"). + SetTargetUserID(targetUser.ID). + SetResolvedEmail(targetUser.Email). + SetBrowserSessionKey("bind-conflict-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "suggested_display_name": "Conflict Example", + "suggested_avatar_url": "https://cdn.example/conflict.png", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "access_token": "access-token", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"adopt_display_name":false,"adopt_avatar":false}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-conflict-browser-session-key")}) + ginCtx.Request = req + + handler.ExchangePendingOAuthCompletion(ginCtx) + + require.Equal(t, http.StatusInternalServerError, recorder.Code) + payload := decodeJSONBody(t, recorder) + require.Equal(t, "PENDING_AUTH_ADOPTION_APPLY_FAILED", payload["reason"]) + + identity, err := client.AuthIdentity.Get(ctx, existingIdentity.ID) + require.NoError(t, err) + require.Equal(t, ownerUser.ID, identity.UserID) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.Nil(t, decision.IdentityID) + require.False(t, decision.AdoptDisplayName) + require.False(t, decision.AdoptAvatar) + + storedSession, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + +func TestExchangePendingOAuthCompletionLoginFalseFalseBindsIdentityWithoutAdoption(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + userEntity, err := client.User.Create(). + SetEmail("login-false@example.com"). + SetUsername("legacy-name"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("login-false-session-token"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("login-false-123"). + SetTargetUserID(userEntity.ID). + SetResolvedEmail(userEntity.Email). + SetBrowserSessionKey("login-false-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "suggested_display_name": "Login Example", + "suggested_avatar_url": "https://cdn.example/login.png", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "access_token": "access-token", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"adopt_display_name":false,"adopt_avatar":false}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("login-false-browser-session-key")}) + ginCtx.Request = req + + handler.ExchangePendingOAuthCompletion(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("linuxdo"), + authidentity.ProviderKeyEQ("linuxdo"), + authidentity.ProviderSubjectEQ("login-false-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, userEntity.ID, identity.UserID) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, decision.IdentityID) + require.Equal(t, identity.ID, *decision.IdentityID) + require.False(t, decision.AdoptDisplayName) + require.False(t, decision.AdoptAvatar) + + storedSession, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, storedSession.ConsumedAt) +} + +func TestExchangePendingOAuthCompletionLoginReassignsExistingDecisionIdentityReference(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + userEntity, err := client.User.Create(). + SetEmail("login-reassign@example.com"). + SetUsername("legacy-name"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + existingIdentity, err := client.AuthIdentity.Create(). + SetUserID(userEntity.ID). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("login-reassign-123"). + SetMetadata(map[string]any{}). + Save(ctx) + require.NoError(t, err) + + previousSession, err := client.PendingAuthSession.Create(). + SetSessionToken("login-reassign-previous-session-token"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("login-reassign-123"). + SetTargetUserID(userEntity.ID). + SetResolvedEmail(userEntity.Email). + SetBrowserSessionKey("login-reassign-previous-browser-session-key"). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "access_token": "previous-access-token", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + previousDecision, err := client.IdentityAdoptionDecision.Create(). + SetPendingAuthSessionID(previousSession.ID). + SetIdentityID(existingIdentity.ID). + SetAdoptDisplayName(true). + SetAdoptAvatar(true). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("login-reassign-session-token"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("login-reassign-123"). + SetTargetUserID(userEntity.ID). + SetResolvedEmail(userEntity.Email). + SetBrowserSessionKey("login-reassign-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "suggested_display_name": "Login Reassign", + "suggested_avatar_url": "https://cdn.example/login-reassign.png", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "access_token": "access-token", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + _, err = client.IdentityAdoptionDecision.Create(). + SetPendingAuthSessionID(session.ID). + SetAdoptDisplayName(false). + SetAdoptAvatar(false). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"adopt_display_name":false,"adopt_avatar":false}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("login-reassign-browser-session-key")}) + ginCtx.Request = req + + handler.ExchangePendingOAuthCompletion(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + reloadedPrevious, err := client.IdentityAdoptionDecision.Get(ctx, previousDecision.ID) + require.NoError(t, err) + require.Nil(t, reloadedPrevious.IdentityID) + + currentDecision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, currentDecision.IdentityID) + require.Equal(t, existingIdentity.ID, *currentDecision.IdentityID) + + storedSession, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, storedSession.ConsumedAt) +} + +func TestExchangePendingOAuthCompletionLoginWithoutDecisionStillBindsIdentity(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + userEntity, err := client.User.Create(). + SetEmail("login-nodecision@example.com"). + SetUsername("legacy-name"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("login-nodecision-session-token"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("login-nodecision-123"). + SetTargetUserID(userEntity.ID). + SetResolvedEmail(userEntity.Email). + SetBrowserSessionKey("login-nodecision-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "login-nodecision-user", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "access_token": "access-token", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil) + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("login-nodecision-browser-session-key")}) + ginCtx.Request = req + + handler.ExchangePendingOAuthCompletion(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("linuxdo"), + authidentity.ProviderKeyEQ("linuxdo"), + authidentity.ProviderSubjectEQ("login-nodecision-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, userEntity.ID, identity.UserID) + + storedSession, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, storedSession.ConsumedAt) +} + +func TestExchangePendingOAuthCompletionExistingLoginWithSuggestedProfileSkipsAdoptionPrompt(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + userEntity, err := client.User.Create(). + SetEmail("existing-login@example.com"). + SetUsername("existing-login-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + _, err = client.AuthIdentity.Create(). + SetUserID(userEntity.ID). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("existing-login-123"). + SetMetadata(map[string]any{ + "username": "existing-login-user", + }). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("existing-login-session-token"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("existing-login-123"). + SetTargetUserID(userEntity.ID). + SetResolvedEmail(userEntity.Email). + SetBrowserSessionKey("existing-login-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "suggested_display_name": "Existing Login Example", + "suggested_avatar_url": "https://cdn.example/existing-login.png", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "access_token": "access-token", + "refresh_token": "refresh-token", + "expires_in": float64(3600), + "token_type": "Bearer", + "redirect": "/dashboard", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil) + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("existing-login-browser-session-key")}) + ginCtx.Request = req + + handler.ExchangePendingOAuthCompletion(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + payload := decodeJSONResponseData(t, recorder) + require.Equal(t, "access-token", payload["access_token"]) + require.Equal(t, "refresh-token", payload["refresh_token"]) + require.Equal(t, "/dashboard", payload["redirect"]) + require.Equal(t, "Existing Login Example", payload["suggested_display_name"]) + require.Equal(t, "https://cdn.example/existing-login.png", payload["suggested_avatar_url"]) + require.NotContains(t, payload, "adoption_required") + + decisionCount, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, decisionCount) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.NotNil(t, storedSession.ConsumedAt) +} + +func TestExchangePendingOAuthCompletionBlocksBackendModeBeforeReturningTokenPayload(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{ + settingValues: map[string]string{ + service.SettingKeyBackendModeEnabled: "true", + }, + }) + ctx := context.Background() + + userEntity, err := client.User.Create(). + SetEmail("blocked@example.com"). + SetUsername("blocked-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("blocked-backend-mode-session-token"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("blocked-subject-123"). + SetTargetUserID(userEntity.ID). + SetResolvedEmail(userEntity.Email). + SetBrowserSessionKey("blocked-backend-mode-browser-session-key"). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "access_token": "access-token", + "refresh_token": "refresh-token", + "expires_in": float64(3600), + "token_type": "Bearer", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil) + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("blocked-backend-mode-browser-session-key")}) + ginCtx.Request = req + + handler.ExchangePendingOAuthCompletion(ginCtx) + + require.Equal(t, http.StatusForbidden, recorder.Code) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + +func TestExchangePendingOAuthCompletionInvitationRequiredFalseFalsePersistsDecisionWithoutBinding(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, true) + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("invitation-required-session-token"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("invitation-123"). + SetBrowserSessionKey("invitation-required-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "suggested_display_name": "Invite Example", + "suggested_avatar_url": "https://cdn.example/invite.png", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "error": "invitation_required", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"adopt_display_name":false,"adopt_avatar":false}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("invitation-required-browser-session-key")}) + ginCtx.Request = req + + handler.ExchangePendingOAuthCompletion(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + data := decodeJSONResponseData(t, recorder) + require.Equal(t, "invitation_required", data["error"]) + require.Equal(t, true, data["adoption_required"]) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("linuxdo"), + authidentity.ProviderKeyEQ("linuxdo"), + authidentity.ProviderSubjectEQ("invitation-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, identityCount) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.Nil(t, decision.IdentityID) + require.False(t, decision.AdoptDisplayName) + require.False(t, decision.AdoptAvatar) + + storedSession, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + +func TestCreateOIDCOAuthAccountCreatesUserBindsIdentityAndConsumesSession(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "fresh@example.com", "246810") + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("create-account-session-token"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-create-123"). + SetBrowserSessionKey("create-account-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "suggested_display_name": "Fresh OIDC User", + "suggested_avatar_url": "https://cdn.example/fresh.png", + }). + SetRedirectTo("/profile"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-browser-session-key")}) + ginCtx.Request = req + + handler.CreateOIDCOAuthAccount(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload)) + require.NotEmpty(t, payload["access_token"]) + require.NotEmpty(t, payload["refresh_token"]) + require.Equal(t, "Bearer", payload["token_type"]) + + createdUser, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Only(ctx) + require.NoError(t, err) + require.Equal(t, service.StatusActive, createdUser.Status) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example"), + authidentity.ProviderSubjectEQ("oidc-create-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, createdUser.ID, identity.UserID) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.NotNil(t, storedSession.ConsumedAt) +} + +func TestCreateOIDCOAuthAccountExistingEmailReturnsChoicePendingSessionState(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790") + ctx := context.Background() + + _, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("existing-email-session-token"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-existing-123"). + SetBrowserSessionKey("existing-email-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "suggested_display_name": "Existing OIDC User", + "suggested_avatar_url": "https://cdn.example/existing.png", + }). + SetRedirectTo("/dashboard"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"owner@example.com","verify_code":"135790","password":"secret-123"}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("existing-email-browser-session-key")}) + ginCtx.Request = req + + handler.CreateOIDCOAuthAccount(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload)) + require.Equal(t, "pending_session", payload["auth_result"]) + require.Equal(t, oauthIntentLogin, payload["intent"]) + require.Equal(t, "oidc", payload["provider"]) + require.Equal(t, "/dashboard", payload["redirect"]) + require.Equal(t, true, payload["adoption_required"]) + require.Equal(t, oauthPendingChoiceStep, payload["step"]) + require.Equal(t, "owner@example.com", payload["email"]) + require.Equal(t, "Existing OIDC User", payload["suggested_display_name"]) + require.Equal(t, "https://cdn.example/existing.png", payload["suggested_avatar_url"]) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Equal(t, oauthIntentLogin, storedSession.Intent) + require.Nil(t, storedSession.TargetUserID) + require.Equal(t, "owner@example.com", storedSession.ResolvedEmail) + require.Nil(t, storedSession.ConsumedAt) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example"), + authidentity.ProviderSubjectEQ("oidc-existing-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, identityCount) +} + +func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790") + ctx := context.Background() + + _, err := client.User.Create(). + SetEmail(" Owner@Example.com "). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("existing-email-normalized-session-token"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-existing-normalized-123"). + SetBrowserSessionKey("existing-email-normalized-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "suggested_display_name": "Existing OIDC User", + "suggested_avatar_url": "https://cdn.example/existing.png", + }). + SetRedirectTo("/dashboard"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"owner@example.com","verify_code":"135790","password":"secret-123"}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("existing-email-normalized-browser-session-key")}) + ginCtx.Request = req + + handler.CreateOIDCOAuthAccount(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload)) + require.Equal(t, oauthIntentLogin, payload["intent"]) + require.Equal(t, oauthPendingChoiceStep, payload["step"]) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.TargetUserID) + require.Equal(t, "owner@example.com", storedSession.ResolvedEmail) +} + +func TestSendPendingOAuthVerifyCodeExistingEmailReturnsBindLoginState(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790") + ctx := context.Background() + + _, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("existing-email-send-code-session-token"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-existing-send-code-123"). + SetBrowserSessionKey("existing-email-send-code-browser-session-key"). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "step": "email_required", + }, + }). + SetRedirectTo("/dashboard"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"owner@example.com"}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/send-verify-code", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("existing-email-send-code-browser-session-key")}) + ginCtx.Request = req + + handler.SendPendingOAuthVerifyCode(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload)) + require.Equal(t, "pending_session", payload["auth_result"]) + require.Equal(t, oauthPendingChoiceStep, payload["step"]) + require.Equal(t, "owner@example.com", payload["email"]) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Equal(t, oauthIntentLogin, storedSession.Intent) + require.Nil(t, storedSession.TargetUserID) + require.Equal(t, "owner@example.com", storedSession.ResolvedEmail) +} + +func TestCreateOIDCOAuthAccountBlocksBackendModeBeforeCreatingUser(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{ + emailVerifyEnabled: true, + emailCache: &oauthPendingFlowEmailCacheStub{ + verificationCodes: map[string]*service.VerificationCodeData{ + "fresh@example.com": { + Code: "246810", + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(15 * time.Minute), + }, + }, + }, + settingValues: map[string]string{ + service.SettingKeyBackendModeEnabled: "true", + }, + }) + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("create-account-backend-mode-session-token"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-create-backend-mode-123"). + SetBrowserSessionKey("create-account-backend-mode-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123"}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-backend-mode-browser-session-key")}) + ginCtx.Request = req + + handler.CreateOIDCOAuthAccount(ginCtx) + + require.Equal(t, http.StatusForbidden, recorder.Code) + + userCount, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Count(ctx) + require.NoError(t, err) + require.Zero(t, userCount) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + +func TestCreateOIDCOAuthAccountRollsBackCreatedUserWhenBindingFails(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, true, "fresh@example.com", "246810") + ctx := context.Background() + + conflictOwner, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + _, err = client.AuthIdentity.Create(). + SetUserID(conflictOwner.ID). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-conflict-123"). + SetMetadata(map[string]any{ + "username": "owner-user", + }). + Save(ctx) + require.NoError(t, err) + + invitation, err := client.RedeemCode.Create(). + SetCode("INVITE123"). + SetType(service.RedeemTypeInvitation). + SetStatus(service.StatusUnused). + SetValue(0). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("create-account-conflict-session-token"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-conflict-123"). + SetBrowserSessionKey("create-account-conflict-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + }). + SetRedirectTo("/profile"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123","invitation_code":"INVITE123"}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-conflict-browser-session-key")}) + ginCtx.Request = req + + handler.CreateOIDCOAuthAccount(ginCtx) + + require.Equal(t, http.StatusConflict, recorder.Code) + + userCount, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Count(ctx) + require.NoError(t, err) + require.Zero(t, userCount) + + storedInvitation, err := client.RedeemCode.Get(ctx, invitation.ID) + require.NoError(t, err) + require.Equal(t, service.StatusUnused, storedInvitation.Status) + require.Nil(t, storedInvitation.UsedBy) + require.Nil(t, storedInvitation.UsedAt) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + +func TestCreateOIDCOAuthAccountRollsBackPostBindFailureBeforeIdentityCanCommit(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{ + emailVerifyEnabled: true, + emailCache: &oauthPendingFlowEmailCacheStub{ + verificationCodes: map[string]*service.VerificationCodeData{ + "fresh@example.com": { + Code: "246810", + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(15 * time.Minute), + }, + }, + }, + userRepoOptions: oauthPendingFlowUserRepoOptions{ + rejectDeleteWhileAuthIdentityExists: true, + }, + }) + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("create-account-finalize-failure-session-token"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-finalize-failure-123"). + SetBrowserSessionKey("create-account-finalize-failure-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + }). + SetRedirectTo("/profile"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + pendingOAuthCreateAccountPreCommitHook = func(context.Context, *dbent.PendingAuthSession) error { + return errors.New("forced post-bind failure") + } + t.Cleanup(func() { + pendingOAuthCreateAccountPreCommitHook = nil + }) + + body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123"}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-finalize-failure-browser-session-key")}) + ginCtx.Request = req + + handler.CreateOIDCOAuthAccount(ginCtx) + + require.Equal(t, http.StatusInternalServerError, recorder.Code) + + userCount, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Count(ctx) + require.NoError(t, err) + require.Zero(t, userCount) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example"), + authidentity.ProviderSubjectEQ("oidc-finalize-failure-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, identityCount) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + +func TestBindOIDCOAuthLoginBindsExistingUserAndConsumesSession(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + passwordHash, err := handler.authService.HashPassword("secret-123") + require.NoError(t, err) + + existingUser, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash(passwordHash). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("bind-login-session-token"). + SetIntent("adopt_existing_user_by_email"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-bind-123"). + SetTargetUserID(existingUser.ID). + SetResolvedEmail(existingUser.Email). + SetBrowserSessionKey("bind-login-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "suggested_display_name": "Bound OIDC User", + "suggested_avatar_url": "https://cdn.example/bound.png", + }). + SetRedirectTo("/profile"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-browser-session-key")}) + ginCtx.Request = req + + handler.BindOIDCOAuthLogin(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload)) + require.NotEmpty(t, payload["access_token"]) + require.NotEmpty(t, payload["refresh_token"]) + require.Equal(t, "Bearer", payload["token_type"]) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example"), + authidentity.ProviderSubjectEQ("oidc-bind-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, existingUser.ID, identity.UserID) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.NotNil(t, storedSession.ConsumedAt) +} + +func TestBindOIDCOAuthLoginBlocksBackendModeBeforeTokenIssue(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{ + settingValues: map[string]string{ + service.SettingKeyBackendModeEnabled: "true", + }, + }) + ctx := context.Background() + + passwordHash, err := handler.authService.HashPassword("secret-123") + require.NoError(t, err) + + existingUser, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash(passwordHash). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("bind-login-backend-mode-session-token"). + SetIntent("adopt_existing_user_by_email"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-bind-backend-mode-123"). + SetTargetUserID(existingUser.ID). + SetResolvedEmail(existingUser.Email). + SetBrowserSessionKey("bind-login-backend-mode-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123"}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-backend-mode-browser-session-key")}) + ginCtx.Request = req + + handler.BindOIDCOAuthLogin(ginCtx) + + require.Equal(t, http.StatusForbidden, recorder.Code) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example"), + authidentity.ProviderSubjectEQ("oidc-bind-backend-mode-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, identityCount) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + +func TestBindOIDCOAuthLoginRejectsInvalidPasswordWithoutConsumingSession(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + passwordHash, err := handler.authService.HashPassword("secret-123") + require.NoError(t, err) + + existingUser, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash(passwordHash). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("bind-login-invalid-password-session-token"). + SetIntent("adopt_existing_user_by_email"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-bind-invalid-123"). + SetTargetUserID(existingUser.ID). + SetResolvedEmail(existingUser.Email). + SetBrowserSessionKey("bind-login-invalid-password-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "suggested_display_name": "Bound OIDC User", + "suggested_avatar_url": "https://cdn.example/bound.png", + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"owner@example.com","password":"wrong-password"}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-invalid-password-browser-session-key")}) + ginCtx.Request = req + + handler.BindOIDCOAuthLogin(ginCtx) + + require.Equal(t, http.StatusUnauthorized, recorder.Code) + payload := decodeJSONBody(t, recorder) + require.Equal(t, "INVALID_CREDENTIALS", payload["reason"]) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example"), + authidentity.ProviderSubjectEQ("oidc-bind-invalid-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, identityCount) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + +func TestBindOIDCOAuthLoginReclaimsIdentityOwnedBySoftDeletedUser(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + oldOwnerHash, err := handler.authService.HashPassword("old-secret") + require.NoError(t, err) + oldOwner, err := client.User.Create(). + SetEmail("old-owner@example.com"). + SetUsername("old-owner"). + SetPasswordHash(oldOwnerHash). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + identity, err := client.AuthIdentity.Create(). + SetUserID(oldOwner.ID). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-bind-soft-deleted-123"). + SetMetadata(map[string]any{"username": "old-owner"}). + Save(ctx) + require.NoError(t, err) + + _, err = client.User.Delete().Where(dbuser.IDEQ(oldOwner.ID)).Exec(ctx) + require.NoError(t, err) + + newOwnerHash, err := handler.authService.HashPassword("secret-123") + require.NoError(t, err) + newOwner, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash(newOwnerHash). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("bind-login-soft-deleted-owner-session-token"). + SetIntent("adopt_existing_user_by_email"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-bind-soft-deleted-123"). + SetTargetUserID(newOwner.ID). + SetResolvedEmail(newOwner.Email). + SetBrowserSessionKey("bind-login-soft-deleted-owner-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "suggested_display_name": "Recovered OIDC User", + }). + SetRedirectTo("/profile"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-soft-deleted-owner-browser-session-key")}) + ginCtx.Request = req + + handler.BindOIDCOAuthLogin(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + identity, err = client.AuthIdentity.Get(ctx, identity.ID) + require.NoError(t, err) + require.Equal(t, newOwner.ID, identity.UserID) +} + +func TestBindOIDCOAuthLoginAppliesFirstBindGrantOnce(t *testing.T) { + defaultSubAssigner := &oauthPendingFlowDefaultSubAssignerStub{} + handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{ + settingValues: map[string]string{ + service.SettingKeyAuthSourceDefaultOIDCBalance: "12.5", + service.SettingKeyAuthSourceDefaultOIDCConcurrency: "3", + service.SettingKeyAuthSourceDefaultOIDCSubscriptions: `[{"group_id":101,"validity_days":30}]`, + service.SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind: "true", + }, + defaultSubAssigner: defaultSubAssigner, + }) + ctx := context.Background() + + passwordHash, err := handler.authService.HashPassword("secret-123") + require.NoError(t, err) + + existingUser, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash(passwordHash). + SetBalance(5). + SetConcurrency(2). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + firstSession, err := client.PendingAuthSession.Create(). + SetSessionToken("first-bind-session-token"). + SetIntent("adopt_existing_user_by_email"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-bind-first-123"). + SetTargetUserID(existingUser.ID). + SetResolvedEmail(existingUser.Email). + SetBrowserSessionKey("first-bind-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "suggested_display_name": "Bound OIDC User", + "suggested_avatar_url": "https://cdn.example/bound.png", + }). + SetRedirectTo("/profile"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + firstBody := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`) + firstRecorder := httptest.NewRecorder() + firstGinCtx, _ := gin.CreateTestContext(firstRecorder) + firstReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", firstBody) + firstReq.Header.Set("Content-Type", "application/json") + firstReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(firstSession.SessionToken)}) + firstReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("first-bind-browser-session-key")}) + firstGinCtx.Request = firstReq + + handler.BindOIDCOAuthLogin(firstGinCtx) + + require.Equal(t, http.StatusOK, firstRecorder.Code) + + storedUser, err := client.User.Get(ctx, existingUser.ID) + require.NoError(t, err) + require.Equal(t, 17.5, storedUser.Balance) + require.Equal(t, 5, storedUser.Concurrency) + require.Zero(t, storedUser.TotalRecharged) + require.Len(t, defaultSubAssigner.calls, 1) + require.Equal(t, int64(existingUser.ID), defaultSubAssigner.calls[0].UserID) + require.Equal(t, int64(101), defaultSubAssigner.calls[0].GroupID) + require.Equal(t, 30, defaultSubAssigner.calls[0].ValidityDays) + require.Equal(t, 1, countProviderGrantRecords(t, client, existingUser.ID, "oidc", "first_bind")) + + secondSession, err := client.PendingAuthSession.Create(). + SetSessionToken("second-bind-session-token"). + SetIntent("adopt_existing_user_by_email"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-bind-second-456"). + SetTargetUserID(existingUser.ID). + SetResolvedEmail(existingUser.Email). + SetBrowserSessionKey("second-bind-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "suggested_display_name": "Second OIDC User", + "suggested_avatar_url": "https://cdn.example/second.png", + }). + SetRedirectTo("/profile"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + secondBody := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`) + secondRecorder := httptest.NewRecorder() + secondGinCtx, _ := gin.CreateTestContext(secondRecorder) + secondReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", secondBody) + secondReq.Header.Set("Content-Type", "application/json") + secondReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(secondSession.SessionToken)}) + secondReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("second-bind-browser-session-key")}) + secondGinCtx.Request = secondReq + + handler.BindOIDCOAuthLogin(secondGinCtx) + + require.Equal(t, http.StatusOK, secondRecorder.Code) + + storedUser, err = client.User.Get(ctx, existingUser.ID) + require.NoError(t, err) + require.Equal(t, 17.5, storedUser.Balance) + require.Equal(t, 5, storedUser.Concurrency) + require.Zero(t, storedUser.TotalRecharged) + require.Len(t, defaultSubAssigner.calls, 1) + require.Equal(t, 1, countProviderGrantRecords(t, client, existingUser.ID, "oidc", "first_bind")) +} + +func TestResolvePendingOAuthTargetUserIDNormalizesLegacySpacingAndCase(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + _ = handler + ctx := context.Background() + + existingUser, err := client.User.Create(). + SetEmail(" Owner@Example.com "). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("resolve-target-session-token"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-target-123"). + SetResolvedEmail("owner@example.com"). + SetBrowserSessionKey("resolve-target-browser-session-key"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + resolvedUserID, err := resolvePendingOAuthTargetUserID(ctx, client, session) + require.NoError(t, err) + require.Equal(t, existingUser.ID, resolvedUserID) +} + +func TestBindOIDCOAuthLoginReturns2FAChallengeWhenUserHasTotp(t *testing.T) { + totpCache := &oauthPendingFlowTotpCacheStub{} + handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{ + settingValues: map[string]string{ + service.SettingKeyTotpEnabled: "true", + }, + totpCache: totpCache, + totpEncryptor: oauthPendingFlowTotpEncryptorStub{}, + }) + ctx := context.Background() + + passwordHash, err := handler.authService.HashPassword("secret-123") + require.NoError(t, err) + totpEnabledAt := time.Now().UTC().Add(-time.Hour) + secret := "JBSWY3DPEHPK3PXP" + + existingUser, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash(passwordHash). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + SetTotpEnabled(true). + SetTotpSecretEncrypted(secret). + SetTotpEnabledAt(totpEnabledAt). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("bind-login-2fa-session-token"). + SetIntent("adopt_existing_user_by_email"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-bind-2fa-123"). + SetTargetUserID(existingUser.ID). + SetResolvedEmail(existingUser.Email). + SetBrowserSessionKey("bind-login-2fa-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "suggested_display_name": "Bound OIDC User", + "suggested_avatar_url": "https://cdn.example/bound.png", + }). + SetRedirectTo("/profile"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-2fa-browser-session-key")}) + ginCtx.Request = req + + handler.BindOIDCOAuthLogin(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + data := decodeJSONResponseData(t, recorder) + require.Equal(t, true, data["requires_2fa"]) + require.Equal(t, "o***r@example.com", data["user_email_masked"]) + tempToken, ok := data["temp_token"].(string) + require.True(t, ok) + require.NotEmpty(t, tempToken) + + loginSession, err := totpCache.GetLoginSession(ctx, tempToken) + require.NoError(t, err) + require.NotNil(t, loginSession) + require.NotNil(t, loginSession.PendingOAuthBind) + require.Equal(t, session.SessionToken, loginSession.PendingOAuthBind.PendingSessionToken) + require.Equal(t, session.BrowserSessionKey, loginSession.PendingOAuthBind.BrowserSessionKey) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example"), + authidentity.ProviderSubjectEQ("oidc-bind-2fa-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, identityCount) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + +func TestLogin2FACompletesPendingOAuthBindAndConsumesSession(t *testing.T) { + totpCache := &oauthPendingFlowTotpCacheStub{} + defaultSubAssigner := &oauthPendingFlowDefaultSubAssignerStub{} + handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{ + settingValues: map[string]string{ + service.SettingKeyTotpEnabled: "true", + service.SettingKeyAuthSourceDefaultOIDCBalance: "8", + service.SettingKeyAuthSourceDefaultOIDCConcurrency: "2", + service.SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind: "true", + }, + defaultSubAssigner: defaultSubAssigner, + totpCache: totpCache, + totpEncryptor: oauthPendingFlowTotpEncryptorStub{}, + }) + ctx := context.Background() + + passwordHash, err := handler.authService.HashPassword("secret-123") + require.NoError(t, err) + totpEnabledAt := time.Now().UTC().Add(-time.Hour) + secret := "JBSWY3DPEHPK3PXP" + + existingUser, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash(passwordHash). + SetBalance(1.5). + SetConcurrency(4). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + SetTotpEnabled(true). + SetTotpSecretEncrypted(secret). + SetTotpEnabledAt(totpEnabledAt). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("login-2fa-pending-session-token"). + SetIntent("adopt_existing_user_by_email"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-login-2fa-123"). + SetTargetUserID(existingUser.ID). + SetResolvedEmail(existingUser.Email). + SetBrowserSessionKey("login-2fa-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "suggested_display_name": "Bound OIDC User", + "suggested_avatar_url": "https://cdn.example/bound.png", + }). + SetRedirectTo("/profile"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + _, err = client.IdentityAdoptionDecision.Create(). + SetPendingAuthSessionID(session.ID). + SetAdoptDisplayName(false). + SetAdoptAvatar(false). + Save(ctx) + require.NoError(t, err) + + tempToken, err := handler.totpService.CreatePendingOAuthBindLoginSession( + ctx, + existingUser.ID, + existingUser.Email, + session.SessionToken, + session.BrowserSessionKey, + ) + require.NoError(t, err) + + code, err := totp.GenerateCode(secret, time.Now().UTC()) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"temp_token":"` + tempToken + `","totp_code":"` + code + `"}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/2fa", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue(session.BrowserSessionKey)}) + ginCtx.Request = req + + handler.Login2FA(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + payload := decodeJSONResponseData(t, recorder) + require.NotEmpty(t, payload["access_token"]) + require.NotEmpty(t, payload["refresh_token"]) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example"), + authidentity.ProviderSubjectEQ("oidc-login-2fa-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, existingUser.ID, identity.UserID) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.NotNil(t, storedSession.ConsumedAt) + + loginSession, err := totpCache.GetLoginSession(ctx, tempToken) + require.NoError(t, err) + require.Nil(t, loginSession) + + storedUser, err := client.User.Get(ctx, existingUser.ID) + require.NoError(t, err) + require.Equal(t, 9.5, storedUser.Balance) + require.Equal(t, 6, storedUser.Concurrency) + require.Equal(t, 1, countProviderGrantRecords(t, client, existingUser.ID, "oidc", "first_bind")) + require.Empty(t, defaultSubAssigner.calls) +} + +func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) { + t.Helper() + + return newOAuthPendingFlowTestHandlerWithOptions(t, invitationEnabled, false, nil) +} + +func newOAuthPendingFlowTestHandlerWithEmailVerification( + t *testing.T, + invitationEnabled bool, + email string, + code string, +) (*AuthHandler, *dbent.Client) { + t.Helper() + + cache := &oauthPendingFlowEmailCacheStub{ + verificationCodes: map[string]*service.VerificationCodeData{ + email: { + Code: code, + Attempts: 0, + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(15 * time.Minute), + }, + }, + } + return newOAuthPendingFlowTestHandlerWithOptions(t, invitationEnabled, true, cache) +} + +func newOAuthPendingFlowTestHandlerWithOptions( + t *testing.T, + invitationEnabled bool, + emailVerifyEnabled bool, + emailCache service.EmailCache, +) (*AuthHandler, *dbent.Client) { + return newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{ + invitationEnabled: invitationEnabled, + emailVerifyEnabled: emailVerifyEnabled, + emailCache: emailCache, + }) +} + +type oauthPendingFlowTestHandlerOptions struct { + invitationEnabled bool + emailVerifyEnabled bool + emailCache service.EmailCache + settingValues map[string]string + defaultSubAssigner service.DefaultSubscriptionAssigner + totpCache service.TotpCache + totpEncryptor service.SecretEncryptor + userRepoOptions oauthPendingFlowUserRepoOptions +} + +func newOAuthPendingFlowTestHandlerWithDependencies( + t *testing.T, + options oauthPendingFlowTestHandlerOptions, +) (*AuthHandler, *dbent.Client) { + t.Helper() + + db, err := sql.Open("sqlite", "file:auth_oauth_pending_flow_handler?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + _, err = db.Exec(` +CREATE TABLE IF NOT EXISTS user_provider_default_grants ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + provider_type TEXT NOT NULL, + grant_reason TEXT NOT NULL DEFAULT 'first_bind', + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + UNIQUE(user_id, provider_type, grant_reason) +)`) + require.NoError(t, err) + _, err = db.Exec(` +CREATE TABLE IF NOT EXISTS user_avatars ( + user_id INTEGER PRIMARY KEY, + storage_provider TEXT NOT NULL, + storage_key TEXT NOT NULL DEFAULT '', + url TEXT NOT NULL, + content_type TEXT NOT NULL DEFAULT '', + byte_size INTEGER NOT NULL DEFAULT 0, + sha256 TEXT NOT NULL DEFAULT '', + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +)`) + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + + cfg := &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret", + ExpireHour: 1, + AccessTokenExpireMinutes: 60, + RefreshTokenExpireDays: 7, + }, + Default: config.DefaultConfig{ + UserBalance: 0, + UserConcurrency: 1, + }, + } + settingValues := map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyInvitationCodeEnabled: boolSettingValue(options.invitationEnabled), + service.SettingKeyEmailVerifyEnabled: boolSettingValue(options.emailVerifyEnabled), + } + for key, value := range options.settingValues { + settingValues[key] = value + } + settingSvc := service.NewSettingService(&oauthPendingFlowSettingRepoStub{values: settingValues}, cfg) + userRepo := &oauthPendingFlowUserRepo{ + client: client, + options: options.userRepoOptions, + } + redeemRepo := &oauthPendingFlowRedeemCodeRepo{client: client} + var emailService *service.EmailService + if options.emailCache != nil { + emailService = service.NewEmailService(&oauthPendingFlowSettingRepoStub{ + values: map[string]string{ + service.SettingKeyEmailVerifyEnabled: boolSettingValue(options.emailVerifyEnabled), + }, + }, options.emailCache) + } + authSvc := service.NewAuthService( + client, + userRepo, + redeemRepo, + &oauthPendingFlowRefreshTokenCacheStub{}, + cfg, + settingSvc, + emailService, + nil, + nil, + nil, + options.defaultSubAssigner, + ) + userSvc := service.NewUserService(userRepo, nil, nil, nil) + var totpSvc *service.TotpService + if options.totpCache != nil || options.totpEncryptor != nil { + totpCache := options.totpCache + if totpCache == nil { + totpCache = &oauthPendingFlowTotpCacheStub{} + } + totpEncryptor := options.totpEncryptor + if totpEncryptor == nil { + totpEncryptor = oauthPendingFlowTotpEncryptorStub{} + } + totpSvc = service.NewTotpService(userRepo, totpEncryptor, totpCache, settingSvc, nil, nil) + } + + return &AuthHandler{ + authService: authSvc, + userService: userSvc, + settingSvc: settingSvc, + totpService: totpSvc, + }, client +} + +func boolSettingValue(v bool) string { + if v { + return "true" + } + return "false" +} + +func boolPtr(v bool) *bool { + return &v +} + +type oauthPendingFlowSettingRepoStub struct { + values map[string]string +} + +func (s *oauthPendingFlowSettingRepoStub) Get(context.Context, string) (*service.Setting, error) { + return nil, service.ErrSettingNotFound +} + +func (s *oauthPendingFlowSettingRepoStub) GetValue(_ context.Context, key string) (string, error) { + value, ok := s.values[key] + if !ok { + return "", service.ErrSettingNotFound + } + return value, nil +} + +func (s *oauthPendingFlowSettingRepoStub) Set(context.Context, string, string) error { + return nil +} + +func (s *oauthPendingFlowSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + result := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + result[key] = value + } + } + return result, nil +} + +func (s *oauthPendingFlowSettingRepoStub) SetMultiple(context.Context, map[string]string) error { + return nil +} + +func (s *oauthPendingFlowSettingRepoStub) GetAll(context.Context) (map[string]string, error) { + result := make(map[string]string, len(s.values)) + for key, value := range s.values { + result[key] = value + } + return result, nil +} + +func (s *oauthPendingFlowSettingRepoStub) Delete(context.Context, string) error { + return nil +} + +type oauthPendingFlowRefreshTokenCacheStub struct{} + +type oauthPendingFlowEmailCacheStub struct { + verificationCodes map[string]*service.VerificationCodeData +} + +func (s *oauthPendingFlowEmailCacheStub) GetVerificationCode(_ context.Context, email string) (*service.VerificationCodeData, error) { + if s == nil || s.verificationCodes == nil { + return nil, nil + } + return s.verificationCodes[email], nil +} + +func (s *oauthPendingFlowEmailCacheStub) SetVerificationCode(_ context.Context, email string, data *service.VerificationCodeData, _ time.Duration) error { + if s.verificationCodes == nil { + s.verificationCodes = map[string]*service.VerificationCodeData{} + } + s.verificationCodes[email] = data + return nil +} + +func (s *oauthPendingFlowEmailCacheStub) DeleteVerificationCode(_ context.Context, email string) error { + delete(s.verificationCodes, email) + return nil +} + +func (s *oauthPendingFlowEmailCacheStub) GetNotifyVerifyCode(context.Context, string) (*service.VerificationCodeData, error) { + return nil, nil +} + +func (s *oauthPendingFlowEmailCacheStub) SetNotifyVerifyCode(context.Context, string, *service.VerificationCodeData, time.Duration) error { + return nil +} + +func (s *oauthPendingFlowEmailCacheStub) DeleteNotifyVerifyCode(context.Context, string) error { + return nil +} + +func (s *oauthPendingFlowEmailCacheStub) GetPasswordResetToken(context.Context, string) (*service.PasswordResetTokenData, error) { + return nil, nil +} + +func (s *oauthPendingFlowEmailCacheStub) SetPasswordResetToken(context.Context, string, *service.PasswordResetTokenData, time.Duration) error { + return nil +} + +func (s *oauthPendingFlowEmailCacheStub) DeletePasswordResetToken(context.Context, string) error { + return nil +} + +func (s *oauthPendingFlowEmailCacheStub) IsPasswordResetEmailInCooldown(context.Context, string) bool { + return false +} + +func (s *oauthPendingFlowEmailCacheStub) SetPasswordResetEmailCooldown(context.Context, string, time.Duration) error { + return nil +} + +func (s *oauthPendingFlowEmailCacheStub) IncrNotifyCodeUserRate(context.Context, int64, time.Duration) (int64, error) { + return 0, nil +} + +func (s *oauthPendingFlowEmailCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int64, error) { + return 0, nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error { + return nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) GetRefreshToken(context.Context, string) (*service.RefreshTokenData, error) { + return nil, service.ErrRefreshTokenNotFound +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error { + return nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error { + return nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error { + return nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error { + return nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error { + return nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) { + return nil, nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) { + return nil, nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) { + return false, nil +} + +type oauthPendingFlowRedeemCodeRepo struct { + client *dbent.Client +} + +func (r *oauthPendingFlowRedeemCodeRepo) Create(context.Context, *service.RedeemCode) error { + panic("unexpected Create call") +} + +func (r *oauthPendingFlowRedeemCodeRepo) CreateBatch(context.Context, []service.RedeemCode) error { + panic("unexpected CreateBatch call") +} + +func (r *oauthPendingFlowRedeemCodeRepo) GetByID(context.Context, int64) (*service.RedeemCode, error) { + panic("unexpected GetByID call") +} + +func (r *oauthPendingFlowRedeemCodeRepo) GetByCode(ctx context.Context, code string) (*service.RedeemCode, error) { + entity, err := r.client.RedeemCode.Query().Where(redeemcode.CodeEQ(code)).Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil, service.ErrRedeemCodeNotFound + } + return nil, err + } + notes := "" + if entity.Notes != nil { + notes = *entity.Notes + } + return &service.RedeemCode{ + ID: entity.ID, + Code: entity.Code, + Type: entity.Type, + Value: entity.Value, + Status: entity.Status, + UsedBy: entity.UsedBy, + UsedAt: entity.UsedAt, + Notes: notes, + CreatedAt: entity.CreatedAt, + GroupID: entity.GroupID, + ValidityDays: entity.ValidityDays, + }, nil +} + +func (r *oauthPendingFlowRedeemCodeRepo) Update(ctx context.Context, code *service.RedeemCode) error { + if code == nil { + return nil + } + update := r.client.RedeemCode.UpdateOneID(code.ID). + SetCode(code.Code). + SetType(code.Type). + SetValue(code.Value). + SetStatus(code.Status). + SetNotes(code.Notes). + SetValidityDays(code.ValidityDays) + if code.UsedBy != nil { + update = update.SetUsedBy(*code.UsedBy) + } else { + update = update.ClearUsedBy() + } + if code.UsedAt != nil { + update = update.SetUsedAt(*code.UsedAt) + } else { + update = update.ClearUsedAt() + } + if code.GroupID != nil { + update = update.SetGroupID(*code.GroupID) + } else { + update = update.ClearGroupID() + } + _, err := update.Save(ctx) + return err +} + +func (r *oauthPendingFlowRedeemCodeRepo) Delete(context.Context, int64) error { + panic("unexpected Delete call") +} + +func (r *oauthPendingFlowRedeemCodeRepo) Use(ctx context.Context, id, userID int64) error { + affected, err := r.client.RedeemCode.Update(). + Where(redeemcode.IDEQ(id), redeemcode.StatusEQ(service.StatusUnused)). + SetStatus(service.StatusUsed). + SetUsedBy(userID). + SetUsedAt(time.Now().UTC()). + Save(ctx) + if err != nil { + return err + } + if affected == 0 { + return service.ErrRedeemCodeUsed + } + return nil +} + +func (r *oauthPendingFlowRedeemCodeRepo) List(context.Context, pagination.PaginationParams) ([]service.RedeemCode, *pagination.PaginationResult, error) { + panic("unexpected List call") +} + +func (r *oauthPendingFlowRedeemCodeRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string) ([]service.RedeemCode, *pagination.PaginationResult, error) { + panic("unexpected ListWithFilters call") +} + +func (r *oauthPendingFlowRedeemCodeRepo) ListByUser(context.Context, int64, int) ([]service.RedeemCode, error) { + panic("unexpected ListByUser call") +} + +func (r *oauthPendingFlowRedeemCodeRepo) ListByUserPaginated(context.Context, int64, pagination.PaginationParams, string) ([]service.RedeemCode, *pagination.PaginationResult, error) { + panic("unexpected ListByUserPaginated call") +} + +func (r *oauthPendingFlowRedeemCodeRepo) SumPositiveBalanceByUser(context.Context, int64) (float64, error) { + panic("unexpected SumPositiveBalanceByUser call") +} + +func decodeJSONResponseData(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any { + t.Helper() + + var envelope struct { + Data map[string]any `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &envelope)) + return envelope.Data +} + +func decodeJSONBody(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any { + t.Helper() + + var payload map[string]any + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload)) + return payload +} + +type oauthPendingFlowAvatarRecord struct { + StorageProvider string + URL string +} + +func loadUserAvatarRecord(t *testing.T, client *dbent.Client, userID int64) *oauthPendingFlowAvatarRecord { + t.Helper() + + var rows entsql.Rows + err := client.Driver().Query( + context.Background(), + `SELECT storage_provider, url FROM user_avatars WHERE user_id = ?`, + []any{userID}, + &rows, + ) + require.NoError(t, err) + defer func() { _ = rows.Close() }() + + if !rows.Next() { + require.NoError(t, rows.Err()) + return nil + } + + var record oauthPendingFlowAvatarRecord + require.NoError(t, rows.Scan(&record.StorageProvider, &record.URL)) + require.NoError(t, rows.Err()) + return &record +} + +func countProviderGrantRecords( + t *testing.T, + client *dbent.Client, + userID int64, + providerType string, + grantReason string, +) int { + t.Helper() + + var rows entsql.Rows + err := client.Driver().Query( + context.Background(), + `SELECT COUNT(*) FROM user_provider_default_grants WHERE user_id = ? AND provider_type = ? AND grant_reason = ?`, + []any{userID, providerType, grantReason}, + &rows, + ) + require.NoError(t, err) + defer func() { _ = rows.Close() }() + + require.True(t, rows.Next()) + var count int + require.NoError(t, rows.Scan(&count)) + require.False(t, rows.Next()) + return count +} + +type oauthPendingFlowUserRepo struct { + client *dbent.Client + options oauthPendingFlowUserRepoOptions +} + +type oauthPendingFlowUserRepoOptions struct { + rejectDeleteWhileAuthIdentityExists bool +} + +func (r *oauthPendingFlowUserRepo) Create(ctx context.Context, user *service.User) error { + entity, err := r.client.User.Create(). + SetEmail(user.Email). + SetUsername(user.Username). + SetNotes(user.Notes). + SetPasswordHash(user.PasswordHash). + SetRole(user.Role). + SetBalance(user.Balance). + SetConcurrency(user.Concurrency). + SetStatus(user.Status). + SetNillableTotpSecretEncrypted(user.TotpSecretEncrypted). + SetTotpEnabled(user.TotpEnabled). + SetNillableTotpEnabledAt(user.TotpEnabledAt). + SetTotalRecharged(user.TotalRecharged). + SetSignupSource(user.SignupSource). + SetNillableLastLoginAt(user.LastLoginAt). + SetNillableLastActiveAt(user.LastActiveAt). + Save(ctx) + if err != nil { + return err + } + user.ID = entity.ID + user.CreatedAt = entity.CreatedAt + user.UpdatedAt = entity.UpdatedAt + return nil +} + +func (r *oauthPendingFlowUserRepo) GetByID(ctx context.Context, id int64) (*service.User, error) { + entity, err := r.client.User.Get(ctx, id) + if err != nil { + if dbent.IsNotFound(err) { + return nil, service.ErrUserNotFound + } + return nil, err + } + return oauthPendingFlowServiceUser(entity), nil +} + +func (r *oauthPendingFlowUserRepo) GetByEmail(ctx context.Context, email string) (*service.User, error) { + entity, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil, service.ErrUserNotFound + } + return nil, err + } + return oauthPendingFlowServiceUser(entity), nil +} + +func (r *oauthPendingFlowUserRepo) GetFirstAdmin(context.Context) (*service.User, error) { + panic("unexpected GetFirstAdmin call") +} + +func (r *oauthPendingFlowUserRepo) Update(ctx context.Context, user *service.User) error { + entity, err := r.client.User.UpdateOneID(user.ID). + SetEmail(user.Email). + SetUsername(user.Username). + SetNotes(user.Notes). + SetPasswordHash(user.PasswordHash). + SetRole(user.Role). + SetBalance(user.Balance). + SetConcurrency(user.Concurrency). + SetStatus(user.Status). + SetNillableTotpSecretEncrypted(user.TotpSecretEncrypted). + SetTotpEnabled(user.TotpEnabled). + SetNillableTotpEnabledAt(user.TotpEnabledAt). + SetTotalRecharged(user.TotalRecharged). + SetSignupSource(user.SignupSource). + SetNillableLastLoginAt(user.LastLoginAt). + SetNillableLastActiveAt(user.LastActiveAt). + Save(ctx) + if err != nil { + return err + } + user.UpdatedAt = entity.UpdatedAt + return nil +} + +func (r *oauthPendingFlowUserRepo) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error { + return r.client.User.UpdateOneID(userID).SetLastActiveAt(activeAt).Exec(ctx) +} + +func (r *oauthPendingFlowUserRepo) Delete(ctx context.Context, id int64) error { + if r.options.rejectDeleteWhileAuthIdentityExists { + count, err := r.client.AuthIdentity.Query().Where(authidentity.UserIDEQ(id)).Count(ctx) + if err != nil { + return err + } + if count > 0 { + return errors.New("cannot delete user while auth identities still exist") + } + } + return r.client.User.DeleteOneID(id).Exec(ctx) +} + +func (r *oauthPendingFlowUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) { + driver := r.client.Driver() + if tx := dbent.TxFromContext(ctx); tx != nil { + driver = tx.Client().Driver() + } + + var rows entsql.Rows + if err := driver.Query( + ctx, + `SELECT storage_provider, storage_key, url, content_type, byte_size, sha256 FROM user_avatars WHERE user_id = ?`, + []any{userID}, + &rows, + ); err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + if !rows.Next() { + return nil, rows.Err() + } + + var avatar service.UserAvatar + if err := rows.Scan( + &avatar.StorageProvider, + &avatar.StorageKey, + &avatar.URL, + &avatar.ContentType, + &avatar.ByteSize, + &avatar.SHA256, + ); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return &avatar, nil +} + +func (r *oauthPendingFlowUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) { + driver := r.client.Driver() + if tx := dbent.TxFromContext(ctx); tx != nil { + driver = tx.Client().Driver() + } + + var result entsql.Result + if err := driver.Exec( + ctx, + `INSERT INTO user_avatars (user_id, storage_provider, storage_key, url, content_type, byte_size, sha256, updated_at) +VALUES (?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP) +ON CONFLICT(user_id) DO UPDATE SET + storage_provider = excluded.storage_provider, + storage_key = excluded.storage_key, + url = excluded.url, + content_type = excluded.content_type, + byte_size = excluded.byte_size, + sha256 = excluded.sha256, + updated_at = CURRENT_TIMESTAMP`, + []any{ + userID, + input.StorageProvider, + input.StorageKey, + input.URL, + input.ContentType, + input.ByteSize, + input.SHA256, + }, + &result, + ); err != nil { + return nil, err + } + + return &service.UserAvatar{ + StorageProvider: input.StorageProvider, + StorageKey: input.StorageKey, + URL: input.URL, + ContentType: input.ContentType, + ByteSize: input.ByteSize, + SHA256: input.SHA256, + }, nil +} + +func (r *oauthPendingFlowUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error { + driver := r.client.Driver() + if tx := dbent.TxFromContext(ctx); tx != nil { + driver = tx.Client().Driver() + } + + var result entsql.Result + return driver.Exec(ctx, `DELETE FROM user_avatars WHERE user_id = ?`, []any{userID}, &result) +} + +func (r *oauthPendingFlowUserRepo) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) { + panic("unexpected List call") +} + +func (r *oauthPendingFlowUserRepo) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) { + panic("unexpected ListWithFilters call") +} + +func (r *oauthPendingFlowUserRepo) UpdateBalance(context.Context, int64, float64) error { + panic("unexpected UpdateBalance call") +} + +func (r *oauthPendingFlowUserRepo) DeductBalance(context.Context, int64, float64) error { + panic("unexpected DeductBalance call") +} + +func (r *oauthPendingFlowUserRepo) UpdateConcurrency(context.Context, int64, int) error { + panic("unexpected UpdateConcurrency call") +} + +func (r *oauthPendingFlowUserRepo) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) { + return map[int64]*time.Time{}, nil +} + +func (r *oauthPendingFlowUserRepo) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) { + return nil, nil +} + +func (r *oauthPendingFlowUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) { + count, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Count(ctx) + return count > 0, err +} + +func (r *oauthPendingFlowUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { + panic("unexpected RemoveGroupFromAllowedGroups call") +} + +func (r *oauthPendingFlowUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { + panic("unexpected AddGroupToAllowedGroups call") +} + +func (r *oauthPendingFlowUserRepo) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { + panic("unexpected RemoveGroupFromUserAllowedGroups call") +} + +func (r *oauthPendingFlowUserRepo) ListUserAuthIdentities(ctx context.Context, userID int64) ([]service.UserAuthIdentityRecord, error) { + identities, err := r.client.AuthIdentity.Query(). + Where(authidentity.UserIDEQ(userID)). + All(ctx) + if err != nil { + return nil, err + } + + records := make([]service.UserAuthIdentityRecord, 0, len(identities)) + for _, identity := range identities { + if identity == nil { + continue + } + records = append(records, service.UserAuthIdentityRecord{ + ProviderType: identity.ProviderType, + ProviderKey: identity.ProviderKey, + ProviderSubject: identity.ProviderSubject, + VerifiedAt: identity.VerifiedAt, + Issuer: identity.Issuer, + Metadata: identity.Metadata, + CreatedAt: identity.CreatedAt, + UpdatedAt: identity.UpdatedAt, + }) + } + return records, nil +} + +func (r *oauthPendingFlowUserRepo) UnbindUserAuthProvider(context.Context, int64, string) error { + panic("unexpected UnbindUserAuthProvider call") +} + +func (r *oauthPendingFlowUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error { + update := r.client.User.UpdateOneID(userID) + if encryptedSecret == nil { + update = update.ClearTotpSecretEncrypted() + } else { + update = update.SetTotpSecretEncrypted(*encryptedSecret) + } + return update.Exec(ctx) +} + +func (r *oauthPendingFlowUserRepo) EnableTotp(ctx context.Context, userID int64) error { + return r.client.User.UpdateOneID(userID). + SetTotpEnabled(true). + SetTotpEnabledAt(time.Now().UTC()). + Exec(ctx) +} + +func (r *oauthPendingFlowUserRepo) DisableTotp(ctx context.Context, userID int64) error { + return r.client.User.UpdateOneID(userID). + SetTotpEnabled(false). + ClearTotpSecretEncrypted(). + ClearTotpEnabledAt(). + Exec(ctx) +} + +func oauthPendingFlowServiceUser(entity *dbent.User) *service.User { + if entity == nil { + return nil + } + return &service.User{ + ID: entity.ID, + Email: entity.Email, + Username: entity.Username, + Notes: entity.Notes, + PasswordHash: entity.PasswordHash, + Role: entity.Role, + Balance: entity.Balance, + Concurrency: entity.Concurrency, + Status: entity.Status, + SignupSource: entity.SignupSource, + LastLoginAt: entity.LastLoginAt, + LastActiveAt: entity.LastActiveAt, + TotpSecretEncrypted: entity.TotpSecretEncrypted, + TotpEnabled: entity.TotpEnabled, + TotpEnabledAt: entity.TotpEnabledAt, + TotalRecharged: entity.TotalRecharged, + CreatedAt: entity.CreatedAt, + UpdatedAt: entity.UpdatedAt, + } +} + +type oauthPendingFlowDefaultSubAssignerStub struct { + calls []service.AssignSubscriptionInput +} + +func (s *oauthPendingFlowDefaultSubAssignerStub) AssignOrExtendSubscription( + _ context.Context, + input *service.AssignSubscriptionInput, +) (*service.UserSubscription, bool, error) { + if input != nil { + s.calls = append(s.calls, *input) + } + return nil, false, nil +} + +type oauthPendingFlowTotpCacheStub struct { + setupSessions map[int64]*service.TotpSetupSession + loginSessions map[string]*service.TotpLoginSession + verifyAttempts map[int64]int +} + +func (s *oauthPendingFlowTotpCacheStub) GetSetupSession(_ context.Context, userID int64) (*service.TotpSetupSession, error) { + if s == nil || s.setupSessions == nil { + return nil, nil + } + return s.setupSessions[userID], nil +} + +func (s *oauthPendingFlowTotpCacheStub) SetSetupSession(_ context.Context, userID int64, session *service.TotpSetupSession, _ time.Duration) error { + if s.setupSessions == nil { + s.setupSessions = map[int64]*service.TotpSetupSession{} + } + s.setupSessions[userID] = session + return nil +} + +func (s *oauthPendingFlowTotpCacheStub) DeleteSetupSession(_ context.Context, userID int64) error { + delete(s.setupSessions, userID) + return nil +} + +func (s *oauthPendingFlowTotpCacheStub) GetLoginSession(_ context.Context, tempToken string) (*service.TotpLoginSession, error) { + if s == nil || s.loginSessions == nil { + return nil, nil + } + return s.loginSessions[tempToken], nil +} + +func (s *oauthPendingFlowTotpCacheStub) SetLoginSession(_ context.Context, tempToken string, session *service.TotpLoginSession, _ time.Duration) error { + if s.loginSessions == nil { + s.loginSessions = map[string]*service.TotpLoginSession{} + } + s.loginSessions[tempToken] = session + return nil +} + +func (s *oauthPendingFlowTotpCacheStub) DeleteLoginSession(_ context.Context, tempToken string) error { + delete(s.loginSessions, tempToken) + return nil +} + +func (s *oauthPendingFlowTotpCacheStub) IncrementVerifyAttempts(_ context.Context, userID int64) (int, error) { + if s.verifyAttempts == nil { + s.verifyAttempts = map[int64]int{} + } + s.verifyAttempts[userID]++ + return s.verifyAttempts[userID], nil +} + +func (s *oauthPendingFlowTotpCacheStub) GetVerifyAttempts(_ context.Context, userID int64) (int, error) { + if s == nil || s.verifyAttempts == nil { + return 0, nil + } + return s.verifyAttempts[userID], nil +} + +func (s *oauthPendingFlowTotpCacheStub) ClearVerifyAttempts(_ context.Context, userID int64) error { + delete(s.verifyAttempts, userID) + return nil +} + +type oauthPendingFlowTotpEncryptorStub struct{} + +func (oauthPendingFlowTotpEncryptorStub) Encrypt(plaintext string) (string, error) { + return plaintext, nil +} + +func (oauthPendingFlowTotpEncryptorStub) Decrypt(ciphertext string) (string, error) { + return ciphertext, nil +} diff --git a/backend/internal/handler/auth_oauth_test_helpers_test.go b/backend/internal/handler/auth_oauth_test_helpers_test.go new file mode 100644 index 0000000000000000000000000000000000000000..8eb87dbb4c58d5080d0fe7ed7d06883f75c79b8f --- /dev/null +++ b/backend/internal/handler/auth_oauth_test_helpers_test.go @@ -0,0 +1,39 @@ +package handler + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func buildEncodedOAuthBindUserCookie(t *testing.T, userID int64, secret string) string { + t.Helper() + value, err := buildOAuthBindUserCookieValue(userID, secret) + require.NoError(t, err) + return value +} + +func encodedCookie(name, value string) *http.Cookie { + return &http.Cookie{ + Name: name, + Value: encodeCookieValue(value), + Path: "/", + } +} + +func findCookie(cookies []*http.Cookie, name string) *http.Cookie { + for _, cookie := range cookies { + if cookie.Name == name { + return cookie + } + } + return nil +} + +func decodeCookieValueForTest(t *testing.T, value string) string { + t.Helper() + decoded, err := decodeCookieValue(value) + require.NoError(t, err) + return decoded +} diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go index 9d24df88ab1a08a2a436e120e7e367c65b66c881..d2042a87b03c40616caa4a5cb69f279ce5386008 100644 --- a/backend/internal/handler/auth_oidc_oauth.go +++ b/backend/internal/handler/auth_oidc_oauth.go @@ -19,6 +19,7 @@ import ( "strings" "time" + dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" @@ -32,14 +33,16 @@ import ( ) const ( - oidcOAuthCookiePath = "/api/v1/auth/oauth/oidc" - oidcOAuthStateCookieName = "oidc_oauth_state" - oidcOAuthVerifierCookie = "oidc_oauth_verifier" - oidcOAuthRedirectCookie = "oidc_oauth_redirect" - oidcOAuthNonceCookie = "oidc_oauth_nonce" - oidcOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes - oidcOAuthDefaultRedirectTo = "/dashboard" - oidcOAuthDefaultFrontendCB = "/auth/oidc/callback" + oidcOAuthCookiePath = "/api/v1/auth/oauth/oidc" + oidcOAuthStateCookieName = "oidc_oauth_state" + oidcOAuthVerifierCookie = "oidc_oauth_verifier" + oidcOAuthRedirectCookie = "oidc_oauth_redirect" + oidcOAuthNonceCookie = "oidc_oauth_nonce" + oidcOAuthIntentCookieName = "oidc_oauth_intent" + oidcOAuthBindUserCookieName = "oidc_oauth_bind_user" + oidcOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes + oidcOAuthDefaultRedirectTo = "/dashboard" + oidcOAuthDefaultFrontendCB = "/auth/oidc/callback" ) type oidcTokenResponse struct { @@ -87,6 +90,8 @@ type oidcUserInfoClaims struct { Username string Subject string EmailVerified *bool + DisplayName string + AvatarURL string } type oidcJWKSet struct { @@ -127,30 +132,46 @@ func (h *AuthHandler) OIDCOAuthStart(c *gin.Context) { redirectTo = oidcOAuthDefaultRedirectTo } + browserSessionKey, err := generateOAuthPendingBrowserSession() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BROWSER_SESSION_GEN_FAILED", "failed to generate oauth browser session").WithCause(err)) + return + } + secureCookie := isRequestHTTPS(c) oidcSetCookie(c, oidcOAuthStateCookieName, encodeCookieValue(state), oidcOAuthCookieMaxAgeSec, secureCookie) oidcSetCookie(c, oidcOAuthRedirectCookie, encodeCookieValue(redirectTo), oidcOAuthCookieMaxAgeSec, secureCookie) - - codeChallenge := "" - if cfg.UsePKCE { - verifier, genErr := oauth.GenerateCodeVerifier() - if genErr != nil { - response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(genErr)) + intent := normalizeOAuthIntent(c.Query("intent")) + oidcSetCookie(c, oidcOAuthIntentCookieName, encodeCookieValue(intent), oidcOAuthCookieMaxAgeSec, secureCookie) + setOAuthPendingBrowserCookie(c, browserSessionKey, secureCookie) + clearOAuthPendingSessionCookie(c, secureCookie) + if intent == oauthIntentBindCurrentUser { + bindCookieValue, err := h.buildOAuthBindUserCookieFromContext(c) + if err != nil { + response.ErrorFrom(c, err) return } - codeChallenge = oauth.GenerateCodeChallenge(verifier) - oidcSetCookie(c, oidcOAuthVerifierCookie, encodeCookieValue(verifier), oidcOAuthCookieMaxAgeSec, secureCookie) + oidcSetCookie(c, oidcOAuthBindUserCookieName, encodeCookieValue(bindCookieValue), oidcOAuthCookieMaxAgeSec, secureCookie) + } else { + oidcClearCookie(c, oidcOAuthBindUserCookieName, secureCookie) } + codeChallenge := "" + verifier, genErr := oauth.GenerateCodeVerifier() + if genErr != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(genErr)) + return + } + codeChallenge = oauth.GenerateCodeChallenge(verifier) + oidcSetCookie(c, oidcOAuthVerifierCookie, encodeCookieValue(verifier), oidcOAuthCookieMaxAgeSec, secureCookie) + nonce := "" - if cfg.ValidateIDToken { - nonce, err = oauth.GenerateState() - if err != nil { - response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_NONCE_GEN_FAILED", "failed to generate oauth nonce").WithCause(err)) - return - } - oidcSetCookie(c, oidcOAuthNonceCookie, encodeCookieValue(nonce), oidcOAuthCookieMaxAgeSec, secureCookie) + nonce, err = oauth.GenerateState() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_NONCE_GEN_FAILED", "failed to generate oauth nonce").WithCause(err)) + return } + oidcSetCookie(c, oidcOAuthNonceCookie, encodeCookieValue(nonce), oidcOAuthCookieMaxAgeSec, secureCookie) redirectURI := strings.TrimSpace(cfg.RedirectURL) if redirectURI == "" { @@ -199,6 +220,8 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { oidcClearCookie(c, oidcOAuthVerifierCookie, secureCookie) oidcClearCookie(c, oidcOAuthRedirectCookie, secureCookie) oidcClearCookie(c, oidcOAuthNonceCookie, secureCookie) + oidcClearCookie(c, oidcOAuthIntentCookieName, secureCookie) + oidcClearCookie(c, oidcOAuthBindUserCookieName, secureCookie) }() expectedState, err := readCookieDecoded(c, oidcOAuthStateCookieName) @@ -212,23 +235,26 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { if redirectTo == "" { redirectTo = oidcOAuthDefaultRedirectTo } + browserSessionKey, _ := readOAuthPendingBrowserCookie(c) + if strings.TrimSpace(browserSessionKey) == "" { + redirectOAuthError(c, frontendCallback, "missing_browser_session", "missing oauth browser session", "") + return + } + intent, _ := readCookieDecoded(c, oidcOAuthIntentCookieName) + intent = normalizeOAuthIntent(intent) codeVerifier := "" - if cfg.UsePKCE { - codeVerifier, _ = readCookieDecoded(c, oidcOAuthVerifierCookie) - if codeVerifier == "" { - redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "") - return - } + codeVerifier, _ = readCookieDecoded(c, oidcOAuthVerifierCookie) + if codeVerifier == "" { + redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "") + return } expectedNonce := "" - if cfg.ValidateIDToken { - expectedNonce, _ = readCookieDecoded(c, oidcOAuthNonceCookie) - if expectedNonce == "" { - redirectOAuthError(c, frontendCallback, "missing_nonce", "missing oauth nonce", "") - return - } + expectedNonce, _ = readCookieDecoded(c, oidcOAuthNonceCookie) + if expectedNonce == "" { + redirectOAuthError(c, frontendCallback, "missing_nonce", "missing oauth nonce", "") + return } redirectURI := strings.TrimSpace(cfg.RedirectURL) @@ -258,7 +284,7 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { return } - if cfg.ValidateIDToken && strings.TrimSpace(tokenResp.IDToken) == "" { + if strings.TrimSpace(tokenResp.IDToken) == "" { redirectOAuthError(c, frontendCallback, "missing_id_token", "missing id_token", "") return } @@ -298,54 +324,235 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { if emailVerified == nil { emailVerified = idClaims.EmailVerified } - if cfg.RequireEmailVerified { - if emailVerified == nil || !*emailVerified { - redirectOAuthError(c, frontendCallback, "email_not_verified", "email is not verified", "") - return - } + if userInfoClaims.Subject != "" && idClaims.Subject != "" && strings.TrimSpace(userInfoClaims.Subject) != strings.TrimSpace(idClaims.Subject) { + redirectOAuthError(c, frontendCallback, "subject_mismatch", "userinfo subject does not match id_token", "") + return } identityKey := oidcIdentityKey(issuer, subject) - email := oidcSelectLoginEmail(userInfoClaims.Email, idClaims.Email, identityKey) + compatEmail := strings.TrimSpace(firstNonEmpty(userInfoClaims.Email, idClaims.Email)) + email := oidcSyntheticEmailFromIdentityKey(identityKey) username := firstNonEmpty( userInfoClaims.Username, idClaims.PreferredUsername, idClaims.Name, oidcFallbackUsername(subject), ) + identityRef := service.PendingAuthIdentityKey{ + ProviderType: "oidc", + ProviderKey: issuer, + ProviderSubject: subject, + } + upstreamClaims := map[string]any{ + "email": email, + "username": username, + "subject": subject, + "issuer": issuer, + "email_verified": emailVerified != nil && *emailVerified, + "provider_fallback": strings.TrimSpace(cfg.ProviderName), + "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username), + "suggested_avatar_url": userInfoClaims.AvatarURL, + } + if compatEmail != "" && !strings.EqualFold(strings.TrimSpace(compatEmail), strings.TrimSpace(email)) { + upstreamClaims["compat_email"] = compatEmail + } + if intent == oauthIntentBindCurrentUser { + targetUserID, err := h.readOAuthBindUserIDFromCookie(c, oidcOAuthBindUserCookieName) + if err != nil { + redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth bind target", "") + return + } + if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: oauthIntentBindCurrentUser, + Identity: identityRef, + TargetUserID: &targetUserID, + ResolvedEmail: email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, + CompletionResponse: map[string]any{ + "redirect": redirectTo, + }, + }); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth bind", "") + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } - // 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired - tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") + existingIdentityUser, err := h.findOAuthIdentityUser(c.Request.Context(), identityRef) if err != nil { - if errors.Is(err, service.ErrOAuthInvitationRequired) { - pendingToken, tokenErr := h.authService.CreatePendingOAuthToken(email, username) - if tokenErr != nil { - redirectOAuthError(c, frontendCallback, "login_failed", "service_error", "") - return - } - fragment := url.Values{} - fragment.Set("error", "invitation_required") - fragment.Set("pending_oauth_token", pendingToken) - fragment.Set("redirect", redirectTo) - redirectWithFragment(c, frontendCallback, fragment) + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + if existingIdentityUser != nil { + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "") + if err != nil { + redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err)) return } - redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err)) + if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: oauthIntentLogin, + Identity: identityRef, + TargetUserID: &user.ID, + ResolvedEmail: existingIdentityUser.Email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, + CompletionResponse: map[string]any{ + "access_token": tokenPair.AccessToken, + "refresh_token": tokenPair.RefreshToken, + "expires_in": tokenPair.ExpiresIn, + "token_type": "Bearer", + "redirect": redirectTo, + }, + }); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) return } - fragment := url.Values{} - fragment.Set("access_token", tokenPair.AccessToken) - fragment.Set("refresh_token", tokenPair.RefreshToken) - fragment.Set("expires_in", fmt.Sprintf("%d", tokenPair.ExpiresIn)) - fragment.Set("token_type", "Bearer") - fragment.Set("redirect", redirectTo) - redirectWithFragment(c, frontendCallback, fragment) + compatEmailUser, err := h.findOIDCCompatEmailUser(c.Request.Context(), compatEmail) + if err != nil { + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + + if cfg.RequireEmailVerified { + if emailVerified == nil || !*emailVerified { + redirectOAuthError(c, frontendCallback, "email_not_verified", "email is not verified", "") + return + } + } + + if h.isForceEmailOnThirdPartySignup(c.Request.Context()) { + if err := h.createOIDCOAuthChoicePendingSession( + c, + identityRef, + email, + email, + redirectTo, + browserSessionKey, + upstreamClaims, + compatEmail, + compatEmailUser, + true, + ); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } + + if err := h.createOIDCOAuthChoicePendingSession( + c, + identityRef, + email, + email, + redirectTo, + browserSessionKey, + upstreamClaims, + compatEmail, + compatEmailUser, + h.isForceEmailOnThirdPartySignup(c.Request.Context()), + ); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) +} + +func (h *AuthHandler) findOIDCCompatEmailUser(ctx context.Context, email string) (*dbent.User, error) { + client := h.entClient() + if client == nil { + return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + + email = strings.TrimSpace(strings.ToLower(email)) + if email == "" || + strings.HasSuffix(email, service.LinuxDoConnectSyntheticEmailDomain) || + strings.HasSuffix(email, service.OIDCConnectSyntheticEmailDomain) || + strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) { + return nil, nil + } + + userEntity, err := findUserByNormalizedEmail(ctx, client, email) + if err != nil { + if errors.Is(err, service.ErrUserNotFound) { + return nil, nil + } + return nil, infraerrors.InternalServer("COMPAT_EMAIL_LOOKUP_FAILED", "failed to look up compat email user").WithCause(err) + } + return userEntity, nil +} + +func (h *AuthHandler) createOIDCOAuthChoicePendingSession( + c *gin.Context, + identity service.PendingAuthIdentityKey, + suggestedEmail string, + resolvedEmail string, + redirectTo string, + browserSessionKey string, + upstreamClaims map[string]any, + compatEmail string, + compatEmailUser *dbent.User, + forceEmailOnSignup bool, +) error { + suggestionEmail := strings.TrimSpace(suggestedEmail) + canonicalEmail := strings.TrimSpace(resolvedEmail) + if suggestionEmail == "" { + suggestionEmail = canonicalEmail + } + + completionResponse := map[string]any{ + "step": oauthPendingChoiceStep, + "adoption_required": true, + "redirect": strings.TrimSpace(redirectTo), + "email": suggestionEmail, + "resolved_email": canonicalEmail, + "existing_account_email": "", + "existing_account_bindable": false, + "create_account_allowed": true, + "force_email_on_signup": forceEmailOnSignup, + "choice_reason": "third_party_signup", + } + if strings.TrimSpace(compatEmail) != "" { + completionResponse["compat_email"] = strings.TrimSpace(compatEmail) + } + if compatEmailUser != nil { + completionResponse["email"] = strings.TrimSpace(compatEmailUser.Email) + completionResponse["existing_account_email"] = strings.TrimSpace(compatEmailUser.Email) + completionResponse["existing_account_bindable"] = true + completionResponse["choice_reason"] = "compat_email_match" + } + if forceEmailOnSignup && compatEmailUser == nil { + completionResponse["choice_reason"] = "force_email_on_signup" + } + + resolvedChoiceEmail := suggestionEmail + if compatEmailUser != nil { + resolvedChoiceEmail = strings.TrimSpace(compatEmailUser.Email) + } + + return h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: oauthIntentLogin, + Identity: identity, + ResolvedEmail: resolvedChoiceEmail, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, + CompletionResponse: completionResponse, + }) } type completeOIDCOAuthRequest struct { - PendingOAuthToken string `json:"pending_oauth_token" binding:"required"` - InvitationCode string `json:"invitation_code" binding:"required"` + InvitationCode string `json:"invitation_code" binding:"required"` + AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` + AdoptAvatar *bool `json:"adopt_avatar,omitempty"` } // CompleteOIDCOAuthRegistration completes a pending OAuth registration by validating @@ -358,17 +565,75 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) { return } - email, username, err := h.authService.VerifyPendingOAuthToken(req.PendingOAuthToken) + secureCookie := isRequestHTTPS(c) + sessionToken, err := readOAuthPendingSessionCookie(c) + if err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound) + return + } + browserSessionKey, err := readOAuthPendingBrowserCookie(c) + if err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch) + return + } + pendingSvc, err := h.pendingIdentityService() if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{"error": "INVALID_TOKEN", "message": "invalid or expired registration token"}) + response.ErrorFrom(c, err) + return + } + session, err := pendingSvc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey) + if err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, err) + return + } + if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil { + response.ErrorFrom(c, err) + return + } + if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + email := strings.TrimSpace(session.ResolvedEmail) + username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username") + if email == "" || username == "" { + response.ErrorFrom(c, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")) return } - tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) if err != nil { response.ErrorFrom(c, err) return } + decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{ + AdoptDisplayName: req.AdoptDisplayName, + AdoptAvatar: req.AdoptAvatar, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID); err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) + return + } + h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) + if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, err) + return + } + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) c.JSON(http.StatusOK, gin.H{ "access_token": tokenPair.AccessToken, @@ -405,9 +670,7 @@ func oidcExchangeCode( form.Set("client_id", cfg.ClientID) form.Set("code", code) form.Set("redirect_uri", redirectURI) - if cfg.UsePKCE { - form.Set("code_verifier", codeVerifier) - } + form.Set("code_verifier", codeVerifier) r := client.R(). SetContext(ctx). @@ -560,9 +823,26 @@ func oidcParseUserInfo(body string, cfg config.OIDCConnectConfig) *oidcUserInfoC if verified, ok := getGJSONBool(body, "email_verified"); ok { claims.EmailVerified = &verified } + claims.DisplayName = firstNonEmpty( + getGJSON(body, "name"), + getGJSON(body, "nickname"), + getGJSON(body, "display_name"), + getGJSON(body, "preferred_username"), + getGJSON(body, "username"), + ) + claims.AvatarURL = firstNonEmpty( + getGJSON(body, "picture"), + getGJSON(body, "avatar_url"), + getGJSON(body, "avatar"), + getGJSON(body, "profile_image_url"), + getGJSON(body, "user.avatar"), + getGJSON(body, "user.avatar_url"), + ) claims.Email = strings.TrimSpace(claims.Email) claims.Username = strings.TrimSpace(claims.Username) claims.Subject = strings.TrimSpace(claims.Subject) + claims.DisplayName = strings.TrimSpace(claims.DisplayName) + claims.AvatarURL = strings.TrimSpace(claims.AvatarURL) return claims } @@ -592,13 +872,9 @@ func buildOIDCAuthorizeURL(cfg config.OIDCConnectConfig, state, nonce, codeChall q.Set("scope", cfg.Scopes) } q.Set("state", state) - if strings.TrimSpace(nonce) != "" { - q.Set("nonce", nonce) - } - if cfg.UsePKCE { - q.Set("code_challenge", codeChallenge) - q.Set("code_challenge_method", "S256") - } + q.Set("nonce", nonce) + q.Set("code_challenge", codeChallenge) + q.Set("code_challenge_method", "S256") u.RawQuery = q.Encode() return u.String(), nil @@ -831,14 +1107,6 @@ func oidcSyntheticEmailFromIdentityKey(identityKey string) string { return "oidc-" + hex.EncodeToString(sum[:16]) + service.OIDCConnectSyntheticEmailDomain } -func oidcSelectLoginEmail(userInfoEmail, idTokenEmail, identityKey string) string { - email := strings.TrimSpace(firstNonEmpty(userInfoEmail, idTokenEmail)) - if email != "" { - return email - } - return oidcSyntheticEmailFromIdentityKey(identityKey) -} - func oidcFallbackUsername(subject string) string { subject = strings.TrimSpace(subject) if subject == "" { diff --git a/backend/internal/handler/auth_oidc_oauth_test.go b/backend/internal/handler/auth_oidc_oauth_test.go index a161aa77cf6faefaebae46f6a570227263484c90..2acca18a8a81e4e99aa10f4369bd13f58866d3d7 100644 --- a/backend/internal/handler/auth_oidc_oauth_test.go +++ b/backend/internal/handler/auth_oidc_oauth_test.go @@ -1,6 +1,7 @@ package handler import ( + "bytes" "context" "crypto/rand" "crypto/rsa" @@ -12,7 +13,15 @@ import ( "testing" "time" + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + dbuser "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/internal/config" + servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/require" ) @@ -30,26 +39,11 @@ func TestOIDCSyntheticEmailStableAndDistinct(t *testing.T) { require.Contains(t, e1, "@oidc-connect.invalid") } -func TestOIDCSelectLoginEmailPrefersRealEmail(t *testing.T) { - identityKey := oidcIdentityKey("https://issuer.example.com", "subject-a") - - email := oidcSelectLoginEmail("user@example.com", "idtoken@example.com", identityKey) - require.Equal(t, "user@example.com", email) - - email = oidcSelectLoginEmail("", "idtoken@example.com", identityKey) - require.Equal(t, "idtoken@example.com", email) - - email = oidcSelectLoginEmail("", "", identityKey) - require.Contains(t, email, "@oidc-connect.invalid") - require.Equal(t, oidcSyntheticEmailFromIdentityKey(identityKey), email) -} - func TestBuildOIDCAuthorizeURLIncludesNonceAndPKCE(t *testing.T) { cfg := config.OIDCConnectConfig{ AuthorizeURL: "https://issuer.example.com/auth", ClientID: "cid", Scopes: "openid email profile", - UsePKCE: true, } u, err := buildOIDCAuthorizeURL(cfg, "state123", "nonce123", "challenge123", "https://app.example.com/callback") @@ -106,6 +100,26 @@ func TestOIDCParseAndValidateIDToken(t *testing.T) { require.Error(t, err) } +func TestOIDCParseUserInfoIncludesSuggestedProfile(t *testing.T) { + cfg := config.OIDCConnectConfig{} + + claims := oidcParseUserInfo(`{ + "sub":"subject-1", + "preferred_username":"alice", + "name":"Alice Example", + "picture":"https://cdn.example/avatar.png", + "email":"alice@example.com", + "email_verified":true + }`, cfg) + + require.Equal(t, "subject-1", claims.Subject) + require.Equal(t, "alice", claims.Username) + require.Equal(t, "Alice Example", claims.DisplayName) + require.Equal(t, "https://cdn.example/avatar.png", claims.AvatarURL) + require.NotNil(t, claims.EmailVerified) + require.True(t, *claims.EmailVerified) +} + func buildRSAJWK(kid string, pub *rsa.PublicKey) oidcJWK { n := base64.RawURLEncoding.EncodeToString(pub.N.Bytes()) e := base64.RawURLEncoding.EncodeToString(big.NewInt(int64(pub.E)).Bytes()) @@ -118,3 +132,589 @@ func buildRSAJWK(kid string, pub *rsa.PublicKey) oidcJWK { E: e, } } + +func TestOIDCOAuthBindStartRedirectsAndSetsBindCookies(t *testing.T) { + handler := newOIDCOAuthTestHandler(t, false, config.OIDCConnectConfig{ + Enabled: true, + ClientID: "oidc-client", + ClientSecret: "oidc-secret", + IssuerURL: "https://issuer.example.com", + AuthorizeURL: "https://issuer.example.com/oauth/authorize", + TokenURL: "https://issuer.example.com/oauth/token", + UserInfoURL: "https://issuer.example.com/oauth/userinfo", + JWKSURL: "https://issuer.example.com/oauth/jwks", + Scopes: "openid profile email", + RedirectURL: "https://api.example.com/api/v1/auth/oauth/oidc/callback", + FrontendRedirectURL: "/auth/oidc/callback", + TokenAuthMethod: "client_secret_post", + UsePKCE: true, + ValidateIDToken: true, + AllowedSigningAlgs: "RS256", + ClockSkewSeconds: 120, + RequireEmailVerified: false, + }) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/bind/start?intent=bind_current_user&redirect=/settings/connections", nil) + c.Request = req + c.Set(string(servermiddleware.ContextKeyUser), servermiddleware.AuthSubject{UserID: 84}) + + handler.OIDCOAuthStart(c) + + require.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + require.Contains(t, location, "issuer.example.com/oauth/authorize") + require.Contains(t, location, "client_id=oidc-client") + require.Contains(t, location, "nonce=") + + cookies := recorder.Result().Cookies() + require.NotNil(t, findCookie(cookies, oidcOAuthStateCookieName)) + require.NotNil(t, findCookie(cookies, oidcOAuthRedirectCookie)) + require.NotNil(t, findCookie(cookies, oidcOAuthVerifierCookie)) + require.NotNil(t, findCookie(cookies, oidcOAuthNonceCookie)) + require.NotNil(t, findCookie(cookies, oauthPendingBrowserCookieName)) + + intentCookie := findCookie(cookies, oidcOAuthIntentCookieName) + require.NotNil(t, intentCookie) + require.Equal(t, oauthIntentBindCurrentUser, decodeCookieValueForTest(t, intentCookie.Value)) + + bindCookie := findCookie(cookies, oidcOAuthBindUserCookieName) + require.NotNil(t, bindCookie) + userID, err := parseOAuthBindUserCookieValue(decodeCookieValueForTest(t, bindCookie.Value), "test-secret") + require.NoError(t, err) + require.Equal(t, int64(84), userID) +} + +func TestOIDCOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t *testing.T) { + cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{ + Subject: "oidc-subject-login", + PreferredUsername: "oidc_login", + DisplayName: "OIDC Login Display", + AvatarURL: "https://cdn.example/oidc-login.png", + Email: "oidc-login@example.com", + EmailVerified: true, + }) + defer cleanup() + + handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg) + t.Cleanup(func() { _ = client.Close() }) + + ctx := context.Background() + existingUser, err := client.User.Create(). + SetEmail(oidcSyntheticEmailFromIdentityKey(oidcIdentityKey(cfg.IssuerURL, "oidc-subject-login"))). + SetUsername("legacy-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + _, err = client.AuthIdentity.Create(). + SetUserID(existingUser.ID). + SetProviderType("oidc"). + SetProviderKey(cfg.IssuerURL). + SetProviderSubject("oidc-subject-login"). + SetMetadata(map[string]any{"username": "legacy-user"}). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-123", nil) + req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard")) + req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-123")) + req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-subject-login")) + req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin)) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.OIDCOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.Equal(t, oauthIntentLogin, session.Intent) + require.NotNil(t, session.TargetUserID) + require.Equal(t, existingUser.ID, *session.TargetUserID) + require.Equal(t, cfg.IssuerURL, session.ProviderKey) + require.Equal(t, "OIDC Login Display", session.UpstreamIdentityClaims["suggested_display_name"]) + + completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.True(t, ok) + require.Equal(t, "/dashboard", completion["redirect"]) + require.NotEmpty(t, completion["access_token"]) + require.Nil(t, completion["error"]) +} + +func TestOIDCOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing.T) { + cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{ + Subject: "oidc-subject-compat", + PreferredUsername: "oidc_compat", + DisplayName: "OIDC Compat Display", + AvatarURL: "https://cdn.example/oidc-compat.png", + Email: "legacy@example.com", + EmailVerified: true, + }) + defer cleanup() + + handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg) + t.Cleanup(func() { _ = client.Close() }) + + ctx := context.Background() + existingUser, err := client.User.Create(). + SetEmail("legacy@example.com"). + SetUsername("legacy-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-compat", nil) + req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-compat")) + req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard")) + req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-compat")) + req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-subject-compat")) + req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin)) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-compat")) + c.Request = req + + handler.OIDCOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.Equal(t, oauthIntentLogin, session.Intent) + require.Nil(t, session.TargetUserID) + require.Equal(t, existingUser.Email, session.ResolvedEmail) + require.Equal(t, "legacy@example.com", session.UpstreamIdentityClaims["compat_email"]) + + completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.True(t, ok) + require.Equal(t, "/dashboard", completion["redirect"]) + require.Equal(t, oauthPendingChoiceStep, completion["step"]) + require.Equal(t, existingUser.Email, completion["email"]) + require.Equal(t, existingUser.Email, completion["existing_account_email"]) + require.Equal(t, true, completion["existing_account_bindable"]) + require.Equal(t, "compat_email_match", completion["choice_reason"]) + _, hasAccessToken := completion["access_token"] + require.False(t, hasAccessToken) +} + +func TestOIDCOAuthCallbackAllowsCompatEmailBindWhenUpstreamEmailIsUnverified(t *testing.T) { + cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{ + Subject: "oidc-subject-unverified-compat", + PreferredUsername: "oidc_unverified", + DisplayName: "OIDC Unverified Compat Display", + AvatarURL: "https://cdn.example/oidc-unverified.png", + Email: "owner@example.com", + EmailVerified: false, + }) + defer cleanup() + cfg.RequireEmailVerified = true + + handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg) + t.Cleanup(func() { _ = client.Close() }) + + ctx := context.Background() + _, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-unverified-compat", nil) + req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-unverified-compat")) + req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/settings/connections")) + req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-unverified-compat")) + req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-subject-unverified-compat")) + req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin)) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-unverified-compat")) + c.Request = req + + handler.OIDCOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/oidc/callback#error=email_not_verified&error_message=email+is+not+verified", recorder.Header().Get("Location")) + require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)) + + count, err := client.PendingAuthSession.Query().Count(ctx) + require.NoError(t, err) + require.Zero(t, count) +} + +func TestOIDCOAuthCallbackCreatesChoicePendingSessionWhenSignupRequiresInvite(t *testing.T) { + cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{ + Subject: "oidc-subject-invite", + PreferredUsername: "oidc_invite", + DisplayName: "OIDC Invite Display", + AvatarURL: "https://cdn.example/oidc-invite.png", + Email: "oidc-invite@example.com", + EmailVerified: true, + }) + defer cleanup() + + handler, client := newOIDCOAuthHandlerAndClient(t, true, cfg) + t.Cleanup(func() { _ = client.Close() }) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-456", nil) + req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-456")) + req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard")) + req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-456")) + req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-subject-invite")) + req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin)) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-456")) + c.Request = req + + handler.OIDCOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + ctx := context.Background() + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.Equal(t, oauthIntentLogin, session.Intent) + require.Nil(t, session.TargetUserID) + + completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.True(t, ok) + require.Equal(t, oauthPendingChoiceStep, completion["step"]) + require.Equal(t, "/dashboard", completion["redirect"]) + require.Equal(t, "third_party_signup", completion["choice_reason"]) +} + +func TestOIDCOAuthCallbackCreatesBindPendingSessionForCurrentUser(t *testing.T) { + cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{ + Subject: "oidc-subject-bind", + PreferredUsername: "oidc_bind", + DisplayName: "OIDC Bind Display", + AvatarURL: "https://cdn.example/oidc-bind.png", + Email: "oidc-bind@example.com", + EmailVerified: true, + }) + defer cleanup() + + handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg) + t.Cleanup(func() { _ = client.Close() }) + + ctx := context.Background() + currentUser, err := client.User.Create(). + SetEmail("current@example.com"). + SetUsername("current-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-bind", nil) + req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-bind")) + req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/settings/connections")) + req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-bind")) + req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-subject-bind")) + req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentBindCurrentUser)) + req.AddCookie(encodedCookie(oidcOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret"))) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-bind")) + c.Request = req + + handler.OIDCOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.Equal(t, oauthIntentBindCurrentUser, session.Intent) + require.NotNil(t, session.TargetUserID) + require.Equal(t, currentUser.ID, *session.TargetUserID) + require.Equal(t, cfg.IssuerURL, session.ProviderKey) + require.Equal(t, "OIDC Bind Display", session.UpstreamIdentityClaims["suggested_display_name"]) + + completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.True(t, ok) + require.Equal(t, "/settings/connections", completion["redirect"]) + require.Empty(t, completion["access_token"]) + + userCount, err := client.User.Query().Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, userCount) +} + +func TestCompleteOIDCOAuthRegistrationAppliesPendingAdoptionDecision(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("oidc-complete-session"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example.com"). + SetProviderSubject("oidc-subject-1"). + SetResolvedEmail("93a310f4c1944c5bbd2e246df1f76485@oidc-connect.invalid"). + SetBrowserSessionKey("oidc-browser"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "issuer": "https://issuer.example.com", + "suggested_display_name": "OIDC Display", + "suggested_avatar_url": "https://cdn.example/oidc.png", + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + _, err = service.NewAuthPendingIdentityService(client).UpsertAdoptionDecision(ctx, service.PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: session.ID, + AdoptAvatar: true, + }) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-browser")}) + c.Request = req + + handler.CompleteOIDCOAuthRegistration(c) + + require.Equal(t, http.StatusOK, recorder.Code) + responseData := decodeJSONBody(t, recorder) + require.NotEmpty(t, responseData["access_token"]) + + userEntity, err := client.User.Query(). + Where(dbuser.EmailEQ(session.ResolvedEmail)). + Only(ctx) + require.NoError(t, err) + require.Equal(t, "OIDC Display", userEntity.Username) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example.com"), + authidentity.ProviderSubjectEQ("oidc-subject-1"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, userEntity.ID, identity.UserID) + require.Equal(t, "OIDC Display", identity.Metadata["display_name"]) + require.Equal(t, "https://cdn.example/oidc.png", identity.Metadata["avatar_url"]) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, decision.IdentityID) + require.Equal(t, identity.ID, *decision.IdentityID) + require.True(t, decision.AdoptDisplayName) + require.True(t, decision.AdoptAvatar) + + consumed, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) +} + +func TestCompleteOIDCOAuthRegistrationRejectsAdoptExistingUserSession(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + existingUser, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("oidc-complete-invalid-session"). + SetIntent("adopt_existing_user_by_email"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example.com"). + SetProviderSubject("oidc-invalid-subject-1"). + SetTargetUserID(existingUser.ID). + SetResolvedEmail(existingUser.Email). + SetBrowserSessionKey("oidc-invalid-browser"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "step": "bind_login_required", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-invalid-browser")}) + c.Request = req + + handler.CompleteOIDCOAuthRegistration(c) + + require.Equal(t, http.StatusBadRequest, recorder.Code) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + +type oidcProviderFixture struct { + Subject string + PreferredUsername string + DisplayName string + AvatarURL string + Email string + EmailVerified bool +} + +func newOIDCOAuthTestHandler(t *testing.T, invitationEnabled bool, oauthCfg config.OIDCConnectConfig) *AuthHandler { + t.Helper() + handler, _ := newOIDCOAuthHandlerAndClient(t, invitationEnabled, oauthCfg) + return handler +} + +func newOIDCOAuthHandlerAndClient(t *testing.T, invitationEnabled bool, oauthCfg config.OIDCConnectConfig) (*AuthHandler, *dbent.Client) { + t.Helper() + handler, client := newOAuthPendingFlowTestHandler(t, invitationEnabled) + handler.settingSvc = nil + handler.cfg = &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret", + ExpireHour: 1, + AccessTokenExpireMinutes: 60, + RefreshTokenExpireDays: 7, + }, + OIDC: oauthCfg, + } + return handler, client +} + +func newOIDCTestProvider(t *testing.T, fixture oidcProviderFixture) (config.OIDCConnectConfig, func()) { + t.Helper() + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + kid := "test-kid" + jwks := oidcJWKSet{Keys: []oidcJWK{buildRSAJWK(kid, &privateKey.PublicKey)}} + tokenResponse := oidcTokenResponse{ + AccessToken: "oidc-access-token", + TokenType: "Bearer", + ExpiresIn: 3600, + } + + userInfoPayload := map[string]any{ + "sub": fixture.Subject, + "preferred_username": fixture.PreferredUsername, + "name": fixture.DisplayName, + "picture": fixture.AvatarURL, + "email": fixture.Email, + "email_verified": fixture.EmailVerified, + } + + var issuer string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/token": + require.NoError(t, json.NewEncoder(w).Encode(tokenResponse)) + case "/userinfo": + require.NoError(t, json.NewEncoder(w).Encode(userInfoPayload)) + case "/jwks": + require.NoError(t, json.NewEncoder(w).Encode(jwks)) + default: + http.NotFound(w, r) + } + })) + + issuer = server.URL + now := time.Now() + claims := oidcIDTokenClaims{ + Email: fixture.Email, + EmailVerified: boolPtr(fixture.EmailVerified), + PreferredUsername: fixture.PreferredUsername, + Name: fixture.DisplayName, + Nonce: "nonce-" + fixture.Subject, + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: issuer, + Subject: fixture.Subject, + Audience: jwt.ClaimStrings{"oidc-client"}, + IssuedAt: jwt.NewNumericDate(now), + NotBefore: jwt.NewNumericDate(now.Add(-30 * time.Second)), + ExpiresAt: jwt.NewNumericDate(now.Add(5 * time.Minute)), + }, + } + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token.Header["kid"] = kid + tokenResponse.IDToken, err = token.SignedString(privateKey) + require.NoError(t, err) + + cfg := config.OIDCConnectConfig{ + Enabled: true, + ProviderName: "Test OIDC", + ClientID: "oidc-client", + ClientSecret: "oidc-secret", + IssuerURL: issuer, + AuthorizeURL: issuer + "/authorize", + TokenURL: issuer + "/token", + UserInfoURL: issuer + "/userinfo", + JWKSURL: issuer + "/jwks", + Scopes: "openid profile email", + RedirectURL: "https://api.example.com/api/v1/auth/oauth/oidc/callback", + FrontendRedirectURL: "/auth/oidc/callback", + TokenAuthMethod: "client_secret_post", + UsePKCE: true, + ValidateIDToken: true, + AllowedSigningAlgs: "RS256", + ClockSkewSeconds: 120, + RequireEmailVerified: false, + } + return cfg, server.Close +} diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go new file mode 100644 index 0000000000000000000000000000000000000000..78f5d7c2aee86f4fc27cb5b72ac86d31bc15f0bb --- /dev/null +++ b/backend/internal/handler/auth_wechat_oauth.go @@ -0,0 +1,1334 @@ +package handler + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/internal/payment" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +const ( + wechatOAuthCookiePath = "/api/v1/auth/oauth/wechat" + wechatOAuthCookieMaxAgeSec = 10 * 60 + wechatOAuthStateCookieName = "wechat_oauth_state" + wechatOAuthRedirectCookieName = "wechat_oauth_redirect" + wechatOAuthIntentCookieName = "wechat_oauth_intent" + wechatOAuthModeCookieName = "wechat_oauth_mode" + wechatOAuthBindUserCookieName = "wechat_oauth_bind_user" + wechatOAuthDefaultRedirectTo = "/dashboard" + wechatOAuthDefaultFrontendCB = "/auth/wechat/callback" + wechatOAuthProviderKey = "wechat-main" + wechatOAuthLegacyProviderKey = "wechat" + wechatPaymentOAuthCookiePath = "/api/v1/auth/oauth/wechat/payment" + wechatPaymentOAuthStateName = "wechat_payment_oauth_state" + wechatPaymentOAuthRedirect = "wechat_payment_oauth_redirect" + wechatPaymentOAuthContextName = "wechat_payment_oauth_context" + wechatPaymentOAuthScope = "wechat_payment_oauth_scope" + wechatPaymentOAuthDefaultTo = "/purchase" + wechatPaymentOAuthFrontendCB = "/auth/wechat/payment/callback" + + wechatOAuthIntentLogin = "login" + wechatOAuthIntentBind = "bind_current_user" + wechatOAuthIntentAdoptEmail = "adopt_existing_user_by_email" +) + +var ( + wechatOAuthAccessTokenURL = "https://api.weixin.qq.com/sns/oauth2/access_token" + wechatOAuthUserInfoURL = "https://api.weixin.qq.com/sns/userinfo" +) + +type wechatOAuthConfig struct { + mode string + appID string + appSecret string + authorizeURL string + scope string + redirectURI string + frontendCallback string + openEnabled bool + mpEnabled bool +} + +type wechatOAuthTokenResponse struct { + AccessToken string `json:"access_token"` + ExpiresIn int64 `json:"expires_in"` + RefreshToken string `json:"refresh_token"` + OpenID string `json:"openid"` + Scope string `json:"scope"` + UnionID string `json:"unionid"` + ErrCode int64 `json:"errcode"` + ErrMsg string `json:"errmsg"` +} + +type wechatOAuthUserInfoResponse struct { + OpenID string `json:"openid"` + Nickname string `json:"nickname"` + HeadImgURL string `json:"headimgurl"` + UnionID string `json:"unionid"` + ErrCode int64 `json:"errcode"` + ErrMsg string `json:"errmsg"` +} + +type wechatPaymentOAuthContext struct { + PaymentType string `json:"payment_type"` + Amount string `json:"amount,omitempty"` + OrderType string `json:"order_type,omitempty"` + PlanID int64 `json:"plan_id,omitempty"` +} + +// WeChatOAuthStart starts the WeChat OAuth login flow and stores the short-lived +// browser cookies required by the rebuild pending-auth bridge. +func (h *AuthHandler) WeChatOAuthStart(c *gin.Context) { + cfg, err := h.getWeChatOAuthConfig(c.Request.Context(), c.Query("mode"), c) + if err != nil { + response.ErrorFrom(c, err) + return + } + + state, err := oauth.GenerateState() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err)) + return + } + + redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect")) + if redirectTo == "" { + redirectTo = wechatOAuthDefaultRedirectTo + } + + browserSessionKey, err := generateOAuthPendingBrowserSession() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BROWSER_SESSION_GEN_FAILED", "failed to generate oauth browser session").WithCause(err)) + return + } + + intent := normalizeWeChatOAuthIntent(c.Query("intent")) + secureCookie := isRequestHTTPS(c) + wechatSetCookie(c, wechatOAuthStateCookieName, encodeCookieValue(state), wechatOAuthCookieMaxAgeSec, secureCookie) + wechatSetCookie(c, wechatOAuthRedirectCookieName, encodeCookieValue(redirectTo), wechatOAuthCookieMaxAgeSec, secureCookie) + wechatSetCookie(c, wechatOAuthIntentCookieName, encodeCookieValue(intent), wechatOAuthCookieMaxAgeSec, secureCookie) + wechatSetCookie(c, wechatOAuthModeCookieName, encodeCookieValue(cfg.mode), wechatOAuthCookieMaxAgeSec, secureCookie) + setOAuthPendingBrowserCookie(c, browserSessionKey, secureCookie) + clearOAuthPendingSessionCookie(c, secureCookie) + if intent == oauthIntentBindCurrentUser { + bindCookieValue, err := h.buildOAuthBindUserCookieFromContext(c) + if err != nil { + response.ErrorFrom(c, err) + return + } + wechatSetCookie(c, wechatOAuthBindUserCookieName, encodeCookieValue(bindCookieValue), wechatOAuthCookieMaxAgeSec, secureCookie) + } else { + wechatClearCookie(c, wechatOAuthBindUserCookieName, secureCookie) + } + + authURL, err := buildWeChatAuthorizeURL(cfg, state) + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err)) + return + } + + c.Redirect(http.StatusFound, authURL) +} + +// WeChatOAuthCallback exchanges the code with WeChat, resolves openid/unionid, +// and stores the result in the unified pending-auth flow. +func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) { + frontendCallback := h.wechatOAuthFrontendCallback(c.Request.Context()) + + if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" { + redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description")) + return + } + + code := strings.TrimSpace(c.Query("code")) + state := strings.TrimSpace(c.Query("state")) + if code == "" || state == "" { + redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "") + return + } + + secureCookie := isRequestHTTPS(c) + defer func() { + wechatClearCookie(c, wechatOAuthStateCookieName, secureCookie) + wechatClearCookie(c, wechatOAuthRedirectCookieName, secureCookie) + wechatClearCookie(c, wechatOAuthIntentCookieName, secureCookie) + wechatClearCookie(c, wechatOAuthModeCookieName, secureCookie) + wechatClearCookie(c, wechatOAuthBindUserCookieName, secureCookie) + }() + + expectedState, err := readCookieDecoded(c, wechatOAuthStateCookieName) + if err != nil || expectedState == "" || state != expectedState { + redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "") + return + } + + redirectTo, _ := readCookieDecoded(c, wechatOAuthRedirectCookieName) + redirectTo = sanitizeFrontendRedirectPath(redirectTo) + if redirectTo == "" { + redirectTo = wechatOAuthDefaultRedirectTo + } + browserSessionKey, _ := readOAuthPendingBrowserCookie(c) + if strings.TrimSpace(browserSessionKey) == "" { + redirectOAuthError(c, frontendCallback, "missing_browser_session", "missing oauth browser session", "") + return + } + + intent, _ := readCookieDecoded(c, wechatOAuthIntentCookieName) + mode, err := readCookieDecoded(c, wechatOAuthModeCookieName) + if err != nil || strings.TrimSpace(mode) == "" { + redirectOAuthError(c, frontendCallback, "invalid_state", "missing oauth mode", "") + return + } + + cfg, err := h.getWeChatOAuthConfig(c.Request.Context(), mode, c) + if err != nil { + redirectOAuthError(c, frontendCallback, "provider_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + + tokenResp, userInfo, err := fetchWeChatOAuthIdentity(c.Request.Context(), cfg, code) + if err != nil { + redirectOAuthError(c, frontendCallback, "provider_error", "wechat_identity_fetch_failed", singleLine(err.Error())) + return + } + + unionid := strings.TrimSpace(firstNonEmpty(userInfo.UnionID, tokenResp.UnionID)) + openid := strings.TrimSpace(firstNonEmpty(userInfo.OpenID, tokenResp.OpenID)) + providerSubject := unionid + if providerSubject == "" { + if cfg.requiresUnionID() { + redirectOAuthError(c, frontendCallback, "provider_error", "wechat_missing_unionid", "") + return + } + providerSubject = openid + } + if providerSubject == "" { + redirectOAuthError(c, frontendCallback, "provider_error", "wechat_missing_unionid", "") + return + } + + username := firstNonEmpty(userInfo.Nickname, wechatFallbackUsername(providerSubject)) + email := wechatSyntheticEmail(providerSubject) + upstreamClaims := map[string]any{ + "email": email, + "username": username, + "subject": providerSubject, + "openid": openid, + "unionid": unionid, + "mode": cfg.mode, + "channel": cfg.mode, + "channel_app_id": strings.TrimSpace(cfg.appID), + "channel_subject": openid, + "suggested_display_name": strings.TrimSpace(userInfo.Nickname), + "suggested_avatar_url": strings.TrimSpace(userInfo.HeadImgURL), + } + identityRef := service.PendingAuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: wechatOAuthProviderKey, + ProviderSubject: providerSubject, + } + + normalizedIntent := normalizeWeChatOAuthIntent(intent) + if normalizedIntent == wechatOAuthIntentBind { + if err := h.createWeChatBindPendingSession(c, cfg, providerSubject, openid, redirectTo, browserSessionKey, upstreamClaims); err != nil { + switch infraerrors.Code(err) { + case http.StatusConflict: + redirectOAuthError(c, frontendCallback, "ownership_conflict", infraerrors.Reason(err), infraerrors.Message(err)) + case http.StatusUnauthorized, http.StatusForbidden: + redirectOAuthError(c, frontendCallback, "auth_required", infraerrors.Reason(err), infraerrors.Message(err)) + default: + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + } + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } + + existingIdentityUser, err := h.findOAuthIdentityUser(c.Request.Context(), identityRef) + if err != nil { + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + if existingIdentityUser == nil { + existingIdentityUser, err = h.findWeChatUserByLegacyOpenID(c.Request.Context(), identityRef, cfg, openid) + if err != nil { + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + } + if existingIdentityUser != nil { + if err := h.ensureWeChatRuntimeIdentityBinding(c.Request.Context(), existingIdentityUser.ID, identityRef, upstreamClaims); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "") + if err != nil { + redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + if err := h.createWeChatPendingSession(c, normalizedIntent, providerSubject, existingIdentityUser.Email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, nil, &user.ID); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } + + if h.isForceEmailOnThirdPartySignup(c.Request.Context()) { + if err := h.createWeChatChoicePendingSession( + c, + identityRef, + email, + email, + redirectTo, + browserSessionKey, + upstreamClaims, + "", + nil, + true, + ); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } + + if err := h.createWeChatChoicePendingSession( + c, + identityRef, + email, + email, + redirectTo, + browserSessionKey, + upstreamClaims, + "", + nil, + false, + ); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) +} + +// WeChatPaymentOAuthStart starts the WeChat payment OAuth flow. +// GET /api/v1/auth/oauth/wechat/payment/start?payment_type=wxpay&redirect=/purchase +func (h *AuthHandler) WeChatPaymentOAuthStart(c *gin.Context) { + cfg, err := h.getWeChatOAuthConfig(c.Request.Context(), "mp", c) + if err != nil { + response.ErrorFrom(c, err) + return + } + + paymentType := normalizeWeChatPaymentType(c.Query("payment_type")) + if paymentType == "" { + response.BadRequest(c, "Invalid payment type") + return + } + + state, err := oauth.GenerateState() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err)) + return + } + + redirectTo := normalizeWeChatPaymentRedirectPath(sanitizeFrontendRedirectPath(c.Query("redirect"))) + if redirectTo == "" { + redirectTo = wechatPaymentOAuthDefaultTo + } + rawContext, err := encodeWeChatPaymentOAuthContext(wechatPaymentOAuthContext{ + PaymentType: paymentType, + Amount: strings.TrimSpace(c.Query("amount")), + OrderType: strings.TrimSpace(c.Query("order_type")), + PlanID: parseWeChatPaymentPlanID(c.Query("plan_id")), + }) + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_CONTEXT_ENCODE_FAILED", "failed to encode oauth context").WithCause(err)) + return + } + + scope := normalizeWeChatPaymentScope(c.Query("scope")) + secureCookie := isRequestHTTPS(c) + wechatPaymentSetCookie(c, wechatPaymentOAuthStateName, encodeCookieValue(state), wechatOAuthCookieMaxAgeSec, secureCookie) + wechatPaymentSetCookie(c, wechatPaymentOAuthRedirect, encodeCookieValue(redirectTo), wechatOAuthCookieMaxAgeSec, secureCookie) + wechatPaymentSetCookie(c, wechatPaymentOAuthContextName, encodeCookieValue(rawContext), wechatOAuthCookieMaxAgeSec, secureCookie) + wechatPaymentSetCookie(c, wechatPaymentOAuthScope, encodeCookieValue(scope), wechatOAuthCookieMaxAgeSec, secureCookie) + + cfg.redirectURI = h.resolveWeChatPaymentOAuthCallbackURL(c.Request.Context(), c) + cfg.scope = scope + authURL, err := buildWeChatAuthorizeURL(cfg, state) + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err)) + return + } + + c.Redirect(http.StatusFound, authURL) +} + +// WeChatPaymentOAuthCallback exchanges a payment OAuth code for an OpenID and +// forwards the browser back to the frontend callback route. +func (h *AuthHandler) WeChatPaymentOAuthCallback(c *gin.Context) { + frontendCallback := wechatPaymentOAuthFrontendCB + + if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" { + redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description")) + return + } + + code := strings.TrimSpace(c.Query("code")) + state := strings.TrimSpace(c.Query("state")) + if code == "" || state == "" { + redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "") + return + } + + secureCookie := isRequestHTTPS(c) + defer func() { + wechatPaymentClearCookie(c, wechatPaymentOAuthStateName, secureCookie) + wechatPaymentClearCookie(c, wechatPaymentOAuthRedirect, secureCookie) + wechatPaymentClearCookie(c, wechatPaymentOAuthContextName, secureCookie) + wechatPaymentClearCookie(c, wechatPaymentOAuthScope, secureCookie) + }() + + expectedState, err := readCookieDecoded(c, wechatPaymentOAuthStateName) + if err != nil || expectedState == "" || state != expectedState { + redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "") + return + } + + redirectTo, _ := readCookieDecoded(c, wechatPaymentOAuthRedirect) + redirectTo = normalizeWeChatPaymentRedirectPath(sanitizeFrontendRedirectPath(redirectTo)) + if redirectTo == "" { + redirectTo = wechatPaymentOAuthDefaultTo + } + + rawContext, _ := readCookieDecoded(c, wechatPaymentOAuthContextName) + paymentContext, err := decodeWeChatPaymentOAuthContext(rawContext) + if err != nil { + redirectOAuthError(c, frontendCallback, "invalid_context", "invalid oauth context", "") + return + } + if paymentContext.PaymentType == "" { + paymentContext.PaymentType = payment.TypeWxpay + } + + scope, _ := readCookieDecoded(c, wechatPaymentOAuthScope) + scope = normalizeWeChatPaymentScope(scope) + + cfg, err := h.getWeChatOAuthConfig(c.Request.Context(), "mp", c) + if err != nil { + redirectOAuthError(c, frontendCallback, "provider_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + cfg.redirectURI = h.resolveWeChatPaymentOAuthCallbackURL(c.Request.Context(), c) + tokenResp, err := exchangeWeChatOAuthCode(c.Request.Context(), cfg, code) + if err != nil { + redirectOAuthError(c, frontendCallback, "token_exchange_failed", "failed to exchange oauth code", err.Error()) + return + } + + openid := strings.TrimSpace(tokenResp.OpenID) + if openid == "" { + redirectOAuthError(c, frontendCallback, "missing_openid", "missing openid", "") + return + } + if strings.TrimSpace(tokenResp.Scope) != "" { + scope = strings.TrimSpace(tokenResp.Scope) + } + + resumeToken, err := h.wechatPaymentResumeService().CreateWeChatPaymentResumeToken(service.WeChatPaymentResumeClaims{ + OpenID: openid, + PaymentType: paymentContext.PaymentType, + Amount: paymentContext.Amount, + OrderType: paymentContext.OrderType, + PlanID: paymentContext.PlanID, + RedirectTo: redirectTo, + Scope: scope, + }) + if err != nil { + redirectOAuthError(c, frontendCallback, "invalid_context", "failed to encode payment resume context", "") + return + } + + fragment := url.Values{} + fragment.Set("wechat_resume_token", resumeToken) + fragment.Set("redirect", redirectTo) + redirectWithFragment(c, frontendCallback, fragment) +} + +func (h *AuthHandler) wechatPaymentResumeService() *service.PaymentResumeService { + key, err := payment.ProvideEncryptionKey(h.cfg) + if err != nil { + return service.NewPaymentResumeService(nil) + } + return service.NewPaymentResumeService([]byte(key)) +} + +type completeWeChatOAuthRequest struct { + InvitationCode string `json:"invitation_code" binding:"required"` + AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` + AdoptAvatar *bool `json:"adopt_avatar,omitempty"` +} + +// CompleteWeChatOAuthRegistration completes a pending WeChat OAuth registration by +// validating the invitation code and consuming the current pending browser session. +// POST /api/v1/auth/oauth/wechat/complete-registration +func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) { + var req completeWeChatOAuthRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "INVALID_REQUEST", "message": err.Error()}) + return + } + + secureCookie := isRequestHTTPS(c) + sessionToken, err := readOAuthPendingSessionCookie(c) + if err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound) + return + } + browserSessionKey, err := readOAuthPendingBrowserCookie(c) + if err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch) + return + } + pendingSvc, err := h.pendingIdentityService() + if err != nil { + response.ErrorFrom(c, err) + return + } + session, err := pendingSvc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey) + if err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, err) + return + } + if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil { + response.ErrorFrom(c, err) + return + } + if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + email := strings.TrimSpace(session.ResolvedEmail) + username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username") + if email == "" || username == "" { + response.ErrorFrom(c, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")) + return + } + + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) + if err != nil { + response.ErrorFrom(c, err) + return + } + decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{ + AdoptDisplayName: req.AdoptDisplayName, + AdoptAvatar: req.AdoptAvatar, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID); err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) + return + } + h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) + if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, err) + return + } + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + + c.JSON(http.StatusOK, gin.H{ + "access_token": tokenPair.AccessToken, + "refresh_token": tokenPair.RefreshToken, + "expires_in": tokenPair.ExpiresIn, + "token_type": "Bearer", + }) +} + +func (h *AuthHandler) createWeChatPendingSession( + c *gin.Context, + intent string, + providerSubject string, + email string, + redirectTo string, + browserSessionKey string, + upstreamClaims map[string]any, + tokenPair *service.TokenPair, + authErr error, + targetUserID *int64, +) error { + completionResponse := map[string]any{ + "redirect": redirectTo, + } + if authErr != nil { + if errors.Is(authErr, service.ErrOAuthInvitationRequired) { + completionResponse["error"] = "invitation_required" + } else { + return authErr + } + } else if tokenPair != nil { + completionResponse["access_token"] = tokenPair.AccessToken + completionResponse["refresh_token"] = tokenPair.RefreshToken + completionResponse["expires_in"] = tokenPair.ExpiresIn + completionResponse["token_type"] = "Bearer" + } + + return h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: intent, + Identity: service.PendingAuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: wechatOAuthProviderKey, + ProviderSubject: providerSubject, + }, + TargetUserID: targetUserID, + ResolvedEmail: email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, + CompletionResponse: completionResponse, + }) +} + +func (h *AuthHandler) createWeChatChoicePendingSession( + c *gin.Context, + identity service.PendingAuthIdentityKey, + suggestedEmail string, + resolvedEmail string, + redirectTo string, + browserSessionKey string, + upstreamClaims map[string]any, + compatEmail string, + compatEmailUser *dbent.User, + forceEmailOnSignup bool, +) error { + suggestionEmail := strings.TrimSpace(suggestedEmail) + canonicalEmail := strings.TrimSpace(resolvedEmail) + if suggestionEmail == "" { + suggestionEmail = canonicalEmail + } + + completionResponse := map[string]any{ + "step": oauthPendingChoiceStep, + "adoption_required": true, + "redirect": strings.TrimSpace(redirectTo), + "email": suggestionEmail, + "resolved_email": canonicalEmail, + "existing_account_email": "", + "existing_account_bindable": false, + "create_account_allowed": true, + "force_email_on_signup": forceEmailOnSignup, + "choice_reason": "third_party_signup", + } + if strings.TrimSpace(compatEmail) != "" { + completionResponse["compat_email"] = strings.TrimSpace(compatEmail) + } + if compatEmailUser != nil { + completionResponse["email"] = strings.TrimSpace(compatEmailUser.Email) + completionResponse["existing_account_email"] = strings.TrimSpace(compatEmailUser.Email) + completionResponse["existing_account_bindable"] = true + completionResponse["choice_reason"] = "compat_email_match" + } + if forceEmailOnSignup { + completionResponse["choice_reason"] = "force_email_on_signup" + } + + resolvedChoiceEmail := suggestionEmail + if compatEmailUser != nil { + resolvedChoiceEmail = strings.TrimSpace(compatEmailUser.Email) + } + + return h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: oauthIntentLogin, + Identity: identity, + ResolvedEmail: resolvedChoiceEmail, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, + CompletionResponse: completionResponse, + }) +} + +func (h *AuthHandler) createWeChatBindPendingSession( + c *gin.Context, + cfg wechatOAuthConfig, + providerSubject string, + channelSubject string, + redirectTo string, + browserSessionKey string, + upstreamClaims map[string]any, +) error { + currentUser, err := h.readOAuthBindTargetUser(c, wechatOAuthBindUserCookieName) + if err != nil { + return err + } + if err := h.ensureWeChatBindOwnership(c.Request.Context(), currentUser.ID, providerSubject, cfg, channelSubject); err != nil { + return err + } + return h.createWeChatPendingSession( + c, + wechatOAuthIntentBind, + providerSubject, + currentUser.Email, + redirectTo, + browserSessionKey, + upstreamClaims, + nil, + nil, + ¤tUser.ID, + ) +} + +func (h *AuthHandler) readOAuthBindTargetUser(c *gin.Context, cookieName string) (*dbent.User, error) { + client := h.entClient() + if client == nil { + return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + userID, err := h.readOAuthBindUserIDFromCookie(c, cookieName) + if err != nil { + return nil, infraerrors.Unauthorized("AUTH_REQUIRED", "current user is required to bind wechat account") + } + userEntity, err := client.User.Get(c.Request.Context(), userID) + if err != nil { + if dbent.IsNotFound(err) { + return nil, infraerrors.Unauthorized("AUTH_REQUIRED", "current user is required to bind wechat account") + } + return nil, infraerrors.InternalServer("WECHAT_BIND_USER_LOOKUP_FAILED", "failed to load current user").WithCause(err) + } + return userEntity, nil +} + +func (h *AuthHandler) ensureWeChatBindOwnership( + ctx context.Context, + userID int64, + providerSubject string, + cfg wechatOAuthConfig, + channelSubject string, +) error { + client := h.entClient() + if client == nil { + return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + + identities, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderKeyIn(wechatCompatibleProviderKeys(wechatOAuthProviderKey)...), + authidentity.ProviderSubjectEQ(strings.TrimSpace(providerSubject)), + ). + All(ctx) + if err != nil { + return infraerrors.InternalServer("WECHAT_BIND_LOOKUP_FAILED", "failed to inspect wechat identity ownership").WithCause(err) + } + for _, identity := range identities { + if identity != nil && identity.UserID != userID { + activeOwner, lookupErr := findActiveUserByID(ctx, client, identity.UserID) + if lookupErr != nil { + return lookupErr + } + if activeOwner != nil { + return infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + } + } + } + + channelSubject = strings.TrimSpace(channelSubject) + channelAppID := strings.TrimSpace(cfg.appID) + if channelSubject == "" || channelAppID == "" { + return nil + } + + channels, err := client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ("wechat"), + authidentitychannel.ProviderKeyIn(wechatCompatibleProviderKeys(wechatOAuthProviderKey)...), + authidentitychannel.ChannelEQ(strings.TrimSpace(cfg.mode)), + authidentitychannel.ChannelAppIDEQ(channelAppID), + authidentitychannel.ChannelSubjectEQ(channelSubject), + ). + WithIdentity(). + All(ctx) + if err != nil { + return infraerrors.InternalServer("WECHAT_BIND_CHANNEL_LOOKUP_FAILED", "failed to inspect wechat identity channel ownership").WithCause(err) + } + for _, channel := range channels { + if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != userID { + activeOwner, lookupErr := findActiveUserByID(ctx, client, channel.Edges.Identity.UserID) + if lookupErr != nil { + return lookupErr + } + if activeOwner != nil { + return infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user") + } + } + } + return nil +} + +func (h *AuthHandler) findWeChatUserByLegacyOpenID( + ctx context.Context, + identity service.PendingAuthIdentityKey, + cfg wechatOAuthConfig, + openid string, +) (*dbent.User, error) { + client := h.entClient() + if client == nil { + return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + + providerType := strings.TrimSpace(identity.ProviderType) + providerSubject := strings.TrimSpace(identity.ProviderSubject) + providerKeys := wechatCompatibleProviderKeys(identity.ProviderKey) + if providerSubject != "" { + records, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(providerType), + authidentity.ProviderKeyIn(providerKeys...), + authidentity.ProviderSubjectEQ(providerSubject), + ). + WithUser(). + All(ctx) + if err != nil { + return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err) + } + if user, err := singleWeChatIdentityUser(records); err != nil || user != nil { + return user, err + } + } + + openid = strings.TrimSpace(openid) + channel := strings.TrimSpace(cfg.mode) + channelAppID := strings.TrimSpace(cfg.appID) + if openid != "" && channel != "" && channelAppID != "" { + records, err := client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ(providerType), + authidentitychannel.ProviderKeyIn(providerKeys...), + authidentitychannel.ChannelEQ(channel), + authidentitychannel.ChannelAppIDEQ(channelAppID), + authidentitychannel.ChannelSubjectEQ(openid), + ). + WithIdentity(func(q *dbent.AuthIdentityQuery) { + q.WithUser() + }). + All(ctx) + if err != nil { + return nil, infraerrors.InternalServer("AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err) + } + if user, err := singleWeChatChannelUser(records); err != nil || user != nil { + return user, err + } + } + + if openid == "" { + return nil, nil + } + + records, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(providerType), + authidentity.ProviderKeyIn(providerKeys...), + authidentity.ProviderSubjectEQ(openid), + ). + WithUser(). + All(ctx) + if err != nil { + return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err) + } + return singleWeChatIdentityUser(records) +} + +func wechatCompatibleProviderKeys(providerKey string) []string { + preferred := strings.TrimSpace(providerKey) + if preferred == "" { + preferred = wechatOAuthProviderKey + } + keys := []string{preferred} + if !strings.EqualFold(preferred, wechatOAuthLegacyProviderKey) { + keys = append(keys, wechatOAuthLegacyProviderKey) + } + return keys +} + +func singleWeChatIdentityUser(records []*dbent.AuthIdentity) (*dbent.User, error) { + var resolved *dbent.User + for _, record := range records { + if record == nil || record.Edges.User == nil { + continue + } + if resolved == nil { + resolved = record.Edges.User + continue + } + if resolved.ID != record.Edges.User.ID { + return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + } + } + return resolved, nil +} + +func singleWeChatChannelUser(records []*dbent.AuthIdentityChannel) (*dbent.User, error) { + var resolved *dbent.User + for _, record := range records { + if record == nil || record.Edges.Identity == nil || record.Edges.Identity.Edges.User == nil { + continue + } + if resolved == nil { + resolved = record.Edges.Identity.Edges.User + continue + } + if resolved.ID != record.Edges.Identity.Edges.User.ID { + return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user") + } + } + return resolved, nil +} + +func (h *AuthHandler) ensureWeChatRuntimeIdentityBinding( + ctx context.Context, + userID int64, + identity service.PendingAuthIdentityKey, + upstreamClaims map[string]any, +) error { + client := h.entClient() + if client == nil { + return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + + tx, err := client.Tx(ctx) + if err != nil { + return infraerrors.InternalServer("AUTH_IDENTITY_BIND_FAILED", "failed to begin wechat identity repair transaction").WithCause(err) + } + defer func() { _ = tx.Rollback() }() + + _, err = ensurePendingOAuthIdentityForUser(dbent.NewTxContext(ctx, tx), tx, &dbent.PendingAuthSession{ + ProviderType: strings.TrimSpace(identity.ProviderType), + ProviderKey: strings.TrimSpace(identity.ProviderKey), + ProviderSubject: strings.TrimSpace(identity.ProviderSubject), + UpstreamIdentityClaims: cloneOAuthMetadata(upstreamClaims), + }, userID) + if err != nil { + return err + } + return tx.Commit() +} + +func (h *AuthHandler) getWeChatOAuthConfig(ctx context.Context, rawMode string, c *gin.Context) (wechatOAuthConfig, error) { + mode, err := resolveWeChatOAuthMode(rawMode, c) + if err != nil { + return wechatOAuthConfig{}, err + } + + if h == nil || h.settingSvc == nil { + return wechatOAuthConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "wechat oauth settings service not ready") + } + + apiBaseURL := "" + if h != nil && h.settingSvc != nil { + settings, err := h.settingSvc.GetAllSettings(ctx) + if err == nil && settings != nil { + apiBaseURL = strings.TrimSpace(settings.APIBaseURL) + } + } + + effective, err := h.settingSvc.GetWeChatConnectOAuthConfig(ctx) + if err != nil { + return wechatOAuthConfig{}, err + } + if !effective.SupportsMode(mode) { + return wechatOAuthConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "wechat oauth is disabled") + } + + cfg := wechatOAuthConfig{ + mode: mode, + appID: strings.TrimSpace(effective.AppIDForMode(mode)), + appSecret: strings.TrimSpace(effective.AppSecretForMode(mode)), + redirectURI: firstNonEmpty(strings.TrimSpace(effective.RedirectURL), resolveWeChatOAuthAbsoluteURL(apiBaseURL, c, "/api/v1/auth/oauth/wechat/callback")), + frontendCallback: firstNonEmpty(strings.TrimSpace(effective.FrontendRedirectURL), wechatOAuthDefaultFrontendCB), + scope: effective.ScopeForMode(mode), + openEnabled: effective.OpenEnabled, + mpEnabled: effective.MPEnabled, + } + + switch mode { + case "mp": + cfg.authorizeURL = "https://open.weixin.qq.com/connect/oauth2/authorize" + default: + cfg.authorizeURL = "https://open.weixin.qq.com/connect/qrconnect" + } + if strings.TrimSpace(cfg.redirectURI) == "" { + return wechatOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth redirect url not configured") + } + + return cfg, nil +} + +func (cfg wechatOAuthConfig) requiresUnionID() bool { + return cfg.openEnabled && cfg.mpEnabled +} + +func (h *AuthHandler) wechatOAuthFrontendCallback(ctx context.Context) string { + if h != nil && h.settingSvc != nil { + cfg, err := h.settingSvc.GetWeChatConnectOAuthConfig(ctx) + if err == nil && strings.TrimSpace(cfg.FrontendRedirectURL) != "" { + return strings.TrimSpace(cfg.FrontendRedirectURL) + } + } + return wechatOAuthDefaultFrontendCB +} + +func resolveWeChatOAuthMode(rawMode string, c *gin.Context) (string, error) { + mode := strings.ToLower(strings.TrimSpace(rawMode)) + if mode == "" { + if isWeChatBrowserRequest(c) { + return "mp", nil + } + return "open", nil + } + if mode != "open" && mode != "mp" { + return "", infraerrors.BadRequest("INVALID_MODE", "wechat oauth mode must be open or mp") + } + return mode, nil +} + +func isWeChatBrowserRequest(c *gin.Context) bool { + if c == nil || c.Request == nil { + return false + } + return strings.Contains(strings.ToLower(strings.TrimSpace(c.GetHeader("User-Agent"))), "micromessenger") +} + +func normalizeWeChatOAuthIntent(raw string) string { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "", "login": + return wechatOAuthIntentLogin + case "bind", "bind_current_user": + return wechatOAuthIntentBind + case "adopt", "adopt_existing_user_by_email": + return wechatOAuthIntentAdoptEmail + default: + return wechatOAuthIntentLogin + } +} + +func buildWeChatAuthorizeURL(cfg wechatOAuthConfig, state string) (string, error) { + u, err := url.Parse(cfg.authorizeURL) + if err != nil { + return "", fmt.Errorf("parse authorize url: %w", err) + } + query := u.Query() + query.Set("appid", cfg.appID) + query.Set("redirect_uri", cfg.redirectURI) + query.Set("response_type", "code") + query.Set("scope", cfg.scope) + query.Set("state", state) + u.RawQuery = query.Encode() + u.Fragment = "wechat_redirect" + return u.String(), nil +} + +func resolveWeChatOAuthAbsoluteURL(apiBaseURL string, c *gin.Context, callbackPath string) string { + callbackPath = strings.TrimSpace(callbackPath) + if callbackPath == "" { + return "" + } + + if raw := strings.TrimSpace(apiBaseURL); raw != "" { + if parsed, err := url.Parse(raw); err == nil && parsed.Scheme != "" && parsed.Host != "" { + basePath := strings.TrimRight(parsed.EscapedPath(), "/") + targetPath := callbackPath + if basePath != "" && strings.HasSuffix(basePath, "/api/v1") && strings.HasPrefix(callbackPath, "/api/v1") { + targetPath = basePath + strings.TrimPrefix(callbackPath, "/api/v1") + } else if basePath != "" { + targetPath = basePath + callbackPath + } + return parsed.Scheme + "://" + parsed.Host + targetPath + } + } + + if c == nil || c.Request == nil { + return "" + } + scheme := "http" + if isRequestHTTPS(c) { + scheme = "https" + } + host := strings.TrimSpace(c.Request.Host) + if forwardedHost := strings.TrimSpace(c.GetHeader("X-Forwarded-Host")); forwardedHost != "" { + host = forwardedHost + } + if host == "" { + return "" + } + return scheme + "://" + host + callbackPath +} + +func fetchWeChatOAuthIdentity(ctx context.Context, cfg wechatOAuthConfig, code string) (*wechatOAuthTokenResponse, *wechatOAuthUserInfoResponse, error) { + tokenResp, err := exchangeWeChatOAuthCode(ctx, cfg, code) + if err != nil { + return nil, nil, err + } + userInfo, err := fetchWeChatUserInfo(ctx, tokenResp) + if err != nil { + return nil, nil, err + } + return tokenResp, userInfo, nil +} + +func exchangeWeChatOAuthCode(ctx context.Context, cfg wechatOAuthConfig, code string) (*wechatOAuthTokenResponse, error) { + endpoint, err := url.Parse(wechatOAuthAccessTokenURL) + if err != nil { + return nil, fmt.Errorf("parse wechat access token url: %w", err) + } + + query := endpoint.Query() + query.Set("appid", cfg.appID) + query.Set("secret", cfg.appSecret) + query.Set("code", strings.TrimSpace(code)) + query.Set("grant_type", "authorization_code") + endpoint.RawQuery = query.Encode() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil) + if err != nil { + return nil, fmt.Errorf("build wechat access token request: %w", err) + } + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("request wechat access token: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read wechat access token response: %w", err) + } + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return nil, fmt.Errorf("wechat access token status=%d", resp.StatusCode) + } + + var tokenResp wechatOAuthTokenResponse + if err := json.Unmarshal(body, &tokenResp); err != nil { + return nil, fmt.Errorf("decode wechat access token response: %w", err) + } + if tokenResp.ErrCode != 0 { + return nil, fmt.Errorf("wechat access token error=%d %s", tokenResp.ErrCode, strings.TrimSpace(tokenResp.ErrMsg)) + } + if strings.TrimSpace(tokenResp.AccessToken) == "" { + return nil, fmt.Errorf("wechat access token missing access_token") + } + return &tokenResp, nil +} + +func fetchWeChatUserInfo(ctx context.Context, tokenResp *wechatOAuthTokenResponse) (*wechatOAuthUserInfoResponse, error) { + if tokenResp == nil { + return nil, fmt.Errorf("wechat token response is nil") + } + + endpoint, err := url.Parse(wechatOAuthUserInfoURL) + if err != nil { + return nil, fmt.Errorf("parse wechat userinfo url: %w", err) + } + query := endpoint.Query() + query.Set("access_token", strings.TrimSpace(tokenResp.AccessToken)) + query.Set("openid", strings.TrimSpace(tokenResp.OpenID)) + query.Set("lang", "zh_CN") + endpoint.RawQuery = query.Encode() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil) + if err != nil { + return nil, fmt.Errorf("build wechat userinfo request: %w", err) + } + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("request wechat userinfo: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read wechat userinfo response: %w", err) + } + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return nil, fmt.Errorf("wechat userinfo status=%d", resp.StatusCode) + } + + var userInfo wechatOAuthUserInfoResponse + if err := json.Unmarshal(body, &userInfo); err != nil { + return nil, fmt.Errorf("decode wechat userinfo response: %w", err) + } + if userInfo.ErrCode != 0 { + return nil, fmt.Errorf("wechat userinfo error=%d %s", userInfo.ErrCode, strings.TrimSpace(userInfo.ErrMsg)) + } + return &userInfo, nil +} + +func wechatSyntheticEmail(subject string) string { + subject = strings.TrimSpace(subject) + if subject == "" { + return "" + } + return "wechat-" + subject + service.WeChatConnectSyntheticEmailDomain +} + +func wechatFallbackUsername(subject string) string { + subject = strings.TrimSpace(subject) + if subject == "" { + return "wechat_user" + } + return "wechat_" + truncateFragmentValue(subject) +} + +func wechatSetCookie(c *gin.Context, name string, value string, maxAgeSec int, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: name, + Value: value, + Path: wechatOAuthCookiePath, + MaxAge: maxAgeSec, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func wechatClearCookie(c *gin.Context, name string, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: name, + Value: "", + Path: wechatOAuthCookiePath, + MaxAge: -1, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func normalizeWeChatPaymentType(raw string) string { + switch strings.TrimSpace(raw) { + case payment.TypeWxpay, payment.TypeWxpayDirect: + return strings.TrimSpace(raw) + default: + return "" + } +} + +func normalizeWeChatPaymentScope(raw string) string { + for _, part := range strings.FieldsFunc(strings.TrimSpace(raw), func(r rune) bool { + return r == ',' || r == ' ' || r == '\t' || r == '\n' || r == '\r' + }) { + switch strings.TrimSpace(part) { + case "snsapi_userinfo": + return "snsapi_userinfo" + case "snsapi_base": + return "snsapi_base" + } + } + return "snsapi_base" +} + +func normalizeWeChatPaymentRedirectPath(path string) string { + path = strings.TrimSpace(path) + if path == "" { + return wechatPaymentOAuthDefaultTo + } + if path == "/payment" { + return "/purchase" + } + if strings.HasPrefix(path, "/payment?") { + return "/purchase" + strings.TrimPrefix(path, "/payment") + } + return path +} + +func (h *AuthHandler) resolveWeChatPaymentOAuthCallbackURL(ctx context.Context, c *gin.Context) string { + apiBaseURL := "" + if h != nil && h.settingSvc != nil { + if settings, err := h.settingSvc.GetAllSettings(ctx); err == nil && settings != nil { + apiBaseURL = strings.TrimSpace(settings.APIBaseURL) + } + } + return resolveWeChatOAuthAbsoluteURL(apiBaseURL, c, "/api/v1/auth/oauth/wechat/payment/callback") +} + +func encodeWeChatPaymentOAuthContext(ctx wechatPaymentOAuthContext) (string, error) { + data, err := json.Marshal(ctx) + if err != nil { + return "", err + } + return string(data), nil +} + +func decodeWeChatPaymentOAuthContext(raw string) (wechatPaymentOAuthContext, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return wechatPaymentOAuthContext{}, nil + } + var ctx wechatPaymentOAuthContext + if err := json.Unmarshal([]byte(raw), &ctx); err != nil { + return wechatPaymentOAuthContext{}, err + } + return ctx, nil +} + +func parseWeChatPaymentPlanID(raw string) int64 { + id, _ := strconv.ParseInt(strings.TrimSpace(raw), 10, 64) + return id +} + +func wechatPaymentSetCookie(c *gin.Context, name string, value string, maxAgeSec int, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: name, + Value: value, + Path: wechatPaymentOAuthCookiePath, + MaxAge: maxAgeSec, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func wechatPaymentClearCookie(c *gin.Context, name string, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: name, + Value: "", + Path: wechatPaymentOAuthCookiePath, + MaxAge: -1, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go new file mode 100644 index 0000000000000000000000000000000000000000..937daa6d3b68f57b22a284b96604b5148397568f --- /dev/null +++ b/backend/internal/handler/auth_wechat_oauth_test.go @@ -0,0 +1,1184 @@ +//go:build unit + +package handler + +import ( + "bytes" + "context" + "database/sql" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + dbuser "github.com/Wei-Shaw/sub2api/ent/user" + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/payment" + "github.com/Wei-Shaw/sub2api/internal/repository" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +func TestWeChatOAuthStartRedirectsAndSetsPendingCookies(t *testing.T) { + gin.SetMode(gin.TestMode) + handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, map[string]string{ + service.SettingKeyWeChatConnectEnabled: "true", + service.SettingKeyWeChatConnectAppID: "wx-open-app", + service.SettingKeyWeChatConnectAppSecret: "wx-open-secret", + service.SettingKeyWeChatConnectMode: "open", + service.SettingKeyWeChatConnectScopes: "snsapi_login", + service.SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", + service.SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback", + }) + defer client.Close() + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/start?mode=open&redirect=/billing", nil) + c.Request.Host = "api.example.com" + + handler.WeChatOAuthStart(c) + + require.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + require.NotEmpty(t, location) + require.Contains(t, location, "open.weixin.qq.com") + require.Contains(t, location, "appid=wx-open-app") + require.Contains(t, location, "scope=snsapi_login") + + cookies := recorder.Result().Cookies() + require.NotEmpty(t, findCookie(cookies, wechatOAuthStateCookieName)) + require.NotEmpty(t, findCookie(cookies, wechatOAuthRedirectCookieName)) + require.NotEmpty(t, findCookie(cookies, wechatOAuthModeCookieName)) + require.NotEmpty(t, findCookie(cookies, oauthPendingBrowserCookieName)) +} + +func TestWeChatOAuthStart_AllowsOpenModeWhenBothCapabilitiesEnabled(t *testing.T) { + gin.SetMode(gin.TestMode) + handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, map[string]string{ + service.SettingKeyWeChatConnectEnabled: "true", + service.SettingKeyWeChatConnectAppID: "wx-shared-app", + service.SettingKeyWeChatConnectAppSecret: "wx-shared-secret", + service.SettingKeyWeChatConnectMode: "mp", + service.SettingKeyWeChatConnectScopes: "snsapi_base", + service.SettingKeyWeChatConnectOpenEnabled: "true", + service.SettingKeyWeChatConnectMPEnabled: "true", + service.SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", + service.SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback", + }) + defer client.Close() + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/start?mode=open&redirect=/billing", nil) + c.Request.Host = "api.example.com" + + handler.WeChatOAuthStart(c) + + require.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + require.NotEmpty(t, location) + require.Contains(t, location, "open.weixin.qq.com") + require.Contains(t, location, "connect/qrconnect") + require.Contains(t, location, "scope=snsapi_login") +} + +func TestWeChatOAuthCallbackCreatesPendingSessionForUnifiedFlow(t *testing.T) { + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Nick","headimgurl":"https://cdn.example/avatar.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandler(t, false) + defer client.Close() + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.WeChatOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + ctx := context.Background() + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.Equal(t, "wechat", session.ProviderType) + require.Equal(t, "wechat-main", session.ProviderKey) + require.Equal(t, "union-456", session.ProviderSubject) + require.Equal(t, "wechat-union-456@wechat-connect.invalid", session.ResolvedEmail) + require.Equal(t, "WeChat Nick", session.UpstreamIdentityClaims["suggested_display_name"]) + require.Equal(t, "https://cdn.example/avatar.png", session.UpstreamIdentityClaims["suggested_avatar_url"]) + require.Equal(t, "union-456", session.UpstreamIdentityClaims["unionid"]) + require.Equal(t, "openid-123", session.UpstreamIdentityClaims["openid"]) +} + +func TestWeChatOAuthCallbackFallsBackToOpenIDWhenUnionIDMissingInSingleChannelMode(t *testing.T) { + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-123","nickname":"WeChat Nick","headimgurl":"https://cdn.example/avatar.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("open", "wx-open-app", "wx-open-secret", "https://app.example.com/auth/wechat/callback")) + defer client.Close() + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.WeChatOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "https://app.example.com/auth/wechat/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(context.Background()) + require.NoError(t, err) + require.Equal(t, oauthIntentLogin, session.Intent) + require.Equal(t, "openid-123", session.ProviderSubject) + require.Equal(t, wechatSyntheticEmail("openid-123"), session.ResolvedEmail) + + completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.Equal(t, oauthPendingChoiceStep, completion["step"]) + require.Equal(t, "third_party_signup", completion["choice_reason"]) +} + +func TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken(t *testing.T) { + originalAccessTokenURL := wechatOAuthAccessTokenURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/sns/oauth2/access_token") { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","scope":"snsapi_base"}`)) + return + } + http.NotFound(w, r) + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + + handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("mp", "wx-mp-app", "wx-mp-secret", "/auth/wechat/callback")) + defer client.Close() + handler.cfg.Totp.EncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/payment/callback?code=wechat-code&state=state-123", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatPaymentOAuthStateName, "state-123")) + req.AddCookie(encodedCookie(wechatPaymentOAuthRedirect, "/purchase?from=wechat")) + req.AddCookie(encodedCookie(wechatPaymentOAuthContextName, `{"payment_type":"wxpay","amount":"12.5","order_type":"subscription","plan_id":7}`)) + req.AddCookie(encodedCookie(wechatPaymentOAuthScope, "snsapi_base")) + c.Request = req + + handler.WeChatPaymentOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + parsed, err := url.Parse(location) + require.NoError(t, err) + fragment, err := url.ParseQuery(parsed.Fragment) + require.NoError(t, err) + require.Equal(t, "/purchase?from=wechat", fragment.Get("redirect")) + require.NotEmpty(t, fragment.Get("wechat_resume_token")) + require.Empty(t, fragment.Get("openid")) + require.Empty(t, fragment.Get("payment_type")) + require.Empty(t, fragment.Get("amount")) + require.Empty(t, fragment.Get("order_type")) + require.Empty(t, fragment.Get("plan_id")) + + claims, err := handler.wechatPaymentResumeService().ParseWeChatPaymentResumeToken(fragment.Get("wechat_resume_token")) + require.NoError(t, err) + require.Equal(t, "openid-123", claims.OpenID) + require.Equal(t, payment.TypeWxpay, claims.PaymentType) + require.Equal(t, "12.5", claims.Amount) + require.Equal(t, payment.OrderTypeSubscription, claims.OrderType) + require.EqualValues(t, 7, claims.PlanID) + require.Equal(t, "/purchase?from=wechat", claims.RedirectTo) +} + +func TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels(t *testing.T) { + testCases := []struct { + name string + mode string + appID string + appSecret string + openID string + }{ + { + name: "open", + mode: "open", + appID: "wx-open-app", + appSecret: "wx-open-secret", + openID: "openid-open-123", + }, + { + name: "mp", + mode: "mp", + appID: "wx-mp-app", + appSecret: "wx-mp-secret", + openID: "openid-mp-123", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"` + tc.openID + `","unionid":"union-456","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"` + tc.openID + `","unionid":"union-456","nickname":"Bind Nick","headimgurl":"https://cdn.example/bind.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings(tc.mode, tc.appID, tc.appSecret, "/auth/wechat/callback")) + defer client.Close() + + currentUser, err := client.User.Create(). + SetEmail("current@example.com"). + SetUsername("current-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(context.Background()) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + req.AddCookie(encodedCookie(wechatOAuthIntentCookieName, wechatOAuthIntentBind)) + req.AddCookie(encodedCookie(wechatOAuthModeCookieName, tc.mode)) + req.AddCookie(encodedCookie(wechatOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret"))) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.WeChatOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(context.Background()) + require.NoError(t, err) + require.Equal(t, wechatOAuthIntentBind, session.Intent) + require.NotNil(t, session.TargetUserID) + require.Equal(t, currentUser.ID, *session.TargetUserID) + require.Equal(t, currentUser.Email, session.ResolvedEmail) + require.Equal(t, "union-456", session.ProviderSubject) + require.Equal(t, "union-456", session.UpstreamIdentityClaims["subject"]) + require.Equal(t, "union-456", session.UpstreamIdentityClaims["unionid"]) + require.Equal(t, tc.openID, session.UpstreamIdentityClaims["openid"]) + require.Equal(t, tc.mode, session.UpstreamIdentityClaims["channel"]) + require.Equal(t, tc.appID, session.UpstreamIdentityClaims["channel_app_id"]) + require.Equal(t, tc.openID, session.UpstreamIdentityClaims["channel_subject"]) + + completionResponse := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.Equal(t, "/dashboard", completionResponse["redirect"]) + _, hasAccessToken := completionResponse["access_token"] + require.False(t, hasAccessToken) + }) + } +} + +func TestWeChatOAuthCallbackBindRejectsCanonicalOwnershipConflict(t *testing.T) { + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Conflict Nick","headimgurl":"https://cdn.example/conflict.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandler(t, false) + defer client.Close() + + ctx := context.Background() + owner, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + currentUser, err := client.User.Create(). + SetEmail("current@example.com"). + SetUsername("current"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + _, err = client.AuthIdentity.Create(). + SetUserID(owner.ID). + SetProviderType("wechat"). + SetProviderKey(wechatOAuthProviderKey). + SetProviderSubject("union-456"). + SetMetadata(map[string]any{"unionid": "union-456"}). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + req.AddCookie(encodedCookie(wechatOAuthIntentCookieName, wechatOAuthIntentBind)) + req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + req.AddCookie(encodedCookie(wechatOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret"))) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.WeChatOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)) + assertOAuthRedirectError(t, recorder.Header().Get("Location"), "ownership_conflict", "AUTH_IDENTITY_OWNERSHIP_CONFLICT") + + count, err := client.PendingAuthSession.Query().Count(ctx) + require.NoError(t, err) + require.Zero(t, count) +} + +func TestWeChatOAuthCallbackBindRejectsChannelOwnershipConflict(t *testing.T) { + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Conflict Nick","headimgurl":"https://cdn.example/conflict.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandler(t, false) + defer client.Close() + + ctx := context.Background() + owner, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + currentUser, err := client.User.Create(). + SetEmail("current@example.com"). + SetUsername("current"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + ownerIdentity, err := client.AuthIdentity.Create(). + SetUserID(owner.ID). + SetProviderType("wechat"). + SetProviderKey(wechatOAuthProviderKey). + SetProviderSubject("union-owner"). + SetMetadata(map[string]any{"unionid": "union-owner"}). + Save(ctx) + require.NoError(t, err) + + _, err = client.AuthIdentityChannel.Create(). + SetIdentityID(ownerIdentity.ID). + SetProviderType("wechat"). + SetProviderKey(wechatOAuthProviderKey). + SetChannel("open"). + SetChannelAppID("wx-open-app"). + SetChannelSubject("openid-123"). + SetMetadata(map[string]any{"openid": "openid-123"}). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + req.AddCookie(encodedCookie(wechatOAuthIntentCookieName, wechatOAuthIntentBind)) + req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + req.AddCookie(encodedCookie(wechatOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret"))) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.WeChatOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)) + assertOAuthRedirectError(t, recorder.Header().Get("Location"), "ownership_conflict", "AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT") + + count, err := client.PendingAuthSession.Query().Count(ctx) + require.NoError(t, err) + require.Zero(t, count) +} + +func TestWeChatOAuthCallbackBindRejectsLegacyProviderKeyOwnershipConflict(t *testing.T) { + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Conflict Nick","headimgurl":"https://cdn.example/conflict.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandler(t, false) + defer client.Close() + + ctx := context.Background() + owner, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + currentUser, err := client.User.Create(). + SetEmail("current@example.com"). + SetUsername("current"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + _, err = client.AuthIdentity.Create(). + SetUserID(owner.ID). + SetProviderType("wechat"). + SetProviderKey(wechatOAuthLegacyProviderKey). + SetProviderSubject("union-456"). + SetMetadata(map[string]any{"unionid": "union-456"}). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + req.AddCookie(encodedCookie(wechatOAuthIntentCookieName, wechatOAuthIntentBind)) + req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + req.AddCookie(encodedCookie(wechatOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret"))) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.WeChatOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)) + assertOAuthRedirectError(t, recorder.Header().Get("Location"), "ownership_conflict", "AUTH_IDENTITY_OWNERSHIP_CONFLICT") + + count, err := client.PendingAuthSession.Query().Count(ctx) + require.NoError(t, err) + require.Zero(t, count) +} + +func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing.T) { + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Display","headimgurl":"https://cdn.example/wechat.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandler(t, true) + defer client.Close() + + ctx := context.Background() + redeemRepo := repository.NewRedeemCodeRepository(client) + require.NoError(t, redeemRepo.Create(ctx, &service.RedeemCode{ + Code: "invite-1", + Type: service.RedeemTypeInvitation, + Status: service.StatusUnused, + })) + + callbackRecorder := httptest.NewRecorder() + callbackCtx, _ := gin.CreateTestContext(callbackRecorder) + callbackReq := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + callbackReq.Host = "api.example.com" + callbackReq.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + callbackReq.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + callbackReq.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + callbackReq.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + callbackCtx.Request = callbackReq + + handler.WeChatOAuthCallback(callbackCtx) + + require.Equal(t, http.StatusFound, callbackRecorder.Code) + require.Equal(t, "/auth/wechat/callback", callbackRecorder.Header().Get("Location")) + + sessionCookie := findCookie(callbackRecorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + sessionToken := decodeCookieValueForTest(t, sessionCookie.Value) + + pendingSession, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(sessionToken)). + Only(ctx) + require.NoError(t, err) + require.Equal(t, oauthPendingChoiceStep, pendingSession.LocalFlowState[oauthCompletionResponseKey].(map[string]any)["step"]) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true,"adopt_avatar":true}`) + completeRecorder := httptest.NewRecorder() + completeCtx, _ := gin.CreateTestContext(completeRecorder) + completeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body) + completeReq.Header.Set("Content-Type", "application/json") + completeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(sessionToken)}) + completeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-123")}) + completeCtx.Request = completeReq + + handler.CompleteWeChatOAuthRegistration(completeCtx) + + require.Equal(t, http.StatusOK, completeRecorder.Code) + responseData := decodeJSONBody(t, completeRecorder) + require.NotEmpty(t, responseData["access_token"]) + + userEntity, err := client.User.Query(). + Where(dbuser.EmailEQ("wechat-union-456@wechat-connect.invalid")). + Only(ctx) + require.NoError(t, err) + require.Equal(t, "WeChat Display", userEntity.Username) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderKeyEQ("wechat-main"), + authidentity.ProviderSubjectEQ("union-456"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, userEntity.ID, identity.UserID) + require.Equal(t, "WeChat Display", identity.Metadata["display_name"]) + require.Equal(t, "https://cdn.example/wechat.png", identity.Metadata["avatar_url"]) + + channel, err := client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ("wechat"), + authidentitychannel.ProviderKeyEQ("wechat-main"), + authidentitychannel.ChannelEQ("open"), + authidentitychannel.ChannelAppIDEQ("wx-open-app"), + authidentitychannel.ChannelSubjectEQ("openid-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, identity.ID, channel.IdentityID) + require.Equal(t, "union-456", channel.Metadata["unionid"]) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingSession.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, decision.IdentityID) + require.Equal(t, identity.ID, *decision.IdentityID) + require.True(t, decision.AdoptDisplayName) + require.True(t, decision.AdoptAvatar) + + consumed, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(pendingSession.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) +} + +func TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity(t *testing.T) { + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Legacy WeChat","headimgurl":"https://cdn.example/legacy.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandler(t, false) + defer client.Close() + + ctx := context.Background() + legacyUser, err := client.User.Create(). + SetEmail("legacy@example.com"). + SetUsername("legacy-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + legacyIdentity, err := client.AuthIdentity.Create(). + SetUserID(legacyUser.ID). + SetProviderType("wechat"). + SetProviderKey(wechatOAuthProviderKey). + SetProviderSubject("openid-123"). + SetMetadata(map[string]any{"openid": "openid-123"}). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.WeChatOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, session.TargetUserID) + require.Equal(t, legacyUser.ID, *session.TargetUserID) + require.Equal(t, legacyUser.Email, session.ResolvedEmail) + + repairedIdentity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderKeyEQ(wechatOAuthProviderKey), + authidentity.ProviderSubjectEQ("union-456"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, legacyIdentity.ID, repairedIdentity.ID) + require.Equal(t, legacyUser.ID, repairedIdentity.UserID) + + openIDIdentityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderKeyEQ(wechatOAuthProviderKey), + authidentity.ProviderSubjectEQ("openid-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, openIDIdentityCount) + + channel, err := client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ("wechat"), + authidentitychannel.ProviderKeyEQ(wechatOAuthProviderKey), + authidentitychannel.ChannelEQ("open"), + authidentitychannel.ChannelAppIDEQ("wx-open-app"), + authidentitychannel.ChannelSubjectEQ("openid-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, repairedIdentity.ID, channel.IdentityID) +} + +func TestCompleteWeChatOAuthRegistrationRejectsAdoptExistingUserSession(t *testing.T) { + handler, client := newWeChatOAuthTestHandler(t, false) + defer client.Close() + + ctx := context.Background() + existingUser, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("wechat-complete-invalid-session"). + SetIntent("adopt_existing_user_by_email"). + SetProviderType("wechat"). + SetProviderKey("wechat-main"). + SetProviderSubject("union-invalid-1"). + SetTargetUserID(existingUser.ID). + SetResolvedEmail(existingUser.Email). + SetBrowserSessionKey("wechat-invalid-browser"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "wechat_user", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "step": "bind_login_required", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`) + recorder := httptest.NewRecorder() + completeCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("wechat-invalid-browser")}) + completeCtx.Request = req + + handler.CompleteWeChatOAuthRegistration(completeCtx) + + require.Equal(t, http.StatusBadRequest, recorder.Code) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + +func TestWeChatOAuthCallbackRepairsLegacyProviderKeyCanonicalIdentity(t *testing.T) { + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Legacy Canonical","headimgurl":"https://cdn.example/legacy-canonical.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandler(t, false) + defer client.Close() + + ctx := context.Background() + legacyUser, err := client.User.Create(). + SetEmail("legacy@example.com"). + SetUsername("legacy-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + legacyIdentity, err := client.AuthIdentity.Create(). + SetUserID(legacyUser.ID). + SetProviderType("wechat"). + SetProviderKey(wechatOAuthLegacyProviderKey). + SetProviderSubject("union-456"). + SetMetadata(map[string]any{"unionid": "union-456"}). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.WeChatOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, session.TargetUserID) + require.Equal(t, legacyUser.ID, *session.TargetUserID) + require.Equal(t, legacyUser.Email, session.ResolvedEmail) + + repairedIdentity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderKeyEQ(wechatOAuthProviderKey), + authidentity.ProviderSubjectEQ("union-456"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, legacyIdentity.ID, repairedIdentity.ID) + require.Equal(t, legacyUser.ID, repairedIdentity.UserID) + + legacyIdentityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderKeyEQ(wechatOAuthLegacyProviderKey), + authidentity.ProviderSubjectEQ("union-456"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, legacyIdentityCount) + + channel, err := client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ("wechat"), + authidentitychannel.ProviderKeyEQ(wechatOAuthProviderKey), + authidentitychannel.ChannelEQ("open"), + authidentitychannel.ChannelAppIDEQ("wx-open-app"), + authidentitychannel.ChannelSubjectEQ("openid-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, repairedIdentity.ID, channel.IdentityID) +} + +func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) { + return newWeChatOAuthTestHandlerWithSettings(t, invitationEnabled, nil) +} + +func wechatOAuthTestSettings(mode, appID, secret, frontendRedirect string) map[string]string { + return map[string]string{ + service.SettingKeyWeChatConnectEnabled: "true", + service.SettingKeyWeChatConnectAppID: appID, + service.SettingKeyWeChatConnectAppSecret: secret, + service.SettingKeyWeChatConnectMode: mode, + service.SettingKeyWeChatConnectScopes: service.DefaultWeChatConnectScopesForMode(mode), + service.SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", + service.SettingKeyWeChatConnectFrontendRedirectURL: frontendRedirect, + } +} + +func newWeChatOAuthTestHandlerWithSettings(t *testing.T, invitationEnabled bool, extraSettings map[string]string) (*AuthHandler, *dbent.Client) { + t.Helper() + + db, err := sql.Open("sqlite", "file:auth_wechat_oauth?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + + userRepo := &oauthPendingFlowUserRepo{client: client} + redeemRepo := repository.NewRedeemCodeRepository(client) + cfg := &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret", + ExpireHour: 1, + AccessTokenExpireMinutes: 60, + RefreshTokenExpireDays: 7, + }, + Default: config.DefaultConfig{ + UserBalance: 0, + UserConcurrency: 1, + }, + } + values := map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled), + } + for key, value := range wechatOAuthTestSettings("open", "wx-open-app", "wx-open-secret", "/auth/wechat/callback") { + values[key] = value + } + for key, value := range extraSettings { + values[key] = value + } + settingSvc := service.NewSettingService(&wechatOAuthSettingRepoStub{values: values}, cfg) + + authSvc := service.NewAuthService( + client, + userRepo, + redeemRepo, + &wechatOAuthRefreshTokenCacheStub{}, + cfg, + settingSvc, + nil, + nil, + nil, + nil, + nil, + ) + + return &AuthHandler{ + authService: authSvc, + settingSvc: settingSvc, + cfg: cfg, + }, client +} + +func assertOAuthRedirectError(t *testing.T, location string, errorCode string, errorMessage string) { + t.Helper() + + parsed, err := url.Parse(location) + require.NoError(t, err) + + fragment, err := url.ParseQuery(parsed.Fragment) + require.NoError(t, err) + require.Equal(t, errorCode, fragment.Get("error")) + require.Equal(t, errorMessage, fragment.Get("error_message")) +} + +type wechatOAuthSettingRepoStub struct { + values map[string]string +} + +func (s *wechatOAuthSettingRepoStub) Get(context.Context, string) (*service.Setting, error) { + return nil, service.ErrSettingNotFound +} + +func (s *wechatOAuthSettingRepoStub) GetValue(_ context.Context, key string) (string, error) { + value, ok := s.values[key] + if !ok { + return "", service.ErrSettingNotFound + } + return value, nil +} + +func (s *wechatOAuthSettingRepoStub) Set(context.Context, string, string) error { + return nil +} + +func (s *wechatOAuthSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + result := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + result[key] = value + } + } + return result, nil +} + +func (s *wechatOAuthSettingRepoStub) SetMultiple(context.Context, map[string]string) error { + return nil +} + +func (s *wechatOAuthSettingRepoStub) GetAll(context.Context) (map[string]string, error) { + result := make(map[string]string, len(s.values)) + for key, value := range s.values { + result[key] = value + } + return result, nil +} + +func (s *wechatOAuthSettingRepoStub) Delete(context.Context, string) error { + return nil +} + +type wechatOAuthRefreshTokenCacheStub struct{} + +func (s *wechatOAuthRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error { + return nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) GetRefreshToken(context.Context, string) (*service.RefreshTokenData, error) { + return nil, service.ErrRefreshTokenNotFound +} + +func (s *wechatOAuthRefreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error { + return nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error { + return nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error { + return nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error { + return nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error { + return nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) { + return nil, nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) { + return nil, nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) { + return false, nil +} diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index d2ccb8d62f827e9b1a7d92a5d90679aedc31df4f..9780ff79fda0c7225f7a42b60d286f6f09ec2189 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -21,6 +21,7 @@ func UserFromServiceShallow(u *service.User) *User { Concurrency: u.Concurrency, Status: u.Status, AllowedGroups: u.AllowedGroups, + LastActiveAt: u.LastActiveAt, CreatedAt: u.CreatedAt, UpdatedAt: u.UpdatedAt, BalanceNotifyEnabled: u.BalanceNotifyEnabled, @@ -66,6 +67,7 @@ func UserFromServiceAdmin(u *service.User) *AdminUser { return &AdminUser{ User: *base, Notes: u.Notes, + LastUsedAt: u.LastUsedAt, GroupRates: u.GroupRates, } } diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 3659e79be3ed9a0aeab4dc2838ec25957d5b026f..fc6a3f9e0c580c701eddfefcb87bc06214342a9f 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -51,6 +51,23 @@ type SystemSettings struct { LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"` LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"` + WeChatConnectEnabled bool `json:"wechat_connect_enabled"` + WeChatConnectAppID string `json:"wechat_connect_app_id"` + WeChatConnectAppSecretConfigured bool `json:"wechat_connect_app_secret_configured"` + WeChatConnectOpenAppID string `json:"wechat_connect_open_app_id"` + WeChatConnectOpenAppSecretConfigured bool `json:"wechat_connect_open_app_secret_configured"` + WeChatConnectMPAppID string `json:"wechat_connect_mp_app_id"` + WeChatConnectMPAppSecretConfigured bool `json:"wechat_connect_mp_app_secret_configured"` + WeChatConnectMobileAppID string `json:"wechat_connect_mobile_app_id"` + WeChatConnectMobileAppSecretConfigured bool `json:"wechat_connect_mobile_app_secret_configured"` + WeChatConnectOpenEnabled bool `json:"wechat_connect_open_enabled"` + WeChatConnectMPEnabled bool `json:"wechat_connect_mp_enabled"` + WeChatConnectMobileEnabled bool `json:"wechat_connect_mobile_enabled"` + WeChatConnectMode string `json:"wechat_connect_mode"` + WeChatConnectScopes string `json:"wechat_connect_scopes"` + WeChatConnectRedirectURL string `json:"wechat_connect_redirect_url"` + WeChatConnectFrontendRedirectURL string `json:"wechat_connect_frontend_redirect_url"` + OIDCConnectEnabled bool `json:"oidc_connect_enabled"` OIDCConnectProviderName string `json:"oidc_connect_provider_name"` OIDCConnectClientID string `json:"oidc_connect_client_id"` @@ -127,6 +144,15 @@ type SystemSettings struct { // Web Search Emulation WebSearchEmulationEnabled bool `json:"web_search_emulation_enabled"` + // Payment visible method routing + PaymentVisibleMethodAlipaySource string `json:"payment_visible_method_alipay_source"` + PaymentVisibleMethodWxpaySource string `json:"payment_visible_method_wxpay_source"` + PaymentVisibleMethodAlipayEnabled bool `json:"payment_visible_method_alipay_enabled"` + PaymentVisibleMethodWxpayEnabled bool `json:"payment_visible_method_wxpay_enabled"` + + // OpenAI account scheduling + OpenAIAdvancedSchedulerEnabled bool `json:"openai_advanced_scheduler_enabled"` + // Payment configuration PaymentEnabled bool `json:"payment_enabled"` PaymentMinAmount float64 `json:"payment_min_amount"` @@ -167,6 +193,7 @@ type DefaultSubscriptionSetting struct { type PublicSettings struct { RegistrationEnabled bool `json:"registration_enabled"` EmailVerifyEnabled bool `json:"email_verify_enabled"` + ForceEmailOnThirdPartySignup bool `json:"force_email_on_third_party_signup"` RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"` PromoCodeEnabled bool `json:"promo_code_enabled"` PasswordResetEnabled bool `json:"password_reset_enabled"` @@ -189,6 +216,10 @@ type PublicSettings struct { CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` CustomEndpoints []CustomEndpoint `json:"custom_endpoints"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` + WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"` + WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"` + WeChatOAuthMPEnabled bool `json:"wechat_oauth_mp_enabled"` + WeChatOAuthMobileEnabled bool `json:"wechat_oauth_mobile_enabled"` OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"` OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"` SoraClientEnabled bool `json:"sora_client_enabled"` diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 8c1e166f3d69dba16ebd90e2cb7571f679593fd8..c0bce40b37bff09b4dd0238cda2bb349ac9eeb75 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -7,16 +7,17 @@ import ( ) type User struct { - ID int64 `json:"id"` - Email string `json:"email"` - Username string `json:"username"` - Role string `json:"role"` - Balance float64 `json:"balance"` - Concurrency int `json:"concurrency"` - Status string `json:"status"` - AllowedGroups []int64 `json:"allowed_groups"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID int64 `json:"id"` + Email string `json:"email"` + Username string `json:"username"` + Role string `json:"role"` + Balance float64 `json:"balance"` + Concurrency int `json:"concurrency"` + Status string `json:"status"` + AllowedGroups []int64 `json:"allowed_groups"` + LastActiveAt *time.Time `json:"last_active_at,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` // 余额不足通知 BalanceNotifyEnabled bool `json:"balance_notify_enabled"` @@ -34,7 +35,8 @@ type User struct { type AdminUser struct { User - Notes string `json:"notes"` + Notes string `json:"notes"` + LastUsedAt *time.Time `json:"last_used_at"` // GroupRates 用户专属分组倍率配置 // map[groupID]rateMultiplier GroupRates map[int64]float64 `json:"group_rates,omitempty"` diff --git a/backend/internal/handler/dto/user_mapper_activity_test.go b/backend/internal/handler/dto/user_mapper_activity_test.go new file mode 100644 index 0000000000000000000000000000000000000000..a17f0ce486549d0243f4f30e7f0cbd5eea0dfbb4 --- /dev/null +++ b/backend/internal/handler/dto/user_mapper_activity_test.go @@ -0,0 +1,33 @@ +package dto + +import ( + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestUserFromServiceAdmin_MapsActivityTimestamps(t *testing.T) { + t.Parallel() + + lastLoginAt := time.Date(2026, time.April, 20, 10, 0, 0, 0, time.UTC) + lastActiveAt := lastLoginAt.Add(15 * time.Minute) + lastUsedAt := lastLoginAt.Add(45 * time.Minute) + + out := UserFromServiceAdmin(&service.User{ + ID: 42, + Email: "admin@example.com", + Username: "admin", + Role: service.RoleAdmin, + Status: service.StatusActive, + LastActiveAt: &lastActiveAt, + LastUsedAt: &lastUsedAt, + }) + + require.NotNil(t, out) + require.NotNil(t, out.LastActiveAt) + require.NotNil(t, out.LastUsedAt) + require.WithinDuration(t, lastActiveAt, *out.LastActiveAt, time.Second) + require.WithinDuration(t, lastUsedAt, *out.LastUsedAt, time.Second) +} diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 5319b55d9dddf85a13200e749ee00a15128f1de1..43999a01a4639389c792df007fdfbcc9e0d3a89b 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -187,6 +187,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "previous_response_id must be a response.id (resp_*), not a message id") return } + reqLog.Warn("openai.request_validation_failed", + zap.String("reason", "previous_response_id_requires_wsv2"), + ) + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "previous_response_id is only supported on Responses WebSocket v2") + return } setOpsRequestContext(c, reqModel, reqStream, body) @@ -856,7 +861,7 @@ func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context, reqLog.Warn("openai.request_validation_failed", zap.String("reason", "function_call_output_missing_call_id"), ) - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id") + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id on HTTP requests; continuation via previous_response_id is only supported on Responses WebSocket v2") return false } if validation.HasItemReferenceForAllCallIDs { @@ -866,7 +871,7 @@ func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context, reqLog.Warn("openai.request_validation_failed", zap.String("reason", "function_call_output_missing_item_reference"), ) - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id") + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id on HTTP requests; continuation via previous_response_id is only supported on Responses WebSocket v2") return false } diff --git a/backend/internal/handler/openai_gateway_handler_test.go b/backend/internal/handler/openai_gateway_handler_test.go index d299fb81e338120a49052581b8dd00ac813c6ab1..8ecee59ae79d47e980d24b83824369c530b98d76 100644 --- a/backend/internal/handler/openai_gateway_handler_test.go +++ b/backend/internal/handler/openai_gateway_handler_test.go @@ -494,6 +494,64 @@ func TestOpenAIResponses_RejectsMessageIDAsPreviousResponseID(t *testing.T) { require.Contains(t, w.Body.String(), "previous_response_id must be a response.id") } +func TestOpenAIResponses_RejectsHTTPContinuationPreviousResponseID(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader( + `{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_123456","input":[{"type":"input_text","text":"hello"}]}`, + )) + c.Request.Header.Set("Content-Type", "application/json") + + groupID := int64(2) + c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{ + ID: 101, + GroupID: &groupID, + User: &service.User{ID: 1}, + }) + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{ + UserID: 1, + Concurrency: 1, + }) + + h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil) + h.Responses(c) + + require.Equal(t, http.StatusBadRequest, w.Code) + require.Contains(t, w.Body.String(), "Responses WebSocket v2") + require.Contains(t, w.Body.String(), "previous_response_id") +} + +func TestOpenAIResponses_FunctionCallOutputHTTPGuidanceDoesNotSuggestPreviousResponseReuse(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader( + `{"model":"gpt-5.1","stream":false,"input":[{"type":"function_call_output","output":"{}"}]}`, + )) + c.Request.Header.Set("Content-Type", "application/json") + + groupID := int64(2) + c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{ + ID: 101, + GroupID: &groupID, + User: &service.User{ID: 1}, + }) + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{ + UserID: 1, + Concurrency: 1, + }) + + h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil) + h.Responses(c) + + require.Equal(t, http.StatusBadRequest, w.Code) + require.Contains(t, w.Body.String(), "Responses WebSocket v2") + require.NotContains(t, w.Body.String(), "reuse previous_response_id") +} + func TestOpenAIResponsesWebSocket_SetsClientTransportWSWhenUpgradeValid(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/handler/payment_handler.go b/backend/internal/handler/payment_handler.go index 854dca548bfe19922d494dc6bbf71be0dec81275..16b25355b5799753e321a6fca2cf82f3df8ea63a 100644 --- a/backend/internal/handler/payment_handler.go +++ b/backend/internal/handler/payment_handler.go @@ -1,9 +1,14 @@ package handler import ( + "fmt" + "net/http" "strconv" "strings" + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/internal/payment" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/response" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" @@ -202,10 +207,14 @@ func (h *PaymentHandler) GetLimits(c *gin.Context) { // CreateOrderRequest is the request body for creating a payment order. type CreateOrderRequest struct { - Amount float64 `json:"amount"` - PaymentType string `json:"payment_type" binding:"required"` - OrderType string `json:"order_type"` - PlanID int64 `json:"plan_id"` + Amount float64 `json:"amount"` + PaymentType string `json:"payment_type" binding:"required"` + OpenID string `json:"openid"` + WechatResumeToken string `json:"wechat_resume_token"` + ReturnURL string `json:"return_url"` + PaymentSource string `json:"payment_source"` + OrderType string `json:"order_type"` + PlanID int64 `json:"plan_id"` // IsMobile lets the frontend declare its mobile status directly. When // nil we fall back to User-Agent heuristics (which miss iPadOS / some // embedded browsers that strip the "Mobile" keyword). @@ -225,21 +234,36 @@ func (h *PaymentHandler) CreateOrder(c *gin.Context) { response.BadRequest(c, "Invalid request: "+err.Error()) return } + if strings.TrimSpace(req.WechatResumeToken) != "" { + claims, err := h.paymentService.ParseWeChatPaymentResumeToken(req.WechatResumeToken) + if err != nil { + response.ErrorFrom(c, err) + return + } + if err := applyWeChatPaymentResumeClaims(&req, claims); err != nil { + response.ErrorFrom(c, err) + return + } + } mobile := isMobile(c) if req.IsMobile != nil { mobile = *req.IsMobile } result, err := h.paymentService.CreateOrder(c.Request.Context(), service.CreateOrderRequest{ - UserID: subject.UserID, - Amount: req.Amount, - PaymentType: req.PaymentType, - ClientIP: c.ClientIP(), - IsMobile: mobile, - SrcHost: c.Request.Host, - SrcURL: c.Request.Referer(), - OrderType: req.OrderType, - PlanID: req.PlanID, + UserID: subject.UserID, + Amount: req.Amount, + PaymentType: req.PaymentType, + OpenID: req.OpenID, + ClientIP: c.ClientIP(), + IsMobile: mobile, + IsWeChatBrowser: isWeChatBrowser(c), + SrcHost: c.Request.Host, + SrcURL: c.Request.Referer(), + ReturnURL: req.ReturnURL, + PaymentSource: req.PaymentSource, + OrderType: req.OrderType, + PlanID: req.PlanID, }) if err != nil { response.ErrorFrom(c, err) @@ -248,6 +272,44 @@ func (h *PaymentHandler) CreateOrder(c *gin.Context) { response.Success(c, result) } +func applyWeChatPaymentResumeClaims(req *CreateOrderRequest, claims *service.WeChatPaymentResumeClaims) error { + if req == nil || claims == nil { + return infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume context is missing") + } + openid := strings.TrimSpace(claims.OpenID) + if openid == "" { + return infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token missing openid") + } + + paymentType := service.NormalizeVisibleMethod(claims.PaymentType) + if paymentType == "" { + paymentType = payment.TypeWxpay + } + if req.PaymentType != "" { + requestPaymentType := service.NormalizeVisibleMethod(req.PaymentType) + if requestPaymentType != "" && requestPaymentType != paymentType { + return infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token payment type mismatch") + } + } + req.PaymentType = paymentType + req.OpenID = openid + + if strings.TrimSpace(claims.Amount) != "" { + amount, err := strconv.ParseFloat(strings.TrimSpace(claims.Amount), 64) + if err != nil || amount <= 0 { + return infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", fmt.Sprintf("invalid resume amount: %s", claims.Amount)) + } + req.Amount = amount + } + if claims.OrderType != "" { + req.OrderType = claims.OrderType + } + if claims.PlanID > 0 { + req.PlanID = claims.PlanID + } + return nil +} + // GetMyOrders returns the authenticated user's orders. // GET /api/v1/payment/orders/my func (h *PaymentHandler) GetMyOrders(c *gin.Context) { @@ -268,7 +330,7 @@ func (h *PaymentHandler) GetMyOrders(c *gin.Context) { response.ErrorFrom(c, err) return } - response.Paginated(c, orders, int64(total), page, pageSize) + response.Paginated(c, sanitizePaymentOrdersForResponse(orders), int64(total), page, pageSize) } // GetOrder returns a single order for the authenticated user. @@ -290,7 +352,7 @@ func (h *PaymentHandler) GetOrder(c *gin.Context) { response.ErrorFrom(c, err) return } - response.Success(c, order) + response.Success(c, sanitizePaymentOrderForResponse(order)) } // CancelOrder cancels a pending order for the authenticated user. @@ -362,6 +424,10 @@ type VerifyOrderRequest struct { OutTradeNo string `json:"out_trade_no" binding:"required"` } +type ResolveOrderByResumeTokenRequest struct { + ResumeToken string `json:"resume_token" binding:"required"` +} + // VerifyOrder actively queries the upstream payment provider to check // if payment was made, and processes it if so. // POST /api/v1/payment/orders/verify @@ -382,7 +448,7 @@ func (h *PaymentHandler) VerifyOrder(c *gin.Context) { response.ErrorFrom(c, err) return } - response.Success(c, order) + response.Success(c, sanitizePaymentOrderForResponse(order)) } // PublicOrderResult is the limited order info returned by the public verify endpoint. @@ -397,16 +463,32 @@ type PublicOrderResult struct { Status string `json:"status"` } -// VerifyOrderPublic verifies payment status without requiring authentication. -// Returns limited order info (no user details) to prevent information leakage. +var errPaymentPublicOrderVerifyRemoved = infraerrors.New( + http.StatusGone, + "PAYMENT_PUBLIC_ORDER_VERIFY_REMOVED", + "public payment order verification by out_trade_no has been removed; use resume_token recovery instead", +).WithMetadata(map[string]string{ + "replacement_endpoint": "/api/v1/payment/public/orders/resolve", + "replacement_field": "resume_token", +}) + +// VerifyOrderPublic is kept as a compatibility shim for the removed anonymous +// out_trade_no lookup endpoint and always returns HTTP 410 Gone. // POST /api/v1/payment/public/orders/verify func (h *PaymentHandler) VerifyOrderPublic(c *gin.Context) { - var req VerifyOrderRequest + response.ErrorFrom(c, errPaymentPublicOrderVerifyRemoved) +} + +// ResolveOrderPublicByResumeToken resolves a payment order from a signed resume token. +// POST /api/v1/payment/public/orders/resolve +func (h *PaymentHandler) ResolveOrderPublicByResumeToken(c *gin.Context) { + var req ResolveOrderByResumeTokenRequest if err := c.ShouldBindJSON(&req); err != nil { response.BadRequest(c, "Invalid request: "+err.Error()) return } - order, err := h.paymentService.VerifyOrderPublic(c.Request.Context(), req.OutTradeNo) + + order, err := h.paymentService.GetPublicOrderByResumeToken(c.Request.Context(), req.ResumeToken) if err != nil { response.ErrorFrom(c, err) return @@ -443,3 +525,27 @@ func isMobile(c *gin.Context) bool { } return false } + +func sanitizePaymentOrdersForResponse(orders []*dbent.PaymentOrder) []*dbent.PaymentOrder { + if len(orders) == 0 { + return orders + } + out := make([]*dbent.PaymentOrder, 0, len(orders)) + for _, order := range orders { + out = append(out, sanitizePaymentOrderForResponse(order)) + } + return out +} + +func sanitizePaymentOrderForResponse(order *dbent.PaymentOrder) *dbent.PaymentOrder { + if order == nil { + return nil + } + cloned := *order + cloned.ProviderSnapshot = nil + return &cloned +} + +func isWeChatBrowser(c *gin.Context) bool { + return strings.Contains(strings.ToLower(c.GetHeader("User-Agent")), "micromessenger") +} diff --git a/backend/internal/handler/payment_handler_resume_test.go b/backend/internal/handler/payment_handler_resume_test.go new file mode 100644 index 0000000000000000000000000000000000000000..28da15d98041a4bbf3c811ed8002b037b62fce8c --- /dev/null +++ b/backend/internal/handler/payment_handler_resume_test.go @@ -0,0 +1,114 @@ +//go:build unit + +package handler + +import ( + "bytes" + "database/sql" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/Wei-Shaw/sub2api/internal/payment" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +func TestApplyWeChatPaymentResumeClaims(t *testing.T) { + t.Parallel() + + req := CreateOrderRequest{ + Amount: 0, + PaymentType: payment.TypeWxpay, + OrderType: payment.OrderTypeBalance, + } + + err := applyWeChatPaymentResumeClaims(&req, &service.WeChatPaymentResumeClaims{ + OpenID: "openid-123", + PaymentType: payment.TypeWxpay, + Amount: "12.50", + OrderType: payment.OrderTypeSubscription, + PlanID: 7, + }) + if err != nil { + t.Fatalf("applyWeChatPaymentResumeClaims returned error: %v", err) + } + if req.OpenID != "openid-123" { + t.Fatalf("openid = %q, want %q", req.OpenID, "openid-123") + } + if req.Amount != 12.5 { + t.Fatalf("amount = %v, want 12.5", req.Amount) + } + if req.OrderType != payment.OrderTypeSubscription { + t.Fatalf("order_type = %q, want %q", req.OrderType, payment.OrderTypeSubscription) + } + if req.PlanID != 7 { + t.Fatalf("plan_id = %d, want 7", req.PlanID) + } +} + +func TestApplyWeChatPaymentResumeClaimsRejectsPaymentTypeMismatch(t *testing.T) { + t.Parallel() + + req := CreateOrderRequest{ + PaymentType: payment.TypeAlipay, + } + + err := applyWeChatPaymentResumeClaims(&req, &service.WeChatPaymentResumeClaims{ + OpenID: "openid-123", + PaymentType: payment.TypeWxpay, + Amount: "12.50", + OrderType: payment.OrderTypeBalance, + }) + if err == nil { + t.Fatal("applyWeChatPaymentResumeClaims should reject mismatched payment types") + } +} + +func TestVerifyOrderPublicReturnsGone(t *testing.T) { + t.Parallel() + + gin.SetMode(gin.TestMode) + + db, err := sql.Open("sqlite", "file:payment_handler_public_verify?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + + paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil) + h := NewPaymentHandler(paymentSvc, nil, nil) + + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest( + http.MethodPost, + "/api/v1/payment/public/orders/verify", + bytes.NewBufferString(`{"out_trade_no":"legacy-order-no"}`), + ) + ctx.Request.Header.Set("Content-Type", "application/json") + + h.VerifyOrderPublic(ctx) + + require.Equal(t, http.StatusGone, recorder.Code) + + var resp response.Response + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, http.StatusGone, resp.Code) + require.Equal(t, "PAYMENT_PUBLIC_ORDER_VERIFY_REMOVED", resp.Reason) + require.Contains(t, resp.Message, "removed") +} diff --git a/backend/internal/handler/payment_webhook_handler.go b/backend/internal/handler/payment_webhook_handler.go index 8a83bfebb72830b9957c83f4cf419814ffbfc2a0..c06a5b7e03cf1d8c3c2879bb763778ccf126dbef 100644 --- a/backend/internal/handler/payment_webhook_handler.go +++ b/backend/internal/handler/payment_webhook_handler.go @@ -1,6 +1,8 @@ package handler import ( + "context" + "fmt" "io" "log/slog" "net/http" @@ -77,9 +79,13 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string) // This is needed when multiple instances of the same provider exist (e.g. multiple EasyPay accounts). outTradeNo := extractOutTradeNo(rawBody, providerKey) - provider, err := h.paymentService.GetWebhookProvider(c.Request.Context(), providerKey, outTradeNo) + providers, err := h.paymentService.GetWebhookProviders(c.Request.Context(), providerKey, outTradeNo) if err != nil { slog.Warn("[Payment Webhook] provider not found", "provider", providerKey, "outTradeNo", outTradeNo, "error", err) + if providerKey == payment.TypeWxpay { + c.String(http.StatusBadRequest, "verify failed") + return + } writeSuccessResponse(c, providerKey) return } @@ -89,7 +95,7 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string) headers[strings.ToLower(k)] = c.GetHeader(k) } - notification, err := provider.VerifyNotification(c.Request.Context(), rawBody, headers) + resolvedProviderKey, notification, err := verifyNotificationWithProviders(c.Request.Context(), providers, rawBody, headers) if err != nil { truncatedBody := rawBody if len(truncatedBody) > webhookLogTruncateLen { @@ -103,24 +109,24 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string) // nil notification means irrelevant event (e.g. Stripe non-payment event); return success. if notification == nil { - writeSuccessResponse(c, providerKey) + writeSuccessResponse(c, resolvedProviderKey) return } - if err := h.paymentService.HandlePaymentNotification(c.Request.Context(), notification, providerKey); err != nil { - slog.Error("[Payment Webhook] handle notification failed", "provider", providerKey, "error", err) + if err := h.paymentService.HandlePaymentNotification(c.Request.Context(), notification, resolvedProviderKey); err != nil { + slog.Error("[Payment Webhook] handle notification failed", "provider", resolvedProviderKey, "error", err) c.String(http.StatusInternalServerError, "handle failed") return } - writeSuccessResponse(c, providerKey) + writeSuccessResponse(c, resolvedProviderKey) } // extractOutTradeNo parses the webhook body to find the out_trade_no. // This allows looking up the correct provider instance before verification. func extractOutTradeNo(rawBody, providerKey string) string { switch providerKey { - case payment.TypeEasyPay: + case payment.TypeEasyPay, payment.TypeAlipay: values, err := url.ParseQuery(rawBody) if err == nil { return values.Get("out_trade_no") @@ -131,6 +137,25 @@ func extractOutTradeNo(rawBody, providerKey string) string { return "" } +func verifyNotificationWithProviders(ctx context.Context, providers []payment.Provider, rawBody string, headers map[string]string) (string, *payment.PaymentNotification, error) { + var lastErr error + for _, provider := range providers { + if provider == nil { + continue + } + notification, err := provider.VerifyNotification(ctx, rawBody, headers) + if err != nil { + lastErr = err + continue + } + return provider.ProviderKey(), notification, nil + } + if lastErr != nil { + return "", nil, lastErr + } + return "", nil, fmt.Errorf("no webhook provider could verify notification") +} + // wxpaySuccessResponse is the JSON response expected by WeChat Pay webhook. type wxpaySuccessResponse struct { Code string `json:"code"` diff --git a/backend/internal/handler/payment_webhook_handler_test.go b/backend/internal/handler/payment_webhook_handler_test.go index bdef1766d91108a7ab4656bfc29715667a3a40aa..88221b5c09b3a2ce058af935f70d40ed060b7ba8 100644 --- a/backend/internal/handler/payment_webhook_handler_test.go +++ b/backend/internal/handler/payment_webhook_handler_test.go @@ -3,11 +3,14 @@ package handler import ( + "context" "encoding/json" + "errors" "net/http" "net/http/httptest" "testing" + "github.com/Wei-Shaw/sub2api/internal/payment" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -97,3 +100,104 @@ func TestWebhookConstants(t *testing.T) { assert.Equal(t, 200, webhookLogTruncateLen) }) } + +func TestExtractOutTradeNo(t *testing.T) { + tests := []struct { + name string + providerKey string + rawBody string + want string + }{ + { + name: "easypay query payload", + providerKey: "easypay", + rawBody: "out_trade_no=sub2_123&trade_status=TRADE_SUCCESS", + want: "sub2_123", + }, + { + name: "alipay query payload", + providerKey: "alipay", + rawBody: "notify_time=2026-04-20+12%3A00%3A00&out_trade_no=sub2_456", + want: "sub2_456", + }, + { + name: "unknown provider", + providerKey: "wxpay", + rawBody: "{}", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, extractOutTradeNo(tt.rawBody, tt.providerKey)) + }) + } +} + +func TestVerifyNotificationWithProvidersReturnsMatchedProvider(t *testing.T) { + firstErr := errors.New("wrong provider") + providers := []payment.Provider{ + webhookHandlerProviderStub{ + key: payment.TypeWxpay, + verifyErr: firstErr, + }, + webhookHandlerProviderStub{ + key: payment.TypeWxpay, + notification: &payment.PaymentNotification{ + OrderID: "sub2_42", + TradeNo: "trade-42", + Status: payment.NotificationStatusSuccess, + }, + }, + } + + providerKey, notification, err := verifyNotificationWithProviders(context.Background(), providers, "{}", map[string]string{"wechatpay-signature": "sig"}) + require.NoError(t, err) + require.Equal(t, payment.TypeWxpay, providerKey) + require.NotNil(t, notification) + require.Equal(t, "sub2_42", notification.OrderID) +} + +func TestVerifyNotificationWithProvidersFailsWhenAllProvidersReject(t *testing.T) { + providers := []payment.Provider{ + webhookHandlerProviderStub{ + key: payment.TypeWxpay, + verifyErr: errors.New("verify failed a"), + }, + webhookHandlerProviderStub{ + key: payment.TypeWxpay, + verifyErr: errors.New("verify failed b"), + }, + } + + _, _, err := verifyNotificationWithProviders(context.Background(), providers, "{}", nil) + require.Error(t, err) +} + +type webhookHandlerProviderStub struct { + key string + notification *payment.PaymentNotification + verifyErr error +} + +func (p webhookHandlerProviderStub) Name() string { return p.key } +func (p webhookHandlerProviderStub) ProviderKey() string { return p.key } +func (p webhookHandlerProviderStub) SupportedTypes() []payment.PaymentType { + return []payment.PaymentType{payment.PaymentType(p.key)} +} +func (p webhookHandlerProviderStub) CreatePayment(context.Context, payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) { + panic("unexpected call") +} +func (p webhookHandlerProviderStub) QueryOrder(context.Context, string) (*payment.QueryOrderResponse, error) { + panic("unexpected call") +} +func (p webhookHandlerProviderStub) VerifyNotification(context.Context, string, map[string]string) (*payment.PaymentNotification, error) { + if p.verifyErr != nil { + return nil, p.verifyErr + } + return p.notification, nil +} +func (p webhookHandlerProviderStub) Refund(context.Context, payment.RefundRequest) (*payment.RefundResponse, error) { + panic("unexpected call") +} diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index 1717b7a1e9afc666465d0dedcb3d457fe530a1a2..c0f5c28be2d185e03698180a8683f0e80e8e60d9 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -34,6 +34,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { response.Success(c, dto.PublicSettings{ RegistrationEnabled: settings.RegistrationEnabled, EmailVerifyEnabled: settings.EmailVerifyEnabled, + ForceEmailOnThirdPartySignup: settings.ForceEmailOnThirdPartySignup, RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist, PromoCodeEnabled: settings.PromoCodeEnabled, PasswordResetEnabled: settings.PasswordResetEnabled, @@ -56,6 +57,10 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems), CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints), LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, + WeChatOAuthEnabled: settings.WeChatOAuthEnabled, + WeChatOAuthOpenEnabled: settings.WeChatOAuthOpenEnabled, + WeChatOAuthMPEnabled: settings.WeChatOAuthMPEnabled, + WeChatOAuthMobileEnabled: settings.WeChatOAuthMobileEnabled, OIDCOAuthEnabled: settings.OIDCOAuthEnabled, OIDCOAuthProviderName: settings.OIDCOAuthProviderName, BackendModeEnabled: settings.BackendModeEnabled, diff --git a/backend/internal/handler/setting_handler_public_test.go b/backend/internal/handler/setting_handler_public_test.go new file mode 100644 index 0000000000000000000000000000000000000000..45d66f8e337ed5c4647518976dcbdbaf157a79a1 --- /dev/null +++ b/backend/internal/handler/setting_handler_public_test.go @@ -0,0 +1,122 @@ +//go:build unit + +package handler + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type settingHandlerPublicRepoStub struct { + values map[string]string +} + +func (s *settingHandlerPublicRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) { + panic("unexpected Get call") +} + +func (s *settingHandlerPublicRepoStub) GetValue(ctx context.Context, key string) (string, error) { + panic("unexpected GetValue call") +} + +func (s *settingHandlerPublicRepoStub) Set(ctx context.Context, key, value string) error { + panic("unexpected Set call") +} + +func (s *settingHandlerPublicRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (s *settingHandlerPublicRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + panic("unexpected SetMultiple call") +} + +func (s *settingHandlerPublicRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *settingHandlerPublicRepoStub) Delete(ctx context.Context, key string) error { + panic("unexpected Delete call") +} + +func TestSettingHandler_GetPublicSettings_ExposesForceEmailOnThirdPartySignup(t *testing.T) { + gin.SetMode(gin.TestMode) + + repo := &settingHandlerPublicRepoStub{ + values: map[string]string{ + service.SettingKeyForceEmailOnThirdPartySignup: "true", + }, + } + h := NewSettingHandler(service.NewSettingService(repo, &config.Config{}), "test-version") + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/settings/public", nil) + + h.GetPublicSettings(c) + + require.Equal(t, http.StatusOK, recorder.Code) + + var resp struct { + Code int `json:"code"` + Data struct { + ForceEmailOnThirdPartySignup bool `json:"force_email_on_third_party_signup"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.True(t, resp.Data.ForceEmailOnThirdPartySignup) +} + +func TestSettingHandler_GetPublicSettings_ExposesWeChatOAuthModeCapabilities(t *testing.T) { + gin.SetMode(gin.TestMode) + h := NewSettingHandler(service.NewSettingService(&settingHandlerPublicRepoStub{ + values: map[string]string{ + service.SettingKeyWeChatConnectEnabled: "true", + service.SettingKeyWeChatConnectAppID: "wx-mp-app", + service.SettingKeyWeChatConnectAppSecret: "wx-mp-secret", + service.SettingKeyWeChatConnectMode: "mp", + service.SettingKeyWeChatConnectScopes: "snsapi_base", + service.SettingKeyWeChatConnectOpenEnabled: "true", + service.SettingKeyWeChatConnectMPEnabled: "true", + service.SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", + service.SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback", + }, + }, &config.Config{}), "test-version") + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/settings/public", nil) + + h.GetPublicSettings(c) + + require.Equal(t, http.StatusOK, recorder.Code) + + var resp struct { + Code int `json:"code"` + Data struct { + WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"` + WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"` + WeChatOAuthMPEnabled bool `json:"wechat_oauth_mp_enabled"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.True(t, resp.Data.WeChatOAuthEnabled) + require.True(t, resp.Data.WeChatOAuthOpenEnabled) + require.True(t, resp.Data.WeChatOAuthMPEnabled) +} diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go index 2535ea5e6f60e4da7761a0edcd4915a7cb143e22..3e5ca0807f0926ebc56a0b873f282538038c676b 100644 --- a/backend/internal/handler/user_handler.go +++ b/backend/internal/handler/user_handler.go @@ -1,6 +1,9 @@ package handler import ( + "context" + "strings" + "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" @@ -12,14 +15,21 @@ import ( // UserHandler handles user-related requests type UserHandler struct { userService *service.UserService + authService *service.AuthService emailService *service.EmailService emailCache service.EmailCache } // NewUserHandler creates a new UserHandler -func NewUserHandler(userService *service.UserService, emailService *service.EmailService, emailCache service.EmailCache) *UserHandler { +func NewUserHandler( + userService *service.UserService, + authService *service.AuthService, + emailService *service.EmailService, + emailCache service.EmailCache, +) *UserHandler { return &UserHandler{ userService: userService, + authService: authService, emailService: emailService, emailCache: emailCache, } @@ -34,10 +44,33 @@ type ChangePasswordRequest struct { // UpdateProfileRequest represents the update profile request payload type UpdateProfileRequest struct { Username *string `json:"username"` + AvatarURL *string `json:"avatar_url"` BalanceNotifyEnabled *bool `json:"balance_notify_enabled"` BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"` } +type userProfileResponse struct { + dto.User + AvatarURL string `json:"avatar_url,omitempty"` + AvatarSource *userProfileSourceContext `json:"avatar_source,omitempty"` + UsernameSource *userProfileSourceContext `json:"username_source,omitempty"` + DisplayNameSource *userProfileSourceContext `json:"display_name_source,omitempty"` + NicknameSource *userProfileSourceContext `json:"nickname_source,omitempty"` + ProfileSources map[string]*userProfileSourceContext `json:"profile_sources,omitempty"` + Identities service.UserIdentitySummarySet `json:"identities"` + AuthBindings map[string]service.UserIdentitySummary `json:"auth_bindings"` + IdentityBindings map[string]service.UserIdentitySummary `json:"identity_bindings"` + EmailBound bool `json:"email_bound"` + LinuxDoBound bool `json:"linuxdo_bound"` + OIDCBound bool `json:"oidc_bound"` + WeChatBound bool `json:"wechat_bound"` +} + +type userProfileSourceContext struct { + Provider string `json:"provider,omitempty"` + Source string `json:"source,omitempty"` +} + // GetProfile handles getting user profile // GET /api/v1/users/me func (h *UserHandler) GetProfile(c *gin.Context) { @@ -47,13 +80,19 @@ func (h *UserHandler) GetProfile(c *gin.Context) { return } - userData, err := h.userService.GetByID(c.Request.Context(), subject.UserID) + userData, err := h.userService.GetProfile(c.Request.Context(), subject.UserID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, userData) if err != nil { response.ErrorFrom(c, err) return } - response.Success(c, dto.UserFromService(userData)) + response.Success(c, profileResp) } // ChangePassword handles changing user password @@ -101,6 +140,7 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) { svcReq := service.UpdateProfileRequest{ Username: req.Username, + AvatarURL: req.AvatarURL, BalanceNotifyEnabled: req.BalanceNotifyEnabled, BalanceNotifyThreshold: req.BalanceNotifyThreshold, } @@ -110,7 +150,149 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) { return } - response.Success(c, dto.UserFromService(updatedUser)) + profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, profileResp) +} + +type StartIdentityBindingRequest struct { + Provider string `json:"provider" binding:"required"` + RedirectTo string `json:"redirect_to"` +} + +type BindEmailIdentityRequest struct { + Email string `json:"email" binding:"required,email"` + VerifyCode string `json:"verify_code" binding:"required"` + Password string `json:"password" binding:"required"` +} + +type SendEmailBindingCodeRequest struct { + Email string `json:"email" binding:"required,email"` +} + +// StartIdentityBinding returns the backend authorize URL for starting a third-party identity bind flow. +// POST /api/v1/user/auth-identities/bind/start +func (h *UserHandler) StartIdentityBinding(c *gin.Context) { + if _, ok := middleware2.GetAuthSubjectFromContext(c); !ok { + response.Unauthorized(c, "User not authenticated") + return + } + + var req StartIdentityBindingRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + result, err := h.userService.PrepareIdentityBindingStart(c.Request.Context(), service.StartUserIdentityBindingRequest{ + Provider: req.Provider, + RedirectTo: req.RedirectTo, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, result) +} + +// BindEmailIdentity verifies and binds a local email identity for the current user. +// POST /api/v1/user/account-bindings/email +func (h *UserHandler) BindEmailIdentity(c *gin.Context) { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + response.Unauthorized(c, "User not authenticated") + return + } + if h.authService == nil { + response.InternalError(c, "Auth service not configured") + return + } + + var req BindEmailIdentityRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + updatedUser, err := h.authService.BindEmailIdentity( + c.Request.Context(), + subject.UserID, + req.Email, + req.VerifyCode, + req.Password, + ) + if err != nil { + response.ErrorFrom(c, err) + return + } + + profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, profileResp) +} + +// UnbindIdentity removes a third-party sign-in provider from the current user. +// DELETE /api/v1/user/account-bindings/:provider +func (h *UserHandler) UnbindIdentity(c *gin.Context) { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + response.Unauthorized(c, "User not authenticated") + return + } + + updatedUser, err := h.userService.UnbindUserAuthProvider( + c.Request.Context(), + subject.UserID, + c.Param("provider"), + ) + if err != nil { + response.ErrorFrom(c, err) + return + } + + profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, profileResp) +} + +// SendEmailBindingCode sends a verification code for the current user's email binding flow. +// POST /api/v1/user/account-bindings/email/send-code +func (h *UserHandler) SendEmailBindingCode(c *gin.Context) { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + response.Unauthorized(c, "User not authenticated") + return + } + if h.authService == nil { + response.InternalError(c, "Auth service not configured") + return + } + + var req SendEmailBindingCodeRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if err := h.authService.SendEmailIdentityBindCode(c.Request.Context(), subject.UserID, req.Email); err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{"message": "Verification code sent successfully"}) } // SendNotifyEmailCodeRequest represents the request to send notify email verification code @@ -176,7 +358,13 @@ func (h *UserHandler) VerifyNotifyEmail(c *gin.Context) { return } - response.Success(c, dto.UserFromService(updatedUser)) + profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, profileResp) } // RemoveNotifyEmailRequest represents the request to remove a notify email @@ -212,7 +400,13 @@ func (h *UserHandler) RemoveNotifyEmail(c *gin.Context) { return } - response.Success(c, dto.UserFromService(updatedUser)) + profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, profileResp) } // ToggleNotifyEmailRequest represents the request to toggle a notify email's disabled state @@ -248,5 +442,116 @@ func (h *UserHandler) ToggleNotifyEmail(c *gin.Context) { return } - response.Success(c, dto.UserFromService(updatedUser)) + profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, profileResp) +} + +func (h *UserHandler) buildUserProfileResponse(ctx context.Context, userID int64, user *service.User) (userProfileResponse, error) { + identities, err := h.userService.GetProfileIdentitySummaries(ctx, userID, user) + if err != nil { + return userProfileResponse{}, err + } + return userProfileResponseFromService(user, identities), nil +} + +func userProfileResponseFromService(user *service.User, identities service.UserIdentitySummarySet) userProfileResponse { + base := dto.UserFromService(user) + if base == nil { + return userProfileResponse{} + } + bindings := userProfileBindingMap(identities) + profileSources, avatarSource, usernameSource := inferUserProfileSources(user, identities) + return userProfileResponse{ + User: *base, + AvatarURL: user.AvatarURL, + AvatarSource: avatarSource, + UsernameSource: usernameSource, + DisplayNameSource: usernameSource, + NicknameSource: usernameSource, + ProfileSources: profileSources, + Identities: identities, + AuthBindings: bindings, + IdentityBindings: bindings, + EmailBound: identities.Email.Bound, + LinuxDoBound: identities.LinuxDo.Bound, + OIDCBound: identities.OIDC.Bound, + WeChatBound: identities.WeChat.Bound, + } +} + +func userProfileBindingMap(identities service.UserIdentitySummarySet) map[string]service.UserIdentitySummary { + return map[string]service.UserIdentitySummary{ + "email": identities.Email, + "linuxdo": identities.LinuxDo, + "oidc": identities.OIDC, + "wechat": identities.WeChat, + } +} + +func inferUserProfileSources(user *service.User, identities service.UserIdentitySummarySet) ( + map[string]*userProfileSourceContext, + *userProfileSourceContext, + *userProfileSourceContext, +) { + if user == nil { + return nil, nil, nil + } + + thirdParty := thirdPartyIdentityProviders(identities) + var avatarSource *userProfileSourceContext + if strings.TrimSpace(user.AvatarURL) != "" && len(thirdParty) == 1 { + avatarSource = buildUserProfileSourceContext(thirdParty[0].Provider) + } + + usernameValue := strings.TrimSpace(user.Username) + var usernameSource *userProfileSourceContext + for _, summary := range thirdParty { + if usernameValue != "" && usernameValue == strings.TrimSpace(summary.DisplayName) { + usernameSource = buildUserProfileSourceContext(summary.Provider) + break + } + } + if usernameSource == nil && usernameValue != "" && len(thirdParty) == 1 { + usernameSource = buildUserProfileSourceContext(thirdParty[0].Provider) + } + + profileSources := map[string]*userProfileSourceContext{} + if avatarSource != nil { + profileSources["avatar"] = avatarSource + } + if usernameSource != nil { + profileSources["username"] = usernameSource + profileSources["display_name"] = usernameSource + profileSources["nickname"] = usernameSource + } + if len(profileSources) == 0 { + return nil, avatarSource, usernameSource + } + return profileSources, avatarSource, usernameSource +} + +func thirdPartyIdentityProviders(identities service.UserIdentitySummarySet) []service.UserIdentitySummary { + out := make([]service.UserIdentitySummary, 0, 3) + for _, summary := range []service.UserIdentitySummary{identities.LinuxDo, identities.OIDC, identities.WeChat} { + if summary.Bound { + out = append(out, summary) + } + } + return out +} + +func buildUserProfileSourceContext(provider string) *userProfileSourceContext { + provider = strings.TrimSpace(provider) + if provider == "" { + return nil + } + return &userProfileSourceContext{ + Provider: provider, + Source: provider, + } } diff --git a/backend/internal/handler/user_handler_test.go b/backend/internal/handler/user_handler_test.go new file mode 100644 index 0000000000000000000000000000000000000000..51d5a8142d3a94836e6b5be6d429da30110c2a9f --- /dev/null +++ b/backend/internal/handler/user_handler_test.go @@ -0,0 +1,593 @@ +//go:build unit + +package handler + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type userHandlerRepoStub struct { + user *service.User + identities []service.UserAuthIdentityRecord + unbound []string +} + +func (s *userHandlerRepoStub) Create(context.Context, *service.User) error { return nil } +func (s *userHandlerRepoStub) GetByID(context.Context, int64) (*service.User, error) { + cloned := *s.user + return &cloned, nil +} +func (s *userHandlerRepoStub) GetByEmail(context.Context, string) (*service.User, error) { + cloned := *s.user + return &cloned, nil +} +func (s *userHandlerRepoStub) GetFirstAdmin(context.Context) (*service.User, error) { + cloned := *s.user + return &cloned, nil +} +func (s *userHandlerRepoStub) Update(_ context.Context, user *service.User) error { + cloned := *user + s.user = &cloned + return nil +} +func (s *userHandlerRepoStub) Delete(context.Context, int64) error { return nil } +func (s *userHandlerRepoStub) GetUserAvatar(context.Context, int64) (*service.UserAvatar, error) { + if s.user == nil || s.user.AvatarURL == "" { + return nil, nil + } + return &service.UserAvatar{ + StorageProvider: s.user.AvatarSource, + URL: s.user.AvatarURL, + ContentType: s.user.AvatarMIME, + ByteSize: s.user.AvatarByteSize, + SHA256: s.user.AvatarSHA256, + }, nil +} +func (s *userHandlerRepoStub) UpsertUserAvatar(_ context.Context, _ int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) { + s.user.AvatarURL = input.URL + s.user.AvatarSource = input.StorageProvider + s.user.AvatarMIME = input.ContentType + s.user.AvatarByteSize = input.ByteSize + s.user.AvatarSHA256 = input.SHA256 + return &service.UserAvatar{ + StorageProvider: input.StorageProvider, + URL: input.URL, + ContentType: input.ContentType, + ByteSize: input.ByteSize, + SHA256: input.SHA256, + }, nil +} +func (s *userHandlerRepoStub) DeleteUserAvatar(context.Context, int64) error { + s.user.AvatarURL = "" + s.user.AvatarSource = "" + s.user.AvatarMIME = "" + s.user.AvatarByteSize = 0 + s.user.AvatarSHA256 = "" + return nil +} +func (s *userHandlerRepoStub) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (s *userHandlerRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (s *userHandlerRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil } +func (s *userHandlerRepoStub) DeductBalance(context.Context, int64, float64) error { return nil } +func (s *userHandlerRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil } +func (s *userHandlerRepoStub) ExistsByEmail(context.Context, string) (bool, error) { return false, nil } +func (s *userHandlerRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { + return 0, nil +} +func (s *userHandlerRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error { + return nil +} +func (s *userHandlerRepoStub) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) { + return map[int64]*time.Time{}, nil +} +func (s *userHandlerRepoStub) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) { + return nil, nil +} +func (s *userHandlerRepoStub) UpdateUserLastActiveAt(_ context.Context, _ int64, activeAt time.Time) error { + if s.user != nil { + s.user.LastActiveAt = &activeAt + } + return nil +} +func (s *userHandlerRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { + return nil +} +func (s *userHandlerRepoStub) UpdateTotpSecret(context.Context, int64, *string) error { return nil } +func (s *userHandlerRepoStub) EnableTotp(context.Context, int64) error { return nil } +func (s *userHandlerRepoStub) DisableTotp(context.Context, int64) error { return nil } +func (s *userHandlerRepoStub) ListUserAuthIdentities(context.Context, int64) ([]service.UserAuthIdentityRecord, error) { + out := make([]service.UserAuthIdentityRecord, len(s.identities)) + copy(out, s.identities) + return out, nil +} +func (s *userHandlerRepoStub) UnbindUserAuthProvider(_ context.Context, _ int64, provider string) error { + s.unbound = append(s.unbound, provider) + filtered := s.identities[:0] + for _, identity := range s.identities { + if identity.ProviderType == provider { + continue + } + filtered = append(filtered, identity) + } + s.identities = append([]service.UserAuthIdentityRecord(nil), filtered...) + return nil +} + +func TestUserHandlerUpdateProfileReturnsAvatarURL(t *testing.T) { + gin.SetMode(gin.TestMode) + + repo := &userHandlerRepoStub{ + user: &service.User{ + ID: 11, + Email: "handler-avatar@example.com", + Username: "handler-avatar", + Role: service.RoleUser, + Status: service.StatusActive, + }, + } + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil) + + body := []byte(`{"avatar_url":"https://cdn.example.com/avatar.png"}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/user", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11}) + + handler.UpdateProfile(c) + + require.Equal(t, http.StatusOK, recorder.Code) + + var resp struct { + Code int `json:"code"` + Data struct { + AvatarURL string `json:"avatar_url"` + Username string `json:"username"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.Equal(t, "https://cdn.example.com/avatar.png", resp.Data.AvatarURL) + require.Equal(t, "handler-avatar", resp.Data.Username) +} + +func TestUserHandlerGetProfileReturnsIdentitySummaries(t *testing.T) { + gin.SetMode(gin.TestMode) + + verifiedAt := time.Date(2026, 4, 20, 8, 30, 0, 0, time.UTC) + repo := &userHandlerRepoStub{ + user: &service.User{ + ID: 11, + Email: "identity@example.com", + Username: "identity-user", + Role: service.RoleUser, + Status: service.StatusActive, + }, + identities: []service.UserAuthIdentityRecord{ + { + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: "linuxdo-subject-123456", + VerifiedAt: &verifiedAt, + Metadata: map[string]any{ + "username": "linuxdo-handle", + }, + }, + { + ProviderType: "oidc", + ProviderKey: "https://issuer.example.com", + ProviderSubject: "oidc-user-abc", + Metadata: map[string]any{ + "suggested_display_name": "OIDC Display", + }, + }, + }, + } + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/user/profile", nil) + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11}) + + handler.GetProfile(c) + + require.Equal(t, http.StatusOK, recorder.Code) + + var resp struct { + Code int `json:"code"` + Data struct { + Identities struct { + Email struct { + Bound bool `json:"bound"` + BoundCount int `json:"bound_count"` + DisplayName string `json:"display_name"` + } `json:"email"` + LinuxDo struct { + Bound bool `json:"bound"` + BoundCount int `json:"bound_count"` + DisplayName string `json:"display_name"` + ProviderKey string `json:"provider_key"` + } `json:"linuxdo"` + OIDC struct { + Bound bool `json:"bound"` + DisplayName string `json:"display_name"` + ProviderKey string `json:"provider_key"` + } `json:"oidc"` + WeChat struct { + Bound bool `json:"bound"` + CanBind bool `json:"can_bind"` + BindStartPath string `json:"bind_start_path"` + } `json:"wechat"` + } `json:"identities"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.True(t, resp.Data.Identities.Email.Bound) + require.Equal(t, 1, resp.Data.Identities.Email.BoundCount) + require.Equal(t, "identity@example.com", resp.Data.Identities.Email.DisplayName) + require.True(t, resp.Data.Identities.LinuxDo.Bound) + require.Equal(t, 1, resp.Data.Identities.LinuxDo.BoundCount) + require.Equal(t, "linuxdo-handle", resp.Data.Identities.LinuxDo.DisplayName) + require.Equal(t, "linuxdo", resp.Data.Identities.LinuxDo.ProviderKey) + require.True(t, resp.Data.Identities.OIDC.Bound) + require.Equal(t, "OIDC Display", resp.Data.Identities.OIDC.DisplayName) + require.Equal(t, "https://issuer.example.com", resp.Data.Identities.OIDC.ProviderKey) + require.False(t, resp.Data.Identities.WeChat.Bound) + require.True(t, resp.Data.Identities.WeChat.CanBind) + require.Contains(t, resp.Data.Identities.WeChat.BindStartPath, "/api/v1/auth/oauth/wechat/start") +} + +func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) { + gin.SetMode(gin.TestMode) + + verifiedAt := time.Date(2026, 4, 20, 8, 30, 0, 0, time.UTC) + repo := &userHandlerRepoStub{ + user: &service.User{ + ID: 21, + Email: "legacy-profile@example.com", + Username: "linuxdo-handle", + Role: service.RoleUser, + Status: service.StatusActive, + AvatarURL: "https://cdn.example.com/linuxdo.png", + AvatarSource: "remote_url", + }, + identities: []service.UserAuthIdentityRecord{ + { + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: "linuxdo-subject-21", + VerifiedAt: &verifiedAt, + Metadata: map[string]any{ + "username": "linuxdo-handle", + }, + }, + }, + } + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/user/profile", nil) + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 21}) + + handler.GetProfile(c) + + require.Equal(t, http.StatusOK, recorder.Code) + + var resp struct { + Code int `json:"code"` + Data map[string]any `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.Equal(t, true, resp.Data["email_bound"]) + require.Equal(t, true, resp.Data["linuxdo_bound"]) + require.Equal(t, false, resp.Data["oidc_bound"]) + require.Equal(t, false, resp.Data["wechat_bound"]) + require.Equal(t, "https://cdn.example.com/linuxdo.png", resp.Data["avatar_url"]) + + avatarSource, ok := resp.Data["avatar_source"].(map[string]any) + require.True(t, ok) + require.Equal(t, "linuxdo", avatarSource["provider"]) + require.Equal(t, "linuxdo", avatarSource["source"]) + + authBindings, ok := resp.Data["auth_bindings"].(map[string]any) + require.True(t, ok) + linuxdoBinding, ok := authBindings["linuxdo"].(map[string]any) + require.True(t, ok) + require.Equal(t, true, linuxdoBinding["bound"]) + require.Equal(t, "linuxdo", linuxdoBinding["provider"]) + + identityBindings, ok := resp.Data["identity_bindings"].(map[string]any) + require.True(t, ok) + emailBinding, ok := identityBindings["email"].(map[string]any) + require.True(t, ok) + require.Equal(t, true, emailBinding["bound"]) + + profileSources, ok := resp.Data["profile_sources"].(map[string]any) + require.True(t, ok) + usernameSource, ok := profileSources["username"].(map[string]any) + require.True(t, ok) + require.Equal(t, "linuxdo", usernameSource["provider"]) + require.Equal(t, "linuxdo", usernameSource["source"]) +} + +type userHandlerEmailCacheStub struct { + data *service.VerificationCodeData +} + +func (s *userHandlerEmailCacheStub) GetVerificationCode(context.Context, string) (*service.VerificationCodeData, error) { + return s.data, nil +} + +func (s *userHandlerEmailCacheStub) SetVerificationCode(context.Context, string, *service.VerificationCodeData, time.Duration) error { + return nil +} + +func (s *userHandlerEmailCacheStub) DeleteVerificationCode(context.Context, string) error { + return nil +} + +func (s *userHandlerEmailCacheStub) GetNotifyVerifyCode(context.Context, string) (*service.VerificationCodeData, error) { + return nil, nil +} + +func (s *userHandlerEmailCacheStub) SetNotifyVerifyCode(context.Context, string, *service.VerificationCodeData, time.Duration) error { + return nil +} + +func (s *userHandlerEmailCacheStub) DeleteNotifyVerifyCode(context.Context, string) error { + return nil +} + +func (s *userHandlerEmailCacheStub) GetPasswordResetToken(context.Context, string) (*service.PasswordResetTokenData, error) { + return nil, nil +} + +func (s *userHandlerEmailCacheStub) SetPasswordResetToken(context.Context, string, *service.PasswordResetTokenData, time.Duration) error { + return nil +} + +func (s *userHandlerEmailCacheStub) DeletePasswordResetToken(context.Context, string) error { + return nil +} + +func (s *userHandlerEmailCacheStub) IsPasswordResetEmailInCooldown(context.Context, string) bool { + return false +} + +func (s *userHandlerEmailCacheStub) SetPasswordResetEmailCooldown(context.Context, string, time.Duration) error { + return nil +} + +func (s *userHandlerEmailCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int64, error) { + return 0, nil +} + +func (s *userHandlerEmailCacheStub) IncrNotifyCodeUserRate(context.Context, int64, time.Duration) (int64, error) { + return 0, nil +} + +func TestUserHandlerBindEmailIdentityReturnsProfileResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + + repo := &userHandlerRepoStub{ + user: &service.User{ + ID: 11, + Email: "legacy-user" + service.LinuxDoConnectSyntheticEmailDomain, + Username: "legacy-user", + Role: service.RoleUser, + Status: service.StatusActive, + }, + } + emailCache := &userHandlerEmailCacheStub{ + data: &service.VerificationCodeData{ + Code: "123456", + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(10 * time.Minute), + }, + } + cfg := &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret", + ExpireHour: 1, + }, + } + emailService := service.NewEmailService(nil, emailCache) + authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil) + + body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"new-password"}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/user/account-bindings/email", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + c.Params = gin.Params{{Key: "provider", Value: "email"}} + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11}) + + handler.BindEmailIdentity(c) + + require.Equal(t, http.StatusOK, recorder.Code) + + var resp struct { + Code int `json:"code"` + Data struct { + Email string `json:"email"` + EmailBound bool `json:"email_bound"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.Equal(t, "new@example.com", resp.Data.Email) + require.True(t, resp.Data.EmailBound) +} + +func TestUserHandlerUnbindIdentityReturnsUpdatedProfile(t *testing.T) { + gin.SetMode(gin.TestMode) + + repo := &userHandlerRepoStub{ + user: &service.User{ + ID: 21, + Email: "identity@example.com", + Username: "identity-user", + Role: service.RoleUser, + Status: service.StatusActive, + }, + identities: []service.UserAuthIdentityRecord{ + { + ProviderType: "email", + ProviderKey: "email", + ProviderSubject: "identity@example.com", + }, + { + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: "linuxdo-subject-21", + Metadata: map[string]any{ + "username": "linuxdo-handle", + }, + }, + }, + } + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodDelete, "/api/v1/user/account-bindings/linuxdo", nil) + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 21}) + c.Params = gin.Params{{Key: "provider", Value: "linuxdo"}} + + handler.UnbindIdentity(c) + + require.Equal(t, http.StatusOK, recorder.Code) + require.Equal(t, []string{"linuxdo"}, repo.unbound) + + var resp struct { + Code int `json:"code"` + Data map[string]any `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + + authBindings, ok := resp.Data["auth_bindings"].(map[string]any) + require.True(t, ok) + linuxdoBinding, ok := authBindings["linuxdo"].(map[string]any) + require.True(t, ok) + require.Equal(t, false, linuxdoBinding["bound"]) +} + +func TestUserHandlerBindEmailIdentityRejectsWrongCurrentPasswordForBoundEmail(t *testing.T) { + gin.SetMode(gin.TestMode) + + user := &service.User{ + ID: 11, + Email: "current@example.com", + Username: "bound-user", + Role: service.RoleUser, + Status: service.StatusActive, + } + require.NoError(t, user.SetPassword("current-password")) + + repo := &userHandlerRepoStub{user: user} + emailCache := &userHandlerEmailCacheStub{ + data: &service.VerificationCodeData{ + Code: "123456", + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(10 * time.Minute), + }, + } + cfg := &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret", + ExpireHour: 1, + }, + } + emailService := service.NewEmailService(nil, emailCache) + authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil) + + body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"wrong-password"}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/user/account-bindings/email", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11}) + + handler.BindEmailIdentity(c) + + require.Equal(t, http.StatusBadRequest, recorder.Code) + + var resp struct { + Code int `json:"code"` + Message string `json:"message"` + Reason string `json:"reason"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, http.StatusBadRequest, resp.Code) + require.Equal(t, "PASSWORD_INCORRECT", resp.Reason) + require.Equal(t, "current password is incorrect", resp.Message) + require.Equal(t, "current@example.com", repo.user.Email) +} + +func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) { + gin.SetMode(gin.TestMode) + + repo := &userHandlerRepoStub{ + user: &service.User{ + ID: 11, + Email: "identity@example.com", + Username: "identity-user", + Role: service.RoleUser, + Status: service.StatusActive, + }, + } + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil) + + body := []byte(`{"provider":"wechat","redirect_to":"/settings/profile"}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/user/auth-identities/bind/start", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11}) + + handler.StartIdentityBinding(c) + + require.Equal(t, http.StatusOK, recorder.Code) + + var resp struct { + Code int `json:"code"` + Data struct { + Provider string `json:"provider"` + AuthorizeURL string `json:"authorize_url"` + Method string `json:"method"` + UseBrowserRedirect bool `json:"use_browser_redirect"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.Equal(t, "wechat", resp.Data.Provider) + require.Equal(t, "GET", resp.Data.Method) + require.True(t, resp.Data.UseBrowserRedirect) + require.Contains(t, resp.Data.AuthorizeURL, "/api/v1/auth/oauth/wechat/start") + require.Contains(t, resp.Data.AuthorizeURL, "intent=bind_current_user") + require.Contains(t, resp.Data.AuthorizeURL, "redirect=%2Fsettings%2Fprofile") +} diff --git a/backend/internal/payment/load_balancer.go b/backend/internal/payment/load_balancer.go index ec244cd676706ad886b446bb78298df4f6b72fb1..41fd2c50cf52a59f8e04ffad6bb33d2e1936b4a4 100644 --- a/backend/internal/payment/load_balancer.go +++ b/backend/internal/payment/load_balancer.go @@ -45,11 +45,31 @@ type DefaultLoadBalancer struct { counter atomic.Uint64 } +type contextKey string + +const wxpayJSAPIAppIDContextKey contextKey = "payment.wxpay.jsapi_app_id" + // NewDefaultLoadBalancer creates a new load balancer. func NewDefaultLoadBalancer(db *dbent.Client, encryptionKey []byte) *DefaultLoadBalancer { return &DefaultLoadBalancer{db: db, encryptionKey: encryptionKey} } +func WithWxpayJSAPIAppID(ctx context.Context, appID string) context.Context { + appID = strings.TrimSpace(appID) + if appID == "" { + return ctx + } + return context.WithValue(ctx, wxpayJSAPIAppIDContextKey, appID) +} + +func wxpayJSAPIAppIDFromContext(ctx context.Context) string { + if ctx == nil { + return "" + } + appID, _ := ctx.Value(wxpayJSAPIAppIDContextKey).(string) + return strings.TrimSpace(appID) +} + // instanceCandidate pairs an instance with its pre-fetched daily usage. type instanceCandidate struct { inst *dbent.PaymentProviderInstance @@ -116,6 +136,7 @@ func (lb *DefaultLoadBalancer) queryEnabledInstances( } var matched []*dbent.PaymentProviderInstance + expectedWxpayJSAPIAppID := wxpayJSAPIAppIDFromContext(ctx) for _, inst := range instances { // Stripe: match by provider_key because supported_types lists sub-types (card,link,alipay,wxpay), // not "stripe" itself. The checkout page aggregates all sub-types under "stripe". @@ -124,6 +145,16 @@ func (lb *DefaultLoadBalancer) queryEnabledInstances( matched = append(matched, inst) } } else if InstanceSupportsType(inst.SupportedTypes, paymentType) { + if expectedWxpayJSAPIAppID != "" && normalizeVisibleMethodSupportType(paymentType) == TypeWxpay && inst.ProviderKey == TypeWxpay { + config, cfgErr := lb.decryptConfig(inst.Config) + if cfgErr != nil { + slog.Warn("skip wxpay instance with unreadable config during jsapi filtering", "instance_id", inst.ID, "error", cfgErr) + continue + } + if resolveWxpayJSAPIAppID(config) != expectedWxpayJSAPIAppID { + continue + } + } matched = append(matched, inst) } } @@ -231,6 +262,11 @@ func getInstanceChannelLimits(inst *dbent.PaymentProviderInstance, paymentType P if cl, ok := limits[lookupKey]; ok { return cl } + if aliasKey := legacyVisibleMethodAlias(lookupKey); aliasKey != "" { + if cl, ok := limits[aliasKey]; ok { + return cl + } + } return ChannelLimits{} } @@ -344,14 +380,45 @@ func InstanceSupportsType(supportedTypes string, target PaymentType) bool { if supportedTypes == "" { return true } + normalizedTarget := normalizeVisibleMethodSupportType(target) for _, t := range strings.Split(supportedTypes, ",") { - if strings.TrimSpace(t) == target { + supported := strings.TrimSpace(t) + if supported == target || normalizeVisibleMethodSupportType(supported) == normalizedTarget { return true } } return false } +func normalizeVisibleMethodSupportType(paymentType PaymentType) PaymentType { + switch strings.TrimSpace(paymentType) { + case TypeAlipay, TypeAlipayDirect: + return TypeAlipay + case TypeWxpay, TypeWxpayDirect: + return TypeWxpay + default: + return strings.TrimSpace(paymentType) + } +} + +func legacyVisibleMethodAlias(paymentType PaymentType) PaymentType { + switch normalizeVisibleMethodSupportType(paymentType) { + case TypeAlipay: + return TypeAlipayDirect + case TypeWxpay: + return TypeWxpayDirect + default: + return "" + } +} + +func resolveWxpayJSAPIAppID(config map[string]string) string { + if appID := strings.TrimSpace(config["mpAppId"]); appID != "" { + return appID + } + return strings.TrimSpace(config["appId"]) +} + // GetInstanceConfig decrypts and returns the configuration for a provider instance by ID. func (lb *DefaultLoadBalancer) GetInstanceConfig(ctx context.Context, instanceID int64) (map[string]string, error) { inst, err := lb.db.PaymentProviderInstance.Get(ctx, instanceID) diff --git a/backend/internal/payment/load_balancer_test.go b/backend/internal/payment/load_balancer_test.go index 2bf4f6ac4116ad4a9dbe59f85a05926ffe507b63..ed08a7dd49092bcf0d3938077d51e7bbb8c50f9d 100644 --- a/backend/internal/payment/load_balancer_test.go +++ b/backend/internal/payment/load_balancer_test.go @@ -68,10 +68,16 @@ func TestInstanceSupportsType(t *testing.T) { expected: true, }, { - name: "partial match should not succeed", + name: "legacy alipay direct supports canonical visible method", supportedTypes: "alipay_direct", target: "alipay", - expected: false, + expected: true, + }, + { + name: "legacy wxpay direct supports canonical visible method", + supportedTypes: "wxpay_direct", + target: "wxpay", + expected: true, }, { name: "empty supported types means all supported", @@ -92,6 +98,22 @@ func TestInstanceSupportsType(t *testing.T) { } } +func TestGetInstanceChannelLimitsFallsBackToLegacyDirectAliases(t *testing.T) { + t.Parallel() + + inst := testInstance(1, TypeAlipay, makeLimitsJSON(TypeAlipayDirect, ChannelLimits{SingleMax: 66})) + got := getInstanceChannelLimits(inst, TypeAlipay) + if got.SingleMax != 66 { + t.Fatalf("getInstanceChannelLimits() = %+v, want SingleMax=66", got) + } + + wxInst := testInstance(2, TypeWxpay, makeLimitsJSON(TypeWxpayDirect, ChannelLimits{SingleMin: 8})) + wxGot := getInstanceChannelLimits(wxInst, TypeWxpay) + if wxGot.SingleMin != 8 { + t.Fatalf("getInstanceChannelLimits() = %+v, want SingleMin=8", wxGot) + } +} + // --------------------------------------------------------------------------- // Helper to build test PaymentProviderInstance values // --------------------------------------------------------------------------- diff --git a/backend/internal/payment/provider/alipay.go b/backend/internal/payment/provider/alipay.go index fe8ea89c0a1df274c857af25aa4ed223026759f4..4a26029586bc1ee6cbdd749fc99a4619a0a1b876 100644 --- a/backend/internal/payment/provider/alipay.go +++ b/backend/internal/payment/provider/alipay.go @@ -26,6 +26,15 @@ const ( alipayRefundSuffix = "-refund" ) +var ( + alipayTradeWapPay = func(client *alipay.Client, param alipay.TradeWapPay) (*url.URL, error) { + return client.TradeWapPay(param) + } + alipayTradePagePay = func(client *alipay.Client, param alipay.TradePagePay) (*url.URL, error) { + return client.TradePagePay(param) + } +) + // Alipay implements payment.Provider and payment.CancelableProvider using the smartwalle/alipay SDK. type Alipay struct { instanceID string @@ -79,6 +88,17 @@ func (a *Alipay) SupportedTypes() []payment.PaymentType { return []payment.PaymentType{payment.TypeAlipay} } +func (a *Alipay) MerchantIdentityMetadata() map[string]string { + if a == nil { + return nil + } + appID := strings.TrimSpace(a.config["appId"]) + if appID == "" { + return nil + } + return map[string]string{"app_id": appID} +} + // CreatePayment creates an Alipay payment using redirect-only flow: // - Mobile (H5): alipay.trade.wap.pay — returns a URL the browser jumps to. // - PC: alipay.trade.page.pay — returns a gateway URL the browser opens in a @@ -115,7 +135,7 @@ func (a *Alipay) createWapTrade(client *alipay.Client, req payment.CreatePayment param.NotifyURL = notifyURL param.ReturnURL = returnURL - payURL, err := client.TradeWapPay(param) + payURL, err := alipayTradeWapPay(client, param) if err != nil { return nil, fmt.Errorf("alipay TradeWapPay: %w", err) } @@ -134,7 +154,7 @@ func (a *Alipay) createPagePayTrade(client *alipay.Client, req payment.CreatePay param.NotifyURL = notifyURL param.ReturnURL = returnURL - payURL, err := client.TradePagePay(param) + payURL, err := alipayTradePagePay(client, param) if err != nil { return nil, fmt.Errorf("alipay TradePagePay: %w", err) } @@ -176,10 +196,11 @@ func (a *Alipay) QueryOrder(ctx context.Context, tradeNo string) (*payment.Query } return &payment.QueryOrderResponse{ - TradeNo: result.TradeNo, - Status: status, - Amount: amount, - PaidAt: result.SendPayDate, + TradeNo: result.TradeNo, + Status: status, + Amount: amount, + PaidAt: result.SendPayDate, + Metadata: a.MerchantIdentityMetadata(), }, nil } @@ -210,12 +231,21 @@ func (a *Alipay) VerifyNotification(ctx context.Context, rawBody string, _ map[s return nil, fmt.Errorf("alipay parse notification amount %q: %w", notification.TotalAmount, err) } + metadata := a.MerchantIdentityMetadata() + if appID := strings.TrimSpace(notification.AppId); appID != "" { + if metadata == nil { + metadata = map[string]string{} + } + metadata["app_id"] = appID + } + return &payment.PaymentNotification{ - TradeNo: notification.TradeNo, - OrderID: notification.OutTradeNo, - Amount: amount, - Status: status, - RawData: rawBody, + TradeNo: notification.TradeNo, + OrderID: notification.OutTradeNo, + Amount: amount, + Status: status, + RawData: rawBody, + Metadata: metadata, }, nil } @@ -278,6 +308,7 @@ func isTradeNotExist(err error) bool { // Ensure interface compliance. var ( - _ payment.Provider = (*Alipay)(nil) - _ payment.CancelableProvider = (*Alipay)(nil) + _ payment.Provider = (*Alipay)(nil) + _ payment.CancelableProvider = (*Alipay)(nil) + _ payment.MerchantIdentityProvider = (*Alipay)(nil) ) diff --git a/backend/internal/payment/provider/alipay_test.go b/backend/internal/payment/provider/alipay_test.go index 7b0ce0d8c86223743bfe8a92f0094033914d6cb6..8b3ff8ce8796f7ae99896675ccb558ab55e61eea 100644 --- a/backend/internal/payment/provider/alipay_test.go +++ b/backend/internal/payment/provider/alipay_test.go @@ -4,8 +4,12 @@ package provider import ( "errors" + "net/url" "strings" "testing" + + "github.com/Wei-Shaw/sub2api/internal/payment" + "github.com/smartwalle/alipay/v3" ) func TestIsTradeNotExist(t *testing.T) { @@ -130,3 +134,96 @@ func TestNewAlipay(t *testing.T) { }) } } + +func TestCreateTradeUsesPagePayForDesktop(t *testing.T) { + origPagePay := alipayTradePagePay + origWapPay := alipayTradeWapPay + t.Cleanup(func() { + alipayTradePagePay = origPagePay + alipayTradeWapPay = origWapPay + }) + + pagePayCalls := 0 + wapPayCalls := 0 + alipayTradePagePay = func(client *alipay.Client, param alipay.TradePagePay) (*url.URL, error) { + pagePayCalls++ + if param.OutTradeNo != "sub2_100" { + t.Fatalf("out_trade_no = %q, want %q", param.OutTradeNo, "sub2_100") + } + if param.NotifyURL != "https://merchant.example.com/api/v1/payment/webhook/alipay" { + t.Fatalf("notify_url = %q", param.NotifyURL) + } + return url.Parse("https://openapi.alipay.com/gateway.do?page-pay") + } + alipayTradeWapPay = func(client *alipay.Client, param alipay.TradeWapPay) (*url.URL, error) { + wapPayCalls++ + return url.Parse("https://openapi.alipay.com/gateway.do?wap-pay") + } + + provider := &Alipay{} + resp, err := provider.createPagePayTrade(&alipay.Client{}, payment.CreatePaymentRequest{ + OrderID: "sub2_100", + Amount: "88.00", + Subject: "Balance recharge", + }, "https://merchant.example.com/api/v1/payment/webhook/alipay", "https://merchant.example.com/payment/result") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if pagePayCalls != 1 { + t.Fatalf("page pay calls = %d, want 1", pagePayCalls) + } + if wapPayCalls != 0 { + t.Fatalf("wap pay calls = %d, want 0", wapPayCalls) + } + if resp.PayURL == "" { + t.Fatal("expected pay_url for desktop page pay") + } +} + +func TestCreateTradeUsesWapPayForMobile(t *testing.T) { + origWapPay := alipayTradeWapPay + t.Cleanup(func() { + alipayTradeWapPay = origWapPay + }) + + wapPayCalls := 0 + alipayTradeWapPay = func(client *alipay.Client, param alipay.TradeWapPay) (*url.URL, error) { + wapPayCalls++ + if param.ReturnURL != "https://merchant.example.com/payment/result" { + t.Fatalf("return_url = %q", param.ReturnURL) + } + return url.Parse("https://openapi.alipay.com/gateway.do?wap-pay") + } + + provider := &Alipay{} + resp, err := provider.createWapTrade(&alipay.Client{}, payment.CreatePaymentRequest{ + OrderID: "sub2_101", + Amount: "18.00", + Subject: "Balance recharge", + IsMobile: true, + }, "https://merchant.example.com/api/v1/payment/webhook/alipay", "https://merchant.example.com/payment/result") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if wapPayCalls != 1 { + t.Fatalf("wap pay calls = %d, want 1", wapPayCalls) + } + if resp.PayURL == "" { + t.Fatal("expected pay_url for mobile wap pay") + } +} + +func TestAlipayMerchantIdentityMetadata(t *testing.T) { + t.Parallel() + + provider := &Alipay{ + config: map[string]string{ + "appId": "2021001234567890", + }, + } + + metadata := provider.MerchantIdentityMetadata() + if metadata["app_id"] != "2021001234567890" { + t.Fatalf("app_id = %q, want %q", metadata["app_id"], "2021001234567890") + } +} diff --git a/backend/internal/payment/provider/easypay.go b/backend/internal/payment/provider/easypay.go index e33a567d0398b240a26144f3f1dd44836748d23d..37bd38b27fc39bfeb055c739b49b399fb32ed9c0 100644 --- a/backend/internal/payment/provider/easypay.go +++ b/backend/internal/payment/provider/easypay.go @@ -59,6 +59,17 @@ func (e *EasyPay) SupportedTypes() []payment.PaymentType { return []payment.PaymentType{payment.TypeAlipay, payment.TypeWxpay} } +func (e *EasyPay) MerchantIdentityMetadata() map[string]string { + if e == nil { + return nil + } + pid := strings.TrimSpace(e.config["pid"]) + if pid == "" { + return nil + } + return map[string]string{"pid": pid} +} + func (e *EasyPay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) { // Payment mode determined by instance config, not payment type. // "popup" → hosted page (submit.php); "qrcode"/default → API call (mapi.php). @@ -178,7 +189,12 @@ func (e *EasyPay) QueryOrder(ctx context.Context, tradeNo string) (*payment.Quer status = payment.ProviderStatusPaid } amount, _ := strconv.ParseFloat(resp.Money, 64) - return &payment.QueryOrderResponse{TradeNo: tradeNo, Status: status, Amount: amount}, nil + return &payment.QueryOrderResponse{ + TradeNo: tradeNo, + Status: status, + Amount: amount, + Metadata: e.MerchantIdentityMetadata(), + }, nil } func (e *EasyPay) VerifyNotification(_ context.Context, rawBody string, _ map[string]string) (*payment.PaymentNotification, error) { @@ -203,9 +219,17 @@ func (e *EasyPay) VerifyNotification(_ context.Context, rawBody string, _ map[st status = payment.ProviderStatusSuccess } amount, _ := strconv.ParseFloat(params["money"], 64) + + metadata := e.MerchantIdentityMetadata() + if pid := strings.TrimSpace(params["pid"]); pid != "" { + if metadata == nil { + metadata = map[string]string{} + } + metadata["pid"] = pid + } return &payment.PaymentNotification{ TradeNo: params["trade_no"], OrderID: params["out_trade_no"], - Amount: amount, Status: status, RawData: rawBody, + Amount: amount, Status: status, RawData: rawBody, Metadata: metadata, }, nil } diff --git a/backend/internal/payment/provider/easypay_sign_test.go b/backend/internal/payment/provider/easypay_sign_test.go index 146a6fa1afd7aea5649cd9d110edcce86dbbdbb0..8328d294e88bae3bdcfb08ef9c267705953bdba0 100644 --- a/backend/internal/payment/provider/easypay_sign_test.go +++ b/backend/internal/payment/provider/easypay_sign_test.go @@ -178,3 +178,18 @@ func TestEasyPayVerifySignWrongSignValue(t *testing.T) { t.Fatal("easyPayVerifySign should return false for an incorrect sign value") } } + +func TestEasyPayMerchantIdentityMetadata(t *testing.T) { + t.Parallel() + + provider := &EasyPay{ + config: map[string]string{ + "pid": "1001", + }, + } + + metadata := provider.MerchantIdentityMetadata() + if metadata["pid"] != "1001" { + t.Fatalf("pid = %q, want %q", metadata["pid"], "1001") + } +} diff --git a/backend/internal/payment/provider/wxpay.go b/backend/internal/payment/provider/wxpay.go index 4df764526e41333812c6977d0b6105c50fffb846..4b3345132909aee15ba543de0a1103013ea082ea 100644 --- a/backend/internal/payment/provider/wxpay.go +++ b/backend/internal/payment/provider/wxpay.go @@ -5,8 +5,8 @@ import ( "context" "fmt" "io" - "log/slog" "net/http" + "net/url" "strconv" "strings" "sync" @@ -20,6 +20,7 @@ import ( "github.com/wechatpay-apiv3/wechatpay-go/core/option" "github.com/wechatpay-apiv3/wechatpay-go/services/payments" "github.com/wechatpay-apiv3/wechatpay-go/services/payments/h5" + "github.com/wechatpay-apiv3/wechatpay-go/services/payments/jsapi" "github.com/wechatpay-apiv3/wechatpay-go/services/payments/native" "github.com/wechatpay-apiv3/wechatpay-go/services/refunddomestic" "github.com/wechatpay-apiv3/wechatpay-go/utils" @@ -27,8 +28,23 @@ import ( // WeChat Pay constants. const ( - wxpayCurrency = "CNY" - wxpayH5Type = "Wap" + wxpayCurrency = "CNY" + wxpayH5Type = "Wap" + wxpayResultPath = "/payment/result" +) + +const ( + wxpayMetadataAppID = "appid" + wxpayMetadataMerchantID = "mchid" + wxpayMetadataCurrency = "currency" + wxpayMetadataTradeState = "trade_state" +) + +// WeChat Pay create-payment modes. +const ( + wxpayModeNative = "native" + wxpayModeH5 = "h5" + wxpayModeJSAPI = "jsapi" ) // WeChat Pay trade states. @@ -49,6 +65,18 @@ const ( wxpayErrNoAuth = "NO_AUTH" ) +var ( + wxpayNativePrepay = func(ctx context.Context, svc native.NativeApiService, req native.PrepayRequest) (*native.PrepayResponse, *core.APIResult, error) { + return svc.Prepay(ctx, req) + } + wxpayH5Prepay = func(ctx context.Context, svc h5.H5ApiService, req h5.PrepayRequest) (*h5.PrepayResponse, *core.APIResult, error) { + return svc.Prepay(ctx, req) + } + wxpayJSAPIPrepayWithRequestPayment = func(ctx context.Context, svc jsapi.JsapiApiService, req jsapi.PrepayRequest) (*jsapi.PrepayWithRequestPaymentResponse, *core.APIResult, error) { + return svc.PrepayWithRequestPayment(ctx, req) + } +) + type Wxpay struct { instanceID string config map[string]string @@ -96,6 +124,16 @@ func (w *Wxpay) SupportedTypes() []payment.PaymentType { return []payment.PaymentType{payment.TypeWxpay} } +// ResolveWxpayJSAPIAppID returns the AppID that JSAPI prepay will use for a +// given provider config. A dedicated MP AppID takes precedence over the base +// merchant AppID. +func ResolveWxpayJSAPIAppID(config map[string]string) string { + if appID := strings.TrimSpace(config["mpAppId"]); appID != "" { + return appID + } + return strings.TrimSpace(config["appId"]) +} + func formatPEM(key, keyType string) string { key = strings.TrimSpace(key) if strings.HasPrefix(key, "-----BEGIN") { @@ -153,30 +191,68 @@ func (w *Wxpay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequ if err != nil { return nil, fmt.Errorf("wxpay create payment: %w", err) } - if req.IsMobile && req.ClientIP != "" { - resp, err := w.createOrder(ctx, client, req, notifyURL, totalFen, true) + + mode, err := resolveWxpayCreateMode(req) + if err != nil { + return nil, err + } + switch mode { + case wxpayModeJSAPI: + return w.prepayJSAPI(ctx, client, req, notifyURL, totalFen) + case wxpayModeH5: + resp, err := w.prepayH5(ctx, client, req, notifyURL, totalFen) if err == nil { return resp, nil } - if !strings.Contains(err.Error(), wxpayErrNoAuth) { - return nil, err + if strings.Contains(err.Error(), wxpayErrNoAuth) { + return nil, fmt.Errorf("wxpay h5 payments are not authorized for this merchant: %w", err) } - slog.Warn("wxpay H5 payment not authorized, falling back to native", "order", req.OrderID) + return nil, err + case wxpayModeNative: + return w.prepayNative(ctx, client, req, notifyURL, totalFen) + default: + return nil, fmt.Errorf("wxpay create payment: unsupported mode %q", mode) } - return w.createOrder(ctx, client, req, notifyURL, totalFen, false) } -func (w *Wxpay) createOrder(ctx context.Context, c *core.Client, req payment.CreatePaymentRequest, notifyURL string, totalFen int64, useH5 bool) (*payment.CreatePaymentResponse, error) { - if useH5 { - return w.prepayH5(ctx, c, req, notifyURL, totalFen) +func (w *Wxpay) prepayJSAPI(ctx context.Context, c *core.Client, req payment.CreatePaymentRequest, notifyURL string, totalFen int64) (*payment.CreatePaymentResponse, error) { + svc := jsapi.JsapiApiService{Client: c} + cur := wxpayCurrency + appID := ResolveWxpayJSAPIAppID(w.config) + prepayReq := jsapi.PrepayRequest{ + Appid: core.String(appID), + Mchid: core.String(w.config["mchId"]), + Description: core.String(req.Subject), + OutTradeNo: core.String(req.OrderID), + NotifyUrl: core.String(notifyURL), + Amount: &jsapi.Amount{Total: core.Int64(totalFen), Currency: &cur}, + Payer: &jsapi.Payer{Openid: core.String(strings.TrimSpace(req.OpenID))}, } - return w.prepayNative(ctx, c, req, notifyURL, totalFen) + if clientIP := strings.TrimSpace(req.ClientIP); clientIP != "" { + prepayReq.SceneInfo = &jsapi.SceneInfo{PayerClientIp: core.String(clientIP)} + } + resp, _, err := wxpayJSAPIPrepayWithRequestPayment(ctx, svc, prepayReq) + if err != nil { + return nil, fmt.Errorf("wxpay jsapi prepay: %w", err) + } + return &payment.CreatePaymentResponse{ + TradeNo: req.OrderID, + ResultType: payment.CreatePaymentResultJSAPIReady, + JSAPI: &payment.WechatJSAPIPayload{ + AppID: wxSV(resp.Appid), + TimeStamp: wxSV(resp.TimeStamp), + NonceStr: wxSV(resp.NonceStr), + Package: wxSV(resp.Package), + SignType: wxSV(resp.SignType), + PaySign: wxSV(resp.PaySign), + }, + }, nil } func (w *Wxpay) prepayNative(ctx context.Context, c *core.Client, req payment.CreatePaymentRequest, notifyURL string, totalFen int64) (*payment.CreatePaymentResponse, error) { svc := native.NativeApiService{Client: c} cur := wxpayCurrency - resp, _, err := svc.Prepay(ctx, native.PrepayRequest{ + resp, _, err := wxpayNativePrepay(ctx, svc, native.PrepayRequest{ Appid: core.String(w.config["appId"]), Mchid: core.String(w.config["mchId"]), Description: core.String(req.Subject), OutTradeNo: core.String(req.OrderID), NotifyUrl: core.String(notifyURL), @@ -195,13 +271,12 @@ func (w *Wxpay) prepayNative(ctx context.Context, c *core.Client, req payment.Cr func (w *Wxpay) prepayH5(ctx context.Context, c *core.Client, req payment.CreatePaymentRequest, notifyURL string, totalFen int64) (*payment.CreatePaymentResponse, error) { svc := h5.H5ApiService{Client: c} cur := wxpayCurrency - tp := wxpayH5Type - resp, _, err := svc.Prepay(ctx, h5.PrepayRequest{ + resp, _, err := wxpayH5Prepay(ctx, svc, h5.PrepayRequest{ Appid: core.String(w.config["appId"]), Mchid: core.String(w.config["mchId"]), Description: core.String(req.Subject), OutTradeNo: core.String(req.OrderID), NotifyUrl: core.String(notifyURL), Amount: &h5.Amount{Total: core.Int64(totalFen), Currency: &cur}, - SceneInfo: &h5.SceneInfo{PayerClientIp: core.String(req.ClientIP), H5Info: &h5.H5Info{Type: &tp}}, + SceneInfo: &h5.SceneInfo{PayerClientIp: core.String(req.ClientIP), H5Info: buildWxpayH5Info(w.config)}, }) if err != nil { return nil, fmt.Errorf("wxpay h5 prepay: %w", err) @@ -210,9 +285,77 @@ func (w *Wxpay) prepayH5(ctx context.Context, c *core.Client, req payment.Create if resp.H5Url != nil { h5URL = *resp.H5Url } + h5URL, err = appendWxpayRedirectURL(h5URL, req) + if err != nil { + return nil, err + } return &payment.CreatePaymentResponse{TradeNo: req.OrderID, PayURL: h5URL}, nil } +func buildWxpayH5Info(config map[string]string) *h5.H5Info { + tp := wxpayH5Type + info := &h5.H5Info{Type: &tp} + if appName := strings.TrimSpace(config["h5AppName"]); appName != "" { + info.AppName = core.String(appName) + } + if appURL := strings.TrimSpace(config["h5AppUrl"]); appURL != "" { + info.AppUrl = core.String(appURL) + } + return info +} + +func resolveWxpayCreateMode(req payment.CreatePaymentRequest) (string, error) { + if strings.TrimSpace(req.OpenID) != "" { + return wxpayModeJSAPI, nil + } + if req.IsMobile { + if strings.TrimSpace(req.ClientIP) == "" { + return "", fmt.Errorf("wxpay H5 payment requires client IP") + } + return wxpayModeH5, nil + } + return wxpayModeNative, nil +} + +func appendWxpayRedirectURL(h5URL string, req payment.CreatePaymentRequest) (string, error) { + h5URL = strings.TrimSpace(h5URL) + returnURL := strings.TrimSpace(req.ReturnURL) + if h5URL == "" || returnURL == "" { + return h5URL, nil + } + + redirectURL, err := buildWxpayResultURL(returnURL, req) + if err != nil { + return "", err + } + + sep := "&" + if !strings.Contains(h5URL, "?") { + sep = "?" + } + return h5URL + sep + "redirect_url=" + url.QueryEscape(redirectURL), nil +} + +func buildWxpayResultURL(returnURL string, req payment.CreatePaymentRequest) (string, error) { + u, err := url.Parse(returnURL) + if err != nil || !u.IsAbs() || u.Host == "" || (u.Scheme != "http" && u.Scheme != "https") { + return "", fmt.Errorf("return URL must be an absolute http(s) URL") + } + + values := u.Query() + values.Set("out_trade_no", strings.TrimSpace(req.OrderID)) + if paymentType := strings.TrimSpace(req.PaymentType); paymentType != "" { + values.Set("payment_type", paymentType) + } + if strings.TrimSpace(u.Path) == "" { + u.Path = wxpayResultPath + } + u.RawPath = "" + u.RawQuery = values.Encode() + u.Fragment = "" + return u.String(), nil +} + func wxSV(s *string) string { if s == nil { return "" @@ -233,6 +376,32 @@ func mapWxState(s string) string { } } +func buildWxpayTransactionMetadata(tx *payments.Transaction) map[string]string { + if tx == nil { + return nil + } + + metadata := map[string]string{} + if appID := wxSV(tx.Appid); appID != "" { + metadata[wxpayMetadataAppID] = appID + } + if merchantID := wxSV(tx.Mchid); merchantID != "" { + metadata[wxpayMetadataMerchantID] = merchantID + } + if tradeState := wxSV(tx.TradeState); tradeState != "" { + metadata[wxpayMetadataTradeState] = tradeState + } + if tx.Amount != nil { + if currency := wxSV(tx.Amount.Currency); currency != "" { + metadata[wxpayMetadataCurrency] = currency + } + } + if len(metadata) == 0 { + return nil + } + return metadata +} + func (w *Wxpay) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryOrderResponse, error) { c, err := w.ensureClient() if err != nil { @@ -257,7 +426,13 @@ func (w *Wxpay) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryO if tx.SuccessTime != nil { pa = *tx.SuccessTime } - return &payment.QueryOrderResponse{TradeNo: id, Status: mapWxState(wxSV(tx.TradeState)), Amount: amt, PaidAt: pa}, nil + return &payment.QueryOrderResponse{ + TradeNo: id, + Status: mapWxState(wxSV(tx.TradeState)), + Amount: amt, + PaidAt: pa, + Metadata: buildWxpayTransactionMetadata(tx), + }, nil } func (w *Wxpay) VerifyNotification(ctx context.Context, rawBody string, headers map[string]string) (*payment.PaymentNotification, error) { @@ -289,7 +464,7 @@ func (w *Wxpay) VerifyNotification(ctx context.Context, rawBody string, headers } return &payment.PaymentNotification{ TradeNo: wxSV(tx.TransactionId), OrderID: wxSV(tx.OutTradeNo), - Amount: amt, Status: st, RawData: rawBody, + Amount: amt, Status: st, RawData: rawBody, Metadata: buildWxpayTransactionMetadata(&tx), }, nil } diff --git a/backend/internal/payment/provider/wxpay_test.go b/backend/internal/payment/provider/wxpay_test.go index 707fec18c87fe7b365db3713d698c20322ad8399..ebbd9d344d688e7de24a20cd61dfecd949f33cbc 100644 --- a/backend/internal/payment/provider/wxpay_test.go +++ b/backend/internal/payment/provider/wxpay_test.go @@ -3,14 +3,21 @@ package provider import ( + "context" "crypto/rand" "crypto/rsa" "crypto/x509" "encoding/pem" + "net/url" "strings" "testing" "github.com/Wei-Shaw/sub2api/internal/payment" + "github.com/wechatpay-apiv3/wechatpay-go/core" + "github.com/wechatpay-apiv3/wechatpay-go/services/payments" + "github.com/wechatpay-apiv3/wechatpay-go/services/payments/h5" + "github.com/wechatpay-apiv3/wechatpay-go/services/payments/jsapi" + "github.com/wechatpay-apiv3/wechatpay-go/services/payments/native" ) // generateTestKeyPair returns a fresh RSA 2048 key pair as PEM strings. @@ -120,6 +127,33 @@ func TestWxSV(t *testing.T) { } } +func TestBuildWxpayTransactionMetadata(t *testing.T) { + t.Parallel() + + tx := &payments.Transaction{ + Appid: strPtr("wx-app-id"), + Mchid: strPtr("mch-id"), + TradeState: strPtr(wxpayTradeStateSuccess), + Amount: &payments.TransactionAmount{ + Currency: strPtr(wxpayCurrency), + }, + } + + metadata := buildWxpayTransactionMetadata(tx) + if metadata[wxpayMetadataAppID] != "wx-app-id" { + t.Fatalf("appid = %q", metadata[wxpayMetadataAppID]) + } + if metadata[wxpayMetadataMerchantID] != "mch-id" { + t.Fatalf("mchid = %q", metadata[wxpayMetadataMerchantID]) + } + if metadata[wxpayMetadataCurrency] != wxpayCurrency { + t.Fatalf("currency = %q", metadata[wxpayMetadataCurrency]) + } + if metadata[wxpayMetadataTradeState] != wxpayTradeStateSuccess { + t.Fatalf("trade_state = %q", metadata[wxpayMetadataTradeState]) + } +} + func strPtr(s string) *string { return &s } @@ -300,3 +334,310 @@ func TestNewWxpay(t *testing.T) { }) } } + +func TestBuildWxpayResultURLPreservesResumeToken(t *testing.T) { + t.Parallel() + + resultURL, err := buildWxpayResultURL("https://app.example.com/payment/result?order_id=42&resume_token=resume-42&status=success", payment.CreatePaymentRequest{ + OrderID: "sub2_42", + PaymentType: payment.TypeWxpay, + }) + if err != nil { + t.Fatalf("buildWxpayResultURL returned error: %v", err) + } + + parsed, err := url.Parse(resultURL) + if err != nil { + t.Fatalf("url.Parse returned error: %v", err) + } + query := parsed.Query() + if parsed.Path != wxpayResultPath { + t.Fatalf("path = %q, want %q", parsed.Path, wxpayResultPath) + } + if query.Get("resume_token") != "resume-42" { + t.Fatalf("resume_token = %q, want %q", query.Get("resume_token"), "resume-42") + } + if query.Get("order_id") != "42" { + t.Fatalf("order_id = %q, want %q", query.Get("order_id"), "42") + } + if query.Get("out_trade_no") != "sub2_42" { + t.Fatalf("out_trade_no = %q, want %q", query.Get("out_trade_no"), "sub2_42") + } +} + +func TestResolveWxpayJSAPIAppID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + config map[string]string + want string + }{ + { + name: "prefers dedicated mp app id", + config: map[string]string{ + "mpAppId": "wx-mp-app", + "appId": "wx-merchant-app", + }, + want: "wx-mp-app", + }, + { + name: "falls back to merchant app id", + config: map[string]string{ + "appId": "wx-merchant-app", + }, + want: "wx-merchant-app", + }, + { + name: "missing app ids returns empty", + config: map[string]string{}, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := ResolveWxpayJSAPIAppID(tt.config); got != tt.want { + t.Fatalf("ResolveWxpayJSAPIAppID() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestResolveWxpayCreateMode(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + req payment.CreatePaymentRequest + wantMode string + wantErr string + }{ + { + name: "desktop uses native", + req: payment.CreatePaymentRequest{}, + wantMode: wxpayModeNative, + }, + { + name: "mobile uses h5 when client ip is present", + req: payment.CreatePaymentRequest{ + IsMobile: true, + ClientIP: "203.0.113.10", + }, + wantMode: wxpayModeH5, + }, + { + name: "mobile without client ip returns clear error", + req: payment.CreatePaymentRequest{ + IsMobile: true, + }, + wantErr: "requires client IP", + }, + { + name: "openid uses jsapi mode", + req: payment.CreatePaymentRequest{ + OpenID: "openid-123", + }, + wantMode: wxpayModeJSAPI, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, err := resolveWxpayCreateMode(tt.req) + if tt.wantErr != "" { + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("error %q should contain %q", err.Error(), tt.wantErr) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != tt.wantMode { + t.Fatalf("resolveWxpayCreateMode() = %q, want %q", got, tt.wantMode) + } + }) + } +} + +func TestCreatePaymentWithOpenIDReturnsJSAPIResult(t *testing.T) { + origJSAPIPrepay := wxpayJSAPIPrepayWithRequestPayment + origNativePrepay := wxpayNativePrepay + origH5Prepay := wxpayH5Prepay + t.Cleanup(func() { + wxpayJSAPIPrepayWithRequestPayment = origJSAPIPrepay + wxpayNativePrepay = origNativePrepay + wxpayH5Prepay = origH5Prepay + }) + + jsapiCalls := 0 + nativeCalls := 0 + h5Calls := 0 + wxpayJSAPIPrepayWithRequestPayment = func(ctx context.Context, svc jsapi.JsapiApiService, req jsapi.PrepayRequest) (*jsapi.PrepayWithRequestPaymentResponse, *core.APIResult, error) { + jsapiCalls++ + if got := wxSV(req.Payer.Openid); got != "openid-123" { + t.Fatalf("openid = %q, want %q", got, "openid-123") + } + if req.SceneInfo == nil || wxSV(req.SceneInfo.PayerClientIp) != "203.0.113.10" { + t.Fatalf("scene_info payer_client_ip = %q, want %q", wxSV(req.SceneInfo.PayerClientIp), "203.0.113.10") + } + return &jsapi.PrepayWithRequestPaymentResponse{ + Appid: core.String("wx123"), + TimeStamp: core.String("1712345678"), + NonceStr: core.String("nonce-123"), + Package: core.String("prepay_id=wx_prepay_123"), + SignType: core.String("RSA"), + PaySign: core.String("signed-payload"), + }, nil, nil + } + wxpayNativePrepay = func(ctx context.Context, svc native.NativeApiService, req native.PrepayRequest) (*native.PrepayResponse, *core.APIResult, error) { + nativeCalls++ + return &native.PrepayResponse{}, nil, nil + } + wxpayH5Prepay = func(ctx context.Context, svc h5.H5ApiService, req h5.PrepayRequest) (*h5.PrepayResponse, *core.APIResult, error) { + h5Calls++ + return &h5.PrepayResponse{}, nil, nil + } + + provider := &Wxpay{ + config: map[string]string{ + "appId": "wx123", + "mchId": "mch123", + }, + coreClient: &core.Client{}, + } + + resp, err := provider.CreatePayment(context.Background(), payment.CreatePaymentRequest{ + OrderID: "sub2_88", + Amount: "66.88", + PaymentType: payment.TypeWxpay, + NotifyURL: "https://merchant.example/payment/notify", + OpenID: "openid-123", + ClientIP: "203.0.113.10", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if jsapiCalls != 1 { + t.Fatalf("jsapi prepay calls = %d, want 1", jsapiCalls) + } + if nativeCalls != 0 { + t.Fatalf("native prepay calls = %d, want 0", nativeCalls) + } + if h5Calls != 0 { + t.Fatalf("h5 prepay calls = %d, want 0", h5Calls) + } + if resp.ResultType != payment.CreatePaymentResultJSAPIReady { + t.Fatalf("result type = %q, want %q", resp.ResultType, payment.CreatePaymentResultJSAPIReady) + } + if resp.JSAPI == nil { + t.Fatal("expected jsapi payload, got nil") + } + if resp.JSAPI.AppID != "wx123" { + t.Fatalf("jsapi appId = %q, want %q", resp.JSAPI.AppID, "wx123") + } + if resp.JSAPI.TimeStamp != "1712345678" { + t.Fatalf("jsapi timeStamp = %q, want %q", resp.JSAPI.TimeStamp, "1712345678") + } + if resp.JSAPI.NonceStr != "nonce-123" { + t.Fatalf("jsapi nonceStr = %q, want %q", resp.JSAPI.NonceStr, "nonce-123") + } + if resp.JSAPI.Package != "prepay_id=wx_prepay_123" { + t.Fatalf("jsapi package = %q, want %q", resp.JSAPI.Package, "prepay_id=wx_prepay_123") + } + if resp.JSAPI.SignType != "RSA" { + t.Fatalf("jsapi signType = %q, want %q", resp.JSAPI.SignType, "RSA") + } + if resp.JSAPI.PaySign != "signed-payload" { + t.Fatalf("jsapi paySign = %q, want %q", resp.JSAPI.PaySign, "signed-payload") + } +} + +func TestCreatePaymentMobileH5IncludesConfiguredSceneInfo(t *testing.T) { + origJSAPIPrepay := wxpayJSAPIPrepayWithRequestPayment + origNativePrepay := wxpayNativePrepay + origH5Prepay := wxpayH5Prepay + t.Cleanup(func() { + wxpayJSAPIPrepayWithRequestPayment = origJSAPIPrepay + wxpayNativePrepay = origNativePrepay + wxpayH5Prepay = origH5Prepay + }) + + jsapiCalls := 0 + nativeCalls := 0 + h5Calls := 0 + wxpayJSAPIPrepayWithRequestPayment = func(ctx context.Context, svc jsapi.JsapiApiService, req jsapi.PrepayRequest) (*jsapi.PrepayWithRequestPaymentResponse, *core.APIResult, error) { + jsapiCalls++ + return &jsapi.PrepayWithRequestPaymentResponse{}, nil, nil + } + wxpayNativePrepay = func(ctx context.Context, svc native.NativeApiService, req native.PrepayRequest) (*native.PrepayResponse, *core.APIResult, error) { + nativeCalls++ + return &native.PrepayResponse{}, nil, nil + } + wxpayH5Prepay = func(ctx context.Context, svc h5.H5ApiService, req h5.PrepayRequest) (*h5.PrepayResponse, *core.APIResult, error) { + h5Calls++ + if req.SceneInfo == nil { + t.Fatal("expected scene_info, got nil") + } + if got := wxSV(req.SceneInfo.PayerClientIp); got != "203.0.113.10" { + t.Fatalf("scene_info payer_client_ip = %q, want %q", got, "203.0.113.10") + } + if req.SceneInfo.H5Info == nil { + t.Fatal("expected scene_info.h5_info, got nil") + } + if got := wxSV(req.SceneInfo.H5Info.Type); got != wxpayH5Type { + t.Fatalf("scene_info.h5_info.type = %q, want %q", got, wxpayH5Type) + } + if got := wxSV(req.SceneInfo.H5Info.AppName); got != "Sub2API" { + t.Fatalf("scene_info.h5_info.app_name = %q, want %q", got, "Sub2API") + } + if got := wxSV(req.SceneInfo.H5Info.AppUrl); got != "https://app.example.com" { + t.Fatalf("scene_info.h5_info.app_url = %q, want %q", got, "https://app.example.com") + } + return &h5.PrepayResponse{ + H5Url: core.String("https://wx.tenpay.example/h5pay?prepay_id=1"), + }, nil, nil + } + + provider := &Wxpay{ + config: map[string]string{ + "appId": "wx123", + "mchId": "mch123", + "h5AppName": "Sub2API", + "h5AppUrl": "https://app.example.com", + }, + coreClient: &core.Client{}, + } + + resp, err := provider.CreatePayment(context.Background(), payment.CreatePaymentRequest{ + OrderID: "sub2_99", + Amount: "66.88", + PaymentType: payment.TypeWxpay, + Subject: "Balance Recharge", + NotifyURL: "https://merchant.example/payment/notify", + ReturnURL: "https://merchant.example/payment/result?resume_token=resume-99", + ClientIP: "203.0.113.10", + IsMobile: true, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if jsapiCalls != 0 { + t.Fatalf("jsapi prepay calls = %d, want 0", jsapiCalls) + } + if nativeCalls != 0 { + t.Fatalf("native prepay calls = %d, want 0", nativeCalls) + } + if h5Calls != 1 { + t.Fatalf("h5 prepay calls = %d, want 1", h5Calls) + } + if !strings.Contains(resp.PayURL, "redirect_url=") { + t.Fatalf("pay_url = %q, want redirect_url query appended", resp.PayURL) + } +} diff --git a/backend/internal/payment/types.go b/backend/internal/payment/types.go index 5d613a4a9d746480afbd4bbd583fc778a27ba617..e7ac6727b95c2cb9d56fb61fe400380751f68c52 100644 --- a/backend/internal/payment/types.go +++ b/backend/internal/payment/types.go @@ -101,34 +101,69 @@ type CreatePaymentRequest struct { Subject string // Product description NotifyURL string // Webhook callback URL ReturnURL string // Browser redirect URL after payment + OpenID string // WeChat JSAPI payer OpenID when available ClientIP string // Payer's IP address IsMobile bool // Whether the request comes from a mobile device InstanceSubMethods string // Comma-separated sub-methods from instance supported_types (for Stripe) } +// CreatePaymentResultType describes the shape of the create-payment result. +type CreatePaymentResultType = string + +const ( + CreatePaymentResultOrderCreated CreatePaymentResultType = "order_created" + CreatePaymentResultOAuthRequired CreatePaymentResultType = "oauth_required" + CreatePaymentResultJSAPIReady CreatePaymentResultType = "jsapi_ready" +) + +// WechatOAuthInfo describes the next step when WeChat OAuth is required before payment. +type WechatOAuthInfo struct { + AuthorizeURL string `json:"authorize_url,omitempty"` + AppID string `json:"appid,omitempty"` + OpenID string `json:"openid,omitempty"` + Scope string `json:"scope,omitempty"` + State string `json:"state,omitempty"` + RedirectURL string `json:"redirect_url,omitempty"` +} + +// WechatJSAPIPayload contains the fields the frontend needs to invoke WeChat JSAPI payment. +type WechatJSAPIPayload struct { + AppID string `json:"appId,omitempty"` + TimeStamp string `json:"timeStamp,omitempty"` + NonceStr string `json:"nonceStr,omitempty"` + Package string `json:"package,omitempty"` + SignType string `json:"signType,omitempty"` + PaySign string `json:"paySign,omitempty"` +} + // CreatePaymentResponse is returned after successfully initiating a payment. type CreatePaymentResponse struct { - TradeNo string // Third-party transaction ID - PayURL string // H5 payment URL (alipay/wxpay) - QRCode string // QR code content for scanning - ClientSecret string // Stripe PaymentIntent client secret + TradeNo string // Third-party transaction ID + PayURL string // H5 payment URL (alipay/wxpay) + QRCode string // QR code content for scanning + ClientSecret string // Stripe PaymentIntent client secret + ResultType CreatePaymentResultType // Typed result contract for frontend flows + OAuth *WechatOAuthInfo // WeChat OAuth bootstrap payload when required + JSAPI *WechatJSAPIPayload // WeChat JSAPI invocation payload when ready } // QueryOrderResponse describes the payment status from the upstream provider. type QueryOrderResponse struct { - TradeNo string - Status string // "pending", "paid", "failed", "refunded" - Amount float64 // Amount in CNY - PaidAt string // RFC3339 timestamp or empty + TradeNo string + Status string // "pending", "paid", "failed", "refunded" + Amount float64 // Amount in CNY + PaidAt string // RFC3339 timestamp or empty + Metadata map[string]string } // PaymentNotification is the parsed result of a webhook/notify callback. type PaymentNotification struct { - TradeNo string - OrderID string - Amount float64 - Status string // "success" or "failed" - RawData string // Raw notification body for audit + TradeNo string + OrderID string + Amount float64 + Status string // "success" or "failed" + RawData string // Raw notification body for audit + Metadata map[string]string } // RefundRequest contains the parameters for requesting a refund. @@ -179,3 +214,9 @@ type CancelableProvider interface { // CancelPayment cancels/expires a pending payment on the upstream platform. CancelPayment(ctx context.Context, tradeNo string) error } + +// MerchantIdentityProvider exposes the current non-sensitive merchant identity +// derived from provider configuration for snapshot consistency checks. +type MerchantIdentityProvider interface { + MerchantIdentityMetadata() map[string]string +} diff --git a/backend/internal/repository/announcement_read_repo.go b/backend/internal/repository/announcement_read_repo.go index 2dc346b15544ef3f8a886bbec91acf145a3d8894..5268ec45ff610cdcf37fce9e42f54b6f223a150f 100644 --- a/backend/internal/repository/announcement_read_repo.go +++ b/backend/internal/repository/announcement_read_repo.go @@ -19,13 +19,17 @@ func NewAnnouncementReadRepository(client *dbent.Client) service.AnnouncementRea func (r *announcementReadRepository) MarkRead(ctx context.Context, announcementID, userID int64, readAt time.Time) error { client := clientFromContext(ctx, r.client) - return client.AnnouncementRead.Create(). + err := client.AnnouncementRead.Create(). SetAnnouncementID(announcementID). SetUserID(userID). SetReadAt(readAt). OnConflictColumns(announcementread.FieldAnnouncementID, announcementread.FieldUserID). DoNothing(). Exec(ctx) + if isSQLNoRowsError(err) { + return nil + } + return err } func (r *announcementReadRepository) GetReadMapByUser(ctx context.Context, userID int64, announcementIDs []int64) (map[int64]time.Time, error) { diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 38ea9bde3d73e704e014d793dd1bb366a277c5cc..36d80309f7e46be0aa216d8f057aed1c93ae74a0 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -149,6 +149,9 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se user.FieldBalanceNotifyThreshold, user.FieldBalanceNotifyExtraEmails, user.FieldTotalRecharged, + user.FieldSignupSource, + user.FieldLastLoginAt, + user.FieldLastActiveAt, ) }). WithGroup(func(q *dbent.GroupQuery) { @@ -656,6 +659,9 @@ func userEntityToService(u *dbent.User) *service.User { Balance: u.Balance, Concurrency: u.Concurrency, Status: u.Status, + SignupSource: u.SignupSource, + LastLoginAt: u.LastLoginAt, + LastActiveAt: u.LastActiveAt, TotpSecretEncrypted: u.TotpSecretEncrypted, TotpEnabled: u.TotpEnabled, TotpEnabledAt: u.TotpEnabledAt, diff --git a/backend/internal/repository/auth_identity_compat_backfill_integration_test.go b/backend/internal/repository/auth_identity_compat_backfill_integration_test.go new file mode 100644 index 0000000000000000000000000000000000000000..56b375125a9c3e3601ec4495102e284c2818438d --- /dev/null +++ b/backend/internal/repository/auth_identity_compat_backfill_integration_test.go @@ -0,0 +1,73 @@ +//go:build integration + +package repository + +import ( + "context" + "os" + "path/filepath" + "strconv" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAuthIdentityCompatBackfillMigration_AllowsLongReportTypes(t *testing.T) { + tx := testTx(t) + ctx := context.Background() + + migration108Path := filepath.Join("..", "..", "migrations", "108_auth_identity_foundation_core.sql") + migration108SQL, err := os.ReadFile(migration108Path) + require.NoError(t, err) + + migration109Path := filepath.Join("..", "..", "migrations", "109_auth_identity_compat_backfill.sql") + migration109SQL, err := os.ReadFile(migration109Path) + require.NoError(t, err) + + _, err = tx.ExecContext(ctx, ` +DROP TABLE IF EXISTS auth_identity_migration_reports CASCADE; +DROP TABLE IF EXISTS auth_identity_channels CASCADE; +DROP TABLE IF EXISTS identity_adoption_decisions CASCADE; +DROP TABLE IF EXISTS pending_auth_sessions CASCADE; +DROP TABLE IF EXISTS auth_identities CASCADE; + +ALTER TABLE users + DROP COLUMN IF EXISTS signup_source, + DROP COLUMN IF EXISTS last_login_at, + DROP COLUMN IF EXISTS last_active_at; +`) + require.NoError(t, err) + + _, err = tx.ExecContext(ctx, string(migration108SQL)) + require.NoError(t, err) + + var userID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('oidc-demo-subject@oidc-connect.invalid', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&userID)) + + _, err = tx.ExecContext(ctx, string(migration109SQL)) + require.NoError(t, err) + + var reportCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_migration_reports +WHERE report_type = 'oidc_synthetic_email_requires_manual_recovery' + AND report_key = $1 +`, strconv.FormatInt(userID, 10)).Scan(&reportCount)) + require.Equal(t, 1, reportCount) + + var reportTypeLimit int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT character_maximum_length +FROM information_schema.columns +WHERE table_schema = 'public' + AND table_name = 'auth_identity_migration_reports' + AND column_name = 'report_type' +`).Scan(&reportTypeLimit)) + require.GreaterOrEqual(t, reportTypeLimit, 45) + + require.NotZero(t, userID) +} diff --git a/backend/internal/repository/auth_identity_legacy_migration_integration_test.go b/backend/internal/repository/auth_identity_legacy_migration_integration_test.go new file mode 100644 index 0000000000000000000000000000000000000000..e59c257cd3ef01c5da6b3c769b28bee7518f8a4e --- /dev/null +++ b/backend/internal/repository/auth_identity_legacy_migration_integration_test.go @@ -0,0 +1,648 @@ +//go:build integration + +package repository + +import ( + "context" + "os" + "path/filepath" + "strconv" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAuthIdentityLegacyExternalBackfillMigration(t *testing.T) { + tx := testTx(t) + ctx := context.Background() + + migrationPath := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql") + migrationSQL, err := os.ReadFile(migrationPath) + require.NoError(t, err) + + _, err = tx.ExecContext(ctx, ` +CREATE TABLE IF NOT EXISTS user_external_identities ( + id BIGSERIAL PRIMARY KEY, + user_id BIGINT NOT NULL, + provider TEXT NOT NULL, + provider_user_id TEXT NOT NULL, + provider_union_id TEXT NULL, + provider_username TEXT NOT NULL DEFAULT '', + display_name TEXT NOT NULL DEFAULT '', + profile_url TEXT NOT NULL DEFAULT '', + avatar_url TEXT NOT NULL DEFAULT '', + metadata TEXT NOT NULL DEFAULT '{}', + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP +); + + TRUNCATE TABLE + auth_identity_channels, + identity_adoption_decisions, + auth_identities, + auth_identity_migration_reports, + user_external_identities, + users + RESTART IDENTITY CASCADE; +`) + require.NoError(t, err) + + var linuxDoUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-linuxdo@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&linuxDoUserID)) + + var wechatUnionUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-wechat-union@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&wechatUnionUserID)) + + var wechatOpenIDOnlyUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-wechat-openid@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&wechatOpenIDOnlyUserID)) + + var syntheticAuthIdentityID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO auth_identities (user_id, provider_type, provider_key, provider_subject, metadata) +VALUES ($1, 'wechat', 'wechat-main', 'openid-synthetic', '{"backfill_source":"synthetic_email"}'::jsonb) +RETURNING id`, wechatOpenIDOnlyUserID).Scan(&syntheticAuthIdentityID)) + + var linuxDoLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'linuxdo', 'linuxdo-user-1', NULL, 'linux-user', 'Linux User', '{"source":"legacy"}') +RETURNING id +`, linuxDoUserID).Scan(&linuxDoLegacyID)) + + var wechatUnionLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'wechat', 'openid-union-1', 'union-1', 'wechat-union-user', 'WeChat Union User', '{"channel":"oa","appid":"wx-app-1"}') +RETURNING id +`, wechatUnionUserID).Scan(&wechatUnionLegacyID)) + + var wechatOpenIDLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'wechat', 'openid-only-1', NULL, 'wechat-openid-user', 'WeChat OpenID User', '{"channel":"oa","appid":"wx-app-2"}') +RETURNING id +`, wechatOpenIDOnlyUserID).Scan(&wechatOpenIDLegacyID)) + + _, err = tx.ExecContext(ctx, string(migrationSQL)) + require.NoError(t, err) + + var linuxDoCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identities +WHERE user_id = $1 + AND provider_type = 'linuxdo' + AND provider_key = 'linuxdo' + AND provider_subject = 'linuxdo-user-1' +`, linuxDoUserID).Scan(&linuxDoCount)) + require.Equal(t, 1, linuxDoCount) + + var wechatSubject string + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT provider_subject +FROM auth_identities +WHERE user_id = $1 + AND provider_type = 'wechat' + AND provider_key = 'wechat-main' + AND provider_subject = 'union-1' +`, wechatUnionUserID).Scan(&wechatSubject)) + require.Equal(t, "union-1", wechatSubject) + + var wechatChannelCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_channels channel +JOIN auth_identities ai ON ai.id = channel.identity_id +WHERE ai.user_id = $1 + AND channel.provider_type = 'wechat' + AND channel.provider_key = 'wechat-main' + AND channel.channel = 'oa' + AND channel.channel_app_id = 'wx-app-1' + AND channel.channel_subject = 'openid-union-1' +`, wechatUnionUserID).Scan(&wechatChannelCount)) + require.Equal(t, 1, wechatChannelCount) + + var legacyOpenIDOnlyReportCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_migration_reports +WHERE report_type = 'wechat_openid_only_requires_remediation' + AND report_key = $1 +`, "legacy_external_identity:"+strconv.FormatInt(wechatOpenIDLegacyID, 10)).Scan(&legacyOpenIDOnlyReportCount)) + require.Equal(t, 1, legacyOpenIDOnlyReportCount) + + var syntheticReviewCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_migration_reports +WHERE report_type = 'wechat_openid_only_requires_remediation' + AND report_key = $1 +`, "synthetic_auth_identity:"+strconv.FormatInt(syntheticAuthIdentityID, 10)).Scan(&syntheticReviewCount)) + require.Equal(t, 1, syntheticReviewCount) + + var unionLegacyReportCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_migration_reports +WHERE report_type = 'wechat_openid_only_requires_remediation' + AND report_key = $1 +`, "legacy_external_identity:"+strconv.FormatInt(wechatUnionLegacyID, 10)).Scan(&unionLegacyReportCount)) + require.Zero(t, unionLegacyReportCount) + require.NotZero(t, linuxDoLegacyID) +} + +func TestAuthIdentityLegacyExternalBackfillMigration_IsSafeWhenLegacyTableMissing(t *testing.T) { + tx := testTx(t) + ctx := context.Background() + + migrationPath := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql") + migrationSQL, err := os.ReadFile(migrationPath) + require.NoError(t, err) + + var beforeCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_migration_reports +`).Scan(&beforeCount)) + + _, err = tx.ExecContext(ctx, string(migrationSQL)) + require.NoError(t, err) + + var afterCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_migration_reports + `).Scan(&afterCount)) + require.Equal(t, beforeCount, afterCount) +} + +func TestAuthIdentityLegacyExternalMigrations_ChainHandlesMalformedAndNonObjectMetadata(t *testing.T) { + tx := testTx(t) + ctx := context.Background() + + migration115Path := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql") + migration115SQL, err := os.ReadFile(migration115Path) + require.NoError(t, err) + + migration116Path := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql") + migration116SQL, err := os.ReadFile(migration116Path) + require.NoError(t, err) + + _, err = tx.ExecContext(ctx, ` +CREATE TABLE IF NOT EXISTS user_external_identities ( + id BIGSERIAL PRIMARY KEY, + user_id BIGINT NOT NULL, + provider TEXT NOT NULL, + provider_user_id TEXT NOT NULL, + provider_union_id TEXT NULL, + provider_username TEXT NOT NULL DEFAULT '', + display_name TEXT NOT NULL DEFAULT '', + profile_url TEXT NOT NULL DEFAULT '', + avatar_url TEXT NOT NULL DEFAULT '', + metadata TEXT NOT NULL DEFAULT '{}', + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +TRUNCATE TABLE + auth_identity_channels, + identity_adoption_decisions, + auth_identities, + auth_identity_migration_reports, + user_external_identities, + users +RESTART IDENTITY CASCADE; +`) + require.NoError(t, err) + + var linuxDoMalformedUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-linuxdo-malformed@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&linuxDoMalformedUserID)) + + var linuxDoArrayUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-linuxdo-array@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&linuxDoArrayUserID)) + + var wechatUnionArrayUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-wechat-array@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&wechatUnionArrayUserID)) + + var wechatOpenIDArrayUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-wechat-openid-array@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&wechatOpenIDArrayUserID)) + + var linuxDoMalformedLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'linuxdo', 'linuxdo-malformed', NULL, 'legacy-linuxdo-malformed', 'Legacy LinuxDo Malformed', '{invalid') +RETURNING id +`, linuxDoMalformedUserID).Scan(&linuxDoMalformedLegacyID)) + + var linuxDoArrayLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'linuxdo', 'linuxdo-array', NULL, 'legacy-linuxdo-array', 'Legacy LinuxDo Array', '["legacy-linuxdo-array"]') +RETURNING id +`, linuxDoArrayUserID).Scan(&linuxDoArrayLegacyID)) + + var wechatUnionArrayLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'wechat', 'openid-array', 'union-array', 'legacy-wechat-array', 'Legacy WeChat Array', '["legacy-wechat-array"]') +RETURNING id +`, wechatUnionArrayUserID).Scan(&wechatUnionArrayLegacyID)) + + var wechatOpenIDArrayLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'wechat', 'openid-array-only', NULL, 'legacy-wechat-array-only', 'Legacy WeChat Array Only', '["legacy-wechat-openid-array"]') +RETURNING id +`, wechatOpenIDArrayUserID).Scan(&wechatOpenIDArrayLegacyID)) + + _, err = tx.ExecContext(ctx, string(migration115SQL)) + require.NoError(t, err) + + _, err = tx.ExecContext(ctx, string(migration116SQL)) + require.NoError(t, err) + + var linuxDoMalformedMetadataType string + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT jsonb_typeof(metadata) +FROM auth_identities +WHERE user_id = $1 + AND provider_type = 'linuxdo' + AND provider_key = 'linuxdo' + AND provider_subject = 'linuxdo-malformed' +`, linuxDoMalformedUserID).Scan(&linuxDoMalformedMetadataType)) + require.Equal(t, "object", linuxDoMalformedMetadataType) + + var linuxDoArrayMetadataType string + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT jsonb_typeof(metadata) +FROM auth_identities +WHERE user_id = $1 + AND provider_type = 'linuxdo' + AND provider_key = 'linuxdo' + AND provider_subject = 'linuxdo-array' +`, linuxDoArrayUserID).Scan(&linuxDoArrayMetadataType)) + require.Equal(t, "object", linuxDoArrayMetadataType) + + var wechatUnionArrayMetadataType string + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT jsonb_typeof(metadata) +FROM auth_identities +WHERE user_id = $1 + AND provider_type = 'wechat' + AND provider_key = 'wechat-main' + AND provider_subject = 'union-array' +`, wechatUnionArrayUserID).Scan(&wechatUnionArrayMetadataType)) + require.Equal(t, "object", wechatUnionArrayMetadataType) + + var invalidJSONReportDetailsType string + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT jsonb_typeof(details) +FROM auth_identity_migration_reports +WHERE report_type = 'legacy_external_identity_invalid_metadata_json' + AND report_key = $1 +`, "legacy_external_identity:"+strconv.FormatInt(linuxDoMalformedLegacyID, 10)).Scan(&invalidJSONReportDetailsType)) + require.Equal(t, "object", invalidJSONReportDetailsType) + + var openIDOnlyReportDetailsType string + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT jsonb_typeof(details) +FROM auth_identity_migration_reports +WHERE report_type = 'wechat_openid_only_requires_remediation' + AND report_key = $1 +`, "legacy_external_identity:"+strconv.FormatInt(wechatOpenIDArrayLegacyID, 10)).Scan(&openIDOnlyReportDetailsType)) + require.Equal(t, "object", openIDOnlyReportDetailsType) + + var preservedArrayMetadataCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identities +WHERE id IN ( + SELECT id + FROM auth_identities + WHERE (user_id = $1 AND provider_subject = 'linuxdo-array') + OR (user_id = $2 AND provider_subject = 'union-array') +) + AND metadata ? '_legacy_metadata_raw_json' +`, linuxDoArrayUserID, wechatUnionArrayUserID).Scan(&preservedArrayMetadataCount)) + require.Equal(t, 2, preservedArrayMetadataCount) + + require.NotZero(t, linuxDoArrayLegacyID) + require.NotZero(t, wechatUnionArrayLegacyID) +} + +func TestAuthIdentityLegacyExternalSafetyMigration_ReportsConflictsAndDowngradesInvalidJSON(t *testing.T) { + tx := testTx(t) + ctx := context.Background() + + migrationPath := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql") + migrationSQL, err := os.ReadFile(migrationPath) + require.NoError(t, err) + + _, err = tx.ExecContext(ctx, ` +CREATE TABLE IF NOT EXISTS user_external_identities ( + id BIGSERIAL PRIMARY KEY, + user_id BIGINT NOT NULL, + provider TEXT NOT NULL, + provider_user_id TEXT NOT NULL, + provider_union_id TEXT NULL, + provider_username TEXT NOT NULL DEFAULT '', + display_name TEXT NOT NULL DEFAULT '', + profile_url TEXT NOT NULL DEFAULT '', + avatar_url TEXT NOT NULL DEFAULT '', + metadata TEXT NOT NULL DEFAULT '{}', + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP +); + + TRUNCATE TABLE + auth_identity_channels, + identity_adoption_decisions, + auth_identities, + auth_identity_migration_reports, + user_external_identities, + users + RESTART IDENTITY CASCADE; +`) + require.NoError(t, err) + + userIDs := make([]int64, 0, 8) + for _, email := range []string{ + "linuxdo-conflict-legacy@example.com", + "linuxdo-conflict-owner@example.com", + "wechat-conflict-legacy@example.com", + "wechat-conflict-owner@example.com", + "wechat-channel-legacy@example.com", + "wechat-channel-owner@example.com", + "linuxdo-invalid-json@example.com", + "wechat-openid-invalid-json@example.com", + } { + var userID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ($1, 'hash', 'user', 'active', 0, 1) +RETURNING id`, email).Scan(&userID)) + userIDs = append(userIDs, userID) + } + + linuxdoConflictLegacyUserID := userIDs[0] + linuxdoConflictOwnerUserID := userIDs[1] + wechatConflictLegacyUserID := userIDs[2] + wechatConflictOwnerUserID := userIDs[3] + wechatChannelLegacyUserID := userIDs[4] + wechatChannelOwnerUserID := userIDs[5] + linuxdoInvalidJSONUserID := userIDs[6] + wechatInvalidOpenIDUserID := userIDs[7] + + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO auth_identities (user_id, provider_type, provider_key, provider_subject, metadata) +VALUES ($1, 'linuxdo', 'linuxdo', 'linuxdo-conflict', '{}'::jsonb) +RETURNING id`, linuxdoConflictOwnerUserID).Scan(new(int64))) + + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO auth_identities (user_id, provider_type, provider_key, provider_subject, metadata) +VALUES ($1, 'wechat', 'wechat-main', 'union-conflict', '{}'::jsonb) +RETURNING id`, wechatConflictOwnerUserID).Scan(new(int64))) + + var wechatChannelOwnerIdentityID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO auth_identities (user_id, provider_type, provider_key, provider_subject, metadata) +VALUES ($1, 'wechat', 'wechat-main', 'union-channel-owner', '{}'::jsonb) +RETURNING id`, wechatChannelOwnerUserID).Scan(&wechatChannelOwnerIdentityID)) + + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO auth_identity_channels ( + identity_id, + provider_type, + provider_key, + channel, + channel_app_id, + channel_subject, + metadata +) +VALUES ($1, 'wechat', 'wechat-main', 'oa', 'wx-app-conflict', 'openid-channel-conflict', '{}'::jsonb) +RETURNING id`, wechatChannelOwnerIdentityID).Scan(new(int64))) + + var linuxdoConflictLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'linuxdo', 'linuxdo-conflict', NULL, 'legacy-linuxdo', 'Legacy LinuxDo Conflict', '{"source":"legacy"}') +RETURNING id +`, linuxdoConflictLegacyUserID).Scan(&linuxdoConflictLegacyID)) + + var wechatConflictLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'wechat', 'openid-union-conflict', 'union-conflict', 'legacy-wechat', 'Legacy WeChat Conflict', '{"channel":"oa","appid":"wx-app-conflict-canon"}') +RETURNING id +`, wechatConflictLegacyUserID).Scan(&wechatConflictLegacyID)) + + var wechatChannelConflictLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'wechat', 'openid-channel-conflict', 'union-channel-legacy', 'legacy-wechat-channel', 'Legacy WeChat Channel Conflict', '{"channel":"oa","appid":"wx-app-conflict"}') +RETURNING id +`, wechatChannelLegacyUserID).Scan(&wechatChannelConflictLegacyID)) + + var linuxdoInvalidJSONLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'linuxdo', 'linuxdo-invalid-json', NULL, 'legacy-linuxdo-invalid', 'Legacy LinuxDo Invalid JSON', '{invalid') +RETURNING id +`, linuxdoInvalidJSONUserID).Scan(&linuxdoInvalidJSONLegacyID)) + + var wechatInvalidOpenIDLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'wechat', 'openid-invalid-json-only', NULL, 'legacy-wechat-invalid', 'Legacy WeChat Invalid JSON', '{still-invalid') +RETURNING id +`, wechatInvalidOpenIDUserID).Scan(&wechatInvalidOpenIDLegacyID)) + + _, err = tx.ExecContext(ctx, string(migrationSQL)) + require.NoError(t, err) + + var linuxdoConflictReportCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_migration_reports +WHERE report_type = 'legacy_external_identity_conflict' + AND report_key = $1 +`, "legacy_external_identity:"+strconv.FormatInt(linuxdoConflictLegacyID, 10)).Scan(&linuxdoConflictReportCount)) + require.Equal(t, 1, linuxdoConflictReportCount) + + var wechatConflictReportCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_migration_reports +WHERE report_type = 'legacy_external_identity_conflict' + AND report_key = $1 +`, "legacy_external_identity:"+strconv.FormatInt(wechatConflictLegacyID, 10)).Scan(&wechatConflictReportCount)) + require.Equal(t, 1, wechatConflictReportCount) + + var channelConflictReportCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_migration_reports +WHERE report_type = 'legacy_external_channel_conflict' + AND report_key = $1 +`, "legacy_external_identity:"+strconv.FormatInt(wechatChannelConflictLegacyID, 10)).Scan(&channelConflictReportCount)) + require.Equal(t, 1, channelConflictReportCount) + + var invalidJSONReportCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_migration_reports +WHERE report_type = 'legacy_external_identity_invalid_metadata_json' + AND report_key IN ($1, $2) +`, "legacy_external_identity:"+strconv.FormatInt(linuxdoInvalidJSONLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatInvalidOpenIDLegacyID, 10)).Scan(&invalidJSONReportCount)) + require.Equal(t, 2, invalidJSONReportCount) + + var linuxdoInvalidIdentityCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identities +WHERE user_id = $1 + AND provider_type = 'linuxdo' + AND provider_key = 'linuxdo' + AND provider_subject = 'linuxdo-invalid-json' +`, linuxdoInvalidJSONUserID).Scan(&linuxdoInvalidIdentityCount)) + require.Equal(t, 1, linuxdoInvalidIdentityCount) + + var wechatOpenIDOnlyReportCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_migration_reports +WHERE report_type = 'wechat_openid_only_requires_remediation' + AND report_key = $1 +`, "legacy_external_identity:"+strconv.FormatInt(wechatInvalidOpenIDLegacyID, 10)).Scan(&wechatOpenIDOnlyReportCount)) + require.Equal(t, 1, wechatOpenIDOnlyReportCount) +} + +func TestAuthIdentityLegacyExternalSafetyMigration_IsSafeWhenLegacyTableMissing(t *testing.T) { + tx := testTx(t) + ctx := context.Background() + + migrationPath := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql") + migrationSQL, err := os.ReadFile(migrationPath) + require.NoError(t, err) + + var beforeCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_migration_reports +`).Scan(&beforeCount)) + + _, err = tx.ExecContext(ctx, string(migrationSQL)) + require.NoError(t, err) + + var afterCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_migration_reports +`).Scan(&afterCount)) + require.Equal(t, beforeCount, afterCount) +} diff --git a/backend/internal/repository/migrations_runner.go b/backend/internal/repository/migrations_runner.go index 9cf3b3920fb3393844d5d0fe6798df5c6e59f402..5a2e66778a0e723981f9b94bd293192fbe477ccf 100644 --- a/backend/internal/repository/migrations_runner.go +++ b/backend/internal/repository/migrations_runner.go @@ -73,6 +73,12 @@ var migrationChecksumCompatibilityRules = map[string]migrationChecksumCompatibil "222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3": {}, }, }, + "109_auth_identity_compat_backfill.sql": { + fileChecksum: "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee", + acceptedDBChecksum: map[string]struct{}{ + "2b380305e73ff0c13aa8c811e45897f2b36ca4a438f7b3e8f98e19ecb6bae0b3": {}, + }, + }, } // ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。 diff --git a/backend/internal/repository/migrations_runner_checksum_test.go b/backend/internal/repository/migrations_runner_checksum_test.go index 6c3ad725fa541a4e79fb4360ee8f01baceb30e9e..6030991b6c172f0a02e1acb91496977392546820 100644 --- a/backend/internal/repository/migrations_runner_checksum_test.go +++ b/backend/internal/repository/migrations_runner_checksum_test.go @@ -51,4 +51,13 @@ func TestIsMigrationChecksumCompatible(t *testing.T) { ) require.False(t, ok) }) + + t.Run("109历史checksum可兼容", func(t *testing.T) { + ok := isMigrationChecksumCompatible( + "109_auth_identity_compat_backfill.sql", + "2b380305e73ff0c13aa8c811e45897f2b36ca4a438f7b3e8f98e19ecb6bae0b3", + "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee", + ) + require.True(t, ok) + }) } diff --git a/backend/internal/repository/user_profile_identity_repo.go b/backend/internal/repository/user_profile_identity_repo.go new file mode 100644 index 0000000000000000000000000000000000000000..2d81239442ba9ea0715eb3419f3d98a87aa40fcd --- /dev/null +++ b/backend/internal/repository/user_profile_identity_repo.go @@ -0,0 +1,642 @@ +package repository + +import ( + "context" + "database/sql" + "fmt" + "reflect" + "strings" + "time" + "unsafe" + + entsql "entgo.io/ent/dialect/sql" + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + dbpredicate "github.com/Wei-Shaw/sub2api/ent/predicate" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +var ( + ErrAuthIdentityOwnershipConflict = infraerrors.Conflict( + "AUTH_IDENTITY_OWNERSHIP_CONFLICT", + "auth identity already belongs to another user", + ) + ErrAuthIdentityChannelOwnershipConflict = infraerrors.Conflict( + "AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", + "auth identity channel already belongs to another user", + ) + ErrAuthIdentityChannelProviderMismatch = infraerrors.BadRequest( + "AUTH_IDENTITY_CHANNEL_PROVIDER_MISMATCH", + "auth identity channel provider must match canonical identity", + ) +) + +type ProviderGrantReason string + +const ( + ProviderGrantReasonSignup ProviderGrantReason = "signup" + ProviderGrantReasonFirstBind ProviderGrantReason = "first_bind" +) + +type AuthIdentityKey struct { + ProviderType string + ProviderKey string + ProviderSubject string +} + +type AuthIdentityChannelKey struct { + ProviderType string + ProviderKey string + Channel string + ChannelAppID string + ChannelSubject string +} + +type CreateAuthIdentityInput struct { + UserID int64 + Canonical AuthIdentityKey + Channel *AuthIdentityChannelKey + Issuer *string + VerifiedAt *time.Time + Metadata map[string]any + ChannelMetadata map[string]any +} + +type BindAuthIdentityInput = CreateAuthIdentityInput + +type CreateAuthIdentityResult struct { + Identity *dbent.AuthIdentity + Channel *dbent.AuthIdentityChannel +} + +func (r *CreateAuthIdentityResult) IdentityRef() AuthIdentityKey { + if r == nil || r.Identity == nil { + return AuthIdentityKey{} + } + return AuthIdentityKey{ + ProviderType: r.Identity.ProviderType, + ProviderKey: r.Identity.ProviderKey, + ProviderSubject: r.Identity.ProviderSubject, + } +} + +func (r *CreateAuthIdentityResult) ChannelRef() *AuthIdentityChannelKey { + if r == nil || r.Channel == nil { + return nil + } + return &AuthIdentityChannelKey{ + ProviderType: r.Channel.ProviderType, + ProviderKey: r.Channel.ProviderKey, + Channel: r.Channel.Channel, + ChannelAppID: r.Channel.ChannelAppID, + ChannelSubject: r.Channel.ChannelSubject, + } +} + +type UserAuthIdentityLookup struct { + User *dbent.User + Identity *dbent.AuthIdentity + Channel *dbent.AuthIdentityChannel +} + +type ProviderGrantRecordInput struct { + UserID int64 + ProviderType string + GrantReason ProviderGrantReason +} + +type IdentityAdoptionDecisionInput struct { + PendingAuthSessionID int64 + IdentityID *int64 + AdoptDisplayName bool + AdoptAvatar bool +} + +type sqlQueryExecutor interface { + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) +} + +func (r *userRepository) WithUserProfileIdentityTx(ctx context.Context, fn func(txCtx context.Context) error) error { + if dbent.TxFromContext(ctx) != nil { + return fn(ctx) + } + + tx, err := r.client.Tx(ctx) + if err != nil { + return err + } + defer func() { _ = tx.Rollback() }() + + txCtx := dbent.NewTxContext(ctx, tx) + if err := fn(txCtx); err != nil { + return err + } + return tx.Commit() +} + +func (r *userRepository) CreateAuthIdentity(ctx context.Context, input CreateAuthIdentityInput) (*CreateAuthIdentityResult, error) { + if err := validateAuthIdentityChannelProviderMatch(input.Canonical, input.Channel); err != nil { + return nil, err + } + + client := clientFromContext(ctx, r.client) + + create := client.AuthIdentity.Create(). + SetUserID(input.UserID). + SetProviderType(strings.TrimSpace(input.Canonical.ProviderType)). + SetProviderKey(strings.TrimSpace(input.Canonical.ProviderKey)). + SetProviderSubject(strings.TrimSpace(input.Canonical.ProviderSubject)). + SetMetadata(copyMetadata(input.Metadata)). + SetNillableIssuer(input.Issuer). + SetNillableVerifiedAt(input.VerifiedAt) + + identity, err := create.Save(ctx) + if err != nil { + return nil, err + } + + var channel *dbent.AuthIdentityChannel + if input.Channel != nil { + channel, err = client.AuthIdentityChannel.Create(). + SetIdentityID(identity.ID). + SetProviderType(strings.TrimSpace(input.Channel.ProviderType)). + SetProviderKey(strings.TrimSpace(input.Channel.ProviderKey)). + SetChannel(strings.TrimSpace(input.Channel.Channel)). + SetChannelAppID(strings.TrimSpace(input.Channel.ChannelAppID)). + SetChannelSubject(strings.TrimSpace(input.Channel.ChannelSubject)). + SetMetadata(copyMetadata(input.ChannelMetadata)). + Save(ctx) + if err != nil { + return nil, err + } + } + + return &CreateAuthIdentityResult{Identity: identity, Channel: channel}, nil +} + +func (r *userRepository) GetUserByCanonicalIdentity(ctx context.Context, key AuthIdentityKey) (*UserAuthIdentityLookup, error) { + identity, err := clientFromContext(ctx, r.client).AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(strings.TrimSpace(key.ProviderType)), + authidentity.ProviderKeyEQ(strings.TrimSpace(key.ProviderKey)), + authidentity.ProviderSubjectEQ(strings.TrimSpace(key.ProviderSubject)), + ). + WithUser(). + Only(ctx) + if err != nil { + return nil, err + } + + return &UserAuthIdentityLookup{ + User: identity.Edges.User, + Identity: identity, + }, nil +} + +func (r *userRepository) GetUserByChannelIdentity(ctx context.Context, key AuthIdentityChannelKey) (*UserAuthIdentityLookup, error) { + channel, err := clientFromContext(ctx, r.client).AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ(strings.TrimSpace(key.ProviderType)), + authidentitychannel.ProviderKeyEQ(strings.TrimSpace(key.ProviderKey)), + authidentitychannel.ChannelEQ(strings.TrimSpace(key.Channel)), + authidentitychannel.ChannelAppIDEQ(strings.TrimSpace(key.ChannelAppID)), + authidentitychannel.ChannelSubjectEQ(strings.TrimSpace(key.ChannelSubject)), + ). + WithIdentity(func(q *dbent.AuthIdentityQuery) { + q.WithUser() + }). + Only(ctx) + if err != nil { + return nil, err + } + + return &UserAuthIdentityLookup{ + User: channel.Edges.Identity.Edges.User, + Identity: channel.Edges.Identity, + Channel: channel, + }, nil +} + +func (r *userRepository) ListUserAuthIdentities(ctx context.Context, userID int64) ([]service.UserAuthIdentityRecord, error) { + identities, err := clientFromContext(ctx, r.client).AuthIdentity.Query(). + Where(authidentity.UserIDEQ(userID)). + All(ctx) + if err != nil { + return nil, err + } + + records := make([]service.UserAuthIdentityRecord, 0, len(identities)) + for _, identity := range identities { + if identity == nil { + continue + } + records = append(records, service.UserAuthIdentityRecord{ + ProviderType: strings.TrimSpace(identity.ProviderType), + ProviderKey: strings.TrimSpace(identity.ProviderKey), + ProviderSubject: strings.TrimSpace(identity.ProviderSubject), + VerifiedAt: identity.VerifiedAt, + Issuer: identity.Issuer, + Metadata: copyMetadata(identity.Metadata), + CreatedAt: identity.CreatedAt, + UpdatedAt: identity.UpdatedAt, + }) + } + + return records, nil +} + +func (r *userRepository) UnbindUserAuthProvider(ctx context.Context, userID int64, provider string) error { + provider = strings.ToLower(strings.TrimSpace(provider)) + if provider == "" || provider == "email" { + return service.ErrIdentityProviderInvalid + } + + return r.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error { + client := clientFromContext(txCtx, r.client) + identityIDs, err := client.AuthIdentity.Query(). + Where( + authidentity.UserIDEQ(userID), + authidentity.ProviderTypeEQ(provider), + ). + IDs(txCtx) + if err != nil { + return err + } + if len(identityIDs) == 0 { + return nil + } + + if _, err := client.IdentityAdoptionDecision.Update(). + Where(identityadoptiondecision.IdentityIDIn(identityIDs...)). + ClearIdentityID(). + Save(txCtx); err != nil { + return err + } + if _, err := client.AuthIdentityChannel.Delete(). + Where(authidentitychannel.IdentityIDIn(identityIDs...)). + Exec(txCtx); err != nil { + return err + } + _, err = client.AuthIdentity.Delete(). + Where( + authidentity.UserIDEQ(userID), + authidentity.ProviderTypeEQ(provider), + ). + Exec(txCtx) + return err + }) +} + +func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindAuthIdentityInput) (*CreateAuthIdentityResult, error) { + if err := validateAuthIdentityChannelProviderMatch(input.Canonical, input.Channel); err != nil { + return nil, err + } + + var result *CreateAuthIdentityResult + err := r.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error { + client := clientFromContext(txCtx, r.client) + canonical := input.Canonical + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(strings.TrimSpace(canonical.ProviderType)), + authidentity.ProviderKeyEQ(strings.TrimSpace(canonical.ProviderKey)), + authidentity.ProviderSubjectEQ(strings.TrimSpace(canonical.ProviderSubject)), + ). + Only(txCtx) + if err != nil && !dbent.IsNotFound(err) { + return err + } + if identity != nil && identity.UserID != input.UserID { + return ErrAuthIdentityOwnershipConflict + } + if identity == nil { + identity, err = client.AuthIdentity.Create(). + SetUserID(input.UserID). + SetProviderType(strings.TrimSpace(canonical.ProviderType)). + SetProviderKey(strings.TrimSpace(canonical.ProviderKey)). + SetProviderSubject(strings.TrimSpace(canonical.ProviderSubject)). + SetMetadata(copyMetadata(input.Metadata)). + SetNillableIssuer(input.Issuer). + SetNillableVerifiedAt(input.VerifiedAt). + Save(txCtx) + if err != nil { + return err + } + } else { + update := client.AuthIdentity.UpdateOneID(identity.ID) + if input.Metadata != nil { + update = update.SetMetadata(copyMetadata(input.Metadata)) + } + if input.Issuer != nil { + update = update.SetIssuer(strings.TrimSpace(*input.Issuer)) + } + if input.VerifiedAt != nil { + update = update.SetVerifiedAt(*input.VerifiedAt) + } + identity, err = update.Save(txCtx) + if err != nil { + return err + } + } + + var channel *dbent.AuthIdentityChannel + if input.Channel != nil { + channel, err = client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ(strings.TrimSpace(input.Channel.ProviderType)), + authidentitychannel.ProviderKeyEQ(strings.TrimSpace(input.Channel.ProviderKey)), + authidentitychannel.ChannelEQ(strings.TrimSpace(input.Channel.Channel)), + authidentitychannel.ChannelAppIDEQ(strings.TrimSpace(input.Channel.ChannelAppID)), + authidentitychannel.ChannelSubjectEQ(strings.TrimSpace(input.Channel.ChannelSubject)), + ). + WithIdentity(). + Only(txCtx) + if err != nil && !dbent.IsNotFound(err) { + return err + } + if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != input.UserID { + return ErrAuthIdentityChannelOwnershipConflict + } + if channel == nil { + channel, err = client.AuthIdentityChannel.Create(). + SetIdentityID(identity.ID). + SetProviderType(strings.TrimSpace(input.Channel.ProviderType)). + SetProviderKey(strings.TrimSpace(input.Channel.ProviderKey)). + SetChannel(strings.TrimSpace(input.Channel.Channel)). + SetChannelAppID(strings.TrimSpace(input.Channel.ChannelAppID)). + SetChannelSubject(strings.TrimSpace(input.Channel.ChannelSubject)). + SetMetadata(copyMetadata(input.ChannelMetadata)). + Save(txCtx) + if err != nil { + return err + } + } else { + update := client.AuthIdentityChannel.UpdateOneID(channel.ID). + SetIdentityID(identity.ID) + if input.ChannelMetadata != nil { + update = update.SetMetadata(copyMetadata(input.ChannelMetadata)) + } + channel, err = update.Save(txCtx) + if err != nil { + return err + } + } + } + + result = &CreateAuthIdentityResult{Identity: identity, Channel: channel} + return nil + }) + if err != nil { + return nil, err + } + return result, nil +} + +func (r *userRepository) RecordProviderGrant(ctx context.Context, input ProviderGrantRecordInput) (bool, error) { + exec := txAwareSQLExecutor(ctx, r.sql, r.client) + if exec == nil { + return false, fmt.Errorf("sql executor is not configured") + } + + result, err := exec.ExecContext(ctx, ` +INSERT INTO user_provider_default_grants (user_id, provider_type, grant_reason) +VALUES ($1, $2, $3) +ON CONFLICT (user_id, provider_type, grant_reason) DO NOTHING`, + input.UserID, + strings.TrimSpace(input.ProviderType), + string(input.GrantReason), + ) + if err != nil { + return false, err + } + affected, err := result.RowsAffected() + if err != nil { + return false, err + } + return affected > 0, nil +} + +func (r *userRepository) UpsertIdentityAdoptionDecision(ctx context.Context, input IdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) { + client := clientFromContext(ctx, r.client) + if input.IdentityID != nil && *input.IdentityID > 0 { + if _, err := client.IdentityAdoptionDecision.Update(). + Where( + identityadoptiondecision.IdentityIDEQ(*input.IdentityID), + dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) { + col := s.C(identityadoptiondecision.FieldPendingAuthSessionID) + s.Where(entsql.Or( + entsql.IsNull(col), + entsql.NEQ(col, input.PendingAuthSessionID), + )) + }), + ). + ClearIdentityID(). + Save(ctx); err != nil { + return nil, err + } + } + + current, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)). + Only(ctx) + if err != nil && !dbent.IsNotFound(err) { + return nil, err + } + now := time.Now().UTC() + if current == nil { + create := client.IdentityAdoptionDecision.Create(). + SetPendingAuthSessionID(input.PendingAuthSessionID). + SetAdoptDisplayName(input.AdoptDisplayName). + SetAdoptAvatar(input.AdoptAvatar). + SetDecidedAt(now) + if input.IdentityID != nil { + create = create.SetIdentityID(*input.IdentityID) + } + return create.Save(ctx) + } + + update := client.IdentityAdoptionDecision.UpdateOneID(current.ID). + SetAdoptDisplayName(input.AdoptDisplayName). + SetAdoptAvatar(input.AdoptAvatar) + if input.IdentityID != nil { + update = update.SetIdentityID(*input.IdentityID) + } + return update.Save(ctx) +} + +func (r *userRepository) GetIdentityAdoptionDecisionByPendingAuthSessionID(ctx context.Context, pendingAuthSessionID int64) (*dbent.IdentityAdoptionDecision, error) { + return clientFromContext(ctx, r.client).IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingAuthSessionID)). + Only(ctx) +} + +func (r *userRepository) UpdateUserLastLoginAt(ctx context.Context, userID int64, loginAt time.Time) error { + _, err := clientFromContext(ctx, r.client).User.UpdateOneID(userID). + SetLastLoginAt(loginAt). + Save(ctx) + return err +} + +func (r *userRepository) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error { + _, err := clientFromContext(ctx, r.client).User.UpdateOneID(userID). + SetLastActiveAt(activeAt). + Save(ctx) + return err +} + +func (r *userRepository) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) { + exec, err := r.userProfileIdentitySQL(ctx) + if err != nil { + return nil, err + } + + rows, err := exec.QueryContext(ctx, ` +SELECT storage_provider, storage_key, url, content_type, byte_size, sha256 +FROM user_avatars +WHERE user_id = $1`, userID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + if !rows.Next() { + return nil, rows.Err() + } + + var avatar service.UserAvatar + if err := rows.Scan( + &avatar.StorageProvider, + &avatar.StorageKey, + &avatar.URL, + &avatar.ContentType, + &avatar.ByteSize, + &avatar.SHA256, + ); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return &avatar, nil +} + +func (r *userRepository) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) { + exec, err := r.userProfileIdentitySQL(ctx) + if err != nil { + return nil, err + } + + _, err = exec.ExecContext(ctx, ` +INSERT INTO user_avatars (user_id, storage_provider, storage_key, url, content_type, byte_size, sha256, updated_at) +VALUES ($1, $2, $3, $4, $5, $6, $7, NOW()) +ON CONFLICT (user_id) DO UPDATE SET + storage_provider = EXCLUDED.storage_provider, + storage_key = EXCLUDED.storage_key, + url = EXCLUDED.url, + content_type = EXCLUDED.content_type, + byte_size = EXCLUDED.byte_size, + sha256 = EXCLUDED.sha256, + updated_at = NOW()`, + userID, + strings.TrimSpace(input.StorageProvider), + strings.TrimSpace(input.StorageKey), + strings.TrimSpace(input.URL), + strings.TrimSpace(input.ContentType), + input.ByteSize, + strings.TrimSpace(input.SHA256), + ) + if err != nil { + return nil, err + } + + return &service.UserAvatar{ + StorageProvider: strings.TrimSpace(input.StorageProvider), + StorageKey: strings.TrimSpace(input.StorageKey), + URL: strings.TrimSpace(input.URL), + ContentType: strings.TrimSpace(input.ContentType), + ByteSize: input.ByteSize, + SHA256: strings.TrimSpace(input.SHA256), + }, nil +} + +func (r *userRepository) DeleteUserAvatar(ctx context.Context, userID int64) error { + exec, err := r.userProfileIdentitySQL(ctx) + if err != nil { + return err + } + _, err = exec.ExecContext(ctx, `DELETE FROM user_avatars WHERE user_id = $1`, userID) + return err +} + +func copyMetadata(in map[string]any) map[string]any { + if len(in) == 0 { + return map[string]any{} + } + out := make(map[string]any, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +func validateAuthIdentityChannelProviderMatch(canonical AuthIdentityKey, channel *AuthIdentityChannelKey) error { + if channel == nil { + return nil + } + + canonicalProviderType := strings.TrimSpace(canonical.ProviderType) + canonicalProviderKey := strings.TrimSpace(canonical.ProviderKey) + channelProviderType := strings.TrimSpace(channel.ProviderType) + channelProviderKey := strings.TrimSpace(channel.ProviderKey) + + if canonicalProviderType != channelProviderType || canonicalProviderKey != channelProviderKey { + return ErrAuthIdentityChannelProviderMismatch + } + + return nil +} + +func txAwareSQLExecutor(ctx context.Context, fallback sqlExecutor, client *dbent.Client) sqlQueryExecutor { + if tx := dbent.TxFromContext(ctx); tx != nil { + if exec := sqlExecutorFromEntClient(tx.Client()); exec != nil { + return exec + } + } + if fallback != nil { + return fallback + } + return sqlExecutorFromEntClient(client) +} + +func (r *userRepository) userProfileIdentitySQL(ctx context.Context) (sqlQueryExecutor, error) { + exec := txAwareSQLExecutor(ctx, r.sql, r.client) + if exec == nil { + return nil, fmt.Errorf("sql executor is not configured") + } + return exec, nil +} + +func sqlExecutorFromEntClient(client *dbent.Client) sqlQueryExecutor { + if client == nil { + return nil + } + + clientValue := reflect.ValueOf(client).Elem() + configValue := clientValue.FieldByName("config") + driverValue := configValue.FieldByName("driver") + if !driverValue.IsValid() { + return nil + } + + driver := reflect.NewAt(driverValue.Type(), unsafe.Pointer(driverValue.UnsafeAddr())).Elem().Interface() + exec, ok := driver.(sqlQueryExecutor) + if !ok { + return nil + } + return exec +} diff --git a/backend/internal/repository/user_profile_identity_repo_contract_test.go b/backend/internal/repository/user_profile_identity_repo_contract_test.go new file mode 100644 index 0000000000000000000000000000000000000000..697e96a4fd71d730e6b769873430701f3cc435e4 --- /dev/null +++ b/backend/internal/repository/user_profile_identity_repo_contract_test.go @@ -0,0 +1,503 @@ +//go:build integration + +package repository + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/suite" +) + +type UserProfileIdentityRepoSuite struct { + suite.Suite + ctx context.Context + client *dbent.Client + repo *userRepository +} + +func TestUserProfileIdentityRepoSuite(t *testing.T) { + suite.Run(t, new(UserProfileIdentityRepoSuite)) +} + +func (s *UserProfileIdentityRepoSuite) SetupTest() { + s.ctx = context.Background() + s.client = testEntClient(s.T()) + s.repo = newUserRepositoryWithSQL(s.client, integrationDB) + + _, err := integrationDB.ExecContext(s.ctx, ` +TRUNCATE TABLE + identity_adoption_decisions, + auth_identity_channels, + auth_identities, + pending_auth_sessions, + user_provider_default_grants, + user_avatars +RESTART IDENTITY`) + s.Require().NoError(err) +} + +func (s *UserProfileIdentityRepoSuite) mustCreateUser(label string) *dbent.User { + s.T().Helper() + + user, err := s.client.User.Create(). + SetEmail(fmt.Sprintf("%s-%d@example.com", label, time.Now().UnixNano())). + SetPasswordHash("test-password-hash"). + SetRole("user"). + SetStatus("active"). + Save(s.ctx) + s.Require().NoError(err) + return user +} + +func (s *UserProfileIdentityRepoSuite) mustCreatePendingAuthSession(key AuthIdentityKey) *dbent.PendingAuthSession { + s.T().Helper() + + session, err := s.client.PendingAuthSession.Create(). + SetSessionToken(fmt.Sprintf("pending-%d", time.Now().UnixNano())). + SetIntent("bind_current_user"). + SetProviderType(key.ProviderType). + SetProviderKey(key.ProviderKey). + SetProviderSubject(key.ProviderSubject). + SetExpiresAt(time.Now().UTC().Add(15 * time.Minute)). + SetUpstreamIdentityClaims(map[string]any{"provider_subject": key.ProviderSubject}). + SetLocalFlowState(map[string]any{"step": "pending"}). + Save(s.ctx) + s.Require().NoError(err) + return session +} + +func (s *UserProfileIdentityRepoSuite) TestCreateAndLookupCanonicalAndChannelIdentity() { + user := s.mustCreateUser("canonical-channel") + + verifiedAt := time.Now().UTC().Truncate(time.Second) + created, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{ + UserID: user.ID, + Canonical: AuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-open", + ProviderSubject: "union-123", + }, + Channel: &AuthIdentityChannelKey{ + ProviderType: "wechat", + ProviderKey: "wechat-open", + Channel: "mp", + ChannelAppID: "wx-app", + ChannelSubject: "openid-123", + }, + Issuer: stringPtr("https://issuer.example"), + VerifiedAt: &verifiedAt, + Metadata: map[string]any{"unionid": "union-123"}, + ChannelMetadata: map[string]any{"openid": "openid-123"}, + }) + s.Require().NoError(err) + s.Require().NotNil(created.Identity) + s.Require().NotNil(created.Channel) + + canonical, err := s.repo.GetUserByCanonicalIdentity(s.ctx, created.IdentityRef()) + s.Require().NoError(err) + s.Require().Equal(user.ID, canonical.User.ID) + s.Require().Equal(created.Identity.ID, canonical.Identity.ID) + s.Require().Equal("union-123", canonical.Identity.ProviderSubject) + + channel, err := s.repo.GetUserByChannelIdentity(s.ctx, *created.ChannelRef()) + s.Require().NoError(err) + s.Require().Equal(user.ID, channel.User.ID) + s.Require().Equal(created.Identity.ID, channel.Identity.ID) + s.Require().Equal(created.Channel.ID, channel.Channel.ID) +} + +func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_IsIdempotentAndRejectsOtherOwners() { + owner := s.mustCreateUser("owner") + other := s.mustCreateUser("other") + + first, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{ + UserID: owner.ID, + Canonical: AuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + ProviderSubject: "subject-1", + }, + Channel: &AuthIdentityChannelKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + Channel: "oauth", + ChannelAppID: "linuxdo-web", + ChannelSubject: "subject-1", + }, + Metadata: map[string]any{"username": "first"}, + ChannelMetadata: map[string]any{"scope": "read"}, + }) + s.Require().NoError(err) + + second, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{ + UserID: owner.ID, + Canonical: AuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + ProviderSubject: "subject-1", + }, + Channel: &AuthIdentityChannelKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + Channel: "oauth", + ChannelAppID: "linuxdo-web", + ChannelSubject: "subject-1", + }, + Metadata: map[string]any{"username": "second"}, + ChannelMetadata: map[string]any{"scope": "write"}, + }) + s.Require().NoError(err) + s.Require().Equal(first.Identity.ID, second.Identity.ID) + s.Require().Equal(first.Channel.ID, second.Channel.ID) + s.Require().Equal("second", second.Identity.Metadata["username"]) + s.Require().Equal("write", second.Channel.Metadata["scope"]) + + _, err = s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{ + UserID: other.ID, + Canonical: AuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + ProviderSubject: "subject-1", + }, + }) + s.Require().ErrorIs(err, ErrAuthIdentityOwnershipConflict) + + _, err = s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{ + UserID: other.ID, + Canonical: AuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + ProviderSubject: "subject-2", + }, + Channel: &AuthIdentityChannelKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + Channel: "oauth", + ChannelAppID: "linuxdo-web", + ChannelSubject: "subject-1", + }, + }) + s.Require().ErrorIs(err, ErrAuthIdentityChannelOwnershipConflict) +} + +func (s *UserProfileIdentityRepoSuite) TestCreateAuthIdentity_RejectsChannelProviderMismatch() { + user := s.mustCreateUser("provider-mismatch-create") + + _, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{ + UserID: user.ID, + Canonical: AuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-main", + ProviderSubject: "union-create-mismatch", + }, + Channel: &AuthIdentityChannelKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + Channel: "oauth", + ChannelAppID: "app-mismatch", + ChannelSubject: "openid-create-mismatch", + }, + }) + s.Require().ErrorIs(err, ErrAuthIdentityChannelProviderMismatch) +} + +func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_RejectsChannelProviderMismatch() { + user := s.mustCreateUser("provider-mismatch-bind") + + _, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{ + UserID: user.ID, + Canonical: AuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-main", + ProviderSubject: "union-bind-mismatch", + }, + Channel: &AuthIdentityChannelKey{ + ProviderType: "wechat", + ProviderKey: "wechat-legacy", + Channel: "oa", + ChannelAppID: "wx-app-bind-mismatch", + ChannelSubject: "openid-bind-mismatch", + }, + }) + s.Require().ErrorIs(err, ErrAuthIdentityChannelProviderMismatch) +} + +func (s *UserProfileIdentityRepoSuite) TestWithUserProfileIdentityTx_RollsBackIdentityAndGrantOnError() { + user := s.mustCreateUser("tx-rollback") + expectedErr := errors.New("rollback") + + err := s.repo.WithUserProfileIdentityTx(s.ctx, func(txCtx context.Context) error { + _, err := s.repo.CreateAuthIdentity(txCtx, CreateAuthIdentityInput{ + UserID: user.ID, + Canonical: AuthIdentityKey{ + ProviderType: "oidc", + ProviderKey: "https://issuer.example", + ProviderSubject: "subject-rollback", + }, + }) + s.Require().NoError(err) + + inserted, err := s.repo.RecordProviderGrant(txCtx, ProviderGrantRecordInput{ + UserID: user.ID, + ProviderType: "oidc", + GrantReason: ProviderGrantReasonFirstBind, + }) + s.Require().NoError(err) + s.Require().True(inserted) + return expectedErr + }) + s.Require().ErrorIs(err, expectedErr) + + _, err = s.repo.GetUserByCanonicalIdentity(s.ctx, AuthIdentityKey{ + ProviderType: "oidc", + ProviderKey: "https://issuer.example", + ProviderSubject: "subject-rollback", + }) + s.Require().True(dbent.IsNotFound(err)) + + var count int + s.Require().NoError(integrationDB.QueryRowContext(s.ctx, ` +SELECT COUNT(*) +FROM user_provider_default_grants +WHERE user_id = $1 AND provider_type = $2 AND grant_reason = $3`, + user.ID, + "oidc", + string(ProviderGrantReasonFirstBind), + ).Scan(&count)) + s.Require().Zero(count) +} + +func (s *UserProfileIdentityRepoSuite) TestRecordProviderGrant_IsIdempotentPerReason() { + user := s.mustCreateUser("grant") + + inserted, err := s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{ + UserID: user.ID, + ProviderType: "wechat", + GrantReason: ProviderGrantReasonFirstBind, + }) + s.Require().NoError(err) + s.Require().True(inserted) + + inserted, err = s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{ + UserID: user.ID, + ProviderType: "wechat", + GrantReason: ProviderGrantReasonFirstBind, + }) + s.Require().NoError(err) + s.Require().False(inserted) + + inserted, err = s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{ + UserID: user.ID, + ProviderType: "wechat", + GrantReason: ProviderGrantReasonSignup, + }) + s.Require().NoError(err) + s.Require().True(inserted) + + var count int + s.Require().NoError(integrationDB.QueryRowContext(s.ctx, ` +SELECT COUNT(*) +FROM user_provider_default_grants +WHERE user_id = $1 AND provider_type = $2`, + user.ID, + "wechat", + ).Scan(&count)) + s.Require().Equal(2, count) +} + +func (s *UserProfileIdentityRepoSuite) TestUpsertIdentityAdoptionDecision_PersistsAndLinksIdentity() { + user := s.mustCreateUser("adoption") + identity, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{ + UserID: user.ID, + Canonical: AuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-open", + ProviderSubject: "union-adoption", + }, + }) + s.Require().NoError(err) + + session := s.mustCreatePendingAuthSession(identity.IdentityRef()) + + first, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{ + PendingAuthSessionID: session.ID, + AdoptDisplayName: true, + AdoptAvatar: false, + }) + s.Require().NoError(err) + s.Require().True(first.AdoptDisplayName) + s.Require().False(first.AdoptAvatar) + s.Require().Nil(first.IdentityID) + + second, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{ + PendingAuthSessionID: session.ID, + IdentityID: &identity.Identity.ID, + AdoptDisplayName: true, + AdoptAvatar: true, + }) + s.Require().NoError(err) + s.Require().Equal(first.ID, second.ID) + s.Require().NotNil(second.IdentityID) + s.Require().Equal(identity.Identity.ID, *second.IdentityID) + s.Require().True(second.AdoptAvatar) + + loaded, err := s.repo.GetIdentityAdoptionDecisionByPendingAuthSessionID(s.ctx, session.ID) + s.Require().NoError(err) + s.Require().Equal(second.ID, loaded.ID) + s.Require().Equal(identity.Identity.ID, *loaded.IdentityID) +} + +func (s *UserProfileIdentityRepoSuite) TestUpsertIdentityAdoptionDecision_ReassignsExistingIdentityReference() { + user := s.mustCreateUser("adoption-reassign") + identity, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{ + UserID: user.ID, + Canonical: AuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-open", + ProviderSubject: "union-adoption-reassign", + }, + }) + s.Require().NoError(err) + + firstSession := s.mustCreatePendingAuthSession(identity.IdentityRef()) + firstDecision, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{ + PendingAuthSessionID: firstSession.ID, + IdentityID: &identity.Identity.ID, + AdoptDisplayName: true, + AdoptAvatar: false, + }) + s.Require().NoError(err) + s.Require().NotNil(firstDecision.IdentityID) + s.Require().Equal(identity.Identity.ID, *firstDecision.IdentityID) + + secondSession := s.mustCreatePendingAuthSession(identity.IdentityRef()) + secondDecision, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{ + PendingAuthSessionID: secondSession.ID, + IdentityID: &identity.Identity.ID, + AdoptDisplayName: false, + AdoptAvatar: true, + }) + s.Require().NoError(err) + s.Require().NotNil(secondDecision.IdentityID) + s.Require().Equal(identity.Identity.ID, *secondDecision.IdentityID) + + reloadedFirst, err := s.repo.GetIdentityAdoptionDecisionByPendingAuthSessionID(s.ctx, firstSession.ID) + s.Require().NoError(err) + s.Require().Nil(reloadedFirst.IdentityID) +} + +func (s *UserProfileIdentityRepoSuite) TestWithUserProfileIdentityTx_AllowsAvatarOnlyProfileUpdate() { + user := s.mustCreateUser("avatar-only-update") + + model, err := s.repo.GetByID(s.ctx, user.ID) + s.Require().NoError(err) + s.Require().NotNil(model) + + err = s.repo.WithUserProfileIdentityTx(s.ctx, func(txCtx context.Context) error { + _, err := s.repo.UpsertUserAvatar(txCtx, user.ID, service.UpsertUserAvatarInput{ + StorageProvider: "remote_url", + URL: "https://cdn.example.com/avatar.png", + }) + if err != nil { + return err + } + return s.repo.Update(txCtx, model) + }) + s.Require().NoError(err) + + avatar, err := s.repo.GetUserAvatar(s.ctx, user.ID) + s.Require().NoError(err) + s.Require().NotNil(avatar) + s.Require().Equal("https://cdn.example.com/avatar.png", avatar.URL) +} + +func (s *UserProfileIdentityRepoSuite) TestUserAvatarCRUDAndUserLookup() { + user := s.mustCreateUser("avatar") + + inlineAvatar, err := s.repo.UpsertUserAvatar(s.ctx, user.ID, service.UpsertUserAvatarInput{ + StorageProvider: "inline", + URL: "data:image/png;base64,QUJD", + ContentType: "image/png", + ByteSize: 3, + SHA256: "902fbdd2b1df0c4f70b4a5d23525e932", + }) + s.Require().NoError(err) + s.Require().Equal("inline", inlineAvatar.StorageProvider) + s.Require().Equal("data:image/png;base64,QUJD", inlineAvatar.URL) + + loadedAvatar, err := s.repo.GetUserAvatar(s.ctx, user.ID) + s.Require().NoError(err) + s.Require().NotNil(loadedAvatar) + s.Require().Equal("image/png", loadedAvatar.ContentType) + s.Require().Equal(3, loadedAvatar.ByteSize) + + _, err = s.repo.UpsertUserAvatar(s.ctx, user.ID, service.UpsertUserAvatarInput{ + StorageProvider: "remote_url", + URL: "https://cdn.example.com/avatar.png", + }) + s.Require().NoError(err) + + loadedAvatar, err = s.repo.GetUserAvatar(s.ctx, user.ID) + s.Require().NoError(err) + s.Require().NotNil(loadedAvatar) + s.Require().Equal("remote_url", loadedAvatar.StorageProvider) + s.Require().Equal("https://cdn.example.com/avatar.png", loadedAvatar.URL) + s.Require().Zero(loadedAvatar.ByteSize) + + s.Require().NoError(s.repo.DeleteUserAvatar(s.ctx, user.ID)) + loadedAvatar, err = s.repo.GetUserAvatar(s.ctx, user.ID) + s.Require().NoError(err) + s.Require().Nil(loadedAvatar) +} + +func (s *UserProfileIdentityRepoSuite) TestUpdateUserLastLoginAndActiveAt_UsesDedicatedColumns() { + user := s.mustCreateUser("activity") + loginAt := time.Date(2026, 4, 20, 8, 0, 0, 0, time.UTC) + activeAt := loginAt.Add(5 * time.Minute) + + s.Require().NoError(s.repo.UpdateUserLastLoginAt(s.ctx, user.ID, loginAt)) + s.Require().NoError(s.repo.UpdateUserLastActiveAt(s.ctx, user.ID, activeAt)) + + var storedLoginAt sqlNullTime + var storedActiveAt sqlNullTime + s.Require().NoError(integrationDB.QueryRowContext(s.ctx, ` +SELECT last_login_at, last_active_at +FROM users +WHERE id = $1`, + user.ID, + ).Scan(&storedLoginAt, &storedActiveAt)) + s.Require().True(storedLoginAt.Valid) + s.Require().True(storedActiveAt.Valid) + s.Require().True(storedLoginAt.Time.Equal(loginAt)) + s.Require().True(storedActiveAt.Time.Equal(activeAt)) +} + +type sqlNullTime struct { + Time time.Time + Valid bool +} + +func (t *sqlNullTime) Scan(value any) error { + switch v := value.(type) { + case time.Time: + t.Time = v + t.Valid = true + return nil + case nil: + t.Time = time.Time{} + t.Valid = false + return nil + default: + return fmt.Errorf("unsupported scan type %T", value) + } +} + +func stringPtr(v string) *string { + return &v +} diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index 913e1c4000595b454f22e3f52d93d06f24aca4e1..c7d301c76aa452c746b25a4a997acb3b8fff5f4b 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -11,12 +11,17 @@ import ( dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" dbgroup "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/predicate" dbuser "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/lib/pq" entsql "entgo.io/ent/dialect/sql" ) @@ -51,8 +56,12 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error defer func() { _ = tx.Rollback() }() txClient = tx.Client() } else { - // 已处于外部事务中(ErrTxStarted),复用当前 client 并由调用方负责提交/回滚。 - txClient = r.client + // 已处于外部事务中(ErrTxStarted),复用当前事务 client 并由调用方负责提交/回滚。 + if existingTx := dbent.TxFromContext(ctx); existingTx != nil { + txClient = existingTx.Client() + } else { + txClient = r.client + } } created, err := txClient.User.Create(). @@ -64,6 +73,9 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error SetBalance(userIn.Balance). SetConcurrency(userIn.Concurrency). SetStatus(userIn.Status). + SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)). + SetNillableLastLoginAt(userIn.LastLoginAt). + SetNillableLastActiveAt(userIn.LastActiveAt). Save(ctx) if err != nil { return translatePersistenceError(err, nil, service.ErrEmailExists) @@ -72,6 +84,9 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, created.ID, userIn.AllowedGroups); err != nil { return err } + if err := ensureEmailAuthIdentityWithClient(ctx, txClient, created.ID, created.Email, "user_repo_create"); err != nil { + return err + } if tx != nil { if err := tx.Commit(); err != nil { @@ -101,10 +116,20 @@ func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User, } func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service.User, error) { - m, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Only(ctx) + matches, err := r.client.User.Query(). + Where(userEmailLookupPredicate(email)). + Order(dbent.Asc(dbuser.FieldID)). + All(ctx) if err != nil { - return nil, translatePersistenceError(err, service.ErrUserNotFound, nil) + return nil, err + } + if len(matches) == 0 { + return nil, service.ErrUserNotFound + } + if len(matches) > 1 { + return nil, fmt.Errorf("normalized email lookup matched multiple users for %q", strings.TrimSpace(email)) } + m := matches[0] out := userEntityToService(m) groups, err := r.loadAllowedGroups(ctx, []int64{m.ID}) @@ -133,9 +158,18 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error defer func() { _ = tx.Rollback() }() txClient = tx.Client() } else { - // 已处于外部事务中(ErrTxStarted),复用当前 client 并由调用方负责提交/回滚。 - txClient = r.client + // 已处于外部事务中(ErrTxStarted),复用当前事务 client 并由调用方负责提交/回滚。 + if existingTx := dbent.TxFromContext(ctx); existingTx != nil { + txClient = existingTx.Client() + } else { + txClient = r.client + } } + existing, err := clientFromContext(ctx, txClient).User.Get(ctx, userIn.ID) + if err != nil { + return translatePersistenceError(err, service.ErrUserNotFound, nil) + } + oldEmail := existing.Email updateOp := txClient.User.UpdateOneID(userIn.ID). SetEmail(userIn.Email). @@ -151,6 +185,15 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold). SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)). SetTotalRecharged(userIn.TotalRecharged) + if userIn.SignupSource != "" { + updateOp = updateOp.SetSignupSource(userIn.SignupSource) + } + if userIn.LastLoginAt != nil { + updateOp = updateOp.SetLastLoginAt(*userIn.LastLoginAt) + } + if userIn.LastActiveAt != nil { + updateOp = updateOp.SetLastActiveAt(*userIn.LastActiveAt) + } if userIn.BalanceNotifyThreshold == nil { updateOp = updateOp.ClearBalanceNotifyThreshold() } @@ -162,6 +205,9 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, updated.ID, userIn.AllowedGroups); err != nil { return err } + if err := replaceEmailAuthIdentityWithClient(ctx, txClient, updated.ID, oldEmail, updated.Email, "user_repo_update"); err != nil { + return err + } if tx != nil { if err := tx.Commit(); err != nil { @@ -173,14 +219,146 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error return nil } +func ensureEmailAuthIdentityWithClient(ctx context.Context, client *dbent.Client, userID int64, email string, source string) error { + client = clientFromContext(ctx, client) + if client == nil || userID <= 0 { + return nil + } + + subject := normalizeEmailAuthIdentitySubject(email) + if subject == "" { + return nil + } + + if err := client.AuthIdentity.Create(). + SetUserID(userID). + SetProviderType("email"). + SetProviderKey("email"). + SetProviderSubject(subject). + SetVerifiedAt(time.Now().UTC()). + SetMetadata(map[string]any{"source": source}). + OnConflictColumns( + authidentity.FieldProviderType, + authidentity.FieldProviderKey, + authidentity.FieldProviderSubject, + ). + DoNothing(). + Exec(ctx); err != nil { + if !isSQLNoRowsError(err) { + return err + } + } + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ(subject), + ). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil + } + return err + } + if identity.UserID != userID { + return ErrAuthIdentityOwnershipConflict + } + return nil +} + +func replaceEmailAuthIdentityWithClient(ctx context.Context, client *dbent.Client, userID int64, oldEmail, newEmail string, source string) error { + newSubject := normalizeEmailAuthIdentitySubject(newEmail) + if err := ensureEmailAuthIdentityWithClient(ctx, client, userID, newEmail, source); err != nil { + return err + } + + oldSubject := normalizeEmailAuthIdentitySubject(oldEmail) + if oldSubject == "" || oldSubject == newSubject { + return nil + } + + _, err := clientFromContext(ctx, client).AuthIdentity.Delete(). + Where( + authidentity.UserIDEQ(userID), + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ(oldSubject), + ). + Exec(ctx) + return err +} + +func normalizeEmailAuthIdentitySubject(email string) string { + normalized := strings.ToLower(strings.TrimSpace(email)) + if normalized == "" { + return "" + } + if strings.HasSuffix(normalized, service.LinuxDoConnectSyntheticEmailDomain) || + strings.HasSuffix(normalized, service.OIDCConnectSyntheticEmailDomain) || + strings.HasSuffix(normalized, service.WeChatConnectSyntheticEmailDomain) { + return "" + } + return normalized +} + func (r *userRepository) Delete(ctx context.Context, id int64) error { - affected, err := r.client.User.Delete().Where(dbuser.IDEQ(id)).Exec(ctx) + tx, err := r.client.Tx(ctx) + if err != nil && !errors.Is(err, dbent.ErrTxStarted) { + return translatePersistenceError(err, service.ErrUserNotFound, nil) + } + + var txClient *dbent.Client + if err == nil { + defer func() { _ = tx.Rollback() }() + txClient = tx.Client() + } else { + if existingTx := dbent.TxFromContext(ctx); existingTx != nil { + txClient = existingTx.Client() + } else { + txClient = r.client + } + } + + identityIDs, err := txClient.AuthIdentity.Query(). + Where(authidentity.UserIDEQ(id)). + IDs(ctx) + if err != nil { + return translatePersistenceError(err, service.ErrUserNotFound, nil) + } + if len(identityIDs) > 0 { + if _, err := txClient.IdentityAdoptionDecision.Update(). + Where(identityadoptiondecision.IdentityIDIn(identityIDs...)). + ClearIdentityID(). + Save(ctx); err != nil { + return translatePersistenceError(err, service.ErrUserNotFound, nil) + } + if _, err := txClient.AuthIdentityChannel.Delete(). + Where(authidentitychannel.IdentityIDIn(identityIDs...)). + Exec(ctx); err != nil { + return translatePersistenceError(err, service.ErrUserNotFound, nil) + } + if _, err := txClient.AuthIdentity.Delete(). + Where(authidentity.UserIDEQ(id)). + Exec(ctx); err != nil { + return translatePersistenceError(err, service.ErrUserNotFound, nil) + } + } + + affected, err := txClient.User.Delete().Where(dbuser.IDEQ(id)).Exec(ctx) if err != nil { return translatePersistenceError(err, service.ErrUserNotFound, nil) } if affected == 0 { return service.ErrUserNotFound } + + if tx != nil { + if err := tx.Commit(); err != nil { + return translatePersistenceError(err, service.ErrUserNotFound, nil) + } + } return nil } @@ -298,8 +476,13 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector) sortBy := strings.ToLower(strings.TrimSpace(params.SortBy)) sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc) + if sortBy == "last_used_at" { + return userLastUsedAtOrder(sortOrder) + } + var field string defaultField := true + nullsLastField := false switch sortBy { case "email": field = dbuser.FieldEmail @@ -322,6 +505,10 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector) case "created_at": field = dbuser.FieldCreatedAt defaultField = false + case "last_active_at": + field = dbuser.FieldLastActiveAt + defaultField = false + nullsLastField = true default: field = dbuser.FieldID } @@ -330,14 +517,92 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector) if defaultField && field == dbuser.FieldID { return []func(*entsql.Selector){dbent.Asc(dbuser.FieldID)} } + if nullsLastField { + return []func(*entsql.Selector){ + entsql.OrderByField(field, entsql.OrderNullsLast()).ToFunc(), + dbent.Asc(dbuser.FieldID), + } + } return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(dbuser.FieldID)} } if defaultField && field == dbuser.FieldID { return []func(*entsql.Selector){dbent.Desc(dbuser.FieldID)} } + if nullsLastField { + return []func(*entsql.Selector){ + entsql.OrderByField(field, entsql.OrderDesc(), entsql.OrderNullsLast()).ToFunc(), + dbent.Desc(dbuser.FieldID), + } + } return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(dbuser.FieldID)} } +func (r *userRepository) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) { + result := make(map[int64]*time.Time, len(userIDs)) + if len(userIDs) == 0 { + return result, nil + } + if r.sql == nil { + return nil, fmt.Errorf("sql executor is not configured") + } + + const query = ` + SELECT user_id, MAX(created_at) AS last_used_at + FROM usage_logs + WHERE user_id = ANY($1) + GROUP BY user_id + ` + + rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs)) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + for rows.Next() { + var ( + userID int64 + lastUsedAt time.Time + ) + if scanErr := rows.Scan(&userID, &lastUsedAt); scanErr != nil { + return nil, scanErr + } + ts := lastUsedAt.UTC() + result[userID] = &ts + } + if err := rows.Err(); err != nil { + return nil, err + } + return result, nil +} + +func (r *userRepository) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) { + latestByUserID, err := r.GetLatestUsedAtByUserIDs(ctx, []int64{userID}) + if err != nil { + return nil, err + } + return latestByUserID[userID], nil +} + +func userLastUsedAtOrder(sortOrder string) []func(*entsql.Selector) { + orderExpr := func(direction, nulls string, tieOrder func(string) string) func(*entsql.Selector) { + return func(s *entsql.Selector) { + subquery := fmt.Sprintf("(SELECT MAX(created_at) FROM usage_logs WHERE user_id = %s)", s.C(dbuser.FieldID)) + s.OrderExpr(entsql.Expr(subquery + " " + direction + " NULLS " + nulls)) + s.OrderBy(tieOrder(s.C(dbuser.FieldID))) + } + } + + if sortOrder == pagination.SortOrderAsc { + return []func(*entsql.Selector){ + orderExpr("ASC", "FIRST", entsql.Asc), + } + } + return []func(*entsql.Selector){ + orderExpr("DESC", "LAST", entsql.Desc), + } +} + // filterUsersByAttributes returns user IDs that match ALL the given attribute filters func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) ([]int64, error) { if len(attrs) == 0 { @@ -436,17 +701,36 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount } func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) { - return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx) + return r.client.User.Query().Where(userEmailLookupPredicate(email)).Exist(ctx) +} + +func userEmailLookupPredicate(email string) predicate.User { + normalized := strings.ToLower(strings.TrimSpace(email)) + if normalized == "" { + return dbuser.EmailEQ(email) + } + return predicate.User(func(s *entsql.Selector) { + s.Where(entsql.P(func(b *entsql.Builder) { + b.WriteString("LOWER(TRIM("). + Ident(s.C(dbuser.FieldEmail)). + WriteString(")) = "). + Arg(normalized) + })) + }) } func (r *userRepository) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error { client := clientFromContext(ctx, r.client) - return client.UserAllowedGroup.Create(). + err := client.UserAllowedGroup.Create(). SetUserID(userID). SetGroupID(groupID). OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID). DoNothing(). Exec(ctx) + if isSQLNoRowsError(err) { + return nil + } + return err } func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) { @@ -546,6 +830,9 @@ func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, cl OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID). DoNothing(). Exec(ctx); err != nil { + if isSQLNoRowsError(err) { + return nil + } return err } } @@ -558,10 +845,21 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) { return } dst.ID = src.ID + dst.SignupSource = src.SignupSource + dst.LastLoginAt = src.LastLoginAt + dst.LastActiveAt = src.LastActiveAt dst.CreatedAt = src.CreatedAt dst.UpdatedAt = src.UpdatedAt } +func userSignupSourceOrDefault(signupSource string) string { + signupSource = strings.TrimSpace(signupSource) + if signupSource == "" { + return "email" + } + return signupSource +} + // marshalExtraEmails serializes notify email entries to JSON for storage. func marshalExtraEmails(entries []service.NotifyEmailEntry) string { return service.MarshalNotifyEmails(entries) diff --git a/backend/internal/repository/user_repo_email_identity_integration_test.go b/backend/internal/repository/user_repo_email_identity_integration_test.go new file mode 100644 index 0000000000000000000000000000000000000000..fddd82c5981a220cf541083fa82fc088dd81e55c --- /dev/null +++ b/backend/internal/repository/user_repo_email_identity_integration_test.go @@ -0,0 +1,86 @@ +//go:build integration + +package repository + +import ( + "context" + + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +func (s *UserRepoSuite) TestCreate_CreatesEmailAuthIdentityForNormalEmail() { + user := &service.User{ + Email: "repo-create@example.com", + PasswordHash: "test-password-hash", + Role: service.RoleUser, + Status: service.StatusActive, + Concurrency: 2, + } + + s.Require().NoError(s.repo.Create(s.ctx, user)) + + identity, err := s.client.AuthIdentity.Query(). + Where( + authidentity.UserIDEQ(user.ID), + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("repo-create@example.com"), + ). + Only(s.ctx) + s.Require().NoError(err) + s.Require().Equal(user.ID, identity.UserID) +} + +func (s *UserRepoSuite) TestCreate_SkipsEmailAuthIdentityForSyntheticLinuxDoEmail() { + user := &service.User{ + Email: "linuxdo-legacy-user@linuxdo-connect.invalid", + PasswordHash: "test-password-hash", + Role: service.RoleUser, + Status: service.StatusActive, + Concurrency: 2, + } + + s.Require().NoError(s.repo.Create(s.ctx, user)) + + count, err := s.client.AuthIdentity.Query(). + Where( + authidentity.UserIDEQ(user.ID), + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + ). + Count(s.ctx) + s.Require().NoError(err) + s.Require().Zero(count) +} + +func (s *UserRepoSuite) TestUpdate_ReplacesEmailAuthIdentityWhenEmailChanges() { + user := s.mustCreateUser(&service.User{ + Email: "before-update@example.com", + }) + + user.Email = "after-update@example.com" + s.Require().NoError(s.repo.Update(s.ctx, user)) + + newIdentity, err := s.client.AuthIdentity.Query(). + Where( + authidentity.UserIDEQ(user.ID), + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("after-update@example.com"), + ). + Only(s.ctx) + s.Require().NoError(err) + s.Require().Equal(user.ID, newIdentity.UserID) + + oldCount, err := s.client.AuthIdentity.Query(). + Where( + authidentity.UserIDEQ(user.ID), + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("before-update@example.com"), + ). + Count(context.Background()) + s.Require().NoError(err) + s.Require().Zero(oldCount) +} diff --git a/backend/internal/repository/user_repo_email_lookup_unit_test.go b/backend/internal/repository/user_repo_email_lookup_unit_test.go new file mode 100644 index 0000000000000000000000000000000000000000..d42ce9ac4cf8e62c3db5a17bb5be0a6e8f4fae02 --- /dev/null +++ b/backend/internal/repository/user_repo_email_lookup_unit_test.go @@ -0,0 +1,69 @@ +package repository + +import ( + "context" + "database/sql" + "testing" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +func newUserEntRepo(t *testing.T) (*userRepository, *dbent.Client) { + t.Helper() + + db, err := sql.Open("sqlite", "file:user_repo_email_lookup?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + + return newUserRepositoryWithSQL(client, db), client +} + +func TestUserRepositoryGetByEmailNormalizesLegacySpacingAndCase(t *testing.T) { + repo, _ := newUserEntRepo(t) + ctx := context.Background() + + err := repo.Create(ctx, &service.User{ + Email: " Legacy@Example.com ", + Username: "legacy-user", + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + }) + require.NoError(t, err) + + got, err := repo.GetByEmail(ctx, "legacy@example.com") + require.NoError(t, err) + require.Equal(t, " Legacy@Example.com ", got.Email) +} + +func TestUserRepositoryExistsByEmailNormalizesLegacySpacingAndCase(t *testing.T) { + repo, _ := newUserEntRepo(t) + ctx := context.Background() + + err := repo.Create(ctx, &service.User{ + Email: " Legacy@Example.com ", + Username: "legacy-user", + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + }) + require.NoError(t, err) + + exists, err := repo.ExistsByEmail(ctx, " LEGACY@example.com ") + require.NoError(t, err) + require.True(t, exists) +} diff --git a/backend/internal/repository/user_repo_integration_test.go b/backend/internal/repository/user_repo_integration_test.go index f5d0f9ff1893024e7187c2283e62a16c0ad3ad3c..13a605a2f5467dc6685bc1f5b05941d763416d53 100644 --- a/backend/internal/repository/user_repo_integration_test.go +++ b/backend/internal/repository/user_repo_integration_test.go @@ -8,6 +8,8 @@ import ( "time" dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/stretchr/testify/suite" @@ -26,6 +28,8 @@ func (s *UserRepoSuite) SetupTest() { s.repo = newUserRepositoryWithSQL(s.client, integrationDB) // 清理测试数据,确保每个测试从干净状态开始 + _, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM auth_identity_channels") + _, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM auth_identities") _, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM user_subscriptions") _, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM user_allowed_groups") _, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM users") @@ -122,11 +126,27 @@ func (s *UserRepoSuite) TestGetByEmail() { s.Require().Equal(user.ID, got.ID) } +func (s *UserRepoSuite) TestGetByEmail_NormalizesSpacingAndCaseOnPostgres() { + user := s.mustCreateUser(&service.User{Email: " Legacy@Example.com "}) + + got, err := s.repo.GetByEmail(s.ctx, " legacy@example.com ") + s.Require().NoError(err, "GetByEmail normalized lookup") + s.Require().Equal(user.ID, got.ID) +} + func (s *UserRepoSuite) TestGetByEmail_NotFound() { _, err := s.repo.GetByEmail(s.ctx, "nonexistent@test.com") s.Require().Error(err, "expected error for non-existent email") } +func (s *UserRepoSuite) TestExistsByEmail_NormalizesSpacingAndCaseOnPostgres() { + s.mustCreateUser(&service.User{Email: " Legacy@Example.com "}) + + exists, err := s.repo.ExistsByEmail(s.ctx, " LEGACY@example.com ") + s.Require().NoError(err, "ExistsByEmail normalized lookup") + s.Require().True(exists) +} + func (s *UserRepoSuite) TestUpdate() { user := s.mustCreateUser(&service.User{Email: "update@test.com", Username: "original"}) @@ -140,6 +160,30 @@ func (s *UserRepoSuite) TestUpdate() { s.Require().Equal("updated", updated.Username) } +func (s *UserRepoSuite) TestUpdateIgnoresNoRowsFromConflictingEmailIdentityUpsert() { + user := s.mustCreateUser(&service.User{Email: "update-existing-identity@test.com", Username: "original"}) + + identityCount, err := s.client.AuthIdentity.Query(). + Where( + authidentity.UserIDEQ(user.ID), + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("update-existing-identity@test.com"), + ). + Count(s.ctx) + s.Require().NoError(err) + s.Require().Equal(1, identityCount) + + got, err := s.repo.GetByID(s.ctx, user.ID) + s.Require().NoError(err) + got.Username = "updated" + s.Require().NoError(s.repo.Update(s.ctx, got), "Update should tolerate ON CONFLICT DO NOTHING returning no rows") + + updated, err := s.repo.GetByID(s.ctx, user.ID) + s.Require().NoError(err) + s.Require().Equal("updated", updated.Username) +} + func (s *UserRepoSuite) TestDelete() { user := s.mustCreateUser(&service.User{Email: "delete@test.com"}) @@ -150,6 +194,39 @@ func (s *UserRepoSuite) TestDelete() { s.Require().Error(err, "expected error after delete") } +func (s *UserRepoSuite) TestDeleteRemovesAuthIdentitiesAndChannels() { + user := s.mustCreateUser(&service.User{Email: "delete-oauth@test.com"}) + + identity, err := s.client.AuthIdentity.Create(). + SetUserID(user.ID). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("delete-oauth-subject"). + Save(s.ctx) + s.Require().NoError(err) + + _, err = s.client.AuthIdentityChannel.Create(). + SetIdentityID(identity.ID). + SetProviderType("wechat"). + SetProviderKey("wechat"). + SetChannel("open"). + SetChannelAppID("app-id"). + SetChannelSubject("openid-123"). + Save(s.ctx) + s.Require().NoError(err) + + err = s.repo.Delete(s.ctx, user.ID) + s.Require().NoError(err) + + identityCount, err := s.client.AuthIdentity.Query().Where(authidentity.UserIDEQ(user.ID)).Count(s.ctx) + s.Require().NoError(err) + s.Require().Zero(identityCount) + + channelCount, err := s.client.AuthIdentityChannel.Query().Where(authidentitychannel.IdentityIDEQ(identity.ID)).Count(s.ctx) + s.Require().NoError(err) + s.Require().Zero(channelCount) +} + // --- List / ListWithFilters --- func (s *UserRepoSuite) TestList() { diff --git a/backend/internal/repository/user_repo_sort_integration_test.go b/backend/internal/repository/user_repo_sort_integration_test.go index ab84b0e93bb969bfaf1c863cdb30067cc15a8942..3a15bc1024b66033002876732a1047a10ab50595 100644 --- a/backend/internal/repository/user_repo_sort_integration_test.go +++ b/backend/internal/repository/user_repo_sort_integration_test.go @@ -4,11 +4,30 @@ package repository import ( "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" ) +func (s *UserRepoSuite) mustInsertUsageLog(userID int64, createdAt time.Time) { + s.T().Helper() + + account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "usage-log-account"}) + apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: userID}) + + _, err := integrationDB.ExecContext( + s.ctx, + `INSERT INTO usage_logs (user_id, api_key_id, account_id, model, input_tokens, output_tokens, total_cost, actual_cost, created_at) + VALUES ($1, $2, $3, 'gpt-test', 1, 1, 0.01, 0.01, $4)`, + userID, + apiKey.ID, + account.ID, + createdAt.UTC(), + ) + s.Require().NoError(err) +} + func (s *UserRepoSuite) TestListWithFilters_SortByEmailAsc() { s.mustCreateUser(&service.User{Email: "z-last@example.com", Username: "z-user"}) s.mustCreateUser(&service.User{Email: "a-first@example.com", Username: "a-user"}) @@ -36,4 +55,110 @@ func (s *UserRepoSuite) TestList_DefaultSortByNewestFirst() { s.Require().Equal(first.ID, users[1].ID) } +func (s *UserRepoSuite) TestCreateAndRead_PreservesSignupSourceAndActivityTimestamps() { + lastLoginAt := time.Now().Add(-2 * time.Hour).UTC().Truncate(time.Microsecond) + lastActiveAt := time.Now().Add(-30 * time.Minute).UTC().Truncate(time.Microsecond) + + created := s.mustCreateUser(&service.User{ + Email: "identity-meta@example.com", + SignupSource: "linuxdo", + LastLoginAt: &lastLoginAt, + LastActiveAt: &lastActiveAt, + }) + + got, err := s.repo.GetByID(s.ctx, created.ID) + s.Require().NoError(err) + s.Require().Equal("linuxdo", got.SignupSource) + s.Require().NotNil(got.LastLoginAt) + s.Require().NotNil(got.LastActiveAt) + s.Require().True(got.LastLoginAt.Equal(lastLoginAt)) + s.Require().True(got.LastActiveAt.Equal(lastActiveAt)) +} + +func (s *UserRepoSuite) TestUpdate_PersistsSignupSourceAndActivityTimestamps() { + created := s.mustCreateUser(&service.User{Email: "identity-update@example.com"}) + lastLoginAt := time.Now().Add(-90 * time.Minute).UTC().Truncate(time.Microsecond) + lastActiveAt := time.Now().Add(-15 * time.Minute).UTC().Truncate(time.Microsecond) + + created.SignupSource = "oidc" + created.LastLoginAt = &lastLoginAt + created.LastActiveAt = &lastActiveAt + + s.Require().NoError(s.repo.Update(s.ctx, created)) + + got, err := s.repo.GetByID(s.ctx, created.ID) + s.Require().NoError(err) + s.Require().Equal("oidc", got.SignupSource) + s.Require().NotNil(got.LastLoginAt) + s.Require().NotNil(got.LastActiveAt) + s.Require().True(got.LastLoginAt.Equal(lastLoginAt)) + s.Require().True(got.LastActiveAt.Equal(lastActiveAt)) +} + +func (s *UserRepoSuite) TestListWithFilters_SortByLastActiveAtAsc() { + earlier := time.Now().Add(-3 * time.Hour).UTC().Truncate(time.Microsecond) + later := time.Now().Add(-45 * time.Minute).UTC().Truncate(time.Microsecond) + + s.mustCreateUser(&service.User{Email: "nil-active@example.com"}) + s.mustCreateUser(&service.User{Email: "later-active@example.com", LastActiveAt: &later}) + s.mustCreateUser(&service.User{Email: "earlier-active@example.com", LastActiveAt: &earlier}) + + users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{ + Page: 1, + PageSize: 10, + SortBy: "last_active_at", + SortOrder: "asc", + }, service.UserListFilters{}) + s.Require().NoError(err) + s.Require().Len(users, 3) + s.Require().Equal("earlier-active@example.com", users[0].Email) + s.Require().Equal("later-active@example.com", users[1].Email) + s.Require().Equal("nil-active@example.com", users[2].Email) +} + +func (s *UserRepoSuite) TestGetLatestUsedAtByUserIDs_UsesUsageLogs() { + older := time.Now().Add(-4 * time.Hour).UTC().Truncate(time.Second) + newer := time.Now().Add(-90 * time.Minute).UTC().Truncate(time.Second) + + userWithUsage := s.mustCreateUser(&service.User{Email: "usage-source@example.com"}) + userWithoutUsage := s.mustCreateUser(&service.User{Email: "usage-missing@example.com"}) + s.mustInsertUsageLog(userWithUsage.ID, older) + s.mustInsertUsageLog(userWithUsage.ID, newer) + + got, err := s.repo.GetLatestUsedAtByUserIDs(s.ctx, []int64{userWithUsage.ID, userWithoutUsage.ID}) + s.Require().NoError(err) + s.Require().Contains(got, userWithUsage.ID) + s.Require().NotContains(got, userWithoutUsage.ID) + s.Require().NotNil(got[userWithUsage.ID]) + s.Require().True(got[userWithUsage.ID].Equal(newer)) +} + +func (s *UserRepoSuite) TestListWithFilters_SortByLastUsedAtDesc_UsesUsageLogsNotLastActiveAt() { + lastUsedOlder := time.Now().Add(-6 * time.Hour).UTC().Truncate(time.Second) + lastUsedNewer := time.Now().Add(-2 * time.Hour).UTC().Truncate(time.Second) + lastActiveVeryRecent := time.Now().Add(-10 * time.Minute).UTC().Truncate(time.Second) + + nilUsage := s.mustCreateUser(&service.User{Email: "nil-last-used@example.com"}) + wrongSource := s.mustCreateUser(&service.User{ + Email: "active-not-usage@example.com", + LastActiveAt: &lastActiveVeryRecent, + }) + rightSource := s.mustCreateUser(&service.User{Email: "usage-wins@example.com"}) + + s.mustInsertUsageLog(wrongSource.ID, lastUsedOlder) + s.mustInsertUsageLog(rightSource.ID, lastUsedNewer) + + users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{ + Page: 1, + PageSize: 10, + SortBy: "last_used_at", + SortOrder: "desc", + }, service.UserListFilters{}) + s.Require().NoError(err) + s.Require().Len(users, 3) + s.Require().Equal(rightSource.ID, users[0].ID) + s.Require().Equal(wrongSource.ID, users[1].ID) + s.Require().Equal(nilUsage.ID, users[2].ID) +} + func TestUserRepoSortSuiteSmoke(_ *testing.T) {} diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index b686b986faee962e777845a584dc6a553c9128af..ed7764cfb1ee002a000ec37b9634c3a3d4d1ed00 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -50,6 +50,7 @@ func TestAPIContracts(t *testing.T) { "data": { "id": 1, "email": "alice@example.com", + "email_bound": true, "username": "alice", "role": "user", "balance": 12.5, @@ -63,6 +64,120 @@ func TestAPIContracts(t *testing.T) { "balance_notify_threshold": null, "balance_notify_extra_emails": null, "total_recharged": 0, + "linuxdo_bound": false, + "oidc_bound": false, + "wechat_bound": false, + "identities": { + "email": { + "provider": "email", + "provider_key": "email", + "bound": true, + "bound_count": 1, + "can_bind": false, + "can_unbind": false, + "display_name": "alice@example.com", + "subject_hint": "a***e@example.com", + "note": "Primary account email is managed from the profile form." + }, + "linuxdo": { + "provider": "linuxdo", + "bound": false, + "bound_count": 0, + "can_bind": true, + "can_unbind": false, + "bind_start_path": "/api/v1/auth/oauth/linuxdo/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" + }, + "oidc": { + "provider": "oidc", + "bound": false, + "bound_count": 0, + "can_bind": true, + "can_unbind": false, + "bind_start_path": "/api/v1/auth/oauth/oidc/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" + }, + "wechat": { + "provider": "wechat", + "bound": false, + "bound_count": 0, + "can_bind": true, + "can_unbind": false, + "bind_start_path": "/api/v1/auth/oauth/wechat/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" + } + }, + "identity_bindings": { + "email": { + "provider": "email", + "provider_key": "email", + "bound": true, + "bound_count": 1, + "can_bind": false, + "can_unbind": false, + "display_name": "alice@example.com", + "subject_hint": "a***e@example.com", + "note": "Primary account email is managed from the profile form." + }, + "linuxdo": { + "provider": "linuxdo", + "bound": false, + "bound_count": 0, + "can_bind": true, + "can_unbind": false, + "bind_start_path": "/api/v1/auth/oauth/linuxdo/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" + }, + "oidc": { + "provider": "oidc", + "bound": false, + "bound_count": 0, + "can_bind": true, + "can_unbind": false, + "bind_start_path": "/api/v1/auth/oauth/oidc/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" + }, + "wechat": { + "provider": "wechat", + "bound": false, + "bound_count": 0, + "can_bind": true, + "can_unbind": false, + "bind_start_path": "/api/v1/auth/oauth/wechat/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" + } + }, + "auth_bindings": { + "email": { + "provider": "email", + "provider_key": "email", + "bound": true, + "bound_count": 1, + "can_bind": false, + "can_unbind": false, + "display_name": "alice@example.com", + "subject_hint": "a***e@example.com", + "note": "Primary account email is managed from the profile form." + }, + "linuxdo": { + "provider": "linuxdo", + "bound": false, + "bound_count": 0, + "can_bind": true, + "can_unbind": false, + "bind_start_path": "/api/v1/auth/oauth/linuxdo/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" + }, + "oidc": { + "provider": "oidc", + "bound": false, + "bound_count": 0, + "can_bind": true, + "can_unbind": false, + "bind_start_path": "/api/v1/auth/oauth/oidc/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" + }, + "wechat": { + "provider": "wechat", + "bound": false, + "bound_count": 0, + "can_bind": true, + "can_unbind": false, + "bind_start_path": "/api/v1/auth/oauth/wechat/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" + } + }, "run_mode": "standard" } }`, @@ -479,7 +594,7 @@ func TestAPIContracts(t *testing.T) { service.SettingKeyOIDCConnectRedirectURL: "", service.SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback", service.SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post", - service.SettingKeyOIDCConnectUsePKCE: "false", + service.SettingKeyOIDCConnectUsePKCE: "true", service.SettingKeyOIDCConnectValidateIDToken: "true", service.SettingKeyOIDCConnectAllowedSigningAlgs: "RS256,ES256,PS256", service.SettingKeyOIDCConnectClockSkewSeconds: "120", @@ -500,10 +615,15 @@ func TestAPIContracts(t *testing.T) { service.SettingKeyTableDefaultPageSize: "20", service.SettingKeyTablePageSizeOptions: "[10,20,50,100]", - service.SettingKeyOpsMonitoringEnabled: "false", - service.SettingKeyOpsRealtimeMonitoringEnabled: "true", - service.SettingKeyOpsQueryModeDefault: "auto", - service.SettingKeyOpsMetricsIntervalSeconds: "60", + service.SettingKeyOpsMonitoringEnabled: "false", + service.SettingKeyOpsRealtimeMonitoringEnabled: "true", + service.SettingKeyOpsQueryModeDefault: "auto", + service.SettingKeyOpsMetricsIntervalSeconds: "60", + service.SettingPaymentVisibleMethodAlipaySource: service.VisibleMethodSourceEasyPayAlipay, + service.SettingPaymentVisibleMethodWxpaySource: service.VisibleMethodSourceOfficialWechat, + service.SettingPaymentVisibleMethodAlipayEnabled: "true", + service.SettingPaymentVisibleMethodWxpayEnabled: "false", + "openai_advanced_scheduler_enabled": "true", }) }, method: http.MethodGet, @@ -549,7 +669,7 @@ func TestAPIContracts(t *testing.T) { "oidc_connect_redirect_url": "", "oidc_connect_frontend_redirect_url": "/auth/oidc/callback", "oidc_connect_token_auth_method": "client_secret_post", - "oidc_connect_use_pkce": false, + "oidc_connect_use_pkce": true, "oidc_connect_validate_id_token": true, "oidc_connect_allowed_signing_algs": "RS256,ES256,PS256", "oidc_connect_clock_skew_seconds": 120, @@ -567,6 +687,27 @@ func TestAPIContracts(t *testing.T) { "api_base_url": "https://api.example.com", "contact_info": "support", "doc_url": "https://docs.example.com", + "auth_source_default_email_balance": 0, + "auth_source_default_email_concurrency": 5, + "auth_source_default_email_subscriptions": [], + "auth_source_default_email_grant_on_signup": false, + "auth_source_default_email_grant_on_first_bind": false, + "auth_source_default_linuxdo_balance": 0, + "auth_source_default_linuxdo_concurrency": 5, + "auth_source_default_linuxdo_subscriptions": [], + "auth_source_default_linuxdo_grant_on_signup": false, + "auth_source_default_linuxdo_grant_on_first_bind": false, + "auth_source_default_oidc_balance": 0, + "auth_source_default_oidc_concurrency": 5, + "auth_source_default_oidc_subscriptions": [], + "auth_source_default_oidc_grant_on_signup": false, + "auth_source_default_oidc_grant_on_first_bind": false, + "auth_source_default_wechat_balance": 0, + "auth_source_default_wechat_concurrency": 5, + "auth_source_default_wechat_subscriptions": [], + "auth_source_default_wechat_grant_on_signup": false, + "auth_source_default_wechat_grant_on_first_bind": false, + "force_email_on_third_party_signup": false, "default_concurrency": 5, "default_balance": 1.25, "default_subscriptions": [], @@ -592,6 +733,11 @@ func TestAPIContracts(t *testing.T) { "enable_fingerprint_unification": true, "enable_metadata_passthrough": false, "web_search_emulation_enabled": false, + "payment_visible_method_alipay_source": "easypay_alipay", + "payment_visible_method_wxpay_source": "official_wxpay", + "payment_visible_method_alipay_enabled": true, + "payment_visible_method_wxpay_enabled": false, + "openai_advanced_scheduler_enabled": true, "custom_menu_items": [], "custom_endpoints": [], "payment_enabled": false, @@ -618,7 +764,23 @@ func TestAPIContracts(t *testing.T) { "account_quota_notify_enabled": false, "balance_low_notify_threshold": 0, "balance_low_notify_recharge_url": "", - "account_quota_notify_emails": [] + "account_quota_notify_emails": [], + "wechat_connect_enabled": false, + "wechat_connect_app_id": "", + "wechat_connect_app_secret_configured": false, + "wechat_connect_mode": "open", + "wechat_connect_open_enabled": false, + "wechat_connect_open_app_id": "", + "wechat_connect_open_app_secret_configured": false, + "wechat_connect_mp_enabled": false, + "wechat_connect_mp_app_id": "", + "wechat_connect_mp_app_secret_configured": false, + "wechat_connect_mobile_enabled": false, + "wechat_connect_mobile_app_id": "", + "wechat_connect_mobile_app_secret_configured": false, + "wechat_connect_redirect_url": "", + "wechat_connect_frontend_redirect_url": "/auth/wechat/callback", + "wechat_connect_scopes": "snsapi_login" } }`, }, @@ -858,6 +1020,18 @@ func (r *stubUserRepo) Delete(ctx context.Context, id int64) error { return errors.New("not implemented") } +func (r *stubUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) { + return nil, nil +} + +func (r *stubUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) { + return nil, errors.New("not implemented") +} + +func (r *stubUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error { + return errors.New("not implemented") +} + func (r *stubUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } @@ -894,6 +1068,26 @@ func (r *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64 return errors.New("not implemented") } +func (r *stubUserRepo) ListUserAuthIdentities(ctx context.Context, userID int64) ([]service.UserAuthIdentityRecord, error) { + return nil, nil +} + +func (r *stubUserRepo) UnbindUserAuthProvider(context.Context, int64, string) error { + return errors.New("not implemented") +} + +func (r *stubUserRepo) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) { + return map[int64]*time.Time{}, nil +} + +func (r *stubUserRepo) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) { + return nil, nil +} + +func (r *stubUserRepo) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error { + return nil +} + func (r *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error { return errors.New("not implemented") } diff --git a/backend/internal/server/middleware/admin_auth_test.go b/backend/internal/server/middleware/admin_auth_test.go index ed2578c843f1f9ef38b0a4c6110d0468cadebe6d..06e3355e5ea2b7d1c127135163412af8d8baefce 100644 --- a/backend/internal/server/middleware/admin_auth_test.go +++ b/backend/internal/server/middleware/admin_auth_test.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" @@ -153,6 +154,18 @@ func (s *stubUserRepo) Delete(ctx context.Context, id int64) error { panic("unexpected Delete call") } +func (s *stubUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) { + return nil, nil +} + +func (s *stubUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) { + panic("unexpected UpsertUserAvatar call") +} + +func (s *stubUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error { + panic("unexpected DeleteUserAvatar call") +} + func (s *stubUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) { panic("unexpected List call") } @@ -161,6 +174,18 @@ func (s *stubUserRepo) ListWithFilters(ctx context.Context, params pagination.Pa panic("unexpected ListWithFilters call") } +func (s *stubUserRepo) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) { + panic("unexpected GetLatestUsedAtByUserIDs call") +} + +func (s *stubUserRepo) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) { + panic("unexpected GetLatestUsedAtByUserID call") +} + +func (s *stubUserRepo) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error { + panic("unexpected UpdateUserLastActiveAt call") +} + func (s *stubUserRepo) UpdateBalance(ctx context.Context, id int64, amount float64) error { panic("unexpected UpdateBalance call") } @@ -189,6 +214,14 @@ func (s *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64 panic("unexpected AddGroupToAllowedGroups call") } +func (s *stubUserRepo) ListUserAuthIdentities(ctx context.Context, userID int64) ([]service.UserAuthIdentityRecord, error) { + panic("unexpected ListUserAuthIdentities call") +} + +func (s *stubUserRepo) UnbindUserAuthProvider(context.Context, int64, string) error { + panic("unexpected UnbindUserAuthProvider call") +} + func (s *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error { panic("unexpected UpdateTotpSecret call") } diff --git a/backend/internal/server/middleware/jwt_auth.go b/backend/internal/server/middleware/jwt_auth.go index 4aceb3550258b23efdaad7aced6ebd95b47d32ea..48cb9004bfb918f754bee37f04491017882b1cc8 100644 --- a/backend/internal/server/middleware/jwt_auth.go +++ b/backend/internal/server/middleware/jwt_auth.go @@ -1,6 +1,7 @@ package middleware import ( + "context" "errors" "strings" @@ -11,11 +12,19 @@ import ( // NewJWTAuthMiddleware 创建 JWT 认证中间件 func NewJWTAuthMiddleware(authService *service.AuthService, userService *service.UserService) JWTAuthMiddleware { - return JWTAuthMiddleware(jwtAuth(authService, userService)) + return JWTAuthMiddleware(jwtAuth(authService, userService, userService)) +} + +type jwtUserReader interface { + GetByID(ctx context.Context, id int64) (*service.User, error) +} + +type userActivityToucher interface { + TouchLastActiveForUser(ctx context.Context, user *service.User) } // jwtAuth JWT认证中间件实现 -func jwtAuth(authService *service.AuthService, userService *service.UserService) gin.HandlerFunc { +func jwtAuth(authService *service.AuthService, userService jwtUserReader, activityToucher userActivityToucher) gin.HandlerFunc { return func(c *gin.Context) { // 从Authorization header中提取token authHeader := c.GetHeader("Authorization") @@ -73,6 +82,9 @@ func jwtAuth(authService *service.AuthService, userService *service.UserService) Concurrency: user.Concurrency, }) c.Set(string(ContextKeyUserRole), user.Role) + if activityToucher != nil { + activityToucher.TouchLastActiveForUser(c.Request.Context(), user) + } c.Next() } diff --git a/backend/internal/server/middleware/jwt_auth_test.go b/backend/internal/server/middleware/jwt_auth_test.go index c483a51eafca6b83e5867277ae4c227046494e07..84fd696739dfa61332e4d12c0dd14611f77b2b11 100644 --- a/backend/internal/server/middleware/jwt_auth_test.go +++ b/backend/internal/server/middleware/jwt_auth_test.go @@ -9,6 +9,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/service" @@ -30,6 +31,25 @@ func (r *stubJWTUserRepo) GetByID(_ context.Context, id int64) (*service.User, e return u, nil } +func (r *stubJWTUserRepo) GetUserAvatar(_ context.Context, _ int64) (*service.UserAvatar, error) { + return nil, nil +} + +func (r *stubJWTUserRepo) UpdateUserLastActiveAt(_ context.Context, _ int64, _ time.Time) error { + return nil +} + +type recordingActivityToucher struct { + userIDs []int64 +} + +func (r *recordingActivityToucher) TouchLastActiveForUser(_ context.Context, user *service.User) { + if user == nil { + return + } + r.userIDs = append(r.userIDs, user.ID) +} + // newJWTTestEnv 创建 JWT 认证中间件测试环境。 // 返回 gin.Engine(已注册 JWT 中间件)和 AuthService(用于生成 Token)。 func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthService) { @@ -106,6 +126,45 @@ func TestJWTAuth_ValidToken_LowercaseBearer(t *testing.T) { require.Equal(t, http.StatusOK, w.Code) } +func TestJWTAuth_ValidToken_TouchesLastActive(t *testing.T) { + user := &service.User{ + ID: 1, + Email: "test@example.com", + Role: "user", + Status: service.StatusActive, + Concurrency: 5, + TokenVersion: 1, + } + + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.JWT.Secret = "test-jwt-secret-32bytes-long!!!" + cfg.JWT.AccessTokenExpireMinutes = 60 + + userRepo := &stubJWTUserRepo{users: map[int64]*service.User{1: user}} + authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil) + userSvc := service.NewUserService(userRepo, nil, nil, nil) + toucher := &recordingActivityToucher{} + + r := gin.New() + r.Use(jwtAuth(authSvc, userSvc, toucher)) + r.GET("/protected", func(c *gin.Context) { + c.Status(http.StatusOK) + }) + + token, err := authSvc.GenerateToken(user) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "Bearer "+token) + r.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + require.Equal(t, []int64{1}, toucher.userIDs) +} + func TestJWTAuth_MissingAuthorizationHeader(t *testing.T) { router, _ := newJWTTestEnv(nil) diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 9af0fd8ef6ab886f66f4b4fed2ea42099f01db2d..84c963ec24241af0b7187b716843f6033a8cfcae 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -212,6 +212,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) { { users.GET("", h.Admin.User.List) users.GET("/:id", h.Admin.User.GetByID) + users.POST("/:id/auth-identities", h.Admin.User.BindAuthIdentity) users.POST("", h.Admin.User.Create) users.PUT("/:id", h.Admin.User.Update) users.DELETE("/:id", h.Admin.User.Delete) diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go index c143b030fc88da423867e4bf77404474d9039f78..f1032eb52479cc8a580ab4882870971886b8247a 100644 --- a/backend/internal/server/routes/auth.go +++ b/backend/internal/server/routes/auth.go @@ -64,12 +64,70 @@ func RegisterAuthRoutes( }), h.Auth.ResetPassword) auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart) auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback) + auth.GET("/oauth/wechat/start", h.Auth.WeChatOAuthStart) + auth.GET("/oauth/wechat/callback", h.Auth.WeChatOAuthCallback) + auth.GET("/oauth/wechat/payment/start", h.Auth.WeChatPaymentOAuthStart) + auth.GET("/oauth/wechat/payment/callback", h.Auth.WeChatPaymentOAuthCallback) + auth.POST("/oauth/pending/exchange", + rateLimiter.LimitWithOptions("oauth-pending-exchange", 20, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.ExchangePendingOAuthCompletion, + ) + auth.POST("/oauth/pending/send-verify-code", + rateLimiter.LimitWithOptions("oauth-pending-send-verify-code", 5, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.SendPendingOAuthVerifyCode, + ) + auth.POST("/oauth/pending/create-account", + rateLimiter.LimitWithOptions("oauth-pending-create-account", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.CreatePendingOAuthAccount, + ) + auth.POST("/oauth/pending/bind-login", + rateLimiter.LimitWithOptions("oauth-pending-bind-login", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.BindPendingOAuthLogin, + ) auth.POST("/oauth/linuxdo/complete-registration", rateLimiter.LimitWithOptions("oauth-linuxdo-complete", 10, time.Minute, middleware.RateLimitOptions{ FailureMode: middleware.RateLimitFailClose, }), h.Auth.CompleteLinuxDoOAuthRegistration, ) + auth.POST("/oauth/linuxdo/bind-login", + rateLimiter.LimitWithOptions("oauth-linuxdo-bind-login", 20, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.BindLinuxDoOAuthLogin, + ) + auth.POST("/oauth/linuxdo/create-account", + rateLimiter.LimitWithOptions("oauth-linuxdo-create-account", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.CreateLinuxDoOAuthAccount, + ) + auth.POST("/oauth/wechat/complete-registration", + rateLimiter.LimitWithOptions("oauth-wechat-complete", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.CompleteWeChatOAuthRegistration, + ) + auth.POST("/oauth/wechat/bind-login", + rateLimiter.LimitWithOptions("oauth-wechat-bind-login", 20, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.BindWeChatOAuthLogin, + ) + auth.POST("/oauth/wechat/create-account", + rateLimiter.LimitWithOptions("oauth-wechat-create-account", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.CreateWeChatOAuthAccount, + ) auth.GET("/oauth/oidc/start", h.Auth.OIDCOAuthStart) auth.GET("/oauth/oidc/callback", h.Auth.OIDCOAuthCallback) auth.POST("/oauth/oidc/complete-registration", @@ -78,6 +136,18 @@ func RegisterAuthRoutes( }), h.Auth.CompleteOIDCOAuthRegistration, ) + auth.POST("/oauth/oidc/bind-login", + rateLimiter.LimitWithOptions("oauth-oidc-bind-login", 20, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.BindOIDCOAuthLogin, + ) + auth.POST("/oauth/oidc/create-account", + rateLimiter.LimitWithOptions("oauth-oidc-create-account", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.CreateOIDCOAuthAccount, + ) } // 公开设置(无需认证) @@ -94,5 +164,23 @@ func RegisterAuthRoutes( authenticated.GET("/auth/me", h.Auth.GetCurrentUser) // 撤销所有会话(需要认证) authenticated.POST("/auth/revoke-all-sessions", h.Auth.RevokeAllSessions) + authenticated.GET("/auth/oauth/linuxdo/bind/start", func(c *gin.Context) { + query := c.Request.URL.Query() + query.Set("intent", "bind_current_user") + c.Request.URL.RawQuery = query.Encode() + h.Auth.LinuxDoOAuthStart(c) + }) + authenticated.GET("/auth/oauth/oidc/bind/start", func(c *gin.Context) { + query := c.Request.URL.Query() + query.Set("intent", "bind_current_user") + c.Request.URL.RawQuery = query.Encode() + h.Auth.OIDCOAuthStart(c) + }) + authenticated.GET("/auth/oauth/wechat/bind/start", func(c *gin.Context) { + query := c.Request.URL.Query() + query.Set("intent", "bind_current_user") + c.Request.URL.RawQuery = query.Encode() + h.Auth.WeChatOAuthStart(c) + }) } } diff --git a/backend/internal/server/routes/auth_rate_limit_test.go b/backend/internal/server/routes/auth_rate_limit_test.go index 4f411cec570a3595248d7d2b86f440f6b0bdf119..07a66efb4e580cbade54a8befa4cf00b63d06783 100644 --- a/backend/internal/server/routes/auth_rate_limit_test.go +++ b/backend/internal/server/routes/auth_rate_limit_test.go @@ -52,6 +52,7 @@ func TestAuthRoutesRateLimitFailCloseWhenRedisUnavailable(t *testing.T) { "/api/v1/auth/login", "/api/v1/auth/login/2fa", "/api/v1/auth/send-verify-code", + "/api/v1/auth/oauth/pending/send-verify-code", } for _, path := range paths { diff --git a/backend/internal/server/routes/payment.go b/backend/internal/server/routes/payment.go index 23bd58ad8eb20335e6169cc3cc89e4c445d15455..ec340d94eabee7d86f9794003b1ca0879fc6e331 100644 --- a/backend/internal/server/routes/payment.go +++ b/backend/internal/server/routes/payment.go @@ -44,11 +44,13 @@ func RegisterPaymentRoutes( } // --- Public payment endpoints (no auth) --- - // Payment result page needs to verify order status without login - // (user session may have expired during provider redirect). + // Signed resume-token recovery is the supported public lookup path. + // The legacy anonymous out_trade_no verify endpoint is kept only as a + // compatibility shim that returns HTTP 410 Gone. public := v1.Group("/payment/public") { public.POST("/orders/verify", paymentHandler.VerifyOrderPublic) + public.POST("/orders/resolve", paymentHandler.ResolveOrderPublicByResumeToken) } // --- Webhook endpoints (no auth) --- diff --git a/backend/internal/server/routes/user.go b/backend/internal/server/routes/user.go index d004f8b4391d02c28a40e8d8f0f4ae50f881be4c..b76bb3cd2c19300bf2a38c5c5bbad04e31725336 100644 --- a/backend/internal/server/routes/user.go +++ b/backend/internal/server/routes/user.go @@ -25,6 +25,10 @@ func RegisterUserRoutes( user.GET("/profile", h.User.GetProfile) user.PUT("/password", h.User.ChangePassword) user.PUT("", h.User.UpdateProfile) + user.POST("/account-bindings/email/send-code", h.User.SendEmailBindingCode) + user.POST("/account-bindings/email", h.User.BindEmailIdentity) + user.DELETE("/account-bindings/:provider", h.User.UnbindIdentity) + user.POST("/auth-identities/bind/start", h.User.StartIdentityBinding) // 通知邮箱管理 notifyEmail := user.Group("/notify-email") diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 701f3659ee03ac052f7201f036f43ae01cca41d2..110c90083b1313c03bdefdad474785787cf52b4f 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -2,6 +2,7 @@ package service import ( "context" + "encoding/json" "errors" "fmt" "io" @@ -11,6 +12,8 @@ import ( "time" dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" @@ -33,6 +36,7 @@ type AdminService interface { // codeType is optional - pass empty string to return all types. // Also returns totalRecharged (sum of all positive balance top-ups). GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) + BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error) // Group management ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error) @@ -127,6 +131,44 @@ type UpdateUserInput struct { GroupRates map[int64]*float64 } +type AdminBindAuthIdentityInput struct { + ProviderType string + ProviderKey string + ProviderSubject string + Issuer *string + Metadata map[string]any + Channel *AdminBindAuthIdentityChannelInput +} + +type AdminBindAuthIdentityChannelInput struct { + Channel string + ChannelAppID string + ChannelSubject string + Metadata map[string]any +} + +type AdminBoundAuthIdentity struct { + UserID int64 `json:"user_id"` + ProviderType string `json:"provider_type"` + ProviderKey string `json:"provider_key"` + ProviderSubject string `json:"provider_subject"` + VerifiedAt *time.Time `json:"verified_at,omitempty"` + Issuer *string `json:"issuer,omitempty"` + Metadata map[string]any `json:"metadata"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + Channel *AdminBoundAuthIdentityChannel `json:"channel,omitempty"` +} + +type AdminBoundAuthIdentityChannel struct { + Channel string `json:"channel"` + ChannelAppID string `json:"channel_app_id"` + ChannelSubject string `json:"channel_subject"` + Metadata map[string]any `json:"metadata"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + type CreateGroupInput struct { Name string Description string @@ -491,6 +533,20 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi if err != nil { return nil, 0, err } + if len(users) > 0 { + userIDs := make([]int64, 0, len(users)) + for i := range users { + userIDs = append(userIDs, users[i].ID) + } + lastUsedByUserID, latestErr := s.userRepo.GetLatestUsedAtByUserIDs(ctx, userIDs) + if latestErr != nil { + logger.LegacyPrintf("service.admin", "failed to load user last_used_at in batch: err=%v", latestErr) + } else { + for i := range users { + users[i].LastUsedAt = lastUsedByUserID[users[i].ID] + } + } + } // 批量加载用户专属分组倍率 if s.userGroupRateRepo != nil && len(users) > 0 { if batchRepo, ok := s.userGroupRateRepo.(userGroupRateBatchReader); ok { @@ -535,6 +591,12 @@ func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) if err != nil { return nil, err } + lastUsedAt, latestErr := s.userRepo.GetLatestUsedAtByUserID(ctx, id) + if latestErr != nil { + logger.LegacyPrintf("service.admin", "failed to load user last_used_at: user_id=%d err=%v", id, latestErr) + } else { + user.LastUsedAt = lastUsedAt + } // 加载用户专属分组倍率 if s.userGroupRateRepo != nil { rates, err := s.userGroupRateRepo.GetByUserID(ctx, id) @@ -797,6 +859,227 @@ func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int return codes, result.Total, totalRecharged, nil } +func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error) { + if userID <= 0 { + return nil, infraerrors.BadRequest("INVALID_INPUT", "user_id must be greater than 0") + } + if s == nil || s.entClient == nil || s.userRepo == nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_UNAVAILABLE", "auth identity binding service is unavailable") + } + if _, err := s.userRepo.GetByID(ctx, userID); err != nil { + return nil, err + } + + providerType := normalizeAdminAuthIdentityProviderType(input.ProviderType) + providerKey := strings.TrimSpace(input.ProviderKey) + providerSubject := strings.TrimSpace(input.ProviderSubject) + if providerType == "" { + return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type must be one of email, linuxdo, oidc, or wechat") + } + if providerKey == "" || providerSubject == "" { + return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type, provider_key, and provider_subject are required") + } + + var issuer *string + if input.Issuer != nil { + trimmed := strings.TrimSpace(*input.Issuer) + if trimmed != "" { + issuer = &trimmed + } + } + + channelInput := normalizeAdminBindChannelInput(input.Channel) + if input.Channel != nil && channelInput == nil { + return nil, infraerrors.BadRequest("INVALID_INPUT", "channel, channel_app_id, and channel_subject are required when channel binding is provided") + } + + verifiedAt := time.Now().UTC() + tx, err := s.entClient.Tx(ctx) + if err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_TX_FAILED", "failed to start auth identity bind transaction").WithCause(err) + } + defer func() { _ = tx.Rollback() }() + + identity, err := tx.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(providerType), + authidentity.ProviderKeyEQ(providerKey), + authidentity.ProviderSubjectEQ(providerSubject), + ). + Only(ctx) + if err != nil && !dbent.IsNotFound(err) { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err) + } + if identity != nil && identity.UserID != userID { + return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + } + + if identity == nil { + create := tx.AuthIdentity.Create(). + SetUserID(userID). + SetProviderType(providerType). + SetProviderKey(providerKey). + SetProviderSubject(providerSubject). + SetVerifiedAt(verifiedAt) + if issuer != nil { + create = create.SetIssuer(*issuer) + } + if input.Metadata != nil { + create = create.SetMetadata(cloneAdminAuthIdentityMetadata(input.Metadata)) + } + identity, err = create.Save(ctx) + if err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_SAVE_FAILED", "failed to save auth identity").WithCause(err) + } + } else { + update := tx.AuthIdentity.UpdateOneID(identity.ID).SetVerifiedAt(verifiedAt) + if issuer != nil { + update = update.SetIssuer(*issuer) + } + if input.Metadata != nil { + update = update.SetMetadata(cloneAdminAuthIdentityMetadata(input.Metadata)) + } + identity, err = update.Save(ctx) + if err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_SAVE_FAILED", "failed to save auth identity").WithCause(err) + } + } + + var channel *dbent.AuthIdentityChannel + if channelInput != nil { + channel, err = tx.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ(providerType), + authidentitychannel.ProviderKeyEQ(providerKey), + authidentitychannel.ChannelEQ(channelInput.Channel), + authidentitychannel.ChannelAppIDEQ(channelInput.ChannelAppID), + authidentitychannel.ChannelSubjectEQ(channelInput.ChannelSubject), + ). + WithIdentity(). + Only(ctx) + if err != nil && !dbent.IsNotFound(err) { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err) + } + if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != userID { + return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user") + } + if channel == nil { + create := tx.AuthIdentityChannel.Create(). + SetIdentityID(identity.ID). + SetProviderType(providerType). + SetProviderKey(providerKey). + SetChannel(channelInput.Channel). + SetChannelAppID(channelInput.ChannelAppID). + SetChannelSubject(channelInput.ChannelSubject) + if channelInput.Metadata != nil { + create = create.SetMetadata(cloneAdminAuthIdentityMetadata(channelInput.Metadata)) + } + channel, err = create.Save(ctx) + if err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_SAVE_FAILED", "failed to save auth identity channel").WithCause(err) + } + } else { + update := tx.AuthIdentityChannel.UpdateOneID(channel.ID).SetIdentityID(identity.ID) + if channelInput.Metadata != nil { + update = update.SetMetadata(cloneAdminAuthIdentityMetadata(channelInput.Metadata)) + } + channel, err = update.Save(ctx) + if err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_SAVE_FAILED", "failed to save auth identity channel").WithCause(err) + } + } + } + + if err := tx.Commit(); err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_COMMIT_FAILED", "failed to commit auth identity bind").WithCause(err) + } + return buildAdminBoundAuthIdentity(identity, channel), nil +} + +func normalizeAdminBindChannelInput(input *AdminBindAuthIdentityChannelInput) *AdminBindAuthIdentityChannelInput { + if input == nil { + return nil + } + channel := &AdminBindAuthIdentityChannelInput{ + Channel: strings.TrimSpace(input.Channel), + ChannelAppID: strings.TrimSpace(input.ChannelAppID), + ChannelSubject: strings.TrimSpace(input.ChannelSubject), + Metadata: cloneAdminAuthIdentityMetadata(input.Metadata), + } + if channel.Channel == "" || channel.ChannelAppID == "" || channel.ChannelSubject == "" { + return nil + } + return channel +} + +func normalizeAdminAuthIdentityProviderType(input string) string { + switch strings.ToLower(strings.TrimSpace(input)) { + case "email": + return "email" + case "linuxdo": + return "linuxdo" + case "oidc": + return "oidc" + case "wechat": + return "wechat" + default: + return "" + } +} + +func buildAdminBoundAuthIdentity(identity *dbent.AuthIdentity, channel *dbent.AuthIdentityChannel) *AdminBoundAuthIdentity { + if identity == nil { + return nil + } + result := &AdminBoundAuthIdentity{ + UserID: identity.UserID, + ProviderType: strings.TrimSpace(identity.ProviderType), + ProviderKey: strings.TrimSpace(identity.ProviderKey), + ProviderSubject: strings.TrimSpace(identity.ProviderSubject), + VerifiedAt: identity.VerifiedAt, + Issuer: identity.Issuer, + Metadata: cloneAdminAuthIdentityMetadata(identity.Metadata), + CreatedAt: identity.CreatedAt, + UpdatedAt: identity.UpdatedAt, + } + if channel != nil { + result.Channel = &AdminBoundAuthIdentityChannel{ + Channel: strings.TrimSpace(channel.Channel), + ChannelAppID: strings.TrimSpace(channel.ChannelAppID), + ChannelSubject: strings.TrimSpace(channel.ChannelSubject), + Metadata: cloneAdminAuthIdentityMetadata(channel.Metadata), + CreatedAt: channel.CreatedAt, + UpdatedAt: channel.UpdatedAt, + } + } + return result +} + +func cloneAdminAuthIdentityMetadata(input map[string]any) map[string]any { + if input == nil { + return nil + } + if len(input) == 0 { + return map[string]any{} + } + data, err := json.Marshal(input) + if err != nil { + out := make(map[string]any, len(input)) + for key, value := range input { + out[key] = value + } + return out + } + var out map[string]any + if err := json.Unmarshal(data, &out); err != nil { + out = make(map[string]any, len(input)) + for key, value := range input { + out[key] = value + } + } + return out +} + // Group management implementations func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder} diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go index 419ddbc329642e9717d545319c6f53afd297347c..fcde5cbf4abd89d7e666c378ad7bee8666c64146 100644 --- a/backend/internal/service/admin_service_apikey_test.go +++ b/backend/internal/service/admin_service_apikey_test.go @@ -44,6 +44,15 @@ func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, erro } func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") } func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) GetUserAvatar(context.Context, int64) (*UserAvatar, error) { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) UpsertUserAvatar(context.Context, int64, UpsertUserAvatarInput) (*UserAvatar, error) { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) DeleteUserAvatar(context.Context, int64) error { + panic("unexpected") +} func (s *userRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { panic("unexpected") } @@ -70,6 +79,23 @@ func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *s } func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") } func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) { + panic("unexpected") +} + +func (s *userRepoStubForGroupUpdate) UnbindUserAuthProvider(context.Context, int64, string) error { + panic("unexpected") +} + +func (s *userRepoStubForGroupUpdate) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) UpdateUserLastActiveAt(context.Context, int64, time.Time) error { + panic("unexpected") +} func (s *userRepoStubForGroupUpdate) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { panic("unexpected") } diff --git a/backend/internal/service/admin_service_auth_identity_binding_test.go b/backend/internal/service/admin_service_auth_identity_binding_test.go new file mode 100644 index 0000000000000000000000000000000000000000..f8ce393536662e478b6d2e6462404a4d65c504f0 --- /dev/null +++ b/backend/internal/service/admin_service_auth_identity_binding_test.go @@ -0,0 +1,215 @@ +//go:build unit + +package service + +import ( + "context" + "database/sql" + "testing" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/enttest" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +func newAdminServiceAuthIdentityBindingTestClient(t *testing.T) *dbent.Client { + t.Helper() + + db, err := sql.Open("sqlite", "file:admin_service_auth_identity_binding?mode=memory&cache=shared&_fk=1") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + return client +} + +func TestAdminServiceBindUserAuthIdentityCreatesCanonicalAndChannelBinding(t *testing.T) { + client := newAdminServiceAuthIdentityBindingTestClient(t) + ctx := context.Background() + + user, err := client.User.Create(). + SetEmail("bind-target@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + svc := &adminServiceImpl{ + userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}}, + entClient: client, + } + + result, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{ + ProviderType: "wechat", + ProviderKey: "wechat-main", + ProviderSubject: "union-123", + Metadata: map[string]any{"source": "admin-repair"}, + Channel: &AdminBindAuthIdentityChannelInput{ + Channel: "open", + ChannelAppID: "wx-open", + ChannelSubject: "openid-123", + Metadata: map[string]any{"scene": "migration"}, + }, + }) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, user.ID, result.UserID) + require.Equal(t, "wechat", result.ProviderType) + require.Equal(t, "wechat-main", result.ProviderKey) + require.NotNil(t, result.VerifiedAt) + require.NotNil(t, result.Channel) + require.Equal(t, "open", result.Channel.Channel) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderKeyEQ("wechat-main"), + authidentity.ProviderSubjectEQ("union-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, user.ID, identity.UserID) + require.Equal(t, "admin-repair", identity.Metadata["source"]) + require.NotNil(t, identity.VerifiedAt) + + channel, err := client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ("wechat"), + authidentitychannel.ProviderKeyEQ("wechat-main"), + authidentitychannel.ChannelEQ("open"), + authidentitychannel.ChannelAppIDEQ("wx-open"), + authidentitychannel.ChannelSubjectEQ("openid-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, identity.ID, channel.IdentityID) + require.Equal(t, "migration", channel.Metadata["scene"]) +} + +func TestAdminServiceBindUserAuthIdentityRejectsOtherOwner(t *testing.T) { + client := newAdminServiceAuthIdentityBindingTestClient(t) + ctx := context.Background() + + owner, err := client.User.Create(). + SetEmail("owner@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + target, err := client.User.Create(). + SetEmail("target@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + _, err = client.AuthIdentity.Create(). + SetUserID(owner.ID). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("subject-1"). + Save(ctx) + require.NoError(t, err) + + svc := &adminServiceImpl{ + userRepo: &userRepoStub{user: &User{ID: target.ID, Email: target.Email, Status: StatusActive}}, + entClient: client, + } + + _, err = svc.BindUserAuthIdentity(ctx, target.ID, AdminBindAuthIdentityInput{ + ProviderType: "oidc", + ProviderKey: "https://issuer.example", + ProviderSubject: "subject-1", + }) + require.Error(t, err) + require.Equal(t, "AUTH_IDENTITY_OWNERSHIP_CONFLICT", infraerrors.Reason(err)) +} + +func TestAdminServiceBindUserAuthIdentityIsIdempotentForSameUser(t *testing.T) { + client := newAdminServiceAuthIdentityBindingTestClient(t) + ctx := context.Background() + + user, err := client.User.Create(). + SetEmail("same-user@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + svc := &adminServiceImpl{ + userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}}, + entClient: client, + } + + first, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{ + ProviderType: "oidc", + ProviderKey: "https://issuer.example", + ProviderSubject: "subject-2", + Metadata: map[string]any{"source": "first"}, + }) + require.NoError(t, err) + + second, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{ + ProviderType: "oidc", + ProviderKey: "https://issuer.example", + ProviderSubject: "subject-2", + Metadata: map[string]any{"source": "second"}, + }) + require.NoError(t, err) + require.Equal(t, first.UserID, second.UserID) + require.Equal(t, "second", second.Metadata["source"]) + + identities, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example"), + authidentity.ProviderSubjectEQ("subject-2"), + ). + All(ctx) + require.NoError(t, err) + require.Len(t, identities, 1) + require.Equal(t, "second", identities[0].Metadata["source"]) +} + +func TestAdminServiceBindUserAuthIdentityRejectsInvalidProviderType(t *testing.T) { + client := newAdminServiceAuthIdentityBindingTestClient(t) + ctx := context.Background() + + user, err := client.User.Create(). + SetEmail("invalid-provider@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + svc := &adminServiceImpl{ + userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}}, + entClient: client, + } + + _, err = svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{ + ProviderType: "github", + ProviderKey: "github-main", + ProviderSubject: "subject-3", + }) + require.Error(t, err) + require.Equal(t, "INVALID_INPUT", infraerrors.Reason(err)) +} diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go index fbc856cf3c1f9541141047b6eea0cec211080edb..fe9e7701a26c1916aa7db35b7d2b732552eac422 100644 --- a/backend/internal/service/admin_service_delete_test.go +++ b/backend/internal/service/admin_service_delete_test.go @@ -13,15 +13,18 @@ import ( ) type userRepoStub struct { - user *User - getErr error - createErr error - deleteErr error - exists bool - existsErr error - nextID int64 - created []*User - deletedIDs []int64 + user *User + getErr error + createErr error + deleteErr error + exists bool + existsErr error + nextID int64 + created []*User + updated []*User + deletedIDs []int64 + usersByEmail map[string]*User + getByEmailErr error } func (s *userRepoStub) Create(ctx context.Context, user *User) error { @@ -32,6 +35,11 @@ func (s *userRepoStub) Create(ctx context.Context, user *User) error { user.ID = s.nextID } s.created = append(s.created, user) + if s.usersByEmail == nil { + s.usersByEmail = make(map[string]*User) + } + s.usersByEmail[user.Email] = user + s.user = user return nil } @@ -46,7 +54,18 @@ func (s *userRepoStub) GetByID(ctx context.Context, id int64) (*User, error) { } func (s *userRepoStub) GetByEmail(ctx context.Context, email string) (*User, error) { - panic("unexpected GetByEmail call") + if s.getByEmailErr != nil { + return nil, s.getByEmailErr + } + if s.usersByEmail != nil { + if user, ok := s.usersByEmail[email]; ok { + return user, nil + } + } + if s.user != nil && s.user.Email == email { + return s.user, nil + } + return nil, ErrUserNotFound } func (s *userRepoStub) GetFirstAdmin(ctx context.Context) (*User, error) { @@ -54,7 +73,13 @@ func (s *userRepoStub) GetFirstAdmin(ctx context.Context) (*User, error) { } func (s *userRepoStub) Update(ctx context.Context, user *User) error { - panic("unexpected Update call") + s.updated = append(s.updated, user) + if s.usersByEmail == nil { + s.usersByEmail = make(map[string]*User) + } + s.usersByEmail[user.Email] = user + s.user = user + return nil } func (s *userRepoStub) Delete(ctx context.Context, id int64) error { @@ -62,6 +87,18 @@ func (s *userRepoStub) Delete(ctx context.Context, id int64) error { return s.deleteErr } +func (s *userRepoStub) GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error) { + panic("unexpected GetUserAvatar call") +} + +func (s *userRepoStub) UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) { + panic("unexpected UpsertUserAvatar call") +} + +func (s *userRepoStub) DeleteUserAvatar(ctx context.Context, userID int64) error { + panic("unexpected DeleteUserAvatar call") +} + func (s *userRepoStub) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { panic("unexpected List call") } @@ -70,6 +107,18 @@ func (s *userRepoStub) ListWithFilters(ctx context.Context, params pagination.Pa panic("unexpected ListWithFilters call") } +func (s *userRepoStub) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) { + panic("unexpected GetLatestUsedAtByUserIDs call") +} + +func (s *userRepoStub) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) { + panic("unexpected GetLatestUsedAtByUserID call") +} + +func (s *userRepoStub) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error { + panic("unexpected UpdateUserLastActiveAt call") +} + func (s *userRepoStub) UpdateBalance(ctx context.Context, id int64, amount float64) error { panic("unexpected UpdateBalance call") } @@ -101,6 +150,14 @@ func (s *userRepoStub) AddGroupToAllowedGroups(ctx context.Context, userID int64 panic("unexpected AddGroupToAllowedGroups call") } +func (s *userRepoStub) ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error) { + panic("unexpected ListUserAuthIdentities call") +} + +func (s *userRepoStub) UnbindUserAuthProvider(context.Context, int64, string) error { + panic("unexpected UnbindUserAuthProvider call") +} + func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error { panic("unexpected UpdateTotpSecret call") } diff --git a/backend/internal/service/admin_service_email_identity_sync_test.go b/backend/internal/service/admin_service_email_identity_sync_test.go new file mode 100644 index 0000000000000000000000000000000000000000..2232c9c38b6e9994233e569b6befc976c8a2fd63 --- /dev/null +++ b/backend/internal/service/admin_service_email_identity_sync_test.go @@ -0,0 +1,187 @@ +//go:build unit + +package service + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +type ensureEmailCall struct { + userID int64 + email string +} + +type replaceEmailCall struct { + userID int64 + oldEmail string + newEmail string +} + +type emailSyncRepoStub struct { + user *User + nextID int64 + updateCalls int + created []*User + updated []*User + ensureCalls []ensureEmailCall + replaceCalls []replaceEmailCall + ensureErr error + replaceErr error +} + +func (s *emailSyncRepoStub) Create(_ context.Context, user *User) error { + if s.nextID != 0 && user.ID == 0 { + user.ID = s.nextID + } + s.created = append(s.created, user) + s.user = user + return nil +} + +func (s *emailSyncRepoStub) GetByID(_ context.Context, _ int64) (*User, error) { + if s.user == nil { + return nil, ErrUserNotFound + } + cloned := *s.user + return &cloned, nil +} + +func (s *emailSyncRepoStub) GetByEmail(_ context.Context, _ string) (*User, error) { + return nil, ErrUserNotFound +} + +func (s *emailSyncRepoStub) GetFirstAdmin(context.Context) (*User, error) { + return nil, fmt.Errorf("unexpected GetFirstAdmin call") +} + +func (s *emailSyncRepoStub) Update(_ context.Context, user *User) error { + s.updateCalls++ + s.updated = append(s.updated, user) + s.user = user + return nil +} + +func (s *emailSyncRepoStub) Delete(context.Context, int64) error { return nil } + +func (s *emailSyncRepoStub) GetUserAvatar(context.Context, int64) (*UserAvatar, error) { + return nil, fmt.Errorf("unexpected GetUserAvatar call") +} + +func (s *emailSyncRepoStub) UpsertUserAvatar(context.Context, int64, UpsertUserAvatarInput) (*UserAvatar, error) { + return nil, fmt.Errorf("unexpected UpsertUserAvatar call") +} + +func (s *emailSyncRepoStub) DeleteUserAvatar(context.Context, int64) error { + return fmt.Errorf("unexpected DeleteUserAvatar call") +} + +func (s *emailSyncRepoStub) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { + return nil, nil, fmt.Errorf("unexpected List call") +} + +func (s *emailSyncRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) { + return nil, nil, fmt.Errorf("unexpected ListWithFilters call") +} + +func (s *emailSyncRepoStub) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) { + return map[int64]*time.Time{}, nil +} + +func (s *emailSyncRepoStub) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) { + return nil, nil +} + +func (s *emailSyncRepoStub) UpdateUserLastActiveAt(context.Context, int64, time.Time) error { + return nil +} + +func (s *emailSyncRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil } + +func (s *emailSyncRepoStub) DeductBalance(context.Context, int64, float64) error { return nil } + +func (s *emailSyncRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil } + +func (s *emailSyncRepoStub) ExistsByEmail(context.Context, string) (bool, error) { return false, nil } + +func (s *emailSyncRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { + return 0, nil +} + +func (s *emailSyncRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil } + +func (s *emailSyncRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { + return nil +} + +func (s *emailSyncRepoStub) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) { + return nil, nil +} + +func (s *emailSyncRepoStub) UnbindUserAuthProvider(context.Context, int64, string) error { return nil } + +func (s *emailSyncRepoStub) UpdateTotpSecret(context.Context, int64, *string) error { return nil } + +func (s *emailSyncRepoStub) EnableTotp(context.Context, int64) error { return nil } + +func (s *emailSyncRepoStub) DisableTotp(context.Context, int64) error { return nil } + +func (s *emailSyncRepoStub) EnsureEmailAuthIdentity(_ context.Context, userID int64, email string) error { + s.ensureCalls = append(s.ensureCalls, ensureEmailCall{userID: userID, email: email}) + return s.ensureErr +} + +func (s *emailSyncRepoStub) ReplaceEmailAuthIdentity(_ context.Context, userID int64, oldEmail, newEmail string) error { + s.replaceCalls = append(s.replaceCalls, replaceEmailCall{ + userID: userID, + oldEmail: oldEmail, + newEmail: newEmail, + }) + return s.replaceErr +} + +func TestAdminService_CreateUser_DoesNotReturnPartialSuccessFromEmailIdentityResync(t *testing.T) { + repo := &emailSyncRepoStub{ + nextID: 55, + ensureErr: fmt.Errorf("unexpected email resync"), + } + svc := &adminServiceImpl{userRepo: repo} + + user, err := svc.CreateUser(context.Background(), &CreateUserInput{ + Email: "admin-created@example.com", + Password: "strong-pass", + }) + require.NoError(t, err) + require.NotNil(t, user) + require.Equal(t, int64(55), user.ID) + require.Empty(t, repo.ensureCalls) + require.Empty(t, repo.replaceCalls) +} + +func TestAdminService_UpdateUser_DoesNotReturnPartialSuccessFromEmailIdentityResync(t *testing.T) { + repo := &emailSyncRepoStub{ + user: &User{ + ID: 91, + Email: "before@example.com", + Role: RoleUser, + Status: StatusActive, + Concurrency: 3, + }, + replaceErr: fmt.Errorf("unexpected email resync"), + } + svc := &adminServiceImpl{userRepo: repo} + + updated, err := svc.UpdateUser(context.Background(), 91, &UpdateUserInput{ + Email: "after@example.com", + }) + require.NoError(t, err) + require.NotNil(t, updated) + require.Equal(t, "after@example.com", updated.Email) + require.Empty(t, repo.replaceCalls) + require.Empty(t, repo.ensureCalls) +} diff --git a/backend/internal/service/admin_service_list_users_test.go b/backend/internal/service/admin_service_list_users_test.go index ceeb52c2944c830d82a071fecae79f4a00719ad2..657616c4e9a57ca3a3738810c619ff06601aa496 100644 --- a/backend/internal/service/admin_service_list_users_test.go +++ b/backend/internal/service/admin_service_list_users_test.go @@ -6,6 +6,7 @@ import ( "context" "errors" "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/stretchr/testify/require" @@ -16,6 +17,8 @@ type userRepoStubForListUsers struct { users []User err error listWithFiltersParams pagination.PaginationParams + lastUsedByUserID map[int64]*time.Time + lastUsedErr error } func (s *userRepoStubForListUsers) ListWithFilters(_ context.Context, params pagination.PaginationParams, _ UserListFilters) ([]User, *pagination.PaginationResult, error) { @@ -32,6 +35,26 @@ func (s *userRepoStubForListUsers) ListWithFilters(_ context.Context, params pag }, nil } +func (s *userRepoStubForListUsers) GetLatestUsedAtByUserIDs(_ context.Context, userIDs []int64) (map[int64]*time.Time, error) { + if s.lastUsedErr != nil { + return nil, s.lastUsedErr + } + result := make(map[int64]*time.Time, len(userIDs)) + for _, userID := range userIDs { + if ts, ok := s.lastUsedByUserID[userID]; ok { + result[userID] = ts + } + } + return result, nil +} + +func (s *userRepoStubForListUsers) GetLatestUsedAtByUserID(_ context.Context, userID int64) (*time.Time, error) { + if s.lastUsedErr != nil { + return nil, s.lastUsedErr + } + return s.lastUsedByUserID[userID], nil +} + type userGroupRateRepoStubForListUsers struct { batchCalls int singleCall []int64 @@ -130,3 +153,21 @@ func TestAdminService_ListUsers_PassesSortParams(t *testing.T) { SortOrder: "ASC", }, userRepo.listWithFiltersParams) } + +func TestAdminService_ListUsers_PopulatesLastUsedAt(t *testing.T) { + lastUsed := time.Now().UTC().Add(-30 * time.Minute).Truncate(time.Second) + userRepo := &userRepoStubForListUsers{ + users: []User{{ID: 101, Email: "u@example.com"}}, + lastUsedByUserID: map[int64]*time.Time{ + 101: &lastUsed, + }, + } + svc := &adminServiceImpl{userRepo: userRepo} + + users, total, err := svc.ListUsers(context.Background(), 1, 20, UserListFilters{}, "", "") + require.NoError(t, err) + require.Equal(t, int64(1), total) + require.Len(t, users, 1) + require.NotNil(t, users[0].LastUsedAt) + require.WithinDuration(t, lastUsed, *users[0].LastUsedAt, time.Second) +} diff --git a/backend/internal/service/announcement.go b/backend/internal/service/announcement.go index 25c66eb43746944899934e403cd4d96b89cdf7f6..02741d37ba8a960edc8b1e36bea893e99595beea 100644 --- a/backend/internal/service/announcement.go +++ b/backend/internal/service/announcement.go @@ -5,6 +5,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/domain" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" ) @@ -34,8 +35,23 @@ const ( ) var ( - ErrAnnouncementNotFound = domain.ErrAnnouncementNotFound - ErrAnnouncementInvalidTarget = domain.ErrAnnouncementInvalidTarget + ErrAnnouncementNotFound = domain.ErrAnnouncementNotFound + ErrAnnouncementInvalidTarget = domain.ErrAnnouncementInvalidTarget + ErrAnnouncementNilInput = infraerrors.BadRequest("ANNOUNCEMENT_INPUT_REQUIRED", "announcement input is required") + ErrAnnouncementInvalidTitle = infraerrors.BadRequest("ANNOUNCEMENT_TITLE_INVALID", "announcement title is invalid") + ErrAnnouncementContentRequired = infraerrors.BadRequest( + "ANNOUNCEMENT_CONTENT_REQUIRED", + "announcement content is required", + ) + ErrAnnouncementInvalidStatus = infraerrors.BadRequest("ANNOUNCEMENT_STATUS_INVALID", "announcement status is invalid") + ErrAnnouncementInvalidNotifyMode = infraerrors.BadRequest( + "ANNOUNCEMENT_NOTIFY_MODE_INVALID", + "announcement notify_mode is invalid", + ) + ErrAnnouncementInvalidSchedule = infraerrors.BadRequest( + "ANNOUNCEMENT_TIME_RANGE_INVALID", + "starts_at must be before ends_at", + ) ) type AnnouncementTargeting = domain.AnnouncementTargeting diff --git a/backend/internal/service/announcement_service.go b/backend/internal/service/announcement_service.go index c0a0681ac9e75eb9451a2e862a742e9c5823f20f..124790419b88ea55b3f4efef9041ec71f78b4938 100644 --- a/backend/internal/service/announcement_service.go +++ b/backend/internal/service/announcement_service.go @@ -70,16 +70,16 @@ type AnnouncementUserReadStatus struct { func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncementInput) (*Announcement, error) { if input == nil { - return nil, fmt.Errorf("create announcement: nil input") + return nil, ErrAnnouncementNilInput } title := strings.TrimSpace(input.Title) content := strings.TrimSpace(input.Content) if title == "" || len(title) > 200 { - return nil, fmt.Errorf("create announcement: invalid title") + return nil, ErrAnnouncementInvalidTitle } if content == "" { - return nil, fmt.Errorf("create announcement: content is required") + return nil, ErrAnnouncementContentRequired } status := strings.TrimSpace(input.Status) @@ -87,7 +87,7 @@ func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncem status = AnnouncementStatusDraft } if !isValidAnnouncementStatus(status) { - return nil, fmt.Errorf("create announcement: invalid status") + return nil, ErrAnnouncementInvalidStatus } targeting, err := domain.AnnouncementTargeting(input.Targeting).NormalizeAndValidate() @@ -100,12 +100,12 @@ func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncem notifyMode = AnnouncementNotifyModeSilent } if !isValidAnnouncementNotifyMode(notifyMode) { - return nil, fmt.Errorf("create announcement: invalid notify_mode") + return nil, ErrAnnouncementInvalidNotifyMode } if input.StartsAt != nil && input.EndsAt != nil { if !input.StartsAt.Before(*input.EndsAt) { - return nil, fmt.Errorf("create announcement: starts_at must be before ends_at") + return nil, ErrAnnouncementInvalidSchedule } } @@ -131,7 +131,7 @@ func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncem func (s *AnnouncementService) Update(ctx context.Context, id int64, input *UpdateAnnouncementInput) (*Announcement, error) { if input == nil { - return nil, fmt.Errorf("update announcement: nil input") + return nil, ErrAnnouncementNilInput } a, err := s.announcementRepo.GetByID(ctx, id) @@ -142,21 +142,21 @@ func (s *AnnouncementService) Update(ctx context.Context, id int64, input *Updat if input.Title != nil { title := strings.TrimSpace(*input.Title) if title == "" || len(title) > 200 { - return nil, fmt.Errorf("update announcement: invalid title") + return nil, ErrAnnouncementInvalidTitle } a.Title = title } if input.Content != nil { content := strings.TrimSpace(*input.Content) if content == "" { - return nil, fmt.Errorf("update announcement: content is required") + return nil, ErrAnnouncementContentRequired } a.Content = content } if input.Status != nil { status := strings.TrimSpace(*input.Status) if !isValidAnnouncementStatus(status) { - return nil, fmt.Errorf("update announcement: invalid status") + return nil, ErrAnnouncementInvalidStatus } a.Status = status } @@ -164,7 +164,7 @@ func (s *AnnouncementService) Update(ctx context.Context, id int64, input *Updat if input.NotifyMode != nil { notifyMode := strings.TrimSpace(*input.NotifyMode) if !isValidAnnouncementNotifyMode(notifyMode) { - return nil, fmt.Errorf("update announcement: invalid notify_mode") + return nil, ErrAnnouncementInvalidNotifyMode } a.NotifyMode = notifyMode } @@ -186,7 +186,7 @@ func (s *AnnouncementService) Update(ctx context.Context, id int64, input *Updat if a.StartsAt != nil && a.EndsAt != nil { if !a.StartsAt.Before(*a.EndsAt) { - return nil, fmt.Errorf("update announcement: starts_at must be before ends_at") + return nil, ErrAnnouncementInvalidSchedule } } diff --git a/backend/internal/service/announcement_service_test.go b/backend/internal/service/announcement_service_test.go new file mode 100644 index 0000000000000000000000000000000000000000..77fb9896e1728a046291fc337dc898b3d2e6e60a --- /dev/null +++ b/backend/internal/service/announcement_service_test.go @@ -0,0 +1,81 @@ +package service + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +type announcementRepoStub struct { + item *Announcement +} + +func (s *announcementRepoStub) Create(_ context.Context, a *Announcement) error { + s.item = a + return nil +} + +func (s *announcementRepoStub) GetByID(_ context.Context, _ int64) (*Announcement, error) { + if s.item == nil { + return nil, ErrAnnouncementNotFound + } + return s.item, nil +} + +func (s *announcementRepoStub) Update(_ context.Context, a *Announcement) error { + s.item = a + return nil +} + +func (*announcementRepoStub) Delete(context.Context, int64) error { + return nil +} + +func (*announcementRepoStub) List(context.Context, pagination.PaginationParams, AnnouncementListFilters) ([]Announcement, *pagination.PaginationResult, error) { + return nil, nil, nil +} + +func (*announcementRepoStub) ListActive(context.Context, time.Time) ([]Announcement, error) { + return nil, nil +} + +func TestAnnouncementServiceCreateRejectsEqualStartEndTimes(t *testing.T) { + repo := &announcementRepoStub{} + svc := NewAnnouncementService(repo, nil, nil, nil) + now := time.Unix(1776790020, 0) + + _, err := svc.Create(context.Background(), &CreateAnnouncementInput{ + Title: "公告", + Content: "内容", + Status: AnnouncementStatusActive, + NotifyMode: AnnouncementNotifyModePopup, + StartsAt: &now, + EndsAt: &now, + }) + require.ErrorIs(t, err, ErrAnnouncementInvalidSchedule) +} + +func TestAnnouncementServiceUpdateRejectsEqualStartEndTimes(t *testing.T) { + repo := &announcementRepoStub{ + item: &Announcement{ + ID: 1, + Title: "公告", + Content: "内容", + Status: AnnouncementStatusActive, + NotifyMode: AnnouncementNotifyModePopup, + }, + } + svc := NewAnnouncementService(repo, nil, nil, nil) + now := time.Unix(1776790020, 0) + startsAt := &now + endsAt := &now + + _, err := svc.Update(context.Background(), 1, &UpdateAnnouncementInput{ + StartsAt: &startsAt, + EndsAt: &endsAt, + }) + require.ErrorIs(t, err, ErrAnnouncementInvalidSchedule) +} diff --git a/backend/internal/service/auth_email_binding.go b/backend/internal/service/auth_email_binding.go new file mode 100644 index 0000000000000000000000000000000000000000..f0483800abbbbadfad616bef30dc787ed3e2a73d --- /dev/null +++ b/backend/internal/service/auth_email_binding.go @@ -0,0 +1,310 @@ +package service + +import ( + "context" + "errors" + "fmt" + "net/mail" + "strings" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" +) + +// BindEmailIdentity verifies and binds a local email/password identity to the +// current user, or replaces the existing bound primary email. +func (s *AuthService) BindEmailIdentity( + ctx context.Context, + userID int64, + email string, + verifyCode string, + password string, +) (*User, error) { + if s == nil { + return nil, ErrServiceUnavailable + } + + normalizedEmail, err := normalizeEmailForIdentityBinding(email) + if err != nil { + return nil, err + } + if isReservedEmail(normalizedEmail) { + return nil, ErrEmailReserved + } + if strings.TrimSpace(password) == "" { + return nil, ErrPasswordRequired + } + if err := s.VerifyOAuthEmailCode(ctx, normalizedEmail, verifyCode); err != nil { + return nil, err + } + + currentUser, err := s.userRepo.GetByID(ctx, userID) + if err != nil { + return nil, err + } + firstRealEmailBind := !hasBindableEmailIdentitySubject(currentUser.Email) + if firstRealEmailBind && len(password) < 6 { + return nil, infraerrors.BadRequest("PASSWORD_TOO_SHORT", "password must be at least 6 characters") + } + if !firstRealEmailBind && !s.CheckPassword(password, currentUser.PasswordHash) { + return nil, ErrPasswordIncorrect + } + + existingUser, err := s.userRepo.GetByEmail(ctx, normalizedEmail) + switch { + case err == nil && existingUser != nil && existingUser.ID != userID: + return nil, ErrEmailExists + case err != nil && !errors.Is(err, ErrUserNotFound): + return nil, ErrServiceUnavailable + } + + hashedPassword, err := s.HashPassword(password) + if err != nil { + return nil, fmt.Errorf("hash password: %w", err) + } + + if s.entClient != nil { + if err := s.updateBoundEmailIdentityTx(ctx, currentUser, normalizedEmail, hashedPassword, firstRealEmailBind); err != nil { + return nil, err + } + return currentUser, nil + } + + currentUser.Email = normalizedEmail + currentUser.PasswordHash = hashedPassword + if err := s.userRepo.Update(ctx, currentUser); err != nil { + if errors.Is(err, ErrEmailExists) { + return nil, ErrEmailExists + } + return nil, ErrServiceUnavailable + } + + if firstRealEmailBind { + if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, userID, "email"); err != nil { + return nil, fmt.Errorf("apply email first bind defaults: %w", err) + } + } + + return currentUser, nil +} + +// SendEmailIdentityBindCode sends a verification code for authenticated email binding flows. +func (s *AuthService) SendEmailIdentityBindCode(ctx context.Context, userID int64, email string) error { + if s == nil { + return ErrServiceUnavailable + } + + normalizedEmail, err := normalizeEmailForIdentityBinding(email) + if err != nil { + return err + } + if isReservedEmail(normalizedEmail) { + return ErrEmailReserved + } + if s.emailService == nil { + return ErrServiceUnavailable + } + if _, err := s.userRepo.GetByID(ctx, userID); err != nil { + if errors.Is(err, ErrUserNotFound) { + return ErrUserNotFound + } + return ErrServiceUnavailable + } + + existingUser, err := s.userRepo.GetByEmail(ctx, normalizedEmail) + switch { + case err == nil && existingUser != nil && existingUser.ID != userID: + return ErrEmailExists + case err != nil && !errors.Is(err, ErrUserNotFound): + return ErrServiceUnavailable + } + + siteName := "Sub2API" + if s.settingService != nil { + siteName = s.settingService.GetSiteName(ctx) + } + return s.emailService.SendVerifyCode(ctx, normalizedEmail, siteName) +} + +func normalizeEmailForIdentityBinding(email string) (string, error) { + normalized := strings.ToLower(strings.TrimSpace(email)) + if normalized == "" || len(normalized) > 255 { + return "", infraerrors.BadRequest("INVALID_EMAIL", "invalid email") + } + if _, err := mail.ParseAddress(normalized); err != nil { + return "", infraerrors.BadRequest("INVALID_EMAIL", "invalid email") + } + return normalized, nil +} + +func hasBindableEmailIdentitySubject(email string) bool { + normalized := strings.ToLower(strings.TrimSpace(email)) + return normalized != "" && !isReservedEmail(normalized) +} + +func (s *AuthService) updateBoundEmailIdentityTx( + ctx context.Context, + currentUser *User, + email string, + hashedPassword string, + applyFirstBindDefaults bool, +) error { + if tx := dbent.TxFromContext(ctx); tx != nil { + return s.updateBoundEmailIdentityWithClient(ctx, tx.Client(), currentUser, email, hashedPassword, applyFirstBindDefaults) + } + + tx, err := s.entClient.Tx(ctx) + if err != nil { + return ErrServiceUnavailable + } + defer func() { _ = tx.Rollback() }() + + txCtx := dbent.NewTxContext(ctx, tx) + if err := s.updateBoundEmailIdentityWithClient(txCtx, tx.Client(), currentUser, email, hashedPassword, applyFirstBindDefaults); err != nil { + return err + } + if err := tx.Commit(); err != nil { + return ErrServiceUnavailable + } + return nil +} + +func (s *AuthService) updateBoundEmailIdentityWithClient( + ctx context.Context, + client *dbent.Client, + currentUser *User, + email string, + hashedPassword string, + applyFirstBindDefaults bool, +) error { + if client == nil || currentUser == nil || currentUser.ID <= 0 { + return ErrServiceUnavailable + } + + oldEmail := currentUser.Email + if _, err := client.User.UpdateOneID(currentUser.ID). + SetEmail(email). + SetPasswordHash(hashedPassword). + Save(ctx); err != nil { + if dbent.IsConstraintError(err) { + return ErrEmailExists + } + return ErrServiceUnavailable + } + + if err := replaceBoundEmailAuthIdentityWithClient(ctx, client, currentUser.ID, oldEmail, email, "auth_service_email_bind"); err != nil { + if errors.Is(err, ErrEmailExists) { + return ErrEmailExists + } + return ErrServiceUnavailable + } + + if applyFirstBindDefaults { + if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, currentUser.ID, "email"); err != nil { + return fmt.Errorf("apply email first bind defaults: %w", err) + } + } + + updatedUser, err := client.User.Get(ctx, currentUser.ID) + if err != nil { + return ErrServiceUnavailable + } + currentUser.Email = updatedUser.Email + currentUser.PasswordHash = updatedUser.PasswordHash + currentUser.Balance = updatedUser.Balance + currentUser.Concurrency = updatedUser.Concurrency + currentUser.UpdatedAt = updatedUser.UpdatedAt + return nil +} + +func replaceBoundEmailAuthIdentityWithClient( + ctx context.Context, + client *dbent.Client, + userID int64, + oldEmail string, + newEmail string, + source string, +) error { + newSubject := normalizeBoundEmailAuthIdentitySubject(newEmail) + if err := ensureBoundEmailAuthIdentityWithClient(ctx, client, userID, newSubject, source); err != nil { + return err + } + + oldSubject := normalizeBoundEmailAuthIdentitySubject(oldEmail) + if oldSubject == "" || oldSubject == newSubject { + return nil + } + + _, err := client.AuthIdentity.Delete(). + Where( + authidentity.UserIDEQ(userID), + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ(oldSubject), + ). + Exec(ctx) + return err +} + +func ensureBoundEmailAuthIdentityWithClient( + ctx context.Context, + client *dbent.Client, + userID int64, + subject string, + source string, +) error { + if client == nil || userID <= 0 || subject == "" { + return nil + } + + if strings.TrimSpace(source) == "" { + source = "auth_service_email_bind" + } + + if err := client.AuthIdentity.Create(). + SetUserID(userID). + SetProviderType("email"). + SetProviderKey("email"). + SetProviderSubject(subject). + SetVerifiedAt(time.Now().UTC()). + SetMetadata(map[string]any{"source": strings.TrimSpace(source)}). + OnConflictColumns( + authidentity.FieldProviderType, + authidentity.FieldProviderKey, + authidentity.FieldProviderSubject, + ). + DoNothing(). + Exec(ctx); err != nil { + if !isSQLNoRowsError(err) { + return err + } + } + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ(subject), + ). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil + } + return err + } + if identity.UserID != userID { + return ErrEmailExists + } + return nil +} + +func normalizeBoundEmailAuthIdentitySubject(email string) string { + normalized := strings.ToLower(strings.TrimSpace(email)) + if normalized == "" || isReservedEmail(normalized) { + return "" + } + return normalized +} diff --git a/backend/internal/service/auth_oauth_email_flow.go b/backend/internal/service/auth_oauth_email_flow.go new file mode 100644 index 0000000000000000000000000000000000000000..ea558ae2272e6bfd74c21ea858207880aa00198e --- /dev/null +++ b/backend/internal/service/auth_oauth_email_flow.go @@ -0,0 +1,383 @@ +package service + +import ( + "context" + "errors" + "fmt" + "net/mail" + "strings" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/redeemcode" +) + +func normalizeOAuthSignupSource(signupSource string) string { + signupSource = strings.TrimSpace(strings.ToLower(signupSource)) + if signupSource == "" { + return "email" + } + return signupSource +} + +// SendPendingOAuthVerifyCode sends a local verification code for pending OAuth +// account-creation flows without relying on the public registration gate. +func (s *AuthService) SendPendingOAuthVerifyCode(ctx context.Context, email string) (*SendVerifyCodeResult, error) { + email = strings.TrimSpace(strings.ToLower(email)) + if email == "" { + return nil, ErrEmailVerifyRequired + } + if _, err := mail.ParseAddress(email); err != nil { + return nil, ErrEmailVerifyRequired + } + if isReservedEmail(email) { + return nil, ErrEmailReserved + } + if s == nil || s.emailService == nil { + return nil, ErrServiceUnavailable + } + + siteName := "Sub2API" + if s.settingService != nil { + siteName = s.settingService.GetSiteName(ctx) + } + if err := s.emailService.SendVerifyCode(ctx, email, siteName); err != nil { + return nil, err + } + return &SendVerifyCodeResult{ + Countdown: int(verifyCodeCooldown / time.Second), + }, nil +} + +func (s *AuthService) validateOAuthRegistrationInvitation(ctx context.Context, invitationCode string) (*RedeemCode, error) { + if s == nil || s.settingService == nil || !s.settingService.IsInvitationCodeEnabled(ctx) { + return nil, nil + } + if s.redeemRepo == nil && s.oauthEmailFlowClient(ctx) == nil { + return nil, ErrServiceUnavailable + } + + invitationCode = strings.TrimSpace(invitationCode) + if invitationCode == "" { + return nil, ErrInvitationCodeRequired + } + + redeemCode, err := s.loadOAuthRegistrationInvitation(ctx, invitationCode) + if err != nil { + return nil, ErrInvitationCodeInvalid + } + if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused { + return nil, ErrInvitationCodeInvalid + } + return redeemCode, nil +} + +// VerifyOAuthEmailCode verifies the locally entered email verification code for +// third-party signup and binding flows. This is intentionally independent from +// the global registration email verification toggle. +func (s *AuthService) VerifyOAuthEmailCode(ctx context.Context, email, verifyCode string) error { + email = strings.TrimSpace(strings.ToLower(email)) + verifyCode = strings.TrimSpace(verifyCode) + + if email == "" { + return ErrEmailVerifyRequired + } + if verifyCode == "" { + return ErrEmailVerifyRequired + } + if s == nil || s.emailService == nil { + return ErrServiceUnavailable + } + return s.emailService.VerifyCode(ctx, email, verifyCode) +} + +// RegisterOAuthEmailAccount creates a local account from a third-party first +// login after the user has verified a local email address. +func (s *AuthService) RegisterOAuthEmailAccount( + ctx context.Context, + email string, + password string, + verifyCode string, + invitationCode string, + signupSource string, +) (*TokenPair, *User, error) { + if s == nil { + return nil, nil, ErrServiceUnavailable + } + if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { + return nil, nil, ErrRegDisabled + } + + email = strings.TrimSpace(strings.ToLower(email)) + if isReservedEmail(email) { + return nil, nil, ErrEmailReserved + } + if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil { + return nil, nil, err + } + if err := s.VerifyOAuthEmailCode(ctx, email, verifyCode); err != nil { + return nil, nil, err + } + + if _, err := s.validateOAuthRegistrationInvitation(ctx, invitationCode); err != nil { + return nil, nil, err + } + + existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) + if err != nil { + return nil, nil, ErrServiceUnavailable + } + if existsEmail { + return nil, nil, ErrEmailExists + } + + hashedPassword, err := s.HashPassword(password) + if err != nil { + return nil, nil, fmt.Errorf("hash password: %w", err) + } + + signupSource = strings.TrimSpace(strings.ToLower(signupSource)) + if signupSource == "" { + signupSource = "email" + } + grantPlan := s.resolveSignupGrantPlan(ctx, signupSource) + + user := &User{ + Email: email, + PasswordHash: hashedPassword, + Role: RoleUser, + Balance: grantPlan.Balance, + Concurrency: grantPlan.Concurrency, + Status: StatusActive, + } + + if err := s.userRepo.Create(ctx, user); err != nil { + if errors.Is(err, ErrEmailExists) { + return nil, nil, ErrEmailExists + } + return nil, nil, ErrServiceUnavailable + } + + tokenPair, err := s.GenerateTokenPair(ctx, user, "") + if err != nil { + _ = s.RollbackOAuthEmailAccountCreation(ctx, user.ID, "") + return nil, nil, fmt.Errorf("generate token pair: %w", err) + } + return tokenPair, user, nil +} + +// FinalizeOAuthEmailAccount applies invitation usage and normal signup bootstrap +// only after the pending OAuth flow has fully reached its last reversible step. +func (s *AuthService) FinalizeOAuthEmailAccount( + ctx context.Context, + user *User, + invitationCode string, + signupSource string, +) error { + if s == nil || user == nil || user.ID <= 0 { + return ErrServiceUnavailable + } + + signupSource = normalizeOAuthSignupSource(signupSource) + invitationRedeemCode, err := s.validateOAuthRegistrationInvitation(ctx, invitationCode) + if err != nil { + return err + } + if invitationRedeemCode != nil { + if err := s.useOAuthRegistrationInvitation(ctx, invitationRedeemCode.ID, user.ID); err != nil { + return ErrInvitationCodeInvalid + } + } + + s.updateOAuthSignupSource(ctx, user.ID, signupSource) + grantPlan := s.resolveSignupGrantPlan(ctx, signupSource) + s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") + return nil +} + +// RollbackOAuthEmailAccountCreation removes a partially-created local account +// and restores any invitation code already consumed by that account. +func (s *AuthService) RollbackOAuthEmailAccountCreation(ctx context.Context, userID int64, invitationCode string) error { + if s == nil || s.userRepo == nil || userID <= 0 { + return ErrServiceUnavailable + } + if err := s.restoreOAuthRegistrationInvitation(ctx, invitationCode, userID); err != nil { + return err + } + if err := s.userRepo.Delete(ctx, userID); err != nil { + return fmt.Errorf("delete created oauth user: %w", err) + } + return nil +} + +func (s *AuthService) restoreOAuthRegistrationInvitation(ctx context.Context, invitationCode string, userID int64) error { + if s == nil || s.settingService == nil || !s.settingService.IsInvitationCodeEnabled(ctx) { + return nil + } + if s.redeemRepo == nil && s.oauthEmailFlowClient(ctx) == nil { + return ErrServiceUnavailable + } + + invitationCode = strings.TrimSpace(invitationCode) + if invitationCode == "" || userID <= 0 { + return nil + } + + redeemCode, err := s.loadOAuthRegistrationInvitation(ctx, invitationCode) + if err != nil { + if errors.Is(err, ErrRedeemCodeNotFound) { + return nil + } + return fmt.Errorf("load invitation code: %w", err) + } + if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUsed || redeemCode.UsedBy == nil || *redeemCode.UsedBy != userID { + return nil + } + + redeemCode.Status = StatusUnused + redeemCode.UsedBy = nil + redeemCode.UsedAt = nil + if err := s.updateOAuthRegistrationInvitation(ctx, redeemCode); err != nil { + return fmt.Errorf("restore invitation code: %w", err) + } + return nil +} + +func (s *AuthService) oauthEmailFlowClient(ctx context.Context) *dbent.Client { + if s == nil || s.entClient == nil { + return nil + } + if tx := dbent.TxFromContext(ctx); tx != nil { + return tx.Client() + } + return s.entClient +} + +func (s *AuthService) loadOAuthRegistrationInvitation(ctx context.Context, invitationCode string) (*RedeemCode, error) { + if client := s.oauthEmailFlowClient(ctx); client != nil { + entity, err := client.RedeemCode.Query().Where(redeemcode.CodeEQ(invitationCode)).Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil, ErrRedeemCodeNotFound + } + return nil, err + } + return &RedeemCode{ + ID: entity.ID, + Code: entity.Code, + Type: entity.Type, + Value: entity.Value, + Status: entity.Status, + UsedBy: entity.UsedBy, + UsedAt: entity.UsedAt, + Notes: oauthEmailFlowStringValue(entity.Notes), + CreatedAt: entity.CreatedAt, + GroupID: entity.GroupID, + ValidityDays: entity.ValidityDays, + }, nil + } + return s.redeemRepo.GetByCode(ctx, invitationCode) +} + +func (s *AuthService) useOAuthRegistrationInvitation(ctx context.Context, invitationID, userID int64) error { + if client := s.oauthEmailFlowClient(ctx); client != nil { + affected, err := client.RedeemCode.Update(). + Where(redeemcode.IDEQ(invitationID), redeemcode.StatusEQ(StatusUnused)). + SetStatus(StatusUsed). + SetUsedBy(userID). + SetUsedAt(time.Now().UTC()). + Save(ctx) + if err != nil { + return err + } + if affected == 0 { + return ErrRedeemCodeUsed + } + return nil + } + return s.redeemRepo.Use(ctx, invitationID, userID) +} + +func (s *AuthService) updateOAuthRegistrationInvitation(ctx context.Context, code *RedeemCode) error { + if code == nil { + return nil + } + if client := s.oauthEmailFlowClient(ctx); client != nil { + update := client.RedeemCode.UpdateOneID(code.ID). + SetCode(code.Code). + SetType(code.Type). + SetValue(code.Value). + SetStatus(code.Status). + SetNotes(code.Notes). + SetValidityDays(code.ValidityDays) + if code.UsedBy != nil { + update = update.SetUsedBy(*code.UsedBy) + } else { + update = update.ClearUsedBy() + } + if code.UsedAt != nil { + update = update.SetUsedAt(*code.UsedAt) + } else { + update = update.ClearUsedAt() + } + if code.GroupID != nil { + update = update.SetGroupID(*code.GroupID) + } else { + update = update.ClearGroupID() + } + _, err := update.Save(ctx) + return err + } + return s.redeemRepo.Update(ctx, code) +} + +func (s *AuthService) updateOAuthSignupSource(ctx context.Context, userID int64, signupSource string) { + client := s.oauthEmailFlowClient(ctx) + if client == nil || userID <= 0 || strings.TrimSpace(signupSource) == "" { + return + } + _ = client.User.UpdateOneID(userID).SetSignupSource(signupSource).Exec(ctx) +} + +func oauthEmailFlowStringValue(value *string) string { + if value == nil { + return "" + } + return *value +} + +// ValidatePasswordCredentials checks the local password without completing the +// login flow. This is used by pending third-party account adoption flows before +// the external identity has been bound. +func (s *AuthService) ValidatePasswordCredentials(ctx context.Context, email, password string) (*User, error) { + if s == nil { + return nil, ErrServiceUnavailable + } + + user, err := s.userRepo.GetByEmail(ctx, strings.TrimSpace(strings.ToLower(email))) + if err != nil { + if errors.Is(err, ErrUserNotFound) { + return nil, ErrInvalidCredentials + } + return nil, ErrServiceUnavailable + } + if !user.IsActive() { + return nil, ErrUserNotActive + } + if !s.CheckPassword(password, user.PasswordHash) { + return nil, ErrInvalidCredentials + } + return user, nil +} + +// RecordSuccessfulLogin updates last-login activity after a non-standard login +// flow finishes with a real session. +func (s *AuthService) RecordSuccessfulLogin(ctx context.Context, userID int64) { + if s != nil && s.userRepo != nil && userID > 0 { + user, err := s.userRepo.GetByID(ctx, userID) + if err == nil && user != nil && !isReservedEmail(user.Email) { + s.backfillEmailIdentityOnSuccessfulLogin(ctx, user) + } + } + s.touchUserLogin(ctx, userID) +} diff --git a/backend/internal/service/auth_oauth_email_flow_test.go b/backend/internal/service/auth_oauth_email_flow_test.go new file mode 100644 index 0000000000000000000000000000000000000000..a77dda72e5a9e7dfd3482e8133197e4b789c115b --- /dev/null +++ b/backend/internal/service/auth_oauth_email_flow_test.go @@ -0,0 +1,251 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +type redeemCodeRepoStub struct { + codesByCode map[string]*RedeemCode + useCalls []struct { + id int64 + userID int64 + } + updateCalls []*RedeemCode +} + +func (s *redeemCodeRepoStub) Create(context.Context, *RedeemCode) error { + panic("unexpected Create call") +} + +func (s *redeemCodeRepoStub) CreateBatch(context.Context, []RedeemCode) error { + panic("unexpected CreateBatch call") +} + +func (s *redeemCodeRepoStub) GetByID(context.Context, int64) (*RedeemCode, error) { + panic("unexpected GetByID call") +} + +func (s *redeemCodeRepoStub) GetByCode(_ context.Context, code string) (*RedeemCode, error) { + if s.codesByCode == nil { + return nil, ErrRedeemCodeNotFound + } + redeemCode, ok := s.codesByCode[code] + if !ok { + return nil, ErrRedeemCodeNotFound + } + cloned := *redeemCode + return &cloned, nil +} + +func (s *redeemCodeRepoStub) Update(_ context.Context, code *RedeemCode) error { + if code == nil { + return nil + } + cloned := *code + s.updateCalls = append(s.updateCalls, &cloned) + if s.codesByCode == nil { + s.codesByCode = make(map[string]*RedeemCode) + } + s.codesByCode[cloned.Code] = &cloned + return nil +} + +func (s *redeemCodeRepoStub) Delete(context.Context, int64) error { + panic("unexpected Delete call") +} + +func (s *redeemCodeRepoStub) Use(_ context.Context, id, userID int64) error { + for code, redeemCode := range s.codesByCode { + if redeemCode.ID != id { + continue + } + now := time.Now().UTC() + redeemCode.Status = StatusUsed + redeemCode.UsedBy = &userID + redeemCode.UsedAt = &now + s.codesByCode[code] = redeemCode + s.useCalls = append(s.useCalls, struct { + id int64 + userID int64 + }{id: id, userID: userID}) + return nil + } + return ErrRedeemCodeNotFound +} + +func (s *redeemCodeRepoStub) List(context.Context, pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error) { + panic("unexpected List call") +} + +func (s *redeemCodeRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string) ([]RedeemCode, *pagination.PaginationResult, error) { + panic("unexpected ListWithFilters call") +} + +func (s *redeemCodeRepoStub) ListByUser(context.Context, int64, int) ([]RedeemCode, error) { + panic("unexpected ListByUser call") +} + +func (s *redeemCodeRepoStub) ListByUserPaginated(context.Context, int64, pagination.PaginationParams, string) ([]RedeemCode, *pagination.PaginationResult, error) { + panic("unexpected ListByUserPaginated call") +} + +func (s *redeemCodeRepoStub) SumPositiveBalanceByUser(context.Context, int64) (float64, error) { + panic("unexpected SumPositiveBalanceByUser call") +} + +func newOAuthEmailFlowAuthService( + userRepo UserRepository, + redeemRepo RedeemCodeRepository, + refreshTokenCache RefreshTokenCache, + settings map[string]string, + emailCache EmailCache, +) *AuthService { + cfg := &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret", + ExpireHour: 1, + AccessTokenExpireMinutes: 60, + RefreshTokenExpireDays: 7, + }, + Default: config.DefaultConfig{ + UserBalance: 3.5, + UserConcurrency: 2, + }, + } + + settingService := NewSettingService(&settingRepoStub{values: settings}, cfg) + emailService := NewEmailService(&settingRepoStub{values: settings}, emailCache) + + return NewAuthService( + nil, + userRepo, + redeemRepo, + refreshTokenCache, + cfg, + settingService, + emailService, + nil, + nil, + nil, + nil, + ) +} + +func TestRegisterOAuthEmailAccountRollsBackCreatedUserWhenTokenPairGenerationFails(t *testing.T) { + userRepo := &userRepoStub{nextID: 42} + redeemRepo := &redeemCodeRepoStub{ + codesByCode: map[string]*RedeemCode{ + "INVITE123": { + ID: 7, + Code: "INVITE123", + Type: RedeemTypeInvitation, + Status: StatusUnused, + }, + }, + } + emailCache := &emailCacheStub{ + data: &VerificationCodeData{ + Code: "246810", + Attempts: 0, + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(15 * time.Minute), + }, + } + authService := newOAuthEmailFlowAuthService( + userRepo, + redeemRepo, + nil, + map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyInvitationCodeEnabled: "true", + SettingKeyEmailVerifyEnabled: "true", + }, + emailCache, + ) + + tokenPair, user, err := authService.RegisterOAuthEmailAccount( + context.Background(), + "fresh@example.com", + "secret-123", + "246810", + "INVITE123", + "oidc", + ) + + require.Nil(t, tokenPair) + require.Nil(t, user) + require.Error(t, err) + require.Contains(t, err.Error(), "generate token pair") + require.Equal(t, []int64{42}, userRepo.deletedIDs) + require.Len(t, userRepo.created, 1) + require.Empty(t, redeemRepo.useCalls) + require.Empty(t, redeemRepo.updateCalls) +} + +func TestRollbackOAuthEmailAccountCreationRestoresInvitationUsage(t *testing.T) { + userRepo := &userRepoStub{} + redeemRepo := &redeemCodeRepoStub{ + codesByCode: map[string]*RedeemCode{ + "INVITE123": { + ID: 7, + Code: "INVITE123", + Type: RedeemTypeInvitation, + Status: StatusUsed, + UsedBy: func() *int64 { + v := int64(42) + return &v + }(), + UsedAt: func() *time.Time { + v := time.Now().UTC() + return &v + }(), + }, + }, + } + authService := newOAuthEmailFlowAuthService( + userRepo, + redeemRepo, + &refreshTokenCacheStub{}, + map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyInvitationCodeEnabled: "true", + }, + &emailCacheStub{}, + ) + + err := authService.RollbackOAuthEmailAccountCreation(context.Background(), 42, "INVITE123") + + require.NoError(t, err) + require.Equal(t, []int64{42}, userRepo.deletedIDs) + require.Len(t, redeemRepo.updateCalls, 1) + require.Equal(t, StatusUnused, redeemRepo.updateCalls[0].Status) + require.Nil(t, redeemRepo.updateCalls[0].UsedBy) + require.Nil(t, redeemRepo.updateCalls[0].UsedAt) +} + +func TestRollbackOAuthEmailAccountCreationPropagatesDeleteError(t *testing.T) { + userRepo := &userRepoStub{deleteErr: errors.New("delete failed")} + authService := newOAuthEmailFlowAuthService( + userRepo, + &redeemCodeRepoStub{}, + &refreshTokenCacheStub{}, + map[string]string{ + SettingKeyRegistrationEnabled: "true", + }, + &emailCacheStub{}, + ) + + err := authService.RollbackOAuthEmailAccountCreation(context.Background(), 42, "") + + require.Error(t, err) + require.Contains(t, err.Error(), "delete created oauth user") +} diff --git a/backend/internal/service/auth_oauth_first_bind.go b/backend/internal/service/auth_oauth_first_bind.go new file mode 100644 index 0000000000000000000000000000000000000000..aa06e59f3079a02ae7d9716bdc7c91029fe4d751 --- /dev/null +++ b/backend/internal/service/auth_oauth_first_bind.go @@ -0,0 +1,104 @@ +package service + +import ( + "context" + "fmt" + "strings" + + dbent "github.com/Wei-Shaw/sub2api/ent" + + entsql "entgo.io/ent/dialect/sql" +) + +// ApplyProviderDefaultSettingsOnFirstBind applies provider-specific bootstrap +// settings the first time a user binds a third-party identity. The grant is +// idempotent per user/provider pair. +func (s *AuthService) ApplyProviderDefaultSettingsOnFirstBind( + ctx context.Context, + userID int64, + providerType string, +) error { + if s == nil || s.entClient == nil || s.settingService == nil || userID <= 0 { + return nil + } + + if dbent.TxFromContext(ctx) != nil { + return s.applyProviderDefaultSettingsOnFirstBind(ctx, userID, providerType) + } + + tx, err := s.entClient.Tx(ctx) + if err != nil { + return fmt.Errorf("begin first bind defaults transaction: %w", err) + } + defer func() { _ = tx.Rollback() }() + + txCtx := dbent.NewTxContext(ctx, tx) + if err := s.applyProviderDefaultSettingsOnFirstBind(txCtx, userID, providerType); err != nil { + return err + } + return tx.Commit() +} + +func (s *AuthService) applyProviderDefaultSettingsOnFirstBind( + ctx context.Context, + userID int64, + providerType string, +) error { + providerDefaults, enabled, err := s.settingService.ResolveAuthSourceGrantSettings(ctx, providerType, true) + if err != nil { + return fmt.Errorf("load auth source defaults: %w", err) + } + if !enabled { + return nil + } + + client := s.entClient + if tx := dbent.TxFromContext(ctx); tx != nil { + client = tx.Client() + } + + var result entsql.Result + if err := client.Driver().Exec( + ctx, + `INSERT INTO user_provider_default_grants (user_id, provider_type, grant_reason) +VALUES ($1, $2, $3) +ON CONFLICT (user_id, provider_type, grant_reason) DO NOTHING`, + []any{userID, strings.TrimSpace(providerType), "first_bind"}, + &result, + ); err != nil { + return fmt.Errorf("record first bind provider grant: %w", err) + } + + affected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("read first bind provider grant result: %w", err) + } + if affected == 0 { + return nil + } + + if providerDefaults.Balance != 0 { + if err := client.User.UpdateOneID(userID).AddBalance(providerDefaults.Balance).Exec(ctx); err != nil { + return fmt.Errorf("apply first bind balance default: %w", err) + } + } + if providerDefaults.Concurrency != 0 { + if err := client.User.UpdateOneID(userID).AddConcurrency(providerDefaults.Concurrency).Exec(ctx); err != nil { + return fmt.Errorf("apply first bind concurrency default: %w", err) + } + } + if s.defaultSubAssigner != nil { + for _, item := range providerDefaults.Subscriptions { + if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{ + UserID: userID, + GroupID: item.GroupID, + ValidityDays: item.ValidityDays, + Notes: "auto assigned by first bind defaults", + }); err != nil { + return fmt.Errorf("apply first bind subscription default: %w", err) + } + } + } + + return nil +} diff --git a/backend/internal/service/auth_pending_identity_service.go b/backend/internal/service/auth_pending_identity_service.go new file mode 100644 index 0000000000000000000000000000000000000000..7001ee18715458cc500a4afef42a6e9d14c60ab5 --- /dev/null +++ b/backend/internal/service/auth_pending_identity_service.go @@ -0,0 +1,347 @@ +package service + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/hex" + "fmt" + "strings" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + dbpredicate "github.com/Wei-Shaw/sub2api/ent/predicate" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + + entsql "entgo.io/ent/dialect/sql" +) + +var ( + ErrPendingAuthSessionNotFound = infraerrors.NotFound("PENDING_AUTH_SESSION_NOT_FOUND", "pending auth session not found") + ErrPendingAuthSessionExpired = infraerrors.Unauthorized("PENDING_AUTH_SESSION_EXPIRED", "pending auth session has expired") + ErrPendingAuthSessionConsumed = infraerrors.Unauthorized("PENDING_AUTH_SESSION_CONSUMED", "pending auth session has already been used") + ErrPendingAuthCodeInvalid = infraerrors.Unauthorized("PENDING_AUTH_CODE_INVALID", "pending auth completion code is invalid") + ErrPendingAuthCodeExpired = infraerrors.Unauthorized("PENDING_AUTH_CODE_EXPIRED", "pending auth completion code has expired") + ErrPendingAuthCodeConsumed = infraerrors.Unauthorized("PENDING_AUTH_CODE_CONSUMED", "pending auth completion code has already been used") + ErrPendingAuthBrowserMismatch = infraerrors.Unauthorized("PENDING_AUTH_BROWSER_MISMATCH", "pending auth completion code does not match this browser session") +) + +const ( + defaultPendingAuthTTL = 15 * time.Minute + defaultPendingAuthCompletionTTL = 5 * time.Minute +) + +type PendingAuthIdentityKey struct { + ProviderType string + ProviderKey string + ProviderSubject string +} + +type CreatePendingAuthSessionInput struct { + SessionToken string + Intent string + Identity PendingAuthIdentityKey + TargetUserID *int64 + RedirectTo string + ResolvedEmail string + RegistrationPasswordHash string + BrowserSessionKey string + UpstreamIdentityClaims map[string]any + LocalFlowState map[string]any + ExpiresAt time.Time +} + +type IssuePendingAuthCompletionCodeInput struct { + PendingAuthSessionID int64 + BrowserSessionKey string + TTL time.Duration +} + +type IssuePendingAuthCompletionCodeResult struct { + Code string + ExpiresAt time.Time +} + +type PendingIdentityAdoptionDecisionInput struct { + PendingAuthSessionID int64 + IdentityID *int64 + AdoptDisplayName bool + AdoptAvatar bool +} + +type AuthPendingIdentityService struct { + entClient *dbent.Client +} + +func NewAuthPendingIdentityService(entClient *dbent.Client) *AuthPendingIdentityService { + return &AuthPendingIdentityService{entClient: entClient} +} + +func (s *AuthPendingIdentityService) CreatePendingSession(ctx context.Context, input CreatePendingAuthSessionInput) (*dbent.PendingAuthSession, error) { + if s == nil || s.entClient == nil { + return nil, fmt.Errorf("pending auth ent client is not configured") + } + + sessionToken := strings.TrimSpace(input.SessionToken) + if sessionToken == "" { + var err error + sessionToken, err = randomOpaqueToken(24) + if err != nil { + return nil, err + } + } + + expiresAt := input.ExpiresAt.UTC() + if expiresAt.IsZero() { + expiresAt = time.Now().UTC().Add(defaultPendingAuthTTL) + } + + create := s.entClient.PendingAuthSession.Create(). + SetSessionToken(sessionToken). + SetIntent(strings.TrimSpace(input.Intent)). + SetProviderType(strings.TrimSpace(input.Identity.ProviderType)). + SetProviderKey(strings.TrimSpace(input.Identity.ProviderKey)). + SetProviderSubject(strings.TrimSpace(input.Identity.ProviderSubject)). + SetRedirectTo(strings.TrimSpace(input.RedirectTo)). + SetResolvedEmail(strings.TrimSpace(input.ResolvedEmail)). + SetRegistrationPasswordHash(strings.TrimSpace(input.RegistrationPasswordHash)). + SetBrowserSessionKey(strings.TrimSpace(input.BrowserSessionKey)). + SetUpstreamIdentityClaims(copyPendingMap(input.UpstreamIdentityClaims)). + SetLocalFlowState(copyPendingMap(input.LocalFlowState)). + SetExpiresAt(expiresAt) + if input.TargetUserID != nil { + create = create.SetTargetUserID(*input.TargetUserID) + } + return create.Save(ctx) +} + +func (s *AuthPendingIdentityService) IssueCompletionCode(ctx context.Context, input IssuePendingAuthCompletionCodeInput) (*IssuePendingAuthCompletionCodeResult, error) { + if s == nil || s.entClient == nil { + return nil, fmt.Errorf("pending auth ent client is not configured") + } + + session, err := s.entClient.PendingAuthSession.Get(ctx, input.PendingAuthSessionID) + if err != nil { + if dbent.IsNotFound(err) { + return nil, ErrPendingAuthSessionNotFound + } + return nil, err + } + + code, err := randomOpaqueToken(24) + if err != nil { + return nil, err + } + ttl := input.TTL + if ttl <= 0 { + ttl = defaultPendingAuthCompletionTTL + } + expiresAt := time.Now().UTC().Add(ttl) + + update := s.entClient.PendingAuthSession.UpdateOneID(session.ID). + SetCompletionCodeHash(hashPendingAuthCode(code)). + SetCompletionCodeExpiresAt(expiresAt) + if strings.TrimSpace(input.BrowserSessionKey) != "" { + update = update.SetBrowserSessionKey(strings.TrimSpace(input.BrowserSessionKey)) + } + if _, err := update.Save(ctx); err != nil { + return nil, err + } + + return &IssuePendingAuthCompletionCodeResult{ + Code: code, + ExpiresAt: expiresAt, + }, nil +} + +func (s *AuthPendingIdentityService) ConsumeCompletionCode(ctx context.Context, rawCode, browserSessionKey string) (*dbent.PendingAuthSession, error) { + if s == nil || s.entClient == nil { + return nil, fmt.Errorf("pending auth ent client is not configured") + } + + codeHash := hashPendingAuthCode(strings.TrimSpace(rawCode)) + session, err := s.entClient.PendingAuthSession.Query(). + Where(pendingauthsession.CompletionCodeHashEQ(codeHash)). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil, ErrPendingAuthCodeInvalid + } + return nil, err + } + + return s.consumeSession(ctx, session, browserSessionKey, ErrPendingAuthCodeExpired, ErrPendingAuthCodeConsumed) +} + +func (s *AuthPendingIdentityService) ConsumeBrowserSession(ctx context.Context, sessionToken, browserSessionKey string) (*dbent.PendingAuthSession, error) { + if s == nil || s.entClient == nil { + return nil, fmt.Errorf("pending auth ent client is not configured") + } + + session, err := s.getBrowserSession(ctx, sessionToken) + if err != nil { + return nil, err + } + + return s.consumeSession(ctx, session, browserSessionKey, ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed) +} + +func (s *AuthPendingIdentityService) GetBrowserSession(ctx context.Context, sessionToken, browserSessionKey string) (*dbent.PendingAuthSession, error) { + if s == nil || s.entClient == nil { + return nil, fmt.Errorf("pending auth ent client is not configured") + } + + session, err := s.getBrowserSession(ctx, sessionToken) + if err != nil { + return nil, err + } + if err := validatePendingSessionState(session, browserSessionKey, ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed); err != nil { + return nil, err + } + return session, nil +} + +func (s *AuthPendingIdentityService) getBrowserSession(ctx context.Context, sessionToken string) (*dbent.PendingAuthSession, error) { + if s == nil || s.entClient == nil { + return nil, fmt.Errorf("pending auth ent client is not configured") + } + + sessionToken = strings.TrimSpace(sessionToken) + if sessionToken == "" { + return nil, ErrPendingAuthSessionNotFound + } + + session, err := s.entClient.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(sessionToken)). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil, ErrPendingAuthSessionNotFound + } + return nil, err + } + return session, nil +} + +func (s *AuthPendingIdentityService) consumeSession( + ctx context.Context, + session *dbent.PendingAuthSession, + browserSessionKey string, + expiredErr error, + consumedErr error, +) (*dbent.PendingAuthSession, error) { + if err := validatePendingSessionState(session, browserSessionKey, expiredErr, consumedErr); err != nil { + return nil, err + } + + now := time.Now().UTC() + updated, err := s.entClient.PendingAuthSession.UpdateOneID(session.ID). + SetConsumedAt(now). + SetCompletionCodeHash(""). + ClearCompletionCodeExpiresAt(). + Save(ctx) + if err != nil { + return nil, err + } + return updated, nil +} + +func validatePendingSessionState(session *dbent.PendingAuthSession, browserSessionKey string, expiredErr error, consumedErr error) error { + if session == nil { + return ErrPendingAuthSessionNotFound + } + + now := time.Now().UTC() + if session.ConsumedAt != nil { + return consumedErr + } + if !session.ExpiresAt.IsZero() && now.After(session.ExpiresAt) { + return expiredErr + } + if session.CompletionCodeExpiresAt != nil && now.After(*session.CompletionCodeExpiresAt) { + return expiredErr + } + if strings.TrimSpace(session.BrowserSessionKey) != "" && strings.TrimSpace(browserSessionKey) != strings.TrimSpace(session.BrowserSessionKey) { + return ErrPendingAuthBrowserMismatch + } + return nil +} + +func (s *AuthPendingIdentityService) UpsertAdoptionDecision(ctx context.Context, input PendingIdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) { + if s == nil || s.entClient == nil { + return nil, fmt.Errorf("pending auth ent client is not configured") + } + + if input.IdentityID != nil && *input.IdentityID > 0 { + if _, err := s.entClient.IdentityAdoptionDecision.Update(). + Where( + identityadoptiondecision.IdentityIDEQ(*input.IdentityID), + dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) { + col := s.C(identityadoptiondecision.FieldPendingAuthSessionID) + s.Where(entsql.Or( + entsql.IsNull(col), + entsql.NEQ(col, input.PendingAuthSessionID), + )) + }), + ). + ClearIdentityID(). + Save(ctx); err != nil { + return nil, err + } + } + + existing, err := s.entClient.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)). + Only(ctx) + if err != nil && !dbent.IsNotFound(err) { + return nil, err + } + if existing == nil { + create := s.entClient.IdentityAdoptionDecision.Create(). + SetPendingAuthSessionID(input.PendingAuthSessionID). + SetAdoptDisplayName(input.AdoptDisplayName). + SetAdoptAvatar(input.AdoptAvatar). + SetDecidedAt(time.Now().UTC()) + if input.IdentityID != nil { + create = create.SetIdentityID(*input.IdentityID) + } + return create.Save(ctx) + } + + update := s.entClient.IdentityAdoptionDecision.UpdateOneID(existing.ID). + SetAdoptDisplayName(input.AdoptDisplayName). + SetAdoptAvatar(input.AdoptAvatar) + if input.IdentityID != nil { + update = update.SetIdentityID(*input.IdentityID) + } + return update.Save(ctx) +} + +func copyPendingMap(in map[string]any) map[string]any { + if len(in) == 0 { + return map[string]any{} + } + out := make(map[string]any, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +func randomOpaqueToken(byteLen int) (string, error) { + if byteLen <= 0 { + byteLen = 16 + } + buf := make([]byte, byteLen) + if _, err := rand.Read(buf); err != nil { + return "", err + } + return hex.EncodeToString(buf), nil +} + +func hashPendingAuthCode(code string) string { + sum := sha256.Sum256([]byte(code)) + return hex.EncodeToString(sum[:]) +} diff --git a/backend/internal/service/auth_pending_identity_service_test.go b/backend/internal/service/auth_pending_identity_service_test.go new file mode 100644 index 0000000000000000000000000000000000000000..de0b18d28b9867887db4dc5b07928bbb07b23b36 --- /dev/null +++ b/backend/internal/service/auth_pending_identity_service_test.go @@ -0,0 +1,358 @@ +//go:build unit + +package service + +import ( + "context" + "database/sql" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +func newAuthPendingIdentityServiceTestClient(t *testing.T) (*AuthPendingIdentityService, *dbent.Client) { + t.Helper() + + db, err := sql.Open("sqlite", "file:auth_pending_identity_service?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + + return NewAuthPendingIdentityService(client), client +} + +func TestAuthPendingIdentityService_CreatePendingSessionStoresSeparatedState(t *testing.T) { + svc, client := newAuthPendingIdentityServiceTestClient(t) + ctx := context.Background() + + targetUser, err := client.User.Create(). + SetEmail("pending-target@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "bind_current_user", + Identity: PendingAuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-open", + ProviderSubject: "union-123", + }, + TargetUserID: &targetUser.ID, + RedirectTo: "/profile", + ResolvedEmail: "user@example.com", + BrowserSessionKey: "browser-1", + UpstreamIdentityClaims: map[string]any{"nickname": "wx-user", "avatar_url": "https://cdn.example/avatar.png"}, + LocalFlowState: map[string]any{"step": "email_required"}, + }) + require.NoError(t, err) + require.NotEmpty(t, session.SessionToken) + require.Equal(t, "bind_current_user", session.Intent) + require.Equal(t, "wechat", session.ProviderType) + require.NotNil(t, session.TargetUserID) + require.Equal(t, targetUser.ID, *session.TargetUserID) + require.Equal(t, "wx-user", session.UpstreamIdentityClaims["nickname"]) + require.Equal(t, "email_required", session.LocalFlowState["step"]) +} + +func TestAuthPendingIdentityService_CompletionCodeIsBrowserBoundAndOneTime(t *testing.T) { + svc, _ := newAuthPendingIdentityServiceTestClient(t) + ctx := context.Background() + + session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "login", + Identity: PendingAuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + ProviderSubject: "subject-1", + }, + BrowserSessionKey: "browser-expected", + UpstreamIdentityClaims: map[string]any{"nickname": "linux-user"}, + LocalFlowState: map[string]any{"step": "pending"}, + }) + require.NoError(t, err) + + issued, err := svc.IssueCompletionCode(ctx, IssuePendingAuthCompletionCodeInput{ + PendingAuthSessionID: session.ID, + BrowserSessionKey: "browser-expected", + }) + require.NoError(t, err) + require.NotEmpty(t, issued.Code) + + _, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-other") + require.ErrorIs(t, err, ErrPendingAuthBrowserMismatch) + + consumed, err := svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expected") + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) + require.Empty(t, consumed.CompletionCodeHash) + require.Nil(t, consumed.CompletionCodeExpiresAt) + + _, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expected") + require.ErrorIs(t, err, ErrPendingAuthCodeInvalid) +} + +func TestAuthPendingIdentityService_CompletionCodeExpires(t *testing.T) { + svc, client := newAuthPendingIdentityServiceTestClient(t) + ctx := context.Background() + + session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "login", + Identity: PendingAuthIdentityKey{ + ProviderType: "oidc", + ProviderKey: "https://issuer.example", + ProviderSubject: "subject-1", + }, + BrowserSessionKey: "browser-expired", + }) + require.NoError(t, err) + + issued, err := svc.IssueCompletionCode(ctx, IssuePendingAuthCompletionCodeInput{ + PendingAuthSessionID: session.ID, + BrowserSessionKey: "browser-expired", + TTL: time.Second, + }) + require.NoError(t, err) + + _, err = client.PendingAuthSession.UpdateOneID(session.ID). + SetCompletionCodeExpiresAt(time.Now().UTC().Add(-time.Minute)). + Save(ctx) + require.NoError(t, err) + + _, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expired") + require.ErrorIs(t, err, ErrPendingAuthCodeExpired) +} + +func TestAuthPendingIdentityService_UpsertAdoptionDecision(t *testing.T) { + svc, client := newAuthPendingIdentityServiceTestClient(t) + ctx := context.Background() + + user, err := client.User.Create(). + SetEmail("adoption@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + identity, err := client.AuthIdentity.Create(). + SetUserID(user.ID). + SetProviderType("wechat"). + SetProviderKey("wechat-open"). + SetProviderSubject("union-adoption"). + SetMetadata(map[string]any{}). + Save(ctx) + require.NoError(t, err) + + session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "bind_current_user", + Identity: PendingAuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-open", + ProviderSubject: "union-adoption", + }, + }) + require.NoError(t, err) + + first, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: session.ID, + AdoptDisplayName: true, + AdoptAvatar: false, + }) + require.NoError(t, err) + require.True(t, first.AdoptDisplayName) + require.False(t, first.AdoptAvatar) + require.Nil(t, first.IdentityID) + + second, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: session.ID, + IdentityID: &identity.ID, + AdoptDisplayName: true, + AdoptAvatar: true, + }) + require.NoError(t, err) + require.Equal(t, first.ID, second.ID) + require.NotNil(t, second.IdentityID) + require.Equal(t, identity.ID, *second.IdentityID) + require.True(t, second.AdoptAvatar) +} + +func TestAuthPendingIdentityService_UpsertAdoptionDecision_ReassignsExistingIdentityReference(t *testing.T) { + svc, client := newAuthPendingIdentityServiceTestClient(t) + ctx := context.Background() + + user, err := client.User.Create(). + SetEmail("adoption-reassign@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + identity, err := client.AuthIdentity.Create(). + SetUserID(user.ID). + SetProviderType("wechat"). + SetProviderKey("wechat-open"). + SetProviderSubject("union-reassign"). + SetMetadata(map[string]any{}). + Save(ctx) + require.NoError(t, err) + + firstSession, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "bind_current_user", + Identity: PendingAuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-open", + ProviderSubject: "union-reassign", + }, + }) + require.NoError(t, err) + + firstDecision, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: firstSession.ID, + IdentityID: &identity.ID, + AdoptDisplayName: true, + AdoptAvatar: false, + }) + require.NoError(t, err) + require.NotNil(t, firstDecision.IdentityID) + require.Equal(t, identity.ID, *firstDecision.IdentityID) + + secondSession, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "bind_current_user", + Identity: PendingAuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-open", + ProviderSubject: "union-reassign", + }, + }) + require.NoError(t, err) + + secondDecision, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: secondSession.ID, + IdentityID: &identity.ID, + AdoptDisplayName: false, + AdoptAvatar: true, + }) + require.NoError(t, err) + require.NotNil(t, secondDecision.IdentityID) + require.Equal(t, identity.ID, *secondDecision.IdentityID) + + reloadedFirst, err := client.IdentityAdoptionDecision.Get(ctx, firstDecision.ID) + require.NoError(t, err) + require.Nil(t, reloadedFirst.IdentityID) +} + +func TestAuthPendingIdentityService_UpsertAdoptionDecision_ClearsLegacyNullSessionReference(t *testing.T) { + t.Skip("legacy NULL pending_auth_session_id rows only exist in production PostgreSQL history; sqlite unit schema rejects NULL") + + svc, client := newAuthPendingIdentityServiceTestClient(t) + ctx := context.Background() + + user, err := client.User.Create(). + SetEmail("legacy-null-session@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + identity, err := client.AuthIdentity.Create(). + SetUserID(user.ID). + SetProviderType("wechat"). + SetProviderKey("wechat-main"). + SetProviderSubject("legacy-null-session"). + SetMetadata(map[string]any{}). + Save(ctx) + require.NoError(t, err) + + _, err = client.ExecContext( + ctx, + `INSERT INTO identity_adoption_decisions + (identity_id, adopt_display_name, adopt_avatar, decided_at, created_at, updated_at, pending_auth_session_id) + VALUES (?, ?, ?, ?, ?, ?, NULL)`, + identity.ID, + true, + false, + time.Now().UTC(), + time.Now().UTC(), + time.Now().UTC(), + ) + require.NoError(t, err) + legacyDecision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.IdentityIDEQ(identity.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, legacyDecision.IdentityID) + + session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "bind_current_user", + Identity: PendingAuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-main", + ProviderSubject: "legacy-null-session", + }, + }) + require.NoError(t, err) + + decision, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: session.ID, + IdentityID: &identity.ID, + AdoptDisplayName: false, + AdoptAvatar: true, + }) + require.NoError(t, err) + require.NotNil(t, decision.IdentityID) + require.Equal(t, identity.ID, *decision.IdentityID) + + reloadedLegacy, err := client.IdentityAdoptionDecision.Get(ctx, legacyDecision.ID) + require.NoError(t, err) + require.Nil(t, reloadedLegacy.IdentityID) +} + +func TestAuthPendingIdentityService_ConsumeBrowserSession(t *testing.T) { + svc, _ := newAuthPendingIdentityServiceTestClient(t) + ctx := context.Background() + + session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "login", + Identity: PendingAuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: "subject-session-token", + }, + BrowserSessionKey: "browser-session", + LocalFlowState: map[string]any{ + "completion_response": map[string]any{ + "access_token": "token", + }, + }, + }) + require.NoError(t, err) + + _, err = svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-other") + require.ErrorIs(t, err, ErrPendingAuthBrowserMismatch) + + consumed, err := svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session") + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) + + _, err = svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session") + require.ErrorIs(t, err, ErrPendingAuthSessionConsumed) +} diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index fd28cd4235274da6d0433c690454a1089214bf4a..6d61894b2acb29c0b90605b24b16f09ab71f96bd 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -13,6 +13,7 @@ import ( "time" dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" @@ -77,6 +78,12 @@ type DefaultSubscriptionAssigner interface { AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) } +type signupGrantPlan struct { + Balance float64 + Concurrency int + Subscriptions []DefaultSubscriptionSetting +} + // NewAuthService 创建认证服务实例 func NewAuthService( entClient *dbent.Client, @@ -106,6 +113,13 @@ func NewAuthService( } } +func (s *AuthService) EntClient() *dbent.Client { + if s == nil { + return nil + } + return s.entClient +} + // Register 用户注册,返回token和用户 func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) { return s.RegisterWithVerification(ctx, email, password, "", "", "") @@ -179,21 +193,15 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw return "", nil, fmt.Errorf("hash password: %w", err) } - // 获取默认配置 - defaultBalance := s.cfg.Default.UserBalance - defaultConcurrency := s.cfg.Default.UserConcurrency - if s.settingService != nil { - defaultBalance = s.settingService.GetDefaultBalance(ctx) - defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx) - } + grantPlan := s.resolveSignupGrantPlan(ctx, "email") // 创建用户 user := &User{ Email: email, PasswordHash: hashedPassword, Role: RoleUser, - Balance: defaultBalance, - Concurrency: defaultConcurrency, + Balance: grantPlan.Balance, + Concurrency: grantPlan.Concurrency, Status: StatusActive, } @@ -205,7 +213,8 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw logger.LegacyPrintf("service.auth", "[Auth] Database error creating user: %v", err) return "", nil, ErrServiceUnavailable } - s.assignDefaultSubscriptions(ctx, user.ID) + s.postAuthUserBootstrap(ctx, user, "email", true) + s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") // 标记邀请码为已使用(如果使用了邀请码) if invitationRedeemCode != nil { @@ -469,21 +478,16 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username return "", nil, fmt.Errorf("hash password: %w", err) } - // 新用户默认值。 - defaultBalance := s.cfg.Default.UserBalance - defaultConcurrency := s.cfg.Default.UserConcurrency - if s.settingService != nil { - defaultBalance = s.settingService.GetDefaultBalance(ctx) - defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx) - } + signupSource := inferLegacySignupSource(email) + grantPlan := s.resolveSignupGrantPlan(ctx, signupSource) newUser := &User{ Email: email, Username: username, PasswordHash: hashedPassword, Role: RoleUser, - Balance: defaultBalance, - Concurrency: defaultConcurrency, + Balance: grantPlan.Balance, + Concurrency: grantPlan.Concurrency, Status: StatusActive, } @@ -501,7 +505,8 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username } } else { user = newUser - s.assignDefaultSubscriptions(ctx, user.ID) + s.postAuthUserBootstrap(ctx, user, signupSource, false) + s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") } } else { logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err) @@ -520,7 +525,6 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err) } } - token, err := s.GenerateToken(user) if err != nil { return "", nil, fmt.Errorf("generate token: %w", err) @@ -584,20 +588,16 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema return nil, nil, fmt.Errorf("hash password: %w", err) } - defaultBalance := s.cfg.Default.UserBalance - defaultConcurrency := s.cfg.Default.UserConcurrency - if s.settingService != nil { - defaultBalance = s.settingService.GetDefaultBalance(ctx) - defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx) - } + signupSource := inferLegacySignupSource(email) + grantPlan := s.resolveSignupGrantPlan(ctx, signupSource) newUser := &User{ Email: email, Username: username, PasswordHash: hashedPassword, Role: RoleUser, - Balance: defaultBalance, - Concurrency: defaultConcurrency, + Balance: grantPlan.Balance, + Concurrency: grantPlan.Concurrency, Status: StatusActive, } @@ -630,7 +630,8 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema return nil, nil, ErrServiceUnavailable } user = newUser - s.assignDefaultSubscriptions(ctx, user.ID) + s.postAuthUserBootstrap(ctx, user, signupSource, false) + s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") } } else { if err := s.userRepo.Create(ctx, newUser); err != nil { @@ -646,7 +647,8 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema } } else { user = newUser - s.assignDefaultSubscriptions(ctx, user.ID) + s.postAuthUserBootstrap(ctx, user, signupSource, false) + s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") if invitationRedeemCode != nil { if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil { return nil, nil, ErrInvitationCodeInvalid @@ -670,7 +672,6 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err) } } - tokenPair, err := s.GenerateTokenPair(ctx, user, "") if err != nil { return nil, nil, fmt.Errorf("generate token pair: %w", err) @@ -678,77 +679,270 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema return tokenPair, user, nil } -// pendingOAuthTokenTTL is the validity period for pending OAuth tokens. -const pendingOAuthTokenTTL = 10 * time.Minute +func (s *AuthService) assignSubscriptions(ctx context.Context, userID int64, items []DefaultSubscriptionSetting, notes string) { + if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 { + return + } + for _, item := range items { + if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{ + UserID: userID, + GroupID: item.GroupID, + ValidityDays: item.ValidityDays, + Notes: notes, + }); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err) + } + } +} -// pendingOAuthPurpose is the purpose claim value for pending OAuth registration tokens. -const pendingOAuthPurpose = "pending_oauth_registration" +func (s *AuthService) resolveSignupGrantPlan(ctx context.Context, signupSource string) signupGrantPlan { + plan := signupGrantPlan{} + if s != nil && s.cfg != nil { + plan.Balance = s.cfg.Default.UserBalance + plan.Concurrency = s.cfg.Default.UserConcurrency + } + if s == nil || s.settingService == nil { + return plan + } -type pendingOAuthClaims struct { - Email string `json:"email"` - Username string `json:"username"` - Purpose string `json:"purpose"` - jwt.RegisteredClaims + plan.Balance = s.settingService.GetDefaultBalance(ctx) + plan.Concurrency = s.settingService.GetDefaultConcurrency(ctx) + plan.Subscriptions = s.settingService.GetDefaultSubscriptions(ctx) + + resolved, enabled, err := s.settingService.ResolveAuthSourceGrantSettings(ctx, signupSource, false) + if err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to load auth source signup defaults for %s: %v", signupSource, err) + return plan + } + if !enabled { + return plan + } + + plan.Balance = resolved.Balance + plan.Concurrency = resolved.Concurrency + plan.Subscriptions = resolved.Subscriptions + return plan } -// CreatePendingOAuthToken generates a short-lived JWT that carries the OAuth identity -// while waiting for the user to supply an invitation code. -func (s *AuthService) CreatePendingOAuthToken(email, username string) (string, error) { - now := time.Now() - claims := &pendingOAuthClaims{ - Email: email, - Username: username, - Purpose: pendingOAuthPurpose, - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(now.Add(pendingOAuthTokenTTL)), - IssuedAt: jwt.NewNumericDate(now), - NotBefore: jwt.NewNumericDate(now), - }, +func authSourceSignupSettings(defaults *AuthSourceDefaultSettings, signupSource string) (ProviderDefaultGrantSettings, bool) { + if defaults == nil { + return ProviderDefaultGrantSettings{}, false + } + + switch strings.ToLower(strings.TrimSpace(signupSource)) { + case "email": + return defaults.Email, true + case "linuxdo": + return defaults.LinuxDo, true + case "oidc": + return defaults.OIDC, true + case "wechat": + return defaults.WeChat, true + default: + return ProviderDefaultGrantSettings{}, false } - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - return token.SignedString([]byte(s.cfg.JWT.Secret)) } -// VerifyPendingOAuthToken validates a pending OAuth token and returns the embedded identity. -// Returns ErrInvalidToken when the token is invalid or expired. -func (s *AuthService) VerifyPendingOAuthToken(tokenStr string) (email, username string, err error) { - if len(tokenStr) > maxTokenLength { - return "", "", ErrInvalidToken +func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, signupSource string, touchLogin bool) { + if user == nil || user.ID <= 0 { + return } - parser := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) - token, parseErr := parser.ParseWithClaims(tokenStr, &pendingOAuthClaims{}, func(t *jwt.Token) (any, error) { - if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) - } - return []byte(s.cfg.JWT.Secret), nil - }) - if parseErr != nil { - return "", "", ErrInvalidToken + + if strings.TrimSpace(signupSource) == "" { + signupSource = "email" + } + s.updateUserSignupSource(ctx, user.ID, signupSource) + + if touchLogin { + s.touchUserLogin(ctx, user.ID) + } +} + +func (s *AuthService) updateUserSignupSource(ctx context.Context, userID int64, signupSource string) { + if s == nil || s.entClient == nil || userID <= 0 { + return } - claims, ok := token.Claims.(*pendingOAuthClaims) - if !ok || !token.Valid { - return "", "", ErrInvalidToken + if strings.TrimSpace(signupSource) == "" { + return } - if claims.Purpose != pendingOAuthPurpose { - return "", "", ErrInvalidToken + if err := s.entClient.User.UpdateOneID(userID). + SetSignupSource(signupSource). + Exec(ctx); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to update signup source: user_id=%d source=%s err=%v", userID, signupSource, err) } - return claims.Email, claims.Username, nil } -func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) { - if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 { +func (s *AuthService) touchUserLogin(ctx context.Context, userID int64) { + if s == nil || s.entClient == nil || userID <= 0 { return } - items := s.settingService.GetDefaultSubscriptions(ctx) - for _, item := range items { - if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{ - UserID: userID, - GroupID: item.GroupID, - ValidityDays: item.ValidityDays, - Notes: "auto assigned by default user subscriptions setting", - }); err != nil { - logger.LegacyPrintf("service.auth", "[Auth] Failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err) + now := time.Now().UTC() + if err := s.entClient.User.UpdateOneID(userID). + SetLastLoginAt(now). + SetLastActiveAt(now). + Exec(ctx); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to touch login timestamps: user_id=%d err=%v", userID, err) + } +} + +func (s *AuthService) backfillEmailIdentityOnSuccessfulLogin(ctx context.Context, user *User) { + if s == nil || user == nil || user.ID <= 0 { + return + } + identity, created := s.ensureEmailAuthIdentity(ctx, user, "auth_service_login_backfill") + if s.shouldApplyEmailFirstBindDefaults(ctx, user.ID, identity, created) { + if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, user.ID, "email"); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to apply email first bind defaults: user_id=%d err=%v", user.ID, err) + } + } +} + +func (s *AuthService) shouldApplyEmailFirstBindDefaults( + ctx context.Context, + userID int64, + identity *dbent.AuthIdentity, + created bool, +) bool { + source := emailAuthIdentitySource(identity.Metadata) + if source == "auth_service_login_backfill" { + return false + } + if created { + return true + } + if s == nil || s.entClient == nil || userID <= 0 || identity == nil || identity.UserID != userID { + return false + } + if source != "auth_service_dual_write" { + return false + } + + hasGrant, err := s.hasProviderGrantRecord(ctx, userID, "email", "first_bind") + if err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to inspect email first bind grant state: user_id=%d err=%v", userID, err) + return false + } + return !hasGrant +} + +func emailAuthIdentitySource(metadata map[string]any) string { + if len(metadata) == 0 { + return "" + } + raw, ok := metadata["source"] + if !ok { + return "" + } + return strings.TrimSpace(fmt.Sprint(raw)) +} + +func (s *AuthService) hasProviderGrantRecord( + ctx context.Context, + userID int64, + providerType string, + grantReason string, +) (bool, error) { + if s == nil || s.entClient == nil || userID <= 0 { + return false, nil + } + + rows, err := s.entClient.QueryContext( + ctx, + `SELECT 1 FROM user_provider_default_grants WHERE user_id = $1 AND provider_type = $2 AND grant_reason = $3 LIMIT 1`, + userID, + strings.TrimSpace(providerType), + strings.TrimSpace(grantReason), + ) + if err != nil { + return false, err + } + defer func() { _ = rows.Close() }() + return rows.Next(), rows.Err() +} + +func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User, source string) (*dbent.AuthIdentity, bool) { + if s == nil || s.entClient == nil || user == nil || user.ID <= 0 { + return nil, false + } + + email := strings.ToLower(strings.TrimSpace(user.Email)) + if email == "" || isReservedEmail(email) { + return nil, false + } + if strings.TrimSpace(source) == "" { + source = "auth_service_dual_write" + } + + client := s.entClient + if tx := dbent.TxFromContext(ctx); tx != nil { + client = tx.Client() + } + + buildQuery := func() *dbent.AuthIdentityQuery { + return client.AuthIdentity.Query().Where( + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ(email), + ) + } + + existed, err := buildQuery().Exist(ctx) + if err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to inspect email auth identity: user_id=%d email=%s err=%v", user.ID, email, err) + return nil, false + } + + if !existed { + if err := client.AuthIdentity.Create(). + SetUserID(user.ID). + SetProviderType("email"). + SetProviderKey("email"). + SetProviderSubject(email). + SetVerifiedAt(time.Now().UTC()). + SetMetadata(map[string]any{ + "source": strings.TrimSpace(source), + }). + OnConflictColumns( + authidentity.FieldProviderType, + authidentity.FieldProviderKey, + authidentity.FieldProviderSubject, + ). + DoNothing(). + Exec(ctx); err != nil { + if isSQLNoRowsError(err) { + return nil, false + } } + if err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to ensure email auth identity: user_id=%d email=%s err=%v", user.ID, email, err) + return nil, false + } + } + + identity, err := buildQuery().Only(ctx) + if err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to reload email auth identity: user_id=%d email=%s err=%v", user.ID, email, err) + return nil, false + } + if identity.UserID != user.ID { + logger.LegacyPrintf("service.auth", "[Auth] Email auth identity ownership mismatch: user_id=%d email=%s owner_id=%d", user.ID, email, identity.UserID) + return nil, false + } + + return identity, !existed +} + +func inferLegacySignupSource(email string) string { + normalized := strings.ToLower(strings.TrimSpace(email)) + switch { + case strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain): + return "linuxdo" + case strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain): + return "oidc" + case strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain): + return "wechat" + default: + return "email" } } @@ -834,7 +1028,8 @@ func randomHexString(byteLength int) (string, error) { func isReservedEmail(email string) bool { normalized := strings.ToLower(strings.TrimSpace(email)) return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain) || - strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain) + strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain) || + strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain) } // GenerateToken 生成JWT access token diff --git a/backend/internal/service/auth_service_email_bind_test.go b/backend/internal/service/auth_service_email_bind_test.go new file mode 100644 index 0000000000000000000000000000000000000000..d32a4a40d6af784646790676edf622231762a42e --- /dev/null +++ b/backend/internal/service/auth_service_email_bind_test.go @@ -0,0 +1,529 @@ +//go:build unit + +package service_test + +import ( + "context" + "database/sql" + "errors" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/repository" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +type emailBindDefaultSubAssignerStub struct { + calls []*service.AssignSubscriptionInput +} + +func (s *emailBindDefaultSubAssignerStub) AssignOrExtendSubscription( + _ context.Context, + input *service.AssignSubscriptionInput, +) (*service.UserSubscription, bool, error) { + cloned := *input + s.calls = append(s.calls, &cloned) + return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil +} + +type flakyEmailBindDefaultSubAssignerStub struct { + err error + calls []*service.AssignSubscriptionInput +} + +func (s *flakyEmailBindDefaultSubAssignerStub) AssignOrExtendSubscription( + _ context.Context, + input *service.AssignSubscriptionInput, +) (*service.UserSubscription, bool, error) { + cloned := *input + s.calls = append(s.calls, &cloned) + return nil, false, s.err +} + +func newAuthServiceForEmailBind( + t *testing.T, + settings map[string]string, + emailCache service.EmailCache, + defaultSubAssigner service.DefaultSubscriptionAssigner, +) (*service.AuthService, service.UserRepository, *dbent.Client) { + t.Helper() + + db, err := sql.Open("sqlite", "file:auth_service_email_bind?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + _, err = db.Exec(` +CREATE TABLE IF NOT EXISTS user_provider_default_grants ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + provider_type TEXT NOT NULL, + grant_reason TEXT NOT NULL DEFAULT 'first_bind', + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + UNIQUE(user_id, provider_type, grant_reason) +)`) + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + + repo := repository.NewUserRepository(client, db) + cfg := &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-bind-email-secret", + ExpireHour: 1, + }, + Default: config.DefaultConfig{ + UserBalance: 3.5, + UserConcurrency: 2, + }, + } + + settingRepo := &emailBindSettingRepoStub{values: settings} + settingSvc := service.NewSettingService(settingRepo, cfg) + + var emailSvc *service.EmailService + if emailCache != nil { + emailSvc = service.NewEmailService(settingRepo, emailCache) + } + + svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner) + return svc, repo, client +} + +func TestAuthServiceBindEmailIdentity_UpdatesEmailAndAppliesFirstBindDefaults(t *testing.T) { + assigner := &emailBindDefaultSubAssignerStub{} + cache := &emailBindCacheStub{ + data: &service.VerificationCodeData{ + Code: "123456", + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(10 * time.Minute), + }, + } + svc, _, client := newAuthServiceForEmailBind(t, map[string]string{ + service.SettingKeyAuthSourceDefaultEmailBalance: "8.5", + service.SettingKeyAuthSourceDefaultEmailConcurrency: "4", + service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`, + service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true", + }, cache, assigner) + + ctx := context.Background() + user, err := client.User.Create(). + SetEmail("legacy-user" + service.LinuxDoConnectSyntheticEmailDomain). + SetUsername("legacy-user"). + SetPasswordHash("old-hash"). + SetBalance(2.5). + SetConcurrency(1). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, " NewEmail@Example.com ", "123456", "new-password") + require.NoError(t, err) + require.NotNil(t, updatedUser) + require.Equal(t, "newemail@example.com", updatedUser.Email) + + storedUser, err := client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, "newemail@example.com", storedUser.Email) + require.Equal(t, 11.0, storedUser.Balance) + require.Equal(t, 5, storedUser.Concurrency) + require.True(t, svc.CheckPassword("new-password", storedUser.PasswordHash)) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.UserIDEQ(user.ID), + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("newemail@example.com"), + ). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, identityCount) + + require.Len(t, assigner.calls, 1) + require.Equal(t, user.ID, assigner.calls[0].UserID) + require.Equal(t, int64(11), assigner.calls[0].GroupID) + require.Equal(t, 30, assigner.calls[0].ValidityDays) + require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind")) +} + +func TestAuthServiceBindEmailIdentity_RejectsExistingEmailOnAnotherUser(t *testing.T) { + cache := &emailBindCacheStub{ + data: &service.VerificationCodeData{ + Code: "123456", + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(10 * time.Minute), + }, + } + svc, _, client := newAuthServiceForEmailBind(t, nil, cache, nil) + + ctx := context.Background() + sourceUser, err := client.User.Create(). + SetEmail("source-user" + service.OIDCConnectSyntheticEmailDomain). + SetUsername("source-user"). + SetPasswordHash("old-hash"). + SetBalance(1). + SetConcurrency(1). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + _, err = client.User.Create(). + SetEmail("taken@example.com"). + SetUsername("taken-user"). + SetPasswordHash("hash"). + SetBalance(1). + SetConcurrency(1). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + updatedUser, err := svc.BindEmailIdentity(ctx, sourceUser.ID, "taken@example.com", "123456", "new-password") + require.ErrorIs(t, err, service.ErrEmailExists) + require.Nil(t, updatedUser) + + storedUser, err := client.User.Get(ctx, sourceUser.ID) + require.NoError(t, err) + require.Equal(t, "source-user"+service.OIDCConnectSyntheticEmailDomain, storedUser.Email) + require.Equal(t, 0, countProviderGrantRecords(t, client, sourceUser.ID, "email", "first_bind")) +} + +func TestAuthServiceBindEmailIdentity_RollsBackWhenFirstBindDefaultsFail(t *testing.T) { + assigner := &flakyEmailBindDefaultSubAssignerStub{err: errors.New("temporary assign failure")} + cache := &emailBindCacheStub{ + data: &service.VerificationCodeData{ + Code: "123456", + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(10 * time.Minute), + }, + } + svc, _, client := newAuthServiceForEmailBind(t, map[string]string{ + service.SettingKeyAuthSourceDefaultEmailBalance: "8.5", + service.SettingKeyAuthSourceDefaultEmailConcurrency: "4", + service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`, + service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true", + }, cache, assigner) + + ctx := context.Background() + originalEmail := "legacy-rollback" + service.LinuxDoConnectSyntheticEmailDomain + user, err := client.User.Create(). + SetEmail(originalEmail). + SetUsername("legacy-rollback"). + SetPasswordHash("old-hash"). + SetBalance(2.5). + SetConcurrency(1). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "rollback@example.com", "123456", "new-password") + require.ErrorContains(t, err, "apply email first bind defaults") + require.ErrorContains(t, err, "temporary assign failure") + require.Nil(t, updatedUser) + + storedUser, err := client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, originalEmail, storedUser.Email) + require.Equal(t, "old-hash", storedUser.PasswordHash) + require.Equal(t, 2.5, storedUser.Balance) + require.Equal(t, 1, storedUser.Concurrency) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.UserIDEQ(user.ID), + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("rollback@example.com"), + ). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 0, identityCount) + + require.Len(t, assigner.calls, 1) + require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind")) +} + +func TestAuthServiceBindEmailIdentity_RejectsReservedEmail(t *testing.T) { + cache := &emailBindCacheStub{ + data: &service.VerificationCodeData{ + Code: "123456", + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(10 * time.Minute), + }, + } + svc, _, client := newAuthServiceForEmailBind(t, nil, cache, nil) + + ctx := context.Background() + user, err := client.User.Create(). + SetEmail("source-user@example.com"). + SetUsername("source-user"). + SetPasswordHash("old-hash"). + SetBalance(1). + SetConcurrency(1). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "reserved"+service.LinuxDoConnectSyntheticEmailDomain, "123456", "new-password") + require.ErrorIs(t, err, service.ErrEmailReserved) + require.Nil(t, updatedUser) +} + +func TestAuthServiceBindEmailIdentity_ReplacesBoundEmailAndSkipsFirstBindDefaults(t *testing.T) { + assigner := &emailBindDefaultSubAssignerStub{} + cache := &emailBindCacheStub{ + data: &service.VerificationCodeData{ + Code: "123456", + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(10 * time.Minute), + }, + } + svc, _, client := newAuthServiceForEmailBind(t, map[string]string{ + service.SettingKeyAuthSourceDefaultEmailBalance: "8.5", + service.SettingKeyAuthSourceDefaultEmailConcurrency: "4", + service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`, + service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true", + }, cache, assigner) + + ctx := context.Background() + hashedPassword, err := svc.HashPassword("current-password") + require.NoError(t, err) + + user, err := client.User.Create(). + SetEmail("current@example.com"). + SetUsername("bound-user"). + SetPasswordHash(hashedPassword). + SetBalance(7.5). + SetConcurrency(3). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + require.NoError(t, client.AuthIdentity.Create(). + SetUserID(user.ID). + SetProviderType("email"). + SetProviderKey("email"). + SetProviderSubject("current@example.com"). + SetVerifiedAt(time.Now().UTC()). + SetMetadata(map[string]any{"source": "test"}). + Exec(ctx)) + + updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "new@example.com", "123456", "current-password") + require.NoError(t, err) + require.NotNil(t, updatedUser) + require.Equal(t, "new@example.com", updatedUser.Email) + + storedUser, err := client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, "new@example.com", storedUser.Email) + require.Equal(t, 7.5, storedUser.Balance) + require.Equal(t, 3, storedUser.Concurrency) + require.True(t, svc.CheckPassword("current-password", storedUser.PasswordHash)) + + newIdentityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.UserIDEQ(user.ID), + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("new@example.com"), + ). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, newIdentityCount) + + oldIdentityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.UserIDEQ(user.ID), + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("current@example.com"), + ). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 0, oldIdentityCount) + + require.Empty(t, assigner.calls) + require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind")) +} + +func TestAuthServiceBindEmailIdentity_RejectsWrongCurrentPasswordForBoundEmail(t *testing.T) { + cache := &emailBindCacheStub{ + data: &service.VerificationCodeData{ + Code: "123456", + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(10 * time.Minute), + }, + } + svc, _, client := newAuthServiceForEmailBind(t, nil, cache, nil) + + ctx := context.Background() + hashedPassword, err := svc.HashPassword("current-password") + require.NoError(t, err) + + user, err := client.User.Create(). + SetEmail("current@example.com"). + SetUsername("bound-user"). + SetPasswordHash(hashedPassword). + SetBalance(1). + SetConcurrency(1). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + require.NoError(t, client.AuthIdentity.Create(). + SetUserID(user.ID). + SetProviderType("email"). + SetProviderKey("email"). + SetProviderSubject("current@example.com"). + SetVerifiedAt(time.Now().UTC()). + SetMetadata(map[string]any{"source": "test"}). + Exec(ctx)) + + updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "new@example.com", "123456", "wrong-password") + require.ErrorIs(t, err, service.ErrPasswordIncorrect) + require.Nil(t, updatedUser) + + storedUser, err := client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, "current@example.com", storedUser.Email) + require.True(t, svc.CheckPassword("current-password", storedUser.PasswordHash)) + + oldIdentityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.UserIDEQ(user.ID), + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("current@example.com"), + ). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, oldIdentityCount) + + newIdentityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.UserIDEQ(user.ID), + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("new@example.com"), + ). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 0, newIdentityCount) +} + +type emailBindSettingRepoStub struct { + values map[string]string +} + +func (s *emailBindSettingRepoStub) Get(context.Context, string) (*service.Setting, error) { + panic("unexpected Get call") +} + +func (s *emailBindSettingRepoStub) GetValue(_ context.Context, key string) (string, error) { + if v, ok := s.values[key]; ok { + return v, nil + } + return "", service.ErrSettingNotFound +} + +func (s *emailBindSettingRepoStub) Set(context.Context, string, string) error { + panic("unexpected Set call") +} + +func (s *emailBindSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if v, ok := s.values[key]; ok { + out[key] = v + } + } + return out, nil +} + +func (s *emailBindSettingRepoStub) SetMultiple(context.Context, map[string]string) error { + panic("unexpected SetMultiple call") +} + +func (s *emailBindSettingRepoStub) GetAll(context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *emailBindSettingRepoStub) Delete(context.Context, string) error { + panic("unexpected Delete call") +} + +type emailBindCacheStub struct { + data *service.VerificationCodeData + err error +} + +func (s *emailBindCacheStub) GetVerificationCode(context.Context, string) (*service.VerificationCodeData, error) { + if s.err != nil { + return nil, s.err + } + return s.data, nil +} + +func (s *emailBindCacheStub) SetVerificationCode(context.Context, string, *service.VerificationCodeData, time.Duration) error { + return nil +} + +func (s *emailBindCacheStub) DeleteVerificationCode(context.Context, string) error { + return nil +} + +func (s *emailBindCacheStub) GetNotifyVerifyCode(context.Context, string) (*service.VerificationCodeData, error) { + return nil, nil +} + +func (s *emailBindCacheStub) SetNotifyVerifyCode(context.Context, string, *service.VerificationCodeData, time.Duration) error { + return nil +} + +func (s *emailBindCacheStub) DeleteNotifyVerifyCode(context.Context, string) error { + return nil +} + +func (s *emailBindCacheStub) GetPasswordResetToken(context.Context, string) (*service.PasswordResetTokenData, error) { + return nil, nil +} + +func (s *emailBindCacheStub) SetPasswordResetToken(context.Context, string, *service.PasswordResetTokenData, time.Duration) error { + return nil +} + +func (s *emailBindCacheStub) DeletePasswordResetToken(context.Context, string) error { + return nil +} + +func (s *emailBindCacheStub) IsPasswordResetEmailInCooldown(context.Context, string) bool { + return false +} + +func (s *emailBindCacheStub) SetPasswordResetEmailCooldown(context.Context, string, time.Duration) error { + return nil +} + +func (s *emailBindCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int64, error) { + return 0, nil +} + +func (s *emailBindCacheStub) IncrNotifyCodeUserRate(context.Context, int64, time.Duration) (int64, error) { + return 0, nil +} diff --git a/backend/internal/service/auth_service_identity_sync_test.go b/backend/internal/service/auth_service_identity_sync_test.go new file mode 100644 index 0000000000000000000000000000000000000000..2233e427eb7f479d68a11ef424a148b6575069be --- /dev/null +++ b/backend/internal/service/auth_service_identity_sync_test.go @@ -0,0 +1,482 @@ +//go:build unit + +package service_test + +import ( + "context" + "database/sql" + "errors" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/repository" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +type authIdentityDefaultSubAssignerStub struct { + calls []*service.AssignSubscriptionInput +} + +func (s *authIdentityDefaultSubAssignerStub) AssignOrExtendSubscription( + _ context.Context, + input *service.AssignSubscriptionInput, +) (*service.UserSubscription, bool, error) { + cloned := *input + s.calls = append(s.calls, &cloned) + return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, true, nil +} + +type flakyAuthIdentityDefaultSubAssignerStub struct { + failuresRemaining int + calls []*service.AssignSubscriptionInput +} + +func (s *flakyAuthIdentityDefaultSubAssignerStub) AssignOrExtendSubscription( + _ context.Context, + input *service.AssignSubscriptionInput, +) (*service.UserSubscription, bool, error) { + cloned := *input + s.calls = append(s.calls, &cloned) + if s.failuresRemaining > 0 { + s.failuresRemaining-- + return nil, false, errors.New("temporary assign failure") + } + return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, true, nil +} + +type authIdentitySettingRepoStub struct { + values map[string]string +} + +func (s *authIdentitySettingRepoStub) Get(context.Context, string) (*service.Setting, error) { + panic("unexpected Get call") +} + +func (s *authIdentitySettingRepoStub) GetValue(_ context.Context, key string) (string, error) { + if v, ok := s.values[key]; ok { + return v, nil + } + return "", service.ErrSettingNotFound +} + +func (s *authIdentitySettingRepoStub) Set(context.Context, string, string) error { + panic("unexpected Set call") +} + +func (s *authIdentitySettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if v, ok := s.values[key]; ok { + out[key] = v + } + } + return out, nil +} + +func (s *authIdentitySettingRepoStub) SetMultiple(context.Context, map[string]string) error { + panic("unexpected SetMultiple call") +} + +func (s *authIdentitySettingRepoStub) GetAll(context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *authIdentitySettingRepoStub) Delete(context.Context, string) error { + panic("unexpected Delete call") +} + +func newAuthServiceWithEnt( + t *testing.T, + settings map[string]string, + defaultSubAssigner service.DefaultSubscriptionAssigner, +) (*service.AuthService, service.UserRepository, *dbent.Client) { + t.Helper() + + db, err := sql.Open("sqlite", "file:auth_service_identity_sync?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + _, err = db.Exec(` +CREATE TABLE IF NOT EXISTS user_provider_default_grants ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + provider_type TEXT NOT NULL, + grant_reason TEXT NOT NULL DEFAULT 'first_bind', + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + UNIQUE(user_id, provider_type, grant_reason) +)`) + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + + repo := repository.NewUserRepository(client, db) + cfg := &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-auth-identity-secret", + ExpireHour: 1, + }, + Default: config.DefaultConfig{ + UserBalance: 3.5, + UserConcurrency: 2, + }, + } + settingSvc := service.NewSettingService(&authIdentitySettingRepoStub{ + values: settings, + }, cfg) + + svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, defaultSubAssigner) + return svc, repo, client +} + +func TestAuthServiceRegisterDualWritesEmailIdentity(t *testing.T) { + svc, _, client := newAuthServiceWithEnt(t, map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + }, nil) + ctx := context.Background() + + token, user, err := svc.Register(ctx, "user@example.com", "password") + require.NoError(t, err) + require.NotEmpty(t, token) + require.NotNil(t, user) + + storedUser, err := client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, "email", storedUser.SignupSource) + require.NotNil(t, storedUser.LastLoginAt) + require.NotNil(t, storedUser.LastActiveAt) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("user@example.com"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, user.ID, identity.UserID) + require.NotNil(t, identity.VerifiedAt) +} + +func TestAuthServiceLoginDefersLastLoginTouchUntilRecordSuccessfulLogin(t *testing.T) { + svc, _, client := newAuthServiceWithEnt(t, map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + }, nil) + ctx := context.Background() + + passwordHash, err := svc.HashPassword("password") + require.NoError(t, err) + user, err := client.User.Create(). + SetEmail("login@example.com"). + SetPasswordHash(passwordHash). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + SetBalance(1). + SetConcurrency(1). + Save(ctx) + require.NoError(t, err) + + old := time.Now().Add(-2 * time.Hour).UTC().Round(time.Second) + _, err = client.User.UpdateOneID(user.ID). + SetLastLoginAt(old). + SetLastActiveAt(old). + Save(ctx) + require.NoError(t, err) + + token, gotUser, err := svc.Login(ctx, user.Email, "password") + require.NoError(t, err) + require.NotEmpty(t, token) + require.NotNil(t, gotUser) + + storedUser, err := client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.NotNil(t, storedUser.LastLoginAt) + require.NotNil(t, storedUser.LastActiveAt) + require.True(t, storedUser.LastLoginAt.Equal(old)) + require.True(t, storedUser.LastActiveAt.Equal(old)) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("login@example.com"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, identityCount) + + svc.RecordSuccessfulLogin(ctx, user.ID) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("login@example.com"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, user.ID, identity.UserID) +} + +func TestAuthServiceRecordSuccessfulLoginBackfillsEmailIdentity(t *testing.T) { + svc, repo, client := newAuthServiceWithEnt(t, map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + }, nil) + ctx := context.Background() + + user := &service.User{ + Email: "record@example.com", + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 1, + Concurrency: 1, + } + require.NoError(t, user.SetPassword("password")) + require.NoError(t, repo.Create(ctx, user)) + + svc.RecordSuccessfulLogin(ctx, user.ID) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("record@example.com"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, user.ID, identity.UserID) +} + +func TestAuthServiceLogin_DoesNotApplyEmailFirstBindDefaultsWhenBackfillingLegacyEmailIdentity(t *testing.T) { + assigner := &authIdentityDefaultSubAssignerStub{} + svc, _, client := newAuthServiceWithEnt(t, map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyAuthSourceDefaultEmailBalance: "8.5", + service.SettingKeyAuthSourceDefaultEmailConcurrency: "4", + service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`, + service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true", + }, assigner) + ctx := context.Background() + + passwordHash, err := svc.HashPassword("password") + require.NoError(t, err) + user, err := client.User.Create(). + SetEmail("legacy@example.com"). + SetUsername("legacy-user"). + SetPasswordHash(passwordHash). + SetBalance(1.5). + SetConcurrency(2). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + token, gotUser, err := svc.Login(ctx, user.Email, "password") + require.NoError(t, err) + require.NotEmpty(t, token) + require.NotNil(t, gotUser) + svc.RecordSuccessfulLogin(ctx, user.ID) + + storedUser, err := client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, 1.5, storedUser.Balance) + require.Equal(t, 2, storedUser.Concurrency) + require.Empty(t, assigner.calls) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("legacy@example.com"), + ). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, identityCount) + require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind")) + + token, gotUser, err = svc.Login(ctx, user.Email, "password") + require.NoError(t, err) + require.NotEmpty(t, token) + require.NotNil(t, gotUser) + + storedUser, err = client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, 1.5, storedUser.Balance) + require.Equal(t, 2, storedUser.Concurrency) + require.Empty(t, assigner.calls) + require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind")) +} + +func TestAuthServiceLogin_DoesNotApplyMergedEmailFirstBindDefaultsWhenBackfillingLegacyEmailIdentity(t *testing.T) { + assigner := &authIdentityDefaultSubAssignerStub{} + svc, _, client := newAuthServiceWithEnt(t, map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyDefaultSubscriptions: `[{"group_id":21,"validity_days":14}]`, + service.SettingKeyAuthSourceDefaultEmailBalance: "8.5", + service.SettingKeyAuthSourceDefaultEmailConcurrency: "5", + service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[]`, + service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true", + }, assigner) + ctx := context.Background() + + passwordHash, err := svc.HashPassword("password") + require.NoError(t, err) + user, err := client.User.Create(). + SetEmail("merged-first-bind@example.com"). + SetUsername("merged-user"). + SetPasswordHash(passwordHash). + SetBalance(1.5). + SetConcurrency(2). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + token, gotUser, err := svc.Login(ctx, user.Email, "password") + require.NoError(t, err) + require.NotEmpty(t, token) + require.NotNil(t, gotUser) + svc.RecordSuccessfulLogin(ctx, user.ID) + + storedUser, err := client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, 1.5, storedUser.Balance) + require.Equal(t, 2, storedUser.Concurrency) + require.Empty(t, assigner.calls) + require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind")) +} + +func TestAuthServiceLogin_DoesNotApplyEmailFirstBindDefaultsWhenIdentityAlreadyExists(t *testing.T) { + assigner := &authIdentityDefaultSubAssignerStub{} + svc, _, client := newAuthServiceWithEnt(t, map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyAuthSourceDefaultEmailBalance: "8.5", + service.SettingKeyAuthSourceDefaultEmailConcurrency: "4", + service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`, + service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true", + }, assigner) + ctx := context.Background() + + passwordHash, err := svc.HashPassword("password") + require.NoError(t, err) + user, err := client.User.Create(). + SetEmail("bound@example.com"). + SetUsername("bound-user"). + SetPasswordHash(passwordHash). + SetBalance(2). + SetConcurrency(3). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + _, err = client.AuthIdentity.Create(). + SetUserID(user.ID). + SetProviderType("email"). + SetProviderKey("email"). + SetProviderSubject("bound@example.com"). + SetVerifiedAt(time.Now().UTC()). + SetMetadata(map[string]any{"source": "preexisting"}). + Save(ctx) + require.NoError(t, err) + + token, gotUser, err := svc.Login(ctx, user.Email, "password") + require.NoError(t, err) + require.NotEmpty(t, token) + require.NotNil(t, gotUser) + svc.RecordSuccessfulLogin(ctx, user.ID) + + storedUser, err := client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, 2.0, storedUser.Balance) + require.Equal(t, 3, storedUser.Concurrency) + require.Empty(t, assigner.calls) + require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind")) +} + +func TestAuthServiceLogin_DoesNotRetryEmailFirstBindDefaultsForBackfilledEmailIdentity(t *testing.T) { + assigner := &flakyAuthIdentityDefaultSubAssignerStub{failuresRemaining: 1} + svc, _, client := newAuthServiceWithEnt(t, map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyAuthSourceDefaultEmailBalance: "8.5", + service.SettingKeyAuthSourceDefaultEmailConcurrency: "4", + service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`, + service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true", + }, assigner) + ctx := context.Background() + + passwordHash, err := svc.HashPassword("password") + require.NoError(t, err) + user, err := client.User.Create(). + SetEmail("retry-first-bind@example.com"). + SetUsername("retry-user"). + SetPasswordHash(passwordHash). + SetBalance(1.5). + SetConcurrency(2). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + token, gotUser, err := svc.Login(ctx, user.Email, "password") + require.NoError(t, err) + require.NotEmpty(t, token) + require.NotNil(t, gotUser) + svc.RecordSuccessfulLogin(ctx, user.ID) + + storedUser, err := client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, 1.5, storedUser.Balance) + require.Equal(t, 2, storedUser.Concurrency) + require.Empty(t, assigner.calls) + require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind")) + + token, gotUser, err = svc.Login(ctx, user.Email, "password") + require.NoError(t, err) + require.NotEmpty(t, token) + require.NotNil(t, gotUser) + svc.RecordSuccessfulLogin(ctx, user.ID) + + storedUser, err = client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, 1.5, storedUser.Balance) + require.Equal(t, 2, storedUser.Concurrency) + require.Empty(t, assigner.calls) + require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind")) +} + +func countProviderGrantRecords( + t *testing.T, + client *dbent.Client, + userID int64, + providerType string, + grantReason string, +) int { + t.Helper() + + var count int + rows, err := client.QueryContext( + context.Background(), + `SELECT COUNT(*) FROM user_provider_default_grants WHERE user_id = ? AND provider_type = ? AND grant_reason = ?`, + userID, + providerType, + grantReason, + ) + require.NoError(t, err) + defer rows.Close() + require.True(t, rows.Next()) + require.NoError(t, rows.Scan(&count)) + require.NoError(t, rows.Err()) + return count +} diff --git a/backend/internal/service/auth_service_pending_oauth_test.go b/backend/internal/service/auth_service_pending_oauth_test.go deleted file mode 100644 index 0472e06c72d7809d55d3f227bbd0bafc778880cd..0000000000000000000000000000000000000000 --- a/backend/internal/service/auth_service_pending_oauth_test.go +++ /dev/null @@ -1,146 +0,0 @@ -//go:build unit - -package service - -import ( - "testing" - "time" - - "github.com/Wei-Shaw/sub2api/internal/config" - "github.com/golang-jwt/jwt/v5" - "github.com/stretchr/testify/require" -) - -func newAuthServiceForPendingOAuthTest() *AuthService { - cfg := &config.Config{ - JWT: config.JWTConfig{ - Secret: "test-secret-pending-oauth", - ExpireHour: 1, - }, - } - return NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil) -} - -// TestVerifyPendingOAuthToken_ValidToken 验证正常签发的 pending token 可以被成功解析。 -func TestVerifyPendingOAuthToken_ValidToken(t *testing.T) { - svc := newAuthServiceForPendingOAuthTest() - - token, err := svc.CreatePendingOAuthToken("user@example.com", "alice") - require.NoError(t, err) - require.NotEmpty(t, token) - - email, username, err := svc.VerifyPendingOAuthToken(token) - require.NoError(t, err) - require.Equal(t, "user@example.com", email) - require.Equal(t, "alice", username) -} - -// TestVerifyPendingOAuthToken_RegularJWTRejected 用普通 access token 尝试验证,应返回 ErrInvalidToken。 -func TestVerifyPendingOAuthToken_RegularJWTRejected(t *testing.T) { - svc := newAuthServiceForPendingOAuthTest() - - // 签发一个普通 access token(JWTClaims,无 Purpose 字段) - accessToken, err := svc.GenerateToken(&User{ - ID: 1, - Email: "user@example.com", - Role: RoleUser, - }) - require.NoError(t, err) - - _, _, err = svc.VerifyPendingOAuthToken(accessToken) - require.ErrorIs(t, err, ErrInvalidToken) -} - -// TestVerifyPendingOAuthToken_WrongPurpose 手动构造 purpose 字段不匹配的 JWT,应返回 ErrInvalidToken。 -func TestVerifyPendingOAuthToken_WrongPurpose(t *testing.T) { - svc := newAuthServiceForPendingOAuthTest() - - now := time.Now() - claims := &pendingOAuthClaims{ - Email: "user@example.com", - Username: "alice", - Purpose: "some_other_purpose", - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)), - IssuedAt: jwt.NewNumericDate(now), - NotBefore: jwt.NewNumericDate(now), - }, - } - tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret)) - require.NoError(t, err) - - _, _, err = svc.VerifyPendingOAuthToken(tokenStr) - require.ErrorIs(t, err, ErrInvalidToken) -} - -// TestVerifyPendingOAuthToken_MissingPurpose 手动构造无 purpose 字段的 JWT(模拟旧 token),应返回 ErrInvalidToken。 -func TestVerifyPendingOAuthToken_MissingPurpose(t *testing.T) { - svc := newAuthServiceForPendingOAuthTest() - - now := time.Now() - claims := &pendingOAuthClaims{ - Email: "user@example.com", - Username: "alice", - Purpose: "", // 旧 token 无此字段,反序列化后为零值 - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)), - IssuedAt: jwt.NewNumericDate(now), - NotBefore: jwt.NewNumericDate(now), - }, - } - tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret)) - require.NoError(t, err) - - _, _, err = svc.VerifyPendingOAuthToken(tokenStr) - require.ErrorIs(t, err, ErrInvalidToken) -} - -// TestVerifyPendingOAuthToken_ExpiredToken 过期 token 应返回 ErrInvalidToken。 -func TestVerifyPendingOAuthToken_ExpiredToken(t *testing.T) { - svc := newAuthServiceForPendingOAuthTest() - - past := time.Now().Add(-1 * time.Hour) - claims := &pendingOAuthClaims{ - Email: "user@example.com", - Username: "alice", - Purpose: pendingOAuthPurpose, - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(past), - IssuedAt: jwt.NewNumericDate(past.Add(-10 * time.Minute)), - NotBefore: jwt.NewNumericDate(past.Add(-10 * time.Minute)), - }, - } - tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret)) - require.NoError(t, err) - - _, _, err = svc.VerifyPendingOAuthToken(tokenStr) - require.ErrorIs(t, err, ErrInvalidToken) -} - -// TestVerifyPendingOAuthToken_WrongSecret 不同密钥签发的 token 应返回 ErrInvalidToken。 -func TestVerifyPendingOAuthToken_WrongSecret(t *testing.T) { - other := NewAuthService(nil, nil, nil, nil, &config.Config{ - JWT: config.JWTConfig{Secret: "other-secret"}, - }, nil, nil, nil, nil, nil, nil) - - token, err := other.CreatePendingOAuthToken("user@example.com", "alice") - require.NoError(t, err) - - svc := newAuthServiceForPendingOAuthTest() - _, _, err = svc.VerifyPendingOAuthToken(token) - require.ErrorIs(t, err, ErrInvalidToken) -} - -// TestVerifyPendingOAuthToken_TooLong 超长 token 应返回 ErrInvalidToken。 -func TestVerifyPendingOAuthToken_TooLong(t *testing.T) { - svc := newAuthServiceForPendingOAuthTest() - giant := make([]byte, maxTokenLength+1) - for i := range giant { - giant[i] = 'a' - } - _, _, err := svc.VerifyPendingOAuthToken(string(giant)) - require.ErrorIs(t, err, ErrInvalidToken) -} diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index 103bafe709943160f2d0aefc181a88c1744149d9..dbd18a20a979c4b968669d3bf76b72daab9fe94d 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -37,7 +37,16 @@ func (s *settingRepoStub) Set(ctx context.Context, key, value string) error { } func (s *settingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { - panic("unexpected GetMultiple call") + if s.err != nil { + return nil, s.err + } + result := make(map[string]string, len(keys)) + for _, key := range keys { + if v, ok := s.values[key]; ok { + result[key] = v + } + } + return result, nil } func (s *settingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { @@ -62,6 +71,8 @@ type defaultSubscriptionAssignerStub struct { err error } +type refreshTokenCacheStub struct{} + func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) { if input != nil { s.calls = append(s.calls, *input) @@ -72,6 +83,46 @@ func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.C return &UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil } +func (s *refreshTokenCacheStub) StoreRefreshToken(context.Context, string, *RefreshTokenData, time.Duration) error { + return nil +} + +func (s *refreshTokenCacheStub) GetRefreshToken(context.Context, string) (*RefreshTokenData, error) { + return nil, ErrRefreshTokenNotFound +} + +func (s *refreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error { + return nil +} + +func (s *refreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error { + return nil +} + +func (s *refreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error { + return nil +} + +func (s *refreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error { + return nil +} + +func (s *refreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error { + return nil +} + +func (s *refreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) { + return nil, nil +} + +func (s *refreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) { + return nil, nil +} + +func (s *refreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) { + return false, nil +} + func (s *emailCacheStub) GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error) { if s.err != nil { return nil, s.err @@ -322,7 +373,8 @@ func TestAuthService_Register_CreateEmailExistsRace(t *testing.T) { func TestAuthService_Register_Success(t *testing.T) { repo := &userRepoStub{nextID: 5} service := newAuthService(repo, map[string]string{ - SettingKeyRegistrationEnabled: "true", + SettingKeyRegistrationEnabled: "true", + SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false", }, nil) token, user, err := service.Register(context.Background(), "user@test.com", "password") @@ -469,8 +521,9 @@ func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) { repo := &userRepoStub{nextID: 42} assigner := &defaultSubscriptionAssignerStub{} service := newAuthService(repo, map[string]string{ - SettingKeyRegistrationEnabled: "true", - SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`, + SettingKeyRegistrationEnabled: "true", + SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`, + SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false", }, nil) service.defaultSubAssigner = assigner @@ -484,3 +537,132 @@ func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) { require.Equal(t, int64(12), assigner.calls[1].GroupID) require.Equal(t, 7, assigner.calls[1].ValidityDays) } + +func TestAuthService_Register_UsesEmailAuthSourceDefaultsWhenGrantEnabled(t *testing.T) { + repo := &userRepoStub{nextID: 52} + assigner := &defaultSubscriptionAssignerStub{} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyDefaultSubscriptions: `[{"group_id":91,"validity_days":3}]`, + SettingKeyAuthSourceDefaultEmailBalance: "12.5", + SettingKeyAuthSourceDefaultEmailConcurrency: "7", + SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`, + SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true", + }, nil) + service.defaultSubAssigner = assigner + + _, user, err := service.Register(context.Background(), "email-defaults@test.com", "password") + require.NoError(t, err) + require.NotNil(t, user) + require.Equal(t, 12.5, user.Balance) + require.Equal(t, 7, user.Concurrency) + require.Len(t, assigner.calls, 1) + require.Equal(t, int64(11), assigner.calls[0].GroupID) + require.Equal(t, 30, assigner.calls[0].ValidityDays) +} + +func TestAuthService_Register_GrantOnSignupFalseFallsBackToGlobalDefaults(t *testing.T) { + repo := &userRepoStub{nextID: 53} + assigner := &defaultSubscriptionAssignerStub{} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyDefaultSubscriptions: `[{"group_id":31,"validity_days":5}]`, + SettingKeyAuthSourceDefaultEmailBalance: "99", + SettingKeyAuthSourceDefaultEmailConcurrency: "88", + SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":32,"validity_days":9}]`, + SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false", + }, nil) + service.defaultSubAssigner = assigner + + _, user, err := service.Register(context.Background(), "email-global@test.com", "password") + require.NoError(t, err) + require.NotNil(t, user) + require.Equal(t, 3.5, user.Balance) + require.Equal(t, 2, user.Concurrency) + require.Len(t, assigner.calls, 1) + require.Equal(t, int64(31), assigner.calls[0].GroupID) + require.Equal(t, 5, assigner.calls[0].ValidityDays) +} + +func TestAuthService_Register_GrantOnSignupMergesSourceOverridesWithGlobalDefaults(t *testing.T) { + repo := &userRepoStub{nextID: 54} + assigner := &defaultSubscriptionAssignerStub{} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyDefaultSubscriptions: `[{"group_id":31,"validity_days":5}]`, + SettingKeyAuthSourceDefaultEmailBalance: "9.5", + SettingKeyAuthSourceDefaultEmailConcurrency: "5", + SettingKeyAuthSourceDefaultEmailSubscriptions: `[]`, + SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true", + }, nil) + service.defaultSubAssigner = assigner + + _, user, err := service.Register(context.Background(), "email-merged@test.com", "password") + require.NoError(t, err) + require.NotNil(t, user) + require.Equal(t, 9.5, user.Balance) + require.Equal(t, 2, user.Concurrency) + require.Len(t, assigner.calls, 1) + require.Equal(t, int64(31), assigner.calls[0].GroupID) + require.Equal(t, 5, assigner.calls[0].ValidityDays) +} + +func TestAuthService_LoginOrRegisterOAuthWithTokenPair_UsesLinuxDoAuthSourceDefaultsOnSignup(t *testing.T) { + repo := &userRepoStub{nextID: 61} + assigner := &defaultSubscriptionAssignerStub{} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyDefaultSubscriptions: `[{"group_id":81,"validity_days":1}]`, + SettingKeyAuthSourceDefaultLinuxDoBalance: "21.75", + SettingKeyAuthSourceDefaultLinuxDoConcurrency: "9", + SettingKeyAuthSourceDefaultLinuxDoSubscriptions: `[{"group_id":22,"validity_days":14}]`, + SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true", + }, nil) + service.defaultSubAssigner = assigner + service.refreshTokenCache = &refreshTokenCacheStub{} + + tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "") + require.NoError(t, err) + require.NotNil(t, tokenPair) + require.NotNil(t, user) + require.Equal(t, int64(61), user.ID) + require.Equal(t, 21.75, user.Balance) + require.Equal(t, 9, user.Concurrency) + require.Len(t, repo.created, 1) + require.Len(t, assigner.calls, 1) + require.Equal(t, int64(22), assigner.calls[0].GroupID) + require.Equal(t, 14, assigner.calls[0].ValidityDays) +} + +func TestAuthService_LoginOrRegisterOAuthWithTokenPair_ExistingUserDoesNotGrantAgain(t *testing.T) { + existing := &User{ + ID: 88, + Email: "linuxdo-123@linuxdo-connect.invalid", + Username: "existing-linuxdo", + Role: RoleUser, + Status: StatusActive, + Balance: 4, + Concurrency: 1, + TokenVersion: 2, + } + repo := &userRepoStub{user: existing} + assigner := &defaultSubscriptionAssignerStub{} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyAuthSourceDefaultLinuxDoBalance: "21.75", + SettingKeyAuthSourceDefaultLinuxDoConcurrency: "9", + SettingKeyAuthSourceDefaultLinuxDoSubscriptions: `[{"group_id":22,"validity_days":14}]`, + SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true", + }, nil) + service.defaultSubAssigner = assigner + service.refreshTokenCache = &refreshTokenCacheStub{} + + tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "") + require.NoError(t, err) + require.NotNil(t, tokenPair) + require.Equal(t, existing.ID, user.ID) + require.Equal(t, 4.0, user.Balance) + require.Equal(t, 1, user.Concurrency) + require.Empty(t, repo.created) + require.Empty(t, assigner.calls) +} diff --git a/backend/internal/service/billing_cache_service_singleflight_test.go b/backend/internal/service/billing_cache_service_singleflight_test.go index 4a8b8f03e570c67319b946f82a6ba3f2700c52e6..0eaf4570bbf4b3c16937dae4f2f2ce027bf58d34 100644 --- a/backend/internal/service/billing_cache_service_singleflight_test.go +++ b/backend/internal/service/billing_cache_service_singleflight_test.go @@ -86,6 +86,14 @@ func (s *balanceLoadUserRepoStub) GetByID(ctx context.Context, id int64) (*User, return &User{ID: id, Balance: s.balance}, nil } +func (s *balanceLoadUserRepoStub) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) { + return nil, nil +} + +func (s *balanceLoadUserRepoStub) UnbindUserAuthProvider(context.Context, int64, string) error { + return nil +} + func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) { cache := &billingCacheMissStub{} userRepo := &balanceLoadUserRepoStub{ diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index cb452efbeec05e037e188c3eea0a0c0c55a7313e..3c6888b8fe7ae7d4936dd2bc5589cd133be09acf 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -74,6 +74,9 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid" // OIDCConnectSyntheticEmailDomain 是 OIDC 用户的合成邮箱后缀(RFC 保留域名)。 const OIDCConnectSyntheticEmailDomain = "@oidc-connect.invalid" +// WeChatConnectSyntheticEmailDomain 是 WeChat Connect 用户的合成邮箱后缀(RFC 保留域名)。 +const WeChatConnectSyntheticEmailDomain = "@wechat-connect.invalid" + // Setting keys const ( // 注册设置 @@ -108,6 +111,24 @@ const ( SettingKeyLinuxDoConnectClientSecret = "linuxdo_connect_client_secret" SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url" + // WeChat Connect OAuth 登录设置 + SettingKeyWeChatConnectEnabled = "wechat_connect_enabled" + SettingKeyWeChatConnectAppID = "wechat_connect_app_id" + SettingKeyWeChatConnectAppSecret = "wechat_connect_app_secret" + SettingKeyWeChatConnectOpenAppID = "wechat_connect_open_app_id" + SettingKeyWeChatConnectOpenAppSecret = "wechat_connect_open_app_secret" + SettingKeyWeChatConnectMPAppID = "wechat_connect_mp_app_id" + SettingKeyWeChatConnectMPAppSecret = "wechat_connect_mp_app_secret" + SettingKeyWeChatConnectMobileAppID = "wechat_connect_mobile_app_id" + SettingKeyWeChatConnectMobileAppSecret = "wechat_connect_mobile_app_secret" + SettingKeyWeChatConnectOpenEnabled = "wechat_connect_open_enabled" + SettingKeyWeChatConnectMPEnabled = "wechat_connect_mp_enabled" + SettingKeyWeChatConnectMobileEnabled = "wechat_connect_mobile_enabled" + SettingKeyWeChatConnectMode = "wechat_connect_mode" + SettingKeyWeChatConnectScopes = "wechat_connect_scopes" + SettingKeyWeChatConnectRedirectURL = "wechat_connect_redirect_url" + SettingKeyWeChatConnectFrontendRedirectURL = "wechat_connect_frontend_redirect_url" + // Generic OIDC OAuth 登录设置 SettingKeyOIDCConnectEnabled = "oidc_connect_enabled" SettingKeyOIDCConnectProviderName = "oidc_connect_provider_name" @@ -153,6 +174,29 @@ const ( SettingKeyDefaultBalance = "default_balance" // 新用户默认余额 SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON) + // 第三方认证来源默认授予配置 + SettingKeyAuthSourceDefaultEmailBalance = "auth_source_default_email_balance" + SettingKeyAuthSourceDefaultEmailConcurrency = "auth_source_default_email_concurrency" + SettingKeyAuthSourceDefaultEmailSubscriptions = "auth_source_default_email_subscriptions" + SettingKeyAuthSourceDefaultEmailGrantOnSignup = "auth_source_default_email_grant_on_signup" + SettingKeyAuthSourceDefaultEmailGrantOnFirstBind = "auth_source_default_email_grant_on_first_bind" + SettingKeyAuthSourceDefaultLinuxDoBalance = "auth_source_default_linuxdo_balance" + SettingKeyAuthSourceDefaultLinuxDoConcurrency = "auth_source_default_linuxdo_concurrency" + SettingKeyAuthSourceDefaultLinuxDoSubscriptions = "auth_source_default_linuxdo_subscriptions" + SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup = "auth_source_default_linuxdo_grant_on_signup" + SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind = "auth_source_default_linuxdo_grant_on_first_bind" + SettingKeyAuthSourceDefaultOIDCBalance = "auth_source_default_oidc_balance" + SettingKeyAuthSourceDefaultOIDCConcurrency = "auth_source_default_oidc_concurrency" + SettingKeyAuthSourceDefaultOIDCSubscriptions = "auth_source_default_oidc_subscriptions" + SettingKeyAuthSourceDefaultOIDCGrantOnSignup = "auth_source_default_oidc_grant_on_signup" + SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind = "auth_source_default_oidc_grant_on_first_bind" + SettingKeyAuthSourceDefaultWeChatBalance = "auth_source_default_wechat_balance" + SettingKeyAuthSourceDefaultWeChatConcurrency = "auth_source_default_wechat_concurrency" + SettingKeyAuthSourceDefaultWeChatSubscriptions = "auth_source_default_wechat_subscriptions" + SettingKeyAuthSourceDefaultWeChatGrantOnSignup = "auth_source_default_wechat_grant_on_signup" + SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind = "auth_source_default_wechat_grant_on_first_bind" + SettingKeyForceEmailOnThirdPartySignup = "force_email_on_third_party_signup" + // 管理员 API Key SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成) diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go index 6c09e354a1eb3b926dafa15ecd3cc1e4a37c11cc..5fda3abddec3bf82b9fcb03e6b8861b147c808e6 100644 --- a/backend/internal/service/openai_account_scheduler.go +++ b/backend/internal/service/openai_account_scheduler.go @@ -13,14 +13,30 @@ import ( "sync" "sync/atomic" "time" + + "golang.org/x/sync/singleflight" ) const ( openAIAccountScheduleLayerPreviousResponse = "previous_response_id" openAIAccountScheduleLayerSessionSticky = "session_hash" openAIAccountScheduleLayerLoadBalance = "load_balance" + openAIAdvancedSchedulerSettingKey = "openai_advanced_scheduler_enabled" +) + +const ( + openAIAdvancedSchedulerSettingCacheTTL = 5 * time.Second + openAIAdvancedSchedulerSettingDBTimeout = 2 * time.Second ) +type cachedOpenAIAdvancedSchedulerSetting struct { + enabled bool + expiresAt int64 +} + +var openAIAdvancedSchedulerSettingCache atomic.Value // *cachedOpenAIAdvancedSchedulerSetting +var openAIAdvancedSchedulerSettingSF singleflight.Group + type OpenAIAccountScheduleRequest struct { GroupID *int64 SessionHash string @@ -751,14 +767,13 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( } func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool { - // HTTP 入站可回退到 HTTP 线路,不需要在账号选择阶段做传输协议强过滤。 if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE { return true } - if s == nil || s.service == nil || account == nil { + if s == nil || s.service == nil { return false } - return s.service.getOpenAIWSProtocolResolver().Resolve(account).Transport == requiredTransport + return s.service.isOpenAIAccountTransportCompatible(account, requiredTransport) } func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int) { @@ -805,10 +820,56 @@ func (s *defaultOpenAIAccountScheduler) SnapshotMetrics() OpenAIAccountScheduler return snapshot } -func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountScheduler { +func (s *OpenAIGatewayService) openAIAdvancedSchedulerSettingRepo() SettingRepository { + if s == nil || s.rateLimitService == nil || s.rateLimitService.settingService == nil { + return nil + } + return s.rateLimitService.settingService.settingRepo +} + +func (s *OpenAIGatewayService) isOpenAIAdvancedSchedulerEnabled(ctx context.Context) bool { + if cached, ok := openAIAdvancedSchedulerSettingCache.Load().(*cachedOpenAIAdvancedSchedulerSetting); ok && cached != nil { + if time.Now().UnixNano() < cached.expiresAt { + return cached.enabled + } + } + + result, _, _ := openAIAdvancedSchedulerSettingSF.Do(openAIAdvancedSchedulerSettingKey, func() (any, error) { + if cached, ok := openAIAdvancedSchedulerSettingCache.Load().(*cachedOpenAIAdvancedSchedulerSetting); ok && cached != nil { + if time.Now().UnixNano() < cached.expiresAt { + return cached.enabled, nil + } + } + + enabled := false + if repo := s.openAIAdvancedSchedulerSettingRepo(); repo != nil { + dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), openAIAdvancedSchedulerSettingDBTimeout) + defer cancel() + + value, err := repo.GetValue(dbCtx, openAIAdvancedSchedulerSettingKey) + if err == nil { + enabled = strings.EqualFold(strings.TrimSpace(value), "true") + } + } + + openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{ + enabled: enabled, + expiresAt: time.Now().Add(openAIAdvancedSchedulerSettingCacheTTL).UnixNano(), + }) + return enabled, nil + }) + + enabled, _ := result.(bool) + return enabled +} + +func (s *OpenAIGatewayService) getOpenAIAccountScheduler(ctx context.Context) OpenAIAccountScheduler { if s == nil { return nil } + if !s.isOpenAIAdvancedSchedulerEnabled(ctx) { + return nil + } s.openaiSchedulerOnce.Do(func() { if s.openaiAccountStats == nil { s.openaiAccountStats = newOpenAIAccountRuntimeStats() @@ -820,6 +881,11 @@ func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountSchedule return s.openaiScheduler } +func resetOpenAIAdvancedSchedulerSettingCacheForTest() { + openAIAdvancedSchedulerSettingCache = atomic.Value{} + openAIAdvancedSchedulerSettingSF = singleflight.Group{} +} + func (s *OpenAIGatewayService) SelectAccountWithScheduler( ctx context.Context, groupID *int64, @@ -830,11 +896,37 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler( requiredTransport OpenAIUpstreamTransport, ) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) { decision := OpenAIAccountScheduleDecision{} - scheduler := s.getOpenAIAccountScheduler() + scheduler := s.getOpenAIAccountScheduler(ctx) if scheduler == nil { - selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs) decision.Layer = openAIAccountScheduleLayerLoadBalance - return selection, decision, err + if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE { + selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs) + return selection, decision, err + } + + effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs) + for { + selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs) + if err != nil { + return nil, decision, err + } + if selection == nil || selection.Account == nil { + return selection, decision, nil + } + if s.isOpenAIAccountTransportCompatible(selection.Account, requiredTransport) { + return selection, decision, nil + } + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + if effectiveExcludedIDs == nil { + effectiveExcludedIDs = make(map[int64]struct{}) + } + if _, exists := effectiveExcludedIDs[selection.Account.ID]; exists { + return nil, decision, ErrNoAvailableAccounts + } + effectiveExcludedIDs[selection.Account.ID] = struct{}{} + } } var stickyAccountID int64 @@ -855,8 +947,29 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler( }) } +func cloneExcludedAccountIDs(excludedIDs map[int64]struct{}) map[int64]struct{} { + if len(excludedIDs) == 0 { + return nil + } + cloned := make(map[int64]struct{}, len(excludedIDs)) + for id := range excludedIDs { + cloned[id] = struct{}{} + } + return cloned +} + +func (s *OpenAIGatewayService) isOpenAIAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool { + if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE { + return true + } + if s == nil || account == nil { + return false + } + return s.getOpenAIWSProtocolResolver().Resolve(account).Transport == requiredTransport +} + func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int) { - scheduler := s.getOpenAIAccountScheduler() + scheduler := s.getOpenAIAccountScheduler(context.Background()) if scheduler == nil { return } @@ -864,7 +977,7 @@ func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64 } func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() { - scheduler := s.getOpenAIAccountScheduler() + scheduler := s.getOpenAIAccountScheduler(context.Background()) if scheduler == nil { return } @@ -872,7 +985,7 @@ func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() { } func (s *OpenAIGatewayService) SnapshotOpenAIAccountSchedulerMetrics() OpenAIAccountSchedulerMetricsSnapshot { - scheduler := s.getOpenAIAccountScheduler() + scheduler := s.getOpenAIAccountScheduler(context.Background()) if scheduler == nil { return OpenAIAccountSchedulerMetricsSnapshot{} } diff --git a/backend/internal/service/openai_account_scheduler_test.go b/backend/internal/service/openai_account_scheduler_test.go index 088815ed40ae239a0dc4c4dc0dd56ad6f61ce11d..b02370cb5ffd28c9caf9c57bf4c66e1559b4518a 100644 --- a/backend/internal/service/openai_account_scheduler_test.go +++ b/backend/internal/service/openai_account_scheduler_test.go @@ -2,6 +2,7 @@ package service import ( "context" + "errors" "fmt" "math" "sync" @@ -18,6 +19,202 @@ type openAISnapshotCacheStub struct { accountsByID map[int64]*Account } +type schedulerTestOpenAIAccountRepo struct { + AccountRepository + accounts []Account +} + +func (r schedulerTestOpenAIAccountRepo) GetByID(ctx context.Context, id int64) (*Account, error) { + for i := range r.accounts { + if r.accounts[i].ID == id { + return &r.accounts[i], nil + } + } + return nil, errors.New("account not found") +} + +func (r schedulerTestOpenAIAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) { + var result []Account + for _, acc := range r.accounts { + if acc.Platform == platform { + result = append(result, acc) + } + } + return result, nil +} + +func (r schedulerTestOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) { + var result []Account + for _, acc := range r.accounts { + if acc.Platform == platform { + result = append(result, acc) + } + } + return result, nil +} + +func (r schedulerTestOpenAIAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) { + return r.ListSchedulableByPlatform(ctx, platform) +} + +type schedulerTestConcurrencyCache struct { + ConcurrencyCache + loadBatchErr error + loadMap map[int64]*AccountLoadInfo + acquireResults map[int64]bool + waitCounts map[int64]int + skipDefaultLoad bool +} + +func (c schedulerTestConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { + if c.acquireResults != nil { + if result, ok := c.acquireResults[accountID]; ok { + return result, nil + } + } + return true, nil +} + +func (c schedulerTestConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error { + return nil +} + +func (c schedulerTestConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) { + if c.loadBatchErr != nil { + return nil, c.loadBatchErr + } + out := make(map[int64]*AccountLoadInfo, len(accounts)) + if c.skipDefaultLoad && c.loadMap != nil { + for _, acc := range accounts { + if load, ok := c.loadMap[acc.ID]; ok { + out[acc.ID] = load + } + } + return out, nil + } + for _, acc := range accounts { + if c.loadMap != nil { + if load, ok := c.loadMap[acc.ID]; ok { + out[acc.ID] = load + continue + } + } + out[acc.ID] = &AccountLoadInfo{AccountID: acc.ID, LoadRate: 0} + } + return out, nil +} + +func (c schedulerTestConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) { + if c.waitCounts != nil { + if count, ok := c.waitCounts[accountID]; ok { + return count, nil + } + } + return 0, nil +} + +type schedulerTestGatewayCache struct { + sessionBindings map[string]int64 + deletedSessions map[string]int +} + +func (c *schedulerTestGatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) { + if id, ok := c.sessionBindings[sessionHash]; ok { + return id, nil + } + return 0, errors.New("not found") +} + +func (c *schedulerTestGatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error { + if c.sessionBindings == nil { + c.sessionBindings = make(map[string]int64) + } + c.sessionBindings[sessionHash] = accountID + return nil +} + +func (c *schedulerTestGatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error { + return nil +} + +func (c *schedulerTestGatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error { + if c.sessionBindings == nil { + return nil + } + if c.deletedSessions == nil { + c.deletedSessions = make(map[string]int) + } + c.deletedSessions[sessionHash]++ + delete(c.sessionBindings, sessionHash) + return nil +} + +func newSchedulerTestOpenAIWSV2Config() *config.Config { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600 + return cfg +} + +type openAIAdvancedSchedulerSettingRepoStub struct { + values map[string]string +} + +func (s *openAIAdvancedSchedulerSettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) { + value, err := s.GetValue(ctx, key) + if err != nil { + return nil, err + } + return &Setting{Key: key, Value: value}, nil +} + +func (s *openAIAdvancedSchedulerSettingRepoStub) GetValue(_ context.Context, key string) (string, error) { + if s == nil || s.values == nil { + return "", ErrSettingNotFound + } + value, ok := s.values[key] + if !ok { + return "", ErrSettingNotFound + } + return value, nil +} + +func (s *openAIAdvancedSchedulerSettingRepoStub) Set(context.Context, string, string) error { + panic("unexpected call to Set") +} + +func (s *openAIAdvancedSchedulerSettingRepoStub) GetMultiple(context.Context, []string) (map[string]string, error) { + panic("unexpected call to GetMultiple") +} + +func (s *openAIAdvancedSchedulerSettingRepoStub) SetMultiple(context.Context, map[string]string) error { + panic("unexpected call to SetMultiple") +} + +func (s *openAIAdvancedSchedulerSettingRepoStub) GetAll(context.Context) (map[string]string, error) { + panic("unexpected call to GetAll") +} + +func (s *openAIAdvancedSchedulerSettingRepoStub) Delete(context.Context, string) error { + panic("unexpected call to Delete") +} + +func newOpenAIAdvancedSchedulerRateLimitService(enabled string) *RateLimitService { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + repo := &openAIAdvancedSchedulerSettingRepoStub{ + values: map[string]string{}, + } + if enabled != "" { + repo.values[openAIAdvancedSchedulerSettingKey] = enabled + } + return &RateLimitService{ + settingService: NewSettingService(repo, &config.Config{}), + } +} + func (s *openAISnapshotCacheStub) GetSnapshot(ctx context.Context, bucket SchedulerBucket) ([]*Account, bool, error) { if len(s.snapshotAccounts) == 0 { return nil, false, nil @@ -45,6 +242,230 @@ func (s *openAISnapshotCacheStub) GetAccount(ctx context.Context, accountID int6 return &cloned, nil } +func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabledUsesLegacyLoadAwareness(t *testing.T) { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + + ctx := context.Background() + groupID := int64(10106) + accounts := []Account{ + { + ID: 36001, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 5, + }, + { + ID: 36002, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + }, + } + cfg := &config.Config{} + cfg.Gateway.Scheduling.LoadBatchEnabled = false + cache := &schedulerTestGatewayCache{} + svc := &OpenAIGatewayService{ + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), + } + + store := svc.getOpenAIWSStateStore() + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_disabled_001", 36001, time.Hour)) + require.False(t, svc.isOpenAIAdvancedSchedulerEnabled(ctx)) + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "resp_disabled_001", + "", + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(36002), selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) + require.False(t, decision.StickyPreviousHit) +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabled_RequiredWSV2_SkipsHTTPOnlyAccount(t *testing.T) { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + + ctx := context.Background() + groupID := int64(10108) + accounts := []Account{ + { + ID: 36011, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + }, + { + ID: 36012, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 5, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + }, + } + cfg := newSchedulerTestOpenAIWSV2Config() + cfg.Gateway.Scheduling.LoadBatchEnabled = false + svc := &OpenAIGatewayService{ + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: &schedulerTestGatewayCache{}, + cfg: cfg, + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), + } + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "", + "gpt-5.1", + nil, + OpenAIUpstreamTransportResponsesWebsocketV2, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(36012), selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabled_RequiredWSV2_NoAvailableAccount(t *testing.T) { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + + ctx := context.Background() + groupID := int64(10109) + accounts := []Account{ + { + ID: 36021, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + }, + } + cfg := newSchedulerTestOpenAIWSV2Config() + cfg.Gateway.Scheduling.LoadBatchEnabled = false + svc := &OpenAIGatewayService{ + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: &schedulerTestGatewayCache{}, + cfg: cfg, + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), + } + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "", + "gpt-5.1", + nil, + OpenAIUpstreamTransportResponsesWebsocketV2, + ) + require.ErrorContains(t, err, "no available OpenAI accounts") + require.Nil(t, selection) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_EnabledUsesAdvancedPreviousResponseRouting(t *testing.T) { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + + ctx := context.Background() + groupID := int64(10107) + accounts := []Account{ + { + ID: 37001, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 5, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + }, + { + ID: 37002, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + }, + } + cfg := &config.Config{} + cfg.Gateway.Scheduling.LoadBatchEnabled = false + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600 + svc := &OpenAIGatewayService{ + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: &schedulerTestGatewayCache{}, + cfg: cfg, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), + } + + store := svc.getOpenAIWSStateStore() + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_enabled_001", 37001, time.Hour)) + require.True(t, svc.isOpenAIAdvancedSchedulerEnabled(ctx)) + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "resp_enabled_001", + "", + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(37001), selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerPreviousResponse, decision.Layer) + require.True(t, decision.StickyPreviousHit) +} + +func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics_DisabledNoOp(t *testing.T) { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + + svc := &OpenAIGatewayService{} + ttft := 120 + svc.ReportOpenAIAccountScheduleResult(10, true, &ttft) + svc.RecordOpenAIAccountSwitch() + + snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics() + require.Equal(t, OpenAIAccountSchedulerMetricsSnapshot{}, snapshot) +} + func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimitedAccountFallsBackToFreshCandidate(t *testing.T) { ctx := context.Background() groupID := int64(10101) @@ -53,10 +474,17 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimite staleBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} freshSticky := &Account{ID: 31001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil} freshBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} - cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_rate_limited": 31001}} + cache := &schedulerTestGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_rate_limited": 31001}} snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{staleSticky, staleBackup}, accountsByID: map[int64]*Account{31001: freshSticky, 31002: freshBackup}} snapshotService := &SchedulerSnapshotService{cache: snapshotCache} - svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshSticky, *freshBackup}}, cache: cache, cfg: &config.Config{}, schedulerSnapshot: snapshotService, concurrencyService: NewConcurrencyService(stubConcurrencyCache{})} + svc := &OpenAIGatewayService{ + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*freshSticky, *freshBackup}}, + cache: cache, + cfg: &config.Config{}, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + schedulerSnapshot: snapshotService, + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), + } selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_rate_limited", "gpt-5.1", nil, OpenAIUpstreamTransportAny) require.NoError(t, err) @@ -76,7 +504,12 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRa freshSecondary := &Account{ID: 32002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{stalePrimary, staleSecondary}, accountsByID: map[int64]*Account{32001: freshPrimary, 32002: freshSecondary}} snapshotService := &SchedulerSnapshotService{cache: snapshotCache} - svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshPrimary, *freshSecondary}}, cfg: &config.Config{}, schedulerSnapshot: snapshotService} + svc := &OpenAIGatewayService{ + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*freshPrimary, *freshSecondary}}, + cfg: &config.Config{}, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + schedulerSnapshot: snapshotService, + } account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gpt-5.1", nil) require.NoError(t, err) @@ -92,18 +525,19 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyDBRuntimeR staleBackup := &Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} dbSticky := Account{ID: 33001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil} dbBackup := Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} - cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_db_runtime_recheck": 33001}} + cache := &schedulerTestGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_db_runtime_recheck": 33001}} snapshotCache := &openAISnapshotCacheStub{ snapshotAccounts: []*Account{staleSticky, staleBackup}, accountsByID: map[int64]*Account{33001: staleSticky, 33002: staleBackup}, } snapshotService := &SchedulerSnapshotService{cache: snapshotCache} svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbSticky, dbBackup}}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{dbSticky, dbBackup}}, cache: cache, cfg: &config.Config{}, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), schedulerSnapshot: snapshotService, - concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), } selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_db_runtime_recheck", "gpt-5.1", nil, OpenAIUpstreamTransportAny) @@ -128,8 +562,9 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_DBRuntimeReche } snapshotService := &SchedulerSnapshotService{cache: snapshotCache} svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbPrimary, dbSecondary}}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{dbPrimary, dbSecondary}}, cfg: &config.Config{}, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), schedulerSnapshot: snapshotService, } @@ -153,7 +588,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky( "openai_apikey_responses_websockets_v2_enabled": true, }, } - cache := &stubGatewayCache{} + cache := &schedulerTestGatewayCache{} cfg := &config.Config{} cfg.Gateway.OpenAIWS.Enabled = true cfg.Gateway.OpenAIWS.OAuthEnabled = true @@ -163,10 +598,11 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky( cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600 svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}}, cache: cache, cfg: cfg, - concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), } store := svc.getOpenAIWSStateStore() @@ -204,17 +640,18 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky(t *testin Schedulable: true, Concurrency: 1, } - cache := &stubGatewayCache{ + cache := &schedulerTestGatewayCache{ sessionBindings: map[string]int64{ "openai:session_hash_abc": account.ID, }, } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}}, cache: cache, cfg: &config.Config{}, - concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), } selection, decision, err := svc.SelectAccountWithScheduler( @@ -260,7 +697,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS Priority: 9, }, } - cache := &stubGatewayCache{ + cache := &schedulerTestGatewayCache{ sessionBindings: map[string]int64{ "openai:session_hash_sticky_busy": 21001, }, @@ -273,7 +710,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS cfg.Gateway.OpenAIWS.OAuthEnabled = true cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true - concurrencyCache := stubConcurrencyCache{ + concurrencyCache := schedulerTestConcurrencyCache{ acquireResults: map[int64]bool{ 21001: false, // sticky 账号已满 21002: true, // 若回退负载均衡会命中该账号(本测试要求不能切换) @@ -288,9 +725,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, cache: cache, cfg: cfg, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), concurrencyService: NewConcurrencyService(concurrencyCache), } @@ -328,17 +766,18 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky_ForceHTTP "openai_ws_force_http": true, }, } - cache := &stubGatewayCache{ + cache := &schedulerTestGatewayCache{ sessionBindings: map[string]int64{ "openai:session_hash_force_http": account.ID, }, } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}}, cache: cache, cfg: &config.Config{}, - concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), } selection, decision, err := svc.SelectAccountWithScheduler( @@ -387,15 +826,15 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStick }, }, } - cache := &stubGatewayCache{ + cache := &schedulerTestGatewayCache{ sessionBindings: map[string]int64{ "openai:session_hash_ws_only": 2201, }, } - cfg := newOpenAIWSV2TestConfig() + cfg := newSchedulerTestOpenAIWSV2Config() // 构造“HTTP-only 账号负载更低”的场景,验证 required transport 会强制过滤。 - concurrencyCache := stubConcurrencyCache{ + concurrencyCache := schedulerTestConcurrencyCache{ loadMap: map[int64]*AccountLoadInfo{ 2201: {AccountID: 2201, LoadRate: 0, WaitingCount: 0}, 2202: {AccountID: 2202, LoadRate: 90, WaitingCount: 5}, @@ -403,9 +842,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStick } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, cache: cache, cfg: cfg, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), concurrencyService: NewConcurrencyService(concurrencyCache), } @@ -445,10 +885,11 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_NoAvailabl } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: accounts}, - cache: &stubGatewayCache{}, - cfg: newOpenAIWSV2TestConfig(), - concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: &schedulerTestGatewayCache{}, + cfg: newSchedulerTestOpenAIWSV2Config(), + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), } selection, decision, err := svc.SelectAccountWithScheduler( @@ -507,7 +948,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.2 cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.1 - concurrencyCache := stubConcurrencyCache{ + concurrencyCache := schedulerTestConcurrencyCache{ loadMap: map[int64]*AccountLoadInfo{ 3001: {AccountID: 3001, LoadRate: 95, WaitingCount: 8}, 3002: {AccountID: 3002, LoadRate: 20, WaitingCount: 1}, @@ -520,9 +961,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: accounts}, - cache: &stubGatewayCache{}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: &schedulerTestGatewayCache{}, cfg: cfg, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), concurrencyService: NewConcurrencyService(concurrencyCache), } @@ -559,16 +1001,17 @@ func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics(t *testing.T) { Schedulable: true, Concurrency: 1, } - cache := &stubGatewayCache{ + cache := &schedulerTestGatewayCache{ sessionBindings: map[string]int64{ "openai:session_hash_metrics": account.ID, }, } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}}, cache: cache, cfg: &config.Config{}, - concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), } selection, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_metrics", "gpt-5.1", nil, OpenAIUpstreamTransportAny) @@ -749,7 +1192,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesA cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1 cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1 - concurrencyCache := stubConcurrencyCache{ + concurrencyCache := schedulerTestConcurrencyCache{ loadMap: map[int64]*AccountLoadInfo{ 5101: {AccountID: 5101, LoadRate: 20, WaitingCount: 1}, 5102: {AccountID: 5102, LoadRate: 20, WaitingCount: 1}, @@ -757,9 +1200,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesA }, } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: accounts}, - cache: &stubGatewayCache{sessionBindings: map[string]int64{}}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: &schedulerTestGatewayCache{sessionBindings: map[string]int64{}}, cfg: cfg, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), concurrencyService: NewConcurrencyService(concurrencyCache), } @@ -905,12 +1349,14 @@ func TestDefaultOpenAIAccountScheduler_ReportSwitchAndSnapshot(t *testing.T) { } func TestOpenAIGatewayService_SchedulerWrappersAndDefaults(t *testing.T) { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + svc := &OpenAIGatewayService{} ttft := 120 svc.ReportOpenAIAccountScheduleResult(10, true, &ttft) svc.RecordOpenAIAccountSwitch() snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics() - require.GreaterOrEqual(t, snapshot.AccountSwitchTotal, int64(1)) + require.Equal(t, OpenAIAccountSchedulerMetricsSnapshot{}, snapshot) require.Equal(t, 7, svc.openAIWSLBTopK()) require.Equal(t, openaiStickySessionTTL, svc.openAIWSSessionStickyTTL()) @@ -947,7 +1393,7 @@ func TestDefaultOpenAIAccountScheduler_IsAccountTransportCompatible_Branches(t * require.True(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportHTTPSSE)) require.False(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportResponsesWebsocketV2)) - cfg := newOpenAIWSV2TestConfig() + cfg := newSchedulerTestOpenAIWSV2Config() scheduler.service = &OpenAIGatewayService{cfg: cfg} account := &Account{ ID: 8801, diff --git a/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go b/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go index c5de8203412ead756d693f12819f3c0d53b4a9ed..ddafc6eb76d5dbb9afcc9bf8ad2c74b3d474211a 100644 --- a/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go +++ b/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go @@ -38,11 +38,12 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_UsesWSPassthroughSnapsh cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: []Account{*account}}, - cache: &stubGatewayCache{}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*account}}, + cache: &schedulerTestGatewayCache{}, cfg: cfg, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), schedulerSnapshot: &SchedulerSnapshotService{cache: snapshotCache}, - concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), } selection, decision, err := svc.SelectAccountWithScheduler( diff --git a/backend/internal/service/openai_gateway_chat_completions.go b/backend/internal/service/openai_gateway_chat_completions.go index ac7d28a7f0ddbd33642bdbac610fc00a03249337..663066a35d9ce7381b841180a3408fef390c013b 100644 --- a/backend/internal/service/openai_gateway_chat_completions.go +++ b/backend/internal/service/openai_gateway_chat_completions.go @@ -107,11 +107,15 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( responsesBody = stripped } } + responsesBody, normalizedServiceTier, err := normalizeResponsesBodyServiceTier(responsesBody) + if err != nil { + return nil, fmt.Errorf("normalize service_tier in responses-shape body: %w", err) + } // Minimal stub populated from the raw body so downstream billing // propagation (ServiceTier, ReasoningEffort) keeps working. responsesReq = &apicompat.ResponsesRequest{ Model: upstreamModel, - ServiceTier: gjson.GetBytes(responsesBody, "service_tier").String(), + ServiceTier: normalizedServiceTier, } if effort := gjson.GetBytes(responsesBody, "reasoning.effort").String(); effort != "" { responsesReq.Reasoning = &apicompat.ResponsesReasoning{Effort: effort} @@ -124,6 +128,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( return nil, fmt.Errorf("convert chat completions to responses: %w", err) } responsesReq.Model = upstreamModel + normalizeResponsesRequestServiceTier(responsesReq) responsesBody, err = json.Marshal(responsesReq) if err != nil { return nil, fmt.Errorf("marshal responses request: %w", err) @@ -274,6 +279,41 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( return result, handleErr } +func normalizeResponsesRequestServiceTier(req *apicompat.ResponsesRequest) { + if req == nil { + return + } + req.ServiceTier = normalizedOpenAIServiceTierValue(req.ServiceTier) +} + +func normalizeResponsesBodyServiceTier(body []byte) ([]byte, string, error) { + if len(body) == 0 { + return body, "", nil + } + rawServiceTier := gjson.GetBytes(body, "service_tier").String() + if rawServiceTier == "" { + return body, "", nil + } + normalizedServiceTier := normalizedOpenAIServiceTierValue(rawServiceTier) + if normalizedServiceTier == "" { + trimmed, err := sjson.DeleteBytes(body, "service_tier") + return trimmed, "", err + } + if normalizedServiceTier == rawServiceTier { + return body, normalizedServiceTier, nil + } + trimmed, err := sjson.SetBytes(body, "service_tier", normalizedServiceTier) + return trimmed, normalizedServiceTier, err +} + +func normalizedOpenAIServiceTierValue(raw string) string { + normalized := normalizeOpenAIServiceTier(raw) + if normalized == nil { + return "" + } + return *normalized +} + // handleChatCompletionsErrorResponse reads an upstream error and returns it in // OpenAI Chat Completions error format. func (s *OpenAIGatewayService) handleChatCompletionsErrorResponse( diff --git a/backend/internal/service/openai_gateway_chat_completions_test.go b/backend/internal/service/openai_gateway_chat_completions_test.go new file mode 100644 index 0000000000000000000000000000000000000000..a00fb71cab773af64459759cf246cd0a7ce841b5 --- /dev/null +++ b/backend/internal/service/openai_gateway_chat_completions_test.go @@ -0,0 +1,44 @@ +package service + +import ( + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestNormalizeResponsesRequestServiceTier(t *testing.T) { + t.Parallel() + + req := &apicompat.ResponsesRequest{ServiceTier: " fast "} + normalizeResponsesRequestServiceTier(req) + require.Equal(t, "priority", req.ServiceTier) + + req.ServiceTier = "flex" + normalizeResponsesRequestServiceTier(req) + require.Equal(t, "flex", req.ServiceTier) + + req.ServiceTier = "default" + normalizeResponsesRequestServiceTier(req) + require.Empty(t, req.ServiceTier) +} + +func TestNormalizeResponsesBodyServiceTier(t *testing.T) { + t.Parallel() + + body, tier, err := normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"fast"}`)) + require.NoError(t, err) + require.Equal(t, "priority", tier) + require.Equal(t, "priority", gjson.GetBytes(body, "service_tier").String()) + + body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"flex"}`)) + require.NoError(t, err) + require.Equal(t, "flex", tier) + require.Equal(t, "flex", gjson.GetBytes(body, "service_tier").String()) + + body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"default"}`)) + require.NoError(t, err) + require.Empty(t, tier) + require.False(t, gjson.GetBytes(body, "service_tier").Exists()) +} diff --git a/backend/internal/service/payment_config_limits.go b/backend/internal/service/payment_config_limits.go index 569052788564c04c21fa164299e99c20832edb7b..57a4108f4b801a360a2128ecddc0bfcfe0f1d533 100644 --- a/backend/internal/service/payment_config_limits.go +++ b/backend/internal/service/payment_config_limits.go @@ -20,6 +20,7 @@ func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*M return nil, fmt.Errorf("query provider instances: %w", err) } typeInstances := pcGroupByPaymentType(instances) + typeInstances = pcApplyEnabledVisibleMethodInstances(typeInstances, instances) resp := &MethodLimitsResponse{ Methods: make(map[string]MethodLimits, len(typeInstances)), } @@ -31,6 +32,27 @@ func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*M return resp, nil } +func pcApplyEnabledVisibleMethodInstances(typeInstances map[string][]*dbent.PaymentProviderInstance, instances []*dbent.PaymentProviderInstance) map[string][]*dbent.PaymentProviderInstance { + if len(typeInstances) == 0 { + return typeInstances + } + + filtered := make(map[string][]*dbent.PaymentProviderInstance, len(typeInstances)) + for paymentType, groupedInstances := range typeInstances { + filtered[paymentType] = groupedInstances + } + + for _, method := range []string{payment.TypeAlipay, payment.TypeWxpay} { + matching := filterEnabledVisibleMethodInstances(instances, method) + if len(matching) != 1 { + delete(filtered, method) + continue + } + filtered[method] = []*dbent.PaymentProviderInstance{matching[0]} + } + return filtered +} + // GetMethodLimits returns per-payment-type limits from enabled provider instances. func (s *PaymentConfigService) GetMethodLimits(ctx context.Context, types []string) ([]MethodLimits, error) { instances, err := s.entClient.PaymentProviderInstance.Query(). diff --git a/backend/internal/service/payment_config_limits_test.go b/backend/internal/service/payment_config_limits_test.go index 73ad66ef03c60127040d391e7dfb091ebced8dc9..b392558381b65beaa4ae0bf3fd99c1e0ef2c1296 100644 --- a/backend/internal/service/payment_config_limits_test.go +++ b/backend/internal/service/payment_config_limits_test.go @@ -1,6 +1,7 @@ package service import ( + "context" "testing" dbent "github.com/Wei-Shaw/sub2api/ent" @@ -299,3 +300,66 @@ func TestPcInstanceTypeLimits(t *testing.T) { } }) } + +func TestGetAvailableMethodLimitsHidesConflictingVisibleMethodProviders(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeAlipay). + SetName("Official Alipay"). + SetConfig("{}"). + SetSupportedTypes("alipay"). + SetLimits(`{"alipay":{"singleMin":10,"singleMax":100}}`). + SetEnabled(true). + Save(ctx) + if err != nil { + t.Fatalf("create official alipay instance: %v", err) + } + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeEasyPay). + SetName("EasyPay Alipay"). + SetConfig("{}"). + SetSupportedTypes("alipay"). + SetLimits(`{"alipay":{"singleMin":20,"singleMax":200}}`). + SetEnabled(true). + Save(ctx) + if err != nil { + t.Fatalf("create easypay alipay instance: %v", err) + } + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("Official WeChat"). + SetConfig("{}"). + SetSupportedTypes("wxpay"). + SetLimits(`{"wxpay":{"singleMin":30,"singleMax":300}}`). + SetEnabled(true). + Save(ctx) + if err != nil { + t.Fatalf("create official wxpay instance: %v", err) + } + + svc := &PaymentConfigService{ + entClient: client, + } + + resp, err := svc.GetAvailableMethodLimits(ctx) + if err != nil { + t.Fatalf("GetAvailableMethodLimits returned error: %v", err) + } + + if _, ok := resp.Methods[payment.TypeAlipay]; ok { + t.Fatalf("alipay should be hidden when multiple enabled providers claim it, got %v", resp.Methods[payment.TypeAlipay]) + } + + wxpayLimits, ok := resp.Methods[payment.TypeWxpay] + if !ok { + t.Fatalf("expected wxpay limits to remain visible, got %v", resp.Methods) + } + if wxpayLimits.SingleMin != 30 || wxpayLimits.SingleMax != 300 { + t.Fatalf("wxpay limits = %+v, want official-only min=30 max=300", wxpayLimits) + } + if resp.GlobalMin != 30 || resp.GlobalMax != 300 { + t.Fatalf("global range = (%v, %v), want (30, 300)", resp.GlobalMin, resp.GlobalMax) + } +} diff --git a/backend/internal/service/payment_config_providers.go b/backend/internal/service/payment_config_providers.go index 30ff4253f5e625f93f878e12b02a579911b4dfe9..d2f89b06db7338a048d0ea3ef79d29c3ace8416c 100644 --- a/backend/internal/service/payment_config_providers.go +++ b/backend/internal/service/payment_config_providers.go @@ -150,6 +150,9 @@ func (s *PaymentConfigService) CreateProviderInstance(ctx context.Context, req C if err := validateProviderRequest(req.ProviderKey, req.Name, typesStr); err != nil { return nil, err } + if err := s.validateVisibleMethodEnablementConflicts(ctx, 0, req.ProviderKey, typesStr, req.Enabled); err != nil { + return nil, err + } if req.Enabled { if err := s.validateProviderConfig(req.ProviderKey, req.Config); err != nil { return nil, err @@ -183,26 +186,25 @@ func validateProviderRequest(providerKey, name, supportedTypes string) error { // NOTE: This function exceeds 30 lines due to per-field nil-check patch update // boilerplate and pending-order safety checks. func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id int64, req UpdateProviderInstanceRequest) (*dbent.PaymentProviderInstance, error) { - var cachedInst *dbent.PaymentProviderInstance - loadInst := func() (*dbent.PaymentProviderInstance, error) { - if cachedInst != nil { - return cachedInst, nil - } - inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id) - if err != nil { - return nil, fmt.Errorf("load provider instance: %w", err) - } - cachedInst = inst - return inst, nil + current, err := s.entClient.PaymentProviderInstance.Get(ctx, id) + if err != nil { + return nil, fmt.Errorf("load provider instance: %w", err) + } + nextEnabled := current.Enabled + if req.Enabled != nil { + nextEnabled = *req.Enabled + } + nextSupportedTypes := current.SupportedTypes + if req.SupportedTypes != nil { + nextSupportedTypes = joinTypes(req.SupportedTypes) + } + if err := s.validateVisibleMethodEnablementConflicts(ctx, id, current.ProviderKey, nextSupportedTypes, nextEnabled); err != nil { + return nil, err } if req.Config != nil { - inst, err := loadInst() - if err != nil { - return nil, err - } hasSensitive := false for k, v := range req.Config { - if v != "" && isSensitiveProviderConfigField(inst.ProviderKey, k) { + if v != "" && isSensitiveProviderConfigField(current.ProviderKey, k) { hasSensitive = true break } @@ -231,11 +233,7 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in // Validate merged config when the instance will end up enabled. // This surfaces provider-level errors (e.g. wxpay missing certSerial) at save time, // so admins see them in the dialog instead of only when an order is created. - inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id) - if err != nil { - return nil, fmt.Errorf("load provider instance: %w", err) - } - finalEnabled := inst.Enabled + finalEnabled := current.Enabled if req.Enabled != nil { finalEnabled = *req.Enabled } @@ -249,12 +247,12 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in if finalEnabled { configToValidate := mergedConfig if configToValidate == nil { - configToValidate, err = s.decryptConfig(inst.Config) + configToValidate, err = s.decryptConfig(current.Config) if err != nil { return nil, fmt.Errorf("decrypt existing config: %w", err) } } - if err := s.validateProviderConfig(inst.ProviderKey, configToValidate); err != nil { + if err := s.validateProviderConfig(current.ProviderKey, configToValidate); err != nil { return nil, err } } @@ -277,11 +275,7 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in } if count > 0 { // Load current instance to compare types - inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id) - if err != nil { - return nil, fmt.Errorf("load provider instance: %w", err) - } - oldTypes := strings.Split(inst.SupportedTypes, ",") + oldTypes := strings.Split(current.SupportedTypes, ",") newTypes := req.SupportedTypes for _, ot := range oldTypes { ot = strings.TrimSpace(ot) @@ -326,10 +320,7 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in if req.RefundEnabled != nil { refundEnabled = *req.RefundEnabled } else { - inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id) - if err == nil { - refundEnabled = inst.RefundEnabled - } + refundEnabled = current.RefundEnabled } if refundEnabled { u.SetAllowUserRefund(true) diff --git a/backend/internal/service/payment_config_providers_test.go b/backend/internal/service/payment_config_providers_test.go index bc2a9b18be49c735b705cc4a3c20c5167f2f3efc..2c0f8206e08f9065b1367bea3f76d07ab06449ea 100644 --- a/backend/internal/service/payment_config_providers_test.go +++ b/backend/internal/service/payment_config_providers_test.go @@ -3,8 +3,10 @@ package service import ( + "context" "testing" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -196,3 +198,122 @@ func TestJoinTypes(t *testing.T) { }) } } + +func TestCreateProviderInstanceRejectsConflictingVisibleMethodEnablement(t *testing.T) { + t.Parallel() + + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + svc := &PaymentConfigService{ + entClient: client, + encryptionKey: []byte("0123456789abcdef0123456789abcdef"), + } + + _, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{ + ProviderKey: "easypay", + Name: "EasyPay Alipay", + Config: map[string]string{ + "pid": "1001", + "pkey": "pkey-1001", + "apiBase": "https://pay.example.com", + "notifyUrl": "https://merchant.example.com/notify", + "returnUrl": "https://merchant.example.com/return", + }, + SupportedTypes: []string{"alipay"}, + Enabled: true, + }) + require.NoError(t, err) + + _, err = svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{ + ProviderKey: "alipay", + Name: "Official Alipay", + Config: map[string]string{"appId": "app-1"}, + SupportedTypes: []string{"alipay"}, + Enabled: true, + }) + require.Error(t, err) + require.Equal(t, "PAYMENT_PROVIDER_CONFLICT", infraerrors.Reason(err)) +} + +func TestUpdateProviderInstanceRejectsEnablingConflictingVisibleMethodProvider(t *testing.T) { + t.Parallel() + + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + svc := &PaymentConfigService{ + entClient: client, + encryptionKey: []byte("0123456789abcdef0123456789abcdef"), + } + + existing, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{ + ProviderKey: "easypay", + Name: "EasyPay WeChat", + Config: map[string]string{ + "pid": "2001", + "pkey": "pkey-2001", + "apiBase": "https://pay.example.com", + "notifyUrl": "https://merchant.example.com/notify", + "returnUrl": "https://merchant.example.com/return", + }, + SupportedTypes: []string{"wxpay"}, + Enabled: true, + }) + require.NoError(t, err) + require.NotNil(t, existing) + + candidate, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{ + ProviderKey: "wxpay", + Name: "Official WeChat", + Config: map[string]string{"appId": "wx-app"}, + SupportedTypes: []string{"wxpay"}, + Enabled: false, + }) + require.NoError(t, err) + + _, err = svc.UpdateProviderInstance(ctx, candidate.ID, UpdateProviderInstanceRequest{ + Enabled: boolPtrValue(true), + }) + require.Error(t, err) + require.Equal(t, "PAYMENT_PROVIDER_CONFLICT", infraerrors.Reason(err)) +} + +func TestUpdateProviderInstancePersistsEnabledAndSupportedTypes(t *testing.T) { + t.Parallel() + + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + svc := &PaymentConfigService{ + entClient: client, + encryptionKey: []byte("0123456789abcdef0123456789abcdef"), + } + + instance, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{ + ProviderKey: "easypay", + Name: "EasyPay", + Config: map[string]string{ + "pid": "3001", + "pkey": "pkey-3001", + "apiBase": "https://pay.example.com", + "notifyUrl": "https://merchant.example.com/notify", + "returnUrl": "https://merchant.example.com/return", + }, + SupportedTypes: []string{"alipay"}, + Enabled: false, + }) + require.NoError(t, err) + + _, err = svc.UpdateProviderInstance(ctx, instance.ID, UpdateProviderInstanceRequest{ + Enabled: boolPtrValue(true), + SupportedTypes: []string{"alipay", "wxpay"}, + }) + require.NoError(t, err) + + saved, err := client.PaymentProviderInstance.Get(ctx, instance.ID) + require.NoError(t, err) + require.True(t, saved.Enabled) + require.Equal(t, "alipay,wxpay", saved.SupportedTypes) +} + +func boolPtrValue(v bool) *bool { + return &v +} diff --git a/backend/internal/service/payment_config_service.go b/backend/internal/service/payment_config_service.go index 59764b298cc8cffc55af8303e4f1af3e3a56c4eb..02d061aeeaad72263e0bb6a87eef02dc2ff73bde 100644 --- a/backend/internal/service/payment_config_service.go +++ b/backend/internal/service/payment_config_service.go @@ -93,6 +93,11 @@ type UpdatePaymentConfigRequest struct { CancelRateLimitWindow *int `json:"cancel_rate_limit_window"` CancelRateLimitUnit *string `json:"cancel_rate_limit_unit"` CancelRateLimitMode *string `json:"cancel_rate_limit_window_mode"` + + VisibleMethodAlipaySource *string `json:"payment_visible_method_alipay_source"` + VisibleMethodWxpaySource *string `json:"payment_visible_method_wxpay_source"` + VisibleMethodAlipayEnabled *bool `json:"payment_visible_method_alipay_enabled"` + VisibleMethodWxpayEnabled *bool `json:"payment_visible_method_wxpay_enabled"` } // MethodLimits holds per-payment-type limits. @@ -196,6 +201,8 @@ func (s *PaymentConfigService) GetPaymentConfig(ctx context.Context) (*PaymentCo SettingHelpImageURL, SettingHelpText, SettingCancelRateLimitOn, SettingCancelRateLimitMax, SettingCancelWindowSize, SettingCancelWindowUnit, SettingCancelWindowMode, + SettingPaymentVisibleMethodAlipayEnabled, SettingPaymentVisibleMethodAlipaySource, + SettingPaymentVisibleMethodWxpayEnabled, SettingPaymentVisibleMethodWxpaySource, } vals, err := s.settingRepo.GetMultiple(ctx, keys) if err != nil { @@ -234,18 +241,23 @@ func (s *PaymentConfigService) parsePaymentConfig(vals map[string]string) *Payme cfg.LoadBalanceStrategy = payment.DefaultLoadBalanceStrategy } if raw := vals[SettingEnabledPaymentTypes]; raw != "" { + types := make([]string, 0, len(strings.Split(raw, ","))) for _, t := range strings.Split(raw, ",") { t = strings.TrimSpace(t) if t != "" { - cfg.EnabledTypes = append(cfg.EnabledTypes, t) + types = append(types, t) } } + cfg.EnabledTypes = NormalizeVisibleMethods(types) } return cfg } // getStripePublishableKey finds the publishable key from the first enabled Stripe provider instance. func (s *PaymentConfigService) getStripePublishableKey(ctx context.Context) string { + if s.entClient == nil { + return "" + } instances, err := s.entClient.PaymentProviderInstance.Query(). Where( paymentproviderinstance.EnabledEQ(true), @@ -282,25 +294,29 @@ func (s *PaymentConfigService) UpdatePaymentConfig(ctx context.Context, req Upda } } m := map[string]string{ - SettingPaymentEnabled: formatBoolOrEmpty(req.Enabled), - SettingMinRechargeAmount: formatPositiveFloat(req.MinAmount), - SettingMaxRechargeAmount: formatPositiveFloat(req.MaxAmount), - SettingDailyRechargeLimit: formatPositiveFloat(req.DailyLimit), - SettingOrderTimeoutMinutes: formatPositiveInt(req.OrderTimeoutMin), - SettingMaxPendingOrders: formatPositiveInt(req.MaxPendingOrders), - SettingBalancePayDisabled: formatBoolOrEmpty(req.BalanceDisabled), - SettingBalanceRechargeMult: formatPositiveFloat(req.BalanceRechargeMultiplier), - SettingRechargeFeeRate: formatNonNegativeFloat(req.RechargeFeeRate), - SettingLoadBalanceStrategy: derefStr(req.LoadBalanceStrategy), - SettingProductNamePrefix: derefStr(req.ProductNamePrefix), - SettingProductNameSuffix: derefStr(req.ProductNameSuffix), - SettingHelpImageURL: derefStr(req.HelpImageURL), - SettingHelpText: derefStr(req.HelpText), - SettingCancelRateLimitOn: formatBoolOrEmpty(req.CancelRateLimitEnabled), - SettingCancelRateLimitMax: formatPositiveInt(req.CancelRateLimitMax), - SettingCancelWindowSize: formatPositiveInt(req.CancelRateLimitWindow), - SettingCancelWindowUnit: derefStr(req.CancelRateLimitUnit), - SettingCancelWindowMode: derefStr(req.CancelRateLimitMode), + SettingPaymentEnabled: formatBoolOrEmpty(req.Enabled), + SettingMinRechargeAmount: formatPositiveFloat(req.MinAmount), + SettingMaxRechargeAmount: formatPositiveFloat(req.MaxAmount), + SettingDailyRechargeLimit: formatPositiveFloat(req.DailyLimit), + SettingOrderTimeoutMinutes: formatPositiveInt(req.OrderTimeoutMin), + SettingMaxPendingOrders: formatPositiveInt(req.MaxPendingOrders), + SettingBalancePayDisabled: formatBoolOrEmpty(req.BalanceDisabled), + SettingBalanceRechargeMult: formatPositiveFloat(req.BalanceRechargeMultiplier), + SettingRechargeFeeRate: formatNonNegativeFloat(req.RechargeFeeRate), + SettingLoadBalanceStrategy: derefStr(req.LoadBalanceStrategy), + SettingProductNamePrefix: derefStr(req.ProductNamePrefix), + SettingProductNameSuffix: derefStr(req.ProductNameSuffix), + SettingHelpImageURL: derefStr(req.HelpImageURL), + SettingHelpText: derefStr(req.HelpText), + SettingCancelRateLimitOn: formatBoolOrEmpty(req.CancelRateLimitEnabled), + SettingCancelRateLimitMax: formatPositiveInt(req.CancelRateLimitMax), + SettingCancelWindowSize: formatPositiveInt(req.CancelRateLimitWindow), + SettingCancelWindowUnit: derefStr(req.CancelRateLimitUnit), + SettingCancelWindowMode: derefStr(req.CancelRateLimitMode), + SettingPaymentVisibleMethodAlipaySource: derefStr(req.VisibleMethodAlipaySource), + SettingPaymentVisibleMethodWxpaySource: derefStr(req.VisibleMethodWxpaySource), + SettingPaymentVisibleMethodAlipayEnabled: formatBoolOrEmpty(req.VisibleMethodAlipayEnabled), + SettingPaymentVisibleMethodWxpayEnabled: formatBoolOrEmpty(req.VisibleMethodWxpayEnabled), } if req.EnabledTypes != nil { m[SettingEnabledPaymentTypes] = strings.Join(req.EnabledTypes, ",") @@ -385,3 +401,79 @@ func pcParseInt(s string, defaultVal int) int { } return v } + +func buildVisibleMethodSourceAvailability(instances []*dbent.PaymentProviderInstance) map[string]bool { + available := make(map[string]bool, 4) + for _, inst := range instances { + switch inst.ProviderKey { + case payment.TypeAlipay: + if inst.SupportedTypes == "" || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeAlipay) || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeAlipayDirect) { + available[VisibleMethodSourceOfficialAlipay] = true + } + case payment.TypeWxpay: + if inst.SupportedTypes == "" || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeWxpay) || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeWxpayDirect) { + available[VisibleMethodSourceOfficialWechat] = true + } + case payment.TypeEasyPay: + for _, supportedType := range splitTypes(inst.SupportedTypes) { + switch NormalizeVisibleMethod(supportedType) { + case payment.TypeAlipay: + available[VisibleMethodSourceEasyPayAlipay] = true + case payment.TypeWxpay: + available[VisibleMethodSourceEasyPayWechat] = true + } + } + } + } + return available +} + +func applyVisibleMethodRoutingToEnabledTypes(base []string, vals map[string]string, available map[string]bool) []string { + shouldExpose := map[string]bool{ + payment.TypeAlipay: visibleMethodShouldBeExposed(payment.TypeAlipay, vals, available), + payment.TypeWxpay: visibleMethodShouldBeExposed(payment.TypeWxpay, vals, available), + } + + seen := make(map[string]struct{}, len(base)+2) + out := make([]string, 0, len(base)+2) + appendType := func(paymentType string) { + paymentType = NormalizeVisibleMethod(paymentType) + if paymentType == "" { + return + } + if _, ok := seen[paymentType]; ok { + return + } + seen[paymentType] = struct{}{} + out = append(out, paymentType) + } + + for _, paymentType := range base { + visibleMethod := NormalizeVisibleMethod(paymentType) + switch visibleMethod { + case payment.TypeAlipay, payment.TypeWxpay: + if shouldExpose[visibleMethod] { + appendType(visibleMethod) + } + default: + appendType(visibleMethod) + } + } + + for _, visibleMethod := range []string{payment.TypeAlipay, payment.TypeWxpay} { + if shouldExpose[visibleMethod] { + appendType(visibleMethod) + } + } + return out +} + +func visibleMethodShouldBeExposed(method string, vals map[string]string, available map[string]bool) bool { + enabledKey := visibleMethodEnabledSettingKey(method) + sourceKey := visibleMethodSourceSettingKey(method) + if enabledKey == "" || sourceKey == "" || vals[enabledKey] != "true" { + return false + } + source := NormalizeVisibleMethodSource(method, vals[sourceKey]) + return source != "" && available[source] +} diff --git a/backend/internal/service/payment_config_service_test.go b/backend/internal/service/payment_config_service_test.go index 027bb796fcde5f6dbae4d46c7dfc3ee7d2c1e99e..f04f4697b1d64b016e5b85e83219e58d1ce01445 100644 --- a/backend/internal/service/payment_config_service_test.go +++ b/backend/internal/service/payment_config_service_test.go @@ -1,9 +1,19 @@ package service import ( + "context" + "database/sql" + "fmt" + "strings" "testing" + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/enttest" "github.com/Wei-Shaw/sub2api/internal/payment" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" ) func TestPcParseFloat(t *testing.T) { @@ -163,6 +173,20 @@ func TestParsePaymentConfig(t *testing.T) { } }) + t.Run("enabled types are normalized to visible methods and deduplicated", func(t *testing.T) { + t.Parallel() + vals := map[string]string{ + SettingEnabledPaymentTypes: "alipay_direct, alipay, wxpay_direct, wxpay", + } + cfg := svc.parsePaymentConfig(vals) + if len(cfg.EnabledTypes) != 2 { + t.Fatalf("EnabledTypes len = %d, want 2", len(cfg.EnabledTypes)) + } + if cfg.EnabledTypes[0] != "alipay" || cfg.EnabledTypes[1] != "wxpay" { + t.Fatalf("EnabledTypes = %v, want [alipay wxpay]", cfg.EnabledTypes) + } + }) + t.Run("empty enabled types string", func(t *testing.T) { t.Parallel() vals := map[string]string{ @@ -204,3 +228,210 @@ func TestGetBasePaymentType(t *testing.T) { }) } } + +func TestApplyVisibleMethodRoutingToEnabledTypes(t *testing.T) { + t.Parallel() + + base := []string{"alipay", "wxpay", "stripe"} + vals := map[string]string{ + SettingPaymentVisibleMethodAlipayEnabled: "true", + SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceOfficialAlipay, + SettingPaymentVisibleMethodWxpayEnabled: "true", + SettingPaymentVisibleMethodWxpaySource: VisibleMethodSourceOfficialWechat, + } + available := map[string]bool{ + VisibleMethodSourceOfficialAlipay: true, + VisibleMethodSourceOfficialWechat: false, + } + + got := applyVisibleMethodRoutingToEnabledTypes(base, vals, available) + want := []string{"alipay", "stripe"} + if len(got) != len(want) { + t.Fatalf("applyVisibleMethodRoutingToEnabledTypes len = %d, want %d (%v)", len(got), len(want), got) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("applyVisibleMethodRoutingToEnabledTypes[%d] = %q, want %q (full=%v)", i, got[i], want[i], got) + } + } +} + +func TestApplyVisibleMethodRoutingAddsConfiguredVisibleMethod(t *testing.T) { + t.Parallel() + + base := []string{"stripe"} + vals := map[string]string{ + SettingPaymentVisibleMethodAlipayEnabled: "true", + SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceEasyPayAlipay, + } + available := map[string]bool{ + VisibleMethodSourceEasyPayAlipay: true, + } + + got := applyVisibleMethodRoutingToEnabledTypes(base, vals, available) + want := []string{"stripe", "alipay"} + if len(got) != len(want) { + t.Fatalf("applyVisibleMethodRoutingToEnabledTypes len = %d, want %d (%v)", len(got), len(want), got) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("applyVisibleMethodRoutingToEnabledTypes[%d] = %q, want %q (full=%v)", i, got[i], want[i], got) + } + } +} + +func TestBuildVisibleMethodSourceAvailability(t *testing.T) { + t.Parallel() + + instances := []*dbent.PaymentProviderInstance{ + {ProviderKey: payment.TypeAlipay, SupportedTypes: "alipay"}, + {ProviderKey: payment.TypeEasyPay, SupportedTypes: "wxpay_direct, alipay"}, + {ProviderKey: payment.TypeWxpay, SupportedTypes: "wxpay_direct"}, + } + + got := buildVisibleMethodSourceAvailability(instances) + if !got[VisibleMethodSourceOfficialAlipay] { + t.Fatalf("expected %q to be available", VisibleMethodSourceOfficialAlipay) + } + if !got[VisibleMethodSourceEasyPayAlipay] { + t.Fatalf("expected %q to be available", VisibleMethodSourceEasyPayAlipay) + } + if !got[VisibleMethodSourceOfficialWechat] { + t.Fatalf("expected %q to be available", VisibleMethodSourceOfficialWechat) + } + if !got[VisibleMethodSourceEasyPayWechat] { + t.Fatalf("expected %q to be available", VisibleMethodSourceEasyPayWechat) + } +} + +func TestGetPaymentConfigKeepsStoredEnabledTypes(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeEasyPay). + SetName("EasyPay Alipay"). + SetConfig("{}"). + SetSupportedTypes("alipay"). + SetEnabled(true). + Save(ctx) + if err != nil { + t.Fatalf("create easypay instance: %v", err) + } + + svc := &PaymentConfigService{ + entClient: client, + settingRepo: &paymentConfigSettingRepoStub{ + values: map[string]string{ + SettingEnabledPaymentTypes: "alipay,wxpay,stripe", + }, + }, + } + + cfg, err := svc.GetPaymentConfig(ctx) + if err != nil { + t.Fatalf("GetPaymentConfig returned error: %v", err) + } + + want := []string{payment.TypeAlipay, payment.TypeWxpay, payment.TypeStripe} + if len(cfg.EnabledTypes) != len(want) { + t.Fatalf("EnabledTypes len = %d, want %d (%v)", len(cfg.EnabledTypes), len(want), cfg.EnabledTypes) + } + for i := range want { + if cfg.EnabledTypes[i] != want[i] { + t.Fatalf("EnabledTypes[%d] = %q, want %q (full=%v)", i, cfg.EnabledTypes[i], want[i], cfg.EnabledTypes) + } + } +} + +func newPaymentConfigServiceTestClient(t *testing.T) *dbent.Client { + t.Helper() + + dbName := fmt.Sprintf( + "file:%s?mode=memory&cache=shared", + strings.NewReplacer("/", "_", " ", "_").Replace(t.Name()), + ) + db, err := sql.Open("sqlite", dbName) + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + + if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil { + t.Fatalf("enable foreign keys: %v", err) + } + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + return client +} + +type paymentConfigSettingRepoStub struct { + values map[string]string + updates map[string]string +} + +func (s *paymentConfigSettingRepoStub) Get(context.Context, string) (*Setting, error) { + return nil, nil +} +func (s *paymentConfigSettingRepoStub) GetValue(_ context.Context, key string) (string, error) { + return s.values[key], nil +} +func (s *paymentConfigSettingRepoStub) Set(context.Context, string, string) error { return nil } +func (s *paymentConfigSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + out[key] = s.values[key] + } + return out, nil +} +func (s *paymentConfigSettingRepoStub) SetMultiple(_ context.Context, values map[string]string) error { + s.updates = make(map[string]string, len(values)) + for key, value := range values { + s.updates[key] = value + if s.values == nil { + s.values = map[string]string{} + } + s.values[key] = value + } + return nil +} +func (s *paymentConfigSettingRepoStub) GetAll(context.Context) (map[string]string, error) { + return s.values, nil +} +func (s *paymentConfigSettingRepoStub) Delete(context.Context, string) error { return nil } + +func TestUpdatePaymentConfig_PersistsVisibleMethodRouting(t *testing.T) { + repo := &paymentConfigSettingRepoStub{values: map[string]string{}} + svc := &PaymentConfigService{settingRepo: repo} + + alipayEnabled := true + wxpayEnabled := false + err := svc.UpdatePaymentConfig(context.Background(), UpdatePaymentConfigRequest{ + VisibleMethodAlipayEnabled: &alipayEnabled, + VisibleMethodAlipaySource: paymentConfigStrPtr(VisibleMethodSourceEasyPayAlipay), + VisibleMethodWxpayEnabled: &wxpayEnabled, + VisibleMethodWxpaySource: paymentConfigStrPtr(VisibleMethodSourceOfficialWechat), + }) + if err != nil { + t.Fatalf("UpdatePaymentConfig returned error: %v", err) + } + + if repo.values[SettingPaymentVisibleMethodAlipayEnabled] != "true" { + t.Fatalf("alipay enabled = %q, want true", repo.values[SettingPaymentVisibleMethodAlipayEnabled]) + } + if repo.values[SettingPaymentVisibleMethodAlipaySource] != VisibleMethodSourceEasyPayAlipay { + t.Fatalf("alipay source = %q, want %q", repo.values[SettingPaymentVisibleMethodAlipaySource], VisibleMethodSourceEasyPayAlipay) + } + if repo.values[SettingPaymentVisibleMethodWxpayEnabled] != "false" { + t.Fatalf("wxpay enabled = %q, want false", repo.values[SettingPaymentVisibleMethodWxpayEnabled]) + } + if repo.values[SettingPaymentVisibleMethodWxpaySource] != VisibleMethodSourceOfficialWechat { + t.Fatalf("wxpay source = %q, want %q", repo.values[SettingPaymentVisibleMethodWxpaySource], VisibleMethodSourceOfficialWechat) + } +} + +func paymentConfigStrPtr(value string) *string { + return &value +} diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go index 44818b37e6ff50d7947ad018d5a85eeb9a38e204..904960ee9c3339f8b39f2276c8b54f8844577034 100644 --- a/backend/internal/service/payment_fulfillment.go +++ b/backend/internal/service/payment_fulfillment.go @@ -25,22 +25,61 @@ func (s *PaymentService) HandlePaymentNotification(ctx context.Context, n *payme // Look up order by out_trade_no (the external order ID we sent to the provider) order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(n.OrderID)).Only(ctx) if err != nil { - // Fallback: try legacy format (sub2_N where N is DB ID) - trimmed := strings.TrimPrefix(n.OrderID, orderIDPrefix) - if oid, parseErr := strconv.ParseInt(trimmed, 10, 64); parseErr == nil { - return s.confirmPayment(ctx, oid, n.TradeNo, n.Amount, pk) + // Fallback only for true legacy "sub2_N" DB-ID payloads when the + // current out_trade_no lookup genuinely did not find an order. + if oid, ok := parseLegacyPaymentOrderID(n.OrderID, err); ok { + return s.confirmPayment(ctx, oid, n.TradeNo, n.Amount, pk, n.Metadata) } return fmt.Errorf("order not found for out_trade_no: %s", n.OrderID) } - return s.confirmPayment(ctx, order.ID, n.TradeNo, n.Amount, pk) + return s.confirmPayment(ctx, order.ID, n.TradeNo, n.Amount, pk, n.Metadata) } -func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo string, paid float64, pk string) error { +func parseLegacyPaymentOrderID(orderID string, lookupErr error) (int64, bool) { + if !dbent.IsNotFound(lookupErr) { + return 0, false + } + orderID = strings.TrimSpace(orderID) + if !strings.HasPrefix(orderID, orderIDPrefix) { + return 0, false + } + trimmed := strings.TrimPrefix(orderID, orderIDPrefix) + if trimmed == "" || trimmed == orderID { + return 0, false + } + oid, err := strconv.ParseInt(trimmed, 10, 64) + if err != nil || oid <= 0 { + return 0, false + } + return oid, true +} + +func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo string, paid float64, pk string, metadata map[string]string) error { o, err := s.entClient.PaymentOrder.Get(ctx, oid) if err != nil { slog.Error("order not found", "orderID", oid) return nil } + instanceProviderKey := "" + if inst, instErr := s.getOrderProviderInstance(ctx, o); instErr == nil && inst != nil { + instanceProviderKey = inst.ProviderKey + } + expectedProviderKey := expectedNotificationProviderKeyForOrder(s.registry, o, instanceProviderKey) + if expectedProviderKey != "" && strings.TrimSpace(pk) != "" && !strings.EqualFold(expectedProviderKey, strings.TrimSpace(pk)) { + s.writeAuditLog(ctx, o.ID, "PAYMENT_PROVIDER_MISMATCH", pk, map[string]any{ + "expectedProvider": expectedProviderKey, + "actualProvider": pk, + "tradeNo": tradeNo, + }) + return fmt.Errorf("provider mismatch: expected %s, got %s", expectedProviderKey, pk) + } + if err := validateProviderNotificationMetadata(o, pk, metadata); err != nil { + s.writeAuditLog(ctx, o.ID, "PAYMENT_PROVIDER_METADATA_MISMATCH", pk, map[string]any{ + "detail": err.Error(), + "tradeNo": tradeNo, + }) + return err + } // Skip amount check when paid=0 (e.g. QueryOrder doesn't return amount). // Also skip if paid is NaN/Inf (malformed provider data). if paid > 0 && !math.IsNaN(paid) && !math.IsInf(paid, 0) { @@ -56,6 +95,25 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo return s.toPaid(ctx, o, tradeNo, paid, pk) } +func validateProviderNotificationMetadata(order *dbent.PaymentOrder, providerKey string, metadata map[string]string) error { + return validateProviderSnapshotMetadata(order, providerKey, metadata) +} + +func expectedNotificationProviderKey(registry *payment.Registry, orderPaymentType string, orderProviderKey string, instanceProviderKey string) string { + if key := strings.TrimSpace(instanceProviderKey); key != "" { + return key + } + if key := strings.TrimSpace(orderProviderKey); key != "" { + return key + } + if registry != nil { + if key := strings.TrimSpace(registry.GetProviderKey(payment.PaymentType(orderPaymentType))); key != "" { + return key + } + } + return strings.TrimSpace(orderPaymentType) +} + func (s *PaymentService) toPaid(ctx context.Context, o *dbent.PaymentOrder, tradeNo string, paid float64, pk string) error { previousStatus := o.Status now := time.Now() diff --git a/backend/internal/service/payment_fulfillment_test.go b/backend/internal/service/payment_fulfillment_test.go index 625b0d9f0ed9bed88489abc63bd7b5a5c24b6eb7..6aed19f849abad1d108b6f5700f674bd6990259d 100644 --- a/backend/internal/service/payment_fulfillment_test.go +++ b/backend/internal/service/payment_fulfillment_test.go @@ -3,12 +3,38 @@ package service import ( + "context" "errors" "testing" + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/internal/payment" "github.com/stretchr/testify/assert" ) +type paymentFulfillmentTestProvider struct { + key string + supportedTypes []payment.PaymentType +} + +func (p paymentFulfillmentTestProvider) Name() string { return p.key } +func (p paymentFulfillmentTestProvider) ProviderKey() string { return p.key } +func (p paymentFulfillmentTestProvider) SupportedTypes() []payment.PaymentType { + return p.supportedTypes +} +func (p paymentFulfillmentTestProvider) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) { + panic("unexpected call") +} +func (p paymentFulfillmentTestProvider) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryOrderResponse, error) { + panic("unexpected call") +} +func (p paymentFulfillmentTestProvider) VerifyNotification(ctx context.Context, rawBody string, headers map[string]string) (*payment.PaymentNotification, error) { + panic("unexpected call") +} +func (p paymentFulfillmentTestProvider) Refund(ctx context.Context, req payment.RefundRequest) (*payment.RefundResponse, error) { + panic("unexpected call") +} + // --------------------------------------------------------------------------- // resolveRedeemAction — pure idempotency decision logic // --------------------------------------------------------------------------- @@ -161,3 +187,171 @@ func TestResolveRedeemAction_IsUsedCanUseConsistency(t *testing.T) { assert.True(t, unusedCode.CanUse()) assert.Equal(t, redeemActionRedeem, resolveRedeemAction(unusedCode, nil)) } + +func TestExpectedNotificationProviderKeyPrefersOrderInstanceProvider(t *testing.T) { + t.Parallel() + + registry := payment.NewRegistry() + registry.Register(paymentFulfillmentTestProvider{ + key: payment.TypeAlipay, + supportedTypes: []payment.PaymentType{payment.TypeAlipay}, + }) + + assert.Equal(t, + payment.TypeEasyPay, + expectedNotificationProviderKey(registry, payment.TypeAlipay, "", payment.TypeEasyPay), + ) +} + +func TestExpectedNotificationProviderKeyUsesRegistryMappingForLegacyOrders(t *testing.T) { + t.Parallel() + + registry := payment.NewRegistry() + registry.Register(paymentFulfillmentTestProvider{ + key: payment.TypeEasyPay, + supportedTypes: []payment.PaymentType{payment.TypeAlipay}, + }) + + assert.Equal(t, + payment.TypeEasyPay, + expectedNotificationProviderKey(registry, payment.TypeAlipay, "", ""), + ) +} + +func TestExpectedNotificationProviderKeyFallsBackToPaymentType(t *testing.T) { + t.Parallel() + + assert.Equal(t, + payment.TypeWxpay, + expectedNotificationProviderKey(nil, payment.TypeWxpay, "", ""), + ) +} + +func TestExpectedNotificationProviderKeyPrefersOrderSnapshotProviderKey(t *testing.T) { + t.Parallel() + + registry := payment.NewRegistry() + registry.Register(paymentFulfillmentTestProvider{ + key: payment.TypeAlipay, + supportedTypes: []payment.PaymentType{payment.TypeAlipay}, + }) + + assert.Equal(t, + payment.TypeEasyPay, + expectedNotificationProviderKey(registry, payment.TypeAlipay, payment.TypeEasyPay, ""), + ) +} + +func TestExpectedNotificationProviderKeyForOrderUsesSnapshotProviderKey(t *testing.T) { + t.Parallel() + + registry := payment.NewRegistry() + registry.Register(paymentFulfillmentTestProvider{ + key: payment.TypeAlipay, + supportedTypes: []payment.PaymentType{payment.TypeAlipay}, + }) + + order := &dbent.PaymentOrder{ + PaymentType: payment.TypeAlipay, + ProviderSnapshot: map[string]any{ + "schema_version": 1, + "provider_key": payment.TypeEasyPay, + }, + } + + assert.Equal(t, + payment.TypeEasyPay, + expectedNotificationProviderKeyForOrder(registry, order, ""), + ) +} + +func TestValidateProviderNotificationMetadataRejectsWxpaySnapshotMismatch(t *testing.T) { + t.Parallel() + + order := &dbent.PaymentOrder{ + PaymentType: payment.TypeWxpay, + ProviderSnapshot: map[string]any{ + "schema_version": 1, + "merchant_app_id": "wx-app-expected", + "merchant_id": "mch-expected", + "currency": "CNY", + }, + } + + err := validateProviderNotificationMetadata(order, payment.TypeWxpay, map[string]string{ + "appid": "wx-app-other", + "mchid": "mch-expected", + "currency": "CNY", + "trade_state": "SUCCESS", + }) + assert.ErrorContains(t, err, "wxpay appid mismatch") +} + +func TestValidateProviderNotificationMetadataAllowsLegacyOrdersWithoutSnapshotFields(t *testing.T) { + t.Parallel() + + order := &dbent.PaymentOrder{ + PaymentType: payment.TypeWxpay, + ProviderSnapshot: map[string]any{ + "schema_version": 1, + "provider_instance_id": "9", + "provider_key": payment.TypeWxpay, + }, + } + + err := validateProviderNotificationMetadata(order, payment.TypeWxpay, map[string]string{ + "appid": "wx-app-runtime", + "mchid": "mch-runtime", + "currency": "CNY", + "trade_state": "SUCCESS", + }) + assert.NoError(t, err) +} + +func TestParseLegacyPaymentOrderID(t *testing.T) { + t.Parallel() + + oid, ok := parseLegacyPaymentOrderID("sub2_42", &dbent.NotFoundError{}) + assert.True(t, ok) + assert.EqualValues(t, 42, oid) + + _, ok = parseLegacyPaymentOrderID("42", &dbent.NotFoundError{}) + assert.False(t, ok) + + _, ok = parseLegacyPaymentOrderID("sub2_42", errors.New("db down")) + assert.False(t, ok) +} + +func TestValidateProviderNotificationMetadataRejectsAlipaySnapshotMismatch(t *testing.T) { + t.Parallel() + + order := &dbent.PaymentOrder{ + PaymentType: payment.TypeAlipay, + ProviderSnapshot: map[string]any{ + "schema_version": 2, + "merchant_app_id": "alipay-app-expected", + }, + } + + err := validateProviderNotificationMetadata(order, payment.TypeAlipay, map[string]string{ + "app_id": "alipay-app-other", + }) + assert.ErrorContains(t, err, "alipay app_id mismatch") +} + +func TestValidateProviderNotificationMetadataRejectsEasyPaySnapshotMismatch(t *testing.T) { + t.Parallel() + + order := &dbent.PaymentOrder{ + PaymentType: payment.TypeAlipay, + ProviderSnapshot: map[string]any{ + "schema_version": 2, + "merchant_id": "pid-expected", + }, + } + + err := validateProviderNotificationMetadata(order, payment.TypeEasyPay, map[string]string{ + "pid": "pid-other", + }) + assert.ErrorContains(t, err, "easypay pid mismatch") +} diff --git a/backend/internal/service/payment_order.go b/backend/internal/service/payment_order.go index a72120257ac15ad55a1fad25f179aa08c372f4e7..6554526ef6a91c8aa49a7f9e6a56fc6822145a47 100644 --- a/backend/internal/service/payment_order.go +++ b/backend/internal/service/payment_order.go @@ -6,6 +6,7 @@ import ( "fmt" "log/slog" "math" + "net/url" "strconv" "strings" "time" @@ -23,6 +24,9 @@ func (s *PaymentService) CreateOrder(ctx context.Context, req CreateOrderRequest if req.OrderType == "" { req.OrderType = payment.OrderTypeBalance } + if normalized := NormalizeVisibleMethod(req.PaymentType); normalized != "" { + req.PaymentType = normalized + } cfg, err := s.configService.GetPaymentConfig(ctx) if err != nil { return nil, fmt.Errorf("get payment config: %w", err) @@ -55,11 +59,25 @@ func (s *PaymentService) CreateOrder(ctx context.Context, req CreateOrderRequest feeRate := cfg.RechargeFeeRate payAmountStr := payment.CalculatePayAmount(limitAmount, feeRate) payAmount, _ := strconv.ParseFloat(payAmountStr, 64) - order, err := s.createOrderInTx(ctx, req, user, plan, cfg, orderAmount, limitAmount, feeRate, payAmount) + sel, err := s.selectCreateOrderInstance(ctx, req, cfg, payAmount) + if err != nil { + return nil, err + } + if err := s.validateSelectedCreateOrderInstance(ctx, req, sel); err != nil { + return nil, err + } + oauthResp, err := s.maybeBuildWeChatOAuthRequiredResponseForSelection(ctx, req, limitAmount, payAmount, feeRate, sel) + if err != nil { + return nil, err + } + if oauthResp != nil { + return oauthResp, nil + } + order, err := s.createOrderInTx(ctx, req, user, plan, cfg, orderAmount, limitAmount, feeRate, payAmount, sel) if err != nil { return nil, err } - resp, err := s.invokeProvider(ctx, order, req, cfg, limitAmount, payAmountStr, payAmount, plan) + resp, err := s.invokeProvider(ctx, order, req, cfg, limitAmount, payAmountStr, payAmount, plan, sel) if err != nil { _, _ = s.entClient.PaymentOrder.UpdateOneID(order.ID). SetStatus(OrderStatusFailed). @@ -104,7 +122,7 @@ func (s *PaymentService) validateSubOrder(ctx context.Context, req CreateOrderRe return plan, nil } -func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderRequest, user *User, plan *dbent.SubscriptionPlan, cfg *PaymentConfig, orderAmount, limitAmount, feeRate, payAmount float64) (*dbent.PaymentOrder, error) { +func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderRequest, user *User, plan *dbent.SubscriptionPlan, cfg *PaymentConfig, orderAmount, limitAmount, feeRate, payAmount float64, sel *payment.InstanceSelection) (*dbent.PaymentOrder, error) { tx, err := s.entClient.Tx(ctx) if err != nil { return nil, fmt.Errorf("begin transaction: %w", err) @@ -121,6 +139,13 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq tm = defaultOrderTimeoutMin } exp := time.Now().Add(time.Duration(tm) * time.Minute) + providerSnapshot := buildPaymentOrderProviderSnapshot(sel, req) + selectedInstanceID := "" + selectedProviderKey := "" + if sel != nil { + selectedInstanceID = strings.TrimSpace(sel.InstanceID) + selectedProviderKey = strings.TrimSpace(sel.ProviderKey) + } b := tx.PaymentOrder.Create(). SetUserID(req.UserID). SetUserEmail(user.Email). @@ -141,6 +166,15 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq if req.SrcURL != "" { b.SetSrcURL(req.SrcURL) } + if selectedInstanceID != "" { + b.SetProviderInstanceID(selectedInstanceID) + } + if selectedProviderKey != "" { + b.SetProviderKey(selectedProviderKey) + } + if providerSnapshot != nil { + b.SetProviderSnapshot(providerSnapshot) + } if plan != nil { b.SetPlanID(plan.ID).SetSubscriptionGroupID(plan.GroupID).SetSubscriptionDays(psComputeValidityDays(plan.ValidityDays, plan.ValidityUnit)) } @@ -174,6 +208,65 @@ func (s *PaymentService) checkPendingLimit(ctx context.Context, tx *dbent.Tx, us return nil } +func buildPaymentOrderProviderSnapshot(sel *payment.InstanceSelection, req CreateOrderRequest) map[string]any { + if sel == nil { + return nil + } + + snapshot := map[string]any{} + snapshot["schema_version"] = 2 + + instanceID := strings.TrimSpace(sel.InstanceID) + if instanceID != "" { + snapshot["provider_instance_id"] = instanceID + } + + providerKey := strings.TrimSpace(sel.ProviderKey) + if providerKey != "" { + snapshot["provider_key"] = providerKey + } + + paymentMode := strings.TrimSpace(sel.PaymentMode) + if paymentMode != "" { + snapshot["payment_mode"] = paymentMode + } + + if providerKey == payment.TypeWxpay { + if merchantAppID := paymentOrderSnapshotWxpayAppID(sel, req); merchantAppID != "" { + snapshot["merchant_app_id"] = merchantAppID + } + if merchantID := strings.TrimSpace(sel.Config["mchId"]); merchantID != "" { + snapshot["merchant_id"] = merchantID + } + snapshot["currency"] = "CNY" + } + if providerKey == payment.TypeAlipay { + if merchantAppID := strings.TrimSpace(sel.Config["appId"]); merchantAppID != "" { + snapshot["merchant_app_id"] = merchantAppID + } + } + if providerKey == payment.TypeEasyPay { + if merchantID := strings.TrimSpace(sel.Config["pid"]); merchantID != "" { + snapshot["merchant_id"] = merchantID + } + } + + if len(snapshot) == 1 { + return nil + } + return snapshot +} + +func paymentOrderSnapshotWxpayAppID(sel *payment.InstanceSelection, req CreateOrderRequest) string { + if sel == nil || strings.TrimSpace(sel.ProviderKey) != payment.TypeWxpay { + return "" + } + if strings.TrimSpace(req.OpenID) != "" { + return strings.TrimSpace(provider.ResolveWxpayJSAPIAppID(sel.Config)) + } + return strings.TrimSpace(sel.Config["appId"]) +} + func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, userID int64, amount, limit float64) error { if limit <= 0 { return nil @@ -198,10 +291,12 @@ func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, user return nil } -func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.PaymentOrder, req CreateOrderRequest, cfg *PaymentConfig, limitAmount float64, payAmountStr string, payAmount float64, plan *dbent.SubscriptionPlan) (*CreateOrderResponse, error) { - // Select an instance across all providers that support the requested payment type. - // This enables cross-provider load balancing (e.g. EasyPay + Alipay direct for "alipay"). - sel, err := s.loadBalancer.SelectInstance(ctx, "", req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount) +func (s *PaymentService) selectCreateOrderInstance(ctx context.Context, req CreateOrderRequest, cfg *PaymentConfig, payAmount float64) (*payment.InstanceSelection, error) { + selectCtx, err := s.prepareCreateOrderSelectionContext(ctx, req) + if err != nil { + return nil, err + } + sel, err := s.loadBalancer.SelectInstance(selectCtx, "", req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount) if err != nil { return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", "method_not_configured"). WithMetadata(map[string]string{"payment_type": req.PaymentType}) @@ -209,6 +304,45 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen if sel == nil { return nil, infraerrors.TooManyRequests("NO_AVAILABLE_INSTANCE", "no_available_instance") } + return sel, nil +} + +func (s *PaymentService) prepareCreateOrderSelectionContext(ctx context.Context, req CreateOrderRequest) (context.Context, error) { + if !requestNeedsWeChatJSAPICompatibility(req) { + return ctx, nil + } + if !s.usesOfficialWxpayVisibleMethod(ctx) { + return ctx, nil + } + expectedAppID, _, err := s.getWeChatPaymentOAuthCredential(ctx) + if err != nil { + return nil, err + } + return payment.WithWxpayJSAPIAppID(ctx, expectedAppID), nil +} + +func requestNeedsWeChatJSAPICompatibility(req CreateOrderRequest) bool { + if payment.GetBasePaymentType(req.PaymentType) != payment.TypeWxpay { + return false + } + return req.IsWeChatBrowser || strings.TrimSpace(req.OpenID) != "" +} + +func (s *PaymentService) usesOfficialWxpayVisibleMethod(ctx context.Context) bool { + if s == nil || s.configService == nil { + return false + } + inst, err := s.configService.resolveEnabledVisibleMethodInstance(ctx, payment.TypeWxpay) + if err != nil { + return false + } + if inst == nil { + return false + } + return inst.ProviderKey == payment.TypeWxpay +} + +func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.PaymentOrder, req CreateOrderRequest, cfg *PaymentConfig, limitAmount float64, payAmountStr string, payAmount float64, plan *dbent.SubscriptionPlan, sel *payment.InstanceSelection) (*CreateOrderResponse, error) { prov, err := provider.CreateProvider(sel.ProviderKey, sel.InstanceID, sel.Config) if err != nil { slog.Error("[PaymentService] CreateProvider failed", "provider", sel.ProviderKey, "instance", sel.InstanceID, "error", err) @@ -226,16 +360,52 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen } subject := s.buildPaymentSubject(plan, limitAmount, cfg) outTradeNo := order.OutTradeNo - pr, err := prov.CreatePayment(ctx, payment.CreatePaymentRequest{OrderID: outTradeNo, Amount: payAmountStr, PaymentType: req.PaymentType, Subject: subject, ClientIP: req.ClientIP, IsMobile: req.IsMobile, InstanceSubMethods: sel.SupportedTypes}) + canonicalReturnURL, err := CanonicalizeReturnURL(req.ReturnURL, req.SrcHost) + if err != nil { + return nil, err + } + resumeToken := "" + if resume := s.paymentResume(); resume != nil { + if resume.isSigningConfigured() { + resumeToken, err = resume.CreateToken(ResumeTokenClaims{ + OrderID: order.ID, + UserID: order.UserID, + ProviderInstanceID: sel.InstanceID, + ProviderKey: sel.ProviderKey, + PaymentType: req.PaymentType, + CanonicalReturnURL: canonicalReturnURL, + }) + if err != nil { + return nil, fmt.Errorf("create payment resume token: %w", err) + } + } + } + providerReturnURL, err := buildPaymentReturnURL(canonicalReturnURL, order.ID, resumeToken) + if err != nil { + return nil, err + } + providerReq := buildProviderCreatePaymentRequest(CreateOrderRequest{ + PaymentType: req.PaymentType, + OpenID: req.OpenID, + ClientIP: req.ClientIP, + IsMobile: req.IsMobile, + ReturnURL: providerReturnURL, + }, sel, outTradeNo, payAmountStr, subject) + pr, err := prov.CreatePayment(ctx, providerReq) if err != nil { slog.Error("[PaymentService] CreatePayment failed", "provider", sel.ProviderKey, "instance", sel.InstanceID, "error", err) if appErr := new(infraerrors.ApplicationError); errors.As(err, &appErr) { return nil, appErr } - return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", "payment_gateway_error"). - WithMetadata(map[string]string{"provider": sel.ProviderKey, "instance_id": sel.InstanceID}) - } - _, err = s.entClient.PaymentOrder.UpdateOneID(order.ID).SetNillablePaymentTradeNo(psNilIfEmpty(pr.TradeNo)).SetNillablePayURL(psNilIfEmpty(pr.PayURL)).SetNillableQrCode(psNilIfEmpty(pr.QRCode)).SetNillableProviderInstanceID(psNilIfEmpty(sel.InstanceID)).Save(ctx) + return nil, classifyCreatePaymentError(req, sel.ProviderKey, err) + } + _, err = s.entClient.PaymentOrder.UpdateOneID(order.ID). + SetNillablePaymentTradeNo(psNilIfEmpty(pr.TradeNo)). + SetNillablePayURL(psNilIfEmpty(pr.PayURL)). + SetNillableQrCode(psNilIfEmpty(pr.QRCode)). + SetNillableProviderInstanceID(psNilIfEmpty(sel.InstanceID)). + SetNillableProviderKey(psNilIfEmpty(sel.ProviderKey)). + Save(ctx) if err != nil { return nil, fmt.Errorf("update order with payment details: %w", err) } @@ -245,8 +415,36 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen "payAmount": order.PayAmount, "paymentType": req.PaymentType, "orderType": req.OrderType, + "paymentSource": NormalizePaymentSource(req.PaymentSource), }) - return &CreateOrderResponse{OrderID: order.ID, Amount: order.Amount, PayAmount: payAmount, FeeRate: order.FeeRate, Status: OrderStatusPending, PaymentType: req.PaymentType, PayURL: pr.PayURL, QRCode: pr.QRCode, ClientSecret: pr.ClientSecret, ExpiresAt: order.ExpiresAt, PaymentMode: sel.PaymentMode}, nil + resultType := pr.ResultType + if resultType == "" { + resultType = payment.CreatePaymentResultOrderCreated + } + resp := buildCreateOrderResponse(order, req, payAmount, sel, pr, resultType) + resp.ResumeToken = resumeToken + return resp, nil +} + +func buildProviderCreatePaymentRequest(req CreateOrderRequest, sel *payment.InstanceSelection, orderID, amount, subject string) payment.CreatePaymentRequest { + return payment.CreatePaymentRequest{ + OrderID: orderID, + Amount: amount, + PaymentType: req.PaymentType, + Subject: subject, + ReturnURL: req.ReturnURL, + OpenID: strings.TrimSpace(req.OpenID), + ClientIP: req.ClientIP, + IsMobile: req.IsMobile, + InstanceSubMethods: selectedInstanceSupportedTypes(sel), + } +} + +func selectedInstanceSupportedTypes(sel *payment.InstanceSelection) string { + if sel == nil { + return "" + } + return sel.SupportedTypes } func (s *PaymentService) buildPaymentSubject(plan *dbent.SubscriptionPlan, limitAmount float64, cfg *PaymentConfig) string { @@ -265,6 +463,190 @@ func (s *PaymentService) buildPaymentSubject(plan *dbent.SubscriptionPlan, limit return "Sub2API " + amountStr + " CNY" } +func (s *PaymentService) maybeBuildWeChatOAuthRequiredResponse(ctx context.Context, req CreateOrderRequest, amount, payAmount, feeRate float64) (*CreateOrderResponse, error) { + return s.maybeBuildWeChatOAuthRequiredResponseForSelection(ctx, req, amount, payAmount, feeRate, nil) +} + +func (s *PaymentService) maybeBuildWeChatOAuthRequiredResponseForSelection(ctx context.Context, req CreateOrderRequest, amount, payAmount, feeRate float64, sel *payment.InstanceSelection) (*CreateOrderResponse, error) { + if sel != nil && sel.ProviderKey != "" && sel.ProviderKey != payment.TypeWxpay { + return nil, nil + } + if strings.TrimSpace(req.OpenID) != "" || !req.IsWeChatBrowser || payment.GetBasePaymentType(req.PaymentType) != payment.TypeWxpay { + return nil, nil + } + return s.buildWeChatOAuthRequiredResponse(ctx, req, amount, payAmount, feeRate) +} + +func (s *PaymentService) buildWeChatOAuthRequiredResponse(ctx context.Context, req CreateOrderRequest, amount, payAmount, feeRate float64) (*CreateOrderResponse, error) { + appID, _, err := s.getWeChatPaymentOAuthCredential(ctx) + if err != nil { + return nil, err + } + + authorizeURL, err := buildWeChatPaymentOAuthStartURL(req, "snsapi_base") + if err != nil { + return nil, err + } + + return &CreateOrderResponse{ + Amount: amount, + PayAmount: payAmount, + FeeRate: feeRate, + ResultType: payment.CreatePaymentResultOAuthRequired, + PaymentType: req.PaymentType, + OAuth: &payment.WechatOAuthInfo{ + AuthorizeURL: authorizeURL, + AppID: appID, + Scope: "snsapi_base", + RedirectURL: "/auth/wechat/payment/callback", + }, + }, nil +} + +func (s *PaymentService) validateSelectedCreateOrderInstance(ctx context.Context, req CreateOrderRequest, sel *payment.InstanceSelection) error { + if !requiresWeChatJSAPICompatibleSelection(req, sel) { + return nil + } + expectedAppID, _, err := s.getWeChatPaymentOAuthCredential(ctx) + if err != nil { + return err + } + selectedAppID := provider.ResolveWxpayJSAPIAppID(sel.Config) + if selectedAppID == "" || selectedAppID != expectedAppID { + return infraerrors.TooManyRequests("NO_AVAILABLE_INSTANCE", "selected payment instance is not compatible with the current WeChat OAuth app") + } + return nil +} + +func requiresWeChatJSAPICompatibleSelection(req CreateOrderRequest, sel *payment.InstanceSelection) bool { + if sel == nil || sel.ProviderKey != payment.TypeWxpay || payment.GetBasePaymentType(req.PaymentType) != payment.TypeWxpay { + return false + } + return req.IsWeChatBrowser || strings.TrimSpace(req.OpenID) != "" +} + +func (s *PaymentService) getWeChatPaymentOAuthCredential(ctx context.Context) (string, string, error) { + if s == nil || s.configService == nil || s.configService.settingRepo == nil { + return "", "", infraerrors.ServiceUnavailable( + "WECHAT_PAYMENT_MP_NOT_CONFIGURED", + "wechat in-app payment requires a complete WeChat MP OAuth credential", + ) + } + cfg, err := (&SettingService{settingRepo: s.configService.settingRepo}).GetWeChatConnectOAuthConfig(ctx) + appID := strings.TrimSpace(cfg.AppIDForMode("mp")) + appSecret := strings.TrimSpace(cfg.AppSecretForMode("mp")) + if err != nil || !cfg.SupportsMode("mp") || appID == "" || appSecret == "" { + return "", "", infraerrors.ServiceUnavailable( + "WECHAT_PAYMENT_MP_NOT_CONFIGURED", + "wechat in-app payment requires a complete WeChat MP OAuth credential", + ) + } + return appID, appSecret, nil +} + +func classifyCreatePaymentError(req CreateOrderRequest, providerKey string, err error) error { + if err == nil { + return nil + } + if providerKey == payment.TypeWxpay && + payment.GetBasePaymentType(req.PaymentType) == payment.TypeWxpay && + strings.Contains(err.Error(), "wxpay h5 payments are not authorized for this merchant") { + return infraerrors.ServiceUnavailable( + "WECHAT_H5_NOT_AUTHORIZED", + "wechat h5 payment is not available for this merchant", + ).WithMetadata(map[string]string{ + "action": "open_in_wechat_or_scan_qr", + }) + } + return infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment gateway error: %s", err.Error())) +} + +func buildCreateOrderResponse(order *dbent.PaymentOrder, req CreateOrderRequest, payAmount float64, sel *payment.InstanceSelection, pr *payment.CreatePaymentResponse, resultType payment.CreatePaymentResultType) *CreateOrderResponse { + return &CreateOrderResponse{ + OrderID: order.ID, + Amount: order.Amount, + PayAmount: payAmount, + FeeRate: order.FeeRate, + Status: OrderStatusPending, + ResultType: resultType, + PaymentType: req.PaymentType, + OutTradeNo: order.OutTradeNo, + PayURL: pr.PayURL, + QRCode: pr.QRCode, + ClientSecret: pr.ClientSecret, + OAuth: pr.OAuth, + JSAPI: pr.JSAPI, + JSAPIPayload: pr.JSAPI, + ExpiresAt: order.ExpiresAt, + PaymentMode: sel.PaymentMode, + } +} + +func buildWeChatPaymentOAuthStartURL(req CreateOrderRequest, scope string) (string, error) { + u, err := url.Parse("/api/v1/auth/oauth/wechat/payment/start") + if err != nil { + return "", fmt.Errorf("build wechat payment oauth start url: %w", err) + } + q := u.Query() + q.Set("payment_type", strings.TrimSpace(req.PaymentType)) + if req.Amount > 0 { + q.Set("amount", strconv.FormatFloat(req.Amount, 'f', -1, 64)) + } + if orderType := strings.TrimSpace(req.OrderType); orderType != "" { + q.Set("order_type", orderType) + } + if req.PlanID > 0 { + q.Set("plan_id", strconv.FormatInt(req.PlanID, 10)) + } + if scope = strings.TrimSpace(scope); scope != "" { + q.Set("scope", scope) + } + if redirectTo := paymentRedirectPathFromURL(req.SrcURL); redirectTo != "" { + q.Set("redirect", redirectTo) + } + u.RawQuery = q.Encode() + return u.String(), nil +} + +func paymentRedirectPathFromURL(rawURL string) string { + rawURL = strings.TrimSpace(rawURL) + if rawURL == "" { + return "/purchase" + } + if strings.HasPrefix(rawURL, "/") && !strings.HasPrefix(rawURL, "//") { + return normalizePaymentRedirectPath(rawURL) + } + u, err := url.Parse(rawURL) + if err != nil { + return "/purchase" + } + path := strings.TrimSpace(u.EscapedPath()) + if path == "" { + path = strings.TrimSpace(u.Path) + } + if path == "" || !strings.HasPrefix(path, "/") || strings.HasPrefix(path, "//") { + return "/purchase" + } + if strings.TrimSpace(u.RawQuery) != "" { + path += "?" + u.RawQuery + } + return normalizePaymentRedirectPath(path) +} + +func normalizePaymentRedirectPath(path string) string { + path = strings.TrimSpace(path) + if path == "" { + return "/purchase" + } + if path == "/payment" { + return "/purchase" + } + if strings.HasPrefix(path, "/payment?") { + return "/purchase" + strings.TrimPrefix(path, "/payment") + } + return path +} + // --- Order Queries --- func (s *PaymentService) GetOrder(ctx context.Context, orderID, userID int64) (*dbent.PaymentOrder, error) { diff --git a/backend/internal/service/payment_order_jsapi_test.go b/backend/internal/service/payment_order_jsapi_test.go new file mode 100644 index 0000000000000000000000000000000000000000..a89d03801cb9c96d85f726a0cf52b3f3bb946669 --- /dev/null +++ b/backend/internal/service/payment_order_jsapi_test.go @@ -0,0 +1,33 @@ +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/payment" +) + +func TestUsesOfficialWxpayVisibleMethodDerivesFromEnabledProviderInstance(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("Official WeChat"). + SetConfig("{}"). + SetSupportedTypes("wxpay"). + SetEnabled(true). + SetSortOrder(1). + Save(ctx) + if err != nil { + t.Fatalf("create official wxpay instance: %v", err) + } + + svc := &PaymentService{ + configService: &PaymentConfigService{entClient: client}, + } + + if !svc.usesOfficialWxpayVisibleMethod(ctx) { + t.Fatal("expected official wxpay visible method to be detected from enabled provider instance") + } +} diff --git a/backend/internal/service/payment_order_lifecycle.go b/backend/internal/service/payment_order_lifecycle.go index 801471804c7597ea667ca86279b35177a3aec95f..ccab7c1181bf9f9c9fa2c218e6265e6e39c7810e 100644 --- a/backend/internal/service/payment_order_lifecycle.go +++ b/backend/internal/service/payment_order_lifecycle.go @@ -5,6 +5,7 @@ import ( "fmt" "log/slog" "strconv" + "strings" "time" dbent "github.com/Wei-Shaw/sub2api/ent" @@ -139,30 +140,86 @@ func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) s if err != nil { return "" } - // Use OutTradeNo as fallback when PaymentTradeNo is empty - // (e.g. EasyPay popup mode where trade_no arrives only via notify callback) - tradeNo := o.PaymentTradeNo - if tradeNo == "" { - tradeNo = o.OutTradeNo + queryRef := paymentOrderQueryReference(o, prov) + if queryRef == "" { + return "" } - resp, err := prov.QueryOrder(ctx, tradeNo) + resp, err := prov.QueryOrder(ctx, queryRef) if err != nil { slog.Warn("query upstream failed", "orderID", o.ID, "error", err) return "" } if resp.Status == payment.ProviderStatusPaid { - if err := s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: o.PaymentTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess}, prov.ProviderKey()); err != nil { + notificationTradeNo := o.PaymentTradeNo + if upstreamTradeNo := strings.TrimSpace(resp.TradeNo); paymentOrderShouldPersistUpstreamTradeNo(queryRef, upstreamTradeNo, notificationTradeNo) { + if _, updateErr := s.entClient.PaymentOrder.Update(). + Where(paymentorder.IDEQ(o.ID)). + SetPaymentTradeNo(upstreamTradeNo). + Save(ctx); updateErr != nil { + slog.Error("persist upstream trade no during checkPaid failed", "orderID", o.ID, "tradeNo", upstreamTradeNo, "error", updateErr) + } else { + o.PaymentTradeNo = upstreamTradeNo + } + notificationTradeNo = upstreamTradeNo + } + if err := s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: notificationTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess, Metadata: resp.Metadata}, prov.ProviderKey()); err != nil { slog.Error("fulfillment failed during checkPaid", "orderID", o.ID, "error", err) // Still return already_paid — order was paid, fulfillment can be retried } return checkPaidResultAlreadyPaid } if cp, ok := prov.(payment.CancelableProvider); ok { - _ = cp.CancelPayment(ctx, tradeNo) + _ = cp.CancelPayment(ctx, queryRef) } return "" } +func paymentOrderQueryReference(order *dbent.PaymentOrder, prov payment.Provider) string { + if order == nil { + return "" + } + + providerKey := "" + if prov != nil { + providerKey = strings.TrimSpace(prov.ProviderKey()) + } + if providerKey == "" { + if snapshot := psOrderProviderSnapshot(order); snapshot != nil { + providerKey = strings.TrimSpace(snapshot.ProviderKey) + } + } + if providerKey == "" { + providerKey = strings.TrimSpace(psStringValue(order.ProviderKey)) + } + if providerKey == "" { + providerKey = strings.TrimSpace(order.PaymentType) + } + + switch payment.GetBasePaymentType(providerKey) { + case payment.TypeAlipay, payment.TypeEasyPay, payment.TypeWxpay: + return strings.TrimSpace(order.OutTradeNo) + default: + if tradeNo := strings.TrimSpace(order.PaymentTradeNo); tradeNo != "" { + return tradeNo + } + return strings.TrimSpace(order.OutTradeNo) + } +} + +func paymentOrderShouldPersistUpstreamTradeNo(queryRef, upstreamTradeNo, currentTradeNo string) bool { + upstreamTradeNo = strings.TrimSpace(upstreamTradeNo) + if upstreamTradeNo == "" { + return false + } + if strings.EqualFold(upstreamTradeNo, strings.TrimSpace(currentTradeNo)) { + return false + } + if strings.EqualFold(upstreamTradeNo, strings.TrimSpace(queryRef)) { + return false + } + return true +} + // VerifyOrderByOutTradeNo actively queries the upstream provider to check // if a payment was made, and processes it if so. This handles the case where // the provider's notify callback was missed (e.g. EasyPay popup mode). @@ -190,8 +247,9 @@ func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo return o, nil } -// VerifyOrderPublic verifies payment status without user authentication. -// Used by the payment result page when the user's session has expired. +// VerifyOrderPublic returns the currently persisted public order state without +// triggering any upstream reconciliation. Signed resume-token recovery is the +// only public recovery path allowed to query upstream state. func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo string) (*dbent.PaymentOrder, error) { o, err := s.entClient.PaymentOrder.Query(). Where(paymentorder.OutTradeNo(outTradeNo)). @@ -199,15 +257,6 @@ func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo strin if err != nil { return nil, infraerrors.NotFound("NOT_FOUND", "order not found") } - if o.Status == OrderStatusPending || o.Status == OrderStatusExpired { - result := s.checkPaid(ctx, o) - if result == checkPaidResultAlreadyPaid { - o, err = s.entClient.PaymentOrder.Get(ctx, o.ID) - if err != nil { - return nil, fmt.Errorf("reload order: %w", err) - } - } - } return o, nil } @@ -236,22 +285,79 @@ func (s *PaymentService) ExpireTimedOutOrders(ctx context.Context) (int, error) // getOrderProvider creates a provider using the order's original instance config. // Falls back to registry lookup if instance ID is missing (legacy orders). func (s *PaymentService) getOrderProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) { - if o.ProviderInstanceID != nil && *o.ProviderInstanceID != "" { - instID, err := strconv.ParseInt(*o.ProviderInstanceID, 10, 64) - if err == nil { - cfg, err := s.loadBalancer.GetInstanceConfig(ctx, instID) - if err == nil { - providerKey := s.registry.GetProviderKey(o.PaymentType) - if providerKey == "" { - providerKey = o.PaymentType - } - p, err := provider.CreateProvider(providerKey, *o.ProviderInstanceID, cfg) - if err == nil { - return p, nil - } - } - } + inst, err := s.getOrderProviderInstance(ctx, o) + if err != nil { + return nil, fmt.Errorf("load order provider instance: %w", err) + } + if inst != nil { + return s.createProviderFromInstance(ctx, inst) + } + if !paymentOrderAllowsRegistryFallback(o) { + return nil, fmt.Errorf("order %d provider instance is unresolved", o.ID) + } + providerKey := paymentOrderFallbackProviderKey(s.registry, o) + if providerKey == "" { + return nil, fmt.Errorf("order %d provider fallback key is missing", o.ID) + } + if !s.webhookRegistryFallbackAllowed(ctx, providerKey) { + return nil, fmt.Errorf("order %d provider fallback is ambiguous for %s", o.ID, providerKey) } s.EnsureProviders(ctx) return s.registry.GetProvider(o.PaymentType) } + +func paymentOrderAllowsRegistryFallback(order *dbent.PaymentOrder) bool { + if order == nil { + return false + } + if psOrderProviderSnapshot(order) != nil { + return false + } + if strings.TrimSpace(psStringValue(order.ProviderInstanceID)) != "" { + return false + } + if strings.TrimSpace(psStringValue(order.ProviderKey)) != "" { + return false + } + return true +} + +func paymentOrderFallbackProviderKey(registry *payment.Registry, order *dbent.PaymentOrder) string { + if order == nil { + return "" + } + if registry != nil { + if key := strings.TrimSpace(registry.GetProviderKey(payment.PaymentType(order.PaymentType))); key != "" { + return key + } + } + return strings.TrimSpace(payment.GetBasePaymentType(strings.TrimSpace(order.PaymentType))) +} + +func (s *PaymentService) createProviderFromInstance(ctx context.Context, inst *dbent.PaymentProviderInstance) (payment.Provider, error) { + if inst == nil { + return nil, fmt.Errorf("payment provider instance is missing") + } + + cfg, err := s.loadBalancer.GetInstanceConfig(ctx, int64(inst.ID)) + if err != nil { + return nil, fmt.Errorf("load provider instance config: %w", err) + } + if inst.PaymentMode != "" { + cfg["paymentMode"] = inst.PaymentMode + } + + instID := strconv.FormatInt(int64(inst.ID), 10) + prov, err := provider.CreateProvider(inst.ProviderKey, instID, cfg) + if err != nil { + return nil, fmt.Errorf("create provider from instance: %w", err) + } + return prov, nil +} + +func psStringValue(value *string) string { + if value == nil { + return "" + } + return *value +} diff --git a/backend/internal/service/payment_order_lifecycle_test.go b/backend/internal/service/payment_order_lifecycle_test.go new file mode 100644 index 0000000000000000000000000000000000000000..39993a2f14828d561488df8d3962dbc8c7ea83ab --- /dev/null +++ b/backend/internal/service/payment_order_lifecycle_test.go @@ -0,0 +1,377 @@ +//go:build unit + +package service + +import ( + "context" + "database/sql" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/Wei-Shaw/sub2api/internal/payment" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +type paymentOrderLifecycleQueryProvider struct { + lastQueryTradeNo string + resp *payment.QueryOrderResponse +} + +type paymentOrderLifecycleRedeemRepo struct { + codesByCode map[string]*RedeemCode + useCalls []struct { + id int64 + userID int64 + } +} + +func (p *paymentOrderLifecycleQueryProvider) Name() string { + return "payment-order-lifecycle-query-provider" +} + +func (p *paymentOrderLifecycleQueryProvider) ProviderKey() string { return payment.TypeAlipay } + +func (p *paymentOrderLifecycleQueryProvider) SupportedTypes() []payment.PaymentType { + return []payment.PaymentType{payment.TypeAlipay} +} + +func (p *paymentOrderLifecycleQueryProvider) CreatePayment(context.Context, payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) { + panic("unexpected call") +} + +func (p *paymentOrderLifecycleQueryProvider) QueryOrder(_ context.Context, tradeNo string) (*payment.QueryOrderResponse, error) { + p.lastQueryTradeNo = tradeNo + return p.resp, nil +} + +func (p *paymentOrderLifecycleQueryProvider) VerifyNotification(context.Context, string, map[string]string) (*payment.PaymentNotification, error) { + panic("unexpected call") +} + +func (p *paymentOrderLifecycleQueryProvider) Refund(context.Context, payment.RefundRequest) (*payment.RefundResponse, error) { + panic("unexpected call") +} + +func (r *paymentOrderLifecycleRedeemRepo) Create(context.Context, *RedeemCode) error { + panic("unexpected call") +} + +func (r *paymentOrderLifecycleRedeemRepo) CreateBatch(context.Context, []RedeemCode) error { + panic("unexpected call") +} + +func (r *paymentOrderLifecycleRedeemRepo) GetByID(_ context.Context, id int64) (*RedeemCode, error) { + for _, code := range r.codesByCode { + if code.ID != id { + continue + } + cloned := *code + return &cloned, nil + } + return nil, ErrRedeemCodeNotFound +} + +func (r *paymentOrderLifecycleRedeemRepo) GetByCode(_ context.Context, code string) (*RedeemCode, error) { + redeemCode, ok := r.codesByCode[code] + if !ok { + return nil, ErrRedeemCodeNotFound + } + cloned := *redeemCode + return &cloned, nil +} + +func (r *paymentOrderLifecycleRedeemRepo) Update(context.Context, *RedeemCode) error { + panic("unexpected call") +} + +func (r *paymentOrderLifecycleRedeemRepo) Delete(context.Context, int64) error { + panic("unexpected call") +} + +func (r *paymentOrderLifecycleRedeemRepo) Use(_ context.Context, id, userID int64) error { + for code, redeemCode := range r.codesByCode { + if redeemCode.ID != id { + continue + } + now := time.Now().UTC() + redeemCode.Status = StatusUsed + redeemCode.UsedBy = &userID + redeemCode.UsedAt = &now + r.codesByCode[code] = redeemCode + r.useCalls = append(r.useCalls, struct { + id int64 + userID int64 + }{id: id, userID: userID}) + return nil + } + return ErrRedeemCodeNotFound +} + +func (r *paymentOrderLifecycleRedeemRepo) List(context.Context, pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error) { + panic("unexpected call") +} + +func (r *paymentOrderLifecycleRedeemRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string) ([]RedeemCode, *pagination.PaginationResult, error) { + panic("unexpected call") +} + +func (r *paymentOrderLifecycleRedeemRepo) ListByUser(context.Context, int64, int) ([]RedeemCode, error) { + panic("unexpected call") +} + +func (r *paymentOrderLifecycleRedeemRepo) ListByUserPaginated(context.Context, int64, pagination.PaginationParams, string) ([]RedeemCode, *pagination.PaginationResult, error) { + panic("unexpected call") +} + +func (r *paymentOrderLifecycleRedeemRepo) SumPositiveBalanceByUser(context.Context, int64) (float64, error) { + panic("unexpected call") +} + +func TestVerifyOrderByOutTradeNoBackfillsTradeNoFromPaidQuery(t *testing.T) { + ctx := context.Background() + client := newPaymentOrderLifecycleTestClient(t) + + user, err := client.User.Create(). + SetEmail("checkpaid@example.com"). + SetPasswordHash("hash"). + SetUsername("checkpaid-user"). + Save(ctx) + require.NoError(t, err) + + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(88). + SetPayAmount(88). + SetFeeRate(0). + SetRechargeCode("CHECKPAID-UPSTREAM-TRADE-NO"). + SetOutTradeNo("sub2_checkpaid_trade_no_missing"). + SetPaymentType(payment.TypeAlipay). + SetPaymentTradeNo(""). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusPending). + SetExpiresAt(time.Now().Add(time.Hour)). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + Save(ctx) + require.NoError(t, err) + + userRepo := &mockUserRepo{ + getByIDUser: &User{ + ID: user.ID, + Email: user.Email, + Username: user.Username, + Balance: 0, + }, + } + userRepo.updateBalanceFn = func(ctx context.Context, id int64, amount float64) error { + require.Equal(t, user.ID, id) + if userRepo.getByIDUser != nil { + userRepo.getByIDUser.Balance += amount + } + return nil + } + redeemRepo := &paymentOrderLifecycleRedeemRepo{ + codesByCode: map[string]*RedeemCode{ + order.RechargeCode: { + ID: 1, + Code: order.RechargeCode, + Type: RedeemTypeBalance, + Value: order.Amount, + Status: StatusUnused, + }, + }, + } + redeemService := NewRedeemService( + redeemRepo, + userRepo, + nil, + nil, + nil, + client, + nil, + ) + registry := payment.NewRegistry() + provider := &paymentOrderLifecycleQueryProvider{ + resp: &payment.QueryOrderResponse{ + TradeNo: "upstream-trade-123", + Status: payment.ProviderStatusPaid, + Amount: 88, + }, + } + registry.Register(provider) + + svc := &PaymentService{ + entClient: client, + registry: registry, + redeemService: redeemService, + userRepo: userRepo, + providersLoaded: true, + } + + got, err := svc.VerifyOrderByOutTradeNo(ctx, order.OutTradeNo, user.ID) + require.NoError(t, err) + require.Equal(t, order.OutTradeNo, provider.lastQueryTradeNo) + require.Equal(t, OrderStatusCompleted, got.Status) + require.Equal(t, "upstream-trade-123", got.PaymentTradeNo) + + reloaded, err := client.PaymentOrder.Get(ctx, order.ID) + require.NoError(t, err) + require.Equal(t, OrderStatusCompleted, reloaded.Status) + require.Equal(t, "upstream-trade-123", reloaded.PaymentTradeNo) + + require.Equal(t, 88.0, userRepo.getByIDUser.Balance) + require.Len(t, redeemRepo.useCalls, 1) + require.Equal(t, int64(1), redeemRepo.useCalls[0].id) + require.Equal(t, user.ID, redeemRepo.useCalls[0].userID) +} + +func TestVerifyOrderByOutTradeNoUsesOutTradeNoWhenPaymentTradeNoAlreadyExistsForAlipay(t *testing.T) { + ctx := context.Background() + client := newPaymentOrderLifecycleTestClient(t) + + user, err := client.User.Create(). + SetEmail("checkpaid-existing-trade@example.com"). + SetPasswordHash("hash"). + SetUsername("checkpaid-existing-trade-user"). + Save(ctx) + require.NoError(t, err) + + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(88). + SetPayAmount(88). + SetFeeRate(0). + SetRechargeCode("CHECKPAID-EXISTING-TRADE-NO"). + SetOutTradeNo("sub2_checkpaid_use_out_trade_no"). + SetPaymentType(payment.TypeAlipay). + SetPaymentTradeNo("upstream-trade-existing"). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusPending). + SetExpiresAt(time.Now().Add(time.Hour)). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + Save(ctx) + require.NoError(t, err) + + userRepo := &mockUserRepo{ + getByIDUser: &User{ + ID: user.ID, + Email: user.Email, + Username: user.Username, + Balance: 0, + }, + } + userRepo.updateBalanceFn = func(ctx context.Context, id int64, amount float64) error { + require.Equal(t, user.ID, id) + if userRepo.getByIDUser != nil { + userRepo.getByIDUser.Balance += amount + } + return nil + } + redeemRepo := &paymentOrderLifecycleRedeemRepo{ + codesByCode: map[string]*RedeemCode{ + order.RechargeCode: { + ID: 1, + Code: order.RechargeCode, + Type: RedeemTypeBalance, + Value: order.Amount, + Status: StatusUnused, + }, + }, + } + redeemService := NewRedeemService( + redeemRepo, + userRepo, + nil, + nil, + nil, + client, + nil, + ) + registry := payment.NewRegistry() + provider := &paymentOrderLifecycleQueryProvider{ + resp: &payment.QueryOrderResponse{ + TradeNo: "upstream-trade-existing", + Status: payment.ProviderStatusPaid, + Amount: 88, + }, + } + registry.Register(provider) + + svc := &PaymentService{ + entClient: client, + registry: registry, + redeemService: redeemService, + userRepo: userRepo, + providersLoaded: true, + } + + got, err := svc.VerifyOrderByOutTradeNo(ctx, order.OutTradeNo, user.ID) + require.NoError(t, err) + require.Equal(t, order.OutTradeNo, provider.lastQueryTradeNo) + require.Equal(t, "upstream-trade-existing", got.PaymentTradeNo) +} + +func TestPaymentOrderAllowsRegistryFallbackOnlyForLegacyOrdersWithoutPinnedProviderState(t *testing.T) { + t.Parallel() + + require.True(t, paymentOrderAllowsRegistryFallback(&dbent.PaymentOrder{ + PaymentType: payment.TypeAlipay, + })) + + instanceID := "12" + require.False(t, paymentOrderAllowsRegistryFallback(&dbent.PaymentOrder{ + PaymentType: payment.TypeAlipay, + ProviderInstanceID: &instanceID, + })) + + require.False(t, paymentOrderAllowsRegistryFallback(&dbent.PaymentOrder{ + PaymentType: payment.TypeAlipay, + ProviderSnapshot: map[string]any{ + "schema_version": 2, + "provider_instance_id": "12", + }, + })) +} + +func TestPaymentOrderQueryReferenceUsesOutTradeNoForOfficialProviders(t *testing.T) { + t.Parallel() + + order := &dbent.PaymentOrder{ + PaymentType: payment.TypeWxpay, + OutTradeNo: "sub2_out_trade_no", + PaymentTradeNo: "wx-transaction-id", + } + + require.Equal(t, "sub2_out_trade_no", paymentOrderQueryReference(order, &paymentOrderLifecycleQueryProvider{})) + require.Equal(t, "sub2_out_trade_no", paymentOrderQueryReference(order, paymentFulfillmentTestProvider{ + key: payment.TypeWxpay, + })) +} + +func newPaymentOrderLifecycleTestClient(t *testing.T) *dbent.Client { + t.Helper() + + db, err := sql.Open("sqlite", "file:payment_order_lifecycle?mode=memory&cache=shared&_fk=1") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + return client +} diff --git a/backend/internal/service/payment_order_provider_snapshot.go b/backend/internal/service/payment_order_provider_snapshot.go new file mode 100644 index 0000000000000000000000000000000000000000..bb60f9e25b884bb93f1c611a34c4aa9375ce1514 --- /dev/null +++ b/backend/internal/service/payment_order_provider_snapshot.go @@ -0,0 +1,205 @@ +package service + +import ( + "context" + "fmt" + "strconv" + "strings" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/internal/payment" +) + +type paymentOrderProviderSnapshot struct { + SchemaVersion int + ProviderInstanceID string + ProviderKey string + PaymentMode string + MerchantAppID string + MerchantID string + Currency string +} + +func psOrderProviderSnapshot(order *dbent.PaymentOrder) *paymentOrderProviderSnapshot { + if order == nil || len(order.ProviderSnapshot) == 0 { + return nil + } + + snapshot := &paymentOrderProviderSnapshot{ + SchemaVersion: psSnapshotIntValue(order.ProviderSnapshot["schema_version"]), + ProviderInstanceID: psSnapshotStringValue(order.ProviderSnapshot["provider_instance_id"]), + ProviderKey: psSnapshotStringValue(order.ProviderSnapshot["provider_key"]), + PaymentMode: psSnapshotStringValue(order.ProviderSnapshot["payment_mode"]), + MerchantAppID: psSnapshotStringValue(order.ProviderSnapshot["merchant_app_id"]), + MerchantID: psSnapshotStringValue(order.ProviderSnapshot["merchant_id"]), + Currency: psSnapshotStringValue(order.ProviderSnapshot["currency"]), + } + if snapshot.SchemaVersion == 0 && + snapshot.ProviderInstanceID == "" && + snapshot.ProviderKey == "" && + snapshot.PaymentMode == "" && + snapshot.MerchantAppID == "" && + snapshot.MerchantID == "" && + snapshot.Currency == "" { + return nil + } + return snapshot +} + +func psSnapshotStringValue(value any) string { + switch typed := value.(type) { + case string: + return strings.TrimSpace(typed) + default: + return "" + } +} + +func psSnapshotIntValue(value any) int { + switch typed := value.(type) { + case int: + return typed + case int32: + return int(typed) + case int64: + return int(typed) + case float32: + return int(typed) + case float64: + return int(typed) + case string: + n, err := strconv.Atoi(strings.TrimSpace(typed)) + if err == nil { + return n + } + } + return 0 +} + +func (s *PaymentService) resolveSnapshotOrderProviderInstance(ctx context.Context, order *dbent.PaymentOrder, snapshot *paymentOrderProviderSnapshot) (*dbent.PaymentProviderInstance, error) { + if s == nil || s.entClient == nil || order == nil || snapshot == nil { + return nil, nil + } + + snapshotInstanceID := strings.TrimSpace(snapshot.ProviderInstanceID) + columnInstanceID := strings.TrimSpace(psStringValue(order.ProviderInstanceID)) + if snapshotInstanceID == "" { + snapshotInstanceID = columnInstanceID + } + if snapshotInstanceID == "" { + return nil, fmt.Errorf("order %d provider snapshot is missing provider_instance_id", order.ID) + } + if columnInstanceID != "" && snapshot.ProviderInstanceID != "" && !strings.EqualFold(columnInstanceID, snapshot.ProviderInstanceID) { + return nil, fmt.Errorf("order %d provider snapshot instance mismatch: snapshot=%s order=%s", order.ID, snapshot.ProviderInstanceID, columnInstanceID) + } + + instID, err := strconv.ParseInt(snapshotInstanceID, 10, 64) + if err != nil { + return nil, fmt.Errorf("order %d provider snapshot instance id is invalid: %s", order.ID, snapshotInstanceID) + } + + inst, err := s.entClient.PaymentProviderInstance.Get(ctx, instID) + if err != nil { + if dbent.IsNotFound(err) { + return nil, fmt.Errorf("order %d provider snapshot instance %s is missing", order.ID, snapshotInstanceID) + } + return nil, err + } + + if snapshot.ProviderKey != "" && !strings.EqualFold(strings.TrimSpace(inst.ProviderKey), snapshot.ProviderKey) { + return nil, fmt.Errorf("order %d provider snapshot key mismatch: snapshot=%s instance=%s", order.ID, snapshot.ProviderKey, inst.ProviderKey) + } + + return inst, nil +} + +func expectedNotificationProviderKeyForOrder(registry *payment.Registry, order *dbent.PaymentOrder, instanceProviderKey string) string { + if order == nil { + return strings.TrimSpace(instanceProviderKey) + } + + orderProviderKey := psStringValue(order.ProviderKey) + if snapshot := psOrderProviderSnapshot(order); snapshot != nil && snapshot.ProviderKey != "" { + orderProviderKey = snapshot.ProviderKey + } + + return expectedNotificationProviderKey(registry, order.PaymentType, orderProviderKey, instanceProviderKey) +} + +func validateProviderSnapshotMetadata(order *dbent.PaymentOrder, providerKey string, metadata map[string]string) error { + if order == nil || len(metadata) == 0 { + return nil + } + + snapshot := psOrderProviderSnapshot(order) + if snapshot == nil { + return nil + } + + switch strings.TrimSpace(providerKey) { + case payment.TypeWxpay: + if expected := strings.TrimSpace(snapshot.MerchantAppID); expected != "" { + actual := strings.TrimSpace(metadata["appid"]) + if actual == "" { + return fmt.Errorf("wxpay notification missing appid") + } + if !strings.EqualFold(expected, actual) { + return fmt.Errorf("wxpay appid mismatch: expected %s, got %s", expected, actual) + } + } + if expected := strings.TrimSpace(snapshot.MerchantID); expected != "" { + actual := strings.TrimSpace(metadata["mchid"]) + if actual == "" { + return fmt.Errorf("wxpay notification missing mchid") + } + if !strings.EqualFold(expected, actual) { + return fmt.Errorf("wxpay mchid mismatch: expected %s, got %s", expected, actual) + } + } + if expected := strings.TrimSpace(snapshot.Currency); expected != "" { + actual := strings.ToUpper(strings.TrimSpace(metadata["currency"])) + if actual == "" { + return fmt.Errorf("wxpay notification missing currency") + } + if !strings.EqualFold(expected, actual) { + return fmt.Errorf("wxpay currency mismatch: expected %s, got %s", expected, actual) + } + } + if actual := strings.TrimSpace(metadata["trade_state"]); actual != "" && !strings.EqualFold(actual, "SUCCESS") { + return fmt.Errorf("wxpay trade_state mismatch: expected SUCCESS, got %s", actual) + } + case payment.TypeAlipay: + if expected := strings.TrimSpace(snapshot.MerchantAppID); expected != "" { + actual := strings.TrimSpace(metadata["app_id"]) + if actual == "" { + return fmt.Errorf("alipay app_id missing") + } + if !strings.EqualFold(expected, actual) { + return fmt.Errorf("alipay app_id mismatch: expected %s, got %s", expected, actual) + } + } + case payment.TypeEasyPay: + if expected := strings.TrimSpace(snapshot.MerchantID); expected != "" { + actual := strings.TrimSpace(metadata["pid"]) + if actual == "" { + return fmt.Errorf("easypay pid missing") + } + if !strings.EqualFold(expected, actual) { + return fmt.Errorf("easypay pid mismatch: expected %s, got %s", expected, actual) + } + } + } + + return nil +} + +func providerMerchantIdentityMetadata(prov payment.Provider) map[string]string { + if prov == nil { + return nil + } + reporter, ok := prov.(payment.MerchantIdentityProvider) + if !ok { + return nil + } + return reporter.MerchantIdentityMetadata() +} diff --git a/backend/internal/service/payment_order_provider_snapshot_test.go b/backend/internal/service/payment_order_provider_snapshot_test.go new file mode 100644 index 0000000000000000000000000000000000000000..efa013b52dd6e0e1d9f2013c6a320c53d28d8887 --- /dev/null +++ b/backend/internal/service/payment_order_provider_snapshot_test.go @@ -0,0 +1,172 @@ +//go:build unit + +package service + +import ( + "context" + "strconv" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/payment" + "github.com/stretchr/testify/require" +) + +func TestBuildPaymentOrderProviderSnapshot_ExcludesSensitiveConfig(t *testing.T) { + t.Parallel() + + sel := &payment.InstanceSelection{ + InstanceID: "12", + ProviderKey: payment.TypeWxpay, + SupportedTypes: "wxpay,wxpay_direct", + PaymentMode: "popup", + Config: map[string]string{ + "privateKey": "secret", + "apiV3Key": "secret-v3", + "appId": "wx-app-id", + }, + } + + snapshot := buildPaymentOrderProviderSnapshot(sel, CreateOrderRequest{}) + require.Equal(t, map[string]any{ + "schema_version": 2, + "provider_instance_id": "12", + "provider_key": payment.TypeWxpay, + "payment_mode": "popup", + "merchant_app_id": "wx-app-id", + "currency": "CNY", + }, snapshot) + require.NotContains(t, snapshot, "config") + require.NotContains(t, snapshot, "privateKey") + require.NotContains(t, snapshot, "apiV3Key") + require.NotContains(t, snapshot, "supported_types") + require.NotContains(t, snapshot, "instance_name") + require.NotContains(t, snapshot, "merchant_id") +} + +func TestCreateOrderInTx_WritesProviderSnapshot(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + + user, err := client.User.Create(). + SetEmail("snapshot@example.com"). + SetPasswordHash("hash"). + SetUsername("snapshot-user"). + Save(ctx) + require.NoError(t, err) + + instance, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeAlipay). + SetName("Primary Alipay"). + SetConfig(`{"secretKey":"do-not-copy"}`). + SetSupportedTypes("alipay,alipay_direct"). + SetPaymentMode("redirect"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + svc := &PaymentService{entClient: client} + order, err := svc.createOrderInTx( + ctx, + CreateOrderRequest{ + UserID: user.ID, + PaymentType: payment.TypeAlipay, + OrderType: payment.OrderTypeBalance, + ClientIP: "127.0.0.1", + SrcHost: "app.example.com", + }, + &User{ + ID: user.ID, + Email: user.Email, + Username: user.Username, + }, + nil, + &PaymentConfig{ + MaxPendingOrders: 3, + OrderTimeoutMin: 30, + }, + 88, + 88, + 0, + 88, + &payment.InstanceSelection{ + InstanceID: strconv.FormatInt(instance.ID, 10), + ProviderKey: payment.TypeAlipay, + SupportedTypes: "alipay,alipay_direct", + PaymentMode: "redirect", + Config: map[string]string{ + "secretKey": "do-not-copy", + }, + }, + ) + require.NoError(t, err) + require.Equal(t, strconv.FormatInt(instance.ID, 10), valueOrEmpty(order.ProviderInstanceID)) + require.Equal(t, payment.TypeAlipay, valueOrEmpty(order.ProviderKey)) + require.Equal(t, float64(2), order.ProviderSnapshot["schema_version"]) + require.Equal(t, strconv.FormatInt(instance.ID, 10), order.ProviderSnapshot["provider_instance_id"]) + require.Equal(t, payment.TypeAlipay, order.ProviderSnapshot["provider_key"]) + require.Equal(t, "redirect", order.ProviderSnapshot["payment_mode"]) + require.NotContains(t, order.ProviderSnapshot, "config") + require.NotContains(t, order.ProviderSnapshot, "secretKey") + require.NotContains(t, order.ProviderSnapshot, "supported_types") + require.NotContains(t, order.ProviderSnapshot, "instance_name") +} + +func TestBuildPaymentOrderProviderSnapshot_UsesWxpayJSAPIAppIDForOpenIDOrders(t *testing.T) { + t.Parallel() + + snapshot := buildPaymentOrderProviderSnapshot(&payment.InstanceSelection{ + InstanceID: "88", + ProviderKey: payment.TypeWxpay, + Config: map[string]string{ + "appId": "wx-open-app", + "mpAppId": "wx-mp-app", + "mchId": "mch-88", + }, + PaymentMode: "jsapi", + }, CreateOrderRequest{OpenID: "openid-123"}) + + require.Equal(t, "wx-mp-app", snapshot["merchant_app_id"]) + require.Equal(t, "mch-88", snapshot["merchant_id"]) + require.Equal(t, "CNY", snapshot["currency"]) +} + +func TestBuildPaymentOrderProviderSnapshot_IncludesAlipayMerchantIdentity(t *testing.T) { + t.Parallel() + + snapshot := buildPaymentOrderProviderSnapshot(&payment.InstanceSelection{ + InstanceID: "21", + ProviderKey: payment.TypeAlipay, + Config: map[string]string{ + "appId": "alipay-app-21", + "privateKey": "secret", + }, + PaymentMode: "redirect", + }, CreateOrderRequest{}) + + require.Equal(t, "alipay-app-21", snapshot["merchant_app_id"]) + require.NotContains(t, snapshot, "privateKey") +} + +func TestBuildPaymentOrderProviderSnapshot_IncludesEasyPayMerchantIdentity(t *testing.T) { + t.Parallel() + + snapshot := buildPaymentOrderProviderSnapshot(&payment.InstanceSelection{ + InstanceID: "66", + ProviderKey: payment.TypeEasyPay, + Config: map[string]string{ + "pid": "easypay-merchant-66", + "pkey": "secret", + }, + PaymentMode: "popup", + }, CreateOrderRequest{PaymentType: payment.TypeAlipay}) + + require.Equal(t, "easypay-merchant-66", snapshot["merchant_id"]) + require.NotContains(t, snapshot, "pkey") +} + +func valueOrEmpty(v *string) string { + if v == nil { + return "" + } + return *v +} diff --git a/backend/internal/service/payment_order_result_test.go b/backend/internal/service/payment_order_result_test.go new file mode 100644 index 0000000000000000000000000000000000000000..1675732399b71345cd802965213575dff8a84a17 --- /dev/null +++ b/backend/internal/service/payment_order_result_test.go @@ -0,0 +1,195 @@ +package service + +import ( + "context" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/internal/payment" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" +) + +func TestBuildCreateOrderResponseDefaultsToOrderCreated(t *testing.T) { + t.Parallel() + + expiresAt := time.Date(2026, 4, 16, 12, 0, 0, 0, time.UTC) + resp := buildCreateOrderResponse( + &dbent.PaymentOrder{ + ID: 42, + Amount: 12.34, + FeeRate: 0.03, + ExpiresAt: expiresAt, + OutTradeNo: "sub2_42", + }, + CreateOrderRequest{PaymentType: payment.TypeWxpay}, + 12.71, + &payment.InstanceSelection{PaymentMode: "qrcode"}, + &payment.CreatePaymentResponse{ + TradeNo: "sub2_42", + QRCode: "weixin://wxpay/bizpayurl?pr=test", + }, + payment.CreatePaymentResultOrderCreated, + ) + + if resp.ResultType != payment.CreatePaymentResultOrderCreated { + t.Fatalf("result type = %q, want %q", resp.ResultType, payment.CreatePaymentResultOrderCreated) + } + if resp.OutTradeNo != "sub2_42" { + t.Fatalf("out_trade_no = %q, want %q", resp.OutTradeNo, "sub2_42") + } + if resp.QRCode != "weixin://wxpay/bizpayurl?pr=test" { + t.Fatalf("qr_code = %q, want %q", resp.QRCode, "weixin://wxpay/bizpayurl?pr=test") + } + if resp.JSAPI != nil || resp.JSAPIPayload != nil { + t.Fatal("order_created response should not include jsapi payload") + } + if !resp.ExpiresAt.Equal(expiresAt) { + t.Fatalf("expires_at = %v, want %v", resp.ExpiresAt, expiresAt) + } +} + +func TestBuildCreateOrderResponseCopiesJSAPIPayload(t *testing.T) { + t.Parallel() + + jsapiPayload := &payment.WechatJSAPIPayload{ + AppID: "wx123", + TimeStamp: "1712345678", + NonceStr: "nonce-123", + Package: "prepay_id=wx123", + SignType: "RSA", + PaySign: "signed-payload", + } + resp := buildCreateOrderResponse( + &dbent.PaymentOrder{ + ID: 88, + Amount: 66.88, + FeeRate: 0.01, + ExpiresAt: time.Date(2026, 4, 16, 13, 0, 0, 0, time.UTC), + OutTradeNo: "sub2_88", + }, + CreateOrderRequest{PaymentType: payment.TypeWxpay}, + 67.55, + &payment.InstanceSelection{PaymentMode: "popup"}, + &payment.CreatePaymentResponse{ + TradeNo: "sub2_88", + ResultType: payment.CreatePaymentResultJSAPIReady, + JSAPI: jsapiPayload, + }, + payment.CreatePaymentResultJSAPIReady, + ) + + if resp.ResultType != payment.CreatePaymentResultJSAPIReady { + t.Fatalf("result type = %q, want %q", resp.ResultType, payment.CreatePaymentResultJSAPIReady) + } + if resp.JSAPI == nil || resp.JSAPIPayload == nil { + t.Fatal("expected jsapi payload aliases to be populated") + } + if resp.JSAPI != jsapiPayload || resp.JSAPIPayload != jsapiPayload { + t.Fatal("expected jsapi aliases to preserve the original pointer") + } +} + +func TestMaybeBuildWeChatOAuthRequiredResponse(t *testing.T) { + svc := newWeChatPaymentOAuthTestService(map[string]string{ + SettingKeyWeChatConnectEnabled: "true", + SettingKeyWeChatConnectAppID: "wx123456", + SettingKeyWeChatConnectAppSecret: "wechat-secret", + SettingKeyWeChatConnectMode: "mp", + SettingKeyWeChatConnectScopes: "snsapi_base", + SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", + SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback", + }) + + resp, err := svc.maybeBuildWeChatOAuthRequiredResponse(context.Background(), CreateOrderRequest{ + Amount: 12.5, + PaymentType: payment.TypeWxpay, + IsWeChatBrowser: true, + SrcURL: "https://merchant.example/payment?from=wechat", + OrderType: payment.OrderTypeBalance, + }, 12.5, 12.88, 0.03) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp == nil { + t.Fatal("expected oauth_required response, got nil") + } + if resp.ResultType != payment.CreatePaymentResultOAuthRequired { + t.Fatalf("result type = %q, want %q", resp.ResultType, payment.CreatePaymentResultOAuthRequired) + } + if resp.OAuth == nil { + t.Fatal("expected oauth payload, got nil") + } + if resp.OAuth.AppID != "wx123456" { + t.Fatalf("appid = %q, want %q", resp.OAuth.AppID, "wx123456") + } + if resp.OAuth.Scope != "snsapi_base" { + t.Fatalf("scope = %q, want %q", resp.OAuth.Scope, "snsapi_base") + } + if resp.OAuth.RedirectURL != "/auth/wechat/payment/callback" { + t.Fatalf("redirect_url = %q, want %q", resp.OAuth.RedirectURL, "/auth/wechat/payment/callback") + } + if resp.OAuth.AuthorizeURL != "/api/v1/auth/oauth/wechat/payment/start?amount=12.5&order_type=balance&payment_type=wxpay&redirect=%2Fpurchase%3Ffrom%3Dwechat&scope=snsapi_base" { + t.Fatalf("authorize_url = %q", resp.OAuth.AuthorizeURL) + } +} + +func TestMaybeBuildWeChatOAuthRequiredResponseRequiresMPConfigInWeChat(t *testing.T) { + t.Parallel() + + svc := newWeChatPaymentOAuthTestService(nil) + + resp, err := svc.maybeBuildWeChatOAuthRequiredResponse(context.Background(), CreateOrderRequest{ + Amount: 12.5, + PaymentType: payment.TypeWxpay, + IsWeChatBrowser: true, + SrcURL: "https://merchant.example/payment?from=wechat", + OrderType: payment.OrderTypeBalance, + }, 12.5, 12.88, 0.03) + if resp != nil { + t.Fatalf("expected nil response, got %+v", resp) + } + if err == nil { + t.Fatal("expected error, got nil") + } + + appErr := infraerrors.FromError(err) + if appErr.Reason != "WECHAT_PAYMENT_MP_NOT_CONFIGURED" { + t.Fatalf("reason = %q, want %q", appErr.Reason, "WECHAT_PAYMENT_MP_NOT_CONFIGURED") + } +} + +func TestMaybeBuildWeChatOAuthRequiredResponseForSelectionSkipsEasyPayProvider(t *testing.T) { + svc := newWeChatPaymentOAuthTestService(map[string]string{ + SettingKeyWeChatConnectEnabled: "true", + SettingKeyWeChatConnectAppID: "wx123456", + SettingKeyWeChatConnectAppSecret: "wechat-secret", + SettingKeyWeChatConnectMode: "mp", + SettingKeyWeChatConnectScopes: "snsapi_base", + SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", + SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback", + }) + + resp, err := svc.maybeBuildWeChatOAuthRequiredResponseForSelection(context.Background(), CreateOrderRequest{ + Amount: 12.5, + PaymentType: payment.TypeWxpay, + IsWeChatBrowser: true, + OrderType: payment.OrderTypeBalance, + }, 12.5, 12.88, 0.03, &payment.InstanceSelection{ + ProviderKey: payment.TypeEasyPay, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp != nil { + t.Fatalf("expected nil response, got %+v", resp) + } +} + +func newWeChatPaymentOAuthTestService(values map[string]string) *PaymentService { + return &PaymentService{ + configService: &PaymentConfigService{ + settingRepo: &paymentConfigSettingRepoStub{values: values}, + }, + } +} diff --git a/backend/internal/service/payment_refund.go b/backend/internal/service/payment_refund.go index c5bda763cd96324a5fda22ba91b6da14c93964ff..7521878c7dd5520f832b3ac9e3c719b7623b7033 100644 --- a/backend/internal/service/payment_refund.go +++ b/backend/internal/service/payment_refund.go @@ -12,6 +12,7 @@ import ( dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/paymentorder" + "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" "github.com/Wei-Shaw/sub2api/internal/payment" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" ) @@ -19,18 +20,133 @@ import ( // --- Refund Flow --- // getOrderProviderInstance looks up the provider instance that processed this order. -// Returns nil, nil for legacy orders without provider_instance_id. +// For legacy orders without provider_instance_id, it resolves only when the +// historical instance is uniquely identifiable from the stored order fields. func (s *PaymentService) getOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) { - if o.ProviderInstanceID == nil || *o.ProviderInstanceID == "" { + if s == nil || s.entClient == nil || o == nil { return nil, nil } - instID, err := strconv.ParseInt(*o.ProviderInstanceID, 10, 64) + + if snapshot := psOrderProviderSnapshot(o); snapshot != nil { + return s.resolveSnapshotOrderProviderInstance(ctx, o, snapshot) + } + + instIDStr := strings.TrimSpace(psStringValue(o.ProviderInstanceID)) + if instIDStr == "" { + return s.resolveUniqueLegacyOrderProviderInstance(ctx, o) + } + + instID, err := strconv.ParseInt(instIDStr, 10, 64) if err != nil { return nil, nil } return s.entClient.PaymentProviderInstance.Get(ctx, instID) } +// getRefundOrderProviderInstance resolves the provider instance for refund paths. +// Refunds must be pinned to an explicit historical binding, so legacy +// "best-effort" provider guessing is intentionally not allowed here. +func (s *PaymentService) getRefundOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) { + if s == nil || s.entClient == nil || o == nil { + return nil, nil + } + + if snapshot := psOrderProviderSnapshot(o); snapshot != nil { + return s.resolveSnapshotOrderProviderInstance(ctx, o, snapshot) + } + + instIDStr := strings.TrimSpace(psStringValue(o.ProviderInstanceID)) + if instIDStr == "" { + return nil, nil + } + + instID, err := strconv.ParseInt(instIDStr, 10, 64) + if err != nil { + return nil, fmt.Errorf("order %d refund provider instance id is invalid: %s", o.ID, instIDStr) + } + inst, err := s.entClient.PaymentProviderInstance.Get(ctx, instID) + if err != nil { + if dbent.IsNotFound(err) { + return nil, fmt.Errorf("order %d refund provider instance %s is missing", o.ID, instIDStr) + } + return nil, err + } + return inst, nil +} + +func (s *PaymentService) resolveUniqueLegacyOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) { + paymentType := payment.GetBasePaymentType(strings.TrimSpace(o.PaymentType)) + providerKey := strings.TrimSpace(psStringValue(o.ProviderKey)) + if providerKey != "" { + instances, err := s.entClient.PaymentProviderInstance.Query(). + Where(paymentproviderinstance.ProviderKeyEQ(providerKey)). + All(ctx) + if err != nil { + return nil, err + } + matched := psFilterLegacyOrderProviderInstances(paymentType, instances) + if len(matched) == 1 { + return matched[0], nil + } + return nil, nil + } + + if paymentType == "" { + return nil, nil + } + + instances, err := s.entClient.PaymentProviderInstance.Query(). + All(ctx) + if err != nil { + return nil, err + } + + matched := psFilterLegacyOrderProviderInstances(paymentType, instances) + if len(matched) == 1 { + return matched[0], nil + } + return nil, nil +} + +func psFilterLegacyOrderProviderInstances(orderPaymentType string, instances []*dbent.PaymentProviderInstance) []*dbent.PaymentProviderInstance { + if len(instances) == 0 { + return nil + } + if strings.TrimSpace(orderPaymentType) == "" { + return instances + } + var matched []*dbent.PaymentProviderInstance + for _, inst := range instances { + if psLegacyOrderMatchesInstance(orderPaymentType, inst) { + matched = append(matched, inst) + } + } + return matched +} + +func psLegacyOrderMatchesInstance(orderPaymentType string, inst *dbent.PaymentProviderInstance) bool { + if inst == nil { + return false + } + + baseType := payment.GetBasePaymentType(strings.TrimSpace(orderPaymentType)) + instanceProviderKey := strings.TrimSpace(inst.ProviderKey) + if baseType == "" { + return false + } + + if baseType == payment.TypeStripe { + return instanceProviderKey == payment.TypeStripe + } + if instanceProviderKey == payment.TypeStripe { + return false + } + if instanceProviderKey == baseType { + return true + } + return payment.InstanceSupportsType(inst.SupportedTypes, baseType) +} + func (s *PaymentService) RequestRefund(ctx context.Context, oid, uid int64, reason string) error { o, err := s.validateRefundRequest(ctx, oid, uid) if err != nil { @@ -72,7 +188,7 @@ func (s *PaymentService) validateRefundRequest(ctx context.Context, oid, uid int return nil, infraerrors.BadRequest("INVALID_STATUS", "only completed orders can request refund") } // Check provider instance allows user refund - inst, err := s.getOrderProviderInstance(ctx, o) + inst, err := s.getRefundOrderProviderInstance(ctx, o) if err != nil || inst == nil { return nil, infraerrors.Forbidden("USER_REFUND_DISABLED", "refund is not available for this order") } @@ -92,7 +208,7 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float return nil, nil, infraerrors.BadRequest("INVALID_STATUS", "order status does not allow refund") } // Check provider instance allows admin refund - inst, instErr := s.getOrderProviderInstance(ctx, o) + inst, instErr := s.getRefundOrderProviderInstance(ctx, o) if instErr != nil { slog.Warn("refund: provider instance lookup failed", "orderID", oid, "error", instErr) return nil, nil, infraerrors.InternalServer("PROVIDER_LOOKUP_FAILED", "failed to look up payment provider for this order") @@ -217,6 +333,12 @@ func (s *PaymentService) gwRefund(ctx context.Context, p *RefundPlan) error { if err != nil { return fmt.Errorf("get refund provider: %w", err) } + if err := validateProviderSnapshotMetadata(p.Order, prov.ProviderKey(), providerMerchantIdentityMetadata(prov)); err != nil { + s.writeAuditLog(ctx, p.Order.ID, "REFUND_PROVIDER_METADATA_MISMATCH", "admin", map[string]any{ + "detail": err.Error(), + }) + return err + } _, err = prov.Refund(ctx, payment.RefundRequest{ TradeNo: p.Order.PaymentTradeNo, OrderID: p.Order.OutTradeNo, @@ -229,7 +351,14 @@ func (s *PaymentService) gwRefund(ctx context.Context, p *RefundPlan) error { // getRefundProvider creates a provider using the order's original instance config. // Delegates to getOrderProvider which handles instance lookup and fallback. func (s *PaymentService) getRefundProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) { - return s.getOrderProvider(ctx, o) + inst, err := s.getRefundOrderProviderInstance(ctx, o) + if err != nil { + return nil, err + } + if inst == nil { + return nil, fmt.Errorf("refund provider instance is unavailable for order %d", o.ID) + } + return s.createProviderFromInstance(ctx, inst) } func (s *PaymentService) handleGwFail(ctx context.Context, p *RefundPlan, gErr error) (*RefundResult, error) { diff --git a/backend/internal/service/payment_refund_test.go b/backend/internal/service/payment_refund_test.go new file mode 100644 index 0000000000000000000000000000000000000000..ca5b62cb28d94551d14779d46c860e808b2e68a4 --- /dev/null +++ b/backend/internal/service/payment_refund_test.go @@ -0,0 +1,186 @@ +//go:build unit + +package service + +import ( + "context" + "strconv" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/payment" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/stretchr/testify/require" +) + +func TestValidateRefundRequestRejectsLegacyGuessedProviderInstance(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + + user, err := client.User.Create(). + SetEmail("refund-legacy@example.com"). + SetPasswordHash("hash"). + SetUsername("refund-legacy-user"). + Save(ctx) + require.NoError(t, err) + + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeAlipay). + SetName("alipay-refund-instance"). + SetConfig("{}"). + SetSupportedTypes("alipay"). + SetEnabled(true). + SetAllowUserRefund(true). + SetRefundEnabled(true). + Save(ctx) + require.NoError(t, err) + + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(88). + SetPayAmount(88). + SetFeeRate(0). + SetRechargeCode("REFUND-LEGACY-ORDER"). + SetOutTradeNo("sub2_refund_legacy_order"). + SetPaymentType(payment.TypeAlipay). + SetPaymentTradeNo("trade-legacy-refund"). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusCompleted). + SetExpiresAt(time.Now().Add(time.Hour)). + SetPaidAt(time.Now()). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + Save(ctx) + require.NoError(t, err) + + svc := &PaymentService{ + entClient: client, + } + + _, err = svc.validateRefundRequest(ctx, order.ID, user.ID) + require.Error(t, err) + require.Equal(t, "USER_REFUND_DISABLED", infraerrors.Reason(err)) +} + +func TestPrepareRefundRejectsLegacyGuessedProviderInstance(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + + user, err := client.User.Create(). + SetEmail("refund-legacy-admin@example.com"). + SetPasswordHash("hash"). + SetUsername("refund-legacy-admin-user"). + Save(ctx) + require.NoError(t, err) + + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeAlipay). + SetName("alipay-refund-admin-instance"). + SetConfig("{}"). + SetSupportedTypes("alipay"). + SetEnabled(true). + SetAllowUserRefund(true). + SetRefundEnabled(true). + Save(ctx) + require.NoError(t, err) + + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(188). + SetPayAmount(188). + SetFeeRate(0). + SetRechargeCode("REFUND-LEGACY-ADMIN-ORDER"). + SetOutTradeNo("sub2_refund_legacy_admin_order"). + SetPaymentType(payment.TypeAlipay). + SetPaymentTradeNo("trade-legacy-admin-refund"). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusCompleted). + SetExpiresAt(time.Now().Add(time.Hour)). + SetPaidAt(time.Now()). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + Save(ctx) + require.NoError(t, err) + + svc := &PaymentService{ + entClient: client, + } + + plan, result, err := svc.PrepareRefund(ctx, order.ID, 0, "", false, false) + require.Nil(t, plan) + require.Nil(t, result) + require.Error(t, err) + require.Equal(t, "REFUND_DISABLED", infraerrors.Reason(err)) +} + +func TestGwRefundRejectsAlipayMerchantIdentitySnapshotMismatch(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + + user, err := client.User.Create(). + SetEmail("refund-snapshot-mismatch@example.com"). + SetPasswordHash("hash"). + SetUsername("refund-snapshot-mismatch-user"). + Save(ctx) + require.NoError(t, err) + + inst, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeAlipay). + SetName("alipay-refund-mismatch-instance"). + SetConfig(encryptWebhookProviderConfig(t, map[string]string{ + "appId": "runtime-alipay-app", + "privateKey": "runtime-private-key", + })). + SetSupportedTypes("alipay"). + SetEnabled(true). + SetRefundEnabled(true). + Save(ctx) + require.NoError(t, err) + + instID := strconv.FormatInt(inst.ID, 10) + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(88). + SetPayAmount(88). + SetFeeRate(0). + SetRechargeCode("REFUND-SNAPSHOT-MISMATCH-ORDER"). + SetOutTradeNo("sub2_refund_snapshot_mismatch_order"). + SetPaymentType(payment.TypeAlipay). + SetPaymentTradeNo("trade-refund-snapshot-mismatch"). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusCompleted). + SetExpiresAt(time.Now().Add(time.Hour)). + SetPaidAt(time.Now()). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + SetProviderInstanceID(instID). + SetProviderKey(payment.TypeAlipay). + SetProviderSnapshot(map[string]any{ + "schema_version": 2, + "provider_instance_id": instID, + "provider_key": payment.TypeAlipay, + "merchant_app_id": "expected-alipay-app", + }). + Save(ctx) + require.NoError(t, err) + + svc := &PaymentService{ + entClient: client, + loadBalancer: newWebhookProviderTestLoadBalancer(client), + } + + err = svc.gwRefund(ctx, &RefundPlan{ + OrderID: order.ID, + Order: order, + RefundAmount: order.Amount, + GatewayAmount: order.Amount, + Reason: "snapshot mismatch", + }) + require.ErrorContains(t, err, "alipay app_id mismatch") +} diff --git a/backend/internal/service/payment_resume_lookup.go b/backend/internal/service/payment_resume_lookup.go new file mode 100644 index 0000000000000000000000000000000000000000..05626aa674a59dd7b9e0876b68fa583b6f9ba9b4 --- /dev/null +++ b/backend/internal/service/payment_resume_lookup.go @@ -0,0 +1,59 @@ +package service + +import ( + "context" + "fmt" + "strings" + + dbent "github.com/Wei-Shaw/sub2api/ent" +) + +func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token string) (*dbent.PaymentOrder, error) { + claims, err := s.paymentResume().ParseToken(strings.TrimSpace(token)) + if err != nil { + return nil, err + } + + order, err := s.entClient.PaymentOrder.Get(ctx, claims.OrderID) + if err != nil { + return nil, fmt.Errorf("get order by resume token: %w", err) + } + if claims.UserID > 0 && order.UserID != claims.UserID { + return nil, fmt.Errorf("resume token user mismatch") + } + snapshot := psOrderProviderSnapshot(order) + orderProviderInstanceID := strings.TrimSpace(psStringValue(order.ProviderInstanceID)) + orderProviderKey := strings.TrimSpace(psStringValue(order.ProviderKey)) + if snapshot != nil { + if snapshot.ProviderInstanceID != "" { + orderProviderInstanceID = snapshot.ProviderInstanceID + } + if snapshot.ProviderKey != "" { + orderProviderKey = snapshot.ProviderKey + } + } + if claims.ProviderInstanceID != "" && orderProviderInstanceID != claims.ProviderInstanceID { + return nil, fmt.Errorf("resume token provider instance mismatch") + } + if claims.ProviderKey != "" && orderProviderKey != claims.ProviderKey { + return nil, fmt.Errorf("resume token provider key mismatch") + } + if claims.PaymentType != "" && strings.TrimSpace(order.PaymentType) != claims.PaymentType { + return nil, fmt.Errorf("resume token payment type mismatch") + } + if order.Status == OrderStatusPending || order.Status == OrderStatusExpired { + result := s.checkPaid(ctx, order) + if result == checkPaidResultAlreadyPaid { + order, err = s.entClient.PaymentOrder.Get(ctx, order.ID) + if err != nil { + return nil, fmt.Errorf("reload order by resume token: %w", err) + } + } + } + + return order, nil +} + +func (s *PaymentService) ParseWeChatPaymentResumeToken(token string) (*WeChatPaymentResumeClaims, error) { + return s.paymentResume().ParseWeChatPaymentResumeToken(strings.TrimSpace(token)) +} diff --git a/backend/internal/service/payment_resume_lookup_test.go b/backend/internal/service/payment_resume_lookup_test.go new file mode 100644 index 0000000000000000000000000000000000000000..946e7aa2126c6cf032b14704a61cc2dfae92e01c --- /dev/null +++ b/backend/internal/service/payment_resume_lookup_test.go @@ -0,0 +1,304 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/payment" + "github.com/stretchr/testify/require" +) + +type paymentResumeLookupProvider struct { + queryCount int +} + +func (p *paymentResumeLookupProvider) Name() string { return "resume-lookup-provider" } + +func (p *paymentResumeLookupProvider) ProviderKey() string { return payment.TypeAlipay } + +func (p *paymentResumeLookupProvider) SupportedTypes() []payment.PaymentType { + return []payment.PaymentType{payment.TypeAlipay} +} + +func (p *paymentResumeLookupProvider) CreatePayment(context.Context, payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) { + panic("unexpected call") +} + +func (p *paymentResumeLookupProvider) QueryOrder(context.Context, string) (*payment.QueryOrderResponse, error) { + p.queryCount++ + return &payment.QueryOrderResponse{Status: payment.ProviderStatusPending}, nil +} + +func (p *paymentResumeLookupProvider) VerifyNotification(context.Context, string, map[string]string) (*payment.PaymentNotification, error) { + panic("unexpected call") +} + +func (p *paymentResumeLookupProvider) Refund(context.Context, payment.RefundRequest) (*payment.RefundResponse, error) { + panic("unexpected call") +} + +func TestGetPublicOrderByResumeTokenReturnsMatchingOrder(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + user, err := client.User.Create(). + SetEmail("resume@example.com"). + SetPasswordHash("hash"). + SetUsername("resume-user"). + Save(ctx) + require.NoError(t, err) + + instanceID := "12" + providerKey := payment.TypeEasyPay + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(88). + SetPayAmount(88). + SetFeeRate(0). + SetRechargeCode("RESUME-ORDER"). + SetOutTradeNo("sub2_resume_lookup"). + SetPaymentType(payment.TypeAlipay). + SetPaymentTradeNo("trade-1"). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusPending). + SetExpiresAt(time.Now().Add(time.Hour)). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + SetProviderInstanceID(instanceID). + SetProviderKey(providerKey). + Save(ctx) + require.NoError(t, err) + + resumeSvc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef")) + token, err := resumeSvc.CreateToken(ResumeTokenClaims{ + OrderID: order.ID, + UserID: user.ID, + ProviderInstanceID: instanceID, + ProviderKey: providerKey, + PaymentType: payment.TypeAlipay, + CanonicalReturnURL: "https://app.example.com/payment/result", + }) + require.NoError(t, err) + + svc := &PaymentService{ + entClient: client, + resumeService: resumeSvc, + } + + got, err := svc.GetPublicOrderByResumeToken(ctx, token) + require.NoError(t, err) + require.Equal(t, order.ID, got.ID) +} + +func TestGetPublicOrderByResumeTokenRejectsSnapshotMismatch(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + user, err := client.User.Create(). + SetEmail("resume-mismatch@example.com"). + SetPasswordHash("hash"). + SetUsername("resume-mismatch-user"). + Save(ctx) + require.NoError(t, err) + + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(88). + SetPayAmount(88). + SetFeeRate(0). + SetRechargeCode("RESUME-MISMATCH"). + SetOutTradeNo("sub2_resume_lookup_mismatch"). + SetPaymentType(payment.TypeAlipay). + SetPaymentTradeNo("trade-2"). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusPending). + SetExpiresAt(time.Now().Add(time.Hour)). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + SetProviderInstanceID("12"). + SetProviderKey(payment.TypeEasyPay). + Save(ctx) + require.NoError(t, err) + + resumeSvc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef")) + token, err := resumeSvc.CreateToken(ResumeTokenClaims{ + OrderID: order.ID, + UserID: user.ID, + ProviderInstanceID: "99", + ProviderKey: payment.TypeEasyPay, + PaymentType: payment.TypeAlipay, + CanonicalReturnURL: "https://app.example.com/payment/result", + }) + require.NoError(t, err) + + svc := &PaymentService{ + entClient: client, + resumeService: resumeSvc, + } + + _, err = svc.GetPublicOrderByResumeToken(ctx, token) + require.Error(t, err) + require.Contains(t, err.Error(), "resume token") +} + +func TestGetPublicOrderByResumeTokenUsesSnapshotAuthorityWhenColumnsDiffer(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + user, err := client.User.Create(). + SetEmail("resume-snapshot-authority@example.com"). + SetPasswordHash("hash"). + SetUsername("resume-snapshot-authority-user"). + Save(ctx) + require.NoError(t, err) + + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(88). + SetPayAmount(88). + SetFeeRate(0). + SetRechargeCode("RESUME-SNAPSHOT-AUTHORITY"). + SetOutTradeNo("sub2_resume_snapshot_authority"). + SetPaymentType(payment.TypeAlipay). + SetPaymentTradeNo("trade-snapshot-authority"). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusPending). + SetExpiresAt(time.Now().Add(time.Hour)). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + SetProviderInstanceID("legacy-column-instance"). + SetProviderKey(payment.TypeAlipay). + SetProviderSnapshot(map[string]any{ + "schema_version": 2, + "provider_instance_id": "snapshot-instance", + "provider_key": payment.TypeEasyPay, + }). + Save(ctx) + require.NoError(t, err) + + resumeSvc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef")) + token, err := resumeSvc.CreateToken(ResumeTokenClaims{ + OrderID: order.ID, + UserID: user.ID, + ProviderInstanceID: "snapshot-instance", + ProviderKey: payment.TypeEasyPay, + PaymentType: payment.TypeAlipay, + CanonicalReturnURL: "https://app.example.com/payment/result", + }) + require.NoError(t, err) + + svc := &PaymentService{ + entClient: client, + resumeService: resumeSvc, + } + + got, err := svc.GetPublicOrderByResumeToken(ctx, token) + require.NoError(t, err) + require.Equal(t, order.ID, got.ID) +} + +func TestGetPublicOrderByResumeTokenChecksUpstreamForPendingOrder(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + user, err := client.User.Create(). + SetEmail("resume-refresh@example.com"). + SetPasswordHash("hash"). + SetUsername("resume-refresh-user"). + Save(ctx) + require.NoError(t, err) + + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(88). + SetPayAmount(88). + SetFeeRate(0). + SetRechargeCode("RESUME-PENDING"). + SetOutTradeNo("sub2_resume_lookup_pending"). + SetPaymentType(payment.TypeAlipay). + SetPaymentTradeNo("trade-pending"). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusPending). + SetExpiresAt(time.Now().Add(time.Hour)). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + Save(ctx) + require.NoError(t, err) + + resumeSvc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef")) + token, err := resumeSvc.CreateToken(ResumeTokenClaims{ + OrderID: order.ID, + UserID: user.ID, + PaymentType: payment.TypeAlipay, + CanonicalReturnURL: "https://app.example.com/payment/result", + }) + require.NoError(t, err) + + registry := payment.NewRegistry() + provider := &paymentResumeLookupProvider{} + registry.Register(provider) + + svc := &PaymentService{ + entClient: client, + registry: registry, + resumeService: resumeSvc, + providersLoaded: true, + } + + got, err := svc.GetPublicOrderByResumeToken(ctx, token) + require.NoError(t, err) + require.Equal(t, order.ID, got.ID) + require.Equal(t, 1, provider.queryCount) +} + +func TestVerifyOrderPublicDoesNotCheckUpstreamForPendingOrder(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + user, err := client.User.Create(). + SetEmail("public-verify@example.com"). + SetPasswordHash("hash"). + SetUsername("public-verify-user"). + Save(ctx) + require.NoError(t, err) + + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(88). + SetPayAmount(88). + SetFeeRate(0). + SetRechargeCode("PUBLIC-VERIFY"). + SetOutTradeNo("sub2_public_verify_pending"). + SetPaymentType(payment.TypeAlipay). + SetPaymentTradeNo("trade-public-verify"). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusPending). + SetExpiresAt(time.Now().Add(time.Hour)). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + Save(ctx) + require.NoError(t, err) + + registry := payment.NewRegistry() + provider := &paymentResumeLookupProvider{} + registry.Register(provider) + + svc := &PaymentService{ + entClient: client, + registry: registry, + providersLoaded: true, + } + + got, err := svc.VerifyOrderPublic(ctx, order.OutTradeNo) + require.NoError(t, err) + require.Equal(t, order.ID, got.ID) + require.Equal(t, 0, provider.queryCount) +} diff --git a/backend/internal/service/payment_resume_service.go b/backend/internal/service/payment_resume_service.go new file mode 100644 index 0000000000000000000000000000000000000000..6e8acccbfd4c9ed7affa7e93222fb3fb8b7650b4 --- /dev/null +++ b/backend/internal/service/payment_resume_service.go @@ -0,0 +1,418 @@ +package service + +import ( + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "net" + "net/url" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/payment" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" +) + +const paymentResultReturnPath = "/payment/result" + +const ( + PaymentSourceHostedRedirect = "hosted_redirect" + PaymentSourceWechatInAppResume = "wechat_in_app_resume" + + SettingPaymentVisibleMethodAlipaySource = "payment_visible_method_alipay_source" + SettingPaymentVisibleMethodWxpaySource = "payment_visible_method_wxpay_source" + SettingPaymentVisibleMethodAlipayEnabled = "payment_visible_method_alipay_enabled" + SettingPaymentVisibleMethodWxpayEnabled = "payment_visible_method_wxpay_enabled" + + VisibleMethodSourceOfficialAlipay = "official_alipay" + VisibleMethodSourceEasyPayAlipay = "easypay_alipay" + VisibleMethodSourceOfficialWechat = "official_wxpay" + VisibleMethodSourceEasyPayWechat = "easypay_wxpay" + + wechatPaymentResumeTokenType = "wechat_payment_resume" + + paymentResumeNotConfiguredCode = "PAYMENT_RESUME_NOT_CONFIGURED" + paymentResumeNotConfiguredMessage = "payment resume tokens require a configured signing key" + + paymentResumeTokenTTL = 24 * time.Hour + wechatPaymentResumeTokenTTL = 15 * time.Minute +) + +type ResumeTokenClaims struct { + OrderID int64 `json:"oid"` + UserID int64 `json:"uid,omitempty"` + ProviderInstanceID string `json:"pi,omitempty"` + ProviderKey string `json:"pk,omitempty"` + PaymentType string `json:"pt,omitempty"` + CanonicalReturnURL string `json:"ru,omitempty"` + IssuedAt int64 `json:"iat"` + ExpiresAt int64 `json:"exp,omitempty"` +} + +type WeChatPaymentResumeClaims struct { + TokenType string `json:"tk,omitempty"` + OpenID string `json:"openid"` + PaymentType string `json:"pt,omitempty"` + Amount string `json:"amt,omitempty"` + OrderType string `json:"ot,omitempty"` + PlanID int64 `json:"pid,omitempty"` + RedirectTo string `json:"rd,omitempty"` + Scope string `json:"scp,omitempty"` + IssuedAt int64 `json:"iat"` + ExpiresAt int64 `json:"exp,omitempty"` +} + +type PaymentResumeService struct { + signingKey []byte +} + +type visibleMethodLoadBalancer struct { + inner payment.LoadBalancer + configService *PaymentConfigService +} + +func NewPaymentResumeService(signingKey []byte) *PaymentResumeService { + return &PaymentResumeService{signingKey: signingKey} +} + +func (s *PaymentResumeService) isSigningConfigured() bool { + return s != nil && len(s.signingKey) > 0 +} + +func (s *PaymentResumeService) ensureSigningKey() error { + if s.isSigningConfigured() { + return nil + } + return infraerrors.ServiceUnavailable(paymentResumeNotConfiguredCode, paymentResumeNotConfiguredMessage) +} + +func NormalizeVisibleMethod(method string) string { + return payment.GetBasePaymentType(strings.TrimSpace(method)) +} + +func NormalizeVisibleMethods(methods []string) []string { + if len(methods) == 0 { + return nil + } + seen := make(map[string]struct{}, len(methods)) + out := make([]string, 0, len(methods)) + for _, method := range methods { + normalized := NormalizeVisibleMethod(method) + if normalized == "" { + continue + } + if _, ok := seen[normalized]; ok { + continue + } + seen[normalized] = struct{}{} + out = append(out, normalized) + } + return out +} + +func NormalizePaymentSource(source string) string { + switch strings.TrimSpace(strings.ToLower(source)) { + case "", PaymentSourceHostedRedirect: + return PaymentSourceHostedRedirect + case "wechat_in_app", "wxpay_resume", PaymentSourceWechatInAppResume: + return PaymentSourceWechatInAppResume + default: + return strings.TrimSpace(strings.ToLower(source)) + } +} + +func NormalizeVisibleMethodSource(method, source string) string { + switch NormalizeVisibleMethod(method) { + case payment.TypeAlipay: + switch strings.TrimSpace(strings.ToLower(source)) { + case VisibleMethodSourceOfficialAlipay, payment.TypeAlipay, payment.TypeAlipayDirect, "official": + return VisibleMethodSourceOfficialAlipay + case VisibleMethodSourceEasyPayAlipay, payment.TypeEasyPay: + return VisibleMethodSourceEasyPayAlipay + } + case payment.TypeWxpay: + switch strings.TrimSpace(strings.ToLower(source)) { + case VisibleMethodSourceOfficialWechat, payment.TypeWxpay, payment.TypeWxpayDirect, "wechat", "official": + return VisibleMethodSourceOfficialWechat + case VisibleMethodSourceEasyPayWechat, payment.TypeEasyPay: + return VisibleMethodSourceEasyPayWechat + } + } + return "" +} + +func VisibleMethodProviderKeyForSource(method, source string) (string, bool) { + switch NormalizeVisibleMethodSource(method, source) { + case VisibleMethodSourceOfficialAlipay: + return payment.TypeAlipay, NormalizeVisibleMethod(method) == payment.TypeAlipay + case VisibleMethodSourceEasyPayAlipay: + return payment.TypeEasyPay, NormalizeVisibleMethod(method) == payment.TypeAlipay + case VisibleMethodSourceOfficialWechat: + return payment.TypeWxpay, NormalizeVisibleMethod(method) == payment.TypeWxpay + case VisibleMethodSourceEasyPayWechat: + return payment.TypeEasyPay, NormalizeVisibleMethod(method) == payment.TypeWxpay + default: + return "", false + } +} + +func newVisibleMethodLoadBalancer(inner payment.LoadBalancer, configService *PaymentConfigService) payment.LoadBalancer { + if inner == nil || configService == nil || configService.entClient == nil { + return inner + } + return &visibleMethodLoadBalancer{inner: inner, configService: configService} +} + +func (lb *visibleMethodLoadBalancer) GetInstanceConfig(ctx context.Context, instanceID int64) (map[string]string, error) { + return lb.inner.GetInstanceConfig(ctx, instanceID) +} + +func (lb *visibleMethodLoadBalancer) SelectInstance(ctx context.Context, providerKey string, paymentType payment.PaymentType, strategy payment.Strategy, orderAmount float64) (*payment.InstanceSelection, error) { + visibleMethod := NormalizeVisibleMethod(paymentType) + if providerKey != "" || (visibleMethod != payment.TypeAlipay && visibleMethod != payment.TypeWxpay) { + return lb.inner.SelectInstance(ctx, providerKey, paymentType, strategy, orderAmount) + } + + inst, err := lb.configService.resolveEnabledVisibleMethodInstance(ctx, visibleMethod) + if err != nil { + return nil, err + } + if inst == nil { + return nil, fmt.Errorf("visible payment method %s has no enabled provider instance", visibleMethod) + } + return lb.inner.SelectInstance(ctx, inst.ProviderKey, paymentType, strategy, orderAmount) +} + +func visibleMethodEnabledSettingKey(method string) string { + switch NormalizeVisibleMethod(method) { + case payment.TypeAlipay: + return SettingPaymentVisibleMethodAlipayEnabled + case payment.TypeWxpay: + return SettingPaymentVisibleMethodWxpayEnabled + default: + return "" + } +} + +func visibleMethodSourceSettingKey(method string) string { + switch NormalizeVisibleMethod(method) { + case payment.TypeAlipay: + return SettingPaymentVisibleMethodAlipaySource + case payment.TypeWxpay: + return SettingPaymentVisibleMethodWxpaySource + default: + return "" + } +} + +func CanonicalizeReturnURL(raw string, srcHost string) (string, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return "", nil + } + parsed, err := url.Parse(raw) + if err != nil || !parsed.IsAbs() || parsed.Host == "" { + return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must be an absolute http/https URL") + } + if parsed.Scheme != "http" && parsed.Scheme != "https" { + return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must use http or https") + } + parsed.Fragment = "" + if parsed.Path == "" { + parsed.Path = "/" + } + if parsed.Path != paymentResultReturnPath { + return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must target the canonical internal payment result page") + } + if !sameOriginHost(parsed.Host, srcHost) { + return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must use the same host as the current site") + } + return parsed.String(), nil +} + +func buildPaymentReturnURL(base string, orderID int64, resumeToken string) (string, error) { + canonical := strings.TrimSpace(base) + if canonical == "" { + return "", nil + } + + parsed, err := url.Parse(canonical) + if err != nil { + return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must be a valid URL") + } + if !parsed.IsAbs() || parsed.Host == "" { + return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must be a valid absolute URL") + } + parsed.Fragment = "" + + query := parsed.Query() + if orderID > 0 { + query.Set("order_id", strconv.FormatInt(orderID, 10)) + } + if strings.TrimSpace(resumeToken) != "" { + query.Set("resume_token", strings.TrimSpace(resumeToken)) + } + query.Set("status", "success") + parsed.RawQuery = query.Encode() + + return parsed.String(), nil +} + +func sameOriginHost(returnURLHost string, requestHost string) bool { + returnHost := strings.TrimSpace(returnURLHost) + reqHost := strings.TrimSpace(requestHost) + if returnHost == "" || reqHost == "" { + return false + } + if strings.EqualFold(returnHost, reqHost) { + return true + } + + returnName, returnPort := splitHostPortDefault(returnHost) + reqName, reqPort := splitHostPortDefault(reqHost) + if returnName == "" || reqName == "" { + return false + } + return strings.EqualFold(returnName, reqName) && returnPort == reqPort +} + +func splitHostPortDefault(raw string) (string, string) { + if host, port, err := net.SplitHostPort(raw); err == nil { + return host, port + } + return raw, "" +} + +func (s *PaymentResumeService) CreateToken(claims ResumeTokenClaims) (string, error) { + if err := s.ensureSigningKey(); err != nil { + return "", err + } + if claims.OrderID <= 0 { + return "", fmt.Errorf("resume token requires order id") + } + if claims.IssuedAt == 0 { + claims.IssuedAt = time.Now().Unix() + } + if claims.ExpiresAt == 0 { + claims.ExpiresAt = time.Now().Add(paymentResumeTokenTTL).Unix() + } + return s.createSignedToken(claims) +} + +func (s *PaymentResumeService) ParseToken(token string) (*ResumeTokenClaims, error) { + if err := s.ensureSigningKey(); err != nil { + return nil, err + } + var claims ResumeTokenClaims + if err := s.parseSignedToken(token, &claims); err != nil { + return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is invalid") + } + if claims.OrderID <= 0 { + return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token missing order id") + } + if err := validatePaymentResumeExpiry(claims.ExpiresAt, "INVALID_RESUME_TOKEN", "resume token has expired"); err != nil { + return nil, err + } + return &claims, nil +} + +func (s *PaymentResumeService) CreateWeChatPaymentResumeToken(claims WeChatPaymentResumeClaims) (string, error) { + if err := s.ensureSigningKey(); err != nil { + return "", err + } + claims.OpenID = strings.TrimSpace(claims.OpenID) + if claims.OpenID == "" { + return "", fmt.Errorf("wechat payment resume token requires openid") + } + if claims.IssuedAt == 0 { + claims.IssuedAt = time.Now().Unix() + } + if claims.ExpiresAt == 0 { + claims.ExpiresAt = time.Now().Add(wechatPaymentResumeTokenTTL).Unix() + } + if normalized := NormalizeVisibleMethod(claims.PaymentType); normalized != "" { + claims.PaymentType = normalized + } + if claims.PaymentType == "" { + claims.PaymentType = payment.TypeWxpay + } + if claims.OrderType == "" { + claims.OrderType = payment.OrderTypeBalance + } + claims.TokenType = wechatPaymentResumeTokenType + return s.createSignedToken(claims) +} + +func (s *PaymentResumeService) ParseWeChatPaymentResumeToken(token string) (*WeChatPaymentResumeClaims, error) { + if err := s.ensureSigningKey(); err != nil { + return nil, err + } + var claims WeChatPaymentResumeClaims + if err := s.parseSignedToken(token, &claims); err != nil { + return nil, infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token payload is invalid") + } + if claims.TokenType != wechatPaymentResumeTokenType { + return nil, infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token type mismatch") + } + claims.OpenID = strings.TrimSpace(claims.OpenID) + if claims.OpenID == "" { + return nil, infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token missing openid") + } + if err := validatePaymentResumeExpiry(claims.ExpiresAt, "INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token has expired"); err != nil { + return nil, err + } + if normalized := NormalizeVisibleMethod(claims.PaymentType); normalized != "" { + claims.PaymentType = normalized + } + if claims.PaymentType == "" { + claims.PaymentType = payment.TypeWxpay + } + if claims.OrderType == "" { + claims.OrderType = payment.OrderTypeBalance + } + return &claims, nil +} + +func (s *PaymentResumeService) createSignedToken(claims any) (string, error) { + payload, err := json.Marshal(claims) + if err != nil { + return "", fmt.Errorf("marshal resume claims: %w", err) + } + encodedPayload := base64.RawURLEncoding.EncodeToString(payload) + return encodedPayload + "." + s.sign(encodedPayload), nil +} + +func (s *PaymentResumeService) parseSignedToken(token string, dest any) error { + parts := strings.Split(token, ".") + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token is malformed") + } + if !hmac.Equal([]byte(parts[1]), []byte(s.sign(parts[0]))) { + return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token signature mismatch") + } + payload, err := base64.RawURLEncoding.DecodeString(parts[0]) + if err != nil { + return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is malformed") + } + return json.Unmarshal(payload, dest) +} + +func validatePaymentResumeExpiry(expiresAt int64, code, message string) error { + if expiresAt <= 0 { + return nil + } + if time.Now().Unix() > expiresAt { + return infraerrors.BadRequest(code, message) + } + return nil +} + +func (s *PaymentResumeService) sign(payload string) string { + mac := hmac.New(sha256.New, s.signingKey) + _, _ = mac.Write([]byte(payload)) + return base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) +} diff --git a/backend/internal/service/payment_resume_service_test.go b/backend/internal/service/payment_resume_service_test.go new file mode 100644 index 0000000000000000000000000000000000000000..78b6bba338a99bbf8eb4aef7663a42ba0a87da7d --- /dev/null +++ b/backend/internal/service/payment_resume_service_test.go @@ -0,0 +1,420 @@ +//go:build unit + +package service + +import ( + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "net/url" + "strconv" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/payment" +) + +func TestNormalizeVisibleMethods(t *testing.T) { + t.Parallel() + + got := NormalizeVisibleMethods([]string{ + "alipay_direct", + "alipay", + " wxpay_direct ", + "wxpay", + "stripe", + }) + + want := []string{"alipay", "wxpay", "stripe"} + if len(got) != len(want) { + t.Fatalf("NormalizeVisibleMethods len = %d, want %d (%v)", len(got), len(want), got) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("NormalizeVisibleMethods[%d] = %q, want %q (full=%v)", i, got[i], want[i], got) + } + } +} + +func TestNormalizePaymentSource(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + expect string + }{ + {name: "empty uses default", input: "", expect: PaymentSourceHostedRedirect}, + {name: "wechat alias normalized", input: "wechat_in_app", expect: PaymentSourceWechatInAppResume}, + {name: "canonical value preserved", input: PaymentSourceWechatInAppResume, expect: PaymentSourceWechatInAppResume}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := NormalizePaymentSource(tt.input); got != tt.expect { + t.Fatalf("NormalizePaymentSource(%q) = %q, want %q", tt.input, got, tt.expect) + } + }) + } +} + +func TestCanonicalizeReturnURL(t *testing.T) { + t.Parallel() + + got, err := CanonicalizeReturnURL("https://example.com/payment/result?b=2#a", "example.com") + if err != nil { + t.Fatalf("CanonicalizeReturnURL returned error: %v", err) + } + if got != "https://example.com/payment/result?b=2" { + t.Fatalf("CanonicalizeReturnURL = %q, want %q", got, "https://example.com/payment/result?b=2") + } +} + +func TestCanonicalizeReturnURLRejectsRelativeURL(t *testing.T) { + t.Parallel() + + if _, err := CanonicalizeReturnURL("/payment/result", "example.com"); err == nil { + t.Fatal("CanonicalizeReturnURL should reject relative URLs") + } +} + +func TestCanonicalizeReturnURLRejectsExternalHost(t *testing.T) { + t.Parallel() + + if _, err := CanonicalizeReturnURL("https://evil.example/payment/result", "app.example.com"); err == nil { + t.Fatal("CanonicalizeReturnURL should reject external hosts") + } +} + +func TestCanonicalizeReturnURLRejectsNonCanonicalPath(t *testing.T) { + t.Parallel() + + if _, err := CanonicalizeReturnURL("https://app.example.com/orders/42", "app.example.com"); err == nil { + t.Fatal("CanonicalizeReturnURL should reject non-canonical result paths") + } +} + +func TestBuildPaymentReturnURL(t *testing.T) { + t.Parallel() + + got, err := buildPaymentReturnURL("https://example.com/payment/result?from=checkout#fragment", 42, "resume-token") + if err != nil { + t.Fatalf("buildPaymentReturnURL returned error: %v", err) + } + + parsed, err := url.Parse(got) + if err != nil { + t.Fatalf("url.Parse returned error: %v", err) + } + if parsed.Fragment != "" { + t.Fatalf("buildPaymentReturnURL should strip fragments, got %q", parsed.Fragment) + } + query := parsed.Query() + if query.Get("from") != "checkout" { + t.Fatalf("expected original query to be preserved, got %q", query.Get("from")) + } + if query.Get("order_id") != strconv.FormatInt(42, 10) { + t.Fatalf("order_id = %q", query.Get("order_id")) + } + if query.Get("resume_token") != "resume-token" { + t.Fatalf("resume_token = %q", query.Get("resume_token")) + } + if query.Get("status") != "success" { + t.Fatalf("status = %q", query.Get("status")) + } +} + +func TestBuildPaymentReturnURLEmptyBase(t *testing.T) { + t.Parallel() + + got, err := buildPaymentReturnURL("", 42, "resume-token") + if err != nil { + t.Fatalf("buildPaymentReturnURL returned error: %v", err) + } + if got != "" { + t.Fatalf("buildPaymentReturnURL = %q, want empty string", got) + } +} + +func TestPaymentResumeTokenRoundTrip(t *testing.T) { + t.Parallel() + + svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef")) + token, err := svc.CreateToken(ResumeTokenClaims{ + OrderID: 42, + UserID: 7, + ProviderInstanceID: "19", + ProviderKey: "easypay", + PaymentType: "wxpay", + CanonicalReturnURL: "https://example.com/payment/result", + IssuedAt: 1234567890, + }) + if err != nil { + t.Fatalf("CreateToken returned error: %v", err) + } + + claims, err := svc.ParseToken(token) + if err != nil { + t.Fatalf("ParseToken returned error: %v", err) + } + if claims.OrderID != 42 || claims.UserID != 7 { + t.Fatalf("claims mismatch: %+v", claims) + } + if claims.ProviderInstanceID != "19" || claims.ProviderKey != "easypay" || claims.PaymentType != "wxpay" { + t.Fatalf("claims provider snapshot mismatch: %+v", claims) + } + if claims.CanonicalReturnURL != "https://example.com/payment/result" { + t.Fatalf("claims return URL = %q", claims.CanonicalReturnURL) + } +} + +func TestCreateTokenRejectsMissingSigningKey(t *testing.T) { + t.Parallel() + + svc := NewPaymentResumeService(nil) + _, err := svc.CreateToken(ResumeTokenClaims{OrderID: 42}) + if err == nil { + t.Fatal("CreateToken should reject missing signing key") + } +} + +func TestParseTokenRejectsFallbackSignedTokenWhenSigningKeyMissing(t *testing.T) { + t.Parallel() + + token := mustCreateFallbackSignedToken(t, ResumeTokenClaims{OrderID: 42, UserID: 7}) + svc := NewPaymentResumeService(nil) + _, err := svc.ParseToken(token) + if err == nil { + t.Fatal("ParseToken should reject tokens when signing key is missing") + } +} + +func TestParseTokenRejectsExpiredToken(t *testing.T) { + t.Parallel() + + svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef")) + token, err := svc.CreateToken(ResumeTokenClaims{ + OrderID: 42, + UserID: 7, + IssuedAt: time.Now().Add(-25 * time.Hour).Unix(), + ExpiresAt: time.Now().Add(-1 * time.Hour).Unix(), + }) + if err != nil { + t.Fatalf("CreateToken returned error: %v", err) + } + + _, err = svc.ParseToken(token) + if err == nil { + t.Fatal("ParseToken should reject expired tokens") + } +} + +func TestWeChatPaymentResumeTokenRoundTrip(t *testing.T) { + t.Parallel() + + svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef")) + token, err := svc.CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{ + OpenID: "openid-123", + PaymentType: payment.TypeWxpay, + Amount: "12.50", + OrderType: payment.OrderTypeSubscription, + PlanID: 7, + RedirectTo: "/purchase?from=wechat", + Scope: "snsapi_base", + IssuedAt: 1234567890, + }) + if err != nil { + t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err) + } + + claims, err := svc.ParseWeChatPaymentResumeToken(token) + if err != nil { + t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err) + } + if claims.OpenID != "openid-123" || claims.PaymentType != payment.TypeWxpay { + t.Fatalf("claims mismatch: %+v", claims) + } + if claims.Amount != "12.50" || claims.OrderType != payment.OrderTypeSubscription || claims.PlanID != 7 { + t.Fatalf("claims payment context mismatch: %+v", claims) + } + if claims.RedirectTo != "/purchase?from=wechat" || claims.Scope != "snsapi_base" { + t.Fatalf("claims redirect/scope mismatch: %+v", claims) + } +} + +func TestCreateWeChatPaymentResumeTokenRejectsMissingSigningKey(t *testing.T) { + t.Parallel() + + svc := NewPaymentResumeService(nil) + _, err := svc.CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{OpenID: "openid-123"}) + if err == nil { + t.Fatal("CreateWeChatPaymentResumeToken should reject missing signing key") + } +} + +func TestParseWeChatPaymentResumeTokenRejectsFallbackSignedTokenWhenSigningKeyMissing(t *testing.T) { + t.Parallel() + + token := mustCreateFallbackSignedToken(t, WeChatPaymentResumeClaims{ + TokenType: wechatPaymentResumeTokenType, + OpenID: "openid-123", + PaymentType: payment.TypeWxpay, + }) + svc := NewPaymentResumeService(nil) + _, err := svc.ParseWeChatPaymentResumeToken(token) + if err == nil { + t.Fatal("ParseWeChatPaymentResumeToken should reject tokens when signing key is missing") + } +} + +func TestParseWeChatPaymentResumeTokenRejectsExpiredToken(t *testing.T) { + t.Parallel() + + svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef")) + token, err := svc.CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{ + OpenID: "openid-123", + PaymentType: payment.TypeWxpay, + IssuedAt: time.Now().Add(-30 * time.Minute).Unix(), + ExpiresAt: time.Now().Add(-1 * time.Minute).Unix(), + }) + if err != nil { + t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err) + } + + _, err = svc.ParseWeChatPaymentResumeToken(token) + if err == nil { + t.Fatal("ParseWeChatPaymentResumeToken should reject expired tokens") + } +} + +func TestNormalizeVisibleMethodSource(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + method string + input string + want string + }{ + {name: "alipay official alias", method: payment.TypeAlipay, input: "alipay", want: VisibleMethodSourceOfficialAlipay}, + {name: "alipay easypay alias", method: payment.TypeAlipay, input: "easypay", want: VisibleMethodSourceEasyPayAlipay}, + {name: "wxpay official alias", method: payment.TypeWxpay, input: "wxpay", want: VisibleMethodSourceOfficialWechat}, + {name: "wxpay easypay alias", method: payment.TypeWxpay, input: "easypay", want: VisibleMethodSourceEasyPayWechat}, + {name: "unsupported source", method: payment.TypeWxpay, input: "stripe", want: ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := NormalizeVisibleMethodSource(tt.method, tt.input); got != tt.want { + t.Fatalf("NormalizeVisibleMethodSource(%q, %q) = %q, want %q", tt.method, tt.input, got, tt.want) + } + }) + } +} + +func TestVisibleMethodProviderKeyForSource(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + method string + source string + want string + ok bool + }{ + {name: "official alipay", method: payment.TypeAlipay, source: VisibleMethodSourceOfficialAlipay, want: payment.TypeAlipay, ok: true}, + {name: "easypay alipay", method: payment.TypeAlipay, source: VisibleMethodSourceEasyPayAlipay, want: payment.TypeEasyPay, ok: true}, + {name: "official wechat", method: payment.TypeWxpay, source: VisibleMethodSourceOfficialWechat, want: payment.TypeWxpay, ok: true}, + {name: "easypay wechat", method: payment.TypeWxpay, source: VisibleMethodSourceEasyPayWechat, want: payment.TypeEasyPay, ok: true}, + {name: "mismatched method and source", method: payment.TypeAlipay, source: VisibleMethodSourceOfficialWechat, want: "", ok: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, ok := VisibleMethodProviderKeyForSource(tt.method, tt.source) + if got != tt.want || ok != tt.ok { + t.Fatalf("VisibleMethodProviderKeyForSource(%q, %q) = (%q, %v), want (%q, %v)", tt.method, tt.source, got, ok, tt.want, tt.ok) + } + }) + } +} + +func TestVisibleMethodLoadBalancerUsesEnabledProviderInstance(t *testing.T) { + t.Parallel() + + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeAlipay). + SetName("Official Alipay"). + SetConfig("{}"). + SetSupportedTypes("alipay"). + SetEnabled(true). + SetSortOrder(1). + Save(ctx) + if err != nil { + t.Fatalf("create alipay provider: %v", err) + } + + inner := &captureLoadBalancer{} + configService := &PaymentConfigService{ + entClient: client, + } + lb := newVisibleMethodLoadBalancer(inner, configService) + + _, err = lb.SelectInstance(ctx, "", payment.TypeAlipay, payment.StrategyRoundRobin, 12.5) + if err != nil { + t.Fatalf("SelectInstance returned error: %v", err) + } + if inner.lastProviderKey != payment.TypeAlipay { + t.Fatalf("lastProviderKey = %q, want %q", inner.lastProviderKey, payment.TypeAlipay) + } +} + +func TestVisibleMethodLoadBalancerRejectsMissingEnabledVisibleMethodProvider(t *testing.T) { + t.Parallel() + + inner := &captureLoadBalancer{} + configService := &PaymentConfigService{ + entClient: newPaymentConfigServiceTestClient(t), + } + lb := newVisibleMethodLoadBalancer(inner, configService) + + if _, err := lb.SelectInstance(context.Background(), "", payment.TypeWxpay, payment.StrategyRoundRobin, 9.9); err == nil { + t.Fatal("SelectInstance should reject when no enabled provider instance exists") + } +} + +type captureLoadBalancer struct { + lastProviderKey string + lastPaymentType string +} + +func (c *captureLoadBalancer) GetInstanceConfig(context.Context, int64) (map[string]string, error) { + return map[string]string{}, nil +} + +func (c *captureLoadBalancer) SelectInstance(_ context.Context, providerKey string, paymentType payment.PaymentType, _ payment.Strategy, _ float64) (*payment.InstanceSelection, error) { + c.lastProviderKey = providerKey + c.lastPaymentType = paymentType + return &payment.InstanceSelection{ProviderKey: providerKey, SupportedTypes: paymentType}, nil +} + +func mustCreateFallbackSignedToken(t *testing.T, claims any) string { + t.Helper() + + payload, err := json.Marshal(claims) + if err != nil { + t.Fatalf("marshal claims: %v", err) + } + encodedPayload := base64.RawURLEncoding.EncodeToString(payload) + mac := hmac.New(sha256.New, []byte("sub2api-payment-resume")) + _, _ = mac.Write([]byte(encodedPayload)) + signature := base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) + return encodedPayload + "." + signature +} diff --git a/backend/internal/service/payment_service.go b/backend/internal/service/payment_service.go index 6fc23f974f731731b91356c2bd756f4cbd396aca..73bbb2566ef919779413bf4896eaf563169e1dcb 100644 --- a/backend/internal/service/payment_service.go +++ b/backend/internal/service/payment_service.go @@ -9,7 +9,6 @@ import ( "time" dbent "github.com/Wei-Shaw/sub2api/ent" - "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" "github.com/Wei-Shaw/sub2api/internal/payment" "github.com/Wei-Shaw/sub2api/internal/payment/provider" @@ -65,29 +64,39 @@ func generateRandomString(n int) string { } type CreateOrderRequest struct { - UserID int64 - Amount float64 - PaymentType string - ClientIP string - IsMobile bool - SrcHost string - SrcURL string - OrderType string - PlanID int64 + UserID int64 + Amount float64 + PaymentType string + OpenID string + ClientIP string + IsMobile bool + IsWeChatBrowser bool + SrcHost string + SrcURL string + ReturnURL string + PaymentSource string + OrderType string + PlanID int64 } type CreateOrderResponse struct { - OrderID int64 `json:"order_id"` - Amount float64 `json:"amount"` - PayAmount float64 `json:"pay_amount"` - FeeRate float64 `json:"fee_rate"` - Status string `json:"status"` - PaymentType string `json:"payment_type"` - PayURL string `json:"pay_url,omitempty"` - QRCode string `json:"qr_code,omitempty"` - ClientSecret string `json:"client_secret,omitempty"` - ExpiresAt time.Time `json:"expires_at"` - PaymentMode string `json:"payment_mode,omitempty"` + OrderID int64 `json:"order_id"` + Amount float64 `json:"amount"` + PayAmount float64 `json:"pay_amount"` + FeeRate float64 `json:"fee_rate"` + Status string `json:"status"` + ResultType payment.CreatePaymentResultType `json:"result_type,omitempty"` + PaymentType string `json:"payment_type"` + OutTradeNo string `json:"out_trade_no,omitempty"` + PayURL string `json:"pay_url,omitempty"` + QRCode string `json:"qr_code,omitempty"` + ClientSecret string `json:"client_secret,omitempty"` + OAuth *payment.WechatOAuthInfo `json:"oauth,omitempty"` + JSAPI *payment.WechatJSAPIPayload `json:"jsapi,omitempty"` + JSAPIPayload *payment.WechatJSAPIPayload `json:"jsapi_payload,omitempty"` + ExpiresAt time.Time `json:"expires_at"` + PaymentMode string `json:"payment_mode,omitempty"` + ResumeToken string `json:"resume_token,omitempty"` } type OrderListParams struct { @@ -165,10 +174,13 @@ type PaymentService struct { configService *PaymentConfigService userRepo UserRepository groupRepo GroupRepository + resumeService *PaymentResumeService } func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService { - return &PaymentService{entClient: entClient, registry: registry, loadBalancer: loadBalancer, redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo} + svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo} + svc.resumeService = NewPaymentResumeService(psResumeSigningKey(configService)) + return svc } // --- Provider Registry --- @@ -219,25 +231,6 @@ func (s *PaymentService) loadProviders(ctx context.Context) { } } -// GetWebhookProvider returns the provider instance that should verify a webhook. -// It extracts out_trade_no from the raw body, looks up the order to find the -// original provider instance, and creates a provider with that instance's credentials. -// Falls back to the registry provider when the order cannot be found. -func (s *PaymentService) GetWebhookProvider(ctx context.Context, providerKey, outTradeNo string) (payment.Provider, error) { - if outTradeNo != "" { - order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(outTradeNo)).Only(ctx) - if err == nil { - p, pErr := s.getOrderProvider(ctx, order) - if pErr == nil { - return p, nil - } - slog.Warn("[Webhook] order provider creation failed, falling back to registry", "outTradeNo", outTradeNo, "error", pErr) - } - } - s.EnsureProviders(ctx) - return s.registry.GetProviderByKey(providerKey) -} - // --- Helpers --- func psIsRefundStatus(s string) bool { @@ -262,6 +255,20 @@ func psNilIfEmpty(s string) *string { return &s } +func (s *PaymentService) paymentResume() *PaymentResumeService { + if s.resumeService != nil { + return s.resumeService + } + return NewPaymentResumeService(psResumeSigningKey(s.configService)) +} + +func psResumeSigningKey(configService *PaymentConfigService) []byte { + if configService == nil { + return nil + } + return configService.encryptionKey +} + func psSliceContains(sl []string, s string) bool { for _, v := range sl { if v == s { diff --git a/backend/internal/service/payment_visible_method_instances.go b/backend/internal/service/payment_visible_method_instances.go new file mode 100644 index 0000000000000000000000000000000000000000..477e8e8b22f899eb11a7ceeaad2b8d7289b522a7 --- /dev/null +++ b/backend/internal/service/payment_visible_method_instances.go @@ -0,0 +1,166 @@ +package service + +import ( + "context" + "fmt" + "strings" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" + "github.com/Wei-Shaw/sub2api/internal/payment" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" +) + +func enabledVisibleMethodsForProvider(providerKey, supportedTypes string) []string { + methodSet := make(map[string]struct{}, 2) + addMethod := func(method string) { + method = NormalizeVisibleMethod(method) + switch method { + case payment.TypeAlipay, payment.TypeWxpay: + methodSet[method] = struct{}{} + } + } + + switch strings.TrimSpace(providerKey) { + case payment.TypeAlipay: + if strings.TrimSpace(supportedTypes) == "" { + addMethod(payment.TypeAlipay) + break + } + for _, supportedType := range splitTypes(supportedTypes) { + if NormalizeVisibleMethod(supportedType) == payment.TypeAlipay { + addMethod(payment.TypeAlipay) + break + } + } + case payment.TypeWxpay: + if strings.TrimSpace(supportedTypes) == "" { + addMethod(payment.TypeWxpay) + break + } + for _, supportedType := range splitTypes(supportedTypes) { + if NormalizeVisibleMethod(supportedType) == payment.TypeWxpay { + addMethod(payment.TypeWxpay) + break + } + } + case payment.TypeEasyPay: + for _, supportedType := range splitTypes(supportedTypes) { + addMethod(supportedType) + } + } + + methods := make([]string, 0, len(methodSet)) + for _, method := range []string{payment.TypeAlipay, payment.TypeWxpay} { + if _, ok := methodSet[method]; ok { + methods = append(methods, method) + } + } + return methods +} + +func providerSupportsVisibleMethod(inst *dbent.PaymentProviderInstance, method string) bool { + if inst == nil || !inst.Enabled { + return false + } + method = NormalizeVisibleMethod(method) + for _, candidate := range enabledVisibleMethodsForProvider(inst.ProviderKey, inst.SupportedTypes) { + if candidate == method { + return true + } + } + return false +} + +func filterEnabledVisibleMethodInstances(instances []*dbent.PaymentProviderInstance, method string) []*dbent.PaymentProviderInstance { + filtered := make([]*dbent.PaymentProviderInstance, 0, len(instances)) + for _, inst := range instances { + if providerSupportsVisibleMethod(inst, method) { + filtered = append(filtered, inst) + } + } + return filtered +} + +func buildPaymentProviderConflictError(method string, conflicting *dbent.PaymentProviderInstance) error { + metadata := map[string]string{ + "payment_method": NormalizeVisibleMethod(method), + } + if conflicting != nil { + metadata["conflicting_provider_id"] = fmt.Sprintf("%d", conflicting.ID) + metadata["conflicting_provider_key"] = conflicting.ProviderKey + metadata["conflicting_provider_name"] = conflicting.Name + } + return infraerrors.Conflict( + "PAYMENT_PROVIDER_CONFLICT", + fmt.Sprintf("%s payment already has an enabled provider instance", NormalizeVisibleMethod(method)), + ).WithMetadata(metadata) +} + +func (s *PaymentConfigService) validateVisibleMethodEnablementConflicts( + ctx context.Context, + excludeID int64, + providerKey string, + supportedTypes string, + enabled bool, +) error { + if s == nil || s.entClient == nil || !enabled { + return nil + } + + claimedMethods := enabledVisibleMethodsForProvider(providerKey, supportedTypes) + if len(claimedMethods) == 0 { + return nil + } + + query := s.entClient.PaymentProviderInstance.Query(). + Where(paymentproviderinstance.EnabledEQ(true)) + if excludeID > 0 { + query = query.Where(paymentproviderinstance.IDNEQ(excludeID)) + } + instances, err := query.All(ctx) + if err != nil { + return fmt.Errorf("query enabled payment providers: %w", err) + } + + for _, method := range claimedMethods { + for _, inst := range instances { + if providerSupportsVisibleMethod(inst, method) { + return buildPaymentProviderConflictError(method, inst) + } + } + } + return nil +} + +func (s *PaymentConfigService) resolveEnabledVisibleMethodInstance( + ctx context.Context, + method string, +) (*dbent.PaymentProviderInstance, error) { + if s == nil || s.entClient == nil { + return nil, nil + } + + method = NormalizeVisibleMethod(method) + if method != payment.TypeAlipay && method != payment.TypeWxpay { + return nil, nil + } + + instances, err := s.entClient.PaymentProviderInstance.Query(). + Where(paymentproviderinstance.EnabledEQ(true)). + Order(paymentproviderinstance.BySortOrder()). + All(ctx) + if err != nil { + return nil, fmt.Errorf("query enabled payment providers: %w", err) + } + + matching := filterEnabledVisibleMethodInstances(instances, method) + switch len(matching) { + case 0: + return nil, nil + case 1: + return matching[0], nil + default: + return nil, buildPaymentProviderConflictError(method, matching[0]) + } +} diff --git a/backend/internal/service/payment_webhook_provider.go b/backend/internal/service/payment_webhook_provider.go new file mode 100644 index 0000000000000000000000000000000000000000..f2da40d9b4c6dfda2314347294b25c3f1f27bbea --- /dev/null +++ b/backend/internal/service/payment_webhook_provider.go @@ -0,0 +1,148 @@ +package service + +import ( + "context" + "fmt" + "log/slog" + "strings" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/paymentorder" + "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" + "github.com/Wei-Shaw/sub2api/internal/payment" +) + +// GetWebhookProvider returns the provider instance that should verify a webhook. +// It resolves the original provider instance from the order whenever possible and +// only falls back to a registry provider for legacy/single-instance scenarios. +func (s *PaymentService) GetWebhookProvider(ctx context.Context, providerKey, outTradeNo string) (payment.Provider, error) { + providers, err := s.GetWebhookProviders(ctx, providerKey, outTradeNo) + if err != nil { + return nil, err + } + if len(providers) == 0 { + return nil, payment.ErrProviderNotFound + } + return providers[0], nil +} + +// GetWebhookProviders returns provider candidates that can verify the webhook. +// Official WeChat Pay may require multiple candidates because the callback body +// cannot be bound to a merchant before decryption. +func (s *PaymentService) GetWebhookProviders(ctx context.Context, providerKey, outTradeNo string) ([]payment.Provider, error) { + if outTradeNo != "" { + order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(outTradeNo)).Only(ctx) + if err == nil { + if psHasPinnedProviderInstance(order) { + prov, err := s.getPinnedOrderProvider(ctx, order) + if err != nil { + return nil, err + } + return []payment.Provider{prov}, nil + } + inst, err := s.getOrderProviderInstance(ctx, order) + if err != nil { + return nil, fmt.Errorf("load order provider instance: %w", err) + } + if inst != nil { + prov, err := s.createProviderFromInstance(ctx, inst) + if err != nil { + return nil, err + } + return []payment.Provider{prov}, nil + } + if strings.TrimSpace(providerKey) == payment.TypeWxpay { + return s.getEnabledWebhookProvidersByKey(ctx, providerKey) + } + if !s.webhookRegistryFallbackAllowed(ctx, providerKey) { + return nil, fmt.Errorf("webhook provider fallback is ambiguous for %s", providerKey) + } + s.EnsureProviders(ctx) + prov, err := s.registry.GetProviderByKey(providerKey) + if err != nil { + return nil, err + } + return []payment.Provider{prov}, nil + } + } + + if strings.TrimSpace(providerKey) == payment.TypeWxpay { + return s.getEnabledWebhookProvidersByKey(ctx, providerKey) + } + + if !s.webhookRegistryFallbackAllowed(ctx, providerKey) { + return nil, fmt.Errorf("webhook provider fallback is ambiguous for %s", providerKey) + } + + s.EnsureProviders(ctx) + prov, err := s.registry.GetProviderByKey(providerKey) + if err != nil { + return nil, err + } + return []payment.Provider{prov}, nil +} + +func (s *PaymentService) getPinnedOrderProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) { + inst, err := s.getOrderProviderInstance(ctx, o) + if err != nil { + return nil, fmt.Errorf("load order provider instance: %w", err) + } + if inst == nil { + return nil, fmt.Errorf("order %d provider instance is missing", o.ID) + } + return s.createProviderFromInstance(ctx, inst) +} + +func (s *PaymentService) webhookRegistryFallbackAllowed(ctx context.Context, providerKey string) bool { + providerKey = strings.TrimSpace(providerKey) + if providerKey == "" || s == nil || s.entClient == nil { + return false + } + + count, err := s.entClient.PaymentProviderInstance.Query(). + Where( + paymentproviderinstance.ProviderKeyEQ(providerKey), + paymentproviderinstance.EnabledEQ(true), + ). + Count(ctx) + if err != nil { + slog.Warn("payment webhook fallback instance count failed", "provider", providerKey, "error", err) + return false + } + return count <= 1 +} + +func psHasPinnedProviderInstance(order *dbent.PaymentOrder) bool { + return order != nil && (psOrderProviderSnapshot(order) != nil || (order.ProviderInstanceID != nil && strings.TrimSpace(*order.ProviderInstanceID) != "")) +} + +func (s *PaymentService) getEnabledWebhookProvidersByKey(ctx context.Context, providerKey string) ([]payment.Provider, error) { + providerKey = strings.TrimSpace(providerKey) + instances, err := s.entClient.PaymentProviderInstance.Query(). + Where( + paymentproviderinstance.ProviderKeyEQ(providerKey), + paymentproviderinstance.EnabledEQ(true), + ). + Order(dbent.Asc(paymentproviderinstance.FieldSortOrder)). + All(ctx) + if err != nil { + return nil, fmt.Errorf("query webhook provider instances: %w", err) + } + if len(instances) == 0 { + return nil, payment.ErrProviderNotFound + } + + providers := make([]payment.Provider, 0, len(instances)) + for _, inst := range instances { + prov, provErr := s.createProviderFromInstance(ctx, inst) + if provErr != nil { + slog.Warn("skip webhook provider instance", "provider", providerKey, "instanceID", inst.ID, "error", provErr) + continue + } + providers = append(providers, prov) + } + if len(providers) == 0 { + return nil, payment.ErrProviderNotFound + } + return providers, nil +} diff --git a/backend/internal/service/payment_webhook_provider_test.go b/backend/internal/service/payment_webhook_provider_test.go new file mode 100644 index 0000000000000000000000000000000000000000..0f3efa1f49784ef2989e76b168802156b6365404 --- /dev/null +++ b/backend/internal/service/payment_webhook_provider_test.go @@ -0,0 +1,510 @@ +//go:build unit + +package service + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/json" + "encoding/pem" + "strconv" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/internal/payment" + "github.com/stretchr/testify/require" +) + +const webhookProviderTestEncryptionKey = "0123456789abcdef0123456789abcdef" + +type webhookProviderTestDouble struct { + key string + types []payment.PaymentType +} + +func (p webhookProviderTestDouble) Name() string { return p.key } +func (p webhookProviderTestDouble) ProviderKey() string { return p.key } +func (p webhookProviderTestDouble) SupportedTypes() []payment.PaymentType { return p.types } +func (p webhookProviderTestDouble) CreatePayment(context.Context, payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) { + panic("unexpected call") +} +func (p webhookProviderTestDouble) QueryOrder(context.Context, string) (*payment.QueryOrderResponse, error) { + panic("unexpected call") +} +func (p webhookProviderTestDouble) VerifyNotification(context.Context, string, map[string]string) (*payment.PaymentNotification, error) { + panic("unexpected call") +} +func (p webhookProviderTestDouble) Refund(context.Context, payment.RefundRequest) (*payment.RefundResponse, error) { + panic("unexpected call") +} + +func encryptWebhookProviderConfig(t *testing.T, config map[string]string) string { + t.Helper() + + data, err := json.Marshal(config) + require.NoError(t, err) + + encrypted, err := payment.Encrypt(string(data), []byte(webhookProviderTestEncryptionKey)) + require.NoError(t, err) + return encrypted +} + +func newWebhookProviderTestLoadBalancer(client *dbent.Client) payment.LoadBalancer { + return payment.NewDefaultLoadBalancer(client, []byte(webhookProviderTestEncryptionKey)) +} + +func encryptValidWebhookWxpayConfig(t *testing.T, suffix string) string { + t.Helper() + + key, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + privDER, err := x509.MarshalPKCS8PrivateKey(key) + require.NoError(t, err) + pubDER, err := x509.MarshalPKIXPublicKey(&key.PublicKey) + require.NoError(t, err) + + return encryptWebhookProviderConfig(t, map[string]string{ + "appId": "wx-app-" + suffix, + "mchId": "mch-" + suffix, + "privateKey": string(pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privDER})), + "apiV3Key": webhookProviderTestEncryptionKey, + "publicKey": string(pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubDER})), + "publicKeyId": "public-key-id-" + suffix, + "certSerial": "cert-serial-" + suffix, + }) +} + +func TestGetOrderProviderInstanceResolvesUniqueLegacyProviderKey(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + inst, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeStripe). + SetName("stripe-a"). + SetConfig(encryptWebhookProviderConfig(t, map[string]string{"secretKey": "sk_test_legacy_provider_key"})). + SetSupportedTypes("stripe"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + providerKey := payment.TypeStripe + order := &dbent.PaymentOrder{ + PaymentType: payment.TypeStripe, + ProviderKey: &providerKey, + } + + svc := &PaymentService{ + entClient: client, + loadBalancer: newWebhookProviderTestLoadBalancer(client), + } + + got, err := svc.getOrderProviderInstance(ctx, order) + require.NoError(t, err) + require.NotNil(t, got) + require.Equal(t, inst.ID, got.ID) +} + +func TestGetOrderProviderInstanceResolvesUniqueLegacyPaymentType(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + inst, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("wxpay-a"). + SetConfig("{}"). + SetSupportedTypes("wxpay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + order := &dbent.PaymentOrder{ + PaymentType: payment.TypeWxpayDirect, + } + + svc := &PaymentService{ + entClient: client, + loadBalancer: newWebhookProviderTestLoadBalancer(client), + } + + got, err := svc.getOrderProviderInstance(ctx, order) + require.NoError(t, err) + require.NotNil(t, got) + require.Equal(t, inst.ID, got.ID) +} + +func TestGetOrderProviderInstanceLeavesAmbiguousLegacyOrderUnresolved(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeEasyPay). + SetName("easypay-a"). + SetConfig("{}"). + SetSupportedTypes("wxpay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("wxpay-a"). + SetConfig("{}"). + SetSupportedTypes("wxpay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + order := &dbent.PaymentOrder{ + PaymentType: payment.TypeWxpay, + } + + svc := &PaymentService{ + entClient: client, + loadBalancer: newWebhookProviderTestLoadBalancer(client), + } + + got, err := svc.getOrderProviderInstance(ctx, order) + require.NoError(t, err) + require.Nil(t, got) +} + +func TestGetOrderProviderInstanceLeavesLegacyProviderKeyUnresolvedWhenHistoricalInstancesConflict(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeStripe). + SetName("stripe-disabled-legacy"). + SetConfig("{}"). + SetSupportedTypes("stripe"). + SetEnabled(false). + Save(ctx) + require.NoError(t, err) + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeStripe). + SetName("stripe-enabled-current"). + SetConfig("{}"). + SetSupportedTypes("stripe"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + providerKey := payment.TypeStripe + order := &dbent.PaymentOrder{ + PaymentType: payment.TypeStripe, + ProviderKey: &providerKey, + } + + svc := &PaymentService{ + entClient: client, + loadBalancer: newWebhookProviderTestLoadBalancer(client), + } + + got, err := svc.getOrderProviderInstance(ctx, order) + require.NoError(t, err) + require.Nil(t, got) +} + +func TestGetOrderProviderInstanceLeavesProviderKeyMatchUnresolvedWhenTypeNotSupported(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("wxpay-only"). + SetConfig("{}"). + SetSupportedTypes("wxpay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + providerKey := payment.TypeWxpay + order := &dbent.PaymentOrder{ + PaymentType: payment.TypeAlipayDirect, + ProviderKey: &providerKey, + } + + svc := &PaymentService{ + entClient: client, + loadBalancer: newWebhookProviderTestLoadBalancer(client), + } + + got, err := svc.getOrderProviderInstance(ctx, order) + require.NoError(t, err) + require.Nil(t, got) +} + +func TestGetOrderProviderInstanceUsesProviderSnapshotWhenPinnedColumnMissing(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + inst, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeStripe). + SetName("stripe-snapshot"). + SetConfig(encryptWebhookProviderConfig(t, map[string]string{"secretKey": "sk_snapshot"})). + SetSupportedTypes("stripe"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + order := &dbent.PaymentOrder{ + ID: 42, + PaymentType: payment.TypeStripe, + ProviderSnapshot: map[string]any{ + "schema_version": 1, + "provider_instance_id": strconv.FormatInt(inst.ID, 10), + "provider_key": payment.TypeStripe, + }, + } + + svc := &PaymentService{ + entClient: client, + loadBalancer: newWebhookProviderTestLoadBalancer(client), + } + + got, err := svc.getOrderProviderInstance(ctx, order) + require.NoError(t, err) + require.NotNil(t, got) + require.Equal(t, inst.ID, got.ID) +} + +func TestGetOrderProviderInstanceRejectsMissingSnapshotInstanceWithoutLegacyFallback(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeStripe). + SetName("stripe-legacy-fallback"). + SetConfig(encryptWebhookProviderConfig(t, map[string]string{"secretKey": "sk_legacy"})). + SetSupportedTypes("stripe"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + order := &dbent.PaymentOrder{ + ID: 43, + PaymentType: payment.TypeStripe, + ProviderSnapshot: map[string]any{ + "schema_version": 1, + "provider_instance_id": "999999", + "provider_key": payment.TypeStripe, + }, + } + + svc := &PaymentService{ + entClient: client, + loadBalancer: newWebhookProviderTestLoadBalancer(client), + } + + got, err := svc.getOrderProviderInstance(ctx, order) + require.Nil(t, got) + require.Error(t, err) + require.Contains(t, err.Error(), "provider snapshot instance 999999 is missing") +} + +func TestGetWebhookProviderRejectsAmbiguousRegistryFallback(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + wxpayConfigA := encryptValidWebhookWxpayConfig(t, "a") + wxpayConfigB := encryptValidWebhookWxpayConfig(t, "b") + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("wxpay-a"). + SetConfig(wxpayConfigA). + SetSupportedTypes("wxpay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("wxpay-b"). + SetConfig(wxpayConfigB). + SetSupportedTypes("wxpay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + svc := &PaymentService{ + entClient: client, + loadBalancer: newWebhookProviderTestLoadBalancer(client), + registry: payment.NewRegistry(), + providersLoaded: true, + } + + providers, err := svc.GetWebhookProviders(ctx, payment.TypeWxpay, "") + require.NoError(t, err) + require.Len(t, providers, 2) +} + +func TestGetWebhookProvidersRejectAmbiguousFallbackForNonWxpay(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeAlipay). + SetName("alipay-a"). + SetConfig("{}"). + SetSupportedTypes("alipay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeAlipay). + SetName("alipay-b"). + SetConfig("{}"). + SetSupportedTypes("alipay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + svc := &PaymentService{ + entClient: client, + registry: payment.NewRegistry(), + providersLoaded: true, + } + + _, err = svc.GetWebhookProviders(ctx, payment.TypeAlipay, "") + require.Error(t, err) + require.Contains(t, err.Error(), "ambiguous") +} + +func TestGetWebhookProviderAllowsSingleInstanceRegistryFallback(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeStripe). + SetName("stripe-a"). + SetConfig("{}"). + SetSupportedTypes("stripe"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + registry := payment.NewRegistry() + registry.Register(webhookProviderTestDouble{ + key: payment.TypeStripe, + types: []payment.PaymentType{payment.TypeStripe}, + }) + + svc := &PaymentService{ + entClient: client, + registry: registry, + providersLoaded: true, + } + + providers, err := svc.GetWebhookProviders(ctx, payment.TypeStripe, "") + require.NoError(t, err) + require.Len(t, providers, 1) + prov := providers[0] + require.Equal(t, payment.TypeStripe, prov.ProviderKey()) +} + +func TestGetWebhookProviderRejectsRegistryFallbackForPinnedOrder(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + user, err := client.User.Create(). + SetEmail("webhook@example.com"). + SetPasswordHash("hash"). + SetUsername("webhook"). + Save(ctx) + require.NoError(t, err) + + pinnedInstanceID := "999" + _, err = client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(88). + SetPayAmount(88). + SetFeeRate(0). + SetRechargeCode("TEST-RECHARGE"). + SetOutTradeNo("sub2_test_pinned_order"). + SetPaymentType(payment.TypeWxpay). + SetPaymentTradeNo(""). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusPending). + SetExpiresAt(time.Now().Add(time.Hour)). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + SetProviderInstanceID(pinnedInstanceID). + Save(ctx) + require.NoError(t, err) + + registry := payment.NewRegistry() + registry.Register(webhookProviderTestDouble{ + key: payment.TypeWxpay, + types: []payment.PaymentType{payment.TypeWxpay}, + }) + + svc := &PaymentService{ + entClient: client, + registry: registry, + providersLoaded: true, + } + + _, err = svc.GetWebhookProviders(ctx, payment.TypeWxpay, "sub2_test_pinned_order") + require.Error(t, err) + require.Contains(t, err.Error(), "provider instance") +} + +func TestGetWebhookProviderUsesProviderSnapshotBeforeWxpayFallback(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + user, err := client.User.Create(). + SetEmail("snapshot-webhook@example.com"). + SetPasswordHash("hash"). + SetUsername("snapshot-webhook"). + Save(ctx) + require.NoError(t, err) + + wxpayConfigA := encryptValidWebhookWxpayConfig(t, "snapshot-a") + wxpayConfigB := encryptValidWebhookWxpayConfig(t, "snapshot-b") + instA, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("wxpay-snapshot-a"). + SetConfig(wxpayConfigA). + SetSupportedTypes("wxpay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("wxpay-snapshot-b"). + SetConfig(wxpayConfigB). + SetSupportedTypes("wxpay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + _, err = client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(66). + SetPayAmount(66). + SetFeeRate(0). + SetRechargeCode("SNAPSHOT-WEBHOOK"). + SetOutTradeNo("sub2_test_snapshot_webhook_order"). + SetPaymentType(payment.TypeWxpay). + SetPaymentTradeNo(""). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusPending). + SetExpiresAt(time.Now().Add(time.Hour)). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + SetProviderSnapshot(map[string]any{ + "schema_version": 1, + "provider_instance_id": strconv.FormatInt(instA.ID, 10), + "provider_key": payment.TypeWxpay, + "payment_mode": "native", + }). + Save(ctx) + require.NoError(t, err) + + svc := &PaymentService{ + entClient: client, + loadBalancer: newWebhookProviderTestLoadBalancer(client), + registry: payment.NewRegistry(), + providersLoaded: true, + } + + providers, err := svc.GetWebhookProviders(ctx, payment.TypeWxpay, "sub2_test_snapshot_webhook_order") + require.NoError(t, err) + require.Len(t, providers, 1) + require.Equal(t, payment.TypeWxpay, providers[0].ProviderKey()) +} diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 7f4a2eb1319e0e5b5b45c255a3412f0403370271..fe566feccbbdb2f2eefa851e40395bb2486972fe 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -114,6 +114,149 @@ type SettingService struct { webSearchManagerBuilder WebSearchManagerBuilder } +type ProviderDefaultGrantSettings struct { + Balance float64 + Concurrency int + Subscriptions []DefaultSubscriptionSetting + GrantOnSignup bool + GrantOnFirstBind bool +} + +type AuthSourceDefaultSettings struct { + Email ProviderDefaultGrantSettings + LinuxDo ProviderDefaultGrantSettings + OIDC ProviderDefaultGrantSettings + WeChat ProviderDefaultGrantSettings + ForceEmailOnThirdPartySignup bool +} + +type authSourceDefaultKeySet struct { + balance string + concurrency string + subscriptions string + grantOnSignup string + grantOnFirstBind string +} + +var ( + emailAuthSourceDefaultKeys = authSourceDefaultKeySet{ + balance: SettingKeyAuthSourceDefaultEmailBalance, + concurrency: SettingKeyAuthSourceDefaultEmailConcurrency, + subscriptions: SettingKeyAuthSourceDefaultEmailSubscriptions, + grantOnSignup: SettingKeyAuthSourceDefaultEmailGrantOnSignup, + grantOnFirstBind: SettingKeyAuthSourceDefaultEmailGrantOnFirstBind, + } + linuxDoAuthSourceDefaultKeys = authSourceDefaultKeySet{ + balance: SettingKeyAuthSourceDefaultLinuxDoBalance, + concurrency: SettingKeyAuthSourceDefaultLinuxDoConcurrency, + subscriptions: SettingKeyAuthSourceDefaultLinuxDoSubscriptions, + grantOnSignup: SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup, + grantOnFirstBind: SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind, + } + oidcAuthSourceDefaultKeys = authSourceDefaultKeySet{ + balance: SettingKeyAuthSourceDefaultOIDCBalance, + concurrency: SettingKeyAuthSourceDefaultOIDCConcurrency, + subscriptions: SettingKeyAuthSourceDefaultOIDCSubscriptions, + grantOnSignup: SettingKeyAuthSourceDefaultOIDCGrantOnSignup, + grantOnFirstBind: SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind, + } + weChatAuthSourceDefaultKeys = authSourceDefaultKeySet{ + balance: SettingKeyAuthSourceDefaultWeChatBalance, + concurrency: SettingKeyAuthSourceDefaultWeChatConcurrency, + subscriptions: SettingKeyAuthSourceDefaultWeChatSubscriptions, + grantOnSignup: SettingKeyAuthSourceDefaultWeChatGrantOnSignup, + grantOnFirstBind: SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind, + } +) + +const ( + defaultAuthSourceBalance = 0 + defaultAuthSourceConcurrency = 5 + defaultWeChatConnectMode = "open" + defaultWeChatConnectScopes = "snsapi_login" + defaultWeChatConnectFrontend = "/auth/wechat/callback" +) + +func normalizeWeChatConnectModeSetting(raw string) string { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "mp": + return "mp" + case "mobile": + return "mobile" + default: + return "open" + } +} + +func defaultWeChatConnectScopeForMode(mode string) string { + switch normalizeWeChatConnectModeSetting(mode) { + case "mp": + return "snsapi_userinfo" + case "mobile": + return "" + } + return defaultWeChatConnectScopes +} + +func normalizeWeChatConnectScopeSetting(raw, mode string) string { + switch normalizeWeChatConnectModeSetting(mode) { + case "mp": + switch strings.TrimSpace(raw) { + case "snsapi_base": + return "snsapi_base" + case "snsapi_userinfo": + return "snsapi_userinfo" + default: + return defaultWeChatConnectScopeForMode(mode) + } + case "mobile": + return "" + default: + return defaultWeChatConnectScopes + } +} + +func parseWeChatConnectCapabilitySettings(settings map[string]string, enabled bool, mode string) (bool, bool, bool) { + mode = normalizeWeChatConnectModeSetting(mode) + rawOpen, hasOpen := settings[SettingKeyWeChatConnectOpenEnabled] + rawMP, hasMP := settings[SettingKeyWeChatConnectMPEnabled] + rawMobile, hasMobile := settings[SettingKeyWeChatConnectMobileEnabled] + openConfigured := hasOpen && strings.TrimSpace(rawOpen) != "" + mpConfigured := hasMP && strings.TrimSpace(rawMP) != "" + mobileConfigured := hasMobile && strings.TrimSpace(rawMobile) != "" + + if openConfigured || mpConfigured || mobileConfigured { + openEnabled := strings.TrimSpace(rawOpen) == "true" + mpEnabled := strings.TrimSpace(rawMP) == "true" + mobileEnabled := strings.TrimSpace(rawMobile) == "true" + return openEnabled, mpEnabled, mobileEnabled + } + + if !enabled { + return false, false, false + } + if mode == "mp" { + return false, true, false + } + if mode == "mobile" { + return false, false, true + } + return true, false, false +} + +func normalizeWeChatConnectStoredMode(openEnabled, mpEnabled, mobileEnabled bool, mode string) string { + switch { + case mpEnabled: + return "mp" + case mobileEnabled: + return "mobile" + case openEnabled: + return "open" + default: + return normalizeWeChatConnectModeSetting(mode) + } +} + // NewSettingService 创建系统设置服务实例 func NewSettingService(settingRepo SettingRepository, cfg *config.Config) *SettingService { return &SettingService{ @@ -156,6 +299,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings keys := []string{ SettingKeyRegistrationEnabled, SettingKeyEmailVerifyEnabled, + SettingKeyForceEmailOnThirdPartySignup, SettingKeyRegistrationEmailSuffixWhitelist, SettingKeyPromoCodeEnabled, SettingKeyPasswordResetEnabled, @@ -178,6 +322,22 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings SettingKeyCustomMenuItems, SettingKeyCustomEndpoints, SettingKeyLinuxDoConnectEnabled, + SettingKeyWeChatConnectEnabled, + SettingKeyWeChatConnectAppID, + SettingKeyWeChatConnectAppSecret, + SettingKeyWeChatConnectOpenAppID, + SettingKeyWeChatConnectOpenAppSecret, + SettingKeyWeChatConnectMPAppID, + SettingKeyWeChatConnectMPAppSecret, + SettingKeyWeChatConnectMobileAppID, + SettingKeyWeChatConnectMobileAppSecret, + SettingKeyWeChatConnectOpenEnabled, + SettingKeyWeChatConnectMPEnabled, + SettingKeyWeChatConnectMobileEnabled, + SettingKeyWeChatConnectMode, + SettingKeyWeChatConnectScopes, + SettingKeyWeChatConnectRedirectURL, + SettingKeyWeChatConnectFrontendRedirectURL, SettingKeyBackendModeEnabled, SettingPaymentEnabled, SettingKeyOIDCConnectEnabled, @@ -212,6 +372,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings if oidcProviderName == "" { oidcProviderName = "OIDC" } + weChatEnabled, weChatOpenEnabled, weChatMPEnabled, weChatMobileEnabled := s.weChatOAuthCapabilitiesFromSettings(settings) // Password reset requires email verification to be enabled emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true" @@ -232,6 +393,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings return &PublicSettings{ RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", EmailVerifyEnabled: emailVerifyEnabled, + ForceEmailOnThirdPartySignup: settings[SettingKeyForceEmailOnThirdPartySignup] == "true", RegistrationEmailSuffixWhitelist: registrationEmailSuffixWhitelist, PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用 PasswordResetEnabled: passwordResetEnabled, @@ -254,6 +416,10 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings CustomMenuItems: settings[SettingKeyCustomMenuItems], CustomEndpoints: settings[SettingKeyCustomEndpoints], LinuxDoOAuthEnabled: linuxDoEnabled, + WeChatOAuthEnabled: weChatEnabled, + WeChatOAuthOpenEnabled: weChatOpenEnabled, + WeChatOAuthMPEnabled: weChatMPEnabled, + WeChatOAuthMobileEnabled: weChatMobileEnabled, BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true", PaymentEnabled: settings[SettingPaymentEnabled] == "true", OIDCOAuthEnabled: oidcEnabled, @@ -310,6 +476,10 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any CustomMenuItems json.RawMessage `json:"custom_menu_items"` CustomEndpoints json.RawMessage `json:"custom_endpoints"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` + WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"` + WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"` + WeChatOAuthMPEnabled bool `json:"wechat_oauth_mp_enabled"` + WeChatOAuthMobileEnabled bool `json:"wechat_oauth_mobile_enabled"` BackendModeEnabled bool `json:"backend_mode_enabled"` PaymentEnabled bool `json:"payment_enabled"` OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"` @@ -344,6 +514,10 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems), CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints), LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, + WeChatOAuthEnabled: settings.WeChatOAuthEnabled, + WeChatOAuthOpenEnabled: settings.WeChatOAuthOpenEnabled, + WeChatOAuthMPEnabled: settings.WeChatOAuthMPEnabled, + WeChatOAuthMobileEnabled: settings.WeChatOAuthMobileEnabled, BackendModeEnabled: settings.BackendModeEnabled, PaymentEnabled: settings.PaymentEnabled, OIDCOAuthEnabled: settings.OIDCOAuthEnabled, @@ -356,6 +530,110 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any }, nil } +func DefaultWeChatConnectScopesForMode(mode string) string { + return defaultWeChatConnectScopeForMode(mode) +} + +func (s *SettingService) parseWeChatConnectOAuthConfig(settings map[string]string) (WeChatConnectOAuthConfig, error) { + enabled := settings[SettingKeyWeChatConnectEnabled] == "true" + mode := normalizeWeChatConnectModeSetting(settings[SettingKeyWeChatConnectMode]) + openEnabled, mpEnabled, mobileEnabled := parseWeChatConnectCapabilitySettings(settings, enabled, mode) + mode = normalizeWeChatConnectStoredMode(openEnabled, mpEnabled, mobileEnabled, mode) + + cfg := WeChatConnectOAuthConfig{ + Enabled: enabled, + LegacyAppID: strings.TrimSpace(settings[SettingKeyWeChatConnectAppID]), + LegacyAppSecret: strings.TrimSpace(settings[SettingKeyWeChatConnectAppSecret]), + OpenAppID: strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectOpenAppID], settings[SettingKeyWeChatConnectAppID])), + OpenAppSecret: strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectOpenAppSecret], settings[SettingKeyWeChatConnectAppSecret])), + MPAppID: strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMPAppID], settings[SettingKeyWeChatConnectAppID])), + MPAppSecret: strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMPAppSecret], settings[SettingKeyWeChatConnectAppSecret])), + MobileAppID: strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMobileAppID], settings[SettingKeyWeChatConnectAppID])), + MobileAppSecret: strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMobileAppSecret], settings[SettingKeyWeChatConnectAppSecret])), + OpenEnabled: openEnabled, + MPEnabled: mpEnabled, + MobileEnabled: mobileEnabled, + Mode: mode, + Scopes: normalizeWeChatConnectScopeSetting(settings[SettingKeyWeChatConnectScopes], mode), + RedirectURL: strings.TrimSpace(settings[SettingKeyWeChatConnectRedirectURL]), + FrontendRedirectURL: strings.TrimSpace(settings[SettingKeyWeChatConnectFrontendRedirectURL]), + } + if cfg.FrontendRedirectURL == "" { + cfg.FrontendRedirectURL = defaultWeChatConnectFrontend + } + + if !cfg.Enabled || (!cfg.OpenEnabled && !cfg.MPEnabled) { + return WeChatConnectOAuthConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "wechat oauth is disabled") + } + if cfg.OpenEnabled { + if cfg.AppIDForMode("open") == "" { + return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth pc app id not configured") + } + if cfg.AppSecretForMode("open") == "" { + return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth pc app secret not configured") + } + } + if cfg.MPEnabled { + if cfg.AppIDForMode("mp") == "" { + return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth official account app id not configured") + } + if cfg.AppSecretForMode("mp") == "" { + return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth official account app secret not configured") + } + } + if cfg.MobileEnabled { + if cfg.AppIDForMode("mobile") == "" { + return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth mobile app id not configured") + } + if cfg.AppSecretForMode("mobile") == "" { + return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth mobile app secret not configured") + } + } + if cfg.RedirectURL == "" { + return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth redirect url not configured") + } + if cfg.FrontendRedirectURL == "" { + return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth frontend redirect url not configured") + } + if err := config.ValidateAbsoluteHTTPURL(cfg.RedirectURL); err != nil { + return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth redirect url invalid") + } + if err := config.ValidateFrontendRedirectURL(cfg.FrontendRedirectURL); err != nil { + return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth frontend redirect url invalid") + } + return cfg, nil +} + +func (s *SettingService) weChatOAuthCapabilitiesFromSettings(settings map[string]string) (bool, bool, bool, bool) { + if settings[SettingKeyWeChatConnectEnabled] != "true" { + return false, false, false, false + } + + mode := normalizeWeChatConnectModeSetting(settings[SettingKeyWeChatConnectMode]) + openEnabled, mpEnabled, mobileEnabled := parseWeChatConnectCapabilitySettings(settings, true, mode) + redirectURL := strings.TrimSpace(settings[SettingKeyWeChatConnectRedirectURL]) + frontendRedirectURL := strings.TrimSpace(settings[SettingKeyWeChatConnectFrontendRedirectURL]) + if frontendRedirectURL == "" { + frontendRedirectURL = defaultWeChatConnectFrontend + } + + legacyAppID := strings.TrimSpace(settings[SettingKeyWeChatConnectAppID]) + legacyAppSecret := strings.TrimSpace(settings[SettingKeyWeChatConnectAppSecret]) + openAppID := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectOpenAppID], legacyAppID)) + openAppSecret := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectOpenAppSecret], legacyAppSecret)) + mpAppID := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMPAppID], legacyAppID)) + mpAppSecret := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMPAppSecret], legacyAppSecret)) + mobileAppID := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMobileAppID], legacyAppID)) + mobileAppSecret := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMobileAppSecret], legacyAppSecret)) + + webRedirectReady := redirectURL != "" && frontendRedirectURL != "" + openReady := openEnabled && webRedirectReady && openAppID != "" && openAppSecret != "" + mpReady := mpEnabled && webRedirectReady && mpAppID != "" && mpAppSecret != "" + mobileReady := mobileEnabled && mobileAppID != "" && mobileAppSecret != "" + + return openReady || mpReady || mobileReady, openReady, mpReady, mobileReady +} + // filterUserVisibleMenuItems filters out admin-only menu items from a raw JSON // array string, returning only items with visibility != "admin". func filterUserVisibleMenuItems(raw string) json.RawMessage { @@ -480,17 +758,82 @@ func parseCustomMenuItemURLs(raw string) []string { // UpdateSettings 更新系统设置 func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSettings) error { - if err := s.validateDefaultSubscriptionGroups(ctx, settings.DefaultSubscriptions); err != nil { + updates, err := s.buildSystemSettingsUpdates(ctx, settings) + if err != nil { + return err + } + + err = s.settingRepo.SetMultiple(ctx, updates) + if err == nil { + s.refreshCachedSettings(settings) + } + return err +} + +// UpdateSettingsWithAuthSourceDefaults persists system settings and auth-source defaults in a single write. +func (s *SettingService) UpdateSettingsWithAuthSourceDefaults(ctx context.Context, settings *SystemSettings, authDefaults *AuthSourceDefaultSettings) error { + updates, err := s.buildSystemSettingsUpdates(ctx, settings) + if err != nil { + return err + } + + authSourceUpdates, err := s.buildAuthSourceDefaultUpdates(ctx, authDefaults) + if err != nil { return err } + for key, value := range authSourceUpdates { + updates[key] = value + } + + err = s.settingRepo.SetMultiple(ctx, updates) + if err == nil { + s.refreshCachedSettings(settings) + } + return err +} + +func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, settings *SystemSettings) (map[string]string, error) { + if err := s.validateDefaultSubscriptionGroups(ctx, settings.DefaultSubscriptions); err != nil { + return nil, err + } normalizedWhitelist, err := NormalizeRegistrationEmailSuffixWhitelist(settings.RegistrationEmailSuffixWhitelist) if err != nil { - return infraerrors.BadRequest("INVALID_REGISTRATION_EMAIL_SUFFIX_WHITELIST", err.Error()) + return nil, infraerrors.BadRequest("INVALID_REGISTRATION_EMAIL_SUFFIX_WHITELIST", err.Error()) } if normalizedWhitelist == nil { normalizedWhitelist = []string{} } settings.RegistrationEmailSuffixWhitelist = normalizedWhitelist + alipaySource, err := normalizeVisibleMethodSettingSource("alipay", settings.PaymentVisibleMethodAlipaySource, settings.PaymentVisibleMethodAlipayEnabled) + if err != nil { + return nil, err + } + wxpaySource, err := normalizeVisibleMethodSettingSource("wxpay", settings.PaymentVisibleMethodWxpaySource, settings.PaymentVisibleMethodWxpayEnabled) + if err != nil { + return nil, err + } + settings.PaymentVisibleMethodAlipaySource = alipaySource + settings.PaymentVisibleMethodWxpaySource = wxpaySource + settings.WeChatConnectAppID = strings.TrimSpace(settings.WeChatConnectAppID) + settings.WeChatConnectAppSecret = strings.TrimSpace(settings.WeChatConnectAppSecret) + settings.WeChatConnectOpenAppID = strings.TrimSpace(firstNonEmpty(settings.WeChatConnectOpenAppID, settings.WeChatConnectAppID)) + settings.WeChatConnectOpenAppSecret = strings.TrimSpace(firstNonEmpty(settings.WeChatConnectOpenAppSecret, settings.WeChatConnectAppSecret)) + settings.WeChatConnectMPAppID = strings.TrimSpace(firstNonEmpty(settings.WeChatConnectMPAppID, settings.WeChatConnectAppID)) + settings.WeChatConnectMPAppSecret = strings.TrimSpace(firstNonEmpty(settings.WeChatConnectMPAppSecret, settings.WeChatConnectAppSecret)) + settings.WeChatConnectMobileAppID = strings.TrimSpace(firstNonEmpty(settings.WeChatConnectMobileAppID, settings.WeChatConnectAppID)) + settings.WeChatConnectMobileAppSecret = strings.TrimSpace(firstNonEmpty(settings.WeChatConnectMobileAppSecret, settings.WeChatConnectAppSecret)) + settings.WeChatConnectMode = normalizeWeChatConnectStoredMode( + settings.WeChatConnectOpenEnabled, + settings.WeChatConnectMPEnabled, + settings.WeChatConnectMobileEnabled, + settings.WeChatConnectMode, + ) + settings.WeChatConnectScopes = normalizeWeChatConnectScopeSetting(settings.WeChatConnectScopes, settings.WeChatConnectMode) + settings.WeChatConnectRedirectURL = strings.TrimSpace(settings.WeChatConnectRedirectURL) + settings.WeChatConnectFrontendRedirectURL = strings.TrimSpace(settings.WeChatConnectFrontendRedirectURL) + if settings.WeChatConnectFrontendRedirectURL == "" { + settings.WeChatConnectFrontendRedirectURL = defaultWeChatConnectFrontend + } updates := make(map[string]string) @@ -499,7 +842,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled) registrationEmailSuffixWhitelistJSON, err := json.Marshal(settings.RegistrationEmailSuffixWhitelist) if err != nil { - return fmt.Errorf("marshal registration email suffix whitelist: %w", err) + return nil, fmt.Errorf("marshal registration email suffix whitelist: %w", err) } updates[SettingKeyRegistrationEmailSuffixWhitelist] = string(registrationEmailSuffixWhitelistJSON) updates[SettingKeyPromoCodeEnabled] = strconv.FormatBool(settings.PromoCodeEnabled) @@ -560,6 +903,32 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyOIDCConnectClientSecret] = settings.OIDCConnectClientSecret } + // WeChat Connect OAuth 登录 + updates[SettingKeyWeChatConnectEnabled] = strconv.FormatBool(settings.WeChatConnectEnabled) + updates[SettingKeyWeChatConnectAppID] = settings.WeChatConnectAppID + updates[SettingKeyWeChatConnectOpenAppID] = settings.WeChatConnectOpenAppID + updates[SettingKeyWeChatConnectMPAppID] = settings.WeChatConnectMPAppID + updates[SettingKeyWeChatConnectMobileAppID] = settings.WeChatConnectMobileAppID + updates[SettingKeyWeChatConnectOpenEnabled] = strconv.FormatBool(settings.WeChatConnectOpenEnabled) + updates[SettingKeyWeChatConnectMPEnabled] = strconv.FormatBool(settings.WeChatConnectMPEnabled) + updates[SettingKeyWeChatConnectMobileEnabled] = strconv.FormatBool(settings.WeChatConnectMobileEnabled) + updates[SettingKeyWeChatConnectMode] = settings.WeChatConnectMode + updates[SettingKeyWeChatConnectScopes] = settings.WeChatConnectScopes + updates[SettingKeyWeChatConnectRedirectURL] = settings.WeChatConnectRedirectURL + updates[SettingKeyWeChatConnectFrontendRedirectURL] = settings.WeChatConnectFrontendRedirectURL + if settings.WeChatConnectAppSecret != "" { + updates[SettingKeyWeChatConnectAppSecret] = settings.WeChatConnectAppSecret + } + if settings.WeChatConnectOpenAppSecret != "" { + updates[SettingKeyWeChatConnectOpenAppSecret] = settings.WeChatConnectOpenAppSecret + } + if settings.WeChatConnectMPAppSecret != "" { + updates[SettingKeyWeChatConnectMPAppSecret] = settings.WeChatConnectMPAppSecret + } + if settings.WeChatConnectMobileAppSecret != "" { + updates[SettingKeyWeChatConnectMobileAppSecret] = settings.WeChatConnectMobileAppSecret + } + // OEM设置 updates[SettingKeySiteName] = settings.SiteName updates[SettingKeySiteLogo] = settings.SiteLogo @@ -578,7 +947,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyTableDefaultPageSize] = strconv.Itoa(tableDefaultPageSize) tablePageSizeOptionsJSON, err := json.Marshal(tablePageSizeOptions) if err != nil { - return fmt.Errorf("marshal table page size options: %w", err) + return nil, fmt.Errorf("marshal table page size options: %w", err) } updates[SettingKeyTablePageSizeOptions] = string(tablePageSizeOptionsJSON) updates[SettingKeyCustomMenuItems] = settings.CustomMenuItems @@ -589,7 +958,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64) defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions) if err != nil { - return fmt.Errorf("marshal default subscriptions: %w", err) + return nil, fmt.Errorf("marshal default subscriptions: %w", err) } updates[SettingKeyDefaultSubscriptions] = string(defaultSubsJSON) @@ -626,6 +995,11 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyEnableFingerprintUnification] = strconv.FormatBool(settings.EnableFingerprintUnification) updates[SettingKeyEnableMetadataPassthrough] = strconv.FormatBool(settings.EnableMetadataPassthrough) updates[SettingKeyEnableCCHSigning] = strconv.FormatBool(settings.EnableCCHSigning) + updates[SettingPaymentVisibleMethodAlipaySource] = settings.PaymentVisibleMethodAlipaySource + updates[SettingPaymentVisibleMethodWxpaySource] = settings.PaymentVisibleMethodWxpaySource + updates[SettingPaymentVisibleMethodAlipayEnabled] = strconv.FormatBool(settings.PaymentVisibleMethodAlipayEnabled) + updates[SettingPaymentVisibleMethodWxpayEnabled] = strconv.FormatBool(settings.PaymentVisibleMethodWxpayEnabled) + updates[openAIAdvancedSchedulerSettingKey] = strconv.FormatBool(settings.OpenAIAdvancedSchedulerEnabled) // Balance low notification updates[SettingKeyBalanceLowNotifyEnabled] = strconv.FormatBool(settings.BalanceLowNotifyEnabled) @@ -634,32 +1008,66 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyAccountQuotaNotifyEnabled] = strconv.FormatBool(settings.AccountQuotaNotifyEnabled) updates[SettingKeyAccountQuotaNotifyEmails] = MarshalNotifyEmails(settings.AccountQuotaNotifyEmails) - err = s.settingRepo.SetMultiple(ctx, updates) - if err == nil { - // 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口 - versionBoundsSF.Forget("version_bounds") - versionBoundsCache.Store(&cachedVersionBounds{ - min: settings.MinClaudeCodeVersion, - max: settings.MaxClaudeCodeVersion, - expiresAt: time.Now().Add(versionBoundsCacheTTL).UnixNano(), - }) - backendModeSF.Forget("backend_mode") - backendModeCache.Store(&cachedBackendMode{ - value: settings.BackendModeEnabled, - expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(), - }) - gatewayForwardingSF.Forget("gateway_forwarding") - gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{ - fingerprintUnification: settings.EnableFingerprintUnification, - metadataPassthrough: settings.EnableMetadataPassthrough, - cchSigning: settings.EnableCCHSigning, - expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(), - }) - if s.onUpdate != nil { - s.onUpdate() // Invalidate cache after settings update + return updates, nil +} + +func (s *SettingService) buildAuthSourceDefaultUpdates(ctx context.Context, settings *AuthSourceDefaultSettings) (map[string]string, error) { + if settings == nil { + return nil, nil + } + + for _, subscriptions := range [][]DefaultSubscriptionSetting{ + settings.Email.Subscriptions, + settings.LinuxDo.Subscriptions, + settings.OIDC.Subscriptions, + settings.WeChat.Subscriptions, + } { + if err := s.validateDefaultSubscriptionGroups(ctx, subscriptions); err != nil { + return nil, err } } - return err + + updates := make(map[string]string, 21) + writeProviderDefaultGrantUpdates(updates, emailAuthSourceDefaultKeys, settings.Email) + writeProviderDefaultGrantUpdates(updates, linuxDoAuthSourceDefaultKeys, settings.LinuxDo) + writeProviderDefaultGrantUpdates(updates, oidcAuthSourceDefaultKeys, settings.OIDC) + writeProviderDefaultGrantUpdates(updates, weChatAuthSourceDefaultKeys, settings.WeChat) + updates[SettingKeyForceEmailOnThirdPartySignup] = strconv.FormatBool(settings.ForceEmailOnThirdPartySignup) + return updates, nil +} + +func (s *SettingService) refreshCachedSettings(settings *SystemSettings) { + if settings == nil { + return + } + + // 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口 + versionBoundsSF.Forget("version_bounds") + versionBoundsCache.Store(&cachedVersionBounds{ + min: settings.MinClaudeCodeVersion, + max: settings.MaxClaudeCodeVersion, + expiresAt: time.Now().Add(versionBoundsCacheTTL).UnixNano(), + }) + backendModeSF.Forget("backend_mode") + backendModeCache.Store(&cachedBackendMode{ + value: settings.BackendModeEnabled, + expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(), + }) + gatewayForwardingSF.Forget("gateway_forwarding") + gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{ + fingerprintUnification: settings.EnableFingerprintUnification, + metadataPassthrough: settings.EnableMetadataPassthrough, + cchSigning: settings.EnableCCHSigning, + expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(), + }) + openAIAdvancedSchedulerSettingSF.Forget(openAIAdvancedSchedulerSettingKey) + openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{ + enabled: settings.OpenAIAdvancedSchedulerEnabled, + expiresAt: time.Now().Add(openAIAdvancedSchedulerSettingCacheTTL).UnixNano(), + }) + if s.onUpdate != nil { + s.onUpdate() // Invalidate cache after settings update + } } func (s *SettingService) validateDefaultSubscriptionGroups(ctx context.Context, items []DefaultSubscriptionSetting) error { @@ -919,6 +1327,88 @@ func (s *SettingService) GetDefaultSubscriptions(ctx context.Context) []DefaultS return parseDefaultSubscriptions(value) } +func (s *SettingService) GetAuthSourceDefaultSettings(ctx context.Context) (*AuthSourceDefaultSettings, error) { + keys := []string{ + SettingKeyAuthSourceDefaultEmailBalance, + SettingKeyAuthSourceDefaultEmailConcurrency, + SettingKeyAuthSourceDefaultEmailSubscriptions, + SettingKeyAuthSourceDefaultEmailGrantOnSignup, + SettingKeyAuthSourceDefaultEmailGrantOnFirstBind, + SettingKeyAuthSourceDefaultLinuxDoBalance, + SettingKeyAuthSourceDefaultLinuxDoConcurrency, + SettingKeyAuthSourceDefaultLinuxDoSubscriptions, + SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup, + SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind, + SettingKeyAuthSourceDefaultOIDCBalance, + SettingKeyAuthSourceDefaultOIDCConcurrency, + SettingKeyAuthSourceDefaultOIDCSubscriptions, + SettingKeyAuthSourceDefaultOIDCGrantOnSignup, + SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind, + SettingKeyAuthSourceDefaultWeChatBalance, + SettingKeyAuthSourceDefaultWeChatConcurrency, + SettingKeyAuthSourceDefaultWeChatSubscriptions, + SettingKeyAuthSourceDefaultWeChatGrantOnSignup, + SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind, + SettingKeyForceEmailOnThirdPartySignup, + } + + settings, err := s.settingRepo.GetMultiple(ctx, keys) + if err != nil { + return nil, fmt.Errorf("get auth source default settings: %w", err) + } + + return &AuthSourceDefaultSettings{ + Email: parseProviderDefaultGrantSettings(settings, emailAuthSourceDefaultKeys), + LinuxDo: parseProviderDefaultGrantSettings(settings, linuxDoAuthSourceDefaultKeys), + OIDC: parseProviderDefaultGrantSettings(settings, oidcAuthSourceDefaultKeys), + WeChat: parseProviderDefaultGrantSettings(settings, weChatAuthSourceDefaultKeys), + ForceEmailOnThirdPartySignup: settings[SettingKeyForceEmailOnThirdPartySignup] == "true", + }, nil +} + +func (s *SettingService) ResolveAuthSourceGrantSettings(ctx context.Context, signupSource string, firstBind bool) (ProviderDefaultGrantSettings, bool, error) { + result := ProviderDefaultGrantSettings{ + Balance: s.GetDefaultBalance(ctx), + Concurrency: s.GetDefaultConcurrency(ctx), + Subscriptions: s.GetDefaultSubscriptions(ctx), + } + + defaults, err := s.GetAuthSourceDefaultSettings(ctx) + if err != nil { + return result, false, err + } + + providerDefaults, ok := authSourceSignupSettings(defaults, signupSource) + if !ok { + return result, false, nil + } + + enabled := providerDefaults.GrantOnSignup + if firstBind { + enabled = providerDefaults.GrantOnFirstBind + } + if !enabled { + return result, false, nil + } + + return mergeProviderDefaultGrantSettings(result, providerDefaults), true, nil +} + +func (s *SettingService) UpdateAuthSourceDefaultSettings(ctx context.Context, settings *AuthSourceDefaultSettings) error { + updates, err := s.buildAuthSourceDefaultUpdates(ctx, settings) + if err != nil { + return err + } + if len(updates) == 0 { + return nil + } + + if err := s.settingRepo.SetMultiple(ctx, updates); err != nil { + return fmt.Errorf("update auth source default settings: %w", err) + } + return nil +} + // InitializeDefaultSettings 初始化默认设置 func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { // 检查是否已有设置 @@ -933,25 +1423,59 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { // 初始化默认设置 defaults := map[string]string{ - SettingKeyRegistrationEnabled: "true", - SettingKeyEmailVerifyEnabled: "false", - SettingKeyRegistrationEmailSuffixWhitelist: "[]", - SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能 - SettingKeySiteName: "Sub2API", - SettingKeySiteLogo: "", - SettingKeyPurchaseSubscriptionEnabled: "false", - SettingKeyPurchaseSubscriptionURL: "", - SettingKeyTableDefaultPageSize: "20", - SettingKeyTablePageSizeOptions: "[10,20,50,100]", - SettingKeyCustomMenuItems: "[]", - SettingKeyCustomEndpoints: "[]", - SettingKeyOIDCConnectEnabled: "false", - SettingKeyOIDCConnectProviderName: "OIDC", - SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), - SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), - SettingKeyDefaultSubscriptions: "[]", - SettingKeySMTPPort: "587", - SettingKeySMTPUseTLS: "false", + SettingKeyRegistrationEnabled: "true", + SettingKeyEmailVerifyEnabled: "false", + SettingKeyRegistrationEmailSuffixWhitelist: "[]", + SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能 + SettingKeySiteName: "Sub2API", + SettingKeySiteLogo: "", + SettingKeyPurchaseSubscriptionEnabled: "false", + SettingKeyPurchaseSubscriptionURL: "", + SettingKeyTableDefaultPageSize: "20", + SettingKeyTablePageSizeOptions: "[10,20,50,100]", + SettingKeyCustomMenuItems: "[]", + SettingKeyCustomEndpoints: "[]", + SettingKeyWeChatConnectEnabled: "false", + SettingKeyWeChatConnectOpenAppID: "", + SettingKeyWeChatConnectOpenAppSecret: "", + SettingKeyWeChatConnectMPAppID: "", + SettingKeyWeChatConnectMPAppSecret: "", + SettingKeyWeChatConnectMobileAppID: "", + SettingKeyWeChatConnectMobileAppSecret: "", + SettingKeyWeChatConnectOpenEnabled: "false", + SettingKeyWeChatConnectMPEnabled: "false", + SettingKeyWeChatConnectMobileEnabled: "false", + SettingKeyWeChatConnectMode: "open", + SettingKeyWeChatConnectScopes: "snsapi_login", + SettingKeyWeChatConnectFrontendRedirectURL: defaultWeChatConnectFrontend, + SettingKeyOIDCConnectEnabled: "false", + SettingKeyOIDCConnectProviderName: "OIDC", + SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), + SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), + SettingKeyDefaultSubscriptions: "[]", + SettingKeyAuthSourceDefaultEmailBalance: "0", + SettingKeyAuthSourceDefaultEmailConcurrency: "5", + SettingKeyAuthSourceDefaultEmailSubscriptions: "[]", + SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false", + SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "false", + SettingKeyAuthSourceDefaultLinuxDoBalance: "0", + SettingKeyAuthSourceDefaultLinuxDoConcurrency: "5", + SettingKeyAuthSourceDefaultLinuxDoSubscriptions: "[]", + SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "false", + SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind: "false", + SettingKeyAuthSourceDefaultOIDCBalance: "0", + SettingKeyAuthSourceDefaultOIDCConcurrency: "5", + SettingKeyAuthSourceDefaultOIDCSubscriptions: "[]", + SettingKeyAuthSourceDefaultOIDCGrantOnSignup: "false", + SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind: "false", + SettingKeyAuthSourceDefaultWeChatBalance: "0", + SettingKeyAuthSourceDefaultWeChatConcurrency: "5", + SettingKeyAuthSourceDefaultWeChatSubscriptions: "[]", + SettingKeyAuthSourceDefaultWeChatGrantOnSignup: "false", + SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind: "false", + SettingKeyForceEmailOnThirdPartySignup: "false", + SettingKeySMTPPort: "587", + SettingKeySMTPUseTLS: "false", // Model fallback defaults SettingKeyEnableModelFallback: "false", SettingKeyFallbackModelAnthropic: "claude-3-5-sonnet-20241022", @@ -973,7 +1497,12 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeyMaxClaudeCodeVersion: "", // 分组隔离(默认不允许未分组 Key 调度) - SettingKeyAllowUngroupedKeyScheduling: "false", + SettingKeyAllowUngroupedKeyScheduling: "false", + SettingPaymentVisibleMethodAlipaySource: "", + SettingPaymentVisibleMethodWxpaySource: "", + SettingPaymentVisibleMethodAlipayEnabled: "false", + SettingPaymentVisibleMethodWxpayEnabled: "false", + openAIAdvancedSchedulerSettingKey: "false", } return s.settingRepo.SetMultiple(ctx, defaults) @@ -1164,6 +1693,8 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin } else { result.OIDCConnectValidateIDToken = oidcBase.ValidateIDToken } + result.OIDCConnectUsePKCE = true + result.OIDCConnectValidateIDToken = true if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" { result.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(v) } else { @@ -1208,6 +1739,38 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin } result.OIDCConnectClientSecretConfigured = result.OIDCConnectClientSecret != "" + // WeChat Connect 设置:完全以 DB 系统设置为准。 + result.WeChatConnectEnabled = settings[SettingKeyWeChatConnectEnabled] == "true" + result.WeChatConnectAppID = strings.TrimSpace(settings[SettingKeyWeChatConnectAppID]) + result.WeChatConnectAppSecret = strings.TrimSpace(settings[SettingKeyWeChatConnectAppSecret]) + result.WeChatConnectAppSecretConfigured = result.WeChatConnectAppSecret != "" + result.WeChatConnectOpenAppID = strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectOpenAppID], result.WeChatConnectAppID)) + result.WeChatConnectOpenAppSecret = strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectOpenAppSecret], result.WeChatConnectAppSecret)) + result.WeChatConnectOpenAppSecretConfigured = result.WeChatConnectOpenAppSecret != "" + result.WeChatConnectMPAppID = strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMPAppID], result.WeChatConnectAppID)) + result.WeChatConnectMPAppSecret = strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMPAppSecret], result.WeChatConnectAppSecret)) + result.WeChatConnectMPAppSecretConfigured = result.WeChatConnectMPAppSecret != "" + result.WeChatConnectMobileAppID = strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMobileAppID], result.WeChatConnectAppID)) + result.WeChatConnectMobileAppSecret = strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMobileAppSecret], result.WeChatConnectAppSecret)) + result.WeChatConnectMobileAppSecretConfigured = result.WeChatConnectMobileAppSecret != "" + result.WeChatConnectOpenEnabled, result.WeChatConnectMPEnabled, result.WeChatConnectMobileEnabled = parseWeChatConnectCapabilitySettings( + settings, + result.WeChatConnectEnabled, + settings[SettingKeyWeChatConnectMode], + ) + result.WeChatConnectMode = normalizeWeChatConnectStoredMode( + result.WeChatConnectOpenEnabled, + result.WeChatConnectMPEnabled, + result.WeChatConnectMobileEnabled, + settings[SettingKeyWeChatConnectMode], + ) + result.WeChatConnectScopes = normalizeWeChatConnectScopeSetting(settings[SettingKeyWeChatConnectScopes], result.WeChatConnectMode) + result.WeChatConnectRedirectURL = strings.TrimSpace(settings[SettingKeyWeChatConnectRedirectURL]) + result.WeChatConnectFrontendRedirectURL = strings.TrimSpace(settings[SettingKeyWeChatConnectFrontendRedirectURL]) + if result.WeChatConnectFrontendRedirectURL == "" { + result.WeChatConnectFrontendRedirectURL = defaultWeChatConnectFrontend + } + // Model fallback settings result.EnableModelFallback = settings[SettingKeyEnableModelFallback] == "true" result.FallbackModelAnthropic = s.getStringOrDefault(settings, SettingKeyFallbackModelAnthropic, "claude-3-5-sonnet-20241022") @@ -1263,6 +1826,11 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin result.WebSearchEmulationEnabled = wsCfg.Enabled && len(wsCfg.Providers) > 0 } } + result.PaymentVisibleMethodAlipaySource = NormalizeVisibleMethodSource("alipay", settings[SettingPaymentVisibleMethodAlipaySource]) + result.PaymentVisibleMethodWxpaySource = NormalizeVisibleMethodSource("wxpay", settings[SettingPaymentVisibleMethodWxpaySource]) + result.PaymentVisibleMethodAlipayEnabled = settings[SettingPaymentVisibleMethodAlipayEnabled] == "true" + result.PaymentVisibleMethodWxpayEnabled = settings[SettingPaymentVisibleMethodWxpayEnabled] == "true" + result.OpenAIAdvancedSchedulerEnabled = settings[openAIAdvancedSchedulerSettingKey] == "true" // Balance low notification result.BalanceLowNotifyEnabled = settings[SettingKeyBalanceLowNotifyEnabled] == "true" @@ -1292,6 +1860,28 @@ func isFalseSettingValue(value string) bool { } } +func normalizeVisibleMethodSettingSource(method, source string, enabled bool) (string, error) { + source = strings.TrimSpace(source) + if source == "" { + if enabled { + return "", infraerrors.BadRequest( + "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE", + fmt.Sprintf("%s source is required when the visible method is enabled", method), + ) + } + return "", nil + } + + normalized := NormalizeVisibleMethodSource(method, source) + if normalized == "" { + return "", infraerrors.BadRequest( + "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE", + fmt.Sprintf("%s source must be one of the supported payment providers", method), + ) + } + return normalized, nil +} + func parseDefaultSubscriptions(raw string) []DefaultSubscriptionSetting { raw = strings.TrimSpace(raw) if raw == "" { @@ -1317,6 +1907,73 @@ func parseDefaultSubscriptions(raw string) []DefaultSubscriptionSetting { return normalized } +func parseProviderDefaultGrantSettings(settings map[string]string, keys authSourceDefaultKeySet) ProviderDefaultGrantSettings { + result := ProviderDefaultGrantSettings{ + Balance: defaultAuthSourceBalance, + Concurrency: defaultAuthSourceConcurrency, + Subscriptions: []DefaultSubscriptionSetting{}, + GrantOnSignup: false, + GrantOnFirstBind: false, + } + + if v, err := strconv.ParseFloat(strings.TrimSpace(settings[keys.balance]), 64); err == nil { + result.Balance = v + } + if v, err := strconv.Atoi(strings.TrimSpace(settings[keys.concurrency])); err == nil { + result.Concurrency = v + } + if items := parseDefaultSubscriptions(settings[keys.subscriptions]); items != nil { + result.Subscriptions = items + } + if raw, ok := settings[keys.grantOnSignup]; ok { + result.GrantOnSignup = raw == "true" + } + if raw, ok := settings[keys.grantOnFirstBind]; ok { + result.GrantOnFirstBind = raw == "true" + } + + return result +} + +func writeProviderDefaultGrantUpdates(updates map[string]string, keys authSourceDefaultKeySet, settings ProviderDefaultGrantSettings) { + updates[keys.balance] = strconv.FormatFloat(settings.Balance, 'f', 8, 64) + updates[keys.concurrency] = strconv.Itoa(settings.Concurrency) + + subscriptions := settings.Subscriptions + if subscriptions == nil { + subscriptions = []DefaultSubscriptionSetting{} + } + raw, err := json.Marshal(subscriptions) + if err != nil { + raw = []byte("[]") + } + updates[keys.subscriptions] = string(raw) + updates[keys.grantOnSignup] = strconv.FormatBool(settings.GrantOnSignup) + updates[keys.grantOnFirstBind] = strconv.FormatBool(settings.GrantOnFirstBind) +} + +func mergeProviderDefaultGrantSettings(globalDefaults ProviderDefaultGrantSettings, providerDefaults ProviderDefaultGrantSettings) ProviderDefaultGrantSettings { + result := ProviderDefaultGrantSettings{ + Balance: globalDefaults.Balance, + Concurrency: globalDefaults.Concurrency, + Subscriptions: append([]DefaultSubscriptionSetting(nil), globalDefaults.Subscriptions...), + GrantOnSignup: providerDefaults.GrantOnSignup, + GrantOnFirstBind: providerDefaults.GrantOnFirstBind, + } + + if providerDefaults.Balance != defaultAuthSourceBalance { + result.Balance = providerDefaults.Balance + } + if providerDefaults.Concurrency > 0 && providerDefaults.Concurrency != defaultAuthSourceConcurrency { + result.Concurrency = providerDefaults.Concurrency + } + if len(providerDefaults.Subscriptions) > 0 { + result.Subscriptions = append([]DefaultSubscriptionSetting(nil), providerDefaults.Subscriptions...) + } + + return result +} + func parseTablePreferences(defaultPageSizeRaw, optionsRaw string) (int, []int) { defaultPageSize := 20 if v, err := strconv.Atoi(strings.TrimSpace(defaultPageSizeRaw)); err == nil { @@ -1539,6 +2196,7 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf if v, ok := settings[SettingKeyLinuxDoConnectRedirectURL]; ok && strings.TrimSpace(v) != "" { effective.RedirectURL = strings.TrimSpace(v) } + effective.UsePKCE = true if !effective.Enabled { return config.LinuxDoConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled") @@ -1587,9 +2245,6 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured") } case "none": - if !effective.UsePKCE { - return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth pkce must be enabled when token_auth_method=none") - } default: return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid") } @@ -1597,6 +2252,35 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf return effective, nil } +// GetWeChatConnectOAuthConfig 返回用于登录的最终生效 WeChat Connect 配置。 +// +// WeChat Connect 已回归 DB 系统设置模型,不再回退到 config/env。 +func (s *SettingService) GetWeChatConnectOAuthConfig(ctx context.Context) (WeChatConnectOAuthConfig, error) { + keys := []string{ + SettingKeyWeChatConnectEnabled, + SettingKeyWeChatConnectAppID, + SettingKeyWeChatConnectAppSecret, + SettingKeyWeChatConnectOpenAppID, + SettingKeyWeChatConnectOpenAppSecret, + SettingKeyWeChatConnectMPAppID, + SettingKeyWeChatConnectMPAppSecret, + SettingKeyWeChatConnectMobileAppID, + SettingKeyWeChatConnectMobileAppSecret, + SettingKeyWeChatConnectOpenEnabled, + SettingKeyWeChatConnectMPEnabled, + SettingKeyWeChatConnectMobileEnabled, + SettingKeyWeChatConnectMode, + SettingKeyWeChatConnectScopes, + SettingKeyWeChatConnectRedirectURL, + SettingKeyWeChatConnectFrontendRedirectURL, + } + settings, err := s.settingRepo.GetMultiple(ctx, keys) + if err != nil { + return WeChatConnectOAuthConfig{}, fmt.Errorf("get wechat connect settings: %w", err) + } + return s.parseWeChatConnectOAuthConfig(settings) +} + // GetOverloadCooldownSettings 获取529过载冷却配置 func (s *SettingService) GetOverloadCooldownSettings(ctx context.Context) (*OverloadCooldownSettings, error) { value, err := s.settingRepo.GetValue(ctx, SettingKeyOverloadCooldownSettings) @@ -1737,6 +2421,8 @@ func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config. if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok { effective.ValidateIDToken = raw == "true" } + effective.UsePKCE = true + effective.ValidateIDToken = true if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" { effective.AllowedSigningAlgs = strings.TrimSpace(v) } @@ -1864,9 +2550,6 @@ func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config. return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured") } case "none": - if !effective.UsePKCE { - return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth pkce must be enabled when token_auth_method=none") - } default: return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid") } diff --git a/backend/internal/service/setting_service_auth_source_defaults_test.go b/backend/internal/service/setting_service_auth_source_defaults_test.go new file mode 100644 index 0000000000000000000000000000000000000000..1ff4974066c5c18589235e20194cbcb39f9a3153 --- /dev/null +++ b/backend/internal/service/setting_service_auth_source_defaults_test.go @@ -0,0 +1,138 @@ +//go:build unit + +package service + +import ( + "context" + "encoding/json" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type authSourceDefaultsRepoStub struct { + values map[string]string + updates map[string]string +} + +func (s *authSourceDefaultsRepoStub) Get(ctx context.Context, key string) (*Setting, error) { + panic("unexpected Get call") +} + +func (s *authSourceDefaultsRepoStub) GetValue(ctx context.Context, key string) (string, error) { + panic("unexpected GetValue call") +} + +func (s *authSourceDefaultsRepoStub) Set(ctx context.Context, key, value string) error { + panic("unexpected Set call") +} + +func (s *authSourceDefaultsRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (s *authSourceDefaultsRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + s.updates = make(map[string]string, len(settings)) + for key, value := range settings { + s.updates[key] = value + if s.values == nil { + s.values = map[string]string{} + } + s.values[key] = value + } + return nil +} + +func (s *authSourceDefaultsRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *authSourceDefaultsRepoStub) Delete(ctx context.Context, key string) error { + panic("unexpected Delete call") +} + +func TestSettingService_GetAuthSourceDefaultSettings_ParsesValuesAndDefaults(t *testing.T) { + repo := &authSourceDefaultsRepoStub{ + values: map[string]string{ + SettingKeyAuthSourceDefaultEmailBalance: "12.5", + SettingKeyAuthSourceDefaultEmailConcurrency: "7", + SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`, + SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false", + SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind: "true", + SettingKeyForceEmailOnThirdPartySignup: "true", + }, + } + svc := NewSettingService(repo, &config.Config{}) + + got, err := svc.GetAuthSourceDefaultSettings(context.Background()) + require.NoError(t, err) + require.Equal(t, 12.5, got.Email.Balance) + require.Equal(t, 7, got.Email.Concurrency) + require.Equal(t, []DefaultSubscriptionSetting{{GroupID: 11, ValidityDays: 30}}, got.Email.Subscriptions) + require.False(t, got.Email.GrantOnSignup) + require.False(t, got.Email.GrantOnFirstBind) + require.Equal(t, 0.0, got.LinuxDo.Balance) + require.Equal(t, 5, got.LinuxDo.Concurrency) + require.Equal(t, []DefaultSubscriptionSetting{}, got.LinuxDo.Subscriptions) + require.False(t, got.LinuxDo.GrantOnSignup) + require.True(t, got.LinuxDo.GrantOnFirstBind) + require.Equal(t, 5, got.OIDC.Concurrency) + require.Equal(t, 5, got.WeChat.Concurrency) + require.False(t, got.OIDC.GrantOnSignup) + require.False(t, got.WeChat.GrantOnSignup) + require.True(t, got.ForceEmailOnThirdPartySignup) +} + +func TestSettingService_UpdateAuthSourceDefaultSettings_PersistsAllKeys(t *testing.T) { + repo := &authSourceDefaultsRepoStub{} + svc := NewSettingService(repo, &config.Config{}) + + err := svc.UpdateAuthSourceDefaultSettings(context.Background(), &AuthSourceDefaultSettings{ + Email: ProviderDefaultGrantSettings{ + Balance: 1.25, + Concurrency: 3, + Subscriptions: []DefaultSubscriptionSetting{{GroupID: 21, ValidityDays: 14}}, + GrantOnSignup: false, + GrantOnFirstBind: true, + }, + LinuxDo: ProviderDefaultGrantSettings{ + Balance: 2, + Concurrency: 4, + Subscriptions: []DefaultSubscriptionSetting{{GroupID: 22, ValidityDays: 30}}, + GrantOnSignup: true, + GrantOnFirstBind: false, + }, + OIDC: ProviderDefaultGrantSettings{ + Balance: 3, + Concurrency: 5, + Subscriptions: []DefaultSubscriptionSetting{{GroupID: 23, ValidityDays: 60}}, + GrantOnSignup: true, + GrantOnFirstBind: true, + }, + WeChat: ProviderDefaultGrantSettings{ + Balance: 4, + Concurrency: 6, + Subscriptions: []DefaultSubscriptionSetting{{GroupID: 24, ValidityDays: 90}}, + GrantOnSignup: false, + GrantOnFirstBind: false, + }, + ForceEmailOnThirdPartySignup: true, + }) + require.NoError(t, err) + require.Equal(t, "1.25000000", repo.updates[SettingKeyAuthSourceDefaultEmailBalance]) + require.Equal(t, "3", repo.updates[SettingKeyAuthSourceDefaultEmailConcurrency]) + require.Equal(t, "false", repo.updates[SettingKeyAuthSourceDefaultEmailGrantOnSignup]) + require.Equal(t, "true", repo.updates[SettingKeyAuthSourceDefaultEmailGrantOnFirstBind]) + require.Equal(t, "true", repo.updates[SettingKeyForceEmailOnThirdPartySignup]) + + var got []DefaultSubscriptionSetting + require.NoError(t, json.Unmarshal([]byte(repo.updates[SettingKeyAuthSourceDefaultWeChatSubscriptions]), &got)) + require.Equal(t, []DefaultSubscriptionSetting{{GroupID: 24, ValidityDays: 90}}, got) +} diff --git a/backend/internal/service/setting_service_public_test.go b/backend/internal/service/setting_service_public_test.go index 5cf1e860eeab69786a119d9745b95b6cff59c214..497d1e36059452f78f7203732b37a8d02f0d1d96 100644 --- a/backend/internal/service/setting_service_public_test.go +++ b/backend/internal/service/setting_service_public_test.go @@ -77,3 +77,38 @@ func TestSettingService_GetPublicSettings_ExposesTablePreferences(t *testing.T) require.Equal(t, 50, settings.TableDefaultPageSize) require.Equal(t, []int{20, 50, 100}, settings.TablePageSizeOptions) } + +func TestSettingService_GetPublicSettings_ExposesForceEmailOnThirdPartySignup(t *testing.T) { + repo := &settingPublicRepoStub{ + values: map[string]string{ + SettingKeyForceEmailOnThirdPartySignup: "true", + }, + } + svc := NewSettingService(repo, &config.Config{}) + + settings, err := svc.GetPublicSettings(context.Background()) + require.NoError(t, err) + require.True(t, settings.ForceEmailOnThirdPartySignup) +} + +func TestSettingService_GetPublicSettings_ExposesWeChatOAuthModeCapabilities(t *testing.T) { + svc := NewSettingService(&settingPublicRepoStub{ + values: map[string]string{ + SettingKeyWeChatConnectEnabled: "true", + SettingKeyWeChatConnectAppID: "wx-mp-app", + SettingKeyWeChatConnectAppSecret: "wx-mp-secret", + SettingKeyWeChatConnectMode: "mp", + SettingKeyWeChatConnectScopes: "snsapi_base", + SettingKeyWeChatConnectOpenEnabled: "true", + SettingKeyWeChatConnectMPEnabled: "true", + SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", + SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback", + }, + }, &config.Config{}) + + settings, err := svc.GetPublicSettings(context.Background()) + require.NoError(t, err) + require.True(t, settings.WeChatOAuthEnabled) + require.True(t, settings.WeChatOAuthOpenEnabled) + require.True(t, settings.WeChatOAuthMPEnabled) +} diff --git a/backend/internal/service/setting_service_update_test.go b/backend/internal/service/setting_service_update_test.go index e62218b4543d22487078bd2119198ece8bd7e5cc..9dc0ca59a3ff1d1b5bf99e9905bed1b576e35826 100644 --- a/backend/internal/service/setting_service_update_test.go +++ b/backend/internal/service/setting_service_update_test.go @@ -223,3 +223,34 @@ func TestSettingService_UpdateSettings_TablePreferences(t *testing.T) { require.Equal(t, "1000", repo.updates[SettingKeyTableDefaultPageSize]) require.Equal(t, "[20,100]", repo.updates[SettingKeyTablePageSizeOptions]) } + +func TestSettingService_UpdateSettings_PaymentVisibleMethodsAndAdvancedScheduler(t *testing.T) { + repo := &settingUpdateRepoStub{} + svc := NewSettingService(repo, &config.Config{}) + + err := svc.UpdateSettings(context.Background(), &SystemSettings{ + PaymentVisibleMethodAlipaySource: "alipay", + PaymentVisibleMethodWxpaySource: "easypay", + PaymentVisibleMethodAlipayEnabled: true, + PaymentVisibleMethodWxpayEnabled: false, + OpenAIAdvancedSchedulerEnabled: true, + }) + require.NoError(t, err) + require.Equal(t, VisibleMethodSourceOfficialAlipay, repo.updates[SettingPaymentVisibleMethodAlipaySource]) + require.Equal(t, VisibleMethodSourceEasyPayWechat, repo.updates[SettingPaymentVisibleMethodWxpaySource]) + require.Equal(t, "true", repo.updates[SettingPaymentVisibleMethodAlipayEnabled]) + require.Equal(t, "false", repo.updates[SettingPaymentVisibleMethodWxpayEnabled]) + require.Equal(t, "true", repo.updates[openAIAdvancedSchedulerSettingKey]) +} + +func TestSettingService_UpdateSettings_RejectsInvalidPaymentVisibleMethodSource(t *testing.T) { + repo := &settingUpdateRepoStub{} + svc := NewSettingService(repo, &config.Config{}) + + err := svc.UpdateSettings(context.Background(), &SystemSettings{ + PaymentVisibleMethodAlipaySource: "not-a-provider", + }) + require.Error(t, err) + require.Equal(t, "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE", infraerrors.Reason(err)) + require.Nil(t, repo.updates) +} diff --git a/backend/internal/service/setting_service_wechat_config_test.go b/backend/internal/service/setting_service_wechat_config_test.go new file mode 100644 index 0000000000000000000000000000000000000000..73d86e8fa0ffb7b75faa12debe98cd6b1a18d563 --- /dev/null +++ b/backend/internal/service/setting_service_wechat_config_test.go @@ -0,0 +1,81 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type settingWeChatRepoStub struct { + values map[string]string +} + +func (s *settingWeChatRepoStub) Get(context.Context, string) (*Setting, error) { + panic("unexpected Get call") +} + +func (s *settingWeChatRepoStub) GetValue(_ context.Context, key string) (string, error) { + if value, ok := s.values[key]; ok { + return value, nil + } + return "", ErrSettingNotFound +} + +func (s *settingWeChatRepoStub) Set(context.Context, string, string) error { + panic("unexpected Set call") +} + +func (s *settingWeChatRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (s *settingWeChatRepoStub) SetMultiple(context.Context, map[string]string) error { + panic("unexpected SetMultiple call") +} + +func (s *settingWeChatRepoStub) GetAll(context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *settingWeChatRepoStub) Delete(context.Context, string) error { + panic("unexpected Delete call") +} + +func TestSettingService_GetWeChatConnectOAuthConfig_UsesDatabaseOverrides(t *testing.T) { + repo := &settingWeChatRepoStub{ + values: map[string]string{ + SettingKeyWeChatConnectEnabled: "true", + SettingKeyWeChatConnectAppID: "wx-db-app", + SettingKeyWeChatConnectAppSecret: "wx-db-secret", + SettingKeyWeChatConnectMode: "mp", + SettingKeyWeChatConnectScopes: "snsapi_base", + SettingKeyWeChatConnectOpenEnabled: "true", + SettingKeyWeChatConnectMPEnabled: "true", + SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", + SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback", + }, + } + svc := NewSettingService(repo, &config.Config{}) + + got, err := svc.GetWeChatConnectOAuthConfig(context.Background()) + require.NoError(t, err) + require.True(t, got.Enabled) + require.Equal(t, "wx-db-app", got.AppIDForMode("mp")) + require.Equal(t, "wx-db-secret", got.AppSecretForMode("mp")) + require.True(t, got.OpenEnabled) + require.True(t, got.MPEnabled) + require.Equal(t, "mp", got.Mode) + require.Equal(t, "snsapi_base", got.Scopes) + require.Equal(t, "https://api.example.com/api/v1/auth/oauth/wechat/callback", got.RedirectURL) + require.Equal(t, "/auth/wechat/callback", got.FrontendRedirectURL) +} diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index ab2eb274fd95fbcb66291c6a1d3be1c13670a2a1..d2ef8faee0f1fb8ad7d3b9f228c5201c4f2ef3f3 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -1,5 +1,16 @@ package service +import "strings" + +func firstNonEmpty(values ...string) string { + for _, value := range values { + if trimmed := strings.TrimSpace(value); trimmed != "" { + return trimmed + } + } + return "" +} + type SystemSettings struct { RegistrationEnabled bool EmailVerifyEnabled bool @@ -31,6 +42,28 @@ type SystemSettings struct { LinuxDoConnectClientSecretConfigured bool LinuxDoConnectRedirectURL string + // WeChat Connect OAuth 登录 + WeChatConnectEnabled bool + WeChatConnectAppID string + WeChatConnectAppSecret string + WeChatConnectAppSecretConfigured bool + WeChatConnectOpenAppID string + WeChatConnectOpenAppSecret string + WeChatConnectOpenAppSecretConfigured bool + WeChatConnectMPAppID string + WeChatConnectMPAppSecret string + WeChatConnectMPAppSecretConfigured bool + WeChatConnectMobileAppID string + WeChatConnectMobileAppSecret string + WeChatConnectMobileAppSecretConfigured bool + WeChatConnectOpenEnabled bool + WeChatConnectMPEnabled bool + WeChatConnectMobileEnabled bool + WeChatConnectMode string + WeChatConnectScopes string + WeChatConnectRedirectURL string + WeChatConnectFrontendRedirectURL string + // Generic OIDC OAuth 登录 OIDCConnectEnabled bool OIDCConnectProviderName string @@ -110,6 +143,15 @@ type SystemSettings struct { // Web Search Emulation WebSearchEmulationEnabled bool // 是否启用 web search 模拟 + // Payment visible method routing + PaymentVisibleMethodAlipaySource string + PaymentVisibleMethodWxpaySource string + PaymentVisibleMethodAlipayEnabled bool + PaymentVisibleMethodWxpayEnabled bool + + // OpenAI account scheduling + OpenAIAdvancedSchedulerEnabled bool + // Balance low notification BalanceLowNotifyEnabled bool BalanceLowNotifyThreshold float64 @@ -128,6 +170,7 @@ type DefaultSubscriptionSetting struct { type PublicSettings struct { RegistrationEnabled bool EmailVerifyEnabled bool + ForceEmailOnThirdPartySignup bool RegistrationEmailSuffixWhitelist []string PromoCodeEnabled bool PasswordResetEnabled bool @@ -151,12 +194,16 @@ type PublicSettings struct { CustomMenuItems string // JSON array of custom menu items CustomEndpoints string // JSON array of custom endpoints - LinuxDoOAuthEnabled bool - BackendModeEnabled bool - PaymentEnabled bool - OIDCOAuthEnabled bool - OIDCOAuthProviderName string - Version string + LinuxDoOAuthEnabled bool + WeChatOAuthEnabled bool + WeChatOAuthOpenEnabled bool + WeChatOAuthMPEnabled bool + WeChatOAuthMobileEnabled bool + BackendModeEnabled bool + PaymentEnabled bool + OIDCOAuthEnabled bool + OIDCOAuthProviderName string + Version string BalanceLowNotifyEnabled bool AccountQuotaNotifyEnabled bool @@ -164,6 +211,66 @@ type PublicSettings struct { BalanceLowNotifyRechargeURL string } +type WeChatConnectOAuthConfig struct { + Enabled bool + LegacyAppID string + LegacyAppSecret string + OpenAppID string + OpenAppSecret string + MPAppID string + MPAppSecret string + MobileAppID string + MobileAppSecret string + OpenEnabled bool + MPEnabled bool + MobileEnabled bool + Mode string + Scopes string + RedirectURL string + FrontendRedirectURL string +} + +func (cfg WeChatConnectOAuthConfig) SupportsMode(mode string) bool { + switch normalizeWeChatConnectModeSetting(mode) { + case "mp": + return cfg.MPEnabled + case "mobile": + return cfg.MobileEnabled + default: + return cfg.OpenEnabled + } +} + +func (cfg WeChatConnectOAuthConfig) ScopeForMode(mode string) string { + switch normalizeWeChatConnectModeSetting(mode) { + case "mp": + return normalizeWeChatConnectScopeSetting(cfg.Scopes, "mp") + case "mobile": + return "" + } + return defaultWeChatConnectScopeForMode("open") +} + +func (cfg WeChatConnectOAuthConfig) AppIDForMode(mode string) string { + switch normalizeWeChatConnectModeSetting(mode) { + case "mp": + return strings.TrimSpace(firstNonEmpty(cfg.MPAppID, cfg.LegacyAppID)) + case "mobile": + return strings.TrimSpace(firstNonEmpty(cfg.MobileAppID, cfg.LegacyAppID)) + } + return strings.TrimSpace(firstNonEmpty(cfg.OpenAppID, cfg.LegacyAppID)) +} + +func (cfg WeChatConnectOAuthConfig) AppSecretForMode(mode string) string { + switch normalizeWeChatConnectModeSetting(mode) { + case "mp": + return strings.TrimSpace(firstNonEmpty(cfg.MPAppSecret, cfg.LegacyAppSecret)) + case "mobile": + return strings.TrimSpace(firstNonEmpty(cfg.MobileAppSecret, cfg.LegacyAppSecret)) + } + return strings.TrimSpace(firstNonEmpty(cfg.OpenAppSecret, cfg.LegacyAppSecret)) +} + // StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制) type StreamTimeoutSettings struct { // Enabled 是否启用流超时处理 diff --git a/backend/internal/service/sql_errors.go b/backend/internal/service/sql_errors.go new file mode 100644 index 0000000000000000000000000000000000000000..7c0155a4e68044742ddb1084f570f9dde33484bc --- /dev/null +++ b/backend/internal/service/sql_errors.go @@ -0,0 +1,14 @@ +package service + +import ( + "database/sql" + "errors" + "strings" +) + +func isSQLNoRowsError(err error) bool { + if err == nil { + return false + } + return errors.Is(err, sql.ErrNoRows) || strings.Contains(err.Error(), "no rows in result set") +} diff --git a/backend/internal/service/totp_service.go b/backend/internal/service/totp_service.go index 5192fe3d2763ae26d043079f99e5ff0c72e85470..052739ed19c82215479e2bcb656521ca7e394226 100644 --- a/backend/internal/service/totp_service.go +++ b/backend/internal/service/totp_service.go @@ -58,9 +58,15 @@ type TotpSetupSession struct { // TotpLoginSession represents a pending 2FA login session type TotpLoginSession struct { - UserID int64 - Email string - TokenExpiry time.Time + UserID int64 + Email string + TokenExpiry time.Time + PendingOAuthBind *PendingOAuthBindLoginSession `json:"pending_oauth_bind,omitempty"` +} + +type PendingOAuthBindLoginSession struct { + PendingSessionToken string `json:"pending_session_token,omitempty"` + BrowserSessionKey string `json:"browser_session_key,omitempty"` } // TotpStatus represents the TOTP status for a user @@ -397,6 +403,30 @@ func (s *TotpService) VerifyCode(ctx context.Context, userID int64, code string) // CreateLoginSession creates a temporary login session for 2FA func (s *TotpService) CreateLoginSession(ctx context.Context, userID int64, email string) (string, error) { + return s.createLoginSession(ctx, userID, email, nil) +} + +// CreatePendingOAuthBindLoginSession creates a temporary 2FA session that will +// finalize a pending OAuth bind after the TOTP code is verified. +func (s *TotpService) CreatePendingOAuthBindLoginSession( + ctx context.Context, + userID int64, + email string, + pendingSessionToken string, + browserSessionKey string, +) (string, error) { + return s.createLoginSession(ctx, userID, email, &PendingOAuthBindLoginSession{ + PendingSessionToken: pendingSessionToken, + BrowserSessionKey: browserSessionKey, + }) +} + +func (s *TotpService) createLoginSession( + ctx context.Context, + userID int64, + email string, + pendingOAuthBind *PendingOAuthBindLoginSession, +) (string, error) { // Generate a random temp token tempToken, err := generateRandomToken(32) if err != nil { @@ -404,9 +434,10 @@ func (s *TotpService) CreateLoginSession(ctx context.Context, userID int64, emai } session := &TotpLoginSession{ - UserID: userID, - Email: email, - TokenExpiry: time.Now().Add(totpLoginTTL), + UserID: userID, + Email: email, + TokenExpiry: time.Now().Add(totpLoginTTL), + PendingOAuthBind: pendingOAuthBind, } if err := s.cache.SetLoginSession(ctx, tempToken, session, totpLoginTTL); err != nil { diff --git a/backend/internal/service/user.go b/backend/internal/service/user.go index 59f8aa6b78eb4433a877f6b5490c0b51d28de9f4..fa04d95e84d2e324b0df133f805e68c7f624b82b 100644 --- a/backend/internal/service/user.go +++ b/backend/internal/service/user.go @@ -7,19 +7,28 @@ import ( ) type User struct { - ID int64 - Email string - Username string - Notes string - PasswordHash string - Role string - Balance float64 - Concurrency int - Status string - AllowedGroups []int64 - TokenVersion int64 // Incremented on password change to invalidate existing tokens - CreatedAt time.Time - UpdatedAt time.Time + ID int64 + Email string + Username string + Notes string + AvatarURL string + AvatarSource string + AvatarMIME string + AvatarByteSize int + AvatarSHA256 string + PasswordHash string + Role string + Balance float64 + Concurrency int + Status string + AllowedGroups []int64 + TokenVersion int64 // Incremented on password change to invalidate existing tokens + SignupSource string + LastLoginAt *time.Time + LastActiveAt *time.Time + LastUsedAt *time.Time + CreatedAt time.Time + UpdatedAt time.Time // GroupRates 用户专属分组倍率配置 // map[groupID]rateMultiplier diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 3490e8042148579c5673f2cbd14092bea454c6fe..bc444af5cfa81c2063e413157d7429759f17aca7 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -1,30 +1,66 @@ package service import ( + "bytes" "context" + "crypto/sha256" "crypto/subtle" + "encoding/base64" + "encoding/hex" "fmt" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "image" + "image/color" + stddraw "image/draw" + _ "image/gif" + "image/jpeg" + _ "image/png" "log/slog" + "net/url" + "sort" + "strconv" "strings" + "sync" "time" - infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" - "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + xdraw "golang.org/x/image/draw" + "golang.org/x/sync/singleflight" ) var ( - ErrUserNotFound = infraerrors.NotFound("USER_NOT_FOUND", "user not found") - ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect") - ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions") - ErrNotifyCodeUserRateLimit = infraerrors.TooManyRequests("NOTIFY_CODE_USER_RATE_LIMIT", "too many verification codes requested, please try again later") + ErrUserNotFound = infraerrors.NotFound("USER_NOT_FOUND", "user not found") + ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect") + ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions") + ErrNotifyCodeUserRateLimit = infraerrors.TooManyRequests("NOTIFY_CODE_USER_RATE_LIMIT", "too many verification codes requested, please try again later") + ErrAvatarInvalid = infraerrors.BadRequest("AVATAR_INVALID", "avatar must be a valid image data URL or http(s) URL") + ErrAvatarTooLarge = infraerrors.BadRequest("AVATAR_TOO_LARGE", "avatar image must be 100KB or smaller") + ErrAvatarNotImage = infraerrors.BadRequest("AVATAR_NOT_IMAGE", "avatar content must be an image") + ErrIdentityProviderInvalid = infraerrors.BadRequest("IDENTITY_PROVIDER_INVALID", "identity provider is invalid") + ErrIdentityRedirectInvalid = infraerrors.BadRequest("IDENTITY_REDIRECT_INVALID", "identity redirect path is invalid") + ErrIdentityUnbindLastMethod = infraerrors.Conflict( + "IDENTITY_UNBIND_LAST_METHOD", + "bind another sign-in method before unbinding this provider", + ) ) const ( - maxNotifyEmails = 3 // Maximum number of notification emails per user + maxNotifyEmails = 3 // Maximum number of notification emails per user + maxInlineAvatarBytes = 100 * 1024 + targetAvatarBytes = 20 * 1024 // User-level rate limiting for notify email verification codes notifyCodeUserRateLimit = 5 notifyCodeUserRateWindow = 10 * time.Minute + + defaultUserIdentityRedirect = "/settings/profile" + userLastActiveMinTouch = 10 * time.Minute + userLastActiveFailBackoff = 30 * time.Second +) + +var ( + avatarScaleSteps = []float64{1, 0.92, 0.84, 0.76, 0.68, 0.6, 0.52, 0.44, 0.36} + avatarQualitySteps = []int{88, 80, 72, 64, 56, 48, 40, 32} ) // UserListFilters contains all filter options for listing users @@ -47,9 +83,15 @@ type UserRepository interface { GetFirstAdmin(ctx context.Context) (*User, error) Update(ctx context.Context, user *User) error Delete(ctx context.Context, id int64) error + GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error) + UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) + DeleteUserAvatar(ctx context.Context, userID int64) error List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error) + GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) + GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) + UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error UpdateBalance(ctx context.Context, id int64, amount float64) error DeductBalance(ctx context.Context, id int64, amount float64) error @@ -60,6 +102,8 @@ type UserRepository interface { AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error // RemoveGroupFromUserAllowedGroups 移除单个用户的指定分组权限 RemoveGroupFromUserAllowedGroups(ctx context.Context, userID int64, groupID int64) error + ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error) + UnbindUserAuthProvider(ctx context.Context, userID int64, provider string) error // TOTP 双因素认证 UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error @@ -67,15 +111,82 @@ type UserRepository interface { DisableTotp(ctx context.Context, userID int64) error } +type UserAuthIdentityRecord struct { + ProviderType string + ProviderKey string + ProviderSubject string + VerifiedAt *time.Time + Issuer *string + Metadata map[string]any + CreatedAt time.Time + UpdatedAt time.Time +} + +type UserIdentitySummary struct { + Provider string `json:"provider"` + Bound bool `json:"bound"` + BoundCount int `json:"bound_count"` + DisplayName string `json:"display_name,omitempty"` + SubjectHint string `json:"subject_hint,omitempty"` + ProviderKey string `json:"provider_key,omitempty"` + VerifiedAt *time.Time `json:"verified_at,omitempty"` + BindStartPath string `json:"bind_start_path,omitempty"` + CanBind bool `json:"can_bind"` + CanUnbind bool `json:"can_unbind"` + Note string `json:"note,omitempty"` +} + +type UserIdentitySummarySet struct { + Email UserIdentitySummary `json:"email"` + LinuxDo UserIdentitySummary `json:"linuxdo"` + OIDC UserIdentitySummary `json:"oidc"` + WeChat UserIdentitySummary `json:"wechat"` +} + +type StartUserIdentityBindingRequest struct { + Provider string + RedirectTo string +} + +type StartUserIdentityBindingResult struct { + Provider string `json:"provider"` + AuthorizeURL string `json:"authorize_url"` + Method string `json:"method"` + UseBrowserRedirect bool `json:"use_browser_redirect"` +} + // UpdateProfileRequest 更新用户资料请求 type UpdateProfileRequest struct { Email *string `json:"email"` Username *string `json:"username"` + AvatarURL *string `json:"avatar_url"` Concurrency *int `json:"concurrency"` BalanceNotifyEnabled *bool `json:"balance_notify_enabled"` BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"` } +type UserAvatar struct { + StorageProvider string + StorageKey string + URL string + ContentType string + ByteSize int + SHA256 string +} + +type UpsertUserAvatarInput struct { + StorageProvider string + StorageKey string + URL string + ContentType string + ByteSize int + SHA256 string +} + +type userProfileIdentityTxRunner interface { + WithUserProfileIdentityTx(ctx context.Context, fn func(txCtx context.Context) error) error +} + // ChangePasswordRequest 修改密码请求 type ChangePasswordRequest struct { CurrentPassword string `json:"current_password"` @@ -88,6 +199,8 @@ type UserService struct { settingRepo SettingRepository authCacheInvalidator APIKeyAuthCacheInvalidator billingCache BillingCache + lastActiveTouchL1 sync.Map + lastActiveTouchSF singleflight.Group } // NewUserService 创建用户服务实例 @@ -115,14 +228,123 @@ func (s *UserService) GetProfile(ctx context.Context, userID int64) (*User, erro if err != nil { return nil, fmt.Errorf("get user: %w", err) } + if err := s.hydrateUserAvatar(ctx, user); err != nil { + return nil, fmt.Errorf("get user avatar: %w", err) + } return user, nil } +func (s *UserService) GetProfileIdentitySummaries(ctx context.Context, userID int64, user *User) (UserIdentitySummarySet, error) { + if user == nil { + var err error + user, err = s.userRepo.GetByID(ctx, userID) + if err != nil { + return UserIdentitySummarySet{}, fmt.Errorf("get user: %w", err) + } + } + + records, err := s.listUserAuthIdentities(ctx, userID) + if err != nil { + return UserIdentitySummarySet{}, err + } + + return UserIdentitySummarySet{ + Email: s.buildEmailIdentitySummary(user, records), + LinuxDo: s.buildProviderIdentitySummary("linuxdo", user, records), + OIDC: s.buildProviderIdentitySummary("oidc", user, records), + WeChat: s.buildProviderIdentitySummary("wechat", user, records), + }, nil +} + +func (s *UserService) PrepareIdentityBindingStart(_ context.Context, req StartUserIdentityBindingRequest) (*StartUserIdentityBindingResult, error) { + provider := normalizeUserIdentityProvider(req.Provider) + if provider == "" { + return nil, ErrIdentityProviderInvalid + } + + authorizeURL, err := buildUserIdentityBindAuthorizeURL(provider, req.RedirectTo) + if err != nil { + return nil, err + } + + return &StartUserIdentityBindingResult{ + Provider: provider, + AuthorizeURL: authorizeURL, + Method: "GET", + UseBrowserRedirect: true, + }, nil +} + +func (s *UserService) UnbindUserAuthProvider(ctx context.Context, userID int64, provider string) (*User, error) { + provider = normalizeUserIdentityProvider(provider) + if provider == "" || provider == "email" { + return nil, ErrIdentityProviderInvalid + } + + user, err := s.userRepo.GetByID(ctx, userID) + if err != nil { + return nil, fmt.Errorf("get user: %w", err) + } + + records, err := s.listUserAuthIdentities(ctx, userID) + if err != nil { + return nil, err + } + if len(filterUserAuthIdentities(records, provider)) == 0 { + return user, nil + } + if !s.canUnbindProvider(provider, user, records) { + return nil, ErrIdentityUnbindLastMethod + } + + if err := s.userRepo.UnbindUserAuthProvider(ctx, userID, provider); err != nil { + return nil, err + } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } + + updatedUser, err := s.GetProfile(ctx, userID) + if err != nil { + return nil, err + } + return updatedUser, nil +} + // UpdateProfile 更新用户资料 func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*User, error) { + if txRunner, ok := s.userRepo.(userProfileIdentityTxRunner); ok { + var ( + updated *User + oldConcurrency int + ) + if err := txRunner.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error { + var err error + updated, oldConcurrency, err = s.updateProfile(txCtx, userID, req) + return err + }); err != nil { + return nil, err + } + if s.authCacheInvalidator != nil && updated != nil && updated.Concurrency != oldConcurrency { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } + return updated, nil + } + + updated, oldConcurrency, err := s.updateProfile(ctx, userID, req) + if err != nil { + return nil, err + } + if s.authCacheInvalidator != nil && updated.Concurrency != oldConcurrency { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } + return updated, nil +} + +func (s *UserService) updateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*User, int, error) { user, err := s.userRepo.GetByID(ctx, userID) if err != nil { - return nil, fmt.Errorf("get user: %w", err) + return nil, 0, fmt.Errorf("get user: %w", err) } oldConcurrency := user.Concurrency @@ -131,10 +353,10 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat // 检查新邮箱是否已被使用 exists, err := s.userRepo.ExistsByEmail(ctx, *req.Email) if err != nil { - return nil, fmt.Errorf("check email exists: %w", err) + return nil, oldConcurrency, fmt.Errorf("check email exists: %w", err) } if exists && *req.Email != user.Email { - return nil, ErrEmailExists + return nil, oldConcurrency, ErrEmailExists } user.Email = *req.Email } @@ -143,6 +365,14 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat user.Username = *req.Username } + if req.AvatarURL != nil { + avatar, err := s.SetAvatar(ctx, userID, *req.AvatarURL) + if err != nil { + return nil, oldConcurrency, err + } + applyUserAvatar(user, avatar) + } + if req.Concurrency != nil { user.Concurrency = *req.Concurrency } @@ -159,13 +389,423 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat } if err := s.userRepo.Update(ctx, user); err != nil { - return nil, fmt.Errorf("update user: %w", err) + return nil, oldConcurrency, fmt.Errorf("update user: %w", err) } - if s.authCacheInvalidator != nil && user.Concurrency != oldConcurrency { - s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + + return user, oldConcurrency, nil +} + +func (s *UserService) SetAvatar(ctx context.Context, userID int64, raw string) (*UserAvatar, error) { + avatarValue := strings.TrimSpace(raw) + if avatarValue == "" { + if err := s.userRepo.DeleteUserAvatar(ctx, userID); err != nil { + return nil, fmt.Errorf("delete avatar: %w", err) + } + return nil, nil } - return user, nil + avatarInput, err := normalizeUserAvatarInput(avatarValue) + if err != nil { + return nil, err + } + + avatar, err := s.userRepo.UpsertUserAvatar(ctx, userID, avatarInput) + if err != nil { + return nil, fmt.Errorf("upsert avatar: %w", err) + } + return avatar, nil +} + +func applyUserAvatar(user *User, avatar *UserAvatar) { + if user == nil { + return + } + if avatar == nil { + user.AvatarURL = "" + user.AvatarSource = "" + user.AvatarMIME = "" + user.AvatarByteSize = 0 + user.AvatarSHA256 = "" + return + } + + user.AvatarURL = avatar.URL + user.AvatarSource = avatar.StorageProvider + user.AvatarMIME = avatar.ContentType + user.AvatarByteSize = avatar.ByteSize + user.AvatarSHA256 = avatar.SHA256 +} + +func normalizeUserAvatarInput(raw string) (UpsertUserAvatarInput, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return UpsertUserAvatarInput{}, ErrAvatarInvalid + } + if strings.HasPrefix(raw, "data:") { + return normalizeInlineUserAvatarInput(raw) + } + + parsed, err := url.Parse(raw) + if err != nil || parsed == nil { + return UpsertUserAvatarInput{}, ErrAvatarInvalid + } + if !strings.EqualFold(parsed.Scheme, "http") && !strings.EqualFold(parsed.Scheme, "https") { + return UpsertUserAvatarInput{}, ErrAvatarInvalid + } + if strings.TrimSpace(parsed.Host) == "" { + return UpsertUserAvatarInput{}, ErrAvatarInvalid + } + + return UpsertUserAvatarInput{ + StorageProvider: "remote_url", + URL: raw, + }, nil +} + +func ValidateUserAvatar(raw string) error { + _, err := normalizeUserAvatarInput(raw) + return err +} + +func normalizeInlineUserAvatarInput(raw string) (UpsertUserAvatarInput, error) { + body := strings.TrimPrefix(raw, "data:") + meta, encoded, ok := strings.Cut(body, ",") + if !ok { + return UpsertUserAvatarInput{}, ErrAvatarInvalid + } + meta = strings.TrimSpace(meta) + encoded = strings.TrimSpace(encoded) + if !strings.HasSuffix(strings.ToLower(meta), ";base64") { + return UpsertUserAvatarInput{}, ErrAvatarInvalid + } + + contentType := strings.TrimSpace(meta[:len(meta)-len(";base64")]) + if contentType == "" || !strings.HasPrefix(strings.ToLower(contentType), "image/") { + return UpsertUserAvatarInput{}, ErrAvatarNotImage + } + + decoded, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + return UpsertUserAvatarInput{}, ErrAvatarInvalid + } + if len(decoded) > maxInlineAvatarBytes { + return UpsertUserAvatarInput{}, ErrAvatarTooLarge + } + + if len(decoded) > targetAvatarBytes { + decoded, contentType, err = compressInlineAvatar(decoded) + if err != nil { + return UpsertUserAvatarInput{}, err + } + raw = "data:" + contentType + ";base64," + base64.StdEncoding.EncodeToString(decoded) + } + + sum := sha256.Sum256(decoded) + return UpsertUserAvatarInput{ + StorageProvider: "inline", + URL: raw, + ContentType: contentType, + ByteSize: len(decoded), + SHA256: hex.EncodeToString(sum[:]), + }, nil +} + +func compressInlineAvatar(decoded []byte) ([]byte, string, error) { + src, _, err := image.Decode(bytes.NewReader(decoded)) + if err != nil { + return nil, "", ErrAvatarInvalid + } + + srcBounds := src.Bounds() + if srcBounds.Empty() { + return nil, "", ErrAvatarInvalid + } + + for _, scale := range avatarScaleSteps { + width := max(1, int(float64(srcBounds.Dx())*scale)) + height := max(1, int(float64(srcBounds.Dy())*scale)) + dst := image.NewRGBA(image.Rect(0, 0, width, height)) + stddraw.Draw(dst, dst.Bounds(), &image.Uniform{C: color.White}, image.Point{}, stddraw.Src) + xdraw.CatmullRom.Scale(dst, dst.Bounds(), src, srcBounds, stddraw.Over, nil) + + for _, quality := range avatarQualitySteps { + var buf bytes.Buffer + if err := jpeg.Encode(&buf, dst, &jpeg.Options{Quality: quality}); err != nil { + return nil, "", ErrAvatarInvalid + } + if buf.Len() <= targetAvatarBytes { + return buf.Bytes(), "image/jpeg", nil + } + } + } + + return nil, "", ErrAvatarTooLarge +} + +func (s *UserService) buildEmailIdentitySummary(user *User, records []UserAuthIdentityRecord) UserIdentitySummary { + summary := UserIdentitySummary{ + Provider: "email", + CanBind: false, + CanUnbind: false, + Note: "Primary account email is managed from the profile form.", + } + if user == nil { + return summary + } + + filtered := filterUserAuthIdentities(records, "email") + if len(filtered) > 0 { + primary := selectPrimaryUserAuthIdentity(filtered) + email := strings.TrimSpace(firstStringIdentityValue(primary.Metadata, "email")) + if email == "" { + email = strings.TrimSpace(primary.ProviderSubject) + } + if email == "" || isReservedEmail(email) { + email = strings.TrimSpace(user.Email) + } + if email == "" || isReservedEmail(email) { + email = strings.TrimSpace(primary.ProviderKey) + } + + summary.Bound = true + summary.BoundCount = len(filtered) + summary.DisplayName = email + summary.SubjectHint = maskEmailIdentity(email) + summary.ProviderKey = strings.TrimSpace(primary.ProviderKey) + summary.VerifiedAt = primary.VerifiedAt + return summary + } + + // Compatibility fallback for legacy normal-email users that predate auth_identities backfill. + email := strings.TrimSpace(user.Email) + if email == "" || isReservedEmail(email) { + return summary + } + summary.Bound = true + summary.BoundCount = 1 + summary.DisplayName = email + summary.SubjectHint = maskEmailIdentity(email) + summary.ProviderKey = "email" + return summary +} + +func (s *UserService) buildProviderIdentitySummary(provider string, user *User, records []UserAuthIdentityRecord) UserIdentitySummary { + summary := UserIdentitySummary{ + Provider: provider, + CanUnbind: false, + } + filtered := filterUserAuthIdentities(records, provider) + if len(filtered) == 0 { + summary.CanBind = true + bindStartPath, err := buildUserIdentityBindAuthorizeURL(provider, "") + if err == nil { + summary.BindStartPath = bindStartPath + } + return summary + } + + primary := selectPrimaryUserAuthIdentity(filtered) + summary.Bound = true + summary.BoundCount = len(filtered) + summary.DisplayName = userAuthIdentityDisplayName(primary) + summary.SubjectHint = maskOpaqueIdentity(primary.ProviderSubject) + summary.ProviderKey = strings.TrimSpace(primary.ProviderKey) + summary.VerifiedAt = primary.VerifiedAt + summary.CanUnbind = s.canUnbindProvider(provider, user, records) + if summary.CanUnbind { + summary.Note = "You can unbind this sign-in method." + } else { + summary.Note = "Bind another sign-in method before unbinding." + } + return summary +} + +func (s *UserService) canUnbindProvider(provider string, user *User, records []UserAuthIdentityRecord) bool { + if provider == "" || provider == "email" || len(filterUserAuthIdentities(records, provider)) == 0 { + return false + } + + if s.buildEmailIdentitySummary(user, records).Bound { + return true + } + + for _, candidate := range []string{"linuxdo", "oidc", "wechat"} { + if candidate == provider { + continue + } + if len(filterUserAuthIdentities(records, candidate)) > 0 { + return true + } + } + + return false +} + +func (s *UserService) listUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error) { + if userID <= 0 || s == nil || s.userRepo == nil { + return nil, nil + } + return s.userRepo.ListUserAuthIdentities(ctx, userID) +} + +func buildUserIdentityBindAuthorizeURL(provider, redirectTo string) (string, error) { + provider = normalizeUserIdentityProvider(provider) + if provider == "" || provider == "email" { + return "", ErrIdentityProviderInvalid + } + + redirectTo, err := normalizeUserIdentityRedirect(redirectTo) + if err != nil { + return "", err + } + + path := "" + switch provider { + case "linuxdo": + path = "/api/v1/auth/oauth/linuxdo/start" + case "oidc": + path = "/api/v1/auth/oauth/oidc/start" + case "wechat": + path = "/api/v1/auth/oauth/wechat/start" + default: + return "", ErrIdentityProviderInvalid + } + + query := url.Values{} + query.Set("redirect", redirectTo) + query.Set("intent", "bind_current_user") + return path + "?" + query.Encode(), nil +} + +func normalizeUserIdentityProvider(provider string) string { + switch strings.ToLower(strings.TrimSpace(provider)) { + case "linuxdo": + return "linuxdo" + case "oidc": + return "oidc" + case "wechat": + return "wechat" + case "email": + return "email" + default: + return "" + } +} + +func normalizeUserIdentityRedirect(raw string) (string, error) { + redirect := strings.TrimSpace(raw) + if redirect == "" { + return defaultUserIdentityRedirect, nil + } + if len(redirect) > 2048 || !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") { + return "", ErrIdentityRedirectInvalid + } + return redirect, nil +} + +func filterUserAuthIdentities(records []UserAuthIdentityRecord, provider string) []UserAuthIdentityRecord { + if len(records) == 0 { + return nil + } + filtered := make([]UserAuthIdentityRecord, 0, len(records)) + for _, record := range records { + if strings.EqualFold(strings.TrimSpace(record.ProviderType), provider) { + filtered = append(filtered, record) + } + } + return filtered +} + +func selectPrimaryUserAuthIdentity(records []UserAuthIdentityRecord) UserAuthIdentityRecord { + if len(records) == 0 { + return UserAuthIdentityRecord{} + } + sort.SliceStable(records, func(i, j int) bool { + left := userAuthIdentitySortTime(records[i]) + right := userAuthIdentitySortTime(records[j]) + if !left.Equal(right) { + return left.After(right) + } + return records[i].ProviderKey < records[j].ProviderKey + }) + return records[0] +} + +func userAuthIdentitySortTime(record UserAuthIdentityRecord) time.Time { + if record.VerifiedAt != nil && !record.VerifiedAt.IsZero() { + return record.VerifiedAt.UTC() + } + if !record.UpdatedAt.IsZero() { + return record.UpdatedAt.UTC() + } + if !record.CreatedAt.IsZero() { + return record.CreatedAt.UTC() + } + return time.Time{} +} + +func userAuthIdentityDisplayName(record UserAuthIdentityRecord) string { + if displayName := firstStringIdentityValue(record.Metadata, + "display_name", + "suggested_display_name", + "username", + "name", + "nickname", + "email", + ); displayName != "" { + return displayName + } + if subject := strings.TrimSpace(record.ProviderSubject); subject != "" { + return subject + } + return strings.TrimSpace(record.ProviderType) +} + +func firstStringIdentityValue(values map[string]any, keys ...string) string { + for _, key := range keys { + raw, ok := values[key] + if !ok { + continue + } + switch value := raw.(type) { + case string: + if trimmed := strings.TrimSpace(value); trimmed != "" { + return trimmed + } + case fmt.Stringer: + if trimmed := strings.TrimSpace(value.String()); trimmed != "" { + return trimmed + } + } + } + return "" +} + +func maskEmailIdentity(email string) string { + local, domain, ok := strings.Cut(strings.TrimSpace(email), "@") + if !ok || local == "" || domain == "" { + return maskOpaqueIdentity(email) + } + runes := []rune(local) + if len(runes) == 1 { + return string(runes[0]) + "***@" + domain + } + return string(runes[0]) + "***" + string(runes[len(runes)-1]) + "@" + domain +} + +func maskOpaqueIdentity(value string) string { + value = strings.TrimSpace(value) + runes := []rune(value) + switch { + case len(runes) == 0: + return "" + case len(runes) <= 4: + return string(runes[0]) + "***" + case len(runes) <= 8: + return string(runes[:2]) + "***" + string(runes[len(runes)-1:]) + default: + return string(runes[:3]) + "***" + string(runes[len(runes)-3:]) + } } // ChangePassword 修改密码 @@ -202,9 +842,85 @@ func (s *UserService) GetByID(ctx context.Context, id int64) (*User, error) { if err != nil { return nil, fmt.Errorf("get user: %w", err) } + if err := s.hydrateUserAvatar(ctx, user); err != nil { + return nil, fmt.Errorf("get user avatar: %w", err) + } return user, nil } +// TouchLastActive 通过防抖更新 users.last_active_at,减少鉴权热路径写放大。 +// 该操作为尽力而为,不应中断正常请求。 +func (s *UserService) TouchLastActive(ctx context.Context, userID int64) { + if s == nil || s.userRepo == nil || userID <= 0 { + return + } + + user, err := s.userRepo.GetByID(ctx, userID) + if err != nil { + slog.Debug("skip touch user last active after load failure", "user_id", userID, "error", err) + return + } + s.TouchLastActiveForUser(ctx, user) +} + +// TouchLastActiveForUser 使用已加载的用户信息更新 last_active_at,避免重复读取数据库。 +func (s *UserService) TouchLastActiveForUser(ctx context.Context, user *User) { + if s == nil || s.userRepo == nil || user == nil || user.ID <= 0 { + return + } + + now := time.Now() + if userLastActiveFresh(user.LastActiveAt, now) { + return + } + if v, ok := s.lastActiveTouchL1.Load(user.ID); ok { + if nextAllowedAt, ok := v.(time.Time); ok && now.Before(nextAllowedAt) { + return + } + } + + _, err, _ := s.lastActiveTouchSF.Do(strconv.FormatInt(user.ID, 10), func() (any, error) { + latest := time.Now() + if v, ok := s.lastActiveTouchL1.Load(user.ID); ok { + if nextAllowedAt, ok := v.(time.Time); ok && latest.Before(nextAllowedAt) { + return nil, nil + } + } + if userLastActiveFresh(user.LastActiveAt, latest) { + return nil, nil + } + if err := s.userRepo.UpdateUserLastActiveAt(ctx, user.ID, latest); err != nil { + s.lastActiveTouchL1.Store(user.ID, latest.Add(userLastActiveFailBackoff)) + return nil, fmt.Errorf("touch user last active: %w", err) + } + s.lastActiveTouchL1.Store(user.ID, latest.Add(userLastActiveMinTouch)) + return nil, nil + }) + if err != nil { + slog.Warn("touch user last active failed", "user_id", user.ID, "error", err) + } +} + +func userLastActiveFresh(lastActiveAt *time.Time, now time.Time) bool { + if lastActiveAt == nil { + return false + } + return now.Before(lastActiveAt.Add(userLastActiveMinTouch)) +} + +func (s *UserService) hydrateUserAvatar(ctx context.Context, user *User) error { + if s == nil || s.userRepo == nil || user == nil || user.ID == 0 { + return nil + } + + avatar, err := s.userRepo.GetUserAvatar(ctx, user.ID) + if err != nil { + return err + } + applyUserAvatar(user, avatar) + return nil +} + // List 获取用户列表(管理员功能) func (s *UserService) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { users, pagination, err := s.userRepo.List(ctx, params) diff --git a/backend/internal/service/user_service_email_identity_sync_test.go b/backend/internal/service/user_service_email_identity_sync_test.go new file mode 100644 index 0000000000000000000000000000000000000000..702b3b1a21503ecdd32f96ad87919788cb14f07b --- /dev/null +++ b/backend/internal/service/user_service_email_identity_sync_test.go @@ -0,0 +1,34 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestUpdateProfile_DoesNotReturnPartialSuccessFromEmailIdentityResync(t *testing.T) { + repo := &emailSyncRepoStub{ + user: &User{ + ID: 19, + Email: "profile-before@example.com", + Username: "tester", + Concurrency: 2, + }, + replaceErr: context.DeadlineExceeded, + } + svc := NewUserService(repo, nil, nil, nil) + + newEmail := "profile-after@example.com" + updated, err := svc.UpdateProfile(context.Background(), 19, UpdateProfileRequest{ + Email: &newEmail, + }) + require.NoError(t, err) + require.NotNil(t, updated) + require.Equal(t, newEmail, updated.Email) + require.Equal(t, 1, repo.updateCalls) + require.Empty(t, repo.replaceCalls) + require.Empty(t, repo.ensureCalls) +} diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go index a998d5f443d06d3a2238ad6923b0d4afd74a5579..88bb1637839cd6bde13ec9973b905cbd3f357bef 100644 --- a/backend/internal/service/user_service_test.go +++ b/backend/internal/service/user_service_test.go @@ -3,8 +3,14 @@ package service import ( + "bytes" "context" + "crypto/sha256" + "encoding/base64" + "encoding/hex" "errors" + "image" + "image/png" "sync" "sync/atomic" "testing" @@ -17,16 +23,121 @@ import ( // --- mock: UserRepository --- type mockUserRepo struct { - updateBalanceErr error - updateBalanceFn func(ctx context.Context, id int64, amount float64) error + updateBalanceErr error + updateBalanceFn func(ctx context.Context, id int64, amount float64) error + getByIDUser *User + getByIDErr error + identities []UserAuthIdentityRecord + unbindIdentityErr error + unboundProviders []string + updateLastActiveErr error + updateLastActiveUserIDs []int64 + updateLastActiveAt []time.Time + updateFn func(ctx context.Context, user *User) error + updateCalls int + upsertAvatarFn func(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) + upsertAvatarArgs []UpsertUserAvatarInput + deleteAvatarFn func(ctx context.Context, userID int64) error + deleteAvatarIDs []int64 + getAvatarFn func(ctx context.Context, userID int64) (*UserAvatar, error) + txCalls int } -func (m *mockUserRepo) Create(context.Context, *User) error { return nil } -func (m *mockUserRepo) GetByID(context.Context, int64) (*User, error) { return &User{}, nil } +type mockUserRepoTxKey struct{} + +type mockUserRepoTxState struct { + getByIDUser *User + upsertAvatarArgs []UpsertUserAvatarInput + deleteAvatarIDs []int64 +} + +func (m *mockUserRepo) Create(context.Context, *User) error { return nil } +func (m *mockUserRepo) GetByID(ctx context.Context, _ int64) (*User, error) { + if m.getByIDErr != nil { + return nil, m.getByIDErr + } + if txState, _ := ctx.Value(mockUserRepoTxKey{}).(*mockUserRepoTxState); txState != nil && txState.getByIDUser != nil { + cloned := *txState.getByIDUser + return &cloned, nil + } + if m.getByIDUser != nil { + cloned := *m.getByIDUser + return &cloned, nil + } + return &User{}, nil +} func (m *mockUserRepo) GetByEmail(context.Context, string) (*User, error) { return &User{}, nil } func (m *mockUserRepo) GetFirstAdmin(context.Context) (*User, error) { return &User{}, nil } -func (m *mockUserRepo) Update(context.Context, *User) error { return nil } -func (m *mockUserRepo) Delete(context.Context, int64) error { return nil } +func (m *mockUserRepo) Update(ctx context.Context, user *User) error { + m.updateCalls++ + if m.updateFn != nil { + return m.updateFn(ctx, user) + } + return nil +} +func (m *mockUserRepo) Delete(context.Context, int64) error { return nil } +func (m *mockUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error) { + if m.getAvatarFn != nil { + return m.getAvatarFn(ctx, userID) + } + return nil, nil +} +func (m *mockUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) { + if txState, _ := ctx.Value(mockUserRepoTxKey{}).(*mockUserRepoTxState); txState != nil { + txState.upsertAvatarArgs = append(txState.upsertAvatarArgs, input) + if txState.getByIDUser != nil { + txState.getByIDUser.AvatarURL = input.URL + txState.getByIDUser.AvatarSource = input.StorageProvider + txState.getByIDUser.AvatarMIME = input.ContentType + txState.getByIDUser.AvatarByteSize = input.ByteSize + txState.getByIDUser.AvatarSHA256 = input.SHA256 + } + if m.upsertAvatarFn != nil { + return m.upsertAvatarFn(ctx, userID, input) + } + return &UserAvatar{ + StorageProvider: input.StorageProvider, + StorageKey: input.StorageKey, + URL: input.URL, + ContentType: input.ContentType, + ByteSize: input.ByteSize, + SHA256: input.SHA256, + }, nil + } + m.upsertAvatarArgs = append(m.upsertAvatarArgs, input) + if m.upsertAvatarFn != nil { + return m.upsertAvatarFn(ctx, userID, input) + } + return &UserAvatar{ + StorageProvider: input.StorageProvider, + StorageKey: input.StorageKey, + URL: input.URL, + ContentType: input.ContentType, + ByteSize: input.ByteSize, + SHA256: input.SHA256, + }, nil +} +func (m *mockUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error { + if txState, _ := ctx.Value(mockUserRepoTxKey{}).(*mockUserRepoTxState); txState != nil { + txState.deleteAvatarIDs = append(txState.deleteAvatarIDs, userID) + if txState.getByIDUser != nil { + txState.getByIDUser.AvatarURL = "" + txState.getByIDUser.AvatarSource = "" + txState.getByIDUser.AvatarMIME = "" + txState.getByIDUser.AvatarByteSize = 0 + txState.getByIDUser.AvatarSHA256 = "" + } + if m.deleteAvatarFn != nil { + return m.deleteAvatarFn(ctx, userID) + } + return nil + } + m.deleteAvatarIDs = append(m.deleteAvatarIDs, userID) + if m.deleteAvatarFn != nil { + return m.deleteAvatarFn(ctx, userID) + } + return nil +} func (m *mockUserRepo) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { return nil, nil, nil } @@ -39,6 +150,11 @@ func (m *mockUserRepo) UpdateBalance(ctx context.Context, id int64, amount float } return m.updateBalanceErr } +func (m *mockUserRepo) UpdateUserLastActiveAt(_ context.Context, userID int64, activeAt time.Time) error { + m.updateLastActiveUserIDs = append(m.updateLastActiveUserIDs, userID) + m.updateLastActiveAt = append(m.updateLastActiveAt, activeAt) + return m.updateLastActiveErr +} func (m *mockUserRepo) DeductBalance(context.Context, int64, float64) error { return nil } func (m *mockUserRepo) UpdateConcurrency(context.Context, int64, int) error { return nil } func (m *mockUserRepo) ExistsByEmail(context.Context, string) (bool, error) { return false, nil } @@ -46,12 +162,58 @@ func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int return 0, nil } func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil } -func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil } -func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil } -func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil } +func (m *mockUserRepo) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) { + out := make([]UserAuthIdentityRecord, len(m.identities)) + copy(out, m.identities) + return out, nil +} +func (m *mockUserRepo) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) { + return map[int64]*time.Time{}, nil +} +func (m *mockUserRepo) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) { + return nil, nil +} +func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil } +func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil } +func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil } func (m *mockUserRepo) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { return nil } +func (m *mockUserRepo) UnbindUserAuthProvider(_ context.Context, _ int64, provider string) error { + if m.unbindIdentityErr != nil { + return m.unbindIdentityErr + } + m.unboundProviders = append(m.unboundProviders, provider) + filtered := m.identities[:0] + for _, identity := range m.identities { + if identity.ProviderType == provider { + continue + } + filtered = append(filtered, identity) + } + m.identities = append([]UserAuthIdentityRecord(nil), filtered...) + return nil +} + +func (m *mockUserRepo) WithUserProfileIdentityTx(ctx context.Context, fn func(txCtx context.Context) error) error { + m.txCalls++ + txState := &mockUserRepoTxState{ + upsertAvatarArgs: append([]UpsertUserAvatarInput(nil), m.upsertAvatarArgs...), + deleteAvatarIDs: append([]int64(nil), m.deleteAvatarIDs...), + } + if m.getByIDUser != nil { + userCopy := *m.getByIDUser + txState.getByIDUser = &userCopy + } + err := fn(context.WithValue(ctx, mockUserRepoTxKey{}, txState)) + if err != nil { + return err + } + m.getByIDUser = txState.getByIDUser + m.upsertAvatarArgs = txState.upsertAvatarArgs + m.deleteAvatarIDs = txState.deleteAvatarIDs + return nil +} // --- mock: APIKeyAuthCacheInvalidator --- @@ -132,6 +294,94 @@ func TestUpdateBalance_Success(t *testing.T) { require.Equal(t, []int64{42}, cache.invalidatedUserIDs, "应对 userID=42 失效缓存") } +func TestGetProfileIdentitySummaries_AllowsUnbindWhenAnotherLoginMethodRemains(t *testing.T) { + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 7, + Email: "alice@example.com", + }, + identities: []UserAuthIdentityRecord{ + { + ProviderType: "email", + ProviderKey: "email", + ProviderSubject: "alice@example.com", + }, + { + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: "linuxdo-subject-123456", + Metadata: map[string]any{ + "username": "linuxdo-handle", + }, + }, + }, + } + svc := NewUserService(repo, nil, nil, nil) + + summaries, err := svc.GetProfileIdentitySummaries(context.Background(), 7, repo.getByIDUser) + + require.NoError(t, err) + require.True(t, summaries.LinuxDo.Bound) + require.True(t, summaries.LinuxDo.CanUnbind) + require.Equal(t, "linuxdo-handle", summaries.LinuxDo.DisplayName) + require.NotEmpty(t, summaries.LinuxDo.SubjectHint) +} + +func TestUnbindUserAuthProviderRejectsLastRemainingLoginMethod(t *testing.T) { + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 9, + Email: "only-user@linuxdo-connect.invalid", + }, + identities: []UserAuthIdentityRecord{ + { + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: "linuxdo-only-subject", + }, + }, + } + svc := NewUserService(repo, nil, nil, nil) + + _, err := svc.UnbindUserAuthProvider(context.Background(), 9, "linuxdo") + + require.ErrorIs(t, err, ErrIdentityUnbindLastMethod) + require.Empty(t, repo.unboundProviders) +} + +func TestUnbindUserAuthProviderRemovesProviderAndReturnsUpdatedProfile(t *testing.T) { + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 12, + Email: "alice@example.com", + }, + identities: []UserAuthIdentityRecord{ + { + ProviderType: "email", + ProviderKey: "email", + ProviderSubject: "alice@example.com", + }, + { + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: "linuxdo-subject-12", + }, + }, + } + svc := NewUserService(repo, nil, nil, nil) + + user, err := svc.UnbindUserAuthProvider(context.Background(), 12, "linuxdo") + + require.NoError(t, err) + require.Equal(t, []string{"linuxdo"}, repo.unboundProviders) + require.Equal(t, int64(12), user.ID) + + summaries, err := svc.GetProfileIdentitySummaries(context.Background(), 12, user) + require.NoError(t, err) + require.False(t, summaries.LinuxDo.Bound) + require.True(t, summaries.LinuxDo.CanBind) +} + func TestUpdateBalance_NilBillingCache_NoPanic(t *testing.T) { repo := &mockUserRepo{} svc := NewUserService(repo, nil, nil, nil) // billingCache = nil @@ -154,6 +404,39 @@ func TestUpdateBalance_CacheFailure_DoesNotAffectReturn(t *testing.T) { }, 2*time.Second, 10*time.Millisecond, "即使失败也应调用 InvalidateUserBalance") } +func TestTouchLastActive_UpdatesWhenStale(t *testing.T) { + stale := time.Now().Add(-11 * time.Minute) + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 42, + LastActiveAt: &stale, + }, + } + svc := NewUserService(repo, nil, nil, nil) + + svc.TouchLastActive(context.Background(), 42) + + require.Equal(t, []int64{42}, repo.updateLastActiveUserIDs) + require.Len(t, repo.updateLastActiveAt, 1) + require.WithinDuration(t, time.Now(), repo.updateLastActiveAt[0], 2*time.Second) +} + +func TestTouchLastActive_SkipsWhenRecent(t *testing.T) { + recent := time.Now().Add(-time.Minute) + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 42, + LastActiveAt: &recent, + }, + } + svc := NewUserService(repo, nil, nil, nil) + + svc.TouchLastActive(context.Background(), 42) + + require.Empty(t, repo.updateLastActiveUserIDs) + require.Empty(t, repo.updateLastActiveAt) +} + func TestUpdateBalance_RepoError_ReturnsError(t *testing.T) { repo := &mockUserRepo{updateBalanceErr: errors.New("database error")} cache := &mockBillingCache{} @@ -200,3 +483,199 @@ func TestNewUserService_FieldsAssignment(t *testing.T) { require.Equal(t, auth, svc.authCacheInvalidator) require.Equal(t, cache, svc.billingCache) } + +func TestUpdateProfile_StoresInlineAvatarWithinLimit(t *testing.T) { + raw := []byte("small-avatar") + dataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(raw) + expectedSum := sha256.Sum256(raw) + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 7, + Email: "avatar@example.com", + Username: "avatar-user", + }, + } + svc := NewUserService(repo, nil, nil, nil) + + updated, err := svc.UpdateProfile(context.Background(), 7, UpdateProfileRequest{ + AvatarURL: &dataURL, + }) + require.NoError(t, err) + require.Len(t, repo.upsertAvatarArgs, 1) + require.Equal(t, "inline", repo.upsertAvatarArgs[0].StorageProvider) + require.Equal(t, "image/png", repo.upsertAvatarArgs[0].ContentType) + require.Equal(t, len(raw), repo.upsertAvatarArgs[0].ByteSize) + require.Equal(t, hex.EncodeToString(expectedSum[:]), repo.upsertAvatarArgs[0].SHA256) + require.Equal(t, dataURL, updated.AvatarURL) + require.Equal(t, "inline", updated.AvatarSource) + require.Equal(t, "image/png", updated.AvatarMIME) + require.Equal(t, len(raw), updated.AvatarByteSize) + require.Equal(t, hex.EncodeToString(expectedSum[:]), updated.AvatarSHA256) +} + +func TestUpdateProfile_CompressesInlineAvatarToTwentyKilobytes(t *testing.T) { + var encoded bytes.Buffer + for _, size := range []int{192, 224, 256, 288} { + encoded.Reset() + var img image.RGBA + img.Rect = image.Rect(0, 0, size, size) + img.Stride = size * 4 + img.Pix = make([]byte, size*size*4) + for y := 0; y < size; y++ { + for x := 0; x < size; x++ { + offset := y*img.Stride + x*4 + img.Pix[offset] = uint8((x*x + y*17) % 255) + img.Pix[offset+1] = uint8((y*y + x*29) % 255) + img.Pix[offset+2] = uint8(((x * y) + x*13 + y*7) % 255) + img.Pix[offset+3] = 0xff + } + } + require.NoError(t, png.Encode(&encoded, &img)) + if encoded.Len() > 20*1024 && encoded.Len() <= maxInlineAvatarBytes { + break + } + } + require.Greater(t, encoded.Len(), 20*1024) + require.LessOrEqual(t, encoded.Len(), maxInlineAvatarBytes) + + dataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(encoded.Bytes()) + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 17, + Email: "avatar-compress@example.com", + Username: "avatar-compress", + }, + } + svc := NewUserService(repo, nil, nil, nil) + + updated, err := svc.UpdateProfile(context.Background(), 17, UpdateProfileRequest{ + AvatarURL: &dataURL, + }) + require.NoError(t, err) + require.Len(t, repo.upsertAvatarArgs, 1) + require.Equal(t, "inline", repo.upsertAvatarArgs[0].StorageProvider) + require.LessOrEqual(t, repo.upsertAvatarArgs[0].ByteSize, 20*1024) + require.Equal(t, "image/jpeg", repo.upsertAvatarArgs[0].ContentType) + require.Contains(t, repo.upsertAvatarArgs[0].URL, "data:image/jpeg;base64,") + require.Equal(t, "inline", updated.AvatarSource) + require.Equal(t, "image/jpeg", updated.AvatarMIME) + require.LessOrEqual(t, updated.AvatarByteSize, 20*1024) + require.Contains(t, updated.AvatarURL, "data:image/jpeg;base64,") + require.NotEmpty(t, updated.AvatarSHA256) +} + +func TestUpdateProfile_RejectsInlineAvatarOverLimit(t *testing.T) { + raw := make([]byte, maxInlineAvatarBytes+1) + dataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(raw) + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 8, + Email: "large-avatar@example.com", + Username: "too-large", + }, + } + svc := NewUserService(repo, nil, nil, nil) + + _, err := svc.UpdateProfile(context.Background(), 8, UpdateProfileRequest{ + AvatarURL: &dataURL, + }) + require.ErrorIs(t, err, ErrAvatarTooLarge) + require.Empty(t, repo.upsertAvatarArgs) + require.Empty(t, repo.deleteAvatarIDs) + require.Zero(t, repo.updateCalls) +} + +func TestUpdateProfile_StoresRemoteAvatarURL(t *testing.T) { + remoteURL := "https://cdn.example.com/avatar.png" + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 9, + Email: "remote-avatar@example.com", + Username: "remote-avatar", + }, + } + svc := NewUserService(repo, nil, nil, nil) + + updated, err := svc.UpdateProfile(context.Background(), 9, UpdateProfileRequest{ + AvatarURL: &remoteURL, + }) + require.NoError(t, err) + require.Len(t, repo.upsertAvatarArgs, 1) + require.Equal(t, "remote_url", repo.upsertAvatarArgs[0].StorageProvider) + require.Equal(t, remoteURL, repo.upsertAvatarArgs[0].URL) + require.Equal(t, remoteURL, updated.AvatarURL) + require.Equal(t, "remote_url", updated.AvatarSource) + require.Zero(t, updated.AvatarByteSize) +} + +func TestUpdateProfile_DeletesAvatarOnEmptyString(t *testing.T) { + empty := "" + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 10, + Email: "delete-avatar@example.com", + Username: "delete-avatar", + AvatarURL: "https://cdn.example.com/old.png", + AvatarSource: "remote_url", + }, + } + svc := NewUserService(repo, nil, nil, nil) + + updated, err := svc.UpdateProfile(context.Background(), 10, UpdateProfileRequest{ + AvatarURL: &empty, + }) + require.NoError(t, err) + require.Equal(t, []int64{10}, repo.deleteAvatarIDs) + require.Empty(t, repo.upsertAvatarArgs) + require.Empty(t, updated.AvatarURL) + require.Empty(t, updated.AvatarSource) +} + +func TestUpdateProfile_RollsBackAvatarMutationWhenUserUpdateFails(t *testing.T) { + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 11, + Email: "rollback@example.com", + AvatarURL: "https://cdn.example.com/original.png", + AvatarSource: "remote_url", + }, + updateFn: func(context.Context, *User) error { + return errors.New("write user failed") + }, + } + svc := NewUserService(repo, nil, nil, nil) + + remoteURL := "https://cdn.example.com/new.png" + _, err := svc.UpdateProfile(context.Background(), 11, UpdateProfileRequest{ + AvatarURL: &remoteURL, + }) + + require.EqualError(t, err, "update user: write user failed") + require.Equal(t, 1, repo.txCalls) + require.Empty(t, repo.upsertAvatarArgs) + require.Empty(t, repo.deleteAvatarIDs) + require.Equal(t, "https://cdn.example.com/original.png", repo.getByIDUser.AvatarURL) + require.Equal(t, "remote_url", repo.getByIDUser.AvatarSource) +} + +func TestGetProfile_HydratesAvatarFromRepository(t *testing.T) { + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 12, + Email: "profile-avatar@example.com", + Username: "profile-avatar", + }, + getAvatarFn: func(context.Context, int64) (*UserAvatar, error) { + return &UserAvatar{ + StorageProvider: "remote_url", + URL: "https://cdn.example.com/profile.png", + }, nil + }, + } + svc := NewUserService(repo, nil, nil, nil) + + user, err := svc.GetProfile(context.Background(), 12) + require.NoError(t, err) + require.Equal(t, "https://cdn.example.com/profile.png", user.AvatarURL) + require.Equal(t, "remote_url", user.AvatarSource) +} diff --git a/backend/migrations/108_auth_identity_foundation_core.sql b/backend/migrations/108_auth_identity_foundation_core.sql new file mode 100644 index 0000000000000000000000000000000000000000..117e3ca38c5c11b00491c298344d9ada4e14650c --- /dev/null +++ b/backend/migrations/108_auth_identity_foundation_core.sql @@ -0,0 +1,141 @@ +ALTER TABLE users +ADD COLUMN IF NOT EXISTS signup_source VARCHAR(20) NOT NULL DEFAULT 'email', +ADD COLUMN IF NOT EXISTS last_login_at TIMESTAMPTZ NULL, +ADD COLUMN IF NOT EXISTS last_active_at TIMESTAMPTZ NULL; + +UPDATE users +SET signup_source = 'email' +WHERE signup_source IS NULL OR signup_source = ''; + +DO $$ +BEGIN + IF NOT EXISTS ( + SELECT 1 + FROM pg_constraint + WHERE conname = 'users_signup_source_check' + ) THEN + ALTER TABLE users + ADD CONSTRAINT users_signup_source_check + CHECK (signup_source IN ('email', 'linuxdo', 'wechat', 'oidc')); + END IF; +END $$; + +CREATE TABLE IF NOT EXISTS auth_identities ( + id BIGSERIAL PRIMARY KEY, + user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + provider_type VARCHAR(20) NOT NULL, + provider_key TEXT NOT NULL, + provider_subject TEXT NOT NULL, + verified_at TIMESTAMPTZ NULL, + issuer TEXT NULL, + metadata JSONB NOT NULL DEFAULT '{}'::jsonb, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + CONSTRAINT auth_identities_provider_type_check + CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc')) +); + +CREATE UNIQUE INDEX IF NOT EXISTS auth_identities_provider_subject_key + ON auth_identities (provider_type, provider_key, provider_subject); + +CREATE INDEX IF NOT EXISTS auth_identities_user_id_idx + ON auth_identities (user_id); + +CREATE INDEX IF NOT EXISTS auth_identities_user_provider_idx + ON auth_identities (user_id, provider_type); + +CREATE TABLE IF NOT EXISTS auth_identity_channels ( + id BIGSERIAL PRIMARY KEY, + identity_id BIGINT NOT NULL REFERENCES auth_identities(id) ON DELETE CASCADE, + provider_type VARCHAR(20) NOT NULL, + provider_key TEXT NOT NULL, + channel VARCHAR(20) NOT NULL, + channel_app_id TEXT NOT NULL, + channel_subject TEXT NOT NULL, + metadata JSONB NOT NULL DEFAULT '{}'::jsonb, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + CONSTRAINT auth_identity_channels_provider_type_check + CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc')) +); + +CREATE UNIQUE INDEX IF NOT EXISTS auth_identity_channels_channel_key + ON auth_identity_channels (provider_type, provider_key, channel, channel_app_id, channel_subject); + +CREATE INDEX IF NOT EXISTS auth_identity_channels_identity_id_idx + ON auth_identity_channels (identity_id); + +CREATE TABLE IF NOT EXISTS pending_auth_sessions ( + id BIGSERIAL PRIMARY KEY, + session_token VARCHAR(255) NOT NULL, + intent VARCHAR(40) NOT NULL, + provider_type VARCHAR(20) NOT NULL, + provider_key TEXT NOT NULL, + provider_subject TEXT NOT NULL, + target_user_id BIGINT NULL REFERENCES users(id) ON DELETE SET NULL, + redirect_to TEXT NOT NULL DEFAULT '', + resolved_email TEXT NOT NULL DEFAULT '', + registration_password_hash TEXT NOT NULL DEFAULT '', + upstream_identity_claims JSONB NOT NULL DEFAULT '{}'::jsonb, + local_flow_state JSONB NOT NULL DEFAULT '{}'::jsonb, + browser_session_key TEXT NOT NULL DEFAULT '', + completion_code_hash TEXT NOT NULL DEFAULT '', + completion_code_expires_at TIMESTAMPTZ NULL, + email_verified_at TIMESTAMPTZ NULL, + password_verified_at TIMESTAMPTZ NULL, + totp_verified_at TIMESTAMPTZ NULL, + expires_at TIMESTAMPTZ NOT NULL, + consumed_at TIMESTAMPTZ NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + CONSTRAINT pending_auth_sessions_intent_check + CHECK (intent IN ('login', 'bind_current_user', 'adopt_existing_user_by_email')), + CONSTRAINT pending_auth_sessions_provider_type_check + CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc')) +); + +CREATE UNIQUE INDEX IF NOT EXISTS pending_auth_sessions_session_token_key + ON pending_auth_sessions (session_token); + +CREATE INDEX IF NOT EXISTS pending_auth_sessions_target_user_id_idx + ON pending_auth_sessions (target_user_id); + +CREATE INDEX IF NOT EXISTS pending_auth_sessions_expires_at_idx + ON pending_auth_sessions (expires_at); + +CREATE INDEX IF NOT EXISTS pending_auth_sessions_provider_idx + ON pending_auth_sessions (provider_type, provider_key, provider_subject); + +CREATE INDEX IF NOT EXISTS pending_auth_sessions_completion_code_idx + ON pending_auth_sessions (completion_code_hash); + +CREATE TABLE IF NOT EXISTS identity_adoption_decisions ( + id BIGSERIAL PRIMARY KEY, + pending_auth_session_id BIGINT NOT NULL REFERENCES pending_auth_sessions(id) ON DELETE CASCADE, + identity_id BIGINT NULL REFERENCES auth_identities(id) ON DELETE SET NULL, + adopt_display_name BOOLEAN NOT NULL DEFAULT FALSE, + adopt_avatar BOOLEAN NOT NULL DEFAULT FALSE, + decided_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE UNIQUE INDEX IF NOT EXISTS identity_adoption_decisions_pending_auth_session_id_key + ON identity_adoption_decisions (pending_auth_session_id); + +CREATE INDEX IF NOT EXISTS identity_adoption_decisions_identity_id_idx + ON identity_adoption_decisions (identity_id); + +CREATE TABLE IF NOT EXISTS auth_identity_migration_reports ( + id BIGSERIAL PRIMARY KEY, + report_type VARCHAR(40) NOT NULL, + report_key TEXT NOT NULL, + details JSONB NOT NULL DEFAULT '{}'::jsonb, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS auth_identity_migration_reports_type_idx + ON auth_identity_migration_reports (report_type); + +CREATE UNIQUE INDEX IF NOT EXISTS auth_identity_migration_reports_type_key + ON auth_identity_migration_reports (report_type, report_key); diff --git a/backend/migrations/109_auth_identity_compat_backfill.sql b/backend/migrations/109_auth_identity_compat_backfill.sql new file mode 100644 index 0000000000000000000000000000000000000000..5147ae45a5eed3000d4171f45ea90c427c3b4e54 --- /dev/null +++ b/backend/migrations/109_auth_identity_compat_backfill.sql @@ -0,0 +1,128 @@ +ALTER TABLE auth_identity_migration_reports +ALTER COLUMN report_type TYPE VARCHAR(80); + +INSERT INTO auth_identities ( + user_id, + provider_type, + provider_key, + provider_subject, + verified_at, + metadata +) +SELECT + u.id, + 'email', + 'email', + LOWER(BTRIM(u.email)), + COALESCE(u.updated_at, u.created_at, NOW()), + jsonb_build_object( + 'backfill_source', 'users.email', + 'migration', '109_auth_identity_compat_backfill' + ) +FROM users AS u +WHERE u.deleted_at IS NULL + AND BTRIM(COALESCE(u.email, '')) <> '' + AND RIGHT(LOWER(BTRIM(u.email)), LENGTH('@linuxdo-connect.invalid')) <> '@linuxdo-connect.invalid' + AND RIGHT(LOWER(BTRIM(u.email)), LENGTH('@oidc-connect.invalid')) <> '@oidc-connect.invalid' + AND RIGHT(LOWER(BTRIM(u.email)), LENGTH('@wechat-connect.invalid')) <> '@wechat-connect.invalid' +ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING; + +INSERT INTO auth_identities ( + user_id, + provider_type, + provider_key, + provider_subject, + verified_at, + metadata +) +SELECT + u.id, + 'linuxdo', + 'linuxdo', + SUBSTRING(BTRIM(u.email) FROM '(?i)^linuxdo-(.+)@linuxdo-connect\.invalid$'), + COALESCE(u.updated_at, u.created_at, NOW()), + jsonb_build_object( + 'backfill_source', 'synthetic_email', + 'legacy_email', BTRIM(u.email), + 'migration', '109_auth_identity_compat_backfill' + ) +FROM users AS u +WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(u.email)) ~ '^linuxdo-.+@linuxdo-connect\.invalid$' +ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING; + +INSERT INTO auth_identities ( + user_id, + provider_type, + provider_key, + provider_subject, + verified_at, + metadata +) +SELECT + u.id, + 'wechat', + 'wechat', + SUBSTRING(BTRIM(u.email) FROM '(?i)^wechat-(.+)@wechat-connect\.invalid$'), + COALESCE(u.updated_at, u.created_at, NOW()), + jsonb_build_object( + 'backfill_source', 'synthetic_email', + 'legacy_email', BTRIM(u.email), + 'migration', '109_auth_identity_compat_backfill' + ) +FROM users AS u +WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(u.email)) ~ '^wechat-.+@wechat-connect\.invalid$' +ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING; + +UPDATE users +SET signup_source = 'linuxdo' +WHERE deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(email, ''))) ~ '^linuxdo-.+@linuxdo-connect\.invalid$'; + +UPDATE users +SET signup_source = 'wechat' +WHERE deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(email, ''))) ~ '^wechat-.+@wechat-connect\.invalid$'; + +UPDATE users +SET signup_source = 'oidc' +WHERE deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(email, ''))) ~ '^oidc-.+@oidc-connect\.invalid$'; + +INSERT INTO auth_identity_migration_reports (report_type, report_key, details) +SELECT + 'oidc_synthetic_email_requires_manual_recovery', + CAST(u.id AS TEXT), + jsonb_build_object( + 'user_id', u.id, + 'email', LOWER(BTRIM(u.email)), + 'reason', 'cannot recover issuer_plus_sub deterministically from synthetic email alone', + 'migration', '109_auth_identity_compat_backfill' + ) +FROM users AS u +WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(u.email)) ~ '^oidc-.+@oidc-connect\.invalid$' +ON CONFLICT (report_type, report_key) DO NOTHING; + +INSERT INTO auth_identity_migration_reports (report_type, report_key, details) +SELECT + 'wechat_openid_only_requires_remediation', + CAST(u.id AS TEXT), + jsonb_build_object( + 'user_id', u.id, + 'email', LOWER(BTRIM(u.email)), + 'reason', 'legacy wechat synthetic identity requires explicit unionid remediation if channel-only data exists', + 'migration', '109_auth_identity_compat_backfill' + ) +FROM users AS u +WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(u.email)) ~ '^wechat-.+@wechat-connect\.invalid$' + AND NOT EXISTS ( + SELECT 1 + FROM auth_identities ai + WHERE ai.user_id = u.id + AND ai.provider_type = 'wechat' + AND ai.provider_key = 'wechat' + ) +ON CONFLICT (report_type, report_key) DO NOTHING; diff --git a/backend/migrations/110_pending_auth_and_provider_default_grants.sql b/backend/migrations/110_pending_auth_and_provider_default_grants.sql new file mode 100644 index 0000000000000000000000000000000000000000..fbaed62e4b58fb29f890eca80b6b39d474458f6f --- /dev/null +++ b/backend/migrations/110_pending_auth_and_provider_default_grants.sql @@ -0,0 +1,60 @@ +CREATE TABLE IF NOT EXISTS user_provider_default_grants ( + id BIGSERIAL PRIMARY KEY, + user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + provider_type VARCHAR(20) NOT NULL, + grant_reason VARCHAR(20) NOT NULL DEFAULT 'first_bind', + granted_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + CONSTRAINT user_provider_default_grants_provider_type_check + CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc')), + CONSTRAINT user_provider_default_grants_reason_check + CHECK (grant_reason IN ('signup', 'first_bind')) +); + +CREATE UNIQUE INDEX IF NOT EXISTS user_provider_default_grants_user_provider_reason_key + ON user_provider_default_grants (user_id, provider_type, grant_reason); + +CREATE INDEX IF NOT EXISTS user_provider_default_grants_user_id_idx + ON user_provider_default_grants (user_id); + +CREATE TABLE IF NOT EXISTS user_avatars ( + id BIGSERIAL PRIMARY KEY, + user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + storage_provider VARCHAR(20) NOT NULL DEFAULT 'database', + storage_key TEXT NOT NULL DEFAULT '', + url TEXT NOT NULL DEFAULT '', + content_type VARCHAR(100) NOT NULL DEFAULT '', + byte_size INT NOT NULL DEFAULT 0, + sha256 VARCHAR(64) NOT NULL DEFAULT '', + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE UNIQUE INDEX IF NOT EXISTS user_avatars_user_id_key + ON user_avatars (user_id); + +INSERT INTO settings (key, value) +VALUES + ('auth_source_default_email_balance', '0'), + ('auth_source_default_email_concurrency', '5'), + ('auth_source_default_email_subscriptions', '[]'), + ('auth_source_default_email_grant_on_signup', 'true'), + ('auth_source_default_email_grant_on_first_bind', 'false'), + ('auth_source_default_linuxdo_balance', '0'), + ('auth_source_default_linuxdo_concurrency', '5'), + ('auth_source_default_linuxdo_subscriptions', '[]'), + ('auth_source_default_linuxdo_grant_on_signup', 'true'), + ('auth_source_default_linuxdo_grant_on_first_bind', 'false'), + ('auth_source_default_oidc_balance', '0'), + ('auth_source_default_oidc_concurrency', '5'), + ('auth_source_default_oidc_subscriptions', '[]'), + ('auth_source_default_oidc_grant_on_signup', 'true'), + ('auth_source_default_oidc_grant_on_first_bind', 'false'), + ('auth_source_default_wechat_balance', '0'), + ('auth_source_default_wechat_concurrency', '5'), + ('auth_source_default_wechat_subscriptions', '[]'), + ('auth_source_default_wechat_grant_on_signup', 'true'), + ('auth_source_default_wechat_grant_on_first_bind', 'false'), + ('force_email_on_third_party_signup', 'false') +ON CONFLICT (key) DO NOTHING; + diff --git a/backend/migrations/111_payment_routing_and_scheduler_flags.sql b/backend/migrations/111_payment_routing_and_scheduler_flags.sql new file mode 100644 index 0000000000000000000000000000000000000000..f222a8d40a9f18b376409ccb9587715eda637985 --- /dev/null +++ b/backend/migrations/111_payment_routing_and_scheduler_flags.sql @@ -0,0 +1,8 @@ +INSERT INTO settings (key, value) +VALUES + ('payment_visible_method_alipay_source', ''), + ('payment_visible_method_wxpay_source', ''), + ('payment_visible_method_alipay_enabled', 'false'), + ('payment_visible_method_wxpay_enabled', 'false'), + ('openai_advanced_scheduler_enabled', 'false') +ON CONFLICT (key) DO NOTHING; diff --git a/backend/migrations/112_add_payment_order_provider_key_snapshot.sql b/backend/migrations/112_add_payment_order_provider_key_snapshot.sql new file mode 100644 index 0000000000000000000000000000000000000000..7ec19ae32a557089072b37edafb8a125417683f5 --- /dev/null +++ b/backend/migrations/112_add_payment_order_provider_key_snapshot.sql @@ -0,0 +1,10 @@ +ALTER TABLE payment_orders ADD COLUMN provider_key VARCHAR(30); + +UPDATE payment_orders +SET provider_key = ( + SELECT provider_key + FROM payment_provider_instances + WHERE CAST(id AS TEXT) = payment_orders.provider_instance_id +) +WHERE provider_key IS NULL + AND provider_instance_id IS NOT NULL; diff --git a/backend/migrations/113_normalize_legacy_wechat_provider_key.sql b/backend/migrations/113_normalize_legacy_wechat_provider_key.sql new file mode 100644 index 0000000000000000000000000000000000000000..15610af0d2a7f9d660ccb534c068d3a87c84619a --- /dev/null +++ b/backend/migrations/113_normalize_legacy_wechat_provider_key.sql @@ -0,0 +1,89 @@ +UPDATE auth_identities AS ai +SET + provider_key = 'wechat-main', + metadata = COALESCE(ai.metadata, '{}'::jsonb) || jsonb_build_object( + 'legacy_provider_key', 'wechat', + 'normalized_by_migration', '113_normalize_legacy_wechat_provider_key' + ), + updated_at = NOW() +WHERE ai.provider_type = 'wechat' + AND ai.provider_key = 'wechat' + AND NOT EXISTS ( + SELECT 1 + FROM auth_identities AS canon + WHERE canon.provider_type = 'wechat' + AND canon.provider_key = 'wechat-main' + AND canon.provider_subject = ai.provider_subject + ); + +UPDATE auth_identity_channels AS channel +SET + provider_key = 'wechat-main', + metadata = COALESCE(channel.metadata, '{}'::jsonb) || jsonb_build_object( + 'legacy_provider_key', 'wechat', + 'normalized_by_migration', '113_normalize_legacy_wechat_provider_key' + ), + updated_at = NOW() +WHERE channel.provider_type = 'wechat' + AND channel.provider_key = 'wechat' + AND NOT EXISTS ( + SELECT 1 + FROM auth_identity_channels AS canon + WHERE canon.provider_type = 'wechat' + AND canon.provider_key = 'wechat-main' + AND canon.channel = channel.channel + AND canon.channel_app_id = channel.channel_app_id + AND canon.channel_subject = channel.channel_subject + ); + +INSERT INTO auth_identity_migration_reports (report_type, report_key, details) +SELECT + 'wechat_provider_key_conflict', + CAST(ai.id AS TEXT), + jsonb_build_object( + 'legacy_identity_id', ai.id, + 'legacy_user_id', ai.user_id, + 'provider_subject', ai.provider_subject, + 'canonical_identity_id', canon.id, + 'canonical_user_id', canon.user_id, + 'same_user', canon.user_id = ai.user_id, + 'migration', '113_normalize_legacy_wechat_provider_key' + ) +FROM auth_identities AS ai +JOIN auth_identities AS canon + ON canon.provider_type = 'wechat' + AND canon.provider_key = 'wechat-main' + AND canon.provider_subject = ai.provider_subject +WHERE ai.provider_type = 'wechat' + AND ai.provider_key = 'wechat' +ON CONFLICT (report_type, report_key) DO NOTHING; + +INSERT INTO auth_identity_migration_reports (report_type, report_key, details) +SELECT + 'wechat_channel_provider_key_conflict', + CAST(channel.id AS TEXT), + jsonb_build_object( + 'legacy_channel_id', channel.id, + 'legacy_identity_id', channel.identity_id, + 'canonical_channel_id', canon.id, + 'canonical_identity_id', canon.identity_id, + 'channel', channel.channel, + 'channel_app_id', channel.channel_app_id, + 'channel_subject', channel.channel_subject, + 'same_user', COALESCE(legacy_identity.user_id = canonical_identity.user_id, FALSE), + 'migration', '113_normalize_legacy_wechat_provider_key' + ) +FROM auth_identity_channels AS channel +JOIN auth_identity_channels AS canon + ON canon.provider_type = 'wechat' + AND canon.provider_key = 'wechat-main' + AND canon.channel = channel.channel + AND canon.channel_app_id = channel.channel_app_id + AND canon.channel_subject = channel.channel_subject +LEFT JOIN auth_identities AS legacy_identity + ON legacy_identity.id = channel.identity_id +LEFT JOIN auth_identities AS canonical_identity + ON canonical_identity.id = canon.identity_id +WHERE channel.provider_type = 'wechat' + AND channel.provider_key = 'wechat' +ON CONFLICT (report_type, report_key) DO NOTHING; diff --git a/backend/migrations/114_auth_identity_migration_report_resolution.sql b/backend/migrations/114_auth_identity_migration_report_resolution.sql new file mode 100644 index 0000000000000000000000000000000000000000..f84bf822921fc5135c5ab5d659302b40ad417617 --- /dev/null +++ b/backend/migrations/114_auth_identity_migration_report_resolution.sql @@ -0,0 +1,11 @@ +ALTER TABLE auth_identity_migration_reports + ADD COLUMN IF NOT EXISTS resolved_at TIMESTAMPTZ NULL; + +ALTER TABLE auth_identity_migration_reports + ADD COLUMN IF NOT EXISTS resolved_by_user_id BIGINT NULL; + +ALTER TABLE auth_identity_migration_reports + ADD COLUMN IF NOT EXISTS resolution_note TEXT NOT NULL DEFAULT ''; + +CREATE INDEX IF NOT EXISTS idx_auth_identity_migration_reports_resolved_at + ON auth_identity_migration_reports (resolved_at); diff --git a/backend/migrations/115_auth_identity_legacy_external_backfill.sql b/backend/migrations/115_auth_identity_legacy_external_backfill.sql new file mode 100644 index 0000000000000000000000000000000000000000..7a20f8eb74926fdbca5e2d07d95d2072ba4a6859 --- /dev/null +++ b/backend/migrations/115_auth_identity_legacy_external_backfill.sql @@ -0,0 +1,215 @@ +CREATE OR REPLACE FUNCTION public.__migration_115_safe_legacy_metadata_jsonb(input_text TEXT) +RETURNS JSONB +LANGUAGE plpgsql +AS $$ +DECLARE + parsed JSONB; +BEGIN + IF input_text IS NULL OR BTRIM(input_text) = '' THEN + RETURN '{}'::jsonb; + END IF; + + BEGIN + parsed := input_text::jsonb; + EXCEPTION + WHEN OTHERS THEN + RETURN '{}'::jsonb; + END; + + IF jsonb_typeof(parsed) = 'object' THEN + RETURN parsed; + END IF; + + RETURN jsonb_build_object('_legacy_metadata_raw_json', parsed); +END; +$$; + +DO $$ +BEGIN + IF to_regclass('public.user_external_identities') IS NULL THEN + RETURN; + END IF; + + EXECUTE $sql$ +INSERT INTO auth_identities ( + user_id, + provider_type, + provider_key, + provider_subject, + verified_at, + metadata +) +SELECT + legacy.user_id, + 'linuxdo', + 'linuxdo', + legacy.provider_user_id, + COALESCE(legacy.updated_at, legacy.created_at, NOW()), + legacy.metadata_json || jsonb_build_object( + 'legacy_identity_id', legacy.id, + 'provider_user_id', legacy.provider_user_id, + 'provider_username', legacy.provider_username, + 'display_name', legacy.display_name, + 'migration', '115_auth_identity_legacy_external_backfill' + ) +FROM ( + SELECT + uei.id, + uei.user_id, + BTRIM(uei.provider_user_id) AS provider_user_id, + BTRIM(uei.provider_username) AS provider_username, + BTRIM(uei.display_name) AS display_name, + public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json, + uei.created_at, + uei.updated_at + FROM user_external_identities AS uei + JOIN users AS u ON u.id = uei.user_id + WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' + AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '' +) AS legacy +ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING; +$sql$; + + EXECUTE $sql$ +INSERT INTO auth_identities ( + user_id, + provider_type, + provider_key, + provider_subject, + verified_at, + metadata +) +SELECT + legacy.user_id, + 'wechat', + 'wechat-main', + legacy.provider_union_id, + COALESCE(legacy.updated_at, legacy.created_at, NOW()), + legacy.metadata_json || jsonb_build_object( + 'legacy_identity_id', legacy.id, + 'openid', legacy.provider_user_id, + 'unionid', legacy.provider_union_id, + 'provider_user_id', legacy.provider_user_id, + 'provider_union_id', legacy.provider_union_id, + 'provider_username', legacy.provider_username, + 'display_name', legacy.display_name, + 'migration', '115_auth_identity_legacy_external_backfill' + ) +FROM ( + SELECT + uei.id, + uei.user_id, + BTRIM(uei.provider_user_id) AS provider_user_id, + BTRIM(uei.provider_union_id) AS provider_union_id, + BTRIM(uei.provider_username) AS provider_username, + BTRIM(uei.display_name) AS display_name, + public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json, + uei.created_at, + uei.updated_at + FROM user_external_identities AS uei + JOIN users AS u ON u.id = uei.user_id + WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' + AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '' +) AS legacy +ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING; +$sql$; + + EXECUTE $sql$ +INSERT INTO auth_identity_channels ( + identity_id, + provider_type, + provider_key, + channel, + channel_app_id, + channel_subject, + metadata +) +SELECT + ai.id, + 'wechat', + 'wechat-main', + legacy.channel, + legacy.channel_app_id, + legacy.provider_user_id, + legacy.metadata_json || jsonb_build_object( + 'openid', legacy.provider_user_id, + 'unionid', legacy.provider_union_id, + 'migration', '115_auth_identity_legacy_external_backfill' + ) +FROM ( + SELECT + uei.user_id, + BTRIM(uei.provider_user_id) AS provider_user_id, + BTRIM(uei.provider_union_id) AS provider_union_id, + BTRIM(COALESCE(meta.metadata_json ->> 'channel', '')) AS channel, + BTRIM(COALESCE(meta.metadata_json ->> 'channel_app_id', meta.metadata_json ->> 'appid', meta.metadata_json ->> 'app_id', '')) AS channel_app_id, + meta.metadata_json + FROM user_external_identities AS uei + JOIN users AS u ON u.id = uei.user_id + CROSS JOIN LATERAL ( + SELECT public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json + ) AS meta + WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' + AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '' +) AS legacy +JOIN auth_identities AS ai + ON ai.user_id = legacy.user_id + AND ai.provider_type = 'wechat' + AND ai.provider_key = 'wechat-main' + AND ai.provider_subject = legacy.provider_union_id +WHERE legacy.channel <> '' + AND legacy.channel_app_id <> '' + AND legacy.provider_user_id <> '' +ON CONFLICT DO NOTHING; +$sql$; + + EXECUTE $sql$ +INSERT INTO auth_identity_migration_reports (report_type, report_key, details) +SELECT + 'wechat_openid_only_requires_remediation', + 'legacy_external_identity:' || legacy.id::text, + legacy.metadata_json || jsonb_build_object( + 'legacy_identity_id', legacy.id, + 'user_id', legacy.user_id, + 'openid', legacy.provider_user_id, + 'reason', 'legacy user_external_identities row only has openid and cannot be canonicalized offline', + 'migration', '115_auth_identity_legacy_external_backfill' + ) +FROM ( + SELECT + uei.id, + uei.user_id, + BTRIM(uei.provider_user_id) AS provider_user_id, + public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json + FROM user_external_identities AS uei + JOIN users AS u ON u.id = uei.user_id + WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' + AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '' + AND BTRIM(COALESCE(uei.provider_union_id, '')) = '' +) AS legacy +ON CONFLICT (report_type, report_key) DO NOTHING; +$sql$; +END $$; + +INSERT INTO auth_identity_migration_reports (report_type, report_key, details) +SELECT + 'wechat_openid_only_requires_remediation', + 'synthetic_auth_identity:' || ai.id::text, + COALESCE(ai.metadata, '{}'::jsonb) || jsonb_build_object( + 'auth_identity_id', ai.id, + 'user_id', ai.user_id, + 'provider_subject', ai.provider_subject, + 'reason', 'synthetic wechat auth identity still lacks unionid metadata and needs remediation', + 'migration', '115_auth_identity_legacy_external_backfill' + ) +FROM auth_identities AS ai +WHERE ai.provider_type = 'wechat' + AND COALESCE(ai.metadata ->> 'backfill_source', '') = 'synthetic_email' + AND BTRIM(COALESCE(ai.metadata ->> 'unionid', '')) = '' +ON CONFLICT (report_type, report_key) DO NOTHING; + +DROP FUNCTION IF EXISTS public.__migration_115_safe_legacy_metadata_jsonb(TEXT); diff --git a/backend/migrations/116_auth_identity_legacy_external_safety_reports.sql b/backend/migrations/116_auth_identity_legacy_external_safety_reports.sql new file mode 100644 index 0000000000000000000000000000000000000000..3983bb1a5ddda793ebd2826972b09a2521dd72eb --- /dev/null +++ b/backend/migrations/116_auth_identity_legacy_external_safety_reports.sql @@ -0,0 +1,369 @@ +CREATE OR REPLACE FUNCTION public.__migration_116_safe_legacy_metadata_jsonb(input_text TEXT) +RETURNS JSONB +LANGUAGE plpgsql +AS $$ +DECLARE + parsed JSONB; +BEGIN + IF input_text IS NULL OR BTRIM(input_text) = '' THEN + RETURN '{}'::jsonb; + END IF; + + BEGIN + parsed := input_text::jsonb; + EXCEPTION + WHEN OTHERS THEN + RETURN '{}'::jsonb; + END; + + IF jsonb_typeof(parsed) = 'object' THEN + RETURN parsed; + END IF; + + RETURN jsonb_build_object('_legacy_metadata_raw_json', parsed); +END; +$$; + +CREATE OR REPLACE FUNCTION public.__migration_116_is_valid_legacy_metadata_jsonb(input_text TEXT) +RETURNS BOOLEAN +LANGUAGE plpgsql +AS $$ +DECLARE + parsed JSONB; +BEGIN + IF input_text IS NULL OR BTRIM(input_text) = '' THEN + RETURN TRUE; + END IF; + + parsed := input_text::jsonb; + RETURN TRUE; +EXCEPTION + WHEN OTHERS THEN + RETURN FALSE; +END; +$$; + +DO $$ +BEGIN + IF to_regclass('public.user_external_identities') IS NULL THEN + RETURN; + END IF; + + EXECUTE $sql$ +INSERT INTO auth_identity_migration_reports (report_type, report_key, details) +SELECT + 'legacy_external_identity_invalid_metadata_json', + 'legacy_external_identity:' || uei.id::text, + jsonb_build_object( + 'legacy_identity_id', uei.id, + 'user_id', uei.user_id, + 'provider', LOWER(BTRIM(COALESCE(uei.provider, ''))), + 'provider_user_id', BTRIM(COALESCE(uei.provider_user_id, '')), + 'provider_union_id', BTRIM(COALESCE(uei.provider_union_id, '')), + 'reason', 'legacy metadata is not valid JSON; migration downgraded metadata to empty object', + 'raw_metadata', LEFT(BTRIM(COALESCE(uei.metadata, '')), 1000), + 'migration', '116_auth_identity_legacy_external_safety_reports' + ) +FROM user_external_identities AS uei +JOIN users AS u ON u.id = uei.user_id +WHERE u.deleted_at IS NULL + AND BTRIM(COALESCE(uei.metadata, '')) <> '' + AND NOT public.__migration_116_is_valid_legacy_metadata_jsonb(uei.metadata) +ON CONFLICT (report_type, report_key) DO NOTHING; +$sql$; + + EXECUTE $sql$ +INSERT INTO auth_identity_migration_reports (report_type, report_key, details) +SELECT + 'legacy_external_identity_conflict', + 'legacy_external_identity:' || legacy.id::text, + legacy.metadata_json || jsonb_build_object( + 'legacy_identity_id', legacy.id, + 'legacy_user_id', legacy.user_id, + 'existing_identity_id', ai.id, + 'existing_user_id', ai.user_id, + 'provider_type', legacy.provider_type, + 'provider_key', legacy.provider_key, + 'provider_subject', legacy.provider_subject, + 'reason', 'legacy canonical identity subject already belongs to another user', + 'migration', '116_auth_identity_legacy_external_safety_reports' + ) +FROM ( + SELECT + uei.id, + uei.user_id, + LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type, + CASE + WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main' + ELSE 'linuxdo' + END AS provider_key, + CASE + WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, '')) + ELSE BTRIM(COALESCE(uei.provider_user_id, '')) + END AS provider_subject, + BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id, + BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_union_id, + BTRIM(COALESCE(uei.provider_username, '')) AS provider_username, + BTRIM(COALESCE(uei.display_name, '')) AS display_name, + public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json + FROM user_external_identities AS uei + JOIN users AS u ON u.id = uei.user_id + WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat') + AND ( + (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '') + OR + (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '') + ) +) AS legacy +JOIN auth_identities AS ai + ON ai.provider_type = legacy.provider_type + AND ai.provider_key = legacy.provider_key + AND ai.provider_subject = legacy.provider_subject +WHERE ai.user_id <> legacy.user_id +ON CONFLICT (report_type, report_key) DO NOTHING; +$sql$; + + EXECUTE $sql$ +INSERT INTO auth_identities ( + user_id, + provider_type, + provider_key, + provider_subject, + verified_at, + metadata +) +SELECT + legacy.user_id, + legacy.provider_type, + legacy.provider_key, + legacy.provider_subject, + legacy.verified_at, + legacy.metadata_json || jsonb_build_object( + 'legacy_identity_id', legacy.id, + 'provider_user_id', legacy.provider_user_id, + 'provider_union_id', NULLIF(legacy.provider_union_id, ''), + 'provider_username', legacy.provider_username, + 'display_name', legacy.display_name, + 'migration', '116_auth_identity_legacy_external_safety_reports' + ) +FROM ( + SELECT + uei.id, + uei.user_id, + LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type, + CASE + WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main' + ELSE 'linuxdo' + END AS provider_key, + CASE + WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, '')) + ELSE BTRIM(COALESCE(uei.provider_user_id, '')) + END AS provider_subject, + BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id, + BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_union_id, + BTRIM(COALESCE(uei.provider_username, '')) AS provider_username, + BTRIM(COALESCE(uei.display_name, '')) AS display_name, + public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json, + COALESCE(uei.updated_at, uei.created_at, NOW()) AS verified_at + FROM user_external_identities AS uei + JOIN users AS u ON u.id = uei.user_id + WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat') + AND ( + (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '') + OR + (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '') + ) +) AS legacy +LEFT JOIN auth_identities AS ai + ON ai.provider_type = legacy.provider_type + AND ai.provider_key = legacy.provider_key + AND ai.provider_subject = legacy.provider_subject +WHERE ai.id IS NULL +ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING; +$sql$; + + EXECUTE $sql$ +INSERT INTO auth_identity_migration_reports (report_type, report_key, details) +SELECT + 'legacy_external_channel_conflict', + 'legacy_external_identity:' || legacy.id::text, + legacy.metadata_json || jsonb_build_object( + 'legacy_identity_id', legacy.id, + 'legacy_user_id', legacy.user_id, + 'existing_channel_id', channel.id, + 'existing_identity_id', existing_ai.id, + 'existing_user_id', existing_ai.user_id, + 'provider_type', 'wechat', + 'provider_key', 'wechat-main', + 'provider_subject', legacy.provider_union_id, + 'channel', legacy.channel, + 'channel_app_id', legacy.channel_app_id, + 'channel_subject', legacy.provider_user_id, + 'reason', 'legacy channel subject already belongs to another user', + 'migration', '116_auth_identity_legacy_external_safety_reports' + ) +FROM ( + SELECT + uei.id, + uei.user_id, + BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id, + BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_union_id, + public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json, + BTRIM(COALESCE(public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel', '')) AS channel, + BTRIM(COALESCE( + public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel_app_id', + public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'appid', + public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'app_id', + '' + )) AS channel_app_id + FROM user_external_identities AS uei + JOIN users AS u ON u.id = uei.user_id + WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' + AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '' + AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '' +) AS legacy +JOIN auth_identities AS legacy_ai + ON legacy_ai.user_id = legacy.user_id + AND legacy_ai.provider_type = 'wechat' + AND legacy_ai.provider_key = 'wechat-main' + AND legacy_ai.provider_subject = legacy.provider_union_id +JOIN auth_identity_channels AS channel + ON channel.provider_type = 'wechat' + AND channel.provider_key = 'wechat-main' + AND channel.channel = legacy.channel + AND channel.channel_app_id = legacy.channel_app_id + AND channel.channel_subject = legacy.provider_user_id +JOIN auth_identities AS existing_ai + ON existing_ai.id = channel.identity_id +WHERE legacy.channel <> '' + AND legacy.channel_app_id <> '' + AND existing_ai.user_id <> legacy.user_id +ON CONFLICT (report_type, report_key) DO NOTHING; +$sql$; + + EXECUTE $sql$ +INSERT INTO auth_identity_channels ( + identity_id, + provider_type, + provider_key, + channel, + channel_app_id, + channel_subject, + metadata +) +SELECT + legacy_ai.id, + 'wechat', + 'wechat-main', + legacy.channel, + legacy.channel_app_id, + legacy.provider_user_id, + legacy.metadata_json || jsonb_build_object( + 'openid', legacy.provider_user_id, + 'unionid', legacy.provider_union_id, + 'migration', '116_auth_identity_legacy_external_safety_reports' + ) +FROM ( + SELECT + uei.user_id, + BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id, + BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_union_id, + public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json, + BTRIM(COALESCE(public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel', '')) AS channel, + BTRIM(COALESCE( + public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel_app_id', + public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'appid', + public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'app_id', + '' + )) AS channel_app_id + FROM user_external_identities AS uei + JOIN users AS u ON u.id = uei.user_id + WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' + AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '' + AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '' +) AS legacy +JOIN auth_identities AS legacy_ai + ON legacy_ai.user_id = legacy.user_id + AND legacy_ai.provider_type = 'wechat' + AND legacy_ai.provider_key = 'wechat-main' + AND legacy_ai.provider_subject = legacy.provider_union_id +LEFT JOIN auth_identity_channels AS channel + ON channel.provider_type = 'wechat' + AND channel.provider_key = 'wechat-main' + AND channel.channel = legacy.channel + AND channel.channel_app_id = legacy.channel_app_id + AND channel.channel_subject = legacy.provider_user_id +WHERE legacy.channel <> '' + AND legacy.channel_app_id <> '' + AND channel.id IS NULL +ON CONFLICT DO NOTHING; +$sql$; + + EXECUTE $sql$ +INSERT INTO auth_identity_migration_reports (report_type, report_key, details) +SELECT + 'wechat_openid_only_requires_remediation', + 'legacy_external_identity:' || legacy.id::text, + legacy.metadata_json || jsonb_build_object( + 'legacy_identity_id', legacy.id, + 'user_id', legacy.user_id, + 'openid', legacy.provider_user_id, + 'reason', 'legacy user_external_identities row only has openid and cannot be canonicalized offline', + 'migration', '116_auth_identity_legacy_external_safety_reports' + ) +FROM ( + SELECT + uei.id, + uei.user_id, + BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id, + public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json + FROM user_external_identities AS uei + JOIN users AS u ON u.id = uei.user_id + WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' + AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '' + AND BTRIM(COALESCE(uei.provider_union_id, '')) = '' +) AS legacy +ON CONFLICT (report_type, report_key) DO NOTHING; +$sql$; +END $$; + +DO $$ +BEGIN + IF NOT EXISTS ( + SELECT 1 + FROM pg_constraint + WHERE conname = 'auth_identities_metadata_is_object_check' + ) THEN + ALTER TABLE auth_identities + ADD CONSTRAINT auth_identities_metadata_is_object_check + CHECK (jsonb_typeof(metadata) = 'object'); + END IF; + + IF NOT EXISTS ( + SELECT 1 + FROM pg_constraint + WHERE conname = 'auth_identity_channels_metadata_is_object_check' + ) THEN + ALTER TABLE auth_identity_channels + ADD CONSTRAINT auth_identity_channels_metadata_is_object_check + CHECK (jsonb_typeof(metadata) = 'object'); + END IF; + + IF NOT EXISTS ( + SELECT 1 + FROM pg_constraint + WHERE conname = 'auth_identity_migration_reports_details_is_object_check' + ) THEN + ALTER TABLE auth_identity_migration_reports + ADD CONSTRAINT auth_identity_migration_reports_details_is_object_check + CHECK (jsonb_typeof(details) = 'object'); + END IF; +END $$; + +DROP FUNCTION IF EXISTS public.__migration_116_is_valid_legacy_metadata_jsonb(TEXT); +DROP FUNCTION IF EXISTS public.__migration_116_safe_legacy_metadata_jsonb(TEXT); diff --git a/backend/migrations/117_add_payment_order_provider_snapshot.sql b/backend/migrations/117_add_payment_order_provider_snapshot.sql new file mode 100644 index 0000000000000000000000000000000000000000..56a5fe2dc504bf9fa34d04588206f9a43b66af43 --- /dev/null +++ b/backend/migrations/117_add_payment_order_provider_snapshot.sql @@ -0,0 +1,2 @@ +ALTER TABLE payment_orders +ADD COLUMN IF NOT EXISTS provider_snapshot JSONB; diff --git a/backend/migrations/118_wechat_dual_mode_and_auth_source_defaults.sql b/backend/migrations/118_wechat_dual_mode_and_auth_source_defaults.sql new file mode 100644 index 0000000000000000000000000000000000000000..6eef59e29d6e7666b1788b71c32eb3817a197b7e --- /dev/null +++ b/backend/migrations/118_wechat_dual_mode_and_auth_source_defaults.sql @@ -0,0 +1,32 @@ +INSERT INTO settings (key, value) +VALUES + ( + 'wechat_connect_open_enabled', + CASE + WHEN COALESCE((SELECT value FROM settings WHERE key = 'wechat_connect_enabled'), 'false') <> 'true' THEN 'false' + WHEN LOWER(TRIM(COALESCE((SELECT value FROM settings WHERE key = 'wechat_connect_mode'), 'open'))) = 'mp' THEN 'false' + ELSE 'true' + END + ), + ( + 'wechat_connect_mp_enabled', + CASE + WHEN COALESCE((SELECT value FROM settings WHERE key = 'wechat_connect_enabled'), 'false') <> 'true' THEN 'false' + WHEN LOWER(TRIM(COALESCE((SELECT value FROM settings WHERE key = 'wechat_connect_mode'), 'open'))) = 'mp' THEN 'true' + ELSE 'false' + END + ), + ('auth_source_default_email_grant_on_signup', 'false'), + ('auth_source_default_linuxdo_grant_on_signup', 'false'), + ('auth_source_default_oidc_grant_on_signup', 'false'), + ('auth_source_default_wechat_grant_on_signup', 'false') +ON CONFLICT (key) DO NOTHING; + +UPDATE settings +SET value = 'false' +WHERE key IN ( + 'auth_source_default_email_grant_on_signup', + 'auth_source_default_linuxdo_grant_on_signup', + 'auth_source_default_oidc_grant_on_signup', + 'auth_source_default_wechat_grant_on_signup' +); diff --git a/docs/ADMIN_PAYMENT_INTEGRATION_API.md b/docs/ADMIN_PAYMENT_INTEGRATION_API.md deleted file mode 100644 index f674f86c934805bad2284554a04de31b151609e5..0000000000000000000000000000000000000000 --- a/docs/ADMIN_PAYMENT_INTEGRATION_API.md +++ /dev/null @@ -1,243 +0,0 @@ -# ADMIN_PAYMENT_INTEGRATION_API - -> 单文件中英双语文档 / Single-file bilingual documentation (Chinese + English) - ---- - -## 中文 - -### 目标 -本文档用于对接外部支付系统(如 `sub2apipay`)与 Sub2API 的 Admin API,覆盖: -- 支付成功后充值 -- 用户查询 -- 人工余额修正 -- 前端购买页参数透传 - -### 基础地址 -- 生产:`https://` -- Beta:`http://:8084` - -### 认证 -推荐使用: -- `x-api-key: admin-<64hex>` -- `Content-Type: application/json` -- 幂等接口额外传:`Idempotency-Key` - -说明:管理员 JWT 也可访问 admin 路由,但服务间调用建议使用 Admin API Key。 - -### 1) 一步完成创建并兑换 -`POST /api/v1/admin/redeem-codes/create-and-redeem` - -用途:原子完成“创建兑换码 + 兑换到指定用户”。 - -请求头: -- `x-api-key` -- `Idempotency-Key` - -请求体示例: -```json -{ - "code": "s2p_cm1234567890", - "type": "balance", - "value": 100.0, - "user_id": 123, - "notes": "sub2apipay order: cm1234567890" -} -``` - -幂等语义: -- 同 `code` 且 `used_by` 一致:`200` -- 同 `code` 但 `used_by` 不一致:`409` -- 缺少 `Idempotency-Key`:`400`(`IDEMPOTENCY_KEY_REQUIRED`) - -curl 示例: -```bash -curl -X POST "${BASE}/api/v1/admin/redeem-codes/create-and-redeem" \ - -H "x-api-key: ${KEY}" \ - -H "Idempotency-Key: pay-cm1234567890-success" \ - -H "Content-Type: application/json" \ - -d '{ - "code":"s2p_cm1234567890", - "type":"balance", - "value":100.00, - "user_id":123, - "notes":"sub2apipay order: cm1234567890" - }' -``` - -### 2) 查询用户(可选前置校验) -`GET /api/v1/admin/users/:id` - -```bash -curl -s "${BASE}/api/v1/admin/users/123" \ - -H "x-api-key: ${KEY}" -``` - -### 3) 余额调整(已有接口) -`POST /api/v1/admin/users/:id/balance` - -用途:人工补偿 / 扣减,支持 `set` / `add` / `subtract`。 - -请求体示例(扣减): -```json -{ - "balance": 100.0, - "operation": "subtract", - "notes": "manual correction" -} -``` - -```bash -curl -X POST "${BASE}/api/v1/admin/users/123/balance" \ - -H "x-api-key: ${KEY}" \ - -H "Idempotency-Key: balance-subtract-cm1234567890" \ - -H "Content-Type: application/json" \ - -d '{ - "balance":100.00, - "operation":"subtract", - "notes":"manual correction" - }' -``` - -### 4) 购买页 / 自定义页面 URL Query 透传(iframe / 新窗口一致) -当 Sub2API 打开 `purchase_subscription_url` 或用户侧自定义页面 iframe URL 时,会统一追加: -- `user_id` -- `token` -- `theme`(`light` / `dark`) -- `lang`(例如 `zh` / `en`,用于向嵌入页传递当前界面语言) -- `ui_mode`(固定 `embedded`) - -示例: -```text -https://pay.example.com/pay?user_id=123&token=&theme=light&lang=zh&ui_mode=embedded -``` - -### 5) 失败处理建议 -- 支付成功与充值成功分状态落库 -- 回调验签成功后立即标记“支付成功” -- 支付成功但充值失败的订单允许后续重试 -- 重试保持相同 `code`,并使用新的 `Idempotency-Key` - -### 6) `doc_url` 配置建议 -- 查看链接:`https://github.com/Wei-Shaw/sub2api/blob/main/ADMIN_PAYMENT_INTEGRATION_API.md` -- 下载链接:`https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/ADMIN_PAYMENT_INTEGRATION_API.md` - ---- - -## English - -### Purpose -This document describes the minimal Sub2API Admin API surface for external payment integrations (for example, `sub2apipay`), including: -- Recharge after payment success -- User lookup -- Manual balance correction -- Purchase page query parameter forwarding - -### Base URL -- Production: `https://` -- Beta: `http://:8084` - -### Authentication -Recommended headers: -- `x-api-key: admin-<64hex>` -- `Content-Type: application/json` -- `Idempotency-Key` for idempotent endpoints - -Note: Admin JWT can also access admin routes, but Admin API Key is recommended for server-to-server integration. - -### 1) Create and Redeem in one step -`POST /api/v1/admin/redeem-codes/create-and-redeem` - -Use case: atomically create a redeem code and redeem it to a target user. - -Headers: -- `x-api-key` -- `Idempotency-Key` - -Request body: -```json -{ - "code": "s2p_cm1234567890", - "type": "balance", - "value": 100.0, - "user_id": 123, - "notes": "sub2apipay order: cm1234567890" -} -``` - -Idempotency behavior: -- Same `code` and same `used_by`: `200` -- Same `code` but different `used_by`: `409` -- Missing `Idempotency-Key`: `400` (`IDEMPOTENCY_KEY_REQUIRED`) - -curl example: -```bash -curl -X POST "${BASE}/api/v1/admin/redeem-codes/create-and-redeem" \ - -H "x-api-key: ${KEY}" \ - -H "Idempotency-Key: pay-cm1234567890-success" \ - -H "Content-Type: application/json" \ - -d '{ - "code":"s2p_cm1234567890", - "type":"balance", - "value":100.00, - "user_id":123, - "notes":"sub2apipay order: cm1234567890" - }' -``` - -### 2) Query User (optional pre-check) -`GET /api/v1/admin/users/:id` - -```bash -curl -s "${BASE}/api/v1/admin/users/123" \ - -H "x-api-key: ${KEY}" -``` - -### 3) Balance Adjustment (existing API) -`POST /api/v1/admin/users/:id/balance` - -Use case: manual correction with `set` / `add` / `subtract`. - -Request body example (`subtract`): -```json -{ - "balance": 100.0, - "operation": "subtract", - "notes": "manual correction" -} -``` - -```bash -curl -X POST "${BASE}/api/v1/admin/users/123/balance" \ - -H "x-api-key: ${KEY}" \ - -H "Idempotency-Key: balance-subtract-cm1234567890" \ - -H "Content-Type: application/json" \ - -d '{ - "balance":100.00, - "operation":"subtract", - "notes":"manual correction" - }' -``` - -### 4) Purchase / Custom Page URL query forwarding (iframe and new tab) -When Sub2API opens `purchase_subscription_url` or a user-facing custom page iframe URL, it appends: -- `user_id` -- `token` -- `theme` (`light` / `dark`) -- `lang` (for example `zh` / `en`, used to pass the current UI language to the embedded page) -- `ui_mode` (fixed: `embedded`) - -Example: -```text -https://pay.example.com/pay?user_id=123&token=&theme=light&lang=zh&ui_mode=embedded -``` - -### 5) Failure handling recommendations -- Persist payment success and recharge success as separate states -- Mark payment as successful immediately after verified callback -- Allow retry for orders with payment success but recharge failure -- Keep the same `code` for retry, and use a new `Idempotency-Key` - -### 6) Recommended `doc_url` -- View URL: `https://github.com/Wei-Shaw/sub2api/blob/main/ADMIN_PAYMENT_INTEGRATION_API.md` -- Download URL: `https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/ADMIN_PAYMENT_INTEGRATION_API.md` diff --git a/docs/PAYMENT.md b/docs/PAYMENT.md deleted file mode 100644 index 755b313a63d230d6c3a60a89092d89fc3b813e2b..0000000000000000000000000000000000000000 --- a/docs/PAYMENT.md +++ /dev/null @@ -1,278 +0,0 @@ -# Payment System Configuration Guide - -Sub2API has a built-in payment system that enables user self-service top-up without deploying a separate payment service. - ---- - -## Table of Contents - -- [Supported Payment Methods](#supported-payment-methods) -- [Quick Start](#quick-start) -- [System Settings](#system-settings) -- [Provider Configuration](#provider-configuration) -- [Provider Instance Management](#provider-instance-management) -- [Webhook Configuration](#webhook-configuration) -- [Payment Flow](#payment-flow) -- [Migrating from Sub2ApiPay](#migrating-from-sub2apipay) - ---- - -## Supported Payment Methods - -| Provider | Payment Methods | Description | -|----------|----------------|-------------| -| **EasyPay** | Alipay, WeChat Pay | Third-party aggregation via EasyPay protocol | -| **Alipay (Direct)** | PC Page Pay, H5 Mobile Pay | Direct integration with Alipay Open Platform, auto-switches by device | -| **WeChat Pay (Direct)** | Native QR Code, H5 Pay | Direct integration with WeChat Pay APIv3, mobile-first H5 | -| **Stripe** | Card, Alipay, WeChat Pay, Link, etc. | International payments, multi-currency support | - -> Alipay/WeChat Pay direct and EasyPay can coexist. Direct channels connect to payment APIs directly with lower fees; EasyPay aggregates through third-party platforms with easier setup. - -> **EasyPay Provider Recommendations**: Both options below are third-party aggregators compatible with the EasyPay protocol. Pick based on the funding channel and settlement currency you need: -> -> - **Domestic channel / CNY settlement** — [ZPay](https://z-pay.cn/?uid=23808) (`https://z-pay.cn/?uid=23808`): direct integration with official Alipay / WeChat Pay APIs, fee **1.6%**; funds go straight to the merchant account with **T+1 automatic settlement**. Supports **individual users** (no business license required) with up to 10,000 CNY daily transactions; business-licensed accounts have no limit. Link contains the referral code of [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) original author [@touwaeriol](https://github.com/touwaeriol) — feel free to remove it. -> - **International channel / USDT or USD settlement** — [Kyren Topup](https://kyren.top/?code=SUB2API) (`https://kyren.top/?code=SUB2API`): a ready-to-launch global payment stack for AI startups with WeChat Pay and Alipay support, local-currency checkout, and USD settlement. Fees: WeChat 2%, Alipay 2.5%; withdrawal 0.1% (min $40, max $150), settled in **USDT or USD**. No qualification review required — sign up and use immediately, making it the lowest barrier to entry. Withdrawal threshold is relatively high, recommended for users **who do not use domestic Chinese payment channels, cannot tolerate Stripe's 6%+ fees, have high transaction volume, and have USD or USDT channels to receive withdrawn funds**. Kyren Topup charges a $200 account opening fee; signing up via this link (which contains Sub2Api author [@Wei-Shaw](https://github.com/Wei-Shaw)'s referral code) **waives the opening fee**. Feel free to remove it if you prefer. -> -> Please evaluate the security, reliability, and compliance of any third-party payment provider on your own — this project does not endorse or guarantee any of them. - ---- - -## Quick Start - -1. Go to Admin Dashboard → **Settings** → **Payment Settings** tab -2. Enable **Payment** -3. Configure basic parameters (amount range, timeout, etc.) -4. Add at least one provider instance in **Provider Management** -5. Users can now top up from the frontend - ---- - -## System Settings - -Configure the following in Admin Dashboard **Settings → Payment Settings**: - -### Basic Settings - -| Setting | Description | Default | -|---------|-------------|---------| -| **Enable Payment** | Enable or disable the payment system | Off | -| **Product Name Prefix** | Prefix shown on payment page | - | -| **Product Name Suffix** | Suffix (e.g., "Credits") | - | -| **Minimum Amount** | Minimum single top-up amount | 1 | -| **Maximum Amount** | Maximum single top-up amount (empty = unlimited) | - | -| **Daily Limit** | Per-user daily cumulative limit (empty = unlimited) | - | -| **Order Timeout** | Order timeout in minutes (minimum 1) | 5 | -| **Max Pending Orders** | Maximum concurrent pending orders per user | 3 | -| **Load Balance Strategy** | Strategy for selecting provider instances | Least Amount | - -### Load Balance Strategies - -| Strategy | Description | -|----------|-------------| -| **Round Robin** | Distribute orders to instances in rotation | -| **Least Amount** | Prefer instances with the lowest daily cumulative amount | - -### Cancel Rate Limiting - -Prevents users from repeatedly creating and canceling orders: - -| Setting | Description | -|---------|-------------| -| **Enable Limit** | Toggle | -| **Window Mode** | Sliding / Fixed window | -| **Time Window** | Window duration | -| **Window Unit** | Minutes / Hours | -| **Max Cancels** | Maximum cancellations allowed within the window | - -### Help Information - -| Setting | Description | -|---------|-------------| -| **Help Image** | Customer service QR code or help image (supports upload) | -| **Help Text** | Instructions displayed on the payment page | - ---- - -## Provider Configuration - -Each provider type requires different credentials. Select the type when adding a new provider instance in **Provider Management → Add Provider**. - -> **Callback URLs are auto-generated**: When adding a provider, the Notify URL and Return URL are automatically constructed from your site domain. You only need to confirm the domain is correct. - -### EasyPay - -Compatible with any payment service that implements the EasyPay protocol. - -| Parameter | Description | Required | -|-----------|-------------|----------| -| **Merchant ID (PID)** | EasyPay merchant ID | Yes | -| **Merchant Key (PKey)** | EasyPay merchant secret key | Yes | -| **API Base URL** | EasyPay API base address | Yes | -| **Alipay Channel ID** | Specify Alipay channel (optional) | No | -| **WeChat Channel ID** | Specify WeChat channel (optional) | No | - -### Alipay (Direct) - -Direct integration with Alipay Open Platform. Supports PC page pay and H5 mobile pay. - -| Parameter | Description | Required | -|-----------|-------------|----------| -| **AppID** | Alipay application AppID | Yes | -| **Private Key** | RSA2 application private key | Yes | -| **Alipay Public Key** | Alipay public key | Yes | - -### WeChat Pay (Direct) - -Direct integration with WeChat Pay APIv3. Supports Native QR code and H5 payment. - -| Parameter | Description | Required | -|-----------|-------------|----------| -| **AppID** | WeChat Pay AppID | Yes | -| **Merchant ID (MchID)** | WeChat Pay merchant ID | Yes | -| **Merchant API Private Key** | Merchant API private key (PEM format) | Yes | -| **APIv3 Key** | 32-byte APIv3 key | Yes | -| **WeChat Pay Public Key** | WeChat Pay public key (PEM format) | Yes | -| **WeChat Pay Public Key ID** | WeChat Pay public key ID | No | -| **Certificate Serial Number** | Merchant certificate serial number | No | - -### Stripe - -International payment platform supporting multiple payment methods and currencies. - -| Parameter | Description | Required | -|-----------|-------------|----------| -| **Secret Key** | Stripe secret key (`sk_live_...` or `sk_test_...`) | Yes | -| **Publishable Key** | Stripe publishable key (`pk_live_...` or `pk_test_...`) | Yes | -| **Webhook Secret** | Stripe Webhook signing secret (`whsec_...`) | Yes | - ---- - -## Provider Instance Management - -You can create **multiple instances** of the same provider type for load balancing and risk control: - -- **Multi-instance load balancing** — Distribute orders via round-robin or least-amount strategy -- **Independent limits** — Each instance can have its own min/max amount and daily limit -- **Independent toggle** — Enable/disable individual instances without affecting others -- **Refund control** — Enable or disable refunds per instance -- **Payment methods** — Each instance can support a subset of payment methods -- **Ordering** — Drag to reorder instances - -### Instance Limit Configuration - -Each instance supports these limits: - -| Limit | Description | -|-------|-------------| -| **Minimum Amount** | Minimum order amount accepted by this instance | -| **Maximum Amount** | Maximum order amount accepted by this instance | -| **Daily Limit** | Daily cumulative transaction limit for this instance | - -> During load balancing, instances that exceed their limits are automatically skipped. - ---- - -## Webhook Configuration - -Payment callbacks are essential for the payment system to work correctly. - -### Callback URL Format - -When adding a provider, the system auto-generates callback URLs from your site domain: - -| Provider | Callback Path | -|----------|-------------| -| **EasyPay** | `https://your-domain.com/api/v1/payment/webhook/easypay` | -| **Alipay (Direct)** | `https://your-domain.com/api/v1/payment/webhook/alipay` | -| **WeChat Pay (Direct)** | `https://your-domain.com/api/v1/payment/webhook/wxpay` | -| **Stripe** | `https://your-domain.com/api/v1/payment/webhook/stripe` | - -> Replace `your-domain.com` with your actual domain. For EasyPay / Alipay / WeChat Pay, the callback URL is auto-filled when adding the provider — no manual configuration needed. - -### Stripe Webhook Setup - -1. Log in to [Stripe Dashboard](https://dashboard.stripe.com/) -2. Go to **Developers → Webhooks** -3. Add an endpoint with the callback URL -4. Subscribe to events: `payment_intent.succeeded`, `payment_intent.payment_failed` -5. Copy the generated Webhook Secret (`whsec_...`) to your provider configuration - -### Important Notes - -- Callback URLs must use **HTTPS** (required by Stripe, strongly recommended for others) -- Ensure your firewall allows callback requests from payment platforms -- The system automatically verifies callback signatures to prevent forgery -- Balance top-up is processed automatically upon successful payment — no manual intervention needed - ---- - -## Payment Flow - -``` -User selects amount and payment method - │ - ▼ - Create Order (PENDING) - ├─ Validate amount range, pending order count, daily limit - ├─ Load balance to select provider instance - └─ Call provider to get payment info - │ - ▼ - User completes payment - ├─ EasyPay → QR code / H5 redirect - ├─ Alipay → PC page pay / H5 mobile pay - ├─ WeChat Pay → Native QR / H5 pay - └─ Stripe → Payment Element (card/Alipay/WeChat/etc.) - │ - ▼ - Webhook callback verified → Order PAID - │ - ▼ - Auto top-up to user balance → Order COMPLETED -``` - -### Order Status Reference - -| Status | Description | -|--------|-------------| -| `PENDING` | Waiting for user to complete payment | -| `PAID` | Payment confirmed, awaiting balance credit | -| `COMPLETED` | Balance credited successfully | -| `EXPIRED` | Timed out without payment | -| `CANCELLED` | Cancelled by user | -| `FAILED` | Balance credit failed, admin can retry | -| `REFUND_REQUESTED` | Refund requested | -| `REFUNDING` | Refund in progress | -| `REFUNDED` | Refund completed | - -### Timeout and Fallback - -- Before marking an order as expired, the background job queries the upstream payment status first -- If the user has actually paid but the callback was delayed, the system will reconcile automatically -- The background job runs every 60 seconds to check for timed-out orders - ---- - -## Migrating from Sub2ApiPay - -If you previously used [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) as an external payment system, you can migrate to the built-in payment system: - -### Key Differences - -| Aspect | Sub2ApiPay | Built-in Payment | -|--------|-----------|-----------------| -| Deployment | Separate service (Next.js + PostgreSQL) | Built into Sub2API, no extra deployment | -| Payment Methods | EasyPay, Alipay, WeChat, Stripe | Same | -| Configuration | Environment variables + separate admin UI | Unified in Sub2API admin dashboard | -| Top-up Integration | Via Admin API callback | Internal processing, more reliable | -| Subscription Plans | Supported | Not yet (planned) | -| Order Management | Separate admin interface | Integrated in Sub2API admin dashboard | - -### Migration Steps - -1. Enable payment in Sub2API admin dashboard and configure providers (use the same payment credentials) -2. Update webhook callback URLs to Sub2API's callback endpoints -3. Verify that new orders are processed correctly via built-in payment -4. Decommission the Sub2ApiPay service - -> **Note**: Historical order data from Sub2ApiPay will not be automatically migrated. Keep Sub2ApiPay running for a while to access historical records. diff --git a/docs/PAYMENT_CN.md b/docs/PAYMENT_CN.md deleted file mode 100644 index aca3c866126007102af10517c83d466433f93c42..0000000000000000000000000000000000000000 --- a/docs/PAYMENT_CN.md +++ /dev/null @@ -1,278 +0,0 @@ -# 支付系统配置指南 - -Sub2API 内置支付系统,支持用户自助充值,无需部署独立的支付服务。 - ---- - -## 目录 - -- [支持的支付方式](#支持的支付方式) -- [快速开始](#快速开始) -- [系统设置](#系统设置) -- [服务商配置](#服务商配置) -- [服务商实例管理](#服务商实例管理) -- [Webhook 配置](#webhook-配置) -- [支付流程](#支付流程) -- [从 Sub2ApiPay 迁移](#从-sub2apipay-迁移) - ---- - -## 支持的支付方式 - -| 服务商 | 支付方式 | 说明 | -|--------|---------|------| -| **EasyPay(易支付)** | 支付宝、微信支付 | 兼容易支付协议的第三方聚合支付 | -| **支付宝官方** | 支付宝 PC 页面支付、H5 手机网站支付 | 直接对接支付宝开放平台,自动根据终端切换 | -| **微信官方** | Native 扫码支付、H5 支付 | 直接对接微信支付 APIv3,移动端优先 H5 | -| **Stripe** | 银行卡、支付宝、微信支付、Link 等 | 国际支付,支持多币种 | - -> 支付宝官方 / 微信官方与易支付可以共存。官方渠道直接对接 API,资金直达商户账户,手续费更低;易支付通过第三方平台聚合,接入门槛更低。 - -> **易支付服务商推荐**:以下两家均为兼容易支付协议的第三方聚合支付,按资金通道与结算方式选择: -> -> - **国内渠道 / 人民币结算** — [ZPay](https://z-pay.cn/?uid=23808)(`https://z-pay.cn/?uid=23808`):支付宝 / 微信官方 API 直连,手续费 **1.6%**;资金直达商家账户,**T+1 自动到账**。支持**个人用户**(无营业执照)每日 1 万元以内交易;拥有营业执照则无限额。链接含 [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) 原作者 [@touwaeriol](https://github.com/touwaeriol) 的邀请码,介意可去掉。 -> - **国际渠道 / USDT 或美元结算** — [启润支付](https://kyren.top/?code=SUB2API)(`https://kyren.top/?code=SUB2API`):为 AI 项目提供低门槛国际收款通道,支持国际版微信支付与支付宝,本地货币支付、美元结算。手续费:微信 2%、支付宝 2.5%;提现 0.1%(最低 40 美元、最高 150 美元),以 **USDT 或美元**到账。无资质审核、注册即用,使用门槛最低;提现门槛略高,适合**不使用国内支付渠道、无法接受 Stripe 高达 6%+ 手续费、流水较大,且拥有美元或 USDT 渠道可接收提现资金**的用户。启润支付开户费 200 美元,通过本链接注册(含 Sub2Api 作者 [@Wei-Shaw](https://github.com/Wei-Shaw) 邀请码)可**免开户费**,介意可去掉。 -> -> 支付渠道的安全性、稳定性及合规性请自行鉴别,本项目不对任何第三方支付服务商做担保或背书。 - ---- - -## 快速开始 - -1. 进入管理后台 → **设置** → **支付设置** 标签页 -2. 开启 **启用支付** -3. 配置基本参数(金额范围、超时时间等) -4. 在 **服务商管理** 中添加至少一个服务商实例 -5. 用户即可在前端页面进行充值 - ---- - -## 系统设置 - -在管理后台 **设置 → 支付设置** 中配置以下参数: - -### 基本设置 - -| 设置项 | 说明 | 默认值 | -|--------|------|--------| -| **启用支付** | 启用或禁用支付系统 | 关闭 | -| **商品名前缀** | 支付页面显示的商品名前缀 | - | -| **商品名后缀** | 商品名后缀(如"元") | - | -| **最低金额** | 单笔最低充值金额 | 1 | -| **最高金额** | 单笔最高充值金额(留空表示不限制) | - | -| **每日限额** | 每用户每日累计充值上限(留空表示不限制) | - | -| **订单超时时间** | 订单超时分钟数,至少 1 分钟 | 5 | -| **最大待支付订单数** | 同一用户最大并行待支付订单数 | 3 | -| **负载均衡策略** | 多服务商实例时的选择策略 | 最少金额 | - -### 负载均衡策略 - -| 策略 | 说明 | -|------|------| -| **轮询(round-robin)** | 按顺序轮流分配到各服务商实例 | -| **最少金额(least-amount)** | 优先分配到当日累计金额最少的实例 | - -### 取消频率限制 - -防止用户频繁创建并取消订单: - -| 设置项 | 说明 | -|--------|------| -| **启用限制** | 开关 | -| **窗口模式** | 滚动窗口 / 固定窗口 | -| **时间窗口** | 窗口长度 | -| **窗口单位** | 分钟 / 小时 | -| **最大次数** | 窗口内允许的最大取消次数 | - -### 帮助信息 - -| 设置项 | 说明 | -|--------|------| -| **帮助图片** | 充值页面显示的客服二维码等图片(支持上传) | -| **帮助文本** | 充值页面显示的说明文字 | - ---- - -## 服务商配置 - -每种服务商需要不同的凭证和参数。在 **服务商管理 → 添加服务商** 中选择类型后填写。 - -> **回调地址自动生成**:添加服务商时,异步回调地址(Notify URL)和同步跳转地址(Return URL)由系统根据你的站点域名自动拼接,无需手动填写。管理员只需确认域名正确即可。 - -### EasyPay(易支付) - -兼容任何 EasyPay 协议的支付服务商。 - -| 参数 | 说明 | 必填 | -|------|------|------| -| **商户 ID(PID)** | EasyPay 商户 ID | 是 | -| **商户密钥(PKey)** | EasyPay 商户密钥 | 是 | -| **API 地址** | EasyPay API 基础地址 | 是 | -| **支付宝通道 ID** | 指定支付宝通道(可选) | 否 | -| **微信通道 ID** | 指定微信通道(可选) | 否 | - -### 支付宝官方 - -直接对接支付宝开放平台,支持 PC 页面支付和 H5 手机网站支付。 - -| 参数 | 说明 | 必填 | -|------|------|------| -| **AppID** | 支付宝应用 AppID | 是 | -| **应用私钥** | RSA2 应用私钥 | 是 | -| **支付宝公钥** | 支付宝公钥 | 是 | - -### 微信官方 - -直接对接微信支付 APIv3,支持 Native 扫码支付和 H5 支付。 - -| 参数 | 说明 | 必填 | -|------|------|------| -| **AppID** | 微信支付 AppID | 是 | -| **商户号(MchID)** | 微信支付商户号 | 是 | -| **商户 API 私钥** | 商户 API 私钥(PEM 格式) | 是 | -| **APIv3 密钥** | 32 位 APIv3 密钥 | 是 | -| **微信支付公钥** | 微信支付公钥(PEM 格式) | 是 | -| **微信支付公钥 ID** | 微信支付公钥 ID | 否 | -| **商户证书序列号** | 商户证书序列号 | 否 | - -### Stripe - -国际支付平台,支持多种支付方式和币种。 - -| 参数 | 说明 | 必填 | -|------|------|------| -| **Secret Key** | Stripe 密钥(`sk_live_...` 或 `sk_test_...`) | 是 | -| **Publishable Key** | Stripe 可公开密钥(`pk_live_...` 或 `pk_test_...`) | 是 | -| **Webhook Secret** | Stripe Webhook 签名密钥(`whsec_...`) | 是 | - ---- - -## 服务商实例管理 - -同一种服务商可以创建**多个实例**,实现负载均衡和风控: - -- **多实例负载均衡** — 按轮询或最少金额策略分流订单 -- **独立限额** — 每个实例可独立配置单笔最小/最大金额和每日限额 -- **独立启停** — 可单独启用/禁用某个实例,不影响其他实例 -- **退款控制** — 每个实例可单独开启或关闭退款功能 -- **支付方式** — 每个实例可选择支持的支付方式子集 -- **排序** — 拖拽调整实例顺序 - -### 实例限额配置 - -每个实例支持以下限额: - -| 限额项 | 说明 | -|--------|------| -| **单笔最小金额** | 该实例接受的最小订单金额 | -| **单笔最大金额** | 该实例接受的最大订单金额 | -| **每日限额** | 该实例每日累计交易上限 | - -> 负载均衡时,系统会自动跳过超出限额的实例。 - ---- - -## Webhook 配置 - -支付回调是支付系统的核心环节,必须正确配置: - -### 回调地址格式 - -添加服务商时,系统会自动根据站点域名拼接回调地址,格式如下: - -| 服务商 | 回调路径 | -|--------|---------| -| **EasyPay** | `https://your-domain.com/api/v1/payment/webhook/easypay` | -| **支付宝官方** | `https://your-domain.com/api/v1/payment/webhook/alipay` | -| **微信官方** | `https://your-domain.com/api/v1/payment/webhook/wxpay` | -| **Stripe** | `https://your-domain.com/api/v1/payment/webhook/stripe` | - -> 将 `your-domain.com` 替换为你的实际域名。EasyPay / 支付宝 / 微信的回调地址在添加服务商时自动填入,无需手动配置。 - -### Stripe Webhook 设置 - -1. 登录 [Stripe Dashboard](https://dashboard.stripe.com/) -2. 进入 **Developers → Webhooks** -3. 添加端点,填写回调地址 -4. 订阅事件:`payment_intent.succeeded`、`payment_intent.payment_failed` -5. 将生成的 Webhook Secret(`whsec_...`)填入服务商配置 - -### 注意事项 - -- 回调地址必须是 **HTTPS**(Stripe 强制要求,其他服务商强烈推荐) -- 确保服务器防火墙允许支付平台的回调请求 -- 系统会自动进行签名验证,防止伪造回调 -- 支付成功后自动完成余额充值,无需人工干预 - ---- - -## 支付流程 - -``` -用户选择充值金额和支付方式 - │ - ▼ - 创建订单 (PENDING) - ├─ 校验金额范围、待支付订单数、每日限额 - ├─ 负载均衡选择服务商实例 - └─ 调用服务商获取支付信息 - │ - ▼ - 用户完成支付 - ├─ EasyPay → 扫码 / H5 跳转 - ├─ 支付宝官方 → PC 页面支付 / H5 手机网站支付 - ├─ 微信官方 → Native 扫码 / H5 支付 - └─ Stripe → Payment Element(银行卡/支付宝/微信等) - │ - ▼ - 支付回调验签 → 订单 PAID - │ - ▼ - 自动充值到用户余额 → 订单 COMPLETED -``` - -### 订单状态说明 - -| 状态 | 说明 | -|------|------| -| `PENDING` | 待支付,等待用户完成支付 | -| `PAID` | 已支付,等待充值到账 | -| `COMPLETED` | 已完成,余额已到账 | -| `EXPIRED` | 已过期,超时未支付 | -| `CANCELLED` | 已取消,用户主动取消 | -| `FAILED` | 充值失败,可管理员重试 | -| `REFUND_REQUESTED` | 已申请退款 | -| `REFUNDING` | 退款处理中 | -| `REFUNDED` | 已退款 | - -### 超时与兜底 - -- 订单超时后,后台任务会先查询上游支付状态再标记过期 -- 如果用户实际已支付但回调延迟,系统会通过查询补单 -- 后台任务每 60 秒执行一次超时检查 - ---- - -## 从 Sub2ApiPay 迁移 - -如果你之前使用 [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) 作为外部支付系统,现在可以迁移到内置支付: - -### 主要差异 - -| 对比项 | Sub2ApiPay | 内置支付 | -|--------|-----------|---------| -| 部署方式 | 独立服务(Next.js + PostgreSQL) | 内置于 Sub2API,无需额外部署 | -| 支付方式 | EasyPay、支付宝、微信、Stripe | 相同 | -| 配置方式 | 环境变量 + 独立管理后台 | Sub2API 管理后台内统一配置 | -| 充值对接 | 通过 Admin API 回调 | 内部直接处理,更可靠 | -| 订阅套餐 | 支持 | 暂不支持(计划中) | -| 订单管理 | 独立管理界面 | 集成在 Sub2API 管理后台 | - -### 迁移步骤 - -1. 在 Sub2API 管理后台启用支付并配置服务商(使用相同的支付凭证) -2. 更新 Webhook 回调地址为 Sub2API 的回调地址 -3. 确认新订单通过内置支付正常处理 -4. 停用 Sub2ApiPay 服务 - -> **注意**:Sub2ApiPay 中的历史订单数据不会自动迁移。建议保留 Sub2ApiPay 一段时间以便查询历史记录。 diff --git a/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts b/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..f95332fb971801be109533bef8138b2d7e686124 --- /dev/null +++ b/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts @@ -0,0 +1,192 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const post = vi.fn() + +vi.mock('@/api/client', () => ({ + apiClient: { + post + } +})) + +describe('oauth adoption auth api', () => { + beforeEach(() => { + post.mockReset() + post.mockResolvedValue({ data: {} }) + localStorage.clear() + document.cookie = 'oauth_bind_access_token=; Max-Age=0; path=/' + }) + + it('posts adoption decisions when exchanging pending oauth completion', async () => { + const { exchangePendingOAuthCompletion } = await import('@/api/auth') + + await exchangePendingOAuthCompletion({ + adoptDisplayName: false, + adoptAvatar: true + }) + + expect(post).toHaveBeenCalledWith('/auth/oauth/pending/exchange', { + adopt_display_name: false, + adopt_avatar: true + }) + }) + + it('posts bind-login decisions when finalizing pending oauth bind flow', async () => { + const { completePendingOAuthBindLogin } = await import('@/api/auth') + + await completePendingOAuthBindLogin({ + adoptDisplayName: true, + adoptAvatar: false + }) + + expect(post).toHaveBeenCalledWith('/auth/oauth/pending/exchange', { + adopt_display_name: true, + adopt_avatar: false + }) + }) + + it('posts linuxdo invitation completion with adoption decisions', async () => { + const { completeLinuxDoOAuthRegistration } = await import('@/api/auth') + + await completeLinuxDoOAuthRegistration('invite-code', { + adoptDisplayName: true, + adoptAvatar: false + }) + + expect(post).toHaveBeenCalledWith('/auth/oauth/linuxdo/complete-registration', { + invitation_code: 'invite-code', + adopt_display_name: true, + adopt_avatar: false + }) + }) + + it('posts linuxdo create-account completion with adoption decisions', async () => { + const { createPendingLinuxDoOAuthAccount } = await import('@/api/auth') + + await createPendingLinuxDoOAuthAccount('invite-code', { + adoptDisplayName: false, + adoptAvatar: true + }) + + expect(post).toHaveBeenCalledWith('/auth/oauth/linuxdo/complete-registration', { + invitation_code: 'invite-code', + adopt_display_name: false, + adopt_avatar: true + }) + }) + + it('posts oidc invitation completion with adoption decisions', async () => { + const { completeOIDCOAuthRegistration } = await import('@/api/auth') + + await completeOIDCOAuthRegistration('invite-code', { + adoptDisplayName: false, + adoptAvatar: true + }) + + expect(post).toHaveBeenCalledWith('/auth/oauth/oidc/complete-registration', { + invitation_code: 'invite-code', + adopt_display_name: false, + adopt_avatar: true + }) + }) + + it('posts oidc create-account completion with adoption decisions', async () => { + const { createPendingOIDCOAuthAccount } = await import('@/api/auth') + + await createPendingOIDCOAuthAccount('invite-code', { + adoptDisplayName: true, + adoptAvatar: false + }) + + expect(post).toHaveBeenCalledWith('/auth/oauth/oidc/complete-registration', { + invitation_code: 'invite-code', + adopt_display_name: true, + adopt_avatar: false + }) + }) + + it('posts wechat invitation completion with adoption decisions', async () => { + const { completeWeChatOAuthRegistration } = await import('@/api/auth') + + await completeWeChatOAuthRegistration('invite-code', { + adoptDisplayName: true, + adoptAvatar: true + }) + + expect(post).toHaveBeenCalledWith('/auth/oauth/wechat/complete-registration', { + invitation_code: 'invite-code', + adopt_display_name: true, + adopt_avatar: true + }) + }) + + it('posts wechat create-account completion with adoption decisions', async () => { + const { createPendingWeChatOAuthAccount } = await import('@/api/auth') + + await createPendingWeChatOAuthAccount('invite-code', { + adoptDisplayName: false, + adoptAvatar: false + }) + + expect(post).toHaveBeenCalledWith('/auth/oauth/wechat/complete-registration', { + invitation_code: 'invite-code', + adopt_display_name: false, + adopt_avatar: false + }) + }) + + it('classifies oauth completion results as login or bind', async () => { + const { getOAuthCompletionKind } = await import('@/api/auth') + + expect(getOAuthCompletionKind({ access_token: 'access-token' })).toBe('login') + expect(getOAuthCompletionKind({ redirect: '/profile' })).toBe('bind') + }) + + it('provides bind-login utility helpers for invitation and suggested profile states', async () => { + const { + getPendingOAuthBindLoginKind, + hasPendingOAuthSuggestedProfile, + isPendingOAuthCreateAccountRequired + } = await import('@/api/auth') + + expect(getPendingOAuthBindLoginKind({ access_token: 'access-token' })).toBe('login') + expect(getPendingOAuthBindLoginKind({ redirect: '/profile' })).toBe('bind') + expect( + isPendingOAuthCreateAccountRequired({ + error: 'invitation_required' + }) + ).toBe(true) + expect( + isPendingOAuthCreateAccountRequired({ + error: 'other' + }) + ).toBe(false) + expect( + hasPendingOAuthSuggestedProfile({ + suggested_display_name: 'OAuth Nick' + }) + ).toBe(true) + expect( + hasPendingOAuthSuggestedProfile({ + suggested_avatar_url: 'https://cdn.example/avatar.png' + }) + ).toBe(true) + expect(hasPendingOAuthSuggestedProfile({})).toBe(false) + }) + + it('prepares an oauth bind access token cookie before redirect binding', async () => { + localStorage.setItem('auth_token', 'access-token-value') + const setCookie = vi.fn() + Object.defineProperty(document, 'cookie', { + configurable: true, + get: () => '', + set: setCookie + }) + + const { prepareOAuthBindAccessTokenCookie } = await import('@/api/auth') + + prepareOAuthBindAccessTokenCookie() + + expect(setCookie).toHaveBeenCalledTimes(1) + expect(setCookie.mock.calls[0]?.[0]).toContain('oauth_bind_access_token=access-token-value') + }) +}) diff --git a/frontend/src/api/__tests__/payment.spec.ts b/frontend/src/api/__tests__/payment.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..3006484ef8a9ecd7951d6f33a592813075d047a0 --- /dev/null +++ b/frontend/src/api/__tests__/payment.spec.ts @@ -0,0 +1,36 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const { get, post } = vi.hoisted(() => ({ + get: vi.fn(), + post: vi.fn(), +})) + +vi.mock('@/api/client', () => ({ + apiClient: { + get, + post, + }, +})) + +import { paymentAPI } from '@/api/payment' + +describe('payment api', () => { + beforeEach(() => { + get.mockReset() + post.mockReset() + get.mockResolvedValue({ data: {} }) + post.mockResolvedValue({ data: {} }) + }) + + it('does not expose anonymous public out_trade_no verification', () => { + expect(Object.prototype.hasOwnProperty.call(paymentAPI, 'verifyOrderPublic')).toBe(false) + }) + + it('keeps signed public resume-token resolve endpoint', async () => { + await paymentAPI.resolveOrderPublicByResumeToken('resume-token-123') + + expect(post).toHaveBeenCalledWith('/payment/public/orders/resolve', { + resume_token: 'resume-token-123', + }) + }) +}) diff --git a/frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts b/frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..10f6247ab7eca93114463b0bacbf66add9d80d03 --- /dev/null +++ b/frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts @@ -0,0 +1,131 @@ +import { describe, expect, it } from "vitest"; + +import { + appendAuthSourceDefaultsToUpdateRequest, + buildAuthSourceDefaultsState, + type UpdateSettingsRequest, +} from "@/api/admin/settings"; + +describe("admin settings auth source defaults helpers", () => { + it("builds auth source defaults state from flat settings fields", () => { + const state = buildAuthSourceDefaultsState({ + auth_source_default_email_balance: 9.5, + auth_source_default_email_concurrency: 3, + auth_source_default_email_subscriptions: [ + { group_id: 1, validity_days: 30 }, + ], + auth_source_default_email_grant_on_signup: false, + auth_source_default_email_grant_on_first_bind: true, + auth_source_default_linuxdo_balance: 6, + auth_source_default_linuxdo_concurrency: 8, + auth_source_default_linuxdo_subscriptions: [ + { group_id: 2, validity_days: 60 }, + ], + auth_source_default_linuxdo_grant_on_signup: true, + auth_source_default_linuxdo_grant_on_first_bind: false, + }); + + expect(state.email).toEqual({ + balance: 9.5, + concurrency: 3, + subscriptions: [{ group_id: 1, validity_days: 30 }], + grant_on_signup: false, + grant_on_first_bind: true, + }); + expect(state.linuxdo).toEqual({ + balance: 6, + concurrency: 8, + subscriptions: [{ group_id: 2, validity_days: 60 }], + grant_on_signup: true, + grant_on_first_bind: false, + }); + expect(state.oidc).toEqual({ + balance: 0, + concurrency: 5, + subscriptions: [], + grant_on_signup: false, + grant_on_first_bind: false, + }); + expect(state.wechat).toEqual({ + balance: 0, + concurrency: 5, + subscriptions: [], + grant_on_signup: false, + grant_on_first_bind: false, + }); + }); + + it("defaults grant-on-signup to disabled when settings are missing", () => { + const state = buildAuthSourceDefaultsState({}); + + expect(state.email.grant_on_signup).toBe(false); + expect(state.linuxdo.grant_on_signup).toBe(false); + expect(state.oidc.grant_on_signup).toBe(false); + expect(state.wechat.grant_on_signup).toBe(false); + }); + + it("appends auth source defaults back onto update payload", () => { + const payload: UpdateSettingsRequest = { + site_name: "Sub2API", + }; + + appendAuthSourceDefaultsToUpdateRequest(payload, { + email: { + balance: 1.25, + concurrency: 2, + subscriptions: [{ group_id: 3, validity_days: 7 }], + grant_on_signup: true, + grant_on_first_bind: false, + }, + linuxdo: { + balance: 0, + concurrency: 6, + subscriptions: [], + grant_on_signup: false, + grant_on_first_bind: true, + }, + oidc: { + balance: 4, + concurrency: 9, + subscriptions: [{ group_id: 9, validity_days: 90 }], + grant_on_signup: true, + grant_on_first_bind: true, + }, + wechat: { + balance: 2, + concurrency: 5, + subscriptions: [], + grant_on_signup: false, + grant_on_first_bind: false, + }, + }); + + expect(payload).toMatchObject({ + site_name: "Sub2API", + auth_source_default_email_balance: 1.25, + auth_source_default_email_concurrency: 2, + auth_source_default_email_subscriptions: [ + { group_id: 3, validity_days: 7 }, + ], + auth_source_default_email_grant_on_signup: true, + auth_source_default_email_grant_on_first_bind: false, + auth_source_default_linuxdo_balance: 0, + auth_source_default_linuxdo_concurrency: 6, + auth_source_default_linuxdo_subscriptions: [], + auth_source_default_linuxdo_grant_on_signup: false, + auth_source_default_linuxdo_grant_on_first_bind: true, + auth_source_default_oidc_balance: 4, + auth_source_default_oidc_concurrency: 9, + auth_source_default_oidc_subscriptions: [ + { group_id: 9, validity_days: 90 }, + ], + auth_source_default_oidc_grant_on_signup: true, + auth_source_default_oidc_grant_on_first_bind: true, + auth_source_default_wechat_balance: 2, + auth_source_default_wechat_concurrency: 5, + auth_source_default_wechat_subscriptions: [], + auth_source_default_wechat_grant_on_signup: false, + auth_source_default_wechat_grant_on_first_bind: false, + }); + }); +}); diff --git a/frontend/src/api/__tests__/settings.paymentVisibleMethods.spec.ts b/frontend/src/api/__tests__/settings.paymentVisibleMethods.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..ad355afee372d5c0ee78686c370153efb6e6b407 --- /dev/null +++ b/frontend/src/api/__tests__/settings.paymentVisibleMethods.spec.ts @@ -0,0 +1,63 @@ +import { describe, expect, it } from 'vitest' + +import { + getPaymentVisibleMethodSourceOptions, + normalizePaymentVisibleMethodSource, +} from '@/api/admin/settings' + +describe('admin settings payment visible method helpers', () => { + it('normalizes aliases into canonical source keys per visible method', () => { + expect(normalizePaymentVisibleMethodSource('alipay', 'official')).toBe('official_alipay') + expect(normalizePaymentVisibleMethodSource('alipay', 'alipay_direct')).toBe('official_alipay') + expect(normalizePaymentVisibleMethodSource('alipay', 'easypay')).toBe('easypay_alipay') + + expect(normalizePaymentVisibleMethodSource('wxpay', 'official')).toBe('official_wxpay') + expect(normalizePaymentVisibleMethodSource('wxpay', 'wechat')).toBe('official_wxpay') + expect(normalizePaymentVisibleMethodSource('wxpay', 'easypay')).toBe('easypay_wxpay') + }) + + it('rejects unknown or cross-method source values', () => { + expect(normalizePaymentVisibleMethodSource('alipay', 'official_wxpay')).toBe('') + expect(normalizePaymentVisibleMethodSource('wxpay', 'official_alipay')).toBe('') + expect(normalizePaymentVisibleMethodSource('alipay', 'unknown')).toBe('') + expect(normalizePaymentVisibleMethodSource('wxpay', null)).toBe('') + }) + + it('exposes method-scoped source options instead of arbitrary strings', () => { + expect(getPaymentVisibleMethodSourceOptions('alipay')).toEqual([ + { + value: '', + labelZh: '未配置', + labelEn: 'Not configured', + }, + { + value: 'official_alipay', + labelZh: '支付宝官方', + labelEn: 'Official Alipay', + }, + { + value: 'easypay_alipay', + labelZh: '易支付支付宝', + labelEn: 'EasyPay Alipay', + }, + ]) + + expect(getPaymentVisibleMethodSourceOptions('wxpay')).toEqual([ + { + value: '', + labelZh: '未配置', + labelEn: 'Not configured', + }, + { + value: 'official_wxpay', + labelZh: '微信官方', + labelEn: 'Official WeChat Pay', + }, + { + value: 'easypay_wxpay', + labelZh: '易支付微信', + labelEn: 'EasyPay WeChat Pay', + }, + ]) + }) +}) diff --git a/frontend/src/api/__tests__/settings.wechatConnect.spec.ts b/frontend/src/api/__tests__/settings.wechatConnect.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..eccb72149b4a28c4f63d5a8fe459a72eacc4e081 --- /dev/null +++ b/frontend/src/api/__tests__/settings.wechatConnect.spec.ts @@ -0,0 +1,21 @@ +import { describe, expect, it } from "vitest"; + +import { + defaultWeChatConnectScopesForMode, + normalizeWeChatConnectMode, +} from "@/api/admin/settings"; + +describe("admin settings wechat connect helpers", () => { + it("normalizes legacy or noisy mode values to the backend contract", () => { + expect(normalizeWeChatConnectMode("OPEN")).toBe("open"); + expect(normalizeWeChatConnectMode(" open_platform ")).toBe("open"); + expect(normalizeWeChatConnectMode("mp")).toBe("mp"); + expect(normalizeWeChatConnectMode("official_account")).toBe("mp"); + expect(normalizeWeChatConnectMode("unknown")).toBe("open"); + }); + + it("maps each mode to the backend default scopes", () => { + expect(defaultWeChatConnectScopesForMode("open")).toBe("snsapi_login"); + expect(defaultWeChatConnectScopesForMode("mp")).toBe("snsapi_userinfo"); + }); +}); diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index 1e4a3053092e5870d8004ee433ffe86f80581ed2..0403b0f3fefea210e8ca3d6edd8efe3a1c4f5014 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -3,12 +3,293 @@ * Handles system settings management for administrators */ -import { apiClient } from '../client' -import type { CustomMenuItem, CustomEndpoint, NotifyEmailEntry } from '@/types' +import { apiClient } from "../client"; +import type { CustomMenuItem, CustomEndpoint, NotifyEmailEntry } from "@/types"; export interface DefaultSubscriptionSetting { - group_id: number - validity_days: number + group_id: number; + validity_days: number; +} + +export type AuthSourceType = "email" | "linuxdo" | "oidc" | "wechat"; + +export interface AuthSourceDefaultsValue { + balance: number; + concurrency: number; + subscriptions: DefaultSubscriptionSetting[]; + grant_on_signup: boolean; + grant_on_first_bind: boolean; +} + +export type AuthSourceDefaultsState = Record< + AuthSourceType, + AuthSourceDefaultsValue +>; +export type PaymentVisibleMethod = "alipay" | "wxpay"; +export type PaymentVisibleMethodSource = + | "" + | "official_alipay" + | "easypay_alipay" + | "official_wxpay" + | "easypay_wxpay"; +export type WeChatConnectMode = "open" | "mp" | "mobile"; + +export interface PaymentVisibleMethodSourceOption { + value: PaymentVisibleMethodSource; + labelZh: string; + labelEn: string; +} + +export interface WeChatConnectModeOption { + value: WeChatConnectMode; + labelZh: string; + labelEn: string; +} + +const AUTH_SOURCE_TYPES: AuthSourceType[] = [ + "email", + "linuxdo", + "oidc", + "wechat", +]; +const AUTH_SOURCE_DEFAULT_BALANCE = 0; +const AUTH_SOURCE_DEFAULT_CONCURRENCY = 5; +const PAYMENT_VISIBLE_METHOD_SOURCE_OPTIONS: Record< + PaymentVisibleMethod, + PaymentVisibleMethodSourceOption[] +> = { + alipay: [ + { value: "", labelZh: "未配置", labelEn: "Not configured" }, + { + value: "official_alipay", + labelZh: "支付宝官方", + labelEn: "Official Alipay", + }, + { + value: "easypay_alipay", + labelZh: "易支付支付宝", + labelEn: "EasyPay Alipay", + }, + ], + wxpay: [ + { value: "", labelZh: "未配置", labelEn: "Not configured" }, + { + value: "official_wxpay", + labelZh: "微信官方", + labelEn: "Official WeChat Pay", + }, + { + value: "easypay_wxpay", + labelZh: "易支付微信", + labelEn: "EasyPay WeChat Pay", + }, + ], +}; +const PAYMENT_VISIBLE_METHOD_SOURCE_ALIASES: Record< + PaymentVisibleMethod, + Record +> = { + alipay: { + official_alipay: "official_alipay", + alipay: "official_alipay", + alipay_direct: "official_alipay", + official: "official_alipay", + easypay_alipay: "easypay_alipay", + easypay: "easypay_alipay", + }, + wxpay: { + official_wxpay: "official_wxpay", + wxpay: "official_wxpay", + wxpay_direct: "official_wxpay", + wechat: "official_wxpay", + official: "official_wxpay", + easypay_wxpay: "easypay_wxpay", + easypay: "easypay_wxpay", + }, +}; +const WECHAT_CONNECT_MODE_OPTIONS: WeChatConnectModeOption[] = [ + { value: "open", labelZh: "PC 应用", labelEn: "PC App" }, + { + value: "mp", + labelZh: "公众号", + labelEn: "Official Account", + }, + { + value: "mobile", + labelZh: "移动应用", + labelEn: "Mobile App", + }, +]; +const WECHAT_CONNECT_MODE_ALIASES: Record = { + open: "open", + open_platform: "open", + official: "open", + wx_open: "open", + mp: "mp", + official_account: "mp", + wechat_mp: "mp", + mini_program: "mp", + mobile: "mobile", + mobile_app: "mobile", + native_app: "mobile", +}; + +export function normalizeDefaultSubscriptionSettings( + subscriptions: DefaultSubscriptionSetting[] | null | undefined, +): DefaultSubscriptionSetting[] { + if (!Array.isArray(subscriptions)) return []; + + return subscriptions + .filter((item) => item.group_id > 0 && item.validity_days > 0) + .map((item) => ({ + group_id: Math.floor(item.group_id), + validity_days: Math.min( + 36500, + Math.max(1, Math.floor(item.validity_days)), + ), + })); +} + +export function buildAuthSourceDefaultsState( + settings: Partial, +): AuthSourceDefaultsState { + const raw = settings as Record; + + return AUTH_SOURCE_TYPES.reduce((acc, source) => { + const subscriptions = raw[`auth_source_default_${source}_subscriptions`]; + acc[source] = { + balance: Number( + raw[`auth_source_default_${source}_balance`] ?? + AUTH_SOURCE_DEFAULT_BALANCE, + ), + concurrency: Math.max( + 1, + Number( + raw[`auth_source_default_${source}_concurrency`] ?? + AUTH_SOURCE_DEFAULT_CONCURRENCY, + ), + ), + subscriptions: normalizeDefaultSubscriptionSettings( + Array.isArray(subscriptions) + ? (subscriptions as DefaultSubscriptionSetting[]) + : [], + ), + grant_on_signup: + raw[`auth_source_default_${source}_grant_on_signup`] === true, + grant_on_first_bind: + raw[`auth_source_default_${source}_grant_on_first_bind`] === true, + }; + return acc; + }, {} as AuthSourceDefaultsState); +} + +export function appendAuthSourceDefaultsToUpdateRequest( + payload: UpdateSettingsRequest, + authSourceDefaults: AuthSourceDefaultsState, +): UpdateSettingsRequest { + const target = payload as Record; + + for (const source of AUTH_SOURCE_TYPES) { + const current = authSourceDefaults[source]; + target[`auth_source_default_${source}_balance`] = + Number(current.balance) || 0; + target[`auth_source_default_${source}_concurrency`] = Math.max( + 1, + Math.floor( + Number(current.concurrency) || AUTH_SOURCE_DEFAULT_CONCURRENCY, + ), + ); + target[`auth_source_default_${source}_subscriptions`] = + normalizeDefaultSubscriptionSettings(current.subscriptions); + target[`auth_source_default_${source}_grant_on_signup`] = + current.grant_on_signup; + target[`auth_source_default_${source}_grant_on_first_bind`] = + current.grant_on_first_bind; + } + + return payload; +} + +export function getPaymentVisibleMethodSourceOptions( + method: PaymentVisibleMethod, +): PaymentVisibleMethodSourceOption[] { + return PAYMENT_VISIBLE_METHOD_SOURCE_OPTIONS[method]; +} + +export function normalizePaymentVisibleMethodSource( + method: PaymentVisibleMethod, + source: unknown, +): PaymentVisibleMethodSource { + if (typeof source !== "string") return ""; + + const normalized = source.trim().toLowerCase(); + if (!normalized) return ""; + + return PAYMENT_VISIBLE_METHOD_SOURCE_ALIASES[method][normalized] ?? ""; +} + +export function getWeChatConnectModeOptions(): WeChatConnectModeOption[] { + return WECHAT_CONNECT_MODE_OPTIONS; +} + +export function normalizeWeChatConnectMode(source: unknown): WeChatConnectMode { + if (typeof source !== "string") return "open"; + + const normalized = source.trim().toLowerCase(); + if (!normalized) return "open"; + + return WECHAT_CONNECT_MODE_ALIASES[normalized] ?? "open"; +} + +export function defaultWeChatConnectScopesForMode(mode: unknown): string { + switch (normalizeWeChatConnectMode(mode)) { + case "mp": + return "snsapi_userinfo"; + case "mobile": + return ""; + default: + return "snsapi_login"; + } +} + +export function resolveWeChatConnectModeCapabilities( + openEnabled: unknown, + mpEnabled: unknown, + mobileEnabled: unknown, + legacyMode: unknown, +): { openEnabled: boolean; mpEnabled: boolean; mobileEnabled: boolean } { + if ( + typeof openEnabled === "boolean" || + typeof mpEnabled === "boolean" || + typeof mobileEnabled === "boolean" + ) { + return { + openEnabled: openEnabled === true, + mpEnabled: mpEnabled === true, + mobileEnabled: mobileEnabled === true, + }; + } + + switch (normalizeWeChatConnectMode(legacyMode)) { + case "mp": + return { openEnabled: false, mpEnabled: true, mobileEnabled: false }; + case "mobile": + return { openEnabled: false, mpEnabled: false, mobileEnabled: true }; + default: + return { openEnabled: true, mpEnabled: false, mobileEnabled: false }; + } +} + +export function deriveWeChatConnectStoredMode( + openEnabled: boolean, + mpEnabled: boolean, + mobileEnabled: boolean, + legacyMode: unknown, +): WeChatConnectMode { + if (mpEnabled) return "mp"; + if (mobileEnabled) return "mobile"; + if (openEnabled) return "open"; + return normalizeWeChatConnectMode(legacyMode); } /** @@ -16,241 +297,327 @@ export interface DefaultSubscriptionSetting { */ export interface SystemSettings { // Registration settings - registration_enabled: boolean - email_verify_enabled: boolean - registration_email_suffix_whitelist: string[] - promo_code_enabled: boolean - password_reset_enabled: boolean - frontend_url: string - invitation_code_enabled: boolean - totp_enabled: boolean // TOTP 双因素认证 - totp_encryption_key_configured: boolean // TOTP 加密密钥是否已配置 + registration_enabled: boolean; + email_verify_enabled: boolean; + registration_email_suffix_whitelist: string[]; + promo_code_enabled: boolean; + password_reset_enabled: boolean; + frontend_url: string; + invitation_code_enabled: boolean; + totp_enabled: boolean; // TOTP 双因素认证 + totp_encryption_key_configured: boolean; // TOTP 加密密钥是否已配置 // Default settings - default_balance: number - default_concurrency: number - default_subscriptions: DefaultSubscriptionSetting[] + default_balance: number; + default_concurrency: number; + default_subscriptions: DefaultSubscriptionSetting[]; + auth_source_default_email_balance?: number; + auth_source_default_email_concurrency?: number; + auth_source_default_email_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_email_grant_on_signup?: boolean; + auth_source_default_email_grant_on_first_bind?: boolean; + auth_source_default_linuxdo_balance?: number; + auth_source_default_linuxdo_concurrency?: number; + auth_source_default_linuxdo_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_linuxdo_grant_on_signup?: boolean; + auth_source_default_linuxdo_grant_on_first_bind?: boolean; + auth_source_default_oidc_balance?: number; + auth_source_default_oidc_concurrency?: number; + auth_source_default_oidc_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_oidc_grant_on_signup?: boolean; + auth_source_default_oidc_grant_on_first_bind?: boolean; + auth_source_default_wechat_balance?: number; + auth_source_default_wechat_concurrency?: number; + auth_source_default_wechat_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_wechat_grant_on_signup?: boolean; + auth_source_default_wechat_grant_on_first_bind?: boolean; + force_email_on_third_party_signup?: boolean; // OEM settings - site_name: string - site_logo: string - site_subtitle: string - api_base_url: string - contact_info: string - doc_url: string - home_content: string - hide_ccs_import_button: boolean - table_default_page_size: number - table_page_size_options: number[] - backend_mode_enabled: boolean - custom_menu_items: CustomMenuItem[] - custom_endpoints: CustomEndpoint[] + site_name: string; + site_logo: string; + site_subtitle: string; + api_base_url: string; + contact_info: string; + doc_url: string; + home_content: string; + hide_ccs_import_button: boolean; + table_default_page_size: number; + table_page_size_options: number[]; + backend_mode_enabled: boolean; + custom_menu_items: CustomMenuItem[]; + custom_endpoints: CustomEndpoint[]; // SMTP settings - smtp_host: string - smtp_port: number - smtp_username: string - smtp_password_configured: boolean - smtp_from_email: string - smtp_from_name: string - smtp_use_tls: boolean + smtp_host: string; + smtp_port: number; + smtp_username: string; + smtp_password_configured: boolean; + smtp_from_email: string; + smtp_from_name: string; + smtp_use_tls: boolean; // Cloudflare Turnstile settings - turnstile_enabled: boolean - turnstile_site_key: string - turnstile_secret_key_configured: boolean + turnstile_enabled: boolean; + turnstile_site_key: string; + turnstile_secret_key_configured: boolean; // LinuxDo Connect OAuth settings - linuxdo_connect_enabled: boolean - linuxdo_connect_client_id: string - linuxdo_connect_client_secret_configured: boolean - linuxdo_connect_redirect_url: string + linuxdo_connect_enabled: boolean; + linuxdo_connect_client_id: string; + linuxdo_connect_client_secret_configured: boolean; + linuxdo_connect_redirect_url: string; + + // WeChat Connect OAuth settings + wechat_connect_enabled: boolean; + wechat_connect_app_id: string; + wechat_connect_app_secret_configured: boolean; + wechat_connect_open_app_id?: string; + wechat_connect_open_app_secret_configured?: boolean; + wechat_connect_mp_app_id?: string; + wechat_connect_mp_app_secret_configured?: boolean; + wechat_connect_mobile_app_id?: string; + wechat_connect_mobile_app_secret_configured?: boolean; + wechat_connect_open_enabled?: boolean; + wechat_connect_mp_enabled?: boolean; + wechat_connect_mobile_enabled?: boolean; + wechat_connect_mode: string; + wechat_connect_scopes: string; + wechat_connect_redirect_url: string; + wechat_connect_frontend_redirect_url: string; // Generic OIDC OAuth settings - oidc_connect_enabled: boolean - oidc_connect_provider_name: string - oidc_connect_client_id: string - oidc_connect_client_secret_configured: boolean - oidc_connect_issuer_url: string - oidc_connect_discovery_url: string - oidc_connect_authorize_url: string - oidc_connect_token_url: string - oidc_connect_userinfo_url: string - oidc_connect_jwks_url: string - oidc_connect_scopes: string - oidc_connect_redirect_url: string - oidc_connect_frontend_redirect_url: string - oidc_connect_token_auth_method: string - oidc_connect_use_pkce: boolean - oidc_connect_validate_id_token: boolean - oidc_connect_allowed_signing_algs: string - oidc_connect_clock_skew_seconds: number - oidc_connect_require_email_verified: boolean - oidc_connect_userinfo_email_path: string - oidc_connect_userinfo_id_path: string - oidc_connect_userinfo_username_path: string + oidc_connect_enabled: boolean; + oidc_connect_provider_name: string; + oidc_connect_client_id: string; + oidc_connect_client_secret_configured: boolean; + oidc_connect_issuer_url: string; + oidc_connect_discovery_url: string; + oidc_connect_authorize_url: string; + oidc_connect_token_url: string; + oidc_connect_userinfo_url: string; + oidc_connect_jwks_url: string; + oidc_connect_scopes: string; + oidc_connect_redirect_url: string; + oidc_connect_frontend_redirect_url: string; + oidc_connect_token_auth_method: string; + oidc_connect_use_pkce: boolean; + oidc_connect_validate_id_token: boolean; + oidc_connect_allowed_signing_algs: string; + oidc_connect_clock_skew_seconds: number; + oidc_connect_require_email_verified: boolean; + oidc_connect_userinfo_email_path: string; + oidc_connect_userinfo_id_path: string; + oidc_connect_userinfo_username_path: string; // Model fallback configuration - enable_model_fallback: boolean - fallback_model_anthropic: string - fallback_model_openai: string - fallback_model_gemini: string - fallback_model_antigravity: string + enable_model_fallback: boolean; + fallback_model_anthropic: string; + fallback_model_openai: string; + fallback_model_gemini: string; + fallback_model_antigravity: string; // Identity patch configuration (Claude -> Gemini) - enable_identity_patch: boolean - identity_patch_prompt: string + enable_identity_patch: boolean; + identity_patch_prompt: string; // Ops Monitoring (vNext) - ops_monitoring_enabled: boolean - ops_realtime_monitoring_enabled: boolean - ops_query_mode_default: 'auto' | 'raw' | 'preagg' | string - ops_metrics_interval_seconds: number + ops_monitoring_enabled: boolean; + ops_realtime_monitoring_enabled: boolean; + ops_query_mode_default: "auto" | "raw" | "preagg" | string; + ops_metrics_interval_seconds: number; // Claude Code version check - min_claude_code_version: string - max_claude_code_version: string + min_claude_code_version: string; + max_claude_code_version: string; // 分组隔离 - allow_ungrouped_key_scheduling: boolean + allow_ungrouped_key_scheduling: boolean; // Gateway forwarding behavior - enable_fingerprint_unification: boolean - enable_metadata_passthrough: boolean - enable_cch_signing: boolean - web_search_emulation_enabled?: boolean + enable_fingerprint_unification: boolean; + enable_metadata_passthrough: boolean; + enable_cch_signing: boolean; + web_search_emulation_enabled?: boolean; // Payment configuration - payment_enabled: boolean - payment_min_amount: number - payment_max_amount: number - payment_daily_limit: number - payment_order_timeout_minutes: number - payment_max_pending_orders: number - payment_enabled_types: string[] - payment_balance_disabled: boolean - payment_balance_recharge_multiplier: number - payment_recharge_fee_rate: number - payment_load_balance_strategy: string - payment_product_name_prefix: string - payment_product_name_suffix: string - payment_help_image_url: string - payment_help_text: string - payment_cancel_rate_limit_enabled: boolean - payment_cancel_rate_limit_max: number - payment_cancel_rate_limit_window: number - payment_cancel_rate_limit_unit: string - payment_cancel_rate_limit_window_mode: string + payment_enabled: boolean; + payment_min_amount: number; + payment_max_amount: number; + payment_daily_limit: number; + payment_order_timeout_minutes: number; + payment_max_pending_orders: number; + payment_enabled_types: string[]; + payment_balance_disabled: boolean; + payment_balance_recharge_multiplier: number; + payment_recharge_fee_rate: number; + payment_load_balance_strategy: string; + payment_product_name_prefix: string; + payment_product_name_suffix: string; + payment_help_image_url: string; + payment_help_text: string; + payment_cancel_rate_limit_enabled: boolean; + payment_cancel_rate_limit_max: number; + payment_cancel_rate_limit_window: number; + payment_cancel_rate_limit_unit: string; + payment_cancel_rate_limit_window_mode: string; + payment_visible_method_alipay_source?: string; + payment_visible_method_wxpay_source?: string; + payment_visible_method_alipay_enabled?: boolean; + payment_visible_method_wxpay_enabled?: boolean; + openai_advanced_scheduler_enabled?: boolean; // Balance & quota notification - balance_low_notify_enabled: boolean - balance_low_notify_threshold: number - balance_low_notify_recharge_url: string - account_quota_notify_enabled: boolean - account_quota_notify_emails: NotifyEmailEntry[] + balance_low_notify_enabled: boolean; + balance_low_notify_threshold: number; + balance_low_notify_recharge_url: string; + account_quota_notify_enabled: boolean; + account_quota_notify_emails: NotifyEmailEntry[]; } export interface UpdateSettingsRequest { - registration_enabled?: boolean - email_verify_enabled?: boolean - registration_email_suffix_whitelist?: string[] - promo_code_enabled?: boolean - password_reset_enabled?: boolean - frontend_url?: string - invitation_code_enabled?: boolean - totp_enabled?: boolean // TOTP 双因素认证 - default_balance?: number - default_concurrency?: number - default_subscriptions?: DefaultSubscriptionSetting[] - site_name?: string - site_logo?: string - site_subtitle?: string - api_base_url?: string - contact_info?: string - doc_url?: string - home_content?: string - hide_ccs_import_button?: boolean - table_default_page_size?: number - table_page_size_options?: number[] - backend_mode_enabled?: boolean - custom_menu_items?: CustomMenuItem[] - custom_endpoints?: CustomEndpoint[] - smtp_host?: string - smtp_port?: number - smtp_username?: string - smtp_password?: string - smtp_from_email?: string - smtp_from_name?: string - smtp_use_tls?: boolean - turnstile_enabled?: boolean - turnstile_site_key?: string - turnstile_secret_key?: string - linuxdo_connect_enabled?: boolean - linuxdo_connect_client_id?: string - linuxdo_connect_client_secret?: string - linuxdo_connect_redirect_url?: string - oidc_connect_enabled?: boolean - oidc_connect_provider_name?: string - oidc_connect_client_id?: string - oidc_connect_client_secret?: string - oidc_connect_issuer_url?: string - oidc_connect_discovery_url?: string - oidc_connect_authorize_url?: string - oidc_connect_token_url?: string - oidc_connect_userinfo_url?: string - oidc_connect_jwks_url?: string - oidc_connect_scopes?: string - oidc_connect_redirect_url?: string - oidc_connect_frontend_redirect_url?: string - oidc_connect_token_auth_method?: string - oidc_connect_use_pkce?: boolean - oidc_connect_validate_id_token?: boolean - oidc_connect_allowed_signing_algs?: string - oidc_connect_clock_skew_seconds?: number - oidc_connect_require_email_verified?: boolean - oidc_connect_userinfo_email_path?: string - oidc_connect_userinfo_id_path?: string - oidc_connect_userinfo_username_path?: string - enable_model_fallback?: boolean - fallback_model_anthropic?: string - fallback_model_openai?: string - fallback_model_gemini?: string - fallback_model_antigravity?: string - enable_identity_patch?: boolean - identity_patch_prompt?: string - ops_monitoring_enabled?: boolean - ops_realtime_monitoring_enabled?: boolean - ops_query_mode_default?: 'auto' | 'raw' | 'preagg' | string - ops_metrics_interval_seconds?: number - min_claude_code_version?: string - max_claude_code_version?: string - allow_ungrouped_key_scheduling?: boolean - enable_fingerprint_unification?: boolean - enable_metadata_passthrough?: boolean - enable_cch_signing?: boolean + registration_enabled?: boolean; + email_verify_enabled?: boolean; + registration_email_suffix_whitelist?: string[]; + promo_code_enabled?: boolean; + password_reset_enabled?: boolean; + frontend_url?: string; + invitation_code_enabled?: boolean; + totp_enabled?: boolean; // TOTP 双因素认证 + default_balance?: number; + default_concurrency?: number; + default_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_email_balance?: number; + auth_source_default_email_concurrency?: number; + auth_source_default_email_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_email_grant_on_signup?: boolean; + auth_source_default_email_grant_on_first_bind?: boolean; + auth_source_default_linuxdo_balance?: number; + auth_source_default_linuxdo_concurrency?: number; + auth_source_default_linuxdo_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_linuxdo_grant_on_signup?: boolean; + auth_source_default_linuxdo_grant_on_first_bind?: boolean; + auth_source_default_oidc_balance?: number; + auth_source_default_oidc_concurrency?: number; + auth_source_default_oidc_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_oidc_grant_on_signup?: boolean; + auth_source_default_oidc_grant_on_first_bind?: boolean; + auth_source_default_wechat_balance?: number; + auth_source_default_wechat_concurrency?: number; + auth_source_default_wechat_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_wechat_grant_on_signup?: boolean; + auth_source_default_wechat_grant_on_first_bind?: boolean; + force_email_on_third_party_signup?: boolean; + site_name?: string; + site_logo?: string; + site_subtitle?: string; + api_base_url?: string; + contact_info?: string; + doc_url?: string; + home_content?: string; + hide_ccs_import_button?: boolean; + table_default_page_size?: number; + table_page_size_options?: number[]; + backend_mode_enabled?: boolean; + custom_menu_items?: CustomMenuItem[]; + custom_endpoints?: CustomEndpoint[]; + smtp_host?: string; + smtp_port?: number; + smtp_username?: string; + smtp_password?: string; + smtp_from_email?: string; + smtp_from_name?: string; + smtp_use_tls?: boolean; + turnstile_enabled?: boolean; + turnstile_site_key?: string; + turnstile_secret_key?: string; + linuxdo_connect_enabled?: boolean; + linuxdo_connect_client_id?: string; + linuxdo_connect_client_secret?: string; + linuxdo_connect_redirect_url?: string; + wechat_connect_enabled?: boolean; + wechat_connect_app_id?: string; + wechat_connect_app_secret?: string; + wechat_connect_open_app_id?: string; + wechat_connect_open_app_secret?: string; + wechat_connect_mp_app_id?: string; + wechat_connect_mp_app_secret?: string; + wechat_connect_mobile_app_id?: string; + wechat_connect_mobile_app_secret?: string; + wechat_connect_open_enabled?: boolean; + wechat_connect_mp_enabled?: boolean; + wechat_connect_mobile_enabled?: boolean; + wechat_connect_mode?: string; + wechat_connect_scopes?: string; + wechat_connect_redirect_url?: string; + wechat_connect_frontend_redirect_url?: string; + oidc_connect_enabled?: boolean; + oidc_connect_provider_name?: string; + oidc_connect_client_id?: string; + oidc_connect_client_secret?: string; + oidc_connect_issuer_url?: string; + oidc_connect_discovery_url?: string; + oidc_connect_authorize_url?: string; + oidc_connect_token_url?: string; + oidc_connect_userinfo_url?: string; + oidc_connect_jwks_url?: string; + oidc_connect_scopes?: string; + oidc_connect_redirect_url?: string; + oidc_connect_frontend_redirect_url?: string; + oidc_connect_token_auth_method?: string; + oidc_connect_use_pkce?: boolean; + oidc_connect_validate_id_token?: boolean; + oidc_connect_allowed_signing_algs?: string; + oidc_connect_clock_skew_seconds?: number; + oidc_connect_require_email_verified?: boolean; + oidc_connect_userinfo_email_path?: string; + oidc_connect_userinfo_id_path?: string; + oidc_connect_userinfo_username_path?: string; + enable_model_fallback?: boolean; + fallback_model_anthropic?: string; + fallback_model_openai?: string; + fallback_model_gemini?: string; + fallback_model_antigravity?: string; + enable_identity_patch?: boolean; + identity_patch_prompt?: string; + ops_monitoring_enabled?: boolean; + ops_realtime_monitoring_enabled?: boolean; + ops_query_mode_default?: "auto" | "raw" | "preagg" | string; + ops_metrics_interval_seconds?: number; + min_claude_code_version?: string; + max_claude_code_version?: string; + allow_ungrouped_key_scheduling?: boolean; + enable_fingerprint_unification?: boolean; + enable_metadata_passthrough?: boolean; + enable_cch_signing?: boolean; // Payment configuration - payment_enabled?: boolean - payment_min_amount?: number - payment_max_amount?: number - payment_daily_limit?: number - payment_order_timeout_minutes?: number - payment_max_pending_orders?: number - payment_enabled_types?: string[] - payment_balance_disabled?: boolean - payment_balance_recharge_multiplier?: number - payment_recharge_fee_rate?: number - payment_load_balance_strategy?: string - payment_product_name_prefix?: string - payment_product_name_suffix?: string - payment_help_image_url?: string - payment_help_text?: string - payment_cancel_rate_limit_enabled?: boolean - payment_cancel_rate_limit_max?: number - payment_cancel_rate_limit_window?: number - payment_cancel_rate_limit_unit?: string - payment_cancel_rate_limit_window_mode?: string + payment_enabled?: boolean; + payment_min_amount?: number; + payment_max_amount?: number; + payment_daily_limit?: number; + payment_order_timeout_minutes?: number; + payment_max_pending_orders?: number; + payment_enabled_types?: string[]; + payment_balance_disabled?: boolean; + payment_balance_recharge_multiplier?: number; + payment_recharge_fee_rate?: number; + payment_load_balance_strategy?: string; + payment_product_name_prefix?: string; + payment_product_name_suffix?: string; + payment_help_image_url?: string; + payment_help_text?: string; + payment_cancel_rate_limit_enabled?: boolean; + payment_cancel_rate_limit_max?: number; + payment_cancel_rate_limit_window?: number; + payment_cancel_rate_limit_unit?: string; + payment_cancel_rate_limit_window_mode?: string; + payment_visible_method_alipay_source?: string; + payment_visible_method_wxpay_source?: string; + payment_visible_method_alipay_enabled?: boolean; + payment_visible_method_wxpay_enabled?: boolean; + openai_advanced_scheduler_enabled?: boolean; // Balance & quota notification - balance_low_notify_enabled?: boolean - balance_low_notify_threshold?: number - balance_low_notify_recharge_url?: string - account_quota_notify_enabled?: boolean - account_quota_notify_emails?: NotifyEmailEntry[] + balance_low_notify_enabled?: boolean; + balance_low_notify_threshold?: number; + balance_low_notify_recharge_url?: string; + account_quota_notify_enabled?: boolean; + account_quota_notify_emails?: NotifyEmailEntry[]; } /** @@ -258,8 +625,8 @@ export interface UpdateSettingsRequest { * @returns System settings */ export async function getSettings(): Promise { - const { data } = await apiClient.get('/admin/settings') - return data + const { data } = await apiClient.get("/admin/settings"); + return data; } /** @@ -267,20 +634,25 @@ export async function getSettings(): Promise { * @param settings - Partial settings to update * @returns Updated settings */ -export async function updateSettings(settings: UpdateSettingsRequest): Promise { - const { data } = await apiClient.put('/admin/settings', settings) - return data +export async function updateSettings( + settings: UpdateSettingsRequest, +): Promise { + const { data } = await apiClient.put( + "/admin/settings", + settings, + ); + return data; } /** * Test SMTP connection request */ export interface TestSmtpRequest { - smtp_host: string - smtp_port: number - smtp_username: string - smtp_password: string - smtp_use_tls: boolean + smtp_host: string; + smtp_port: number; + smtp_username: string; + smtp_password: string; + smtp_use_tls: boolean; } /** @@ -288,23 +660,28 @@ export interface TestSmtpRequest { * @param config - SMTP configuration to test * @returns Test result message */ -export async function testSmtpConnection(config: TestSmtpRequest): Promise<{ message: string }> { - const { data } = await apiClient.post<{ message: string }>('/admin/settings/test-smtp', config) - return data +export async function testSmtpConnection( + config: TestSmtpRequest, +): Promise<{ message: string }> { + const { data } = await apiClient.post<{ message: string }>( + "/admin/settings/test-smtp", + config, + ); + return data; } /** * Send test email request */ export interface SendTestEmailRequest { - email: string - smtp_host: string - smtp_port: number - smtp_username: string - smtp_password: string - smtp_from_email: string - smtp_from_name: string - smtp_use_tls: boolean + email: string; + smtp_host: string; + smtp_port: number; + smtp_username: string; + smtp_password: string; + smtp_from_email: string; + smtp_from_name: string; + smtp_use_tls: boolean; } /** @@ -312,20 +689,22 @@ export interface SendTestEmailRequest { * @param request - Email address and SMTP config * @returns Test result message */ -export async function sendTestEmail(request: SendTestEmailRequest): Promise<{ message: string }> { +export async function sendTestEmail( + request: SendTestEmailRequest, +): Promise<{ message: string }> { const { data } = await apiClient.post<{ message: string }>( - '/admin/settings/send-test-email', - request - ) - return data + "/admin/settings/send-test-email", + request, + ); + return data; } /** * Admin API Key status response */ export interface AdminApiKeyStatus { - exists: boolean - masked_key: string + exists: boolean; + masked_key: string; } /** @@ -333,8 +712,10 @@ export interface AdminApiKeyStatus { * @returns Status indicating if key exists and masked version */ export async function getAdminApiKey(): Promise { - const { data } = await apiClient.get('/admin/settings/admin-api-key') - return data + const { data } = await apiClient.get( + "/admin/settings/admin-api-key", + ); + return data; } /** @@ -342,8 +723,10 @@ export async function getAdminApiKey(): Promise { * @returns The new full API key (only shown once) */ export async function regenerateAdminApiKey(): Promise<{ key: string }> { - const { data } = await apiClient.post<{ key: string }>('/admin/settings/admin-api-key/regenerate') - return data + const { data } = await apiClient.post<{ key: string }>( + "/admin/settings/admin-api-key/regenerate", + ); + return data; } /** @@ -351,8 +734,10 @@ export async function regenerateAdminApiKey(): Promise<{ key: string }> { * @returns Success message */ export async function deleteAdminApiKey(): Promise<{ message: string }> { - const { data } = await apiClient.delete<{ message: string }>('/admin/settings/admin-api-key') - return data + const { data } = await apiClient.delete<{ message: string }>( + "/admin/settings/admin-api-key", + ); + return data; } // ==================== Overload Cooldown Settings ==================== @@ -361,23 +746,25 @@ export async function deleteAdminApiKey(): Promise<{ message: string }> { * Overload cooldown settings interface (529 handling) */ export interface OverloadCooldownSettings { - enabled: boolean - cooldown_minutes: number + enabled: boolean; + cooldown_minutes: number; } export async function getOverloadCooldownSettings(): Promise { - const { data } = await apiClient.get('/admin/settings/overload-cooldown') - return data + const { data } = await apiClient.get( + "/admin/settings/overload-cooldown", + ); + return data; } export async function updateOverloadCooldownSettings( - settings: OverloadCooldownSettings + settings: OverloadCooldownSettings, ): Promise { const { data } = await apiClient.put( - '/admin/settings/overload-cooldown', - settings - ) - return data + "/admin/settings/overload-cooldown", + settings, + ); + return data; } // ==================== Stream Timeout Settings ==================== @@ -386,11 +773,11 @@ export async function updateOverloadCooldownSettings( * Stream timeout settings interface */ export interface StreamTimeoutSettings { - enabled: boolean - action: 'temp_unsched' | 'error' | 'none' - temp_unsched_minutes: number - threshold_count: number - threshold_window_minutes: number + enabled: boolean; + action: "temp_unsched" | "error" | "none"; + temp_unsched_minutes: number; + threshold_count: number; + threshold_window_minutes: number; } /** @@ -398,8 +785,10 @@ export interface StreamTimeoutSettings { * @returns Stream timeout settings */ export async function getStreamTimeoutSettings(): Promise { - const { data } = await apiClient.get('/admin/settings/stream-timeout') - return data + const { data } = await apiClient.get( + "/admin/settings/stream-timeout", + ); + return data; } /** @@ -408,13 +797,13 @@ export async function getStreamTimeoutSettings(): Promise * @returns Updated settings */ export async function updateStreamTimeoutSettings( - settings: StreamTimeoutSettings + settings: StreamTimeoutSettings, ): Promise { const { data } = await apiClient.put( - '/admin/settings/stream-timeout', - settings - ) - return data + "/admin/settings/stream-timeout", + settings, + ); + return data; } // ==================== Rectifier Settings ==================== @@ -423,11 +812,11 @@ export async function updateStreamTimeoutSettings( * Rectifier settings interface */ export interface RectifierSettings { - enabled: boolean - thinking_signature_enabled: boolean - thinking_budget_enabled: boolean - apikey_signature_enabled: boolean - apikey_signature_patterns: string[] + enabled: boolean; + thinking_signature_enabled: boolean; + thinking_budget_enabled: boolean; + apikey_signature_enabled: boolean; + apikey_signature_patterns: string[]; } /** @@ -435,8 +824,10 @@ export interface RectifierSettings { * @returns Rectifier settings */ export async function getRectifierSettings(): Promise { - const { data } = await apiClient.get('/admin/settings/rectifier') - return data + const { data } = await apiClient.get( + "/admin/settings/rectifier", + ); + return data; } /** @@ -445,13 +836,13 @@ export async function getRectifierSettings(): Promise { * @returns Updated settings */ export async function updateRectifierSettings( - settings: RectifierSettings + settings: RectifierSettings, ): Promise { const { data } = await apiClient.put( - '/admin/settings/rectifier', - settings - ) - return data + "/admin/settings/rectifier", + settings, + ); + return data; } // ==================== Beta Policy Settings ==================== @@ -460,20 +851,20 @@ export async function updateRectifierSettings( * Beta policy rule interface */ export interface BetaPolicyRule { - beta_token: string - action: 'pass' | 'filter' | 'block' - scope: 'all' | 'oauth' | 'apikey' | 'bedrock' - error_message?: string - model_whitelist?: string[] - fallback_action?: 'pass' | 'filter' | 'block' - fallback_error_message?: string + beta_token: string; + action: "pass" | "filter" | "block"; + scope: "all" | "oauth" | "apikey" | "bedrock"; + error_message?: string; + model_whitelist?: string[]; + fallback_action?: "pass" | "filter" | "block"; + fallback_error_message?: string; } /** * Beta policy settings interface */ export interface BetaPolicySettings { - rules: BetaPolicyRule[] + rules: BetaPolicyRule[]; } /** @@ -481,8 +872,10 @@ export interface BetaPolicySettings { * @returns Beta policy settings */ export async function getBetaPolicySettings(): Promise { - const { data } = await apiClient.get('/admin/settings/beta-policy') - return data + const { data } = await apiClient.get( + "/admin/settings/beta-policy", + ); + return data; } /** @@ -491,70 +884,73 @@ export async function getBetaPolicySettings(): Promise { * @returns Updated settings */ export async function updateBetaPolicySettings( - settings: BetaPolicySettings + settings: BetaPolicySettings, ): Promise { const { data } = await apiClient.put( - '/admin/settings/beta-policy', - settings - ) - return data + "/admin/settings/beta-policy", + settings, + ); + return data; } // --- Web Search Emulation Config --- export interface WebSearchProviderConfig { - type: 'brave' | 'tavily' - api_key: string - api_key_configured: boolean - quota_limit: number | null - subscribed_at: number | null - quota_used?: number - proxy_id: number | null - expires_at: number | null + type: "brave" | "tavily"; + api_key: string; + api_key_configured: boolean; + quota_limit: number | null; + subscribed_at: number | null; + quota_used?: number; + proxy_id: number | null; + expires_at: number | null; } export interface WebSearchEmulationConfig { - enabled: boolean - providers: WebSearchProviderConfig[] + enabled: boolean; + providers: WebSearchProviderConfig[]; } export interface WebSearchTestResult { - provider: string - results: { url: string; title: string; snippet: string; page_age?: string }[] - query: string + provider: string; + results: { url: string; title: string; snippet: string; page_age?: string }[]; + query: string; } export async function getWebSearchEmulationConfig(): Promise { const { data } = await apiClient.get( - '/admin/settings/web-search-emulation' - ) - return data + "/admin/settings/web-search-emulation", + ); + return data; } export async function updateWebSearchEmulationConfig( - config: WebSearchEmulationConfig + config: WebSearchEmulationConfig, ): Promise { const { data } = await apiClient.put( - '/admin/settings/web-search-emulation', - config - ) - return data + "/admin/settings/web-search-emulation", + config, + ); + return data; } export async function testWebSearchEmulation( - query: string + query: string, ): Promise { const { data } = await apiClient.post( - '/admin/settings/web-search-emulation/test', - { query } - ) - return data + "/admin/settings/web-search-emulation/test", + { query }, + ); + return data; } -export async function resetWebSearchUsage( - payload: { provider_type: string } -): Promise { - await apiClient.post('/admin/settings/web-search-emulation/reset-usage', payload) +export async function resetWebSearchUsage(payload: { + provider_type: string; +}): Promise { + await apiClient.post( + "/admin/settings/web-search-emulation/reset-usage", + payload, + ); } export const settingsAPI = { @@ -576,7 +972,7 @@ export const settingsAPI = { getWebSearchEmulationConfig, updateWebSearchEmulationConfig, testWebSearchEmulation, - resetWebSearchUsage -} + resetWebSearchUsage, +}; -export default settingsAPI +export default settingsAPI; diff --git a/frontend/src/api/admin/users.ts b/frontend/src/api/admin/users.ts index 39cb1dfa69217d7492e592d955cc5d7f2eb2aa63..1bb3d54c75ac7543b66218c29457f9b9d6da217e 100644 --- a/frontend/src/api/admin/users.ts +++ b/frontend/src/api/admin/users.ts @@ -6,6 +6,30 @@ import { apiClient } from '../client' import type { AdminUser, UpdateUserRequest, PaginatedResponse, ApiKey } from '@/types' +export interface AdminBindAuthIdentityChannelRequest { + channel: string + channel_app_id?: string + channel_subject: string + metadata?: Record +} + +export interface AdminBindAuthIdentityRequest { + provider_type: string + provider_key: string + provider_subject: string + issuer?: string + metadata?: Record + channel?: AdminBindAuthIdentityChannelRequest +} + +export interface AdminBoundAuthIdentity { + identity_id: number + provider_type: string + provider_key: string + provider_subject: string + channel_id?: number | null +} + /** * List all users with pagination * @param page - Page number (default: 1) @@ -248,6 +272,17 @@ export async function replaceGroup( return data } +export async function bindUserAuthIdentity( + userId: number, + input: AdminBindAuthIdentityRequest +): Promise { + const { data } = await apiClient.post( + `/admin/users/${userId}/auth-identities`, + input + ) + return data +} + export const usersAPI = { list, getById, @@ -260,7 +295,8 @@ export const usersAPI = { getUserApiKeys, getUserUsageStats, getUserBalanceHistory, - replaceGroup + replaceGroup, + bindUserAuthIdentity } export default usersAPI diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts index 837c4f4cf7068fe1ce394f49ad83131d346cff01..9244489ca701c5578403d4f90a36e45d98518302 100644 --- a/frontend/src/api/auth.ts +++ b/frontend/src/api/auth.ts @@ -186,6 +186,127 @@ export interface RefreshTokenResponse { token_type: string } +export interface OAuthTokenResponse { + access_token: string + refresh_token?: string + expires_in?: number + token_type?: string +} + +export interface PendingOAuthBindLoginResponse extends Partial { + redirect?: string + error?: string + requires_2fa?: boolean + temp_token?: string + user_email_masked?: string + adoption_required?: boolean + suggested_display_name?: string + suggested_avatar_url?: string +} + +export type PendingOAuthExchangeResponse = PendingOAuthBindLoginResponse + +export interface PendingOAuthCreateAccountResponse extends OAuthTokenResponse {} + +export interface PendingOAuthSendVerifyCodeResponse extends SendVerifyCodeResponse { + auth_result?: string + provider?: string + redirect?: string +} + +export type OAuthCompletionKind = 'login' | 'bind' + +export interface OAuthAdoptionDecision { + adoptDisplayName?: boolean + adoptAvatar?: boolean +} + +function serializeOAuthAdoptionDecision( + decision?: OAuthAdoptionDecision +): Record { + const payload: Record = {} + + if (typeof decision?.adoptDisplayName === 'boolean') { + payload.adopt_display_name = decision.adoptDisplayName + } + if (typeof decision?.adoptAvatar === 'boolean') { + payload.adopt_avatar = decision.adoptAvatar + } + + return payload +} + +export function isOAuthLoginCompletion( + completion: Partial +): completion is OAuthTokenResponse { + return typeof completion.access_token === 'string' && completion.access_token.trim().length > 0 +} + +export function getOAuthCompletionKind( + completion: Partial +): OAuthCompletionKind { + return isOAuthLoginCompletion(completion) ? 'login' : 'bind' +} + +export function getPendingOAuthBindLoginKind( + completion: PendingOAuthBindLoginResponse +): OAuthCompletionKind { + return getOAuthCompletionKind(completion) +} + +export function isPendingOAuthCreateAccountRequired( + completion: Pick +): boolean { + return completion.error === 'invitation_required' +} + +export function hasPendingOAuthSuggestedProfile( + completion: Pick< + PendingOAuthBindLoginResponse, + 'suggested_display_name' | 'suggested_avatar_url' + > +): boolean { + return Boolean(completion.suggested_display_name || completion.suggested_avatar_url) +} + +export function persistOAuthTokenContext(tokens: Partial): void { + if (tokens.refresh_token) { + setRefreshToken(tokens.refresh_token) + } + if (tokens.expires_in) { + setTokenExpiresAt(tokens.expires_in) + } +} + +export function prepareOAuthBindAccessTokenCookie(): void { + if (typeof document === 'undefined' || typeof window === 'undefined') { + return + } + + const token = getAuthToken() + if (!token) { + return + } + + const secure = window.location.protocol === 'https:' ? '; Secure' : '' + const path = resolveOAuthBindCookiePath() + document.cookie = + `oauth_bind_access_token=${encodeURIComponent(token)}; Path=${path}/auth/oauth; Max-Age=600; SameSite=Lax${secure}` +} + +function resolveOAuthBindCookiePath(): string { + const apiBase = ((import.meta.env.VITE_API_BASE_URL as string | undefined) || '/api/v1').replace(/\/$/, '') + + try { + return new URL(apiBase, window.location.origin).pathname.replace(/\/$/, '') || '/api/v1' + } catch { + if (apiBase.startsWith('/')) { + return apiBase + } + return '/api/v1' + } +} + /** * Refresh the access token using the refresh token * @returns New token pair @@ -234,6 +355,116 @@ export async function getPublicSettings(): Promise { return data } +export type WeChatOAuthMode = 'open' | 'mp' +export type WeChatOAuthUnavailableReason = + | 'not_configured' + | 'capability_unknown' + | 'external_browser_required' + | 'wechat_browser_required' + | 'native_app_required' + +export interface ResolvedWeChatOAuthStart { + mode: WeChatOAuthMode | null + openEnabled: boolean + mpEnabled: boolean + mobileEnabled: boolean + isWeChatBrowser: boolean + unavailableReason: WeChatOAuthUnavailableReason | null +} + +export type WeChatOAuthPublicSettings = { + wechat_oauth_enabled?: boolean + wechat_oauth_open_enabled?: boolean + wechat_oauth_mp_enabled?: boolean + wechat_oauth_mobile_enabled?: boolean +} + +export function isWeChatWebOAuthEnabled( + settings: WeChatOAuthPublicSettings | null | undefined, +): boolean { + const legacyEnabled = settings?.wechat_oauth_enabled ?? false + const hasExplicitCapabilities = + typeof settings?.wechat_oauth_open_enabled === 'boolean' || + typeof settings?.wechat_oauth_mp_enabled === 'boolean' + + if (!hasExplicitCapabilities) { + return legacyEnabled + } + + return settings?.wechat_oauth_open_enabled === true || settings?.wechat_oauth_mp_enabled === true +} + +export function hasExplicitWeChatOAuthCapabilities( + settings: WeChatOAuthPublicSettings | null | undefined, +): settings is WeChatOAuthPublicSettings & { + wechat_oauth_open_enabled: boolean + wechat_oauth_mp_enabled: boolean +} { + return typeof settings?.wechat_oauth_open_enabled === 'boolean' + && typeof settings?.wechat_oauth_mp_enabled === 'boolean' +} + +export function resolveWeChatOAuthStart( + settings: WeChatOAuthPublicSettings | null | undefined, + userAgent?: string +): ResolvedWeChatOAuthStart { + const normalizedUserAgent = (userAgent + ?? (typeof navigator !== 'undefined' ? navigator.userAgent : '') + ?? '').trim() + const isWeChatBrowser = /MicroMessenger/i.test(normalizedUserAgent) + const legacyEnabled = settings?.wechat_oauth_enabled ?? false + const openEnabled = typeof settings?.wechat_oauth_open_enabled === 'boolean' + ? settings.wechat_oauth_open_enabled + : legacyEnabled + const mpEnabled = typeof settings?.wechat_oauth_mp_enabled === 'boolean' + ? settings.wechat_oauth_mp_enabled + : legacyEnabled + const mobileEnabled = typeof settings?.wechat_oauth_mobile_enabled === 'boolean' + ? settings.wechat_oauth_mobile_enabled + : false + + if (isWeChatBrowser) { + if (mpEnabled) { + return { mode: 'mp', openEnabled, mpEnabled, mobileEnabled, isWeChatBrowser, unavailableReason: null } + } + if (openEnabled) { + return { mode: null, openEnabled, mpEnabled, mobileEnabled, isWeChatBrowser, unavailableReason: 'external_browser_required' } + } + return { mode: null, openEnabled, mpEnabled, mobileEnabled, isWeChatBrowser, unavailableReason: 'not_configured' } + } + + if (openEnabled) { + return { mode: 'open', openEnabled, mpEnabled, mobileEnabled, isWeChatBrowser, unavailableReason: null } + } + if (mpEnabled) { + return { mode: null, openEnabled, mpEnabled, mobileEnabled, isWeChatBrowser, unavailableReason: 'wechat_browser_required' } + } + return { mode: null, openEnabled, mpEnabled, mobileEnabled, isWeChatBrowser, unavailableReason: 'not_configured' } +} + +export function resolveWeChatOAuthStartStrict( + settings: WeChatOAuthPublicSettings | null | undefined, + userAgent?: string, +): ResolvedWeChatOAuthStart { + const normalizedUserAgent = (userAgent + ?? (typeof navigator !== 'undefined' ? navigator.userAgent : '') + ?? '').trim() + const isWeChatBrowser = /MicroMessenger/i.test(normalizedUserAgent) + + if (!hasExplicitWeChatOAuthCapabilities(settings)) { + return { + mode: null, + openEnabled: false, + mpEnabled: false, + mobileEnabled: false, + isWeChatBrowser, + unavailableReason: 'capability_unknown', + } + } + + return resolveWeChatOAuthStart(settings, normalizedUserAgent) +} + /** * Send verification code to email * @param request - Email and optional Turnstile token @@ -246,6 +477,16 @@ export async function sendVerifyCode( return data } +export async function sendPendingOAuthVerifyCode( + request: SendVerifyCodeRequest +): Promise { + const { data } = await apiClient.post( + '/auth/oauth/pending/send-verify-code', + request + ) + return data +} + /** * Validate promo code response */ @@ -337,48 +578,87 @@ export async function resetPassword(request: ResetPasswordRequest): Promise { - const { data } = await apiClient.post<{ - access_token: string - refresh_token: string - expires_in: number - token_type: string - }>('/auth/oauth/linuxdo/complete-registration', { - pending_oauth_token: pendingOAuthToken, - invitation_code: invitationCode - }) - return data + invitationCode: string, + decision?: OAuthAdoptionDecision +): Promise { + return createPendingLinuxDoOAuthAccount(invitationCode, decision) } /** * Complete OIDC OAuth registration by supplying an invitation code - * @param pendingOAuthToken - Short-lived JWT from the OAuth callback * @param invitationCode - Invitation code entered by the user * @returns Token pair on success */ export async function completeOIDCOAuthRegistration( - pendingOAuthToken: string, - invitationCode: string -): Promise<{ access_token: string; refresh_token: string; expires_in: number; token_type: string }> { - const { data } = await apiClient.post<{ - access_token: string - refresh_token: string - expires_in: number - token_type: string - }>('/auth/oauth/oidc/complete-registration', { - pending_oauth_token: pendingOAuthToken, - invitation_code: invitationCode - }) + invitationCode: string, + decision?: OAuthAdoptionDecision +): Promise { + return createPendingOIDCOAuthAccount(invitationCode, decision) +} + +export async function completeWeChatOAuthRegistration( + invitationCode: string, + decision?: OAuthAdoptionDecision +): Promise { + return createPendingWeChatOAuthAccount(invitationCode, decision) +} + +async function createPendingOAuthAccount( + provider: 'linuxdo' | 'oidc' | 'wechat', + invitationCode: string, + decision?: OAuthAdoptionDecision +): Promise { + const { data } = await apiClient.post( + `/auth/oauth/${provider}/complete-registration`, + { + invitation_code: invitationCode, + ...serializeOAuthAdoptionDecision(decision) + } + ) + return data +} + +export async function createPendingLinuxDoOAuthAccount( + invitationCode: string, + decision?: OAuthAdoptionDecision +): Promise { + return createPendingOAuthAccount('linuxdo', invitationCode, decision) +} + +export async function createPendingOIDCOAuthAccount( + invitationCode: string, + decision?: OAuthAdoptionDecision +): Promise { + return createPendingOAuthAccount('oidc', invitationCode, decision) +} + +export async function createPendingWeChatOAuthAccount( + invitationCode: string, + decision?: OAuthAdoptionDecision +): Promise { + return createPendingOAuthAccount('wechat', invitationCode, decision) +} + +export async function completePendingOAuthBindLogin( + decision?: OAuthAdoptionDecision +): Promise { + const { data } = await apiClient.post( + '/auth/oauth/pending/exchange', + serializeOAuthAdoptionDecision(decision) + ) return data } +export async function exchangePendingOAuthCompletion( + decision?: OAuthAdoptionDecision +): Promise { + return completePendingOAuthBindLogin(decision) +} + export const authAPI = { login, login2FA, @@ -396,14 +676,24 @@ export const authAPI = { clearAuthToken, getPublicSettings, sendVerifyCode, + sendPendingOAuthVerifyCode, validatePromoCode, validateInvitationCode, forgotPassword, resetPassword, refreshToken, revokeAllSessions, + getPendingOAuthBindLoginKind, + isPendingOAuthCreateAccountRequired, + hasPendingOAuthSuggestedProfile, + completePendingOAuthBindLogin, + createPendingLinuxDoOAuthAccount, + createPendingOIDCOAuthAccount, + createPendingWeChatOAuthAccount, + exchangePendingOAuthCompletion, completeLinuxDoOAuthRegistration, - completeOIDCOAuthRegistration + completeOIDCOAuthRegistration, + completeWeChatOAuthRegistration } export default authAPI diff --git a/frontend/src/api/payment.ts b/frontend/src/api/payment.ts index 5cedb107ec77e14f92c86b2f605e2353fa9862bc..e866e184f55bc6fcb410c7bd85d2321454c9af2b 100644 --- a/frontend/src/api/payment.ts +++ b/frontend/src/api/payment.ts @@ -67,9 +67,9 @@ export const paymentAPI = { return apiClient.post('/payment/orders/verify', { out_trade_no: outTradeNo }) }, - /** Verify order payment status without auth (public endpoint for result page) */ - verifyOrderPublic(outTradeNo: string) { - return apiClient.post('/payment/public/orders/verify', { out_trade_no: outTradeNo }) + /** Resolve an order from a signed resume token without auth */ + resolveOrderPublicByResumeToken(resumeToken: string) { + return apiClient.post('/payment/public/orders/resolve', { resume_token: resumeToken }) }, /** Request a refund for a completed order */ diff --git a/frontend/src/api/user.ts b/frontend/src/api/user.ts index cd6482708f3faf3d269dba0dfbc3f7899edf2097..32ef07e0f3920f1291ce5537031de290b876cb07 100644 --- a/frontend/src/api/user.ts +++ b/frontend/src/api/user.ts @@ -4,7 +4,12 @@ */ import { apiClient } from './client' -import type { User, ChangePasswordRequest, NotifyEmailEntry } from '@/types' +import { + resolveWeChatOAuthStartStrict, + prepareOAuthBindAccessTokenCookie, + type WeChatOAuthPublicSettings, +} from './auth' +import type { User, ChangePasswordRequest, NotifyEmailEntry, UserAuthProvider } from '@/types' /** * Get current user profile @@ -22,6 +27,7 @@ export async function getProfile(): Promise { */ export async function updateProfile(profile: { username?: string + avatar_url?: string | null balance_notify_enabled?: boolean balance_notify_threshold?: number | null balance_notify_extra_emails?: NotifyEmailEntry[] @@ -83,6 +89,85 @@ export async function toggleNotifyEmail(email: string, disabled: boolean): Promi return data } +export async function sendEmailBindingCode(email: string): Promise { + await apiClient.post('/user/account-bindings/email/send-code', { email }) +} + +export async function bindEmailIdentity(payload: { + email: string + verify_code: string + password: string +}): Promise { + const { data } = await apiClient.post('/user/account-bindings/email', payload) + return data +} + +export async function unbindAuthIdentity(provider: BindableOAuthProvider): Promise { + const { data } = await apiClient.delete(`/user/account-bindings/${provider}`) + return data +} + +export type BindableOAuthProvider = Exclude + +interface BuildOAuthBindingStartURLOptions { + redirectTo?: string + wechatOAuthSettings?: WeChatOAuthPublicSettings | null +} + +export function resolveWeChatOAuthMode(): 'open' | 'mp' { + if (typeof navigator === 'undefined') { + return 'open' + } + return /MicroMessenger/i.test(navigator.userAgent) ? 'mp' : 'open' +} + +function resolveWeChatOAuthBindingMode( + settings?: WeChatOAuthPublicSettings | null +): 'open' | 'mp' | null { + if (settings) { + return resolveWeChatOAuthStartStrict(settings).mode + } + return resolveWeChatOAuthMode() +} + +export function buildOAuthBindingStartURL( + provider: BindableOAuthProvider, + options: BuildOAuthBindingStartURLOptions = {} +): string | null { + const redirectTo = options.redirectTo?.trim() || '/profile' + const apiBase = (import.meta.env.VITE_API_BASE_URL as string | undefined) || '/api/v1' + const normalized = apiBase.replace(/\/$/, '') + const params = new URLSearchParams({ + redirect: redirectTo, + intent: 'bind_current_user' + }) + + if (provider === 'wechat') { + const mode = resolveWeChatOAuthBindingMode(options.wechatOAuthSettings) + if (!mode) { + return null + } + params.set('mode', mode) + } + + return `${normalized}/auth/oauth/${provider}/start?${params.toString()}` +} + +export function startOAuthBinding( + provider: BindableOAuthProvider, + options: BuildOAuthBindingStartURLOptions = {} +): void { + if (typeof window === 'undefined') { + return + } + const startURL = buildOAuthBindingStartURL(provider, options) + if (!startURL) { + return + } + prepareOAuthBindAccessTokenCookie() + window.location.href = startURL +} + export const userAPI = { getProfile, updateProfile, @@ -90,7 +175,12 @@ export const userAPI = { sendNotifyEmailCode, verifyNotifyEmail, removeNotifyEmail, - toggleNotifyEmail + toggleNotifyEmail, + sendEmailBindingCode, + bindEmailIdentity, + unbindAuthIdentity, + buildOAuthBindingStartURL, + startOAuthBinding } export default userAPI diff --git a/frontend/src/components/auth/PendingOAuthCreateAccountForm.vue b/frontend/src/components/auth/PendingOAuthCreateAccountForm.vue new file mode 100644 index 0000000000000000000000000000000000000000..a566e264e143e0da26cbea10737128d81f72b0e2 --- /dev/null +++ b/frontend/src/components/auth/PendingOAuthCreateAccountForm.vue @@ -0,0 +1,288 @@ + + + + + diff --git a/frontend/src/components/auth/TotpLoginModal.vue b/frontend/src/components/auth/TotpLoginModal.vue index 03fa718de8d6c2fb86788d02c4d4ca42d3e23cab..0ae2f48293776b8c50c287572ab978b3333101b0 100644 --- a/frontend/src/components/auth/TotpLoginModal.vue +++ b/frontend/src/components/auth/TotpLoginModal.vue @@ -47,11 +47,6 @@ - -
- {{ error }} -
- ' + } + } + } + }) + + await flushPromises() + await wrapper.get('[data-testid="linuxdo-create-account-email"]').setValue(' user@example.com ') + + expect(wrapper.get('[data-testid="linuxdo-create-account-send-code"]').attributes('disabled')).toBeDefined() + + await wrapper.get('[data-testid="turnstile-verify"]').trigger('click') + await wrapper.get('[data-testid="linuxdo-create-account-send-code"]').trigger('click') + await flushPromises() + + expect(sendPendingOAuthVerifyCode).toHaveBeenCalledWith({ + email: 'user@example.com', + turnstile_token: 'turnstile-token' + }) + }) +}) diff --git a/frontend/src/components/auth/__tests__/TotpLoginModal.spec.ts b/frontend/src/components/auth/__tests__/TotpLoginModal.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..06fbe397579cbb7e9439e739219f03d3a2056b1c --- /dev/null +++ b/frontend/src/components/auth/__tests__/TotpLoginModal.spec.ts @@ -0,0 +1,41 @@ +import { mount } from '@vue/test-utils' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import TotpLoginModal from '@/components/auth/TotpLoginModal.vue' + +const { showErrorMock } = vi.hoisted(() => ({ + showErrorMock: vi.fn(), +})) + +vi.mock('vue-i18n', () => ({ + useI18n: () => ({ + t: (key: string) => key, + }), +})) + +vi.mock('@/stores', () => ({ + useAppStore: () => ({ + showError: (...args: any[]) => showErrorMock(...args), + }), +})) + +describe('TotpLoginModal', () => { + beforeEach(() => { + showErrorMock.mockReset() + }) + + it('sends verification errors to toast and does not render inline red text', async () => { + const wrapper = mount(TotpLoginModal, { + props: { + tempToken: 'temp-token', + userEmailMasked: 'u***@example.com', + }, + }) + + ;(wrapper.vm as unknown as { setError: (message: string) => void }).setError('Invalid code') + await wrapper.vm.$nextTick() + + expect(showErrorMock).toHaveBeenCalledWith('Invalid code') + expect(wrapper.text()).not.toContain('Invalid code') + expect(wrapper.find('.bg-red-50').exists()).toBe(false) + }) +}) diff --git a/frontend/src/components/auth/__tests__/WechatOAuthSection.spec.ts b/frontend/src/components/auth/__tests__/WechatOAuthSection.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..2f269e0bc53b2165dfe0f86350bef9397c465352 --- /dev/null +++ b/frontend/src/components/auth/__tests__/WechatOAuthSection.spec.ts @@ -0,0 +1,238 @@ +import { mount } from '@vue/test-utils' +import { createPinia, setActivePinia } from 'pinia' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import WechatOAuthSection from '@/components/auth/WechatOAuthSection.vue' +import { useAppStore } from '@/stores' +import type { PublicSettings } from '@/types' + +const routeState = vi.hoisted(() => ({ + query: {} as Record, +})) + +const locationState = vi.hoisted(() => ({ + current: { href: 'http://localhost/login' } as { href: string }, +})) + +let pinia: ReturnType + +vi.mock('vue-router', () => ({ + useRoute: () => routeState, +})) + +vi.mock('vue-i18n', async () => { + const actual = await vi.importActual('vue-i18n') + return { + ...actual, + useI18n: () => ({ + locale: { value: 'en' }, + t: (key: string, params?: Record) => { + if (key === 'auth.wechatProviderName') { + return 'Mock WeChat' + } + if (key === 'auth.oidc.signIn') { + return `Continue with ${params?.providerName ?? ''}`.trim() + } + if (key === 'auth.oauthFlow.wechatSystemBrowserOnly') { + return 'MOCK-SYSTEM-BROWSER-ONLY' + } + if (key === 'auth.oauthFlow.wechatBrowserOnly') { + return 'MOCK-WECHAT-BROWSER-ONLY' + } + if (key === 'auth.oauthFlow.wechatNotConfigured') { + return 'MOCK-NOT-CONFIGURED' + } + if (key === 'auth.oauthOrContinue') { + return 'or continue' + } + return key + }, + }), + } +}) + +type WeChatPublicSettings = PublicSettings & { + wechat_oauth_open_enabled?: boolean + wechat_oauth_mp_enabled?: boolean +} + +function buildPublicSettings(overrides: Partial = {}): WeChatPublicSettings { + return { + registration_enabled: true, + email_verify_enabled: false, + force_email_on_third_party_signup: false, + registration_email_suffix_whitelist: [], + promo_code_enabled: true, + password_reset_enabled: false, + invitation_code_enabled: false, + turnstile_enabled: false, + turnstile_site_key: '', + site_name: 'Sub2API', + site_logo: '', + site_subtitle: '', + api_base_url: '/api/v1', + contact_info: '', + doc_url: '', + home_content: '', + hide_ccs_import_button: false, + payment_enabled: false, + table_default_page_size: 20, + table_page_size_options: [10, 20, 50, 100], + custom_menu_items: [], + custom_endpoints: [], + linuxdo_oauth_enabled: false, + wechat_oauth_enabled: true, + oidc_oauth_enabled: false, + oidc_oauth_provider_name: 'OIDC', + backend_mode_enabled: false, + version: 'test', + balance_low_notify_enabled: false, + account_quota_notify_enabled: false, + balance_low_notify_threshold: 0, + ...overrides, + } +} + +function seedPublicSettings(overrides: Partial = {}): void { + const appStore = useAppStore() + const settings = buildPublicSettings(overrides) + appStore.cachedPublicSettings = settings + appStore.publicSettingsLoaded = true +} + +describe('WechatOAuthSection', () => { + beforeEach(() => { + pinia = createPinia() + setActivePinia(pinia) + routeState.query = { redirect: '/billing?plan=pro' } + locationState.current = { href: 'http://localhost/login' } + Object.defineProperty(window, 'location', { + configurable: true, + value: locationState.current, + }) + Object.defineProperty(window.navigator, 'userAgent', { + configurable: true, + value: 'Mozilla/5.0', + }) + }) + + afterEach(() => { + vi.unstubAllGlobals() + }) + + it('starts the open WeChat OAuth flow with the current redirect target when open mode is configured', async () => { + seedPublicSettings({ + wechat_oauth_open_enabled: true, + wechat_oauth_mp_enabled: false, + }) + const wrapper = mount(WechatOAuthSection, { + global: { + plugins: [pinia], + }, + }) + + expect(wrapper.text()).toContain('Mock WeChat') + + await wrapper.get('button').trigger('click') + + expect(locationState.current.href).toContain( + '/api/v1/auth/oauth/wechat/start?mode=open&redirect=%2Fbilling%3Fplan%3Dpro' + ) + }) + + it('uses mp mode inside the WeChat browser when mp mode is configured', async () => { + Object.defineProperty(window.navigator, 'userAgent', { + configurable: true, + value: 'Mozilla/5.0 MicroMessenger', + }) + seedPublicSettings({ + wechat_oauth_open_enabled: false, + wechat_oauth_mp_enabled: true, + }) + const wrapper = mount(WechatOAuthSection, { + global: { + plugins: [pinia], + }, + }) + + await wrapper.get('button').trigger('click') + + expect(locationState.current.href).toContain( + '/api/v1/auth/oauth/wechat/start?mode=mp&redirect=%2Fbilling%3Fplan%3Dpro' + ) + }) + + it('disables the button outside the WeChat browser when only mp mode is configured', async () => { + seedPublicSettings({ + wechat_oauth_open_enabled: false, + wechat_oauth_mp_enabled: true, + }) + const wrapper = mount(WechatOAuthSection, { + global: { + plugins: [pinia], + }, + }) + + expect(wrapper.get('button').attributes('disabled')).toBeDefined() + expect(wrapper.text()).toContain('MOCK-WECHAT-BROWSER-ONLY') + + await wrapper.get('button').trigger('click') + + expect(locationState.current.href).toBe('http://localhost/login') + }) + + it('disables the button inside the WeChat browser when only open mode is configured', async () => { + Object.defineProperty(window.navigator, 'userAgent', { + configurable: true, + value: 'Mozilla/5.0 MicroMessenger', + }) + seedPublicSettings({ + wechat_oauth_open_enabled: true, + wechat_oauth_mp_enabled: false, + }) + const wrapper = mount(WechatOAuthSection, { + global: { + plugins: [pinia], + }, + }) + + expect(wrapper.get('button').attributes('disabled')).toBeDefined() + expect(wrapper.text()).toContain('MOCK-SYSTEM-BROWSER-ONLY') + + await wrapper.get('button').trigger('click') + + expect(locationState.current.href).toBe('http://localhost/login') + }) + + it('uses the legacy overall enabled flag when per-mode settings are not present', async () => { + seedPublicSettings({ + wechat_oauth_enabled: true, + }) + const wrapper = mount(WechatOAuthSection, { + global: { + plugins: [pinia], + }, + }) + + await wrapper.get('button').trigger('click') + + expect(locationState.current.href).toContain( + '/api/v1/auth/oauth/wechat/start?mode=open&redirect=%2Fbilling%3Fplan%3Dpro' + ) + }) + + it('shows the localized not-configured hint when WeChat OAuth is unavailable', async () => { + seedPublicSettings({ + wechat_oauth_enabled: false, + wechat_oauth_open_enabled: false, + wechat_oauth_mp_enabled: false, + }) + + const wrapper = mount(WechatOAuthSection, { + global: { + plugins: [pinia], + }, + }) + + expect(wrapper.text()).toContain('MOCK-NOT-CONFIGURED') + }) +}) diff --git a/frontend/src/components/layout/AppHeader.vue b/frontend/src/components/layout/AppHeader.vue index fbcab521581c1a47a0931c3924bc36cdaed9f880..306f1429d8dea80fbe914396e1a9e00e0dc994f6 100644 --- a/frontend/src/components/layout/AppHeader.vue +++ b/frontend/src/components/layout/AppHeader.vue @@ -74,10 +74,14 @@ class="flex items-center gap-2 rounded-xl p-1.5 transition-colors hover:bg-gray-100 dark:hover:bg-dark-800" aria-label="User Menu" > -
- {{ userInitials }} +
+ + {{ userInitials }}
-

- {{ errors.password }} -

- + -

- {{ errors.turnstile }} -

- - -
-
-
- -
-

- {{ errorMessage }} -

-
-
-
-
-
-

- {{ t('auth.oidc.invitationRequired', { providerName }) }} -

-
- +
+
+
+
+

+ {{ t('auth.oauthFlow.profileDetailsTitle', { providerName }) }} +

+

+ {{ t('auth.oauthFlow.profileDetailsDescription', { providerName }) }} +

+
+ + + + +
- -

- {{ invitationError }} + + - -

-
-
- + + + + + + + + +
@@ -73,15 +243,26 @@ + + diff --git a/frontend/src/views/auth/WechatPaymentCallbackView.vue b/frontend/src/views/auth/WechatPaymentCallbackView.vue new file mode 100644 index 0000000000000000000000000000000000000000..53599ec3c730967df94d83254309b83cf215e089 --- /dev/null +++ b/frontend/src/views/auth/WechatPaymentCallbackView.vue @@ -0,0 +1,126 @@ + + + diff --git a/frontend/src/views/auth/__tests__/EmailVerifyView.spec.ts b/frontend/src/views/auth/__tests__/EmailVerifyView.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..9f67a994286cc1d85072b5a022054c8c374fafdf --- /dev/null +++ b/frontend/src/views/auth/__tests__/EmailVerifyView.spec.ts @@ -0,0 +1,453 @@ +import { flushPromises, mount } from '@vue/test-utils' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import EmailVerifyView from '@/views/auth/EmailVerifyView.vue' + +const { + pushMock, + showSuccessMock, + showErrorMock, + registerMock, + setTokenMock, + setPendingAuthSessionMock, + clearPendingAuthSessionMock, + getPublicSettingsMock, + sendVerifyCodeMock, + sendPendingOAuthVerifyCodeMock, + persistOAuthTokenContextMock, + apiClientPostMock, + authStoreState, +} = vi.hoisted(() => ({ + pushMock: vi.fn(), + showSuccessMock: vi.fn(), + showErrorMock: vi.fn(), + registerMock: vi.fn(), + setTokenMock: vi.fn(), + setPendingAuthSessionMock: vi.fn(), + clearPendingAuthSessionMock: vi.fn(), + getPublicSettingsMock: vi.fn(), + sendVerifyCodeMock: vi.fn(), + sendPendingOAuthVerifyCodeMock: vi.fn(), + persistOAuthTokenContextMock: vi.fn(), + apiClientPostMock: vi.fn(), + authStoreState: { + pendingAuthSession: null as null | { + token: string + token_field: 'pending_auth_token' | 'pending_oauth_token' + provider: string + redirect?: string + adoption_required?: boolean + suggested_display_name?: string + suggested_avatar_url?: string + } + }, +})) + +vi.mock('vue-router', () => ({ + useRouter: () => ({ + push: pushMock, + }), +})) + +vi.mock('vue-i18n', () => ({ + createI18n: () => ({ + global: { + t: (key: string) => key, + }, + }), + useI18n: () => ({ + t: (key: string, params?: Record) => { + if (key === 'auth.accountCreatedSuccess') { + return `Account created for ${params?.siteName ?? 'Sub2API'}` + } + return key + }, + locale: { value: 'en' }, + }), +})) + +vi.mock('@/stores', () => ({ + useAuthStore: () => ({ + pendingAuthSession: authStoreState.pendingAuthSession, + register: (...args: any[]) => registerMock(...args), + setToken: (...args: any[]) => setTokenMock(...args), + setPendingAuthSession: (...args: any[]) => setPendingAuthSessionMock(...args), + clearPendingAuthSession: (...args: any[]) => clearPendingAuthSessionMock(...args), + }), + useAppStore: () => ({ + showSuccess: (...args: any[]) => showSuccessMock(...args), + showError: (...args: any[]) => showErrorMock(...args), + }), +})) + +vi.mock('@/api/auth', async () => { + const actual = await vi.importActual('@/api/auth') + return { + ...actual, + getPublicSettings: (...args: any[]) => getPublicSettingsMock(...args), + sendVerifyCode: (...args: any[]) => sendVerifyCodeMock(...args), + sendPendingOAuthVerifyCode: (...args: any[]) => sendPendingOAuthVerifyCodeMock(...args), + persistOAuthTokenContext: (...args: any[]) => persistOAuthTokenContextMock(...args), + } +}) + +vi.mock('@/api/client', () => ({ + apiClient: { + post: (...args: any[]) => apiClientPostMock(...args), + }, +})) + +describe('EmailVerifyView', () => { + beforeEach(() => { + pushMock.mockReset() + showSuccessMock.mockReset() + showErrorMock.mockReset() + registerMock.mockReset() + setTokenMock.mockReset() + setPendingAuthSessionMock.mockReset() + clearPendingAuthSessionMock.mockReset() + getPublicSettingsMock.mockReset() + sendVerifyCodeMock.mockReset() + sendPendingOAuthVerifyCodeMock.mockReset() + persistOAuthTokenContextMock.mockReset() + apiClientPostMock.mockReset() + authStoreState.pendingAuthSession = null + sessionStorage.clear() + + getPublicSettingsMock.mockResolvedValue({ + turnstile_enabled: false, + turnstile_site_key: '', + site_name: 'Sub2API', + registration_email_suffix_whitelist: [], + }) + sendVerifyCodeMock.mockResolvedValue({ countdown: 60 }) + sendPendingOAuthVerifyCodeMock.mockResolvedValue({ countdown: 60 }) + setTokenMock.mockResolvedValue({}) + }) + + it('uses the pending oauth verify-code endpoint when register data carries a pending auth session', async () => { + authStoreState.pendingAuthSession = { + token: 'pending-token-1', + token_field: 'pending_auth_token', + provider: 'wechat', + redirect: '/profile', + } + sessionStorage.setItem( + 'register_data', + JSON.stringify({ + email: 'fresh@example.com', + password: 'secret-123', + }) + ) + + mount(EmailVerifyView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + TurnstileWidget: true, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(sendPendingOAuthVerifyCodeMock).toHaveBeenCalledWith({ + email: 'fresh@example.com', + pending_auth_token: 'pending-token-1', + }) + expect(sendVerifyCodeMock).not.toHaveBeenCalled() + }) + + it('skips the registration email suffix whitelist for pending oauth verification', async () => { + authStoreState.pendingAuthSession = { + token: 'pending-token-2', + token_field: 'pending_auth_token', + provider: 'oidc', + redirect: '/profile', + } + getPublicSettingsMock.mockResolvedValue({ + turnstile_enabled: false, + turnstile_site_key: '', + site_name: 'Sub2API', + registration_email_suffix_whitelist: ['allowed.com'], + }) + sessionStorage.setItem( + 'register_data', + JSON.stringify({ + email: 'fresh@example.com', + password: 'secret-123', + }) + ) + + mount(EmailVerifyView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + TurnstileWidget: true, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(sendPendingOAuthVerifyCodeMock).toHaveBeenCalledWith({ + email: 'fresh@example.com', + pending_auth_token: 'pending-token-2', + }) + expect(showErrorMock).not.toHaveBeenCalled() + }) + + it('uses the pending oauth verify-code endpoint when auth store only carries the pending provider', async () => { + authStoreState.pendingAuthSession = { + token: '', + token_field: 'pending_oauth_token', + provider: 'oidc', + redirect: '/profile', + } + getPublicSettingsMock.mockResolvedValue({ + turnstile_enabled: false, + turnstile_site_key: '', + site_name: 'Sub2API', + registration_email_suffix_whitelist: ['allowed.com'], + }) + sessionStorage.setItem( + 'register_data', + JSON.stringify({ + email: 'fresh@example.com', + password: 'secret-123', + }) + ) + + mount(EmailVerifyView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + TurnstileWidget: true, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(sendPendingOAuthVerifyCodeMock).toHaveBeenCalledWith({ + email: 'fresh@example.com', + pending_oauth_token: undefined, + }) + expect(sendVerifyCodeMock).not.toHaveBeenCalled() + expect(showErrorMock).not.toHaveBeenCalled() + }) + + it('returns to the oauth callback flow when pending send-code detects an existing account email', async () => { + authStoreState.pendingAuthSession = { + token: '', + token_field: 'pending_oauth_token', + provider: 'oidc', + redirect: '/profile/security', + } + getPublicSettingsMock.mockResolvedValue({ + turnstile_enabled: false, + turnstile_site_key: '', + site_name: 'Sub2API', + registration_email_suffix_whitelist: ['allowed.com'], + }) + sendPendingOAuthVerifyCodeMock.mockResolvedValue({ + auth_result: 'pending_session', + provider: 'oidc', + redirect: '/profile/security', + }) + sessionStorage.setItem( + 'register_data', + JSON.stringify({ + email: 'fresh@example.com', + password: 'secret-123', + }) + ) + + mount(EmailVerifyView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + TurnstileWidget: true, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(setPendingAuthSessionMock).toHaveBeenCalledWith({ + token: '', + token_field: 'pending_oauth_token', + provider: 'oidc', + redirect: '/profile/security', + }) + expect(pushMock).toHaveBeenCalledWith('/auth/oidc/callback') + expect(showErrorMock).not.toHaveBeenCalled() + }) + + it('submits pending auth account creation when session storage has no pending metadata but auth store does', async () => { + authStoreState.pendingAuthSession = { + token: 'pending-token-1', + token_field: 'pending_auth_token', + provider: 'wechat', + redirect: '/profile', + } + sessionStorage.setItem( + 'register_data', + JSON.stringify({ + email: 'fresh@example.com', + password: 'secret-123', + }) + ) + apiClientPostMock.mockResolvedValue({ + data: { + access_token: 'oauth-access-token', + refresh_token: 'oauth-refresh-token', + expires_in: 3600, + token_type: 'Bearer', + }, + }) + + const wrapper = mount(EmailVerifyView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + TurnstileWidget: true, + transition: false, + }, + }, + }) + + await flushPromises() + await wrapper.get('#code').setValue('123456') + await wrapper.get('form').trigger('submit.prevent') + await flushPromises() + + expect(apiClientPostMock).toHaveBeenCalledWith('/auth/oauth/pending/create-account', { + email: 'fresh@example.com', + password: 'secret-123', + verify_code: '123456', + }) + expect(persistOAuthTokenContextMock).toHaveBeenCalledWith({ + access_token: 'oauth-access-token', + refresh_token: 'oauth-refresh-token', + expires_in: 3600, + token_type: 'Bearer', + }) + expect(setTokenMock).toHaveBeenCalledWith('oauth-access-token') + expect(clearPendingAuthSessionMock).toHaveBeenCalled() + expect(pushMock).toHaveBeenCalledWith('/profile') + expect(registerMock).not.toHaveBeenCalled() + }) + + it('returns to the oauth callback flow when pending account creation becomes bind-login', async () => { + authStoreState.pendingAuthSession = { + token: '', + token_field: 'pending_oauth_token', + provider: 'oidc', + redirect: '/profile/security', + } + getPublicSettingsMock.mockResolvedValue({ + turnstile_enabled: false, + turnstile_site_key: '', + site_name: 'Sub2API', + registration_email_suffix_whitelist: ['allowed.com'], + }) + sessionStorage.setItem( + 'register_data', + JSON.stringify({ + email: 'fresh@example.com', + password: 'secret-123', + }) + ) + apiClientPostMock.mockResolvedValue({ + data: { + auth_result: 'pending_session', + provider: 'oidc', + step: 'bind_login_required', + redirect: '/profile/security', + email: 'fresh@example.com', + }, + }) + + const wrapper = mount(EmailVerifyView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + TurnstileWidget: true, + transition: false, + }, + }, + }) + + await flushPromises() + await wrapper.get('#code').setValue('123456') + await wrapper.get('form').trigger('submit.prevent') + await flushPromises() + + expect(apiClientPostMock).toHaveBeenCalledWith('/auth/oauth/pending/create-account', { + email: 'fresh@example.com', + password: 'secret-123', + verify_code: '123456', + }) + expect(setPendingAuthSessionMock).toHaveBeenCalledWith({ + token: '', + token_field: 'pending_oauth_token', + provider: 'oidc', + redirect: '/profile/security', + }) + expect(pushMock).toHaveBeenCalledWith('/auth/oidc/callback') + expect(setTokenMock).not.toHaveBeenCalled() + expect(persistOAuthTokenContextMock).not.toHaveBeenCalled() + expect(clearPendingAuthSessionMock).not.toHaveBeenCalled() + expect(showSuccessMock).not.toHaveBeenCalled() + }) + + it('keeps the normal email registration flow unchanged', async () => { + sessionStorage.setItem( + 'register_data', + JSON.stringify({ + email: 'normal@example.com', + password: 'secret-456', + promo_code: 'PROMO', + invitation_code: 'INVITE', + }) + ) + registerMock.mockResolvedValue({}) + + const wrapper = mount(EmailVerifyView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + TurnstileWidget: true, + transition: false, + }, + }, + }) + + await flushPromises() + await wrapper.get('#code').setValue('654321') + await wrapper.get('form').trigger('submit.prevent') + await flushPromises() + + expect(registerMock).toHaveBeenCalledWith({ + email: 'normal@example.com', + password: 'secret-456', + verify_code: '654321', + turnstile_token: undefined, + promo_code: 'PROMO', + invitation_code: 'INVITE', + }) + expect(apiClientPostMock).not.toHaveBeenCalled() + expect(pushMock).toHaveBeenCalledWith('/dashboard') + }) +}) diff --git a/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..29aef61340e393723b427d212f508ed9e660aba5 --- /dev/null +++ b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts @@ -0,0 +1,668 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { flushPromises, mount } from '@vue/test-utils' + +import LinuxDoCallbackView from '../LinuxDoCallbackView.vue' + +const replace = vi.fn() +const showSuccess = vi.fn() +const showError = vi.fn() +const setToken = vi.fn() +const setPendingAuthSession = vi.fn() +const clearPendingAuthSession = vi.fn() +const exchangePendingOAuthCompletion = vi.fn() +const completeLinuxDoOAuthRegistration = vi.fn() +const getPublicSettings = vi.fn() +const login2FA = vi.fn() +const apiClientPost = vi.fn() +const sendVerifyCode = vi.fn() +const sendPendingOAuthVerifyCode = vi.fn() + +vi.mock('vue-router', () => ({ + useRoute: () => ({ + query: {} + }), + useRouter: () => ({ + replace + }) +})) + +vi.mock('vue-i18n', async () => { + const actual = await vi.importActual('vue-i18n') + return { + ...actual, + useI18n: () => ({ + t: (key: string, params?: Record) => { + if (key === 'auth.oauthFlow.totpHint') { + return `verify ${params?.account ?? ''}`.trim() + } + return key + } + }) + } +}) + +vi.mock('@/stores', () => ({ + useAuthStore: () => ({ + setToken, + setPendingAuthSession, + clearPendingAuthSession + }), + useAppStore: () => ({ + showSuccess, + showError + }) +})) + +vi.mock('@/api/client', () => ({ + apiClient: { + post: (...args: any[]) => apiClientPost(...args) + } +})) + +vi.mock('@/api/auth', async () => { + const actual = await vi.importActual('@/api/auth') + return { + ...actual, + exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletion(...args), + completeLinuxDoOAuthRegistration: (...args: any[]) => completeLinuxDoOAuthRegistration(...args), + getPublicSettings: (...args: any[]) => getPublicSettings(...args), + login2FA: (...args: any[]) => login2FA(...args), + sendVerifyCode: (...args: any[]) => sendVerifyCode(...args), + sendPendingOAuthVerifyCode: (...args: any[]) => sendPendingOAuthVerifyCode(...args) + } +}) + +describe('LinuxDoCallbackView', () => { + beforeEach(() => { + replace.mockReset() + showSuccess.mockReset() + showError.mockReset() + setToken.mockReset() + setPendingAuthSession.mockReset() + clearPendingAuthSession.mockReset() + exchangePendingOAuthCompletion.mockReset() + completeLinuxDoOAuthRegistration.mockReset() + getPublicSettings.mockReset() + login2FA.mockReset() + apiClientPost.mockReset() + sendVerifyCode.mockReset() + sendPendingOAuthVerifyCode.mockReset() + getPublicSettings.mockResolvedValue({ + turnstile_enabled: false, + turnstile_site_key: '' + }) + window.location.hash = '' + localStorage.clear() + }) + + it('accepts the legacy fragment token success callback without pending-session exchange', async () => { + window.location.hash = + '#access_token=legacy-access-token&refresh_token=legacy-refresh-token&expires_in=3600&token_type=Bearer&redirect=%2Flegacy-dashboard' + setToken.mockResolvedValue({}) + + mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(exchangePendingOAuthCompletion).not.toHaveBeenCalled() + expect(setToken).toHaveBeenCalledWith('legacy-access-token') + expect(localStorage.getItem('refresh_token')).toBe('legacy-refresh-token') + expect(localStorage.getItem('token_expires_at')).not.toBeNull() + expect(showSuccess).toHaveBeenCalledWith('auth.loginSuccess') + expect(replace).toHaveBeenCalledWith('/legacy-dashboard') + }) + + it('accepts the legacy pending oauth invitation fragment without pending-session exchange', async () => { + window.location.hash = '#error=invitation_required&pending_oauth_token=legacy-pending-token&redirect=%2Flegacy-invite' + apiClientPost.mockResolvedValue({ + data: { + access_token: 'legacy-access-token', + refresh_token: 'legacy-refresh-token', + expires_in: 3600, + token_type: 'Bearer' + } + }) + setToken.mockResolvedValue({}) + + const wrapper = mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(exchangePendingOAuthCompletion).not.toHaveBeenCalled() + await wrapper.find('input[type="text"]').setValue('invite-code') + await wrapper.find('button').trigger('click') + await flushPromises() + + expect(apiClientPost).toHaveBeenCalledWith('/auth/oauth/linuxdo/complete-registration', { + adopt_display_name: true, + adopt_avatar: true, + pending_oauth_token: 'legacy-pending-token', + invitation_code: 'invite-code' + }) + expect(setToken).toHaveBeenCalledWith('legacy-access-token') + expect(replace).toHaveBeenCalledWith('/legacy-invite') + }) + + it('does not send adoption decisions during the initial exchange', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + access_token: 'access-token', + refresh_token: 'refresh-token', + expires_in: 3600, + redirect: '/dashboard', + adoption_required: true + }) + setToken.mockResolvedValue({}) + + mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(1) + expect(exchangePendingOAuthCompletion).toHaveBeenCalledWith() + }) + + it('waits for explicit adoption confirmation before finishing a non-invitation login', async () => { + exchangePendingOAuthCompletion + .mockResolvedValueOnce({ + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'LinuxDo Nick', + suggested_avatar_url: 'https://cdn.example/linuxdo.png' + }) + .mockResolvedValueOnce({ + access_token: 'access-token', + refresh_token: 'refresh-token', + expires_in: 3600, + redirect: '/dashboard' + }) + setToken.mockResolvedValue({}) + + const wrapper = mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(wrapper.text()).toContain('LinuxDo Nick') + expect(setToken).not.toHaveBeenCalled() + expect(replace).not.toHaveBeenCalled() + + const checkboxes = wrapper.findAll('input[type="checkbox"]') + await checkboxes[1].setValue(false) + + const buttons = wrapper.findAll('button') + expect(buttons).toHaveLength(1) + await buttons[0].trigger('click') + await flushPromises() + + expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(2) + expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(1) + expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(2, { + adoptDisplayName: true, + adoptAvatar: false + }) + expect(setToken).toHaveBeenCalledWith('access-token') + expect(replace).toHaveBeenCalledWith('/dashboard') + }) + + it('treats a completion without token as bind success and returns to profile', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({}) + + mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(setToken).not.toHaveBeenCalled() + expect(showSuccess).toHaveBeenCalledWith('profile.authBindings.bindSuccess') + expect(replace).toHaveBeenCalledWith('/profile') + }) + + it('supports bind completion after adoption confirmation', async () => { + exchangePendingOAuthCompletion + .mockResolvedValueOnce({ + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'LinuxDo Nick', + suggested_avatar_url: 'https://cdn.example/linuxdo.png' + }) + .mockResolvedValueOnce({ + redirect: '/profile/security' + }) + + const wrapper = mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + await wrapper.findAll('button')[0].trigger('click') + await flushPromises() + + expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(2, { + adoptDisplayName: true, + adoptAvatar: true + }) + expect(setToken).not.toHaveBeenCalled() + expect(showSuccess).toHaveBeenCalledWith('profile.authBindings.bindSuccess') + expect(replace).toHaveBeenCalledWith('/profile/security') + }) + + it('keeps rendering pending bind-login UI when adoption confirmation leads to another pending step', async () => { + exchangePendingOAuthCompletion + .mockResolvedValueOnce({ + redirect: '/profile/security', + adoption_required: true, + suggested_display_name: 'LinuxDo Nick', + suggested_avatar_url: 'https://cdn.example/linuxdo.png' + }) + .mockResolvedValueOnce({ + step: 'bind_login_required', + redirect: '/profile/security', + email: 'existing@example.com', + adoption_required: true, + suggested_display_name: 'LinuxDo Nick', + suggested_avatar_url: 'https://cdn.example/linuxdo.png' + }) + + const wrapper = mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + await wrapper.findAll('button')[0].trigger('click') + await flushPromises() + + expect(showSuccess).not.toHaveBeenCalled() + expect(replace).not.toHaveBeenCalled() + expect((wrapper.get('[data-testid="linuxdo-bind-login-email"]').element as HTMLInputElement).value).toBe( + 'existing@example.com' + ) + }) + + it('persists a pending auth session when the oauth flow still needs account creation', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'email_required', + redirect: '/welcome' + }) + + mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(setPendingAuthSession).toHaveBeenCalledWith({ + token: '', + token_field: 'pending_oauth_token', + provider: 'linuxdo', + redirect: '/welcome' + }) + }) + + it('renders adoption choices for invitation flow and submits the selected values', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'invitation_required', + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'LinuxDo Nick', + suggested_avatar_url: 'https://cdn.example/linuxdo.png' + }) + completeLinuxDoOAuthRegistration.mockResolvedValue({ + access_token: 'access-token', + refresh_token: 'refresh-token', + expires_in: 3600, + token_type: 'Bearer' + }) + setToken.mockResolvedValue({}) + + const wrapper = mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(wrapper.text()).toContain('LinuxDo Nick') + expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(1) + expect(exchangePendingOAuthCompletion).toHaveBeenCalledWith() + + const checkboxes = wrapper.findAll('input[type="checkbox"]') + expect(checkboxes).toHaveLength(2) + + await checkboxes[0].setValue(false) + await wrapper.find('input[type="text"]').setValue('invite-code') + await wrapper.find('button').trigger('click') + + expect(completeLinuxDoOAuthRegistration).toHaveBeenCalledWith('invite-code', { + adoptDisplayName: false, + adoptAvatar: true + }) + }) + + it('collects email, password, and verify code for pending oauth account creation and submits adoption decisions', async () => { + getPublicSettings.mockResolvedValue({ + invitation_code_enabled: true, + turnstile_enabled: false, + turnstile_site_key: '' + }) + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'email_required', + redirect: '/welcome', + adoption_required: true, + suggested_display_name: 'LinuxDo Nick', + suggested_avatar_url: 'https://cdn.example/linuxdo.png' + }) + apiClientPost.mockResolvedValue({ + data: { + access_token: 'new-access-token', + refresh_token: 'new-refresh-token', + expires_in: 3600, + token_type: 'Bearer' + } + }) + setToken.mockResolvedValue({}) + + const wrapper = mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + const checkboxes = wrapper.findAll('input[type="checkbox"]') + expect(checkboxes).toHaveLength(2) + await checkboxes[1].setValue(false) + await wrapper.get('[data-testid="linuxdo-create-account-email"]').setValue(' new@example.com ') + await wrapper.get('[data-testid="linuxdo-create-account-password"]').setValue('secret-123') + await wrapper.get('[data-testid="linuxdo-create-account-verify-code"]').setValue('246810') + await wrapper.get('[data-testid="linuxdo-create-account-invitation-code"]').setValue(' INVITE123 ') + await wrapper.get('[data-testid="linuxdo-create-account-submit"]').trigger('click') + await flushPromises() + + expect(apiClientPost).toHaveBeenCalledWith('/auth/oauth/pending/create-account', { + email: 'new@example.com', + password: 'secret-123', + verify_code: '246810', + invitation_code: 'INVITE123', + adopt_display_name: true, + adopt_avatar: false + }) + expect(setToken).toHaveBeenCalledWith('new-access-token') + expect(replace).toHaveBeenCalledWith('/welcome') + }) + + it('switches to bind-login when create-account returns EMAIL_EXISTS', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'email_required', + redirect: '/welcome' + }) + apiClientPost.mockRejectedValue({ + response: { + data: { + reason: 'EMAIL_EXISTS', + message: 'email already exists' + } + } + }) + + const wrapper = mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + await wrapper.get('[data-testid="linuxdo-create-account-email"]').setValue('existing@example.com') + await wrapper.get('[data-testid="linuxdo-create-account-password"]').setValue('secret-123') + await wrapper.get('[data-testid="linuxdo-create-account-submit"]').trigger('click') + await flushPromises() + + expect((wrapper.get('[data-testid="linuxdo-bind-login-email"]').element as HTMLInputElement).value).toBe( + 'existing@example.com' + ) + }) + + it('shows create-account failures through toast without inline error text', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'email_required', + redirect: '/welcome' + }) + apiClientPost.mockRejectedValue(new Error('create failed')) + + const wrapper = mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + await wrapper.get('[data-testid="linuxdo-create-account-email"]').setValue('new@example.com') + await wrapper.get('[data-testid="linuxdo-create-account-password"]').setValue('secret-123') + await wrapper.get('[data-testid="linuxdo-create-account-submit"]').trigger('click') + await flushPromises() + + expect(showError).toHaveBeenCalledWith('create failed') + expect(wrapper.text()).not.toContain('create failed') + }) + + it('sends a verify code for pending oauth account creation', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'email_required', + redirect: '/welcome' + }) + sendPendingOAuthVerifyCode.mockResolvedValue({ + message: 'sent', + countdown: 60 + }) + + const wrapper = mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + await wrapper.get('[data-testid="linuxdo-create-account-email"]').setValue(' new@example.com ') + await wrapper.get('[data-testid="linuxdo-create-account-send-code"]').trigger('click') + await flushPromises() + + expect(sendPendingOAuthVerifyCode).toHaveBeenCalledWith({ + email: 'new@example.com' + }) + }) + + it('shows bind-login form for existing account binding and submits credentials with adoption decisions', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'bind_login_required', + redirect: '/profile/security', + email: 'existing@example.com', + adoption_required: true, + suggested_display_name: 'LinuxDo Nick', + suggested_avatar_url: 'https://cdn.example/linuxdo.png' + }) + apiClientPost.mockResolvedValue({ + data: { + access_token: 'bind-access-token', + refresh_token: 'bind-refresh-token', + expires_in: 3600, + token_type: 'Bearer' + } + }) + setToken.mockResolvedValue({}) + + const wrapper = mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + const checkboxes = wrapper.findAll('input[type="checkbox"]') + expect(checkboxes).toHaveLength(2) + await checkboxes[0].setValue(false) + await wrapper.get('[data-testid="linuxdo-bind-login-email"]').setValue('existing@example.com') + await wrapper.get('[data-testid="linuxdo-bind-login-password"]').setValue('secret-password') + await wrapper.get('[data-testid="linuxdo-bind-login-submit"]').trigger('click') + await flushPromises() + + expect(apiClientPost).toHaveBeenCalledWith('/auth/oauth/pending/bind-login', { + email: 'existing@example.com', + password: 'secret-password', + adopt_display_name: false, + adopt_avatar: true + }) + expect(setToken).toHaveBeenCalledWith('bind-access-token') + expect(replace).toHaveBeenCalledWith('/profile/security') + }) + + it('handles bind-login 2FA challenge before redirecting', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'bind_login_required', + redirect: '/profile', + email: 'existing@example.com', + adoption_required: true, + suggested_display_name: 'LinuxDo Nick', + suggested_avatar_url: 'https://cdn.example/linuxdo.png' + }) + apiClientPost.mockResolvedValue({ + data: { + requires_2fa: true, + temp_token: 'temp-123', + user_email_masked: 'o***g@example.com' + } + }) + login2FA.mockResolvedValue({ + access_token: '2fa-access-token' + }) + setToken.mockResolvedValue({}) + + const wrapper = mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + await wrapper.get('[data-testid="linuxdo-bind-login-password"]').setValue('secret-password') + await wrapper.get('[data-testid="linuxdo-bind-login-submit"]').trigger('click') + await flushPromises() + + expect(wrapper.text()).toContain('o***g@example.com') + expect(login2FA).not.toHaveBeenCalled() + + await wrapper.get('[data-testid="linuxdo-bind-login-totp"]').setValue('123456') + await wrapper.get('[data-testid="linuxdo-bind-login-totp-submit"]').trigger('click') + await flushPromises() + + expect(login2FA).toHaveBeenCalledWith({ + temp_token: 'temp-123', + totp_code: '123456' + }) + expect(setToken).toHaveBeenCalledWith('2fa-access-token') + expect(replace).toHaveBeenCalledWith('/profile') + }) +}) diff --git a/frontend/src/views/auth/__tests__/OAuthCallbackView.spec.ts b/frontend/src/views/auth/__tests__/OAuthCallbackView.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..1669f763f2cea57b5f6d8128974edca691161433 --- /dev/null +++ b/frontend/src/views/auth/__tests__/OAuthCallbackView.spec.ts @@ -0,0 +1,68 @@ +import { mount } from '@vue/test-utils' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import OAuthCallbackView from '@/views/auth/OAuthCallbackView.vue' + +const { routeState, showErrorMock, copyToClipboardMock } = vi.hoisted(() => ({ + routeState: { + query: {} as Record, + }, + showErrorMock: vi.fn(), + copyToClipboardMock: vi.fn(), +})) + +vi.mock('vue-router', () => ({ + useRoute: () => routeState, +})) + +vi.mock('vue-i18n', () => ({ + useI18n: () => ({ + t: (key: string) => key, + }), +})) + +vi.mock('@/stores', () => ({ + useAppStore: () => ({ + showError: (...args: any[]) => showErrorMock(...args), + }), +})) + +vi.mock('@/composables/useClipboard', () => ({ + useClipboard: () => ({ + copyToClipboard: (...args: any[]) => copyToClipboardMock(...args), + }), +})) + +describe('OAuthCallbackView', () => { + beforeEach(() => { + routeState.query = {} + showErrorMock.mockReset() + copyToClipboardMock.mockReset() + }) + + it('renders localized callback copy actions', () => { + routeState.query = { + code: 'oauth-code', + state: 'oauth-state', + } + + const wrapper = mount(OAuthCallbackView) + + expect(wrapper.text()).toContain('auth.oauth.callbackTitle') + expect(wrapper.text()).toContain('auth.oauth.callbackHint') + expect(wrapper.text()).toContain('common.copy') + expect(wrapper.find('input[value="oauth-code"]').exists()).toBe(true) + expect(wrapper.find('input[value="oauth-state"]').exists()).toBe(true) + }) + + it('sends callback errors to toast instead of rendering inline red text', () => { + routeState.query = { + error: 'oauth failed', + } + + const wrapper = mount(OAuthCallbackView) + + expect(showErrorMock).toHaveBeenCalledWith('oauth failed') + expect(wrapper.text()).not.toContain('oauth failed') + expect(wrapper.find('.bg-red-50').exists()).toBe(false) + }) +}) diff --git a/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..0167604c0a30f914b5880bf4190256d288b68149 --- /dev/null +++ b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts @@ -0,0 +1,645 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { flushPromises, mount } from '@vue/test-utils' + +import OidcCallbackView from '../OidcCallbackView.vue' + +const replace = vi.fn() +const showSuccess = vi.fn() +const showError = vi.fn() +const setToken = vi.fn() +const setPendingAuthSession = vi.fn() +const clearPendingAuthSession = vi.fn() +const exchangePendingOAuthCompletion = vi.fn() +const completeOIDCOAuthRegistration = vi.fn() +const getPublicSettings = vi.fn() +const login2FA = vi.fn() +const apiClientPost = vi.fn() +const sendVerifyCode = vi.fn() +const sendPendingOAuthVerifyCode = vi.fn() + +vi.mock('vue-router', () => ({ + useRoute: () => ({ + query: {} + }), + useRouter: () => ({ + replace + }) +})) + +vi.mock('vue-i18n', async () => { + const actual = await vi.importActual('vue-i18n') + return { + ...actual, + useI18n: () => ({ + t: (key: string, params?: Record) => { + if (key === 'auth.oauthFlow.totpHint') { + return `verify ${params?.account ?? ''}`.trim() + } + if (!params?.providerName) { + return key + } + return `${key}:${params.providerName}` + } + }) + } +}) + +vi.mock('@/stores', () => ({ + useAuthStore: () => ({ + setToken, + setPendingAuthSession, + clearPendingAuthSession + }), + useAppStore: () => ({ + showSuccess, + showError + }) +})) + +vi.mock('@/api/client', () => ({ + apiClient: { + post: (...args: any[]) => apiClientPost(...args) + } +})) + +vi.mock('@/api/auth', async () => { + const actual = await vi.importActual('@/api/auth') + return { + ...actual, + exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletion(...args), + completeOIDCOAuthRegistration: (...args: any[]) => completeOIDCOAuthRegistration(...args), + getPublicSettings: (...args: any[]) => getPublicSettings(...args), + login2FA: (...args: any[]) => login2FA(...args), + sendVerifyCode: (...args: any[]) => sendVerifyCode(...args), + sendPendingOAuthVerifyCode: (...args: any[]) => sendPendingOAuthVerifyCode(...args) + } +}) + +describe('OidcCallbackView', () => { + beforeEach(() => { + replace.mockReset() + showSuccess.mockReset() + showError.mockReset() + setToken.mockReset() + setPendingAuthSession.mockReset() + clearPendingAuthSession.mockReset() + exchangePendingOAuthCompletion.mockReset() + completeOIDCOAuthRegistration.mockReset() + getPublicSettings.mockReset() + login2FA.mockReset() + apiClientPost.mockReset() + sendVerifyCode.mockReset() + sendPendingOAuthVerifyCode.mockReset() + getPublicSettings.mockResolvedValue({ + oidc_oauth_provider_name: 'ExampleID', + turnstile_enabled: false, + turnstile_site_key: '' + }) + window.location.hash = '' + localStorage.clear() + }) + + it('accepts the legacy fragment token success callback without pending-session exchange', async () => { + window.location.hash = + '#access_token=legacy-access-token&refresh_token=legacy-refresh-token&expires_in=3600&token_type=Bearer&redirect=%2Flegacy-dashboard' + setToken.mockResolvedValue({}) + + mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(exchangePendingOAuthCompletion).not.toHaveBeenCalled() + expect(setToken).toHaveBeenCalledWith('legacy-access-token') + expect(localStorage.getItem('refresh_token')).toBe('legacy-refresh-token') + expect(localStorage.getItem('token_expires_at')).not.toBeNull() + expect(showSuccess).toHaveBeenCalledWith('auth.loginSuccess') + expect(replace).toHaveBeenCalledWith('/legacy-dashboard') + }) + + it('accepts the legacy pending oauth invitation fragment without pending-session exchange', async () => { + window.location.hash = '#error=invitation_required&pending_oauth_token=legacy-pending-token&redirect=%2Flegacy-invite' + apiClientPost.mockResolvedValue({ + data: { + access_token: 'legacy-access-token', + refresh_token: 'legacy-refresh-token', + expires_in: 3600, + token_type: 'Bearer' + } + }) + setToken.mockResolvedValue({}) + + const wrapper = mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(exchangePendingOAuthCompletion).not.toHaveBeenCalled() + await wrapper.find('input[type="text"]').setValue('invite-code') + await wrapper.find('button').trigger('click') + await flushPromises() + + expect(apiClientPost).toHaveBeenCalledWith('/auth/oauth/oidc/complete-registration', { + adopt_display_name: true, + adopt_avatar: true, + pending_oauth_token: 'legacy-pending-token', + invitation_code: 'invite-code' + }) + expect(setToken).toHaveBeenCalledWith('legacy-access-token') + expect(replace).toHaveBeenCalledWith('/legacy-invite') + }) + + it('does not send adoption decisions during the initial exchange', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + access_token: 'access-token', + refresh_token: 'refresh-token', + expires_in: 3600, + redirect: '/dashboard', + adoption_required: true + }) + setToken.mockResolvedValue({}) + + mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(1) + expect(exchangePendingOAuthCompletion).toHaveBeenCalledWith() + }) + + it('waits for explicit adoption confirmation before finishing a non-invitation login', async () => { + exchangePendingOAuthCompletion + .mockResolvedValueOnce({ + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'OIDC Nick', + suggested_avatar_url: 'https://cdn.example/oidc.png' + }) + .mockResolvedValueOnce({ + access_token: 'access-token', + refresh_token: 'refresh-token', + expires_in: 3600, + redirect: '/dashboard' + }) + setToken.mockResolvedValue({}) + + const wrapper = mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(wrapper.text()).toContain('OIDC Nick') + expect(setToken).not.toHaveBeenCalled() + expect(replace).not.toHaveBeenCalled() + + const checkboxes = wrapper.findAll('input[type="checkbox"]') + await checkboxes[0].setValue(false) + + await wrapper.findAll('button')[0].trigger('click') + await flushPromises() + + expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(2) + expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(1) + expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(2, { + adoptDisplayName: false, + adoptAvatar: true + }) + expect(setToken).toHaveBeenCalledWith('access-token') + expect(replace).toHaveBeenCalledWith('/dashboard') + }) + + it('supports bind completion after adoption confirmation', async () => { + exchangePendingOAuthCompletion + .mockResolvedValueOnce({ + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'OIDC Nick', + suggested_avatar_url: 'https://cdn.example/oidc.png' + }) + .mockResolvedValueOnce({ + redirect: '/profile' + }) + + const wrapper = mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + await wrapper.findAll('button')[0].trigger('click') + await flushPromises() + + expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(2, { + adoptDisplayName: true, + adoptAvatar: true + }) + expect(setToken).not.toHaveBeenCalled() + expect(showSuccess).toHaveBeenCalledWith('profile.authBindings.bindSuccess') + expect(replace).toHaveBeenCalledWith('/profile') + }) + + it('keeps rendering pending bind-login UI when adoption confirmation leads to another pending step', async () => { + exchangePendingOAuthCompletion + .mockResolvedValueOnce({ + redirect: '/profile', + adoption_required: true, + suggested_display_name: 'OIDC Nick', + suggested_avatar_url: 'https://cdn.example/oidc.png' + }) + .mockResolvedValueOnce({ + step: 'bind_login_required', + redirect: '/profile', + email: 'existing@example.com', + adoption_required: true, + suggested_display_name: 'OIDC Nick', + suggested_avatar_url: 'https://cdn.example/oidc.png' + }) + + const wrapper = mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + await wrapper.findAll('button')[0].trigger('click') + await flushPromises() + + expect(showSuccess).not.toHaveBeenCalled() + expect(replace).not.toHaveBeenCalled() + expect((wrapper.get('[data-testid="oidc-bind-login-email"]').element as HTMLInputElement).value).toBe( + 'existing@example.com' + ) + }) + + it('persists a pending auth session when the oauth flow still needs account creation', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'email_required', + redirect: '/welcome' + }) + + mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(setPendingAuthSession).toHaveBeenCalledWith({ + token: '', + token_field: 'pending_oauth_token', + provider: 'oidc', + redirect: '/welcome' + }) + }) + + it('renders adoption choices for invitation flow and submits the selected values', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'invitation_required', + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'OIDC Nick', + suggested_avatar_url: 'https://cdn.example/oidc.png' + }) + completeOIDCOAuthRegistration.mockResolvedValue({ + access_token: 'access-token', + refresh_token: 'refresh-token', + expires_in: 3600, + token_type: 'Bearer' + }) + setToken.mockResolvedValue({}) + + const wrapper = mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + const checkboxes = wrapper.findAll('input[type="checkbox"]') + expect(checkboxes).toHaveLength(2) + await checkboxes[1].setValue(false) + await wrapper.find('input[type="text"]').setValue('invite-code') + await wrapper.find('button').trigger('click') + + expect(completeOIDCOAuthRegistration).toHaveBeenCalledWith('invite-code', { + adoptDisplayName: true, + adoptAvatar: false + }) + }) + + it('collects email, password, and verify code for pending oauth account creation and submits adoption decisions', async () => { + getPublicSettings.mockResolvedValue({ + oidc_oauth_provider_name: 'ExampleID', + invitation_code_enabled: true, + turnstile_enabled: false, + turnstile_site_key: '' + }) + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'email_required', + redirect: '/welcome', + adoption_required: true, + suggested_display_name: 'OIDC Nick', + suggested_avatar_url: 'https://cdn.example/oidc.png' + }) + apiClientPost.mockResolvedValue({ + data: { + access_token: 'new-access-token', + refresh_token: 'new-refresh-token', + expires_in: 3600, + token_type: 'Bearer' + } + }) + setToken.mockResolvedValue({}) + + const wrapper = mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + const checkboxes = wrapper.findAll('input[type="checkbox"]') + expect(checkboxes).toHaveLength(2) + await checkboxes[1].setValue(false) + await wrapper.get('[data-testid="oidc-create-account-email"]').setValue(' new@example.com ') + await wrapper.get('[data-testid="oidc-create-account-password"]').setValue('secret-123') + await wrapper.get('[data-testid="oidc-create-account-verify-code"]').setValue('246810') + await wrapper.get('[data-testid="oidc-create-account-invitation-code"]').setValue(' INVITE123 ') + await wrapper.get('[data-testid="oidc-create-account-submit"]').trigger('click') + await flushPromises() + + expect(apiClientPost).toHaveBeenCalledWith('/auth/oauth/pending/create-account', { + email: 'new@example.com', + password: 'secret-123', + verify_code: '246810', + invitation_code: 'INVITE123', + adopt_display_name: true, + adopt_avatar: false + }) + expect(setToken).toHaveBeenCalledWith('new-access-token') + expect(replace).toHaveBeenCalledWith('/welcome') + }) + + it('switches to bind-login when create-account returns EMAIL_EXISTS', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'email_required', + redirect: '/welcome' + }) + apiClientPost.mockRejectedValue({ + response: { + data: { + reason: 'EMAIL_EXISTS', + message: 'email already exists' + } + } + }) + + const wrapper = mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + await wrapper.get('[data-testid="oidc-create-account-email"]').setValue('existing@example.com') + await wrapper.get('[data-testid="oidc-create-account-password"]').setValue('secret-123') + await wrapper.get('[data-testid="oidc-create-account-submit"]').trigger('click') + await flushPromises() + + expect((wrapper.get('[data-testid="oidc-bind-login-email"]').element as HTMLInputElement).value).toBe( + 'existing@example.com' + ) + }) + + it('shows create-account failures through toast without inline error text', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'email_required', + redirect: '/welcome' + }) + apiClientPost.mockRejectedValue(new Error('create failed')) + + const wrapper = mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + await wrapper.get('[data-testid="oidc-create-account-email"]').setValue('new@example.com') + await wrapper.get('[data-testid="oidc-create-account-password"]').setValue('secret-123') + await wrapper.get('[data-testid="oidc-create-account-submit"]').trigger('click') + await flushPromises() + + expect(showError).toHaveBeenCalledWith('create failed') + expect(wrapper.text()).not.toContain('create failed') + }) + + it('sends a verify code for pending oauth account creation', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'email_required', + redirect: '/welcome' + }) + sendPendingOAuthVerifyCode.mockResolvedValue({ + message: 'sent', + countdown: 60 + }) + + const wrapper = mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + await wrapper.get('[data-testid="oidc-create-account-email"]').setValue(' new@example.com ') + await wrapper.get('[data-testid="oidc-create-account-send-code"]').trigger('click') + await flushPromises() + + expect(sendPendingOAuthVerifyCode).toHaveBeenCalledWith({ + email: 'new@example.com' + }) + }) + + it('shows bind-login form for existing account binding and submits credentials with adoption decisions', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'adopt_existing_user_by_email', + redirect: '/profile/security', + email: 'existing@example.com', + adoption_required: true, + suggested_display_name: 'OIDC Nick', + suggested_avatar_url: 'https://cdn.example/oidc.png' + }) + apiClientPost.mockResolvedValue({ + data: { + access_token: 'bind-access-token', + refresh_token: 'bind-refresh-token', + expires_in: 3600, + token_type: 'Bearer' + } + }) + setToken.mockResolvedValue({}) + + const wrapper = mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + const checkboxes = wrapper.findAll('input[type="checkbox"]') + expect(checkboxes).toHaveLength(2) + await checkboxes[0].setValue(false) + await wrapper.get('[data-testid="oidc-bind-login-email"]').setValue('existing@example.com') + await wrapper.get('[data-testid="oidc-bind-login-password"]').setValue('secret-password') + await wrapper.get('[data-testid="oidc-bind-login-submit"]').trigger('click') + await flushPromises() + + expect(apiClientPost).toHaveBeenCalledWith('/auth/oauth/pending/bind-login', { + email: 'existing@example.com', + password: 'secret-password', + adopt_display_name: false, + adopt_avatar: true + }) + expect(setToken).toHaveBeenCalledWith('bind-access-token') + expect(replace).toHaveBeenCalledWith('/profile/security') + }) + + it('handles bind-login 2FA challenge before redirecting', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'adopt_existing_user_by_email', + redirect: '/profile', + email: 'existing@example.com', + adoption_required: true, + suggested_display_name: 'OIDC Nick', + suggested_avatar_url: 'https://cdn.example/oidc.png' + }) + apiClientPost.mockResolvedValue({ + data: { + requires_2fa: true, + temp_token: 'temp-123', + user_email_masked: 'o***g@example.com' + } + }) + login2FA.mockResolvedValue({ + access_token: '2fa-access-token' + }) + setToken.mockResolvedValue({}) + + const wrapper = mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + await wrapper.get('[data-testid="oidc-bind-login-password"]').setValue('secret-password') + await wrapper.get('[data-testid="oidc-bind-login-submit"]').trigger('click') + await flushPromises() + + expect(wrapper.text()).toContain('o***g@example.com') + expect(login2FA).not.toHaveBeenCalled() + + await wrapper.get('[data-testid="oidc-bind-login-totp"]').setValue('123456') + await wrapper.get('[data-testid="oidc-bind-login-totp-submit"]').trigger('click') + await flushPromises() + + expect(login2FA).toHaveBeenCalledWith({ + temp_token: 'temp-123', + totp_code: '123456' + }) + expect(setToken).toHaveBeenCalledWith('2fa-access-token') + expect(replace).toHaveBeenCalledWith('/profile') + }) +}) diff --git a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..cc72107d46f3d67a1d2ded7437b7bdbe11c6acfe --- /dev/null +++ b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts @@ -0,0 +1,1017 @@ +import { flushPromises, mount } from '@vue/test-utils' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import WechatCallbackView from '@/views/auth/WechatCallbackView.vue' + +const { + exchangePendingOAuthCompletionMock, + completeWeChatOAuthRegistrationMock, + login2FAMock, + apiClientPostMock, + sendVerifyCodeMock, + sendPendingOAuthVerifyCodeMock, + getPublicSettingsMock, + prepareOAuthBindAccessTokenCookieMock, + getAuthTokenMock, + replaceMock, + setTokenMock, + setPendingAuthSessionMock, + clearPendingAuthSessionMock, + showSuccessMock, + showErrorMock, + fetchPublicSettingsMock, + routeState, + locationState, + appStoreState, +} = vi.hoisted(() => ({ + exchangePendingOAuthCompletionMock: vi.fn(), + completeWeChatOAuthRegistrationMock: vi.fn(), + login2FAMock: vi.fn(), + apiClientPostMock: vi.fn(), + sendVerifyCodeMock: vi.fn(), + sendPendingOAuthVerifyCodeMock: vi.fn(), + getPublicSettingsMock: vi.fn(), + prepareOAuthBindAccessTokenCookieMock: vi.fn(), + getAuthTokenMock: vi.fn(), + replaceMock: vi.fn(), + setTokenMock: vi.fn(), + setPendingAuthSessionMock: vi.fn(), + clearPendingAuthSessionMock: vi.fn(), + showSuccessMock: vi.fn(), + showErrorMock: vi.fn(), + fetchPublicSettingsMock: vi.fn(), + routeState: { + query: {} as Record, + }, + locationState: { + current: { + href: 'http://localhost/auth/wechat/callback', + hash: '', + search: '', + pathname: '/auth/wechat/callback' + } as { href: string; hash: string; search: string; pathname: string }, + }, + appStoreState: { + cachedPublicSettings: null as null | Record, + publicSettingsLoaded: false, + }, +})) + +vi.mock('vue-router', () => ({ + useRoute: () => routeState, + useRouter: () => ({ + replace: replaceMock, + }), +})) + +vi.mock('vue-i18n', () => ({ + createI18n: () => ({ + global: { + t: (key: string) => key, + }, + }), + useI18n: () => ({ + t: (key: string, params?: Record) => { + if (key === 'auth.oauthFlow.totpHint') { + return `verify ${params?.account ?? ''}`.trim() + } + if (key === 'auth.oidc.callbackTitle') { + return `Signing you in with ${params?.providerName ?? ''}`.trim() + } + if (key === 'auth.oidc.callbackProcessing') { + return `Completing login with ${params?.providerName ?? ''}`.trim() + } + if (key === 'auth.oidc.invitationRequired') { + return `${params?.providerName ?? ''} invitation required`.trim() + } + if (key === 'auth.oidc.completeRegistration') { + return 'Complete registration' + } + if (key === 'auth.oidc.completing') { + return 'Completing' + } + if (key === 'auth.oidc.backToLogin') { + return 'Back to login' + } + if (key === 'auth.invitationCodePlaceholder') { + return 'Invitation code' + } + if (key === 'auth.loginSuccess') { + return 'Login success' + } + if (key === 'auth.loginFailed') { + return 'Login failed' + } + if (key === 'auth.oidc.callbackHint') { + return 'Callback hint' + } + if (key === 'auth.oidc.callbackMissingToken') { + return 'Missing login token' + } + if (key === 'auth.oidc.completeRegistrationFailed') { + return 'Complete registration failed' + } + return key + }, + }), +})) + +vi.mock('@/stores', () => ({ + useAuthStore: () => ({ + setToken: setTokenMock, + setPendingAuthSession: setPendingAuthSessionMock, + clearPendingAuthSession: clearPendingAuthSessionMock, + }), + useAppStore: () => ({ + ...appStoreState, + showSuccess: showSuccessMock, + showError: showErrorMock, + fetchPublicSettings: fetchPublicSettingsMock, + }), +})) + +vi.mock('@/api/client', () => ({ + apiClient: { + post: (...args: any[]) => apiClientPostMock(...args), + }, +})) + +vi.mock('@/api/auth', async () => { + const actual = await vi.importActual('@/api/auth') + return { + ...actual, + exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletionMock(...args), + completeWeChatOAuthRegistration: (...args: any[]) => completeWeChatOAuthRegistrationMock(...args), + login2FA: (...args: any[]) => login2FAMock(...args), + sendVerifyCode: (...args: any[]) => sendVerifyCodeMock(...args), + sendPendingOAuthVerifyCode: (...args: any[]) => sendPendingOAuthVerifyCodeMock(...args), + getPublicSettings: (...args: any[]) => getPublicSettingsMock(...args), + prepareOAuthBindAccessTokenCookie: (...args: any[]) => prepareOAuthBindAccessTokenCookieMock(...args), + getAuthToken: (...args: any[]) => getAuthTokenMock(...args), + } +}) + +describe('WechatCallbackView', () => { + beforeEach(() => { + exchangePendingOAuthCompletionMock.mockReset() + completeWeChatOAuthRegistrationMock.mockReset() + login2FAMock.mockReset() + apiClientPostMock.mockReset() + sendVerifyCodeMock.mockReset() + sendPendingOAuthVerifyCodeMock.mockReset() + getPublicSettingsMock.mockReset() + replaceMock.mockReset() + setTokenMock.mockReset() + setPendingAuthSessionMock.mockReset() + clearPendingAuthSessionMock.mockReset() + showSuccessMock.mockReset() + showErrorMock.mockReset() + prepareOAuthBindAccessTokenCookieMock.mockReset() + getAuthTokenMock.mockReset() + fetchPublicSettingsMock.mockReset() + routeState.query = {} + appStoreState.cachedPublicSettings = null + appStoreState.publicSettingsLoaded = false + localStorage.clear() + locationState.current = { + href: 'http://localhost/auth/wechat/callback', + hash: '', + search: '', + pathname: '/auth/wechat/callback' + } + Object.defineProperty(window, 'location', { + configurable: true, + value: locationState.current, + }) + Object.defineProperty(window.navigator, 'userAgent', { + configurable: true, + value: 'Mozilla/5.0', + }) + getPublicSettingsMock.mockResolvedValue({ + invitation_code_enabled: false, + turnstile_enabled: false, + turnstile_site_key: '', + }) + }) + + it('overrides an incompatible query mode with the configured open capability during bind recovery', async () => { + routeState.query = { + wechat_bind_existing: '1', + mode: 'mp', + redirect: '/profile', + } + appStoreState.cachedPublicSettings = { + wechat_oauth_enabled: true, + wechat_oauth_open_enabled: true, + wechat_oauth_mp_enabled: false, + } + appStoreState.publicSettingsLoaded = true + getAuthTokenMock.mockReturnValue('current-auth-token') + + mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(prepareOAuthBindAccessTokenCookieMock).toHaveBeenCalledTimes(1) + expect(locationState.current.href).toContain('mode=open') + expect(locationState.current.href).not.toContain('mode=mp') + }) + + it('falls back to the query mode when capability settings cannot be confirmed', async () => { + routeState.query = { + wechat_bind_existing: '1', + mode: 'mp', + redirect: '/profile', + } + fetchPublicSettingsMock.mockResolvedValue(null) + getAuthTokenMock.mockReturnValue('current-auth-token') + + mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(prepareOAuthBindAccessTokenCookieMock).toHaveBeenCalledTimes(1) + expect(locationState.current.href).toContain('mode=mp') + }) + + it('ignores legacy aggregate wechat settings and reuses the query mode during bind recovery', async () => { + routeState.query = { + wechat_bind_existing: '1', + mode: 'open', + redirect: '/profile', + } + appStoreState.cachedPublicSettings = { + wechat_oauth_enabled: true, + } + appStoreState.publicSettingsLoaded = true + getAuthTokenMock.mockReturnValue('current-auth-token') + + mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(prepareOAuthBindAccessTokenCookieMock).toHaveBeenCalledTimes(1) + expect(locationState.current.href).toContain('mode=open') + }) + + it('accepts the legacy fragment token success callback without pending-session exchange', async () => { + locationState.current.hash = + '#access_token=legacy-access-token&refresh_token=legacy-refresh-token&expires_in=3600&token_type=Bearer&redirect=%2Flegacy-dashboard' + Object.defineProperty(window, 'location', { + configurable: true, + value: locationState.current, + }) + setTokenMock.mockResolvedValue({}) + + mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(exchangePendingOAuthCompletionMock).not.toHaveBeenCalled() + expect(setTokenMock).toHaveBeenCalledWith('legacy-access-token') + expect(localStorage.getItem('refresh_token')).toBe('legacy-refresh-token') + expect(localStorage.getItem('token_expires_at')).not.toBeNull() + expect(showSuccessMock).toHaveBeenCalledWith('Login success') + expect(replaceMock).toHaveBeenCalledWith('/legacy-dashboard') + }) + + it('accepts the legacy pending oauth invitation fragment without pending-session exchange', async () => { + locationState.current.hash = + '#error=invitation_required&pending_oauth_token=legacy-pending-token&redirect=%2Flegacy-invite' + Object.defineProperty(window, 'location', { + configurable: true, + value: locationState.current, + }) + apiClientPostMock.mockResolvedValue({ + data: { + access_token: 'legacy-access-token', + refresh_token: 'legacy-refresh-token', + expires_in: 3600, + token_type: 'Bearer', + }, + }) + setTokenMock.mockResolvedValue({}) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(exchangePendingOAuthCompletionMock).not.toHaveBeenCalled() + await wrapper.find('input[type="text"]').setValue('invite-code') + await wrapper.find('button').trigger('click') + await flushPromises() + + expect(apiClientPostMock).toHaveBeenCalledWith('/auth/oauth/wechat/complete-registration', { + pending_oauth_token: 'legacy-pending-token', + invitation_code: 'invite-code', + adopt_display_name: true, + adopt_avatar: true, + }) + expect(setTokenMock).toHaveBeenCalledWith('legacy-access-token') + expect(replaceMock).toHaveBeenCalledWith('/legacy-invite') + }) + + it('does not send adoption decisions during the initial exchange', async () => { + exchangePendingOAuthCompletionMock.mockResolvedValue({ + access_token: 'access-token', + refresh_token: 'refresh-token', + expires_in: 3600, + redirect: '/dashboard', + adoption_required: true, + }) + setTokenMock.mockResolvedValue({}) + + mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(exchangePendingOAuthCompletionMock).toHaveBeenCalledWith() + expect(exchangePendingOAuthCompletionMock).toHaveBeenCalledTimes(1) + }) + + it('waits for explicit adoption confirmation before finishing a non-invitation login', async () => { + exchangePendingOAuthCompletionMock + .mockResolvedValueOnce({ + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'WeChat Nick', + suggested_avatar_url: 'https://cdn.example/wechat.png', + }) + .mockResolvedValueOnce({ + access_token: 'wechat-access-token', + refresh_token: 'wechat-refresh-token', + expires_in: 3600, + token_type: 'Bearer', + redirect: '/dashboard', + }) + setTokenMock.mockResolvedValue({}) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(wrapper.text()).toContain('WeChat Nick') + expect(setTokenMock).not.toHaveBeenCalled() + expect(replaceMock).not.toHaveBeenCalled() + + const checkboxes = wrapper.findAll('input[type="checkbox"]') + expect(checkboxes).toHaveLength(2) + await checkboxes[1].setValue(false) + + const buttons = wrapper.findAll('button') + expect(buttons).toHaveLength(1) + await buttons[0].trigger('click') + await flushPromises() + + expect(exchangePendingOAuthCompletionMock).toHaveBeenNthCalledWith(1) + expect(exchangePendingOAuthCompletionMock).toHaveBeenNthCalledWith(2, { + adoptDisplayName: true, + adoptAvatar: false, + }) + expect(setTokenMock).toHaveBeenCalledWith('wechat-access-token') + expect(replaceMock).toHaveBeenCalledWith('/dashboard') + expect(localStorage.getItem('refresh_token')).toBe('wechat-refresh-token') + }) + + it('supports bind completion after adoption confirmation', async () => { + exchangePendingOAuthCompletionMock + .mockResolvedValueOnce({ + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'WeChat Nick', + suggested_avatar_url: 'https://cdn.example/wechat.png', + }) + .mockResolvedValueOnce({ + redirect: '/profile/connections', + }) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + await wrapper.findAll('button')[0].trigger('click') + await flushPromises() + + expect(exchangePendingOAuthCompletionMock).toHaveBeenNthCalledWith(2, { + adoptDisplayName: true, + adoptAvatar: true, + }) + expect(setTokenMock).not.toHaveBeenCalled() + expect(clearPendingAuthSessionMock).toHaveBeenCalledTimes(1) + expect(showSuccessMock).toHaveBeenCalledWith('profile.authBindings.bindSuccess') + expect(replaceMock).toHaveBeenCalledWith('/profile/connections') + }) + + it('renders adoption choices for invitation flow and submits the selected values', async () => { + exchangePendingOAuthCompletionMock.mockResolvedValue({ + error: 'invitation_required', + redirect: '/subscriptions', + adoption_required: true, + suggested_display_name: 'WeChat Nick', + suggested_avatar_url: 'https://cdn.example/wechat.png', + }) + completeWeChatOAuthRegistrationMock.mockResolvedValue({ + access_token: 'wechat-invite-token', + refresh_token: 'wechat-invite-refresh', + expires_in: 600, + token_type: 'Bearer', + }) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(wrapper.text()).toContain('WeChat Nick') + const checkboxes = wrapper.findAll('input[type="checkbox"]') + expect(checkboxes).toHaveLength(2) + await checkboxes[0].setValue(false) + await wrapper.get('input[type="text"]').setValue(' INVITE-CODE ') + await wrapper.get('button').trigger('click') + await flushPromises() + + expect(completeWeChatOAuthRegistrationMock).toHaveBeenCalledWith('INVITE-CODE', { + adoptDisplayName: false, + adoptAvatar: true, + }) + expect(setTokenMock).toHaveBeenCalledWith('wechat-invite-token') + expect(replaceMock).toHaveBeenCalledWith('/subscriptions') + }) + + it('offers existing-account email collection during invitation flow', async () => { + exchangePendingOAuthCompletionMock.mockResolvedValue({ + error: 'invitation_required', + redirect: '/usage', + }) + getAuthTokenMock.mockReturnValue(null) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + const emailInput = wrapper.get('[data-testid="existing-account-email"]') + await emailInput.setValue('user@example.com') + await wrapper.get('[data-testid="existing-account-submit"]').trigger('click') + + expect(replaceMock).toHaveBeenCalledTimes(1) + expect(replaceMock.mock.calls[0]?.[0]).toContain('/login?') + expect(replaceMock.mock.calls[0]?.[0]).toContain('wechat_bind_existing%3D1') + expect(replaceMock.mock.calls[0]?.[0]).toContain('email=user%40example.com') + expect(replaceMock.mock.calls[0]?.[0]).toContain('mode%3Dopen') + }) + + it('binds directly to the current signed-in account during invitation flow', async () => { + exchangePendingOAuthCompletionMock.mockResolvedValue({ + error: 'invitation_required', + redirect: '/usage', + }) + getAuthTokenMock.mockReturnValue('current-auth-token') + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(wrapper.find('[data-testid="existing-account-email"]').exists()).toBe(false) + await wrapper.get('[data-testid="existing-account-submit"]').trigger('click') + + expect(prepareOAuthBindAccessTokenCookieMock).toHaveBeenCalledTimes(1) + expect(locationState.current.href).toContain('intent=bind_current_user') + expect(locationState.current.href).toContain('redirect=%2Fusage') + expect(locationState.current.href).toContain('mode=open') + }) + + it('collects email, password, and verify code for pending oauth account creation and submits adoption decisions', async () => { + getPublicSettingsMock.mockResolvedValue({ + invitation_code_enabled: true, + turnstile_enabled: false, + turnstile_site_key: '', + }) + exchangePendingOAuthCompletionMock.mockResolvedValue({ + error: 'email_required', + redirect: '/welcome', + adoption_required: true, + suggested_display_name: 'WeChat Nick', + suggested_avatar_url: 'https://cdn.example/wechat.png', + }) + apiClientPostMock.mockResolvedValue({ + data: { + access_token: 'new-access-token', + refresh_token: 'new-refresh-token', + expires_in: 3600, + token_type: 'Bearer', + }, + }) + setTokenMock.mockResolvedValue({}) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + const checkboxes = wrapper.findAll('input[type="checkbox"]') + expect(checkboxes).toHaveLength(2) + await checkboxes[1].setValue(false) + await wrapper.get('[data-testid="wechat-create-account-email"]').setValue(' new@example.com ') + await wrapper.get('[data-testid="wechat-create-account-password"]').setValue('secret-123') + await wrapper.get('[data-testid="wechat-create-account-verify-code"]').setValue('246810') + await wrapper.get('[data-testid="wechat-create-account-invitation-code"]').setValue(' INVITE123 ') + await wrapper.get('[data-testid="wechat-create-account-submit"]').trigger('click') + await flushPromises() + + expect(apiClientPostMock).toHaveBeenCalledWith('/auth/oauth/pending/create-account', { + email: 'new@example.com', + password: 'secret-123', + verify_code: '246810', + invitation_code: 'INVITE123', + adopt_display_name: true, + adopt_avatar: false, + }) + expect(setTokenMock).toHaveBeenCalledWith('new-access-token') + expect(replaceMock).toHaveBeenCalledWith('/welcome') + }) + + it('persists a pending auth session when the oauth flow still needs account creation', async () => { + exchangePendingOAuthCompletionMock.mockResolvedValue({ + error: 'email_required', + redirect: '/welcome', + }) + + mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(setPendingAuthSessionMock).toHaveBeenCalledWith({ + token: '', + token_field: 'pending_oauth_token', + provider: 'wechat', + redirect: '/welcome', + }) + }) + + it('switches to bind-login when create-account returns EMAIL_EXISTS', async () => { + exchangePendingOAuthCompletionMock.mockResolvedValue({ + error: 'email_required', + redirect: '/welcome', + }) + apiClientPostMock.mockRejectedValue({ + response: { + data: { + reason: 'EMAIL_EXISTS', + message: 'email already exists', + }, + }, + }) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + await wrapper.get('[data-testid="wechat-create-account-email"]').setValue('existing@example.com') + await wrapper.get('[data-testid="wechat-create-account-password"]').setValue('secret-123') + await wrapper.get('[data-testid="wechat-create-account-submit"]').trigger('click') + await flushPromises() + + expect((wrapper.get('[data-testid="wechat-bind-login-email"]').element as HTMLInputElement).value).toBe( + 'existing@example.com' + ) + }) + + it('shows create-account failures through toast without inline error text', async () => { + exchangePendingOAuthCompletionMock.mockResolvedValue({ + error: 'email_required', + redirect: '/welcome', + }) + apiClientPostMock.mockRejectedValue(new Error('create failed')) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + await wrapper.get('[data-testid="wechat-create-account-email"]').setValue('new@example.com') + await wrapper.get('[data-testid="wechat-create-account-password"]').setValue('secret-123') + await wrapper.get('[data-testid="wechat-create-account-submit"]').trigger('click') + await flushPromises() + + expect(showErrorMock).toHaveBeenCalledWith('create failed') + expect(wrapper.text()).not.toContain('create failed') + }) + + it('sends a verify code for pending oauth account creation', async () => { + exchangePendingOAuthCompletionMock.mockResolvedValue({ + error: 'email_required', + redirect: '/welcome', + }) + sendPendingOAuthVerifyCodeMock.mockResolvedValue({ + message: 'sent', + countdown: 60, + }) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + await wrapper.get('[data-testid="wechat-create-account-email"]').setValue(' new@example.com ') + await wrapper.get('[data-testid="wechat-create-account-send-code"]').trigger('click') + await flushPromises() + + expect(sendPendingOAuthVerifyCodeMock).toHaveBeenCalledWith({ + email: 'new@example.com', + }) + }) + + it('shows bind-login form for existing account binding and submits credentials with adoption decisions', async () => { + exchangePendingOAuthCompletionMock.mockResolvedValue({ + step: 'bind_login_required', + redirect: '/profile/security', + email: 'existing@example.com', + adoption_required: true, + suggested_display_name: 'WeChat Nick', + suggested_avatar_url: 'https://cdn.example/wechat.png', + }) + apiClientPostMock.mockResolvedValue({ + data: { + access_token: 'bind-access-token', + refresh_token: 'bind-refresh-token', + expires_in: 3600, + token_type: 'Bearer', + }, + }) + setTokenMock.mockResolvedValue({}) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + const checkboxes = wrapper.findAll('input[type="checkbox"]') + expect(checkboxes).toHaveLength(2) + await checkboxes[0].setValue(false) + await wrapper.get('[data-testid="wechat-bind-login-email"]').setValue('existing@example.com') + await wrapper.get('[data-testid="wechat-bind-login-password"]').setValue('secret-password') + await wrapper.get('[data-testid="wechat-bind-login-submit"]').trigger('click') + await flushPromises() + + expect(apiClientPostMock).toHaveBeenCalledWith('/auth/oauth/pending/bind-login', { + email: 'existing@example.com', + password: 'secret-password', + adopt_display_name: false, + adopt_avatar: true, + }) + expect(setTokenMock).toHaveBeenCalledWith('bind-access-token') + expect(replaceMock).toHaveBeenCalledWith('/profile/security') + }) + + it('allows switching from server-driven bind-login to create-account mode', async () => { + exchangePendingOAuthCompletionMock.mockResolvedValue({ + step: 'bind_login_required', + redirect: '/welcome', + email: 'existing@example.com', + }) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + await wrapper.get('button.btn-secondary').trigger('click') + await flushPromises() + + const createAccountEmail = wrapper.get('[data-testid="wechat-create-account-email"]') + expect((createAccountEmail.element as HTMLInputElement).value).toBe('existing@example.com') + }) + + it('reuses query email for bind-login when backend does not echo it back', async () => { + routeState.query = { + email: 'resume@example.com', + } + exchangePendingOAuthCompletionMock.mockResolvedValue({ + step: 'bind_login_required', + redirect: '/profile', + }) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + const bindEmail = wrapper.get('[data-testid="wechat-bind-login-email"]') + expect((bindEmail.element as HTMLInputElement).value).toBe('resume@example.com') + }) + + it('keeps rendering pending bind-login UI when adoption confirmation leads to another pending step', async () => { + exchangePendingOAuthCompletionMock + .mockResolvedValueOnce({ + redirect: '/profile', + adoption_required: true, + suggested_display_name: 'WeChat Nick', + suggested_avatar_url: 'https://cdn.example/wechat.png', + }) + .mockResolvedValueOnce({ + step: 'bind_login_required', + redirect: '/profile', + email: 'existing@example.com', + adoption_required: true, + suggested_display_name: 'WeChat Nick', + suggested_avatar_url: 'https://cdn.example/wechat.png', + }) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + await wrapper.findAll('button')[0].trigger('click') + await flushPromises() + + expect(showSuccessMock).not.toHaveBeenCalled() + expect(replaceMock).not.toHaveBeenCalled() + expect((wrapper.get('[data-testid="wechat-bind-login-email"]').element as HTMLInputElement).value).toBe( + 'existing@example.com' + ) + }) + + it('handles bind-login 2FA challenge before redirecting', async () => { + exchangePendingOAuthCompletionMock.mockResolvedValue({ + error: 'adopt_existing_user_by_email', + redirect: '/profile', + email: 'existing@example.com', + adoption_required: true, + suggested_display_name: 'WeChat Nick', + suggested_avatar_url: 'https://cdn.example/wechat.png', + }) + apiClientPostMock.mockResolvedValue({ + data: { + requires_2fa: true, + temp_token: 'temp-123', + user_email_masked: 'o***g@example.com', + }, + }) + login2FAMock.mockResolvedValue({ + access_token: '2fa-access-token', + refresh_token: '2fa-refresh-token', + expires_in: 3600, + }) + setTokenMock.mockResolvedValue({}) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + await wrapper.get('[data-testid="wechat-bind-login-password"]').setValue('secret-password') + await wrapper.get('[data-testid="wechat-bind-login-submit"]').trigger('click') + await flushPromises() + + expect(wrapper.text()).toContain('o***g@example.com') + expect(login2FAMock).not.toHaveBeenCalled() + + await wrapper.get('[data-testid="wechat-bind-login-totp"]').setValue('123456') + await wrapper.get('[data-testid="wechat-bind-login-totp-submit"]').trigger('click') + await flushPromises() + + expect(login2FAMock).toHaveBeenCalledWith({ + temp_token: 'temp-123', + totp_code: '123456', + }) + expect(setTokenMock).toHaveBeenCalledWith('2fa-access-token') + expect(replaceMock).toHaveBeenCalledWith('/profile') + expect(localStorage.getItem('refresh_token')).toBe('2fa-refresh-token') + }) + + it('restarts the current-user bind flow after returning from login', async () => { + routeState.query = { + wechat_bind_existing: '1', + redirect: '/profile', + mode: 'mp', + } + getAuthTokenMock.mockReturnValue('existing-auth-token') + + mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(exchangePendingOAuthCompletionMock).not.toHaveBeenCalled() + expect(prepareOAuthBindAccessTokenCookieMock).toHaveBeenCalledTimes(1) + expect(locationState.current.href).toContain('/api/v1/auth/oauth/wechat/start?') + expect(locationState.current.href).toContain('mode=mp') + expect(locationState.current.href).toContain('intent=bind_current_user') + expect(locationState.current.href).toContain('redirect=%2Fprofile') + }) + + it('redirects back to login instead of falling through when bind-existing resume has no auth token', async () => { + routeState.query = { + wechat_bind_existing: '1', + redirect: '/profile', + mode: 'mp', + email: 'resume@example.com', + } + getAuthTokenMock.mockReturnValue(null) + + mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(exchangePendingOAuthCompletionMock).not.toHaveBeenCalled() + expect(replaceMock).toHaveBeenCalledTimes(1) + expect(replaceMock.mock.calls[0]?.[0]).toContain('/login?') + expect(replaceMock.mock.calls[0]?.[0]).toContain('wechat_bind_existing%3D1') + expect(replaceMock.mock.calls[0]?.[0]).toContain('mode%3Dmp') + expect(replaceMock.mock.calls[0]?.[0]).toContain('email=resume%40example.com') + }) +}) diff --git a/frontend/src/views/auth/__tests__/WechatPaymentCallbackView.spec.ts b/frontend/src/views/auth/__tests__/WechatPaymentCallbackView.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..93cd0e9478616934965cd52766eeaa23c23f5f09 --- /dev/null +++ b/frontend/src/views/auth/__tests__/WechatPaymentCallbackView.spec.ts @@ -0,0 +1,93 @@ +import { flushPromises, mount } from '@vue/test-utils' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import WechatPaymentCallbackView from '@/views/auth/WechatPaymentCallbackView.vue' + +const { replaceMock, routeState, locationState, showErrorMock } = vi.hoisted(() => ({ + replaceMock: vi.fn(), + routeState: { + query: {} as Record, + }, + locationState: { + current: { + href: 'http://localhost/auth/wechat/payment/callback', + hash: '', + search: '', + pathname: '/auth/wechat/payment/callback', + origin: 'http://localhost', + } as Location & { origin: string }, + }, + showErrorMock: vi.fn(), +})) + +vi.mock('vue-router', () => ({ + useRoute: () => routeState, + useRouter: () => ({ + replace: replaceMock, + }), +})) + +vi.mock('vue-i18n', () => ({ + useI18n: () => ({ + t: (key: string) => { + if (key === 'auth.wechatPayment.callbackTitle') return '正在恢复微信支付' + if (key === 'auth.wechatPayment.callbackProcessing') return '正在恢复微信支付...' + if (key === 'auth.wechatPayment.backToPayment') return '返回支付页' + if (key === 'auth.wechatPayment.callbackMissingResumeToken') return '微信支付回调缺少恢复令牌。' + return key + }, + locale: { value: 'zh-CN' }, + }), +})) + +vi.mock('@/stores', () => ({ + useAppStore: () => ({ + showError: (...args: any[]) => showErrorMock(...args), + }), +})) + +describe('WechatPaymentCallbackView', () => { + beforeEach(() => { + replaceMock.mockReset() + showErrorMock.mockReset() + routeState.query = {} + locationState.current = { + href: 'http://localhost/auth/wechat/payment/callback', + hash: '', + search: '', + pathname: '/auth/wechat/payment/callback', + origin: 'http://localhost', + } as Location & { origin: string } + Object.defineProperty(window, 'location', { + configurable: true, + value: locationState.current, + }) + }) + + it('redirects back to purchase with an opaque resume token from hash fragment', async () => { + locationState.current.hash = '#wechat_resume_token=resume-token-123&redirect=%2Fpurchase%3Ffrom%3Dwechat' + + mount(WechatPaymentCallbackView) + await flushPromises() + + expect(replaceMock).toHaveBeenCalledWith({ + path: '/purchase', + query: { + from: 'wechat', + wechat_resume: '1', + wechat_resume_token: 'resume-token-123', + }, + }) + }) + + it('shows an error when the callback payload is missing the resume token', async () => { + locationState.current.hash = '#payment_type=wxpay' + + const wrapper = mount(WechatPaymentCallbackView) + await flushPromises() + + expect(replaceMock).not.toHaveBeenCalled() + expect(showErrorMock).toHaveBeenCalledWith('微信支付回调缺少恢复令牌。') + expect(wrapper.text()).toContain('微信支付回调缺少恢复令牌。') + expect(wrapper.find('.bg-red-50').exists()).toBe(false) + }) +}) diff --git a/frontend/src/views/user/PaymentResultView.vue b/frontend/src/views/user/PaymentResultView.vue index 6431ddf69f27e813981d39e019daa60e68f592da..57e81f4045d6630436cc90d4f137ec0d3d364954 100644 --- a/frontend/src/views/user/PaymentResultView.vue +++ b/frontend/src/views/user/PaymentResultView.vue @@ -15,6 +15,10 @@
+
+
+
@@ -22,8 +26,11 @@

- {{ isSuccess ? t('payment.result.success') : t('payment.result.failed') }} + {{ statusTitle }}

+

+ {{ t('payment.result.processingHint') }} +

@@ -54,7 +61,7 @@
{{ t('payment.orders.paymentMethod') }} - {{ t('payment.methods.' + order.payment_type, order.payment_type) }} + {{ t(paymentMethodI18nKey(order.payment_type), normalizedOrderPaymentType(order.payment_type)) }}
{{ t('payment.orders.status') }} @@ -75,7 +82,7 @@
{{ t('payment.orders.paymentMethod') }} - {{ t('payment.methods.' + returnInfo.type, returnInfo.type) }} + {{ t(paymentMethodI18nKey(returnInfo.type), normalizedOrderPaymentType(returnInfo.type)) }}
@@ -90,13 +97,15 @@ diff --git a/frontend/src/views/user/PaymentView.vue b/frontend/src/views/user/PaymentView.vue index e2885c80fe8749b9c858e5ea444be8280cafbaf8..7d037917643b3b16db92ed5fc48837326942fbc0 100644 --- a/frontend/src/views/user/PaymentView.vue +++ b/frontend/src/views/user/PaymentView.vue @@ -23,20 +23,7 @@ :order-type="paymentState.orderType" @done="onPaymentDone" @success="onPaymentSuccess" - /> - - @@ -99,9 +86,6 @@ {{ t('payment.createOrder') }} ¥{{ totalAmount.toFixed(2) }} -
-

{{ errorMessage }}

-
@@ -185,9 +169,6 @@ {{ t('payment.createOrder') }} ¥{{ (feeRate > 0 ? subTotalAmount : selectedPlan.price).toFixed(2) }} -
-

{{ errorMessage }}

-