Unverified Commit 8eb3f9e7 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #1785 from IanShaw027/rebuild/auth-identity-foundation

feat(auth,payment): 重构认证身份和支付系统及其他部分优化
parents 78f691d2 7fbd5177
......@@ -51,6 +51,23 @@ type SystemSettings struct {
LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"`
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
WeChatConnectEnabled bool `json:"wechat_connect_enabled"`
WeChatConnectAppID string `json:"wechat_connect_app_id"`
WeChatConnectAppSecretConfigured bool `json:"wechat_connect_app_secret_configured"`
WeChatConnectOpenAppID string `json:"wechat_connect_open_app_id"`
WeChatConnectOpenAppSecretConfigured bool `json:"wechat_connect_open_app_secret_configured"`
WeChatConnectMPAppID string `json:"wechat_connect_mp_app_id"`
WeChatConnectMPAppSecretConfigured bool `json:"wechat_connect_mp_app_secret_configured"`
WeChatConnectMobileAppID string `json:"wechat_connect_mobile_app_id"`
WeChatConnectMobileAppSecretConfigured bool `json:"wechat_connect_mobile_app_secret_configured"`
WeChatConnectOpenEnabled bool `json:"wechat_connect_open_enabled"`
WeChatConnectMPEnabled bool `json:"wechat_connect_mp_enabled"`
WeChatConnectMobileEnabled bool `json:"wechat_connect_mobile_enabled"`
WeChatConnectMode string `json:"wechat_connect_mode"`
WeChatConnectScopes string `json:"wechat_connect_scopes"`
WeChatConnectRedirectURL string `json:"wechat_connect_redirect_url"`
WeChatConnectFrontendRedirectURL string `json:"wechat_connect_frontend_redirect_url"`
OIDCConnectEnabled bool `json:"oidc_connect_enabled"`
OIDCConnectProviderName string `json:"oidc_connect_provider_name"`
OIDCConnectClientID string `json:"oidc_connect_client_id"`
......@@ -127,6 +144,15 @@ type SystemSettings struct {
// Web Search Emulation
WebSearchEmulationEnabled bool `json:"web_search_emulation_enabled"`
// Payment visible method routing
PaymentVisibleMethodAlipaySource string `json:"payment_visible_method_alipay_source"`
PaymentVisibleMethodWxpaySource string `json:"payment_visible_method_wxpay_source"`
PaymentVisibleMethodAlipayEnabled bool `json:"payment_visible_method_alipay_enabled"`
PaymentVisibleMethodWxpayEnabled bool `json:"payment_visible_method_wxpay_enabled"`
// OpenAI account scheduling
OpenAIAdvancedSchedulerEnabled bool `json:"openai_advanced_scheduler_enabled"`
// Payment configuration
PaymentEnabled bool `json:"payment_enabled"`
PaymentMinAmount float64 `json:"payment_min_amount"`
......@@ -167,6 +193,7 @@ type DefaultSubscriptionSetting struct {
type PublicSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
ForceEmailOnThirdPartySignup bool `json:"force_email_on_third_party_signup"`
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
......@@ -189,6 +216,10 @@ type PublicSettings struct {
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"`
WeChatOAuthMPEnabled bool `json:"wechat_oauth_mp_enabled"`
WeChatOAuthMobileEnabled bool `json:"wechat_oauth_mobile_enabled"`
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
SoraClientEnabled bool `json:"sora_client_enabled"`
......
......@@ -7,16 +7,17 @@ import (
)
type User struct {
ID int64 `json:"id"`
Email string `json:"email"`
Username string `json:"username"`
Role string `json:"role"`
Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"`
Status string `json:"status"`
AllowedGroups []int64 `json:"allowed_groups"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
ID int64 `json:"id"`
Email string `json:"email"`
Username string `json:"username"`
Role string `json:"role"`
Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"`
Status string `json:"status"`
AllowedGroups []int64 `json:"allowed_groups"`
LastActiveAt *time.Time `json:"last_active_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
// 余额不足通知
BalanceNotifyEnabled bool `json:"balance_notify_enabled"`
......@@ -34,7 +35,8 @@ type User struct {
type AdminUser struct {
User
Notes string `json:"notes"`
Notes string `json:"notes"`
LastUsedAt *time.Time `json:"last_used_at"`
// GroupRates 用户专属分组倍率配置
// map[groupID]rateMultiplier
GroupRates map[int64]float64 `json:"group_rates,omitempty"`
......
package dto
import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func TestUserFromServiceAdmin_MapsActivityTimestamps(t *testing.T) {
t.Parallel()
lastLoginAt := time.Date(2026, time.April, 20, 10, 0, 0, 0, time.UTC)
lastActiveAt := lastLoginAt.Add(15 * time.Minute)
lastUsedAt := lastLoginAt.Add(45 * time.Minute)
out := UserFromServiceAdmin(&service.User{
ID: 42,
Email: "admin@example.com",
Username: "admin",
Role: service.RoleAdmin,
Status: service.StatusActive,
LastActiveAt: &lastActiveAt,
LastUsedAt: &lastUsedAt,
})
require.NotNil(t, out)
require.NotNil(t, out.LastActiveAt)
require.NotNil(t, out.LastUsedAt)
require.WithinDuration(t, lastActiveAt, *out.LastActiveAt, time.Second)
require.WithinDuration(t, lastUsedAt, *out.LastUsedAt, time.Second)
}
......@@ -187,6 +187,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "previous_response_id must be a response.id (resp_*), not a message id")
return
}
reqLog.Warn("openai.request_validation_failed",
zap.String("reason", "previous_response_id_requires_wsv2"),
)
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "previous_response_id is only supported on Responses WebSocket v2")
return
}
setOpsRequestContext(c, reqModel, reqStream, body)
......@@ -856,7 +861,7 @@ func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context,
reqLog.Warn("openai.request_validation_failed",
zap.String("reason", "function_call_output_missing_call_id"),
)
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id on HTTP requests; continuation via previous_response_id is only supported on Responses WebSocket v2")
return false
}
if validation.HasItemReferenceForAllCallIDs {
......@@ -866,7 +871,7 @@ func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context,
reqLog.Warn("openai.request_validation_failed",
zap.String("reason", "function_call_output_missing_item_reference"),
)
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id on HTTP requests; continuation via previous_response_id is only supported on Responses WebSocket v2")
return false
}
......
......@@ -494,6 +494,64 @@ func TestOpenAIResponses_RejectsMessageIDAsPreviousResponseID(t *testing.T) {
require.Contains(t, w.Body.String(), "previous_response_id must be a response.id")
}
func TestOpenAIResponses_RejectsHTTPContinuationPreviousResponseID(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader(
`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_123456","input":[{"type":"input_text","text":"hello"}]}`,
))
c.Request.Header.Set("Content-Type", "application/json")
groupID := int64(2)
c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{
ID: 101,
GroupID: &groupID,
User: &service.User{ID: 1},
})
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
UserID: 1,
Concurrency: 1,
})
h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil)
h.Responses(c)
require.Equal(t, http.StatusBadRequest, w.Code)
require.Contains(t, w.Body.String(), "Responses WebSocket v2")
require.Contains(t, w.Body.String(), "previous_response_id")
}
func TestOpenAIResponses_FunctionCallOutputHTTPGuidanceDoesNotSuggestPreviousResponseReuse(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader(
`{"model":"gpt-5.1","stream":false,"input":[{"type":"function_call_output","output":"{}"}]}`,
))
c.Request.Header.Set("Content-Type", "application/json")
groupID := int64(2)
c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{
ID: 101,
GroupID: &groupID,
User: &service.User{ID: 1},
})
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
UserID: 1,
Concurrency: 1,
})
h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil)
h.Responses(c)
require.Equal(t, http.StatusBadRequest, w.Code)
require.Contains(t, w.Body.String(), "Responses WebSocket v2")
require.NotContains(t, w.Body.String(), "reuse previous_response_id")
}
func TestOpenAIResponsesWebSocket_SetsClientTransportWSWhenUpgradeValid(t *testing.T) {
gin.SetMode(gin.TestMode)
......
package handler
import (
"fmt"
"net/http"
"strconv"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
......@@ -202,10 +207,14 @@ func (h *PaymentHandler) GetLimits(c *gin.Context) {
// CreateOrderRequest is the request body for creating a payment order.
type CreateOrderRequest struct {
Amount float64 `json:"amount"`
PaymentType string `json:"payment_type" binding:"required"`
OrderType string `json:"order_type"`
PlanID int64 `json:"plan_id"`
Amount float64 `json:"amount"`
PaymentType string `json:"payment_type" binding:"required"`
OpenID string `json:"openid"`
WechatResumeToken string `json:"wechat_resume_token"`
ReturnURL string `json:"return_url"`
PaymentSource string `json:"payment_source"`
OrderType string `json:"order_type"`
PlanID int64 `json:"plan_id"`
// IsMobile lets the frontend declare its mobile status directly. When
// nil we fall back to User-Agent heuristics (which miss iPadOS / some
// embedded browsers that strip the "Mobile" keyword).
......@@ -225,21 +234,36 @@ func (h *PaymentHandler) CreateOrder(c *gin.Context) {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if strings.TrimSpace(req.WechatResumeToken) != "" {
claims, err := h.paymentService.ParseWeChatPaymentResumeToken(req.WechatResumeToken)
if err != nil {
response.ErrorFrom(c, err)
return
}
if err := applyWeChatPaymentResumeClaims(&req, claims); err != nil {
response.ErrorFrom(c, err)
return
}
}
mobile := isMobile(c)
if req.IsMobile != nil {
mobile = *req.IsMobile
}
result, err := h.paymentService.CreateOrder(c.Request.Context(), service.CreateOrderRequest{
UserID: subject.UserID,
Amount: req.Amount,
PaymentType: req.PaymentType,
ClientIP: c.ClientIP(),
IsMobile: mobile,
SrcHost: c.Request.Host,
SrcURL: c.Request.Referer(),
OrderType: req.OrderType,
PlanID: req.PlanID,
UserID: subject.UserID,
Amount: req.Amount,
PaymentType: req.PaymentType,
OpenID: req.OpenID,
ClientIP: c.ClientIP(),
IsMobile: mobile,
IsWeChatBrowser: isWeChatBrowser(c),
SrcHost: c.Request.Host,
SrcURL: c.Request.Referer(),
ReturnURL: req.ReturnURL,
PaymentSource: req.PaymentSource,
OrderType: req.OrderType,
PlanID: req.PlanID,
})
if err != nil {
response.ErrorFrom(c, err)
......@@ -248,6 +272,44 @@ func (h *PaymentHandler) CreateOrder(c *gin.Context) {
response.Success(c, result)
}
func applyWeChatPaymentResumeClaims(req *CreateOrderRequest, claims *service.WeChatPaymentResumeClaims) error {
if req == nil || claims == nil {
return infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume context is missing")
}
openid := strings.TrimSpace(claims.OpenID)
if openid == "" {
return infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token missing openid")
}
paymentType := service.NormalizeVisibleMethod(claims.PaymentType)
if paymentType == "" {
paymentType = payment.TypeWxpay
}
if req.PaymentType != "" {
requestPaymentType := service.NormalizeVisibleMethod(req.PaymentType)
if requestPaymentType != "" && requestPaymentType != paymentType {
return infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token payment type mismatch")
}
}
req.PaymentType = paymentType
req.OpenID = openid
if strings.TrimSpace(claims.Amount) != "" {
amount, err := strconv.ParseFloat(strings.TrimSpace(claims.Amount), 64)
if err != nil || amount <= 0 {
return infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", fmt.Sprintf("invalid resume amount: %s", claims.Amount))
}
req.Amount = amount
}
if claims.OrderType != "" {
req.OrderType = claims.OrderType
}
if claims.PlanID > 0 {
req.PlanID = claims.PlanID
}
return nil
}
// GetMyOrders returns the authenticated user's orders.
// GET /api/v1/payment/orders/my
func (h *PaymentHandler) GetMyOrders(c *gin.Context) {
......@@ -268,7 +330,7 @@ func (h *PaymentHandler) GetMyOrders(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
response.Paginated(c, orders, int64(total), page, pageSize)
response.Paginated(c, sanitizePaymentOrdersForResponse(orders), int64(total), page, pageSize)
}
// GetOrder returns a single order for the authenticated user.
......@@ -290,7 +352,7 @@ func (h *PaymentHandler) GetOrder(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
response.Success(c, order)
response.Success(c, sanitizePaymentOrderForResponse(order))
}
// CancelOrder cancels a pending order for the authenticated user.
......@@ -362,6 +424,10 @@ type VerifyOrderRequest struct {
OutTradeNo string `json:"out_trade_no" binding:"required"`
}
type ResolveOrderByResumeTokenRequest struct {
ResumeToken string `json:"resume_token" binding:"required"`
}
// VerifyOrder actively queries the upstream payment provider to check
// if payment was made, and processes it if so.
// POST /api/v1/payment/orders/verify
......@@ -382,7 +448,7 @@ func (h *PaymentHandler) VerifyOrder(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
response.Success(c, order)
response.Success(c, sanitizePaymentOrderForResponse(order))
}
// PublicOrderResult is the limited order info returned by the public verify endpoint.
......@@ -397,16 +463,32 @@ type PublicOrderResult struct {
Status string `json:"status"`
}
// VerifyOrderPublic verifies payment status without requiring authentication.
// Returns limited order info (no user details) to prevent information leakage.
var errPaymentPublicOrderVerifyRemoved = infraerrors.New(
http.StatusGone,
"PAYMENT_PUBLIC_ORDER_VERIFY_REMOVED",
"public payment order verification by out_trade_no has been removed; use resume_token recovery instead",
).WithMetadata(map[string]string{
"replacement_endpoint": "/api/v1/payment/public/orders/resolve",
"replacement_field": "resume_token",
})
// VerifyOrderPublic is kept as a compatibility shim for the removed anonymous
// out_trade_no lookup endpoint and always returns HTTP 410 Gone.
// POST /api/v1/payment/public/orders/verify
func (h *PaymentHandler) VerifyOrderPublic(c *gin.Context) {
var req VerifyOrderRequest
response.ErrorFrom(c, errPaymentPublicOrderVerifyRemoved)
}
// ResolveOrderPublicByResumeToken resolves a payment order from a signed resume token.
// POST /api/v1/payment/public/orders/resolve
func (h *PaymentHandler) ResolveOrderPublicByResumeToken(c *gin.Context) {
var req ResolveOrderByResumeTokenRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
order, err := h.paymentService.VerifyOrderPublic(c.Request.Context(), req.OutTradeNo)
order, err := h.paymentService.GetPublicOrderByResumeToken(c.Request.Context(), req.ResumeToken)
if err != nil {
response.ErrorFrom(c, err)
return
......@@ -443,3 +525,27 @@ func isMobile(c *gin.Context) bool {
}
return false
}
func sanitizePaymentOrdersForResponse(orders []*dbent.PaymentOrder) []*dbent.PaymentOrder {
if len(orders) == 0 {
return orders
}
out := make([]*dbent.PaymentOrder, 0, len(orders))
for _, order := range orders {
out = append(out, sanitizePaymentOrderForResponse(order))
}
return out
}
func sanitizePaymentOrderForResponse(order *dbent.PaymentOrder) *dbent.PaymentOrder {
if order == nil {
return nil
}
cloned := *order
cloned.ProviderSnapshot = nil
return &cloned
}
func isWeChatBrowser(c *gin.Context) bool {
return strings.Contains(strings.ToLower(c.GetHeader("User-Agent")), "micromessenger")
}
//go:build unit
package handler
import (
"bytes"
"database/sql"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "modernc.org/sqlite"
)
func TestApplyWeChatPaymentResumeClaims(t *testing.T) {
t.Parallel()
req := CreateOrderRequest{
Amount: 0,
PaymentType: payment.TypeWxpay,
OrderType: payment.OrderTypeBalance,
}
err := applyWeChatPaymentResumeClaims(&req, &service.WeChatPaymentResumeClaims{
OpenID: "openid-123",
PaymentType: payment.TypeWxpay,
Amount: "12.50",
OrderType: payment.OrderTypeSubscription,
PlanID: 7,
})
if err != nil {
t.Fatalf("applyWeChatPaymentResumeClaims returned error: %v", err)
}
if req.OpenID != "openid-123" {
t.Fatalf("openid = %q, want %q", req.OpenID, "openid-123")
}
if req.Amount != 12.5 {
t.Fatalf("amount = %v, want 12.5", req.Amount)
}
if req.OrderType != payment.OrderTypeSubscription {
t.Fatalf("order_type = %q, want %q", req.OrderType, payment.OrderTypeSubscription)
}
if req.PlanID != 7 {
t.Fatalf("plan_id = %d, want 7", req.PlanID)
}
}
func TestApplyWeChatPaymentResumeClaimsRejectsPaymentTypeMismatch(t *testing.T) {
t.Parallel()
req := CreateOrderRequest{
PaymentType: payment.TypeAlipay,
}
err := applyWeChatPaymentResumeClaims(&req, &service.WeChatPaymentResumeClaims{
OpenID: "openid-123",
PaymentType: payment.TypeWxpay,
Amount: "12.50",
OrderType: payment.OrderTypeBalance,
})
if err == nil {
t.Fatal("applyWeChatPaymentResumeClaims should reject mismatched payment types")
}
}
func TestVerifyOrderPublicReturnsGone(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
db, err := sql.Open("sqlite", "file:payment_handler_public_verify?mode=memory&cache=shared")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil)
h := NewPaymentHandler(paymentSvc, nil, nil)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest(
http.MethodPost,
"/api/v1/payment/public/orders/verify",
bytes.NewBufferString(`{"out_trade_no":"legacy-order-no"}`),
)
ctx.Request.Header.Set("Content-Type", "application/json")
h.VerifyOrderPublic(ctx)
require.Equal(t, http.StatusGone, recorder.Code)
var resp response.Response
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, http.StatusGone, resp.Code)
require.Equal(t, "PAYMENT_PUBLIC_ORDER_VERIFY_REMOVED", resp.Reason)
require.Contains(t, resp.Message, "removed")
}
package handler
import (
"context"
"fmt"
"io"
"log/slog"
"net/http"
......@@ -77,9 +79,13 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string)
// This is needed when multiple instances of the same provider exist (e.g. multiple EasyPay accounts).
outTradeNo := extractOutTradeNo(rawBody, providerKey)
provider, err := h.paymentService.GetWebhookProvider(c.Request.Context(), providerKey, outTradeNo)
providers, err := h.paymentService.GetWebhookProviders(c.Request.Context(), providerKey, outTradeNo)
if err != nil {
slog.Warn("[Payment Webhook] provider not found", "provider", providerKey, "outTradeNo", outTradeNo, "error", err)
if providerKey == payment.TypeWxpay {
c.String(http.StatusBadRequest, "verify failed")
return
}
writeSuccessResponse(c, providerKey)
return
}
......@@ -89,7 +95,7 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string)
headers[strings.ToLower(k)] = c.GetHeader(k)
}
notification, err := provider.VerifyNotification(c.Request.Context(), rawBody, headers)
resolvedProviderKey, notification, err := verifyNotificationWithProviders(c.Request.Context(), providers, rawBody, headers)
if err != nil {
truncatedBody := rawBody
if len(truncatedBody) > webhookLogTruncateLen {
......@@ -103,24 +109,24 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string)
// nil notification means irrelevant event (e.g. Stripe non-payment event); return success.
if notification == nil {
writeSuccessResponse(c, providerKey)
writeSuccessResponse(c, resolvedProviderKey)
return
}
if err := h.paymentService.HandlePaymentNotification(c.Request.Context(), notification, providerKey); err != nil {
slog.Error("[Payment Webhook] handle notification failed", "provider", providerKey, "error", err)
if err := h.paymentService.HandlePaymentNotification(c.Request.Context(), notification, resolvedProviderKey); err != nil {
slog.Error("[Payment Webhook] handle notification failed", "provider", resolvedProviderKey, "error", err)
c.String(http.StatusInternalServerError, "handle failed")
return
}
writeSuccessResponse(c, providerKey)
writeSuccessResponse(c, resolvedProviderKey)
}
// extractOutTradeNo parses the webhook body to find the out_trade_no.
// This allows looking up the correct provider instance before verification.
func extractOutTradeNo(rawBody, providerKey string) string {
switch providerKey {
case payment.TypeEasyPay:
case payment.TypeEasyPay, payment.TypeAlipay:
values, err := url.ParseQuery(rawBody)
if err == nil {
return values.Get("out_trade_no")
......@@ -131,6 +137,25 @@ func extractOutTradeNo(rawBody, providerKey string) string {
return ""
}
func verifyNotificationWithProviders(ctx context.Context, providers []payment.Provider, rawBody string, headers map[string]string) (string, *payment.PaymentNotification, error) {
var lastErr error
for _, provider := range providers {
if provider == nil {
continue
}
notification, err := provider.VerifyNotification(ctx, rawBody, headers)
if err != nil {
lastErr = err
continue
}
return provider.ProviderKey(), notification, nil
}
if lastErr != nil {
return "", nil, lastErr
}
return "", nil, fmt.Errorf("no webhook provider could verify notification")
}
// wxpaySuccessResponse is the JSON response expected by WeChat Pay webhook.
type wxpaySuccessResponse struct {
Code string `json:"code"`
......
......@@ -3,11 +3,14 @@
package handler
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
......@@ -97,3 +100,104 @@ func TestWebhookConstants(t *testing.T) {
assert.Equal(t, 200, webhookLogTruncateLen)
})
}
func TestExtractOutTradeNo(t *testing.T) {
tests := []struct {
name string
providerKey string
rawBody string
want string
}{
{
name: "easypay query payload",
providerKey: "easypay",
rawBody: "out_trade_no=sub2_123&trade_status=TRADE_SUCCESS",
want: "sub2_123",
},
{
name: "alipay query payload",
providerKey: "alipay",
rawBody: "notify_time=2026-04-20+12%3A00%3A00&out_trade_no=sub2_456",
want: "sub2_456",
},
{
name: "unknown provider",
providerKey: "wxpay",
rawBody: "{}",
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.want, extractOutTradeNo(tt.rawBody, tt.providerKey))
})
}
}
func TestVerifyNotificationWithProvidersReturnsMatchedProvider(t *testing.T) {
firstErr := errors.New("wrong provider")
providers := []payment.Provider{
webhookHandlerProviderStub{
key: payment.TypeWxpay,
verifyErr: firstErr,
},
webhookHandlerProviderStub{
key: payment.TypeWxpay,
notification: &payment.PaymentNotification{
OrderID: "sub2_42",
TradeNo: "trade-42",
Status: payment.NotificationStatusSuccess,
},
},
}
providerKey, notification, err := verifyNotificationWithProviders(context.Background(), providers, "{}", map[string]string{"wechatpay-signature": "sig"})
require.NoError(t, err)
require.Equal(t, payment.TypeWxpay, providerKey)
require.NotNil(t, notification)
require.Equal(t, "sub2_42", notification.OrderID)
}
func TestVerifyNotificationWithProvidersFailsWhenAllProvidersReject(t *testing.T) {
providers := []payment.Provider{
webhookHandlerProviderStub{
key: payment.TypeWxpay,
verifyErr: errors.New("verify failed a"),
},
webhookHandlerProviderStub{
key: payment.TypeWxpay,
verifyErr: errors.New("verify failed b"),
},
}
_, _, err := verifyNotificationWithProviders(context.Background(), providers, "{}", nil)
require.Error(t, err)
}
type webhookHandlerProviderStub struct {
key string
notification *payment.PaymentNotification
verifyErr error
}
func (p webhookHandlerProviderStub) Name() string { return p.key }
func (p webhookHandlerProviderStub) ProviderKey() string { return p.key }
func (p webhookHandlerProviderStub) SupportedTypes() []payment.PaymentType {
return []payment.PaymentType{payment.PaymentType(p.key)}
}
func (p webhookHandlerProviderStub) CreatePayment(context.Context, payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
panic("unexpected call")
}
func (p webhookHandlerProviderStub) QueryOrder(context.Context, string) (*payment.QueryOrderResponse, error) {
panic("unexpected call")
}
func (p webhookHandlerProviderStub) VerifyNotification(context.Context, string, map[string]string) (*payment.PaymentNotification, error) {
if p.verifyErr != nil {
return nil, p.verifyErr
}
return p.notification, nil
}
func (p webhookHandlerProviderStub) Refund(context.Context, payment.RefundRequest) (*payment.RefundResponse, error) {
panic("unexpected call")
}
......@@ -34,6 +34,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
response.Success(c, dto.PublicSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
ForceEmailOnThirdPartySignup: settings.ForceEmailOnThirdPartySignup,
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled,
......@@ -56,6 +57,10 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
WeChatOAuthEnabled: settings.WeChatOAuthEnabled,
WeChatOAuthOpenEnabled: settings.WeChatOAuthOpenEnabled,
WeChatOAuthMPEnabled: settings.WeChatOAuthMPEnabled,
WeChatOAuthMobileEnabled: settings.WeChatOAuthMobileEnabled,
OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
OIDCOAuthProviderName: settings.OIDCOAuthProviderName,
BackendModeEnabled: settings.BackendModeEnabled,
......
//go:build unit
package handler
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type settingHandlerPublicRepoStub struct {
values map[string]string
}
func (s *settingHandlerPublicRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) {
panic("unexpected Get call")
}
func (s *settingHandlerPublicRepoStub) GetValue(ctx context.Context, key string) (string, error) {
panic("unexpected GetValue call")
}
func (s *settingHandlerPublicRepoStub) Set(ctx context.Context, key, value string) error {
panic("unexpected Set call")
}
func (s *settingHandlerPublicRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
out := make(map[string]string, len(keys))
for _, key := range keys {
if value, ok := s.values[key]; ok {
out[key] = value
}
}
return out, nil
}
func (s *settingHandlerPublicRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
panic("unexpected SetMultiple call")
}
func (s *settingHandlerPublicRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (s *settingHandlerPublicRepoStub) Delete(ctx context.Context, key string) error {
panic("unexpected Delete call")
}
func TestSettingHandler_GetPublicSettings_ExposesForceEmailOnThirdPartySignup(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &settingHandlerPublicRepoStub{
values: map[string]string{
service.SettingKeyForceEmailOnThirdPartySignup: "true",
},
}
h := NewSettingHandler(service.NewSettingService(repo, &config.Config{}), "test-version")
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/settings/public", nil)
h.GetPublicSettings(c)
require.Equal(t, http.StatusOK, recorder.Code)
var resp struct {
Code int `json:"code"`
Data struct {
ForceEmailOnThirdPartySignup bool `json:"force_email_on_third_party_signup"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.True(t, resp.Data.ForceEmailOnThirdPartySignup)
}
func TestSettingHandler_GetPublicSettings_ExposesWeChatOAuthModeCapabilities(t *testing.T) {
gin.SetMode(gin.TestMode)
h := NewSettingHandler(service.NewSettingService(&settingHandlerPublicRepoStub{
values: map[string]string{
service.SettingKeyWeChatConnectEnabled: "true",
service.SettingKeyWeChatConnectAppID: "wx-mp-app",
service.SettingKeyWeChatConnectAppSecret: "wx-mp-secret",
service.SettingKeyWeChatConnectMode: "mp",
service.SettingKeyWeChatConnectScopes: "snsapi_base",
service.SettingKeyWeChatConnectOpenEnabled: "true",
service.SettingKeyWeChatConnectMPEnabled: "true",
service.SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
service.SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
},
}, &config.Config{}), "test-version")
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/settings/public", nil)
h.GetPublicSettings(c)
require.Equal(t, http.StatusOK, recorder.Code)
var resp struct {
Code int `json:"code"`
Data struct {
WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"`
WeChatOAuthMPEnabled bool `json:"wechat_oauth_mp_enabled"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.True(t, resp.Data.WeChatOAuthEnabled)
require.True(t, resp.Data.WeChatOAuthOpenEnabled)
require.True(t, resp.Data.WeChatOAuthMPEnabled)
}
package handler
import (
"context"
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
......@@ -12,14 +15,21 @@ import (
// UserHandler handles user-related requests
type UserHandler struct {
userService *service.UserService
authService *service.AuthService
emailService *service.EmailService
emailCache service.EmailCache
}
// NewUserHandler creates a new UserHandler
func NewUserHandler(userService *service.UserService, emailService *service.EmailService, emailCache service.EmailCache) *UserHandler {
func NewUserHandler(
userService *service.UserService,
authService *service.AuthService,
emailService *service.EmailService,
emailCache service.EmailCache,
) *UserHandler {
return &UserHandler{
userService: userService,
authService: authService,
emailService: emailService,
emailCache: emailCache,
}
......@@ -34,10 +44,33 @@ type ChangePasswordRequest struct {
// UpdateProfileRequest represents the update profile request payload
type UpdateProfileRequest struct {
Username *string `json:"username"`
AvatarURL *string `json:"avatar_url"`
BalanceNotifyEnabled *bool `json:"balance_notify_enabled"`
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
}
type userProfileResponse struct {
dto.User
AvatarURL string `json:"avatar_url,omitempty"`
AvatarSource *userProfileSourceContext `json:"avatar_source,omitempty"`
UsernameSource *userProfileSourceContext `json:"username_source,omitempty"`
DisplayNameSource *userProfileSourceContext `json:"display_name_source,omitempty"`
NicknameSource *userProfileSourceContext `json:"nickname_source,omitempty"`
ProfileSources map[string]*userProfileSourceContext `json:"profile_sources,omitempty"`
Identities service.UserIdentitySummarySet `json:"identities"`
AuthBindings map[string]service.UserIdentitySummary `json:"auth_bindings"`
IdentityBindings map[string]service.UserIdentitySummary `json:"identity_bindings"`
EmailBound bool `json:"email_bound"`
LinuxDoBound bool `json:"linuxdo_bound"`
OIDCBound bool `json:"oidc_bound"`
WeChatBound bool `json:"wechat_bound"`
}
type userProfileSourceContext struct {
Provider string `json:"provider,omitempty"`
Source string `json:"source,omitempty"`
}
// GetProfile handles getting user profile
// GET /api/v1/users/me
func (h *UserHandler) GetProfile(c *gin.Context) {
......@@ -47,13 +80,19 @@ func (h *UserHandler) GetProfile(c *gin.Context) {
return
}
userData, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
userData, err := h.userService.GetProfile(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, userData)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.UserFromService(userData))
response.Success(c, profileResp)
}
// ChangePassword handles changing user password
......@@ -101,6 +140,7 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
svcReq := service.UpdateProfileRequest{
Username: req.Username,
AvatarURL: req.AvatarURL,
BalanceNotifyEnabled: req.BalanceNotifyEnabled,
BalanceNotifyThreshold: req.BalanceNotifyThreshold,
}
......@@ -110,7 +150,149 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
return
}
response.Success(c, dto.UserFromService(updatedUser))
profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profileResp)
}
type StartIdentityBindingRequest struct {
Provider string `json:"provider" binding:"required"`
RedirectTo string `json:"redirect_to"`
}
type BindEmailIdentityRequest struct {
Email string `json:"email" binding:"required,email"`
VerifyCode string `json:"verify_code" binding:"required"`
Password string `json:"password" binding:"required"`
}
type SendEmailBindingCodeRequest struct {
Email string `json:"email" binding:"required,email"`
}
// StartIdentityBinding returns the backend authorize URL for starting a third-party identity bind flow.
// POST /api/v1/user/auth-identities/bind/start
func (h *UserHandler) StartIdentityBinding(c *gin.Context) {
if _, ok := middleware2.GetAuthSubjectFromContext(c); !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req StartIdentityBindingRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
result, err := h.userService.PrepareIdentityBindingStart(c.Request.Context(), service.StartUserIdentityBindingRequest{
Provider: req.Provider,
RedirectTo: req.RedirectTo,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
// BindEmailIdentity verifies and binds a local email identity for the current user.
// POST /api/v1/user/account-bindings/email
func (h *UserHandler) BindEmailIdentity(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
if h.authService == nil {
response.InternalError(c, "Auth service not configured")
return
}
var req BindEmailIdentityRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
updatedUser, err := h.authService.BindEmailIdentity(
c.Request.Context(),
subject.UserID,
req.Email,
req.VerifyCode,
req.Password,
)
if err != nil {
response.ErrorFrom(c, err)
return
}
profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profileResp)
}
// UnbindIdentity removes a third-party sign-in provider from the current user.
// DELETE /api/v1/user/account-bindings/:provider
func (h *UserHandler) UnbindIdentity(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
updatedUser, err := h.userService.UnbindUserAuthProvider(
c.Request.Context(),
subject.UserID,
c.Param("provider"),
)
if err != nil {
response.ErrorFrom(c, err)
return
}
profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profileResp)
}
// SendEmailBindingCode sends a verification code for the current user's email binding flow.
// POST /api/v1/user/account-bindings/email/send-code
func (h *UserHandler) SendEmailBindingCode(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
if h.authService == nil {
response.InternalError(c, "Auth service not configured")
return
}
var req SendEmailBindingCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if err := h.authService.SendEmailIdentityBindCode(c.Request.Context(), subject.UserID, req.Email); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Verification code sent successfully"})
}
// SendNotifyEmailCodeRequest represents the request to send notify email verification code
......@@ -176,7 +358,13 @@ func (h *UserHandler) VerifyNotifyEmail(c *gin.Context) {
return
}
response.Success(c, dto.UserFromService(updatedUser))
profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profileResp)
}
// RemoveNotifyEmailRequest represents the request to remove a notify email
......@@ -212,7 +400,13 @@ func (h *UserHandler) RemoveNotifyEmail(c *gin.Context) {
return
}
response.Success(c, dto.UserFromService(updatedUser))
profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profileResp)
}
// ToggleNotifyEmailRequest represents the request to toggle a notify email's disabled state
......@@ -248,5 +442,116 @@ func (h *UserHandler) ToggleNotifyEmail(c *gin.Context) {
return
}
response.Success(c, dto.UserFromService(updatedUser))
profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profileResp)
}
func (h *UserHandler) buildUserProfileResponse(ctx context.Context, userID int64, user *service.User) (userProfileResponse, error) {
identities, err := h.userService.GetProfileIdentitySummaries(ctx, userID, user)
if err != nil {
return userProfileResponse{}, err
}
return userProfileResponseFromService(user, identities), nil
}
func userProfileResponseFromService(user *service.User, identities service.UserIdentitySummarySet) userProfileResponse {
base := dto.UserFromService(user)
if base == nil {
return userProfileResponse{}
}
bindings := userProfileBindingMap(identities)
profileSources, avatarSource, usernameSource := inferUserProfileSources(user, identities)
return userProfileResponse{
User: *base,
AvatarURL: user.AvatarURL,
AvatarSource: avatarSource,
UsernameSource: usernameSource,
DisplayNameSource: usernameSource,
NicknameSource: usernameSource,
ProfileSources: profileSources,
Identities: identities,
AuthBindings: bindings,
IdentityBindings: bindings,
EmailBound: identities.Email.Bound,
LinuxDoBound: identities.LinuxDo.Bound,
OIDCBound: identities.OIDC.Bound,
WeChatBound: identities.WeChat.Bound,
}
}
func userProfileBindingMap(identities service.UserIdentitySummarySet) map[string]service.UserIdentitySummary {
return map[string]service.UserIdentitySummary{
"email": identities.Email,
"linuxdo": identities.LinuxDo,
"oidc": identities.OIDC,
"wechat": identities.WeChat,
}
}
func inferUserProfileSources(user *service.User, identities service.UserIdentitySummarySet) (
map[string]*userProfileSourceContext,
*userProfileSourceContext,
*userProfileSourceContext,
) {
if user == nil {
return nil, nil, nil
}
thirdParty := thirdPartyIdentityProviders(identities)
var avatarSource *userProfileSourceContext
if strings.TrimSpace(user.AvatarURL) != "" && len(thirdParty) == 1 {
avatarSource = buildUserProfileSourceContext(thirdParty[0].Provider)
}
usernameValue := strings.TrimSpace(user.Username)
var usernameSource *userProfileSourceContext
for _, summary := range thirdParty {
if usernameValue != "" && usernameValue == strings.TrimSpace(summary.DisplayName) {
usernameSource = buildUserProfileSourceContext(summary.Provider)
break
}
}
if usernameSource == nil && usernameValue != "" && len(thirdParty) == 1 {
usernameSource = buildUserProfileSourceContext(thirdParty[0].Provider)
}
profileSources := map[string]*userProfileSourceContext{}
if avatarSource != nil {
profileSources["avatar"] = avatarSource
}
if usernameSource != nil {
profileSources["username"] = usernameSource
profileSources["display_name"] = usernameSource
profileSources["nickname"] = usernameSource
}
if len(profileSources) == 0 {
return nil, avatarSource, usernameSource
}
return profileSources, avatarSource, usernameSource
}
func thirdPartyIdentityProviders(identities service.UserIdentitySummarySet) []service.UserIdentitySummary {
out := make([]service.UserIdentitySummary, 0, 3)
for _, summary := range []service.UserIdentitySummary{identities.LinuxDo, identities.OIDC, identities.WeChat} {
if summary.Bound {
out = append(out, summary)
}
}
return out
}
func buildUserProfileSourceContext(provider string) *userProfileSourceContext {
provider = strings.TrimSpace(provider)
if provider == "" {
return nil
}
return &userProfileSourceContext{
Provider: provider,
Source: provider,
}
}
//go:build unit
package handler
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type userHandlerRepoStub struct {
user *service.User
identities []service.UserAuthIdentityRecord
unbound []string
}
func (s *userHandlerRepoStub) Create(context.Context, *service.User) error { return nil }
func (s *userHandlerRepoStub) GetByID(context.Context, int64) (*service.User, error) {
cloned := *s.user
return &cloned, nil
}
func (s *userHandlerRepoStub) GetByEmail(context.Context, string) (*service.User, error) {
cloned := *s.user
return &cloned, nil
}
func (s *userHandlerRepoStub) GetFirstAdmin(context.Context) (*service.User, error) {
cloned := *s.user
return &cloned, nil
}
func (s *userHandlerRepoStub) Update(_ context.Context, user *service.User) error {
cloned := *user
s.user = &cloned
return nil
}
func (s *userHandlerRepoStub) Delete(context.Context, int64) error { return nil }
func (s *userHandlerRepoStub) GetUserAvatar(context.Context, int64) (*service.UserAvatar, error) {
if s.user == nil || s.user.AvatarURL == "" {
return nil, nil
}
return &service.UserAvatar{
StorageProvider: s.user.AvatarSource,
URL: s.user.AvatarURL,
ContentType: s.user.AvatarMIME,
ByteSize: s.user.AvatarByteSize,
SHA256: s.user.AvatarSHA256,
}, nil
}
func (s *userHandlerRepoStub) UpsertUserAvatar(_ context.Context, _ int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
s.user.AvatarURL = input.URL
s.user.AvatarSource = input.StorageProvider
s.user.AvatarMIME = input.ContentType
s.user.AvatarByteSize = input.ByteSize
s.user.AvatarSHA256 = input.SHA256
return &service.UserAvatar{
StorageProvider: input.StorageProvider,
URL: input.URL,
ContentType: input.ContentType,
ByteSize: input.ByteSize,
SHA256: input.SHA256,
}, nil
}
func (s *userHandlerRepoStub) DeleteUserAvatar(context.Context, int64) error {
s.user.AvatarURL = ""
s.user.AvatarSource = ""
s.user.AvatarMIME = ""
s.user.AvatarByteSize = 0
s.user.AvatarSHA256 = ""
return nil
}
func (s *userHandlerRepoStub) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *userHandlerRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *userHandlerRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil }
func (s *userHandlerRepoStub) DeductBalance(context.Context, int64, float64) error { return nil }
func (s *userHandlerRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil }
func (s *userHandlerRepoStub) ExistsByEmail(context.Context, string) (bool, error) { return false, nil }
func (s *userHandlerRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
return 0, nil
}
func (s *userHandlerRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error {
return nil
}
func (s *userHandlerRepoStub) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
return map[int64]*time.Time{}, nil
}
func (s *userHandlerRepoStub) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
return nil, nil
}
func (s *userHandlerRepoStub) UpdateUserLastActiveAt(_ context.Context, _ int64, activeAt time.Time) error {
if s.user != nil {
s.user.LastActiveAt = &activeAt
}
return nil
}
func (s *userHandlerRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
return nil
}
func (s *userHandlerRepoStub) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (s *userHandlerRepoStub) EnableTotp(context.Context, int64) error { return nil }
func (s *userHandlerRepoStub) DisableTotp(context.Context, int64) error { return nil }
func (s *userHandlerRepoStub) ListUserAuthIdentities(context.Context, int64) ([]service.UserAuthIdentityRecord, error) {
out := make([]service.UserAuthIdentityRecord, len(s.identities))
copy(out, s.identities)
return out, nil
}
func (s *userHandlerRepoStub) UnbindUserAuthProvider(_ context.Context, _ int64, provider string) error {
s.unbound = append(s.unbound, provider)
filtered := s.identities[:0]
for _, identity := range s.identities {
if identity.ProviderType == provider {
continue
}
filtered = append(filtered, identity)
}
s.identities = append([]service.UserAuthIdentityRecord(nil), filtered...)
return nil
}
func TestUserHandlerUpdateProfileReturnsAvatarURL(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &userHandlerRepoStub{
user: &service.User{
ID: 11,
Email: "handler-avatar@example.com",
Username: "handler-avatar",
Role: service.RoleUser,
Status: service.StatusActive,
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
body := []byte(`{"avatar_url":"https://cdn.example.com/avatar.png"}`)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/user", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11})
handler.UpdateProfile(c)
require.Equal(t, http.StatusOK, recorder.Code)
var resp struct {
Code int `json:"code"`
Data struct {
AvatarURL string `json:"avatar_url"`
Username string `json:"username"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Equal(t, "https://cdn.example.com/avatar.png", resp.Data.AvatarURL)
require.Equal(t, "handler-avatar", resp.Data.Username)
}
func TestUserHandlerGetProfileReturnsIdentitySummaries(t *testing.T) {
gin.SetMode(gin.TestMode)
verifiedAt := time.Date(2026, 4, 20, 8, 30, 0, 0, time.UTC)
repo := &userHandlerRepoStub{
user: &service.User{
ID: 11,
Email: "identity@example.com",
Username: "identity-user",
Role: service.RoleUser,
Status: service.StatusActive,
},
identities: []service.UserAuthIdentityRecord{
{
ProviderType: "linuxdo",
ProviderKey: "linuxdo",
ProviderSubject: "linuxdo-subject-123456",
VerifiedAt: &verifiedAt,
Metadata: map[string]any{
"username": "linuxdo-handle",
},
},
{
ProviderType: "oidc",
ProviderKey: "https://issuer.example.com",
ProviderSubject: "oidc-user-abc",
Metadata: map[string]any{
"suggested_display_name": "OIDC Display",
},
},
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/user/profile", nil)
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11})
handler.GetProfile(c)
require.Equal(t, http.StatusOK, recorder.Code)
var resp struct {
Code int `json:"code"`
Data struct {
Identities struct {
Email struct {
Bound bool `json:"bound"`
BoundCount int `json:"bound_count"`
DisplayName string `json:"display_name"`
} `json:"email"`
LinuxDo struct {
Bound bool `json:"bound"`
BoundCount int `json:"bound_count"`
DisplayName string `json:"display_name"`
ProviderKey string `json:"provider_key"`
} `json:"linuxdo"`
OIDC struct {
Bound bool `json:"bound"`
DisplayName string `json:"display_name"`
ProviderKey string `json:"provider_key"`
} `json:"oidc"`
WeChat struct {
Bound bool `json:"bound"`
CanBind bool `json:"can_bind"`
BindStartPath string `json:"bind_start_path"`
} `json:"wechat"`
} `json:"identities"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.True(t, resp.Data.Identities.Email.Bound)
require.Equal(t, 1, resp.Data.Identities.Email.BoundCount)
require.Equal(t, "identity@example.com", resp.Data.Identities.Email.DisplayName)
require.True(t, resp.Data.Identities.LinuxDo.Bound)
require.Equal(t, 1, resp.Data.Identities.LinuxDo.BoundCount)
require.Equal(t, "linuxdo-handle", resp.Data.Identities.LinuxDo.DisplayName)
require.Equal(t, "linuxdo", resp.Data.Identities.LinuxDo.ProviderKey)
require.True(t, resp.Data.Identities.OIDC.Bound)
require.Equal(t, "OIDC Display", resp.Data.Identities.OIDC.DisplayName)
require.Equal(t, "https://issuer.example.com", resp.Data.Identities.OIDC.ProviderKey)
require.False(t, resp.Data.Identities.WeChat.Bound)
require.True(t, resp.Data.Identities.WeChat.CanBind)
require.Contains(t, resp.Data.Identities.WeChat.BindStartPath, "/api/v1/auth/oauth/wechat/start")
}
func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
gin.SetMode(gin.TestMode)
verifiedAt := time.Date(2026, 4, 20, 8, 30, 0, 0, time.UTC)
repo := &userHandlerRepoStub{
user: &service.User{
ID: 21,
Email: "legacy-profile@example.com",
Username: "linuxdo-handle",
Role: service.RoleUser,
Status: service.StatusActive,
AvatarURL: "https://cdn.example.com/linuxdo.png",
AvatarSource: "remote_url",
},
identities: []service.UserAuthIdentityRecord{
{
ProviderType: "linuxdo",
ProviderKey: "linuxdo",
ProviderSubject: "linuxdo-subject-21",
VerifiedAt: &verifiedAt,
Metadata: map[string]any{
"username": "linuxdo-handle",
},
},
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/user/profile", nil)
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 21})
handler.GetProfile(c)
require.Equal(t, http.StatusOK, recorder.Code)
var resp struct {
Code int `json:"code"`
Data map[string]any `json:"data"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Equal(t, true, resp.Data["email_bound"])
require.Equal(t, true, resp.Data["linuxdo_bound"])
require.Equal(t, false, resp.Data["oidc_bound"])
require.Equal(t, false, resp.Data["wechat_bound"])
require.Equal(t, "https://cdn.example.com/linuxdo.png", resp.Data["avatar_url"])
avatarSource, ok := resp.Data["avatar_source"].(map[string]any)
require.True(t, ok)
require.Equal(t, "linuxdo", avatarSource["provider"])
require.Equal(t, "linuxdo", avatarSource["source"])
authBindings, ok := resp.Data["auth_bindings"].(map[string]any)
require.True(t, ok)
linuxdoBinding, ok := authBindings["linuxdo"].(map[string]any)
require.True(t, ok)
require.Equal(t, true, linuxdoBinding["bound"])
require.Equal(t, "linuxdo", linuxdoBinding["provider"])
identityBindings, ok := resp.Data["identity_bindings"].(map[string]any)
require.True(t, ok)
emailBinding, ok := identityBindings["email"].(map[string]any)
require.True(t, ok)
require.Equal(t, true, emailBinding["bound"])
profileSources, ok := resp.Data["profile_sources"].(map[string]any)
require.True(t, ok)
usernameSource, ok := profileSources["username"].(map[string]any)
require.True(t, ok)
require.Equal(t, "linuxdo", usernameSource["provider"])
require.Equal(t, "linuxdo", usernameSource["source"])
}
type userHandlerEmailCacheStub struct {
data *service.VerificationCodeData
}
func (s *userHandlerEmailCacheStub) GetVerificationCode(context.Context, string) (*service.VerificationCodeData, error) {
return s.data, nil
}
func (s *userHandlerEmailCacheStub) SetVerificationCode(context.Context, string, *service.VerificationCodeData, time.Duration) error {
return nil
}
func (s *userHandlerEmailCacheStub) DeleteVerificationCode(context.Context, string) error {
return nil
}
func (s *userHandlerEmailCacheStub) GetNotifyVerifyCode(context.Context, string) (*service.VerificationCodeData, error) {
return nil, nil
}
func (s *userHandlerEmailCacheStub) SetNotifyVerifyCode(context.Context, string, *service.VerificationCodeData, time.Duration) error {
return nil
}
func (s *userHandlerEmailCacheStub) DeleteNotifyVerifyCode(context.Context, string) error {
return nil
}
func (s *userHandlerEmailCacheStub) GetPasswordResetToken(context.Context, string) (*service.PasswordResetTokenData, error) {
return nil, nil
}
func (s *userHandlerEmailCacheStub) SetPasswordResetToken(context.Context, string, *service.PasswordResetTokenData, time.Duration) error {
return nil
}
func (s *userHandlerEmailCacheStub) DeletePasswordResetToken(context.Context, string) error {
return nil
}
func (s *userHandlerEmailCacheStub) IsPasswordResetEmailInCooldown(context.Context, string) bool {
return false
}
func (s *userHandlerEmailCacheStub) SetPasswordResetEmailCooldown(context.Context, string, time.Duration) error {
return nil
}
func (s *userHandlerEmailCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int64, error) {
return 0, nil
}
func (s *userHandlerEmailCacheStub) IncrNotifyCodeUserRate(context.Context, int64, time.Duration) (int64, error) {
return 0, nil
}
func TestUserHandlerBindEmailIdentityReturnsProfileResponse(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &userHandlerRepoStub{
user: &service.User{
ID: 11,
Email: "legacy-user" + service.LinuxDoConnectSyntheticEmailDomain,
Username: "legacy-user",
Role: service.RoleUser,
Status: service.StatusActive,
},
}
emailCache := &userHandlerEmailCacheStub{
data: &service.VerificationCodeData{
Code: "123456",
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
},
}
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
ExpireHour: 1,
},
}
emailService := service.NewEmailService(nil, emailCache)
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"new-password"}`)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/user/account-bindings/email", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
c.Params = gin.Params{{Key: "provider", Value: "email"}}
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11})
handler.BindEmailIdentity(c)
require.Equal(t, http.StatusOK, recorder.Code)
var resp struct {
Code int `json:"code"`
Data struct {
Email string `json:"email"`
EmailBound bool `json:"email_bound"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Equal(t, "new@example.com", resp.Data.Email)
require.True(t, resp.Data.EmailBound)
}
func TestUserHandlerUnbindIdentityReturnsUpdatedProfile(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &userHandlerRepoStub{
user: &service.User{
ID: 21,
Email: "identity@example.com",
Username: "identity-user",
Role: service.RoleUser,
Status: service.StatusActive,
},
identities: []service.UserAuthIdentityRecord{
{
ProviderType: "email",
ProviderKey: "email",
ProviderSubject: "identity@example.com",
},
{
ProviderType: "linuxdo",
ProviderKey: "linuxdo",
ProviderSubject: "linuxdo-subject-21",
Metadata: map[string]any{
"username": "linuxdo-handle",
},
},
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodDelete, "/api/v1/user/account-bindings/linuxdo", nil)
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 21})
c.Params = gin.Params{{Key: "provider", Value: "linuxdo"}}
handler.UnbindIdentity(c)
require.Equal(t, http.StatusOK, recorder.Code)
require.Equal(t, []string{"linuxdo"}, repo.unbound)
var resp struct {
Code int `json:"code"`
Data map[string]any `json:"data"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
authBindings, ok := resp.Data["auth_bindings"].(map[string]any)
require.True(t, ok)
linuxdoBinding, ok := authBindings["linuxdo"].(map[string]any)
require.True(t, ok)
require.Equal(t, false, linuxdoBinding["bound"])
}
func TestUserHandlerBindEmailIdentityRejectsWrongCurrentPasswordForBoundEmail(t *testing.T) {
gin.SetMode(gin.TestMode)
user := &service.User{
ID: 11,
Email: "current@example.com",
Username: "bound-user",
Role: service.RoleUser,
Status: service.StatusActive,
}
require.NoError(t, user.SetPassword("current-password"))
repo := &userHandlerRepoStub{user: user}
emailCache := &userHandlerEmailCacheStub{
data: &service.VerificationCodeData{
Code: "123456",
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
},
}
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
ExpireHour: 1,
},
}
emailService := service.NewEmailService(nil, emailCache)
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"wrong-password"}`)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/user/account-bindings/email", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11})
handler.BindEmailIdentity(c)
require.Equal(t, http.StatusBadRequest, recorder.Code)
var resp struct {
Code int `json:"code"`
Message string `json:"message"`
Reason string `json:"reason"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, http.StatusBadRequest, resp.Code)
require.Equal(t, "PASSWORD_INCORRECT", resp.Reason)
require.Equal(t, "current password is incorrect", resp.Message)
require.Equal(t, "current@example.com", repo.user.Email)
}
func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &userHandlerRepoStub{
user: &service.User{
ID: 11,
Email: "identity@example.com",
Username: "identity-user",
Role: service.RoleUser,
Status: service.StatusActive,
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
body := []byte(`{"provider":"wechat","redirect_to":"/settings/profile"}`)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/user/auth-identities/bind/start", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11})
handler.StartIdentityBinding(c)
require.Equal(t, http.StatusOK, recorder.Code)
var resp struct {
Code int `json:"code"`
Data struct {
Provider string `json:"provider"`
AuthorizeURL string `json:"authorize_url"`
Method string `json:"method"`
UseBrowserRedirect bool `json:"use_browser_redirect"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Equal(t, "wechat", resp.Data.Provider)
require.Equal(t, "GET", resp.Data.Method)
require.True(t, resp.Data.UseBrowserRedirect)
require.Contains(t, resp.Data.AuthorizeURL, "/api/v1/auth/oauth/wechat/start")
require.Contains(t, resp.Data.AuthorizeURL, "intent=bind_current_user")
require.Contains(t, resp.Data.AuthorizeURL, "redirect=%2Fsettings%2Fprofile")
}
......@@ -45,11 +45,31 @@ type DefaultLoadBalancer struct {
counter atomic.Uint64
}
type contextKey string
const wxpayJSAPIAppIDContextKey contextKey = "payment.wxpay.jsapi_app_id"
// NewDefaultLoadBalancer creates a new load balancer.
func NewDefaultLoadBalancer(db *dbent.Client, encryptionKey []byte) *DefaultLoadBalancer {
return &DefaultLoadBalancer{db: db, encryptionKey: encryptionKey}
}
func WithWxpayJSAPIAppID(ctx context.Context, appID string) context.Context {
appID = strings.TrimSpace(appID)
if appID == "" {
return ctx
}
return context.WithValue(ctx, wxpayJSAPIAppIDContextKey, appID)
}
func wxpayJSAPIAppIDFromContext(ctx context.Context) string {
if ctx == nil {
return ""
}
appID, _ := ctx.Value(wxpayJSAPIAppIDContextKey).(string)
return strings.TrimSpace(appID)
}
// instanceCandidate pairs an instance with its pre-fetched daily usage.
type instanceCandidate struct {
inst *dbent.PaymentProviderInstance
......@@ -116,6 +136,7 @@ func (lb *DefaultLoadBalancer) queryEnabledInstances(
}
var matched []*dbent.PaymentProviderInstance
expectedWxpayJSAPIAppID := wxpayJSAPIAppIDFromContext(ctx)
for _, inst := range instances {
// Stripe: match by provider_key because supported_types lists sub-types (card,link,alipay,wxpay),
// not "stripe" itself. The checkout page aggregates all sub-types under "stripe".
......@@ -124,6 +145,16 @@ func (lb *DefaultLoadBalancer) queryEnabledInstances(
matched = append(matched, inst)
}
} else if InstanceSupportsType(inst.SupportedTypes, paymentType) {
if expectedWxpayJSAPIAppID != "" && normalizeVisibleMethodSupportType(paymentType) == TypeWxpay && inst.ProviderKey == TypeWxpay {
config, cfgErr := lb.decryptConfig(inst.Config)
if cfgErr != nil {
slog.Warn("skip wxpay instance with unreadable config during jsapi filtering", "instance_id", inst.ID, "error", cfgErr)
continue
}
if resolveWxpayJSAPIAppID(config) != expectedWxpayJSAPIAppID {
continue
}
}
matched = append(matched, inst)
}
}
......@@ -231,6 +262,11 @@ func getInstanceChannelLimits(inst *dbent.PaymentProviderInstance, paymentType P
if cl, ok := limits[lookupKey]; ok {
return cl
}
if aliasKey := legacyVisibleMethodAlias(lookupKey); aliasKey != "" {
if cl, ok := limits[aliasKey]; ok {
return cl
}
}
return ChannelLimits{}
}
......@@ -344,14 +380,45 @@ func InstanceSupportsType(supportedTypes string, target PaymentType) bool {
if supportedTypes == "" {
return true
}
normalizedTarget := normalizeVisibleMethodSupportType(target)
for _, t := range strings.Split(supportedTypes, ",") {
if strings.TrimSpace(t) == target {
supported := strings.TrimSpace(t)
if supported == target || normalizeVisibleMethodSupportType(supported) == normalizedTarget {
return true
}
}
return false
}
func normalizeVisibleMethodSupportType(paymentType PaymentType) PaymentType {
switch strings.TrimSpace(paymentType) {
case TypeAlipay, TypeAlipayDirect:
return TypeAlipay
case TypeWxpay, TypeWxpayDirect:
return TypeWxpay
default:
return strings.TrimSpace(paymentType)
}
}
func legacyVisibleMethodAlias(paymentType PaymentType) PaymentType {
switch normalizeVisibleMethodSupportType(paymentType) {
case TypeAlipay:
return TypeAlipayDirect
case TypeWxpay:
return TypeWxpayDirect
default:
return ""
}
}
func resolveWxpayJSAPIAppID(config map[string]string) string {
if appID := strings.TrimSpace(config["mpAppId"]); appID != "" {
return appID
}
return strings.TrimSpace(config["appId"])
}
// GetInstanceConfig decrypts and returns the configuration for a provider instance by ID.
func (lb *DefaultLoadBalancer) GetInstanceConfig(ctx context.Context, instanceID int64) (map[string]string, error) {
inst, err := lb.db.PaymentProviderInstance.Get(ctx, instanceID)
......
......@@ -68,10 +68,16 @@ func TestInstanceSupportsType(t *testing.T) {
expected: true,
},
{
name: "partial match should not succeed",
name: "legacy alipay direct supports canonical visible method",
supportedTypes: "alipay_direct",
target: "alipay",
expected: false,
expected: true,
},
{
name: "legacy wxpay direct supports canonical visible method",
supportedTypes: "wxpay_direct",
target: "wxpay",
expected: true,
},
{
name: "empty supported types means all supported",
......@@ -92,6 +98,22 @@ func TestInstanceSupportsType(t *testing.T) {
}
}
func TestGetInstanceChannelLimitsFallsBackToLegacyDirectAliases(t *testing.T) {
t.Parallel()
inst := testInstance(1, TypeAlipay, makeLimitsJSON(TypeAlipayDirect, ChannelLimits{SingleMax: 66}))
got := getInstanceChannelLimits(inst, TypeAlipay)
if got.SingleMax != 66 {
t.Fatalf("getInstanceChannelLimits() = %+v, want SingleMax=66", got)
}
wxInst := testInstance(2, TypeWxpay, makeLimitsJSON(TypeWxpayDirect, ChannelLimits{SingleMin: 8}))
wxGot := getInstanceChannelLimits(wxInst, TypeWxpay)
if wxGot.SingleMin != 8 {
t.Fatalf("getInstanceChannelLimits() = %+v, want SingleMin=8", wxGot)
}
}
// ---------------------------------------------------------------------------
// Helper to build test PaymentProviderInstance values
// ---------------------------------------------------------------------------
......
......@@ -26,6 +26,15 @@ const (
alipayRefundSuffix = "-refund"
)
var (
alipayTradeWapPay = func(client *alipay.Client, param alipay.TradeWapPay) (*url.URL, error) {
return client.TradeWapPay(param)
}
alipayTradePagePay = func(client *alipay.Client, param alipay.TradePagePay) (*url.URL, error) {
return client.TradePagePay(param)
}
)
// Alipay implements payment.Provider and payment.CancelableProvider using the smartwalle/alipay SDK.
type Alipay struct {
instanceID string
......@@ -79,6 +88,17 @@ func (a *Alipay) SupportedTypes() []payment.PaymentType {
return []payment.PaymentType{payment.TypeAlipay}
}
func (a *Alipay) MerchantIdentityMetadata() map[string]string {
if a == nil {
return nil
}
appID := strings.TrimSpace(a.config["appId"])
if appID == "" {
return nil
}
return map[string]string{"app_id": appID}
}
// CreatePayment creates an Alipay payment using redirect-only flow:
// - Mobile (H5): alipay.trade.wap.pay — returns a URL the browser jumps to.
// - PC: alipay.trade.page.pay — returns a gateway URL the browser opens in a
......@@ -115,7 +135,7 @@ func (a *Alipay) createWapTrade(client *alipay.Client, req payment.CreatePayment
param.NotifyURL = notifyURL
param.ReturnURL = returnURL
payURL, err := client.TradeWapPay(param)
payURL, err := alipayTradeWapPay(client, param)
if err != nil {
return nil, fmt.Errorf("alipay TradeWapPay: %w", err)
}
......@@ -134,7 +154,7 @@ func (a *Alipay) createPagePayTrade(client *alipay.Client, req payment.CreatePay
param.NotifyURL = notifyURL
param.ReturnURL = returnURL
payURL, err := client.TradePagePay(param)
payURL, err := alipayTradePagePay(client, param)
if err != nil {
return nil, fmt.Errorf("alipay TradePagePay: %w", err)
}
......@@ -176,10 +196,11 @@ func (a *Alipay) QueryOrder(ctx context.Context, tradeNo string) (*payment.Query
}
return &payment.QueryOrderResponse{
TradeNo: result.TradeNo,
Status: status,
Amount: amount,
PaidAt: result.SendPayDate,
TradeNo: result.TradeNo,
Status: status,
Amount: amount,
PaidAt: result.SendPayDate,
Metadata: a.MerchantIdentityMetadata(),
}, nil
}
......@@ -210,12 +231,21 @@ func (a *Alipay) VerifyNotification(ctx context.Context, rawBody string, _ map[s
return nil, fmt.Errorf("alipay parse notification amount %q: %w", notification.TotalAmount, err)
}
metadata := a.MerchantIdentityMetadata()
if appID := strings.TrimSpace(notification.AppId); appID != "" {
if metadata == nil {
metadata = map[string]string{}
}
metadata["app_id"] = appID
}
return &payment.PaymentNotification{
TradeNo: notification.TradeNo,
OrderID: notification.OutTradeNo,
Amount: amount,
Status: status,
RawData: rawBody,
TradeNo: notification.TradeNo,
OrderID: notification.OutTradeNo,
Amount: amount,
Status: status,
RawData: rawBody,
Metadata: metadata,
}, nil
}
......@@ -278,6 +308,7 @@ func isTradeNotExist(err error) bool {
// Ensure interface compliance.
var (
_ payment.Provider = (*Alipay)(nil)
_ payment.CancelableProvider = (*Alipay)(nil)
_ payment.Provider = (*Alipay)(nil)
_ payment.CancelableProvider = (*Alipay)(nil)
_ payment.MerchantIdentityProvider = (*Alipay)(nil)
)
......@@ -4,8 +4,12 @@ package provider
import (
"errors"
"net/url"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/smartwalle/alipay/v3"
)
func TestIsTradeNotExist(t *testing.T) {
......@@ -130,3 +134,96 @@ func TestNewAlipay(t *testing.T) {
})
}
}
func TestCreateTradeUsesPagePayForDesktop(t *testing.T) {
origPagePay := alipayTradePagePay
origWapPay := alipayTradeWapPay
t.Cleanup(func() {
alipayTradePagePay = origPagePay
alipayTradeWapPay = origWapPay
})
pagePayCalls := 0
wapPayCalls := 0
alipayTradePagePay = func(client *alipay.Client, param alipay.TradePagePay) (*url.URL, error) {
pagePayCalls++
if param.OutTradeNo != "sub2_100" {
t.Fatalf("out_trade_no = %q, want %q", param.OutTradeNo, "sub2_100")
}
if param.NotifyURL != "https://merchant.example.com/api/v1/payment/webhook/alipay" {
t.Fatalf("notify_url = %q", param.NotifyURL)
}
return url.Parse("https://openapi.alipay.com/gateway.do?page-pay")
}
alipayTradeWapPay = func(client *alipay.Client, param alipay.TradeWapPay) (*url.URL, error) {
wapPayCalls++
return url.Parse("https://openapi.alipay.com/gateway.do?wap-pay")
}
provider := &Alipay{}
resp, err := provider.createPagePayTrade(&alipay.Client{}, payment.CreatePaymentRequest{
OrderID: "sub2_100",
Amount: "88.00",
Subject: "Balance recharge",
}, "https://merchant.example.com/api/v1/payment/webhook/alipay", "https://merchant.example.com/payment/result")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if pagePayCalls != 1 {
t.Fatalf("page pay calls = %d, want 1", pagePayCalls)
}
if wapPayCalls != 0 {
t.Fatalf("wap pay calls = %d, want 0", wapPayCalls)
}
if resp.PayURL == "" {
t.Fatal("expected pay_url for desktop page pay")
}
}
func TestCreateTradeUsesWapPayForMobile(t *testing.T) {
origWapPay := alipayTradeWapPay
t.Cleanup(func() {
alipayTradeWapPay = origWapPay
})
wapPayCalls := 0
alipayTradeWapPay = func(client *alipay.Client, param alipay.TradeWapPay) (*url.URL, error) {
wapPayCalls++
if param.ReturnURL != "https://merchant.example.com/payment/result" {
t.Fatalf("return_url = %q", param.ReturnURL)
}
return url.Parse("https://openapi.alipay.com/gateway.do?wap-pay")
}
provider := &Alipay{}
resp, err := provider.createWapTrade(&alipay.Client{}, payment.CreatePaymentRequest{
OrderID: "sub2_101",
Amount: "18.00",
Subject: "Balance recharge",
IsMobile: true,
}, "https://merchant.example.com/api/v1/payment/webhook/alipay", "https://merchant.example.com/payment/result")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if wapPayCalls != 1 {
t.Fatalf("wap pay calls = %d, want 1", wapPayCalls)
}
if resp.PayURL == "" {
t.Fatal("expected pay_url for mobile wap pay")
}
}
func TestAlipayMerchantIdentityMetadata(t *testing.T) {
t.Parallel()
provider := &Alipay{
config: map[string]string{
"appId": "2021001234567890",
},
}
metadata := provider.MerchantIdentityMetadata()
if metadata["app_id"] != "2021001234567890" {
t.Fatalf("app_id = %q, want %q", metadata["app_id"], "2021001234567890")
}
}
......@@ -59,6 +59,17 @@ func (e *EasyPay) SupportedTypes() []payment.PaymentType {
return []payment.PaymentType{payment.TypeAlipay, payment.TypeWxpay}
}
func (e *EasyPay) MerchantIdentityMetadata() map[string]string {
if e == nil {
return nil
}
pid := strings.TrimSpace(e.config["pid"])
if pid == "" {
return nil
}
return map[string]string{"pid": pid}
}
func (e *EasyPay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
// Payment mode determined by instance config, not payment type.
// "popup" → hosted page (submit.php); "qrcode"/default → API call (mapi.php).
......@@ -178,7 +189,12 @@ func (e *EasyPay) QueryOrder(ctx context.Context, tradeNo string) (*payment.Quer
status = payment.ProviderStatusPaid
}
amount, _ := strconv.ParseFloat(resp.Money, 64)
return &payment.QueryOrderResponse{TradeNo: tradeNo, Status: status, Amount: amount}, nil
return &payment.QueryOrderResponse{
TradeNo: tradeNo,
Status: status,
Amount: amount,
Metadata: e.MerchantIdentityMetadata(),
}, nil
}
func (e *EasyPay) VerifyNotification(_ context.Context, rawBody string, _ map[string]string) (*payment.PaymentNotification, error) {
......@@ -203,9 +219,17 @@ func (e *EasyPay) VerifyNotification(_ context.Context, rawBody string, _ map[st
status = payment.ProviderStatusSuccess
}
amount, _ := strconv.ParseFloat(params["money"], 64)
metadata := e.MerchantIdentityMetadata()
if pid := strings.TrimSpace(params["pid"]); pid != "" {
if metadata == nil {
metadata = map[string]string{}
}
metadata["pid"] = pid
}
return &payment.PaymentNotification{
TradeNo: params["trade_no"], OrderID: params["out_trade_no"],
Amount: amount, Status: status, RawData: rawBody,
Amount: amount, Status: status, RawData: rawBody, Metadata: metadata,
}, nil
}
......
......@@ -178,3 +178,18 @@ func TestEasyPayVerifySignWrongSignValue(t *testing.T) {
t.Fatal("easyPayVerifySign should return false for an incorrect sign value")
}
}
func TestEasyPayMerchantIdentityMetadata(t *testing.T) {
t.Parallel()
provider := &EasyPay{
config: map[string]string{
"pid": "1001",
},
}
metadata := provider.MerchantIdentityMetadata()
if metadata["pid"] != "1001" {
t.Fatalf("pid = %q, want %q", metadata["pid"], "1001")
}
}
......@@ -5,8 +5,8 @@ import (
"context"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
......@@ -20,6 +20,7 @@ import (
"github.com/wechatpay-apiv3/wechatpay-go/core/option"
"github.com/wechatpay-apiv3/wechatpay-go/services/payments"
"github.com/wechatpay-apiv3/wechatpay-go/services/payments/h5"
"github.com/wechatpay-apiv3/wechatpay-go/services/payments/jsapi"
"github.com/wechatpay-apiv3/wechatpay-go/services/payments/native"
"github.com/wechatpay-apiv3/wechatpay-go/services/refunddomestic"
"github.com/wechatpay-apiv3/wechatpay-go/utils"
......@@ -27,8 +28,23 @@ import (
// WeChat Pay constants.
const (
wxpayCurrency = "CNY"
wxpayH5Type = "Wap"
wxpayCurrency = "CNY"
wxpayH5Type = "Wap"
wxpayResultPath = "/payment/result"
)
const (
wxpayMetadataAppID = "appid"
wxpayMetadataMerchantID = "mchid"
wxpayMetadataCurrency = "currency"
wxpayMetadataTradeState = "trade_state"
)
// WeChat Pay create-payment modes.
const (
wxpayModeNative = "native"
wxpayModeH5 = "h5"
wxpayModeJSAPI = "jsapi"
)
// WeChat Pay trade states.
......@@ -49,6 +65,18 @@ const (
wxpayErrNoAuth = "NO_AUTH"
)
var (
wxpayNativePrepay = func(ctx context.Context, svc native.NativeApiService, req native.PrepayRequest) (*native.PrepayResponse, *core.APIResult, error) {
return svc.Prepay(ctx, req)
}
wxpayH5Prepay = func(ctx context.Context, svc h5.H5ApiService, req h5.PrepayRequest) (*h5.PrepayResponse, *core.APIResult, error) {
return svc.Prepay(ctx, req)
}
wxpayJSAPIPrepayWithRequestPayment = func(ctx context.Context, svc jsapi.JsapiApiService, req jsapi.PrepayRequest) (*jsapi.PrepayWithRequestPaymentResponse, *core.APIResult, error) {
return svc.PrepayWithRequestPayment(ctx, req)
}
)
type Wxpay struct {
instanceID string
config map[string]string
......@@ -96,6 +124,16 @@ func (w *Wxpay) SupportedTypes() []payment.PaymentType {
return []payment.PaymentType{payment.TypeWxpay}
}
// ResolveWxpayJSAPIAppID returns the AppID that JSAPI prepay will use for a
// given provider config. A dedicated MP AppID takes precedence over the base
// merchant AppID.
func ResolveWxpayJSAPIAppID(config map[string]string) string {
if appID := strings.TrimSpace(config["mpAppId"]); appID != "" {
return appID
}
return strings.TrimSpace(config["appId"])
}
func formatPEM(key, keyType string) string {
key = strings.TrimSpace(key)
if strings.HasPrefix(key, "-----BEGIN") {
......@@ -153,30 +191,68 @@ func (w *Wxpay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequ
if err != nil {
return nil, fmt.Errorf("wxpay create payment: %w", err)
}
if req.IsMobile && req.ClientIP != "" {
resp, err := w.createOrder(ctx, client, req, notifyURL, totalFen, true)
mode, err := resolveWxpayCreateMode(req)
if err != nil {
return nil, err
}
switch mode {
case wxpayModeJSAPI:
return w.prepayJSAPI(ctx, client, req, notifyURL, totalFen)
case wxpayModeH5:
resp, err := w.prepayH5(ctx, client, req, notifyURL, totalFen)
if err == nil {
return resp, nil
}
if !strings.Contains(err.Error(), wxpayErrNoAuth) {
return nil, err
if strings.Contains(err.Error(), wxpayErrNoAuth) {
return nil, fmt.Errorf("wxpay h5 payments are not authorized for this merchant: %w", err)
}
slog.Warn("wxpay H5 payment not authorized, falling back to native", "order", req.OrderID)
return nil, err
case wxpayModeNative:
return w.prepayNative(ctx, client, req, notifyURL, totalFen)
default:
return nil, fmt.Errorf("wxpay create payment: unsupported mode %q", mode)
}
return w.createOrder(ctx, client, req, notifyURL, totalFen, false)
}
func (w *Wxpay) createOrder(ctx context.Context, c *core.Client, req payment.CreatePaymentRequest, notifyURL string, totalFen int64, useH5 bool) (*payment.CreatePaymentResponse, error) {
if useH5 {
return w.prepayH5(ctx, c, req, notifyURL, totalFen)
func (w *Wxpay) prepayJSAPI(ctx context.Context, c *core.Client, req payment.CreatePaymentRequest, notifyURL string, totalFen int64) (*payment.CreatePaymentResponse, error) {
svc := jsapi.JsapiApiService{Client: c}
cur := wxpayCurrency
appID := ResolveWxpayJSAPIAppID(w.config)
prepayReq := jsapi.PrepayRequest{
Appid: core.String(appID),
Mchid: core.String(w.config["mchId"]),
Description: core.String(req.Subject),
OutTradeNo: core.String(req.OrderID),
NotifyUrl: core.String(notifyURL),
Amount: &jsapi.Amount{Total: core.Int64(totalFen), Currency: &cur},
Payer: &jsapi.Payer{Openid: core.String(strings.TrimSpace(req.OpenID))},
}
return w.prepayNative(ctx, c, req, notifyURL, totalFen)
if clientIP := strings.TrimSpace(req.ClientIP); clientIP != "" {
prepayReq.SceneInfo = &jsapi.SceneInfo{PayerClientIp: core.String(clientIP)}
}
resp, _, err := wxpayJSAPIPrepayWithRequestPayment(ctx, svc, prepayReq)
if err != nil {
return nil, fmt.Errorf("wxpay jsapi prepay: %w", err)
}
return &payment.CreatePaymentResponse{
TradeNo: req.OrderID,
ResultType: payment.CreatePaymentResultJSAPIReady,
JSAPI: &payment.WechatJSAPIPayload{
AppID: wxSV(resp.Appid),
TimeStamp: wxSV(resp.TimeStamp),
NonceStr: wxSV(resp.NonceStr),
Package: wxSV(resp.Package),
SignType: wxSV(resp.SignType),
PaySign: wxSV(resp.PaySign),
},
}, nil
}
func (w *Wxpay) prepayNative(ctx context.Context, c *core.Client, req payment.CreatePaymentRequest, notifyURL string, totalFen int64) (*payment.CreatePaymentResponse, error) {
svc := native.NativeApiService{Client: c}
cur := wxpayCurrency
resp, _, err := svc.Prepay(ctx, native.PrepayRequest{
resp, _, err := wxpayNativePrepay(ctx, svc, native.PrepayRequest{
Appid: core.String(w.config["appId"]), Mchid: core.String(w.config["mchId"]),
Description: core.String(req.Subject), OutTradeNo: core.String(req.OrderID),
NotifyUrl: core.String(notifyURL),
......@@ -195,13 +271,12 @@ func (w *Wxpay) prepayNative(ctx context.Context, c *core.Client, req payment.Cr
func (w *Wxpay) prepayH5(ctx context.Context, c *core.Client, req payment.CreatePaymentRequest, notifyURL string, totalFen int64) (*payment.CreatePaymentResponse, error) {
svc := h5.H5ApiService{Client: c}
cur := wxpayCurrency
tp := wxpayH5Type
resp, _, err := svc.Prepay(ctx, h5.PrepayRequest{
resp, _, err := wxpayH5Prepay(ctx, svc, h5.PrepayRequest{
Appid: core.String(w.config["appId"]), Mchid: core.String(w.config["mchId"]),
Description: core.String(req.Subject), OutTradeNo: core.String(req.OrderID),
NotifyUrl: core.String(notifyURL),
Amount: &h5.Amount{Total: core.Int64(totalFen), Currency: &cur},
SceneInfo: &h5.SceneInfo{PayerClientIp: core.String(req.ClientIP), H5Info: &h5.H5Info{Type: &tp}},
SceneInfo: &h5.SceneInfo{PayerClientIp: core.String(req.ClientIP), H5Info: buildWxpayH5Info(w.config)},
})
if err != nil {
return nil, fmt.Errorf("wxpay h5 prepay: %w", err)
......@@ -210,9 +285,77 @@ func (w *Wxpay) prepayH5(ctx context.Context, c *core.Client, req payment.Create
if resp.H5Url != nil {
h5URL = *resp.H5Url
}
h5URL, err = appendWxpayRedirectURL(h5URL, req)
if err != nil {
return nil, err
}
return &payment.CreatePaymentResponse{TradeNo: req.OrderID, PayURL: h5URL}, nil
}
func buildWxpayH5Info(config map[string]string) *h5.H5Info {
tp := wxpayH5Type
info := &h5.H5Info{Type: &tp}
if appName := strings.TrimSpace(config["h5AppName"]); appName != "" {
info.AppName = core.String(appName)
}
if appURL := strings.TrimSpace(config["h5AppUrl"]); appURL != "" {
info.AppUrl = core.String(appURL)
}
return info
}
func resolveWxpayCreateMode(req payment.CreatePaymentRequest) (string, error) {
if strings.TrimSpace(req.OpenID) != "" {
return wxpayModeJSAPI, nil
}
if req.IsMobile {
if strings.TrimSpace(req.ClientIP) == "" {
return "", fmt.Errorf("wxpay H5 payment requires client IP")
}
return wxpayModeH5, nil
}
return wxpayModeNative, nil
}
func appendWxpayRedirectURL(h5URL string, req payment.CreatePaymentRequest) (string, error) {
h5URL = strings.TrimSpace(h5URL)
returnURL := strings.TrimSpace(req.ReturnURL)
if h5URL == "" || returnURL == "" {
return h5URL, nil
}
redirectURL, err := buildWxpayResultURL(returnURL, req)
if err != nil {
return "", err
}
sep := "&"
if !strings.Contains(h5URL, "?") {
sep = "?"
}
return h5URL + sep + "redirect_url=" + url.QueryEscape(redirectURL), nil
}
func buildWxpayResultURL(returnURL string, req payment.CreatePaymentRequest) (string, error) {
u, err := url.Parse(returnURL)
if err != nil || !u.IsAbs() || u.Host == "" || (u.Scheme != "http" && u.Scheme != "https") {
return "", fmt.Errorf("return URL must be an absolute http(s) URL")
}
values := u.Query()
values.Set("out_trade_no", strings.TrimSpace(req.OrderID))
if paymentType := strings.TrimSpace(req.PaymentType); paymentType != "" {
values.Set("payment_type", paymentType)
}
if strings.TrimSpace(u.Path) == "" {
u.Path = wxpayResultPath
}
u.RawPath = ""
u.RawQuery = values.Encode()
u.Fragment = ""
return u.String(), nil
}
func wxSV(s *string) string {
if s == nil {
return ""
......@@ -233,6 +376,32 @@ func mapWxState(s string) string {
}
}
func buildWxpayTransactionMetadata(tx *payments.Transaction) map[string]string {
if tx == nil {
return nil
}
metadata := map[string]string{}
if appID := wxSV(tx.Appid); appID != "" {
metadata[wxpayMetadataAppID] = appID
}
if merchantID := wxSV(tx.Mchid); merchantID != "" {
metadata[wxpayMetadataMerchantID] = merchantID
}
if tradeState := wxSV(tx.TradeState); tradeState != "" {
metadata[wxpayMetadataTradeState] = tradeState
}
if tx.Amount != nil {
if currency := wxSV(tx.Amount.Currency); currency != "" {
metadata[wxpayMetadataCurrency] = currency
}
}
if len(metadata) == 0 {
return nil
}
return metadata
}
func (w *Wxpay) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryOrderResponse, error) {
c, err := w.ensureClient()
if err != nil {
......@@ -257,7 +426,13 @@ func (w *Wxpay) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryO
if tx.SuccessTime != nil {
pa = *tx.SuccessTime
}
return &payment.QueryOrderResponse{TradeNo: id, Status: mapWxState(wxSV(tx.TradeState)), Amount: amt, PaidAt: pa}, nil
return &payment.QueryOrderResponse{
TradeNo: id,
Status: mapWxState(wxSV(tx.TradeState)),
Amount: amt,
PaidAt: pa,
Metadata: buildWxpayTransactionMetadata(tx),
}, nil
}
func (w *Wxpay) VerifyNotification(ctx context.Context, rawBody string, headers map[string]string) (*payment.PaymentNotification, error) {
......@@ -289,7 +464,7 @@ func (w *Wxpay) VerifyNotification(ctx context.Context, rawBody string, headers
}
return &payment.PaymentNotification{
TradeNo: wxSV(tx.TransactionId), OrderID: wxSV(tx.OutTradeNo),
Amount: amt, Status: st, RawData: rawBody,
Amount: amt, Status: st, RawData: rawBody, Metadata: buildWxpayTransactionMetadata(&tx),
}, nil
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment