Unverified Commit 8eb3f9e7 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #1785 from IanShaw027/rebuild/auth-identity-foundation

feat(auth,payment): 重构认证身份和支付系统及其他部分优化
parents 78f691d2 7fbd5177
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/payment"
)
func TestUsesOfficialWxpayVisibleMethodDerivesFromEnabledProviderInstance(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
_, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeWxpay).
SetName("Official WeChat").
SetConfig("{}").
SetSupportedTypes("wxpay").
SetEnabled(true).
SetSortOrder(1).
Save(ctx)
if err != nil {
t.Fatalf("create official wxpay instance: %v", err)
}
svc := &PaymentService{
configService: &PaymentConfigService{entClient: client},
}
if !svc.usesOfficialWxpayVisibleMethod(ctx) {
t.Fatal("expected official wxpay visible method to be detected from enabled provider instance")
}
}
......@@ -5,6 +5,7 @@ import (
"fmt"
"log/slog"
"strconv"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
......@@ -139,30 +140,86 @@ func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) s
if err != nil {
return ""
}
// Use OutTradeNo as fallback when PaymentTradeNo is empty
// (e.g. EasyPay popup mode where trade_no arrives only via notify callback)
tradeNo := o.PaymentTradeNo
if tradeNo == "" {
tradeNo = o.OutTradeNo
queryRef := paymentOrderQueryReference(o, prov)
if queryRef == "" {
return ""
}
resp, err := prov.QueryOrder(ctx, tradeNo)
resp, err := prov.QueryOrder(ctx, queryRef)
if err != nil {
slog.Warn("query upstream failed", "orderID", o.ID, "error", err)
return ""
}
if resp.Status == payment.ProviderStatusPaid {
if err := s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: o.PaymentTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess}, prov.ProviderKey()); err != nil {
notificationTradeNo := o.PaymentTradeNo
if upstreamTradeNo := strings.TrimSpace(resp.TradeNo); paymentOrderShouldPersistUpstreamTradeNo(queryRef, upstreamTradeNo, notificationTradeNo) {
if _, updateErr := s.entClient.PaymentOrder.Update().
Where(paymentorder.IDEQ(o.ID)).
SetPaymentTradeNo(upstreamTradeNo).
Save(ctx); updateErr != nil {
slog.Error("persist upstream trade no during checkPaid failed", "orderID", o.ID, "tradeNo", upstreamTradeNo, "error", updateErr)
} else {
o.PaymentTradeNo = upstreamTradeNo
}
notificationTradeNo = upstreamTradeNo
}
if err := s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: notificationTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess, Metadata: resp.Metadata}, prov.ProviderKey()); err != nil {
slog.Error("fulfillment failed during checkPaid", "orderID", o.ID, "error", err)
// Still return already_paid — order was paid, fulfillment can be retried
}
return checkPaidResultAlreadyPaid
}
if cp, ok := prov.(payment.CancelableProvider); ok {
_ = cp.CancelPayment(ctx, tradeNo)
_ = cp.CancelPayment(ctx, queryRef)
}
return ""
}
func paymentOrderQueryReference(order *dbent.PaymentOrder, prov payment.Provider) string {
if order == nil {
return ""
}
providerKey := ""
if prov != nil {
providerKey = strings.TrimSpace(prov.ProviderKey())
}
if providerKey == "" {
if snapshot := psOrderProviderSnapshot(order); snapshot != nil {
providerKey = strings.TrimSpace(snapshot.ProviderKey)
}
}
if providerKey == "" {
providerKey = strings.TrimSpace(psStringValue(order.ProviderKey))
}
if providerKey == "" {
providerKey = strings.TrimSpace(order.PaymentType)
}
switch payment.GetBasePaymentType(providerKey) {
case payment.TypeAlipay, payment.TypeEasyPay, payment.TypeWxpay:
return strings.TrimSpace(order.OutTradeNo)
default:
if tradeNo := strings.TrimSpace(order.PaymentTradeNo); tradeNo != "" {
return tradeNo
}
return strings.TrimSpace(order.OutTradeNo)
}
}
func paymentOrderShouldPersistUpstreamTradeNo(queryRef, upstreamTradeNo, currentTradeNo string) bool {
upstreamTradeNo = strings.TrimSpace(upstreamTradeNo)
if upstreamTradeNo == "" {
return false
}
if strings.EqualFold(upstreamTradeNo, strings.TrimSpace(currentTradeNo)) {
return false
}
if strings.EqualFold(upstreamTradeNo, strings.TrimSpace(queryRef)) {
return false
}
return true
}
// VerifyOrderByOutTradeNo actively queries the upstream provider to check
// if a payment was made, and processes it if so. This handles the case where
// the provider's notify callback was missed (e.g. EasyPay popup mode).
......@@ -190,8 +247,9 @@ func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo
return o, nil
}
// VerifyOrderPublic verifies payment status without user authentication.
// Used by the payment result page when the user's session has expired.
// VerifyOrderPublic returns the currently persisted public order state without
// triggering any upstream reconciliation. Signed resume-token recovery is the
// only public recovery path allowed to query upstream state.
func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo string) (*dbent.PaymentOrder, error) {
o, err := s.entClient.PaymentOrder.Query().
Where(paymentorder.OutTradeNo(outTradeNo)).
......@@ -199,15 +257,6 @@ func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo strin
if err != nil {
return nil, infraerrors.NotFound("NOT_FOUND", "order not found")
}
if o.Status == OrderStatusPending || o.Status == OrderStatusExpired {
result := s.checkPaid(ctx, o)
if result == checkPaidResultAlreadyPaid {
o, err = s.entClient.PaymentOrder.Get(ctx, o.ID)
if err != nil {
return nil, fmt.Errorf("reload order: %w", err)
}
}
}
return o, nil
}
......@@ -236,22 +285,79 @@ func (s *PaymentService) ExpireTimedOutOrders(ctx context.Context) (int, error)
// getOrderProvider creates a provider using the order's original instance config.
// Falls back to registry lookup if instance ID is missing (legacy orders).
func (s *PaymentService) getOrderProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) {
if o.ProviderInstanceID != nil && *o.ProviderInstanceID != "" {
instID, err := strconv.ParseInt(*o.ProviderInstanceID, 10, 64)
if err == nil {
cfg, err := s.loadBalancer.GetInstanceConfig(ctx, instID)
if err == nil {
providerKey := s.registry.GetProviderKey(o.PaymentType)
if providerKey == "" {
providerKey = o.PaymentType
inst, err := s.getOrderProviderInstance(ctx, o)
if err != nil {
return nil, fmt.Errorf("load order provider instance: %w", err)
}
p, err := provider.CreateProvider(providerKey, *o.ProviderInstanceID, cfg)
if err == nil {
return p, nil
if inst != nil {
return s.createProviderFromInstance(ctx, inst)
}
if !paymentOrderAllowsRegistryFallback(o) {
return nil, fmt.Errorf("order %d provider instance is unresolved", o.ID)
}
providerKey := paymentOrderFallbackProviderKey(s.registry, o)
if providerKey == "" {
return nil, fmt.Errorf("order %d provider fallback key is missing", o.ID)
}
if !s.webhookRegistryFallbackAllowed(ctx, providerKey) {
return nil, fmt.Errorf("order %d provider fallback is ambiguous for %s", o.ID, providerKey)
}
s.EnsureProviders(ctx)
return s.registry.GetProvider(o.PaymentType)
}
func paymentOrderAllowsRegistryFallback(order *dbent.PaymentOrder) bool {
if order == nil {
return false
}
if psOrderProviderSnapshot(order) != nil {
return false
}
if strings.TrimSpace(psStringValue(order.ProviderInstanceID)) != "" {
return false
}
if strings.TrimSpace(psStringValue(order.ProviderKey)) != "" {
return false
}
return true
}
func paymentOrderFallbackProviderKey(registry *payment.Registry, order *dbent.PaymentOrder) string {
if order == nil {
return ""
}
if registry != nil {
if key := strings.TrimSpace(registry.GetProviderKey(payment.PaymentType(order.PaymentType))); key != "" {
return key
}
}
return strings.TrimSpace(payment.GetBasePaymentType(strings.TrimSpace(order.PaymentType)))
}
func (s *PaymentService) createProviderFromInstance(ctx context.Context, inst *dbent.PaymentProviderInstance) (payment.Provider, error) {
if inst == nil {
return nil, fmt.Errorf("payment provider instance is missing")
}
cfg, err := s.loadBalancer.GetInstanceConfig(ctx, int64(inst.ID))
if err != nil {
return nil, fmt.Errorf("load provider instance config: %w", err)
}
if inst.PaymentMode != "" {
cfg["paymentMode"] = inst.PaymentMode
}
instID := strconv.FormatInt(int64(inst.ID), 10)
prov, err := provider.CreateProvider(inst.ProviderKey, instID, cfg)
if err != nil {
return nil, fmt.Errorf("create provider from instance: %w", err)
}
return prov, nil
}
func psStringValue(value *string) string {
if value == nil {
return ""
}
return *value
}
//go:build unit
package service
import (
"context"
"database/sql"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "modernc.org/sqlite"
)
type paymentOrderLifecycleQueryProvider struct {
lastQueryTradeNo string
resp *payment.QueryOrderResponse
}
type paymentOrderLifecycleRedeemRepo struct {
codesByCode map[string]*RedeemCode
useCalls []struct {
id int64
userID int64
}
}
func (p *paymentOrderLifecycleQueryProvider) Name() string {
return "payment-order-lifecycle-query-provider"
}
func (p *paymentOrderLifecycleQueryProvider) ProviderKey() string { return payment.TypeAlipay }
func (p *paymentOrderLifecycleQueryProvider) SupportedTypes() []payment.PaymentType {
return []payment.PaymentType{payment.TypeAlipay}
}
func (p *paymentOrderLifecycleQueryProvider) CreatePayment(context.Context, payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
panic("unexpected call")
}
func (p *paymentOrderLifecycleQueryProvider) QueryOrder(_ context.Context, tradeNo string) (*payment.QueryOrderResponse, error) {
p.lastQueryTradeNo = tradeNo
return p.resp, nil
}
func (p *paymentOrderLifecycleQueryProvider) VerifyNotification(context.Context, string, map[string]string) (*payment.PaymentNotification, error) {
panic("unexpected call")
}
func (p *paymentOrderLifecycleQueryProvider) Refund(context.Context, payment.RefundRequest) (*payment.RefundResponse, error) {
panic("unexpected call")
}
func (r *paymentOrderLifecycleRedeemRepo) Create(context.Context, *RedeemCode) error {
panic("unexpected call")
}
func (r *paymentOrderLifecycleRedeemRepo) CreateBatch(context.Context, []RedeemCode) error {
panic("unexpected call")
}
func (r *paymentOrderLifecycleRedeemRepo) GetByID(_ context.Context, id int64) (*RedeemCode, error) {
for _, code := range r.codesByCode {
if code.ID != id {
continue
}
cloned := *code
return &cloned, nil
}
return nil, ErrRedeemCodeNotFound
}
func (r *paymentOrderLifecycleRedeemRepo) GetByCode(_ context.Context, code string) (*RedeemCode, error) {
redeemCode, ok := r.codesByCode[code]
if !ok {
return nil, ErrRedeemCodeNotFound
}
cloned := *redeemCode
return &cloned, nil
}
func (r *paymentOrderLifecycleRedeemRepo) Update(context.Context, *RedeemCode) error {
panic("unexpected call")
}
func (r *paymentOrderLifecycleRedeemRepo) Delete(context.Context, int64) error {
panic("unexpected call")
}
func (r *paymentOrderLifecycleRedeemRepo) Use(_ context.Context, id, userID int64) error {
for code, redeemCode := range r.codesByCode {
if redeemCode.ID != id {
continue
}
now := time.Now().UTC()
redeemCode.Status = StatusUsed
redeemCode.UsedBy = &userID
redeemCode.UsedAt = &now
r.codesByCode[code] = redeemCode
r.useCalls = append(r.useCalls, struct {
id int64
userID int64
}{id: id, userID: userID})
return nil
}
return ErrRedeemCodeNotFound
}
func (r *paymentOrderLifecycleRedeemRepo) List(context.Context, pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error) {
panic("unexpected call")
}
func (r *paymentOrderLifecycleRedeemRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string) ([]RedeemCode, *pagination.PaginationResult, error) {
panic("unexpected call")
}
func (r *paymentOrderLifecycleRedeemRepo) ListByUser(context.Context, int64, int) ([]RedeemCode, error) {
panic("unexpected call")
}
func (r *paymentOrderLifecycleRedeemRepo) ListByUserPaginated(context.Context, int64, pagination.PaginationParams, string) ([]RedeemCode, *pagination.PaginationResult, error) {
panic("unexpected call")
}
func (r *paymentOrderLifecycleRedeemRepo) SumPositiveBalanceByUser(context.Context, int64) (float64, error) {
panic("unexpected call")
}
func TestVerifyOrderByOutTradeNoBackfillsTradeNoFromPaidQuery(t *testing.T) {
ctx := context.Background()
client := newPaymentOrderLifecycleTestClient(t)
user, err := client.User.Create().
SetEmail("checkpaid@example.com").
SetPasswordHash("hash").
SetUsername("checkpaid-user").
Save(ctx)
require.NoError(t, err)
order, err := client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(88).
SetPayAmount(88).
SetFeeRate(0).
SetRechargeCode("CHECKPAID-UPSTREAM-TRADE-NO").
SetOutTradeNo("sub2_checkpaid_trade_no_missing").
SetPaymentType(payment.TypeAlipay).
SetPaymentTradeNo("").
SetOrderType(payment.OrderTypeBalance).
SetStatus(OrderStatusPending).
SetExpiresAt(time.Now().Add(time.Hour)).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
Save(ctx)
require.NoError(t, err)
userRepo := &mockUserRepo{
getByIDUser: &User{
ID: user.ID,
Email: user.Email,
Username: user.Username,
Balance: 0,
},
}
userRepo.updateBalanceFn = func(ctx context.Context, id int64, amount float64) error {
require.Equal(t, user.ID, id)
if userRepo.getByIDUser != nil {
userRepo.getByIDUser.Balance += amount
}
return nil
}
redeemRepo := &paymentOrderLifecycleRedeemRepo{
codesByCode: map[string]*RedeemCode{
order.RechargeCode: {
ID: 1,
Code: order.RechargeCode,
Type: RedeemTypeBalance,
Value: order.Amount,
Status: StatusUnused,
},
},
}
redeemService := NewRedeemService(
redeemRepo,
userRepo,
nil,
nil,
nil,
client,
nil,
)
registry := payment.NewRegistry()
provider := &paymentOrderLifecycleQueryProvider{
resp: &payment.QueryOrderResponse{
TradeNo: "upstream-trade-123",
Status: payment.ProviderStatusPaid,
Amount: 88,
},
}
registry.Register(provider)
svc := &PaymentService{
entClient: client,
registry: registry,
redeemService: redeemService,
userRepo: userRepo,
providersLoaded: true,
}
got, err := svc.VerifyOrderByOutTradeNo(ctx, order.OutTradeNo, user.ID)
require.NoError(t, err)
require.Equal(t, order.OutTradeNo, provider.lastQueryTradeNo)
require.Equal(t, OrderStatusCompleted, got.Status)
require.Equal(t, "upstream-trade-123", got.PaymentTradeNo)
reloaded, err := client.PaymentOrder.Get(ctx, order.ID)
require.NoError(t, err)
require.Equal(t, OrderStatusCompleted, reloaded.Status)
require.Equal(t, "upstream-trade-123", reloaded.PaymentTradeNo)
require.Equal(t, 88.0, userRepo.getByIDUser.Balance)
require.Len(t, redeemRepo.useCalls, 1)
require.Equal(t, int64(1), redeemRepo.useCalls[0].id)
require.Equal(t, user.ID, redeemRepo.useCalls[0].userID)
}
func TestVerifyOrderByOutTradeNoUsesOutTradeNoWhenPaymentTradeNoAlreadyExistsForAlipay(t *testing.T) {
ctx := context.Background()
client := newPaymentOrderLifecycleTestClient(t)
user, err := client.User.Create().
SetEmail("checkpaid-existing-trade@example.com").
SetPasswordHash("hash").
SetUsername("checkpaid-existing-trade-user").
Save(ctx)
require.NoError(t, err)
order, err := client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(88).
SetPayAmount(88).
SetFeeRate(0).
SetRechargeCode("CHECKPAID-EXISTING-TRADE-NO").
SetOutTradeNo("sub2_checkpaid_use_out_trade_no").
SetPaymentType(payment.TypeAlipay).
SetPaymentTradeNo("upstream-trade-existing").
SetOrderType(payment.OrderTypeBalance).
SetStatus(OrderStatusPending).
SetExpiresAt(time.Now().Add(time.Hour)).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
Save(ctx)
require.NoError(t, err)
userRepo := &mockUserRepo{
getByIDUser: &User{
ID: user.ID,
Email: user.Email,
Username: user.Username,
Balance: 0,
},
}
userRepo.updateBalanceFn = func(ctx context.Context, id int64, amount float64) error {
require.Equal(t, user.ID, id)
if userRepo.getByIDUser != nil {
userRepo.getByIDUser.Balance += amount
}
return nil
}
redeemRepo := &paymentOrderLifecycleRedeemRepo{
codesByCode: map[string]*RedeemCode{
order.RechargeCode: {
ID: 1,
Code: order.RechargeCode,
Type: RedeemTypeBalance,
Value: order.Amount,
Status: StatusUnused,
},
},
}
redeemService := NewRedeemService(
redeemRepo,
userRepo,
nil,
nil,
nil,
client,
nil,
)
registry := payment.NewRegistry()
provider := &paymentOrderLifecycleQueryProvider{
resp: &payment.QueryOrderResponse{
TradeNo: "upstream-trade-existing",
Status: payment.ProviderStatusPaid,
Amount: 88,
},
}
registry.Register(provider)
svc := &PaymentService{
entClient: client,
registry: registry,
redeemService: redeemService,
userRepo: userRepo,
providersLoaded: true,
}
got, err := svc.VerifyOrderByOutTradeNo(ctx, order.OutTradeNo, user.ID)
require.NoError(t, err)
require.Equal(t, order.OutTradeNo, provider.lastQueryTradeNo)
require.Equal(t, "upstream-trade-existing", got.PaymentTradeNo)
}
func TestPaymentOrderAllowsRegistryFallbackOnlyForLegacyOrdersWithoutPinnedProviderState(t *testing.T) {
t.Parallel()
require.True(t, paymentOrderAllowsRegistryFallback(&dbent.PaymentOrder{
PaymentType: payment.TypeAlipay,
}))
instanceID := "12"
require.False(t, paymentOrderAllowsRegistryFallback(&dbent.PaymentOrder{
PaymentType: payment.TypeAlipay,
ProviderInstanceID: &instanceID,
}))
require.False(t, paymentOrderAllowsRegistryFallback(&dbent.PaymentOrder{
PaymentType: payment.TypeAlipay,
ProviderSnapshot: map[string]any{
"schema_version": 2,
"provider_instance_id": "12",
},
}))
}
func TestPaymentOrderQueryReferenceUsesOutTradeNoForOfficialProviders(t *testing.T) {
t.Parallel()
order := &dbent.PaymentOrder{
PaymentType: payment.TypeWxpay,
OutTradeNo: "sub2_out_trade_no",
PaymentTradeNo: "wx-transaction-id",
}
require.Equal(t, "sub2_out_trade_no", paymentOrderQueryReference(order, &paymentOrderLifecycleQueryProvider{}))
require.Equal(t, "sub2_out_trade_no", paymentOrderQueryReference(order, paymentFulfillmentTestProvider{
key: payment.TypeWxpay,
}))
}
func newPaymentOrderLifecycleTestClient(t *testing.T) *dbent.Client {
t.Helper()
db, err := sql.Open("sqlite", "file:payment_order_lifecycle?mode=memory&cache=shared&_fk=1")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
return client
}
package service
import (
"context"
"fmt"
"strconv"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/payment"
)
type paymentOrderProviderSnapshot struct {
SchemaVersion int
ProviderInstanceID string
ProviderKey string
PaymentMode string
MerchantAppID string
MerchantID string
Currency string
}
func psOrderProviderSnapshot(order *dbent.PaymentOrder) *paymentOrderProviderSnapshot {
if order == nil || len(order.ProviderSnapshot) == 0 {
return nil
}
snapshot := &paymentOrderProviderSnapshot{
SchemaVersion: psSnapshotIntValue(order.ProviderSnapshot["schema_version"]),
ProviderInstanceID: psSnapshotStringValue(order.ProviderSnapshot["provider_instance_id"]),
ProviderKey: psSnapshotStringValue(order.ProviderSnapshot["provider_key"]),
PaymentMode: psSnapshotStringValue(order.ProviderSnapshot["payment_mode"]),
MerchantAppID: psSnapshotStringValue(order.ProviderSnapshot["merchant_app_id"]),
MerchantID: psSnapshotStringValue(order.ProviderSnapshot["merchant_id"]),
Currency: psSnapshotStringValue(order.ProviderSnapshot["currency"]),
}
if snapshot.SchemaVersion == 0 &&
snapshot.ProviderInstanceID == "" &&
snapshot.ProviderKey == "" &&
snapshot.PaymentMode == "" &&
snapshot.MerchantAppID == "" &&
snapshot.MerchantID == "" &&
snapshot.Currency == "" {
return nil
}
return snapshot
}
func psSnapshotStringValue(value any) string {
switch typed := value.(type) {
case string:
return strings.TrimSpace(typed)
default:
return ""
}
}
func psSnapshotIntValue(value any) int {
switch typed := value.(type) {
case int:
return typed
case int32:
return int(typed)
case int64:
return int(typed)
case float32:
return int(typed)
case float64:
return int(typed)
case string:
n, err := strconv.Atoi(strings.TrimSpace(typed))
if err == nil {
return n
}
}
return 0
}
func (s *PaymentService) resolveSnapshotOrderProviderInstance(ctx context.Context, order *dbent.PaymentOrder, snapshot *paymentOrderProviderSnapshot) (*dbent.PaymentProviderInstance, error) {
if s == nil || s.entClient == nil || order == nil || snapshot == nil {
return nil, nil
}
snapshotInstanceID := strings.TrimSpace(snapshot.ProviderInstanceID)
columnInstanceID := strings.TrimSpace(psStringValue(order.ProviderInstanceID))
if snapshotInstanceID == "" {
snapshotInstanceID = columnInstanceID
}
if snapshotInstanceID == "" {
return nil, fmt.Errorf("order %d provider snapshot is missing provider_instance_id", order.ID)
}
if columnInstanceID != "" && snapshot.ProviderInstanceID != "" && !strings.EqualFold(columnInstanceID, snapshot.ProviderInstanceID) {
return nil, fmt.Errorf("order %d provider snapshot instance mismatch: snapshot=%s order=%s", order.ID, snapshot.ProviderInstanceID, columnInstanceID)
}
instID, err := strconv.ParseInt(snapshotInstanceID, 10, 64)
if err != nil {
return nil, fmt.Errorf("order %d provider snapshot instance id is invalid: %s", order.ID, snapshotInstanceID)
}
inst, err := s.entClient.PaymentProviderInstance.Get(ctx, instID)
if err != nil {
if dbent.IsNotFound(err) {
return nil, fmt.Errorf("order %d provider snapshot instance %s is missing", order.ID, snapshotInstanceID)
}
return nil, err
}
if snapshot.ProviderKey != "" && !strings.EqualFold(strings.TrimSpace(inst.ProviderKey), snapshot.ProviderKey) {
return nil, fmt.Errorf("order %d provider snapshot key mismatch: snapshot=%s instance=%s", order.ID, snapshot.ProviderKey, inst.ProviderKey)
}
return inst, nil
}
func expectedNotificationProviderKeyForOrder(registry *payment.Registry, order *dbent.PaymentOrder, instanceProviderKey string) string {
if order == nil {
return strings.TrimSpace(instanceProviderKey)
}
orderProviderKey := psStringValue(order.ProviderKey)
if snapshot := psOrderProviderSnapshot(order); snapshot != nil && snapshot.ProviderKey != "" {
orderProviderKey = snapshot.ProviderKey
}
return expectedNotificationProviderKey(registry, order.PaymentType, orderProviderKey, instanceProviderKey)
}
func validateProviderSnapshotMetadata(order *dbent.PaymentOrder, providerKey string, metadata map[string]string) error {
if order == nil || len(metadata) == 0 {
return nil
}
snapshot := psOrderProviderSnapshot(order)
if snapshot == nil {
return nil
}
switch strings.TrimSpace(providerKey) {
case payment.TypeWxpay:
if expected := strings.TrimSpace(snapshot.MerchantAppID); expected != "" {
actual := strings.TrimSpace(metadata["appid"])
if actual == "" {
return fmt.Errorf("wxpay notification missing appid")
}
if !strings.EqualFold(expected, actual) {
return fmt.Errorf("wxpay appid mismatch: expected %s, got %s", expected, actual)
}
}
if expected := strings.TrimSpace(snapshot.MerchantID); expected != "" {
actual := strings.TrimSpace(metadata["mchid"])
if actual == "" {
return fmt.Errorf("wxpay notification missing mchid")
}
if !strings.EqualFold(expected, actual) {
return fmt.Errorf("wxpay mchid mismatch: expected %s, got %s", expected, actual)
}
}
if expected := strings.TrimSpace(snapshot.Currency); expected != "" {
actual := strings.ToUpper(strings.TrimSpace(metadata["currency"]))
if actual == "" {
return fmt.Errorf("wxpay notification missing currency")
}
if !strings.EqualFold(expected, actual) {
return fmt.Errorf("wxpay currency mismatch: expected %s, got %s", expected, actual)
}
}
if actual := strings.TrimSpace(metadata["trade_state"]); actual != "" && !strings.EqualFold(actual, "SUCCESS") {
return fmt.Errorf("wxpay trade_state mismatch: expected SUCCESS, got %s", actual)
}
case payment.TypeAlipay:
if expected := strings.TrimSpace(snapshot.MerchantAppID); expected != "" {
actual := strings.TrimSpace(metadata["app_id"])
if actual == "" {
return fmt.Errorf("alipay app_id missing")
}
if !strings.EqualFold(expected, actual) {
return fmt.Errorf("alipay app_id mismatch: expected %s, got %s", expected, actual)
}
}
case payment.TypeEasyPay:
if expected := strings.TrimSpace(snapshot.MerchantID); expected != "" {
actual := strings.TrimSpace(metadata["pid"])
if actual == "" {
return fmt.Errorf("easypay pid missing")
}
if !strings.EqualFold(expected, actual) {
return fmt.Errorf("easypay pid mismatch: expected %s, got %s", expected, actual)
}
}
}
return nil
}
func providerMerchantIdentityMetadata(prov payment.Provider) map[string]string {
if prov == nil {
return nil
}
reporter, ok := prov.(payment.MerchantIdentityProvider)
if !ok {
return nil
}
return reporter.MerchantIdentityMetadata()
}
//go:build unit
package service
import (
"context"
"strconv"
"testing"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/stretchr/testify/require"
)
func TestBuildPaymentOrderProviderSnapshot_ExcludesSensitiveConfig(t *testing.T) {
t.Parallel()
sel := &payment.InstanceSelection{
InstanceID: "12",
ProviderKey: payment.TypeWxpay,
SupportedTypes: "wxpay,wxpay_direct",
PaymentMode: "popup",
Config: map[string]string{
"privateKey": "secret",
"apiV3Key": "secret-v3",
"appId": "wx-app-id",
},
}
snapshot := buildPaymentOrderProviderSnapshot(sel, CreateOrderRequest{})
require.Equal(t, map[string]any{
"schema_version": 2,
"provider_instance_id": "12",
"provider_key": payment.TypeWxpay,
"payment_mode": "popup",
"merchant_app_id": "wx-app-id",
"currency": "CNY",
}, snapshot)
require.NotContains(t, snapshot, "config")
require.NotContains(t, snapshot, "privateKey")
require.NotContains(t, snapshot, "apiV3Key")
require.NotContains(t, snapshot, "supported_types")
require.NotContains(t, snapshot, "instance_name")
require.NotContains(t, snapshot, "merchant_id")
}
func TestCreateOrderInTx_WritesProviderSnapshot(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
user, err := client.User.Create().
SetEmail("snapshot@example.com").
SetPasswordHash("hash").
SetUsername("snapshot-user").
Save(ctx)
require.NoError(t, err)
instance, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeAlipay).
SetName("Primary Alipay").
SetConfig(`{"secretKey":"do-not-copy"}`).
SetSupportedTypes("alipay,alipay_direct").
SetPaymentMode("redirect").
SetEnabled(true).
Save(ctx)
require.NoError(t, err)
svc := &PaymentService{entClient: client}
order, err := svc.createOrderInTx(
ctx,
CreateOrderRequest{
UserID: user.ID,
PaymentType: payment.TypeAlipay,
OrderType: payment.OrderTypeBalance,
ClientIP: "127.0.0.1",
SrcHost: "app.example.com",
},
&User{
ID: user.ID,
Email: user.Email,
Username: user.Username,
},
nil,
&PaymentConfig{
MaxPendingOrders: 3,
OrderTimeoutMin: 30,
},
88,
88,
0,
88,
&payment.InstanceSelection{
InstanceID: strconv.FormatInt(instance.ID, 10),
ProviderKey: payment.TypeAlipay,
SupportedTypes: "alipay,alipay_direct",
PaymentMode: "redirect",
Config: map[string]string{
"secretKey": "do-not-copy",
},
},
)
require.NoError(t, err)
require.Equal(t, strconv.FormatInt(instance.ID, 10), valueOrEmpty(order.ProviderInstanceID))
require.Equal(t, payment.TypeAlipay, valueOrEmpty(order.ProviderKey))
require.Equal(t, float64(2), order.ProviderSnapshot["schema_version"])
require.Equal(t, strconv.FormatInt(instance.ID, 10), order.ProviderSnapshot["provider_instance_id"])
require.Equal(t, payment.TypeAlipay, order.ProviderSnapshot["provider_key"])
require.Equal(t, "redirect", order.ProviderSnapshot["payment_mode"])
require.NotContains(t, order.ProviderSnapshot, "config")
require.NotContains(t, order.ProviderSnapshot, "secretKey")
require.NotContains(t, order.ProviderSnapshot, "supported_types")
require.NotContains(t, order.ProviderSnapshot, "instance_name")
}
func TestBuildPaymentOrderProviderSnapshot_UsesWxpayJSAPIAppIDForOpenIDOrders(t *testing.T) {
t.Parallel()
snapshot := buildPaymentOrderProviderSnapshot(&payment.InstanceSelection{
InstanceID: "88",
ProviderKey: payment.TypeWxpay,
Config: map[string]string{
"appId": "wx-open-app",
"mpAppId": "wx-mp-app",
"mchId": "mch-88",
},
PaymentMode: "jsapi",
}, CreateOrderRequest{OpenID: "openid-123"})
require.Equal(t, "wx-mp-app", snapshot["merchant_app_id"])
require.Equal(t, "mch-88", snapshot["merchant_id"])
require.Equal(t, "CNY", snapshot["currency"])
}
func TestBuildPaymentOrderProviderSnapshot_IncludesAlipayMerchantIdentity(t *testing.T) {
t.Parallel()
snapshot := buildPaymentOrderProviderSnapshot(&payment.InstanceSelection{
InstanceID: "21",
ProviderKey: payment.TypeAlipay,
Config: map[string]string{
"appId": "alipay-app-21",
"privateKey": "secret",
},
PaymentMode: "redirect",
}, CreateOrderRequest{})
require.Equal(t, "alipay-app-21", snapshot["merchant_app_id"])
require.NotContains(t, snapshot, "privateKey")
}
func TestBuildPaymentOrderProviderSnapshot_IncludesEasyPayMerchantIdentity(t *testing.T) {
t.Parallel()
snapshot := buildPaymentOrderProviderSnapshot(&payment.InstanceSelection{
InstanceID: "66",
ProviderKey: payment.TypeEasyPay,
Config: map[string]string{
"pid": "easypay-merchant-66",
"pkey": "secret",
},
PaymentMode: "popup",
}, CreateOrderRequest{PaymentType: payment.TypeAlipay})
require.Equal(t, "easypay-merchant-66", snapshot["merchant_id"])
require.NotContains(t, snapshot, "pkey")
}
func valueOrEmpty(v *string) string {
if v == nil {
return ""
}
return *v
}
package service
import (
"context"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
func TestBuildCreateOrderResponseDefaultsToOrderCreated(t *testing.T) {
t.Parallel()
expiresAt := time.Date(2026, 4, 16, 12, 0, 0, 0, time.UTC)
resp := buildCreateOrderResponse(
&dbent.PaymentOrder{
ID: 42,
Amount: 12.34,
FeeRate: 0.03,
ExpiresAt: expiresAt,
OutTradeNo: "sub2_42",
},
CreateOrderRequest{PaymentType: payment.TypeWxpay},
12.71,
&payment.InstanceSelection{PaymentMode: "qrcode"},
&payment.CreatePaymentResponse{
TradeNo: "sub2_42",
QRCode: "weixin://wxpay/bizpayurl?pr=test",
},
payment.CreatePaymentResultOrderCreated,
)
if resp.ResultType != payment.CreatePaymentResultOrderCreated {
t.Fatalf("result type = %q, want %q", resp.ResultType, payment.CreatePaymentResultOrderCreated)
}
if resp.OutTradeNo != "sub2_42" {
t.Fatalf("out_trade_no = %q, want %q", resp.OutTradeNo, "sub2_42")
}
if resp.QRCode != "weixin://wxpay/bizpayurl?pr=test" {
t.Fatalf("qr_code = %q, want %q", resp.QRCode, "weixin://wxpay/bizpayurl?pr=test")
}
if resp.JSAPI != nil || resp.JSAPIPayload != nil {
t.Fatal("order_created response should not include jsapi payload")
}
if !resp.ExpiresAt.Equal(expiresAt) {
t.Fatalf("expires_at = %v, want %v", resp.ExpiresAt, expiresAt)
}
}
func TestBuildCreateOrderResponseCopiesJSAPIPayload(t *testing.T) {
t.Parallel()
jsapiPayload := &payment.WechatJSAPIPayload{
AppID: "wx123",
TimeStamp: "1712345678",
NonceStr: "nonce-123",
Package: "prepay_id=wx123",
SignType: "RSA",
PaySign: "signed-payload",
}
resp := buildCreateOrderResponse(
&dbent.PaymentOrder{
ID: 88,
Amount: 66.88,
FeeRate: 0.01,
ExpiresAt: time.Date(2026, 4, 16, 13, 0, 0, 0, time.UTC),
OutTradeNo: "sub2_88",
},
CreateOrderRequest{PaymentType: payment.TypeWxpay},
67.55,
&payment.InstanceSelection{PaymentMode: "popup"},
&payment.CreatePaymentResponse{
TradeNo: "sub2_88",
ResultType: payment.CreatePaymentResultJSAPIReady,
JSAPI: jsapiPayload,
},
payment.CreatePaymentResultJSAPIReady,
)
if resp.ResultType != payment.CreatePaymentResultJSAPIReady {
t.Fatalf("result type = %q, want %q", resp.ResultType, payment.CreatePaymentResultJSAPIReady)
}
if resp.JSAPI == nil || resp.JSAPIPayload == nil {
t.Fatal("expected jsapi payload aliases to be populated")
}
if resp.JSAPI != jsapiPayload || resp.JSAPIPayload != jsapiPayload {
t.Fatal("expected jsapi aliases to preserve the original pointer")
}
}
func TestMaybeBuildWeChatOAuthRequiredResponse(t *testing.T) {
svc := newWeChatPaymentOAuthTestService(map[string]string{
SettingKeyWeChatConnectEnabled: "true",
SettingKeyWeChatConnectAppID: "wx123456",
SettingKeyWeChatConnectAppSecret: "wechat-secret",
SettingKeyWeChatConnectMode: "mp",
SettingKeyWeChatConnectScopes: "snsapi_base",
SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
})
resp, err := svc.maybeBuildWeChatOAuthRequiredResponse(context.Background(), CreateOrderRequest{
Amount: 12.5,
PaymentType: payment.TypeWxpay,
IsWeChatBrowser: true,
SrcURL: "https://merchant.example/payment?from=wechat",
OrderType: payment.OrderTypeBalance,
}, 12.5, 12.88, 0.03)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp == nil {
t.Fatal("expected oauth_required response, got nil")
}
if resp.ResultType != payment.CreatePaymentResultOAuthRequired {
t.Fatalf("result type = %q, want %q", resp.ResultType, payment.CreatePaymentResultOAuthRequired)
}
if resp.OAuth == nil {
t.Fatal("expected oauth payload, got nil")
}
if resp.OAuth.AppID != "wx123456" {
t.Fatalf("appid = %q, want %q", resp.OAuth.AppID, "wx123456")
}
if resp.OAuth.Scope != "snsapi_base" {
t.Fatalf("scope = %q, want %q", resp.OAuth.Scope, "snsapi_base")
}
if resp.OAuth.RedirectURL != "/auth/wechat/payment/callback" {
t.Fatalf("redirect_url = %q, want %q", resp.OAuth.RedirectURL, "/auth/wechat/payment/callback")
}
if resp.OAuth.AuthorizeURL != "/api/v1/auth/oauth/wechat/payment/start?amount=12.5&order_type=balance&payment_type=wxpay&redirect=%2Fpurchase%3Ffrom%3Dwechat&scope=snsapi_base" {
t.Fatalf("authorize_url = %q", resp.OAuth.AuthorizeURL)
}
}
func TestMaybeBuildWeChatOAuthRequiredResponseRequiresMPConfigInWeChat(t *testing.T) {
t.Parallel()
svc := newWeChatPaymentOAuthTestService(nil)
resp, err := svc.maybeBuildWeChatOAuthRequiredResponse(context.Background(), CreateOrderRequest{
Amount: 12.5,
PaymentType: payment.TypeWxpay,
IsWeChatBrowser: true,
SrcURL: "https://merchant.example/payment?from=wechat",
OrderType: payment.OrderTypeBalance,
}, 12.5, 12.88, 0.03)
if resp != nil {
t.Fatalf("expected nil response, got %+v", resp)
}
if err == nil {
t.Fatal("expected error, got nil")
}
appErr := infraerrors.FromError(err)
if appErr.Reason != "WECHAT_PAYMENT_MP_NOT_CONFIGURED" {
t.Fatalf("reason = %q, want %q", appErr.Reason, "WECHAT_PAYMENT_MP_NOT_CONFIGURED")
}
}
func TestMaybeBuildWeChatOAuthRequiredResponseForSelectionSkipsEasyPayProvider(t *testing.T) {
svc := newWeChatPaymentOAuthTestService(map[string]string{
SettingKeyWeChatConnectEnabled: "true",
SettingKeyWeChatConnectAppID: "wx123456",
SettingKeyWeChatConnectAppSecret: "wechat-secret",
SettingKeyWeChatConnectMode: "mp",
SettingKeyWeChatConnectScopes: "snsapi_base",
SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
})
resp, err := svc.maybeBuildWeChatOAuthRequiredResponseForSelection(context.Background(), CreateOrderRequest{
Amount: 12.5,
PaymentType: payment.TypeWxpay,
IsWeChatBrowser: true,
OrderType: payment.OrderTypeBalance,
}, 12.5, 12.88, 0.03, &payment.InstanceSelection{
ProviderKey: payment.TypeEasyPay,
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp != nil {
t.Fatalf("expected nil response, got %+v", resp)
}
}
func newWeChatPaymentOAuthTestService(values map[string]string) *PaymentService {
return &PaymentService{
configService: &PaymentConfigService{
settingRepo: &paymentConfigSettingRepoStub{values: values},
},
}
}
......@@ -12,6 +12,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
......@@ -19,18 +20,133 @@ import (
// --- Refund Flow ---
// getOrderProviderInstance looks up the provider instance that processed this order.
// Returns nil, nil for legacy orders without provider_instance_id.
// For legacy orders without provider_instance_id, it resolves only when the
// historical instance is uniquely identifiable from the stored order fields.
func (s *PaymentService) getOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) {
if o.ProviderInstanceID == nil || *o.ProviderInstanceID == "" {
if s == nil || s.entClient == nil || o == nil {
return nil, nil
}
instID, err := strconv.ParseInt(*o.ProviderInstanceID, 10, 64)
if snapshot := psOrderProviderSnapshot(o); snapshot != nil {
return s.resolveSnapshotOrderProviderInstance(ctx, o, snapshot)
}
instIDStr := strings.TrimSpace(psStringValue(o.ProviderInstanceID))
if instIDStr == "" {
return s.resolveUniqueLegacyOrderProviderInstance(ctx, o)
}
instID, err := strconv.ParseInt(instIDStr, 10, 64)
if err != nil {
return nil, nil
}
return s.entClient.PaymentProviderInstance.Get(ctx, instID)
}
// getRefundOrderProviderInstance resolves the provider instance for refund paths.
// Refunds must be pinned to an explicit historical binding, so legacy
// "best-effort" provider guessing is intentionally not allowed here.
func (s *PaymentService) getRefundOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) {
if s == nil || s.entClient == nil || o == nil {
return nil, nil
}
if snapshot := psOrderProviderSnapshot(o); snapshot != nil {
return s.resolveSnapshotOrderProviderInstance(ctx, o, snapshot)
}
instIDStr := strings.TrimSpace(psStringValue(o.ProviderInstanceID))
if instIDStr == "" {
return nil, nil
}
instID, err := strconv.ParseInt(instIDStr, 10, 64)
if err != nil {
return nil, fmt.Errorf("order %d refund provider instance id is invalid: %s", o.ID, instIDStr)
}
inst, err := s.entClient.PaymentProviderInstance.Get(ctx, instID)
if err != nil {
if dbent.IsNotFound(err) {
return nil, fmt.Errorf("order %d refund provider instance %s is missing", o.ID, instIDStr)
}
return nil, err
}
return inst, nil
}
func (s *PaymentService) resolveUniqueLegacyOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) {
paymentType := payment.GetBasePaymentType(strings.TrimSpace(o.PaymentType))
providerKey := strings.TrimSpace(psStringValue(o.ProviderKey))
if providerKey != "" {
instances, err := s.entClient.PaymentProviderInstance.Query().
Where(paymentproviderinstance.ProviderKeyEQ(providerKey)).
All(ctx)
if err != nil {
return nil, err
}
matched := psFilterLegacyOrderProviderInstances(paymentType, instances)
if len(matched) == 1 {
return matched[0], nil
}
return nil, nil
}
if paymentType == "" {
return nil, nil
}
instances, err := s.entClient.PaymentProviderInstance.Query().
All(ctx)
if err != nil {
return nil, err
}
matched := psFilterLegacyOrderProviderInstances(paymentType, instances)
if len(matched) == 1 {
return matched[0], nil
}
return nil, nil
}
func psFilterLegacyOrderProviderInstances(orderPaymentType string, instances []*dbent.PaymentProviderInstance) []*dbent.PaymentProviderInstance {
if len(instances) == 0 {
return nil
}
if strings.TrimSpace(orderPaymentType) == "" {
return instances
}
var matched []*dbent.PaymentProviderInstance
for _, inst := range instances {
if psLegacyOrderMatchesInstance(orderPaymentType, inst) {
matched = append(matched, inst)
}
}
return matched
}
func psLegacyOrderMatchesInstance(orderPaymentType string, inst *dbent.PaymentProviderInstance) bool {
if inst == nil {
return false
}
baseType := payment.GetBasePaymentType(strings.TrimSpace(orderPaymentType))
instanceProviderKey := strings.TrimSpace(inst.ProviderKey)
if baseType == "" {
return false
}
if baseType == payment.TypeStripe {
return instanceProviderKey == payment.TypeStripe
}
if instanceProviderKey == payment.TypeStripe {
return false
}
if instanceProviderKey == baseType {
return true
}
return payment.InstanceSupportsType(inst.SupportedTypes, baseType)
}
func (s *PaymentService) RequestRefund(ctx context.Context, oid, uid int64, reason string) error {
o, err := s.validateRefundRequest(ctx, oid, uid)
if err != nil {
......@@ -72,7 +188,7 @@ func (s *PaymentService) validateRefundRequest(ctx context.Context, oid, uid int
return nil, infraerrors.BadRequest("INVALID_STATUS", "only completed orders can request refund")
}
// Check provider instance allows user refund
inst, err := s.getOrderProviderInstance(ctx, o)
inst, err := s.getRefundOrderProviderInstance(ctx, o)
if err != nil || inst == nil {
return nil, infraerrors.Forbidden("USER_REFUND_DISABLED", "refund is not available for this order")
}
......@@ -92,7 +208,7 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float
return nil, nil, infraerrors.BadRequest("INVALID_STATUS", "order status does not allow refund")
}
// Check provider instance allows admin refund
inst, instErr := s.getOrderProviderInstance(ctx, o)
inst, instErr := s.getRefundOrderProviderInstance(ctx, o)
if instErr != nil {
slog.Warn("refund: provider instance lookup failed", "orderID", oid, "error", instErr)
return nil, nil, infraerrors.InternalServer("PROVIDER_LOOKUP_FAILED", "failed to look up payment provider for this order")
......@@ -217,6 +333,12 @@ func (s *PaymentService) gwRefund(ctx context.Context, p *RefundPlan) error {
if err != nil {
return fmt.Errorf("get refund provider: %w", err)
}
if err := validateProviderSnapshotMetadata(p.Order, prov.ProviderKey(), providerMerchantIdentityMetadata(prov)); err != nil {
s.writeAuditLog(ctx, p.Order.ID, "REFUND_PROVIDER_METADATA_MISMATCH", "admin", map[string]any{
"detail": err.Error(),
})
return err
}
_, err = prov.Refund(ctx, payment.RefundRequest{
TradeNo: p.Order.PaymentTradeNo,
OrderID: p.Order.OutTradeNo,
......@@ -229,7 +351,14 @@ func (s *PaymentService) gwRefund(ctx context.Context, p *RefundPlan) error {
// getRefundProvider creates a provider using the order's original instance config.
// Delegates to getOrderProvider which handles instance lookup and fallback.
func (s *PaymentService) getRefundProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) {
return s.getOrderProvider(ctx, o)
inst, err := s.getRefundOrderProviderInstance(ctx, o)
if err != nil {
return nil, err
}
if inst == nil {
return nil, fmt.Errorf("refund provider instance is unavailable for order %d", o.ID)
}
return s.createProviderFromInstance(ctx, inst)
}
func (s *PaymentService) handleGwFail(ctx context.Context, p *RefundPlan, gErr error) (*RefundResult, error) {
......
//go:build unit
package service
import (
"context"
"strconv"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require"
)
func TestValidateRefundRequestRejectsLegacyGuessedProviderInstance(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
user, err := client.User.Create().
SetEmail("refund-legacy@example.com").
SetPasswordHash("hash").
SetUsername("refund-legacy-user").
Save(ctx)
require.NoError(t, err)
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeAlipay).
SetName("alipay-refund-instance").
SetConfig("{}").
SetSupportedTypes("alipay").
SetEnabled(true).
SetAllowUserRefund(true).
SetRefundEnabled(true).
Save(ctx)
require.NoError(t, err)
order, err := client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(88).
SetPayAmount(88).
SetFeeRate(0).
SetRechargeCode("REFUND-LEGACY-ORDER").
SetOutTradeNo("sub2_refund_legacy_order").
SetPaymentType(payment.TypeAlipay).
SetPaymentTradeNo("trade-legacy-refund").
SetOrderType(payment.OrderTypeBalance).
SetStatus(OrderStatusCompleted).
SetExpiresAt(time.Now().Add(time.Hour)).
SetPaidAt(time.Now()).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
Save(ctx)
require.NoError(t, err)
svc := &PaymentService{
entClient: client,
}
_, err = svc.validateRefundRequest(ctx, order.ID, user.ID)
require.Error(t, err)
require.Equal(t, "USER_REFUND_DISABLED", infraerrors.Reason(err))
}
func TestPrepareRefundRejectsLegacyGuessedProviderInstance(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
user, err := client.User.Create().
SetEmail("refund-legacy-admin@example.com").
SetPasswordHash("hash").
SetUsername("refund-legacy-admin-user").
Save(ctx)
require.NoError(t, err)
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeAlipay).
SetName("alipay-refund-admin-instance").
SetConfig("{}").
SetSupportedTypes("alipay").
SetEnabled(true).
SetAllowUserRefund(true).
SetRefundEnabled(true).
Save(ctx)
require.NoError(t, err)
order, err := client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(188).
SetPayAmount(188).
SetFeeRate(0).
SetRechargeCode("REFUND-LEGACY-ADMIN-ORDER").
SetOutTradeNo("sub2_refund_legacy_admin_order").
SetPaymentType(payment.TypeAlipay).
SetPaymentTradeNo("trade-legacy-admin-refund").
SetOrderType(payment.OrderTypeBalance).
SetStatus(OrderStatusCompleted).
SetExpiresAt(time.Now().Add(time.Hour)).
SetPaidAt(time.Now()).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
Save(ctx)
require.NoError(t, err)
svc := &PaymentService{
entClient: client,
}
plan, result, err := svc.PrepareRefund(ctx, order.ID, 0, "", false, false)
require.Nil(t, plan)
require.Nil(t, result)
require.Error(t, err)
require.Equal(t, "REFUND_DISABLED", infraerrors.Reason(err))
}
func TestGwRefundRejectsAlipayMerchantIdentitySnapshotMismatch(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
user, err := client.User.Create().
SetEmail("refund-snapshot-mismatch@example.com").
SetPasswordHash("hash").
SetUsername("refund-snapshot-mismatch-user").
Save(ctx)
require.NoError(t, err)
inst, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeAlipay).
SetName("alipay-refund-mismatch-instance").
SetConfig(encryptWebhookProviderConfig(t, map[string]string{
"appId": "runtime-alipay-app",
"privateKey": "runtime-private-key",
})).
SetSupportedTypes("alipay").
SetEnabled(true).
SetRefundEnabled(true).
Save(ctx)
require.NoError(t, err)
instID := strconv.FormatInt(inst.ID, 10)
order, err := client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(88).
SetPayAmount(88).
SetFeeRate(0).
SetRechargeCode("REFUND-SNAPSHOT-MISMATCH-ORDER").
SetOutTradeNo("sub2_refund_snapshot_mismatch_order").
SetPaymentType(payment.TypeAlipay).
SetPaymentTradeNo("trade-refund-snapshot-mismatch").
SetOrderType(payment.OrderTypeBalance).
SetStatus(OrderStatusCompleted).
SetExpiresAt(time.Now().Add(time.Hour)).
SetPaidAt(time.Now()).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
SetProviderInstanceID(instID).
SetProviderKey(payment.TypeAlipay).
SetProviderSnapshot(map[string]any{
"schema_version": 2,
"provider_instance_id": instID,
"provider_key": payment.TypeAlipay,
"merchant_app_id": "expected-alipay-app",
}).
Save(ctx)
require.NoError(t, err)
svc := &PaymentService{
entClient: client,
loadBalancer: newWebhookProviderTestLoadBalancer(client),
}
err = svc.gwRefund(ctx, &RefundPlan{
OrderID: order.ID,
Order: order,
RefundAmount: order.Amount,
GatewayAmount: order.Amount,
Reason: "snapshot mismatch",
})
require.ErrorContains(t, err, "alipay app_id mismatch")
}
package service
import (
"context"
"fmt"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
)
func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token string) (*dbent.PaymentOrder, error) {
claims, err := s.paymentResume().ParseToken(strings.TrimSpace(token))
if err != nil {
return nil, err
}
order, err := s.entClient.PaymentOrder.Get(ctx, claims.OrderID)
if err != nil {
return nil, fmt.Errorf("get order by resume token: %w", err)
}
if claims.UserID > 0 && order.UserID != claims.UserID {
return nil, fmt.Errorf("resume token user mismatch")
}
snapshot := psOrderProviderSnapshot(order)
orderProviderInstanceID := strings.TrimSpace(psStringValue(order.ProviderInstanceID))
orderProviderKey := strings.TrimSpace(psStringValue(order.ProviderKey))
if snapshot != nil {
if snapshot.ProviderInstanceID != "" {
orderProviderInstanceID = snapshot.ProviderInstanceID
}
if snapshot.ProviderKey != "" {
orderProviderKey = snapshot.ProviderKey
}
}
if claims.ProviderInstanceID != "" && orderProviderInstanceID != claims.ProviderInstanceID {
return nil, fmt.Errorf("resume token provider instance mismatch")
}
if claims.ProviderKey != "" && orderProviderKey != claims.ProviderKey {
return nil, fmt.Errorf("resume token provider key mismatch")
}
if claims.PaymentType != "" && strings.TrimSpace(order.PaymentType) != claims.PaymentType {
return nil, fmt.Errorf("resume token payment type mismatch")
}
if order.Status == OrderStatusPending || order.Status == OrderStatusExpired {
result := s.checkPaid(ctx, order)
if result == checkPaidResultAlreadyPaid {
order, err = s.entClient.PaymentOrder.Get(ctx, order.ID)
if err != nil {
return nil, fmt.Errorf("reload order by resume token: %w", err)
}
}
}
return order, nil
}
func (s *PaymentService) ParseWeChatPaymentResumeToken(token string) (*WeChatPaymentResumeClaims, error) {
return s.paymentResume().ParseWeChatPaymentResumeToken(strings.TrimSpace(token))
}
//go:build unit
package service
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/stretchr/testify/require"
)
type paymentResumeLookupProvider struct {
queryCount int
}
func (p *paymentResumeLookupProvider) Name() string { return "resume-lookup-provider" }
func (p *paymentResumeLookupProvider) ProviderKey() string { return payment.TypeAlipay }
func (p *paymentResumeLookupProvider) SupportedTypes() []payment.PaymentType {
return []payment.PaymentType{payment.TypeAlipay}
}
func (p *paymentResumeLookupProvider) CreatePayment(context.Context, payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
panic("unexpected call")
}
func (p *paymentResumeLookupProvider) QueryOrder(context.Context, string) (*payment.QueryOrderResponse, error) {
p.queryCount++
return &payment.QueryOrderResponse{Status: payment.ProviderStatusPending}, nil
}
func (p *paymentResumeLookupProvider) VerifyNotification(context.Context, string, map[string]string) (*payment.PaymentNotification, error) {
panic("unexpected call")
}
func (p *paymentResumeLookupProvider) Refund(context.Context, payment.RefundRequest) (*payment.RefundResponse, error) {
panic("unexpected call")
}
func TestGetPublicOrderByResumeTokenReturnsMatchingOrder(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
user, err := client.User.Create().
SetEmail("resume@example.com").
SetPasswordHash("hash").
SetUsername("resume-user").
Save(ctx)
require.NoError(t, err)
instanceID := "12"
providerKey := payment.TypeEasyPay
order, err := client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(88).
SetPayAmount(88).
SetFeeRate(0).
SetRechargeCode("RESUME-ORDER").
SetOutTradeNo("sub2_resume_lookup").
SetPaymentType(payment.TypeAlipay).
SetPaymentTradeNo("trade-1").
SetOrderType(payment.OrderTypeBalance).
SetStatus(OrderStatusPending).
SetExpiresAt(time.Now().Add(time.Hour)).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
SetProviderInstanceID(instanceID).
SetProviderKey(providerKey).
Save(ctx)
require.NoError(t, err)
resumeSvc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
token, err := resumeSvc.CreateToken(ResumeTokenClaims{
OrderID: order.ID,
UserID: user.ID,
ProviderInstanceID: instanceID,
ProviderKey: providerKey,
PaymentType: payment.TypeAlipay,
CanonicalReturnURL: "https://app.example.com/payment/result",
})
require.NoError(t, err)
svc := &PaymentService{
entClient: client,
resumeService: resumeSvc,
}
got, err := svc.GetPublicOrderByResumeToken(ctx, token)
require.NoError(t, err)
require.Equal(t, order.ID, got.ID)
}
func TestGetPublicOrderByResumeTokenRejectsSnapshotMismatch(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
user, err := client.User.Create().
SetEmail("resume-mismatch@example.com").
SetPasswordHash("hash").
SetUsername("resume-mismatch-user").
Save(ctx)
require.NoError(t, err)
order, err := client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(88).
SetPayAmount(88).
SetFeeRate(0).
SetRechargeCode("RESUME-MISMATCH").
SetOutTradeNo("sub2_resume_lookup_mismatch").
SetPaymentType(payment.TypeAlipay).
SetPaymentTradeNo("trade-2").
SetOrderType(payment.OrderTypeBalance).
SetStatus(OrderStatusPending).
SetExpiresAt(time.Now().Add(time.Hour)).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
SetProviderInstanceID("12").
SetProviderKey(payment.TypeEasyPay).
Save(ctx)
require.NoError(t, err)
resumeSvc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
token, err := resumeSvc.CreateToken(ResumeTokenClaims{
OrderID: order.ID,
UserID: user.ID,
ProviderInstanceID: "99",
ProviderKey: payment.TypeEasyPay,
PaymentType: payment.TypeAlipay,
CanonicalReturnURL: "https://app.example.com/payment/result",
})
require.NoError(t, err)
svc := &PaymentService{
entClient: client,
resumeService: resumeSvc,
}
_, err = svc.GetPublicOrderByResumeToken(ctx, token)
require.Error(t, err)
require.Contains(t, err.Error(), "resume token")
}
func TestGetPublicOrderByResumeTokenUsesSnapshotAuthorityWhenColumnsDiffer(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
user, err := client.User.Create().
SetEmail("resume-snapshot-authority@example.com").
SetPasswordHash("hash").
SetUsername("resume-snapshot-authority-user").
Save(ctx)
require.NoError(t, err)
order, err := client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(88).
SetPayAmount(88).
SetFeeRate(0).
SetRechargeCode("RESUME-SNAPSHOT-AUTHORITY").
SetOutTradeNo("sub2_resume_snapshot_authority").
SetPaymentType(payment.TypeAlipay).
SetPaymentTradeNo("trade-snapshot-authority").
SetOrderType(payment.OrderTypeBalance).
SetStatus(OrderStatusPending).
SetExpiresAt(time.Now().Add(time.Hour)).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
SetProviderInstanceID("legacy-column-instance").
SetProviderKey(payment.TypeAlipay).
SetProviderSnapshot(map[string]any{
"schema_version": 2,
"provider_instance_id": "snapshot-instance",
"provider_key": payment.TypeEasyPay,
}).
Save(ctx)
require.NoError(t, err)
resumeSvc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
token, err := resumeSvc.CreateToken(ResumeTokenClaims{
OrderID: order.ID,
UserID: user.ID,
ProviderInstanceID: "snapshot-instance",
ProviderKey: payment.TypeEasyPay,
PaymentType: payment.TypeAlipay,
CanonicalReturnURL: "https://app.example.com/payment/result",
})
require.NoError(t, err)
svc := &PaymentService{
entClient: client,
resumeService: resumeSvc,
}
got, err := svc.GetPublicOrderByResumeToken(ctx, token)
require.NoError(t, err)
require.Equal(t, order.ID, got.ID)
}
func TestGetPublicOrderByResumeTokenChecksUpstreamForPendingOrder(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
user, err := client.User.Create().
SetEmail("resume-refresh@example.com").
SetPasswordHash("hash").
SetUsername("resume-refresh-user").
Save(ctx)
require.NoError(t, err)
order, err := client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(88).
SetPayAmount(88).
SetFeeRate(0).
SetRechargeCode("RESUME-PENDING").
SetOutTradeNo("sub2_resume_lookup_pending").
SetPaymentType(payment.TypeAlipay).
SetPaymentTradeNo("trade-pending").
SetOrderType(payment.OrderTypeBalance).
SetStatus(OrderStatusPending).
SetExpiresAt(time.Now().Add(time.Hour)).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
Save(ctx)
require.NoError(t, err)
resumeSvc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
token, err := resumeSvc.CreateToken(ResumeTokenClaims{
OrderID: order.ID,
UserID: user.ID,
PaymentType: payment.TypeAlipay,
CanonicalReturnURL: "https://app.example.com/payment/result",
})
require.NoError(t, err)
registry := payment.NewRegistry()
provider := &paymentResumeLookupProvider{}
registry.Register(provider)
svc := &PaymentService{
entClient: client,
registry: registry,
resumeService: resumeSvc,
providersLoaded: true,
}
got, err := svc.GetPublicOrderByResumeToken(ctx, token)
require.NoError(t, err)
require.Equal(t, order.ID, got.ID)
require.Equal(t, 1, provider.queryCount)
}
func TestVerifyOrderPublicDoesNotCheckUpstreamForPendingOrder(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
user, err := client.User.Create().
SetEmail("public-verify@example.com").
SetPasswordHash("hash").
SetUsername("public-verify-user").
Save(ctx)
require.NoError(t, err)
order, err := client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(88).
SetPayAmount(88).
SetFeeRate(0).
SetRechargeCode("PUBLIC-VERIFY").
SetOutTradeNo("sub2_public_verify_pending").
SetPaymentType(payment.TypeAlipay).
SetPaymentTradeNo("trade-public-verify").
SetOrderType(payment.OrderTypeBalance).
SetStatus(OrderStatusPending).
SetExpiresAt(time.Now().Add(time.Hour)).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
Save(ctx)
require.NoError(t, err)
registry := payment.NewRegistry()
provider := &paymentResumeLookupProvider{}
registry.Register(provider)
svc := &PaymentService{
entClient: client,
registry: registry,
providersLoaded: true,
}
got, err := svc.VerifyOrderPublic(ctx, order.OutTradeNo)
require.NoError(t, err)
require.Equal(t, order.ID, got.ID)
require.Equal(t, 0, provider.queryCount)
}
package service
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"net"
"net/url"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
const paymentResultReturnPath = "/payment/result"
const (
PaymentSourceHostedRedirect = "hosted_redirect"
PaymentSourceWechatInAppResume = "wechat_in_app_resume"
SettingPaymentVisibleMethodAlipaySource = "payment_visible_method_alipay_source"
SettingPaymentVisibleMethodWxpaySource = "payment_visible_method_wxpay_source"
SettingPaymentVisibleMethodAlipayEnabled = "payment_visible_method_alipay_enabled"
SettingPaymentVisibleMethodWxpayEnabled = "payment_visible_method_wxpay_enabled"
VisibleMethodSourceOfficialAlipay = "official_alipay"
VisibleMethodSourceEasyPayAlipay = "easypay_alipay"
VisibleMethodSourceOfficialWechat = "official_wxpay"
VisibleMethodSourceEasyPayWechat = "easypay_wxpay"
wechatPaymentResumeTokenType = "wechat_payment_resume"
paymentResumeNotConfiguredCode = "PAYMENT_RESUME_NOT_CONFIGURED"
paymentResumeNotConfiguredMessage = "payment resume tokens require a configured signing key"
paymentResumeTokenTTL = 24 * time.Hour
wechatPaymentResumeTokenTTL = 15 * time.Minute
)
type ResumeTokenClaims struct {
OrderID int64 `json:"oid"`
UserID int64 `json:"uid,omitempty"`
ProviderInstanceID string `json:"pi,omitempty"`
ProviderKey string `json:"pk,omitempty"`
PaymentType string `json:"pt,omitempty"`
CanonicalReturnURL string `json:"ru,omitempty"`
IssuedAt int64 `json:"iat"`
ExpiresAt int64 `json:"exp,omitempty"`
}
type WeChatPaymentResumeClaims struct {
TokenType string `json:"tk,omitempty"`
OpenID string `json:"openid"`
PaymentType string `json:"pt,omitempty"`
Amount string `json:"amt,omitempty"`
OrderType string `json:"ot,omitempty"`
PlanID int64 `json:"pid,omitempty"`
RedirectTo string `json:"rd,omitempty"`
Scope string `json:"scp,omitempty"`
IssuedAt int64 `json:"iat"`
ExpiresAt int64 `json:"exp,omitempty"`
}
type PaymentResumeService struct {
signingKey []byte
}
type visibleMethodLoadBalancer struct {
inner payment.LoadBalancer
configService *PaymentConfigService
}
func NewPaymentResumeService(signingKey []byte) *PaymentResumeService {
return &PaymentResumeService{signingKey: signingKey}
}
func (s *PaymentResumeService) isSigningConfigured() bool {
return s != nil && len(s.signingKey) > 0
}
func (s *PaymentResumeService) ensureSigningKey() error {
if s.isSigningConfigured() {
return nil
}
return infraerrors.ServiceUnavailable(paymentResumeNotConfiguredCode, paymentResumeNotConfiguredMessage)
}
func NormalizeVisibleMethod(method string) string {
return payment.GetBasePaymentType(strings.TrimSpace(method))
}
func NormalizeVisibleMethods(methods []string) []string {
if len(methods) == 0 {
return nil
}
seen := make(map[string]struct{}, len(methods))
out := make([]string, 0, len(methods))
for _, method := range methods {
normalized := NormalizeVisibleMethod(method)
if normalized == "" {
continue
}
if _, ok := seen[normalized]; ok {
continue
}
seen[normalized] = struct{}{}
out = append(out, normalized)
}
return out
}
func NormalizePaymentSource(source string) string {
switch strings.TrimSpace(strings.ToLower(source)) {
case "", PaymentSourceHostedRedirect:
return PaymentSourceHostedRedirect
case "wechat_in_app", "wxpay_resume", PaymentSourceWechatInAppResume:
return PaymentSourceWechatInAppResume
default:
return strings.TrimSpace(strings.ToLower(source))
}
}
func NormalizeVisibleMethodSource(method, source string) string {
switch NormalizeVisibleMethod(method) {
case payment.TypeAlipay:
switch strings.TrimSpace(strings.ToLower(source)) {
case VisibleMethodSourceOfficialAlipay, payment.TypeAlipay, payment.TypeAlipayDirect, "official":
return VisibleMethodSourceOfficialAlipay
case VisibleMethodSourceEasyPayAlipay, payment.TypeEasyPay:
return VisibleMethodSourceEasyPayAlipay
}
case payment.TypeWxpay:
switch strings.TrimSpace(strings.ToLower(source)) {
case VisibleMethodSourceOfficialWechat, payment.TypeWxpay, payment.TypeWxpayDirect, "wechat", "official":
return VisibleMethodSourceOfficialWechat
case VisibleMethodSourceEasyPayWechat, payment.TypeEasyPay:
return VisibleMethodSourceEasyPayWechat
}
}
return ""
}
func VisibleMethodProviderKeyForSource(method, source string) (string, bool) {
switch NormalizeVisibleMethodSource(method, source) {
case VisibleMethodSourceOfficialAlipay:
return payment.TypeAlipay, NormalizeVisibleMethod(method) == payment.TypeAlipay
case VisibleMethodSourceEasyPayAlipay:
return payment.TypeEasyPay, NormalizeVisibleMethod(method) == payment.TypeAlipay
case VisibleMethodSourceOfficialWechat:
return payment.TypeWxpay, NormalizeVisibleMethod(method) == payment.TypeWxpay
case VisibleMethodSourceEasyPayWechat:
return payment.TypeEasyPay, NormalizeVisibleMethod(method) == payment.TypeWxpay
default:
return "", false
}
}
func newVisibleMethodLoadBalancer(inner payment.LoadBalancer, configService *PaymentConfigService) payment.LoadBalancer {
if inner == nil || configService == nil || configService.entClient == nil {
return inner
}
return &visibleMethodLoadBalancer{inner: inner, configService: configService}
}
func (lb *visibleMethodLoadBalancer) GetInstanceConfig(ctx context.Context, instanceID int64) (map[string]string, error) {
return lb.inner.GetInstanceConfig(ctx, instanceID)
}
func (lb *visibleMethodLoadBalancer) SelectInstance(ctx context.Context, providerKey string, paymentType payment.PaymentType, strategy payment.Strategy, orderAmount float64) (*payment.InstanceSelection, error) {
visibleMethod := NormalizeVisibleMethod(paymentType)
if providerKey != "" || (visibleMethod != payment.TypeAlipay && visibleMethod != payment.TypeWxpay) {
return lb.inner.SelectInstance(ctx, providerKey, paymentType, strategy, orderAmount)
}
inst, err := lb.configService.resolveEnabledVisibleMethodInstance(ctx, visibleMethod)
if err != nil {
return nil, err
}
if inst == nil {
return nil, fmt.Errorf("visible payment method %s has no enabled provider instance", visibleMethod)
}
return lb.inner.SelectInstance(ctx, inst.ProviderKey, paymentType, strategy, orderAmount)
}
func visibleMethodEnabledSettingKey(method string) string {
switch NormalizeVisibleMethod(method) {
case payment.TypeAlipay:
return SettingPaymentVisibleMethodAlipayEnabled
case payment.TypeWxpay:
return SettingPaymentVisibleMethodWxpayEnabled
default:
return ""
}
}
func visibleMethodSourceSettingKey(method string) string {
switch NormalizeVisibleMethod(method) {
case payment.TypeAlipay:
return SettingPaymentVisibleMethodAlipaySource
case payment.TypeWxpay:
return SettingPaymentVisibleMethodWxpaySource
default:
return ""
}
}
func CanonicalizeReturnURL(raw string, srcHost string) (string, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return "", nil
}
parsed, err := url.Parse(raw)
if err != nil || !parsed.IsAbs() || parsed.Host == "" {
return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must be an absolute http/https URL")
}
if parsed.Scheme != "http" && parsed.Scheme != "https" {
return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must use http or https")
}
parsed.Fragment = ""
if parsed.Path == "" {
parsed.Path = "/"
}
if parsed.Path != paymentResultReturnPath {
return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must target the canonical internal payment result page")
}
if !sameOriginHost(parsed.Host, srcHost) {
return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must use the same host as the current site")
}
return parsed.String(), nil
}
func buildPaymentReturnURL(base string, orderID int64, resumeToken string) (string, error) {
canonical := strings.TrimSpace(base)
if canonical == "" {
return "", nil
}
parsed, err := url.Parse(canonical)
if err != nil {
return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must be a valid URL")
}
if !parsed.IsAbs() || parsed.Host == "" {
return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must be a valid absolute URL")
}
parsed.Fragment = ""
query := parsed.Query()
if orderID > 0 {
query.Set("order_id", strconv.FormatInt(orderID, 10))
}
if strings.TrimSpace(resumeToken) != "" {
query.Set("resume_token", strings.TrimSpace(resumeToken))
}
query.Set("status", "success")
parsed.RawQuery = query.Encode()
return parsed.String(), nil
}
func sameOriginHost(returnURLHost string, requestHost string) bool {
returnHost := strings.TrimSpace(returnURLHost)
reqHost := strings.TrimSpace(requestHost)
if returnHost == "" || reqHost == "" {
return false
}
if strings.EqualFold(returnHost, reqHost) {
return true
}
returnName, returnPort := splitHostPortDefault(returnHost)
reqName, reqPort := splitHostPortDefault(reqHost)
if returnName == "" || reqName == "" {
return false
}
return strings.EqualFold(returnName, reqName) && returnPort == reqPort
}
func splitHostPortDefault(raw string) (string, string) {
if host, port, err := net.SplitHostPort(raw); err == nil {
return host, port
}
return raw, ""
}
func (s *PaymentResumeService) CreateToken(claims ResumeTokenClaims) (string, error) {
if err := s.ensureSigningKey(); err != nil {
return "", err
}
if claims.OrderID <= 0 {
return "", fmt.Errorf("resume token requires order id")
}
if claims.IssuedAt == 0 {
claims.IssuedAt = time.Now().Unix()
}
if claims.ExpiresAt == 0 {
claims.ExpiresAt = time.Now().Add(paymentResumeTokenTTL).Unix()
}
return s.createSignedToken(claims)
}
func (s *PaymentResumeService) ParseToken(token string) (*ResumeTokenClaims, error) {
if err := s.ensureSigningKey(); err != nil {
return nil, err
}
var claims ResumeTokenClaims
if err := s.parseSignedToken(token, &claims); err != nil {
return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is invalid")
}
if claims.OrderID <= 0 {
return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token missing order id")
}
if err := validatePaymentResumeExpiry(claims.ExpiresAt, "INVALID_RESUME_TOKEN", "resume token has expired"); err != nil {
return nil, err
}
return &claims, nil
}
func (s *PaymentResumeService) CreateWeChatPaymentResumeToken(claims WeChatPaymentResumeClaims) (string, error) {
if err := s.ensureSigningKey(); err != nil {
return "", err
}
claims.OpenID = strings.TrimSpace(claims.OpenID)
if claims.OpenID == "" {
return "", fmt.Errorf("wechat payment resume token requires openid")
}
if claims.IssuedAt == 0 {
claims.IssuedAt = time.Now().Unix()
}
if claims.ExpiresAt == 0 {
claims.ExpiresAt = time.Now().Add(wechatPaymentResumeTokenTTL).Unix()
}
if normalized := NormalizeVisibleMethod(claims.PaymentType); normalized != "" {
claims.PaymentType = normalized
}
if claims.PaymentType == "" {
claims.PaymentType = payment.TypeWxpay
}
if claims.OrderType == "" {
claims.OrderType = payment.OrderTypeBalance
}
claims.TokenType = wechatPaymentResumeTokenType
return s.createSignedToken(claims)
}
func (s *PaymentResumeService) ParseWeChatPaymentResumeToken(token string) (*WeChatPaymentResumeClaims, error) {
if err := s.ensureSigningKey(); err != nil {
return nil, err
}
var claims WeChatPaymentResumeClaims
if err := s.parseSignedToken(token, &claims); err != nil {
return nil, infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token payload is invalid")
}
if claims.TokenType != wechatPaymentResumeTokenType {
return nil, infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token type mismatch")
}
claims.OpenID = strings.TrimSpace(claims.OpenID)
if claims.OpenID == "" {
return nil, infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token missing openid")
}
if err := validatePaymentResumeExpiry(claims.ExpiresAt, "INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token has expired"); err != nil {
return nil, err
}
if normalized := NormalizeVisibleMethod(claims.PaymentType); normalized != "" {
claims.PaymentType = normalized
}
if claims.PaymentType == "" {
claims.PaymentType = payment.TypeWxpay
}
if claims.OrderType == "" {
claims.OrderType = payment.OrderTypeBalance
}
return &claims, nil
}
func (s *PaymentResumeService) createSignedToken(claims any) (string, error) {
payload, err := json.Marshal(claims)
if err != nil {
return "", fmt.Errorf("marshal resume claims: %w", err)
}
encodedPayload := base64.RawURLEncoding.EncodeToString(payload)
return encodedPayload + "." + s.sign(encodedPayload), nil
}
func (s *PaymentResumeService) parseSignedToken(token string, dest any) error {
parts := strings.Split(token, ".")
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token is malformed")
}
if !hmac.Equal([]byte(parts[1]), []byte(s.sign(parts[0]))) {
return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token signature mismatch")
}
payload, err := base64.RawURLEncoding.DecodeString(parts[0])
if err != nil {
return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is malformed")
}
return json.Unmarshal(payload, dest)
}
func validatePaymentResumeExpiry(expiresAt int64, code, message string) error {
if expiresAt <= 0 {
return nil
}
if time.Now().Unix() > expiresAt {
return infraerrors.BadRequest(code, message)
}
return nil
}
func (s *PaymentResumeService) sign(payload string) string {
mac := hmac.New(sha256.New, s.signingKey)
_, _ = mac.Write([]byte(payload))
return base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
}
//go:build unit
package service
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"net/url"
"strconv"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/payment"
)
func TestNormalizeVisibleMethods(t *testing.T) {
t.Parallel()
got := NormalizeVisibleMethods([]string{
"alipay_direct",
"alipay",
" wxpay_direct ",
"wxpay",
"stripe",
})
want := []string{"alipay", "wxpay", "stripe"}
if len(got) != len(want) {
t.Fatalf("NormalizeVisibleMethods len = %d, want %d (%v)", len(got), len(want), got)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("NormalizeVisibleMethods[%d] = %q, want %q (full=%v)", i, got[i], want[i], got)
}
}
}
func TestNormalizePaymentSource(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
expect string
}{
{name: "empty uses default", input: "", expect: PaymentSourceHostedRedirect},
{name: "wechat alias normalized", input: "wechat_in_app", expect: PaymentSourceWechatInAppResume},
{name: "canonical value preserved", input: PaymentSourceWechatInAppResume, expect: PaymentSourceWechatInAppResume},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := NormalizePaymentSource(tt.input); got != tt.expect {
t.Fatalf("NormalizePaymentSource(%q) = %q, want %q", tt.input, got, tt.expect)
}
})
}
}
func TestCanonicalizeReturnURL(t *testing.T) {
t.Parallel()
got, err := CanonicalizeReturnURL("https://example.com/payment/result?b=2#a", "example.com")
if err != nil {
t.Fatalf("CanonicalizeReturnURL returned error: %v", err)
}
if got != "https://example.com/payment/result?b=2" {
t.Fatalf("CanonicalizeReturnURL = %q, want %q", got, "https://example.com/payment/result?b=2")
}
}
func TestCanonicalizeReturnURLRejectsRelativeURL(t *testing.T) {
t.Parallel()
if _, err := CanonicalizeReturnURL("/payment/result", "example.com"); err == nil {
t.Fatal("CanonicalizeReturnURL should reject relative URLs")
}
}
func TestCanonicalizeReturnURLRejectsExternalHost(t *testing.T) {
t.Parallel()
if _, err := CanonicalizeReturnURL("https://evil.example/payment/result", "app.example.com"); err == nil {
t.Fatal("CanonicalizeReturnURL should reject external hosts")
}
}
func TestCanonicalizeReturnURLRejectsNonCanonicalPath(t *testing.T) {
t.Parallel()
if _, err := CanonicalizeReturnURL("https://app.example.com/orders/42", "app.example.com"); err == nil {
t.Fatal("CanonicalizeReturnURL should reject non-canonical result paths")
}
}
func TestBuildPaymentReturnURL(t *testing.T) {
t.Parallel()
got, err := buildPaymentReturnURL("https://example.com/payment/result?from=checkout#fragment", 42, "resume-token")
if err != nil {
t.Fatalf("buildPaymentReturnURL returned error: %v", err)
}
parsed, err := url.Parse(got)
if err != nil {
t.Fatalf("url.Parse returned error: %v", err)
}
if parsed.Fragment != "" {
t.Fatalf("buildPaymentReturnURL should strip fragments, got %q", parsed.Fragment)
}
query := parsed.Query()
if query.Get("from") != "checkout" {
t.Fatalf("expected original query to be preserved, got %q", query.Get("from"))
}
if query.Get("order_id") != strconv.FormatInt(42, 10) {
t.Fatalf("order_id = %q", query.Get("order_id"))
}
if query.Get("resume_token") != "resume-token" {
t.Fatalf("resume_token = %q", query.Get("resume_token"))
}
if query.Get("status") != "success" {
t.Fatalf("status = %q", query.Get("status"))
}
}
func TestBuildPaymentReturnURLEmptyBase(t *testing.T) {
t.Parallel()
got, err := buildPaymentReturnURL("", 42, "resume-token")
if err != nil {
t.Fatalf("buildPaymentReturnURL returned error: %v", err)
}
if got != "" {
t.Fatalf("buildPaymentReturnURL = %q, want empty string", got)
}
}
func TestPaymentResumeTokenRoundTrip(t *testing.T) {
t.Parallel()
svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
token, err := svc.CreateToken(ResumeTokenClaims{
OrderID: 42,
UserID: 7,
ProviderInstanceID: "19",
ProviderKey: "easypay",
PaymentType: "wxpay",
CanonicalReturnURL: "https://example.com/payment/result",
IssuedAt: 1234567890,
})
if err != nil {
t.Fatalf("CreateToken returned error: %v", err)
}
claims, err := svc.ParseToken(token)
if err != nil {
t.Fatalf("ParseToken returned error: %v", err)
}
if claims.OrderID != 42 || claims.UserID != 7 {
t.Fatalf("claims mismatch: %+v", claims)
}
if claims.ProviderInstanceID != "19" || claims.ProviderKey != "easypay" || claims.PaymentType != "wxpay" {
t.Fatalf("claims provider snapshot mismatch: %+v", claims)
}
if claims.CanonicalReturnURL != "https://example.com/payment/result" {
t.Fatalf("claims return URL = %q", claims.CanonicalReturnURL)
}
}
func TestCreateTokenRejectsMissingSigningKey(t *testing.T) {
t.Parallel()
svc := NewPaymentResumeService(nil)
_, err := svc.CreateToken(ResumeTokenClaims{OrderID: 42})
if err == nil {
t.Fatal("CreateToken should reject missing signing key")
}
}
func TestParseTokenRejectsFallbackSignedTokenWhenSigningKeyMissing(t *testing.T) {
t.Parallel()
token := mustCreateFallbackSignedToken(t, ResumeTokenClaims{OrderID: 42, UserID: 7})
svc := NewPaymentResumeService(nil)
_, err := svc.ParseToken(token)
if err == nil {
t.Fatal("ParseToken should reject tokens when signing key is missing")
}
}
func TestParseTokenRejectsExpiredToken(t *testing.T) {
t.Parallel()
svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
token, err := svc.CreateToken(ResumeTokenClaims{
OrderID: 42,
UserID: 7,
IssuedAt: time.Now().Add(-25 * time.Hour).Unix(),
ExpiresAt: time.Now().Add(-1 * time.Hour).Unix(),
})
if err != nil {
t.Fatalf("CreateToken returned error: %v", err)
}
_, err = svc.ParseToken(token)
if err == nil {
t.Fatal("ParseToken should reject expired tokens")
}
}
func TestWeChatPaymentResumeTokenRoundTrip(t *testing.T) {
t.Parallel()
svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
token, err := svc.CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
OpenID: "openid-123",
PaymentType: payment.TypeWxpay,
Amount: "12.50",
OrderType: payment.OrderTypeSubscription,
PlanID: 7,
RedirectTo: "/purchase?from=wechat",
Scope: "snsapi_base",
IssuedAt: 1234567890,
})
if err != nil {
t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
}
claims, err := svc.ParseWeChatPaymentResumeToken(token)
if err != nil {
t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err)
}
if claims.OpenID != "openid-123" || claims.PaymentType != payment.TypeWxpay {
t.Fatalf("claims mismatch: %+v", claims)
}
if claims.Amount != "12.50" || claims.OrderType != payment.OrderTypeSubscription || claims.PlanID != 7 {
t.Fatalf("claims payment context mismatch: %+v", claims)
}
if claims.RedirectTo != "/purchase?from=wechat" || claims.Scope != "snsapi_base" {
t.Fatalf("claims redirect/scope mismatch: %+v", claims)
}
}
func TestCreateWeChatPaymentResumeTokenRejectsMissingSigningKey(t *testing.T) {
t.Parallel()
svc := NewPaymentResumeService(nil)
_, err := svc.CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{OpenID: "openid-123"})
if err == nil {
t.Fatal("CreateWeChatPaymentResumeToken should reject missing signing key")
}
}
func TestParseWeChatPaymentResumeTokenRejectsFallbackSignedTokenWhenSigningKeyMissing(t *testing.T) {
t.Parallel()
token := mustCreateFallbackSignedToken(t, WeChatPaymentResumeClaims{
TokenType: wechatPaymentResumeTokenType,
OpenID: "openid-123",
PaymentType: payment.TypeWxpay,
})
svc := NewPaymentResumeService(nil)
_, err := svc.ParseWeChatPaymentResumeToken(token)
if err == nil {
t.Fatal("ParseWeChatPaymentResumeToken should reject tokens when signing key is missing")
}
}
func TestParseWeChatPaymentResumeTokenRejectsExpiredToken(t *testing.T) {
t.Parallel()
svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
token, err := svc.CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
OpenID: "openid-123",
PaymentType: payment.TypeWxpay,
IssuedAt: time.Now().Add(-30 * time.Minute).Unix(),
ExpiresAt: time.Now().Add(-1 * time.Minute).Unix(),
})
if err != nil {
t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
}
_, err = svc.ParseWeChatPaymentResumeToken(token)
if err == nil {
t.Fatal("ParseWeChatPaymentResumeToken should reject expired tokens")
}
}
func TestNormalizeVisibleMethodSource(t *testing.T) {
t.Parallel()
tests := []struct {
name string
method string
input string
want string
}{
{name: "alipay official alias", method: payment.TypeAlipay, input: "alipay", want: VisibleMethodSourceOfficialAlipay},
{name: "alipay easypay alias", method: payment.TypeAlipay, input: "easypay", want: VisibleMethodSourceEasyPayAlipay},
{name: "wxpay official alias", method: payment.TypeWxpay, input: "wxpay", want: VisibleMethodSourceOfficialWechat},
{name: "wxpay easypay alias", method: payment.TypeWxpay, input: "easypay", want: VisibleMethodSourceEasyPayWechat},
{name: "unsupported source", method: payment.TypeWxpay, input: "stripe", want: ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := NormalizeVisibleMethodSource(tt.method, tt.input); got != tt.want {
t.Fatalf("NormalizeVisibleMethodSource(%q, %q) = %q, want %q", tt.method, tt.input, got, tt.want)
}
})
}
}
func TestVisibleMethodProviderKeyForSource(t *testing.T) {
t.Parallel()
tests := []struct {
name string
method string
source string
want string
ok bool
}{
{name: "official alipay", method: payment.TypeAlipay, source: VisibleMethodSourceOfficialAlipay, want: payment.TypeAlipay, ok: true},
{name: "easypay alipay", method: payment.TypeAlipay, source: VisibleMethodSourceEasyPayAlipay, want: payment.TypeEasyPay, ok: true},
{name: "official wechat", method: payment.TypeWxpay, source: VisibleMethodSourceOfficialWechat, want: payment.TypeWxpay, ok: true},
{name: "easypay wechat", method: payment.TypeWxpay, source: VisibleMethodSourceEasyPayWechat, want: payment.TypeEasyPay, ok: true},
{name: "mismatched method and source", method: payment.TypeAlipay, source: VisibleMethodSourceOfficialWechat, want: "", ok: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, ok := VisibleMethodProviderKeyForSource(tt.method, tt.source)
if got != tt.want || ok != tt.ok {
t.Fatalf("VisibleMethodProviderKeyForSource(%q, %q) = (%q, %v), want (%q, %v)", tt.method, tt.source, got, ok, tt.want, tt.ok)
}
})
}
}
func TestVisibleMethodLoadBalancerUsesEnabledProviderInstance(t *testing.T) {
t.Parallel()
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
_, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeAlipay).
SetName("Official Alipay").
SetConfig("{}").
SetSupportedTypes("alipay").
SetEnabled(true).
SetSortOrder(1).
Save(ctx)
if err != nil {
t.Fatalf("create alipay provider: %v", err)
}
inner := &captureLoadBalancer{}
configService := &PaymentConfigService{
entClient: client,
}
lb := newVisibleMethodLoadBalancer(inner, configService)
_, err = lb.SelectInstance(ctx, "", payment.TypeAlipay, payment.StrategyRoundRobin, 12.5)
if err != nil {
t.Fatalf("SelectInstance returned error: %v", err)
}
if inner.lastProviderKey != payment.TypeAlipay {
t.Fatalf("lastProviderKey = %q, want %q", inner.lastProviderKey, payment.TypeAlipay)
}
}
func TestVisibleMethodLoadBalancerRejectsMissingEnabledVisibleMethodProvider(t *testing.T) {
t.Parallel()
inner := &captureLoadBalancer{}
configService := &PaymentConfigService{
entClient: newPaymentConfigServiceTestClient(t),
}
lb := newVisibleMethodLoadBalancer(inner, configService)
if _, err := lb.SelectInstance(context.Background(), "", payment.TypeWxpay, payment.StrategyRoundRobin, 9.9); err == nil {
t.Fatal("SelectInstance should reject when no enabled provider instance exists")
}
}
type captureLoadBalancer struct {
lastProviderKey string
lastPaymentType string
}
func (c *captureLoadBalancer) GetInstanceConfig(context.Context, int64) (map[string]string, error) {
return map[string]string{}, nil
}
func (c *captureLoadBalancer) SelectInstance(_ context.Context, providerKey string, paymentType payment.PaymentType, _ payment.Strategy, _ float64) (*payment.InstanceSelection, error) {
c.lastProviderKey = providerKey
c.lastPaymentType = paymentType
return &payment.InstanceSelection{ProviderKey: providerKey, SupportedTypes: paymentType}, nil
}
func mustCreateFallbackSignedToken(t *testing.T, claims any) string {
t.Helper()
payload, err := json.Marshal(claims)
if err != nil {
t.Fatalf("marshal claims: %v", err)
}
encodedPayload := base64.RawURLEncoding.EncodeToString(payload)
mac := hmac.New(sha256.New, []byte("sub2api-payment-resume"))
_, _ = mac.Write([]byte(encodedPayload))
signature := base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
return encodedPayload + "." + signature
}
......@@ -9,7 +9,6 @@ import (
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/Wei-Shaw/sub2api/internal/payment/provider"
......@@ -68,10 +67,14 @@ type CreateOrderRequest struct {
UserID int64
Amount float64
PaymentType string
OpenID string
ClientIP string
IsMobile bool
IsWeChatBrowser bool
SrcHost string
SrcURL string
ReturnURL string
PaymentSource string
OrderType string
PlanID int64
}
......@@ -82,12 +85,18 @@ type CreateOrderResponse struct {
PayAmount float64 `json:"pay_amount"`
FeeRate float64 `json:"fee_rate"`
Status string `json:"status"`
ResultType payment.CreatePaymentResultType `json:"result_type,omitempty"`
PaymentType string `json:"payment_type"`
OutTradeNo string `json:"out_trade_no,omitempty"`
PayURL string `json:"pay_url,omitempty"`
QRCode string `json:"qr_code,omitempty"`
ClientSecret string `json:"client_secret,omitempty"`
OAuth *payment.WechatOAuthInfo `json:"oauth,omitempty"`
JSAPI *payment.WechatJSAPIPayload `json:"jsapi,omitempty"`
JSAPIPayload *payment.WechatJSAPIPayload `json:"jsapi_payload,omitempty"`
ExpiresAt time.Time `json:"expires_at"`
PaymentMode string `json:"payment_mode,omitempty"`
ResumeToken string `json:"resume_token,omitempty"`
}
type OrderListParams struct {
......@@ -165,10 +174,13 @@ type PaymentService struct {
configService *PaymentConfigService
userRepo UserRepository
groupRepo GroupRepository
resumeService *PaymentResumeService
}
func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService {
return &PaymentService{entClient: entClient, registry: registry, loadBalancer: loadBalancer, redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo}
svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo}
svc.resumeService = NewPaymentResumeService(psResumeSigningKey(configService))
return svc
}
// --- Provider Registry ---
......@@ -219,25 +231,6 @@ func (s *PaymentService) loadProviders(ctx context.Context) {
}
}
// GetWebhookProvider returns the provider instance that should verify a webhook.
// It extracts out_trade_no from the raw body, looks up the order to find the
// original provider instance, and creates a provider with that instance's credentials.
// Falls back to the registry provider when the order cannot be found.
func (s *PaymentService) GetWebhookProvider(ctx context.Context, providerKey, outTradeNo string) (payment.Provider, error) {
if outTradeNo != "" {
order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(outTradeNo)).Only(ctx)
if err == nil {
p, pErr := s.getOrderProvider(ctx, order)
if pErr == nil {
return p, nil
}
slog.Warn("[Webhook] order provider creation failed, falling back to registry", "outTradeNo", outTradeNo, "error", pErr)
}
}
s.EnsureProviders(ctx)
return s.registry.GetProviderByKey(providerKey)
}
// --- Helpers ---
func psIsRefundStatus(s string) bool {
......@@ -262,6 +255,20 @@ func psNilIfEmpty(s string) *string {
return &s
}
func (s *PaymentService) paymentResume() *PaymentResumeService {
if s.resumeService != nil {
return s.resumeService
}
return NewPaymentResumeService(psResumeSigningKey(s.configService))
}
func psResumeSigningKey(configService *PaymentConfigService) []byte {
if configService == nil {
return nil
}
return configService.encryptionKey
}
func psSliceContains(sl []string, s string) bool {
for _, v := range sl {
if v == s {
......
package service
import (
"context"
"fmt"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
func enabledVisibleMethodsForProvider(providerKey, supportedTypes string) []string {
methodSet := make(map[string]struct{}, 2)
addMethod := func(method string) {
method = NormalizeVisibleMethod(method)
switch method {
case payment.TypeAlipay, payment.TypeWxpay:
methodSet[method] = struct{}{}
}
}
switch strings.TrimSpace(providerKey) {
case payment.TypeAlipay:
if strings.TrimSpace(supportedTypes) == "" {
addMethod(payment.TypeAlipay)
break
}
for _, supportedType := range splitTypes(supportedTypes) {
if NormalizeVisibleMethod(supportedType) == payment.TypeAlipay {
addMethod(payment.TypeAlipay)
break
}
}
case payment.TypeWxpay:
if strings.TrimSpace(supportedTypes) == "" {
addMethod(payment.TypeWxpay)
break
}
for _, supportedType := range splitTypes(supportedTypes) {
if NormalizeVisibleMethod(supportedType) == payment.TypeWxpay {
addMethod(payment.TypeWxpay)
break
}
}
case payment.TypeEasyPay:
for _, supportedType := range splitTypes(supportedTypes) {
addMethod(supportedType)
}
}
methods := make([]string, 0, len(methodSet))
for _, method := range []string{payment.TypeAlipay, payment.TypeWxpay} {
if _, ok := methodSet[method]; ok {
methods = append(methods, method)
}
}
return methods
}
func providerSupportsVisibleMethod(inst *dbent.PaymentProviderInstance, method string) bool {
if inst == nil || !inst.Enabled {
return false
}
method = NormalizeVisibleMethod(method)
for _, candidate := range enabledVisibleMethodsForProvider(inst.ProviderKey, inst.SupportedTypes) {
if candidate == method {
return true
}
}
return false
}
func filterEnabledVisibleMethodInstances(instances []*dbent.PaymentProviderInstance, method string) []*dbent.PaymentProviderInstance {
filtered := make([]*dbent.PaymentProviderInstance, 0, len(instances))
for _, inst := range instances {
if providerSupportsVisibleMethod(inst, method) {
filtered = append(filtered, inst)
}
}
return filtered
}
func buildPaymentProviderConflictError(method string, conflicting *dbent.PaymentProviderInstance) error {
metadata := map[string]string{
"payment_method": NormalizeVisibleMethod(method),
}
if conflicting != nil {
metadata["conflicting_provider_id"] = fmt.Sprintf("%d", conflicting.ID)
metadata["conflicting_provider_key"] = conflicting.ProviderKey
metadata["conflicting_provider_name"] = conflicting.Name
}
return infraerrors.Conflict(
"PAYMENT_PROVIDER_CONFLICT",
fmt.Sprintf("%s payment already has an enabled provider instance", NormalizeVisibleMethod(method)),
).WithMetadata(metadata)
}
func (s *PaymentConfigService) validateVisibleMethodEnablementConflicts(
ctx context.Context,
excludeID int64,
providerKey string,
supportedTypes string,
enabled bool,
) error {
if s == nil || s.entClient == nil || !enabled {
return nil
}
claimedMethods := enabledVisibleMethodsForProvider(providerKey, supportedTypes)
if len(claimedMethods) == 0 {
return nil
}
query := s.entClient.PaymentProviderInstance.Query().
Where(paymentproviderinstance.EnabledEQ(true))
if excludeID > 0 {
query = query.Where(paymentproviderinstance.IDNEQ(excludeID))
}
instances, err := query.All(ctx)
if err != nil {
return fmt.Errorf("query enabled payment providers: %w", err)
}
for _, method := range claimedMethods {
for _, inst := range instances {
if providerSupportsVisibleMethod(inst, method) {
return buildPaymentProviderConflictError(method, inst)
}
}
}
return nil
}
func (s *PaymentConfigService) resolveEnabledVisibleMethodInstance(
ctx context.Context,
method string,
) (*dbent.PaymentProviderInstance, error) {
if s == nil || s.entClient == nil {
return nil, nil
}
method = NormalizeVisibleMethod(method)
if method != payment.TypeAlipay && method != payment.TypeWxpay {
return nil, nil
}
instances, err := s.entClient.PaymentProviderInstance.Query().
Where(paymentproviderinstance.EnabledEQ(true)).
Order(paymentproviderinstance.BySortOrder()).
All(ctx)
if err != nil {
return nil, fmt.Errorf("query enabled payment providers: %w", err)
}
matching := filterEnabledVisibleMethodInstances(instances, method)
switch len(matching) {
case 0:
return nil, nil
case 1:
return matching[0], nil
default:
return nil, buildPaymentProviderConflictError(method, matching[0])
}
}
package service
import (
"context"
"fmt"
"log/slog"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
"github.com/Wei-Shaw/sub2api/internal/payment"
)
// GetWebhookProvider returns the provider instance that should verify a webhook.
// It resolves the original provider instance from the order whenever possible and
// only falls back to a registry provider for legacy/single-instance scenarios.
func (s *PaymentService) GetWebhookProvider(ctx context.Context, providerKey, outTradeNo string) (payment.Provider, error) {
providers, err := s.GetWebhookProviders(ctx, providerKey, outTradeNo)
if err != nil {
return nil, err
}
if len(providers) == 0 {
return nil, payment.ErrProviderNotFound
}
return providers[0], nil
}
// GetWebhookProviders returns provider candidates that can verify the webhook.
// Official WeChat Pay may require multiple candidates because the callback body
// cannot be bound to a merchant before decryption.
func (s *PaymentService) GetWebhookProviders(ctx context.Context, providerKey, outTradeNo string) ([]payment.Provider, error) {
if outTradeNo != "" {
order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(outTradeNo)).Only(ctx)
if err == nil {
if psHasPinnedProviderInstance(order) {
prov, err := s.getPinnedOrderProvider(ctx, order)
if err != nil {
return nil, err
}
return []payment.Provider{prov}, nil
}
inst, err := s.getOrderProviderInstance(ctx, order)
if err != nil {
return nil, fmt.Errorf("load order provider instance: %w", err)
}
if inst != nil {
prov, err := s.createProviderFromInstance(ctx, inst)
if err != nil {
return nil, err
}
return []payment.Provider{prov}, nil
}
if strings.TrimSpace(providerKey) == payment.TypeWxpay {
return s.getEnabledWebhookProvidersByKey(ctx, providerKey)
}
if !s.webhookRegistryFallbackAllowed(ctx, providerKey) {
return nil, fmt.Errorf("webhook provider fallback is ambiguous for %s", providerKey)
}
s.EnsureProviders(ctx)
prov, err := s.registry.GetProviderByKey(providerKey)
if err != nil {
return nil, err
}
return []payment.Provider{prov}, nil
}
}
if strings.TrimSpace(providerKey) == payment.TypeWxpay {
return s.getEnabledWebhookProvidersByKey(ctx, providerKey)
}
if !s.webhookRegistryFallbackAllowed(ctx, providerKey) {
return nil, fmt.Errorf("webhook provider fallback is ambiguous for %s", providerKey)
}
s.EnsureProviders(ctx)
prov, err := s.registry.GetProviderByKey(providerKey)
if err != nil {
return nil, err
}
return []payment.Provider{prov}, nil
}
func (s *PaymentService) getPinnedOrderProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) {
inst, err := s.getOrderProviderInstance(ctx, o)
if err != nil {
return nil, fmt.Errorf("load order provider instance: %w", err)
}
if inst == nil {
return nil, fmt.Errorf("order %d provider instance is missing", o.ID)
}
return s.createProviderFromInstance(ctx, inst)
}
func (s *PaymentService) webhookRegistryFallbackAllowed(ctx context.Context, providerKey string) bool {
providerKey = strings.TrimSpace(providerKey)
if providerKey == "" || s == nil || s.entClient == nil {
return false
}
count, err := s.entClient.PaymentProviderInstance.Query().
Where(
paymentproviderinstance.ProviderKeyEQ(providerKey),
paymentproviderinstance.EnabledEQ(true),
).
Count(ctx)
if err != nil {
slog.Warn("payment webhook fallback instance count failed", "provider", providerKey, "error", err)
return false
}
return count <= 1
}
func psHasPinnedProviderInstance(order *dbent.PaymentOrder) bool {
return order != nil && (psOrderProviderSnapshot(order) != nil || (order.ProviderInstanceID != nil && strings.TrimSpace(*order.ProviderInstanceID) != ""))
}
func (s *PaymentService) getEnabledWebhookProvidersByKey(ctx context.Context, providerKey string) ([]payment.Provider, error) {
providerKey = strings.TrimSpace(providerKey)
instances, err := s.entClient.PaymentProviderInstance.Query().
Where(
paymentproviderinstance.ProviderKeyEQ(providerKey),
paymentproviderinstance.EnabledEQ(true),
).
Order(dbent.Asc(paymentproviderinstance.FieldSortOrder)).
All(ctx)
if err != nil {
return nil, fmt.Errorf("query webhook provider instances: %w", err)
}
if len(instances) == 0 {
return nil, payment.ErrProviderNotFound
}
providers := make([]payment.Provider, 0, len(instances))
for _, inst := range instances {
prov, provErr := s.createProviderFromInstance(ctx, inst)
if provErr != nil {
slog.Warn("skip webhook provider instance", "provider", providerKey, "instanceID", inst.ID, "error", provErr)
continue
}
providers = append(providers, prov)
}
if len(providers) == 0 {
return nil, payment.ErrProviderNotFound
}
return providers, nil
}
//go:build unit
package service
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/json"
"encoding/pem"
"strconv"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/stretchr/testify/require"
)
const webhookProviderTestEncryptionKey = "0123456789abcdef0123456789abcdef"
type webhookProviderTestDouble struct {
key string
types []payment.PaymentType
}
func (p webhookProviderTestDouble) Name() string { return p.key }
func (p webhookProviderTestDouble) ProviderKey() string { return p.key }
func (p webhookProviderTestDouble) SupportedTypes() []payment.PaymentType { return p.types }
func (p webhookProviderTestDouble) CreatePayment(context.Context, payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
panic("unexpected call")
}
func (p webhookProviderTestDouble) QueryOrder(context.Context, string) (*payment.QueryOrderResponse, error) {
panic("unexpected call")
}
func (p webhookProviderTestDouble) VerifyNotification(context.Context, string, map[string]string) (*payment.PaymentNotification, error) {
panic("unexpected call")
}
func (p webhookProviderTestDouble) Refund(context.Context, payment.RefundRequest) (*payment.RefundResponse, error) {
panic("unexpected call")
}
func encryptWebhookProviderConfig(t *testing.T, config map[string]string) string {
t.Helper()
data, err := json.Marshal(config)
require.NoError(t, err)
encrypted, err := payment.Encrypt(string(data), []byte(webhookProviderTestEncryptionKey))
require.NoError(t, err)
return encrypted
}
func newWebhookProviderTestLoadBalancer(client *dbent.Client) payment.LoadBalancer {
return payment.NewDefaultLoadBalancer(client, []byte(webhookProviderTestEncryptionKey))
}
func encryptValidWebhookWxpayConfig(t *testing.T, suffix string) string {
t.Helper()
key, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
privDER, err := x509.MarshalPKCS8PrivateKey(key)
require.NoError(t, err)
pubDER, err := x509.MarshalPKIXPublicKey(&key.PublicKey)
require.NoError(t, err)
return encryptWebhookProviderConfig(t, map[string]string{
"appId": "wx-app-" + suffix,
"mchId": "mch-" + suffix,
"privateKey": string(pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privDER})),
"apiV3Key": webhookProviderTestEncryptionKey,
"publicKey": string(pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubDER})),
"publicKeyId": "public-key-id-" + suffix,
"certSerial": "cert-serial-" + suffix,
})
}
func TestGetOrderProviderInstanceResolvesUniqueLegacyProviderKey(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
inst, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeStripe).
SetName("stripe-a").
SetConfig(encryptWebhookProviderConfig(t, map[string]string{"secretKey": "sk_test_legacy_provider_key"})).
SetSupportedTypes("stripe").
SetEnabled(true).
Save(ctx)
require.NoError(t, err)
providerKey := payment.TypeStripe
order := &dbent.PaymentOrder{
PaymentType: payment.TypeStripe,
ProviderKey: &providerKey,
}
svc := &PaymentService{
entClient: client,
loadBalancer: newWebhookProviderTestLoadBalancer(client),
}
got, err := svc.getOrderProviderInstance(ctx, order)
require.NoError(t, err)
require.NotNil(t, got)
require.Equal(t, inst.ID, got.ID)
}
func TestGetOrderProviderInstanceResolvesUniqueLegacyPaymentType(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
inst, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeWxpay).
SetName("wxpay-a").
SetConfig("{}").
SetSupportedTypes("wxpay").
SetEnabled(true).
Save(ctx)
require.NoError(t, err)
order := &dbent.PaymentOrder{
PaymentType: payment.TypeWxpayDirect,
}
svc := &PaymentService{
entClient: client,
loadBalancer: newWebhookProviderTestLoadBalancer(client),
}
got, err := svc.getOrderProviderInstance(ctx, order)
require.NoError(t, err)
require.NotNil(t, got)
require.Equal(t, inst.ID, got.ID)
}
func TestGetOrderProviderInstanceLeavesAmbiguousLegacyOrderUnresolved(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
_, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeEasyPay).
SetName("easypay-a").
SetConfig("{}").
SetSupportedTypes("wxpay").
SetEnabled(true).
Save(ctx)
require.NoError(t, err)
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeWxpay).
SetName("wxpay-a").
SetConfig("{}").
SetSupportedTypes("wxpay").
SetEnabled(true).
Save(ctx)
require.NoError(t, err)
order := &dbent.PaymentOrder{
PaymentType: payment.TypeWxpay,
}
svc := &PaymentService{
entClient: client,
loadBalancer: newWebhookProviderTestLoadBalancer(client),
}
got, err := svc.getOrderProviderInstance(ctx, order)
require.NoError(t, err)
require.Nil(t, got)
}
func TestGetOrderProviderInstanceLeavesLegacyProviderKeyUnresolvedWhenHistoricalInstancesConflict(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
_, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeStripe).
SetName("stripe-disabled-legacy").
SetConfig("{}").
SetSupportedTypes("stripe").
SetEnabled(false).
Save(ctx)
require.NoError(t, err)
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeStripe).
SetName("stripe-enabled-current").
SetConfig("{}").
SetSupportedTypes("stripe").
SetEnabled(true).
Save(ctx)
require.NoError(t, err)
providerKey := payment.TypeStripe
order := &dbent.PaymentOrder{
PaymentType: payment.TypeStripe,
ProviderKey: &providerKey,
}
svc := &PaymentService{
entClient: client,
loadBalancer: newWebhookProviderTestLoadBalancer(client),
}
got, err := svc.getOrderProviderInstance(ctx, order)
require.NoError(t, err)
require.Nil(t, got)
}
func TestGetOrderProviderInstanceLeavesProviderKeyMatchUnresolvedWhenTypeNotSupported(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
_, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeWxpay).
SetName("wxpay-only").
SetConfig("{}").
SetSupportedTypes("wxpay").
SetEnabled(true).
Save(ctx)
require.NoError(t, err)
providerKey := payment.TypeWxpay
order := &dbent.PaymentOrder{
PaymentType: payment.TypeAlipayDirect,
ProviderKey: &providerKey,
}
svc := &PaymentService{
entClient: client,
loadBalancer: newWebhookProviderTestLoadBalancer(client),
}
got, err := svc.getOrderProviderInstance(ctx, order)
require.NoError(t, err)
require.Nil(t, got)
}
func TestGetOrderProviderInstanceUsesProviderSnapshotWhenPinnedColumnMissing(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
inst, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeStripe).
SetName("stripe-snapshot").
SetConfig(encryptWebhookProviderConfig(t, map[string]string{"secretKey": "sk_snapshot"})).
SetSupportedTypes("stripe").
SetEnabled(true).
Save(ctx)
require.NoError(t, err)
order := &dbent.PaymentOrder{
ID: 42,
PaymentType: payment.TypeStripe,
ProviderSnapshot: map[string]any{
"schema_version": 1,
"provider_instance_id": strconv.FormatInt(inst.ID, 10),
"provider_key": payment.TypeStripe,
},
}
svc := &PaymentService{
entClient: client,
loadBalancer: newWebhookProviderTestLoadBalancer(client),
}
got, err := svc.getOrderProviderInstance(ctx, order)
require.NoError(t, err)
require.NotNil(t, got)
require.Equal(t, inst.ID, got.ID)
}
func TestGetOrderProviderInstanceRejectsMissingSnapshotInstanceWithoutLegacyFallback(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
_, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeStripe).
SetName("stripe-legacy-fallback").
SetConfig(encryptWebhookProviderConfig(t, map[string]string{"secretKey": "sk_legacy"})).
SetSupportedTypes("stripe").
SetEnabled(true).
Save(ctx)
require.NoError(t, err)
order := &dbent.PaymentOrder{
ID: 43,
PaymentType: payment.TypeStripe,
ProviderSnapshot: map[string]any{
"schema_version": 1,
"provider_instance_id": "999999",
"provider_key": payment.TypeStripe,
},
}
svc := &PaymentService{
entClient: client,
loadBalancer: newWebhookProviderTestLoadBalancer(client),
}
got, err := svc.getOrderProviderInstance(ctx, order)
require.Nil(t, got)
require.Error(t, err)
require.Contains(t, err.Error(), "provider snapshot instance 999999 is missing")
}
func TestGetWebhookProviderRejectsAmbiguousRegistryFallback(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
wxpayConfigA := encryptValidWebhookWxpayConfig(t, "a")
wxpayConfigB := encryptValidWebhookWxpayConfig(t, "b")
_, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeWxpay).
SetName("wxpay-a").
SetConfig(wxpayConfigA).
SetSupportedTypes("wxpay").
SetEnabled(true).
Save(ctx)
require.NoError(t, err)
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeWxpay).
SetName("wxpay-b").
SetConfig(wxpayConfigB).
SetSupportedTypes("wxpay").
SetEnabled(true).
Save(ctx)
require.NoError(t, err)
svc := &PaymentService{
entClient: client,
loadBalancer: newWebhookProviderTestLoadBalancer(client),
registry: payment.NewRegistry(),
providersLoaded: true,
}
providers, err := svc.GetWebhookProviders(ctx, payment.TypeWxpay, "")
require.NoError(t, err)
require.Len(t, providers, 2)
}
func TestGetWebhookProvidersRejectAmbiguousFallbackForNonWxpay(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
_, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeAlipay).
SetName("alipay-a").
SetConfig("{}").
SetSupportedTypes("alipay").
SetEnabled(true).
Save(ctx)
require.NoError(t, err)
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeAlipay).
SetName("alipay-b").
SetConfig("{}").
SetSupportedTypes("alipay").
SetEnabled(true).
Save(ctx)
require.NoError(t, err)
svc := &PaymentService{
entClient: client,
registry: payment.NewRegistry(),
providersLoaded: true,
}
_, err = svc.GetWebhookProviders(ctx, payment.TypeAlipay, "")
require.Error(t, err)
require.Contains(t, err.Error(), "ambiguous")
}
func TestGetWebhookProviderAllowsSingleInstanceRegistryFallback(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
_, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeStripe).
SetName("stripe-a").
SetConfig("{}").
SetSupportedTypes("stripe").
SetEnabled(true).
Save(ctx)
require.NoError(t, err)
registry := payment.NewRegistry()
registry.Register(webhookProviderTestDouble{
key: payment.TypeStripe,
types: []payment.PaymentType{payment.TypeStripe},
})
svc := &PaymentService{
entClient: client,
registry: registry,
providersLoaded: true,
}
providers, err := svc.GetWebhookProviders(ctx, payment.TypeStripe, "")
require.NoError(t, err)
require.Len(t, providers, 1)
prov := providers[0]
require.Equal(t, payment.TypeStripe, prov.ProviderKey())
}
func TestGetWebhookProviderRejectsRegistryFallbackForPinnedOrder(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
user, err := client.User.Create().
SetEmail("webhook@example.com").
SetPasswordHash("hash").
SetUsername("webhook").
Save(ctx)
require.NoError(t, err)
pinnedInstanceID := "999"
_, err = client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(88).
SetPayAmount(88).
SetFeeRate(0).
SetRechargeCode("TEST-RECHARGE").
SetOutTradeNo("sub2_test_pinned_order").
SetPaymentType(payment.TypeWxpay).
SetPaymentTradeNo("").
SetOrderType(payment.OrderTypeBalance).
SetStatus(OrderStatusPending).
SetExpiresAt(time.Now().Add(time.Hour)).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
SetProviderInstanceID(pinnedInstanceID).
Save(ctx)
require.NoError(t, err)
registry := payment.NewRegistry()
registry.Register(webhookProviderTestDouble{
key: payment.TypeWxpay,
types: []payment.PaymentType{payment.TypeWxpay},
})
svc := &PaymentService{
entClient: client,
registry: registry,
providersLoaded: true,
}
_, err = svc.GetWebhookProviders(ctx, payment.TypeWxpay, "sub2_test_pinned_order")
require.Error(t, err)
require.Contains(t, err.Error(), "provider instance")
}
func TestGetWebhookProviderUsesProviderSnapshotBeforeWxpayFallback(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
user, err := client.User.Create().
SetEmail("snapshot-webhook@example.com").
SetPasswordHash("hash").
SetUsername("snapshot-webhook").
Save(ctx)
require.NoError(t, err)
wxpayConfigA := encryptValidWebhookWxpayConfig(t, "snapshot-a")
wxpayConfigB := encryptValidWebhookWxpayConfig(t, "snapshot-b")
instA, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeWxpay).
SetName("wxpay-snapshot-a").
SetConfig(wxpayConfigA).
SetSupportedTypes("wxpay").
SetEnabled(true).
Save(ctx)
require.NoError(t, err)
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeWxpay).
SetName("wxpay-snapshot-b").
SetConfig(wxpayConfigB).
SetSupportedTypes("wxpay").
SetEnabled(true).
Save(ctx)
require.NoError(t, err)
_, err = client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(66).
SetPayAmount(66).
SetFeeRate(0).
SetRechargeCode("SNAPSHOT-WEBHOOK").
SetOutTradeNo("sub2_test_snapshot_webhook_order").
SetPaymentType(payment.TypeWxpay).
SetPaymentTradeNo("").
SetOrderType(payment.OrderTypeBalance).
SetStatus(OrderStatusPending).
SetExpiresAt(time.Now().Add(time.Hour)).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
SetProviderSnapshot(map[string]any{
"schema_version": 1,
"provider_instance_id": strconv.FormatInt(instA.ID, 10),
"provider_key": payment.TypeWxpay,
"payment_mode": "native",
}).
Save(ctx)
require.NoError(t, err)
svc := &PaymentService{
entClient: client,
loadBalancer: newWebhookProviderTestLoadBalancer(client),
registry: payment.NewRegistry(),
providersLoaded: true,
}
providers, err := svc.GetWebhookProviders(ctx, payment.TypeWxpay, "sub2_test_snapshot_webhook_order")
require.NoError(t, err)
require.Len(t, providers, 1)
require.Equal(t, payment.TypeWxpay, providers[0].ProviderKey())
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment