Unverified Commit 97f14b7a authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #1572 from touwaeriol/feat/payment-system-v2

feat(payment): add complete payment system with multi-provider support
parents 1ef3782d 6793503e
//go:build unit
package handler
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestWriteSuccessResponse(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
providerKey string
wantCode int
wantContentType string
wantBody string
checkJSON bool
wantJSONCode string
wantJSONMessage string
}{
{
name: "wxpay returns JSON with code SUCCESS",
providerKey: "wxpay",
wantCode: http.StatusOK,
wantContentType: "application/json",
checkJSON: true,
wantJSONCode: "SUCCESS",
wantJSONMessage: "成功",
},
{
name: "stripe returns empty 200",
providerKey: "stripe",
wantCode: http.StatusOK,
wantContentType: "text/plain",
wantBody: "",
},
{
name: "easypay returns plain text success",
providerKey: "easypay",
wantCode: http.StatusOK,
wantContentType: "text/plain",
wantBody: "success",
},
{
name: "alipay returns plain text success",
providerKey: "alipay",
wantCode: http.StatusOK,
wantContentType: "text/plain",
wantBody: "success",
},
{
name: "unknown provider returns plain text success",
providerKey: "unknown_provider",
wantCode: http.StatusOK,
wantContentType: "text/plain",
wantBody: "success",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
writeSuccessResponse(c, tt.providerKey)
assert.Equal(t, tt.wantCode, w.Code)
assert.Contains(t, w.Header().Get("Content-Type"), tt.wantContentType)
if tt.checkJSON {
var resp wxpaySuccessResponse
err := json.Unmarshal(w.Body.Bytes(), &resp)
require.NoError(t, err, "response body should be valid JSON")
assert.Equal(t, tt.wantJSONCode, resp.Code)
assert.Equal(t, tt.wantJSONMessage, resp.Message)
} else {
assert.Equal(t, tt.wantBody, w.Body.String())
}
})
}
}
func TestWebhookConstants(t *testing.T) {
t.Run("maxWebhookBodySize is 1MB", func(t *testing.T) {
assert.Equal(t, int64(1<<20), int64(maxWebhookBodySize))
})
t.Run("webhookLogTruncateLen is 200", func(t *testing.T) {
assert.Equal(t, 200, webhookLogTruncateLen)
})
}
...@@ -59,6 +59,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { ...@@ -59,6 +59,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
OIDCOAuthEnabled: settings.OIDCOAuthEnabled, OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
OIDCOAuthProviderName: settings.OIDCOAuthProviderName, OIDCOAuthProviderName: settings.OIDCOAuthProviderName,
BackendModeEnabled: settings.BackendModeEnabled, BackendModeEnabled: settings.BackendModeEnabled,
PaymentEnabled: settings.PaymentEnabled,
Version: h.version, Version: h.version,
}) })
} }
...@@ -34,6 +34,7 @@ func ProvideAdminHandlers( ...@@ -34,6 +34,7 @@ func ProvideAdminHandlers(
apiKeyHandler *admin.AdminAPIKeyHandler, apiKeyHandler *admin.AdminAPIKeyHandler,
scheduledTestHandler *admin.ScheduledTestHandler, scheduledTestHandler *admin.ScheduledTestHandler,
channelHandler *admin.ChannelHandler, channelHandler *admin.ChannelHandler,
paymentHandler *admin.PaymentHandler,
) *AdminHandlers { ) *AdminHandlers {
return &AdminHandlers{ return &AdminHandlers{
Dashboard: dashboardHandler, Dashboard: dashboardHandler,
...@@ -61,6 +62,7 @@ func ProvideAdminHandlers( ...@@ -61,6 +62,7 @@ func ProvideAdminHandlers(
APIKey: apiKeyHandler, APIKey: apiKeyHandler,
ScheduledTest: scheduledTestHandler, ScheduledTest: scheduledTestHandler,
Channel: channelHandler, Channel: channelHandler,
Payment: paymentHandler,
} }
} }
...@@ -88,22 +90,26 @@ func ProvideHandlers( ...@@ -88,22 +90,26 @@ func ProvideHandlers(
openaiGatewayHandler *OpenAIGatewayHandler, openaiGatewayHandler *OpenAIGatewayHandler,
settingHandler *SettingHandler, settingHandler *SettingHandler,
totpHandler *TotpHandler, totpHandler *TotpHandler,
paymentHandler *PaymentHandler,
paymentWebhookHandler *PaymentWebhookHandler,
_ *service.IdempotencyCoordinator, _ *service.IdempotencyCoordinator,
_ *service.IdempotencyCleanupService, _ *service.IdempotencyCleanupService,
) *Handlers { ) *Handlers {
return &Handlers{ return &Handlers{
Auth: authHandler, Auth: authHandler,
User: userHandler, User: userHandler,
APIKey: apiKeyHandler, APIKey: apiKeyHandler,
Usage: usageHandler, Usage: usageHandler,
Redeem: redeemHandler, Redeem: redeemHandler,
Subscription: subscriptionHandler, Subscription: subscriptionHandler,
Announcement: announcementHandler, Announcement: announcementHandler,
Admin: adminHandlers, Admin: adminHandlers,
Gateway: gatewayHandler, Gateway: gatewayHandler,
OpenAIGateway: openaiGatewayHandler, OpenAIGateway: openaiGatewayHandler,
Setting: settingHandler, Setting: settingHandler,
Totp: totpHandler, Totp: totpHandler,
Payment: paymentHandler,
PaymentWebhook: paymentWebhookHandler,
} }
} }
...@@ -121,6 +127,8 @@ var ProviderSet = wire.NewSet( ...@@ -121,6 +127,8 @@ var ProviderSet = wire.NewSet(
NewOpenAIGatewayHandler, NewOpenAIGatewayHandler,
NewTotpHandler, NewTotpHandler,
ProvideSettingHandler, ProvideSettingHandler,
NewPaymentHandler,
NewPaymentWebhookHandler,
// Admin handlers // Admin handlers
admin.NewDashboardHandler, admin.NewDashboardHandler,
...@@ -148,6 +156,7 @@ var ProviderSet = wire.NewSet( ...@@ -148,6 +156,7 @@ var ProviderSet = wire.NewSet(
admin.NewAdminAPIKeyHandler, admin.NewAdminAPIKeyHandler,
admin.NewScheduledTestHandler, admin.NewScheduledTestHandler,
admin.NewChannelHandler, admin.NewChannelHandler,
admin.NewPaymentHandler,
// AdminHandlers and Handlers constructors // AdminHandlers and Handlers constructors
ProvideAdminHandlers, ProvideAdminHandlers,
......
package payment
import (
"fmt"
"github.com/shopspring/decimal"
)
const centsPerYuan = 100
// YuanToFen converts a CNY yuan string (e.g. "10.50") to fen (int64).
// Uses shopspring/decimal for precision.
func YuanToFen(yuanStr string) (int64, error) {
d, err := decimal.NewFromString(yuanStr)
if err != nil {
return 0, fmt.Errorf("invalid amount: %s", yuanStr)
}
return d.Mul(decimal.NewFromInt(centsPerYuan)).IntPart(), nil
}
// FenToYuan converts fen (int64) to yuan as a float64 for interface compatibility.
func FenToYuan(fen int64) float64 {
return decimal.NewFromInt(fen).Div(decimal.NewFromInt(centsPerYuan)).InexactFloat64()
}
//go:build unit
package payment
import (
"math"
"testing"
)
func TestYuanToFen(t *testing.T) {
tests := []struct {
name string
input string
want int64
wantErr bool
}{
// Normal values
{name: "one yuan", input: "1.00", want: 100},
{name: "ten yuan fifty fen", input: "10.50", want: 1050},
{name: "one fen", input: "0.01", want: 1},
{name: "large amount", input: "99999.99", want: 9999999},
// Edge: zero
{name: "zero no decimal", input: "0", want: 0},
{name: "zero with decimal", input: "0.00", want: 0},
// IEEE 754 precision edge case: 1.15 * 100 = 114.99999... in float64
{name: "ieee754 precision 1.15", input: "1.15", want: 115},
// More precision edge cases
{name: "ieee754 precision 0.1", input: "0.1", want: 10},
{name: "ieee754 precision 0.2", input: "0.2", want: 20},
{name: "ieee754 precision 33.33", input: "33.33", want: 3333},
// Large value
{name: "hundred thousand", input: "100000.00", want: 10000000},
// Integer without decimal
{name: "integer 5", input: "5", want: 500},
{name: "integer 100", input: "100", want: 10000},
// Single decimal place
{name: "single decimal 1.5", input: "1.5", want: 150},
// Negative values
{name: "negative one yuan", input: "-1.00", want: -100},
{name: "negative with fen", input: "-10.50", want: -1050},
// Invalid inputs
{name: "empty string", input: "", wantErr: true},
{name: "alphabetic", input: "abc", wantErr: true},
{name: "double dot", input: "1.2.3", wantErr: true},
{name: "spaces", input: " ", wantErr: true},
{name: "special chars", input: "$10.00", wantErr: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := YuanToFen(tt.input)
if tt.wantErr {
if err == nil {
t.Errorf("YuanToFen(%q) expected error, got %d", tt.input, got)
}
return
}
if err != nil {
t.Fatalf("YuanToFen(%q) unexpected error: %v", tt.input, err)
}
if got != tt.want {
t.Errorf("YuanToFen(%q) = %d, want %d", tt.input, got, tt.want)
}
})
}
}
func TestFenToYuan(t *testing.T) {
tests := []struct {
name string
fen int64
want float64
}{
{name: "one yuan", fen: 100, want: 1.0},
{name: "ten yuan fifty fen", fen: 1050, want: 10.5},
{name: "one fen", fen: 1, want: 0.01},
{name: "zero", fen: 0, want: 0.0},
{name: "large amount", fen: 9999999, want: 99999.99},
{name: "negative", fen: -100, want: -1.0},
{name: "negative with fen", fen: -1050, want: -10.5},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := FenToYuan(tt.fen)
if math.Abs(got-tt.want) > 1e-9 {
t.Errorf("FenToYuan(%d) = %f, want %f", tt.fen, got, tt.want)
}
})
}
}
func TestYuanToFenRoundTrip(t *testing.T) {
// Verify that converting yuan->fen->yuan preserves the value.
cases := []struct {
yuan string
fen int64
}{
{"0.01", 1},
{"1.00", 100},
{"10.50", 1050},
{"99999.99", 9999999},
}
for _, tc := range cases {
fen, err := YuanToFen(tc.yuan)
if err != nil {
t.Fatalf("YuanToFen(%q) unexpected error: %v", tc.yuan, err)
}
if fen != tc.fen {
t.Errorf("YuanToFen(%q) = %d, want %d", tc.yuan, fen, tc.fen)
}
yuan := FenToYuan(fen)
// Parse expected yuan back for comparison
expectedYuan := FenToYuan(tc.fen)
if math.Abs(yuan-expectedYuan) > 1e-9 {
t.Errorf("round-trip: FenToYuan(%d) = %f, want %f", fen, yuan, expectedYuan)
}
}
}
package payment
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"fmt"
"io"
"strings"
)
// Encrypt encrypts plaintext using AES-256-GCM with the given 32-byte key.
// The output format is "iv:authTag:ciphertext" where each component is base64-encoded,
// matching the Node.js crypto.ts format for cross-compatibility.
func Encrypt(plaintext string, key []byte) (string, error) {
if len(key) != 32 {
return "", fmt.Errorf("encryption key must be 32 bytes, got %d", len(key))
}
block, err := aes.NewCipher(key)
if err != nil {
return "", fmt.Errorf("create AES cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("create GCM: %w", err)
}
nonce := make([]byte, gcm.NonceSize()) // 12 bytes for GCM
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", fmt.Errorf("generate nonce: %w", err)
}
// Seal appends the ciphertext + auth tag
sealed := gcm.Seal(nil, nonce, []byte(plaintext), nil)
// Split sealed into ciphertext and auth tag (last 16 bytes)
tagSize := gcm.Overhead()
ciphertext := sealed[:len(sealed)-tagSize]
authTag := sealed[len(sealed)-tagSize:]
// Format: iv:authTag:ciphertext (all base64)
return fmt.Sprintf("%s:%s:%s",
base64.StdEncoding.EncodeToString(nonce),
base64.StdEncoding.EncodeToString(authTag),
base64.StdEncoding.EncodeToString(ciphertext),
), nil
}
// Decrypt decrypts a ciphertext string produced by Encrypt.
// The input format is "iv:authTag:ciphertext" where each component is base64-encoded.
func Decrypt(ciphertext string, key []byte) (string, error) {
if len(key) != 32 {
return "", fmt.Errorf("encryption key must be 32 bytes, got %d", len(key))
}
parts := strings.SplitN(ciphertext, ":", 3)
if len(parts) != 3 {
return "", fmt.Errorf("invalid ciphertext format: expected iv:authTag:ciphertext")
}
nonce, err := base64.StdEncoding.DecodeString(parts[0])
if err != nil {
return "", fmt.Errorf("decode IV: %w", err)
}
authTag, err := base64.StdEncoding.DecodeString(parts[1])
if err != nil {
return "", fmt.Errorf("decode auth tag: %w", err)
}
encrypted, err := base64.StdEncoding.DecodeString(parts[2])
if err != nil {
return "", fmt.Errorf("decode ciphertext: %w", err)
}
block, err := aes.NewCipher(key)
if err != nil {
return "", fmt.Errorf("create AES cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("create GCM: %w", err)
}
// Reconstruct the sealed data: ciphertext + authTag
sealed := append(encrypted, authTag...)
plaintext, err := gcm.Open(nil, nonce, sealed, nil)
if err != nil {
return "", fmt.Errorf("decrypt: %w", err)
}
return string(plaintext), nil
}
package payment
import (
"crypto/rand"
"strings"
"testing"
)
func makeKey(t *testing.T) []byte {
t.Helper()
key := make([]byte, 32)
if _, err := rand.Read(key); err != nil {
t.Fatalf("generate random key: %v", err)
}
return key
}
func TestEncryptDecryptRoundTrip(t *testing.T) {
t.Parallel()
key := makeKey(t)
plaintexts := []string{
"hello world",
"short",
"a longer string with special chars: !@#$%^&*()",
`{"key":"value","num":42}`,
"你好世界 unicode test 🎉",
strings.Repeat("x", 10000),
}
for _, pt := range plaintexts {
encrypted, err := Encrypt(pt, key)
if err != nil {
t.Fatalf("Encrypt(%q) error: %v", pt[:min(len(pt), 30)], err)
}
decrypted, err := Decrypt(encrypted, key)
if err != nil {
t.Fatalf("Decrypt error for plaintext %q: %v", pt[:min(len(pt), 30)], err)
}
if decrypted != pt {
t.Fatalf("round-trip failed: got %q, want %q", decrypted[:min(len(decrypted), 30)], pt[:min(len(pt), 30)])
}
}
}
func TestEncryptProducesDifferentCiphertexts(t *testing.T) {
t.Parallel()
key := makeKey(t)
ct1, err := Encrypt("same plaintext", key)
if err != nil {
t.Fatalf("first Encrypt error: %v", err)
}
ct2, err := Encrypt("same plaintext", key)
if err != nil {
t.Fatalf("second Encrypt error: %v", err)
}
if ct1 == ct2 {
t.Fatal("two encryptions of the same plaintext should produce different ciphertexts (random nonce)")
}
}
func TestDecryptWithWrongKeyFails(t *testing.T) {
t.Parallel()
key1 := makeKey(t)
key2 := makeKey(t)
encrypted, err := Encrypt("secret data", key1)
if err != nil {
t.Fatalf("Encrypt error: %v", err)
}
_, err = Decrypt(encrypted, key2)
if err == nil {
t.Fatal("Decrypt with wrong key should fail, but got nil error")
}
}
func TestEncryptRejectsInvalidKeyLength(t *testing.T) {
t.Parallel()
badKeys := [][]byte{
nil,
make([]byte, 0),
make([]byte, 16),
make([]byte, 31),
make([]byte, 33),
make([]byte, 64),
}
for _, key := range badKeys {
_, err := Encrypt("test", key)
if err == nil {
t.Fatalf("Encrypt should reject key of length %d", len(key))
}
}
}
func TestDecryptRejectsInvalidKeyLength(t *testing.T) {
t.Parallel()
badKeys := [][]byte{
nil,
make([]byte, 16),
make([]byte, 33),
}
for _, key := range badKeys {
_, err := Decrypt("dummydata:dummydata:dummydata", key)
if err == nil {
t.Fatalf("Decrypt should reject key of length %d", len(key))
}
}
}
func TestEncryptEmptyPlaintext(t *testing.T) {
t.Parallel()
key := makeKey(t)
encrypted, err := Encrypt("", key)
if err != nil {
t.Fatalf("Encrypt empty plaintext error: %v", err)
}
decrypted, err := Decrypt(encrypted, key)
if err != nil {
t.Fatalf("Decrypt empty plaintext error: %v", err)
}
if decrypted != "" {
t.Fatalf("expected empty string, got %q", decrypted)
}
}
func TestEncryptDecryptUnicodeJSON(t *testing.T) {
t.Parallel()
key := makeKey(t)
jsonContent := `{"name":"测试用户","email":"test@example.com","balance":100.50}`
encrypted, err := Encrypt(jsonContent, key)
if err != nil {
t.Fatalf("Encrypt JSON error: %v", err)
}
decrypted, err := Decrypt(encrypted, key)
if err != nil {
t.Fatalf("Decrypt JSON error: %v", err)
}
if decrypted != jsonContent {
t.Fatalf("JSON round-trip failed: got %q, want %q", decrypted, jsonContent)
}
}
func TestDecryptInvalidFormat(t *testing.T) {
t.Parallel()
key := makeKey(t)
invalidInputs := []string{
"",
"nodelimiter",
"only:two",
"invalid:base64:!!!",
}
for _, input := range invalidInputs {
_, err := Decrypt(input, key)
if err == nil {
t.Fatalf("Decrypt(%q) should fail but got nil error", input)
}
}
}
func TestCiphertextFormat(t *testing.T) {
t.Parallel()
key := makeKey(t)
encrypted, err := Encrypt("test", key)
if err != nil {
t.Fatalf("Encrypt error: %v", err)
}
parts := strings.SplitN(encrypted, ":", 3)
if len(parts) != 3 {
t.Fatalf("ciphertext should have format iv:authTag:ciphertext, got %d parts", len(parts))
}
for i, part := range parts {
if part == "" {
t.Fatalf("ciphertext part %d is empty", i)
}
}
}
package payment
import (
"github.com/shopspring/decimal"
)
// CalculatePayAmount computes the total pay amount given a recharge amount and
// fee rate (percentage). Fee = amount * feeRate / 100, rounded UP (away from zero)
// to 2 decimal places. The returned string is formatted to exactly 2 decimal places.
// If feeRate <= 0, the amount is returned as-is (formatted to 2 decimal places).
func CalculatePayAmount(rechargeAmount float64, feeRate float64) string {
amount := decimal.NewFromFloat(rechargeAmount)
if feeRate <= 0 {
return amount.StringFixed(2)
}
rate := decimal.NewFromFloat(feeRate)
fee := amount.Mul(rate).Div(decimal.NewFromInt(100)).RoundUp(2)
return amount.Add(fee).StringFixed(2)
}
package payment
import (
"testing"
)
func TestCalculatePayAmount(t *testing.T) {
t.Parallel()
tests := []struct {
name string
amount float64
feeRate float64
expected string
}{
{
name: "zero fee rate returns same amount",
amount: 100.00,
feeRate: 0,
expected: "100.00",
},
{
name: "negative fee rate returns same amount",
amount: 50.00,
feeRate: -5,
expected: "50.00",
},
{
name: "1 percent fee rate",
amount: 100.00,
feeRate: 1,
expected: "101.00",
},
{
name: "5 percent fee on 200",
amount: 200.00,
feeRate: 5,
expected: "210.00",
},
{
name: "fee rounds UP to 2 decimal places",
amount: 100.00,
feeRate: 3,
expected: "103.00",
},
{
name: "fee rounds UP small remainder",
amount: 10.00,
feeRate: 3.33,
expected: "10.34", // 10 * 3.33 / 100 = 0.333 -> round up -> 0.34
},
{
name: "very small amount",
amount: 0.01,
feeRate: 1,
expected: "0.02", // 0.01 * 1/100 = 0.0001 -> round up -> 0.01 -> total 0.02
},
{
name: "large amount",
amount: 99999.99,
feeRate: 10,
expected: "109999.99", // 99999.99 * 10/100 = 9999.999 -> round up -> 10000.00 -> total 109999.99
},
{
name: "100 percent fee rate doubles amount",
amount: 50.00,
feeRate: 100,
expected: "100.00",
},
{
name: "precision 0.01 fee difference",
amount: 100.00,
feeRate: 1.01,
expected: "101.01", // 100 * 1.01/100 = 1.01
},
{
name: "precision 0.02 fee",
amount: 100.00,
feeRate: 1.02,
expected: "101.02",
},
{
name: "zero amount with positive fee",
amount: 0,
feeRate: 5,
expected: "0.00",
},
{
name: "fractional amount no fee",
amount: 19.99,
feeRate: 0,
expected: "19.99",
},
{
name: "fractional fee that causes rounding up",
amount: 33.33,
feeRate: 7.77,
expected: "35.92", // 33.33 * 7.77 / 100 = 2.589741 -> round up -> 2.59 -> total 35.92
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := CalculatePayAmount(tt.amount, tt.feeRate)
if got != tt.expected {
t.Fatalf("CalculatePayAmount(%v, %v) = %q, want %q", tt.amount, tt.feeRate, got, tt.expected)
}
})
}
}
package payment
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"strings"
"sync/atomic"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
)
// Strategy represents a load balancing strategy for provider instance selection.
type Strategy string
const (
StrategyRoundRobin Strategy = "round-robin"
StrategyLeastAmount Strategy = "least-amount"
)
// ChannelLimits holds limits for a single payment channel within a provider instance.
type ChannelLimits struct {
DailyLimit float64 `json:"dailyLimit,omitempty"`
SingleMin float64 `json:"singleMin,omitempty"`
SingleMax float64 `json:"singleMax,omitempty"`
}
// InstanceLimits holds per-channel limits for a provider instance (JSON).
type InstanceLimits map[string]ChannelLimits
// LoadBalancer selects a provider instance for a given payment type.
type LoadBalancer interface {
GetInstanceConfig(ctx context.Context, instanceID int64) (map[string]string, error)
SelectInstance(ctx context.Context, providerKey string, paymentType PaymentType, strategy Strategy, orderAmount float64) (*InstanceSelection, error)
}
// DefaultLoadBalancer implements LoadBalancer using database queries.
type DefaultLoadBalancer struct {
db *dbent.Client
encryptionKey []byte
counter atomic.Uint64
}
// NewDefaultLoadBalancer creates a new load balancer.
func NewDefaultLoadBalancer(db *dbent.Client, encryptionKey []byte) *DefaultLoadBalancer {
return &DefaultLoadBalancer{db: db, encryptionKey: encryptionKey}
}
// instanceCandidate pairs an instance with its pre-fetched daily usage.
type instanceCandidate struct {
inst *dbent.PaymentProviderInstance
dailyUsed float64 // includes PENDING orders
}
// SelectInstance picks an enabled instance for the given provider key and payment type.
//
// Flow:
// 1. Query all enabled instances for providerKey, filter by supported paymentType
// 2. Batch-query daily usage (PENDING + PAID + COMPLETED + RECHARGING) for all candidates
// 3. Filter out instances where: single-min/max violated OR daily remaining < orderAmount
// 4. Pick from survivors using the configured strategy (round-robin / least-amount)
// 5. If all filtered out, fall back to full list (let the provider itself reject)
func (lb *DefaultLoadBalancer) SelectInstance(
ctx context.Context,
providerKey string,
paymentType PaymentType,
strategy Strategy,
orderAmount float64,
) (*InstanceSelection, error) {
// Step 1: query enabled instances matching payment type.
instances, err := lb.queryEnabledInstances(ctx, providerKey, paymentType)
if err != nil {
return nil, err
}
// Step 2: batch-fetch daily usage for all candidates.
candidates := lb.attachDailyUsage(ctx, instances)
// Step 3: filter by limits.
available := filterByLimits(candidates, paymentType, orderAmount)
if len(available) == 0 {
slog.Warn("all instances exceeded limits, using full candidate list",
"provider", providerKey, "payment_type", paymentType,
"order_amount", orderAmount, "count", len(candidates))
available = candidates
}
// Step 4: pick by strategy.
selected := lb.pickByStrategy(available, strategy)
return lb.buildSelection(selected.inst)
}
// queryEnabledInstances returns enabled instances for providerKey that support paymentType.
func (lb *DefaultLoadBalancer) queryEnabledInstances(
ctx context.Context,
providerKey string,
paymentType PaymentType,
) ([]*dbent.PaymentProviderInstance, error) {
instances, err := lb.db.PaymentProviderInstance.Query().
Where(
paymentproviderinstance.ProviderKey(providerKey),
paymentproviderinstance.Enabled(true),
).
Order(dbent.Asc(paymentproviderinstance.FieldSortOrder)).
All(ctx)
if err != nil {
return nil, fmt.Errorf("query provider instances: %w", err)
}
var matched []*dbent.PaymentProviderInstance
for _, inst := range instances {
if paymentType == providerKey || InstanceSupportsType(inst.SupportedTypes, paymentType) {
matched = append(matched, inst)
}
}
if len(matched) == 0 {
return nil, fmt.Errorf("no enabled instance for provider %s type %s", providerKey, paymentType)
}
return matched, nil
}
// attachDailyUsage queries daily usage for each instance in a single pass.
// Usage includes PENDING orders to avoid over-committing capacity.
func (lb *DefaultLoadBalancer) attachDailyUsage(
ctx context.Context,
instances []*dbent.PaymentProviderInstance,
) []instanceCandidate {
todayStart := startOfDay(time.Now())
// Collect instance IDs.
ids := make([]string, len(instances))
for i, inst := range instances {
ids[i] = fmt.Sprintf("%d", inst.ID)
}
// Batch query: sum pay_amount grouped by provider_instance_id.
type row struct {
InstanceID string `json:"provider_instance_id"`
Sum float64 `json:"sum"`
}
var rows []row
err := lb.db.PaymentOrder.Query().
Where(
paymentorder.ProviderInstanceIDIn(ids...),
paymentorder.StatusIn(
OrderStatusPending, OrderStatusPaid,
OrderStatusCompleted, OrderStatusRecharging,
),
paymentorder.CreatedAtGTE(todayStart),
).
GroupBy(paymentorder.FieldProviderInstanceID).
Aggregate(dbent.Sum(paymentorder.FieldPayAmount)).
Scan(ctx, &rows)
if err != nil {
slog.Warn("batch daily usage query failed, treating all as zero", "error", err)
}
usageMap := make(map[string]float64, len(rows))
for _, r := range rows {
usageMap[r.InstanceID] = r.Sum
}
candidates := make([]instanceCandidate, len(instances))
for i, inst := range instances {
candidates[i] = instanceCandidate{
inst: inst,
dailyUsed: usageMap[fmt.Sprintf("%d", inst.ID)],
}
}
return candidates
}
// filterByLimits removes instances that cannot accommodate the order:
// - orderAmount outside single-transaction [min, max]
// - daily remaining capacity (limit - used) < orderAmount
func filterByLimits(candidates []instanceCandidate, paymentType PaymentType, orderAmount float64) []instanceCandidate {
var result []instanceCandidate
for _, c := range candidates {
cl := getInstanceChannelLimits(c.inst, paymentType)
if cl.SingleMin > 0 && orderAmount < cl.SingleMin {
slog.Info("order below instance single min, skipping",
"instance_id", c.inst.ID, "order", orderAmount, "min", cl.SingleMin)
continue
}
if cl.SingleMax > 0 && orderAmount > cl.SingleMax {
slog.Info("order above instance single max, skipping",
"instance_id", c.inst.ID, "order", orderAmount, "max", cl.SingleMax)
continue
}
if cl.DailyLimit > 0 && c.dailyUsed+orderAmount > cl.DailyLimit {
slog.Info("instance daily remaining insufficient, skipping",
"instance_id", c.inst.ID, "used", c.dailyUsed,
"order", orderAmount, "limit", cl.DailyLimit)
continue
}
result = append(result, c)
}
return result
}
// getInstanceChannelLimits returns the channel limits for a specific payment type.
func getInstanceChannelLimits(inst *dbent.PaymentProviderInstance, paymentType PaymentType) ChannelLimits {
if inst.Limits == "" {
return ChannelLimits{}
}
var limits InstanceLimits
if err := json.Unmarshal([]byte(inst.Limits), &limits); err != nil {
return ChannelLimits{}
}
// For Stripe, limits are stored under the provider key "stripe".
lookupKey := paymentType
if inst.ProviderKey == "stripe" {
lookupKey = "stripe"
}
if cl, ok := limits[lookupKey]; ok {
return cl
}
return ChannelLimits{}
}
// pickByStrategy selects one instance from the available candidates.
func (lb *DefaultLoadBalancer) pickByStrategy(candidates []instanceCandidate, strategy Strategy) instanceCandidate {
if strategy == StrategyLeastAmount && len(candidates) > 1 {
return pickLeastAmount(candidates)
}
// Default: round-robin.
idx := lb.counter.Add(1) % uint64(len(candidates))
return candidates[idx]
}
// pickLeastAmount selects the instance with the lowest daily usage.
// No extra DB queries — usage was pre-fetched in attachDailyUsage.
func pickLeastAmount(candidates []instanceCandidate) instanceCandidate {
best := candidates[0]
for _, c := range candidates[1:] {
if c.dailyUsed < best.dailyUsed {
best = c
}
}
return best
}
func (lb *DefaultLoadBalancer) buildSelection(selected *dbent.PaymentProviderInstance) (*InstanceSelection, error) {
config, err := lb.decryptConfig(selected.Config)
if err != nil {
return nil, fmt.Errorf("decrypt instance %d config: %w", selected.ID, err)
}
if selected.PaymentMode != "" {
config["paymentMode"] = selected.PaymentMode
}
return &InstanceSelection{
InstanceID: fmt.Sprintf("%d", selected.ID),
Config: config,
SupportedTypes: selected.SupportedTypes,
PaymentMode: selected.PaymentMode,
}, nil
}
func (lb *DefaultLoadBalancer) decryptConfig(encrypted string) (map[string]string, error) {
plaintext, err := Decrypt(encrypted, lb.encryptionKey)
if err != nil {
return nil, err
}
var config map[string]string
if err := json.Unmarshal([]byte(plaintext), &config); err != nil {
return nil, fmt.Errorf("unmarshal config: %w", err)
}
return config, nil
}
// GetInstanceDailyAmount returns the total completed order amount for an instance today.
func (lb *DefaultLoadBalancer) GetInstanceDailyAmount(ctx context.Context, instanceID string) (float64, error) {
todayStart := startOfDay(time.Now())
var result []struct {
Sum float64 `json:"sum"`
}
err := lb.db.PaymentOrder.Query().
Where(
paymentorder.ProviderInstanceID(instanceID),
paymentorder.StatusIn(OrderStatusCompleted, OrderStatusPaid, OrderStatusRecharging),
paymentorder.PaidAtGTE(todayStart),
).
Aggregate(dbent.Sum(paymentorder.FieldPayAmount)).
Scan(ctx, &result)
if err != nil {
return 0, fmt.Errorf("query daily amount: %w", err)
}
if len(result) > 0 {
return result[0].Sum, nil
}
return 0, nil
}
func startOfDay(t time.Time) time.Time {
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location())
}
// InstanceSupportsType checks if the given supported types string includes the target type.
// An empty supportedTypes string means all types are supported.
func InstanceSupportsType(supportedTypes string, target PaymentType) bool {
if supportedTypes == "" {
return true
}
for _, t := range strings.Split(supportedTypes, ",") {
if strings.TrimSpace(t) == target {
return true
}
}
return false
}
// GetInstanceConfig decrypts and returns the configuration for a provider instance by ID.
func (lb *DefaultLoadBalancer) GetInstanceConfig(ctx context.Context, instanceID int64) (map[string]string, error) {
inst, err := lb.db.PaymentProviderInstance.Get(ctx, instanceID)
if err != nil {
return nil, fmt.Errorf("get instance %d: %w", instanceID, err)
}
return lb.decryptConfig(inst.Config)
}
//go:build unit
package payment
import (
"encoding/json"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
)
func TestInstanceSupportsType(t *testing.T) {
t.Parallel()
tests := []struct {
name string
supportedTypes string
target PaymentType
expected bool
}{
{
name: "exact match single type",
supportedTypes: "alipay",
target: "alipay",
expected: true,
},
{
name: "no match single type",
supportedTypes: "wxpay",
target: "alipay",
expected: false,
},
{
name: "match in comma-separated list",
supportedTypes: "alipay,wxpay,stripe",
target: "wxpay",
expected: true,
},
{
name: "first in comma-separated list",
supportedTypes: "alipay,wxpay",
target: "alipay",
expected: true,
},
{
name: "last in comma-separated list",
supportedTypes: "alipay,wxpay,stripe",
target: "stripe",
expected: true,
},
{
name: "no match in comma-separated list",
supportedTypes: "alipay,wxpay",
target: "stripe",
expected: false,
},
{
name: "empty target",
supportedTypes: "alipay,wxpay",
target: "",
expected: false,
},
{
name: "types with spaces are trimmed",
supportedTypes: " alipay , wxpay ",
target: "alipay",
expected: true,
},
{
name: "partial match should not succeed",
supportedTypes: "alipay_direct",
target: "alipay",
expected: false,
},
{
name: "empty supported types means all supported",
supportedTypes: "",
target: "alipay",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := InstanceSupportsType(tt.supportedTypes, tt.target)
if got != tt.expected {
t.Fatalf("InstanceSupportsType(%q, %q) = %v, want %v", tt.supportedTypes, tt.target, got, tt.expected)
}
})
}
}
// ---------------------------------------------------------------------------
// Helper to build test PaymentProviderInstance values
// ---------------------------------------------------------------------------
func testInstance(id int64, providerKey, limits string) *dbent.PaymentProviderInstance {
return &dbent.PaymentProviderInstance{
ID: id,
ProviderKey: providerKey,
Limits: limits,
Enabled: true,
}
}
// makeLimitsJSON builds a limits JSON string for a single payment type.
func makeLimitsJSON(paymentType string, cl ChannelLimits) string {
m := map[string]ChannelLimits{paymentType: cl}
b, _ := json.Marshal(m)
return string(b)
}
// ---------------------------------------------------------------------------
// filterByLimits
// ---------------------------------------------------------------------------
func TestFilterByLimits(t *testing.T) {
t.Parallel()
tests := []struct {
name string
candidates []instanceCandidate
paymentType PaymentType
orderAmount float64
wantIDs []int64 // expected surviving instance IDs
}{
{
name: "order below SingleMin is filtered out",
candidates: []instanceCandidate{
{inst: testInstance(1, "easypay", makeLimitsJSON("alipay", ChannelLimits{SingleMin: 10})), dailyUsed: 0},
},
paymentType: "alipay",
orderAmount: 5,
wantIDs: nil,
},
{
name: "order at exact SingleMin boundary passes",
candidates: []instanceCandidate{
{inst: testInstance(1, "easypay", makeLimitsJSON("alipay", ChannelLimits{SingleMin: 10})), dailyUsed: 0},
},
paymentType: "alipay",
orderAmount: 10,
wantIDs: []int64{1},
},
{
name: "order above SingleMax is filtered out",
candidates: []instanceCandidate{
{inst: testInstance(1, "easypay", makeLimitsJSON("alipay", ChannelLimits{SingleMax: 100})), dailyUsed: 0},
},
paymentType: "alipay",
orderAmount: 150,
wantIDs: nil,
},
{
name: "order at exact SingleMax boundary passes",
candidates: []instanceCandidate{
{inst: testInstance(1, "easypay", makeLimitsJSON("alipay", ChannelLimits{SingleMax: 100})), dailyUsed: 0},
},
paymentType: "alipay",
orderAmount: 100,
wantIDs: []int64{1},
},
{
name: "daily used + orderAmount exceeding dailyLimit is filtered out",
candidates: []instanceCandidate{
{inst: testInstance(1, "easypay", makeLimitsJSON("alipay", ChannelLimits{DailyLimit: 500})), dailyUsed: 480},
},
paymentType: "alipay",
orderAmount: 30,
wantIDs: nil, // 480+30=510 > 500
},
{
name: "daily used + orderAmount equal to dailyLimit passes (strict greater-than)",
candidates: []instanceCandidate{
{inst: testInstance(1, "easypay", makeLimitsJSON("alipay", ChannelLimits{DailyLimit: 500})), dailyUsed: 480},
},
paymentType: "alipay",
orderAmount: 20,
wantIDs: []int64{1}, // 480+20=500, 500 > 500 is false → passes
},
{
name: "daily used + orderAmount below dailyLimit passes",
candidates: []instanceCandidate{
{inst: testInstance(1, "easypay", makeLimitsJSON("alipay", ChannelLimits{DailyLimit: 500})), dailyUsed: 400},
},
paymentType: "alipay",
orderAmount: 50,
wantIDs: []int64{1},
},
{
name: "no limits configured passes through",
candidates: []instanceCandidate{
{inst: testInstance(1, "easypay", ""), dailyUsed: 99999},
},
paymentType: "alipay",
orderAmount: 100,
wantIDs: []int64{1},
},
{
name: "multiple candidates with partial filtering",
candidates: []instanceCandidate{
// singleMax=50, order=80 → filtered out
{inst: testInstance(1, "easypay", makeLimitsJSON("alipay", ChannelLimits{SingleMax: 50})), dailyUsed: 0},
// no limits → passes
{inst: testInstance(2, "easypay", ""), dailyUsed: 0},
// singleMin=100, order=80 → filtered out
{inst: testInstance(3, "easypay", makeLimitsJSON("alipay", ChannelLimits{SingleMin: 100})), dailyUsed: 0},
// daily limit ok → passes (500+80=580 < 1000)
{inst: testInstance(4, "easypay", makeLimitsJSON("alipay", ChannelLimits{DailyLimit: 1000})), dailyUsed: 500},
},
paymentType: "alipay",
orderAmount: 80,
wantIDs: []int64{2, 4},
},
{
name: "zero SingleMin and SingleMax means no single-transaction limit",
candidates: []instanceCandidate{
{inst: testInstance(1, "easypay", makeLimitsJSON("alipay", ChannelLimits{SingleMin: 0, SingleMax: 0, DailyLimit: 0})), dailyUsed: 0},
},
paymentType: "alipay",
orderAmount: 99999,
wantIDs: []int64{1},
},
{
name: "all limits combined - order passes all checks",
candidates: []instanceCandidate{
{inst: testInstance(1, "easypay", makeLimitsJSON("alipay", ChannelLimits{SingleMin: 10, SingleMax: 200, DailyLimit: 1000})), dailyUsed: 500},
},
paymentType: "alipay",
orderAmount: 50,
wantIDs: []int64{1},
},
{
name: "all limits combined - order fails SingleMin",
candidates: []instanceCandidate{
{inst: testInstance(1, "easypay", makeLimitsJSON("alipay", ChannelLimits{SingleMin: 10, SingleMax: 200, DailyLimit: 1000})), dailyUsed: 500},
},
paymentType: "alipay",
orderAmount: 5,
wantIDs: nil,
},
{
name: "empty candidates returns empty",
candidates: nil,
paymentType: "alipay",
orderAmount: 10,
wantIDs: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := filterByLimits(tt.candidates, tt.paymentType, tt.orderAmount)
gotIDs := make([]int64, len(got))
for i, c := range got {
gotIDs[i] = c.inst.ID
}
if !int64SliceEqual(gotIDs, tt.wantIDs) {
t.Fatalf("filterByLimits() returned IDs %v, want %v", gotIDs, tt.wantIDs)
}
})
}
}
// ---------------------------------------------------------------------------
// pickLeastAmount
// ---------------------------------------------------------------------------
func TestPickLeastAmount(t *testing.T) {
t.Parallel()
t.Run("picks candidate with lowest dailyUsed", func(t *testing.T) {
t.Parallel()
candidates := []instanceCandidate{
{inst: testInstance(1, "easypay", ""), dailyUsed: 300},
{inst: testInstance(2, "easypay", ""), dailyUsed: 100},
{inst: testInstance(3, "easypay", ""), dailyUsed: 200},
}
got := pickLeastAmount(candidates)
if got.inst.ID != 2 {
t.Fatalf("pickLeastAmount() picked instance %d, want 2", got.inst.ID)
}
})
t.Run("with equal dailyUsed picks the first one", func(t *testing.T) {
t.Parallel()
candidates := []instanceCandidate{
{inst: testInstance(1, "easypay", ""), dailyUsed: 100},
{inst: testInstance(2, "easypay", ""), dailyUsed: 100},
{inst: testInstance(3, "easypay", ""), dailyUsed: 200},
}
got := pickLeastAmount(candidates)
if got.inst.ID != 1 {
t.Fatalf("pickLeastAmount() picked instance %d, want 1 (first with lowest)", got.inst.ID)
}
})
t.Run("single candidate returns that candidate", func(t *testing.T) {
t.Parallel()
candidates := []instanceCandidate{
{inst: testInstance(42, "easypay", ""), dailyUsed: 999},
}
got := pickLeastAmount(candidates)
if got.inst.ID != 42 {
t.Fatalf("pickLeastAmount() picked instance %d, want 42", got.inst.ID)
}
})
t.Run("zero usage among non-zero picks zero", func(t *testing.T) {
t.Parallel()
candidates := []instanceCandidate{
{inst: testInstance(1, "easypay", ""), dailyUsed: 500},
{inst: testInstance(2, "easypay", ""), dailyUsed: 0},
{inst: testInstance(3, "easypay", ""), dailyUsed: 300},
}
got := pickLeastAmount(candidates)
if got.inst.ID != 2 {
t.Fatalf("pickLeastAmount() picked instance %d, want 2", got.inst.ID)
}
})
}
// ---------------------------------------------------------------------------
// getInstanceChannelLimits
// ---------------------------------------------------------------------------
func TestGetInstanceChannelLimits(t *testing.T) {
t.Parallel()
tests := []struct {
name string
inst *dbent.PaymentProviderInstance
paymentType PaymentType
want ChannelLimits
}{
{
name: "empty limits string returns zero ChannelLimits",
inst: testInstance(1, "easypay", ""),
paymentType: "alipay",
want: ChannelLimits{},
},
{
name: "invalid JSON returns zero ChannelLimits",
inst: testInstance(1, "easypay", "not-json{"),
paymentType: "alipay",
want: ChannelLimits{},
},
{
name: "valid JSON with matching payment type",
inst: testInstance(1, "easypay",
`{"alipay":{"singleMin":5,"singleMax":200,"dailyLimit":1000}}`),
paymentType: "alipay",
want: ChannelLimits{SingleMin: 5, SingleMax: 200, DailyLimit: 1000},
},
{
name: "payment type not in limits returns zero ChannelLimits",
inst: testInstance(1, "easypay",
`{"alipay":{"singleMin":5,"singleMax":200}}`),
paymentType: "wxpay",
want: ChannelLimits{},
},
{
name: "stripe provider uses stripe lookup key regardless of payment type",
inst: testInstance(1, "stripe",
`{"stripe":{"singleMin":10,"singleMax":500,"dailyLimit":5000}}`),
paymentType: "alipay",
want: ChannelLimits{SingleMin: 10, SingleMax: 500, DailyLimit: 5000},
},
{
name: "stripe provider ignores payment type key even if present",
inst: testInstance(1, "stripe",
`{"stripe":{"singleMin":10,"singleMax":500},"alipay":{"singleMin":1,"singleMax":100}}`),
paymentType: "alipay",
want: ChannelLimits{SingleMin: 10, SingleMax: 500},
},
{
name: "non-stripe provider uses payment type as lookup key",
inst: testInstance(1, "easypay",
`{"alipay":{"singleMin":5},"wxpay":{"singleMin":10}}`),
paymentType: "wxpay",
want: ChannelLimits{SingleMin: 10},
},
{
name: "valid JSON with partial limits (only dailyLimit)",
inst: testInstance(1, "easypay",
`{"alipay":{"dailyLimit":800}}`),
paymentType: "alipay",
want: ChannelLimits{DailyLimit: 800},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := getInstanceChannelLimits(tt.inst, tt.paymentType)
if got != tt.want {
t.Fatalf("getInstanceChannelLimits() = %+v, want %+v", got, tt.want)
}
})
}
}
// ---------------------------------------------------------------------------
// startOfDay
// ---------------------------------------------------------------------------
func TestStartOfDay(t *testing.T) {
t.Parallel()
tests := []struct {
name string
in time.Time
want time.Time
}{
{
name: "midday returns midnight of same day",
in: time.Date(2025, 6, 15, 14, 30, 45, 123456789, time.UTC),
want: time.Date(2025, 6, 15, 0, 0, 0, 0, time.UTC),
},
{
name: "midnight returns same time",
in: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
want: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
},
{
name: "last second of day returns midnight of same day",
in: time.Date(2025, 12, 31, 23, 59, 59, 999999999, time.UTC),
want: time.Date(2025, 12, 31, 0, 0, 0, 0, time.UTC),
},
{
name: "preserves timezone location",
in: time.Date(2025, 3, 10, 15, 0, 0, 0, time.FixedZone("CST", 8*3600)),
want: time.Date(2025, 3, 10, 0, 0, 0, 0, time.FixedZone("CST", 8*3600)),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := startOfDay(tt.in)
if !got.Equal(tt.want) {
t.Fatalf("startOfDay(%v) = %v, want %v", tt.in, got, tt.want)
}
// Also verify location is preserved.
if got.Location().String() != tt.want.Location().String() {
t.Fatalf("startOfDay() location = %v, want %v", got.Location(), tt.want.Location())
}
})
}
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
// int64SliceEqual compares two int64 slices for equality.
// Both nil and empty slices are treated as equal.
func int64SliceEqual(a, b []int64) bool {
if len(a) == 0 && len(b) == 0 {
return true
}
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
package provider
import (
"context"
"fmt"
"net/url"
"strconv"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/smartwalle/alipay/v3"
)
// Alipay product codes.
const (
alipayProductCodePagePay = "FAST_INSTANT_TRADE_PAY"
alipayProductCodeWapPay = "QUICK_WAP_WAY"
)
// Alipay response constants.
const (
alipayFundChangeYes = "Y"
alipayErrTradeNotExist = "ACQ.TRADE_NOT_EXIST"
alipayRefundSuffix = "-refund"
)
// Alipay implements payment.Provider and payment.CancelableProvider using the smartwalle/alipay SDK.
type Alipay struct {
instanceID string
config map[string]string // appId, privateKey, publicKey (or alipayPublicKey), notifyUrl, returnUrl
mu sync.Mutex
client *alipay.Client
}
// NewAlipay creates a new Alipay provider instance.
func NewAlipay(instanceID string, config map[string]string) (*Alipay, error) {
required := []string{"appId", "privateKey"}
for _, k := range required {
if config[k] == "" {
return nil, fmt.Errorf("alipay config missing required key: %s", k)
}
}
return &Alipay{
instanceID: instanceID,
config: config,
}, nil
}
func (a *Alipay) getClient() (*alipay.Client, error) {
a.mu.Lock()
defer a.mu.Unlock()
if a.client != nil {
return a.client, nil
}
client, err := alipay.New(a.config["appId"], a.config["privateKey"], true)
if err != nil {
return nil, fmt.Errorf("alipay init client: %w", err)
}
pubKey := a.config["publicKey"]
if pubKey == "" {
pubKey = a.config["alipayPublicKey"]
}
if pubKey == "" {
return nil, fmt.Errorf("alipay config missing required key: publicKey (or alipayPublicKey)")
}
if err := client.LoadAliPayPublicKey(pubKey); err != nil {
return nil, fmt.Errorf("alipay load public key: %w", err)
}
a.client = client
return a.client, nil
}
func (a *Alipay) Name() string { return "Alipay" }
func (a *Alipay) ProviderKey() string { return payment.TypeAlipay }
func (a *Alipay) SupportedTypes() []payment.PaymentType {
return []payment.PaymentType{payment.TypeAlipayDirect}
}
// CreatePayment creates an Alipay payment page URL.
func (a *Alipay) CreatePayment(_ context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
client, err := a.getClient()
if err != nil {
return nil, err
}
notifyURL := a.config["notifyUrl"]
if req.NotifyURL != "" {
notifyURL = req.NotifyURL
}
returnURL := a.config["returnUrl"]
if req.ReturnURL != "" {
returnURL = req.ReturnURL
}
if req.IsMobile {
return a.createTrade(client, req, notifyURL, returnURL, true)
}
return a.createTrade(client, req, notifyURL, returnURL, false)
}
func (a *Alipay) createTrade(client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string, isMobile bool) (*payment.CreatePaymentResponse, error) {
if isMobile {
param := alipay.TradeWapPay{}
param.OutTradeNo = req.OrderID
param.TotalAmount = req.Amount
param.Subject = req.Subject
param.ProductCode = alipayProductCodeWapPay
param.NotifyURL = notifyURL
param.ReturnURL = returnURL
payURL, err := client.TradeWapPay(param)
if err != nil {
return nil, fmt.Errorf("alipay TradeWapPay: %w", err)
}
return &payment.CreatePaymentResponse{
TradeNo: req.OrderID,
PayURL: payURL.String(),
}, nil
}
param := alipay.TradePagePay{}
param.OutTradeNo = req.OrderID
param.TotalAmount = req.Amount
param.Subject = req.Subject
param.ProductCode = alipayProductCodePagePay
param.NotifyURL = notifyURL
param.ReturnURL = returnURL
payURL, err := client.TradePagePay(param)
if err != nil {
return nil, fmt.Errorf("alipay TradePagePay: %w", err)
}
return &payment.CreatePaymentResponse{
TradeNo: req.OrderID,
PayURL: payURL.String(),
QRCode: payURL.String(),
}, nil
}
// QueryOrder queries the trade status via Alipay.
func (a *Alipay) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryOrderResponse, error) {
client, err := a.getClient()
if err != nil {
return nil, err
}
result, err := client.TradeQuery(ctx, alipay.TradeQuery{OutTradeNo: tradeNo})
if err != nil {
if isTradeNotExist(err) {
return &payment.QueryOrderResponse{
TradeNo: tradeNo,
Status: payment.ProviderStatusPending,
}, nil
}
return nil, fmt.Errorf("alipay TradeQuery: %w", err)
}
status := payment.ProviderStatusPending
switch result.TradeStatus {
case alipay.TradeStatusSuccess, alipay.TradeStatusFinished:
status = payment.ProviderStatusPaid
case alipay.TradeStatusClosed:
status = payment.ProviderStatusFailed
}
amount, err := strconv.ParseFloat(result.TotalAmount, 64)
if err != nil {
return nil, fmt.Errorf("alipay parse amount %q: %w", result.TotalAmount, err)
}
return &payment.QueryOrderResponse{
TradeNo: result.TradeNo,
Status: status,
Amount: amount,
PaidAt: result.SendPayDate,
}, nil
}
// VerifyNotification decodes and verifies an Alipay async notification.
func (a *Alipay) VerifyNotification(ctx context.Context, rawBody string, _ map[string]string) (*payment.PaymentNotification, error) {
client, err := a.getClient()
if err != nil {
return nil, err
}
values, err := url.ParseQuery(rawBody)
if err != nil {
return nil, fmt.Errorf("alipay parse notification: %w", err)
}
notification, err := client.DecodeNotification(ctx, values)
if err != nil {
return nil, fmt.Errorf("alipay verify notification: %w", err)
}
status := payment.ProviderStatusFailed
if notification.TradeStatus == alipay.TradeStatusSuccess || notification.TradeStatus == alipay.TradeStatusFinished {
status = payment.ProviderStatusSuccess
}
amount, err := strconv.ParseFloat(notification.TotalAmount, 64)
if err != nil {
return nil, fmt.Errorf("alipay parse notification amount %q: %w", notification.TotalAmount, err)
}
return &payment.PaymentNotification{
TradeNo: notification.TradeNo,
OrderID: notification.OutTradeNo,
Amount: amount,
Status: status,
RawData: rawBody,
}, nil
}
// Refund requests a refund through Alipay.
func (a *Alipay) Refund(ctx context.Context, req payment.RefundRequest) (*payment.RefundResponse, error) {
client, err := a.getClient()
if err != nil {
return nil, err
}
result, err := client.TradeRefund(ctx, alipay.TradeRefund{
OutTradeNo: req.OrderID,
RefundAmount: req.Amount,
RefundReason: req.Reason,
OutRequestNo: fmt.Sprintf("%s-refund-%d", req.OrderID, time.Now().UnixNano()),
})
if err != nil {
return nil, fmt.Errorf("alipay TradeRefund: %w", err)
}
refundStatus := payment.ProviderStatusPending
if result.FundChange == alipayFundChangeYes {
refundStatus = payment.ProviderStatusSuccess
}
refundID := result.TradeNo
if refundID == "" {
refundID = req.OrderID + alipayRefundSuffix
}
return &payment.RefundResponse{
RefundID: refundID,
Status: refundStatus,
}, nil
}
// CancelPayment closes a pending trade on Alipay.
func (a *Alipay) CancelPayment(ctx context.Context, tradeNo string) error {
client, err := a.getClient()
if err != nil {
return err
}
_, err = client.TradeClose(ctx, alipay.TradeClose{OutTradeNo: tradeNo})
if err != nil {
if isTradeNotExist(err) {
return nil
}
return fmt.Errorf("alipay TradeClose: %w", err)
}
return nil
}
func isTradeNotExist(err error) bool {
if err == nil {
return false
}
return strings.Contains(err.Error(), alipayErrTradeNotExist)
}
// Ensure interface compliance.
var (
_ payment.Provider = (*Alipay)(nil)
_ payment.CancelableProvider = (*Alipay)(nil)
)
//go:build unit
package provider
import (
"errors"
"strings"
"testing"
)
func TestIsTradeNotExist(t *testing.T) {
t.Parallel()
tests := []struct {
name string
err error
want bool
}{
{
name: "nil error returns false",
err: nil,
want: false,
},
{
name: "error containing ACQ.TRADE_NOT_EXIST returns true",
err: errors.New("alipay: sub_code=ACQ.TRADE_NOT_EXIST, sub_msg=交易不存在"),
want: true,
},
{
name: "error not containing the code returns false",
err: errors.New("alipay: sub_code=ACQ.SYSTEM_ERROR, sub_msg=系统错误"),
want: false,
},
{
name: "error with only partial match returns false",
err: errors.New("ACQ.TRADE_NOT"),
want: false,
},
{
name: "error with exact constant value returns true",
err: errors.New(alipayErrTradeNotExist),
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := isTradeNotExist(tt.err)
if got != tt.want {
t.Errorf("isTradeNotExist(%v) = %v, want %v", tt.err, got, tt.want)
}
})
}
}
func TestNewAlipay(t *testing.T) {
t.Parallel()
validConfig := map[string]string{
"appId": "2021001234567890",
"privateKey": "MIIEvQIBADANBgkqhkiG9w0BAQEFAASC...",
}
// helper to clone and override config fields
withOverride := func(overrides map[string]string) map[string]string {
cfg := make(map[string]string, len(validConfig))
for k, v := range validConfig {
cfg[k] = v
}
for k, v := range overrides {
cfg[k] = v
}
return cfg
}
tests := []struct {
name string
config map[string]string
wantErr bool
errSubstr string
}{
{
name: "valid config succeeds",
config: validConfig,
wantErr: false,
},
{
name: "missing appId",
config: withOverride(map[string]string{"appId": ""}),
wantErr: true,
errSubstr: "appId",
},
{
name: "missing privateKey",
config: withOverride(map[string]string{"privateKey": ""}),
wantErr: true,
errSubstr: "privateKey",
},
{
name: "nil config map returns error for appId",
config: map[string]string{},
wantErr: true,
errSubstr: "appId",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := NewAlipay("test-instance", tt.config)
if tt.wantErr {
if err == nil {
t.Fatal("expected error, got nil")
}
if tt.errSubstr != "" && !strings.Contains(err.Error(), tt.errSubstr) {
t.Errorf("error %q should contain %q", err.Error(), tt.errSubstr)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got == nil {
t.Fatal("expected non-nil Alipay instance")
}
if got.instanceID != "test-instance" {
t.Errorf("instanceID = %q, want %q", got.instanceID, "test-instance")
}
})
}
}
// Package provider contains concrete payment provider implementations.
package provider
import (
"context"
"crypto/hmac"
"crypto/md5"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/payment"
)
// EasyPay constants.
const (
easypayCodeSuccess = 1
easypayStatusPaid = 1
easypayHTTPTimeout = 10 * time.Second
maxEasypayResponseSize = 1 << 20 // 1MB
tradeStatusSuccess = "TRADE_SUCCESS"
signTypeMD5 = "MD5"
paymentModePopup = "popup"
deviceMobile = "mobile"
)
// EasyPay implements payment.Provider for the EasyPay aggregation platform.
type EasyPay struct {
instanceID string
config map[string]string
httpClient *http.Client
}
// NewEasyPay creates a new EasyPay provider.
// config keys: pid, pkey, apiBase, notifyUrl, returnUrl, cid, cidAlipay, cidWxpay
func NewEasyPay(instanceID string, config map[string]string) (*EasyPay, error) {
for _, k := range []string{"pid", "pkey", "apiBase", "notifyUrl", "returnUrl"} {
if config[k] == "" {
return nil, fmt.Errorf("easypay config missing required key: %s", k)
}
}
return &EasyPay{
instanceID: instanceID,
config: config,
httpClient: &http.Client{Timeout: easypayHTTPTimeout},
}, nil
}
func (e *EasyPay) Name() string { return "EasyPay" }
func (e *EasyPay) ProviderKey() string { return payment.TypeEasyPay }
func (e *EasyPay) SupportedTypes() []payment.PaymentType {
return []payment.PaymentType{payment.TypeAlipay, payment.TypeWxpay}
}
func (e *EasyPay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
// Payment mode determined by instance config, not payment type.
// "popup" → hosted page (submit.php); "qrcode"/default → API call (mapi.php).
mode := e.config["paymentMode"]
if mode == paymentModePopup {
return e.createRedirectPayment(req)
}
return e.createAPIPayment(ctx, req)
}
// createRedirectPayment builds a submit.php URL for browser redirect.
// No server-side API call — the user is redirected to EasyPay's hosted page.
// TradeNo is empty; it arrives via the notify callback after payment.
func (e *EasyPay) createRedirectPayment(req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
notifyURL, returnURL := e.resolveURLs(req)
params := map[string]string{
"pid": e.config["pid"], "type": req.PaymentType,
"out_trade_no": req.OrderID, "notify_url": notifyURL,
"return_url": returnURL, "name": req.Subject,
"money": req.Amount,
}
if cid := e.resolveCID(req.PaymentType); cid != "" {
params["cid"] = cid
}
if req.IsMobile {
params["device"] = deviceMobile
}
params["sign"] = easyPaySign(params, e.config["pkey"])
params["sign_type"] = signTypeMD5
q := url.Values{}
for k, v := range params {
q.Set(k, v)
}
base := strings.TrimRight(e.config["apiBase"], "/")
payURL := base + "/submit.php?" + q.Encode()
return &payment.CreatePaymentResponse{PayURL: payURL}, nil
}
// createAPIPayment calls mapi.php to get payurl/qrcode (existing behavior).
func (e *EasyPay) createAPIPayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
notifyURL, returnURL := e.resolveURLs(req)
params := map[string]string{
"pid": e.config["pid"], "type": req.PaymentType,
"out_trade_no": req.OrderID, "notify_url": notifyURL,
"return_url": returnURL, "name": req.Subject,
"money": req.Amount, "clientip": req.ClientIP,
}
if cid := e.resolveCID(req.PaymentType); cid != "" {
params["cid"] = cid
}
if req.IsMobile {
params["device"] = deviceMobile
}
params["sign"] = easyPaySign(params, e.config["pkey"])
params["sign_type"] = signTypeMD5
body, err := e.post(ctx, strings.TrimRight(e.config["apiBase"], "/")+"/mapi.php", params)
if err != nil {
return nil, fmt.Errorf("easypay create: %w", err)
}
var resp struct {
Code int `json:"code"`
Msg string `json:"msg"`
TradeNo string `json:"trade_no"`
PayURL string `json:"payurl"`
PayURL2 string `json:"payurl2"` // H5 mobile payment URL
QRCode string `json:"qrcode"`
}
if err := json.Unmarshal(body, &resp); err != nil {
return nil, fmt.Errorf("easypay parse: %w", err)
}
if resp.Code != easypayCodeSuccess {
return nil, fmt.Errorf("easypay error: %s", resp.Msg)
}
payURL := resp.PayURL
if req.IsMobile && resp.PayURL2 != "" {
payURL = resp.PayURL2
}
return &payment.CreatePaymentResponse{TradeNo: resp.TradeNo, PayURL: payURL, QRCode: resp.QRCode}, nil
}
// resolveURLs returns (notifyURL, returnURL) preferring request values,
// falling back to instance config.
func (e *EasyPay) resolveURLs(req payment.CreatePaymentRequest) (string, string) {
notifyURL := req.NotifyURL
if notifyURL == "" {
notifyURL = e.config["notifyUrl"]
}
returnURL := req.ReturnURL
if returnURL == "" {
returnURL = e.config["returnUrl"]
}
return notifyURL, returnURL
}
func (e *EasyPay) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryOrderResponse, error) {
params := map[string]string{
"act": "order", "pid": e.config["pid"],
"key": e.config["pkey"], "out_trade_no": tradeNo,
}
body, err := e.post(ctx, e.config["apiBase"]+"/api.php", params)
if err != nil {
return nil, fmt.Errorf("easypay query: %w", err)
}
var resp struct {
Code int `json:"code"`
Msg string `json:"msg"`
Status int `json:"status"`
Money string `json:"money"`
}
if err := json.Unmarshal(body, &resp); err != nil {
return nil, fmt.Errorf("easypay parse query: %w", err)
}
status := payment.ProviderStatusPending
if resp.Status == easypayStatusPaid {
status = payment.ProviderStatusPaid
}
amount, _ := strconv.ParseFloat(resp.Money, 64)
return &payment.QueryOrderResponse{TradeNo: tradeNo, Status: status, Amount: amount}, nil
}
func (e *EasyPay) VerifyNotification(_ context.Context, rawBody string, _ map[string]string) (*payment.PaymentNotification, error) {
values, err := url.ParseQuery(rawBody)
if err != nil {
return nil, fmt.Errorf("parse notify: %w", err)
}
// url.ParseQuery already decodes values — no additional decode needed.
params := make(map[string]string)
for k := range values {
params[k] = values.Get(k)
}
sign := params["sign"]
if sign == "" {
return nil, fmt.Errorf("missing sign")
}
if !easyPayVerifySign(params, e.config["pkey"], sign) {
return nil, fmt.Errorf("invalid signature")
}
status := payment.ProviderStatusFailed
if params["trade_status"] == tradeStatusSuccess {
status = payment.ProviderStatusSuccess
}
amount, _ := strconv.ParseFloat(params["money"], 64)
return &payment.PaymentNotification{
TradeNo: params["trade_no"], OrderID: params["out_trade_no"],
Amount: amount, Status: status, RawData: rawBody,
}, nil
}
func (e *EasyPay) Refund(ctx context.Context, req payment.RefundRequest) (*payment.RefundResponse, error) {
params := map[string]string{
"pid": e.config["pid"], "key": e.config["pkey"],
"trade_no": req.TradeNo, "out_trade_no": req.OrderID, "money": req.Amount,
}
body, err := e.post(ctx, e.config["apiBase"]+"/api.php?act=refund", params)
if err != nil {
return nil, fmt.Errorf("easypay refund: %w", err)
}
var resp struct {
Code int `json:"code"`
Msg string `json:"msg"`
}
if err := json.Unmarshal(body, &resp); err != nil {
return nil, fmt.Errorf("easypay parse refund: %w", err)
}
if resp.Code != easypayCodeSuccess {
return nil, fmt.Errorf("easypay refund failed: %s", resp.Msg)
}
return &payment.RefundResponse{RefundID: req.TradeNo, Status: payment.ProviderStatusSuccess}, nil
}
func (e *EasyPay) resolveCID(paymentType string) string {
if strings.HasPrefix(paymentType, "alipay") {
if v := e.config["cidAlipay"]; v != "" {
return v
}
return e.config["cid"]
}
if v := e.config["cidWxpay"]; v != "" {
return v
}
return e.config["cid"]
}
func (e *EasyPay) post(ctx context.Context, endpoint string, params map[string]string) ([]byte, error) {
form := url.Values{}
for k, v := range params {
form.Set(k, v)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := e.httpClient.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
return io.ReadAll(io.LimitReader(resp.Body, maxEasypayResponseSize))
}
func easyPaySign(params map[string]string, pkey string) string {
keys := make([]string, 0, len(params))
for k, v := range params {
if k == "sign" || k == "sign_type" || v == "" {
continue
}
keys = append(keys, k)
}
sort.Strings(keys)
var buf strings.Builder
for i, k := range keys {
if i > 0 {
_ = buf.WriteByte('&')
}
_, _ = buf.WriteString(k + "=" + params[k])
}
_, _ = buf.WriteString(pkey)
hash := md5.Sum([]byte(buf.String()))
return hex.EncodeToString(hash[:])
}
func easyPayVerifySign(params map[string]string, pkey string, sign string) bool {
return hmac.Equal([]byte(easyPaySign(params, pkey)), []byte(sign))
}
package provider
import (
"testing"
)
func TestEasyPaySignConsistentOutput(t *testing.T) {
t.Parallel()
params := map[string]string{
"pid": "1001",
"type": "alipay",
"out_trade_no": "ORDER123",
"name": "Test Product",
"money": "10.00",
}
pkey := "test_secret_key"
sign1 := easyPaySign(params, pkey)
sign2 := easyPaySign(params, pkey)
if sign1 != sign2 {
t.Fatalf("easyPaySign should be deterministic: %q != %q", sign1, sign2)
}
if len(sign1) != 32 {
t.Fatalf("MD5 hex should be 32 chars, got %d", len(sign1))
}
}
func TestEasyPaySignExcludesSignAndSignType(t *testing.T) {
t.Parallel()
pkey := "my_key"
base := map[string]string{
"pid": "1001",
"type": "alipay",
}
withSign := map[string]string{
"pid": "1001",
"type": "alipay",
"sign": "should_be_ignored",
"sign_type": "MD5",
}
signBase := easyPaySign(base, pkey)
signWithExtra := easyPaySign(withSign, pkey)
if signBase != signWithExtra {
t.Fatalf("sign and sign_type should be excluded: base=%q, withExtra=%q", signBase, signWithExtra)
}
}
func TestEasyPaySignExcludesEmptyValues(t *testing.T) {
t.Parallel()
pkey := "key123"
base := map[string]string{
"pid": "1001",
"type": "alipay",
}
withEmpty := map[string]string{
"pid": "1001",
"type": "alipay",
"device": "",
"clientip": "",
}
signBase := easyPaySign(base, pkey)
signWithEmpty := easyPaySign(withEmpty, pkey)
if signBase != signWithEmpty {
t.Fatalf("empty values should be excluded: base=%q, withEmpty=%q", signBase, signWithEmpty)
}
}
func TestEasyPayVerifySignValid(t *testing.T) {
t.Parallel()
params := map[string]string{
"pid": "1001",
"type": "alipay",
"out_trade_no": "ORDER456",
"money": "25.00",
}
pkey := "secret"
sign := easyPaySign(params, pkey)
// Add sign to params (as would come in a real callback)
params["sign"] = sign
params["sign_type"] = "MD5"
if !easyPayVerifySign(params, pkey, sign) {
t.Fatal("easyPayVerifySign should return true for a valid signature")
}
}
func TestEasyPayVerifySignTampered(t *testing.T) {
t.Parallel()
params := map[string]string{
"pid": "1001",
"type": "alipay",
"out_trade_no": "ORDER789",
"money": "50.00",
}
pkey := "secret"
sign := easyPaySign(params, pkey)
// Tamper with the amount
params["money"] = "99.99"
if easyPayVerifySign(params, pkey, sign) {
t.Fatal("easyPayVerifySign should return false for tampered params")
}
}
func TestEasyPayVerifySignWrongKey(t *testing.T) {
t.Parallel()
params := map[string]string{
"pid": "1001",
"type": "wxpay",
}
sign := easyPaySign(params, "correct_key")
if easyPayVerifySign(params, "wrong_key", sign) {
t.Fatal("easyPayVerifySign should return false with wrong key")
}
}
func TestEasyPaySignEmptyParams(t *testing.T) {
t.Parallel()
sign := easyPaySign(map[string]string{}, "key123")
if sign == "" {
t.Fatal("easyPaySign with empty params should still produce a hash")
}
if len(sign) != 32 {
t.Fatalf("MD5 hex should be 32 chars, got %d", len(sign))
}
}
func TestEasyPaySignSortOrder(t *testing.T) {
t.Parallel()
pkey := "test_key"
params1 := map[string]string{
"a": "1",
"b": "2",
"c": "3",
}
params2 := map[string]string{
"c": "3",
"a": "1",
"b": "2",
}
sign1 := easyPaySign(params1, pkey)
sign2 := easyPaySign(params2, pkey)
if sign1 != sign2 {
t.Fatalf("easyPaySign should be order-independent: %q != %q", sign1, sign2)
}
}
func TestEasyPayVerifySignWrongSignValue(t *testing.T) {
t.Parallel()
params := map[string]string{
"pid": "1001",
"type": "alipay",
}
pkey := "key"
if easyPayVerifySign(params, pkey, "00000000000000000000000000000000") {
t.Fatal("easyPayVerifySign should return false for an incorrect sign value")
}
}
package provider
import (
"fmt"
"github.com/Wei-Shaw/sub2api/internal/payment"
)
// CreateProvider creates a Provider from a provider key, instance ID and decrypted config.
func CreateProvider(providerKey string, instanceID string, config map[string]string) (payment.Provider, error) {
switch providerKey {
case payment.TypeEasyPay:
return NewEasyPay(instanceID, config)
case payment.TypeAlipay:
return NewAlipay(instanceID, config)
case payment.TypeWxpay:
return NewWxpay(instanceID, config)
case payment.TypeStripe:
return NewStripe(instanceID, config)
default:
return nil, fmt.Errorf("unknown provider key: %s", providerKey)
}
}
package provider
import (
"context"
"encoding/json"
"fmt"
"strings"
"sync"
"github.com/Wei-Shaw/sub2api/internal/payment"
stripe "github.com/stripe/stripe-go/v85"
"github.com/stripe/stripe-go/v85/webhook"
)
// Stripe constants.
const (
stripeCurrency = "cny"
stripeEventPaymentSuccess = "payment_intent.succeeded"
stripeEventPaymentFailed = "payment_intent.payment_failed"
)
// Stripe implements the payment.CancelableProvider interface for Stripe payments.
type Stripe struct {
instanceID string
config map[string]string
mu sync.Mutex
initialized bool
sc *stripe.Client
}
// NewStripe creates a new Stripe provider instance.
func NewStripe(instanceID string, config map[string]string) (*Stripe, error) {
if config["secretKey"] == "" {
return nil, fmt.Errorf("stripe config missing required key: secretKey")
}
return &Stripe{
instanceID: instanceID,
config: config,
}, nil
}
func (s *Stripe) ensureInit() {
s.mu.Lock()
defer s.mu.Unlock()
if !s.initialized {
s.sc = stripe.NewClient(s.config["secretKey"])
s.initialized = true
}
}
// GetPublishableKey returns the publishable key for frontend use.
func (s *Stripe) GetPublishableKey() string {
return s.config["publishableKey"]
}
func (s *Stripe) Name() string { return "Stripe" }
func (s *Stripe) ProviderKey() string { return payment.TypeStripe }
func (s *Stripe) SupportedTypes() []payment.PaymentType {
return []payment.PaymentType{payment.TypeStripe}
}
// stripePaymentMethodTypes maps our PaymentType to Stripe payment_method_types.
var stripePaymentMethodTypes = map[string][]string{
payment.TypeCard: {"card"},
payment.TypeAlipay: {"alipay"},
payment.TypeWxpay: {"wechat_pay"},
payment.TypeLink: {"link"},
}
// CreatePayment creates a Stripe PaymentIntent.
func (s *Stripe) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
s.ensureInit()
amountInCents, err := payment.YuanToFen(req.Amount)
if err != nil {
return nil, fmt.Errorf("stripe create payment: %w", err)
}
// Collect all Stripe payment_method_types from the instance's configured sub-methods
methods := resolveStripeMethodTypes(req.InstanceSubMethods)
pmTypes := make([]*string, len(methods))
for i, m := range methods {
pmTypes[i] = stripe.String(m)
}
params := &stripe.PaymentIntentCreateParams{
Amount: stripe.Int64(amountInCents),
Currency: stripe.String(stripeCurrency),
PaymentMethodTypes: pmTypes,
Description: stripe.String(req.Subject),
Metadata: map[string]string{"orderId": req.OrderID},
}
// WeChat Pay requires payment_method_options with client type
if hasStripeMethod(methods, "wechat_pay") {
params.PaymentMethodOptions = &stripe.PaymentIntentCreatePaymentMethodOptionsParams{
WeChatPay: &stripe.PaymentIntentCreatePaymentMethodOptionsWeChatPayParams{
Client: stripe.String("web"),
},
}
}
params.SetIdempotencyKey(fmt.Sprintf("pi-%s", req.OrderID))
params.Context = ctx
pi, err := s.sc.V1PaymentIntents.Create(ctx, params)
if err != nil {
return nil, fmt.Errorf("stripe create payment: %w", err)
}
return &payment.CreatePaymentResponse{
TradeNo: pi.ID,
ClientSecret: pi.ClientSecret,
}, nil
}
// QueryOrder retrieves a PaymentIntent by ID.
func (s *Stripe) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryOrderResponse, error) {
s.ensureInit()
pi, err := s.sc.V1PaymentIntents.Retrieve(ctx, tradeNo, nil)
if err != nil {
return nil, fmt.Errorf("stripe query order: %w", err)
}
status := payment.ProviderStatusPending
switch pi.Status {
case stripe.PaymentIntentStatusSucceeded:
status = payment.ProviderStatusPaid
case stripe.PaymentIntentStatusCanceled:
status = payment.ProviderStatusFailed
}
return &payment.QueryOrderResponse{
TradeNo: pi.ID,
Status: status,
Amount: payment.FenToYuan(pi.Amount),
}, nil
}
// VerifyNotification verifies a Stripe webhook event.
func (s *Stripe) VerifyNotification(_ context.Context, rawBody string, headers map[string]string) (*payment.PaymentNotification, error) {
s.ensureInit()
webhookSecret := s.config["webhookSecret"]
if webhookSecret == "" {
return nil, fmt.Errorf("stripe webhookSecret not configured")
}
sig := headers["stripe-signature"]
if sig == "" {
return nil, fmt.Errorf("stripe notification missing stripe-signature header")
}
event, err := webhook.ConstructEvent([]byte(rawBody), sig, webhookSecret)
if err != nil {
return nil, fmt.Errorf("stripe verify notification: %w", err)
}
switch event.Type {
case stripeEventPaymentSuccess:
return parseStripePaymentIntent(&event, payment.ProviderStatusSuccess, rawBody)
case stripeEventPaymentFailed:
return parseStripePaymentIntent(&event, payment.ProviderStatusFailed, rawBody)
}
return nil, nil
}
func parseStripePaymentIntent(event *stripe.Event, status string, rawBody string) (*payment.PaymentNotification, error) {
var pi stripe.PaymentIntent
if err := json.Unmarshal(event.Data.Raw, &pi); err != nil {
return nil, fmt.Errorf("stripe parse payment_intent: %w", err)
}
return &payment.PaymentNotification{
TradeNo: pi.ID,
OrderID: pi.Metadata["orderId"],
Amount: payment.FenToYuan(pi.Amount),
Status: status,
RawData: rawBody,
}, nil
}
// Refund creates a Stripe refund.
func (s *Stripe) Refund(ctx context.Context, req payment.RefundRequest) (*payment.RefundResponse, error) {
s.ensureInit()
amountInCents, err := payment.YuanToFen(req.Amount)
if err != nil {
return nil, fmt.Errorf("stripe refund: %w", err)
}
params := &stripe.RefundCreateParams{
PaymentIntent: stripe.String(req.TradeNo),
Amount: stripe.Int64(amountInCents),
Reason: stripe.String(string(stripe.RefundReasonRequestedByCustomer)),
}
params.Context = ctx
r, err := s.sc.V1Refunds.Create(ctx, params)
if err != nil {
return nil, fmt.Errorf("stripe refund: %w", err)
}
refundStatus := payment.ProviderStatusPending
if r.Status == stripe.RefundStatusSucceeded {
refundStatus = payment.ProviderStatusSuccess
}
return &payment.RefundResponse{
RefundID: r.ID,
Status: refundStatus,
}, nil
}
// resolveStripeMethodTypes converts instance supported_types (comma-separated)
// into Stripe API payment_method_types. Falls back to ["card"] if empty.
func resolveStripeMethodTypes(instanceSubMethods string) []string {
if instanceSubMethods == "" {
return []string{"card"}
}
var methods []string
for _, t := range strings.Split(instanceSubMethods, ",") {
t = strings.TrimSpace(t)
if mapped, ok := stripePaymentMethodTypes[t]; ok {
methods = append(methods, mapped...)
}
}
if len(methods) == 0 {
return []string{"card"}
}
return methods
}
// hasStripeMethod checks if the given Stripe method list contains the target method.
func hasStripeMethod(methods []string, target string) bool {
for _, m := range methods {
if m == target {
return true
}
}
return false
}
// CancelPayment cancels a pending PaymentIntent.
func (s *Stripe) CancelPayment(ctx context.Context, tradeNo string) error {
s.ensureInit()
_, err := s.sc.V1PaymentIntents.Cancel(ctx, tradeNo, nil)
if err != nil {
return fmt.Errorf("stripe cancel payment: %w", err)
}
return nil
}
// Ensure interface compliance.
var (
_ payment.Provider = (*Stripe)(nil)
_ payment.CancelableProvider = (*Stripe)(nil)
)
package provider
import (
"bytes"
"context"
"crypto/rsa"
"fmt"
"io"
"log/slog"
"net/http"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/wechatpay-apiv3/wechatpay-go/core"
"github.com/wechatpay-apiv3/wechatpay-go/core/auth/verifiers"
"github.com/wechatpay-apiv3/wechatpay-go/core/notify"
"github.com/wechatpay-apiv3/wechatpay-go/core/option"
"github.com/wechatpay-apiv3/wechatpay-go/services/payments"
"github.com/wechatpay-apiv3/wechatpay-go/services/payments/h5"
"github.com/wechatpay-apiv3/wechatpay-go/services/payments/native"
"github.com/wechatpay-apiv3/wechatpay-go/services/refunddomestic"
"github.com/wechatpay-apiv3/wechatpay-go/utils"
)
// WeChat Pay constants.
const (
wxpayCurrency = "CNY"
wxpayH5Type = "Wap"
)
// WeChat Pay trade states.
const (
wxpayTradeStateSuccess = "SUCCESS"
wxpayTradeStateRefund = "REFUND"
wxpayTradeStateClosed = "CLOSED"
wxpayTradeStatePayError = "PAYERROR"
)
// WeChat Pay notification event types.
const (
wxpayEventTransactionSuccess = "TRANSACTION.SUCCESS"
)
// WeChat Pay error codes.
const (
wxpayErrNoAuth = "NO_AUTH"
)
type Wxpay struct {
instanceID string
config map[string]string
mu sync.Mutex
coreClient *core.Client
notifyHandler *notify.Handler
}
func NewWxpay(instanceID string, config map[string]string) (*Wxpay, error) {
required := []string{"appId", "mchId", "privateKey", "apiV3Key", "publicKey", "publicKeyId", "certSerial"}
for _, k := range required {
if config[k] == "" {
return nil, fmt.Errorf("wxpay config missing required key: %s", k)
}
}
if len(config["apiV3Key"]) != 32 {
return nil, fmt.Errorf("wxpay apiV3Key must be exactly 32 bytes, got %d", len(config["apiV3Key"]))
}
return &Wxpay{instanceID: instanceID, config: config}, nil
}
func (w *Wxpay) Name() string { return "Wxpay" }
func (w *Wxpay) ProviderKey() string { return payment.TypeWxpay }
func (w *Wxpay) SupportedTypes() []payment.PaymentType {
return []payment.PaymentType{payment.TypeWxpayDirect}
}
func formatPEM(key, keyType string) string {
key = strings.TrimSpace(key)
if strings.HasPrefix(key, "-----BEGIN") {
return key
}
return fmt.Sprintf("-----BEGIN %s-----\n%s\n-----END %s-----", keyType, key, keyType)
}
func (w *Wxpay) ensureClient() (*core.Client, error) {
w.mu.Lock()
defer w.mu.Unlock()
if w.coreClient != nil {
return w.coreClient, nil
}
privateKey, publicKey, err := w.loadKeyPair()
if err != nil {
return nil, err
}
certSerial := w.config["certSerial"]
verifier := verifiers.NewSHA256WithRSAPubkeyVerifier(w.config["publicKeyId"], *publicKey)
client, err := core.NewClient(context.Background(),
option.WithMerchantCredential(w.config["mchId"], certSerial, privateKey),
option.WithVerifier(verifier))
if err != nil {
return nil, fmt.Errorf("wxpay init client: %w", err)
}
handler, err := notify.NewRSANotifyHandler(w.config["apiV3Key"], verifier)
if err != nil {
return nil, fmt.Errorf("wxpay init notify handler: %w", err)
}
w.notifyHandler = handler
w.coreClient = client
return w.coreClient, nil
}
func (w *Wxpay) loadKeyPair() (*rsa.PrivateKey, *rsa.PublicKey, error) {
privateKey, err := utils.LoadPrivateKey(formatPEM(w.config["privateKey"], "PRIVATE KEY"))
if err != nil {
return nil, nil, fmt.Errorf("wxpay load private key: %w", err)
}
publicKey, err := utils.LoadPublicKey(formatPEM(w.config["publicKey"], "PUBLIC KEY"))
if err != nil {
return nil, nil, fmt.Errorf("wxpay load public key: %w", err)
}
return privateKey, publicKey, nil
}
func (w *Wxpay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
client, err := w.ensureClient()
if err != nil {
return nil, err
}
// Request-first, config-fallback (consistent with EasyPay/Alipay)
notifyURL := req.NotifyURL
if notifyURL == "" {
notifyURL = w.config["notifyUrl"]
}
if notifyURL == "" {
return nil, fmt.Errorf("wxpay notifyUrl is required")
}
totalFen, err := payment.YuanToFen(req.Amount)
if err != nil {
return nil, fmt.Errorf("wxpay create payment: %w", err)
}
if req.IsMobile && req.ClientIP != "" {
resp, err := w.createOrder(ctx, client, req, notifyURL, totalFen, true)
if err == nil {
return resp, nil
}
if !strings.Contains(err.Error(), wxpayErrNoAuth) {
return nil, err
}
slog.Warn("wxpay H5 payment not authorized, falling back to native", "order", req.OrderID)
}
return w.createOrder(ctx, client, req, notifyURL, totalFen, false)
}
func (w *Wxpay) createOrder(ctx context.Context, c *core.Client, req payment.CreatePaymentRequest, notifyURL string, totalFen int64, useH5 bool) (*payment.CreatePaymentResponse, error) {
if useH5 {
return w.prepayH5(ctx, c, req, notifyURL, totalFen)
}
return w.prepayNative(ctx, c, req, notifyURL, totalFen)
}
func (w *Wxpay) prepayNative(ctx context.Context, c *core.Client, req payment.CreatePaymentRequest, notifyURL string, totalFen int64) (*payment.CreatePaymentResponse, error) {
svc := native.NativeApiService{Client: c}
cur := wxpayCurrency
resp, _, err := svc.Prepay(ctx, native.PrepayRequest{
Appid: core.String(w.config["appId"]), Mchid: core.String(w.config["mchId"]),
Description: core.String(req.Subject), OutTradeNo: core.String(req.OrderID),
NotifyUrl: core.String(notifyURL),
Amount: &native.Amount{Total: core.Int64(totalFen), Currency: &cur},
})
if err != nil {
return nil, fmt.Errorf("wxpay native prepay: %w", err)
}
codeURL := ""
if resp.CodeUrl != nil {
codeURL = *resp.CodeUrl
}
return &payment.CreatePaymentResponse{TradeNo: req.OrderID, QRCode: codeURL}, nil
}
func (w *Wxpay) prepayH5(ctx context.Context, c *core.Client, req payment.CreatePaymentRequest, notifyURL string, totalFen int64) (*payment.CreatePaymentResponse, error) {
svc := h5.H5ApiService{Client: c}
cur := wxpayCurrency
tp := wxpayH5Type
resp, _, err := svc.Prepay(ctx, h5.PrepayRequest{
Appid: core.String(w.config["appId"]), Mchid: core.String(w.config["mchId"]),
Description: core.String(req.Subject), OutTradeNo: core.String(req.OrderID),
NotifyUrl: core.String(notifyURL),
Amount: &h5.Amount{Total: core.Int64(totalFen), Currency: &cur},
SceneInfo: &h5.SceneInfo{PayerClientIp: core.String(req.ClientIP), H5Info: &h5.H5Info{Type: &tp}},
})
if err != nil {
return nil, fmt.Errorf("wxpay h5 prepay: %w", err)
}
h5URL := ""
if resp.H5Url != nil {
h5URL = *resp.H5Url
}
return &payment.CreatePaymentResponse{TradeNo: req.OrderID, PayURL: h5URL}, nil
}
func wxSV(s *string) string {
if s == nil {
return ""
}
return *s
}
func mapWxState(s string) string {
switch s {
case wxpayTradeStateSuccess:
return payment.ProviderStatusPaid
case wxpayTradeStateRefund:
return payment.ProviderStatusRefunded
case wxpayTradeStateClosed, wxpayTradeStatePayError:
return payment.ProviderStatusFailed
default:
return payment.ProviderStatusPending
}
}
func (w *Wxpay) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryOrderResponse, error) {
c, err := w.ensureClient()
if err != nil {
return nil, err
}
svc := native.NativeApiService{Client: c}
tx, _, err := svc.QueryOrderByOutTradeNo(ctx, native.QueryOrderByOutTradeNoRequest{
OutTradeNo: core.String(tradeNo), Mchid: core.String(w.config["mchId"]),
})
if err != nil {
return nil, fmt.Errorf("wxpay query order: %w", err)
}
var amt float64
if tx.Amount != nil && tx.Amount.Total != nil {
amt = payment.FenToYuan(*tx.Amount.Total)
}
id := tradeNo
if tx.TransactionId != nil {
id = *tx.TransactionId
}
pa := ""
if tx.SuccessTime != nil {
pa = *tx.SuccessTime
}
return &payment.QueryOrderResponse{TradeNo: id, Status: mapWxState(wxSV(tx.TradeState)), Amount: amt, PaidAt: pa}, nil
}
func (w *Wxpay) VerifyNotification(ctx context.Context, rawBody string, headers map[string]string) (*payment.PaymentNotification, error) {
if _, err := w.ensureClient(); err != nil {
return nil, err
}
r, err := http.NewRequestWithContext(ctx, http.MethodPost, "/", io.NopCloser(bytes.NewBufferString(rawBody)))
if err != nil {
return nil, fmt.Errorf("wxpay construct request: %w", err)
}
for k, v := range headers {
r.Header.Set(k, v)
}
var tx payments.Transaction
nr, err := w.notifyHandler.ParseNotifyRequest(ctx, r, &tx)
if err != nil {
return nil, fmt.Errorf("wxpay verify notification: %w", err)
}
if nr.EventType != wxpayEventTransactionSuccess {
return nil, nil
}
var amt float64
if tx.Amount != nil && tx.Amount.Total != nil {
amt = payment.FenToYuan(*tx.Amount.Total)
}
st := payment.ProviderStatusFailed
if wxSV(tx.TradeState) == wxpayTradeStateSuccess {
st = payment.ProviderStatusSuccess
}
return &payment.PaymentNotification{
TradeNo: wxSV(tx.TransactionId), OrderID: wxSV(tx.OutTradeNo),
Amount: amt, Status: st, RawData: rawBody,
}, nil
}
func (w *Wxpay) Refund(ctx context.Context, req payment.RefundRequest) (*payment.RefundResponse, error) {
c, err := w.ensureClient()
if err != nil {
return nil, err
}
rf, err := payment.YuanToFen(req.Amount)
if err != nil {
return nil, fmt.Errorf("wxpay refund amount: %w", err)
}
tf, err := w.queryOrderTotalFen(ctx, c, req.OrderID)
if err != nil {
return nil, err
}
rs := refunddomestic.RefundsApiService{Client: c}
cur := wxpayCurrency
res, _, err := rs.Create(ctx, refunddomestic.CreateRequest{
OutTradeNo: core.String(req.OrderID),
OutRefundNo: core.String(fmt.Sprintf("%s-refund-%d", req.OrderID, time.Now().UnixNano())),
Reason: core.String(req.Reason),
Amount: &refunddomestic.AmountReq{Refund: core.Int64(rf), Total: core.Int64(tf), Currency: &cur},
})
if err != nil {
return nil, fmt.Errorf("wxpay refund: %w", err)
}
rid := wxSV(res.RefundId)
if rid == "" {
rid = fmt.Sprintf("%s-refund", req.OrderID)
}
st := payment.ProviderStatusPending
if res.Status != nil && *res.Status == refunddomestic.STATUS_SUCCESS {
st = payment.ProviderStatusSuccess
}
return &payment.RefundResponse{RefundID: rid, Status: st}, nil
}
func (w *Wxpay) queryOrderTotalFen(ctx context.Context, c *core.Client, orderID string) (int64, error) {
svc := native.NativeApiService{Client: c}
tx, _, err := svc.QueryOrderByOutTradeNo(ctx, native.QueryOrderByOutTradeNoRequest{
OutTradeNo: core.String(orderID), Mchid: core.String(w.config["mchId"]),
})
if err != nil {
return 0, fmt.Errorf("wxpay refund query order: %w", err)
}
var tf int64
if tx.Amount != nil && tx.Amount.Total != nil {
tf = *tx.Amount.Total
}
return tf, nil
}
func (w *Wxpay) CancelPayment(ctx context.Context, tradeNo string) error {
c, err := w.ensureClient()
if err != nil {
return err
}
svc := native.NativeApiService{Client: c}
_, err = svc.CloseOrder(ctx, native.CloseOrderRequest{
OutTradeNo: core.String(tradeNo), Mchid: core.String(w.config["mchId"]),
})
if err != nil {
return fmt.Errorf("wxpay cancel payment: %w", err)
}
return nil
}
var (
_ payment.Provider = (*Wxpay)(nil)
_ payment.CancelableProvider = (*Wxpay)(nil)
)
//go:build unit
package provider
import (
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/payment"
)
func TestMapWxState(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
want string
}{
{
name: "SUCCESS maps to paid",
input: wxpayTradeStateSuccess,
want: payment.ProviderStatusPaid,
},
{
name: "REFUND maps to refunded",
input: wxpayTradeStateRefund,
want: payment.ProviderStatusRefunded,
},
{
name: "CLOSED maps to failed",
input: wxpayTradeStateClosed,
want: payment.ProviderStatusFailed,
},
{
name: "PAYERROR maps to failed",
input: wxpayTradeStatePayError,
want: payment.ProviderStatusFailed,
},
{
name: "unknown state maps to pending",
input: "NOTPAY",
want: payment.ProviderStatusPending,
},
{
name: "empty string maps to pending",
input: "",
want: payment.ProviderStatusPending,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := mapWxState(tt.input)
if got != tt.want {
t.Errorf("mapWxState(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestWxSV(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input *string
want string
}{
{
name: "nil pointer returns empty string",
input: nil,
want: "",
},
{
name: "non-nil pointer returns value",
input: strPtr("hello"),
want: "hello",
},
{
name: "pointer to empty string returns empty string",
input: strPtr(""),
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := wxSV(tt.input)
if got != tt.want {
t.Errorf("wxSV() = %q, want %q", got, tt.want)
}
})
}
}
func strPtr(s string) *string {
return &s
}
func TestFormatPEM(t *testing.T) {
t.Parallel()
tests := []struct {
name string
key string
keyType string
want string
}{
{
name: "raw key gets wrapped with headers",
key: "MIIBIjANBgkqhki...",
keyType: "PUBLIC KEY",
want: "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhki...\n-----END PUBLIC KEY-----",
},
{
name: "already formatted key is returned as-is",
key: "-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBg...\n-----END PRIVATE KEY-----",
keyType: "PRIVATE KEY",
want: "-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBg...\n-----END PRIVATE KEY-----",
},
{
name: "key with leading/trailing whitespace is trimmed before check",
key: " \n MIIBIjANBgkqhki... \n ",
keyType: "PUBLIC KEY",
want: "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhki...\n-----END PUBLIC KEY-----",
},
{
name: "already formatted key with whitespace is trimmed and returned",
key: " -----BEGIN RSA PRIVATE KEY-----\ndata\n-----END RSA PRIVATE KEY----- ",
keyType: "RSA PRIVATE KEY",
want: "-----BEGIN RSA PRIVATE KEY-----\ndata\n-----END RSA PRIVATE KEY-----",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := formatPEM(tt.key, tt.keyType)
if got != tt.want {
t.Errorf("formatPEM(%q, %q) =\n%s\nwant:\n%s", tt.key, tt.keyType, got, tt.want)
}
})
}
}
func TestNewWxpay(t *testing.T) {
t.Parallel()
validConfig := map[string]string{
"appId": "wx1234567890",
"mchId": "1234567890",
"privateKey": "fake-private-key",
"apiV3Key": "12345678901234567890123456789012", // exactly 32 bytes
"publicKey": "fake-public-key",
"publicKeyId": "key-id-001",
"certSerial": "SERIAL001",
}
// helper to clone and override config fields
withOverride := func(overrides map[string]string) map[string]string {
cfg := make(map[string]string, len(validConfig))
for k, v := range validConfig {
cfg[k] = v
}
for k, v := range overrides {
cfg[k] = v
}
return cfg
}
tests := []struct {
name string
config map[string]string
wantErr bool
errSubstr string
}{
{
name: "valid config succeeds",
config: validConfig,
wantErr: false,
},
{
name: "missing appId",
config: withOverride(map[string]string{"appId": ""}),
wantErr: true,
errSubstr: "appId",
},
{
name: "missing mchId",
config: withOverride(map[string]string{"mchId": ""}),
wantErr: true,
errSubstr: "mchId",
},
{
name: "missing privateKey",
config: withOverride(map[string]string{"privateKey": ""}),
wantErr: true,
errSubstr: "privateKey",
},
{
name: "missing apiV3Key",
config: withOverride(map[string]string{"apiV3Key": ""}),
wantErr: true,
errSubstr: "apiV3Key",
},
{
name: "missing publicKey",
config: withOverride(map[string]string{"publicKey": ""}),
wantErr: true,
errSubstr: "publicKey",
},
{
name: "missing publicKeyId",
config: withOverride(map[string]string{"publicKeyId": ""}),
wantErr: true,
errSubstr: "publicKeyId",
},
{
name: "apiV3Key too short",
config: withOverride(map[string]string{"apiV3Key": "short"}),
wantErr: true,
errSubstr: "exactly 32 bytes",
},
{
name: "apiV3Key too long",
config: withOverride(map[string]string{"apiV3Key": "123456789012345678901234567890123"}), // 33 bytes
wantErr: true,
errSubstr: "exactly 32 bytes",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := NewWxpay("test-instance", tt.config)
if tt.wantErr {
if err == nil {
t.Fatal("expected error, got nil")
}
if tt.errSubstr != "" && !strings.Contains(err.Error(), tt.errSubstr) {
t.Errorf("error %q should contain %q", err.Error(), tt.errSubstr)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got == nil {
t.Fatal("expected non-nil Wxpay instance")
}
if got.instanceID != "test-instance" {
t.Errorf("instanceID = %q, want %q", got.instanceID, "test-instance")
}
})
}
}
package payment
import (
"sync"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
// Registry is a thread-safe registry mapping PaymentType to Provider.
type Registry struct {
mu sync.RWMutex
providers map[PaymentType]Provider
}
// ErrProviderNotFound is returned when a requested payment provider is not registered.
var ErrProviderNotFound = infraerrors.NotFound("PROVIDER_NOT_FOUND", "payment provider not registered")
// NewRegistry creates a new empty provider registry.
func NewRegistry() *Registry {
return &Registry{
providers: make(map[PaymentType]Provider),
}
}
// Register adds a provider for each of its supported payment types.
// If a type was previously registered, it is overwritten.
func (r *Registry) Register(p Provider) {
r.mu.Lock()
defer r.mu.Unlock()
for _, t := range p.SupportedTypes() {
r.providers[t] = p
}
}
// GetProvider returns the provider registered for the given payment type.
func (r *Registry) GetProvider(t PaymentType) (Provider, error) {
r.mu.RLock()
defer r.mu.RUnlock()
p, ok := r.providers[t]
if !ok {
return nil, ErrProviderNotFound
}
return p, nil
}
// GetProviderByKey returns the first provider whose ProviderKey matches the given key.
func (r *Registry) GetProviderByKey(key string) (Provider, error) {
r.mu.RLock()
defer r.mu.RUnlock()
for _, p := range r.providers {
if p.ProviderKey() == key {
return p, nil
}
}
return nil, ErrProviderNotFound
}
// GetProviderKey returns the provider key for the given payment type, or empty string if not found.
func (r *Registry) GetProviderKey(t PaymentType) string {
r.mu.RLock()
defer r.mu.RUnlock()
p, ok := r.providers[t]
if !ok {
return ""
}
return p.ProviderKey()
}
// SupportedTypes returns all currently registered payment types.
func (r *Registry) SupportedTypes() []PaymentType {
r.mu.RLock()
defer r.mu.RUnlock()
types := make([]PaymentType, 0, len(r.providers))
for t := range r.providers {
types = append(types, t)
}
return types
}
// Clear removes all registered providers.
func (r *Registry) Clear() {
r.mu.Lock()
defer r.mu.Unlock()
r.providers = make(map[PaymentType]Provider)
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment