Commit f1297a36 authored by erio's avatar erio
Browse files

feat: add per-provider allow_user_refund control and align wildcard matching

allow_user_refund:
- Add allow_user_refund field to PaymentProviderInstance ent schema
- Migration 103: ALTER TABLE payment_provider_instances ADD COLUMN
- Cascade logic: disabling refund_enabled auto-disables allow_user_refund
- User refund validation: check provider instance allows user refund
- Admin refund validation: check provider instance allows admin refund
- Subscription refund: deduct days on refund, rollback on failure
- New endpoint: GET /payment/orders/refund-eligible-providers
- Frontend: ToggleSwitch in ProviderCard/Dialog, cascade in SettingsView

Wildcard matching:
- Change findPricingForModel from "longest prefix wins" to "config order
  priority (first match wins)", aligning with channel service behavior
parent e8ee400a
...@@ -333,10 +333,10 @@ func (c *Client) Use(hooks ...Hook) { ...@@ -333,10 +333,10 @@ func (c *Client) Use(hooks ...Hook) {
for _, n := range []interface{ Use(...Hook) }{ for _, n := range []interface{ Use(...Hook) }{
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PaymentAuditLog, c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PaymentAuditLog,
c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, c.PromoCodeUsage,
c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.SubscriptionPlan,
c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User,
c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.UserSubscription, c.UserSubscription,
} { } {
n.Use(hooks...) n.Use(hooks...)
...@@ -349,10 +349,10 @@ func (c *Client) Intercept(interceptors ...Interceptor) { ...@@ -349,10 +349,10 @@ func (c *Client) Intercept(interceptors ...Interceptor) {
for _, n := range []interface{ Intercept(...Interceptor) }{ for _, n := range []interface{ Intercept(...Interceptor) }{
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PaymentAuditLog, c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PaymentAuditLog,
c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, c.PromoCodeUsage,
c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.SubscriptionPlan,
c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User,
c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.UserSubscription, c.UserSubscription,
} { } {
n.Intercept(interceptors...) n.Intercept(interceptors...)
...@@ -4629,19 +4629,19 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription ...@@ -4629,19 +4629,19 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription
type ( type (
hooks struct { hooks struct {
APIKey, Account, AccountGroup, Announcement, AnnouncementRead, APIKey, Account, AccountGroup, Announcement, AnnouncementRead,
ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, PaymentOrder,
PaymentOrder, PaymentProviderInstance, PromoCode, PaymentProviderInstance, PromoCode, PromoCodeUsage, Proxy, RedeemCode,
PromoCodeUsage, Proxy, RedeemCode, SecuritySecret, Setting, SubscriptionPlan, SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile,
TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup, UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition,
UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook UserAttributeValue, UserSubscription []ent.Hook
} }
inters struct { inters struct {
APIKey, Account, AccountGroup, Announcement, AnnouncementRead, APIKey, Account, AccountGroup, Announcement, AnnouncementRead,
ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, PaymentOrder,
PaymentOrder, PaymentProviderInstance, PromoCode, PaymentProviderInstance, PromoCode, PromoCodeUsage, Proxy, RedeemCode,
PromoCodeUsage, Proxy, RedeemCode, SecuritySecret, Setting, SubscriptionPlan, SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile,
TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup, UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition,
UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor UserAttributeValue, UserSubscription []ent.Interceptor
} }
) )
......
...@@ -336,7 +336,6 @@ func (f TraversePaymentAuditLog) Traverse(ctx context.Context, q ent.Query) erro ...@@ -336,7 +336,6 @@ func (f TraversePaymentAuditLog) Traverse(ctx context.Context, q ent.Query) erro
return fmt.Errorf("unexpected query type %T. expect *ent.PaymentAuditLogQuery", q) return fmt.Errorf("unexpected query type %T. expect *ent.PaymentAuditLogQuery", q)
} }
// The PaymentOrderFunc type is an adapter to allow the use of ordinary function as a Querier. // The PaymentOrderFunc type is an adapter to allow the use of ordinary function as a Querier.
type PaymentOrderFunc func(context.Context, *ent.PaymentOrderQuery) (ent.Value, error) type PaymentOrderFunc func(context.Context, *ent.PaymentOrderQuery) (ent.Value, error)
......
...@@ -616,6 +616,7 @@ var ( ...@@ -616,6 +616,7 @@ var (
{Name: "sort_order", Type: field.TypeInt, Default: 0}, {Name: "sort_order", Type: field.TypeInt, Default: 0},
{Name: "limits", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}}, {Name: "limits", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
{Name: "refund_enabled", Type: field.TypeBool, Default: false}, {Name: "refund_enabled", Type: field.TypeBool, Default: false},
{Name: "allow_user_refund", Type: field.TypeBool, Default: false},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
} }
......
...@@ -15655,6 +15655,7 @@ type PaymentProviderInstanceMutation struct { ...@@ -15655,6 +15655,7 @@ type PaymentProviderInstanceMutation struct {
addsort_order *int addsort_order *int
limits *string limits *string
refund_enabled *bool refund_enabled *bool
allow_user_refund *bool
created_at *time.Time created_at *time.Time
updated_at *time.Time updated_at *time.Time
clearedFields map[string]struct{} clearedFields map[string]struct{}
...@@ -16105,6 +16106,42 @@ func (m *PaymentProviderInstanceMutation) ResetRefundEnabled() { ...@@ -16105,6 +16106,42 @@ func (m *PaymentProviderInstanceMutation) ResetRefundEnabled() {
m.refund_enabled = nil m.refund_enabled = nil
} }
   
// SetAllowUserRefund sets the "allow_user_refund" field.
func (m *PaymentProviderInstanceMutation) SetAllowUserRefund(b bool) {
m.allow_user_refund = &b
}
// AllowUserRefund returns the value of the "allow_user_refund" field in the mutation.
func (m *PaymentProviderInstanceMutation) AllowUserRefund() (r bool, exists bool) {
v := m.allow_user_refund
if v == nil {
return
}
return *v, true
}
// OldAllowUserRefund returns the old "allow_user_refund" field's value of the PaymentProviderInstance entity.
// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *PaymentProviderInstanceMutation) OldAllowUserRefund(ctx context.Context) (v bool, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldAllowUserRefund is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldAllowUserRefund requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldAllowUserRefund: %w", err)
}
return oldValue.AllowUserRefund, nil
}
// ResetAllowUserRefund resets all changes to the "allow_user_refund" field.
func (m *PaymentProviderInstanceMutation) ResetAllowUserRefund() {
m.allow_user_refund = nil
}
// SetCreatedAt sets the "created_at" field. // SetCreatedAt sets the "created_at" field.
func (m *PaymentProviderInstanceMutation) SetCreatedAt(t time.Time) { func (m *PaymentProviderInstanceMutation) SetCreatedAt(t time.Time) {
m.created_at = &t m.created_at = &t
...@@ -16211,7 +16248,7 @@ func (m *PaymentProviderInstanceMutation) Type() string { ...@@ -16211,7 +16248,7 @@ func (m *PaymentProviderInstanceMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call // order to get all numeric fields that were incremented/decremented, call
// AddedFields(). // AddedFields().
func (m *PaymentProviderInstanceMutation) Fields() []string { func (m *PaymentProviderInstanceMutation) Fields() []string {
fields := make([]string, 0, 11) fields := make([]string, 0, 12)
if m.provider_key != nil { if m.provider_key != nil {
fields = append(fields, paymentproviderinstance.FieldProviderKey) fields = append(fields, paymentproviderinstance.FieldProviderKey)
} }
...@@ -16239,6 +16276,9 @@ func (m *PaymentProviderInstanceMutation) Fields() []string { ...@@ -16239,6 +16276,9 @@ func (m *PaymentProviderInstanceMutation) Fields() []string {
if m.refund_enabled != nil { if m.refund_enabled != nil {
fields = append(fields, paymentproviderinstance.FieldRefundEnabled) fields = append(fields, paymentproviderinstance.FieldRefundEnabled)
} }
if m.allow_user_refund != nil {
fields = append(fields, paymentproviderinstance.FieldAllowUserRefund)
}
if m.created_at != nil { if m.created_at != nil {
fields = append(fields, paymentproviderinstance.FieldCreatedAt) fields = append(fields, paymentproviderinstance.FieldCreatedAt)
} }
...@@ -16271,6 +16311,8 @@ func (m *PaymentProviderInstanceMutation) Field(name string) (ent.Value, bool) { ...@@ -16271,6 +16311,8 @@ func (m *PaymentProviderInstanceMutation) Field(name string) (ent.Value, bool) {
return m.Limits() return m.Limits()
case paymentproviderinstance.FieldRefundEnabled: case paymentproviderinstance.FieldRefundEnabled:
return m.RefundEnabled() return m.RefundEnabled()
case paymentproviderinstance.FieldAllowUserRefund:
return m.AllowUserRefund()
case paymentproviderinstance.FieldCreatedAt: case paymentproviderinstance.FieldCreatedAt:
return m.CreatedAt() return m.CreatedAt()
case paymentproviderinstance.FieldUpdatedAt: case paymentproviderinstance.FieldUpdatedAt:
...@@ -16302,6 +16344,8 @@ func (m *PaymentProviderInstanceMutation) OldField(ctx context.Context, name str ...@@ -16302,6 +16344,8 @@ func (m *PaymentProviderInstanceMutation) OldField(ctx context.Context, name str
return m.OldLimits(ctx) return m.OldLimits(ctx)
case paymentproviderinstance.FieldRefundEnabled: case paymentproviderinstance.FieldRefundEnabled:
return m.OldRefundEnabled(ctx) return m.OldRefundEnabled(ctx)
case paymentproviderinstance.FieldAllowUserRefund:
return m.OldAllowUserRefund(ctx)
case paymentproviderinstance.FieldCreatedAt: case paymentproviderinstance.FieldCreatedAt:
return m.OldCreatedAt(ctx) return m.OldCreatedAt(ctx)
case paymentproviderinstance.FieldUpdatedAt: case paymentproviderinstance.FieldUpdatedAt:
...@@ -16378,6 +16422,13 @@ func (m *PaymentProviderInstanceMutation) SetField(name string, value ent.Value) ...@@ -16378,6 +16422,13 @@ func (m *PaymentProviderInstanceMutation) SetField(name string, value ent.Value)
} }
m.SetRefundEnabled(v) m.SetRefundEnabled(v)
return nil return nil
case paymentproviderinstance.FieldAllowUserRefund:
v, ok := value.(bool)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetAllowUserRefund(v)
return nil
case paymentproviderinstance.FieldCreatedAt: case paymentproviderinstance.FieldCreatedAt:
v, ok := value.(time.Time) v, ok := value.(time.Time)
if !ok { if !ok {
...@@ -16483,6 +16534,9 @@ func (m *PaymentProviderInstanceMutation) ResetField(name string) error { ...@@ -16483,6 +16534,9 @@ func (m *PaymentProviderInstanceMutation) ResetField(name string) error {
case paymentproviderinstance.FieldRefundEnabled: case paymentproviderinstance.FieldRefundEnabled:
m.ResetRefundEnabled() m.ResetRefundEnabled()
return nil return nil
case paymentproviderinstance.FieldAllowUserRefund:
m.ResetAllowUserRefund()
return nil
case paymentproviderinstance.FieldCreatedAt: case paymentproviderinstance.FieldCreatedAt:
m.ResetCreatedAt() m.ResetCreatedAt()
return nil return nil
...@@ -35,6 +35,8 @@ type PaymentProviderInstance struct { ...@@ -35,6 +35,8 @@ type PaymentProviderInstance struct {
Limits string `json:"limits,omitempty"` Limits string `json:"limits,omitempty"`
// RefundEnabled holds the value of the "refund_enabled" field. // RefundEnabled holds the value of the "refund_enabled" field.
RefundEnabled bool `json:"refund_enabled,omitempty"` RefundEnabled bool `json:"refund_enabled,omitempty"`
// AllowUserRefund holds the value of the "allow_user_refund" field.
AllowUserRefund bool `json:"allow_user_refund,omitempty"`
// CreatedAt holds the value of the "created_at" field. // CreatedAt holds the value of the "created_at" field.
CreatedAt time.Time `json:"created_at,omitempty"` CreatedAt time.Time `json:"created_at,omitempty"`
// UpdatedAt holds the value of the "updated_at" field. // UpdatedAt holds the value of the "updated_at" field.
...@@ -47,7 +49,7 @@ func (*PaymentProviderInstance) scanValues(columns []string) ([]any, error) { ...@@ -47,7 +49,7 @@ func (*PaymentProviderInstance) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns)) values := make([]any, len(columns))
for i := range columns { for i := range columns {
switch columns[i] { switch columns[i] {
case paymentproviderinstance.FieldEnabled, paymentproviderinstance.FieldRefundEnabled: case paymentproviderinstance.FieldEnabled, paymentproviderinstance.FieldRefundEnabled, paymentproviderinstance.FieldAllowUserRefund:
values[i] = new(sql.NullBool) values[i] = new(sql.NullBool)
case paymentproviderinstance.FieldID, paymentproviderinstance.FieldSortOrder: case paymentproviderinstance.FieldID, paymentproviderinstance.FieldSortOrder:
values[i] = new(sql.NullInt64) values[i] = new(sql.NullInt64)
...@@ -130,6 +132,12 @@ func (_m *PaymentProviderInstance) assignValues(columns []string, values []any) ...@@ -130,6 +132,12 @@ func (_m *PaymentProviderInstance) assignValues(columns []string, values []any)
} else if value.Valid { } else if value.Valid {
_m.RefundEnabled = value.Bool _m.RefundEnabled = value.Bool
} }
case paymentproviderinstance.FieldAllowUserRefund:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field allow_user_refund", values[i])
} else if value.Valid {
_m.AllowUserRefund = value.Bool
}
case paymentproviderinstance.FieldCreatedAt: case paymentproviderinstance.FieldCreatedAt:
if value, ok := values[i].(*sql.NullTime); !ok { if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field created_at", values[i]) return fmt.Errorf("unexpected type %T for field created_at", values[i])
...@@ -205,6 +213,9 @@ func (_m *PaymentProviderInstance) String() string { ...@@ -205,6 +213,9 @@ func (_m *PaymentProviderInstance) String() string {
builder.WriteString("refund_enabled=") builder.WriteString("refund_enabled=")
builder.WriteString(fmt.Sprintf("%v", _m.RefundEnabled)) builder.WriteString(fmt.Sprintf("%v", _m.RefundEnabled))
builder.WriteString(", ") builder.WriteString(", ")
builder.WriteString("allow_user_refund=")
builder.WriteString(fmt.Sprintf("%v", _m.AllowUserRefund))
builder.WriteString(", ")
builder.WriteString("created_at=") builder.WriteString("created_at=")
builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
builder.WriteString(", ") builder.WriteString(", ")
......
...@@ -31,6 +31,8 @@ const ( ...@@ -31,6 +31,8 @@ const (
FieldLimits = "limits" FieldLimits = "limits"
// FieldRefundEnabled holds the string denoting the refund_enabled field in the database. // FieldRefundEnabled holds the string denoting the refund_enabled field in the database.
FieldRefundEnabled = "refund_enabled" FieldRefundEnabled = "refund_enabled"
// FieldAllowUserRefund holds the string denoting the allow_user_refund field in the database.
FieldAllowUserRefund = "allow_user_refund"
// FieldCreatedAt holds the string denoting the created_at field in the database. // FieldCreatedAt holds the string denoting the created_at field in the database.
FieldCreatedAt = "created_at" FieldCreatedAt = "created_at"
// FieldUpdatedAt holds the string denoting the updated_at field in the database. // FieldUpdatedAt holds the string denoting the updated_at field in the database.
...@@ -51,6 +53,7 @@ var Columns = []string{ ...@@ -51,6 +53,7 @@ var Columns = []string{
FieldSortOrder, FieldSortOrder,
FieldLimits, FieldLimits,
FieldRefundEnabled, FieldRefundEnabled,
FieldAllowUserRefund,
FieldCreatedAt, FieldCreatedAt,
FieldUpdatedAt, FieldUpdatedAt,
} }
...@@ -88,6 +91,8 @@ var ( ...@@ -88,6 +91,8 @@ var (
DefaultLimits string DefaultLimits string
// DefaultRefundEnabled holds the default value on creation for the "refund_enabled" field. // DefaultRefundEnabled holds the default value on creation for the "refund_enabled" field.
DefaultRefundEnabled bool DefaultRefundEnabled bool
// DefaultAllowUserRefund holds the default value on creation for the "allow_user_refund" field.
DefaultAllowUserRefund bool
// DefaultCreatedAt holds the default value on creation for the "created_at" field. // DefaultCreatedAt holds the default value on creation for the "created_at" field.
DefaultCreatedAt func() time.Time DefaultCreatedAt func() time.Time
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field. // DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
...@@ -149,6 +154,11 @@ func ByRefundEnabled(opts ...sql.OrderTermOption) OrderOption { ...@@ -149,6 +154,11 @@ func ByRefundEnabled(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldRefundEnabled, opts...).ToFunc() return sql.OrderByField(FieldRefundEnabled, opts...).ToFunc()
} }
// ByAllowUserRefund orders the results by the allow_user_refund field.
func ByAllowUserRefund(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldAllowUserRefund, opts...).ToFunc()
}
// ByCreatedAt orders the results by the created_at field. // ByCreatedAt orders the results by the created_at field.
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
......
...@@ -99,6 +99,11 @@ func RefundEnabled(v bool) predicate.PaymentProviderInstance { ...@@ -99,6 +99,11 @@ func RefundEnabled(v bool) predicate.PaymentProviderInstance {
return predicate.PaymentProviderInstance(sql.FieldEQ(FieldRefundEnabled, v)) return predicate.PaymentProviderInstance(sql.FieldEQ(FieldRefundEnabled, v))
} }
// AllowUserRefund applies equality check predicate on the "allow_user_refund" field. It's identical to AllowUserRefundEQ.
func AllowUserRefund(v bool) predicate.PaymentProviderInstance {
return predicate.PaymentProviderInstance(sql.FieldEQ(FieldAllowUserRefund, v))
}
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
func CreatedAt(v time.Time) predicate.PaymentProviderInstance { func CreatedAt(v time.Time) predicate.PaymentProviderInstance {
return predicate.PaymentProviderInstance(sql.FieldEQ(FieldCreatedAt, v)) return predicate.PaymentProviderInstance(sql.FieldEQ(FieldCreatedAt, v))
...@@ -559,6 +564,16 @@ func RefundEnabledNEQ(v bool) predicate.PaymentProviderInstance { ...@@ -559,6 +564,16 @@ func RefundEnabledNEQ(v bool) predicate.PaymentProviderInstance {
return predicate.PaymentProviderInstance(sql.FieldNEQ(FieldRefundEnabled, v)) return predicate.PaymentProviderInstance(sql.FieldNEQ(FieldRefundEnabled, v))
} }
// AllowUserRefundEQ applies the EQ predicate on the "allow_user_refund" field.
func AllowUserRefundEQ(v bool) predicate.PaymentProviderInstance {
return predicate.PaymentProviderInstance(sql.FieldEQ(FieldAllowUserRefund, v))
}
// AllowUserRefundNEQ applies the NEQ predicate on the "allow_user_refund" field.
func AllowUserRefundNEQ(v bool) predicate.PaymentProviderInstance {
return predicate.PaymentProviderInstance(sql.FieldNEQ(FieldAllowUserRefund, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field. // CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.PaymentProviderInstance { func CreatedAtEQ(v time.Time) predicate.PaymentProviderInstance {
return predicate.PaymentProviderInstance(sql.FieldEQ(FieldCreatedAt, v)) return predicate.PaymentProviderInstance(sql.FieldEQ(FieldCreatedAt, v))
......
...@@ -132,6 +132,20 @@ func (_c *PaymentProviderInstanceCreate) SetNillableRefundEnabled(v *bool) *Paym ...@@ -132,6 +132,20 @@ func (_c *PaymentProviderInstanceCreate) SetNillableRefundEnabled(v *bool) *Paym
return _c return _c
} }
// SetAllowUserRefund sets the "allow_user_refund" field.
func (_c *PaymentProviderInstanceCreate) SetAllowUserRefund(v bool) *PaymentProviderInstanceCreate {
_c.mutation.SetAllowUserRefund(v)
return _c
}
// SetNillableAllowUserRefund sets the "allow_user_refund" field if the given value is not nil.
func (_c *PaymentProviderInstanceCreate) SetNillableAllowUserRefund(v *bool) *PaymentProviderInstanceCreate {
if v != nil {
_c.SetAllowUserRefund(*v)
}
return _c
}
// SetCreatedAt sets the "created_at" field. // SetCreatedAt sets the "created_at" field.
func (_c *PaymentProviderInstanceCreate) SetCreatedAt(v time.Time) *PaymentProviderInstanceCreate { func (_c *PaymentProviderInstanceCreate) SetCreatedAt(v time.Time) *PaymentProviderInstanceCreate {
_c.mutation.SetCreatedAt(v) _c.mutation.SetCreatedAt(v)
...@@ -223,6 +237,10 @@ func (_c *PaymentProviderInstanceCreate) defaults() { ...@@ -223,6 +237,10 @@ func (_c *PaymentProviderInstanceCreate) defaults() {
v := paymentproviderinstance.DefaultRefundEnabled v := paymentproviderinstance.DefaultRefundEnabled
_c.mutation.SetRefundEnabled(v) _c.mutation.SetRefundEnabled(v)
} }
if _, ok := _c.mutation.AllowUserRefund(); !ok {
v := paymentproviderinstance.DefaultAllowUserRefund
_c.mutation.SetAllowUserRefund(v)
}
if _, ok := _c.mutation.CreatedAt(); !ok { if _, ok := _c.mutation.CreatedAt(); !ok {
v := paymentproviderinstance.DefaultCreatedAt() v := paymentproviderinstance.DefaultCreatedAt()
_c.mutation.SetCreatedAt(v) _c.mutation.SetCreatedAt(v)
...@@ -282,6 +300,9 @@ func (_c *PaymentProviderInstanceCreate) check() error { ...@@ -282,6 +300,9 @@ func (_c *PaymentProviderInstanceCreate) check() error {
if _, ok := _c.mutation.RefundEnabled(); !ok { if _, ok := _c.mutation.RefundEnabled(); !ok {
return &ValidationError{Name: "refund_enabled", err: errors.New(`ent: missing required field "PaymentProviderInstance.refund_enabled"`)} return &ValidationError{Name: "refund_enabled", err: errors.New(`ent: missing required field "PaymentProviderInstance.refund_enabled"`)}
} }
if _, ok := _c.mutation.AllowUserRefund(); !ok {
return &ValidationError{Name: "allow_user_refund", err: errors.New(`ent: missing required field "PaymentProviderInstance.allow_user_refund"`)}
}
if _, ok := _c.mutation.CreatedAt(); !ok { if _, ok := _c.mutation.CreatedAt(); !ok {
return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "PaymentProviderInstance.created_at"`)} return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "PaymentProviderInstance.created_at"`)}
} }
...@@ -351,6 +372,10 @@ func (_c *PaymentProviderInstanceCreate) createSpec() (*PaymentProviderInstance, ...@@ -351,6 +372,10 @@ func (_c *PaymentProviderInstanceCreate) createSpec() (*PaymentProviderInstance,
_spec.SetField(paymentproviderinstance.FieldRefundEnabled, field.TypeBool, value) _spec.SetField(paymentproviderinstance.FieldRefundEnabled, field.TypeBool, value)
_node.RefundEnabled = value _node.RefundEnabled = value
} }
if value, ok := _c.mutation.AllowUserRefund(); ok {
_spec.SetField(paymentproviderinstance.FieldAllowUserRefund, field.TypeBool, value)
_node.AllowUserRefund = value
}
if value, ok := _c.mutation.CreatedAt(); ok { if value, ok := _c.mutation.CreatedAt(); ok {
_spec.SetField(paymentproviderinstance.FieldCreatedAt, field.TypeTime, value) _spec.SetField(paymentproviderinstance.FieldCreatedAt, field.TypeTime, value)
_node.CreatedAt = value _node.CreatedAt = value
...@@ -525,6 +550,18 @@ func (u *PaymentProviderInstanceUpsert) UpdateRefundEnabled() *PaymentProviderIn ...@@ -525,6 +550,18 @@ func (u *PaymentProviderInstanceUpsert) UpdateRefundEnabled() *PaymentProviderIn
return u return u
} }
// SetAllowUserRefund sets the "allow_user_refund" field.
func (u *PaymentProviderInstanceUpsert) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpsert {
u.Set(paymentproviderinstance.FieldAllowUserRefund, v)
return u
}
// UpdateAllowUserRefund sets the "allow_user_refund" field to the value that was provided on create.
func (u *PaymentProviderInstanceUpsert) UpdateAllowUserRefund() *PaymentProviderInstanceUpsert {
u.SetExcluded(paymentproviderinstance.FieldAllowUserRefund)
return u
}
// SetUpdatedAt sets the "updated_at" field. // SetUpdatedAt sets the "updated_at" field.
func (u *PaymentProviderInstanceUpsert) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpsert { func (u *PaymentProviderInstanceUpsert) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpsert {
u.Set(paymentproviderinstance.FieldUpdatedAt, v) u.Set(paymentproviderinstance.FieldUpdatedAt, v)
...@@ -715,6 +752,20 @@ func (u *PaymentProviderInstanceUpsertOne) UpdateRefundEnabled() *PaymentProvide ...@@ -715,6 +752,20 @@ func (u *PaymentProviderInstanceUpsertOne) UpdateRefundEnabled() *PaymentProvide
}) })
} }
// SetAllowUserRefund sets the "allow_user_refund" field.
func (u *PaymentProviderInstanceUpsertOne) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpsertOne {
return u.Update(func(s *PaymentProviderInstanceUpsert) {
s.SetAllowUserRefund(v)
})
}
// UpdateAllowUserRefund sets the "allow_user_refund" field to the value that was provided on create.
func (u *PaymentProviderInstanceUpsertOne) UpdateAllowUserRefund() *PaymentProviderInstanceUpsertOne {
return u.Update(func(s *PaymentProviderInstanceUpsert) {
s.UpdateAllowUserRefund()
})
}
// SetUpdatedAt sets the "updated_at" field. // SetUpdatedAt sets the "updated_at" field.
func (u *PaymentProviderInstanceUpsertOne) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpsertOne { func (u *PaymentProviderInstanceUpsertOne) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpsertOne {
return u.Update(func(s *PaymentProviderInstanceUpsert) { return u.Update(func(s *PaymentProviderInstanceUpsert) {
...@@ -1073,6 +1124,20 @@ func (u *PaymentProviderInstanceUpsertBulk) UpdateRefundEnabled() *PaymentProvid ...@@ -1073,6 +1124,20 @@ func (u *PaymentProviderInstanceUpsertBulk) UpdateRefundEnabled() *PaymentProvid
}) })
} }
// SetAllowUserRefund sets the "allow_user_refund" field.
func (u *PaymentProviderInstanceUpsertBulk) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpsertBulk {
return u.Update(func(s *PaymentProviderInstanceUpsert) {
s.SetAllowUserRefund(v)
})
}
// UpdateAllowUserRefund sets the "allow_user_refund" field to the value that was provided on create.
func (u *PaymentProviderInstanceUpsertBulk) UpdateAllowUserRefund() *PaymentProviderInstanceUpsertBulk {
return u.Update(func(s *PaymentProviderInstanceUpsert) {
s.UpdateAllowUserRefund()
})
}
// SetUpdatedAt sets the "updated_at" field. // SetUpdatedAt sets the "updated_at" field.
func (u *PaymentProviderInstanceUpsertBulk) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpsertBulk { func (u *PaymentProviderInstanceUpsertBulk) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpsertBulk {
return u.Update(func(s *PaymentProviderInstanceUpsert) { return u.Update(func(s *PaymentProviderInstanceUpsert) {
......
...@@ -161,6 +161,20 @@ func (_u *PaymentProviderInstanceUpdate) SetNillableRefundEnabled(v *bool) *Paym ...@@ -161,6 +161,20 @@ func (_u *PaymentProviderInstanceUpdate) SetNillableRefundEnabled(v *bool) *Paym
return _u return _u
} }
// SetAllowUserRefund sets the "allow_user_refund" field.
func (_u *PaymentProviderInstanceUpdate) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpdate {
_u.mutation.SetAllowUserRefund(v)
return _u
}
// SetNillableAllowUserRefund sets the "allow_user_refund" field if the given value is not nil.
func (_u *PaymentProviderInstanceUpdate) SetNillableAllowUserRefund(v *bool) *PaymentProviderInstanceUpdate {
if v != nil {
_u.SetAllowUserRefund(*v)
}
return _u
}
// SetUpdatedAt sets the "updated_at" field. // SetUpdatedAt sets the "updated_at" field.
func (_u *PaymentProviderInstanceUpdate) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpdate { func (_u *PaymentProviderInstanceUpdate) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpdate {
_u.mutation.SetUpdatedAt(v) _u.mutation.SetUpdatedAt(v)
...@@ -275,6 +289,9 @@ func (_u *PaymentProviderInstanceUpdate) sqlSave(ctx context.Context) (_node int ...@@ -275,6 +289,9 @@ func (_u *PaymentProviderInstanceUpdate) sqlSave(ctx context.Context) (_node int
if value, ok := _u.mutation.RefundEnabled(); ok { if value, ok := _u.mutation.RefundEnabled(); ok {
_spec.SetField(paymentproviderinstance.FieldRefundEnabled, field.TypeBool, value) _spec.SetField(paymentproviderinstance.FieldRefundEnabled, field.TypeBool, value)
} }
if value, ok := _u.mutation.AllowUserRefund(); ok {
_spec.SetField(paymentproviderinstance.FieldAllowUserRefund, field.TypeBool, value)
}
if value, ok := _u.mutation.UpdatedAt(); ok { if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(paymentproviderinstance.FieldUpdatedAt, field.TypeTime, value) _spec.SetField(paymentproviderinstance.FieldUpdatedAt, field.TypeTime, value)
} }
...@@ -431,6 +448,20 @@ func (_u *PaymentProviderInstanceUpdateOne) SetNillableRefundEnabled(v *bool) *P ...@@ -431,6 +448,20 @@ func (_u *PaymentProviderInstanceUpdateOne) SetNillableRefundEnabled(v *bool) *P
return _u return _u
} }
// SetAllowUserRefund sets the "allow_user_refund" field.
func (_u *PaymentProviderInstanceUpdateOne) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpdateOne {
_u.mutation.SetAllowUserRefund(v)
return _u
}
// SetNillableAllowUserRefund sets the "allow_user_refund" field if the given value is not nil.
func (_u *PaymentProviderInstanceUpdateOne) SetNillableAllowUserRefund(v *bool) *PaymentProviderInstanceUpdateOne {
if v != nil {
_u.SetAllowUserRefund(*v)
}
return _u
}
// SetUpdatedAt sets the "updated_at" field. // SetUpdatedAt sets the "updated_at" field.
func (_u *PaymentProviderInstanceUpdateOne) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpdateOne { func (_u *PaymentProviderInstanceUpdateOne) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpdateOne {
_u.mutation.SetUpdatedAt(v) _u.mutation.SetUpdatedAt(v)
...@@ -575,6 +606,9 @@ func (_u *PaymentProviderInstanceUpdateOne) sqlSave(ctx context.Context) (_node ...@@ -575,6 +606,9 @@ func (_u *PaymentProviderInstanceUpdateOne) sqlSave(ctx context.Context) (_node
if value, ok := _u.mutation.RefundEnabled(); ok { if value, ok := _u.mutation.RefundEnabled(); ok {
_spec.SetField(paymentproviderinstance.FieldRefundEnabled, field.TypeBool, value) _spec.SetField(paymentproviderinstance.FieldRefundEnabled, field.TypeBool, value)
} }
if value, ok := _u.mutation.AllowUserRefund(); ok {
_spec.SetField(paymentproviderinstance.FieldAllowUserRefund, field.TypeBool, value)
}
if value, ok := _u.mutation.UpdatedAt(); ok { if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(paymentproviderinstance.FieldUpdatedAt, field.TypeTime, value) _spec.SetField(paymentproviderinstance.FieldUpdatedAt, field.TypeTime, value)
} }
......
...@@ -33,7 +33,6 @@ type IdempotencyRecord func(*sql.Selector) ...@@ -33,7 +33,6 @@ type IdempotencyRecord func(*sql.Selector)
// PaymentAuditLog is the predicate function for paymentauditlog builders. // PaymentAuditLog is the predicate function for paymentauditlog builders.
type PaymentAuditLog func(*sql.Selector) type PaymentAuditLog func(*sql.Selector)
// PaymentOrder is the predicate function for paymentorder builders. // PaymentOrder is the predicate function for paymentorder builders.
type PaymentOrder func(*sql.Selector) type PaymentOrder func(*sql.Selector)
......
...@@ -668,12 +668,16 @@ func init() { ...@@ -668,12 +668,16 @@ func init() {
paymentproviderinstanceDescRefundEnabled := paymentproviderinstanceFields[8].Descriptor() paymentproviderinstanceDescRefundEnabled := paymentproviderinstanceFields[8].Descriptor()
// paymentproviderinstance.DefaultRefundEnabled holds the default value on creation for the refund_enabled field. // paymentproviderinstance.DefaultRefundEnabled holds the default value on creation for the refund_enabled field.
paymentproviderinstance.DefaultRefundEnabled = paymentproviderinstanceDescRefundEnabled.Default.(bool) paymentproviderinstance.DefaultRefundEnabled = paymentproviderinstanceDescRefundEnabled.Default.(bool)
// paymentproviderinstanceDescAllowUserRefund is the schema descriptor for allow_user_refund field.
paymentproviderinstanceDescAllowUserRefund := paymentproviderinstanceFields[9].Descriptor()
// paymentproviderinstance.DefaultAllowUserRefund holds the default value on creation for the allow_user_refund field.
paymentproviderinstance.DefaultAllowUserRefund = paymentproviderinstanceDescAllowUserRefund.Default.(bool)
// paymentproviderinstanceDescCreatedAt is the schema descriptor for created_at field. // paymentproviderinstanceDescCreatedAt is the schema descriptor for created_at field.
paymentproviderinstanceDescCreatedAt := paymentproviderinstanceFields[9].Descriptor() paymentproviderinstanceDescCreatedAt := paymentproviderinstanceFields[10].Descriptor()
// paymentproviderinstance.DefaultCreatedAt holds the default value on creation for the created_at field. // paymentproviderinstance.DefaultCreatedAt holds the default value on creation for the created_at field.
paymentproviderinstance.DefaultCreatedAt = paymentproviderinstanceDescCreatedAt.Default.(func() time.Time) paymentproviderinstance.DefaultCreatedAt = paymentproviderinstanceDescCreatedAt.Default.(func() time.Time)
// paymentproviderinstanceDescUpdatedAt is the schema descriptor for updated_at field. // paymentproviderinstanceDescUpdatedAt is the schema descriptor for updated_at field.
paymentproviderinstanceDescUpdatedAt := paymentproviderinstanceFields[10].Descriptor() paymentproviderinstanceDescUpdatedAt := paymentproviderinstanceFields[11].Descriptor()
// paymentproviderinstance.DefaultUpdatedAt holds the default value on creation for the updated_at field. // paymentproviderinstance.DefaultUpdatedAt holds the default value on creation for the updated_at field.
paymentproviderinstance.DefaultUpdatedAt = paymentproviderinstanceDescUpdatedAt.Default.(func() time.Time) paymentproviderinstance.DefaultUpdatedAt = paymentproviderinstanceDescUpdatedAt.Default.(func() time.Time)
// paymentproviderinstance.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. // paymentproviderinstance.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
......
...@@ -53,6 +53,8 @@ func (PaymentProviderInstance) Fields() []ent.Field { ...@@ -53,6 +53,8 @@ func (PaymentProviderInstance) Fields() []ent.Field {
Default(""), Default(""),
field.Bool("refund_enabled"). field.Bool("refund_enabled").
Default(false), Default(false),
field.Bool("allow_user_refund").
Default(false),
field.Time("created_at"). field.Time("created_at").
Immutable(). Immutable().
Default(time.Now). Default(time.Now).
......
...@@ -335,6 +335,16 @@ func (h *PaymentHandler) RequestRefund(c *gin.Context) { ...@@ -335,6 +335,16 @@ func (h *PaymentHandler) RequestRefund(c *gin.Context) {
response.Success(c, gin.H{"message": "refund requested"}) response.Success(c, gin.H{"message": "refund requested"})
} }
// GetRefundEligibleProviders returns provider instance IDs that allow user refund.
func (h *PaymentHandler) GetRefundEligibleProviders(c *gin.Context) {
ids, err := h.configService.GetUserRefundEligibleInstanceIDs(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"provider_instance_ids": ids})
}
// VerifyOrderRequest is the request body for verifying a payment order. // VerifyOrderRequest is the request body for verifying a payment order.
type VerifyOrderRequest struct { type VerifyOrderRequest struct {
OutTradeNo string `json:"out_trade_no" binding:"required"` OutTradeNo string `json:"out_trade_no" binding:"required"`
......
...@@ -37,6 +37,7 @@ func RegisterPaymentRoutes( ...@@ -37,6 +37,7 @@ func RegisterPaymentRoutes(
orders.GET("/:id", paymentHandler.GetOrder) orders.GET("/:id", paymentHandler.GetOrder)
orders.POST("/:id/cancel", paymentHandler.CancelOrder) orders.POST("/:id/cancel", paymentHandler.CancelOrder)
orders.POST("/:id/refund-request", paymentHandler.RequestRefund) orders.POST("/:id/refund-request", paymentHandler.RequestRefund)
orders.GET("/refund-eligible-providers", paymentHandler.GetRefundEligibleProviders)
} }
} }
......
...@@ -2,7 +2,6 @@ package service ...@@ -2,7 +2,6 @@ package service
import ( import (
"context" "context"
"sort"
"strings" "strings"
) )
...@@ -116,14 +115,8 @@ func matchAccountStatsRule(rule *AccountStatsPricingRule, accountID, groupID int ...@@ -116,14 +115,8 @@ func matchAccountStatsRule(rule *AccountStatsPricingRule, accountID, groupID int
return false return false
} }
// wildcardMatch 通配符匹配候选项(用于排序)
type wildcardMatch struct {
prefixLen int
pricing *ChannelModelPricing
}
// findPricingForModel 在定价列表中查找匹配的模型定价。 // findPricingForModel 在定价列表中查找匹配的模型定价。
// 先精确匹配,再通配符匹配(前缀越长优先级越高)。 // 先精确匹配,再通配符匹配(按配置顺序,先匹配先使用)。
func findPricingForModel(pricingList []ChannelModelPricing, platform, modelLower string) *ChannelModelPricing { func findPricingForModel(pricingList []ChannelModelPricing, platform, modelLower string) *ChannelModelPricing {
// 精确匹配优先 // 精确匹配优先
for i := range pricingList { for i := range pricingList {
...@@ -137,8 +130,7 @@ func findPricingForModel(pricingList []ChannelModelPricing, platform, modelLower ...@@ -137,8 +130,7 @@ func findPricingForModel(pricingList []ChannelModelPricing, platform, modelLower
} }
} }
} }
// 通配符匹配:收集所有匹配项,按前缀长度降序取最长 // 通配符匹配:按配置顺序,先匹配先使用
var matches []wildcardMatch
for i := range pricingList { for i := range pricingList {
p := &pricingList[i] p := &pricingList[i]
if !isPlatformMatch(platform, p.Platform) { if !isPlatformMatch(platform, p.Platform) {
...@@ -151,17 +143,11 @@ func findPricingForModel(pricingList []ChannelModelPricing, platform, modelLower ...@@ -151,17 +143,11 @@ func findPricingForModel(pricingList []ChannelModelPricing, platform, modelLower
} }
prefix := strings.TrimSuffix(ml, "*") prefix := strings.TrimSuffix(ml, "*")
if strings.HasPrefix(modelLower, prefix) { if strings.HasPrefix(modelLower, prefix) {
matches = append(matches, wildcardMatch{prefixLen: len(prefix), pricing: p}) return p
} }
} }
} }
if len(matches) == 0 {
return nil return nil
}
sort.Slice(matches, func(i, j int) bool {
return matches[i].prefixLen > matches[j].prefixLen
})
return matches[0].pricing
} }
// isPlatformMatch 判断平台是否匹配(空平台视为不限平台)。 // isPlatformMatch 判断平台是否匹配(空平台视为不限平台)。
......
...@@ -147,14 +147,14 @@ func TestFindPricingForModel(t *testing.T) { ...@@ -147,14 +147,14 @@ func TestFindPricingForModel(t *testing.T) {
wantNil: true, wantNil: true,
}, },
{ {
name: "wildcard matches by longest prefix (most specific wins)", name: "wildcard matches by config order (first match wins)",
list: []ChannelModelPricing{ list: []ChannelModelPricing{
{ID: 10, Models: []string{"claude-*"}}, {ID: 10, Models: []string{"claude-*"}},
{ID: 11, Models: []string{"claude-opus-*"}}, {ID: 11, Models: []string{"claude-opus-*"}},
}, },
platform: "", platform: "",
model: "claude-opus-4", model: "claude-opus-4",
wantID: 11, // "claude-opus-*" is longer prefix, wins over "claude-*" wantID: 10, // config order: "claude-*" is first and matches, so it wins
}, },
{ {
name: "shorter wildcard used when longer does not match", name: "shorter wildcard used when longer does not match",
......
...@@ -30,6 +30,7 @@ type ProviderInstanceResponse struct { ...@@ -30,6 +30,7 @@ type ProviderInstanceResponse struct {
Limits string `json:"limits"` Limits string `json:"limits"`
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
RefundEnabled bool `json:"refund_enabled"` RefundEnabled bool `json:"refund_enabled"`
AllowUserRefund bool `json:"allow_user_refund"`
SortOrder int `json:"sort_order"` SortOrder int `json:"sort_order"`
PaymentMode string `json:"payment_mode"` PaymentMode string `json:"payment_mode"`
} }
...@@ -47,6 +48,7 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte ...@@ -47,6 +48,7 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte
ID: int64(inst.ID), ProviderKey: inst.ProviderKey, Name: inst.Name, ID: int64(inst.ID), ProviderKey: inst.ProviderKey, Name: inst.Name,
SupportedTypes: splitTypes(inst.SupportedTypes), Limits: inst.Limits, SupportedTypes: splitTypes(inst.SupportedTypes), Limits: inst.Limits,
Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled, Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled,
AllowUserRefund: inst.AllowUserRefund,
SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode, SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode,
} }
resp.Config, err = s.decryptAndMaskConfig(inst.Config) resp.Config, err = s.decryptAndMaskConfig(inst.Config)
...@@ -110,10 +112,12 @@ func (s *PaymentConfigService) CreateProviderInstance(ctx context.Context, req C ...@@ -110,10 +112,12 @@ func (s *PaymentConfigService) CreateProviderInstance(ctx context.Context, req C
if err != nil { if err != nil {
return nil, err return nil, err
} }
allowUserRefund := req.AllowUserRefund && req.RefundEnabled
return s.entClient.PaymentProviderInstance.Create(). return s.entClient.PaymentProviderInstance.Create().
SetProviderKey(req.ProviderKey).SetName(req.Name).SetConfig(enc). SetProviderKey(req.ProviderKey).SetName(req.Name).SetConfig(enc).
SetSupportedTypes(typesStr).SetEnabled(req.Enabled).SetPaymentMode(req.PaymentMode). SetSupportedTypes(typesStr).SetEnabled(req.Enabled).SetPaymentMode(req.PaymentMode).
SetSortOrder(req.SortOrder).SetLimits(req.Limits).SetRefundEnabled(req.RefundEnabled). SetSortOrder(req.SortOrder).SetLimits(req.Limits).SetRefundEnabled(req.RefundEnabled).
SetAllowUserRefund(allowUserRefund).
Save(ctx) Save(ctx)
} }
...@@ -221,6 +225,21 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in ...@@ -221,6 +225,21 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
} }
if req.RefundEnabled != nil { if req.RefundEnabled != nil {
u.SetRefundEnabled(*req.RefundEnabled) u.SetRefundEnabled(*req.RefundEnabled)
// Cascade: turning off refund_enabled also disables allow_user_refund
if !*req.RefundEnabled {
u.SetAllowUserRefund(false)
}
}
if req.AllowUserRefund != nil {
// Only allow enabling when refund_enabled is true
if *req.AllowUserRefund {
inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id)
if err == nil && inst.RefundEnabled {
u.SetAllowUserRefund(true)
}
} else {
u.SetAllowUserRefund(false)
}
} }
if req.PaymentMode != nil { if req.PaymentMode != nil {
u.SetPaymentMode(*req.PaymentMode) u.SetPaymentMode(*req.PaymentMode)
...@@ -233,6 +252,7 @@ func (s *PaymentConfigService) GetUserRefundEligibleInstanceIDs(ctx context.Cont ...@@ -233,6 +252,7 @@ func (s *PaymentConfigService) GetUserRefundEligibleInstanceIDs(ctx context.Cont
instances, err := s.entClient.PaymentProviderInstance.Query(). instances, err := s.entClient.PaymentProviderInstance.Query().
Where( Where(
paymentproviderinstance.RefundEnabledEQ(true), paymentproviderinstance.RefundEnabledEQ(true),
paymentproviderinstance.AllowUserRefundEQ(true),
).Select(paymentproviderinstance.FieldID).All(ctx) ).Select(paymentproviderinstance.FieldID).All(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
......
...@@ -114,6 +114,7 @@ type CreateProviderInstanceRequest struct { ...@@ -114,6 +114,7 @@ type CreateProviderInstanceRequest struct {
SortOrder int `json:"sort_order"` SortOrder int `json:"sort_order"`
Limits string `json:"limits"` Limits string `json:"limits"`
RefundEnabled bool `json:"refund_enabled"` RefundEnabled bool `json:"refund_enabled"`
AllowUserRefund bool `json:"allow_user_refund"`
} }
type UpdateProviderInstanceRequest struct { type UpdateProviderInstanceRequest struct {
...@@ -125,6 +126,7 @@ type UpdateProviderInstanceRequest struct { ...@@ -125,6 +126,7 @@ type UpdateProviderInstanceRequest struct {
SortOrder *int `json:"sort_order"` SortOrder *int `json:"sort_order"`
Limits *string `json:"limits"` Limits *string `json:"limits"`
RefundEnabled *bool `json:"refund_enabled"` RefundEnabled *bool `json:"refund_enabled"`
AllowUserRefund *bool `json:"allow_user_refund"`
} }
type CreatePlanRequest struct { type CreatePlanRequest struct {
GroupID int64 `json:"group_id"` GroupID int64 `json:"group_id"`
......
...@@ -17,6 +17,19 @@ import ( ...@@ -17,6 +17,19 @@ import (
// --- Refund Flow --- // --- Refund Flow ---
// getOrderProviderInstance looks up the provider instance that processed this order.
// Returns nil, nil for legacy orders without provider_instance_id.
func (s *PaymentService) getOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) {
if o.ProviderInstanceID == nil || *o.ProviderInstanceID == "" {
return nil, nil
}
instID, err := strconv.ParseInt(*o.ProviderInstanceID, 10, 64)
if err != nil {
return nil, nil
}
return s.entClient.PaymentProviderInstance.Get(ctx, instID)
}
func (s *PaymentService) RequestRefund(ctx context.Context, oid, uid int64, reason string) error { func (s *PaymentService) RequestRefund(ctx context.Context, oid, uid int64, reason string) error {
o, err := s.validateRefundRequest(ctx, oid, uid) o, err := s.validateRefundRequest(ctx, oid, uid)
if err != nil { if err != nil {
...@@ -57,6 +70,14 @@ func (s *PaymentService) validateRefundRequest(ctx context.Context, oid, uid int ...@@ -57,6 +70,14 @@ func (s *PaymentService) validateRefundRequest(ctx context.Context, oid, uid int
if o.Status != OrderStatusCompleted { if o.Status != OrderStatusCompleted {
return nil, infraerrors.BadRequest("INVALID_STATUS", "only completed orders can request refund") return nil, infraerrors.BadRequest("INVALID_STATUS", "only completed orders can request refund")
} }
// Check provider instance allows user refund
inst, err := s.getOrderProviderInstance(ctx, o)
if err != nil || inst == nil {
return nil, infraerrors.Forbidden("USER_REFUND_DISABLED", "refund is not available for this order")
}
if !inst.AllowUserRefund {
return nil, infraerrors.Forbidden("USER_REFUND_DISABLED", "user refund is not enabled for this provider")
}
return o, nil return o, nil
} }
...@@ -69,6 +90,18 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float ...@@ -69,6 +90,18 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float
if !psSliceContains(ok, o.Status) { if !psSliceContains(ok, o.Status) {
return nil, nil, infraerrors.BadRequest("INVALID_STATUS", "order status does not allow refund") return nil, nil, infraerrors.BadRequest("INVALID_STATUS", "order status does not allow refund")
} }
// Check provider instance allows admin refund
inst, instErr := s.getOrderProviderInstance(ctx, o)
if instErr != nil {
slog.Warn("refund: provider instance not found", "orderID", oid, "error", instErr)
}
if inst != nil && !inst.RefundEnabled {
return nil, nil, infraerrors.Forbidden("REFUND_DISABLED", "refund is not enabled for this provider")
}
if inst == nil && instErr == nil {
// Legacy order without provider_instance_id — block refund
return nil, nil, infraerrors.Forbidden("REFUND_DISABLED", "refund is not available for this order")
}
if math.IsNaN(amt) || math.IsInf(amt, 0) { if math.IsNaN(amt) || math.IsInf(amt, 0) {
return nil, nil, infraerrors.BadRequest("INVALID_AMOUNT", "invalid refund amount") return nil, nil, infraerrors.BadRequest("INVALID_AMOUNT", "invalid refund amount")
} }
...@@ -102,6 +135,15 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float ...@@ -102,6 +135,15 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float
func (s *PaymentService) prepDeduct(ctx context.Context, o *dbent.PaymentOrder, p *RefundPlan, force bool) *RefundResult { func (s *PaymentService) prepDeduct(ctx context.Context, o *dbent.PaymentOrder, p *RefundPlan, force bool) *RefundResult {
if o.OrderType == payment.OrderTypeSubscription { if o.OrderType == payment.OrderTypeSubscription {
p.DeductionType = payment.DeductionTypeSubscription p.DeductionType = payment.DeductionTypeSubscription
if o.SubscriptionGroupID != nil && o.SubscriptionDays != nil {
p.SubDaysToDeduct = *o.SubscriptionDays
sub, err := s.subscriptionSvc.GetActiveSubscription(ctx, o.UserID, *o.SubscriptionGroupID)
if err == nil && sub != nil {
p.SubscriptionID = sub.ID
} else if !force {
return &RefundResult{Success: false, Warning: "cannot find active subscription for deduction, use force", RequireForce: true}
}
}
return nil return nil
} }
u, err := s.userRepo.GetByID(ctx, o.UserID) u, err := s.userRepo.GetByID(ctx, o.UserID)
...@@ -137,6 +179,21 @@ func (s *PaymentService) ExecuteRefund(ctx context.Context, p *RefundPlan) (*Ref ...@@ -137,6 +179,21 @@ func (s *PaymentService) ExecuteRefund(ctx context.Context, p *RefundPlan) (*Ref
p.BalanceToDeduct = 0 p.BalanceToDeduct = 0
} }
} }
if p.DeductionType == payment.DeductionTypeSubscription && p.SubDaysToDeduct > 0 && p.SubscriptionID > 0 {
if !s.hasAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED") {
_, err := s.subscriptionSvc.ExtendSubscription(ctx, p.SubscriptionID, -p.SubDaysToDeduct)
if err != nil {
slog.Info("subscription deduction would expire, revoking", "orderID", p.OrderID, "subID", p.SubscriptionID, "days", p.SubDaysToDeduct)
if revokeErr := s.subscriptionSvc.RevokeSubscription(ctx, p.SubscriptionID); revokeErr != nil {
s.restoreStatus(ctx, p)
return nil, fmt.Errorf("revoke subscription: %w", revokeErr)
}
}
} else {
slog.Warn("skipping subscription deduction on retry (previous rollback failed)", "orderID", p.OrderID)
p.SubDaysToDeduct = 0
}
}
if err := s.gwRefund(ctx, p); err != nil { if err := s.gwRefund(ctx, p); err != nil {
return s.handleGwFail(ctx, p, err) return s.handleGwFail(ctx, p, err)
} }
...@@ -204,6 +261,13 @@ func (s *PaymentService) RollbackRefund(ctx context.Context, p *RefundPlan, gErr ...@@ -204,6 +261,13 @@ func (s *PaymentService) RollbackRefund(ctx context.Context, p *RefundPlan, gErr
return false return false
} }
} }
if p.DeductionType == payment.DeductionTypeSubscription && p.SubDaysToDeduct > 0 && p.SubscriptionID > 0 {
if _, err := s.subscriptionSvc.ExtendSubscription(ctx, p.SubscriptionID, p.SubDaysToDeduct); err != nil {
slog.Error("[CRITICAL] subscription rollback failed", "orderID", p.OrderID, "subID", p.SubscriptionID, "days", p.SubDaysToDeduct, "error", err)
s.writeAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED", "admin", map[string]any{"gatewayError": psErrMsg(gErr), "rollbackError": psErrMsg(err), "subDaysDeducted": p.SubDaysToDeduct})
return false
}
}
return true return true
} }
......
ALTER TABLE payment_provider_instances ADD COLUMN IF NOT EXISTS allow_user_refund BOOLEAN NOT NULL DEFAULT false;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment