Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
ddf80f5e
Unverified
Commit
ddf80f5e
authored
Apr 22, 2026
by
Wesley Liddick
Committed by
GitHub
Apr 22, 2026
Browse files
Merge pull request #1799 from IanShaw027/rebuild/auth-identity-foundation
fix(auth,payment,profile): 修复认证身份和支付系统的后续问题
parents
4d0483f5
c048ca80
Changes
140
Hide whitespace changes
Inline
Side-by-side
backend/internal/service/auth_pending_identity_service_test.go
View file @
ddf80f5e
...
@@ -5,6 +5,7 @@ package service
...
@@ -5,6 +5,7 @@ package service
import
(
import
(
"context"
"context"
"database/sql"
"database/sql"
"sync"
"testing"
"testing"
"time"
"time"
...
@@ -259,6 +260,107 @@ func TestAuthPendingIdentityService_UpsertAdoptionDecision_ReassignsExistingIden
...
@@ -259,6 +260,107 @@ func TestAuthPendingIdentityService_UpsertAdoptionDecision_ReassignsExistingIden
require
.
Nil
(
t
,
reloadedFirst
.
IdentityID
)
require
.
Nil
(
t
,
reloadedFirst
.
IdentityID
)
}
}
func
TestAuthPendingIdentityService_UpsertAdoptionDecision_IsIdempotentUnderConcurrency
(
t
*
testing
.
T
)
{
svc
,
client
:=
newAuthPendingIdentityServiceTestClient
(
t
)
ctx
:=
context
.
Background
()
user
,
err
:=
client
.
User
.
Create
()
.
SetEmail
(
"adoption-concurrent@example.com"
)
.
SetPasswordHash
(
"hash"
)
.
SetRole
(
RoleUser
)
.
SetStatus
(
StatusActive
)
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
identity
,
err
:=
client
.
AuthIdentity
.
Create
()
.
SetUserID
(
user
.
ID
)
.
SetProviderType
(
"wechat"
)
.
SetProviderKey
(
"wechat-main"
)
.
SetProviderSubject
(
"union-concurrent"
)
.
SetMetadata
(
map
[
string
]
any
{})
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
session
,
err
:=
svc
.
CreatePendingSession
(
ctx
,
CreatePendingAuthSessionInput
{
Intent
:
"bind_current_user"
,
Identity
:
PendingAuthIdentityKey
{
ProviderType
:
"wechat"
,
ProviderKey
:
"wechat-main"
,
ProviderSubject
:
"union-concurrent"
,
},
})
require
.
NoError
(
t
,
err
)
firstCreateStarted
:=
make
(
chan
struct
{})
releaseFirstCreate
:=
make
(
chan
struct
{})
var
firstCreate
sync
.
Once
client
.
IdentityAdoptionDecision
.
Use
(
func
(
next
dbent
.
Mutator
)
dbent
.
Mutator
{
return
dbent
.
MutateFunc
(
func
(
ctx
context
.
Context
,
m
dbent
.
Mutation
)
(
dbent
.
Value
,
error
)
{
blocked
:=
false
if
m
.
Op
()
.
Is
(
dbent
.
OpCreate
)
{
firstCreate
.
Do
(
func
()
{
blocked
=
true
close
(
firstCreateStarted
)
})
}
if
blocked
{
<-
releaseFirstCreate
}
return
next
.
Mutate
(
ctx
,
m
)
})
})
type
adoptionResult
struct
{
decision
*
dbent
.
IdentityAdoptionDecision
err
error
}
input
:=
PendingIdentityAdoptionDecisionInput
{
PendingAuthSessionID
:
session
.
ID
,
IdentityID
:
&
identity
.
ID
,
AdoptDisplayName
:
true
,
AdoptAvatar
:
true
,
}
results
:=
make
(
chan
adoptionResult
,
2
)
go
func
()
{
decision
,
err
:=
svc
.
UpsertAdoptionDecision
(
ctx
,
input
)
results
<-
adoptionResult
{
decision
:
decision
,
err
:
err
}
}()
<-
firstCreateStarted
go
func
()
{
decision
,
err
:=
svc
.
UpsertAdoptionDecision
(
ctx
,
input
)
results
<-
adoptionResult
{
decision
:
decision
,
err
:
err
}
}()
time
.
Sleep
(
100
*
time
.
Millisecond
)
close
(
releaseFirstCreate
)
first
:=
<-
results
second
:=
<-
results
require
.
NoError
(
t
,
first
.
err
)
require
.
NoError
(
t
,
second
.
err
)
require
.
NotNil
(
t
,
first
.
decision
)
require
.
NotNil
(
t
,
second
.
decision
)
require
.
Equal
(
t
,
first
.
decision
.
ID
,
second
.
decision
.
ID
)
count
,
err
:=
client
.
IdentityAdoptionDecision
.
Query
()
.
Where
(
identityadoptiondecision
.
PendingAuthSessionIDEQ
(
session
.
ID
))
.
Count
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
count
)
loaded
,
err
:=
client
.
IdentityAdoptionDecision
.
Query
()
.
Where
(
identityadoptiondecision
.
PendingAuthSessionIDEQ
(
session
.
ID
))
.
Only
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
loaded
.
IdentityID
)
require
.
Equal
(
t
,
identity
.
ID
,
*
loaded
.
IdentityID
)
}
func
TestAuthPendingIdentityService_UpsertAdoptionDecision_ClearsLegacyNullSessionReference
(
t
*
testing
.
T
)
{
func
TestAuthPendingIdentityService_UpsertAdoptionDecision_ClearsLegacyNullSessionReference
(
t
*
testing
.
T
)
{
t
.
Skip
(
"legacy NULL pending_auth_session_id rows only exist in production PostgreSQL history; sqlite unit schema rejects NULL"
)
t
.
Skip
(
"legacy NULL pending_auth_session_id rows only exist in production PostgreSQL history; sqlite unit schema rejects NULL"
)
...
@@ -356,3 +458,69 @@ func TestAuthPendingIdentityService_ConsumeBrowserSession(t *testing.T) {
...
@@ -356,3 +458,69 @@ func TestAuthPendingIdentityService_ConsumeBrowserSession(t *testing.T) {
_
,
err
=
svc
.
ConsumeBrowserSession
(
ctx
,
session
.
SessionToken
,
"browser-session"
)
_
,
err
=
svc
.
ConsumeBrowserSession
(
ctx
,
session
.
SessionToken
,
"browser-session"
)
require
.
ErrorIs
(
t
,
err
,
ErrPendingAuthSessionConsumed
)
require
.
ErrorIs
(
t
,
err
,
ErrPendingAuthSessionConsumed
)
}
}
func
TestAuthPendingIdentityService_ConsumeBrowserSessionRejectsStaleLoadedSessionReplay
(
t
*
testing
.
T
)
{
svc
,
_
:=
newAuthPendingIdentityServiceTestClient
(
t
)
ctx
:=
context
.
Background
()
session
,
err
:=
svc
.
CreatePendingSession
(
ctx
,
CreatePendingAuthSessionInput
{
Intent
:
"login"
,
Identity
:
PendingAuthIdentityKey
{
ProviderType
:
"linuxdo"
,
ProviderKey
:
"linuxdo"
,
ProviderSubject
:
"stale-replay-subject"
,
},
BrowserSessionKey
:
"browser-session"
,
})
require
.
NoError
(
t
,
err
)
loaded
,
err
:=
svc
.
getBrowserSession
(
ctx
,
session
.
SessionToken
)
require
.
NoError
(
t
,
err
)
consumed
,
err
:=
svc
.
consumeSession
(
ctx
,
loaded
,
"browser-session"
,
ErrPendingAuthSessionExpired
,
ErrPendingAuthSessionConsumed
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
consumed
.
ConsumedAt
)
_
,
err
=
svc
.
consumeSession
(
ctx
,
loaded
,
"browser-session"
,
ErrPendingAuthSessionExpired
,
ErrPendingAuthSessionConsumed
)
require
.
ErrorIs
(
t
,
err
,
ErrPendingAuthSessionConsumed
)
}
func
TestAuthPendingIdentityService_ConsumeBrowserSessionScrubsLegacyCompletionTokens
(
t
*
testing
.
T
)
{
svc
,
client
:=
newAuthPendingIdentityServiceTestClient
(
t
)
ctx
:=
context
.
Background
()
session
,
err
:=
svc
.
CreatePendingSession
(
ctx
,
CreatePendingAuthSessionInput
{
Intent
:
"login"
,
Identity
:
PendingAuthIdentityKey
{
ProviderType
:
"linuxdo"
,
ProviderKey
:
"linuxdo"
,
ProviderSubject
:
"legacy-token-subject"
,
},
BrowserSessionKey
:
"browser-session"
,
LocalFlowState
:
map
[
string
]
any
{
"completion_response"
:
map
[
string
]
any
{
"access_token"
:
"legacy-access-token"
,
"refresh_token"
:
"legacy-refresh-token"
,
"expires_in"
:
float64
(
3600
),
"token_type"
:
"Bearer"
,
"redirect"
:
"/dashboard"
,
},
},
})
require
.
NoError
(
t
,
err
)
consumed
,
err
:=
svc
.
ConsumeBrowserSession
(
ctx
,
session
.
SessionToken
,
"browser-session"
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
consumed
.
ConsumedAt
)
stored
,
err
:=
client
.
PendingAuthSession
.
Get
(
ctx
,
session
.
ID
)
require
.
NoError
(
t
,
err
)
completion
,
ok
:=
stored
.
LocalFlowState
[
"completion_response"
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
NotContains
(
t
,
completion
,
"access_token"
)
require
.
NotContains
(
t
,
completion
,
"refresh_token"
)
require
.
NotContains
(
t
,
completion
,
"expires_in"
)
require
.
NotContains
(
t
,
completion
,
"token_type"
)
require
.
Equal
(
t
,
"/dashboard"
,
completion
[
"redirect"
])
}
backend/internal/service/auth_service.go
View file @
ddf80f5e
...
@@ -4,6 +4,7 @@ import (
...
@@ -4,6 +4,7 @@ import (
"context"
"context"
"crypto/rand"
"crypto/rand"
"crypto/sha256"
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"encoding/hex"
"errors"
"errors"
"fmt"
"fmt"
...
@@ -489,6 +490,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
...
@@ -489,6 +490,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
Balance
:
grantPlan
.
Balance
,
Balance
:
grantPlan
.
Balance
,
Concurrency
:
grantPlan
.
Concurrency
,
Concurrency
:
grantPlan
.
Concurrency
,
Status
:
StatusActive
,
Status
:
StatusActive
,
SignupSource
:
signupSource
,
}
}
if
err
:=
s
.
userRepo
.
Create
(
ctx
,
newUser
);
err
!=
nil
{
if
err
:=
s
.
userRepo
.
Create
(
ctx
,
newUser
);
err
!=
nil
{
...
@@ -599,6 +601,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
...
@@ -599,6 +601,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
Balance
:
grantPlan
.
Balance
,
Balance
:
grantPlan
.
Balance
,
Concurrency
:
grantPlan
.
Concurrency
,
Concurrency
:
grantPlan
.
Concurrency
,
Status
:
StatusActive
,
Status
:
StatusActive
,
SignupSource
:
signupSource
,
}
}
if
s
.
entClient
!=
nil
&&
invitationRedeemCode
!=
nil
{
if
s
.
entClient
!=
nil
&&
invitationRedeemCode
!=
nil
{
...
@@ -1048,7 +1051,7 @@ func (s *AuthService) GenerateToken(user *User) (string, error) {
...
@@ -1048,7 +1051,7 @@ func (s *AuthService) GenerateToken(user *User) (string, error) {
UserID
:
user
.
ID
,
UserID
:
user
.
ID
,
Email
:
user
.
Email
,
Email
:
user
.
Email
,
Role
:
user
.
Role
,
Role
:
user
.
Role
,
TokenVersion
:
user
.
TokenVersion
,
TokenVersion
:
resolved
TokenVersion
(
user
)
,
RegisteredClaims
:
jwt
.
RegisteredClaims
{
RegisteredClaims
:
jwt
.
RegisteredClaims
{
ExpiresAt
:
jwt
.
NewNumericDate
(
expiresAt
),
ExpiresAt
:
jwt
.
NewNumericDate
(
expiresAt
),
IssuedAt
:
jwt
.
NewNumericDate
(
now
),
IssuedAt
:
jwt
.
NewNumericDate
(
now
),
...
@@ -1114,7 +1117,7 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
...
@@ -1114,7 +1117,7 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
// Security: Check TokenVersion to prevent refreshing revoked tokens
// Security: Check TokenVersion to prevent refreshing revoked tokens
// This ensures tokens issued before a password change cannot be refreshed
// This ensures tokens issued before a password change cannot be refreshed
if
claims
.
TokenVersion
!=
user
.
TokenVersion
{
if
claims
.
TokenVersion
!=
resolved
TokenVersion
(
user
)
{
return
""
,
ErrTokenRevoked
return
""
,
ErrTokenRevoked
}
}
...
@@ -1342,7 +1345,7 @@ func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, fami
...
@@ -1342,7 +1345,7 @@ func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, fami
data
:=
&
RefreshTokenData
{
data
:=
&
RefreshTokenData
{
UserID
:
user
.
ID
,
UserID
:
user
.
ID
,
TokenVersion
:
user
.
TokenVersion
,
TokenVersion
:
resolved
TokenVersion
(
user
)
,
FamilyID
:
familyID
,
FamilyID
:
familyID
,
CreatedAt
:
now
,
CreatedAt
:
now
,
ExpiresAt
:
now
.
Add
(
ttl
),
ExpiresAt
:
now
.
Add
(
ttl
),
...
@@ -1422,7 +1425,7 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string)
...
@@ -1422,7 +1425,7 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string)
}
}
// 检查TokenVersion(密码更改后所有Token失效)
// 检查TokenVersion(密码更改后所有Token失效)
if
data
.
TokenVersion
!=
user
.
TokenVersion
{
if
data
.
TokenVersion
!=
resolved
TokenVersion
(
user
)
{
// TokenVersion不匹配,撤销整个Token家族
// TokenVersion不匹配,撤销整个Token家族
_
=
s
.
refreshTokenCache
.
DeleteTokenFamily
(
ctx
,
data
.
FamilyID
)
_
=
s
.
refreshTokenCache
.
DeleteTokenFamily
(
ctx
,
data
.
FamilyID
)
return
nil
,
ErrTokenRevoked
return
nil
,
ErrTokenRevoked
...
@@ -1467,8 +1470,42 @@ func (s *AuthService) RevokeAllUserSessions(ctx context.Context, userID int64) e
...
@@ -1467,8 +1470,42 @@ func (s *AuthService) RevokeAllUserSessions(ctx context.Context, userID int64) e
return
s
.
refreshTokenCache
.
DeleteUserRefreshTokens
(
ctx
,
userID
)
return
s
.
refreshTokenCache
.
DeleteUserRefreshTokens
(
ctx
,
userID
)
}
}
// RevokeAllUserTokens invalidates both stateless access tokens and refresh sessions.
// Access/refresh token verification both depend on TokenVersion, so bumping it provides
// immediate revocation even if refresh-token cache cleanup later fails.
func
(
s
*
AuthService
)
RevokeAllUserTokens
(
ctx
context
.
Context
,
userID
int64
)
error
{
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"get user: %w"
,
err
)
}
user
.
TokenVersion
++
if
err
:=
s
.
userRepo
.
Update
(
ctx
,
user
);
err
!=
nil
{
return
fmt
.
Errorf
(
"update user: %w"
,
err
)
}
if
err
:=
s
.
RevokeAllUserSessions
(
ctx
,
userID
);
err
!=
nil
{
logger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Failed to revoke refresh sessions after token invalidation for user %d: %v"
,
userID
,
err
)
}
return
nil
}
// hashToken 计算Token的SHA256哈希
// hashToken 计算Token的SHA256哈希
func
hashToken
(
token
string
)
string
{
func
hashToken
(
token
string
)
string
{
hash
:=
sha256
.
Sum256
([]
byte
(
token
))
hash
:=
sha256
.
Sum256
([]
byte
(
token
))
return
hex
.
EncodeToString
(
hash
[
:
])
return
hex
.
EncodeToString
(
hash
[
:
])
}
}
func
resolvedTokenVersion
(
user
*
User
)
int64
{
if
user
==
nil
{
return
0
}
if
user
.
TokenVersionResolved
{
return
user
.
TokenVersion
}
material
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
user
.
Email
))
+
"
\n
"
+
user
.
PasswordHash
sum
:=
sha256
.
Sum256
([]
byte
(
material
))
fingerprint
:=
int64
(
binary
.
BigEndian
.
Uint64
(
sum
[
:
8
])
&
0x7fffffffffffffff
)
return
user
.
TokenVersion
^
fingerprint
}
backend/internal/service/auth_service_email_bind_test.go
View file @
ddf80f5e
...
@@ -6,6 +6,7 @@ import (
...
@@ -6,6 +6,7 @@ import (
"context"
"context"
"database/sql"
"database/sql"
"errors"
"errors"
"sync"
"testing"
"testing"
"time"
"time"
...
@@ -13,6 +14,7 @@ import (
...
@@ -13,6 +14,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/require"
...
@@ -54,6 +56,16 @@ func newAuthServiceForEmailBind(
...
@@ -54,6 +56,16 @@ func newAuthServiceForEmailBind(
settings
map
[
string
]
string
,
settings
map
[
string
]
string
,
emailCache
service
.
EmailCache
,
emailCache
service
.
EmailCache
,
defaultSubAssigner
service
.
DefaultSubscriptionAssigner
,
defaultSubAssigner
service
.
DefaultSubscriptionAssigner
,
)
(
*
service
.
AuthService
,
service
.
UserRepository
,
*
dbent
.
Client
)
{
return
newAuthServiceForEmailBindWithRefreshCache
(
t
,
settings
,
emailCache
,
defaultSubAssigner
,
nil
)
}
func
newAuthServiceForEmailBindWithRefreshCache
(
t
*
testing
.
T
,
settings
map
[
string
]
string
,
emailCache
service
.
EmailCache
,
defaultSubAssigner
service
.
DefaultSubscriptionAssigner
,
refreshTokenCache
service
.
RefreshTokenCache
,
)
(
*
service
.
AuthService
,
service
.
UserRepository
,
*
dbent
.
Client
)
{
)
(
*
service
.
AuthService
,
service
.
UserRepository
,
*
dbent
.
Client
)
{
t
.
Helper
()
t
.
Helper
()
...
@@ -98,7 +110,7 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants (
...
@@ -98,7 +110,7 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants (
emailSvc
=
service
.
NewEmailService
(
settingRepo
,
emailCache
)
emailSvc
=
service
.
NewEmailService
(
settingRepo
,
emailCache
)
}
}
svc
:=
service
.
NewAuthService
(
client
,
repo
,
nil
,
nil
,
cfg
,
settingSvc
,
emailSvc
,
nil
,
nil
,
nil
,
defaultSubAssigner
)
svc
:=
service
.
NewAuthService
(
client
,
repo
,
nil
,
refreshTokenCache
,
cfg
,
settingSvc
,
emailSvc
,
nil
,
nil
,
nil
,
defaultSubAssigner
)
return
svc
,
repo
,
client
return
svc
,
repo
,
client
}
}
...
@@ -427,6 +439,61 @@ func TestAuthServiceBindEmailIdentity_RejectsWrongCurrentPasswordForBoundEmail(t
...
@@ -427,6 +439,61 @@ func TestAuthServiceBindEmailIdentity_RejectsWrongCurrentPasswordForBoundEmail(t
require
.
Equal
(
t
,
0
,
newIdentityCount
)
require
.
Equal
(
t
,
0
,
newIdentityCount
)
}
}
func
TestAuthServiceBindEmailIdentity_RevokesExistingAccessAndRefreshTokens
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
cache
:=
&
emailBindCacheStub
{
data
:
&
service
.
VerificationCodeData
{
Code
:
"123456"
,
CreatedAt
:
time
.
Now
()
.
UTC
(),
ExpiresAt
:
time
.
Now
()
.
UTC
()
.
Add
(
10
*
time
.
Minute
),
},
}
refreshTokenCache
:=
newEmailBindRefreshTokenCacheStub
()
userRepo
:=
newEmailBindUserRepoStub
(
&
service
.
User
{
ID
:
41
,
Email
:
"legacy-user"
+
service
.
OIDCConnectSyntheticEmailDomain
,
Username
:
"legacy-user"
,
PasswordHash
:
"old-hash"
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
TokenVersion
:
4
,
})
cfg
:=
&
config
.
Config
{
JWT
:
config
.
JWTConfig
{
Secret
:
"test-bind-email-secret"
,
ExpireHour
:
1
,
AccessTokenExpireMinutes
:
60
,
RefreshTokenExpireDays
:
7
,
},
}
emailService
:=
service
.
NewEmailService
(
nil
,
cache
)
svc
:=
service
.
NewAuthService
(
nil
,
userRepo
,
nil
,
refreshTokenCache
,
cfg
,
nil
,
emailService
,
nil
,
nil
,
nil
,
nil
)
oldTokenPair
,
err
:=
svc
.
GenerateTokenPair
(
ctx
,
&
service
.
User
{
ID
:
41
,
Email
:
"legacy-user"
+
service
.
OIDCConnectSyntheticEmailDomain
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
TokenVersion
:
4
,
},
""
)
require
.
NoError
(
t
,
err
)
updatedUser
,
err
:=
svc
.
BindEmailIdentity
(
ctx
,
41
,
"new@example.com"
,
"123456"
,
"new-password"
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
updatedUser
)
storedUser
,
err
:=
userRepo
.
GetByID
(
ctx
,
41
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"new@example.com"
,
storedUser
.
Email
)
require
.
True
(
t
,
svc
.
CheckPassword
(
"new-password"
,
storedUser
.
PasswordHash
))
_
,
err
=
svc
.
RefreshToken
(
ctx
,
oldTokenPair
.
AccessToken
)
require
.
ErrorIs
(
t
,
err
,
service
.
ErrTokenRevoked
)
_
,
err
=
svc
.
RefreshTokenPair
(
ctx
,
oldTokenPair
.
RefreshToken
)
require
.
True
(
t
,
errors
.
Is
(
err
,
service
.
ErrTokenRevoked
)
||
errors
.
Is
(
err
,
service
.
ErrRefreshTokenInvalid
))
}
type
emailBindSettingRepoStub
struct
{
type
emailBindSettingRepoStub
struct
{
values
map
[
string
]
string
values
map
[
string
]
string
}
}
...
@@ -527,3 +594,260 @@ func (s *emailBindCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int6
...
@@ -527,3 +594,260 @@ func (s *emailBindCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int6
func
(
s
*
emailBindCacheStub
)
IncrNotifyCodeUserRate
(
context
.
Context
,
int64
,
time
.
Duration
)
(
int64
,
error
)
{
func
(
s
*
emailBindCacheStub
)
IncrNotifyCodeUserRate
(
context
.
Context
,
int64
,
time
.
Duration
)
(
int64
,
error
)
{
return
0
,
nil
return
0
,
nil
}
}
type
emailBindRefreshTokenCacheStub
struct
{
mu
sync
.
Mutex
tokens
map
[
string
]
*
service
.
RefreshTokenData
userSets
map
[
int64
]
map
[
string
]
struct
{}
families
map
[
string
]
map
[
string
]
struct
{}
}
func
newEmailBindRefreshTokenCacheStub
()
*
emailBindRefreshTokenCacheStub
{
return
&
emailBindRefreshTokenCacheStub
{
tokens
:
make
(
map
[
string
]
*
service
.
RefreshTokenData
),
userSets
:
make
(
map
[
int64
]
map
[
string
]
struct
{}),
families
:
make
(
map
[
string
]
map
[
string
]
struct
{}),
}
}
func
(
s
*
emailBindRefreshTokenCacheStub
)
StoreRefreshToken
(
_
context
.
Context
,
tokenHash
string
,
data
*
service
.
RefreshTokenData
,
_
time
.
Duration
)
error
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
cloned
:=
*
data
s
.
tokens
[
tokenHash
]
=
&
cloned
return
nil
}
func
(
s
*
emailBindRefreshTokenCacheStub
)
GetRefreshToken
(
_
context
.
Context
,
tokenHash
string
)
(
*
service
.
RefreshTokenData
,
error
)
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
data
,
ok
:=
s
.
tokens
[
tokenHash
]
if
!
ok
{
return
nil
,
service
.
ErrRefreshTokenNotFound
}
cloned
:=
*
data
return
&
cloned
,
nil
}
func
(
s
*
emailBindRefreshTokenCacheStub
)
DeleteRefreshToken
(
_
context
.
Context
,
tokenHash
string
)
error
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
delete
(
s
.
tokens
,
tokenHash
)
for
_
,
tokenSet
:=
range
s
.
userSets
{
delete
(
tokenSet
,
tokenHash
)
}
for
_
,
tokenSet
:=
range
s
.
families
{
delete
(
tokenSet
,
tokenHash
)
}
return
nil
}
func
(
s
*
emailBindRefreshTokenCacheStub
)
DeleteUserRefreshTokens
(
_
context
.
Context
,
userID
int64
)
error
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
for
tokenHash
:=
range
s
.
userSets
[
userID
]
{
delete
(
s
.
tokens
,
tokenHash
)
for
_
,
tokenSet
:=
range
s
.
families
{
delete
(
tokenSet
,
tokenHash
)
}
}
delete
(
s
.
userSets
,
userID
)
return
nil
}
func
(
s
*
emailBindRefreshTokenCacheStub
)
DeleteTokenFamily
(
_
context
.
Context
,
familyID
string
)
error
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
for
tokenHash
:=
range
s
.
families
[
familyID
]
{
delete
(
s
.
tokens
,
tokenHash
)
for
_
,
tokenSet
:=
range
s
.
userSets
{
delete
(
tokenSet
,
tokenHash
)
}
}
delete
(
s
.
families
,
familyID
)
return
nil
}
func
(
s
*
emailBindRefreshTokenCacheStub
)
AddToUserTokenSet
(
_
context
.
Context
,
userID
int64
,
tokenHash
string
,
_
time
.
Duration
)
error
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
if
s
.
userSets
[
userID
]
==
nil
{
s
.
userSets
[
userID
]
=
make
(
map
[
string
]
struct
{})
}
s
.
userSets
[
userID
][
tokenHash
]
=
struct
{}{}
return
nil
}
func
(
s
*
emailBindRefreshTokenCacheStub
)
AddToFamilyTokenSet
(
_
context
.
Context
,
familyID
string
,
tokenHash
string
,
_
time
.
Duration
)
error
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
if
s
.
families
[
familyID
]
==
nil
{
s
.
families
[
familyID
]
=
make
(
map
[
string
]
struct
{})
}
s
.
families
[
familyID
][
tokenHash
]
=
struct
{}{}
return
nil
}
func
(
s
*
emailBindRefreshTokenCacheStub
)
GetUserTokenHashes
(
_
context
.
Context
,
userID
int64
)
([]
string
,
error
)
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
tokenSet
:=
s
.
userSets
[
userID
]
out
:=
make
([]
string
,
0
,
len
(
tokenSet
))
for
tokenHash
:=
range
tokenSet
{
out
=
append
(
out
,
tokenHash
)
}
return
out
,
nil
}
func
(
s
*
emailBindRefreshTokenCacheStub
)
GetFamilyTokenHashes
(
_
context
.
Context
,
familyID
string
)
([]
string
,
error
)
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
tokenSet
:=
s
.
families
[
familyID
]
out
:=
make
([]
string
,
0
,
len
(
tokenSet
))
for
tokenHash
:=
range
tokenSet
{
out
=
append
(
out
,
tokenHash
)
}
return
out
,
nil
}
func
(
s
*
emailBindRefreshTokenCacheStub
)
IsTokenInFamily
(
_
context
.
Context
,
familyID
string
,
tokenHash
string
)
(
bool
,
error
)
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
_
,
ok
:=
s
.
families
[
familyID
][
tokenHash
]
return
ok
,
nil
}
type
emailBindUserRepoStub
struct
{
mu
sync
.
Mutex
usersByID
map
[
int64
]
*
service
.
User
usersByEmail
map
[
string
]
*
service
.
User
}
func
newEmailBindUserRepoStub
(
user
*
service
.
User
)
*
emailBindUserRepoStub
{
cloned
:=
cloneEmailBindUser
(
user
)
return
&
emailBindUserRepoStub
{
usersByID
:
map
[
int64
]
*
service
.
User
{
cloned
.
ID
:
cloned
,
},
usersByEmail
:
map
[
string
]
*
service
.
User
{
cloned
.
Email
:
cloned
,
},
}
}
func
(
s
*
emailBindUserRepoStub
)
Create
(
context
.
Context
,
*
service
.
User
)
error
{
return
nil
}
func
(
s
*
emailBindUserRepoStub
)
GetByID
(
_
context
.
Context
,
id
int64
)
(
*
service
.
User
,
error
)
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
user
,
ok
:=
s
.
usersByID
[
id
]
if
!
ok
{
return
nil
,
service
.
ErrUserNotFound
}
return
cloneEmailBindUser
(
user
),
nil
}
func
(
s
*
emailBindUserRepoStub
)
GetByEmail
(
_
context
.
Context
,
email
string
)
(
*
service
.
User
,
error
)
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
user
,
ok
:=
s
.
usersByEmail
[
email
]
if
!
ok
{
return
nil
,
service
.
ErrUserNotFound
}
return
cloneEmailBindUser
(
user
),
nil
}
func
(
s
*
emailBindUserRepoStub
)
GetFirstAdmin
(
context
.
Context
)
(
*
service
.
User
,
error
)
{
panic
(
"unexpected GetFirstAdmin call"
)
}
func
(
s
*
emailBindUserRepoStub
)
Update
(
_
context
.
Context
,
user
*
service
.
User
)
error
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
existing
,
ok
:=
s
.
usersByID
[
user
.
ID
]
if
!
ok
{
return
service
.
ErrUserNotFound
}
delete
(
s
.
usersByEmail
,
existing
.
Email
)
cloned
:=
cloneEmailBindUser
(
user
)
s
.
usersByID
[
user
.
ID
]
=
cloned
s
.
usersByEmail
[
cloned
.
Email
]
=
cloned
return
nil
}
func
(
s
*
emailBindUserRepoStub
)
Delete
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
s
*
emailBindUserRepoStub
)
GetUserAvatar
(
context
.
Context
,
int64
)
(
*
service
.
UserAvatar
,
error
)
{
return
nil
,
nil
}
func
(
s
*
emailBindUserRepoStub
)
UpsertUserAvatar
(
context
.
Context
,
int64
,
service
.
UpsertUserAvatarInput
)
(
*
service
.
UserAvatar
,
error
)
{
panic
(
"unexpected UpsertUserAvatar call"
)
}
func
(
s
*
emailBindUserRepoStub
)
DeleteUserAvatar
(
context
.
Context
,
int64
)
error
{
panic
(
"unexpected DeleteUserAvatar call"
)
}
func
(
s
*
emailBindUserRepoStub
)
List
(
context
.
Context
,
pagination
.
PaginationParams
)
([]
service
.
User
,
*
pagination
.
PaginationResult
,
error
)
{
panic
(
"unexpected List call"
)
}
func
(
s
*
emailBindUserRepoStub
)
ListWithFilters
(
context
.
Context
,
pagination
.
PaginationParams
,
service
.
UserListFilters
)
([]
service
.
User
,
*
pagination
.
PaginationResult
,
error
)
{
panic
(
"unexpected ListWithFilters call"
)
}
func
(
s
*
emailBindUserRepoStub
)
GetLatestUsedAtByUserIDs
(
context
.
Context
,
[]
int64
)
(
map
[
int64
]
*
time
.
Time
,
error
)
{
return
map
[
int64
]
*
time
.
Time
{},
nil
}
func
(
s
*
emailBindUserRepoStub
)
GetLatestUsedAtByUserID
(
context
.
Context
,
int64
)
(
*
time
.
Time
,
error
)
{
return
nil
,
nil
}
func
(
s
*
emailBindUserRepoStub
)
UpdateUserLastActiveAt
(
context
.
Context
,
int64
,
time
.
Time
)
error
{
return
nil
}
func
(
s
*
emailBindUserRepoStub
)
UpdateBalance
(
context
.
Context
,
int64
,
float64
)
error
{
return
nil
}
func
(
s
*
emailBindUserRepoStub
)
DeductBalance
(
context
.
Context
,
int64
,
float64
)
error
{
return
nil
}
func
(
s
*
emailBindUserRepoStub
)
UpdateConcurrency
(
context
.
Context
,
int64
,
int
)
error
{
return
nil
}
func
(
s
*
emailBindUserRepoStub
)
ExistsByEmail
(
_
context
.
Context
,
email
string
)
(
bool
,
error
)
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
_
,
ok
:=
s
.
usersByEmail
[
email
]
return
ok
,
nil
}
func
(
s
*
emailBindUserRepoStub
)
RemoveGroupFromAllowedGroups
(
context
.
Context
,
int64
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
s
*
emailBindUserRepoStub
)
AddGroupToAllowedGroups
(
context
.
Context
,
int64
,
int64
)
error
{
return
nil
}
func
(
s
*
emailBindUserRepoStub
)
RemoveGroupFromUserAllowedGroups
(
context
.
Context
,
int64
,
int64
)
error
{
return
nil
}
func
(
s
*
emailBindUserRepoStub
)
ListUserAuthIdentities
(
context
.
Context
,
int64
)
([]
service
.
UserAuthIdentityRecord
,
error
)
{
return
nil
,
nil
}
func
(
s
*
emailBindUserRepoStub
)
UnbindUserAuthProvider
(
context
.
Context
,
int64
,
string
)
error
{
return
nil
}
func
(
s
*
emailBindUserRepoStub
)
UpdateTotpSecret
(
context
.
Context
,
int64
,
*
string
)
error
{
return
nil
}
func
(
s
*
emailBindUserRepoStub
)
EnableTotp
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
s
*
emailBindUserRepoStub
)
DisableTotp
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
cloneEmailBindUser
(
user
*
service
.
User
)
*
service
.
User
{
if
user
==
nil
{
return
nil
}
cloned
:=
*
user
return
&
cloned
}
backend/internal/service/payment_config_limits.go
View file @
ddf80f5e
...
@@ -20,7 +20,7 @@ func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*M
...
@@ -20,7 +20,7 @@ func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*M
return
nil
,
fmt
.
Errorf
(
"query provider instances: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"query provider instances: %w"
,
err
)
}
}
typeInstances
:=
pcGroupByPaymentType
(
instances
)
typeInstances
:=
pcGroupByPaymentType
(
instances
)
typeInstances
=
pcApplyEnabledVisibleMethodInstances
(
typeInstances
,
instances
)
typeInstances
=
s
.
pcApplyEnabledVisibleMethodInstances
(
ctx
,
typeInstances
,
instances
)
resp
:=
&
MethodLimitsResponse
{
resp
:=
&
MethodLimitsResponse
{
Methods
:
make
(
map
[
string
]
MethodLimits
,
len
(
typeInstances
)),
Methods
:
make
(
map
[
string
]
MethodLimits
,
len
(
typeInstances
)),
}
}
...
@@ -32,7 +32,7 @@ func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*M
...
@@ -32,7 +32,7 @@ func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*M
return
resp
,
nil
return
resp
,
nil
}
}
func
pcApplyEnabledVisibleMethodInstances
(
typeInstances
map
[
string
][]
*
dbent
.
PaymentProviderInstance
,
instances
[]
*
dbent
.
PaymentProviderInstance
)
map
[
string
][]
*
dbent
.
PaymentProviderInstance
{
func
(
s
*
PaymentConfigService
)
pcApplyEnabledVisibleMethodInstances
(
ctx
context
.
Context
,
typeInstances
map
[
string
][]
*
dbent
.
PaymentProviderInstance
,
instances
[]
*
dbent
.
PaymentProviderInstance
)
map
[
string
][]
*
dbent
.
PaymentProviderInstance
{
if
len
(
typeInstances
)
==
0
{
if
len
(
typeInstances
)
==
0
{
return
typeInstances
return
typeInstances
}
}
...
@@ -44,11 +44,25 @@ func pcApplyEnabledVisibleMethodInstances(typeInstances map[string][]*dbent.Paym
...
@@ -44,11 +44,25 @@ func pcApplyEnabledVisibleMethodInstances(typeInstances map[string][]*dbent.Paym
for
_
,
method
:=
range
[]
string
{
payment
.
TypeAlipay
,
payment
.
TypeWxpay
}
{
for
_
,
method
:=
range
[]
string
{
payment
.
TypeAlipay
,
payment
.
TypeWxpay
}
{
matching
:=
filterEnabledVisibleMethodInstances
(
instances
,
method
)
matching
:=
filterEnabledVisibleMethodInstances
(
instances
,
method
)
if
len
(
matching
)
!=
1
{
providerKey
,
err
:=
s
.
resolveVisibleMethodProviderKey
(
ctx
,
method
,
matching
)
if
err
!=
nil
{
delete
(
filtered
,
method
)
delete
(
filtered
,
method
)
continue
continue
}
}
filtered
[
method
]
=
[]
*
dbent
.
PaymentProviderInstance
{
matching
[
0
]}
if
providerKey
==
""
{
if
len
(
matching
)
==
0
{
delete
(
filtered
,
method
)
continue
}
filtered
[
method
]
=
matching
continue
}
selectedInstances
:=
filterVisibleMethodInstancesByProviderKey
(
instances
,
method
,
providerKey
)
if
len
(
selectedInstances
)
==
0
{
delete
(
filtered
,
method
)
continue
}
filtered
[
method
]
=
selectedInstances
}
}
return
filtered
return
filtered
}
}
...
...
backend/internal/service/payment_config_limits_test.go
View file @
ddf80f5e
...
@@ -6,6 +6,7 @@ import (
...
@@ -6,6 +6,7 @@ import (
dbent
"github.com/Wei-Shaw/sub2api/ent"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/stretchr/testify/require"
)
)
func
TestUnionFloat
(
t
*
testing
.
T
)
{
func
TestUnionFloat
(
t
*
testing
.
T
)
{
...
@@ -301,7 +302,109 @@ func TestPcInstanceTypeLimits(t *testing.T) {
...
@@ -301,7 +302,109 @@ func TestPcInstanceTypeLimits(t *testing.T) {
})
})
}
}
func
TestGetAvailableMethodLimitsHidesConflictingVisibleMethodProviders
(
t
*
testing
.
T
)
{
func
TestGetAvailableMethodLimitsUsesConfiguredVisibleMethodSource
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
sourceSetting
string
wantAlipaySingleMin
float64
wantAlipaySingleMax
float64
wantGlobalMin
float64
wantGlobalMax
float64
}{
{
name
:
"official source"
,
sourceSetting
:
VisibleMethodSourceOfficialAlipay
,
wantAlipaySingleMin
:
10
,
wantAlipaySingleMax
:
100
,
wantGlobalMin
:
10
,
wantGlobalMax
:
300
,
},
{
name
:
"easypay source"
,
sourceSetting
:
VisibleMethodSourceEasyPayAlipay
,
wantAlipaySingleMin
:
20
,
wantAlipaySingleMax
:
200
,
wantGlobalMin
:
20
,
wantGlobalMax
:
300
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
client
:=
newPaymentConfigServiceTestClient
(
t
)
_
,
err
:=
client
.
PaymentProviderInstance
.
Create
()
.
SetProviderKey
(
payment
.
TypeAlipay
)
.
SetName
(
"Official Alipay"
)
.
SetConfig
(
"{}"
)
.
SetSupportedTypes
(
"alipay"
)
.
SetLimits
(
`{"alipay":{"singleMin":10,"singleMax":100}}`
)
.
SetEnabled
(
true
)
.
Save
(
ctx
)
if
err
!=
nil
{
t
.
Fatalf
(
"create official alipay instance: %v"
,
err
)
}
_
,
err
=
client
.
PaymentProviderInstance
.
Create
()
.
SetProviderKey
(
payment
.
TypeEasyPay
)
.
SetName
(
"EasyPay Alipay"
)
.
SetConfig
(
"{}"
)
.
SetSupportedTypes
(
"alipay"
)
.
SetLimits
(
`{"alipay":{"singleMin":20,"singleMax":200}}`
)
.
SetEnabled
(
true
)
.
Save
(
ctx
)
if
err
!=
nil
{
t
.
Fatalf
(
"create easypay alipay instance: %v"
,
err
)
}
_
,
err
=
client
.
PaymentProviderInstance
.
Create
()
.
SetProviderKey
(
payment
.
TypeWxpay
)
.
SetName
(
"Official WeChat"
)
.
SetConfig
(
"{}"
)
.
SetSupportedTypes
(
"wxpay"
)
.
SetLimits
(
`{"wxpay":{"singleMin":30,"singleMax":300}}`
)
.
SetEnabled
(
true
)
.
Save
(
ctx
)
if
err
!=
nil
{
t
.
Fatalf
(
"create official wxpay instance: %v"
,
err
)
}
svc
:=
&
PaymentConfigService
{
entClient
:
client
,
settingRepo
:
&
paymentConfigSettingRepoStub
{
values
:
map
[
string
]
string
{
SettingPaymentVisibleMethodAlipaySource
:
tt
.
sourceSetting
,
},
},
}
resp
,
err
:=
svc
.
GetAvailableMethodLimits
(
ctx
)
if
err
!=
nil
{
t
.
Fatalf
(
"GetAvailableMethodLimits returned error: %v"
,
err
)
}
alipayLimits
,
ok
:=
resp
.
Methods
[
payment
.
TypeAlipay
]
if
!
ok
{
t
.
Fatalf
(
"expected alipay limits to remain visible, got %v"
,
resp
.
Methods
)
}
if
alipayLimits
.
SingleMin
!=
tt
.
wantAlipaySingleMin
||
alipayLimits
.
SingleMax
!=
tt
.
wantAlipaySingleMax
{
t
.
Fatalf
(
"alipay limits = %+v, want min=%v max=%v"
,
alipayLimits
,
tt
.
wantAlipaySingleMin
,
tt
.
wantAlipaySingleMax
)
}
wxpayLimits
,
ok
:=
resp
.
Methods
[
payment
.
TypeWxpay
]
if
!
ok
{
t
.
Fatalf
(
"expected wxpay limits to remain visible, got %v"
,
resp
.
Methods
)
}
if
wxpayLimits
.
SingleMin
!=
30
||
wxpayLimits
.
SingleMax
!=
300
{
t
.
Fatalf
(
"wxpay limits = %+v, want official-only min=30 max=300"
,
wxpayLimits
)
}
if
resp
.
GlobalMin
!=
tt
.
wantGlobalMin
||
resp
.
GlobalMax
!=
tt
.
wantGlobalMax
{
t
.
Fatalf
(
"global range = (%v, %v), want (%v, %v)"
,
resp
.
GlobalMin
,
resp
.
GlobalMax
,
tt
.
wantGlobalMin
,
tt
.
wantGlobalMax
)
}
})
}
}
func
TestGetAvailableMethodLimitsPreservesLegacyCrossProviderBehaviorWhenVisibleMethodSourceMissing
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
ctx
:=
context
.
Background
()
client
:=
newPaymentConfigServiceTestClient
(
t
)
client
:=
newPaymentConfigServiceTestClient
(
t
)
...
@@ -313,20 +416,18 @@ func TestGetAvailableMethodLimitsHidesConflictingVisibleMethodProviders(t *testi
...
@@ -313,20 +416,18 @@ func TestGetAvailableMethodLimitsHidesConflictingVisibleMethodProviders(t *testi
SetLimits
(
`{"alipay":{"singleMin":10,"singleMax":100}}`
)
.
SetLimits
(
`{"alipay":{"singleMin":10,"singleMax":100}}`
)
.
SetEnabled
(
true
)
.
SetEnabled
(
true
)
.
Save
(
ctx
)
Save
(
ctx
)
if
err
!=
nil
{
require
.
NoError
(
t
,
err
)
t
.
Fatalf
(
"create official alipay instance: %v"
,
err
)
}
_
,
err
=
client
.
PaymentProviderInstance
.
Create
()
.
_
,
err
=
client
.
PaymentProviderInstance
.
Create
()
.
SetProviderKey
(
payment
.
TypeEasyPay
)
.
SetProviderKey
(
payment
.
TypeEasyPay
)
.
SetName
(
"EasyPay
Alipay
"
)
.
SetName
(
"EasyPay
Mixed
"
)
.
SetConfig
(
"{}"
)
.
SetConfig
(
"{}"
)
.
SetSupportedTypes
(
"alipay"
)
.
SetSupportedTypes
(
"alipay
,wxpay
"
)
.
SetLimits
(
`{"alipay":{"singleMin":20,"singleMax":200}}`
)
.
SetLimits
(
`{"alipay":{"singleMin":20,"singleMax":200}
,"wxpay":{"singleMin":40,"singleMax":400}
}`
)
.
SetEnabled
(
true
)
.
SetEnabled
(
true
)
.
Save
(
ctx
)
Save
(
ctx
)
if
err
!=
nil
{
require
.
NoError
(
t
,
err
)
t
.
Fatalf
(
"create easypay alipay instance: %v"
,
err
)
}
_
,
err
=
client
.
PaymentProviderInstance
.
Create
()
.
_
,
err
=
client
.
PaymentProviderInstance
.
Create
()
.
SetProviderKey
(
payment
.
TypeWxpay
)
.
SetProviderKey
(
payment
.
TypeWxpay
)
.
SetName
(
"Official WeChat"
)
.
SetName
(
"Official WeChat"
)
.
...
@@ -335,31 +436,26 @@ func TestGetAvailableMethodLimitsHidesConflictingVisibleMethodProviders(t *testi
...
@@ -335,31 +436,26 @@ func TestGetAvailableMethodLimitsHidesConflictingVisibleMethodProviders(t *testi
SetLimits
(
`{"wxpay":{"singleMin":30,"singleMax":300}}`
)
.
SetLimits
(
`{"wxpay":{"singleMin":30,"singleMax":300}}`
)
.
SetEnabled
(
true
)
.
SetEnabled
(
true
)
.
Save
(
ctx
)
Save
(
ctx
)
if
err
!=
nil
{
require
.
NoError
(
t
,
err
)
t
.
Fatalf
(
"create official wxpay instance: %v"
,
err
)
}
svc
:=
&
PaymentConfigService
{
svc
:=
&
PaymentConfigService
{
entClient
:
client
,
entClient
:
client
,
settingRepo
:
&
paymentConfigSettingRepoStub
{
values
:
map
[
string
]
string
{}},
}
}
resp
,
err
:=
svc
.
GetAvailableMethodLimits
(
ctx
)
resp
,
err
:=
svc
.
GetAvailableMethodLimits
(
ctx
)
if
err
!=
nil
{
require
.
NoError
(
t
,
err
)
t
.
Fatalf
(
"GetAvailableMethodLimits returned error: %v"
,
err
)
}
if
_
,
ok
:=
resp
.
Methods
[
payment
.
TypeAlipay
];
ok
{
alipayLimits
,
ok
:=
resp
.
Methods
[
payment
.
TypeAlipay
]
t
.
Fatalf
(
"alipay should be hidden when multiple enabled providers claim it, got %v"
,
resp
.
Methods
[
payment
.
TypeAlipay
])
require
.
True
(
t
,
ok
,
"expected alipay limits to remain visible"
)
}
require
.
Equal
(
t
,
10.0
,
alipayLimits
.
SingleMin
)
require
.
Equal
(
t
,
200.0
,
alipayLimits
.
SingleMax
)
wxpayLimits
,
ok
:=
resp
.
Methods
[
payment
.
TypeWxpay
]
wxpayLimits
,
ok
:=
resp
.
Methods
[
payment
.
TypeWxpay
]
if
!
ok
{
require
.
True
(
t
,
ok
,
"expected wxpay limits to remain visible"
)
t
.
Fatalf
(
"expected wxpay limits to remain visible, got %v"
,
resp
.
Methods
)
require
.
Equal
(
t
,
30.0
,
wxpayLimits
.
SingleMin
)
}
require
.
Equal
(
t
,
400.0
,
wxpayLimits
.
SingleMax
)
if
wxpayLimits
.
SingleMin
!=
30
||
wxpayLimits
.
SingleMax
!=
300
{
t
.
Fatalf
(
"wxpay limits = %+v, want official-only min=30 max=300"
,
wxpayLimits
)
require
.
Equal
(
t
,
10.0
,
resp
.
GlobalMin
)
}
require
.
Equal
(
t
,
400.0
,
resp
.
GlobalMax
)
if
resp
.
GlobalMin
!=
30
||
resp
.
GlobalMax
!=
300
{
t
.
Fatalf
(
"global range = (%v, %v), want (30, 300)"
,
resp
.
GlobalMin
,
resp
.
GlobalMax
)
}
}
}
backend/internal/service/payment_config_providers.go
View file @
ddf80f5e
...
@@ -116,6 +116,17 @@ var providerSensitiveConfigFields = map[string]map[string]struct{}{
...
@@ -116,6 +116,17 @@ var providerSensitiveConfigFields = map[string]map[string]struct{}{
payment
.
TypeStripe
:
{
"secretkey"
:
{},
"webhooksecret"
:
{}},
payment
.
TypeStripe
:
{
"secretkey"
:
{},
"webhooksecret"
:
{}},
}
}
// providerPendingOrderProtectedConfigFields lists config keys that cannot be
// changed while the instance has in-progress orders. This includes secrets plus
// all provider identity fields that are snapshotted into orders or used by
// webhook/refund verification.
var
providerPendingOrderProtectedConfigFields
=
map
[
string
]
map
[
string
]
struct
{}{
payment
.
TypeEasyPay
:
{
"pkey"
:
{},
"pid"
:
{}},
payment
.
TypeAlipay
:
{
"privatekey"
:
{},
"publickey"
:
{},
"alipaypublickey"
:
{},
"appid"
:
{}},
payment
.
TypeWxpay
:
{
"privatekey"
:
{},
"apiv3key"
:
{},
"publickey"
:
{},
"appid"
:
{},
"mpappid"
:
{},
"mchid"
:
{},
"publickeyid"
:
{},
"certserial"
:
{}},
payment
.
TypeStripe
:
{
"secretkey"
:
{},
"webhooksecret"
:
{}},
}
func
isSensitiveProviderConfigField
(
providerKey
,
fieldName
string
)
bool
{
func
isSensitiveProviderConfigField
(
providerKey
,
fieldName
string
)
bool
{
fields
,
ok
:=
providerSensitiveConfigFields
[
providerKey
]
fields
,
ok
:=
providerSensitiveConfigFields
[
providerKey
]
if
!
ok
{
if
!
ok
{
...
@@ -125,6 +136,28 @@ func isSensitiveProviderConfigField(providerKey, fieldName string) bool {
...
@@ -125,6 +136,28 @@ func isSensitiveProviderConfigField(providerKey, fieldName string) bool {
return
found
return
found
}
}
func
hasPendingOrderProtectedConfigChange
(
providerKey
string
,
currentConfig
,
nextConfig
map
[
string
]
string
)
bool
{
fields
,
ok
:=
providerPendingOrderProtectedConfigFields
[
providerKey
]
if
!
ok
{
return
false
}
for
fieldName
:=
range
fields
{
if
providerConfigFieldValue
(
currentConfig
,
fieldName
)
!=
providerConfigFieldValue
(
nextConfig
,
fieldName
)
{
return
true
}
}
return
false
}
func
providerConfigFieldValue
(
config
map
[
string
]
string
,
fieldName
string
)
string
{
for
key
,
value
:=
range
config
{
if
strings
.
EqualFold
(
key
,
fieldName
)
{
return
value
}
}
return
""
}
func
(
s
*
PaymentConfigService
)
countPendingOrders
(
ctx
context
.
Context
,
providerInstanceID
int64
)
(
int
,
error
)
{
func
(
s
*
PaymentConfigService
)
countPendingOrders
(
ctx
context
.
Context
,
providerInstanceID
int64
)
(
int
,
error
)
{
return
s
.
entClient
.
PaymentOrder
.
Query
()
.
return
s
.
entClient
.
PaymentOrder
.
Query
()
.
Where
(
Where
(
...
@@ -190,6 +223,18 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
...
@@ -190,6 +223,18 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"load provider instance: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"load provider instance: %w"
,
err
)
}
}
var
pendingOrderCount
*
int
getPendingOrderCount
:=
func
()
(
int
,
error
)
{
if
pendingOrderCount
!=
nil
{
return
*
pendingOrderCount
,
nil
}
count
,
err
:=
s
.
countPendingOrders
(
ctx
,
id
)
if
err
!=
nil
{
return
0
,
fmt
.
Errorf
(
"check pending orders: %w"
,
err
)
}
pendingOrderCount
=
&
count
return
count
,
nil
}
nextEnabled
:=
current
.
Enabled
nextEnabled
:=
current
.
Enabled
if
req
.
Enabled
!=
nil
{
if
req
.
Enabled
!=
nil
{
nextEnabled
=
*
req
.
Enabled
nextEnabled
=
*
req
.
Enabled
...
@@ -201,18 +246,20 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
...
@@ -201,18 +246,20 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
if
err
:=
s
.
validateVisibleMethodEnablementConflicts
(
ctx
,
id
,
current
.
ProviderKey
,
nextSupportedTypes
,
nextEnabled
);
err
!=
nil
{
if
err
:=
s
.
validateVisibleMethodEnablementConflicts
(
ctx
,
id
,
current
.
ProviderKey
,
nextSupportedTypes
,
nextEnabled
);
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
var
mergedConfig
map
[
string
]
string
if
req
.
Config
!=
nil
{
if
req
.
Config
!=
nil
{
hasSensitive
:=
false
currentConfig
,
err
:=
s
.
decryptConfig
(
current
.
Config
)
for
k
,
v
:=
range
req
.
Config
{
if
err
!=
nil
{
if
v
!=
""
&&
isSensitiveProviderConfigField
(
current
.
ProviderKey
,
k
)
{
return
nil
,
fmt
.
Errorf
(
"decrypt existing config: %w"
,
err
)
hasSensitive
=
true
break
}
}
}
if
hasSensitive
{
mergedConfig
,
err
=
s
.
mergeConfig
(
ctx
,
id
,
req
.
Config
)
count
,
err
:=
s
.
countPendingOrders
(
ctx
,
id
)
if
err
!=
nil
{
return
nil
,
err
}
if
hasPendingOrderProtectedConfigChange
(
current
.
ProviderKey
,
currentConfig
,
mergedConfig
)
{
count
,
err
:=
getPendingOrderCount
()
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"check pending orders: %w"
,
err
)
return
nil
,
err
}
}
if
count
>
0
{
if
count
>
0
{
return
nil
,
infraerrors
.
Conflict
(
"PENDING_ORDERS"
,
"instance has pending orders"
)
.
return
nil
,
infraerrors
.
Conflict
(
"PENDING_ORDERS"
,
"instance has pending orders"
)
.
...
@@ -221,9 +268,9 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
...
@@ -221,9 +268,9 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
}
}
}
}
if
req
.
Enabled
!=
nil
&&
!*
req
.
Enabled
{
if
req
.
Enabled
!=
nil
&&
!*
req
.
Enabled
{
count
,
err
:=
s
.
coun
tPendingOrder
s
(
ctx
,
id
)
count
,
err
:=
ge
tPendingOrder
Count
(
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"check pending orders: %w"
,
err
)
return
nil
,
err
}
}
if
count
>
0
{
if
count
>
0
{
return
nil
,
infraerrors
.
Conflict
(
"PENDING_ORDERS"
,
"instance has pending orders"
)
.
return
nil
,
infraerrors
.
Conflict
(
"PENDING_ORDERS"
,
"instance has pending orders"
)
.
...
@@ -237,13 +284,6 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
...
@@ -237,13 +284,6 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
if
req
.
Enabled
!=
nil
{
if
req
.
Enabled
!=
nil
{
finalEnabled
=
*
req
.
Enabled
finalEnabled
=
*
req
.
Enabled
}
}
var
mergedConfig
map
[
string
]
string
if
req
.
Config
!=
nil
{
mergedConfig
,
err
=
s
.
mergeConfig
(
ctx
,
id
,
req
.
Config
)
if
err
!=
nil
{
return
nil
,
err
}
}
if
finalEnabled
{
if
finalEnabled
{
configToValidate
:=
mergedConfig
configToValidate
:=
mergedConfig
if
configToValidate
==
nil
{
if
configToValidate
==
nil
{
...
@@ -269,9 +309,9 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
...
@@ -269,9 +309,9 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
}
}
if
req
.
SupportedTypes
!=
nil
{
if
req
.
SupportedTypes
!=
nil
{
// Check pending orders before removing payment types
// Check pending orders before removing payment types
count
,
err
:=
s
.
coun
tPendingOrder
s
(
ctx
,
id
)
count
,
err
:=
ge
tPendingOrder
Count
(
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"check pending orders: %w"
,
err
)
return
nil
,
err
}
}
if
count
>
0
{
if
count
>
0
{
// Load current instance to compare types
// Load current instance to compare types
...
...
backend/internal/service/payment_config_providers_test.go
View file @
ddf80f5e
...
@@ -4,8 +4,16 @@ package service
...
@@ -4,8 +4,16 @@ package service
import
(
import
(
"context"
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"strconv"
"testing"
"testing"
"time"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/require"
...
@@ -199,7 +207,7 @@ func TestJoinTypes(t *testing.T) {
...
@@ -199,7 +207,7 @@ func TestJoinTypes(t *testing.T) {
}
}
}
}
func
TestCreateProviderInstance
RejectsConflictingVisibleMethodEnablement
(
t
*
testing
.
T
)
{
func
TestCreateProviderInstance
AllowsVisibleMethodProvidersFromDifferentSources
(
t
*
testing
.
T
)
{
t
.
Parallel
()
t
.
Parallel
()
ctx
:=
context
.
Background
()
ctx
:=
context
.
Background
()
...
@@ -227,15 +235,14 @@ func TestCreateProviderInstanceRejectsConflictingVisibleMethodEnablement(t *test
...
@@ -227,15 +235,14 @@ func TestCreateProviderInstanceRejectsConflictingVisibleMethodEnablement(t *test
_
,
err
=
svc
.
CreateProviderInstance
(
ctx
,
CreateProviderInstanceRequest
{
_
,
err
=
svc
.
CreateProviderInstance
(
ctx
,
CreateProviderInstanceRequest
{
ProviderKey
:
"alipay"
,
ProviderKey
:
"alipay"
,
Name
:
"Official Alipay"
,
Name
:
"Official Alipay"
,
Config
:
map
[
string
]
string
{
"appId"
:
"app-1"
},
Config
:
map
[
string
]
string
{
"appId"
:
"app-1"
,
"privateKey"
:
"private-key"
},
SupportedTypes
:
[]
string
{
"alipay"
},
SupportedTypes
:
[]
string
{
"alipay"
},
Enabled
:
true
,
Enabled
:
true
,
})
})
require
.
Error
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"PAYMENT_PROVIDER_CONFLICT"
,
infraerrors
.
Reason
(
err
))
}
}
func
TestUpdateProviderInstance
Reject
sEnabling
Conflicting
VisibleMethodProvider
(
t
*
testing
.
T
)
{
func
TestUpdateProviderInstance
Allow
sEnablingVisibleMethodProvider
FromDifferentSource
(
t
*
testing
.
T
)
{
t
.
Parallel
()
t
.
Parallel
()
ctx
:=
context
.
Background
()
ctx
:=
context
.
Background
()
...
@@ -264,7 +271,7 @@ func TestUpdateProviderInstanceRejectsEnablingConflictingVisibleMethodProvider(t
...
@@ -264,7 +271,7 @@ func TestUpdateProviderInstanceRejectsEnablingConflictingVisibleMethodProvider(t
candidate
,
err
:=
svc
.
CreateProviderInstance
(
ctx
,
CreateProviderInstanceRequest
{
candidate
,
err
:=
svc
.
CreateProviderInstance
(
ctx
,
CreateProviderInstanceRequest
{
ProviderKey
:
"wxpay"
,
ProviderKey
:
"wxpay"
,
Name
:
"Official WeChat"
,
Name
:
"Official WeChat"
,
Config
:
map
[
string
]
string
{
"appId"
:
"wx-app"
}
,
Config
:
validWxpayProviderConfig
(
t
)
,
SupportedTypes
:
[]
string
{
"wxpay"
},
SupportedTypes
:
[]
string
{
"wxpay"
},
Enabled
:
false
,
Enabled
:
false
,
})
})
...
@@ -273,8 +280,7 @@ func TestUpdateProviderInstanceRejectsEnablingConflictingVisibleMethodProvider(t
...
@@ -273,8 +280,7 @@ func TestUpdateProviderInstanceRejectsEnablingConflictingVisibleMethodProvider(t
_
,
err
=
svc
.
UpdateProviderInstance
(
ctx
,
candidate
.
ID
,
UpdateProviderInstanceRequest
{
_
,
err
=
svc
.
UpdateProviderInstance
(
ctx
,
candidate
.
ID
,
UpdateProviderInstanceRequest
{
Enabled
:
boolPtrValue
(
true
),
Enabled
:
boolPtrValue
(
true
),
})
})
require
.
Error
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"PAYMENT_PROVIDER_CONFLICT"
,
infraerrors
.
Reason
(
err
))
}
}
func
TestUpdateProviderInstancePersistsEnabledAndSupportedTypes
(
t
*
testing
.
T
)
{
func
TestUpdateProviderInstancePersistsEnabledAndSupportedTypes
(
t
*
testing
.
T
)
{
...
@@ -314,6 +320,289 @@ func TestUpdateProviderInstancePersistsEnabledAndSupportedTypes(t *testing.T) {
...
@@ -314,6 +320,289 @@ func TestUpdateProviderInstancePersistsEnabledAndSupportedTypes(t *testing.T) {
require
.
Equal
(
t
,
"alipay,wxpay"
,
saved
.
SupportedTypes
)
require
.
Equal
(
t
,
"alipay,wxpay"
,
saved
.
SupportedTypes
)
}
}
func
TestUpdateProviderInstanceRejectsProtectedConfigChangesWhilePendingOrders
(
t
*
testing
.
T
)
{
t
.
Parallel
()
tests
:=
[]
struct
{
name
string
providerKey
string
createConfig
func
(
*
testing
.
T
)
map
[
string
]
string
supportedType
[]
string
updateConfig
map
[
string
]
string
fieldName
string
wantValue
string
}{
{
name
:
"wxpay appId"
,
providerKey
:
payment
.
TypeWxpay
,
createConfig
:
validWxpayProviderConfig
,
supportedType
:
[]
string
{
payment
.
TypeWxpay
},
updateConfig
:
map
[
string
]
string
{
"appId"
:
"wx-app-updated"
},
fieldName
:
"appId"
,
wantValue
:
"wx-app-test"
,
},
{
name
:
"wxpay mpAppId"
,
providerKey
:
payment
.
TypeWxpay
,
createConfig
:
validWxpayProviderConfigWithJSAPIAppID
,
supportedType
:
[]
string
{
payment
.
TypeWxpay
},
updateConfig
:
map
[
string
]
string
{
"mpAppId"
:
"wx-mp-app-updated"
},
fieldName
:
"mpAppId"
,
wantValue
:
"wx-mp-app-test"
,
},
{
name
:
"wxpay mchId"
,
providerKey
:
payment
.
TypeWxpay
,
createConfig
:
validWxpayProviderConfig
,
supportedType
:
[]
string
{
payment
.
TypeWxpay
},
updateConfig
:
map
[
string
]
string
{
"mchId"
:
"mch-updated"
},
fieldName
:
"mchId"
,
wantValue
:
"mch-test"
,
},
{
name
:
"wxpay publicKeyId"
,
providerKey
:
payment
.
TypeWxpay
,
createConfig
:
validWxpayProviderConfig
,
supportedType
:
[]
string
{
payment
.
TypeWxpay
},
updateConfig
:
map
[
string
]
string
{
"publicKeyId"
:
"public-key-id-updated"
},
fieldName
:
"publicKeyId"
,
wantValue
:
"public-key-id-test"
,
},
{
name
:
"wxpay certSerial"
,
providerKey
:
payment
.
TypeWxpay
,
createConfig
:
validWxpayProviderConfig
,
supportedType
:
[]
string
{
payment
.
TypeWxpay
},
updateConfig
:
map
[
string
]
string
{
"certSerial"
:
"cert-serial-updated"
},
fieldName
:
"certSerial"
,
wantValue
:
"cert-serial-test"
,
},
{
name
:
"alipay appId"
,
providerKey
:
payment
.
TypeAlipay
,
createConfig
:
validAlipayProviderConfig
,
supportedType
:
[]
string
{
payment
.
TypeAlipay
},
updateConfig
:
map
[
string
]
string
{
"appId"
:
"alipay-app-updated"
},
fieldName
:
"appId"
,
wantValue
:
"alipay-app-test"
,
},
{
name
:
"easypay pid"
,
providerKey
:
payment
.
TypeEasyPay
,
createConfig
:
validEasyPayProviderConfig
,
supportedType
:
[]
string
{
payment
.
TypeAlipay
},
updateConfig
:
map
[
string
]
string
{
"pid"
:
"pid-updated"
},
fieldName
:
"pid"
,
wantValue
:
"pid-test"
,
},
}
for
_
,
tc
:=
range
tests
{
tc
:=
tc
t
.
Run
(
tc
.
name
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
ctx
:=
context
.
Background
()
client
:=
newPaymentConfigServiceTestClient
(
t
)
svc
:=
&
PaymentConfigService
{
entClient
:
client
,
encryptionKey
:
[]
byte
(
"0123456789abcdef0123456789abcdef"
),
}
instance
,
err
:=
svc
.
CreateProviderInstance
(
ctx
,
CreateProviderInstanceRequest
{
ProviderKey
:
tc
.
providerKey
,
Name
:
"protected-config-instance"
,
Config
:
tc
.
createConfig
(
t
),
SupportedTypes
:
tc
.
supportedType
,
Enabled
:
true
,
})
require
.
NoError
(
t
,
err
)
createPendingProviderConfigOrder
(
t
,
ctx
,
client
,
instance
)
updated
,
err
:=
svc
.
UpdateProviderInstance
(
ctx
,
instance
.
ID
,
UpdateProviderInstanceRequest
{
Config
:
tc
.
updateConfig
,
})
require
.
Nil
(
t
,
updated
)
require
.
Error
(
t
,
err
)
require
.
Equal
(
t
,
"PENDING_ORDERS"
,
infraerrors
.
Reason
(
err
))
saved
,
err
:=
client
.
PaymentProviderInstance
.
Get
(
ctx
,
instance
.
ID
)
require
.
NoError
(
t
,
err
)
cfg
,
err
:=
svc
.
decryptConfig
(
saved
.
Config
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
tc
.
wantValue
,
cfg
[
tc
.
fieldName
])
})
}
}
func
TestUpdateProviderInstanceAllowsSafeConfigChangesWhilePendingOrders
(
t
*
testing
.
T
)
{
t
.
Parallel
()
tests
:=
[]
struct
{
name
string
providerKey
string
createConfig
func
(
*
testing
.
T
)
map
[
string
]
string
supportedType
[]
string
updateConfig
map
[
string
]
string
fieldName
string
wantValue
string
}{
{
name
:
"wxpay notifyUrl"
,
providerKey
:
payment
.
TypeWxpay
,
createConfig
:
validWxpayProviderConfig
,
supportedType
:
[]
string
{
payment
.
TypeWxpay
},
updateConfig
:
map
[
string
]
string
{
"notifyUrl"
:
"https://merchant.example.com/wxpay/notify-v2"
},
fieldName
:
"notifyUrl"
,
wantValue
:
"https://merchant.example.com/wxpay/notify-v2"
,
},
{
name
:
"alipay same appId"
,
providerKey
:
payment
.
TypeAlipay
,
createConfig
:
validAlipayProviderConfig
,
supportedType
:
[]
string
{
payment
.
TypeAlipay
},
updateConfig
:
map
[
string
]
string
{
"appId"
:
"alipay-app-test"
},
fieldName
:
"appId"
,
wantValue
:
"alipay-app-test"
,
},
}
for
_
,
tc
:=
range
tests
{
tc
:=
tc
t
.
Run
(
tc
.
name
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
ctx
:=
context
.
Background
()
client
:=
newPaymentConfigServiceTestClient
(
t
)
svc
:=
&
PaymentConfigService
{
entClient
:
client
,
encryptionKey
:
[]
byte
(
"0123456789abcdef0123456789abcdef"
),
}
instance
,
err
:=
svc
.
CreateProviderInstance
(
ctx
,
CreateProviderInstanceRequest
{
ProviderKey
:
tc
.
providerKey
,
Name
:
"safe-config-instance"
,
Config
:
tc
.
createConfig
(
t
),
SupportedTypes
:
tc
.
supportedType
,
Enabled
:
true
,
})
require
.
NoError
(
t
,
err
)
createPendingProviderConfigOrder
(
t
,
ctx
,
client
,
instance
)
updated
,
err
:=
svc
.
UpdateProviderInstance
(
ctx
,
instance
.
ID
,
UpdateProviderInstanceRequest
{
Config
:
tc
.
updateConfig
,
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
updated
)
saved
,
err
:=
client
.
PaymentProviderInstance
.
Get
(
ctx
,
instance
.
ID
)
require
.
NoError
(
t
,
err
)
cfg
,
err
:=
svc
.
decryptConfig
(
saved
.
Config
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
tc
.
wantValue
,
cfg
[
tc
.
fieldName
])
})
}
}
func
createPendingProviderConfigOrder
(
t
*
testing
.
T
,
ctx
context
.
Context
,
client
*
dbent
.
Client
,
instance
*
dbent
.
PaymentProviderInstance
)
{
t
.
Helper
()
user
,
err
:=
client
.
User
.
Create
()
.
SetEmail
(
"provider-config-pending@example.com"
)
.
SetPasswordHash
(
"hash"
)
.
SetUsername
(
"provider-config-pending-user"
)
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
instanceID
:=
strconv
.
FormatInt
(
instance
.
ID
,
10
)
_
,
err
=
client
.
PaymentOrder
.
Create
()
.
SetUserID
(
user
.
ID
)
.
SetUserEmail
(
user
.
Email
)
.
SetUserName
(
user
.
Username
)
.
SetAmount
(
88
)
.
SetPayAmount
(
88
)
.
SetFeeRate
(
0
)
.
SetRechargeCode
(
"PENDING-PROVIDER-CONFIG-"
+
instanceID
)
.
SetOutTradeNo
(
"sub2_pending_provider_config_"
+
instanceID
)
.
SetPaymentType
(
providerPendingOrderPaymentType
(
instance
.
ProviderKey
))
.
SetPaymentTradeNo
(
""
)
.
SetOrderType
(
payment
.
OrderTypeBalance
)
.
SetStatus
(
OrderStatusPending
)
.
SetExpiresAt
(
time
.
Now
()
.
Add
(
time
.
Hour
))
.
SetClientIP
(
"127.0.0.1"
)
.
SetSrcHost
(
"api.example.com"
)
.
SetProviderInstanceID
(
instanceID
)
.
SetProviderKey
(
instance
.
ProviderKey
)
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
}
func
providerPendingOrderPaymentType
(
providerKey
string
)
string
{
switch
providerKey
{
case
payment
.
TypeWxpay
:
return
payment
.
TypeWxpay
case
payment
.
TypeAlipay
:
return
payment
.
TypeAlipay
default
:
return
payment
.
TypeAlipay
}
}
func
boolPtrValue
(
v
bool
)
*
bool
{
func
boolPtrValue
(
v
bool
)
*
bool
{
return
&
v
return
&
v
}
}
func
validAlipayProviderConfig
(
t
*
testing
.
T
)
map
[
string
]
string
{
t
.
Helper
()
return
map
[
string
]
string
{
"appId"
:
"alipay-app-test"
,
"privateKey"
:
"alipay-private-key-test"
,
"notifyUrl"
:
"https://merchant.example.com/alipay/notify"
,
"returnUrl"
:
"https://merchant.example.com/alipay/return"
,
}
}
func
validEasyPayProviderConfig
(
t
*
testing
.
T
)
map
[
string
]
string
{
t
.
Helper
()
return
map
[
string
]
string
{
"pid"
:
"pid-test"
,
"pkey"
:
"pkey-test"
,
"apiBase"
:
"https://pay.example.com"
,
"notifyUrl"
:
"https://merchant.example.com/easypay/notify"
,
"returnUrl"
:
"https://merchant.example.com/easypay/return"
,
}
}
func
validWxpayProviderConfig
(
t
*
testing
.
T
)
map
[
string
]
string
{
t
.
Helper
()
key
,
err
:=
rsa
.
GenerateKey
(
rand
.
Reader
,
2048
)
require
.
NoError
(
t
,
err
)
privDER
,
err
:=
x509
.
MarshalPKCS8PrivateKey
(
key
)
require
.
NoError
(
t
,
err
)
pubDER
,
err
:=
x509
.
MarshalPKIXPublicKey
(
&
key
.
PublicKey
)
require
.
NoError
(
t
,
err
)
return
map
[
string
]
string
{
"appId"
:
"wx-app-test"
,
"mchId"
:
"mch-test"
,
"privateKey"
:
string
(
pem
.
EncodeToMemory
(
&
pem
.
Block
{
Type
:
"PRIVATE KEY"
,
Bytes
:
privDER
})),
"apiV3Key"
:
"12345678901234567890123456789012"
,
"publicKey"
:
string
(
pem
.
EncodeToMemory
(
&
pem
.
Block
{
Type
:
"PUBLIC KEY"
,
Bytes
:
pubDER
})),
"publicKeyId"
:
"public-key-id-test"
,
"certSerial"
:
"cert-serial-test"
,
}
}
func
validWxpayProviderConfigWithJSAPIAppID
(
t
*
testing
.
T
)
map
[
string
]
string
{
t
.
Helper
()
cfg
:=
validWxpayProviderConfig
(
t
)
cfg
[
"mpAppId"
]
=
"wx-mp-app-test"
return
cfg
}
backend/internal/service/payment_fulfillment.go
View file @
ddf80f5e
...
@@ -80,21 +80,25 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo
...
@@ -80,21 +80,25 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo
})
})
return
err
return
err
}
}
// Skip amount check when paid=0 (e.g. QueryOrder doesn't return amount).
if
!
isValidProviderAmount
(
paid
)
{
// Also skip if paid is NaN/Inf (malformed provider data).
s
.
writeAuditLog
(
ctx
,
o
.
ID
,
"PAYMENT_INVALID_AMOUNT"
,
pk
,
map
[
string
]
any
{
if
paid
>
0
&&
!
math
.
IsNaN
(
paid
)
&&
!
math
.
IsInf
(
paid
,
0
)
{
"expected"
:
o
.
PayAmount
,
if
math
.
Abs
(
paid
-
o
.
PayAmount
)
>
amountToleranceCNY
{
"paid"
:
paid
,
s
.
writeAuditLog
(
ctx
,
o
.
ID
,
"PAYMENT_AMOUNT_MISMATCH"
,
pk
,
map
[
string
]
any
{
"expected"
:
o
.
PayAmount
,
"paid"
:
paid
,
"tradeNo"
:
tradeNo
})
"tradeNo"
:
tradeNo
,
return
fmt
.
Errorf
(
"amount mismatch: expected %.2f, got %.2f"
,
o
.
PayAmount
,
paid
)
}
)
}
return
fmt
.
Errorf
(
"invalid paid amount from provider: %v"
,
paid
)
}
}
// Use order's expected amount when provider didn't report one
if
math
.
Abs
(
paid
-
o
.
PayAmount
)
>
amountToleranceCNY
{
if
paid
<=
0
||
math
.
IsNaN
(
paid
)
||
math
.
IsInf
(
paid
,
0
)
{
s
.
writeAuditLog
(
ctx
,
o
.
ID
,
"PAYMENT_AMOUNT_MISMATCH"
,
pk
,
map
[
string
]
any
{
"expected"
:
o
.
PayAmount
,
"paid"
:
paid
,
"tradeNo"
:
tradeNo
})
paid
=
o
.
PayAmount
return
fmt
.
Errorf
(
"amount mismatch: expected %.2f, got %.2f"
,
o
.
PayAmount
,
paid
)
}
}
return
s
.
toPaid
(
ctx
,
o
,
tradeNo
,
paid
,
pk
)
return
s
.
toPaid
(
ctx
,
o
,
tradeNo
,
paid
,
pk
)
}
}
func
isValidProviderAmount
(
amount
float64
)
bool
{
return
amount
>
0
&&
!
math
.
IsNaN
(
amount
)
&&
!
math
.
IsInf
(
amount
,
0
)
}
func
validateProviderNotificationMetadata
(
order
*
dbent
.
PaymentOrder
,
providerKey
string
,
metadata
map
[
string
]
string
)
error
{
func
validateProviderNotificationMetadata
(
order
*
dbent
.
PaymentOrder
,
providerKey
string
,
metadata
map
[
string
]
string
)
error
{
return
validateProviderSnapshotMetadata
(
order
,
providerKey
,
metadata
)
return
validateProviderSnapshotMetadata
(
order
,
providerKey
,
metadata
)
}
}
...
...
backend/internal/service/payment_fulfillment_test.go
View file @
ddf80f5e
...
@@ -5,6 +5,7 @@ package service
...
@@ -5,6 +5,7 @@ package service
import
(
import
(
"context"
"context"
"errors"
"errors"
"math"
"testing"
"testing"
dbent
"github.com/Wei-Shaw/sub2api/ent"
dbent
"github.com/Wei-Shaw/sub2api/ent"
...
@@ -322,6 +323,16 @@ func TestParseLegacyPaymentOrderID(t *testing.T) {
...
@@ -322,6 +323,16 @@ func TestParseLegacyPaymentOrderID(t *testing.T) {
assert
.
False
(
t
,
ok
)
assert
.
False
(
t
,
ok
)
}
}
func
TestIsValidProviderAmount
(
t
*
testing
.
T
)
{
t
.
Parallel
()
assert
.
True
(
t
,
isValidProviderAmount
(
0.01
))
assert
.
False
(
t
,
isValidProviderAmount
(
0
))
assert
.
False
(
t
,
isValidProviderAmount
(
-
1
))
assert
.
False
(
t
,
isValidProviderAmount
(
math
.
NaN
()))
assert
.
False
(
t
,
isValidProviderAmount
(
math
.
Inf
(
1
)))
}
func
TestValidateProviderNotificationMetadataRejectsAlipaySnapshotMismatch
(
t
*
testing
.
T
)
{
func
TestValidateProviderNotificationMetadataRejectsAlipaySnapshotMismatch
(
t
*
testing
.
T
)
{
t
.
Parallel
()
t
.
Parallel
()
...
...
backend/internal/service/payment_order.go
View file @
ddf80f5e
...
@@ -139,6 +139,10 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
...
@@ -139,6 +139,10 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
tm
=
defaultOrderTimeoutMin
tm
=
defaultOrderTimeoutMin
}
}
exp
:=
time
.
Now
()
.
Add
(
time
.
Duration
(
tm
)
*
time
.
Minute
)
exp
:=
time
.
Now
()
.
Add
(
time
.
Duration
(
tm
)
*
time
.
Minute
)
outTradeNo
,
err
:=
s
.
allocateOutTradeNo
(
ctx
,
tx
)
if
err
!=
nil
{
return
nil
,
err
}
providerSnapshot
:=
buildPaymentOrderProviderSnapshot
(
sel
,
req
)
providerSnapshot
:=
buildPaymentOrderProviderSnapshot
(
sel
,
req
)
selectedInstanceID
:=
""
selectedInstanceID
:=
""
selectedProviderKey
:=
""
selectedProviderKey
:=
""
...
@@ -155,7 +159,7 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
...
@@ -155,7 +159,7 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
SetPayAmount
(
payAmount
)
.
SetPayAmount
(
payAmount
)
.
SetFeeRate
(
feeRate
)
.
SetFeeRate
(
feeRate
)
.
SetRechargeCode
(
""
)
.
SetRechargeCode
(
""
)
.
SetOutTradeNo
(
generateO
utTradeNo
()
)
.
SetOutTradeNo
(
o
utTradeNo
)
.
SetPaymentType
(
req
.
PaymentType
)
.
SetPaymentType
(
req
.
PaymentType
)
.
SetPaymentTradeNo
(
""
)
.
SetPaymentTradeNo
(
""
)
.
SetOrderType
(
req
.
OrderType
)
.
SetOrderType
(
req
.
OrderType
)
.
...
@@ -193,6 +197,21 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
...
@@ -193,6 +197,21 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
return
order
,
nil
return
order
,
nil
}
}
func
(
s
*
PaymentService
)
allocateOutTradeNo
(
ctx
context
.
Context
,
tx
*
dbent
.
Tx
)
(
string
,
error
)
{
const
maxAttempts
=
5
for
attempt
:=
0
;
attempt
<
maxAttempts
;
attempt
++
{
candidate
:=
generateOutTradeNo
()
exists
,
err
:=
tx
.
PaymentOrder
.
Query
()
.
Where
(
paymentorder
.
OutTradeNo
(
candidate
))
.
Exist
(
ctx
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"check out_trade_no uniqueness: %w"
,
err
)
}
if
!
exists
{
return
candidate
,
nil
}
}
return
""
,
fmt
.
Errorf
(
"generate unique out_trade_no: exhausted %d attempts"
,
maxAttempts
)
}
func
(
s
*
PaymentService
)
checkPendingLimit
(
ctx
context
.
Context
,
tx
*
dbent
.
Tx
,
userID
int64
,
max
int
)
error
{
func
(
s
*
PaymentService
)
checkPendingLimit
(
ctx
context
.
Context
,
tx
*
dbent
.
Tx
,
userID
int64
,
max
int
)
error
{
if
max
<=
0
{
if
max
<=
0
{
max
=
defaultMaxPendingOrders
max
=
defaultMaxPendingOrders
...
@@ -360,13 +379,13 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen
...
@@ -360,13 +379,13 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen
}
}
subject
:=
s
.
buildPaymentSubject
(
plan
,
limitAmount
,
cfg
)
subject
:=
s
.
buildPaymentSubject
(
plan
,
limitAmount
,
cfg
)
outTradeNo
:=
order
.
OutTradeNo
outTradeNo
:=
order
.
OutTradeNo
canonicalReturnURL
,
err
:=
CanonicalizeReturnURL
(
req
.
ReturnURL
,
req
.
SrcHost
)
canonicalReturnURL
,
err
:=
CanonicalizeReturnURL
(
req
.
ReturnURL
,
req
.
SrcHost
,
req
.
SrcURL
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
resumeToken
:=
""
resumeToken
:=
""
if
resume
:=
s
.
paymentResume
();
resume
!=
nil
{
if
resume
:=
s
.
paymentResume
();
resume
!=
nil
{
if
resume
.
isSigningConfigured
()
{
if
canonicalReturnURL
!=
""
&&
resume
.
isSigningConfigured
()
{
resumeToken
,
err
=
resume
.
CreateToken
(
ResumeTokenClaims
{
resumeToken
,
err
=
resume
.
CreateToken
(
ResumeTokenClaims
{
OrderID
:
order
.
ID
,
OrderID
:
order
.
ID
,
UserID
:
order
.
UserID
,
UserID
:
order
.
UserID
,
...
@@ -380,7 +399,7 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen
...
@@ -380,7 +399,7 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen
}
}
}
}
}
}
providerReturnURL
,
err
:=
buildPaymentReturnURL
(
canonicalReturnURL
,
order
.
ID
,
resumeToken
)
providerReturnURL
,
err
:=
buildPaymentReturnURL
(
canonicalReturnURL
,
order
.
ID
,
outTradeNo
,
resumeToken
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
...
@@ -482,6 +501,9 @@ func (s *PaymentService) buildWeChatOAuthRequiredResponse(ctx context.Context, r
...
@@ -482,6 +501,9 @@ func (s *PaymentService) buildWeChatOAuthRequiredResponse(ctx context.Context, r
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
if
err
:=
s
.
paymentResume
()
.
ensureSigningKey
();
err
!=
nil
{
return
nil
,
err
}
authorizeURL
,
err
:=
buildWeChatPaymentOAuthStartURL
(
req
,
"snsapi_base"
)
authorizeURL
,
err
:=
buildWeChatPaymentOAuthStartURL
(
req
,
"snsapi_base"
)
if
err
!=
nil
{
if
err
!=
nil
{
...
...
backend/internal/service/payment_order_jsapi_test.go
View file @
ddf80f5e
...
@@ -31,3 +31,68 @@ func TestUsesOfficialWxpayVisibleMethodDerivesFromEnabledProviderInstance(t *tes
...
@@ -31,3 +31,68 @@ func TestUsesOfficialWxpayVisibleMethodDerivesFromEnabledProviderInstance(t *tes
t
.
Fatal
(
"expected official wxpay visible method to be detected from enabled provider instance"
)
t
.
Fatal
(
"expected official wxpay visible method to be detected from enabled provider instance"
)
}
}
}
}
func
TestUsesOfficialWxpayVisibleMethodRespectsConfiguredSourceWhenMultipleProvidersEnabled
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
source
string
wantOfficial
bool
}{
{
name
:
"official source selected"
,
source
:
VisibleMethodSourceOfficialWechat
,
wantOfficial
:
true
,
},
{
name
:
"easypay source selected"
,
source
:
VisibleMethodSourceEasyPayWechat
,
wantOfficial
:
false
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
client
:=
newPaymentConfigServiceTestClient
(
t
)
_
,
err
:=
client
.
PaymentProviderInstance
.
Create
()
.
SetProviderKey
(
payment
.
TypeWxpay
)
.
SetName
(
"Official WeChat"
)
.
SetConfig
(
"{}"
)
.
SetSupportedTypes
(
"wxpay"
)
.
SetEnabled
(
true
)
.
SetSortOrder
(
1
)
.
Save
(
ctx
)
if
err
!=
nil
{
t
.
Fatalf
(
"create official wxpay instance: %v"
,
err
)
}
_
,
err
=
client
.
PaymentProviderInstance
.
Create
()
.
SetProviderKey
(
payment
.
TypeEasyPay
)
.
SetName
(
"EasyPay WeChat"
)
.
SetConfig
(
"{}"
)
.
SetSupportedTypes
(
"wxpay"
)
.
SetEnabled
(
true
)
.
SetSortOrder
(
2
)
.
Save
(
ctx
)
if
err
!=
nil
{
t
.
Fatalf
(
"create easypay wxpay instance: %v"
,
err
)
}
svc
:=
&
PaymentService
{
configService
:
&
PaymentConfigService
{
entClient
:
client
,
settingRepo
:
&
paymentConfigSettingRepoStub
{
values
:
map
[
string
]
string
{
SettingPaymentVisibleMethodWxpaySource
:
tt
.
source
,
},
},
},
}
if
got
:=
svc
.
usesOfficialWxpayVisibleMethod
(
ctx
);
got
!=
tt
.
wantOfficial
{
t
.
Fatalf
(
"usesOfficialWxpayVisibleMethod() = %v, want %v"
,
got
,
tt
.
wantOfficial
)
}
})
}
}
backend/internal/service/payment_order_lifecycle.go
View file @
ddf80f5e
...
@@ -150,6 +150,20 @@ func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) s
...
@@ -150,6 +150,20 @@ func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) s
return
""
return
""
}
}
if
resp
.
Status
==
payment
.
ProviderStatusPaid
{
if
resp
.
Status
==
payment
.
ProviderStatusPaid
{
if
!
isValidProviderAmount
(
resp
.
Amount
)
{
s
.
writeAuditLog
(
ctx
,
o
.
ID
,
"PAYMENT_INVALID_AMOUNT"
,
prov
.
ProviderKey
(),
map
[
string
]
any
{
"expected"
:
o
.
PayAmount
,
"paid"
:
resp
.
Amount
,
"tradeNo"
:
resp
.
TradeNo
,
"queryRef"
:
queryRef
,
})
slog
.
Warn
(
"query upstream returned invalid paid amount"
,
"orderID"
,
o
.
ID
,
"queryRef"
,
queryRef
,
"paid"
,
resp
.
Amount
)
retriedResp
,
retryOK
:=
requeryPaidOrderOnce
(
ctx
,
prov
,
queryRef
)
if
!
retryOK
{
return
""
}
resp
=
retriedResp
}
notificationTradeNo
:=
o
.
PaymentTradeNo
notificationTradeNo
:=
o
.
PaymentTradeNo
if
upstreamTradeNo
:=
strings
.
TrimSpace
(
resp
.
TradeNo
);
paymentOrderShouldPersistUpstreamTradeNo
(
queryRef
,
upstreamTradeNo
,
notificationTradeNo
)
{
if
upstreamTradeNo
:=
strings
.
TrimSpace
(
resp
.
TradeNo
);
paymentOrderShouldPersistUpstreamTradeNo
(
queryRef
,
upstreamTradeNo
,
notificationTradeNo
)
{
if
_
,
updateErr
:=
s
.
entClient
.
PaymentOrder
.
Update
()
.
if
_
,
updateErr
:=
s
.
entClient
.
PaymentOrder
.
Update
()
.
...
@@ -174,6 +188,21 @@ func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) s
...
@@ -174,6 +188,21 @@ func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) s
return
""
return
""
}
}
func
requeryPaidOrderOnce
(
ctx
context
.
Context
,
prov
payment
.
Provider
,
queryRef
string
)
(
*
payment
.
QueryOrderResponse
,
bool
)
{
if
prov
==
nil
||
strings
.
TrimSpace
(
queryRef
)
==
""
{
return
nil
,
false
}
resp
,
err
:=
prov
.
QueryOrder
(
ctx
,
queryRef
)
if
err
!=
nil
{
slog
.
Warn
(
"query upstream retry failed"
,
"queryRef"
,
queryRef
,
"error"
,
err
)
return
nil
,
false
}
if
resp
==
nil
||
resp
.
Status
!=
payment
.
ProviderStatusPaid
||
!
isValidProviderAmount
(
resp
.
Amount
)
{
return
nil
,
false
}
return
resp
,
true
}
func
paymentOrderQueryReference
(
order
*
dbent
.
PaymentOrder
,
prov
payment
.
Provider
)
string
{
func
paymentOrderQueryReference
(
order
*
dbent
.
PaymentOrder
,
prov
payment
.
Provider
)
string
{
if
order
==
nil
{
if
order
==
nil
{
return
""
return
""
...
@@ -224,6 +253,10 @@ func paymentOrderShouldPersistUpstreamTradeNo(queryRef, upstreamTradeNo, current
...
@@ -224,6 +253,10 @@ func paymentOrderShouldPersistUpstreamTradeNo(queryRef, upstreamTradeNo, current
// if a payment was made, and processes it if so. This handles the case where
// if a payment was made, and processes it if so. This handles the case where
// the provider's notify callback was missed (e.g. EasyPay popup mode).
// the provider's notify callback was missed (e.g. EasyPay popup mode).
func
(
s
*
PaymentService
)
VerifyOrderByOutTradeNo
(
ctx
context
.
Context
,
outTradeNo
string
,
userID
int64
)
(
*
dbent
.
PaymentOrder
,
error
)
{
func
(
s
*
PaymentService
)
VerifyOrderByOutTradeNo
(
ctx
context
.
Context
,
outTradeNo
string
,
userID
int64
)
(
*
dbent
.
PaymentOrder
,
error
)
{
outTradeNo
,
err
:=
normalizeOrderLookupOutTradeNo
(
outTradeNo
)
if
err
!=
nil
{
return
nil
,
err
}
o
,
err
:=
s
.
entClient
.
PaymentOrder
.
Query
()
.
o
,
err
:=
s
.
entClient
.
PaymentOrder
.
Query
()
.
Where
(
paymentorder
.
OutTradeNo
(
outTradeNo
))
.
Where
(
paymentorder
.
OutTradeNo
(
outTradeNo
))
.
Only
(
ctx
)
Only
(
ctx
)
...
@@ -251,6 +284,10 @@ func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo
...
@@ -251,6 +284,10 @@ func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo
// triggering any upstream reconciliation. Signed resume-token recovery is the
// triggering any upstream reconciliation. Signed resume-token recovery is the
// only public recovery path allowed to query upstream state.
// only public recovery path allowed to query upstream state.
func
(
s
*
PaymentService
)
VerifyOrderPublic
(
ctx
context
.
Context
,
outTradeNo
string
)
(
*
dbent
.
PaymentOrder
,
error
)
{
func
(
s
*
PaymentService
)
VerifyOrderPublic
(
ctx
context
.
Context
,
outTradeNo
string
)
(
*
dbent
.
PaymentOrder
,
error
)
{
outTradeNo
,
err
:=
normalizeOrderLookupOutTradeNo
(
outTradeNo
)
if
err
!=
nil
{
return
nil
,
err
}
o
,
err
:=
s
.
entClient
.
PaymentOrder
.
Query
()
.
o
,
err
:=
s
.
entClient
.
PaymentOrder
.
Query
()
.
Where
(
paymentorder
.
OutTradeNo
(
outTradeNo
))
.
Where
(
paymentorder
.
OutTradeNo
(
outTradeNo
))
.
Only
(
ctx
)
Only
(
ctx
)
...
@@ -260,6 +297,27 @@ func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo strin
...
@@ -260,6 +297,27 @@ func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo strin
return
o
,
nil
return
o
,
nil
}
}
func
normalizeOrderLookupOutTradeNo
(
raw
string
)
(
string
,
error
)
{
outTradeNo
:=
strings
.
TrimSpace
(
raw
)
if
outTradeNo
==
""
{
return
""
,
infraerrors
.
BadRequest
(
"INVALID_OUT_TRADE_NO"
,
"out_trade_no is required"
)
}
if
len
(
outTradeNo
)
>
64
{
return
""
,
infraerrors
.
BadRequest
(
"INVALID_OUT_TRADE_NO"
,
"out_trade_no is invalid"
)
}
for
_
,
ch
:=
range
outTradeNo
{
switch
{
case
ch
>=
'a'
&&
ch
<=
'z'
:
case
ch
>=
'A'
&&
ch
<=
'Z'
:
case
ch
>=
'0'
&&
ch
<=
'9'
:
case
ch
==
'_'
||
ch
==
'-'
:
default
:
return
""
,
infraerrors
.
BadRequest
(
"INVALID_OUT_TRADE_NO"
,
"out_trade_no is invalid"
)
}
}
return
outTradeNo
,
nil
}
func
(
s
*
PaymentService
)
ExpireTimedOutOrders
(
ctx
context
.
Context
)
(
int
,
error
)
{
func
(
s
*
PaymentService
)
ExpireTimedOutOrders
(
ctx
context
.
Context
)
(
int
,
error
)
{
now
:=
time
.
Now
()
now
:=
time
.
Now
()
orders
,
err
:=
s
.
entClient
.
PaymentOrder
.
Query
()
.
Where
(
paymentorder
.
StatusEQ
(
OrderStatusPending
),
paymentorder
.
ExpiresAtLTE
(
now
))
.
All
(
ctx
)
orders
,
err
:=
s
.
entClient
.
PaymentOrder
.
Query
()
.
Where
(
paymentorder
.
StatusEQ
(
OrderStatusPending
),
paymentorder
.
ExpiresAtLTE
(
now
))
.
All
(
ctx
)
...
...
backend/internal/service/payment_order_lifecycle_test.go
View file @
ddf80f5e
...
@@ -21,6 +21,8 @@ import (
...
@@ -21,6 +21,8 @@ import (
type
paymentOrderLifecycleQueryProvider
struct
{
type
paymentOrderLifecycleQueryProvider
struct
{
lastQueryTradeNo
string
lastQueryTradeNo
string
queryCalls
int
responses
[]
*
payment
.
QueryOrderResponse
resp
*
payment
.
QueryOrderResponse
resp
*
payment
.
QueryOrderResponse
}
}
...
@@ -48,6 +50,14 @@ func (p *paymentOrderLifecycleQueryProvider) CreatePayment(context.Context, paym
...
@@ -48,6 +50,14 @@ func (p *paymentOrderLifecycleQueryProvider) CreatePayment(context.Context, paym
func
(
p
*
paymentOrderLifecycleQueryProvider
)
QueryOrder
(
_
context
.
Context
,
tradeNo
string
)
(
*
payment
.
QueryOrderResponse
,
error
)
{
func
(
p
*
paymentOrderLifecycleQueryProvider
)
QueryOrder
(
_
context
.
Context
,
tradeNo
string
)
(
*
payment
.
QueryOrderResponse
,
error
)
{
p
.
lastQueryTradeNo
=
tradeNo
p
.
lastQueryTradeNo
=
tradeNo
p
.
queryCalls
++
if
len
(
p
.
responses
)
>
0
{
resp
:=
p
.
responses
[
0
]
if
len
(
p
.
responses
)
>
1
{
p
.
responses
=
p
.
responses
[
1
:
]
}
return
resp
,
nil
}
return
p
.
resp
,
nil
return
p
.
resp
,
nil
}
}
...
@@ -234,6 +244,194 @@ func TestVerifyOrderByOutTradeNoBackfillsTradeNoFromPaidQuery(t *testing.T) {
...
@@ -234,6 +244,194 @@ func TestVerifyOrderByOutTradeNoBackfillsTradeNoFromPaidQuery(t *testing.T) {
require
.
Equal
(
t
,
user
.
ID
,
redeemRepo
.
useCalls
[
0
]
.
userID
)
require
.
Equal
(
t
,
user
.
ID
,
redeemRepo
.
useCalls
[
0
]
.
userID
)
}
}
func
TestVerifyOrderByOutTradeNoRetriesZeroAmountPaidQueryOnce
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
client
:=
newPaymentOrderLifecycleTestClient
(
t
)
user
,
err
:=
client
.
User
.
Create
()
.
SetEmail
(
"checkpaid-retry@example.com"
)
.
SetPasswordHash
(
"hash"
)
.
SetUsername
(
"checkpaid-retry-user"
)
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
order
,
err
:=
client
.
PaymentOrder
.
Create
()
.
SetUserID
(
user
.
ID
)
.
SetUserEmail
(
user
.
Email
)
.
SetUserName
(
user
.
Username
)
.
SetAmount
(
88
)
.
SetPayAmount
(
88
)
.
SetFeeRate
(
0
)
.
SetRechargeCode
(
"CHECKPAID-UPSTREAM-RETRY"
)
.
SetOutTradeNo
(
"sub2_checkpaid_retry_zero_amount"
)
.
SetPaymentType
(
payment
.
TypeAlipay
)
.
SetPaymentTradeNo
(
""
)
.
SetOrderType
(
payment
.
OrderTypeBalance
)
.
SetStatus
(
OrderStatusPending
)
.
SetExpiresAt
(
time
.
Now
()
.
Add
(
time
.
Hour
))
.
SetClientIP
(
"127.0.0.1"
)
.
SetSrcHost
(
"api.example.com"
)
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
userRepo
:=
&
mockUserRepo
{
getByIDUser
:
&
User
{
ID
:
user
.
ID
,
Email
:
user
.
Email
,
Username
:
user
.
Username
,
Balance
:
0
,
},
}
userRepo
.
updateBalanceFn
=
func
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
error
{
require
.
Equal
(
t
,
user
.
ID
,
id
)
if
userRepo
.
getByIDUser
!=
nil
{
userRepo
.
getByIDUser
.
Balance
+=
amount
}
return
nil
}
redeemRepo
:=
&
paymentOrderLifecycleRedeemRepo
{
codesByCode
:
map
[
string
]
*
RedeemCode
{
order
.
RechargeCode
:
{
ID
:
1
,
Code
:
order
.
RechargeCode
,
Type
:
RedeemTypeBalance
,
Value
:
order
.
Amount
,
Status
:
StatusUnused
,
},
},
}
redeemService
:=
NewRedeemService
(
redeemRepo
,
userRepo
,
nil
,
nil
,
nil
,
client
,
nil
,
)
registry
:=
payment
.
NewRegistry
()
provider
:=
&
paymentOrderLifecycleQueryProvider
{
responses
:
[]
*
payment
.
QueryOrderResponse
{
{
TradeNo
:
"upstream-trade-zero"
,
Status
:
payment
.
ProviderStatusPaid
,
Amount
:
0
,
},
{
TradeNo
:
"upstream-trade-retry"
,
Status
:
payment
.
ProviderStatusPaid
,
Amount
:
88
,
},
},
}
registry
.
Register
(
provider
)
svc
:=
&
PaymentService
{
entClient
:
client
,
registry
:
registry
,
redeemService
:
redeemService
,
userRepo
:
userRepo
,
providersLoaded
:
true
,
}
got
,
err
:=
svc
.
VerifyOrderByOutTradeNo
(
ctx
,
order
.
OutTradeNo
,
user
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
2
,
provider
.
queryCalls
)
require
.
Equal
(
t
,
OrderStatusCompleted
,
got
.
Status
)
require
.
Equal
(
t
,
"upstream-trade-retry"
,
got
.
PaymentTradeNo
)
}
func
TestVerifyOrderByOutTradeNoRejectsPaidQueryWithZeroAmount
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
client
:=
newPaymentOrderLifecycleTestClient
(
t
)
user
,
err
:=
client
.
User
.
Create
()
.
SetEmail
(
"checkpaid-zero-amount@example.com"
)
.
SetPasswordHash
(
"hash"
)
.
SetUsername
(
"checkpaid-zero-amount-user"
)
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
order
,
err
:=
client
.
PaymentOrder
.
Create
()
.
SetUserID
(
user
.
ID
)
.
SetUserEmail
(
user
.
Email
)
.
SetUserName
(
user
.
Username
)
.
SetAmount
(
88
)
.
SetPayAmount
(
88
)
.
SetFeeRate
(
0
)
.
SetRechargeCode
(
"CHECKPAID-ZERO-AMOUNT"
)
.
SetOutTradeNo
(
"sub2_checkpaid_zero_amount"
)
.
SetPaymentType
(
payment
.
TypeAlipay
)
.
SetPaymentTradeNo
(
""
)
.
SetOrderType
(
payment
.
OrderTypeBalance
)
.
SetStatus
(
OrderStatusPending
)
.
SetExpiresAt
(
time
.
Now
()
.
Add
(
time
.
Hour
))
.
SetClientIP
(
"127.0.0.1"
)
.
SetSrcHost
(
"api.example.com"
)
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
userRepo
:=
&
mockUserRepo
{
getByIDUser
:
&
User
{
ID
:
user
.
ID
,
Email
:
user
.
Email
,
Username
:
user
.
Username
,
Balance
:
0
,
},
}
redeemRepo
:=
&
paymentOrderLifecycleRedeemRepo
{
codesByCode
:
map
[
string
]
*
RedeemCode
{
order
.
RechargeCode
:
{
ID
:
1
,
Code
:
order
.
RechargeCode
,
Type
:
RedeemTypeBalance
,
Value
:
order
.
Amount
,
Status
:
StatusUnused
,
},
},
}
redeemService
:=
NewRedeemService
(
redeemRepo
,
userRepo
,
nil
,
nil
,
nil
,
client
,
nil
,
)
registry
:=
payment
.
NewRegistry
()
provider
:=
&
paymentOrderLifecycleQueryProvider
{
resp
:
&
payment
.
QueryOrderResponse
{
TradeNo
:
"upstream-trade-zero"
,
Status
:
payment
.
ProviderStatusPaid
,
Amount
:
0
,
},
}
registry
.
Register
(
provider
)
svc
:=
&
PaymentService
{
entClient
:
client
,
registry
:
registry
,
redeemService
:
redeemService
,
userRepo
:
userRepo
,
providersLoaded
:
true
,
}
got
,
err
:=
svc
.
VerifyOrderByOutTradeNo
(
ctx
,
order
.
OutTradeNo
,
user
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
order
.
OutTradeNo
,
provider
.
lastQueryTradeNo
)
require
.
Equal
(
t
,
OrderStatusPending
,
got
.
Status
)
require
.
Empty
(
t
,
got
.
PaymentTradeNo
)
reloaded
,
err
:=
client
.
PaymentOrder
.
Get
(
ctx
,
order
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
OrderStatusPending
,
reloaded
.
Status
)
require
.
Empty
(
t
,
reloaded
.
PaymentTradeNo
)
require
.
Equal
(
t
,
0.0
,
userRepo
.
getByIDUser
.
Balance
)
require
.
Empty
(
t
,
redeemRepo
.
useCalls
)
}
func
TestVerifyOrderByOutTradeNoUsesOutTradeNoWhenPaymentTradeNoAlreadyExistsForAlipay
(
t
*
testing
.
T
)
{
func
TestVerifyOrderByOutTradeNoUsesOutTradeNoWhenPaymentTradeNoAlreadyExistsForAlipay
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
ctx
:=
context
.
Background
()
client
:=
newPaymentOrderLifecycleTestClient
(
t
)
client
:=
newPaymentOrderLifecycleTestClient
(
t
)
...
...
backend/internal/service/payment_order_result_test.go
View file @
ddf80f5e
...
@@ -2,6 +2,7 @@ package service
...
@@ -2,6 +2,7 @@ package service
import
(
import
(
"context"
"context"
"strings"
"testing"
"testing"
"time"
"time"
...
@@ -91,6 +92,8 @@ func TestBuildCreateOrderResponseCopiesJSAPIPayload(t *testing.T) {
...
@@ -91,6 +92,8 @@ func TestBuildCreateOrderResponseCopiesJSAPIPayload(t *testing.T) {
}
}
func
TestMaybeBuildWeChatOAuthRequiredResponse
(
t
*
testing
.
T
)
{
func
TestMaybeBuildWeChatOAuthRequiredResponse
(
t
*
testing
.
T
)
{
t
.
Setenv
(
"PAYMENT_RESUME_SIGNING_KEY"
,
"0123456789abcdef0123456789abcdef"
)
svc
:=
newWeChatPaymentOAuthTestService
(
map
[
string
]
string
{
svc
:=
newWeChatPaymentOAuthTestService
(
map
[
string
]
string
{
SettingKeyWeChatConnectEnabled
:
"true"
,
SettingKeyWeChatConnectEnabled
:
"true"
,
SettingKeyWeChatConnectAppID
:
"wx123456"
,
SettingKeyWeChatConnectAppID
:
"wx123456"
,
...
@@ -159,6 +162,83 @@ func TestMaybeBuildWeChatOAuthRequiredResponseRequiresMPConfigInWeChat(t *testin
...
@@ -159,6 +162,83 @@ func TestMaybeBuildWeChatOAuthRequiredResponseRequiresMPConfigInWeChat(t *testin
}
}
}
}
func
TestMaybeBuildWeChatOAuthRequiredResponseRequiresResumeSigningKey
(
t
*
testing
.
T
)
{
t
.
Parallel
()
svc
:=
&
PaymentService
{
configService
:
&
PaymentConfigService
{
settingRepo
:
&
paymentConfigSettingRepoStub
{
values
:
map
[
string
]
string
{
SettingKeyWeChatConnectEnabled
:
"true"
,
SettingKeyWeChatConnectAppID
:
"wx123456"
,
SettingKeyWeChatConnectAppSecret
:
"wechat-secret"
,
SettingKeyWeChatConnectMode
:
"mp"
,
SettingKeyWeChatConnectScopes
:
"snsapi_base"
,
SettingKeyWeChatConnectRedirectURL
:
"https://api.example.com/api/v1/auth/oauth/wechat/callback"
,
SettingKeyWeChatConnectFrontendRedirectURL
:
"/auth/wechat/callback"
,
}},
// Intentionally missing payment resume signing key.
encryptionKey
:
nil
,
},
}
resp
,
err
:=
svc
.
maybeBuildWeChatOAuthRequiredResponse
(
context
.
Background
(),
CreateOrderRequest
{
Amount
:
12.5
,
PaymentType
:
payment
.
TypeWxpay
,
IsWeChatBrowser
:
true
,
SrcURL
:
"https://merchant.example/payment?from=wechat"
,
OrderType
:
payment
.
OrderTypeBalance
,
},
12.5
,
12.88
,
0.03
)
if
resp
!=
nil
{
t
.
Fatalf
(
"expected nil response, got %+v"
,
resp
)
}
if
err
==
nil
{
t
.
Fatal
(
"expected error, got nil"
)
}
appErr
:=
infraerrors
.
FromError
(
err
)
if
appErr
.
Reason
!=
"PAYMENT_RESUME_NOT_CONFIGURED"
{
t
.
Fatalf
(
"reason = %q, want %q"
,
appErr
.
Reason
,
"PAYMENT_RESUME_NOT_CONFIGURED"
)
}
}
func
TestMaybeBuildWeChatOAuthRequiredResponseFallsBackToConfiguredLegacySigningKey
(
t
*
testing
.
T
)
{
svc
:=
&
PaymentService
{
configService
:
&
PaymentConfigService
{
settingRepo
:
&
paymentConfigSettingRepoStub
{
values
:
map
[
string
]
string
{
SettingKeyWeChatConnectEnabled
:
"true"
,
SettingKeyWeChatConnectAppID
:
"wx123456"
,
SettingKeyWeChatConnectAppSecret
:
"wechat-secret"
,
SettingKeyWeChatConnectMode
:
"mp"
,
SettingKeyWeChatConnectScopes
:
"snsapi_base"
,
SettingKeyWeChatConnectRedirectURL
:
"https://api.example.com/api/v1/auth/oauth/wechat/callback"
,
SettingKeyWeChatConnectFrontendRedirectURL
:
"/auth/wechat/callback"
,
}},
// Legacy stable signing key remains available for no-config upgrade compatibility.
encryptionKey
:
[]
byte
(
"0123456789abcdef0123456789abcdef"
),
},
}
resp
,
err
:=
svc
.
maybeBuildWeChatOAuthRequiredResponse
(
context
.
Background
(),
CreateOrderRequest
{
Amount
:
12.5
,
PaymentType
:
payment
.
TypeWxpay
,
IsWeChatBrowser
:
true
,
SrcURL
:
"https://merchant.example/payment?from=wechat"
,
OrderType
:
payment
.
OrderTypeBalance
,
},
12.5
,
12.88
,
0.03
)
if
err
!=
nil
{
t
.
Fatalf
(
"expected nil error, got %v"
,
err
)
}
if
resp
==
nil
{
t
.
Fatal
(
"expected oauth-required response, got nil"
)
}
if
resp
.
ResultType
!=
payment
.
CreatePaymentResultOAuthRequired
{
t
.
Fatalf
(
"result type = %q, want %q"
,
resp
.
ResultType
,
payment
.
CreatePaymentResultOAuthRequired
)
}
if
resp
.
OAuth
==
nil
||
strings
.
TrimSpace
(
resp
.
OAuth
.
AuthorizeURL
)
==
""
{
t
.
Fatalf
(
"expected oauth redirect payload, got %+v"
,
resp
.
OAuth
)
}
}
func
TestMaybeBuildWeChatOAuthRequiredResponseForSelectionSkipsEasyPayProvider
(
t
*
testing
.
T
)
{
func
TestMaybeBuildWeChatOAuthRequiredResponseForSelectionSkipsEasyPayProvider
(
t
*
testing
.
T
)
{
svc
:=
newWeChatPaymentOAuthTestService
(
map
[
string
]
string
{
svc
:=
newWeChatPaymentOAuthTestService
(
map
[
string
]
string
{
SettingKeyWeChatConnectEnabled
:
"true"
,
SettingKeyWeChatConnectEnabled
:
"true"
,
...
@@ -189,7 +269,8 @@ func TestMaybeBuildWeChatOAuthRequiredResponseForSelectionSkipsEasyPayProvider(t
...
@@ -189,7 +269,8 @@ func TestMaybeBuildWeChatOAuthRequiredResponseForSelectionSkipsEasyPayProvider(t
func
newWeChatPaymentOAuthTestService
(
values
map
[
string
]
string
)
*
PaymentService
{
func
newWeChatPaymentOAuthTestService
(
values
map
[
string
]
string
)
*
PaymentService
{
return
&
PaymentService
{
return
&
PaymentService
{
configService
:
&
PaymentConfigService
{
configService
:
&
PaymentConfigService
{
settingRepo
:
&
paymentConfigSettingRepoStub
{
values
:
values
},
settingRepo
:
&
paymentConfigSettingRepoStub
{
values
:
values
},
encryptionKey
:
[]
byte
(
"0123456789abcdef0123456789abcdef"
),
},
},
}
}
}
}
backend/internal/service/payment_resume_lookup.go
View file @
ddf80f5e
...
@@ -6,6 +6,7 @@ import (
...
@@ -6,6 +6,7 @@ import (
"strings"
"strings"
dbent
"github.com/Wei-Shaw/sub2api/ent"
dbent
"github.com/Wei-Shaw/sub2api/ent"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
)
func
(
s
*
PaymentService
)
GetPublicOrderByResumeToken
(
ctx
context
.
Context
,
token
string
)
(
*
dbent
.
PaymentOrder
,
error
)
{
func
(
s
*
PaymentService
)
GetPublicOrderByResumeToken
(
ctx
context
.
Context
,
token
string
)
(
*
dbent
.
PaymentOrder
,
error
)
{
...
@@ -16,10 +17,13 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token
...
@@ -16,10 +17,13 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token
order
,
err
:=
s
.
entClient
.
PaymentOrder
.
Get
(
ctx
,
claims
.
OrderID
)
order
,
err
:=
s
.
entClient
.
PaymentOrder
.
Get
(
ctx
,
claims
.
OrderID
)
if
err
!=
nil
{
if
err
!=
nil
{
if
dbent
.
IsNotFound
(
err
)
{
return
nil
,
infraerrors
.
NotFound
(
"NOT_FOUND"
,
"order not found"
)
}
return
nil
,
fmt
.
Errorf
(
"get order by resume token: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"get order by resume token: %w"
,
err
)
}
}
if
claims
.
UserID
>
0
&&
order
.
UserID
!=
claims
.
UserID
{
if
claims
.
UserID
>
0
&&
order
.
UserID
!=
claims
.
UserID
{
return
nil
,
fmt
.
Errorf
(
"r
esume
t
oken
user mismatch"
)
return
nil
,
invalidR
esume
T
oken
MatchError
(
)
}
}
snapshot
:=
psOrderProviderSnapshot
(
order
)
snapshot
:=
psOrderProviderSnapshot
(
order
)
orderProviderInstanceID
:=
strings
.
TrimSpace
(
psStringValue
(
order
.
ProviderInstanceID
))
orderProviderInstanceID
:=
strings
.
TrimSpace
(
psStringValue
(
order
.
ProviderInstanceID
))
...
@@ -33,13 +37,13 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token
...
@@ -33,13 +37,13 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token
}
}
}
}
if
claims
.
ProviderInstanceID
!=
""
&&
orderProviderInstanceID
!=
claims
.
ProviderInstanceID
{
if
claims
.
ProviderInstanceID
!=
""
&&
orderProviderInstanceID
!=
claims
.
ProviderInstanceID
{
return
nil
,
fmt
.
Errorf
(
"r
esume
t
oken
provider instance mismatch"
)
return
nil
,
invalidR
esume
T
oken
MatchError
(
)
}
}
if
claims
.
ProviderKey
!=
""
&&
orderProviderKey
!=
claims
.
ProviderKey
{
if
claims
.
ProviderKey
!=
""
&&
!
strings
.
EqualFold
(
orderProviderKey
,
claims
.
ProviderKey
)
{
return
nil
,
fmt
.
Errorf
(
"r
esume
t
oken
provider key mismatch"
)
return
nil
,
invalidR
esume
T
oken
MatchError
(
)
}
}
if
claims
.
PaymentType
!=
""
&&
strings
.
TrimSpace
(
order
.
PaymentType
)
!=
claims
.
PaymentType
{
if
claims
.
PaymentType
!=
""
&&
NormalizeVisibleMethod
(
order
.
PaymentType
)
!=
NormalizeVisibleMethod
(
claims
.
PaymentType
)
{
return
nil
,
fmt
.
Errorf
(
"r
esume
t
oken
payment type mismatch"
)
return
nil
,
invalidR
esume
T
oken
MatchError
(
)
}
}
if
order
.
Status
==
OrderStatusPending
||
order
.
Status
==
OrderStatusExpired
{
if
order
.
Status
==
OrderStatusPending
||
order
.
Status
==
OrderStatusExpired
{
result
:=
s
.
checkPaid
(
ctx
,
order
)
result
:=
s
.
checkPaid
(
ctx
,
order
)
...
@@ -54,6 +58,10 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token
...
@@ -54,6 +58,10 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token
return
order
,
nil
return
order
,
nil
}
}
func
invalidResumeTokenMatchError
()
error
{
return
infraerrors
.
BadRequest
(
"INVALID_RESUME_TOKEN"
,
"resume token does not match the payment order"
)
}
func
(
s
*
PaymentService
)
ParseWeChatPaymentResumeToken
(
token
string
)
(
*
WeChatPaymentResumeClaims
,
error
)
{
func
(
s
*
PaymentService
)
ParseWeChatPaymentResumeToken
(
token
string
)
(
*
WeChatPaymentResumeClaims
,
error
)
{
return
s
.
paymentResume
()
.
ParseWeChatPaymentResumeToken
(
strings
.
TrimSpace
(
token
))
return
s
.
paymentResume
()
.
ParseWeChatPaymentResumeToken
(
strings
.
TrimSpace
(
token
))
}
}
backend/internal/service/payment_resume_lookup_test.go
View file @
ddf80f5e
...
@@ -8,6 +8,7 @@ import (
...
@@ -8,6 +8,7 @@ import (
"time"
"time"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/require"
)
)
...
@@ -143,7 +144,7 @@ func TestGetPublicOrderByResumeTokenRejectsSnapshotMismatch(t *testing.T) {
...
@@ -143,7 +144,7 @@ func TestGetPublicOrderByResumeTokenRejectsSnapshotMismatch(t *testing.T) {
_
,
err
=
svc
.
GetPublicOrderByResumeToken
(
ctx
,
token
)
_
,
err
=
svc
.
GetPublicOrderByResumeToken
(
ctx
,
token
)
require
.
Error
(
t
,
err
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"resume token"
)
require
.
Equal
(
t
,
"INVALID_RESUME_TOKEN"
,
infraerrors
.
Reason
(
err
)
)
}
}
func
TestGetPublicOrderByResumeTokenUsesSnapshotAuthorityWhenColumnsDiffer
(
t
*
testing
.
T
)
{
func
TestGetPublicOrderByResumeTokenUsesSnapshotAuthorityWhenColumnsDiffer
(
t
*
testing
.
T
)
{
...
@@ -302,3 +303,13 @@ func TestVerifyOrderPublicDoesNotCheckUpstreamForPendingOrder(t *testing.T) {
...
@@ -302,3 +303,13 @@ func TestVerifyOrderPublicDoesNotCheckUpstreamForPendingOrder(t *testing.T) {
require
.
Equal
(
t
,
order
.
ID
,
got
.
ID
)
require
.
Equal
(
t
,
order
.
ID
,
got
.
ID
)
require
.
Equal
(
t
,
0
,
provider
.
queryCount
)
require
.
Equal
(
t
,
0
,
provider
.
queryCount
)
}
}
func
TestVerifyOrderPublicRejectsBlankOutTradeNo
(
t
*
testing
.
T
)
{
svc
:=
&
PaymentService
{
entClient
:
newPaymentConfigServiceTestClient
(
t
),
}
_
,
err
:=
svc
.
VerifyOrderPublic
(
context
.
Background
(),
" "
)
require
.
Error
(
t
,
err
)
require
.
Equal
(
t
,
"INVALID_OUT_TRADE_NO"
,
infraerrors
.
Reason
(
err
))
}
backend/internal/service/payment_resume_service.go
View file @
ddf80f5e
package
service
package
service
import
(
import
(
"bytes"
"context"
"context"
"crypto/hmac"
"crypto/hmac"
"crypto/sha256"
"crypto/sha256"
...
@@ -68,6 +69,7 @@ type WeChatPaymentResumeClaims struct {
...
@@ -68,6 +69,7 @@ type WeChatPaymentResumeClaims struct {
type
PaymentResumeService
struct
{
type
PaymentResumeService
struct
{
signingKey
[]
byte
signingKey
[]
byte
verifyKeys
[][]
byte
}
}
type
visibleMethodLoadBalancer
struct
{
type
visibleMethodLoadBalancer
struct
{
...
@@ -75,8 +77,29 @@ type visibleMethodLoadBalancer struct {
...
@@ -75,8 +77,29 @@ type visibleMethodLoadBalancer struct {
configService
*
PaymentConfigService
configService
*
PaymentConfigService
}
}
func
NewPaymentResumeService
(
signingKey
[]
byte
)
*
PaymentResumeService
{
func
NewPaymentResumeService
(
signingKey
[]
byte
,
verifyFallbacks
...
[]
byte
)
*
PaymentResumeService
{
return
&
PaymentResumeService
{
signingKey
:
signingKey
}
svc
:=
&
PaymentResumeService
{}
if
len
(
signingKey
)
>
0
{
svc
.
signingKey
=
append
([]
byte
(
nil
),
signingKey
...
)
svc
.
verifyKeys
=
append
(
svc
.
verifyKeys
,
svc
.
signingKey
)
}
for
_
,
fallback
:=
range
verifyFallbacks
{
if
len
(
fallback
)
==
0
{
continue
}
cloned
:=
append
([]
byte
(
nil
),
fallback
...
)
duplicate
:=
false
for
_
,
existing
:=
range
svc
.
verifyKeys
{
if
bytes
.
Equal
(
existing
,
cloned
)
{
duplicate
=
true
break
}
}
if
!
duplicate
{
svc
.
verifyKeys
=
append
(
svc
.
verifyKeys
,
cloned
)
}
}
return
svc
}
}
func
(
s
*
PaymentResumeService
)
isSigningConfigured
()
bool
{
func
(
s
*
PaymentResumeService
)
isSigningConfigured
()
bool
{
...
@@ -209,7 +232,7 @@ func visibleMethodSourceSettingKey(method string) string {
...
@@ -209,7 +232,7 @@ func visibleMethodSourceSettingKey(method string) string {
}
}
}
}
func
CanonicalizeReturnURL
(
raw
string
,
srcHost
string
)
(
string
,
error
)
{
func
CanonicalizeReturnURL
(
raw
string
,
srcHost
string
,
srcURL
string
)
(
string
,
error
)
{
raw
=
strings
.
TrimSpace
(
raw
)
raw
=
strings
.
TrimSpace
(
raw
)
if
raw
==
""
{
if
raw
==
""
{
return
""
,
nil
return
""
,
nil
...
@@ -228,13 +251,29 @@ func CanonicalizeReturnURL(raw string, srcHost string) (string, error) {
...
@@ -228,13 +251,29 @@ func CanonicalizeReturnURL(raw string, srcHost string) (string, error) {
if
parsed
.
Path
!=
paymentResultReturnPath
{
if
parsed
.
Path
!=
paymentResultReturnPath
{
return
""
,
infraerrors
.
BadRequest
(
"INVALID_RETURN_URL"
,
"return_url must target the canonical internal payment result page"
)
return
""
,
infraerrors
.
BadRequest
(
"INVALID_RETURN_URL"
,
"return_url must target the canonical internal payment result page"
)
}
}
if
!
sameOrigin
Host
(
parsed
.
Host
,
srcHost
)
{
if
!
allowedReturnURL
Host
(
parsed
.
Host
,
srcHost
,
srcURL
)
{
return
""
,
infraerrors
.
BadRequest
(
"INVALID_RETURN_URL"
,
"return_url must use the same host as the current site"
)
return
""
,
infraerrors
.
BadRequest
(
"INVALID_RETURN_URL"
,
"return_url must use the same host as the current site
or browser origin
"
)
}
}
return
parsed
.
String
(),
nil
return
parsed
.
String
(),
nil
}
}
func
buildPaymentReturnURL
(
base
string
,
orderID
int64
,
resumeToken
string
)
(
string
,
error
)
{
func
allowedReturnURLHost
(
returnURLHost
string
,
requestHost
string
,
refererURL
string
)
bool
{
if
sameOriginHost
(
returnURLHost
,
requestHost
)
{
return
true
}
refererURL
=
strings
.
TrimSpace
(
refererURL
)
if
refererURL
==
""
{
return
false
}
parsedReferer
,
err
:=
url
.
Parse
(
refererURL
)
if
err
!=
nil
||
parsedReferer
.
Host
==
""
{
return
false
}
return
sameOriginHost
(
returnURLHost
,
parsedReferer
.
Host
)
}
func
buildPaymentReturnURL
(
base
string
,
orderID
int64
,
outTradeNo
string
,
resumeToken
string
)
(
string
,
error
)
{
canonical
:=
strings
.
TrimSpace
(
base
)
canonical
:=
strings
.
TrimSpace
(
base
)
if
canonical
==
""
{
if
canonical
==
""
{
return
""
,
nil
return
""
,
nil
...
@@ -253,6 +292,9 @@ func buildPaymentReturnURL(base string, orderID int64, resumeToken string) (stri
...
@@ -253,6 +292,9 @@ func buildPaymentReturnURL(base string, orderID int64, resumeToken string) (stri
if
orderID
>
0
{
if
orderID
>
0
{
query
.
Set
(
"order_id"
,
strconv
.
FormatInt
(
orderID
,
10
))
query
.
Set
(
"order_id"
,
strconv
.
FormatInt
(
orderID
,
10
))
}
}
if
strings
.
TrimSpace
(
outTradeNo
)
!=
""
{
query
.
Set
(
"out_trade_no"
,
strings
.
TrimSpace
(
outTradeNo
))
}
if
strings
.
TrimSpace
(
resumeToken
)
!=
""
{
if
strings
.
TrimSpace
(
resumeToken
)
!=
""
{
query
.
Set
(
"resume_token"
,
strings
.
TrimSpace
(
resumeToken
))
query
.
Set
(
"resume_token"
,
strings
.
TrimSpace
(
resumeToken
))
}
}
...
@@ -391,7 +433,7 @@ func (s *PaymentResumeService) parseSignedToken(token string, dest any) error {
...
@@ -391,7 +433,7 @@ func (s *PaymentResumeService) parseSignedToken(token string, dest any) error {
if
len
(
parts
)
!=
2
||
parts
[
0
]
==
""
||
parts
[
1
]
==
""
{
if
len
(
parts
)
!=
2
||
parts
[
0
]
==
""
||
parts
[
1
]
==
""
{
return
infraerrors
.
BadRequest
(
"INVALID_RESUME_TOKEN"
,
"resume token is malformed"
)
return
infraerrors
.
BadRequest
(
"INVALID_RESUME_TOKEN"
,
"resume token is malformed"
)
}
}
if
!
hmac
.
Equal
([]
byte
(
parts
[
1
]),
[]
byte
(
s
.
sign
(
parts
[
0
]))
)
{
if
!
s
.
verifySignature
(
parts
[
0
],
parts
[
1
]
)
{
return
infraerrors
.
BadRequest
(
"INVALID_RESUME_TOKEN"
,
"resume token signature mismatch"
)
return
infraerrors
.
BadRequest
(
"INVALID_RESUME_TOKEN"
,
"resume token signature mismatch"
)
}
}
payload
,
err
:=
base64
.
RawURLEncoding
.
DecodeString
(
parts
[
0
])
payload
,
err
:=
base64
.
RawURLEncoding
.
DecodeString
(
parts
[
0
])
...
@@ -401,6 +443,18 @@ func (s *PaymentResumeService) parseSignedToken(token string, dest any) error {
...
@@ -401,6 +443,18 @@ func (s *PaymentResumeService) parseSignedToken(token string, dest any) error {
return
json
.
Unmarshal
(
payload
,
dest
)
return
json
.
Unmarshal
(
payload
,
dest
)
}
}
func
(
s
*
PaymentResumeService
)
verifySignature
(
payload
string
,
signature
string
)
bool
{
if
s
==
nil
{
return
false
}
for
_
,
key
:=
range
s
.
verifyKeys
{
if
hmac
.
Equal
([]
byte
(
signature
),
[]
byte
(
signPaymentResumePayload
(
payload
,
key
)))
{
return
true
}
}
return
false
}
func
validatePaymentResumeExpiry
(
expiresAt
int64
,
code
,
message
string
)
error
{
func
validatePaymentResumeExpiry
(
expiresAt
int64
,
code
,
message
string
)
error
{
if
expiresAt
<=
0
{
if
expiresAt
<=
0
{
return
nil
return
nil
...
@@ -412,7 +466,11 @@ func validatePaymentResumeExpiry(expiresAt int64, code, message string) error {
...
@@ -412,7 +466,11 @@ func validatePaymentResumeExpiry(expiresAt int64, code, message string) error {
}
}
func
(
s
*
PaymentResumeService
)
sign
(
payload
string
)
string
{
func
(
s
*
PaymentResumeService
)
sign
(
payload
string
)
string
{
mac
:=
hmac
.
New
(
sha256
.
New
,
s
.
signingKey
)
return
signPaymentResumePayload
(
payload
,
s
.
signingKey
)
}
func
signPaymentResumePayload
(
payload
string
,
key
[]
byte
)
string
{
mac
:=
hmac
.
New
(
sha256
.
New
,
key
)
_
,
_
=
mac
.
Write
([]
byte
(
payload
))
_
,
_
=
mac
.
Write
([]
byte
(
payload
))
return
base64
.
RawURLEncoding
.
EncodeToString
(
mac
.
Sum
(
nil
))
return
base64
.
RawURLEncoding
.
EncodeToString
(
mac
.
Sum
(
nil
))
}
}
backend/internal/service/payment_resume_service_test.go
View file @
ddf80f5e
...
@@ -14,6 +14,7 @@ import (
...
@@ -14,6 +14,7 @@ import (
"time"
"time"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
)
func
TestNormalizeVisibleMethods
(
t
*
testing
.
T
)
{
func
TestNormalizeVisibleMethods
(
t
*
testing
.
T
)
{
...
@@ -64,7 +65,7 @@ func TestNormalizePaymentSource(t *testing.T) {
...
@@ -64,7 +65,7 @@ func TestNormalizePaymentSource(t *testing.T) {
func
TestCanonicalizeReturnURL
(
t
*
testing
.
T
)
{
func
TestCanonicalizeReturnURL
(
t
*
testing
.
T
)
{
t
.
Parallel
()
t
.
Parallel
()
got
,
err
:=
CanonicalizeReturnURL
(
"https://example.com/payment/result?b=2#a"
,
"example.com"
)
got
,
err
:=
CanonicalizeReturnURL
(
"https://example.com/payment/result?b=2#a"
,
"example.com"
,
""
)
if
err
!=
nil
{
if
err
!=
nil
{
t
.
Fatalf
(
"CanonicalizeReturnURL returned error: %v"
,
err
)
t
.
Fatalf
(
"CanonicalizeReturnURL returned error: %v"
,
err
)
}
}
...
@@ -76,7 +77,7 @@ func TestCanonicalizeReturnURL(t *testing.T) {
...
@@ -76,7 +77,7 @@ func TestCanonicalizeReturnURL(t *testing.T) {
func
TestCanonicalizeReturnURLRejectsRelativeURL
(
t
*
testing
.
T
)
{
func
TestCanonicalizeReturnURLRejectsRelativeURL
(
t
*
testing
.
T
)
{
t
.
Parallel
()
t
.
Parallel
()
if
_
,
err
:=
CanonicalizeReturnURL
(
"/payment/result"
,
"example.com"
);
err
==
nil
{
if
_
,
err
:=
CanonicalizeReturnURL
(
"/payment/result"
,
"example.com"
,
""
);
err
==
nil
{
t
.
Fatal
(
"CanonicalizeReturnURL should reject relative URLs"
)
t
.
Fatal
(
"CanonicalizeReturnURL should reject relative URLs"
)
}
}
}
}
...
@@ -84,15 +85,31 @@ func TestCanonicalizeReturnURLRejectsRelativeURL(t *testing.T) {
...
@@ -84,15 +85,31 @@ func TestCanonicalizeReturnURLRejectsRelativeURL(t *testing.T) {
func
TestCanonicalizeReturnURLRejectsExternalHost
(
t
*
testing
.
T
)
{
func
TestCanonicalizeReturnURLRejectsExternalHost
(
t
*
testing
.
T
)
{
t
.
Parallel
()
t
.
Parallel
()
if
_
,
err
:=
CanonicalizeReturnURL
(
"https://evil.example/payment/result"
,
"app.example.com"
);
err
==
nil
{
if
_
,
err
:=
CanonicalizeReturnURL
(
"https://evil.example/payment/result"
,
"app.example.com"
,
""
);
err
==
nil
{
t
.
Fatal
(
"CanonicalizeReturnURL should reject external hosts"
)
t
.
Fatal
(
"CanonicalizeReturnURL should reject external hosts"
)
}
}
}
}
func
TestCanonicalizeReturnURLAllowsConfiguredFrontendHost
(
t
*
testing
.
T
)
{
t
.
Parallel
()
got
,
err
:=
CanonicalizeReturnURL
(
"https://app.example.com/payment/result?from=checkout"
,
"api.example.com"
,
"https://app.example.com/purchase"
,
)
if
err
!=
nil
{
t
.
Fatalf
(
"CanonicalizeReturnURL returned error: %v"
,
err
)
}
if
got
!=
"https://app.example.com/payment/result?from=checkout"
{
t
.
Fatalf
(
"CanonicalizeReturnURL = %q, want %q"
,
got
,
"https://app.example.com/payment/result?from=checkout"
)
}
}
func
TestCanonicalizeReturnURLRejectsNonCanonicalPath
(
t
*
testing
.
T
)
{
func
TestCanonicalizeReturnURLRejectsNonCanonicalPath
(
t
*
testing
.
T
)
{
t
.
Parallel
()
t
.
Parallel
()
if
_
,
err
:=
CanonicalizeReturnURL
(
"https://app.example.com/orders/42"
,
"app.example.com"
);
err
==
nil
{
if
_
,
err
:=
CanonicalizeReturnURL
(
"https://app.example.com/orders/42"
,
"app.example.com"
,
""
);
err
==
nil
{
t
.
Fatal
(
"CanonicalizeReturnURL should reject non-canonical result paths"
)
t
.
Fatal
(
"CanonicalizeReturnURL should reject non-canonical result paths"
)
}
}
}
}
...
@@ -100,7 +117,7 @@ func TestCanonicalizeReturnURLRejectsNonCanonicalPath(t *testing.T) {
...
@@ -100,7 +117,7 @@ func TestCanonicalizeReturnURLRejectsNonCanonicalPath(t *testing.T) {
func
TestBuildPaymentReturnURL
(
t
*
testing
.
T
)
{
func
TestBuildPaymentReturnURL
(
t
*
testing
.
T
)
{
t
.
Parallel
()
t
.
Parallel
()
got
,
err
:=
buildPaymentReturnURL
(
"https://example.com/payment/result?from=checkout#fragment"
,
42
,
"resume-token"
)
got
,
err
:=
buildPaymentReturnURL
(
"https://example.com/payment/result?from=checkout#fragment"
,
42
,
"sub2_42"
,
"resume-token"
)
if
err
!=
nil
{
if
err
!=
nil
{
t
.
Fatalf
(
"buildPaymentReturnURL returned error: %v"
,
err
)
t
.
Fatalf
(
"buildPaymentReturnURL returned error: %v"
,
err
)
}
}
...
@@ -119,6 +136,9 @@ func TestBuildPaymentReturnURL(t *testing.T) {
...
@@ -119,6 +136,9 @@ func TestBuildPaymentReturnURL(t *testing.T) {
if
query
.
Get
(
"order_id"
)
!=
strconv
.
FormatInt
(
42
,
10
)
{
if
query
.
Get
(
"order_id"
)
!=
strconv
.
FormatInt
(
42
,
10
)
{
t
.
Fatalf
(
"order_id = %q"
,
query
.
Get
(
"order_id"
))
t
.
Fatalf
(
"order_id = %q"
,
query
.
Get
(
"order_id"
))
}
}
if
query
.
Get
(
"out_trade_no"
)
!=
"sub2_42"
{
t
.
Fatalf
(
"out_trade_no = %q"
,
query
.
Get
(
"out_trade_no"
))
}
if
query
.
Get
(
"resume_token"
)
!=
"resume-token"
{
if
query
.
Get
(
"resume_token"
)
!=
"resume-token"
{
t
.
Fatalf
(
"resume_token = %q"
,
query
.
Get
(
"resume_token"
))
t
.
Fatalf
(
"resume_token = %q"
,
query
.
Get
(
"resume_token"
))
}
}
...
@@ -127,10 +147,34 @@ func TestBuildPaymentReturnURL(t *testing.T) {
...
@@ -127,10 +147,34 @@ func TestBuildPaymentReturnURL(t *testing.T) {
}
}
}
}
func
TestBuildPaymentReturnURLWithoutResumeTokenStillIncludesOutTradeNo
(
t
*
testing
.
T
)
{
t
.
Parallel
()
got
,
err
:=
buildPaymentReturnURL
(
"https://example.com/payment/result"
,
42
,
"sub2_42"
,
""
)
if
err
!=
nil
{
t
.
Fatalf
(
"buildPaymentReturnURL returned error: %v"
,
err
)
}
parsed
,
err
:=
url
.
Parse
(
got
)
if
err
!=
nil
{
t
.
Fatalf
(
"url.Parse returned error: %v"
,
err
)
}
query
:=
parsed
.
Query
()
if
query
.
Get
(
"order_id"
)
!=
"42"
{
t
.
Fatalf
(
"order_id = %q"
,
query
.
Get
(
"order_id"
))
}
if
query
.
Get
(
"out_trade_no"
)
!=
"sub2_42"
{
t
.
Fatalf
(
"out_trade_no = %q"
,
query
.
Get
(
"out_trade_no"
))
}
if
query
.
Get
(
"resume_token"
)
!=
""
{
t
.
Fatalf
(
"resume_token = %q, want empty"
,
query
.
Get
(
"resume_token"
))
}
}
func
TestBuildPaymentReturnURLEmptyBase
(
t
*
testing
.
T
)
{
func
TestBuildPaymentReturnURLEmptyBase
(
t
*
testing
.
T
)
{
t
.
Parallel
()
t
.
Parallel
()
got
,
err
:=
buildPaymentReturnURL
(
""
,
42
,
"resume-token"
)
got
,
err
:=
buildPaymentReturnURL
(
""
,
42
,
"sub2_42"
,
"resume-token"
)
if
err
!=
nil
{
if
err
!=
nil
{
t
.
Fatalf
(
"buildPaymentReturnURL returned error: %v"
,
err
)
t
.
Fatalf
(
"buildPaymentReturnURL returned error: %v"
,
err
)
}
}
...
@@ -290,6 +334,98 @@ func TestParseWeChatPaymentResumeTokenRejectsExpiredToken(t *testing.T) {
...
@@ -290,6 +334,98 @@ func TestParseWeChatPaymentResumeTokenRejectsExpiredToken(t *testing.T) {
}
}
}
}
func
TestPaymentServiceParseWeChatPaymentResumeTokenUsesExplicitSigningKey
(
t
*
testing
.
T
)
{
t
.
Setenv
(
"PAYMENT_RESUME_SIGNING_KEY"
,
"explicit-payment-resume-signing-key"
)
token
,
err
:=
NewPaymentResumeService
([]
byte
(
"explicit-payment-resume-signing-key"
))
.
CreateWeChatPaymentResumeToken
(
WeChatPaymentResumeClaims
{
OpenID
:
"openid-explicit-key"
,
PaymentType
:
payment
.
TypeWxpay
,
})
if
err
!=
nil
{
t
.
Fatalf
(
"CreateWeChatPaymentResumeToken returned error: %v"
,
err
)
}
svc
:=
&
PaymentService
{
configService
:
&
PaymentConfigService
{
encryptionKey
:
[]
byte
(
"0123456789abcdef0123456789abcdef"
),
},
}
claims
,
err
:=
svc
.
ParseWeChatPaymentResumeToken
(
token
)
if
err
!=
nil
{
t
.
Fatalf
(
"ParseWeChatPaymentResumeToken returned error: %v"
,
err
)
}
if
claims
.
OpenID
!=
"openid-explicit-key"
{
t
.
Fatalf
(
"openid = %q, want %q"
,
claims
.
OpenID
,
"openid-explicit-key"
)
}
}
func
TestPaymentServiceParseWeChatPaymentResumeTokenAcceptsLegacyEncryptionKeyDuringMigration
(
t
*
testing
.
T
)
{
t
.
Setenv
(
"PAYMENT_RESUME_SIGNING_KEY"
,
"explicit-payment-resume-signing-key"
)
legacyKey
:=
[]
byte
(
"0123456789abcdef0123456789abcdef"
)
token
,
err
:=
NewPaymentResumeService
(
legacyKey
)
.
CreateWeChatPaymentResumeToken
(
WeChatPaymentResumeClaims
{
OpenID
:
"openid-legacy-key"
,
PaymentType
:
payment
.
TypeWxpay
,
})
if
err
!=
nil
{
t
.
Fatalf
(
"CreateWeChatPaymentResumeToken returned error: %v"
,
err
)
}
svc
:=
&
PaymentService
{
configService
:
&
PaymentConfigService
{
encryptionKey
:
legacyKey
,
},
}
claims
,
err
:=
svc
.
ParseWeChatPaymentResumeToken
(
token
)
if
err
!=
nil
{
t
.
Fatalf
(
"ParseWeChatPaymentResumeToken returned error: %v"
,
err
)
}
if
claims
.
OpenID
!=
"openid-legacy-key"
{
t
.
Fatalf
(
"openid = %q, want %q"
,
claims
.
OpenID
,
"openid-legacy-key"
)
}
}
func
TestNewConfiguredPaymentResumeServicePrefersExplicitSigningKeyAndKeepsLegacyVerificationFallback
(
t
*
testing
.
T
)
{
t
.
Setenv
(
"PAYMENT_RESUME_SIGNING_KEY"
,
"explicit-payment-resume-signing-key"
)
legacyKey
:=
[]
byte
(
"0123456789abcdef0123456789abcdef"
)
svc
:=
newLegacyAwarePaymentResumeService
(
legacyKey
)
explicitToken
,
err
:=
svc
.
CreateWeChatPaymentResumeToken
(
WeChatPaymentResumeClaims
{
OpenID
:
"openid-explicit-key"
,
PaymentType
:
payment
.
TypeWxpay
,
})
if
err
!=
nil
{
t
.
Fatalf
(
"CreateWeChatPaymentResumeToken returned error: %v"
,
err
)
}
explicitClaims
,
err
:=
NewPaymentResumeService
([]
byte
(
"explicit-payment-resume-signing-key"
))
.
ParseWeChatPaymentResumeToken
(
explicitToken
)
if
err
!=
nil
{
t
.
Fatalf
(
"ParseWeChatPaymentResumeToken returned error: %v"
,
err
)
}
if
explicitClaims
.
OpenID
!=
"openid-explicit-key"
{
t
.
Fatalf
(
"openid = %q, want %q"
,
explicitClaims
.
OpenID
,
"openid-explicit-key"
)
}
legacyToken
,
err
:=
NewPaymentResumeService
(
legacyKey
)
.
CreateWeChatPaymentResumeToken
(
WeChatPaymentResumeClaims
{
OpenID
:
"openid-legacy-key"
,
PaymentType
:
payment
.
TypeWxpay
,
})
if
err
!=
nil
{
t
.
Fatalf
(
"CreateWeChatPaymentResumeToken returned error: %v"
,
err
)
}
legacyClaims
,
err
:=
svc
.
ParseWeChatPaymentResumeToken
(
legacyToken
)
if
err
!=
nil
{
t
.
Fatalf
(
"ParseWeChatPaymentResumeToken returned error: %v"
,
err
)
}
if
legacyClaims
.
OpenID
!=
"openid-legacy-key"
{
t
.
Fatalf
(
"openid = %q, want %q"
,
legacyClaims
.
OpenID
,
"openid-legacy-key"
)
}
}
func
TestNormalizeVisibleMethodSource
(
t
*
testing
.
T
)
{
func
TestNormalizeVisibleMethodSource
(
t
*
testing
.
T
)
{
t
.
Parallel
()
t
.
Parallel
()
...
@@ -376,6 +512,258 @@ func TestVisibleMethodLoadBalancerUsesEnabledProviderInstance(t *testing.T) {
...
@@ -376,6 +512,258 @@ func TestVisibleMethodLoadBalancerUsesEnabledProviderInstance(t *testing.T) {
}
}
}
}
func
TestVisibleMethodLoadBalancerUsesConfiguredSourceWhenMultipleProvidersEnabled
(
t
*
testing
.
T
)
{
t
.
Parallel
()
tests
:=
[]
struct
{
name
string
method
payment
.
PaymentType
officialName
string
officialTypes
string
easyPayName
string
easyPayTypes
string
sourceSetting
string
wantProvider
string
}{
{
name
:
"alipay uses official source"
,
method
:
payment
.
TypeAlipay
,
officialName
:
"Official Alipay"
,
officialTypes
:
"alipay"
,
easyPayName
:
"EasyPay Alipay"
,
easyPayTypes
:
"alipay"
,
sourceSetting
:
VisibleMethodSourceOfficialAlipay
,
wantProvider
:
payment
.
TypeAlipay
,
},
{
name
:
"alipay uses easypay source"
,
method
:
payment
.
TypeAlipay
,
officialName
:
"Official Alipay"
,
officialTypes
:
"alipay"
,
easyPayName
:
"EasyPay Alipay"
,
easyPayTypes
:
"alipay"
,
sourceSetting
:
VisibleMethodSourceEasyPayAlipay
,
wantProvider
:
payment
.
TypeEasyPay
,
},
{
name
:
"wxpay uses official source"
,
method
:
payment
.
TypeWxpay
,
officialName
:
"Official WeChat"
,
officialTypes
:
"wxpay"
,
easyPayName
:
"EasyPay WeChat"
,
easyPayTypes
:
"wxpay"
,
sourceSetting
:
VisibleMethodSourceOfficialWechat
,
wantProvider
:
payment
.
TypeWxpay
,
},
{
name
:
"wxpay uses easypay source"
,
method
:
payment
.
TypeWxpay
,
officialName
:
"Official WeChat"
,
officialTypes
:
"wxpay"
,
easyPayName
:
"EasyPay WeChat"
,
easyPayTypes
:
"wxpay"
,
sourceSetting
:
VisibleMethodSourceEasyPayWechat
,
wantProvider
:
payment
.
TypeEasyPay
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
ctx
:=
context
.
Background
()
client
:=
newPaymentConfigServiceTestClient
(
t
)
officialProviderKey
:=
payment
.
TypeAlipay
if
tt
.
method
==
payment
.
TypeWxpay
{
officialProviderKey
=
payment
.
TypeWxpay
}
_
,
err
:=
client
.
PaymentProviderInstance
.
Create
()
.
SetProviderKey
(
officialProviderKey
)
.
SetName
(
tt
.
officialName
)
.
SetConfig
(
"{}"
)
.
SetSupportedTypes
(
tt
.
officialTypes
)
.
SetEnabled
(
true
)
.
SetSortOrder
(
1
)
.
Save
(
ctx
)
if
err
!=
nil
{
t
.
Fatalf
(
"create official provider: %v"
,
err
)
}
_
,
err
=
client
.
PaymentProviderInstance
.
Create
()
.
SetProviderKey
(
payment
.
TypeEasyPay
)
.
SetName
(
tt
.
easyPayName
)
.
SetConfig
(
"{}"
)
.
SetSupportedTypes
(
tt
.
easyPayTypes
)
.
SetEnabled
(
true
)
.
SetSortOrder
(
2
)
.
Save
(
ctx
)
if
err
!=
nil
{
t
.
Fatalf
(
"create easypay provider: %v"
,
err
)
}
inner
:=
&
captureLoadBalancer
{}
configService
:=
&
PaymentConfigService
{
entClient
:
client
,
settingRepo
:
&
paymentConfigSettingRepoStub
{
values
:
map
[
string
]
string
{
visibleMethodSourceSettingKey
(
tt
.
method
)
:
tt
.
sourceSetting
,
},
},
}
lb
:=
newVisibleMethodLoadBalancer
(
inner
,
configService
)
_
,
err
=
lb
.
SelectInstance
(
ctx
,
""
,
tt
.
method
,
payment
.
StrategyRoundRobin
,
12.5
)
if
err
!=
nil
{
t
.
Fatalf
(
"SelectInstance returned error: %v"
,
err
)
}
if
inner
.
lastProviderKey
!=
tt
.
wantProvider
{
t
.
Fatalf
(
"lastProviderKey = %q, want %q"
,
inner
.
lastProviderKey
,
tt
.
wantProvider
)
}
})
}
}
func
TestVisibleMethodLoadBalancerPreservesLegacyCrossProviderRoutingWhenSourceMissing
(
t
*
testing
.
T
)
{
t
.
Parallel
()
ctx
:=
context
.
Background
()
client
:=
newPaymentConfigServiceTestClient
(
t
)
_
,
err
:=
client
.
PaymentProviderInstance
.
Create
()
.
SetProviderKey
(
payment
.
TypeAlipay
)
.
SetName
(
"Official Alipay"
)
.
SetConfig
(
"{}"
)
.
SetSupportedTypes
(
"alipay"
)
.
SetEnabled
(
true
)
.
SetSortOrder
(
1
)
.
Save
(
ctx
)
if
err
!=
nil
{
t
.
Fatalf
(
"create official provider: %v"
,
err
)
}
_
,
err
=
client
.
PaymentProviderInstance
.
Create
()
.
SetProviderKey
(
payment
.
TypeEasyPay
)
.
SetName
(
"EasyPay Alipay"
)
.
SetConfig
(
"{}"
)
.
SetSupportedTypes
(
"alipay"
)
.
SetEnabled
(
true
)
.
SetSortOrder
(
2
)
.
Save
(
ctx
)
if
err
!=
nil
{
t
.
Fatalf
(
"create easypay provider: %v"
,
err
)
}
inner
:=
&
captureLoadBalancer
{}
configService
:=
&
PaymentConfigService
{
entClient
:
client
,
settingRepo
:
&
paymentConfigSettingRepoStub
{
values
:
map
[
string
]
string
{
visibleMethodSourceSettingKey
(
payment
.
TypeAlipay
)
:
""
,
},
},
}
lb
:=
newVisibleMethodLoadBalancer
(
inner
,
configService
)
_
,
err
=
lb
.
SelectInstance
(
ctx
,
""
,
payment
.
TypeAlipay
,
payment
.
StrategyRoundRobin
,
9.9
)
if
err
!=
nil
{
t
.
Fatalf
(
"SelectInstance returned error: %v"
,
err
)
}
if
inner
.
lastProviderKey
!=
""
{
t
.
Fatalf
(
"lastProviderKey = %q, want legacy cross-provider empty key"
,
inner
.
lastProviderKey
)
}
if
inner
.
lastPaymentType
!=
payment
.
TypeAlipay
{
t
.
Fatalf
(
"lastPaymentType = %q, want %q"
,
inner
.
lastPaymentType
,
payment
.
TypeAlipay
)
}
}
func
TestVisibleMethodLoadBalancerRejectsInvalidSourceWhenMultipleProvidersEnabled
(
t
*
testing
.
T
)
{
t
.
Parallel
()
tests
:=
[]
struct
{
name
string
method
payment
.
PaymentType
sourceValue
string
wantMessage
string
}{
{
name
:
"invalid wxpay source"
,
method
:
payment
.
TypeWxpay
,
sourceValue
:
"stripe"
,
wantMessage
:
"wxpay source must be one of the supported payment providers"
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
ctx
:=
context
.
Background
()
client
:=
newPaymentConfigServiceTestClient
(
t
)
officialProviderKey
:=
payment
.
TypeAlipay
officialSupportedTypes
:=
"alipay"
officialName
:=
"Official Alipay"
easyPaySupportedTypes
:=
"alipay"
easyPayName
:=
"EasyPay Alipay"
if
tt
.
method
==
payment
.
TypeWxpay
{
officialProviderKey
=
payment
.
TypeWxpay
officialSupportedTypes
=
"wxpay"
officialName
=
"Official WeChat"
easyPaySupportedTypes
=
"wxpay"
easyPayName
=
"EasyPay WeChat"
}
_
,
err
:=
client
.
PaymentProviderInstance
.
Create
()
.
SetProviderKey
(
officialProviderKey
)
.
SetName
(
officialName
)
.
SetConfig
(
"{}"
)
.
SetSupportedTypes
(
officialSupportedTypes
)
.
SetEnabled
(
true
)
.
SetSortOrder
(
1
)
.
Save
(
ctx
)
if
err
!=
nil
{
t
.
Fatalf
(
"create official provider: %v"
,
err
)
}
_
,
err
=
client
.
PaymentProviderInstance
.
Create
()
.
SetProviderKey
(
payment
.
TypeEasyPay
)
.
SetName
(
easyPayName
)
.
SetConfig
(
"{}"
)
.
SetSupportedTypes
(
easyPaySupportedTypes
)
.
SetEnabled
(
true
)
.
SetSortOrder
(
2
)
.
Save
(
ctx
)
if
err
!=
nil
{
t
.
Fatalf
(
"create easypay provider: %v"
,
err
)
}
inner
:=
&
captureLoadBalancer
{}
configService
:=
&
PaymentConfigService
{
entClient
:
client
,
settingRepo
:
&
paymentConfigSettingRepoStub
{
values
:
map
[
string
]
string
{
visibleMethodSourceSettingKey
(
tt
.
method
)
:
tt
.
sourceValue
,
},
},
}
lb
:=
newVisibleMethodLoadBalancer
(
inner
,
configService
)
_
,
err
=
lb
.
SelectInstance
(
ctx
,
""
,
tt
.
method
,
payment
.
StrategyRoundRobin
,
9.9
)
if
err
==
nil
{
t
.
Fatal
(
"SelectInstance should reject invalid visible method source configuration"
)
}
if
infraerrors
.
Reason
(
err
)
!=
"INVALID_PAYMENT_VISIBLE_METHOD_SOURCE"
{
t
.
Fatalf
(
"Reason(err) = %q, want %q"
,
infraerrors
.
Reason
(
err
),
"INVALID_PAYMENT_VISIBLE_METHOD_SOURCE"
)
}
if
infraerrors
.
Message
(
err
)
!=
tt
.
wantMessage
{
t
.
Fatalf
(
"Message(err) = %q, want %q"
,
infraerrors
.
Message
(
err
),
tt
.
wantMessage
)
}
})
}
}
func
TestVisibleMethodLoadBalancerRejectsMissingEnabledVisibleMethodProvider
(
t
*
testing
.
T
)
{
func
TestVisibleMethodLoadBalancerRejectsMissingEnabledVisibleMethodProvider
(
t
*
testing
.
T
)
{
t
.
Parallel
()
t
.
Parallel
()
...
...
backend/internal/service/payment_service.go
View file @
ddf80f5e
package
service
package
service
import
(
import
(
"bytes"
"context"
"context"
"encoding/hex"
"fmt"
"fmt"
"log/slog"
"log/slog"
"math/rand/v2"
"math/rand/v2"
"os"
"strings"
"sync"
"sync"
"time"
"time"
...
@@ -44,6 +48,8 @@ const (
...
@@ -44,6 +48,8 @@ const (
orderIDPrefix
=
"sub2_"
orderIDPrefix
=
"sub2_"
)
)
const
paymentResumeSigningKeyEnv
=
"PAYMENT_RESUME_SIGNING_KEY"
// --- Types ---
// --- Types ---
// generateOutTradeNo creates a unique external order ID for payment providers.
// generateOutTradeNo creates a unique external order ID for payment providers.
...
@@ -179,7 +185,7 @@ type PaymentService struct {
...
@@ -179,7 +185,7 @@ type PaymentService struct {
func
NewPaymentService
(
entClient
*
dbent
.
Client
,
registry
*
payment
.
Registry
,
loadBalancer
payment
.
LoadBalancer
,
redeemService
*
RedeemService
,
subscriptionSvc
*
SubscriptionService
,
configService
*
PaymentConfigService
,
userRepo
UserRepository
,
groupRepo
GroupRepository
)
*
PaymentService
{
func
NewPaymentService
(
entClient
*
dbent
.
Client
,
registry
*
payment
.
Registry
,
loadBalancer
payment
.
LoadBalancer
,
redeemService
*
RedeemService
,
subscriptionSvc
*
SubscriptionService
,
configService
*
PaymentConfigService
,
userRepo
UserRepository
,
groupRepo
GroupRepository
)
*
PaymentService
{
svc
:=
&
PaymentService
{
entClient
:
entClient
,
registry
:
registry
,
loadBalancer
:
newVisibleMethodLoadBalancer
(
loadBalancer
,
configService
),
redeemService
:
redeemService
,
subscriptionSvc
:
subscriptionSvc
,
configService
:
configService
,
userRepo
:
userRepo
,
groupRepo
:
groupRepo
}
svc
:=
&
PaymentService
{
entClient
:
entClient
,
registry
:
registry
,
loadBalancer
:
newVisibleMethodLoadBalancer
(
loadBalancer
,
configService
),
redeemService
:
redeemService
,
subscriptionSvc
:
subscriptionSvc
,
configService
:
configService
,
userRepo
:
userRepo
,
groupRepo
:
groupRepo
}
svc
.
resumeService
=
NewPaymentResumeService
(
psResumeSigningKey
(
configService
)
)
svc
.
resumeService
=
ps
NewPaymentResumeService
(
configService
)
return
svc
return
svc
}
}
...
@@ -259,16 +265,56 @@ func (s *PaymentService) paymentResume() *PaymentResumeService {
...
@@ -259,16 +265,56 @@ func (s *PaymentService) paymentResume() *PaymentResumeService {
if
s
.
resumeService
!=
nil
{
if
s
.
resumeService
!=
nil
{
return
s
.
resumeService
return
s
.
resumeService
}
}
return
NewPaymentResumeService
(
psResumeSigningKey
(
s
.
configService
))
return
psNewPaymentResumeService
(
s
.
configService
)
}
func
NewLegacyAwarePaymentResumeService
(
legacyKey
[]
byte
)
*
PaymentResumeService
{
return
newLegacyAwarePaymentResumeService
(
legacyKey
)
}
func
psNewPaymentResumeService
(
configService
*
PaymentConfigService
)
*
PaymentResumeService
{
return
newLegacyAwarePaymentResumeService
(
psResumeLegacyVerificationKey
(
configService
))
}
}
func
psResumeSigningKey
(
configService
*
PaymentConfigService
)
[]
byte
{
func
newLegacyAwarePaymentResumeService
(
legacyKey
[]
byte
)
*
PaymentResumeService
{
signingKey
,
verifyFallbacks
:=
resolvePaymentResumeSigningKeys
(
legacyKey
)
return
NewPaymentResumeService
(
signingKey
,
verifyFallbacks
...
)
}
func
psResumeLegacyVerificationKey
(
configService
*
PaymentConfigService
)
[]
byte
{
if
configService
==
nil
{
if
configService
==
nil
{
return
nil
return
nil
}
}
return
configService
.
encryptionKey
return
configService
.
encryptionKey
}
}
func
resolvePaymentResumeSigningKeys
(
legacyKey
[]
byte
)
([]
byte
,
[][]
byte
)
{
signingKey
:=
parsePaymentResumeSigningKey
(
os
.
Getenv
(
paymentResumeSigningKeyEnv
))
if
len
(
signingKey
)
==
0
{
if
len
(
legacyKey
)
==
0
{
return
nil
,
nil
}
return
legacyKey
,
nil
}
if
len
(
legacyKey
)
==
0
||
bytes
.
Equal
(
legacyKey
,
signingKey
)
{
return
signingKey
,
nil
}
return
signingKey
,
[][]
byte
{
legacyKey
}
}
func
parsePaymentResumeSigningKey
(
raw
string
)
[]
byte
{
raw
=
strings
.
TrimSpace
(
raw
)
if
raw
==
""
{
return
nil
}
if
len
(
raw
)
>=
64
&&
len
(
raw
)
%
2
==
0
{
if
decoded
,
err
:=
hex
.
DecodeString
(
raw
);
err
==
nil
&&
len
(
decoded
)
>
0
{
return
decoded
}
}
return
[]
byte
(
raw
)
}
func
psSliceContains
(
sl
[]
string
,
s
string
)
bool
{
func
psSliceContains
(
sl
[]
string
,
s
string
)
bool
{
for
_
,
v
:=
range
sl
{
for
_
,
v
:=
range
sl
{
if
v
==
s
{
if
v
==
s
{
...
...
backend/internal/service/payment_visible_method_instances.go
View file @
ddf80f5e
...
@@ -2,6 +2,7 @@ package service
...
@@ -2,6 +2,7 @@ package service
import
(
import
(
"context"
"context"
"errors"
"fmt"
"fmt"
"strings"
"strings"
...
@@ -82,19 +83,52 @@ func filterEnabledVisibleMethodInstances(instances []*dbent.PaymentProviderInsta
...
@@ -82,19 +83,52 @@ func filterEnabledVisibleMethodInstances(instances []*dbent.PaymentProviderInsta
return
filtered
return
filtered
}
}
func
buildPaymentProviderConflictError
(
method
string
,
conflicting
*
dbent
.
PaymentProviderInstance
)
error
{
func
filterVisibleMethodInstancesByProviderKey
(
instances
[]
*
dbent
.
PaymentProviderInstance
,
method
string
,
providerKey
string
)
[]
*
dbent
.
PaymentProviderInstance
{
metadata
:=
map
[
string
]
string
{
filtered
:=
make
([]
*
dbent
.
PaymentProviderInstance
,
0
,
len
(
instances
))
"payment_method"
:
NormalizeVisibleMethod
(
method
),
for
_
,
inst
:=
range
instances
{
if
!
providerSupportsVisibleMethod
(
inst
,
method
)
{
continue
}
if
!
strings
.
EqualFold
(
strings
.
TrimSpace
(
inst
.
ProviderKey
),
strings
.
TrimSpace
(
providerKey
))
{
continue
}
filtered
=
append
(
filtered
,
inst
)
}
return
filtered
}
func
distinctVisibleMethodProviderKeys
(
instances
[]
*
dbent
.
PaymentProviderInstance
)
[]
string
{
seen
:=
make
(
map
[
string
]
struct
{},
len
(
instances
))
keys
:=
make
([]
string
,
0
,
len
(
instances
))
for
_
,
inst
:=
range
instances
{
if
inst
==
nil
{
continue
}
key
:=
strings
.
TrimSpace
(
inst
.
ProviderKey
)
if
key
==
""
{
continue
}
normalized
:=
strings
.
ToLower
(
key
)
if
_
,
ok
:=
seen
[
normalized
];
ok
{
continue
}
seen
[
normalized
]
=
struct
{}{}
keys
=
append
(
keys
,
key
)
}
}
if
conflicting
!=
nil
{
return
keys
metadata
[
"conflicting_provider_id"
]
=
fmt
.
Sprintf
(
"%d"
,
conflicting
.
ID
)
}
metadata
[
"conflicting_provider_key"
]
=
conflicting
.
ProviderKey
metadata
[
"conflicting_provider_name"
]
=
conflicting
.
Name
func
selectVisibleMethodInstanceByProviderKey
(
instances
[]
*
dbent
.
PaymentProviderInstance
,
providerKey
string
)
*
dbent
.
PaymentProviderInstance
{
providerKey
=
strings
.
TrimSpace
(
providerKey
)
if
providerKey
==
""
{
return
nil
}
for
_
,
inst
:=
range
instances
{
if
strings
.
EqualFold
(
strings
.
TrimSpace
(
inst
.
ProviderKey
),
providerKey
)
{
return
inst
}
}
}
return
infraerrors
.
Conflict
(
return
nil
"PAYMENT_PROVIDER_CONFLICT"
,
fmt
.
Sprintf
(
"%s payment already has an enabled provider instance"
,
NormalizeVisibleMethod
(
method
)),
)
.
WithMetadata
(
metadata
)
}
}
func
(
s
*
PaymentConfigService
)
validateVisibleMethodEnablementConflicts
(
func
(
s
*
PaymentConfigService
)
validateVisibleMethodEnablementConflicts
(
...
@@ -104,33 +138,72 @@ func (s *PaymentConfigService) validateVisibleMethodEnablementConflicts(
...
@@ -104,33 +138,72 @@ func (s *PaymentConfigService) validateVisibleMethodEnablementConflicts(
supportedTypes
string
,
supportedTypes
string
,
enabled
bool
,
enabled
bool
,
)
error
{
)
error
{
if
s
==
nil
||
s
.
entClient
==
nil
||
!
enabled
{
// Visible methods are selected by configured source (official/easypay),
return
nil
// so multiple enabled providers can intentionally claim the same user-facing
}
// method. Order creation and limits will route through the configured source.
_
,
_
,
_
,
_
,
_
=
ctx
,
excludeID
,
providerKey
,
supportedTypes
,
enabled
return
nil
}
claimedMethods
:=
enabledVisibleMethodsForProvider
(
providerKey
,
supportedTypes
)
func
(
s
*
PaymentConfigService
)
resolveVisibleMethodSourceProviderKey
(
ctx
context
.
Context
,
method
string
)
(
string
,
error
)
{
if
len
(
claimedMethods
)
==
0
{
method
=
NormalizeVisibleMethod
(
method
)
return
nil
sourceKey
:=
visibleMethodSourceSettingKey
(
method
)
rawSource
:=
""
if
s
!=
nil
&&
s
.
settingRepo
!=
nil
&&
sourceKey
!=
""
{
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
sourceKey
)
if
err
!=
nil
{
if
!
errors
.
Is
(
err
,
ErrSettingNotFound
)
{
return
""
,
fmt
.
Errorf
(
"get %s: %w"
,
sourceKey
,
err
)
}
}
else
{
rawSource
=
value
}
}
}
query
:=
s
.
entClient
.
PaymentProviderInstance
.
Query
()
.
normalizedSource
,
err
:=
normalizeVisibleMethodSettingSource
(
method
,
rawSource
,
true
)
Where
(
paymentproviderinstance
.
EnabledEQ
(
true
))
if
excludeID
>
0
{
query
=
query
.
Where
(
paymentproviderinstance
.
IDNEQ
(
excludeID
))
}
instances
,
err
:=
query
.
All
(
ctx
)
if
err
!=
nil
{
if
err
!=
nil
{
return
fmt
.
Errorf
(
"query enabled payment providers: %w"
,
err
)
return
""
,
err
}
if
normalizedSource
==
""
{
return
""
,
nil
}
providerKey
,
ok
:=
VisibleMethodProviderKeyForSource
(
method
,
normalizedSource
)
if
!
ok
{
return
""
,
infraerrors
.
BadRequest
(
"INVALID_PAYMENT_VISIBLE_METHOD_SOURCE"
,
fmt
.
Sprintf
(
"%s source must be one of the supported payment providers"
,
method
),
)
}
}
return
providerKey
,
nil
}
for
_
,
method
:=
range
claimedMethods
{
func
(
s
*
PaymentConfigService
)
resolveVisibleMethodProviderKey
(
for
_
,
inst
:=
range
instances
{
ctx
context
.
Context
,
if
providerSupportsVisibleMethod
(
inst
,
method
)
{
method
string
,
return
buildPaymentProviderConflictError
(
method
,
inst
)
matching
[]
*
dbent
.
PaymentProviderInstance
,
}
)
(
string
,
error
)
{
switch
providerKeys
:=
distinctVisibleMethodProviderKeys
(
matching
);
len
(
providerKeys
)
{
case
0
:
return
""
,
nil
case
1
:
return
strings
.
TrimSpace
(
providerKeys
[
0
]),
nil
default
:
providerKey
,
err
:=
s
.
resolveVisibleMethodSourceProviderKey
(
ctx
,
method
)
if
err
!=
nil
{
return
""
,
err
}
if
providerKey
==
""
{
return
""
,
nil
}
selected
:=
selectVisibleMethodInstanceByProviderKey
(
matching
,
providerKey
)
if
selected
==
nil
{
return
""
,
infraerrors
.
BadRequest
(
"INVALID_PAYMENT_VISIBLE_METHOD_SOURCE"
,
fmt
.
Sprintf
(
"%s source has no enabled provider instance"
,
method
),
)
}
}
return
strings
.
TrimSpace
(
selected
.
ProviderKey
),
nil
}
}
return
nil
}
}
func
(
s
*
PaymentConfigService
)
resolveEnabledVisibleMethodInstance
(
func
(
s
*
PaymentConfigService
)
resolveEnabledVisibleMethodInstance
(
...
@@ -155,12 +228,15 @@ func (s *PaymentConfigService) resolveEnabledVisibleMethodInstance(
...
@@ -155,12 +228,15 @@ func (s *PaymentConfigService) resolveEnabledVisibleMethodInstance(
}
}
matching
:=
filterEnabledVisibleMethodInstances
(
instances
,
method
)
matching
:=
filterEnabledVisibleMethodInstances
(
instances
,
method
)
switch
len
(
matching
)
{
providerKey
,
err
:=
s
.
resolveVisibleMethodProviderKey
(
ctx
,
method
,
matching
)
case
0
:
if
err
!=
nil
{
return
nil
,
nil
return
nil
,
err
case
1
:
}
return
matching
[
0
],
nil
if
providerKey
==
""
{
default
:
if
len
(
matching
)
==
0
{
return
nil
,
buildPaymentProviderConflictError
(
method
,
matching
[
0
])
return
nil
,
nil
}
return
&
dbent
.
PaymentProviderInstance
{
ProviderKey
:
""
},
nil
}
}
return
selectVisibleMethodInstanceByProviderKey
(
matching
,
providerKey
),
nil
}
}
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment