Commit 63d1860d authored by erio's avatar erio
Browse files

feat(payment): add complete payment system with multi-provider support

Add a full payment and subscription system supporting EasyPay (Alipay/WeChat),
Stripe, and direct Alipay/WeChat Pay providers with multi-instance load balancing.
parent 00c08c57
// Package payment provides the core payment provider abstraction,
// registry, load balancing, and shared utilities for the payment subsystem.
package payment
import "context"
// PaymentType represents a supported payment method.
type PaymentType = string
// Supported payment type constants.
const (
TypeAlipay PaymentType = "alipay"
TypeWxpay PaymentType = "wxpay"
TypeAlipayDirect PaymentType = "alipay_direct"
TypeWxpayDirect PaymentType = "wxpay_direct"
TypeStripe PaymentType = "stripe"
TypeCard PaymentType = "card"
TypeLink PaymentType = "link"
TypeEasyPay PaymentType = "easypay"
)
// Order status constants shared across payment and service layers.
const (
OrderStatusPending = "PENDING"
OrderStatusPaid = "PAID"
OrderStatusRecharging = "RECHARGING"
OrderStatusCompleted = "COMPLETED"
OrderStatusExpired = "EXPIRED"
OrderStatusCancelled = "CANCELLED"
OrderStatusFailed = "FAILED"
OrderStatusRefundRequested = "REFUND_REQUESTED"
OrderStatusRefunding = "REFUNDING"
OrderStatusPartiallyRefunded = "PARTIALLY_REFUNDED"
OrderStatusRefunded = "REFUNDED"
OrderStatusRefundFailed = "REFUND_FAILED"
)
// Order types distinguish balance recharges from subscription purchases.
const (
OrderTypeBalance = "balance"
OrderTypeSubscription = "subscription"
)
// Entity statuses shared across users, groups, etc.
const (
EntityStatusActive = "active"
)
// Deduction types for refund flow.
const (
DeductionTypeBalance = "balance"
DeductionTypeSubscription = "subscription"
DeductionTypeNone = "none"
)
// Payment notification status values.
const (
NotificationStatusSuccess = "success"
NotificationStatusPaid = "paid"
)
// Provider-level status constants returned by provider implementations
// to the service layer (lowercase, distinct from OrderStatus uppercase constants).
const (
ProviderStatusPending = "pending"
ProviderStatusPaid = "paid"
ProviderStatusSuccess = "success"
ProviderStatusFailed = "failed"
ProviderStatusRefunded = "refunded"
)
// DefaultLoadBalanceStrategy is the default load-balancing strategy
// used when no strategy is configured.
const DefaultLoadBalanceStrategy = "round-robin"
// ConfigKeyPublishableKey is the config map key for Stripe's publishable key.
const ConfigKeyPublishableKey = "publishableKey"
// GetBasePaymentType extracts the base payment method from a composite key.
// For example, "alipay_direct" -> "alipay".
func GetBasePaymentType(t string) string {
switch {
case t == TypeEasyPay:
return TypeEasyPay
case t == TypeStripe || t == TypeCard || t == TypeLink:
return TypeStripe
case len(t) >= len(TypeAlipay) && t[:len(TypeAlipay)] == TypeAlipay:
return TypeAlipay
case len(t) >= len(TypeWxpay) && t[:len(TypeWxpay)] == TypeWxpay:
return TypeWxpay
default:
return t
}
}
// CreatePaymentRequest holds the parameters for creating a new payment.
type CreatePaymentRequest struct {
OrderID string // Internal order ID
Amount string // Pay amount in CNY (formatted to 2 decimal places)
PaymentType string // e.g. "alipay", "wxpay", "stripe"
Subject string // Product description
NotifyURL string // Webhook callback URL
ReturnURL string // Browser redirect URL after payment
ClientIP string // Payer's IP address
IsMobile bool // Whether the request comes from a mobile device
InstanceSubMethods string // Comma-separated sub-methods from instance supported_types (for Stripe)
}
// CreatePaymentResponse is returned after successfully initiating a payment.
type CreatePaymentResponse struct {
TradeNo string // Third-party transaction ID
PayURL string // H5 payment URL (alipay/wxpay)
QRCode string // QR code content for scanning
ClientSecret string // Stripe PaymentIntent client secret
}
// QueryOrderResponse describes the payment status from the upstream provider.
type QueryOrderResponse struct {
TradeNo string
Status string // "pending", "paid", "failed", "refunded"
Amount float64 // Amount in CNY
PaidAt string // RFC3339 timestamp or empty
}
// PaymentNotification is the parsed result of a webhook/notify callback.
type PaymentNotification struct {
TradeNo string
OrderID string
Amount float64
Status string // "success" or "failed"
RawData string // Raw notification body for audit
}
// RefundRequest contains the parameters for requesting a refund.
type RefundRequest struct {
TradeNo string
OrderID string
Amount string // Refund amount formatted to 2 decimal places
Reason string
}
// RefundResponse is returned after a refund request.
type RefundResponse struct {
RefundID string
Status string // "success", "pending", "failed"
}
// InstanceSelection holds the selected provider instance and its decrypted config.
type InstanceSelection struct {
InstanceID string
Config map[string]string
SupportedTypes string // Comma-separated list of supported payment types from the instance
PaymentMode string // Payment display mode: "qrcode", "redirect", "popup"
}
// Provider defines the interface that all payment providers must implement.
type Provider interface {
// Name returns a human-readable name for this provider.
Name() string
// ProviderKey returns the unique key identifying this provider type (e.g. "easypay").
ProviderKey() string
// SupportedTypes returns the list of payment types this provider handles.
SupportedTypes() []PaymentType
// CreatePayment initiates a payment and returns the upstream response.
CreatePayment(ctx context.Context, req CreatePaymentRequest) (*CreatePaymentResponse, error)
// QueryOrder queries the payment status of the given trade number.
QueryOrder(ctx context.Context, tradeNo string) (*QueryOrderResponse, error)
// VerifyNotification parses and verifies a webhook callback.
// Returns nil for unrecognized or irrelevant events (caller should return 200).
VerifyNotification(ctx context.Context, rawBody string, headers map[string]string) (*PaymentNotification, error)
// Refund requests a refund from the upstream provider.
Refund(ctx context.Context, req RefundRequest) (*RefundResponse, error)
}
// CancelableProvider extends Provider with the ability to cancel pending payments.
type CancelableProvider interface {
Provider
// CancelPayment cancels/expires a pending payment on the upstream platform.
CancelPayment(ctx context.Context, tradeNo string) error
}
package payment
import (
"encoding/hex"
"fmt"
"log/slog"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/google/wire"
)
// EncryptionKey is a named type for the payment encryption key (AES-256, 32 bytes).
// Using a named type avoids Wire ambiguity with other []byte parameters.
type EncryptionKey []byte
// ProvideEncryptionKey derives the payment encryption key from the TOTP encryption key in config.
// When the key is empty, nil is returned (payment features that need encryption will be disabled).
// When the key is non-empty but invalid (bad hex or wrong length), an error is returned
// to prevent startup with a misconfigured encryption key.
func ProvideEncryptionKey(cfg *config.Config) (EncryptionKey, error) {
if cfg.Totp.EncryptionKey == "" {
slog.Warn("payment encryption key not configured — encrypted payment config will be unavailable")
return nil, nil
}
key, err := hex.DecodeString(cfg.Totp.EncryptionKey)
if err != nil {
return nil, fmt.Errorf("invalid payment encryption key (hex decode): %w", err)
}
if len(key) != 32 {
return nil, fmt.Errorf("payment encryption key must be 32 bytes, got %d", len(key))
}
return EncryptionKey(key), nil
}
// ProvideRegistry creates an empty payment provider registry.
// Providers are registered at runtime after application startup.
func ProvideRegistry() *Registry {
return NewRegistry()
}
// ProvideDefaultLoadBalancer creates a DefaultLoadBalancer backed by the ent client.
func ProvideDefaultLoadBalancer(client *dbent.Client, key EncryptionKey) *DefaultLoadBalancer {
return NewDefaultLoadBalancer(client, []byte(key))
}
// ProviderSet is the Wire provider set for the payment package.
var ProviderSet = wire.NewSet(
ProvideEncryptionKey,
ProvideRegistry,
ProvideDefaultLoadBalancer,
wire.Bind(new(LoadBalancer), new(*DefaultLoadBalancer)),
)
......@@ -583,6 +583,24 @@ func TestAPIContracts(t *testing.T) {
"enable_cch_signing": false,
"enable_fingerprint_unification": true,
"enable_metadata_passthrough": false,
"payment_enabled": false,
"payment_min_amount": 0,
"payment_max_amount": 0,
"payment_daily_limit": 0,
"payment_order_timeout_minutes": 0,
"payment_max_pending_orders": 0,
"payment_enabled_types": null,
"payment_balance_disabled": false,
"payment_load_balance_strategy": "",
"payment_product_name_prefix": "",
"payment_product_name_suffix": "",
"payment_help_image_url": "",
"payment_help_text": "",
"payment_cancel_rate_limit_enabled": false,
"payment_cancel_rate_limit_max": 0,
"payment_cancel_rate_limit_window": 0,
"payment_cancel_rate_limit_unit": "",
"payment_cancel_rate_limit_window_mode": "",
"custom_menu_items": [],
"custom_endpoints": []
}
......@@ -696,7 +714,7 @@ func newContractDeps(t *testing.T) *contractDeps {
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil, nil, nil)
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
jwtAuth := func(c *gin.Context) {
......
......@@ -111,4 +111,5 @@ func registerRoutes(
routes.RegisterUserRoutes(v1, h, jwtAuth, settingService)
routes.RegisterAdminRoutes(v1, h, adminAuth)
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg)
routes.RegisterPaymentRoutes(v1, h.Payment, h.PaymentWebhook, h.Admin.Payment, jwtAuth, adminAuth, settingService)
}
package routes
import (
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/handler/admin"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// RegisterPaymentRoutes registers all payment-related routes:
// user-facing endpoints, webhook endpoints, and admin endpoints.
func RegisterPaymentRoutes(
v1 *gin.RouterGroup,
paymentHandler *handler.PaymentHandler,
webhookHandler *handler.PaymentWebhookHandler,
adminPaymentHandler *admin.PaymentHandler,
jwtAuth middleware.JWTAuthMiddleware,
adminAuth middleware.AdminAuthMiddleware,
settingService *service.SettingService,
) {
// --- User-facing payment endpoints (authenticated) ---
authenticated := v1.Group("/payment")
authenticated.Use(gin.HandlerFunc(jwtAuth))
authenticated.Use(middleware.BackendModeUserGuard(settingService))
{
authenticated.GET("/config", paymentHandler.GetPaymentConfig)
authenticated.GET("/checkout-info", paymentHandler.GetCheckoutInfo)
authenticated.GET("/plans", paymentHandler.GetPlans)
authenticated.GET("/channels", paymentHandler.GetChannels)
authenticated.GET("/limits", paymentHandler.GetLimits)
orders := authenticated.Group("/orders")
{
orders.POST("", paymentHandler.CreateOrder)
orders.POST("/verify", paymentHandler.VerifyOrder)
orders.GET("/my", paymentHandler.GetMyOrders)
orders.GET("/:id", paymentHandler.GetOrder)
orders.POST("/:id/cancel", paymentHandler.CancelOrder)
orders.POST("/:id/refund-request", paymentHandler.RequestRefund)
}
}
// --- Public payment endpoints (no auth) ---
// Payment result page needs to verify order status without login
// (user session may have expired during provider redirect).
public := v1.Group("/payment/public")
{
public.POST("/orders/verify", paymentHandler.VerifyOrderPublic)
}
// --- Webhook endpoints (no auth) ---
webhook := v1.Group("/payment/webhook")
{
// EasyPay sends GET callbacks with query params
webhook.GET("/easypay", webhookHandler.EasyPayNotify)
webhook.POST("/easypay", webhookHandler.EasyPayNotify)
webhook.POST("/alipay", webhookHandler.AlipayNotify)
webhook.POST("/wxpay", webhookHandler.WxpayNotify)
webhook.POST("/stripe", webhookHandler.StripeWebhook)
}
// --- Admin payment endpoints (admin auth) ---
adminGroup := v1.Group("/admin/payment")
adminGroup.Use(gin.HandlerFunc(adminAuth))
{
// Dashboard
adminGroup.GET("/dashboard", adminPaymentHandler.GetDashboard)
// Config
adminGroup.GET("/config", adminPaymentHandler.GetConfig)
adminGroup.PUT("/config", adminPaymentHandler.UpdateConfig)
// Orders
adminOrders := adminGroup.Group("/orders")
{
adminOrders.GET("", adminPaymentHandler.ListOrders)
adminOrders.GET("/:id", adminPaymentHandler.GetOrderDetail)
adminOrders.POST("/:id/cancel", adminPaymentHandler.CancelOrder)
adminOrders.POST("/:id/retry", adminPaymentHandler.RetryFulfillment)
adminOrders.POST("/:id/refund", adminPaymentHandler.ProcessRefund)
}
// Subscription Plans
plans := adminGroup.Group("/plans")
{
plans.GET("", adminPaymentHandler.ListPlans)
plans.POST("", adminPaymentHandler.CreatePlan)
plans.PUT("/:id", adminPaymentHandler.UpdatePlan)
plans.DELETE("/:id", adminPaymentHandler.DeletePlan)
}
// Provider Instances
providers := adminGroup.Group("/providers")
{
providers.GET("", adminPaymentHandler.ListProviders)
providers.POST("", adminPaymentHandler.CreateProvider)
providers.PUT("/:id", adminPaymentHandler.UpdateProvider)
providers.DELETE("/:id", adminPaymentHandler.DeleteProvider)
}
}
}
package service
import (
"context"
"encoding/json"
"fmt"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
"github.com/Wei-Shaw/sub2api/internal/payment"
)
// GetAvailableMethodLimits collects all payment types from enabled provider
// instances and returns limits for each, plus the global widest range.
// Stripe sub-types (card, link) are aggregated under "stripe".
func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*MethodLimitsResponse, error) {
instances, err := s.entClient.PaymentProviderInstance.Query().
Where(paymentproviderinstance.EnabledEQ(true)).All(ctx)
if err != nil {
return nil, fmt.Errorf("query provider instances: %w", err)
}
typeInstances := pcGroupByPaymentType(instances)
resp := &MethodLimitsResponse{
Methods: make(map[string]MethodLimits, len(typeInstances)),
}
for pt, insts := range typeInstances {
ml := pcAggregateMethodLimits(pt, insts)
resp.Methods[ml.PaymentType] = ml
}
resp.GlobalMin, resp.GlobalMax = pcComputeGlobalRange(resp.Methods)
return resp, nil
}
// GetMethodLimits returns per-payment-type limits from enabled provider instances.
func (s *PaymentConfigService) GetMethodLimits(ctx context.Context, types []string) ([]MethodLimits, error) {
instances, err := s.entClient.PaymentProviderInstance.Query().
Where(paymentproviderinstance.EnabledEQ(true)).All(ctx)
if err != nil {
return nil, fmt.Errorf("query provider instances: %w", err)
}
result := make([]MethodLimits, 0, len(types))
for _, pt := range types {
var matching []*dbent.PaymentProviderInstance
for _, inst := range instances {
if payment.InstanceSupportsType(inst.SupportedTypes, pt) {
matching = append(matching, inst)
}
}
result = append(result, pcAggregateMethodLimits(pt, matching))
}
return result, nil
}
// pcGroupByPaymentType groups instances by user-facing payment type.
// For Stripe providers, ALL sub-types (card, link, alipay, wxpay) map to "stripe"
// because the user sees a single "Stripe" button, not individual sub-methods.
// Uses a seen set to avoid counting one instance twice.
func pcGroupByPaymentType(instances []*dbent.PaymentProviderInstance) map[string][]*dbent.PaymentProviderInstance {
typeInstances := make(map[string][]*dbent.PaymentProviderInstance)
seen := make(map[string]map[int64]bool)
add := func(key string, inst *dbent.PaymentProviderInstance) {
if seen[key] == nil {
seen[key] = make(map[int64]bool)
}
if !seen[key][int64(inst.ID)] {
seen[key][int64(inst.ID)] = true
typeInstances[key] = append(typeInstances[key], inst)
}
}
for _, inst := range instances {
// Stripe provider: all sub-types → single "stripe" group
if inst.ProviderKey == payment.TypeStripe {
add(payment.TypeStripe, inst)
continue
}
for _, t := range splitTypes(inst.SupportedTypes) {
add(t, inst)
}
}
return typeInstances
}
// pcInstanceTypeLimits extracts per-type limits from a provider instance.
// Returns (limits, true) if configured; (zero, false) if unlimited.
// For Stripe instances, limits are stored under "stripe" key regardless of sub-types.
func pcInstanceTypeLimits(inst *dbent.PaymentProviderInstance, pt string) (payment.ChannelLimits, bool) {
if inst.Limits == "" {
return payment.ChannelLimits{}, false
}
var limits payment.InstanceLimits
if err := json.Unmarshal([]byte(inst.Limits), &limits); err != nil {
return payment.ChannelLimits{}, false
}
cl, ok := limits[pt]
return cl, ok
}
// unionFloat merges a single limit value into the aggregate using UNION semantics.
// - For "min" fields (wantMin=true): keeps the lowest non-zero value
// - For "max"/"cap" fields (wantMin=false): keeps the highest non-zero value
// - If any value is 0 (unlimited), the result is unlimited.
//
// Returns (aggregated value, still limited).
func unionFloat(agg float64, limited bool, val float64, wantMin bool) (float64, bool) {
if val == 0 {
return agg, false
}
if !limited {
return agg, false
}
if agg == 0 {
return val, true
}
if wantMin && val < agg {
return val, true
}
if !wantMin && val > agg {
return val, true
}
return agg, true
}
// pcAggregateMethodLimits computes the UNION (least restrictive) of limits
// across all provider instances for a given payment type.
//
// Since the load balancer can route an order to any available instance,
// the user should see the widest possible range:
// - SingleMin: lowest floor across instances; 0 if any is unlimited
// - SingleMax: highest ceiling across instances; 0 if any is unlimited
// - DailyLimit: highest cap across instances; 0 if any is unlimited
func pcAggregateMethodLimits(pt string, instances []*dbent.PaymentProviderInstance) MethodLimits {
ml := MethodLimits{PaymentType: pt}
minLimited, maxLimited, dailyLimited := true, true, true
for _, inst := range instances {
cl, hasLimits := pcInstanceTypeLimits(inst, pt)
if !hasLimits {
return MethodLimits{PaymentType: pt} // any unlimited instance → all zeros
}
ml.SingleMin, minLimited = unionFloat(ml.SingleMin, minLimited, cl.SingleMin, true)
ml.SingleMax, maxLimited = unionFloat(ml.SingleMax, maxLimited, cl.SingleMax, false)
ml.DailyLimit, dailyLimited = unionFloat(ml.DailyLimit, dailyLimited, cl.DailyLimit, false)
}
if !minLimited {
ml.SingleMin = 0
}
if !maxLimited {
ml.SingleMax = 0
}
if !dailyLimited {
ml.DailyLimit = 0
}
return ml
}
// pcComputeGlobalRange computes the widest [min, max] across all methods.
// Uses the same union logic: lowest min, highest max, 0 if any is unlimited.
func pcComputeGlobalRange(methods map[string]MethodLimits) (globalMin, globalMax float64) {
minLimited, maxLimited := true, true
for _, ml := range methods {
globalMin, minLimited = unionFloat(globalMin, minLimited, ml.SingleMin, true)
globalMax, maxLimited = unionFloat(globalMax, maxLimited, ml.SingleMax, false)
}
if !minLimited {
globalMin = 0
}
if !maxLimited {
globalMax = 0
}
return globalMin, globalMax
}
package service
import (
"testing"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/payment"
)
func TestUnionFloat(t *testing.T) {
t.Parallel()
tests := []struct {
name string
agg float64
limited bool
val float64
wantMin bool
wantAgg float64
wantLimited bool
}{
{"first non-zero value", 0, true, 5, true, 5, true},
{"lower min replaces", 10, true, 3, true, 3, true},
{"higher min does not replace", 3, true, 10, true, 3, true},
{"higher max replaces", 10, true, 20, false, 20, true},
{"lower max does not replace", 20, true, 10, false, 20, true},
{"zero value makes unlimited", 5, true, 0, true, 5, false},
{"already unlimited stays unlimited", 5, false, 10, true, 5, false},
{"zero on first call", 0, true, 0, true, 0, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
gotAgg, gotLimited := unionFloat(tt.agg, tt.limited, tt.val, tt.wantMin)
if gotAgg != tt.wantAgg || gotLimited != tt.wantLimited {
t.Fatalf("unionFloat(%v, %v, %v, %v) = (%v, %v), want (%v, %v)",
tt.agg, tt.limited, tt.val, tt.wantMin,
gotAgg, gotLimited, tt.wantAgg, tt.wantLimited)
}
})
}
}
func makeInstance(id int64, providerKey, supportedTypes, limits string) *dbent.PaymentProviderInstance {
return &dbent.PaymentProviderInstance{
ID: id,
ProviderKey: providerKey,
SupportedTypes: supportedTypes,
Limits: limits,
Enabled: true,
}
}
func TestPcAggregateMethodLimits(t *testing.T) {
t.Parallel()
t.Run("single instance with limits", func(t *testing.T) {
t.Parallel()
inst := makeInstance(1, "easypay", "alipay,wxpay",
`{"alipay":{"singleMin":2,"singleMax":14},"wxpay":{"singleMin":1,"singleMax":12}}`)
ml := pcAggregateMethodLimits("alipay", []*dbent.PaymentProviderInstance{inst})
if ml.SingleMin != 2 || ml.SingleMax != 14 {
t.Fatalf("alipay limits = min:%v max:%v, want min:2 max:14", ml.SingleMin, ml.SingleMax)
}
})
t.Run("two instances union takes widest range", func(t *testing.T) {
t.Parallel()
inst1 := makeInstance(1, "easypay", "alipay,wxpay",
`{"alipay":{"singleMin":5,"singleMax":100}}`)
inst2 := makeInstance(2, "easypay", "alipay,wxpay",
`{"alipay":{"singleMin":2,"singleMax":200}}`)
ml := pcAggregateMethodLimits("alipay", []*dbent.PaymentProviderInstance{inst1, inst2})
if ml.SingleMin != 2 {
t.Fatalf("SingleMin = %v, want 2 (lowest floor)", ml.SingleMin)
}
if ml.SingleMax != 200 {
t.Fatalf("SingleMax = %v, want 200 (highest ceiling)", ml.SingleMax)
}
})
t.Run("one instance unlimited makes aggregate unlimited", func(t *testing.T) {
t.Parallel()
inst1 := makeInstance(1, "easypay", "wxpay",
`{"wxpay":{"singleMin":3,"singleMax":10}}`)
inst2 := makeInstance(2, "easypay", "wxpay", "") // no limits = unlimited
ml := pcAggregateMethodLimits("wxpay", []*dbent.PaymentProviderInstance{inst1, inst2})
if ml.SingleMin != 0 || ml.SingleMax != 0 {
t.Fatalf("limits = min:%v max:%v, want min:0 max:0 (unlimited)", ml.SingleMin, ml.SingleMax)
}
})
t.Run("one field unlimited others limited", func(t *testing.T) {
t.Parallel()
inst1 := makeInstance(1, "easypay", "alipay",
`{"alipay":{"singleMin":5,"singleMax":100}}`)
inst2 := makeInstance(2, "easypay", "alipay",
`{"alipay":{"singleMin":3,"singleMax":0}}`) // singleMax=0 = unlimited
ml := pcAggregateMethodLimits("alipay", []*dbent.PaymentProviderInstance{inst1, inst2})
if ml.SingleMin != 3 {
t.Fatalf("SingleMin = %v, want 3 (lowest floor)", ml.SingleMin)
}
if ml.SingleMax != 0 {
t.Fatalf("SingleMax = %v, want 0 (unlimited)", ml.SingleMax)
}
})
t.Run("empty instances returns zeros", func(t *testing.T) {
t.Parallel()
ml := pcAggregateMethodLimits("alipay", nil)
if ml.SingleMin != 0 || ml.SingleMax != 0 || ml.DailyLimit != 0 {
t.Fatalf("empty instances should return all zeros, got %+v", ml)
}
})
t.Run("invalid JSON treated as unlimited", func(t *testing.T) {
t.Parallel()
inst := makeInstance(1, "easypay", "alipay", `{invalid json}`)
ml := pcAggregateMethodLimits("alipay", []*dbent.PaymentProviderInstance{inst})
if ml.SingleMin != 0 || ml.SingleMax != 0 {
t.Fatalf("invalid JSON should be treated as unlimited, got %+v", ml)
}
})
t.Run("type not in limits JSON treated as unlimited", func(t *testing.T) {
t.Parallel()
inst := makeInstance(1, "easypay", "alipay,wxpay",
`{"wxpay":{"singleMin":1,"singleMax":10}}`) // only wxpay, no alipay
ml := pcAggregateMethodLimits("alipay", []*dbent.PaymentProviderInstance{inst})
if ml.SingleMin != 0 || ml.SingleMax != 0 {
t.Fatalf("missing type should be treated as unlimited, got %+v", ml)
}
})
t.Run("daily limit aggregation", func(t *testing.T) {
t.Parallel()
inst1 := makeInstance(1, "easypay", "alipay",
`{"alipay":{"singleMin":1,"singleMax":100,"dailyLimit":500}}`)
inst2 := makeInstance(2, "easypay", "alipay",
`{"alipay":{"singleMin":2,"singleMax":200,"dailyLimit":1000}}`)
ml := pcAggregateMethodLimits("alipay", []*dbent.PaymentProviderInstance{inst1, inst2})
if ml.DailyLimit != 1000 {
t.Fatalf("DailyLimit = %v, want 1000 (highest cap)", ml.DailyLimit)
}
})
}
func TestPcGroupByPaymentType(t *testing.T) {
t.Parallel()
t.Run("stripe instance maps all types to stripe group", func(t *testing.T) {
t.Parallel()
stripe := makeInstance(1, payment.TypeStripe, "card,alipay,link,wxpay", "")
easypay := makeInstance(2, payment.TypeEasyPay, "alipay,wxpay", "")
groups := pcGroupByPaymentType([]*dbent.PaymentProviderInstance{stripe, easypay})
// Stripe instance should only be in "stripe" group
if len(groups[payment.TypeStripe]) != 1 || groups[payment.TypeStripe][0].ID != 1 {
t.Fatalf("stripe group should contain only stripe instance, got %v", groups[payment.TypeStripe])
}
// alipay group should only contain easypay, NOT stripe
if len(groups[payment.TypeAlipay]) != 1 || groups[payment.TypeAlipay][0].ID != 2 {
t.Fatalf("alipay group should contain only easypay instance, got %v", groups[payment.TypeAlipay])
}
// wxpay group should only contain easypay, NOT stripe
if len(groups[payment.TypeWxpay]) != 1 || groups[payment.TypeWxpay][0].ID != 2 {
t.Fatalf("wxpay group should contain only easypay instance, got %v", groups[payment.TypeWxpay])
}
})
t.Run("multiple easypay instances in same groups", func(t *testing.T) {
t.Parallel()
ep1 := makeInstance(1, payment.TypeEasyPay, "alipay,wxpay", "")
ep2 := makeInstance(2, payment.TypeEasyPay, "alipay,wxpay", "")
groups := pcGroupByPaymentType([]*dbent.PaymentProviderInstance{ep1, ep2})
if len(groups[payment.TypeAlipay]) != 2 {
t.Fatalf("alipay group should have 2 instances, got %d", len(groups[payment.TypeAlipay]))
}
if len(groups[payment.TypeWxpay]) != 2 {
t.Fatalf("wxpay group should have 2 instances, got %d", len(groups[payment.TypeWxpay]))
}
})
t.Run("stripe with no supported types still in stripe group", func(t *testing.T) {
t.Parallel()
stripe := makeInstance(1, payment.TypeStripe, "", "")
groups := pcGroupByPaymentType([]*dbent.PaymentProviderInstance{stripe})
if len(groups[payment.TypeStripe]) != 1 {
t.Fatalf("stripe with empty types should still be in stripe group, got %v", groups)
}
})
}
func TestPcComputeGlobalRange(t *testing.T) {
t.Parallel()
t.Run("all methods have limits", func(t *testing.T) {
t.Parallel()
methods := map[string]MethodLimits{
"alipay": {SingleMin: 2, SingleMax: 14},
"wxpay": {SingleMin: 1, SingleMax: 12},
"stripe": {SingleMin: 5, SingleMax: 100},
}
gMin, gMax := pcComputeGlobalRange(methods)
if gMin != 1 {
t.Fatalf("global min = %v, want 1 (lowest floor)", gMin)
}
if gMax != 100 {
t.Fatalf("global max = %v, want 100 (highest ceiling)", gMax)
}
})
t.Run("one method unlimited makes global unlimited", func(t *testing.T) {
t.Parallel()
methods := map[string]MethodLimits{
"alipay": {SingleMin: 2, SingleMax: 14},
"stripe": {SingleMin: 0, SingleMax: 0}, // unlimited
}
gMin, gMax := pcComputeGlobalRange(methods)
if gMin != 0 {
t.Fatalf("global min = %v, want 0 (unlimited)", gMin)
}
if gMax != 0 {
t.Fatalf("global max = %v, want 0 (unlimited)", gMax)
}
})
t.Run("empty methods returns zeros", func(t *testing.T) {
t.Parallel()
gMin, gMax := pcComputeGlobalRange(map[string]MethodLimits{})
if gMin != 0 || gMax != 0 {
t.Fatalf("empty methods should return (0, 0), got (%v, %v)", gMin, gMax)
}
})
t.Run("only min unlimited", func(t *testing.T) {
t.Parallel()
methods := map[string]MethodLimits{
"alipay": {SingleMin: 0, SingleMax: 100},
"wxpay": {SingleMin: 5, SingleMax: 50},
}
gMin, gMax := pcComputeGlobalRange(methods)
if gMin != 0 {
t.Fatalf("global min = %v, want 0 (unlimited)", gMin)
}
if gMax != 100 {
t.Fatalf("global max = %v, want 100", gMax)
}
})
}
func TestPcInstanceTypeLimits(t *testing.T) {
t.Parallel()
t.Run("empty limits string returns false", func(t *testing.T) {
t.Parallel()
inst := makeInstance(1, "easypay", "alipay", "")
_, ok := pcInstanceTypeLimits(inst, "alipay")
if ok {
t.Fatal("expected ok=false for empty limits")
}
})
t.Run("type found returns correct values", func(t *testing.T) {
t.Parallel()
inst := makeInstance(1, "easypay", "alipay",
`{"alipay":{"singleMin":2,"singleMax":14,"dailyLimit":500}}`)
cl, ok := pcInstanceTypeLimits(inst, "alipay")
if !ok {
t.Fatal("expected ok=true")
}
if cl.SingleMin != 2 || cl.SingleMax != 14 || cl.DailyLimit != 500 {
t.Fatalf("limits = %+v, want min:2 max:14 daily:500", cl)
}
})
t.Run("type not found returns false", func(t *testing.T) {
t.Parallel()
inst := makeInstance(1, "easypay", "alipay",
`{"wxpay":{"singleMin":1}}`)
_, ok := pcInstanceTypeLimits(inst, "alipay")
if ok {
t.Fatal("expected ok=false for missing type")
}
})
t.Run("invalid JSON returns false", func(t *testing.T) {
t.Parallel()
inst := makeInstance(1, "easypay", "alipay", `{bad json}`)
_, ok := pcInstanceTypeLimits(inst, "alipay")
if ok {
t.Fatal("expected ok=false for invalid JSON")
}
})
}
package service
import (
"context"
"fmt"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/subscriptionplan"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
// --- Plan CRUD ---
// PlanGroupInfo holds the group details needed for subscription plan display.
type PlanGroupInfo struct {
Platform string `json:"platform"`
Name string `json:"name"`
RateMultiplier float64 `json:"rate_multiplier"`
DailyLimitUSD *float64 `json:"daily_limit_usd"`
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
ModelScopes []string `json:"supported_model_scopes"`
}
// GetGroupPlatformMap returns a map of group_id → platform for the given plans.
func (s *PaymentConfigService) GetGroupPlatformMap(ctx context.Context, plans []*dbent.SubscriptionPlan) map[int64]string {
info := s.GetGroupInfoMap(ctx, plans)
m := make(map[int64]string, len(info))
for id, gi := range info {
m[id] = gi.Platform
}
return m
}
// GetGroupInfoMap returns a map of group_id → PlanGroupInfo for the given plans.
func (s *PaymentConfigService) GetGroupInfoMap(ctx context.Context, plans []*dbent.SubscriptionPlan) map[int64]PlanGroupInfo {
ids := make([]int64, 0, len(plans))
seen := make(map[int64]bool)
for _, p := range plans {
if !seen[p.GroupID] {
seen[p.GroupID] = true
ids = append(ids, p.GroupID)
}
}
if len(ids) == 0 {
return nil
}
groups, err := s.entClient.Group.Query().Where(group.IDIn(ids...)).All(ctx)
if err != nil {
return nil
}
m := make(map[int64]PlanGroupInfo, len(groups))
for _, g := range groups {
m[int64(g.ID)] = PlanGroupInfo{
Platform: g.Platform,
Name: g.Name,
RateMultiplier: g.RateMultiplier,
DailyLimitUSD: g.DailyLimitUsd,
WeeklyLimitUSD: g.WeeklyLimitUsd,
MonthlyLimitUSD: g.MonthlyLimitUsd,
ModelScopes: g.SupportedModelScopes,
}
}
return m
}
func (s *PaymentConfigService) ListPlans(ctx context.Context) ([]*dbent.SubscriptionPlan, error) {
return s.entClient.SubscriptionPlan.Query().Order(subscriptionplan.BySortOrder()).All(ctx)
}
func (s *PaymentConfigService) ListPlansForSale(ctx context.Context) ([]*dbent.SubscriptionPlan, error) {
return s.entClient.SubscriptionPlan.Query().Where(subscriptionplan.ForSaleEQ(true)).Order(subscriptionplan.BySortOrder()).All(ctx)
}
func (s *PaymentConfigService) CreatePlan(ctx context.Context, req CreatePlanRequest) (*dbent.SubscriptionPlan, error) {
b := s.entClient.SubscriptionPlan.Create().
SetGroupID(req.GroupID).SetName(req.Name).SetDescription(req.Description).
SetPrice(req.Price).SetValidityDays(req.ValidityDays).SetValidityUnit(req.ValidityUnit).
SetFeatures(req.Features).SetProductName(req.ProductName).
SetForSale(req.ForSale).SetSortOrder(req.SortOrder)
if req.OriginalPrice != nil {
b.SetOriginalPrice(*req.OriginalPrice)
}
return b.Save(ctx)
}
// UpdatePlan updates a subscription plan by ID (patch semantics).
// NOTE: This function exceeds 30 lines due to per-field nil-check patch update boilerplate.
func (s *PaymentConfigService) UpdatePlan(ctx context.Context, id int64, req UpdatePlanRequest) (*dbent.SubscriptionPlan, error) {
u := s.entClient.SubscriptionPlan.UpdateOneID(id)
if req.GroupID != nil {
u.SetGroupID(*req.GroupID)
}
if req.Name != nil {
u.SetName(*req.Name)
}
if req.Description != nil {
u.SetDescription(*req.Description)
}
if req.Price != nil {
u.SetPrice(*req.Price)
}
if req.OriginalPrice != nil {
u.SetOriginalPrice(*req.OriginalPrice)
}
if req.ValidityDays != nil {
u.SetValidityDays(*req.ValidityDays)
}
if req.ValidityUnit != nil {
u.SetValidityUnit(*req.ValidityUnit)
}
if req.Features != nil {
u.SetFeatures(*req.Features)
}
if req.ProductName != nil {
u.SetProductName(*req.ProductName)
}
if req.ForSale != nil {
u.SetForSale(*req.ForSale)
}
if req.SortOrder != nil {
u.SetSortOrder(*req.SortOrder)
}
return u.Save(ctx)
}
func (s *PaymentConfigService) DeletePlan(ctx context.Context, id int64) error {
count, err := s.countPendingOrdersByPlan(ctx, id)
if err != nil {
return fmt.Errorf("check pending orders: %w", err)
}
if count > 0 {
return infraerrors.Conflict("PENDING_ORDERS",
fmt.Sprintf("this plan has %d in-progress orders and cannot be deleted — wait for orders to complete first", count))
}
return s.entClient.SubscriptionPlan.DeleteOneID(id).Exec(ctx)
}
// GetPlan returns a subscription plan by ID.
func (s *PaymentConfigService) GetPlan(ctx context.Context, id int64) (*dbent.SubscriptionPlan, error) {
plan, err := s.entClient.SubscriptionPlan.Get(ctx, id)
if err != nil {
return nil, infraerrors.NotFound("PLAN_NOT_FOUND", "subscription plan not found")
}
return plan, nil
}
package service
import (
"context"
"encoding/json"
"fmt"
"strconv"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
// --- Provider Instance CRUD ---
func (s *PaymentConfigService) ListProviderInstances(ctx context.Context) ([]*dbent.PaymentProviderInstance, error) {
return s.entClient.PaymentProviderInstance.Query().Order(paymentproviderinstance.BySortOrder()).All(ctx)
}
// ProviderInstanceResponse is the API response for a provider instance.
type ProviderInstanceResponse struct {
ID int64 `json:"id"`
ProviderKey string `json:"provider_key"`
Name string `json:"name"`
Config map[string]string `json:"config"`
SupportedTypes []string `json:"supported_types"`
Limits string `json:"limits"`
Enabled bool `json:"enabled"`
RefundEnabled bool `json:"refund_enabled"`
SortOrder int `json:"sort_order"`
PaymentMode string `json:"payment_mode"`
}
// ListProviderInstancesWithConfig returns provider instances with decrypted config.
func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Context) ([]ProviderInstanceResponse, error) {
instances, err := s.entClient.PaymentProviderInstance.Query().
Order(paymentproviderinstance.BySortOrder()).All(ctx)
if err != nil {
return nil, err
}
result := make([]ProviderInstanceResponse, 0, len(instances))
for _, inst := range instances {
resp := ProviderInstanceResponse{
ID: int64(inst.ID), ProviderKey: inst.ProviderKey, Name: inst.Name,
SupportedTypes: splitTypes(inst.SupportedTypes), Limits: inst.Limits,
Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled, SortOrder: inst.SortOrder,
PaymentMode: inst.PaymentMode,
}
resp.Config, err = s.decryptAndMaskConfig(inst.Config)
if err != nil {
return nil, fmt.Errorf("decrypt config for instance %d: %w", inst.ID, err)
}
result = append(result, resp)
}
return result, nil
}
func (s *PaymentConfigService) decryptAndMaskConfig(encrypted string) (map[string]string, error) {
return s.decryptConfig(encrypted)
}
// pendingOrderStatuses are order statuses considered "in progress".
var pendingOrderStatuses = []string{
payment.OrderStatusPending,
payment.OrderStatusPaid,
payment.OrderStatusRecharging,
}
var sensitiveConfigPatterns = []string{"key", "pkey", "secret", "private", "password"}
func isSensitiveConfigField(fieldName string) bool {
lower := strings.ToLower(fieldName)
for _, p := range sensitiveConfigPatterns {
if strings.Contains(lower, p) {
return true
}
}
return false
}
func (s *PaymentConfigService) countPendingOrders(ctx context.Context, providerInstanceID int64) (int, error) {
return s.entClient.PaymentOrder.Query().
Where(
paymentorder.ProviderInstanceIDEQ(strconv.FormatInt(providerInstanceID, 10)),
paymentorder.StatusIn(pendingOrderStatuses...),
).Count(ctx)
}
func (s *PaymentConfigService) countPendingOrdersByPlan(ctx context.Context, planID int64) (int, error) {
return s.entClient.PaymentOrder.Query().
Where(
paymentorder.PlanIDEQ(planID),
paymentorder.StatusIn(pendingOrderStatuses...),
).Count(ctx)
}
var validProviderKeys = map[string]bool{
payment.TypeEasyPay: true, payment.TypeAlipay: true, payment.TypeWxpay: true, payment.TypeStripe: true,
}
func (s *PaymentConfigService) CreateProviderInstance(ctx context.Context, req CreateProviderInstanceRequest) (*dbent.PaymentProviderInstance, error) {
typesStr := joinTypes(req.SupportedTypes)
if err := validateProviderRequest(req.ProviderKey, req.Name, typesStr); err != nil {
return nil, err
}
enc, err := s.encryptConfig(req.Config)
if err != nil {
return nil, err
}
return s.entClient.PaymentProviderInstance.Create().
SetProviderKey(req.ProviderKey).SetName(req.Name).SetConfig(enc).
SetSupportedTypes(typesStr).SetEnabled(req.Enabled).SetPaymentMode(req.PaymentMode).
SetSortOrder(req.SortOrder).SetLimits(req.Limits).SetRefundEnabled(req.RefundEnabled).
Save(ctx)
}
func validateProviderRequest(providerKey, name, supportedTypes string) error {
if strings.TrimSpace(name) == "" {
return infraerrors.BadRequest("VALIDATION_ERROR", "provider name is required")
}
if !validProviderKeys[providerKey] {
return infraerrors.BadRequest("VALIDATION_ERROR", fmt.Sprintf("invalid provider key: %s", providerKey))
}
// supported_types can be empty (provider accepts no payment types until configured)
return nil
}
// UpdateProviderInstance updates a provider instance by ID (patch semantics).
// NOTE: This function exceeds 30 lines due to per-field nil-check patch update
// boilerplate and pending-order safety checks.
func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id int64, req UpdateProviderInstanceRequest) (*dbent.PaymentProviderInstance, error) {
if req.Config != nil {
hasSensitive := false
for k := range req.Config {
if isSensitiveConfigField(k) && req.Config[k] != "" {
hasSensitive = true
break
}
}
if hasSensitive {
count, err := s.countPendingOrders(ctx, id)
if err != nil {
return nil, fmt.Errorf("check pending orders: %w", err)
}
if count > 0 {
return nil, infraerrors.Conflict("PENDING_ORDERS", "instance has pending orders").
WithMetadata(map[string]string{"count": strconv.Itoa(count)})
}
}
}
if req.Enabled != nil && !*req.Enabled {
count, err := s.countPendingOrders(ctx, id)
if err != nil {
return nil, fmt.Errorf("check pending orders: %w", err)
}
if count > 0 {
return nil, infraerrors.Conflict("PENDING_ORDERS", "instance has pending orders").
WithMetadata(map[string]string{"count": strconv.Itoa(count)})
}
}
u := s.entClient.PaymentProviderInstance.UpdateOneID(id)
if req.Name != nil {
u.SetName(*req.Name)
}
if req.Config != nil {
merged, err := s.mergeConfig(ctx, id, req.Config)
if err != nil {
return nil, err
}
enc, err := s.encryptConfig(merged)
if err != nil {
return nil, err
}
u.SetConfig(enc)
}
if req.SupportedTypes != nil {
// Check pending orders before removing payment types
count, err := s.countPendingOrders(ctx, id)
if err != nil {
return nil, fmt.Errorf("check pending orders: %w", err)
}
if count > 0 {
// Load current instance to compare types
inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id)
if err != nil {
return nil, fmt.Errorf("load provider instance: %w", err)
}
oldTypes := strings.Split(inst.SupportedTypes, ",")
newTypes := req.SupportedTypes
for _, ot := range oldTypes {
ot = strings.TrimSpace(ot)
if ot == "" {
continue
}
found := false
for _, nt := range newTypes {
if strings.TrimSpace(nt) == ot {
found = true
break
}
}
if !found {
return nil, infraerrors.Conflict("PENDING_ORDERS", "cannot remove payment types while instance has pending orders").
WithMetadata(map[string]string{"count": strconv.Itoa(count)})
}
}
}
u.SetSupportedTypes(joinTypes(req.SupportedTypes))
}
if req.Enabled != nil {
u.SetEnabled(*req.Enabled)
}
if req.SortOrder != nil {
u.SetSortOrder(*req.SortOrder)
}
if req.Limits != nil {
u.SetLimits(*req.Limits)
}
if req.RefundEnabled != nil {
u.SetRefundEnabled(*req.RefundEnabled)
}
if req.PaymentMode != nil {
u.SetPaymentMode(*req.PaymentMode)
}
return u.Save(ctx)
}
func (s *PaymentConfigService) mergeConfig(ctx context.Context, id int64, newConfig map[string]string) (map[string]string, error) {
inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id)
if err != nil {
return nil, fmt.Errorf("load existing provider: %w", err)
}
existing, err := s.decryptConfig(inst.Config)
if err != nil {
return nil, fmt.Errorf("decrypt existing config for instance %d: %w", id, err)
}
if existing == nil {
return newConfig, nil
}
for k, v := range newConfig {
existing[k] = v
}
return existing, nil
}
func (s *PaymentConfigService) decryptConfig(encrypted string) (map[string]string, error) {
if encrypted == "" {
return nil, nil
}
decrypted, err := payment.Decrypt(encrypted, s.encryptionKey)
if err != nil {
return nil, fmt.Errorf("decrypt config: %w", err)
}
var raw map[string]string
if err := json.Unmarshal([]byte(decrypted), &raw); err != nil {
return nil, fmt.Errorf("unmarshal decrypted config: %w", err)
}
return raw, nil
}
func (s *PaymentConfigService) DeleteProviderInstance(ctx context.Context, id int64) error {
count, err := s.countPendingOrders(ctx, id)
if err != nil {
return fmt.Errorf("check pending orders: %w", err)
}
if count > 0 {
return infraerrors.Conflict("PENDING_ORDERS",
fmt.Sprintf("this instance has %d in-progress orders and cannot be deleted — wait for orders to complete or disable the instance first", count))
}
return s.entClient.PaymentProviderInstance.DeleteOneID(id).Exec(ctx)
}
func (s *PaymentConfigService) encryptConfig(cfg map[string]string) (string, error) {
data, err := json.Marshal(cfg)
if err != nil {
return "", fmt.Errorf("marshal config: %w", err)
}
enc, err := payment.Encrypt(string(data), s.encryptionKey)
if err != nil {
return "", fmt.Errorf("encrypt config: %w", err)
}
return enc, nil
}
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestValidateProviderRequest(t *testing.T) {
t.Parallel()
tests := []struct {
name string
providerKey string
providerName string
supportedTypes string
wantErr bool
errContains string
}{
{
name: "valid easypay with types",
providerKey: "easypay",
providerName: "MyProvider",
supportedTypes: "alipay,wxpay",
wantErr: false,
},
{
name: "valid stripe with empty types",
providerKey: "stripe",
providerName: "Stripe Provider",
supportedTypes: "",
wantErr: false,
},
{
name: "valid alipay provider",
providerKey: "alipay",
providerName: "Alipay Direct",
supportedTypes: "alipay",
wantErr: false,
},
{
name: "valid wxpay provider",
providerKey: "wxpay",
providerName: "WeChat Pay",
supportedTypes: "wxpay",
wantErr: false,
},
{
name: "invalid provider key",
providerKey: "invalid",
providerName: "Name",
supportedTypes: "alipay",
wantErr: true,
errContains: "invalid provider key",
},
{
name: "empty name",
providerKey: "easypay",
providerName: "",
supportedTypes: "alipay",
wantErr: true,
errContains: "provider name is required",
},
{
name: "whitespace-only name",
providerKey: "easypay",
providerName: " ",
supportedTypes: "alipay",
wantErr: true,
errContains: "provider name is required",
},
{
name: "tab-only name",
providerKey: "easypay",
providerName: "\t",
supportedTypes: "alipay",
wantErr: true,
errContains: "provider name is required",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
err := validateProviderRequest(tc.providerKey, tc.providerName, tc.supportedTypes)
if tc.wantErr {
require.Error(t, err)
assert.Contains(t, err.Error(), tc.errContains)
} else {
require.NoError(t, err)
}
})
}
}
func TestIsSensitiveConfigField(t *testing.T) {
t.Parallel()
tests := []struct {
field string
wantSen bool
}{
// Sensitive fields (contain key/secret/private/password/pkey patterns)
{"secretKey", true},
{"apiSecret", true},
{"pkey", true},
{"privateKey", true},
{"apiPassword", true},
{"appKey", true},
{"SECRET_TOKEN", true},
{"PrivateData", true},
{"PASSWORD", true},
{"mySecretValue", true},
// Non-sensitive fields
{"appId", false},
{"mchId", false},
{"apiBase", false},
{"endpoint", false},
{"merchantNo", false},
{"paymentMode", false},
{"notifyUrl", false},
}
for _, tc := range tests {
t.Run(tc.field, func(t *testing.T) {
t.Parallel()
got := isSensitiveConfigField(tc.field)
assert.Equal(t, tc.wantSen, got, "isSensitiveConfigField(%q)", tc.field)
})
}
}
func TestJoinTypes(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input []string
want string
}{
{
name: "multiple types",
input: []string{"alipay", "wxpay"},
want: "alipay,wxpay",
},
{
name: "single type",
input: []string{"stripe"},
want: "stripe",
},
{
name: "empty slice",
input: []string{},
want: "",
},
{
name: "nil slice",
input: nil,
want: "",
},
{
name: "three types",
input: []string{"alipay", "wxpay", "stripe"},
want: "alipay,wxpay,stripe",
},
{
name: "types with spaces are not trimmed",
input: []string{" alipay ", " wxpay "},
want: " alipay , wxpay ",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := joinTypes(tc.input)
assert.Equal(t, tc.want, got)
})
}
}
package service
import (
"context"
"fmt"
"strconv"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
"github.com/Wei-Shaw/sub2api/internal/payment"
)
const (
SettingPaymentEnabled = "payment_enabled"
SettingMinRechargeAmount = "MIN_RECHARGE_AMOUNT"
SettingMaxRechargeAmount = "MAX_RECHARGE_AMOUNT"
SettingDailyRechargeLimit = "DAILY_RECHARGE_LIMIT"
SettingOrderTimeoutMinutes = "ORDER_TIMEOUT_MINUTES"
SettingMaxPendingOrders = "MAX_PENDING_ORDERS"
SettingEnabledPaymentTypes = "ENABLED_PAYMENT_TYPES"
SettingLoadBalanceStrategy = "LOAD_BALANCE_STRATEGY"
SettingBalancePayDisabled = "BALANCE_PAYMENT_DISABLED"
SettingProductNamePrefix = "PRODUCT_NAME_PREFIX"
SettingProductNameSuffix = "PRODUCT_NAME_SUFFIX"
SettingHelpImageURL = "PAYMENT_HELP_IMAGE_URL"
SettingHelpText = "PAYMENT_HELP_TEXT"
SettingCancelRateLimitOn = "CANCEL_RATE_LIMIT_ENABLED"
SettingCancelRateLimitMax = "CANCEL_RATE_LIMIT_MAX"
SettingCancelWindowSize = "CANCEL_RATE_LIMIT_WINDOW"
SettingCancelWindowUnit = "CANCEL_RATE_LIMIT_UNIT"
SettingCancelWindowMode = "CANCEL_RATE_LIMIT_WINDOW_MODE"
)
// Default values for payment configuration settings.
const (
defaultOrderTimeoutMin = 30
defaultMaxPendingOrders = 3
)
// PaymentConfig holds the payment system configuration.
type PaymentConfig struct {
Enabled bool `json:"enabled"`
MinAmount float64 `json:"min_amount"`
MaxAmount float64 `json:"max_amount"`
DailyLimit float64 `json:"daily_limit"`
OrderTimeoutMin int `json:"order_timeout_minutes"`
MaxPendingOrders int `json:"max_pending_orders"`
EnabledTypes []string `json:"enabled_payment_types"`
BalanceDisabled bool `json:"balance_disabled"`
LoadBalanceStrategy string `json:"load_balance_strategy"`
ProductNamePrefix string `json:"product_name_prefix"`
ProductNameSuffix string `json:"product_name_suffix"`
HelpImageURL string `json:"help_image_url"`
HelpText string `json:"help_text"`
StripePublishableKey string `json:"stripe_publishable_key,omitempty"`
// Cancel rate limit settings
CancelRateLimitEnabled bool `json:"cancel_rate_limit_enabled"`
CancelRateLimitMax int `json:"cancel_rate_limit_max"`
CancelRateLimitWindow int `json:"cancel_rate_limit_window"`
CancelRateLimitUnit string `json:"cancel_rate_limit_unit"`
CancelRateLimitMode string `json:"cancel_rate_limit_window_mode"`
}
// UpdatePaymentConfigRequest contains fields to update payment configuration.
type UpdatePaymentConfigRequest struct {
Enabled *bool `json:"enabled"`
MinAmount *float64 `json:"min_amount"`
MaxAmount *float64 `json:"max_amount"`
DailyLimit *float64 `json:"daily_limit"`
OrderTimeoutMin *int `json:"order_timeout_minutes"`
MaxPendingOrders *int `json:"max_pending_orders"`
EnabledTypes []string `json:"enabled_payment_types"`
BalanceDisabled *bool `json:"balance_disabled"`
LoadBalanceStrategy *string `json:"load_balance_strategy"`
ProductNamePrefix *string `json:"product_name_prefix"`
ProductNameSuffix *string `json:"product_name_suffix"`
HelpImageURL *string `json:"help_image_url"`
HelpText *string `json:"help_text"`
// Cancel rate limit settings
CancelRateLimitEnabled *bool `json:"cancel_rate_limit_enabled"`
CancelRateLimitMax *int `json:"cancel_rate_limit_max"`
CancelRateLimitWindow *int `json:"cancel_rate_limit_window"`
CancelRateLimitUnit *string `json:"cancel_rate_limit_unit"`
CancelRateLimitMode *string `json:"cancel_rate_limit_window_mode"`
}
// MethodLimits holds per-payment-type limits.
type MethodLimits struct {
PaymentType string `json:"payment_type"`
FeeRate float64 `json:"fee_rate"`
DailyLimit float64 `json:"daily_limit"`
SingleMin float64 `json:"single_min"`
SingleMax float64 `json:"single_max"`
}
// MethodLimitsResponse is the full response for the user-facing /limits API.
// It includes per-method limits and the global widest range (union of all methods).
type MethodLimitsResponse struct {
Methods map[string]MethodLimits `json:"methods"`
GlobalMin float64 `json:"global_min"` // 0 = no minimum
GlobalMax float64 `json:"global_max"` // 0 = no maximum
}
type CreateProviderInstanceRequest struct {
ProviderKey string `json:"provider_key"`
Name string `json:"name"`
Config map[string]string `json:"config"`
SupportedTypes []string `json:"supported_types"`
Enabled bool `json:"enabled"`
PaymentMode string `json:"payment_mode"`
SortOrder int `json:"sort_order"`
Limits string `json:"limits"`
RefundEnabled bool `json:"refund_enabled"`
}
type UpdateProviderInstanceRequest struct {
Name *string `json:"name"`
Config map[string]string `json:"config"`
SupportedTypes []string `json:"supported_types"`
Enabled *bool `json:"enabled"`
PaymentMode *string `json:"payment_mode"`
SortOrder *int `json:"sort_order"`
Limits *string `json:"limits"`
RefundEnabled *bool `json:"refund_enabled"`
}
type CreatePlanRequest struct {
GroupID int64 `json:"group_id"`
Name string `json:"name"`
Description string `json:"description"`
Price float64 `json:"price"`
OriginalPrice *float64 `json:"original_price"`
ValidityDays int `json:"validity_days"`
ValidityUnit string `json:"validity_unit"`
Features string `json:"features"`
ProductName string `json:"product_name"`
ForSale bool `json:"for_sale"`
SortOrder int `json:"sort_order"`
}
type UpdatePlanRequest struct {
GroupID *int64 `json:"group_id"`
Name *string `json:"name"`
Description *string `json:"description"`
Price *float64 `json:"price"`
OriginalPrice *float64 `json:"original_price"`
ValidityDays *int `json:"validity_days"`
ValidityUnit *string `json:"validity_unit"`
Features *string `json:"features"`
ProductName *string `json:"product_name"`
ForSale *bool `json:"for_sale"`
SortOrder *int `json:"sort_order"`
}
// PaymentConfigService manages payment configuration and CRUD for
// provider instances, channels, and subscription plans.
type PaymentConfigService struct {
entClient *dbent.Client
settingRepo SettingRepository
encryptionKey []byte
}
// NewPaymentConfigService creates a new PaymentConfigService.
func NewPaymentConfigService(entClient *dbent.Client, settingRepo SettingRepository, encryptionKey []byte) *PaymentConfigService {
return &PaymentConfigService{entClient: entClient, settingRepo: settingRepo, encryptionKey: encryptionKey}
}
// IsPaymentEnabled returns whether the payment system is enabled.
func (s *PaymentConfigService) IsPaymentEnabled(ctx context.Context) bool {
val, err := s.settingRepo.GetValue(ctx, SettingPaymentEnabled)
if err != nil {
return false
}
return val == "true"
}
// GetPaymentConfig returns the full payment configuration.
func (s *PaymentConfigService) GetPaymentConfig(ctx context.Context) (*PaymentConfig, error) {
keys := []string{
SettingPaymentEnabled, SettingMinRechargeAmount, SettingMaxRechargeAmount,
SettingDailyRechargeLimit, SettingOrderTimeoutMinutes, SettingMaxPendingOrders,
SettingEnabledPaymentTypes, SettingBalancePayDisabled, SettingLoadBalanceStrategy,
SettingProductNamePrefix, SettingProductNameSuffix,
SettingHelpImageURL, SettingHelpText,
SettingCancelRateLimitOn, SettingCancelRateLimitMax,
SettingCancelWindowSize, SettingCancelWindowUnit, SettingCancelWindowMode,
}
vals, err := s.settingRepo.GetMultiple(ctx, keys)
if err != nil {
return nil, fmt.Errorf("get payment config settings: %w", err)
}
cfg := s.parsePaymentConfig(vals)
// Load Stripe publishable key from the first enabled Stripe provider instance
cfg.StripePublishableKey = s.getStripePublishableKey(ctx)
return cfg, nil
}
func (s *PaymentConfigService) parsePaymentConfig(vals map[string]string) *PaymentConfig {
cfg := &PaymentConfig{
Enabled: vals[SettingPaymentEnabled] == "true",
MinAmount: pcParseFloat(vals[SettingMinRechargeAmount], 1),
MaxAmount: pcParseFloat(vals[SettingMaxRechargeAmount], 0),
DailyLimit: pcParseFloat(vals[SettingDailyRechargeLimit], 0),
OrderTimeoutMin: pcParseInt(vals[SettingOrderTimeoutMinutes], defaultOrderTimeoutMin),
MaxPendingOrders: pcParseInt(vals[SettingMaxPendingOrders], defaultMaxPendingOrders),
BalanceDisabled: vals[SettingBalancePayDisabled] == "true",
LoadBalanceStrategy: vals[SettingLoadBalanceStrategy],
ProductNamePrefix: vals[SettingProductNamePrefix],
ProductNameSuffix: vals[SettingProductNameSuffix],
HelpImageURL: vals[SettingHelpImageURL],
HelpText: vals[SettingHelpText],
CancelRateLimitEnabled: vals[SettingCancelRateLimitOn] == "true",
CancelRateLimitMax: pcParseInt(vals[SettingCancelRateLimitMax], 10),
CancelRateLimitWindow: pcParseInt(vals[SettingCancelWindowSize], 1),
CancelRateLimitUnit: vals[SettingCancelWindowUnit],
CancelRateLimitMode: vals[SettingCancelWindowMode],
}
if cfg.LoadBalanceStrategy == "" {
cfg.LoadBalanceStrategy = payment.DefaultLoadBalanceStrategy
}
if raw := vals[SettingEnabledPaymentTypes]; raw != "" {
for _, t := range strings.Split(raw, ",") {
t = strings.TrimSpace(t)
if t != "" {
cfg.EnabledTypes = append(cfg.EnabledTypes, t)
}
}
}
return cfg
}
// getStripePublishableKey finds the publishable key from the first enabled Stripe provider instance.
func (s *PaymentConfigService) getStripePublishableKey(ctx context.Context) string {
instances, err := s.entClient.PaymentProviderInstance.Query().
Where(
paymentproviderinstance.EnabledEQ(true),
paymentproviderinstance.ProviderKeyEQ(payment.TypeStripe),
).Limit(1).All(ctx)
if err != nil || len(instances) == 0 {
return ""
}
cfg, err := s.decryptConfig(instances[0].Config)
if err != nil || cfg == nil {
return ""
}
return cfg[payment.ConfigKeyPublishableKey]
}
// UpdatePaymentConfig updates the payment configuration settings.
// NOTE: This function exceeds 30 lines because each field requires an independent
// nil-check before serialisation — this is inherent to patch-style update patterns
// and cannot be meaningfully decomposed without introducing unnecessary abstraction.
func (s *PaymentConfigService) UpdatePaymentConfig(ctx context.Context, req UpdatePaymentConfigRequest) error {
m := map[string]string{
SettingPaymentEnabled: formatBoolOrEmpty(req.Enabled),
SettingMinRechargeAmount: formatPositiveFloat(req.MinAmount),
SettingMaxRechargeAmount: formatPositiveFloat(req.MaxAmount),
SettingDailyRechargeLimit: formatPositiveFloat(req.DailyLimit),
SettingOrderTimeoutMinutes: formatPositiveInt(req.OrderTimeoutMin),
SettingMaxPendingOrders: formatPositiveInt(req.MaxPendingOrders),
SettingBalancePayDisabled: formatBoolOrEmpty(req.BalanceDisabled),
SettingLoadBalanceStrategy: derefStr(req.LoadBalanceStrategy),
SettingProductNamePrefix: derefStr(req.ProductNamePrefix),
SettingProductNameSuffix: derefStr(req.ProductNameSuffix),
SettingHelpImageURL: derefStr(req.HelpImageURL),
SettingHelpText: derefStr(req.HelpText),
SettingCancelRateLimitOn: formatBoolOrEmpty(req.CancelRateLimitEnabled),
SettingCancelRateLimitMax: formatPositiveInt(req.CancelRateLimitMax),
SettingCancelWindowSize: formatPositiveInt(req.CancelRateLimitWindow),
SettingCancelWindowUnit: derefStr(req.CancelRateLimitUnit),
SettingCancelWindowMode: derefStr(req.CancelRateLimitMode),
}
if req.EnabledTypes != nil {
m[SettingEnabledPaymentTypes] = strings.Join(req.EnabledTypes, ",")
} else {
m[SettingEnabledPaymentTypes] = ""
}
return s.settingRepo.SetMultiple(ctx, m)
}
func formatBoolOrEmpty(v *bool) string {
if v == nil {
return ""
}
return strconv.FormatBool(*v)
}
func formatPositiveFloat(v *float64) string {
if v == nil || *v <= 0 {
return "" // empty → parsePaymentConfig uses default
}
return strconv.FormatFloat(*v, 'f', 2, 64)
}
func formatPositiveInt(v *int) string {
if v == nil || *v <= 0 {
return ""
}
return strconv.Itoa(*v)
}
func derefStr(v *string) string {
if v == nil {
return ""
}
return *v
}
func splitTypes(s string) []string {
if s == "" {
return nil
}
parts := strings.Split(s, ",")
result := make([]string, 0, len(parts))
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
result = append(result, p)
}
}
return result
}
func joinTypes(types []string) string {
return strings.Join(types, ",")
}
func pcParseFloat(s string, defaultVal float64) float64 {
if s == "" {
return defaultVal
}
v, err := strconv.ParseFloat(s, 64)
if err != nil {
return defaultVal
}
return v
}
func pcParseInt(s string, defaultVal int) int {
if s == "" {
return defaultVal
}
v, err := strconv.Atoi(s)
if err != nil {
return defaultVal
}
return v
}
package service
import (
"testing"
"github.com/Wei-Shaw/sub2api/internal/payment"
)
func TestPcParseFloat(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
defaultVal float64
expected float64
}{
{"empty string returns default", "", 1.0, 1.0},
{"valid float", "3.14", 0, 3.14},
{"valid integer as float", "42", 0, 42.0},
{"invalid string returns default", "notanumber", 9.99, 9.99},
{"zero value", "0", 5.0, 0},
{"negative value", "-10.5", 0, -10.5},
{"very large value", "99999999.99", 0, 99999999.99},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := pcParseFloat(tt.input, tt.defaultVal)
if got != tt.expected {
t.Fatalf("pcParseFloat(%q, %v) = %v, want %v", tt.input, tt.defaultVal, got, tt.expected)
}
})
}
}
func TestPcParseInt(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
defaultVal int
expected int
}{
{"empty string returns default", "", 30, 30},
{"valid int", "10", 0, 10},
{"invalid string returns default", "abc", 5, 5},
{"float string returns default", "3.14", 0, 0},
{"zero value", "0", 99, 0},
{"negative value", "-1", 0, -1},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := pcParseInt(tt.input, tt.defaultVal)
if got != tt.expected {
t.Fatalf("pcParseInt(%q, %v) = %v, want %v", tt.input, tt.defaultVal, got, tt.expected)
}
})
}
}
func TestParsePaymentConfig(t *testing.T) {
t.Parallel()
svc := &PaymentConfigService{}
t.Run("empty vals uses defaults", func(t *testing.T) {
t.Parallel()
cfg := svc.parsePaymentConfig(map[string]string{})
if cfg.Enabled {
t.Fatal("expected Enabled=false by default")
}
if cfg.MinAmount != 1 {
t.Fatalf("expected MinAmount=1, got %v", cfg.MinAmount)
}
if cfg.MaxAmount != 0 {
t.Fatalf("expected MaxAmount=0 (no limit), got %v", cfg.MaxAmount)
}
if cfg.OrderTimeoutMin != 30 {
t.Fatalf("expected OrderTimeoutMin=30, got %v", cfg.OrderTimeoutMin)
}
if cfg.MaxPendingOrders != 3 {
t.Fatalf("expected MaxPendingOrders=3, got %v", cfg.MaxPendingOrders)
}
if cfg.LoadBalanceStrategy != payment.DefaultLoadBalanceStrategy {
t.Fatalf("expected LoadBalanceStrategy=%s, got %q", payment.DefaultLoadBalanceStrategy, cfg.LoadBalanceStrategy)
}
if len(cfg.EnabledTypes) != 0 {
t.Fatalf("expected empty EnabledTypes, got %v", cfg.EnabledTypes)
}
})
t.Run("all values populated", func(t *testing.T) {
t.Parallel()
vals := map[string]string{
SettingPaymentEnabled: "true",
SettingMinRechargeAmount: "5.00",
SettingMaxRechargeAmount: "1000.00",
SettingDailyRechargeLimit: "5000.00",
SettingOrderTimeoutMinutes: "15",
SettingMaxPendingOrders: "5",
SettingEnabledPaymentTypes: "alipay,wxpay,stripe",
SettingBalancePayDisabled: "true",
SettingLoadBalanceStrategy: "least_amount",
SettingProductNamePrefix: "PRE",
SettingProductNameSuffix: "SUF",
}
cfg := svc.parsePaymentConfig(vals)
if !cfg.Enabled {
t.Fatal("expected Enabled=true")
}
if cfg.MinAmount != 5 {
t.Fatalf("MinAmount = %v, want 5", cfg.MinAmount)
}
if cfg.MaxAmount != 1000 {
t.Fatalf("MaxAmount = %v, want 1000", cfg.MaxAmount)
}
if cfg.DailyLimit != 5000 {
t.Fatalf("DailyLimit = %v, want 5000", cfg.DailyLimit)
}
if cfg.OrderTimeoutMin != 15 {
t.Fatalf("OrderTimeoutMin = %v, want 15", cfg.OrderTimeoutMin)
}
if cfg.MaxPendingOrders != 5 {
t.Fatalf("MaxPendingOrders = %v, want 5", cfg.MaxPendingOrders)
}
if len(cfg.EnabledTypes) != 3 {
t.Fatalf("EnabledTypes len = %d, want 3", len(cfg.EnabledTypes))
}
if cfg.EnabledTypes[0] != "alipay" || cfg.EnabledTypes[1] != "wxpay" || cfg.EnabledTypes[2] != "stripe" {
t.Fatalf("EnabledTypes = %v, want [alipay wxpay stripe]", cfg.EnabledTypes)
}
if !cfg.BalanceDisabled {
t.Fatal("expected BalanceDisabled=true")
}
if cfg.LoadBalanceStrategy != "least_amount" {
t.Fatalf("LoadBalanceStrategy = %q, want %q", cfg.LoadBalanceStrategy, "least_amount")
}
if cfg.ProductNamePrefix != "PRE" {
t.Fatalf("ProductNamePrefix = %q, want %q", cfg.ProductNamePrefix, "PRE")
}
if cfg.ProductNameSuffix != "SUF" {
t.Fatalf("ProductNameSuffix = %q, want %q", cfg.ProductNameSuffix, "SUF")
}
})
t.Run("enabled types with spaces are trimmed", func(t *testing.T) {
t.Parallel()
vals := map[string]string{
SettingEnabledPaymentTypes: " alipay , wxpay ",
}
cfg := svc.parsePaymentConfig(vals)
if len(cfg.EnabledTypes) != 2 {
t.Fatalf("EnabledTypes len = %d, want 2", len(cfg.EnabledTypes))
}
if cfg.EnabledTypes[0] != "alipay" || cfg.EnabledTypes[1] != "wxpay" {
t.Fatalf("EnabledTypes = %v, want [alipay wxpay]", cfg.EnabledTypes)
}
})
t.Run("empty enabled types string", func(t *testing.T) {
t.Parallel()
vals := map[string]string{
SettingEnabledPaymentTypes: "",
}
cfg := svc.parsePaymentConfig(vals)
if len(cfg.EnabledTypes) != 0 {
t.Fatalf("expected empty EnabledTypes for empty string, got %v", cfg.EnabledTypes)
}
})
}
func TestGetBasePaymentType(t *testing.T) {
t.Parallel()
tests := []struct {
input string
expected string
}{
{payment.TypeEasyPay, payment.TypeEasyPay},
{payment.TypeStripe, payment.TypeStripe},
{payment.TypeCard, payment.TypeStripe},
{payment.TypeLink, payment.TypeStripe},
{payment.TypeAlipay, payment.TypeAlipay},
{payment.TypeAlipayDirect, payment.TypeAlipay},
{payment.TypeWxpay, payment.TypeWxpay},
{payment.TypeWxpayDirect, payment.TypeWxpay},
{"unknown", "unknown"},
{"", ""},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
t.Parallel()
got := payment.GetBasePaymentType(tt.input)
if got != tt.expected {
t.Fatalf("GetBasePaymentType(%q) = %q, want %q", tt.input, got, tt.expected)
}
})
}
}
package service
import (
"context"
"fmt"
"log/slog"
"math"
"strconv"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
// --- Payment Notification & Fulfillment ---
func (s *PaymentService) HandlePaymentNotification(ctx context.Context, n *payment.PaymentNotification, pk string) error {
if n.Status != payment.NotificationStatusSuccess {
return nil
}
// Look up order by out_trade_no (the external order ID we sent to the provider)
order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(n.OrderID)).Only(ctx)
if err != nil {
// Fallback: try legacy format (sub2_N where N is DB ID)
trimmed := strings.TrimPrefix(n.OrderID, orderIDPrefix)
if oid, parseErr := strconv.ParseInt(trimmed, 10, 64); parseErr == nil {
return s.confirmPayment(ctx, oid, n.TradeNo, n.Amount, pk)
}
return fmt.Errorf("order not found for out_trade_no: %s", n.OrderID)
}
return s.confirmPayment(ctx, order.ID, n.TradeNo, n.Amount, pk)
}
func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo string, paid float64, pk string) error {
o, err := s.entClient.PaymentOrder.Get(ctx, oid)
if err != nil {
slog.Error("order not found", "orderID", oid)
return nil
}
// Skip amount check when paid=0 (e.g. QueryOrder doesn't return amount).
// Also skip if paid is NaN/Inf (malformed provider data).
if paid > 0 && !math.IsNaN(paid) && !math.IsInf(paid, 0) {
if math.Abs(paid-o.PayAmount) > amountToleranceCNY {
s.writeAuditLog(ctx, o.ID, "PAYMENT_AMOUNT_MISMATCH", pk, map[string]any{"expected": o.PayAmount, "paid": paid, "tradeNo": tradeNo})
return fmt.Errorf("amount mismatch: expected %.2f, got %.2f", o.PayAmount, paid)
}
}
// Use order's expected amount when provider didn't report one
if paid <= 0 || math.IsNaN(paid) || math.IsInf(paid, 0) {
paid = o.PayAmount
}
return s.toPaid(ctx, o, tradeNo, paid, pk)
}
func (s *PaymentService) toPaid(ctx context.Context, o *dbent.PaymentOrder, tradeNo string, paid float64, pk string) error {
previousStatus := o.Status
now := time.Now()
grace := now.Add(-paymentGraceMinutes * time.Minute)
c, err := s.entClient.PaymentOrder.Update().Where(
paymentorder.IDEQ(o.ID),
paymentorder.Or(
paymentorder.StatusEQ(OrderStatusPending),
paymentorder.StatusEQ(OrderStatusCancelled),
paymentorder.And(
paymentorder.StatusEQ(OrderStatusExpired),
paymentorder.UpdatedAtGTE(grace),
),
),
).SetStatus(OrderStatusPaid).SetPayAmount(paid).SetPaymentTradeNo(tradeNo).SetPaidAt(now).ClearFailedAt().ClearFailedReason().Save(ctx)
if err != nil {
return fmt.Errorf("update to PAID: %w", err)
}
if c == 0 {
return s.alreadyProcessed(ctx, o)
}
if previousStatus == OrderStatusCancelled || previousStatus == OrderStatusExpired {
slog.Info("order recovered from webhook payment success",
"orderID", o.ID,
"previousStatus", previousStatus,
"tradeNo", tradeNo,
"provider", pk,
)
s.writeAuditLog(ctx, o.ID, "ORDER_RECOVERED", pk, map[string]any{
"previous_status": previousStatus,
"tradeNo": tradeNo,
"paidAmount": paid,
"reason": "webhook payment success received after order " + previousStatus,
})
}
s.writeAuditLog(ctx, o.ID, "ORDER_PAID", pk, map[string]any{"tradeNo": tradeNo, "paidAmount": paid})
return s.executeFulfillment(ctx, o.ID)
}
func (s *PaymentService) alreadyProcessed(ctx context.Context, o *dbent.PaymentOrder) error {
cur, err := s.entClient.PaymentOrder.Get(ctx, o.ID)
if err != nil {
return nil
}
switch cur.Status {
case OrderStatusCompleted, OrderStatusRefunded:
return nil
case OrderStatusFailed:
return s.executeFulfillment(ctx, o.ID)
case OrderStatusPaid, OrderStatusRecharging:
return fmt.Errorf("order %d is being processed", o.ID)
case OrderStatusExpired:
slog.Warn("webhook payment success for expired order beyond grace period",
"orderID", o.ID,
"status", cur.Status,
"updatedAt", cur.UpdatedAt,
)
s.writeAuditLog(ctx, o.ID, "PAYMENT_AFTER_EXPIRY", "system", map[string]any{
"status": cur.Status,
"updatedAt": cur.UpdatedAt,
"reason": "payment arrived after expiry grace period",
})
return nil
default:
return nil
}
}
func (s *PaymentService) executeFulfillment(ctx context.Context, oid int64) error {
o, err := s.entClient.PaymentOrder.Get(ctx, oid)
if err != nil {
return fmt.Errorf("get order: %w", err)
}
if o.OrderType == payment.OrderTypeSubscription {
return s.ExecuteSubscriptionFulfillment(ctx, oid)
}
return s.ExecuteBalanceFulfillment(ctx, oid)
}
func (s *PaymentService) ExecuteBalanceFulfillment(ctx context.Context, oid int64) error {
o, err := s.entClient.PaymentOrder.Get(ctx, oid)
if err != nil {
return infraerrors.NotFound("NOT_FOUND", "order not found")
}
if o.Status == OrderStatusCompleted {
return nil
}
if psIsRefundStatus(o.Status) {
return infraerrors.BadRequest("INVALID_STATUS", "refund-related order cannot fulfill")
}
if o.Status != OrderStatusPaid && o.Status != OrderStatusFailed {
return infraerrors.BadRequest("INVALID_STATUS", "order cannot fulfill in status "+o.Status)
}
c, err := s.entClient.PaymentOrder.Update().Where(paymentorder.IDEQ(oid), paymentorder.StatusIn(OrderStatusPaid, OrderStatusFailed)).SetStatus(OrderStatusRecharging).Save(ctx)
if err != nil {
return fmt.Errorf("lock: %w", err)
}
if c == 0 {
return nil
}
if err := s.doBalance(ctx, o); err != nil {
s.markFailed(ctx, oid, err)
return err
}
return nil
}
// redeemAction represents the idempotency decision for balance fulfillment.
type redeemAction int
const (
// redeemActionCreate: code does not exist — create it, then redeem.
redeemActionCreate redeemAction = iota
// redeemActionRedeem: code exists but is unused — skip creation, redeem only.
redeemActionRedeem
// redeemActionSkipCompleted: code exists and is already used — skip to mark completed.
redeemActionSkipCompleted
)
// resolveRedeemAction decides the idempotency action based on an existing redeem code lookup.
// existing is the result of GetByCode; lookupErr is the error from that call.
func resolveRedeemAction(existing *RedeemCode, lookupErr error) redeemAction {
if existing == nil || lookupErr != nil {
return redeemActionCreate
}
if existing.IsUsed() {
return redeemActionSkipCompleted
}
return redeemActionRedeem
}
func (s *PaymentService) doBalance(ctx context.Context, o *dbent.PaymentOrder) error {
// Idempotency: check if redeem code already exists (from a previous partial run)
existing, lookupErr := s.redeemService.GetByCode(ctx, o.RechargeCode)
action := resolveRedeemAction(existing, lookupErr)
switch action {
case redeemActionSkipCompleted:
// Code already created and redeemed — just mark completed
return s.markCompleted(ctx, o, "RECHARGE_SUCCESS")
case redeemActionCreate:
rc := &RedeemCode{Code: o.RechargeCode, Type: RedeemTypeBalance, Value: o.Amount, Status: StatusUnused}
if err := s.redeemService.CreateCode(ctx, rc); err != nil {
return fmt.Errorf("create redeem code: %w", err)
}
case redeemActionRedeem:
// Code exists but unused — skip creation, proceed to redeem
}
if _, err := s.redeemService.Redeem(ctx, o.UserID, o.RechargeCode); err != nil {
return fmt.Errorf("redeem balance: %w", err)
}
return s.markCompleted(ctx, o, "RECHARGE_SUCCESS")
}
func (s *PaymentService) markCompleted(ctx context.Context, o *dbent.PaymentOrder, auditAction string) error {
now := time.Now()
_, err := s.entClient.PaymentOrder.Update().Where(paymentorder.IDEQ(o.ID), paymentorder.StatusEQ(OrderStatusRecharging)).SetStatus(OrderStatusCompleted).SetCompletedAt(now).Save(ctx)
if err != nil {
return fmt.Errorf("mark completed: %w", err)
}
s.writeAuditLog(ctx, o.ID, auditAction, "system", map[string]any{"rechargeCode": o.RechargeCode, "amount": o.Amount})
return nil
}
func (s *PaymentService) ExecuteSubscriptionFulfillment(ctx context.Context, oid int64) error {
o, err := s.entClient.PaymentOrder.Get(ctx, oid)
if err != nil {
return infraerrors.NotFound("NOT_FOUND", "order not found")
}
if o.Status == OrderStatusCompleted {
return nil
}
if psIsRefundStatus(o.Status) {
return infraerrors.BadRequest("INVALID_STATUS", "refund-related order cannot fulfill")
}
if o.Status != OrderStatusPaid && o.Status != OrderStatusFailed {
return infraerrors.BadRequest("INVALID_STATUS", "order cannot fulfill in status "+o.Status)
}
if o.SubscriptionGroupID == nil || o.SubscriptionDays == nil {
return infraerrors.BadRequest("INVALID_STATUS", "missing subscription info")
}
c, err := s.entClient.PaymentOrder.Update().Where(paymentorder.IDEQ(oid), paymentorder.StatusIn(OrderStatusPaid, OrderStatusFailed)).SetStatus(OrderStatusRecharging).Save(ctx)
if err != nil {
return fmt.Errorf("lock: %w", err)
}
if c == 0 {
return nil
}
if err := s.doSub(ctx, o); err != nil {
s.markFailed(ctx, oid, err)
return err
}
return nil
}
func (s *PaymentService) doSub(ctx context.Context, o *dbent.PaymentOrder) error {
gid := *o.SubscriptionGroupID
days := *o.SubscriptionDays
g, err := s.groupRepo.GetByID(ctx, gid)
if err != nil || g.Status != payment.EntityStatusActive {
return fmt.Errorf("group %d no longer exists or inactive", gid)
}
// Idempotency: check audit log to see if subscription was already assigned.
// Prevents double-extension on retry after markCompleted fails.
if s.hasAuditLog(ctx, o.ID, "SUBSCRIPTION_SUCCESS") {
slog.Info("subscription already assigned for order, skipping", "orderID", o.ID, "groupID", gid)
return s.markCompleted(ctx, o, "SUBSCRIPTION_SUCCESS")
}
orderNote := fmt.Sprintf("payment order %d", o.ID)
_, _, err = s.subscriptionSvc.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{UserID: o.UserID, GroupID: gid, ValidityDays: days, AssignedBy: 0, Notes: orderNote})
if err != nil {
return fmt.Errorf("assign subscription: %w", err)
}
return s.markCompleted(ctx, o, "SUBSCRIPTION_SUCCESS")
}
func (s *PaymentService) hasAuditLog(ctx context.Context, orderID int64, action string) bool {
oid := strconv.FormatInt(orderID, 10)
c, _ := s.entClient.PaymentAuditLog.Query().
Where(paymentauditlog.OrderIDEQ(oid), paymentauditlog.ActionEQ(action)).
Limit(1).Count(ctx)
return c > 0
}
func (s *PaymentService) markFailed(ctx context.Context, oid int64, cause error) {
now := time.Now()
r := psErrMsg(cause)
// Only mark FAILED if still in RECHARGING state — prevents overwriting
// a COMPLETED order when markCompleted failed but fulfillment succeeded.
c, e := s.entClient.PaymentOrder.Update().
Where(paymentorder.IDEQ(oid), paymentorder.StatusEQ(OrderStatusRecharging)).
SetStatus(OrderStatusFailed).SetFailedAt(now).SetFailedReason(r).Save(ctx)
if e != nil {
slog.Error("mark FAILED", "orderID", oid, "error", e)
}
if c > 0 {
s.writeAuditLog(ctx, oid, "FULFILLMENT_FAILED", "system", map[string]any{"reason": r})
}
}
func (s *PaymentService) RetryFulfillment(ctx context.Context, oid int64) error {
o, err := s.entClient.PaymentOrder.Get(ctx, oid)
if err != nil {
return infraerrors.NotFound("NOT_FOUND", "order not found")
}
if o.PaidAt == nil {
return infraerrors.BadRequest("INVALID_STATUS", "order is not paid")
}
if psIsRefundStatus(o.Status) {
return infraerrors.BadRequest("INVALID_STATUS", "refund-related order cannot retry")
}
if o.Status == OrderStatusRecharging {
return infraerrors.Conflict("CONFLICT", "order is being processed")
}
if o.Status == OrderStatusCompleted {
return infraerrors.BadRequest("INVALID_STATUS", "order already completed")
}
if o.Status != OrderStatusFailed && o.Status != OrderStatusPaid {
return infraerrors.BadRequest("INVALID_STATUS", "only paid and failed orders can retry")
}
_, err = s.entClient.PaymentOrder.Update().Where(paymentorder.IDEQ(oid), paymentorder.StatusIn(OrderStatusFailed, OrderStatusPaid)).SetStatus(OrderStatusPaid).ClearFailedAt().ClearFailedReason().Save(ctx)
if err != nil {
return fmt.Errorf("reset for retry: %w", err)
}
s.writeAuditLog(ctx, oid, "RECHARGE_RETRY", "admin", map[string]any{"detail": "admin manual retry"})
return s.executeFulfillment(ctx, oid)
}
//go:build unit
package service
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
)
// ---------------------------------------------------------------------------
// resolveRedeemAction — pure idempotency decision logic
// ---------------------------------------------------------------------------
func TestResolveRedeemAction_CodeNotFound(t *testing.T) {
t.Parallel()
action := resolveRedeemAction(nil, nil)
assert.Equal(t, redeemActionCreate, action, "nil code with nil error should create")
}
func TestResolveRedeemAction_LookupError(t *testing.T) {
t.Parallel()
action := resolveRedeemAction(nil, errors.New("db connection lost"))
assert.Equal(t, redeemActionCreate, action, "lookup error should fall back to create")
}
func TestResolveRedeemAction_LookupErrorWithNonNilCode(t *testing.T) {
t.Parallel()
// Edge case: both code and error are non-nil (shouldn't happen in practice,
// but the function should still treat error as authoritative)
code := &RedeemCode{Status: StatusUnused}
action := resolveRedeemAction(code, errors.New("partial error"))
assert.Equal(t, redeemActionCreate, action, "non-nil error should always result in create regardless of code")
}
func TestResolveRedeemAction_CodeExistsAndUsed(t *testing.T) {
t.Parallel()
code := &RedeemCode{
Code: "test-code-123",
Status: StatusUsed,
Type: RedeemTypeBalance,
Value: 10.0,
}
action := resolveRedeemAction(code, nil)
assert.Equal(t, redeemActionSkipCompleted, action, "used code should skip to completed")
}
func TestResolveRedeemAction_CodeExistsAndUnused(t *testing.T) {
t.Parallel()
code := &RedeemCode{
Code: "test-code-456",
Status: StatusUnused,
Type: RedeemTypeBalance,
Value: 25.0,
}
action := resolveRedeemAction(code, nil)
assert.Equal(t, redeemActionRedeem, action, "unused code should skip creation and proceed to redeem")
}
func TestResolveRedeemAction_CodeExistsWithExpiredStatus(t *testing.T) {
t.Parallel()
// A code with a non-standard status (neither "unused" nor "used")
// should NOT be treated as used, so it falls through to redeemActionRedeem.
code := &RedeemCode{
Code: "expired-code",
Status: StatusExpired,
}
action := resolveRedeemAction(code, nil)
assert.Equal(t, redeemActionRedeem, action, "expired-status code is not IsUsed(), should redeem")
}
// ---------------------------------------------------------------------------
// Table-driven comprehensive test
// ---------------------------------------------------------------------------
func TestResolveRedeemAction_Table(t *testing.T) {
t.Parallel()
tests := []struct {
name string
code *RedeemCode
err error
expected redeemAction
}{
{
name: "nil code, nil error — first run",
code: nil,
err: nil,
expected: redeemActionCreate,
},
{
name: "nil code, lookup error — treat as not found",
code: nil,
err: ErrRedeemCodeNotFound,
expected: redeemActionCreate,
},
{
name: "nil code, generic DB error — treat as not found",
code: nil,
err: errors.New("connection refused"),
expected: redeemActionCreate,
},
{
name: "code exists, used — previous run completed redeem",
code: &RedeemCode{Status: StatusUsed},
err: nil,
expected: redeemActionSkipCompleted,
},
{
name: "code exists, unused — previous run created code but crashed before redeem",
code: &RedeemCode{Status: StatusUnused},
err: nil,
expected: redeemActionRedeem,
},
{
name: "code exists but error also set — error takes precedence",
code: &RedeemCode{Status: StatusUsed},
err: errors.New("unexpected"),
expected: redeemActionCreate,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := resolveRedeemAction(tt.code, tt.err)
assert.Equal(t, tt.expected, got)
})
}
}
// ---------------------------------------------------------------------------
// redeemAction enum value sanity
// ---------------------------------------------------------------------------
func TestRedeemAction_DistinctValues(t *testing.T) {
t.Parallel()
// Ensure the three actions have distinct values (iota correctness)
assert.NotEqual(t, redeemActionCreate, redeemActionRedeem)
assert.NotEqual(t, redeemActionCreate, redeemActionSkipCompleted)
assert.NotEqual(t, redeemActionRedeem, redeemActionSkipCompleted)
}
// ---------------------------------------------------------------------------
// RedeemCode.IsUsed / CanUse interaction with resolveRedeemAction
// ---------------------------------------------------------------------------
func TestResolveRedeemAction_IsUsedCanUseConsistency(t *testing.T) {
t.Parallel()
usedCode := &RedeemCode{Status: StatusUsed}
unusedCode := &RedeemCode{Status: StatusUnused}
// Verify our decision function is consistent with the domain model methods
assert.True(t, usedCode.IsUsed())
assert.False(t, usedCode.CanUse())
assert.Equal(t, redeemActionSkipCompleted, resolveRedeemAction(usedCode, nil))
assert.False(t, unusedCode.IsUsed())
assert.True(t, unusedCode.CanUse())
assert.Equal(t, redeemActionRedeem, resolveRedeemAction(unusedCode, nil))
}
package service
import (
"context"
"fmt"
"log/slog"
"math"
"strconv"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/Wei-Shaw/sub2api/internal/payment/provider"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
// --- Order Creation ---
func (s *PaymentService) CreateOrder(ctx context.Context, req CreateOrderRequest) (*CreateOrderResponse, error) {
if req.OrderType == "" {
req.OrderType = payment.OrderTypeBalance
}
cfg, err := s.configService.GetPaymentConfig(ctx)
if err != nil {
return nil, fmt.Errorf("get payment config: %w", err)
}
if !cfg.Enabled {
return nil, infraerrors.Forbidden("PAYMENT_DISABLED", "payment system is disabled")
}
plan, err := s.validateOrderInput(ctx, req, cfg)
if err != nil {
return nil, err
}
if err := s.checkCancelRateLimit(ctx, req.UserID, cfg); err != nil {
return nil, err
}
user, err := s.userRepo.GetByID(ctx, req.UserID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
if user.Status != payment.EntityStatusActive {
return nil, infraerrors.Forbidden("USER_INACTIVE", "user account is disabled")
}
amount := req.Amount
if plan != nil {
amount = plan.Price
}
feeRate := s.getFeeRate(req.PaymentType)
payAmountStr := payment.CalculatePayAmount(amount, feeRate)
payAmount, _ := strconv.ParseFloat(payAmountStr, 64)
order, err := s.createOrderInTx(ctx, req, user, plan, cfg, amount, feeRate, payAmount)
if err != nil {
return nil, err
}
resp, err := s.invokeProvider(ctx, order, req, cfg, payAmountStr, payAmount, plan)
if err != nil {
_, _ = s.entClient.PaymentOrder.UpdateOneID(order.ID).
SetStatus(OrderStatusFailed).
Save(ctx)
return nil, err
}
return resp, nil
}
func (s *PaymentService) validateOrderInput(ctx context.Context, req CreateOrderRequest, cfg *PaymentConfig) (*dbent.SubscriptionPlan, error) {
if req.OrderType == payment.OrderTypeBalance && cfg.BalanceDisabled {
return nil, infraerrors.Forbidden("BALANCE_PAYMENT_DISABLED", "balance recharge has been disabled")
}
if req.OrderType == payment.OrderTypeSubscription {
return s.validateSubOrder(ctx, req)
}
if math.IsNaN(req.Amount) || math.IsInf(req.Amount, 0) || req.Amount <= 0 {
return nil, infraerrors.BadRequest("INVALID_AMOUNT", "amount must be a positive number")
}
if (cfg.MinAmount > 0 && req.Amount < cfg.MinAmount) || (cfg.MaxAmount > 0 && req.Amount > cfg.MaxAmount) {
return nil, infraerrors.BadRequest("INVALID_AMOUNT", "amount out of range").
WithMetadata(map[string]string{"min": fmt.Sprintf("%.2f", cfg.MinAmount), "max": fmt.Sprintf("%.2f", cfg.MaxAmount)})
}
return nil, nil
}
func (s *PaymentService) validateSubOrder(ctx context.Context, req CreateOrderRequest) (*dbent.SubscriptionPlan, error) {
if req.PlanID == 0 {
return nil, infraerrors.BadRequest("INVALID_INPUT", "subscription order requires a plan")
}
plan, err := s.configService.GetPlan(ctx, req.PlanID)
if err != nil || !plan.ForSale {
return nil, infraerrors.NotFound("PLAN_NOT_AVAILABLE", "plan not found or not for sale")
}
group, err := s.groupRepo.GetByID(ctx, plan.GroupID)
if err != nil || group.Status != payment.EntityStatusActive {
return nil, infraerrors.NotFound("GROUP_NOT_FOUND", "subscription group is no longer available")
}
if !group.IsSubscriptionType() {
return nil, infraerrors.BadRequest("GROUP_TYPE_MISMATCH", "group is not a subscription type")
}
return plan, nil
}
func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderRequest, user *User, plan *dbent.SubscriptionPlan, cfg *PaymentConfig, amount, feeRate, payAmount float64) (*dbent.PaymentOrder, error) {
tx, err := s.entClient.Tx(ctx)
if err != nil {
return nil, fmt.Errorf("begin transaction: %w", err)
}
defer func() { _ = tx.Rollback() }()
if err := s.checkPendingLimit(ctx, tx, req.UserID, cfg.MaxPendingOrders); err != nil {
return nil, err
}
if err := s.checkDailyLimit(ctx, tx, req.UserID, amount, cfg.DailyLimit); err != nil {
return nil, err
}
tm := cfg.OrderTimeoutMin
if tm <= 0 {
tm = defaultOrderTimeoutMin
}
exp := time.Now().Add(time.Duration(tm) * time.Minute)
b := tx.PaymentOrder.Create().
SetUserID(req.UserID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetNillableUserNotes(psNilIfEmpty(user.Notes)).
SetAmount(amount).
SetPayAmount(payAmount).
SetFeeRate(feeRate).
SetRechargeCode("").
SetOutTradeNo(generateOutTradeNo()).
SetPaymentType(req.PaymentType).
SetPaymentTradeNo("").
SetOrderType(req.OrderType).
SetStatus(OrderStatusPending).
SetExpiresAt(exp).
SetClientIP(req.ClientIP).
SetSrcHost(req.SrcHost)
if req.SrcURL != "" {
b.SetSrcURL(req.SrcURL)
}
if plan != nil {
b.SetPlanID(plan.ID).SetSubscriptionGroupID(plan.GroupID).SetSubscriptionDays(psComputeValidityDays(plan.ValidityDays, plan.ValidityUnit))
}
order, err := b.Save(ctx)
if err != nil {
return nil, fmt.Errorf("create order: %w", err)
}
code := fmt.Sprintf("PAY-%d-%d", order.ID, time.Now().UnixNano()%100000)
order, err = tx.PaymentOrder.UpdateOneID(order.ID).SetRechargeCode(code).Save(ctx)
if err != nil {
return nil, fmt.Errorf("set recharge code: %w", err)
}
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("commit order transaction: %w", err)
}
return order, nil
}
func (s *PaymentService) checkPendingLimit(ctx context.Context, tx *dbent.Tx, userID int64, max int) error {
if max <= 0 {
max = defaultMaxPendingOrders
}
c, err := tx.PaymentOrder.Query().Where(paymentorder.UserIDEQ(userID), paymentorder.StatusEQ(OrderStatusPending)).Count(ctx)
if err != nil {
return fmt.Errorf("count pending orders: %w", err)
}
if c >= max {
return infraerrors.TooManyRequests("TOO_MANY_PENDING", fmt.Sprintf("too many pending orders (max %d)", max)).
WithMetadata(map[string]string{"max": strconv.Itoa(max)})
}
return nil
}
func (s *PaymentService) checkCancelRateLimit(ctx context.Context, userID int64, cfg *PaymentConfig) error {
if !cfg.CancelRateLimitEnabled || cfg.CancelRateLimitMax <= 0 {
return nil
}
windowStart := cancelRateLimitWindowStart(cfg)
operator := fmt.Sprintf("user:%d", userID)
count, err := s.entClient.PaymentAuditLog.Query().
Where(
paymentauditlog.ActionEQ("ORDER_CANCELLED"),
paymentauditlog.OperatorEQ(operator),
paymentauditlog.CreatedAtGTE(windowStart),
).Count(ctx)
if err != nil {
slog.Error("check cancel rate limit failed", "userID", userID, "error", err)
return nil // fail open
}
if count >= cfg.CancelRateLimitMax {
return infraerrors.TooManyRequests("CANCEL_RATE_LIMITED", "cancel rate limited").
WithMetadata(map[string]string{
"max": strconv.Itoa(cfg.CancelRateLimitMax),
"window": strconv.Itoa(cfg.CancelRateLimitWindow),
"unit": cfg.CancelRateLimitUnit,
})
}
return nil
}
func cancelRateLimitWindowStart(cfg *PaymentConfig) time.Time {
now := time.Now()
w := cfg.CancelRateLimitWindow
if w <= 0 {
w = 1
}
unit := cfg.CancelRateLimitUnit
if unit == "" {
unit = "day"
}
if cfg.CancelRateLimitMode == "fixed" {
switch unit {
case "minute":
t := now.Truncate(time.Minute)
return t.Add(-time.Duration(w-1) * time.Minute)
case "day":
y, m, d := now.Date()
t := time.Date(y, m, d, 0, 0, 0, 0, now.Location())
return t.AddDate(0, 0, -(w - 1))
default: // hour
t := now.Truncate(time.Hour)
return t.Add(-time.Duration(w-1) * time.Hour)
}
}
// rolling window
switch unit {
case "minute":
return now.Add(-time.Duration(w) * time.Minute)
case "day":
return now.AddDate(0, 0, -w)
default: // hour
return now.Add(-time.Duration(w) * time.Hour)
}
}
func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, userID int64, amount, limit float64) error {
if limit <= 0 {
return nil
}
ts := psStartOfDayUTC(time.Now())
orders, err := tx.PaymentOrder.Query().Where(paymentorder.UserIDEQ(userID), paymentorder.StatusIn(OrderStatusPaid, OrderStatusRecharging, OrderStatusCompleted), paymentorder.PaidAtGTE(ts)).All(ctx)
if err != nil {
return fmt.Errorf("query daily usage: %w", err)
}
var used float64
for _, o := range orders {
used += o.Amount
}
if used+amount > limit {
return infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", fmt.Sprintf("daily recharge limit reached, remaining: %.2f", math.Max(0, limit-used)))
}
return nil
}
func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.PaymentOrder, req CreateOrderRequest, cfg *PaymentConfig, payAmountStr string, payAmount float64, plan *dbent.SubscriptionPlan) (*CreateOrderResponse, error) {
s.EnsureProviders(ctx)
providerKey := s.registry.GetProviderKey(req.PaymentType)
if providerKey == "" {
return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment method (%s) is not configured", req.PaymentType))
}
sel, err := s.loadBalancer.SelectInstance(ctx, providerKey, req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount)
if err != nil {
return nil, fmt.Errorf("select provider instance: %w", err)
}
if sel == nil {
return nil, infraerrors.TooManyRequests("NO_AVAILABLE_INSTANCE", "no available payment instance")
}
prov, err := provider.CreateProvider(providerKey, sel.InstanceID, sel.Config)
if err != nil {
return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", "payment method is temporarily unavailable")
}
subject := s.buildPaymentSubject(plan, payAmountStr, cfg)
outTradeNo := order.OutTradeNo
pr, err := prov.CreatePayment(ctx, payment.CreatePaymentRequest{OrderID: outTradeNo, Amount: payAmountStr, PaymentType: req.PaymentType, Subject: subject, ClientIP: req.ClientIP, IsMobile: req.IsMobile, InstanceSubMethods: sel.SupportedTypes})
if err != nil {
slog.Error("[PaymentService] CreatePayment failed", "provider", providerKey, "instance", sel.InstanceID, "error", err)
return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment gateway error: %s", err.Error()))
}
_, err = s.entClient.PaymentOrder.UpdateOneID(order.ID).SetNillablePaymentTradeNo(psNilIfEmpty(pr.TradeNo)).SetNillablePayURL(psNilIfEmpty(pr.PayURL)).SetNillableQrCode(psNilIfEmpty(pr.QRCode)).SetNillableProviderInstanceID(psNilIfEmpty(sel.InstanceID)).Save(ctx)
if err != nil {
return nil, fmt.Errorf("update order with payment details: %w", err)
}
s.writeAuditLog(ctx, order.ID, "ORDER_CREATED", fmt.Sprintf("user:%d", req.UserID), map[string]any{"amount": req.Amount, "paymentType": req.PaymentType, "orderType": req.OrderType})
return &CreateOrderResponse{OrderID: order.ID, Amount: order.Amount, PayAmount: payAmount, FeeRate: order.FeeRate, Status: OrderStatusPending, PaymentType: req.PaymentType, PayURL: pr.PayURL, QRCode: pr.QRCode, ClientSecret: pr.ClientSecret, ExpiresAt: order.ExpiresAt, PaymentMode: sel.PaymentMode}, nil
}
func (s *PaymentService) buildPaymentSubject(plan *dbent.SubscriptionPlan, payAmountStr string, cfg *PaymentConfig) string {
if plan != nil {
if plan.ProductName != "" {
return plan.ProductName
}
return "Sub2API Subscription " + plan.Name
}
pf := strings.TrimSpace(cfg.ProductNamePrefix)
sf := strings.TrimSpace(cfg.ProductNameSuffix)
if pf != "" || sf != "" {
return strings.TrimSpace(pf + " " + payAmountStr + " " + sf)
}
return "Sub2API " + payAmountStr + " CNY"
}
// --- Order Queries ---
func (s *PaymentService) GetOrder(ctx context.Context, orderID, userID int64) (*dbent.PaymentOrder, error) {
o, err := s.entClient.PaymentOrder.Get(ctx, orderID)
if err != nil {
return nil, infraerrors.NotFound("NOT_FOUND", "order not found")
}
if o.UserID != userID {
return nil, infraerrors.Forbidden("FORBIDDEN", "no permission for this order")
}
return o, nil
}
func (s *PaymentService) GetOrderByID(ctx context.Context, orderID int64) (*dbent.PaymentOrder, error) {
o, err := s.entClient.PaymentOrder.Get(ctx, orderID)
if err != nil {
return nil, infraerrors.NotFound("NOT_FOUND", "order not found")
}
return o, nil
}
func (s *PaymentService) GetUserOrders(ctx context.Context, userID int64, p OrderListParams) ([]*dbent.PaymentOrder, int, error) {
q := s.entClient.PaymentOrder.Query().Where(paymentorder.UserIDEQ(userID))
if p.Status != "" {
q = q.Where(paymentorder.StatusEQ(p.Status))
}
if p.OrderType != "" {
q = q.Where(paymentorder.OrderTypeEQ(p.OrderType))
}
if p.PaymentType != "" {
q = q.Where(paymentorder.PaymentTypeEQ(p.PaymentType))
}
total, err := q.Clone().Count(ctx)
if err != nil {
return nil, 0, fmt.Errorf("count user orders: %w", err)
}
ps, pg := applyPagination(p.PageSize, p.Page)
orders, err := q.Order(dbent.Desc(paymentorder.FieldCreatedAt)).Limit(ps).Offset((pg - 1) * ps).All(ctx)
if err != nil {
return nil, 0, fmt.Errorf("query user orders: %w", err)
}
return orders, total, nil
}
// AdminListOrders returns a paginated list of orders. If userID > 0, filters by user.
func (s *PaymentService) AdminListOrders(ctx context.Context, userID int64, p OrderListParams) ([]*dbent.PaymentOrder, int, error) {
q := s.entClient.PaymentOrder.Query()
if userID > 0 {
q = q.Where(paymentorder.UserIDEQ(userID))
}
if p.Status != "" {
q = q.Where(paymentorder.StatusEQ(p.Status))
}
if p.OrderType != "" {
q = q.Where(paymentorder.OrderTypeEQ(p.OrderType))
}
if p.PaymentType != "" {
q = q.Where(paymentorder.PaymentTypeEQ(p.PaymentType))
}
if p.Keyword != "" {
q = q.Where(paymentorder.Or(
paymentorder.OutTradeNoContainsFold(p.Keyword),
paymentorder.UserEmailContainsFold(p.Keyword),
paymentorder.UserNameContainsFold(p.Keyword),
))
}
total, err := q.Clone().Count(ctx)
if err != nil {
return nil, 0, fmt.Errorf("count admin orders: %w", err)
}
ps, pg := applyPagination(p.PageSize, p.Page)
orders, err := q.Order(dbent.Desc(paymentorder.FieldCreatedAt)).Limit(ps).Offset((pg - 1) * ps).All(ctx)
if err != nil {
return nil, 0, fmt.Errorf("query admin orders: %w", err)
}
return orders, total, nil
}
// --- Cancel & Expire ---
func (s *PaymentService) CancelOrder(ctx context.Context, orderID, userID int64) (string, error) {
o, err := s.entClient.PaymentOrder.Get(ctx, orderID)
if err != nil {
return "", infraerrors.NotFound("NOT_FOUND", "order not found")
}
if o.UserID != userID {
return "", infraerrors.Forbidden("FORBIDDEN", "no permission for this order")
}
if o.Status != OrderStatusPending {
return "", infraerrors.BadRequest("INVALID_STATUS", "order cannot be cancelled in current status")
}
return s.cancelCore(ctx, o, OrderStatusCancelled, fmt.Sprintf("user:%d", userID), "user cancelled order")
}
func (s *PaymentService) AdminCancelOrder(ctx context.Context, orderID int64) (string, error) {
o, err := s.entClient.PaymentOrder.Get(ctx, orderID)
if err != nil {
return "", infraerrors.NotFound("NOT_FOUND", "order not found")
}
if o.Status != OrderStatusPending {
return "", infraerrors.BadRequest("INVALID_STATUS", "order cannot be cancelled in current status")
}
return s.cancelCore(ctx, o, OrderStatusCancelled, "admin", "admin cancelled order")
}
func (s *PaymentService) cancelCore(ctx context.Context, o *dbent.PaymentOrder, fs, op, ad string) (string, error) {
if o.PaymentTradeNo != "" || o.PaymentType != "" {
if s.checkPaid(ctx, o) == "already_paid" {
return "already_paid", nil
}
}
c, err := s.entClient.PaymentOrder.Update().Where(paymentorder.IDEQ(o.ID), paymentorder.StatusEQ(OrderStatusPending)).SetStatus(fs).Save(ctx)
if err != nil {
return "", fmt.Errorf("update order status: %w", err)
}
if c > 0 {
auditAction := "ORDER_CANCELLED"
if fs == OrderStatusExpired {
auditAction = "ORDER_EXPIRED"
}
s.writeAuditLog(ctx, o.ID, auditAction, op, map[string]any{"detail": ad})
}
return "cancelled", nil
}
func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) string {
prov, err := s.getOrderProvider(ctx, o)
if err != nil {
return ""
}
// Use OutTradeNo as fallback when PaymentTradeNo is empty
// (e.g. EasyPay popup mode where trade_no arrives only via notify callback)
tradeNo := o.PaymentTradeNo
if tradeNo == "" {
tradeNo = o.OutTradeNo
}
resp, err := prov.QueryOrder(ctx, tradeNo)
if err != nil {
slog.Warn("query upstream failed", "orderID", o.ID, "error", err)
return ""
}
if resp.Status == payment.ProviderStatusPaid {
if err := s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: o.PaymentTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess}, prov.ProviderKey()); err != nil {
slog.Error("fulfillment failed during checkPaid", "orderID", o.ID, "error", err)
// Still return already_paid — order was paid, fulfillment can be retried
}
return "already_paid"
}
if cp, ok := prov.(payment.CancelableProvider); ok {
_ = cp.CancelPayment(ctx, tradeNo)
}
return ""
}
// VerifyOrderByOutTradeNo actively queries the upstream provider to check
// if a payment was made, and processes it if so. This handles the case where
// the provider's notify callback was missed (e.g. EasyPay popup mode).
func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo string, userID int64) (*dbent.PaymentOrder, error) {
o, err := s.entClient.PaymentOrder.Query().
Where(paymentorder.OutTradeNo(outTradeNo)).
Only(ctx)
if err != nil {
return nil, infraerrors.NotFound("NOT_FOUND", "order not found")
}
if o.UserID != userID {
return nil, infraerrors.Forbidden("FORBIDDEN", "no permission for this order")
}
// Only verify orders that are still pending or recently expired
if o.Status == OrderStatusPending || o.Status == OrderStatusExpired {
result := s.checkPaid(ctx, o)
if result == "already_paid" {
// Reload order to get updated status
o, err = s.entClient.PaymentOrder.Get(ctx, o.ID)
if err != nil {
return nil, fmt.Errorf("reload order: %w", err)
}
}
}
return o, nil
}
// VerifyOrderPublic verifies payment status without user authentication.
// Used by the payment result page when the user's session has expired.
func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo string) (*dbent.PaymentOrder, error) {
o, err := s.entClient.PaymentOrder.Query().
Where(paymentorder.OutTradeNo(outTradeNo)).
Only(ctx)
if err != nil {
return nil, infraerrors.NotFound("NOT_FOUND", "order not found")
}
if o.Status == OrderStatusPending || o.Status == OrderStatusExpired {
result := s.checkPaid(ctx, o)
if result == "already_paid" {
o, err = s.entClient.PaymentOrder.Get(ctx, o.ID)
if err != nil {
return nil, fmt.Errorf("reload order: %w", err)
}
}
}
return o, nil
}
func (s *PaymentService) ExpireTimedOutOrders(ctx context.Context) (int, error) {
now := time.Now()
orders, err := s.entClient.PaymentOrder.Query().Where(paymentorder.StatusEQ(OrderStatusPending), paymentorder.ExpiresAtLTE(now)).All(ctx)
if err != nil {
return 0, fmt.Errorf("query expired: %w", err)
}
n := 0
for _, o := range orders {
// Check upstream payment status before expiring — the user may have
// paid just before timeout and the webhook hasn't arrived yet.
outcome, _ := s.cancelCore(ctx, o, OrderStatusExpired, "system", "order expired")
if outcome == "already_paid" {
slog.Info("order was paid during expiry", "orderID", o.ID)
continue
}
if outcome != "" {
n++
}
}
return n, nil
}
// getOrderProvider creates a provider using the order's original instance config.
// Falls back to registry lookup if instance ID is missing (legacy orders).
func (s *PaymentService) getOrderProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) {
if o.ProviderInstanceID != nil && *o.ProviderInstanceID != "" {
instID, err := strconv.ParseInt(*o.ProviderInstanceID, 10, 64)
if err == nil {
cfg, err := s.loadBalancer.GetInstanceConfig(ctx, instID)
if err == nil {
providerKey := s.registry.GetProviderKey(o.PaymentType)
if providerKey == "" {
providerKey = o.PaymentType
}
p, err := provider.CreateProvider(providerKey, *o.ProviderInstanceID, cfg)
if err == nil {
return p, nil
}
}
}
}
s.EnsureProviders(ctx)
return s.registry.GetProvider(o.PaymentType)
}
package service
import (
"context"
"log/slog"
"sync"
"time"
)
const expiryCheckTimeout = 30 * time.Second
// PaymentOrderExpiryService periodically expires timed-out payment orders.
type PaymentOrderExpiryService struct {
paymentSvc *PaymentService
interval time.Duration
stopCh chan struct{}
stopOnce sync.Once
wg sync.WaitGroup
}
func NewPaymentOrderExpiryService(paymentSvc *PaymentService, interval time.Duration) *PaymentOrderExpiryService {
return &PaymentOrderExpiryService{
paymentSvc: paymentSvc,
interval: interval,
stopCh: make(chan struct{}),
}
}
func (s *PaymentOrderExpiryService) Start() {
if s == nil || s.paymentSvc == nil || s.interval <= 0 {
return
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
ticker := time.NewTicker(s.interval)
defer ticker.Stop()
s.runOnce()
for {
select {
case <-ticker.C:
s.runOnce()
case <-s.stopCh:
return
}
}
}()
}
func (s *PaymentOrderExpiryService) Stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
close(s.stopCh)
})
s.wg.Wait()
}
func (s *PaymentOrderExpiryService) runOnce() {
ctx, cancel := context.WithTimeout(context.Background(), expiryCheckTimeout)
defer cancel()
expired, err := s.paymentSvc.ExpireTimedOutOrders(ctx)
if err != nil {
slog.Error("[PaymentOrderExpiry] failed to expire orders", "error", err)
return
}
if expired > 0 {
slog.Info("[PaymentOrderExpiry] expired timed-out orders", "count", expired)
}
}
package service
import (
"context"
"fmt"
"log/slog"
"math"
"strconv"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
// --- Refund Flow ---
func (s *PaymentService) RequestRefund(ctx context.Context, oid, uid int64, reason string) error {
o, err := s.validateRefundRequest(ctx, oid, uid)
if err != nil {
return err
}
u, err := s.userRepo.GetByID(ctx, o.UserID)
if err != nil {
return fmt.Errorf("get user: %w", err)
}
if u.Balance < o.Amount {
return infraerrors.BadRequest("BALANCE_NOT_ENOUGH", "refund amount exceeds balance")
}
nr := strings.TrimSpace(reason)
now := time.Now()
by := fmt.Sprintf("%d", uid)
c, err := s.entClient.PaymentOrder.Update().Where(paymentorder.IDEQ(oid), paymentorder.UserIDEQ(uid), paymentorder.StatusEQ(OrderStatusCompleted), paymentorder.OrderTypeEQ(payment.OrderTypeBalance)).SetStatus(OrderStatusRefundRequested).SetRefundRequestedAt(now).SetRefundRequestReason(nr).SetRefundRequestedBy(by).SetRefundAmount(o.Amount).Save(ctx)
if err != nil {
return fmt.Errorf("update: %w", err)
}
if c == 0 {
return infraerrors.Conflict("CONFLICT", "order status changed")
}
s.writeAuditLog(ctx, oid, "REFUND_REQUESTED", fmt.Sprintf("user:%d", uid), map[string]any{"amount": o.Amount, "reason": nr})
return nil
}
func (s *PaymentService) validateRefundRequest(ctx context.Context, oid, uid int64) (*dbent.PaymentOrder, error) {
o, err := s.entClient.PaymentOrder.Get(ctx, oid)
if err != nil {
return nil, infraerrors.NotFound("NOT_FOUND", "order not found")
}
if o.UserID != uid {
return nil, infraerrors.Forbidden("FORBIDDEN", "no permission")
}
if o.OrderType != payment.OrderTypeBalance {
return nil, infraerrors.BadRequest("INVALID_ORDER_TYPE", "only balance orders can request refund")
}
if o.Status != OrderStatusCompleted {
return nil, infraerrors.BadRequest("INVALID_STATUS", "only completed orders can request refund")
}
return o, nil
}
func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float64, reason string, force, deduct bool) (*RefundPlan, *RefundResult, error) {
o, err := s.entClient.PaymentOrder.Get(ctx, oid)
if err != nil {
return nil, nil, infraerrors.NotFound("NOT_FOUND", "order not found")
}
ok := []string{OrderStatusCompleted, OrderStatusRefundRequested, OrderStatusRefundFailed}
if !psSliceContains(ok, o.Status) {
return nil, nil, infraerrors.BadRequest("INVALID_STATUS", "order status does not allow refund")
}
if math.IsNaN(amt) || math.IsInf(amt, 0) {
return nil, nil, infraerrors.BadRequest("INVALID_AMOUNT", "invalid refund amount")
}
if amt <= 0 {
amt = o.Amount
}
if amt-o.Amount > amountToleranceCNY {
return nil, nil, infraerrors.BadRequest("REFUND_AMOUNT_EXCEEDED", "refund amount exceeds recharge")
}
// Full refund: use actual pay_amount for gateway (includes fees)
ga := amt
if math.Abs(amt-o.Amount) <= amountToleranceCNY {
ga = o.PayAmount
}
rr := strings.TrimSpace(reason)
if rr == "" && o.RefundRequestReason != nil {
rr = *o.RefundRequestReason
}
if rr == "" {
rr = fmt.Sprintf("refund order:%d", o.ID)
}
p := &RefundPlan{OrderID: oid, Order: o, RefundAmount: amt, GatewayAmount: ga, Reason: rr, Force: force, DeductBalance: deduct, DeductionType: payment.DeductionTypeNone}
if deduct {
if er := s.prepDeduct(ctx, o, p, force); er != nil {
return nil, er, nil
}
}
return p, nil, nil
}
func (s *PaymentService) prepDeduct(ctx context.Context, o *dbent.PaymentOrder, p *RefundPlan, force bool) *RefundResult {
if o.OrderType == payment.OrderTypeSubscription {
p.DeductionType = payment.DeductionTypeSubscription
if o.SubscriptionGroupID != nil && o.SubscriptionDays != nil {
p.SubDaysToDeduct = *o.SubscriptionDays
sub, err := s.subscriptionSvc.GetActiveSubscription(ctx, o.UserID, *o.SubscriptionGroupID)
if err == nil && sub != nil {
p.SubscriptionID = sub.ID
} else if !force {
return &RefundResult{Success: false, Warning: "cannot find active subscription for deduction, use force", RequireForce: true}
}
}
return nil
}
u, err := s.userRepo.GetByID(ctx, o.UserID)
if err != nil {
if !force {
return &RefundResult{Success: false, Warning: "cannot fetch user balance, use force", RequireForce: true}
}
return nil
}
p.DeductionType = payment.DeductionTypeBalance
p.BalanceToDeduct = math.Min(p.RefundAmount, u.Balance)
return nil
}
func (s *PaymentService) ExecuteRefund(ctx context.Context, p *RefundPlan) (*RefundResult, error) {
c, err := s.entClient.PaymentOrder.Update().Where(paymentorder.IDEQ(p.OrderID), paymentorder.StatusIn(OrderStatusCompleted, OrderStatusRefundRequested, OrderStatusRefundFailed)).SetStatus(OrderStatusRefunding).Save(ctx)
if err != nil {
return nil, fmt.Errorf("lock: %w", err)
}
if c == 0 {
return nil, infraerrors.Conflict("CONFLICT", "order status changed")
}
if p.DeductionType == payment.DeductionTypeBalance && p.BalanceToDeduct > 0 {
// Skip balance deduction on retry if previous attempt already deducted
// but failed to roll back (REFUND_ROLLBACK_FAILED in audit log).
if !s.hasAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED") {
if err := s.userRepo.DeductBalance(ctx, p.Order.UserID, p.BalanceToDeduct); err != nil {
s.restoreStatus(ctx, p)
return nil, fmt.Errorf("deduction: %w", err)
}
} else {
slog.Warn("skipping balance deduction on retry (previous rollback failed)", "orderID", p.OrderID)
p.BalanceToDeduct = 0
}
}
if p.DeductionType == payment.DeductionTypeSubscription && p.SubDaysToDeduct > 0 && p.SubscriptionID > 0 {
if !s.hasAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED") {
_, err := s.subscriptionSvc.ExtendSubscription(ctx, p.SubscriptionID, -p.SubDaysToDeduct)
if err != nil {
// If deducting would expire the subscription, revoke it entirely
slog.Info("subscription deduction would expire, revoking", "orderID", p.OrderID, "subID", p.SubscriptionID, "days", p.SubDaysToDeduct)
if revokeErr := s.subscriptionSvc.RevokeSubscription(ctx, p.SubscriptionID); revokeErr != nil {
s.restoreStatus(ctx, p)
return nil, fmt.Errorf("revoke subscription: %w", revokeErr)
}
}
} else {
slog.Warn("skipping subscription deduction on retry (previous rollback failed)", "orderID", p.OrderID)
p.SubDaysToDeduct = 0
}
}
if err := s.gwRefund(ctx, p); err != nil {
return s.handleGwFail(ctx, p, err)
}
return s.markRefundOk(ctx, p)
}
func (s *PaymentService) gwRefund(ctx context.Context, p *RefundPlan) error {
if p.Order.PaymentTradeNo == "" {
s.writeAuditLog(ctx, p.Order.ID, "REFUND_NO_TRADE_NO", "admin", map[string]any{"detail": "skipped"})
return nil
}
// Use the exact provider instance that created this order, not a random one
// from the registry. Each instance has its own merchant credentials.
prov, err := s.getRefundProvider(ctx, p.Order)
if err != nil {
return fmt.Errorf("get refund provider: %w", err)
}
_, err = prov.Refund(ctx, payment.RefundRequest{
TradeNo: p.Order.PaymentTradeNo,
OrderID: p.Order.OutTradeNo,
Amount: strconv.FormatFloat(p.GatewayAmount, 'f', 2, 64),
Reason: p.Reason,
})
return err
}
// getRefundProvider creates a provider using the order's original instance config.
// Delegates to getOrderProvider which handles instance lookup and fallback.
func (s *PaymentService) getRefundProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) {
return s.getOrderProvider(ctx, o)
}
func (s *PaymentService) handleGwFail(ctx context.Context, p *RefundPlan, gErr error) (*RefundResult, error) {
if s.RollbackRefund(ctx, p, gErr) {
s.restoreStatus(ctx, p)
s.writeAuditLog(ctx, p.OrderID, "REFUND_GATEWAY_FAILED", "admin", map[string]any{"detail": psErrMsg(gErr)})
return &RefundResult{Success: false, Warning: "gateway failed: " + psErrMsg(gErr) + ", rolled back"}, nil
}
now := time.Now()
_, _ = s.entClient.PaymentOrder.UpdateOneID(p.OrderID).SetStatus(OrderStatusRefundFailed).SetFailedAt(now).SetFailedReason(psErrMsg(gErr)).Save(ctx)
s.writeAuditLog(ctx, p.OrderID, "REFUND_FAILED", "admin", map[string]any{"detail": psErrMsg(gErr)})
return nil, infraerrors.InternalServer("REFUND_FAILED", psErrMsg(gErr))
}
func (s *PaymentService) markRefundOk(ctx context.Context, p *RefundPlan) (*RefundResult, error) {
fs := OrderStatusRefunded
if p.RefundAmount < p.Order.Amount {
fs = OrderStatusPartiallyRefunded
}
now := time.Now()
_, err := s.entClient.PaymentOrder.UpdateOneID(p.OrderID).SetStatus(fs).SetRefundAmount(p.RefundAmount).SetRefundReason(p.Reason).SetRefundAt(now).SetForceRefund(p.Force).Save(ctx)
if err != nil {
return nil, fmt.Errorf("mark refund: %w", err)
}
s.writeAuditLog(ctx, p.OrderID, "REFUND_SUCCESS", "admin", map[string]any{"refundAmount": p.RefundAmount, "reason": p.Reason, "balanceDeducted": p.BalanceToDeduct, "force": p.Force})
return &RefundResult{Success: true, BalanceDeducted: p.BalanceToDeduct, SubDaysDeducted: p.SubDaysToDeduct}, nil
}
func (s *PaymentService) RollbackRefund(ctx context.Context, p *RefundPlan, gErr error) bool {
if p.DeductionType == payment.DeductionTypeBalance && p.BalanceToDeduct > 0 {
if err := s.userRepo.UpdateBalance(ctx, p.Order.UserID, p.BalanceToDeduct); err != nil {
slog.Error("[CRITICAL] rollback failed", "orderID", p.OrderID, "amount", p.BalanceToDeduct, "error", err)
s.writeAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED", "admin", map[string]any{"gatewayError": psErrMsg(gErr), "rollbackError": psErrMsg(err), "balanceDeducted": p.BalanceToDeduct})
return false
}
}
if p.DeductionType == payment.DeductionTypeSubscription && p.SubDaysToDeduct > 0 && p.SubscriptionID > 0 {
if _, err := s.subscriptionSvc.ExtendSubscription(ctx, p.SubscriptionID, p.SubDaysToDeduct); err != nil {
slog.Error("[CRITICAL] subscription rollback failed", "orderID", p.OrderID, "subID", p.SubscriptionID, "days", p.SubDaysToDeduct, "error", err)
s.writeAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED", "admin", map[string]any{"gatewayError": psErrMsg(gErr), "rollbackError": psErrMsg(err), "subDaysDeducted": p.SubDaysToDeduct})
return false
}
}
return true
}
func (s *PaymentService) restoreStatus(ctx context.Context, p *RefundPlan) {
rs := OrderStatusCompleted
if p.Order.Status == OrderStatusRefundRequested {
rs = OrderStatusRefundRequested
}
_, _ = s.entClient.PaymentOrder.UpdateOneID(p.OrderID).SetStatus(rs).Save(ctx)
}
package service
import (
"context"
"fmt"
"log/slog"
"math/rand/v2"
"sync"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/Wei-Shaw/sub2api/internal/payment/provider"
)
// --- Order Status Constants ---
const (
OrderStatusPending = payment.OrderStatusPending
OrderStatusPaid = payment.OrderStatusPaid
OrderStatusRecharging = payment.OrderStatusRecharging
OrderStatusCompleted = payment.OrderStatusCompleted
OrderStatusExpired = payment.OrderStatusExpired
OrderStatusCancelled = payment.OrderStatusCancelled
OrderStatusFailed = payment.OrderStatusFailed
OrderStatusRefundRequested = payment.OrderStatusRefundRequested
OrderStatusRefunding = payment.OrderStatusRefunding
OrderStatusPartiallyRefunded = payment.OrderStatusPartiallyRefunded
OrderStatusRefunded = payment.OrderStatusRefunded
OrderStatusRefundFailed = payment.OrderStatusRefundFailed
)
const (
// defaultMaxPendingOrders and defaultOrderTimeoutMin are defined in
// payment_config_service.go alongside other payment configuration defaults.
paymentGraceMinutes = 5
defaultPageSize = 20
maxPageSize = 100
topUsersLimit = 10
amountToleranceCNY = 0.01
orderIDPrefix = "sub2_"
)
// --- Types ---
// generateOutTradeNo creates a unique external order ID for payment providers.
// Format: sub2_20250409aB3kX9mQ (prefix + date + 8-char random)
func generateOutTradeNo() string {
date := time.Now().Format("20060102")
rnd := generateRandomString(8)
return orderIDPrefix + date + rnd
}
func generateRandomString(n int) string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
b := make([]byte, n)
for i := range b {
b[i] = charset[rand.IntN(len(charset))]
}
return string(b)
}
type CreateOrderRequest struct {
UserID int64
Amount float64
PaymentType string
ClientIP string
IsMobile bool
SrcHost string
SrcURL string
OrderType string
PlanID int64
}
type CreateOrderResponse struct {
OrderID int64 `json:"order_id"`
Amount float64 `json:"amount"`
PayAmount float64 `json:"pay_amount"`
FeeRate float64 `json:"fee_rate"`
Status string `json:"status"`
PaymentType string `json:"payment_type"`
PayURL string `json:"pay_url,omitempty"`
QRCode string `json:"qr_code,omitempty"`
ClientSecret string `json:"client_secret,omitempty"`
ExpiresAt time.Time `json:"expires_at"`
PaymentMode string `json:"payment_mode,omitempty"`
}
type OrderListParams struct {
Page int
PageSize int
Status string
OrderType string
PaymentType string
Keyword string
}
type RefundPlan struct {
OrderID int64
Order *dbent.PaymentOrder
RefundAmount float64
GatewayAmount float64
Reason string
Force bool
DeductBalance bool
DeductionType string
BalanceToDeduct float64
SubDaysToDeduct int
SubscriptionID int64
}
type RefundResult struct {
Success bool `json:"success"`
Warning string `json:"warning,omitempty"`
RequireForce bool `json:"require_force,omitempty"`
BalanceDeducted float64 `json:"balance_deducted,omitempty"`
SubDaysDeducted int `json:"subscription_days_deducted,omitempty"`
}
type DashboardStats struct {
TodayAmount float64 `json:"today_amount"`
TotalAmount float64 `json:"total_amount"`
TodayCount int `json:"today_count"`
TotalCount int `json:"total_count"`
AvgAmount float64 `json:"avg_amount"`
PendingOrders int `json:"pending_orders"`
DailySeries []DailyStats `json:"daily_series"`
PaymentMethods []PaymentMethodStat `json:"payment_methods"`
TopUsers []TopUserStat `json:"top_users"`
}
type DailyStats struct {
Date string `json:"date"`
Amount float64 `json:"amount"`
Count int `json:"count"`
}
type PaymentMethodStat struct {
Type string `json:"type"`
Amount float64 `json:"amount"`
Count int `json:"count"`
}
type TopUserStat struct {
UserID int64 `json:"user_id"`
Email string `json:"email"`
Amount float64 `json:"amount"`
}
// --- Service ---
type PaymentService struct {
providerMu sync.Mutex
providersLoaded bool
entClient *dbent.Client
registry *payment.Registry
loadBalancer payment.LoadBalancer
redeemService *RedeemService
subscriptionSvc *SubscriptionService
configService *PaymentConfigService
userRepo UserRepository
groupRepo GroupRepository
}
func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService {
return &PaymentService{entClient: entClient, registry: registry, loadBalancer: loadBalancer, redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo}
}
// --- Provider Registry ---
// EnsureProviders lazily initializes the provider registry on first call.
func (s *PaymentService) EnsureProviders(ctx context.Context) {
s.providerMu.Lock()
defer s.providerMu.Unlock()
if !s.providersLoaded {
s.loadProviders(ctx)
s.providersLoaded = true
}
}
// RefreshProviders clears and re-registers all providers from the database.
func (s *PaymentService) RefreshProviders(ctx context.Context) {
s.providerMu.Lock()
defer s.providerMu.Unlock()
s.registry.Clear()
s.loadProviders(ctx)
s.providersLoaded = true
}
func (s *PaymentService) loadProviders(ctx context.Context) {
instances, err := s.entClient.PaymentProviderInstance.Query().
Where(paymentproviderinstance.EnabledEQ(true)).
All(ctx)
if err != nil {
slog.Error("[PaymentService] failed to query provider instances", "error", err)
return
}
for _, inst := range instances {
cfg, err := s.loadBalancer.GetInstanceConfig(ctx, int64(inst.ID))
if err != nil {
slog.Warn("[PaymentService] failed to decrypt config for instance", "instanceID", inst.ID, "error", err)
continue
}
if inst.PaymentMode != "" {
cfg["paymentMode"] = inst.PaymentMode
}
instID := fmt.Sprintf("%d", inst.ID)
p, err := provider.CreateProvider(inst.ProviderKey, instID, cfg)
if err != nil {
slog.Warn("[PaymentService] failed to create provider for instance", "instanceID", inst.ID, "key", inst.ProviderKey, "error", err)
continue
}
s.registry.Register(p)
}
}
// GetWebhookProvider returns the provider instance that should verify a webhook.
// It extracts out_trade_no from the raw body, looks up the order to find the
// original provider instance, and creates a provider with that instance's credentials.
// Falls back to the registry provider when the order cannot be found.
func (s *PaymentService) GetWebhookProvider(ctx context.Context, providerKey, outTradeNo string) (payment.Provider, error) {
if outTradeNo != "" {
order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(outTradeNo)).Only(ctx)
if err == nil {
p, pErr := s.getOrderProvider(ctx, order)
if pErr == nil {
return p, nil
}
slog.Warn("[Webhook] order provider creation failed, falling back to registry", "outTradeNo", outTradeNo, "error", pErr)
}
}
s.EnsureProviders(ctx)
return s.registry.GetProviderByKey(providerKey)
}
// --- Helpers ---
func psIsRefundStatus(s string) bool {
switch s {
case OrderStatusRefundRequested, OrderStatusRefunding, OrderStatusPartiallyRefunded, OrderStatusRefunded, OrderStatusRefundFailed:
return true
}
return false
}
func psErrMsg(err error) string {
if err == nil {
return ""
}
return err.Error()
}
func psNilIfEmpty(s string) *string {
if s == "" {
return nil
}
return &s
}
func psSliceContains(sl []string, s string) bool {
for _, v := range sl {
if v == s {
return true
}
}
return false
}
func psComputeValidityDays(days int, unit string) int {
switch unit {
case "week":
return days * 7
case "month":
return days * 30
default:
return days
}
}
func (s *PaymentService) getFeeRate(_ string) float64 { return 0 }
func psStartOfDayUTC(t time.Time) time.Time {
y, m, d := t.UTC().Date()
return time.Date(y, m, d, 0, 0, 0, 0, time.UTC)
}
func applyPagination(pageSize, page int) (size, pg int) {
size = pageSize
if size <= 0 {
size = defaultPageSize
}
if size > maxPageSize {
size = maxPageSize
}
pg = page
if pg < 1 {
pg = 1
}
return size, pg
}
package service
import (
"context"
"encoding/json"
"log/slog"
"math"
"sort"
"strconv"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
)
// --- Dashboard & Analytics ---
func (s *PaymentService) GetDashboardStats(ctx context.Context, days int) (*DashboardStats, error) {
if days <= 0 {
days = 30
}
now := time.Now()
since := now.AddDate(0, 0, -days)
todayStart := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
paidStatuses := []string{OrderStatusCompleted, OrderStatusPaid, OrderStatusRecharging}
orders, err := s.entClient.PaymentOrder.Query().
Where(
paymentorder.StatusIn(paidStatuses...),
paymentorder.PaidAtGTE(since),
).
All(ctx)
if err != nil {
return nil, err
}
st := &DashboardStats{}
computeBasicStats(st, orders, todayStart)
st.PendingOrders, err = s.entClient.PaymentOrder.Query().
Where(paymentorder.StatusEQ(OrderStatusPending)).
Count(ctx)
if err != nil {
return nil, err
}
st.DailySeries = buildDailySeries(orders, since, days)
st.PaymentMethods = buildMethodDistribution(orders)
st.TopUsers = buildTopUsers(orders)
return st, nil
}
func computeBasicStats(st *DashboardStats, orders []*dbent.PaymentOrder, todayStart time.Time) {
var totalAmount, todayAmount float64
var todayCount int
for _, o := range orders {
totalAmount += o.PayAmount
if o.PaidAt != nil && !o.PaidAt.Before(todayStart) {
todayAmount += o.PayAmount
todayCount++
}
}
st.TotalAmount = math.Round(totalAmount*100) / 100
st.TodayAmount = math.Round(todayAmount*100) / 100
st.TotalCount = len(orders)
st.TodayCount = todayCount
if st.TotalCount > 0 {
st.AvgAmount = math.Round(totalAmount/float64(st.TotalCount)*100) / 100
}
}
func buildDailySeries(orders []*dbent.PaymentOrder, since time.Time, days int) []DailyStats {
dailyMap := make(map[string]*DailyStats)
for _, o := range orders {
if o.PaidAt == nil {
continue
}
date := o.PaidAt.Format("2006-01-02")
ds, ok := dailyMap[date]
if !ok {
ds = &DailyStats{Date: date}
dailyMap[date] = ds
}
ds.Amount += o.PayAmount
ds.Count++
}
series := make([]DailyStats, 0, days)
for i := 0; i < days; i++ {
date := since.AddDate(0, 0, i+1).Format("2006-01-02")
if ds, ok := dailyMap[date]; ok {
ds.Amount = math.Round(ds.Amount*100) / 100
series = append(series, *ds)
} else {
series = append(series, DailyStats{Date: date})
}
}
return series
}
func buildMethodDistribution(orders []*dbent.PaymentOrder) []PaymentMethodStat {
methodMap := make(map[string]*PaymentMethodStat)
for _, o := range orders {
ms, ok := methodMap[o.PaymentType]
if !ok {
ms = &PaymentMethodStat{Type: o.PaymentType}
methodMap[o.PaymentType] = ms
}
ms.Amount += o.PayAmount
ms.Count++
}
methods := make([]PaymentMethodStat, 0, len(methodMap))
for _, ms := range methodMap {
ms.Amount = math.Round(ms.Amount*100) / 100
methods = append(methods, *ms)
}
return methods
}
func buildTopUsers(orders []*dbent.PaymentOrder) []TopUserStat {
userMap := make(map[int64]*TopUserStat)
for _, o := range orders {
us, ok := userMap[o.UserID]
if !ok {
us = &TopUserStat{UserID: o.UserID, Email: o.UserEmail}
userMap[o.UserID] = us
}
us.Amount += o.PayAmount
}
userList := make([]*TopUserStat, 0, len(userMap))
for _, us := range userMap {
us.Amount = math.Round(us.Amount*100) / 100
userList = append(userList, us)
}
sort.Slice(userList, func(i, j int) bool {
return userList[i].Amount > userList[j].Amount
})
limit := topUsersLimit
if len(userList) < limit {
limit = len(userList)
}
result := make([]TopUserStat, 0, limit)
for i := 0; i < limit; i++ {
result = append(result, *userList[i])
}
return result
}
// --- Audit Logs ---
func (s *PaymentService) writeAuditLog(ctx context.Context, oid int64, action, op string, detail map[string]any) {
dj, _ := json.Marshal(detail)
_, err := s.entClient.PaymentAuditLog.Create().SetOrderID(strconv.FormatInt(oid, 10)).SetAction(action).SetDetail(string(dj)).SetOperator(op).Save(ctx)
if err != nil {
slog.Error("audit log failed", "orderID", oid, "action", action, "error", err)
}
}
func (s *PaymentService) GetOrderAuditLogs(ctx context.Context, oid int64) ([]*dbent.PaymentAuditLog, error) {
return s.entClient.PaymentAuditLog.Query().Where(paymentauditlog.OrderIDEQ(strconv.FormatInt(oid, 10))).Order(paymentauditlog.ByCreatedAt()).All(ctx)
}
......@@ -167,6 +167,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyBackendModeEnabled,
SettingKeyOIDCConnectEnabled,
SettingKeyOIDCConnectProviderName,
SettingPaymentEnabled,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
......@@ -227,6 +228,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
OIDCOAuthEnabled: oidcEnabled,
OIDCOAuthProviderName: oidcProviderName,
PaymentEnabled: settings[SettingPaymentEnabled] == "true",
}, nil
}
......@@ -276,6 +278,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
BackendModeEnabled bool `json:"backend_mode_enabled"`
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
PaymentEnabled bool `json:"payment_enabled"`
Version string `json:"version,omitempty"`
}{
RegistrationEnabled: settings.RegistrationEnabled,
......@@ -303,6 +306,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
BackendModeEnabled: settings.BackendModeEnabled,
OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
OIDCOAuthProviderName: settings.OIDCOAuthProviderName,
PaymentEnabled: settings.PaymentEnabled,
Version: s.version,
}, nil
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment