Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
97f14b7a
Unverified
Commit
97f14b7a
authored
Apr 11, 2026
by
Wesley Liddick
Committed by
GitHub
Apr 11, 2026
Browse files
Merge pull request #1572 from touwaeriol/feat/payment-system-v2
feat(payment): add complete payment system with multi-provider support
parents
1ef3782d
6793503e
Changes
174
Show whitespace changes
Inline
Side-by-side
backend/internal/handler/payment_webhook_handler_test.go
0 → 100644
View file @
97f14b7a
//go:build unit
package
handler
import
(
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func
TestWriteSuccessResponse
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
tests
:=
[]
struct
{
name
string
providerKey
string
wantCode
int
wantContentType
string
wantBody
string
checkJSON
bool
wantJSONCode
string
wantJSONMessage
string
}{
{
name
:
"wxpay returns JSON with code SUCCESS"
,
providerKey
:
"wxpay"
,
wantCode
:
http
.
StatusOK
,
wantContentType
:
"application/json"
,
checkJSON
:
true
,
wantJSONCode
:
"SUCCESS"
,
wantJSONMessage
:
"成功"
,
},
{
name
:
"stripe returns empty 200"
,
providerKey
:
"stripe"
,
wantCode
:
http
.
StatusOK
,
wantContentType
:
"text/plain"
,
wantBody
:
""
,
},
{
name
:
"easypay returns plain text success"
,
providerKey
:
"easypay"
,
wantCode
:
http
.
StatusOK
,
wantContentType
:
"text/plain"
,
wantBody
:
"success"
,
},
{
name
:
"alipay returns plain text success"
,
providerKey
:
"alipay"
,
wantCode
:
http
.
StatusOK
,
wantContentType
:
"text/plain"
,
wantBody
:
"success"
,
},
{
name
:
"unknown provider returns plain text success"
,
providerKey
:
"unknown_provider"
,
wantCode
:
http
.
StatusOK
,
wantContentType
:
"text/plain"
,
wantBody
:
"success"
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
writeSuccessResponse
(
c
,
tt
.
providerKey
)
assert
.
Equal
(
t
,
tt
.
wantCode
,
w
.
Code
)
assert
.
Contains
(
t
,
w
.
Header
()
.
Get
(
"Content-Type"
),
tt
.
wantContentType
)
if
tt
.
checkJSON
{
var
resp
wxpaySuccessResponse
err
:=
json
.
Unmarshal
(
w
.
Body
.
Bytes
(),
&
resp
)
require
.
NoError
(
t
,
err
,
"response body should be valid JSON"
)
assert
.
Equal
(
t
,
tt
.
wantJSONCode
,
resp
.
Code
)
assert
.
Equal
(
t
,
tt
.
wantJSONMessage
,
resp
.
Message
)
}
else
{
assert
.
Equal
(
t
,
tt
.
wantBody
,
w
.
Body
.
String
())
}
})
}
}
func
TestWebhookConstants
(
t
*
testing
.
T
)
{
t
.
Run
(
"maxWebhookBodySize is 1MB"
,
func
(
t
*
testing
.
T
)
{
assert
.
Equal
(
t
,
int64
(
1
<<
20
),
int64
(
maxWebhookBodySize
))
})
t
.
Run
(
"webhookLogTruncateLen is 200"
,
func
(
t
*
testing
.
T
)
{
assert
.
Equal
(
t
,
200
,
webhookLogTruncateLen
)
})
}
backend/internal/handler/setting_handler.go
View file @
97f14b7a
...
...
@@ -59,6 +59,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
OIDCOAuthEnabled
:
settings
.
OIDCOAuthEnabled
,
OIDCOAuthProviderName
:
settings
.
OIDCOAuthProviderName
,
BackendModeEnabled
:
settings
.
BackendModeEnabled
,
PaymentEnabled
:
settings
.
PaymentEnabled
,
Version
:
h
.
version
,
})
}
backend/internal/handler/wire.go
View file @
97f14b7a
...
...
@@ -34,6 +34,7 @@ func ProvideAdminHandlers(
apiKeyHandler
*
admin
.
AdminAPIKeyHandler
,
scheduledTestHandler
*
admin
.
ScheduledTestHandler
,
channelHandler
*
admin
.
ChannelHandler
,
paymentHandler
*
admin
.
PaymentHandler
,
)
*
AdminHandlers
{
return
&
AdminHandlers
{
Dashboard
:
dashboardHandler
,
...
...
@@ -61,6 +62,7 @@ func ProvideAdminHandlers(
APIKey
:
apiKeyHandler
,
ScheduledTest
:
scheduledTestHandler
,
Channel
:
channelHandler
,
Payment
:
paymentHandler
,
}
}
...
...
@@ -88,6 +90,8 @@ func ProvideHandlers(
openaiGatewayHandler
*
OpenAIGatewayHandler
,
settingHandler
*
SettingHandler
,
totpHandler
*
TotpHandler
,
paymentHandler
*
PaymentHandler
,
paymentWebhookHandler
*
PaymentWebhookHandler
,
_
*
service
.
IdempotencyCoordinator
,
_
*
service
.
IdempotencyCleanupService
,
)
*
Handlers
{
...
...
@@ -104,6 +108,8 @@ func ProvideHandlers(
OpenAIGateway
:
openaiGatewayHandler
,
Setting
:
settingHandler
,
Totp
:
totpHandler
,
Payment
:
paymentHandler
,
PaymentWebhook
:
paymentWebhookHandler
,
}
}
...
...
@@ -121,6 +127,8 @@ var ProviderSet = wire.NewSet(
NewOpenAIGatewayHandler
,
NewTotpHandler
,
ProvideSettingHandler
,
NewPaymentHandler
,
NewPaymentWebhookHandler
,
// Admin handlers
admin
.
NewDashboardHandler
,
...
...
@@ -148,6 +156,7 @@ var ProviderSet = wire.NewSet(
admin
.
NewAdminAPIKeyHandler
,
admin
.
NewScheduledTestHandler
,
admin
.
NewChannelHandler
,
admin
.
NewPaymentHandler
,
// AdminHandlers and Handlers constructors
ProvideAdminHandlers
,
...
...
backend/internal/payment/amount.go
0 → 100644
View file @
97f14b7a
package
payment
import
(
"fmt"
"github.com/shopspring/decimal"
)
const
centsPerYuan
=
100
// YuanToFen converts a CNY yuan string (e.g. "10.50") to fen (int64).
// Uses shopspring/decimal for precision.
func
YuanToFen
(
yuanStr
string
)
(
int64
,
error
)
{
d
,
err
:=
decimal
.
NewFromString
(
yuanStr
)
if
err
!=
nil
{
return
0
,
fmt
.
Errorf
(
"invalid amount: %s"
,
yuanStr
)
}
return
d
.
Mul
(
decimal
.
NewFromInt
(
centsPerYuan
))
.
IntPart
(),
nil
}
// FenToYuan converts fen (int64) to yuan as a float64 for interface compatibility.
func
FenToYuan
(
fen
int64
)
float64
{
return
decimal
.
NewFromInt
(
fen
)
.
Div
(
decimal
.
NewFromInt
(
centsPerYuan
))
.
InexactFloat64
()
}
backend/internal/payment/amount_test.go
0 → 100644
View file @
97f14b7a
//go:build unit
package
payment
import
(
"math"
"testing"
)
func
TestYuanToFen
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
input
string
want
int64
wantErr
bool
}{
// Normal values
{
name
:
"one yuan"
,
input
:
"1.00"
,
want
:
100
},
{
name
:
"ten yuan fifty fen"
,
input
:
"10.50"
,
want
:
1050
},
{
name
:
"one fen"
,
input
:
"0.01"
,
want
:
1
},
{
name
:
"large amount"
,
input
:
"99999.99"
,
want
:
9999999
},
// Edge: zero
{
name
:
"zero no decimal"
,
input
:
"0"
,
want
:
0
},
{
name
:
"zero with decimal"
,
input
:
"0.00"
,
want
:
0
},
// IEEE 754 precision edge case: 1.15 * 100 = 114.99999... in float64
{
name
:
"ieee754 precision 1.15"
,
input
:
"1.15"
,
want
:
115
},
// More precision edge cases
{
name
:
"ieee754 precision 0.1"
,
input
:
"0.1"
,
want
:
10
},
{
name
:
"ieee754 precision 0.2"
,
input
:
"0.2"
,
want
:
20
},
{
name
:
"ieee754 precision 33.33"
,
input
:
"33.33"
,
want
:
3333
},
// Large value
{
name
:
"hundred thousand"
,
input
:
"100000.00"
,
want
:
10000000
},
// Integer without decimal
{
name
:
"integer 5"
,
input
:
"5"
,
want
:
500
},
{
name
:
"integer 100"
,
input
:
"100"
,
want
:
10000
},
// Single decimal place
{
name
:
"single decimal 1.5"
,
input
:
"1.5"
,
want
:
150
},
// Negative values
{
name
:
"negative one yuan"
,
input
:
"-1.00"
,
want
:
-
100
},
{
name
:
"negative with fen"
,
input
:
"-10.50"
,
want
:
-
1050
},
// Invalid inputs
{
name
:
"empty string"
,
input
:
""
,
wantErr
:
true
},
{
name
:
"alphabetic"
,
input
:
"abc"
,
wantErr
:
true
},
{
name
:
"double dot"
,
input
:
"1.2.3"
,
wantErr
:
true
},
{
name
:
"spaces"
,
input
:
" "
,
wantErr
:
true
},
{
name
:
"special chars"
,
input
:
"$10.00"
,
wantErr
:
true
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
got
,
err
:=
YuanToFen
(
tt
.
input
)
if
tt
.
wantErr
{
if
err
==
nil
{
t
.
Errorf
(
"YuanToFen(%q) expected error, got %d"
,
tt
.
input
,
got
)
}
return
}
if
err
!=
nil
{
t
.
Fatalf
(
"YuanToFen(%q) unexpected error: %v"
,
tt
.
input
,
err
)
}
if
got
!=
tt
.
want
{
t
.
Errorf
(
"YuanToFen(%q) = %d, want %d"
,
tt
.
input
,
got
,
tt
.
want
)
}
})
}
}
func
TestFenToYuan
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
fen
int64
want
float64
}{
{
name
:
"one yuan"
,
fen
:
100
,
want
:
1.0
},
{
name
:
"ten yuan fifty fen"
,
fen
:
1050
,
want
:
10.5
},
{
name
:
"one fen"
,
fen
:
1
,
want
:
0.01
},
{
name
:
"zero"
,
fen
:
0
,
want
:
0.0
},
{
name
:
"large amount"
,
fen
:
9999999
,
want
:
99999.99
},
{
name
:
"negative"
,
fen
:
-
100
,
want
:
-
1.0
},
{
name
:
"negative with fen"
,
fen
:
-
1050
,
want
:
-
10.5
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
got
:=
FenToYuan
(
tt
.
fen
)
if
math
.
Abs
(
got
-
tt
.
want
)
>
1e-9
{
t
.
Errorf
(
"FenToYuan(%d) = %f, want %f"
,
tt
.
fen
,
got
,
tt
.
want
)
}
})
}
}
func
TestYuanToFenRoundTrip
(
t
*
testing
.
T
)
{
// Verify that converting yuan->fen->yuan preserves the value.
cases
:=
[]
struct
{
yuan
string
fen
int64
}{
{
"0.01"
,
1
},
{
"1.00"
,
100
},
{
"10.50"
,
1050
},
{
"99999.99"
,
9999999
},
}
for
_
,
tc
:=
range
cases
{
fen
,
err
:=
YuanToFen
(
tc
.
yuan
)
if
err
!=
nil
{
t
.
Fatalf
(
"YuanToFen(%q) unexpected error: %v"
,
tc
.
yuan
,
err
)
}
if
fen
!=
tc
.
fen
{
t
.
Errorf
(
"YuanToFen(%q) = %d, want %d"
,
tc
.
yuan
,
fen
,
tc
.
fen
)
}
yuan
:=
FenToYuan
(
fen
)
// Parse expected yuan back for comparison
expectedYuan
:=
FenToYuan
(
tc
.
fen
)
if
math
.
Abs
(
yuan
-
expectedYuan
)
>
1e-9
{
t
.
Errorf
(
"round-trip: FenToYuan(%d) = %f, want %f"
,
fen
,
yuan
,
expectedYuan
)
}
}
}
backend/internal/payment/crypto.go
0 → 100644
View file @
97f14b7a
package
payment
import
(
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"fmt"
"io"
"strings"
)
// Encrypt encrypts plaintext using AES-256-GCM with the given 32-byte key.
// The output format is "iv:authTag:ciphertext" where each component is base64-encoded,
// matching the Node.js crypto.ts format for cross-compatibility.
func
Encrypt
(
plaintext
string
,
key
[]
byte
)
(
string
,
error
)
{
if
len
(
key
)
!=
32
{
return
""
,
fmt
.
Errorf
(
"encryption key must be 32 bytes, got %d"
,
len
(
key
))
}
block
,
err
:=
aes
.
NewCipher
(
key
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"create AES cipher: %w"
,
err
)
}
gcm
,
err
:=
cipher
.
NewGCM
(
block
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"create GCM: %w"
,
err
)
}
nonce
:=
make
([]
byte
,
gcm
.
NonceSize
())
// 12 bytes for GCM
if
_
,
err
:=
io
.
ReadFull
(
rand
.
Reader
,
nonce
);
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"generate nonce: %w"
,
err
)
}
// Seal appends the ciphertext + auth tag
sealed
:=
gcm
.
Seal
(
nil
,
nonce
,
[]
byte
(
plaintext
),
nil
)
// Split sealed into ciphertext and auth tag (last 16 bytes)
tagSize
:=
gcm
.
Overhead
()
ciphertext
:=
sealed
[
:
len
(
sealed
)
-
tagSize
]
authTag
:=
sealed
[
len
(
sealed
)
-
tagSize
:
]
// Format: iv:authTag:ciphertext (all base64)
return
fmt
.
Sprintf
(
"%s:%s:%s"
,
base64
.
StdEncoding
.
EncodeToString
(
nonce
),
base64
.
StdEncoding
.
EncodeToString
(
authTag
),
base64
.
StdEncoding
.
EncodeToString
(
ciphertext
),
),
nil
}
// Decrypt decrypts a ciphertext string produced by Encrypt.
// The input format is "iv:authTag:ciphertext" where each component is base64-encoded.
func
Decrypt
(
ciphertext
string
,
key
[]
byte
)
(
string
,
error
)
{
if
len
(
key
)
!=
32
{
return
""
,
fmt
.
Errorf
(
"encryption key must be 32 bytes, got %d"
,
len
(
key
))
}
parts
:=
strings
.
SplitN
(
ciphertext
,
":"
,
3
)
if
len
(
parts
)
!=
3
{
return
""
,
fmt
.
Errorf
(
"invalid ciphertext format: expected iv:authTag:ciphertext"
)
}
nonce
,
err
:=
base64
.
StdEncoding
.
DecodeString
(
parts
[
0
])
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"decode IV: %w"
,
err
)
}
authTag
,
err
:=
base64
.
StdEncoding
.
DecodeString
(
parts
[
1
])
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"decode auth tag: %w"
,
err
)
}
encrypted
,
err
:=
base64
.
StdEncoding
.
DecodeString
(
parts
[
2
])
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"decode ciphertext: %w"
,
err
)
}
block
,
err
:=
aes
.
NewCipher
(
key
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"create AES cipher: %w"
,
err
)
}
gcm
,
err
:=
cipher
.
NewGCM
(
block
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"create GCM: %w"
,
err
)
}
// Reconstruct the sealed data: ciphertext + authTag
sealed
:=
append
(
encrypted
,
authTag
...
)
plaintext
,
err
:=
gcm
.
Open
(
nil
,
nonce
,
sealed
,
nil
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"decrypt: %w"
,
err
)
}
return
string
(
plaintext
),
nil
}
backend/internal/payment/crypto_test.go
0 → 100644
View file @
97f14b7a
package
payment
import
(
"crypto/rand"
"strings"
"testing"
)
func
makeKey
(
t
*
testing
.
T
)
[]
byte
{
t
.
Helper
()
key
:=
make
([]
byte
,
32
)
if
_
,
err
:=
rand
.
Read
(
key
);
err
!=
nil
{
t
.
Fatalf
(
"generate random key: %v"
,
err
)
}
return
key
}
func
TestEncryptDecryptRoundTrip
(
t
*
testing
.
T
)
{
t
.
Parallel
()
key
:=
makeKey
(
t
)
plaintexts
:=
[]
string
{
"hello world"
,
"short"
,
"a longer string with special chars: !@#$%^&*()"
,
`{"key":"value","num":42}`
,
"你好世界 unicode test 🎉"
,
strings
.
Repeat
(
"x"
,
10000
),
}
for
_
,
pt
:=
range
plaintexts
{
encrypted
,
err
:=
Encrypt
(
pt
,
key
)
if
err
!=
nil
{
t
.
Fatalf
(
"Encrypt(%q) error: %v"
,
pt
[
:
min
(
len
(
pt
),
30
)],
err
)
}
decrypted
,
err
:=
Decrypt
(
encrypted
,
key
)
if
err
!=
nil
{
t
.
Fatalf
(
"Decrypt error for plaintext %q: %v"
,
pt
[
:
min
(
len
(
pt
),
30
)],
err
)
}
if
decrypted
!=
pt
{
t
.
Fatalf
(
"round-trip failed: got %q, want %q"
,
decrypted
[
:
min
(
len
(
decrypted
),
30
)],
pt
[
:
min
(
len
(
pt
),
30
)])
}
}
}
func
TestEncryptProducesDifferentCiphertexts
(
t
*
testing
.
T
)
{
t
.
Parallel
()
key
:=
makeKey
(
t
)
ct1
,
err
:=
Encrypt
(
"same plaintext"
,
key
)
if
err
!=
nil
{
t
.
Fatalf
(
"first Encrypt error: %v"
,
err
)
}
ct2
,
err
:=
Encrypt
(
"same plaintext"
,
key
)
if
err
!=
nil
{
t
.
Fatalf
(
"second Encrypt error: %v"
,
err
)
}
if
ct1
==
ct2
{
t
.
Fatal
(
"two encryptions of the same plaintext should produce different ciphertexts (random nonce)"
)
}
}
func
TestDecryptWithWrongKeyFails
(
t
*
testing
.
T
)
{
t
.
Parallel
()
key1
:=
makeKey
(
t
)
key2
:=
makeKey
(
t
)
encrypted
,
err
:=
Encrypt
(
"secret data"
,
key1
)
if
err
!=
nil
{
t
.
Fatalf
(
"Encrypt error: %v"
,
err
)
}
_
,
err
=
Decrypt
(
encrypted
,
key2
)
if
err
==
nil
{
t
.
Fatal
(
"Decrypt with wrong key should fail, but got nil error"
)
}
}
func
TestEncryptRejectsInvalidKeyLength
(
t
*
testing
.
T
)
{
t
.
Parallel
()
badKeys
:=
[][]
byte
{
nil
,
make
([]
byte
,
0
),
make
([]
byte
,
16
),
make
([]
byte
,
31
),
make
([]
byte
,
33
),
make
([]
byte
,
64
),
}
for
_
,
key
:=
range
badKeys
{
_
,
err
:=
Encrypt
(
"test"
,
key
)
if
err
==
nil
{
t
.
Fatalf
(
"Encrypt should reject key of length %d"
,
len
(
key
))
}
}
}
func
TestDecryptRejectsInvalidKeyLength
(
t
*
testing
.
T
)
{
t
.
Parallel
()
badKeys
:=
[][]
byte
{
nil
,
make
([]
byte
,
16
),
make
([]
byte
,
33
),
}
for
_
,
key
:=
range
badKeys
{
_
,
err
:=
Decrypt
(
"dummydata:dummydata:dummydata"
,
key
)
if
err
==
nil
{
t
.
Fatalf
(
"Decrypt should reject key of length %d"
,
len
(
key
))
}
}
}
func
TestEncryptEmptyPlaintext
(
t
*
testing
.
T
)
{
t
.
Parallel
()
key
:=
makeKey
(
t
)
encrypted
,
err
:=
Encrypt
(
""
,
key
)
if
err
!=
nil
{
t
.
Fatalf
(
"Encrypt empty plaintext error: %v"
,
err
)
}
decrypted
,
err
:=
Decrypt
(
encrypted
,
key
)
if
err
!=
nil
{
t
.
Fatalf
(
"Decrypt empty plaintext error: %v"
,
err
)
}
if
decrypted
!=
""
{
t
.
Fatalf
(
"expected empty string, got %q"
,
decrypted
)
}
}
func
TestEncryptDecryptUnicodeJSON
(
t
*
testing
.
T
)
{
t
.
Parallel
()
key
:=
makeKey
(
t
)
jsonContent
:=
`{"name":"测试用户","email":"test@example.com","balance":100.50}`
encrypted
,
err
:=
Encrypt
(
jsonContent
,
key
)
if
err
!=
nil
{
t
.
Fatalf
(
"Encrypt JSON error: %v"
,
err
)
}
decrypted
,
err
:=
Decrypt
(
encrypted
,
key
)
if
err
!=
nil
{
t
.
Fatalf
(
"Decrypt JSON error: %v"
,
err
)
}
if
decrypted
!=
jsonContent
{
t
.
Fatalf
(
"JSON round-trip failed: got %q, want %q"
,
decrypted
,
jsonContent
)
}
}
func
TestDecryptInvalidFormat
(
t
*
testing
.
T
)
{
t
.
Parallel
()
key
:=
makeKey
(
t
)
invalidInputs
:=
[]
string
{
""
,
"nodelimiter"
,
"only:two"
,
"invalid:base64:!!!"
,
}
for
_
,
input
:=
range
invalidInputs
{
_
,
err
:=
Decrypt
(
input
,
key
)
if
err
==
nil
{
t
.
Fatalf
(
"Decrypt(%q) should fail but got nil error"
,
input
)
}
}
}
func
TestCiphertextFormat
(
t
*
testing
.
T
)
{
t
.
Parallel
()
key
:=
makeKey
(
t
)
encrypted
,
err
:=
Encrypt
(
"test"
,
key
)
if
err
!=
nil
{
t
.
Fatalf
(
"Encrypt error: %v"
,
err
)
}
parts
:=
strings
.
SplitN
(
encrypted
,
":"
,
3
)
if
len
(
parts
)
!=
3
{
t
.
Fatalf
(
"ciphertext should have format iv:authTag:ciphertext, got %d parts"
,
len
(
parts
))
}
for
i
,
part
:=
range
parts
{
if
part
==
""
{
t
.
Fatalf
(
"ciphertext part %d is empty"
,
i
)
}
}
}
backend/internal/payment/fee.go
0 → 100644
View file @
97f14b7a
package
payment
import
(
"github.com/shopspring/decimal"
)
// CalculatePayAmount computes the total pay amount given a recharge amount and
// fee rate (percentage). Fee = amount * feeRate / 100, rounded UP (away from zero)
// to 2 decimal places. The returned string is formatted to exactly 2 decimal places.
// If feeRate <= 0, the amount is returned as-is (formatted to 2 decimal places).
func
CalculatePayAmount
(
rechargeAmount
float64
,
feeRate
float64
)
string
{
amount
:=
decimal
.
NewFromFloat
(
rechargeAmount
)
if
feeRate
<=
0
{
return
amount
.
StringFixed
(
2
)
}
rate
:=
decimal
.
NewFromFloat
(
feeRate
)
fee
:=
amount
.
Mul
(
rate
)
.
Div
(
decimal
.
NewFromInt
(
100
))
.
RoundUp
(
2
)
return
amount
.
Add
(
fee
)
.
StringFixed
(
2
)
}
backend/internal/payment/fee_test.go
0 → 100644
View file @
97f14b7a
package
payment
import
(
"testing"
)
func
TestCalculatePayAmount
(
t
*
testing
.
T
)
{
t
.
Parallel
()
tests
:=
[]
struct
{
name
string
amount
float64
feeRate
float64
expected
string
}{
{
name
:
"zero fee rate returns same amount"
,
amount
:
100.00
,
feeRate
:
0
,
expected
:
"100.00"
,
},
{
name
:
"negative fee rate returns same amount"
,
amount
:
50.00
,
feeRate
:
-
5
,
expected
:
"50.00"
,
},
{
name
:
"1 percent fee rate"
,
amount
:
100.00
,
feeRate
:
1
,
expected
:
"101.00"
,
},
{
name
:
"5 percent fee on 200"
,
amount
:
200.00
,
feeRate
:
5
,
expected
:
"210.00"
,
},
{
name
:
"fee rounds UP to 2 decimal places"
,
amount
:
100.00
,
feeRate
:
3
,
expected
:
"103.00"
,
},
{
name
:
"fee rounds UP small remainder"
,
amount
:
10.00
,
feeRate
:
3.33
,
expected
:
"10.34"
,
// 10 * 3.33 / 100 = 0.333 -> round up -> 0.34
},
{
name
:
"very small amount"
,
amount
:
0.01
,
feeRate
:
1
,
expected
:
"0.02"
,
// 0.01 * 1/100 = 0.0001 -> round up -> 0.01 -> total 0.02
},
{
name
:
"large amount"
,
amount
:
99999.99
,
feeRate
:
10
,
expected
:
"109999.99"
,
// 99999.99 * 10/100 = 9999.999 -> round up -> 10000.00 -> total 109999.99
},
{
name
:
"100 percent fee rate doubles amount"
,
amount
:
50.00
,
feeRate
:
100
,
expected
:
"100.00"
,
},
{
name
:
"precision 0.01 fee difference"
,
amount
:
100.00
,
feeRate
:
1.01
,
expected
:
"101.01"
,
// 100 * 1.01/100 = 1.01
},
{
name
:
"precision 0.02 fee"
,
amount
:
100.00
,
feeRate
:
1.02
,
expected
:
"101.02"
,
},
{
name
:
"zero amount with positive fee"
,
amount
:
0
,
feeRate
:
5
,
expected
:
"0.00"
,
},
{
name
:
"fractional amount no fee"
,
amount
:
19.99
,
feeRate
:
0
,
expected
:
"19.99"
,
},
{
name
:
"fractional fee that causes rounding up"
,
amount
:
33.33
,
feeRate
:
7.77
,
expected
:
"35.92"
,
// 33.33 * 7.77 / 100 = 2.589741 -> round up -> 2.59 -> total 35.92
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
got
:=
CalculatePayAmount
(
tt
.
amount
,
tt
.
feeRate
)
if
got
!=
tt
.
expected
{
t
.
Fatalf
(
"CalculatePayAmount(%v, %v) = %q, want %q"
,
tt
.
amount
,
tt
.
feeRate
,
got
,
tt
.
expected
)
}
})
}
}
backend/internal/payment/load_balancer.go
0 → 100644
View file @
97f14b7a
package
payment
import
(
"context"
"encoding/json"
"fmt"
"log/slog"
"strings"
"sync/atomic"
"time"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
)
// Strategy represents a load balancing strategy for provider instance selection.
type
Strategy
string
const
(
StrategyRoundRobin
Strategy
=
"round-robin"
StrategyLeastAmount
Strategy
=
"least-amount"
)
// ChannelLimits holds limits for a single payment channel within a provider instance.
type
ChannelLimits
struct
{
DailyLimit
float64
`json:"dailyLimit,omitempty"`
SingleMin
float64
`json:"singleMin,omitempty"`
SingleMax
float64
`json:"singleMax,omitempty"`
}
// InstanceLimits holds per-channel limits for a provider instance (JSON).
type
InstanceLimits
map
[
string
]
ChannelLimits
// LoadBalancer selects a provider instance for a given payment type.
type
LoadBalancer
interface
{
GetInstanceConfig
(
ctx
context
.
Context
,
instanceID
int64
)
(
map
[
string
]
string
,
error
)
SelectInstance
(
ctx
context
.
Context
,
providerKey
string
,
paymentType
PaymentType
,
strategy
Strategy
,
orderAmount
float64
)
(
*
InstanceSelection
,
error
)
}
// DefaultLoadBalancer implements LoadBalancer using database queries.
type
DefaultLoadBalancer
struct
{
db
*
dbent
.
Client
encryptionKey
[]
byte
counter
atomic
.
Uint64
}
// NewDefaultLoadBalancer creates a new load balancer.
func
NewDefaultLoadBalancer
(
db
*
dbent
.
Client
,
encryptionKey
[]
byte
)
*
DefaultLoadBalancer
{
return
&
DefaultLoadBalancer
{
db
:
db
,
encryptionKey
:
encryptionKey
}
}
// instanceCandidate pairs an instance with its pre-fetched daily usage.
type
instanceCandidate
struct
{
inst
*
dbent
.
PaymentProviderInstance
dailyUsed
float64
// includes PENDING orders
}
// SelectInstance picks an enabled instance for the given provider key and payment type.
//
// Flow:
// 1. Query all enabled instances for providerKey, filter by supported paymentType
// 2. Batch-query daily usage (PENDING + PAID + COMPLETED + RECHARGING) for all candidates
// 3. Filter out instances where: single-min/max violated OR daily remaining < orderAmount
// 4. Pick from survivors using the configured strategy (round-robin / least-amount)
// 5. If all filtered out, fall back to full list (let the provider itself reject)
func
(
lb
*
DefaultLoadBalancer
)
SelectInstance
(
ctx
context
.
Context
,
providerKey
string
,
paymentType
PaymentType
,
strategy
Strategy
,
orderAmount
float64
,
)
(
*
InstanceSelection
,
error
)
{
// Step 1: query enabled instances matching payment type.
instances
,
err
:=
lb
.
queryEnabledInstances
(
ctx
,
providerKey
,
paymentType
)
if
err
!=
nil
{
return
nil
,
err
}
// Step 2: batch-fetch daily usage for all candidates.
candidates
:=
lb
.
attachDailyUsage
(
ctx
,
instances
)
// Step 3: filter by limits.
available
:=
filterByLimits
(
candidates
,
paymentType
,
orderAmount
)
if
len
(
available
)
==
0
{
slog
.
Warn
(
"all instances exceeded limits, using full candidate list"
,
"provider"
,
providerKey
,
"payment_type"
,
paymentType
,
"order_amount"
,
orderAmount
,
"count"
,
len
(
candidates
))
available
=
candidates
}
// Step 4: pick by strategy.
selected
:=
lb
.
pickByStrategy
(
available
,
strategy
)
return
lb
.
buildSelection
(
selected
.
inst
)
}
// queryEnabledInstances returns enabled instances for providerKey that support paymentType.
func
(
lb
*
DefaultLoadBalancer
)
queryEnabledInstances
(
ctx
context
.
Context
,
providerKey
string
,
paymentType
PaymentType
,
)
([]
*
dbent
.
PaymentProviderInstance
,
error
)
{
instances
,
err
:=
lb
.
db
.
PaymentProviderInstance
.
Query
()
.
Where
(
paymentproviderinstance
.
ProviderKey
(
providerKey
),
paymentproviderinstance
.
Enabled
(
true
),
)
.
Order
(
dbent
.
Asc
(
paymentproviderinstance
.
FieldSortOrder
))
.
All
(
ctx
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query provider instances: %w"
,
err
)
}
var
matched
[]
*
dbent
.
PaymentProviderInstance
for
_
,
inst
:=
range
instances
{
if
paymentType
==
providerKey
||
InstanceSupportsType
(
inst
.
SupportedTypes
,
paymentType
)
{
matched
=
append
(
matched
,
inst
)
}
}
if
len
(
matched
)
==
0
{
return
nil
,
fmt
.
Errorf
(
"no enabled instance for provider %s type %s"
,
providerKey
,
paymentType
)
}
return
matched
,
nil
}
// attachDailyUsage queries daily usage for each instance in a single pass.
// Usage includes PENDING orders to avoid over-committing capacity.
func
(
lb
*
DefaultLoadBalancer
)
attachDailyUsage
(
ctx
context
.
Context
,
instances
[]
*
dbent
.
PaymentProviderInstance
,
)
[]
instanceCandidate
{
todayStart
:=
startOfDay
(
time
.
Now
())
// Collect instance IDs.
ids
:=
make
([]
string
,
len
(
instances
))
for
i
,
inst
:=
range
instances
{
ids
[
i
]
=
fmt
.
Sprintf
(
"%d"
,
inst
.
ID
)
}
// Batch query: sum pay_amount grouped by provider_instance_id.
type
row
struct
{
InstanceID
string
`json:"provider_instance_id"`
Sum
float64
`json:"sum"`
}
var
rows
[]
row
err
:=
lb
.
db
.
PaymentOrder
.
Query
()
.
Where
(
paymentorder
.
ProviderInstanceIDIn
(
ids
...
),
paymentorder
.
StatusIn
(
OrderStatusPending
,
OrderStatusPaid
,
OrderStatusCompleted
,
OrderStatusRecharging
,
),
paymentorder
.
CreatedAtGTE
(
todayStart
),
)
.
GroupBy
(
paymentorder
.
FieldProviderInstanceID
)
.
Aggregate
(
dbent
.
Sum
(
paymentorder
.
FieldPayAmount
))
.
Scan
(
ctx
,
&
rows
)
if
err
!=
nil
{
slog
.
Warn
(
"batch daily usage query failed, treating all as zero"
,
"error"
,
err
)
}
usageMap
:=
make
(
map
[
string
]
float64
,
len
(
rows
))
for
_
,
r
:=
range
rows
{
usageMap
[
r
.
InstanceID
]
=
r
.
Sum
}
candidates
:=
make
([]
instanceCandidate
,
len
(
instances
))
for
i
,
inst
:=
range
instances
{
candidates
[
i
]
=
instanceCandidate
{
inst
:
inst
,
dailyUsed
:
usageMap
[
fmt
.
Sprintf
(
"%d"
,
inst
.
ID
)],
}
}
return
candidates
}
// filterByLimits removes instances that cannot accommodate the order:
// - orderAmount outside single-transaction [min, max]
// - daily remaining capacity (limit - used) < orderAmount
func
filterByLimits
(
candidates
[]
instanceCandidate
,
paymentType
PaymentType
,
orderAmount
float64
)
[]
instanceCandidate
{
var
result
[]
instanceCandidate
for
_
,
c
:=
range
candidates
{
cl
:=
getInstanceChannelLimits
(
c
.
inst
,
paymentType
)
if
cl
.
SingleMin
>
0
&&
orderAmount
<
cl
.
SingleMin
{
slog
.
Info
(
"order below instance single min, skipping"
,
"instance_id"
,
c
.
inst
.
ID
,
"order"
,
orderAmount
,
"min"
,
cl
.
SingleMin
)
continue
}
if
cl
.
SingleMax
>
0
&&
orderAmount
>
cl
.
SingleMax
{
slog
.
Info
(
"order above instance single max, skipping"
,
"instance_id"
,
c
.
inst
.
ID
,
"order"
,
orderAmount
,
"max"
,
cl
.
SingleMax
)
continue
}
if
cl
.
DailyLimit
>
0
&&
c
.
dailyUsed
+
orderAmount
>
cl
.
DailyLimit
{
slog
.
Info
(
"instance daily remaining insufficient, skipping"
,
"instance_id"
,
c
.
inst
.
ID
,
"used"
,
c
.
dailyUsed
,
"order"
,
orderAmount
,
"limit"
,
cl
.
DailyLimit
)
continue
}
result
=
append
(
result
,
c
)
}
return
result
}
// getInstanceChannelLimits returns the channel limits for a specific payment type.
func
getInstanceChannelLimits
(
inst
*
dbent
.
PaymentProviderInstance
,
paymentType
PaymentType
)
ChannelLimits
{
if
inst
.
Limits
==
""
{
return
ChannelLimits
{}
}
var
limits
InstanceLimits
if
err
:=
json
.
Unmarshal
([]
byte
(
inst
.
Limits
),
&
limits
);
err
!=
nil
{
return
ChannelLimits
{}
}
// For Stripe, limits are stored under the provider key "stripe".
lookupKey
:=
paymentType
if
inst
.
ProviderKey
==
"stripe"
{
lookupKey
=
"stripe"
}
if
cl
,
ok
:=
limits
[
lookupKey
];
ok
{
return
cl
}
return
ChannelLimits
{}
}
// pickByStrategy selects one instance from the available candidates.
func
(
lb
*
DefaultLoadBalancer
)
pickByStrategy
(
candidates
[]
instanceCandidate
,
strategy
Strategy
)
instanceCandidate
{
if
strategy
==
StrategyLeastAmount
&&
len
(
candidates
)
>
1
{
return
pickLeastAmount
(
candidates
)
}
// Default: round-robin.
idx
:=
lb
.
counter
.
Add
(
1
)
%
uint64
(
len
(
candidates
))
return
candidates
[
idx
]
}
// pickLeastAmount selects the instance with the lowest daily usage.
// No extra DB queries — usage was pre-fetched in attachDailyUsage.
func
pickLeastAmount
(
candidates
[]
instanceCandidate
)
instanceCandidate
{
best
:=
candidates
[
0
]
for
_
,
c
:=
range
candidates
[
1
:
]
{
if
c
.
dailyUsed
<
best
.
dailyUsed
{
best
=
c
}
}
return
best
}
func
(
lb
*
DefaultLoadBalancer
)
buildSelection
(
selected
*
dbent
.
PaymentProviderInstance
)
(
*
InstanceSelection
,
error
)
{
config
,
err
:=
lb
.
decryptConfig
(
selected
.
Config
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"decrypt instance %d config: %w"
,
selected
.
ID
,
err
)
}
if
selected
.
PaymentMode
!=
""
{
config
[
"paymentMode"
]
=
selected
.
PaymentMode
}
return
&
InstanceSelection
{
InstanceID
:
fmt
.
Sprintf
(
"%d"
,
selected
.
ID
),
Config
:
config
,
SupportedTypes
:
selected
.
SupportedTypes
,
PaymentMode
:
selected
.
PaymentMode
,
},
nil
}
func
(
lb
*
DefaultLoadBalancer
)
decryptConfig
(
encrypted
string
)
(
map
[
string
]
string
,
error
)
{
plaintext
,
err
:=
Decrypt
(
encrypted
,
lb
.
encryptionKey
)
if
err
!=
nil
{
return
nil
,
err
}
var
config
map
[
string
]
string
if
err
:=
json
.
Unmarshal
([]
byte
(
plaintext
),
&
config
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"unmarshal config: %w"
,
err
)
}
return
config
,
nil
}
// GetInstanceDailyAmount returns the total completed order amount for an instance today.
func
(
lb
*
DefaultLoadBalancer
)
GetInstanceDailyAmount
(
ctx
context
.
Context
,
instanceID
string
)
(
float64
,
error
)
{
todayStart
:=
startOfDay
(
time
.
Now
())
var
result
[]
struct
{
Sum
float64
`json:"sum"`
}
err
:=
lb
.
db
.
PaymentOrder
.
Query
()
.
Where
(
paymentorder
.
ProviderInstanceID
(
instanceID
),
paymentorder
.
StatusIn
(
OrderStatusCompleted
,
OrderStatusPaid
,
OrderStatusRecharging
),
paymentorder
.
PaidAtGTE
(
todayStart
),
)
.
Aggregate
(
dbent
.
Sum
(
paymentorder
.
FieldPayAmount
))
.
Scan
(
ctx
,
&
result
)
if
err
!=
nil
{
return
0
,
fmt
.
Errorf
(
"query daily amount: %w"
,
err
)
}
if
len
(
result
)
>
0
{
return
result
[
0
]
.
Sum
,
nil
}
return
0
,
nil
}
func
startOfDay
(
t
time
.
Time
)
time
.
Time
{
return
time
.
Date
(
t
.
Year
(),
t
.
Month
(),
t
.
Day
(),
0
,
0
,
0
,
0
,
t
.
Location
())
}
// InstanceSupportsType checks if the given supported types string includes the target type.
// An empty supportedTypes string means all types are supported.
func
InstanceSupportsType
(
supportedTypes
string
,
target
PaymentType
)
bool
{
if
supportedTypes
==
""
{
return
true
}
for
_
,
t
:=
range
strings
.
Split
(
supportedTypes
,
","
)
{
if
strings
.
TrimSpace
(
t
)
==
target
{
return
true
}
}
return
false
}
// GetInstanceConfig decrypts and returns the configuration for a provider instance by ID.
func
(
lb
*
DefaultLoadBalancer
)
GetInstanceConfig
(
ctx
context
.
Context
,
instanceID
int64
)
(
map
[
string
]
string
,
error
)
{
inst
,
err
:=
lb
.
db
.
PaymentProviderInstance
.
Get
(
ctx
,
instanceID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get instance %d: %w"
,
instanceID
,
err
)
}
return
lb
.
decryptConfig
(
inst
.
Config
)
}
backend/internal/payment/load_balancer_test.go
0 → 100644
View file @
97f14b7a
//go:build unit
package
payment
import
(
"encoding/json"
"testing"
"time"
dbent
"github.com/Wei-Shaw/sub2api/ent"
)
func
TestInstanceSupportsType
(
t
*
testing
.
T
)
{
t
.
Parallel
()
tests
:=
[]
struct
{
name
string
supportedTypes
string
target
PaymentType
expected
bool
}{
{
name
:
"exact match single type"
,
supportedTypes
:
"alipay"
,
target
:
"alipay"
,
expected
:
true
,
},
{
name
:
"no match single type"
,
supportedTypes
:
"wxpay"
,
target
:
"alipay"
,
expected
:
false
,
},
{
name
:
"match in comma-separated list"
,
supportedTypes
:
"alipay,wxpay,stripe"
,
target
:
"wxpay"
,
expected
:
true
,
},
{
name
:
"first in comma-separated list"
,
supportedTypes
:
"alipay,wxpay"
,
target
:
"alipay"
,
expected
:
true
,
},
{
name
:
"last in comma-separated list"
,
supportedTypes
:
"alipay,wxpay,stripe"
,
target
:
"stripe"
,
expected
:
true
,
},
{
name
:
"no match in comma-separated list"
,
supportedTypes
:
"alipay,wxpay"
,
target
:
"stripe"
,
expected
:
false
,
},
{
name
:
"empty target"
,
supportedTypes
:
"alipay,wxpay"
,
target
:
""
,
expected
:
false
,
},
{
name
:
"types with spaces are trimmed"
,
supportedTypes
:
" alipay , wxpay "
,
target
:
"alipay"
,
expected
:
true
,
},
{
name
:
"partial match should not succeed"
,
supportedTypes
:
"alipay_direct"
,
target
:
"alipay"
,
expected
:
false
,
},
{
name
:
"empty supported types means all supported"
,
supportedTypes
:
""
,
target
:
"alipay"
,
expected
:
true
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
got
:=
InstanceSupportsType
(
tt
.
supportedTypes
,
tt
.
target
)
if
got
!=
tt
.
expected
{
t
.
Fatalf
(
"InstanceSupportsType(%q, %q) = %v, want %v"
,
tt
.
supportedTypes
,
tt
.
target
,
got
,
tt
.
expected
)
}
})
}
}
// ---------------------------------------------------------------------------
// Helper to build test PaymentProviderInstance values
// ---------------------------------------------------------------------------
func
testInstance
(
id
int64
,
providerKey
,
limits
string
)
*
dbent
.
PaymentProviderInstance
{
return
&
dbent
.
PaymentProviderInstance
{
ID
:
id
,
ProviderKey
:
providerKey
,
Limits
:
limits
,
Enabled
:
true
,
}
}
// makeLimitsJSON builds a limits JSON string for a single payment type.
func
makeLimitsJSON
(
paymentType
string
,
cl
ChannelLimits
)
string
{
m
:=
map
[
string
]
ChannelLimits
{
paymentType
:
cl
}
b
,
_
:=
json
.
Marshal
(
m
)
return
string
(
b
)
}
// ---------------------------------------------------------------------------
// filterByLimits
// ---------------------------------------------------------------------------
func
TestFilterByLimits
(
t
*
testing
.
T
)
{
t
.
Parallel
()
tests
:=
[]
struct
{
name
string
candidates
[]
instanceCandidate
paymentType
PaymentType
orderAmount
float64
wantIDs
[]
int64
// expected surviving instance IDs
}{
{
name
:
"order below SingleMin is filtered out"
,
candidates
:
[]
instanceCandidate
{
{
inst
:
testInstance
(
1
,
"easypay"
,
makeLimitsJSON
(
"alipay"
,
ChannelLimits
{
SingleMin
:
10
})),
dailyUsed
:
0
},
},
paymentType
:
"alipay"
,
orderAmount
:
5
,
wantIDs
:
nil
,
},
{
name
:
"order at exact SingleMin boundary passes"
,
candidates
:
[]
instanceCandidate
{
{
inst
:
testInstance
(
1
,
"easypay"
,
makeLimitsJSON
(
"alipay"
,
ChannelLimits
{
SingleMin
:
10
})),
dailyUsed
:
0
},
},
paymentType
:
"alipay"
,
orderAmount
:
10
,
wantIDs
:
[]
int64
{
1
},
},
{
name
:
"order above SingleMax is filtered out"
,
candidates
:
[]
instanceCandidate
{
{
inst
:
testInstance
(
1
,
"easypay"
,
makeLimitsJSON
(
"alipay"
,
ChannelLimits
{
SingleMax
:
100
})),
dailyUsed
:
0
},
},
paymentType
:
"alipay"
,
orderAmount
:
150
,
wantIDs
:
nil
,
},
{
name
:
"order at exact SingleMax boundary passes"
,
candidates
:
[]
instanceCandidate
{
{
inst
:
testInstance
(
1
,
"easypay"
,
makeLimitsJSON
(
"alipay"
,
ChannelLimits
{
SingleMax
:
100
})),
dailyUsed
:
0
},
},
paymentType
:
"alipay"
,
orderAmount
:
100
,
wantIDs
:
[]
int64
{
1
},
},
{
name
:
"daily used + orderAmount exceeding dailyLimit is filtered out"
,
candidates
:
[]
instanceCandidate
{
{
inst
:
testInstance
(
1
,
"easypay"
,
makeLimitsJSON
(
"alipay"
,
ChannelLimits
{
DailyLimit
:
500
})),
dailyUsed
:
480
},
},
paymentType
:
"alipay"
,
orderAmount
:
30
,
wantIDs
:
nil
,
// 480+30=510 > 500
},
{
name
:
"daily used + orderAmount equal to dailyLimit passes (strict greater-than)"
,
candidates
:
[]
instanceCandidate
{
{
inst
:
testInstance
(
1
,
"easypay"
,
makeLimitsJSON
(
"alipay"
,
ChannelLimits
{
DailyLimit
:
500
})),
dailyUsed
:
480
},
},
paymentType
:
"alipay"
,
orderAmount
:
20
,
wantIDs
:
[]
int64
{
1
},
// 480+20=500, 500 > 500 is false → passes
},
{
name
:
"daily used + orderAmount below dailyLimit passes"
,
candidates
:
[]
instanceCandidate
{
{
inst
:
testInstance
(
1
,
"easypay"
,
makeLimitsJSON
(
"alipay"
,
ChannelLimits
{
DailyLimit
:
500
})),
dailyUsed
:
400
},
},
paymentType
:
"alipay"
,
orderAmount
:
50
,
wantIDs
:
[]
int64
{
1
},
},
{
name
:
"no limits configured passes through"
,
candidates
:
[]
instanceCandidate
{
{
inst
:
testInstance
(
1
,
"easypay"
,
""
),
dailyUsed
:
99999
},
},
paymentType
:
"alipay"
,
orderAmount
:
100
,
wantIDs
:
[]
int64
{
1
},
},
{
name
:
"multiple candidates with partial filtering"
,
candidates
:
[]
instanceCandidate
{
// singleMax=50, order=80 → filtered out
{
inst
:
testInstance
(
1
,
"easypay"
,
makeLimitsJSON
(
"alipay"
,
ChannelLimits
{
SingleMax
:
50
})),
dailyUsed
:
0
},
// no limits → passes
{
inst
:
testInstance
(
2
,
"easypay"
,
""
),
dailyUsed
:
0
},
// singleMin=100, order=80 → filtered out
{
inst
:
testInstance
(
3
,
"easypay"
,
makeLimitsJSON
(
"alipay"
,
ChannelLimits
{
SingleMin
:
100
})),
dailyUsed
:
0
},
// daily limit ok → passes (500+80=580 < 1000)
{
inst
:
testInstance
(
4
,
"easypay"
,
makeLimitsJSON
(
"alipay"
,
ChannelLimits
{
DailyLimit
:
1000
})),
dailyUsed
:
500
},
},
paymentType
:
"alipay"
,
orderAmount
:
80
,
wantIDs
:
[]
int64
{
2
,
4
},
},
{
name
:
"zero SingleMin and SingleMax means no single-transaction limit"
,
candidates
:
[]
instanceCandidate
{
{
inst
:
testInstance
(
1
,
"easypay"
,
makeLimitsJSON
(
"alipay"
,
ChannelLimits
{
SingleMin
:
0
,
SingleMax
:
0
,
DailyLimit
:
0
})),
dailyUsed
:
0
},
},
paymentType
:
"alipay"
,
orderAmount
:
99999
,
wantIDs
:
[]
int64
{
1
},
},
{
name
:
"all limits combined - order passes all checks"
,
candidates
:
[]
instanceCandidate
{
{
inst
:
testInstance
(
1
,
"easypay"
,
makeLimitsJSON
(
"alipay"
,
ChannelLimits
{
SingleMin
:
10
,
SingleMax
:
200
,
DailyLimit
:
1000
})),
dailyUsed
:
500
},
},
paymentType
:
"alipay"
,
orderAmount
:
50
,
wantIDs
:
[]
int64
{
1
},
},
{
name
:
"all limits combined - order fails SingleMin"
,
candidates
:
[]
instanceCandidate
{
{
inst
:
testInstance
(
1
,
"easypay"
,
makeLimitsJSON
(
"alipay"
,
ChannelLimits
{
SingleMin
:
10
,
SingleMax
:
200
,
DailyLimit
:
1000
})),
dailyUsed
:
500
},
},
paymentType
:
"alipay"
,
orderAmount
:
5
,
wantIDs
:
nil
,
},
{
name
:
"empty candidates returns empty"
,
candidates
:
nil
,
paymentType
:
"alipay"
,
orderAmount
:
10
,
wantIDs
:
nil
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
got
:=
filterByLimits
(
tt
.
candidates
,
tt
.
paymentType
,
tt
.
orderAmount
)
gotIDs
:=
make
([]
int64
,
len
(
got
))
for
i
,
c
:=
range
got
{
gotIDs
[
i
]
=
c
.
inst
.
ID
}
if
!
int64SliceEqual
(
gotIDs
,
tt
.
wantIDs
)
{
t
.
Fatalf
(
"filterByLimits() returned IDs %v, want %v"
,
gotIDs
,
tt
.
wantIDs
)
}
})
}
}
// ---------------------------------------------------------------------------
// pickLeastAmount
// ---------------------------------------------------------------------------
func
TestPickLeastAmount
(
t
*
testing
.
T
)
{
t
.
Parallel
()
t
.
Run
(
"picks candidate with lowest dailyUsed"
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
candidates
:=
[]
instanceCandidate
{
{
inst
:
testInstance
(
1
,
"easypay"
,
""
),
dailyUsed
:
300
},
{
inst
:
testInstance
(
2
,
"easypay"
,
""
),
dailyUsed
:
100
},
{
inst
:
testInstance
(
3
,
"easypay"
,
""
),
dailyUsed
:
200
},
}
got
:=
pickLeastAmount
(
candidates
)
if
got
.
inst
.
ID
!=
2
{
t
.
Fatalf
(
"pickLeastAmount() picked instance %d, want 2"
,
got
.
inst
.
ID
)
}
})
t
.
Run
(
"with equal dailyUsed picks the first one"
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
candidates
:=
[]
instanceCandidate
{
{
inst
:
testInstance
(
1
,
"easypay"
,
""
),
dailyUsed
:
100
},
{
inst
:
testInstance
(
2
,
"easypay"
,
""
),
dailyUsed
:
100
},
{
inst
:
testInstance
(
3
,
"easypay"
,
""
),
dailyUsed
:
200
},
}
got
:=
pickLeastAmount
(
candidates
)
if
got
.
inst
.
ID
!=
1
{
t
.
Fatalf
(
"pickLeastAmount() picked instance %d, want 1 (first with lowest)"
,
got
.
inst
.
ID
)
}
})
t
.
Run
(
"single candidate returns that candidate"
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
candidates
:=
[]
instanceCandidate
{
{
inst
:
testInstance
(
42
,
"easypay"
,
""
),
dailyUsed
:
999
},
}
got
:=
pickLeastAmount
(
candidates
)
if
got
.
inst
.
ID
!=
42
{
t
.
Fatalf
(
"pickLeastAmount() picked instance %d, want 42"
,
got
.
inst
.
ID
)
}
})
t
.
Run
(
"zero usage among non-zero picks zero"
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
candidates
:=
[]
instanceCandidate
{
{
inst
:
testInstance
(
1
,
"easypay"
,
""
),
dailyUsed
:
500
},
{
inst
:
testInstance
(
2
,
"easypay"
,
""
),
dailyUsed
:
0
},
{
inst
:
testInstance
(
3
,
"easypay"
,
""
),
dailyUsed
:
300
},
}
got
:=
pickLeastAmount
(
candidates
)
if
got
.
inst
.
ID
!=
2
{
t
.
Fatalf
(
"pickLeastAmount() picked instance %d, want 2"
,
got
.
inst
.
ID
)
}
})
}
// ---------------------------------------------------------------------------
// getInstanceChannelLimits
// ---------------------------------------------------------------------------
func
TestGetInstanceChannelLimits
(
t
*
testing
.
T
)
{
t
.
Parallel
()
tests
:=
[]
struct
{
name
string
inst
*
dbent
.
PaymentProviderInstance
paymentType
PaymentType
want
ChannelLimits
}{
{
name
:
"empty limits string returns zero ChannelLimits"
,
inst
:
testInstance
(
1
,
"easypay"
,
""
),
paymentType
:
"alipay"
,
want
:
ChannelLimits
{},
},
{
name
:
"invalid JSON returns zero ChannelLimits"
,
inst
:
testInstance
(
1
,
"easypay"
,
"not-json{"
),
paymentType
:
"alipay"
,
want
:
ChannelLimits
{},
},
{
name
:
"valid JSON with matching payment type"
,
inst
:
testInstance
(
1
,
"easypay"
,
`{"alipay":{"singleMin":5,"singleMax":200,"dailyLimit":1000}}`
),
paymentType
:
"alipay"
,
want
:
ChannelLimits
{
SingleMin
:
5
,
SingleMax
:
200
,
DailyLimit
:
1000
},
},
{
name
:
"payment type not in limits returns zero ChannelLimits"
,
inst
:
testInstance
(
1
,
"easypay"
,
`{"alipay":{"singleMin":5,"singleMax":200}}`
),
paymentType
:
"wxpay"
,
want
:
ChannelLimits
{},
},
{
name
:
"stripe provider uses stripe lookup key regardless of payment type"
,
inst
:
testInstance
(
1
,
"stripe"
,
`{"stripe":{"singleMin":10,"singleMax":500,"dailyLimit":5000}}`
),
paymentType
:
"alipay"
,
want
:
ChannelLimits
{
SingleMin
:
10
,
SingleMax
:
500
,
DailyLimit
:
5000
},
},
{
name
:
"stripe provider ignores payment type key even if present"
,
inst
:
testInstance
(
1
,
"stripe"
,
`{"stripe":{"singleMin":10,"singleMax":500},"alipay":{"singleMin":1,"singleMax":100}}`
),
paymentType
:
"alipay"
,
want
:
ChannelLimits
{
SingleMin
:
10
,
SingleMax
:
500
},
},
{
name
:
"non-stripe provider uses payment type as lookup key"
,
inst
:
testInstance
(
1
,
"easypay"
,
`{"alipay":{"singleMin":5},"wxpay":{"singleMin":10}}`
),
paymentType
:
"wxpay"
,
want
:
ChannelLimits
{
SingleMin
:
10
},
},
{
name
:
"valid JSON with partial limits (only dailyLimit)"
,
inst
:
testInstance
(
1
,
"easypay"
,
`{"alipay":{"dailyLimit":800}}`
),
paymentType
:
"alipay"
,
want
:
ChannelLimits
{
DailyLimit
:
800
},
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
got
:=
getInstanceChannelLimits
(
tt
.
inst
,
tt
.
paymentType
)
if
got
!=
tt
.
want
{
t
.
Fatalf
(
"getInstanceChannelLimits() = %+v, want %+v"
,
got
,
tt
.
want
)
}
})
}
}
// ---------------------------------------------------------------------------
// startOfDay
// ---------------------------------------------------------------------------
func
TestStartOfDay
(
t
*
testing
.
T
)
{
t
.
Parallel
()
tests
:=
[]
struct
{
name
string
in
time
.
Time
want
time
.
Time
}{
{
name
:
"midday returns midnight of same day"
,
in
:
time
.
Date
(
2025
,
6
,
15
,
14
,
30
,
45
,
123456789
,
time
.
UTC
),
want
:
time
.
Date
(
2025
,
6
,
15
,
0
,
0
,
0
,
0
,
time
.
UTC
),
},
{
name
:
"midnight returns same time"
,
in
:
time
.
Date
(
2025
,
1
,
1
,
0
,
0
,
0
,
0
,
time
.
UTC
),
want
:
time
.
Date
(
2025
,
1
,
1
,
0
,
0
,
0
,
0
,
time
.
UTC
),
},
{
name
:
"last second of day returns midnight of same day"
,
in
:
time
.
Date
(
2025
,
12
,
31
,
23
,
59
,
59
,
999999999
,
time
.
UTC
),
want
:
time
.
Date
(
2025
,
12
,
31
,
0
,
0
,
0
,
0
,
time
.
UTC
),
},
{
name
:
"preserves timezone location"
,
in
:
time
.
Date
(
2025
,
3
,
10
,
15
,
0
,
0
,
0
,
time
.
FixedZone
(
"CST"
,
8
*
3600
)),
want
:
time
.
Date
(
2025
,
3
,
10
,
0
,
0
,
0
,
0
,
time
.
FixedZone
(
"CST"
,
8
*
3600
)),
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
got
:=
startOfDay
(
tt
.
in
)
if
!
got
.
Equal
(
tt
.
want
)
{
t
.
Fatalf
(
"startOfDay(%v) = %v, want %v"
,
tt
.
in
,
got
,
tt
.
want
)
}
// Also verify location is preserved.
if
got
.
Location
()
.
String
()
!=
tt
.
want
.
Location
()
.
String
()
{
t
.
Fatalf
(
"startOfDay() location = %v, want %v"
,
got
.
Location
(),
tt
.
want
.
Location
())
}
})
}
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
// int64SliceEqual compares two int64 slices for equality.
// Both nil and empty slices are treated as equal.
func
int64SliceEqual
(
a
,
b
[]
int64
)
bool
{
if
len
(
a
)
==
0
&&
len
(
b
)
==
0
{
return
true
}
if
len
(
a
)
!=
len
(
b
)
{
return
false
}
for
i
:=
range
a
{
if
a
[
i
]
!=
b
[
i
]
{
return
false
}
}
return
true
}
backend/internal/payment/provider/alipay.go
0 → 100644
View file @
97f14b7a
package
provider
import
(
"context"
"fmt"
"net/url"
"strconv"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/smartwalle/alipay/v3"
)
// Alipay product codes.
const
(
alipayProductCodePagePay
=
"FAST_INSTANT_TRADE_PAY"
alipayProductCodeWapPay
=
"QUICK_WAP_WAY"
)
// Alipay response constants.
const
(
alipayFundChangeYes
=
"Y"
alipayErrTradeNotExist
=
"ACQ.TRADE_NOT_EXIST"
alipayRefundSuffix
=
"-refund"
)
// Alipay implements payment.Provider and payment.CancelableProvider using the smartwalle/alipay SDK.
type
Alipay
struct
{
instanceID
string
config
map
[
string
]
string
// appId, privateKey, publicKey (or alipayPublicKey), notifyUrl, returnUrl
mu
sync
.
Mutex
client
*
alipay
.
Client
}
// NewAlipay creates a new Alipay provider instance.
func
NewAlipay
(
instanceID
string
,
config
map
[
string
]
string
)
(
*
Alipay
,
error
)
{
required
:=
[]
string
{
"appId"
,
"privateKey"
}
for
_
,
k
:=
range
required
{
if
config
[
k
]
==
""
{
return
nil
,
fmt
.
Errorf
(
"alipay config missing required key: %s"
,
k
)
}
}
return
&
Alipay
{
instanceID
:
instanceID
,
config
:
config
,
},
nil
}
func
(
a
*
Alipay
)
getClient
()
(
*
alipay
.
Client
,
error
)
{
a
.
mu
.
Lock
()
defer
a
.
mu
.
Unlock
()
if
a
.
client
!=
nil
{
return
a
.
client
,
nil
}
client
,
err
:=
alipay
.
New
(
a
.
config
[
"appId"
],
a
.
config
[
"privateKey"
],
true
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"alipay init client: %w"
,
err
)
}
pubKey
:=
a
.
config
[
"publicKey"
]
if
pubKey
==
""
{
pubKey
=
a
.
config
[
"alipayPublicKey"
]
}
if
pubKey
==
""
{
return
nil
,
fmt
.
Errorf
(
"alipay config missing required key: publicKey (or alipayPublicKey)"
)
}
if
err
:=
client
.
LoadAliPayPublicKey
(
pubKey
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"alipay load public key: %w"
,
err
)
}
a
.
client
=
client
return
a
.
client
,
nil
}
func
(
a
*
Alipay
)
Name
()
string
{
return
"Alipay"
}
func
(
a
*
Alipay
)
ProviderKey
()
string
{
return
payment
.
TypeAlipay
}
func
(
a
*
Alipay
)
SupportedTypes
()
[]
payment
.
PaymentType
{
return
[]
payment
.
PaymentType
{
payment
.
TypeAlipayDirect
}
}
// CreatePayment creates an Alipay payment page URL.
func
(
a
*
Alipay
)
CreatePayment
(
_
context
.
Context
,
req
payment
.
CreatePaymentRequest
)
(
*
payment
.
CreatePaymentResponse
,
error
)
{
client
,
err
:=
a
.
getClient
()
if
err
!=
nil
{
return
nil
,
err
}
notifyURL
:=
a
.
config
[
"notifyUrl"
]
if
req
.
NotifyURL
!=
""
{
notifyURL
=
req
.
NotifyURL
}
returnURL
:=
a
.
config
[
"returnUrl"
]
if
req
.
ReturnURL
!=
""
{
returnURL
=
req
.
ReturnURL
}
if
req
.
IsMobile
{
return
a
.
createTrade
(
client
,
req
,
notifyURL
,
returnURL
,
true
)
}
return
a
.
createTrade
(
client
,
req
,
notifyURL
,
returnURL
,
false
)
}
func
(
a
*
Alipay
)
createTrade
(
client
*
alipay
.
Client
,
req
payment
.
CreatePaymentRequest
,
notifyURL
,
returnURL
string
,
isMobile
bool
)
(
*
payment
.
CreatePaymentResponse
,
error
)
{
if
isMobile
{
param
:=
alipay
.
TradeWapPay
{}
param
.
OutTradeNo
=
req
.
OrderID
param
.
TotalAmount
=
req
.
Amount
param
.
Subject
=
req
.
Subject
param
.
ProductCode
=
alipayProductCodeWapPay
param
.
NotifyURL
=
notifyURL
param
.
ReturnURL
=
returnURL
payURL
,
err
:=
client
.
TradeWapPay
(
param
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"alipay TradeWapPay: %w"
,
err
)
}
return
&
payment
.
CreatePaymentResponse
{
TradeNo
:
req
.
OrderID
,
PayURL
:
payURL
.
String
(),
},
nil
}
param
:=
alipay
.
TradePagePay
{}
param
.
OutTradeNo
=
req
.
OrderID
param
.
TotalAmount
=
req
.
Amount
param
.
Subject
=
req
.
Subject
param
.
ProductCode
=
alipayProductCodePagePay
param
.
NotifyURL
=
notifyURL
param
.
ReturnURL
=
returnURL
payURL
,
err
:=
client
.
TradePagePay
(
param
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"alipay TradePagePay: %w"
,
err
)
}
return
&
payment
.
CreatePaymentResponse
{
TradeNo
:
req
.
OrderID
,
PayURL
:
payURL
.
String
(),
QRCode
:
payURL
.
String
(),
},
nil
}
// QueryOrder queries the trade status via Alipay.
func
(
a
*
Alipay
)
QueryOrder
(
ctx
context
.
Context
,
tradeNo
string
)
(
*
payment
.
QueryOrderResponse
,
error
)
{
client
,
err
:=
a
.
getClient
()
if
err
!=
nil
{
return
nil
,
err
}
result
,
err
:=
client
.
TradeQuery
(
ctx
,
alipay
.
TradeQuery
{
OutTradeNo
:
tradeNo
})
if
err
!=
nil
{
if
isTradeNotExist
(
err
)
{
return
&
payment
.
QueryOrderResponse
{
TradeNo
:
tradeNo
,
Status
:
payment
.
ProviderStatusPending
,
},
nil
}
return
nil
,
fmt
.
Errorf
(
"alipay TradeQuery: %w"
,
err
)
}
status
:=
payment
.
ProviderStatusPending
switch
result
.
TradeStatus
{
case
alipay
.
TradeStatusSuccess
,
alipay
.
TradeStatusFinished
:
status
=
payment
.
ProviderStatusPaid
case
alipay
.
TradeStatusClosed
:
status
=
payment
.
ProviderStatusFailed
}
amount
,
err
:=
strconv
.
ParseFloat
(
result
.
TotalAmount
,
64
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"alipay parse amount %q: %w"
,
result
.
TotalAmount
,
err
)
}
return
&
payment
.
QueryOrderResponse
{
TradeNo
:
result
.
TradeNo
,
Status
:
status
,
Amount
:
amount
,
PaidAt
:
result
.
SendPayDate
,
},
nil
}
// VerifyNotification decodes and verifies an Alipay async notification.
func
(
a
*
Alipay
)
VerifyNotification
(
ctx
context
.
Context
,
rawBody
string
,
_
map
[
string
]
string
)
(
*
payment
.
PaymentNotification
,
error
)
{
client
,
err
:=
a
.
getClient
()
if
err
!=
nil
{
return
nil
,
err
}
values
,
err
:=
url
.
ParseQuery
(
rawBody
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"alipay parse notification: %w"
,
err
)
}
notification
,
err
:=
client
.
DecodeNotification
(
ctx
,
values
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"alipay verify notification: %w"
,
err
)
}
status
:=
payment
.
ProviderStatusFailed
if
notification
.
TradeStatus
==
alipay
.
TradeStatusSuccess
||
notification
.
TradeStatus
==
alipay
.
TradeStatusFinished
{
status
=
payment
.
ProviderStatusSuccess
}
amount
,
err
:=
strconv
.
ParseFloat
(
notification
.
TotalAmount
,
64
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"alipay parse notification amount %q: %w"
,
notification
.
TotalAmount
,
err
)
}
return
&
payment
.
PaymentNotification
{
TradeNo
:
notification
.
TradeNo
,
OrderID
:
notification
.
OutTradeNo
,
Amount
:
amount
,
Status
:
status
,
RawData
:
rawBody
,
},
nil
}
// Refund requests a refund through Alipay.
func
(
a
*
Alipay
)
Refund
(
ctx
context
.
Context
,
req
payment
.
RefundRequest
)
(
*
payment
.
RefundResponse
,
error
)
{
client
,
err
:=
a
.
getClient
()
if
err
!=
nil
{
return
nil
,
err
}
result
,
err
:=
client
.
TradeRefund
(
ctx
,
alipay
.
TradeRefund
{
OutTradeNo
:
req
.
OrderID
,
RefundAmount
:
req
.
Amount
,
RefundReason
:
req
.
Reason
,
OutRequestNo
:
fmt
.
Sprintf
(
"%s-refund-%d"
,
req
.
OrderID
,
time
.
Now
()
.
UnixNano
()),
})
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"alipay TradeRefund: %w"
,
err
)
}
refundStatus
:=
payment
.
ProviderStatusPending
if
result
.
FundChange
==
alipayFundChangeYes
{
refundStatus
=
payment
.
ProviderStatusSuccess
}
refundID
:=
result
.
TradeNo
if
refundID
==
""
{
refundID
=
req
.
OrderID
+
alipayRefundSuffix
}
return
&
payment
.
RefundResponse
{
RefundID
:
refundID
,
Status
:
refundStatus
,
},
nil
}
// CancelPayment closes a pending trade on Alipay.
func
(
a
*
Alipay
)
CancelPayment
(
ctx
context
.
Context
,
tradeNo
string
)
error
{
client
,
err
:=
a
.
getClient
()
if
err
!=
nil
{
return
err
}
_
,
err
=
client
.
TradeClose
(
ctx
,
alipay
.
TradeClose
{
OutTradeNo
:
tradeNo
})
if
err
!=
nil
{
if
isTradeNotExist
(
err
)
{
return
nil
}
return
fmt
.
Errorf
(
"alipay TradeClose: %w"
,
err
)
}
return
nil
}
func
isTradeNotExist
(
err
error
)
bool
{
if
err
==
nil
{
return
false
}
return
strings
.
Contains
(
err
.
Error
(),
alipayErrTradeNotExist
)
}
// Ensure interface compliance.
var
(
_
payment
.
Provider
=
(
*
Alipay
)(
nil
)
_
payment
.
CancelableProvider
=
(
*
Alipay
)(
nil
)
)
backend/internal/payment/provider/alipay_test.go
0 → 100644
View file @
97f14b7a
//go:build unit
package
provider
import
(
"errors"
"strings"
"testing"
)
func
TestIsTradeNotExist
(
t
*
testing
.
T
)
{
t
.
Parallel
()
tests
:=
[]
struct
{
name
string
err
error
want
bool
}{
{
name
:
"nil error returns false"
,
err
:
nil
,
want
:
false
,
},
{
name
:
"error containing ACQ.TRADE_NOT_EXIST returns true"
,
err
:
errors
.
New
(
"alipay: sub_code=ACQ.TRADE_NOT_EXIST, sub_msg=交易不存在"
),
want
:
true
,
},
{
name
:
"error not containing the code returns false"
,
err
:
errors
.
New
(
"alipay: sub_code=ACQ.SYSTEM_ERROR, sub_msg=系统错误"
),
want
:
false
,
},
{
name
:
"error with only partial match returns false"
,
err
:
errors
.
New
(
"ACQ.TRADE_NOT"
),
want
:
false
,
},
{
name
:
"error with exact constant value returns true"
,
err
:
errors
.
New
(
alipayErrTradeNotExist
),
want
:
true
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
got
:=
isTradeNotExist
(
tt
.
err
)
if
got
!=
tt
.
want
{
t
.
Errorf
(
"isTradeNotExist(%v) = %v, want %v"
,
tt
.
err
,
got
,
tt
.
want
)
}
})
}
}
func
TestNewAlipay
(
t
*
testing
.
T
)
{
t
.
Parallel
()
validConfig
:=
map
[
string
]
string
{
"appId"
:
"2021001234567890"
,
"privateKey"
:
"MIIEvQIBADANBgkqhkiG9w0BAQEFAASC..."
,
}
// helper to clone and override config fields
withOverride
:=
func
(
overrides
map
[
string
]
string
)
map
[
string
]
string
{
cfg
:=
make
(
map
[
string
]
string
,
len
(
validConfig
))
for
k
,
v
:=
range
validConfig
{
cfg
[
k
]
=
v
}
for
k
,
v
:=
range
overrides
{
cfg
[
k
]
=
v
}
return
cfg
}
tests
:=
[]
struct
{
name
string
config
map
[
string
]
string
wantErr
bool
errSubstr
string
}{
{
name
:
"valid config succeeds"
,
config
:
validConfig
,
wantErr
:
false
,
},
{
name
:
"missing appId"
,
config
:
withOverride
(
map
[
string
]
string
{
"appId"
:
""
}),
wantErr
:
true
,
errSubstr
:
"appId"
,
},
{
name
:
"missing privateKey"
,
config
:
withOverride
(
map
[
string
]
string
{
"privateKey"
:
""
}),
wantErr
:
true
,
errSubstr
:
"privateKey"
,
},
{
name
:
"nil config map returns error for appId"
,
config
:
map
[
string
]
string
{},
wantErr
:
true
,
errSubstr
:
"appId"
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
got
,
err
:=
NewAlipay
(
"test-instance"
,
tt
.
config
)
if
tt
.
wantErr
{
if
err
==
nil
{
t
.
Fatal
(
"expected error, got nil"
)
}
if
tt
.
errSubstr
!=
""
&&
!
strings
.
Contains
(
err
.
Error
(),
tt
.
errSubstr
)
{
t
.
Errorf
(
"error %q should contain %q"
,
err
.
Error
(),
tt
.
errSubstr
)
}
return
}
if
err
!=
nil
{
t
.
Fatalf
(
"unexpected error: %v"
,
err
)
}
if
got
==
nil
{
t
.
Fatal
(
"expected non-nil Alipay instance"
)
}
if
got
.
instanceID
!=
"test-instance"
{
t
.
Errorf
(
"instanceID = %q, want %q"
,
got
.
instanceID
,
"test-instance"
)
}
})
}
}
backend/internal/payment/provider/easypay.go
0 → 100644
View file @
97f14b7a
// Package provider contains concrete payment provider implementations.
package
provider
import
(
"context"
"crypto/hmac"
"crypto/md5"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/payment"
)
// EasyPay constants.
const
(
easypayCodeSuccess
=
1
easypayStatusPaid
=
1
easypayHTTPTimeout
=
10
*
time
.
Second
maxEasypayResponseSize
=
1
<<
20
// 1MB
tradeStatusSuccess
=
"TRADE_SUCCESS"
signTypeMD5
=
"MD5"
paymentModePopup
=
"popup"
deviceMobile
=
"mobile"
)
// EasyPay implements payment.Provider for the EasyPay aggregation platform.
type
EasyPay
struct
{
instanceID
string
config
map
[
string
]
string
httpClient
*
http
.
Client
}
// NewEasyPay creates a new EasyPay provider.
// config keys: pid, pkey, apiBase, notifyUrl, returnUrl, cid, cidAlipay, cidWxpay
func
NewEasyPay
(
instanceID
string
,
config
map
[
string
]
string
)
(
*
EasyPay
,
error
)
{
for
_
,
k
:=
range
[]
string
{
"pid"
,
"pkey"
,
"apiBase"
,
"notifyUrl"
,
"returnUrl"
}
{
if
config
[
k
]
==
""
{
return
nil
,
fmt
.
Errorf
(
"easypay config missing required key: %s"
,
k
)
}
}
return
&
EasyPay
{
instanceID
:
instanceID
,
config
:
config
,
httpClient
:
&
http
.
Client
{
Timeout
:
easypayHTTPTimeout
},
},
nil
}
func
(
e
*
EasyPay
)
Name
()
string
{
return
"EasyPay"
}
func
(
e
*
EasyPay
)
ProviderKey
()
string
{
return
payment
.
TypeEasyPay
}
func
(
e
*
EasyPay
)
SupportedTypes
()
[]
payment
.
PaymentType
{
return
[]
payment
.
PaymentType
{
payment
.
TypeAlipay
,
payment
.
TypeWxpay
}
}
func
(
e
*
EasyPay
)
CreatePayment
(
ctx
context
.
Context
,
req
payment
.
CreatePaymentRequest
)
(
*
payment
.
CreatePaymentResponse
,
error
)
{
// Payment mode determined by instance config, not payment type.
// "popup" → hosted page (submit.php); "qrcode"/default → API call (mapi.php).
mode
:=
e
.
config
[
"paymentMode"
]
if
mode
==
paymentModePopup
{
return
e
.
createRedirectPayment
(
req
)
}
return
e
.
createAPIPayment
(
ctx
,
req
)
}
// createRedirectPayment builds a submit.php URL for browser redirect.
// No server-side API call — the user is redirected to EasyPay's hosted page.
// TradeNo is empty; it arrives via the notify callback after payment.
func
(
e
*
EasyPay
)
createRedirectPayment
(
req
payment
.
CreatePaymentRequest
)
(
*
payment
.
CreatePaymentResponse
,
error
)
{
notifyURL
,
returnURL
:=
e
.
resolveURLs
(
req
)
params
:=
map
[
string
]
string
{
"pid"
:
e
.
config
[
"pid"
],
"type"
:
req
.
PaymentType
,
"out_trade_no"
:
req
.
OrderID
,
"notify_url"
:
notifyURL
,
"return_url"
:
returnURL
,
"name"
:
req
.
Subject
,
"money"
:
req
.
Amount
,
}
if
cid
:=
e
.
resolveCID
(
req
.
PaymentType
);
cid
!=
""
{
params
[
"cid"
]
=
cid
}
if
req
.
IsMobile
{
params
[
"device"
]
=
deviceMobile
}
params
[
"sign"
]
=
easyPaySign
(
params
,
e
.
config
[
"pkey"
])
params
[
"sign_type"
]
=
signTypeMD5
q
:=
url
.
Values
{}
for
k
,
v
:=
range
params
{
q
.
Set
(
k
,
v
)
}
base
:=
strings
.
TrimRight
(
e
.
config
[
"apiBase"
],
"/"
)
payURL
:=
base
+
"/submit.php?"
+
q
.
Encode
()
return
&
payment
.
CreatePaymentResponse
{
PayURL
:
payURL
},
nil
}
// createAPIPayment calls mapi.php to get payurl/qrcode (existing behavior).
func
(
e
*
EasyPay
)
createAPIPayment
(
ctx
context
.
Context
,
req
payment
.
CreatePaymentRequest
)
(
*
payment
.
CreatePaymentResponse
,
error
)
{
notifyURL
,
returnURL
:=
e
.
resolveURLs
(
req
)
params
:=
map
[
string
]
string
{
"pid"
:
e
.
config
[
"pid"
],
"type"
:
req
.
PaymentType
,
"out_trade_no"
:
req
.
OrderID
,
"notify_url"
:
notifyURL
,
"return_url"
:
returnURL
,
"name"
:
req
.
Subject
,
"money"
:
req
.
Amount
,
"clientip"
:
req
.
ClientIP
,
}
if
cid
:=
e
.
resolveCID
(
req
.
PaymentType
);
cid
!=
""
{
params
[
"cid"
]
=
cid
}
if
req
.
IsMobile
{
params
[
"device"
]
=
deviceMobile
}
params
[
"sign"
]
=
easyPaySign
(
params
,
e
.
config
[
"pkey"
])
params
[
"sign_type"
]
=
signTypeMD5
body
,
err
:=
e
.
post
(
ctx
,
strings
.
TrimRight
(
e
.
config
[
"apiBase"
],
"/"
)
+
"/mapi.php"
,
params
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"easypay create: %w"
,
err
)
}
var
resp
struct
{
Code
int
`json:"code"`
Msg
string
`json:"msg"`
TradeNo
string
`json:"trade_no"`
PayURL
string
`json:"payurl"`
PayURL2
string
`json:"payurl2"`
// H5 mobile payment URL
QRCode
string
`json:"qrcode"`
}
if
err
:=
json
.
Unmarshal
(
body
,
&
resp
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"easypay parse: %w"
,
err
)
}
if
resp
.
Code
!=
easypayCodeSuccess
{
return
nil
,
fmt
.
Errorf
(
"easypay error: %s"
,
resp
.
Msg
)
}
payURL
:=
resp
.
PayURL
if
req
.
IsMobile
&&
resp
.
PayURL2
!=
""
{
payURL
=
resp
.
PayURL2
}
return
&
payment
.
CreatePaymentResponse
{
TradeNo
:
resp
.
TradeNo
,
PayURL
:
payURL
,
QRCode
:
resp
.
QRCode
},
nil
}
// resolveURLs returns (notifyURL, returnURL) preferring request values,
// falling back to instance config.
func
(
e
*
EasyPay
)
resolveURLs
(
req
payment
.
CreatePaymentRequest
)
(
string
,
string
)
{
notifyURL
:=
req
.
NotifyURL
if
notifyURL
==
""
{
notifyURL
=
e
.
config
[
"notifyUrl"
]
}
returnURL
:=
req
.
ReturnURL
if
returnURL
==
""
{
returnURL
=
e
.
config
[
"returnUrl"
]
}
return
notifyURL
,
returnURL
}
func
(
e
*
EasyPay
)
QueryOrder
(
ctx
context
.
Context
,
tradeNo
string
)
(
*
payment
.
QueryOrderResponse
,
error
)
{
params
:=
map
[
string
]
string
{
"act"
:
"order"
,
"pid"
:
e
.
config
[
"pid"
],
"key"
:
e
.
config
[
"pkey"
],
"out_trade_no"
:
tradeNo
,
}
body
,
err
:=
e
.
post
(
ctx
,
e
.
config
[
"apiBase"
]
+
"/api.php"
,
params
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"easypay query: %w"
,
err
)
}
var
resp
struct
{
Code
int
`json:"code"`
Msg
string
`json:"msg"`
Status
int
`json:"status"`
Money
string
`json:"money"`
}
if
err
:=
json
.
Unmarshal
(
body
,
&
resp
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"easypay parse query: %w"
,
err
)
}
status
:=
payment
.
ProviderStatusPending
if
resp
.
Status
==
easypayStatusPaid
{
status
=
payment
.
ProviderStatusPaid
}
amount
,
_
:=
strconv
.
ParseFloat
(
resp
.
Money
,
64
)
return
&
payment
.
QueryOrderResponse
{
TradeNo
:
tradeNo
,
Status
:
status
,
Amount
:
amount
},
nil
}
func
(
e
*
EasyPay
)
VerifyNotification
(
_
context
.
Context
,
rawBody
string
,
_
map
[
string
]
string
)
(
*
payment
.
PaymentNotification
,
error
)
{
values
,
err
:=
url
.
ParseQuery
(
rawBody
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"parse notify: %w"
,
err
)
}
// url.ParseQuery already decodes values — no additional decode needed.
params
:=
make
(
map
[
string
]
string
)
for
k
:=
range
values
{
params
[
k
]
=
values
.
Get
(
k
)
}
sign
:=
params
[
"sign"
]
if
sign
==
""
{
return
nil
,
fmt
.
Errorf
(
"missing sign"
)
}
if
!
easyPayVerifySign
(
params
,
e
.
config
[
"pkey"
],
sign
)
{
return
nil
,
fmt
.
Errorf
(
"invalid signature"
)
}
status
:=
payment
.
ProviderStatusFailed
if
params
[
"trade_status"
]
==
tradeStatusSuccess
{
status
=
payment
.
ProviderStatusSuccess
}
amount
,
_
:=
strconv
.
ParseFloat
(
params
[
"money"
],
64
)
return
&
payment
.
PaymentNotification
{
TradeNo
:
params
[
"trade_no"
],
OrderID
:
params
[
"out_trade_no"
],
Amount
:
amount
,
Status
:
status
,
RawData
:
rawBody
,
},
nil
}
func
(
e
*
EasyPay
)
Refund
(
ctx
context
.
Context
,
req
payment
.
RefundRequest
)
(
*
payment
.
RefundResponse
,
error
)
{
params
:=
map
[
string
]
string
{
"pid"
:
e
.
config
[
"pid"
],
"key"
:
e
.
config
[
"pkey"
],
"trade_no"
:
req
.
TradeNo
,
"out_trade_no"
:
req
.
OrderID
,
"money"
:
req
.
Amount
,
}
body
,
err
:=
e
.
post
(
ctx
,
e
.
config
[
"apiBase"
]
+
"/api.php?act=refund"
,
params
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"easypay refund: %w"
,
err
)
}
var
resp
struct
{
Code
int
`json:"code"`
Msg
string
`json:"msg"`
}
if
err
:=
json
.
Unmarshal
(
body
,
&
resp
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"easypay parse refund: %w"
,
err
)
}
if
resp
.
Code
!=
easypayCodeSuccess
{
return
nil
,
fmt
.
Errorf
(
"easypay refund failed: %s"
,
resp
.
Msg
)
}
return
&
payment
.
RefundResponse
{
RefundID
:
req
.
TradeNo
,
Status
:
payment
.
ProviderStatusSuccess
},
nil
}
func
(
e
*
EasyPay
)
resolveCID
(
paymentType
string
)
string
{
if
strings
.
HasPrefix
(
paymentType
,
"alipay"
)
{
if
v
:=
e
.
config
[
"cidAlipay"
];
v
!=
""
{
return
v
}
return
e
.
config
[
"cid"
]
}
if
v
:=
e
.
config
[
"cidWxpay"
];
v
!=
""
{
return
v
}
return
e
.
config
[
"cid"
]
}
func
(
e
*
EasyPay
)
post
(
ctx
context
.
Context
,
endpoint
string
,
params
map
[
string
]
string
)
([]
byte
,
error
)
{
form
:=
url
.
Values
{}
for
k
,
v
:=
range
params
{
form
.
Set
(
k
,
v
)
}
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodPost
,
endpoint
,
strings
.
NewReader
(
form
.
Encode
()))
if
err
!=
nil
{
return
nil
,
err
}
req
.
Header
.
Set
(
"Content-Type"
,
"application/x-www-form-urlencoded"
)
resp
,
err
:=
e
.
httpClient
.
Do
(
req
)
if
err
!=
nil
{
return
nil
,
err
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
return
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
maxEasypayResponseSize
))
}
func
easyPaySign
(
params
map
[
string
]
string
,
pkey
string
)
string
{
keys
:=
make
([]
string
,
0
,
len
(
params
))
for
k
,
v
:=
range
params
{
if
k
==
"sign"
||
k
==
"sign_type"
||
v
==
""
{
continue
}
keys
=
append
(
keys
,
k
)
}
sort
.
Strings
(
keys
)
var
buf
strings
.
Builder
for
i
,
k
:=
range
keys
{
if
i
>
0
{
_
=
buf
.
WriteByte
(
'&'
)
}
_
,
_
=
buf
.
WriteString
(
k
+
"="
+
params
[
k
])
}
_
,
_
=
buf
.
WriteString
(
pkey
)
hash
:=
md5
.
Sum
([]
byte
(
buf
.
String
()))
return
hex
.
EncodeToString
(
hash
[
:
])
}
func
easyPayVerifySign
(
params
map
[
string
]
string
,
pkey
string
,
sign
string
)
bool
{
return
hmac
.
Equal
([]
byte
(
easyPaySign
(
params
,
pkey
)),
[]
byte
(
sign
))
}
backend/internal/payment/provider/easypay_sign_test.go
0 → 100644
View file @
97f14b7a
package
provider
import
(
"testing"
)
func
TestEasyPaySignConsistentOutput
(
t
*
testing
.
T
)
{
t
.
Parallel
()
params
:=
map
[
string
]
string
{
"pid"
:
"1001"
,
"type"
:
"alipay"
,
"out_trade_no"
:
"ORDER123"
,
"name"
:
"Test Product"
,
"money"
:
"10.00"
,
}
pkey
:=
"test_secret_key"
sign1
:=
easyPaySign
(
params
,
pkey
)
sign2
:=
easyPaySign
(
params
,
pkey
)
if
sign1
!=
sign2
{
t
.
Fatalf
(
"easyPaySign should be deterministic: %q != %q"
,
sign1
,
sign2
)
}
if
len
(
sign1
)
!=
32
{
t
.
Fatalf
(
"MD5 hex should be 32 chars, got %d"
,
len
(
sign1
))
}
}
func
TestEasyPaySignExcludesSignAndSignType
(
t
*
testing
.
T
)
{
t
.
Parallel
()
pkey
:=
"my_key"
base
:=
map
[
string
]
string
{
"pid"
:
"1001"
,
"type"
:
"alipay"
,
}
withSign
:=
map
[
string
]
string
{
"pid"
:
"1001"
,
"type"
:
"alipay"
,
"sign"
:
"should_be_ignored"
,
"sign_type"
:
"MD5"
,
}
signBase
:=
easyPaySign
(
base
,
pkey
)
signWithExtra
:=
easyPaySign
(
withSign
,
pkey
)
if
signBase
!=
signWithExtra
{
t
.
Fatalf
(
"sign and sign_type should be excluded: base=%q, withExtra=%q"
,
signBase
,
signWithExtra
)
}
}
func
TestEasyPaySignExcludesEmptyValues
(
t
*
testing
.
T
)
{
t
.
Parallel
()
pkey
:=
"key123"
base
:=
map
[
string
]
string
{
"pid"
:
"1001"
,
"type"
:
"alipay"
,
}
withEmpty
:=
map
[
string
]
string
{
"pid"
:
"1001"
,
"type"
:
"alipay"
,
"device"
:
""
,
"clientip"
:
""
,
}
signBase
:=
easyPaySign
(
base
,
pkey
)
signWithEmpty
:=
easyPaySign
(
withEmpty
,
pkey
)
if
signBase
!=
signWithEmpty
{
t
.
Fatalf
(
"empty values should be excluded: base=%q, withEmpty=%q"
,
signBase
,
signWithEmpty
)
}
}
func
TestEasyPayVerifySignValid
(
t
*
testing
.
T
)
{
t
.
Parallel
()
params
:=
map
[
string
]
string
{
"pid"
:
"1001"
,
"type"
:
"alipay"
,
"out_trade_no"
:
"ORDER456"
,
"money"
:
"25.00"
,
}
pkey
:=
"secret"
sign
:=
easyPaySign
(
params
,
pkey
)
// Add sign to params (as would come in a real callback)
params
[
"sign"
]
=
sign
params
[
"sign_type"
]
=
"MD5"
if
!
easyPayVerifySign
(
params
,
pkey
,
sign
)
{
t
.
Fatal
(
"easyPayVerifySign should return true for a valid signature"
)
}
}
func
TestEasyPayVerifySignTampered
(
t
*
testing
.
T
)
{
t
.
Parallel
()
params
:=
map
[
string
]
string
{
"pid"
:
"1001"
,
"type"
:
"alipay"
,
"out_trade_no"
:
"ORDER789"
,
"money"
:
"50.00"
,
}
pkey
:=
"secret"
sign
:=
easyPaySign
(
params
,
pkey
)
// Tamper with the amount
params
[
"money"
]
=
"99.99"
if
easyPayVerifySign
(
params
,
pkey
,
sign
)
{
t
.
Fatal
(
"easyPayVerifySign should return false for tampered params"
)
}
}
func
TestEasyPayVerifySignWrongKey
(
t
*
testing
.
T
)
{
t
.
Parallel
()
params
:=
map
[
string
]
string
{
"pid"
:
"1001"
,
"type"
:
"wxpay"
,
}
sign
:=
easyPaySign
(
params
,
"correct_key"
)
if
easyPayVerifySign
(
params
,
"wrong_key"
,
sign
)
{
t
.
Fatal
(
"easyPayVerifySign should return false with wrong key"
)
}
}
func
TestEasyPaySignEmptyParams
(
t
*
testing
.
T
)
{
t
.
Parallel
()
sign
:=
easyPaySign
(
map
[
string
]
string
{},
"key123"
)
if
sign
==
""
{
t
.
Fatal
(
"easyPaySign with empty params should still produce a hash"
)
}
if
len
(
sign
)
!=
32
{
t
.
Fatalf
(
"MD5 hex should be 32 chars, got %d"
,
len
(
sign
))
}
}
func
TestEasyPaySignSortOrder
(
t
*
testing
.
T
)
{
t
.
Parallel
()
pkey
:=
"test_key"
params1
:=
map
[
string
]
string
{
"a"
:
"1"
,
"b"
:
"2"
,
"c"
:
"3"
,
}
params2
:=
map
[
string
]
string
{
"c"
:
"3"
,
"a"
:
"1"
,
"b"
:
"2"
,
}
sign1
:=
easyPaySign
(
params1
,
pkey
)
sign2
:=
easyPaySign
(
params2
,
pkey
)
if
sign1
!=
sign2
{
t
.
Fatalf
(
"easyPaySign should be order-independent: %q != %q"
,
sign1
,
sign2
)
}
}
func
TestEasyPayVerifySignWrongSignValue
(
t
*
testing
.
T
)
{
t
.
Parallel
()
params
:=
map
[
string
]
string
{
"pid"
:
"1001"
,
"type"
:
"alipay"
,
}
pkey
:=
"key"
if
easyPayVerifySign
(
params
,
pkey
,
"00000000000000000000000000000000"
)
{
t
.
Fatal
(
"easyPayVerifySign should return false for an incorrect sign value"
)
}
}
backend/internal/payment/provider/factory.go
0 → 100644
View file @
97f14b7a
package
provider
import
(
"fmt"
"github.com/Wei-Shaw/sub2api/internal/payment"
)
// CreateProvider creates a Provider from a provider key, instance ID and decrypted config.
func
CreateProvider
(
providerKey
string
,
instanceID
string
,
config
map
[
string
]
string
)
(
payment
.
Provider
,
error
)
{
switch
providerKey
{
case
payment
.
TypeEasyPay
:
return
NewEasyPay
(
instanceID
,
config
)
case
payment
.
TypeAlipay
:
return
NewAlipay
(
instanceID
,
config
)
case
payment
.
TypeWxpay
:
return
NewWxpay
(
instanceID
,
config
)
case
payment
.
TypeStripe
:
return
NewStripe
(
instanceID
,
config
)
default
:
return
nil
,
fmt
.
Errorf
(
"unknown provider key: %s"
,
providerKey
)
}
}
backend/internal/payment/provider/stripe.go
0 → 100644
View file @
97f14b7a
package
provider
import
(
"context"
"encoding/json"
"fmt"
"strings"
"sync"
"github.com/Wei-Shaw/sub2api/internal/payment"
stripe
"github.com/stripe/stripe-go/v85"
"github.com/stripe/stripe-go/v85/webhook"
)
// Stripe constants.
const
(
stripeCurrency
=
"cny"
stripeEventPaymentSuccess
=
"payment_intent.succeeded"
stripeEventPaymentFailed
=
"payment_intent.payment_failed"
)
// Stripe implements the payment.CancelableProvider interface for Stripe payments.
type
Stripe
struct
{
instanceID
string
config
map
[
string
]
string
mu
sync
.
Mutex
initialized
bool
sc
*
stripe
.
Client
}
// NewStripe creates a new Stripe provider instance.
func
NewStripe
(
instanceID
string
,
config
map
[
string
]
string
)
(
*
Stripe
,
error
)
{
if
config
[
"secretKey"
]
==
""
{
return
nil
,
fmt
.
Errorf
(
"stripe config missing required key: secretKey"
)
}
return
&
Stripe
{
instanceID
:
instanceID
,
config
:
config
,
},
nil
}
func
(
s
*
Stripe
)
ensureInit
()
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
if
!
s
.
initialized
{
s
.
sc
=
stripe
.
NewClient
(
s
.
config
[
"secretKey"
])
s
.
initialized
=
true
}
}
// GetPublishableKey returns the publishable key for frontend use.
func
(
s
*
Stripe
)
GetPublishableKey
()
string
{
return
s
.
config
[
"publishableKey"
]
}
func
(
s
*
Stripe
)
Name
()
string
{
return
"Stripe"
}
func
(
s
*
Stripe
)
ProviderKey
()
string
{
return
payment
.
TypeStripe
}
func
(
s
*
Stripe
)
SupportedTypes
()
[]
payment
.
PaymentType
{
return
[]
payment
.
PaymentType
{
payment
.
TypeStripe
}
}
// stripePaymentMethodTypes maps our PaymentType to Stripe payment_method_types.
var
stripePaymentMethodTypes
=
map
[
string
][]
string
{
payment
.
TypeCard
:
{
"card"
},
payment
.
TypeAlipay
:
{
"alipay"
},
payment
.
TypeWxpay
:
{
"wechat_pay"
},
payment
.
TypeLink
:
{
"link"
},
}
// CreatePayment creates a Stripe PaymentIntent.
func
(
s
*
Stripe
)
CreatePayment
(
ctx
context
.
Context
,
req
payment
.
CreatePaymentRequest
)
(
*
payment
.
CreatePaymentResponse
,
error
)
{
s
.
ensureInit
()
amountInCents
,
err
:=
payment
.
YuanToFen
(
req
.
Amount
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"stripe create payment: %w"
,
err
)
}
// Collect all Stripe payment_method_types from the instance's configured sub-methods
methods
:=
resolveStripeMethodTypes
(
req
.
InstanceSubMethods
)
pmTypes
:=
make
([]
*
string
,
len
(
methods
))
for
i
,
m
:=
range
methods
{
pmTypes
[
i
]
=
stripe
.
String
(
m
)
}
params
:=
&
stripe
.
PaymentIntentCreateParams
{
Amount
:
stripe
.
Int64
(
amountInCents
),
Currency
:
stripe
.
String
(
stripeCurrency
),
PaymentMethodTypes
:
pmTypes
,
Description
:
stripe
.
String
(
req
.
Subject
),
Metadata
:
map
[
string
]
string
{
"orderId"
:
req
.
OrderID
},
}
// WeChat Pay requires payment_method_options with client type
if
hasStripeMethod
(
methods
,
"wechat_pay"
)
{
params
.
PaymentMethodOptions
=
&
stripe
.
PaymentIntentCreatePaymentMethodOptionsParams
{
WeChatPay
:
&
stripe
.
PaymentIntentCreatePaymentMethodOptionsWeChatPayParams
{
Client
:
stripe
.
String
(
"web"
),
},
}
}
params
.
SetIdempotencyKey
(
fmt
.
Sprintf
(
"pi-%s"
,
req
.
OrderID
))
params
.
Context
=
ctx
pi
,
err
:=
s
.
sc
.
V1PaymentIntents
.
Create
(
ctx
,
params
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"stripe create payment: %w"
,
err
)
}
return
&
payment
.
CreatePaymentResponse
{
TradeNo
:
pi
.
ID
,
ClientSecret
:
pi
.
ClientSecret
,
},
nil
}
// QueryOrder retrieves a PaymentIntent by ID.
func
(
s
*
Stripe
)
QueryOrder
(
ctx
context
.
Context
,
tradeNo
string
)
(
*
payment
.
QueryOrderResponse
,
error
)
{
s
.
ensureInit
()
pi
,
err
:=
s
.
sc
.
V1PaymentIntents
.
Retrieve
(
ctx
,
tradeNo
,
nil
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"stripe query order: %w"
,
err
)
}
status
:=
payment
.
ProviderStatusPending
switch
pi
.
Status
{
case
stripe
.
PaymentIntentStatusSucceeded
:
status
=
payment
.
ProviderStatusPaid
case
stripe
.
PaymentIntentStatusCanceled
:
status
=
payment
.
ProviderStatusFailed
}
return
&
payment
.
QueryOrderResponse
{
TradeNo
:
pi
.
ID
,
Status
:
status
,
Amount
:
payment
.
FenToYuan
(
pi
.
Amount
),
},
nil
}
// VerifyNotification verifies a Stripe webhook event.
func
(
s
*
Stripe
)
VerifyNotification
(
_
context
.
Context
,
rawBody
string
,
headers
map
[
string
]
string
)
(
*
payment
.
PaymentNotification
,
error
)
{
s
.
ensureInit
()
webhookSecret
:=
s
.
config
[
"webhookSecret"
]
if
webhookSecret
==
""
{
return
nil
,
fmt
.
Errorf
(
"stripe webhookSecret not configured"
)
}
sig
:=
headers
[
"stripe-signature"
]
if
sig
==
""
{
return
nil
,
fmt
.
Errorf
(
"stripe notification missing stripe-signature header"
)
}
event
,
err
:=
webhook
.
ConstructEvent
([]
byte
(
rawBody
),
sig
,
webhookSecret
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"stripe verify notification: %w"
,
err
)
}
switch
event
.
Type
{
case
stripeEventPaymentSuccess
:
return
parseStripePaymentIntent
(
&
event
,
payment
.
ProviderStatusSuccess
,
rawBody
)
case
stripeEventPaymentFailed
:
return
parseStripePaymentIntent
(
&
event
,
payment
.
ProviderStatusFailed
,
rawBody
)
}
return
nil
,
nil
}
func
parseStripePaymentIntent
(
event
*
stripe
.
Event
,
status
string
,
rawBody
string
)
(
*
payment
.
PaymentNotification
,
error
)
{
var
pi
stripe
.
PaymentIntent
if
err
:=
json
.
Unmarshal
(
event
.
Data
.
Raw
,
&
pi
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"stripe parse payment_intent: %w"
,
err
)
}
return
&
payment
.
PaymentNotification
{
TradeNo
:
pi
.
ID
,
OrderID
:
pi
.
Metadata
[
"orderId"
],
Amount
:
payment
.
FenToYuan
(
pi
.
Amount
),
Status
:
status
,
RawData
:
rawBody
,
},
nil
}
// Refund creates a Stripe refund.
func
(
s
*
Stripe
)
Refund
(
ctx
context
.
Context
,
req
payment
.
RefundRequest
)
(
*
payment
.
RefundResponse
,
error
)
{
s
.
ensureInit
()
amountInCents
,
err
:=
payment
.
YuanToFen
(
req
.
Amount
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"stripe refund: %w"
,
err
)
}
params
:=
&
stripe
.
RefundCreateParams
{
PaymentIntent
:
stripe
.
String
(
req
.
TradeNo
),
Amount
:
stripe
.
Int64
(
amountInCents
),
Reason
:
stripe
.
String
(
string
(
stripe
.
RefundReasonRequestedByCustomer
)),
}
params
.
Context
=
ctx
r
,
err
:=
s
.
sc
.
V1Refunds
.
Create
(
ctx
,
params
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"stripe refund: %w"
,
err
)
}
refundStatus
:=
payment
.
ProviderStatusPending
if
r
.
Status
==
stripe
.
RefundStatusSucceeded
{
refundStatus
=
payment
.
ProviderStatusSuccess
}
return
&
payment
.
RefundResponse
{
RefundID
:
r
.
ID
,
Status
:
refundStatus
,
},
nil
}
// resolveStripeMethodTypes converts instance supported_types (comma-separated)
// into Stripe API payment_method_types. Falls back to ["card"] if empty.
func
resolveStripeMethodTypes
(
instanceSubMethods
string
)
[]
string
{
if
instanceSubMethods
==
""
{
return
[]
string
{
"card"
}
}
var
methods
[]
string
for
_
,
t
:=
range
strings
.
Split
(
instanceSubMethods
,
","
)
{
t
=
strings
.
TrimSpace
(
t
)
if
mapped
,
ok
:=
stripePaymentMethodTypes
[
t
];
ok
{
methods
=
append
(
methods
,
mapped
...
)
}
}
if
len
(
methods
)
==
0
{
return
[]
string
{
"card"
}
}
return
methods
}
// hasStripeMethod checks if the given Stripe method list contains the target method.
func
hasStripeMethod
(
methods
[]
string
,
target
string
)
bool
{
for
_
,
m
:=
range
methods
{
if
m
==
target
{
return
true
}
}
return
false
}
// CancelPayment cancels a pending PaymentIntent.
func
(
s
*
Stripe
)
CancelPayment
(
ctx
context
.
Context
,
tradeNo
string
)
error
{
s
.
ensureInit
()
_
,
err
:=
s
.
sc
.
V1PaymentIntents
.
Cancel
(
ctx
,
tradeNo
,
nil
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"stripe cancel payment: %w"
,
err
)
}
return
nil
}
// Ensure interface compliance.
var
(
_
payment
.
Provider
=
(
*
Stripe
)(
nil
)
_
payment
.
CancelableProvider
=
(
*
Stripe
)(
nil
)
)
backend/internal/payment/provider/wxpay.go
0 → 100644
View file @
97f14b7a
package
provider
import
(
"bytes"
"context"
"crypto/rsa"
"fmt"
"io"
"log/slog"
"net/http"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/wechatpay-apiv3/wechatpay-go/core"
"github.com/wechatpay-apiv3/wechatpay-go/core/auth/verifiers"
"github.com/wechatpay-apiv3/wechatpay-go/core/notify"
"github.com/wechatpay-apiv3/wechatpay-go/core/option"
"github.com/wechatpay-apiv3/wechatpay-go/services/payments"
"github.com/wechatpay-apiv3/wechatpay-go/services/payments/h5"
"github.com/wechatpay-apiv3/wechatpay-go/services/payments/native"
"github.com/wechatpay-apiv3/wechatpay-go/services/refunddomestic"
"github.com/wechatpay-apiv3/wechatpay-go/utils"
)
// WeChat Pay constants.
const
(
wxpayCurrency
=
"CNY"
wxpayH5Type
=
"Wap"
)
// WeChat Pay trade states.
const
(
wxpayTradeStateSuccess
=
"SUCCESS"
wxpayTradeStateRefund
=
"REFUND"
wxpayTradeStateClosed
=
"CLOSED"
wxpayTradeStatePayError
=
"PAYERROR"
)
// WeChat Pay notification event types.
const
(
wxpayEventTransactionSuccess
=
"TRANSACTION.SUCCESS"
)
// WeChat Pay error codes.
const
(
wxpayErrNoAuth
=
"NO_AUTH"
)
type
Wxpay
struct
{
instanceID
string
config
map
[
string
]
string
mu
sync
.
Mutex
coreClient
*
core
.
Client
notifyHandler
*
notify
.
Handler
}
func
NewWxpay
(
instanceID
string
,
config
map
[
string
]
string
)
(
*
Wxpay
,
error
)
{
required
:=
[]
string
{
"appId"
,
"mchId"
,
"privateKey"
,
"apiV3Key"
,
"publicKey"
,
"publicKeyId"
,
"certSerial"
}
for
_
,
k
:=
range
required
{
if
config
[
k
]
==
""
{
return
nil
,
fmt
.
Errorf
(
"wxpay config missing required key: %s"
,
k
)
}
}
if
len
(
config
[
"apiV3Key"
])
!=
32
{
return
nil
,
fmt
.
Errorf
(
"wxpay apiV3Key must be exactly 32 bytes, got %d"
,
len
(
config
[
"apiV3Key"
]))
}
return
&
Wxpay
{
instanceID
:
instanceID
,
config
:
config
},
nil
}
func
(
w
*
Wxpay
)
Name
()
string
{
return
"Wxpay"
}
func
(
w
*
Wxpay
)
ProviderKey
()
string
{
return
payment
.
TypeWxpay
}
func
(
w
*
Wxpay
)
SupportedTypes
()
[]
payment
.
PaymentType
{
return
[]
payment
.
PaymentType
{
payment
.
TypeWxpayDirect
}
}
func
formatPEM
(
key
,
keyType
string
)
string
{
key
=
strings
.
TrimSpace
(
key
)
if
strings
.
HasPrefix
(
key
,
"-----BEGIN"
)
{
return
key
}
return
fmt
.
Sprintf
(
"-----BEGIN %s-----
\n
%s
\n
-----END %s-----"
,
keyType
,
key
,
keyType
)
}
func
(
w
*
Wxpay
)
ensureClient
()
(
*
core
.
Client
,
error
)
{
w
.
mu
.
Lock
()
defer
w
.
mu
.
Unlock
()
if
w
.
coreClient
!=
nil
{
return
w
.
coreClient
,
nil
}
privateKey
,
publicKey
,
err
:=
w
.
loadKeyPair
()
if
err
!=
nil
{
return
nil
,
err
}
certSerial
:=
w
.
config
[
"certSerial"
]
verifier
:=
verifiers
.
NewSHA256WithRSAPubkeyVerifier
(
w
.
config
[
"publicKeyId"
],
*
publicKey
)
client
,
err
:=
core
.
NewClient
(
context
.
Background
(),
option
.
WithMerchantCredential
(
w
.
config
[
"mchId"
],
certSerial
,
privateKey
),
option
.
WithVerifier
(
verifier
))
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"wxpay init client: %w"
,
err
)
}
handler
,
err
:=
notify
.
NewRSANotifyHandler
(
w
.
config
[
"apiV3Key"
],
verifier
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"wxpay init notify handler: %w"
,
err
)
}
w
.
notifyHandler
=
handler
w
.
coreClient
=
client
return
w
.
coreClient
,
nil
}
func
(
w
*
Wxpay
)
loadKeyPair
()
(
*
rsa
.
PrivateKey
,
*
rsa
.
PublicKey
,
error
)
{
privateKey
,
err
:=
utils
.
LoadPrivateKey
(
formatPEM
(
w
.
config
[
"privateKey"
],
"PRIVATE KEY"
))
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"wxpay load private key: %w"
,
err
)
}
publicKey
,
err
:=
utils
.
LoadPublicKey
(
formatPEM
(
w
.
config
[
"publicKey"
],
"PUBLIC KEY"
))
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"wxpay load public key: %w"
,
err
)
}
return
privateKey
,
publicKey
,
nil
}
func
(
w
*
Wxpay
)
CreatePayment
(
ctx
context
.
Context
,
req
payment
.
CreatePaymentRequest
)
(
*
payment
.
CreatePaymentResponse
,
error
)
{
client
,
err
:=
w
.
ensureClient
()
if
err
!=
nil
{
return
nil
,
err
}
// Request-first, config-fallback (consistent with EasyPay/Alipay)
notifyURL
:=
req
.
NotifyURL
if
notifyURL
==
""
{
notifyURL
=
w
.
config
[
"notifyUrl"
]
}
if
notifyURL
==
""
{
return
nil
,
fmt
.
Errorf
(
"wxpay notifyUrl is required"
)
}
totalFen
,
err
:=
payment
.
YuanToFen
(
req
.
Amount
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"wxpay create payment: %w"
,
err
)
}
if
req
.
IsMobile
&&
req
.
ClientIP
!=
""
{
resp
,
err
:=
w
.
createOrder
(
ctx
,
client
,
req
,
notifyURL
,
totalFen
,
true
)
if
err
==
nil
{
return
resp
,
nil
}
if
!
strings
.
Contains
(
err
.
Error
(),
wxpayErrNoAuth
)
{
return
nil
,
err
}
slog
.
Warn
(
"wxpay H5 payment not authorized, falling back to native"
,
"order"
,
req
.
OrderID
)
}
return
w
.
createOrder
(
ctx
,
client
,
req
,
notifyURL
,
totalFen
,
false
)
}
func
(
w
*
Wxpay
)
createOrder
(
ctx
context
.
Context
,
c
*
core
.
Client
,
req
payment
.
CreatePaymentRequest
,
notifyURL
string
,
totalFen
int64
,
useH5
bool
)
(
*
payment
.
CreatePaymentResponse
,
error
)
{
if
useH5
{
return
w
.
prepayH5
(
ctx
,
c
,
req
,
notifyURL
,
totalFen
)
}
return
w
.
prepayNative
(
ctx
,
c
,
req
,
notifyURL
,
totalFen
)
}
func
(
w
*
Wxpay
)
prepayNative
(
ctx
context
.
Context
,
c
*
core
.
Client
,
req
payment
.
CreatePaymentRequest
,
notifyURL
string
,
totalFen
int64
)
(
*
payment
.
CreatePaymentResponse
,
error
)
{
svc
:=
native
.
NativeApiService
{
Client
:
c
}
cur
:=
wxpayCurrency
resp
,
_
,
err
:=
svc
.
Prepay
(
ctx
,
native
.
PrepayRequest
{
Appid
:
core
.
String
(
w
.
config
[
"appId"
]),
Mchid
:
core
.
String
(
w
.
config
[
"mchId"
]),
Description
:
core
.
String
(
req
.
Subject
),
OutTradeNo
:
core
.
String
(
req
.
OrderID
),
NotifyUrl
:
core
.
String
(
notifyURL
),
Amount
:
&
native
.
Amount
{
Total
:
core
.
Int64
(
totalFen
),
Currency
:
&
cur
},
})
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"wxpay native prepay: %w"
,
err
)
}
codeURL
:=
""
if
resp
.
CodeUrl
!=
nil
{
codeURL
=
*
resp
.
CodeUrl
}
return
&
payment
.
CreatePaymentResponse
{
TradeNo
:
req
.
OrderID
,
QRCode
:
codeURL
},
nil
}
func
(
w
*
Wxpay
)
prepayH5
(
ctx
context
.
Context
,
c
*
core
.
Client
,
req
payment
.
CreatePaymentRequest
,
notifyURL
string
,
totalFen
int64
)
(
*
payment
.
CreatePaymentResponse
,
error
)
{
svc
:=
h5
.
H5ApiService
{
Client
:
c
}
cur
:=
wxpayCurrency
tp
:=
wxpayH5Type
resp
,
_
,
err
:=
svc
.
Prepay
(
ctx
,
h5
.
PrepayRequest
{
Appid
:
core
.
String
(
w
.
config
[
"appId"
]),
Mchid
:
core
.
String
(
w
.
config
[
"mchId"
]),
Description
:
core
.
String
(
req
.
Subject
),
OutTradeNo
:
core
.
String
(
req
.
OrderID
),
NotifyUrl
:
core
.
String
(
notifyURL
),
Amount
:
&
h5
.
Amount
{
Total
:
core
.
Int64
(
totalFen
),
Currency
:
&
cur
},
SceneInfo
:
&
h5
.
SceneInfo
{
PayerClientIp
:
core
.
String
(
req
.
ClientIP
),
H5Info
:
&
h5
.
H5Info
{
Type
:
&
tp
}},
})
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"wxpay h5 prepay: %w"
,
err
)
}
h5URL
:=
""
if
resp
.
H5Url
!=
nil
{
h5URL
=
*
resp
.
H5Url
}
return
&
payment
.
CreatePaymentResponse
{
TradeNo
:
req
.
OrderID
,
PayURL
:
h5URL
},
nil
}
func
wxSV
(
s
*
string
)
string
{
if
s
==
nil
{
return
""
}
return
*
s
}
func
mapWxState
(
s
string
)
string
{
switch
s
{
case
wxpayTradeStateSuccess
:
return
payment
.
ProviderStatusPaid
case
wxpayTradeStateRefund
:
return
payment
.
ProviderStatusRefunded
case
wxpayTradeStateClosed
,
wxpayTradeStatePayError
:
return
payment
.
ProviderStatusFailed
default
:
return
payment
.
ProviderStatusPending
}
}
func
(
w
*
Wxpay
)
QueryOrder
(
ctx
context
.
Context
,
tradeNo
string
)
(
*
payment
.
QueryOrderResponse
,
error
)
{
c
,
err
:=
w
.
ensureClient
()
if
err
!=
nil
{
return
nil
,
err
}
svc
:=
native
.
NativeApiService
{
Client
:
c
}
tx
,
_
,
err
:=
svc
.
QueryOrderByOutTradeNo
(
ctx
,
native
.
QueryOrderByOutTradeNoRequest
{
OutTradeNo
:
core
.
String
(
tradeNo
),
Mchid
:
core
.
String
(
w
.
config
[
"mchId"
]),
})
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"wxpay query order: %w"
,
err
)
}
var
amt
float64
if
tx
.
Amount
!=
nil
&&
tx
.
Amount
.
Total
!=
nil
{
amt
=
payment
.
FenToYuan
(
*
tx
.
Amount
.
Total
)
}
id
:=
tradeNo
if
tx
.
TransactionId
!=
nil
{
id
=
*
tx
.
TransactionId
}
pa
:=
""
if
tx
.
SuccessTime
!=
nil
{
pa
=
*
tx
.
SuccessTime
}
return
&
payment
.
QueryOrderResponse
{
TradeNo
:
id
,
Status
:
mapWxState
(
wxSV
(
tx
.
TradeState
)),
Amount
:
amt
,
PaidAt
:
pa
},
nil
}
func
(
w
*
Wxpay
)
VerifyNotification
(
ctx
context
.
Context
,
rawBody
string
,
headers
map
[
string
]
string
)
(
*
payment
.
PaymentNotification
,
error
)
{
if
_
,
err
:=
w
.
ensureClient
();
err
!=
nil
{
return
nil
,
err
}
r
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodPost
,
"/"
,
io
.
NopCloser
(
bytes
.
NewBufferString
(
rawBody
)))
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"wxpay construct request: %w"
,
err
)
}
for
k
,
v
:=
range
headers
{
r
.
Header
.
Set
(
k
,
v
)
}
var
tx
payments
.
Transaction
nr
,
err
:=
w
.
notifyHandler
.
ParseNotifyRequest
(
ctx
,
r
,
&
tx
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"wxpay verify notification: %w"
,
err
)
}
if
nr
.
EventType
!=
wxpayEventTransactionSuccess
{
return
nil
,
nil
}
var
amt
float64
if
tx
.
Amount
!=
nil
&&
tx
.
Amount
.
Total
!=
nil
{
amt
=
payment
.
FenToYuan
(
*
tx
.
Amount
.
Total
)
}
st
:=
payment
.
ProviderStatusFailed
if
wxSV
(
tx
.
TradeState
)
==
wxpayTradeStateSuccess
{
st
=
payment
.
ProviderStatusSuccess
}
return
&
payment
.
PaymentNotification
{
TradeNo
:
wxSV
(
tx
.
TransactionId
),
OrderID
:
wxSV
(
tx
.
OutTradeNo
),
Amount
:
amt
,
Status
:
st
,
RawData
:
rawBody
,
},
nil
}
func
(
w
*
Wxpay
)
Refund
(
ctx
context
.
Context
,
req
payment
.
RefundRequest
)
(
*
payment
.
RefundResponse
,
error
)
{
c
,
err
:=
w
.
ensureClient
()
if
err
!=
nil
{
return
nil
,
err
}
rf
,
err
:=
payment
.
YuanToFen
(
req
.
Amount
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"wxpay refund amount: %w"
,
err
)
}
tf
,
err
:=
w
.
queryOrderTotalFen
(
ctx
,
c
,
req
.
OrderID
)
if
err
!=
nil
{
return
nil
,
err
}
rs
:=
refunddomestic
.
RefundsApiService
{
Client
:
c
}
cur
:=
wxpayCurrency
res
,
_
,
err
:=
rs
.
Create
(
ctx
,
refunddomestic
.
CreateRequest
{
OutTradeNo
:
core
.
String
(
req
.
OrderID
),
OutRefundNo
:
core
.
String
(
fmt
.
Sprintf
(
"%s-refund-%d"
,
req
.
OrderID
,
time
.
Now
()
.
UnixNano
())),
Reason
:
core
.
String
(
req
.
Reason
),
Amount
:
&
refunddomestic
.
AmountReq
{
Refund
:
core
.
Int64
(
rf
),
Total
:
core
.
Int64
(
tf
),
Currency
:
&
cur
},
})
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"wxpay refund: %w"
,
err
)
}
rid
:=
wxSV
(
res
.
RefundId
)
if
rid
==
""
{
rid
=
fmt
.
Sprintf
(
"%s-refund"
,
req
.
OrderID
)
}
st
:=
payment
.
ProviderStatusPending
if
res
.
Status
!=
nil
&&
*
res
.
Status
==
refunddomestic
.
STATUS_SUCCESS
{
st
=
payment
.
ProviderStatusSuccess
}
return
&
payment
.
RefundResponse
{
RefundID
:
rid
,
Status
:
st
},
nil
}
func
(
w
*
Wxpay
)
queryOrderTotalFen
(
ctx
context
.
Context
,
c
*
core
.
Client
,
orderID
string
)
(
int64
,
error
)
{
svc
:=
native
.
NativeApiService
{
Client
:
c
}
tx
,
_
,
err
:=
svc
.
QueryOrderByOutTradeNo
(
ctx
,
native
.
QueryOrderByOutTradeNoRequest
{
OutTradeNo
:
core
.
String
(
orderID
),
Mchid
:
core
.
String
(
w
.
config
[
"mchId"
]),
})
if
err
!=
nil
{
return
0
,
fmt
.
Errorf
(
"wxpay refund query order: %w"
,
err
)
}
var
tf
int64
if
tx
.
Amount
!=
nil
&&
tx
.
Amount
.
Total
!=
nil
{
tf
=
*
tx
.
Amount
.
Total
}
return
tf
,
nil
}
func
(
w
*
Wxpay
)
CancelPayment
(
ctx
context
.
Context
,
tradeNo
string
)
error
{
c
,
err
:=
w
.
ensureClient
()
if
err
!=
nil
{
return
err
}
svc
:=
native
.
NativeApiService
{
Client
:
c
}
_
,
err
=
svc
.
CloseOrder
(
ctx
,
native
.
CloseOrderRequest
{
OutTradeNo
:
core
.
String
(
tradeNo
),
Mchid
:
core
.
String
(
w
.
config
[
"mchId"
]),
})
if
err
!=
nil
{
return
fmt
.
Errorf
(
"wxpay cancel payment: %w"
,
err
)
}
return
nil
}
var
(
_
payment
.
Provider
=
(
*
Wxpay
)(
nil
)
_
payment
.
CancelableProvider
=
(
*
Wxpay
)(
nil
)
)
backend/internal/payment/provider/wxpay_test.go
0 → 100644
View file @
97f14b7a
//go:build unit
package
provider
import
(
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/payment"
)
func
TestMapWxState
(
t
*
testing
.
T
)
{
t
.
Parallel
()
tests
:=
[]
struct
{
name
string
input
string
want
string
}{
{
name
:
"SUCCESS maps to paid"
,
input
:
wxpayTradeStateSuccess
,
want
:
payment
.
ProviderStatusPaid
,
},
{
name
:
"REFUND maps to refunded"
,
input
:
wxpayTradeStateRefund
,
want
:
payment
.
ProviderStatusRefunded
,
},
{
name
:
"CLOSED maps to failed"
,
input
:
wxpayTradeStateClosed
,
want
:
payment
.
ProviderStatusFailed
,
},
{
name
:
"PAYERROR maps to failed"
,
input
:
wxpayTradeStatePayError
,
want
:
payment
.
ProviderStatusFailed
,
},
{
name
:
"unknown state maps to pending"
,
input
:
"NOTPAY"
,
want
:
payment
.
ProviderStatusPending
,
},
{
name
:
"empty string maps to pending"
,
input
:
""
,
want
:
payment
.
ProviderStatusPending
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
got
:=
mapWxState
(
tt
.
input
)
if
got
!=
tt
.
want
{
t
.
Errorf
(
"mapWxState(%q) = %q, want %q"
,
tt
.
input
,
got
,
tt
.
want
)
}
})
}
}
func
TestWxSV
(
t
*
testing
.
T
)
{
t
.
Parallel
()
tests
:=
[]
struct
{
name
string
input
*
string
want
string
}{
{
name
:
"nil pointer returns empty string"
,
input
:
nil
,
want
:
""
,
},
{
name
:
"non-nil pointer returns value"
,
input
:
strPtr
(
"hello"
),
want
:
"hello"
,
},
{
name
:
"pointer to empty string returns empty string"
,
input
:
strPtr
(
""
),
want
:
""
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
got
:=
wxSV
(
tt
.
input
)
if
got
!=
tt
.
want
{
t
.
Errorf
(
"wxSV() = %q, want %q"
,
got
,
tt
.
want
)
}
})
}
}
func
strPtr
(
s
string
)
*
string
{
return
&
s
}
func
TestFormatPEM
(
t
*
testing
.
T
)
{
t
.
Parallel
()
tests
:=
[]
struct
{
name
string
key
string
keyType
string
want
string
}{
{
name
:
"raw key gets wrapped with headers"
,
key
:
"MIIBIjANBgkqhki..."
,
keyType
:
"PUBLIC KEY"
,
want
:
"-----BEGIN PUBLIC KEY-----
\n
MIIBIjANBgkqhki...
\n
-----END PUBLIC KEY-----"
,
},
{
name
:
"already formatted key is returned as-is"
,
key
:
"-----BEGIN PRIVATE KEY-----
\n
MIIEvQIBADANBg...
\n
-----END PRIVATE KEY-----"
,
keyType
:
"PRIVATE KEY"
,
want
:
"-----BEGIN PRIVATE KEY-----
\n
MIIEvQIBADANBg...
\n
-----END PRIVATE KEY-----"
,
},
{
name
:
"key with leading/trailing whitespace is trimmed before check"
,
key
:
"
\n
MIIBIjANBgkqhki...
\n
"
,
keyType
:
"PUBLIC KEY"
,
want
:
"-----BEGIN PUBLIC KEY-----
\n
MIIBIjANBgkqhki...
\n
-----END PUBLIC KEY-----"
,
},
{
name
:
"already formatted key with whitespace is trimmed and returned"
,
key
:
" -----BEGIN RSA PRIVATE KEY-----
\n
data
\n
-----END RSA PRIVATE KEY----- "
,
keyType
:
"RSA PRIVATE KEY"
,
want
:
"-----BEGIN RSA PRIVATE KEY-----
\n
data
\n
-----END RSA PRIVATE KEY-----"
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
got
:=
formatPEM
(
tt
.
key
,
tt
.
keyType
)
if
got
!=
tt
.
want
{
t
.
Errorf
(
"formatPEM(%q, %q) =
\n
%s
\n
want:
\n
%s"
,
tt
.
key
,
tt
.
keyType
,
got
,
tt
.
want
)
}
})
}
}
func
TestNewWxpay
(
t
*
testing
.
T
)
{
t
.
Parallel
()
validConfig
:=
map
[
string
]
string
{
"appId"
:
"wx1234567890"
,
"mchId"
:
"1234567890"
,
"privateKey"
:
"fake-private-key"
,
"apiV3Key"
:
"12345678901234567890123456789012"
,
// exactly 32 bytes
"publicKey"
:
"fake-public-key"
,
"publicKeyId"
:
"key-id-001"
,
"certSerial"
:
"SERIAL001"
,
}
// helper to clone and override config fields
withOverride
:=
func
(
overrides
map
[
string
]
string
)
map
[
string
]
string
{
cfg
:=
make
(
map
[
string
]
string
,
len
(
validConfig
))
for
k
,
v
:=
range
validConfig
{
cfg
[
k
]
=
v
}
for
k
,
v
:=
range
overrides
{
cfg
[
k
]
=
v
}
return
cfg
}
tests
:=
[]
struct
{
name
string
config
map
[
string
]
string
wantErr
bool
errSubstr
string
}{
{
name
:
"valid config succeeds"
,
config
:
validConfig
,
wantErr
:
false
,
},
{
name
:
"missing appId"
,
config
:
withOverride
(
map
[
string
]
string
{
"appId"
:
""
}),
wantErr
:
true
,
errSubstr
:
"appId"
,
},
{
name
:
"missing mchId"
,
config
:
withOverride
(
map
[
string
]
string
{
"mchId"
:
""
}),
wantErr
:
true
,
errSubstr
:
"mchId"
,
},
{
name
:
"missing privateKey"
,
config
:
withOverride
(
map
[
string
]
string
{
"privateKey"
:
""
}),
wantErr
:
true
,
errSubstr
:
"privateKey"
,
},
{
name
:
"missing apiV3Key"
,
config
:
withOverride
(
map
[
string
]
string
{
"apiV3Key"
:
""
}),
wantErr
:
true
,
errSubstr
:
"apiV3Key"
,
},
{
name
:
"missing publicKey"
,
config
:
withOverride
(
map
[
string
]
string
{
"publicKey"
:
""
}),
wantErr
:
true
,
errSubstr
:
"publicKey"
,
},
{
name
:
"missing publicKeyId"
,
config
:
withOverride
(
map
[
string
]
string
{
"publicKeyId"
:
""
}),
wantErr
:
true
,
errSubstr
:
"publicKeyId"
,
},
{
name
:
"apiV3Key too short"
,
config
:
withOverride
(
map
[
string
]
string
{
"apiV3Key"
:
"short"
}),
wantErr
:
true
,
errSubstr
:
"exactly 32 bytes"
,
},
{
name
:
"apiV3Key too long"
,
config
:
withOverride
(
map
[
string
]
string
{
"apiV3Key"
:
"123456789012345678901234567890123"
}),
// 33 bytes
wantErr
:
true
,
errSubstr
:
"exactly 32 bytes"
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
got
,
err
:=
NewWxpay
(
"test-instance"
,
tt
.
config
)
if
tt
.
wantErr
{
if
err
==
nil
{
t
.
Fatal
(
"expected error, got nil"
)
}
if
tt
.
errSubstr
!=
""
&&
!
strings
.
Contains
(
err
.
Error
(),
tt
.
errSubstr
)
{
t
.
Errorf
(
"error %q should contain %q"
,
err
.
Error
(),
tt
.
errSubstr
)
}
return
}
if
err
!=
nil
{
t
.
Fatalf
(
"unexpected error: %v"
,
err
)
}
if
got
==
nil
{
t
.
Fatal
(
"expected non-nil Wxpay instance"
)
}
if
got
.
instanceID
!=
"test-instance"
{
t
.
Errorf
(
"instanceID = %q, want %q"
,
got
.
instanceID
,
"test-instance"
)
}
})
}
}
backend/internal/payment/registry.go
0 → 100644
View file @
97f14b7a
package
payment
import
(
"sync"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
// Registry is a thread-safe registry mapping PaymentType to Provider.
type
Registry
struct
{
mu
sync
.
RWMutex
providers
map
[
PaymentType
]
Provider
}
// ErrProviderNotFound is returned when a requested payment provider is not registered.
var
ErrProviderNotFound
=
infraerrors
.
NotFound
(
"PROVIDER_NOT_FOUND"
,
"payment provider not registered"
)
// NewRegistry creates a new empty provider registry.
func
NewRegistry
()
*
Registry
{
return
&
Registry
{
providers
:
make
(
map
[
PaymentType
]
Provider
),
}
}
// Register adds a provider for each of its supported payment types.
// If a type was previously registered, it is overwritten.
func
(
r
*
Registry
)
Register
(
p
Provider
)
{
r
.
mu
.
Lock
()
defer
r
.
mu
.
Unlock
()
for
_
,
t
:=
range
p
.
SupportedTypes
()
{
r
.
providers
[
t
]
=
p
}
}
// GetProvider returns the provider registered for the given payment type.
func
(
r
*
Registry
)
GetProvider
(
t
PaymentType
)
(
Provider
,
error
)
{
r
.
mu
.
RLock
()
defer
r
.
mu
.
RUnlock
()
p
,
ok
:=
r
.
providers
[
t
]
if
!
ok
{
return
nil
,
ErrProviderNotFound
}
return
p
,
nil
}
// GetProviderByKey returns the first provider whose ProviderKey matches the given key.
func
(
r
*
Registry
)
GetProviderByKey
(
key
string
)
(
Provider
,
error
)
{
r
.
mu
.
RLock
()
defer
r
.
mu
.
RUnlock
()
for
_
,
p
:=
range
r
.
providers
{
if
p
.
ProviderKey
()
==
key
{
return
p
,
nil
}
}
return
nil
,
ErrProviderNotFound
}
// GetProviderKey returns the provider key for the given payment type, or empty string if not found.
func
(
r
*
Registry
)
GetProviderKey
(
t
PaymentType
)
string
{
r
.
mu
.
RLock
()
defer
r
.
mu
.
RUnlock
()
p
,
ok
:=
r
.
providers
[
t
]
if
!
ok
{
return
""
}
return
p
.
ProviderKey
()
}
// SupportedTypes returns all currently registered payment types.
func
(
r
*
Registry
)
SupportedTypes
()
[]
PaymentType
{
r
.
mu
.
RLock
()
defer
r
.
mu
.
RUnlock
()
types
:=
make
([]
PaymentType
,
0
,
len
(
r
.
providers
))
for
t
:=
range
r
.
providers
{
types
=
append
(
types
,
t
)
}
return
types
}
// Clear removes all registered providers.
func
(
r
*
Registry
)
Clear
()
{
r
.
mu
.
Lock
()
defer
r
.
mu
.
Unlock
()
r
.
providers
=
make
(
map
[
PaymentType
]
Provider
)
}
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment