Unverified Commit ddf80f5e authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #1799 from IanShaw027/rebuild/auth-identity-foundation

fix(auth,payment,profile): 修复认证身份和支付系统的后续问题
parents 4d0483f5 c048ca80
......@@ -5,6 +5,7 @@ package service
import (
"context"
"database/sql"
"sync"
"testing"
"time"
......@@ -259,6 +260,107 @@ func TestAuthPendingIdentityService_UpsertAdoptionDecision_ReassignsExistingIden
require.Nil(t, reloadedFirst.IdentityID)
}
func TestAuthPendingIdentityService_UpsertAdoptionDecision_IsIdempotentUnderConcurrency(t *testing.T) {
svc, client := newAuthPendingIdentityServiceTestClient(t)
ctx := context.Background()
user, err := client.User.Create().
SetEmail("adoption-concurrent@example.com").
SetPasswordHash("hash").
SetRole(RoleUser).
SetStatus(StatusActive).
Save(ctx)
require.NoError(t, err)
identity, err := client.AuthIdentity.Create().
SetUserID(user.ID).
SetProviderType("wechat").
SetProviderKey("wechat-main").
SetProviderSubject("union-concurrent").
SetMetadata(map[string]any{}).
Save(ctx)
require.NoError(t, err)
session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
Intent: "bind_current_user",
Identity: PendingAuthIdentityKey{
ProviderType: "wechat",
ProviderKey: "wechat-main",
ProviderSubject: "union-concurrent",
},
})
require.NoError(t, err)
firstCreateStarted := make(chan struct{})
releaseFirstCreate := make(chan struct{})
var firstCreate sync.Once
client.IdentityAdoptionDecision.Use(func(next dbent.Mutator) dbent.Mutator {
return dbent.MutateFunc(func(ctx context.Context, m dbent.Mutation) (dbent.Value, error) {
blocked := false
if m.Op().Is(dbent.OpCreate) {
firstCreate.Do(func() {
blocked = true
close(firstCreateStarted)
})
}
if blocked {
<-releaseFirstCreate
}
return next.Mutate(ctx, m)
})
})
type adoptionResult struct {
decision *dbent.IdentityAdoptionDecision
err error
}
input := PendingIdentityAdoptionDecisionInput{
PendingAuthSessionID: session.ID,
IdentityID: &identity.ID,
AdoptDisplayName: true,
AdoptAvatar: true,
}
results := make(chan adoptionResult, 2)
go func() {
decision, err := svc.UpsertAdoptionDecision(ctx, input)
results <- adoptionResult{decision: decision, err: err}
}()
<-firstCreateStarted
go func() {
decision, err := svc.UpsertAdoptionDecision(ctx, input)
results <- adoptionResult{decision: decision, err: err}
}()
time.Sleep(100 * time.Millisecond)
close(releaseFirstCreate)
first := <-results
second := <-results
require.NoError(t, first.err)
require.NoError(t, second.err)
require.NotNil(t, first.decision)
require.NotNil(t, second.decision)
require.Equal(t, first.decision.ID, second.decision.ID)
count, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
Count(ctx)
require.NoError(t, err)
require.Equal(t, 1, count)
loaded, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, loaded.IdentityID)
require.Equal(t, identity.ID, *loaded.IdentityID)
}
func TestAuthPendingIdentityService_UpsertAdoptionDecision_ClearsLegacyNullSessionReference(t *testing.T) {
t.Skip("legacy NULL pending_auth_session_id rows only exist in production PostgreSQL history; sqlite unit schema rejects NULL")
......@@ -356,3 +458,69 @@ func TestAuthPendingIdentityService_ConsumeBrowserSession(t *testing.T) {
_, err = svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session")
require.ErrorIs(t, err, ErrPendingAuthSessionConsumed)
}
func TestAuthPendingIdentityService_ConsumeBrowserSessionRejectsStaleLoadedSessionReplay(t *testing.T) {
svc, _ := newAuthPendingIdentityServiceTestClient(t)
ctx := context.Background()
session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
Intent: "login",
Identity: PendingAuthIdentityKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo",
ProviderSubject: "stale-replay-subject",
},
BrowserSessionKey: "browser-session",
})
require.NoError(t, err)
loaded, err := svc.getBrowserSession(ctx, session.SessionToken)
require.NoError(t, err)
consumed, err := svc.consumeSession(ctx, loaded, "browser-session", ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed)
require.NoError(t, err)
require.NotNil(t, consumed.ConsumedAt)
_, err = svc.consumeSession(ctx, loaded, "browser-session", ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed)
require.ErrorIs(t, err, ErrPendingAuthSessionConsumed)
}
func TestAuthPendingIdentityService_ConsumeBrowserSessionScrubsLegacyCompletionTokens(t *testing.T) {
svc, client := newAuthPendingIdentityServiceTestClient(t)
ctx := context.Background()
session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
Intent: "login",
Identity: PendingAuthIdentityKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo",
ProviderSubject: "legacy-token-subject",
},
BrowserSessionKey: "browser-session",
LocalFlowState: map[string]any{
"completion_response": map[string]any{
"access_token": "legacy-access-token",
"refresh_token": "legacy-refresh-token",
"expires_in": float64(3600),
"token_type": "Bearer",
"redirect": "/dashboard",
},
},
})
require.NoError(t, err)
consumed, err := svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session")
require.NoError(t, err)
require.NotNil(t, consumed.ConsumedAt)
stored, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
completion, ok := stored.LocalFlowState["completion_response"].(map[string]any)
require.True(t, ok)
require.NotContains(t, completion, "access_token")
require.NotContains(t, completion, "refresh_token")
require.NotContains(t, completion, "expires_in")
require.NotContains(t, completion, "token_type")
require.Equal(t, "/dashboard", completion["redirect"])
}
......@@ -4,6 +4,7 @@ import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
......@@ -489,6 +490,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency,
Status: StatusActive,
SignupSource: signupSource,
}
if err := s.userRepo.Create(ctx, newUser); err != nil {
......@@ -599,6 +601,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency,
Status: StatusActive,
SignupSource: signupSource,
}
if s.entClient != nil && invitationRedeemCode != nil {
......@@ -1048,7 +1051,7 @@ func (s *AuthService) GenerateToken(user *User) (string, error) {
UserID: user.ID,
Email: user.Email,
Role: user.Role,
TokenVersion: user.TokenVersion,
TokenVersion: resolvedTokenVersion(user),
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expiresAt),
IssuedAt: jwt.NewNumericDate(now),
......@@ -1114,7 +1117,7 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
// Security: Check TokenVersion to prevent refreshing revoked tokens
// This ensures tokens issued before a password change cannot be refreshed
if claims.TokenVersion != user.TokenVersion {
if claims.TokenVersion != resolvedTokenVersion(user) {
return "", ErrTokenRevoked
}
......@@ -1342,7 +1345,7 @@ func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, fami
data := &RefreshTokenData{
UserID: user.ID,
TokenVersion: user.TokenVersion,
TokenVersion: resolvedTokenVersion(user),
FamilyID: familyID,
CreatedAt: now,
ExpiresAt: now.Add(ttl),
......@@ -1422,7 +1425,7 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string)
}
// 检查TokenVersion(密码更改后所有Token失效)
if data.TokenVersion != user.TokenVersion {
if data.TokenVersion != resolvedTokenVersion(user) {
// TokenVersion不匹配,撤销整个Token家族
_ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID)
return nil, ErrTokenRevoked
......@@ -1467,8 +1470,42 @@ func (s *AuthService) RevokeAllUserSessions(ctx context.Context, userID int64) e
return s.refreshTokenCache.DeleteUserRefreshTokens(ctx, userID)
}
// RevokeAllUserTokens invalidates both stateless access tokens and refresh sessions.
// Access/refresh token verification both depend on TokenVersion, so bumping it provides
// immediate revocation even if refresh-token cache cleanup later fails.
func (s *AuthService) RevokeAllUserTokens(ctx context.Context, userID int64) error {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return fmt.Errorf("get user: %w", err)
}
user.TokenVersion++
if err := s.userRepo.Update(ctx, user); err != nil {
return fmt.Errorf("update user: %w", err)
}
if err := s.RevokeAllUserSessions(ctx, userID); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to revoke refresh sessions after token invalidation for user %d: %v", userID, err)
}
return nil
}
// hashToken 计算Token的SHA256哈希
func hashToken(token string) string {
hash := sha256.Sum256([]byte(token))
return hex.EncodeToString(hash[:])
}
func resolvedTokenVersion(user *User) int64 {
if user == nil {
return 0
}
if user.TokenVersionResolved {
return user.TokenVersion
}
material := strings.ToLower(strings.TrimSpace(user.Email)) + "\n" + user.PasswordHash
sum := sha256.Sum256([]byte(material))
fingerprint := int64(binary.BigEndian.Uint64(sum[:8]) & 0x7fffffffffffffff)
return user.TokenVersion ^ fingerprint
}
......@@ -6,6 +6,7 @@ import (
"context"
"database/sql"
"errors"
"sync"
"testing"
"time"
......@@ -13,6 +14,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
......@@ -54,6 +56,16 @@ func newAuthServiceForEmailBind(
settings map[string]string,
emailCache service.EmailCache,
defaultSubAssigner service.DefaultSubscriptionAssigner,
) (*service.AuthService, service.UserRepository, *dbent.Client) {
return newAuthServiceForEmailBindWithRefreshCache(t, settings, emailCache, defaultSubAssigner, nil)
}
func newAuthServiceForEmailBindWithRefreshCache(
t *testing.T,
settings map[string]string,
emailCache service.EmailCache,
defaultSubAssigner service.DefaultSubscriptionAssigner,
refreshTokenCache service.RefreshTokenCache,
) (*service.AuthService, service.UserRepository, *dbent.Client) {
t.Helper()
......@@ -98,7 +110,7 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants (
emailSvc = service.NewEmailService(settingRepo, emailCache)
}
svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner)
svc := service.NewAuthService(client, repo, nil, refreshTokenCache, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner)
return svc, repo, client
}
......@@ -427,6 +439,61 @@ func TestAuthServiceBindEmailIdentity_RejectsWrongCurrentPasswordForBoundEmail(t
require.Equal(t, 0, newIdentityCount)
}
func TestAuthServiceBindEmailIdentity_RevokesExistingAccessAndRefreshTokens(t *testing.T) {
ctx := context.Background()
cache := &emailBindCacheStub{
data: &service.VerificationCodeData{
Code: "123456",
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
},
}
refreshTokenCache := newEmailBindRefreshTokenCacheStub()
userRepo := newEmailBindUserRepoStub(&service.User{
ID: 41,
Email: "legacy-user" + service.OIDCConnectSyntheticEmailDomain,
Username: "legacy-user",
PasswordHash: "old-hash",
Role: service.RoleUser,
Status: service.StatusActive,
TokenVersion: 4,
})
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-bind-email-secret",
ExpireHour: 1,
AccessTokenExpireMinutes: 60,
RefreshTokenExpireDays: 7,
},
}
emailService := service.NewEmailService(nil, cache)
svc := service.NewAuthService(nil, userRepo, nil, refreshTokenCache, cfg, nil, emailService, nil, nil, nil, nil)
oldTokenPair, err := svc.GenerateTokenPair(ctx, &service.User{
ID: 41,
Email: "legacy-user" + service.OIDCConnectSyntheticEmailDomain,
Role: service.RoleUser,
Status: service.StatusActive,
TokenVersion: 4,
}, "")
require.NoError(t, err)
updatedUser, err := svc.BindEmailIdentity(ctx, 41, "new@example.com", "123456", "new-password")
require.NoError(t, err)
require.NotNil(t, updatedUser)
storedUser, err := userRepo.GetByID(ctx, 41)
require.NoError(t, err)
require.Equal(t, "new@example.com", storedUser.Email)
require.True(t, svc.CheckPassword("new-password", storedUser.PasswordHash))
_, err = svc.RefreshToken(ctx, oldTokenPair.AccessToken)
require.ErrorIs(t, err, service.ErrTokenRevoked)
_, err = svc.RefreshTokenPair(ctx, oldTokenPair.RefreshToken)
require.True(t, errors.Is(err, service.ErrTokenRevoked) || errors.Is(err, service.ErrRefreshTokenInvalid))
}
type emailBindSettingRepoStub struct {
values map[string]string
}
......@@ -527,3 +594,260 @@ func (s *emailBindCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int6
func (s *emailBindCacheStub) IncrNotifyCodeUserRate(context.Context, int64, time.Duration) (int64, error) {
return 0, nil
}
type emailBindRefreshTokenCacheStub struct {
mu sync.Mutex
tokens map[string]*service.RefreshTokenData
userSets map[int64]map[string]struct{}
families map[string]map[string]struct{}
}
func newEmailBindRefreshTokenCacheStub() *emailBindRefreshTokenCacheStub {
return &emailBindRefreshTokenCacheStub{
tokens: make(map[string]*service.RefreshTokenData),
userSets: make(map[int64]map[string]struct{}),
families: make(map[string]map[string]struct{}),
}
}
func (s *emailBindRefreshTokenCacheStub) StoreRefreshToken(_ context.Context, tokenHash string, data *service.RefreshTokenData, _ time.Duration) error {
s.mu.Lock()
defer s.mu.Unlock()
cloned := *data
s.tokens[tokenHash] = &cloned
return nil
}
func (s *emailBindRefreshTokenCacheStub) GetRefreshToken(_ context.Context, tokenHash string) (*service.RefreshTokenData, error) {
s.mu.Lock()
defer s.mu.Unlock()
data, ok := s.tokens[tokenHash]
if !ok {
return nil, service.ErrRefreshTokenNotFound
}
cloned := *data
return &cloned, nil
}
func (s *emailBindRefreshTokenCacheStub) DeleteRefreshToken(_ context.Context, tokenHash string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.tokens, tokenHash)
for _, tokenSet := range s.userSets {
delete(tokenSet, tokenHash)
}
for _, tokenSet := range s.families {
delete(tokenSet, tokenHash)
}
return nil
}
func (s *emailBindRefreshTokenCacheStub) DeleteUserRefreshTokens(_ context.Context, userID int64) error {
s.mu.Lock()
defer s.mu.Unlock()
for tokenHash := range s.userSets[userID] {
delete(s.tokens, tokenHash)
for _, tokenSet := range s.families {
delete(tokenSet, tokenHash)
}
}
delete(s.userSets, userID)
return nil
}
func (s *emailBindRefreshTokenCacheStub) DeleteTokenFamily(_ context.Context, familyID string) error {
s.mu.Lock()
defer s.mu.Unlock()
for tokenHash := range s.families[familyID] {
delete(s.tokens, tokenHash)
for _, tokenSet := range s.userSets {
delete(tokenSet, tokenHash)
}
}
delete(s.families, familyID)
return nil
}
func (s *emailBindRefreshTokenCacheStub) AddToUserTokenSet(_ context.Context, userID int64, tokenHash string, _ time.Duration) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.userSets[userID] == nil {
s.userSets[userID] = make(map[string]struct{})
}
s.userSets[userID][tokenHash] = struct{}{}
return nil
}
func (s *emailBindRefreshTokenCacheStub) AddToFamilyTokenSet(_ context.Context, familyID string, tokenHash string, _ time.Duration) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.families[familyID] == nil {
s.families[familyID] = make(map[string]struct{})
}
s.families[familyID][tokenHash] = struct{}{}
return nil
}
func (s *emailBindRefreshTokenCacheStub) GetUserTokenHashes(_ context.Context, userID int64) ([]string, error) {
s.mu.Lock()
defer s.mu.Unlock()
tokenSet := s.userSets[userID]
out := make([]string, 0, len(tokenSet))
for tokenHash := range tokenSet {
out = append(out, tokenHash)
}
return out, nil
}
func (s *emailBindRefreshTokenCacheStub) GetFamilyTokenHashes(_ context.Context, familyID string) ([]string, error) {
s.mu.Lock()
defer s.mu.Unlock()
tokenSet := s.families[familyID]
out := make([]string, 0, len(tokenSet))
for tokenHash := range tokenSet {
out = append(out, tokenHash)
}
return out, nil
}
func (s *emailBindRefreshTokenCacheStub) IsTokenInFamily(_ context.Context, familyID string, tokenHash string) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
_, ok := s.families[familyID][tokenHash]
return ok, nil
}
type emailBindUserRepoStub struct {
mu sync.Mutex
usersByID map[int64]*service.User
usersByEmail map[string]*service.User
}
func newEmailBindUserRepoStub(user *service.User) *emailBindUserRepoStub {
cloned := cloneEmailBindUser(user)
return &emailBindUserRepoStub{
usersByID: map[int64]*service.User{
cloned.ID: cloned,
},
usersByEmail: map[string]*service.User{
cloned.Email: cloned,
},
}
}
func (s *emailBindUserRepoStub) Create(context.Context, *service.User) error { return nil }
func (s *emailBindUserRepoStub) GetByID(_ context.Context, id int64) (*service.User, error) {
s.mu.Lock()
defer s.mu.Unlock()
user, ok := s.usersByID[id]
if !ok {
return nil, service.ErrUserNotFound
}
return cloneEmailBindUser(user), nil
}
func (s *emailBindUserRepoStub) GetByEmail(_ context.Context, email string) (*service.User, error) {
s.mu.Lock()
defer s.mu.Unlock()
user, ok := s.usersByEmail[email]
if !ok {
return nil, service.ErrUserNotFound
}
return cloneEmailBindUser(user), nil
}
func (s *emailBindUserRepoStub) GetFirstAdmin(context.Context) (*service.User, error) {
panic("unexpected GetFirstAdmin call")
}
func (s *emailBindUserRepoStub) Update(_ context.Context, user *service.User) error {
s.mu.Lock()
defer s.mu.Unlock()
existing, ok := s.usersByID[user.ID]
if !ok {
return service.ErrUserNotFound
}
delete(s.usersByEmail, existing.Email)
cloned := cloneEmailBindUser(user)
s.usersByID[user.ID] = cloned
s.usersByEmail[cloned.Email] = cloned
return nil
}
func (s *emailBindUserRepoStub) Delete(context.Context, int64) error { return nil }
func (s *emailBindUserRepoStub) GetUserAvatar(context.Context, int64) (*service.UserAvatar, error) {
return nil, nil
}
func (s *emailBindUserRepoStub) UpsertUserAvatar(context.Context, int64, service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
panic("unexpected UpsertUserAvatar call")
}
func (s *emailBindUserRepoStub) DeleteUserAvatar(context.Context, int64) error {
panic("unexpected DeleteUserAvatar call")
}
func (s *emailBindUserRepoStub) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (s *emailBindUserRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}
func (s *emailBindUserRepoStub) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
return map[int64]*time.Time{}, nil
}
func (s *emailBindUserRepoStub) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
return nil, nil
}
func (s *emailBindUserRepoStub) UpdateUserLastActiveAt(context.Context, int64, time.Time) error {
return nil
}
func (s *emailBindUserRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil }
func (s *emailBindUserRepoStub) DeductBalance(context.Context, int64, float64) error { return nil }
func (s *emailBindUserRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil }
func (s *emailBindUserRepoStub) ExistsByEmail(_ context.Context, email string) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
_, ok := s.usersByEmail[email]
return ok, nil
}
func (s *emailBindUserRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
return 0, nil
}
func (s *emailBindUserRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error {
return nil
}
func (s *emailBindUserRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
return nil
}
func (s *emailBindUserRepoStub) ListUserAuthIdentities(context.Context, int64) ([]service.UserAuthIdentityRecord, error) {
return nil, nil
}
func (s *emailBindUserRepoStub) UnbindUserAuthProvider(context.Context, int64, string) error {
return nil
}
func (s *emailBindUserRepoStub) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (s *emailBindUserRepoStub) EnableTotp(context.Context, int64) error { return nil }
func (s *emailBindUserRepoStub) DisableTotp(context.Context, int64) error { return nil }
func cloneEmailBindUser(user *service.User) *service.User {
if user == nil {
return nil
}
cloned := *user
return &cloned
}
......@@ -20,7 +20,7 @@ func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*M
return nil, fmt.Errorf("query provider instances: %w", err)
}
typeInstances := pcGroupByPaymentType(instances)
typeInstances = pcApplyEnabledVisibleMethodInstances(typeInstances, instances)
typeInstances = s.pcApplyEnabledVisibleMethodInstances(ctx, typeInstances, instances)
resp := &MethodLimitsResponse{
Methods: make(map[string]MethodLimits, len(typeInstances)),
}
......@@ -32,7 +32,7 @@ func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*M
return resp, nil
}
func pcApplyEnabledVisibleMethodInstances(typeInstances map[string][]*dbent.PaymentProviderInstance, instances []*dbent.PaymentProviderInstance) map[string][]*dbent.PaymentProviderInstance {
func (s *PaymentConfigService) pcApplyEnabledVisibleMethodInstances(ctx context.Context, typeInstances map[string][]*dbent.PaymentProviderInstance, instances []*dbent.PaymentProviderInstance) map[string][]*dbent.PaymentProviderInstance {
if len(typeInstances) == 0 {
return typeInstances
}
......@@ -44,11 +44,25 @@ func pcApplyEnabledVisibleMethodInstances(typeInstances map[string][]*dbent.Paym
for _, method := range []string{payment.TypeAlipay, payment.TypeWxpay} {
matching := filterEnabledVisibleMethodInstances(instances, method)
if len(matching) != 1 {
providerKey, err := s.resolveVisibleMethodProviderKey(ctx, method, matching)
if err != nil {
delete(filtered, method)
continue
}
filtered[method] = []*dbent.PaymentProviderInstance{matching[0]}
if providerKey == "" {
if len(matching) == 0 {
delete(filtered, method)
continue
}
filtered[method] = matching
continue
}
selectedInstances := filterVisibleMethodInstancesByProviderKey(instances, method, providerKey)
if len(selectedInstances) == 0 {
delete(filtered, method)
continue
}
filtered[method] = selectedInstances
}
return filtered
}
......
......@@ -6,6 +6,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/stretchr/testify/require"
)
func TestUnionFloat(t *testing.T) {
......@@ -301,7 +302,109 @@ func TestPcInstanceTypeLimits(t *testing.T) {
})
}
func TestGetAvailableMethodLimitsHidesConflictingVisibleMethodProviders(t *testing.T) {
func TestGetAvailableMethodLimitsUsesConfiguredVisibleMethodSource(t *testing.T) {
tests := []struct {
name string
sourceSetting string
wantAlipaySingleMin float64
wantAlipaySingleMax float64
wantGlobalMin float64
wantGlobalMax float64
}{
{
name: "official source",
sourceSetting: VisibleMethodSourceOfficialAlipay,
wantAlipaySingleMin: 10,
wantAlipaySingleMax: 100,
wantGlobalMin: 10,
wantGlobalMax: 300,
},
{
name: "easypay source",
sourceSetting: VisibleMethodSourceEasyPayAlipay,
wantAlipaySingleMin: 20,
wantAlipaySingleMax: 200,
wantGlobalMin: 20,
wantGlobalMax: 300,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
_, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeAlipay).
SetName("Official Alipay").
SetConfig("{}").
SetSupportedTypes("alipay").
SetLimits(`{"alipay":{"singleMin":10,"singleMax":100}}`).
SetEnabled(true).
Save(ctx)
if err != nil {
t.Fatalf("create official alipay instance: %v", err)
}
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeEasyPay).
SetName("EasyPay Alipay").
SetConfig("{}").
SetSupportedTypes("alipay").
SetLimits(`{"alipay":{"singleMin":20,"singleMax":200}}`).
SetEnabled(true).
Save(ctx)
if err != nil {
t.Fatalf("create easypay alipay instance: %v", err)
}
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeWxpay).
SetName("Official WeChat").
SetConfig("{}").
SetSupportedTypes("wxpay").
SetLimits(`{"wxpay":{"singleMin":30,"singleMax":300}}`).
SetEnabled(true).
Save(ctx)
if err != nil {
t.Fatalf("create official wxpay instance: %v", err)
}
svc := &PaymentConfigService{
entClient: client,
settingRepo: &paymentConfigSettingRepoStub{
values: map[string]string{
SettingPaymentVisibleMethodAlipaySource: tt.sourceSetting,
},
},
}
resp, err := svc.GetAvailableMethodLimits(ctx)
if err != nil {
t.Fatalf("GetAvailableMethodLimits returned error: %v", err)
}
alipayLimits, ok := resp.Methods[payment.TypeAlipay]
if !ok {
t.Fatalf("expected alipay limits to remain visible, got %v", resp.Methods)
}
if alipayLimits.SingleMin != tt.wantAlipaySingleMin || alipayLimits.SingleMax != tt.wantAlipaySingleMax {
t.Fatalf("alipay limits = %+v, want min=%v max=%v", alipayLimits, tt.wantAlipaySingleMin, tt.wantAlipaySingleMax)
}
wxpayLimits, ok := resp.Methods[payment.TypeWxpay]
if !ok {
t.Fatalf("expected wxpay limits to remain visible, got %v", resp.Methods)
}
if wxpayLimits.SingleMin != 30 || wxpayLimits.SingleMax != 300 {
t.Fatalf("wxpay limits = %+v, want official-only min=30 max=300", wxpayLimits)
}
if resp.GlobalMin != tt.wantGlobalMin || resp.GlobalMax != tt.wantGlobalMax {
t.Fatalf("global range = (%v, %v), want (%v, %v)", resp.GlobalMin, resp.GlobalMax, tt.wantGlobalMin, tt.wantGlobalMax)
}
})
}
}
func TestGetAvailableMethodLimitsPreservesLegacyCrossProviderBehaviorWhenVisibleMethodSourceMissing(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
......@@ -313,20 +416,18 @@ func TestGetAvailableMethodLimitsHidesConflictingVisibleMethodProviders(t *testi
SetLimits(`{"alipay":{"singleMin":10,"singleMax":100}}`).
SetEnabled(true).
Save(ctx)
if err != nil {
t.Fatalf("create official alipay instance: %v", err)
}
require.NoError(t, err)
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeEasyPay).
SetName("EasyPay Alipay").
SetName("EasyPay Mixed").
SetConfig("{}").
SetSupportedTypes("alipay").
SetLimits(`{"alipay":{"singleMin":20,"singleMax":200}}`).
SetSupportedTypes("alipay,wxpay").
SetLimits(`{"alipay":{"singleMin":20,"singleMax":200},"wxpay":{"singleMin":40,"singleMax":400}}`).
SetEnabled(true).
Save(ctx)
if err != nil {
t.Fatalf("create easypay alipay instance: %v", err)
}
require.NoError(t, err)
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeWxpay).
SetName("Official WeChat").
......@@ -335,31 +436,26 @@ func TestGetAvailableMethodLimitsHidesConflictingVisibleMethodProviders(t *testi
SetLimits(`{"wxpay":{"singleMin":30,"singleMax":300}}`).
SetEnabled(true).
Save(ctx)
if err != nil {
t.Fatalf("create official wxpay instance: %v", err)
}
require.NoError(t, err)
svc := &PaymentConfigService{
entClient: client,
entClient: client,
settingRepo: &paymentConfigSettingRepoStub{values: map[string]string{}},
}
resp, err := svc.GetAvailableMethodLimits(ctx)
if err != nil {
t.Fatalf("GetAvailableMethodLimits returned error: %v", err)
}
require.NoError(t, err)
if _, ok := resp.Methods[payment.TypeAlipay]; ok {
t.Fatalf("alipay should be hidden when multiple enabled providers claim it, got %v", resp.Methods[payment.TypeAlipay])
}
alipayLimits, ok := resp.Methods[payment.TypeAlipay]
require.True(t, ok, "expected alipay limits to remain visible")
require.Equal(t, 10.0, alipayLimits.SingleMin)
require.Equal(t, 200.0, alipayLimits.SingleMax)
wxpayLimits, ok := resp.Methods[payment.TypeWxpay]
if !ok {
t.Fatalf("expected wxpay limits to remain visible, got %v", resp.Methods)
}
if wxpayLimits.SingleMin != 30 || wxpayLimits.SingleMax != 300 {
t.Fatalf("wxpay limits = %+v, want official-only min=30 max=300", wxpayLimits)
}
if resp.GlobalMin != 30 || resp.GlobalMax != 300 {
t.Fatalf("global range = (%v, %v), want (30, 300)", resp.GlobalMin, resp.GlobalMax)
}
require.True(t, ok, "expected wxpay limits to remain visible")
require.Equal(t, 30.0, wxpayLimits.SingleMin)
require.Equal(t, 400.0, wxpayLimits.SingleMax)
require.Equal(t, 10.0, resp.GlobalMin)
require.Equal(t, 400.0, resp.GlobalMax)
}
......@@ -116,6 +116,17 @@ var providerSensitiveConfigFields = map[string]map[string]struct{}{
payment.TypeStripe: {"secretkey": {}, "webhooksecret": {}},
}
// providerPendingOrderProtectedConfigFields lists config keys that cannot be
// changed while the instance has in-progress orders. This includes secrets plus
// all provider identity fields that are snapshotted into orders or used by
// webhook/refund verification.
var providerPendingOrderProtectedConfigFields = map[string]map[string]struct{}{
payment.TypeEasyPay: {"pkey": {}, "pid": {}},
payment.TypeAlipay: {"privatekey": {}, "publickey": {}, "alipaypublickey": {}, "appid": {}},
payment.TypeWxpay: {"privatekey": {}, "apiv3key": {}, "publickey": {}, "appid": {}, "mpappid": {}, "mchid": {}, "publickeyid": {}, "certserial": {}},
payment.TypeStripe: {"secretkey": {}, "webhooksecret": {}},
}
func isSensitiveProviderConfigField(providerKey, fieldName string) bool {
fields, ok := providerSensitiveConfigFields[providerKey]
if !ok {
......@@ -125,6 +136,28 @@ func isSensitiveProviderConfigField(providerKey, fieldName string) bool {
return found
}
func hasPendingOrderProtectedConfigChange(providerKey string, currentConfig, nextConfig map[string]string) bool {
fields, ok := providerPendingOrderProtectedConfigFields[providerKey]
if !ok {
return false
}
for fieldName := range fields {
if providerConfigFieldValue(currentConfig, fieldName) != providerConfigFieldValue(nextConfig, fieldName) {
return true
}
}
return false
}
func providerConfigFieldValue(config map[string]string, fieldName string) string {
for key, value := range config {
if strings.EqualFold(key, fieldName) {
return value
}
}
return ""
}
func (s *PaymentConfigService) countPendingOrders(ctx context.Context, providerInstanceID int64) (int, error) {
return s.entClient.PaymentOrder.Query().
Where(
......@@ -190,6 +223,18 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
if err != nil {
return nil, fmt.Errorf("load provider instance: %w", err)
}
var pendingOrderCount *int
getPendingOrderCount := func() (int, error) {
if pendingOrderCount != nil {
return *pendingOrderCount, nil
}
count, err := s.countPendingOrders(ctx, id)
if err != nil {
return 0, fmt.Errorf("check pending orders: %w", err)
}
pendingOrderCount = &count
return count, nil
}
nextEnabled := current.Enabled
if req.Enabled != nil {
nextEnabled = *req.Enabled
......@@ -201,18 +246,20 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
if err := s.validateVisibleMethodEnablementConflicts(ctx, id, current.ProviderKey, nextSupportedTypes, nextEnabled); err != nil {
return nil, err
}
var mergedConfig map[string]string
if req.Config != nil {
hasSensitive := false
for k, v := range req.Config {
if v != "" && isSensitiveProviderConfigField(current.ProviderKey, k) {
hasSensitive = true
break
}
currentConfig, err := s.decryptConfig(current.Config)
if err != nil {
return nil, fmt.Errorf("decrypt existing config: %w", err)
}
if hasSensitive {
count, err := s.countPendingOrders(ctx, id)
mergedConfig, err = s.mergeConfig(ctx, id, req.Config)
if err != nil {
return nil, err
}
if hasPendingOrderProtectedConfigChange(current.ProviderKey, currentConfig, mergedConfig) {
count, err := getPendingOrderCount()
if err != nil {
return nil, fmt.Errorf("check pending orders: %w", err)
return nil, err
}
if count > 0 {
return nil, infraerrors.Conflict("PENDING_ORDERS", "instance has pending orders").
......@@ -221,9 +268,9 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
}
}
if req.Enabled != nil && !*req.Enabled {
count, err := s.countPendingOrders(ctx, id)
count, err := getPendingOrderCount()
if err != nil {
return nil, fmt.Errorf("check pending orders: %w", err)
return nil, err
}
if count > 0 {
return nil, infraerrors.Conflict("PENDING_ORDERS", "instance has pending orders").
......@@ -237,13 +284,6 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
if req.Enabled != nil {
finalEnabled = *req.Enabled
}
var mergedConfig map[string]string
if req.Config != nil {
mergedConfig, err = s.mergeConfig(ctx, id, req.Config)
if err != nil {
return nil, err
}
}
if finalEnabled {
configToValidate := mergedConfig
if configToValidate == nil {
......@@ -269,9 +309,9 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
}
if req.SupportedTypes != nil {
// Check pending orders before removing payment types
count, err := s.countPendingOrders(ctx, id)
count, err := getPendingOrderCount()
if err != nil {
return nil, fmt.Errorf("check pending orders: %w", err)
return nil, err
}
if count > 0 {
// Load current instance to compare types
......
......@@ -4,8 +4,16 @@ package service
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"strconv"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
......@@ -199,7 +207,7 @@ func TestJoinTypes(t *testing.T) {
}
}
func TestCreateProviderInstanceRejectsConflictingVisibleMethodEnablement(t *testing.T) {
func TestCreateProviderInstanceAllowsVisibleMethodProvidersFromDifferentSources(t *testing.T) {
t.Parallel()
ctx := context.Background()
......@@ -227,15 +235,14 @@ func TestCreateProviderInstanceRejectsConflictingVisibleMethodEnablement(t *test
_, err = svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
ProviderKey: "alipay",
Name: "Official Alipay",
Config: map[string]string{"appId": "app-1"},
Config: map[string]string{"appId": "app-1", "privateKey": "private-key"},
SupportedTypes: []string{"alipay"},
Enabled: true,
})
require.Error(t, err)
require.Equal(t, "PAYMENT_PROVIDER_CONFLICT", infraerrors.Reason(err))
require.NoError(t, err)
}
func TestUpdateProviderInstanceRejectsEnablingConflictingVisibleMethodProvider(t *testing.T) {
func TestUpdateProviderInstanceAllowsEnablingVisibleMethodProviderFromDifferentSource(t *testing.T) {
t.Parallel()
ctx := context.Background()
......@@ -264,7 +271,7 @@ func TestUpdateProviderInstanceRejectsEnablingConflictingVisibleMethodProvider(t
candidate, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
ProviderKey: "wxpay",
Name: "Official WeChat",
Config: map[string]string{"appId": "wx-app"},
Config: validWxpayProviderConfig(t),
SupportedTypes: []string{"wxpay"},
Enabled: false,
})
......@@ -273,8 +280,7 @@ func TestUpdateProviderInstanceRejectsEnablingConflictingVisibleMethodProvider(t
_, err = svc.UpdateProviderInstance(ctx, candidate.ID, UpdateProviderInstanceRequest{
Enabled: boolPtrValue(true),
})
require.Error(t, err)
require.Equal(t, "PAYMENT_PROVIDER_CONFLICT", infraerrors.Reason(err))
require.NoError(t, err)
}
func TestUpdateProviderInstancePersistsEnabledAndSupportedTypes(t *testing.T) {
......@@ -314,6 +320,289 @@ func TestUpdateProviderInstancePersistsEnabledAndSupportedTypes(t *testing.T) {
require.Equal(t, "alipay,wxpay", saved.SupportedTypes)
}
func TestUpdateProviderInstanceRejectsProtectedConfigChangesWhilePendingOrders(t *testing.T) {
t.Parallel()
tests := []struct {
name string
providerKey string
createConfig func(*testing.T) map[string]string
supportedType []string
updateConfig map[string]string
fieldName string
wantValue string
}{
{
name: "wxpay appId",
providerKey: payment.TypeWxpay,
createConfig: validWxpayProviderConfig,
supportedType: []string{payment.TypeWxpay},
updateConfig: map[string]string{"appId": "wx-app-updated"},
fieldName: "appId",
wantValue: "wx-app-test",
},
{
name: "wxpay mpAppId",
providerKey: payment.TypeWxpay,
createConfig: validWxpayProviderConfigWithJSAPIAppID,
supportedType: []string{payment.TypeWxpay},
updateConfig: map[string]string{"mpAppId": "wx-mp-app-updated"},
fieldName: "mpAppId",
wantValue: "wx-mp-app-test",
},
{
name: "wxpay mchId",
providerKey: payment.TypeWxpay,
createConfig: validWxpayProviderConfig,
supportedType: []string{payment.TypeWxpay},
updateConfig: map[string]string{"mchId": "mch-updated"},
fieldName: "mchId",
wantValue: "mch-test",
},
{
name: "wxpay publicKeyId",
providerKey: payment.TypeWxpay,
createConfig: validWxpayProviderConfig,
supportedType: []string{payment.TypeWxpay},
updateConfig: map[string]string{"publicKeyId": "public-key-id-updated"},
fieldName: "publicKeyId",
wantValue: "public-key-id-test",
},
{
name: "wxpay certSerial",
providerKey: payment.TypeWxpay,
createConfig: validWxpayProviderConfig,
supportedType: []string{payment.TypeWxpay},
updateConfig: map[string]string{"certSerial": "cert-serial-updated"},
fieldName: "certSerial",
wantValue: "cert-serial-test",
},
{
name: "alipay appId",
providerKey: payment.TypeAlipay,
createConfig: validAlipayProviderConfig,
supportedType: []string{payment.TypeAlipay},
updateConfig: map[string]string{"appId": "alipay-app-updated"},
fieldName: "appId",
wantValue: "alipay-app-test",
},
{
name: "easypay pid",
providerKey: payment.TypeEasyPay,
createConfig: validEasyPayProviderConfig,
supportedType: []string{payment.TypeAlipay},
updateConfig: map[string]string{"pid": "pid-updated"},
fieldName: "pid",
wantValue: "pid-test",
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
svc := &PaymentConfigService{
entClient: client,
encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
}
instance, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
ProviderKey: tc.providerKey,
Name: "protected-config-instance",
Config: tc.createConfig(t),
SupportedTypes: tc.supportedType,
Enabled: true,
})
require.NoError(t, err)
createPendingProviderConfigOrder(t, ctx, client, instance)
updated, err := svc.UpdateProviderInstance(ctx, instance.ID, UpdateProviderInstanceRequest{
Config: tc.updateConfig,
})
require.Nil(t, updated)
require.Error(t, err)
require.Equal(t, "PENDING_ORDERS", infraerrors.Reason(err))
saved, err := client.PaymentProviderInstance.Get(ctx, instance.ID)
require.NoError(t, err)
cfg, err := svc.decryptConfig(saved.Config)
require.NoError(t, err)
require.Equal(t, tc.wantValue, cfg[tc.fieldName])
})
}
}
func TestUpdateProviderInstanceAllowsSafeConfigChangesWhilePendingOrders(t *testing.T) {
t.Parallel()
tests := []struct {
name string
providerKey string
createConfig func(*testing.T) map[string]string
supportedType []string
updateConfig map[string]string
fieldName string
wantValue string
}{
{
name: "wxpay notifyUrl",
providerKey: payment.TypeWxpay,
createConfig: validWxpayProviderConfig,
supportedType: []string{payment.TypeWxpay},
updateConfig: map[string]string{"notifyUrl": "https://merchant.example.com/wxpay/notify-v2"},
fieldName: "notifyUrl",
wantValue: "https://merchant.example.com/wxpay/notify-v2",
},
{
name: "alipay same appId",
providerKey: payment.TypeAlipay,
createConfig: validAlipayProviderConfig,
supportedType: []string{payment.TypeAlipay},
updateConfig: map[string]string{"appId": "alipay-app-test"},
fieldName: "appId",
wantValue: "alipay-app-test",
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
svc := &PaymentConfigService{
entClient: client,
encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
}
instance, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
ProviderKey: tc.providerKey,
Name: "safe-config-instance",
Config: tc.createConfig(t),
SupportedTypes: tc.supportedType,
Enabled: true,
})
require.NoError(t, err)
createPendingProviderConfigOrder(t, ctx, client, instance)
updated, err := svc.UpdateProviderInstance(ctx, instance.ID, UpdateProviderInstanceRequest{
Config: tc.updateConfig,
})
require.NoError(t, err)
require.NotNil(t, updated)
saved, err := client.PaymentProviderInstance.Get(ctx, instance.ID)
require.NoError(t, err)
cfg, err := svc.decryptConfig(saved.Config)
require.NoError(t, err)
require.Equal(t, tc.wantValue, cfg[tc.fieldName])
})
}
}
func createPendingProviderConfigOrder(t *testing.T, ctx context.Context, client *dbent.Client, instance *dbent.PaymentProviderInstance) {
t.Helper()
user, err := client.User.Create().
SetEmail("provider-config-pending@example.com").
SetPasswordHash("hash").
SetUsername("provider-config-pending-user").
Save(ctx)
require.NoError(t, err)
instanceID := strconv.FormatInt(instance.ID, 10)
_, err = client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(88).
SetPayAmount(88).
SetFeeRate(0).
SetRechargeCode("PENDING-PROVIDER-CONFIG-" + instanceID).
SetOutTradeNo("sub2_pending_provider_config_" + instanceID).
SetPaymentType(providerPendingOrderPaymentType(instance.ProviderKey)).
SetPaymentTradeNo("").
SetOrderType(payment.OrderTypeBalance).
SetStatus(OrderStatusPending).
SetExpiresAt(time.Now().Add(time.Hour)).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
SetProviderInstanceID(instanceID).
SetProviderKey(instance.ProviderKey).
Save(ctx)
require.NoError(t, err)
}
func providerPendingOrderPaymentType(providerKey string) string {
switch providerKey {
case payment.TypeWxpay:
return payment.TypeWxpay
case payment.TypeAlipay:
return payment.TypeAlipay
default:
return payment.TypeAlipay
}
}
func boolPtrValue(v bool) *bool {
return &v
}
func validAlipayProviderConfig(t *testing.T) map[string]string {
t.Helper()
return map[string]string{
"appId": "alipay-app-test",
"privateKey": "alipay-private-key-test",
"notifyUrl": "https://merchant.example.com/alipay/notify",
"returnUrl": "https://merchant.example.com/alipay/return",
}
}
func validEasyPayProviderConfig(t *testing.T) map[string]string {
t.Helper()
return map[string]string{
"pid": "pid-test",
"pkey": "pkey-test",
"apiBase": "https://pay.example.com",
"notifyUrl": "https://merchant.example.com/easypay/notify",
"returnUrl": "https://merchant.example.com/easypay/return",
}
}
func validWxpayProviderConfig(t *testing.T) map[string]string {
t.Helper()
key, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
privDER, err := x509.MarshalPKCS8PrivateKey(key)
require.NoError(t, err)
pubDER, err := x509.MarshalPKIXPublicKey(&key.PublicKey)
require.NoError(t, err)
return map[string]string{
"appId": "wx-app-test",
"mchId": "mch-test",
"privateKey": string(pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privDER})),
"apiV3Key": "12345678901234567890123456789012",
"publicKey": string(pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubDER})),
"publicKeyId": "public-key-id-test",
"certSerial": "cert-serial-test",
}
}
func validWxpayProviderConfigWithJSAPIAppID(t *testing.T) map[string]string {
t.Helper()
cfg := validWxpayProviderConfig(t)
cfg["mpAppId"] = "wx-mp-app-test"
return cfg
}
......@@ -80,21 +80,25 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo
})
return err
}
// Skip amount check when paid=0 (e.g. QueryOrder doesn't return amount).
// Also skip if paid is NaN/Inf (malformed provider data).
if paid > 0 && !math.IsNaN(paid) && !math.IsInf(paid, 0) {
if math.Abs(paid-o.PayAmount) > amountToleranceCNY {
s.writeAuditLog(ctx, o.ID, "PAYMENT_AMOUNT_MISMATCH", pk, map[string]any{"expected": o.PayAmount, "paid": paid, "tradeNo": tradeNo})
return fmt.Errorf("amount mismatch: expected %.2f, got %.2f", o.PayAmount, paid)
}
if !isValidProviderAmount(paid) {
s.writeAuditLog(ctx, o.ID, "PAYMENT_INVALID_AMOUNT", pk, map[string]any{
"expected": o.PayAmount,
"paid": paid,
"tradeNo": tradeNo,
})
return fmt.Errorf("invalid paid amount from provider: %v", paid)
}
// Use order's expected amount when provider didn't report one
if paid <= 0 || math.IsNaN(paid) || math.IsInf(paid, 0) {
paid = o.PayAmount
if math.Abs(paid-o.PayAmount) > amountToleranceCNY {
s.writeAuditLog(ctx, o.ID, "PAYMENT_AMOUNT_MISMATCH", pk, map[string]any{"expected": o.PayAmount, "paid": paid, "tradeNo": tradeNo})
return fmt.Errorf("amount mismatch: expected %.2f, got %.2f", o.PayAmount, paid)
}
return s.toPaid(ctx, o, tradeNo, paid, pk)
}
func isValidProviderAmount(amount float64) bool {
return amount > 0 && !math.IsNaN(amount) && !math.IsInf(amount, 0)
}
func validateProviderNotificationMetadata(order *dbent.PaymentOrder, providerKey string, metadata map[string]string) error {
return validateProviderSnapshotMetadata(order, providerKey, metadata)
}
......
......@@ -5,6 +5,7 @@ package service
import (
"context"
"errors"
"math"
"testing"
dbent "github.com/Wei-Shaw/sub2api/ent"
......@@ -322,6 +323,16 @@ func TestParseLegacyPaymentOrderID(t *testing.T) {
assert.False(t, ok)
}
func TestIsValidProviderAmount(t *testing.T) {
t.Parallel()
assert.True(t, isValidProviderAmount(0.01))
assert.False(t, isValidProviderAmount(0))
assert.False(t, isValidProviderAmount(-1))
assert.False(t, isValidProviderAmount(math.NaN()))
assert.False(t, isValidProviderAmount(math.Inf(1)))
}
func TestValidateProviderNotificationMetadataRejectsAlipaySnapshotMismatch(t *testing.T) {
t.Parallel()
......
......@@ -139,6 +139,10 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
tm = defaultOrderTimeoutMin
}
exp := time.Now().Add(time.Duration(tm) * time.Minute)
outTradeNo, err := s.allocateOutTradeNo(ctx, tx)
if err != nil {
return nil, err
}
providerSnapshot := buildPaymentOrderProviderSnapshot(sel, req)
selectedInstanceID := ""
selectedProviderKey := ""
......@@ -155,7 +159,7 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
SetPayAmount(payAmount).
SetFeeRate(feeRate).
SetRechargeCode("").
SetOutTradeNo(generateOutTradeNo()).
SetOutTradeNo(outTradeNo).
SetPaymentType(req.PaymentType).
SetPaymentTradeNo("").
SetOrderType(req.OrderType).
......@@ -193,6 +197,21 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
return order, nil
}
func (s *PaymentService) allocateOutTradeNo(ctx context.Context, tx *dbent.Tx) (string, error) {
const maxAttempts = 5
for attempt := 0; attempt < maxAttempts; attempt++ {
candidate := generateOutTradeNo()
exists, err := tx.PaymentOrder.Query().Where(paymentorder.OutTradeNo(candidate)).Exist(ctx)
if err != nil {
return "", fmt.Errorf("check out_trade_no uniqueness: %w", err)
}
if !exists {
return candidate, nil
}
}
return "", fmt.Errorf("generate unique out_trade_no: exhausted %d attempts", maxAttempts)
}
func (s *PaymentService) checkPendingLimit(ctx context.Context, tx *dbent.Tx, userID int64, max int) error {
if max <= 0 {
max = defaultMaxPendingOrders
......@@ -360,13 +379,13 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen
}
subject := s.buildPaymentSubject(plan, limitAmount, cfg)
outTradeNo := order.OutTradeNo
canonicalReturnURL, err := CanonicalizeReturnURL(req.ReturnURL, req.SrcHost)
canonicalReturnURL, err := CanonicalizeReturnURL(req.ReturnURL, req.SrcHost, req.SrcURL)
if err != nil {
return nil, err
}
resumeToken := ""
if resume := s.paymentResume(); resume != nil {
if resume.isSigningConfigured() {
if canonicalReturnURL != "" && resume.isSigningConfigured() {
resumeToken, err = resume.CreateToken(ResumeTokenClaims{
OrderID: order.ID,
UserID: order.UserID,
......@@ -380,7 +399,7 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen
}
}
}
providerReturnURL, err := buildPaymentReturnURL(canonicalReturnURL, order.ID, resumeToken)
providerReturnURL, err := buildPaymentReturnURL(canonicalReturnURL, order.ID, outTradeNo, resumeToken)
if err != nil {
return nil, err
}
......@@ -482,6 +501,9 @@ func (s *PaymentService) buildWeChatOAuthRequiredResponse(ctx context.Context, r
if err != nil {
return nil, err
}
if err := s.paymentResume().ensureSigningKey(); err != nil {
return nil, err
}
authorizeURL, err := buildWeChatPaymentOAuthStartURL(req, "snsapi_base")
if err != nil {
......
......@@ -31,3 +31,68 @@ func TestUsesOfficialWxpayVisibleMethodDerivesFromEnabledProviderInstance(t *tes
t.Fatal("expected official wxpay visible method to be detected from enabled provider instance")
}
}
func TestUsesOfficialWxpayVisibleMethodRespectsConfiguredSourceWhenMultipleProvidersEnabled(t *testing.T) {
tests := []struct {
name string
source string
wantOfficial bool
}{
{
name: "official source selected",
source: VisibleMethodSourceOfficialWechat,
wantOfficial: true,
},
{
name: "easypay source selected",
source: VisibleMethodSourceEasyPayWechat,
wantOfficial: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
_, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeWxpay).
SetName("Official WeChat").
SetConfig("{}").
SetSupportedTypes("wxpay").
SetEnabled(true).
SetSortOrder(1).
Save(ctx)
if err != nil {
t.Fatalf("create official wxpay instance: %v", err)
}
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeEasyPay).
SetName("EasyPay WeChat").
SetConfig("{}").
SetSupportedTypes("wxpay").
SetEnabled(true).
SetSortOrder(2).
Save(ctx)
if err != nil {
t.Fatalf("create easypay wxpay instance: %v", err)
}
svc := &PaymentService{
configService: &PaymentConfigService{
entClient: client,
settingRepo: &paymentConfigSettingRepoStub{
values: map[string]string{
SettingPaymentVisibleMethodWxpaySource: tt.source,
},
},
},
}
if got := svc.usesOfficialWxpayVisibleMethod(ctx); got != tt.wantOfficial {
t.Fatalf("usesOfficialWxpayVisibleMethod() = %v, want %v", got, tt.wantOfficial)
}
})
}
}
......@@ -150,6 +150,20 @@ func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) s
return ""
}
if resp.Status == payment.ProviderStatusPaid {
if !isValidProviderAmount(resp.Amount) {
s.writeAuditLog(ctx, o.ID, "PAYMENT_INVALID_AMOUNT", prov.ProviderKey(), map[string]any{
"expected": o.PayAmount,
"paid": resp.Amount,
"tradeNo": resp.TradeNo,
"queryRef": queryRef,
})
slog.Warn("query upstream returned invalid paid amount", "orderID", o.ID, "queryRef", queryRef, "paid", resp.Amount)
retriedResp, retryOK := requeryPaidOrderOnce(ctx, prov, queryRef)
if !retryOK {
return ""
}
resp = retriedResp
}
notificationTradeNo := o.PaymentTradeNo
if upstreamTradeNo := strings.TrimSpace(resp.TradeNo); paymentOrderShouldPersistUpstreamTradeNo(queryRef, upstreamTradeNo, notificationTradeNo) {
if _, updateErr := s.entClient.PaymentOrder.Update().
......@@ -174,6 +188,21 @@ func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) s
return ""
}
func requeryPaidOrderOnce(ctx context.Context, prov payment.Provider, queryRef string) (*payment.QueryOrderResponse, bool) {
if prov == nil || strings.TrimSpace(queryRef) == "" {
return nil, false
}
resp, err := prov.QueryOrder(ctx, queryRef)
if err != nil {
slog.Warn("query upstream retry failed", "queryRef", queryRef, "error", err)
return nil, false
}
if resp == nil || resp.Status != payment.ProviderStatusPaid || !isValidProviderAmount(resp.Amount) {
return nil, false
}
return resp, true
}
func paymentOrderQueryReference(order *dbent.PaymentOrder, prov payment.Provider) string {
if order == nil {
return ""
......@@ -224,6 +253,10 @@ func paymentOrderShouldPersistUpstreamTradeNo(queryRef, upstreamTradeNo, current
// if a payment was made, and processes it if so. This handles the case where
// the provider's notify callback was missed (e.g. EasyPay popup mode).
func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo string, userID int64) (*dbent.PaymentOrder, error) {
outTradeNo, err := normalizeOrderLookupOutTradeNo(outTradeNo)
if err != nil {
return nil, err
}
o, err := s.entClient.PaymentOrder.Query().
Where(paymentorder.OutTradeNo(outTradeNo)).
Only(ctx)
......@@ -251,6 +284,10 @@ func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo
// triggering any upstream reconciliation. Signed resume-token recovery is the
// only public recovery path allowed to query upstream state.
func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo string) (*dbent.PaymentOrder, error) {
outTradeNo, err := normalizeOrderLookupOutTradeNo(outTradeNo)
if err != nil {
return nil, err
}
o, err := s.entClient.PaymentOrder.Query().
Where(paymentorder.OutTradeNo(outTradeNo)).
Only(ctx)
......@@ -260,6 +297,27 @@ func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo strin
return o, nil
}
func normalizeOrderLookupOutTradeNo(raw string) (string, error) {
outTradeNo := strings.TrimSpace(raw)
if outTradeNo == "" {
return "", infraerrors.BadRequest("INVALID_OUT_TRADE_NO", "out_trade_no is required")
}
if len(outTradeNo) > 64 {
return "", infraerrors.BadRequest("INVALID_OUT_TRADE_NO", "out_trade_no is invalid")
}
for _, ch := range outTradeNo {
switch {
case ch >= 'a' && ch <= 'z':
case ch >= 'A' && ch <= 'Z':
case ch >= '0' && ch <= '9':
case ch == '_' || ch == '-':
default:
return "", infraerrors.BadRequest("INVALID_OUT_TRADE_NO", "out_trade_no is invalid")
}
}
return outTradeNo, nil
}
func (s *PaymentService) ExpireTimedOutOrders(ctx context.Context) (int, error) {
now := time.Now()
orders, err := s.entClient.PaymentOrder.Query().Where(paymentorder.StatusEQ(OrderStatusPending), paymentorder.ExpiresAtLTE(now)).All(ctx)
......
......@@ -21,6 +21,8 @@ import (
type paymentOrderLifecycleQueryProvider struct {
lastQueryTradeNo string
queryCalls int
responses []*payment.QueryOrderResponse
resp *payment.QueryOrderResponse
}
......@@ -48,6 +50,14 @@ func (p *paymentOrderLifecycleQueryProvider) CreatePayment(context.Context, paym
func (p *paymentOrderLifecycleQueryProvider) QueryOrder(_ context.Context, tradeNo string) (*payment.QueryOrderResponse, error) {
p.lastQueryTradeNo = tradeNo
p.queryCalls++
if len(p.responses) > 0 {
resp := p.responses[0]
if len(p.responses) > 1 {
p.responses = p.responses[1:]
}
return resp, nil
}
return p.resp, nil
}
......@@ -234,6 +244,194 @@ func TestVerifyOrderByOutTradeNoBackfillsTradeNoFromPaidQuery(t *testing.T) {
require.Equal(t, user.ID, redeemRepo.useCalls[0].userID)
}
func TestVerifyOrderByOutTradeNoRetriesZeroAmountPaidQueryOnce(t *testing.T) {
ctx := context.Background()
client := newPaymentOrderLifecycleTestClient(t)
user, err := client.User.Create().
SetEmail("checkpaid-retry@example.com").
SetPasswordHash("hash").
SetUsername("checkpaid-retry-user").
Save(ctx)
require.NoError(t, err)
order, err := client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(88).
SetPayAmount(88).
SetFeeRate(0).
SetRechargeCode("CHECKPAID-UPSTREAM-RETRY").
SetOutTradeNo("sub2_checkpaid_retry_zero_amount").
SetPaymentType(payment.TypeAlipay).
SetPaymentTradeNo("").
SetOrderType(payment.OrderTypeBalance).
SetStatus(OrderStatusPending).
SetExpiresAt(time.Now().Add(time.Hour)).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
Save(ctx)
require.NoError(t, err)
userRepo := &mockUserRepo{
getByIDUser: &User{
ID: user.ID,
Email: user.Email,
Username: user.Username,
Balance: 0,
},
}
userRepo.updateBalanceFn = func(ctx context.Context, id int64, amount float64) error {
require.Equal(t, user.ID, id)
if userRepo.getByIDUser != nil {
userRepo.getByIDUser.Balance += amount
}
return nil
}
redeemRepo := &paymentOrderLifecycleRedeemRepo{
codesByCode: map[string]*RedeemCode{
order.RechargeCode: {
ID: 1,
Code: order.RechargeCode,
Type: RedeemTypeBalance,
Value: order.Amount,
Status: StatusUnused,
},
},
}
redeemService := NewRedeemService(
redeemRepo,
userRepo,
nil,
nil,
nil,
client,
nil,
)
registry := payment.NewRegistry()
provider := &paymentOrderLifecycleQueryProvider{
responses: []*payment.QueryOrderResponse{
{
TradeNo: "upstream-trade-zero",
Status: payment.ProviderStatusPaid,
Amount: 0,
},
{
TradeNo: "upstream-trade-retry",
Status: payment.ProviderStatusPaid,
Amount: 88,
},
},
}
registry.Register(provider)
svc := &PaymentService{
entClient: client,
registry: registry,
redeemService: redeemService,
userRepo: userRepo,
providersLoaded: true,
}
got, err := svc.VerifyOrderByOutTradeNo(ctx, order.OutTradeNo, user.ID)
require.NoError(t, err)
require.Equal(t, 2, provider.queryCalls)
require.Equal(t, OrderStatusCompleted, got.Status)
require.Equal(t, "upstream-trade-retry", got.PaymentTradeNo)
}
func TestVerifyOrderByOutTradeNoRejectsPaidQueryWithZeroAmount(t *testing.T) {
ctx := context.Background()
client := newPaymentOrderLifecycleTestClient(t)
user, err := client.User.Create().
SetEmail("checkpaid-zero-amount@example.com").
SetPasswordHash("hash").
SetUsername("checkpaid-zero-amount-user").
Save(ctx)
require.NoError(t, err)
order, err := client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(88).
SetPayAmount(88).
SetFeeRate(0).
SetRechargeCode("CHECKPAID-ZERO-AMOUNT").
SetOutTradeNo("sub2_checkpaid_zero_amount").
SetPaymentType(payment.TypeAlipay).
SetPaymentTradeNo("").
SetOrderType(payment.OrderTypeBalance).
SetStatus(OrderStatusPending).
SetExpiresAt(time.Now().Add(time.Hour)).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
Save(ctx)
require.NoError(t, err)
userRepo := &mockUserRepo{
getByIDUser: &User{
ID: user.ID,
Email: user.Email,
Username: user.Username,
Balance: 0,
},
}
redeemRepo := &paymentOrderLifecycleRedeemRepo{
codesByCode: map[string]*RedeemCode{
order.RechargeCode: {
ID: 1,
Code: order.RechargeCode,
Type: RedeemTypeBalance,
Value: order.Amount,
Status: StatusUnused,
},
},
}
redeemService := NewRedeemService(
redeemRepo,
userRepo,
nil,
nil,
nil,
client,
nil,
)
registry := payment.NewRegistry()
provider := &paymentOrderLifecycleQueryProvider{
resp: &payment.QueryOrderResponse{
TradeNo: "upstream-trade-zero",
Status: payment.ProviderStatusPaid,
Amount: 0,
},
}
registry.Register(provider)
svc := &PaymentService{
entClient: client,
registry: registry,
redeemService: redeemService,
userRepo: userRepo,
providersLoaded: true,
}
got, err := svc.VerifyOrderByOutTradeNo(ctx, order.OutTradeNo, user.ID)
require.NoError(t, err)
require.Equal(t, order.OutTradeNo, provider.lastQueryTradeNo)
require.Equal(t, OrderStatusPending, got.Status)
require.Empty(t, got.PaymentTradeNo)
reloaded, err := client.PaymentOrder.Get(ctx, order.ID)
require.NoError(t, err)
require.Equal(t, OrderStatusPending, reloaded.Status)
require.Empty(t, reloaded.PaymentTradeNo)
require.Equal(t, 0.0, userRepo.getByIDUser.Balance)
require.Empty(t, redeemRepo.useCalls)
}
func TestVerifyOrderByOutTradeNoUsesOutTradeNoWhenPaymentTradeNoAlreadyExistsForAlipay(t *testing.T) {
ctx := context.Background()
client := newPaymentOrderLifecycleTestClient(t)
......
......@@ -2,6 +2,7 @@ package service
import (
"context"
"strings"
"testing"
"time"
......@@ -91,6 +92,8 @@ func TestBuildCreateOrderResponseCopiesJSAPIPayload(t *testing.T) {
}
func TestMaybeBuildWeChatOAuthRequiredResponse(t *testing.T) {
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef")
svc := newWeChatPaymentOAuthTestService(map[string]string{
SettingKeyWeChatConnectEnabled: "true",
SettingKeyWeChatConnectAppID: "wx123456",
......@@ -159,6 +162,83 @@ func TestMaybeBuildWeChatOAuthRequiredResponseRequiresMPConfigInWeChat(t *testin
}
}
func TestMaybeBuildWeChatOAuthRequiredResponseRequiresResumeSigningKey(t *testing.T) {
t.Parallel()
svc := &PaymentService{
configService: &PaymentConfigService{
settingRepo: &paymentConfigSettingRepoStub{values: map[string]string{
SettingKeyWeChatConnectEnabled: "true",
SettingKeyWeChatConnectAppID: "wx123456",
SettingKeyWeChatConnectAppSecret: "wechat-secret",
SettingKeyWeChatConnectMode: "mp",
SettingKeyWeChatConnectScopes: "snsapi_base",
SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
}},
// Intentionally missing payment resume signing key.
encryptionKey: nil,
},
}
resp, err := svc.maybeBuildWeChatOAuthRequiredResponse(context.Background(), CreateOrderRequest{
Amount: 12.5,
PaymentType: payment.TypeWxpay,
IsWeChatBrowser: true,
SrcURL: "https://merchant.example/payment?from=wechat",
OrderType: payment.OrderTypeBalance,
}, 12.5, 12.88, 0.03)
if resp != nil {
t.Fatalf("expected nil response, got %+v", resp)
}
if err == nil {
t.Fatal("expected error, got nil")
}
appErr := infraerrors.FromError(err)
if appErr.Reason != "PAYMENT_RESUME_NOT_CONFIGURED" {
t.Fatalf("reason = %q, want %q", appErr.Reason, "PAYMENT_RESUME_NOT_CONFIGURED")
}
}
func TestMaybeBuildWeChatOAuthRequiredResponseFallsBackToConfiguredLegacySigningKey(t *testing.T) {
svc := &PaymentService{
configService: &PaymentConfigService{
settingRepo: &paymentConfigSettingRepoStub{values: map[string]string{
SettingKeyWeChatConnectEnabled: "true",
SettingKeyWeChatConnectAppID: "wx123456",
SettingKeyWeChatConnectAppSecret: "wechat-secret",
SettingKeyWeChatConnectMode: "mp",
SettingKeyWeChatConnectScopes: "snsapi_base",
SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
}},
// Legacy stable signing key remains available for no-config upgrade compatibility.
encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
},
}
resp, err := svc.maybeBuildWeChatOAuthRequiredResponse(context.Background(), CreateOrderRequest{
Amount: 12.5,
PaymentType: payment.TypeWxpay,
IsWeChatBrowser: true,
SrcURL: "https://merchant.example/payment?from=wechat",
OrderType: payment.OrderTypeBalance,
}, 12.5, 12.88, 0.03)
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
if resp == nil {
t.Fatal("expected oauth-required response, got nil")
}
if resp.ResultType != payment.CreatePaymentResultOAuthRequired {
t.Fatalf("result type = %q, want %q", resp.ResultType, payment.CreatePaymentResultOAuthRequired)
}
if resp.OAuth == nil || strings.TrimSpace(resp.OAuth.AuthorizeURL) == "" {
t.Fatalf("expected oauth redirect payload, got %+v", resp.OAuth)
}
}
func TestMaybeBuildWeChatOAuthRequiredResponseForSelectionSkipsEasyPayProvider(t *testing.T) {
svc := newWeChatPaymentOAuthTestService(map[string]string{
SettingKeyWeChatConnectEnabled: "true",
......@@ -189,7 +269,8 @@ func TestMaybeBuildWeChatOAuthRequiredResponseForSelectionSkipsEasyPayProvider(t
func newWeChatPaymentOAuthTestService(values map[string]string) *PaymentService {
return &PaymentService{
configService: &PaymentConfigService{
settingRepo: &paymentConfigSettingRepoStub{values: values},
settingRepo: &paymentConfigSettingRepoStub{values: values},
encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
},
}
}
......@@ -6,6 +6,7 @@ import (
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token string) (*dbent.PaymentOrder, error) {
......@@ -16,10 +17,13 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token
order, err := s.entClient.PaymentOrder.Get(ctx, claims.OrderID)
if err != nil {
if dbent.IsNotFound(err) {
return nil, infraerrors.NotFound("NOT_FOUND", "order not found")
}
return nil, fmt.Errorf("get order by resume token: %w", err)
}
if claims.UserID > 0 && order.UserID != claims.UserID {
return nil, fmt.Errorf("resume token user mismatch")
return nil, invalidResumeTokenMatchError()
}
snapshot := psOrderProviderSnapshot(order)
orderProviderInstanceID := strings.TrimSpace(psStringValue(order.ProviderInstanceID))
......@@ -33,13 +37,13 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token
}
}
if claims.ProviderInstanceID != "" && orderProviderInstanceID != claims.ProviderInstanceID {
return nil, fmt.Errorf("resume token provider instance mismatch")
return nil, invalidResumeTokenMatchError()
}
if claims.ProviderKey != "" && orderProviderKey != claims.ProviderKey {
return nil, fmt.Errorf("resume token provider key mismatch")
if claims.ProviderKey != "" && !strings.EqualFold(orderProviderKey, claims.ProviderKey) {
return nil, invalidResumeTokenMatchError()
}
if claims.PaymentType != "" && strings.TrimSpace(order.PaymentType) != claims.PaymentType {
return nil, fmt.Errorf("resume token payment type mismatch")
if claims.PaymentType != "" && NormalizeVisibleMethod(order.PaymentType) != NormalizeVisibleMethod(claims.PaymentType) {
return nil, invalidResumeTokenMatchError()
}
if order.Status == OrderStatusPending || order.Status == OrderStatusExpired {
result := s.checkPaid(ctx, order)
......@@ -54,6 +58,10 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token
return order, nil
}
func invalidResumeTokenMatchError() error {
return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token does not match the payment order")
}
func (s *PaymentService) ParseWeChatPaymentResumeToken(token string) (*WeChatPaymentResumeClaims, error) {
return s.paymentResume().ParseWeChatPaymentResumeToken(strings.TrimSpace(token))
}
......@@ -8,6 +8,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require"
)
......@@ -143,7 +144,7 @@ func TestGetPublicOrderByResumeTokenRejectsSnapshotMismatch(t *testing.T) {
_, err = svc.GetPublicOrderByResumeToken(ctx, token)
require.Error(t, err)
require.Contains(t, err.Error(), "resume token")
require.Equal(t, "INVALID_RESUME_TOKEN", infraerrors.Reason(err))
}
func TestGetPublicOrderByResumeTokenUsesSnapshotAuthorityWhenColumnsDiffer(t *testing.T) {
......@@ -302,3 +303,13 @@ func TestVerifyOrderPublicDoesNotCheckUpstreamForPendingOrder(t *testing.T) {
require.Equal(t, order.ID, got.ID)
require.Equal(t, 0, provider.queryCount)
}
func TestVerifyOrderPublicRejectsBlankOutTradeNo(t *testing.T) {
svc := &PaymentService{
entClient: newPaymentConfigServiceTestClient(t),
}
_, err := svc.VerifyOrderPublic(context.Background(), " ")
require.Error(t, err)
require.Equal(t, "INVALID_OUT_TRADE_NO", infraerrors.Reason(err))
}
package service
import (
"bytes"
"context"
"crypto/hmac"
"crypto/sha256"
......@@ -68,6 +69,7 @@ type WeChatPaymentResumeClaims struct {
type PaymentResumeService struct {
signingKey []byte
verifyKeys [][]byte
}
type visibleMethodLoadBalancer struct {
......@@ -75,8 +77,29 @@ type visibleMethodLoadBalancer struct {
configService *PaymentConfigService
}
func NewPaymentResumeService(signingKey []byte) *PaymentResumeService {
return &PaymentResumeService{signingKey: signingKey}
func NewPaymentResumeService(signingKey []byte, verifyFallbacks ...[]byte) *PaymentResumeService {
svc := &PaymentResumeService{}
if len(signingKey) > 0 {
svc.signingKey = append([]byte(nil), signingKey...)
svc.verifyKeys = append(svc.verifyKeys, svc.signingKey)
}
for _, fallback := range verifyFallbacks {
if len(fallback) == 0 {
continue
}
cloned := append([]byte(nil), fallback...)
duplicate := false
for _, existing := range svc.verifyKeys {
if bytes.Equal(existing, cloned) {
duplicate = true
break
}
}
if !duplicate {
svc.verifyKeys = append(svc.verifyKeys, cloned)
}
}
return svc
}
func (s *PaymentResumeService) isSigningConfigured() bool {
......@@ -209,7 +232,7 @@ func visibleMethodSourceSettingKey(method string) string {
}
}
func CanonicalizeReturnURL(raw string, srcHost string) (string, error) {
func CanonicalizeReturnURL(raw string, srcHost string, srcURL string) (string, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return "", nil
......@@ -228,13 +251,29 @@ func CanonicalizeReturnURL(raw string, srcHost string) (string, error) {
if parsed.Path != paymentResultReturnPath {
return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must target the canonical internal payment result page")
}
if !sameOriginHost(parsed.Host, srcHost) {
return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must use the same host as the current site")
if !allowedReturnURLHost(parsed.Host, srcHost, srcURL) {
return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must use the same host as the current site or browser origin")
}
return parsed.String(), nil
}
func buildPaymentReturnURL(base string, orderID int64, resumeToken string) (string, error) {
func allowedReturnURLHost(returnURLHost string, requestHost string, refererURL string) bool {
if sameOriginHost(returnURLHost, requestHost) {
return true
}
refererURL = strings.TrimSpace(refererURL)
if refererURL == "" {
return false
}
parsedReferer, err := url.Parse(refererURL)
if err != nil || parsedReferer.Host == "" {
return false
}
return sameOriginHost(returnURLHost, parsedReferer.Host)
}
func buildPaymentReturnURL(base string, orderID int64, outTradeNo string, resumeToken string) (string, error) {
canonical := strings.TrimSpace(base)
if canonical == "" {
return "", nil
......@@ -253,6 +292,9 @@ func buildPaymentReturnURL(base string, orderID int64, resumeToken string) (stri
if orderID > 0 {
query.Set("order_id", strconv.FormatInt(orderID, 10))
}
if strings.TrimSpace(outTradeNo) != "" {
query.Set("out_trade_no", strings.TrimSpace(outTradeNo))
}
if strings.TrimSpace(resumeToken) != "" {
query.Set("resume_token", strings.TrimSpace(resumeToken))
}
......@@ -391,7 +433,7 @@ func (s *PaymentResumeService) parseSignedToken(token string, dest any) error {
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token is malformed")
}
if !hmac.Equal([]byte(parts[1]), []byte(s.sign(parts[0]))) {
if !s.verifySignature(parts[0], parts[1]) {
return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token signature mismatch")
}
payload, err := base64.RawURLEncoding.DecodeString(parts[0])
......@@ -401,6 +443,18 @@ func (s *PaymentResumeService) parseSignedToken(token string, dest any) error {
return json.Unmarshal(payload, dest)
}
func (s *PaymentResumeService) verifySignature(payload string, signature string) bool {
if s == nil {
return false
}
for _, key := range s.verifyKeys {
if hmac.Equal([]byte(signature), []byte(signPaymentResumePayload(payload, key))) {
return true
}
}
return false
}
func validatePaymentResumeExpiry(expiresAt int64, code, message string) error {
if expiresAt <= 0 {
return nil
......@@ -412,7 +466,11 @@ func validatePaymentResumeExpiry(expiresAt int64, code, message string) error {
}
func (s *PaymentResumeService) sign(payload string) string {
mac := hmac.New(sha256.New, s.signingKey)
return signPaymentResumePayload(payload, s.signingKey)
}
func signPaymentResumePayload(payload string, key []byte) string {
mac := hmac.New(sha256.New, key)
_, _ = mac.Write([]byte(payload))
return base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
}
......@@ -14,6 +14,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
func TestNormalizeVisibleMethods(t *testing.T) {
......@@ -64,7 +65,7 @@ func TestNormalizePaymentSource(t *testing.T) {
func TestCanonicalizeReturnURL(t *testing.T) {
t.Parallel()
got, err := CanonicalizeReturnURL("https://example.com/payment/result?b=2#a", "example.com")
got, err := CanonicalizeReturnURL("https://example.com/payment/result?b=2#a", "example.com", "")
if err != nil {
t.Fatalf("CanonicalizeReturnURL returned error: %v", err)
}
......@@ -76,7 +77,7 @@ func TestCanonicalizeReturnURL(t *testing.T) {
func TestCanonicalizeReturnURLRejectsRelativeURL(t *testing.T) {
t.Parallel()
if _, err := CanonicalizeReturnURL("/payment/result", "example.com"); err == nil {
if _, err := CanonicalizeReturnURL("/payment/result", "example.com", ""); err == nil {
t.Fatal("CanonicalizeReturnURL should reject relative URLs")
}
}
......@@ -84,15 +85,31 @@ func TestCanonicalizeReturnURLRejectsRelativeURL(t *testing.T) {
func TestCanonicalizeReturnURLRejectsExternalHost(t *testing.T) {
t.Parallel()
if _, err := CanonicalizeReturnURL("https://evil.example/payment/result", "app.example.com"); err == nil {
if _, err := CanonicalizeReturnURL("https://evil.example/payment/result", "app.example.com", ""); err == nil {
t.Fatal("CanonicalizeReturnURL should reject external hosts")
}
}
func TestCanonicalizeReturnURLAllowsConfiguredFrontendHost(t *testing.T) {
t.Parallel()
got, err := CanonicalizeReturnURL(
"https://app.example.com/payment/result?from=checkout",
"api.example.com",
"https://app.example.com/purchase",
)
if err != nil {
t.Fatalf("CanonicalizeReturnURL returned error: %v", err)
}
if got != "https://app.example.com/payment/result?from=checkout" {
t.Fatalf("CanonicalizeReturnURL = %q, want %q", got, "https://app.example.com/payment/result?from=checkout")
}
}
func TestCanonicalizeReturnURLRejectsNonCanonicalPath(t *testing.T) {
t.Parallel()
if _, err := CanonicalizeReturnURL("https://app.example.com/orders/42", "app.example.com"); err == nil {
if _, err := CanonicalizeReturnURL("https://app.example.com/orders/42", "app.example.com", ""); err == nil {
t.Fatal("CanonicalizeReturnURL should reject non-canonical result paths")
}
}
......@@ -100,7 +117,7 @@ func TestCanonicalizeReturnURLRejectsNonCanonicalPath(t *testing.T) {
func TestBuildPaymentReturnURL(t *testing.T) {
t.Parallel()
got, err := buildPaymentReturnURL("https://example.com/payment/result?from=checkout#fragment", 42, "resume-token")
got, err := buildPaymentReturnURL("https://example.com/payment/result?from=checkout#fragment", 42, "sub2_42", "resume-token")
if err != nil {
t.Fatalf("buildPaymentReturnURL returned error: %v", err)
}
......@@ -119,6 +136,9 @@ func TestBuildPaymentReturnURL(t *testing.T) {
if query.Get("order_id") != strconv.FormatInt(42, 10) {
t.Fatalf("order_id = %q", query.Get("order_id"))
}
if query.Get("out_trade_no") != "sub2_42" {
t.Fatalf("out_trade_no = %q", query.Get("out_trade_no"))
}
if query.Get("resume_token") != "resume-token" {
t.Fatalf("resume_token = %q", query.Get("resume_token"))
}
......@@ -127,10 +147,34 @@ func TestBuildPaymentReturnURL(t *testing.T) {
}
}
func TestBuildPaymentReturnURLWithoutResumeTokenStillIncludesOutTradeNo(t *testing.T) {
t.Parallel()
got, err := buildPaymentReturnURL("https://example.com/payment/result", 42, "sub2_42", "")
if err != nil {
t.Fatalf("buildPaymentReturnURL returned error: %v", err)
}
parsed, err := url.Parse(got)
if err != nil {
t.Fatalf("url.Parse returned error: %v", err)
}
query := parsed.Query()
if query.Get("order_id") != "42" {
t.Fatalf("order_id = %q", query.Get("order_id"))
}
if query.Get("out_trade_no") != "sub2_42" {
t.Fatalf("out_trade_no = %q", query.Get("out_trade_no"))
}
if query.Get("resume_token") != "" {
t.Fatalf("resume_token = %q, want empty", query.Get("resume_token"))
}
}
func TestBuildPaymentReturnURLEmptyBase(t *testing.T) {
t.Parallel()
got, err := buildPaymentReturnURL("", 42, "resume-token")
got, err := buildPaymentReturnURL("", 42, "sub2_42", "resume-token")
if err != nil {
t.Fatalf("buildPaymentReturnURL returned error: %v", err)
}
......@@ -290,6 +334,98 @@ func TestParseWeChatPaymentResumeTokenRejectsExpiredToken(t *testing.T) {
}
}
func TestPaymentServiceParseWeChatPaymentResumeTokenUsesExplicitSigningKey(t *testing.T) {
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "explicit-payment-resume-signing-key")
token, err := NewPaymentResumeService([]byte("explicit-payment-resume-signing-key")).CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
OpenID: "openid-explicit-key",
PaymentType: payment.TypeWxpay,
})
if err != nil {
t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
}
svc := &PaymentService{
configService: &PaymentConfigService{
encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
},
}
claims, err := svc.ParseWeChatPaymentResumeToken(token)
if err != nil {
t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err)
}
if claims.OpenID != "openid-explicit-key" {
t.Fatalf("openid = %q, want %q", claims.OpenID, "openid-explicit-key")
}
}
func TestPaymentServiceParseWeChatPaymentResumeTokenAcceptsLegacyEncryptionKeyDuringMigration(t *testing.T) {
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "explicit-payment-resume-signing-key")
legacyKey := []byte("0123456789abcdef0123456789abcdef")
token, err := NewPaymentResumeService(legacyKey).CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
OpenID: "openid-legacy-key",
PaymentType: payment.TypeWxpay,
})
if err != nil {
t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
}
svc := &PaymentService{
configService: &PaymentConfigService{
encryptionKey: legacyKey,
},
}
claims, err := svc.ParseWeChatPaymentResumeToken(token)
if err != nil {
t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err)
}
if claims.OpenID != "openid-legacy-key" {
t.Fatalf("openid = %q, want %q", claims.OpenID, "openid-legacy-key")
}
}
func TestNewConfiguredPaymentResumeServicePrefersExplicitSigningKeyAndKeepsLegacyVerificationFallback(t *testing.T) {
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "explicit-payment-resume-signing-key")
legacyKey := []byte("0123456789abcdef0123456789abcdef")
svc := newLegacyAwarePaymentResumeService(legacyKey)
explicitToken, err := svc.CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
OpenID: "openid-explicit-key",
PaymentType: payment.TypeWxpay,
})
if err != nil {
t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
}
explicitClaims, err := NewPaymentResumeService([]byte("explicit-payment-resume-signing-key")).ParseWeChatPaymentResumeToken(explicitToken)
if err != nil {
t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err)
}
if explicitClaims.OpenID != "openid-explicit-key" {
t.Fatalf("openid = %q, want %q", explicitClaims.OpenID, "openid-explicit-key")
}
legacyToken, err := NewPaymentResumeService(legacyKey).CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
OpenID: "openid-legacy-key",
PaymentType: payment.TypeWxpay,
})
if err != nil {
t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
}
legacyClaims, err := svc.ParseWeChatPaymentResumeToken(legacyToken)
if err != nil {
t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err)
}
if legacyClaims.OpenID != "openid-legacy-key" {
t.Fatalf("openid = %q, want %q", legacyClaims.OpenID, "openid-legacy-key")
}
}
func TestNormalizeVisibleMethodSource(t *testing.T) {
t.Parallel()
......@@ -376,6 +512,258 @@ func TestVisibleMethodLoadBalancerUsesEnabledProviderInstance(t *testing.T) {
}
}
func TestVisibleMethodLoadBalancerUsesConfiguredSourceWhenMultipleProvidersEnabled(t *testing.T) {
t.Parallel()
tests := []struct {
name string
method payment.PaymentType
officialName string
officialTypes string
easyPayName string
easyPayTypes string
sourceSetting string
wantProvider string
}{
{
name: "alipay uses official source",
method: payment.TypeAlipay,
officialName: "Official Alipay",
officialTypes: "alipay",
easyPayName: "EasyPay Alipay",
easyPayTypes: "alipay",
sourceSetting: VisibleMethodSourceOfficialAlipay,
wantProvider: payment.TypeAlipay,
},
{
name: "alipay uses easypay source",
method: payment.TypeAlipay,
officialName: "Official Alipay",
officialTypes: "alipay",
easyPayName: "EasyPay Alipay",
easyPayTypes: "alipay",
sourceSetting: VisibleMethodSourceEasyPayAlipay,
wantProvider: payment.TypeEasyPay,
},
{
name: "wxpay uses official source",
method: payment.TypeWxpay,
officialName: "Official WeChat",
officialTypes: "wxpay",
easyPayName: "EasyPay WeChat",
easyPayTypes: "wxpay",
sourceSetting: VisibleMethodSourceOfficialWechat,
wantProvider: payment.TypeWxpay,
},
{
name: "wxpay uses easypay source",
method: payment.TypeWxpay,
officialName: "Official WeChat",
officialTypes: "wxpay",
easyPayName: "EasyPay WeChat",
easyPayTypes: "wxpay",
sourceSetting: VisibleMethodSourceEasyPayWechat,
wantProvider: payment.TypeEasyPay,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
officialProviderKey := payment.TypeAlipay
if tt.method == payment.TypeWxpay {
officialProviderKey = payment.TypeWxpay
}
_, err := client.PaymentProviderInstance.Create().
SetProviderKey(officialProviderKey).
SetName(tt.officialName).
SetConfig("{}").
SetSupportedTypes(tt.officialTypes).
SetEnabled(true).
SetSortOrder(1).
Save(ctx)
if err != nil {
t.Fatalf("create official provider: %v", err)
}
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeEasyPay).
SetName(tt.easyPayName).
SetConfig("{}").
SetSupportedTypes(tt.easyPayTypes).
SetEnabled(true).
SetSortOrder(2).
Save(ctx)
if err != nil {
t.Fatalf("create easypay provider: %v", err)
}
inner := &captureLoadBalancer{}
configService := &PaymentConfigService{
entClient: client,
settingRepo: &paymentConfigSettingRepoStub{
values: map[string]string{
visibleMethodSourceSettingKey(tt.method): tt.sourceSetting,
},
},
}
lb := newVisibleMethodLoadBalancer(inner, configService)
_, err = lb.SelectInstance(ctx, "", tt.method, payment.StrategyRoundRobin, 12.5)
if err != nil {
t.Fatalf("SelectInstance returned error: %v", err)
}
if inner.lastProviderKey != tt.wantProvider {
t.Fatalf("lastProviderKey = %q, want %q", inner.lastProviderKey, tt.wantProvider)
}
})
}
}
func TestVisibleMethodLoadBalancerPreservesLegacyCrossProviderRoutingWhenSourceMissing(t *testing.T) {
t.Parallel()
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
_, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeAlipay).
SetName("Official Alipay").
SetConfig("{}").
SetSupportedTypes("alipay").
SetEnabled(true).
SetSortOrder(1).
Save(ctx)
if err != nil {
t.Fatalf("create official provider: %v", err)
}
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeEasyPay).
SetName("EasyPay Alipay").
SetConfig("{}").
SetSupportedTypes("alipay").
SetEnabled(true).
SetSortOrder(2).
Save(ctx)
if err != nil {
t.Fatalf("create easypay provider: %v", err)
}
inner := &captureLoadBalancer{}
configService := &PaymentConfigService{
entClient: client,
settingRepo: &paymentConfigSettingRepoStub{
values: map[string]string{
visibleMethodSourceSettingKey(payment.TypeAlipay): "",
},
},
}
lb := newVisibleMethodLoadBalancer(inner, configService)
_, err = lb.SelectInstance(ctx, "", payment.TypeAlipay, payment.StrategyRoundRobin, 9.9)
if err != nil {
t.Fatalf("SelectInstance returned error: %v", err)
}
if inner.lastProviderKey != "" {
t.Fatalf("lastProviderKey = %q, want legacy cross-provider empty key", inner.lastProviderKey)
}
if inner.lastPaymentType != payment.TypeAlipay {
t.Fatalf("lastPaymentType = %q, want %q", inner.lastPaymentType, payment.TypeAlipay)
}
}
func TestVisibleMethodLoadBalancerRejectsInvalidSourceWhenMultipleProvidersEnabled(t *testing.T) {
t.Parallel()
tests := []struct {
name string
method payment.PaymentType
sourceValue string
wantMessage string
}{
{
name: "invalid wxpay source",
method: payment.TypeWxpay,
sourceValue: "stripe",
wantMessage: "wxpay source must be one of the supported payment providers",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
officialProviderKey := payment.TypeAlipay
officialSupportedTypes := "alipay"
officialName := "Official Alipay"
easyPaySupportedTypes := "alipay"
easyPayName := "EasyPay Alipay"
if tt.method == payment.TypeWxpay {
officialProviderKey = payment.TypeWxpay
officialSupportedTypes = "wxpay"
officialName = "Official WeChat"
easyPaySupportedTypes = "wxpay"
easyPayName = "EasyPay WeChat"
}
_, err := client.PaymentProviderInstance.Create().
SetProviderKey(officialProviderKey).
SetName(officialName).
SetConfig("{}").
SetSupportedTypes(officialSupportedTypes).
SetEnabled(true).
SetSortOrder(1).
Save(ctx)
if err != nil {
t.Fatalf("create official provider: %v", err)
}
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeEasyPay).
SetName(easyPayName).
SetConfig("{}").
SetSupportedTypes(easyPaySupportedTypes).
SetEnabled(true).
SetSortOrder(2).
Save(ctx)
if err != nil {
t.Fatalf("create easypay provider: %v", err)
}
inner := &captureLoadBalancer{}
configService := &PaymentConfigService{
entClient: client,
settingRepo: &paymentConfigSettingRepoStub{
values: map[string]string{
visibleMethodSourceSettingKey(tt.method): tt.sourceValue,
},
},
}
lb := newVisibleMethodLoadBalancer(inner, configService)
_, err = lb.SelectInstance(ctx, "", tt.method, payment.StrategyRoundRobin, 9.9)
if err == nil {
t.Fatal("SelectInstance should reject invalid visible method source configuration")
}
if infraerrors.Reason(err) != "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE" {
t.Fatalf("Reason(err) = %q, want %q", infraerrors.Reason(err), "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE")
}
if infraerrors.Message(err) != tt.wantMessage {
t.Fatalf("Message(err) = %q, want %q", infraerrors.Message(err), tt.wantMessage)
}
})
}
}
func TestVisibleMethodLoadBalancerRejectsMissingEnabledVisibleMethodProvider(t *testing.T) {
t.Parallel()
......
package service
import (
"bytes"
"context"
"encoding/hex"
"fmt"
"log/slog"
"math/rand/v2"
"os"
"strings"
"sync"
"time"
......@@ -44,6 +48,8 @@ const (
orderIDPrefix = "sub2_"
)
const paymentResumeSigningKeyEnv = "PAYMENT_RESUME_SIGNING_KEY"
// --- Types ---
// generateOutTradeNo creates a unique external order ID for payment providers.
......@@ -179,7 +185,7 @@ type PaymentService struct {
func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService {
svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo}
svc.resumeService = NewPaymentResumeService(psResumeSigningKey(configService))
svc.resumeService = psNewPaymentResumeService(configService)
return svc
}
......@@ -259,16 +265,56 @@ func (s *PaymentService) paymentResume() *PaymentResumeService {
if s.resumeService != nil {
return s.resumeService
}
return NewPaymentResumeService(psResumeSigningKey(s.configService))
return psNewPaymentResumeService(s.configService)
}
func NewLegacyAwarePaymentResumeService(legacyKey []byte) *PaymentResumeService {
return newLegacyAwarePaymentResumeService(legacyKey)
}
func psNewPaymentResumeService(configService *PaymentConfigService) *PaymentResumeService {
return newLegacyAwarePaymentResumeService(psResumeLegacyVerificationKey(configService))
}
func psResumeSigningKey(configService *PaymentConfigService) []byte {
func newLegacyAwarePaymentResumeService(legacyKey []byte) *PaymentResumeService {
signingKey, verifyFallbacks := resolvePaymentResumeSigningKeys(legacyKey)
return NewPaymentResumeService(signingKey, verifyFallbacks...)
}
func psResumeLegacyVerificationKey(configService *PaymentConfigService) []byte {
if configService == nil {
return nil
}
return configService.encryptionKey
}
func resolvePaymentResumeSigningKeys(legacyKey []byte) ([]byte, [][]byte) {
signingKey := parsePaymentResumeSigningKey(os.Getenv(paymentResumeSigningKeyEnv))
if len(signingKey) == 0 {
if len(legacyKey) == 0 {
return nil, nil
}
return legacyKey, nil
}
if len(legacyKey) == 0 || bytes.Equal(legacyKey, signingKey) {
return signingKey, nil
}
return signingKey, [][]byte{legacyKey}
}
func parsePaymentResumeSigningKey(raw string) []byte {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil
}
if len(raw) >= 64 && len(raw)%2 == 0 {
if decoded, err := hex.DecodeString(raw); err == nil && len(decoded) > 0 {
return decoded
}
}
return []byte(raw)
}
func psSliceContains(sl []string, s string) bool {
for _, v := range sl {
if v == s {
......
......@@ -2,6 +2,7 @@ package service
import (
"context"
"errors"
"fmt"
"strings"
......@@ -82,19 +83,52 @@ func filterEnabledVisibleMethodInstances(instances []*dbent.PaymentProviderInsta
return filtered
}
func buildPaymentProviderConflictError(method string, conflicting *dbent.PaymentProviderInstance) error {
metadata := map[string]string{
"payment_method": NormalizeVisibleMethod(method),
func filterVisibleMethodInstancesByProviderKey(instances []*dbent.PaymentProviderInstance, method string, providerKey string) []*dbent.PaymentProviderInstance {
filtered := make([]*dbent.PaymentProviderInstance, 0, len(instances))
for _, inst := range instances {
if !providerSupportsVisibleMethod(inst, method) {
continue
}
if !strings.EqualFold(strings.TrimSpace(inst.ProviderKey), strings.TrimSpace(providerKey)) {
continue
}
filtered = append(filtered, inst)
}
return filtered
}
func distinctVisibleMethodProviderKeys(instances []*dbent.PaymentProviderInstance) []string {
seen := make(map[string]struct{}, len(instances))
keys := make([]string, 0, len(instances))
for _, inst := range instances {
if inst == nil {
continue
}
key := strings.TrimSpace(inst.ProviderKey)
if key == "" {
continue
}
normalized := strings.ToLower(key)
if _, ok := seen[normalized]; ok {
continue
}
seen[normalized] = struct{}{}
keys = append(keys, key)
}
if conflicting != nil {
metadata["conflicting_provider_id"] = fmt.Sprintf("%d", conflicting.ID)
metadata["conflicting_provider_key"] = conflicting.ProviderKey
metadata["conflicting_provider_name"] = conflicting.Name
return keys
}
func selectVisibleMethodInstanceByProviderKey(instances []*dbent.PaymentProviderInstance, providerKey string) *dbent.PaymentProviderInstance {
providerKey = strings.TrimSpace(providerKey)
if providerKey == "" {
return nil
}
for _, inst := range instances {
if strings.EqualFold(strings.TrimSpace(inst.ProviderKey), providerKey) {
return inst
}
}
return infraerrors.Conflict(
"PAYMENT_PROVIDER_CONFLICT",
fmt.Sprintf("%s payment already has an enabled provider instance", NormalizeVisibleMethod(method)),
).WithMetadata(metadata)
return nil
}
func (s *PaymentConfigService) validateVisibleMethodEnablementConflicts(
......@@ -104,33 +138,72 @@ func (s *PaymentConfigService) validateVisibleMethodEnablementConflicts(
supportedTypes string,
enabled bool,
) error {
if s == nil || s.entClient == nil || !enabled {
return nil
}
// Visible methods are selected by configured source (official/easypay),
// so multiple enabled providers can intentionally claim the same user-facing
// method. Order creation and limits will route through the configured source.
_, _, _, _, _ = ctx, excludeID, providerKey, supportedTypes, enabled
return nil
}
claimedMethods := enabledVisibleMethodsForProvider(providerKey, supportedTypes)
if len(claimedMethods) == 0 {
return nil
func (s *PaymentConfigService) resolveVisibleMethodSourceProviderKey(ctx context.Context, method string) (string, error) {
method = NormalizeVisibleMethod(method)
sourceKey := visibleMethodSourceSettingKey(method)
rawSource := ""
if s != nil && s.settingRepo != nil && sourceKey != "" {
value, err := s.settingRepo.GetValue(ctx, sourceKey)
if err != nil {
if !errors.Is(err, ErrSettingNotFound) {
return "", fmt.Errorf("get %s: %w", sourceKey, err)
}
} else {
rawSource = value
}
}
query := s.entClient.PaymentProviderInstance.Query().
Where(paymentproviderinstance.EnabledEQ(true))
if excludeID > 0 {
query = query.Where(paymentproviderinstance.IDNEQ(excludeID))
}
instances, err := query.All(ctx)
normalizedSource, err := normalizeVisibleMethodSettingSource(method, rawSource, true)
if err != nil {
return fmt.Errorf("query enabled payment providers: %w", err)
return "", err
}
if normalizedSource == "" {
return "", nil
}
providerKey, ok := VisibleMethodProviderKeyForSource(method, normalizedSource)
if !ok {
return "", infraerrors.BadRequest(
"INVALID_PAYMENT_VISIBLE_METHOD_SOURCE",
fmt.Sprintf("%s source must be one of the supported payment providers", method),
)
}
return providerKey, nil
}
for _, method := range claimedMethods {
for _, inst := range instances {
if providerSupportsVisibleMethod(inst, method) {
return buildPaymentProviderConflictError(method, inst)
}
func (s *PaymentConfigService) resolveVisibleMethodProviderKey(
ctx context.Context,
method string,
matching []*dbent.PaymentProviderInstance,
) (string, error) {
switch providerKeys := distinctVisibleMethodProviderKeys(matching); len(providerKeys) {
case 0:
return "", nil
case 1:
return strings.TrimSpace(providerKeys[0]), nil
default:
providerKey, err := s.resolveVisibleMethodSourceProviderKey(ctx, method)
if err != nil {
return "", err
}
if providerKey == "" {
return "", nil
}
selected := selectVisibleMethodInstanceByProviderKey(matching, providerKey)
if selected == nil {
return "", infraerrors.BadRequest(
"INVALID_PAYMENT_VISIBLE_METHOD_SOURCE",
fmt.Sprintf("%s source has no enabled provider instance", method),
)
}
return strings.TrimSpace(selected.ProviderKey), nil
}
return nil
}
func (s *PaymentConfigService) resolveEnabledVisibleMethodInstance(
......@@ -155,12 +228,15 @@ func (s *PaymentConfigService) resolveEnabledVisibleMethodInstance(
}
matching := filterEnabledVisibleMethodInstances(instances, method)
switch len(matching) {
case 0:
return nil, nil
case 1:
return matching[0], nil
default:
return nil, buildPaymentProviderConflictError(method, matching[0])
providerKey, err := s.resolveVisibleMethodProviderKey(ctx, method, matching)
if err != nil {
return nil, err
}
if providerKey == "" {
if len(matching) == 0 {
return nil, nil
}
return &dbent.PaymentProviderInstance{ProviderKey: ""}, nil
}
return selectVisibleMethodInstanceByProviderKey(matching, providerKey), nil
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment