Commit b017f461 authored by 陈曦's avatar 陈曦
Browse files

confict fixed by v117

parents 11b97147 d162604f
package schema
import (
"fmt"
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema"
"entgo.io/ent/schema/edge"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/index"
)
var pendingAuthIntents = map[string]struct{}{
"login": {},
"bind_current_user": {},
"adopt_existing_user_by_email": {},
}
func validatePendingAuthIntent(value string) error {
if _, ok := pendingAuthIntents[value]; ok {
return nil
}
return fmt.Errorf("invalid pending auth intent %q", value)
}
// PendingAuthSession stores a short-lived post-auth decision session.
type PendingAuthSession struct {
ent.Schema
}
func (PendingAuthSession) Annotations() []schema.Annotation {
return []schema.Annotation{
entsql.Annotation{Table: "pending_auth_sessions"},
}
}
func (PendingAuthSession) Mixin() []ent.Mixin {
return []ent.Mixin{
mixins.TimeMixin{},
}
}
func (PendingAuthSession) Fields() []ent.Field {
return []ent.Field{
field.String("session_token").
MaxLen(255).
NotEmpty(),
field.String("intent").
MaxLen(40).
NotEmpty().
Validate(validatePendingAuthIntent),
field.String("provider_type").
MaxLen(20).
NotEmpty().
Validate(validateAuthProviderType),
field.String("provider_key").
NotEmpty().
SchemaType(map[string]string{dialect.Postgres: "text"}),
field.String("provider_subject").
NotEmpty().
SchemaType(map[string]string{dialect.Postgres: "text"}),
field.Int64("target_user_id").
Optional().
Nillable(),
field.String("redirect_to").
Default("").
SchemaType(map[string]string{dialect.Postgres: "text"}),
field.String("resolved_email").
Default("").
SchemaType(map[string]string{dialect.Postgres: "text"}),
field.String("registration_password_hash").
Default("").
SchemaType(map[string]string{dialect.Postgres: "text"}),
field.JSON("upstream_identity_claims", map[string]any{}).
Default(func() map[string]any { return map[string]any{} }).
SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
field.JSON("local_flow_state", map[string]any{}).
Default(func() map[string]any { return map[string]any{} }).
SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
field.String("browser_session_key").
Default("").
SchemaType(map[string]string{dialect.Postgres: "text"}),
field.String("completion_code_hash").
Default("").
SchemaType(map[string]string{dialect.Postgres: "text"}),
field.Time("completion_code_expires_at").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
field.Time("email_verified_at").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
field.Time("password_verified_at").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
field.Time("totp_verified_at").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
field.Time("expires_at").
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
field.Time("consumed_at").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
}
}
func (PendingAuthSession) Edges() []ent.Edge {
return []ent.Edge{
edge.From("target_user", User.Type).
Ref("pending_auth_sessions").
Field("target_user_id").
Unique(),
edge.To("adoption_decision", IdentityAdoptionDecision.Type).
Annotations(entsql.OnDelete(entsql.Cascade)).
Unique(),
}
}
func (PendingAuthSession) Indexes() []ent.Index {
return []ent.Index{
index.Fields("session_token").Unique(),
index.Fields("target_user_id"),
index.Fields("expires_at"),
index.Fields("provider_type", "provider_key", "provider_subject"),
index.Fields("completion_code_hash"),
}
}
package schema
import (
"fmt"
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"github.com/Wei-Shaw/sub2api/internal/domain"
......@@ -72,6 +74,24 @@ func (User) Fields() []ent.Field {
field.Time("totp_enabled_at").
Optional().
Nillable(),
field.String("signup_source").
Validate(func(value string) error {
switch value {
case "email", "linuxdo", "wechat", "oidc":
return nil
default:
return fmt.Errorf("must be one of email, linuxdo, wechat, oidc")
}
}).
Default("email"),
field.Time("last_login_at").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
field.Time("last_active_at").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
// 余额不足通知
field.Bool("balance_notify_enabled").
......@@ -88,6 +108,10 @@ func (User) Fields() []ent.Field {
field.Float("total_recharged").
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
Default(0),
// 用户级每分钟请求数上限(0 = 不限制)。仅当所在分组未设置 rpm_limit 时作为兜底生效。
field.Int("rpm_limit").
Default(0),
}
}
......@@ -104,6 +128,9 @@ func (User) Edges() []ent.Edge {
edge.To("attribute_values", UserAttributeValue.Type),
edge.To("promo_code_usages", PromoCodeUsage.Type),
edge.To("payment_orders", PaymentOrder.Type),
edge.To("auth_identities", AuthIdentity.Type).
Annotations(entsql.OnDelete(entsql.Cascade)),
edge.To("pending_auth_sessions", PendingAuthSession.Type),
}
}
......
......@@ -24,18 +24,34 @@ type Tx struct {
Announcement *AnnouncementClient
// AnnouncementRead is the client for interacting with the AnnouncementRead builders.
AnnouncementRead *AnnouncementReadClient
// AuthIdentity is the client for interacting with the AuthIdentity builders.
AuthIdentity *AuthIdentityClient
// AuthIdentityChannel is the client for interacting with the AuthIdentityChannel builders.
AuthIdentityChannel *AuthIdentityChannelClient
// ChannelMonitor is the client for interacting with the ChannelMonitor builders.
ChannelMonitor *ChannelMonitorClient
// ChannelMonitorDailyRollup is the client for interacting with the ChannelMonitorDailyRollup builders.
ChannelMonitorDailyRollup *ChannelMonitorDailyRollupClient
// ChannelMonitorHistory is the client for interacting with the ChannelMonitorHistory builders.
ChannelMonitorHistory *ChannelMonitorHistoryClient
// ChannelMonitorRequestTemplate is the client for interacting with the ChannelMonitorRequestTemplate builders.
ChannelMonitorRequestTemplate *ChannelMonitorRequestTemplateClient
// ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders.
ErrorPassthroughRule *ErrorPassthroughRuleClient
// Group is the client for interacting with the Group builders.
Group *GroupClient
// IdempotencyRecord is the client for interacting with the IdempotencyRecord builders.
IdempotencyRecord *IdempotencyRecordClient
// IdentityAdoptionDecision is the client for interacting with the IdentityAdoptionDecision builders.
IdentityAdoptionDecision *IdentityAdoptionDecisionClient
// PaymentAuditLog is the client for interacting with the PaymentAuditLog builders.
PaymentAuditLog *PaymentAuditLogClient
// PaymentOrder is the client for interacting with the PaymentOrder builders.
PaymentOrder *PaymentOrderClient
// PaymentProviderInstance is the client for interacting with the PaymentProviderInstance builders.
PaymentProviderInstance *PaymentProviderInstanceClient
// PendingAuthSession is the client for interacting with the PendingAuthSession builders.
PendingAuthSession *PendingAuthSessionClient
// PromoCode is the client for interacting with the PromoCode builders.
PromoCode *PromoCodeClient
// PromoCodeUsage is the client for interacting with the PromoCodeUsage builders.
......@@ -202,12 +218,20 @@ func (tx *Tx) init() {
tx.AccountGroup = NewAccountGroupClient(tx.config)
tx.Announcement = NewAnnouncementClient(tx.config)
tx.AnnouncementRead = NewAnnouncementReadClient(tx.config)
tx.AuthIdentity = NewAuthIdentityClient(tx.config)
tx.AuthIdentityChannel = NewAuthIdentityChannelClient(tx.config)
tx.ChannelMonitor = NewChannelMonitorClient(tx.config)
tx.ChannelMonitorDailyRollup = NewChannelMonitorDailyRollupClient(tx.config)
tx.ChannelMonitorHistory = NewChannelMonitorHistoryClient(tx.config)
tx.ChannelMonitorRequestTemplate = NewChannelMonitorRequestTemplateClient(tx.config)
tx.ErrorPassthroughRule = NewErrorPassthroughRuleClient(tx.config)
tx.Group = NewGroupClient(tx.config)
tx.IdempotencyRecord = NewIdempotencyRecordClient(tx.config)
tx.IdentityAdoptionDecision = NewIdentityAdoptionDecisionClient(tx.config)
tx.PaymentAuditLog = NewPaymentAuditLogClient(tx.config)
tx.PaymentOrder = NewPaymentOrderClient(tx.config)
tx.PaymentProviderInstance = NewPaymentProviderInstanceClient(tx.config)
tx.PendingAuthSession = NewPendingAuthSessionClient(tx.config)
tx.PromoCode = NewPromoCodeClient(tx.config)
tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config)
tx.Proxy = NewProxyClient(tx.config)
......
......@@ -45,6 +45,12 @@ type User struct {
TotpEnabled bool `json:"totp_enabled,omitempty"`
// TotpEnabledAt holds the value of the "totp_enabled_at" field.
TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"`
// SignupSource holds the value of the "signup_source" field.
SignupSource string `json:"signup_source,omitempty"`
// LastLoginAt holds the value of the "last_login_at" field.
LastLoginAt *time.Time `json:"last_login_at,omitempty"`
// LastActiveAt holds the value of the "last_active_at" field.
LastActiveAt *time.Time `json:"last_active_at,omitempty"`
// BalanceNotifyEnabled holds the value of the "balance_notify_enabled" field.
BalanceNotifyEnabled bool `json:"balance_notify_enabled,omitempty"`
// BalanceNotifyThresholdType holds the value of the "balance_notify_threshold_type" field.
......@@ -55,6 +61,8 @@ type User struct {
BalanceNotifyExtraEmails string `json:"balance_notify_extra_emails,omitempty"`
// TotalRecharged holds the value of the "total_recharged" field.
TotalRecharged float64 `json:"total_recharged,omitempty"`
// RpmLimit holds the value of the "rpm_limit" field.
RpmLimit int `json:"rpm_limit,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the UserQuery when eager-loading is set.
Edges UserEdges `json:"edges"`
......@@ -83,11 +91,15 @@ type UserEdges struct {
PromoCodeUsages []*PromoCodeUsage `json:"promo_code_usages,omitempty"`
// PaymentOrders holds the value of the payment_orders edge.
PaymentOrders []*PaymentOrder `json:"payment_orders,omitempty"`
// AuthIdentities holds the value of the auth_identities edge.
AuthIdentities []*AuthIdentity `json:"auth_identities,omitempty"`
// PendingAuthSessions holds the value of the pending_auth_sessions edge.
PendingAuthSessions []*PendingAuthSession `json:"pending_auth_sessions,omitempty"`
// UserAllowedGroups holds the value of the user_allowed_groups edge.
UserAllowedGroups []*UserAllowedGroup `json:"user_allowed_groups,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
loadedTypes [11]bool
loadedTypes [13]bool
}
// APIKeysOrErr returns the APIKeys value or an error if the edge
......@@ -180,10 +192,28 @@ func (e UserEdges) PaymentOrdersOrErr() ([]*PaymentOrder, error) {
return nil, &NotLoadedError{edge: "payment_orders"}
}
// AuthIdentitiesOrErr returns the AuthIdentities value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) AuthIdentitiesOrErr() ([]*AuthIdentity, error) {
if e.loadedTypes[10] {
return e.AuthIdentities, nil
}
return nil, &NotLoadedError{edge: "auth_identities"}
}
// PendingAuthSessionsOrErr returns the PendingAuthSessions value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) PendingAuthSessionsOrErr() ([]*PendingAuthSession, error) {
if e.loadedTypes[11] {
return e.PendingAuthSessions, nil
}
return nil, &NotLoadedError{edge: "pending_auth_sessions"}
}
// UserAllowedGroupsOrErr returns the UserAllowedGroups value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) UserAllowedGroupsOrErr() ([]*UserAllowedGroup, error) {
if e.loadedTypes[10] {
if e.loadedTypes[12] {
return e.UserAllowedGroups, nil
}
return nil, &NotLoadedError{edge: "user_allowed_groups"}
......@@ -198,11 +228,11 @@ func (*User) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullBool)
case user.FieldBalance, user.FieldBalanceNotifyThreshold, user.FieldTotalRecharged:
values[i] = new(sql.NullFloat64)
case user.FieldID, user.FieldConcurrency:
case user.FieldID, user.FieldConcurrency, user.FieldRpmLimit:
values[i] = new(sql.NullInt64)
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldBalanceNotifyThresholdType, user.FieldBalanceNotifyExtraEmails:
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldSignupSource, user.FieldBalanceNotifyThresholdType, user.FieldBalanceNotifyExtraEmails:
values[i] = new(sql.NullString)
case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt:
case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt, user.FieldLastLoginAt, user.FieldLastActiveAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
......@@ -312,6 +342,26 @@ func (_m *User) assignValues(columns []string, values []any) error {
_m.TotpEnabledAt = new(time.Time)
*_m.TotpEnabledAt = value.Time
}
case user.FieldSignupSource:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field signup_source", values[i])
} else if value.Valid {
_m.SignupSource = value.String
}
case user.FieldLastLoginAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field last_login_at", values[i])
} else if value.Valid {
_m.LastLoginAt = new(time.Time)
*_m.LastLoginAt = value.Time
}
case user.FieldLastActiveAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field last_active_at", values[i])
} else if value.Valid {
_m.LastActiveAt = new(time.Time)
*_m.LastActiveAt = value.Time
}
case user.FieldBalanceNotifyEnabled:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field balance_notify_enabled", values[i])
......@@ -343,6 +393,12 @@ func (_m *User) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.TotalRecharged = value.Float64
}
case user.FieldRpmLimit:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field rpm_limit", values[i])
} else if value.Valid {
_m.RpmLimit = int(value.Int64)
}
default:
_m.selectValues.Set(columns[i], values[i])
}
......@@ -406,6 +462,16 @@ func (_m *User) QueryPaymentOrders() *PaymentOrderQuery {
return NewUserClient(_m.config).QueryPaymentOrders(_m)
}
// QueryAuthIdentities queries the "auth_identities" edge of the User entity.
func (_m *User) QueryAuthIdentities() *AuthIdentityQuery {
return NewUserClient(_m.config).QueryAuthIdentities(_m)
}
// QueryPendingAuthSessions queries the "pending_auth_sessions" edge of the User entity.
func (_m *User) QueryPendingAuthSessions() *PendingAuthSessionQuery {
return NewUserClient(_m.config).QueryPendingAuthSessions(_m)
}
// QueryUserAllowedGroups queries the "user_allowed_groups" edge of the User entity.
func (_m *User) QueryUserAllowedGroups() *UserAllowedGroupQuery {
return NewUserClient(_m.config).QueryUserAllowedGroups(_m)
......@@ -482,6 +548,19 @@ func (_m *User) String() string {
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
builder.WriteString("signup_source=")
builder.WriteString(_m.SignupSource)
builder.WriteString(", ")
if v := _m.LastLoginAt; v != nil {
builder.WriteString("last_login_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
if v := _m.LastActiveAt; v != nil {
builder.WriteString("last_active_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
builder.WriteString("balance_notify_enabled=")
builder.WriteString(fmt.Sprintf("%v", _m.BalanceNotifyEnabled))
builder.WriteString(", ")
......@@ -498,6 +577,9 @@ func (_m *User) String() string {
builder.WriteString(", ")
builder.WriteString("total_recharged=")
builder.WriteString(fmt.Sprintf("%v", _m.TotalRecharged))
builder.WriteString(", ")
builder.WriteString("rpm_limit=")
builder.WriteString(fmt.Sprintf("%v", _m.RpmLimit))
builder.WriteByte(')')
return builder.String()
}
......
......@@ -43,6 +43,12 @@ const (
FieldTotpEnabled = "totp_enabled"
// FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database.
FieldTotpEnabledAt = "totp_enabled_at"
// FieldSignupSource holds the string denoting the signup_source field in the database.
FieldSignupSource = "signup_source"
// FieldLastLoginAt holds the string denoting the last_login_at field in the database.
FieldLastLoginAt = "last_login_at"
// FieldLastActiveAt holds the string denoting the last_active_at field in the database.
FieldLastActiveAt = "last_active_at"
// FieldBalanceNotifyEnabled holds the string denoting the balance_notify_enabled field in the database.
FieldBalanceNotifyEnabled = "balance_notify_enabled"
// FieldBalanceNotifyThresholdType holds the string denoting the balance_notify_threshold_type field in the database.
......@@ -53,6 +59,8 @@ const (
FieldBalanceNotifyExtraEmails = "balance_notify_extra_emails"
// FieldTotalRecharged holds the string denoting the total_recharged field in the database.
FieldTotalRecharged = "total_recharged"
// FieldRpmLimit holds the string denoting the rpm_limit field in the database.
FieldRpmLimit = "rpm_limit"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
EdgeAPIKeys = "api_keys"
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
......@@ -73,6 +81,10 @@ const (
EdgePromoCodeUsages = "promo_code_usages"
// EdgePaymentOrders holds the string denoting the payment_orders edge name in mutations.
EdgePaymentOrders = "payment_orders"
// EdgeAuthIdentities holds the string denoting the auth_identities edge name in mutations.
EdgeAuthIdentities = "auth_identities"
// EdgePendingAuthSessions holds the string denoting the pending_auth_sessions edge name in mutations.
EdgePendingAuthSessions = "pending_auth_sessions"
// EdgeUserAllowedGroups holds the string denoting the user_allowed_groups edge name in mutations.
EdgeUserAllowedGroups = "user_allowed_groups"
// Table holds the table name of the user in the database.
......@@ -145,6 +157,20 @@ const (
PaymentOrdersInverseTable = "payment_orders"
// PaymentOrdersColumn is the table column denoting the payment_orders relation/edge.
PaymentOrdersColumn = "user_id"
// AuthIdentitiesTable is the table that holds the auth_identities relation/edge.
AuthIdentitiesTable = "auth_identities"
// AuthIdentitiesInverseTable is the table name for the AuthIdentity entity.
// It exists in this package in order to avoid circular dependency with the "authidentity" package.
AuthIdentitiesInverseTable = "auth_identities"
// AuthIdentitiesColumn is the table column denoting the auth_identities relation/edge.
AuthIdentitiesColumn = "user_id"
// PendingAuthSessionsTable is the table that holds the pending_auth_sessions relation/edge.
PendingAuthSessionsTable = "pending_auth_sessions"
// PendingAuthSessionsInverseTable is the table name for the PendingAuthSession entity.
// It exists in this package in order to avoid circular dependency with the "pendingauthsession" package.
PendingAuthSessionsInverseTable = "pending_auth_sessions"
// PendingAuthSessionsColumn is the table column denoting the pending_auth_sessions relation/edge.
PendingAuthSessionsColumn = "target_user_id"
// UserAllowedGroupsTable is the table that holds the user_allowed_groups relation/edge.
UserAllowedGroupsTable = "user_allowed_groups"
// UserAllowedGroupsInverseTable is the table name for the UserAllowedGroup entity.
......@@ -171,11 +197,15 @@ var Columns = []string{
FieldTotpSecretEncrypted,
FieldTotpEnabled,
FieldTotpEnabledAt,
FieldSignupSource,
FieldLastLoginAt,
FieldLastActiveAt,
FieldBalanceNotifyEnabled,
FieldBalanceNotifyThresholdType,
FieldBalanceNotifyThreshold,
FieldBalanceNotifyExtraEmails,
FieldTotalRecharged,
FieldRpmLimit,
}
var (
......@@ -232,6 +262,10 @@ var (
DefaultNotes string
// DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field.
DefaultTotpEnabled bool
// DefaultSignupSource holds the default value on creation for the "signup_source" field.
DefaultSignupSource string
// SignupSourceValidator is a validator for the "signup_source" field. It is called by the builders before save.
SignupSourceValidator func(string) error
// DefaultBalanceNotifyEnabled holds the default value on creation for the "balance_notify_enabled" field.
DefaultBalanceNotifyEnabled bool
// DefaultBalanceNotifyThresholdType holds the default value on creation for the "balance_notify_threshold_type" field.
......@@ -240,6 +274,8 @@ var (
DefaultBalanceNotifyExtraEmails string
// DefaultTotalRecharged holds the default value on creation for the "total_recharged" field.
DefaultTotalRecharged float64
// DefaultRpmLimit holds the default value on creation for the "rpm_limit" field.
DefaultRpmLimit int
)
// OrderOption defines the ordering options for the User queries.
......@@ -320,6 +356,21 @@ func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc()
}
// BySignupSource orders the results by the signup_source field.
func BySignupSource(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSignupSource, opts...).ToFunc()
}
// ByLastLoginAt orders the results by the last_login_at field.
func ByLastLoginAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldLastLoginAt, opts...).ToFunc()
}
// ByLastActiveAt orders the results by the last_active_at field.
func ByLastActiveAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldLastActiveAt, opts...).ToFunc()
}
// ByBalanceNotifyEnabled orders the results by the balance_notify_enabled field.
func ByBalanceNotifyEnabled(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldBalanceNotifyEnabled, opts...).ToFunc()
......@@ -345,6 +396,11 @@ func ByTotalRecharged(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTotalRecharged, opts...).ToFunc()
}
// ByRpmLimit orders the results by the rpm_limit field.
func ByRpmLimit(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldRpmLimit, opts...).ToFunc()
}
// ByAPIKeysCount orders the results by api_keys count.
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
......@@ -485,6 +541,34 @@ func ByPaymentOrders(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
}
}
// ByAuthIdentitiesCount orders the results by auth_identities count.
func ByAuthIdentitiesCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborsCount(s, newAuthIdentitiesStep(), opts...)
}
}
// ByAuthIdentities orders the results by auth_identities terms.
func ByAuthIdentities(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newAuthIdentitiesStep(), append([]sql.OrderTerm{term}, terms...)...)
}
}
// ByPendingAuthSessionsCount orders the results by pending_auth_sessions count.
func ByPendingAuthSessionsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborsCount(s, newPendingAuthSessionsStep(), opts...)
}
}
// ByPendingAuthSessions orders the results by pending_auth_sessions terms.
func ByPendingAuthSessions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newPendingAuthSessionsStep(), append([]sql.OrderTerm{term}, terms...)...)
}
}
// ByUserAllowedGroupsCount orders the results by user_allowed_groups count.
func ByUserAllowedGroupsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
......@@ -568,6 +652,20 @@ func newPaymentOrdersStep() *sqlgraph.Step {
sqlgraph.Edge(sqlgraph.O2M, false, PaymentOrdersTable, PaymentOrdersColumn),
)
}
func newAuthIdentitiesStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(AuthIdentitiesInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, AuthIdentitiesTable, AuthIdentitiesColumn),
)
}
func newPendingAuthSessionsStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(PendingAuthSessionsInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, PendingAuthSessionsTable, PendingAuthSessionsColumn),
)
}
func newUserAllowedGroupsStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
......
......@@ -125,6 +125,21 @@ func TotpEnabledAt(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
}
// SignupSource applies equality check predicate on the "signup_source" field. It's identical to SignupSourceEQ.
func SignupSource(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldSignupSource, v))
}
// LastLoginAt applies equality check predicate on the "last_login_at" field. It's identical to LastLoginAtEQ.
func LastLoginAt(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldLastLoginAt, v))
}
// LastActiveAt applies equality check predicate on the "last_active_at" field. It's identical to LastActiveAtEQ.
func LastActiveAt(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldLastActiveAt, v))
}
// BalanceNotifyEnabled applies equality check predicate on the "balance_notify_enabled" field. It's identical to BalanceNotifyEnabledEQ.
func BalanceNotifyEnabled(v bool) predicate.User {
return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v))
......@@ -150,6 +165,11 @@ func TotalRecharged(v float64) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotalRecharged, v))
}
// RpmLimit applies equality check predicate on the "rpm_limit" field. It's identical to RpmLimitEQ.
func RpmLimit(v int) predicate.User {
return predicate.User(sql.FieldEQ(FieldRpmLimit, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldCreatedAt, v))
......@@ -885,6 +905,171 @@ func TotpEnabledAtNotNil() predicate.User {
return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt))
}
// SignupSourceEQ applies the EQ predicate on the "signup_source" field.
func SignupSourceEQ(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldSignupSource, v))
}
// SignupSourceNEQ applies the NEQ predicate on the "signup_source" field.
func SignupSourceNEQ(v string) predicate.User {
return predicate.User(sql.FieldNEQ(FieldSignupSource, v))
}
// SignupSourceIn applies the In predicate on the "signup_source" field.
func SignupSourceIn(vs ...string) predicate.User {
return predicate.User(sql.FieldIn(FieldSignupSource, vs...))
}
// SignupSourceNotIn applies the NotIn predicate on the "signup_source" field.
func SignupSourceNotIn(vs ...string) predicate.User {
return predicate.User(sql.FieldNotIn(FieldSignupSource, vs...))
}
// SignupSourceGT applies the GT predicate on the "signup_source" field.
func SignupSourceGT(v string) predicate.User {
return predicate.User(sql.FieldGT(FieldSignupSource, v))
}
// SignupSourceGTE applies the GTE predicate on the "signup_source" field.
func SignupSourceGTE(v string) predicate.User {
return predicate.User(sql.FieldGTE(FieldSignupSource, v))
}
// SignupSourceLT applies the LT predicate on the "signup_source" field.
func SignupSourceLT(v string) predicate.User {
return predicate.User(sql.FieldLT(FieldSignupSource, v))
}
// SignupSourceLTE applies the LTE predicate on the "signup_source" field.
func SignupSourceLTE(v string) predicate.User {
return predicate.User(sql.FieldLTE(FieldSignupSource, v))
}
// SignupSourceContains applies the Contains predicate on the "signup_source" field.
func SignupSourceContains(v string) predicate.User {
return predicate.User(sql.FieldContains(FieldSignupSource, v))
}
// SignupSourceHasPrefix applies the HasPrefix predicate on the "signup_source" field.
func SignupSourceHasPrefix(v string) predicate.User {
return predicate.User(sql.FieldHasPrefix(FieldSignupSource, v))
}
// SignupSourceHasSuffix applies the HasSuffix predicate on the "signup_source" field.
func SignupSourceHasSuffix(v string) predicate.User {
return predicate.User(sql.FieldHasSuffix(FieldSignupSource, v))
}
// SignupSourceEqualFold applies the EqualFold predicate on the "signup_source" field.
func SignupSourceEqualFold(v string) predicate.User {
return predicate.User(sql.FieldEqualFold(FieldSignupSource, v))
}
// SignupSourceContainsFold applies the ContainsFold predicate on the "signup_source" field.
func SignupSourceContainsFold(v string) predicate.User {
return predicate.User(sql.FieldContainsFold(FieldSignupSource, v))
}
// LastLoginAtEQ applies the EQ predicate on the "last_login_at" field.
func LastLoginAtEQ(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldLastLoginAt, v))
}
// LastLoginAtNEQ applies the NEQ predicate on the "last_login_at" field.
func LastLoginAtNEQ(v time.Time) predicate.User {
return predicate.User(sql.FieldNEQ(FieldLastLoginAt, v))
}
// LastLoginAtIn applies the In predicate on the "last_login_at" field.
func LastLoginAtIn(vs ...time.Time) predicate.User {
return predicate.User(sql.FieldIn(FieldLastLoginAt, vs...))
}
// LastLoginAtNotIn applies the NotIn predicate on the "last_login_at" field.
func LastLoginAtNotIn(vs ...time.Time) predicate.User {
return predicate.User(sql.FieldNotIn(FieldLastLoginAt, vs...))
}
// LastLoginAtGT applies the GT predicate on the "last_login_at" field.
func LastLoginAtGT(v time.Time) predicate.User {
return predicate.User(sql.FieldGT(FieldLastLoginAt, v))
}
// LastLoginAtGTE applies the GTE predicate on the "last_login_at" field.
func LastLoginAtGTE(v time.Time) predicate.User {
return predicate.User(sql.FieldGTE(FieldLastLoginAt, v))
}
// LastLoginAtLT applies the LT predicate on the "last_login_at" field.
func LastLoginAtLT(v time.Time) predicate.User {
return predicate.User(sql.FieldLT(FieldLastLoginAt, v))
}
// LastLoginAtLTE applies the LTE predicate on the "last_login_at" field.
func LastLoginAtLTE(v time.Time) predicate.User {
return predicate.User(sql.FieldLTE(FieldLastLoginAt, v))
}
// LastLoginAtIsNil applies the IsNil predicate on the "last_login_at" field.
func LastLoginAtIsNil() predicate.User {
return predicate.User(sql.FieldIsNull(FieldLastLoginAt))
}
// LastLoginAtNotNil applies the NotNil predicate on the "last_login_at" field.
func LastLoginAtNotNil() predicate.User {
return predicate.User(sql.FieldNotNull(FieldLastLoginAt))
}
// LastActiveAtEQ applies the EQ predicate on the "last_active_at" field.
func LastActiveAtEQ(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldLastActiveAt, v))
}
// LastActiveAtNEQ applies the NEQ predicate on the "last_active_at" field.
func LastActiveAtNEQ(v time.Time) predicate.User {
return predicate.User(sql.FieldNEQ(FieldLastActiveAt, v))
}
// LastActiveAtIn applies the In predicate on the "last_active_at" field.
func LastActiveAtIn(vs ...time.Time) predicate.User {
return predicate.User(sql.FieldIn(FieldLastActiveAt, vs...))
}
// LastActiveAtNotIn applies the NotIn predicate on the "last_active_at" field.
func LastActiveAtNotIn(vs ...time.Time) predicate.User {
return predicate.User(sql.FieldNotIn(FieldLastActiveAt, vs...))
}
// LastActiveAtGT applies the GT predicate on the "last_active_at" field.
func LastActiveAtGT(v time.Time) predicate.User {
return predicate.User(sql.FieldGT(FieldLastActiveAt, v))
}
// LastActiveAtGTE applies the GTE predicate on the "last_active_at" field.
func LastActiveAtGTE(v time.Time) predicate.User {
return predicate.User(sql.FieldGTE(FieldLastActiveAt, v))
}
// LastActiveAtLT applies the LT predicate on the "last_active_at" field.
func LastActiveAtLT(v time.Time) predicate.User {
return predicate.User(sql.FieldLT(FieldLastActiveAt, v))
}
// LastActiveAtLTE applies the LTE predicate on the "last_active_at" field.
func LastActiveAtLTE(v time.Time) predicate.User {
return predicate.User(sql.FieldLTE(FieldLastActiveAt, v))
}
// LastActiveAtIsNil applies the IsNil predicate on the "last_active_at" field.
func LastActiveAtIsNil() predicate.User {
return predicate.User(sql.FieldIsNull(FieldLastActiveAt))
}
// LastActiveAtNotNil applies the NotNil predicate on the "last_active_at" field.
func LastActiveAtNotNil() predicate.User {
return predicate.User(sql.FieldNotNull(FieldLastActiveAt))
}
// BalanceNotifyEnabledEQ applies the EQ predicate on the "balance_notify_enabled" field.
func BalanceNotifyEnabledEQ(v bool) predicate.User {
return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v))
......@@ -1115,6 +1300,46 @@ func TotalRechargedLTE(v float64) predicate.User {
return predicate.User(sql.FieldLTE(FieldTotalRecharged, v))
}
// RpmLimitEQ applies the EQ predicate on the "rpm_limit" field.
func RpmLimitEQ(v int) predicate.User {
return predicate.User(sql.FieldEQ(FieldRpmLimit, v))
}
// RpmLimitNEQ applies the NEQ predicate on the "rpm_limit" field.
func RpmLimitNEQ(v int) predicate.User {
return predicate.User(sql.FieldNEQ(FieldRpmLimit, v))
}
// RpmLimitIn applies the In predicate on the "rpm_limit" field.
func RpmLimitIn(vs ...int) predicate.User {
return predicate.User(sql.FieldIn(FieldRpmLimit, vs...))
}
// RpmLimitNotIn applies the NotIn predicate on the "rpm_limit" field.
func RpmLimitNotIn(vs ...int) predicate.User {
return predicate.User(sql.FieldNotIn(FieldRpmLimit, vs...))
}
// RpmLimitGT applies the GT predicate on the "rpm_limit" field.
func RpmLimitGT(v int) predicate.User {
return predicate.User(sql.FieldGT(FieldRpmLimit, v))
}
// RpmLimitGTE applies the GTE predicate on the "rpm_limit" field.
func RpmLimitGTE(v int) predicate.User {
return predicate.User(sql.FieldGTE(FieldRpmLimit, v))
}
// RpmLimitLT applies the LT predicate on the "rpm_limit" field.
func RpmLimitLT(v int) predicate.User {
return predicate.User(sql.FieldLT(FieldRpmLimit, v))
}
// RpmLimitLTE applies the LTE predicate on the "rpm_limit" field.
func RpmLimitLTE(v int) predicate.User {
return predicate.User(sql.FieldLTE(FieldRpmLimit, v))
}
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
func HasAPIKeys() predicate.User {
return predicate.User(func(s *sql.Selector) {
......@@ -1345,6 +1570,52 @@ func HasPaymentOrdersWith(preds ...predicate.PaymentOrder) predicate.User {
})
}
// HasAuthIdentities applies the HasEdge predicate on the "auth_identities" edge.
func HasAuthIdentities() predicate.User {
return predicate.User(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, AuthIdentitiesTable, AuthIdentitiesColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasAuthIdentitiesWith applies the HasEdge predicate on the "auth_identities" edge with a given conditions (other predicates).
func HasAuthIdentitiesWith(preds ...predicate.AuthIdentity) predicate.User {
return predicate.User(func(s *sql.Selector) {
step := newAuthIdentitiesStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// HasPendingAuthSessions applies the HasEdge predicate on the "pending_auth_sessions" edge.
func HasPendingAuthSessions() predicate.User {
return predicate.User(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, PendingAuthSessionsTable, PendingAuthSessionsColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasPendingAuthSessionsWith applies the HasEdge predicate on the "pending_auth_sessions" edge with a given conditions (other predicates).
func HasPendingAuthSessionsWith(preds ...predicate.PendingAuthSession) predicate.User {
return predicate.User(func(s *sql.Selector) {
step := newPendingAuthSessionsStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// HasUserAllowedGroups applies the HasEdge predicate on the "user_allowed_groups" edge.
func HasUserAllowedGroups() predicate.User {
return predicate.User(func(s *sql.Selector) {
......
......@@ -13,8 +13,10 @@ import (
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
......@@ -211,6 +213,48 @@ func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate {
return _c
}
// SetSignupSource sets the "signup_source" field.
func (_c *UserCreate) SetSignupSource(v string) *UserCreate {
_c.mutation.SetSignupSource(v)
return _c
}
// SetNillableSignupSource sets the "signup_source" field if the given value is not nil.
func (_c *UserCreate) SetNillableSignupSource(v *string) *UserCreate {
if v != nil {
_c.SetSignupSource(*v)
}
return _c
}
// SetLastLoginAt sets the "last_login_at" field.
func (_c *UserCreate) SetLastLoginAt(v time.Time) *UserCreate {
_c.mutation.SetLastLoginAt(v)
return _c
}
// SetNillableLastLoginAt sets the "last_login_at" field if the given value is not nil.
func (_c *UserCreate) SetNillableLastLoginAt(v *time.Time) *UserCreate {
if v != nil {
_c.SetLastLoginAt(*v)
}
return _c
}
// SetLastActiveAt sets the "last_active_at" field.
func (_c *UserCreate) SetLastActiveAt(v time.Time) *UserCreate {
_c.mutation.SetLastActiveAt(v)
return _c
}
// SetNillableLastActiveAt sets the "last_active_at" field if the given value is not nil.
func (_c *UserCreate) SetNillableLastActiveAt(v *time.Time) *UserCreate {
if v != nil {
_c.SetLastActiveAt(*v)
}
return _c
}
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (_c *UserCreate) SetBalanceNotifyEnabled(v bool) *UserCreate {
_c.mutation.SetBalanceNotifyEnabled(v)
......@@ -281,6 +325,20 @@ func (_c *UserCreate) SetNillableTotalRecharged(v *float64) *UserCreate {
return _c
}
// SetRpmLimit sets the "rpm_limit" field.
func (_c *UserCreate) SetRpmLimit(v int) *UserCreate {
_c.mutation.SetRpmLimit(v)
return _c
}
// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil.
func (_c *UserCreate) SetNillableRpmLimit(v *int) *UserCreate {
if v != nil {
_c.SetRpmLimit(*v)
}
return _c
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate {
_c.mutation.AddAPIKeyIDs(ids...)
......@@ -431,6 +489,36 @@ func (_c *UserCreate) AddPaymentOrders(v ...*PaymentOrder) *UserCreate {
return _c.AddPaymentOrderIDs(ids...)
}
// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by IDs.
func (_c *UserCreate) AddAuthIdentityIDs(ids ...int64) *UserCreate {
_c.mutation.AddAuthIdentityIDs(ids...)
return _c
}
// AddAuthIdentities adds the "auth_identities" edges to the AuthIdentity entity.
func (_c *UserCreate) AddAuthIdentities(v ...*AuthIdentity) *UserCreate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _c.AddAuthIdentityIDs(ids...)
}
// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs.
func (_c *UserCreate) AddPendingAuthSessionIDs(ids ...int64) *UserCreate {
_c.mutation.AddPendingAuthSessionIDs(ids...)
return _c
}
// AddPendingAuthSessions adds the "pending_auth_sessions" edges to the PendingAuthSession entity.
func (_c *UserCreate) AddPendingAuthSessions(v ...*PendingAuthSession) *UserCreate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _c.AddPendingAuthSessionIDs(ids...)
}
// Mutation returns the UserMutation object of the builder.
func (_c *UserCreate) Mutation() *UserMutation {
return _c.mutation
......@@ -510,6 +598,10 @@ func (_c *UserCreate) defaults() error {
v := user.DefaultTotpEnabled
_c.mutation.SetTotpEnabled(v)
}
if _, ok := _c.mutation.SignupSource(); !ok {
v := user.DefaultSignupSource
_c.mutation.SetSignupSource(v)
}
if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok {
v := user.DefaultBalanceNotifyEnabled
_c.mutation.SetBalanceNotifyEnabled(v)
......@@ -526,6 +618,10 @@ func (_c *UserCreate) defaults() error {
v := user.DefaultTotalRecharged
_c.mutation.SetTotalRecharged(v)
}
if _, ok := _c.mutation.RpmLimit(); !ok {
v := user.DefaultRpmLimit
_c.mutation.SetRpmLimit(v)
}
return nil
}
......@@ -589,6 +685,14 @@ func (_c *UserCreate) check() error {
if _, ok := _c.mutation.TotpEnabled(); !ok {
return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)}
}
if _, ok := _c.mutation.SignupSource(); !ok {
return &ValidationError{Name: "signup_source", err: errors.New(`ent: missing required field "User.signup_source"`)}
}
if v, ok := _c.mutation.SignupSource(); ok {
if err := user.SignupSourceValidator(v); err != nil {
return &ValidationError{Name: "signup_source", err: fmt.Errorf(`ent: validator failed for field "User.signup_source": %w`, err)}
}
}
if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok {
return &ValidationError{Name: "balance_notify_enabled", err: errors.New(`ent: missing required field "User.balance_notify_enabled"`)}
}
......@@ -601,6 +705,9 @@ func (_c *UserCreate) check() error {
if _, ok := _c.mutation.TotalRecharged(); !ok {
return &ValidationError{Name: "total_recharged", err: errors.New(`ent: missing required field "User.total_recharged"`)}
}
if _, ok := _c.mutation.RpmLimit(); !ok {
return &ValidationError{Name: "rpm_limit", err: errors.New(`ent: missing required field "User.rpm_limit"`)}
}
return nil
}
......@@ -684,6 +791,18 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
_node.TotpEnabledAt = &value
}
if value, ok := _c.mutation.SignupSource(); ok {
_spec.SetField(user.FieldSignupSource, field.TypeString, value)
_node.SignupSource = value
}
if value, ok := _c.mutation.LastLoginAt(); ok {
_spec.SetField(user.FieldLastLoginAt, field.TypeTime, value)
_node.LastLoginAt = &value
}
if value, ok := _c.mutation.LastActiveAt(); ok {
_spec.SetField(user.FieldLastActiveAt, field.TypeTime, value)
_node.LastActiveAt = &value
}
if value, ok := _c.mutation.BalanceNotifyEnabled(); ok {
_spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
_node.BalanceNotifyEnabled = value
......@@ -704,6 +823,10 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
_spec.SetField(user.FieldTotalRecharged, field.TypeFloat64, value)
_node.TotalRecharged = value
}
if value, ok := _c.mutation.RpmLimit(); ok {
_spec.SetField(user.FieldRpmLimit, field.TypeInt, value)
_node.RpmLimit = value
}
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
......@@ -868,6 +991,38 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
}
_spec.Edges = append(_spec.Edges, edge)
}
if nodes := _c.mutation.AuthIdentitiesIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AuthIdentitiesTable,
Columns: []string{user.AuthIdentitiesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges = append(_spec.Edges, edge)
}
if nodes := _c.mutation.PendingAuthSessionsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PendingAuthSessionsTable,
Columns: []string{user.PendingAuthSessionsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges = append(_spec.Edges, edge)
}
return _node, _spec
}
......@@ -1106,6 +1261,54 @@ func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert {
return u
}
// SetSignupSource sets the "signup_source" field.
func (u *UserUpsert) SetSignupSource(v string) *UserUpsert {
u.Set(user.FieldSignupSource, v)
return u
}
// UpdateSignupSource sets the "signup_source" field to the value that was provided on create.
func (u *UserUpsert) UpdateSignupSource() *UserUpsert {
u.SetExcluded(user.FieldSignupSource)
return u
}
// SetLastLoginAt sets the "last_login_at" field.
func (u *UserUpsert) SetLastLoginAt(v time.Time) *UserUpsert {
u.Set(user.FieldLastLoginAt, v)
return u
}
// UpdateLastLoginAt sets the "last_login_at" field to the value that was provided on create.
func (u *UserUpsert) UpdateLastLoginAt() *UserUpsert {
u.SetExcluded(user.FieldLastLoginAt)
return u
}
// ClearLastLoginAt clears the value of the "last_login_at" field.
func (u *UserUpsert) ClearLastLoginAt() *UserUpsert {
u.SetNull(user.FieldLastLoginAt)
return u
}
// SetLastActiveAt sets the "last_active_at" field.
func (u *UserUpsert) SetLastActiveAt(v time.Time) *UserUpsert {
u.Set(user.FieldLastActiveAt, v)
return u
}
// UpdateLastActiveAt sets the "last_active_at" field to the value that was provided on create.
func (u *UserUpsert) UpdateLastActiveAt() *UserUpsert {
u.SetExcluded(user.FieldLastActiveAt)
return u
}
// ClearLastActiveAt clears the value of the "last_active_at" field.
func (u *UserUpsert) ClearLastActiveAt() *UserUpsert {
u.SetNull(user.FieldLastActiveAt)
return u
}
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (u *UserUpsert) SetBalanceNotifyEnabled(v bool) *UserUpsert {
u.Set(user.FieldBalanceNotifyEnabled, v)
......@@ -1184,6 +1387,24 @@ func (u *UserUpsert) AddTotalRecharged(v float64) *UserUpsert {
return u
}
// SetRpmLimit sets the "rpm_limit" field.
func (u *UserUpsert) SetRpmLimit(v int) *UserUpsert {
u.Set(user.FieldRpmLimit, v)
return u
}
// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create.
func (u *UserUpsert) UpdateRpmLimit() *UserUpsert {
u.SetExcluded(user.FieldRpmLimit)
return u
}
// AddRpmLimit adds v to the "rpm_limit" field.
func (u *UserUpsert) AddRpmLimit(v int) *UserUpsert {
u.Add(user.FieldRpmLimit, v)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
......@@ -1446,6 +1667,62 @@ func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne {
})
}
// SetSignupSource sets the "signup_source" field.
func (u *UserUpsertOne) SetSignupSource(v string) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.SetSignupSource(v)
})
}
// UpdateSignupSource sets the "signup_source" field to the value that was provided on create.
func (u *UserUpsertOne) UpdateSignupSource() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.UpdateSignupSource()
})
}
// SetLastLoginAt sets the "last_login_at" field.
func (u *UserUpsertOne) SetLastLoginAt(v time.Time) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.SetLastLoginAt(v)
})
}
// UpdateLastLoginAt sets the "last_login_at" field to the value that was provided on create.
func (u *UserUpsertOne) UpdateLastLoginAt() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.UpdateLastLoginAt()
})
}
// ClearLastLoginAt clears the value of the "last_login_at" field.
func (u *UserUpsertOne) ClearLastLoginAt() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.ClearLastLoginAt()
})
}
// SetLastActiveAt sets the "last_active_at" field.
func (u *UserUpsertOne) SetLastActiveAt(v time.Time) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.SetLastActiveAt(v)
})
}
// UpdateLastActiveAt sets the "last_active_at" field to the value that was provided on create.
func (u *UserUpsertOne) UpdateLastActiveAt() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.UpdateLastActiveAt()
})
}
// ClearLastActiveAt clears the value of the "last_active_at" field.
func (u *UserUpsertOne) ClearLastActiveAt() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.ClearLastActiveAt()
})
}
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (u *UserUpsertOne) SetBalanceNotifyEnabled(v bool) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
......@@ -1537,6 +1814,27 @@ func (u *UserUpsertOne) UpdateTotalRecharged() *UserUpsertOne {
})
}
// SetRpmLimit sets the "rpm_limit" field.
func (u *UserUpsertOne) SetRpmLimit(v int) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.SetRpmLimit(v)
})
}
// AddRpmLimit adds v to the "rpm_limit" field.
func (u *UserUpsertOne) AddRpmLimit(v int) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.AddRpmLimit(v)
})
}
// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create.
func (u *UserUpsertOne) UpdateRpmLimit() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.UpdateRpmLimit()
})
}
// Exec executes the query.
func (u *UserUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
......@@ -1965,6 +2263,62 @@ func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk {
})
}
// SetSignupSource sets the "signup_source" field.
func (u *UserUpsertBulk) SetSignupSource(v string) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.SetSignupSource(v)
})
}
// UpdateSignupSource sets the "signup_source" field to the value that was provided on create.
func (u *UserUpsertBulk) UpdateSignupSource() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.UpdateSignupSource()
})
}
// SetLastLoginAt sets the "last_login_at" field.
func (u *UserUpsertBulk) SetLastLoginAt(v time.Time) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.SetLastLoginAt(v)
})
}
// UpdateLastLoginAt sets the "last_login_at" field to the value that was provided on create.
func (u *UserUpsertBulk) UpdateLastLoginAt() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.UpdateLastLoginAt()
})
}
// ClearLastLoginAt clears the value of the "last_login_at" field.
func (u *UserUpsertBulk) ClearLastLoginAt() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.ClearLastLoginAt()
})
}
// SetLastActiveAt sets the "last_active_at" field.
func (u *UserUpsertBulk) SetLastActiveAt(v time.Time) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.SetLastActiveAt(v)
})
}
// UpdateLastActiveAt sets the "last_active_at" field to the value that was provided on create.
func (u *UserUpsertBulk) UpdateLastActiveAt() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.UpdateLastActiveAt()
})
}
// ClearLastActiveAt clears the value of the "last_active_at" field.
func (u *UserUpsertBulk) ClearLastActiveAt() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.ClearLastActiveAt()
})
}
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (u *UserUpsertBulk) SetBalanceNotifyEnabled(v bool) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
......@@ -2056,6 +2410,27 @@ func (u *UserUpsertBulk) UpdateTotalRecharged() *UserUpsertBulk {
})
}
// SetRpmLimit sets the "rpm_limit" field.
func (u *UserUpsertBulk) SetRpmLimit(v int) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.SetRpmLimit(v)
})
}
// AddRpmLimit adds v to the "rpm_limit" field.
func (u *UserUpsertBulk) AddRpmLimit(v int) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.AddRpmLimit(v)
})
}
// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create.
func (u *UserUpsertBulk) UpdateRpmLimit() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.UpdateRpmLimit()
})
}
// Exec executes the query.
func (u *UserUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {
......
......@@ -15,8 +15,10 @@ import (
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
......@@ -44,6 +46,8 @@ type UserQuery struct {
withAttributeValues *UserAttributeValueQuery
withPromoCodeUsages *PromoCodeUsageQuery
withPaymentOrders *PaymentOrderQuery
withAuthIdentities *AuthIdentityQuery
withPendingAuthSessions *PendingAuthSessionQuery
withUserAllowedGroups *UserAllowedGroupQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
......@@ -302,6 +306,50 @@ func (_q *UserQuery) QueryPaymentOrders() *PaymentOrderQuery {
return query
}
// QueryAuthIdentities chains the current query on the "auth_identities" edge.
func (_q *UserQuery) QueryAuthIdentities() *AuthIdentityQuery {
query := (&AuthIdentityClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(user.Table, user.FieldID, selector),
sqlgraph.To(authidentity.Table, authidentity.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, user.AuthIdentitiesTable, user.AuthIdentitiesColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// QueryPendingAuthSessions chains the current query on the "pending_auth_sessions" edge.
func (_q *UserQuery) QueryPendingAuthSessions() *PendingAuthSessionQuery {
query := (&PendingAuthSessionClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(user.Table, user.FieldID, selector),
sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, user.PendingAuthSessionsTable, user.PendingAuthSessionsColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// QueryUserAllowedGroups chains the current query on the "user_allowed_groups" edge.
func (_q *UserQuery) QueryUserAllowedGroups() *UserAllowedGroupQuery {
query := (&UserAllowedGroupClient{config: _q.config}).Query()
......@@ -526,6 +574,8 @@ func (_q *UserQuery) Clone() *UserQuery {
withAttributeValues: _q.withAttributeValues.Clone(),
withPromoCodeUsages: _q.withPromoCodeUsages.Clone(),
withPaymentOrders: _q.withPaymentOrders.Clone(),
withAuthIdentities: _q.withAuthIdentities.Clone(),
withPendingAuthSessions: _q.withPendingAuthSessions.Clone(),
withUserAllowedGroups: _q.withUserAllowedGroups.Clone(),
// clone intermediate query.
sql: _q.sql.Clone(),
......@@ -643,6 +693,28 @@ func (_q *UserQuery) WithPaymentOrders(opts ...func(*PaymentOrderQuery)) *UserQu
return _q
}
// WithAuthIdentities tells the query-builder to eager-load the nodes that are connected to
// the "auth_identities" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithAuthIdentities(opts ...func(*AuthIdentityQuery)) *UserQuery {
query := (&AuthIdentityClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withAuthIdentities = query
return _q
}
// WithPendingAuthSessions tells the query-builder to eager-load the nodes that are connected to
// the "pending_auth_sessions" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithPendingAuthSessions(opts ...func(*PendingAuthSessionQuery)) *UserQuery {
query := (&PendingAuthSessionClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withPendingAuthSessions = query
return _q
}
// WithUserAllowedGroups tells the query-builder to eager-load the nodes that are connected to
// the "user_allowed_groups" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithUserAllowedGroups(opts ...func(*UserAllowedGroupQuery)) *UserQuery {
......@@ -732,7 +804,7 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
var (
nodes = []*User{}
_spec = _q.querySpec()
loadedTypes = [11]bool{
loadedTypes = [13]bool{
_q.withAPIKeys != nil,
_q.withRedeemCodes != nil,
_q.withSubscriptions != nil,
......@@ -743,6 +815,8 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
_q.withAttributeValues != nil,
_q.withPromoCodeUsages != nil,
_q.withPaymentOrders != nil,
_q.withAuthIdentities != nil,
_q.withPendingAuthSessions != nil,
_q.withUserAllowedGroups != nil,
}
)
......@@ -839,6 +913,22 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
return nil, err
}
}
if query := _q.withAuthIdentities; query != nil {
if err := _q.loadAuthIdentities(ctx, query, nodes,
func(n *User) { n.Edges.AuthIdentities = []*AuthIdentity{} },
func(n *User, e *AuthIdentity) { n.Edges.AuthIdentities = append(n.Edges.AuthIdentities, e) }); err != nil {
return nil, err
}
}
if query := _q.withPendingAuthSessions; query != nil {
if err := _q.loadPendingAuthSessions(ctx, query, nodes,
func(n *User) { n.Edges.PendingAuthSessions = []*PendingAuthSession{} },
func(n *User, e *PendingAuthSession) {
n.Edges.PendingAuthSessions = append(n.Edges.PendingAuthSessions, e)
}); err != nil {
return nil, err
}
}
if query := _q.withUserAllowedGroups; query != nil {
if err := _q.loadUserAllowedGroups(ctx, query, nodes,
func(n *User) { n.Edges.UserAllowedGroups = []*UserAllowedGroup{} },
......@@ -1186,6 +1276,69 @@ func (_q *UserQuery) loadPaymentOrders(ctx context.Context, query *PaymentOrderQ
}
return nil
}
func (_q *UserQuery) loadAuthIdentities(ctx context.Context, query *AuthIdentityQuery, nodes []*User, init func(*User), assign func(*User, *AuthIdentity)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*User)
for i := range nodes {
fks = append(fks, nodes[i].ID)
nodeids[nodes[i].ID] = nodes[i]
if init != nil {
init(nodes[i])
}
}
if len(query.ctx.Fields) > 0 {
query.ctx.AppendFieldOnce(authidentity.FieldUserID)
}
query.Where(predicate.AuthIdentity(func(s *sql.Selector) {
s.Where(sql.InValues(s.C(user.AuthIdentitiesColumn), fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
fk := n.UserID
node, ok := nodeids[fk]
if !ok {
return fmt.Errorf(`unexpected referenced foreign-key "user_id" returned %v for node %v`, fk, n.ID)
}
assign(node, n)
}
return nil
}
func (_q *UserQuery) loadPendingAuthSessions(ctx context.Context, query *PendingAuthSessionQuery, nodes []*User, init func(*User), assign func(*User, *PendingAuthSession)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*User)
for i := range nodes {
fks = append(fks, nodes[i].ID)
nodeids[nodes[i].ID] = nodes[i]
if init != nil {
init(nodes[i])
}
}
if len(query.ctx.Fields) > 0 {
query.ctx.AppendFieldOnce(pendingauthsession.FieldTargetUserID)
}
query.Where(predicate.PendingAuthSession(func(s *sql.Selector) {
s.Where(sql.InValues(s.C(user.PendingAuthSessionsColumn), fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
fk := n.TargetUserID
if fk == nil {
return fmt.Errorf(`foreign-key "target_user_id" is nil for node %v`, n.ID)
}
node, ok := nodeids[*fk]
if !ok {
return fmt.Errorf(`unexpected referenced foreign-key "target_user_id" returned %v for node %v`, *fk, n.ID)
}
assign(node, n)
}
return nil
}
func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllowedGroupQuery, nodes []*User, init func(*User), assign func(*User, *UserAllowedGroup)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*User)
......
......@@ -13,8 +13,10 @@ import (
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
......@@ -243,6 +245,60 @@ func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate {
return _u
}
// SetSignupSource sets the "signup_source" field.
func (_u *UserUpdate) SetSignupSource(v string) *UserUpdate {
_u.mutation.SetSignupSource(v)
return _u
}
// SetNillableSignupSource sets the "signup_source" field if the given value is not nil.
func (_u *UserUpdate) SetNillableSignupSource(v *string) *UserUpdate {
if v != nil {
_u.SetSignupSource(*v)
}
return _u
}
// SetLastLoginAt sets the "last_login_at" field.
func (_u *UserUpdate) SetLastLoginAt(v time.Time) *UserUpdate {
_u.mutation.SetLastLoginAt(v)
return _u
}
// SetNillableLastLoginAt sets the "last_login_at" field if the given value is not nil.
func (_u *UserUpdate) SetNillableLastLoginAt(v *time.Time) *UserUpdate {
if v != nil {
_u.SetLastLoginAt(*v)
}
return _u
}
// ClearLastLoginAt clears the value of the "last_login_at" field.
func (_u *UserUpdate) ClearLastLoginAt() *UserUpdate {
_u.mutation.ClearLastLoginAt()
return _u
}
// SetLastActiveAt sets the "last_active_at" field.
func (_u *UserUpdate) SetLastActiveAt(v time.Time) *UserUpdate {
_u.mutation.SetLastActiveAt(v)
return _u
}
// SetNillableLastActiveAt sets the "last_active_at" field if the given value is not nil.
func (_u *UserUpdate) SetNillableLastActiveAt(v *time.Time) *UserUpdate {
if v != nil {
_u.SetLastActiveAt(*v)
}
return _u
}
// ClearLastActiveAt clears the value of the "last_active_at" field.
func (_u *UserUpdate) ClearLastActiveAt() *UserUpdate {
_u.mutation.ClearLastActiveAt()
return _u
}
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (_u *UserUpdate) SetBalanceNotifyEnabled(v bool) *UserUpdate {
_u.mutation.SetBalanceNotifyEnabled(v)
......@@ -333,6 +389,27 @@ func (_u *UserUpdate) AddTotalRecharged(v float64) *UserUpdate {
return _u
}
// SetRpmLimit sets the "rpm_limit" field.
func (_u *UserUpdate) SetRpmLimit(v int) *UserUpdate {
_u.mutation.ResetRpmLimit()
_u.mutation.SetRpmLimit(v)
return _u
}
// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil.
func (_u *UserUpdate) SetNillableRpmLimit(v *int) *UserUpdate {
if v != nil {
_u.SetRpmLimit(*v)
}
return _u
}
// AddRpmLimit adds value to the "rpm_limit" field.
func (_u *UserUpdate) AddRpmLimit(v int) *UserUpdate {
_u.mutation.AddRpmLimit(v)
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate {
_u.mutation.AddAPIKeyIDs(ids...)
......@@ -483,6 +560,36 @@ func (_u *UserUpdate) AddPaymentOrders(v ...*PaymentOrder) *UserUpdate {
return _u.AddPaymentOrderIDs(ids...)
}
// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by IDs.
func (_u *UserUpdate) AddAuthIdentityIDs(ids ...int64) *UserUpdate {
_u.mutation.AddAuthIdentityIDs(ids...)
return _u
}
// AddAuthIdentities adds the "auth_identities" edges to the AuthIdentity entity.
func (_u *UserUpdate) AddAuthIdentities(v ...*AuthIdentity) *UserUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddAuthIdentityIDs(ids...)
}
// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs.
func (_u *UserUpdate) AddPendingAuthSessionIDs(ids ...int64) *UserUpdate {
_u.mutation.AddPendingAuthSessionIDs(ids...)
return _u
}
// AddPendingAuthSessions adds the "pending_auth_sessions" edges to the PendingAuthSession entity.
func (_u *UserUpdate) AddPendingAuthSessions(v ...*PendingAuthSession) *UserUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddPendingAuthSessionIDs(ids...)
}
// Mutation returns the UserMutation object of the builder.
func (_u *UserUpdate) Mutation() *UserMutation {
return _u.mutation
......@@ -698,6 +805,48 @@ func (_u *UserUpdate) RemovePaymentOrders(v ...*PaymentOrder) *UserUpdate {
return _u.RemovePaymentOrderIDs(ids...)
}
// ClearAuthIdentities clears all "auth_identities" edges to the AuthIdentity entity.
func (_u *UserUpdate) ClearAuthIdentities() *UserUpdate {
_u.mutation.ClearAuthIdentities()
return _u
}
// RemoveAuthIdentityIDs removes the "auth_identities" edge to AuthIdentity entities by IDs.
func (_u *UserUpdate) RemoveAuthIdentityIDs(ids ...int64) *UserUpdate {
_u.mutation.RemoveAuthIdentityIDs(ids...)
return _u
}
// RemoveAuthIdentities removes "auth_identities" edges to AuthIdentity entities.
func (_u *UserUpdate) RemoveAuthIdentities(v ...*AuthIdentity) *UserUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveAuthIdentityIDs(ids...)
}
// ClearPendingAuthSessions clears all "pending_auth_sessions" edges to the PendingAuthSession entity.
func (_u *UserUpdate) ClearPendingAuthSessions() *UserUpdate {
_u.mutation.ClearPendingAuthSessions()
return _u
}
// RemovePendingAuthSessionIDs removes the "pending_auth_sessions" edge to PendingAuthSession entities by IDs.
func (_u *UserUpdate) RemovePendingAuthSessionIDs(ids ...int64) *UserUpdate {
_u.mutation.RemovePendingAuthSessionIDs(ids...)
return _u
}
// RemovePendingAuthSessions removes "pending_auth_sessions" edges to PendingAuthSession entities.
func (_u *UserUpdate) RemovePendingAuthSessions(v ...*PendingAuthSession) *UserUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemovePendingAuthSessionIDs(ids...)
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *UserUpdate) Save(ctx context.Context) (int, error) {
if err := _u.defaults(); err != nil {
......@@ -767,6 +916,11 @@ func (_u *UserUpdate) check() error {
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
}
}
if v, ok := _u.mutation.SignupSource(); ok {
if err := user.SignupSourceValidator(v); err != nil {
return &ValidationError{Name: "signup_source", err: fmt.Errorf(`ent: validator failed for field "User.signup_source": %w`, err)}
}
}
return nil
}
......@@ -836,6 +990,21 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
}
if value, ok := _u.mutation.SignupSource(); ok {
_spec.SetField(user.FieldSignupSource, field.TypeString, value)
}
if value, ok := _u.mutation.LastLoginAt(); ok {
_spec.SetField(user.FieldLastLoginAt, field.TypeTime, value)
}
if _u.mutation.LastLoginAtCleared() {
_spec.ClearField(user.FieldLastLoginAt, field.TypeTime)
}
if value, ok := _u.mutation.LastActiveAt(); ok {
_spec.SetField(user.FieldLastActiveAt, field.TypeTime, value)
}
if _u.mutation.LastActiveAtCleared() {
_spec.ClearField(user.FieldLastActiveAt, field.TypeTime)
}
if value, ok := _u.mutation.BalanceNotifyEnabled(); ok {
_spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
}
......@@ -860,6 +1029,12 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.AddedTotalRecharged(); ok {
_spec.AddField(user.FieldTotalRecharged, field.TypeFloat64, value)
}
if value, ok := _u.mutation.RpmLimit(); ok {
_spec.SetField(user.FieldRpmLimit, field.TypeInt, value)
}
if value, ok := _u.mutation.AddedRpmLimit(); ok {
_spec.AddField(user.FieldRpmLimit, field.TypeInt, value)
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
......@@ -1322,6 +1497,96 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.AuthIdentitiesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AuthIdentitiesTable,
Columns: []string{user.AuthIdentitiesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedAuthIdentitiesIDs(); len(nodes) > 0 && !_u.mutation.AuthIdentitiesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AuthIdentitiesTable,
Columns: []string{user.AuthIdentitiesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.AuthIdentitiesIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AuthIdentitiesTable,
Columns: []string{user.AuthIdentitiesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.PendingAuthSessionsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PendingAuthSessionsTable,
Columns: []string{user.PendingAuthSessionsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedPendingAuthSessionsIDs(); len(nodes) > 0 && !_u.mutation.PendingAuthSessionsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PendingAuthSessionsTable,
Columns: []string{user.PendingAuthSessionsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.PendingAuthSessionsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PendingAuthSessionsTable,
Columns: []string{user.PendingAuthSessionsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{user.Label}
......@@ -1548,6 +1813,60 @@ func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne {
return _u
}
// SetSignupSource sets the "signup_source" field.
func (_u *UserUpdateOne) SetSignupSource(v string) *UserUpdateOne {
_u.mutation.SetSignupSource(v)
return _u
}
// SetNillableSignupSource sets the "signup_source" field if the given value is not nil.
func (_u *UserUpdateOne) SetNillableSignupSource(v *string) *UserUpdateOne {
if v != nil {
_u.SetSignupSource(*v)
}
return _u
}
// SetLastLoginAt sets the "last_login_at" field.
func (_u *UserUpdateOne) SetLastLoginAt(v time.Time) *UserUpdateOne {
_u.mutation.SetLastLoginAt(v)
return _u
}
// SetNillableLastLoginAt sets the "last_login_at" field if the given value is not nil.
func (_u *UserUpdateOne) SetNillableLastLoginAt(v *time.Time) *UserUpdateOne {
if v != nil {
_u.SetLastLoginAt(*v)
}
return _u
}
// ClearLastLoginAt clears the value of the "last_login_at" field.
func (_u *UserUpdateOne) ClearLastLoginAt() *UserUpdateOne {
_u.mutation.ClearLastLoginAt()
return _u
}
// SetLastActiveAt sets the "last_active_at" field.
func (_u *UserUpdateOne) SetLastActiveAt(v time.Time) *UserUpdateOne {
_u.mutation.SetLastActiveAt(v)
return _u
}
// SetNillableLastActiveAt sets the "last_active_at" field if the given value is not nil.
func (_u *UserUpdateOne) SetNillableLastActiveAt(v *time.Time) *UserUpdateOne {
if v != nil {
_u.SetLastActiveAt(*v)
}
return _u
}
// ClearLastActiveAt clears the value of the "last_active_at" field.
func (_u *UserUpdateOne) ClearLastActiveAt() *UserUpdateOne {
_u.mutation.ClearLastActiveAt()
return _u
}
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (_u *UserUpdateOne) SetBalanceNotifyEnabled(v bool) *UserUpdateOne {
_u.mutation.SetBalanceNotifyEnabled(v)
......@@ -1638,6 +1957,27 @@ func (_u *UserUpdateOne) AddTotalRecharged(v float64) *UserUpdateOne {
return _u
}
// SetRpmLimit sets the "rpm_limit" field.
func (_u *UserUpdateOne) SetRpmLimit(v int) *UserUpdateOne {
_u.mutation.ResetRpmLimit()
_u.mutation.SetRpmLimit(v)
return _u
}
// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil.
func (_u *UserUpdateOne) SetNillableRpmLimit(v *int) *UserUpdateOne {
if v != nil {
_u.SetRpmLimit(*v)
}
return _u
}
// AddRpmLimit adds value to the "rpm_limit" field.
func (_u *UserUpdateOne) AddRpmLimit(v int) *UserUpdateOne {
_u.mutation.AddRpmLimit(v)
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddAPIKeyIDs(ids...)
......@@ -1788,6 +2128,36 @@ func (_u *UserUpdateOne) AddPaymentOrders(v ...*PaymentOrder) *UserUpdateOne {
return _u.AddPaymentOrderIDs(ids...)
}
// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by IDs.
func (_u *UserUpdateOne) AddAuthIdentityIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddAuthIdentityIDs(ids...)
return _u
}
// AddAuthIdentities adds the "auth_identities" edges to the AuthIdentity entity.
func (_u *UserUpdateOne) AddAuthIdentities(v ...*AuthIdentity) *UserUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddAuthIdentityIDs(ids...)
}
// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs.
func (_u *UserUpdateOne) AddPendingAuthSessionIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddPendingAuthSessionIDs(ids...)
return _u
}
// AddPendingAuthSessions adds the "pending_auth_sessions" edges to the PendingAuthSession entity.
func (_u *UserUpdateOne) AddPendingAuthSessions(v ...*PendingAuthSession) *UserUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddPendingAuthSessionIDs(ids...)
}
// Mutation returns the UserMutation object of the builder.
func (_u *UserUpdateOne) Mutation() *UserMutation {
return _u.mutation
......@@ -2003,6 +2373,48 @@ func (_u *UserUpdateOne) RemovePaymentOrders(v ...*PaymentOrder) *UserUpdateOne
return _u.RemovePaymentOrderIDs(ids...)
}
// ClearAuthIdentities clears all "auth_identities" edges to the AuthIdentity entity.
func (_u *UserUpdateOne) ClearAuthIdentities() *UserUpdateOne {
_u.mutation.ClearAuthIdentities()
return _u
}
// RemoveAuthIdentityIDs removes the "auth_identities" edge to AuthIdentity entities by IDs.
func (_u *UserUpdateOne) RemoveAuthIdentityIDs(ids ...int64) *UserUpdateOne {
_u.mutation.RemoveAuthIdentityIDs(ids...)
return _u
}
// RemoveAuthIdentities removes "auth_identities" edges to AuthIdentity entities.
func (_u *UserUpdateOne) RemoveAuthIdentities(v ...*AuthIdentity) *UserUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveAuthIdentityIDs(ids...)
}
// ClearPendingAuthSessions clears all "pending_auth_sessions" edges to the PendingAuthSession entity.
func (_u *UserUpdateOne) ClearPendingAuthSessions() *UserUpdateOne {
_u.mutation.ClearPendingAuthSessions()
return _u
}
// RemovePendingAuthSessionIDs removes the "pending_auth_sessions" edge to PendingAuthSession entities by IDs.
func (_u *UserUpdateOne) RemovePendingAuthSessionIDs(ids ...int64) *UserUpdateOne {
_u.mutation.RemovePendingAuthSessionIDs(ids...)
return _u
}
// RemovePendingAuthSessions removes "pending_auth_sessions" edges to PendingAuthSession entities.
func (_u *UserUpdateOne) RemovePendingAuthSessions(v ...*PendingAuthSession) *UserUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemovePendingAuthSessionIDs(ids...)
}
// Where appends a list predicates to the UserUpdate builder.
func (_u *UserUpdateOne) Where(ps ...predicate.User) *UserUpdateOne {
_u.mutation.Where(ps...)
......@@ -2085,6 +2497,11 @@ func (_u *UserUpdateOne) check() error {
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
}
}
if v, ok := _u.mutation.SignupSource(); ok {
if err := user.SignupSourceValidator(v); err != nil {
return &ValidationError{Name: "signup_source", err: fmt.Errorf(`ent: validator failed for field "User.signup_source": %w`, err)}
}
}
return nil
}
......@@ -2171,6 +2588,21 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
}
if value, ok := _u.mutation.SignupSource(); ok {
_spec.SetField(user.FieldSignupSource, field.TypeString, value)
}
if value, ok := _u.mutation.LastLoginAt(); ok {
_spec.SetField(user.FieldLastLoginAt, field.TypeTime, value)
}
if _u.mutation.LastLoginAtCleared() {
_spec.ClearField(user.FieldLastLoginAt, field.TypeTime)
}
if value, ok := _u.mutation.LastActiveAt(); ok {
_spec.SetField(user.FieldLastActiveAt, field.TypeTime, value)
}
if _u.mutation.LastActiveAtCleared() {
_spec.ClearField(user.FieldLastActiveAt, field.TypeTime)
}
if value, ok := _u.mutation.BalanceNotifyEnabled(); ok {
_spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
}
......@@ -2195,6 +2627,12 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
if value, ok := _u.mutation.AddedTotalRecharged(); ok {
_spec.AddField(user.FieldTotalRecharged, field.TypeFloat64, value)
}
if value, ok := _u.mutation.RpmLimit(); ok {
_spec.SetField(user.FieldRpmLimit, field.TypeInt, value)
}
if value, ok := _u.mutation.AddedRpmLimit(); ok {
_spec.AddField(user.FieldRpmLimit, field.TypeInt, value)
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
......@@ -2657,6 +3095,96 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.AuthIdentitiesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AuthIdentitiesTable,
Columns: []string{user.AuthIdentitiesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedAuthIdentitiesIDs(); len(nodes) > 0 && !_u.mutation.AuthIdentitiesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AuthIdentitiesTable,
Columns: []string{user.AuthIdentitiesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.AuthIdentitiesIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AuthIdentitiesTable,
Columns: []string{user.AuthIdentitiesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.PendingAuthSessionsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PendingAuthSessionsTable,
Columns: []string{user.PendingAuthSessionsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedPendingAuthSessionsIDs(); len(nodes) > 0 && !_u.mutation.PendingAuthSessionsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PendingAuthSessionsTable,
Columns: []string{user.PendingAuthSessionsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.PendingAuthSessionsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PendingAuthSessionsTable,
Columns: []string{user.PendingAuthSessionsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
_node = &User{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues
......
......@@ -39,10 +39,11 @@ require (
github.com/wechatpay-apiv3/wechatpay-go v0.2.21
github.com/zeromicro/go-zero v1.9.4
go.uber.org/zap v1.24.0
golang.org/x/crypto v0.48.0
golang.org/x/net v0.49.0
golang.org/x/sync v0.19.0
golang.org/x/term v0.40.0
golang.org/x/crypto v0.49.0
golang.org/x/image v0.39.0
golang.org/x/net v0.52.0
golang.org/x/sync v0.20.0
golang.org/x/term v0.41.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1
gopkg.in/yaml.v3 v3.0.1
modernc.org/sqlite v1.44.3
......@@ -172,10 +173,10 @@ require (
go.uber.org/multierr v1.9.0 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
golang.org/x/mod v0.32.0 // indirect
golang.org/x/sys v0.41.0 // indirect
golang.org/x/text v0.34.0 // indirect
golang.org/x/tools v0.41.0 // indirect
golang.org/x/mod v0.34.0 // indirect
golang.org/x/sys v0.42.0 // indirect
golang.org/x/text v0.36.0 // indirect
golang.org/x/tools v0.43.0 // indirect
google.golang.org/grpc v1.75.1 // indirect
google.golang.org/protobuf v1.36.10 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
......
......@@ -413,16 +413,18 @@ go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c=
golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU=
golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/image v0.39.0 h1:skVYidAEVKgn8lZ602XO75asgXBgLj9G/FE3RbuPFww=
golang.org/x/image v0.39.0/go.mod h1:sIbmppfU+xFLPIG0FoVUTvyBMmgng1/XAMhQ2ft0hpA=
golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI=
golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY=
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
......@@ -432,16 +434,16 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU=
golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A=
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s=
golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ=
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU=
......
......@@ -52,6 +52,11 @@ const (
ConnectionPoolIsolationAccountProxy = "account_proxy"
)
// DefaultUpstreamResponseReadMaxBytes 上游非流式响应体的默认读取上限。
// 128 MB 可容纳 2-3 张 4K PNG(base64 膨胀 33%,单张 4K PNG 最坏约 67MB base64)。
// 可通过 gateway.upstream_response_read_max_bytes 配置项覆盖。
const DefaultUpstreamResponseReadMaxBytes int64 = 128 * 1024 * 1024
type Config struct {
Server ServerConfig `mapstructure:"server"`
Log LogConfig `mapstructure:"log"`
......@@ -65,6 +70,7 @@ type Config struct {
JWT JWTConfig `mapstructure:"jwt"`
Totp TotpConfig `mapstructure:"totp"`
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
WeChat WeChatConnectConfig `mapstructure:"wechat_connect"`
OIDC OIDCConnectConfig `mapstructure:"oidc_connect"`
Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
......@@ -185,26 +191,47 @@ type LinuxDoConnectConfig struct {
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
}
type WeChatConnectConfig struct {
Enabled bool `mapstructure:"enabled"`
AppID string `mapstructure:"app_id"`
AppSecret string `mapstructure:"app_secret"`
OpenAppID string `mapstructure:"open_app_id"`
OpenAppSecret string `mapstructure:"open_app_secret"`
MPAppID string `mapstructure:"mp_app_id"`
MPAppSecret string `mapstructure:"mp_app_secret"`
MobileAppID string `mapstructure:"mobile_app_id"`
MobileAppSecret string `mapstructure:"mobile_app_secret"`
OpenEnabled bool `mapstructure:"open_enabled"`
MPEnabled bool `mapstructure:"mp_enabled"`
MobileEnabled bool `mapstructure:"mobile_enabled"`
Mode string `mapstructure:"mode"`
Scopes string `mapstructure:"scopes"`
RedirectURL string `mapstructure:"redirect_url"`
FrontendRedirectURL string `mapstructure:"frontend_redirect_url"`
}
type OIDCConnectConfig struct {
Enabled bool `mapstructure:"enabled"`
ProviderName string `mapstructure:"provider_name"` // 显示名: "Keycloak" 等
ClientID string `mapstructure:"client_id"`
ClientSecret string `mapstructure:"client_secret"`
IssuerURL string `mapstructure:"issuer_url"`
DiscoveryURL string `mapstructure:"discovery_url"`
AuthorizeURL string `mapstructure:"authorize_url"`
TokenURL string `mapstructure:"token_url"`
UserInfoURL string `mapstructure:"userinfo_url"`
JWKSURL string `mapstructure:"jwks_url"`
Scopes string `mapstructure:"scopes"` // 默认 "openid email profile"
RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记)
FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/oidc/callback)
TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none
UsePKCE bool `mapstructure:"use_pkce"`
ValidateIDToken bool `mapstructure:"validate_id_token"`
AllowedSigningAlgs string `mapstructure:"allowed_signing_algs"` // 默认 "RS256,ES256,PS256"
ClockSkewSeconds int `mapstructure:"clock_skew_seconds"` // 默认 120
RequireEmailVerified bool `mapstructure:"require_email_verified"` // 默认 false
Enabled bool `mapstructure:"enabled"`
ProviderName string `mapstructure:"provider_name"` // 显示名: "Keycloak" 等
ClientID string `mapstructure:"client_id"`
ClientSecret string `mapstructure:"client_secret"`
IssuerURL string `mapstructure:"issuer_url"`
DiscoveryURL string `mapstructure:"discovery_url"`
AuthorizeURL string `mapstructure:"authorize_url"`
TokenURL string `mapstructure:"token_url"`
UserInfoURL string `mapstructure:"userinfo_url"`
JWKSURL string `mapstructure:"jwks_url"`
Scopes string `mapstructure:"scopes"` // 默认 "openid email profile"
RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记)
FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/oidc/callback)
TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none
UsePKCE bool `mapstructure:"use_pkce"`
ValidateIDToken bool `mapstructure:"validate_id_token"`
UsePKCEExplicit bool `mapstructure:"-" yaml:"-"`
ValidateIDTokenExplicit bool `mapstructure:"-" yaml:"-"`
AllowedSigningAlgs string `mapstructure:"allowed_signing_algs"` // 默认 "RS256,ES256,PS256"
ClockSkewSeconds int `mapstructure:"clock_skew_seconds"` // 默认 120
RequireEmailVerified bool `mapstructure:"require_email_verified"` // 默认 false
// 可选:用于从 userinfo JSON 中提取字段的 gjson 路径。
// 为空时,服务端会尝试一组常见字段名。
......@@ -213,6 +240,225 @@ type OIDCConnectConfig struct {
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
}
const (
defaultWeChatConnectMode = "open"
defaultWeChatConnectScopes = "snsapi_login"
defaultWeChatConnectFrontendRedirect = "/auth/wechat/callback"
)
func firstNonEmptyString(values ...string) string {
for _, value := range values {
if trimmed := strings.TrimSpace(value); trimmed != "" {
return trimmed
}
}
return ""
}
func normalizeWeChatConnectMode(raw string) string {
switch strings.ToLower(strings.TrimSpace(raw)) {
case "mp":
return "mp"
case "mobile":
return "mobile"
default:
return defaultWeChatConnectMode
}
}
func normalizeWeChatConnectStoredMode(openEnabled, mpEnabled, mobileEnabled bool, mode string) string {
mode = normalizeWeChatConnectMode(mode)
switch mode {
case "open":
if openEnabled {
return "open"
}
case "mp":
if mpEnabled {
return "mp"
}
case "mobile":
if mobileEnabled {
return "mobile"
}
}
switch {
case openEnabled:
return "open"
case mpEnabled:
return "mp"
case mobileEnabled:
return "mobile"
default:
return mode
}
}
func defaultWeChatConnectScopesForMode(mode string) string {
switch normalizeWeChatConnectMode(mode) {
case "mp":
return "snsapi_userinfo"
case "mobile":
return ""
default:
return defaultWeChatConnectScopes
}
}
func normalizeWeChatConnectScopes(raw, mode string) string {
switch normalizeWeChatConnectMode(mode) {
case "mp":
switch strings.TrimSpace(raw) {
case "snsapi_base":
return "snsapi_base"
case "snsapi_userinfo":
return "snsapi_userinfo"
default:
return defaultWeChatConnectScopesForMode(mode)
}
case "mobile":
return ""
default:
return defaultWeChatConnectScopes
}
}
func shouldApplyLegacyWeChatEnv(configKey, envKey string) bool {
if viper.InConfig(configKey) {
return false
}
_, hasNewEnv := os.LookupEnv(envKey)
return !hasNewEnv
}
func hasExplicitConfigOrEnv(configKey, envKey string) bool {
if viper.InConfig(configKey) {
return true
}
_, ok := os.LookupEnv(envKey)
return ok
}
func applyLegacyWeChatConnectEnvCompatibility(cfg *WeChatConnectConfig) {
if cfg == nil {
return
}
legacyOpenAppID := ""
if shouldApplyLegacyWeChatEnv("wechat_connect.open_app_id", "WECHAT_CONNECT_OPEN_APP_ID") &&
shouldApplyLegacyWeChatEnv("wechat_connect.app_id", "WECHAT_CONNECT_APP_ID") {
legacyOpenAppID = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_ID"))
if legacyOpenAppID != "" {
cfg.OpenAppID = legacyOpenAppID
}
}
legacyOpenAppSecret := ""
if shouldApplyLegacyWeChatEnv("wechat_connect.open_app_secret", "WECHAT_CONNECT_OPEN_APP_SECRET") &&
shouldApplyLegacyWeChatEnv("wechat_connect.app_secret", "WECHAT_CONNECT_APP_SECRET") {
legacyOpenAppSecret = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_SECRET"))
if legacyOpenAppSecret != "" {
cfg.OpenAppSecret = legacyOpenAppSecret
}
}
legacyMPAppID := ""
if shouldApplyLegacyWeChatEnv("wechat_connect.mp_app_id", "WECHAT_CONNECT_MP_APP_ID") &&
shouldApplyLegacyWeChatEnv("wechat_connect.app_id", "WECHAT_CONNECT_APP_ID") {
legacyMPAppID = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_ID"))
if legacyMPAppID != "" {
cfg.MPAppID = legacyMPAppID
}
}
legacyMPAppSecret := ""
if shouldApplyLegacyWeChatEnv("wechat_connect.mp_app_secret", "WECHAT_CONNECT_MP_APP_SECRET") &&
shouldApplyLegacyWeChatEnv("wechat_connect.app_secret", "WECHAT_CONNECT_APP_SECRET") {
legacyMPAppSecret = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_SECRET"))
if legacyMPAppSecret != "" {
cfg.MPAppSecret = legacyMPAppSecret
}
}
if shouldApplyLegacyWeChatEnv("wechat_connect.frontend_redirect_url", "WECHAT_CONNECT_FRONTEND_REDIRECT_URL") {
if legacyFrontend := strings.TrimSpace(os.Getenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL")); legacyFrontend != "" {
cfg.FrontendRedirectURL = legacyFrontend
}
}
hasLegacyOpen := legacyOpenAppID != "" && legacyOpenAppSecret != ""
hasLegacyMP := legacyMPAppID != "" && legacyMPAppSecret != ""
if shouldApplyLegacyWeChatEnv("wechat_connect.enabled", "WECHAT_CONNECT_ENABLED") && (hasLegacyOpen || hasLegacyMP) {
cfg.Enabled = true
}
if shouldApplyLegacyWeChatEnv("wechat_connect.open_enabled", "WECHAT_CONNECT_OPEN_ENABLED") && hasLegacyOpen {
cfg.OpenEnabled = true
}
if shouldApplyLegacyWeChatEnv("wechat_connect.mp_enabled", "WECHAT_CONNECT_MP_ENABLED") && hasLegacyMP {
cfg.MPEnabled = true
}
if shouldApplyLegacyWeChatEnv("wechat_connect.mode", "WECHAT_CONNECT_MODE") {
switch {
case hasLegacyMP && !hasLegacyOpen:
cfg.Mode = "mp"
case hasLegacyOpen:
cfg.Mode = "open"
}
}
if shouldApplyLegacyWeChatEnv("wechat_connect.scopes", "WECHAT_CONNECT_SCOPES") {
switch {
case hasLegacyMP && !hasLegacyOpen:
cfg.Scopes = defaultWeChatConnectScopesForMode("mp")
case hasLegacyOpen:
cfg.Scopes = defaultWeChatConnectScopesForMode("open")
}
}
}
func normalizeWeChatConnectConfig(cfg *WeChatConnectConfig) {
if cfg == nil {
return
}
cfg.AppID = strings.TrimSpace(cfg.AppID)
cfg.AppSecret = strings.TrimSpace(cfg.AppSecret)
cfg.OpenAppID = strings.TrimSpace(cfg.OpenAppID)
cfg.OpenAppSecret = strings.TrimSpace(cfg.OpenAppSecret)
cfg.MPAppID = strings.TrimSpace(cfg.MPAppID)
cfg.MPAppSecret = strings.TrimSpace(cfg.MPAppSecret)
cfg.MobileAppID = strings.TrimSpace(cfg.MobileAppID)
cfg.MobileAppSecret = strings.TrimSpace(cfg.MobileAppSecret)
cfg.Mode = normalizeWeChatConnectMode(cfg.Mode)
cfg.RedirectURL = strings.TrimSpace(cfg.RedirectURL)
cfg.FrontendRedirectURL = strings.TrimSpace(cfg.FrontendRedirectURL)
cfg.AppID = firstNonEmptyString(cfg.AppID, cfg.OpenAppID, cfg.MPAppID, cfg.MobileAppID)
cfg.AppSecret = firstNonEmptyString(cfg.AppSecret, cfg.OpenAppSecret, cfg.MPAppSecret, cfg.MobileAppSecret)
cfg.OpenAppID = firstNonEmptyString(cfg.OpenAppID, cfg.AppID)
cfg.OpenAppSecret = firstNonEmptyString(cfg.OpenAppSecret, cfg.AppSecret)
cfg.MPAppID = firstNonEmptyString(cfg.MPAppID, cfg.AppID)
cfg.MPAppSecret = firstNonEmptyString(cfg.MPAppSecret, cfg.AppSecret)
cfg.MobileAppID = firstNonEmptyString(cfg.MobileAppID, cfg.AppID)
cfg.MobileAppSecret = firstNonEmptyString(cfg.MobileAppSecret, cfg.AppSecret)
if !cfg.OpenEnabled && !cfg.MPEnabled && !cfg.MobileEnabled && cfg.Enabled {
switch cfg.Mode {
case "mp":
cfg.MPEnabled = true
case "mobile":
cfg.MobileEnabled = true
default:
cfg.OpenEnabled = true
}
}
cfg.Mode = normalizeWeChatConnectStoredMode(cfg.OpenEnabled, cfg.MPEnabled, cfg.MobileEnabled, cfg.Mode)
cfg.Scopes = normalizeWeChatConnectScopes(cfg.Scopes, cfg.Mode)
if cfg.FrontendRedirectURL == "" {
cfg.FrontendRedirectURL = defaultWeChatConnectFrontendRedirect
}
}
// TokenRefreshConfig OAuth token自动刷新配置
type TokenRefreshConfig struct {
// 是否启用自动刷新
......@@ -1007,6 +1253,8 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
cfg.LinuxDo.UserInfoEmailPath = strings.TrimSpace(cfg.LinuxDo.UserInfoEmailPath)
cfg.LinuxDo.UserInfoIDPath = strings.TrimSpace(cfg.LinuxDo.UserInfoIDPath)
cfg.LinuxDo.UserInfoUsernamePath = strings.TrimSpace(cfg.LinuxDo.UserInfoUsernamePath)
applyLegacyWeChatConnectEnvCompatibility(&cfg.WeChat)
normalizeWeChatConnectConfig(&cfg.WeChat)
cfg.OIDC.ProviderName = strings.TrimSpace(cfg.OIDC.ProviderName)
cfg.OIDC.ClientID = strings.TrimSpace(cfg.OIDC.ClientID)
cfg.OIDC.ClientSecret = strings.TrimSpace(cfg.OIDC.ClientSecret)
......@@ -1024,6 +1272,8 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
cfg.OIDC.UserInfoEmailPath = strings.TrimSpace(cfg.OIDC.UserInfoEmailPath)
cfg.OIDC.UserInfoIDPath = strings.TrimSpace(cfg.OIDC.UserInfoIDPath)
cfg.OIDC.UserInfoUsernamePath = strings.TrimSpace(cfg.OIDC.UserInfoUsernamePath)
cfg.OIDC.UsePKCEExplicit = hasExplicitConfigOrEnv("oidc_connect.use_pkce", "OIDC_CONNECT_USE_PKCE")
cfg.OIDC.ValidateIDTokenExplicit = hasExplicitConfigOrEnv("oidc_connect.validate_id_token", "OIDC_CONNECT_VALIDATE_ID_TOKEN")
cfg.Dashboard.KeyPrefix = strings.TrimSpace(cfg.Dashboard.KeyPrefix)
cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins)
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
......@@ -1202,6 +1452,24 @@ func setDefaults() {
viper.SetDefault("linuxdo_connect.userinfo_id_path", "")
viper.SetDefault("linuxdo_connect.userinfo_username_path", "")
// WeChat Connect OAuth 登录
viper.SetDefault("wechat_connect.enabled", false)
viper.SetDefault("wechat_connect.app_id", "")
viper.SetDefault("wechat_connect.app_secret", "")
viper.SetDefault("wechat_connect.open_app_id", "")
viper.SetDefault("wechat_connect.open_app_secret", "")
viper.SetDefault("wechat_connect.mp_app_id", "")
viper.SetDefault("wechat_connect.mp_app_secret", "")
viper.SetDefault("wechat_connect.mobile_app_id", "")
viper.SetDefault("wechat_connect.mobile_app_secret", "")
viper.SetDefault("wechat_connect.open_enabled", false)
viper.SetDefault("wechat_connect.mp_enabled", false)
viper.SetDefault("wechat_connect.mobile_enabled", false)
viper.SetDefault("wechat_connect.mode", defaultWeChatConnectMode)
viper.SetDefault("wechat_connect.scopes", defaultWeChatConnectScopes)
viper.SetDefault("wechat_connect.redirect_url", "")
viper.SetDefault("wechat_connect.frontend_redirect_url", defaultWeChatConnectFrontendRedirect)
// Generic OIDC OAuth 登录
viper.SetDefault("oidc_connect.enabled", false)
viper.SetDefault("oidc_connect.provider_name", "OIDC")
......@@ -1217,7 +1485,7 @@ func setDefaults() {
viper.SetDefault("oidc_connect.redirect_url", "")
viper.SetDefault("oidc_connect.frontend_redirect_url", "/auth/oidc/callback")
viper.SetDefault("oidc_connect.token_auth_method", "client_secret_post")
viper.SetDefault("oidc_connect.use_pkce", false)
viper.SetDefault("oidc_connect.use_pkce", true)
viper.SetDefault("oidc_connect.validate_id_token", true)
viper.SetDefault("oidc_connect.allowed_signing_algs", "RS256,ES256,PS256")
viper.SetDefault("oidc_connect.clock_skew_seconds", 120)
......@@ -1407,7 +1675,7 @@ func setDefaults() {
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
viper.SetDefault("gateway.antigravity_extra_retries", 10)
viper.SetDefault("gateway.max_body_size", int64(256*1024*1024))
viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024))
viper.SetDefault("gateway.upstream_response_read_max_bytes", DefaultUpstreamResponseReadMaxBytes)
viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024))
viper.SetDefault("gateway.gemini_debug_response_headers", false)
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
......@@ -1629,9 +1897,6 @@ func (c *Config) Validate() error {
default:
return fmt.Errorf("linuxdo_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none")
}
if method == "none" && !c.LinuxDo.UsePKCE {
return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.token_auth_method=none")
}
if (method == "" || method == "client_secret_post" || method == "client_secret_basic") &&
strings.TrimSpace(c.LinuxDo.ClientSecret) == "" {
return fmt.Errorf("linuxdo_connect.client_secret is required when linuxdo_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic")
......@@ -1662,6 +1927,45 @@ func (c *Config) Validate() error {
warnIfInsecureURL("linuxdo_connect.redirect_url", c.LinuxDo.RedirectURL)
warnIfInsecureURL("linuxdo_connect.frontend_redirect_url", c.LinuxDo.FrontendRedirectURL)
}
if c.WeChat.Enabled {
weChat := c.WeChat
normalizeWeChatConnectConfig(&weChat)
if weChat.OpenEnabled {
if strings.TrimSpace(weChat.OpenAppID) == "" {
return fmt.Errorf("wechat_connect.open_app_id is required when wechat_connect.open_enabled=true")
}
if strings.TrimSpace(weChat.OpenAppSecret) == "" {
return fmt.Errorf("wechat_connect.open_app_secret is required when wechat_connect.open_enabled=true")
}
}
if weChat.MPEnabled {
if strings.TrimSpace(weChat.MPAppID) == "" {
return fmt.Errorf("wechat_connect.mp_app_id is required when wechat_connect.mp_enabled=true")
}
if strings.TrimSpace(weChat.MPAppSecret) == "" {
return fmt.Errorf("wechat_connect.mp_app_secret is required when wechat_connect.mp_enabled=true")
}
}
if weChat.MobileEnabled {
if strings.TrimSpace(weChat.MobileAppID) == "" {
return fmt.Errorf("wechat_connect.mobile_app_id is required when wechat_connect.mobile_enabled=true")
}
if strings.TrimSpace(weChat.MobileAppSecret) == "" {
return fmt.Errorf("wechat_connect.mobile_app_secret is required when wechat_connect.mobile_enabled=true")
}
}
if v := strings.TrimSpace(weChat.RedirectURL); v != "" {
if err := ValidateAbsoluteHTTPURL(v); err != nil {
return fmt.Errorf("wechat_connect.redirect_url invalid: %w", err)
}
warnIfInsecureURL("wechat_connect.redirect_url", v)
}
if err := ValidateFrontendRedirectURL(weChat.FrontendRedirectURL); err != nil {
return fmt.Errorf("wechat_connect.frontend_redirect_url invalid: %w", err)
}
warnIfInsecureURL("wechat_connect.frontend_redirect_url", weChat.FrontendRedirectURL)
}
if c.OIDC.Enabled {
if strings.TrimSpace(c.OIDC.ClientID) == "" {
return fmt.Errorf("oidc_connect.client_id is required when oidc_connect.enabled=true")
......@@ -1685,9 +1989,6 @@ func (c *Config) Validate() error {
default:
return fmt.Errorf("oidc_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none")
}
if method == "none" && !c.OIDC.UsePKCE {
return fmt.Errorf("oidc_connect.use_pkce must be true when oidc_connect.token_auth_method=none")
}
if (method == "" || method == "client_secret_post" || method == "client_secret_basic") &&
strings.TrimSpace(c.OIDC.ClientSecret) == "" {
return fmt.Errorf("oidc_connect.client_secret is required when oidc_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic")
......
......@@ -225,6 +225,52 @@ func TestLoadSchedulingConfigFromEnv(t *testing.T) {
}
}
func TestLoadWeChatConnectConfigFromLegacyEnv(t *testing.T) {
resetViperWithJWTSecret(t)
t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
t.Setenv("WECHAT_OAUTH_MP_APP_ID", "wx-mp-app")
t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", "wx-mp-secret")
t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/legacy-callback")
cfg, err := Load()
require.NoError(t, err)
require.True(t, cfg.WeChat.Enabled)
require.True(t, cfg.WeChat.OpenEnabled)
require.True(t, cfg.WeChat.MPEnabled)
require.False(t, cfg.WeChat.MobileEnabled)
require.Equal(t, "open", cfg.WeChat.Mode)
require.Equal(t, "wx-open-app", cfg.WeChat.OpenAppID)
require.Equal(t, "wx-open-secret", cfg.WeChat.OpenAppSecret)
require.Equal(t, "wx-mp-app", cfg.WeChat.MPAppID)
require.Equal(t, "wx-mp-secret", cfg.WeChat.MPAppSecret)
require.Equal(t, "/auth/wechat/legacy-callback", cfg.WeChat.FrontendRedirectURL)
}
func TestLoadDefaultOIDCSecurityDefaults(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
require.NoError(t, err)
require.True(t, cfg.OIDC.UsePKCE)
require.True(t, cfg.OIDC.ValidateIDToken)
require.False(t, cfg.OIDC.UsePKCEExplicit)
require.False(t, cfg.OIDC.ValidateIDTokenExplicit)
}
func TestLoadExplicitOIDCSecurityDefaultsFromEnvMarksFlagsExplicit(t *testing.T) {
resetViperWithJWTSecret(t)
t.Setenv("OIDC_CONNECT_USE_PKCE", "false")
t.Setenv("OIDC_CONNECT_VALIDATE_ID_TOKEN", "false")
cfg, err := Load()
require.NoError(t, err)
require.False(t, cfg.OIDC.UsePKCE)
require.False(t, cfg.OIDC.ValidateIDToken)
require.True(t, cfg.OIDC.UsePKCEExplicit)
require.True(t, cfg.OIDC.ValidateIDTokenExplicit)
}
func TestLoadForcedCodexInstructionsTemplate(t *testing.T) {
resetViperWithJWTSecret(t)
......@@ -334,7 +380,7 @@ func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) {
cfg.LinuxDo.ClientSecret = "test-secret"
cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback"
cfg.LinuxDo.TokenAuthMethod = "client_secret_post"
cfg.LinuxDo.UsePKCE = false
cfg.LinuxDo.UsePKCE = true
cfg.LinuxDo.FrontendRedirectURL = "javascript:alert(1)"
err = cfg.Validate()
......@@ -346,7 +392,7 @@ func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) {
}
}
func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
func TestValidateLinuxDoAllowsDisablingPKCEForCompatibility(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
......@@ -363,11 +409,8 @@ func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
cfg.LinuxDo.UsePKCE = false
err = cfg.Validate()
if err == nil {
t.Fatalf("Validate() expected error when token_auth_method=none and use_pkce=false, got nil")
}
if !strings.Contains(err.Error(), "linuxdo_connect.use_pkce") {
t.Fatalf("Validate() expected use_pkce error, got: %v", err)
if err != nil {
t.Fatalf("Validate() expected LinuxDo config without PKCE to pass for compatibility, got: %v", err)
}
}
......@@ -389,6 +432,7 @@ func TestValidateOIDCScopesMustContainOpenID(t *testing.T) {
cfg.OIDC.RedirectURL = "https://example.com/api/v1/auth/oauth/oidc/callback"
cfg.OIDC.FrontendRedirectURL = "/auth/oidc/callback"
cfg.OIDC.Scopes = "profile email"
cfg.OIDC.UsePKCE = true
err = cfg.Validate()
if err == nil {
......@@ -418,6 +462,7 @@ func TestValidateOIDCAllowsIssuerOnlyEndpointsWithDiscoveryFallback(t *testing.T
cfg.OIDC.FrontendRedirectURL = "/auth/oidc/callback"
cfg.OIDC.Scopes = "openid email profile"
cfg.OIDC.ValidateIDToken = true
cfg.OIDC.UsePKCE = true
err = cfg.Validate()
if err != nil {
......@@ -425,6 +470,35 @@ func TestValidateOIDCAllowsIssuerOnlyEndpointsWithDiscoveryFallback(t *testing.T
}
}
func TestValidateOIDCAllowsExplicitCompatibilityOverridesForPKCEAndIDTokenValidation(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.OIDC.Enabled = true
cfg.OIDC.ClientID = "oidc-client"
cfg.OIDC.ClientSecret = "oidc-secret"
cfg.OIDC.IssuerURL = "https://issuer.example.com"
cfg.OIDC.AuthorizeURL = "https://issuer.example.com/auth"
cfg.OIDC.TokenURL = "https://issuer.example.com/token"
cfg.OIDC.UserInfoURL = "https://issuer.example.com/userinfo"
cfg.OIDC.RedirectURL = "https://example.com/api/v1/auth/oauth/oidc/callback"
cfg.OIDC.FrontendRedirectURL = "/auth/oidc/callback"
cfg.OIDC.Scopes = "openid email profile"
cfg.OIDC.UsePKCE = false
cfg.OIDC.ValidateIDToken = false
cfg.OIDC.JWKSURL = ""
cfg.OIDC.AllowedSigningAlgs = ""
err = cfg.Validate()
if err != nil {
t.Fatalf("Validate() expected OIDC config without PKCE/id_token validation to pass for compatibility, got: %v", err)
}
}
func TestLoadDefaultDashboardCacheConfig(t *testing.T) {
resetViperWithJWTSecret(t)
......@@ -840,6 +914,7 @@ func TestValidateConfigWithLinuxDoEnabled(t *testing.T) {
cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback"
cfg.LinuxDo.FrontendRedirectURL = "/auth/linuxdo/callback"
cfg.LinuxDo.TokenAuthMethod = "client_secret_post"
cfg.LinuxDo.UsePKCE = true
if err := cfg.Validate(); err != nil {
t.Fatalf("Validate() unexpected error: %v", err)
......@@ -990,6 +1065,7 @@ func TestValidateConfigErrors(t *testing.T) {
name: "linuxdo client id required",
mutate: func(c *Config) {
c.LinuxDo.Enabled = true
c.LinuxDo.UsePKCE = true
c.LinuxDo.ClientID = ""
},
wantErr: "linuxdo_connect.client_id",
......@@ -998,6 +1074,7 @@ func TestValidateConfigErrors(t *testing.T) {
name: "linuxdo token auth method",
mutate: func(c *Config) {
c.LinuxDo.Enabled = true
c.LinuxDo.UsePKCE = true
c.LinuxDo.ClientID = "client"
c.LinuxDo.ClientSecret = "secret"
c.LinuxDo.AuthorizeURL = "https://example.com/authorize"
......
......@@ -23,6 +23,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
router.GET("/api/v1/admin/users", userHandler.List)
router.GET("/api/v1/admin/users/:id", userHandler.GetByID)
router.POST("/api/v1/admin/users/:id/auth-identities", userHandler.BindAuthIdentity)
router.POST("/api/v1/admin/users", userHandler.Create)
router.PUT("/api/v1/admin/users/:id", userHandler.Update)
router.DELETE("/api/v1/admin/users/:id", userHandler.Delete)
......@@ -75,8 +76,26 @@ func TestUserHandlerEndpoints(t *testing.T) {
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
bindBody := map[string]any{
"provider_type": "wechat",
"provider_key": "wechat-main",
"provider_subject": "union-123",
"metadata": map[string]any{"source": "admin-repair"},
"channel": map[string]any{
"channel": "open",
"channel_app_id": "wx-open",
"channel_subject": "openid-123",
},
}
body, _ := json.Marshal(bindBody)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/1/auth-identities", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
createBody := map[string]any{"email": "new@example.com", "password": "pass123", "balance": 1, "concurrency": 2}
body, _ := json.Marshal(createBody)
body, _ = json.Marshal(createBody)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
......@@ -113,6 +132,33 @@ func TestUserHandlerEndpoints(t *testing.T) {
require.Equal(t, http.StatusOK, rec.Code)
}
func TestUserHandlerBindAuthIdentityMapsRequest(t *testing.T) {
router, adminSvc := setupAdminRouter()
body, err := json.Marshal(map[string]any{
"provider_type": "oidc",
"provider_key": "https://issuer.example",
"provider_subject": "subject-123",
"issuer": "https://issuer.example",
"metadata": map[string]any{"report_id": 12},
})
require.NoError(t, err)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/9/auth-identities", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, int64(9), adminSvc.boundAuthIdentityFor)
require.NotNil(t, adminSvc.boundAuthIdentity)
require.Equal(t, "oidc", adminSvc.boundAuthIdentity.ProviderType)
require.Equal(t, "https://issuer.example", adminSvc.boundAuthIdentity.ProviderKey)
require.Equal(t, "subject-123", adminSvc.boundAuthIdentity.ProviderSubject)
require.Nil(t, adminSvc.boundAuthIdentity.Channel)
require.Equal(t, float64(12), adminSvc.boundAuthIdentity.Metadata["report_id"])
}
func TestGroupHandlerEndpoints(t *testing.T) {
router, _ := setupAdminRouter()
......
......@@ -17,6 +17,8 @@ type stubAdminService struct {
proxies []service.Proxy
proxyCounts []service.ProxyWithAccountCount
redeems []service.RedeemCode
boundAuthIdentity *service.AdminBindAuthIdentityInput
boundAuthIdentityFor int64
createdAccounts []*service.CreateAccountInput
createdProxies []*service.CreateProxyInput
updatedProxyIDs []int64
......@@ -42,6 +44,14 @@ type stubAdminService struct {
sortOrder string
calls int
}
lastListUsers struct {
page int
pageSize int
filters service.UserListFilters
sortBy string
sortOrder string
calls int
}
lastListProxies struct {
protocol string
status string
......@@ -127,6 +137,12 @@ func newStubAdminService() *stubAdminService {
}
func (s *stubAdminService) ListUsers(ctx context.Context, page, pageSize int, filters service.UserListFilters, sortBy, sortOrder string) ([]service.User, int64, error) {
s.lastListUsers.page = page
s.lastListUsers.pageSize = pageSize
s.lastListUsers.filters = filters
s.lastListUsers.sortBy = sortBy
s.lastListUsers.sortOrder = sortOrder
s.lastListUsers.calls++
return s.users, int64(len(s.users)), nil
}
......@@ -167,6 +183,63 @@ func (s *stubAdminService) GetUserUsageStats(ctx context.Context, userID int64,
return map[string]any{"user_id": userID}, nil
}
func (s *stubAdminService) GetUserRPMStatus(ctx context.Context, userID int64) (*service.UserRPMStatus, error) {
user, err := s.GetUser(ctx, userID)
if err != nil {
return nil, err
}
return &service.UserRPMStatus{
UserRPMUsed: 0,
UserRPMLimit: user.RPMLimit,
}, nil
}
func (s *stubAdminService) BindUserAuthIdentity(ctx context.Context, userID int64, input service.AdminBindAuthIdentityInput) (*service.AdminBoundAuthIdentity, error) {
s.boundAuthIdentityFor = userID
copied := input
if input.Metadata != nil {
copied.Metadata = map[string]any{}
for key, value := range input.Metadata {
copied.Metadata[key] = value
}
}
if input.Channel != nil {
channel := *input.Channel
if input.Channel.Metadata != nil {
channel.Metadata = map[string]any{}
for key, value := range input.Channel.Metadata {
channel.Metadata[key] = value
}
}
copied.Channel = &channel
}
s.boundAuthIdentity = &copied
now := time.Now().UTC()
result := &service.AdminBoundAuthIdentity{
UserID: userID,
ProviderType: input.ProviderType,
ProviderKey: input.ProviderKey,
ProviderSubject: input.ProviderSubject,
VerifiedAt: &now,
Issuer: input.Issuer,
Metadata: input.Metadata,
CreatedAt: now,
UpdatedAt: now,
}
if input.Channel != nil {
result.Channel = &service.AdminBoundAuthIdentityChannel{
Channel: input.Channel.Channel,
ChannelAppID: input.Channel.ChannelAppID,
ChannelSubject: input.Channel.ChannelSubject,
Metadata: input.Channel.Metadata,
CreatedAt: now,
UpdatedAt: now,
}
}
return result, nil
}
func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]service.Group, int64, error) {
return s.groups, int64(len(s.groups)), nil
}
......@@ -214,6 +287,14 @@ func (s *stubAdminService) BatchSetGroupRateMultipliers(_ context.Context, _ int
return nil
}
func (s *stubAdminService) ClearGroupRPMOverrides(_ context.Context, _ int64) error {
return nil
}
func (s *stubAdminService) BatchSetGroupRPMOverrides(_ context.Context, _ int64, _ []service.GroupRPMOverrideInput) error {
return nil
}
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string, sortBy, sortOrder string) ([]service.Account, int64, error) {
s.lastListAccounts.platform = platform
s.lastListAccounts.accountType = accountType
......
......@@ -158,9 +158,6 @@ func channelToResponse(ch *service.Channel) *channelResponse {
UpdatedAt: ch.UpdatedAt.Format("2006-01-02T15:04:05Z"),
}
resp.BillingModelSource = ch.BillingModelSource
if resp.BillingModelSource == "" {
resp.BillingModelSource = service.BillingModelSourceChannelMapped
}
if resp.GroupIDs == nil {
resp.GroupIDs = []int64{}
}
......
......@@ -91,7 +91,7 @@ func TestChannelToResponse_EmptyDefaults(t *testing.T) {
ch := &service.Channel{
ID: 1,
Name: "ch",
BillingModelSource: "",
BillingModelSource: service.BillingModelSourceChannelMapped,
CreatedAt: now,
UpdatedAt: now,
GroupIDs: nil,
......@@ -105,6 +105,9 @@ func TestChannelToResponse_EmptyDefaults(t *testing.T) {
},
}
// handler 层 channelToResponse 现在是纯透传:BillingModelSource 的空值兜底
// 已下放到 service 层(Create/GetByID/List/Update/ListAvailable 出口统一处理),
// 因此这里构造 fixture 时直接传入归一化后的值。
resp := channelToResponse(ch)
require.Equal(t, "channel_mapped", resp.BillingModelSource)
require.NotNil(t, resp.GroupIDs)
......@@ -117,6 +120,19 @@ func TestChannelToResponse_EmptyDefaults(t *testing.T) {
require.Equal(t, "token", resp.ModelPricing[0].BillingMode)
}
func TestChannelToResponse_BillingModelSourcePassthrough(t *testing.T) {
// handler 不再兜底 BillingModelSource:空值应原样透传(由 service 层负责默认回填)。
ch := &service.Channel{
ID: 1,
Name: "ch",
BillingModelSource: "",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
resp := channelToResponse(ch)
require.Equal(t, "", resp.BillingModelSource, "handler 应纯透传,默认值由 service.normalizeBillingModelSource 负责")
}
func TestChannelToResponse_NilModels(t *testing.T) {
now := time.Now()
ch := &service.Channel{
......
package admin
import (
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
const (
// monitorMaxPageSize 列表分页上限。
monitorMaxPageSize = 100
// monitorAPIKeyMaskPrefix 脱敏时保留的明文前缀长度。
monitorAPIKeyMaskPrefix = 4
// monitorAPIKeyMaskSuffix 脱敏后追加的占位字符串。
monitorAPIKeyMaskSuffix = "***"
)
// ChannelMonitorHandler 渠道监控管理后台 handler。
type ChannelMonitorHandler struct {
monitorService *service.ChannelMonitorService
}
// NewChannelMonitorHandler 创建 handler。
func NewChannelMonitorHandler(monitorService *service.ChannelMonitorService) *ChannelMonitorHandler {
return &ChannelMonitorHandler{monitorService: monitorService}
}
// --- Request / Response ---
type channelMonitorCreateRequest struct {
Name string `json:"name" binding:"required,max=100"`
Provider string `json:"provider" binding:"required,oneof=openai anthropic gemini"`
Endpoint string `json:"endpoint" binding:"required,max=500"`
APIKey string `json:"api_key" binding:"required,max=2000"`
PrimaryModel string `json:"primary_model" binding:"required,max=200"`
ExtraModels []string `json:"extra_models"`
GroupName string `json:"group_name" binding:"max=100"`
Enabled *bool `json:"enabled"`
IntervalSeconds int `json:"interval_seconds" binding:"required,min=15,max=3600"`
TemplateID *int64 `json:"template_id"`
ExtraHeaders map[string]string `json:"extra_headers"`
BodyOverrideMode string `json:"body_override_mode" binding:"omitempty,oneof=off merge replace"`
BodyOverride map[string]any `json:"body_override"`
}
type channelMonitorUpdateRequest struct {
Name *string `json:"name" binding:"omitempty,max=100"`
Provider *string `json:"provider" binding:"omitempty,oneof=openai anthropic gemini"`
Endpoint *string `json:"endpoint" binding:"omitempty,max=500"`
APIKey *string `json:"api_key" binding:"omitempty,max=2000"`
PrimaryModel *string `json:"primary_model" binding:"omitempty,max=200"`
ExtraModels *[]string `json:"extra_models"`
GroupName *string `json:"group_name" binding:"omitempty,max=100"`
Enabled *bool `json:"enabled"`
IntervalSeconds *int `json:"interval_seconds" binding:"omitempty,min=15,max=3600"`
TemplateID *int64 `json:"template_id"`
ClearTemplate bool `json:"clear_template"` // true 时把 template_id 置空,忽略 TemplateID
ExtraHeaders *map[string]string `json:"extra_headers"`
BodyOverrideMode *string `json:"body_override_mode" binding:"omitempty,oneof=off merge replace"`
BodyOverride *map[string]any `json:"body_override"`
}
type channelMonitorResponse struct {
ID int64 `json:"id"`
Name string `json:"name"`
Provider string `json:"provider"`
Endpoint string `json:"endpoint"`
APIKeyMasked string `json:"api_key_masked"`
APIKeyDecryptFailed bool `json:"api_key_decrypt_failed"`
PrimaryModel string `json:"primary_model"`
ExtraModels []string `json:"extra_models"`
GroupName string `json:"group_name"`
Enabled bool `json:"enabled"`
IntervalSeconds int `json:"interval_seconds"`
LastCheckedAt *string `json:"last_checked_at"`
CreatedBy int64 `json:"created_by"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
PrimaryStatus string `json:"primary_status"`
PrimaryLatencyMs *int `json:"primary_latency_ms"`
Availability7d float64 `json:"availability_7d"`
ExtraModelsStatus []dto.ChannelMonitorExtraModelStatus `json:"extra_models_status"`
// 请求自定义快照:前端编辑 / 展示「高级设置」用
TemplateID *int64 `json:"template_id"`
ExtraHeaders map[string]string `json:"extra_headers"`
BodyOverrideMode string `json:"body_override_mode"`
BodyOverride map[string]any `json:"body_override"`
}
type channelMonitorCheckResultResponse struct {
Model string `json:"model"`
Status string `json:"status"`
LatencyMs *int `json:"latency_ms"`
PingLatencyMs *int `json:"ping_latency_ms"`
Message string `json:"message"`
CheckedAt string `json:"checked_at"`
}
type channelMonitorHistoryItemResponse struct {
ID int64 `json:"id"`
Model string `json:"model"`
Status string `json:"status"`
LatencyMs *int `json:"latency_ms"`
PingLatencyMs *int `json:"ping_latency_ms"`
Message string `json:"message"`
CheckedAt string `json:"checked_at"`
}
// maskAPIKey 对 API Key 明文做脱敏:前 4 字符 + "***",长度 ≤ 4 时只显示 "***"。
func maskAPIKey(plain string) string {
if len(plain) <= monitorAPIKeyMaskPrefix {
return monitorAPIKeyMaskSuffix
}
return plain[:monitorAPIKeyMaskPrefix] + monitorAPIKeyMaskSuffix
}
func channelMonitorToResponse(m *service.ChannelMonitor) *channelMonitorResponse {
if m == nil {
return nil
}
extras := m.ExtraModels
if extras == nil {
extras = []string{}
}
headers := m.ExtraHeaders
if headers == nil {
headers = map[string]string{}
}
resp := &channelMonitorResponse{
ID: m.ID,
Name: m.Name,
Provider: m.Provider,
Endpoint: m.Endpoint,
APIKeyMasked: maskAPIKey(m.APIKey),
APIKeyDecryptFailed: m.APIKeyDecryptFailed,
PrimaryModel: m.PrimaryModel,
ExtraModels: extras,
GroupName: m.GroupName,
Enabled: m.Enabled,
IntervalSeconds: m.IntervalSeconds,
CreatedBy: m.CreatedBy,
CreatedAt: m.CreatedAt.UTC().Format(time.RFC3339),
UpdatedAt: m.UpdatedAt.UTC().Format(time.RFC3339),
TemplateID: m.TemplateID,
ExtraHeaders: headers,
BodyOverrideMode: m.BodyOverrideMode,
BodyOverride: m.BodyOverride,
// PrimaryStatus / PrimaryLatencyMs / Availability7d 由 List handler 在批量聚合后填充。
}
if m.LastCheckedAt != nil {
s := m.LastCheckedAt.UTC().Format(time.RFC3339)
resp.LastCheckedAt = &s
}
return resp
}
func checkResultToResponse(r *service.CheckResult) channelMonitorCheckResultResponse {
return channelMonitorCheckResultResponse{
Model: r.Model,
Status: r.Status,
LatencyMs: r.LatencyMs,
PingLatencyMs: r.PingLatencyMs,
Message: r.Message,
CheckedAt: r.CheckedAt.UTC().Format(time.RFC3339),
}
}
func historyEntryToResponse(e *service.ChannelMonitorHistoryEntry) channelMonitorHistoryItemResponse {
return channelMonitorHistoryItemResponse{
ID: e.ID,
Model: e.Model,
Status: e.Status,
LatencyMs: e.LatencyMs,
PingLatencyMs: e.PingLatencyMs,
Message: e.Message,
CheckedAt: e.CheckedAt.UTC().Format(time.RFC3339),
}
}
// ParseChannelMonitorID 提取并校验路径参数 :id(admin 与 user handler 共享)。
// 校验失败时已写入 4xx 响应,调用方只需 return。
func ParseChannelMonitorID(c *gin.Context) (int64, bool) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil || id <= 0 {
response.ErrorFrom(c, infraerrors.BadRequest("INVALID_MONITOR_ID", "invalid monitor id"))
return 0, false
}
return id, true
}
// parseListEnabled 解析 enabled query 参数:true/false 转为 *bool,空或非法则返回 nil。
func parseListEnabled(raw string) *bool {
switch strings.ToLower(strings.TrimSpace(raw)) {
case "true", "1", "yes":
v := true
return &v
case "false", "0", "no":
v := false
return &v
default:
return nil
}
}
// --- Handlers ---
// List GET /api/v1/admin/channel-monitors
func (h *ChannelMonitorHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
if pageSize > monitorMaxPageSize {
pageSize = monitorMaxPageSize
}
params := service.ChannelMonitorListParams{
Page: page,
PageSize: pageSize,
Provider: strings.TrimSpace(c.Query("provider")),
Enabled: parseListEnabled(c.Query("enabled")),
Search: strings.TrimSpace(c.Query("search")),
}
items, total, err := h.monitorService.List(c.Request.Context(), params)
if err != nil {
response.ErrorFrom(c, err)
return
}
summaries := h.batchSummaryFor(c, items)
out := make([]*channelMonitorResponse, 0, len(items))
for _, m := range items {
out = append(out, buildListItemResponse(m, summaries[m.ID]))
}
response.Paginated(c, out, total, page, pageSize)
}
// batchSummaryFor 批量聚合 latest + 7d 可用率,避免每行 2 次 SQL(消除 N+1)。
func (h *ChannelMonitorHandler) batchSummaryFor(c *gin.Context, items []*service.ChannelMonitor) map[int64]service.MonitorStatusSummary {
ids := make([]int64, 0, len(items))
primaryByID := make(map[int64]string, len(items))
extrasByID := make(map[int64][]string, len(items))
for _, m := range items {
ids = append(ids, m.ID)
primaryByID[m.ID] = m.PrimaryModel
extrasByID[m.ID] = m.ExtraModels
}
return h.monitorService.BatchMonitorStatusSummary(c.Request.Context(), ids, primaryByID, extrasByID)
}
// buildListItemResponse 把 monitor + summary 装成 admin list 的响应行。
func buildListItemResponse(m *service.ChannelMonitor, summary service.MonitorStatusSummary) *channelMonitorResponse {
resp := channelMonitorToResponse(m)
resp.PrimaryStatus = summary.PrimaryStatus
resp.PrimaryLatencyMs = summary.PrimaryLatencyMs
resp.Availability7d = summary.Availability7d
resp.ExtraModelsStatus = make([]dto.ChannelMonitorExtraModelStatus, 0, len(summary.ExtraModels))
for _, e := range summary.ExtraModels {
resp.ExtraModelsStatus = append(resp.ExtraModelsStatus, dto.ChannelMonitorExtraModelStatus{
Model: e.Model,
Status: e.Status,
LatencyMs: e.LatencyMs,
})
}
return resp
}
// Get GET /api/v1/admin/channel-monitors/:id
func (h *ChannelMonitorHandler) Get(c *gin.Context) {
id, ok := ParseChannelMonitorID(c)
if !ok {
return
}
m, err := h.monitorService.Get(c.Request.Context(), id)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, channelMonitorToResponse(m))
}
// Create POST /api/v1/admin/channel-monitors
func (h *ChannelMonitorHandler) Create(c *gin.Context) {
var req channelMonitorCreateRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
return
}
subject, _ := middleware2.GetAuthSubjectFromContext(c)
enabled := true
if req.Enabled != nil {
enabled = *req.Enabled
}
m, err := h.monitorService.Create(c.Request.Context(), service.ChannelMonitorCreateParams{
Name: req.Name,
Provider: req.Provider,
Endpoint: req.Endpoint,
APIKey: req.APIKey,
PrimaryModel: req.PrimaryModel,
ExtraModels: req.ExtraModels,
GroupName: req.GroupName,
Enabled: enabled,
IntervalSeconds: req.IntervalSeconds,
CreatedBy: subject.UserID,
TemplateID: req.TemplateID,
ExtraHeaders: req.ExtraHeaders,
BodyOverrideMode: req.BodyOverrideMode,
BodyOverride: req.BodyOverride,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Created(c, channelMonitorToResponse(m))
}
// Update PUT /api/v1/admin/channel-monitors/:id
func (h *ChannelMonitorHandler) Update(c *gin.Context) {
id, ok := ParseChannelMonitorID(c)
if !ok {
return
}
var req channelMonitorUpdateRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
return
}
m, err := h.monitorService.Update(c.Request.Context(), id, service.ChannelMonitorUpdateParams{
Name: req.Name,
Provider: req.Provider,
Endpoint: req.Endpoint,
APIKey: req.APIKey,
PrimaryModel: req.PrimaryModel,
ExtraModels: req.ExtraModels,
GroupName: req.GroupName,
Enabled: req.Enabled,
IntervalSeconds: req.IntervalSeconds,
TemplateID: req.TemplateID,
ClearTemplate: req.ClearTemplate,
ExtraHeaders: req.ExtraHeaders,
BodyOverrideMode: req.BodyOverrideMode,
BodyOverride: req.BodyOverride,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, channelMonitorToResponse(m))
}
// Delete DELETE /api/v1/admin/channel-monitors/:id
func (h *ChannelMonitorHandler) Delete(c *gin.Context) {
id, ok := ParseChannelMonitorID(c)
if !ok {
return
}
if err := h.monitorService.Delete(c.Request.Context(), id); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, nil)
}
// Run POST /api/v1/admin/channel-monitors/:id/run
func (h *ChannelMonitorHandler) Run(c *gin.Context) {
id, ok := ParseChannelMonitorID(c)
if !ok {
return
}
results, err := h.monitorService.RunCheck(c.Request.Context(), id)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]channelMonitorCheckResultResponse, 0, len(results))
for _, r := range results {
out = append(out, checkResultToResponse(r))
}
response.Success(c, gin.H{"results": out})
}
// History GET /api/v1/admin/channel-monitors/:id/history
func (h *ChannelMonitorHandler) History(c *gin.Context) {
id, ok := ParseChannelMonitorID(c)
if !ok {
return
}
limit := parseHistoryLimit(c.Query("limit"))
model := strings.TrimSpace(c.Query("model"))
entries, err := h.monitorService.ListHistory(c.Request.Context(), id, model, limit)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]channelMonitorHistoryItemResponse, 0, len(entries))
for _, e := range entries {
out = append(out, historyEntryToResponse(e))
}
response.Success(c, gin.H{"items": out})
}
// parseHistoryLimit 解析 history 接口的 limit query。
// 使用 service 包的统一上下限常量,避免在 handler 重复定义同名魔法值。
func parseHistoryLimit(raw string) int {
if strings.TrimSpace(raw) == "" {
return service.MonitorHistoryDefaultLimit
}
v, err := strconv.Atoi(raw)
if err != nil || v <= 0 {
return service.MonitorHistoryDefaultLimit
}
if v > service.MonitorHistoryMaxLimit {
return service.MonitorHistoryMaxLimit
}
return v
}
package admin
import (
"strconv"
"strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// ChannelMonitorRequestTemplateHandler 请求模板管理后台 handler。
type ChannelMonitorRequestTemplateHandler struct {
templateService *service.ChannelMonitorRequestTemplateService
}
// NewChannelMonitorRequestTemplateHandler 创建 handler。
func NewChannelMonitorRequestTemplateHandler(templateService *service.ChannelMonitorRequestTemplateService) *ChannelMonitorRequestTemplateHandler {
return &ChannelMonitorRequestTemplateHandler{templateService: templateService}
}
// --- DTO ---
type channelMonitorTemplateCreateRequest struct {
Name string `json:"name" binding:"required,max=100"`
Provider string `json:"provider" binding:"required,oneof=openai anthropic gemini"`
Description string `json:"description" binding:"max=500"`
ExtraHeaders map[string]string `json:"extra_headers"`
BodyOverrideMode string `json:"body_override_mode" binding:"omitempty,oneof=off merge replace"`
BodyOverride map[string]any `json:"body_override"`
}
type channelMonitorTemplateUpdateRequest struct {
Name *string `json:"name" binding:"omitempty,max=100"`
Description *string `json:"description" binding:"omitempty,max=500"`
ExtraHeaders *map[string]string `json:"extra_headers"`
BodyOverrideMode *string `json:"body_override_mode" binding:"omitempty,oneof=off merge replace"`
BodyOverride *map[string]any `json:"body_override"`
}
type channelMonitorTemplateResponse struct {
ID int64 `json:"id"`
Name string `json:"name"`
Provider string `json:"provider"`
Description string `json:"description"`
ExtraHeaders map[string]string `json:"extra_headers"`
BodyOverrideMode string `json:"body_override_mode"`
BodyOverride map[string]any `json:"body_override"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
AssociatedMonitors int64 `json:"associated_monitors"`
}
func (h *ChannelMonitorRequestTemplateHandler) toResponse(c *gin.Context, t *service.ChannelMonitorRequestTemplate) *channelMonitorTemplateResponse {
if t == nil {
return nil
}
headers := t.ExtraHeaders
if headers == nil {
headers = map[string]string{}
}
count, _ := h.templateService.CountAssociatedMonitors(c.Request.Context(), t.ID)
return &channelMonitorTemplateResponse{
ID: t.ID,
Name: t.Name,
Provider: t.Provider,
Description: t.Description,
ExtraHeaders: headers,
BodyOverrideMode: t.BodyOverrideMode,
BodyOverride: t.BodyOverride,
CreatedAt: t.CreatedAt.UTC().Format(time.RFC3339),
UpdatedAt: t.UpdatedAt.UTC().Format(time.RFC3339),
AssociatedMonitors: count,
}
}
// parseTemplateID 提取并校验 :id。
func parseTemplateID(c *gin.Context) (int64, bool) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil || id <= 0 {
response.ErrorFrom(c, infraerrors.BadRequest("INVALID_TEMPLATE_ID", "invalid template id"))
return 0, false
}
return id, true
}
// --- Handlers ---
// List GET /api/v1/admin/channel-monitor-templates?provider=anthropic
func (h *ChannelMonitorRequestTemplateHandler) List(c *gin.Context) {
items, err := h.templateService.List(c.Request.Context(), service.ChannelMonitorRequestTemplateListParams{
Provider: strings.TrimSpace(c.Query("provider")),
})
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]*channelMonitorTemplateResponse, 0, len(items))
for _, t := range items {
out = append(out, h.toResponse(c, t))
}
response.Success(c, gin.H{"items": out})
}
// Get GET /api/v1/admin/channel-monitor-templates/:id
func (h *ChannelMonitorRequestTemplateHandler) Get(c *gin.Context) {
id, ok := parseTemplateID(c)
if !ok {
return
}
t, err := h.templateService.Get(c.Request.Context(), id)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, h.toResponse(c, t))
}
// Create POST /api/v1/admin/channel-monitor-templates
func (h *ChannelMonitorRequestTemplateHandler) Create(c *gin.Context) {
var req channelMonitorTemplateCreateRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
return
}
t, err := h.templateService.Create(c.Request.Context(), service.ChannelMonitorRequestTemplateCreateParams{
Name: req.Name,
Provider: req.Provider,
Description: req.Description,
ExtraHeaders: req.ExtraHeaders,
BodyOverrideMode: req.BodyOverrideMode,
BodyOverride: req.BodyOverride,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Created(c, h.toResponse(c, t))
}
// Update PUT /api/v1/admin/channel-monitor-templates/:id
func (h *ChannelMonitorRequestTemplateHandler) Update(c *gin.Context) {
id, ok := parseTemplateID(c)
if !ok {
return
}
var req channelMonitorTemplateUpdateRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
return
}
t, err := h.templateService.Update(c.Request.Context(), id, service.ChannelMonitorRequestTemplateUpdateParams{
Name: req.Name,
Description: req.Description,
ExtraHeaders: req.ExtraHeaders,
BodyOverrideMode: req.BodyOverrideMode,
BodyOverride: req.BodyOverride,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, h.toResponse(c, t))
}
// Delete DELETE /api/v1/admin/channel-monitor-templates/:id
func (h *ChannelMonitorRequestTemplateHandler) Delete(c *gin.Context) {
id, ok := parseTemplateID(c)
if !ok {
return
}
if err := h.templateService.Delete(c.Request.Context(), id); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, nil)
}
type channelMonitorTemplateApplyRequest struct {
// MonitorIDs 必填、非空:用户在 picker 里勾选的要被覆盖的监控 ID 列表。
// 仅当对应监控当前 template_id == :id 时才会真的被覆盖。
MonitorIDs []int64 `json:"monitor_ids" binding:"required,min=1"`
}
// Apply POST /api/v1/admin/channel-monitor-templates/:id/apply
// 把模板当前配置覆盖到 monitor_ids 列表里的关联监控(picker 选中的子集)。
func (h *ChannelMonitorRequestTemplateHandler) Apply(c *gin.Context) {
id, ok := parseTemplateID(c)
if !ok {
return
}
var req channelMonitorTemplateApplyRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
return
}
affected, err := h.templateService.ApplyToMonitors(c.Request.Context(), id, req.MonitorIDs)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"affected": affected})
}
type associatedMonitorBriefResponse struct {
ID int64 `json:"id"`
Name string `json:"name"`
Provider string `json:"provider"`
Enabled bool `json:"enabled"`
}
// AssociatedMonitors GET /api/v1/admin/channel-monitor-templates/:id/monitors
// 列出关联监控(picker 弹窗用)。
func (h *ChannelMonitorRequestTemplateHandler) AssociatedMonitors(c *gin.Context) {
id, ok := parseTemplateID(c)
if !ok {
return
}
items, err := h.templateService.ListAssociatedMonitors(c.Request.Context(), id)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]associatedMonitorBriefResponse, 0, len(items))
for _, m := range items {
out = append(out, associatedMonitorBriefResponse{
ID: m.ID, Name: m.Name, Provider: m.Provider, Enabled: m.Enabled,
})
}
response.Success(c, gin.H{"items": out})
}
......@@ -110,6 +110,8 @@ type CreateGroupRequest struct {
RequirePrivacySet bool `json:"require_privacy_set"`
DefaultMappedModel string `json:"default_mapped_model"`
MessagesDispatchModelConfig service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"`
// 分组 RPM 上限(0 = 不限制)
RPMLimit int `json:"rpm_limit"`
// 从指定分组复制账号(创建后自动绑定)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
}
......@@ -145,6 +147,8 @@ type UpdateGroupRequest struct {
RequirePrivacySet *bool `json:"require_privacy_set"`
DefaultMappedModel *string `json:"default_mapped_model"`
MessagesDispatchModelConfig *service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"`
// 分组 RPM 上限(0 = 不限制);nil 表示未提供不改动
RPMLimit *int `json:"rpm_limit"`
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
}
......@@ -262,6 +266,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
RequirePrivacySet: req.RequirePrivacySet,
DefaultMappedModel: req.DefaultMappedModel,
MessagesDispatchModelConfig: req.MessagesDispatchModelConfig,
RPMLimit: req.RPMLimit,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
})
if err != nil {
......@@ -313,6 +318,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
RequirePrivacySet: req.RequirePrivacySet,
DefaultMappedModel: req.DefaultMappedModel,
MessagesDispatchModelConfig: req.MessagesDispatchModelConfig,
RPMLimit: req.RPMLimit,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
})
if err != nil {
......@@ -477,6 +483,51 @@ func (h *GroupHandler) BatchSetGroupRateMultipliers(c *gin.Context) {
response.Success(c, gin.H{"message": "Rate multipliers updated successfully"})
}
// BatchSetGroupRPMOverridesRequest represents batch set rpm_override request
type BatchSetGroupRPMOverridesRequest struct {
Entries []service.GroupRPMOverrideInput `json:"entries" binding:"required"`
}
// BatchSetGroupRPMOverrides handles batch setting rpm_override for users in a group
// PUT /api/v1/admin/groups/:id/rpm-overrides
func (h *GroupHandler) BatchSetGroupRPMOverrides(c *gin.Context) {
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid group ID")
return
}
var req BatchSetGroupRPMOverridesRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if err := h.adminService.BatchSetGroupRPMOverrides(c.Request.Context(), groupID, req.Entries); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "RPM overrides updated successfully"})
}
// ClearGroupRPMOverrides handles clearing all rpm_override for a group
// DELETE /api/v1/admin/groups/:id/rpm-overrides
func (h *GroupHandler) ClearGroupRPMOverrides(c *gin.Context) {
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid group ID")
return
}
if err := h.adminService.ClearGroupRPMOverrides(c.Request.Context(), groupID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "RPM overrides cleared successfully"})
}
// UpdateSortOrderRequest represents the request to update group sort orders
type UpdateSortOrderRequest struct {
Updates []struct {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment