Commit 276ce052 authored by IanShaw027's avatar IanShaw027
Browse files

fix: align payment recovery query refs and resume authority

parent 119f784d
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"log/slog" "log/slog"
"strconv" "strconv"
"strings"
"time" "time"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
...@@ -139,20 +140,18 @@ func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) s ...@@ -139,20 +140,18 @@ func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) s
if err != nil { if err != nil {
return "" return ""
} }
// Use OutTradeNo as fallback when PaymentTradeNo is empty queryRef := paymentOrderQueryReference(o, prov)
// (e.g. EasyPay popup mode where trade_no arrives only via notify callback) if queryRef == "" {
tradeNo := o.PaymentTradeNo return ""
if tradeNo == "" {
tradeNo = o.OutTradeNo
} }
resp, err := prov.QueryOrder(ctx, tradeNo) resp, err := prov.QueryOrder(ctx, queryRef)
if err != nil { if err != nil {
slog.Warn("query upstream failed", "orderID", o.ID, "error", err) slog.Warn("query upstream failed", "orderID", o.ID, "error", err)
return "" return ""
} }
if resp.Status == payment.ProviderStatusPaid { if resp.Status == payment.ProviderStatusPaid {
notificationTradeNo := o.PaymentTradeNo notificationTradeNo := o.PaymentTradeNo
if upstreamTradeNo := resp.TradeNo; upstreamTradeNo != "" && upstreamTradeNo != notificationTradeNo { if upstreamTradeNo := strings.TrimSpace(resp.TradeNo); paymentOrderShouldPersistUpstreamTradeNo(queryRef, upstreamTradeNo, notificationTradeNo) {
if _, updateErr := s.entClient.PaymentOrder.Update(). if _, updateErr := s.entClient.PaymentOrder.Update().
Where(paymentorder.IDEQ(o.ID)). Where(paymentorder.IDEQ(o.ID)).
SetPaymentTradeNo(upstreamTradeNo). SetPaymentTradeNo(upstreamTradeNo).
...@@ -170,11 +169,57 @@ func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) s ...@@ -170,11 +169,57 @@ func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) s
return checkPaidResultAlreadyPaid return checkPaidResultAlreadyPaid
} }
if cp, ok := prov.(payment.CancelableProvider); ok { if cp, ok := prov.(payment.CancelableProvider); ok {
_ = cp.CancelPayment(ctx, tradeNo) _ = cp.CancelPayment(ctx, queryRef)
} }
return "" return ""
} }
func paymentOrderQueryReference(order *dbent.PaymentOrder, prov payment.Provider) string {
if order == nil {
return ""
}
providerKey := ""
if prov != nil {
providerKey = strings.TrimSpace(prov.ProviderKey())
}
if providerKey == "" {
if snapshot := psOrderProviderSnapshot(order); snapshot != nil {
providerKey = strings.TrimSpace(snapshot.ProviderKey)
}
}
if providerKey == "" {
providerKey = strings.TrimSpace(psStringValue(order.ProviderKey))
}
if providerKey == "" {
providerKey = strings.TrimSpace(order.PaymentType)
}
switch payment.GetBasePaymentType(providerKey) {
case payment.TypeAlipay, payment.TypeEasyPay, payment.TypeWxpay:
return strings.TrimSpace(order.OutTradeNo)
default:
if tradeNo := strings.TrimSpace(order.PaymentTradeNo); tradeNo != "" {
return tradeNo
}
return strings.TrimSpace(order.OutTradeNo)
}
}
func paymentOrderShouldPersistUpstreamTradeNo(queryRef, upstreamTradeNo, currentTradeNo string) bool {
upstreamTradeNo = strings.TrimSpace(upstreamTradeNo)
if upstreamTradeNo == "" {
return false
}
if strings.EqualFold(upstreamTradeNo, strings.TrimSpace(currentTradeNo)) {
return false
}
if strings.EqualFold(upstreamTradeNo, strings.TrimSpace(queryRef)) {
return false
}
return true
}
// VerifyOrderByOutTradeNo actively queries the upstream provider to check // VerifyOrderByOutTradeNo actively queries the upstream provider to check
// if a payment was made, and processes it if so. This handles the case where // if a payment was made, and processes it if so. This handles the case where
// the provider's notify callback was missed (e.g. EasyPay popup mode). // the provider's notify callback was missed (e.g. EasyPay popup mode).
......
...@@ -234,6 +234,95 @@ func TestVerifyOrderByOutTradeNoBackfillsTradeNoFromPaidQuery(t *testing.T) { ...@@ -234,6 +234,95 @@ func TestVerifyOrderByOutTradeNoBackfillsTradeNoFromPaidQuery(t *testing.T) {
require.Equal(t, user.ID, redeemRepo.useCalls[0].userID) require.Equal(t, user.ID, redeemRepo.useCalls[0].userID)
} }
func TestVerifyOrderByOutTradeNoUsesOutTradeNoWhenPaymentTradeNoAlreadyExistsForAlipay(t *testing.T) {
ctx := context.Background()
client := newPaymentOrderLifecycleTestClient(t)
user, err := client.User.Create().
SetEmail("checkpaid-existing-trade@example.com").
SetPasswordHash("hash").
SetUsername("checkpaid-existing-trade-user").
Save(ctx)
require.NoError(t, err)
order, err := client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(88).
SetPayAmount(88).
SetFeeRate(0).
SetRechargeCode("CHECKPAID-EXISTING-TRADE-NO").
SetOutTradeNo("sub2_checkpaid_use_out_trade_no").
SetPaymentType(payment.TypeAlipay).
SetPaymentTradeNo("upstream-trade-existing").
SetOrderType(payment.OrderTypeBalance).
SetStatus(OrderStatusPending).
SetExpiresAt(time.Now().Add(time.Hour)).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
Save(ctx)
require.NoError(t, err)
userRepo := &mockUserRepo{
getByIDUser: &User{
ID: user.ID,
Email: user.Email,
Username: user.Username,
Balance: 0,
},
}
userRepo.updateBalanceFn = func(ctx context.Context, id int64, amount float64) error {
require.Equal(t, user.ID, id)
if userRepo.getByIDUser != nil {
userRepo.getByIDUser.Balance += amount
}
return nil
}
redeemRepo := &paymentOrderLifecycleRedeemRepo{
codesByCode: map[string]*RedeemCode{
order.RechargeCode: {
ID: 1,
Code: order.RechargeCode,
Type: RedeemTypeBalance,
Value: order.Amount,
Status: StatusUnused,
},
},
}
redeemService := NewRedeemService(
redeemRepo,
userRepo,
nil,
nil,
nil,
client,
nil,
)
registry := payment.NewRegistry()
provider := &paymentOrderLifecycleQueryProvider{
resp: &payment.QueryOrderResponse{
TradeNo: "upstream-trade-existing",
Status: payment.ProviderStatusPaid,
Amount: 88,
},
}
registry.Register(provider)
svc := &PaymentService{
entClient: client,
registry: registry,
redeemService: redeemService,
userRepo: userRepo,
providersLoaded: true,
}
got, err := svc.VerifyOrderByOutTradeNo(ctx, order.OutTradeNo, user.ID)
require.NoError(t, err)
require.Equal(t, order.OutTradeNo, provider.lastQueryTradeNo)
require.Equal(t, "upstream-trade-existing", got.PaymentTradeNo)
}
func newPaymentOrderLifecycleTestClient(t *testing.T) *dbent.Client { func newPaymentOrderLifecycleTestClient(t *testing.T) *dbent.Client {
t.Helper() t.Helper()
......
...@@ -21,10 +21,21 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token ...@@ -21,10 +21,21 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token
if claims.UserID > 0 && order.UserID != claims.UserID { if claims.UserID > 0 && order.UserID != claims.UserID {
return nil, fmt.Errorf("resume token user mismatch") return nil, fmt.Errorf("resume token user mismatch")
} }
if claims.ProviderInstanceID != "" && strings.TrimSpace(psStringValue(order.ProviderInstanceID)) != claims.ProviderInstanceID { snapshot := psOrderProviderSnapshot(order)
orderProviderInstanceID := strings.TrimSpace(psStringValue(order.ProviderInstanceID))
orderProviderKey := strings.TrimSpace(psStringValue(order.ProviderKey))
if snapshot != nil {
if snapshot.ProviderInstanceID != "" {
orderProviderInstanceID = snapshot.ProviderInstanceID
}
if snapshot.ProviderKey != "" {
orderProviderKey = snapshot.ProviderKey
}
}
if claims.ProviderInstanceID != "" && orderProviderInstanceID != claims.ProviderInstanceID {
return nil, fmt.Errorf("resume token provider instance mismatch") return nil, fmt.Errorf("resume token provider instance mismatch")
} }
if claims.ProviderKey != "" && strings.TrimSpace(psStringValue(order.ProviderKey)) != claims.ProviderKey { if claims.ProviderKey != "" && orderProviderKey != claims.ProviderKey {
return nil, fmt.Errorf("resume token provider key mismatch") return nil, fmt.Errorf("resume token provider key mismatch")
} }
if claims.PaymentType != "" && strings.TrimSpace(order.PaymentType) != claims.PaymentType { if claims.PaymentType != "" && strings.TrimSpace(order.PaymentType) != claims.PaymentType {
......
...@@ -146,6 +146,63 @@ func TestGetPublicOrderByResumeTokenRejectsSnapshotMismatch(t *testing.T) { ...@@ -146,6 +146,63 @@ func TestGetPublicOrderByResumeTokenRejectsSnapshotMismatch(t *testing.T) {
require.Contains(t, err.Error(), "resume token") require.Contains(t, err.Error(), "resume token")
} }
func TestGetPublicOrderByResumeTokenUsesSnapshotAuthorityWhenColumnsDiffer(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
user, err := client.User.Create().
SetEmail("resume-snapshot-authority@example.com").
SetPasswordHash("hash").
SetUsername("resume-snapshot-authority-user").
Save(ctx)
require.NoError(t, err)
order, err := client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(88).
SetPayAmount(88).
SetFeeRate(0).
SetRechargeCode("RESUME-SNAPSHOT-AUTHORITY").
SetOutTradeNo("sub2_resume_snapshot_authority").
SetPaymentType(payment.TypeAlipay).
SetPaymentTradeNo("trade-snapshot-authority").
SetOrderType(payment.OrderTypeBalance).
SetStatus(OrderStatusPending).
SetExpiresAt(time.Now().Add(time.Hour)).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
SetProviderInstanceID("legacy-column-instance").
SetProviderKey(payment.TypeAlipay).
SetProviderSnapshot(map[string]any{
"schema_version": 2,
"provider_instance_id": "snapshot-instance",
"provider_key": payment.TypeEasyPay,
}).
Save(ctx)
require.NoError(t, err)
resumeSvc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
token, err := resumeSvc.CreateToken(ResumeTokenClaims{
OrderID: order.ID,
UserID: user.ID,
ProviderInstanceID: "snapshot-instance",
ProviderKey: payment.TypeEasyPay,
PaymentType: payment.TypeAlipay,
CanonicalReturnURL: "https://app.example.com/payment/result",
})
require.NoError(t, err)
svc := &PaymentService{
entClient: client,
resumeService: resumeSvc,
}
got, err := svc.GetPublicOrderByResumeToken(ctx, token)
require.NoError(t, err)
require.Equal(t, order.ID, got.ID)
}
func TestGetPublicOrderByResumeTokenChecksUpstreamForPendingOrder(t *testing.T) { func TestGetPublicOrderByResumeTokenChecksUpstreamForPendingOrder(t *testing.T) {
ctx := context.Background() ctx := context.Background()
client := newPaymentConfigServiceTestClient(t) client := newPaymentConfigServiceTestClient(t)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment