Commit 63d1860d authored by erio's avatar erio
Browse files

feat(payment): add complete payment system with multi-provider support

Add a full payment and subscription system supporting EasyPay (Alipay/WeChat),
Stripe, and direct Alipay/WeChat Pay providers with multi-instance load balancing.
parent 00c08c57
// Package payment provides the core payment provider abstraction,
// registry, load balancing, and shared utilities for the payment subsystem.
package payment
import "context"
// PaymentType represents a supported payment method.
type PaymentType = string
// Supported payment type constants.
const (
TypeAlipay PaymentType = "alipay"
TypeWxpay PaymentType = "wxpay"
TypeAlipayDirect PaymentType = "alipay_direct"
TypeWxpayDirect PaymentType = "wxpay_direct"
TypeStripe PaymentType = "stripe"
TypeCard PaymentType = "card"
TypeLink PaymentType = "link"
TypeEasyPay PaymentType = "easypay"
)
// Order status constants shared across payment and service layers.
const (
OrderStatusPending = "PENDING"
OrderStatusPaid = "PAID"
OrderStatusRecharging = "RECHARGING"
OrderStatusCompleted = "COMPLETED"
OrderStatusExpired = "EXPIRED"
OrderStatusCancelled = "CANCELLED"
OrderStatusFailed = "FAILED"
OrderStatusRefundRequested = "REFUND_REQUESTED"
OrderStatusRefunding = "REFUNDING"
OrderStatusPartiallyRefunded = "PARTIALLY_REFUNDED"
OrderStatusRefunded = "REFUNDED"
OrderStatusRefundFailed = "REFUND_FAILED"
)
// Order types distinguish balance recharges from subscription purchases.
const (
OrderTypeBalance = "balance"
OrderTypeSubscription = "subscription"
)
// Entity statuses shared across users, groups, etc.
const (
EntityStatusActive = "active"
)
// Deduction types for refund flow.
const (
DeductionTypeBalance = "balance"
DeductionTypeSubscription = "subscription"
DeductionTypeNone = "none"
)
// Payment notification status values.
const (
NotificationStatusSuccess = "success"
NotificationStatusPaid = "paid"
)
// Provider-level status constants returned by provider implementations
// to the service layer (lowercase, distinct from OrderStatus uppercase constants).
const (
ProviderStatusPending = "pending"
ProviderStatusPaid = "paid"
ProviderStatusSuccess = "success"
ProviderStatusFailed = "failed"
ProviderStatusRefunded = "refunded"
)
// DefaultLoadBalanceStrategy is the default load-balancing strategy
// used when no strategy is configured.
const DefaultLoadBalanceStrategy = "round-robin"
// ConfigKeyPublishableKey is the config map key for Stripe's publishable key.
const ConfigKeyPublishableKey = "publishableKey"
// GetBasePaymentType extracts the base payment method from a composite key.
// For example, "alipay_direct" -> "alipay".
func GetBasePaymentType(t string) string {
switch {
case t == TypeEasyPay:
return TypeEasyPay
case t == TypeStripe || t == TypeCard || t == TypeLink:
return TypeStripe
case len(t) >= len(TypeAlipay) && t[:len(TypeAlipay)] == TypeAlipay:
return TypeAlipay
case len(t) >= len(TypeWxpay) && t[:len(TypeWxpay)] == TypeWxpay:
return TypeWxpay
default:
return t
}
}
// CreatePaymentRequest holds the parameters for creating a new payment.
type CreatePaymentRequest struct {
OrderID string // Internal order ID
Amount string // Pay amount in CNY (formatted to 2 decimal places)
PaymentType string // e.g. "alipay", "wxpay", "stripe"
Subject string // Product description
NotifyURL string // Webhook callback URL
ReturnURL string // Browser redirect URL after payment
ClientIP string // Payer's IP address
IsMobile bool // Whether the request comes from a mobile device
InstanceSubMethods string // Comma-separated sub-methods from instance supported_types (for Stripe)
}
// CreatePaymentResponse is returned after successfully initiating a payment.
type CreatePaymentResponse struct {
TradeNo string // Third-party transaction ID
PayURL string // H5 payment URL (alipay/wxpay)
QRCode string // QR code content for scanning
ClientSecret string // Stripe PaymentIntent client secret
}
// QueryOrderResponse describes the payment status from the upstream provider.
type QueryOrderResponse struct {
TradeNo string
Status string // "pending", "paid", "failed", "refunded"
Amount float64 // Amount in CNY
PaidAt string // RFC3339 timestamp or empty
}
// PaymentNotification is the parsed result of a webhook/notify callback.
type PaymentNotification struct {
TradeNo string
OrderID string
Amount float64
Status string // "success" or "failed"
RawData string // Raw notification body for audit
}
// RefundRequest contains the parameters for requesting a refund.
type RefundRequest struct {
TradeNo string
OrderID string
Amount string // Refund amount formatted to 2 decimal places
Reason string
}
// RefundResponse is returned after a refund request.
type RefundResponse struct {
RefundID string
Status string // "success", "pending", "failed"
}
// InstanceSelection holds the selected provider instance and its decrypted config.
type InstanceSelection struct {
InstanceID string
Config map[string]string
SupportedTypes string // Comma-separated list of supported payment types from the instance
PaymentMode string // Payment display mode: "qrcode", "redirect", "popup"
}
// Provider defines the interface that all payment providers must implement.
type Provider interface {
// Name returns a human-readable name for this provider.
Name() string
// ProviderKey returns the unique key identifying this provider type (e.g. "easypay").
ProviderKey() string
// SupportedTypes returns the list of payment types this provider handles.
SupportedTypes() []PaymentType
// CreatePayment initiates a payment and returns the upstream response.
CreatePayment(ctx context.Context, req CreatePaymentRequest) (*CreatePaymentResponse, error)
// QueryOrder queries the payment status of the given trade number.
QueryOrder(ctx context.Context, tradeNo string) (*QueryOrderResponse, error)
// VerifyNotification parses and verifies a webhook callback.
// Returns nil for unrecognized or irrelevant events (caller should return 200).
VerifyNotification(ctx context.Context, rawBody string, headers map[string]string) (*PaymentNotification, error)
// Refund requests a refund from the upstream provider.
Refund(ctx context.Context, req RefundRequest) (*RefundResponse, error)
}
// CancelableProvider extends Provider with the ability to cancel pending payments.
type CancelableProvider interface {
Provider
// CancelPayment cancels/expires a pending payment on the upstream platform.
CancelPayment(ctx context.Context, tradeNo string) error
}
package payment
import (
"encoding/hex"
"fmt"
"log/slog"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/google/wire"
)
// EncryptionKey is a named type for the payment encryption key (AES-256, 32 bytes).
// Using a named type avoids Wire ambiguity with other []byte parameters.
type EncryptionKey []byte
// ProvideEncryptionKey derives the payment encryption key from the TOTP encryption key in config.
// When the key is empty, nil is returned (payment features that need encryption will be disabled).
// When the key is non-empty but invalid (bad hex or wrong length), an error is returned
// to prevent startup with a misconfigured encryption key.
func ProvideEncryptionKey(cfg *config.Config) (EncryptionKey, error) {
if cfg.Totp.EncryptionKey == "" {
slog.Warn("payment encryption key not configured — encrypted payment config will be unavailable")
return nil, nil
}
key, err := hex.DecodeString(cfg.Totp.EncryptionKey)
if err != nil {
return nil, fmt.Errorf("invalid payment encryption key (hex decode): %w", err)
}
if len(key) != 32 {
return nil, fmt.Errorf("payment encryption key must be 32 bytes, got %d", len(key))
}
return EncryptionKey(key), nil
}
// ProvideRegistry creates an empty payment provider registry.
// Providers are registered at runtime after application startup.
func ProvideRegistry() *Registry {
return NewRegistry()
}
// ProvideDefaultLoadBalancer creates a DefaultLoadBalancer backed by the ent client.
func ProvideDefaultLoadBalancer(client *dbent.Client, key EncryptionKey) *DefaultLoadBalancer {
return NewDefaultLoadBalancer(client, []byte(key))
}
// ProviderSet is the Wire provider set for the payment package.
var ProviderSet = wire.NewSet(
ProvideEncryptionKey,
ProvideRegistry,
ProvideDefaultLoadBalancer,
wire.Bind(new(LoadBalancer), new(*DefaultLoadBalancer)),
)
...@@ -583,6 +583,24 @@ func TestAPIContracts(t *testing.T) { ...@@ -583,6 +583,24 @@ func TestAPIContracts(t *testing.T) {
"enable_cch_signing": false, "enable_cch_signing": false,
"enable_fingerprint_unification": true, "enable_fingerprint_unification": true,
"enable_metadata_passthrough": false, "enable_metadata_passthrough": false,
"payment_enabled": false,
"payment_min_amount": 0,
"payment_max_amount": 0,
"payment_daily_limit": 0,
"payment_order_timeout_minutes": 0,
"payment_max_pending_orders": 0,
"payment_enabled_types": null,
"payment_balance_disabled": false,
"payment_load_balance_strategy": "",
"payment_product_name_prefix": "",
"payment_product_name_suffix": "",
"payment_help_image_url": "",
"payment_help_text": "",
"payment_cancel_rate_limit_enabled": false,
"payment_cancel_rate_limit_max": 0,
"payment_cancel_rate_limit_window": 0,
"payment_cancel_rate_limit_unit": "",
"payment_cancel_rate_limit_window_mode": "",
"custom_menu_items": [], "custom_menu_items": [],
"custom_endpoints": [] "custom_endpoints": []
} }
...@@ -696,7 +714,7 @@ func newContractDeps(t *testing.T) *contractDeps { ...@@ -696,7 +714,7 @@ func newContractDeps(t *testing.T) *contractDeps {
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil) adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil, nil, nil)
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
jwtAuth := func(c *gin.Context) { jwtAuth := func(c *gin.Context) {
......
...@@ -111,4 +111,5 @@ func registerRoutes( ...@@ -111,4 +111,5 @@ func registerRoutes(
routes.RegisterUserRoutes(v1, h, jwtAuth, settingService) routes.RegisterUserRoutes(v1, h, jwtAuth, settingService)
routes.RegisterAdminRoutes(v1, h, adminAuth) routes.RegisterAdminRoutes(v1, h, adminAuth)
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg) routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg)
routes.RegisterPaymentRoutes(v1, h.Payment, h.PaymentWebhook, h.Admin.Payment, jwtAuth, adminAuth, settingService)
} }
package routes
import (
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/handler/admin"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// RegisterPaymentRoutes registers all payment-related routes:
// user-facing endpoints, webhook endpoints, and admin endpoints.
func RegisterPaymentRoutes(
v1 *gin.RouterGroup,
paymentHandler *handler.PaymentHandler,
webhookHandler *handler.PaymentWebhookHandler,
adminPaymentHandler *admin.PaymentHandler,
jwtAuth middleware.JWTAuthMiddleware,
adminAuth middleware.AdminAuthMiddleware,
settingService *service.SettingService,
) {
// --- User-facing payment endpoints (authenticated) ---
authenticated := v1.Group("/payment")
authenticated.Use(gin.HandlerFunc(jwtAuth))
authenticated.Use(middleware.BackendModeUserGuard(settingService))
{
authenticated.GET("/config", paymentHandler.GetPaymentConfig)
authenticated.GET("/checkout-info", paymentHandler.GetCheckoutInfo)
authenticated.GET("/plans", paymentHandler.GetPlans)
authenticated.GET("/channels", paymentHandler.GetChannels)
authenticated.GET("/limits", paymentHandler.GetLimits)
orders := authenticated.Group("/orders")
{
orders.POST("", paymentHandler.CreateOrder)
orders.POST("/verify", paymentHandler.VerifyOrder)
orders.GET("/my", paymentHandler.GetMyOrders)
orders.GET("/:id", paymentHandler.GetOrder)
orders.POST("/:id/cancel", paymentHandler.CancelOrder)
orders.POST("/:id/refund-request", paymentHandler.RequestRefund)
}
}
// --- Public payment endpoints (no auth) ---
// Payment result page needs to verify order status without login
// (user session may have expired during provider redirect).
public := v1.Group("/payment/public")
{
public.POST("/orders/verify", paymentHandler.VerifyOrderPublic)
}
// --- Webhook endpoints (no auth) ---
webhook := v1.Group("/payment/webhook")
{
// EasyPay sends GET callbacks with query params
webhook.GET("/easypay", webhookHandler.EasyPayNotify)
webhook.POST("/easypay", webhookHandler.EasyPayNotify)
webhook.POST("/alipay", webhookHandler.AlipayNotify)
webhook.POST("/wxpay", webhookHandler.WxpayNotify)
webhook.POST("/stripe", webhookHandler.StripeWebhook)
}
// --- Admin payment endpoints (admin auth) ---
adminGroup := v1.Group("/admin/payment")
adminGroup.Use(gin.HandlerFunc(adminAuth))
{
// Dashboard
adminGroup.GET("/dashboard", adminPaymentHandler.GetDashboard)
// Config
adminGroup.GET("/config", adminPaymentHandler.GetConfig)
adminGroup.PUT("/config", adminPaymentHandler.UpdateConfig)
// Orders
adminOrders := adminGroup.Group("/orders")
{
adminOrders.GET("", adminPaymentHandler.ListOrders)
adminOrders.GET("/:id", adminPaymentHandler.GetOrderDetail)
adminOrders.POST("/:id/cancel", adminPaymentHandler.CancelOrder)
adminOrders.POST("/:id/retry", adminPaymentHandler.RetryFulfillment)
adminOrders.POST("/:id/refund", adminPaymentHandler.ProcessRefund)
}
// Subscription Plans
plans := adminGroup.Group("/plans")
{
plans.GET("", adminPaymentHandler.ListPlans)
plans.POST("", adminPaymentHandler.CreatePlan)
plans.PUT("/:id", adminPaymentHandler.UpdatePlan)
plans.DELETE("/:id", adminPaymentHandler.DeletePlan)
}
// Provider Instances
providers := adminGroup.Group("/providers")
{
providers.GET("", adminPaymentHandler.ListProviders)
providers.POST("", adminPaymentHandler.CreateProvider)
providers.PUT("/:id", adminPaymentHandler.UpdateProvider)
providers.DELETE("/:id", adminPaymentHandler.DeleteProvider)
}
}
}
package service
import (
"context"
"encoding/json"
"fmt"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
"github.com/Wei-Shaw/sub2api/internal/payment"
)
// GetAvailableMethodLimits collects all payment types from enabled provider
// instances and returns limits for each, plus the global widest range.
// Stripe sub-types (card, link) are aggregated under "stripe".
func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*MethodLimitsResponse, error) {
instances, err := s.entClient.PaymentProviderInstance.Query().
Where(paymentproviderinstance.EnabledEQ(true)).All(ctx)
if err != nil {
return nil, fmt.Errorf("query provider instances: %w", err)
}
typeInstances := pcGroupByPaymentType(instances)
resp := &MethodLimitsResponse{
Methods: make(map[string]MethodLimits, len(typeInstances)),
}
for pt, insts := range typeInstances {
ml := pcAggregateMethodLimits(pt, insts)
resp.Methods[ml.PaymentType] = ml
}
resp.GlobalMin, resp.GlobalMax = pcComputeGlobalRange(resp.Methods)
return resp, nil
}
// GetMethodLimits returns per-payment-type limits from enabled provider instances.
func (s *PaymentConfigService) GetMethodLimits(ctx context.Context, types []string) ([]MethodLimits, error) {
instances, err := s.entClient.PaymentProviderInstance.Query().
Where(paymentproviderinstance.EnabledEQ(true)).All(ctx)
if err != nil {
return nil, fmt.Errorf("query provider instances: %w", err)
}
result := make([]MethodLimits, 0, len(types))
for _, pt := range types {
var matching []*dbent.PaymentProviderInstance
for _, inst := range instances {
if payment.InstanceSupportsType(inst.SupportedTypes, pt) {
matching = append(matching, inst)
}
}
result = append(result, pcAggregateMethodLimits(pt, matching))
}
return result, nil
}
// pcGroupByPaymentType groups instances by user-facing payment type.
// For Stripe providers, ALL sub-types (card, link, alipay, wxpay) map to "stripe"
// because the user sees a single "Stripe" button, not individual sub-methods.
// Uses a seen set to avoid counting one instance twice.
func pcGroupByPaymentType(instances []*dbent.PaymentProviderInstance) map[string][]*dbent.PaymentProviderInstance {
typeInstances := make(map[string][]*dbent.PaymentProviderInstance)
seen := make(map[string]map[int64]bool)
add := func(key string, inst *dbent.PaymentProviderInstance) {
if seen[key] == nil {
seen[key] = make(map[int64]bool)
}
if !seen[key][int64(inst.ID)] {
seen[key][int64(inst.ID)] = true
typeInstances[key] = append(typeInstances[key], inst)
}
}
for _, inst := range instances {
// Stripe provider: all sub-types → single "stripe" group
if inst.ProviderKey == payment.TypeStripe {
add(payment.TypeStripe, inst)
continue
}
for _, t := range splitTypes(inst.SupportedTypes) {
add(t, inst)
}
}
return typeInstances
}
// pcInstanceTypeLimits extracts per-type limits from a provider instance.
// Returns (limits, true) if configured; (zero, false) if unlimited.
// For Stripe instances, limits are stored under "stripe" key regardless of sub-types.
func pcInstanceTypeLimits(inst *dbent.PaymentProviderInstance, pt string) (payment.ChannelLimits, bool) {
if inst.Limits == "" {
return payment.ChannelLimits{}, false
}
var limits payment.InstanceLimits
if err := json.Unmarshal([]byte(inst.Limits), &limits); err != nil {
return payment.ChannelLimits{}, false
}
cl, ok := limits[pt]
return cl, ok
}
// unionFloat merges a single limit value into the aggregate using UNION semantics.
// - For "min" fields (wantMin=true): keeps the lowest non-zero value
// - For "max"/"cap" fields (wantMin=false): keeps the highest non-zero value
// - If any value is 0 (unlimited), the result is unlimited.
//
// Returns (aggregated value, still limited).
func unionFloat(agg float64, limited bool, val float64, wantMin bool) (float64, bool) {
if val == 0 {
return agg, false
}
if !limited {
return agg, false
}
if agg == 0 {
return val, true
}
if wantMin && val < agg {
return val, true
}
if !wantMin && val > agg {
return val, true
}
return agg, true
}
// pcAggregateMethodLimits computes the UNION (least restrictive) of limits
// across all provider instances for a given payment type.
//
// Since the load balancer can route an order to any available instance,
// the user should see the widest possible range:
// - SingleMin: lowest floor across instances; 0 if any is unlimited
// - SingleMax: highest ceiling across instances; 0 if any is unlimited
// - DailyLimit: highest cap across instances; 0 if any is unlimited
func pcAggregateMethodLimits(pt string, instances []*dbent.PaymentProviderInstance) MethodLimits {
ml := MethodLimits{PaymentType: pt}
minLimited, maxLimited, dailyLimited := true, true, true
for _, inst := range instances {
cl, hasLimits := pcInstanceTypeLimits(inst, pt)
if !hasLimits {
return MethodLimits{PaymentType: pt} // any unlimited instance → all zeros
}
ml.SingleMin, minLimited = unionFloat(ml.SingleMin, minLimited, cl.SingleMin, true)
ml.SingleMax, maxLimited = unionFloat(ml.SingleMax, maxLimited, cl.SingleMax, false)
ml.DailyLimit, dailyLimited = unionFloat(ml.DailyLimit, dailyLimited, cl.DailyLimit, false)
}
if !minLimited {
ml.SingleMin = 0
}
if !maxLimited {
ml.SingleMax = 0
}
if !dailyLimited {
ml.DailyLimit = 0
}
return ml
}
// pcComputeGlobalRange computes the widest [min, max] across all methods.
// Uses the same union logic: lowest min, highest max, 0 if any is unlimited.
func pcComputeGlobalRange(methods map[string]MethodLimits) (globalMin, globalMax float64) {
minLimited, maxLimited := true, true
for _, ml := range methods {
globalMin, minLimited = unionFloat(globalMin, minLimited, ml.SingleMin, true)
globalMax, maxLimited = unionFloat(globalMax, maxLimited, ml.SingleMax, false)
}
if !minLimited {
globalMin = 0
}
if !maxLimited {
globalMax = 0
}
return globalMin, globalMax
}
package service
import (
"testing"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/payment"
)
func TestUnionFloat(t *testing.T) {
t.Parallel()
tests := []struct {
name string
agg float64
limited bool
val float64
wantMin bool
wantAgg float64
wantLimited bool
}{
{"first non-zero value", 0, true, 5, true, 5, true},
{"lower min replaces", 10, true, 3, true, 3, true},
{"higher min does not replace", 3, true, 10, true, 3, true},
{"higher max replaces", 10, true, 20, false, 20, true},
{"lower max does not replace", 20, true, 10, false, 20, true},
{"zero value makes unlimited", 5, true, 0, true, 5, false},
{"already unlimited stays unlimited", 5, false, 10, true, 5, false},
{"zero on first call", 0, true, 0, true, 0, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
gotAgg, gotLimited := unionFloat(tt.agg, tt.limited, tt.val, tt.wantMin)
if gotAgg != tt.wantAgg || gotLimited != tt.wantLimited {
t.Fatalf("unionFloat(%v, %v, %v, %v) = (%v, %v), want (%v, %v)",
tt.agg, tt.limited, tt.val, tt.wantMin,
gotAgg, gotLimited, tt.wantAgg, tt.wantLimited)
}
})
}
}
func makeInstance(id int64, providerKey, supportedTypes, limits string) *dbent.PaymentProviderInstance {
return &dbent.PaymentProviderInstance{
ID: id,
ProviderKey: providerKey,
SupportedTypes: supportedTypes,
Limits: limits,
Enabled: true,
}
}
func TestPcAggregateMethodLimits(t *testing.T) {
t.Parallel()
t.Run("single instance with limits", func(t *testing.T) {
t.Parallel()
inst := makeInstance(1, "easypay", "alipay,wxpay",
`{"alipay":{"singleMin":2,"singleMax":14},"wxpay":{"singleMin":1,"singleMax":12}}`)
ml := pcAggregateMethodLimits("alipay", []*dbent.PaymentProviderInstance{inst})
if ml.SingleMin != 2 || ml.SingleMax != 14 {
t.Fatalf("alipay limits = min:%v max:%v, want min:2 max:14", ml.SingleMin, ml.SingleMax)
}
})
t.Run("two instances union takes widest range", func(t *testing.T) {
t.Parallel()
inst1 := makeInstance(1, "easypay", "alipay,wxpay",
`{"alipay":{"singleMin":5,"singleMax":100}}`)
inst2 := makeInstance(2, "easypay", "alipay,wxpay",
`{"alipay":{"singleMin":2,"singleMax":200}}`)
ml := pcAggregateMethodLimits("alipay", []*dbent.PaymentProviderInstance{inst1, inst2})
if ml.SingleMin != 2 {
t.Fatalf("SingleMin = %v, want 2 (lowest floor)", ml.SingleMin)
}
if ml.SingleMax != 200 {
t.Fatalf("SingleMax = %v, want 200 (highest ceiling)", ml.SingleMax)
}
})
t.Run("one instance unlimited makes aggregate unlimited", func(t *testing.T) {
t.Parallel()
inst1 := makeInstance(1, "easypay", "wxpay",
`{"wxpay":{"singleMin":3,"singleMax":10}}`)
inst2 := makeInstance(2, "easypay", "wxpay", "") // no limits = unlimited
ml := pcAggregateMethodLimits("wxpay", []*dbent.PaymentProviderInstance{inst1, inst2})
if ml.SingleMin != 0 || ml.SingleMax != 0 {
t.Fatalf("limits = min:%v max:%v, want min:0 max:0 (unlimited)", ml.SingleMin, ml.SingleMax)
}
})
t.Run("one field unlimited others limited", func(t *testing.T) {
t.Parallel()
inst1 := makeInstance(1, "easypay", "alipay",
`{"alipay":{"singleMin":5,"singleMax":100}}`)
inst2 := makeInstance(2, "easypay", "alipay",
`{"alipay":{"singleMin":3,"singleMax":0}}`) // singleMax=0 = unlimited
ml := pcAggregateMethodLimits("alipay", []*dbent.PaymentProviderInstance{inst1, inst2})
if ml.SingleMin != 3 {
t.Fatalf("SingleMin = %v, want 3 (lowest floor)", ml.SingleMin)
}
if ml.SingleMax != 0 {
t.Fatalf("SingleMax = %v, want 0 (unlimited)", ml.SingleMax)
}
})
t.Run("empty instances returns zeros", func(t *testing.T) {
t.Parallel()
ml := pcAggregateMethodLimits("alipay", nil)
if ml.SingleMin != 0 || ml.SingleMax != 0 || ml.DailyLimit != 0 {
t.Fatalf("empty instances should return all zeros, got %+v", ml)
}
})
t.Run("invalid JSON treated as unlimited", func(t *testing.T) {
t.Parallel()
inst := makeInstance(1, "easypay", "alipay", `{invalid json}`)
ml := pcAggregateMethodLimits("alipay", []*dbent.PaymentProviderInstance{inst})
if ml.SingleMin != 0 || ml.SingleMax != 0 {
t.Fatalf("invalid JSON should be treated as unlimited, got %+v", ml)
}
})
t.Run("type not in limits JSON treated as unlimited", func(t *testing.T) {
t.Parallel()
inst := makeInstance(1, "easypay", "alipay,wxpay",
`{"wxpay":{"singleMin":1,"singleMax":10}}`) // only wxpay, no alipay
ml := pcAggregateMethodLimits("alipay", []*dbent.PaymentProviderInstance{inst})
if ml.SingleMin != 0 || ml.SingleMax != 0 {
t.Fatalf("missing type should be treated as unlimited, got %+v", ml)
}
})
t.Run("daily limit aggregation", func(t *testing.T) {
t.Parallel()
inst1 := makeInstance(1, "easypay", "alipay",
`{"alipay":{"singleMin":1,"singleMax":100,"dailyLimit":500}}`)
inst2 := makeInstance(2, "easypay", "alipay",
`{"alipay":{"singleMin":2,"singleMax":200,"dailyLimit":1000}}`)
ml := pcAggregateMethodLimits("alipay", []*dbent.PaymentProviderInstance{inst1, inst2})
if ml.DailyLimit != 1000 {
t.Fatalf("DailyLimit = %v, want 1000 (highest cap)", ml.DailyLimit)
}
})
}
func TestPcGroupByPaymentType(t *testing.T) {
t.Parallel()
t.Run("stripe instance maps all types to stripe group", func(t *testing.T) {
t.Parallel()
stripe := makeInstance(1, payment.TypeStripe, "card,alipay,link,wxpay", "")
easypay := makeInstance(2, payment.TypeEasyPay, "alipay,wxpay", "")
groups := pcGroupByPaymentType([]*dbent.PaymentProviderInstance{stripe, easypay})
// Stripe instance should only be in "stripe" group
if len(groups[payment.TypeStripe]) != 1 || groups[payment.TypeStripe][0].ID != 1 {
t.Fatalf("stripe group should contain only stripe instance, got %v", groups[payment.TypeStripe])
}
// alipay group should only contain easypay, NOT stripe
if len(groups[payment.TypeAlipay]) != 1 || groups[payment.TypeAlipay][0].ID != 2 {
t.Fatalf("alipay group should contain only easypay instance, got %v", groups[payment.TypeAlipay])
}
// wxpay group should only contain easypay, NOT stripe
if len(groups[payment.TypeWxpay]) != 1 || groups[payment.TypeWxpay][0].ID != 2 {
t.Fatalf("wxpay group should contain only easypay instance, got %v", groups[payment.TypeWxpay])
}
})
t.Run("multiple easypay instances in same groups", func(t *testing.T) {
t.Parallel()
ep1 := makeInstance(1, payment.TypeEasyPay, "alipay,wxpay", "")
ep2 := makeInstance(2, payment.TypeEasyPay, "alipay,wxpay", "")
groups := pcGroupByPaymentType([]*dbent.PaymentProviderInstance{ep1, ep2})
if len(groups[payment.TypeAlipay]) != 2 {
t.Fatalf("alipay group should have 2 instances, got %d", len(groups[payment.TypeAlipay]))
}
if len(groups[payment.TypeWxpay]) != 2 {
t.Fatalf("wxpay group should have 2 instances, got %d", len(groups[payment.TypeWxpay]))
}
})
t.Run("stripe with no supported types still in stripe group", func(t *testing.T) {
t.Parallel()
stripe := makeInstance(1, payment.TypeStripe, "", "")
groups := pcGroupByPaymentType([]*dbent.PaymentProviderInstance{stripe})
if len(groups[payment.TypeStripe]) != 1 {
t.Fatalf("stripe with empty types should still be in stripe group, got %v", groups)
}
})
}
func TestPcComputeGlobalRange(t *testing.T) {
t.Parallel()
t.Run("all methods have limits", func(t *testing.T) {
t.Parallel()
methods := map[string]MethodLimits{
"alipay": {SingleMin: 2, SingleMax: 14},
"wxpay": {SingleMin: 1, SingleMax: 12},
"stripe": {SingleMin: 5, SingleMax: 100},
}
gMin, gMax := pcComputeGlobalRange(methods)
if gMin != 1 {
t.Fatalf("global min = %v, want 1 (lowest floor)", gMin)
}
if gMax != 100 {
t.Fatalf("global max = %v, want 100 (highest ceiling)", gMax)
}
})
t.Run("one method unlimited makes global unlimited", func(t *testing.T) {
t.Parallel()
methods := map[string]MethodLimits{
"alipay": {SingleMin: 2, SingleMax: 14},
"stripe": {SingleMin: 0, SingleMax: 0}, // unlimited
}
gMin, gMax := pcComputeGlobalRange(methods)
if gMin != 0 {
t.Fatalf("global min = %v, want 0 (unlimited)", gMin)
}
if gMax != 0 {
t.Fatalf("global max = %v, want 0 (unlimited)", gMax)
}
})
t.Run("empty methods returns zeros", func(t *testing.T) {
t.Parallel()
gMin, gMax := pcComputeGlobalRange(map[string]MethodLimits{})
if gMin != 0 || gMax != 0 {
t.Fatalf("empty methods should return (0, 0), got (%v, %v)", gMin, gMax)
}
})
t.Run("only min unlimited", func(t *testing.T) {
t.Parallel()
methods := map[string]MethodLimits{
"alipay": {SingleMin: 0, SingleMax: 100},
"wxpay": {SingleMin: 5, SingleMax: 50},
}
gMin, gMax := pcComputeGlobalRange(methods)
if gMin != 0 {
t.Fatalf("global min = %v, want 0 (unlimited)", gMin)
}
if gMax != 100 {
t.Fatalf("global max = %v, want 100", gMax)
}
})
}
func TestPcInstanceTypeLimits(t *testing.T) {
t.Parallel()
t.Run("empty limits string returns false", func(t *testing.T) {
t.Parallel()
inst := makeInstance(1, "easypay", "alipay", "")
_, ok := pcInstanceTypeLimits(inst, "alipay")
if ok {
t.Fatal("expected ok=false for empty limits")
}
})
t.Run("type found returns correct values", func(t *testing.T) {
t.Parallel()
inst := makeInstance(1, "easypay", "alipay",
`{"alipay":{"singleMin":2,"singleMax":14,"dailyLimit":500}}`)
cl, ok := pcInstanceTypeLimits(inst, "alipay")
if !ok {
t.Fatal("expected ok=true")
}
if cl.SingleMin != 2 || cl.SingleMax != 14 || cl.DailyLimit != 500 {
t.Fatalf("limits = %+v, want min:2 max:14 daily:500", cl)
}
})
t.Run("type not found returns false", func(t *testing.T) {
t.Parallel()
inst := makeInstance(1, "easypay", "alipay",
`{"wxpay":{"singleMin":1}}`)
_, ok := pcInstanceTypeLimits(inst, "alipay")
if ok {
t.Fatal("expected ok=false for missing type")
}
})
t.Run("invalid JSON returns false", func(t *testing.T) {
t.Parallel()
inst := makeInstance(1, "easypay", "alipay", `{bad json}`)
_, ok := pcInstanceTypeLimits(inst, "alipay")
if ok {
t.Fatal("expected ok=false for invalid JSON")
}
})
}
package service
import (
"context"
"fmt"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/subscriptionplan"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
// --- Plan CRUD ---
// PlanGroupInfo holds the group details needed for subscription plan display.
type PlanGroupInfo struct {
Platform string `json:"platform"`
Name string `json:"name"`
RateMultiplier float64 `json:"rate_multiplier"`
DailyLimitUSD *float64 `json:"daily_limit_usd"`
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
ModelScopes []string `json:"supported_model_scopes"`
}
// GetGroupPlatformMap returns a map of group_id → platform for the given plans.
func (s *PaymentConfigService) GetGroupPlatformMap(ctx context.Context, plans []*dbent.SubscriptionPlan) map[int64]string {
info := s.GetGroupInfoMap(ctx, plans)
m := make(map[int64]string, len(info))
for id, gi := range info {
m[id] = gi.Platform
}
return m
}
// GetGroupInfoMap returns a map of group_id → PlanGroupInfo for the given plans.
func (s *PaymentConfigService) GetGroupInfoMap(ctx context.Context, plans []*dbent.SubscriptionPlan) map[int64]PlanGroupInfo {
ids := make([]int64, 0, len(plans))
seen := make(map[int64]bool)
for _, p := range plans {
if !seen[p.GroupID] {
seen[p.GroupID] = true
ids = append(ids, p.GroupID)
}
}
if len(ids) == 0 {
return nil
}
groups, err := s.entClient.Group.Query().Where(group.IDIn(ids...)).All(ctx)
if err != nil {
return nil
}
m := make(map[int64]PlanGroupInfo, len(groups))
for _, g := range groups {
m[int64(g.ID)] = PlanGroupInfo{
Platform: g.Platform,
Name: g.Name,
RateMultiplier: g.RateMultiplier,
DailyLimitUSD: g.DailyLimitUsd,
WeeklyLimitUSD: g.WeeklyLimitUsd,
MonthlyLimitUSD: g.MonthlyLimitUsd,
ModelScopes: g.SupportedModelScopes,
}
}
return m
}
func (s *PaymentConfigService) ListPlans(ctx context.Context) ([]*dbent.SubscriptionPlan, error) {
return s.entClient.SubscriptionPlan.Query().Order(subscriptionplan.BySortOrder()).All(ctx)
}
func (s *PaymentConfigService) ListPlansForSale(ctx context.Context) ([]*dbent.SubscriptionPlan, error) {
return s.entClient.SubscriptionPlan.Query().Where(subscriptionplan.ForSaleEQ(true)).Order(subscriptionplan.BySortOrder()).All(ctx)
}
func (s *PaymentConfigService) CreatePlan(ctx context.Context, req CreatePlanRequest) (*dbent.SubscriptionPlan, error) {
b := s.entClient.SubscriptionPlan.Create().
SetGroupID(req.GroupID).SetName(req.Name).SetDescription(req.Description).
SetPrice(req.Price).SetValidityDays(req.ValidityDays).SetValidityUnit(req.ValidityUnit).
SetFeatures(req.Features).SetProductName(req.ProductName).
SetForSale(req.ForSale).SetSortOrder(req.SortOrder)
if req.OriginalPrice != nil {
b.SetOriginalPrice(*req.OriginalPrice)
}
return b.Save(ctx)
}
// UpdatePlan updates a subscription plan by ID (patch semantics).
// NOTE: This function exceeds 30 lines due to per-field nil-check patch update boilerplate.
func (s *PaymentConfigService) UpdatePlan(ctx context.Context, id int64, req UpdatePlanRequest) (*dbent.SubscriptionPlan, error) {
u := s.entClient.SubscriptionPlan.UpdateOneID(id)
if req.GroupID != nil {
u.SetGroupID(*req.GroupID)
}
if req.Name != nil {
u.SetName(*req.Name)
}
if req.Description != nil {
u.SetDescription(*req.Description)
}
if req.Price != nil {
u.SetPrice(*req.Price)
}
if req.OriginalPrice != nil {
u.SetOriginalPrice(*req.OriginalPrice)
}
if req.ValidityDays != nil {
u.SetValidityDays(*req.ValidityDays)
}
if req.ValidityUnit != nil {
u.SetValidityUnit(*req.ValidityUnit)
}
if req.Features != nil {
u.SetFeatures(*req.Features)
}
if req.ProductName != nil {
u.SetProductName(*req.ProductName)
}
if req.ForSale != nil {
u.SetForSale(*req.ForSale)
}
if req.SortOrder != nil {
u.SetSortOrder(*req.SortOrder)
}
return u.Save(ctx)
}
func (s *PaymentConfigService) DeletePlan(ctx context.Context, id int64) error {
count, err := s.countPendingOrdersByPlan(ctx, id)
if err != nil {
return fmt.Errorf("check pending orders: %w", err)
}
if count > 0 {
return infraerrors.Conflict("PENDING_ORDERS",
fmt.Sprintf("this plan has %d in-progress orders and cannot be deleted — wait for orders to complete first", count))
}
return s.entClient.SubscriptionPlan.DeleteOneID(id).Exec(ctx)
}
// GetPlan returns a subscription plan by ID.
func (s *PaymentConfigService) GetPlan(ctx context.Context, id int64) (*dbent.SubscriptionPlan, error) {
plan, err := s.entClient.SubscriptionPlan.Get(ctx, id)
if err != nil {
return nil, infraerrors.NotFound("PLAN_NOT_FOUND", "subscription plan not found")
}
return plan, nil
}
package service
import (
"context"
"encoding/json"
"fmt"
"strconv"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
// --- Provider Instance CRUD ---
func (s *PaymentConfigService) ListProviderInstances(ctx context.Context) ([]*dbent.PaymentProviderInstance, error) {
return s.entClient.PaymentProviderInstance.Query().Order(paymentproviderinstance.BySortOrder()).All(ctx)
}
// ProviderInstanceResponse is the API response for a provider instance.
type ProviderInstanceResponse struct {
ID int64 `json:"id"`
ProviderKey string `json:"provider_key"`
Name string `json:"name"`
Config map[string]string `json:"config"`
SupportedTypes []string `json:"supported_types"`
Limits string `json:"limits"`
Enabled bool `json:"enabled"`
RefundEnabled bool `json:"refund_enabled"`
SortOrder int `json:"sort_order"`
PaymentMode string `json:"payment_mode"`
}
// ListProviderInstancesWithConfig returns provider instances with decrypted config.
func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Context) ([]ProviderInstanceResponse, error) {
instances, err := s.entClient.PaymentProviderInstance.Query().
Order(paymentproviderinstance.BySortOrder()).All(ctx)
if err != nil {
return nil, err
}
result := make([]ProviderInstanceResponse, 0, len(instances))
for _, inst := range instances {
resp := ProviderInstanceResponse{
ID: int64(inst.ID), ProviderKey: inst.ProviderKey, Name: inst.Name,
SupportedTypes: splitTypes(inst.SupportedTypes), Limits: inst.Limits,
Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled, SortOrder: inst.SortOrder,
PaymentMode: inst.PaymentMode,
}
resp.Config, err = s.decryptAndMaskConfig(inst.Config)
if err != nil {
return nil, fmt.Errorf("decrypt config for instance %d: %w", inst.ID, err)
}
result = append(result, resp)
}
return result, nil
}
func (s *PaymentConfigService) decryptAndMaskConfig(encrypted string) (map[string]string, error) {
return s.decryptConfig(encrypted)
}
// pendingOrderStatuses are order statuses considered "in progress".
var pendingOrderStatuses = []string{
payment.OrderStatusPending,
payment.OrderStatusPaid,
payment.OrderStatusRecharging,
}
var sensitiveConfigPatterns = []string{"key", "pkey", "secret", "private", "password"}
func isSensitiveConfigField(fieldName string) bool {
lower := strings.ToLower(fieldName)
for _, p := range sensitiveConfigPatterns {
if strings.Contains(lower, p) {
return true
}
}
return false
}
func (s *PaymentConfigService) countPendingOrders(ctx context.Context, providerInstanceID int64) (int, error) {
return s.entClient.PaymentOrder.Query().
Where(
paymentorder.ProviderInstanceIDEQ(strconv.FormatInt(providerInstanceID, 10)),
paymentorder.StatusIn(pendingOrderStatuses...),
).Count(ctx)
}
func (s *PaymentConfigService) countPendingOrdersByPlan(ctx context.Context, planID int64) (int, error) {
return s.entClient.PaymentOrder.Query().
Where(
paymentorder.PlanIDEQ(planID),
paymentorder.StatusIn(pendingOrderStatuses...),
).Count(ctx)
}
var validProviderKeys = map[string]bool{
payment.TypeEasyPay: true, payment.TypeAlipay: true, payment.TypeWxpay: true, payment.TypeStripe: true,
}
func (s *PaymentConfigService) CreateProviderInstance(ctx context.Context, req CreateProviderInstanceRequest) (*dbent.PaymentProviderInstance, error) {
typesStr := joinTypes(req.SupportedTypes)
if err := validateProviderRequest(req.ProviderKey, req.Name, typesStr); err != nil {
return nil, err
}
enc, err := s.encryptConfig(req.Config)
if err != nil {
return nil, err
}
return s.entClient.PaymentProviderInstance.Create().
SetProviderKey(req.ProviderKey).SetName(req.Name).SetConfig(enc).
SetSupportedTypes(typesStr).SetEnabled(req.Enabled).SetPaymentMode(req.PaymentMode).
SetSortOrder(req.SortOrder).SetLimits(req.Limits).SetRefundEnabled(req.RefundEnabled).
Save(ctx)
}
func validateProviderRequest(providerKey, name, supportedTypes string) error {
if strings.TrimSpace(name) == "" {
return infraerrors.BadRequest("VALIDATION_ERROR", "provider name is required")
}
if !validProviderKeys[providerKey] {
return infraerrors.BadRequest("VALIDATION_ERROR", fmt.Sprintf("invalid provider key: %s", providerKey))
}
// supported_types can be empty (provider accepts no payment types until configured)
return nil
}
// UpdateProviderInstance updates a provider instance by ID (patch semantics).
// NOTE: This function exceeds 30 lines due to per-field nil-check patch update
// boilerplate and pending-order safety checks.
func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id int64, req UpdateProviderInstanceRequest) (*dbent.PaymentProviderInstance, error) {
if req.Config != nil {
hasSensitive := false
for k := range req.Config {
if isSensitiveConfigField(k) && req.Config[k] != "" {
hasSensitive = true
break
}
}
if hasSensitive {
count, err := s.countPendingOrders(ctx, id)
if err != nil {
return nil, fmt.Errorf("check pending orders: %w", err)
}
if count > 0 {
return nil, infraerrors.Conflict("PENDING_ORDERS", "instance has pending orders").
WithMetadata(map[string]string{"count": strconv.Itoa(count)})
}
}
}
if req.Enabled != nil && !*req.Enabled {
count, err := s.countPendingOrders(ctx, id)
if err != nil {
return nil, fmt.Errorf("check pending orders: %w", err)
}
if count > 0 {
return nil, infraerrors.Conflict("PENDING_ORDERS", "instance has pending orders").
WithMetadata(map[string]string{"count": strconv.Itoa(count)})
}
}
u := s.entClient.PaymentProviderInstance.UpdateOneID(id)
if req.Name != nil {
u.SetName(*req.Name)
}
if req.Config != nil {
merged, err := s.mergeConfig(ctx, id, req.Config)
if err != nil {
return nil, err
}
enc, err := s.encryptConfig(merged)
if err != nil {
return nil, err
}
u.SetConfig(enc)
}
if req.SupportedTypes != nil {
// Check pending orders before removing payment types
count, err := s.countPendingOrders(ctx, id)
if err != nil {
return nil, fmt.Errorf("check pending orders: %w", err)
}
if count > 0 {
// Load current instance to compare types
inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id)
if err != nil {
return nil, fmt.Errorf("load provider instance: %w", err)
}
oldTypes := strings.Split(inst.SupportedTypes, ",")
newTypes := req.SupportedTypes
for _, ot := range oldTypes {
ot = strings.TrimSpace(ot)
if ot == "" {
continue
}
found := false
for _, nt := range newTypes {
if strings.TrimSpace(nt) == ot {
found = true
break
}
}
if !found {
return nil, infraerrors.Conflict("PENDING_ORDERS", "cannot remove payment types while instance has pending orders").
WithMetadata(map[string]string{"count": strconv.Itoa(count)})
}
}
}
u.SetSupportedTypes(joinTypes(req.SupportedTypes))
}
if req.Enabled != nil {
u.SetEnabled(*req.Enabled)
}
if req.SortOrder != nil {
u.SetSortOrder(*req.SortOrder)
}
if req.Limits != nil {
u.SetLimits(*req.Limits)
}
if req.RefundEnabled != nil {
u.SetRefundEnabled(*req.RefundEnabled)
}
if req.PaymentMode != nil {
u.SetPaymentMode(*req.PaymentMode)
}
return u.Save(ctx)
}
func (s *PaymentConfigService) mergeConfig(ctx context.Context, id int64, newConfig map[string]string) (map[string]string, error) {
inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id)
if err != nil {
return nil, fmt.Errorf("load existing provider: %w", err)
}
existing, err := s.decryptConfig(inst.Config)
if err != nil {
return nil, fmt.Errorf("decrypt existing config for instance %d: %w", id, err)
}
if existing == nil {
return newConfig, nil
}
for k, v := range newConfig {
existing[k] = v
}
return existing, nil
}
func (s *PaymentConfigService) decryptConfig(encrypted string) (map[string]string, error) {
if encrypted == "" {
return nil, nil
}
decrypted, err := payment.Decrypt(encrypted, s.encryptionKey)
if err != nil {
return nil, fmt.Errorf("decrypt config: %w", err)
}
var raw map[string]string
if err := json.Unmarshal([]byte(decrypted), &raw); err != nil {
return nil, fmt.Errorf("unmarshal decrypted config: %w", err)
}
return raw, nil
}
func (s *PaymentConfigService) DeleteProviderInstance(ctx context.Context, id int64) error {
count, err := s.countPendingOrders(ctx, id)
if err != nil {
return fmt.Errorf("check pending orders: %w", err)
}
if count > 0 {
return infraerrors.Conflict("PENDING_ORDERS",
fmt.Sprintf("this instance has %d in-progress orders and cannot be deleted — wait for orders to complete or disable the instance first", count))
}
return s.entClient.PaymentProviderInstance.DeleteOneID(id).Exec(ctx)
}
func (s *PaymentConfigService) encryptConfig(cfg map[string]string) (string, error) {
data, err := json.Marshal(cfg)
if err != nil {
return "", fmt.Errorf("marshal config: %w", err)
}
enc, err := payment.Encrypt(string(data), s.encryptionKey)
if err != nil {
return "", fmt.Errorf("encrypt config: %w", err)
}
return enc, nil
}
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestValidateProviderRequest(t *testing.T) {
t.Parallel()
tests := []struct {
name string
providerKey string
providerName string
supportedTypes string
wantErr bool
errContains string
}{
{
name: "valid easypay with types",
providerKey: "easypay",
providerName: "MyProvider",
supportedTypes: "alipay,wxpay",
wantErr: false,
},
{
name: "valid stripe with empty types",
providerKey: "stripe",
providerName: "Stripe Provider",
supportedTypes: "",
wantErr: false,
},
{
name: "valid alipay provider",
providerKey: "alipay",
providerName: "Alipay Direct",
supportedTypes: "alipay",
wantErr: false,
},
{
name: "valid wxpay provider",
providerKey: "wxpay",
providerName: "WeChat Pay",
supportedTypes: "wxpay",
wantErr: false,
},
{
name: "invalid provider key",
providerKey: "invalid",
providerName: "Name",
supportedTypes: "alipay",
wantErr: true,
errContains: "invalid provider key",
},
{
name: "empty name",
providerKey: "easypay",
providerName: "",
supportedTypes: "alipay",
wantErr: true,
errContains: "provider name is required",
},
{
name: "whitespace-only name",
providerKey: "easypay",
providerName: " ",
supportedTypes: "alipay",
wantErr: true,
errContains: "provider name is required",
},
{
name: "tab-only name",
providerKey: "easypay",
providerName: "\t",
supportedTypes: "alipay",
wantErr: true,
errContains: "provider name is required",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
err := validateProviderRequest(tc.providerKey, tc.providerName, tc.supportedTypes)
if tc.wantErr {
require.Error(t, err)
assert.Contains(t, err.Error(), tc.errContains)
} else {
require.NoError(t, err)
}
})
}
}
func TestIsSensitiveConfigField(t *testing.T) {
t.Parallel()
tests := []struct {
field string
wantSen bool
}{
// Sensitive fields (contain key/secret/private/password/pkey patterns)
{"secretKey", true},
{"apiSecret", true},
{"pkey", true},
{"privateKey", true},
{"apiPassword", true},
{"appKey", true},
{"SECRET_TOKEN", true},
{"PrivateData", true},
{"PASSWORD", true},
{"mySecretValue", true},
// Non-sensitive fields
{"appId", false},
{"mchId", false},
{"apiBase", false},
{"endpoint", false},
{"merchantNo", false},
{"paymentMode", false},
{"notifyUrl", false},
}
for _, tc := range tests {
t.Run(tc.field, func(t *testing.T) {
t.Parallel()
got := isSensitiveConfigField(tc.field)
assert.Equal(t, tc.wantSen, got, "isSensitiveConfigField(%q)", tc.field)
})
}
}
func TestJoinTypes(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input []string
want string
}{
{
name: "multiple types",
input: []string{"alipay", "wxpay"},
want: "alipay,wxpay",
},
{
name: "single type",
input: []string{"stripe"},
want: "stripe",
},
{
name: "empty slice",
input: []string{},
want: "",
},
{
name: "nil slice",
input: nil,
want: "",
},
{
name: "three types",
input: []string{"alipay", "wxpay", "stripe"},
want: "alipay,wxpay,stripe",
},
{
name: "types with spaces are not trimmed",
input: []string{" alipay ", " wxpay "},
want: " alipay , wxpay ",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := joinTypes(tc.input)
assert.Equal(t, tc.want, got)
})
}
}
package service
import (
"context"
"fmt"
"strconv"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
"github.com/Wei-Shaw/sub2api/internal/payment"
)
const (
SettingPaymentEnabled = "payment_enabled"
SettingMinRechargeAmount = "MIN_RECHARGE_AMOUNT"
SettingMaxRechargeAmount = "MAX_RECHARGE_AMOUNT"
SettingDailyRechargeLimit = "DAILY_RECHARGE_LIMIT"
SettingOrderTimeoutMinutes = "ORDER_TIMEOUT_MINUTES"
SettingMaxPendingOrders = "MAX_PENDING_ORDERS"
SettingEnabledPaymentTypes = "ENABLED_PAYMENT_TYPES"
SettingLoadBalanceStrategy = "LOAD_BALANCE_STRATEGY"
SettingBalancePayDisabled = "BALANCE_PAYMENT_DISABLED"
SettingProductNamePrefix = "PRODUCT_NAME_PREFIX"
SettingProductNameSuffix = "PRODUCT_NAME_SUFFIX"
SettingHelpImageURL = "PAYMENT_HELP_IMAGE_URL"
SettingHelpText = "PAYMENT_HELP_TEXT"
SettingCancelRateLimitOn = "CANCEL_RATE_LIMIT_ENABLED"
SettingCancelRateLimitMax = "CANCEL_RATE_LIMIT_MAX"
SettingCancelWindowSize = "CANCEL_RATE_LIMIT_WINDOW"
SettingCancelWindowUnit = "CANCEL_RATE_LIMIT_UNIT"
SettingCancelWindowMode = "CANCEL_RATE_LIMIT_WINDOW_MODE"
)
// Default values for payment configuration settings.
const (
defaultOrderTimeoutMin = 30
defaultMaxPendingOrders = 3
)
// PaymentConfig holds the payment system configuration.
type PaymentConfig struct {
Enabled bool `json:"enabled"`
MinAmount float64 `json:"min_amount"`
MaxAmount float64 `json:"max_amount"`
DailyLimit float64 `json:"daily_limit"`
OrderTimeoutMin int `json:"order_timeout_minutes"`
MaxPendingOrders int `json:"max_pending_orders"`
EnabledTypes []string `json:"enabled_payment_types"`
BalanceDisabled bool `json:"balance_disabled"`
LoadBalanceStrategy string `json:"load_balance_strategy"`
ProductNamePrefix string `json:"product_name_prefix"`
ProductNameSuffix string `json:"product_name_suffix"`
HelpImageURL string `json:"help_image_url"`
HelpText string `json:"help_text"`
StripePublishableKey string `json:"stripe_publishable_key,omitempty"`
// Cancel rate limit settings
CancelRateLimitEnabled bool `json:"cancel_rate_limit_enabled"`
CancelRateLimitMax int `json:"cancel_rate_limit_max"`
CancelRateLimitWindow int `json:"cancel_rate_limit_window"`
CancelRateLimitUnit string `json:"cancel_rate_limit_unit"`
CancelRateLimitMode string `json:"cancel_rate_limit_window_mode"`
}
// UpdatePaymentConfigRequest contains fields to update payment configuration.
type UpdatePaymentConfigRequest struct {
Enabled *bool `json:"enabled"`
MinAmount *float64 `json:"min_amount"`
MaxAmount *float64 `json:"max_amount"`
DailyLimit *float64 `json:"daily_limit"`
OrderTimeoutMin *int `json:"order_timeout_minutes"`
MaxPendingOrders *int `json:"max_pending_orders"`
EnabledTypes []string `json:"enabled_payment_types"`
BalanceDisabled *bool `json:"balance_disabled"`
LoadBalanceStrategy *string `json:"load_balance_strategy"`
ProductNamePrefix *string `json:"product_name_prefix"`
ProductNameSuffix *string `json:"product_name_suffix"`
HelpImageURL *string `json:"help_image_url"`
HelpText *string `json:"help_text"`
// Cancel rate limit settings
CancelRateLimitEnabled *bool `json:"cancel_rate_limit_enabled"`
CancelRateLimitMax *int `json:"cancel_rate_limit_max"`
CancelRateLimitWindow *int `json:"cancel_rate_limit_window"`
CancelRateLimitUnit *string `json:"cancel_rate_limit_unit"`
CancelRateLimitMode *string `json:"cancel_rate_limit_window_mode"`
}
// MethodLimits holds per-payment-type limits.
type MethodLimits struct {
PaymentType string `json:"payment_type"`
FeeRate float64 `json:"fee_rate"`
DailyLimit float64 `json:"daily_limit"`
SingleMin float64 `json:"single_min"`
SingleMax float64 `json:"single_max"`
}
// MethodLimitsResponse is the full response for the user-facing /limits API.
// It includes per-method limits and the global widest range (union of all methods).
type MethodLimitsResponse struct {
Methods map[string]MethodLimits `json:"methods"`
GlobalMin float64 `json:"global_min"` // 0 = no minimum
GlobalMax float64 `json:"global_max"` // 0 = no maximum
}
type CreateProviderInstanceRequest struct {
ProviderKey string `json:"provider_key"`
Name string `json:"name"`
Config map[string]string `json:"config"`
SupportedTypes []string `json:"supported_types"`
Enabled bool `json:"enabled"`
PaymentMode string `json:"payment_mode"`
SortOrder int `json:"sort_order"`
Limits string `json:"limits"`
RefundEnabled bool `json:"refund_enabled"`
}
type UpdateProviderInstanceRequest struct {
Name *string `json:"name"`
Config map[string]string `json:"config"`
SupportedTypes []string `json:"supported_types"`
Enabled *bool `json:"enabled"`
PaymentMode *string `json:"payment_mode"`
SortOrder *int `json:"sort_order"`
Limits *string `json:"limits"`
RefundEnabled *bool `json:"refund_enabled"`
}
type CreatePlanRequest struct {
GroupID int64 `json:"group_id"`
Name string `json:"name"`
Description string `json:"description"`
Price float64 `json:"price"`
OriginalPrice *float64 `json:"original_price"`
ValidityDays int `json:"validity_days"`
ValidityUnit string `json:"validity_unit"`
Features string `json:"features"`
ProductName string `json:"product_name"`
ForSale bool `json:"for_sale"`
SortOrder int `json:"sort_order"`
}
type UpdatePlanRequest struct {
GroupID *int64 `json:"group_id"`
Name *string `json:"name"`
Description *string `json:"description"`
Price *float64 `json:"price"`
OriginalPrice *float64 `json:"original_price"`
ValidityDays *int `json:"validity_days"`
ValidityUnit *string `json:"validity_unit"`
Features *string `json:"features"`
ProductName *string `json:"product_name"`
ForSale *bool `json:"for_sale"`
SortOrder *int `json:"sort_order"`
}
// PaymentConfigService manages payment configuration and CRUD for
// provider instances, channels, and subscription plans.
type PaymentConfigService struct {
entClient *dbent.Client
settingRepo SettingRepository
encryptionKey []byte
}
// NewPaymentConfigService creates a new PaymentConfigService.
func NewPaymentConfigService(entClient *dbent.Client, settingRepo SettingRepository, encryptionKey []byte) *PaymentConfigService {
return &PaymentConfigService{entClient: entClient, settingRepo: settingRepo, encryptionKey: encryptionKey}
}
// IsPaymentEnabled returns whether the payment system is enabled.
func (s *PaymentConfigService) IsPaymentEnabled(ctx context.Context) bool {
val, err := s.settingRepo.GetValue(ctx, SettingPaymentEnabled)
if err != nil {
return false
}
return val == "true"
}
// GetPaymentConfig returns the full payment configuration.
func (s *PaymentConfigService) GetPaymentConfig(ctx context.Context) (*PaymentConfig, error) {
keys := []string{
SettingPaymentEnabled, SettingMinRechargeAmount, SettingMaxRechargeAmount,
SettingDailyRechargeLimit, SettingOrderTimeoutMinutes, SettingMaxPendingOrders,
SettingEnabledPaymentTypes, SettingBalancePayDisabled, SettingLoadBalanceStrategy,
SettingProductNamePrefix, SettingProductNameSuffix,
SettingHelpImageURL, SettingHelpText,
SettingCancelRateLimitOn, SettingCancelRateLimitMax,
SettingCancelWindowSize, SettingCancelWindowUnit, SettingCancelWindowMode,
}
vals, err := s.settingRepo.GetMultiple(ctx, keys)
if err != nil {
return nil, fmt.Errorf("get payment config settings: %w", err)
}
cfg := s.parsePaymentConfig(vals)
// Load Stripe publishable key from the first enabled Stripe provider instance
cfg.StripePublishableKey = s.getStripePublishableKey(ctx)
return cfg, nil
}
func (s *PaymentConfigService) parsePaymentConfig(vals map[string]string) *PaymentConfig {
cfg := &PaymentConfig{
Enabled: vals[SettingPaymentEnabled] == "true",
MinAmount: pcParseFloat(vals[SettingMinRechargeAmount], 1),
MaxAmount: pcParseFloat(vals[SettingMaxRechargeAmount], 0),
DailyLimit: pcParseFloat(vals[SettingDailyRechargeLimit], 0),
OrderTimeoutMin: pcParseInt(vals[SettingOrderTimeoutMinutes], defaultOrderTimeoutMin),
MaxPendingOrders: pcParseInt(vals[SettingMaxPendingOrders], defaultMaxPendingOrders),
BalanceDisabled: vals[SettingBalancePayDisabled] == "true",
LoadBalanceStrategy: vals[SettingLoadBalanceStrategy],
ProductNamePrefix: vals[SettingProductNamePrefix],
ProductNameSuffix: vals[SettingProductNameSuffix],
HelpImageURL: vals[SettingHelpImageURL],
HelpText: vals[SettingHelpText],
CancelRateLimitEnabled: vals[SettingCancelRateLimitOn] == "true",
CancelRateLimitMax: pcParseInt(vals[SettingCancelRateLimitMax], 10),
CancelRateLimitWindow: pcParseInt(vals[SettingCancelWindowSize], 1),
CancelRateLimitUnit: vals[SettingCancelWindowUnit],
CancelRateLimitMode: vals[SettingCancelWindowMode],
}
if cfg.LoadBalanceStrategy == "" {
cfg.LoadBalanceStrategy = payment.DefaultLoadBalanceStrategy
}
if raw := vals[SettingEnabledPaymentTypes]; raw != "" {
for _, t := range strings.Split(raw, ",") {
t = strings.TrimSpace(t)
if t != "" {
cfg.EnabledTypes = append(cfg.EnabledTypes, t)
}
}
}
return cfg
}
// getStripePublishableKey finds the publishable key from the first enabled Stripe provider instance.
func (s *PaymentConfigService) getStripePublishableKey(ctx context.Context) string {
instances, err := s.entClient.PaymentProviderInstance.Query().
Where(
paymentproviderinstance.EnabledEQ(true),
paymentproviderinstance.ProviderKeyEQ(payment.TypeStripe),
).Limit(1).All(ctx)
if err != nil || len(instances) == 0 {
return ""
}
cfg, err := s.decryptConfig(instances[0].Config)
if err != nil || cfg == nil {
return ""
}
return cfg[payment.ConfigKeyPublishableKey]
}
// UpdatePaymentConfig updates the payment configuration settings.
// NOTE: This function exceeds 30 lines because each field requires an independent
// nil-check before serialisation — this is inherent to patch-style update patterns
// and cannot be meaningfully decomposed without introducing unnecessary abstraction.
func (s *PaymentConfigService) UpdatePaymentConfig(ctx context.Context, req UpdatePaymentConfigRequest) error {
m := map[string]string{
SettingPaymentEnabled: formatBoolOrEmpty(req.Enabled),
SettingMinRechargeAmount: formatPositiveFloat(req.MinAmount),
SettingMaxRechargeAmount: formatPositiveFloat(req.MaxAmount),
SettingDailyRechargeLimit: formatPositiveFloat(req.DailyLimit),
SettingOrderTimeoutMinutes: formatPositiveInt(req.OrderTimeoutMin),
SettingMaxPendingOrders: formatPositiveInt(req.MaxPendingOrders),
SettingBalancePayDisabled: formatBoolOrEmpty(req.BalanceDisabled),
SettingLoadBalanceStrategy: derefStr(req.LoadBalanceStrategy),
SettingProductNamePrefix: derefStr(req.ProductNamePrefix),
SettingProductNameSuffix: derefStr(req.ProductNameSuffix),
SettingHelpImageURL: derefStr(req.HelpImageURL),
SettingHelpText: derefStr(req.HelpText),
SettingCancelRateLimitOn: formatBoolOrEmpty(req.CancelRateLimitEnabled),
SettingCancelRateLimitMax: formatPositiveInt(req.CancelRateLimitMax),
SettingCancelWindowSize: formatPositiveInt(req.CancelRateLimitWindow),
SettingCancelWindowUnit: derefStr(req.CancelRateLimitUnit),
SettingCancelWindowMode: derefStr(req.CancelRateLimitMode),
}
if req.EnabledTypes != nil {
m[SettingEnabledPaymentTypes] = strings.Join(req.EnabledTypes, ",")
} else {
m[SettingEnabledPaymentTypes] = ""
}
return s.settingRepo.SetMultiple(ctx, m)
}
func formatBoolOrEmpty(v *bool) string {
if v == nil {
return ""
}
return strconv.FormatBool(*v)
}
func formatPositiveFloat(v *float64) string {
if v == nil || *v <= 0 {
return "" // empty → parsePaymentConfig uses default
}
return strconv.FormatFloat(*v, 'f', 2, 64)
}
func formatPositiveInt(v *int) string {
if v == nil || *v <= 0 {
return ""
}
return strconv.Itoa(*v)
}
func derefStr(v *string) string {
if v == nil {
return ""
}
return *v
}
func splitTypes(s string) []string {
if s == "" {
return nil
}
parts := strings.Split(s, ",")
result := make([]string, 0, len(parts))
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
result = append(result, p)
}
}
return result
}
func joinTypes(types []string) string {
return strings.Join(types, ",")
}
func pcParseFloat(s string, defaultVal float64) float64 {
if s == "" {
return defaultVal
}
v, err := strconv.ParseFloat(s, 64)
if err != nil {
return defaultVal
}
return v
}
func pcParseInt(s string, defaultVal int) int {
if s == "" {
return defaultVal
}
v, err := strconv.Atoi(s)
if err != nil {
return defaultVal
}
return v
}
package service
import (
"testing"
"github.com/Wei-Shaw/sub2api/internal/payment"
)
func TestPcParseFloat(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
defaultVal float64
expected float64
}{
{"empty string returns default", "", 1.0, 1.0},
{"valid float", "3.14", 0, 3.14},
{"valid integer as float", "42", 0, 42.0},
{"invalid string returns default", "notanumber", 9.99, 9.99},
{"zero value", "0", 5.0, 0},
{"negative value", "-10.5", 0, -10.5},
{"very large value", "99999999.99", 0, 99999999.99},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := pcParseFloat(tt.input, tt.defaultVal)
if got != tt.expected {
t.Fatalf("pcParseFloat(%q, %v) = %v, want %v", tt.input, tt.defaultVal, got, tt.expected)
}
})
}
}
func TestPcParseInt(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
defaultVal int
expected int
}{
{"empty string returns default", "", 30, 30},
{"valid int", "10", 0, 10},
{"invalid string returns default", "abc", 5, 5},
{"float string returns default", "3.14", 0, 0},
{"zero value", "0", 99, 0},
{"negative value", "-1", 0, -1},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := pcParseInt(tt.input, tt.defaultVal)
if got != tt.expected {
t.Fatalf("pcParseInt(%q, %v) = %v, want %v", tt.input, tt.defaultVal, got, tt.expected)
}
})
}
}
func TestParsePaymentConfig(t *testing.T) {
t.Parallel()
svc := &PaymentConfigService{}
t.Run("empty vals uses defaults", func(t *testing.T) {
t.Parallel()
cfg := svc.parsePaymentConfig(map[string]string{})
if cfg.Enabled {
t.Fatal("expected Enabled=false by default")
}
if cfg.MinAmount != 1 {
t.Fatalf("expected MinAmount=1, got %v", cfg.MinAmount)
}
if cfg.MaxAmount != 0 {
t.Fatalf("expected MaxAmount=0 (no limit), got %v", cfg.MaxAmount)
}
if cfg.OrderTimeoutMin != 30 {
t.Fatalf("expected OrderTimeoutMin=30, got %v", cfg.OrderTimeoutMin)
}
if cfg.MaxPendingOrders != 3 {
t.Fatalf("expected MaxPendingOrders=3, got %v", cfg.MaxPendingOrders)
}
if cfg.LoadBalanceStrategy != payment.DefaultLoadBalanceStrategy {
t.Fatalf("expected LoadBalanceStrategy=%s, got %q", payment.DefaultLoadBalanceStrategy, cfg.LoadBalanceStrategy)
}
if len(cfg.EnabledTypes) != 0 {
t.Fatalf("expected empty EnabledTypes, got %v", cfg.EnabledTypes)
}
})
t.Run("all values populated", func(t *testing.T) {
t.Parallel()
vals := map[string]string{
SettingPaymentEnabled: "true",
SettingMinRechargeAmount: "5.00",
SettingMaxRechargeAmount: "1000.00",
SettingDailyRechargeLimit: "5000.00",
SettingOrderTimeoutMinutes: "15",
SettingMaxPendingOrders: "5",
SettingEnabledPaymentTypes: "alipay,wxpay,stripe",
SettingBalancePayDisabled: "true",
SettingLoadBalanceStrategy: "least_amount",
SettingProductNamePrefix: "PRE",
SettingProductNameSuffix: "SUF",
}
cfg := svc.parsePaymentConfig(vals)
if !cfg.Enabled {
t.Fatal("expected Enabled=true")
}
if cfg.MinAmount != 5 {
t.Fatalf("MinAmount = %v, want 5", cfg.MinAmount)
}
if cfg.MaxAmount != 1000 {
t.Fatalf("MaxAmount = %v, want 1000", cfg.MaxAmount)
}
if cfg.DailyLimit != 5000 {
t.Fatalf("DailyLimit = %v, want 5000", cfg.DailyLimit)
}
if cfg.OrderTimeoutMin != 15 {
t.Fatalf("OrderTimeoutMin = %v, want 15", cfg.OrderTimeoutMin)
}
if cfg.MaxPendingOrders != 5 {
t.Fatalf("MaxPendingOrders = %v, want 5", cfg.MaxPendingOrders)
}
if len(cfg.EnabledTypes) != 3 {
t.Fatalf("EnabledTypes len = %d, want 3", len(cfg.EnabledTypes))
}
if cfg.EnabledTypes[0] != "alipay" || cfg.EnabledTypes[1] != "wxpay" || cfg.EnabledTypes[2] != "stripe" {
t.Fatalf("EnabledTypes = %v, want [alipay wxpay stripe]", cfg.EnabledTypes)
}
if !cfg.BalanceDisabled {
t.Fatal("expected BalanceDisabled=true")
}
if cfg.LoadBalanceStrategy != "least_amount" {
t.Fatalf("LoadBalanceStrategy = %q, want %q", cfg.LoadBalanceStrategy, "least_amount")
}
if cfg.ProductNamePrefix != "PRE" {
t.Fatalf("ProductNamePrefix = %q, want %q", cfg.ProductNamePrefix, "PRE")
}
if cfg.ProductNameSuffix != "SUF" {
t.Fatalf("ProductNameSuffix = %q, want %q", cfg.ProductNameSuffix, "SUF")
}
})
t.Run("enabled types with spaces are trimmed", func(t *testing.T) {
t.Parallel()
vals := map[string]string{
SettingEnabledPaymentTypes: " alipay , wxpay ",
}
cfg := svc.parsePaymentConfig(vals)
if len(cfg.EnabledTypes) != 2 {
t.Fatalf("EnabledTypes len = %d, want 2", len(cfg.EnabledTypes))
}
if cfg.EnabledTypes[0] != "alipay" || cfg.EnabledTypes[1] != "wxpay" {
t.Fatalf("EnabledTypes = %v, want [alipay wxpay]", cfg.EnabledTypes)
}
})
t.Run("empty enabled types string", func(t *testing.T) {
t.Parallel()
vals := map[string]string{
SettingEnabledPaymentTypes: "",
}
cfg := svc.parsePaymentConfig(vals)
if len(cfg.EnabledTypes) != 0 {
t.Fatalf("expected empty EnabledTypes for empty string, got %v", cfg.EnabledTypes)
}
})
}
func TestGetBasePaymentType(t *testing.T) {
t.Parallel()
tests := []struct {
input string
expected string
}{
{payment.TypeEasyPay, payment.TypeEasyPay},
{payment.TypeStripe, payment.TypeStripe},
{payment.TypeCard, payment.TypeStripe},
{payment.TypeLink, payment.TypeStripe},
{payment.TypeAlipay, payment.TypeAlipay},
{payment.TypeAlipayDirect, payment.TypeAlipay},
{payment.TypeWxpay, payment.TypeWxpay},
{payment.TypeWxpayDirect, payment.TypeWxpay},
{"unknown", "unknown"},
{"", ""},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
t.Parallel()
got := payment.GetBasePaymentType(tt.input)
if got != tt.expected {
t.Fatalf("GetBasePaymentType(%q) = %q, want %q", tt.input, got, tt.expected)
}
})
}
}
package service
import (
"context"
"fmt"
"log/slog"
"math"
"strconv"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
// --- Payment Notification & Fulfillment ---
func (s *PaymentService) HandlePaymentNotification(ctx context.Context, n *payment.PaymentNotification, pk string) error {
if n.Status != payment.NotificationStatusSuccess {
return nil
}
// Look up order by out_trade_no (the external order ID we sent to the provider)
order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(n.OrderID)).Only(ctx)
if err != nil {
// Fallback: try legacy format (sub2_N where N is DB ID)
trimmed := strings.TrimPrefix(n.OrderID, orderIDPrefix)
if oid, parseErr := strconv.ParseInt(trimmed, 10, 64); parseErr == nil {
return s.confirmPayment(ctx, oid, n.TradeNo, n.Amount, pk)
}
return fmt.Errorf("order not found for out_trade_no: %s", n.OrderID)
}
return s.confirmPayment(ctx, order.ID, n.TradeNo, n.Amount, pk)
}
func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo string, paid float64, pk string) error {
o, err := s.entClient.PaymentOrder.Get(ctx, oid)
if err != nil {
slog.Error("order not found", "orderID", oid)
return nil
}
// Skip amount check when paid=0 (e.g. QueryOrder doesn't return amount).
// Also skip if paid is NaN/Inf (malformed provider data).
if paid > 0 && !math.IsNaN(paid) && !math.IsInf(paid, 0) {
if math.Abs(paid-o.PayAmount) > amountToleranceCNY {
s.writeAuditLog(ctx, o.ID, "PAYMENT_AMOUNT_MISMATCH", pk, map[string]any{"expected": o.PayAmount, "paid": paid, "tradeNo": tradeNo})
return fmt.Errorf("amount mismatch: expected %.2f, got %.2f", o.PayAmount, paid)
}
}
// Use order's expected amount when provider didn't report one
if paid <= 0 || math.IsNaN(paid) || math.IsInf(paid, 0) {
paid = o.PayAmount
}
return s.toPaid(ctx, o, tradeNo, paid, pk)
}
func (s *PaymentService) toPaid(ctx context.Context, o *dbent.PaymentOrder, tradeNo string, paid float64, pk string) error {
previousStatus := o.Status
now := time.Now()
grace := now.Add(-paymentGraceMinutes * time.Minute)
c, err := s.entClient.PaymentOrder.Update().Where(
paymentorder.IDEQ(o.ID),
paymentorder.Or(
paymentorder.StatusEQ(OrderStatusPending),
paymentorder.StatusEQ(OrderStatusCancelled),
paymentorder.And(
paymentorder.StatusEQ(OrderStatusExpired),
paymentorder.UpdatedAtGTE(grace),
),
),
).SetStatus(OrderStatusPaid).SetPayAmount(paid).SetPaymentTradeNo(tradeNo).SetPaidAt(now).ClearFailedAt().ClearFailedReason().Save(ctx)
if err != nil {
return fmt.Errorf("update to PAID: %w", err)
}
if c == 0 {
return s.alreadyProcessed(ctx, o)
}
if previousStatus == OrderStatusCancelled || previousStatus == OrderStatusExpired {
slog.Info("order recovered from webhook payment success",
"orderID", o.ID,
"previousStatus", previousStatus,
"tradeNo", tradeNo,
"provider", pk,
)
s.writeAuditLog(ctx, o.ID, "ORDER_RECOVERED", pk, map[string]any{
"previous_status": previousStatus,
"tradeNo": tradeNo,
"paidAmount": paid,
"reason": "webhook payment success received after order " + previousStatus,
})
}
s.writeAuditLog(ctx, o.ID, "ORDER_PAID", pk, map[string]any{"tradeNo": tradeNo, "paidAmount": paid})
return s.executeFulfillment(ctx, o.ID)
}
func (s *PaymentService) alreadyProcessed(ctx context.Context, o *dbent.PaymentOrder) error {
cur, err := s.entClient.PaymentOrder.Get(ctx, o.ID)
if err != nil {
return nil
}
switch cur.Status {
case OrderStatusCompleted, OrderStatusRefunded:
return nil
case OrderStatusFailed:
return s.executeFulfillment(ctx, o.ID)
case OrderStatusPaid, OrderStatusRecharging:
return fmt.Errorf("order %d is being processed", o.ID)
case OrderStatusExpired:
slog.Warn("webhook payment success for expired order beyond grace period",
"orderID", o.ID,
"status", cur.Status,
"updatedAt", cur.UpdatedAt,
)
s.writeAuditLog(ctx, o.ID, "PAYMENT_AFTER_EXPIRY", "system", map[string]any{
"status": cur.Status,
"updatedAt": cur.UpdatedAt,
"reason": "payment arrived after expiry grace period",
})
return nil
default:
return nil
}
}
func (s *PaymentService) executeFulfillment(ctx context.Context, oid int64) error {
o, err := s.entClient.PaymentOrder.Get(ctx, oid)
if err != nil {
return fmt.Errorf("get order: %w", err)
}
if o.OrderType == payment.OrderTypeSubscription {
return s.ExecuteSubscriptionFulfillment(ctx, oid)
}
return s.ExecuteBalanceFulfillment(ctx, oid)
}
func (s *PaymentService) ExecuteBalanceFulfillment(ctx context.Context, oid int64) error {
o, err := s.entClient.PaymentOrder.Get(ctx, oid)
if err != nil {
return infraerrors.NotFound("NOT_FOUND", "order not found")
}
if o.Status == OrderStatusCompleted {
return nil
}
if psIsRefundStatus(o.Status) {
return infraerrors.BadRequest("INVALID_STATUS", "refund-related order cannot fulfill")
}
if o.Status != OrderStatusPaid && o.Status != OrderStatusFailed {
return infraerrors.BadRequest("INVALID_STATUS", "order cannot fulfill in status "+o.Status)
}
c, err := s.entClient.PaymentOrder.Update().Where(paymentorder.IDEQ(oid), paymentorder.StatusIn(OrderStatusPaid, OrderStatusFailed)).SetStatus(OrderStatusRecharging).Save(ctx)
if err != nil {
return fmt.Errorf("lock: %w", err)
}
if c == 0 {
return nil
}
if err := s.doBalance(ctx, o); err != nil {
s.markFailed(ctx, oid, err)
return err
}
return nil
}
// redeemAction represents the idempotency decision for balance fulfillment.
type redeemAction int
const (
// redeemActionCreate: code does not exist — create it, then redeem.
redeemActionCreate redeemAction = iota
// redeemActionRedeem: code exists but is unused — skip creation, redeem only.
redeemActionRedeem
// redeemActionSkipCompleted: code exists and is already used — skip to mark completed.
redeemActionSkipCompleted
)
// resolveRedeemAction decides the idempotency action based on an existing redeem code lookup.
// existing is the result of GetByCode; lookupErr is the error from that call.
func resolveRedeemAction(existing *RedeemCode, lookupErr error) redeemAction {
if existing == nil || lookupErr != nil {
return redeemActionCreate
}
if existing.IsUsed() {
return redeemActionSkipCompleted
}
return redeemActionRedeem
}
func (s *PaymentService) doBalance(ctx context.Context, o *dbent.PaymentOrder) error {
// Idempotency: check if redeem code already exists (from a previous partial run)
existing, lookupErr := s.redeemService.GetByCode(ctx, o.RechargeCode)
action := resolveRedeemAction(existing, lookupErr)
switch action {
case redeemActionSkipCompleted:
// Code already created and redeemed — just mark completed
return s.markCompleted(ctx, o, "RECHARGE_SUCCESS")
case redeemActionCreate:
rc := &RedeemCode{Code: o.RechargeCode, Type: RedeemTypeBalance, Value: o.Amount, Status: StatusUnused}
if err := s.redeemService.CreateCode(ctx, rc); err != nil {
return fmt.Errorf("create redeem code: %w", err)
}
case redeemActionRedeem:
// Code exists but unused — skip creation, proceed to redeem
}
if _, err := s.redeemService.Redeem(ctx, o.UserID, o.RechargeCode); err != nil {
return fmt.Errorf("redeem balance: %w", err)
}
return s.markCompleted(ctx, o, "RECHARGE_SUCCESS")
}
func (s *PaymentService) markCompleted(ctx context.Context, o *dbent.PaymentOrder, auditAction string) error {
now := time.Now()
_, err := s.entClient.PaymentOrder.Update().Where(paymentorder.IDEQ(o.ID), paymentorder.StatusEQ(OrderStatusRecharging)).SetStatus(OrderStatusCompleted).SetCompletedAt(now).Save(ctx)
if err != nil {
return fmt.Errorf("mark completed: %w", err)
}
s.writeAuditLog(ctx, o.ID, auditAction, "system", map[string]any{"rechargeCode": o.RechargeCode, "amount": o.Amount})
return nil
}
func (s *PaymentService) ExecuteSubscriptionFulfillment(ctx context.Context, oid int64) error {
o, err := s.entClient.PaymentOrder.Get(ctx, oid)
if err != nil {
return infraerrors.NotFound("NOT_FOUND", "order not found")
}
if o.Status == OrderStatusCompleted {
return nil
}
if psIsRefundStatus(o.Status) {
return infraerrors.BadRequest("INVALID_STATUS", "refund-related order cannot fulfill")
}
if o.Status != OrderStatusPaid && o.Status != OrderStatusFailed {
return infraerrors.BadRequest("INVALID_STATUS", "order cannot fulfill in status "+o.Status)
}
if o.SubscriptionGroupID == nil || o.SubscriptionDays == nil {
return infraerrors.BadRequest("INVALID_STATUS", "missing subscription info")
}
c, err := s.entClient.PaymentOrder.Update().Where(paymentorder.IDEQ(oid), paymentorder.StatusIn(OrderStatusPaid, OrderStatusFailed)).SetStatus(OrderStatusRecharging).Save(ctx)
if err != nil {
return fmt.Errorf("lock: %w", err)
}
if c == 0 {
return nil
}
if err := s.doSub(ctx, o); err != nil {
s.markFailed(ctx, oid, err)
return err
}
return nil
}
func (s *PaymentService) doSub(ctx context.Context, o *dbent.PaymentOrder) error {
gid := *o.SubscriptionGroupID
days := *o.SubscriptionDays
g, err := s.groupRepo.GetByID(ctx, gid)
if err != nil || g.Status != payment.EntityStatusActive {
return fmt.Errorf("group %d no longer exists or inactive", gid)
}
// Idempotency: check audit log to see if subscription was already assigned.
// Prevents double-extension on retry after markCompleted fails.
if s.hasAuditLog(ctx, o.ID, "SUBSCRIPTION_SUCCESS") {
slog.Info("subscription already assigned for order, skipping", "orderID", o.ID, "groupID", gid)
return s.markCompleted(ctx, o, "SUBSCRIPTION_SUCCESS")
}
orderNote := fmt.Sprintf("payment order %d", o.ID)
_, _, err = s.subscriptionSvc.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{UserID: o.UserID, GroupID: gid, ValidityDays: days, AssignedBy: 0, Notes: orderNote})
if err != nil {
return fmt.Errorf("assign subscription: %w", err)
}
return s.markCompleted(ctx, o, "SUBSCRIPTION_SUCCESS")
}
func (s *PaymentService) hasAuditLog(ctx context.Context, orderID int64, action string) bool {
oid := strconv.FormatInt(orderID, 10)
c, _ := s.entClient.PaymentAuditLog.Query().
Where(paymentauditlog.OrderIDEQ(oid), paymentauditlog.ActionEQ(action)).
Limit(1).Count(ctx)
return c > 0
}
func (s *PaymentService) markFailed(ctx context.Context, oid int64, cause error) {
now := time.Now()
r := psErrMsg(cause)
// Only mark FAILED if still in RECHARGING state — prevents overwriting
// a COMPLETED order when markCompleted failed but fulfillment succeeded.
c, e := s.entClient.PaymentOrder.Update().
Where(paymentorder.IDEQ(oid), paymentorder.StatusEQ(OrderStatusRecharging)).
SetStatus(OrderStatusFailed).SetFailedAt(now).SetFailedReason(r).Save(ctx)
if e != nil {
slog.Error("mark FAILED", "orderID", oid, "error", e)
}
if c > 0 {
s.writeAuditLog(ctx, oid, "FULFILLMENT_FAILED", "system", map[string]any{"reason": r})
}
}
func (s *PaymentService) RetryFulfillment(ctx context.Context, oid int64) error {
o, err := s.entClient.PaymentOrder.Get(ctx, oid)
if err != nil {
return infraerrors.NotFound("NOT_FOUND", "order not found")
}
if o.PaidAt == nil {
return infraerrors.BadRequest("INVALID_STATUS", "order is not paid")
}
if psIsRefundStatus(o.Status) {
return infraerrors.BadRequest("INVALID_STATUS", "refund-related order cannot retry")
}
if o.Status == OrderStatusRecharging {
return infraerrors.Conflict("CONFLICT", "order is being processed")
}
if o.Status == OrderStatusCompleted {
return infraerrors.BadRequest("INVALID_STATUS", "order already completed")
}
if o.Status != OrderStatusFailed && o.Status != OrderStatusPaid {
return infraerrors.BadRequest("INVALID_STATUS", "only paid and failed orders can retry")
}
_, err = s.entClient.PaymentOrder.Update().Where(paymentorder.IDEQ(oid), paymentorder.StatusIn(OrderStatusFailed, OrderStatusPaid)).SetStatus(OrderStatusPaid).ClearFailedAt().ClearFailedReason().Save(ctx)
if err != nil {
return fmt.Errorf("reset for retry: %w", err)
}
s.writeAuditLog(ctx, oid, "RECHARGE_RETRY", "admin", map[string]any{"detail": "admin manual retry"})
return s.executeFulfillment(ctx, oid)
}
//go:build unit
package service
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
)
// ---------------------------------------------------------------------------
// resolveRedeemAction — pure idempotency decision logic
// ---------------------------------------------------------------------------
func TestResolveRedeemAction_CodeNotFound(t *testing.T) {
t.Parallel()
action := resolveRedeemAction(nil, nil)
assert.Equal(t, redeemActionCreate, action, "nil code with nil error should create")
}
func TestResolveRedeemAction_LookupError(t *testing.T) {
t.Parallel()
action := resolveRedeemAction(nil, errors.New("db connection lost"))
assert.Equal(t, redeemActionCreate, action, "lookup error should fall back to create")
}
func TestResolveRedeemAction_LookupErrorWithNonNilCode(t *testing.T) {
t.Parallel()
// Edge case: both code and error are non-nil (shouldn't happen in practice,
// but the function should still treat error as authoritative)
code := &RedeemCode{Status: StatusUnused}
action := resolveRedeemAction(code, errors.New("partial error"))
assert.Equal(t, redeemActionCreate, action, "non-nil error should always result in create regardless of code")
}
func TestResolveRedeemAction_CodeExistsAndUsed(t *testing.T) {
t.Parallel()
code := &RedeemCode{
Code: "test-code-123",
Status: StatusUsed,
Type: RedeemTypeBalance,
Value: 10.0,
}
action := resolveRedeemAction(code, nil)
assert.Equal(t, redeemActionSkipCompleted, action, "used code should skip to completed")
}
func TestResolveRedeemAction_CodeExistsAndUnused(t *testing.T) {
t.Parallel()
code := &RedeemCode{
Code: "test-code-456",
Status: StatusUnused,
Type: RedeemTypeBalance,
Value: 25.0,
}
action := resolveRedeemAction(code, nil)
assert.Equal(t, redeemActionRedeem, action, "unused code should skip creation and proceed to redeem")
}
func TestResolveRedeemAction_CodeExistsWithExpiredStatus(t *testing.T) {
t.Parallel()
// A code with a non-standard status (neither "unused" nor "used")
// should NOT be treated as used, so it falls through to redeemActionRedeem.
code := &RedeemCode{
Code: "expired-code",
Status: StatusExpired,
}
action := resolveRedeemAction(code, nil)
assert.Equal(t, redeemActionRedeem, action, "expired-status code is not IsUsed(), should redeem")
}
// ---------------------------------------------------------------------------
// Table-driven comprehensive test
// ---------------------------------------------------------------------------
func TestResolveRedeemAction_Table(t *testing.T) {
t.Parallel()
tests := []struct {
name string
code *RedeemCode
err error
expected redeemAction
}{
{
name: "nil code, nil error — first run",
code: nil,
err: nil,
expected: redeemActionCreate,
},
{
name: "nil code, lookup error — treat as not found",
code: nil,
err: ErrRedeemCodeNotFound,
expected: redeemActionCreate,
},
{
name: "nil code, generic DB error — treat as not found",
code: nil,
err: errors.New("connection refused"),
expected: redeemActionCreate,
},
{
name: "code exists, used — previous run completed redeem",
code: &RedeemCode{Status: StatusUsed},
err: nil,
expected: redeemActionSkipCompleted,
},
{
name: "code exists, unused — previous run created code but crashed before redeem",
code: &RedeemCode{Status: StatusUnused},
err: nil,
expected: redeemActionRedeem,
},
{
name: "code exists but error also set — error takes precedence",
code: &RedeemCode{Status: StatusUsed},
err: errors.New("unexpected"),
expected: redeemActionCreate,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := resolveRedeemAction(tt.code, tt.err)
assert.Equal(t, tt.expected, got)
})
}
}
// ---------------------------------------------------------------------------
// redeemAction enum value sanity
// ---------------------------------------------------------------------------
func TestRedeemAction_DistinctValues(t *testing.T) {
t.Parallel()
// Ensure the three actions have distinct values (iota correctness)
assert.NotEqual(t, redeemActionCreate, redeemActionRedeem)
assert.NotEqual(t, redeemActionCreate, redeemActionSkipCompleted)
assert.NotEqual(t, redeemActionRedeem, redeemActionSkipCompleted)
}
// ---------------------------------------------------------------------------
// RedeemCode.IsUsed / CanUse interaction with resolveRedeemAction
// ---------------------------------------------------------------------------
func TestResolveRedeemAction_IsUsedCanUseConsistency(t *testing.T) {
t.Parallel()
usedCode := &RedeemCode{Status: StatusUsed}
unusedCode := &RedeemCode{Status: StatusUnused}
// Verify our decision function is consistent with the domain model methods
assert.True(t, usedCode.IsUsed())
assert.False(t, usedCode.CanUse())
assert.Equal(t, redeemActionSkipCompleted, resolveRedeemAction(usedCode, nil))
assert.False(t, unusedCode.IsUsed())
assert.True(t, unusedCode.CanUse())
assert.Equal(t, redeemActionRedeem, resolveRedeemAction(unusedCode, nil))
}
package service
import (
"context"
"fmt"
"log/slog"
"math"
"strconv"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/Wei-Shaw/sub2api/internal/payment/provider"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
// --- Order Creation ---
func (s *PaymentService) CreateOrder(ctx context.Context, req CreateOrderRequest) (*CreateOrderResponse, error) {
if req.OrderType == "" {
req.OrderType = payment.OrderTypeBalance
}
cfg, err := s.configService.GetPaymentConfig(ctx)
if err != nil {
return nil, fmt.Errorf("get payment config: %w", err)
}
if !cfg.Enabled {
return nil, infraerrors.Forbidden("PAYMENT_DISABLED", "payment system is disabled")
}
plan, err := s.validateOrderInput(ctx, req, cfg)
if err != nil {
return nil, err
}
if err := s.checkCancelRateLimit(ctx, req.UserID, cfg); err != nil {
return nil, err
}
user, err := s.userRepo.GetByID(ctx, req.UserID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
if user.Status != payment.EntityStatusActive {
return nil, infraerrors.Forbidden("USER_INACTIVE", "user account is disabled")
}
amount := req.Amount
if plan != nil {
amount = plan.Price
}
feeRate := s.getFeeRate(req.PaymentType)
payAmountStr := payment.CalculatePayAmount(amount, feeRate)
payAmount, _ := strconv.ParseFloat(payAmountStr, 64)
order, err := s.createOrderInTx(ctx, req, user, plan, cfg, amount, feeRate, payAmount)
if err != nil {
return nil, err
}
resp, err := s.invokeProvider(ctx, order, req, cfg, payAmountStr, payAmount, plan)
if err != nil {
_, _ = s.entClient.PaymentOrder.UpdateOneID(order.ID).
SetStatus(OrderStatusFailed).
Save(ctx)
return nil, err
}
return resp, nil
}
func (s *PaymentService) validateOrderInput(ctx context.Context, req CreateOrderRequest, cfg *PaymentConfig) (*dbent.SubscriptionPlan, error) {
if req.OrderType == payment.OrderTypeBalance && cfg.BalanceDisabled {
return nil, infraerrors.Forbidden("BALANCE_PAYMENT_DISABLED", "balance recharge has been disabled")
}
if req.OrderType == payment.OrderTypeSubscription {
return s.validateSubOrder(ctx, req)
}
if math.IsNaN(req.Amount) || math.IsInf(req.Amount, 0) || req.Amount <= 0 {
return nil, infraerrors.BadRequest("INVALID_AMOUNT", "amount must be a positive number")
}
if (cfg.MinAmount > 0 && req.Amount < cfg.MinAmount) || (cfg.MaxAmount > 0 && req.Amount > cfg.MaxAmount) {
return nil, infraerrors.BadRequest("INVALID_AMOUNT", "amount out of range").
WithMetadata(map[string]string{"min": fmt.Sprintf("%.2f", cfg.MinAmount), "max": fmt.Sprintf("%.2f", cfg.MaxAmount)})
}
return nil, nil
}
func (s *PaymentService) validateSubOrder(ctx context.Context, req CreateOrderRequest) (*dbent.SubscriptionPlan, error) {
if req.PlanID == 0 {
return nil, infraerrors.BadRequest("INVALID_INPUT", "subscription order requires a plan")
}
plan, err := s.configService.GetPlan(ctx, req.PlanID)
if err != nil || !plan.ForSale {
return nil, infraerrors.NotFound("PLAN_NOT_AVAILABLE", "plan not found or not for sale")
}
group, err := s.groupRepo.GetByID(ctx, plan.GroupID)
if err != nil || group.Status != payment.EntityStatusActive {
return nil, infraerrors.NotFound("GROUP_NOT_FOUND", "subscription group is no longer available")
}
if !group.IsSubscriptionType() {
return nil, infraerrors.BadRequest("GROUP_TYPE_MISMATCH", "group is not a subscription type")
}
return plan, nil
}
func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderRequest, user *User, plan *dbent.SubscriptionPlan, cfg *PaymentConfig, amount, feeRate, payAmount float64) (*dbent.PaymentOrder, error) {
tx, err := s.entClient.Tx(ctx)
if err != nil {
return nil, fmt.Errorf("begin transaction: %w", err)
}
defer func() { _ = tx.Rollback() }()
if err := s.checkPendingLimit(ctx, tx, req.UserID, cfg.MaxPendingOrders); err != nil {
return nil, err
}
if err := s.checkDailyLimit(ctx, tx, req.UserID, amount, cfg.DailyLimit); err != nil {
return nil, err
}
tm := cfg.OrderTimeoutMin
if tm <= 0 {
tm = defaultOrderTimeoutMin
}
exp := time.Now().Add(time.Duration(tm) * time.Minute)
b := tx.PaymentOrder.Create().
SetUserID(req.UserID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetNillableUserNotes(psNilIfEmpty(user.Notes)).
SetAmount(amount).
SetPayAmount(payAmount).
SetFeeRate(feeRate).
SetRechargeCode("").
SetOutTradeNo(generateOutTradeNo()).
SetPaymentType(req.PaymentType).
SetPaymentTradeNo("").
SetOrderType(req.OrderType).
SetStatus(OrderStatusPending).
SetExpiresAt(exp).
SetClientIP(req.ClientIP).
SetSrcHost(req.SrcHost)
if req.SrcURL != "" {
b.SetSrcURL(req.SrcURL)
}
if plan != nil {
b.SetPlanID(plan.ID).SetSubscriptionGroupID(plan.GroupID).SetSubscriptionDays(psComputeValidityDays(plan.ValidityDays, plan.ValidityUnit))
}
order, err := b.Save(ctx)
if err != nil {
return nil, fmt.Errorf("create order: %w", err)
}
code := fmt.Sprintf("PAY-%d-%d", order.ID, time.Now().UnixNano()%100000)
order, err = tx.PaymentOrder.UpdateOneID(order.ID).SetRechargeCode(code).Save(ctx)
if err != nil {
return nil, fmt.Errorf("set recharge code: %w", err)
}
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("commit order transaction: %w", err)
}
return order, nil
}
func (s *PaymentService) checkPendingLimit(ctx context.Context, tx *dbent.Tx, userID int64, max int) error {
if max <= 0 {
max = defaultMaxPendingOrders
}
c, err := tx.PaymentOrder.Query().Where(paymentorder.UserIDEQ(userID), paymentorder.StatusEQ(OrderStatusPending)).Count(ctx)
if err != nil {
return fmt.Errorf("count pending orders: %w", err)
}
if c >= max {
return infraerrors.TooManyRequests("TOO_MANY_PENDING", fmt.Sprintf("too many pending orders (max %d)", max)).
WithMetadata(map[string]string{"max": strconv.Itoa(max)})
}
return nil
}
func (s *PaymentService) checkCancelRateLimit(ctx context.Context, userID int64, cfg *PaymentConfig) error {
if !cfg.CancelRateLimitEnabled || cfg.CancelRateLimitMax <= 0 {
return nil
}
windowStart := cancelRateLimitWindowStart(cfg)
operator := fmt.Sprintf("user:%d", userID)
count, err := s.entClient.PaymentAuditLog.Query().
Where(
paymentauditlog.ActionEQ("ORDER_CANCELLED"),
paymentauditlog.OperatorEQ(operator),
paymentauditlog.CreatedAtGTE(windowStart),
).Count(ctx)
if err != nil {
slog.Error("check cancel rate limit failed", "userID", userID, "error", err)
return nil // fail open
}
if count >= cfg.CancelRateLimitMax {
return infraerrors.TooManyRequests("CANCEL_RATE_LIMITED", "cancel rate limited").
WithMetadata(map[string]string{
"max": strconv.Itoa(cfg.CancelRateLimitMax),
"window": strconv.Itoa(cfg.CancelRateLimitWindow),
"unit": cfg.CancelRateLimitUnit,
})
}
return nil
}
func cancelRateLimitWindowStart(cfg *PaymentConfig) time.Time {
now := time.Now()
w := cfg.CancelRateLimitWindow
if w <= 0 {
w = 1
}
unit := cfg.CancelRateLimitUnit
if unit == "" {
unit = "day"
}
if cfg.CancelRateLimitMode == "fixed" {
switch unit {
case "minute":
t := now.Truncate(time.Minute)
return t.Add(-time.Duration(w-1) * time.Minute)
case "day":
y, m, d := now.Date()
t := time.Date(y, m, d, 0, 0, 0, 0, now.Location())
return t.AddDate(0, 0, -(w - 1))
default: // hour
t := now.Truncate(time.Hour)
return t.Add(-time.Duration(w-1) * time.Hour)
}
}
// rolling window
switch unit {
case "minute":
return now.Add(-time.Duration(w) * time.Minute)
case "day":
return now.AddDate(0, 0, -w)
default: // hour
return now.Add(-time.Duration(w) * time.Hour)
}
}
func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, userID int64, amount, limit float64) error {
if limit <= 0 {
return nil
}
ts := psStartOfDayUTC(time.Now())
orders, err := tx.PaymentOrder.Query().Where(paymentorder.UserIDEQ(userID), paymentorder.StatusIn(OrderStatusPaid, OrderStatusRecharging, OrderStatusCompleted), paymentorder.PaidAtGTE(ts)).All(ctx)
if err != nil {
return fmt.Errorf("query daily usage: %w", err)
}
var used float64
for _, o := range orders {
used += o.Amount
}
if used+amount > limit {
return infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", fmt.Sprintf("daily recharge limit reached, remaining: %.2f", math.Max(0, limit-used)))
}
return nil
}
func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.PaymentOrder, req CreateOrderRequest, cfg *PaymentConfig, payAmountStr string, payAmount float64, plan *dbent.SubscriptionPlan) (*CreateOrderResponse, error) {
s.EnsureProviders(ctx)
providerKey := s.registry.GetProviderKey(req.PaymentType)
if providerKey == "" {
return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment method (%s) is not configured", req.PaymentType))
}
sel, err := s.loadBalancer.SelectInstance(ctx, providerKey, req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount)
if err != nil {
return nil, fmt.Errorf("select provider instance: %w", err)
}
if sel == nil {
return nil, infraerrors.TooManyRequests("NO_AVAILABLE_INSTANCE", "no available payment instance")
}
prov, err := provider.CreateProvider(providerKey, sel.InstanceID, sel.Config)
if err != nil {
return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", "payment method is temporarily unavailable")
}
subject := s.buildPaymentSubject(plan, payAmountStr, cfg)
outTradeNo := order.OutTradeNo
pr, err := prov.CreatePayment(ctx, payment.CreatePaymentRequest{OrderID: outTradeNo, Amount: payAmountStr, PaymentType: req.PaymentType, Subject: subject, ClientIP: req.ClientIP, IsMobile: req.IsMobile, InstanceSubMethods: sel.SupportedTypes})
if err != nil {
slog.Error("[PaymentService] CreatePayment failed", "provider", providerKey, "instance", sel.InstanceID, "error", err)
return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment gateway error: %s", err.Error()))
}
_, err = s.entClient.PaymentOrder.UpdateOneID(order.ID).SetNillablePaymentTradeNo(psNilIfEmpty(pr.TradeNo)).SetNillablePayURL(psNilIfEmpty(pr.PayURL)).SetNillableQrCode(psNilIfEmpty(pr.QRCode)).SetNillableProviderInstanceID(psNilIfEmpty(sel.InstanceID)).Save(ctx)
if err != nil {
return nil, fmt.Errorf("update order with payment details: %w", err)
}
s.writeAuditLog(ctx, order.ID, "ORDER_CREATED", fmt.Sprintf("user:%d", req.UserID), map[string]any{"amount": req.Amount, "paymentType": req.PaymentType, "orderType": req.OrderType})
return &CreateOrderResponse{OrderID: order.ID, Amount: order.Amount, PayAmount: payAmount, FeeRate: order.FeeRate, Status: OrderStatusPending, PaymentType: req.PaymentType, PayURL: pr.PayURL, QRCode: pr.QRCode, ClientSecret: pr.ClientSecret, ExpiresAt: order.ExpiresAt, PaymentMode: sel.PaymentMode}, nil
}
func (s *PaymentService) buildPaymentSubject(plan *dbent.SubscriptionPlan, payAmountStr string, cfg *PaymentConfig) string {
if plan != nil {
if plan.ProductName != "" {
return plan.ProductName
}
return "Sub2API Subscription " + plan.Name
}
pf := strings.TrimSpace(cfg.ProductNamePrefix)
sf := strings.TrimSpace(cfg.ProductNameSuffix)
if pf != "" || sf != "" {
return strings.TrimSpace(pf + " " + payAmountStr + " " + sf)
}
return "Sub2API " + payAmountStr + " CNY"
}
// --- Order Queries ---
func (s *PaymentService) GetOrder(ctx context.Context, orderID, userID int64) (*dbent.PaymentOrder, error) {
o, err := s.entClient.PaymentOrder.Get(ctx, orderID)
if err != nil {
return nil, infraerrors.NotFound("NOT_FOUND", "order not found")
}
if o.UserID != userID {
return nil, infraerrors.Forbidden("FORBIDDEN", "no permission for this order")
}
return o, nil
}
func (s *PaymentService) GetOrderByID(ctx context.Context, orderID int64) (*dbent.PaymentOrder, error) {
o, err := s.entClient.PaymentOrder.Get(ctx, orderID)
if err != nil {
return nil, infraerrors.NotFound("NOT_FOUND", "order not found")
}
return o, nil
}
func (s *PaymentService) GetUserOrders(ctx context.Context, userID int64, p OrderListParams) ([]*dbent.PaymentOrder, int, error) {
q := s.entClient.PaymentOrder.Query().Where(paymentorder.UserIDEQ(userID))
if p.Status != "" {
q = q.Where(paymentorder.StatusEQ(p.Status))
}
if p.OrderType != "" {
q = q.Where(paymentorder.OrderTypeEQ(p.OrderType))
}
if p.PaymentType != "" {
q = q.Where(paymentorder.PaymentTypeEQ(p.PaymentType))
}
total, err := q.Clone().Count(ctx)
if err != nil {
return nil, 0, fmt.Errorf("count user orders: %w", err)
}
ps, pg := applyPagination(p.PageSize, p.Page)
orders, err := q.Order(dbent.Desc(paymentorder.FieldCreatedAt)).Limit(ps).Offset((pg - 1) * ps).All(ctx)
if err != nil {
return nil, 0, fmt.Errorf("query user orders: %w", err)
}
return orders, total, nil
}
// AdminListOrders returns a paginated list of orders. If userID > 0, filters by user.
func (s *PaymentService) AdminListOrders(ctx context.Context, userID int64, p OrderListParams) ([]*dbent.PaymentOrder, int, error) {
q := s.entClient.PaymentOrder.Query()
if userID > 0 {
q = q.Where(paymentorder.UserIDEQ(userID))
}
if p.Status != "" {
q = q.Where(paymentorder.StatusEQ(p.Status))
}
if p.OrderType != "" {
q = q.Where(paymentorder.OrderTypeEQ(p.OrderType))
}
if p.PaymentType != "" {
q = q.Where(paymentorder.PaymentTypeEQ(p.PaymentType))
}
if p.Keyword != "" {
q = q.Where(paymentorder.Or(
paymentorder.OutTradeNoContainsFold(p.Keyword),
paymentorder.UserEmailContainsFold(p.Keyword),
paymentorder.UserNameContainsFold(p.Keyword),
))
}
total, err := q.Clone().Count(ctx)
if err != nil {
return nil, 0, fmt.Errorf("count admin orders: %w", err)
}
ps, pg := applyPagination(p.PageSize, p.Page)
orders, err := q.Order(dbent.Desc(paymentorder.FieldCreatedAt)).Limit(ps).Offset((pg - 1) * ps).All(ctx)
if err != nil {
return nil, 0, fmt.Errorf("query admin orders: %w", err)
}
return orders, total, nil
}
// --- Cancel & Expire ---
func (s *PaymentService) CancelOrder(ctx context.Context, orderID, userID int64) (string, error) {
o, err := s.entClient.PaymentOrder.Get(ctx, orderID)
if err != nil {
return "", infraerrors.NotFound("NOT_FOUND", "order not found")
}
if o.UserID != userID {
return "", infraerrors.Forbidden("FORBIDDEN", "no permission for this order")
}
if o.Status != OrderStatusPending {
return "", infraerrors.BadRequest("INVALID_STATUS", "order cannot be cancelled in current status")
}
return s.cancelCore(ctx, o, OrderStatusCancelled, fmt.Sprintf("user:%d", userID), "user cancelled order")
}
func (s *PaymentService) AdminCancelOrder(ctx context.Context, orderID int64) (string, error) {
o, err := s.entClient.PaymentOrder.Get(ctx, orderID)
if err != nil {
return "", infraerrors.NotFound("NOT_FOUND", "order not found")
}
if o.Status != OrderStatusPending {
return "", infraerrors.BadRequest("INVALID_STATUS", "order cannot be cancelled in current status")
}
return s.cancelCore(ctx, o, OrderStatusCancelled, "admin", "admin cancelled order")
}
func (s *PaymentService) cancelCore(ctx context.Context, o *dbent.PaymentOrder, fs, op, ad string) (string, error) {
if o.PaymentTradeNo != "" || o.PaymentType != "" {
if s.checkPaid(ctx, o) == "already_paid" {
return "already_paid", nil
}
}
c, err := s.entClient.PaymentOrder.Update().Where(paymentorder.IDEQ(o.ID), paymentorder.StatusEQ(OrderStatusPending)).SetStatus(fs).Save(ctx)
if err != nil {
return "", fmt.Errorf("update order status: %w", err)
}
if c > 0 {
auditAction := "ORDER_CANCELLED"
if fs == OrderStatusExpired {
auditAction = "ORDER_EXPIRED"
}
s.writeAuditLog(ctx, o.ID, auditAction, op, map[string]any{"detail": ad})
}
return "cancelled", nil
}
func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) string {
prov, err := s.getOrderProvider(ctx, o)
if err != nil {
return ""
}
// Use OutTradeNo as fallback when PaymentTradeNo is empty
// (e.g. EasyPay popup mode where trade_no arrives only via notify callback)
tradeNo := o.PaymentTradeNo
if tradeNo == "" {
tradeNo = o.OutTradeNo
}
resp, err := prov.QueryOrder(ctx, tradeNo)
if err != nil {
slog.Warn("query upstream failed", "orderID", o.ID, "error", err)
return ""
}
if resp.Status == payment.ProviderStatusPaid {
if err := s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: o.PaymentTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess}, prov.ProviderKey()); err != nil {
slog.Error("fulfillment failed during checkPaid", "orderID", o.ID, "error", err)
// Still return already_paid — order was paid, fulfillment can be retried
}
return "already_paid"
}
if cp, ok := prov.(payment.CancelableProvider); ok {
_ = cp.CancelPayment(ctx, tradeNo)
}
return ""
}
// VerifyOrderByOutTradeNo actively queries the upstream provider to check
// if a payment was made, and processes it if so. This handles the case where
// the provider's notify callback was missed (e.g. EasyPay popup mode).
func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo string, userID int64) (*dbent.PaymentOrder, error) {
o, err := s.entClient.PaymentOrder.Query().
Where(paymentorder.OutTradeNo(outTradeNo)).
Only(ctx)
if err != nil {
return nil, infraerrors.NotFound("NOT_FOUND", "order not found")
}
if o.UserID != userID {
return nil, infraerrors.Forbidden("FORBIDDEN", "no permission for this order")
}
// Only verify orders that are still pending or recently expired
if o.Status == OrderStatusPending || o.Status == OrderStatusExpired {
result := s.checkPaid(ctx, o)
if result == "already_paid" {
// Reload order to get updated status
o, err = s.entClient.PaymentOrder.Get(ctx, o.ID)
if err != nil {
return nil, fmt.Errorf("reload order: %w", err)
}
}
}
return o, nil
}
// VerifyOrderPublic verifies payment status without user authentication.
// Used by the payment result page when the user's session has expired.
func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo string) (*dbent.PaymentOrder, error) {
o, err := s.entClient.PaymentOrder.Query().
Where(paymentorder.OutTradeNo(outTradeNo)).
Only(ctx)
if err != nil {
return nil, infraerrors.NotFound("NOT_FOUND", "order not found")
}
if o.Status == OrderStatusPending || o.Status == OrderStatusExpired {
result := s.checkPaid(ctx, o)
if result == "already_paid" {
o, err = s.entClient.PaymentOrder.Get(ctx, o.ID)
if err != nil {
return nil, fmt.Errorf("reload order: %w", err)
}
}
}
return o, nil
}
func (s *PaymentService) ExpireTimedOutOrders(ctx context.Context) (int, error) {
now := time.Now()
orders, err := s.entClient.PaymentOrder.Query().Where(paymentorder.StatusEQ(OrderStatusPending), paymentorder.ExpiresAtLTE(now)).All(ctx)
if err != nil {
return 0, fmt.Errorf("query expired: %w", err)
}
n := 0
for _, o := range orders {
// Check upstream payment status before expiring — the user may have
// paid just before timeout and the webhook hasn't arrived yet.
outcome, _ := s.cancelCore(ctx, o, OrderStatusExpired, "system", "order expired")
if outcome == "already_paid" {
slog.Info("order was paid during expiry", "orderID", o.ID)
continue
}
if outcome != "" {
n++
}
}
return n, nil
}
// getOrderProvider creates a provider using the order's original instance config.
// Falls back to registry lookup if instance ID is missing (legacy orders).
func (s *PaymentService) getOrderProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) {
if o.ProviderInstanceID != nil && *o.ProviderInstanceID != "" {
instID, err := strconv.ParseInt(*o.ProviderInstanceID, 10, 64)
if err == nil {
cfg, err := s.loadBalancer.GetInstanceConfig(ctx, instID)
if err == nil {
providerKey := s.registry.GetProviderKey(o.PaymentType)
if providerKey == "" {
providerKey = o.PaymentType
}
p, err := provider.CreateProvider(providerKey, *o.ProviderInstanceID, cfg)
if err == nil {
return p, nil
}
}
}
}
s.EnsureProviders(ctx)
return s.registry.GetProvider(o.PaymentType)
}
package service
import (
"context"
"log/slog"
"sync"
"time"
)
const expiryCheckTimeout = 30 * time.Second
// PaymentOrderExpiryService periodically expires timed-out payment orders.
type PaymentOrderExpiryService struct {
paymentSvc *PaymentService
interval time.Duration
stopCh chan struct{}
stopOnce sync.Once
wg sync.WaitGroup
}
func NewPaymentOrderExpiryService(paymentSvc *PaymentService, interval time.Duration) *PaymentOrderExpiryService {
return &PaymentOrderExpiryService{
paymentSvc: paymentSvc,
interval: interval,
stopCh: make(chan struct{}),
}
}
func (s *PaymentOrderExpiryService) Start() {
if s == nil || s.paymentSvc == nil || s.interval <= 0 {
return
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
ticker := time.NewTicker(s.interval)
defer ticker.Stop()
s.runOnce()
for {
select {
case <-ticker.C:
s.runOnce()
case <-s.stopCh:
return
}
}
}()
}
func (s *PaymentOrderExpiryService) Stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
close(s.stopCh)
})
s.wg.Wait()
}
func (s *PaymentOrderExpiryService) runOnce() {
ctx, cancel := context.WithTimeout(context.Background(), expiryCheckTimeout)
defer cancel()
expired, err := s.paymentSvc.ExpireTimedOutOrders(ctx)
if err != nil {
slog.Error("[PaymentOrderExpiry] failed to expire orders", "error", err)
return
}
if expired > 0 {
slog.Info("[PaymentOrderExpiry] expired timed-out orders", "count", expired)
}
}
package service
import (
"context"
"fmt"
"log/slog"
"math"
"strconv"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
// --- Refund Flow ---
func (s *PaymentService) RequestRefund(ctx context.Context, oid, uid int64, reason string) error {
o, err := s.validateRefundRequest(ctx, oid, uid)
if err != nil {
return err
}
u, err := s.userRepo.GetByID(ctx, o.UserID)
if err != nil {
return fmt.Errorf("get user: %w", err)
}
if u.Balance < o.Amount {
return infraerrors.BadRequest("BALANCE_NOT_ENOUGH", "refund amount exceeds balance")
}
nr := strings.TrimSpace(reason)
now := time.Now()
by := fmt.Sprintf("%d", uid)
c, err := s.entClient.PaymentOrder.Update().Where(paymentorder.IDEQ(oid), paymentorder.UserIDEQ(uid), paymentorder.StatusEQ(OrderStatusCompleted), paymentorder.OrderTypeEQ(payment.OrderTypeBalance)).SetStatus(OrderStatusRefundRequested).SetRefundRequestedAt(now).SetRefundRequestReason(nr).SetRefundRequestedBy(by).SetRefundAmount(o.Amount).Save(ctx)
if err != nil {
return fmt.Errorf("update: %w", err)
}
if c == 0 {
return infraerrors.Conflict("CONFLICT", "order status changed")
}
s.writeAuditLog(ctx, oid, "REFUND_REQUESTED", fmt.Sprintf("user:%d", uid), map[string]any{"amount": o.Amount, "reason": nr})
return nil
}
func (s *PaymentService) validateRefundRequest(ctx context.Context, oid, uid int64) (*dbent.PaymentOrder, error) {
o, err := s.entClient.PaymentOrder.Get(ctx, oid)
if err != nil {
return nil, infraerrors.NotFound("NOT_FOUND", "order not found")
}
if o.UserID != uid {
return nil, infraerrors.Forbidden("FORBIDDEN", "no permission")
}
if o.OrderType != payment.OrderTypeBalance {
return nil, infraerrors.BadRequest("INVALID_ORDER_TYPE", "only balance orders can request refund")
}
if o.Status != OrderStatusCompleted {
return nil, infraerrors.BadRequest("INVALID_STATUS", "only completed orders can request refund")
}
return o, nil
}
func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float64, reason string, force, deduct bool) (*RefundPlan, *RefundResult, error) {
o, err := s.entClient.PaymentOrder.Get(ctx, oid)
if err != nil {
return nil, nil, infraerrors.NotFound("NOT_FOUND", "order not found")
}
ok := []string{OrderStatusCompleted, OrderStatusRefundRequested, OrderStatusRefundFailed}
if !psSliceContains(ok, o.Status) {
return nil, nil, infraerrors.BadRequest("INVALID_STATUS", "order status does not allow refund")
}
if math.IsNaN(amt) || math.IsInf(amt, 0) {
return nil, nil, infraerrors.BadRequest("INVALID_AMOUNT", "invalid refund amount")
}
if amt <= 0 {
amt = o.Amount
}
if amt-o.Amount > amountToleranceCNY {
return nil, nil, infraerrors.BadRequest("REFUND_AMOUNT_EXCEEDED", "refund amount exceeds recharge")
}
// Full refund: use actual pay_amount for gateway (includes fees)
ga := amt
if math.Abs(amt-o.Amount) <= amountToleranceCNY {
ga = o.PayAmount
}
rr := strings.TrimSpace(reason)
if rr == "" && o.RefundRequestReason != nil {
rr = *o.RefundRequestReason
}
if rr == "" {
rr = fmt.Sprintf("refund order:%d", o.ID)
}
p := &RefundPlan{OrderID: oid, Order: o, RefundAmount: amt, GatewayAmount: ga, Reason: rr, Force: force, DeductBalance: deduct, DeductionType: payment.DeductionTypeNone}
if deduct {
if er := s.prepDeduct(ctx, o, p, force); er != nil {
return nil, er, nil
}
}
return p, nil, nil
}
func (s *PaymentService) prepDeduct(ctx context.Context, o *dbent.PaymentOrder, p *RefundPlan, force bool) *RefundResult {
if o.OrderType == payment.OrderTypeSubscription {
p.DeductionType = payment.DeductionTypeSubscription
if o.SubscriptionGroupID != nil && o.SubscriptionDays != nil {
p.SubDaysToDeduct = *o.SubscriptionDays
sub, err := s.subscriptionSvc.GetActiveSubscription(ctx, o.UserID, *o.SubscriptionGroupID)
if err == nil && sub != nil {
p.SubscriptionID = sub.ID
} else if !force {
return &RefundResult{Success: false, Warning: "cannot find active subscription for deduction, use force", RequireForce: true}
}
}
return nil
}
u, err := s.userRepo.GetByID(ctx, o.UserID)
if err != nil {
if !force {
return &RefundResult{Success: false, Warning: "cannot fetch user balance, use force", RequireForce: true}
}
return nil
}
p.DeductionType = payment.DeductionTypeBalance
p.BalanceToDeduct = math.Min(p.RefundAmount, u.Balance)
return nil
}
func (s *PaymentService) ExecuteRefund(ctx context.Context, p *RefundPlan) (*RefundResult, error) {
c, err := s.entClient.PaymentOrder.Update().Where(paymentorder.IDEQ(p.OrderID), paymentorder.StatusIn(OrderStatusCompleted, OrderStatusRefundRequested, OrderStatusRefundFailed)).SetStatus(OrderStatusRefunding).Save(ctx)
if err != nil {
return nil, fmt.Errorf("lock: %w", err)
}
if c == 0 {
return nil, infraerrors.Conflict("CONFLICT", "order status changed")
}
if p.DeductionType == payment.DeductionTypeBalance && p.BalanceToDeduct > 0 {
// Skip balance deduction on retry if previous attempt already deducted
// but failed to roll back (REFUND_ROLLBACK_FAILED in audit log).
if !s.hasAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED") {
if err := s.userRepo.DeductBalance(ctx, p.Order.UserID, p.BalanceToDeduct); err != nil {
s.restoreStatus(ctx, p)
return nil, fmt.Errorf("deduction: %w", err)
}
} else {
slog.Warn("skipping balance deduction on retry (previous rollback failed)", "orderID", p.OrderID)
p.BalanceToDeduct = 0
}
}
if p.DeductionType == payment.DeductionTypeSubscription && p.SubDaysToDeduct > 0 && p.SubscriptionID > 0 {
if !s.hasAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED") {
_, err := s.subscriptionSvc.ExtendSubscription(ctx, p.SubscriptionID, -p.SubDaysToDeduct)
if err != nil {
// If deducting would expire the subscription, revoke it entirely
slog.Info("subscription deduction would expire, revoking", "orderID", p.OrderID, "subID", p.SubscriptionID, "days", p.SubDaysToDeduct)
if revokeErr := s.subscriptionSvc.RevokeSubscription(ctx, p.SubscriptionID); revokeErr != nil {
s.restoreStatus(ctx, p)
return nil, fmt.Errorf("revoke subscription: %w", revokeErr)
}
}
} else {
slog.Warn("skipping subscription deduction on retry (previous rollback failed)", "orderID", p.OrderID)
p.SubDaysToDeduct = 0
}
}
if err := s.gwRefund(ctx, p); err != nil {
return s.handleGwFail(ctx, p, err)
}
return s.markRefundOk(ctx, p)
}
func (s *PaymentService) gwRefund(ctx context.Context, p *RefundPlan) error {
if p.Order.PaymentTradeNo == "" {
s.writeAuditLog(ctx, p.Order.ID, "REFUND_NO_TRADE_NO", "admin", map[string]any{"detail": "skipped"})
return nil
}
// Use the exact provider instance that created this order, not a random one
// from the registry. Each instance has its own merchant credentials.
prov, err := s.getRefundProvider(ctx, p.Order)
if err != nil {
return fmt.Errorf("get refund provider: %w", err)
}
_, err = prov.Refund(ctx, payment.RefundRequest{
TradeNo: p.Order.PaymentTradeNo,
OrderID: p.Order.OutTradeNo,
Amount: strconv.FormatFloat(p.GatewayAmount, 'f', 2, 64),
Reason: p.Reason,
})
return err
}
// getRefundProvider creates a provider using the order's original instance config.
// Delegates to getOrderProvider which handles instance lookup and fallback.
func (s *PaymentService) getRefundProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) {
return s.getOrderProvider(ctx, o)
}
func (s *PaymentService) handleGwFail(ctx context.Context, p *RefundPlan, gErr error) (*RefundResult, error) {
if s.RollbackRefund(ctx, p, gErr) {
s.restoreStatus(ctx, p)
s.writeAuditLog(ctx, p.OrderID, "REFUND_GATEWAY_FAILED", "admin", map[string]any{"detail": psErrMsg(gErr)})
return &RefundResult{Success: false, Warning: "gateway failed: " + psErrMsg(gErr) + ", rolled back"}, nil
}
now := time.Now()
_, _ = s.entClient.PaymentOrder.UpdateOneID(p.OrderID).SetStatus(OrderStatusRefundFailed).SetFailedAt(now).SetFailedReason(psErrMsg(gErr)).Save(ctx)
s.writeAuditLog(ctx, p.OrderID, "REFUND_FAILED", "admin", map[string]any{"detail": psErrMsg(gErr)})
return nil, infraerrors.InternalServer("REFUND_FAILED", psErrMsg(gErr))
}
func (s *PaymentService) markRefundOk(ctx context.Context, p *RefundPlan) (*RefundResult, error) {
fs := OrderStatusRefunded
if p.RefundAmount < p.Order.Amount {
fs = OrderStatusPartiallyRefunded
}
now := time.Now()
_, err := s.entClient.PaymentOrder.UpdateOneID(p.OrderID).SetStatus(fs).SetRefundAmount(p.RefundAmount).SetRefundReason(p.Reason).SetRefundAt(now).SetForceRefund(p.Force).Save(ctx)
if err != nil {
return nil, fmt.Errorf("mark refund: %w", err)
}
s.writeAuditLog(ctx, p.OrderID, "REFUND_SUCCESS", "admin", map[string]any{"refundAmount": p.RefundAmount, "reason": p.Reason, "balanceDeducted": p.BalanceToDeduct, "force": p.Force})
return &RefundResult{Success: true, BalanceDeducted: p.BalanceToDeduct, SubDaysDeducted: p.SubDaysToDeduct}, nil
}
func (s *PaymentService) RollbackRefund(ctx context.Context, p *RefundPlan, gErr error) bool {
if p.DeductionType == payment.DeductionTypeBalance && p.BalanceToDeduct > 0 {
if err := s.userRepo.UpdateBalance(ctx, p.Order.UserID, p.BalanceToDeduct); err != nil {
slog.Error("[CRITICAL] rollback failed", "orderID", p.OrderID, "amount", p.BalanceToDeduct, "error", err)
s.writeAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED", "admin", map[string]any{"gatewayError": psErrMsg(gErr), "rollbackError": psErrMsg(err), "balanceDeducted": p.BalanceToDeduct})
return false
}
}
if p.DeductionType == payment.DeductionTypeSubscription && p.SubDaysToDeduct > 0 && p.SubscriptionID > 0 {
if _, err := s.subscriptionSvc.ExtendSubscription(ctx, p.SubscriptionID, p.SubDaysToDeduct); err != nil {
slog.Error("[CRITICAL] subscription rollback failed", "orderID", p.OrderID, "subID", p.SubscriptionID, "days", p.SubDaysToDeduct, "error", err)
s.writeAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED", "admin", map[string]any{"gatewayError": psErrMsg(gErr), "rollbackError": psErrMsg(err), "subDaysDeducted": p.SubDaysToDeduct})
return false
}
}
return true
}
func (s *PaymentService) restoreStatus(ctx context.Context, p *RefundPlan) {
rs := OrderStatusCompleted
if p.Order.Status == OrderStatusRefundRequested {
rs = OrderStatusRefundRequested
}
_, _ = s.entClient.PaymentOrder.UpdateOneID(p.OrderID).SetStatus(rs).Save(ctx)
}
package service
import (
"context"
"fmt"
"log/slog"
"math/rand/v2"
"sync"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/Wei-Shaw/sub2api/internal/payment/provider"
)
// --- Order Status Constants ---
const (
OrderStatusPending = payment.OrderStatusPending
OrderStatusPaid = payment.OrderStatusPaid
OrderStatusRecharging = payment.OrderStatusRecharging
OrderStatusCompleted = payment.OrderStatusCompleted
OrderStatusExpired = payment.OrderStatusExpired
OrderStatusCancelled = payment.OrderStatusCancelled
OrderStatusFailed = payment.OrderStatusFailed
OrderStatusRefundRequested = payment.OrderStatusRefundRequested
OrderStatusRefunding = payment.OrderStatusRefunding
OrderStatusPartiallyRefunded = payment.OrderStatusPartiallyRefunded
OrderStatusRefunded = payment.OrderStatusRefunded
OrderStatusRefundFailed = payment.OrderStatusRefundFailed
)
const (
// defaultMaxPendingOrders and defaultOrderTimeoutMin are defined in
// payment_config_service.go alongside other payment configuration defaults.
paymentGraceMinutes = 5
defaultPageSize = 20
maxPageSize = 100
topUsersLimit = 10
amountToleranceCNY = 0.01
orderIDPrefix = "sub2_"
)
// --- Types ---
// generateOutTradeNo creates a unique external order ID for payment providers.
// Format: sub2_20250409aB3kX9mQ (prefix + date + 8-char random)
func generateOutTradeNo() string {
date := time.Now().Format("20060102")
rnd := generateRandomString(8)
return orderIDPrefix + date + rnd
}
func generateRandomString(n int) string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
b := make([]byte, n)
for i := range b {
b[i] = charset[rand.IntN(len(charset))]
}
return string(b)
}
type CreateOrderRequest struct {
UserID int64
Amount float64
PaymentType string
ClientIP string
IsMobile bool
SrcHost string
SrcURL string
OrderType string
PlanID int64
}
type CreateOrderResponse struct {
OrderID int64 `json:"order_id"`
Amount float64 `json:"amount"`
PayAmount float64 `json:"pay_amount"`
FeeRate float64 `json:"fee_rate"`
Status string `json:"status"`
PaymentType string `json:"payment_type"`
PayURL string `json:"pay_url,omitempty"`
QRCode string `json:"qr_code,omitempty"`
ClientSecret string `json:"client_secret,omitempty"`
ExpiresAt time.Time `json:"expires_at"`
PaymentMode string `json:"payment_mode,omitempty"`
}
type OrderListParams struct {
Page int
PageSize int
Status string
OrderType string
PaymentType string
Keyword string
}
type RefundPlan struct {
OrderID int64
Order *dbent.PaymentOrder
RefundAmount float64
GatewayAmount float64
Reason string
Force bool
DeductBalance bool
DeductionType string
BalanceToDeduct float64
SubDaysToDeduct int
SubscriptionID int64
}
type RefundResult struct {
Success bool `json:"success"`
Warning string `json:"warning,omitempty"`
RequireForce bool `json:"require_force,omitempty"`
BalanceDeducted float64 `json:"balance_deducted,omitempty"`
SubDaysDeducted int `json:"subscription_days_deducted,omitempty"`
}
type DashboardStats struct {
TodayAmount float64 `json:"today_amount"`
TotalAmount float64 `json:"total_amount"`
TodayCount int `json:"today_count"`
TotalCount int `json:"total_count"`
AvgAmount float64 `json:"avg_amount"`
PendingOrders int `json:"pending_orders"`
DailySeries []DailyStats `json:"daily_series"`
PaymentMethods []PaymentMethodStat `json:"payment_methods"`
TopUsers []TopUserStat `json:"top_users"`
}
type DailyStats struct {
Date string `json:"date"`
Amount float64 `json:"amount"`
Count int `json:"count"`
}
type PaymentMethodStat struct {
Type string `json:"type"`
Amount float64 `json:"amount"`
Count int `json:"count"`
}
type TopUserStat struct {
UserID int64 `json:"user_id"`
Email string `json:"email"`
Amount float64 `json:"amount"`
}
// --- Service ---
type PaymentService struct {
providerMu sync.Mutex
providersLoaded bool
entClient *dbent.Client
registry *payment.Registry
loadBalancer payment.LoadBalancer
redeemService *RedeemService
subscriptionSvc *SubscriptionService
configService *PaymentConfigService
userRepo UserRepository
groupRepo GroupRepository
}
func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService {
return &PaymentService{entClient: entClient, registry: registry, loadBalancer: loadBalancer, redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo}
}
// --- Provider Registry ---
// EnsureProviders lazily initializes the provider registry on first call.
func (s *PaymentService) EnsureProviders(ctx context.Context) {
s.providerMu.Lock()
defer s.providerMu.Unlock()
if !s.providersLoaded {
s.loadProviders(ctx)
s.providersLoaded = true
}
}
// RefreshProviders clears and re-registers all providers from the database.
func (s *PaymentService) RefreshProviders(ctx context.Context) {
s.providerMu.Lock()
defer s.providerMu.Unlock()
s.registry.Clear()
s.loadProviders(ctx)
s.providersLoaded = true
}
func (s *PaymentService) loadProviders(ctx context.Context) {
instances, err := s.entClient.PaymentProviderInstance.Query().
Where(paymentproviderinstance.EnabledEQ(true)).
All(ctx)
if err != nil {
slog.Error("[PaymentService] failed to query provider instances", "error", err)
return
}
for _, inst := range instances {
cfg, err := s.loadBalancer.GetInstanceConfig(ctx, int64(inst.ID))
if err != nil {
slog.Warn("[PaymentService] failed to decrypt config for instance", "instanceID", inst.ID, "error", err)
continue
}
if inst.PaymentMode != "" {
cfg["paymentMode"] = inst.PaymentMode
}
instID := fmt.Sprintf("%d", inst.ID)
p, err := provider.CreateProvider(inst.ProviderKey, instID, cfg)
if err != nil {
slog.Warn("[PaymentService] failed to create provider for instance", "instanceID", inst.ID, "key", inst.ProviderKey, "error", err)
continue
}
s.registry.Register(p)
}
}
// GetWebhookProvider returns the provider instance that should verify a webhook.
// It extracts out_trade_no from the raw body, looks up the order to find the
// original provider instance, and creates a provider with that instance's credentials.
// Falls back to the registry provider when the order cannot be found.
func (s *PaymentService) GetWebhookProvider(ctx context.Context, providerKey, outTradeNo string) (payment.Provider, error) {
if outTradeNo != "" {
order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(outTradeNo)).Only(ctx)
if err == nil {
p, pErr := s.getOrderProvider(ctx, order)
if pErr == nil {
return p, nil
}
slog.Warn("[Webhook] order provider creation failed, falling back to registry", "outTradeNo", outTradeNo, "error", pErr)
}
}
s.EnsureProviders(ctx)
return s.registry.GetProviderByKey(providerKey)
}
// --- Helpers ---
func psIsRefundStatus(s string) bool {
switch s {
case OrderStatusRefundRequested, OrderStatusRefunding, OrderStatusPartiallyRefunded, OrderStatusRefunded, OrderStatusRefundFailed:
return true
}
return false
}
func psErrMsg(err error) string {
if err == nil {
return ""
}
return err.Error()
}
func psNilIfEmpty(s string) *string {
if s == "" {
return nil
}
return &s
}
func psSliceContains(sl []string, s string) bool {
for _, v := range sl {
if v == s {
return true
}
}
return false
}
func psComputeValidityDays(days int, unit string) int {
switch unit {
case "week":
return days * 7
case "month":
return days * 30
default:
return days
}
}
func (s *PaymentService) getFeeRate(_ string) float64 { return 0 }
func psStartOfDayUTC(t time.Time) time.Time {
y, m, d := t.UTC().Date()
return time.Date(y, m, d, 0, 0, 0, 0, time.UTC)
}
func applyPagination(pageSize, page int) (size, pg int) {
size = pageSize
if size <= 0 {
size = defaultPageSize
}
if size > maxPageSize {
size = maxPageSize
}
pg = page
if pg < 1 {
pg = 1
}
return size, pg
}
package service
import (
"context"
"encoding/json"
"log/slog"
"math"
"sort"
"strconv"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
)
// --- Dashboard & Analytics ---
func (s *PaymentService) GetDashboardStats(ctx context.Context, days int) (*DashboardStats, error) {
if days <= 0 {
days = 30
}
now := time.Now()
since := now.AddDate(0, 0, -days)
todayStart := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
paidStatuses := []string{OrderStatusCompleted, OrderStatusPaid, OrderStatusRecharging}
orders, err := s.entClient.PaymentOrder.Query().
Where(
paymentorder.StatusIn(paidStatuses...),
paymentorder.PaidAtGTE(since),
).
All(ctx)
if err != nil {
return nil, err
}
st := &DashboardStats{}
computeBasicStats(st, orders, todayStart)
st.PendingOrders, err = s.entClient.PaymentOrder.Query().
Where(paymentorder.StatusEQ(OrderStatusPending)).
Count(ctx)
if err != nil {
return nil, err
}
st.DailySeries = buildDailySeries(orders, since, days)
st.PaymentMethods = buildMethodDistribution(orders)
st.TopUsers = buildTopUsers(orders)
return st, nil
}
func computeBasicStats(st *DashboardStats, orders []*dbent.PaymentOrder, todayStart time.Time) {
var totalAmount, todayAmount float64
var todayCount int
for _, o := range orders {
totalAmount += o.PayAmount
if o.PaidAt != nil && !o.PaidAt.Before(todayStart) {
todayAmount += o.PayAmount
todayCount++
}
}
st.TotalAmount = math.Round(totalAmount*100) / 100
st.TodayAmount = math.Round(todayAmount*100) / 100
st.TotalCount = len(orders)
st.TodayCount = todayCount
if st.TotalCount > 0 {
st.AvgAmount = math.Round(totalAmount/float64(st.TotalCount)*100) / 100
}
}
func buildDailySeries(orders []*dbent.PaymentOrder, since time.Time, days int) []DailyStats {
dailyMap := make(map[string]*DailyStats)
for _, o := range orders {
if o.PaidAt == nil {
continue
}
date := o.PaidAt.Format("2006-01-02")
ds, ok := dailyMap[date]
if !ok {
ds = &DailyStats{Date: date}
dailyMap[date] = ds
}
ds.Amount += o.PayAmount
ds.Count++
}
series := make([]DailyStats, 0, days)
for i := 0; i < days; i++ {
date := since.AddDate(0, 0, i+1).Format("2006-01-02")
if ds, ok := dailyMap[date]; ok {
ds.Amount = math.Round(ds.Amount*100) / 100
series = append(series, *ds)
} else {
series = append(series, DailyStats{Date: date})
}
}
return series
}
func buildMethodDistribution(orders []*dbent.PaymentOrder) []PaymentMethodStat {
methodMap := make(map[string]*PaymentMethodStat)
for _, o := range orders {
ms, ok := methodMap[o.PaymentType]
if !ok {
ms = &PaymentMethodStat{Type: o.PaymentType}
methodMap[o.PaymentType] = ms
}
ms.Amount += o.PayAmount
ms.Count++
}
methods := make([]PaymentMethodStat, 0, len(methodMap))
for _, ms := range methodMap {
ms.Amount = math.Round(ms.Amount*100) / 100
methods = append(methods, *ms)
}
return methods
}
func buildTopUsers(orders []*dbent.PaymentOrder) []TopUserStat {
userMap := make(map[int64]*TopUserStat)
for _, o := range orders {
us, ok := userMap[o.UserID]
if !ok {
us = &TopUserStat{UserID: o.UserID, Email: o.UserEmail}
userMap[o.UserID] = us
}
us.Amount += o.PayAmount
}
userList := make([]*TopUserStat, 0, len(userMap))
for _, us := range userMap {
us.Amount = math.Round(us.Amount*100) / 100
userList = append(userList, us)
}
sort.Slice(userList, func(i, j int) bool {
return userList[i].Amount > userList[j].Amount
})
limit := topUsersLimit
if len(userList) < limit {
limit = len(userList)
}
result := make([]TopUserStat, 0, limit)
for i := 0; i < limit; i++ {
result = append(result, *userList[i])
}
return result
}
// --- Audit Logs ---
func (s *PaymentService) writeAuditLog(ctx context.Context, oid int64, action, op string, detail map[string]any) {
dj, _ := json.Marshal(detail)
_, err := s.entClient.PaymentAuditLog.Create().SetOrderID(strconv.FormatInt(oid, 10)).SetAction(action).SetDetail(string(dj)).SetOperator(op).Save(ctx)
if err != nil {
slog.Error("audit log failed", "orderID", oid, "action", action, "error", err)
}
}
func (s *PaymentService) GetOrderAuditLogs(ctx context.Context, oid int64) ([]*dbent.PaymentAuditLog, error) {
return s.entClient.PaymentAuditLog.Query().Where(paymentauditlog.OrderIDEQ(strconv.FormatInt(oid, 10))).Order(paymentauditlog.ByCreatedAt()).All(ctx)
}
...@@ -167,6 +167,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings ...@@ -167,6 +167,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyBackendModeEnabled, SettingKeyBackendModeEnabled,
SettingKeyOIDCConnectEnabled, SettingKeyOIDCConnectEnabled,
SettingKeyOIDCConnectProviderName, SettingKeyOIDCConnectProviderName,
SettingPaymentEnabled,
} }
settings, err := s.settingRepo.GetMultiple(ctx, keys) settings, err := s.settingRepo.GetMultiple(ctx, keys)
...@@ -227,6 +228,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings ...@@ -227,6 +228,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true", BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
OIDCOAuthEnabled: oidcEnabled, OIDCOAuthEnabled: oidcEnabled,
OIDCOAuthProviderName: oidcProviderName, OIDCOAuthProviderName: oidcProviderName,
PaymentEnabled: settings[SettingPaymentEnabled] == "true",
}, nil }, nil
} }
...@@ -276,6 +278,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any ...@@ -276,6 +278,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
BackendModeEnabled bool `json:"backend_mode_enabled"` BackendModeEnabled bool `json:"backend_mode_enabled"`
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"` OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"` OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
PaymentEnabled bool `json:"payment_enabled"`
Version string `json:"version,omitempty"` Version string `json:"version,omitempty"`
}{ }{
RegistrationEnabled: settings.RegistrationEnabled, RegistrationEnabled: settings.RegistrationEnabled,
...@@ -303,6 +306,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any ...@@ -303,6 +306,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
BackendModeEnabled: settings.BackendModeEnabled, BackendModeEnabled: settings.BackendModeEnabled,
OIDCOAuthEnabled: settings.OIDCOAuthEnabled, OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
OIDCOAuthProviderName: settings.OIDCOAuthProviderName, OIDCOAuthProviderName: settings.OIDCOAuthProviderName,
PaymentEnabled: settings.PaymentEnabled,
Version: s.version, Version: s.version,
}, nil }, nil
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment