Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
ddf80f5e
Unverified
Commit
ddf80f5e
authored
Apr 22, 2026
by
Wesley Liddick
Committed by
GitHub
Apr 22, 2026
Browse files
Merge pull request #1799 from IanShaw027/rebuild/auth-identity-foundation
fix(auth,payment,profile): 修复认证身份和支付系统的后续问题
parents
4d0483f5
c048ca80
Changes
140
Hide whitespace changes
Inline
Side-by-side
backend/internal/handler/auth_oauth_logout_test.go
0 → 100644
View file @
ddf80f5e
package
handler
import
(
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func
TestLogoutClearsOAuthStateCookiesAndConsumesPendingSession
(
t
*
testing
.
T
)
{
handler
,
client
:=
newOAuthPendingFlowTestHandler
(
t
,
false
)
ctx
:=
context
.
Background
()
session
,
err
:=
client
.
PendingAuthSession
.
Create
()
.
SetSessionToken
(
"logout-pending-session-token"
)
.
SetIntent
(
"login"
)
.
SetProviderType
(
"oidc"
)
.
SetProviderKey
(
"https://issuer.example"
)
.
SetProviderSubject
(
"logout-subject-123"
)
.
SetBrowserSessionKey
(
"logout-browser-session-key"
)
.
SetResolvedEmail
(
"logout@example.com"
)
.
SetExpiresAt
(
time
.
Now
()
.
UTC
()
.
Add
(
10
*
time
.
Minute
))
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
recorder
:=
httptest
.
NewRecorder
()
ginCtx
,
_
:=
gin
.
CreateTestContext
(
recorder
)
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/auth/logout"
,
nil
)
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthPendingSessionCookieName
,
Value
:
encodeCookieValue
(
session
.
SessionToken
)})
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthPendingBrowserCookieName
,
Value
:
encodeCookieValue
(
"logout-browser-session-key"
)})
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthBindAccessTokenCookieName
,
Value
:
"bind-access-token"
})
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
linuxDoOAuthStateCookieName
,
Value
:
encodeCookieValue
(
"linuxdo-state"
)})
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oidcOAuthStateCookieName
,
Value
:
encodeCookieValue
(
"oidc-state"
)})
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
wechatOAuthStateCookieName
,
Value
:
encodeCookieValue
(
"wechat-state"
)})
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
wechatPaymentOAuthStateName
,
Value
:
encodeCookieValue
(
"wechat-payment-state"
)})
ginCtx
.
Request
=
req
handler
.
Logout
(
ginCtx
)
require
.
Equal
(
t
,
http
.
StatusOK
,
recorder
.
Code
)
cookies
:=
recorder
.
Result
()
.
Cookies
()
for
_
,
name
:=
range
[]
string
{
oauthPendingSessionCookieName
,
oauthPendingBrowserCookieName
,
oauthBindAccessTokenCookieName
,
linuxDoOAuthStateCookieName
,
oidcOAuthStateCookieName
,
wechatOAuthStateCookieName
,
wechatPaymentOAuthStateName
,
}
{
cookie
:=
findCookie
(
cookies
,
name
)
require
.
NotNil
(
t
,
cookie
,
name
)
require
.
Equal
(
t
,
-
1
,
cookie
.
MaxAge
,
name
)
require
.
True
(
t
,
cookie
.
HttpOnly
,
name
)
}
storedSession
,
err
:=
client
.
PendingAuthSession
.
Query
()
.
Where
(
pendingauthsession
.
IDEQ
(
session
.
ID
))
.
Only
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
storedSession
.
ConsumedAt
)
}
backend/internal/handler/auth_oauth_pending_flow.go
View file @
ddf80f5e
...
...
@@ -265,16 +265,20 @@ func pendingSessionWantsInvitation(payload map[string]any) bool {
return
strings
.
EqualFold
(
strings
.
TrimSpace
(
pendingSessionStringValue
(
payload
,
"error"
)),
"invitation_required"
)
}
func
pendingOAuthCompletion
IncludesTokenPayload
(
payload
map
[
string
]
any
)
bool
{
if
len
(
payload
)
==
0
{
func
pendingOAuthCompletion
CanIssueTokenPair
(
session
*
dbent
.
PendingAuthSession
,
payload
map
[
string
]
any
)
bool
{
if
session
==
nil
{
return
false
}
for
_
,
key
:=
range
[]
string
{
"access_token"
,
"refresh_token"
}
{
if
value
:=
pendingSessionStringValue
(
payload
,
key
);
value
!=
""
{
return
true
}
if
!
strings
.
EqualFold
(
strings
.
TrimSpace
(
session
.
Intent
),
oauthIntentLogin
)
{
return
false
}
if
session
.
TargetUserID
==
nil
||
*
session
.
TargetUserID
<=
0
{
return
false
}
if
pendingSessionWantsInvitation
(
payload
)
{
return
false
}
return
false
return
strings
.
TrimSpace
(
pendingSessionStringValue
(
payload
,
"step"
))
==
""
}
func
ensurePendingOAuthCompleteRegistrationSession
(
session
*
dbent
.
PendingAuthSession
)
error
{
...
...
@@ -294,6 +298,78 @@ func ensurePendingOAuthCompleteRegistrationSession(session *dbent.PendingAuthSes
return
nil
}
func
buildLegacyCompleteRegistrationPendingResponse
(
session
*
dbent
.
PendingAuthSession
,
forceEmailOnSignup
bool
,
emailVerificationRequired
bool
,
)
map
[
string
]
any
{
completionResponse
:=
normalizePendingOAuthCompletionResponse
(
mergePendingCompletionResponse
(
session
,
map
[
string
]
any
{
"step"
:
oauthPendingChoiceStep
,
"adoption_required"
:
true
,
"create_account_allowed"
:
true
,
"force_email_on_signup"
:
forceEmailOnSignup
,
}))
if
email
:=
strings
.
TrimSpace
(
session
.
ResolvedEmail
);
email
!=
""
{
if
_
,
exists
:=
completionResponse
[
"email"
];
!
exists
{
completionResponse
[
"email"
]
=
email
}
if
_
,
exists
:=
completionResponse
[
"resolved_email"
];
!
exists
{
completionResponse
[
"resolved_email"
]
=
email
}
}
if
_
,
exists
:=
completionResponse
[
"choice_reason"
];
!
exists
{
switch
{
case
forceEmailOnSignup
:
completionResponse
[
"choice_reason"
]
=
"force_email_on_signup"
case
emailVerificationRequired
:
completionResponse
[
"choice_reason"
]
=
"email_verification_required"
default
:
completionResponse
[
"choice_reason"
]
=
"third_party_signup"
}
}
return
completionResponse
}
func
(
h
*
AuthHandler
)
legacyCompleteRegistrationSessionStatus
(
c
*
gin
.
Context
,
session
*
dbent
.
PendingAuthSession
,
)
(
*
dbent
.
PendingAuthSession
,
bool
,
error
)
{
if
session
==
nil
{
return
nil
,
false
,
infraerrors
.
BadRequest
(
"PENDING_AUTH_SESSION_INVALID"
,
"pending auth registration context is invalid"
)
}
payload
:=
normalizePendingOAuthCompletionResponse
(
mergePendingCompletionResponse
(
session
,
nil
))
if
step
:=
pendingSessionStringValue
(
payload
,
"step"
);
step
!=
""
{
return
session
,
true
,
nil
}
emailVerificationRequired
:=
h
!=
nil
&&
h
.
authService
!=
nil
&&
h
.
authService
.
IsEmailVerifyEnabled
(
c
.
Request
.
Context
())
forceEmailOnSignup
:=
h
.
isForceEmailOnThirdPartySignup
(
c
.
Request
.
Context
())
if
!
emailVerificationRequired
&&
!
forceEmailOnSignup
{
return
session
,
false
,
nil
}
client
:=
h
.
entClient
()
if
client
==
nil
{
return
nil
,
false
,
infraerrors
.
ServiceUnavailable
(
"PENDING_AUTH_NOT_READY"
,
"pending auth service is not ready"
)
}
updatedSession
,
err
:=
updatePendingOAuthSessionProgress
(
c
.
Request
.
Context
(),
client
,
session
,
strings
.
TrimSpace
(
session
.
Intent
),
strings
.
TrimSpace
(
session
.
ResolvedEmail
),
nil
,
buildLegacyCompleteRegistrationPendingResponse
(
session
,
forceEmailOnSignup
,
emailVerificationRequired
),
)
if
err
!=
nil
{
return
nil
,
false
,
infraerrors
.
InternalServer
(
"PENDING_AUTH_SESSION_UPDATE_FAILED"
,
"failed to update pending oauth session"
)
.
WithCause
(
err
)
}
return
updatedSession
,
true
,
nil
}
func
(
r
oauthAdoptionDecisionRequest
)
hasDecision
()
bool
{
return
r
.
AdoptDisplayName
!=
nil
||
r
.
AdoptAvatar
!=
nil
}
...
...
@@ -376,15 +452,7 @@ func (h *AuthHandler) findOAuthIdentityUser(ctx context.Context, identity servic
}
return
nil
,
infraerrors
.
InternalServer
(
"AUTH_IDENTITY_LOOKUP_FAILED"
,
"failed to inspect auth identity ownership"
)
.
WithCause
(
err
)
}
userEntity
,
err
:=
client
.
User
.
Get
(
ctx
,
record
.
UserID
)
if
err
!=
nil
{
if
dbent
.
IsNotFound
(
err
)
{
return
nil
,
nil
}
return
nil
,
infraerrors
.
InternalServer
(
"AUTH_IDENTITY_USER_LOOKUP_FAILED"
,
"failed to load auth identity user"
)
.
WithCause
(
err
)
}
return
userEntity
,
nil
return
findActiveUserByID
(
ctx
,
client
,
record
.
UserID
)
}
func
(
h
*
AuthHandler
)
BindLinuxDoOAuthLogin
(
c
*
gin
.
Context
)
{
h
.
bindPendingOAuthLogin
(
c
,
"linuxdo"
)
}
...
...
@@ -439,7 +507,7 @@ func (h *AuthHandler) SendPendingOAuthVerifyCode(c *gin.Context) {
email
:=
strings
.
TrimSpace
(
strings
.
ToLower
(
req
.
Email
))
if
existingUser
,
err
:=
findUserByNormalizedEmail
(
c
.
Request
.
Context
(),
client
,
email
);
err
==
nil
&&
existingUser
!=
nil
{
session
,
err
=
h
.
transitionPendingOAuthAccountToChoiceState
(
c
,
client
,
session
,
email
)
session
,
err
=
h
.
transitionPendingOAuthAccountToChoiceState
(
c
,
client
,
session
,
existingUser
,
email
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
...
...
@@ -624,6 +692,38 @@ func findUserByNormalizedEmail(ctx context.Context, client *dbent.Client, email
return
matches
[
0
],
nil
}
func
ensurePendingOAuthRegistrationIdentityAvailable
(
ctx
context
.
Context
,
client
*
dbent
.
Client
,
session
*
dbent
.
PendingAuthSession
)
error
{
if
client
==
nil
||
session
==
nil
{
return
infraerrors
.
BadRequest
(
"PENDING_AUTH_SESSION_INVALID"
,
"pending auth registration context is invalid"
)
}
identity
,
err
:=
client
.
AuthIdentity
.
Query
()
.
Where
(
authidentity
.
ProviderTypeEQ
(
strings
.
TrimSpace
(
session
.
ProviderType
)),
authidentity
.
ProviderKeyEQ
(
strings
.
TrimSpace
(
session
.
ProviderKey
)),
authidentity
.
ProviderSubjectEQ
(
strings
.
TrimSpace
(
session
.
ProviderSubject
)),
)
.
Only
(
ctx
)
if
err
!=
nil
{
if
dbent
.
IsNotFound
(
err
)
{
return
nil
}
return
err
}
if
identity
==
nil
||
identity
.
UserID
<=
0
{
return
nil
}
activeOwner
,
err
:=
findActiveUserByID
(
ctx
,
client
,
identity
.
UserID
)
if
err
!=
nil
{
return
err
}
if
activeOwner
!=
nil
{
return
infraerrors
.
Conflict
(
"AUTH_IDENTITY_OWNERSHIP_CONFLICT"
,
"auth identity already belongs to another user"
)
}
return
nil
}
func
oauthIdentityIssuer
(
session
*
dbent
.
PendingAuthSession
)
*
string
{
if
session
==
nil
{
return
nil
...
...
@@ -910,6 +1010,9 @@ func findActiveUserByID(ctx context.Context, client *dbent.Client, userID int64)
}
return
nil
,
infraerrors
.
InternalServer
(
"AUTH_IDENTITY_USER_LOOKUP_FAILED"
,
"failed to load auth identity user"
)
.
WithCause
(
err
)
}
if
!
strings
.
EqualFold
(
strings
.
TrimSpace
(
userEntity
.
Status
),
service
.
StatusActive
)
{
return
nil
,
service
.
ErrUserNotActive
}
return
userEntity
,
nil
}
...
...
@@ -1123,6 +1226,38 @@ func consumePendingOAuthBrowserSessionTx(
return
nil
}
func
applyPendingOAuthAdoptionAndConsumeSession
(
ctx
context
.
Context
,
client
*
dbent
.
Client
,
authService
*
service
.
AuthService
,
userService
*
service
.
UserService
,
session
*
dbent
.
PendingAuthSession
,
decision
*
dbent
.
IdentityAdoptionDecision
,
userID
int64
,
)
error
{
if
client
==
nil
{
return
infraerrors
.
ServiceUnavailable
(
"PENDING_AUTH_NOT_READY"
,
"pending auth service is not ready"
)
}
if
session
==
nil
||
userID
<=
0
{
return
infraerrors
.
BadRequest
(
"PENDING_AUTH_SESSION_INVALID"
,
"pending auth registration context is invalid"
)
}
tx
,
err
:=
client
.
Tx
(
ctx
)
if
err
!=
nil
{
return
err
}
defer
func
()
{
_
=
tx
.
Rollback
()
}()
txCtx
:=
dbent
.
NewTxContext
(
ctx
,
tx
)
if
err
:=
applyPendingOAuthAdoption
(
txCtx
,
client
,
authService
,
userService
,
session
,
decision
,
&
userID
);
err
!=
nil
{
return
err
}
if
err
:=
consumePendingOAuthBrowserSessionTx
(
txCtx
,
tx
,
session
);
err
!=
nil
{
return
err
}
return
tx
.
Commit
()
}
func
applyPendingOAuthAdoption
(
ctx
context
.
Context
,
client
*
dbent
.
Client
,
...
...
@@ -1212,13 +1347,7 @@ func (h *AuthHandler) shouldSkipPendingOAuthAdoptionPrompt(
if
session
==
nil
||
len
(
payload
)
==
0
{
return
false
,
nil
}
if
!
strings
.
EqualFold
(
strings
.
TrimSpace
(
session
.
Intent
),
oauthIntentLogin
)
{
return
false
,
nil
}
if
!
pendingOAuthCompletionIncludesTokenPayload
(
payload
)
{
return
false
,
nil
}
if
session
.
TargetUserID
==
nil
||
*
session
.
TargetUserID
<=
0
{
if
!
pendingOAuthCompletionCanIssueTokenPair
(
session
,
payload
)
{
return
false
,
nil
}
if
pendingSessionStringValue
(
session
.
UpstreamIdentityClaims
,
"suggested_display_name"
)
==
""
&&
...
...
@@ -1262,6 +1391,59 @@ func readPendingOAuthBrowserSession(c *gin.Context, h *AuthHandler) (*service.Au
return
svc
,
session
,
clearCookies
,
nil
}
func
(
h
*
AuthHandler
)
consumePendingOAuthSessionOnLogout
(
c
*
gin
.
Context
)
{
if
c
==
nil
||
c
.
Request
==
nil
{
return
}
sessionToken
,
err
:=
readOAuthPendingSessionCookie
(
c
)
if
err
!=
nil
||
strings
.
TrimSpace
(
sessionToken
)
==
""
{
return
}
browserSessionKey
,
err
:=
readOAuthPendingBrowserCookie
(
c
)
if
err
!=
nil
||
strings
.
TrimSpace
(
browserSessionKey
)
==
""
{
return
}
svc
,
err
:=
h
.
pendingIdentityService
()
if
err
!=
nil
{
return
}
_
,
_
=
svc
.
ConsumeBrowserSession
(
c
.
Request
.
Context
(),
sessionToken
,
browserSessionKey
)
}
func
clearOAuthLogoutCookies
(
c
*
gin
.
Context
)
{
secureCookie
:=
isRequestHTTPS
(
c
)
clearOAuthPendingSessionCookie
(
c
,
secureCookie
)
clearOAuthPendingBrowserCookie
(
c
,
secureCookie
)
clearOAuthBindAccessTokenCookie
(
c
,
secureCookie
)
clearCookie
(
c
,
linuxDoOAuthStateCookieName
,
secureCookie
)
clearCookie
(
c
,
linuxDoOAuthVerifierCookie
,
secureCookie
)
clearCookie
(
c
,
linuxDoOAuthRedirectCookie
,
secureCookie
)
clearCookie
(
c
,
linuxDoOAuthIntentCookieName
,
secureCookie
)
clearCookie
(
c
,
linuxDoOAuthBindUserCookieName
,
secureCookie
)
oidcClearCookie
(
c
,
oidcOAuthStateCookieName
,
secureCookie
)
oidcClearCookie
(
c
,
oidcOAuthVerifierCookie
,
secureCookie
)
oidcClearCookie
(
c
,
oidcOAuthRedirectCookie
,
secureCookie
)
oidcClearCookie
(
c
,
oidcOAuthNonceCookie
,
secureCookie
)
oidcClearCookie
(
c
,
oidcOAuthIntentCookieName
,
secureCookie
)
oidcClearCookie
(
c
,
oidcOAuthBindUserCookieName
,
secureCookie
)
wechatClearCookie
(
c
,
wechatOAuthStateCookieName
,
secureCookie
)
wechatClearCookie
(
c
,
wechatOAuthRedirectCookieName
,
secureCookie
)
wechatClearCookie
(
c
,
wechatOAuthIntentCookieName
,
secureCookie
)
wechatClearCookie
(
c
,
wechatOAuthModeCookieName
,
secureCookie
)
wechatClearCookie
(
c
,
wechatOAuthBindUserCookieName
,
secureCookie
)
wechatPaymentClearCookie
(
c
,
wechatPaymentOAuthStateName
,
secureCookie
)
wechatPaymentClearCookie
(
c
,
wechatPaymentOAuthRedirect
,
secureCookie
)
wechatPaymentClearCookie
(
c
,
wechatPaymentOAuthContextName
,
secureCookie
)
wechatPaymentClearCookie
(
c
,
wechatPaymentOAuthScope
,
secureCookie
)
}
func
buildPendingOAuthSessionStatusPayload
(
session
*
dbent
.
PendingAuthSession
)
gin
.
H
{
completionResponse
:=
normalizePendingOAuthCompletionResponse
(
mergePendingCompletionResponse
(
session
,
nil
))
payload
:=
gin
.
H
{
...
...
@@ -1280,6 +1462,9 @@ func buildPendingOAuthSessionStatusPayload(session *dbent.PendingAuthSession) gi
func
normalizePendingOAuthCompletionResponse
(
payload
map
[
string
]
any
)
map
[
string
]
any
{
normalized
:=
clonePendingMap
(
payload
)
for
_
,
key
:=
range
[]
string
{
"access_token"
,
"refresh_token"
,
"expires_in"
,
"token_type"
}
{
delete
(
normalized
,
key
)
}
step
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
pendingSessionStringValue
(
normalized
,
"step"
)))
switch
step
{
case
"choice"
,
"choose_account_action"
,
"choose_account"
,
"choose"
,
"email_required"
,
"bind_login_required"
:
...
...
@@ -1315,16 +1500,21 @@ func (h *AuthHandler) transitionPendingOAuthAccountToChoiceState(
c
*
gin
.
Context
,
client
*
dbent
.
Client
,
session
*
dbent
.
PendingAuthSession
,
targetUser
*
dbent
.
User
,
email
string
,
)
(
*
dbent
.
PendingAuthSession
,
error
)
{
completionResponse
:=
pendingOAuthChoiceCompletionResponse
(
session
,
email
)
var
targetUserID
*
int64
if
targetUser
!=
nil
&&
targetUser
.
ID
>
0
{
targetUserID
=
&
targetUser
.
ID
}
session
,
err
:=
updatePendingOAuthSessionProgress
(
c
.
Request
.
Context
(),
client
,
session
,
strings
.
TrimSpace
(
session
.
Intent
),
email
,
nil
,
targetUserID
,
completionResponse
,
)
if
err
!=
nil
{
...
...
@@ -1438,6 +1628,10 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
response
.
ErrorFrom
(
c
,
err
)
return
}
if
err
:=
ensurePendingOAuthCompleteRegistrationSession
(
session
);
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
if
strings
.
TrimSpace
(
provider
)
!=
""
&&
!
strings
.
EqualFold
(
strings
.
TrimSpace
(
session
.
ProviderType
),
provider
)
{
response
.
BadRequest
(
c
,
"Pending oauth session provider mismatch"
)
return
...
...
@@ -1464,7 +1658,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
}
}
if
existingUser
!=
nil
{
session
,
err
=
h
.
transitionPendingOAuthAccountToChoiceState
(
c
,
client
,
session
,
email
)
session
,
err
=
h
.
transitionPendingOAuthAccountToChoiceState
(
c
,
client
,
session
,
existingUser
,
email
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
...
...
@@ -1487,7 +1681,12 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
service
.
ErrEmailExists
)
{
session
,
err
=
h
.
transitionPendingOAuthAccountToChoiceState
(
c
,
client
,
session
,
email
)
existingUser
,
lookupErr
:=
findUserByNormalizedEmail
(
c
.
Request
.
Context
(),
client
,
email
)
if
lookupErr
!=
nil
{
response
.
ErrorFrom
(
c
,
lookupErr
)
return
}
session
,
err
=
h
.
transitionPendingOAuthAccountToChoiceState
(
c
,
client
,
session
,
existingUser
,
email
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
...
...
@@ -1649,33 +1848,35 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
}
}
applySuggestedProfileToCompletionResponse
(
payload
,
session
.
UpstreamIdentityClaims
)
skipAdoptionPrompt
,
err
:=
h
.
shouldSkipPendingOAuthAdoptionPrompt
(
c
.
Request
.
Context
(),
session
,
payload
)
if
err
!=
nil
{
clearCookies
()
response
.
ErrorFrom
(
c
,
err
)
return
}
if
skipAdoptionPrompt
{
delete
(
payload
,
"adoption_required"
)
}
if
pendingOAuthCompletionIncludesTokenPayload
(
payload
)
{
if
session
.
TargetUserID
==
nil
||
*
session
.
TargetUserID
<=
0
{
canIssueTokenPair
:=
pendingOAuthCompletionCanIssueTokenPair
(
session
,
payload
)
var
loginUser
*
service
.
User
if
canIssueTokenPair
{
loginUser
,
err
=
h
.
userService
.
GetByID
(
c
.
Request
.
Context
(),
*
session
.
TargetUserID
)
if
err
!=
nil
{
clearCookies
()
response
.
ErrorFrom
(
c
,
infraerrors
.
InternalServer
(
"PENDING_AUTH_COMPLETION_INVALID"
,
"pending auth completion payload is invalid"
)
)
response
.
ErrorFrom
(
c
,
err
)
return
}
user
,
err
:=
h
.
userService
.
GetByID
(
c
.
Request
.
Context
(),
*
session
.
TargetUserID
)
if
err
!=
nil
{
if
err
:=
ensureLoginUserActive
(
loginUser
);
err
!=
nil
{
clearCookies
()
response
.
ErrorFrom
(
c
,
err
)
return
}
if
err
:=
h
.
ensureBackendModeAllowsUser
(
c
.
Request
.
Context
(),
u
ser
);
err
!=
nil
{
if
err
:=
h
.
ensureBackendModeAllowsUser
(
c
.
Request
.
Context
(),
loginU
ser
);
err
!=
nil
{
clearCookies
()
response
.
ErrorFrom
(
c
,
err
)
return
}
h
.
authService
.
RecordSuccessfulLogin
(
c
.
Request
.
Context
(),
user
.
ID
)
}
skipAdoptionPrompt
,
err
:=
h
.
shouldSkipPendingOAuthAdoptionPrompt
(
c
.
Request
.
Context
(),
session
,
payload
)
if
err
!=
nil
{
clearCookies
()
response
.
ErrorFrom
(
c
,
err
)
return
}
if
skipAdoptionPrompt
{
delete
(
payload
,
"adoption_required"
)
}
if
pendingSessionWantsInvitation
(
payload
)
{
...
...
@@ -1724,6 +1925,20 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
return
}
if
canIssueTokenPair
{
tokenPair
,
err
:=
h
.
authService
.
GenerateTokenPair
(
c
.
Request
.
Context
(),
loginUser
,
""
)
if
err
!=
nil
{
clearCookies
()
response
.
InternalError
(
c
,
"Failed to generate token pair"
)
return
}
h
.
authService
.
RecordSuccessfulLogin
(
c
.
Request
.
Context
(),
loginUser
.
ID
)
payload
[
"access_token"
]
=
tokenPair
.
AccessToken
payload
[
"refresh_token"
]
=
tokenPair
.
RefreshToken
payload
[
"expires_in"
]
=
tokenPair
.
ExpiresIn
payload
[
"token_type"
]
=
"Bearer"
}
clearCookies
()
response
.
Success
(
c
,
payload
)
}
backend/internal/handler/auth_oauth_pending_flow_test.go
View file @
ddf80f5e
...
...
@@ -746,8 +746,8 @@ func TestExchangePendingOAuthCompletionExistingLoginWithSuggestedProfileSkipsAdo
})
.
SetLocalFlowState
(
map
[
string
]
any
{
oauthCompletionResponseKey
:
map
[
string
]
any
{
"access_token"
:
"access-token"
,
"refresh_token"
:
"refresh-token"
,
"access_token"
:
"
legacy-
access-token"
,
"refresh_token"
:
"
legacy-
refresh-token"
,
"expires_in"
:
float64
(
3600
),
"token_type"
:
"Bearer"
,
"redirect"
:
"/dashboard"
,
...
...
@@ -769,13 +769,23 @@ func TestExchangePendingOAuthCompletionExistingLoginWithSuggestedProfileSkipsAdo
require
.
Equal
(
t
,
http
.
StatusOK
,
recorder
.
Code
)
payload
:=
decodeJSONResponseData
(
t
,
recorder
)
require
.
Equal
(
t
,
"access-token"
,
payload
[
"access_token"
])
require
.
Equal
(
t
,
"refresh-token"
,
payload
[
"refresh_token"
])
require
.
NotEmpty
(
t
,
payload
[
"access_token"
])
require
.
NotEmpty
(
t
,
payload
[
"refresh_token"
])
require
.
NotEqual
(
t
,
"legacy-access-token"
,
payload
[
"access_token"
])
require
.
NotEqual
(
t
,
"legacy-refresh-token"
,
payload
[
"refresh_token"
])
require
.
Equal
(
t
,
"/dashboard"
,
payload
[
"redirect"
])
require
.
Equal
(
t
,
"Existing Login Example"
,
payload
[
"suggested_display_name"
])
require
.
Equal
(
t
,
"https://cdn.example/existing-login.png"
,
payload
[
"suggested_avatar_url"
])
require
.
NotContains
(
t
,
payload
,
"adoption_required"
)
accessToken
,
ok
:=
payload
[
"access_token"
]
.
(
string
)
require
.
True
(
t
,
ok
)
claims
,
err
:=
handler
.
authService
.
ValidateToken
(
accessToken
)
require
.
NoError
(
t
,
err
)
reloadedUser
,
err
:=
handler
.
userService
.
GetByID
(
ctx
,
userEntity
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
reloadedUser
.
TokenVersion
,
claims
.
TokenVersion
)
decisionCount
,
err
:=
client
.
IdentityAdoptionDecision
.
Query
()
.
Where
(
identityadoptiondecision
.
PendingAuthSessionIDEQ
(
session
.
ID
))
.
Count
(
ctx
)
...
...
@@ -785,6 +795,14 @@ func TestExchangePendingOAuthCompletionExistingLoginWithSuggestedProfileSkipsAdo
storedSession
,
err
:=
client
.
PendingAuthSession
.
Get
(
ctx
,
session
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
storedSession
.
ConsumedAt
)
completion
,
ok
:=
storedSession
.
LocalFlowState
[
oauthCompletionResponseKey
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
NotContains
(
t
,
completion
,
"access_token"
)
require
.
NotContains
(
t
,
completion
,
"refresh_token"
)
require
.
NotContains
(
t
,
completion
,
"expires_in"
)
require
.
NotContains
(
t
,
completion
,
"token_type"
)
require
.
Equal
(
t
,
"/dashboard"
,
completion
[
"redirect"
])
}
func
TestExchangePendingOAuthCompletionBlocksBackendModeBeforeReturningTokenPayload
(
t
*
testing
.
T
)
{
...
...
@@ -841,6 +859,72 @@ func TestExchangePendingOAuthCompletionBlocksBackendModeBeforeReturningTokenPayl
require
.
Nil
(
t
,
storedSession
.
ConsumedAt
)
}
func
TestExchangePendingOAuthCompletionRejectsDisabledTargetUser
(
t
*
testing
.
T
)
{
handler
,
client
:=
newOAuthPendingFlowTestHandler
(
t
,
false
)
ctx
:=
context
.
Background
()
userEntity
,
err
:=
client
.
User
.
Create
()
.
SetEmail
(
"disabled-linked@example.com"
)
.
SetUsername
(
"disabled-linked-user"
)
.
SetPasswordHash
(
"hash"
)
.
SetRole
(
service
.
RoleUser
)
.
SetStatus
(
service
.
StatusDisabled
)
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
session
,
err
:=
client
.
PendingAuthSession
.
Create
()
.
SetSessionToken
(
"disabled-linked-session-token"
)
.
SetIntent
(
"login"
)
.
SetProviderType
(
"linuxdo"
)
.
SetProviderKey
(
"linuxdo"
)
.
SetProviderSubject
(
"disabled-linked-subject"
)
.
SetTargetUserID
(
userEntity
.
ID
)
.
SetResolvedEmail
(
userEntity
.
Email
)
.
SetBrowserSessionKey
(
"disabled-linked-browser-session-key"
)
.
SetUpstreamIdentityClaims
(
map
[
string
]
any
{
"suggested_display_name"
:
"Disabled Linked User"
,
})
.
SetLocalFlowState
(
map
[
string
]
any
{
oauthCompletionResponseKey
:
map
[
string
]
any
{
"redirect"
:
"/dashboard"
,
},
})
.
SetExpiresAt
(
time
.
Now
()
.
UTC
()
.
Add
(
10
*
time
.
Minute
))
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
recorder
:=
httptest
.
NewRecorder
()
ginCtx
,
_
:=
gin
.
CreateTestContext
(
recorder
)
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/auth/oauth/pending/exchange"
,
nil
)
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthPendingSessionCookieName
,
Value
:
encodeCookieValue
(
session
.
SessionToken
)})
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthPendingBrowserCookieName
,
Value
:
encodeCookieValue
(
"disabled-linked-browser-session-key"
)})
ginCtx
.
Request
=
req
handler
.
ExchangePendingOAuthCompletion
(
ginCtx
)
require
.
Equal
(
t
,
http
.
StatusForbidden
,
recorder
.
Code
)
storedSession
,
err
:=
client
.
PendingAuthSession
.
Get
(
ctx
,
session
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Nil
(
t
,
storedSession
.
ConsumedAt
)
}
func
TestNormalizePendingOAuthCompletionResponseScrubsLegacyTokenPayload
(
t
*
testing
.
T
)
{
payload
:=
normalizePendingOAuthCompletionResponse
(
map
[
string
]
any
{
"access_token"
:
"legacy-access-token"
,
"refresh_token"
:
"legacy-refresh-token"
,
"expires_in"
:
float64
(
3600
),
"token_type"
:
"Bearer"
,
"redirect"
:
"/dashboard"
,
})
require
.
NotContains
(
t
,
payload
,
"access_token"
)
require
.
NotContains
(
t
,
payload
,
"refresh_token"
)
require
.
NotContains
(
t
,
payload
,
"expires_in"
)
require
.
NotContains
(
t
,
payload
,
"token_type"
)
require
.
Equal
(
t
,
"/dashboard"
,
payload
[
"redirect"
])
}
func
TestExchangePendingOAuthCompletionInvitationRequiredFalseFalsePersistsDecisionWithoutBinding
(
t
*
testing
.
T
)
{
handler
,
client
:=
newOAuthPendingFlowTestHandler
(
t
,
true
)
ctx
:=
context
.
Background
()
...
...
@@ -969,7 +1053,7 @@ func TestCreateOIDCOAuthAccountExistingEmailReturnsChoicePendingSessionState(t *
handler
,
client
:=
newOAuthPendingFlowTestHandlerWithEmailVerification
(
t
,
false
,
"owner@example.com"
,
"135790"
)
ctx
:=
context
.
Background
()
_
,
err
:=
client
.
User
.
Create
()
.
existingUser
,
err
:=
client
.
User
.
Create
()
.
SetEmail
(
"owner@example.com"
)
.
SetUsername
(
"owner-user"
)
.
SetPasswordHash
(
"hash"
)
.
...
...
@@ -1023,7 +1107,8 @@ func TestCreateOIDCOAuthAccountExistingEmailReturnsChoicePendingSessionState(t *
storedSession
,
err
:=
client
.
PendingAuthSession
.
Get
(
ctx
,
session
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
oauthIntentLogin
,
storedSession
.
Intent
)
require
.
Nil
(
t
,
storedSession
.
TargetUserID
)
require
.
NotNil
(
t
,
storedSession
.
TargetUserID
)
require
.
Equal
(
t
,
existingUser
.
ID
,
*
storedSession
.
TargetUserID
)
require
.
Equal
(
t
,
"owner@example.com"
,
storedSession
.
ResolvedEmail
)
require
.
Nil
(
t
,
storedSession
.
ConsumedAt
)
...
...
@@ -1042,7 +1127,7 @@ func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *te
handler
,
client
:=
newOAuthPendingFlowTestHandlerWithEmailVerification
(
t
,
false
,
"owner@example.com"
,
"135790"
)
ctx
:=
context
.
Background
()
_
,
err
:=
client
.
User
.
Create
()
.
existingUser
,
err
:=
client
.
User
.
Create
()
.
SetEmail
(
" Owner@Example.com "
)
.
SetUsername
(
"owner-user"
)
.
SetPasswordHash
(
"hash"
)
.
...
...
@@ -1088,7 +1173,8 @@ func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *te
storedSession
,
err
:=
client
.
PendingAuthSession
.
Get
(
ctx
,
session
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Nil
(
t
,
storedSession
.
TargetUserID
)
require
.
NotNil
(
t
,
storedSession
.
TargetUserID
)
require
.
Equal
(
t
,
existingUser
.
ID
,
*
storedSession
.
TargetUserID
)
require
.
Equal
(
t
,
"owner@example.com"
,
storedSession
.
ResolvedEmail
)
}
...
...
@@ -1096,7 +1182,7 @@ func TestSendPendingOAuthVerifyCodeExistingEmailReturnsBindLoginState(t *testing
handler
,
client
:=
newOAuthPendingFlowTestHandlerWithEmailVerification
(
t
,
false
,
"owner@example.com"
,
"135790"
)
ctx
:=
context
.
Background
()
_
,
err
:=
client
.
User
.
Create
()
.
existingUser
,
err
:=
client
.
User
.
Create
()
.
SetEmail
(
"owner@example.com"
)
.
SetUsername
(
"owner-user"
)
.
SetPasswordHash
(
"hash"
)
.
...
...
@@ -1144,7 +1230,8 @@ func TestSendPendingOAuthVerifyCodeExistingEmailReturnsBindLoginState(t *testing
storedSession
,
err
:=
client
.
PendingAuthSession
.
Get
(
ctx
,
session
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
oauthIntentLogin
,
storedSession
.
Intent
)
require
.
Nil
(
t
,
storedSession
.
TargetUserID
)
require
.
NotNil
(
t
,
storedSession
.
TargetUserID
)
require
.
Equal
(
t
,
existingUser
.
ID
,
*
storedSession
.
TargetUserID
)
require
.
Equal
(
t
,
"owner@example.com"
,
storedSession
.
ResolvedEmail
)
}
...
...
@@ -1202,6 +1289,26 @@ func TestCreateOIDCOAuthAccountBlocksBackendModeBeforeCreatingUser(t *testing.T)
require
.
Nil
(
t
,
storedSession
.
ConsumedAt
)
}
func
TestLogoutClearsPendingOAuthAndBindCookies
(
t
*
testing
.
T
)
{
handler
,
_
:=
newOAuthPendingFlowTestHandler
(
t
,
false
)
recorder
:=
httptest
.
NewRecorder
()
ginCtx
,
_
:=
gin
.
CreateTestContext
(
recorder
)
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/auth/logout"
,
bytes
.
NewBufferString
(
`{}`
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthPendingSessionCookieName
,
Value
:
encodeCookieValue
(
"pending-session-token"
)})
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthPendingBrowserCookieName
,
Value
:
encodeCookieValue
(
"pending-browser-key"
)})
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthBindAccessTokenCookieName
,
Value
:
"bind-token"
})
ginCtx
.
Request
=
req
handler
.
Logout
(
ginCtx
)
require
.
Equal
(
t
,
http
.
StatusOK
,
recorder
.
Code
)
require
.
Equal
(
t
,
-
1
,
findCookie
(
recorder
.
Result
()
.
Cookies
(),
oauthPendingSessionCookieName
)
.
MaxAge
)
require
.
Equal
(
t
,
-
1
,
findCookie
(
recorder
.
Result
()
.
Cookies
(),
oauthPendingBrowserCookieName
)
.
MaxAge
)
require
.
Equal
(
t
,
-
1
,
findCookie
(
recorder
.
Result
()
.
Cookies
(),
oauthBindAccessTokenCookieName
)
.
MaxAge
)
}
func
TestCreateOIDCOAuthAccountRollsBackCreatedUserWhenBindingFails
(
t
*
testing
.
T
)
{
handler
,
client
:=
newOAuthPendingFlowTestHandlerWithEmailVerification
(
t
,
true
,
"fresh@example.com"
,
"246810"
)
ctx
:=
context
.
Background
()
...
...
@@ -1934,6 +2041,13 @@ func TestLogin2FACompletesPendingOAuthBindAndConsumesSession(t *testing.T) {
payload
:=
decodeJSONResponseData
(
t
,
recorder
)
require
.
NotEmpty
(
t
,
payload
[
"access_token"
])
require
.
NotEmpty
(
t
,
payload
[
"refresh_token"
])
accessToken
,
ok
:=
payload
[
"access_token"
]
.
(
string
)
require
.
True
(
t
,
ok
)
claims
,
err
:=
handler
.
authService
.
ValidateToken
(
accessToken
)
require
.
NoError
(
t
,
err
)
reloadedUser
,
err
:=
handler
.
userService
.
GetByID
(
ctx
,
existingUser
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
reloadedUser
.
TokenVersion
,
claims
.
TokenVersion
)
identity
,
err
:=
client
.
AuthIdentity
.
Query
()
.
Where
(
...
...
backend/internal/handler/auth_oauth_test_helpers_test.go
View file @
ddf80f5e
...
...
@@ -2,6 +2,7 @@ package handler
import
(
"net/http"
"net/url"
"testing"
"github.com/stretchr/testify/require"
...
...
@@ -37,3 +38,20 @@ func decodeCookieValueForTest(t *testing.T, value string) string {
require
.
NoError
(
t
,
err
)
return
decoded
}
func
assertOAuthRedirectError
(
t
*
testing
.
T
,
location
string
,
errorCode
string
,
errorMessage
string
)
{
t
.
Helper
()
require
.
NotEmpty
(
t
,
location
)
parsed
,
err
:=
url
.
Parse
(
location
)
require
.
NoError
(
t
,
err
)
rawValues
:=
parsed
.
RawQuery
if
rawValues
==
""
{
rawValues
=
parsed
.
Fragment
}
values
,
err
:=
url
.
ParseQuery
(
rawValues
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
errorCode
,
values
.
Get
(
"error"
))
require
.
Equal
(
t
,
errorMessage
,
values
.
Get
(
"error_message"
))
}
backend/internal/handler/auth_oidc_oauth.go
View file @
ddf80f5e
...
...
@@ -157,21 +157,25 @@ func (h *AuthHandler) OIDCOAuthStart(c *gin.Context) {
}
codeChallenge
:=
""
verifier
,
genErr
:=
oauth
.
GenerateCodeVerifier
()
if
genErr
!=
nil
{
response
.
ErrorFrom
(
c
,
infraerrors
.
InternalServer
(
"OAUTH_PKCE_GEN_FAILED"
,
"failed to generate pkce verifier"
)
.
WithCause
(
genErr
))
return
if
cfg
.
UsePKCE
{
verifier
,
genErr
:=
oauth
.
GenerateCodeVerifier
()
if
genErr
!=
nil
{
response
.
ErrorFrom
(
c
,
infraerrors
.
InternalServer
(
"OAUTH_PKCE_GEN_FAILED"
,
"failed to generate pkce verifier"
)
.
WithCause
(
genErr
))
return
}
codeChallenge
=
oauth
.
GenerateCodeChallenge
(
verifier
)
oidcSetCookie
(
c
,
oidcOAuthVerifierCookie
,
encodeCookieValue
(
verifier
),
oidcOAuthCookieMaxAgeSec
,
secureCookie
)
}
codeChallenge
=
oauth
.
GenerateCodeChallenge
(
verifier
)
oidcSetCookie
(
c
,
oidcOAuthVerifierCookie
,
encodeCookieValue
(
verifier
),
oidcOAuthCookieMaxAgeSec
,
secureCookie
)
nonce
:=
""
nonce
,
err
=
oauth
.
GenerateState
()
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
infraerrors
.
InternalServer
(
"OAUTH_NONCE_GEN_FAILED"
,
"failed to generate oauth nonce"
)
.
WithCause
(
err
))
return
if
cfg
.
ValidateIDToken
{
nonce
,
err
=
oauth
.
GenerateState
()
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
infraerrors
.
InternalServer
(
"OAUTH_NONCE_GEN_FAILED"
,
"failed to generate oauth nonce"
)
.
WithCause
(
err
))
return
}
oidcSetCookie
(
c
,
oidcOAuthNonceCookie
,
encodeCookieValue
(
nonce
),
oidcOAuthCookieMaxAgeSec
,
secureCookie
)
}
oidcSetCookie
(
c
,
oidcOAuthNonceCookie
,
encodeCookieValue
(
nonce
),
oidcOAuthCookieMaxAgeSec
,
secureCookie
)
redirectURI
:=
strings
.
TrimSpace
(
cfg
.
RedirectURL
)
if
redirectURI
==
""
{
...
...
@@ -244,17 +248,21 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
intent
=
normalizeOAuthIntent
(
intent
)
codeVerifier
:=
""
codeVerifier
,
_
=
readCookieDecoded
(
c
,
oidcOAuthVerifierCookie
)
if
codeVerifier
==
""
{
redirectOAuthError
(
c
,
frontendCallback
,
"missing_verifier"
,
"missing pkce verifier"
,
""
)
return
if
cfg
.
UsePKCE
{
codeVerifier
,
_
=
readCookieDecoded
(
c
,
oidcOAuthVerifierCookie
)
if
codeVerifier
==
""
{
redirectOAuthError
(
c
,
frontendCallback
,
"missing_verifier"
,
"missing pkce verifier"
,
""
)
return
}
}
expectedNonce
:=
""
expectedNonce
,
_
=
readCookieDecoded
(
c
,
oidcOAuthNonceCookie
)
if
expectedNonce
==
""
{
redirectOAuthError
(
c
,
frontendCallback
,
"missing_nonce"
,
"missing oauth nonce"
,
""
)
return
if
cfg
.
ValidateIDToken
{
expectedNonce
,
_
=
readCookieDecoded
(
c
,
oidcOAuthNonceCookie
)
if
expectedNonce
==
""
{
redirectOAuthError
(
c
,
frontendCallback
,
"missing_nonce"
,
"missing oauth nonce"
,
""
)
return
}
}
redirectURI
:=
strings
.
TrimSpace
(
cfg
.
RedirectURL
)
...
...
@@ -284,16 +292,19 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
return
}
if
strings
.
TrimSpace
(
tokenResp
.
IDToken
)
==
""
{
redirectOAuthError
(
c
,
frontendCallback
,
"missing_id_token"
,
"missing id_token"
,
""
)
return
}
var
idClaims
*
oidcIDTokenClaims
if
cfg
.
ValidateIDToken
{
if
strings
.
TrimSpace
(
tokenResp
.
IDToken
)
==
""
{
redirectOAuthError
(
c
,
frontendCallback
,
"missing_id_token"
,
"missing id_token"
,
""
)
return
}
idClaims
,
err
:=
oidcParseAndValidateIDToken
(
c
.
Request
.
Context
(),
cfg
,
tokenResp
.
IDToken
,
expectedNonce
)
if
err
!=
nil
{
log
.
Printf
(
"[OIDC OAuth] id_token validation failed: %v"
,
err
)
redirectOAuthError
(
c
,
frontendCallback
,
"invalid_id_token"
,
"failed to validate id_token"
,
""
)
return
idClaims
,
err
=
oidcParseAndValidateIDToken
(
c
.
Request
.
Context
(),
cfg
,
tokenResp
.
IDToken
,
expectedNonce
)
if
err
!=
nil
{
log
.
Printf
(
"[OIDC OAuth] id_token validation failed: %v"
,
err
)
redirectOAuthError
(
c
,
frontendCallback
,
"invalid_id_token"
,
"failed to validate id_token"
,
""
)
return
}
}
userInfoClaims
,
err
:=
oidcFetchUserInfo
(
c
.
Request
.
Context
(),
cfg
,
tokenResp
)
...
...
@@ -303,7 +314,10 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
return
}
subject
:=
strings
.
TrimSpace
(
idClaims
.
Subject
)
subject
:=
""
if
idClaims
!=
nil
{
subject
=
strings
.
TrimSpace
(
idClaims
.
Subject
)
}
if
subject
==
""
{
subject
=
strings
.
TrimSpace
(
userInfoClaims
.
Subject
)
}
...
...
@@ -311,7 +325,10 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
redirectOAuthError
(
c
,
frontendCallback
,
"missing_subject"
,
"missing subject claim"
,
""
)
return
}
issuer
:=
strings
.
TrimSpace
(
idClaims
.
Issuer
)
issuer
:=
""
if
idClaims
!=
nil
{
issuer
=
strings
.
TrimSpace
(
idClaims
.
Issuer
)
}
if
issuer
==
""
{
issuer
=
strings
.
TrimSpace
(
cfg
.
IssuerURL
)
}
...
...
@@ -321,21 +338,34 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
}
emailVerified
:=
userInfoClaims
.
EmailVerified
if
emailVerified
==
nil
{
if
emailVerified
==
nil
&&
idClaims
!=
nil
{
emailVerified
=
idClaims
.
EmailVerified
}
if
userInfoClaims
.
Subject
!=
""
&&
idClaims
.
Subject
!=
""
&&
strings
.
TrimSpace
(
userInfoClaims
.
Subject
)
!=
strings
.
TrimSpace
(
idClaims
.
Subject
)
{
if
idClaims
!=
nil
&&
userInfoClaims
.
Subject
!=
""
&&
idClaims
.
Subject
!=
""
&&
strings
.
TrimSpace
(
userInfoClaims
.
Subject
)
!=
strings
.
TrimSpace
(
idClaims
.
Subject
)
{
redirectOAuthError
(
c
,
frontendCallback
,
"subject_mismatch"
,
"userinfo subject does not match id_token"
,
""
)
return
}
identityKey
:=
oidcIdentityKey
(
issuer
,
subject
)
compatEmail
:=
strings
.
TrimSpace
(
firstNonEmpty
(
userInfoClaims
.
Email
,
idClaims
.
Email
))
compatEmail
:=
strings
.
TrimSpace
(
userInfoClaims
.
Email
)
if
compatEmail
==
""
&&
idClaims
!=
nil
{
compatEmail
=
strings
.
TrimSpace
(
idClaims
.
Email
)
}
email
:=
oidcSyntheticEmailFromIdentityKey
(
identityKey
)
username
:=
firstNonEmpty
(
userInfoClaims
.
Username
,
idClaims
.
PreferredUsername
,
idClaims
.
Name
,
func
()
string
{
if
idClaims
!=
nil
{
return
idClaims
.
PreferredUsername
}
return
""
}(),
func
()
string
{
if
idClaims
!=
nil
{
return
idClaims
.
Name
}
return
""
}(),
oidcFallbackUsername
(
subject
),
)
identityRef
:=
service
.
PendingAuthIdentityKey
{
...
...
@@ -344,14 +374,19 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
ProviderSubject
:
subject
,
}
upstreamClaims
:=
map
[
string
]
any
{
"email"
:
email
,
"username"
:
username
,
"subject"
:
subject
,
"issuer"
:
issuer
,
"email_verified"
:
emailVerified
!=
nil
&&
*
emailVerified
,
"provider_fallback"
:
strings
.
TrimSpace
(
cfg
.
ProviderName
),
"suggested_display_name"
:
firstNonEmpty
(
userInfoClaims
.
DisplayName
,
idClaims
.
Name
,
username
),
"suggested_avatar_url"
:
userInfoClaims
.
AvatarURL
,
"email"
:
email
,
"username"
:
username
,
"subject"
:
subject
,
"issuer"
:
issuer
,
"email_verified"
:
emailVerified
!=
nil
&&
*
emailVerified
,
"provider_fallback"
:
strings
.
TrimSpace
(
cfg
.
ProviderName
),
"suggested_display_name"
:
firstNonEmpty
(
userInfoClaims
.
DisplayName
,
func
()
string
{
if
idClaims
!=
nil
{
return
idClaims
.
Name
}
return
""
}(),
username
),
"suggested_avatar_url"
:
userInfoClaims
.
AvatarURL
,
}
if
compatEmail
!=
""
&&
!
strings
.
EqualFold
(
strings
.
TrimSpace
(
compatEmail
),
strings
.
TrimSpace
(
email
))
{
upstreamClaims
[
"compat_email"
]
=
compatEmail
...
...
@@ -387,25 +422,16 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
return
}
if
existingIdentityUser
!=
nil
{
tokenPair
,
user
,
err
:=
h
.
authService
.
LoginOrRegisterOAuthWithTokenPair
(
c
.
Request
.
Context
(),
existingIdentityUser
.
Email
,
username
,
""
)
if
err
!=
nil
{
redirectOAuthError
(
c
,
frontendCallback
,
"login_failed"
,
infraerrors
.
Reason
(
err
),
infraerrors
.
Message
(
err
))
return
}
if
err
:=
h
.
createOAuthPendingSession
(
c
,
oauthPendingSessionPayload
{
Intent
:
oauthIntentLogin
,
Identity
:
identityRef
,
TargetUserID
:
&
u
ser
.
ID
,
TargetUserID
:
&
existingIdentityU
ser
.
ID
,
ResolvedEmail
:
existingIdentityUser
.
Email
,
RedirectTo
:
redirectTo
,
BrowserSessionKey
:
browserSessionKey
,
UpstreamIdentityClaims
:
upstreamClaims
,
CompletionResponse
:
map
[
string
]
any
{
"access_token"
:
tokenPair
.
AccessToken
,
"refresh_token"
:
tokenPair
.
RefreshToken
,
"expires_in"
:
tokenPair
.
ExpiresIn
,
"token_type"
:
"Bearer"
,
"redirect"
:
redirectTo
,
"redirect"
:
redirectTo
,
},
});
err
!=
nil
{
redirectOAuthError
(
c
,
frontendCallback
,
"session_error"
,
"failed to continue oauth login"
,
""
)
...
...
@@ -537,10 +563,15 @@ func (h *AuthHandler) createOIDCOAuthChoicePendingSession(
if
compatEmailUser
!=
nil
{
resolvedChoiceEmail
=
strings
.
TrimSpace
(
compatEmailUser
.
Email
)
}
var
targetUserID
*
int64
if
compatEmailUser
!=
nil
&&
compatEmailUser
.
ID
>
0
{
targetUserID
=
&
compatEmailUser
.
ID
}
return
h
.
createOAuthPendingSession
(
c
,
oauthPendingSessionPayload
{
Intent
:
oauthIntentLogin
,
Identity
:
identity
,
TargetUserID
:
targetUserID
,
ResolvedEmail
:
resolvedChoiceEmail
,
RedirectTo
:
redirectTo
,
BrowserSessionKey
:
browserSessionKey
,
...
...
@@ -596,6 +627,15 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
response
.
ErrorFrom
(
c
,
err
)
return
}
if
updatedSession
,
handled
,
err
:=
h
.
legacyCompleteRegistrationSessionStatus
(
c
,
session
);
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
else
if
handled
{
c
.
JSON
(
http
.
StatusOK
,
buildPendingOAuthSessionStatusPayload
(
updatedSession
))
return
}
else
{
session
=
updatedSession
}
if
err
:=
h
.
ensureBackendModeAllowsNewUserLogin
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
...
...
@@ -608,12 +648,16 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
return
}
tokenPair
,
user
,
err
:=
h
.
authService
.
LoginOrRegisterOAuthWithTokenPair
(
c
.
Request
.
Context
(),
email
,
username
,
req
.
InvitationCode
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
client
:=
h
.
entClient
()
if
client
==
nil
{
response
.
ErrorFrom
(
c
,
infraerrors
.
ServiceUnavailable
(
"PENDING_AUTH_NOT_READY"
,
"pending auth service is not ready"
))
return
}
if
err
:=
ensurePendingOAuthRegistrationIdentityAvailable
(
c
.
Request
.
Context
(),
client
,
session
);
err
!=
nil
{
respondPendingOAuthBindingApplyError
(
c
,
err
)
return
}
decision
,
err
:=
h
.
upsert
PendingOAuthAdoptionDecision
(
c
,
session
.
ID
,
oauthAdoptionDecisionRequest
{
decision
,
err
:=
h
.
ensure
PendingOAuthAdoptionDecision
(
c
,
session
.
ID
,
oauthAdoptionDecisionRequest
{
AdoptDisplayName
:
req
.
AdoptDisplayName
,
AdoptAvatar
:
req
.
AdoptAvatar
,
})
...
...
@@ -621,17 +665,16 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
response
.
ErrorFrom
(
c
,
err
)
return
}
if
err
:=
applyPendingOAuthAdoption
(
c
.
Request
.
Context
(),
h
.
entClient
(),
h
.
authService
,
h
.
userService
,
session
,
decision
,
&
user
.
ID
);
err
!=
nil
{
response
.
ErrorFrom
(
c
,
infraerrors
.
InternalServer
(
"PENDING_AUTH_ADOPTION_APPLY_FAILED"
,
"failed to apply oauth profile adoption"
)
.
WithCause
(
err
))
tokenPair
,
user
,
err
:=
h
.
authService
.
LoginOrRegisterOAuthWithTokenPair
(
c
.
Request
.
Context
(),
email
,
username
,
req
.
InvitationCode
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
h
.
authService
.
RecordSuccessfulLogin
(
c
.
Request
.
Context
(),
user
.
ID
)
if
_
,
err
:=
pendingSvc
.
ConsumeBrowserSession
(
c
.
Request
.
Context
(),
sessionToken
,
browserSessionKey
);
err
!=
nil
{
clearOAuthPendingSessionCookie
(
c
,
secureCookie
)
clearOAuthPendingBrowserCookie
(
c
,
secureCookie
)
response
.
ErrorFrom
(
c
,
err
)
if
err
:=
applyPendingOAuthAdoptionAndConsumeSession
(
c
.
Request
.
Context
(),
client
,
h
.
authService
,
h
.
userService
,
session
,
decision
,
user
.
ID
);
err
!=
nil
{
respondPendingOAuthBindingApplyError
(
c
,
err
)
return
}
h
.
authService
.
RecordSuccessfulLogin
(
c
.
Request
.
Context
(),
user
.
ID
)
clearOAuthPendingSessionCookie
(
c
,
secureCookie
)
clearOAuthPendingBrowserCookie
(
c
,
secureCookie
)
...
...
@@ -670,7 +713,9 @@ func oidcExchangeCode(
form
.
Set
(
"client_id"
,
cfg
.
ClientID
)
form
.
Set
(
"code"
,
code
)
form
.
Set
(
"redirect_uri"
,
redirectURI
)
form
.
Set
(
"code_verifier"
,
codeVerifier
)
if
strings
.
TrimSpace
(
codeVerifier
)
!=
""
{
form
.
Set
(
"code_verifier"
,
codeVerifier
)
}
r
:=
client
.
R
()
.
SetContext
(
ctx
)
.
...
...
@@ -872,9 +917,13 @@ func buildOIDCAuthorizeURL(cfg config.OIDCConnectConfig, state, nonce, codeChall
q
.
Set
(
"scope"
,
cfg
.
Scopes
)
}
q
.
Set
(
"state"
,
state
)
q
.
Set
(
"nonce"
,
nonce
)
q
.
Set
(
"code_challenge"
,
codeChallenge
)
q
.
Set
(
"code_challenge_method"
,
"S256"
)
if
strings
.
TrimSpace
(
nonce
)
!=
""
{
q
.
Set
(
"nonce"
,
nonce
)
}
if
strings
.
TrimSpace
(
codeChallenge
)
!=
""
{
q
.
Set
(
"code_challenge"
,
codeChallenge
)
q
.
Set
(
"code_challenge_method"
,
"S256"
)
}
u
.
RawQuery
=
q
.
Encode
()
return
u
.
String
(),
nil
...
...
backend/internal/handler/auth_oidc_oauth_test.go
View file @
ddf80f5e
...
...
@@ -186,6 +186,89 @@ func TestOIDCOAuthBindStartRedirectsAndSetsBindCookies(t *testing.T) {
require
.
Equal
(
t
,
int64
(
84
),
userID
)
}
func
TestOIDCOAuthStartOmitsPKCEAndNonceWhenDisabled
(
t
*
testing
.
T
)
{
handler
:=
newOIDCOAuthTestHandler
(
t
,
false
,
config
.
OIDCConnectConfig
{
Enabled
:
true
,
ClientID
:
"oidc-client"
,
ClientSecret
:
"oidc-secret"
,
IssuerURL
:
"https://issuer.example.com"
,
AuthorizeURL
:
"https://issuer.example.com/oauth/authorize"
,
TokenURL
:
"https://issuer.example.com/oauth/token"
,
UserInfoURL
:
"https://issuer.example.com/oauth/userinfo"
,
Scopes
:
"openid profile email"
,
RedirectURL
:
"https://api.example.com/api/v1/auth/oauth/oidc/callback"
,
FrontendRedirectURL
:
"/auth/oidc/callback"
,
TokenAuthMethod
:
"client_secret_post"
,
UsePKCE
:
false
,
ValidateIDToken
:
false
,
RequireEmailVerified
:
false
,
})
recorder
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
recorder
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/api/v1/auth/oauth/oidc/start?redirect=/dashboard"
,
nil
)
handler
.
OIDCOAuthStart
(
c
)
require
.
Equal
(
t
,
http
.
StatusFound
,
recorder
.
Code
)
location
:=
recorder
.
Header
()
.
Get
(
"Location"
)
require
.
NotContains
(
t
,
location
,
"code_challenge="
)
require
.
NotContains
(
t
,
location
,
"nonce="
)
require
.
Nil
(
t
,
findCookie
(
recorder
.
Result
()
.
Cookies
(),
oidcOAuthVerifierCookie
))
require
.
Nil
(
t
,
findCookie
(
recorder
.
Result
()
.
Cookies
(),
oidcOAuthNonceCookie
))
}
func
TestOIDCOAuthCallbackAllowsOptionalPKCEAndIDTokenValidation
(
t
*
testing
.
T
)
{
upstream
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
switch
r
.
URL
.
Path
{
case
"/token"
:
require
.
NoError
(
t
,
r
.
ParseForm
())
require
.
Empty
(
t
,
r
.
PostForm
.
Get
(
"code_verifier"
))
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
w
.
Write
([]
byte
(
`{"access_token":"oidc-access","token_type":"Bearer","expires_in":3600}`
))
case
"/userinfo"
:
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
w
.
Write
([]
byte
(
`{"sub":"oidc-subject-compat","preferred_username":"oidc_user","name":"OIDC Display","email":"oidc@example.com"}`
))
default
:
http
.
NotFound
(
w
,
r
)
}
}))
defer
upstream
.
Close
()
handler
,
client
:=
newOIDCOAuthHandlerAndClient
(
t
,
false
,
config
.
OIDCConnectConfig
{
Enabled
:
true
,
ClientID
:
"oidc-client"
,
ClientSecret
:
"oidc-secret"
,
IssuerURL
:
"https://issuer.example.com"
,
AuthorizeURL
:
upstream
.
URL
+
"/authorize"
,
TokenURL
:
upstream
.
URL
+
"/token"
,
UserInfoURL
:
upstream
.
URL
+
"/userinfo"
,
Scopes
:
"openid profile email"
,
RedirectURL
:
"https://api.example.com/api/v1/auth/oauth/oidc/callback"
,
FrontendRedirectURL
:
"/auth/oidc/callback"
,
TokenAuthMethod
:
"client_secret_post"
,
UsePKCE
:
false
,
ValidateIDToken
:
false
,
RequireEmailVerified
:
false
,
})
t
.
Cleanup
(
func
()
{
_
=
client
.
Close
()
})
recorder
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
recorder
)
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-123"
,
nil
)
req
.
AddCookie
(
encodedCookie
(
oidcOAuthStateCookieName
,
"state-123"
))
req
.
AddCookie
(
encodedCookie
(
oidcOAuthRedirectCookie
,
"/dashboard"
))
req
.
AddCookie
(
encodedCookie
(
oidcOAuthIntentCookieName
,
oauthIntentLogin
))
req
.
AddCookie
(
encodedCookie
(
oauthPendingBrowserCookieName
,
"browser-123"
))
c
.
Request
=
req
handler
.
OIDCOAuthCallback
(
c
)
require
.
Equal
(
t
,
http
.
StatusFound
,
recorder
.
Code
)
require
.
Equal
(
t
,
"/auth/oidc/callback"
,
recorder
.
Header
()
.
Get
(
"Location"
))
require
.
NotNil
(
t
,
findCookie
(
recorder
.
Result
()
.
Cookies
(),
oauthPendingSessionCookieName
))
}
func
TestOIDCOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser
(
t
*
testing
.
T
)
{
cfg
,
cleanup
:=
newOIDCTestProvider
(
t
,
oidcProviderFixture
{
Subject
:
"oidc-subject-login"
,
...
...
@@ -250,10 +333,63 @@ func TestOIDCOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t *t
completion
,
ok
:=
session
.
LocalFlowState
[
oauthCompletionResponseKey
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"/dashboard"
,
completion
[
"redirect"
])
require
.
NotEmpty
(
t
,
completion
[
"access_token"
])
_
,
hasAccessToken
:=
completion
[
"access_token"
]
require
.
False
(
t
,
hasAccessToken
)
_
,
hasRefreshToken
:=
completion
[
"refresh_token"
]
require
.
False
(
t
,
hasRefreshToken
)
require
.
Nil
(
t
,
completion
[
"error"
])
}
func
TestOIDCOAuthCallbackRejectsDisabledExistingIdentityUser
(
t
*
testing
.
T
)
{
cfg
,
cleanup
:=
newOIDCTestProvider
(
t
,
oidcProviderFixture
{
Subject
:
"oidc-disabled-subject"
,
PreferredUsername
:
"oidc_disabled"
,
DisplayName
:
"OIDC Disabled"
,
})
defer
cleanup
()
handler
,
client
:=
newOIDCOAuthHandlerAndClient
(
t
,
false
,
cfg
)
t
.
Cleanup
(
func
()
{
_
=
client
.
Close
()
})
ctx
:=
context
.
Background
()
existingUser
,
err
:=
client
.
User
.
Create
()
.
SetEmail
(
oidcSyntheticEmailFromIdentityKey
(
oidcIdentityKey
(
cfg
.
IssuerURL
,
"oidc-disabled-subject"
)))
.
SetUsername
(
"disabled-user"
)
.
SetPasswordHash
(
"hash"
)
.
SetRole
(
service
.
RoleUser
)
.
SetStatus
(
service
.
StatusDisabled
)
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
_
,
err
=
client
.
AuthIdentity
.
Create
()
.
SetUserID
(
existingUser
.
ID
)
.
SetProviderType
(
"oidc"
)
.
SetProviderKey
(
cfg
.
IssuerURL
)
.
SetProviderSubject
(
"oidc-disabled-subject"
)
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
recorder
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
recorder
)
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-disabled"
,
nil
)
req
.
AddCookie
(
encodedCookie
(
oidcOAuthStateCookieName
,
"state-disabled"
))
req
.
AddCookie
(
encodedCookie
(
oidcOAuthRedirectCookie
,
"/dashboard"
))
req
.
AddCookie
(
encodedCookie
(
oidcOAuthVerifierCookie
,
"verifier-disabled"
))
req
.
AddCookie
(
encodedCookie
(
oidcOAuthNonceCookie
,
"nonce-oidc-disabled-subject"
))
req
.
AddCookie
(
encodedCookie
(
oidcOAuthIntentCookieName
,
oauthIntentLogin
))
req
.
AddCookie
(
encodedCookie
(
oauthPendingBrowserCookieName
,
"browser-disabled"
))
c
.
Request
=
req
handler
.
OIDCOAuthCallback
(
c
)
require
.
Equal
(
t
,
http
.
StatusFound
,
recorder
.
Code
)
require
.
Nil
(
t
,
findCookie
(
recorder
.
Result
()
.
Cookies
(),
oauthPendingSessionCookieName
))
assertOAuthRedirectError
(
t
,
recorder
.
Header
()
.
Get
(
"Location"
),
"session_error"
,
"USER_NOT_ACTIVE"
)
count
,
err
:=
client
.
PendingAuthSession
.
Query
()
.
Count
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Zero
(
t
,
count
)
}
func
TestOIDCOAuthCallbackCreatesBindPendingSessionForCompatEmailUser
(
t
*
testing
.
T
)
{
cfg
,
cleanup
:=
newOIDCTestProvider
(
t
,
oidcProviderFixture
{
Subject
:
"oidc-subject-compat"
,
...
...
@@ -302,7 +438,8 @@ func TestOIDCOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing
Only
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
oauthIntentLogin
,
session
.
Intent
)
require
.
Nil
(
t
,
session
.
TargetUserID
)
require
.
NotNil
(
t
,
session
.
TargetUserID
)
require
.
Equal
(
t
,
existingUser
.
ID
,
*
session
.
TargetUserID
)
require
.
Equal
(
t
,
existingUser
.
Email
,
session
.
ResolvedEmail
)
require
.
Equal
(
t
,
"legacy@example.com"
,
session
.
UpstreamIdentityClaims
[
"compat_email"
])
...
...
@@ -606,6 +743,189 @@ func TestCompleteOIDCOAuthRegistrationRejectsAdoptExistingUserSession(t *testing
require
.
Nil
(
t
,
storedSession
.
ConsumedAt
)
}
func
TestCompleteOIDCOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired
(
t
*
testing
.
T
)
{
handler
,
client
:=
newOAuthPendingFlowTestHandler
(
t
,
false
)
ctx
:=
context
.
Background
()
session
,
err
:=
client
.
PendingAuthSession
.
Create
()
.
SetSessionToken
(
"oidc-complete-choice-session"
)
.
SetIntent
(
"login"
)
.
SetProviderType
(
"oidc"
)
.
SetProviderKey
(
"https://issuer.example.com"
)
.
SetProviderSubject
(
"oidc-choice-subject-1"
)
.
SetResolvedEmail
(
"oidc-choice-subject-1@oidc-connect.invalid"
)
.
SetBrowserSessionKey
(
"oidc-choice-browser"
)
.
SetUpstreamIdentityClaims
(
map
[
string
]
any
{
"username"
:
"oidc_user"
,
"issuer"
:
"https://issuer.example.com"
,
})
.
SetLocalFlowState
(
map
[
string
]
any
{
oauthCompletionResponseKey
:
map
[
string
]
any
{
"step"
:
oauthPendingChoiceStep
,
"redirect"
:
"/dashboard"
,
"email"
:
"fresh@example.com"
,
"resolved_email"
:
"fresh@example.com"
,
"force_email_on_signup"
:
true
,
},
})
.
SetExpiresAt
(
time
.
Now
()
.
UTC
()
.
Add
(
10
*
time
.
Minute
))
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
body
:=
bytes
.
NewBufferString
(
`{"invitation_code":"invite-1"}`
)
recorder
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
recorder
)
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/auth/oauth/oidc/complete-registration"
,
body
)
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthPendingSessionCookieName
,
Value
:
encodeCookieValue
(
session
.
SessionToken
)})
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthPendingBrowserCookieName
,
Value
:
encodeCookieValue
(
"oidc-choice-browser"
)})
c
.
Request
=
req
handler
.
CompleteOIDCOAuthRegistration
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
recorder
.
Code
)
responseData
:=
decodeJSONBody
(
t
,
recorder
)
require
.
Equal
(
t
,
"pending_session"
,
responseData
[
"auth_result"
])
require
.
Equal
(
t
,
oauthPendingChoiceStep
,
responseData
[
"step"
])
require
.
Equal
(
t
,
true
,
responseData
[
"force_email_on_signup"
])
require
.
Empty
(
t
,
responseData
[
"access_token"
])
userCount
,
err
:=
client
.
User
.
Query
()
.
Count
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Zero
(
t
,
userCount
)
storedSession
,
err
:=
client
.
PendingAuthSession
.
Get
(
ctx
,
session
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Nil
(
t
,
storedSession
.
ConsumedAt
)
}
func
TestCompleteOIDCOAuthRegistrationBindsIdentityWithoutAdoptionFlags
(
t
*
testing
.
T
)
{
handler
,
client
:=
newOAuthPendingFlowTestHandler
(
t
,
false
)
ctx
:=
context
.
Background
()
session
,
err
:=
client
.
PendingAuthSession
.
Create
()
.
SetSessionToken
(
"oidc-complete-no-adoption-session"
)
.
SetIntent
(
"login"
)
.
SetProviderType
(
"oidc"
)
.
SetProviderKey
(
"https://issuer.example.com"
)
.
SetProviderSubject
(
"oidc-subject-no-adoption"
)
.
SetResolvedEmail
(
"8c9f12b2a2e14b1db9efc08b27e0ef5c@oidc-connect.invalid"
)
.
SetBrowserSessionKey
(
"oidc-browser-no-adoption"
)
.
SetUpstreamIdentityClaims
(
map
[
string
]
any
{
"username"
:
"oidc_user"
,
"issuer"
:
"https://issuer.example.com"
,
"suggested_display_name"
:
"OIDC Legacy"
,
"suggested_avatar_url"
:
"https://cdn.example/oidc-legacy.png"
,
})
.
SetExpiresAt
(
time
.
Now
()
.
UTC
()
.
Add
(
10
*
time
.
Minute
))
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
body
:=
bytes
.
NewBufferString
(
`{"invitation_code":"invite-1"}`
)
recorder
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
recorder
)
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/auth/oauth/oidc/complete-registration"
,
body
)
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthPendingSessionCookieName
,
Value
:
encodeCookieValue
(
session
.
SessionToken
)})
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthPendingBrowserCookieName
,
Value
:
encodeCookieValue
(
"oidc-browser-no-adoption"
)})
c
.
Request
=
req
handler
.
CompleteOIDCOAuthRegistration
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
recorder
.
Code
)
responseData
:=
decodeJSONBody
(
t
,
recorder
)
require
.
NotEmpty
(
t
,
responseData
[
"access_token"
])
require
.
NotEmpty
(
t
,
responseData
[
"refresh_token"
])
userEntity
,
err
:=
client
.
User
.
Query
()
.
Where
(
dbuser
.
EmailEQ
(
session
.
ResolvedEmail
))
.
Only
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"oidc_user"
,
userEntity
.
Username
)
identity
,
err
:=
client
.
AuthIdentity
.
Query
()
.
Where
(
authidentity
.
ProviderTypeEQ
(
"oidc"
),
authidentity
.
ProviderKeyEQ
(
"https://issuer.example.com"
),
authidentity
.
ProviderSubjectEQ
(
"oidc-subject-no-adoption"
),
)
.
Only
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
userEntity
.
ID
,
identity
.
UserID
)
decision
,
err
:=
client
.
IdentityAdoptionDecision
.
Query
()
.
Where
(
identityadoptiondecision
.
PendingAuthSessionIDEQ
(
session
.
ID
))
.
Only
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
decision
.
IdentityID
)
require
.
Equal
(
t
,
identity
.
ID
,
*
decision
.
IdentityID
)
require
.
False
(
t
,
decision
.
AdoptDisplayName
)
require
.
False
(
t
,
decision
.
AdoptAvatar
)
}
func
TestCompleteOIDCOAuthRegistrationRejectsIdentityOwnershipConflictBeforeUserCreation
(
t
*
testing
.
T
)
{
handler
,
client
:=
newOAuthPendingFlowTestHandler
(
t
,
false
)
ctx
:=
context
.
Background
()
existingOwner
,
err
:=
client
.
User
.
Create
()
.
SetEmail
(
"owner@example.com"
)
.
SetUsername
(
"owner-user"
)
.
SetPasswordHash
(
"hash"
)
.
SetRole
(
service
.
RoleUser
)
.
SetStatus
(
service
.
StatusActive
)
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
_
,
err
=
client
.
AuthIdentity
.
Create
()
.
SetUserID
(
existingOwner
.
ID
)
.
SetProviderType
(
"oidc"
)
.
SetProviderKey
(
"https://issuer.example.com"
)
.
SetProviderSubject
(
"oidc-conflict-subject"
)
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
session
,
err
:=
client
.
PendingAuthSession
.
Create
()
.
SetSessionToken
(
"oidc-complete-conflict-session"
)
.
SetIntent
(
"login"
)
.
SetProviderType
(
"oidc"
)
.
SetProviderKey
(
"https://issuer.example.com"
)
.
SetProviderSubject
(
"oidc-conflict-subject"
)
.
SetResolvedEmail
(
"f6f5f1f16f9248ccb11e0d633963b290@oidc-connect.invalid"
)
.
SetBrowserSessionKey
(
"oidc-conflict-browser"
)
.
SetUpstreamIdentityClaims
(
map
[
string
]
any
{
"username"
:
"oidc_user"
,
"issuer"
:
"https://issuer.example.com"
,
})
.
SetExpiresAt
(
time
.
Now
()
.
UTC
()
.
Add
(
10
*
time
.
Minute
))
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
body
:=
bytes
.
NewBufferString
(
`{"invitation_code":"invite-1"}`
)
recorder
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
recorder
)
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/auth/oauth/oidc/complete-registration"
,
body
)
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthPendingSessionCookieName
,
Value
:
encodeCookieValue
(
session
.
SessionToken
)})
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthPendingBrowserCookieName
,
Value
:
encodeCookieValue
(
"oidc-conflict-browser"
)})
c
.
Request
=
req
handler
.
CompleteOIDCOAuthRegistration
(
c
)
require
.
Equal
(
t
,
http
.
StatusConflict
,
recorder
.
Code
)
payload
:=
decodeJSONBody
(
t
,
recorder
)
require
.
Equal
(
t
,
"AUTH_IDENTITY_OWNERSHIP_CONFLICT"
,
payload
[
"reason"
])
userCount
,
err
:=
client
.
User
.
Query
()
.
Where
(
dbuser
.
EmailEQ
(
"f6f5f1f16f9248ccb11e0d633963b290@oidc-connect.invalid"
))
.
Count
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Zero
(
t
,
userCount
)
storedSession
,
err
:=
client
.
PendingAuthSession
.
Get
(
ctx
,
session
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Nil
(
t
,
storedSession
.
ConsumedAt
)
}
type
oidcProviderFixture
struct
{
Subject
string
PreferredUsername
string
...
...
backend/internal/handler/auth_session_revocation_test.go
0 → 100644
View file @
ddf80f5e
//go:build unit
package
handler
import
(
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
middleware2
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func
TestAuthHandlerRevokeAllSessionsInvalidatesAccessTokens
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
repo
:=
&
userHandlerRepoStub
{
user
:
&
service
.
User
{
ID
:
29
,
Email
:
"session@example.com"
,
Username
:
"session-user"
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
TokenVersion
:
7
,
},
}
refreshTokenCache
:=
&
userHandlerRefreshTokenCacheStub
{}
cfg
:=
&
config
.
Config
{
JWT
:
config
.
JWTConfig
{
Secret
:
"test-secret"
,
ExpireHour
:
1
,
},
}
authService
:=
service
.
NewAuthService
(
nil
,
repo
,
nil
,
refreshTokenCache
,
cfg
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
handler
:=
&
AuthHandler
{
authService
:
authService
}
recorder
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
recorder
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/auth/revoke-all-sessions"
,
nil
)
c
.
Set
(
string
(
middleware2
.
ContextKeyUser
),
middleware2
.
AuthSubject
{
UserID
:
29
})
handler
.
RevokeAllSessions
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
recorder
.
Code
)
require
.
Equal
(
t
,
[]
int64
{
29
},
refreshTokenCache
.
revokedUserIDs
)
require
.
Equal
(
t
,
int64
(
8
),
repo
.
user
.
TokenVersion
)
var
resp
struct
{
Code
int
`json:"code"`
Data
struct
{
Message
string
`json:"message"`
}
`json:"data"`
}
require
.
NoError
(
t
,
json
.
Unmarshal
(
recorder
.
Body
.
Bytes
(),
&
resp
))
require
.
Equal
(
t
,
0
,
resp
.
Code
)
require
.
Equal
(
t
,
"All sessions have been revoked. Please log in again."
,
resp
.
Data
.
Message
)
}
backend/internal/handler/auth_wechat_oauth.go
View file @
ddf80f5e
...
...
@@ -279,12 +279,7 @@ func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) {
redirectOAuthError
(
c
,
frontendCallback
,
"session_error"
,
infraerrors
.
Reason
(
err
),
infraerrors
.
Message
(
err
))
return
}
tokenPair
,
user
,
err
:=
h
.
authService
.
LoginOrRegisterOAuthWithTokenPair
(
c
.
Request
.
Context
(),
existingIdentityUser
.
Email
,
username
,
""
)
if
err
!=
nil
{
redirectOAuthError
(
c
,
frontendCallback
,
"login_failed"
,
infraerrors
.
Reason
(
err
),
infraerrors
.
Message
(
err
))
return
}
if
err
:=
h
.
createWeChatPendingSession
(
c
,
normalizedIntent
,
providerSubject
,
existingIdentityUser
.
Email
,
redirectTo
,
browserSessionKey
,
upstreamClaims
,
tokenPair
,
nil
,
&
user
.
ID
);
err
!=
nil
{
if
err
:=
h
.
createWeChatPendingSession
(
c
,
normalizedIntent
,
providerSubject
,
existingIdentityUser
.
Email
,
redirectTo
,
browserSessionKey
,
upstreamClaims
,
nil
,
nil
,
&
existingIdentityUser
.
ID
);
err
!=
nil
{
redirectOAuthError
(
c
,
frontendCallback
,
"session_error"
,
"failed to continue oauth login"
,
""
)
return
}
...
...
@@ -476,11 +471,12 @@ func (h *AuthHandler) WeChatPaymentOAuthCallback(c *gin.Context) {
}
func
(
h
*
AuthHandler
)
wechatPaymentResumeService
()
*
service
.
PaymentResumeService
{
var
legacyKey
[]
byte
key
,
err
:=
payment
.
ProvideEncryptionKey
(
h
.
cfg
)
if
err
!
=
nil
{
return
service
.
NewPaymentResumeService
(
nil
)
if
err
=
=
nil
{
legacyKey
=
[]
byte
(
key
)
}
return
service
.
NewPaymentResumeService
(
[]
byte
(
k
ey
)
)
return
service
.
New
LegacyAware
PaymentResumeService
(
legacyK
ey
)
}
type
completeWeChatOAuthRequest
struct
{
...
...
@@ -530,6 +526,15 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
response
.
ErrorFrom
(
c
,
err
)
return
}
if
updatedSession
,
handled
,
err
:=
h
.
legacyCompleteRegistrationSessionStatus
(
c
,
session
);
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
else
if
handled
{
c
.
JSON
(
http
.
StatusOK
,
buildPendingOAuthSessionStatusPayload
(
updatedSession
))
return
}
else
{
session
=
updatedSession
}
if
err
:=
h
.
ensureBackendModeAllowsNewUserLogin
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
...
...
@@ -547,7 +552,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
response
.
ErrorFrom
(
c
,
err
)
return
}
decision
,
err
:=
h
.
upsert
PendingOAuthAdoptionDecision
(
c
,
session
.
ID
,
oauthAdoptionDecisionRequest
{
decision
,
err
:=
h
.
ensure
PendingOAuthAdoptionDecision
(
c
,
session
.
ID
,
oauthAdoptionDecisionRequest
{
AdoptDisplayName
:
req
.
AdoptDisplayName
,
AdoptAvatar
:
req
.
AdoptAvatar
,
})
...
...
@@ -823,7 +828,10 @@ func (h *AuthHandler) findWeChatUserByLegacyOpenID(
return
nil
,
infraerrors
.
InternalServer
(
"AUTH_IDENTITY_LOOKUP_FAILED"
,
"failed to inspect auth identity ownership"
)
.
WithCause
(
err
)
}
if
user
,
err
:=
singleWeChatIdentityUser
(
records
);
err
!=
nil
||
user
!=
nil
{
return
user
,
err
if
err
!=
nil
||
user
==
nil
{
return
user
,
err
}
return
findActiveUserByID
(
ctx
,
client
,
user
.
ID
)
}
}
...
...
@@ -847,7 +855,10 @@ func (h *AuthHandler) findWeChatUserByLegacyOpenID(
return
nil
,
infraerrors
.
InternalServer
(
"AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED"
,
"failed to inspect auth identity channel ownership"
)
.
WithCause
(
err
)
}
if
user
,
err
:=
singleWeChatChannelUser
(
records
);
err
!=
nil
||
user
!=
nil
{
return
user
,
err
if
err
!=
nil
||
user
==
nil
{
return
user
,
err
}
return
findActiveUserByID
(
ctx
,
client
,
user
.
ID
)
}
}
...
...
@@ -866,7 +877,11 @@ func (h *AuthHandler) findWeChatUserByLegacyOpenID(
if
err
!=
nil
{
return
nil
,
infraerrors
.
InternalServer
(
"AUTH_IDENTITY_LOOKUP_FAILED"
,
"failed to inspect auth identity ownership"
)
.
WithCause
(
err
)
}
return
singleWeChatIdentityUser
(
records
)
user
,
err
:=
singleWeChatIdentityUser
(
records
)
if
err
!=
nil
||
user
==
nil
{
return
user
,
err
}
return
findActiveUserByID
(
ctx
,
client
,
user
.
ID
)
}
func
wechatCompatibleProviderKeys
(
providerKey
string
)
[]
string
{
...
...
backend/internal/handler/auth_wechat_oauth_test.go
View file @
ddf80f5e
...
...
@@ -213,6 +213,151 @@ func TestWeChatOAuthCallbackFallsBackToOpenIDWhenUnionIDMissingInSingleChannelMo
require
.
Equal
(
t
,
"third_party_signup"
,
completion
[
"choice_reason"
])
}
func
TestWeChatOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUserWithoutStoredTokens
(
t
*
testing
.
T
)
{
originalAccessTokenURL
:=
wechatOAuthAccessTokenURL
originalUserInfoURL
:=
wechatOAuthUserInfoURL
t
.
Cleanup
(
func
()
{
wechatOAuthAccessTokenURL
=
originalAccessTokenURL
wechatOAuthUserInfoURL
=
originalUserInfoURL
})
upstream
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
switch
{
case
strings
.
Contains
(
r
.
URL
.
Path
,
"/sns/oauth2/access_token"
)
:
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
w
.
Write
([]
byte
(
`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`
))
case
strings
.
Contains
(
r
.
URL
.
Path
,
"/sns/userinfo"
)
:
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
w
.
Write
([]
byte
(
`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Display","headimgurl":"https://cdn.example/wechat-login.png"}`
))
default
:
http
.
NotFound
(
w
,
r
)
}
}))
defer
upstream
.
Close
()
wechatOAuthAccessTokenURL
=
upstream
.
URL
+
"/sns/oauth2/access_token"
wechatOAuthUserInfoURL
=
upstream
.
URL
+
"/sns/userinfo"
handler
,
client
:=
newWeChatOAuthTestHandlerWithSettings
(
t
,
false
,
wechatOAuthTestSettings
(
"open"
,
"wx-open-app"
,
"wx-open-secret"
,
"https://app.example.com/auth/wechat/callback"
))
defer
client
.
Close
()
ctx
:=
context
.
Background
()
existingUser
,
err
:=
client
.
User
.
Create
()
.
SetEmail
(
wechatSyntheticEmail
(
"union-456"
))
.
SetUsername
(
"wechat-existing-user"
)
.
SetPasswordHash
(
"hash"
)
.
SetRole
(
service
.
RoleUser
)
.
SetStatus
(
service
.
StatusActive
)
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
_
,
err
=
client
.
AuthIdentity
.
Create
()
.
SetUserID
(
existingUser
.
ID
)
.
SetProviderType
(
"wechat"
)
.
SetProviderKey
(
wechatOAuthProviderKey
)
.
SetProviderSubject
(
"union-456"
)
.
SetMetadata
(
map
[
string
]
any
{
"username"
:
"wechat-existing-user"
})
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
recorder
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
recorder
)
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123"
,
nil
)
req
.
Host
=
"api.example.com"
req
.
AddCookie
(
encodedCookie
(
wechatOAuthStateCookieName
,
"state-123"
))
req
.
AddCookie
(
encodedCookie
(
wechatOAuthRedirectCookieName
,
"/dashboard"
))
req
.
AddCookie
(
encodedCookie
(
wechatOAuthModeCookieName
,
"open"
))
req
.
AddCookie
(
encodedCookie
(
oauthPendingBrowserCookieName
,
"browser-123"
))
c
.
Request
=
req
handler
.
WeChatOAuthCallback
(
c
)
require
.
Equal
(
t
,
http
.
StatusFound
,
recorder
.
Code
)
require
.
Equal
(
t
,
"https://app.example.com/auth/wechat/callback"
,
recorder
.
Header
()
.
Get
(
"Location"
))
sessionCookie
:=
findCookie
(
recorder
.
Result
()
.
Cookies
(),
oauthPendingSessionCookieName
)
require
.
NotNil
(
t
,
sessionCookie
)
session
,
err
:=
client
.
PendingAuthSession
.
Query
()
.
Where
(
pendingauthsession
.
SessionTokenEQ
(
decodeCookieValueForTest
(
t
,
sessionCookie
.
Value
)))
.
Only
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
oauthIntentLogin
,
session
.
Intent
)
require
.
NotNil
(
t
,
session
.
TargetUserID
)
require
.
Equal
(
t
,
existingUser
.
ID
,
*
session
.
TargetUserID
)
require
.
Equal
(
t
,
existingUser
.
Email
,
session
.
ResolvedEmail
)
completion
:=
session
.
LocalFlowState
[
oauthCompletionResponseKey
]
.
(
map
[
string
]
any
)
require
.
Equal
(
t
,
"/dashboard"
,
completion
[
"redirect"
])
_
,
hasAccessToken
:=
completion
[
"access_token"
]
require
.
False
(
t
,
hasAccessToken
)
_
,
hasRefreshToken
:=
completion
[
"refresh_token"
]
require
.
False
(
t
,
hasRefreshToken
)
}
func
TestWeChatOAuthCallbackRejectsDisabledExistingIdentityUser
(
t
*
testing
.
T
)
{
originalAccessTokenURL
:=
wechatOAuthAccessTokenURL
originalUserInfoURL
:=
wechatOAuthUserInfoURL
t
.
Cleanup
(
func
()
{
wechatOAuthAccessTokenURL
=
originalAccessTokenURL
wechatOAuthUserInfoURL
=
originalUserInfoURL
})
upstream
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
switch
{
case
strings
.
Contains
(
r
.
URL
.
Path
,
"/sns/oauth2/access_token"
)
:
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
w
.
Write
([]
byte
(
`{"access_token":"wechat-access","openid":"openid-disabled","unionid":"union-disabled","scope":"snsapi_login"}`
))
case
strings
.
Contains
(
r
.
URL
.
Path
,
"/sns/userinfo"
)
:
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
w
.
Write
([]
byte
(
`{"openid":"openid-disabled","unionid":"union-disabled","nickname":"Disabled WeChat","headimgurl":"https://cdn.example/disabled.png"}`
))
default
:
http
.
NotFound
(
w
,
r
)
}
}))
defer
upstream
.
Close
()
wechatOAuthAccessTokenURL
=
upstream
.
URL
+
"/sns/oauth2/access_token"
wechatOAuthUserInfoURL
=
upstream
.
URL
+
"/sns/userinfo"
handler
,
client
:=
newWeChatOAuthTestHandler
(
t
,
false
)
defer
client
.
Close
()
ctx
:=
context
.
Background
()
existingUser
,
err
:=
client
.
User
.
Create
()
.
SetEmail
(
wechatSyntheticEmail
(
"union-disabled"
))
.
SetUsername
(
"disabled-user"
)
.
SetPasswordHash
(
"hash"
)
.
SetRole
(
service
.
RoleUser
)
.
SetStatus
(
service
.
StatusDisabled
)
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
_
,
err
=
client
.
AuthIdentity
.
Create
()
.
SetUserID
(
existingUser
.
ID
)
.
SetProviderType
(
"wechat"
)
.
SetProviderKey
(
wechatOAuthProviderKey
)
.
SetProviderSubject
(
"union-disabled"
)
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
recorder
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
recorder
)
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-disabled"
,
nil
)
req
.
Host
=
"api.example.com"
req
.
AddCookie
(
encodedCookie
(
wechatOAuthStateCookieName
,
"state-disabled"
))
req
.
AddCookie
(
encodedCookie
(
wechatOAuthRedirectCookieName
,
"/dashboard"
))
req
.
AddCookie
(
encodedCookie
(
wechatOAuthModeCookieName
,
"open"
))
req
.
AddCookie
(
encodedCookie
(
oauthPendingBrowserCookieName
,
"browser-disabled"
))
c
.
Request
=
req
handler
.
WeChatOAuthCallback
(
c
)
require
.
Equal
(
t
,
http
.
StatusFound
,
recorder
.
Code
)
require
.
Nil
(
t
,
findCookie
(
recorder
.
Result
()
.
Cookies
(),
oauthPendingSessionCookieName
))
assertOAuthRedirectError
(
t
,
recorder
.
Header
()
.
Get
(
"Location"
),
"session_error"
,
"USER_NOT_ACTIVE"
)
count
,
err
:=
client
.
PendingAuthSession
.
Query
()
.
Count
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Zero
(
t
,
count
)
}
func
TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken
(
t
*
testing
.
T
)
{
originalAccessTokenURL
:=
wechatOAuthAccessTokenURL
t
.
Cleanup
(
func
()
{
...
...
@@ -233,6 +378,7 @@ func TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken(t *testing.T)
handler
,
client
:=
newWeChatOAuthTestHandlerWithSettings
(
t
,
false
,
wechatOAuthTestSettings
(
"mp"
,
"wx-mp-app"
,
"wx-mp-secret"
,
"/auth/wechat/callback"
))
defer
client
.
Close
()
handler
.
cfg
.
Totp
.
EncryptionKey
=
"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
handler
.
cfg
.
Totp
.
EncryptionKeyConfigured
=
true
recorder
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
recorder
)
...
...
@@ -270,6 +416,67 @@ func TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken(t *testing.T)
require
.
Equal
(
t
,
"/purchase?from=wechat"
,
claims
.
RedirectTo
)
}
func
TestWeChatPaymentOAuthCallbackUsesExplicitPaymentResumeSigningKeyWhenMixedKeysConfigured
(
t
*
testing
.
T
)
{
originalAccessTokenURL
:=
wechatOAuthAccessTokenURL
t
.
Cleanup
(
func
()
{
wechatOAuthAccessTokenURL
=
originalAccessTokenURL
})
upstream
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
if
strings
.
Contains
(
r
.
URL
.
Path
,
"/sns/oauth2/access_token"
)
{
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
w
.
Write
([]
byte
(
`{"access_token":"wechat-access","openid":"openid-mixed-key","scope":"snsapi_base"}`
))
return
}
http
.
NotFound
(
w
,
r
)
}))
defer
upstream
.
Close
()
wechatOAuthAccessTokenURL
=
upstream
.
URL
+
"/sns/oauth2/access_token"
handler
,
client
:=
newWeChatOAuthTestHandlerWithSettings
(
t
,
false
,
wechatOAuthTestSettings
(
"mp"
,
"wx-mp-app"
,
"wx-mp-secret"
,
"/auth/wechat/callback"
))
defer
client
.
Close
()
legacyKeyHex
:=
"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
explicitSigningKey
:=
"explicit-payment-resume-signing-key"
t
.
Setenv
(
"PAYMENT_RESUME_SIGNING_KEY"
,
explicitSigningKey
)
handler
.
cfg
.
Totp
.
EncryptionKey
=
legacyKeyHex
handler
.
cfg
.
Totp
.
EncryptionKeyConfigured
=
true
recorder
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
recorder
)
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/api/v1/auth/oauth/wechat/payment/callback?code=wechat-code&state=state-mixed"
,
nil
)
req
.
Host
=
"api.example.com"
req
.
AddCookie
(
encodedCookie
(
wechatPaymentOAuthStateName
,
"state-mixed"
))
req
.
AddCookie
(
encodedCookie
(
wechatPaymentOAuthRedirect
,
"/purchase?from=wechat"
))
req
.
AddCookie
(
encodedCookie
(
wechatPaymentOAuthContextName
,
`{"payment_type":"wxpay","amount":"18.8","order_type":"subscription","plan_id":9}`
))
req
.
AddCookie
(
encodedCookie
(
wechatPaymentOAuthScope
,
"snsapi_base"
))
c
.
Request
=
req
handler
.
WeChatPaymentOAuthCallback
(
c
)
require
.
Equal
(
t
,
http
.
StatusFound
,
recorder
.
Code
)
location
:=
recorder
.
Header
()
.
Get
(
"Location"
)
parsed
,
err
:=
url
.
Parse
(
location
)
require
.
NoError
(
t
,
err
)
fragment
,
err
:=
url
.
ParseQuery
(
parsed
.
Fragment
)
require
.
NoError
(
t
,
err
)
token
:=
fragment
.
Get
(
"wechat_resume_token"
)
require
.
NotEmpty
(
t
,
token
)
claims
,
err
:=
service
.
NewPaymentResumeService
([]
byte
(
explicitSigningKey
))
.
ParseWeChatPaymentResumeToken
(
token
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"openid-mixed-key"
,
claims
.
OpenID
)
require
.
Equal
(
t
,
payment
.
TypeWxpay
,
claims
.
PaymentType
)
require
.
Equal
(
t
,
"18.8"
,
claims
.
Amount
)
require
.
Equal
(
t
,
payment
.
OrderTypeSubscription
,
claims
.
OrderType
)
require
.
EqualValues
(
t
,
9
,
claims
.
PlanID
)
require
.
Equal
(
t
,
"/purchase?from=wechat"
,
claims
.
RedirectTo
)
_
,
err
=
service
.
NewPaymentResumeService
([]
byte
(
"0123456789abcdef0123456789abcdef"
))
.
ParseWeChatPaymentResumeToken
(
token
)
require
.
Error
(
t
,
err
)
}
func
TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels
(
t
*
testing
.
T
)
{
testCases
:=
[]
struct
{
name
string
...
...
@@ -620,7 +827,7 @@ func TestWeChatOAuthCallbackBindRejectsLegacyProviderKeyOwnershipConflict(t *tes
require
.
Zero
(
t
,
count
)
}
func
TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession
(
t
*
testing
.
T
)
{
func
TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession
ReturnsPendingSession
(
t
*
testing
.
T
)
{
originalAccessTokenURL
:=
wechatOAuthAccessTokenURL
originalUserInfoURL
:=
wechatOAuthUserInfoURL
t
.
Cleanup
(
func
()
{
...
...
@@ -693,27 +900,32 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing
require
.
Equal
(
t
,
http
.
StatusOK
,
completeRecorder
.
Code
)
responseData
:=
decodeJSONBody
(
t
,
completeRecorder
)
require
.
NotEmpty
(
t
,
responseData
[
"access_token"
])
require
.
Equal
(
t
,
"pending_session"
,
responseData
[
"auth_result"
])
require
.
Equal
(
t
,
oauthPendingChoiceStep
,
responseData
[
"step"
])
require
.
Equal
(
t
,
true
,
responseData
[
"adoption_required"
])
require
.
Empty
(
t
,
responseData
[
"access_token"
])
userEntity
,
err
:=
client
.
User
.
Query
()
.
Where
(
dbuser
.
EmailEQ
(
"wechat-union-456@wechat-connect.invalid"
))
.
consumed
,
err
:=
client
.
PendingAuthSession
.
Query
()
.
Where
(
pendingauthsession
.
IDEQ
(
pendingSession
.
ID
))
.
Only
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Equa
l
(
t
,
"WeChat Display"
,
userEntity
.
Username
)
require
.
Ni
l
(
t
,
consumed
.
ConsumedAt
)
identity
,
err
:=
client
.
AuthIdentity
.
Query
()
.
userCount
,
err
:=
client
.
User
.
Query
()
.
Count
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Zero
(
t
,
userCount
)
identityCount
,
err
:=
client
.
AuthIdentity
.
Query
()
.
Where
(
authidentity
.
ProviderTypeEQ
(
"wechat"
),
authidentity
.
ProviderKeyEQ
(
"wechat-main"
),
authidentity
.
ProviderSubjectEQ
(
"union-456"
),
)
.
Only
(
ctx
)
Count
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
userEntity
.
ID
,
identity
.
UserID
)
require
.
Equal
(
t
,
"WeChat Display"
,
identity
.
Metadata
[
"display_name"
])
require
.
Equal
(
t
,
"https://cdn.example/wechat.png"
,
identity
.
Metadata
[
"avatar_url"
])
require
.
Zero
(
t
,
identityCount
)
channel
,
err
:=
client
.
AuthIdentityChannel
.
Query
()
.
channel
Count
,
err
:=
client
.
AuthIdentityChannel
.
Query
()
.
Where
(
authidentitychannel
.
ProviderTypeEQ
(
"wechat"
),
authidentitychannel
.
ProviderKeyEQ
(
"wechat-main"
),
...
...
@@ -721,25 +933,82 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing
authidentitychannel
.
ChannelAppIDEQ
(
"wx-open-app"
),
authidentitychannel
.
ChannelSubjectEQ
(
"openid-123"
),
)
.
Only
(
ctx
)
Count
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
identity
.
ID
,
channel
.
IdentityID
)
require
.
Equal
(
t
,
"union-456"
,
channel
.
Metadata
[
"unionid"
])
require
.
Zero
(
t
,
channelCount
)
decision
,
err
:=
client
.
IdentityAdoptionDecision
.
Query
()
.
decision
Count
,
err
:=
client
.
IdentityAdoptionDecision
.
Query
()
.
Where
(
identityadoptiondecision
.
PendingAuthSessionIDEQ
(
pendingSession
.
ID
))
.
Count
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Zero
(
t
,
decisionCount
)
}
func
TestCompleteWeChatOAuthRegistrationBindsIdentityWithoutAdoptionFlags
(
t
*
testing
.
T
)
{
handler
,
client
:=
newOAuthPendingFlowTestHandler
(
t
,
false
)
ctx
:=
context
.
Background
()
session
,
err
:=
client
.
PendingAuthSession
.
Create
()
.
SetSessionToken
(
"wechat-complete-no-adoption-session"
)
.
SetIntent
(
"login"
)
.
SetProviderType
(
"wechat"
)
.
SetProviderKey
(
wechatOAuthProviderKey
)
.
SetProviderSubject
(
"wechat-subject-no-adoption"
)
.
SetResolvedEmail
(
"wechat-subject-no-adoption@wechat-connect.invalid"
)
.
SetBrowserSessionKey
(
"wechat-browser-no-adoption"
)
.
SetUpstreamIdentityClaims
(
map
[
string
]
any
{
"username"
:
"wechat_user"
,
"suggested_display_name"
:
"WeChat Legacy"
,
"suggested_avatar_url"
:
"https://cdn.example/wechat-legacy.png"
,
"mode"
:
"open"
,
"channel"
:
"open"
,
"channel_app_id"
:
"wx-open-app"
,
"channel_subject"
:
"openid-legacy"
,
})
.
SetExpiresAt
(
time
.
Now
()
.
UTC
()
.
Add
(
10
*
time
.
Minute
))
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
body
:=
bytes
.
NewBufferString
(
`{"invitation_code":"invite-1"}`
)
recorder
:=
httptest
.
NewRecorder
()
completeCtx
,
_
:=
gin
.
CreateTestContext
(
recorder
)
completeReq
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/auth/oauth/wechat/complete-registration"
,
body
)
completeReq
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
completeReq
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthPendingSessionCookieName
,
Value
:
encodeCookieValue
(
session
.
SessionToken
)})
completeReq
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthPendingBrowserCookieName
,
Value
:
encodeCookieValue
(
"wechat-browser-no-adoption"
)})
completeCtx
.
Request
=
completeReq
handler
.
CompleteWeChatOAuthRegistration
(
completeCtx
)
require
.
Equal
(
t
,
http
.
StatusOK
,
recorder
.
Code
)
responseData
:=
decodeJSONBody
(
t
,
recorder
)
require
.
NotEmpty
(
t
,
responseData
[
"access_token"
])
require
.
NotEmpty
(
t
,
responseData
[
"refresh_token"
])
userEntity
,
err
:=
client
.
User
.
Query
()
.
Where
(
dbuser
.
EmailEQ
(
session
.
ResolvedEmail
))
.
Only
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
decision
.
IdentityID
)
require
.
Equal
(
t
,
identity
.
ID
,
*
decision
.
IdentityID
)
require
.
True
(
t
,
decision
.
AdoptDisplayName
)
require
.
True
(
t
,
decision
.
AdoptAvatar
)
require
.
Equal
(
t
,
"wechat_user"
,
userEntity
.
Username
)
consumed
,
err
:=
client
.
PendingAuthSession
.
Query
()
.
Where
(
pendingauthsession
.
IDEQ
(
pendingSession
.
ID
))
.
identity
,
err
:=
client
.
AuthIdentity
.
Query
()
.
Where
(
authidentity
.
ProviderTypeEQ
(
"wechat"
),
authidentity
.
ProviderKeyEQ
(
wechatOAuthProviderKey
),
authidentity
.
ProviderSubjectEQ
(
"wechat-subject-no-adoption"
),
)
.
Only
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
consumed
.
ConsumedAt
)
require
.
Equal
(
t
,
userEntity
.
ID
,
identity
.
UserID
)
decision
,
err
:=
client
.
IdentityAdoptionDecision
.
Query
()
.
Where
(
identityadoptiondecision
.
PendingAuthSessionIDEQ
(
session
.
ID
))
.
Only
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
decision
.
IdentityID
)
require
.
Equal
(
t
,
identity
.
ID
,
*
decision
.
IdentityID
)
require
.
False
(
t
,
decision
.
AdoptDisplayName
)
require
.
False
(
t
,
decision
.
AdoptAvatar
)
}
func
TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity
(
t
*
testing
.
T
)
{
...
...
@@ -901,6 +1170,62 @@ func TestCompleteWeChatOAuthRegistrationRejectsAdoptExistingUserSession(t *testi
require
.
Nil
(
t
,
storedSession
.
ConsumedAt
)
}
func
TestCompleteWeChatOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired
(
t
*
testing
.
T
)
{
handler
,
client
:=
newWeChatOAuthTestHandler
(
t
,
false
)
defer
client
.
Close
()
ctx
:=
context
.
Background
()
session
,
err
:=
client
.
PendingAuthSession
.
Create
()
.
SetSessionToken
(
"wechat-complete-choice-session"
)
.
SetIntent
(
"login"
)
.
SetProviderType
(
"wechat"
)
.
SetProviderKey
(
"wechat-main"
)
.
SetProviderSubject
(
"wechat-choice-subject-1"
)
.
SetResolvedEmail
(
"wechat-choice-subject-1@wechat-connect.invalid"
)
.
SetBrowserSessionKey
(
"wechat-choice-browser"
)
.
SetUpstreamIdentityClaims
(
map
[
string
]
any
{
"username"
:
"wechat_user"
,
})
.
SetLocalFlowState
(
map
[
string
]
any
{
oauthCompletionResponseKey
:
map
[
string
]
any
{
"step"
:
oauthPendingChoiceStep
,
"redirect"
:
"/dashboard"
,
"email"
:
"fresh@example.com"
,
"resolved_email"
:
"fresh@example.com"
,
"force_email_on_signup"
:
true
,
},
})
.
SetExpiresAt
(
time
.
Now
()
.
UTC
()
.
Add
(
10
*
time
.
Minute
))
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
body
:=
bytes
.
NewBufferString
(
`{"invitation_code":"invite-1"}`
)
recorder
:=
httptest
.
NewRecorder
()
completeCtx
,
_
:=
gin
.
CreateTestContext
(
recorder
)
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/auth/oauth/wechat/complete-registration"
,
body
)
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthPendingSessionCookieName
,
Value
:
encodeCookieValue
(
session
.
SessionToken
)})
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthPendingBrowserCookieName
,
Value
:
encodeCookieValue
(
"wechat-choice-browser"
)})
completeCtx
.
Request
=
req
handler
.
CompleteWeChatOAuthRegistration
(
completeCtx
)
require
.
Equal
(
t
,
http
.
StatusOK
,
recorder
.
Code
)
responseData
:=
decodeJSONBody
(
t
,
recorder
)
require
.
Equal
(
t
,
"pending_session"
,
responseData
[
"auth_result"
])
require
.
Equal
(
t
,
oauthPendingChoiceStep
,
responseData
[
"step"
])
require
.
Equal
(
t
,
true
,
responseData
[
"force_email_on_signup"
])
require
.
Empty
(
t
,
responseData
[
"access_token"
])
userCount
,
err
:=
client
.
User
.
Query
()
.
Count
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Zero
(
t
,
userCount
)
storedSession
,
err
:=
client
.
PendingAuthSession
.
Get
(
ctx
,
session
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Nil
(
t
,
storedSession
.
ConsumedAt
)
}
func
TestWeChatOAuthCallbackRepairsLegacyProviderKeyCanonicalIdentity
(
t
*
testing
.
T
)
{
originalAccessTokenURL
:=
wechatOAuthAccessTokenURL
originalUserInfoURL
:=
wechatOAuthUserInfoURL
...
...
@@ -1083,18 +1408,6 @@ func newWeChatOAuthTestHandlerWithSettings(t *testing.T, invitationEnabled bool,
},
client
}
func
assertOAuthRedirectError
(
t
*
testing
.
T
,
location
string
,
errorCode
string
,
errorMessage
string
)
{
t
.
Helper
()
parsed
,
err
:=
url
.
Parse
(
location
)
require
.
NoError
(
t
,
err
)
fragment
,
err
:=
url
.
ParseQuery
(
parsed
.
Fragment
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
errorCode
,
fragment
.
Get
(
"error"
))
require
.
Equal
(
t
,
errorMessage
,
fragment
.
Get
(
"error_message"
))
}
type
wechatOAuthSettingRepoStub
struct
{
values
map
[
string
]
string
}
...
...
backend/internal/handler/payment_handler.go
View file @
ddf80f5e
...
...
@@ -2,9 +2,9 @@ package handler
import
(
"fmt"
"net/http"
"strconv"
"strings"
"time"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/payment"
...
...
@@ -454,29 +454,65 @@ func (h *PaymentHandler) VerifyOrder(c *gin.Context) {
// PublicOrderResult is the limited order info returned by the public verify endpoint.
// No user details are exposed — only payment status information.
type
PublicOrderResult
struct
{
ID
int64
`json:"id"`
OutTradeNo
string
`json:"out_trade_no"`
Amount
float64
`json:"amount"`
PayAmount
float64
`json:"pay_amount"`
PaymentType
string
`json:"payment_type"`
OrderType
string
`json:"order_type"`
Status
string
`json:"status"`
ID
int64
`json:"id"`
OutTradeNo
string
`json:"out_trade_no"`
Amount
float64
`json:"amount"`
PayAmount
float64
`json:"pay_amount"`
FeeRate
float64
`json:"fee_rate"`
PaymentType
string
`json:"payment_type"`
OrderType
string
`json:"order_type"`
Status
string
`json:"status"`
CreatedAt
time
.
Time
`json:"created_at"`
ExpiresAt
time
.
Time
`json:"expires_at"`
PaidAt
*
time
.
Time
`json:"paid_at,omitempty"`
CompletedAt
*
time
.
Time
`json:"completed_at,omitempty"`
RefundAmount
float64
`json:"refund_amount"`
RefundReason
*
string
`json:"refund_reason,omitempty"`
RefundRequestedAt
*
time
.
Time
`json:"refund_requested_at,omitempty"`
RefundRequestedBy
*
string
`json:"refund_requested_by,omitempty"`
RefundRequestReason
*
string
`json:"refund_request_reason,omitempty"`
PlanID
*
int64
`json:"plan_id,omitempty"`
}
var
errPaymentPublicOrderVerifyRemoved
=
infraerrors
.
New
(
http
.
StatusGone
,
"PAYMENT_PUBLIC_ORDER_VERIFY_REMOVED"
,
"public payment order verification by out_trade_no has been removed; use resume_token recovery instead"
,
)
.
WithMetadata
(
map
[
string
]
string
{
"replacement_endpoint"
:
"/api/v1/payment/public/orders/resolve"
,
"replacement_field"
:
"resume_token"
,
})
// VerifyOrderPublic is kept as a compatibility shim for the removed anonymous
// out_trade_no lookup endpoint and always returns HTTP 410 Gone.
func
buildPublicOrderResult
(
order
*
dbent
.
PaymentOrder
)
PublicOrderResult
{
return
PublicOrderResult
{
ID
:
order
.
ID
,
OutTradeNo
:
order
.
OutTradeNo
,
Amount
:
order
.
Amount
,
PayAmount
:
order
.
PayAmount
,
FeeRate
:
order
.
FeeRate
,
PaymentType
:
order
.
PaymentType
,
OrderType
:
order
.
OrderType
,
Status
:
order
.
Status
,
CreatedAt
:
order
.
CreatedAt
,
ExpiresAt
:
order
.
ExpiresAt
,
PaidAt
:
order
.
PaidAt
,
CompletedAt
:
order
.
CompletedAt
,
RefundAmount
:
order
.
RefundAmount
,
RefundReason
:
order
.
RefundReason
,
RefundRequestedAt
:
order
.
RefundRequestedAt
,
RefundRequestedBy
:
order
.
RefundRequestedBy
,
RefundRequestReason
:
order
.
RefundRequestReason
,
PlanID
:
order
.
PlanID
,
}
}
// VerifyOrderPublic keeps the legacy anonymous out_trade_no lookup available as
// a compatibility path for older result pages and staggered deploys.
// POST /api/v1/payment/public/orders/verify
func
(
h
*
PaymentHandler
)
VerifyOrderPublic
(
c
*
gin
.
Context
)
{
response
.
ErrorFrom
(
c
,
errPaymentPublicOrderVerifyRemoved
)
var
req
VerifyOrderRequest
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid request: "
+
err
.
Error
())
return
}
order
,
err
:=
h
.
paymentService
.
VerifyOrderPublic
(
c
.
Request
.
Context
(),
req
.
OutTradeNo
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
buildPublicOrderResult
(
order
))
}
// ResolveOrderPublicByResumeToken resolves a payment order from a signed resume token.
...
...
@@ -493,15 +529,7 @@ func (h *PaymentHandler) ResolveOrderPublicByResumeToken(c *gin.Context) {
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
PublicOrderResult
{
ID
:
order
.
ID
,
OutTradeNo
:
order
.
OutTradeNo
,
Amount
:
order
.
Amount
,
PayAmount
:
order
.
PayAmount
,
PaymentType
:
order
.
PaymentType
,
OrderType
:
order
.
OrderType
,
Status
:
order
.
Status
,
})
response
.
Success
(
c
,
buildPublicOrderResult
(
order
))
}
// requireAuth extracts the authenticated subject from the context.
...
...
backend/internal/handler/payment_handler_resume_test.go
View file @
ddf80f5e
...
...
@@ -4,16 +4,17 @@ package handler
import
(
"bytes"
"context"
"database/sql"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
...
...
@@ -74,7 +75,7 @@ func TestApplyWeChatPaymentResumeClaimsRejectsPaymentTypeMismatch(t *testing.T)
}
}
func
TestVerifyOrderPublicReturns
Gon
e
(
t
*
testing
.
T
)
{
func
TestVerifyOrderPublicReturns
LegacyOrderStat
e
(
t
*
testing
.
T
)
{
t
.
Parallel
()
gin
.
SetMode
(
gin
.
TestMode
)
...
...
@@ -90,6 +91,32 @@ func TestVerifyOrderPublicReturnsGone(t *testing.T) {
client
:=
enttest
.
NewClient
(
t
,
enttest
.
WithOptions
(
dbent
.
Driver
(
drv
)))
t
.
Cleanup
(
func
()
{
_
=
client
.
Close
()
})
user
,
err
:=
client
.
User
.
Create
()
.
SetEmail
(
"public-verify@example.com"
)
.
SetPasswordHash
(
"hash"
)
.
SetUsername
(
"public-verify-user"
)
.
Save
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
order
,
err
:=
client
.
PaymentOrder
.
Create
()
.
SetUserID
(
user
.
ID
)
.
SetUserEmail
(
user
.
Email
)
.
SetUserName
(
user
.
Username
)
.
SetAmount
(
88
)
.
SetPayAmount
(
90.64
)
.
SetFeeRate
(
0.03
)
.
SetRechargeCode
(
"PUBLIC-VERIFY"
)
.
SetOutTradeNo
(
"legacy-order-no"
)
.
SetPaymentType
(
payment
.
TypeAlipay
)
.
SetPaymentTradeNo
(
"trade-public-verify"
)
.
SetOrderType
(
payment
.
OrderTypeBalance
)
.
SetStatus
(
service
.
OrderStatusPending
)
.
SetExpiresAt
(
time
.
Now
()
.
Add
(
time
.
Hour
))
.
SetClientIP
(
"127.0.0.1"
)
.
SetSrcHost
(
"api.example.com"
)
.
Save
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
paymentSvc
:=
service
.
NewPaymentService
(
client
,
payment
.
NewRegistry
(),
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
h
:=
NewPaymentHandler
(
paymentSvc
,
nil
,
nil
)
...
...
@@ -104,11 +131,238 @@ func TestVerifyOrderPublicReturnsGone(t *testing.T) {
h
.
VerifyOrderPublic
(
ctx
)
require
.
Equal
(
t
,
http
.
StatusGone
,
recorder
.
Code
)
require
.
Equal
(
t
,
http
.
StatusOK
,
recorder
.
Code
)
var
resp
struct
{
Code
int
`json:"code"`
Data
struct
{
ID
int64
`json:"id"`
OutTradeNo
string
`json:"out_trade_no"`
Amount
float64
`json:"amount"`
PayAmount
float64
`json:"pay_amount"`
FeeRate
float64
`json:"fee_rate"`
PaymentType
string
`json:"payment_type"`
OrderType
string
`json:"order_type"`
Status
string
`json:"status"`
RefundAmount
float64
`json:"refund_amount"`
CreatedAt
string
`json:"created_at"`
ExpiresAt
string
`json:"expires_at"`
}
`json:"data"`
}
require
.
NoError
(
t
,
json
.
Unmarshal
(
recorder
.
Body
.
Bytes
(),
&
resp
))
require
.
Equal
(
t
,
0
,
resp
.
Code
)
require
.
Equal
(
t
,
order
.
ID
,
resp
.
Data
.
ID
)
require
.
Equal
(
t
,
"legacy-order-no"
,
resp
.
Data
.
OutTradeNo
)
require
.
Equal
(
t
,
90.64
,
resp
.
Data
.
PayAmount
)
require
.
Equal
(
t
,
0.03
,
resp
.
Data
.
FeeRate
)
require
.
Equal
(
t
,
payment
.
TypeAlipay
,
resp
.
Data
.
PaymentType
)
require
.
Equal
(
t
,
payment
.
OrderTypeBalance
,
resp
.
Data
.
OrderType
)
require
.
Equal
(
t
,
service
.
OrderStatusPending
,
resp
.
Data
.
Status
)
require
.
Equal
(
t
,
0.0
,
resp
.
Data
.
RefundAmount
)
require
.
NotEmpty
(
t
,
resp
.
Data
.
CreatedAt
)
require
.
NotEmpty
(
t
,
resp
.
Data
.
ExpiresAt
)
}
func
TestResolveOrderPublicByResumeTokenReturnsFrontendContractFields
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
t
.
Setenv
(
"PAYMENT_RESUME_SIGNING_KEY"
,
"0123456789abcdef0123456789abcdef"
)
db
,
err
:=
sql
.
Open
(
"sqlite"
,
"file:payment_handler_public_resolve?mode=memory&cache=shared"
)
require
.
NoError
(
t
,
err
)
t
.
Cleanup
(
func
()
{
_
=
db
.
Close
()
})
_
,
err
=
db
.
Exec
(
"PRAGMA foreign_keys = ON"
)
require
.
NoError
(
t
,
err
)
drv
:=
entsql
.
OpenDB
(
dialect
.
SQLite
,
db
)
client
:=
enttest
.
NewClient
(
t
,
enttest
.
WithOptions
(
dbent
.
Driver
(
drv
)))
t
.
Cleanup
(
func
()
{
_
=
client
.
Close
()
})
user
,
err
:=
client
.
User
.
Create
()
.
SetEmail
(
"public-resolve@example.com"
)
.
SetPasswordHash
(
"hash"
)
.
SetUsername
(
"public-resolve-user"
)
.
Save
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
order
,
err
:=
client
.
PaymentOrder
.
Create
()
.
SetUserID
(
user
.
ID
)
.
SetUserEmail
(
user
.
Email
)
.
SetUserName
(
user
.
Username
)
.
SetAmount
(
100
)
.
SetPayAmount
(
103
)
.
SetFeeRate
(
0.03
)
.
SetRechargeCode
(
"PUBLIC-RESOLVE"
)
.
SetOutTradeNo
(
"resolve-order-no"
)
.
SetPaymentType
(
payment
.
TypeAlipay
)
.
SetPaymentTradeNo
(
"trade-public-resolve"
)
.
SetOrderType
(
payment
.
OrderTypeBalance
)
.
SetStatus
(
service
.
OrderStatusPaid
)
.
SetExpiresAt
(
time
.
Now
()
.
Add
(
time
.
Hour
))
.
SetPaidAt
(
time
.
Now
())
.
SetClientIP
(
"127.0.0.1"
)
.
SetSrcHost
(
"api.example.com"
)
.
Save
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
resumeSvc
:=
service
.
NewPaymentResumeService
([]
byte
(
"0123456789abcdef0123456789abcdef"
))
token
,
err
:=
resumeSvc
.
CreateToken
(
service
.
ResumeTokenClaims
{
OrderID
:
order
.
ID
,
UserID
:
user
.
ID
,
PaymentType
:
payment
.
TypeAlipay
,
CanonicalReturnURL
:
"https://app.example.com/payment/result"
,
})
require
.
NoError
(
t
,
err
)
configSvc
:=
service
.
NewPaymentConfigService
(
client
,
nil
,
[]
byte
(
"0123456789abcdef0123456789abcdef"
))
paymentSvc
:=
service
.
NewPaymentService
(
client
,
payment
.
NewRegistry
(),
nil
,
nil
,
nil
,
configSvc
,
nil
,
nil
)
h
:=
NewPaymentHandler
(
paymentSvc
,
nil
,
nil
)
recorder
:=
httptest
.
NewRecorder
()
ctx
,
_
:=
gin
.
CreateTestContext
(
recorder
)
ctx
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/payment/public/orders/resolve"
,
bytes
.
NewBufferString
(
`{"resume_token":"`
+
token
+
`"}`
),
)
ctx
.
Request
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
h
.
ResolveOrderPublicByResumeToken
(
ctx
)
require
.
Equal
(
t
,
http
.
StatusOK
,
recorder
.
Code
)
var
resp
struct
{
Code
int
`json:"code"`
Data
map
[
string
]
any
`json:"data"`
}
require
.
NoError
(
t
,
json
.
Unmarshal
(
recorder
.
Body
.
Bytes
(),
&
resp
))
require
.
Equal
(
t
,
0
,
resp
.
Code
)
require
.
Equal
(
t
,
float64
(
order
.
ID
),
resp
.
Data
[
"id"
])
require
.
Equal
(
t
,
"resolve-order-no"
,
resp
.
Data
[
"out_trade_no"
])
require
.
Equal
(
t
,
100.0
,
resp
.
Data
[
"amount"
])
require
.
Equal
(
t
,
103.0
,
resp
.
Data
[
"pay_amount"
])
require
.
Equal
(
t
,
0.03
,
resp
.
Data
[
"fee_rate"
])
require
.
Equal
(
t
,
payment
.
TypeAlipay
,
resp
.
Data
[
"payment_type"
])
require
.
Equal
(
t
,
payment
.
OrderTypeBalance
,
resp
.
Data
[
"order_type"
])
require
.
Equal
(
t
,
service
.
OrderStatusPaid
,
resp
.
Data
[
"status"
])
require
.
Contains
(
t
,
resp
.
Data
,
"created_at"
)
require
.
Contains
(
t
,
resp
.
Data
,
"expires_at"
)
require
.
Contains
(
t
,
resp
.
Data
,
"refund_amount"
)
}
var
resp
response
.
Response
func
TestResolveOrderPublicByResumeTokenReturnsBadRequestForMismatchedToken
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
t
.
Setenv
(
"PAYMENT_RESUME_SIGNING_KEY"
,
"0123456789abcdef0123456789abcdef"
)
db
,
err
:=
sql
.
Open
(
"sqlite"
,
"file:payment_handler_public_resolve_mismatch?mode=memory&cache=shared"
)
require
.
NoError
(
t
,
err
)
t
.
Cleanup
(
func
()
{
_
=
db
.
Close
()
})
_
,
err
=
db
.
Exec
(
"PRAGMA foreign_keys = ON"
)
require
.
NoError
(
t
,
err
)
drv
:=
entsql
.
OpenDB
(
dialect
.
SQLite
,
db
)
client
:=
enttest
.
NewClient
(
t
,
enttest
.
WithOptions
(
dbent
.
Driver
(
drv
)))
t
.
Cleanup
(
func
()
{
_
=
client
.
Close
()
})
user
,
err
:=
client
.
User
.
Create
()
.
SetEmail
(
"public-resolve-mismatch@example.com"
)
.
SetPasswordHash
(
"hash"
)
.
SetUsername
(
"public-resolve-mismatch-user"
)
.
Save
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
order
,
err
:=
client
.
PaymentOrder
.
Create
()
.
SetUserID
(
user
.
ID
)
.
SetUserEmail
(
user
.
Email
)
.
SetUserName
(
user
.
Username
)
.
SetAmount
(
100
)
.
SetPayAmount
(
103
)
.
SetFeeRate
(
0.03
)
.
SetRechargeCode
(
"PUBLIC-RESOLVE-MISMATCH"
)
.
SetOutTradeNo
(
"resolve-order-mismatch-no"
)
.
SetPaymentType
(
payment
.
TypeAlipay
)
.
SetPaymentTradeNo
(
"trade-public-resolve-mismatch"
)
.
SetOrderType
(
payment
.
OrderTypeBalance
)
.
SetStatus
(
service
.
OrderStatusPaid
)
.
SetExpiresAt
(
time
.
Now
()
.
Add
(
time
.
Hour
))
.
SetPaidAt
(
time
.
Now
())
.
SetClientIP
(
"127.0.0.1"
)
.
SetSrcHost
(
"api.example.com"
)
.
Save
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
resumeSvc
:=
service
.
NewPaymentResumeService
([]
byte
(
"0123456789abcdef0123456789abcdef"
))
token
,
err
:=
resumeSvc
.
CreateToken
(
service
.
ResumeTokenClaims
{
OrderID
:
order
.
ID
,
UserID
:
user
.
ID
+
999
,
PaymentType
:
payment
.
TypeAlipay
,
CanonicalReturnURL
:
"https://app.example.com/payment/result"
,
})
require
.
NoError
(
t
,
err
)
configSvc
:=
service
.
NewPaymentConfigService
(
client
,
nil
,
[]
byte
(
"0123456789abcdef0123456789abcdef"
))
paymentSvc
:=
service
.
NewPaymentService
(
client
,
payment
.
NewRegistry
(),
nil
,
nil
,
nil
,
configSvc
,
nil
,
nil
)
h
:=
NewPaymentHandler
(
paymentSvc
,
nil
,
nil
)
recorder
:=
httptest
.
NewRecorder
()
ctx
,
_
:=
gin
.
CreateTestContext
(
recorder
)
ctx
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/payment/public/orders/resolve"
,
bytes
.
NewBufferString
(
`{"resume_token":"`
+
token
+
`"}`
),
)
ctx
.
Request
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
h
.
ResolveOrderPublicByResumeToken
(
ctx
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
recorder
.
Code
)
var
resp
struct
{
Code
int
`json:"code"`
Reason
string
`json:"reason"`
Message
string
`json:"message"`
}
require
.
NoError
(
t
,
json
.
Unmarshal
(
recorder
.
Body
.
Bytes
(),
&
resp
))
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
resp
.
Code
)
require
.
Equal
(
t
,
"INVALID_RESUME_TOKEN"
,
resp
.
Reason
)
}
func
TestVerifyOrderPublicRejectsBlankOutTradeNo
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
db
,
err
:=
sql
.
Open
(
"sqlite"
,
"file:payment_handler_public_verify_blank?mode=memory&cache=shared"
)
require
.
NoError
(
t
,
err
)
t
.
Cleanup
(
func
()
{
_
=
db
.
Close
()
})
_
,
err
=
db
.
Exec
(
"PRAGMA foreign_keys = ON"
)
require
.
NoError
(
t
,
err
)
drv
:=
entsql
.
OpenDB
(
dialect
.
SQLite
,
db
)
client
:=
enttest
.
NewClient
(
t
,
enttest
.
WithOptions
(
dbent
.
Driver
(
drv
)))
t
.
Cleanup
(
func
()
{
_
=
client
.
Close
()
})
paymentSvc
:=
service
.
NewPaymentService
(
client
,
payment
.
NewRegistry
(),
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
h
:=
NewPaymentHandler
(
paymentSvc
,
nil
,
nil
)
recorder
:=
httptest
.
NewRecorder
()
ctx
,
_
:=
gin
.
CreateTestContext
(
recorder
)
ctx
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/payment/public/orders/verify"
,
bytes
.
NewBufferString
(
`{"out_trade_no":" "}`
),
)
ctx
.
Request
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
h
.
VerifyOrderPublic
(
ctx
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
recorder
.
Code
)
var
resp
struct
{
Code
int
`json:"code"`
Reason
string
`json:"reason"`
}
require
.
NoError
(
t
,
json
.
Unmarshal
(
recorder
.
Body
.
Bytes
(),
&
resp
))
require
.
Equal
(
t
,
http
.
StatusGone
,
resp
.
Code
)
require
.
Equal
(
t
,
"PAYMENT_PUBLIC_ORDER_VERIFY_REMOVED"
,
resp
.
Reason
)
require
.
Contains
(
t
,
resp
.
Message
,
"removed"
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
resp
.
Code
)
require
.
Equal
(
t
,
"INVALID_OUT_TRADE_NO"
,
resp
.
Reason
)
}
backend/internal/handler/user_handler.go
View file @
ddf80f5e
...
...
@@ -249,7 +249,7 @@ func (h *UserHandler) UnbindIdentity(c *gin.Context) {
return
}
updatedUser
,
err
:=
h
.
userService
.
UnbindUserAuthProvider
(
updatedUser
,
unbound
,
err
:=
h
.
userService
.
UnbindUserAuthProvider
WithResult
(
c
.
Request
.
Context
(),
subject
.
UserID
,
c
.
Param
(
"provider"
),
...
...
@@ -258,6 +258,12 @@ func (h *UserHandler) UnbindIdentity(c *gin.Context) {
response
.
ErrorFrom
(
c
,
err
)
return
}
if
unbound
&&
h
.
authService
!=
nil
{
if
err
:=
h
.
authService
.
RevokeAllUserTokens
(
c
.
Request
.
Context
(),
subject
.
UserID
);
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
}
profileResp
,
err
:=
h
.
buildUserProfileResponse
(
c
.
Request
.
Context
(),
subject
.
UserID
,
updatedUser
)
if
err
!=
nil
{
...
...
@@ -504,8 +510,12 @@ func inferUserProfileSources(user *service.User, identities service.UserIdentity
thirdParty
:=
thirdPartyIdentityProviders
(
identities
)
var
avatarSource
*
userProfileSourceContext
if
strings
.
TrimSpace
(
user
.
AvatarURL
)
!=
""
&&
len
(
thirdParty
)
==
1
{
avatarSource
=
buildUserProfileSourceContext
(
thirdParty
[
0
]
.
Provider
)
avatarValue
:=
strings
.
TrimSpace
(
user
.
AvatarURL
)
for
_
,
summary
:=
range
thirdParty
{
if
avatarValue
!=
""
&&
avatarValue
==
strings
.
TrimSpace
(
summary
.
AvatarURL
)
{
avatarSource
=
buildUserProfileSourceContext
(
summary
.
Provider
)
break
}
}
usernameValue
:=
strings
.
TrimSpace
(
user
.
Username
)
...
...
@@ -516,9 +526,6 @@ func inferUserProfileSources(user *service.User, identities service.UserIdentity
break
}
}
if
usernameSource
==
nil
&&
usernameValue
!=
""
&&
len
(
thirdParty
)
==
1
{
usernameSource
=
buildUserProfileSourceContext
(
thirdParty
[
0
]
.
Provider
)
}
profileSources
:=
map
[
string
]
*
userProfileSourceContext
{}
if
avatarSource
!=
nil
{
...
...
backend/internal/handler/user_handler_test.go
View file @
ddf80f5e
...
...
@@ -253,7 +253,7 @@ func TestUserHandlerGetProfileReturnsIdentitySummaries(t *testing.T) {
require
.
Equal
(
t
,
"https://issuer.example.com"
,
resp
.
Data
.
Identities
.
OIDC
.
ProviderKey
)
require
.
False
(
t
,
resp
.
Data
.
Identities
.
WeChat
.
Bound
)
require
.
True
(
t
,
resp
.
Data
.
Identities
.
WeChat
.
CanBind
)
require
.
Contains
(
t
,
resp
.
Data
.
Identities
.
WeChat
.
BindStartPath
,
"/api/v1/auth/oauth/wechat/start"
)
require
.
Contains
(
t
,
resp
.
Data
.
Identities
.
WeChat
.
BindStartPath
,
"/api/v1/auth/oauth/wechat/
bind/
start"
)
}
func
TestUserHandlerGetProfileReturnsLegacyCompatibilityFields
(
t
*
testing
.
T
)
{
...
...
@@ -270,18 +270,19 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
AvatarURL
:
"https://cdn.example.com/linuxdo.png"
,
AvatarSource
:
"remote_url"
,
},
identities
:
[]
service
.
UserAuthIdentityRecord
{
{
ProviderType
:
"linuxdo"
,
ProviderKey
:
"linuxdo"
,
ProviderSubject
:
"linuxdo-subject-21"
,
VerifiedAt
:
&
verifiedAt
,
Metadata
:
map
[
string
]
any
{
"username"
:
"linuxdo-handle"
,
identities
:
[]
service
.
UserAuthIdentityRecord
{
{
ProviderType
:
"linuxdo"
,
ProviderKey
:
"linuxdo"
,
ProviderSubject
:
"linuxdo-subject-21"
,
VerifiedAt
:
&
verifiedAt
,
Metadata
:
map
[
string
]
any
{
"username"
:
"linuxdo-handle"
,
"avatar_url"
:
"https://cdn.example.com/linuxdo.png"
,
},
},
},
},
}
}
handler
:=
NewUserHandler
(
service
.
NewUserService
(
repo
,
nil
,
nil
,
nil
),
nil
,
nil
,
nil
)
recorder
:=
httptest
.
NewRecorder
()
...
...
@@ -331,10 +332,102 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
require
.
Equal
(
t
,
"linuxdo"
,
usernameSource
[
"source"
])
}
func
TestUserHandlerGetProfileDoesNotInferEditedProfileSourcesWithoutMatchingIdentityMetadata
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
repo
:=
&
userHandlerRepoStub
{
user
:
&
service
.
User
{
ID
:
22
,
Email
:
"edited-profile@example.com"
,
Username
:
"custom-name"
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
AvatarURL
:
"https://cdn.example.com/custom.png"
,
AvatarSource
:
"remote_url"
,
},
identities
:
[]
service
.
UserAuthIdentityRecord
{
{
ProviderType
:
"linuxdo"
,
ProviderKey
:
"linuxdo"
,
ProviderSubject
:
"linuxdo-subject-22"
,
Metadata
:
map
[
string
]
any
{
"username"
:
"linuxdo-handle"
,
"avatar_url"
:
"https://cdn.example.com/linuxdo.png"
,
},
},
},
}
handler
:=
NewUserHandler
(
service
.
NewUserService
(
repo
,
nil
,
nil
,
nil
),
nil
,
nil
,
nil
)
recorder
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
recorder
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/api/v1/user/profile"
,
nil
)
c
.
Set
(
string
(
middleware2
.
ContextKeyUser
),
middleware2
.
AuthSubject
{
UserID
:
22
})
handler
.
GetProfile
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
recorder
.
Code
)
var
resp
struct
{
Code
int
`json:"code"`
Data
map
[
string
]
any
`json:"data"`
}
require
.
NoError
(
t
,
json
.
Unmarshal
(
recorder
.
Body
.
Bytes
(),
&
resp
))
require
.
Equal
(
t
,
0
,
resp
.
Code
)
require
.
NotContains
(
t
,
resp
.
Data
,
"avatar_source"
)
require
.
NotContains
(
t
,
resp
.
Data
,
"username_source"
)
require
.
NotContains
(
t
,
resp
.
Data
,
"profile_sources"
)
}
type
userHandlerEmailCacheStub
struct
{
data
*
service
.
VerificationCodeData
}
type
userHandlerRefreshTokenCacheStub
struct
{
revokedUserIDs
[]
int64
}
func
(
s
*
userHandlerRefreshTokenCacheStub
)
StoreRefreshToken
(
context
.
Context
,
string
,
*
service
.
RefreshTokenData
,
time
.
Duration
)
error
{
return
nil
}
func
(
s
*
userHandlerRefreshTokenCacheStub
)
GetRefreshToken
(
context
.
Context
,
string
)
(
*
service
.
RefreshTokenData
,
error
)
{
return
nil
,
service
.
ErrRefreshTokenNotFound
}
func
(
s
*
userHandlerRefreshTokenCacheStub
)
DeleteRefreshToken
(
context
.
Context
,
string
)
error
{
return
nil
}
func
(
s
*
userHandlerRefreshTokenCacheStub
)
DeleteUserRefreshTokens
(
_
context
.
Context
,
userID
int64
)
error
{
s
.
revokedUserIDs
=
append
(
s
.
revokedUserIDs
,
userID
)
return
nil
}
func
(
s
*
userHandlerRefreshTokenCacheStub
)
DeleteTokenFamily
(
context
.
Context
,
string
)
error
{
return
nil
}
func
(
s
*
userHandlerRefreshTokenCacheStub
)
AddToUserTokenSet
(
context
.
Context
,
int64
,
string
,
time
.
Duration
)
error
{
return
nil
}
func
(
s
*
userHandlerRefreshTokenCacheStub
)
AddToFamilyTokenSet
(
context
.
Context
,
string
,
string
,
time
.
Duration
)
error
{
return
nil
}
func
(
s
*
userHandlerRefreshTokenCacheStub
)
GetUserTokenHashes
(
context
.
Context
,
int64
)
([]
string
,
error
)
{
return
nil
,
nil
}
func
(
s
*
userHandlerRefreshTokenCacheStub
)
GetFamilyTokenHashes
(
context
.
Context
,
string
)
([]
string
,
error
)
{
return
nil
,
nil
}
func
(
s
*
userHandlerRefreshTokenCacheStub
)
IsTokenInFamily
(
context
.
Context
,
string
,
string
)
(
bool
,
error
)
{
return
false
,
nil
}
func
(
s
*
userHandlerEmailCacheStub
)
GetVerificationCode
(
context
.
Context
,
string
)
(
*
service
.
VerificationCodeData
,
error
)
{
return
s
.
data
,
nil
}
...
...
@@ -495,6 +588,98 @@ func TestUserHandlerUnbindIdentityReturnsUpdatedProfile(t *testing.T) {
require
.
Equal
(
t
,
false
,
linuxdoBinding
[
"bound"
])
}
func
TestUserHandlerUnbindIdentityRevokesAllUserSessionsWhenAuthServiceConfigured
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
repo
:=
&
userHandlerRepoStub
{
user
:
&
service
.
User
{
ID
:
23
,
Email
:
"identity@example.com"
,
Username
:
"identity-user"
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
TokenVersion
:
4
,
},
identities
:
[]
service
.
UserAuthIdentityRecord
{
{
ProviderType
:
"email"
,
ProviderKey
:
"email"
,
ProviderSubject
:
"identity@example.com"
,
},
{
ProviderType
:
"linuxdo"
,
ProviderKey
:
"linuxdo"
,
ProviderSubject
:
"linuxdo-subject-23"
,
},
},
}
refreshTokenCache
:=
&
userHandlerRefreshTokenCacheStub
{}
cfg
:=
&
config
.
Config
{
JWT
:
config
.
JWTConfig
{
Secret
:
"test-secret"
,
ExpireHour
:
1
,
},
}
authService
:=
service
.
NewAuthService
(
nil
,
repo
,
nil
,
refreshTokenCache
,
cfg
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
handler
:=
NewUserHandler
(
service
.
NewUserService
(
repo
,
nil
,
nil
,
nil
),
authService
,
nil
,
nil
)
recorder
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
recorder
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodDelete
,
"/api/v1/user/account-bindings/linuxdo"
,
nil
)
c
.
Set
(
string
(
middleware2
.
ContextKeyUser
),
middleware2
.
AuthSubject
{
UserID
:
23
})
c
.
Params
=
gin
.
Params
{{
Key
:
"provider"
,
Value
:
"linuxdo"
}}
handler
.
UnbindIdentity
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
recorder
.
Code
)
require
.
Equal
(
t
,
[]
int64
{
23
},
refreshTokenCache
.
revokedUserIDs
)
require
.
Equal
(
t
,
int64
(
5
),
repo
.
user
.
TokenVersion
)
}
func
TestUserHandlerUnbindIdentityDoesNotRevokeSessionsWhenNothingWasUnbound
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
repo
:=
&
userHandlerRepoStub
{
user
:
&
service
.
User
{
ID
:
24
,
Email
:
"identity@example.com"
,
Username
:
"identity-user"
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
TokenVersion
:
4
,
},
identities
:
[]
service
.
UserAuthIdentityRecord
{
{
ProviderType
:
"email"
,
ProviderKey
:
"email"
,
ProviderSubject
:
"identity@example.com"
,
},
},
}
refreshTokenCache
:=
&
userHandlerRefreshTokenCacheStub
{}
cfg
:=
&
config
.
Config
{
JWT
:
config
.
JWTConfig
{
Secret
:
"test-secret"
,
ExpireHour
:
1
,
},
}
authService
:=
service
.
NewAuthService
(
nil
,
repo
,
nil
,
refreshTokenCache
,
cfg
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
handler
:=
NewUserHandler
(
service
.
NewUserService
(
repo
,
nil
,
nil
,
nil
),
authService
,
nil
,
nil
)
recorder
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
recorder
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodDelete
,
"/api/v1/user/account-bindings/linuxdo"
,
nil
)
c
.
Set
(
string
(
middleware2
.
ContextKeyUser
),
middleware2
.
AuthSubject
{
UserID
:
24
})
c
.
Params
=
gin
.
Params
{{
Key
:
"provider"
,
Value
:
"linuxdo"
}}
handler
.
UnbindIdentity
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
recorder
.
Code
)
require
.
Empty
(
t
,
repo
.
unbound
)
require
.
Empty
(
t
,
refreshTokenCache
.
revokedUserIDs
)
require
.
Equal
(
t
,
int64
(
4
),
repo
.
user
.
TokenVersion
)
}
func
TestUserHandlerBindEmailIdentityRejectsWrongCurrentPasswordForBoundEmail
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
...
...
@@ -587,7 +772,7 @@ func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) {
require
.
Equal
(
t
,
"wechat"
,
resp
.
Data
.
Provider
)
require
.
Equal
(
t
,
"GET"
,
resp
.
Data
.
Method
)
require
.
True
(
t
,
resp
.
Data
.
UseBrowserRedirect
)
require
.
Contains
(
t
,
resp
.
Data
.
AuthorizeURL
,
"/api/v1/auth/oauth/wechat/start"
)
require
.
Contains
(
t
,
resp
.
Data
.
AuthorizeURL
,
"/api/v1/auth/oauth/wechat/
bind/
start"
)
require
.
Contains
(
t
,
resp
.
Data
.
AuthorizeURL
,
"intent=bind_current_user"
)
require
.
Contains
(
t
,
resp
.
Data
.
AuthorizeURL
,
"redirect=%2Fsettings%2Fprofile"
)
}
backend/internal/payment/provider/wxpay.go
View file @
ddf80f5e
...
...
@@ -60,11 +60,6 @@ const (
wxpayEventTransactionSuccess
=
"TRANSACTION.SUCCESS"
)
// WeChat Pay error codes.
const
(
wxpayErrNoAuth
=
"NO_AUTH"
)
var
(
wxpayNativePrepay
=
func
(
ctx
context
.
Context
,
svc
native
.
NativeApiService
,
req
native
.
PrepayRequest
)
(
*
native
.
PrepayResponse
,
*
core
.
APIResult
,
error
)
{
return
svc
.
Prepay
(
ctx
,
req
)
...
...
@@ -200,14 +195,7 @@ func (w *Wxpay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequ
case
wxpayModeJSAPI
:
return
w
.
prepayJSAPI
(
ctx
,
client
,
req
,
notifyURL
,
totalFen
)
case
wxpayModeH5
:
resp
,
err
:=
w
.
prepayH5
(
ctx
,
client
,
req
,
notifyURL
,
totalFen
)
if
err
==
nil
{
return
resp
,
nil
}
if
strings
.
Contains
(
err
.
Error
(),
wxpayErrNoAuth
)
{
return
nil
,
fmt
.
Errorf
(
"wxpay h5 payments are not authorized for this merchant: %w"
,
err
)
}
return
nil
,
err
return
w
.
prepayH5
(
ctx
,
client
,
req
,
notifyURL
,
totalFen
)
case
wxpayModeNative
:
return
w
.
prepayNative
(
ctx
,
client
,
req
,
notifyURL
,
totalFen
)
default
:
...
...
backend/internal/payment/provider/wxpay_test.go
View file @
ddf80f5e
...
...
@@ -8,6 +8,7 @@ import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"net/url"
"strings"
"testing"
...
...
@@ -641,3 +642,68 @@ func TestCreatePaymentMobileH5IncludesConfiguredSceneInfo(t *testing.T) {
t
.
Fatalf
(
"pay_url = %q, want redirect_url query appended"
,
resp
.
PayURL
)
}
}
func
TestCreatePaymentMobileH5ReturnsNoAuthErrorWithoutNativeFallback
(
t
*
testing
.
T
)
{
origJSAPIPrepay
:=
wxpayJSAPIPrepayWithRequestPayment
origNativePrepay
:=
wxpayNativePrepay
origH5Prepay
:=
wxpayH5Prepay
t
.
Cleanup
(
func
()
{
wxpayJSAPIPrepayWithRequestPayment
=
origJSAPIPrepay
wxpayNativePrepay
=
origNativePrepay
wxpayH5Prepay
=
origH5Prepay
})
jsapiCalls
:=
0
nativeCalls
:=
0
h5Calls
:=
0
wxpayJSAPIPrepayWithRequestPayment
=
func
(
ctx
context
.
Context
,
svc
jsapi
.
JsapiApiService
,
req
jsapi
.
PrepayRequest
)
(
*
jsapi
.
PrepayWithRequestPaymentResponse
,
*
core
.
APIResult
,
error
)
{
jsapiCalls
++
return
&
jsapi
.
PrepayWithRequestPaymentResponse
{},
nil
,
nil
}
wxpayH5Prepay
=
func
(
ctx
context
.
Context
,
svc
h5
.
H5ApiService
,
req
h5
.
PrepayRequest
)
(
*
h5
.
PrepayResponse
,
*
core
.
APIResult
,
error
)
{
h5Calls
++
return
nil
,
nil
,
errors
.
New
(
"NO_AUTH"
)
}
wxpayNativePrepay
=
func
(
ctx
context
.
Context
,
svc
native
.
NativeApiService
,
req
native
.
PrepayRequest
)
(
*
native
.
PrepayResponse
,
*
core
.
APIResult
,
error
)
{
nativeCalls
++
return
&
native
.
PrepayResponse
{
CodeUrl
:
core
.
String
(
"weixin://wxpay/bizpayurl?pr=fallback-native"
),
},
nil
,
nil
}
provider
:=
&
Wxpay
{
config
:
map
[
string
]
string
{
"appId"
:
"wx123"
,
"mchId"
:
"mch123"
,
},
coreClient
:
&
core
.
Client
{},
}
resp
,
err
:=
provider
.
CreatePayment
(
context
.
Background
(),
payment
.
CreatePaymentRequest
{
OrderID
:
"sub2_100"
,
Amount
:
"66.88"
,
PaymentType
:
payment
.
TypeWxpay
,
Subject
:
"Balance Recharge"
,
NotifyURL
:
"https://merchant.example/payment/notify"
,
ClientIP
:
"203.0.113.10"
,
IsMobile
:
true
,
})
if
err
==
nil
{
t
.
Fatal
(
"expected no-auth error, got nil"
)
}
if
jsapiCalls
!=
0
{
t
.
Fatalf
(
"jsapi prepay calls = %d, want 0"
,
jsapiCalls
)
}
if
h5Calls
!=
1
{
t
.
Fatalf
(
"h5 prepay calls = %d, want 1"
,
h5Calls
)
}
if
nativeCalls
!=
0
{
t
.
Fatalf
(
"native prepay calls = %d, want 0"
,
nativeCalls
)
}
if
resp
!=
nil
{
t
.
Fatalf
(
"expected nil response, got %+v"
,
resp
)
}
if
!
strings
.
Contains
(
err
.
Error
(),
"NO_AUTH"
)
{
t
.
Fatalf
(
"error = %v, want NO_AUTH"
,
err
)
}
}
backend/internal/payment/wire.go
View file @
ddf80f5e
...
...
@@ -4,6 +4,7 @@ import (
"encoding/hex"
"fmt"
"log/slog"
"strings"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
...
...
@@ -19,11 +20,22 @@ type EncryptionKey []byte
// When the key is non-empty but invalid (bad hex or wrong length), an error is returned
// to prevent startup with a misconfigured encryption key.
func
ProvideEncryptionKey
(
cfg
*
config
.
Config
)
(
EncryptionKey
,
error
)
{
if
cfg
.
Totp
.
EncryptionKey
==
""
{
if
cfg
==
nil
{
slog
.
Warn
(
"payment encryption key not configured — encrypted payment config and resume signing will be unavailable"
)
return
nil
,
nil
}
keyHex
:=
strings
.
TrimSpace
(
cfg
.
Totp
.
EncryptionKey
)
if
keyHex
==
""
{
slog
.
Warn
(
"payment encryption key not configured — encrypted payment config will be unavailable"
)
return
nil
,
nil
}
key
,
err
:=
hex
.
DecodeString
(
cfg
.
Totp
.
EncryptionKey
)
// Reject auto-generated TOTP keys for payment signing.
// They change across restarts/instances and can silently break resume-token flows.
if
!
cfg
.
Totp
.
EncryptionKeyConfigured
{
slog
.
Warn
(
"payment encryption/signing key is not explicitly configured; set TOTP_ENCRYPTION_KEY to enable payment resume tokens"
)
return
nil
,
nil
}
key
,
err
:=
hex
.
DecodeString
(
keyHex
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"invalid payment encryption key (hex decode): %w"
,
err
)
}
...
...
backend/internal/payment/wire_test.go
0 → 100644
View file @
ddf80f5e
package
payment
import
(
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
)
func
TestProvideEncryptionKeySkipsAutoGeneratedTotpKey
(
t
*
testing
.
T
)
{
t
.
Parallel
()
cfg
:=
&
config
.
Config
{
Totp
:
config
.
TotpConfig
{
EncryptionKey
:
strings
.
Repeat
(
"a"
,
64
),
EncryptionKeyConfigured
:
false
,
},
}
key
,
err
:=
ProvideEncryptionKey
(
cfg
)
if
err
!=
nil
{
t
.
Fatalf
(
"ProvideEncryptionKey returned error: %v"
,
err
)
}
if
len
(
key
)
!=
0
{
t
.
Fatalf
(
"encryption key len = %d, want 0"
,
len
(
key
))
}
}
func
TestProvideEncryptionKeyUsesConfiguredTotpKey
(
t
*
testing
.
T
)
{
t
.
Parallel
()
cfg
:=
&
config
.
Config
{
Totp
:
config
.
TotpConfig
{
EncryptionKey
:
"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
,
EncryptionKeyConfigured
:
true
,
},
}
key
,
err
:=
ProvideEncryptionKey
(
cfg
)
if
err
!=
nil
{
t
.
Fatalf
(
"ProvideEncryptionKey returned error: %v"
,
err
)
}
if
len
(
key
)
!=
32
{
t
.
Fatalf
(
"encryption key len = %d, want 32"
,
len
(
key
))
}
}
func
TestProvideEncryptionKeyRejectsConfiguredInvalidLength
(
t
*
testing
.
T
)
{
t
.
Parallel
()
cfg
:=
&
config
.
Config
{
Totp
:
config
.
TotpConfig
{
EncryptionKey
:
"abcd"
,
EncryptionKeyConfigured
:
true
,
},
}
_
,
err
:=
ProvideEncryptionKey
(
cfg
)
if
err
==
nil
{
t
.
Fatal
(
"expected error for invalid key length"
)
}
}
backend/internal/repository/auth_identity_legacy_migration_integration_test.go
View file @
ddf80f5e
...
...
@@ -4,6 +4,7 @@ package repository
import
(
"context"
"database/sql"
"os"
"path/filepath"
"strconv"
...
...
@@ -20,32 +21,8 @@ func TestAuthIdentityLegacyExternalBackfillMigration(t *testing.T) {
migrationSQL
,
err
:=
os
.
ReadFile
(
migrationPath
)
require
.
NoError
(
t
,
err
)
_
,
err
=
tx
.
ExecContext
(
ctx
,
`
CREATE TABLE IF NOT EXISTS user_external_identities (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL,
provider TEXT NOT NULL,
provider_user_id TEXT NOT NULL,
provider_union_id TEXT NULL,
provider_username TEXT NOT NULL DEFAULT '',
display_name TEXT NOT NULL DEFAULT '',
profile_url TEXT NOT NULL DEFAULT '',
avatar_url TEXT NOT NULL DEFAULT '',
metadata TEXT NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
);
TRUNCATE TABLE
auth_identity_channels,
identity_adoption_decisions,
auth_identities,
auth_identity_migration_reports,
user_external_identities,
users
RESTART IDENTITY CASCADE;
`
)
require
.
NoError
(
t
,
err
)
prepareLegacyExternalIdentitiesTable
(
t
,
tx
,
ctx
)
truncateAuthIdentityLegacyFixtureTables
(
t
,
tx
,
ctx
)
var
linuxDoUserID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
...
...
@@ -218,32 +195,8 @@ func TestAuthIdentityLegacyExternalMigrations_ChainHandlesMalformedAndNonObjectM
migration116SQL
,
err
:=
os
.
ReadFile
(
migration116Path
)
require
.
NoError
(
t
,
err
)
_
,
err
=
tx
.
ExecContext
(
ctx
,
`
CREATE TABLE IF NOT EXISTS user_external_identities (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL,
provider TEXT NOT NULL,
provider_user_id TEXT NOT NULL,
provider_union_id TEXT NULL,
provider_username TEXT NOT NULL DEFAULT '',
display_name TEXT NOT NULL DEFAULT '',
profile_url TEXT NOT NULL DEFAULT '',
avatar_url TEXT NOT NULL DEFAULT '',
metadata TEXT NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
);
TRUNCATE TABLE
auth_identity_channels,
identity_adoption_decisions,
auth_identities,
auth_identity_migration_reports,
user_external_identities,
users
RESTART IDENTITY CASCADE;
`
)
require
.
NoError
(
t
,
err
)
prepareLegacyExternalIdentitiesTable
(
t
,
tx
,
ctx
)
truncateAuthIdentityLegacyFixtureTables
(
t
,
tx
,
ctx
)
var
linuxDoMalformedUserID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
...
...
@@ -408,32 +361,8 @@ func TestAuthIdentityLegacyExternalSafetyMigration_ReportsConflictsAndDowngrades
migrationSQL
,
err
:=
os
.
ReadFile
(
migrationPath
)
require
.
NoError
(
t
,
err
)
_
,
err
=
tx
.
ExecContext
(
ctx
,
`
CREATE TABLE IF NOT EXISTS user_external_identities (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL,
provider TEXT NOT NULL,
provider_user_id TEXT NOT NULL,
provider_union_id TEXT NULL,
provider_username TEXT NOT NULL DEFAULT '',
display_name TEXT NOT NULL DEFAULT '',
profile_url TEXT NOT NULL DEFAULT '',
avatar_url TEXT NOT NULL DEFAULT '',
metadata TEXT NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
);
TRUNCATE TABLE
auth_identity_channels,
identity_adoption_decisions,
auth_identities,
auth_identity_migration_reports,
user_external_identities,
users
RESTART IDENTITY CASCADE;
`
)
require
.
NoError
(
t
,
err
)
prepareLegacyExternalIdentitiesTable
(
t
,
tx
,
ctx
)
truncateAuthIdentityLegacyFixtureTables
(
t
,
tx
,
ctx
)
userIDs
:=
make
([]
int64
,
0
,
8
)
for
_
,
email
:=
range
[]
string
{
...
...
@@ -643,6 +572,388 @@ FROM auth_identity_migration_reports
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
SELECT COUNT(*)
FROM auth_identity_migration_reports
`
)
.
Scan
(
&
afterCount
))
`
)
.
Scan
(
&
afterCount
))
require
.
Equal
(
t
,
beforeCount
,
afterCount
)
}
func
TestAuthIdentityLegacyExternalBackfillMigration_SkipsAmbiguousCanonicalSubjects
(
t
*
testing
.
T
)
{
tx
:=
testTx
(
t
)
ctx
:=
context
.
Background
()
migrationPath
:=
filepath
.
Join
(
".."
,
".."
,
"migrations"
,
"115_auth_identity_legacy_external_backfill.sql"
)
migrationSQL
,
err
:=
os
.
ReadFile
(
migrationPath
)
require
.
NoError
(
t
,
err
)
prepareLegacyExternalIdentitiesTable
(
t
,
tx
,
ctx
)
truncateAuthIdentityLegacyFixtureTables
(
t
,
tx
,
ctx
)
var
linuxDoFirstUserID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-linuxdo-ambiguous-a@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`
)
.
Scan
(
&
linuxDoFirstUserID
))
var
linuxDoSecondUserID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-linuxdo-ambiguous-b@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`
)
.
Scan
(
&
linuxDoSecondUserID
))
var
wechatFirstUserID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-wechat-ambiguous-a@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`
)
.
Scan
(
&
wechatFirstUserID
))
var
wechatSecondUserID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-wechat-ambiguous-b@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`
)
.
Scan
(
&
wechatSecondUserID
))
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'linuxdo', 'linuxdo-ambiguous-subject', NULL, 'legacy-linuxdo-ambiguous-a', 'Legacy LinuxDo Ambiguous A', '{"source":"legacy"}')
RETURNING id
`
,
linuxDoFirstUserID
)
.
Scan
(
new
(
int64
)))
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'linuxdo', 'linuxdo-ambiguous-subject', NULL, 'legacy-linuxdo-ambiguous-b', 'Legacy LinuxDo Ambiguous B', '{"source":"legacy"}')
RETURNING id
`
,
linuxDoSecondUserID
)
.
Scan
(
new
(
int64
)))
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'wechat', 'openid-ambiguous-a', 'union-ambiguous-subject', 'legacy-wechat-ambiguous-a', 'Legacy WeChat Ambiguous A', '{"channel":"oa","appid":"wx-ambiguous-a"}')
RETURNING id
`
,
wechatFirstUserID
)
.
Scan
(
new
(
int64
)))
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'wechat', 'openid-ambiguous-b', 'union-ambiguous-subject', 'legacy-wechat-ambiguous-b', 'Legacy WeChat Ambiguous B', '{"channel":"oa","appid":"wx-ambiguous-b"}')
RETURNING id
`
,
wechatSecondUserID
)
.
Scan
(
new
(
int64
)))
_
,
err
=
tx
.
ExecContext
(
ctx
,
string
(
migrationSQL
))
require
.
NoError
(
t
,
err
)
var
linuxDoIdentityCount
int
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
SELECT COUNT(*)
FROM auth_identities
WHERE provider_type = 'linuxdo'
AND provider_key = 'linuxdo'
AND provider_subject = 'linuxdo-ambiguous-subject'
`
)
.
Scan
(
&
linuxDoIdentityCount
))
require
.
Zero
(
t
,
linuxDoIdentityCount
)
var
wechatIdentityCount
int
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
SELECT COUNT(*)
FROM auth_identities
WHERE provider_type = 'wechat'
AND provider_key = 'wechat-main'
AND provider_subject = 'union-ambiguous-subject'
`
)
.
Scan
(
&
wechatIdentityCount
))
require
.
Zero
(
t
,
wechatIdentityCount
)
var
wechatChannelCount
int
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
SELECT COUNT(*)
FROM auth_identity_channels
WHERE provider_type = 'wechat'
AND provider_key = 'wechat-main'
AND channel = 'oa'
AND channel_app_id IN ('wx-ambiguous-a', 'wx-ambiguous-b')
`
)
.
Scan
(
&
wechatChannelCount
))
require
.
Zero
(
t
,
wechatChannelCount
)
}
func
TestAuthIdentityLegacyExternalMigrations_ReportAmbiguousCanonicalSubjectsWithoutWinnerAttribution
(
t
*
testing
.
T
)
{
tx
:=
testTx
(
t
)
ctx
:=
context
.
Background
()
migration115Path
:=
filepath
.
Join
(
".."
,
".."
,
"migrations"
,
"115_auth_identity_legacy_external_backfill.sql"
)
migration115SQL
,
err
:=
os
.
ReadFile
(
migration115Path
)
require
.
NoError
(
t
,
err
)
migration116Path
:=
filepath
.
Join
(
".."
,
".."
,
"migrations"
,
"116_auth_identity_legacy_external_safety_reports.sql"
)
migration116SQL
,
err
:=
os
.
ReadFile
(
migration116Path
)
require
.
NoError
(
t
,
err
)
prepareLegacyExternalIdentitiesTable
(
t
,
tx
,
ctx
)
truncateAuthIdentityLegacyFixtureTables
(
t
,
tx
,
ctx
)
var
linuxDoFirstUserID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-linuxdo-conflict-a@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`
)
.
Scan
(
&
linuxDoFirstUserID
))
var
linuxDoSecondUserID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-linuxdo-conflict-b@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`
)
.
Scan
(
&
linuxDoSecondUserID
))
var
wechatFirstUserID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-wechat-conflict-a@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`
)
.
Scan
(
&
wechatFirstUserID
))
var
wechatSecondUserID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-wechat-conflict-b@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`
)
.
Scan
(
&
wechatSecondUserID
))
var
linuxDoFirstLegacyID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'linuxdo', 'linuxdo-conflict-subject', NULL, 'legacy-linuxdo-conflict-a', 'Legacy LinuxDo Conflict A', '{"source":"legacy"}')
RETURNING id
`
,
linuxDoFirstUserID
)
.
Scan
(
&
linuxDoFirstLegacyID
))
var
linuxDoSecondLegacyID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'linuxdo', 'linuxdo-conflict-subject', NULL, 'legacy-linuxdo-conflict-b', 'Legacy LinuxDo Conflict B', '{"source":"legacy"}')
RETURNING id
`
,
linuxDoSecondUserID
)
.
Scan
(
&
linuxDoSecondLegacyID
))
var
wechatFirstLegacyID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'wechat', 'openid-conflict-a', 'union-conflict-subject', 'legacy-wechat-conflict-a', 'Legacy WeChat Conflict A', '{"channel":"oa","appid":"wx-conflict-a"}')
RETURNING id
`
,
wechatFirstUserID
)
.
Scan
(
&
wechatFirstLegacyID
))
var
wechatSecondLegacyID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'wechat', 'openid-conflict-b', 'union-conflict-subject', 'legacy-wechat-conflict-b', 'Legacy WeChat Conflict B', '{"channel":"oa","appid":"wx-conflict-b"}')
RETURNING id
`
,
wechatSecondUserID
)
.
Scan
(
&
wechatSecondLegacyID
))
_
,
err
=
tx
.
ExecContext
(
ctx
,
string
(
migration115SQL
))
require
.
NoError
(
t
,
err
)
_
,
err
=
tx
.
ExecContext
(
ctx
,
string
(
migration116SQL
))
require
.
NoError
(
t
,
err
)
var
identityCount
int
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
SELECT COUNT(*)
FROM auth_identities
WHERE (provider_type = 'linuxdo' AND provider_key = 'linuxdo' AND provider_subject = 'linuxdo-conflict-subject')
OR (provider_type = 'wechat' AND provider_key = 'wechat-main' AND provider_subject = 'union-conflict-subject')
`
)
.
Scan
(
&
identityCount
))
require
.
Zero
(
t
,
identityCount
)
var
conflictReportCount
int
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
SELECT COUNT(*)
FROM auth_identity_migration_reports
WHERE report_type = 'legacy_external_identity_conflict'
AND report_key IN ($1, $2, $3, $4)
`
,
"legacy_external_identity:"
+
strconv
.
FormatInt
(
linuxDoFirstLegacyID
,
10
),
"legacy_external_identity:"
+
strconv
.
FormatInt
(
linuxDoSecondLegacyID
,
10
),
"legacy_external_identity:"
+
strconv
.
FormatInt
(
wechatFirstLegacyID
,
10
),
"legacy_external_identity:"
+
strconv
.
FormatInt
(
wechatSecondLegacyID
,
10
))
.
Scan
(
&
conflictReportCount
))
require
.
Equal
(
t
,
4
,
conflictReportCount
)
var
winnerAttributedReportCount
int
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
SELECT COUNT(*)
FROM auth_identity_migration_reports
WHERE report_type = 'legacy_external_identity_conflict'
AND report_key IN ($1, $2, $3, $4)
AND details ->> 'existing_identity_id' IS NOT NULL
`
,
"legacy_external_identity:"
+
strconv
.
FormatInt
(
linuxDoFirstLegacyID
,
10
),
"legacy_external_identity:"
+
strconv
.
FormatInt
(
linuxDoSecondLegacyID
,
10
),
"legacy_external_identity:"
+
strconv
.
FormatInt
(
wechatFirstLegacyID
,
10
),
"legacy_external_identity:"
+
strconv
.
FormatInt
(
wechatSecondLegacyID
,
10
))
.
Scan
(
&
winnerAttributedReportCount
))
require
.
Zero
(
t
,
winnerAttributedReportCount
)
}
func
TestAuthIdentityMigrationReportTypeWideningPreflightKeeps109And116SafeBefore121
(
t
*
testing
.
T
)
{
tx
:=
testTx
(
t
)
ctx
:=
context
.
Background
()
migration108aPath
:=
filepath
.
Join
(
".."
,
".."
,
"migrations"
,
"108a_widen_auth_identity_migration_report_type.sql"
)
migration108aSQL
,
err
:=
os
.
ReadFile
(
migration108aPath
)
require
.
NoError
(
t
,
err
)
migration109Path
:=
filepath
.
Join
(
".."
,
".."
,
"migrations"
,
"109_auth_identity_compat_backfill.sql"
)
migration109SQL
,
err
:=
os
.
ReadFile
(
migration109Path
)
require
.
NoError
(
t
,
err
)
migration116Path
:=
filepath
.
Join
(
".."
,
".."
,
"migrations"
,
"116_auth_identity_legacy_external_safety_reports.sql"
)
migration116SQL
,
err
:=
os
.
ReadFile
(
migration116Path
)
require
.
NoError
(
t
,
err
)
prepareLegacyExternalIdentitiesTable
(
t
,
tx
,
ctx
)
truncateAuthIdentityLegacyFixtureTables
(
t
,
tx
,
ctx
)
_
,
err
=
tx
.
ExecContext
(
ctx
,
`
ALTER TABLE auth_identity_migration_reports
ALTER COLUMN report_type TYPE VARCHAR(40);
`
)
require
.
NoError
(
t
,
err
)
var
oidcSyntheticUserID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('oidc-before-121@oidc-connect.invalid', 'hash', 'user', 'active', 0, 1)
RETURNING id`
)
.
Scan
(
&
oidcSyntheticUserID
))
var
linuxdoLegacyUserID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-linuxdo-before-121@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`
)
.
Scan
(
&
linuxdoLegacyUserID
))
var
invalidMetadataLegacyID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'linuxdo', 'linuxdo-before-121', NULL, 'legacy-linuxdo-before-121', 'Legacy LinuxDo Before 121', '{invalid')
RETURNING id
`
,
linuxdoLegacyUserID
)
.
Scan
(
&
invalidMetadataLegacyID
))
_
,
err
=
tx
.
ExecContext
(
ctx
,
string
(
migration108aSQL
))
require
.
NoError
(
t
,
err
)
_
,
err
=
tx
.
ExecContext
(
ctx
,
string
(
migration109SQL
))
require
.
NoError
(
t
,
err
)
_
,
err
=
tx
.
ExecContext
(
ctx
,
string
(
migration116SQL
))
require
.
NoError
(
t
,
err
)
var
reportTypeWidth
int
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
SELECT character_maximum_length
FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = 'auth_identity_migration_reports'
AND column_name = 'report_type'
`
)
.
Scan
(
&
reportTypeWidth
))
require
.
Equal
(
t
,
80
,
reportTypeWidth
)
var
oidcSyntheticRecoveryReportCount
int
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
SELECT COUNT(*)
FROM auth_identity_migration_reports
WHERE report_type = 'oidc_synthetic_email_requires_manual_recovery'
AND report_key = $1
`
,
strconv
.
FormatInt
(
oidcSyntheticUserID
,
10
))
.
Scan
(
&
oidcSyntheticRecoveryReportCount
))
require
.
Equal
(
t
,
1
,
oidcSyntheticRecoveryReportCount
)
var
invalidMetadataReportCount
int
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
SELECT COUNT(*)
FROM auth_identity_migration_reports
WHERE report_type = 'legacy_external_identity_invalid_metadata_json'
AND report_key = $1
`
,
"legacy_external_identity:"
+
strconv
.
FormatInt
(
invalidMetadataLegacyID
,
10
))
.
Scan
(
&
invalidMetadataReportCount
))
require
.
Equal
(
t
,
1
,
invalidMetadataReportCount
)
}
func
prepareLegacyExternalIdentitiesTable
(
t
*
testing
.
T
,
tx
*
sql
.
Tx
,
ctx
context
.
Context
)
{
t
.
Helper
()
_
,
err
:=
tx
.
ExecContext
(
ctx
,
`
CREATE TABLE IF NOT EXISTS user_external_identities (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL,
provider TEXT NOT NULL,
provider_user_id TEXT NOT NULL,
provider_union_id TEXT NULL,
provider_username TEXT NOT NULL DEFAULT '',
display_name TEXT NOT NULL DEFAULT '',
profile_url TEXT NOT NULL DEFAULT '',
avatar_url TEXT NOT NULL DEFAULT '',
metadata TEXT NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
);
`
)
require
.
NoError
(
t
,
err
)
}
func
truncateAuthIdentityLegacyFixtureTables
(
t
*
testing
.
T
,
tx
*
sql
.
Tx
,
ctx
context
.
Context
)
{
t
.
Helper
()
_
,
err
:=
tx
.
ExecContext
(
ctx
,
`
TRUNCATE TABLE
auth_identity_channels,
identity_adoption_decisions,
pending_auth_sessions,
auth_identities,
auth_identity_migration_reports,
user_provider_default_grants,
user_avatars,
user_external_identities,
users
RESTART IDENTITY CASCADE;
`
)
require
.
NoError
(
t
,
err
)
}
backend/internal/repository/migrations_runner.go
View file @
ddf80f5e
...
...
@@ -51,34 +51,30 @@ CREATE TABLE IF NOT EXISTS atlas_schema_revisions (
const
migrationsAdvisoryLockID
int64
=
694208311321144027
const
migrationsLockRetryInterval
=
500
*
time
.
Millisecond
const
nonTransactionalMigrationSuffix
=
"_notx.sql"
const
paymentOrdersOutTradeNoUniqueMigration
=
"120_enforce_payment_orders_out_trade_no_unique_notx.sql"
const
paymentOrdersOutTradeNoUniqueIndex
=
"paymentorder_out_trade_no_unique"
type
migrationChecksumCompatibilityRule
struct
{
fileChecksum
string
acceptedDBChecksum
map
[
string
]
struct
{}
acceptedChecksums
map
[
string
]
struct
{}
}
// migrationChecksumCompatibilityRules 仅用于兼容历史上误修改过的迁移文件 checksum。
// 规则必须同时匹配「迁移名 + 当前文件 checksum + 历史库 checksum」才会放行,避免放宽全局校验。
// 规则必须同时匹配「迁移名 + 数据库 checksum + 当前文件 checksum」且两者都落在该迁移的已知版本集合内才会放行,
// 避免放宽全局校验,也允许将误改的历史 migration 回滚为已发布版本而不要求人工修 checksum。
var
migrationChecksumCompatibilityRules
=
map
[
string
]
migrationChecksumCompatibilityRule
{
"054_drop_legacy_cache_columns.sql"
:
{
fileChecksum
:
"82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d"
,
acceptedDBChecksum
:
map
[
string
]
struct
{}{
"182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4"
:
{},
},
},
"061_add_usage_log_request_type.sql"
:
{
fileChecksum
:
"66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c"
,
acceptedDBChecksum
:
map
[
string
]
struct
{}{
"08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0"
:
{},
"222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3"
:
{},
},
},
"109_auth_identity_compat_backfill.sql"
:
{
fileChecksum
:
"551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee"
,
acceptedDBChecksum
:
map
[
string
]
struct
{}{
"2b380305e73ff0c13aa8c811e45897f2b36ca4a438f7b3e8f98e19ecb6bae0b3"
:
{},
},
},
"054_drop_legacy_cache_columns.sql"
:
newMigrationChecksumCompatibilityRule
(
"82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d"
,
"182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4"
),
"061_add_usage_log_request_type.sql"
:
newMigrationChecksumCompatibilityRule
(
"66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c"
,
"08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0"
,
"222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3"
),
"109_auth_identity_compat_backfill.sql"
:
newMigrationChecksumCompatibilityRule
(
"0580b4602d85435edf9aca1633db580bb3932f26517f75134106f80275ec2ace"
,
"551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee"
),
"110_pending_auth_and_provider_default_grants.sql"
:
newMigrationChecksumCompatibilityRule
(
"32cf87ee787b1bb36b5c691367c96eee37518fa3eed6f3322cf68795e3745279"
,
"e3d1f433be2b564cfbdc549adf98fce13c5c7b363ebc20fd05b765d0563b0925"
),
"112_add_payment_order_provider_key_snapshot.sql"
:
newMigrationChecksumCompatibilityRule
(
"b75f8f56d39455682787696a3d92ad25b055444ca328fb7fca9a460a15d68d99"
,
"ffd3e8a2c9295fa9cbefefd629a78268877e5b51bc970a82d9b3f46ec4ebd15e"
),
"115_auth_identity_legacy_external_backfill.sql"
:
newMigrationChecksumCompatibilityRule
(
"022aadd97bb53e755f0cf7a3a957e0cb1a1353b0c39ec4de3234acd2871fd04f"
,
"4cf39e508be9fd1a5aa41610cbbebeb80385c9adda45bf78a706de9db4f1385f"
),
"116_auth_identity_legacy_external_safety_reports.sql"
:
newMigrationChecksumCompatibilityRule
(
"07edb09fa8d04ffb172b0621e3c22f4d1757d20a24ae267b3b36b087ab72d488"
,
"f7757bd929ac67ffb08ce69fa4cf20fad39dbff9d5a5085fb2adabb7607e5877"
),
"118_wechat_dual_mode_and_auth_source_defaults.sql"
:
newMigrationChecksumCompatibilityRule
(
"b54194d7a3e4fbf710e0a3590d22a2fe7966804c487052a356e0b55f53ef96b0"
,
"e0cdf835d6c688d64100f483d31bc02ac9ebad414bf1837af239a84bf75b8227"
,
"a38243ca0a72c3a01c0a92b7986423054d6133c0399441f853b99802852720fb"
),
"119_enforce_payment_orders_out_trade_no_unique.sql"
:
newMigrationChecksumCompatibilityRule
(
"0bbe809ae48a9d811dabda1ba1c74955bd71c4a9cc610f9128816818dfa6c11e"
,
"ebd2c67cce0116393fb4f1b5d5116a67c6aceb73820dfb5133d1ff6f36d72d34"
),
"120_enforce_payment_orders_out_trade_no_unique_notx.sql"
:
newMigrationChecksumCompatibilityRule
(
"34aadc0db59a4e390f92a12b73bd74642d9724f33124f73638ae00089ea5e074"
,
"e77921f79d539bc24575cb9c16cbe566d2b23ce816190343d0a7568f6a3fcf61"
,
"707431450603e70a43ce9fbd61e0c12fa67da4875158ccefabacea069587ab22"
,
"04b082b5a239c525154fe9185d324ee2b05ff90da9297e10dba19f9be79aa59a"
),
"123_fix_legacy_auth_source_grant_on_signup_defaults.sql"
:
newMigrationChecksumCompatibilityRule
(
"2ce43c2cd89e9f9e1febd34a407ed9e84d177386c5544b6f02c1f58a21129f57"
,
"6cd33422f215dcd1f486ab6f35c0ea5805d9ca69bb25906d94bc649156657145"
),
}
// ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。
...
...
@@ -205,6 +201,10 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
}
if
nonTx
{
if
err
:=
prepareNonTransactionalMigration
(
ctx
,
db
,
name
);
err
!=
nil
{
return
fmt
.
Errorf
(
"prepare migration %s: %w"
,
name
,
err
)
}
// *_notx.sql:用于 CREATE/DROP INDEX CONCURRENTLY 场景,必须非事务执行。
// 逐条语句执行,避免将多条 CONCURRENTLY 语句放入同一个隐式事务块。
statements
:=
splitSQLStatements
(
content
)
...
...
@@ -254,6 +254,90 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
return
nil
}
func
prepareNonTransactionalMigration
(
ctx
context
.
Context
,
db
*
sql
.
DB
,
name
string
)
error
{
switch
name
{
case
paymentOrdersOutTradeNoUniqueMigration
:
return
preparePaymentOrdersOutTradeNoUniqueMigration
(
ctx
,
db
)
default
:
return
nil
}
}
func
preparePaymentOrdersOutTradeNoUniqueMigration
(
ctx
context
.
Context
,
db
*
sql
.
DB
)
error
{
duplicates
,
err
:=
findDuplicatePaymentOrderOutTradeNos
(
ctx
,
db
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"precheck duplicate out_trade_no: %w"
,
err
)
}
if
len
(
duplicates
)
>
0
{
return
fmt
.
Errorf
(
"duplicate out_trade_no values block %s; remediate duplicates before retrying: %s"
,
paymentOrdersOutTradeNoUniqueMigration
,
strings
.
Join
(
duplicates
,
", "
),
)
}
invalid
,
err
:=
indexIsInvalid
(
ctx
,
db
,
paymentOrdersOutTradeNoUniqueIndex
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"check invalid index %s: %w"
,
paymentOrdersOutTradeNoUniqueIndex
,
err
)
}
if
!
invalid
{
return
nil
}
if
_
,
err
:=
db
.
ExecContext
(
ctx
,
fmt
.
Sprintf
(
"DROP INDEX CONCURRENTLY IF EXISTS %s"
,
paymentOrdersOutTradeNoUniqueIndex
));
err
!=
nil
{
return
fmt
.
Errorf
(
"drop invalid index %s: %w"
,
paymentOrdersOutTradeNoUniqueIndex
,
err
)
}
return
nil
}
func
findDuplicatePaymentOrderOutTradeNos
(
ctx
context
.
Context
,
db
*
sql
.
DB
)
([]
string
,
error
)
{
rows
,
err
:=
db
.
QueryContext
(
ctx
,
`
SELECT out_trade_no, COUNT(*) AS duplicate_count
FROM payment_orders
WHERE out_trade_no <> ''
GROUP BY out_trade_no
HAVING COUNT(*) > 1
ORDER BY duplicate_count DESC, out_trade_no
LIMIT 5
`
)
if
err
!=
nil
{
return
nil
,
err
}
defer
func
()
{
_
=
rows
.
Close
()
}()
duplicates
:=
make
([]
string
,
0
,
5
)
for
rows
.
Next
()
{
var
outTradeNo
string
var
duplicateCount
int
if
err
:=
rows
.
Scan
(
&
outTradeNo
,
&
duplicateCount
);
err
!=
nil
{
return
nil
,
err
}
duplicates
=
append
(
duplicates
,
fmt
.
Sprintf
(
"%s (count=%d)"
,
outTradeNo
,
duplicateCount
))
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
nil
,
err
}
return
duplicates
,
nil
}
func
indexIsInvalid
(
ctx
context
.
Context
,
db
*
sql
.
DB
,
indexName
string
)
(
bool
,
error
)
{
var
invalid
bool
err
:=
db
.
QueryRowContext
(
ctx
,
`
SELECT EXISTS (
SELECT 1
FROM pg_class idx
JOIN pg_namespace ns ON ns.oid = idx.relnamespace
JOIN pg_index i ON i.indexrelid = idx.oid
WHERE ns.nspname = 'public'
AND idx.relname = $1
AND NOT i.indisvalid
)
`
,
indexName
)
.
Scan
(
&
invalid
)
return
invalid
,
err
}
func
ensureAtlasBaselineAligned
(
ctx
context
.
Context
,
db
*
sql
.
DB
,
fsys
fs
.
FS
)
error
{
hasLegacy
,
err
:=
tableExists
(
ctx
,
db
,
"schema_migrations"
)
if
err
!=
nil
{
...
...
@@ -328,16 +412,33 @@ func latestMigrationBaseline(fsys fs.FS) (string, string, string, error) {
return
version
,
version
,
hash
,
nil
}
func
checksumSet
(
values
...
string
)
map
[
string
]
struct
{}
{
out
:=
make
(
map
[
string
]
struct
{},
len
(
values
))
for
_
,
value
:=
range
values
{
out
[
value
]
=
struct
{}{}
}
return
out
}
func
newMigrationChecksumCompatibilityRule
(
fileChecksum
string
,
acceptedDBChecksums
...
string
)
migrationChecksumCompatibilityRule
{
return
migrationChecksumCompatibilityRule
{
fileChecksum
:
fileChecksum
,
acceptedDBChecksum
:
checksumSet
(
acceptedDBChecksums
...
),
acceptedChecksums
:
checksumSet
(
append
([]
string
{
fileChecksum
},
acceptedDBChecksums
...
)
...
),
}
}
func
isMigrationChecksumCompatible
(
name
,
dbChecksum
,
fileChecksum
string
)
bool
{
rule
,
ok
:=
migrationChecksumCompatibilityRules
[
name
]
if
!
ok
{
return
false
}
if
rule
.
fileChecksum
!=
fileChecksum
{
_
,
dbOK
:=
rule
.
acceptedChecksums
[
dbChecksum
]
if
!
dbOK
{
return
false
}
_
,
ok
=
rule
.
accepted
DB
Checksum
[
db
Checksum
]
return
ok
_
,
fileOK
:
=
rule
.
acceptedChecksum
s
[
file
Checksum
]
return
fileOK
}
func
validateMigrationExecutionMode
(
name
,
content
string
)
(
bool
,
error
)
{
...
...
backend/internal/repository/migrations_runner_checksum_test.go
View file @
ddf80f5e
...
...
@@ -55,9 +55,110 @@ func TestIsMigrationChecksumCompatible(t *testing.T) {
t
.
Run
(
"109历史checksum可兼容"
,
func
(
t
*
testing
.
T
)
{
ok
:=
isMigrationChecksumCompatible
(
"109_auth_identity_compat_backfill.sql"
,
"2b380305e73ff0c13aa8c811e45897f2b36ca4a438f7b3e8f98e19ecb6bae0b3"
,
"551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee"
,
"0580b4602d85435edf9aca1633db580bb3932f26517f75134106f80275ec2ace"
,
)
require
.
True
(
t
,
ok
)
})
t
.
Run
(
"109当前checksum可兼容历史checksum"
,
func
(
t
*
testing
.
T
)
{
ok
:=
isMigrationChecksumCompatible
(
"109_auth_identity_compat_backfill.sql"
,
"551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee"
,
"0580b4602d85435edf9aca1633db580bb3932f26517f75134106f80275ec2ace"
,
)
require
.
True
(
t
,
ok
)
})
t
.
Run
(
"109回滚到历史文件后仍兼容已应用的新checksum"
,
func
(
t
*
testing
.
T
)
{
ok
:=
isMigrationChecksumCompatible
(
"109_auth_identity_compat_backfill.sql"
,
"0580b4602d85435edf9aca1633db580bb3932f26517f75134106f80275ec2ace"
,
"551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee"
,
)
require
.
True
(
t
,
ok
)
})
t
.
Run
(
"110历史checksum可兼容"
,
func
(
t
*
testing
.
T
)
{
ok
:=
isMigrationChecksumCompatible
(
"110_pending_auth_and_provider_default_grants.sql"
,
"e3d1f433be2b564cfbdc549adf98fce13c5c7b363ebc20fd05b765d0563b0925"
,
"32cf87ee787b1bb36b5c691367c96eee37518fa3eed6f3322cf68795e3745279"
,
)
require
.
True
(
t
,
ok
)
})
t
.
Run
(
"112历史checksum可兼容"
,
func
(
t
*
testing
.
T
)
{
ok
:=
isMigrationChecksumCompatible
(
"112_add_payment_order_provider_key_snapshot.sql"
,
"ffd3e8a2c9295fa9cbefefd629a78268877e5b51bc970a82d9b3f46ec4ebd15e"
,
"b75f8f56d39455682787696a3d92ad25b055444ca328fb7fca9a460a15d68d99"
,
)
require
.
True
(
t
,
ok
)
})
t
.
Run
(
"115历史checksum可兼容修复后的legacy external backfill"
,
func
(
t
*
testing
.
T
)
{
ok
:=
isMigrationChecksumCompatible
(
"115_auth_identity_legacy_external_backfill.sql"
,
"4cf39e508be9fd1a5aa41610cbbebeb80385c9adda45bf78a706de9db4f1385f"
,
"022aadd97bb53e755f0cf7a3a957e0cb1a1353b0c39ec4de3234acd2871fd04f"
,
)
require
.
True
(
t
,
ok
)
})
t
.
Run
(
"116历史checksum可兼容修复后的legacy external safety reports"
,
func
(
t
*
testing
.
T
)
{
ok
:=
isMigrationChecksumCompatible
(
"116_auth_identity_legacy_external_safety_reports.sql"
,
"f7757bd929ac67ffb08ce69fa4cf20fad39dbff9d5a5085fb2adabb7607e5877"
,
"07edb09fa8d04ffb172b0621e3c22f4d1757d20a24ae267b3b36b087ab72d488"
,
)
require
.
True
(
t
,
ok
)
})
t
.
Run
(
"119历史checksum可兼容占位文件"
,
func
(
t
*
testing
.
T
)
{
ok
:=
isMigrationChecksumCompatible
(
"119_enforce_payment_orders_out_trade_no_unique.sql"
,
"ebd2c67cce0116393fb4f1b5d5116a67c6aceb73820dfb5133d1ff6f36d72d34"
,
"0bbe809ae48a9d811dabda1ba1c74955bd71c4a9cc610f9128816818dfa6c11e"
,
)
require
.
True
(
t
,
ok
)
})
t
.
Run
(
"118多个历史checksum都可兼容当前版本"
,
func
(
t
*
testing
.
T
)
{
for
_
,
dbChecksum
:=
range
[]
string
{
"a38243ca0a72c3a01c0a92b7986423054d6133c0399441f853b99802852720fb"
,
"e0cdf835d6c688d64100f483d31bc02ac9ebad414bf1837af239a84bf75b8227"
,
}
{
ok
:=
isMigrationChecksumCompatible
(
"118_wechat_dual_mode_and_auth_source_defaults.sql"
,
dbChecksum
,
"b54194d7a3e4fbf710e0a3590d22a2fe7966804c487052a356e0b55f53ef96b0"
,
)
require
.
True
(
t
,
ok
)
}
})
t
.
Run
(
"120多个历史checksum都可兼容新的notx修复版本"
,
func
(
t
*
testing
.
T
)
{
for
_
,
dbChecksum
:=
range
[]
string
{
"e77921f79d539bc24575cb9c16cbe566d2b23ce816190343d0a7568f6a3fcf61"
,
"707431450603e70a43ce9fbd61e0c12fa67da4875158ccefabacea069587ab22"
,
"04b082b5a239c525154fe9185d324ee2b05ff90da9297e10dba19f9be79aa59a"
,
}
{
ok
:=
isMigrationChecksumCompatible
(
"120_enforce_payment_orders_out_trade_no_unique_notx.sql"
,
dbChecksum
,
"34aadc0db59a4e390f92a12b73bd74642d9724f33124f73638ae00089ea5e074"
,
)
require
.
True
(
t
,
ok
)
}
})
t
.
Run
(
"119未知checksum不兼容"
,
func
(
t
*
testing
.
T
)
{
ok
:=
isMigrationChecksumCompatible
(
"119_enforce_payment_orders_out_trade_no_unique.sql"
,
"ebd2c67cce0116393fb4f1b5d5116a67c6aceb73820dfb5133d1ff6f36d72d34"
,
"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"
,
)
require
.
False
(
t
,
ok
)
})
}
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment