Commit e9de839d authored by IanShaw027's avatar IanShaw027
Browse files

feat: rebuild auth identity foundation flow

parent fbd0a2e3
package schema
import (
"fmt"
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema"
"entgo.io/ent/schema/edge"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/index"
)
var pendingAuthIntents = map[string]struct{}{
"login": {},
"bind_current_user": {},
"adopt_existing_user_by_email": {},
}
func validatePendingAuthIntent(value string) error {
if _, ok := pendingAuthIntents[value]; ok {
return nil
}
return fmt.Errorf("invalid pending auth intent %q", value)
}
// PendingAuthSession stores a short-lived post-auth decision session.
type PendingAuthSession struct {
ent.Schema
}
func (PendingAuthSession) Annotations() []schema.Annotation {
return []schema.Annotation{
entsql.Annotation{Table: "pending_auth_sessions"},
}
}
func (PendingAuthSession) Mixin() []ent.Mixin {
return []ent.Mixin{
mixins.TimeMixin{},
}
}
func (PendingAuthSession) Fields() []ent.Field {
return []ent.Field{
field.String("session_token").
MaxLen(255).
NotEmpty(),
field.String("intent").
MaxLen(40).
NotEmpty().
Validate(validatePendingAuthIntent),
field.String("provider_type").
MaxLen(20).
NotEmpty().
Validate(validateAuthProviderType),
field.String("provider_key").
NotEmpty().
SchemaType(map[string]string{dialect.Postgres: "text"}),
field.String("provider_subject").
NotEmpty().
SchemaType(map[string]string{dialect.Postgres: "text"}),
field.Int64("target_user_id").
Optional().
Nillable(),
field.String("redirect_to").
Default("").
SchemaType(map[string]string{dialect.Postgres: "text"}),
field.String("resolved_email").
Default("").
SchemaType(map[string]string{dialect.Postgres: "text"}),
field.String("registration_password_hash").
Default("").
SchemaType(map[string]string{dialect.Postgres: "text"}),
field.JSON("upstream_identity_claims", map[string]any{}).
Default(func() map[string]any { return map[string]any{} }).
SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
field.JSON("local_flow_state", map[string]any{}).
Default(func() map[string]any { return map[string]any{} }).
SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
field.String("browser_session_key").
Default("").
SchemaType(map[string]string{dialect.Postgres: "text"}),
field.String("completion_code_hash").
Default("").
SchemaType(map[string]string{dialect.Postgres: "text"}),
field.Time("completion_code_expires_at").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
field.Time("email_verified_at").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
field.Time("password_verified_at").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
field.Time("totp_verified_at").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
field.Time("expires_at").
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
field.Time("consumed_at").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
}
}
func (PendingAuthSession) Edges() []ent.Edge {
return []ent.Edge{
edge.From("target_user", User.Type).
Ref("pending_auth_sessions").
Field("target_user_id").
Unique(),
edge.To("adoption_decision", IdentityAdoptionDecision.Type).
Unique(),
}
}
func (PendingAuthSession) Indexes() []ent.Index {
return []ent.Index{
index.Fields("session_token").Unique(),
index.Fields("target_user_id"),
index.Fields("expires_at"),
index.Fields("provider_type", "provider_key", "provider_subject"),
index.Fields("completion_code_hash"),
}
}
...@@ -72,6 +72,17 @@ func (User) Fields() []ent.Field { ...@@ -72,6 +72,17 @@ func (User) Fields() []ent.Field {
field.Time("totp_enabled_at"). field.Time("totp_enabled_at").
Optional(). Optional().
Nillable(), Nillable(),
field.String("signup_source").
MaxLen(20).
Default("email"),
field.Time("last_login_at").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
field.Time("last_active_at").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
// 余额不足通知 // 余额不足通知
field.Bool("balance_notify_enabled"). field.Bool("balance_notify_enabled").
...@@ -104,6 +115,8 @@ func (User) Edges() []ent.Edge { ...@@ -104,6 +115,8 @@ func (User) Edges() []ent.Edge {
edge.To("attribute_values", UserAttributeValue.Type), edge.To("attribute_values", UserAttributeValue.Type),
edge.To("promo_code_usages", PromoCodeUsage.Type), edge.To("promo_code_usages", PromoCodeUsage.Type),
edge.To("payment_orders", PaymentOrder.Type), edge.To("payment_orders", PaymentOrder.Type),
edge.To("auth_identities", AuthIdentity.Type),
edge.To("pending_auth_sessions", PendingAuthSession.Type),
} }
} }
......
...@@ -24,18 +24,26 @@ type Tx struct { ...@@ -24,18 +24,26 @@ type Tx struct {
Announcement *AnnouncementClient Announcement *AnnouncementClient
// AnnouncementRead is the client for interacting with the AnnouncementRead builders. // AnnouncementRead is the client for interacting with the AnnouncementRead builders.
AnnouncementRead *AnnouncementReadClient AnnouncementRead *AnnouncementReadClient
// AuthIdentity is the client for interacting with the AuthIdentity builders.
AuthIdentity *AuthIdentityClient
// AuthIdentityChannel is the client for interacting with the AuthIdentityChannel builders.
AuthIdentityChannel *AuthIdentityChannelClient
// ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders. // ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders.
ErrorPassthroughRule *ErrorPassthroughRuleClient ErrorPassthroughRule *ErrorPassthroughRuleClient
// Group is the client for interacting with the Group builders. // Group is the client for interacting with the Group builders.
Group *GroupClient Group *GroupClient
// IdempotencyRecord is the client for interacting with the IdempotencyRecord builders. // IdempotencyRecord is the client for interacting with the IdempotencyRecord builders.
IdempotencyRecord *IdempotencyRecordClient IdempotencyRecord *IdempotencyRecordClient
// IdentityAdoptionDecision is the client for interacting with the IdentityAdoptionDecision builders.
IdentityAdoptionDecision *IdentityAdoptionDecisionClient
// PaymentAuditLog is the client for interacting with the PaymentAuditLog builders. // PaymentAuditLog is the client for interacting with the PaymentAuditLog builders.
PaymentAuditLog *PaymentAuditLogClient PaymentAuditLog *PaymentAuditLogClient
// PaymentOrder is the client for interacting with the PaymentOrder builders. // PaymentOrder is the client for interacting with the PaymentOrder builders.
PaymentOrder *PaymentOrderClient PaymentOrder *PaymentOrderClient
// PaymentProviderInstance is the client for interacting with the PaymentProviderInstance builders. // PaymentProviderInstance is the client for interacting with the PaymentProviderInstance builders.
PaymentProviderInstance *PaymentProviderInstanceClient PaymentProviderInstance *PaymentProviderInstanceClient
// PendingAuthSession is the client for interacting with the PendingAuthSession builders.
PendingAuthSession *PendingAuthSessionClient
// PromoCode is the client for interacting with the PromoCode builders. // PromoCode is the client for interacting with the PromoCode builders.
PromoCode *PromoCodeClient PromoCode *PromoCodeClient
// PromoCodeUsage is the client for interacting with the PromoCodeUsage builders. // PromoCodeUsage is the client for interacting with the PromoCodeUsage builders.
...@@ -202,12 +210,16 @@ func (tx *Tx) init() { ...@@ -202,12 +210,16 @@ func (tx *Tx) init() {
tx.AccountGroup = NewAccountGroupClient(tx.config) tx.AccountGroup = NewAccountGroupClient(tx.config)
tx.Announcement = NewAnnouncementClient(tx.config) tx.Announcement = NewAnnouncementClient(tx.config)
tx.AnnouncementRead = NewAnnouncementReadClient(tx.config) tx.AnnouncementRead = NewAnnouncementReadClient(tx.config)
tx.AuthIdentity = NewAuthIdentityClient(tx.config)
tx.AuthIdentityChannel = NewAuthIdentityChannelClient(tx.config)
tx.ErrorPassthroughRule = NewErrorPassthroughRuleClient(tx.config) tx.ErrorPassthroughRule = NewErrorPassthroughRuleClient(tx.config)
tx.Group = NewGroupClient(tx.config) tx.Group = NewGroupClient(tx.config)
tx.IdempotencyRecord = NewIdempotencyRecordClient(tx.config) tx.IdempotencyRecord = NewIdempotencyRecordClient(tx.config)
tx.IdentityAdoptionDecision = NewIdentityAdoptionDecisionClient(tx.config)
tx.PaymentAuditLog = NewPaymentAuditLogClient(tx.config) tx.PaymentAuditLog = NewPaymentAuditLogClient(tx.config)
tx.PaymentOrder = NewPaymentOrderClient(tx.config) tx.PaymentOrder = NewPaymentOrderClient(tx.config)
tx.PaymentProviderInstance = NewPaymentProviderInstanceClient(tx.config) tx.PaymentProviderInstance = NewPaymentProviderInstanceClient(tx.config)
tx.PendingAuthSession = NewPendingAuthSessionClient(tx.config)
tx.PromoCode = NewPromoCodeClient(tx.config) tx.PromoCode = NewPromoCodeClient(tx.config)
tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config) tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config)
tx.Proxy = NewProxyClient(tx.config) tx.Proxy = NewProxyClient(tx.config)
......
...@@ -45,6 +45,12 @@ type User struct { ...@@ -45,6 +45,12 @@ type User struct {
TotpEnabled bool `json:"totp_enabled,omitempty"` TotpEnabled bool `json:"totp_enabled,omitempty"`
// TotpEnabledAt holds the value of the "totp_enabled_at" field. // TotpEnabledAt holds the value of the "totp_enabled_at" field.
TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"` TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"`
// SignupSource holds the value of the "signup_source" field.
SignupSource string `json:"signup_source,omitempty"`
// LastLoginAt holds the value of the "last_login_at" field.
LastLoginAt *time.Time `json:"last_login_at,omitempty"`
// LastActiveAt holds the value of the "last_active_at" field.
LastActiveAt *time.Time `json:"last_active_at,omitempty"`
// BalanceNotifyEnabled holds the value of the "balance_notify_enabled" field. // BalanceNotifyEnabled holds the value of the "balance_notify_enabled" field.
BalanceNotifyEnabled bool `json:"balance_notify_enabled,omitempty"` BalanceNotifyEnabled bool `json:"balance_notify_enabled,omitempty"`
// BalanceNotifyThresholdType holds the value of the "balance_notify_threshold_type" field. // BalanceNotifyThresholdType holds the value of the "balance_notify_threshold_type" field.
...@@ -83,11 +89,15 @@ type UserEdges struct { ...@@ -83,11 +89,15 @@ type UserEdges struct {
PromoCodeUsages []*PromoCodeUsage `json:"promo_code_usages,omitempty"` PromoCodeUsages []*PromoCodeUsage `json:"promo_code_usages,omitempty"`
// PaymentOrders holds the value of the payment_orders edge. // PaymentOrders holds the value of the payment_orders edge.
PaymentOrders []*PaymentOrder `json:"payment_orders,omitempty"` PaymentOrders []*PaymentOrder `json:"payment_orders,omitempty"`
// AuthIdentities holds the value of the auth_identities edge.
AuthIdentities []*AuthIdentity `json:"auth_identities,omitempty"`
// PendingAuthSessions holds the value of the pending_auth_sessions edge.
PendingAuthSessions []*PendingAuthSession `json:"pending_auth_sessions,omitempty"`
// UserAllowedGroups holds the value of the user_allowed_groups edge. // UserAllowedGroups holds the value of the user_allowed_groups edge.
UserAllowedGroups []*UserAllowedGroup `json:"user_allowed_groups,omitempty"` UserAllowedGroups []*UserAllowedGroup `json:"user_allowed_groups,omitempty"`
// loadedTypes holds the information for reporting if a // loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not. // type was loaded (or requested) in eager-loading or not.
loadedTypes [11]bool loadedTypes [13]bool
} }
// APIKeysOrErr returns the APIKeys value or an error if the edge // APIKeysOrErr returns the APIKeys value or an error if the edge
...@@ -180,10 +190,28 @@ func (e UserEdges) PaymentOrdersOrErr() ([]*PaymentOrder, error) { ...@@ -180,10 +190,28 @@ func (e UserEdges) PaymentOrdersOrErr() ([]*PaymentOrder, error) {
return nil, &NotLoadedError{edge: "payment_orders"} return nil, &NotLoadedError{edge: "payment_orders"}
} }
// AuthIdentitiesOrErr returns the AuthIdentities value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) AuthIdentitiesOrErr() ([]*AuthIdentity, error) {
if e.loadedTypes[10] {
return e.AuthIdentities, nil
}
return nil, &NotLoadedError{edge: "auth_identities"}
}
// PendingAuthSessionsOrErr returns the PendingAuthSessions value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) PendingAuthSessionsOrErr() ([]*PendingAuthSession, error) {
if e.loadedTypes[11] {
return e.PendingAuthSessions, nil
}
return nil, &NotLoadedError{edge: "pending_auth_sessions"}
}
// UserAllowedGroupsOrErr returns the UserAllowedGroups value or an error if the edge // UserAllowedGroupsOrErr returns the UserAllowedGroups value or an error if the edge
// was not loaded in eager-loading. // was not loaded in eager-loading.
func (e UserEdges) UserAllowedGroupsOrErr() ([]*UserAllowedGroup, error) { func (e UserEdges) UserAllowedGroupsOrErr() ([]*UserAllowedGroup, error) {
if e.loadedTypes[10] { if e.loadedTypes[12] {
return e.UserAllowedGroups, nil return e.UserAllowedGroups, nil
} }
return nil, &NotLoadedError{edge: "user_allowed_groups"} return nil, &NotLoadedError{edge: "user_allowed_groups"}
...@@ -200,9 +228,9 @@ func (*User) scanValues(columns []string) ([]any, error) { ...@@ -200,9 +228,9 @@ func (*User) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullFloat64) values[i] = new(sql.NullFloat64)
case user.FieldID, user.FieldConcurrency: case user.FieldID, user.FieldConcurrency:
values[i] = new(sql.NullInt64) values[i] = new(sql.NullInt64)
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldBalanceNotifyThresholdType, user.FieldBalanceNotifyExtraEmails: case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldSignupSource, user.FieldBalanceNotifyThresholdType, user.FieldBalanceNotifyExtraEmails:
values[i] = new(sql.NullString) values[i] = new(sql.NullString)
case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt: case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt, user.FieldLastLoginAt, user.FieldLastActiveAt:
values[i] = new(sql.NullTime) values[i] = new(sql.NullTime)
default: default:
values[i] = new(sql.UnknownType) values[i] = new(sql.UnknownType)
...@@ -312,6 +340,26 @@ func (_m *User) assignValues(columns []string, values []any) error { ...@@ -312,6 +340,26 @@ func (_m *User) assignValues(columns []string, values []any) error {
_m.TotpEnabledAt = new(time.Time) _m.TotpEnabledAt = new(time.Time)
*_m.TotpEnabledAt = value.Time *_m.TotpEnabledAt = value.Time
} }
case user.FieldSignupSource:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field signup_source", values[i])
} else if value.Valid {
_m.SignupSource = value.String
}
case user.FieldLastLoginAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field last_login_at", values[i])
} else if value.Valid {
_m.LastLoginAt = new(time.Time)
*_m.LastLoginAt = value.Time
}
case user.FieldLastActiveAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field last_active_at", values[i])
} else if value.Valid {
_m.LastActiveAt = new(time.Time)
*_m.LastActiveAt = value.Time
}
case user.FieldBalanceNotifyEnabled: case user.FieldBalanceNotifyEnabled:
if value, ok := values[i].(*sql.NullBool); !ok { if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field balance_notify_enabled", values[i]) return fmt.Errorf("unexpected type %T for field balance_notify_enabled", values[i])
...@@ -406,6 +454,16 @@ func (_m *User) QueryPaymentOrders() *PaymentOrderQuery { ...@@ -406,6 +454,16 @@ func (_m *User) QueryPaymentOrders() *PaymentOrderQuery {
return NewUserClient(_m.config).QueryPaymentOrders(_m) return NewUserClient(_m.config).QueryPaymentOrders(_m)
} }
// QueryAuthIdentities queries the "auth_identities" edge of the User entity.
func (_m *User) QueryAuthIdentities() *AuthIdentityQuery {
return NewUserClient(_m.config).QueryAuthIdentities(_m)
}
// QueryPendingAuthSessions queries the "pending_auth_sessions" edge of the User entity.
func (_m *User) QueryPendingAuthSessions() *PendingAuthSessionQuery {
return NewUserClient(_m.config).QueryPendingAuthSessions(_m)
}
// QueryUserAllowedGroups queries the "user_allowed_groups" edge of the User entity. // QueryUserAllowedGroups queries the "user_allowed_groups" edge of the User entity.
func (_m *User) QueryUserAllowedGroups() *UserAllowedGroupQuery { func (_m *User) QueryUserAllowedGroups() *UserAllowedGroupQuery {
return NewUserClient(_m.config).QueryUserAllowedGroups(_m) return NewUserClient(_m.config).QueryUserAllowedGroups(_m)
...@@ -482,6 +540,19 @@ func (_m *User) String() string { ...@@ -482,6 +540,19 @@ func (_m *User) String() string {
builder.WriteString(v.Format(time.ANSIC)) builder.WriteString(v.Format(time.ANSIC))
} }
builder.WriteString(", ") builder.WriteString(", ")
builder.WriteString("signup_source=")
builder.WriteString(_m.SignupSource)
builder.WriteString(", ")
if v := _m.LastLoginAt; v != nil {
builder.WriteString("last_login_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
if v := _m.LastActiveAt; v != nil {
builder.WriteString("last_active_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
builder.WriteString("balance_notify_enabled=") builder.WriteString("balance_notify_enabled=")
builder.WriteString(fmt.Sprintf("%v", _m.BalanceNotifyEnabled)) builder.WriteString(fmt.Sprintf("%v", _m.BalanceNotifyEnabled))
builder.WriteString(", ") builder.WriteString(", ")
......
...@@ -43,6 +43,12 @@ const ( ...@@ -43,6 +43,12 @@ const (
FieldTotpEnabled = "totp_enabled" FieldTotpEnabled = "totp_enabled"
// FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database. // FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database.
FieldTotpEnabledAt = "totp_enabled_at" FieldTotpEnabledAt = "totp_enabled_at"
// FieldSignupSource holds the string denoting the signup_source field in the database.
FieldSignupSource = "signup_source"
// FieldLastLoginAt holds the string denoting the last_login_at field in the database.
FieldLastLoginAt = "last_login_at"
// FieldLastActiveAt holds the string denoting the last_active_at field in the database.
FieldLastActiveAt = "last_active_at"
// FieldBalanceNotifyEnabled holds the string denoting the balance_notify_enabled field in the database. // FieldBalanceNotifyEnabled holds the string denoting the balance_notify_enabled field in the database.
FieldBalanceNotifyEnabled = "balance_notify_enabled" FieldBalanceNotifyEnabled = "balance_notify_enabled"
// FieldBalanceNotifyThresholdType holds the string denoting the balance_notify_threshold_type field in the database. // FieldBalanceNotifyThresholdType holds the string denoting the balance_notify_threshold_type field in the database.
...@@ -73,6 +79,10 @@ const ( ...@@ -73,6 +79,10 @@ const (
EdgePromoCodeUsages = "promo_code_usages" EdgePromoCodeUsages = "promo_code_usages"
// EdgePaymentOrders holds the string denoting the payment_orders edge name in mutations. // EdgePaymentOrders holds the string denoting the payment_orders edge name in mutations.
EdgePaymentOrders = "payment_orders" EdgePaymentOrders = "payment_orders"
// EdgeAuthIdentities holds the string denoting the auth_identities edge name in mutations.
EdgeAuthIdentities = "auth_identities"
// EdgePendingAuthSessions holds the string denoting the pending_auth_sessions edge name in mutations.
EdgePendingAuthSessions = "pending_auth_sessions"
// EdgeUserAllowedGroups holds the string denoting the user_allowed_groups edge name in mutations. // EdgeUserAllowedGroups holds the string denoting the user_allowed_groups edge name in mutations.
EdgeUserAllowedGroups = "user_allowed_groups" EdgeUserAllowedGroups = "user_allowed_groups"
// Table holds the table name of the user in the database. // Table holds the table name of the user in the database.
...@@ -145,6 +155,20 @@ const ( ...@@ -145,6 +155,20 @@ const (
PaymentOrdersInverseTable = "payment_orders" PaymentOrdersInverseTable = "payment_orders"
// PaymentOrdersColumn is the table column denoting the payment_orders relation/edge. // PaymentOrdersColumn is the table column denoting the payment_orders relation/edge.
PaymentOrdersColumn = "user_id" PaymentOrdersColumn = "user_id"
// AuthIdentitiesTable is the table that holds the auth_identities relation/edge.
AuthIdentitiesTable = "auth_identities"
// AuthIdentitiesInverseTable is the table name for the AuthIdentity entity.
// It exists in this package in order to avoid circular dependency with the "authidentity" package.
AuthIdentitiesInverseTable = "auth_identities"
// AuthIdentitiesColumn is the table column denoting the auth_identities relation/edge.
AuthIdentitiesColumn = "user_id"
// PendingAuthSessionsTable is the table that holds the pending_auth_sessions relation/edge.
PendingAuthSessionsTable = "pending_auth_sessions"
// PendingAuthSessionsInverseTable is the table name for the PendingAuthSession entity.
// It exists in this package in order to avoid circular dependency with the "pendingauthsession" package.
PendingAuthSessionsInverseTable = "pending_auth_sessions"
// PendingAuthSessionsColumn is the table column denoting the pending_auth_sessions relation/edge.
PendingAuthSessionsColumn = "target_user_id"
// UserAllowedGroupsTable is the table that holds the user_allowed_groups relation/edge. // UserAllowedGroupsTable is the table that holds the user_allowed_groups relation/edge.
UserAllowedGroupsTable = "user_allowed_groups" UserAllowedGroupsTable = "user_allowed_groups"
// UserAllowedGroupsInverseTable is the table name for the UserAllowedGroup entity. // UserAllowedGroupsInverseTable is the table name for the UserAllowedGroup entity.
...@@ -171,6 +195,9 @@ var Columns = []string{ ...@@ -171,6 +195,9 @@ var Columns = []string{
FieldTotpSecretEncrypted, FieldTotpSecretEncrypted,
FieldTotpEnabled, FieldTotpEnabled,
FieldTotpEnabledAt, FieldTotpEnabledAt,
FieldSignupSource,
FieldLastLoginAt,
FieldLastActiveAt,
FieldBalanceNotifyEnabled, FieldBalanceNotifyEnabled,
FieldBalanceNotifyThresholdType, FieldBalanceNotifyThresholdType,
FieldBalanceNotifyThreshold, FieldBalanceNotifyThreshold,
...@@ -232,6 +259,10 @@ var ( ...@@ -232,6 +259,10 @@ var (
DefaultNotes string DefaultNotes string
// DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field. // DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field.
DefaultTotpEnabled bool DefaultTotpEnabled bool
// DefaultSignupSource holds the default value on creation for the "signup_source" field.
DefaultSignupSource string
// SignupSourceValidator is a validator for the "signup_source" field. It is called by the builders before save.
SignupSourceValidator func(string) error
// DefaultBalanceNotifyEnabled holds the default value on creation for the "balance_notify_enabled" field. // DefaultBalanceNotifyEnabled holds the default value on creation for the "balance_notify_enabled" field.
DefaultBalanceNotifyEnabled bool DefaultBalanceNotifyEnabled bool
// DefaultBalanceNotifyThresholdType holds the default value on creation for the "balance_notify_threshold_type" field. // DefaultBalanceNotifyThresholdType holds the default value on creation for the "balance_notify_threshold_type" field.
...@@ -320,6 +351,21 @@ func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption { ...@@ -320,6 +351,21 @@ func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc() return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc()
} }
// BySignupSource orders the results by the signup_source field.
func BySignupSource(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSignupSource, opts...).ToFunc()
}
// ByLastLoginAt orders the results by the last_login_at field.
func ByLastLoginAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldLastLoginAt, opts...).ToFunc()
}
// ByLastActiveAt orders the results by the last_active_at field.
func ByLastActiveAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldLastActiveAt, opts...).ToFunc()
}
// ByBalanceNotifyEnabled orders the results by the balance_notify_enabled field. // ByBalanceNotifyEnabled orders the results by the balance_notify_enabled field.
func ByBalanceNotifyEnabled(opts ...sql.OrderTermOption) OrderOption { func ByBalanceNotifyEnabled(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldBalanceNotifyEnabled, opts...).ToFunc() return sql.OrderByField(FieldBalanceNotifyEnabled, opts...).ToFunc()
...@@ -485,6 +531,34 @@ func ByPaymentOrders(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { ...@@ -485,6 +531,34 @@ func ByPaymentOrders(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
} }
} }
// ByAuthIdentitiesCount orders the results by auth_identities count.
func ByAuthIdentitiesCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborsCount(s, newAuthIdentitiesStep(), opts...)
}
}
// ByAuthIdentities orders the results by auth_identities terms.
func ByAuthIdentities(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newAuthIdentitiesStep(), append([]sql.OrderTerm{term}, terms...)...)
}
}
// ByPendingAuthSessionsCount orders the results by pending_auth_sessions count.
func ByPendingAuthSessionsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborsCount(s, newPendingAuthSessionsStep(), opts...)
}
}
// ByPendingAuthSessions orders the results by pending_auth_sessions terms.
func ByPendingAuthSessions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newPendingAuthSessionsStep(), append([]sql.OrderTerm{term}, terms...)...)
}
}
// ByUserAllowedGroupsCount orders the results by user_allowed_groups count. // ByUserAllowedGroupsCount orders the results by user_allowed_groups count.
func ByUserAllowedGroupsCount(opts ...sql.OrderTermOption) OrderOption { func ByUserAllowedGroupsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) { return func(s *sql.Selector) {
...@@ -568,6 +642,20 @@ func newPaymentOrdersStep() *sqlgraph.Step { ...@@ -568,6 +642,20 @@ func newPaymentOrdersStep() *sqlgraph.Step {
sqlgraph.Edge(sqlgraph.O2M, false, PaymentOrdersTable, PaymentOrdersColumn), sqlgraph.Edge(sqlgraph.O2M, false, PaymentOrdersTable, PaymentOrdersColumn),
) )
} }
func newAuthIdentitiesStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(AuthIdentitiesInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, AuthIdentitiesTable, AuthIdentitiesColumn),
)
}
func newPendingAuthSessionsStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(PendingAuthSessionsInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, PendingAuthSessionsTable, PendingAuthSessionsColumn),
)
}
func newUserAllowedGroupsStep() *sqlgraph.Step { func newUserAllowedGroupsStep() *sqlgraph.Step {
return sqlgraph.NewStep( return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID), sqlgraph.From(Table, FieldID),
......
...@@ -125,6 +125,21 @@ func TotpEnabledAt(v time.Time) predicate.User { ...@@ -125,6 +125,21 @@ func TotpEnabledAt(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v)) return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
} }
// SignupSource applies equality check predicate on the "signup_source" field. It's identical to SignupSourceEQ.
func SignupSource(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldSignupSource, v))
}
// LastLoginAt applies equality check predicate on the "last_login_at" field. It's identical to LastLoginAtEQ.
func LastLoginAt(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldLastLoginAt, v))
}
// LastActiveAt applies equality check predicate on the "last_active_at" field. It's identical to LastActiveAtEQ.
func LastActiveAt(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldLastActiveAt, v))
}
// BalanceNotifyEnabled applies equality check predicate on the "balance_notify_enabled" field. It's identical to BalanceNotifyEnabledEQ. // BalanceNotifyEnabled applies equality check predicate on the "balance_notify_enabled" field. It's identical to BalanceNotifyEnabledEQ.
func BalanceNotifyEnabled(v bool) predicate.User { func BalanceNotifyEnabled(v bool) predicate.User {
return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v)) return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v))
...@@ -885,6 +900,171 @@ func TotpEnabledAtNotNil() predicate.User { ...@@ -885,6 +900,171 @@ func TotpEnabledAtNotNil() predicate.User {
return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt)) return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt))
} }
// SignupSourceEQ applies the EQ predicate on the "signup_source" field.
func SignupSourceEQ(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldSignupSource, v))
}
// SignupSourceNEQ applies the NEQ predicate on the "signup_source" field.
func SignupSourceNEQ(v string) predicate.User {
return predicate.User(sql.FieldNEQ(FieldSignupSource, v))
}
// SignupSourceIn applies the In predicate on the "signup_source" field.
func SignupSourceIn(vs ...string) predicate.User {
return predicate.User(sql.FieldIn(FieldSignupSource, vs...))
}
// SignupSourceNotIn applies the NotIn predicate on the "signup_source" field.
func SignupSourceNotIn(vs ...string) predicate.User {
return predicate.User(sql.FieldNotIn(FieldSignupSource, vs...))
}
// SignupSourceGT applies the GT predicate on the "signup_source" field.
func SignupSourceGT(v string) predicate.User {
return predicate.User(sql.FieldGT(FieldSignupSource, v))
}
// SignupSourceGTE applies the GTE predicate on the "signup_source" field.
func SignupSourceGTE(v string) predicate.User {
return predicate.User(sql.FieldGTE(FieldSignupSource, v))
}
// SignupSourceLT applies the LT predicate on the "signup_source" field.
func SignupSourceLT(v string) predicate.User {
return predicate.User(sql.FieldLT(FieldSignupSource, v))
}
// SignupSourceLTE applies the LTE predicate on the "signup_source" field.
func SignupSourceLTE(v string) predicate.User {
return predicate.User(sql.FieldLTE(FieldSignupSource, v))
}
// SignupSourceContains applies the Contains predicate on the "signup_source" field.
func SignupSourceContains(v string) predicate.User {
return predicate.User(sql.FieldContains(FieldSignupSource, v))
}
// SignupSourceHasPrefix applies the HasPrefix predicate on the "signup_source" field.
func SignupSourceHasPrefix(v string) predicate.User {
return predicate.User(sql.FieldHasPrefix(FieldSignupSource, v))
}
// SignupSourceHasSuffix applies the HasSuffix predicate on the "signup_source" field.
func SignupSourceHasSuffix(v string) predicate.User {
return predicate.User(sql.FieldHasSuffix(FieldSignupSource, v))
}
// SignupSourceEqualFold applies the EqualFold predicate on the "signup_source" field.
func SignupSourceEqualFold(v string) predicate.User {
return predicate.User(sql.FieldEqualFold(FieldSignupSource, v))
}
// SignupSourceContainsFold applies the ContainsFold predicate on the "signup_source" field.
func SignupSourceContainsFold(v string) predicate.User {
return predicate.User(sql.FieldContainsFold(FieldSignupSource, v))
}
// LastLoginAtEQ applies the EQ predicate on the "last_login_at" field.
func LastLoginAtEQ(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldLastLoginAt, v))
}
// LastLoginAtNEQ applies the NEQ predicate on the "last_login_at" field.
func LastLoginAtNEQ(v time.Time) predicate.User {
return predicate.User(sql.FieldNEQ(FieldLastLoginAt, v))
}
// LastLoginAtIn applies the In predicate on the "last_login_at" field.
func LastLoginAtIn(vs ...time.Time) predicate.User {
return predicate.User(sql.FieldIn(FieldLastLoginAt, vs...))
}
// LastLoginAtNotIn applies the NotIn predicate on the "last_login_at" field.
func LastLoginAtNotIn(vs ...time.Time) predicate.User {
return predicate.User(sql.FieldNotIn(FieldLastLoginAt, vs...))
}
// LastLoginAtGT applies the GT predicate on the "last_login_at" field.
func LastLoginAtGT(v time.Time) predicate.User {
return predicate.User(sql.FieldGT(FieldLastLoginAt, v))
}
// LastLoginAtGTE applies the GTE predicate on the "last_login_at" field.
func LastLoginAtGTE(v time.Time) predicate.User {
return predicate.User(sql.FieldGTE(FieldLastLoginAt, v))
}
// LastLoginAtLT applies the LT predicate on the "last_login_at" field.
func LastLoginAtLT(v time.Time) predicate.User {
return predicate.User(sql.FieldLT(FieldLastLoginAt, v))
}
// LastLoginAtLTE applies the LTE predicate on the "last_login_at" field.
func LastLoginAtLTE(v time.Time) predicate.User {
return predicate.User(sql.FieldLTE(FieldLastLoginAt, v))
}
// LastLoginAtIsNil applies the IsNil predicate on the "last_login_at" field.
func LastLoginAtIsNil() predicate.User {
return predicate.User(sql.FieldIsNull(FieldLastLoginAt))
}
// LastLoginAtNotNil applies the NotNil predicate on the "last_login_at" field.
func LastLoginAtNotNil() predicate.User {
return predicate.User(sql.FieldNotNull(FieldLastLoginAt))
}
// LastActiveAtEQ applies the EQ predicate on the "last_active_at" field.
func LastActiveAtEQ(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldLastActiveAt, v))
}
// LastActiveAtNEQ applies the NEQ predicate on the "last_active_at" field.
func LastActiveAtNEQ(v time.Time) predicate.User {
return predicate.User(sql.FieldNEQ(FieldLastActiveAt, v))
}
// LastActiveAtIn applies the In predicate on the "last_active_at" field.
func LastActiveAtIn(vs ...time.Time) predicate.User {
return predicate.User(sql.FieldIn(FieldLastActiveAt, vs...))
}
// LastActiveAtNotIn applies the NotIn predicate on the "last_active_at" field.
func LastActiveAtNotIn(vs ...time.Time) predicate.User {
return predicate.User(sql.FieldNotIn(FieldLastActiveAt, vs...))
}
// LastActiveAtGT applies the GT predicate on the "last_active_at" field.
func LastActiveAtGT(v time.Time) predicate.User {
return predicate.User(sql.FieldGT(FieldLastActiveAt, v))
}
// LastActiveAtGTE applies the GTE predicate on the "last_active_at" field.
func LastActiveAtGTE(v time.Time) predicate.User {
return predicate.User(sql.FieldGTE(FieldLastActiveAt, v))
}
// LastActiveAtLT applies the LT predicate on the "last_active_at" field.
func LastActiveAtLT(v time.Time) predicate.User {
return predicate.User(sql.FieldLT(FieldLastActiveAt, v))
}
// LastActiveAtLTE applies the LTE predicate on the "last_active_at" field.
func LastActiveAtLTE(v time.Time) predicate.User {
return predicate.User(sql.FieldLTE(FieldLastActiveAt, v))
}
// LastActiveAtIsNil applies the IsNil predicate on the "last_active_at" field.
func LastActiveAtIsNil() predicate.User {
return predicate.User(sql.FieldIsNull(FieldLastActiveAt))
}
// LastActiveAtNotNil applies the NotNil predicate on the "last_active_at" field.
func LastActiveAtNotNil() predicate.User {
return predicate.User(sql.FieldNotNull(FieldLastActiveAt))
}
// BalanceNotifyEnabledEQ applies the EQ predicate on the "balance_notify_enabled" field. // BalanceNotifyEnabledEQ applies the EQ predicate on the "balance_notify_enabled" field.
func BalanceNotifyEnabledEQ(v bool) predicate.User { func BalanceNotifyEnabledEQ(v bool) predicate.User {
return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v)) return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v))
...@@ -1345,6 +1525,52 @@ func HasPaymentOrdersWith(preds ...predicate.PaymentOrder) predicate.User { ...@@ -1345,6 +1525,52 @@ func HasPaymentOrdersWith(preds ...predicate.PaymentOrder) predicate.User {
}) })
} }
// HasAuthIdentities applies the HasEdge predicate on the "auth_identities" edge.
func HasAuthIdentities() predicate.User {
return predicate.User(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, AuthIdentitiesTable, AuthIdentitiesColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasAuthIdentitiesWith applies the HasEdge predicate on the "auth_identities" edge with a given conditions (other predicates).
func HasAuthIdentitiesWith(preds ...predicate.AuthIdentity) predicate.User {
return predicate.User(func(s *sql.Selector) {
step := newAuthIdentitiesStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// HasPendingAuthSessions applies the HasEdge predicate on the "pending_auth_sessions" edge.
func HasPendingAuthSessions() predicate.User {
return predicate.User(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, PendingAuthSessionsTable, PendingAuthSessionsColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasPendingAuthSessionsWith applies the HasEdge predicate on the "pending_auth_sessions" edge with a given conditions (other predicates).
func HasPendingAuthSessionsWith(preds ...predicate.PendingAuthSession) predicate.User {
return predicate.User(func(s *sql.Selector) {
step := newPendingAuthSessionsStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// HasUserAllowedGroups applies the HasEdge predicate on the "user_allowed_groups" edge. // HasUserAllowedGroups applies the HasEdge predicate on the "user_allowed_groups" edge.
func HasUserAllowedGroups() predicate.User { func HasUserAllowedGroups() predicate.User {
return predicate.User(func(s *sql.Selector) { return predicate.User(func(s *sql.Selector) {
......
...@@ -13,8 +13,10 @@ import ( ...@@ -13,8 +13,10 @@ import (
"entgo.io/ent/schema/field" "entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/usagelog"
...@@ -211,6 +213,48 @@ func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate { ...@@ -211,6 +213,48 @@ func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate {
return _c return _c
} }
// SetSignupSource sets the "signup_source" field.
func (_c *UserCreate) SetSignupSource(v string) *UserCreate {
_c.mutation.SetSignupSource(v)
return _c
}
// SetNillableSignupSource sets the "signup_source" field if the given value is not nil.
func (_c *UserCreate) SetNillableSignupSource(v *string) *UserCreate {
if v != nil {
_c.SetSignupSource(*v)
}
return _c
}
// SetLastLoginAt sets the "last_login_at" field.
func (_c *UserCreate) SetLastLoginAt(v time.Time) *UserCreate {
_c.mutation.SetLastLoginAt(v)
return _c
}
// SetNillableLastLoginAt sets the "last_login_at" field if the given value is not nil.
func (_c *UserCreate) SetNillableLastLoginAt(v *time.Time) *UserCreate {
if v != nil {
_c.SetLastLoginAt(*v)
}
return _c
}
// SetLastActiveAt sets the "last_active_at" field.
func (_c *UserCreate) SetLastActiveAt(v time.Time) *UserCreate {
_c.mutation.SetLastActiveAt(v)
return _c
}
// SetNillableLastActiveAt sets the "last_active_at" field if the given value is not nil.
func (_c *UserCreate) SetNillableLastActiveAt(v *time.Time) *UserCreate {
if v != nil {
_c.SetLastActiveAt(*v)
}
return _c
}
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (_c *UserCreate) SetBalanceNotifyEnabled(v bool) *UserCreate { func (_c *UserCreate) SetBalanceNotifyEnabled(v bool) *UserCreate {
_c.mutation.SetBalanceNotifyEnabled(v) _c.mutation.SetBalanceNotifyEnabled(v)
...@@ -431,6 +475,36 @@ func (_c *UserCreate) AddPaymentOrders(v ...*PaymentOrder) *UserCreate { ...@@ -431,6 +475,36 @@ func (_c *UserCreate) AddPaymentOrders(v ...*PaymentOrder) *UserCreate {
return _c.AddPaymentOrderIDs(ids...) return _c.AddPaymentOrderIDs(ids...)
} }
// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by IDs.
func (_c *UserCreate) AddAuthIdentityIDs(ids ...int64) *UserCreate {
_c.mutation.AddAuthIdentityIDs(ids...)
return _c
}
// AddAuthIdentities adds the "auth_identities" edges to the AuthIdentity entity.
func (_c *UserCreate) AddAuthIdentities(v ...*AuthIdentity) *UserCreate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _c.AddAuthIdentityIDs(ids...)
}
// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs.
func (_c *UserCreate) AddPendingAuthSessionIDs(ids ...int64) *UserCreate {
_c.mutation.AddPendingAuthSessionIDs(ids...)
return _c
}
// AddPendingAuthSessions adds the "pending_auth_sessions" edges to the PendingAuthSession entity.
func (_c *UserCreate) AddPendingAuthSessions(v ...*PendingAuthSession) *UserCreate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _c.AddPendingAuthSessionIDs(ids...)
}
// Mutation returns the UserMutation object of the builder. // Mutation returns the UserMutation object of the builder.
func (_c *UserCreate) Mutation() *UserMutation { func (_c *UserCreate) Mutation() *UserMutation {
return _c.mutation return _c.mutation
...@@ -510,6 +584,10 @@ func (_c *UserCreate) defaults() error { ...@@ -510,6 +584,10 @@ func (_c *UserCreate) defaults() error {
v := user.DefaultTotpEnabled v := user.DefaultTotpEnabled
_c.mutation.SetTotpEnabled(v) _c.mutation.SetTotpEnabled(v)
} }
if _, ok := _c.mutation.SignupSource(); !ok {
v := user.DefaultSignupSource
_c.mutation.SetSignupSource(v)
}
if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok { if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok {
v := user.DefaultBalanceNotifyEnabled v := user.DefaultBalanceNotifyEnabled
_c.mutation.SetBalanceNotifyEnabled(v) _c.mutation.SetBalanceNotifyEnabled(v)
...@@ -589,6 +667,14 @@ func (_c *UserCreate) check() error { ...@@ -589,6 +667,14 @@ func (_c *UserCreate) check() error {
if _, ok := _c.mutation.TotpEnabled(); !ok { if _, ok := _c.mutation.TotpEnabled(); !ok {
return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)} return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)}
} }
if _, ok := _c.mutation.SignupSource(); !ok {
return &ValidationError{Name: "signup_source", err: errors.New(`ent: missing required field "User.signup_source"`)}
}
if v, ok := _c.mutation.SignupSource(); ok {
if err := user.SignupSourceValidator(v); err != nil {
return &ValidationError{Name: "signup_source", err: fmt.Errorf(`ent: validator failed for field "User.signup_source": %w`, err)}
}
}
if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok { if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok {
return &ValidationError{Name: "balance_notify_enabled", err: errors.New(`ent: missing required field "User.balance_notify_enabled"`)} return &ValidationError{Name: "balance_notify_enabled", err: errors.New(`ent: missing required field "User.balance_notify_enabled"`)}
} }
...@@ -684,6 +770,18 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { ...@@ -684,6 +770,18 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value) _spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
_node.TotpEnabledAt = &value _node.TotpEnabledAt = &value
} }
if value, ok := _c.mutation.SignupSource(); ok {
_spec.SetField(user.FieldSignupSource, field.TypeString, value)
_node.SignupSource = value
}
if value, ok := _c.mutation.LastLoginAt(); ok {
_spec.SetField(user.FieldLastLoginAt, field.TypeTime, value)
_node.LastLoginAt = &value
}
if value, ok := _c.mutation.LastActiveAt(); ok {
_spec.SetField(user.FieldLastActiveAt, field.TypeTime, value)
_node.LastActiveAt = &value
}
if value, ok := _c.mutation.BalanceNotifyEnabled(); ok { if value, ok := _c.mutation.BalanceNotifyEnabled(); ok {
_spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value) _spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
_node.BalanceNotifyEnabled = value _node.BalanceNotifyEnabled = value
...@@ -868,6 +966,38 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { ...@@ -868,6 +966,38 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
} }
_spec.Edges = append(_spec.Edges, edge) _spec.Edges = append(_spec.Edges, edge)
} }
if nodes := _c.mutation.AuthIdentitiesIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AuthIdentitiesTable,
Columns: []string{user.AuthIdentitiesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges = append(_spec.Edges, edge)
}
if nodes := _c.mutation.PendingAuthSessionsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PendingAuthSessionsTable,
Columns: []string{user.PendingAuthSessionsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges = append(_spec.Edges, edge)
}
return _node, _spec return _node, _spec
} }
...@@ -1106,6 +1236,54 @@ func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert { ...@@ -1106,6 +1236,54 @@ func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert {
return u return u
} }
// SetSignupSource sets the "signup_source" field.
func (u *UserUpsert) SetSignupSource(v string) *UserUpsert {
u.Set(user.FieldSignupSource, v)
return u
}
// UpdateSignupSource sets the "signup_source" field to the value that was provided on create.
func (u *UserUpsert) UpdateSignupSource() *UserUpsert {
u.SetExcluded(user.FieldSignupSource)
return u
}
// SetLastLoginAt sets the "last_login_at" field.
func (u *UserUpsert) SetLastLoginAt(v time.Time) *UserUpsert {
u.Set(user.FieldLastLoginAt, v)
return u
}
// UpdateLastLoginAt sets the "last_login_at" field to the value that was provided on create.
func (u *UserUpsert) UpdateLastLoginAt() *UserUpsert {
u.SetExcluded(user.FieldLastLoginAt)
return u
}
// ClearLastLoginAt clears the value of the "last_login_at" field.
func (u *UserUpsert) ClearLastLoginAt() *UserUpsert {
u.SetNull(user.FieldLastLoginAt)
return u
}
// SetLastActiveAt sets the "last_active_at" field.
func (u *UserUpsert) SetLastActiveAt(v time.Time) *UserUpsert {
u.Set(user.FieldLastActiveAt, v)
return u
}
// UpdateLastActiveAt sets the "last_active_at" field to the value that was provided on create.
func (u *UserUpsert) UpdateLastActiveAt() *UserUpsert {
u.SetExcluded(user.FieldLastActiveAt)
return u
}
// ClearLastActiveAt clears the value of the "last_active_at" field.
func (u *UserUpsert) ClearLastActiveAt() *UserUpsert {
u.SetNull(user.FieldLastActiveAt)
return u
}
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (u *UserUpsert) SetBalanceNotifyEnabled(v bool) *UserUpsert { func (u *UserUpsert) SetBalanceNotifyEnabled(v bool) *UserUpsert {
u.Set(user.FieldBalanceNotifyEnabled, v) u.Set(user.FieldBalanceNotifyEnabled, v)
...@@ -1446,6 +1624,62 @@ func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne { ...@@ -1446,6 +1624,62 @@ func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne {
}) })
} }
// SetSignupSource sets the "signup_source" field.
func (u *UserUpsertOne) SetSignupSource(v string) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.SetSignupSource(v)
})
}
// UpdateSignupSource sets the "signup_source" field to the value that was provided on create.
func (u *UserUpsertOne) UpdateSignupSource() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.UpdateSignupSource()
})
}
// SetLastLoginAt sets the "last_login_at" field.
func (u *UserUpsertOne) SetLastLoginAt(v time.Time) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.SetLastLoginAt(v)
})
}
// UpdateLastLoginAt sets the "last_login_at" field to the value that was provided on create.
func (u *UserUpsertOne) UpdateLastLoginAt() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.UpdateLastLoginAt()
})
}
// ClearLastLoginAt clears the value of the "last_login_at" field.
func (u *UserUpsertOne) ClearLastLoginAt() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.ClearLastLoginAt()
})
}
// SetLastActiveAt sets the "last_active_at" field.
func (u *UserUpsertOne) SetLastActiveAt(v time.Time) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.SetLastActiveAt(v)
})
}
// UpdateLastActiveAt sets the "last_active_at" field to the value that was provided on create.
func (u *UserUpsertOne) UpdateLastActiveAt() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.UpdateLastActiveAt()
})
}
// ClearLastActiveAt clears the value of the "last_active_at" field.
func (u *UserUpsertOne) ClearLastActiveAt() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.ClearLastActiveAt()
})
}
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (u *UserUpsertOne) SetBalanceNotifyEnabled(v bool) *UserUpsertOne { func (u *UserUpsertOne) SetBalanceNotifyEnabled(v bool) *UserUpsertOne {
return u.Update(func(s *UserUpsert) { return u.Update(func(s *UserUpsert) {
...@@ -1965,6 +2199,62 @@ func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk { ...@@ -1965,6 +2199,62 @@ func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk {
}) })
} }
// SetSignupSource sets the "signup_source" field.
func (u *UserUpsertBulk) SetSignupSource(v string) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.SetSignupSource(v)
})
}
// UpdateSignupSource sets the "signup_source" field to the value that was provided on create.
func (u *UserUpsertBulk) UpdateSignupSource() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.UpdateSignupSource()
})
}
// SetLastLoginAt sets the "last_login_at" field.
func (u *UserUpsertBulk) SetLastLoginAt(v time.Time) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.SetLastLoginAt(v)
})
}
// UpdateLastLoginAt sets the "last_login_at" field to the value that was provided on create.
func (u *UserUpsertBulk) UpdateLastLoginAt() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.UpdateLastLoginAt()
})
}
// ClearLastLoginAt clears the value of the "last_login_at" field.
func (u *UserUpsertBulk) ClearLastLoginAt() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.ClearLastLoginAt()
})
}
// SetLastActiveAt sets the "last_active_at" field.
func (u *UserUpsertBulk) SetLastActiveAt(v time.Time) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.SetLastActiveAt(v)
})
}
// UpdateLastActiveAt sets the "last_active_at" field to the value that was provided on create.
func (u *UserUpsertBulk) UpdateLastActiveAt() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.UpdateLastActiveAt()
})
}
// ClearLastActiveAt clears the value of the "last_active_at" field.
func (u *UserUpsertBulk) ClearLastActiveAt() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.ClearLastActiveAt()
})
}
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (u *UserUpsertBulk) SetBalanceNotifyEnabled(v bool) *UserUpsertBulk { func (u *UserUpsertBulk) SetBalanceNotifyEnabled(v bool) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) { return u.Update(func(s *UserUpsert) {
......
...@@ -15,8 +15,10 @@ import ( ...@@ -15,8 +15,10 @@ import (
"entgo.io/ent/schema/field" "entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/redeemcode"
...@@ -44,6 +46,8 @@ type UserQuery struct { ...@@ -44,6 +46,8 @@ type UserQuery struct {
withAttributeValues *UserAttributeValueQuery withAttributeValues *UserAttributeValueQuery
withPromoCodeUsages *PromoCodeUsageQuery withPromoCodeUsages *PromoCodeUsageQuery
withPaymentOrders *PaymentOrderQuery withPaymentOrders *PaymentOrderQuery
withAuthIdentities *AuthIdentityQuery
withPendingAuthSessions *PendingAuthSessionQuery
withUserAllowedGroups *UserAllowedGroupQuery withUserAllowedGroups *UserAllowedGroupQuery
modifiers []func(*sql.Selector) modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path). // intermediate query (i.e. traversal path).
...@@ -302,6 +306,50 @@ func (_q *UserQuery) QueryPaymentOrders() *PaymentOrderQuery { ...@@ -302,6 +306,50 @@ func (_q *UserQuery) QueryPaymentOrders() *PaymentOrderQuery {
return query return query
} }
// QueryAuthIdentities chains the current query on the "auth_identities" edge.
func (_q *UserQuery) QueryAuthIdentities() *AuthIdentityQuery {
query := (&AuthIdentityClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(user.Table, user.FieldID, selector),
sqlgraph.To(authidentity.Table, authidentity.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, user.AuthIdentitiesTable, user.AuthIdentitiesColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// QueryPendingAuthSessions chains the current query on the "pending_auth_sessions" edge.
func (_q *UserQuery) QueryPendingAuthSessions() *PendingAuthSessionQuery {
query := (&PendingAuthSessionClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(user.Table, user.FieldID, selector),
sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, user.PendingAuthSessionsTable, user.PendingAuthSessionsColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// QueryUserAllowedGroups chains the current query on the "user_allowed_groups" edge. // QueryUserAllowedGroups chains the current query on the "user_allowed_groups" edge.
func (_q *UserQuery) QueryUserAllowedGroups() *UserAllowedGroupQuery { func (_q *UserQuery) QueryUserAllowedGroups() *UserAllowedGroupQuery {
query := (&UserAllowedGroupClient{config: _q.config}).Query() query := (&UserAllowedGroupClient{config: _q.config}).Query()
...@@ -526,6 +574,8 @@ func (_q *UserQuery) Clone() *UserQuery { ...@@ -526,6 +574,8 @@ func (_q *UserQuery) Clone() *UserQuery {
withAttributeValues: _q.withAttributeValues.Clone(), withAttributeValues: _q.withAttributeValues.Clone(),
withPromoCodeUsages: _q.withPromoCodeUsages.Clone(), withPromoCodeUsages: _q.withPromoCodeUsages.Clone(),
withPaymentOrders: _q.withPaymentOrders.Clone(), withPaymentOrders: _q.withPaymentOrders.Clone(),
withAuthIdentities: _q.withAuthIdentities.Clone(),
withPendingAuthSessions: _q.withPendingAuthSessions.Clone(),
withUserAllowedGroups: _q.withUserAllowedGroups.Clone(), withUserAllowedGroups: _q.withUserAllowedGroups.Clone(),
// clone intermediate query. // clone intermediate query.
sql: _q.sql.Clone(), sql: _q.sql.Clone(),
...@@ -643,6 +693,28 @@ func (_q *UserQuery) WithPaymentOrders(opts ...func(*PaymentOrderQuery)) *UserQu ...@@ -643,6 +693,28 @@ func (_q *UserQuery) WithPaymentOrders(opts ...func(*PaymentOrderQuery)) *UserQu
return _q return _q
} }
// WithAuthIdentities tells the query-builder to eager-load the nodes that are connected to
// the "auth_identities" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithAuthIdentities(opts ...func(*AuthIdentityQuery)) *UserQuery {
query := (&AuthIdentityClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withAuthIdentities = query
return _q
}
// WithPendingAuthSessions tells the query-builder to eager-load the nodes that are connected to
// the "pending_auth_sessions" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithPendingAuthSessions(opts ...func(*PendingAuthSessionQuery)) *UserQuery {
query := (&PendingAuthSessionClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withPendingAuthSessions = query
return _q
}
// WithUserAllowedGroups tells the query-builder to eager-load the nodes that are connected to // WithUserAllowedGroups tells the query-builder to eager-load the nodes that are connected to
// the "user_allowed_groups" edge. The optional arguments are used to configure the query builder of the edge. // the "user_allowed_groups" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithUserAllowedGroups(opts ...func(*UserAllowedGroupQuery)) *UserQuery { func (_q *UserQuery) WithUserAllowedGroups(opts ...func(*UserAllowedGroupQuery)) *UserQuery {
...@@ -732,7 +804,7 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e ...@@ -732,7 +804,7 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
var ( var (
nodes = []*User{} nodes = []*User{}
_spec = _q.querySpec() _spec = _q.querySpec()
loadedTypes = [11]bool{ loadedTypes = [13]bool{
_q.withAPIKeys != nil, _q.withAPIKeys != nil,
_q.withRedeemCodes != nil, _q.withRedeemCodes != nil,
_q.withSubscriptions != nil, _q.withSubscriptions != nil,
...@@ -743,6 +815,8 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e ...@@ -743,6 +815,8 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
_q.withAttributeValues != nil, _q.withAttributeValues != nil,
_q.withPromoCodeUsages != nil, _q.withPromoCodeUsages != nil,
_q.withPaymentOrders != nil, _q.withPaymentOrders != nil,
_q.withAuthIdentities != nil,
_q.withPendingAuthSessions != nil,
_q.withUserAllowedGroups != nil, _q.withUserAllowedGroups != nil,
} }
) )
...@@ -839,6 +913,22 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e ...@@ -839,6 +913,22 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
return nil, err return nil, err
} }
} }
if query := _q.withAuthIdentities; query != nil {
if err := _q.loadAuthIdentities(ctx, query, nodes,
func(n *User) { n.Edges.AuthIdentities = []*AuthIdentity{} },
func(n *User, e *AuthIdentity) { n.Edges.AuthIdentities = append(n.Edges.AuthIdentities, e) }); err != nil {
return nil, err
}
}
if query := _q.withPendingAuthSessions; query != nil {
if err := _q.loadPendingAuthSessions(ctx, query, nodes,
func(n *User) { n.Edges.PendingAuthSessions = []*PendingAuthSession{} },
func(n *User, e *PendingAuthSession) {
n.Edges.PendingAuthSessions = append(n.Edges.PendingAuthSessions, e)
}); err != nil {
return nil, err
}
}
if query := _q.withUserAllowedGroups; query != nil { if query := _q.withUserAllowedGroups; query != nil {
if err := _q.loadUserAllowedGroups(ctx, query, nodes, if err := _q.loadUserAllowedGroups(ctx, query, nodes,
func(n *User) { n.Edges.UserAllowedGroups = []*UserAllowedGroup{} }, func(n *User) { n.Edges.UserAllowedGroups = []*UserAllowedGroup{} },
...@@ -1186,6 +1276,69 @@ func (_q *UserQuery) loadPaymentOrders(ctx context.Context, query *PaymentOrderQ ...@@ -1186,6 +1276,69 @@ func (_q *UserQuery) loadPaymentOrders(ctx context.Context, query *PaymentOrderQ
} }
return nil return nil
} }
func (_q *UserQuery) loadAuthIdentities(ctx context.Context, query *AuthIdentityQuery, nodes []*User, init func(*User), assign func(*User, *AuthIdentity)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*User)
for i := range nodes {
fks = append(fks, nodes[i].ID)
nodeids[nodes[i].ID] = nodes[i]
if init != nil {
init(nodes[i])
}
}
if len(query.ctx.Fields) > 0 {
query.ctx.AppendFieldOnce(authidentity.FieldUserID)
}
query.Where(predicate.AuthIdentity(func(s *sql.Selector) {
s.Where(sql.InValues(s.C(user.AuthIdentitiesColumn), fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
fk := n.UserID
node, ok := nodeids[fk]
if !ok {
return fmt.Errorf(`unexpected referenced foreign-key "user_id" returned %v for node %v`, fk, n.ID)
}
assign(node, n)
}
return nil
}
func (_q *UserQuery) loadPendingAuthSessions(ctx context.Context, query *PendingAuthSessionQuery, nodes []*User, init func(*User), assign func(*User, *PendingAuthSession)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*User)
for i := range nodes {
fks = append(fks, nodes[i].ID)
nodeids[nodes[i].ID] = nodes[i]
if init != nil {
init(nodes[i])
}
}
if len(query.ctx.Fields) > 0 {
query.ctx.AppendFieldOnce(pendingauthsession.FieldTargetUserID)
}
query.Where(predicate.PendingAuthSession(func(s *sql.Selector) {
s.Where(sql.InValues(s.C(user.PendingAuthSessionsColumn), fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
fk := n.TargetUserID
if fk == nil {
return fmt.Errorf(`foreign-key "target_user_id" is nil for node %v`, n.ID)
}
node, ok := nodeids[*fk]
if !ok {
return fmt.Errorf(`unexpected referenced foreign-key "target_user_id" returned %v for node %v`, *fk, n.ID)
}
assign(node, n)
}
return nil
}
func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllowedGroupQuery, nodes []*User, init func(*User), assign func(*User, *UserAllowedGroup)) error { func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllowedGroupQuery, nodes []*User, init func(*User), assign func(*User, *UserAllowedGroup)) error {
fks := make([]driver.Value, 0, len(nodes)) fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*User) nodeids := make(map[int64]*User)
......
...@@ -13,8 +13,10 @@ import ( ...@@ -13,8 +13,10 @@ import (
"entgo.io/ent/schema/field" "entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/redeemcode"
...@@ -243,6 +245,60 @@ func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate { ...@@ -243,6 +245,60 @@ func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate {
return _u return _u
} }
// SetSignupSource sets the "signup_source" field.
func (_u *UserUpdate) SetSignupSource(v string) *UserUpdate {
_u.mutation.SetSignupSource(v)
return _u
}
// SetNillableSignupSource sets the "signup_source" field if the given value is not nil.
func (_u *UserUpdate) SetNillableSignupSource(v *string) *UserUpdate {
if v != nil {
_u.SetSignupSource(*v)
}
return _u
}
// SetLastLoginAt sets the "last_login_at" field.
func (_u *UserUpdate) SetLastLoginAt(v time.Time) *UserUpdate {
_u.mutation.SetLastLoginAt(v)
return _u
}
// SetNillableLastLoginAt sets the "last_login_at" field if the given value is not nil.
func (_u *UserUpdate) SetNillableLastLoginAt(v *time.Time) *UserUpdate {
if v != nil {
_u.SetLastLoginAt(*v)
}
return _u
}
// ClearLastLoginAt clears the value of the "last_login_at" field.
func (_u *UserUpdate) ClearLastLoginAt() *UserUpdate {
_u.mutation.ClearLastLoginAt()
return _u
}
// SetLastActiveAt sets the "last_active_at" field.
func (_u *UserUpdate) SetLastActiveAt(v time.Time) *UserUpdate {
_u.mutation.SetLastActiveAt(v)
return _u
}
// SetNillableLastActiveAt sets the "last_active_at" field if the given value is not nil.
func (_u *UserUpdate) SetNillableLastActiveAt(v *time.Time) *UserUpdate {
if v != nil {
_u.SetLastActiveAt(*v)
}
return _u
}
// ClearLastActiveAt clears the value of the "last_active_at" field.
func (_u *UserUpdate) ClearLastActiveAt() *UserUpdate {
_u.mutation.ClearLastActiveAt()
return _u
}
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (_u *UserUpdate) SetBalanceNotifyEnabled(v bool) *UserUpdate { func (_u *UserUpdate) SetBalanceNotifyEnabled(v bool) *UserUpdate {
_u.mutation.SetBalanceNotifyEnabled(v) _u.mutation.SetBalanceNotifyEnabled(v)
...@@ -483,6 +539,36 @@ func (_u *UserUpdate) AddPaymentOrders(v ...*PaymentOrder) *UserUpdate { ...@@ -483,6 +539,36 @@ func (_u *UserUpdate) AddPaymentOrders(v ...*PaymentOrder) *UserUpdate {
return _u.AddPaymentOrderIDs(ids...) return _u.AddPaymentOrderIDs(ids...)
} }
// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by IDs.
func (_u *UserUpdate) AddAuthIdentityIDs(ids ...int64) *UserUpdate {
_u.mutation.AddAuthIdentityIDs(ids...)
return _u
}
// AddAuthIdentities adds the "auth_identities" edges to the AuthIdentity entity.
func (_u *UserUpdate) AddAuthIdentities(v ...*AuthIdentity) *UserUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddAuthIdentityIDs(ids...)
}
// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs.
func (_u *UserUpdate) AddPendingAuthSessionIDs(ids ...int64) *UserUpdate {
_u.mutation.AddPendingAuthSessionIDs(ids...)
return _u
}
// AddPendingAuthSessions adds the "pending_auth_sessions" edges to the PendingAuthSession entity.
func (_u *UserUpdate) AddPendingAuthSessions(v ...*PendingAuthSession) *UserUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddPendingAuthSessionIDs(ids...)
}
// Mutation returns the UserMutation object of the builder. // Mutation returns the UserMutation object of the builder.
func (_u *UserUpdate) Mutation() *UserMutation { func (_u *UserUpdate) Mutation() *UserMutation {
return _u.mutation return _u.mutation
...@@ -698,6 +784,48 @@ func (_u *UserUpdate) RemovePaymentOrders(v ...*PaymentOrder) *UserUpdate { ...@@ -698,6 +784,48 @@ func (_u *UserUpdate) RemovePaymentOrders(v ...*PaymentOrder) *UserUpdate {
return _u.RemovePaymentOrderIDs(ids...) return _u.RemovePaymentOrderIDs(ids...)
} }
// ClearAuthIdentities clears all "auth_identities" edges to the AuthIdentity entity.
func (_u *UserUpdate) ClearAuthIdentities() *UserUpdate {
_u.mutation.ClearAuthIdentities()
return _u
}
// RemoveAuthIdentityIDs removes the "auth_identities" edge to AuthIdentity entities by IDs.
func (_u *UserUpdate) RemoveAuthIdentityIDs(ids ...int64) *UserUpdate {
_u.mutation.RemoveAuthIdentityIDs(ids...)
return _u
}
// RemoveAuthIdentities removes "auth_identities" edges to AuthIdentity entities.
func (_u *UserUpdate) RemoveAuthIdentities(v ...*AuthIdentity) *UserUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveAuthIdentityIDs(ids...)
}
// ClearPendingAuthSessions clears all "pending_auth_sessions" edges to the PendingAuthSession entity.
func (_u *UserUpdate) ClearPendingAuthSessions() *UserUpdate {
_u.mutation.ClearPendingAuthSessions()
return _u
}
// RemovePendingAuthSessionIDs removes the "pending_auth_sessions" edge to PendingAuthSession entities by IDs.
func (_u *UserUpdate) RemovePendingAuthSessionIDs(ids ...int64) *UserUpdate {
_u.mutation.RemovePendingAuthSessionIDs(ids...)
return _u
}
// RemovePendingAuthSessions removes "pending_auth_sessions" edges to PendingAuthSession entities.
func (_u *UserUpdate) RemovePendingAuthSessions(v ...*PendingAuthSession) *UserUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemovePendingAuthSessionIDs(ids...)
}
// Save executes the query and returns the number of nodes affected by the update operation. // Save executes the query and returns the number of nodes affected by the update operation.
func (_u *UserUpdate) Save(ctx context.Context) (int, error) { func (_u *UserUpdate) Save(ctx context.Context) (int, error) {
if err := _u.defaults(); err != nil { if err := _u.defaults(); err != nil {
...@@ -767,6 +895,11 @@ func (_u *UserUpdate) check() error { ...@@ -767,6 +895,11 @@ func (_u *UserUpdate) check() error {
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)} return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
} }
} }
if v, ok := _u.mutation.SignupSource(); ok {
if err := user.SignupSourceValidator(v); err != nil {
return &ValidationError{Name: "signup_source", err: fmt.Errorf(`ent: validator failed for field "User.signup_source": %w`, err)}
}
}
return nil return nil
} }
...@@ -836,6 +969,21 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) { ...@@ -836,6 +969,21 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.TotpEnabledAtCleared() { if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime) _spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
} }
if value, ok := _u.mutation.SignupSource(); ok {
_spec.SetField(user.FieldSignupSource, field.TypeString, value)
}
if value, ok := _u.mutation.LastLoginAt(); ok {
_spec.SetField(user.FieldLastLoginAt, field.TypeTime, value)
}
if _u.mutation.LastLoginAtCleared() {
_spec.ClearField(user.FieldLastLoginAt, field.TypeTime)
}
if value, ok := _u.mutation.LastActiveAt(); ok {
_spec.SetField(user.FieldLastActiveAt, field.TypeTime, value)
}
if _u.mutation.LastActiveAtCleared() {
_spec.ClearField(user.FieldLastActiveAt, field.TypeTime)
}
if value, ok := _u.mutation.BalanceNotifyEnabled(); ok { if value, ok := _u.mutation.BalanceNotifyEnabled(); ok {
_spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value) _spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
} }
...@@ -1322,6 +1470,96 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) { ...@@ -1322,6 +1470,96 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
} }
_spec.Edges.Add = append(_spec.Edges.Add, edge) _spec.Edges.Add = append(_spec.Edges.Add, edge)
} }
if _u.mutation.AuthIdentitiesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AuthIdentitiesTable,
Columns: []string{user.AuthIdentitiesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedAuthIdentitiesIDs(); len(nodes) > 0 && !_u.mutation.AuthIdentitiesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AuthIdentitiesTable,
Columns: []string{user.AuthIdentitiesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.AuthIdentitiesIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AuthIdentitiesTable,
Columns: []string{user.AuthIdentitiesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.PendingAuthSessionsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PendingAuthSessionsTable,
Columns: []string{user.PendingAuthSessionsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedPendingAuthSessionsIDs(); len(nodes) > 0 && !_u.mutation.PendingAuthSessionsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PendingAuthSessionsTable,
Columns: []string{user.PendingAuthSessionsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.PendingAuthSessionsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PendingAuthSessionsTable,
Columns: []string{user.PendingAuthSessionsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok { if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{user.Label} err = &NotFoundError{user.Label}
...@@ -1548,6 +1786,60 @@ func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne { ...@@ -1548,6 +1786,60 @@ func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne {
return _u return _u
} }
// SetSignupSource sets the "signup_source" field.
func (_u *UserUpdateOne) SetSignupSource(v string) *UserUpdateOne {
_u.mutation.SetSignupSource(v)
return _u
}
// SetNillableSignupSource sets the "signup_source" field if the given value is not nil.
func (_u *UserUpdateOne) SetNillableSignupSource(v *string) *UserUpdateOne {
if v != nil {
_u.SetSignupSource(*v)
}
return _u
}
// SetLastLoginAt sets the "last_login_at" field.
func (_u *UserUpdateOne) SetLastLoginAt(v time.Time) *UserUpdateOne {
_u.mutation.SetLastLoginAt(v)
return _u
}
// SetNillableLastLoginAt sets the "last_login_at" field if the given value is not nil.
func (_u *UserUpdateOne) SetNillableLastLoginAt(v *time.Time) *UserUpdateOne {
if v != nil {
_u.SetLastLoginAt(*v)
}
return _u
}
// ClearLastLoginAt clears the value of the "last_login_at" field.
func (_u *UserUpdateOne) ClearLastLoginAt() *UserUpdateOne {
_u.mutation.ClearLastLoginAt()
return _u
}
// SetLastActiveAt sets the "last_active_at" field.
func (_u *UserUpdateOne) SetLastActiveAt(v time.Time) *UserUpdateOne {
_u.mutation.SetLastActiveAt(v)
return _u
}
// SetNillableLastActiveAt sets the "last_active_at" field if the given value is not nil.
func (_u *UserUpdateOne) SetNillableLastActiveAt(v *time.Time) *UserUpdateOne {
if v != nil {
_u.SetLastActiveAt(*v)
}
return _u
}
// ClearLastActiveAt clears the value of the "last_active_at" field.
func (_u *UserUpdateOne) ClearLastActiveAt() *UserUpdateOne {
_u.mutation.ClearLastActiveAt()
return _u
}
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (_u *UserUpdateOne) SetBalanceNotifyEnabled(v bool) *UserUpdateOne { func (_u *UserUpdateOne) SetBalanceNotifyEnabled(v bool) *UserUpdateOne {
_u.mutation.SetBalanceNotifyEnabled(v) _u.mutation.SetBalanceNotifyEnabled(v)
...@@ -1788,6 +2080,36 @@ func (_u *UserUpdateOne) AddPaymentOrders(v ...*PaymentOrder) *UserUpdateOne { ...@@ -1788,6 +2080,36 @@ func (_u *UserUpdateOne) AddPaymentOrders(v ...*PaymentOrder) *UserUpdateOne {
return _u.AddPaymentOrderIDs(ids...) return _u.AddPaymentOrderIDs(ids...)
} }
// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by IDs.
func (_u *UserUpdateOne) AddAuthIdentityIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddAuthIdentityIDs(ids...)
return _u
}
// AddAuthIdentities adds the "auth_identities" edges to the AuthIdentity entity.
func (_u *UserUpdateOne) AddAuthIdentities(v ...*AuthIdentity) *UserUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddAuthIdentityIDs(ids...)
}
// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs.
func (_u *UserUpdateOne) AddPendingAuthSessionIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddPendingAuthSessionIDs(ids...)
return _u
}
// AddPendingAuthSessions adds the "pending_auth_sessions" edges to the PendingAuthSession entity.
func (_u *UserUpdateOne) AddPendingAuthSessions(v ...*PendingAuthSession) *UserUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddPendingAuthSessionIDs(ids...)
}
// Mutation returns the UserMutation object of the builder. // Mutation returns the UserMutation object of the builder.
func (_u *UserUpdateOne) Mutation() *UserMutation { func (_u *UserUpdateOne) Mutation() *UserMutation {
return _u.mutation return _u.mutation
...@@ -2003,6 +2325,48 @@ func (_u *UserUpdateOne) RemovePaymentOrders(v ...*PaymentOrder) *UserUpdateOne ...@@ -2003,6 +2325,48 @@ func (_u *UserUpdateOne) RemovePaymentOrders(v ...*PaymentOrder) *UserUpdateOne
return _u.RemovePaymentOrderIDs(ids...) return _u.RemovePaymentOrderIDs(ids...)
} }
// ClearAuthIdentities clears all "auth_identities" edges to the AuthIdentity entity.
func (_u *UserUpdateOne) ClearAuthIdentities() *UserUpdateOne {
_u.mutation.ClearAuthIdentities()
return _u
}
// RemoveAuthIdentityIDs removes the "auth_identities" edge to AuthIdentity entities by IDs.
func (_u *UserUpdateOne) RemoveAuthIdentityIDs(ids ...int64) *UserUpdateOne {
_u.mutation.RemoveAuthIdentityIDs(ids...)
return _u
}
// RemoveAuthIdentities removes "auth_identities" edges to AuthIdentity entities.
func (_u *UserUpdateOne) RemoveAuthIdentities(v ...*AuthIdentity) *UserUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveAuthIdentityIDs(ids...)
}
// ClearPendingAuthSessions clears all "pending_auth_sessions" edges to the PendingAuthSession entity.
func (_u *UserUpdateOne) ClearPendingAuthSessions() *UserUpdateOne {
_u.mutation.ClearPendingAuthSessions()
return _u
}
// RemovePendingAuthSessionIDs removes the "pending_auth_sessions" edge to PendingAuthSession entities by IDs.
func (_u *UserUpdateOne) RemovePendingAuthSessionIDs(ids ...int64) *UserUpdateOne {
_u.mutation.RemovePendingAuthSessionIDs(ids...)
return _u
}
// RemovePendingAuthSessions removes "pending_auth_sessions" edges to PendingAuthSession entities.
func (_u *UserUpdateOne) RemovePendingAuthSessions(v ...*PendingAuthSession) *UserUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemovePendingAuthSessionIDs(ids...)
}
// Where appends a list predicates to the UserUpdate builder. // Where appends a list predicates to the UserUpdate builder.
func (_u *UserUpdateOne) Where(ps ...predicate.User) *UserUpdateOne { func (_u *UserUpdateOne) Where(ps ...predicate.User) *UserUpdateOne {
_u.mutation.Where(ps...) _u.mutation.Where(ps...)
...@@ -2085,6 +2449,11 @@ func (_u *UserUpdateOne) check() error { ...@@ -2085,6 +2449,11 @@ func (_u *UserUpdateOne) check() error {
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)} return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
} }
} }
if v, ok := _u.mutation.SignupSource(); ok {
if err := user.SignupSourceValidator(v); err != nil {
return &ValidationError{Name: "signup_source", err: fmt.Errorf(`ent: validator failed for field "User.signup_source": %w`, err)}
}
}
return nil return nil
} }
...@@ -2171,6 +2540,21 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) { ...@@ -2171,6 +2540,21 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
if _u.mutation.TotpEnabledAtCleared() { if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime) _spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
} }
if value, ok := _u.mutation.SignupSource(); ok {
_spec.SetField(user.FieldSignupSource, field.TypeString, value)
}
if value, ok := _u.mutation.LastLoginAt(); ok {
_spec.SetField(user.FieldLastLoginAt, field.TypeTime, value)
}
if _u.mutation.LastLoginAtCleared() {
_spec.ClearField(user.FieldLastLoginAt, field.TypeTime)
}
if value, ok := _u.mutation.LastActiveAt(); ok {
_spec.SetField(user.FieldLastActiveAt, field.TypeTime, value)
}
if _u.mutation.LastActiveAtCleared() {
_spec.ClearField(user.FieldLastActiveAt, field.TypeTime)
}
if value, ok := _u.mutation.BalanceNotifyEnabled(); ok { if value, ok := _u.mutation.BalanceNotifyEnabled(); ok {
_spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value) _spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
} }
...@@ -2657,6 +3041,96 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) { ...@@ -2657,6 +3041,96 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
} }
_spec.Edges.Add = append(_spec.Edges.Add, edge) _spec.Edges.Add = append(_spec.Edges.Add, edge)
} }
if _u.mutation.AuthIdentitiesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AuthIdentitiesTable,
Columns: []string{user.AuthIdentitiesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedAuthIdentitiesIDs(); len(nodes) > 0 && !_u.mutation.AuthIdentitiesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AuthIdentitiesTable,
Columns: []string{user.AuthIdentitiesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.AuthIdentitiesIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AuthIdentitiesTable,
Columns: []string{user.AuthIdentitiesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.PendingAuthSessionsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PendingAuthSessionsTable,
Columns: []string{user.PendingAuthSessionsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedPendingAuthSessionsIDs(); len(nodes) > 0 && !_u.mutation.PendingAuthSessionsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PendingAuthSessionsTable,
Columns: []string{user.PendingAuthSessionsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.PendingAuthSessionsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PendingAuthSessionsTable,
Columns: []string{user.PendingAuthSessionsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
_node = &User{config: _u.config} _node = &User{config: _u.config}
_spec.Assign = _node.assignValues _spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues _spec.ScanValues = _node.scanValues
......
...@@ -1608,6 +1608,9 @@ func (c *Config) Validate() error { ...@@ -1608,6 +1608,9 @@ func (c *Config) Validate() error {
return fmt.Errorf("security.csp.policy is required when CSP is enabled") return fmt.Errorf("security.csp.policy is required when CSP is enabled")
} }
if c.LinuxDo.Enabled { if c.LinuxDo.Enabled {
if !c.LinuxDo.UsePKCE {
return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.enabled=true")
}
if strings.TrimSpace(c.LinuxDo.ClientID) == "" { if strings.TrimSpace(c.LinuxDo.ClientID) == "" {
return fmt.Errorf("linuxdo_connect.client_id is required when linuxdo_connect.enabled=true") return fmt.Errorf("linuxdo_connect.client_id is required when linuxdo_connect.enabled=true")
} }
...@@ -1629,9 +1632,6 @@ func (c *Config) Validate() error { ...@@ -1629,9 +1632,6 @@ func (c *Config) Validate() error {
default: default:
return fmt.Errorf("linuxdo_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none") return fmt.Errorf("linuxdo_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none")
} }
if method == "none" && !c.LinuxDo.UsePKCE {
return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.token_auth_method=none")
}
if (method == "" || method == "client_secret_post" || method == "client_secret_basic") && if (method == "" || method == "client_secret_post" || method == "client_secret_basic") &&
strings.TrimSpace(c.LinuxDo.ClientSecret) == "" { strings.TrimSpace(c.LinuxDo.ClientSecret) == "" {
return fmt.Errorf("linuxdo_connect.client_secret is required when linuxdo_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic") return fmt.Errorf("linuxdo_connect.client_secret is required when linuxdo_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic")
...@@ -1663,6 +1663,12 @@ func (c *Config) Validate() error { ...@@ -1663,6 +1663,12 @@ func (c *Config) Validate() error {
warnIfInsecureURL("linuxdo_connect.frontend_redirect_url", c.LinuxDo.FrontendRedirectURL) warnIfInsecureURL("linuxdo_connect.frontend_redirect_url", c.LinuxDo.FrontendRedirectURL)
} }
if c.OIDC.Enabled { if c.OIDC.Enabled {
if !c.OIDC.UsePKCE {
return fmt.Errorf("oidc_connect.use_pkce must be true when oidc_connect.enabled=true")
}
if !c.OIDC.ValidateIDToken {
return fmt.Errorf("oidc_connect.validate_id_token must be true when oidc_connect.enabled=true")
}
if strings.TrimSpace(c.OIDC.ClientID) == "" { if strings.TrimSpace(c.OIDC.ClientID) == "" {
return fmt.Errorf("oidc_connect.client_id is required when oidc_connect.enabled=true") return fmt.Errorf("oidc_connect.client_id is required when oidc_connect.enabled=true")
} }
...@@ -1685,9 +1691,6 @@ func (c *Config) Validate() error { ...@@ -1685,9 +1691,6 @@ func (c *Config) Validate() error {
default: default:
return fmt.Errorf("oidc_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none") return fmt.Errorf("oidc_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none")
} }
if method == "none" && !c.OIDC.UsePKCE {
return fmt.Errorf("oidc_connect.use_pkce must be true when oidc_connect.token_auth_method=none")
}
if (method == "" || method == "client_secret_post" || method == "client_secret_basic") && if (method == "" || method == "client_secret_post" || method == "client_secret_basic") &&
strings.TrimSpace(c.OIDC.ClientSecret) == "" { strings.TrimSpace(c.OIDC.ClientSecret) == "" {
return fmt.Errorf("oidc_connect.client_secret is required when oidc_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic") return fmt.Errorf("oidc_connect.client_secret is required when oidc_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic")
......
...@@ -73,6 +73,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { ...@@ -73,6 +73,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
authSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
// Check if ops monitoring is enabled (respects config.ops.enabled) // Check if ops monitoring is enabled (respects config.ops.enabled)
opsEnabled := h.opsService != nil && h.opsService.IsMonitoringEnabled(c.Request.Context()) opsEnabled := h.opsService != nil && h.opsService.IsMonitoringEnabled(c.Request.Context())
...@@ -93,7 +98,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { ...@@ -93,7 +98,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
paymentCfg = &service.PaymentConfig{} paymentCfg = &service.PaymentConfig{}
} }
response.Success(c, dto.SystemSettings{ payload := dto.SystemSettings{
RegistrationEnabled: settings.RegistrationEnabled, RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled, EmailVerifyEnabled: settings.EmailVerifyEnabled,
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist, RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
...@@ -200,7 +205,8 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { ...@@ -200,7 +205,8 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
PaymentCancelRateLimitWindow: paymentCfg.CancelRateLimitWindow, PaymentCancelRateLimitWindow: paymentCfg.CancelRateLimitWindow,
PaymentCancelRateLimitUnit: paymentCfg.CancelRateLimitUnit, PaymentCancelRateLimitUnit: paymentCfg.CancelRateLimitUnit,
PaymentCancelRateLimitMode: paymentCfg.CancelRateLimitMode, PaymentCancelRateLimitMode: paymentCfg.CancelRateLimitMode,
}) }
response.Success(c, systemSettingsResponseData(payload, authSourceDefaults))
} }
// UpdateSettingsRequest 更新设置请求 // UpdateSettingsRequest 更新设置请求
...@@ -276,9 +282,30 @@ type UpdateSettingsRequest struct { ...@@ -276,9 +282,30 @@ type UpdateSettingsRequest struct {
CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"` CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"`
// 默认配置 // 默认配置
DefaultConcurrency int `json:"default_concurrency"` DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"` DefaultBalance float64 `json:"default_balance"`
DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"` DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"`
AuthSourceDefaultEmailConcurrency *int `json:"auth_source_default_email_concurrency"`
AuthSourceDefaultEmailSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_email_subscriptions"`
AuthSourceDefaultEmailGrantOnSignup *bool `json:"auth_source_default_email_grant_on_signup"`
AuthSourceDefaultEmailGrantOnFirstBind *bool `json:"auth_source_default_email_grant_on_first_bind"`
AuthSourceDefaultLinuxDoBalance *float64 `json:"auth_source_default_linuxdo_balance"`
AuthSourceDefaultLinuxDoConcurrency *int `json:"auth_source_default_linuxdo_concurrency"`
AuthSourceDefaultLinuxDoSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_linuxdo_subscriptions"`
AuthSourceDefaultLinuxDoGrantOnSignup *bool `json:"auth_source_default_linuxdo_grant_on_signup"`
AuthSourceDefaultLinuxDoGrantOnFirstBind *bool `json:"auth_source_default_linuxdo_grant_on_first_bind"`
AuthSourceDefaultOIDCBalance *float64 `json:"auth_source_default_oidc_balance"`
AuthSourceDefaultOIDCConcurrency *int `json:"auth_source_default_oidc_concurrency"`
AuthSourceDefaultOIDCSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_oidc_subscriptions"`
AuthSourceDefaultOIDCGrantOnSignup *bool `json:"auth_source_default_oidc_grant_on_signup"`
AuthSourceDefaultOIDCGrantOnFirstBind *bool `json:"auth_source_default_oidc_grant_on_first_bind"`
AuthSourceDefaultWeChatBalance *float64 `json:"auth_source_default_wechat_balance"`
AuthSourceDefaultWeChatConcurrency *int `json:"auth_source_default_wechat_concurrency"`
AuthSourceDefaultWeChatSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_wechat_subscriptions"`
AuthSourceDefaultWeChatGrantOnSignup *bool `json:"auth_source_default_wechat_grant_on_signup"`
AuthSourceDefaultWeChatGrantOnFirstBind *bool `json:"auth_source_default_wechat_grant_on_first_bind"`
ForceEmailOnThirdPartySignup *bool `json:"force_email_on_third_party_signup"`
// Model fallback configuration // Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"` EnableModelFallback bool `json:"enable_model_fallback"`
...@@ -357,6 +384,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -357,6 +384,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
previousAuthSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
// 验证参数 // 验证参数
if req.DefaultConcurrency < 1 { if req.DefaultConcurrency < 1 {
...@@ -381,6 +413,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -381,6 +413,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
req.SMTPPort = 587 req.SMTPPort = 587
} }
req.DefaultSubscriptions = normalizeDefaultSubscriptions(req.DefaultSubscriptions) req.DefaultSubscriptions = normalizeDefaultSubscriptions(req.DefaultSubscriptions)
req.AuthSourceDefaultEmailSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultEmailSubscriptions)
req.AuthSourceDefaultLinuxDoSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultLinuxDoSubscriptions)
req.AuthSourceDefaultOIDCSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultOIDCSubscriptions)
req.AuthSourceDefaultWeChatSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultWeChatSubscriptions)
// SMTP 配置保护:如果请求中 smtp_host 为空但数据库中已有配置,则保留已有 SMTP 配置 // SMTP 配置保护:如果请求中 smtp_host 为空但数据库中已有配置,则保留已有 SMTP 配置
// 防止前端加载设置失败时空表单覆盖已保存的 SMTP 配置 // 防止前端加载设置失败时空表单覆盖已保存的 SMTP 配置
...@@ -538,25 +574,27 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -538,25 +574,27 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.BadRequest(c, "OIDC scopes must contain openid") response.BadRequest(c, "OIDC scopes must contain openid")
return return
} }
if !req.OIDCConnectUsePKCE {
response.BadRequest(c, "OIDC PKCE must be enabled")
return
}
if !req.OIDCConnectValidateIDToken {
response.BadRequest(c, "OIDC ID Token validation must be enabled")
return
}
switch req.OIDCConnectTokenAuthMethod { switch req.OIDCConnectTokenAuthMethod {
case "", "client_secret_post", "client_secret_basic", "none": case "", "client_secret_post", "client_secret_basic", "none":
default: default:
response.BadRequest(c, "OIDC Token Auth Method must be one of client_secret_post/client_secret_basic/none") response.BadRequest(c, "OIDC Token Auth Method must be one of client_secret_post/client_secret_basic/none")
return return
} }
if req.OIDCConnectTokenAuthMethod == "none" && !req.OIDCConnectUsePKCE {
response.BadRequest(c, "OIDC PKCE must be enabled when token_auth_method=none")
return
}
if req.OIDCConnectClockSkewSeconds < 0 || req.OIDCConnectClockSkewSeconds > 600 { if req.OIDCConnectClockSkewSeconds < 0 || req.OIDCConnectClockSkewSeconds > 600 {
response.BadRequest(c, "OIDC clock skew seconds must be between 0 and 600") response.BadRequest(c, "OIDC clock skew seconds must be between 0 and 600")
return return
} }
if req.OIDCConnectValidateIDToken { if req.OIDCConnectAllowedSigningAlgs == "" {
if req.OIDCConnectAllowedSigningAlgs == "" { response.BadRequest(c, "OIDC Allowed Signing Algs is required when validate_id_token=true")
response.BadRequest(c, "OIDC Allowed Signing Algs is required when validate_id_token=true") return
return
}
} }
if req.OIDCConnectJWKSURL != "" { if req.OIDCConnectJWKSURL != "" {
if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectJWKSURL); err != nil { if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectJWKSURL); err != nil {
...@@ -933,6 +971,41 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -933,6 +971,41 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
authSourceDefaults := &service.AuthSourceDefaultSettings{
Email: service.ProviderDefaultGrantSettings{
Balance: float64ValueOrDefault(req.AuthSourceDefaultEmailBalance, previousAuthSourceDefaults.Email.Balance),
Concurrency: intValueOrDefault(req.AuthSourceDefaultEmailConcurrency, previousAuthSourceDefaults.Email.Concurrency),
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultEmailSubscriptions, previousAuthSourceDefaults.Email.Subscriptions),
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnSignup, previousAuthSourceDefaults.Email.GrantOnSignup),
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnFirstBind, previousAuthSourceDefaults.Email.GrantOnFirstBind),
},
LinuxDo: service.ProviderDefaultGrantSettings{
Balance: float64ValueOrDefault(req.AuthSourceDefaultLinuxDoBalance, previousAuthSourceDefaults.LinuxDo.Balance),
Concurrency: intValueOrDefault(req.AuthSourceDefaultLinuxDoConcurrency, previousAuthSourceDefaults.LinuxDo.Concurrency),
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultLinuxDoSubscriptions, previousAuthSourceDefaults.LinuxDo.Subscriptions),
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnSignup, previousAuthSourceDefaults.LinuxDo.GrantOnSignup),
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnFirstBind, previousAuthSourceDefaults.LinuxDo.GrantOnFirstBind),
},
OIDC: service.ProviderDefaultGrantSettings{
Balance: float64ValueOrDefault(req.AuthSourceDefaultOIDCBalance, previousAuthSourceDefaults.OIDC.Balance),
Concurrency: intValueOrDefault(req.AuthSourceDefaultOIDCConcurrency, previousAuthSourceDefaults.OIDC.Concurrency),
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultOIDCSubscriptions, previousAuthSourceDefaults.OIDC.Subscriptions),
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnSignup, previousAuthSourceDefaults.OIDC.GrantOnSignup),
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnFirstBind, previousAuthSourceDefaults.OIDC.GrantOnFirstBind),
},
WeChat: service.ProviderDefaultGrantSettings{
Balance: float64ValueOrDefault(req.AuthSourceDefaultWeChatBalance, previousAuthSourceDefaults.WeChat.Balance),
Concurrency: intValueOrDefault(req.AuthSourceDefaultWeChatConcurrency, previousAuthSourceDefaults.WeChat.Concurrency),
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultWeChatSubscriptions, previousAuthSourceDefaults.WeChat.Subscriptions),
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnSignup, previousAuthSourceDefaults.WeChat.GrantOnSignup),
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnFirstBind, previousAuthSourceDefaults.WeChat.GrantOnFirstBind),
},
ForceEmailOnThirdPartySignup: boolValueOrDefault(req.ForceEmailOnThirdPartySignup, previousAuthSourceDefaults.ForceEmailOnThirdPartySignup),
}
if err := h.settingService.UpdateAuthSourceDefaultSettings(c.Request.Context(), authSourceDefaults); err != nil {
response.ErrorFrom(c, err)
return
}
// Update payment configuration (integrated into system settings). // Update payment configuration (integrated into system settings).
// Skip if no payment fields were provided (prevents accidental wipe). // Skip if no payment fields were provided (prevents accidental wipe).
...@@ -977,6 +1050,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -977,6 +1050,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
updatedAuthSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
updatedDefaultSubscriptions := make([]dto.DefaultSubscriptionSetting, 0, len(updatedSettings.DefaultSubscriptions)) updatedDefaultSubscriptions := make([]dto.DefaultSubscriptionSetting, 0, len(updatedSettings.DefaultSubscriptions))
for _, sub := range updatedSettings.DefaultSubscriptions { for _, sub := range updatedSettings.DefaultSubscriptions {
updatedDefaultSubscriptions = append(updatedDefaultSubscriptions, dto.DefaultSubscriptionSetting{ updatedDefaultSubscriptions = append(updatedDefaultSubscriptions, dto.DefaultSubscriptionSetting{
...@@ -994,7 +1072,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -994,7 +1072,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
updatedPaymentCfg = &service.PaymentConfig{} updatedPaymentCfg = &service.PaymentConfig{}
} }
response.Success(c, dto.SystemSettings{ payload := dto.SystemSettings{
RegistrationEnabled: updatedSettings.RegistrationEnabled, RegistrationEnabled: updatedSettings.RegistrationEnabled,
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled, EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist, RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist,
...@@ -1100,7 +1178,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -1100,7 +1178,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
PaymentCancelRateLimitWindow: updatedPaymentCfg.CancelRateLimitWindow, PaymentCancelRateLimitWindow: updatedPaymentCfg.CancelRateLimitWindow,
PaymentCancelRateLimitUnit: updatedPaymentCfg.CancelRateLimitUnit, PaymentCancelRateLimitUnit: updatedPaymentCfg.CancelRateLimitUnit,
PaymentCancelRateLimitMode: updatedPaymentCfg.CancelRateLimitMode, PaymentCancelRateLimitMode: updatedPaymentCfg.CancelRateLimitMode,
}) }
response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults))
} }
// hasPaymentFields returns true if any payment-related field was explicitly provided. // hasPaymentFields returns true if any payment-related field was explicitly provided.
...@@ -1412,6 +1491,84 @@ func normalizeDefaultSubscriptions(input []dto.DefaultSubscriptionSetting) []dto ...@@ -1412,6 +1491,84 @@ func normalizeDefaultSubscriptions(input []dto.DefaultSubscriptionSetting) []dto
return normalized return normalized
} }
func normalizeOptionalDefaultSubscriptions(input *[]dto.DefaultSubscriptionSetting) *[]dto.DefaultSubscriptionSetting {
if input == nil {
return nil
}
normalized := normalizeDefaultSubscriptions(*input)
return &normalized
}
func float64ValueOrDefault(value *float64, fallback float64) float64 {
if value == nil {
return fallback
}
return *value
}
func intValueOrDefault(value *int, fallback int) int {
if value == nil {
return fallback
}
return *value
}
func boolValueOrDefault(value *bool, fallback bool) bool {
if value == nil {
return fallback
}
return *value
}
func defaultSubscriptionsValueOrDefault(input *[]dto.DefaultSubscriptionSetting, fallback []service.DefaultSubscriptionSetting) []service.DefaultSubscriptionSetting {
if input == nil {
return fallback
}
result := make([]service.DefaultSubscriptionSetting, 0, len(*input))
for _, item := range *input {
result = append(result, service.DefaultSubscriptionSetting{
GroupID: item.GroupID,
ValidityDays: item.ValidityDays,
})
}
return result
}
func systemSettingsResponseData(settings dto.SystemSettings, authSourceDefaults *service.AuthSourceDefaultSettings) map[string]any {
data := make(map[string]any)
raw, err := json.Marshal(settings)
if err == nil {
_ = json.Unmarshal(raw, &data)
}
if authSourceDefaults == nil {
authSourceDefaults = &service.AuthSourceDefaultSettings{}
}
data["auth_source_default_email_balance"] = authSourceDefaults.Email.Balance
data["auth_source_default_email_concurrency"] = authSourceDefaults.Email.Concurrency
data["auth_source_default_email_subscriptions"] = authSourceDefaults.Email.Subscriptions
data["auth_source_default_email_grant_on_signup"] = authSourceDefaults.Email.GrantOnSignup
data["auth_source_default_email_grant_on_first_bind"] = authSourceDefaults.Email.GrantOnFirstBind
data["auth_source_default_linuxdo_balance"] = authSourceDefaults.LinuxDo.Balance
data["auth_source_default_linuxdo_concurrency"] = authSourceDefaults.LinuxDo.Concurrency
data["auth_source_default_linuxdo_subscriptions"] = authSourceDefaults.LinuxDo.Subscriptions
data["auth_source_default_linuxdo_grant_on_signup"] = authSourceDefaults.LinuxDo.GrantOnSignup
data["auth_source_default_linuxdo_grant_on_first_bind"] = authSourceDefaults.LinuxDo.GrantOnFirstBind
data["auth_source_default_oidc_balance"] = authSourceDefaults.OIDC.Balance
data["auth_source_default_oidc_concurrency"] = authSourceDefaults.OIDC.Concurrency
data["auth_source_default_oidc_subscriptions"] = authSourceDefaults.OIDC.Subscriptions
data["auth_source_default_oidc_grant_on_signup"] = authSourceDefaults.OIDC.GrantOnSignup
data["auth_source_default_oidc_grant_on_first_bind"] = authSourceDefaults.OIDC.GrantOnFirstBind
data["auth_source_default_wechat_balance"] = authSourceDefaults.WeChat.Balance
data["auth_source_default_wechat_concurrency"] = authSourceDefaults.WeChat.Concurrency
data["auth_source_default_wechat_subscriptions"] = authSourceDefaults.WeChat.Subscriptions
data["auth_source_default_wechat_grant_on_signup"] = authSourceDefaults.WeChat.GrantOnSignup
data["auth_source_default_wechat_grant_on_first_bind"] = authSourceDefaults.WeChat.GrantOnFirstBind
data["force_email_on_third_party_signup"] = authSourceDefaults.ForceEmailOnThirdPartySignup
return data
}
func equalStringSlice(a, b []string) bool { func equalStringSlice(a, b []string) bool {
if len(a) != len(b) { if len(a) != len(b) {
return false return false
......
package admin
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type settingHandlerRepoStub struct {
values map[string]string
lastUpdates map[string]string
}
func (s *settingHandlerRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) {
panic("unexpected Get call")
}
func (s *settingHandlerRepoStub) GetValue(ctx context.Context, key string) (string, error) {
panic("unexpected GetValue call")
}
func (s *settingHandlerRepoStub) Set(ctx context.Context, key, value string) error {
panic("unexpected Set call")
}
func (s *settingHandlerRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
out := make(map[string]string, len(keys))
for _, key := range keys {
if value, ok := s.values[key]; ok {
out[key] = value
}
}
return out, nil
}
func (s *settingHandlerRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
s.lastUpdates = make(map[string]string, len(settings))
for key, value := range settings {
s.lastUpdates[key] = value
if s.values == nil {
s.values = map[string]string{}
}
s.values[key] = value
}
return nil
}
func (s *settingHandlerRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
out := make(map[string]string, len(s.values))
for key, value := range s.values {
out[key] = value
}
return out, nil
}
func (s *settingHandlerRepoStub) Delete(ctx context.Context, key string) error {
panic("unexpected Delete call")
}
func TestSettingHandler_GetSettings_InjectsAuthSourceDefaults(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &settingHandlerRepoStub{
values: map[string]string{
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyPromoCodeEnabled: "true",
service.SettingKeyAuthSourceDefaultEmailBalance: "9.5",
service.SettingKeyAuthSourceDefaultEmailConcurrency: "8",
service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":31,"validity_days":15}]`,
service.SettingKeyForceEmailOnThirdPartySignup: "true",
},
}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/admin/settings", nil)
handler.GetSettings(c)
require.Equal(t, http.StatusOK, rec.Code)
var resp response.Response
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
data, ok := resp.Data.(map[string]any)
require.True(t, ok)
require.Equal(t, 9.5, data["auth_source_default_email_balance"])
require.Equal(t, float64(8), data["auth_source_default_email_concurrency"])
require.Equal(t, true, data["force_email_on_third_party_signup"])
subscriptions, ok := data["auth_source_default_email_subscriptions"].([]any)
require.True(t, ok)
require.Len(t, subscriptions, 1)
}
func TestSettingHandler_UpdateSettings_PreservesOmittedAuthSourceDefaults(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &settingHandlerRepoStub{
values: map[string]string{
service.SettingKeyRegistrationEnabled: "false",
service.SettingKeyPromoCodeEnabled: "true",
service.SettingKeyAuthSourceDefaultEmailBalance: "9.5",
service.SettingKeyAuthSourceDefaultEmailConcurrency: "8",
service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":31,"validity_days":15}]`,
service.SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true",
service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "false",
service.SettingKeyForceEmailOnThirdPartySignup: "true",
},
}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
body := map[string]any{
"registration_enabled": true,
"promo_code_enabled": true,
"auth_source_default_email_balance": 12.75,
}
rawBody, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateSettings(c)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "12.75000000", repo.values[service.SettingKeyAuthSourceDefaultEmailBalance])
require.Equal(t, "8", repo.values[service.SettingKeyAuthSourceDefaultEmailConcurrency])
require.Equal(t, `[{"group_id":31,"validity_days":15}]`, repo.values[service.SettingKeyAuthSourceDefaultEmailSubscriptions])
require.Equal(t, "true", repo.values[service.SettingKeyForceEmailOnThirdPartySignup])
var resp response.Response
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
data, ok := resp.Data.(map[string]any)
require.True(t, ok)
require.Equal(t, 12.75, data["auth_source_default_email_balance"])
require.Equal(t, float64(8), data["auth_source_default_email_concurrency"])
require.Equal(t, true, data["force_email_on_third_party_signup"])
}
...@@ -219,7 +219,7 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { ...@@ -219,7 +219,7 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
} }
// 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired // 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
if err != nil { if err != nil {
if errors.Is(err, service.ErrOAuthInvitationRequired) { if errors.Is(err, service.ErrOAuthInvitationRequired) {
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
...@@ -262,6 +262,7 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { ...@@ -262,6 +262,7 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
ProviderKey: "linuxdo", ProviderKey: "linuxdo",
ProviderSubject: subject, ProviderSubject: subject,
}, },
TargetUserID: &user.ID,
ResolvedEmail: email, ResolvedEmail: email,
RedirectTo: redirectTo, RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey, BrowserSessionKey: browserSessionKey,
...@@ -287,7 +288,9 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { ...@@ -287,7 +288,9 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
} }
type completeLinuxDoOAuthRequest struct { type completeLinuxDoOAuthRequest struct {
InvitationCode string `json:"invitation_code" binding:"required"` InvitationCode string `json:"invitation_code" binding:"required"`
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
} }
// CompleteLinuxDoOAuthRegistration completes a pending OAuth registration by validating // CompleteLinuxDoOAuthRegistration completes a pending OAuth registration by validating
...@@ -335,11 +338,23 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) { ...@@ -335,11 +338,23 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
return return
} }
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
AdoptDisplayName: req.AdoptDisplayName,
AdoptAvatar: req.AdoptAvatar,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, &user.ID); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
return
}
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil { if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
clearOAuthPendingSessionCookie(c, secureCookie) clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie) clearOAuthPendingBrowserCookie(c, secureCookie)
......
package handler package handler
import ( import (
"bytes"
"context"
"net/http"
"net/http/httptest"
"strings" "strings"
"testing" "testing"
"time"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
...@@ -110,3 +121,79 @@ func TestSingleLineStripsWhitespace(t *testing.T) { ...@@ -110,3 +121,79 @@ func TestSingleLineStripsWhitespace(t *testing.T) {
require.Equal(t, "hello world", singleLine("hello\r\nworld")) require.Equal(t, "hello world", singleLine("hello\r\nworld"))
require.Equal(t, "", singleLine("\n\t\r")) require.Equal(t, "", singleLine("\n\t\r"))
} }
func TestCompleteLinuxDoOAuthRegistrationAppliesPendingAdoptionDecision(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
session, err := client.PendingAuthSession.Create().
SetSessionToken("linuxdo-complete-session").
SetIntent("login").
SetProviderType("linuxdo").
SetProviderKey("linuxdo").
SetProviderSubject("linuxdo-subject-1").
SetResolvedEmail("linuxdo-subject-1@linuxdo-connect.invalid").
SetBrowserSessionKey("linuxdo-browser").
SetUpstreamIdentityClaims(map[string]any{
"username": "linuxdo_user",
"suggested_display_name": "LinuxDo Display",
"suggested_avatar_url": "https://cdn.example/linuxdo.png",
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
_, err = service.NewAuthPendingIdentityService(client).UpsertAdoptionDecision(ctx, service.PendingIdentityAdoptionDecisionInput{
PendingAuthSessionID: session.ID,
AdoptAvatar: true,
})
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true}`)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-browser")})
c.Request = req
handler.CompleteLinuxDoOAuthRegistration(c)
require.Equal(t, http.StatusOK, recorder.Code)
responseData := decodeJSONBody(t, recorder)
require.NotEmpty(t, responseData["access_token"])
userEntity, err := client.User.Query().
Where(dbuser.EmailEQ(session.ResolvedEmail)).
Only(ctx)
require.NoError(t, err)
require.Equal(t, "LinuxDo Display", userEntity.Username)
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("linuxdo"),
authidentity.ProviderKeyEQ("linuxdo"),
authidentity.ProviderSubjectEQ("linuxdo-subject-1"),
).
Only(ctx)
require.NoError(t, err)
require.Equal(t, userEntity.ID, identity.UserID)
require.Equal(t, "LinuxDo Display", identity.Metadata["display_name"])
require.Equal(t, "https://cdn.example/linuxdo.png", identity.Metadata["avatar_url"])
decision, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, decision.IdentityID)
require.Equal(t, identity.ID, *decision.IdentityID)
require.True(t, decision.AdoptDisplayName)
require.True(t, decision.AdoptAvatar)
consumed, err := client.PendingAuthSession.Query().
Where(pendingauthsession.IDEQ(session.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, consumed.ConsumedAt)
}
package handler package handler
import ( import (
"context"
"errors"
"io"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth" "github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
...@@ -26,6 +33,7 @@ const ( ...@@ -26,6 +33,7 @@ const (
type oauthPendingSessionPayload struct { type oauthPendingSessionPayload struct {
Intent string Intent string
Identity service.PendingAuthIdentityKey Identity service.PendingAuthIdentityKey
TargetUserID *int64
ResolvedEmail string ResolvedEmail string
RedirectTo string RedirectTo string
BrowserSessionKey string BrowserSessionKey string
...@@ -33,6 +41,11 @@ type oauthPendingSessionPayload struct { ...@@ -33,6 +41,11 @@ type oauthPendingSessionPayload struct {
CompletionResponse map[string]any CompletionResponse map[string]any
} }
type oauthAdoptionDecisionRequest struct {
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
func (h *AuthHandler) pendingIdentityService() (*service.AuthPendingIdentityService, error) { func (h *AuthHandler) pendingIdentityService() (*service.AuthPendingIdentityService, error) {
if h == nil || h.authService == nil || h.authService.EntClient() == nil { if h == nil || h.authService == nil || h.authService.EntClient() == nil {
return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
...@@ -125,6 +138,7 @@ func (h *AuthHandler) createOAuthPendingSession(c *gin.Context, payload oauthPen ...@@ -125,6 +138,7 @@ func (h *AuthHandler) createOAuthPendingSession(c *gin.Context, payload oauthPen
session, err := svc.CreatePendingSession(c.Request.Context(), service.CreatePendingAuthSessionInput{ session, err := svc.CreatePendingSession(c.Request.Context(), service.CreatePendingAuthSessionInput{
Intent: strings.TrimSpace(payload.Intent), Intent: strings.TrimSpace(payload.Intent),
Identity: payload.Identity, Identity: payload.Identity,
TargetUserID: payload.TargetUserID,
ResolvedEmail: strings.TrimSpace(payload.ResolvedEmail), ResolvedEmail: strings.TrimSpace(payload.ResolvedEmail),
RedirectTo: strings.TrimSpace(payload.RedirectTo), RedirectTo: strings.TrimSpace(payload.RedirectTo),
BrowserSessionKey: strings.TrimSpace(payload.BrowserSessionKey), BrowserSessionKey: strings.TrimSpace(payload.BrowserSessionKey),
...@@ -175,6 +189,291 @@ func pendingSessionWantsInvitation(payload map[string]any) bool { ...@@ -175,6 +189,291 @@ func pendingSessionWantsInvitation(payload map[string]any) bool {
return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "error")), "invitation_required") return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "error")), "invitation_required")
} }
func (r oauthAdoptionDecisionRequest) hasDecision() bool {
return r.AdoptDisplayName != nil || r.AdoptAvatar != nil
}
func (r oauthAdoptionDecisionRequest) toServiceInput(sessionID int64) service.PendingIdentityAdoptionDecisionInput {
input := service.PendingIdentityAdoptionDecisionInput{
PendingAuthSessionID: sessionID,
}
if r.AdoptDisplayName != nil {
input.AdoptDisplayName = *r.AdoptDisplayName
}
if r.AdoptAvatar != nil {
input.AdoptAvatar = *r.AdoptAvatar
}
return input
}
func bindOptionalOAuthAdoptionDecision(c *gin.Context) (oauthAdoptionDecisionRequest, error) {
var req oauthAdoptionDecisionRequest
if c == nil || c.Request == nil || c.Request.Body == nil {
return req, nil
}
if err := c.ShouldBindJSON(&req); err != nil {
if errors.Is(err, io.EOF) {
return req, nil
}
return req, err
}
return req, nil
}
func persistPendingOAuthAdoptionDecision(
c *gin.Context,
svc *service.AuthPendingIdentityService,
sessionID int64,
req oauthAdoptionDecisionRequest,
) error {
if !req.hasDecision() {
return nil
}
if svc == nil {
return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
}
if _, err := svc.UpsertAdoptionDecision(c.Request.Context(), req.toServiceInput(sessionID)); err != nil {
return infraerrors.InternalServer("PENDING_AUTH_ADOPTION_SAVE_FAILED", "failed to save oauth profile adoption decision").WithCause(err)
}
return nil
}
func cloneOAuthMetadata(values map[string]any) map[string]any {
if len(values) == 0 {
return map[string]any{}
}
cloned := make(map[string]any, len(values))
for key, value := range values {
cloned[key] = value
}
return cloned
}
func normalizeAdoptedOAuthDisplayName(value string) string {
value = strings.TrimSpace(value)
if len([]rune(value)) > 100 {
value = string([]rune(value)[:100])
}
return value
}
func (h *AuthHandler) entClient() *dbent.Client {
if h == nil || h.authService == nil {
return nil
}
return h.authService.EntClient()
}
func (h *AuthHandler) upsertPendingOAuthAdoptionDecision(
c *gin.Context,
sessionID int64,
req oauthAdoptionDecisionRequest,
) (*dbent.IdentityAdoptionDecision, error) {
client := h.entClient()
if client == nil {
return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
}
existing, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(sessionID)).
Only(c.Request.Context())
if err != nil && !dbent.IsNotFound(err) {
return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_LOAD_FAILED", "failed to load oauth profile adoption decision").WithCause(err)
}
if existing != nil && !req.hasDecision() {
return existing, nil
}
if existing == nil && !req.hasDecision() {
return nil, nil
}
input := service.PendingIdentityAdoptionDecisionInput{
PendingAuthSessionID: sessionID,
}
if existing != nil {
input.AdoptDisplayName = existing.AdoptDisplayName
input.AdoptAvatar = existing.AdoptAvatar
input.IdentityID = existing.IdentityID
}
if req.AdoptDisplayName != nil {
input.AdoptDisplayName = *req.AdoptDisplayName
}
if req.AdoptAvatar != nil {
input.AdoptAvatar = *req.AdoptAvatar
}
svc, err := h.pendingIdentityService()
if err != nil {
return nil, err
}
decision, err := svc.UpsertAdoptionDecision(c.Request.Context(), input)
if err != nil {
return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_SAVE_FAILED", "failed to save oauth profile adoption decision").WithCause(err)
}
return decision, nil
}
func resolvePendingOAuthTargetUserID(ctx context.Context, client *dbent.Client, session *dbent.PendingAuthSession) (int64, error) {
if session == nil {
return 0, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth session is invalid")
}
if session.TargetUserID != nil && *session.TargetUserID > 0 {
return *session.TargetUserID, nil
}
email := strings.TrimSpace(session.ResolvedEmail)
if email == "" {
return 0, infraerrors.BadRequest("PENDING_AUTH_TARGET_USER_MISSING", "pending auth target user is missing")
}
userEntity, err := client.User.Query().
Where(dbuser.EmailEQ(email)).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return 0, infraerrors.InternalServer("PENDING_AUTH_TARGET_USER_NOT_FOUND", "pending auth target user was not found")
}
return 0, err
}
return userEntity.ID, nil
}
func oauthIdentityIssuer(session *dbent.PendingAuthSession) *string {
if session == nil {
return nil
}
switch strings.TrimSpace(session.ProviderType) {
case "oidc":
issuer := strings.TrimSpace(session.ProviderKey)
if issuer == "" {
issuer = pendingSessionStringValue(session.UpstreamIdentityClaims, "issuer")
}
if issuer == "" {
return nil
}
return &issuer
default:
issuer := pendingSessionStringValue(session.UpstreamIdentityClaims, "issuer")
if issuer == "" {
return nil
}
return &issuer
}
}
func ensurePendingOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, session *dbent.PendingAuthSession, userID int64) (*dbent.AuthIdentity, error) {
client := tx.Client()
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ(strings.TrimSpace(session.ProviderType)),
authidentity.ProviderKeyEQ(strings.TrimSpace(session.ProviderKey)),
authidentity.ProviderSubjectEQ(strings.TrimSpace(session.ProviderSubject)),
).
Only(ctx)
if err != nil && !dbent.IsNotFound(err) {
return nil, err
}
if identity != nil {
if identity.UserID != userID {
return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
}
return identity, nil
}
create := client.AuthIdentity.Create().
SetUserID(userID).
SetProviderType(strings.TrimSpace(session.ProviderType)).
SetProviderKey(strings.TrimSpace(session.ProviderKey)).
SetProviderSubject(strings.TrimSpace(session.ProviderSubject)).
SetMetadata(cloneOAuthMetadata(session.UpstreamIdentityClaims))
if issuer := oauthIdentityIssuer(session); issuer != nil {
create = create.SetIssuer(strings.TrimSpace(*issuer))
}
return create.Save(ctx)
}
func applyPendingOAuthAdoption(
ctx context.Context,
client *dbent.Client,
session *dbent.PendingAuthSession,
decision *dbent.IdentityAdoptionDecision,
overrideUserID *int64,
) error {
if client == nil || session == nil || decision == nil {
return nil
}
if !decision.AdoptDisplayName && !decision.AdoptAvatar {
return nil
}
targetUserID := int64(0)
if overrideUserID != nil && *overrideUserID > 0 {
targetUserID = *overrideUserID
} else {
resolvedUserID, err := resolvePendingOAuthTargetUserID(ctx, client, session)
if err != nil {
return err
}
targetUserID = resolvedUserID
}
adoptedDisplayName := ""
if decision.AdoptDisplayName {
adoptedDisplayName = normalizeAdoptedOAuthDisplayName(pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_display_name"))
}
adoptedAvatarURL := ""
if decision.AdoptAvatar {
adoptedAvatarURL = pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_avatar_url")
}
tx, err := client.Tx(ctx)
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
if decision.AdoptDisplayName && adoptedDisplayName != "" {
if err := tx.Client().User.UpdateOneID(targetUserID).
SetUsername(adoptedDisplayName).
Exec(ctx); err != nil {
return err
}
}
identity, err := ensurePendingOAuthIdentityForUser(ctx, tx, session, targetUserID)
if err != nil {
return err
}
metadata := cloneOAuthMetadata(identity.Metadata)
for key, value := range session.UpstreamIdentityClaims {
metadata[key] = value
}
if decision.AdoptDisplayName && adoptedDisplayName != "" {
metadata["display_name"] = adoptedDisplayName
}
if decision.AdoptAvatar && adoptedAvatarURL != "" {
metadata["avatar_url"] = adoptedAvatarURL
}
updateIdentity := tx.Client().AuthIdentity.UpdateOneID(identity.ID).SetMetadata(metadata)
if issuer := oauthIdentityIssuer(session); issuer != nil {
updateIdentity = updateIdentity.SetIssuer(strings.TrimSpace(*issuer))
}
if _, err := updateIdentity.Save(ctx); err != nil {
return err
}
if decision.IdentityID == nil || *decision.IdentityID != identity.ID {
if _, err := tx.Client().IdentityAdoptionDecision.UpdateOneID(decision.ID).
SetIdentityID(identity.ID).
Save(ctx); err != nil {
return err
}
}
return tx.Commit()
}
func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream map[string]any) { func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream map[string]any) {
if len(payload) == 0 || len(upstream) == 0 { if len(payload) == 0 || len(upstream) == 0 {
return return
...@@ -206,6 +505,11 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) { ...@@ -206,6 +505,11 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
clearOAuthPendingSessionCookie(c, secureCookie) clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie) clearOAuthPendingBrowserCookie(c, secureCookie)
} }
adoptionDecision, err := bindOptionalOAuthAdoptionDecision(c)
if err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
sessionToken, err := readOAuthPendingSessionCookie(c) sessionToken, err := readOAuthPendingSessionCookie(c)
if err != nil || strings.TrimSpace(sessionToken) == "" { if err != nil || strings.TrimSpace(sessionToken) == "" {
...@@ -248,9 +552,30 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) { ...@@ -248,9 +552,30 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
applySuggestedProfileToCompletionResponse(payload, session.UpstreamIdentityClaims) applySuggestedProfileToCompletionResponse(payload, session.UpstreamIdentityClaims)
if pendingSessionWantsInvitation(payload) { if pendingSessionWantsInvitation(payload) {
if adoptionDecision.hasDecision() {
decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, adoptionDecision)
if err != nil {
response.ErrorFrom(c, err)
return
}
_ = decision
}
response.Success(c, payload)
return
}
if !adoptionDecision.hasDecision() {
response.Success(c, payload) response.Success(c, payload)
return return
} }
decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, adoptionDecision)
if err != nil {
response.ErrorFrom(c, err)
return
}
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, session.TargetUserID); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
return
}
if _, err := svc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil { if _, err := svc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
clearCookies() clearCookies()
......
package handler package handler
import ( import (
"bytes"
"context"
"database/sql"
"encoding/json"
"net/http"
"net/http/httptest"
"testing" "testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "modernc.org/sqlite"
) )
func TestApplySuggestedProfileToCompletionResponse(t *testing.T) { func TestApplySuggestedProfileToCompletionResponse(t *testing.T) {
...@@ -38,3 +59,439 @@ func TestApplySuggestedProfileToCompletionResponseKeepsExistingPayloadValues(t * ...@@ -38,3 +59,439 @@ func TestApplySuggestedProfileToCompletionResponseKeepsExistingPayloadValues(t *
require.Equal(t, "https://cdn.example/avatar.png", payload["suggested_avatar_url"]) require.Equal(t, "https://cdn.example/avatar.png", payload["suggested_avatar_url"])
require.Equal(t, true, payload["adoption_required"]) require.Equal(t, true, payload["adoption_required"])
} }
func TestExchangePendingOAuthCompletionPreviewThenFinalizeAppliesAdoptionDecision(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
userEntity, err := client.User.Create().
SetEmail("linuxdo-123@linuxdo-connect.invalid").
SetUsername("legacy-name").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
session, err := client.PendingAuthSession.Create().
SetSessionToken("pending-session-token").
SetIntent("login").
SetProviderType("linuxdo").
SetProviderKey("linuxdo").
SetProviderSubject("123").
SetTargetUserID(userEntity.ID).
SetResolvedEmail(userEntity.Email).
SetBrowserSessionKey("browser-session-key").
SetUpstreamIdentityClaims(map[string]any{
"username": "linuxdo_user",
"suggested_display_name": "Alice Example",
"suggested_avatar_url": "https://cdn.example/alice.png",
}).
SetLocalFlowState(map[string]any{
oauthCompletionResponseKey: map[string]any{
"access_token": "access-token",
"redirect": "/dashboard",
},
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
previewRecorder := httptest.NewRecorder()
previewCtx, _ := gin.CreateTestContext(previewRecorder)
previewReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil)
previewReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
previewReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-session-key")})
previewCtx.Request = previewReq
handler.ExchangePendingOAuthCompletion(previewCtx)
require.Equal(t, http.StatusOK, previewRecorder.Code)
previewData := decodeJSONResponseData(t, previewRecorder)
require.Equal(t, "Alice Example", previewData["suggested_display_name"])
require.Equal(t, "https://cdn.example/alice.png", previewData["suggested_avatar_url"])
require.Equal(t, true, previewData["adoption_required"])
storedUser, err := client.User.Get(ctx, userEntity.ID)
require.NoError(t, err)
require.Equal(t, "legacy-name", storedUser.Username)
previewSession, err := client.PendingAuthSession.Query().
Where(pendingauthsession.IDEQ(session.ID)).
Only(ctx)
require.NoError(t, err)
require.Nil(t, previewSession.ConsumedAt)
body := bytes.NewBufferString(`{"adopt_display_name":true,"adopt_avatar":true}`)
finalizeRecorder := httptest.NewRecorder()
finalizeCtx, _ := gin.CreateTestContext(finalizeRecorder)
finalizeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body)
finalizeReq.Header.Set("Content-Type", "application/json")
finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-session-key")})
finalizeCtx.Request = finalizeReq
handler.ExchangePendingOAuthCompletion(finalizeCtx)
require.Equal(t, http.StatusOK, finalizeRecorder.Code)
storedUser, err = client.User.Get(ctx, userEntity.ID)
require.NoError(t, err)
require.Equal(t, "Alice Example", storedUser.Username)
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("linuxdo"),
authidentity.ProviderKeyEQ("linuxdo"),
authidentity.ProviderSubjectEQ("123"),
).
Only(ctx)
require.NoError(t, err)
require.Equal(t, userEntity.ID, identity.UserID)
require.Equal(t, "Alice Example", identity.Metadata["display_name"])
require.Equal(t, "https://cdn.example/alice.png", identity.Metadata["avatar_url"])
decision, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, decision.IdentityID)
require.Equal(t, identity.ID, *decision.IdentityID)
require.True(t, decision.AdoptDisplayName)
require.True(t, decision.AdoptAvatar)
consumed, err := client.PendingAuthSession.Query().
Where(pendingauthsession.IDEQ(session.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, consumed.ConsumedAt)
}
func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) {
t.Helper()
db, err := sql.Open("sqlite", "file:auth_oauth_pending_flow_handler?mode=memory&cache=shared")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
ExpireHour: 1,
AccessTokenExpireMinutes: 60,
RefreshTokenExpireDays: 7,
},
Default: config.DefaultConfig{
UserBalance: 0,
UserConcurrency: 1,
},
}
settingSvc := service.NewSettingService(&oauthPendingFlowSettingRepoStub{
values: map[string]string{
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled),
},
}, cfg)
authSvc := service.NewAuthService(
client,
&oauthPendingFlowUserRepo{client: client},
nil,
&oauthPendingFlowRefreshTokenCacheStub{},
cfg,
settingSvc,
nil,
nil,
nil,
nil,
nil,
)
return &AuthHandler{
authService: authSvc,
settingSvc: settingSvc,
}, client
}
func boolSettingValue(v bool) string {
if v {
return "true"
}
return "false"
}
func boolPtr(v bool) *bool {
return &v
}
type oauthPendingFlowSettingRepoStub struct {
values map[string]string
}
func (s *oauthPendingFlowSettingRepoStub) Get(context.Context, string) (*service.Setting, error) {
return nil, service.ErrSettingNotFound
}
func (s *oauthPendingFlowSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
value, ok := s.values[key]
if !ok {
return "", service.ErrSettingNotFound
}
return value, nil
}
func (s *oauthPendingFlowSettingRepoStub) Set(context.Context, string, string) error {
return nil
}
func (s *oauthPendingFlowSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
result := make(map[string]string, len(keys))
for _, key := range keys {
if value, ok := s.values[key]; ok {
result[key] = value
}
}
return result, nil
}
func (s *oauthPendingFlowSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
return nil
}
func (s *oauthPendingFlowSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
result := make(map[string]string, len(s.values))
for key, value := range s.values {
result[key] = value
}
return result, nil
}
func (s *oauthPendingFlowSettingRepoStub) Delete(context.Context, string) error {
return nil
}
type oauthPendingFlowRefreshTokenCacheStub struct{}
func (s *oauthPendingFlowRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error {
return nil
}
func (s *oauthPendingFlowRefreshTokenCacheStub) GetRefreshToken(context.Context, string) (*service.RefreshTokenData, error) {
return nil, service.ErrRefreshTokenNotFound
}
func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error {
return nil
}
func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error {
return nil
}
func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error {
return nil
}
func (s *oauthPendingFlowRefreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error {
return nil
}
func (s *oauthPendingFlowRefreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error {
return nil
}
func (s *oauthPendingFlowRefreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) {
return nil, nil
}
func (s *oauthPendingFlowRefreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) {
return nil, nil
}
func (s *oauthPendingFlowRefreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) {
return false, nil
}
func decodeJSONResponseData(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any {
t.Helper()
var envelope struct {
Data map[string]any `json:"data"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &envelope))
return envelope.Data
}
func decodeJSONBody(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any {
t.Helper()
var payload map[string]any
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
return payload
}
type oauthPendingFlowUserRepo struct {
client *dbent.Client
}
func (r *oauthPendingFlowUserRepo) Create(ctx context.Context, user *service.User) error {
entity, err := r.client.User.Create().
SetEmail(user.Email).
SetUsername(user.Username).
SetNotes(user.Notes).
SetPasswordHash(user.PasswordHash).
SetRole(user.Role).
SetBalance(user.Balance).
SetConcurrency(user.Concurrency).
SetStatus(user.Status).
SetSignupSource(user.SignupSource).
SetNillableLastLoginAt(user.LastLoginAt).
SetNillableLastActiveAt(user.LastActiveAt).
Save(ctx)
if err != nil {
return err
}
user.ID = entity.ID
user.CreatedAt = entity.CreatedAt
user.UpdatedAt = entity.UpdatedAt
return nil
}
func (r *oauthPendingFlowUserRepo) GetByID(ctx context.Context, id int64) (*service.User, error) {
entity, err := r.client.User.Get(ctx, id)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrUserNotFound
}
return nil, err
}
return oauthPendingFlowServiceUser(entity), nil
}
func (r *oauthPendingFlowUserRepo) GetByEmail(ctx context.Context, email string) (*service.User, error) {
entity, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrUserNotFound
}
return nil, err
}
return oauthPendingFlowServiceUser(entity), nil
}
func (r *oauthPendingFlowUserRepo) GetFirstAdmin(context.Context) (*service.User, error) {
panic("unexpected GetFirstAdmin call")
}
func (r *oauthPendingFlowUserRepo) Update(ctx context.Context, user *service.User) error {
entity, err := r.client.User.UpdateOneID(user.ID).
SetEmail(user.Email).
SetUsername(user.Username).
SetNotes(user.Notes).
SetPasswordHash(user.PasswordHash).
SetRole(user.Role).
SetBalance(user.Balance).
SetConcurrency(user.Concurrency).
SetStatus(user.Status).
SetSignupSource(user.SignupSource).
SetNillableLastLoginAt(user.LastLoginAt).
SetNillableLastActiveAt(user.LastActiveAt).
Save(ctx)
if err != nil {
return err
}
user.UpdatedAt = entity.UpdatedAt
return nil
}
func (r *oauthPendingFlowUserRepo) Delete(ctx context.Context, id int64) error {
return r.client.User.DeleteOneID(id).Exec(ctx)
}
func (r *oauthPendingFlowUserRepo) GetUserAvatar(context.Context, int64) (*service.UserAvatar, error) {
return nil, service.ErrUserNotFound
}
func (r *oauthPendingFlowUserRepo) UpsertUserAvatar(context.Context, int64, service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
panic("unexpected UpsertUserAvatar call")
}
func (r *oauthPendingFlowUserRepo) DeleteUserAvatar(context.Context, int64) error {
return nil
}
func (r *oauthPendingFlowUserRepo) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (r *oauthPendingFlowUserRepo) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}
func (r *oauthPendingFlowUserRepo) UpdateBalance(context.Context, int64, float64) error {
panic("unexpected UpdateBalance call")
}
func (r *oauthPendingFlowUserRepo) DeductBalance(context.Context, int64, float64) error {
panic("unexpected DeductBalance call")
}
func (r *oauthPendingFlowUserRepo) UpdateConcurrency(context.Context, int64, int) error {
panic("unexpected UpdateConcurrency call")
}
func (r *oauthPendingFlowUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
count, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Count(ctx)
return count > 0, err
}
func (r *oauthPendingFlowUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
panic("unexpected RemoveGroupFromAllowedGroups call")
}
func (r *oauthPendingFlowUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error {
panic("unexpected AddGroupToAllowedGroups call")
}
func (r *oauthPendingFlowUserRepo) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
panic("unexpected RemoveGroupFromUserAllowedGroups call")
}
func (r *oauthPendingFlowUserRepo) UpdateTotpSecret(context.Context, int64, *string) error {
panic("unexpected UpdateTotpSecret call")
}
func (r *oauthPendingFlowUserRepo) EnableTotp(context.Context, int64) error {
panic("unexpected EnableTotp call")
}
func (r *oauthPendingFlowUserRepo) DisableTotp(context.Context, int64) error {
panic("unexpected DisableTotp call")
}
func oauthPendingFlowServiceUser(entity *dbent.User) *service.User {
if entity == nil {
return nil
}
return &service.User{
ID: entity.ID,
Email: entity.Email,
Username: entity.Username,
Notes: entity.Notes,
PasswordHash: entity.PasswordHash,
Role: entity.Role,
Balance: entity.Balance,
Concurrency: entity.Concurrency,
Status: entity.Status,
SignupSource: entity.SignupSource,
LastLoginAt: entity.LastLoginAt,
LastActiveAt: entity.LastActiveAt,
CreatedAt: entity.CreatedAt,
UpdatedAt: entity.UpdatedAt,
}
}
...@@ -326,7 +326,7 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { ...@@ -326,7 +326,7 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
) )
// 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired // 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
if err != nil { if err != nil {
if errors.Is(err, service.ErrOAuthInvitationRequired) { if errors.Is(err, service.ErrOAuthInvitationRequired) {
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
...@@ -371,6 +371,7 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { ...@@ -371,6 +371,7 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
ProviderKey: issuer, ProviderKey: issuer,
ProviderSubject: subject, ProviderSubject: subject,
}, },
TargetUserID: &user.ID,
ResolvedEmail: email, ResolvedEmail: email,
RedirectTo: redirectTo, RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey, BrowserSessionKey: browserSessionKey,
...@@ -399,7 +400,9 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { ...@@ -399,7 +400,9 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
} }
type completeOIDCOAuthRequest struct { type completeOIDCOAuthRequest struct {
InvitationCode string `json:"invitation_code" binding:"required"` InvitationCode string `json:"invitation_code" binding:"required"`
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
} }
// CompleteOIDCOAuthRegistration completes a pending OAuth registration by validating // CompleteOIDCOAuthRegistration completes a pending OAuth registration by validating
...@@ -447,11 +450,23 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) { ...@@ -447,11 +450,23 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
return return
} }
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
AdoptDisplayName: req.AdoptDisplayName,
AdoptAvatar: req.AdoptAvatar,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, &user.ID); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
return
}
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil { if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
clearOAuthPendingSessionCookie(c, secureCookie) clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie) clearOAuthPendingBrowserCookie(c, secureCookie)
......
package handler package handler
import ( import (
"bytes"
"context" "context"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
...@@ -12,7 +13,13 @@ import ( ...@@ -12,7 +13,13 @@ import (
"testing" "testing"
"time" "time"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
...@@ -123,3 +130,80 @@ func buildRSAJWK(kid string, pub *rsa.PublicKey) oidcJWK { ...@@ -123,3 +130,80 @@ func buildRSAJWK(kid string, pub *rsa.PublicKey) oidcJWK {
E: e, E: e,
} }
} }
func TestCompleteOIDCOAuthRegistrationAppliesPendingAdoptionDecision(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
session, err := client.PendingAuthSession.Create().
SetSessionToken("oidc-complete-session").
SetIntent("login").
SetProviderType("oidc").
SetProviderKey("https://issuer.example.com").
SetProviderSubject("oidc-subject-1").
SetResolvedEmail("93a310f4c1944c5bbd2e246df1f76485@oidc-connect.invalid").
SetBrowserSessionKey("oidc-browser").
SetUpstreamIdentityClaims(map[string]any{
"username": "oidc_user",
"issuer": "https://issuer.example.com",
"suggested_display_name": "OIDC Display",
"suggested_avatar_url": "https://cdn.example/oidc.png",
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
_, err = service.NewAuthPendingIdentityService(client).UpsertAdoptionDecision(ctx, service.PendingIdentityAdoptionDecisionInput{
PendingAuthSessionID: session.ID,
AdoptAvatar: true,
})
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true}`)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-browser")})
c.Request = req
handler.CompleteOIDCOAuthRegistration(c)
require.Equal(t, http.StatusOK, recorder.Code)
responseData := decodeJSONBody(t, recorder)
require.NotEmpty(t, responseData["access_token"])
userEntity, err := client.User.Query().
Where(dbuser.EmailEQ(session.ResolvedEmail)).
Only(ctx)
require.NoError(t, err)
require.Equal(t, "OIDC Display", userEntity.Username)
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("oidc"),
authidentity.ProviderKeyEQ("https://issuer.example.com"),
authidentity.ProviderSubjectEQ("oidc-subject-1"),
).
Only(ctx)
require.NoError(t, err)
require.Equal(t, userEntity.ID, identity.UserID)
require.Equal(t, "OIDC Display", identity.Metadata["display_name"])
require.Equal(t, "https://cdn.example/oidc.png", identity.Metadata["avatar_url"])
decision, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, decision.IdentityID)
require.Equal(t, identity.ID, *decision.IdentityID)
require.True(t, decision.AdoptDisplayName)
require.True(t, decision.AdoptAvatar)
consumed, err := client.PendingAuthSession.Query().
Where(pendingauthsession.IDEQ(session.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, consumed.ConsumedAt)
}
package handler
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"os"
"strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
const (
wechatOAuthCookiePath = "/api/v1/auth/oauth/wechat"
wechatOAuthCookieMaxAgeSec = 10 * 60
wechatOAuthStateCookieName = "wechat_oauth_state"
wechatOAuthRedirectCookieName = "wechat_oauth_redirect"
wechatOAuthIntentCookieName = "wechat_oauth_intent"
wechatOAuthModeCookieName = "wechat_oauth_mode"
wechatOAuthDefaultRedirectTo = "/dashboard"
wechatOAuthDefaultFrontendCB = "/auth/wechat/callback"
wechatOAuthProviderKey = "wechat-main"
wechatOAuthIntentLogin = "login"
wechatOAuthIntentBind = "bind_current_user"
wechatOAuthIntentAdoptEmail = "adopt_existing_user_by_email"
)
var (
wechatOAuthAccessTokenURL = "https://api.weixin.qq.com/sns/oauth2/access_token"
wechatOAuthUserInfoURL = "https://api.weixin.qq.com/sns/userinfo"
)
type wechatOAuthConfig struct {
mode string
appID string
appSecret string
authorizeURL string
scope string
redirectURI string
frontendCallback string
}
type wechatOAuthTokenResponse struct {
AccessToken string `json:"access_token"`
ExpiresIn int64 `json:"expires_in"`
RefreshToken string `json:"refresh_token"`
OpenID string `json:"openid"`
Scope string `json:"scope"`
UnionID string `json:"unionid"`
ErrCode int64 `json:"errcode"`
ErrMsg string `json:"errmsg"`
}
type wechatOAuthUserInfoResponse struct {
OpenID string `json:"openid"`
Nickname string `json:"nickname"`
HeadImgURL string `json:"headimgurl"`
UnionID string `json:"unionid"`
ErrCode int64 `json:"errcode"`
ErrMsg string `json:"errmsg"`
}
// WeChatOAuthStart starts the WeChat OAuth login flow and stores the short-lived
// browser cookies required by the rebuild pending-auth bridge.
func (h *AuthHandler) WeChatOAuthStart(c *gin.Context) {
cfg, err := h.getWeChatOAuthConfig(c.Request.Context(), c.Query("mode"), c)
if err != nil {
response.ErrorFrom(c, err)
return
}
state, err := oauth.GenerateState()
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err))
return
}
redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect"))
if redirectTo == "" {
redirectTo = wechatOAuthDefaultRedirectTo
}
browserSessionKey, err := generateOAuthPendingBrowserSession()
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BROWSER_SESSION_GEN_FAILED", "failed to generate oauth browser session").WithCause(err))
return
}
intent := normalizeWeChatOAuthIntent(c.Query("intent"))
secureCookie := isRequestHTTPS(c)
wechatSetCookie(c, wechatOAuthStateCookieName, encodeCookieValue(state), wechatOAuthCookieMaxAgeSec, secureCookie)
wechatSetCookie(c, wechatOAuthRedirectCookieName, encodeCookieValue(redirectTo), wechatOAuthCookieMaxAgeSec, secureCookie)
wechatSetCookie(c, wechatOAuthIntentCookieName, encodeCookieValue(intent), wechatOAuthCookieMaxAgeSec, secureCookie)
wechatSetCookie(c, wechatOAuthModeCookieName, encodeCookieValue(cfg.mode), wechatOAuthCookieMaxAgeSec, secureCookie)
setOAuthPendingBrowserCookie(c, browserSessionKey, secureCookie)
clearOAuthPendingSessionCookie(c, secureCookie)
authURL, err := buildWeChatAuthorizeURL(cfg, state)
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err))
return
}
c.Redirect(http.StatusFound, authURL)
}
// WeChatOAuthCallback exchanges the code with WeChat, resolves openid/unionid,
// and stores the result in the unified pending-auth flow.
func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) {
frontendCallback := wechatOAuthFrontendCallback()
if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" {
redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description"))
return
}
code := strings.TrimSpace(c.Query("code"))
state := strings.TrimSpace(c.Query("state"))
if code == "" || state == "" {
redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "")
return
}
secureCookie := isRequestHTTPS(c)
defer func() {
wechatClearCookie(c, wechatOAuthStateCookieName, secureCookie)
wechatClearCookie(c, wechatOAuthRedirectCookieName, secureCookie)
wechatClearCookie(c, wechatOAuthIntentCookieName, secureCookie)
wechatClearCookie(c, wechatOAuthModeCookieName, secureCookie)
}()
expectedState, err := readCookieDecoded(c, wechatOAuthStateCookieName)
if err != nil || expectedState == "" || state != expectedState {
redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "")
return
}
redirectTo, _ := readCookieDecoded(c, wechatOAuthRedirectCookieName)
redirectTo = sanitizeFrontendRedirectPath(redirectTo)
if redirectTo == "" {
redirectTo = wechatOAuthDefaultRedirectTo
}
browserSessionKey, _ := readOAuthPendingBrowserCookie(c)
if strings.TrimSpace(browserSessionKey) == "" {
redirectOAuthError(c, frontendCallback, "missing_browser_session", "missing oauth browser session", "")
return
}
intent, _ := readCookieDecoded(c, wechatOAuthIntentCookieName)
mode, err := readCookieDecoded(c, wechatOAuthModeCookieName)
if err != nil || strings.TrimSpace(mode) == "" {
redirectOAuthError(c, frontendCallback, "invalid_state", "missing oauth mode", "")
return
}
cfg, err := h.getWeChatOAuthConfig(c.Request.Context(), mode, c)
if err != nil {
redirectOAuthError(c, frontendCallback, "provider_error", infraerrors.Reason(err), infraerrors.Message(err))
return
}
tokenResp, userInfo, err := fetchWeChatOAuthIdentity(c.Request.Context(), cfg, code)
if err != nil {
redirectOAuthError(c, frontendCallback, "provider_error", "wechat_identity_fetch_failed", singleLine(err.Error()))
return
}
unionid := strings.TrimSpace(firstNonEmpty(userInfo.UnionID, tokenResp.UnionID))
openid := strings.TrimSpace(firstNonEmpty(userInfo.OpenID, tokenResp.OpenID))
providerSubject := firstNonEmpty(unionid, openid)
if providerSubject == "" {
redirectOAuthError(c, frontendCallback, "provider_error", "wechat_missing_subject", "")
return
}
username := firstNonEmpty(userInfo.Nickname, wechatFallbackUsername(providerSubject))
email := wechatSyntheticEmail(providerSubject)
upstreamClaims := map[string]any{
"email": email,
"username": username,
"subject": providerSubject,
"openid": openid,
"unionid": unionid,
"mode": cfg.mode,
"suggested_display_name": strings.TrimSpace(userInfo.Nickname),
"suggested_avatar_url": strings.TrimSpace(userInfo.HeadImgURL),
}
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
if err != nil {
if err := h.createWeChatPendingSession(c, normalizeWeChatOAuthIntent(intent), providerSubject, email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, err); err != nil {
redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
return
}
redirectToFrontendCallback(c, frontendCallback)
return
}
if err := h.createWeChatPendingSession(c, normalizeWeChatOAuthIntent(intent), providerSubject, email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, nil); err != nil {
redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
return
}
redirectToFrontendCallback(c, frontendCallback)
}
type completeWeChatOAuthRequest struct {
InvitationCode string `json:"invitation_code" binding:"required"`
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
// CompleteWeChatOAuthRegistration completes a pending WeChat OAuth registration by
// validating the invitation code and consuming the current pending browser session.
// POST /api/v1/auth/oauth/wechat/complete-registration
func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
var req completeWeChatOAuthRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "INVALID_REQUEST", "message": err.Error()})
return
}
secureCookie := isRequestHTTPS(c)
sessionToken, err := readOAuthPendingSessionCookie(c)
if err != nil {
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound)
return
}
browserSessionKey, err := readOAuthPendingBrowserCookie(c)
if err != nil {
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch)
return
}
pendingSvc, err := h.pendingIdentityService()
if err != nil {
response.ErrorFrom(c, err)
return
}
session, err := pendingSvc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
if err != nil {
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
response.ErrorFrom(c, err)
return
}
email := strings.TrimSpace(session.ResolvedEmail)
username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username")
if email == "" || username == "" {
response.ErrorFrom(c, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid"))
return
}
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
if err != nil {
response.ErrorFrom(c, err)
return
}
decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
AdoptDisplayName: req.AdoptDisplayName,
AdoptAvatar: req.AdoptAvatar,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, &user.ID); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
return
}
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
response.ErrorFrom(c, err)
return
}
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
c.JSON(http.StatusOK, gin.H{
"access_token": tokenPair.AccessToken,
"refresh_token": tokenPair.RefreshToken,
"expires_in": tokenPair.ExpiresIn,
"token_type": "Bearer",
})
}
func (h *AuthHandler) createWeChatPendingSession(
c *gin.Context,
intent string,
providerSubject string,
email string,
redirectTo string,
browserSessionKey string,
upstreamClaims map[string]any,
tokenPair *service.TokenPair,
authErr error,
) error {
completionResponse := map[string]any{
"redirect": redirectTo,
}
if authErr != nil {
if errors.Is(authErr, service.ErrOAuthInvitationRequired) {
completionResponse["error"] = "invitation_required"
} else {
return authErr
}
} else if tokenPair != nil {
completionResponse["access_token"] = tokenPair.AccessToken
completionResponse["refresh_token"] = tokenPair.RefreshToken
completionResponse["expires_in"] = tokenPair.ExpiresIn
completionResponse["token_type"] = "Bearer"
}
return h.createOAuthPendingSession(c, oauthPendingSessionPayload{
Intent: intent,
Identity: service.PendingAuthIdentityKey{
ProviderType: "wechat",
ProviderKey: wechatOAuthProviderKey,
ProviderSubject: providerSubject,
},
ResolvedEmail: email,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
UpstreamIdentityClaims: upstreamClaims,
CompletionResponse: completionResponse,
})
}
func (h *AuthHandler) getWeChatOAuthConfig(ctx context.Context, rawMode string, c *gin.Context) (wechatOAuthConfig, error) {
mode, err := resolveWeChatOAuthMode(rawMode, c)
if err != nil {
return wechatOAuthConfig{}, err
}
apiBaseURL := ""
if h != nil && h.settingSvc != nil {
settings, err := h.settingSvc.GetAllSettings(ctx)
if err == nil && settings != nil {
apiBaseURL = strings.TrimSpace(settings.APIBaseURL)
}
}
cfg := wechatOAuthConfig{
mode: mode,
redirectURI: resolveWeChatOAuthAbsoluteURL(apiBaseURL, c, "/api/v1/auth/oauth/wechat/callback"),
frontendCallback: wechatOAuthFrontendCallback(),
}
switch mode {
case "mp":
cfg.appID = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_ID"))
cfg.appSecret = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_SECRET"))
cfg.authorizeURL = "https://open.weixin.qq.com/connect/oauth2/authorize"
cfg.scope = "snsapi_userinfo"
default:
cfg.appID = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_ID"))
cfg.appSecret = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_SECRET"))
cfg.authorizeURL = "https://open.weixin.qq.com/connect/qrconnect"
cfg.scope = "snsapi_login"
}
if cfg.appID == "" || cfg.appSecret == "" {
return wechatOAuthConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "wechat oauth is disabled")
}
if strings.TrimSpace(cfg.redirectURI) == "" {
return wechatOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth redirect url not configured")
}
return cfg, nil
}
func wechatOAuthFrontendCallback() string {
return firstNonEmpty(strings.TrimSpace(os.Getenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL")), wechatOAuthDefaultFrontendCB)
}
func resolveWeChatOAuthMode(rawMode string, c *gin.Context) (string, error) {
mode := strings.ToLower(strings.TrimSpace(rawMode))
if mode == "" {
if isWeChatBrowserRequest(c) {
return "mp", nil
}
return "open", nil
}
if mode != "open" && mode != "mp" {
return "", infraerrors.BadRequest("INVALID_MODE", "wechat oauth mode must be open or mp")
}
return mode, nil
}
func isWeChatBrowserRequest(c *gin.Context) bool {
if c == nil || c.Request == nil {
return false
}
return strings.Contains(strings.ToLower(strings.TrimSpace(c.GetHeader("User-Agent"))), "micromessenger")
}
func normalizeWeChatOAuthIntent(raw string) string {
switch strings.ToLower(strings.TrimSpace(raw)) {
case "", "login":
return wechatOAuthIntentLogin
case "bind", "bind_current_user":
return wechatOAuthIntentBind
case "adopt", "adopt_existing_user_by_email":
return wechatOAuthIntentAdoptEmail
default:
return wechatOAuthIntentLogin
}
}
func buildWeChatAuthorizeURL(cfg wechatOAuthConfig, state string) (string, error) {
u, err := url.Parse(cfg.authorizeURL)
if err != nil {
return "", fmt.Errorf("parse authorize url: %w", err)
}
query := u.Query()
query.Set("appid", cfg.appID)
query.Set("redirect_uri", cfg.redirectURI)
query.Set("response_type", "code")
query.Set("scope", cfg.scope)
query.Set("state", state)
u.RawQuery = query.Encode()
u.Fragment = "wechat_redirect"
return u.String(), nil
}
func resolveWeChatOAuthAbsoluteURL(apiBaseURL string, c *gin.Context, callbackPath string) string {
callbackPath = strings.TrimSpace(callbackPath)
if callbackPath == "" {
return ""
}
if raw := strings.TrimSpace(apiBaseURL); raw != "" {
if parsed, err := url.Parse(raw); err == nil && parsed.Scheme != "" && parsed.Host != "" {
basePath := strings.TrimRight(parsed.EscapedPath(), "/")
targetPath := callbackPath
if basePath != "" && strings.HasSuffix(basePath, "/api/v1") && strings.HasPrefix(callbackPath, "/api/v1") {
targetPath = basePath + strings.TrimPrefix(callbackPath, "/api/v1")
} else if basePath != "" {
targetPath = basePath + callbackPath
}
return parsed.Scheme + "://" + parsed.Host + targetPath
}
}
if c == nil || c.Request == nil {
return ""
}
scheme := "http"
if isRequestHTTPS(c) {
scheme = "https"
}
host := strings.TrimSpace(c.Request.Host)
if forwardedHost := strings.TrimSpace(c.GetHeader("X-Forwarded-Host")); forwardedHost != "" {
host = forwardedHost
}
if host == "" {
return ""
}
return scheme + "://" + host + callbackPath
}
func fetchWeChatOAuthIdentity(ctx context.Context, cfg wechatOAuthConfig, code string) (*wechatOAuthTokenResponse, *wechatOAuthUserInfoResponse, error) {
tokenResp, err := exchangeWeChatOAuthCode(ctx, cfg, code)
if err != nil {
return nil, nil, err
}
userInfo, err := fetchWeChatUserInfo(ctx, tokenResp)
if err != nil {
return nil, nil, err
}
return tokenResp, userInfo, nil
}
func exchangeWeChatOAuthCode(ctx context.Context, cfg wechatOAuthConfig, code string) (*wechatOAuthTokenResponse, error) {
endpoint, err := url.Parse(wechatOAuthAccessTokenURL)
if err != nil {
return nil, fmt.Errorf("parse wechat access token url: %w", err)
}
query := endpoint.Query()
query.Set("appid", cfg.appID)
query.Set("secret", cfg.appSecret)
query.Set("code", strings.TrimSpace(code))
query.Set("grant_type", "authorization_code")
endpoint.RawQuery = query.Encode()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil)
if err != nil {
return nil, fmt.Errorf("build wechat access token request: %w", err)
}
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request wechat access token: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read wechat access token response: %w", err)
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
return nil, fmt.Errorf("wechat access token status=%d", resp.StatusCode)
}
var tokenResp wechatOAuthTokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("decode wechat access token response: %w", err)
}
if tokenResp.ErrCode != 0 {
return nil, fmt.Errorf("wechat access token error=%d %s", tokenResp.ErrCode, strings.TrimSpace(tokenResp.ErrMsg))
}
if strings.TrimSpace(tokenResp.AccessToken) == "" {
return nil, fmt.Errorf("wechat access token missing access_token")
}
return &tokenResp, nil
}
func fetchWeChatUserInfo(ctx context.Context, tokenResp *wechatOAuthTokenResponse) (*wechatOAuthUserInfoResponse, error) {
if tokenResp == nil {
return nil, fmt.Errorf("wechat token response is nil")
}
endpoint, err := url.Parse(wechatOAuthUserInfoURL)
if err != nil {
return nil, fmt.Errorf("parse wechat userinfo url: %w", err)
}
query := endpoint.Query()
query.Set("access_token", strings.TrimSpace(tokenResp.AccessToken))
query.Set("openid", strings.TrimSpace(tokenResp.OpenID))
query.Set("lang", "zh_CN")
endpoint.RawQuery = query.Encode()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil)
if err != nil {
return nil, fmt.Errorf("build wechat userinfo request: %w", err)
}
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request wechat userinfo: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read wechat userinfo response: %w", err)
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
return nil, fmt.Errorf("wechat userinfo status=%d", resp.StatusCode)
}
var userInfo wechatOAuthUserInfoResponse
if err := json.Unmarshal(body, &userInfo); err != nil {
return nil, fmt.Errorf("decode wechat userinfo response: %w", err)
}
if userInfo.ErrCode != 0 {
return nil, fmt.Errorf("wechat userinfo error=%d %s", userInfo.ErrCode, strings.TrimSpace(userInfo.ErrMsg))
}
return &userInfo, nil
}
func wechatSyntheticEmail(subject string) string {
subject = strings.TrimSpace(subject)
if subject == "" {
return ""
}
return "wechat-" + subject + service.WeChatConnectSyntheticEmailDomain
}
func wechatFallbackUsername(subject string) string {
subject = strings.TrimSpace(subject)
if subject == "" {
return "wechat_user"
}
return "wechat_" + truncateFragmentValue(subject)
}
func wechatSetCookie(c *gin.Context, name string, value string, maxAgeSec int, secure bool) {
http.SetCookie(c.Writer, &http.Cookie{
Name: name,
Value: value,
Path: wechatOAuthCookiePath,
MaxAge: maxAgeSec,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
})
}
func wechatClearCookie(c *gin.Context, name string, secure bool) {
http.SetCookie(c.Writer, &http.Cookie{
Name: name,
Value: "",
Path: wechatOAuthCookiePath,
MaxAge: -1,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
})
}
//go:build unit
package handler
import (
"bytes"
"context"
"database/sql"
"encoding/base64"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "modernc.org/sqlite"
)
func TestWeChatOAuthStartRedirectsAndSetsPendingCookies(t *testing.T) {
t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/start?mode=open&redirect=/billing", nil)
c.Request.Host = "api.example.com"
handler := &AuthHandler{}
handler.WeChatOAuthStart(c)
require.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
require.NotEmpty(t, location)
require.Contains(t, location, "open.weixin.qq.com")
require.Contains(t, location, "appid=wx-open-app")
require.Contains(t, location, "scope=snsapi_login")
cookies := recorder.Result().Cookies()
require.NotEmpty(t, findCookie(cookies, wechatOAuthStateCookieName))
require.NotEmpty(t, findCookie(cookies, wechatOAuthRedirectCookieName))
require.NotEmpty(t, findCookie(cookies, wechatOAuthModeCookieName))
require.NotEmpty(t, findCookie(cookies, oauthPendingBrowserCookieName))
}
func TestWeChatOAuthCallbackCreatesPendingSessionForUnifiedFlow(t *testing.T) {
t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback")
originalAccessTokenURL := wechatOAuthAccessTokenURL
originalUserInfoURL := wechatOAuthUserInfoURL
t.Cleanup(func() {
wechatOAuthAccessTokenURL = originalAccessTokenURL
wechatOAuthUserInfoURL = originalUserInfoURL
})
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
case strings.Contains(r.URL.Path, "/sns/userinfo"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Nick","headimgurl":"https://cdn.example/avatar.png"}`))
default:
http.NotFound(w, r)
}
}))
defer upstream.Close()
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
handler, client := newWeChatOAuthTestHandler(t, false)
defer client.Close()
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
req.Host = "api.example.com"
req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
c.Request = req
handler.WeChatOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location"))
sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
require.NotNil(t, sessionCookie)
ctx := context.Background()
session, err := client.PendingAuthSession.Query().
Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
Only(ctx)
require.NoError(t, err)
require.Equal(t, "wechat", session.ProviderType)
require.Equal(t, "wechat-main", session.ProviderKey)
require.Equal(t, "union-456", session.ProviderSubject)
require.Equal(t, "wechat-union-456@wechat-connect.invalid", session.ResolvedEmail)
require.Equal(t, "WeChat Nick", session.UpstreamIdentityClaims["suggested_display_name"])
require.Equal(t, "https://cdn.example/avatar.png", session.UpstreamIdentityClaims["suggested_avatar_url"])
require.Equal(t, "union-456", session.UpstreamIdentityClaims["unionid"])
require.Equal(t, "openid-123", session.UpstreamIdentityClaims["openid"])
}
func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing.T) {
t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback")
originalAccessTokenURL := wechatOAuthAccessTokenURL
originalUserInfoURL := wechatOAuthUserInfoURL
t.Cleanup(func() {
wechatOAuthAccessTokenURL = originalAccessTokenURL
wechatOAuthUserInfoURL = originalUserInfoURL
})
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
case strings.Contains(r.URL.Path, "/sns/userinfo"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Display","headimgurl":"https://cdn.example/wechat.png"}`))
default:
http.NotFound(w, r)
}
}))
defer upstream.Close()
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
handler, client := newWeChatOAuthTestHandler(t, true)
defer client.Close()
ctx := context.Background()
redeemRepo := repository.NewRedeemCodeRepository(client)
require.NoError(t, redeemRepo.Create(ctx, &service.RedeemCode{
Code: "invite-1",
Type: service.RedeemTypeInvitation,
Status: service.StatusUnused,
}))
callbackRecorder := httptest.NewRecorder()
callbackCtx, _ := gin.CreateTestContext(callbackRecorder)
callbackReq := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
callbackReq.Host = "api.example.com"
callbackReq.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
callbackReq.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
callbackReq.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
callbackReq.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
callbackCtx.Request = callbackReq
handler.WeChatOAuthCallback(callbackCtx)
require.Equal(t, http.StatusFound, callbackRecorder.Code)
require.Equal(t, "/auth/wechat/callback", callbackRecorder.Header().Get("Location"))
sessionCookie := findCookie(callbackRecorder.Result().Cookies(), oauthPendingSessionCookieName)
require.NotNil(t, sessionCookie)
sessionToken := decodeCookieValueForTest(t, sessionCookie.Value)
pendingSession, err := client.PendingAuthSession.Query().
Where(pendingauthsession.SessionTokenEQ(sessionToken)).
Only(ctx)
require.NoError(t, err)
require.Equal(t, "invitation_required", pendingSession.LocalFlowState[oauthCompletionResponseKey].(map[string]any)["error"])
body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true,"adopt_avatar":true}`)
completeRecorder := httptest.NewRecorder()
completeCtx, _ := gin.CreateTestContext(completeRecorder)
completeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body)
completeReq.Header.Set("Content-Type", "application/json")
completeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(sessionToken)})
completeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-123")})
completeCtx.Request = completeReq
handler.CompleteWeChatOAuthRegistration(completeCtx)
require.Equal(t, http.StatusOK, completeRecorder.Code)
responseData := decodeJSONBody(t, completeRecorder)
require.NotEmpty(t, responseData["access_token"])
userEntity, err := client.User.Query().
Where(dbuser.EmailEQ("wechat-union-456@wechat-connect.invalid")).
Only(ctx)
require.NoError(t, err)
require.Equal(t, "WeChat Display", userEntity.Username)
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("wechat"),
authidentity.ProviderKeyEQ("wechat-main"),
authidentity.ProviderSubjectEQ("union-456"),
).
Only(ctx)
require.NoError(t, err)
require.Equal(t, userEntity.ID, identity.UserID)
require.Equal(t, "WeChat Display", identity.Metadata["display_name"])
require.Equal(t, "https://cdn.example/wechat.png", identity.Metadata["avatar_url"])
decision, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingSession.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, decision.IdentityID)
require.Equal(t, identity.ID, *decision.IdentityID)
require.True(t, decision.AdoptDisplayName)
require.True(t, decision.AdoptAvatar)
consumed, err := client.PendingAuthSession.Query().
Where(pendingauthsession.IDEQ(pendingSession.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, consumed.ConsumedAt)
}
func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) {
t.Helper()
db, err := sql.Open("sqlite", "file:auth_wechat_oauth?mode=memory&cache=shared")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
userRepo := &oauthPendingFlowUserRepo{client: client}
redeemRepo := repository.NewRedeemCodeRepository(client)
settingSvc := service.NewSettingService(&wechatOAuthSettingRepoStub{
values: map[string]string{
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled),
},
}, &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
ExpireHour: 1,
AccessTokenExpireMinutes: 60,
RefreshTokenExpireDays: 7,
},
Default: config.DefaultConfig{
UserBalance: 0,
UserConcurrency: 1,
},
})
authSvc := service.NewAuthService(
client,
userRepo,
redeemRepo,
&wechatOAuthRefreshTokenCacheStub{},
&config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
ExpireHour: 1,
AccessTokenExpireMinutes: 60,
RefreshTokenExpireDays: 7,
},
Default: config.DefaultConfig{
UserBalance: 0,
UserConcurrency: 1,
},
},
settingSvc,
nil,
nil,
nil,
nil,
nil,
)
return &AuthHandler{
authService: authSvc,
settingSvc: settingSvc,
}, client
}
func encodedCookie(name, value string) *http.Cookie {
return &http.Cookie{
Name: name,
Value: encodeCookieValue(value),
Path: "/",
}
}
func findCookie(cookies []*http.Cookie, name string) *http.Cookie {
for _, cookie := range cookies {
if cookie.Name == name {
return cookie
}
}
return nil
}
func decodeCookieValueForTest(t *testing.T, value string) string {
t.Helper()
raw, err := base64.RawURLEncoding.DecodeString(value)
require.NoError(t, err)
return string(raw)
}
type wechatOAuthSettingRepoStub struct {
values map[string]string
}
func (s *wechatOAuthSettingRepoStub) Get(context.Context, string) (*service.Setting, error) {
return nil, service.ErrSettingNotFound
}
func (s *wechatOAuthSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
value, ok := s.values[key]
if !ok {
return "", service.ErrSettingNotFound
}
return value, nil
}
func (s *wechatOAuthSettingRepoStub) Set(context.Context, string, string) error {
return nil
}
func (s *wechatOAuthSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
result := make(map[string]string, len(keys))
for _, key := range keys {
if value, ok := s.values[key]; ok {
result[key] = value
}
}
return result, nil
}
func (s *wechatOAuthSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
return nil
}
func (s *wechatOAuthSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
result := make(map[string]string, len(s.values))
for key, value := range s.values {
result[key] = value
}
return result, nil
}
func (s *wechatOAuthSettingRepoStub) Delete(context.Context, string) error {
return nil
}
type wechatOAuthRefreshTokenCacheStub struct{}
func (s *wechatOAuthRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error {
return nil
}
func (s *wechatOAuthRefreshTokenCacheStub) GetRefreshToken(context.Context, string) (*service.RefreshTokenData, error) {
return nil, service.ErrRefreshTokenNotFound
}
func (s *wechatOAuthRefreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error {
return nil
}
func (s *wechatOAuthRefreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error {
return nil
}
func (s *wechatOAuthRefreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error {
return nil
}
func (s *wechatOAuthRefreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error {
return nil
}
func (s *wechatOAuthRefreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error {
return nil
}
func (s *wechatOAuthRefreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) {
return nil, nil
}
func (s *wechatOAuthRefreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) {
return nil, nil
}
func (s *wechatOAuthRefreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) {
return false, nil
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment