Unverified Commit 8eb3f9e7 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #1785 from IanShaw027/rebuild/auth-identity-foundation

feat(auth,payment): 重构认证身份和支付系统及其他部分优化
parents 78f691d2 7fbd5177
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/payment"
)
func TestUsesOfficialWxpayVisibleMethodDerivesFromEnabledProviderInstance(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
_, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeWxpay).
SetName("Official WeChat").
SetConfig("{}").
SetSupportedTypes("wxpay").
SetEnabled(true).
SetSortOrder(1).
Save(ctx)
if err != nil {
t.Fatalf("create official wxpay instance: %v", err)
}
svc := &PaymentService{
configService: &PaymentConfigService{entClient: client},
}
if !svc.usesOfficialWxpayVisibleMethod(ctx) {
t.Fatal("expected official wxpay visible method to be detected from enabled provider instance")
}
}
......@@ -5,6 +5,7 @@ import (
"fmt"
"log/slog"
"strconv"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
......@@ -139,30 +140,86 @@ func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) s
if err != nil {
return ""
}
// Use OutTradeNo as fallback when PaymentTradeNo is empty
// (e.g. EasyPay popup mode where trade_no arrives only via notify callback)
tradeNo := o.PaymentTradeNo
if tradeNo == "" {
tradeNo = o.OutTradeNo
queryRef := paymentOrderQueryReference(o, prov)
if queryRef == "" {
return ""
}
resp, err := prov.QueryOrder(ctx, tradeNo)
resp, err := prov.QueryOrder(ctx, queryRef)
if err != nil {
slog.Warn("query upstream failed", "orderID", o.ID, "error", err)
return ""
}
if resp.Status == payment.ProviderStatusPaid {
if err := s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: o.PaymentTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess}, prov.ProviderKey()); err != nil {
notificationTradeNo := o.PaymentTradeNo
if upstreamTradeNo := strings.TrimSpace(resp.TradeNo); paymentOrderShouldPersistUpstreamTradeNo(queryRef, upstreamTradeNo, notificationTradeNo) {
if _, updateErr := s.entClient.PaymentOrder.Update().
Where(paymentorder.IDEQ(o.ID)).
SetPaymentTradeNo(upstreamTradeNo).
Save(ctx); updateErr != nil {
slog.Error("persist upstream trade no during checkPaid failed", "orderID", o.ID, "tradeNo", upstreamTradeNo, "error", updateErr)
} else {
o.PaymentTradeNo = upstreamTradeNo
}
notificationTradeNo = upstreamTradeNo
}
if err := s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: notificationTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess, Metadata: resp.Metadata}, prov.ProviderKey()); err != nil {
slog.Error("fulfillment failed during checkPaid", "orderID", o.ID, "error", err)
// Still return already_paid — order was paid, fulfillment can be retried
}
return checkPaidResultAlreadyPaid
}
if cp, ok := prov.(payment.CancelableProvider); ok {
_ = cp.CancelPayment(ctx, tradeNo)
_ = cp.CancelPayment(ctx, queryRef)
}
return ""
}
func paymentOrderQueryReference(order *dbent.PaymentOrder, prov payment.Provider) string {
if order == nil {
return ""
}
providerKey := ""
if prov != nil {
providerKey = strings.TrimSpace(prov.ProviderKey())
}
if providerKey == "" {
if snapshot := psOrderProviderSnapshot(order); snapshot != nil {
providerKey = strings.TrimSpace(snapshot.ProviderKey)
}
}
if providerKey == "" {
providerKey = strings.TrimSpace(psStringValue(order.ProviderKey))
}
if providerKey == "" {
providerKey = strings.TrimSpace(order.PaymentType)
}
switch payment.GetBasePaymentType(providerKey) {
case payment.TypeAlipay, payment.TypeEasyPay, payment.TypeWxpay:
return strings.TrimSpace(order.OutTradeNo)
default:
if tradeNo := strings.TrimSpace(order.PaymentTradeNo); tradeNo != "" {
return tradeNo
}
return strings.TrimSpace(order.OutTradeNo)
}
}
func paymentOrderShouldPersistUpstreamTradeNo(queryRef, upstreamTradeNo, currentTradeNo string) bool {
upstreamTradeNo = strings.TrimSpace(upstreamTradeNo)
if upstreamTradeNo == "" {
return false
}
if strings.EqualFold(upstreamTradeNo, strings.TrimSpace(currentTradeNo)) {
return false
}
if strings.EqualFold(upstreamTradeNo, strings.TrimSpace(queryRef)) {
return false
}
return true
}
// VerifyOrderByOutTradeNo actively queries the upstream provider to check
// if a payment was made, and processes it if so. This handles the case where
// the provider's notify callback was missed (e.g. EasyPay popup mode).
......@@ -190,8 +247,9 @@ func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo
return o, nil
}
// VerifyOrderPublic verifies payment status without user authentication.
// Used by the payment result page when the user's session has expired.
// VerifyOrderPublic returns the currently persisted public order state without
// triggering any upstream reconciliation. Signed resume-token recovery is the
// only public recovery path allowed to query upstream state.
func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo string) (*dbent.PaymentOrder, error) {
o, err := s.entClient.PaymentOrder.Query().
Where(paymentorder.OutTradeNo(outTradeNo)).
......@@ -199,15 +257,6 @@ func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo strin
if err != nil {
return nil, infraerrors.NotFound("NOT_FOUND", "order not found")
}
if o.Status == OrderStatusPending || o.Status == OrderStatusExpired {
result := s.checkPaid(ctx, o)
if result == checkPaidResultAlreadyPaid {
o, err = s.entClient.PaymentOrder.Get(ctx, o.ID)
if err != nil {
return nil, fmt.Errorf("reload order: %w", err)
}
}
}
return o, nil
}
......@@ -236,22 +285,79 @@ func (s *PaymentService) ExpireTimedOutOrders(ctx context.Context) (int, error)
// getOrderProvider creates a provider using the order's original instance config.
// Falls back to registry lookup if instance ID is missing (legacy orders).
func (s *PaymentService) getOrderProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) {
if o.ProviderInstanceID != nil && *o.ProviderInstanceID != "" {
instID, err := strconv.ParseInt(*o.ProviderInstanceID, 10, 64)
if err == nil {
cfg, err := s.loadBalancer.GetInstanceConfig(ctx, instID)
if err == nil {
providerKey := s.registry.GetProviderKey(o.PaymentType)
if providerKey == "" {
providerKey = o.PaymentType
}
p, err := provider.CreateProvider(providerKey, *o.ProviderInstanceID, cfg)
if err == nil {
return p, nil
}
}
}
inst, err := s.getOrderProviderInstance(ctx, o)
if err != nil {
return nil, fmt.Errorf("load order provider instance: %w", err)
}
if inst != nil {
return s.createProviderFromInstance(ctx, inst)
}
if !paymentOrderAllowsRegistryFallback(o) {
return nil, fmt.Errorf("order %d provider instance is unresolved", o.ID)
}
providerKey := paymentOrderFallbackProviderKey(s.registry, o)
if providerKey == "" {
return nil, fmt.Errorf("order %d provider fallback key is missing", o.ID)
}
if !s.webhookRegistryFallbackAllowed(ctx, providerKey) {
return nil, fmt.Errorf("order %d provider fallback is ambiguous for %s", o.ID, providerKey)
}
s.EnsureProviders(ctx)
return s.registry.GetProvider(o.PaymentType)
}
func paymentOrderAllowsRegistryFallback(order *dbent.PaymentOrder) bool {
if order == nil {
return false
}
if psOrderProviderSnapshot(order) != nil {
return false
}
if strings.TrimSpace(psStringValue(order.ProviderInstanceID)) != "" {
return false
}
if strings.TrimSpace(psStringValue(order.ProviderKey)) != "" {
return false
}
return true
}
func paymentOrderFallbackProviderKey(registry *payment.Registry, order *dbent.PaymentOrder) string {
if order == nil {
return ""
}
if registry != nil {
if key := strings.TrimSpace(registry.GetProviderKey(payment.PaymentType(order.PaymentType))); key != "" {
return key
}
}
return strings.TrimSpace(payment.GetBasePaymentType(strings.TrimSpace(order.PaymentType)))
}
func (s *PaymentService) createProviderFromInstance(ctx context.Context, inst *dbent.PaymentProviderInstance) (payment.Provider, error) {
if inst == nil {
return nil, fmt.Errorf("payment provider instance is missing")
}
cfg, err := s.loadBalancer.GetInstanceConfig(ctx, int64(inst.ID))
if err != nil {
return nil, fmt.Errorf("load provider instance config: %w", err)
}
if inst.PaymentMode != "" {
cfg["paymentMode"] = inst.PaymentMode
}
instID := strconv.FormatInt(int64(inst.ID), 10)
prov, err := provider.CreateProvider(inst.ProviderKey, instID, cfg)
if err != nil {
return nil, fmt.Errorf("create provider from instance: %w", err)
}
return prov, nil
}
func psStringValue(value *string) string {
if value == nil {
return ""
}
return *value
}
//go:build unit
package service
import (
"context"
"database/sql"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "modernc.org/sqlite"
)
type paymentOrderLifecycleQueryProvider struct {
lastQueryTradeNo string
resp *payment.QueryOrderResponse
}
type paymentOrderLifecycleRedeemRepo struct {
codesByCode map[string]*RedeemCode
useCalls []struct {
id int64
userID int64
}
}
func (p *paymentOrderLifecycleQueryProvider) Name() string {
return "payment-order-lifecycle-query-provider"
}
func (p *paymentOrderLifecycleQueryProvider) ProviderKey() string { return payment.TypeAlipay }
func (p *paymentOrderLifecycleQueryProvider) SupportedTypes() []payment.PaymentType {
return []payment.PaymentType{payment.TypeAlipay}
}
func (p *paymentOrderLifecycleQueryProvider) CreatePayment(context.Context, payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
panic("unexpected call")
}
func (p *paymentOrderLifecycleQueryProvider) QueryOrder(_ context.Context, tradeNo string) (*payment.QueryOrderResponse, error) {
p.lastQueryTradeNo = tradeNo
return p.resp, nil
}
func (p *paymentOrderLifecycleQueryProvider) VerifyNotification(context.Context, string, map[string]string) (*payment.PaymentNotification, error) {
panic("unexpected call")
}
func (p *paymentOrderLifecycleQueryProvider) Refund(context.Context, payment.RefundRequest) (*payment.RefundResponse, error) {
panic("unexpected call")
}
func (r *paymentOrderLifecycleRedeemRepo) Create(context.Context, *RedeemCode) error {
panic("unexpected call")
}
func (r *paymentOrderLifecycleRedeemRepo) CreateBatch(context.Context, []RedeemCode) error {
panic("unexpected call")
}
func (r *paymentOrderLifecycleRedeemRepo) GetByID(_ context.Context, id int64) (*RedeemCode, error) {
for _, code := range r.codesByCode {
if code.ID != id {
continue
}
cloned := *code
return &cloned, nil
}
return nil, ErrRedeemCodeNotFound
}
func (r *paymentOrderLifecycleRedeemRepo) GetByCode(_ context.Context, code string) (*RedeemCode, error) {
redeemCode, ok := r.codesByCode[code]
if !ok {
return nil, ErrRedeemCodeNotFound
}
cloned := *redeemCode
return &cloned, nil
}
func (r *paymentOrderLifecycleRedeemRepo) Update(context.Context, *RedeemCode) error {
panic("unexpected call")
}
func (r *paymentOrderLifecycleRedeemRepo) Delete(context.Context, int64) error {
panic("unexpected call")
}
func (r *paymentOrderLifecycleRedeemRepo) Use(_ context.Context, id, userID int64) error {
for code, redeemCode := range r.codesByCode {
if redeemCode.ID != id {
continue
}
now := time.Now().UTC()
redeemCode.Status = StatusUsed
redeemCode.UsedBy = &userID
redeemCode.UsedAt = &now
r.codesByCode[code] = redeemCode
r.useCalls = append(r.useCalls, struct {
id int64
userID int64
}{id: id, userID: userID})
return nil
}
return ErrRedeemCodeNotFound
}
func (r *paymentOrderLifecycleRedeemRepo) List(context.Context, pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error) {
panic("unexpected call")
}
func (r *paymentOrderLifecycleRedeemRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string) ([]RedeemCode, *pagination.PaginationResult, error) {
panic("unexpected call")
}
func (r *paymentOrderLifecycleRedeemRepo) ListByUser(context.Context, int64, int) ([]RedeemCode, error) {
panic("unexpected call")
}
func (r *paymentOrderLifecycleRedeemRepo) ListByUserPaginated(context.Context, int64, pagination.PaginationParams, string) ([]RedeemCode, *pagination.PaginationResult, error) {
panic("unexpected call")
}
func (r *paymentOrderLifecycleRedeemRepo) SumPositiveBalanceByUser(context.Context, int64) (float64, error) {
panic("unexpected call")
}
func TestVerifyOrderByOutTradeNoBackfillsTradeNoFromPaidQuery(t *testing.T) {
ctx := context.Background()
client := newPaymentOrderLifecycleTestClient(t)
user, err := client.User.Create().
SetEmail("checkpaid@example.com").
SetPasswordHash("hash").
SetUsername("checkpaid-user").
Save(ctx)
require.NoError(t, err)
order, err := client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(88).
SetPayAmount(88).
SetFeeRate(0).
SetRechargeCode("CHECKPAID-UPSTREAM-TRADE-NO").
SetOutTradeNo("sub2_checkpaid_trade_no_missing").
SetPaymentType(payment.TypeAlipay).
SetPaymentTradeNo("").
SetOrderType(payment.OrderTypeBalance).
SetStatus(OrderStatusPending).
SetExpiresAt(time.Now().Add(time.Hour)).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
Save(ctx)
require.NoError(t, err)
userRepo := &mockUserRepo{
getByIDUser: &User{
ID: user.ID,
Email: user.Email,
Username: user.Username,
Balance: 0,
},
}
userRepo.updateBalanceFn = func(ctx context.Context, id int64, amount float64) error {
require.Equal(t, user.ID, id)
if userRepo.getByIDUser != nil {
userRepo.getByIDUser.Balance += amount
}
return nil
}
redeemRepo := &paymentOrderLifecycleRedeemRepo{
codesByCode: map[string]*RedeemCode{
order.RechargeCode: {
ID: 1,
Code: order.RechargeCode,
Type: RedeemTypeBalance,
Value: order.Amount,
Status: StatusUnused,
},
},
}
redeemService := NewRedeemService(
redeemRepo,
userRepo,
nil,
nil,
nil,
client,
nil,
)
registry := payment.NewRegistry()
provider := &paymentOrderLifecycleQueryProvider{
resp: &payment.QueryOrderResponse{
TradeNo: "upstream-trade-123",
Status: payment.ProviderStatusPaid,
Amount: 88,
},
}
registry.Register(provider)
svc := &PaymentService{
entClient: client,
registry: registry,
redeemService: redeemService,
userRepo: userRepo,
providersLoaded: true,
}
got, err := svc.VerifyOrderByOutTradeNo(ctx, order.OutTradeNo, user.ID)
require.NoError(t, err)
require.Equal(t, order.OutTradeNo, provider.lastQueryTradeNo)
require.Equal(t, OrderStatusCompleted, got.Status)
require.Equal(t, "upstream-trade-123", got.PaymentTradeNo)
reloaded, err := client.PaymentOrder.Get(ctx, order.ID)
require.NoError(t, err)
require.Equal(t, OrderStatusCompleted, reloaded.Status)
require.Equal(t, "upstream-trade-123", reloaded.PaymentTradeNo)
require.Equal(t, 88.0, userRepo.getByIDUser.Balance)
require.Len(t, redeemRepo.useCalls, 1)
require.Equal(t, int64(1), redeemRepo.useCalls[0].id)
require.Equal(t, user.ID, redeemRepo.useCalls[0].userID)
}
func TestVerifyOrderByOutTradeNoUsesOutTradeNoWhenPaymentTradeNoAlreadyExistsForAlipay(t *testing.T) {
ctx := context.Background()
client := newPaymentOrderLifecycleTestClient(t)
user, err := client.User.Create().
SetEmail("checkpaid-existing-trade@example.com").
SetPasswordHash("hash").
SetUsername("checkpaid-existing-trade-user").
Save(ctx)
require.NoError(t, err)
order, err := client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(88).
SetPayAmount(88).
SetFeeRate(0).
SetRechargeCode("CHECKPAID-EXISTING-TRADE-NO").
SetOutTradeNo("sub2_checkpaid_use_out_trade_no").
SetPaymentType(payment.TypeAlipay).
SetPaymentTradeNo("upstream-trade-existing").
SetOrderType(payment.OrderTypeBalance).
SetStatus(OrderStatusPending).
SetExpiresAt(time.Now().Add(time.Hour)).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
Save(ctx)
require.NoError(t, err)
userRepo := &mockUserRepo{
getByIDUser: &User{
ID: user.ID,
Email: user.Email,
Username: user.Username,
Balance: 0,
},
}
userRepo.updateBalanceFn = func(ctx context.Context, id int64, amount float64) error {
require.Equal(t, user.ID, id)
if userRepo.getByIDUser != nil {
userRepo.getByIDUser.Balance += amount
}
return nil
}
redeemRepo := &paymentOrderLifecycleRedeemRepo{
codesByCode: map[string]*RedeemCode{
order.RechargeCode: {
ID: 1,
Code: order.RechargeCode,
Type: RedeemTypeBalance,
Value: order.Amount,
Status: StatusUnused,
},
},
}
redeemService := NewRedeemService(
redeemRepo,
userRepo,
nil,
nil,
nil,
client,
nil,
)
registry := payment.NewRegistry()
provider := &paymentOrderLifecycleQueryProvider{
resp: &payment.QueryOrderResponse{
TradeNo: "upstream-trade-existing",
Status: payment.ProviderStatusPaid,
Amount: 88,
},
}
registry.Register(provider)
svc := &PaymentService{
entClient: client,
registry: registry,
redeemService: redeemService,
userRepo: userRepo,
providersLoaded: true,
}
got, err := svc.VerifyOrderByOutTradeNo(ctx, order.OutTradeNo, user.ID)
require.NoError(t, err)
require.Equal(t, order.OutTradeNo, provider.lastQueryTradeNo)
require.Equal(t, "upstream-trade-existing", got.PaymentTradeNo)
}
func TestPaymentOrderAllowsRegistryFallbackOnlyForLegacyOrdersWithoutPinnedProviderState(t *testing.T) {
t.Parallel()
require.True(t, paymentOrderAllowsRegistryFallback(&dbent.PaymentOrder{
PaymentType: payment.TypeAlipay,
}))
instanceID := "12"
require.False(t, paymentOrderAllowsRegistryFallback(&dbent.PaymentOrder{
PaymentType: payment.TypeAlipay,
ProviderInstanceID: &instanceID,
}))
require.False(t, paymentOrderAllowsRegistryFallback(&dbent.PaymentOrder{
PaymentType: payment.TypeAlipay,
ProviderSnapshot: map[string]any{
"schema_version": 2,
"provider_instance_id": "12",
},
}))
}
func TestPaymentOrderQueryReferenceUsesOutTradeNoForOfficialProviders(t *testing.T) {
t.Parallel()
order := &dbent.PaymentOrder{
PaymentType: payment.TypeWxpay,
OutTradeNo: "sub2_out_trade_no",
PaymentTradeNo: "wx-transaction-id",
}
require.Equal(t, "sub2_out_trade_no", paymentOrderQueryReference(order, &paymentOrderLifecycleQueryProvider{}))
require.Equal(t, "sub2_out_trade_no", paymentOrderQueryReference(order, paymentFulfillmentTestProvider{
key: payment.TypeWxpay,
}))
}
func newPaymentOrderLifecycleTestClient(t *testing.T) *dbent.Client {
t.Helper()
db, err := sql.Open("sqlite", "file:payment_order_lifecycle?mode=memory&cache=shared&_fk=1")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
return client
}
package service
import (
"context"
"fmt"
"strconv"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/payment"
)
type paymentOrderProviderSnapshot struct {
SchemaVersion int
ProviderInstanceID string
ProviderKey string
PaymentMode string
MerchantAppID string
MerchantID string
Currency string
}
func psOrderProviderSnapshot(order *dbent.PaymentOrder) *paymentOrderProviderSnapshot {
if order == nil || len(order.ProviderSnapshot) == 0 {
return nil
}
snapshot := &paymentOrderProviderSnapshot{
SchemaVersion: psSnapshotIntValue(order.ProviderSnapshot["schema_version"]),
ProviderInstanceID: psSnapshotStringValue(order.ProviderSnapshot["provider_instance_id"]),
ProviderKey: psSnapshotStringValue(order.ProviderSnapshot["provider_key"]),
PaymentMode: psSnapshotStringValue(order.ProviderSnapshot["payment_mode"]),
MerchantAppID: psSnapshotStringValue(order.ProviderSnapshot["merchant_app_id"]),
MerchantID: psSnapshotStringValue(order.ProviderSnapshot["merchant_id"]),
Currency: psSnapshotStringValue(order.ProviderSnapshot["currency"]),
}
if snapshot.SchemaVersion == 0 &&
snapshot.ProviderInstanceID == "" &&
snapshot.ProviderKey == "" &&
snapshot.PaymentMode == "" &&
snapshot.MerchantAppID == "" &&
snapshot.MerchantID == "" &&
snapshot.Currency == "" {
return nil
}
return snapshot
}
func psSnapshotStringValue(value any) string {
switch typed := value.(type) {
case string:
return strings.TrimSpace(typed)
default:
return ""
}
}
func psSnapshotIntValue(value any) int {
switch typed := value.(type) {
case int:
return typed
case int32:
return int(typed)
case int64:
return int(typed)
case float32:
return int(typed)
case float64:
return int(typed)
case string:
n, err := strconv.Atoi(strings.TrimSpace(typed))
if err == nil {
return n
}
}
return 0
}
func (s *PaymentService) resolveSnapshotOrderProviderInstance(ctx context.Context, order *dbent.PaymentOrder, snapshot *paymentOrderProviderSnapshot) (*dbent.PaymentProviderInstance, error) {
if s == nil || s.entClient == nil || order == nil || snapshot == nil {
return nil, nil
}
snapshotInstanceID := strings.TrimSpace(snapshot.ProviderInstanceID)
columnInstanceID := strings.TrimSpace(psStringValue(order.ProviderInstanceID))
if snapshotInstanceID == "" {
snapshotInstanceID = columnInstanceID
}
if snapshotInstanceID == "" {
return nil, fmt.Errorf("order %d provider snapshot is missing provider_instance_id", order.ID)
}
if columnInstanceID != "" && snapshot.ProviderInstanceID != "" && !strings.EqualFold(columnInstanceID, snapshot.ProviderInstanceID) {
return nil, fmt.Errorf("order %d provider snapshot instance mismatch: snapshot=%s order=%s", order.ID, snapshot.ProviderInstanceID, columnInstanceID)
}
instID, err := strconv.ParseInt(snapshotInstanceID, 10, 64)
if err != nil {
return nil, fmt.Errorf("order %d provider snapshot instance id is invalid: %s", order.ID, snapshotInstanceID)
}
inst, err := s.entClient.PaymentProviderInstance.Get(ctx, instID)
if err != nil {
if dbent.IsNotFound(err) {
return nil, fmt.Errorf("order %d provider snapshot instance %s is missing", order.ID, snapshotInstanceID)
}
return nil, err
}
if snapshot.ProviderKey != "" && !strings.EqualFold(strings.TrimSpace(inst.ProviderKey), snapshot.ProviderKey) {
return nil, fmt.Errorf("order %d provider snapshot key mismatch: snapshot=%s instance=%s", order.ID, snapshot.ProviderKey, inst.ProviderKey)
}
return inst, nil
}
func expectedNotificationProviderKeyForOrder(registry *payment.Registry, order *dbent.PaymentOrder, instanceProviderKey string) string {
if order == nil {
return strings.TrimSpace(instanceProviderKey)
}
orderProviderKey := psStringValue(order.ProviderKey)
if snapshot := psOrderProviderSnapshot(order); snapshot != nil && snapshot.ProviderKey != "" {
orderProviderKey = snapshot.ProviderKey
}
return expectedNotificationProviderKey(registry, order.PaymentType, orderProviderKey, instanceProviderKey)
}
func validateProviderSnapshotMetadata(order *dbent.PaymentOrder, providerKey string, metadata map[string]string) error {
if order == nil || len(metadata) == 0 {
return nil
}
snapshot := psOrderProviderSnapshot(order)
if snapshot == nil {
return nil
}
switch strings.TrimSpace(providerKey) {
case payment.TypeWxpay:
if expected := strings.TrimSpace(snapshot.MerchantAppID); expected != "" {
actual := strings.TrimSpace(metadata["appid"])
if actual == "" {
return fmt.Errorf("wxpay notification missing appid")
}
if !strings.EqualFold(expected, actual) {
return fmt.Errorf("wxpay appid mismatch: expected %s, got %s", expected, actual)
}
}
if expected := strings.TrimSpace(snapshot.MerchantID); expected != "" {
actual := strings.TrimSpace(metadata["mchid"])
if actual == "" {
return fmt.Errorf("wxpay notification missing mchid")
}
if !strings.EqualFold(expected, actual) {
return fmt.Errorf("wxpay mchid mismatch: expected %s, got %s", expected, actual)
}
}
if expected := strings.TrimSpace(snapshot.Currency); expected != "" {
actual := strings.ToUpper(strings.TrimSpace(metadata["currency"]))
if actual == "" {
return fmt.Errorf("wxpay notification missing currency")
}
if !strings.EqualFold(expected, actual) {
return fmt.Errorf("wxpay currency mismatch: expected %s, got %s", expected, actual)
}
}
if actual := strings.TrimSpace(metadata["trade_state"]); actual != "" && !strings.EqualFold(actual, "SUCCESS") {
return fmt.Errorf("wxpay trade_state mismatch: expected SUCCESS, got %s", actual)
}
case payment.TypeAlipay:
if expected := strings.TrimSpace(snapshot.MerchantAppID); expected != "" {
actual := strings.TrimSpace(metadata["app_id"])
if actual == "" {
return fmt.Errorf("alipay app_id missing")
}
if !strings.EqualFold(expected, actual) {
return fmt.Errorf("alipay app_id mismatch: expected %s, got %s", expected, actual)
}
}
case payment.TypeEasyPay:
if expected := strings.TrimSpace(snapshot.MerchantID); expected != "" {
actual := strings.TrimSpace(metadata["pid"])
if actual == "" {
return fmt.Errorf("easypay pid missing")
}
if !strings.EqualFold(expected, actual) {
return fmt.Errorf("easypay pid mismatch: expected %s, got %s", expected, actual)
}
}
}
return nil
}
func providerMerchantIdentityMetadata(prov payment.Provider) map[string]string {
if prov == nil {
return nil
}
reporter, ok := prov.(payment.MerchantIdentityProvider)
if !ok {
return nil
}
return reporter.MerchantIdentityMetadata()
}
//go:build unit
package service
import (
"context"
"strconv"
"testing"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/stretchr/testify/require"
)
func TestBuildPaymentOrderProviderSnapshot_ExcludesSensitiveConfig(t *testing.T) {
t.Parallel()
sel := &payment.InstanceSelection{
InstanceID: "12",
ProviderKey: payment.TypeWxpay,
SupportedTypes: "wxpay,wxpay_direct",
PaymentMode: "popup",
Config: map[string]string{
"privateKey": "secret",
"apiV3Key": "secret-v3",
"appId": "wx-app-id",
},
}
snapshot := buildPaymentOrderProviderSnapshot(sel, CreateOrderRequest{})
require.Equal(t, map[string]any{
"schema_version": 2,
"provider_instance_id": "12",
"provider_key": payment.TypeWxpay,
"payment_mode": "popup",
"merchant_app_id": "wx-app-id",
"currency": "CNY",
}, snapshot)
require.NotContains(t, snapshot, "config")
require.NotContains(t, snapshot, "privateKey")
require.NotContains(t, snapshot, "apiV3Key")
require.NotContains(t, snapshot, "supported_types")
require.NotContains(t, snapshot, "instance_name")
require.NotContains(t, snapshot, "merchant_id")
}
func TestCreateOrderInTx_WritesProviderSnapshot(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
user, err := client.User.Create().
SetEmail("snapshot@example.com").
SetPasswordHash("hash").
SetUsername("snapshot-user").
Save(ctx)
require.NoError(t, err)
instance, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeAlipay).
SetName("Primary Alipay").
SetConfig(`{"secretKey":"do-not-copy"}`).
SetSupportedTypes("alipay,alipay_direct").
SetPaymentMode("redirect").
SetEnabled(true).
Save(ctx)
require.NoError(t, err)
svc := &PaymentService{entClient: client}
order, err := svc.createOrderInTx(
ctx,
CreateOrderRequest{
UserID: user.ID,
PaymentType: payment.TypeAlipay,
OrderType: payment.OrderTypeBalance,
ClientIP: "127.0.0.1",
SrcHost: "app.example.com",
},
&User{
ID: user.ID,
Email: user.Email,
Username: user.Username,
},
nil,
&PaymentConfig{
MaxPendingOrders: 3,
OrderTimeoutMin: 30,
},
88,
88,
0,
88,
&payment.InstanceSelection{
InstanceID: strconv.FormatInt(instance.ID, 10),
ProviderKey: payment.TypeAlipay,
SupportedTypes: "alipay,alipay_direct",
PaymentMode: "redirect",
Config: map[string]string{
"secretKey": "do-not-copy",
},
},
)
require.NoError(t, err)
require.Equal(t, strconv.FormatInt(instance.ID, 10), valueOrEmpty(order.ProviderInstanceID))
require.Equal(t, payment.TypeAlipay, valueOrEmpty(order.ProviderKey))
require.Equal(t, float64(2), order.ProviderSnapshot["schema_version"])
require.Equal(t, strconv.FormatInt(instance.ID, 10), order.ProviderSnapshot["provider_instance_id"])
require.Equal(t, payment.TypeAlipay, order.ProviderSnapshot["provider_key"])
require.Equal(t, "redirect", order.ProviderSnapshot["payment_mode"])
require.NotContains(t, order.ProviderSnapshot, "config")
require.NotContains(t, order.ProviderSnapshot, "secretKey")
require.NotContains(t, order.ProviderSnapshot, "supported_types")
require.NotContains(t, order.ProviderSnapshot, "instance_name")
}
func TestBuildPaymentOrderProviderSnapshot_UsesWxpayJSAPIAppIDForOpenIDOrders(t *testing.T) {
t.Parallel()
snapshot := buildPaymentOrderProviderSnapshot(&payment.InstanceSelection{
InstanceID: "88",
ProviderKey: payment.TypeWxpay,
Config: map[string]string{
"appId": "wx-open-app",
"mpAppId": "wx-mp-app",
"mchId": "mch-88",
},
PaymentMode: "jsapi",
}, CreateOrderRequest{OpenID: "openid-123"})
require.Equal(t, "wx-mp-app", snapshot["merchant_app_id"])
require.Equal(t, "mch-88", snapshot["merchant_id"])
require.Equal(t, "CNY", snapshot["currency"])
}
func TestBuildPaymentOrderProviderSnapshot_IncludesAlipayMerchantIdentity(t *testing.T) {
t.Parallel()
snapshot := buildPaymentOrderProviderSnapshot(&payment.InstanceSelection{
InstanceID: "21",
ProviderKey: payment.TypeAlipay,
Config: map[string]string{
"appId": "alipay-app-21",
"privateKey": "secret",
},
PaymentMode: "redirect",
}, CreateOrderRequest{})
require.Equal(t, "alipay-app-21", snapshot["merchant_app_id"])
require.NotContains(t, snapshot, "privateKey")
}
func TestBuildPaymentOrderProviderSnapshot_IncludesEasyPayMerchantIdentity(t *testing.T) {
t.Parallel()
snapshot := buildPaymentOrderProviderSnapshot(&payment.InstanceSelection{
InstanceID: "66",
ProviderKey: payment.TypeEasyPay,
Config: map[string]string{
"pid": "easypay-merchant-66",
"pkey": "secret",
},
PaymentMode: "popup",
}, CreateOrderRequest{PaymentType: payment.TypeAlipay})
require.Equal(t, "easypay-merchant-66", snapshot["merchant_id"])
require.NotContains(t, snapshot, "pkey")
}
func valueOrEmpty(v *string) string {
if v == nil {
return ""
}
return *v
}
package service
import (
"context"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
func TestBuildCreateOrderResponseDefaultsToOrderCreated(t *testing.T) {
t.Parallel()
expiresAt := time.Date(2026, 4, 16, 12, 0, 0, 0, time.UTC)
resp := buildCreateOrderResponse(
&dbent.PaymentOrder{
ID: 42,
Amount: 12.34,
FeeRate: 0.03,
ExpiresAt: expiresAt,
OutTradeNo: "sub2_42",
},
CreateOrderRequest{PaymentType: payment.TypeWxpay},
12.71,
&payment.InstanceSelection{PaymentMode: "qrcode"},
&payment.CreatePaymentResponse{
TradeNo: "sub2_42",
QRCode: "weixin://wxpay/bizpayurl?pr=test",
},
payment.CreatePaymentResultOrderCreated,
)
if resp.ResultType != payment.CreatePaymentResultOrderCreated {
t.Fatalf("result type = %q, want %q", resp.ResultType, payment.CreatePaymentResultOrderCreated)
}
if resp.OutTradeNo != "sub2_42" {
t.Fatalf("out_trade_no = %q, want %q", resp.OutTradeNo, "sub2_42")
}
if resp.QRCode != "weixin://wxpay/bizpayurl?pr=test" {
t.Fatalf("qr_code = %q, want %q", resp.QRCode, "weixin://wxpay/bizpayurl?pr=test")
}
if resp.JSAPI != nil || resp.JSAPIPayload != nil {
t.Fatal("order_created response should not include jsapi payload")
}
if !resp.ExpiresAt.Equal(expiresAt) {
t.Fatalf("expires_at = %v, want %v", resp.ExpiresAt, expiresAt)
}
}
func TestBuildCreateOrderResponseCopiesJSAPIPayload(t *testing.T) {
t.Parallel()
jsapiPayload := &payment.WechatJSAPIPayload{
AppID: "wx123",
TimeStamp: "1712345678",
NonceStr: "nonce-123",
Package: "prepay_id=wx123",
SignType: "RSA",
PaySign: "signed-payload",
}
resp := buildCreateOrderResponse(
&dbent.PaymentOrder{
ID: 88,
Amount: 66.88,
FeeRate: 0.01,
ExpiresAt: time.Date(2026, 4, 16, 13, 0, 0, 0, time.UTC),
OutTradeNo: "sub2_88",
},
CreateOrderRequest{PaymentType: payment.TypeWxpay},
67.55,
&payment.InstanceSelection{PaymentMode: "popup"},
&payment.CreatePaymentResponse{
TradeNo: "sub2_88",
ResultType: payment.CreatePaymentResultJSAPIReady,
JSAPI: jsapiPayload,
},
payment.CreatePaymentResultJSAPIReady,
)
if resp.ResultType != payment.CreatePaymentResultJSAPIReady {
t.Fatalf("result type = %q, want %q", resp.ResultType, payment.CreatePaymentResultJSAPIReady)
}
if resp.JSAPI == nil || resp.JSAPIPayload == nil {
t.Fatal("expected jsapi payload aliases to be populated")
}
if resp.JSAPI != jsapiPayload || resp.JSAPIPayload != jsapiPayload {
t.Fatal("expected jsapi aliases to preserve the original pointer")
}
}
func TestMaybeBuildWeChatOAuthRequiredResponse(t *testing.T) {
svc := newWeChatPaymentOAuthTestService(map[string]string{
SettingKeyWeChatConnectEnabled: "true",
SettingKeyWeChatConnectAppID: "wx123456",
SettingKeyWeChatConnectAppSecret: "wechat-secret",
SettingKeyWeChatConnectMode: "mp",
SettingKeyWeChatConnectScopes: "snsapi_base",
SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
})
resp, err := svc.maybeBuildWeChatOAuthRequiredResponse(context.Background(), CreateOrderRequest{
Amount: 12.5,
PaymentType: payment.TypeWxpay,
IsWeChatBrowser: true,
SrcURL: "https://merchant.example/payment?from=wechat",
OrderType: payment.OrderTypeBalance,
}, 12.5, 12.88, 0.03)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp == nil {
t.Fatal("expected oauth_required response, got nil")
}
if resp.ResultType != payment.CreatePaymentResultOAuthRequired {
t.Fatalf("result type = %q, want %q", resp.ResultType, payment.CreatePaymentResultOAuthRequired)
}
if resp.OAuth == nil {
t.Fatal("expected oauth payload, got nil")
}
if resp.OAuth.AppID != "wx123456" {
t.Fatalf("appid = %q, want %q", resp.OAuth.AppID, "wx123456")
}
if resp.OAuth.Scope != "snsapi_base" {
t.Fatalf("scope = %q, want %q", resp.OAuth.Scope, "snsapi_base")
}
if resp.OAuth.RedirectURL != "/auth/wechat/payment/callback" {
t.Fatalf("redirect_url = %q, want %q", resp.OAuth.RedirectURL, "/auth/wechat/payment/callback")
}
if resp.OAuth.AuthorizeURL != "/api/v1/auth/oauth/wechat/payment/start?amount=12.5&order_type=balance&payment_type=wxpay&redirect=%2Fpurchase%3Ffrom%3Dwechat&scope=snsapi_base" {
t.Fatalf("authorize_url = %q", resp.OAuth.AuthorizeURL)
}
}
func TestMaybeBuildWeChatOAuthRequiredResponseRequiresMPConfigInWeChat(t *testing.T) {
t.Parallel()
svc := newWeChatPaymentOAuthTestService(nil)
resp, err := svc.maybeBuildWeChatOAuthRequiredResponse(context.Background(), CreateOrderRequest{
Amount: 12.5,
PaymentType: payment.TypeWxpay,
IsWeChatBrowser: true,
SrcURL: "https://merchant.example/payment?from=wechat",
OrderType: payment.OrderTypeBalance,
}, 12.5, 12.88, 0.03)
if resp != nil {
t.Fatalf("expected nil response, got %+v", resp)
}
if err == nil {
t.Fatal("expected error, got nil")
}
appErr := infraerrors.FromError(err)
if appErr.Reason != "WECHAT_PAYMENT_MP_NOT_CONFIGURED" {
t.Fatalf("reason = %q, want %q", appErr.Reason, "WECHAT_PAYMENT_MP_NOT_CONFIGURED")
}
}
func TestMaybeBuildWeChatOAuthRequiredResponseForSelectionSkipsEasyPayProvider(t *testing.T) {
svc := newWeChatPaymentOAuthTestService(map[string]string{
SettingKeyWeChatConnectEnabled: "true",
SettingKeyWeChatConnectAppID: "wx123456",
SettingKeyWeChatConnectAppSecret: "wechat-secret",
SettingKeyWeChatConnectMode: "mp",
SettingKeyWeChatConnectScopes: "snsapi_base",
SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
})
resp, err := svc.maybeBuildWeChatOAuthRequiredResponseForSelection(context.Background(), CreateOrderRequest{
Amount: 12.5,
PaymentType: payment.TypeWxpay,
IsWeChatBrowser: true,
OrderType: payment.OrderTypeBalance,
}, 12.5, 12.88, 0.03, &payment.InstanceSelection{
ProviderKey: payment.TypeEasyPay,
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp != nil {
t.Fatalf("expected nil response, got %+v", resp)
}
}
func newWeChatPaymentOAuthTestService(values map[string]string) *PaymentService {
return &PaymentService{
configService: &PaymentConfigService{
settingRepo: &paymentConfigSettingRepoStub{values: values},
},
}
}
......@@ -12,6 +12,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
......@@ -19,18 +20,133 @@ import (
// --- Refund Flow ---
// getOrderProviderInstance looks up the provider instance that processed this order.
// Returns nil, nil for legacy orders without provider_instance_id.
// For legacy orders without provider_instance_id, it resolves only when the
// historical instance is uniquely identifiable from the stored order fields.
func (s *PaymentService) getOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) {
if o.ProviderInstanceID == nil || *o.ProviderInstanceID == "" {
if s == nil || s.entClient == nil || o == nil {
return nil, nil
}
instID, err := strconv.ParseInt(*o.ProviderInstanceID, 10, 64)
if snapshot := psOrderProviderSnapshot(o); snapshot != nil {
return s.resolveSnapshotOrderProviderInstance(ctx, o, snapshot)
}
instIDStr := strings.TrimSpace(psStringValue(o.ProviderInstanceID))
if instIDStr == "" {
return s.resolveUniqueLegacyOrderProviderInstance(ctx, o)
}
instID, err := strconv.ParseInt(instIDStr, 10, 64)
if err != nil {
return nil, nil
}
return s.entClient.PaymentProviderInstance.Get(ctx, instID)
}
// getRefundOrderProviderInstance resolves the provider instance for refund paths.
// Refunds must be pinned to an explicit historical binding, so legacy
// "best-effort" provider guessing is intentionally not allowed here.
func (s *PaymentService) getRefundOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) {
if s == nil || s.entClient == nil || o == nil {
return nil, nil
}
if snapshot := psOrderProviderSnapshot(o); snapshot != nil {
return s.resolveSnapshotOrderProviderInstance(ctx, o, snapshot)
}
instIDStr := strings.TrimSpace(psStringValue(o.ProviderInstanceID))
if instIDStr == "" {
return nil, nil
}
instID, err := strconv.ParseInt(instIDStr, 10, 64)
if err != nil {
return nil, fmt.Errorf("order %d refund provider instance id is invalid: %s", o.ID, instIDStr)
}
inst, err := s.entClient.PaymentProviderInstance.Get(ctx, instID)
if err != nil {
if dbent.IsNotFound(err) {
return nil, fmt.Errorf("order %d refund provider instance %s is missing", o.ID, instIDStr)
}
return nil, err
}
return inst, nil
}
func (s *PaymentService) resolveUniqueLegacyOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) {
paymentType := payment.GetBasePaymentType(strings.TrimSpace(o.PaymentType))
providerKey := strings.TrimSpace(psStringValue(o.ProviderKey))
if providerKey != "" {
instances, err := s.entClient.PaymentProviderInstance.Query().
Where(paymentproviderinstance.ProviderKeyEQ(providerKey)).
All(ctx)
if err != nil {
return nil, err
}
matched := psFilterLegacyOrderProviderInstances(paymentType, instances)
if len(matched) == 1 {
return matched[0], nil
}
return nil, nil
}
if paymentType == "" {
return nil, nil
}
instances, err := s.entClient.PaymentProviderInstance.Query().
All(ctx)
if err != nil {
return nil, err
}
matched := psFilterLegacyOrderProviderInstances(paymentType, instances)
if len(matched) == 1 {
return matched[0], nil
}
return nil, nil
}
func psFilterLegacyOrderProviderInstances(orderPaymentType string, instances []*dbent.PaymentProviderInstance) []*dbent.PaymentProviderInstance {
if len(instances) == 0 {
return nil
}
if strings.TrimSpace(orderPaymentType) == "" {
return instances
}
var matched []*dbent.PaymentProviderInstance
for _, inst := range instances {
if psLegacyOrderMatchesInstance(orderPaymentType, inst) {
matched = append(matched, inst)
}
}
return matched
}
func psLegacyOrderMatchesInstance(orderPaymentType string, inst *dbent.PaymentProviderInstance) bool {
if inst == nil {
return false
}
baseType := payment.GetBasePaymentType(strings.TrimSpace(orderPaymentType))
instanceProviderKey := strings.TrimSpace(inst.ProviderKey)
if baseType == "" {
return false
}
if baseType == payment.TypeStripe {
return instanceProviderKey == payment.TypeStripe
}
if instanceProviderKey == payment.TypeStripe {
return false
}
if instanceProviderKey == baseType {
return true
}
return payment.InstanceSupportsType(inst.SupportedTypes, baseType)
}
func (s *PaymentService) RequestRefund(ctx context.Context, oid, uid int64, reason string) error {
o, err := s.validateRefundRequest(ctx, oid, uid)
if err != nil {
......@@ -72,7 +188,7 @@ func (s *PaymentService) validateRefundRequest(ctx context.Context, oid, uid int
return nil, infraerrors.BadRequest("INVALID_STATUS", "only completed orders can request refund")
}
// Check provider instance allows user refund
inst, err := s.getOrderProviderInstance(ctx, o)
inst, err := s.getRefundOrderProviderInstance(ctx, o)
if err != nil || inst == nil {
return nil, infraerrors.Forbidden("USER_REFUND_DISABLED", "refund is not available for this order")
}
......@@ -92,7 +208,7 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float
return nil, nil, infraerrors.BadRequest("INVALID_STATUS", "order status does not allow refund")
}
// Check provider instance allows admin refund
inst, instErr := s.getOrderProviderInstance(ctx, o)
inst, instErr := s.getRefundOrderProviderInstance(ctx, o)
if instErr != nil {
slog.Warn("refund: provider instance lookup failed", "orderID", oid, "error", instErr)
return nil, nil, infraerrors.InternalServer("PROVIDER_LOOKUP_FAILED", "failed to look up payment provider for this order")
......@@ -217,6 +333,12 @@ func (s *PaymentService) gwRefund(ctx context.Context, p *RefundPlan) error {
if err != nil {
return fmt.Errorf("get refund provider: %w", err)
}
if err := validateProviderSnapshotMetadata(p.Order, prov.ProviderKey(), providerMerchantIdentityMetadata(prov)); err != nil {
s.writeAuditLog(ctx, p.Order.ID, "REFUND_PROVIDER_METADATA_MISMATCH", "admin", map[string]any{
"detail": err.Error(),
})
return err
}
_, err = prov.Refund(ctx, payment.RefundRequest{
TradeNo: p.Order.PaymentTradeNo,
OrderID: p.Order.OutTradeNo,
......@@ -229,7 +351,14 @@ func (s *PaymentService) gwRefund(ctx context.Context, p *RefundPlan) error {
// getRefundProvider creates a provider using the order's original instance config.
// Delegates to getOrderProvider which handles instance lookup and fallback.
func (s *PaymentService) getRefundProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) {
return s.getOrderProvider(ctx, o)
inst, err := s.getRefundOrderProviderInstance(ctx, o)
if err != nil {
return nil, err
}
if inst == nil {
return nil, fmt.Errorf("refund provider instance is unavailable for order %d", o.ID)
}
return s.createProviderFromInstance(ctx, inst)
}
func (s *PaymentService) handleGwFail(ctx context.Context, p *RefundPlan, gErr error) (*RefundResult, error) {
......
//go:build unit
package service
import (
"context"
"strconv"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require"
)
func TestValidateRefundRequestRejectsLegacyGuessedProviderInstance(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
user, err := client.User.Create().
SetEmail("refund-legacy@example.com").
SetPasswordHash("hash").
SetUsername("refund-legacy-user").
Save(ctx)
require.NoError(t, err)
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeAlipay).
SetName("alipay-refund-instance").
SetConfig("{}").
SetSupportedTypes("alipay").
SetEnabled(true).
SetAllowUserRefund(true).
SetRefundEnabled(true).
Save(ctx)
require.NoError(t, err)
order, err := client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(88).
SetPayAmount(88).
SetFeeRate(0).
SetRechargeCode("REFUND-LEGACY-ORDER").
SetOutTradeNo("sub2_refund_legacy_order").
SetPaymentType(payment.TypeAlipay).
SetPaymentTradeNo("trade-legacy-refund").
SetOrderType(payment.OrderTypeBalance).
SetStatus(OrderStatusCompleted).
SetExpiresAt(time.Now().Add(time.Hour)).
SetPaidAt(time.Now()).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
Save(ctx)
require.NoError(t, err)
svc := &PaymentService{
entClient: client,
}
_, err = svc.validateRefundRequest(ctx, order.ID, user.ID)
require.Error(t, err)
require.Equal(t, "USER_REFUND_DISABLED", infraerrors.Reason(err))
}
func TestPrepareRefundRejectsLegacyGuessedProviderInstance(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
user, err := client.User.Create().
SetEmail("refund-legacy-admin@example.com").
SetPasswordHash("hash").
SetUsername("refund-legacy-admin-user").
Save(ctx)
require.NoError(t, err)
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeAlipay).
SetName("alipay-refund-admin-instance").
SetConfig("{}").
SetSupportedTypes("alipay").
SetEnabled(true).
SetAllowUserRefund(true).
SetRefundEnabled(true).
Save(ctx)
require.NoError(t, err)
order, err := client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(188).
SetPayAmount(188).
SetFeeRate(0).
SetRechargeCode("REFUND-LEGACY-ADMIN-ORDER").
SetOutTradeNo("sub2_refund_legacy_admin_order").
SetPaymentType(payment.TypeAlipay).
SetPaymentTradeNo("trade-legacy-admin-refund").
SetOrderType(payment.OrderTypeBalance).
SetStatus(OrderStatusCompleted).
SetExpiresAt(time.Now().Add(time.Hour)).
SetPaidAt(time.Now()).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
Save(ctx)
require.NoError(t, err)
svc := &PaymentService{
entClient: client,
}
plan, result, err := svc.PrepareRefund(ctx, order.ID, 0, "", false, false)
require.Nil(t, plan)
require.Nil(t, result)
require.Error(t, err)
require.Equal(t, "REFUND_DISABLED", infraerrors.Reason(err))
}
func TestGwRefundRejectsAlipayMerchantIdentitySnapshotMismatch(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
user, err := client.User.Create().
SetEmail("refund-snapshot-mismatch@example.com").
SetPasswordHash("hash").
SetUsername("refund-snapshot-mismatch-user").
Save(ctx)
require.NoError(t, err)
inst, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeAlipay).
SetName("alipay-refund-mismatch-instance").
SetConfig(encryptWebhookProviderConfig(t, map[string]string{
"appId": "runtime-alipay-app",
"privateKey": "runtime-private-key",
})).
SetSupportedTypes("alipay").
SetEnabled(true).
SetRefundEnabled(true).
Save(ctx)
require.NoError(t, err)
instID := strconv.FormatInt(inst.ID, 10)
order, err := client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(88).
SetPayAmount(88).
SetFeeRate(0).
SetRechargeCode("REFUND-SNAPSHOT-MISMATCH-ORDER").
SetOutTradeNo("sub2_refund_snapshot_mismatch_order").
SetPaymentType(payment.TypeAlipay).
SetPaymentTradeNo("trade-refund-snapshot-mismatch").
SetOrderType(payment.OrderTypeBalance).
SetStatus(OrderStatusCompleted).
SetExpiresAt(time.Now().Add(time.Hour)).
SetPaidAt(time.Now()).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
SetProviderInstanceID(instID).
SetProviderKey(payment.TypeAlipay).
SetProviderSnapshot(map[string]any{
"schema_version": 2,
"provider_instance_id": instID,
"provider_key": payment.TypeAlipay,
"merchant_app_id": "expected-alipay-app",
}).
Save(ctx)
require.NoError(t, err)
svc := &PaymentService{
entClient: client,
loadBalancer: newWebhookProviderTestLoadBalancer(client),
}
err = svc.gwRefund(ctx, &RefundPlan{
OrderID: order.ID,
Order: order,
RefundAmount: order.Amount,
GatewayAmount: order.Amount,
Reason: "snapshot mismatch",
})
require.ErrorContains(t, err, "alipay app_id mismatch")
}
package service
import (
"context"
"fmt"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
)
func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token string) (*dbent.PaymentOrder, error) {
claims, err := s.paymentResume().ParseToken(strings.TrimSpace(token))
if err != nil {
return nil, err
}
order, err := s.entClient.PaymentOrder.Get(ctx, claims.OrderID)
if err != nil {
return nil, fmt.Errorf("get order by resume token: %w", err)
}
if claims.UserID > 0 && order.UserID != claims.UserID {
return nil, fmt.Errorf("resume token user mismatch")
}
snapshot := psOrderProviderSnapshot(order)
orderProviderInstanceID := strings.TrimSpace(psStringValue(order.ProviderInstanceID))
orderProviderKey := strings.TrimSpace(psStringValue(order.ProviderKey))
if snapshot != nil {
if snapshot.ProviderInstanceID != "" {
orderProviderInstanceID = snapshot.ProviderInstanceID
}
if snapshot.ProviderKey != "" {
orderProviderKey = snapshot.ProviderKey
}
}
if claims.ProviderInstanceID != "" && orderProviderInstanceID != claims.ProviderInstanceID {
return nil, fmt.Errorf("resume token provider instance mismatch")
}
if claims.ProviderKey != "" && orderProviderKey != claims.ProviderKey {
return nil, fmt.Errorf("resume token provider key mismatch")
}
if claims.PaymentType != "" && strings.TrimSpace(order.PaymentType) != claims.PaymentType {
return nil, fmt.Errorf("resume token payment type mismatch")
}
if order.Status == OrderStatusPending || order.Status == OrderStatusExpired {
result := s.checkPaid(ctx, order)
if result == checkPaidResultAlreadyPaid {
order, err = s.entClient.PaymentOrder.Get(ctx, order.ID)
if err != nil {
return nil, fmt.Errorf("reload order by resume token: %w", err)
}
}
}
return order, nil
}
func (s *PaymentService) ParseWeChatPaymentResumeToken(token string) (*WeChatPaymentResumeClaims, error) {
return s.paymentResume().ParseWeChatPaymentResumeToken(strings.TrimSpace(token))
}
//go:build unit
package service
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/stretchr/testify/require"
)
type paymentResumeLookupProvider struct {
queryCount int
}
func (p *paymentResumeLookupProvider) Name() string { return "resume-lookup-provider" }
func (p *paymentResumeLookupProvider) ProviderKey() string { return payment.TypeAlipay }
func (p *paymentResumeLookupProvider) SupportedTypes() []payment.PaymentType {
return []payment.PaymentType{payment.TypeAlipay}
}
func (p *paymentResumeLookupProvider) CreatePayment(context.Context, payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
panic("unexpected call")
}
func (p *paymentResumeLookupProvider) QueryOrder(context.Context, string) (*payment.QueryOrderResponse, error) {
p.queryCount++
return &payment.QueryOrderResponse{Status: payment.ProviderStatusPending}, nil
}
func (p *paymentResumeLookupProvider) VerifyNotification(context.Context, string, map[string]string) (*payment.PaymentNotification, error) {
panic("unexpected call")
}
func (p *paymentResumeLookupProvider) Refund(context.Context, payment.RefundRequest) (*payment.RefundResponse, error) {
panic("unexpected call")
}
func TestGetPublicOrderByResumeTokenReturnsMatchingOrder(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
user, err := client.User.Create().
SetEmail("resume@example.com").
SetPasswordHash("hash").
SetUsername("resume-user").
Save(ctx)
require.NoError(t, err)
instanceID := "12"
providerKey := payment.TypeEasyPay
order, err := client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(88).
SetPayAmount(88).
SetFeeRate(0).
SetRechargeCode("RESUME-ORDER").
SetOutTradeNo("sub2_resume_lookup").
SetPaymentType(payment.TypeAlipay).
SetPaymentTradeNo("trade-1").
SetOrderType(payment.OrderTypeBalance).
SetStatus(OrderStatusPending).
SetExpiresAt(time.Now().Add(time.Hour)).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
SetProviderInstanceID(instanceID).
SetProviderKey(providerKey).
Save(ctx)
require.NoError(t, err)
resumeSvc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
token, err := resumeSvc.CreateToken(ResumeTokenClaims{
OrderID: order.ID,
UserID: user.ID,
ProviderInstanceID: instanceID,
ProviderKey: providerKey,
PaymentType: payment.TypeAlipay,
CanonicalReturnURL: "https://app.example.com/payment/result",
})
require.NoError(t, err)
svc := &PaymentService{
entClient: client,
resumeService: resumeSvc,
}
got, err := svc.GetPublicOrderByResumeToken(ctx, token)
require.NoError(t, err)
require.Equal(t, order.ID, got.ID)
}
func TestGetPublicOrderByResumeTokenRejectsSnapshotMismatch(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
user, err := client.User.Create().
SetEmail("resume-mismatch@example.com").
SetPasswordHash("hash").
SetUsername("resume-mismatch-user").
Save(ctx)
require.NoError(t, err)
order, err := client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(88).
SetPayAmount(88).
SetFeeRate(0).
SetRechargeCode("RESUME-MISMATCH").
SetOutTradeNo("sub2_resume_lookup_mismatch").
SetPaymentType(payment.TypeAlipay).
SetPaymentTradeNo("trade-2").
SetOrderType(payment.OrderTypeBalance).
SetStatus(OrderStatusPending).
SetExpiresAt(time.Now().Add(time.Hour)).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
SetProviderInstanceID("12").
SetProviderKey(payment.TypeEasyPay).
Save(ctx)
require.NoError(t, err)
resumeSvc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
token, err := resumeSvc.CreateToken(ResumeTokenClaims{
OrderID: order.ID,
UserID: user.ID,
ProviderInstanceID: "99",
ProviderKey: payment.TypeEasyPay,
PaymentType: payment.TypeAlipay,
CanonicalReturnURL: "https://app.example.com/payment/result",
})
require.NoError(t, err)
svc := &PaymentService{
entClient: client,
resumeService: resumeSvc,
}
_, err = svc.GetPublicOrderByResumeToken(ctx, token)
require.Error(t, err)
require.Contains(t, err.Error(), "resume token")
}
func TestGetPublicOrderByResumeTokenUsesSnapshotAuthorityWhenColumnsDiffer(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
user, err := client.User.Create().
SetEmail("resume-snapshot-authority@example.com").
SetPasswordHash("hash").
SetUsername("resume-snapshot-authority-user").
Save(ctx)
require.NoError(t, err)
order, err := client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(88).
SetPayAmount(88).
SetFeeRate(0).
SetRechargeCode("RESUME-SNAPSHOT-AUTHORITY").
SetOutTradeNo("sub2_resume_snapshot_authority").
SetPaymentType(payment.TypeAlipay).
SetPaymentTradeNo("trade-snapshot-authority").
SetOrderType(payment.OrderTypeBalance).
SetStatus(OrderStatusPending).
SetExpiresAt(time.Now().Add(time.Hour)).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
SetProviderInstanceID("legacy-column-instance").
SetProviderKey(payment.TypeAlipay).
SetProviderSnapshot(map[string]any{
"schema_version": 2,
"provider_instance_id": "snapshot-instance",
"provider_key": payment.TypeEasyPay,
}).
Save(ctx)
require.NoError(t, err)
resumeSvc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
token, err := resumeSvc.CreateToken(ResumeTokenClaims{
OrderID: order.ID,
UserID: user.ID,
ProviderInstanceID: "snapshot-instance",
ProviderKey: payment.TypeEasyPay,
PaymentType: payment.TypeAlipay,
CanonicalReturnURL: "https://app.example.com/payment/result",
})
require.NoError(t, err)
svc := &PaymentService{
entClient: client,
resumeService: resumeSvc,
}
got, err := svc.GetPublicOrderByResumeToken(ctx, token)
require.NoError(t, err)
require.Equal(t, order.ID, got.ID)
}
func TestGetPublicOrderByResumeTokenChecksUpstreamForPendingOrder(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
user, err := client.User.Create().
SetEmail("resume-refresh@example.com").
SetPasswordHash("hash").
SetUsername("resume-refresh-user").
Save(ctx)
require.NoError(t, err)
order, err := client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(88).
SetPayAmount(88).
SetFeeRate(0).
SetRechargeCode("RESUME-PENDING").
SetOutTradeNo("sub2_resume_lookup_pending").
SetPaymentType(payment.TypeAlipay).
SetPaymentTradeNo("trade-pending").
SetOrderType(payment.OrderTypeBalance).
SetStatus(OrderStatusPending).
SetExpiresAt(time.Now().Add(time.Hour)).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
Save(ctx)
require.NoError(t, err)
resumeSvc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
token, err := resumeSvc.CreateToken(ResumeTokenClaims{
OrderID: order.ID,
UserID: user.ID,
PaymentType: payment.TypeAlipay,
CanonicalReturnURL: "https://app.example.com/payment/result",
})
require.NoError(t, err)
registry := payment.NewRegistry()
provider := &paymentResumeLookupProvider{}
registry.Register(provider)
svc := &PaymentService{
entClient: client,
registry: registry,
resumeService: resumeSvc,
providersLoaded: true,
}
got, err := svc.GetPublicOrderByResumeToken(ctx, token)
require.NoError(t, err)
require.Equal(t, order.ID, got.ID)
require.Equal(t, 1, provider.queryCount)
}
func TestVerifyOrderPublicDoesNotCheckUpstreamForPendingOrder(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
user, err := client.User.Create().
SetEmail("public-verify@example.com").
SetPasswordHash("hash").
SetUsername("public-verify-user").
Save(ctx)
require.NoError(t, err)
order, err := client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(88).
SetPayAmount(88).
SetFeeRate(0).
SetRechargeCode("PUBLIC-VERIFY").
SetOutTradeNo("sub2_public_verify_pending").
SetPaymentType(payment.TypeAlipay).
SetPaymentTradeNo("trade-public-verify").
SetOrderType(payment.OrderTypeBalance).
SetStatus(OrderStatusPending).
SetExpiresAt(time.Now().Add(time.Hour)).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
Save(ctx)
require.NoError(t, err)
registry := payment.NewRegistry()
provider := &paymentResumeLookupProvider{}
registry.Register(provider)
svc := &PaymentService{
entClient: client,
registry: registry,
providersLoaded: true,
}
got, err := svc.VerifyOrderPublic(ctx, order.OutTradeNo)
require.NoError(t, err)
require.Equal(t, order.ID, got.ID)
require.Equal(t, 0, provider.queryCount)
}
package service
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"net"
"net/url"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
const paymentResultReturnPath = "/payment/result"
const (
PaymentSourceHostedRedirect = "hosted_redirect"
PaymentSourceWechatInAppResume = "wechat_in_app_resume"
SettingPaymentVisibleMethodAlipaySource = "payment_visible_method_alipay_source"
SettingPaymentVisibleMethodWxpaySource = "payment_visible_method_wxpay_source"
SettingPaymentVisibleMethodAlipayEnabled = "payment_visible_method_alipay_enabled"
SettingPaymentVisibleMethodWxpayEnabled = "payment_visible_method_wxpay_enabled"
VisibleMethodSourceOfficialAlipay = "official_alipay"
VisibleMethodSourceEasyPayAlipay = "easypay_alipay"
VisibleMethodSourceOfficialWechat = "official_wxpay"
VisibleMethodSourceEasyPayWechat = "easypay_wxpay"
wechatPaymentResumeTokenType = "wechat_payment_resume"
paymentResumeNotConfiguredCode = "PAYMENT_RESUME_NOT_CONFIGURED"
paymentResumeNotConfiguredMessage = "payment resume tokens require a configured signing key"
paymentResumeTokenTTL = 24 * time.Hour
wechatPaymentResumeTokenTTL = 15 * time.Minute
)
type ResumeTokenClaims struct {
OrderID int64 `json:"oid"`
UserID int64 `json:"uid,omitempty"`
ProviderInstanceID string `json:"pi,omitempty"`
ProviderKey string `json:"pk,omitempty"`
PaymentType string `json:"pt,omitempty"`
CanonicalReturnURL string `json:"ru,omitempty"`
IssuedAt int64 `json:"iat"`
ExpiresAt int64 `json:"exp,omitempty"`
}
type WeChatPaymentResumeClaims struct {
TokenType string `json:"tk,omitempty"`
OpenID string `json:"openid"`
PaymentType string `json:"pt,omitempty"`
Amount string `json:"amt,omitempty"`
OrderType string `json:"ot,omitempty"`
PlanID int64 `json:"pid,omitempty"`
RedirectTo string `json:"rd,omitempty"`
Scope string `json:"scp,omitempty"`
IssuedAt int64 `json:"iat"`
ExpiresAt int64 `json:"exp,omitempty"`
}
type PaymentResumeService struct {
signingKey []byte
}
type visibleMethodLoadBalancer struct {
inner payment.LoadBalancer
configService *PaymentConfigService
}
func NewPaymentResumeService(signingKey []byte) *PaymentResumeService {
return &PaymentResumeService{signingKey: signingKey}
}
func (s *PaymentResumeService) isSigningConfigured() bool {
return s != nil && len(s.signingKey) > 0
}
func (s *PaymentResumeService) ensureSigningKey() error {
if s.isSigningConfigured() {
return nil
}
return infraerrors.ServiceUnavailable(paymentResumeNotConfiguredCode, paymentResumeNotConfiguredMessage)
}
func NormalizeVisibleMethod(method string) string {
return payment.GetBasePaymentType(strings.TrimSpace(method))
}
func NormalizeVisibleMethods(methods []string) []string {
if len(methods) == 0 {
return nil
}
seen := make(map[string]struct{}, len(methods))
out := make([]string, 0, len(methods))
for _, method := range methods {
normalized := NormalizeVisibleMethod(method)
if normalized == "" {
continue
}
if _, ok := seen[normalized]; ok {
continue
}
seen[normalized] = struct{}{}
out = append(out, normalized)
}
return out
}
func NormalizePaymentSource(source string) string {
switch strings.TrimSpace(strings.ToLower(source)) {
case "", PaymentSourceHostedRedirect:
return PaymentSourceHostedRedirect
case "wechat_in_app", "wxpay_resume", PaymentSourceWechatInAppResume:
return PaymentSourceWechatInAppResume
default:
return strings.TrimSpace(strings.ToLower(source))
}
}
func NormalizeVisibleMethodSource(method, source string) string {
switch NormalizeVisibleMethod(method) {
case payment.TypeAlipay:
switch strings.TrimSpace(strings.ToLower(source)) {
case VisibleMethodSourceOfficialAlipay, payment.TypeAlipay, payment.TypeAlipayDirect, "official":
return VisibleMethodSourceOfficialAlipay
case VisibleMethodSourceEasyPayAlipay, payment.TypeEasyPay:
return VisibleMethodSourceEasyPayAlipay
}
case payment.TypeWxpay:
switch strings.TrimSpace(strings.ToLower(source)) {
case VisibleMethodSourceOfficialWechat, payment.TypeWxpay, payment.TypeWxpayDirect, "wechat", "official":
return VisibleMethodSourceOfficialWechat
case VisibleMethodSourceEasyPayWechat, payment.TypeEasyPay:
return VisibleMethodSourceEasyPayWechat
}
}
return ""
}
func VisibleMethodProviderKeyForSource(method, source string) (string, bool) {
switch NormalizeVisibleMethodSource(method, source) {
case VisibleMethodSourceOfficialAlipay:
return payment.TypeAlipay, NormalizeVisibleMethod(method) == payment.TypeAlipay
case VisibleMethodSourceEasyPayAlipay:
return payment.TypeEasyPay, NormalizeVisibleMethod(method) == payment.TypeAlipay
case VisibleMethodSourceOfficialWechat:
return payment.TypeWxpay, NormalizeVisibleMethod(method) == payment.TypeWxpay
case VisibleMethodSourceEasyPayWechat:
return payment.TypeEasyPay, NormalizeVisibleMethod(method) == payment.TypeWxpay
default:
return "", false
}
}
func newVisibleMethodLoadBalancer(inner payment.LoadBalancer, configService *PaymentConfigService) payment.LoadBalancer {
if inner == nil || configService == nil || configService.entClient == nil {
return inner
}
return &visibleMethodLoadBalancer{inner: inner, configService: configService}
}
func (lb *visibleMethodLoadBalancer) GetInstanceConfig(ctx context.Context, instanceID int64) (map[string]string, error) {
return lb.inner.GetInstanceConfig(ctx, instanceID)
}
func (lb *visibleMethodLoadBalancer) SelectInstance(ctx context.Context, providerKey string, paymentType payment.PaymentType, strategy payment.Strategy, orderAmount float64) (*payment.InstanceSelection, error) {
visibleMethod := NormalizeVisibleMethod(paymentType)
if providerKey != "" || (visibleMethod != payment.TypeAlipay && visibleMethod != payment.TypeWxpay) {
return lb.inner.SelectInstance(ctx, providerKey, paymentType, strategy, orderAmount)
}
inst, err := lb.configService.resolveEnabledVisibleMethodInstance(ctx, visibleMethod)
if err != nil {
return nil, err
}
if inst == nil {
return nil, fmt.Errorf("visible payment method %s has no enabled provider instance", visibleMethod)
}
return lb.inner.SelectInstance(ctx, inst.ProviderKey, paymentType, strategy, orderAmount)
}
func visibleMethodEnabledSettingKey(method string) string {
switch NormalizeVisibleMethod(method) {
case payment.TypeAlipay:
return SettingPaymentVisibleMethodAlipayEnabled
case payment.TypeWxpay:
return SettingPaymentVisibleMethodWxpayEnabled
default:
return ""
}
}
func visibleMethodSourceSettingKey(method string) string {
switch NormalizeVisibleMethod(method) {
case payment.TypeAlipay:
return SettingPaymentVisibleMethodAlipaySource
case payment.TypeWxpay:
return SettingPaymentVisibleMethodWxpaySource
default:
return ""
}
}
func CanonicalizeReturnURL(raw string, srcHost string) (string, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return "", nil
}
parsed, err := url.Parse(raw)
if err != nil || !parsed.IsAbs() || parsed.Host == "" {
return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must be an absolute http/https URL")
}
if parsed.Scheme != "http" && parsed.Scheme != "https" {
return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must use http or https")
}
parsed.Fragment = ""
if parsed.Path == "" {
parsed.Path = "/"
}
if parsed.Path != paymentResultReturnPath {
return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must target the canonical internal payment result page")
}
if !sameOriginHost(parsed.Host, srcHost) {
return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must use the same host as the current site")
}
return parsed.String(), nil
}
func buildPaymentReturnURL(base string, orderID int64, resumeToken string) (string, error) {
canonical := strings.TrimSpace(base)
if canonical == "" {
return "", nil
}
parsed, err := url.Parse(canonical)
if err != nil {
return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must be a valid URL")
}
if !parsed.IsAbs() || parsed.Host == "" {
return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must be a valid absolute URL")
}
parsed.Fragment = ""
query := parsed.Query()
if orderID > 0 {
query.Set("order_id", strconv.FormatInt(orderID, 10))
}
if strings.TrimSpace(resumeToken) != "" {
query.Set("resume_token", strings.TrimSpace(resumeToken))
}
query.Set("status", "success")
parsed.RawQuery = query.Encode()
return parsed.String(), nil
}
func sameOriginHost(returnURLHost string, requestHost string) bool {
returnHost := strings.TrimSpace(returnURLHost)
reqHost := strings.TrimSpace(requestHost)
if returnHost == "" || reqHost == "" {
return false
}
if strings.EqualFold(returnHost, reqHost) {
return true
}
returnName, returnPort := splitHostPortDefault(returnHost)
reqName, reqPort := splitHostPortDefault(reqHost)
if returnName == "" || reqName == "" {
return false
}
return strings.EqualFold(returnName, reqName) && returnPort == reqPort
}
func splitHostPortDefault(raw string) (string, string) {
if host, port, err := net.SplitHostPort(raw); err == nil {
return host, port
}
return raw, ""
}
func (s *PaymentResumeService) CreateToken(claims ResumeTokenClaims) (string, error) {
if err := s.ensureSigningKey(); err != nil {
return "", err
}
if claims.OrderID <= 0 {
return "", fmt.Errorf("resume token requires order id")
}
if claims.IssuedAt == 0 {
claims.IssuedAt = time.Now().Unix()
}
if claims.ExpiresAt == 0 {
claims.ExpiresAt = time.Now().Add(paymentResumeTokenTTL).Unix()
}
return s.createSignedToken(claims)
}
func (s *PaymentResumeService) ParseToken(token string) (*ResumeTokenClaims, error) {
if err := s.ensureSigningKey(); err != nil {
return nil, err
}
var claims ResumeTokenClaims
if err := s.parseSignedToken(token, &claims); err != nil {
return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is invalid")
}
if claims.OrderID <= 0 {
return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token missing order id")
}
if err := validatePaymentResumeExpiry(claims.ExpiresAt, "INVALID_RESUME_TOKEN", "resume token has expired"); err != nil {
return nil, err
}
return &claims, nil
}
func (s *PaymentResumeService) CreateWeChatPaymentResumeToken(claims WeChatPaymentResumeClaims) (string, error) {
if err := s.ensureSigningKey(); err != nil {
return "", err
}
claims.OpenID = strings.TrimSpace(claims.OpenID)
if claims.OpenID == "" {
return "", fmt.Errorf("wechat payment resume token requires openid")
}
if claims.IssuedAt == 0 {
claims.IssuedAt = time.Now().Unix()
}
if claims.ExpiresAt == 0 {
claims.ExpiresAt = time.Now().Add(wechatPaymentResumeTokenTTL).Unix()
}
if normalized := NormalizeVisibleMethod(claims.PaymentType); normalized != "" {
claims.PaymentType = normalized
}
if claims.PaymentType == "" {
claims.PaymentType = payment.TypeWxpay
}
if claims.OrderType == "" {
claims.OrderType = payment.OrderTypeBalance
}
claims.TokenType = wechatPaymentResumeTokenType
return s.createSignedToken(claims)
}
func (s *PaymentResumeService) ParseWeChatPaymentResumeToken(token string) (*WeChatPaymentResumeClaims, error) {
if err := s.ensureSigningKey(); err != nil {
return nil, err
}
var claims WeChatPaymentResumeClaims
if err := s.parseSignedToken(token, &claims); err != nil {
return nil, infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token payload is invalid")
}
if claims.TokenType != wechatPaymentResumeTokenType {
return nil, infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token type mismatch")
}
claims.OpenID = strings.TrimSpace(claims.OpenID)
if claims.OpenID == "" {
return nil, infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token missing openid")
}
if err := validatePaymentResumeExpiry(claims.ExpiresAt, "INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token has expired"); err != nil {
return nil, err
}
if normalized := NormalizeVisibleMethod(claims.PaymentType); normalized != "" {
claims.PaymentType = normalized
}
if claims.PaymentType == "" {
claims.PaymentType = payment.TypeWxpay
}
if claims.OrderType == "" {
claims.OrderType = payment.OrderTypeBalance
}
return &claims, nil
}
func (s *PaymentResumeService) createSignedToken(claims any) (string, error) {
payload, err := json.Marshal(claims)
if err != nil {
return "", fmt.Errorf("marshal resume claims: %w", err)
}
encodedPayload := base64.RawURLEncoding.EncodeToString(payload)
return encodedPayload + "." + s.sign(encodedPayload), nil
}
func (s *PaymentResumeService) parseSignedToken(token string, dest any) error {
parts := strings.Split(token, ".")
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token is malformed")
}
if !hmac.Equal([]byte(parts[1]), []byte(s.sign(parts[0]))) {
return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token signature mismatch")
}
payload, err := base64.RawURLEncoding.DecodeString(parts[0])
if err != nil {
return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is malformed")
}
return json.Unmarshal(payload, dest)
}
func validatePaymentResumeExpiry(expiresAt int64, code, message string) error {
if expiresAt <= 0 {
return nil
}
if time.Now().Unix() > expiresAt {
return infraerrors.BadRequest(code, message)
}
return nil
}
func (s *PaymentResumeService) sign(payload string) string {
mac := hmac.New(sha256.New, s.signingKey)
_, _ = mac.Write([]byte(payload))
return base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
}
//go:build unit
package service
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"net/url"
"strconv"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/payment"
)
func TestNormalizeVisibleMethods(t *testing.T) {
t.Parallel()
got := NormalizeVisibleMethods([]string{
"alipay_direct",
"alipay",
" wxpay_direct ",
"wxpay",
"stripe",
})
want := []string{"alipay", "wxpay", "stripe"}
if len(got) != len(want) {
t.Fatalf("NormalizeVisibleMethods len = %d, want %d (%v)", len(got), len(want), got)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("NormalizeVisibleMethods[%d] = %q, want %q (full=%v)", i, got[i], want[i], got)
}
}
}
func TestNormalizePaymentSource(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
expect string
}{
{name: "empty uses default", input: "", expect: PaymentSourceHostedRedirect},
{name: "wechat alias normalized", input: "wechat_in_app", expect: PaymentSourceWechatInAppResume},
{name: "canonical value preserved", input: PaymentSourceWechatInAppResume, expect: PaymentSourceWechatInAppResume},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := NormalizePaymentSource(tt.input); got != tt.expect {
t.Fatalf("NormalizePaymentSource(%q) = %q, want %q", tt.input, got, tt.expect)
}
})
}
}
func TestCanonicalizeReturnURL(t *testing.T) {
t.Parallel()
got, err := CanonicalizeReturnURL("https://example.com/payment/result?b=2#a", "example.com")
if err != nil {
t.Fatalf("CanonicalizeReturnURL returned error: %v", err)
}
if got != "https://example.com/payment/result?b=2" {
t.Fatalf("CanonicalizeReturnURL = %q, want %q", got, "https://example.com/payment/result?b=2")
}
}
func TestCanonicalizeReturnURLRejectsRelativeURL(t *testing.T) {
t.Parallel()
if _, err := CanonicalizeReturnURL("/payment/result", "example.com"); err == nil {
t.Fatal("CanonicalizeReturnURL should reject relative URLs")
}
}
func TestCanonicalizeReturnURLRejectsExternalHost(t *testing.T) {
t.Parallel()
if _, err := CanonicalizeReturnURL("https://evil.example/payment/result", "app.example.com"); err == nil {
t.Fatal("CanonicalizeReturnURL should reject external hosts")
}
}
func TestCanonicalizeReturnURLRejectsNonCanonicalPath(t *testing.T) {
t.Parallel()
if _, err := CanonicalizeReturnURL("https://app.example.com/orders/42", "app.example.com"); err == nil {
t.Fatal("CanonicalizeReturnURL should reject non-canonical result paths")
}
}
func TestBuildPaymentReturnURL(t *testing.T) {
t.Parallel()
got, err := buildPaymentReturnURL("https://example.com/payment/result?from=checkout#fragment", 42, "resume-token")
if err != nil {
t.Fatalf("buildPaymentReturnURL returned error: %v", err)
}
parsed, err := url.Parse(got)
if err != nil {
t.Fatalf("url.Parse returned error: %v", err)
}
if parsed.Fragment != "" {
t.Fatalf("buildPaymentReturnURL should strip fragments, got %q", parsed.Fragment)
}
query := parsed.Query()
if query.Get("from") != "checkout" {
t.Fatalf("expected original query to be preserved, got %q", query.Get("from"))
}
if query.Get("order_id") != strconv.FormatInt(42, 10) {
t.Fatalf("order_id = %q", query.Get("order_id"))
}
if query.Get("resume_token") != "resume-token" {
t.Fatalf("resume_token = %q", query.Get("resume_token"))
}
if query.Get("status") != "success" {
t.Fatalf("status = %q", query.Get("status"))
}
}
func TestBuildPaymentReturnURLEmptyBase(t *testing.T) {
t.Parallel()
got, err := buildPaymentReturnURL("", 42, "resume-token")
if err != nil {
t.Fatalf("buildPaymentReturnURL returned error: %v", err)
}
if got != "" {
t.Fatalf("buildPaymentReturnURL = %q, want empty string", got)
}
}
func TestPaymentResumeTokenRoundTrip(t *testing.T) {
t.Parallel()
svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
token, err := svc.CreateToken(ResumeTokenClaims{
OrderID: 42,
UserID: 7,
ProviderInstanceID: "19",
ProviderKey: "easypay",
PaymentType: "wxpay",
CanonicalReturnURL: "https://example.com/payment/result",
IssuedAt: 1234567890,
})
if err != nil {
t.Fatalf("CreateToken returned error: %v", err)
}
claims, err := svc.ParseToken(token)
if err != nil {
t.Fatalf("ParseToken returned error: %v", err)
}
if claims.OrderID != 42 || claims.UserID != 7 {
t.Fatalf("claims mismatch: %+v", claims)
}
if claims.ProviderInstanceID != "19" || claims.ProviderKey != "easypay" || claims.PaymentType != "wxpay" {
t.Fatalf("claims provider snapshot mismatch: %+v", claims)
}
if claims.CanonicalReturnURL != "https://example.com/payment/result" {
t.Fatalf("claims return URL = %q", claims.CanonicalReturnURL)
}
}
func TestCreateTokenRejectsMissingSigningKey(t *testing.T) {
t.Parallel()
svc := NewPaymentResumeService(nil)
_, err := svc.CreateToken(ResumeTokenClaims{OrderID: 42})
if err == nil {
t.Fatal("CreateToken should reject missing signing key")
}
}
func TestParseTokenRejectsFallbackSignedTokenWhenSigningKeyMissing(t *testing.T) {
t.Parallel()
token := mustCreateFallbackSignedToken(t, ResumeTokenClaims{OrderID: 42, UserID: 7})
svc := NewPaymentResumeService(nil)
_, err := svc.ParseToken(token)
if err == nil {
t.Fatal("ParseToken should reject tokens when signing key is missing")
}
}
func TestParseTokenRejectsExpiredToken(t *testing.T) {
t.Parallel()
svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
token, err := svc.CreateToken(ResumeTokenClaims{
OrderID: 42,
UserID: 7,
IssuedAt: time.Now().Add(-25 * time.Hour).Unix(),
ExpiresAt: time.Now().Add(-1 * time.Hour).Unix(),
})
if err != nil {
t.Fatalf("CreateToken returned error: %v", err)
}
_, err = svc.ParseToken(token)
if err == nil {
t.Fatal("ParseToken should reject expired tokens")
}
}
func TestWeChatPaymentResumeTokenRoundTrip(t *testing.T) {
t.Parallel()
svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
token, err := svc.CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
OpenID: "openid-123",
PaymentType: payment.TypeWxpay,
Amount: "12.50",
OrderType: payment.OrderTypeSubscription,
PlanID: 7,
RedirectTo: "/purchase?from=wechat",
Scope: "snsapi_base",
IssuedAt: 1234567890,
})
if err != nil {
t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
}
claims, err := svc.ParseWeChatPaymentResumeToken(token)
if err != nil {
t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err)
}
if claims.OpenID != "openid-123" || claims.PaymentType != payment.TypeWxpay {
t.Fatalf("claims mismatch: %+v", claims)
}
if claims.Amount != "12.50" || claims.OrderType != payment.OrderTypeSubscription || claims.PlanID != 7 {
t.Fatalf("claims payment context mismatch: %+v", claims)
}
if claims.RedirectTo != "/purchase?from=wechat" || claims.Scope != "snsapi_base" {
t.Fatalf("claims redirect/scope mismatch: %+v", claims)
}
}
func TestCreateWeChatPaymentResumeTokenRejectsMissingSigningKey(t *testing.T) {
t.Parallel()
svc := NewPaymentResumeService(nil)
_, err := svc.CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{OpenID: "openid-123"})
if err == nil {
t.Fatal("CreateWeChatPaymentResumeToken should reject missing signing key")
}
}
func TestParseWeChatPaymentResumeTokenRejectsFallbackSignedTokenWhenSigningKeyMissing(t *testing.T) {
t.Parallel()
token := mustCreateFallbackSignedToken(t, WeChatPaymentResumeClaims{
TokenType: wechatPaymentResumeTokenType,
OpenID: "openid-123",
PaymentType: payment.TypeWxpay,
})
svc := NewPaymentResumeService(nil)
_, err := svc.ParseWeChatPaymentResumeToken(token)
if err == nil {
t.Fatal("ParseWeChatPaymentResumeToken should reject tokens when signing key is missing")
}
}
func TestParseWeChatPaymentResumeTokenRejectsExpiredToken(t *testing.T) {
t.Parallel()
svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
token, err := svc.CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
OpenID: "openid-123",
PaymentType: payment.TypeWxpay,
IssuedAt: time.Now().Add(-30 * time.Minute).Unix(),
ExpiresAt: time.Now().Add(-1 * time.Minute).Unix(),
})
if err != nil {
t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
}
_, err = svc.ParseWeChatPaymentResumeToken(token)
if err == nil {
t.Fatal("ParseWeChatPaymentResumeToken should reject expired tokens")
}
}
func TestNormalizeVisibleMethodSource(t *testing.T) {
t.Parallel()
tests := []struct {
name string
method string
input string
want string
}{
{name: "alipay official alias", method: payment.TypeAlipay, input: "alipay", want: VisibleMethodSourceOfficialAlipay},
{name: "alipay easypay alias", method: payment.TypeAlipay, input: "easypay", want: VisibleMethodSourceEasyPayAlipay},
{name: "wxpay official alias", method: payment.TypeWxpay, input: "wxpay", want: VisibleMethodSourceOfficialWechat},
{name: "wxpay easypay alias", method: payment.TypeWxpay, input: "easypay", want: VisibleMethodSourceEasyPayWechat},
{name: "unsupported source", method: payment.TypeWxpay, input: "stripe", want: ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := NormalizeVisibleMethodSource(tt.method, tt.input); got != tt.want {
t.Fatalf("NormalizeVisibleMethodSource(%q, %q) = %q, want %q", tt.method, tt.input, got, tt.want)
}
})
}
}
func TestVisibleMethodProviderKeyForSource(t *testing.T) {
t.Parallel()
tests := []struct {
name string
method string
source string
want string
ok bool
}{
{name: "official alipay", method: payment.TypeAlipay, source: VisibleMethodSourceOfficialAlipay, want: payment.TypeAlipay, ok: true},
{name: "easypay alipay", method: payment.TypeAlipay, source: VisibleMethodSourceEasyPayAlipay, want: payment.TypeEasyPay, ok: true},
{name: "official wechat", method: payment.TypeWxpay, source: VisibleMethodSourceOfficialWechat, want: payment.TypeWxpay, ok: true},
{name: "easypay wechat", method: payment.TypeWxpay, source: VisibleMethodSourceEasyPayWechat, want: payment.TypeEasyPay, ok: true},
{name: "mismatched method and source", method: payment.TypeAlipay, source: VisibleMethodSourceOfficialWechat, want: "", ok: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, ok := VisibleMethodProviderKeyForSource(tt.method, tt.source)
if got != tt.want || ok != tt.ok {
t.Fatalf("VisibleMethodProviderKeyForSource(%q, %q) = (%q, %v), want (%q, %v)", tt.method, tt.source, got, ok, tt.want, tt.ok)
}
})
}
}
func TestVisibleMethodLoadBalancerUsesEnabledProviderInstance(t *testing.T) {
t.Parallel()
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
_, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeAlipay).
SetName("Official Alipay").
SetConfig("{}").
SetSupportedTypes("alipay").
SetEnabled(true).
SetSortOrder(1).
Save(ctx)
if err != nil {
t.Fatalf("create alipay provider: %v", err)
}
inner := &captureLoadBalancer{}
configService := &PaymentConfigService{
entClient: client,
}
lb := newVisibleMethodLoadBalancer(inner, configService)
_, err = lb.SelectInstance(ctx, "", payment.TypeAlipay, payment.StrategyRoundRobin, 12.5)
if err != nil {
t.Fatalf("SelectInstance returned error: %v", err)
}
if inner.lastProviderKey != payment.TypeAlipay {
t.Fatalf("lastProviderKey = %q, want %q", inner.lastProviderKey, payment.TypeAlipay)
}
}
func TestVisibleMethodLoadBalancerRejectsMissingEnabledVisibleMethodProvider(t *testing.T) {
t.Parallel()
inner := &captureLoadBalancer{}
configService := &PaymentConfigService{
entClient: newPaymentConfigServiceTestClient(t),
}
lb := newVisibleMethodLoadBalancer(inner, configService)
if _, err := lb.SelectInstance(context.Background(), "", payment.TypeWxpay, payment.StrategyRoundRobin, 9.9); err == nil {
t.Fatal("SelectInstance should reject when no enabled provider instance exists")
}
}
type captureLoadBalancer struct {
lastProviderKey string
lastPaymentType string
}
func (c *captureLoadBalancer) GetInstanceConfig(context.Context, int64) (map[string]string, error) {
return map[string]string{}, nil
}
func (c *captureLoadBalancer) SelectInstance(_ context.Context, providerKey string, paymentType payment.PaymentType, _ payment.Strategy, _ float64) (*payment.InstanceSelection, error) {
c.lastProviderKey = providerKey
c.lastPaymentType = paymentType
return &payment.InstanceSelection{ProviderKey: providerKey, SupportedTypes: paymentType}, nil
}
func mustCreateFallbackSignedToken(t *testing.T, claims any) string {
t.Helper()
payload, err := json.Marshal(claims)
if err != nil {
t.Fatalf("marshal claims: %v", err)
}
encodedPayload := base64.RawURLEncoding.EncodeToString(payload)
mac := hmac.New(sha256.New, []byte("sub2api-payment-resume"))
_, _ = mac.Write([]byte(encodedPayload))
signature := base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
return encodedPayload + "." + signature
}
......@@ -9,7 +9,6 @@ import (
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/Wei-Shaw/sub2api/internal/payment/provider"
......@@ -65,29 +64,39 @@ func generateRandomString(n int) string {
}
type CreateOrderRequest struct {
UserID int64
Amount float64
PaymentType string
ClientIP string
IsMobile bool
SrcHost string
SrcURL string
OrderType string
PlanID int64
UserID int64
Amount float64
PaymentType string
OpenID string
ClientIP string
IsMobile bool
IsWeChatBrowser bool
SrcHost string
SrcURL string
ReturnURL string
PaymentSource string
OrderType string
PlanID int64
}
type CreateOrderResponse struct {
OrderID int64 `json:"order_id"`
Amount float64 `json:"amount"`
PayAmount float64 `json:"pay_amount"`
FeeRate float64 `json:"fee_rate"`
Status string `json:"status"`
PaymentType string `json:"payment_type"`
PayURL string `json:"pay_url,omitempty"`
QRCode string `json:"qr_code,omitempty"`
ClientSecret string `json:"client_secret,omitempty"`
ExpiresAt time.Time `json:"expires_at"`
PaymentMode string `json:"payment_mode,omitempty"`
OrderID int64 `json:"order_id"`
Amount float64 `json:"amount"`
PayAmount float64 `json:"pay_amount"`
FeeRate float64 `json:"fee_rate"`
Status string `json:"status"`
ResultType payment.CreatePaymentResultType `json:"result_type,omitempty"`
PaymentType string `json:"payment_type"`
OutTradeNo string `json:"out_trade_no,omitempty"`
PayURL string `json:"pay_url,omitempty"`
QRCode string `json:"qr_code,omitempty"`
ClientSecret string `json:"client_secret,omitempty"`
OAuth *payment.WechatOAuthInfo `json:"oauth,omitempty"`
JSAPI *payment.WechatJSAPIPayload `json:"jsapi,omitempty"`
JSAPIPayload *payment.WechatJSAPIPayload `json:"jsapi_payload,omitempty"`
ExpiresAt time.Time `json:"expires_at"`
PaymentMode string `json:"payment_mode,omitempty"`
ResumeToken string `json:"resume_token,omitempty"`
}
type OrderListParams struct {
......@@ -165,10 +174,13 @@ type PaymentService struct {
configService *PaymentConfigService
userRepo UserRepository
groupRepo GroupRepository
resumeService *PaymentResumeService
}
func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService {
return &PaymentService{entClient: entClient, registry: registry, loadBalancer: loadBalancer, redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo}
svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo}
svc.resumeService = NewPaymentResumeService(psResumeSigningKey(configService))
return svc
}
// --- Provider Registry ---
......@@ -219,25 +231,6 @@ func (s *PaymentService) loadProviders(ctx context.Context) {
}
}
// GetWebhookProvider returns the provider instance that should verify a webhook.
// It extracts out_trade_no from the raw body, looks up the order to find the
// original provider instance, and creates a provider with that instance's credentials.
// Falls back to the registry provider when the order cannot be found.
func (s *PaymentService) GetWebhookProvider(ctx context.Context, providerKey, outTradeNo string) (payment.Provider, error) {
if outTradeNo != "" {
order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(outTradeNo)).Only(ctx)
if err == nil {
p, pErr := s.getOrderProvider(ctx, order)
if pErr == nil {
return p, nil
}
slog.Warn("[Webhook] order provider creation failed, falling back to registry", "outTradeNo", outTradeNo, "error", pErr)
}
}
s.EnsureProviders(ctx)
return s.registry.GetProviderByKey(providerKey)
}
// --- Helpers ---
func psIsRefundStatus(s string) bool {
......@@ -262,6 +255,20 @@ func psNilIfEmpty(s string) *string {
return &s
}
func (s *PaymentService) paymentResume() *PaymentResumeService {
if s.resumeService != nil {
return s.resumeService
}
return NewPaymentResumeService(psResumeSigningKey(s.configService))
}
func psResumeSigningKey(configService *PaymentConfigService) []byte {
if configService == nil {
return nil
}
return configService.encryptionKey
}
func psSliceContains(sl []string, s string) bool {
for _, v := range sl {
if v == s {
......
package service
import (
"context"
"fmt"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
func enabledVisibleMethodsForProvider(providerKey, supportedTypes string) []string {
methodSet := make(map[string]struct{}, 2)
addMethod := func(method string) {
method = NormalizeVisibleMethod(method)
switch method {
case payment.TypeAlipay, payment.TypeWxpay:
methodSet[method] = struct{}{}
}
}
switch strings.TrimSpace(providerKey) {
case payment.TypeAlipay:
if strings.TrimSpace(supportedTypes) == "" {
addMethod(payment.TypeAlipay)
break
}
for _, supportedType := range splitTypes(supportedTypes) {
if NormalizeVisibleMethod(supportedType) == payment.TypeAlipay {
addMethod(payment.TypeAlipay)
break
}
}
case payment.TypeWxpay:
if strings.TrimSpace(supportedTypes) == "" {
addMethod(payment.TypeWxpay)
break
}
for _, supportedType := range splitTypes(supportedTypes) {
if NormalizeVisibleMethod(supportedType) == payment.TypeWxpay {
addMethod(payment.TypeWxpay)
break
}
}
case payment.TypeEasyPay:
for _, supportedType := range splitTypes(supportedTypes) {
addMethod(supportedType)
}
}
methods := make([]string, 0, len(methodSet))
for _, method := range []string{payment.TypeAlipay, payment.TypeWxpay} {
if _, ok := methodSet[method]; ok {
methods = append(methods, method)
}
}
return methods
}
func providerSupportsVisibleMethod(inst *dbent.PaymentProviderInstance, method string) bool {
if inst == nil || !inst.Enabled {
return false
}
method = NormalizeVisibleMethod(method)
for _, candidate := range enabledVisibleMethodsForProvider(inst.ProviderKey, inst.SupportedTypes) {
if candidate == method {
return true
}
}
return false
}
func filterEnabledVisibleMethodInstances(instances []*dbent.PaymentProviderInstance, method string) []*dbent.PaymentProviderInstance {
filtered := make([]*dbent.PaymentProviderInstance, 0, len(instances))
for _, inst := range instances {
if providerSupportsVisibleMethod(inst, method) {
filtered = append(filtered, inst)
}
}
return filtered
}
func buildPaymentProviderConflictError(method string, conflicting *dbent.PaymentProviderInstance) error {
metadata := map[string]string{
"payment_method": NormalizeVisibleMethod(method),
}
if conflicting != nil {
metadata["conflicting_provider_id"] = fmt.Sprintf("%d", conflicting.ID)
metadata["conflicting_provider_key"] = conflicting.ProviderKey
metadata["conflicting_provider_name"] = conflicting.Name
}
return infraerrors.Conflict(
"PAYMENT_PROVIDER_CONFLICT",
fmt.Sprintf("%s payment already has an enabled provider instance", NormalizeVisibleMethod(method)),
).WithMetadata(metadata)
}
func (s *PaymentConfigService) validateVisibleMethodEnablementConflicts(
ctx context.Context,
excludeID int64,
providerKey string,
supportedTypes string,
enabled bool,
) error {
if s == nil || s.entClient == nil || !enabled {
return nil
}
claimedMethods := enabledVisibleMethodsForProvider(providerKey, supportedTypes)
if len(claimedMethods) == 0 {
return nil
}
query := s.entClient.PaymentProviderInstance.Query().
Where(paymentproviderinstance.EnabledEQ(true))
if excludeID > 0 {
query = query.Where(paymentproviderinstance.IDNEQ(excludeID))
}
instances, err := query.All(ctx)
if err != nil {
return fmt.Errorf("query enabled payment providers: %w", err)
}
for _, method := range claimedMethods {
for _, inst := range instances {
if providerSupportsVisibleMethod(inst, method) {
return buildPaymentProviderConflictError(method, inst)
}
}
}
return nil
}
func (s *PaymentConfigService) resolveEnabledVisibleMethodInstance(
ctx context.Context,
method string,
) (*dbent.PaymentProviderInstance, error) {
if s == nil || s.entClient == nil {
return nil, nil
}
method = NormalizeVisibleMethod(method)
if method != payment.TypeAlipay && method != payment.TypeWxpay {
return nil, nil
}
instances, err := s.entClient.PaymentProviderInstance.Query().
Where(paymentproviderinstance.EnabledEQ(true)).
Order(paymentproviderinstance.BySortOrder()).
All(ctx)
if err != nil {
return nil, fmt.Errorf("query enabled payment providers: %w", err)
}
matching := filterEnabledVisibleMethodInstances(instances, method)
switch len(matching) {
case 0:
return nil, nil
case 1:
return matching[0], nil
default:
return nil, buildPaymentProviderConflictError(method, matching[0])
}
}
package service
import (
"context"
"fmt"
"log/slog"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
"github.com/Wei-Shaw/sub2api/internal/payment"
)
// GetWebhookProvider returns the provider instance that should verify a webhook.
// It resolves the original provider instance from the order whenever possible and
// only falls back to a registry provider for legacy/single-instance scenarios.
func (s *PaymentService) GetWebhookProvider(ctx context.Context, providerKey, outTradeNo string) (payment.Provider, error) {
providers, err := s.GetWebhookProviders(ctx, providerKey, outTradeNo)
if err != nil {
return nil, err
}
if len(providers) == 0 {
return nil, payment.ErrProviderNotFound
}
return providers[0], nil
}
// GetWebhookProviders returns provider candidates that can verify the webhook.
// Official WeChat Pay may require multiple candidates because the callback body
// cannot be bound to a merchant before decryption.
func (s *PaymentService) GetWebhookProviders(ctx context.Context, providerKey, outTradeNo string) ([]payment.Provider, error) {
if outTradeNo != "" {
order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(outTradeNo)).Only(ctx)
if err == nil {
if psHasPinnedProviderInstance(order) {
prov, err := s.getPinnedOrderProvider(ctx, order)
if err != nil {
return nil, err
}
return []payment.Provider{prov}, nil
}
inst, err := s.getOrderProviderInstance(ctx, order)
if err != nil {
return nil, fmt.Errorf("load order provider instance: %w", err)
}
if inst != nil {
prov, err := s.createProviderFromInstance(ctx, inst)
if err != nil {
return nil, err
}
return []payment.Provider{prov}, nil
}
if strings.TrimSpace(providerKey) == payment.TypeWxpay {
return s.getEnabledWebhookProvidersByKey(ctx, providerKey)
}
if !s.webhookRegistryFallbackAllowed(ctx, providerKey) {
return nil, fmt.Errorf("webhook provider fallback is ambiguous for %s", providerKey)
}
s.EnsureProviders(ctx)
prov, err := s.registry.GetProviderByKey(providerKey)
if err != nil {
return nil, err
}
return []payment.Provider{prov}, nil
}
}
if strings.TrimSpace(providerKey) == payment.TypeWxpay {
return s.getEnabledWebhookProvidersByKey(ctx, providerKey)
}
if !s.webhookRegistryFallbackAllowed(ctx, providerKey) {
return nil, fmt.Errorf("webhook provider fallback is ambiguous for %s", providerKey)
}
s.EnsureProviders(ctx)
prov, err := s.registry.GetProviderByKey(providerKey)
if err != nil {
return nil, err
}
return []payment.Provider{prov}, nil
}
func (s *PaymentService) getPinnedOrderProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) {
inst, err := s.getOrderProviderInstance(ctx, o)
if err != nil {
return nil, fmt.Errorf("load order provider instance: %w", err)
}
if inst == nil {
return nil, fmt.Errorf("order %d provider instance is missing", o.ID)
}
return s.createProviderFromInstance(ctx, inst)
}
func (s *PaymentService) webhookRegistryFallbackAllowed(ctx context.Context, providerKey string) bool {
providerKey = strings.TrimSpace(providerKey)
if providerKey == "" || s == nil || s.entClient == nil {
return false
}
count, err := s.entClient.PaymentProviderInstance.Query().
Where(
paymentproviderinstance.ProviderKeyEQ(providerKey),
paymentproviderinstance.EnabledEQ(true),
).
Count(ctx)
if err != nil {
slog.Warn("payment webhook fallback instance count failed", "provider", providerKey, "error", err)
return false
}
return count <= 1
}
func psHasPinnedProviderInstance(order *dbent.PaymentOrder) bool {
return order != nil && (psOrderProviderSnapshot(order) != nil || (order.ProviderInstanceID != nil && strings.TrimSpace(*order.ProviderInstanceID) != ""))
}
func (s *PaymentService) getEnabledWebhookProvidersByKey(ctx context.Context, providerKey string) ([]payment.Provider, error) {
providerKey = strings.TrimSpace(providerKey)
instances, err := s.entClient.PaymentProviderInstance.Query().
Where(
paymentproviderinstance.ProviderKeyEQ(providerKey),
paymentproviderinstance.EnabledEQ(true),
).
Order(dbent.Asc(paymentproviderinstance.FieldSortOrder)).
All(ctx)
if err != nil {
return nil, fmt.Errorf("query webhook provider instances: %w", err)
}
if len(instances) == 0 {
return nil, payment.ErrProviderNotFound
}
providers := make([]payment.Provider, 0, len(instances))
for _, inst := range instances {
prov, provErr := s.createProviderFromInstance(ctx, inst)
if provErr != nil {
slog.Warn("skip webhook provider instance", "provider", providerKey, "instanceID", inst.ID, "error", provErr)
continue
}
providers = append(providers, prov)
}
if len(providers) == 0 {
return nil, payment.ErrProviderNotFound
}
return providers, nil
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment