package service import ( "context" "fmt" "log/slog" "math" "strconv" "strings" "time" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/internal/payment" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" ) // --- Refund Flow --- func (s *PaymentService) RequestRefund(ctx context.Context, oid, uid int64, reason string) error { o, err := s.validateRefundRequest(ctx, oid, uid) if err != nil { return err } u, err := s.userRepo.GetByID(ctx, o.UserID) if err != nil { return fmt.Errorf("get user: %w", err) } if u.Balance < o.Amount { return infraerrors.BadRequest("BALANCE_NOT_ENOUGH", "refund amount exceeds balance") } nr := strings.TrimSpace(reason) now := time.Now() by := fmt.Sprintf("%d", uid) c, err := s.entClient.PaymentOrder.Update().Where(paymentorder.IDEQ(oid), paymentorder.UserIDEQ(uid), paymentorder.StatusEQ(OrderStatusCompleted), paymentorder.OrderTypeEQ(payment.OrderTypeBalance)).SetStatus(OrderStatusRefundRequested).SetRefundRequestedAt(now).SetRefundRequestReason(nr).SetRefundRequestedBy(by).SetRefundAmount(o.Amount).Save(ctx) if err != nil { return fmt.Errorf("update: %w", err) } if c == 0 { return infraerrors.Conflict("CONFLICT", "order status changed") } s.writeAuditLog(ctx, oid, "REFUND_REQUESTED", fmt.Sprintf("user:%d", uid), map[string]any{"amount": o.Amount, "reason": nr}) return nil } func (s *PaymentService) validateRefundRequest(ctx context.Context, oid, uid int64) (*dbent.PaymentOrder, error) { o, err := s.entClient.PaymentOrder.Get(ctx, oid) if err != nil { return nil, infraerrors.NotFound("NOT_FOUND", "order not found") } if o.UserID != uid { return nil, infraerrors.Forbidden("FORBIDDEN", "no permission") } if o.OrderType != payment.OrderTypeBalance { return nil, infraerrors.BadRequest("INVALID_ORDER_TYPE", "only balance orders can request refund") } if o.Status != OrderStatusCompleted { return nil, infraerrors.BadRequest("INVALID_STATUS", "only completed orders can request refund") } return o, nil } func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float64, reason string, force, deduct bool) (*RefundPlan, *RefundResult, error) { o, err := s.entClient.PaymentOrder.Get(ctx, oid) if err != nil { return nil, nil, infraerrors.NotFound("NOT_FOUND", "order not found") } ok := []string{OrderStatusCompleted, OrderStatusRefundRequested, OrderStatusRefundFailed} if !psSliceContains(ok, o.Status) { return nil, nil, infraerrors.BadRequest("INVALID_STATUS", "order status does not allow refund") } if math.IsNaN(amt) || math.IsInf(amt, 0) { return nil, nil, infraerrors.BadRequest("INVALID_AMOUNT", "invalid refund amount") } if amt <= 0 { amt = o.Amount } if amt-o.Amount > amountToleranceCNY { return nil, nil, infraerrors.BadRequest("REFUND_AMOUNT_EXCEEDED", "refund amount exceeds recharge") } // Full refund: use actual pay_amount for gateway (includes fees) ga := amt if math.Abs(amt-o.Amount) <= amountToleranceCNY { ga = o.PayAmount } rr := strings.TrimSpace(reason) if rr == "" && o.RefundRequestReason != nil { rr = *o.RefundRequestReason } if rr == "" { rr = fmt.Sprintf("refund order:%d", o.ID) } p := &RefundPlan{OrderID: oid, Order: o, RefundAmount: amt, GatewayAmount: ga, Reason: rr, Force: force, DeductBalance: deduct, DeductionType: payment.DeductionTypeNone} if deduct { if er := s.prepDeduct(ctx, o, p, force); er != nil { return nil, er, nil } } return p, nil, nil } func (s *PaymentService) prepDeduct(ctx context.Context, o *dbent.PaymentOrder, p *RefundPlan, force bool) *RefundResult { if o.OrderType == payment.OrderTypeSubscription { p.DeductionType = payment.DeductionTypeSubscription if o.SubscriptionGroupID != nil && o.SubscriptionDays != nil { p.SubDaysToDeduct = *o.SubscriptionDays sub, err := s.subscriptionSvc.GetActiveSubscription(ctx, o.UserID, *o.SubscriptionGroupID) if err == nil && sub != nil { p.SubscriptionID = sub.ID } else if !force { return &RefundResult{Success: false, Warning: "cannot find active subscription for deduction, use force", RequireForce: true} } } return nil } u, err := s.userRepo.GetByID(ctx, o.UserID) if err != nil { if !force { return &RefundResult{Success: false, Warning: "cannot fetch user balance, use force", RequireForce: true} } return nil } p.DeductionType = payment.DeductionTypeBalance p.BalanceToDeduct = math.Min(p.RefundAmount, u.Balance) return nil } func (s *PaymentService) ExecuteRefund(ctx context.Context, p *RefundPlan) (*RefundResult, error) { c, err := s.entClient.PaymentOrder.Update().Where(paymentorder.IDEQ(p.OrderID), paymentorder.StatusIn(OrderStatusCompleted, OrderStatusRefundRequested, OrderStatusRefundFailed)).SetStatus(OrderStatusRefunding).Save(ctx) if err != nil { return nil, fmt.Errorf("lock: %w", err) } if c == 0 { return nil, infraerrors.Conflict("CONFLICT", "order status changed") } if p.DeductionType == payment.DeductionTypeBalance && p.BalanceToDeduct > 0 { // Skip balance deduction on retry if previous attempt already deducted // but failed to roll back (REFUND_ROLLBACK_FAILED in audit log). if !s.hasAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED") { if err := s.userRepo.DeductBalance(ctx, p.Order.UserID, p.BalanceToDeduct); err != nil { s.restoreStatus(ctx, p) return nil, fmt.Errorf("deduction: %w", err) } } else { slog.Warn("skipping balance deduction on retry (previous rollback failed)", "orderID", p.OrderID) p.BalanceToDeduct = 0 } } if p.DeductionType == payment.DeductionTypeSubscription && p.SubDaysToDeduct > 0 && p.SubscriptionID > 0 { if !s.hasAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED") { _, err := s.subscriptionSvc.ExtendSubscription(ctx, p.SubscriptionID, -p.SubDaysToDeduct) if err != nil { // If deducting would expire the subscription, revoke it entirely slog.Info("subscription deduction would expire, revoking", "orderID", p.OrderID, "subID", p.SubscriptionID, "days", p.SubDaysToDeduct) if revokeErr := s.subscriptionSvc.RevokeSubscription(ctx, p.SubscriptionID); revokeErr != nil { s.restoreStatus(ctx, p) return nil, fmt.Errorf("revoke subscription: %w", revokeErr) } } } else { slog.Warn("skipping subscription deduction on retry (previous rollback failed)", "orderID", p.OrderID) p.SubDaysToDeduct = 0 } } if err := s.gwRefund(ctx, p); err != nil { return s.handleGwFail(ctx, p, err) } return s.markRefundOk(ctx, p) } func (s *PaymentService) gwRefund(ctx context.Context, p *RefundPlan) error { if p.Order.PaymentTradeNo == "" { s.writeAuditLog(ctx, p.Order.ID, "REFUND_NO_TRADE_NO", "admin", map[string]any{"detail": "skipped"}) return nil } // Use the exact provider instance that created this order, not a random one // from the registry. Each instance has its own merchant credentials. prov, err := s.getRefundProvider(ctx, p.Order) if err != nil { return fmt.Errorf("get refund provider: %w", err) } _, err = prov.Refund(ctx, payment.RefundRequest{ TradeNo: p.Order.PaymentTradeNo, OrderID: p.Order.OutTradeNo, Amount: strconv.FormatFloat(p.GatewayAmount, 'f', 2, 64), Reason: p.Reason, }) return err } // getRefundProvider creates a provider using the order's original instance config. // Delegates to getOrderProvider which handles instance lookup and fallback. func (s *PaymentService) getRefundProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) { return s.getOrderProvider(ctx, o) } func (s *PaymentService) handleGwFail(ctx context.Context, p *RefundPlan, gErr error) (*RefundResult, error) { if s.RollbackRefund(ctx, p, gErr) { s.restoreStatus(ctx, p) s.writeAuditLog(ctx, p.OrderID, "REFUND_GATEWAY_FAILED", "admin", map[string]any{"detail": psErrMsg(gErr)}) return &RefundResult{Success: false, Warning: "gateway failed: " + psErrMsg(gErr) + ", rolled back"}, nil } now := time.Now() _, _ = s.entClient.PaymentOrder.UpdateOneID(p.OrderID).SetStatus(OrderStatusRefundFailed).SetFailedAt(now).SetFailedReason(psErrMsg(gErr)).Save(ctx) s.writeAuditLog(ctx, p.OrderID, "REFUND_FAILED", "admin", map[string]any{"detail": psErrMsg(gErr)}) return nil, infraerrors.InternalServer("REFUND_FAILED", psErrMsg(gErr)) } func (s *PaymentService) markRefundOk(ctx context.Context, p *RefundPlan) (*RefundResult, error) { fs := OrderStatusRefunded if p.RefundAmount < p.Order.Amount { fs = OrderStatusPartiallyRefunded } now := time.Now() _, err := s.entClient.PaymentOrder.UpdateOneID(p.OrderID).SetStatus(fs).SetRefundAmount(p.RefundAmount).SetRefundReason(p.Reason).SetRefundAt(now).SetForceRefund(p.Force).Save(ctx) if err != nil { return nil, fmt.Errorf("mark refund: %w", err) } s.writeAuditLog(ctx, p.OrderID, "REFUND_SUCCESS", "admin", map[string]any{"refundAmount": p.RefundAmount, "reason": p.Reason, "balanceDeducted": p.BalanceToDeduct, "force": p.Force}) return &RefundResult{Success: true, BalanceDeducted: p.BalanceToDeduct, SubDaysDeducted: p.SubDaysToDeduct}, nil } func (s *PaymentService) RollbackRefund(ctx context.Context, p *RefundPlan, gErr error) bool { if p.DeductionType == payment.DeductionTypeBalance && p.BalanceToDeduct > 0 { if err := s.userRepo.UpdateBalance(ctx, p.Order.UserID, p.BalanceToDeduct); err != nil { slog.Error("[CRITICAL] rollback failed", "orderID", p.OrderID, "amount", p.BalanceToDeduct, "error", err) s.writeAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED", "admin", map[string]any{"gatewayError": psErrMsg(gErr), "rollbackError": psErrMsg(err), "balanceDeducted": p.BalanceToDeduct}) return false } } if p.DeductionType == payment.DeductionTypeSubscription && p.SubDaysToDeduct > 0 && p.SubscriptionID > 0 { if _, err := s.subscriptionSvc.ExtendSubscription(ctx, p.SubscriptionID, p.SubDaysToDeduct); err != nil { slog.Error("[CRITICAL] subscription rollback failed", "orderID", p.OrderID, "subID", p.SubscriptionID, "days", p.SubDaysToDeduct, "error", err) s.writeAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED", "admin", map[string]any{"gatewayError": psErrMsg(gErr), "rollbackError": psErrMsg(err), "subDaysDeducted": p.SubDaysToDeduct}) return false } } return true } func (s *PaymentService) restoreStatus(ctx context.Context, p *RefundPlan) { rs := OrderStatusCompleted if p.Order.Status == OrderStatusRefundRequested { rs = OrderStatusRefundRequested } _, _ = s.entClient.PaymentOrder.UpdateOneID(p.OrderID).SetStatus(rs).Save(ctx) }