Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
ddf80f5e
Unverified
Commit
ddf80f5e
authored
Apr 22, 2026
by
Wesley Liddick
Committed by
GitHub
Apr 22, 2026
Browse files
Merge pull request #1799 from IanShaw027/rebuild/auth-identity-foundation
fix(auth,payment,profile): 修复认证身份和支付系统的后续问题
parents
4d0483f5
c048ca80
Changes
140
Hide whitespace changes
Inline
Side-by-side
backend/internal/handler/auth_oauth_logout_test.go
0 → 100644
View file @
ddf80f5e
package
handler
import
(
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func
TestLogoutClearsOAuthStateCookiesAndConsumesPendingSession
(
t
*
testing
.
T
)
{
handler
,
client
:=
newOAuthPendingFlowTestHandler
(
t
,
false
)
ctx
:=
context
.
Background
()
session
,
err
:=
client
.
PendingAuthSession
.
Create
()
.
SetSessionToken
(
"logout-pending-session-token"
)
.
SetIntent
(
"login"
)
.
SetProviderType
(
"oidc"
)
.
SetProviderKey
(
"https://issuer.example"
)
.
SetProviderSubject
(
"logout-subject-123"
)
.
SetBrowserSessionKey
(
"logout-browser-session-key"
)
.
SetResolvedEmail
(
"logout@example.com"
)
.
SetExpiresAt
(
time
.
Now
()
.
UTC
()
.
Add
(
10
*
time
.
Minute
))
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
recorder
:=
httptest
.
NewRecorder
()
ginCtx
,
_
:=
gin
.
CreateTestContext
(
recorder
)
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/auth/logout"
,
nil
)
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthPendingSessionCookieName
,
Value
:
encodeCookieValue
(
session
.
SessionToken
)})
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthPendingBrowserCookieName
,
Value
:
encodeCookieValue
(
"logout-browser-session-key"
)})
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthBindAccessTokenCookieName
,
Value
:
"bind-access-token"
})
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
linuxDoOAuthStateCookieName
,
Value
:
encodeCookieValue
(
"linuxdo-state"
)})
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oidcOAuthStateCookieName
,
Value
:
encodeCookieValue
(
"oidc-state"
)})
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
wechatOAuthStateCookieName
,
Value
:
encodeCookieValue
(
"wechat-state"
)})
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
wechatPaymentOAuthStateName
,
Value
:
encodeCookieValue
(
"wechat-payment-state"
)})
ginCtx
.
Request
=
req
handler
.
Logout
(
ginCtx
)
require
.
Equal
(
t
,
http
.
StatusOK
,
recorder
.
Code
)
cookies
:=
recorder
.
Result
()
.
Cookies
()
for
_
,
name
:=
range
[]
string
{
oauthPendingSessionCookieName
,
oauthPendingBrowserCookieName
,
oauthBindAccessTokenCookieName
,
linuxDoOAuthStateCookieName
,
oidcOAuthStateCookieName
,
wechatOAuthStateCookieName
,
wechatPaymentOAuthStateName
,
}
{
cookie
:=
findCookie
(
cookies
,
name
)
require
.
NotNil
(
t
,
cookie
,
name
)
require
.
Equal
(
t
,
-
1
,
cookie
.
MaxAge
,
name
)
require
.
True
(
t
,
cookie
.
HttpOnly
,
name
)
}
storedSession
,
err
:=
client
.
PendingAuthSession
.
Query
()
.
Where
(
pendingauthsession
.
IDEQ
(
session
.
ID
))
.
Only
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
storedSession
.
ConsumedAt
)
}
backend/internal/handler/auth_oauth_pending_flow.go
View file @
ddf80f5e
...
@@ -265,16 +265,20 @@ func pendingSessionWantsInvitation(payload map[string]any) bool {
...
@@ -265,16 +265,20 @@ func pendingSessionWantsInvitation(payload map[string]any) bool {
return
strings
.
EqualFold
(
strings
.
TrimSpace
(
pendingSessionStringValue
(
payload
,
"error"
)),
"invitation_required"
)
return
strings
.
EqualFold
(
strings
.
TrimSpace
(
pendingSessionStringValue
(
payload
,
"error"
)),
"invitation_required"
)
}
}
func
pendingOAuthCompletion
IncludesTokenPayload
(
payload
map
[
string
]
any
)
bool
{
func
pendingOAuthCompletion
CanIssueTokenPair
(
session
*
dbent
.
PendingAuthSession
,
payload
map
[
string
]
any
)
bool
{
if
len
(
payload
)
==
0
{
if
session
==
nil
{
return
false
return
false
}
}
for
_
,
key
:=
range
[]
string
{
"access_token"
,
"refresh_token"
}
{
if
!
strings
.
EqualFold
(
strings
.
TrimSpace
(
session
.
Intent
),
oauthIntentLogin
)
{
if
value
:=
pendingSessionStringValue
(
payload
,
key
);
value
!=
""
{
return
false
return
true
}
}
if
session
.
TargetUserID
==
nil
||
*
session
.
TargetUserID
<=
0
{
return
false
}
if
pendingSessionWantsInvitation
(
payload
)
{
return
false
}
}
return
false
return
strings
.
TrimSpace
(
pendingSessionStringValue
(
payload
,
"step"
))
==
""
}
}
func
ensurePendingOAuthCompleteRegistrationSession
(
session
*
dbent
.
PendingAuthSession
)
error
{
func
ensurePendingOAuthCompleteRegistrationSession
(
session
*
dbent
.
PendingAuthSession
)
error
{
...
@@ -294,6 +298,78 @@ func ensurePendingOAuthCompleteRegistrationSession(session *dbent.PendingAuthSes
...
@@ -294,6 +298,78 @@ func ensurePendingOAuthCompleteRegistrationSession(session *dbent.PendingAuthSes
return
nil
return
nil
}
}
func
buildLegacyCompleteRegistrationPendingResponse
(
session
*
dbent
.
PendingAuthSession
,
forceEmailOnSignup
bool
,
emailVerificationRequired
bool
,
)
map
[
string
]
any
{
completionResponse
:=
normalizePendingOAuthCompletionResponse
(
mergePendingCompletionResponse
(
session
,
map
[
string
]
any
{
"step"
:
oauthPendingChoiceStep
,
"adoption_required"
:
true
,
"create_account_allowed"
:
true
,
"force_email_on_signup"
:
forceEmailOnSignup
,
}))
if
email
:=
strings
.
TrimSpace
(
session
.
ResolvedEmail
);
email
!=
""
{
if
_
,
exists
:=
completionResponse
[
"email"
];
!
exists
{
completionResponse
[
"email"
]
=
email
}
if
_
,
exists
:=
completionResponse
[
"resolved_email"
];
!
exists
{
completionResponse
[
"resolved_email"
]
=
email
}
}
if
_
,
exists
:=
completionResponse
[
"choice_reason"
];
!
exists
{
switch
{
case
forceEmailOnSignup
:
completionResponse
[
"choice_reason"
]
=
"force_email_on_signup"
case
emailVerificationRequired
:
completionResponse
[
"choice_reason"
]
=
"email_verification_required"
default
:
completionResponse
[
"choice_reason"
]
=
"third_party_signup"
}
}
return
completionResponse
}
func
(
h
*
AuthHandler
)
legacyCompleteRegistrationSessionStatus
(
c
*
gin
.
Context
,
session
*
dbent
.
PendingAuthSession
,
)
(
*
dbent
.
PendingAuthSession
,
bool
,
error
)
{
if
session
==
nil
{
return
nil
,
false
,
infraerrors
.
BadRequest
(
"PENDING_AUTH_SESSION_INVALID"
,
"pending auth registration context is invalid"
)
}
payload
:=
normalizePendingOAuthCompletionResponse
(
mergePendingCompletionResponse
(
session
,
nil
))
if
step
:=
pendingSessionStringValue
(
payload
,
"step"
);
step
!=
""
{
return
session
,
true
,
nil
}
emailVerificationRequired
:=
h
!=
nil
&&
h
.
authService
!=
nil
&&
h
.
authService
.
IsEmailVerifyEnabled
(
c
.
Request
.
Context
())
forceEmailOnSignup
:=
h
.
isForceEmailOnThirdPartySignup
(
c
.
Request
.
Context
())
if
!
emailVerificationRequired
&&
!
forceEmailOnSignup
{
return
session
,
false
,
nil
}
client
:=
h
.
entClient
()
if
client
==
nil
{
return
nil
,
false
,
infraerrors
.
ServiceUnavailable
(
"PENDING_AUTH_NOT_READY"
,
"pending auth service is not ready"
)
}
updatedSession
,
err
:=
updatePendingOAuthSessionProgress
(
c
.
Request
.
Context
(),
client
,
session
,
strings
.
TrimSpace
(
session
.
Intent
),
strings
.
TrimSpace
(
session
.
ResolvedEmail
),
nil
,
buildLegacyCompleteRegistrationPendingResponse
(
session
,
forceEmailOnSignup
,
emailVerificationRequired
),
)
if
err
!=
nil
{
return
nil
,
false
,
infraerrors
.
InternalServer
(
"PENDING_AUTH_SESSION_UPDATE_FAILED"
,
"failed to update pending oauth session"
)
.
WithCause
(
err
)
}
return
updatedSession
,
true
,
nil
}
func
(
r
oauthAdoptionDecisionRequest
)
hasDecision
()
bool
{
func
(
r
oauthAdoptionDecisionRequest
)
hasDecision
()
bool
{
return
r
.
AdoptDisplayName
!=
nil
||
r
.
AdoptAvatar
!=
nil
return
r
.
AdoptDisplayName
!=
nil
||
r
.
AdoptAvatar
!=
nil
}
}
...
@@ -376,15 +452,7 @@ func (h *AuthHandler) findOAuthIdentityUser(ctx context.Context, identity servic
...
@@ -376,15 +452,7 @@ func (h *AuthHandler) findOAuthIdentityUser(ctx context.Context, identity servic
}
}
return
nil
,
infraerrors
.
InternalServer
(
"AUTH_IDENTITY_LOOKUP_FAILED"
,
"failed to inspect auth identity ownership"
)
.
WithCause
(
err
)
return
nil
,
infraerrors
.
InternalServer
(
"AUTH_IDENTITY_LOOKUP_FAILED"
,
"failed to inspect auth identity ownership"
)
.
WithCause
(
err
)
}
}
return
findActiveUserByID
(
ctx
,
client
,
record
.
UserID
)
userEntity
,
err
:=
client
.
User
.
Get
(
ctx
,
record
.
UserID
)
if
err
!=
nil
{
if
dbent
.
IsNotFound
(
err
)
{
return
nil
,
nil
}
return
nil
,
infraerrors
.
InternalServer
(
"AUTH_IDENTITY_USER_LOOKUP_FAILED"
,
"failed to load auth identity user"
)
.
WithCause
(
err
)
}
return
userEntity
,
nil
}
}
func
(
h
*
AuthHandler
)
BindLinuxDoOAuthLogin
(
c
*
gin
.
Context
)
{
h
.
bindPendingOAuthLogin
(
c
,
"linuxdo"
)
}
func
(
h
*
AuthHandler
)
BindLinuxDoOAuthLogin
(
c
*
gin
.
Context
)
{
h
.
bindPendingOAuthLogin
(
c
,
"linuxdo"
)
}
...
@@ -439,7 +507,7 @@ func (h *AuthHandler) SendPendingOAuthVerifyCode(c *gin.Context) {
...
@@ -439,7 +507,7 @@ func (h *AuthHandler) SendPendingOAuthVerifyCode(c *gin.Context) {
email
:=
strings
.
TrimSpace
(
strings
.
ToLower
(
req
.
Email
))
email
:=
strings
.
TrimSpace
(
strings
.
ToLower
(
req
.
Email
))
if
existingUser
,
err
:=
findUserByNormalizedEmail
(
c
.
Request
.
Context
(),
client
,
email
);
err
==
nil
&&
existingUser
!=
nil
{
if
existingUser
,
err
:=
findUserByNormalizedEmail
(
c
.
Request
.
Context
(),
client
,
email
);
err
==
nil
&&
existingUser
!=
nil
{
session
,
err
=
h
.
transitionPendingOAuthAccountToChoiceState
(
c
,
client
,
session
,
email
)
session
,
err
=
h
.
transitionPendingOAuthAccountToChoiceState
(
c
,
client
,
session
,
existingUser
,
email
)
if
err
!=
nil
{
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
response
.
ErrorFrom
(
c
,
err
)
return
return
...
@@ -624,6 +692,38 @@ func findUserByNormalizedEmail(ctx context.Context, client *dbent.Client, email
...
@@ -624,6 +692,38 @@ func findUserByNormalizedEmail(ctx context.Context, client *dbent.Client, email
return
matches
[
0
],
nil
return
matches
[
0
],
nil
}
}
func
ensurePendingOAuthRegistrationIdentityAvailable
(
ctx
context
.
Context
,
client
*
dbent
.
Client
,
session
*
dbent
.
PendingAuthSession
)
error
{
if
client
==
nil
||
session
==
nil
{
return
infraerrors
.
BadRequest
(
"PENDING_AUTH_SESSION_INVALID"
,
"pending auth registration context is invalid"
)
}
identity
,
err
:=
client
.
AuthIdentity
.
Query
()
.
Where
(
authidentity
.
ProviderTypeEQ
(
strings
.
TrimSpace
(
session
.
ProviderType
)),
authidentity
.
ProviderKeyEQ
(
strings
.
TrimSpace
(
session
.
ProviderKey
)),
authidentity
.
ProviderSubjectEQ
(
strings
.
TrimSpace
(
session
.
ProviderSubject
)),
)
.
Only
(
ctx
)
if
err
!=
nil
{
if
dbent
.
IsNotFound
(
err
)
{
return
nil
}
return
err
}
if
identity
==
nil
||
identity
.
UserID
<=
0
{
return
nil
}
activeOwner
,
err
:=
findActiveUserByID
(
ctx
,
client
,
identity
.
UserID
)
if
err
!=
nil
{
return
err
}
if
activeOwner
!=
nil
{
return
infraerrors
.
Conflict
(
"AUTH_IDENTITY_OWNERSHIP_CONFLICT"
,
"auth identity already belongs to another user"
)
}
return
nil
}
func
oauthIdentityIssuer
(
session
*
dbent
.
PendingAuthSession
)
*
string
{
func
oauthIdentityIssuer
(
session
*
dbent
.
PendingAuthSession
)
*
string
{
if
session
==
nil
{
if
session
==
nil
{
return
nil
return
nil
...
@@ -910,6 +1010,9 @@ func findActiveUserByID(ctx context.Context, client *dbent.Client, userID int64)
...
@@ -910,6 +1010,9 @@ func findActiveUserByID(ctx context.Context, client *dbent.Client, userID int64)
}
}
return
nil
,
infraerrors
.
InternalServer
(
"AUTH_IDENTITY_USER_LOOKUP_FAILED"
,
"failed to load auth identity user"
)
.
WithCause
(
err
)
return
nil
,
infraerrors
.
InternalServer
(
"AUTH_IDENTITY_USER_LOOKUP_FAILED"
,
"failed to load auth identity user"
)
.
WithCause
(
err
)
}
}
if
!
strings
.
EqualFold
(
strings
.
TrimSpace
(
userEntity
.
Status
),
service
.
StatusActive
)
{
return
nil
,
service
.
ErrUserNotActive
}
return
userEntity
,
nil
return
userEntity
,
nil
}
}
...
@@ -1123,6 +1226,38 @@ func consumePendingOAuthBrowserSessionTx(
...
@@ -1123,6 +1226,38 @@ func consumePendingOAuthBrowserSessionTx(
return
nil
return
nil
}
}
func
applyPendingOAuthAdoptionAndConsumeSession
(
ctx
context
.
Context
,
client
*
dbent
.
Client
,
authService
*
service
.
AuthService
,
userService
*
service
.
UserService
,
session
*
dbent
.
PendingAuthSession
,
decision
*
dbent
.
IdentityAdoptionDecision
,
userID
int64
,
)
error
{
if
client
==
nil
{
return
infraerrors
.
ServiceUnavailable
(
"PENDING_AUTH_NOT_READY"
,
"pending auth service is not ready"
)
}
if
session
==
nil
||
userID
<=
0
{
return
infraerrors
.
BadRequest
(
"PENDING_AUTH_SESSION_INVALID"
,
"pending auth registration context is invalid"
)
}
tx
,
err
:=
client
.
Tx
(
ctx
)
if
err
!=
nil
{
return
err
}
defer
func
()
{
_
=
tx
.
Rollback
()
}()
txCtx
:=
dbent
.
NewTxContext
(
ctx
,
tx
)
if
err
:=
applyPendingOAuthAdoption
(
txCtx
,
client
,
authService
,
userService
,
session
,
decision
,
&
userID
);
err
!=
nil
{
return
err
}
if
err
:=
consumePendingOAuthBrowserSessionTx
(
txCtx
,
tx
,
session
);
err
!=
nil
{
return
err
}
return
tx
.
Commit
()
}
func
applyPendingOAuthAdoption
(
func
applyPendingOAuthAdoption
(
ctx
context
.
Context
,
ctx
context
.
Context
,
client
*
dbent
.
Client
,
client
*
dbent
.
Client
,
...
@@ -1212,13 +1347,7 @@ func (h *AuthHandler) shouldSkipPendingOAuthAdoptionPrompt(
...
@@ -1212,13 +1347,7 @@ func (h *AuthHandler) shouldSkipPendingOAuthAdoptionPrompt(
if
session
==
nil
||
len
(
payload
)
==
0
{
if
session
==
nil
||
len
(
payload
)
==
0
{
return
false
,
nil
return
false
,
nil
}
}
if
!
strings
.
EqualFold
(
strings
.
TrimSpace
(
session
.
Intent
),
oauthIntentLogin
)
{
if
!
pendingOAuthCompletionCanIssueTokenPair
(
session
,
payload
)
{
return
false
,
nil
}
if
!
pendingOAuthCompletionIncludesTokenPayload
(
payload
)
{
return
false
,
nil
}
if
session
.
TargetUserID
==
nil
||
*
session
.
TargetUserID
<=
0
{
return
false
,
nil
return
false
,
nil
}
}
if
pendingSessionStringValue
(
session
.
UpstreamIdentityClaims
,
"suggested_display_name"
)
==
""
&&
if
pendingSessionStringValue
(
session
.
UpstreamIdentityClaims
,
"suggested_display_name"
)
==
""
&&
...
@@ -1262,6 +1391,59 @@ func readPendingOAuthBrowserSession(c *gin.Context, h *AuthHandler) (*service.Au
...
@@ -1262,6 +1391,59 @@ func readPendingOAuthBrowserSession(c *gin.Context, h *AuthHandler) (*service.Au
return
svc
,
session
,
clearCookies
,
nil
return
svc
,
session
,
clearCookies
,
nil
}
}
func
(
h
*
AuthHandler
)
consumePendingOAuthSessionOnLogout
(
c
*
gin
.
Context
)
{
if
c
==
nil
||
c
.
Request
==
nil
{
return
}
sessionToken
,
err
:=
readOAuthPendingSessionCookie
(
c
)
if
err
!=
nil
||
strings
.
TrimSpace
(
sessionToken
)
==
""
{
return
}
browserSessionKey
,
err
:=
readOAuthPendingBrowserCookie
(
c
)
if
err
!=
nil
||
strings
.
TrimSpace
(
browserSessionKey
)
==
""
{
return
}
svc
,
err
:=
h
.
pendingIdentityService
()
if
err
!=
nil
{
return
}
_
,
_
=
svc
.
ConsumeBrowserSession
(
c
.
Request
.
Context
(),
sessionToken
,
browserSessionKey
)
}
func
clearOAuthLogoutCookies
(
c
*
gin
.
Context
)
{
secureCookie
:=
isRequestHTTPS
(
c
)
clearOAuthPendingSessionCookie
(
c
,
secureCookie
)
clearOAuthPendingBrowserCookie
(
c
,
secureCookie
)
clearOAuthBindAccessTokenCookie
(
c
,
secureCookie
)
clearCookie
(
c
,
linuxDoOAuthStateCookieName
,
secureCookie
)
clearCookie
(
c
,
linuxDoOAuthVerifierCookie
,
secureCookie
)
clearCookie
(
c
,
linuxDoOAuthRedirectCookie
,
secureCookie
)
clearCookie
(
c
,
linuxDoOAuthIntentCookieName
,
secureCookie
)
clearCookie
(
c
,
linuxDoOAuthBindUserCookieName
,
secureCookie
)
oidcClearCookie
(
c
,
oidcOAuthStateCookieName
,
secureCookie
)
oidcClearCookie
(
c
,
oidcOAuthVerifierCookie
,
secureCookie
)
oidcClearCookie
(
c
,
oidcOAuthRedirectCookie
,
secureCookie
)
oidcClearCookie
(
c
,
oidcOAuthNonceCookie
,
secureCookie
)
oidcClearCookie
(
c
,
oidcOAuthIntentCookieName
,
secureCookie
)
oidcClearCookie
(
c
,
oidcOAuthBindUserCookieName
,
secureCookie
)
wechatClearCookie
(
c
,
wechatOAuthStateCookieName
,
secureCookie
)
wechatClearCookie
(
c
,
wechatOAuthRedirectCookieName
,
secureCookie
)
wechatClearCookie
(
c
,
wechatOAuthIntentCookieName
,
secureCookie
)
wechatClearCookie
(
c
,
wechatOAuthModeCookieName
,
secureCookie
)
wechatClearCookie
(
c
,
wechatOAuthBindUserCookieName
,
secureCookie
)
wechatPaymentClearCookie
(
c
,
wechatPaymentOAuthStateName
,
secureCookie
)
wechatPaymentClearCookie
(
c
,
wechatPaymentOAuthRedirect
,
secureCookie
)
wechatPaymentClearCookie
(
c
,
wechatPaymentOAuthContextName
,
secureCookie
)
wechatPaymentClearCookie
(
c
,
wechatPaymentOAuthScope
,
secureCookie
)
}
func
buildPendingOAuthSessionStatusPayload
(
session
*
dbent
.
PendingAuthSession
)
gin
.
H
{
func
buildPendingOAuthSessionStatusPayload
(
session
*
dbent
.
PendingAuthSession
)
gin
.
H
{
completionResponse
:=
normalizePendingOAuthCompletionResponse
(
mergePendingCompletionResponse
(
session
,
nil
))
completionResponse
:=
normalizePendingOAuthCompletionResponse
(
mergePendingCompletionResponse
(
session
,
nil
))
payload
:=
gin
.
H
{
payload
:=
gin
.
H
{
...
@@ -1280,6 +1462,9 @@ func buildPendingOAuthSessionStatusPayload(session *dbent.PendingAuthSession) gi
...
@@ -1280,6 +1462,9 @@ func buildPendingOAuthSessionStatusPayload(session *dbent.PendingAuthSession) gi
func
normalizePendingOAuthCompletionResponse
(
payload
map
[
string
]
any
)
map
[
string
]
any
{
func
normalizePendingOAuthCompletionResponse
(
payload
map
[
string
]
any
)
map
[
string
]
any
{
normalized
:=
clonePendingMap
(
payload
)
normalized
:=
clonePendingMap
(
payload
)
for
_
,
key
:=
range
[]
string
{
"access_token"
,
"refresh_token"
,
"expires_in"
,
"token_type"
}
{
delete
(
normalized
,
key
)
}
step
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
pendingSessionStringValue
(
normalized
,
"step"
)))
step
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
pendingSessionStringValue
(
normalized
,
"step"
)))
switch
step
{
switch
step
{
case
"choice"
,
"choose_account_action"
,
"choose_account"
,
"choose"
,
"email_required"
,
"bind_login_required"
:
case
"choice"
,
"choose_account_action"
,
"choose_account"
,
"choose"
,
"email_required"
,
"bind_login_required"
:
...
@@ -1315,16 +1500,21 @@ func (h *AuthHandler) transitionPendingOAuthAccountToChoiceState(
...
@@ -1315,16 +1500,21 @@ func (h *AuthHandler) transitionPendingOAuthAccountToChoiceState(
c
*
gin
.
Context
,
c
*
gin
.
Context
,
client
*
dbent
.
Client
,
client
*
dbent
.
Client
,
session
*
dbent
.
PendingAuthSession
,
session
*
dbent
.
PendingAuthSession
,
targetUser
*
dbent
.
User
,
email
string
,
email
string
,
)
(
*
dbent
.
PendingAuthSession
,
error
)
{
)
(
*
dbent
.
PendingAuthSession
,
error
)
{
completionResponse
:=
pendingOAuthChoiceCompletionResponse
(
session
,
email
)
completionResponse
:=
pendingOAuthChoiceCompletionResponse
(
session
,
email
)
var
targetUserID
*
int64
if
targetUser
!=
nil
&&
targetUser
.
ID
>
0
{
targetUserID
=
&
targetUser
.
ID
}
session
,
err
:=
updatePendingOAuthSessionProgress
(
session
,
err
:=
updatePendingOAuthSessionProgress
(
c
.
Request
.
Context
(),
c
.
Request
.
Context
(),
client
,
client
,
session
,
session
,
strings
.
TrimSpace
(
session
.
Intent
),
strings
.
TrimSpace
(
session
.
Intent
),
email
,
email
,
nil
,
targetUserID
,
completionResponse
,
completionResponse
,
)
)
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -1438,6 +1628,10 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
...
@@ -1438,6 +1628,10 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
response
.
ErrorFrom
(
c
,
err
)
response
.
ErrorFrom
(
c
,
err
)
return
return
}
}
if
err
:=
ensurePendingOAuthCompleteRegistrationSession
(
session
);
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
if
strings
.
TrimSpace
(
provider
)
!=
""
&&
!
strings
.
EqualFold
(
strings
.
TrimSpace
(
session
.
ProviderType
),
provider
)
{
if
strings
.
TrimSpace
(
provider
)
!=
""
&&
!
strings
.
EqualFold
(
strings
.
TrimSpace
(
session
.
ProviderType
),
provider
)
{
response
.
BadRequest
(
c
,
"Pending oauth session provider mismatch"
)
response
.
BadRequest
(
c
,
"Pending oauth session provider mismatch"
)
return
return
...
@@ -1464,7 +1658,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
...
@@ -1464,7 +1658,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
}
}
}
}
if
existingUser
!=
nil
{
if
existingUser
!=
nil
{
session
,
err
=
h
.
transitionPendingOAuthAccountToChoiceState
(
c
,
client
,
session
,
email
)
session
,
err
=
h
.
transitionPendingOAuthAccountToChoiceState
(
c
,
client
,
session
,
existingUser
,
email
)
if
err
!=
nil
{
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
response
.
ErrorFrom
(
c
,
err
)
return
return
...
@@ -1487,7 +1681,12 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
...
@@ -1487,7 +1681,12 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
)
)
if
err
!=
nil
{
if
err
!=
nil
{
if
errors
.
Is
(
err
,
service
.
ErrEmailExists
)
{
if
errors
.
Is
(
err
,
service
.
ErrEmailExists
)
{
session
,
err
=
h
.
transitionPendingOAuthAccountToChoiceState
(
c
,
client
,
session
,
email
)
existingUser
,
lookupErr
:=
findUserByNormalizedEmail
(
c
.
Request
.
Context
(),
client
,
email
)
if
lookupErr
!=
nil
{
response
.
ErrorFrom
(
c
,
lookupErr
)
return
}
session
,
err
=
h
.
transitionPendingOAuthAccountToChoiceState
(
c
,
client
,
session
,
existingUser
,
email
)
if
err
!=
nil
{
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
response
.
ErrorFrom
(
c
,
err
)
return
return
...
@@ -1649,33 +1848,35 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
...
@@ -1649,33 +1848,35 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
}
}
}
}
applySuggestedProfileToCompletionResponse
(
payload
,
session
.
UpstreamIdentityClaims
)
applySuggestedProfileToCompletionResponse
(
payload
,
session
.
UpstreamIdentityClaims
)
skipAdoptionPrompt
,
err
:=
h
.
shouldSkipPendingOAuthAdoptionPrompt
(
c
.
Request
.
Context
(),
session
,
payload
)
if
err
!=
nil
{
canIssueTokenPair
:=
pendingOAuthCompletionCanIssueTokenPair
(
session
,
payload
)
clearCookies
()
var
loginUser
*
service
.
User
response
.
ErrorFrom
(
c
,
err
)
if
canIssueTokenPair
{
return
loginUser
,
err
=
h
.
userService
.
GetByID
(
c
.
Request
.
Context
(),
*
session
.
TargetUserID
)
}
if
err
!=
nil
{
if
skipAdoptionPrompt
{
delete
(
payload
,
"adoption_required"
)
}
if
pendingOAuthCompletionIncludesTokenPayload
(
payload
)
{
if
session
.
TargetUserID
==
nil
||
*
session
.
TargetUserID
<=
0
{
clearCookies
()
clearCookies
()
response
.
ErrorFrom
(
c
,
infraerrors
.
InternalServer
(
"PENDING_AUTH_COMPLETION_INVALID"
,
"pending auth completion payload is invalid"
)
)
response
.
ErrorFrom
(
c
,
err
)
return
return
}
}
user
,
err
:=
h
.
userService
.
GetByID
(
c
.
Request
.
Context
(),
*
session
.
TargetUserID
)
if
err
:=
ensureLoginUserActive
(
loginUser
);
err
!=
nil
{
if
err
!=
nil
{
clearCookies
()
clearCookies
()
response
.
ErrorFrom
(
c
,
err
)
response
.
ErrorFrom
(
c
,
err
)
return
return
}
}
if
err
:=
h
.
ensureBackendModeAllowsUser
(
c
.
Request
.
Context
(),
u
ser
);
err
!=
nil
{
if
err
:=
h
.
ensureBackendModeAllowsUser
(
c
.
Request
.
Context
(),
loginU
ser
);
err
!=
nil
{
clearCookies
()
clearCookies
()
response
.
ErrorFrom
(
c
,
err
)
response
.
ErrorFrom
(
c
,
err
)
return
return
}
}
h
.
authService
.
RecordSuccessfulLogin
(
c
.
Request
.
Context
(),
user
.
ID
)
}
skipAdoptionPrompt
,
err
:=
h
.
shouldSkipPendingOAuthAdoptionPrompt
(
c
.
Request
.
Context
(),
session
,
payload
)
if
err
!=
nil
{
clearCookies
()
response
.
ErrorFrom
(
c
,
err
)
return
}
if
skipAdoptionPrompt
{
delete
(
payload
,
"adoption_required"
)
}
}
if
pendingSessionWantsInvitation
(
payload
)
{
if
pendingSessionWantsInvitation
(
payload
)
{
...
@@ -1724,6 +1925,20 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
...
@@ -1724,6 +1925,20 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
return
return
}
}
if
canIssueTokenPair
{
tokenPair
,
err
:=
h
.
authService
.
GenerateTokenPair
(
c
.
Request
.
Context
(),
loginUser
,
""
)
if
err
!=
nil
{
clearCookies
()
response
.
InternalError
(
c
,
"Failed to generate token pair"
)
return
}
h
.
authService
.
RecordSuccessfulLogin
(
c
.
Request
.
Context
(),
loginUser
.
ID
)
payload
[
"access_token"
]
=
tokenPair
.
AccessToken
payload
[
"refresh_token"
]
=
tokenPair
.
RefreshToken
payload
[
"expires_in"
]
=
tokenPair
.
ExpiresIn
payload
[
"token_type"
]
=
"Bearer"
}
clearCookies
()
clearCookies
()
response
.
Success
(
c
,
payload
)
response
.
Success
(
c
,
payload
)
}
}
backend/internal/handler/auth_oauth_pending_flow_test.go
View file @
ddf80f5e
...
@@ -746,8 +746,8 @@ func TestExchangePendingOAuthCompletionExistingLoginWithSuggestedProfileSkipsAdo
...
@@ -746,8 +746,8 @@ func TestExchangePendingOAuthCompletionExistingLoginWithSuggestedProfileSkipsAdo
})
.
})
.
SetLocalFlowState
(
map
[
string
]
any
{
SetLocalFlowState
(
map
[
string
]
any
{
oauthCompletionResponseKey
:
map
[
string
]
any
{
oauthCompletionResponseKey
:
map
[
string
]
any
{
"access_token"
:
"access-token"
,
"access_token"
:
"
legacy-
access-token"
,
"refresh_token"
:
"refresh-token"
,
"refresh_token"
:
"
legacy-
refresh-token"
,
"expires_in"
:
float64
(
3600
),
"expires_in"
:
float64
(
3600
),
"token_type"
:
"Bearer"
,
"token_type"
:
"Bearer"
,
"redirect"
:
"/dashboard"
,
"redirect"
:
"/dashboard"
,
...
@@ -769,13 +769,23 @@ func TestExchangePendingOAuthCompletionExistingLoginWithSuggestedProfileSkipsAdo
...
@@ -769,13 +769,23 @@ func TestExchangePendingOAuthCompletionExistingLoginWithSuggestedProfileSkipsAdo
require
.
Equal
(
t
,
http
.
StatusOK
,
recorder
.
Code
)
require
.
Equal
(
t
,
http
.
StatusOK
,
recorder
.
Code
)
payload
:=
decodeJSONResponseData
(
t
,
recorder
)
payload
:=
decodeJSONResponseData
(
t
,
recorder
)
require
.
Equal
(
t
,
"access-token"
,
payload
[
"access_token"
])
require
.
NotEmpty
(
t
,
payload
[
"access_token"
])
require
.
Equal
(
t
,
"refresh-token"
,
payload
[
"refresh_token"
])
require
.
NotEmpty
(
t
,
payload
[
"refresh_token"
])
require
.
NotEqual
(
t
,
"legacy-access-token"
,
payload
[
"access_token"
])
require
.
NotEqual
(
t
,
"legacy-refresh-token"
,
payload
[
"refresh_token"
])
require
.
Equal
(
t
,
"/dashboard"
,
payload
[
"redirect"
])
require
.
Equal
(
t
,
"/dashboard"
,
payload
[
"redirect"
])
require
.
Equal
(
t
,
"Existing Login Example"
,
payload
[
"suggested_display_name"
])
require
.
Equal
(
t
,
"Existing Login Example"
,
payload
[
"suggested_display_name"
])
require
.
Equal
(
t
,
"https://cdn.example/existing-login.png"
,
payload
[
"suggested_avatar_url"
])
require
.
Equal
(
t
,
"https://cdn.example/existing-login.png"
,
payload
[
"suggested_avatar_url"
])
require
.
NotContains
(
t
,
payload
,
"adoption_required"
)
require
.
NotContains
(
t
,
payload
,
"adoption_required"
)
accessToken
,
ok
:=
payload
[
"access_token"
]
.
(
string
)
require
.
True
(
t
,
ok
)
claims
,
err
:=
handler
.
authService
.
ValidateToken
(
accessToken
)
require
.
NoError
(
t
,
err
)
reloadedUser
,
err
:=
handler
.
userService
.
GetByID
(
ctx
,
userEntity
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
reloadedUser
.
TokenVersion
,
claims
.
TokenVersion
)
decisionCount
,
err
:=
client
.
IdentityAdoptionDecision
.
Query
()
.
decisionCount
,
err
:=
client
.
IdentityAdoptionDecision
.
Query
()
.
Where
(
identityadoptiondecision
.
PendingAuthSessionIDEQ
(
session
.
ID
))
.
Where
(
identityadoptiondecision
.
PendingAuthSessionIDEQ
(
session
.
ID
))
.
Count
(
ctx
)
Count
(
ctx
)
...
@@ -785,6 +795,14 @@ func TestExchangePendingOAuthCompletionExistingLoginWithSuggestedProfileSkipsAdo
...
@@ -785,6 +795,14 @@ func TestExchangePendingOAuthCompletionExistingLoginWithSuggestedProfileSkipsAdo
storedSession
,
err
:=
client
.
PendingAuthSession
.
Get
(
ctx
,
session
.
ID
)
storedSession
,
err
:=
client
.
PendingAuthSession
.
Get
(
ctx
,
session
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
storedSession
.
ConsumedAt
)
require
.
NotNil
(
t
,
storedSession
.
ConsumedAt
)
completion
,
ok
:=
storedSession
.
LocalFlowState
[
oauthCompletionResponseKey
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
NotContains
(
t
,
completion
,
"access_token"
)
require
.
NotContains
(
t
,
completion
,
"refresh_token"
)
require
.
NotContains
(
t
,
completion
,
"expires_in"
)
require
.
NotContains
(
t
,
completion
,
"token_type"
)
require
.
Equal
(
t
,
"/dashboard"
,
completion
[
"redirect"
])
}
}
func
TestExchangePendingOAuthCompletionBlocksBackendModeBeforeReturningTokenPayload
(
t
*
testing
.
T
)
{
func
TestExchangePendingOAuthCompletionBlocksBackendModeBeforeReturningTokenPayload
(
t
*
testing
.
T
)
{
...
@@ -841,6 +859,72 @@ func TestExchangePendingOAuthCompletionBlocksBackendModeBeforeReturningTokenPayl
...
@@ -841,6 +859,72 @@ func TestExchangePendingOAuthCompletionBlocksBackendModeBeforeReturningTokenPayl
require
.
Nil
(
t
,
storedSession
.
ConsumedAt
)
require
.
Nil
(
t
,
storedSession
.
ConsumedAt
)
}
}
func
TestExchangePendingOAuthCompletionRejectsDisabledTargetUser
(
t
*
testing
.
T
)
{
handler
,
client
:=
newOAuthPendingFlowTestHandler
(
t
,
false
)
ctx
:=
context
.
Background
()
userEntity
,
err
:=
client
.
User
.
Create
()
.
SetEmail
(
"disabled-linked@example.com"
)
.
SetUsername
(
"disabled-linked-user"
)
.
SetPasswordHash
(
"hash"
)
.
SetRole
(
service
.
RoleUser
)
.
SetStatus
(
service
.
StatusDisabled
)
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
session
,
err
:=
client
.
PendingAuthSession
.
Create
()
.
SetSessionToken
(
"disabled-linked-session-token"
)
.
SetIntent
(
"login"
)
.
SetProviderType
(
"linuxdo"
)
.
SetProviderKey
(
"linuxdo"
)
.
SetProviderSubject
(
"disabled-linked-subject"
)
.
SetTargetUserID
(
userEntity
.
ID
)
.
SetResolvedEmail
(
userEntity
.
Email
)
.
SetBrowserSessionKey
(
"disabled-linked-browser-session-key"
)
.
SetUpstreamIdentityClaims
(
map
[
string
]
any
{
"suggested_display_name"
:
"Disabled Linked User"
,
})
.
SetLocalFlowState
(
map
[
string
]
any
{
oauthCompletionResponseKey
:
map
[
string
]
any
{
"redirect"
:
"/dashboard"
,
},
})
.
SetExpiresAt
(
time
.
Now
()
.
UTC
()
.
Add
(
10
*
time
.
Minute
))
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
recorder
:=
httptest
.
NewRecorder
()
ginCtx
,
_
:=
gin
.
CreateTestContext
(
recorder
)
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/auth/oauth/pending/exchange"
,
nil
)
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthPendingSessionCookieName
,
Value
:
encodeCookieValue
(
session
.
SessionToken
)})
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthPendingBrowserCookieName
,
Value
:
encodeCookieValue
(
"disabled-linked-browser-session-key"
)})
ginCtx
.
Request
=
req
handler
.
ExchangePendingOAuthCompletion
(
ginCtx
)
require
.
Equal
(
t
,
http
.
StatusForbidden
,
recorder
.
Code
)
storedSession
,
err
:=
client
.
PendingAuthSession
.
Get
(
ctx
,
session
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Nil
(
t
,
storedSession
.
ConsumedAt
)
}
func
TestNormalizePendingOAuthCompletionResponseScrubsLegacyTokenPayload
(
t
*
testing
.
T
)
{
payload
:=
normalizePendingOAuthCompletionResponse
(
map
[
string
]
any
{
"access_token"
:
"legacy-access-token"
,
"refresh_token"
:
"legacy-refresh-token"
,
"expires_in"
:
float64
(
3600
),
"token_type"
:
"Bearer"
,
"redirect"
:
"/dashboard"
,
})
require
.
NotContains
(
t
,
payload
,
"access_token"
)
require
.
NotContains
(
t
,
payload
,
"refresh_token"
)
require
.
NotContains
(
t
,
payload
,
"expires_in"
)
require
.
NotContains
(
t
,
payload
,
"token_type"
)
require
.
Equal
(
t
,
"/dashboard"
,
payload
[
"redirect"
])
}
func
TestExchangePendingOAuthCompletionInvitationRequiredFalseFalsePersistsDecisionWithoutBinding
(
t
*
testing
.
T
)
{
func
TestExchangePendingOAuthCompletionInvitationRequiredFalseFalsePersistsDecisionWithoutBinding
(
t
*
testing
.
T
)
{
handler
,
client
:=
newOAuthPendingFlowTestHandler
(
t
,
true
)
handler
,
client
:=
newOAuthPendingFlowTestHandler
(
t
,
true
)
ctx
:=
context
.
Background
()
ctx
:=
context
.
Background
()
...
@@ -969,7 +1053,7 @@ func TestCreateOIDCOAuthAccountExistingEmailReturnsChoicePendingSessionState(t *
...
@@ -969,7 +1053,7 @@ func TestCreateOIDCOAuthAccountExistingEmailReturnsChoicePendingSessionState(t *
handler
,
client
:=
newOAuthPendingFlowTestHandlerWithEmailVerification
(
t
,
false
,
"owner@example.com"
,
"135790"
)
handler
,
client
:=
newOAuthPendingFlowTestHandlerWithEmailVerification
(
t
,
false
,
"owner@example.com"
,
"135790"
)
ctx
:=
context
.
Background
()
ctx
:=
context
.
Background
()
_
,
err
:=
client
.
User
.
Create
()
.
existingUser
,
err
:=
client
.
User
.
Create
()
.
SetEmail
(
"owner@example.com"
)
.
SetEmail
(
"owner@example.com"
)
.
SetUsername
(
"owner-user"
)
.
SetUsername
(
"owner-user"
)
.
SetPasswordHash
(
"hash"
)
.
SetPasswordHash
(
"hash"
)
.
...
@@ -1023,7 +1107,8 @@ func TestCreateOIDCOAuthAccountExistingEmailReturnsChoicePendingSessionState(t *
...
@@ -1023,7 +1107,8 @@ func TestCreateOIDCOAuthAccountExistingEmailReturnsChoicePendingSessionState(t *
storedSession
,
err
:=
client
.
PendingAuthSession
.
Get
(
ctx
,
session
.
ID
)
storedSession
,
err
:=
client
.
PendingAuthSession
.
Get
(
ctx
,
session
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
oauthIntentLogin
,
storedSession
.
Intent
)
require
.
Equal
(
t
,
oauthIntentLogin
,
storedSession
.
Intent
)
require
.
Nil
(
t
,
storedSession
.
TargetUserID
)
require
.
NotNil
(
t
,
storedSession
.
TargetUserID
)
require
.
Equal
(
t
,
existingUser
.
ID
,
*
storedSession
.
TargetUserID
)
require
.
Equal
(
t
,
"owner@example.com"
,
storedSession
.
ResolvedEmail
)
require
.
Equal
(
t
,
"owner@example.com"
,
storedSession
.
ResolvedEmail
)
require
.
Nil
(
t
,
storedSession
.
ConsumedAt
)
require
.
Nil
(
t
,
storedSession
.
ConsumedAt
)
...
@@ -1042,7 +1127,7 @@ func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *te
...
@@ -1042,7 +1127,7 @@ func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *te
handler
,
client
:=
newOAuthPendingFlowTestHandlerWithEmailVerification
(
t
,
false
,
"owner@example.com"
,
"135790"
)
handler
,
client
:=
newOAuthPendingFlowTestHandlerWithEmailVerification
(
t
,
false
,
"owner@example.com"
,
"135790"
)
ctx
:=
context
.
Background
()
ctx
:=
context
.
Background
()
_
,
err
:=
client
.
User
.
Create
()
.
existingUser
,
err
:=
client
.
User
.
Create
()
.
SetEmail
(
" Owner@Example.com "
)
.
SetEmail
(
" Owner@Example.com "
)
.
SetUsername
(
"owner-user"
)
.
SetUsername
(
"owner-user"
)
.
SetPasswordHash
(
"hash"
)
.
SetPasswordHash
(
"hash"
)
.
...
@@ -1088,7 +1173,8 @@ func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *te
...
@@ -1088,7 +1173,8 @@ func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *te
storedSession
,
err
:=
client
.
PendingAuthSession
.
Get
(
ctx
,
session
.
ID
)
storedSession
,
err
:=
client
.
PendingAuthSession
.
Get
(
ctx
,
session
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
Nil
(
t
,
storedSession
.
TargetUserID
)
require
.
NotNil
(
t
,
storedSession
.
TargetUserID
)
require
.
Equal
(
t
,
existingUser
.
ID
,
*
storedSession
.
TargetUserID
)
require
.
Equal
(
t
,
"owner@example.com"
,
storedSession
.
ResolvedEmail
)
require
.
Equal
(
t
,
"owner@example.com"
,
storedSession
.
ResolvedEmail
)
}
}
...
@@ -1096,7 +1182,7 @@ func TestSendPendingOAuthVerifyCodeExistingEmailReturnsBindLoginState(t *testing
...
@@ -1096,7 +1182,7 @@ func TestSendPendingOAuthVerifyCodeExistingEmailReturnsBindLoginState(t *testing
handler
,
client
:=
newOAuthPendingFlowTestHandlerWithEmailVerification
(
t
,
false
,
"owner@example.com"
,
"135790"
)
handler
,
client
:=
newOAuthPendingFlowTestHandlerWithEmailVerification
(
t
,
false
,
"owner@example.com"
,
"135790"
)
ctx
:=
context
.
Background
()
ctx
:=
context
.
Background
()
_
,
err
:=
client
.
User
.
Create
()
.
existingUser
,
err
:=
client
.
User
.
Create
()
.
SetEmail
(
"owner@example.com"
)
.
SetEmail
(
"owner@example.com"
)
.
SetUsername
(
"owner-user"
)
.
SetUsername
(
"owner-user"
)
.
SetPasswordHash
(
"hash"
)
.
SetPasswordHash
(
"hash"
)
.
...
@@ -1144,7 +1230,8 @@ func TestSendPendingOAuthVerifyCodeExistingEmailReturnsBindLoginState(t *testing
...
@@ -1144,7 +1230,8 @@ func TestSendPendingOAuthVerifyCodeExistingEmailReturnsBindLoginState(t *testing
storedSession
,
err
:=
client
.
PendingAuthSession
.
Get
(
ctx
,
session
.
ID
)
storedSession
,
err
:=
client
.
PendingAuthSession
.
Get
(
ctx
,
session
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
oauthIntentLogin
,
storedSession
.
Intent
)
require
.
Equal
(
t
,
oauthIntentLogin
,
storedSession
.
Intent
)
require
.
Nil
(
t
,
storedSession
.
TargetUserID
)
require
.
NotNil
(
t
,
storedSession
.
TargetUserID
)
require
.
Equal
(
t
,
existingUser
.
ID
,
*
storedSession
.
TargetUserID
)
require
.
Equal
(
t
,
"owner@example.com"
,
storedSession
.
ResolvedEmail
)
require
.
Equal
(
t
,
"owner@example.com"
,
storedSession
.
ResolvedEmail
)
}
}
...
@@ -1202,6 +1289,26 @@ func TestCreateOIDCOAuthAccountBlocksBackendModeBeforeCreatingUser(t *testing.T)
...
@@ -1202,6 +1289,26 @@ func TestCreateOIDCOAuthAccountBlocksBackendModeBeforeCreatingUser(t *testing.T)
require
.
Nil
(
t
,
storedSession
.
ConsumedAt
)
require
.
Nil
(
t
,
storedSession
.
ConsumedAt
)
}
}
func
TestLogoutClearsPendingOAuthAndBindCookies
(
t
*
testing
.
T
)
{
handler
,
_
:=
newOAuthPendingFlowTestHandler
(
t
,
false
)
recorder
:=
httptest
.
NewRecorder
()
ginCtx
,
_
:=
gin
.
CreateTestContext
(
recorder
)
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/auth/logout"
,
bytes
.
NewBufferString
(
`{}`
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthPendingSessionCookieName
,
Value
:
encodeCookieValue
(
"pending-session-token"
)})
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthPendingBrowserCookieName
,
Value
:
encodeCookieValue
(
"pending-browser-key"
)})
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthBindAccessTokenCookieName
,
Value
:
"bind-token"
})
ginCtx
.
Request
=
req
handler
.
Logout
(
ginCtx
)
require
.
Equal
(
t
,
http
.
StatusOK
,
recorder
.
Code
)
require
.
Equal
(
t
,
-
1
,
findCookie
(
recorder
.
Result
()
.
Cookies
(),
oauthPendingSessionCookieName
)
.
MaxAge
)
require
.
Equal
(
t
,
-
1
,
findCookie
(
recorder
.
Result
()
.
Cookies
(),
oauthPendingBrowserCookieName
)
.
MaxAge
)
require
.
Equal
(
t
,
-
1
,
findCookie
(
recorder
.
Result
()
.
Cookies
(),
oauthBindAccessTokenCookieName
)
.
MaxAge
)
}
func
TestCreateOIDCOAuthAccountRollsBackCreatedUserWhenBindingFails
(
t
*
testing
.
T
)
{
func
TestCreateOIDCOAuthAccountRollsBackCreatedUserWhenBindingFails
(
t
*
testing
.
T
)
{
handler
,
client
:=
newOAuthPendingFlowTestHandlerWithEmailVerification
(
t
,
true
,
"fresh@example.com"
,
"246810"
)
handler
,
client
:=
newOAuthPendingFlowTestHandlerWithEmailVerification
(
t
,
true
,
"fresh@example.com"
,
"246810"
)
ctx
:=
context
.
Background
()
ctx
:=
context
.
Background
()
...
@@ -1934,6 +2041,13 @@ func TestLogin2FACompletesPendingOAuthBindAndConsumesSession(t *testing.T) {
...
@@ -1934,6 +2041,13 @@ func TestLogin2FACompletesPendingOAuthBindAndConsumesSession(t *testing.T) {
payload
:=
decodeJSONResponseData
(
t
,
recorder
)
payload
:=
decodeJSONResponseData
(
t
,
recorder
)
require
.
NotEmpty
(
t
,
payload
[
"access_token"
])
require
.
NotEmpty
(
t
,
payload
[
"access_token"
])
require
.
NotEmpty
(
t
,
payload
[
"refresh_token"
])
require
.
NotEmpty
(
t
,
payload
[
"refresh_token"
])
accessToken
,
ok
:=
payload
[
"access_token"
]
.
(
string
)
require
.
True
(
t
,
ok
)
claims
,
err
:=
handler
.
authService
.
ValidateToken
(
accessToken
)
require
.
NoError
(
t
,
err
)
reloadedUser
,
err
:=
handler
.
userService
.
GetByID
(
ctx
,
existingUser
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
reloadedUser
.
TokenVersion
,
claims
.
TokenVersion
)
identity
,
err
:=
client
.
AuthIdentity
.
Query
()
.
identity
,
err
:=
client
.
AuthIdentity
.
Query
()
.
Where
(
Where
(
...
...
backend/internal/handler/auth_oauth_test_helpers_test.go
View file @
ddf80f5e
...
@@ -2,6 +2,7 @@ package handler
...
@@ -2,6 +2,7 @@ package handler
import
(
import
(
"net/http"
"net/http"
"net/url"
"testing"
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/require"
...
@@ -37,3 +38,20 @@ func decodeCookieValueForTest(t *testing.T, value string) string {
...
@@ -37,3 +38,20 @@ func decodeCookieValueForTest(t *testing.T, value string) string {
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
return
decoded
return
decoded
}
}
func
assertOAuthRedirectError
(
t
*
testing
.
T
,
location
string
,
errorCode
string
,
errorMessage
string
)
{
t
.
Helper
()
require
.
NotEmpty
(
t
,
location
)
parsed
,
err
:=
url
.
Parse
(
location
)
require
.
NoError
(
t
,
err
)
rawValues
:=
parsed
.
RawQuery
if
rawValues
==
""
{
rawValues
=
parsed
.
Fragment
}
values
,
err
:=
url
.
ParseQuery
(
rawValues
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
errorCode
,
values
.
Get
(
"error"
))
require
.
Equal
(
t
,
errorMessage
,
values
.
Get
(
"error_message"
))
}
backend/internal/handler/auth_oidc_oauth.go
View file @
ddf80f5e
...
@@ -157,21 +157,25 @@ func (h *AuthHandler) OIDCOAuthStart(c *gin.Context) {
...
@@ -157,21 +157,25 @@ func (h *AuthHandler) OIDCOAuthStart(c *gin.Context) {
}
}
codeChallenge
:=
""
codeChallenge
:=
""
verifier
,
genErr
:=
oauth
.
GenerateCodeVerifier
()
if
cfg
.
UsePKCE
{
if
genErr
!=
nil
{
verifier
,
genErr
:=
oauth
.
GenerateCodeVerifier
()
response
.
ErrorFrom
(
c
,
infraerrors
.
InternalServer
(
"OAUTH_PKCE_GEN_FAILED"
,
"failed to generate pkce verifier"
)
.
WithCause
(
genErr
))
if
genErr
!=
nil
{
return
response
.
ErrorFrom
(
c
,
infraerrors
.
InternalServer
(
"OAUTH_PKCE_GEN_FAILED"
,
"failed to generate pkce verifier"
)
.
WithCause
(
genErr
))
return
}
codeChallenge
=
oauth
.
GenerateCodeChallenge
(
verifier
)
oidcSetCookie
(
c
,
oidcOAuthVerifierCookie
,
encodeCookieValue
(
verifier
),
oidcOAuthCookieMaxAgeSec
,
secureCookie
)
}
}
codeChallenge
=
oauth
.
GenerateCodeChallenge
(
verifier
)
oidcSetCookie
(
c
,
oidcOAuthVerifierCookie
,
encodeCookieValue
(
verifier
),
oidcOAuthCookieMaxAgeSec
,
secureCookie
)
nonce
:=
""
nonce
:=
""
nonce
,
err
=
oauth
.
GenerateState
()
if
cfg
.
ValidateIDToken
{
if
err
!=
nil
{
nonce
,
err
=
oauth
.
GenerateState
()
response
.
ErrorFrom
(
c
,
infraerrors
.
InternalServer
(
"OAUTH_NONCE_GEN_FAILED"
,
"failed to generate oauth nonce"
)
.
WithCause
(
err
))
if
err
!=
nil
{
return
response
.
ErrorFrom
(
c
,
infraerrors
.
InternalServer
(
"OAUTH_NONCE_GEN_FAILED"
,
"failed to generate oauth nonce"
)
.
WithCause
(
err
))
return
}
oidcSetCookie
(
c
,
oidcOAuthNonceCookie
,
encodeCookieValue
(
nonce
),
oidcOAuthCookieMaxAgeSec
,
secureCookie
)
}
}
oidcSetCookie
(
c
,
oidcOAuthNonceCookie
,
encodeCookieValue
(
nonce
),
oidcOAuthCookieMaxAgeSec
,
secureCookie
)
redirectURI
:=
strings
.
TrimSpace
(
cfg
.
RedirectURL
)
redirectURI
:=
strings
.
TrimSpace
(
cfg
.
RedirectURL
)
if
redirectURI
==
""
{
if
redirectURI
==
""
{
...
@@ -244,17 +248,21 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
...
@@ -244,17 +248,21 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
intent
=
normalizeOAuthIntent
(
intent
)
intent
=
normalizeOAuthIntent
(
intent
)
codeVerifier
:=
""
codeVerifier
:=
""
codeVerifier
,
_
=
readCookieDecoded
(
c
,
oidcOAuthVerifierCookie
)
if
cfg
.
UsePKCE
{
if
codeVerifier
==
""
{
codeVerifier
,
_
=
readCookieDecoded
(
c
,
oidcOAuthVerifierCookie
)
redirectOAuthError
(
c
,
frontendCallback
,
"missing_verifier"
,
"missing pkce verifier"
,
""
)
if
codeVerifier
==
""
{
return
redirectOAuthError
(
c
,
frontendCallback
,
"missing_verifier"
,
"missing pkce verifier"
,
""
)
return
}
}
}
expectedNonce
:=
""
expectedNonce
:=
""
expectedNonce
,
_
=
readCookieDecoded
(
c
,
oidcOAuthNonceCookie
)
if
cfg
.
ValidateIDToken
{
if
expectedNonce
==
""
{
expectedNonce
,
_
=
readCookieDecoded
(
c
,
oidcOAuthNonceCookie
)
redirectOAuthError
(
c
,
frontendCallback
,
"missing_nonce"
,
"missing oauth nonce"
,
""
)
if
expectedNonce
==
""
{
return
redirectOAuthError
(
c
,
frontendCallback
,
"missing_nonce"
,
"missing oauth nonce"
,
""
)
return
}
}
}
redirectURI
:=
strings
.
TrimSpace
(
cfg
.
RedirectURL
)
redirectURI
:=
strings
.
TrimSpace
(
cfg
.
RedirectURL
)
...
@@ -284,16 +292,19 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
...
@@ -284,16 +292,19 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
return
return
}
}
if
strings
.
TrimSpace
(
tokenResp
.
IDToken
)
==
""
{
var
idClaims
*
oidcIDTokenClaims
redirectOAuthError
(
c
,
frontendCallback
,
"missing_id_token"
,
"missing id_token"
,
""
)
if
cfg
.
ValidateIDToken
{
return
if
strings
.
TrimSpace
(
tokenResp
.
IDToken
)
==
""
{
}
redirectOAuthError
(
c
,
frontendCallback
,
"missing_id_token"
,
"missing id_token"
,
""
)
return
}
idClaims
,
err
:=
oidcParseAndValidateIDToken
(
c
.
Request
.
Context
(),
cfg
,
tokenResp
.
IDToken
,
expectedNonce
)
idClaims
,
err
=
oidcParseAndValidateIDToken
(
c
.
Request
.
Context
(),
cfg
,
tokenResp
.
IDToken
,
expectedNonce
)
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Printf
(
"[OIDC OAuth] id_token validation failed: %v"
,
err
)
log
.
Printf
(
"[OIDC OAuth] id_token validation failed: %v"
,
err
)
redirectOAuthError
(
c
,
frontendCallback
,
"invalid_id_token"
,
"failed to validate id_token"
,
""
)
redirectOAuthError
(
c
,
frontendCallback
,
"invalid_id_token"
,
"failed to validate id_token"
,
""
)
return
return
}
}
}
userInfoClaims
,
err
:=
oidcFetchUserInfo
(
c
.
Request
.
Context
(),
cfg
,
tokenResp
)
userInfoClaims
,
err
:=
oidcFetchUserInfo
(
c
.
Request
.
Context
(),
cfg
,
tokenResp
)
...
@@ -303,7 +314,10 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
...
@@ -303,7 +314,10 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
return
return
}
}
subject
:=
strings
.
TrimSpace
(
idClaims
.
Subject
)
subject
:=
""
if
idClaims
!=
nil
{
subject
=
strings
.
TrimSpace
(
idClaims
.
Subject
)
}
if
subject
==
""
{
if
subject
==
""
{
subject
=
strings
.
TrimSpace
(
userInfoClaims
.
Subject
)
subject
=
strings
.
TrimSpace
(
userInfoClaims
.
Subject
)
}
}
...
@@ -311,7 +325,10 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
...
@@ -311,7 +325,10 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
redirectOAuthError
(
c
,
frontendCallback
,
"missing_subject"
,
"missing subject claim"
,
""
)
redirectOAuthError
(
c
,
frontendCallback
,
"missing_subject"
,
"missing subject claim"
,
""
)
return
return
}
}
issuer
:=
strings
.
TrimSpace
(
idClaims
.
Issuer
)
issuer
:=
""
if
idClaims
!=
nil
{
issuer
=
strings
.
TrimSpace
(
idClaims
.
Issuer
)
}
if
issuer
==
""
{
if
issuer
==
""
{
issuer
=
strings
.
TrimSpace
(
cfg
.
IssuerURL
)
issuer
=
strings
.
TrimSpace
(
cfg
.
IssuerURL
)
}
}
...
@@ -321,21 +338,34 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
...
@@ -321,21 +338,34 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
}
}
emailVerified
:=
userInfoClaims
.
EmailVerified
emailVerified
:=
userInfoClaims
.
EmailVerified
if
emailVerified
==
nil
{
if
emailVerified
==
nil
&&
idClaims
!=
nil
{
emailVerified
=
idClaims
.
EmailVerified
emailVerified
=
idClaims
.
EmailVerified
}
}
if
userInfoClaims
.
Subject
!=
""
&&
idClaims
.
Subject
!=
""
&&
strings
.
TrimSpace
(
userInfoClaims
.
Subject
)
!=
strings
.
TrimSpace
(
idClaims
.
Subject
)
{
if
idClaims
!=
nil
&&
userInfoClaims
.
Subject
!=
""
&&
idClaims
.
Subject
!=
""
&&
strings
.
TrimSpace
(
userInfoClaims
.
Subject
)
!=
strings
.
TrimSpace
(
idClaims
.
Subject
)
{
redirectOAuthError
(
c
,
frontendCallback
,
"subject_mismatch"
,
"userinfo subject does not match id_token"
,
""
)
redirectOAuthError
(
c
,
frontendCallback
,
"subject_mismatch"
,
"userinfo subject does not match id_token"
,
""
)
return
return
}
}
identityKey
:=
oidcIdentityKey
(
issuer
,
subject
)
identityKey
:=
oidcIdentityKey
(
issuer
,
subject
)
compatEmail
:=
strings
.
TrimSpace
(
firstNonEmpty
(
userInfoClaims
.
Email
,
idClaims
.
Email
))
compatEmail
:=
strings
.
TrimSpace
(
userInfoClaims
.
Email
)
if
compatEmail
==
""
&&
idClaims
!=
nil
{
compatEmail
=
strings
.
TrimSpace
(
idClaims
.
Email
)
}
email
:=
oidcSyntheticEmailFromIdentityKey
(
identityKey
)
email
:=
oidcSyntheticEmailFromIdentityKey
(
identityKey
)
username
:=
firstNonEmpty
(
username
:=
firstNonEmpty
(
userInfoClaims
.
Username
,
userInfoClaims
.
Username
,
idClaims
.
PreferredUsername
,
func
()
string
{
idClaims
.
Name
,
if
idClaims
!=
nil
{
return
idClaims
.
PreferredUsername
}
return
""
}(),
func
()
string
{
if
idClaims
!=
nil
{
return
idClaims
.
Name
}
return
""
}(),
oidcFallbackUsername
(
subject
),
oidcFallbackUsername
(
subject
),
)
)
identityRef
:=
service
.
PendingAuthIdentityKey
{
identityRef
:=
service
.
PendingAuthIdentityKey
{
...
@@ -344,14 +374,19 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
...
@@ -344,14 +374,19 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
ProviderSubject
:
subject
,
ProviderSubject
:
subject
,
}
}
upstreamClaims
:=
map
[
string
]
any
{
upstreamClaims
:=
map
[
string
]
any
{
"email"
:
email
,
"email"
:
email
,
"username"
:
username
,
"username"
:
username
,
"subject"
:
subject
,
"subject"
:
subject
,
"issuer"
:
issuer
,
"issuer"
:
issuer
,
"email_verified"
:
emailVerified
!=
nil
&&
*
emailVerified
,
"email_verified"
:
emailVerified
!=
nil
&&
*
emailVerified
,
"provider_fallback"
:
strings
.
TrimSpace
(
cfg
.
ProviderName
),
"provider_fallback"
:
strings
.
TrimSpace
(
cfg
.
ProviderName
),
"suggested_display_name"
:
firstNonEmpty
(
userInfoClaims
.
DisplayName
,
idClaims
.
Name
,
username
),
"suggested_display_name"
:
firstNonEmpty
(
userInfoClaims
.
DisplayName
,
func
()
string
{
"suggested_avatar_url"
:
userInfoClaims
.
AvatarURL
,
if
idClaims
!=
nil
{
return
idClaims
.
Name
}
return
""
}(),
username
),
"suggested_avatar_url"
:
userInfoClaims
.
AvatarURL
,
}
}
if
compatEmail
!=
""
&&
!
strings
.
EqualFold
(
strings
.
TrimSpace
(
compatEmail
),
strings
.
TrimSpace
(
email
))
{
if
compatEmail
!=
""
&&
!
strings
.
EqualFold
(
strings
.
TrimSpace
(
compatEmail
),
strings
.
TrimSpace
(
email
))
{
upstreamClaims
[
"compat_email"
]
=
compatEmail
upstreamClaims
[
"compat_email"
]
=
compatEmail
...
@@ -387,25 +422,16 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
...
@@ -387,25 +422,16 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
return
return
}
}
if
existingIdentityUser
!=
nil
{
if
existingIdentityUser
!=
nil
{
tokenPair
,
user
,
err
:=
h
.
authService
.
LoginOrRegisterOAuthWithTokenPair
(
c
.
Request
.
Context
(),
existingIdentityUser
.
Email
,
username
,
""
)
if
err
!=
nil
{
redirectOAuthError
(
c
,
frontendCallback
,
"login_failed"
,
infraerrors
.
Reason
(
err
),
infraerrors
.
Message
(
err
))
return
}
if
err
:=
h
.
createOAuthPendingSession
(
c
,
oauthPendingSessionPayload
{
if
err
:=
h
.
createOAuthPendingSession
(
c
,
oauthPendingSessionPayload
{
Intent
:
oauthIntentLogin
,
Intent
:
oauthIntentLogin
,
Identity
:
identityRef
,
Identity
:
identityRef
,
TargetUserID
:
&
u
ser
.
ID
,
TargetUserID
:
&
existingIdentityU
ser
.
ID
,
ResolvedEmail
:
existingIdentityUser
.
Email
,
ResolvedEmail
:
existingIdentityUser
.
Email
,
RedirectTo
:
redirectTo
,
RedirectTo
:
redirectTo
,
BrowserSessionKey
:
browserSessionKey
,
BrowserSessionKey
:
browserSessionKey
,
UpstreamIdentityClaims
:
upstreamClaims
,
UpstreamIdentityClaims
:
upstreamClaims
,
CompletionResponse
:
map
[
string
]
any
{
CompletionResponse
:
map
[
string
]
any
{
"access_token"
:
tokenPair
.
AccessToken
,
"redirect"
:
redirectTo
,
"refresh_token"
:
tokenPair
.
RefreshToken
,
"expires_in"
:
tokenPair
.
ExpiresIn
,
"token_type"
:
"Bearer"
,
"redirect"
:
redirectTo
,
},
},
});
err
!=
nil
{
});
err
!=
nil
{
redirectOAuthError
(
c
,
frontendCallback
,
"session_error"
,
"failed to continue oauth login"
,
""
)
redirectOAuthError
(
c
,
frontendCallback
,
"session_error"
,
"failed to continue oauth login"
,
""
)
...
@@ -537,10 +563,15 @@ func (h *AuthHandler) createOIDCOAuthChoicePendingSession(
...
@@ -537,10 +563,15 @@ func (h *AuthHandler) createOIDCOAuthChoicePendingSession(
if
compatEmailUser
!=
nil
{
if
compatEmailUser
!=
nil
{
resolvedChoiceEmail
=
strings
.
TrimSpace
(
compatEmailUser
.
Email
)
resolvedChoiceEmail
=
strings
.
TrimSpace
(
compatEmailUser
.
Email
)
}
}
var
targetUserID
*
int64
if
compatEmailUser
!=
nil
&&
compatEmailUser
.
ID
>
0
{
targetUserID
=
&
compatEmailUser
.
ID
}
return
h
.
createOAuthPendingSession
(
c
,
oauthPendingSessionPayload
{
return
h
.
createOAuthPendingSession
(
c
,
oauthPendingSessionPayload
{
Intent
:
oauthIntentLogin
,
Intent
:
oauthIntentLogin
,
Identity
:
identity
,
Identity
:
identity
,
TargetUserID
:
targetUserID
,
ResolvedEmail
:
resolvedChoiceEmail
,
ResolvedEmail
:
resolvedChoiceEmail
,
RedirectTo
:
redirectTo
,
RedirectTo
:
redirectTo
,
BrowserSessionKey
:
browserSessionKey
,
BrowserSessionKey
:
browserSessionKey
,
...
@@ -596,6 +627,15 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
...
@@ -596,6 +627,15 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
response
.
ErrorFrom
(
c
,
err
)
response
.
ErrorFrom
(
c
,
err
)
return
return
}
}
if
updatedSession
,
handled
,
err
:=
h
.
legacyCompleteRegistrationSessionStatus
(
c
,
session
);
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
else
if
handled
{
c
.
JSON
(
http
.
StatusOK
,
buildPendingOAuthSessionStatusPayload
(
updatedSession
))
return
}
else
{
session
=
updatedSession
}
if
err
:=
h
.
ensureBackendModeAllowsNewUserLogin
(
c
.
Request
.
Context
());
err
!=
nil
{
if
err
:=
h
.
ensureBackendModeAllowsNewUserLogin
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
response
.
ErrorFrom
(
c
,
err
)
return
return
...
@@ -608,12 +648,16 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
...
@@ -608,12 +648,16 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
return
return
}
}
tokenPair
,
user
,
err
:=
h
.
authService
.
LoginOrRegisterOAuthWithTokenPair
(
c
.
Request
.
Context
(),
email
,
username
,
req
.
InvitationCode
)
client
:=
h
.
entClient
()
if
err
!=
nil
{
if
client
==
nil
{
response
.
ErrorFrom
(
c
,
err
)
response
.
ErrorFrom
(
c
,
infraerrors
.
ServiceUnavailable
(
"PENDING_AUTH_NOT_READY"
,
"pending auth service is not ready"
))
return
}
if
err
:=
ensurePendingOAuthRegistrationIdentityAvailable
(
c
.
Request
.
Context
(),
client
,
session
);
err
!=
nil
{
respondPendingOAuthBindingApplyError
(
c
,
err
)
return
return
}
}
decision
,
err
:=
h
.
upsert
PendingOAuthAdoptionDecision
(
c
,
session
.
ID
,
oauthAdoptionDecisionRequest
{
decision
,
err
:=
h
.
ensure
PendingOAuthAdoptionDecision
(
c
,
session
.
ID
,
oauthAdoptionDecisionRequest
{
AdoptDisplayName
:
req
.
AdoptDisplayName
,
AdoptDisplayName
:
req
.
AdoptDisplayName
,
AdoptAvatar
:
req
.
AdoptAvatar
,
AdoptAvatar
:
req
.
AdoptAvatar
,
})
})
...
@@ -621,17 +665,16 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
...
@@ -621,17 +665,16 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
response
.
ErrorFrom
(
c
,
err
)
response
.
ErrorFrom
(
c
,
err
)
return
return
}
}
if
err
:=
applyPendingOAuthAdoption
(
c
.
Request
.
Context
(),
h
.
entClient
(),
h
.
authService
,
h
.
userService
,
session
,
decision
,
&
user
.
ID
);
err
!=
nil
{
tokenPair
,
user
,
err
:=
h
.
authService
.
LoginOrRegisterOAuthWithTokenPair
(
c
.
Request
.
Context
(),
email
,
username
,
req
.
InvitationCode
)
response
.
ErrorFrom
(
c
,
infraerrors
.
InternalServer
(
"PENDING_AUTH_ADOPTION_APPLY_FAILED"
,
"failed to apply oauth profile adoption"
)
.
WithCause
(
err
))
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
return
}
}
h
.
authService
.
RecordSuccessfulLogin
(
c
.
Request
.
Context
(),
user
.
ID
)
if
err
:=
applyPendingOAuthAdoptionAndConsumeSession
(
c
.
Request
.
Context
(),
client
,
h
.
authService
,
h
.
userService
,
session
,
decision
,
user
.
ID
);
err
!=
nil
{
if
_
,
err
:=
pendingSvc
.
ConsumeBrowserSession
(
c
.
Request
.
Context
(),
sessionToken
,
browserSessionKey
);
err
!=
nil
{
respondPendingOAuthBindingApplyError
(
c
,
err
)
clearOAuthPendingSessionCookie
(
c
,
secureCookie
)
clearOAuthPendingBrowserCookie
(
c
,
secureCookie
)
response
.
ErrorFrom
(
c
,
err
)
return
return
}
}
h
.
authService
.
RecordSuccessfulLogin
(
c
.
Request
.
Context
(),
user
.
ID
)
clearOAuthPendingSessionCookie
(
c
,
secureCookie
)
clearOAuthPendingSessionCookie
(
c
,
secureCookie
)
clearOAuthPendingBrowserCookie
(
c
,
secureCookie
)
clearOAuthPendingBrowserCookie
(
c
,
secureCookie
)
...
@@ -670,7 +713,9 @@ func oidcExchangeCode(
...
@@ -670,7 +713,9 @@ func oidcExchangeCode(
form
.
Set
(
"client_id"
,
cfg
.
ClientID
)
form
.
Set
(
"client_id"
,
cfg
.
ClientID
)
form
.
Set
(
"code"
,
code
)
form
.
Set
(
"code"
,
code
)
form
.
Set
(
"redirect_uri"
,
redirectURI
)
form
.
Set
(
"redirect_uri"
,
redirectURI
)
form
.
Set
(
"code_verifier"
,
codeVerifier
)
if
strings
.
TrimSpace
(
codeVerifier
)
!=
""
{
form
.
Set
(
"code_verifier"
,
codeVerifier
)
}
r
:=
client
.
R
()
.
r
:=
client
.
R
()
.
SetContext
(
ctx
)
.
SetContext
(
ctx
)
.
...
@@ -872,9 +917,13 @@ func buildOIDCAuthorizeURL(cfg config.OIDCConnectConfig, state, nonce, codeChall
...
@@ -872,9 +917,13 @@ func buildOIDCAuthorizeURL(cfg config.OIDCConnectConfig, state, nonce, codeChall
q
.
Set
(
"scope"
,
cfg
.
Scopes
)
q
.
Set
(
"scope"
,
cfg
.
Scopes
)
}
}
q
.
Set
(
"state"
,
state
)
q
.
Set
(
"state"
,
state
)
q
.
Set
(
"nonce"
,
nonce
)
if
strings
.
TrimSpace
(
nonce
)
!=
""
{
q
.
Set
(
"code_challenge"
,
codeChallenge
)
q
.
Set
(
"nonce"
,
nonce
)
q
.
Set
(
"code_challenge_method"
,
"S256"
)
}
if
strings
.
TrimSpace
(
codeChallenge
)
!=
""
{
q
.
Set
(
"code_challenge"
,
codeChallenge
)
q
.
Set
(
"code_challenge_method"
,
"S256"
)
}
u
.
RawQuery
=
q
.
Encode
()
u
.
RawQuery
=
q
.
Encode
()
return
u
.
String
(),
nil
return
u
.
String
(),
nil
...
...
backend/internal/handler/auth_oidc_oauth_test.go
View file @
ddf80f5e
...
@@ -186,6 +186,89 @@ func TestOIDCOAuthBindStartRedirectsAndSetsBindCookies(t *testing.T) {
...
@@ -186,6 +186,89 @@ func TestOIDCOAuthBindStartRedirectsAndSetsBindCookies(t *testing.T) {
require
.
Equal
(
t
,
int64
(
84
),
userID
)
require
.
Equal
(
t
,
int64
(
84
),
userID
)
}
}
func
TestOIDCOAuthStartOmitsPKCEAndNonceWhenDisabled
(
t
*
testing
.
T
)
{
handler
:=
newOIDCOAuthTestHandler
(
t
,
false
,
config
.
OIDCConnectConfig
{
Enabled
:
true
,
ClientID
:
"oidc-client"
,
ClientSecret
:
"oidc-secret"
,
IssuerURL
:
"https://issuer.example.com"
,
AuthorizeURL
:
"https://issuer.example.com/oauth/authorize"
,
TokenURL
:
"https://issuer.example.com/oauth/token"
,
UserInfoURL
:
"https://issuer.example.com/oauth/userinfo"
,
Scopes
:
"openid profile email"
,
RedirectURL
:
"https://api.example.com/api/v1/auth/oauth/oidc/callback"
,
FrontendRedirectURL
:
"/auth/oidc/callback"
,
TokenAuthMethod
:
"client_secret_post"
,
UsePKCE
:
false
,
ValidateIDToken
:
false
,
RequireEmailVerified
:
false
,
})
recorder
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
recorder
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/api/v1/auth/oauth/oidc/start?redirect=/dashboard"
,
nil
)
handler
.
OIDCOAuthStart
(
c
)
require
.
Equal
(
t
,
http
.
StatusFound
,
recorder
.
Code
)
location
:=
recorder
.
Header
()
.
Get
(
"Location"
)
require
.
NotContains
(
t
,
location
,
"code_challenge="
)
require
.
NotContains
(
t
,
location
,
"nonce="
)
require
.
Nil
(
t
,
findCookie
(
recorder
.
Result
()
.
Cookies
(),
oidcOAuthVerifierCookie
))
require
.
Nil
(
t
,
findCookie
(
recorder
.
Result
()
.
Cookies
(),
oidcOAuthNonceCookie
))
}
func
TestOIDCOAuthCallbackAllowsOptionalPKCEAndIDTokenValidation
(
t
*
testing
.
T
)
{
upstream
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
switch
r
.
URL
.
Path
{
case
"/token"
:
require
.
NoError
(
t
,
r
.
ParseForm
())
require
.
Empty
(
t
,
r
.
PostForm
.
Get
(
"code_verifier"
))
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
w
.
Write
([]
byte
(
`{"access_token":"oidc-access","token_type":"Bearer","expires_in":3600}`
))
case
"/userinfo"
:
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
w
.
Write
([]
byte
(
`{"sub":"oidc-subject-compat","preferred_username":"oidc_user","name":"OIDC Display","email":"oidc@example.com"}`
))
default
:
http
.
NotFound
(
w
,
r
)
}
}))
defer
upstream
.
Close
()
handler
,
client
:=
newOIDCOAuthHandlerAndClient
(
t
,
false
,
config
.
OIDCConnectConfig
{
Enabled
:
true
,
ClientID
:
"oidc-client"
,
ClientSecret
:
"oidc-secret"
,
IssuerURL
:
"https://issuer.example.com"
,
AuthorizeURL
:
upstream
.
URL
+
"/authorize"
,
TokenURL
:
upstream
.
URL
+
"/token"
,
UserInfoURL
:
upstream
.
URL
+
"/userinfo"
,
Scopes
:
"openid profile email"
,
RedirectURL
:
"https://api.example.com/api/v1/auth/oauth/oidc/callback"
,
FrontendRedirectURL
:
"/auth/oidc/callback"
,
TokenAuthMethod
:
"client_secret_post"
,
UsePKCE
:
false
,
ValidateIDToken
:
false
,
RequireEmailVerified
:
false
,
})
t
.
Cleanup
(
func
()
{
_
=
client
.
Close
()
})
recorder
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
recorder
)
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-123"
,
nil
)
req
.
AddCookie
(
encodedCookie
(
oidcOAuthStateCookieName
,
"state-123"
))
req
.
AddCookie
(
encodedCookie
(
oidcOAuthRedirectCookie
,
"/dashboard"
))
req
.
AddCookie
(
encodedCookie
(
oidcOAuthIntentCookieName
,
oauthIntentLogin
))
req
.
AddCookie
(
encodedCookie
(
oauthPendingBrowserCookieName
,
"browser-123"
))
c
.
Request
=
req
handler
.
OIDCOAuthCallback
(
c
)
require
.
Equal
(
t
,
http
.
StatusFound
,
recorder
.
Code
)
require
.
Equal
(
t
,
"/auth/oidc/callback"
,
recorder
.
Header
()
.
Get
(
"Location"
))
require
.
NotNil
(
t
,
findCookie
(
recorder
.
Result
()
.
Cookies
(),
oauthPendingSessionCookieName
))
}
func
TestOIDCOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser
(
t
*
testing
.
T
)
{
func
TestOIDCOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser
(
t
*
testing
.
T
)
{
cfg
,
cleanup
:=
newOIDCTestProvider
(
t
,
oidcProviderFixture
{
cfg
,
cleanup
:=
newOIDCTestProvider
(
t
,
oidcProviderFixture
{
Subject
:
"oidc-subject-login"
,
Subject
:
"oidc-subject-login"
,
...
@@ -250,10 +333,63 @@ func TestOIDCOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t *t
...
@@ -250,10 +333,63 @@ func TestOIDCOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t *t
completion
,
ok
:=
session
.
LocalFlowState
[
oauthCompletionResponseKey
]
.
(
map
[
string
]
any
)
completion
,
ok
:=
session
.
LocalFlowState
[
oauthCompletionResponseKey
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"/dashboard"
,
completion
[
"redirect"
])
require
.
Equal
(
t
,
"/dashboard"
,
completion
[
"redirect"
])
require
.
NotEmpty
(
t
,
completion
[
"access_token"
])
_
,
hasAccessToken
:=
completion
[
"access_token"
]
require
.
False
(
t
,
hasAccessToken
)
_
,
hasRefreshToken
:=
completion
[
"refresh_token"
]
require
.
False
(
t
,
hasRefreshToken
)
require
.
Nil
(
t
,
completion
[
"error"
])
require
.
Nil
(
t
,
completion
[
"error"
])
}
}
func
TestOIDCOAuthCallbackRejectsDisabledExistingIdentityUser
(
t
*
testing
.
T
)
{
cfg
,
cleanup
:=
newOIDCTestProvider
(
t
,
oidcProviderFixture
{
Subject
:
"oidc-disabled-subject"
,
PreferredUsername
:
"oidc_disabled"
,
DisplayName
:
"OIDC Disabled"
,
})
defer
cleanup
()
handler
,
client
:=
newOIDCOAuthHandlerAndClient
(
t
,
false
,
cfg
)
t
.
Cleanup
(
func
()
{
_
=
client
.
Close
()
})
ctx
:=
context
.
Background
()
existingUser
,
err
:=
client
.
User
.
Create
()
.
SetEmail
(
oidcSyntheticEmailFromIdentityKey
(
oidcIdentityKey
(
cfg
.
IssuerURL
,
"oidc-disabled-subject"
)))
.
SetUsername
(
"disabled-user"
)
.
SetPasswordHash
(
"hash"
)
.
SetRole
(
service
.
RoleUser
)
.
SetStatus
(
service
.
StatusDisabled
)
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
_
,
err
=
client
.
AuthIdentity
.
Create
()
.
SetUserID
(
existingUser
.
ID
)
.
SetProviderType
(
"oidc"
)
.
SetProviderKey
(
cfg
.
IssuerURL
)
.
SetProviderSubject
(
"oidc-disabled-subject"
)
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
recorder
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
recorder
)
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-disabled"
,
nil
)
req
.
AddCookie
(
encodedCookie
(
oidcOAuthStateCookieName
,
"state-disabled"
))
req
.
AddCookie
(
encodedCookie
(
oidcOAuthRedirectCookie
,
"/dashboard"
))
req
.
AddCookie
(
encodedCookie
(
oidcOAuthVerifierCookie
,
"verifier-disabled"
))
req
.
AddCookie
(
encodedCookie
(
oidcOAuthNonceCookie
,
"nonce-oidc-disabled-subject"
))
req
.
AddCookie
(
encodedCookie
(
oidcOAuthIntentCookieName
,
oauthIntentLogin
))
req
.
AddCookie
(
encodedCookie
(
oauthPendingBrowserCookieName
,
"browser-disabled"
))
c
.
Request
=
req
handler
.
OIDCOAuthCallback
(
c
)
require
.
Equal
(
t
,
http
.
StatusFound
,
recorder
.
Code
)
require
.
Nil
(
t
,
findCookie
(
recorder
.
Result
()
.
Cookies
(),
oauthPendingSessionCookieName
))
assertOAuthRedirectError
(
t
,
recorder
.
Header
()
.
Get
(
"Location"
),
"session_error"
,
"USER_NOT_ACTIVE"
)
count
,
err
:=
client
.
PendingAuthSession
.
Query
()
.
Count
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Zero
(
t
,
count
)
}
func
TestOIDCOAuthCallbackCreatesBindPendingSessionForCompatEmailUser
(
t
*
testing
.
T
)
{
func
TestOIDCOAuthCallbackCreatesBindPendingSessionForCompatEmailUser
(
t
*
testing
.
T
)
{
cfg
,
cleanup
:=
newOIDCTestProvider
(
t
,
oidcProviderFixture
{
cfg
,
cleanup
:=
newOIDCTestProvider
(
t
,
oidcProviderFixture
{
Subject
:
"oidc-subject-compat"
,
Subject
:
"oidc-subject-compat"
,
...
@@ -302,7 +438,8 @@ func TestOIDCOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing
...
@@ -302,7 +438,8 @@ func TestOIDCOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing
Only
(
ctx
)
Only
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
oauthIntentLogin
,
session
.
Intent
)
require
.
Equal
(
t
,
oauthIntentLogin
,
session
.
Intent
)
require
.
Nil
(
t
,
session
.
TargetUserID
)
require
.
NotNil
(
t
,
session
.
TargetUserID
)
require
.
Equal
(
t
,
existingUser
.
ID
,
*
session
.
TargetUserID
)
require
.
Equal
(
t
,
existingUser
.
Email
,
session
.
ResolvedEmail
)
require
.
Equal
(
t
,
existingUser
.
Email
,
session
.
ResolvedEmail
)
require
.
Equal
(
t
,
"legacy@example.com"
,
session
.
UpstreamIdentityClaims
[
"compat_email"
])
require
.
Equal
(
t
,
"legacy@example.com"
,
session
.
UpstreamIdentityClaims
[
"compat_email"
])
...
@@ -606,6 +743,189 @@ func TestCompleteOIDCOAuthRegistrationRejectsAdoptExistingUserSession(t *testing
...
@@ -606,6 +743,189 @@ func TestCompleteOIDCOAuthRegistrationRejectsAdoptExistingUserSession(t *testing
require
.
Nil
(
t
,
storedSession
.
ConsumedAt
)
require
.
Nil
(
t
,
storedSession
.
ConsumedAt
)
}
}
func
TestCompleteOIDCOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired
(
t
*
testing
.
T
)
{
handler
,
client
:=
newOAuthPendingFlowTestHandler
(
t
,
false
)
ctx
:=
context
.
Background
()
session
,
err
:=
client
.
PendingAuthSession
.
Create
()
.
SetSessionToken
(
"oidc-complete-choice-session"
)
.
SetIntent
(
"login"
)
.
SetProviderType
(
"oidc"
)
.
SetProviderKey
(
"https://issuer.example.com"
)
.
SetProviderSubject
(
"oidc-choice-subject-1"
)
.
SetResolvedEmail
(
"oidc-choice-subject-1@oidc-connect.invalid"
)
.
SetBrowserSessionKey
(
"oidc-choice-browser"
)
.
SetUpstreamIdentityClaims
(
map
[
string
]
any
{
"username"
:
"oidc_user"
,
"issuer"
:
"https://issuer.example.com"
,
})
.
SetLocalFlowState
(
map
[
string
]
any
{
oauthCompletionResponseKey
:
map
[
string
]
any
{
"step"
:
oauthPendingChoiceStep
,
"redirect"
:
"/dashboard"
,
"email"
:
"fresh@example.com"
,
"resolved_email"
:
"fresh@example.com"
,
"force_email_on_signup"
:
true
,
},
})
.
SetExpiresAt
(
time
.
Now
()
.
UTC
()
.
Add
(
10
*
time
.
Minute
))
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
body
:=
bytes
.
NewBufferString
(
`{"invitation_code":"invite-1"}`
)
recorder
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
recorder
)
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/auth/oauth/oidc/complete-registration"
,
body
)
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthPendingSessionCookieName
,
Value
:
encodeCookieValue
(
session
.
SessionToken
)})
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthPendingBrowserCookieName
,
Value
:
encodeCookieValue
(
"oidc-choice-browser"
)})
c
.
Request
=
req
handler
.
CompleteOIDCOAuthRegistration
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
recorder
.
Code
)
responseData
:=
decodeJSONBody
(
t
,
recorder
)
require
.
Equal
(
t
,
"pending_session"
,
responseData
[
"auth_result"
])
require
.
Equal
(
t
,
oauthPendingChoiceStep
,
responseData
[
"step"
])
require
.
Equal
(
t
,
true
,
responseData
[
"force_email_on_signup"
])
require
.
Empty
(
t
,
responseData
[
"access_token"
])
userCount
,
err
:=
client
.
User
.
Query
()
.
Count
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Zero
(
t
,
userCount
)
storedSession
,
err
:=
client
.
PendingAuthSession
.
Get
(
ctx
,
session
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Nil
(
t
,
storedSession
.
ConsumedAt
)
}
func
TestCompleteOIDCOAuthRegistrationBindsIdentityWithoutAdoptionFlags
(
t
*
testing
.
T
)
{
handler
,
client
:=
newOAuthPendingFlowTestHandler
(
t
,
false
)
ctx
:=
context
.
Background
()
session
,
err
:=
client
.
PendingAuthSession
.
Create
()
.
SetSessionToken
(
"oidc-complete-no-adoption-session"
)
.
SetIntent
(
"login"
)
.
SetProviderType
(
"oidc"
)
.
SetProviderKey
(
"https://issuer.example.com"
)
.
SetProviderSubject
(
"oidc-subject-no-adoption"
)
.
SetResolvedEmail
(
"8c9f12b2a2e14b1db9efc08b27e0ef5c@oidc-connect.invalid"
)
.
SetBrowserSessionKey
(
"oidc-browser-no-adoption"
)
.
SetUpstreamIdentityClaims
(
map
[
string
]
any
{
"username"
:
"oidc_user"
,
"issuer"
:
"https://issuer.example.com"
,
"suggested_display_name"
:
"OIDC Legacy"
,
"suggested_avatar_url"
:
"https://cdn.example/oidc-legacy.png"
,
})
.
SetExpiresAt
(
time
.
Now
()
.
UTC
()
.
Add
(
10
*
time
.
Minute
))
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
body
:=
bytes
.
NewBufferString
(
`{"invitation_code":"invite-1"}`
)
recorder
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
recorder
)
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/auth/oauth/oidc/complete-registration"
,
body
)
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthPendingSessionCookieName
,
Value
:
encodeCookieValue
(
session
.
SessionToken
)})
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthPendingBrowserCookieName
,
Value
:
encodeCookieValue
(
"oidc-browser-no-adoption"
)})
c
.
Request
=
req
handler
.
CompleteOIDCOAuthRegistration
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
recorder
.
Code
)
responseData
:=
decodeJSONBody
(
t
,
recorder
)
require
.
NotEmpty
(
t
,
responseData
[
"access_token"
])
require
.
NotEmpty
(
t
,
responseData
[
"refresh_token"
])
userEntity
,
err
:=
client
.
User
.
Query
()
.
Where
(
dbuser
.
EmailEQ
(
session
.
ResolvedEmail
))
.
Only
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"oidc_user"
,
userEntity
.
Username
)
identity
,
err
:=
client
.
AuthIdentity
.
Query
()
.
Where
(
authidentity
.
ProviderTypeEQ
(
"oidc"
),
authidentity
.
ProviderKeyEQ
(
"https://issuer.example.com"
),
authidentity
.
ProviderSubjectEQ
(
"oidc-subject-no-adoption"
),
)
.
Only
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
userEntity
.
ID
,
identity
.
UserID
)
decision
,
err
:=
client
.
IdentityAdoptionDecision
.
Query
()
.
Where
(
identityadoptiondecision
.
PendingAuthSessionIDEQ
(
session
.
ID
))
.
Only
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
decision
.
IdentityID
)
require
.
Equal
(
t
,
identity
.
ID
,
*
decision
.
IdentityID
)
require
.
False
(
t
,
decision
.
AdoptDisplayName
)
require
.
False
(
t
,
decision
.
AdoptAvatar
)
}
func
TestCompleteOIDCOAuthRegistrationRejectsIdentityOwnershipConflictBeforeUserCreation
(
t
*
testing
.
T
)
{
handler
,
client
:=
newOAuthPendingFlowTestHandler
(
t
,
false
)
ctx
:=
context
.
Background
()
existingOwner
,
err
:=
client
.
User
.
Create
()
.
SetEmail
(
"owner@example.com"
)
.
SetUsername
(
"owner-user"
)
.
SetPasswordHash
(
"hash"
)
.
SetRole
(
service
.
RoleUser
)
.
SetStatus
(
service
.
StatusActive
)
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
_
,
err
=
client
.
AuthIdentity
.
Create
()
.
SetUserID
(
existingOwner
.
ID
)
.
SetProviderType
(
"oidc"
)
.
SetProviderKey
(
"https://issuer.example.com"
)
.
SetProviderSubject
(
"oidc-conflict-subject"
)
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
session
,
err
:=
client
.
PendingAuthSession
.
Create
()
.
SetSessionToken
(
"oidc-complete-conflict-session"
)
.
SetIntent
(
"login"
)
.
SetProviderType
(
"oidc"
)
.
SetProviderKey
(
"https://issuer.example.com"
)
.
SetProviderSubject
(
"oidc-conflict-subject"
)
.
SetResolvedEmail
(
"f6f5f1f16f9248ccb11e0d633963b290@oidc-connect.invalid"
)
.
SetBrowserSessionKey
(
"oidc-conflict-browser"
)
.
SetUpstreamIdentityClaims
(
map
[
string
]
any
{
"username"
:
"oidc_user"
,
"issuer"
:
"https://issuer.example.com"
,
})
.
SetExpiresAt
(
time
.
Now
()
.
UTC
()
.
Add
(
10
*
time
.
Minute
))
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
body
:=
bytes
.
NewBufferString
(
`{"invitation_code":"invite-1"}`
)
recorder
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
recorder
)
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/auth/oauth/oidc/complete-registration"
,
body
)
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthPendingSessionCookieName
,
Value
:
encodeCookieValue
(
session
.
SessionToken
)})
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthPendingBrowserCookieName
,
Value
:
encodeCookieValue
(
"oidc-conflict-browser"
)})
c
.
Request
=
req
handler
.
CompleteOIDCOAuthRegistration
(
c
)
require
.
Equal
(
t
,
http
.
StatusConflict
,
recorder
.
Code
)
payload
:=
decodeJSONBody
(
t
,
recorder
)
require
.
Equal
(
t
,
"AUTH_IDENTITY_OWNERSHIP_CONFLICT"
,
payload
[
"reason"
])
userCount
,
err
:=
client
.
User
.
Query
()
.
Where
(
dbuser
.
EmailEQ
(
"f6f5f1f16f9248ccb11e0d633963b290@oidc-connect.invalid"
))
.
Count
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Zero
(
t
,
userCount
)
storedSession
,
err
:=
client
.
PendingAuthSession
.
Get
(
ctx
,
session
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Nil
(
t
,
storedSession
.
ConsumedAt
)
}
type
oidcProviderFixture
struct
{
type
oidcProviderFixture
struct
{
Subject
string
Subject
string
PreferredUsername
string
PreferredUsername
string
...
...
backend/internal/handler/auth_session_revocation_test.go
0 → 100644
View file @
ddf80f5e
//go:build unit
package
handler
import
(
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
middleware2
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func
TestAuthHandlerRevokeAllSessionsInvalidatesAccessTokens
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
repo
:=
&
userHandlerRepoStub
{
user
:
&
service
.
User
{
ID
:
29
,
Email
:
"session@example.com"
,
Username
:
"session-user"
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
TokenVersion
:
7
,
},
}
refreshTokenCache
:=
&
userHandlerRefreshTokenCacheStub
{}
cfg
:=
&
config
.
Config
{
JWT
:
config
.
JWTConfig
{
Secret
:
"test-secret"
,
ExpireHour
:
1
,
},
}
authService
:=
service
.
NewAuthService
(
nil
,
repo
,
nil
,
refreshTokenCache
,
cfg
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
handler
:=
&
AuthHandler
{
authService
:
authService
}
recorder
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
recorder
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/auth/revoke-all-sessions"
,
nil
)
c
.
Set
(
string
(
middleware2
.
ContextKeyUser
),
middleware2
.
AuthSubject
{
UserID
:
29
})
handler
.
RevokeAllSessions
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
recorder
.
Code
)
require
.
Equal
(
t
,
[]
int64
{
29
},
refreshTokenCache
.
revokedUserIDs
)
require
.
Equal
(
t
,
int64
(
8
),
repo
.
user
.
TokenVersion
)
var
resp
struct
{
Code
int
`json:"code"`
Data
struct
{
Message
string
`json:"message"`
}
`json:"data"`
}
require
.
NoError
(
t
,
json
.
Unmarshal
(
recorder
.
Body
.
Bytes
(),
&
resp
))
require
.
Equal
(
t
,
0
,
resp
.
Code
)
require
.
Equal
(
t
,
"All sessions have been revoked. Please log in again."
,
resp
.
Data
.
Message
)
}
backend/internal/handler/auth_wechat_oauth.go
View file @
ddf80f5e
...
@@ -279,12 +279,7 @@ func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) {
...
@@ -279,12 +279,7 @@ func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) {
redirectOAuthError
(
c
,
frontendCallback
,
"session_error"
,
infraerrors
.
Reason
(
err
),
infraerrors
.
Message
(
err
))
redirectOAuthError
(
c
,
frontendCallback
,
"session_error"
,
infraerrors
.
Reason
(
err
),
infraerrors
.
Message
(
err
))
return
return
}
}
tokenPair
,
user
,
err
:=
h
.
authService
.
LoginOrRegisterOAuthWithTokenPair
(
c
.
Request
.
Context
(),
existingIdentityUser
.
Email
,
username
,
""
)
if
err
:=
h
.
createWeChatPendingSession
(
c
,
normalizedIntent
,
providerSubject
,
existingIdentityUser
.
Email
,
redirectTo
,
browserSessionKey
,
upstreamClaims
,
nil
,
nil
,
&
existingIdentityUser
.
ID
);
err
!=
nil
{
if
err
!=
nil
{
redirectOAuthError
(
c
,
frontendCallback
,
"login_failed"
,
infraerrors
.
Reason
(
err
),
infraerrors
.
Message
(
err
))
return
}
if
err
:=
h
.
createWeChatPendingSession
(
c
,
normalizedIntent
,
providerSubject
,
existingIdentityUser
.
Email
,
redirectTo
,
browserSessionKey
,
upstreamClaims
,
tokenPair
,
nil
,
&
user
.
ID
);
err
!=
nil
{
redirectOAuthError
(
c
,
frontendCallback
,
"session_error"
,
"failed to continue oauth login"
,
""
)
redirectOAuthError
(
c
,
frontendCallback
,
"session_error"
,
"failed to continue oauth login"
,
""
)
return
return
}
}
...
@@ -476,11 +471,12 @@ func (h *AuthHandler) WeChatPaymentOAuthCallback(c *gin.Context) {
...
@@ -476,11 +471,12 @@ func (h *AuthHandler) WeChatPaymentOAuthCallback(c *gin.Context) {
}
}
func
(
h
*
AuthHandler
)
wechatPaymentResumeService
()
*
service
.
PaymentResumeService
{
func
(
h
*
AuthHandler
)
wechatPaymentResumeService
()
*
service
.
PaymentResumeService
{
var
legacyKey
[]
byte
key
,
err
:=
payment
.
ProvideEncryptionKey
(
h
.
cfg
)
key
,
err
:=
payment
.
ProvideEncryptionKey
(
h
.
cfg
)
if
err
!
=
nil
{
if
err
=
=
nil
{
return
service
.
NewPaymentResumeService
(
nil
)
legacyKey
=
[]
byte
(
key
)
}
}
return
service
.
NewPaymentResumeService
(
[]
byte
(
k
ey
)
)
return
service
.
New
LegacyAware
PaymentResumeService
(
legacyK
ey
)
}
}
type
completeWeChatOAuthRequest
struct
{
type
completeWeChatOAuthRequest
struct
{
...
@@ -530,6 +526,15 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
...
@@ -530,6 +526,15 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
response
.
ErrorFrom
(
c
,
err
)
response
.
ErrorFrom
(
c
,
err
)
return
return
}
}
if
updatedSession
,
handled
,
err
:=
h
.
legacyCompleteRegistrationSessionStatus
(
c
,
session
);
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
else
if
handled
{
c
.
JSON
(
http
.
StatusOK
,
buildPendingOAuthSessionStatusPayload
(
updatedSession
))
return
}
else
{
session
=
updatedSession
}
if
err
:=
h
.
ensureBackendModeAllowsNewUserLogin
(
c
.
Request
.
Context
());
err
!=
nil
{
if
err
:=
h
.
ensureBackendModeAllowsNewUserLogin
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
response
.
ErrorFrom
(
c
,
err
)
return
return
...
@@ -547,7 +552,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
...
@@ -547,7 +552,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
response
.
ErrorFrom
(
c
,
err
)
response
.
ErrorFrom
(
c
,
err
)
return
return
}
}
decision
,
err
:=
h
.
upsert
PendingOAuthAdoptionDecision
(
c
,
session
.
ID
,
oauthAdoptionDecisionRequest
{
decision
,
err
:=
h
.
ensure
PendingOAuthAdoptionDecision
(
c
,
session
.
ID
,
oauthAdoptionDecisionRequest
{
AdoptDisplayName
:
req
.
AdoptDisplayName
,
AdoptDisplayName
:
req
.
AdoptDisplayName
,
AdoptAvatar
:
req
.
AdoptAvatar
,
AdoptAvatar
:
req
.
AdoptAvatar
,
})
})
...
@@ -823,7 +828,10 @@ func (h *AuthHandler) findWeChatUserByLegacyOpenID(
...
@@ -823,7 +828,10 @@ func (h *AuthHandler) findWeChatUserByLegacyOpenID(
return
nil
,
infraerrors
.
InternalServer
(
"AUTH_IDENTITY_LOOKUP_FAILED"
,
"failed to inspect auth identity ownership"
)
.
WithCause
(
err
)
return
nil
,
infraerrors
.
InternalServer
(
"AUTH_IDENTITY_LOOKUP_FAILED"
,
"failed to inspect auth identity ownership"
)
.
WithCause
(
err
)
}
}
if
user
,
err
:=
singleWeChatIdentityUser
(
records
);
err
!=
nil
||
user
!=
nil
{
if
user
,
err
:=
singleWeChatIdentityUser
(
records
);
err
!=
nil
||
user
!=
nil
{
return
user
,
err
if
err
!=
nil
||
user
==
nil
{
return
user
,
err
}
return
findActiveUserByID
(
ctx
,
client
,
user
.
ID
)
}
}
}
}
...
@@ -847,7 +855,10 @@ func (h *AuthHandler) findWeChatUserByLegacyOpenID(
...
@@ -847,7 +855,10 @@ func (h *AuthHandler) findWeChatUserByLegacyOpenID(
return
nil
,
infraerrors
.
InternalServer
(
"AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED"
,
"failed to inspect auth identity channel ownership"
)
.
WithCause
(
err
)
return
nil
,
infraerrors
.
InternalServer
(
"AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED"
,
"failed to inspect auth identity channel ownership"
)
.
WithCause
(
err
)
}
}
if
user
,
err
:=
singleWeChatChannelUser
(
records
);
err
!=
nil
||
user
!=
nil
{
if
user
,
err
:=
singleWeChatChannelUser
(
records
);
err
!=
nil
||
user
!=
nil
{
return
user
,
err
if
err
!=
nil
||
user
==
nil
{
return
user
,
err
}
return
findActiveUserByID
(
ctx
,
client
,
user
.
ID
)
}
}
}
}
...
@@ -866,7 +877,11 @@ func (h *AuthHandler) findWeChatUserByLegacyOpenID(
...
@@ -866,7 +877,11 @@ func (h *AuthHandler) findWeChatUserByLegacyOpenID(
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
infraerrors
.
InternalServer
(
"AUTH_IDENTITY_LOOKUP_FAILED"
,
"failed to inspect auth identity ownership"
)
.
WithCause
(
err
)
return
nil
,
infraerrors
.
InternalServer
(
"AUTH_IDENTITY_LOOKUP_FAILED"
,
"failed to inspect auth identity ownership"
)
.
WithCause
(
err
)
}
}
return
singleWeChatIdentityUser
(
records
)
user
,
err
:=
singleWeChatIdentityUser
(
records
)
if
err
!=
nil
||
user
==
nil
{
return
user
,
err
}
return
findActiveUserByID
(
ctx
,
client
,
user
.
ID
)
}
}
func
wechatCompatibleProviderKeys
(
providerKey
string
)
[]
string
{
func
wechatCompatibleProviderKeys
(
providerKey
string
)
[]
string
{
...
...
backend/internal/handler/auth_wechat_oauth_test.go
View file @
ddf80f5e
...
@@ -213,6 +213,151 @@ func TestWeChatOAuthCallbackFallsBackToOpenIDWhenUnionIDMissingInSingleChannelMo
...
@@ -213,6 +213,151 @@ func TestWeChatOAuthCallbackFallsBackToOpenIDWhenUnionIDMissingInSingleChannelMo
require
.
Equal
(
t
,
"third_party_signup"
,
completion
[
"choice_reason"
])
require
.
Equal
(
t
,
"third_party_signup"
,
completion
[
"choice_reason"
])
}
}
func
TestWeChatOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUserWithoutStoredTokens
(
t
*
testing
.
T
)
{
originalAccessTokenURL
:=
wechatOAuthAccessTokenURL
originalUserInfoURL
:=
wechatOAuthUserInfoURL
t
.
Cleanup
(
func
()
{
wechatOAuthAccessTokenURL
=
originalAccessTokenURL
wechatOAuthUserInfoURL
=
originalUserInfoURL
})
upstream
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
switch
{
case
strings
.
Contains
(
r
.
URL
.
Path
,
"/sns/oauth2/access_token"
)
:
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
w
.
Write
([]
byte
(
`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`
))
case
strings
.
Contains
(
r
.
URL
.
Path
,
"/sns/userinfo"
)
:
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
w
.
Write
([]
byte
(
`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Display","headimgurl":"https://cdn.example/wechat-login.png"}`
))
default
:
http
.
NotFound
(
w
,
r
)
}
}))
defer
upstream
.
Close
()
wechatOAuthAccessTokenURL
=
upstream
.
URL
+
"/sns/oauth2/access_token"
wechatOAuthUserInfoURL
=
upstream
.
URL
+
"/sns/userinfo"
handler
,
client
:=
newWeChatOAuthTestHandlerWithSettings
(
t
,
false
,
wechatOAuthTestSettings
(
"open"
,
"wx-open-app"
,
"wx-open-secret"
,
"https://app.example.com/auth/wechat/callback"
))
defer
client
.
Close
()
ctx
:=
context
.
Background
()
existingUser
,
err
:=
client
.
User
.
Create
()
.
SetEmail
(
wechatSyntheticEmail
(
"union-456"
))
.
SetUsername
(
"wechat-existing-user"
)
.
SetPasswordHash
(
"hash"
)
.
SetRole
(
service
.
RoleUser
)
.
SetStatus
(
service
.
StatusActive
)
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
_
,
err
=
client
.
AuthIdentity
.
Create
()
.
SetUserID
(
existingUser
.
ID
)
.
SetProviderType
(
"wechat"
)
.
SetProviderKey
(
wechatOAuthProviderKey
)
.
SetProviderSubject
(
"union-456"
)
.
SetMetadata
(
map
[
string
]
any
{
"username"
:
"wechat-existing-user"
})
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
recorder
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
recorder
)
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123"
,
nil
)
req
.
Host
=
"api.example.com"
req
.
AddCookie
(
encodedCookie
(
wechatOAuthStateCookieName
,
"state-123"
))
req
.
AddCookie
(
encodedCookie
(
wechatOAuthRedirectCookieName
,
"/dashboard"
))
req
.
AddCookie
(
encodedCookie
(
wechatOAuthModeCookieName
,
"open"
))
req
.
AddCookie
(
encodedCookie
(
oauthPendingBrowserCookieName
,
"browser-123"
))
c
.
Request
=
req
handler
.
WeChatOAuthCallback
(
c
)
require
.
Equal
(
t
,
http
.
StatusFound
,
recorder
.
Code
)
require
.
Equal
(
t
,
"https://app.example.com/auth/wechat/callback"
,
recorder
.
Header
()
.
Get
(
"Location"
))
sessionCookie
:=
findCookie
(
recorder
.
Result
()
.
Cookies
(),
oauthPendingSessionCookieName
)
require
.
NotNil
(
t
,
sessionCookie
)
session
,
err
:=
client
.
PendingAuthSession
.
Query
()
.
Where
(
pendingauthsession
.
SessionTokenEQ
(
decodeCookieValueForTest
(
t
,
sessionCookie
.
Value
)))
.
Only
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
oauthIntentLogin
,
session
.
Intent
)
require
.
NotNil
(
t
,
session
.
TargetUserID
)
require
.
Equal
(
t
,
existingUser
.
ID
,
*
session
.
TargetUserID
)
require
.
Equal
(
t
,
existingUser
.
Email
,
session
.
ResolvedEmail
)
completion
:=
session
.
LocalFlowState
[
oauthCompletionResponseKey
]
.
(
map
[
string
]
any
)
require
.
Equal
(
t
,
"/dashboard"
,
completion
[
"redirect"
])
_
,
hasAccessToken
:=
completion
[
"access_token"
]
require
.
False
(
t
,
hasAccessToken
)
_
,
hasRefreshToken
:=
completion
[
"refresh_token"
]
require
.
False
(
t
,
hasRefreshToken
)
}
func
TestWeChatOAuthCallbackRejectsDisabledExistingIdentityUser
(
t
*
testing
.
T
)
{
originalAccessTokenURL
:=
wechatOAuthAccessTokenURL
originalUserInfoURL
:=
wechatOAuthUserInfoURL
t
.
Cleanup
(
func
()
{
wechatOAuthAccessTokenURL
=
originalAccessTokenURL
wechatOAuthUserInfoURL
=
originalUserInfoURL
})
upstream
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
switch
{
case
strings
.
Contains
(
r
.
URL
.
Path
,
"/sns/oauth2/access_token"
)
:
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
w
.
Write
([]
byte
(
`{"access_token":"wechat-access","openid":"openid-disabled","unionid":"union-disabled","scope":"snsapi_login"}`
))
case
strings
.
Contains
(
r
.
URL
.
Path
,
"/sns/userinfo"
)
:
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
w
.
Write
([]
byte
(
`{"openid":"openid-disabled","unionid":"union-disabled","nickname":"Disabled WeChat","headimgurl":"https://cdn.example/disabled.png"}`
))
default
:
http
.
NotFound
(
w
,
r
)
}
}))
defer
upstream
.
Close
()
wechatOAuthAccessTokenURL
=
upstream
.
URL
+
"/sns/oauth2/access_token"
wechatOAuthUserInfoURL
=
upstream
.
URL
+
"/sns/userinfo"
handler
,
client
:=
newWeChatOAuthTestHandler
(
t
,
false
)
defer
client
.
Close
()
ctx
:=
context
.
Background
()
existingUser
,
err
:=
client
.
User
.
Create
()
.
SetEmail
(
wechatSyntheticEmail
(
"union-disabled"
))
.
SetUsername
(
"disabled-user"
)
.
SetPasswordHash
(
"hash"
)
.
SetRole
(
service
.
RoleUser
)
.
SetStatus
(
service
.
StatusDisabled
)
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
_
,
err
=
client
.
AuthIdentity
.
Create
()
.
SetUserID
(
existingUser
.
ID
)
.
SetProviderType
(
"wechat"
)
.
SetProviderKey
(
wechatOAuthProviderKey
)
.
SetProviderSubject
(
"union-disabled"
)
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
recorder
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
recorder
)
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-disabled"
,
nil
)
req
.
Host
=
"api.example.com"
req
.
AddCookie
(
encodedCookie
(
wechatOAuthStateCookieName
,
"state-disabled"
))
req
.
AddCookie
(
encodedCookie
(
wechatOAuthRedirectCookieName
,
"/dashboard"
))
req
.
AddCookie
(
encodedCookie
(
wechatOAuthModeCookieName
,
"open"
))
req
.
AddCookie
(
encodedCookie
(
oauthPendingBrowserCookieName
,
"browser-disabled"
))
c
.
Request
=
req
handler
.
WeChatOAuthCallback
(
c
)
require
.
Equal
(
t
,
http
.
StatusFound
,
recorder
.
Code
)
require
.
Nil
(
t
,
findCookie
(
recorder
.
Result
()
.
Cookies
(),
oauthPendingSessionCookieName
))
assertOAuthRedirectError
(
t
,
recorder
.
Header
()
.
Get
(
"Location"
),
"session_error"
,
"USER_NOT_ACTIVE"
)
count
,
err
:=
client
.
PendingAuthSession
.
Query
()
.
Count
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Zero
(
t
,
count
)
}
func
TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken
(
t
*
testing
.
T
)
{
func
TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken
(
t
*
testing
.
T
)
{
originalAccessTokenURL
:=
wechatOAuthAccessTokenURL
originalAccessTokenURL
:=
wechatOAuthAccessTokenURL
t
.
Cleanup
(
func
()
{
t
.
Cleanup
(
func
()
{
...
@@ -233,6 +378,7 @@ func TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken(t *testing.T)
...
@@ -233,6 +378,7 @@ func TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken(t *testing.T)
handler
,
client
:=
newWeChatOAuthTestHandlerWithSettings
(
t
,
false
,
wechatOAuthTestSettings
(
"mp"
,
"wx-mp-app"
,
"wx-mp-secret"
,
"/auth/wechat/callback"
))
handler
,
client
:=
newWeChatOAuthTestHandlerWithSettings
(
t
,
false
,
wechatOAuthTestSettings
(
"mp"
,
"wx-mp-app"
,
"wx-mp-secret"
,
"/auth/wechat/callback"
))
defer
client
.
Close
()
defer
client
.
Close
()
handler
.
cfg
.
Totp
.
EncryptionKey
=
"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
handler
.
cfg
.
Totp
.
EncryptionKey
=
"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
handler
.
cfg
.
Totp
.
EncryptionKeyConfigured
=
true
recorder
:=
httptest
.
NewRecorder
()
recorder
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
recorder
)
c
,
_
:=
gin
.
CreateTestContext
(
recorder
)
...
@@ -270,6 +416,67 @@ func TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken(t *testing.T)
...
@@ -270,6 +416,67 @@ func TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken(t *testing.T)
require
.
Equal
(
t
,
"/purchase?from=wechat"
,
claims
.
RedirectTo
)
require
.
Equal
(
t
,
"/purchase?from=wechat"
,
claims
.
RedirectTo
)
}
}
func
TestWeChatPaymentOAuthCallbackUsesExplicitPaymentResumeSigningKeyWhenMixedKeysConfigured
(
t
*
testing
.
T
)
{
originalAccessTokenURL
:=
wechatOAuthAccessTokenURL
t
.
Cleanup
(
func
()
{
wechatOAuthAccessTokenURL
=
originalAccessTokenURL
})
upstream
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
if
strings
.
Contains
(
r
.
URL
.
Path
,
"/sns/oauth2/access_token"
)
{
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
w
.
Write
([]
byte
(
`{"access_token":"wechat-access","openid":"openid-mixed-key","scope":"snsapi_base"}`
))
return
}
http
.
NotFound
(
w
,
r
)
}))
defer
upstream
.
Close
()
wechatOAuthAccessTokenURL
=
upstream
.
URL
+
"/sns/oauth2/access_token"
handler
,
client
:=
newWeChatOAuthTestHandlerWithSettings
(
t
,
false
,
wechatOAuthTestSettings
(
"mp"
,
"wx-mp-app"
,
"wx-mp-secret"
,
"/auth/wechat/callback"
))
defer
client
.
Close
()
legacyKeyHex
:=
"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
explicitSigningKey
:=
"explicit-payment-resume-signing-key"
t
.
Setenv
(
"PAYMENT_RESUME_SIGNING_KEY"
,
explicitSigningKey
)
handler
.
cfg
.
Totp
.
EncryptionKey
=
legacyKeyHex
handler
.
cfg
.
Totp
.
EncryptionKeyConfigured
=
true
recorder
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
recorder
)
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/api/v1/auth/oauth/wechat/payment/callback?code=wechat-code&state=state-mixed"
,
nil
)
req
.
Host
=
"api.example.com"
req
.
AddCookie
(
encodedCookie
(
wechatPaymentOAuthStateName
,
"state-mixed"
))
req
.
AddCookie
(
encodedCookie
(
wechatPaymentOAuthRedirect
,
"/purchase?from=wechat"
))
req
.
AddCookie
(
encodedCookie
(
wechatPaymentOAuthContextName
,
`{"payment_type":"wxpay","amount":"18.8","order_type":"subscription","plan_id":9}`
))
req
.
AddCookie
(
encodedCookie
(
wechatPaymentOAuthScope
,
"snsapi_base"
))
c
.
Request
=
req
handler
.
WeChatPaymentOAuthCallback
(
c
)
require
.
Equal
(
t
,
http
.
StatusFound
,
recorder
.
Code
)
location
:=
recorder
.
Header
()
.
Get
(
"Location"
)
parsed
,
err
:=
url
.
Parse
(
location
)
require
.
NoError
(
t
,
err
)
fragment
,
err
:=
url
.
ParseQuery
(
parsed
.
Fragment
)
require
.
NoError
(
t
,
err
)
token
:=
fragment
.
Get
(
"wechat_resume_token"
)
require
.
NotEmpty
(
t
,
token
)
claims
,
err
:=
service
.
NewPaymentResumeService
([]
byte
(
explicitSigningKey
))
.
ParseWeChatPaymentResumeToken
(
token
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"openid-mixed-key"
,
claims
.
OpenID
)
require
.
Equal
(
t
,
payment
.
TypeWxpay
,
claims
.
PaymentType
)
require
.
Equal
(
t
,
"18.8"
,
claims
.
Amount
)
require
.
Equal
(
t
,
payment
.
OrderTypeSubscription
,
claims
.
OrderType
)
require
.
EqualValues
(
t
,
9
,
claims
.
PlanID
)
require
.
Equal
(
t
,
"/purchase?from=wechat"
,
claims
.
RedirectTo
)
_
,
err
=
service
.
NewPaymentResumeService
([]
byte
(
"0123456789abcdef0123456789abcdef"
))
.
ParseWeChatPaymentResumeToken
(
token
)
require
.
Error
(
t
,
err
)
}
func
TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels
(
t
*
testing
.
T
)
{
func
TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels
(
t
*
testing
.
T
)
{
testCases
:=
[]
struct
{
testCases
:=
[]
struct
{
name
string
name
string
...
@@ -620,7 +827,7 @@ func TestWeChatOAuthCallbackBindRejectsLegacyProviderKeyOwnershipConflict(t *tes
...
@@ -620,7 +827,7 @@ func TestWeChatOAuthCallbackBindRejectsLegacyProviderKeyOwnershipConflict(t *tes
require
.
Zero
(
t
,
count
)
require
.
Zero
(
t
,
count
)
}
}
func
TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession
(
t
*
testing
.
T
)
{
func
TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession
ReturnsPendingSession
(
t
*
testing
.
T
)
{
originalAccessTokenURL
:=
wechatOAuthAccessTokenURL
originalAccessTokenURL
:=
wechatOAuthAccessTokenURL
originalUserInfoURL
:=
wechatOAuthUserInfoURL
originalUserInfoURL
:=
wechatOAuthUserInfoURL
t
.
Cleanup
(
func
()
{
t
.
Cleanup
(
func
()
{
...
@@ -693,27 +900,32 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing
...
@@ -693,27 +900,32 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing
require
.
Equal
(
t
,
http
.
StatusOK
,
completeRecorder
.
Code
)
require
.
Equal
(
t
,
http
.
StatusOK
,
completeRecorder
.
Code
)
responseData
:=
decodeJSONBody
(
t
,
completeRecorder
)
responseData
:=
decodeJSONBody
(
t
,
completeRecorder
)
require
.
NotEmpty
(
t
,
responseData
[
"access_token"
])
require
.
Equal
(
t
,
"pending_session"
,
responseData
[
"auth_result"
])
require
.
Equal
(
t
,
oauthPendingChoiceStep
,
responseData
[
"step"
])
require
.
Equal
(
t
,
true
,
responseData
[
"adoption_required"
])
require
.
Empty
(
t
,
responseData
[
"access_token"
])
userEntity
,
err
:=
client
.
User
.
Query
()
.
consumed
,
err
:=
client
.
PendingAuthSession
.
Query
()
.
Where
(
dbuser
.
EmailEQ
(
"wechat-union-456@wechat-connect.invalid"
))
.
Where
(
pendingauthsession
.
IDEQ
(
pendingSession
.
ID
))
.
Only
(
ctx
)
Only
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
Equa
l
(
t
,
"WeChat Display"
,
userEntity
.
Username
)
require
.
Ni
l
(
t
,
consumed
.
ConsumedAt
)
identity
,
err
:=
client
.
AuthIdentity
.
Query
()
.
userCount
,
err
:=
client
.
User
.
Query
()
.
Count
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Zero
(
t
,
userCount
)
identityCount
,
err
:=
client
.
AuthIdentity
.
Query
()
.
Where
(
Where
(
authidentity
.
ProviderTypeEQ
(
"wechat"
),
authidentity
.
ProviderTypeEQ
(
"wechat"
),
authidentity
.
ProviderKeyEQ
(
"wechat-main"
),
authidentity
.
ProviderKeyEQ
(
"wechat-main"
),
authidentity
.
ProviderSubjectEQ
(
"union-456"
),
authidentity
.
ProviderSubjectEQ
(
"union-456"
),
)
.
)
.
Only
(
ctx
)
Count
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
userEntity
.
ID
,
identity
.
UserID
)
require
.
Zero
(
t
,
identityCount
)
require
.
Equal
(
t
,
"WeChat Display"
,
identity
.
Metadata
[
"display_name"
])
require
.
Equal
(
t
,
"https://cdn.example/wechat.png"
,
identity
.
Metadata
[
"avatar_url"
])
channel
,
err
:=
client
.
AuthIdentityChannel
.
Query
()
.
channel
Count
,
err
:=
client
.
AuthIdentityChannel
.
Query
()
.
Where
(
Where
(
authidentitychannel
.
ProviderTypeEQ
(
"wechat"
),
authidentitychannel
.
ProviderTypeEQ
(
"wechat"
),
authidentitychannel
.
ProviderKeyEQ
(
"wechat-main"
),
authidentitychannel
.
ProviderKeyEQ
(
"wechat-main"
),
...
@@ -721,25 +933,82 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing
...
@@ -721,25 +933,82 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing
authidentitychannel
.
ChannelAppIDEQ
(
"wx-open-app"
),
authidentitychannel
.
ChannelAppIDEQ
(
"wx-open-app"
),
authidentitychannel
.
ChannelSubjectEQ
(
"openid-123"
),
authidentitychannel
.
ChannelSubjectEQ
(
"openid-123"
),
)
.
)
.
Only
(
ctx
)
Count
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
identity
.
ID
,
channel
.
IdentityID
)
require
.
Zero
(
t
,
channelCount
)
require
.
Equal
(
t
,
"union-456"
,
channel
.
Metadata
[
"unionid"
])
decision
,
err
:=
client
.
IdentityAdoptionDecision
.
Query
()
.
decision
Count
,
err
:=
client
.
IdentityAdoptionDecision
.
Query
()
.
Where
(
identityadoptiondecision
.
PendingAuthSessionIDEQ
(
pendingSession
.
ID
))
.
Where
(
identityadoptiondecision
.
PendingAuthSessionIDEQ
(
pendingSession
.
ID
))
.
Count
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Zero
(
t
,
decisionCount
)
}
func
TestCompleteWeChatOAuthRegistrationBindsIdentityWithoutAdoptionFlags
(
t
*
testing
.
T
)
{
handler
,
client
:=
newOAuthPendingFlowTestHandler
(
t
,
false
)
ctx
:=
context
.
Background
()
session
,
err
:=
client
.
PendingAuthSession
.
Create
()
.
SetSessionToken
(
"wechat-complete-no-adoption-session"
)
.
SetIntent
(
"login"
)
.
SetProviderType
(
"wechat"
)
.
SetProviderKey
(
wechatOAuthProviderKey
)
.
SetProviderSubject
(
"wechat-subject-no-adoption"
)
.
SetResolvedEmail
(
"wechat-subject-no-adoption@wechat-connect.invalid"
)
.
SetBrowserSessionKey
(
"wechat-browser-no-adoption"
)
.
SetUpstreamIdentityClaims
(
map
[
string
]
any
{
"username"
:
"wechat_user"
,
"suggested_display_name"
:
"WeChat Legacy"
,
"suggested_avatar_url"
:
"https://cdn.example/wechat-legacy.png"
,
"mode"
:
"open"
,
"channel"
:
"open"
,
"channel_app_id"
:
"wx-open-app"
,
"channel_subject"
:
"openid-legacy"
,
})
.
SetExpiresAt
(
time
.
Now
()
.
UTC
()
.
Add
(
10
*
time
.
Minute
))
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
body
:=
bytes
.
NewBufferString
(
`{"invitation_code":"invite-1"}`
)
recorder
:=
httptest
.
NewRecorder
()
completeCtx
,
_
:=
gin
.
CreateTestContext
(
recorder
)
completeReq
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/auth/oauth/wechat/complete-registration"
,
body
)
completeReq
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
completeReq
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthPendingSessionCookieName
,
Value
:
encodeCookieValue
(
session
.
SessionToken
)})
completeReq
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthPendingBrowserCookieName
,
Value
:
encodeCookieValue
(
"wechat-browser-no-adoption"
)})
completeCtx
.
Request
=
completeReq
handler
.
CompleteWeChatOAuthRegistration
(
completeCtx
)
require
.
Equal
(
t
,
http
.
StatusOK
,
recorder
.
Code
)
responseData
:=
decodeJSONBody
(
t
,
recorder
)
require
.
NotEmpty
(
t
,
responseData
[
"access_token"
])
require
.
NotEmpty
(
t
,
responseData
[
"refresh_token"
])
userEntity
,
err
:=
client
.
User
.
Query
()
.
Where
(
dbuser
.
EmailEQ
(
session
.
ResolvedEmail
))
.
Only
(
ctx
)
Only
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
decision
.
IdentityID
)
require
.
Equal
(
t
,
"wechat_user"
,
userEntity
.
Username
)
require
.
Equal
(
t
,
identity
.
ID
,
*
decision
.
IdentityID
)
require
.
True
(
t
,
decision
.
AdoptDisplayName
)
require
.
True
(
t
,
decision
.
AdoptAvatar
)
consumed
,
err
:=
client
.
PendingAuthSession
.
Query
()
.
identity
,
err
:=
client
.
AuthIdentity
.
Query
()
.
Where
(
pendingauthsession
.
IDEQ
(
pendingSession
.
ID
))
.
Where
(
authidentity
.
ProviderTypeEQ
(
"wechat"
),
authidentity
.
ProviderKeyEQ
(
wechatOAuthProviderKey
),
authidentity
.
ProviderSubjectEQ
(
"wechat-subject-no-adoption"
),
)
.
Only
(
ctx
)
Only
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
consumed
.
ConsumedAt
)
require
.
Equal
(
t
,
userEntity
.
ID
,
identity
.
UserID
)
decision
,
err
:=
client
.
IdentityAdoptionDecision
.
Query
()
.
Where
(
identityadoptiondecision
.
PendingAuthSessionIDEQ
(
session
.
ID
))
.
Only
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
decision
.
IdentityID
)
require
.
Equal
(
t
,
identity
.
ID
,
*
decision
.
IdentityID
)
require
.
False
(
t
,
decision
.
AdoptDisplayName
)
require
.
False
(
t
,
decision
.
AdoptAvatar
)
}
}
func
TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity
(
t
*
testing
.
T
)
{
func
TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity
(
t
*
testing
.
T
)
{
...
@@ -901,6 +1170,62 @@ func TestCompleteWeChatOAuthRegistrationRejectsAdoptExistingUserSession(t *testi
...
@@ -901,6 +1170,62 @@ func TestCompleteWeChatOAuthRegistrationRejectsAdoptExistingUserSession(t *testi
require
.
Nil
(
t
,
storedSession
.
ConsumedAt
)
require
.
Nil
(
t
,
storedSession
.
ConsumedAt
)
}
}
func
TestCompleteWeChatOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired
(
t
*
testing
.
T
)
{
handler
,
client
:=
newWeChatOAuthTestHandler
(
t
,
false
)
defer
client
.
Close
()
ctx
:=
context
.
Background
()
session
,
err
:=
client
.
PendingAuthSession
.
Create
()
.
SetSessionToken
(
"wechat-complete-choice-session"
)
.
SetIntent
(
"login"
)
.
SetProviderType
(
"wechat"
)
.
SetProviderKey
(
"wechat-main"
)
.
SetProviderSubject
(
"wechat-choice-subject-1"
)
.
SetResolvedEmail
(
"wechat-choice-subject-1@wechat-connect.invalid"
)
.
SetBrowserSessionKey
(
"wechat-choice-browser"
)
.
SetUpstreamIdentityClaims
(
map
[
string
]
any
{
"username"
:
"wechat_user"
,
})
.
SetLocalFlowState
(
map
[
string
]
any
{
oauthCompletionResponseKey
:
map
[
string
]
any
{
"step"
:
oauthPendingChoiceStep
,
"redirect"
:
"/dashboard"
,
"email"
:
"fresh@example.com"
,
"resolved_email"
:
"fresh@example.com"
,
"force_email_on_signup"
:
true
,
},
})
.
SetExpiresAt
(
time
.
Now
()
.
UTC
()
.
Add
(
10
*
time
.
Minute
))
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
body
:=
bytes
.
NewBufferString
(
`{"invitation_code":"invite-1"}`
)
recorder
:=
httptest
.
NewRecorder
()
completeCtx
,
_
:=
gin
.
CreateTestContext
(
recorder
)
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/auth/oauth/wechat/complete-registration"
,
body
)
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthPendingSessionCookieName
,
Value
:
encodeCookieValue
(
session
.
SessionToken
)})
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthPendingBrowserCookieName
,
Value
:
encodeCookieValue
(
"wechat-choice-browser"
)})
completeCtx
.
Request
=
req
handler
.
CompleteWeChatOAuthRegistration
(
completeCtx
)
require
.
Equal
(
t
,
http
.
StatusOK
,
recorder
.
Code
)
responseData
:=
decodeJSONBody
(
t
,
recorder
)
require
.
Equal
(
t
,
"pending_session"
,
responseData
[
"auth_result"
])
require
.
Equal
(
t
,
oauthPendingChoiceStep
,
responseData
[
"step"
])
require
.
Equal
(
t
,
true
,
responseData
[
"force_email_on_signup"
])
require
.
Empty
(
t
,
responseData
[
"access_token"
])
userCount
,
err
:=
client
.
User
.
Query
()
.
Count
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Zero
(
t
,
userCount
)
storedSession
,
err
:=
client
.
PendingAuthSession
.
Get
(
ctx
,
session
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Nil
(
t
,
storedSession
.
ConsumedAt
)
}
func
TestWeChatOAuthCallbackRepairsLegacyProviderKeyCanonicalIdentity
(
t
*
testing
.
T
)
{
func
TestWeChatOAuthCallbackRepairsLegacyProviderKeyCanonicalIdentity
(
t
*
testing
.
T
)
{
originalAccessTokenURL
:=
wechatOAuthAccessTokenURL
originalAccessTokenURL
:=
wechatOAuthAccessTokenURL
originalUserInfoURL
:=
wechatOAuthUserInfoURL
originalUserInfoURL
:=
wechatOAuthUserInfoURL
...
@@ -1083,18 +1408,6 @@ func newWeChatOAuthTestHandlerWithSettings(t *testing.T, invitationEnabled bool,
...
@@ -1083,18 +1408,6 @@ func newWeChatOAuthTestHandlerWithSettings(t *testing.T, invitationEnabled bool,
},
client
},
client
}
}
func
assertOAuthRedirectError
(
t
*
testing
.
T
,
location
string
,
errorCode
string
,
errorMessage
string
)
{
t
.
Helper
()
parsed
,
err
:=
url
.
Parse
(
location
)
require
.
NoError
(
t
,
err
)
fragment
,
err
:=
url
.
ParseQuery
(
parsed
.
Fragment
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
errorCode
,
fragment
.
Get
(
"error"
))
require
.
Equal
(
t
,
errorMessage
,
fragment
.
Get
(
"error_message"
))
}
type
wechatOAuthSettingRepoStub
struct
{
type
wechatOAuthSettingRepoStub
struct
{
values
map
[
string
]
string
values
map
[
string
]
string
}
}
...
...
backend/internal/handler/payment_handler.go
View file @
ddf80f5e
...
@@ -2,9 +2,9 @@ package handler
...
@@ -2,9 +2,9 @@ package handler
import
(
import
(
"fmt"
"fmt"
"net/http"
"strconv"
"strconv"
"strings"
"strings"
"time"
dbent
"github.com/Wei-Shaw/sub2api/ent"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/Wei-Shaw/sub2api/internal/payment"
...
@@ -454,29 +454,65 @@ func (h *PaymentHandler) VerifyOrder(c *gin.Context) {
...
@@ -454,29 +454,65 @@ func (h *PaymentHandler) VerifyOrder(c *gin.Context) {
// PublicOrderResult is the limited order info returned by the public verify endpoint.
// PublicOrderResult is the limited order info returned by the public verify endpoint.
// No user details are exposed — only payment status information.
// No user details are exposed — only payment status information.
type
PublicOrderResult
struct
{
type
PublicOrderResult
struct
{
ID
int64
`json:"id"`
ID
int64
`json:"id"`
OutTradeNo
string
`json:"out_trade_no"`
OutTradeNo
string
`json:"out_trade_no"`
Amount
float64
`json:"amount"`
Amount
float64
`json:"amount"`
PayAmount
float64
`json:"pay_amount"`
PayAmount
float64
`json:"pay_amount"`
PaymentType
string
`json:"payment_type"`
FeeRate
float64
`json:"fee_rate"`
OrderType
string
`json:"order_type"`
PaymentType
string
`json:"payment_type"`
Status
string
`json:"status"`
OrderType
string
`json:"order_type"`
Status
string
`json:"status"`
CreatedAt
time
.
Time
`json:"created_at"`
ExpiresAt
time
.
Time
`json:"expires_at"`
PaidAt
*
time
.
Time
`json:"paid_at,omitempty"`
CompletedAt
*
time
.
Time
`json:"completed_at,omitempty"`
RefundAmount
float64
`json:"refund_amount"`
RefundReason
*
string
`json:"refund_reason,omitempty"`
RefundRequestedAt
*
time
.
Time
`json:"refund_requested_at,omitempty"`
RefundRequestedBy
*
string
`json:"refund_requested_by,omitempty"`
RefundRequestReason
*
string
`json:"refund_request_reason,omitempty"`
PlanID
*
int64
`json:"plan_id,omitempty"`
}
}
var
errPaymentPublicOrderVerifyRemoved
=
infraerrors
.
New
(
func
buildPublicOrderResult
(
order
*
dbent
.
PaymentOrder
)
PublicOrderResult
{
http
.
StatusGone
,
return
PublicOrderResult
{
"PAYMENT_PUBLIC_ORDER_VERIFY_REMOVED"
,
ID
:
order
.
ID
,
"public payment order verification by out_trade_no has been removed; use resume_token recovery instead"
,
OutTradeNo
:
order
.
OutTradeNo
,
)
.
WithMetadata
(
map
[
string
]
string
{
Amount
:
order
.
Amount
,
"replacement_endpoint"
:
"/api/v1/payment/public/orders/resolve"
,
PayAmount
:
order
.
PayAmount
,
"replacement_field"
:
"resume_token"
,
FeeRate
:
order
.
FeeRate
,
})
PaymentType
:
order
.
PaymentType
,
OrderType
:
order
.
OrderType
,
// VerifyOrderPublic is kept as a compatibility shim for the removed anonymous
Status
:
order
.
Status
,
// out_trade_no lookup endpoint and always returns HTTP 410 Gone.
CreatedAt
:
order
.
CreatedAt
,
ExpiresAt
:
order
.
ExpiresAt
,
PaidAt
:
order
.
PaidAt
,
CompletedAt
:
order
.
CompletedAt
,
RefundAmount
:
order
.
RefundAmount
,
RefundReason
:
order
.
RefundReason
,
RefundRequestedAt
:
order
.
RefundRequestedAt
,
RefundRequestedBy
:
order
.
RefundRequestedBy
,
RefundRequestReason
:
order
.
RefundRequestReason
,
PlanID
:
order
.
PlanID
,
}
}
// VerifyOrderPublic keeps the legacy anonymous out_trade_no lookup available as
// a compatibility path for older result pages and staggered deploys.
// POST /api/v1/payment/public/orders/verify
// POST /api/v1/payment/public/orders/verify
func
(
h
*
PaymentHandler
)
VerifyOrderPublic
(
c
*
gin
.
Context
)
{
func
(
h
*
PaymentHandler
)
VerifyOrderPublic
(
c
*
gin
.
Context
)
{
response
.
ErrorFrom
(
c
,
errPaymentPublicOrderVerifyRemoved
)
var
req
VerifyOrderRequest
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid request: "
+
err
.
Error
())
return
}
order
,
err
:=
h
.
paymentService
.
VerifyOrderPublic
(
c
.
Request
.
Context
(),
req
.
OutTradeNo
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
buildPublicOrderResult
(
order
))
}
}
// ResolveOrderPublicByResumeToken resolves a payment order from a signed resume token.
// ResolveOrderPublicByResumeToken resolves a payment order from a signed resume token.
...
@@ -493,15 +529,7 @@ func (h *PaymentHandler) ResolveOrderPublicByResumeToken(c *gin.Context) {
...
@@ -493,15 +529,7 @@ func (h *PaymentHandler) ResolveOrderPublicByResumeToken(c *gin.Context) {
response
.
ErrorFrom
(
c
,
err
)
response
.
ErrorFrom
(
c
,
err
)
return
return
}
}
response
.
Success
(
c
,
PublicOrderResult
{
response
.
Success
(
c
,
buildPublicOrderResult
(
order
))
ID
:
order
.
ID
,
OutTradeNo
:
order
.
OutTradeNo
,
Amount
:
order
.
Amount
,
PayAmount
:
order
.
PayAmount
,
PaymentType
:
order
.
PaymentType
,
OrderType
:
order
.
OrderType
,
Status
:
order
.
Status
,
})
}
}
// requireAuth extracts the authenticated subject from the context.
// requireAuth extracts the authenticated subject from the context.
...
...
backend/internal/handler/payment_handler_resume_test.go
View file @
ddf80f5e
...
@@ -4,16 +4,17 @@ package handler
...
@@ -4,16 +4,17 @@ package handler
import
(
import
(
"bytes"
"bytes"
"context"
"database/sql"
"database/sql"
"encoding/json"
"encoding/json"
"net/http"
"net/http"
"net/http/httptest"
"net/http/httptest"
"testing"
"testing"
"time"
dbent
"github.com/Wei-Shaw/sub2api/ent"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/require"
...
@@ -74,7 +75,7 @@ func TestApplyWeChatPaymentResumeClaimsRejectsPaymentTypeMismatch(t *testing.T)
...
@@ -74,7 +75,7 @@ func TestApplyWeChatPaymentResumeClaimsRejectsPaymentTypeMismatch(t *testing.T)
}
}
}
}
func
TestVerifyOrderPublicReturns
Gon
e
(
t
*
testing
.
T
)
{
func
TestVerifyOrderPublicReturns
LegacyOrderStat
e
(
t
*
testing
.
T
)
{
t
.
Parallel
()
t
.
Parallel
()
gin
.
SetMode
(
gin
.
TestMode
)
gin
.
SetMode
(
gin
.
TestMode
)
...
@@ -90,6 +91,32 @@ func TestVerifyOrderPublicReturnsGone(t *testing.T) {
...
@@ -90,6 +91,32 @@ func TestVerifyOrderPublicReturnsGone(t *testing.T) {
client
:=
enttest
.
NewClient
(
t
,
enttest
.
WithOptions
(
dbent
.
Driver
(
drv
)))
client
:=
enttest
.
NewClient
(
t
,
enttest
.
WithOptions
(
dbent
.
Driver
(
drv
)))
t
.
Cleanup
(
func
()
{
_
=
client
.
Close
()
})
t
.
Cleanup
(
func
()
{
_
=
client
.
Close
()
})
user
,
err
:=
client
.
User
.
Create
()
.
SetEmail
(
"public-verify@example.com"
)
.
SetPasswordHash
(
"hash"
)
.
SetUsername
(
"public-verify-user"
)
.
Save
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
order
,
err
:=
client
.
PaymentOrder
.
Create
()
.
SetUserID
(
user
.
ID
)
.
SetUserEmail
(
user
.
Email
)
.
SetUserName
(
user
.
Username
)
.
SetAmount
(
88
)
.
SetPayAmount
(
90.64
)
.
SetFeeRate
(
0.03
)
.
SetRechargeCode
(
"PUBLIC-VERIFY"
)
.
SetOutTradeNo
(
"legacy-order-no"
)
.
SetPaymentType
(
payment
.
TypeAlipay
)
.
SetPaymentTradeNo
(
"trade-public-verify"
)
.
SetOrderType
(
payment
.
OrderTypeBalance
)
.
SetStatus
(
service
.
OrderStatusPending
)
.
SetExpiresAt
(
time
.
Now
()
.
Add
(
time
.
Hour
))
.
SetClientIP
(
"127.0.0.1"
)
.
SetSrcHost
(
"api.example.com"
)
.
Save
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
paymentSvc
:=
service
.
NewPaymentService
(
client
,
payment
.
NewRegistry
(),
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
paymentSvc
:=
service
.
NewPaymentService
(
client
,
payment
.
NewRegistry
(),
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
h
:=
NewPaymentHandler
(
paymentSvc
,
nil
,
nil
)
h
:=
NewPaymentHandler
(
paymentSvc
,
nil
,
nil
)
...
@@ -104,11 +131,238 @@ func TestVerifyOrderPublicReturnsGone(t *testing.T) {
...
@@ -104,11 +131,238 @@ func TestVerifyOrderPublicReturnsGone(t *testing.T) {
h
.
VerifyOrderPublic
(
ctx
)
h
.
VerifyOrderPublic
(
ctx
)
require
.
Equal
(
t
,
http
.
StatusGone
,
recorder
.
Code
)
require
.
Equal
(
t
,
http
.
StatusOK
,
recorder
.
Code
)
var
resp
struct
{
Code
int
`json:"code"`
Data
struct
{
ID
int64
`json:"id"`
OutTradeNo
string
`json:"out_trade_no"`
Amount
float64
`json:"amount"`
PayAmount
float64
`json:"pay_amount"`
FeeRate
float64
`json:"fee_rate"`
PaymentType
string
`json:"payment_type"`
OrderType
string
`json:"order_type"`
Status
string
`json:"status"`
RefundAmount
float64
`json:"refund_amount"`
CreatedAt
string
`json:"created_at"`
ExpiresAt
string
`json:"expires_at"`
}
`json:"data"`
}
require
.
NoError
(
t
,
json
.
Unmarshal
(
recorder
.
Body
.
Bytes
(),
&
resp
))
require
.
Equal
(
t
,
0
,
resp
.
Code
)
require
.
Equal
(
t
,
order
.
ID
,
resp
.
Data
.
ID
)
require
.
Equal
(
t
,
"legacy-order-no"
,
resp
.
Data
.
OutTradeNo
)
require
.
Equal
(
t
,
90.64
,
resp
.
Data
.
PayAmount
)
require
.
Equal
(
t
,
0.03
,
resp
.
Data
.
FeeRate
)
require
.
Equal
(
t
,
payment
.
TypeAlipay
,
resp
.
Data
.
PaymentType
)
require
.
Equal
(
t
,
payment
.
OrderTypeBalance
,
resp
.
Data
.
OrderType
)
require
.
Equal
(
t
,
service
.
OrderStatusPending
,
resp
.
Data
.
Status
)
require
.
Equal
(
t
,
0.0
,
resp
.
Data
.
RefundAmount
)
require
.
NotEmpty
(
t
,
resp
.
Data
.
CreatedAt
)
require
.
NotEmpty
(
t
,
resp
.
Data
.
ExpiresAt
)
}
func
TestResolveOrderPublicByResumeTokenReturnsFrontendContractFields
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
t
.
Setenv
(
"PAYMENT_RESUME_SIGNING_KEY"
,
"0123456789abcdef0123456789abcdef"
)
db
,
err
:=
sql
.
Open
(
"sqlite"
,
"file:payment_handler_public_resolve?mode=memory&cache=shared"
)
require
.
NoError
(
t
,
err
)
t
.
Cleanup
(
func
()
{
_
=
db
.
Close
()
})
_
,
err
=
db
.
Exec
(
"PRAGMA foreign_keys = ON"
)
require
.
NoError
(
t
,
err
)
drv
:=
entsql
.
OpenDB
(
dialect
.
SQLite
,
db
)
client
:=
enttest
.
NewClient
(
t
,
enttest
.
WithOptions
(
dbent
.
Driver
(
drv
)))
t
.
Cleanup
(
func
()
{
_
=
client
.
Close
()
})
user
,
err
:=
client
.
User
.
Create
()
.
SetEmail
(
"public-resolve@example.com"
)
.
SetPasswordHash
(
"hash"
)
.
SetUsername
(
"public-resolve-user"
)
.
Save
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
order
,
err
:=
client
.
PaymentOrder
.
Create
()
.
SetUserID
(
user
.
ID
)
.
SetUserEmail
(
user
.
Email
)
.
SetUserName
(
user
.
Username
)
.
SetAmount
(
100
)
.
SetPayAmount
(
103
)
.
SetFeeRate
(
0.03
)
.
SetRechargeCode
(
"PUBLIC-RESOLVE"
)
.
SetOutTradeNo
(
"resolve-order-no"
)
.
SetPaymentType
(
payment
.
TypeAlipay
)
.
SetPaymentTradeNo
(
"trade-public-resolve"
)
.
SetOrderType
(
payment
.
OrderTypeBalance
)
.
SetStatus
(
service
.
OrderStatusPaid
)
.
SetExpiresAt
(
time
.
Now
()
.
Add
(
time
.
Hour
))
.
SetPaidAt
(
time
.
Now
())
.
SetClientIP
(
"127.0.0.1"
)
.
SetSrcHost
(
"api.example.com"
)
.
Save
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
resumeSvc
:=
service
.
NewPaymentResumeService
([]
byte
(
"0123456789abcdef0123456789abcdef"
))
token
,
err
:=
resumeSvc
.
CreateToken
(
service
.
ResumeTokenClaims
{
OrderID
:
order
.
ID
,
UserID
:
user
.
ID
,
PaymentType
:
payment
.
TypeAlipay
,
CanonicalReturnURL
:
"https://app.example.com/payment/result"
,
})
require
.
NoError
(
t
,
err
)
configSvc
:=
service
.
NewPaymentConfigService
(
client
,
nil
,
[]
byte
(
"0123456789abcdef0123456789abcdef"
))
paymentSvc
:=
service
.
NewPaymentService
(
client
,
payment
.
NewRegistry
(),
nil
,
nil
,
nil
,
configSvc
,
nil
,
nil
)
h
:=
NewPaymentHandler
(
paymentSvc
,
nil
,
nil
)
recorder
:=
httptest
.
NewRecorder
()
ctx
,
_
:=
gin
.
CreateTestContext
(
recorder
)
ctx
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/payment/public/orders/resolve"
,
bytes
.
NewBufferString
(
`{"resume_token":"`
+
token
+
`"}`
),
)
ctx
.
Request
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
h
.
ResolveOrderPublicByResumeToken
(
ctx
)
require
.
Equal
(
t
,
http
.
StatusOK
,
recorder
.
Code
)
var
resp
struct
{
Code
int
`json:"code"`
Data
map
[
string
]
any
`json:"data"`
}
require
.
NoError
(
t
,
json
.
Unmarshal
(
recorder
.
Body
.
Bytes
(),
&
resp
))
require
.
Equal
(
t
,
0
,
resp
.
Code
)
require
.
Equal
(
t
,
float64
(
order
.
ID
),
resp
.
Data
[
"id"
])
require
.
Equal
(
t
,
"resolve-order-no"
,
resp
.
Data
[
"out_trade_no"
])
require
.
Equal
(
t
,
100.0
,
resp
.
Data
[
"amount"
])
require
.
Equal
(
t
,
103.0
,
resp
.
Data
[
"pay_amount"
])
require
.
Equal
(
t
,
0.03
,
resp
.
Data
[
"fee_rate"
])
require
.
Equal
(
t
,
payment
.
TypeAlipay
,
resp
.
Data
[
"payment_type"
])
require
.
Equal
(
t
,
payment
.
OrderTypeBalance
,
resp
.
Data
[
"order_type"
])
require
.
Equal
(
t
,
service
.
OrderStatusPaid
,
resp
.
Data
[
"status"
])
require
.
Contains
(
t
,
resp
.
Data
,
"created_at"
)
require
.
Contains
(
t
,
resp
.
Data
,
"expires_at"
)
require
.
Contains
(
t
,
resp
.
Data
,
"refund_amount"
)
}
var
resp
response
.
Response
func
TestResolveOrderPublicByResumeTokenReturnsBadRequestForMismatchedToken
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
t
.
Setenv
(
"PAYMENT_RESUME_SIGNING_KEY"
,
"0123456789abcdef0123456789abcdef"
)
db
,
err
:=
sql
.
Open
(
"sqlite"
,
"file:payment_handler_public_resolve_mismatch?mode=memory&cache=shared"
)
require
.
NoError
(
t
,
err
)
t
.
Cleanup
(
func
()
{
_
=
db
.
Close
()
})
_
,
err
=
db
.
Exec
(
"PRAGMA foreign_keys = ON"
)
require
.
NoError
(
t
,
err
)
drv
:=
entsql
.
OpenDB
(
dialect
.
SQLite
,
db
)
client
:=
enttest
.
NewClient
(
t
,
enttest
.
WithOptions
(
dbent
.
Driver
(
drv
)))
t
.
Cleanup
(
func
()
{
_
=
client
.
Close
()
})
user
,
err
:=
client
.
User
.
Create
()
.
SetEmail
(
"public-resolve-mismatch@example.com"
)
.
SetPasswordHash
(
"hash"
)
.
SetUsername
(
"public-resolve-mismatch-user"
)
.
Save
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
order
,
err
:=
client
.
PaymentOrder
.
Create
()
.
SetUserID
(
user
.
ID
)
.
SetUserEmail
(
user
.
Email
)
.
SetUserName
(
user
.
Username
)
.
SetAmount
(
100
)
.
SetPayAmount
(
103
)
.
SetFeeRate
(
0.03
)
.
SetRechargeCode
(
"PUBLIC-RESOLVE-MISMATCH"
)
.
SetOutTradeNo
(
"resolve-order-mismatch-no"
)
.
SetPaymentType
(
payment
.
TypeAlipay
)
.
SetPaymentTradeNo
(
"trade-public-resolve-mismatch"
)
.
SetOrderType
(
payment
.
OrderTypeBalance
)
.
SetStatus
(
service
.
OrderStatusPaid
)
.
SetExpiresAt
(
time
.
Now
()
.
Add
(
time
.
Hour
))
.
SetPaidAt
(
time
.
Now
())
.
SetClientIP
(
"127.0.0.1"
)
.
SetSrcHost
(
"api.example.com"
)
.
Save
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
resumeSvc
:=
service
.
NewPaymentResumeService
([]
byte
(
"0123456789abcdef0123456789abcdef"
))
token
,
err
:=
resumeSvc
.
CreateToken
(
service
.
ResumeTokenClaims
{
OrderID
:
order
.
ID
,
UserID
:
user
.
ID
+
999
,
PaymentType
:
payment
.
TypeAlipay
,
CanonicalReturnURL
:
"https://app.example.com/payment/result"
,
})
require
.
NoError
(
t
,
err
)
configSvc
:=
service
.
NewPaymentConfigService
(
client
,
nil
,
[]
byte
(
"0123456789abcdef0123456789abcdef"
))
paymentSvc
:=
service
.
NewPaymentService
(
client
,
payment
.
NewRegistry
(),
nil
,
nil
,
nil
,
configSvc
,
nil
,
nil
)
h
:=
NewPaymentHandler
(
paymentSvc
,
nil
,
nil
)
recorder
:=
httptest
.
NewRecorder
()
ctx
,
_
:=
gin
.
CreateTestContext
(
recorder
)
ctx
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/payment/public/orders/resolve"
,
bytes
.
NewBufferString
(
`{"resume_token":"`
+
token
+
`"}`
),
)
ctx
.
Request
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
h
.
ResolveOrderPublicByResumeToken
(
ctx
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
recorder
.
Code
)
var
resp
struct
{
Code
int
`json:"code"`
Reason
string
`json:"reason"`
Message
string
`json:"message"`
}
require
.
NoError
(
t
,
json
.
Unmarshal
(
recorder
.
Body
.
Bytes
(),
&
resp
))
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
resp
.
Code
)
require
.
Equal
(
t
,
"INVALID_RESUME_TOKEN"
,
resp
.
Reason
)
}
func
TestVerifyOrderPublicRejectsBlankOutTradeNo
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
db
,
err
:=
sql
.
Open
(
"sqlite"
,
"file:payment_handler_public_verify_blank?mode=memory&cache=shared"
)
require
.
NoError
(
t
,
err
)
t
.
Cleanup
(
func
()
{
_
=
db
.
Close
()
})
_
,
err
=
db
.
Exec
(
"PRAGMA foreign_keys = ON"
)
require
.
NoError
(
t
,
err
)
drv
:=
entsql
.
OpenDB
(
dialect
.
SQLite
,
db
)
client
:=
enttest
.
NewClient
(
t
,
enttest
.
WithOptions
(
dbent
.
Driver
(
drv
)))
t
.
Cleanup
(
func
()
{
_
=
client
.
Close
()
})
paymentSvc
:=
service
.
NewPaymentService
(
client
,
payment
.
NewRegistry
(),
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
h
:=
NewPaymentHandler
(
paymentSvc
,
nil
,
nil
)
recorder
:=
httptest
.
NewRecorder
()
ctx
,
_
:=
gin
.
CreateTestContext
(
recorder
)
ctx
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/payment/public/orders/verify"
,
bytes
.
NewBufferString
(
`{"out_trade_no":" "}`
),
)
ctx
.
Request
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
h
.
VerifyOrderPublic
(
ctx
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
recorder
.
Code
)
var
resp
struct
{
Code
int
`json:"code"`
Reason
string
`json:"reason"`
}
require
.
NoError
(
t
,
json
.
Unmarshal
(
recorder
.
Body
.
Bytes
(),
&
resp
))
require
.
NoError
(
t
,
json
.
Unmarshal
(
recorder
.
Body
.
Bytes
(),
&
resp
))
require
.
Equal
(
t
,
http
.
StatusGone
,
resp
.
Code
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
resp
.
Code
)
require
.
Equal
(
t
,
"PAYMENT_PUBLIC_ORDER_VERIFY_REMOVED"
,
resp
.
Reason
)
require
.
Equal
(
t
,
"INVALID_OUT_TRADE_NO"
,
resp
.
Reason
)
require
.
Contains
(
t
,
resp
.
Message
,
"removed"
)
}
}
backend/internal/handler/user_handler.go
View file @
ddf80f5e
...
@@ -249,7 +249,7 @@ func (h *UserHandler) UnbindIdentity(c *gin.Context) {
...
@@ -249,7 +249,7 @@ func (h *UserHandler) UnbindIdentity(c *gin.Context) {
return
return
}
}
updatedUser
,
err
:=
h
.
userService
.
UnbindUserAuthProvider
(
updatedUser
,
unbound
,
err
:=
h
.
userService
.
UnbindUserAuthProvider
WithResult
(
c
.
Request
.
Context
(),
c
.
Request
.
Context
(),
subject
.
UserID
,
subject
.
UserID
,
c
.
Param
(
"provider"
),
c
.
Param
(
"provider"
),
...
@@ -258,6 +258,12 @@ func (h *UserHandler) UnbindIdentity(c *gin.Context) {
...
@@ -258,6 +258,12 @@ func (h *UserHandler) UnbindIdentity(c *gin.Context) {
response
.
ErrorFrom
(
c
,
err
)
response
.
ErrorFrom
(
c
,
err
)
return
return
}
}
if
unbound
&&
h
.
authService
!=
nil
{
if
err
:=
h
.
authService
.
RevokeAllUserTokens
(
c
.
Request
.
Context
(),
subject
.
UserID
);
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
}
profileResp
,
err
:=
h
.
buildUserProfileResponse
(
c
.
Request
.
Context
(),
subject
.
UserID
,
updatedUser
)
profileResp
,
err
:=
h
.
buildUserProfileResponse
(
c
.
Request
.
Context
(),
subject
.
UserID
,
updatedUser
)
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -504,8 +510,12 @@ func inferUserProfileSources(user *service.User, identities service.UserIdentity
...
@@ -504,8 +510,12 @@ func inferUserProfileSources(user *service.User, identities service.UserIdentity
thirdParty
:=
thirdPartyIdentityProviders
(
identities
)
thirdParty
:=
thirdPartyIdentityProviders
(
identities
)
var
avatarSource
*
userProfileSourceContext
var
avatarSource
*
userProfileSourceContext
if
strings
.
TrimSpace
(
user
.
AvatarURL
)
!=
""
&&
len
(
thirdParty
)
==
1
{
avatarValue
:=
strings
.
TrimSpace
(
user
.
AvatarURL
)
avatarSource
=
buildUserProfileSourceContext
(
thirdParty
[
0
]
.
Provider
)
for
_
,
summary
:=
range
thirdParty
{
if
avatarValue
!=
""
&&
avatarValue
==
strings
.
TrimSpace
(
summary
.
AvatarURL
)
{
avatarSource
=
buildUserProfileSourceContext
(
summary
.
Provider
)
break
}
}
}
usernameValue
:=
strings
.
TrimSpace
(
user
.
Username
)
usernameValue
:=
strings
.
TrimSpace
(
user
.
Username
)
...
@@ -516,9 +526,6 @@ func inferUserProfileSources(user *service.User, identities service.UserIdentity
...
@@ -516,9 +526,6 @@ func inferUserProfileSources(user *service.User, identities service.UserIdentity
break
break
}
}
}
}
if
usernameSource
==
nil
&&
usernameValue
!=
""
&&
len
(
thirdParty
)
==
1
{
usernameSource
=
buildUserProfileSourceContext
(
thirdParty
[
0
]
.
Provider
)
}
profileSources
:=
map
[
string
]
*
userProfileSourceContext
{}
profileSources
:=
map
[
string
]
*
userProfileSourceContext
{}
if
avatarSource
!=
nil
{
if
avatarSource
!=
nil
{
...
...
backend/internal/handler/user_handler_test.go
View file @
ddf80f5e
...
@@ -253,7 +253,7 @@ func TestUserHandlerGetProfileReturnsIdentitySummaries(t *testing.T) {
...
@@ -253,7 +253,7 @@ func TestUserHandlerGetProfileReturnsIdentitySummaries(t *testing.T) {
require
.
Equal
(
t
,
"https://issuer.example.com"
,
resp
.
Data
.
Identities
.
OIDC
.
ProviderKey
)
require
.
Equal
(
t
,
"https://issuer.example.com"
,
resp
.
Data
.
Identities
.
OIDC
.
ProviderKey
)
require
.
False
(
t
,
resp
.
Data
.
Identities
.
WeChat
.
Bound
)
require
.
False
(
t
,
resp
.
Data
.
Identities
.
WeChat
.
Bound
)
require
.
True
(
t
,
resp
.
Data
.
Identities
.
WeChat
.
CanBind
)
require
.
True
(
t
,
resp
.
Data
.
Identities
.
WeChat
.
CanBind
)
require
.
Contains
(
t
,
resp
.
Data
.
Identities
.
WeChat
.
BindStartPath
,
"/api/v1/auth/oauth/wechat/start"
)
require
.
Contains
(
t
,
resp
.
Data
.
Identities
.
WeChat
.
BindStartPath
,
"/api/v1/auth/oauth/wechat/
bind/
start"
)
}
}
func
TestUserHandlerGetProfileReturnsLegacyCompatibilityFields
(
t
*
testing
.
T
)
{
func
TestUserHandlerGetProfileReturnsLegacyCompatibilityFields
(
t
*
testing
.
T
)
{
...
@@ -270,18 +270,19 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
...
@@ -270,18 +270,19 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
AvatarURL
:
"https://cdn.example.com/linuxdo.png"
,
AvatarURL
:
"https://cdn.example.com/linuxdo.png"
,
AvatarSource
:
"remote_url"
,
AvatarSource
:
"remote_url"
,
},
},
identities
:
[]
service
.
UserAuthIdentityRecord
{
identities
:
[]
service
.
UserAuthIdentityRecord
{
{
{
ProviderType
:
"linuxdo"
,
ProviderType
:
"linuxdo"
,
ProviderKey
:
"linuxdo"
,
ProviderKey
:
"linuxdo"
,
ProviderSubject
:
"linuxdo-subject-21"
,
ProviderSubject
:
"linuxdo-subject-21"
,
VerifiedAt
:
&
verifiedAt
,
VerifiedAt
:
&
verifiedAt
,
Metadata
:
map
[
string
]
any
{
Metadata
:
map
[
string
]
any
{
"username"
:
"linuxdo-handle"
,
"username"
:
"linuxdo-handle"
,
"avatar_url"
:
"https://cdn.example.com/linuxdo.png"
,
},
},
},
},
},
},
}
}
handler
:=
NewUserHandler
(
service
.
NewUserService
(
repo
,
nil
,
nil
,
nil
),
nil
,
nil
,
nil
)
handler
:=
NewUserHandler
(
service
.
NewUserService
(
repo
,
nil
,
nil
,
nil
),
nil
,
nil
,
nil
)
recorder
:=
httptest
.
NewRecorder
()
recorder
:=
httptest
.
NewRecorder
()
...
@@ -331,10 +332,102 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
...
@@ -331,10 +332,102 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
require
.
Equal
(
t
,
"linuxdo"
,
usernameSource
[
"source"
])
require
.
Equal
(
t
,
"linuxdo"
,
usernameSource
[
"source"
])
}
}
func
TestUserHandlerGetProfileDoesNotInferEditedProfileSourcesWithoutMatchingIdentityMetadata
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
repo
:=
&
userHandlerRepoStub
{
user
:
&
service
.
User
{
ID
:
22
,
Email
:
"edited-profile@example.com"
,
Username
:
"custom-name"
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
AvatarURL
:
"https://cdn.example.com/custom.png"
,
AvatarSource
:
"remote_url"
,
},
identities
:
[]
service
.
UserAuthIdentityRecord
{
{
ProviderType
:
"linuxdo"
,
ProviderKey
:
"linuxdo"
,
ProviderSubject
:
"linuxdo-subject-22"
,
Metadata
:
map
[
string
]
any
{
"username"
:
"linuxdo-handle"
,
"avatar_url"
:
"https://cdn.example.com/linuxdo.png"
,
},
},
},
}
handler
:=
NewUserHandler
(
service
.
NewUserService
(
repo
,
nil
,
nil
,
nil
),
nil
,
nil
,
nil
)
recorder
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
recorder
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/api/v1/user/profile"
,
nil
)
c
.
Set
(
string
(
middleware2
.
ContextKeyUser
),
middleware2
.
AuthSubject
{
UserID
:
22
})
handler
.
GetProfile
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
recorder
.
Code
)
var
resp
struct
{
Code
int
`json:"code"`
Data
map
[
string
]
any
`json:"data"`
}
require
.
NoError
(
t
,
json
.
Unmarshal
(
recorder
.
Body
.
Bytes
(),
&
resp
))
require
.
Equal
(
t
,
0
,
resp
.
Code
)
require
.
NotContains
(
t
,
resp
.
Data
,
"avatar_source"
)
require
.
NotContains
(
t
,
resp
.
Data
,
"username_source"
)
require
.
NotContains
(
t
,
resp
.
Data
,
"profile_sources"
)
}
type
userHandlerEmailCacheStub
struct
{
type
userHandlerEmailCacheStub
struct
{
data
*
service
.
VerificationCodeData
data
*
service
.
VerificationCodeData
}
}
type
userHandlerRefreshTokenCacheStub
struct
{
revokedUserIDs
[]
int64
}
func
(
s
*
userHandlerRefreshTokenCacheStub
)
StoreRefreshToken
(
context
.
Context
,
string
,
*
service
.
RefreshTokenData
,
time
.
Duration
)
error
{
return
nil
}
func
(
s
*
userHandlerRefreshTokenCacheStub
)
GetRefreshToken
(
context
.
Context
,
string
)
(
*
service
.
RefreshTokenData
,
error
)
{
return
nil
,
service
.
ErrRefreshTokenNotFound
}
func
(
s
*
userHandlerRefreshTokenCacheStub
)
DeleteRefreshToken
(
context
.
Context
,
string
)
error
{
return
nil
}
func
(
s
*
userHandlerRefreshTokenCacheStub
)
DeleteUserRefreshTokens
(
_
context
.
Context
,
userID
int64
)
error
{
s
.
revokedUserIDs
=
append
(
s
.
revokedUserIDs
,
userID
)
return
nil
}
func
(
s
*
userHandlerRefreshTokenCacheStub
)
DeleteTokenFamily
(
context
.
Context
,
string
)
error
{
return
nil
}
func
(
s
*
userHandlerRefreshTokenCacheStub
)
AddToUserTokenSet
(
context
.
Context
,
int64
,
string
,
time
.
Duration
)
error
{
return
nil
}
func
(
s
*
userHandlerRefreshTokenCacheStub
)
AddToFamilyTokenSet
(
context
.
Context
,
string
,
string
,
time
.
Duration
)
error
{
return
nil
}
func
(
s
*
userHandlerRefreshTokenCacheStub
)
GetUserTokenHashes
(
context
.
Context
,
int64
)
([]
string
,
error
)
{
return
nil
,
nil
}
func
(
s
*
userHandlerRefreshTokenCacheStub
)
GetFamilyTokenHashes
(
context
.
Context
,
string
)
([]
string
,
error
)
{
return
nil
,
nil
}
func
(
s
*
userHandlerRefreshTokenCacheStub
)
IsTokenInFamily
(
context
.
Context
,
string
,
string
)
(
bool
,
error
)
{
return
false
,
nil
}
func
(
s
*
userHandlerEmailCacheStub
)
GetVerificationCode
(
context
.
Context
,
string
)
(
*
service
.
VerificationCodeData
,
error
)
{
func
(
s
*
userHandlerEmailCacheStub
)
GetVerificationCode
(
context
.
Context
,
string
)
(
*
service
.
VerificationCodeData
,
error
)
{
return
s
.
data
,
nil
return
s
.
data
,
nil
}
}
...
@@ -495,6 +588,98 @@ func TestUserHandlerUnbindIdentityReturnsUpdatedProfile(t *testing.T) {
...
@@ -495,6 +588,98 @@ func TestUserHandlerUnbindIdentityReturnsUpdatedProfile(t *testing.T) {
require
.
Equal
(
t
,
false
,
linuxdoBinding
[
"bound"
])
require
.
Equal
(
t
,
false
,
linuxdoBinding
[
"bound"
])
}
}
func
TestUserHandlerUnbindIdentityRevokesAllUserSessionsWhenAuthServiceConfigured
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
repo
:=
&
userHandlerRepoStub
{
user
:
&
service
.
User
{
ID
:
23
,
Email
:
"identity@example.com"
,
Username
:
"identity-user"
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
TokenVersion
:
4
,
},
identities
:
[]
service
.
UserAuthIdentityRecord
{
{
ProviderType
:
"email"
,
ProviderKey
:
"email"
,
ProviderSubject
:
"identity@example.com"
,
},
{
ProviderType
:
"linuxdo"
,
ProviderKey
:
"linuxdo"
,
ProviderSubject
:
"linuxdo-subject-23"
,
},
},
}
refreshTokenCache
:=
&
userHandlerRefreshTokenCacheStub
{}
cfg
:=
&
config
.
Config
{
JWT
:
config
.
JWTConfig
{
Secret
:
"test-secret"
,
ExpireHour
:
1
,
},
}
authService
:=
service
.
NewAuthService
(
nil
,
repo
,
nil
,
refreshTokenCache
,
cfg
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
handler
:=
NewUserHandler
(
service
.
NewUserService
(
repo
,
nil
,
nil
,
nil
),
authService
,
nil
,
nil
)
recorder
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
recorder
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodDelete
,
"/api/v1/user/account-bindings/linuxdo"
,
nil
)
c
.
Set
(
string
(
middleware2
.
ContextKeyUser
),
middleware2
.
AuthSubject
{
UserID
:
23
})
c
.
Params
=
gin
.
Params
{{
Key
:
"provider"
,
Value
:
"linuxdo"
}}
handler
.
UnbindIdentity
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
recorder
.
Code
)
require
.
Equal
(
t
,
[]
int64
{
23
},
refreshTokenCache
.
revokedUserIDs
)
require
.
Equal
(
t
,
int64
(
5
),
repo
.
user
.
TokenVersion
)
}
func
TestUserHandlerUnbindIdentityDoesNotRevokeSessionsWhenNothingWasUnbound
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
repo
:=
&
userHandlerRepoStub
{
user
:
&
service
.
User
{
ID
:
24
,
Email
:
"identity@example.com"
,
Username
:
"identity-user"
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
TokenVersion
:
4
,
},
identities
:
[]
service
.
UserAuthIdentityRecord
{
{
ProviderType
:
"email"
,
ProviderKey
:
"email"
,
ProviderSubject
:
"identity@example.com"
,
},
},
}
refreshTokenCache
:=
&
userHandlerRefreshTokenCacheStub
{}
cfg
:=
&
config
.
Config
{
JWT
:
config
.
JWTConfig
{
Secret
:
"test-secret"
,
ExpireHour
:
1
,
},
}
authService
:=
service
.
NewAuthService
(
nil
,
repo
,
nil
,
refreshTokenCache
,
cfg
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
handler
:=
NewUserHandler
(
service
.
NewUserService
(
repo
,
nil
,
nil
,
nil
),
authService
,
nil
,
nil
)
recorder
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
recorder
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodDelete
,
"/api/v1/user/account-bindings/linuxdo"
,
nil
)
c
.
Set
(
string
(
middleware2
.
ContextKeyUser
),
middleware2
.
AuthSubject
{
UserID
:
24
})
c
.
Params
=
gin
.
Params
{{
Key
:
"provider"
,
Value
:
"linuxdo"
}}
handler
.
UnbindIdentity
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
recorder
.
Code
)
require
.
Empty
(
t
,
repo
.
unbound
)
require
.
Empty
(
t
,
refreshTokenCache
.
revokedUserIDs
)
require
.
Equal
(
t
,
int64
(
4
),
repo
.
user
.
TokenVersion
)
}
func
TestUserHandlerBindEmailIdentityRejectsWrongCurrentPasswordForBoundEmail
(
t
*
testing
.
T
)
{
func
TestUserHandlerBindEmailIdentityRejectsWrongCurrentPasswordForBoundEmail
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
gin
.
SetMode
(
gin
.
TestMode
)
...
@@ -587,7 +772,7 @@ func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) {
...
@@ -587,7 +772,7 @@ func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) {
require
.
Equal
(
t
,
"wechat"
,
resp
.
Data
.
Provider
)
require
.
Equal
(
t
,
"wechat"
,
resp
.
Data
.
Provider
)
require
.
Equal
(
t
,
"GET"
,
resp
.
Data
.
Method
)
require
.
Equal
(
t
,
"GET"
,
resp
.
Data
.
Method
)
require
.
True
(
t
,
resp
.
Data
.
UseBrowserRedirect
)
require
.
True
(
t
,
resp
.
Data
.
UseBrowserRedirect
)
require
.
Contains
(
t
,
resp
.
Data
.
AuthorizeURL
,
"/api/v1/auth/oauth/wechat/start"
)
require
.
Contains
(
t
,
resp
.
Data
.
AuthorizeURL
,
"/api/v1/auth/oauth/wechat/
bind/
start"
)
require
.
Contains
(
t
,
resp
.
Data
.
AuthorizeURL
,
"intent=bind_current_user"
)
require
.
Contains
(
t
,
resp
.
Data
.
AuthorizeURL
,
"intent=bind_current_user"
)
require
.
Contains
(
t
,
resp
.
Data
.
AuthorizeURL
,
"redirect=%2Fsettings%2Fprofile"
)
require
.
Contains
(
t
,
resp
.
Data
.
AuthorizeURL
,
"redirect=%2Fsettings%2Fprofile"
)
}
}
backend/internal/payment/provider/wxpay.go
View file @
ddf80f5e
...
@@ -60,11 +60,6 @@ const (
...
@@ -60,11 +60,6 @@ const (
wxpayEventTransactionSuccess
=
"TRANSACTION.SUCCESS"
wxpayEventTransactionSuccess
=
"TRANSACTION.SUCCESS"
)
)
// WeChat Pay error codes.
const
(
wxpayErrNoAuth
=
"NO_AUTH"
)
var
(
var
(
wxpayNativePrepay
=
func
(
ctx
context
.
Context
,
svc
native
.
NativeApiService
,
req
native
.
PrepayRequest
)
(
*
native
.
PrepayResponse
,
*
core
.
APIResult
,
error
)
{
wxpayNativePrepay
=
func
(
ctx
context
.
Context
,
svc
native
.
NativeApiService
,
req
native
.
PrepayRequest
)
(
*
native
.
PrepayResponse
,
*
core
.
APIResult
,
error
)
{
return
svc
.
Prepay
(
ctx
,
req
)
return
svc
.
Prepay
(
ctx
,
req
)
...
@@ -200,14 +195,7 @@ func (w *Wxpay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequ
...
@@ -200,14 +195,7 @@ func (w *Wxpay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequ
case
wxpayModeJSAPI
:
case
wxpayModeJSAPI
:
return
w
.
prepayJSAPI
(
ctx
,
client
,
req
,
notifyURL
,
totalFen
)
return
w
.
prepayJSAPI
(
ctx
,
client
,
req
,
notifyURL
,
totalFen
)
case
wxpayModeH5
:
case
wxpayModeH5
:
resp
,
err
:=
w
.
prepayH5
(
ctx
,
client
,
req
,
notifyURL
,
totalFen
)
return
w
.
prepayH5
(
ctx
,
client
,
req
,
notifyURL
,
totalFen
)
if
err
==
nil
{
return
resp
,
nil
}
if
strings
.
Contains
(
err
.
Error
(),
wxpayErrNoAuth
)
{
return
nil
,
fmt
.
Errorf
(
"wxpay h5 payments are not authorized for this merchant: %w"
,
err
)
}
return
nil
,
err
case
wxpayModeNative
:
case
wxpayModeNative
:
return
w
.
prepayNative
(
ctx
,
client
,
req
,
notifyURL
,
totalFen
)
return
w
.
prepayNative
(
ctx
,
client
,
req
,
notifyURL
,
totalFen
)
default
:
default
:
...
...
backend/internal/payment/provider/wxpay_test.go
View file @
ddf80f5e
...
@@ -8,6 +8,7 @@ import (
...
@@ -8,6 +8,7 @@ import (
"crypto/rsa"
"crypto/rsa"
"crypto/x509"
"crypto/x509"
"encoding/pem"
"encoding/pem"
"errors"
"net/url"
"net/url"
"strings"
"strings"
"testing"
"testing"
...
@@ -641,3 +642,68 @@ func TestCreatePaymentMobileH5IncludesConfiguredSceneInfo(t *testing.T) {
...
@@ -641,3 +642,68 @@ func TestCreatePaymentMobileH5IncludesConfiguredSceneInfo(t *testing.T) {
t
.
Fatalf
(
"pay_url = %q, want redirect_url query appended"
,
resp
.
PayURL
)
t
.
Fatalf
(
"pay_url = %q, want redirect_url query appended"
,
resp
.
PayURL
)
}
}
}
}
func
TestCreatePaymentMobileH5ReturnsNoAuthErrorWithoutNativeFallback
(
t
*
testing
.
T
)
{
origJSAPIPrepay
:=
wxpayJSAPIPrepayWithRequestPayment
origNativePrepay
:=
wxpayNativePrepay
origH5Prepay
:=
wxpayH5Prepay
t
.
Cleanup
(
func
()
{
wxpayJSAPIPrepayWithRequestPayment
=
origJSAPIPrepay
wxpayNativePrepay
=
origNativePrepay
wxpayH5Prepay
=
origH5Prepay
})
jsapiCalls
:=
0
nativeCalls
:=
0
h5Calls
:=
0
wxpayJSAPIPrepayWithRequestPayment
=
func
(
ctx
context
.
Context
,
svc
jsapi
.
JsapiApiService
,
req
jsapi
.
PrepayRequest
)
(
*
jsapi
.
PrepayWithRequestPaymentResponse
,
*
core
.
APIResult
,
error
)
{
jsapiCalls
++
return
&
jsapi
.
PrepayWithRequestPaymentResponse
{},
nil
,
nil
}
wxpayH5Prepay
=
func
(
ctx
context
.
Context
,
svc
h5
.
H5ApiService
,
req
h5
.
PrepayRequest
)
(
*
h5
.
PrepayResponse
,
*
core
.
APIResult
,
error
)
{
h5Calls
++
return
nil
,
nil
,
errors
.
New
(
"NO_AUTH"
)
}
wxpayNativePrepay
=
func
(
ctx
context
.
Context
,
svc
native
.
NativeApiService
,
req
native
.
PrepayRequest
)
(
*
native
.
PrepayResponse
,
*
core
.
APIResult
,
error
)
{
nativeCalls
++
return
&
native
.
PrepayResponse
{
CodeUrl
:
core
.
String
(
"weixin://wxpay/bizpayurl?pr=fallback-native"
),
},
nil
,
nil
}
provider
:=
&
Wxpay
{
config
:
map
[
string
]
string
{
"appId"
:
"wx123"
,
"mchId"
:
"mch123"
,
},
coreClient
:
&
core
.
Client
{},
}
resp
,
err
:=
provider
.
CreatePayment
(
context
.
Background
(),
payment
.
CreatePaymentRequest
{
OrderID
:
"sub2_100"
,
Amount
:
"66.88"
,
PaymentType
:
payment
.
TypeWxpay
,
Subject
:
"Balance Recharge"
,
NotifyURL
:
"https://merchant.example/payment/notify"
,
ClientIP
:
"203.0.113.10"
,
IsMobile
:
true
,
})
if
err
==
nil
{
t
.
Fatal
(
"expected no-auth error, got nil"
)
}
if
jsapiCalls
!=
0
{
t
.
Fatalf
(
"jsapi prepay calls = %d, want 0"
,
jsapiCalls
)
}
if
h5Calls
!=
1
{
t
.
Fatalf
(
"h5 prepay calls = %d, want 1"
,
h5Calls
)
}
if
nativeCalls
!=
0
{
t
.
Fatalf
(
"native prepay calls = %d, want 0"
,
nativeCalls
)
}
if
resp
!=
nil
{
t
.
Fatalf
(
"expected nil response, got %+v"
,
resp
)
}
if
!
strings
.
Contains
(
err
.
Error
(),
"NO_AUTH"
)
{
t
.
Fatalf
(
"error = %v, want NO_AUTH"
,
err
)
}
}
backend/internal/payment/wire.go
View file @
ddf80f5e
...
@@ -4,6 +4,7 @@ import (
...
@@ -4,6 +4,7 @@ import (
"encoding/hex"
"encoding/hex"
"fmt"
"fmt"
"log/slog"
"log/slog"
"strings"
dbent
"github.com/Wei-Shaw/sub2api/ent"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/config"
...
@@ -19,11 +20,22 @@ type EncryptionKey []byte
...
@@ -19,11 +20,22 @@ type EncryptionKey []byte
// When the key is non-empty but invalid (bad hex or wrong length), an error is returned
// When the key is non-empty but invalid (bad hex or wrong length), an error is returned
// to prevent startup with a misconfigured encryption key.
// to prevent startup with a misconfigured encryption key.
func
ProvideEncryptionKey
(
cfg
*
config
.
Config
)
(
EncryptionKey
,
error
)
{
func
ProvideEncryptionKey
(
cfg
*
config
.
Config
)
(
EncryptionKey
,
error
)
{
if
cfg
.
Totp
.
EncryptionKey
==
""
{
if
cfg
==
nil
{
slog
.
Warn
(
"payment encryption key not configured — encrypted payment config and resume signing will be unavailable"
)
return
nil
,
nil
}
keyHex
:=
strings
.
TrimSpace
(
cfg
.
Totp
.
EncryptionKey
)
if
keyHex
==
""
{
slog
.
Warn
(
"payment encryption key not configured — encrypted payment config will be unavailable"
)
slog
.
Warn
(
"payment encryption key not configured — encrypted payment config will be unavailable"
)
return
nil
,
nil
return
nil
,
nil
}
}
key
,
err
:=
hex
.
DecodeString
(
cfg
.
Totp
.
EncryptionKey
)
// Reject auto-generated TOTP keys for payment signing.
// They change across restarts/instances and can silently break resume-token flows.
if
!
cfg
.
Totp
.
EncryptionKeyConfigured
{
slog
.
Warn
(
"payment encryption/signing key is not explicitly configured; set TOTP_ENCRYPTION_KEY to enable payment resume tokens"
)
return
nil
,
nil
}
key
,
err
:=
hex
.
DecodeString
(
keyHex
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"invalid payment encryption key (hex decode): %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"invalid payment encryption key (hex decode): %w"
,
err
)
}
}
...
...
backend/internal/payment/wire_test.go
0 → 100644
View file @
ddf80f5e
package
payment
import
(
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
)
func
TestProvideEncryptionKeySkipsAutoGeneratedTotpKey
(
t
*
testing
.
T
)
{
t
.
Parallel
()
cfg
:=
&
config
.
Config
{
Totp
:
config
.
TotpConfig
{
EncryptionKey
:
strings
.
Repeat
(
"a"
,
64
),
EncryptionKeyConfigured
:
false
,
},
}
key
,
err
:=
ProvideEncryptionKey
(
cfg
)
if
err
!=
nil
{
t
.
Fatalf
(
"ProvideEncryptionKey returned error: %v"
,
err
)
}
if
len
(
key
)
!=
0
{
t
.
Fatalf
(
"encryption key len = %d, want 0"
,
len
(
key
))
}
}
func
TestProvideEncryptionKeyUsesConfiguredTotpKey
(
t
*
testing
.
T
)
{
t
.
Parallel
()
cfg
:=
&
config
.
Config
{
Totp
:
config
.
TotpConfig
{
EncryptionKey
:
"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
,
EncryptionKeyConfigured
:
true
,
},
}
key
,
err
:=
ProvideEncryptionKey
(
cfg
)
if
err
!=
nil
{
t
.
Fatalf
(
"ProvideEncryptionKey returned error: %v"
,
err
)
}
if
len
(
key
)
!=
32
{
t
.
Fatalf
(
"encryption key len = %d, want 32"
,
len
(
key
))
}
}
func
TestProvideEncryptionKeyRejectsConfiguredInvalidLength
(
t
*
testing
.
T
)
{
t
.
Parallel
()
cfg
:=
&
config
.
Config
{
Totp
:
config
.
TotpConfig
{
EncryptionKey
:
"abcd"
,
EncryptionKeyConfigured
:
true
,
},
}
_
,
err
:=
ProvideEncryptionKey
(
cfg
)
if
err
==
nil
{
t
.
Fatal
(
"expected error for invalid key length"
)
}
}
backend/internal/repository/auth_identity_legacy_migration_integration_test.go
View file @
ddf80f5e
...
@@ -4,6 +4,7 @@ package repository
...
@@ -4,6 +4,7 @@ package repository
import
(
import
(
"context"
"context"
"database/sql"
"os"
"os"
"path/filepath"
"path/filepath"
"strconv"
"strconv"
...
@@ -20,32 +21,8 @@ func TestAuthIdentityLegacyExternalBackfillMigration(t *testing.T) {
...
@@ -20,32 +21,8 @@ func TestAuthIdentityLegacyExternalBackfillMigration(t *testing.T) {
migrationSQL
,
err
:=
os
.
ReadFile
(
migrationPath
)
migrationSQL
,
err
:=
os
.
ReadFile
(
migrationPath
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
_
,
err
=
tx
.
ExecContext
(
ctx
,
`
prepareLegacyExternalIdentitiesTable
(
t
,
tx
,
ctx
)
CREATE TABLE IF NOT EXISTS user_external_identities (
truncateAuthIdentityLegacyFixtureTables
(
t
,
tx
,
ctx
)
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL,
provider TEXT NOT NULL,
provider_user_id TEXT NOT NULL,
provider_union_id TEXT NULL,
provider_username TEXT NOT NULL DEFAULT '',
display_name TEXT NOT NULL DEFAULT '',
profile_url TEXT NOT NULL DEFAULT '',
avatar_url TEXT NOT NULL DEFAULT '',
metadata TEXT NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
);
TRUNCATE TABLE
auth_identity_channels,
identity_adoption_decisions,
auth_identities,
auth_identity_migration_reports,
user_external_identities,
users
RESTART IDENTITY CASCADE;
`
)
require
.
NoError
(
t
,
err
)
var
linuxDoUserID
int64
var
linuxDoUserID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
...
@@ -218,32 +195,8 @@ func TestAuthIdentityLegacyExternalMigrations_ChainHandlesMalformedAndNonObjectM
...
@@ -218,32 +195,8 @@ func TestAuthIdentityLegacyExternalMigrations_ChainHandlesMalformedAndNonObjectM
migration116SQL
,
err
:=
os
.
ReadFile
(
migration116Path
)
migration116SQL
,
err
:=
os
.
ReadFile
(
migration116Path
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
_
,
err
=
tx
.
ExecContext
(
ctx
,
`
prepareLegacyExternalIdentitiesTable
(
t
,
tx
,
ctx
)
CREATE TABLE IF NOT EXISTS user_external_identities (
truncateAuthIdentityLegacyFixtureTables
(
t
,
tx
,
ctx
)
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL,
provider TEXT NOT NULL,
provider_user_id TEXT NOT NULL,
provider_union_id TEXT NULL,
provider_username TEXT NOT NULL DEFAULT '',
display_name TEXT NOT NULL DEFAULT '',
profile_url TEXT NOT NULL DEFAULT '',
avatar_url TEXT NOT NULL DEFAULT '',
metadata TEXT NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
);
TRUNCATE TABLE
auth_identity_channels,
identity_adoption_decisions,
auth_identities,
auth_identity_migration_reports,
user_external_identities,
users
RESTART IDENTITY CASCADE;
`
)
require
.
NoError
(
t
,
err
)
var
linuxDoMalformedUserID
int64
var
linuxDoMalformedUserID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
...
@@ -408,32 +361,8 @@ func TestAuthIdentityLegacyExternalSafetyMigration_ReportsConflictsAndDowngrades
...
@@ -408,32 +361,8 @@ func TestAuthIdentityLegacyExternalSafetyMigration_ReportsConflictsAndDowngrades
migrationSQL
,
err
:=
os
.
ReadFile
(
migrationPath
)
migrationSQL
,
err
:=
os
.
ReadFile
(
migrationPath
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
_
,
err
=
tx
.
ExecContext
(
ctx
,
`
prepareLegacyExternalIdentitiesTable
(
t
,
tx
,
ctx
)
CREATE TABLE IF NOT EXISTS user_external_identities (
truncateAuthIdentityLegacyFixtureTables
(
t
,
tx
,
ctx
)
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL,
provider TEXT NOT NULL,
provider_user_id TEXT NOT NULL,
provider_union_id TEXT NULL,
provider_username TEXT NOT NULL DEFAULT '',
display_name TEXT NOT NULL DEFAULT '',
profile_url TEXT NOT NULL DEFAULT '',
avatar_url TEXT NOT NULL DEFAULT '',
metadata TEXT NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
);
TRUNCATE TABLE
auth_identity_channels,
identity_adoption_decisions,
auth_identities,
auth_identity_migration_reports,
user_external_identities,
users
RESTART IDENTITY CASCADE;
`
)
require
.
NoError
(
t
,
err
)
userIDs
:=
make
([]
int64
,
0
,
8
)
userIDs
:=
make
([]
int64
,
0
,
8
)
for
_
,
email
:=
range
[]
string
{
for
_
,
email
:=
range
[]
string
{
...
@@ -643,6 +572,388 @@ FROM auth_identity_migration_reports
...
@@ -643,6 +572,388 @@ FROM auth_identity_migration_reports
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
SELECT COUNT(*)
SELECT COUNT(*)
FROM auth_identity_migration_reports
FROM auth_identity_migration_reports
`
)
.
Scan
(
&
afterCount
))
`
)
.
Scan
(
&
afterCount
))
require
.
Equal
(
t
,
beforeCount
,
afterCount
)
require
.
Equal
(
t
,
beforeCount
,
afterCount
)
}
}
func
TestAuthIdentityLegacyExternalBackfillMigration_SkipsAmbiguousCanonicalSubjects
(
t
*
testing
.
T
)
{
tx
:=
testTx
(
t
)
ctx
:=
context
.
Background
()
migrationPath
:=
filepath
.
Join
(
".."
,
".."
,
"migrations"
,
"115_auth_identity_legacy_external_backfill.sql"
)
migrationSQL
,
err
:=
os
.
ReadFile
(
migrationPath
)
require
.
NoError
(
t
,
err
)
prepareLegacyExternalIdentitiesTable
(
t
,
tx
,
ctx
)
truncateAuthIdentityLegacyFixtureTables
(
t
,
tx
,
ctx
)
var
linuxDoFirstUserID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-linuxdo-ambiguous-a@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`
)
.
Scan
(
&
linuxDoFirstUserID
))
var
linuxDoSecondUserID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-linuxdo-ambiguous-b@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`
)
.
Scan
(
&
linuxDoSecondUserID
))
var
wechatFirstUserID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-wechat-ambiguous-a@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`
)
.
Scan
(
&
wechatFirstUserID
))
var
wechatSecondUserID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-wechat-ambiguous-b@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`
)
.
Scan
(
&
wechatSecondUserID
))
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'linuxdo', 'linuxdo-ambiguous-subject', NULL, 'legacy-linuxdo-ambiguous-a', 'Legacy LinuxDo Ambiguous A', '{"source":"legacy"}')
RETURNING id
`
,
linuxDoFirstUserID
)
.
Scan
(
new
(
int64
)))
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'linuxdo', 'linuxdo-ambiguous-subject', NULL, 'legacy-linuxdo-ambiguous-b', 'Legacy LinuxDo Ambiguous B', '{"source":"legacy"}')
RETURNING id
`
,
linuxDoSecondUserID
)
.
Scan
(
new
(
int64
)))
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'wechat', 'openid-ambiguous-a', 'union-ambiguous-subject', 'legacy-wechat-ambiguous-a', 'Legacy WeChat Ambiguous A', '{"channel":"oa","appid":"wx-ambiguous-a"}')
RETURNING id
`
,
wechatFirstUserID
)
.
Scan
(
new
(
int64
)))
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'wechat', 'openid-ambiguous-b', 'union-ambiguous-subject', 'legacy-wechat-ambiguous-b', 'Legacy WeChat Ambiguous B', '{"channel":"oa","appid":"wx-ambiguous-b"}')
RETURNING id
`
,
wechatSecondUserID
)
.
Scan
(
new
(
int64
)))
_
,
err
=
tx
.
ExecContext
(
ctx
,
string
(
migrationSQL
))
require
.
NoError
(
t
,
err
)
var
linuxDoIdentityCount
int
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
SELECT COUNT(*)
FROM auth_identities
WHERE provider_type = 'linuxdo'
AND provider_key = 'linuxdo'
AND provider_subject = 'linuxdo-ambiguous-subject'
`
)
.
Scan
(
&
linuxDoIdentityCount
))
require
.
Zero
(
t
,
linuxDoIdentityCount
)
var
wechatIdentityCount
int
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
SELECT COUNT(*)
FROM auth_identities
WHERE provider_type = 'wechat'
AND provider_key = 'wechat-main'
AND provider_subject = 'union-ambiguous-subject'
`
)
.
Scan
(
&
wechatIdentityCount
))
require
.
Zero
(
t
,
wechatIdentityCount
)
var
wechatChannelCount
int
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
SELECT COUNT(*)
FROM auth_identity_channels
WHERE provider_type = 'wechat'
AND provider_key = 'wechat-main'
AND channel = 'oa'
AND channel_app_id IN ('wx-ambiguous-a', 'wx-ambiguous-b')
`
)
.
Scan
(
&
wechatChannelCount
))
require
.
Zero
(
t
,
wechatChannelCount
)
}
func
TestAuthIdentityLegacyExternalMigrations_ReportAmbiguousCanonicalSubjectsWithoutWinnerAttribution
(
t
*
testing
.
T
)
{
tx
:=
testTx
(
t
)
ctx
:=
context
.
Background
()
migration115Path
:=
filepath
.
Join
(
".."
,
".."
,
"migrations"
,
"115_auth_identity_legacy_external_backfill.sql"
)
migration115SQL
,
err
:=
os
.
ReadFile
(
migration115Path
)
require
.
NoError
(
t
,
err
)
migration116Path
:=
filepath
.
Join
(
".."
,
".."
,
"migrations"
,
"116_auth_identity_legacy_external_safety_reports.sql"
)
migration116SQL
,
err
:=
os
.
ReadFile
(
migration116Path
)
require
.
NoError
(
t
,
err
)
prepareLegacyExternalIdentitiesTable
(
t
,
tx
,
ctx
)
truncateAuthIdentityLegacyFixtureTables
(
t
,
tx
,
ctx
)
var
linuxDoFirstUserID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-linuxdo-conflict-a@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`
)
.
Scan
(
&
linuxDoFirstUserID
))
var
linuxDoSecondUserID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-linuxdo-conflict-b@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`
)
.
Scan
(
&
linuxDoSecondUserID
))
var
wechatFirstUserID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-wechat-conflict-a@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`
)
.
Scan
(
&
wechatFirstUserID
))
var
wechatSecondUserID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-wechat-conflict-b@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`
)
.
Scan
(
&
wechatSecondUserID
))
var
linuxDoFirstLegacyID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'linuxdo', 'linuxdo-conflict-subject', NULL, 'legacy-linuxdo-conflict-a', 'Legacy LinuxDo Conflict A', '{"source":"legacy"}')
RETURNING id
`
,
linuxDoFirstUserID
)
.
Scan
(
&
linuxDoFirstLegacyID
))
var
linuxDoSecondLegacyID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'linuxdo', 'linuxdo-conflict-subject', NULL, 'legacy-linuxdo-conflict-b', 'Legacy LinuxDo Conflict B', '{"source":"legacy"}')
RETURNING id
`
,
linuxDoSecondUserID
)
.
Scan
(
&
linuxDoSecondLegacyID
))
var
wechatFirstLegacyID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'wechat', 'openid-conflict-a', 'union-conflict-subject', 'legacy-wechat-conflict-a', 'Legacy WeChat Conflict A', '{"channel":"oa","appid":"wx-conflict-a"}')
RETURNING id
`
,
wechatFirstUserID
)
.
Scan
(
&
wechatFirstLegacyID
))
var
wechatSecondLegacyID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'wechat', 'openid-conflict-b', 'union-conflict-subject', 'legacy-wechat-conflict-b', 'Legacy WeChat Conflict B', '{"channel":"oa","appid":"wx-conflict-b"}')
RETURNING id
`
,
wechatSecondUserID
)
.
Scan
(
&
wechatSecondLegacyID
))
_
,
err
=
tx
.
ExecContext
(
ctx
,
string
(
migration115SQL
))
require
.
NoError
(
t
,
err
)
_
,
err
=
tx
.
ExecContext
(
ctx
,
string
(
migration116SQL
))
require
.
NoError
(
t
,
err
)
var
identityCount
int
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
SELECT COUNT(*)
FROM auth_identities
WHERE (provider_type = 'linuxdo' AND provider_key = 'linuxdo' AND provider_subject = 'linuxdo-conflict-subject')
OR (provider_type = 'wechat' AND provider_key = 'wechat-main' AND provider_subject = 'union-conflict-subject')
`
)
.
Scan
(
&
identityCount
))
require
.
Zero
(
t
,
identityCount
)
var
conflictReportCount
int
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
SELECT COUNT(*)
FROM auth_identity_migration_reports
WHERE report_type = 'legacy_external_identity_conflict'
AND report_key IN ($1, $2, $3, $4)
`
,
"legacy_external_identity:"
+
strconv
.
FormatInt
(
linuxDoFirstLegacyID
,
10
),
"legacy_external_identity:"
+
strconv
.
FormatInt
(
linuxDoSecondLegacyID
,
10
),
"legacy_external_identity:"
+
strconv
.
FormatInt
(
wechatFirstLegacyID
,
10
),
"legacy_external_identity:"
+
strconv
.
FormatInt
(
wechatSecondLegacyID
,
10
))
.
Scan
(
&
conflictReportCount
))
require
.
Equal
(
t
,
4
,
conflictReportCount
)
var
winnerAttributedReportCount
int
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
SELECT COUNT(*)
FROM auth_identity_migration_reports
WHERE report_type = 'legacy_external_identity_conflict'
AND report_key IN ($1, $2, $3, $4)
AND details ->> 'existing_identity_id' IS NOT NULL
`
,
"legacy_external_identity:"
+
strconv
.
FormatInt
(
linuxDoFirstLegacyID
,
10
),
"legacy_external_identity:"
+
strconv
.
FormatInt
(
linuxDoSecondLegacyID
,
10
),
"legacy_external_identity:"
+
strconv
.
FormatInt
(
wechatFirstLegacyID
,
10
),
"legacy_external_identity:"
+
strconv
.
FormatInt
(
wechatSecondLegacyID
,
10
))
.
Scan
(
&
winnerAttributedReportCount
))
require
.
Zero
(
t
,
winnerAttributedReportCount
)
}
func
TestAuthIdentityMigrationReportTypeWideningPreflightKeeps109And116SafeBefore121
(
t
*
testing
.
T
)
{
tx
:=
testTx
(
t
)
ctx
:=
context
.
Background
()
migration108aPath
:=
filepath
.
Join
(
".."
,
".."
,
"migrations"
,
"108a_widen_auth_identity_migration_report_type.sql"
)
migration108aSQL
,
err
:=
os
.
ReadFile
(
migration108aPath
)
require
.
NoError
(
t
,
err
)
migration109Path
:=
filepath
.
Join
(
".."
,
".."
,
"migrations"
,
"109_auth_identity_compat_backfill.sql"
)
migration109SQL
,
err
:=
os
.
ReadFile
(
migration109Path
)
require
.
NoError
(
t
,
err
)
migration116Path
:=
filepath
.
Join
(
".."
,
".."
,
"migrations"
,
"116_auth_identity_legacy_external_safety_reports.sql"
)
migration116SQL
,
err
:=
os
.
ReadFile
(
migration116Path
)
require
.
NoError
(
t
,
err
)
prepareLegacyExternalIdentitiesTable
(
t
,
tx
,
ctx
)
truncateAuthIdentityLegacyFixtureTables
(
t
,
tx
,
ctx
)
_
,
err
=
tx
.
ExecContext
(
ctx
,
`
ALTER TABLE auth_identity_migration_reports
ALTER COLUMN report_type TYPE VARCHAR(40);
`
)
require
.
NoError
(
t
,
err
)
var
oidcSyntheticUserID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('oidc-before-121@oidc-connect.invalid', 'hash', 'user', 'active', 0, 1)
RETURNING id`
)
.
Scan
(
&
oidcSyntheticUserID
))
var
linuxdoLegacyUserID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-linuxdo-before-121@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`
)
.
Scan
(
&
linuxdoLegacyUserID
))
var
invalidMetadataLegacyID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'linuxdo', 'linuxdo-before-121', NULL, 'legacy-linuxdo-before-121', 'Legacy LinuxDo Before 121', '{invalid')
RETURNING id
`
,
linuxdoLegacyUserID
)
.
Scan
(
&
invalidMetadataLegacyID
))
_
,
err
=
tx
.
ExecContext
(
ctx
,
string
(
migration108aSQL
))
require
.
NoError
(
t
,
err
)
_
,
err
=
tx
.
ExecContext
(
ctx
,
string
(
migration109SQL
))
require
.
NoError
(
t
,
err
)
_
,
err
=
tx
.
ExecContext
(
ctx
,
string
(
migration116SQL
))
require
.
NoError
(
t
,
err
)
var
reportTypeWidth
int
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
SELECT character_maximum_length
FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = 'auth_identity_migration_reports'
AND column_name = 'report_type'
`
)
.
Scan
(
&
reportTypeWidth
))
require
.
Equal
(
t
,
80
,
reportTypeWidth
)
var
oidcSyntheticRecoveryReportCount
int
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
SELECT COUNT(*)
FROM auth_identity_migration_reports
WHERE report_type = 'oidc_synthetic_email_requires_manual_recovery'
AND report_key = $1
`
,
strconv
.
FormatInt
(
oidcSyntheticUserID
,
10
))
.
Scan
(
&
oidcSyntheticRecoveryReportCount
))
require
.
Equal
(
t
,
1
,
oidcSyntheticRecoveryReportCount
)
var
invalidMetadataReportCount
int
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
SELECT COUNT(*)
FROM auth_identity_migration_reports
WHERE report_type = 'legacy_external_identity_invalid_metadata_json'
AND report_key = $1
`
,
"legacy_external_identity:"
+
strconv
.
FormatInt
(
invalidMetadataLegacyID
,
10
))
.
Scan
(
&
invalidMetadataReportCount
))
require
.
Equal
(
t
,
1
,
invalidMetadataReportCount
)
}
func
prepareLegacyExternalIdentitiesTable
(
t
*
testing
.
T
,
tx
*
sql
.
Tx
,
ctx
context
.
Context
)
{
t
.
Helper
()
_
,
err
:=
tx
.
ExecContext
(
ctx
,
`
CREATE TABLE IF NOT EXISTS user_external_identities (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL,
provider TEXT NOT NULL,
provider_user_id TEXT NOT NULL,
provider_union_id TEXT NULL,
provider_username TEXT NOT NULL DEFAULT '',
display_name TEXT NOT NULL DEFAULT '',
profile_url TEXT NOT NULL DEFAULT '',
avatar_url TEXT NOT NULL DEFAULT '',
metadata TEXT NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
);
`
)
require
.
NoError
(
t
,
err
)
}
func
truncateAuthIdentityLegacyFixtureTables
(
t
*
testing
.
T
,
tx
*
sql
.
Tx
,
ctx
context
.
Context
)
{
t
.
Helper
()
_
,
err
:=
tx
.
ExecContext
(
ctx
,
`
TRUNCATE TABLE
auth_identity_channels,
identity_adoption_decisions,
pending_auth_sessions,
auth_identities,
auth_identity_migration_reports,
user_provider_default_grants,
user_avatars,
user_external_identities,
users
RESTART IDENTITY CASCADE;
`
)
require
.
NoError
(
t
,
err
)
}
backend/internal/repository/migrations_runner.go
View file @
ddf80f5e
...
@@ -51,34 +51,30 @@ CREATE TABLE IF NOT EXISTS atlas_schema_revisions (
...
@@ -51,34 +51,30 @@ CREATE TABLE IF NOT EXISTS atlas_schema_revisions (
const
migrationsAdvisoryLockID
int64
=
694208311321144027
const
migrationsAdvisoryLockID
int64
=
694208311321144027
const
migrationsLockRetryInterval
=
500
*
time
.
Millisecond
const
migrationsLockRetryInterval
=
500
*
time
.
Millisecond
const
nonTransactionalMigrationSuffix
=
"_notx.sql"
const
nonTransactionalMigrationSuffix
=
"_notx.sql"
const
paymentOrdersOutTradeNoUniqueMigration
=
"120_enforce_payment_orders_out_trade_no_unique_notx.sql"
const
paymentOrdersOutTradeNoUniqueIndex
=
"paymentorder_out_trade_no_unique"
type
migrationChecksumCompatibilityRule
struct
{
type
migrationChecksumCompatibilityRule
struct
{
fileChecksum
string
fileChecksum
string
acceptedDBChecksum
map
[
string
]
struct
{}
acceptedDBChecksum
map
[
string
]
struct
{}
acceptedChecksums
map
[
string
]
struct
{}
}
}
// migrationChecksumCompatibilityRules 仅用于兼容历史上误修改过的迁移文件 checksum。
// migrationChecksumCompatibilityRules 仅用于兼容历史上误修改过的迁移文件 checksum。
// 规则必须同时匹配「迁移名 + 当前文件 checksum + 历史库 checksum」才会放行,避免放宽全局校验。
// 规则必须同时匹配「迁移名 + 数据库 checksum + 当前文件 checksum」且两者都落在该迁移的已知版本集合内才会放行,
// 避免放宽全局校验,也允许将误改的历史 migration 回滚为已发布版本而不要求人工修 checksum。
var
migrationChecksumCompatibilityRules
=
map
[
string
]
migrationChecksumCompatibilityRule
{
var
migrationChecksumCompatibilityRules
=
map
[
string
]
migrationChecksumCompatibilityRule
{
"054_drop_legacy_cache_columns.sql"
:
{
"054_drop_legacy_cache_columns.sql"
:
newMigrationChecksumCompatibilityRule
(
"82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d"
,
"182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4"
),
fileChecksum
:
"82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d"
,
"061_add_usage_log_request_type.sql"
:
newMigrationChecksumCompatibilityRule
(
"66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c"
,
"08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0"
,
"222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3"
),
acceptedDBChecksum
:
map
[
string
]
struct
{}{
"109_auth_identity_compat_backfill.sql"
:
newMigrationChecksumCompatibilityRule
(
"0580b4602d85435edf9aca1633db580bb3932f26517f75134106f80275ec2ace"
,
"551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee"
),
"182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4"
:
{},
"110_pending_auth_and_provider_default_grants.sql"
:
newMigrationChecksumCompatibilityRule
(
"32cf87ee787b1bb36b5c691367c96eee37518fa3eed6f3322cf68795e3745279"
,
"e3d1f433be2b564cfbdc549adf98fce13c5c7b363ebc20fd05b765d0563b0925"
),
},
"112_add_payment_order_provider_key_snapshot.sql"
:
newMigrationChecksumCompatibilityRule
(
"b75f8f56d39455682787696a3d92ad25b055444ca328fb7fca9a460a15d68d99"
,
"ffd3e8a2c9295fa9cbefefd629a78268877e5b51bc970a82d9b3f46ec4ebd15e"
),
},
"115_auth_identity_legacy_external_backfill.sql"
:
newMigrationChecksumCompatibilityRule
(
"022aadd97bb53e755f0cf7a3a957e0cb1a1353b0c39ec4de3234acd2871fd04f"
,
"4cf39e508be9fd1a5aa41610cbbebeb80385c9adda45bf78a706de9db4f1385f"
),
"061_add_usage_log_request_type.sql"
:
{
"116_auth_identity_legacy_external_safety_reports.sql"
:
newMigrationChecksumCompatibilityRule
(
"07edb09fa8d04ffb172b0621e3c22f4d1757d20a24ae267b3b36b087ab72d488"
,
"f7757bd929ac67ffb08ce69fa4cf20fad39dbff9d5a5085fb2adabb7607e5877"
),
fileChecksum
:
"66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c"
,
"118_wechat_dual_mode_and_auth_source_defaults.sql"
:
newMigrationChecksumCompatibilityRule
(
"b54194d7a3e4fbf710e0a3590d22a2fe7966804c487052a356e0b55f53ef96b0"
,
"e0cdf835d6c688d64100f483d31bc02ac9ebad414bf1837af239a84bf75b8227"
,
"a38243ca0a72c3a01c0a92b7986423054d6133c0399441f853b99802852720fb"
),
acceptedDBChecksum
:
map
[
string
]
struct
{}{
"119_enforce_payment_orders_out_trade_no_unique.sql"
:
newMigrationChecksumCompatibilityRule
(
"0bbe809ae48a9d811dabda1ba1c74955bd71c4a9cc610f9128816818dfa6c11e"
,
"ebd2c67cce0116393fb4f1b5d5116a67c6aceb73820dfb5133d1ff6f36d72d34"
),
"08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0"
:
{},
"120_enforce_payment_orders_out_trade_no_unique_notx.sql"
:
newMigrationChecksumCompatibilityRule
(
"34aadc0db59a4e390f92a12b73bd74642d9724f33124f73638ae00089ea5e074"
,
"e77921f79d539bc24575cb9c16cbe566d2b23ce816190343d0a7568f6a3fcf61"
,
"707431450603e70a43ce9fbd61e0c12fa67da4875158ccefabacea069587ab22"
,
"04b082b5a239c525154fe9185d324ee2b05ff90da9297e10dba19f9be79aa59a"
),
"222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3"
:
{},
"123_fix_legacy_auth_source_grant_on_signup_defaults.sql"
:
newMigrationChecksumCompatibilityRule
(
"2ce43c2cd89e9f9e1febd34a407ed9e84d177386c5544b6f02c1f58a21129f57"
,
"6cd33422f215dcd1f486ab6f35c0ea5805d9ca69bb25906d94bc649156657145"
),
},
},
"109_auth_identity_compat_backfill.sql"
:
{
fileChecksum
:
"551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee"
,
acceptedDBChecksum
:
map
[
string
]
struct
{}{
"2b380305e73ff0c13aa8c811e45897f2b36ca4a438f7b3e8f98e19ecb6bae0b3"
:
{},
},
},
}
}
// ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。
// ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。
...
@@ -205,6 +201,10 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
...
@@ -205,6 +201,10 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
}
}
if
nonTx
{
if
nonTx
{
if
err
:=
prepareNonTransactionalMigration
(
ctx
,
db
,
name
);
err
!=
nil
{
return
fmt
.
Errorf
(
"prepare migration %s: %w"
,
name
,
err
)
}
// *_notx.sql:用于 CREATE/DROP INDEX CONCURRENTLY 场景,必须非事务执行。
// *_notx.sql:用于 CREATE/DROP INDEX CONCURRENTLY 场景,必须非事务执行。
// 逐条语句执行,避免将多条 CONCURRENTLY 语句放入同一个隐式事务块。
// 逐条语句执行,避免将多条 CONCURRENTLY 语句放入同一个隐式事务块。
statements
:=
splitSQLStatements
(
content
)
statements
:=
splitSQLStatements
(
content
)
...
@@ -254,6 +254,90 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
...
@@ -254,6 +254,90 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
return
nil
return
nil
}
}
func
prepareNonTransactionalMigration
(
ctx
context
.
Context
,
db
*
sql
.
DB
,
name
string
)
error
{
switch
name
{
case
paymentOrdersOutTradeNoUniqueMigration
:
return
preparePaymentOrdersOutTradeNoUniqueMigration
(
ctx
,
db
)
default
:
return
nil
}
}
func
preparePaymentOrdersOutTradeNoUniqueMigration
(
ctx
context
.
Context
,
db
*
sql
.
DB
)
error
{
duplicates
,
err
:=
findDuplicatePaymentOrderOutTradeNos
(
ctx
,
db
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"precheck duplicate out_trade_no: %w"
,
err
)
}
if
len
(
duplicates
)
>
0
{
return
fmt
.
Errorf
(
"duplicate out_trade_no values block %s; remediate duplicates before retrying: %s"
,
paymentOrdersOutTradeNoUniqueMigration
,
strings
.
Join
(
duplicates
,
", "
),
)
}
invalid
,
err
:=
indexIsInvalid
(
ctx
,
db
,
paymentOrdersOutTradeNoUniqueIndex
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"check invalid index %s: %w"
,
paymentOrdersOutTradeNoUniqueIndex
,
err
)
}
if
!
invalid
{
return
nil
}
if
_
,
err
:=
db
.
ExecContext
(
ctx
,
fmt
.
Sprintf
(
"DROP INDEX CONCURRENTLY IF EXISTS %s"
,
paymentOrdersOutTradeNoUniqueIndex
));
err
!=
nil
{
return
fmt
.
Errorf
(
"drop invalid index %s: %w"
,
paymentOrdersOutTradeNoUniqueIndex
,
err
)
}
return
nil
}
func
findDuplicatePaymentOrderOutTradeNos
(
ctx
context
.
Context
,
db
*
sql
.
DB
)
([]
string
,
error
)
{
rows
,
err
:=
db
.
QueryContext
(
ctx
,
`
SELECT out_trade_no, COUNT(*) AS duplicate_count
FROM payment_orders
WHERE out_trade_no <> ''
GROUP BY out_trade_no
HAVING COUNT(*) > 1
ORDER BY duplicate_count DESC, out_trade_no
LIMIT 5
`
)
if
err
!=
nil
{
return
nil
,
err
}
defer
func
()
{
_
=
rows
.
Close
()
}()
duplicates
:=
make
([]
string
,
0
,
5
)
for
rows
.
Next
()
{
var
outTradeNo
string
var
duplicateCount
int
if
err
:=
rows
.
Scan
(
&
outTradeNo
,
&
duplicateCount
);
err
!=
nil
{
return
nil
,
err
}
duplicates
=
append
(
duplicates
,
fmt
.
Sprintf
(
"%s (count=%d)"
,
outTradeNo
,
duplicateCount
))
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
nil
,
err
}
return
duplicates
,
nil
}
func
indexIsInvalid
(
ctx
context
.
Context
,
db
*
sql
.
DB
,
indexName
string
)
(
bool
,
error
)
{
var
invalid
bool
err
:=
db
.
QueryRowContext
(
ctx
,
`
SELECT EXISTS (
SELECT 1
FROM pg_class idx
JOIN pg_namespace ns ON ns.oid = idx.relnamespace
JOIN pg_index i ON i.indexrelid = idx.oid
WHERE ns.nspname = 'public'
AND idx.relname = $1
AND NOT i.indisvalid
)
`
,
indexName
)
.
Scan
(
&
invalid
)
return
invalid
,
err
}
func
ensureAtlasBaselineAligned
(
ctx
context
.
Context
,
db
*
sql
.
DB
,
fsys
fs
.
FS
)
error
{
func
ensureAtlasBaselineAligned
(
ctx
context
.
Context
,
db
*
sql
.
DB
,
fsys
fs
.
FS
)
error
{
hasLegacy
,
err
:=
tableExists
(
ctx
,
db
,
"schema_migrations"
)
hasLegacy
,
err
:=
tableExists
(
ctx
,
db
,
"schema_migrations"
)
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -328,16 +412,33 @@ func latestMigrationBaseline(fsys fs.FS) (string, string, string, error) {
...
@@ -328,16 +412,33 @@ func latestMigrationBaseline(fsys fs.FS) (string, string, string, error) {
return
version
,
version
,
hash
,
nil
return
version
,
version
,
hash
,
nil
}
}
func
checksumSet
(
values
...
string
)
map
[
string
]
struct
{}
{
out
:=
make
(
map
[
string
]
struct
{},
len
(
values
))
for
_
,
value
:=
range
values
{
out
[
value
]
=
struct
{}{}
}
return
out
}
func
newMigrationChecksumCompatibilityRule
(
fileChecksum
string
,
acceptedDBChecksums
...
string
)
migrationChecksumCompatibilityRule
{
return
migrationChecksumCompatibilityRule
{
fileChecksum
:
fileChecksum
,
acceptedDBChecksum
:
checksumSet
(
acceptedDBChecksums
...
),
acceptedChecksums
:
checksumSet
(
append
([]
string
{
fileChecksum
},
acceptedDBChecksums
...
)
...
),
}
}
func
isMigrationChecksumCompatible
(
name
,
dbChecksum
,
fileChecksum
string
)
bool
{
func
isMigrationChecksumCompatible
(
name
,
dbChecksum
,
fileChecksum
string
)
bool
{
rule
,
ok
:=
migrationChecksumCompatibilityRules
[
name
]
rule
,
ok
:=
migrationChecksumCompatibilityRules
[
name
]
if
!
ok
{
if
!
ok
{
return
false
return
false
}
}
if
rule
.
fileChecksum
!=
fileChecksum
{
_
,
dbOK
:=
rule
.
acceptedChecksums
[
dbChecksum
]
if
!
dbOK
{
return
false
return
false
}
}
_
,
ok
=
rule
.
accepted
DB
Checksum
[
db
Checksum
]
_
,
fileOK
:
=
rule
.
acceptedChecksum
s
[
file
Checksum
]
return
ok
return
fileOK
}
}
func
validateMigrationExecutionMode
(
name
,
content
string
)
(
bool
,
error
)
{
func
validateMigrationExecutionMode
(
name
,
content
string
)
(
bool
,
error
)
{
...
...
backend/internal/repository/migrations_runner_checksum_test.go
View file @
ddf80f5e
...
@@ -55,9 +55,110 @@ func TestIsMigrationChecksumCompatible(t *testing.T) {
...
@@ -55,9 +55,110 @@ func TestIsMigrationChecksumCompatible(t *testing.T) {
t
.
Run
(
"109历史checksum可兼容"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"109历史checksum可兼容"
,
func
(
t
*
testing
.
T
)
{
ok
:=
isMigrationChecksumCompatible
(
ok
:=
isMigrationChecksumCompatible
(
"109_auth_identity_compat_backfill.sql"
,
"109_auth_identity_compat_backfill.sql"
,
"2b380305e73ff0c13aa8c811e45897f2b36ca4a438f7b3e8f98e19ecb6bae0b3"
,
"551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee"
,
"551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee"
,
"0580b4602d85435edf9aca1633db580bb3932f26517f75134106f80275ec2ace"
,
)
)
require
.
True
(
t
,
ok
)
require
.
True
(
t
,
ok
)
})
})
t
.
Run
(
"109当前checksum可兼容历史checksum"
,
func
(
t
*
testing
.
T
)
{
ok
:=
isMigrationChecksumCompatible
(
"109_auth_identity_compat_backfill.sql"
,
"551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee"
,
"0580b4602d85435edf9aca1633db580bb3932f26517f75134106f80275ec2ace"
,
)
require
.
True
(
t
,
ok
)
})
t
.
Run
(
"109回滚到历史文件后仍兼容已应用的新checksum"
,
func
(
t
*
testing
.
T
)
{
ok
:=
isMigrationChecksumCompatible
(
"109_auth_identity_compat_backfill.sql"
,
"0580b4602d85435edf9aca1633db580bb3932f26517f75134106f80275ec2ace"
,
"551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee"
,
)
require
.
True
(
t
,
ok
)
})
t
.
Run
(
"110历史checksum可兼容"
,
func
(
t
*
testing
.
T
)
{
ok
:=
isMigrationChecksumCompatible
(
"110_pending_auth_and_provider_default_grants.sql"
,
"e3d1f433be2b564cfbdc549adf98fce13c5c7b363ebc20fd05b765d0563b0925"
,
"32cf87ee787b1bb36b5c691367c96eee37518fa3eed6f3322cf68795e3745279"
,
)
require
.
True
(
t
,
ok
)
})
t
.
Run
(
"112历史checksum可兼容"
,
func
(
t
*
testing
.
T
)
{
ok
:=
isMigrationChecksumCompatible
(
"112_add_payment_order_provider_key_snapshot.sql"
,
"ffd3e8a2c9295fa9cbefefd629a78268877e5b51bc970a82d9b3f46ec4ebd15e"
,
"b75f8f56d39455682787696a3d92ad25b055444ca328fb7fca9a460a15d68d99"
,
)
require
.
True
(
t
,
ok
)
})
t
.
Run
(
"115历史checksum可兼容修复后的legacy external backfill"
,
func
(
t
*
testing
.
T
)
{
ok
:=
isMigrationChecksumCompatible
(
"115_auth_identity_legacy_external_backfill.sql"
,
"4cf39e508be9fd1a5aa41610cbbebeb80385c9adda45bf78a706de9db4f1385f"
,
"022aadd97bb53e755f0cf7a3a957e0cb1a1353b0c39ec4de3234acd2871fd04f"
,
)
require
.
True
(
t
,
ok
)
})
t
.
Run
(
"116历史checksum可兼容修复后的legacy external safety reports"
,
func
(
t
*
testing
.
T
)
{
ok
:=
isMigrationChecksumCompatible
(
"116_auth_identity_legacy_external_safety_reports.sql"
,
"f7757bd929ac67ffb08ce69fa4cf20fad39dbff9d5a5085fb2adabb7607e5877"
,
"07edb09fa8d04ffb172b0621e3c22f4d1757d20a24ae267b3b36b087ab72d488"
,
)
require
.
True
(
t
,
ok
)
})
t
.
Run
(
"119历史checksum可兼容占位文件"
,
func
(
t
*
testing
.
T
)
{
ok
:=
isMigrationChecksumCompatible
(
"119_enforce_payment_orders_out_trade_no_unique.sql"
,
"ebd2c67cce0116393fb4f1b5d5116a67c6aceb73820dfb5133d1ff6f36d72d34"
,
"0bbe809ae48a9d811dabda1ba1c74955bd71c4a9cc610f9128816818dfa6c11e"
,
)
require
.
True
(
t
,
ok
)
})
t
.
Run
(
"118多个历史checksum都可兼容当前版本"
,
func
(
t
*
testing
.
T
)
{
for
_
,
dbChecksum
:=
range
[]
string
{
"a38243ca0a72c3a01c0a92b7986423054d6133c0399441f853b99802852720fb"
,
"e0cdf835d6c688d64100f483d31bc02ac9ebad414bf1837af239a84bf75b8227"
,
}
{
ok
:=
isMigrationChecksumCompatible
(
"118_wechat_dual_mode_and_auth_source_defaults.sql"
,
dbChecksum
,
"b54194d7a3e4fbf710e0a3590d22a2fe7966804c487052a356e0b55f53ef96b0"
,
)
require
.
True
(
t
,
ok
)
}
})
t
.
Run
(
"120多个历史checksum都可兼容新的notx修复版本"
,
func
(
t
*
testing
.
T
)
{
for
_
,
dbChecksum
:=
range
[]
string
{
"e77921f79d539bc24575cb9c16cbe566d2b23ce816190343d0a7568f6a3fcf61"
,
"707431450603e70a43ce9fbd61e0c12fa67da4875158ccefabacea069587ab22"
,
"04b082b5a239c525154fe9185d324ee2b05ff90da9297e10dba19f9be79aa59a"
,
}
{
ok
:=
isMigrationChecksumCompatible
(
"120_enforce_payment_orders_out_trade_no_unique_notx.sql"
,
dbChecksum
,
"34aadc0db59a4e390f92a12b73bd74642d9724f33124f73638ae00089ea5e074"
,
)
require
.
True
(
t
,
ok
)
}
})
t
.
Run
(
"119未知checksum不兼容"
,
func
(
t
*
testing
.
T
)
{
ok
:=
isMigrationChecksumCompatible
(
"119_enforce_payment_orders_out_trade_no_unique.sql"
,
"ebd2c67cce0116393fb4f1b5d5116a67c6aceb73820dfb5133d1ff6f36d72d34"
,
"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"
,
)
require
.
False
(
t
,
ok
)
})
}
}
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment