Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
a04ae28a
Commit
a04ae28a
authored
Apr 13, 2026
by
陈曦
Browse files
merge v0.1.111
parents
68f67198
ad64190b
Changes
302
Show whitespace changes
Inline
Side-by-side
Too many changes to show.
To preserve performance only
302 of 302+
files are displayed.
Plain diff
Email patch
backend/internal/handler/payment_webhook_handler.go
0 → 100644
View file @
a04ae28a
package
handler
import
(
"io"
"log/slog"
"net/http"
"net/url"
"strings"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// PaymentWebhookHandler handles payment provider webhook callbacks.
type
PaymentWebhookHandler
struct
{
paymentService
*
service
.
PaymentService
registry
*
payment
.
Registry
}
// maxWebhookBodySize is the maximum allowed webhook request body size (1 MB).
const
maxWebhookBodySize
=
1
<<
20
// webhookLogTruncateLen is the maximum length of raw body logged on verify failure.
const
webhookLogTruncateLen
=
200
// NewPaymentWebhookHandler creates a new PaymentWebhookHandler.
func
NewPaymentWebhookHandler
(
paymentService
*
service
.
PaymentService
,
registry
*
payment
.
Registry
)
*
PaymentWebhookHandler
{
return
&
PaymentWebhookHandler
{
paymentService
:
paymentService
,
registry
:
registry
,
}
}
// EasyPayNotify handles EasyPay payment notifications.
// POST /api/v1/payment/webhook/easypay
func
(
h
*
PaymentWebhookHandler
)
EasyPayNotify
(
c
*
gin
.
Context
)
{
h
.
handleNotify
(
c
,
payment
.
TypeEasyPay
)
}
// AlipayNotify handles Alipay payment notifications.
// POST /api/v1/payment/webhook/alipay
func
(
h
*
PaymentWebhookHandler
)
AlipayNotify
(
c
*
gin
.
Context
)
{
h
.
handleNotify
(
c
,
payment
.
TypeAlipay
)
}
// WxpayNotify handles WeChat Pay payment notifications.
// POST /api/v1/payment/webhook/wxpay
func
(
h
*
PaymentWebhookHandler
)
WxpayNotify
(
c
*
gin
.
Context
)
{
h
.
handleNotify
(
c
,
payment
.
TypeWxpay
)
}
// StripeWebhook handles Stripe webhook events.
// POST /api/v1/payment/webhook/stripe
func
(
h
*
PaymentWebhookHandler
)
StripeWebhook
(
c
*
gin
.
Context
)
{
h
.
handleNotify
(
c
,
payment
.
TypeStripe
)
}
// handleNotify is the shared logic for all provider webhook handlers.
func
(
h
*
PaymentWebhookHandler
)
handleNotify
(
c
*
gin
.
Context
,
providerKey
string
)
{
var
rawBody
string
if
c
.
Request
.
Method
==
http
.
MethodGet
{
// GET callbacks (e.g. EasyPay) pass params as URL query string
rawBody
=
c
.
Request
.
URL
.
RawQuery
}
else
{
body
,
err
:=
io
.
ReadAll
(
io
.
LimitReader
(
c
.
Request
.
Body
,
maxWebhookBodySize
))
if
err
!=
nil
{
slog
.
Error
(
"[Payment Webhook] failed to read body"
,
"provider"
,
providerKey
,
"error"
,
err
)
c
.
String
(
http
.
StatusBadRequest
,
"failed to read body"
)
return
}
rawBody
=
string
(
body
)
}
// Extract out_trade_no to look up the order's specific provider instance.
// This is needed when multiple instances of the same provider exist (e.g. multiple EasyPay accounts).
outTradeNo
:=
extractOutTradeNo
(
rawBody
,
providerKey
)
provider
,
err
:=
h
.
paymentService
.
GetWebhookProvider
(
c
.
Request
.
Context
(),
providerKey
,
outTradeNo
)
if
err
!=
nil
{
slog
.
Warn
(
"[Payment Webhook] provider not found"
,
"provider"
,
providerKey
,
"outTradeNo"
,
outTradeNo
,
"error"
,
err
)
writeSuccessResponse
(
c
,
providerKey
)
return
}
headers
:=
make
(
map
[
string
]
string
)
for
k
:=
range
c
.
Request
.
Header
{
headers
[
strings
.
ToLower
(
k
)]
=
c
.
GetHeader
(
k
)
}
notification
,
err
:=
provider
.
VerifyNotification
(
c
.
Request
.
Context
(),
rawBody
,
headers
)
if
err
!=
nil
{
truncatedBody
:=
rawBody
if
len
(
truncatedBody
)
>
webhookLogTruncateLen
{
truncatedBody
=
truncatedBody
[
:
webhookLogTruncateLen
]
+
"...(truncated)"
}
slog
.
Error
(
"[Payment Webhook] verify failed"
,
"provider"
,
providerKey
,
"error"
,
err
,
"method"
,
c
.
Request
.
Method
,
"bodyLen"
,
len
(
rawBody
))
slog
.
Debug
(
"[Payment Webhook] verify failed body"
,
"provider"
,
providerKey
,
"rawBody"
,
truncatedBody
)
c
.
String
(
http
.
StatusBadRequest
,
"verify failed"
)
return
}
// nil notification means irrelevant event (e.g. Stripe non-payment event); return success.
if
notification
==
nil
{
writeSuccessResponse
(
c
,
providerKey
)
return
}
if
err
:=
h
.
paymentService
.
HandlePaymentNotification
(
c
.
Request
.
Context
(),
notification
,
providerKey
);
err
!=
nil
{
slog
.
Error
(
"[Payment Webhook] handle notification failed"
,
"provider"
,
providerKey
,
"error"
,
err
)
c
.
String
(
http
.
StatusInternalServerError
,
"handle failed"
)
return
}
writeSuccessResponse
(
c
,
providerKey
)
}
// extractOutTradeNo parses the webhook body to find the out_trade_no.
// This allows looking up the correct provider instance before verification.
func
extractOutTradeNo
(
rawBody
,
providerKey
string
)
string
{
switch
providerKey
{
case
payment
.
TypeEasyPay
:
values
,
err
:=
url
.
ParseQuery
(
rawBody
)
if
err
==
nil
{
return
values
.
Get
(
"out_trade_no"
)
}
}
// For other providers (Stripe, Alipay direct, WxPay direct), the registry
// typically has only one instance, so no instance lookup is needed.
return
""
}
// wxpaySuccessResponse is the JSON response expected by WeChat Pay webhook.
type
wxpaySuccessResponse
struct
{
Code
string
`json:"code"`
Message
string
`json:"message"`
}
// WeChat Pay webhook success response constants.
const
(
wxpaySuccessCode
=
"SUCCESS"
wxpaySuccessMessage
=
"成功"
)
// writeSuccessResponse sends the provider-specific success response.
// WeChat Pay requires JSON {"code":"SUCCESS","message":"成功"};
// Stripe expects an empty 200; others accept plain text "success".
func
writeSuccessResponse
(
c
*
gin
.
Context
,
providerKey
string
)
{
switch
providerKey
{
case
payment
.
TypeWxpay
:
c
.
JSON
(
http
.
StatusOK
,
wxpaySuccessResponse
{
Code
:
wxpaySuccessCode
,
Message
:
wxpaySuccessMessage
})
case
payment
.
TypeStripe
:
c
.
String
(
http
.
StatusOK
,
""
)
default
:
c
.
String
(
http
.
StatusOK
,
"success"
)
}
}
backend/internal/handler/payment_webhook_handler_test.go
0 → 100644
View file @
a04ae28a
//go:build unit
package
handler
import
(
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func
TestWriteSuccessResponse
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
tests
:=
[]
struct
{
name
string
providerKey
string
wantCode
int
wantContentType
string
wantBody
string
checkJSON
bool
wantJSONCode
string
wantJSONMessage
string
}{
{
name
:
"wxpay returns JSON with code SUCCESS"
,
providerKey
:
"wxpay"
,
wantCode
:
http
.
StatusOK
,
wantContentType
:
"application/json"
,
checkJSON
:
true
,
wantJSONCode
:
"SUCCESS"
,
wantJSONMessage
:
"成功"
,
},
{
name
:
"stripe returns empty 200"
,
providerKey
:
"stripe"
,
wantCode
:
http
.
StatusOK
,
wantContentType
:
"text/plain"
,
wantBody
:
""
,
},
{
name
:
"easypay returns plain text success"
,
providerKey
:
"easypay"
,
wantCode
:
http
.
StatusOK
,
wantContentType
:
"text/plain"
,
wantBody
:
"success"
,
},
{
name
:
"alipay returns plain text success"
,
providerKey
:
"alipay"
,
wantCode
:
http
.
StatusOK
,
wantContentType
:
"text/plain"
,
wantBody
:
"success"
,
},
{
name
:
"unknown provider returns plain text success"
,
providerKey
:
"unknown_provider"
,
wantCode
:
http
.
StatusOK
,
wantContentType
:
"text/plain"
,
wantBody
:
"success"
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
writeSuccessResponse
(
c
,
tt
.
providerKey
)
assert
.
Equal
(
t
,
tt
.
wantCode
,
w
.
Code
)
assert
.
Contains
(
t
,
w
.
Header
()
.
Get
(
"Content-Type"
),
tt
.
wantContentType
)
if
tt
.
checkJSON
{
var
resp
wxpaySuccessResponse
err
:=
json
.
Unmarshal
(
w
.
Body
.
Bytes
(),
&
resp
)
require
.
NoError
(
t
,
err
,
"response body should be valid JSON"
)
assert
.
Equal
(
t
,
tt
.
wantJSONCode
,
resp
.
Code
)
assert
.
Equal
(
t
,
tt
.
wantJSONMessage
,
resp
.
Message
)
}
else
{
assert
.
Equal
(
t
,
tt
.
wantBody
,
w
.
Body
.
String
())
}
})
}
}
func
TestWebhookConstants
(
t
*
testing
.
T
)
{
t
.
Run
(
"maxWebhookBodySize is 1MB"
,
func
(
t
*
testing
.
T
)
{
assert
.
Equal
(
t
,
int64
(
1
<<
20
),
int64
(
maxWebhookBodySize
))
})
t
.
Run
(
"webhookLogTruncateLen is 200"
,
func
(
t
*
testing
.
T
)
{
assert
.
Equal
(
t
,
200
,
webhookLogTruncateLen
)
})
}
backend/internal/handler/setting_handler.go
View file @
a04ae28a
...
...
@@ -51,10 +51,15 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
HideCcsImportButton
:
settings
.
HideCcsImportButton
,
PurchaseSubscriptionEnabled
:
settings
.
PurchaseSubscriptionEnabled
,
PurchaseSubscriptionURL
:
settings
.
PurchaseSubscriptionURL
,
TableDefaultPageSize
:
settings
.
TableDefaultPageSize
,
TablePageSizeOptions
:
settings
.
TablePageSizeOptions
,
CustomMenuItems
:
dto
.
ParseUserVisibleMenuItems
(
settings
.
CustomMenuItems
),
CustomEndpoints
:
dto
.
ParseCustomEndpoints
(
settings
.
CustomEndpoints
),
LinuxDoOAuthEnabled
:
settings
.
LinuxDoOAuthEnabled
,
OIDCOAuthEnabled
:
settings
.
OIDCOAuthEnabled
,
OIDCOAuthProviderName
:
settings
.
OIDCOAuthProviderName
,
BackendModeEnabled
:
settings
.
BackendModeEnabled
,
PaymentEnabled
:
settings
.
PaymentEnabled
,
Version
:
h
.
version
,
})
}
backend/internal/handler/usage_handler.go
View file @
a04ae28a
...
...
@@ -119,7 +119,12 @@ func (h *UsageHandler) List(c *gin.Context) {
endTime
=
&
t
}
params
:=
pagination
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
}
params
:=
pagination
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
,
SortBy
:
c
.
DefaultQuery
(
"sort_by"
,
"created_at"
),
SortOrder
:
c
.
DefaultQuery
(
"sort_order"
,
"desc"
),
}
filters
:=
usagestats
.
UsageLogFilters
{
UserID
:
subject
.
UserID
,
// Always filter by current user for security
APIKeyID
:
apiKeyID
,
...
...
backend/internal/handler/usage_handler_request_type_test.go
View file @
a04ae28a
...
...
@@ -16,10 +16,12 @@ import (
type
userUsageRepoCapture
struct
{
service
.
UsageLogRepository
listParams
pagination
.
PaginationParams
listFilters
usagestats
.
UsageLogFilters
}
func
(
s
*
userUsageRepoCapture
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
filters
usagestats
.
UsageLogFilters
)
([]
service
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
s
.
listParams
=
params
s
.
listFilters
=
filters
return
[]
service
.
UsageLog
{},
&
pagination
.
PaginationResult
{
Total
:
0
,
...
...
backend/internal/handler/usage_handler_sort_test.go
0 → 100644
View file @
a04ae28a
package
handler
import
(
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
)
func
TestUserUsageListSortParams
(
t
*
testing
.
T
)
{
repo
:=
&
userUsageRepoCapture
{}
router
:=
newUserUsageRequestTypeTestRouter
(
repo
)
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/usage?sort_by=model&sort_order=ASC"
,
nil
)
rec
:=
httptest
.
NewRecorder
()
router
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
require
.
Equal
(
t
,
"model"
,
repo
.
listParams
.
SortBy
)
require
.
Equal
(
t
,
"ASC"
,
repo
.
listParams
.
SortOrder
)
}
func
TestUserUsageListSortDefaults
(
t
*
testing
.
T
)
{
repo
:=
&
userUsageRepoCapture
{}
router
:=
newUserUsageRequestTypeTestRouter
(
repo
)
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/usage"
,
nil
)
rec
:=
httptest
.
NewRecorder
()
router
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
require
.
Equal
(
t
,
"created_at"
,
repo
.
listParams
.
SortBy
)
require
.
Equal
(
t
,
"desc"
,
repo
.
listParams
.
SortOrder
)
}
backend/internal/handler/wire.go
View file @
a04ae28a
...
...
@@ -34,6 +34,7 @@ func ProvideAdminHandlers(
apiKeyHandler
*
admin
.
AdminAPIKeyHandler
,
scheduledTestHandler
*
admin
.
ScheduledTestHandler
,
channelHandler
*
admin
.
ChannelHandler
,
paymentHandler
*
admin
.
PaymentHandler
,
)
*
AdminHandlers
{
return
&
AdminHandlers
{
Dashboard
:
dashboardHandler
,
...
...
@@ -61,6 +62,7 @@ func ProvideAdminHandlers(
APIKey
:
apiKeyHandler
,
ScheduledTest
:
scheduledTestHandler
,
Channel
:
channelHandler
,
Payment
:
paymentHandler
,
}
}
...
...
@@ -88,6 +90,8 @@ func ProvideHandlers(
openaiGatewayHandler
*
OpenAIGatewayHandler
,
settingHandler
*
SettingHandler
,
totpHandler
*
TotpHandler
,
paymentHandler
*
PaymentHandler
,
paymentWebhookHandler
*
PaymentWebhookHandler
,
_
*
service
.
IdempotencyCoordinator
,
_
*
service
.
IdempotencyCleanupService
,
)
*
Handlers
{
...
...
@@ -104,6 +108,8 @@ func ProvideHandlers(
OpenAIGateway
:
openaiGatewayHandler
,
Setting
:
settingHandler
,
Totp
:
totpHandler
,
Payment
:
paymentHandler
,
PaymentWebhook
:
paymentWebhookHandler
,
}
}
...
...
@@ -121,6 +127,8 @@ var ProviderSet = wire.NewSet(
NewOpenAIGatewayHandler
,
NewTotpHandler
,
ProvideSettingHandler
,
NewPaymentHandler
,
NewPaymentWebhookHandler
,
// Admin handlers
admin
.
NewDashboardHandler
,
...
...
@@ -148,6 +156,7 @@ var ProviderSet = wire.NewSet(
admin
.
NewAdminAPIKeyHandler
,
admin
.
NewScheduledTestHandler
,
admin
.
NewChannelHandler
,
admin
.
NewPaymentHandler
,
// AdminHandlers and Handlers constructors
ProvideAdminHandlers
,
...
...
backend/internal/payment/amount.go
0 → 100644
View file @
a04ae28a
package
payment
import
(
"fmt"
"github.com/shopspring/decimal"
)
const
centsPerYuan
=
100
// YuanToFen converts a CNY yuan string (e.g. "10.50") to fen (int64).
// Uses shopspring/decimal for precision.
func
YuanToFen
(
yuanStr
string
)
(
int64
,
error
)
{
d
,
err
:=
decimal
.
NewFromString
(
yuanStr
)
if
err
!=
nil
{
return
0
,
fmt
.
Errorf
(
"invalid amount: %s"
,
yuanStr
)
}
return
d
.
Mul
(
decimal
.
NewFromInt
(
centsPerYuan
))
.
IntPart
(),
nil
}
// FenToYuan converts fen (int64) to yuan as a float64 for interface compatibility.
func
FenToYuan
(
fen
int64
)
float64
{
return
decimal
.
NewFromInt
(
fen
)
.
Div
(
decimal
.
NewFromInt
(
centsPerYuan
))
.
InexactFloat64
()
}
backend/internal/payment/amount_test.go
0 → 100644
View file @
a04ae28a
//go:build unit
package
payment
import
(
"math"
"testing"
)
func
TestYuanToFen
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
input
string
want
int64
wantErr
bool
}{
// Normal values
{
name
:
"one yuan"
,
input
:
"1.00"
,
want
:
100
},
{
name
:
"ten yuan fifty fen"
,
input
:
"10.50"
,
want
:
1050
},
{
name
:
"one fen"
,
input
:
"0.01"
,
want
:
1
},
{
name
:
"large amount"
,
input
:
"99999.99"
,
want
:
9999999
},
// Edge: zero
{
name
:
"zero no decimal"
,
input
:
"0"
,
want
:
0
},
{
name
:
"zero with decimal"
,
input
:
"0.00"
,
want
:
0
},
// IEEE 754 precision edge case: 1.15 * 100 = 114.99999... in float64
{
name
:
"ieee754 precision 1.15"
,
input
:
"1.15"
,
want
:
115
},
// More precision edge cases
{
name
:
"ieee754 precision 0.1"
,
input
:
"0.1"
,
want
:
10
},
{
name
:
"ieee754 precision 0.2"
,
input
:
"0.2"
,
want
:
20
},
{
name
:
"ieee754 precision 33.33"
,
input
:
"33.33"
,
want
:
3333
},
// Large value
{
name
:
"hundred thousand"
,
input
:
"100000.00"
,
want
:
10000000
},
// Integer without decimal
{
name
:
"integer 5"
,
input
:
"5"
,
want
:
500
},
{
name
:
"integer 100"
,
input
:
"100"
,
want
:
10000
},
// Single decimal place
{
name
:
"single decimal 1.5"
,
input
:
"1.5"
,
want
:
150
},
// Negative values
{
name
:
"negative one yuan"
,
input
:
"-1.00"
,
want
:
-
100
},
{
name
:
"negative with fen"
,
input
:
"-10.50"
,
want
:
-
1050
},
// Invalid inputs
{
name
:
"empty string"
,
input
:
""
,
wantErr
:
true
},
{
name
:
"alphabetic"
,
input
:
"abc"
,
wantErr
:
true
},
{
name
:
"double dot"
,
input
:
"1.2.3"
,
wantErr
:
true
},
{
name
:
"spaces"
,
input
:
" "
,
wantErr
:
true
},
{
name
:
"special chars"
,
input
:
"$10.00"
,
wantErr
:
true
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
got
,
err
:=
YuanToFen
(
tt
.
input
)
if
tt
.
wantErr
{
if
err
==
nil
{
t
.
Errorf
(
"YuanToFen(%q) expected error, got %d"
,
tt
.
input
,
got
)
}
return
}
if
err
!=
nil
{
t
.
Fatalf
(
"YuanToFen(%q) unexpected error: %v"
,
tt
.
input
,
err
)
}
if
got
!=
tt
.
want
{
t
.
Errorf
(
"YuanToFen(%q) = %d, want %d"
,
tt
.
input
,
got
,
tt
.
want
)
}
})
}
}
func
TestFenToYuan
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
fen
int64
want
float64
}{
{
name
:
"one yuan"
,
fen
:
100
,
want
:
1.0
},
{
name
:
"ten yuan fifty fen"
,
fen
:
1050
,
want
:
10.5
},
{
name
:
"one fen"
,
fen
:
1
,
want
:
0.01
},
{
name
:
"zero"
,
fen
:
0
,
want
:
0.0
},
{
name
:
"large amount"
,
fen
:
9999999
,
want
:
99999.99
},
{
name
:
"negative"
,
fen
:
-
100
,
want
:
-
1.0
},
{
name
:
"negative with fen"
,
fen
:
-
1050
,
want
:
-
10.5
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
got
:=
FenToYuan
(
tt
.
fen
)
if
math
.
Abs
(
got
-
tt
.
want
)
>
1e-9
{
t
.
Errorf
(
"FenToYuan(%d) = %f, want %f"
,
tt
.
fen
,
got
,
tt
.
want
)
}
})
}
}
func
TestYuanToFenRoundTrip
(
t
*
testing
.
T
)
{
// Verify that converting yuan->fen->yuan preserves the value.
cases
:=
[]
struct
{
yuan
string
fen
int64
}{
{
"0.01"
,
1
},
{
"1.00"
,
100
},
{
"10.50"
,
1050
},
{
"99999.99"
,
9999999
},
}
for
_
,
tc
:=
range
cases
{
fen
,
err
:=
YuanToFen
(
tc
.
yuan
)
if
err
!=
nil
{
t
.
Fatalf
(
"YuanToFen(%q) unexpected error: %v"
,
tc
.
yuan
,
err
)
}
if
fen
!=
tc
.
fen
{
t
.
Errorf
(
"YuanToFen(%q) = %d, want %d"
,
tc
.
yuan
,
fen
,
tc
.
fen
)
}
yuan
:=
FenToYuan
(
fen
)
// Parse expected yuan back for comparison
expectedYuan
:=
FenToYuan
(
tc
.
fen
)
if
math
.
Abs
(
yuan
-
expectedYuan
)
>
1e-9
{
t
.
Errorf
(
"round-trip: FenToYuan(%d) = %f, want %f"
,
fen
,
yuan
,
expectedYuan
)
}
}
}
backend/internal/payment/crypto.go
0 → 100644
View file @
a04ae28a
package
payment
import
(
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"fmt"
"io"
"strings"
)
// Encrypt encrypts plaintext using AES-256-GCM with the given 32-byte key.
// The output format is "iv:authTag:ciphertext" where each component is base64-encoded,
// matching the Node.js crypto.ts format for cross-compatibility.
func
Encrypt
(
plaintext
string
,
key
[]
byte
)
(
string
,
error
)
{
if
len
(
key
)
!=
32
{
return
""
,
fmt
.
Errorf
(
"encryption key must be 32 bytes, got %d"
,
len
(
key
))
}
block
,
err
:=
aes
.
NewCipher
(
key
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"create AES cipher: %w"
,
err
)
}
gcm
,
err
:=
cipher
.
NewGCM
(
block
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"create GCM: %w"
,
err
)
}
nonce
:=
make
([]
byte
,
gcm
.
NonceSize
())
// 12 bytes for GCM
if
_
,
err
:=
io
.
ReadFull
(
rand
.
Reader
,
nonce
);
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"generate nonce: %w"
,
err
)
}
// Seal appends the ciphertext + auth tag
sealed
:=
gcm
.
Seal
(
nil
,
nonce
,
[]
byte
(
plaintext
),
nil
)
// Split sealed into ciphertext and auth tag (last 16 bytes)
tagSize
:=
gcm
.
Overhead
()
ciphertext
:=
sealed
[
:
len
(
sealed
)
-
tagSize
]
authTag
:=
sealed
[
len
(
sealed
)
-
tagSize
:
]
// Format: iv:authTag:ciphertext (all base64)
return
fmt
.
Sprintf
(
"%s:%s:%s"
,
base64
.
StdEncoding
.
EncodeToString
(
nonce
),
base64
.
StdEncoding
.
EncodeToString
(
authTag
),
base64
.
StdEncoding
.
EncodeToString
(
ciphertext
),
),
nil
}
// Decrypt decrypts a ciphertext string produced by Encrypt.
// The input format is "iv:authTag:ciphertext" where each component is base64-encoded.
func
Decrypt
(
ciphertext
string
,
key
[]
byte
)
(
string
,
error
)
{
if
len
(
key
)
!=
32
{
return
""
,
fmt
.
Errorf
(
"encryption key must be 32 bytes, got %d"
,
len
(
key
))
}
parts
:=
strings
.
SplitN
(
ciphertext
,
":"
,
3
)
if
len
(
parts
)
!=
3
{
return
""
,
fmt
.
Errorf
(
"invalid ciphertext format: expected iv:authTag:ciphertext"
)
}
nonce
,
err
:=
base64
.
StdEncoding
.
DecodeString
(
parts
[
0
])
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"decode IV: %w"
,
err
)
}
authTag
,
err
:=
base64
.
StdEncoding
.
DecodeString
(
parts
[
1
])
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"decode auth tag: %w"
,
err
)
}
encrypted
,
err
:=
base64
.
StdEncoding
.
DecodeString
(
parts
[
2
])
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"decode ciphertext: %w"
,
err
)
}
block
,
err
:=
aes
.
NewCipher
(
key
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"create AES cipher: %w"
,
err
)
}
gcm
,
err
:=
cipher
.
NewGCM
(
block
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"create GCM: %w"
,
err
)
}
// Reconstruct the sealed data: ciphertext + authTag
sealed
:=
append
(
encrypted
,
authTag
...
)
plaintext
,
err
:=
gcm
.
Open
(
nil
,
nonce
,
sealed
,
nil
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"decrypt: %w"
,
err
)
}
return
string
(
plaintext
),
nil
}
backend/internal/payment/crypto_test.go
0 → 100644
View file @
a04ae28a
package
payment
import
(
"crypto/rand"
"strings"
"testing"
)
func
makeKey
(
t
*
testing
.
T
)
[]
byte
{
t
.
Helper
()
key
:=
make
([]
byte
,
32
)
if
_
,
err
:=
rand
.
Read
(
key
);
err
!=
nil
{
t
.
Fatalf
(
"generate random key: %v"
,
err
)
}
return
key
}
func
TestEncryptDecryptRoundTrip
(
t
*
testing
.
T
)
{
t
.
Parallel
()
key
:=
makeKey
(
t
)
plaintexts
:=
[]
string
{
"hello world"
,
"short"
,
"a longer string with special chars: !@#$%^&*()"
,
`{"key":"value","num":42}`
,
"你好世界 unicode test 🎉"
,
strings
.
Repeat
(
"x"
,
10000
),
}
for
_
,
pt
:=
range
plaintexts
{
encrypted
,
err
:=
Encrypt
(
pt
,
key
)
if
err
!=
nil
{
t
.
Fatalf
(
"Encrypt(%q) error: %v"
,
pt
[
:
min
(
len
(
pt
),
30
)],
err
)
}
decrypted
,
err
:=
Decrypt
(
encrypted
,
key
)
if
err
!=
nil
{
t
.
Fatalf
(
"Decrypt error for plaintext %q: %v"
,
pt
[
:
min
(
len
(
pt
),
30
)],
err
)
}
if
decrypted
!=
pt
{
t
.
Fatalf
(
"round-trip failed: got %q, want %q"
,
decrypted
[
:
min
(
len
(
decrypted
),
30
)],
pt
[
:
min
(
len
(
pt
),
30
)])
}
}
}
func
TestEncryptProducesDifferentCiphertexts
(
t
*
testing
.
T
)
{
t
.
Parallel
()
key
:=
makeKey
(
t
)
ct1
,
err
:=
Encrypt
(
"same plaintext"
,
key
)
if
err
!=
nil
{
t
.
Fatalf
(
"first Encrypt error: %v"
,
err
)
}
ct2
,
err
:=
Encrypt
(
"same plaintext"
,
key
)
if
err
!=
nil
{
t
.
Fatalf
(
"second Encrypt error: %v"
,
err
)
}
if
ct1
==
ct2
{
t
.
Fatal
(
"two encryptions of the same plaintext should produce different ciphertexts (random nonce)"
)
}
}
func
TestDecryptWithWrongKeyFails
(
t
*
testing
.
T
)
{
t
.
Parallel
()
key1
:=
makeKey
(
t
)
key2
:=
makeKey
(
t
)
encrypted
,
err
:=
Encrypt
(
"secret data"
,
key1
)
if
err
!=
nil
{
t
.
Fatalf
(
"Encrypt error: %v"
,
err
)
}
_
,
err
=
Decrypt
(
encrypted
,
key2
)
if
err
==
nil
{
t
.
Fatal
(
"Decrypt with wrong key should fail, but got nil error"
)
}
}
func
TestEncryptRejectsInvalidKeyLength
(
t
*
testing
.
T
)
{
t
.
Parallel
()
badKeys
:=
[][]
byte
{
nil
,
make
([]
byte
,
0
),
make
([]
byte
,
16
),
make
([]
byte
,
31
),
make
([]
byte
,
33
),
make
([]
byte
,
64
),
}
for
_
,
key
:=
range
badKeys
{
_
,
err
:=
Encrypt
(
"test"
,
key
)
if
err
==
nil
{
t
.
Fatalf
(
"Encrypt should reject key of length %d"
,
len
(
key
))
}
}
}
func
TestDecryptRejectsInvalidKeyLength
(
t
*
testing
.
T
)
{
t
.
Parallel
()
badKeys
:=
[][]
byte
{
nil
,
make
([]
byte
,
16
),
make
([]
byte
,
33
),
}
for
_
,
key
:=
range
badKeys
{
_
,
err
:=
Decrypt
(
"dummydata:dummydata:dummydata"
,
key
)
if
err
==
nil
{
t
.
Fatalf
(
"Decrypt should reject key of length %d"
,
len
(
key
))
}
}
}
func
TestEncryptEmptyPlaintext
(
t
*
testing
.
T
)
{
t
.
Parallel
()
key
:=
makeKey
(
t
)
encrypted
,
err
:=
Encrypt
(
""
,
key
)
if
err
!=
nil
{
t
.
Fatalf
(
"Encrypt empty plaintext error: %v"
,
err
)
}
decrypted
,
err
:=
Decrypt
(
encrypted
,
key
)
if
err
!=
nil
{
t
.
Fatalf
(
"Decrypt empty plaintext error: %v"
,
err
)
}
if
decrypted
!=
""
{
t
.
Fatalf
(
"expected empty string, got %q"
,
decrypted
)
}
}
func
TestEncryptDecryptUnicodeJSON
(
t
*
testing
.
T
)
{
t
.
Parallel
()
key
:=
makeKey
(
t
)
jsonContent
:=
`{"name":"测试用户","email":"test@example.com","balance":100.50}`
encrypted
,
err
:=
Encrypt
(
jsonContent
,
key
)
if
err
!=
nil
{
t
.
Fatalf
(
"Encrypt JSON error: %v"
,
err
)
}
decrypted
,
err
:=
Decrypt
(
encrypted
,
key
)
if
err
!=
nil
{
t
.
Fatalf
(
"Decrypt JSON error: %v"
,
err
)
}
if
decrypted
!=
jsonContent
{
t
.
Fatalf
(
"JSON round-trip failed: got %q, want %q"
,
decrypted
,
jsonContent
)
}
}
func
TestDecryptInvalidFormat
(
t
*
testing
.
T
)
{
t
.
Parallel
()
key
:=
makeKey
(
t
)
invalidInputs
:=
[]
string
{
""
,
"nodelimiter"
,
"only:two"
,
"invalid:base64:!!!"
,
}
for
_
,
input
:=
range
invalidInputs
{
_
,
err
:=
Decrypt
(
input
,
key
)
if
err
==
nil
{
t
.
Fatalf
(
"Decrypt(%q) should fail but got nil error"
,
input
)
}
}
}
func
TestCiphertextFormat
(
t
*
testing
.
T
)
{
t
.
Parallel
()
key
:=
makeKey
(
t
)
encrypted
,
err
:=
Encrypt
(
"test"
,
key
)
if
err
!=
nil
{
t
.
Fatalf
(
"Encrypt error: %v"
,
err
)
}
parts
:=
strings
.
SplitN
(
encrypted
,
":"
,
3
)
if
len
(
parts
)
!=
3
{
t
.
Fatalf
(
"ciphertext should have format iv:authTag:ciphertext, got %d parts"
,
len
(
parts
))
}
for
i
,
part
:=
range
parts
{
if
part
==
""
{
t
.
Fatalf
(
"ciphertext part %d is empty"
,
i
)
}
}
}
backend/internal/payment/fee.go
0 → 100644
View file @
a04ae28a
package
payment
import
(
"github.com/shopspring/decimal"
)
// CalculatePayAmount computes the total pay amount given a recharge amount and
// fee rate (percentage). Fee = amount * feeRate / 100, rounded UP (away from zero)
// to 2 decimal places. The returned string is formatted to exactly 2 decimal places.
// If feeRate <= 0, the amount is returned as-is (formatted to 2 decimal places).
func
CalculatePayAmount
(
rechargeAmount
float64
,
feeRate
float64
)
string
{
amount
:=
decimal
.
NewFromFloat
(
rechargeAmount
)
if
feeRate
<=
0
{
return
amount
.
StringFixed
(
2
)
}
rate
:=
decimal
.
NewFromFloat
(
feeRate
)
fee
:=
amount
.
Mul
(
rate
)
.
Div
(
decimal
.
NewFromInt
(
100
))
.
RoundUp
(
2
)
return
amount
.
Add
(
fee
)
.
StringFixed
(
2
)
}
backend/internal/payment/fee_test.go
0 → 100644
View file @
a04ae28a
package
payment
import
(
"testing"
)
func
TestCalculatePayAmount
(
t
*
testing
.
T
)
{
t
.
Parallel
()
tests
:=
[]
struct
{
name
string
amount
float64
feeRate
float64
expected
string
}{
{
name
:
"zero fee rate returns same amount"
,
amount
:
100.00
,
feeRate
:
0
,
expected
:
"100.00"
,
},
{
name
:
"negative fee rate returns same amount"
,
amount
:
50.00
,
feeRate
:
-
5
,
expected
:
"50.00"
,
},
{
name
:
"1 percent fee rate"
,
amount
:
100.00
,
feeRate
:
1
,
expected
:
"101.00"
,
},
{
name
:
"5 percent fee on 200"
,
amount
:
200.00
,
feeRate
:
5
,
expected
:
"210.00"
,
},
{
name
:
"fee rounds UP to 2 decimal places"
,
amount
:
100.00
,
feeRate
:
3
,
expected
:
"103.00"
,
},
{
name
:
"fee rounds UP small remainder"
,
amount
:
10.00
,
feeRate
:
3.33
,
expected
:
"10.34"
,
// 10 * 3.33 / 100 = 0.333 -> round up -> 0.34
},
{
name
:
"very small amount"
,
amount
:
0.01
,
feeRate
:
1
,
expected
:
"0.02"
,
// 0.01 * 1/100 = 0.0001 -> round up -> 0.01 -> total 0.02
},
{
name
:
"large amount"
,
amount
:
99999.99
,
feeRate
:
10
,
expected
:
"109999.99"
,
// 99999.99 * 10/100 = 9999.999 -> round up -> 10000.00 -> total 109999.99
},
{
name
:
"100 percent fee rate doubles amount"
,
amount
:
50.00
,
feeRate
:
100
,
expected
:
"100.00"
,
},
{
name
:
"precision 0.01 fee difference"
,
amount
:
100.00
,
feeRate
:
1.01
,
expected
:
"101.01"
,
// 100 * 1.01/100 = 1.01
},
{
name
:
"precision 0.02 fee"
,
amount
:
100.00
,
feeRate
:
1.02
,
expected
:
"101.02"
,
},
{
name
:
"zero amount with positive fee"
,
amount
:
0
,
feeRate
:
5
,
expected
:
"0.00"
,
},
{
name
:
"fractional amount no fee"
,
amount
:
19.99
,
feeRate
:
0
,
expected
:
"19.99"
,
},
{
name
:
"fractional fee that causes rounding up"
,
amount
:
33.33
,
feeRate
:
7.77
,
expected
:
"35.92"
,
// 33.33 * 7.77 / 100 = 2.589741 -> round up -> 2.59 -> total 35.92
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
got
:=
CalculatePayAmount
(
tt
.
amount
,
tt
.
feeRate
)
if
got
!=
tt
.
expected
{
t
.
Fatalf
(
"CalculatePayAmount(%v, %v) = %q, want %q"
,
tt
.
amount
,
tt
.
feeRate
,
got
,
tt
.
expected
)
}
})
}
}
backend/internal/payment/load_balancer.go
0 → 100644
View file @
a04ae28a
package
payment
import
(
"context"
"encoding/json"
"fmt"
"log/slog"
"strings"
"sync/atomic"
"time"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
)
// Strategy represents a load balancing strategy for provider instance selection.
type
Strategy
string
const
(
StrategyRoundRobin
Strategy
=
"round-robin"
StrategyLeastAmount
Strategy
=
"least-amount"
)
// ChannelLimits holds limits for a single payment channel within a provider instance.
type
ChannelLimits
struct
{
DailyLimit
float64
`json:"dailyLimit,omitempty"`
SingleMin
float64
`json:"singleMin,omitempty"`
SingleMax
float64
`json:"singleMax,omitempty"`
}
// InstanceLimits holds per-channel limits for a provider instance (JSON).
type
InstanceLimits
map
[
string
]
ChannelLimits
// LoadBalancer selects a provider instance for a given payment type.
type
LoadBalancer
interface
{
GetInstanceConfig
(
ctx
context
.
Context
,
instanceID
int64
)
(
map
[
string
]
string
,
error
)
SelectInstance
(
ctx
context
.
Context
,
providerKey
string
,
paymentType
PaymentType
,
strategy
Strategy
,
orderAmount
float64
)
(
*
InstanceSelection
,
error
)
}
// DefaultLoadBalancer implements LoadBalancer using database queries.
type
DefaultLoadBalancer
struct
{
db
*
dbent
.
Client
encryptionKey
[]
byte
counter
atomic
.
Uint64
}
// NewDefaultLoadBalancer creates a new load balancer.
func
NewDefaultLoadBalancer
(
db
*
dbent
.
Client
,
encryptionKey
[]
byte
)
*
DefaultLoadBalancer
{
return
&
DefaultLoadBalancer
{
db
:
db
,
encryptionKey
:
encryptionKey
}
}
// instanceCandidate pairs an instance with its pre-fetched daily usage.
type
instanceCandidate
struct
{
inst
*
dbent
.
PaymentProviderInstance
dailyUsed
float64
// includes PENDING orders
}
// SelectInstance picks an enabled instance for the given provider key and payment type.
//
// Flow:
// 1. Query all enabled instances for providerKey, filter by supported paymentType
// 2. Batch-query daily usage (PENDING + PAID + COMPLETED + RECHARGING) for all candidates
// 3. Filter out instances where: single-min/max violated OR daily remaining < orderAmount
// 4. Pick from survivors using the configured strategy (round-robin / least-amount)
// 5. If all filtered out, fall back to full list (let the provider itself reject)
func
(
lb
*
DefaultLoadBalancer
)
SelectInstance
(
ctx
context
.
Context
,
providerKey
string
,
paymentType
PaymentType
,
strategy
Strategy
,
orderAmount
float64
,
)
(
*
InstanceSelection
,
error
)
{
// Step 1: query enabled instances matching payment type.
instances
,
err
:=
lb
.
queryEnabledInstances
(
ctx
,
providerKey
,
paymentType
)
if
err
!=
nil
{
return
nil
,
err
}
// Step 2: batch-fetch daily usage for all candidates.
candidates
:=
lb
.
attachDailyUsage
(
ctx
,
instances
)
// Step 3: filter by limits.
available
:=
filterByLimits
(
candidates
,
paymentType
,
orderAmount
)
if
len
(
available
)
==
0
{
slog
.
Warn
(
"all instances exceeded limits, using full candidate list"
,
"provider"
,
providerKey
,
"payment_type"
,
paymentType
,
"order_amount"
,
orderAmount
,
"count"
,
len
(
candidates
))
available
=
candidates
}
// Step 4: pick by strategy.
selected
:=
lb
.
pickByStrategy
(
available
,
strategy
)
return
lb
.
buildSelection
(
selected
.
inst
)
}
// queryEnabledInstances returns enabled instances for providerKey that support paymentType.
func
(
lb
*
DefaultLoadBalancer
)
queryEnabledInstances
(
ctx
context
.
Context
,
providerKey
string
,
paymentType
PaymentType
,
)
([]
*
dbent
.
PaymentProviderInstance
,
error
)
{
instances
,
err
:=
lb
.
db
.
PaymentProviderInstance
.
Query
()
.
Where
(
paymentproviderinstance
.
ProviderKey
(
providerKey
),
paymentproviderinstance
.
Enabled
(
true
),
)
.
Order
(
dbent
.
Asc
(
paymentproviderinstance
.
FieldSortOrder
))
.
All
(
ctx
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query provider instances: %w"
,
err
)
}
var
matched
[]
*
dbent
.
PaymentProviderInstance
for
_
,
inst
:=
range
instances
{
if
paymentType
==
providerKey
||
InstanceSupportsType
(
inst
.
SupportedTypes
,
paymentType
)
{
matched
=
append
(
matched
,
inst
)
}
}
if
len
(
matched
)
==
0
{
return
nil
,
fmt
.
Errorf
(
"no enabled instance for provider %s type %s"
,
providerKey
,
paymentType
)
}
return
matched
,
nil
}
// attachDailyUsage queries daily usage for each instance in a single pass.
// Usage includes PENDING orders to avoid over-committing capacity.
func
(
lb
*
DefaultLoadBalancer
)
attachDailyUsage
(
ctx
context
.
Context
,
instances
[]
*
dbent
.
PaymentProviderInstance
,
)
[]
instanceCandidate
{
todayStart
:=
startOfDay
(
time
.
Now
())
// Collect instance IDs.
ids
:=
make
([]
string
,
len
(
instances
))
for
i
,
inst
:=
range
instances
{
ids
[
i
]
=
fmt
.
Sprintf
(
"%d"
,
inst
.
ID
)
}
// Batch query: sum pay_amount grouped by provider_instance_id.
type
row
struct
{
InstanceID
string
`json:"provider_instance_id"`
Sum
float64
`json:"sum"`
}
var
rows
[]
row
err
:=
lb
.
db
.
PaymentOrder
.
Query
()
.
Where
(
paymentorder
.
ProviderInstanceIDIn
(
ids
...
),
paymentorder
.
StatusIn
(
OrderStatusPending
,
OrderStatusPaid
,
OrderStatusCompleted
,
OrderStatusRecharging
,
),
paymentorder
.
CreatedAtGTE
(
todayStart
),
)
.
GroupBy
(
paymentorder
.
FieldProviderInstanceID
)
.
Aggregate
(
dbent
.
Sum
(
paymentorder
.
FieldPayAmount
))
.
Scan
(
ctx
,
&
rows
)
if
err
!=
nil
{
slog
.
Warn
(
"batch daily usage query failed, treating all as zero"
,
"error"
,
err
)
}
usageMap
:=
make
(
map
[
string
]
float64
,
len
(
rows
))
for
_
,
r
:=
range
rows
{
usageMap
[
r
.
InstanceID
]
=
r
.
Sum
}
candidates
:=
make
([]
instanceCandidate
,
len
(
instances
))
for
i
,
inst
:=
range
instances
{
candidates
[
i
]
=
instanceCandidate
{
inst
:
inst
,
dailyUsed
:
usageMap
[
fmt
.
Sprintf
(
"%d"
,
inst
.
ID
)],
}
}
return
candidates
}
// filterByLimits removes instances that cannot accommodate the order:
// - orderAmount outside single-transaction [min, max]
// - daily remaining capacity (limit - used) < orderAmount
func
filterByLimits
(
candidates
[]
instanceCandidate
,
paymentType
PaymentType
,
orderAmount
float64
)
[]
instanceCandidate
{
var
result
[]
instanceCandidate
for
_
,
c
:=
range
candidates
{
cl
:=
getInstanceChannelLimits
(
c
.
inst
,
paymentType
)
if
cl
.
SingleMin
>
0
&&
orderAmount
<
cl
.
SingleMin
{
slog
.
Info
(
"order below instance single min, skipping"
,
"instance_id"
,
c
.
inst
.
ID
,
"order"
,
orderAmount
,
"min"
,
cl
.
SingleMin
)
continue
}
if
cl
.
SingleMax
>
0
&&
orderAmount
>
cl
.
SingleMax
{
slog
.
Info
(
"order above instance single max, skipping"
,
"instance_id"
,
c
.
inst
.
ID
,
"order"
,
orderAmount
,
"max"
,
cl
.
SingleMax
)
continue
}
if
cl
.
DailyLimit
>
0
&&
c
.
dailyUsed
+
orderAmount
>
cl
.
DailyLimit
{
slog
.
Info
(
"instance daily remaining insufficient, skipping"
,
"instance_id"
,
c
.
inst
.
ID
,
"used"
,
c
.
dailyUsed
,
"order"
,
orderAmount
,
"limit"
,
cl
.
DailyLimit
)
continue
}
result
=
append
(
result
,
c
)
}
return
result
}
// getInstanceChannelLimits returns the channel limits for a specific payment type.
func
getInstanceChannelLimits
(
inst
*
dbent
.
PaymentProviderInstance
,
paymentType
PaymentType
)
ChannelLimits
{
if
inst
.
Limits
==
""
{
return
ChannelLimits
{}
}
var
limits
InstanceLimits
if
err
:=
json
.
Unmarshal
([]
byte
(
inst
.
Limits
),
&
limits
);
err
!=
nil
{
return
ChannelLimits
{}
}
// For Stripe, limits are stored under the provider key "stripe".
lookupKey
:=
paymentType
if
inst
.
ProviderKey
==
"stripe"
{
lookupKey
=
"stripe"
}
if
cl
,
ok
:=
limits
[
lookupKey
];
ok
{
return
cl
}
return
ChannelLimits
{}
}
// pickByStrategy selects one instance from the available candidates.
func
(
lb
*
DefaultLoadBalancer
)
pickByStrategy
(
candidates
[]
instanceCandidate
,
strategy
Strategy
)
instanceCandidate
{
if
strategy
==
StrategyLeastAmount
&&
len
(
candidates
)
>
1
{
return
pickLeastAmount
(
candidates
)
}
// Default: round-robin.
idx
:=
lb
.
counter
.
Add
(
1
)
%
uint64
(
len
(
candidates
))
return
candidates
[
idx
]
}
// pickLeastAmount selects the instance with the lowest daily usage.
// No extra DB queries — usage was pre-fetched in attachDailyUsage.
func
pickLeastAmount
(
candidates
[]
instanceCandidate
)
instanceCandidate
{
best
:=
candidates
[
0
]
for
_
,
c
:=
range
candidates
[
1
:
]
{
if
c
.
dailyUsed
<
best
.
dailyUsed
{
best
=
c
}
}
return
best
}
func
(
lb
*
DefaultLoadBalancer
)
buildSelection
(
selected
*
dbent
.
PaymentProviderInstance
)
(
*
InstanceSelection
,
error
)
{
config
,
err
:=
lb
.
decryptConfig
(
selected
.
Config
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"decrypt instance %d config: %w"
,
selected
.
ID
,
err
)
}
if
selected
.
PaymentMode
!=
""
{
config
[
"paymentMode"
]
=
selected
.
PaymentMode
}
return
&
InstanceSelection
{
InstanceID
:
fmt
.
Sprintf
(
"%d"
,
selected
.
ID
),
Config
:
config
,
SupportedTypes
:
selected
.
SupportedTypes
,
PaymentMode
:
selected
.
PaymentMode
,
},
nil
}
func
(
lb
*
DefaultLoadBalancer
)
decryptConfig
(
encrypted
string
)
(
map
[
string
]
string
,
error
)
{
plaintext
,
err
:=
Decrypt
(
encrypted
,
lb
.
encryptionKey
)
if
err
!=
nil
{
return
nil
,
err
}
var
config
map
[
string
]
string
if
err
:=
json
.
Unmarshal
([]
byte
(
plaintext
),
&
config
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"unmarshal config: %w"
,
err
)
}
return
config
,
nil
}
// GetInstanceDailyAmount returns the total completed order amount for an instance today.
func
(
lb
*
DefaultLoadBalancer
)
GetInstanceDailyAmount
(
ctx
context
.
Context
,
instanceID
string
)
(
float64
,
error
)
{
todayStart
:=
startOfDay
(
time
.
Now
())
var
result
[]
struct
{
Sum
float64
`json:"sum"`
}
err
:=
lb
.
db
.
PaymentOrder
.
Query
()
.
Where
(
paymentorder
.
ProviderInstanceID
(
instanceID
),
paymentorder
.
StatusIn
(
OrderStatusCompleted
,
OrderStatusPaid
,
OrderStatusRecharging
),
paymentorder
.
PaidAtGTE
(
todayStart
),
)
.
Aggregate
(
dbent
.
Sum
(
paymentorder
.
FieldPayAmount
))
.
Scan
(
ctx
,
&
result
)
if
err
!=
nil
{
return
0
,
fmt
.
Errorf
(
"query daily amount: %w"
,
err
)
}
if
len
(
result
)
>
0
{
return
result
[
0
]
.
Sum
,
nil
}
return
0
,
nil
}
func
startOfDay
(
t
time
.
Time
)
time
.
Time
{
return
time
.
Date
(
t
.
Year
(),
t
.
Month
(),
t
.
Day
(),
0
,
0
,
0
,
0
,
t
.
Location
())
}
// InstanceSupportsType checks if the given supported types string includes the target type.
// An empty supportedTypes string means all types are supported.
func
InstanceSupportsType
(
supportedTypes
string
,
target
PaymentType
)
bool
{
if
supportedTypes
==
""
{
return
true
}
for
_
,
t
:=
range
strings
.
Split
(
supportedTypes
,
","
)
{
if
strings
.
TrimSpace
(
t
)
==
target
{
return
true
}
}
return
false
}
// GetInstanceConfig decrypts and returns the configuration for a provider instance by ID.
func
(
lb
*
DefaultLoadBalancer
)
GetInstanceConfig
(
ctx
context
.
Context
,
instanceID
int64
)
(
map
[
string
]
string
,
error
)
{
inst
,
err
:=
lb
.
db
.
PaymentProviderInstance
.
Get
(
ctx
,
instanceID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get instance %d: %w"
,
instanceID
,
err
)
}
return
lb
.
decryptConfig
(
inst
.
Config
)
}
backend/internal/payment/load_balancer_test.go
0 → 100644
View file @
a04ae28a
//go:build unit
package
payment
import
(
"encoding/json"
"testing"
"time"
dbent
"github.com/Wei-Shaw/sub2api/ent"
)
func
TestInstanceSupportsType
(
t
*
testing
.
T
)
{
t
.
Parallel
()
tests
:=
[]
struct
{
name
string
supportedTypes
string
target
PaymentType
expected
bool
}{
{
name
:
"exact match single type"
,
supportedTypes
:
"alipay"
,
target
:
"alipay"
,
expected
:
true
,
},
{
name
:
"no match single type"
,
supportedTypes
:
"wxpay"
,
target
:
"alipay"
,
expected
:
false
,
},
{
name
:
"match in comma-separated list"
,
supportedTypes
:
"alipay,wxpay,stripe"
,
target
:
"wxpay"
,
expected
:
true
,
},
{
name
:
"first in comma-separated list"
,
supportedTypes
:
"alipay,wxpay"
,
target
:
"alipay"
,
expected
:
true
,
},
{
name
:
"last in comma-separated list"
,
supportedTypes
:
"alipay,wxpay,stripe"
,
target
:
"stripe"
,
expected
:
true
,
},
{
name
:
"no match in comma-separated list"
,
supportedTypes
:
"alipay,wxpay"
,
target
:
"stripe"
,
expected
:
false
,
},
{
name
:
"empty target"
,
supportedTypes
:
"alipay,wxpay"
,
target
:
""
,
expected
:
false
,
},
{
name
:
"types with spaces are trimmed"
,
supportedTypes
:
" alipay , wxpay "
,
target
:
"alipay"
,
expected
:
true
,
},
{
name
:
"partial match should not succeed"
,
supportedTypes
:
"alipay_direct"
,
target
:
"alipay"
,
expected
:
false
,
},
{
name
:
"empty supported types means all supported"
,
supportedTypes
:
""
,
target
:
"alipay"
,
expected
:
true
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
got
:=
InstanceSupportsType
(
tt
.
supportedTypes
,
tt
.
target
)
if
got
!=
tt
.
expected
{
t
.
Fatalf
(
"InstanceSupportsType(%q, %q) = %v, want %v"
,
tt
.
supportedTypes
,
tt
.
target
,
got
,
tt
.
expected
)
}
})
}
}
// ---------------------------------------------------------------------------
// Helper to build test PaymentProviderInstance values
// ---------------------------------------------------------------------------
func
testInstance
(
id
int64
,
providerKey
,
limits
string
)
*
dbent
.
PaymentProviderInstance
{
return
&
dbent
.
PaymentProviderInstance
{
ID
:
id
,
ProviderKey
:
providerKey
,
Limits
:
limits
,
Enabled
:
true
,
}
}
// makeLimitsJSON builds a limits JSON string for a single payment type.
func
makeLimitsJSON
(
paymentType
string
,
cl
ChannelLimits
)
string
{
m
:=
map
[
string
]
ChannelLimits
{
paymentType
:
cl
}
b
,
_
:=
json
.
Marshal
(
m
)
return
string
(
b
)
}
// ---------------------------------------------------------------------------
// filterByLimits
// ---------------------------------------------------------------------------
func
TestFilterByLimits
(
t
*
testing
.
T
)
{
t
.
Parallel
()
tests
:=
[]
struct
{
name
string
candidates
[]
instanceCandidate
paymentType
PaymentType
orderAmount
float64
wantIDs
[]
int64
// expected surviving instance IDs
}{
{
name
:
"order below SingleMin is filtered out"
,
candidates
:
[]
instanceCandidate
{
{
inst
:
testInstance
(
1
,
"easypay"
,
makeLimitsJSON
(
"alipay"
,
ChannelLimits
{
SingleMin
:
10
})),
dailyUsed
:
0
},
},
paymentType
:
"alipay"
,
orderAmount
:
5
,
wantIDs
:
nil
,
},
{
name
:
"order at exact SingleMin boundary passes"
,
candidates
:
[]
instanceCandidate
{
{
inst
:
testInstance
(
1
,
"easypay"
,
makeLimitsJSON
(
"alipay"
,
ChannelLimits
{
SingleMin
:
10
})),
dailyUsed
:
0
},
},
paymentType
:
"alipay"
,
orderAmount
:
10
,
wantIDs
:
[]
int64
{
1
},
},
{
name
:
"order above SingleMax is filtered out"
,
candidates
:
[]
instanceCandidate
{
{
inst
:
testInstance
(
1
,
"easypay"
,
makeLimitsJSON
(
"alipay"
,
ChannelLimits
{
SingleMax
:
100
})),
dailyUsed
:
0
},
},
paymentType
:
"alipay"
,
orderAmount
:
150
,
wantIDs
:
nil
,
},
{
name
:
"order at exact SingleMax boundary passes"
,
candidates
:
[]
instanceCandidate
{
{
inst
:
testInstance
(
1
,
"easypay"
,
makeLimitsJSON
(
"alipay"
,
ChannelLimits
{
SingleMax
:
100
})),
dailyUsed
:
0
},
},
paymentType
:
"alipay"
,
orderAmount
:
100
,
wantIDs
:
[]
int64
{
1
},
},
{
name
:
"daily used + orderAmount exceeding dailyLimit is filtered out"
,
candidates
:
[]
instanceCandidate
{
{
inst
:
testInstance
(
1
,
"easypay"
,
makeLimitsJSON
(
"alipay"
,
ChannelLimits
{
DailyLimit
:
500
})),
dailyUsed
:
480
},
},
paymentType
:
"alipay"
,
orderAmount
:
30
,
wantIDs
:
nil
,
// 480+30=510 > 500
},
{
name
:
"daily used + orderAmount equal to dailyLimit passes (strict greater-than)"
,
candidates
:
[]
instanceCandidate
{
{
inst
:
testInstance
(
1
,
"easypay"
,
makeLimitsJSON
(
"alipay"
,
ChannelLimits
{
DailyLimit
:
500
})),
dailyUsed
:
480
},
},
paymentType
:
"alipay"
,
orderAmount
:
20
,
wantIDs
:
[]
int64
{
1
},
// 480+20=500, 500 > 500 is false → passes
},
{
name
:
"daily used + orderAmount below dailyLimit passes"
,
candidates
:
[]
instanceCandidate
{
{
inst
:
testInstance
(
1
,
"easypay"
,
makeLimitsJSON
(
"alipay"
,
ChannelLimits
{
DailyLimit
:
500
})),
dailyUsed
:
400
},
},
paymentType
:
"alipay"
,
orderAmount
:
50
,
wantIDs
:
[]
int64
{
1
},
},
{
name
:
"no limits configured passes through"
,
candidates
:
[]
instanceCandidate
{
{
inst
:
testInstance
(
1
,
"easypay"
,
""
),
dailyUsed
:
99999
},
},
paymentType
:
"alipay"
,
orderAmount
:
100
,
wantIDs
:
[]
int64
{
1
},
},
{
name
:
"multiple candidates with partial filtering"
,
candidates
:
[]
instanceCandidate
{
// singleMax=50, order=80 → filtered out
{
inst
:
testInstance
(
1
,
"easypay"
,
makeLimitsJSON
(
"alipay"
,
ChannelLimits
{
SingleMax
:
50
})),
dailyUsed
:
0
},
// no limits → passes
{
inst
:
testInstance
(
2
,
"easypay"
,
""
),
dailyUsed
:
0
},
// singleMin=100, order=80 → filtered out
{
inst
:
testInstance
(
3
,
"easypay"
,
makeLimitsJSON
(
"alipay"
,
ChannelLimits
{
SingleMin
:
100
})),
dailyUsed
:
0
},
// daily limit ok → passes (500+80=580 < 1000)
{
inst
:
testInstance
(
4
,
"easypay"
,
makeLimitsJSON
(
"alipay"
,
ChannelLimits
{
DailyLimit
:
1000
})),
dailyUsed
:
500
},
},
paymentType
:
"alipay"
,
orderAmount
:
80
,
wantIDs
:
[]
int64
{
2
,
4
},
},
{
name
:
"zero SingleMin and SingleMax means no single-transaction limit"
,
candidates
:
[]
instanceCandidate
{
{
inst
:
testInstance
(
1
,
"easypay"
,
makeLimitsJSON
(
"alipay"
,
ChannelLimits
{
SingleMin
:
0
,
SingleMax
:
0
,
DailyLimit
:
0
})),
dailyUsed
:
0
},
},
paymentType
:
"alipay"
,
orderAmount
:
99999
,
wantIDs
:
[]
int64
{
1
},
},
{
name
:
"all limits combined - order passes all checks"
,
candidates
:
[]
instanceCandidate
{
{
inst
:
testInstance
(
1
,
"easypay"
,
makeLimitsJSON
(
"alipay"
,
ChannelLimits
{
SingleMin
:
10
,
SingleMax
:
200
,
DailyLimit
:
1000
})),
dailyUsed
:
500
},
},
paymentType
:
"alipay"
,
orderAmount
:
50
,
wantIDs
:
[]
int64
{
1
},
},
{
name
:
"all limits combined - order fails SingleMin"
,
candidates
:
[]
instanceCandidate
{
{
inst
:
testInstance
(
1
,
"easypay"
,
makeLimitsJSON
(
"alipay"
,
ChannelLimits
{
SingleMin
:
10
,
SingleMax
:
200
,
DailyLimit
:
1000
})),
dailyUsed
:
500
},
},
paymentType
:
"alipay"
,
orderAmount
:
5
,
wantIDs
:
nil
,
},
{
name
:
"empty candidates returns empty"
,
candidates
:
nil
,
paymentType
:
"alipay"
,
orderAmount
:
10
,
wantIDs
:
nil
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
got
:=
filterByLimits
(
tt
.
candidates
,
tt
.
paymentType
,
tt
.
orderAmount
)
gotIDs
:=
make
([]
int64
,
len
(
got
))
for
i
,
c
:=
range
got
{
gotIDs
[
i
]
=
c
.
inst
.
ID
}
if
!
int64SliceEqual
(
gotIDs
,
tt
.
wantIDs
)
{
t
.
Fatalf
(
"filterByLimits() returned IDs %v, want %v"
,
gotIDs
,
tt
.
wantIDs
)
}
})
}
}
// ---------------------------------------------------------------------------
// pickLeastAmount
// ---------------------------------------------------------------------------
func
TestPickLeastAmount
(
t
*
testing
.
T
)
{
t
.
Parallel
()
t
.
Run
(
"picks candidate with lowest dailyUsed"
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
candidates
:=
[]
instanceCandidate
{
{
inst
:
testInstance
(
1
,
"easypay"
,
""
),
dailyUsed
:
300
},
{
inst
:
testInstance
(
2
,
"easypay"
,
""
),
dailyUsed
:
100
},
{
inst
:
testInstance
(
3
,
"easypay"
,
""
),
dailyUsed
:
200
},
}
got
:=
pickLeastAmount
(
candidates
)
if
got
.
inst
.
ID
!=
2
{
t
.
Fatalf
(
"pickLeastAmount() picked instance %d, want 2"
,
got
.
inst
.
ID
)
}
})
t
.
Run
(
"with equal dailyUsed picks the first one"
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
candidates
:=
[]
instanceCandidate
{
{
inst
:
testInstance
(
1
,
"easypay"
,
""
),
dailyUsed
:
100
},
{
inst
:
testInstance
(
2
,
"easypay"
,
""
),
dailyUsed
:
100
},
{
inst
:
testInstance
(
3
,
"easypay"
,
""
),
dailyUsed
:
200
},
}
got
:=
pickLeastAmount
(
candidates
)
if
got
.
inst
.
ID
!=
1
{
t
.
Fatalf
(
"pickLeastAmount() picked instance %d, want 1 (first with lowest)"
,
got
.
inst
.
ID
)
}
})
t
.
Run
(
"single candidate returns that candidate"
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
candidates
:=
[]
instanceCandidate
{
{
inst
:
testInstance
(
42
,
"easypay"
,
""
),
dailyUsed
:
999
},
}
got
:=
pickLeastAmount
(
candidates
)
if
got
.
inst
.
ID
!=
42
{
t
.
Fatalf
(
"pickLeastAmount() picked instance %d, want 42"
,
got
.
inst
.
ID
)
}
})
t
.
Run
(
"zero usage among non-zero picks zero"
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
candidates
:=
[]
instanceCandidate
{
{
inst
:
testInstance
(
1
,
"easypay"
,
""
),
dailyUsed
:
500
},
{
inst
:
testInstance
(
2
,
"easypay"
,
""
),
dailyUsed
:
0
},
{
inst
:
testInstance
(
3
,
"easypay"
,
""
),
dailyUsed
:
300
},
}
got
:=
pickLeastAmount
(
candidates
)
if
got
.
inst
.
ID
!=
2
{
t
.
Fatalf
(
"pickLeastAmount() picked instance %d, want 2"
,
got
.
inst
.
ID
)
}
})
}
// ---------------------------------------------------------------------------
// getInstanceChannelLimits
// ---------------------------------------------------------------------------
func
TestGetInstanceChannelLimits
(
t
*
testing
.
T
)
{
t
.
Parallel
()
tests
:=
[]
struct
{
name
string
inst
*
dbent
.
PaymentProviderInstance
paymentType
PaymentType
want
ChannelLimits
}{
{
name
:
"empty limits string returns zero ChannelLimits"
,
inst
:
testInstance
(
1
,
"easypay"
,
""
),
paymentType
:
"alipay"
,
want
:
ChannelLimits
{},
},
{
name
:
"invalid JSON returns zero ChannelLimits"
,
inst
:
testInstance
(
1
,
"easypay"
,
"not-json{"
),
paymentType
:
"alipay"
,
want
:
ChannelLimits
{},
},
{
name
:
"valid JSON with matching payment type"
,
inst
:
testInstance
(
1
,
"easypay"
,
`{"alipay":{"singleMin":5,"singleMax":200,"dailyLimit":1000}}`
),
paymentType
:
"alipay"
,
want
:
ChannelLimits
{
SingleMin
:
5
,
SingleMax
:
200
,
DailyLimit
:
1000
},
},
{
name
:
"payment type not in limits returns zero ChannelLimits"
,
inst
:
testInstance
(
1
,
"easypay"
,
`{"alipay":{"singleMin":5,"singleMax":200}}`
),
paymentType
:
"wxpay"
,
want
:
ChannelLimits
{},
},
{
name
:
"stripe provider uses stripe lookup key regardless of payment type"
,
inst
:
testInstance
(
1
,
"stripe"
,
`{"stripe":{"singleMin":10,"singleMax":500,"dailyLimit":5000}}`
),
paymentType
:
"alipay"
,
want
:
ChannelLimits
{
SingleMin
:
10
,
SingleMax
:
500
,
DailyLimit
:
5000
},
},
{
name
:
"stripe provider ignores payment type key even if present"
,
inst
:
testInstance
(
1
,
"stripe"
,
`{"stripe":{"singleMin":10,"singleMax":500},"alipay":{"singleMin":1,"singleMax":100}}`
),
paymentType
:
"alipay"
,
want
:
ChannelLimits
{
SingleMin
:
10
,
SingleMax
:
500
},
},
{
name
:
"non-stripe provider uses payment type as lookup key"
,
inst
:
testInstance
(
1
,
"easypay"
,
`{"alipay":{"singleMin":5},"wxpay":{"singleMin":10}}`
),
paymentType
:
"wxpay"
,
want
:
ChannelLimits
{
SingleMin
:
10
},
},
{
name
:
"valid JSON with partial limits (only dailyLimit)"
,
inst
:
testInstance
(
1
,
"easypay"
,
`{"alipay":{"dailyLimit":800}}`
),
paymentType
:
"alipay"
,
want
:
ChannelLimits
{
DailyLimit
:
800
},
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
got
:=
getInstanceChannelLimits
(
tt
.
inst
,
tt
.
paymentType
)
if
got
!=
tt
.
want
{
t
.
Fatalf
(
"getInstanceChannelLimits() = %+v, want %+v"
,
got
,
tt
.
want
)
}
})
}
}
// ---------------------------------------------------------------------------
// startOfDay
// ---------------------------------------------------------------------------
func
TestStartOfDay
(
t
*
testing
.
T
)
{
t
.
Parallel
()
tests
:=
[]
struct
{
name
string
in
time
.
Time
want
time
.
Time
}{
{
name
:
"midday returns midnight of same day"
,
in
:
time
.
Date
(
2025
,
6
,
15
,
14
,
30
,
45
,
123456789
,
time
.
UTC
),
want
:
time
.
Date
(
2025
,
6
,
15
,
0
,
0
,
0
,
0
,
time
.
UTC
),
},
{
name
:
"midnight returns same time"
,
in
:
time
.
Date
(
2025
,
1
,
1
,
0
,
0
,
0
,
0
,
time
.
UTC
),
want
:
time
.
Date
(
2025
,
1
,
1
,
0
,
0
,
0
,
0
,
time
.
UTC
),
},
{
name
:
"last second of day returns midnight of same day"
,
in
:
time
.
Date
(
2025
,
12
,
31
,
23
,
59
,
59
,
999999999
,
time
.
UTC
),
want
:
time
.
Date
(
2025
,
12
,
31
,
0
,
0
,
0
,
0
,
time
.
UTC
),
},
{
name
:
"preserves timezone location"
,
in
:
time
.
Date
(
2025
,
3
,
10
,
15
,
0
,
0
,
0
,
time
.
FixedZone
(
"CST"
,
8
*
3600
)),
want
:
time
.
Date
(
2025
,
3
,
10
,
0
,
0
,
0
,
0
,
time
.
FixedZone
(
"CST"
,
8
*
3600
)),
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
got
:=
startOfDay
(
tt
.
in
)
if
!
got
.
Equal
(
tt
.
want
)
{
t
.
Fatalf
(
"startOfDay(%v) = %v, want %v"
,
tt
.
in
,
got
,
tt
.
want
)
}
// Also verify location is preserved.
if
got
.
Location
()
.
String
()
!=
tt
.
want
.
Location
()
.
String
()
{
t
.
Fatalf
(
"startOfDay() location = %v, want %v"
,
got
.
Location
(),
tt
.
want
.
Location
())
}
})
}
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
// int64SliceEqual compares two int64 slices for equality.
// Both nil and empty slices are treated as equal.
func
int64SliceEqual
(
a
,
b
[]
int64
)
bool
{
if
len
(
a
)
==
0
&&
len
(
b
)
==
0
{
return
true
}
if
len
(
a
)
!=
len
(
b
)
{
return
false
}
for
i
:=
range
a
{
if
a
[
i
]
!=
b
[
i
]
{
return
false
}
}
return
true
}
backend/internal/payment/provider/alipay.go
0 → 100644
View file @
a04ae28a
package
provider
import
(
"context"
"fmt"
"net/url"
"strconv"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/smartwalle/alipay/v3"
)
// Alipay product codes.
const
(
alipayProductCodePagePay
=
"FAST_INSTANT_TRADE_PAY"
alipayProductCodeWapPay
=
"QUICK_WAP_WAY"
)
// Alipay response constants.
const
(
alipayFundChangeYes
=
"Y"
alipayErrTradeNotExist
=
"ACQ.TRADE_NOT_EXIST"
alipayRefundSuffix
=
"-refund"
)
// Alipay implements payment.Provider and payment.CancelableProvider using the smartwalle/alipay SDK.
type
Alipay
struct
{
instanceID
string
config
map
[
string
]
string
// appId, privateKey, publicKey (or alipayPublicKey), notifyUrl, returnUrl
mu
sync
.
Mutex
client
*
alipay
.
Client
}
// NewAlipay creates a new Alipay provider instance.
func
NewAlipay
(
instanceID
string
,
config
map
[
string
]
string
)
(
*
Alipay
,
error
)
{
required
:=
[]
string
{
"appId"
,
"privateKey"
}
for
_
,
k
:=
range
required
{
if
config
[
k
]
==
""
{
return
nil
,
fmt
.
Errorf
(
"alipay config missing required key: %s"
,
k
)
}
}
return
&
Alipay
{
instanceID
:
instanceID
,
config
:
config
,
},
nil
}
func
(
a
*
Alipay
)
getClient
()
(
*
alipay
.
Client
,
error
)
{
a
.
mu
.
Lock
()
defer
a
.
mu
.
Unlock
()
if
a
.
client
!=
nil
{
return
a
.
client
,
nil
}
client
,
err
:=
alipay
.
New
(
a
.
config
[
"appId"
],
a
.
config
[
"privateKey"
],
true
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"alipay init client: %w"
,
err
)
}
pubKey
:=
a
.
config
[
"publicKey"
]
if
pubKey
==
""
{
pubKey
=
a
.
config
[
"alipayPublicKey"
]
}
if
pubKey
==
""
{
return
nil
,
fmt
.
Errorf
(
"alipay config missing required key: publicKey (or alipayPublicKey)"
)
}
if
err
:=
client
.
LoadAliPayPublicKey
(
pubKey
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"alipay load public key: %w"
,
err
)
}
a
.
client
=
client
return
a
.
client
,
nil
}
func
(
a
*
Alipay
)
Name
()
string
{
return
"Alipay"
}
func
(
a
*
Alipay
)
ProviderKey
()
string
{
return
payment
.
TypeAlipay
}
func
(
a
*
Alipay
)
SupportedTypes
()
[]
payment
.
PaymentType
{
return
[]
payment
.
PaymentType
{
payment
.
TypeAlipayDirect
}
}
// CreatePayment creates an Alipay payment page URL.
func
(
a
*
Alipay
)
CreatePayment
(
_
context
.
Context
,
req
payment
.
CreatePaymentRequest
)
(
*
payment
.
CreatePaymentResponse
,
error
)
{
client
,
err
:=
a
.
getClient
()
if
err
!=
nil
{
return
nil
,
err
}
notifyURL
:=
a
.
config
[
"notifyUrl"
]
if
req
.
NotifyURL
!=
""
{
notifyURL
=
req
.
NotifyURL
}
returnURL
:=
a
.
config
[
"returnUrl"
]
if
req
.
ReturnURL
!=
""
{
returnURL
=
req
.
ReturnURL
}
if
req
.
IsMobile
{
return
a
.
createTrade
(
client
,
req
,
notifyURL
,
returnURL
,
true
)
}
return
a
.
createTrade
(
client
,
req
,
notifyURL
,
returnURL
,
false
)
}
func
(
a
*
Alipay
)
createTrade
(
client
*
alipay
.
Client
,
req
payment
.
CreatePaymentRequest
,
notifyURL
,
returnURL
string
,
isMobile
bool
)
(
*
payment
.
CreatePaymentResponse
,
error
)
{
if
isMobile
{
param
:=
alipay
.
TradeWapPay
{}
param
.
OutTradeNo
=
req
.
OrderID
param
.
TotalAmount
=
req
.
Amount
param
.
Subject
=
req
.
Subject
param
.
ProductCode
=
alipayProductCodeWapPay
param
.
NotifyURL
=
notifyURL
param
.
ReturnURL
=
returnURL
payURL
,
err
:=
client
.
TradeWapPay
(
param
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"alipay TradeWapPay: %w"
,
err
)
}
return
&
payment
.
CreatePaymentResponse
{
TradeNo
:
req
.
OrderID
,
PayURL
:
payURL
.
String
(),
},
nil
}
param
:=
alipay
.
TradePagePay
{}
param
.
OutTradeNo
=
req
.
OrderID
param
.
TotalAmount
=
req
.
Amount
param
.
Subject
=
req
.
Subject
param
.
ProductCode
=
alipayProductCodePagePay
param
.
NotifyURL
=
notifyURL
param
.
ReturnURL
=
returnURL
payURL
,
err
:=
client
.
TradePagePay
(
param
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"alipay TradePagePay: %w"
,
err
)
}
return
&
payment
.
CreatePaymentResponse
{
TradeNo
:
req
.
OrderID
,
PayURL
:
payURL
.
String
(),
QRCode
:
payURL
.
String
(),
},
nil
}
// QueryOrder queries the trade status via Alipay.
func
(
a
*
Alipay
)
QueryOrder
(
ctx
context
.
Context
,
tradeNo
string
)
(
*
payment
.
QueryOrderResponse
,
error
)
{
client
,
err
:=
a
.
getClient
()
if
err
!=
nil
{
return
nil
,
err
}
result
,
err
:=
client
.
TradeQuery
(
ctx
,
alipay
.
TradeQuery
{
OutTradeNo
:
tradeNo
})
if
err
!=
nil
{
if
isTradeNotExist
(
err
)
{
return
&
payment
.
QueryOrderResponse
{
TradeNo
:
tradeNo
,
Status
:
payment
.
ProviderStatusPending
,
},
nil
}
return
nil
,
fmt
.
Errorf
(
"alipay TradeQuery: %w"
,
err
)
}
status
:=
payment
.
ProviderStatusPending
switch
result
.
TradeStatus
{
case
alipay
.
TradeStatusSuccess
,
alipay
.
TradeStatusFinished
:
status
=
payment
.
ProviderStatusPaid
case
alipay
.
TradeStatusClosed
:
status
=
payment
.
ProviderStatusFailed
}
amount
,
err
:=
strconv
.
ParseFloat
(
result
.
TotalAmount
,
64
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"alipay parse amount %q: %w"
,
result
.
TotalAmount
,
err
)
}
return
&
payment
.
QueryOrderResponse
{
TradeNo
:
result
.
TradeNo
,
Status
:
status
,
Amount
:
amount
,
PaidAt
:
result
.
SendPayDate
,
},
nil
}
// VerifyNotification decodes and verifies an Alipay async notification.
func
(
a
*
Alipay
)
VerifyNotification
(
ctx
context
.
Context
,
rawBody
string
,
_
map
[
string
]
string
)
(
*
payment
.
PaymentNotification
,
error
)
{
client
,
err
:=
a
.
getClient
()
if
err
!=
nil
{
return
nil
,
err
}
values
,
err
:=
url
.
ParseQuery
(
rawBody
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"alipay parse notification: %w"
,
err
)
}
notification
,
err
:=
client
.
DecodeNotification
(
ctx
,
values
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"alipay verify notification: %w"
,
err
)
}
status
:=
payment
.
ProviderStatusFailed
if
notification
.
TradeStatus
==
alipay
.
TradeStatusSuccess
||
notification
.
TradeStatus
==
alipay
.
TradeStatusFinished
{
status
=
payment
.
ProviderStatusSuccess
}
amount
,
err
:=
strconv
.
ParseFloat
(
notification
.
TotalAmount
,
64
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"alipay parse notification amount %q: %w"
,
notification
.
TotalAmount
,
err
)
}
return
&
payment
.
PaymentNotification
{
TradeNo
:
notification
.
TradeNo
,
OrderID
:
notification
.
OutTradeNo
,
Amount
:
amount
,
Status
:
status
,
RawData
:
rawBody
,
},
nil
}
// Refund requests a refund through Alipay.
func
(
a
*
Alipay
)
Refund
(
ctx
context
.
Context
,
req
payment
.
RefundRequest
)
(
*
payment
.
RefundResponse
,
error
)
{
client
,
err
:=
a
.
getClient
()
if
err
!=
nil
{
return
nil
,
err
}
result
,
err
:=
client
.
TradeRefund
(
ctx
,
alipay
.
TradeRefund
{
OutTradeNo
:
req
.
OrderID
,
RefundAmount
:
req
.
Amount
,
RefundReason
:
req
.
Reason
,
OutRequestNo
:
fmt
.
Sprintf
(
"%s-refund-%d"
,
req
.
OrderID
,
time
.
Now
()
.
UnixNano
()),
})
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"alipay TradeRefund: %w"
,
err
)
}
refundStatus
:=
payment
.
ProviderStatusPending
if
result
.
FundChange
==
alipayFundChangeYes
{
refundStatus
=
payment
.
ProviderStatusSuccess
}
refundID
:=
result
.
TradeNo
if
refundID
==
""
{
refundID
=
req
.
OrderID
+
alipayRefundSuffix
}
return
&
payment
.
RefundResponse
{
RefundID
:
refundID
,
Status
:
refundStatus
,
},
nil
}
// CancelPayment closes a pending trade on Alipay.
func
(
a
*
Alipay
)
CancelPayment
(
ctx
context
.
Context
,
tradeNo
string
)
error
{
client
,
err
:=
a
.
getClient
()
if
err
!=
nil
{
return
err
}
_
,
err
=
client
.
TradeClose
(
ctx
,
alipay
.
TradeClose
{
OutTradeNo
:
tradeNo
})
if
err
!=
nil
{
if
isTradeNotExist
(
err
)
{
return
nil
}
return
fmt
.
Errorf
(
"alipay TradeClose: %w"
,
err
)
}
return
nil
}
func
isTradeNotExist
(
err
error
)
bool
{
if
err
==
nil
{
return
false
}
return
strings
.
Contains
(
err
.
Error
(),
alipayErrTradeNotExist
)
}
// Ensure interface compliance.
var
(
_
payment
.
Provider
=
(
*
Alipay
)(
nil
)
_
payment
.
CancelableProvider
=
(
*
Alipay
)(
nil
)
)
backend/internal/payment/provider/alipay_test.go
0 → 100644
View file @
a04ae28a
//go:build unit
package
provider
import
(
"errors"
"strings"
"testing"
)
func
TestIsTradeNotExist
(
t
*
testing
.
T
)
{
t
.
Parallel
()
tests
:=
[]
struct
{
name
string
err
error
want
bool
}{
{
name
:
"nil error returns false"
,
err
:
nil
,
want
:
false
,
},
{
name
:
"error containing ACQ.TRADE_NOT_EXIST returns true"
,
err
:
errors
.
New
(
"alipay: sub_code=ACQ.TRADE_NOT_EXIST, sub_msg=交易不存在"
),
want
:
true
,
},
{
name
:
"error not containing the code returns false"
,
err
:
errors
.
New
(
"alipay: sub_code=ACQ.SYSTEM_ERROR, sub_msg=系统错误"
),
want
:
false
,
},
{
name
:
"error with only partial match returns false"
,
err
:
errors
.
New
(
"ACQ.TRADE_NOT"
),
want
:
false
,
},
{
name
:
"error with exact constant value returns true"
,
err
:
errors
.
New
(
alipayErrTradeNotExist
),
want
:
true
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
got
:=
isTradeNotExist
(
tt
.
err
)
if
got
!=
tt
.
want
{
t
.
Errorf
(
"isTradeNotExist(%v) = %v, want %v"
,
tt
.
err
,
got
,
tt
.
want
)
}
})
}
}
func
TestNewAlipay
(
t
*
testing
.
T
)
{
t
.
Parallel
()
validConfig
:=
map
[
string
]
string
{
"appId"
:
"2021001234567890"
,
"privateKey"
:
"MIIEvQIBADANBgkqhkiG9w0BAQEFAASC..."
,
}
// helper to clone and override config fields
withOverride
:=
func
(
overrides
map
[
string
]
string
)
map
[
string
]
string
{
cfg
:=
make
(
map
[
string
]
string
,
len
(
validConfig
))
for
k
,
v
:=
range
validConfig
{
cfg
[
k
]
=
v
}
for
k
,
v
:=
range
overrides
{
cfg
[
k
]
=
v
}
return
cfg
}
tests
:=
[]
struct
{
name
string
config
map
[
string
]
string
wantErr
bool
errSubstr
string
}{
{
name
:
"valid config succeeds"
,
config
:
validConfig
,
wantErr
:
false
,
},
{
name
:
"missing appId"
,
config
:
withOverride
(
map
[
string
]
string
{
"appId"
:
""
}),
wantErr
:
true
,
errSubstr
:
"appId"
,
},
{
name
:
"missing privateKey"
,
config
:
withOverride
(
map
[
string
]
string
{
"privateKey"
:
""
}),
wantErr
:
true
,
errSubstr
:
"privateKey"
,
},
{
name
:
"nil config map returns error for appId"
,
config
:
map
[
string
]
string
{},
wantErr
:
true
,
errSubstr
:
"appId"
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
got
,
err
:=
NewAlipay
(
"test-instance"
,
tt
.
config
)
if
tt
.
wantErr
{
if
err
==
nil
{
t
.
Fatal
(
"expected error, got nil"
)
}
if
tt
.
errSubstr
!=
""
&&
!
strings
.
Contains
(
err
.
Error
(),
tt
.
errSubstr
)
{
t
.
Errorf
(
"error %q should contain %q"
,
err
.
Error
(),
tt
.
errSubstr
)
}
return
}
if
err
!=
nil
{
t
.
Fatalf
(
"unexpected error: %v"
,
err
)
}
if
got
==
nil
{
t
.
Fatal
(
"expected non-nil Alipay instance"
)
}
if
got
.
instanceID
!=
"test-instance"
{
t
.
Errorf
(
"instanceID = %q, want %q"
,
got
.
instanceID
,
"test-instance"
)
}
})
}
}
backend/internal/payment/provider/easypay.go
0 → 100644
View file @
a04ae28a
// Package provider contains concrete payment provider implementations.
package
provider
import
(
"context"
"crypto/hmac"
"crypto/md5"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/payment"
)
// EasyPay constants.
const
(
easypayCodeSuccess
=
1
easypayStatusPaid
=
1
easypayHTTPTimeout
=
10
*
time
.
Second
maxEasypayResponseSize
=
1
<<
20
// 1MB
tradeStatusSuccess
=
"TRADE_SUCCESS"
signTypeMD5
=
"MD5"
paymentModePopup
=
"popup"
deviceMobile
=
"mobile"
)
// EasyPay implements payment.Provider for the EasyPay aggregation platform.
type
EasyPay
struct
{
instanceID
string
config
map
[
string
]
string
httpClient
*
http
.
Client
}
// NewEasyPay creates a new EasyPay provider.
// config keys: pid, pkey, apiBase, notifyUrl, returnUrl, cid, cidAlipay, cidWxpay
func
NewEasyPay
(
instanceID
string
,
config
map
[
string
]
string
)
(
*
EasyPay
,
error
)
{
for
_
,
k
:=
range
[]
string
{
"pid"
,
"pkey"
,
"apiBase"
,
"notifyUrl"
,
"returnUrl"
}
{
if
config
[
k
]
==
""
{
return
nil
,
fmt
.
Errorf
(
"easypay config missing required key: %s"
,
k
)
}
}
return
&
EasyPay
{
instanceID
:
instanceID
,
config
:
config
,
httpClient
:
&
http
.
Client
{
Timeout
:
easypayHTTPTimeout
},
},
nil
}
func
(
e
*
EasyPay
)
Name
()
string
{
return
"EasyPay"
}
func
(
e
*
EasyPay
)
ProviderKey
()
string
{
return
payment
.
TypeEasyPay
}
func
(
e
*
EasyPay
)
SupportedTypes
()
[]
payment
.
PaymentType
{
return
[]
payment
.
PaymentType
{
payment
.
TypeAlipay
,
payment
.
TypeWxpay
}
}
func
(
e
*
EasyPay
)
CreatePayment
(
ctx
context
.
Context
,
req
payment
.
CreatePaymentRequest
)
(
*
payment
.
CreatePaymentResponse
,
error
)
{
// Payment mode determined by instance config, not payment type.
// "popup" → hosted page (submit.php); "qrcode"/default → API call (mapi.php).
mode
:=
e
.
config
[
"paymentMode"
]
if
mode
==
paymentModePopup
{
return
e
.
createRedirectPayment
(
req
)
}
return
e
.
createAPIPayment
(
ctx
,
req
)
}
// createRedirectPayment builds a submit.php URL for browser redirect.
// No server-side API call — the user is redirected to EasyPay's hosted page.
// TradeNo is empty; it arrives via the notify callback after payment.
func
(
e
*
EasyPay
)
createRedirectPayment
(
req
payment
.
CreatePaymentRequest
)
(
*
payment
.
CreatePaymentResponse
,
error
)
{
notifyURL
,
returnURL
:=
e
.
resolveURLs
(
req
)
params
:=
map
[
string
]
string
{
"pid"
:
e
.
config
[
"pid"
],
"type"
:
req
.
PaymentType
,
"out_trade_no"
:
req
.
OrderID
,
"notify_url"
:
notifyURL
,
"return_url"
:
returnURL
,
"name"
:
req
.
Subject
,
"money"
:
req
.
Amount
,
}
if
cid
:=
e
.
resolveCID
(
req
.
PaymentType
);
cid
!=
""
{
params
[
"cid"
]
=
cid
}
if
req
.
IsMobile
{
params
[
"device"
]
=
deviceMobile
}
params
[
"sign"
]
=
easyPaySign
(
params
,
e
.
config
[
"pkey"
])
params
[
"sign_type"
]
=
signTypeMD5
q
:=
url
.
Values
{}
for
k
,
v
:=
range
params
{
q
.
Set
(
k
,
v
)
}
base
:=
strings
.
TrimRight
(
e
.
config
[
"apiBase"
],
"/"
)
payURL
:=
base
+
"/submit.php?"
+
q
.
Encode
()
return
&
payment
.
CreatePaymentResponse
{
PayURL
:
payURL
},
nil
}
// createAPIPayment calls mapi.php to get payurl/qrcode (existing behavior).
func
(
e
*
EasyPay
)
createAPIPayment
(
ctx
context
.
Context
,
req
payment
.
CreatePaymentRequest
)
(
*
payment
.
CreatePaymentResponse
,
error
)
{
notifyURL
,
returnURL
:=
e
.
resolveURLs
(
req
)
params
:=
map
[
string
]
string
{
"pid"
:
e
.
config
[
"pid"
],
"type"
:
req
.
PaymentType
,
"out_trade_no"
:
req
.
OrderID
,
"notify_url"
:
notifyURL
,
"return_url"
:
returnURL
,
"name"
:
req
.
Subject
,
"money"
:
req
.
Amount
,
"clientip"
:
req
.
ClientIP
,
}
if
cid
:=
e
.
resolveCID
(
req
.
PaymentType
);
cid
!=
""
{
params
[
"cid"
]
=
cid
}
if
req
.
IsMobile
{
params
[
"device"
]
=
deviceMobile
}
params
[
"sign"
]
=
easyPaySign
(
params
,
e
.
config
[
"pkey"
])
params
[
"sign_type"
]
=
signTypeMD5
body
,
err
:=
e
.
post
(
ctx
,
strings
.
TrimRight
(
e
.
config
[
"apiBase"
],
"/"
)
+
"/mapi.php"
,
params
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"easypay create: %w"
,
err
)
}
var
resp
struct
{
Code
int
`json:"code"`
Msg
string
`json:"msg"`
TradeNo
string
`json:"trade_no"`
PayURL
string
`json:"payurl"`
PayURL2
string
`json:"payurl2"`
// H5 mobile payment URL
QRCode
string
`json:"qrcode"`
}
if
err
:=
json
.
Unmarshal
(
body
,
&
resp
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"easypay parse: %w"
,
err
)
}
if
resp
.
Code
!=
easypayCodeSuccess
{
return
nil
,
fmt
.
Errorf
(
"easypay error: %s"
,
resp
.
Msg
)
}
payURL
:=
resp
.
PayURL
if
req
.
IsMobile
&&
resp
.
PayURL2
!=
""
{
payURL
=
resp
.
PayURL2
}
return
&
payment
.
CreatePaymentResponse
{
TradeNo
:
resp
.
TradeNo
,
PayURL
:
payURL
,
QRCode
:
resp
.
QRCode
},
nil
}
// resolveURLs returns (notifyURL, returnURL) preferring request values,
// falling back to instance config.
func
(
e
*
EasyPay
)
resolveURLs
(
req
payment
.
CreatePaymentRequest
)
(
string
,
string
)
{
notifyURL
:=
req
.
NotifyURL
if
notifyURL
==
""
{
notifyURL
=
e
.
config
[
"notifyUrl"
]
}
returnURL
:=
req
.
ReturnURL
if
returnURL
==
""
{
returnURL
=
e
.
config
[
"returnUrl"
]
}
return
notifyURL
,
returnURL
}
func
(
e
*
EasyPay
)
QueryOrder
(
ctx
context
.
Context
,
tradeNo
string
)
(
*
payment
.
QueryOrderResponse
,
error
)
{
params
:=
map
[
string
]
string
{
"act"
:
"order"
,
"pid"
:
e
.
config
[
"pid"
],
"key"
:
e
.
config
[
"pkey"
],
"out_trade_no"
:
tradeNo
,
}
body
,
err
:=
e
.
post
(
ctx
,
e
.
config
[
"apiBase"
]
+
"/api.php"
,
params
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"easypay query: %w"
,
err
)
}
var
resp
struct
{
Code
int
`json:"code"`
Msg
string
`json:"msg"`
Status
int
`json:"status"`
Money
string
`json:"money"`
}
if
err
:=
json
.
Unmarshal
(
body
,
&
resp
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"easypay parse query: %w"
,
err
)
}
status
:=
payment
.
ProviderStatusPending
if
resp
.
Status
==
easypayStatusPaid
{
status
=
payment
.
ProviderStatusPaid
}
amount
,
_
:=
strconv
.
ParseFloat
(
resp
.
Money
,
64
)
return
&
payment
.
QueryOrderResponse
{
TradeNo
:
tradeNo
,
Status
:
status
,
Amount
:
amount
},
nil
}
func
(
e
*
EasyPay
)
VerifyNotification
(
_
context
.
Context
,
rawBody
string
,
_
map
[
string
]
string
)
(
*
payment
.
PaymentNotification
,
error
)
{
values
,
err
:=
url
.
ParseQuery
(
rawBody
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"parse notify: %w"
,
err
)
}
// url.ParseQuery already decodes values — no additional decode needed.
params
:=
make
(
map
[
string
]
string
)
for
k
:=
range
values
{
params
[
k
]
=
values
.
Get
(
k
)
}
sign
:=
params
[
"sign"
]
if
sign
==
""
{
return
nil
,
fmt
.
Errorf
(
"missing sign"
)
}
if
!
easyPayVerifySign
(
params
,
e
.
config
[
"pkey"
],
sign
)
{
return
nil
,
fmt
.
Errorf
(
"invalid signature"
)
}
status
:=
payment
.
ProviderStatusFailed
if
params
[
"trade_status"
]
==
tradeStatusSuccess
{
status
=
payment
.
ProviderStatusSuccess
}
amount
,
_
:=
strconv
.
ParseFloat
(
params
[
"money"
],
64
)
return
&
payment
.
PaymentNotification
{
TradeNo
:
params
[
"trade_no"
],
OrderID
:
params
[
"out_trade_no"
],
Amount
:
amount
,
Status
:
status
,
RawData
:
rawBody
,
},
nil
}
func
(
e
*
EasyPay
)
Refund
(
ctx
context
.
Context
,
req
payment
.
RefundRequest
)
(
*
payment
.
RefundResponse
,
error
)
{
params
:=
map
[
string
]
string
{
"pid"
:
e
.
config
[
"pid"
],
"key"
:
e
.
config
[
"pkey"
],
"trade_no"
:
req
.
TradeNo
,
"out_trade_no"
:
req
.
OrderID
,
"money"
:
req
.
Amount
,
}
body
,
err
:=
e
.
post
(
ctx
,
e
.
config
[
"apiBase"
]
+
"/api.php?act=refund"
,
params
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"easypay refund: %w"
,
err
)
}
var
resp
struct
{
Code
int
`json:"code"`
Msg
string
`json:"msg"`
}
if
err
:=
json
.
Unmarshal
(
body
,
&
resp
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"easypay parse refund: %w"
,
err
)
}
if
resp
.
Code
!=
easypayCodeSuccess
{
return
nil
,
fmt
.
Errorf
(
"easypay refund failed: %s"
,
resp
.
Msg
)
}
return
&
payment
.
RefundResponse
{
RefundID
:
req
.
TradeNo
,
Status
:
payment
.
ProviderStatusSuccess
},
nil
}
func
(
e
*
EasyPay
)
resolveCID
(
paymentType
string
)
string
{
if
strings
.
HasPrefix
(
paymentType
,
"alipay"
)
{
if
v
:=
e
.
config
[
"cidAlipay"
];
v
!=
""
{
return
v
}
return
e
.
config
[
"cid"
]
}
if
v
:=
e
.
config
[
"cidWxpay"
];
v
!=
""
{
return
v
}
return
e
.
config
[
"cid"
]
}
func
(
e
*
EasyPay
)
post
(
ctx
context
.
Context
,
endpoint
string
,
params
map
[
string
]
string
)
([]
byte
,
error
)
{
form
:=
url
.
Values
{}
for
k
,
v
:=
range
params
{
form
.
Set
(
k
,
v
)
}
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodPost
,
endpoint
,
strings
.
NewReader
(
form
.
Encode
()))
if
err
!=
nil
{
return
nil
,
err
}
req
.
Header
.
Set
(
"Content-Type"
,
"application/x-www-form-urlencoded"
)
resp
,
err
:=
e
.
httpClient
.
Do
(
req
)
if
err
!=
nil
{
return
nil
,
err
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
return
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
maxEasypayResponseSize
))
}
func
easyPaySign
(
params
map
[
string
]
string
,
pkey
string
)
string
{
keys
:=
make
([]
string
,
0
,
len
(
params
))
for
k
,
v
:=
range
params
{
if
k
==
"sign"
||
k
==
"sign_type"
||
v
==
""
{
continue
}
keys
=
append
(
keys
,
k
)
}
sort
.
Strings
(
keys
)
var
buf
strings
.
Builder
for
i
,
k
:=
range
keys
{
if
i
>
0
{
_
=
buf
.
WriteByte
(
'&'
)
}
_
,
_
=
buf
.
WriteString
(
k
+
"="
+
params
[
k
])
}
_
,
_
=
buf
.
WriteString
(
pkey
)
hash
:=
md5
.
Sum
([]
byte
(
buf
.
String
()))
return
hex
.
EncodeToString
(
hash
[
:
])
}
func
easyPayVerifySign
(
params
map
[
string
]
string
,
pkey
string
,
sign
string
)
bool
{
return
hmac
.
Equal
([]
byte
(
easyPaySign
(
params
,
pkey
)),
[]
byte
(
sign
))
}
backend/internal/payment/provider/easypay_sign_test.go
0 → 100644
View file @
a04ae28a
package
provider
import
(
"testing"
)
func
TestEasyPaySignConsistentOutput
(
t
*
testing
.
T
)
{
t
.
Parallel
()
params
:=
map
[
string
]
string
{
"pid"
:
"1001"
,
"type"
:
"alipay"
,
"out_trade_no"
:
"ORDER123"
,
"name"
:
"Test Product"
,
"money"
:
"10.00"
,
}
pkey
:=
"test_secret_key"
sign1
:=
easyPaySign
(
params
,
pkey
)
sign2
:=
easyPaySign
(
params
,
pkey
)
if
sign1
!=
sign2
{
t
.
Fatalf
(
"easyPaySign should be deterministic: %q != %q"
,
sign1
,
sign2
)
}
if
len
(
sign1
)
!=
32
{
t
.
Fatalf
(
"MD5 hex should be 32 chars, got %d"
,
len
(
sign1
))
}
}
func
TestEasyPaySignExcludesSignAndSignType
(
t
*
testing
.
T
)
{
t
.
Parallel
()
pkey
:=
"my_key"
base
:=
map
[
string
]
string
{
"pid"
:
"1001"
,
"type"
:
"alipay"
,
}
withSign
:=
map
[
string
]
string
{
"pid"
:
"1001"
,
"type"
:
"alipay"
,
"sign"
:
"should_be_ignored"
,
"sign_type"
:
"MD5"
,
}
signBase
:=
easyPaySign
(
base
,
pkey
)
signWithExtra
:=
easyPaySign
(
withSign
,
pkey
)
if
signBase
!=
signWithExtra
{
t
.
Fatalf
(
"sign and sign_type should be excluded: base=%q, withExtra=%q"
,
signBase
,
signWithExtra
)
}
}
func
TestEasyPaySignExcludesEmptyValues
(
t
*
testing
.
T
)
{
t
.
Parallel
()
pkey
:=
"key123"
base
:=
map
[
string
]
string
{
"pid"
:
"1001"
,
"type"
:
"alipay"
,
}
withEmpty
:=
map
[
string
]
string
{
"pid"
:
"1001"
,
"type"
:
"alipay"
,
"device"
:
""
,
"clientip"
:
""
,
}
signBase
:=
easyPaySign
(
base
,
pkey
)
signWithEmpty
:=
easyPaySign
(
withEmpty
,
pkey
)
if
signBase
!=
signWithEmpty
{
t
.
Fatalf
(
"empty values should be excluded: base=%q, withEmpty=%q"
,
signBase
,
signWithEmpty
)
}
}
func
TestEasyPayVerifySignValid
(
t
*
testing
.
T
)
{
t
.
Parallel
()
params
:=
map
[
string
]
string
{
"pid"
:
"1001"
,
"type"
:
"alipay"
,
"out_trade_no"
:
"ORDER456"
,
"money"
:
"25.00"
,
}
pkey
:=
"secret"
sign
:=
easyPaySign
(
params
,
pkey
)
// Add sign to params (as would come in a real callback)
params
[
"sign"
]
=
sign
params
[
"sign_type"
]
=
"MD5"
if
!
easyPayVerifySign
(
params
,
pkey
,
sign
)
{
t
.
Fatal
(
"easyPayVerifySign should return true for a valid signature"
)
}
}
func
TestEasyPayVerifySignTampered
(
t
*
testing
.
T
)
{
t
.
Parallel
()
params
:=
map
[
string
]
string
{
"pid"
:
"1001"
,
"type"
:
"alipay"
,
"out_trade_no"
:
"ORDER789"
,
"money"
:
"50.00"
,
}
pkey
:=
"secret"
sign
:=
easyPaySign
(
params
,
pkey
)
// Tamper with the amount
params
[
"money"
]
=
"99.99"
if
easyPayVerifySign
(
params
,
pkey
,
sign
)
{
t
.
Fatal
(
"easyPayVerifySign should return false for tampered params"
)
}
}
func
TestEasyPayVerifySignWrongKey
(
t
*
testing
.
T
)
{
t
.
Parallel
()
params
:=
map
[
string
]
string
{
"pid"
:
"1001"
,
"type"
:
"wxpay"
,
}
sign
:=
easyPaySign
(
params
,
"correct_key"
)
if
easyPayVerifySign
(
params
,
"wrong_key"
,
sign
)
{
t
.
Fatal
(
"easyPayVerifySign should return false with wrong key"
)
}
}
func
TestEasyPaySignEmptyParams
(
t
*
testing
.
T
)
{
t
.
Parallel
()
sign
:=
easyPaySign
(
map
[
string
]
string
{},
"key123"
)
if
sign
==
""
{
t
.
Fatal
(
"easyPaySign with empty params should still produce a hash"
)
}
if
len
(
sign
)
!=
32
{
t
.
Fatalf
(
"MD5 hex should be 32 chars, got %d"
,
len
(
sign
))
}
}
func
TestEasyPaySignSortOrder
(
t
*
testing
.
T
)
{
t
.
Parallel
()
pkey
:=
"test_key"
params1
:=
map
[
string
]
string
{
"a"
:
"1"
,
"b"
:
"2"
,
"c"
:
"3"
,
}
params2
:=
map
[
string
]
string
{
"c"
:
"3"
,
"a"
:
"1"
,
"b"
:
"2"
,
}
sign1
:=
easyPaySign
(
params1
,
pkey
)
sign2
:=
easyPaySign
(
params2
,
pkey
)
if
sign1
!=
sign2
{
t
.
Fatalf
(
"easyPaySign should be order-independent: %q != %q"
,
sign1
,
sign2
)
}
}
func
TestEasyPayVerifySignWrongSignValue
(
t
*
testing
.
T
)
{
t
.
Parallel
()
params
:=
map
[
string
]
string
{
"pid"
:
"1001"
,
"type"
:
"alipay"
,
}
pkey
:=
"key"
if
easyPayVerifySign
(
params
,
pkey
,
"00000000000000000000000000000000"
)
{
t
.
Fatal
(
"easyPayVerifySign should return false for an incorrect sign value"
)
}
}
backend/internal/payment/provider/factory.go
0 → 100644
View file @
a04ae28a
package
provider
import
(
"fmt"
"github.com/Wei-Shaw/sub2api/internal/payment"
)
// CreateProvider creates a Provider from a provider key, instance ID and decrypted config.
func
CreateProvider
(
providerKey
string
,
instanceID
string
,
config
map
[
string
]
string
)
(
payment
.
Provider
,
error
)
{
switch
providerKey
{
case
payment
.
TypeEasyPay
:
return
NewEasyPay
(
instanceID
,
config
)
case
payment
.
TypeAlipay
:
return
NewAlipay
(
instanceID
,
config
)
case
payment
.
TypeWxpay
:
return
NewWxpay
(
instanceID
,
config
)
case
payment
.
TypeStripe
:
return
NewStripe
(
instanceID
,
config
)
default
:
return
nil
,
fmt
.
Errorf
(
"unknown provider key: %s"
,
providerKey
)
}
}
Prev
1
2
3
4
5
6
7
8
9
10
…
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment